From 863c52c7bb360e1b09ff282fbf3596755ced561c Mon Sep 17 00:00:00 2001
From: pluto <2631223275@qq.com>
Date: Fri, 4 Aug 2023 21:35:33 +0800
Subject: [PATCH] Improve user subscription and unsubscription (#770)

* Improve user subscription and unsubscription

* Modification only does not delete all subscribed documents when unsubscribing

* fix build

* update

* update

* update

* update

* add log

* update

* update

* update

* delete simple log
---
 internal/rpc/user/user.go              |  23 ++++--
 pkg/common/db/controller/user.go       |  40 +++++----
 pkg/common/db/table/unrelation/user.go |   4 +
 pkg/common/db/unrelation/user.go       | 107 +++++++++++++++++--------
 4 files changed, 116 insertions(+), 58 deletions(-)

diff --git a/internal/rpc/user/user.go b/internal/rpc/user/user.go
index 53598db0b..5425a9c13 100644
--- a/internal/rpc/user/user.go
+++ b/internal/rpc/user/user.go
@@ -250,18 +250,27 @@ func (s *userServer) GetAllUserID(ctx context.Context, req *pbuser.GetAllUserIDR
 
 // SubscribeOrCancelUsersStatus Subscribe online or cancel online users.
 func (s *userServer) SubscribeOrCancelUsersStatus(ctx context.Context, req *pbuser.SubscribeOrCancelUsersStatusReq) (resp *pbuser.SubscribeOrCancelUsersStatusResp, err error) {
-	err = s.UserDatabase.SubscribeOrCancelUsersStatus(ctx, req.UserID, req.UserIDs, req.Genre)
-	if err != nil {
-		return nil, err
+	if req.Genre == constant.SubscriberUser {
+		err = s.UserDatabase.SubscribeUsersStatus(ctx, req.UserID, req.UserIDs)
+		if err != nil {
+			return nil, err
+		}
+		var status []*pbuser.OnlineStatus
+		status, err = s.UserDatabase.GetUserStatus(ctx, req.UserIDs)
+		if err != nil {
+			return nil, err
+		}
+		return &pbuser.SubscribeOrCancelUsersStatusResp{StatusList: status}, nil
+	} else if req.Genre == constant.Unsubscribe {
+		err = s.UserDatabase.UnsubscribeUsersStatus(ctx, req.UserID, req.UserIDs)
+		if err != nil {
+			return nil, err
+		}
 	}
-	//var status map[string][]string
-	//TODO 获取用户在线列表,返回订阅的用户的在线列表
-
 	return &pbuser.SubscribeOrCancelUsersStatusResp{}, nil
 }
 
 func (s *userServer) GetUserStatus(ctx context.Context, req *pbuser.GetUserStatusReq) (resp *pbuser.GetUserStatusResp, err error) {
-	//TODO 是否加一个参数校验-判断req.userID的数量,每一个获取加一个限制,一次请求限制500?
 	onlineStatusList, err := s.UserDatabase.GetUserStatus(ctx, req.UserIDs)
 	if err != nil {
 		return nil, err
diff --git a/pkg/common/db/controller/user.go b/pkg/common/db/controller/user.go
index 5b303ebd7..a867899d9 100644
--- a/pkg/common/db/controller/user.go
+++ b/pkg/common/db/controller/user.go
@@ -17,7 +17,6 @@ package controller
 import (
 	"context"
 	unRelationTb "github.com/OpenIMSDK/Open-IM-Server/pkg/common/db/table/unrelation"
-	"github.com/OpenIMSDK/protocol/constant"
 	"github.com/OpenIMSDK/protocol/user"
 	"time"
 
@@ -51,8 +50,10 @@ type UserDatabase interface {
 	CountTotal(ctx context.Context, before *time.Time) (int64, error)
 	// CountRangeEverydayTotal Get the user increment in the range
 	CountRangeEverydayTotal(ctx context.Context, start time.Time, end time.Time) (map[string]int64, error)
-	//SubscribeOrCancelUsersStatus Subscribe or unsubscribe a user's presence status
-	SubscribeOrCancelUsersStatus(ctx context.Context, userID string, userIDs []string, genre int32) error
+	//SubscribeUsersStatus Subscribe a user's presence status
+	SubscribeUsersStatus(ctx context.Context, userID string, userIDs []string) error
+	// UnsubscribeUsersStatus unsubscribe a user's presence status
+	UnsubscribeUsersStatus(ctx context.Context, userID string, userIDs []string) error
 	// GetAllSubscribeList Get a list of all subscriptions
 	GetAllSubscribeList(ctx context.Context, userID string) ([]string, error)
 	// GetSubscribedList Get all subscribed lists
@@ -176,29 +177,34 @@ func (u *userDatabase) CountRangeEverydayTotal(ctx context.Context, start time.T
 	return u.userDB.CountRangeEverydayTotal(ctx, start, end)
 }
 
-//SubscribeOrCancelUsersStatus Subscribe or unsubscribe a user's presence status
-func (u *userDatabase) SubscribeOrCancelUsersStatus(ctx context.Context, userID string, userIDs []string, genre int32) error {
-	var err error
-	if genre == constant.SubscriberUser {
-		err = u.mongoDB.AddSubscriptionList(ctx, userID, userIDs)
-	} else if genre == constant.Unsubscribe {
-		err = u.mongoDB.UnsubscriptionList(ctx, userID, userIDs)
-	}
+// SubscribeUsersStatus Subscribe or unsubscribe a user's presence status
+func (u *userDatabase) SubscribeUsersStatus(ctx context.Context, userID string, userIDs []string) error {
+	err := u.mongoDB.AddSubscriptionList(ctx, userID, userIDs)
+	return err
+}
+
+// UnsubscribeUsersStatus unsubscribe a user's presence status
+func (u *userDatabase) UnsubscribeUsersStatus(ctx context.Context, userID string, userIDs []string) error {
+	err := u.mongoDB.UnsubscriptionList(ctx, userID, userIDs)
 	return err
 }
 
 // GetAllSubscribeList Get a list of all subscriptions.
 func (u *userDatabase) GetAllSubscribeList(ctx context.Context, userID string) ([]string, error) {
-
-	//TODO 获取所有订阅
-	return nil, nil
+	list, err := u.mongoDB.GetAllSubscribeList(ctx, userID)
+	if err != nil {
+		return nil, err
+	}
+	return list, nil
 }
 
 // GetSubscribedList Get all subscribed lists
 func (u *userDatabase) GetSubscribedList(ctx context.Context, userID string) ([]string, error) {
-
-	//TODO 获取所有被订阅
-	return nil, nil
+	list, err := u.mongoDB.GetSubscribedList(ctx, userID)
+	if err != nil {
+		return nil, err
+	}
+	return list, nil
 }
 
 // GetUserStatus get user status
diff --git a/pkg/common/db/table/unrelation/user.go b/pkg/common/db/table/unrelation/user.go
index d264da467..8664df262 100644
--- a/pkg/common/db/table/unrelation/user.go
+++ b/pkg/common/db/table/unrelation/user.go
@@ -39,4 +39,8 @@ type UserModelInterface interface {
 	UnsubscriptionList(ctx context.Context, userID string, userIDList []string) error
 	// RemoveSubscribedListFromUser Among the unsubscribed users, delete the user from the subscribed list.
 	RemoveSubscribedListFromUser(ctx context.Context, userID string, userIDList []string) error
+	// GetAllSubscribeList Get all users subscribed by this user
+	GetAllSubscribeList(ctx context.Context, id string) (userIDList []string, err error)
+	// GetSubscribedList Get the user subscribed by those users
+	GetSubscribedList(ctx context.Context, id string) (userIDList []string, err error)
 }
diff --git a/pkg/common/db/unrelation/user.go b/pkg/common/db/unrelation/user.go
index feec8aa21..4f1bbd017 100644
--- a/pkg/common/db/unrelation/user.go
+++ b/pkg/common/db/unrelation/user.go
@@ -17,14 +17,14 @@ package unrelation
 import (
 	"context"
 	"github.com/OpenIMSDK/Open-IM-Server/pkg/common/db/table/unrelation"
+	"github.com/OpenIMSDK/tools/errs"
 	"github.com/OpenIMSDK/tools/utils"
 	"go.mongodb.org/mongo-driver/bson"
 	"go.mongodb.org/mongo-driver/mongo"
 	"go.mongodb.org/mongo-driver/mongo/options"
-	"log"
 )
 
-//  prefixes and suffixes.
+// prefixes and suffixes.
 const (
 	SubscriptionPrefix = "subscription_prefix"
 	SubscribedPrefix   = "subscribed_prefix"
@@ -48,22 +48,35 @@ type UserMongoDriver struct {
 // AddSubscriptionList Subscriber's handling of thresholds.
 func (u *UserMongoDriver) AddSubscriptionList(ctx context.Context, userID string, userIDList []string) error {
 	// Check the number of lists in the key.
-	filter := bson.M{SubscriptionPrefix + userID: bson.M{"$size": 1}}
-	result, err := u.userCollection.Find(context.Background(), filter)
-	if err != nil {
-		return err
+	pipeline := mongo.Pipeline{
+		{{"$match", bson.D{{"user_id", SubscriptionPrefix + userID}}}},
+		{{"$project", bson.D{{"count", bson.D{{"$size", "$user_id_list"}}}}}},
 	}
-	var newUserIDList []string
-	for result.Next(context.Background()) {
-		err := result.Decode(&newUserIDList)
+	// perform aggregate operations
+	cursor, err := u.userCollection.Aggregate(ctx, pipeline)
+	if err != nil {
+		return errs.Wrap(err)
+	}
+	defer cursor.Close(ctx)
+	var cnt struct {
+		Count int `bson:"count"`
+	}
+	// iterate over aggregated results
+	for cursor.Next(ctx) {
+		err := cursor.Decode(&cnt)
 		if err != nil {
-			log.Fatal(err)
+			return errs.Wrap(err)
 		}
 	}
+	var newUserIDList []string
 	// If the threshold is exceeded, pop out the previous MaximumSubscription - len(userIDList) and insert it.
-	if len(newUserIDList)+len(userIDList) > MaximumSubscription {
+	if cnt.Count+len(userIDList) > MaximumSubscription {
+		newUserIDList, err = u.GetAllSubscribeList(ctx, userID)
+		if err != nil {
+			return err
+		}
 		newUserIDList = newUserIDList[MaximumSubscription-len(userIDList):]
-		_, err := u.userCollection.UpdateOne(
+		_, err = u.userCollection.UpdateOne(
 			ctx,
 			bson.M{"user_id": SubscriptionPrefix + userID},
 			bson.M{"$set": bson.M{"user_id_list": newUserIDList}},
@@ -71,16 +84,17 @@ func (u *UserMongoDriver) AddSubscriptionList(ctx context.Context, userID string
 		if err != nil {
 			return err
 		}
-		//for i := 1; i <= MaximumSubscription-len(userIDList); i++ {
-		//	_, err := u.userCollection.UpdateOne(
-		//		ctx,
-		//		bson.M{"user_id": SubscriptionPrefix + userID},
-		//		bson.M{SubscriptionPrefix + userID: bson.M{"$pop": -1}},
-		//	)
-		//	if err != nil {
-		//		return err
-		//	}
-		//}
+		// Another way to subscribe to N before pop,Delete after testing
+		/*for i := 1; i <= MaximumSubscription-len(userIDList); i++ {
+			_, err := u.userCollection.UpdateOne(
+				ctx,
+				bson.M{"user_id": SubscriptionPrefix + userID},
+				bson.M{SubscriptionPrefix + userID: bson.M{"$pop": -1}},
+			)
+			if err != nil {
+				return err
+			}
+		}*/
 	}
 	upsert := true
 	opts := &options.UpdateOptions{
@@ -93,7 +107,7 @@ func (u *UserMongoDriver) AddSubscriptionList(ctx context.Context, userID string
 		opts,
 	)
 	if err != nil {
-		return err
+		return errs.Wrap(err)
 	}
 	for _, user := range userIDList {
 		_, err = u.userCollection.UpdateOne(
@@ -117,25 +131,50 @@ func (u *UserMongoDriver) UnsubscriptionList(ctx context.Context, userID string,
 		bson.M{"$pull": bson.M{"user_id_list": bson.M{"$in": userIDList}}},
 	)
 	if err != nil {
-		return err
+		return errs.Wrap(err)
 	}
 	err = u.RemoveSubscribedListFromUser(ctx, userID, userIDList)
 	if err != nil {
-		return err
+		return errs.Wrap(err)
 	}
 	return nil
 }
 
 // RemoveSubscribedListFromUser Among the unsubscribed users, delete the user from the subscribed list.
 func (u *UserMongoDriver) RemoveSubscribedListFromUser(ctx context.Context, userID string, userIDList []string) error {
-	var newUserIDList []string
-	for _, value := range userIDList {
-		newUserIDList = append(newUserIDList, SubscribedPrefix+value)
+	var err error
+	for _, userIDTemp := range userIDList {
+		_, err = u.userCollection.UpdateOne(
+			ctx,
+			bson.M{"user_id": SubscribedPrefix + userIDTemp},
+			bson.M{"$pull": bson.M{"user_id_list": userID}},
+		)
 	}
-	_, err := u.userCollection.UpdateOne(
-		ctx,
-		bson.M{"user_id": bson.M{"$in": newUserIDList}},
-		bson.M{"$pull": bson.M{"user_id_list": userID}},
-	)
-	return utils.Wrap(err, "")
+	return errs.Wrap(err)
+}
+
+// GetAllSubscribeList Get all users subscribed by this user
+func (u *UserMongoDriver) GetAllSubscribeList(ctx context.Context, userID string) (userIDList []string, err error) {
+	var user unrelation.UserModel
+	cursor := u.userCollection.FindOne(
+		ctx,
+		bson.M{"user_id": SubscriptionPrefix + userID})
+	err = cursor.Decode(&user)
+	if err != nil {
+		return nil, errs.Wrap(err)
+	}
+	return user.UserIDList, nil
+}
+
+// GetSubscribedList Get the user subscribed by those users
+func (u *UserMongoDriver) GetSubscribedList(ctx context.Context, userID string) (userIDList []string, err error) {
+	var user unrelation.UserModel
+	cursor := u.userCollection.FindOne(
+		ctx,
+		bson.M{"user_id": SubscribedPrefix + userID})
+	err = cursor.Decode(&user)
+	if err != nil {
+		return nil, errs.Wrap(err)
+	}
+	return user.UserIDList, nil
 }