Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
30 changes: 21 additions & 9 deletions cache.go
Original file line number Diff line number Diff line change
Expand Up @@ -88,14 +88,11 @@ func New[K comparable, V any](opts ...Option[K, V]) *Cache[K, V] {

// updateExpirations updates the expiration queue and notifies
// the cache auto cleaner if needed.
// 'oldExpiresAt' should reflect the front of the expiration queue
// before any item mutations.
// Not safe for concurrent use by multiple goroutines without additional
// locking.
func (c *Cache[K, V]) updateExpirations(fresh bool, elem *list.Element) {
var oldExpiresAt time.Time

if !c.items.expQueue.isEmpty() {
oldExpiresAt = c.items.expQueue[0].Value.(*Item[K, V]).expiresAt
}
func (c *Cache[K, V]) updateExpirations(fresh bool, elem *list.Element, oldExpiresAt time.Time) {

if fresh {
c.items.expQueue.push(elem)
Expand Down Expand Up @@ -151,9 +148,14 @@ func (c *Cache[K, V]) set(key K, value V, ttl time.Duration) *Item[K, V] {
item := elem.Value.(*Item[K, V])
oldItemCost := item.cost

var oldExpiresAt time.Time
if !c.items.expQueue.isEmpty() {
oldExpiresAt = c.items.expQueue[0].Value.(*Item[K, V]).expiresAt
}

item.update(value, ttl)

c.updateExpirations(false, elem)
c.updateExpirations(false, elem, oldExpiresAt)

if c.options.maxCost != 0 {
c.cost = c.cost - oldItemCost + item.cost
Expand Down Expand Up @@ -185,11 +187,16 @@ func (c *Cache[K, V]) set(key K, value V, ttl time.Duration) *Item[K, V] {
ttl = c.options.ttl
}

var oldExpiresAt time.Time
if !c.items.expQueue.isEmpty() {
oldExpiresAt = c.items.expQueue[0].Value.(*Item[K, V]).expiresAt
}

// create a new item
item := NewItemWithOpts(key, value, ttl, c.options.itemOpts...)
elem = c.items.lru.PushFront(item)
c.items.values[key] = elem
c.updateExpirations(true, elem)
c.updateExpirations(true, elem, oldExpiresAt)

if c.options.maxCost != 0 {
c.cost += item.cost
Expand Down Expand Up @@ -231,8 +238,13 @@ func (c *Cache[K, V]) get(key K, touch bool, includeExpired bool) *list.Element
c.items.lru.MoveToFront(elem)

if touch && item.ttl > 0 {
var oldExpiresAt time.Time
if !c.items.expQueue.isEmpty() {
oldExpiresAt = c.items.expQueue[0].Value.(*Item[K, V]).expiresAt
}

item.touch()
c.updateExpirations(false, elem)
c.updateExpirations(false, elem, oldExpiresAt)
}

return elem
Expand Down
125 changes: 95 additions & 30 deletions cache_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@ func Test_Cache_updateExpirations(t *testing.T) {
TimerChValue time.Duration
Fresh bool
EmptyQueue bool
SingleItem bool
OldExpiresAt time.Time
NewExpiresAt time.Time
Result time.Duration
Expand Down Expand Up @@ -115,6 +116,12 @@ func Test_Cache_updateExpirations(t *testing.T) {
NewExpiresAt: newExp,
Result: time.Until(newExp),
},
"Update with non fresh item, single item in queue and shortened expiresAt field": {
SingleItem: true,
OldExpiresAt: oldExp,
NewExpiresAt: newExp,
Result: time.Until(newExp),
},
}

for cn, c := range cc {
Expand All @@ -136,11 +143,13 @@ func Test_Cache_updateExpirations(t *testing.T) {
}

if !c.EmptyQueue {
cache.items.expQueue.push(&list.Element{
Value: &Item[string, string]{
expiresAt: c.OldExpiresAt,
},
})
if !c.SingleItem {
cache.items.expQueue.push(&list.Element{
Value: &Item[string, string]{
expiresAt: c.OldExpiresAt,
},
})
}

if !c.Fresh {
elem = &list.Element{
Expand All @@ -154,7 +163,11 @@ func Test_Cache_updateExpirations(t *testing.T) {
}
}

cache.updateExpirations(c.Fresh, elem)
var oldExpiresAt time.Time
if !c.EmptyQueue {
oldExpiresAt = c.OldExpiresAt
}
cache.updateExpirations(c.Fresh, elem, oldExpiresAt)

var res time.Duration

Expand All @@ -172,13 +185,14 @@ func Test_Cache_set(t *testing.T) {
const newKey, existingKey, evictedKey = "newKey123", "existingKey", "evicted"

cc := map[string]struct {
Capacity uint64
MaxCost uint64
Key string
TTL time.Duration
Metrics Metrics
InsertCalled bool
UpdateCalled bool
Capacity uint64
MaxCost uint64
Key string
TTL time.Duration
Metrics Metrics
InsertCalled bool
UpdateCalled bool
ExpectedTimerNotification time.Duration
}{
"Set with existing key and custom TTL": {
Key: existingKey,
Expand Down Expand Up @@ -294,6 +308,24 @@ func Test_Cache_set(t *testing.T) {
Evictions: 1,
},
},
"Set with existing key and shortened TTL": {
Key: existingKey,
TTL: time.Minute,
Metrics: Metrics{
Updates: 1,
},
UpdateCalled: true,
ExpectedTimerNotification: time.Minute,
},
"Set with new key and shortened TTL": {
Key: newKey,
TTL: time.Minute,
Metrics: Metrics{
Insertions: 1,
},
InsertCalled: true,
ExpectedTimerNotification: time.Minute,
},
}

for cn, c := range cc {
Expand Down Expand Up @@ -384,6 +416,16 @@ func Test_Cache_set(t *testing.T) {
assert.Zero(t, item.expiresAt)
assert.NotEqual(t, c.Key, cache.items.expQueue[0].Value.(*Item[string, string]).key)
}

if c.ExpectedTimerNotification > 0 {
var res time.Duration
select {
case res = <-cache.items.timerCh:
default:
t.Fatal("expected timer notification but channel was empty")
}
assert.InDelta(t, c.ExpectedTimerNotification, res, float64(time.Second))
}
})
}

Expand Down Expand Up @@ -422,28 +464,37 @@ func Test_Cache_get(t *testing.T) {
const existingKey, notFoundKey, expiredKey = "existing", "notfound", "expired"

cc := map[string]struct {
Key string
Touch bool
WithTTL bool
Key string
Touch bool
TTL time.Duration
AddExpiredKey bool
ExpectedTimerNotification time.Duration
}{
"Retrieval of non-existent item": {
Key: notFoundKey,
},
"Retrieval of expired item": {
Key: expiredKey,
Key: expiredKey,
AddExpiredKey: true,
},
"Retrieval of existing item without update": {
Key: existingKey,
},
"Retrieval of existing item with touch and non zero TTL": {
Key: existingKey,
Touch: true,
WithTTL: true,
Key: existingKey,
Touch: true,
TTL: time.Hour * 30,
},
"Retrieval of existing item with touch and zero TTL": {
Key: existingKey,
Touch: true,
},
"Retrieval of existing item with touch and shortened TTL": {
Key: existingKey,
Touch: true,
TTL: time.Millisecond,
ExpectedTimerNotification: time.Millisecond,
},
}

for cn, c := range cc {
Expand All @@ -453,18 +504,16 @@ func Test_Cache_get(t *testing.T) {
t.Parallel()

cache := prepCache(0, time.Hour, existingKey, "test2", "test3")
addExpiredCacheItems(cache, expiredKey)
time.Sleep(time.Millisecond) // force expiration
if c.AddExpiredKey {
addExpiredCacheItems(cache, expiredKey)
time.Sleep(time.Millisecond) // force expiration
}

oldItem := cache.items.values[existingKey].Value.(*Item[string, string])
oldQueueIndex := oldItem.queueIndex
oldExpiresAt := oldItem.expiresAt

if c.WithTTL {
oldItem.ttl = time.Hour * 30
} else {
oldItem.ttl = 0
}
oldItem.ttl = c.TTL

elem := cache.get(c.Key, c.Touch, false)

Expand All @@ -482,14 +531,30 @@ func Test_Cache_get(t *testing.T) {
require.NotNil(t, elem)
item := elem.Value.(*Item[string, string])

if c.Touch && c.WithTTL {
assert.True(t, item.expiresAt.After(oldExpiresAt))
assert.NotEqual(t, oldQueueIndex, item.queueIndex)
if c.Touch && c.TTL > 0 {
if item.expiresAt.Before(oldExpiresAt) {
assert.Equal(t, oldQueueIndex, item.queueIndex)
} else {
assert.True(t, item.expiresAt.After(oldExpiresAt))
assert.NotEqual(t, oldQueueIndex, item.queueIndex)
}
} else {
assert.True(t, item.expiresAt.Equal(oldExpiresAt))
assert.Equal(t, oldQueueIndex, item.queueIndex)
}

select {
case res := <-cache.items.timerCh:
if c.ExpectedTimerNotification == 0 {
t.Fatalf("unexpected timer notification: %v", res)
}
assert.InDelta(t, c.ExpectedTimerNotification, res, float64(time.Second))
default:
if c.ExpectedTimerNotification > 0 {
t.Fatal("expected timer notification but channel was empty")
}
}

assert.Equal(t, c.Key, cache.items.lru.Front().Value.(*Item[string, string]).key)
})
}
Expand Down
Loading