From 351c5bacfad36c5b122199e43e8e203bcb99765c Mon Sep 17 00:00:00 2001 From: wangchuxiao Date: Thu, 11 May 2023 15:30:06 +0800 Subject: [PATCH] pipeline --- pkg/common/db/cache/msg.go | 11 ++++------- pkg/common/db/unrelation/msg.go | 27 +++++++++++++++++++++++---- 2 files changed, 27 insertions(+), 11 deletions(-) diff --git a/pkg/common/db/cache/msg.go b/pkg/common/db/cache/msg.go index 9fd552773..9c700d6e0 100644 --- a/pkg/common/db/cache/msg.go +++ b/pkg/common/db/cache/msg.go @@ -3,7 +3,6 @@ package cache import ( "context" "errors" - "fmt" "strconv" "time" @@ -290,7 +289,7 @@ func (c *msgCache) GetMessagesBySeq(ctx context.Context, conversationID string, func (c *msgCache) SetMessageToCache(ctx context.Context, conversationID string, msgs []*sdkws.MsgData) (int, error) { pipe := c.rdb.Pipeline() - var failedMsgs []sdkws.MsgData + var failedMsgs []*sdkws.MsgData for _, msg := range msgs { key := c.getMessageCacheKey(conversationID, msg.Seq) s, err := utils.Pb2String(msg) @@ -299,14 +298,12 @@ func (c *msgCache) SetMessageToCache(ctx context.Context, conversationID string, } err = pipe.Set(ctx, key, s, time.Duration(config.Config.MsgCacheTimeout)*time.Second).Err() if err != nil { - return 0, errs.Wrap(err) + failedMsgs = append(failedMsgs, msg) + log.ZWarn(ctx, "set msg 2 cache failed", err, "msg", failedMsgs) } } - if len(failedMsgs) != 0 { - return len(failedMsgs), fmt.Errorf("set msg to msgCache failed, failed lists: %v, %s", failedMsgs, conversationID) - } _, err := pipe.Exec(ctx) - return 0, err + return len(failedMsgs), err } func (c *msgCache) DeleteMessageFromCache(ctx context.Context, userID string, msgList []*sdkws.MsgData) error { diff --git a/pkg/common/db/unrelation/msg.go b/pkg/common/db/unrelation/msg.go index aed86a06e..f3cc11db8 100644 --- a/pkg/common/db/unrelation/msg.go +++ b/pkg/common/db/unrelation/msg.go @@ -137,12 +137,31 @@ func (m *MsgMongoDriver) UpdateOneDoc(ctx context.Context, msg *table.MsgDocMode func (m *MsgMongoDriver) GetMsgBySeqIndexIn1Doc(ctx context.Context, docID string, beginSeq, endSeq int64) (msgs []*sdkws.MsgData, seqs []int64, err error) { beginIndex := m.msg.GetMsgIndex(beginSeq) num := endSeq - beginSeq + 1 - result, err := m.MsgCollection.Find(ctx, bson.M{"doc_id": docID, "msgs": bson.M{"$slice": []int64{beginIndex, num}}}) - if err != nil { - return nil, nil, err + + pipeline := bson.A{ + bson.M{ + "$match": bson.M{"doc_id": docID}, + }, + bson.M{ + "$project": bson.M{ + "doc_id": 1, + "msgs": bson.M{ + "$slice": []interface{}{"$msgs", beginIndex, num}, + }, + }, + }, } + cursor, err := m.MsgCollection.Aggregate(ctx, pipeline) + if err != nil { + return nil, nil, errs.Wrap(err) + } + + // result, err := m.MsgCollection.Find(ctx, bson.M{"doc_id": docID, "msgs": bson.M{"$slice": []int64{beginIndex, num}}}) + // if err != nil { + // return nil, nil, err + // } var msgInfos []table.MsgInfoModel - if err := result.Decode(&msgInfos); err != nil { + if err := cursor.All(ctx, &msgInfos); err != nil { return nil, nil, err } if len(msgInfos) < 1 {