diff --git a/cmd/main.go b/cmd/main.go index fc61816e1..a8daa4acd 100644 --- a/cmd/main.go +++ b/cmd/main.go @@ -8,14 +8,12 @@ import ( "fmt" "net" "os" - "os/signal" "path" "path/filepath" "reflect" "runtime" "strings" "sync" - "syscall" "time" "github.com/mitchellh/mapstructure" @@ -210,6 +208,11 @@ func (x *cmds) run(ctx context.Context) error { ctx, cancel := context.WithCancelCause(ctx) + go func() { + <-ctx.Done() + log.ZError(ctx, "context server exit cause", context.Cause(ctx)) + }() + if prometheus := x.config.API.Prometheus; prometheus.Enable { var ( port int @@ -247,16 +250,16 @@ func (x *cmds) run(ctx context.Context) error { }() } - go func() { - sigs := make(chan os.Signal, 1) - signal.Notify(sigs, syscall.SIGTERM, syscall.SIGINT, syscall.SIGKILL) - select { - case <-ctx.Done(): - return - case val := <-sigs: - cancel(fmt.Errorf("signal %s", val.String())) - } - }() + //go func() { + // sigs := make(chan os.Signal, 1) + // signal.Notify(sigs, syscall.SIGTERM, syscall.SIGINT, syscall.SIGKILL) + // select { + // case <-ctx.Done(): + // return + // case val := <-sigs: + // cancel(fmt.Errorf("signal %s", val.String())) + // } + //}() for i := range x.cmds { cmd := x.cmds[i] diff --git a/internal/api/init.go b/internal/api/init.go index 69ade8aaf..4a1404ffc 100644 --- a/internal/api/init.go +++ b/internal/api/init.go @@ -17,15 +17,14 @@ package api import ( "context" "errors" + "fmt" "net" "net/http" "strconv" "time" conf "github.com/openimsdk/open-im-server/v3/pkg/common/config" - disetcd "github.com/openimsdk/open-im-server/v3/pkg/common/discovery/etcd" "github.com/openimsdk/tools/discovery" - "github.com/openimsdk/tools/discovery/etcd" "github.com/openimsdk/tools/log" "github.com/openimsdk/tools/utils/datautil" "github.com/openimsdk/tools/utils/network" @@ -51,7 +50,7 @@ func Start(ctx context.Context, config *Config, client discovery.Conn, service g return err } - ctx, cancel := context.WithCancelCause(ctx) + apiCtx, apiCancel := context.WithCancelCause(context.Background()) done := make(chan struct{}) go func() { httpServer := &http.Server{ @@ -60,7 +59,11 @@ func Start(ctx context.Context, config *Config, client discovery.Conn, service g } go func() { defer close(done) - <-ctx.Done() + select { + case <-ctx.Done(): + apiCancel(fmt.Errorf("recv ctx %w", context.Cause(ctx))) + case <-apiCtx.Done(): + } log.ZDebug(ctx, "api server is shutting down") if err := httpServer.Shutdown(context.Background()); err != nil { log.ZWarn(ctx, "api server shutdown err", err) @@ -71,13 +74,13 @@ func Start(ctx context.Context, config *Config, client discovery.Conn, service g if err == nil { err = errors.New("api done") } - cancel(err) + apiCancel(err) }() - if config.Discovery.Enable == conf.ETCD { - cm := disetcd.NewConfigManager(client.(*etcd.SvcDiscoveryRegistryImpl).GetClient(), config.GetConfigNames()) - cm.Watch(ctx) - } + //if config.Discovery.Enable == conf.ETCD { + // cm := disetcd.NewConfigManager(client.(*etcd.SvcDiscoveryRegistryImpl).GetClient(), config.GetConfigNames()) + // cm.Watch(ctx) + //} //sigs := make(chan os.Signal, 1) //signal.Notify(sigs, syscall.SIGTERM) //select { @@ -86,6 +89,7 @@ func Start(ctx context.Context, config *Config, client discovery.Conn, service g // cancel(fmt.Errorf("signal %s", val.String())) //case <-ctx.Done(): //} + <-apiCtx.Done() exitCause := context.Cause(ctx) log.ZWarn(ctx, "api server exit", exitCause) timer := time.NewTimer(time.Second * 15) diff --git a/internal/msgtransfer/init.go b/internal/msgtransfer/init.go index d6c8a0797..e02940853 100644 --- a/internal/msgtransfer/init.go +++ b/internal/msgtransfer/init.go @@ -21,6 +21,8 @@ import ( "syscall" disetcd "github.com/openimsdk/open-im-server/v3/pkg/common/discovery/etcd" + "github.com/openimsdk/open-im-server/v3/pkg/common/storage/cache" + "github.com/openimsdk/open-im-server/v3/pkg/common/storage/cache/mcache" "github.com/openimsdk/open-im-server/v3/pkg/common/storage/cache/redis" "github.com/openimsdk/open-im-server/v3/pkg/common/storage/database/mgo" "github.com/openimsdk/open-im-server/v3/pkg/dbbuild" @@ -101,7 +103,16 @@ func Start(ctx context.Context, config *Config, client discovery.Conn, server gr if err != nil { return err } - msgModel := redis.NewMsgCache(rdb, msgDocModel) + var msgModel cache.MsgCache + if rdb == nil { + cm, err := mgo.NewCacheMgo(mgocli.GetDB()) + if err != nil { + return err + } + msgModel = mcache.NewMsgCache(cm, msgDocModel) + } else { + msgModel = redis.NewMsgCache(rdb, msgDocModel) + } seqConversation, err := mgo.NewSeqConversationMongo(mgocli.GetDB()) if err != nil { return err diff --git a/internal/rpc/msg/server.go b/internal/rpc/msg/server.go index d0b228156..7737f7e7f 100644 --- a/internal/rpc/msg/server.go +++ b/internal/rpc/msg/server.go @@ -17,6 +17,8 @@ package msg import ( "context" + "github.com/openimsdk/open-im-server/v3/pkg/common/storage/cache" + "github.com/openimsdk/open-im-server/v3/pkg/common/storage/cache/mcache" "github.com/openimsdk/open-im-server/v3/pkg/dbbuild" "github.com/openimsdk/open-im-server/v3/pkg/mqbuild" "github.com/openimsdk/open-im-server/v3/pkg/rpcli" @@ -95,7 +97,16 @@ func Start(ctx context.Context, config *Config, client discovery.Conn, server gr if err != nil { return err } - msgModel := redis.NewMsgCache(rdb, msgDocModel) + var msgModel cache.MsgCache + if rdb == nil { + cm, err := mgo.NewCacheMgo(mgocli.GetDB()) + if err != nil { + return err + } + msgModel = mcache.NewMsgCache(cm, msgDocModel) + } else { + msgModel = redis.NewMsgCache(rdb, msgDocModel) + } seqConversation, err := mgo.NewSeqConversationMongo(mgocli.GetDB()) if err != nil { return err diff --git a/pkg/common/storage/cache/mcache/minio.go b/pkg/common/storage/cache/mcache/minio.go index ecee54aa5..f07203cc2 100644 --- a/pkg/common/storage/cache/mcache/minio.go +++ b/pkg/common/storage/cache/mcache/minio.go @@ -2,12 +2,10 @@ package mcache import ( "context" - "encoding/json" "time" "github.com/openimsdk/open-im-server/v3/pkg/common/storage/cache/cachekey" "github.com/openimsdk/open-im-server/v3/pkg/common/storage/database" - "github.com/openimsdk/tools/log" "github.com/openimsdk/tools/s3/minio" ) @@ -50,56 +48,3 @@ func (g *minioCache) GetImageObjectKeyInfo(ctx context.Context, key string, fn f func (g *minioCache) GetThumbnailKey(ctx context.Context, key string, format string, width int, height int, minioCache func(ctx context.Context) (string, error)) (string, error) { return getCache[string](ctx, g.cache, g.getMinioImageThumbnailKey(key, format, width, height), g.expireTime, minioCache) } - -func getCache[V any](ctx context.Context, cache database.Cache, key string, expireTime time.Duration, fn func(ctx context.Context) (V, error)) (V, error) { - getDB := func() (V, bool, error) { - res, err := cache.Get(ctx, []string{key}) - if err != nil { - var val V - return val, false, err - } - var val V - if str, ok := res[key]; ok { - if json.Unmarshal([]byte(str), &val) != nil { - return val, false, err - } - return val, true, nil - } - return val, false, nil - } - dbVal, ok, err := getDB() - if err != nil { - return dbVal, err - } - if ok { - return dbVal, nil - } - lockValue, err := cache.Lock(ctx, key, time.Minute) - if err != nil { - return dbVal, err - } - defer func() { - if err := cache.Unlock(ctx, key, lockValue); err != nil { - log.ZError(ctx, "unlock cache key", err, "key", key, "value", lockValue) - } - }() - dbVal, ok, err = getDB() - if err != nil { - return dbVal, err - } - if ok { - return dbVal, nil - } - val, err := fn(ctx) - if err != nil { - return val, err - } - data, err := json.Marshal(val) - if err != nil { - return val, err - } - if err := cache.Set(ctx, key, string(data), expireTime); err != nil { - return val, err - } - return val, nil -} diff --git a/pkg/common/storage/cache/mcache/msg_cache.go b/pkg/common/storage/cache/mcache/msg_cache.go new file mode 100644 index 000000000..3846be3f8 --- /dev/null +++ b/pkg/common/storage/cache/mcache/msg_cache.go @@ -0,0 +1,132 @@ +package mcache + +import ( + "context" + "strconv" + "sync" + "time" + + "github.com/openimsdk/open-im-server/v3/pkg/common/storage/cache" + "github.com/openimsdk/open-im-server/v3/pkg/common/storage/cache/cachekey" + "github.com/openimsdk/open-im-server/v3/pkg/common/storage/database" + "github.com/openimsdk/open-im-server/v3/pkg/common/storage/model" + "github.com/openimsdk/open-im-server/v3/pkg/localcache" + "github.com/openimsdk/open-im-server/v3/pkg/localcache/lru" + "github.com/openimsdk/tools/errs" + "github.com/openimsdk/tools/utils/datautil" + "github.com/redis/go-redis/v9" +) + +var ( + memMsgCache lru.LRU[string, *model.MsgInfoModel] + initMemMsgCache sync.Once +) + +func NewMsgCache(cache database.Cache, msgDocDatabase database.Msg) cache.MsgCache { + initMemMsgCache.Do(func() { + memMsgCache = lru.NewLayLRU[string, *model.MsgInfoModel](1024*8, time.Hour, time.Second*10, localcache.EmptyTarget{}, nil) + }) + return &msgCache{ + cache: cache, + msgDocDatabase: msgDocDatabase, + memMsgCache: memMsgCache, + } +} + +type msgCache struct { + cache database.Cache + msgDocDatabase database.Msg + memMsgCache lru.LRU[string, *model.MsgInfoModel] +} + +func (x *msgCache) getSendMsgKey(id string) string { + return cachekey.GetSendMsgKey(id) +} + +func (x *msgCache) SetSendMsgStatus(ctx context.Context, id string, status int32) error { + return x.cache.Set(ctx, x.getSendMsgKey(id), strconv.Itoa(int(status)), time.Hour*24) +} + +func (x *msgCache) GetSendMsgStatus(ctx context.Context, id string) (int32, error) { + key := x.getSendMsgKey(id) + res, err := x.cache.Get(ctx, []string{key}) + if err != nil { + return 0, err + } + val, ok := res[key] + if !ok { + return 0, errs.Wrap(redis.Nil) + } + status, err := strconv.Atoi(val) + if err != nil { + return 0, errs.WrapMsg(err, "GetSendMsgStatus strconv.Atoi error", "val", val) + } + return int32(status), nil +} + +func (x *msgCache) getMsgCacheKey(conversationID string, seq int64) string { + return cachekey.GetMsgCacheKey(conversationID, seq) + +} + +func (x *msgCache) GetMessageBySeqs(ctx context.Context, conversationID string, seqs []int64) ([]*model.MsgInfoModel, error) { + if len(seqs) == 0 { + return nil, nil + } + keys := make([]string, 0, len(seqs)) + keySeq := make(map[string]int64, len(seqs)) + for _, seq := range seqs { + key := x.getMsgCacheKey(conversationID, seq) + keys = append(keys, key) + keySeq[key] = seq + } + res, err := x.memMsgCache.GetBatch(keys, func(keys []string) (map[string]*model.MsgInfoModel, error) { + findSeqs := make([]int64, 0, len(keys)) + for _, key := range keys { + seq, ok := keySeq[key] + if !ok { + continue + } + findSeqs = append(findSeqs, seq) + } + res, err := x.msgDocDatabase.FindSeqs(ctx, conversationID, seqs) + if err != nil { + return nil, err + } + kv := make(map[string]*model.MsgInfoModel) + for i := range res { + msg := res[i] + if msg == nil || msg.Msg == nil || msg.Msg.Seq <= 0 { + continue + } + key := x.getMsgCacheKey(conversationID, msg.Msg.Seq) + kv[key] = msg + } + return kv, nil + }) + if err != nil { + return nil, err + } + return datautil.Values(res), nil +} + +func (x msgCache) DelMessageBySeqs(ctx context.Context, conversationID string, seqs []int64) error { + if len(seqs) == 0 { + return nil + } + for _, seq := range seqs { + x.memMsgCache.Del(x.getMsgCacheKey(conversationID, seq)) + } + return nil +} + +func (x *msgCache) SetMessageBySeqs(ctx context.Context, conversationID string, msgs []*model.MsgInfoModel) error { + for i := range msgs { + msg := msgs[i] + if msg == nil || msg.Msg == nil || msg.Msg.Seq <= 0 { + continue + } + x.memMsgCache.Set(x.getMsgCacheKey(conversationID, msg.Msg.Seq), msg) + } + return nil +} diff --git a/pkg/common/storage/cache/mcache/seq_conversation.go b/pkg/common/storage/cache/mcache/seq_conversation.go new file mode 100644 index 000000000..27f00de15 --- /dev/null +++ b/pkg/common/storage/cache/mcache/seq_conversation.go @@ -0,0 +1,79 @@ +package mcache + +import ( + "context" + + "github.com/openimsdk/open-im-server/v3/pkg/common/storage/cache" + "github.com/openimsdk/open-im-server/v3/pkg/common/storage/database" +) + +func NewSeqConversationCache(sc database.SeqConversation) cache.SeqConversationCache { + return &seqConversationCache{ + sc: sc, + } +} + +type seqConversationCache struct { + sc database.SeqConversation +} + +func (x *seqConversationCache) Malloc(ctx context.Context, conversationID string, size int64) (int64, error) { + return x.sc.Malloc(ctx, conversationID, size) +} + +func (x *seqConversationCache) SetMinSeq(ctx context.Context, conversationID string, seq int64) error { + return x.sc.SetMinSeq(ctx, conversationID, seq) +} + +func (x *seqConversationCache) GetMinSeq(ctx context.Context, conversationID string) (int64, error) { + return x.sc.GetMinSeq(ctx, conversationID) +} + +func (x *seqConversationCache) GetMaxSeqs(ctx context.Context, conversationIDs []string) (map[string]int64, error) { + res := make(map[string]int64) + for _, conversationID := range conversationIDs { + seq, err := x.GetMinSeq(ctx, conversationID) + if err != nil { + return nil, err + } + res[conversationID] = seq + } + return res, nil +} + +func (x *seqConversationCache) GetMaxSeqsWithTime(ctx context.Context, conversationIDs []string) (map[string]database.SeqTime, error) { + res := make(map[string]database.SeqTime) + for _, conversationID := range conversationIDs { + seq, err := x.GetMinSeq(ctx, conversationID) + if err != nil { + return nil, err + } + res[conversationID] = database.SeqTime{Seq: seq} + } + return res, nil +} + +func (x *seqConversationCache) GetMaxSeq(ctx context.Context, conversationID string) (int64, error) { + return x.sc.GetMaxSeq(ctx, conversationID) +} + +func (x *seqConversationCache) GetMaxSeqWithTime(ctx context.Context, conversationID string) (database.SeqTime, error) { + seq, err := x.GetMinSeq(ctx, conversationID) + if err != nil { + return database.SeqTime{}, err + } + return database.SeqTime{Seq: seq}, nil +} + +func (x *seqConversationCache) SetMinSeqs(ctx context.Context, seqs map[string]int64) error { + for conversationID, seq := range seqs { + if err := x.sc.SetMinSeq(ctx, conversationID, seq); err != nil { + return err + } + } + return nil +} + +func (x *seqConversationCache) GetCacheMaxSeqWithTime(ctx context.Context, conversationIDs []string) (map[string]database.SeqTime, error) { + return x.GetMaxSeqsWithTime(ctx, conversationIDs) +} diff --git a/pkg/common/storage/cache/mcache/tools.go b/pkg/common/storage/cache/mcache/tools.go new file mode 100644 index 000000000..f3c4265cd --- /dev/null +++ b/pkg/common/storage/cache/mcache/tools.go @@ -0,0 +1,63 @@ +package mcache + +import ( + "context" + "encoding/json" + "time" + + "github.com/openimsdk/open-im-server/v3/pkg/common/storage/database" + "github.com/openimsdk/tools/log" +) + +func getCache[V any](ctx context.Context, cache database.Cache, key string, expireTime time.Duration, fn func(ctx context.Context) (V, error)) (V, error) { + getDB := func() (V, bool, error) { + res, err := cache.Get(ctx, []string{key}) + if err != nil { + var val V + return val, false, err + } + var val V + if str, ok := res[key]; ok { + if json.Unmarshal([]byte(str), &val) != nil { + return val, false, err + } + return val, true, nil + } + return val, false, nil + } + dbVal, ok, err := getDB() + if err != nil { + return dbVal, err + } + if ok { + return dbVal, nil + } + lockValue, err := cache.Lock(ctx, key, time.Minute) + if err != nil { + return dbVal, err + } + defer func() { + if err := cache.Unlock(ctx, key, lockValue); err != nil { + log.ZError(ctx, "unlock cache key", err, "key", key, "value", lockValue) + } + }() + dbVal, ok, err = getDB() + if err != nil { + return dbVal, err + } + if ok { + return dbVal, nil + } + val, err := fn(ctx) + if err != nil { + return val, err + } + data, err := json.Marshal(val) + if err != nil { + return val, err + } + if err := cache.Set(ctx, key, string(data), expireTime); err != nil { + return val, err + } + return val, nil +} diff --git a/pkg/common/storage/cache/redis/seq_conversation.go b/pkg/common/storage/cache/redis/seq_conversation.go index 2ba69a7d6..604826598 100644 --- a/pkg/common/storage/cache/redis/seq_conversation.go +++ b/pkg/common/storage/cache/redis/seq_conversation.go @@ -9,6 +9,7 @@ import ( "github.com/openimsdk/open-im-server/v3/pkg/common/storage/cache" "github.com/openimsdk/open-im-server/v3/pkg/common/storage/cache/cachekey" + "github.com/openimsdk/open-im-server/v3/pkg/common/storage/cache/mcache" "github.com/openimsdk/open-im-server/v3/pkg/common/storage/database" "github.com/openimsdk/open-im-server/v3/pkg/msgprocessor" "github.com/openimsdk/tools/errs" @@ -17,6 +18,9 @@ import ( ) func NewSeqConversationCacheRedis(rdb redis.UniversalClient, mgo database.SeqConversation) cache.SeqConversationCache { + if rdb == nil { + return mcache.NewSeqConversationCache(mgo) + } return &seqConversationCacheRedis{ mgo: mgo, lockTime: time.Second * 3, diff --git a/pkg/common/storage/cache/redis/seq_user.go b/pkg/common/storage/cache/redis/seq_user.go index ad289be07..af9cbef5a 100644 --- a/pkg/common/storage/cache/redis/seq_user.go +++ b/pkg/common/storage/cache/redis/seq_user.go @@ -72,6 +72,9 @@ func (s *seqUserCacheRedis) GetUserReadSeq(ctx context.Context, conversationID s } func (s *seqUserCacheRedis) SetUserReadSeq(ctx context.Context, conversationID string, userID string, seq int64) error { + if s.rocks.GetRedis() == nil { + return s.SetUserReadSeqToDB(ctx, conversationID, userID, seq) + } dbSeq, err := s.GetUserReadSeq(ctx, conversationID, userID) if err != nil { return err