diff --git a/internal/common/check/conversation.go b/internal/common/check/conversation.go index 9ebb1b299..5d6e25537 100644 --- a/internal/common/check/conversation.go +++ b/internal/common/check/conversation.go @@ -17,13 +17,13 @@ func NewConversationChecker(zk discoveryRegistry.SvcDiscoveryRegistry) *Conversa return &ConversationChecker{zk: zk} } -func (c *ConversationChecker) ModifyConversationField(ctx context.Context, req *pbConversation.ModifyConversationFieldReq) (resp *pbConversation.ModifyConversationFieldResp, err error) { +func (c *ConversationChecker) ModifyConversationField(ctx context.Context, req *pbConversation.ModifyConversationFieldReq) error { cc, err := c.getConn() if err != nil { - return nil, err + return err } - resp, err = conversation.NewConversationClient(cc).ModifyConversationField(ctx, req) - return + _, err = conversation.NewConversationClient(cc).ModifyConversationField(ctx, req) + return err } func (c *ConversationChecker) getConn() (*grpc.ClientConn, error) { diff --git a/internal/rpc/group/db_map.go b/internal/rpc/group/db_map.go index ad9ea4e75..29bfd0de4 100644 --- a/internal/rpc/group/db_map.go +++ b/internal/rpc/group/db_map.go @@ -12,7 +12,7 @@ func UpdateGroupInfoMap(group *sdkws.GroupInfoForSet) map[string]any { m["group_name"] = group.GroupName } if group.Notification != "" { - m["notification"] = group.Notification + m["Notification"] = group.Notification } if group.Introduction != "" { m["introduction"] = group.Introduction diff --git a/internal/rpc/group/group.go b/internal/rpc/group/group.go index 824748eaf..b116f833f 100644 --- a/internal/rpc/group/group.go +++ b/internal/rpc/group/group.go @@ -11,12 +11,13 @@ import ( "Open_IM/pkg/common/db/unrelation" "Open_IM/pkg/common/tokenverify" "Open_IM/pkg/common/tracelog" - discoveryRegistry "Open_IM/pkg/discoveryregistry" + pbConversation "Open_IM/pkg/proto/conversation" pbGroup "Open_IM/pkg/proto/group" "Open_IM/pkg/proto/sdkws" "Open_IM/pkg/utils" "context" "fmt" + "github.com/OpenIMSDK/openKeeper" "google.golang.org/grpc" "gorm.io/gorm" "math/big" @@ -26,20 +27,35 @@ import ( "time" ) -func Start(server *grpc.Server) error { +func Start(client *openKeeper.ZkClient, server *grpc.Server) error { + mysql, err := relation.NewGormDB() + if err != nil { + return err + } + if err := mysql.AutoMigrate(&relationTb.GroupModel{}, &relationTb.GroupMemberModel{}, &relationTb.GroupRequestModel{}); err != nil { + return err + } + redis, err := cache.NewRedis() + if err != nil { + return err + } + mongo, err := unrelation.NewMongo() + if err != nil { + return err + } pbGroup.RegisterGroupServer(server, &groupServer{ - GroupInterface: controller.NewGroupInterface(nil, cache.NewRedis().GetClient(), unrelation.NewMongo().GetClient()), - registerCenter: nil, - user: check.NewUserCheck(nil), + GroupInterface: controller.NewGroupInterface(mysql, redis.GetClient(), mongo.GetClient()), + UserCheck: check.NewUserCheck(client), + ConversationChecker: check.NewConversationChecker(client), }) return nil } type groupServer struct { - GroupInterface controller.GroupInterface - registerCenter discoveryRegistry.SvcDiscoveryRegistry - user *check.UserCheck - notification *notification.Check + GroupInterface controller.GroupInterface + UserCheck *check.UserCheck + Notification *notification.Check + ConversationChecker *check.ConversationChecker } func (s *groupServer) CheckGroupAdmin(ctx context.Context, groupID string) error { @@ -59,7 +75,7 @@ func (s *groupServer) GetUsernameMap(ctx context.Context, userIDs []string, comp if len(userIDs) == 0 { return map[string]string{}, nil } - users, err := s.user.GetPublicUserInfos(ctx, userIDs, complete) + users, err := s.UserCheck.GetPublicUserInfos(ctx, userIDs, complete) if err != nil { return nil, err } @@ -68,7 +84,19 @@ func (s *groupServer) GetUsernameMap(ctx context.Context, userIDs []string, comp }), nil } -func (s *groupServer) GroupNotification(ctx context.Context, groupID string) { +func (s *groupServer) GroupNotification(ctx context.Context, groupID string, userIDs []string) { + + s.ConversationChecker.ModifyConversationField(ctx, &pbConversation.ModifyConversationFieldReq{ + Conversation: &pbConversation.Conversation{ + OwnerUserID: tracelog.GetOpUserID(ctx), + ConversationID: utils.GetConversationIDBySessionType(groupID, constant.GroupChatType), + ConversationType: constant.GroupChatType, + GroupID: groupID, + }, + FieldType: constant.FieldGroupAtType, + UserIDList: userIDs, + }) + // todo 群公告修改通知 //var conversationReq pbConversation.ModifyConversationFieldReq //conversation := pbConversation.Conversation{ @@ -130,7 +158,7 @@ func (s *groupServer) CreateGroup(ctx context.Context, req *pbGroup.CreateGroupR if utils.Duplicate(userIDs) { return nil, constant.ErrArgs.Wrap("group member repeated") } - userMap, err := s.user.GetUsersInfoMap(ctx, userIDs, true) + userMap, err := s.UserCheck.GetUsersInfoMap(ctx, userIDs, true) if err != nil { return nil, err } @@ -183,11 +211,11 @@ func (s *groupServer) CreateGroup(ctx context.Context, req *pbGroup.CreateGroupR if req.GroupInfo.GroupType == constant.SuperGroup { go func() { for _, userID := range userIDs { - s.notification.SuperGroupNotification(ctx, userID, userID) + s.Notification.SuperGroupNotification(ctx, userID, userID) } }() } else { - s.notification.GroupCreatedNotification(ctx, group.GroupID, userIDs) + s.Notification.GroupCreatedNotification(ctx, group.GroupID, userIDs) } return resp, nil } @@ -256,7 +284,7 @@ func (s *groupServer) InviteUserToGroup(ctx context.Context, req *pbGroup.Invite if ids := utils.Single(req.InvitedUserIDs, utils.Keys(memberMap)); len(ids) > 0 { return nil, constant.ErrArgs.Wrap("user in group " + strings.Join(ids, ",")) } - userMap, err := s.user.GetUsersInfoMap(ctx, req.InvitedUserIDs, true) + userMap, err := s.UserCheck.GetUsersInfoMap(ctx, req.InvitedUserIDs, true) if err != nil { return nil, err } @@ -281,7 +309,7 @@ func (s *groupServer) InviteUserToGroup(ctx context.Context, req *pbGroup.Invite return nil, err } for _, request := range requests { - s.notification.JoinGroupApplicationNotification(ctx, &pbGroup.JoinGroupReq{ + s.Notification.JoinGroupApplicationNotification(ctx, &pbGroup.JoinGroupReq{ GroupID: request.GroupID, ReqMessage: request.ReqMsg, JoinSource: request.JoinSource, @@ -297,7 +325,7 @@ func (s *groupServer) InviteUserToGroup(ctx context.Context, req *pbGroup.Invite return nil, err } for _, userID := range req.InvitedUserIDs { - s.notification.SuperGroupNotification(ctx, userID, userID) + s.Notification.SuperGroupNotification(ctx, userID, userID) } } else { opUserID := tracelog.GetOpUserID(ctx) @@ -318,7 +346,7 @@ func (s *groupServer) InviteUserToGroup(ctx context.Context, req *pbGroup.Invite if err := s.GroupInterface.CreateGroup(ctx, nil, groupMembers); err != nil { return nil, err } - s.notification.MemberInvitedNotification(ctx, req.GroupID, req.Reason, req.InvitedUserIDs) + s.Notification.MemberInvitedNotification(ctx, req.GroupID, req.Reason, req.InvitedUserIDs) } return resp, nil } @@ -395,7 +423,7 @@ func (s *groupServer) KickGroupMember(ctx context.Context, req *pbGroup.KickGrou } go func() { for _, userID := range req.KickedUserIDs { - s.notification.SuperGroupNotification(ctx, userID, userID) + s.Notification.SuperGroupNotification(ctx, userID, userID) } }() } else { @@ -435,7 +463,7 @@ func (s *groupServer) KickGroupMember(ctx context.Context, req *pbGroup.KickGrou if err := s.GroupInterface.DeleteGroupMember(ctx, group.GroupID, req.KickedUserIDs); err != nil { return nil, err } - s.notification.MemberKickedNotification(ctx, req, req.KickedUserIDs) + s.Notification.MemberKickedNotification(ctx, req, req.KickedUserIDs) } return resp, nil } @@ -487,7 +515,7 @@ func (s *groupServer) GetGroupApplicationList(ctx context.Context, req *pbGroup. } userIDs = utils.Distinct(userIDs) groupIDs = utils.Distinct(groupIDs) - userMap, err := s.user.GetPublicUserInfoMap(ctx, userIDs, true) + userMap, err := s.UserCheck.GetPublicUserInfoMap(ctx, userIDs, true) if err != nil { return nil, err } @@ -575,7 +603,7 @@ func (s *groupServer) GroupApplicationResponse(ctx context.Context, req *pbGroup } else if !s.IsNotFound(err) { return nil, err } - user, err := s.user.GetPublicUserInfo(ctx, req.FromUserID) + user, err := s.UserCheck.GetPublicUserInfo(ctx, req.FromUserID) if err != nil { return nil, err } @@ -602,10 +630,10 @@ func (s *groupServer) GroupApplicationResponse(ctx context.Context, req *pbGroup } if !join { if req.HandleResult == constant.GroupResponseAgree { - s.notification.GroupApplicationAcceptedNotification(ctx, req) - s.notification.MemberEnterNotification(ctx, req) + s.Notification.GroupApplicationAcceptedNotification(ctx, req) + s.Notification.MemberEnterNotification(ctx, req) } else if req.HandleResult == constant.GroupResponseRefuse { - s.notification.GroupApplicationRejectedNotification(ctx, req) + s.Notification.GroupApplicationRejectedNotification(ctx, req) } } return resp, nil @@ -613,7 +641,7 @@ func (s *groupServer) GroupApplicationResponse(ctx context.Context, req *pbGroup func (s *groupServer) JoinGroup(ctx context.Context, req *pbGroup.JoinGroupReq) (*pbGroup.JoinGroupResp, error) { resp := &pbGroup.JoinGroupResp{} - if _, err := s.user.GetPublicUserInfo(ctx, tracelog.GetOpUserID(ctx)); err != nil { + if _, err := s.UserCheck.GetPublicUserInfo(ctx, tracelog.GetOpUserID(ctx)); err != nil { return nil, err } group, err := s.GroupInterface.TakeGroup(ctx, req.GroupID) @@ -643,7 +671,7 @@ func (s *groupServer) JoinGroup(ctx context.Context, req *pbGroup.JoinGroupReq) if err := s.GroupInterface.CreateGroup(ctx, nil, []*relationTb.GroupMemberModel{groupMember}); err != nil { return nil, err } - s.notification.MemberEnterDirectlyNotification(ctx, req.GroupID, tracelog.GetOpUserID(ctx), tracelog.GetOperationID(ctx)) + s.Notification.MemberEnterDirectlyNotification(ctx, req.GroupID, tracelog.GetOpUserID(ctx), tracelog.GetOperationID(ctx)) return resp, nil } groupRequest := relationTb.GroupRequestModel{ @@ -656,7 +684,7 @@ func (s *groupServer) JoinGroup(ctx context.Context, req *pbGroup.JoinGroupReq) if err := s.GroupInterface.CreateGroupRequest(ctx, []*relationTb.GroupRequestModel{&groupRequest}); err != nil { return nil, err } - s.notification.JoinGroupApplicationNotification(ctx, req) + s.Notification.JoinGroupApplicationNotification(ctx, req) return resp, nil } @@ -670,13 +698,13 @@ func (s *groupServer) QuitGroup(ctx context.Context, req *pbGroup.QuitGroupReq) if err := s.GroupInterface.DeleteSuperGroupMember(ctx, req.GroupID, []string{tracelog.GetOpUserID(ctx)}); err != nil { return nil, err } - s.notification.SuperGroupNotification(ctx, tracelog.GetOpUserID(ctx), tracelog.GetOpUserID(ctx)) + s.Notification.SuperGroupNotification(ctx, tracelog.GetOpUserID(ctx), tracelog.GetOpUserID(ctx)) } else { _, err := s.GroupInterface.TakeGroupMember(ctx, req.GroupID, tracelog.GetOpUserID(ctx)) if err != nil { return nil, err } - s.notification.MemberQuitNotification(ctx, req) + s.Notification.MemberQuitNotification(ctx, req) } return resp, nil } @@ -699,6 +727,10 @@ func (s *groupServer) SetGroupInfo(ctx context.Context, req *pbGroup.SetGroupInf if group.Status == constant.GroupStatusDismissed { return nil, utils.Wrap(constant.ErrDismissedAlready, "") } + userIDs, err := s.GroupInterface.FindGroupMemberUserID(ctx, group.GroupID) + if err != nil { + return nil, err + } data := UpdateGroupInfoMap(req.GroupInfoForSet) if len(data) > 0 { return resp, nil @@ -710,9 +742,21 @@ func (s *groupServer) SetGroupInfo(ctx context.Context, req *pbGroup.SetGroupInf if err != nil { return nil, err } - s.notification.GroupInfoSetNotification(ctx, req.GroupInfoForSet.GroupID, group.GroupName, group.Notification, group.Introduction, group.FaceURL, req.GroupInfoForSet.NeedVerification) + s.Notification.GroupInfoSetNotification(ctx, req.GroupInfoForSet.GroupID, group.GroupName, group.Notification, group.Introduction, group.FaceURL, req.GroupInfoForSet.NeedVerification) if req.GroupInfoForSet.Notification != "" { - s.GroupNotification(ctx, group.GroupID) + args := pbConversation.ModifyConversationFieldReq{ + Conversation: &pbConversation.Conversation{ + OwnerUserID: tracelog.GetOpUserID(ctx), + ConversationID: utils.GetConversationIDBySessionType(group.GroupID, constant.GroupChatType), + ConversationType: constant.GroupChatType, + GroupID: group.GroupID, + }, + FieldType: constant.FieldGroupAtType, + UserIDList: userIDs, + } + if err := s.ConversationChecker.ModifyConversationField(ctx, &args); err != nil { + tracelog.SetCtxWarn(ctx, "ModifyConversationField", err, args) + } } return resp, nil } @@ -760,7 +804,7 @@ func (s *groupServer) TransferGroupOwner(ctx context.Context, req *pbGroup.Trans if err := s.GroupInterface.TransferGroupOwner(ctx, req.GroupID, req.OldOwnerUserID, req.NewOwnerUserID, newOwner.RoleLevel); err != nil { return nil, err } - s.notification.GroupOwnerTransferredNotification(ctx, req) + s.Notification.GroupOwnerTransferredNotification(ctx, req) return resp, nil } @@ -827,7 +871,7 @@ func (s *groupServer) GetGroupMembersCMS(ctx context.Context, req *pbGroup.GetGr func (s *groupServer) GetUserReqApplicationList(ctx context.Context, req *pbGroup.GetUserReqApplicationListReq) (*pbGroup.GetUserReqApplicationListResp, error) { resp := &pbGroup.GetUserReqApplicationListResp{} - user, err := s.user.GetPublicUserInfo(ctx, req.UserID) + user, err := s.UserCheck.GetPublicUserInfo(ctx, req.UserID) if err != nil { return nil, err } @@ -892,7 +936,7 @@ func (s *groupServer) DismissGroup(ctx context.Context, req *pbGroup.DismissGrou return nil, err } } else { - s.notification.GroupDismissedNotification(ctx, req) + s.Notification.GroupDismissedNotification(ctx, req) } return resp, nil } @@ -916,7 +960,7 @@ func (s *groupServer) MuteGroupMember(ctx context.Context, req *pbGroup.MuteGrou if err := s.GroupInterface.UpdateGroupMember(ctx, member.GroupID, member.UserID, data); err != nil { return nil, err } - s.notification.GroupMemberMutedNotification(ctx, req.GroupID, req.UserID, req.MutedSeconds) + s.Notification.GroupMemberMutedNotification(ctx, req.GroupID, req.UserID, req.MutedSeconds) return resp, nil } @@ -939,7 +983,7 @@ func (s *groupServer) CancelMuteGroupMember(ctx context.Context, req *pbGroup.Ca if err := s.GroupInterface.UpdateGroupMember(ctx, member.GroupID, member.UserID, data); err != nil { return nil, err } - s.notification.GroupMemberCancelMutedNotification(ctx, req.GroupID, req.UserID) + s.Notification.GroupMemberCancelMutedNotification(ctx, req.GroupID, req.UserID) return resp, nil } @@ -951,7 +995,7 @@ func (s *groupServer) MuteGroup(ctx context.Context, req *pbGroup.MuteGroupReq) if err := s.GroupInterface.UpdateGroup(ctx, req.GroupID, UpdateGroupStatusMap(constant.GroupStatusMuted)); err != nil { return nil, err } - s.notification.GroupMutedNotification(ctx, req.GroupID) + s.Notification.GroupMutedNotification(ctx, req.GroupID) return resp, nil } @@ -963,7 +1007,7 @@ func (s *groupServer) CancelMuteGroup(ctx context.Context, req *pbGroup.CancelMu if err := s.GroupInterface.UpdateGroup(ctx, req.GroupID, UpdateGroupStatusMap(constant.GroupOk)); err != nil { return nil, err } - s.notification.GroupCancelMutedNotification(ctx, req.GroupID) + s.Notification.GroupCancelMutedNotification(ctx, req.GroupID) return resp, nil } @@ -1040,7 +1084,7 @@ func (s *groupServer) SetGroupMemberInfo(ctx context.Context, req *pbGroup.SetGr return nil, err } for _, member := range req.Members { - s.notification.GroupMemberInfoSetNotification(ctx, member.GroupID, member.UserID) + s.Notification.GroupMemberInfoSetNotification(ctx, member.GroupID, member.UserID) } return resp, nil } diff --git a/internal/rpc/msg/send_pull.go b/internal/rpc/msg/send_pull.go index ebfde0851..4493f9768 100644 --- a/internal/rpc/msg/send_pull.go +++ b/internal/rpc/msg/send_pull.go @@ -203,7 +203,7 @@ func (m *msgServer) sendMsgGroupChat(ctx context.Context, req *msg.SendMsgReq) ( conversation.GroupAtType = constant.AtMe } - _, err := m.Conversation.ModifyConversationField(ctx, &conversationReq) + err := m.Conversation.ModifyConversationField(ctx, &conversationReq) if err != nil { return } @@ -211,7 +211,7 @@ func (m *msgServer) sendMsgGroupChat(ctx context.Context, req *msg.SendMsgReq) ( if tag { conversationReq.UserIDList = utils.DifferenceString(atUserID, memberUserIDList) conversation.GroupAtType = constant.AtAll - _, err := m.Conversation.ModifyConversationField(ctx, &conversationReq) + err := m.Conversation.ModifyConversationField(ctx, &conversationReq) if err != nil { return } diff --git a/internal/startrpc/start.go b/internal/startrpc/start.go index ba4b6a5b6..0bdc2a51e 100644 --- a/internal/startrpc/start.go +++ b/internal/startrpc/start.go @@ -15,7 +15,7 @@ import ( "net" ) -func start(rpcPort int, rpcRegisterName string, prometheusPort int, rpcFn func(server *grpc.Server) error, options []grpc.ServerOption) error { +func start(rpcPort int, rpcRegisterName string, prometheusPort int, rpcFn func(client *openKeeper.ZkClient, server *grpc.Server) error, options []grpc.ServerOption) error { flagRpcPort := flag.Int("port", rpcPort, "get RpcGroupPort from cmd,default 16000 as port") flagPrometheusPort := flag.Int("prometheus_port", prometheusPort, "groupPrometheusPort default listen port") flag.Parse() @@ -60,10 +60,10 @@ func start(rpcPort int, rpcRegisterName string, prometheusPort int, rpcFn func(s return err } } - return rpcFn(srv) + return rpcFn(zkClient, srv) } -func Start(rpcPort int, rpcRegisterName string, prometheusPort int, rpcFn func(server *grpc.Server) error, options ...grpc.ServerOption) { +func Start(rpcPort int, rpcRegisterName string, prometheusPort int, rpcFn func(client *openKeeper.ZkClient, server *grpc.Server) error, options ...grpc.ServerOption) { err := start(rpcPort, rpcRegisterName, prometheusPort, rpcFn, options) fmt.Println("end", err) } diff --git a/pkg/common/db/cache/redis.go b/pkg/common/db/cache/redis.go index fbafdd83e..34c5ee94b 100644 --- a/pkg/common/db/cache/redis.go +++ b/pkg/common/db/cache/redis.go @@ -50,20 +50,14 @@ type Cache interface { // native redis operate -func NewRedis() *RedisClient { - o := &RedisClient{} - o.InitRedis() - return o -} +//func NewRedis() *RedisClient { +// o := &RedisClient{} +// o.InitRedis() +// return o +//} -type RedisClient struct { - rdb redis.UniversalClient -} - -func (r *RedisClient) InitRedis() { +func NewRedis() (*RedisClient, error) { var rdb redis.UniversalClient - var err error - ctx := context.Background() if config.Config.Redis.EnableCluster { rdb = redis.NewClusterClient(&redis.ClusterOptions{ Addrs: config.Config.Redis.DBAddress, @@ -71,11 +65,10 @@ func (r *RedisClient) InitRedis() { Password: config.Config.Redis.DBPassWord, // no password set PoolSize: 50, }) - _, err = rdb.Ping(ctx).Result() - if err != nil { - fmt.Println("redis cluster failed address ", config.Config.Redis.DBAddress) - panic(err.Error() + " redis cluster " + config.Config.Redis.DBUserName + config.Config.Redis.DBPassWord) - } + //if err := rdb.Ping(ctx).Err();err != nil { + // return nil, fmt.Errorf("redis ping %w", err) + //} + //return &RedisClient{rdb: rdb}, nil } else { rdb = redis.NewClient(&redis.Options{ Addr: config.Config.Redis.DBAddress[0], @@ -84,21 +77,63 @@ func (r *RedisClient) InitRedis() { DB: 0, // use default DB PoolSize: 100, // 连接池大小 }) - _, err = rdb.Ping(ctx).Result() - if err != nil { - panic(err.Error() + " redis " + config.Config.Redis.DBAddress[0] + config.Config.Redis.DBUserName + config.Config.Redis.DBPassWord) - } + //err := rdb.Ping(ctx).Err() + //if err != nil { + // panic(err.Error() + " redis " + config.Config.Redis.DBAddress[0] + config.Config.Redis.DBUserName + config.Config.Redis.DBPassWord) + //} } - r.rdb = rdb + ctx, cancel := context.WithTimeout(context.Background(), time.Second*10) + defer cancel() + err := rdb.Ping(ctx).Err() + if err != nil { + return nil, fmt.Errorf("redis ping %w", err) + } + return &RedisClient{rdb: rdb}, nil } +type RedisClient struct { + rdb redis.UniversalClient +} + +//func (r *RedisClient) InitRedis() { +// var rdb redis.UniversalClient +// var err error +// ctx := context.Background() +// if config.Config.Redis.EnableCluster { +// rdb = redis.NewClusterClient(&redis.ClusterOptions{ +// Addrs: config.Config.Redis.DBAddress, +// Username: config.Config.Redis.DBUserName, +// Password: config.Config.Redis.DBPassWord, // no password set +// PoolSize: 50, +// }) +// _, err = rdb.Ping(ctx).Result() +// if err != nil { +// fmt.Println("redis cluster failed address ", config.Config.Redis.DBAddress) +// panic(err.Error() + " redis cluster " + config.Config.Redis.DBUserName + config.Config.Redis.DBPassWord) +// } +// } else { +// rdb = redis.NewClient(&redis.Options{ +// Addr: config.Config.Redis.DBAddress[0], +// Username: config.Config.Redis.DBUserName, +// Password: config.Config.Redis.DBPassWord, // no password set +// DB: 0, // use default DB +// PoolSize: 100, // 连接池大小 +// }) +// _, err = rdb.Ping(ctx).Result() +// if err != nil { +// panic(err.Error() + " redis " + config.Config.Redis.DBAddress[0] + config.Config.Redis.DBUserName + config.Config.Redis.DBPassWord) +// } +// } +// r.rdb = rdb +//} + func (r *RedisClient) GetClient() redis.UniversalClient { return r.rdb } -func NewRedisClient(rdb redis.UniversalClient) *RedisClient { - return &RedisClient{rdb: rdb} -} +//func NewRedisClient(rdb redis.UniversalClient) *RedisClient { +// return &RedisClient{rdb: rdb} +//} // Perform seq auto-increment operation of user messages func (r *RedisClient) IncrUserSeq(uid string) (uint64, error) { diff --git a/pkg/common/db/relation/mysql_init.go b/pkg/common/db/relation/mysql_init.go new file mode 100644 index 000000000..d4461c294 --- /dev/null +++ b/pkg/common/db/relation/mysql_init.go @@ -0,0 +1,148 @@ +package relation + +import ( + "Open_IM/pkg/common/config" + "fmt" + "time" + + "gorm.io/driver/mysql" + "gorm.io/gorm" + "gorm.io/gorm/logger" +) + +func NewGormDB() (*gorm.DB, error) { + dsn := fmt.Sprintf("%s:%s@tcp(%s)/%s?charset=utf8mb4&parseTime=true&loc=Local", + config.Config.Mysql.DBUserName, config.Config.Mysql.DBPassword, config.Config.Mysql.DBAddress[0], "mysql") + db, err := gorm.Open(mysql.Open(dsn), nil) + if err != nil { + time.Sleep(time.Duration(30) * time.Second) + db, err = gorm.Open(mysql.Open(dsn), nil) + if err != nil { + panic(err.Error() + " open failed " + dsn) + } + } + sqlDB, err := db.DB() + if err != nil { + return nil, err + } + defer sqlDB.Close() + sql := fmt.Sprintf("CREATE DATABASE IF NOT EXISTS %s default charset utf8 COLLATE utf8_general_ci;", config.Config.Mysql.DBDatabaseName) + err = db.Exec(sql).Error + if err != nil { + return nil, fmt.Errorf("init db %w", err) + } + dsn = fmt.Sprintf("%s:%s@tcp(%s)/%s?charset=utf8mb4&parseTime=true&loc=Local", + config.Config.Mysql.DBUserName, config.Config.Mysql.DBPassword, config.Config.Mysql.DBAddress[0], config.Config.Mysql.DBDatabaseName) + newLogger := logger.New( + Writer{}, + logger.Config{ + SlowThreshold: time.Duration(config.Config.Mysql.SlowThreshold) * time.Millisecond, // Slow SQL threshold + LogLevel: logger.LogLevel(config.Config.Mysql.LogLevel), // Log level + IgnoreRecordNotFoundError: true, // Ignore ErrRecordNotFound error for logger + Colorful: true, // Disable color + }, + ) + db, err = gorm.Open(mysql.Open(dsn), &gorm.Config{ + Logger: newLogger, + }) + if err != nil { + return nil, err + } + sqlDB, err = db.DB() + if err != nil { + return nil, err + } + sqlDB.SetConnMaxLifetime(time.Second * time.Duration(config.Config.Mysql.DBMaxLifeTime)) + sqlDB.SetMaxOpenConns(config.Config.Mysql.DBMaxOpenConns) + sqlDB.SetMaxIdleConns(config.Config.Mysql.DBMaxIdleConns) + return db, nil +} + +type Mysql struct { + gormConn *gorm.DB +} + +func (m *Mysql) GormConn() *gorm.DB { + return m.gormConn +} + +//func (m *Mysql) SetGormConn(gormConn *gorm.DB) { +// m.gormConn = gormConn +//} +// +//func (m *Mysql) InitConn() *Mysql { +// dsn := fmt.Sprintf("%s:%s@tcp(%s)/%s?charset=utf8mb4&parseTime=true&loc=Local", +// config.Config.Mysql.DBUserName, config.Config.Mysql.DBPassword, config.Config.Mysql.DBAddress[0], "mysql") +// var db *gorm.DB +// db, err := gorm.Open(mysql.Open(dsn), nil) +// if err != nil { +// time.Sleep(time.Duration(30) * time.Second) +// db, err = gorm.Open(mysql.Open(dsn), nil) +// if err != nil { +// panic(err.Error() + " open failed " + dsn) +// } +// } +// sql := fmt.Sprintf("CREATE DATABASE IF NOT EXISTS %s default charset utf8 COLLATE utf8_general_ci;", config.Config.Mysql.DBDatabaseName) +// err = db.Exec(sql).Error +// if err != nil { +// panic(err.Error() + " Exec failed:" + sql) +// } +// dsn = fmt.Sprintf("%s:%s@tcp(%s)/%s?charset=utf8mb4&parseTime=true&loc=Local", +// config.Config.Mysql.DBUserName, config.Config.Mysql.DBPassword, config.Config.Mysql.DBAddress[0], config.Config.Mysql.DBDatabaseName) +// newLogger := logger.New( +// Writer{}, +// logger.Config{ +// SlowThreshold: time.Duration(config.Config.Mysql.SlowThreshold) * time.Millisecond, // Slow SQL threshold +// LogLevel: logger.LogLevel(config.Config.Mysql.LogLevel), // Log level +// IgnoreRecordNotFoundError: true, // Ignore ErrRecordNotFound error for logger +// Colorful: true, // Disable color +// }, +// ) +// db, err = gorm.Open(mysql.Open(dsn), &gorm.Config{ +// Logger: newLogger, +// }) +// if err != nil { +// panic(err.Error() + " Open failed " + dsn) +// } +// sqlDB, err := db.DB() +// if err != nil { +// panic(err.Error() + " DB.DB() failed ") +// } +// sqlDB.SetConnMaxLifetime(time.Second * time.Duration(config.Config.Mysql.DBMaxLifeTime)) +// sqlDB.SetMaxOpenConns(config.Config.Mysql.DBMaxOpenConns) +// sqlDB.SetMaxIdleConns(config.Config.Mysql.DBMaxIdleConns) +// if db == nil { +// panic("db is nil") +// } +// m.SetGormConn(db) +// return m +//} + +//models := []interface{}{&Friend{}, &FriendRequest{}, &Group{}, &GroupMember{}, &GroupRequest{}, +// &User{}, &Black{}, &ChatLog{}, &Conversation{}, &AppVersion{}} + +//func (m *Mysql) AutoMigrateModel(model interface{}) error { +// err := m.gormConn.AutoMigrate(model) +// if err != nil { +// return err +// } +// m.gormConn.Set("gorm:table_options", "CHARSET=utf8") +// m.gormConn.Set("gorm:table_options", "collation=utf8_unicode_ci") +// _ = m.gormConn.Migrator().CreateTable(model) +// return nil +//} + +type Writer struct{} + +func (w Writer) Printf(format string, args ...interface{}) { + fmt.Printf(format, args...) +} + +func getDBConn(db *gorm.DB, tx []any) *gorm.DB { + if len(tx) > 0 { + if txDB, ok := tx[0].(*gorm.DB); ok { + return txDB + } + } + return db +} diff --git a/pkg/common/db/unrelation/mongo.go b/pkg/common/db/unrelation/mongo.go index 7500214ae..9a4c9159a 100644 --- a/pkg/common/db/unrelation/mongo.go +++ b/pkg/common/db/unrelation/mongo.go @@ -13,10 +13,45 @@ import ( "time" ) -func NewMongo() *Mongo { - mgo := &Mongo{} - mgo.InitMongo() - return mgo +//func NewMongo() *Mongo { +// mgo := &Mongo{} +// mgo.InitMongo() +// return mgo +//} + +func NewMongo() (*Mongo, error) { + uri := "mongodb://sample.host:27017/?maxPoolSize=20&w=majority" + if config.Config.Mongo.DBUri != "" { + // example: mongodb://$user:$password@mongo1.mongo:27017,mongo2.mongo:27017,mongo3.mongo:27017/$DBDatabase/?replicaSet=rs0&readPreference=secondary&authSource=admin&maxPoolSize=$DBMaxPoolSize + uri = config.Config.Mongo.DBUri + } else { + //mongodb://mongodb1.example.com:27317,mongodb2.example.com:27017/?replicaSet=mySet&authSource=authDB + mongodbHosts := "" + for i, v := range config.Config.Mongo.DBAddress { + if i == len(config.Config.Mongo.DBAddress)-1 { + mongodbHosts += v + } else { + mongodbHosts += v + "," + } + } + if config.Config.Mongo.DBPassword != "" && config.Config.Mongo.DBUserName != "" { + uri = fmt.Sprintf("mongodb://%s:%s@%s/%s?maxPoolSize=%d&authSource=admin", + config.Config.Mongo.DBUserName, config.Config.Mongo.DBPassword, mongodbHosts, + config.Config.Mongo.DBDatabase, config.Config.Mongo.DBMaxPoolSize) + } else { + uri = fmt.Sprintf("mongodb://%s/%s/?maxPoolSize=%d&authSource=admin", + mongodbHosts, config.Config.Mongo.DBDatabase, + config.Config.Mongo.DBMaxPoolSize) + } + } + fmt.Println("mongo:", uri) + ctx, cancel := context.WithTimeout(context.Background(), time.Second*60) + defer cancel() + mongoClient, err := mongo.Connect(ctx, options.Client().ApplyURI(uri)) + if err != nil { + return nil, err + } + return &Mongo{db: mongoClient}, nil } type Mongo struct {