diff --git a/cache.go b/cache.go index 1cfcaf9..d860225 100644 --- a/cache.go +++ b/cache.go @@ -88,14 +88,11 @@ func New[K comparable, V any](opts ...Option[K, V]) *Cache[K, V] { // updateExpirations updates the expiration queue and notifies // the cache auto cleaner if needed. +// 'oldExpiresAt' should reflect the front of the expiration queue +// before any item mutations. // Not safe for concurrent use by multiple goroutines without additional // locking. -func (c *Cache[K, V]) updateExpirations(fresh bool, elem *list.Element) { - var oldExpiresAt time.Time - - if !c.items.expQueue.isEmpty() { - oldExpiresAt = c.items.expQueue[0].Value.(*Item[K, V]).expiresAt - } +func (c *Cache[K, V]) updateExpirations(fresh bool, elem *list.Element, oldExpiresAt time.Time) { if fresh { c.items.expQueue.push(elem) @@ -151,9 +148,14 @@ func (c *Cache[K, V]) set(key K, value V, ttl time.Duration) *Item[K, V] { item := elem.Value.(*Item[K, V]) oldItemCost := item.cost + var oldExpiresAt time.Time + if !c.items.expQueue.isEmpty() { + oldExpiresAt = c.items.expQueue[0].Value.(*Item[K, V]).expiresAt + } + item.update(value, ttl) - c.updateExpirations(false, elem) + c.updateExpirations(false, elem, oldExpiresAt) if c.options.maxCost != 0 { c.cost = c.cost - oldItemCost + item.cost @@ -185,11 +187,16 @@ func (c *Cache[K, V]) set(key K, value V, ttl time.Duration) *Item[K, V] { ttl = c.options.ttl } + var oldExpiresAt time.Time + if !c.items.expQueue.isEmpty() { + oldExpiresAt = c.items.expQueue[0].Value.(*Item[K, V]).expiresAt + } + // create a new item item := NewItemWithOpts(key, value, ttl, c.options.itemOpts...) elem = c.items.lru.PushFront(item) c.items.values[key] = elem - c.updateExpirations(true, elem) + c.updateExpirations(true, elem, oldExpiresAt) if c.options.maxCost != 0 { c.cost += item.cost @@ -231,8 +238,13 @@ func (c *Cache[K, V]) get(key K, touch bool, includeExpired bool) *list.Element c.items.lru.MoveToFront(elem) if touch && item.ttl > 0 { + var oldExpiresAt time.Time + if !c.items.expQueue.isEmpty() { + oldExpiresAt = c.items.expQueue[0].Value.(*Item[K, V]).expiresAt + } + item.touch() - c.updateExpirations(false, elem) + c.updateExpirations(false, elem, oldExpiresAt) } return elem diff --git a/cache_test.go b/cache_test.go index d436aaa..8124ec5 100644 --- a/cache_test.go +++ b/cache_test.go @@ -43,6 +43,7 @@ func Test_Cache_updateExpirations(t *testing.T) { TimerChValue time.Duration Fresh bool EmptyQueue bool + SingleItem bool OldExpiresAt time.Time NewExpiresAt time.Time Result time.Duration @@ -115,6 +116,12 @@ func Test_Cache_updateExpirations(t *testing.T) { NewExpiresAt: newExp, Result: time.Until(newExp), }, + "Update with non fresh item, single item in queue and shortened expiresAt field": { + SingleItem: true, + OldExpiresAt: oldExp, + NewExpiresAt: newExp, + Result: time.Until(newExp), + }, } for cn, c := range cc { @@ -136,11 +143,13 @@ func Test_Cache_updateExpirations(t *testing.T) { } if !c.EmptyQueue { - cache.items.expQueue.push(&list.Element{ - Value: &Item[string, string]{ - expiresAt: c.OldExpiresAt, - }, - }) + if !c.SingleItem { + cache.items.expQueue.push(&list.Element{ + Value: &Item[string, string]{ + expiresAt: c.OldExpiresAt, + }, + }) + } if !c.Fresh { elem = &list.Element{ @@ -154,7 +163,11 @@ func Test_Cache_updateExpirations(t *testing.T) { } } - cache.updateExpirations(c.Fresh, elem) + var oldExpiresAt time.Time + if !c.EmptyQueue { + oldExpiresAt = c.OldExpiresAt + } + cache.updateExpirations(c.Fresh, elem, oldExpiresAt) var res time.Duration @@ -172,13 +185,14 @@ func Test_Cache_set(t *testing.T) { const newKey, existingKey, evictedKey = "newKey123", "existingKey", "evicted" cc := map[string]struct { - Capacity uint64 - MaxCost uint64 - Key string - TTL time.Duration - Metrics Metrics - InsertCalled bool - UpdateCalled bool + Capacity uint64 + MaxCost uint64 + Key string + TTL time.Duration + Metrics Metrics + InsertCalled bool + UpdateCalled bool + ExpectedTimerNotification time.Duration }{ "Set with existing key and custom TTL": { Key: existingKey, @@ -294,6 +308,24 @@ func Test_Cache_set(t *testing.T) { Evictions: 1, }, }, + "Set with existing key and shortened TTL": { + Key: existingKey, + TTL: time.Minute, + Metrics: Metrics{ + Updates: 1, + }, + UpdateCalled: true, + ExpectedTimerNotification: time.Minute, + }, + "Set with new key and shortened TTL": { + Key: newKey, + TTL: time.Minute, + Metrics: Metrics{ + Insertions: 1, + }, + InsertCalled: true, + ExpectedTimerNotification: time.Minute, + }, } for cn, c := range cc { @@ -384,6 +416,16 @@ func Test_Cache_set(t *testing.T) { assert.Zero(t, item.expiresAt) assert.NotEqual(t, c.Key, cache.items.expQueue[0].Value.(*Item[string, string]).key) } + + if c.ExpectedTimerNotification > 0 { + var res time.Duration + select { + case res = <-cache.items.timerCh: + default: + t.Fatal("expected timer notification but channel was empty") + } + assert.InDelta(t, c.ExpectedTimerNotification, res, float64(time.Second)) + } }) } @@ -422,28 +464,37 @@ func Test_Cache_get(t *testing.T) { const existingKey, notFoundKey, expiredKey = "existing", "notfound", "expired" cc := map[string]struct { - Key string - Touch bool - WithTTL bool + Key string + Touch bool + TTL time.Duration + AddExpiredKey bool + ExpectedTimerNotification time.Duration }{ "Retrieval of non-existent item": { Key: notFoundKey, }, "Retrieval of expired item": { - Key: expiredKey, + Key: expiredKey, + AddExpiredKey: true, }, "Retrieval of existing item without update": { Key: existingKey, }, "Retrieval of existing item with touch and non zero TTL": { - Key: existingKey, - Touch: true, - WithTTL: true, + Key: existingKey, + Touch: true, + TTL: time.Hour * 30, }, "Retrieval of existing item with touch and zero TTL": { Key: existingKey, Touch: true, }, + "Retrieval of existing item with touch and shortened TTL": { + Key: existingKey, + Touch: true, + TTL: time.Millisecond, + ExpectedTimerNotification: time.Millisecond, + }, } for cn, c := range cc { @@ -453,18 +504,16 @@ func Test_Cache_get(t *testing.T) { t.Parallel() cache := prepCache(0, time.Hour, existingKey, "test2", "test3") - addExpiredCacheItems(cache, expiredKey) - time.Sleep(time.Millisecond) // force expiration + if c.AddExpiredKey { + addExpiredCacheItems(cache, expiredKey) + time.Sleep(time.Millisecond) // force expiration + } oldItem := cache.items.values[existingKey].Value.(*Item[string, string]) oldQueueIndex := oldItem.queueIndex oldExpiresAt := oldItem.expiresAt - if c.WithTTL { - oldItem.ttl = time.Hour * 30 - } else { - oldItem.ttl = 0 - } + oldItem.ttl = c.TTL elem := cache.get(c.Key, c.Touch, false) @@ -482,14 +531,30 @@ func Test_Cache_get(t *testing.T) { require.NotNil(t, elem) item := elem.Value.(*Item[string, string]) - if c.Touch && c.WithTTL { - assert.True(t, item.expiresAt.After(oldExpiresAt)) - assert.NotEqual(t, oldQueueIndex, item.queueIndex) + if c.Touch && c.TTL > 0 { + if item.expiresAt.Before(oldExpiresAt) { + assert.Equal(t, oldQueueIndex, item.queueIndex) + } else { + assert.True(t, item.expiresAt.After(oldExpiresAt)) + assert.NotEqual(t, oldQueueIndex, item.queueIndex) + } } else { assert.True(t, item.expiresAt.Equal(oldExpiresAt)) assert.Equal(t, oldQueueIndex, item.queueIndex) } + select { + case res := <-cache.items.timerCh: + if c.ExpectedTimerNotification == 0 { + t.Fatalf("unexpected timer notification: %v", res) + } + assert.InDelta(t, c.ExpectedTimerNotification, res, float64(time.Second)) + default: + if c.ExpectedTimerNotification > 0 { + t.Fatal("expected timer notification but channel was empty") + } + } + assert.Equal(t, c.Key, cache.items.lru.Front().Value.(*Item[string, string]).key) }) }