From e0a422fd1668282d6018575a72dd8fd6fb4ce716 Mon Sep 17 00:00:00 2001 From: withchao <993506633@qq.com> Date: Mon, 20 Mar 2023 14:36:42 +0800 Subject: [PATCH 1/2] code error --- pkg/errs/code.go | 2 +- pkg/errs/coderr.go | 17 ++++++++++++++--- pkg/errs/predefine.go | 2 +- pkg/errs/relation.go | 43 +++++++++++++++++++++++++++++++++++++++++++ 4 files changed, 59 insertions(+), 5 deletions(-) create mode 100644 pkg/errs/relation.go diff --git a/pkg/errs/code.go b/pkg/errs/code.go index 166cbd3b9..36c0e2e81 100644 --- a/pkg/errs/code.go +++ b/pkg/errs/code.go @@ -48,7 +48,7 @@ const ( // 群组错误码 GroupIDNotFoundError = 1201 //GroupID不存在 - GroupIDIDExisted = 1202 //GroupID已存在 + GroupIDExisted = 1202 //GroupID已存在 OnlyOneOwnerError = 1203 //只能有一个群主 InGroupAlreadyError = 1204 //已在群组中 NotInGroupYetError = 1205 //不在群组中 diff --git a/pkg/errs/coderr.go b/pkg/errs/coderr.go index 1ce329905..445ee0ddd 100644 --- a/pkg/errs/coderr.go +++ b/pkg/errs/coderr.go @@ -11,7 +11,8 @@ type CodeError interface { Msg() string Detail() string WithDetail(detail string) CodeError - Is(err error) bool + // Is 判断是否是某个错误, loose为false时, 只有错误码相同就认为是同一个错误, 默认为true + Is(err error, loose ...bool) bool Wrap(msg ...string) error error } @@ -59,13 +60,23 @@ func (e *codeError) Wrap(w ...string) error { return errors.Wrap(e, strings.Join(w, ", ")) } -func (e *codeError) Is(err error) bool { +func (e *codeError) Is(err error, loose ...bool) bool { if err == nil { return false } + var allowSubclasses bool + if len(loose) == 0 { + allowSubclasses = true + } else { + allowSubclasses = loose[0] + } codeErr, ok := Unwrap(err).(CodeError) if ok { - return codeErr.Code() == e.code + if allowSubclasses { + return Relation.Is(e.code, codeErr.Code()) + } else { + return codeErr.Code() == e.code + } } return false } diff --git a/pkg/errs/predefine.go b/pkg/errs/predefine.go index e18cdf4eb..6ba2f76f8 100644 --- a/pkg/errs/predefine.go +++ b/pkg/errs/predefine.go @@ -12,7 +12,7 @@ var ( ErrUserIDNotFound = NewCodeError(UserIDNotFoundError, "UserIDNotFoundError") ErrGroupIDNotFound = NewCodeError(GroupIDNotFoundError, "GroupIDNotFoundError") - ErrGroupIDExisted = NewCodeError(GroupIDIDExisted, "GroupIDExisted") + ErrGroupIDExisted = NewCodeError(GroupIDExisted, "GroupIDExisted") ErrUserIDExisted = NewCodeError(UserIDExisted, "UserIDExisted") ErrRecordNotFound = NewCodeError(RecordNotFoundError, "RecordNotFoundError") diff --git a/pkg/errs/relation.go b/pkg/errs/relation.go new file mode 100644 index 000000000..5d87ab59e --- /dev/null +++ b/pkg/errs/relation.go @@ -0,0 +1,43 @@ +package errs + +var Relation = &relation{m: make(map[int]map[int]struct{})} + +func init() { + Relation.Add(RecordNotFoundError, UserIDNotFoundError) + Relation.Add(RecordNotFoundError, GroupIDNotFoundError) + Relation.Add(DuplicateKeyError, UserIDExisted) + Relation.Add(DuplicateKeyError, GroupIDExisted) +} + +type relation struct { + m map[int]map[int]struct{} +} + +func (r *relation) Add(codes ...int) { + if len(codes) < 2 { + panic("codes length must be greater than 2") + } + for i := 1; i < len(codes); i++ { + parent := codes[i-1] + s, ok := r.m[parent] + if !ok { + s = make(map[int]struct{}) + r.m[parent] = s + } + for _, code := range codes[i:] { + s[code] = struct{}{} + } + } +} + +func (r *relation) Is(parent, child int) bool { + if parent == child { + return true + } + s, ok := r.m[parent] + if !ok { + return false + } + _, ok = s[child] + return ok +} From 762d6516a584aaeb5040513be3894b08f94b810c Mon Sep 17 00:00:00 2001 From: withchao <993506633@qq.com> Date: Mon, 20 Mar 2023 15:11:00 +0800 Subject: [PATCH 2/2] db --- pkg/common/db/{relation => ormutil}/utils.go | 12 ++++++------ pkg/common/db/relation/black_model.go | 3 ++- pkg/common/db/relation/group_member_model.go | 11 ++++++----- pkg/common/db/relation/group_model.go | 3 ++- pkg/common/db/relation/group_request_model.go | 3 ++- 5 files changed, 18 insertions(+), 14 deletions(-) rename pkg/common/db/{relation => ormutil}/utils.go (78%) diff --git a/pkg/common/db/relation/utils.go b/pkg/common/db/ormutil/utils.go similarity index 78% rename from pkg/common/db/relation/utils.go rename to pkg/common/db/ormutil/utils.go index fcb275b16..c64938f6c 100644 --- a/pkg/common/db/relation/utils.go +++ b/pkg/common/db/ormutil/utils.go @@ -1,11 +1,11 @@ -package relation +package ormutil import ( "github.com/OpenIMSDK/Open-IM-Server/pkg/utils" "gorm.io/gorm" ) -func gormPage[E any](db *gorm.DB, pageNumber, showNumber int32) (uint32, []*E, error) { +func GormPage[E any](db *gorm.DB, pageNumber, showNumber int32) (uint32, []*E, error) { var count int64 if err := db.Count(&count).Error; err != nil { return 0, nil, utils.Wrap(err, "") @@ -17,7 +17,7 @@ func gormPage[E any](db *gorm.DB, pageNumber, showNumber int32) (uint32, []*E, e return uint32(count), es, nil } -func gormSearch[E any](db *gorm.DB, fields []string, value string, pageNumber, showNumber int32) (uint32, []*E, error) { +func GormSearch[E any](db *gorm.DB, fields []string, value string, pageNumber, showNumber int32) (uint32, []*E, error) { if len(fields) > 0 && value != "" { value = "%" + value + "%" if len(fields) == 1 { @@ -30,17 +30,17 @@ func gormSearch[E any](db *gorm.DB, fields []string, value string, pageNumber, s db = db.Where(t) } } - return gormPage[E](db, pageNumber, showNumber) + return GormPage[E](db, pageNumber, showNumber) } -func gormIn[E any](db **gorm.DB, field string, es []E) { +func GormIn[E any](db **gorm.DB, field string, es []E) { if len(es) == 0 { return } *db = (*db).Where(field+" in (?)", es) } -func mapCount(db *gorm.DB, field string) (map[string]uint32, error) { +func MapCount(db *gorm.DB, field string) (map[string]uint32, error) { var items []struct { ID string `gorm:"column:id"` Count uint32 `gorm:"column:count"` diff --git a/pkg/common/db/relation/black_model.go b/pkg/common/db/relation/black_model.go index e7ffa9a49..77efa5502 100644 --- a/pkg/common/db/relation/black_model.go +++ b/pkg/common/db/relation/black_model.go @@ -2,6 +2,7 @@ package relation import ( "context" + "github.com/OpenIMSDK/Open-IM-Server/pkg/common/db/ormutil" "github.com/OpenIMSDK/Open-IM-Server/pkg/common/db/table/relation" "github.com/OpenIMSDK/Open-IM-Server/pkg/utils" @@ -50,7 +51,7 @@ func (b *BlackGorm) FindOwnerBlacks(ctx context.Context, ownerUserID string, pag if err != nil { return nil, 0, utils.Wrap(err, "") } - totalUint32, blacks, err := gormPage[relation.BlackModel](b.db(ctx), pageNumber, showNumber) + totalUint32, blacks, err := ormutil.GormPage[relation.BlackModel](b.db(ctx), pageNumber, showNumber) total = int64(totalUint32) return } diff --git a/pkg/common/db/relation/group_member_model.go b/pkg/common/db/relation/group_member_model.go index 42266c550..d63416e89 100644 --- a/pkg/common/db/relation/group_member_model.go +++ b/pkg/common/db/relation/group_member_model.go @@ -3,6 +3,7 @@ package relation import ( "context" "github.com/OpenIMSDK/Open-IM-Server/pkg/common/constant" + "github.com/OpenIMSDK/Open-IM-Server/pkg/common/db/ormutil" "github.com/OpenIMSDK/Open-IM-Server/pkg/common/db/table/relation" "github.com/OpenIMSDK/Open-IM-Server/pkg/utils" "gorm.io/gorm" @@ -71,14 +72,14 @@ func (g *GroupMemberGorm) TakeOwner(ctx context.Context, groupID string) (groupM func (g *GroupMemberGorm) SearchMember(ctx context.Context, keyword string, groupIDs []string, userIDs []string, roleLevels []int32, pageNumber, showNumber int32) (total uint32, groupList []*relation.GroupMemberModel, err error) { db := g.DB - gormIn(&db, "group_id", groupIDs) - gormIn(&db, "user_id", userIDs) - gormIn(&db, "role_level", roleLevels) - return gormSearch[relation.GroupMemberModel](db, []string{"nickname"}, keyword, pageNumber, showNumber) + ormutil.GormIn(&db, "group_id", groupIDs) + ormutil.GormIn(&db, "user_id", userIDs) + ormutil.GormIn(&db, "role_level", roleLevels) + return ormutil.GormSearch[relation.GroupMemberModel](db, []string{"nickname"}, keyword, pageNumber, showNumber) } func (g *GroupMemberGorm) MapGroupMemberNum(ctx context.Context, groupIDs []string) (count map[string]uint32, err error) { - return mapCount(g.DB.Where("group_id in (?)", groupIDs), "group_id") + return ormutil.MapCount(g.DB.Where("group_id in (?)", groupIDs), "group_id") } func (g *GroupMemberGorm) FindJoinUserID(ctx context.Context, groupIDs []string) (groupUsers map[string][]string, err error) { diff --git a/pkg/common/db/relation/group_model.go b/pkg/common/db/relation/group_model.go index e4d5d9759..39feb5d7a 100644 --- a/pkg/common/db/relation/group_model.go +++ b/pkg/common/db/relation/group_model.go @@ -2,6 +2,7 @@ package relation import ( "context" + "github.com/OpenIMSDK/Open-IM-Server/pkg/common/db/ormutil" "github.com/OpenIMSDK/Open-IM-Server/pkg/common/db/table/relation" "github.com/OpenIMSDK/Open-IM-Server/pkg/utils" "gorm.io/gorm" @@ -43,7 +44,7 @@ func (g *GroupGorm) Take(ctx context.Context, groupID string) (group *relation.G } func (g *GroupGorm) Search(ctx context.Context, keyword string, pageNumber, showNumber int32) (total uint32, groups []*relation.GroupModel, err error) { - return gormSearch[relation.GroupModel](g.DB, []string{"name"}, keyword, pageNumber, showNumber) + return ormutil.GormSearch[relation.GroupModel](g.DB, []string{"name"}, keyword, pageNumber, showNumber) } func (g *GroupGorm) GetGroupIDsByGroupType(ctx context.Context, groupType int) (groupIDs []string, err error) { diff --git a/pkg/common/db/relation/group_request_model.go b/pkg/common/db/relation/group_request_model.go index b6dfa3b23..e7e15539b 100644 --- a/pkg/common/db/relation/group_request_model.go +++ b/pkg/common/db/relation/group_request_model.go @@ -2,6 +2,7 @@ package relation import ( "context" + "github.com/OpenIMSDK/Open-IM-Server/pkg/common/db/ormutil" "github.com/OpenIMSDK/Open-IM-Server/pkg/common/db/table/relation" "github.com/OpenIMSDK/Open-IM-Server/pkg/utils" @@ -39,5 +40,5 @@ func (g *GroupRequestGorm) Take(ctx context.Context, groupID string, userID stri } func (g *GroupRequestGorm) Page(ctx context.Context, userID string, pageNumber, showNumber int32) (total uint32, groups []*relation.GroupRequestModel, err error) { - return gormSearch[relation.GroupRequestModel](g.DB.Where("user_id = ?", userID), nil, "", pageNumber, showNumber) + return ormutil.GormSearch[relation.GroupRequestModel](g.DB.Where("user_id = ?", userID), nil, "", pageNumber, showNumber) }