diff --git a/pkg/localcache/cache.go b/pkg/localcache/cache.go index ba849f892..b2376d6f1 100644 --- a/pkg/localcache/cache.go +++ b/pkg/localcache/cache.go @@ -49,7 +49,7 @@ func New[V any](opts ...Option) Cache[V] { if opt.expirationEvict { return lru.NewExpirationLRU[string, V](opt.localSlotSize, opt.localSuccessTTL, opt.localFailedTTL, opt.target, c.onEvict) } else { - return lru.NewLayLRU[string, V](opt.localSlotSize, opt.localSuccessTTL, opt.localFailedTTL, opt.target, c.onEvict) + return lru.NewLazyLRU[string, V](opt.localSlotSize, opt.localSuccessTTL, opt.localFailedTTL, opt.target, c.onEvict) } } if opt.localSlotNum == 1 { @@ -72,11 +72,18 @@ type cache[V any] struct { func (c *cache[V]) onEvict(key string, value V) { if c.link != nil { - lks := c.link.Del(key) - for k := range lks { - if key != k { // prevent deadlock - c.local.Del(k) - } + // Do not delete other keys while the underlying LRU still holds its lock; + // defer linked deletions to avoid re-entering the same slot and deadlocking. + if lks := c.link.Del(key); len(lks) > 0 { + go c.delLinked(key, lks) + } + } +} + +func (c *cache[V]) delLinked(src string, keys map[string]struct{}) { + for k := range keys { + if src != k { + c.local.Del(k) } } } @@ -103,7 +110,7 @@ func (c *cache[V]) Get(ctx context.Context, key string, fetch func(ctx context.C func (c *cache[V]) GetLink(ctx context.Context, key string, fetch func(ctx context.Context) (V, error), link ...string) (V, error) { if c.local != nil { return c.local.Get(key, func() (V, error) { - if len(link) > 0 { + if len(link) > 0 && c.link != nil { c.link.Link(key, link...) } return fetch(ctx) diff --git a/pkg/localcache/cache_test.go b/pkg/localcache/cache_test.go index c206e6799..13eb20797 100644 --- a/pkg/localcache/cache_test.go +++ b/pkg/localcache/cache_test.go @@ -22,6 +22,8 @@ import ( "sync/atomic" "testing" "time" + + "github.com/openimsdk/open-im-server/v3/pkg/localcache/lru" ) func TestName(t *testing.T) { @@ -91,3 +93,68 @@ func TestName(t *testing.T) { t.Log("del", del.Load()) // 137.35s } + +// Test deadlock scenario when eviction callback deletes a linked key that hashes to the same slot. +func TestCacheEvictDeadlock(t *testing.T) { + ctx := context.Background() + c := New[string](WithLocalSlotNum(1), WithLocalSlotSize(1), WithLazy()) + + if _, err := c.GetLink(ctx, "k1", func(ctx context.Context) (string, error) { + return "v1", nil + }, "k2"); err != nil { + t.Fatalf("seed cache failed: %v", err) + } + + done := make(chan struct{}) + go func() { + defer close(done) + _, _ = c.GetLink(ctx, "k2", func(ctx context.Context) (string, error) { + return "v2", nil + }, "k1") + }() + + select { + case <-done: + // expected to finish quickly; current implementation deadlocks here. + case <-time.After(time.Second): + t.Fatal("GetLink deadlocked during eviction of linked key") + } +} + +func TestExpirationLRUGetBatch(t *testing.T) { + l := lru.NewExpirationLRU[string, string](2, time.Minute, time.Second*5, EmptyTarget{}, nil) + + keys := []string{"a", "b"} + values, err := l.GetBatch(keys, func(keys []string) (map[string]string, error) { + res := make(map[string]string) + for _, k := range keys { + res[k] = k + "_v" + } + return res, nil + }) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if len(values) != len(keys) { + t.Fatalf("expected %d values, got %d", len(keys), len(values)) + } + for _, k := range keys { + if v, ok := values[k]; !ok || v != k+"_v" { + t.Fatalf("unexpected value for %s: %q, ok=%v", k, v, ok) + } + } + + // second batch should hit cache + values, err = l.GetBatch(keys, func(keys []string) (map[string]string, error) { + t.Fatalf("should not fetch on cache hit") + return nil, nil + }) + if err != nil { + t.Fatalf("unexpected error on cache hit: %v", err) + } + for _, k := range keys { + if v, ok := values[k]; !ok || v != k+"_v" { + t.Fatalf("unexpected cached value for %s: %q, ok=%v", k, v, ok) + } + } +} diff --git a/pkg/localcache/lru/lru_expiration.go b/pkg/localcache/lru/lru_expiration.go index df6bacbf4..4197cacec 100644 --- a/pkg/localcache/lru/lru_expiration.go +++ b/pkg/localcache/lru/lru_expiration.go @@ -52,8 +52,53 @@ type ExpirationLRU[K comparable, V any] struct { } func (x *ExpirationLRU[K, V]) GetBatch(keys []K, fetch func(keys []K) (map[K]V, error)) (map[K]V, error) { - //TODO implement me - panic("implement me") + var ( + err error + results = make(map[K]V) + misses = make([]K, 0, len(keys)) + ) + + for _, key := range keys { + x.lock.Lock() + v, ok := x.core.Get(key) + x.lock.Unlock() + if ok { + x.target.IncrGetHit() + v.lock.RLock() + results[key] = v.value + if v.err != nil && err == nil { + err = v.err + } + v.lock.RUnlock() + continue + } + misses = append(misses, key) + } + + if len(misses) == 0 { + return results, err + } + + fetchValues, fetchErr := fetch(misses) + if fetchErr != nil && err == nil { + err = fetchErr + } + + for key, val := range fetchValues { + results[key] = val + if fetchErr != nil { + x.target.IncrGetFailed() + continue + } + x.target.IncrGetSuccess() + item := &expirationLruItem[V]{value: val} + x.lock.Lock() + x.core.Add(key, item) + x.lock.Unlock() + } + + // any keys not returned from fetch remain absent (no cache write) + return results, err } func (x *ExpirationLRU[K, V]) Get(key K, fetch func() (V, error)) (V, error) { diff --git a/pkg/localcache/lru/lru_lazy.go b/pkg/localcache/lru/lru_lazy.go index b4f0377a7..4a3db46c9 100644 --- a/pkg/localcache/lru/lru_lazy.go +++ b/pkg/localcache/lru/lru_lazy.go @@ -21,25 +21,25 @@ import ( "github.com/hashicorp/golang-lru/v2/simplelru" ) -type layLruItem[V any] struct { +type lazyLruItem[V any] struct { lock sync.Mutex expires int64 err error value V } -func NewLayLRU[K comparable, V any](size int, successTTL, failedTTL time.Duration, target Target, onEvict EvictCallback[K, V]) *LayLRU[K, V] { - var cb simplelru.EvictCallback[K, *layLruItem[V]] +func NewLazyLRU[K comparable, V any](size int, successTTL, failedTTL time.Duration, target Target, onEvict EvictCallback[K, V]) *LazyLRU[K, V] { + var cb simplelru.EvictCallback[K, *lazyLruItem[V]] if onEvict != nil { - cb = func(key K, value *layLruItem[V]) { + cb = func(key K, value *lazyLruItem[V]) { onEvict(key, value.value) } } - core, err := simplelru.NewLRU[K, *layLruItem[V]](size, cb) + core, err := simplelru.NewLRU[K, *lazyLruItem[V]](size, cb) if err != nil { panic(err) } - return &LayLRU[K, V]{ + return &LazyLRU[K, V]{ core: core, successTTL: successTTL, failedTTL: failedTTL, @@ -47,15 +47,15 @@ func NewLayLRU[K comparable, V any](size int, successTTL, failedTTL time.Duratio } } -type LayLRU[K comparable, V any] struct { +type LazyLRU[K comparable, V any] struct { lock sync.Mutex - core *simplelru.LRU[K, *layLruItem[V]] + core *simplelru.LRU[K, *lazyLruItem[V]] successTTL time.Duration failedTTL time.Duration target Target } -func (x *LayLRU[K, V]) Get(key K, fetch func() (V, error)) (V, error) { +func (x *LazyLRU[K, V]) Get(key K, fetch func() (V, error)) (V, error) { x.lock.Lock() v, ok := x.core.Get(key) if ok { @@ -68,7 +68,7 @@ func (x *LayLRU[K, V]) Get(key K, fetch func() (V, error)) (V, error) { return value, err } } else { - v = &layLruItem[V]{} + v = &lazyLruItem[V]{} x.core.Add(key, v) v.lock.Lock() x.lock.Unlock() @@ -88,15 +88,15 @@ func (x *LayLRU[K, V]) Get(key K, fetch func() (V, error)) (V, error) { return v.value, v.err } -func (x *LayLRU[K, V]) GetBatch(keys []K, fetch func(keys []K) (map[K]V, error)) (map[K]V, error) { +func (x *LazyLRU[K, V]) GetBatch(keys []K, fetch func(keys []K) (map[K]V, error)) (map[K]V, error) { var ( err error once sync.Once ) res := make(map[K]V) - queries := make([]K, 0) - setVs := make(map[K]*layLruItem[V]) + queries := make([]K, 0, len(keys)) + for _, key := range keys { x.lock.Lock() v, ok := x.core.Get(key) @@ -118,14 +118,20 @@ func (x *LayLRU[K, V]) GetBatch(keys []K, fetch func(keys []K) (map[K]V, error)) } queries = append(queries, key) } - values, err1 := fetch(queries) - if err1 != nil { + + if len(queries) == 0 { + return res, err + } + + values, fetchErr := fetch(queries) + if fetchErr != nil { once.Do(func() { - err = err1 + err = fetchErr }) } + for key, val := range values { - v := &layLruItem[V]{} + v := &lazyLruItem[V]{} v.value = val if err == nil { @@ -135,7 +141,7 @@ func (x *LayLRU[K, V]) GetBatch(keys []K, fetch func(keys []K) (map[K]V, error)) v.expires = time.Now().Add(x.failedTTL).UnixMilli() x.target.IncrGetFailed() } - setVs[key] = v + x.lock.Lock() x.core.Add(key, v) x.lock.Unlock() @@ -145,29 +151,29 @@ func (x *LayLRU[K, V]) GetBatch(keys []K, fetch func(keys []K) (map[K]V, error)) return res, err } -//func (x *LayLRU[K, V]) Has(key K) bool { +//func (x *LazyLRU[K, V]) Has(key K) bool { // x.lock.Lock() // defer x.lock.Unlock() // return x.core.Contains(key) //} -func (x *LayLRU[K, V]) Set(key K, value V) { +func (x *LazyLRU[K, V]) Set(key K, value V) { x.lock.Lock() defer x.lock.Unlock() - x.core.Add(key, &layLruItem[V]{value: value, expires: time.Now().Add(x.successTTL).UnixMilli()}) + x.core.Add(key, &lazyLruItem[V]{value: value, expires: time.Now().Add(x.successTTL).UnixMilli()}) } -func (x *LayLRU[K, V]) SetHas(key K, value V) bool { +func (x *LazyLRU[K, V]) SetHas(key K, value V) bool { x.lock.Lock() defer x.lock.Unlock() if x.core.Contains(key) { - x.core.Add(key, &layLruItem[V]{value: value, expires: time.Now().Add(x.successTTL).UnixMilli()}) + x.core.Add(key, &lazyLruItem[V]{value: value, expires: time.Now().Add(x.successTTL).UnixMilli()}) return true } return false } -func (x *LayLRU[K, V]) Del(key K) bool { +func (x *LazyLRU[K, V]) Del(key K) bool { x.lock.Lock() ok := x.core.Remove(key) x.lock.Unlock() @@ -179,6 +185,6 @@ func (x *LayLRU[K, V]) Del(key K) bool { return ok } -func (x *LayLRU[K, V]) Stop() { +func (x *LazyLRU[K, V]) Stop() { } diff --git a/pkg/rpccache/online.go b/pkg/rpccache/online.go index b5308bbe8..87823c1c0 100644 --- a/pkg/rpccache/online.go +++ b/pkg/rpccache/online.go @@ -3,15 +3,16 @@ package rpccache import ( "context" "fmt" - "github.com/openimsdk/open-im-server/v3/pkg/rpcli" - "github.com/openimsdk/protocol/constant" - "github.com/openimsdk/protocol/user" "math/rand" "strconv" "sync" "sync/atomic" "time" + "github.com/openimsdk/open-im-server/v3/pkg/rpcli" + "github.com/openimsdk/protocol/constant" + "github.com/openimsdk/protocol/user" + "github.com/openimsdk/open-im-server/v3/pkg/common/storage/cache/cachekey" "github.com/openimsdk/open-im-server/v3/pkg/localcache" "github.com/openimsdk/open-im-server/v3/pkg/localcache/lru" @@ -46,7 +47,7 @@ func NewOnlineCache(client *rpcli.UserClient, group *GroupLocalCache, rdb redis. case false: log.ZDebug(ctx, "fullUserCache is false") x.lruCache = lru.NewSlotLRU(1024, localcache.LRUStringHash, func() lru.LRU[string, []int32] { - return lru.NewLayLRU[string, []int32](2048, cachekey.OnlineExpire/2, time.Second*3, localcache.EmptyTarget{}, func(key string, value []int32) {}) + return lru.NewLazyLRU[string, []int32](2048, cachekey.OnlineExpire/2, time.Second*3, localcache.EmptyTarget{}, func(key string, value []int32) {}) }) x.CurrentPhase.Store(DoSubscribeOver) x.Cond.Broadcast()