feat: Integrate Comprehensive E2E Testing for GoChat (#1906)

* feat: create e2e test readme

Signed-off-by: Xinwei Xiong (cubxxw) <3293172751nss@gmail.com>

* feat: fix markdown file

* feat: add openim make lint

* feat: add git chglog pull request

* feat: add git chglog pull request

* fix: fix openim api err code

* fix: fix openim api err code

* fix: fix openim api err code

* feat: Improve CICD

* feat: Combining GitHub and Google Workspace for Effective Project Management'

* feat: fix openim tools error code

* feat: fix openim tools error code

* feat: add openim error handle

* feat: add openim error handle

* feat: optimize tim white prom code return err

* feat: fix openim tools error code

* style: format openim server code style

* feat: add openim optimize commit code

* feat: add openim optimize commit code

* feat: add openim auto format code

* feat: add openim auto format code

* feat: add openim auto format code

* feat: add openim auto format code

* feat: add openim auto format code

* feat: format openim code

* feat: Some of the notes were translated

* feat: Some of the notes were translated

* feat: update openim server code

* feat: optimize openim reset code

* feat: optimize openim reset code

---------

Signed-off-by: Xinwei Xiong (cubxxw) <3293172751nss@gmail.com>
This commit is contained in:
Xinwei Xiong 2024-03-04 12:12:14 +08:00 committed by GitHub
parent 1ef26b29a7
commit 853ac47e42
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
131 changed files with 1133 additions and 881 deletions

View File

@ -92,6 +92,7 @@ jobs:
- name: Exec OpenIM API test - name: Exec OpenIM API test
run: | run: |
sudo make test-api
mkdir -p ./tmp mkdir -p ./tmp
touch ./tmp/test.md touch ./tmp/test.md
echo "# OpenIM Test" >> ./tmp/test.md echo "# OpenIM Test" >> ./tmp/test.md
@ -104,6 +105,7 @@ jobs:
- name: Exec OpenIM E2E Test - name: Exec OpenIM E2E Test
run: | run: |
sudo make test-e2e
echo "" >> ./tmp/test.md echo "" >> ./tmp/test.md
echo "## OpenIM E2E Test" >> ./tmp/test.md echo "## OpenIM E2E Test" >> ./tmp/test.md
echo "<details><summary>Command Output for OpenIM E2E Test</summary>" >> ./tmp/test.md echo "<details><summary>Command Output for OpenIM E2E Test</summary>" >> ./tmp/test.md

View File

@ -74,6 +74,17 @@ jobs:
echo $latest_tag > pkg/common/config/version echo $latest_tag > pkg/common/config/version
continue-on-error: true continue-on-error: true
- name: Gen CHANGELOG file
run: |
current_tag=$(git describe --tags --abbrev=0)
version=$(echo "$current_tag" | sed -E 's/^v?([0-9]+)\.([0-9]+)\..*$/\1.\2/')
echo "OpenIM Version: $version"
make tools.install.git-chglog
cd CHANGELOG
git-chglog --tag-filter-pattern "v${version}.*" -o CHANGELOG-${version}.md
cd ..
continue-on-error: true
- name: Run unit test and get test coverage - name: Run unit test and get test coverage
run: | run: |
make cover make cover

View File

@ -8,6 +8,21 @@
## [Unreleased] ## [Unreleased]
<a name="v3.5.1-alpha.2"></a>
## [v3.5.1-alpha.2] - 2024-01-26
<a name="v3.5.0+15.d356f7a"></a>
## [v3.5.0+15.d356f7a] - 2024-01-26
<a name="v3.5.1-rc.1"></a>
## [v3.5.1-rc.1] - 2024-01-23
<a name="v3.5.0+2.e0bd54f-3-g52f9fc209"></a>
## [v3.5.0+2.e0bd54f-3-g52f9fc209] - 2024-01-12
<a name="v3.5.0+2.e0bd54f-1-g4ce6a0fa6"></a>
## [v3.5.0+2.e0bd54f-1-g4ce6a0fa6] - 2024-01-12
<a name="v3.5.1-alpha.1"></a> <a name="v3.5.1-alpha.1"></a>
## [v3.5.1-alpha.1] - 2024-01-09 ## [v3.5.1-alpha.1] - 2024-01-09
@ -59,7 +74,12 @@
- Merge branch 'tuoyun' - Merge branch 'tuoyun'
[Unreleased]: https://github.com/openimsdk/open-im-server/compare/v3.5.1-alpha.1...HEAD [Unreleased]: https://github.com/openimsdk/open-im-server/compare/v3.5.1-alpha.2...HEAD
[v3.5.1-alpha.2]: https://github.com/openimsdk/open-im-server/compare/v3.5.0+15.d356f7a...v3.5.1-alpha.2
[v3.5.0+15.d356f7a]: https://github.com/openimsdk/open-im-server/compare/v3.5.1-rc.1...v3.5.0+15.d356f7a
[v3.5.1-rc.1]: https://github.com/openimsdk/open-im-server/compare/v3.5.0+2.e0bd54f-3-g52f9fc209...v3.5.1-rc.1
[v3.5.0+2.e0bd54f-3-g52f9fc209]: https://github.com/openimsdk/open-im-server/compare/v3.5.0+2.e0bd54f-1-g4ce6a0fa6...v3.5.0+2.e0bd54f-3-g52f9fc209
[v3.5.0+2.e0bd54f-1-g4ce6a0fa6]: https://github.com/openimsdk/open-im-server/compare/v3.5.1-alpha.1...v3.5.0+2.e0bd54f-1-g4ce6a0fa6
[v3.5.1-alpha.1]: https://github.com/openimsdk/open-im-server/compare/v3.5.0...v3.5.1-alpha.1 [v3.5.1-alpha.1]: https://github.com/openimsdk/open-im-server/compare/v3.5.0...v3.5.1-alpha.1
[v3.5.0]: https://github.com/openimsdk/open-im-server/compare/v3.5.1...v3.5.0 [v3.5.0]: https://github.com/openimsdk/open-im-server/compare/v3.5.1...v3.5.0
[v3.5.1]: https://github.com/openimsdk/open-im-server/compare/v3.5.1-bate.1...v3.5.1 [v3.5.1]: https://github.com/openimsdk/open-im-server/compare/v3.5.1-bate.1...v3.5.1

View File

@ -184,7 +184,7 @@ test-e2e:
imports: imports:
@$(MAKE) go.imports @$(MAKE) go.imports
## clean: Remove all files that are created by building. ✨ ## clean: Delete all files created by the build, as well as all log files. ✨
.PHONY: clean .PHONY: clean
clean: clean:
@$(MAKE) go.clean @$(MAKE) go.clean

View File

@ -67,20 +67,22 @@ func run(port int, proPort int) error {
// Determine whether zk is passed according to whether it is a clustered deployment // Determine whether zk is passed according to whether it is a clustered deployment
client, err = kdisc.NewDiscoveryRegister(config.Config.Envs.Discovery) client, err = kdisc.NewDiscoveryRegister(config.Config.Envs.Discovery)
if err != nil { if err != nil {
return errs.Wrap(err, "register discovery err") return err
} }
if err = client.CreateRpcRootNodes(config.Config.GetServiceNames()); err != nil { if err = client.CreateRpcRootNodes(config.Config.GetServiceNames()); err != nil {
return errs.Wrap(err, "create rpc root nodes error") return err
} }
if err = client.RegisterConf2Registry(constant.OpenIMCommonConfigKey, config.Config.EncodeConfig()); err != nil { if err = client.RegisterConf2Registry(constant.OpenIMCommonConfigKey, config.Config.EncodeConfig()); err != nil {
return err return err
} }
var ( var (
netDone = make(chan struct{}, 1) netDone = make(chan struct{}, 1)
netErr error netErr error
) )
router := api.NewGinRouter(client, rdb) router := api.NewGinRouter(client, rdb)
if config.Config.Prometheus.Enable { if config.Config.Prometheus.Enable {
go func() { go func() {
@ -91,7 +93,6 @@ func run(port int, proPort int) error {
netDone <- struct{}{} netDone <- struct{}{}
} }
}() }()
} }
var address string var address string
@ -108,7 +109,6 @@ func run(port int, proPort int) error {
if err != nil && err != http.ErrServerClosed { if err != nil && err != http.ErrServerClosed {
netErr = errs.Wrap(err, fmt.Sprintf("api start err: %s", server.Addr)) netErr = errs.Wrap(err, fmt.Sprintf("api start err: %s", server.Addr))
netDone <- struct{}{} netDone <- struct{}{}
} }
}() }()
@ -122,7 +122,7 @@ func run(port int, proPort int) error {
util.SIGTERMExit() util.SIGTERMExit()
err := server.Shutdown(ctx) err := server.Shutdown(ctx)
if err != nil { if err != nil {
return errs.Wrap(err, "shutdown err") return errs.Wrap(err, "api shutdown err")
} }
case <-netDone: case <-netDone:
close(netDone) close(netDone)

View File

@ -26,7 +26,7 @@ func main() {
pushCmd.AddPortFlag() pushCmd.AddPortFlag()
pushCmd.AddPrometheusPortFlag() pushCmd.AddPrometheusPortFlag()
if err := pushCmd.Exec(); err != nil { if err := pushCmd.Exec(); err != nil {
panic(err.Error()) util.ExitWithError(err)
} }
if err := pushCmd.StartSvr(config.Config.RpcRegisterName.OpenImPushName, push.Start); err != nil { if err := pushCmd.StartSvr(config.Config.RpcRegisterName.OpenImPushName, push.Start); err != nil {
util.ExitWithError(err) util.ExitWithError(err)

View File

@ -26,7 +26,7 @@ func main() {
authCmd.AddPortFlag() authCmd.AddPortFlag()
authCmd.AddPrometheusPortFlag() authCmd.AddPrometheusPortFlag()
if err := authCmd.Exec(); err != nil { if err := authCmd.Exec(); err != nil {
panic(err.Error()) util.ExitWithError(err)
} }
if err := authCmd.StartSvr(config.Config.RpcRegisterName.OpenImAuthName, auth.Start); err != nil { if err := authCmd.StartSvr(config.Config.RpcRegisterName.OpenImAuthName, auth.Start); err != nil {
util.ExitWithError(err) util.ExitWithError(err)

View File

@ -26,7 +26,7 @@ func main() {
rpcCmd.AddPortFlag() rpcCmd.AddPortFlag()
rpcCmd.AddPrometheusPortFlag() rpcCmd.AddPrometheusPortFlag()
if err := rpcCmd.Exec(); err != nil { if err := rpcCmd.Exec(); err != nil {
panic(err.Error()) util.ExitWithError(err)
} }
if err := rpcCmd.StartSvr(config.Config.RpcRegisterName.OpenImConversationName, conversation.Start); err != nil { if err := rpcCmd.StartSvr(config.Config.RpcRegisterName.OpenImConversationName, conversation.Start); err != nil {
util.ExitWithError(err) util.ExitWithError(err)

View File

@ -26,7 +26,7 @@ func main() {
rpcCmd.AddPortFlag() rpcCmd.AddPortFlag()
rpcCmd.AddPrometheusPortFlag() rpcCmd.AddPrometheusPortFlag()
if err := rpcCmd.Exec(); err != nil { if err := rpcCmd.Exec(); err != nil {
panic(err.Error()) util.ExitWithError(err)
} }
if err := rpcCmd.StartSvr(config.Config.RpcRegisterName.OpenImFriendName, friend.Start); err != nil { if err := rpcCmd.StartSvr(config.Config.RpcRegisterName.OpenImFriendName, friend.Start); err != nil {
util.ExitWithError(err) util.ExitWithError(err)

View File

@ -26,7 +26,7 @@ func main() {
rpcCmd.AddPortFlag() rpcCmd.AddPortFlag()
rpcCmd.AddPrometheusPortFlag() rpcCmd.AddPrometheusPortFlag()
if err := rpcCmd.Exec(); err != nil { if err := rpcCmd.Exec(); err != nil {
panic(err.Error()) util.ExitWithError(err)
} }
if err := rpcCmd.StartSvr(config.Config.RpcRegisterName.OpenImGroupName, group.Start); err != nil { if err := rpcCmd.StartSvr(config.Config.RpcRegisterName.OpenImGroupName, group.Start); err != nil {
util.ExitWithError(err) util.ExitWithError(err)

View File

@ -26,7 +26,7 @@ func main() {
rpcCmd.AddPortFlag() rpcCmd.AddPortFlag()
rpcCmd.AddPrometheusPortFlag() rpcCmd.AddPrometheusPortFlag()
if err := rpcCmd.Exec(); err != nil { if err := rpcCmd.Exec(); err != nil {
panic(err.Error()) util.ExitWithError(err)
} }
if err := rpcCmd.StartSvr(config.Config.RpcRegisterName.OpenImMsgName, msg.Start); err != nil { if err := rpcCmd.StartSvr(config.Config.RpcRegisterName.OpenImMsgName, msg.Start); err != nil {
util.ExitWithError(err) util.ExitWithError(err)

View File

@ -26,7 +26,7 @@ func main() {
rpcCmd.AddPortFlag() rpcCmd.AddPortFlag()
rpcCmd.AddPrometheusPortFlag() rpcCmd.AddPrometheusPortFlag()
if err := rpcCmd.Exec(); err != nil { if err := rpcCmd.Exec(); err != nil {
panic(err.Error()) util.ExitWithError(err)
} }
if err := rpcCmd.StartSvr(config.Config.RpcRegisterName.OpenImThirdName, third.Start); err != nil { if err := rpcCmd.StartSvr(config.Config.RpcRegisterName.OpenImThirdName, third.Start); err != nil {
util.ExitWithError(err) util.ExitWithError(err)

View File

@ -26,7 +26,7 @@ func main() {
rpcCmd.AddPortFlag() rpcCmd.AddPortFlag()
rpcCmd.AddPrometheusPortFlag() rpcCmd.AddPrometheusPortFlag()
if err := rpcCmd.Exec(); err != nil { if err := rpcCmd.Exec(); err != nil {
panic(err.Error()) util.ExitWithError(err)
} }
if err := rpcCmd.StartSvr(config.Config.RpcRegisterName.OpenImUserName, user.Start); err != nil { if err := rpcCmd.StartSvr(config.Config.RpcRegisterName.OpenImUserName, user.Start); err != nil {
util.ExitWithError(err) util.ExitWithError(err)

View File

@ -3,7 +3,15 @@
- [Code conventions](#code-conventions) - [Code conventions](#code-conventions)
- [POSIX shell](#posix-shell) - [POSIX shell](#posix-shell)
- [Go](#go) - [Go](#go)
- [Directory and file conventions](#directory-and-file-conventions) - [OpenIM Naming Conventions Guide](#openim-naming-conventions-guide)
- [1. General File Naming](#1-general-file-naming)
- [2. Special File Types](#2-special-file-types)
- [a. Script and Markdown Files](#a-script-and-markdown-files)
- [b. Uppercase Markdown Documentation](#b-uppercase-markdown-documentation)
- [3. Directory Naming](#3-directory-naming)
- [4. Configuration Files](#4-configuration-files)
- [Best Practices](#best-practices)
- [Directory and File Conventions](#directory-and-file-conventions)
- [Testing conventions](#testing-conventions) - [Testing conventions](#testing-conventions)
## POSIX shell ## POSIX shell
@ -67,12 +75,13 @@ Files within the OpenIM project should adhere to the following rules:
+ Stick to lowercase naming where possible for consistency and to prevent issues with case-sensitive systems. + Stick to lowercase naming where possible for consistency and to prevent issues with case-sensitive systems.
+ Include version numbers or dates in file names if the file is subject to updates, following the format: `project-plan-v1.2.md` or `backup-2023-03-15.sql`. + Include version numbers or dates in file names if the file is subject to updates, following the format: `project-plan-v1.2.md` or `backup-2023-03-15.sql`.
## Directory and file conventions ## Directory and File Conventions
- Avoid generic utility packages. Instead of naming a package "util", choose a name that clearly describes its purpose. For instance, functions related to waiting operations are contained within the `wait` package, which includes methods like `Poll`, fully named as `wait.Poll`.
- All filenames, script files, configuration files, and directories should be in lowercase and use dashes (`-`) as separators.
- For Go language files, filenames should be in lowercase and use underscores (`_`).
- Package names should match their directory names to ensure consistency. For example, within the `openim-api` directory, the Go file should be named `openim-api.go`, following the convention of using dashes for directory names and aligning package names with directory names.
- Avoid general utility packages. Packages called "util" are suspect. Instead, derive a name that describes your desired function. For example, the utility functions dealing with waiting for operations are in the `wait` package and include functionality like `Poll`. The full name is `wait.Poll`.
- All filenames should be lowercase.
- All source files and directories should use underscores, not dashes.
- Package directories should generally avoid using separators as much as possible. When package names are multiple words, they usually should be in nested subdirectories.
## Testing conventions ## Testing conventions

View File

@ -20,19 +20,16 @@ import (
"github.com/OpenIMSDK/protocol/constant" "github.com/OpenIMSDK/protocol/constant"
) )
// RequiredIf validates if the specified field is required based on the session type.
func RequiredIf(fl validator.FieldLevel) bool { func RequiredIf(fl validator.FieldLevel) bool {
sessionType := fl.Parent().FieldByName("SessionType").Int() sessionType := fl.Parent().FieldByName("SessionType").Int()
switch sessionType { switch sessionType {
case constant.SingleChatType, constant.NotificationChatType: case constant.SingleChatType, constant.NotificationChatType:
if fl.FieldName() == "RecvID" { return fl.FieldName() != "RecvID" || fl.Field().String() != ""
return fl.Field().String() != ""
}
case constant.GroupChatType, constant.SuperGroupChatType: case constant.GroupChatType, constant.SuperGroupChatType:
if fl.FieldName() == "GroupID" { return fl.FieldName() != "GroupID" || fl.Field().String() != ""
return fl.Field().String() != ""
}
default: default:
return true return true
} }
return true
} }

View File

@ -210,7 +210,6 @@ func (m *MessageApi) SendMessage(c *gin.Context) {
sendMsgReq, err := m.getSendMsgReq(c, req.SendMsg) sendMsgReq, err := m.getSendMsgReq(c, req.SendMsg)
if err != nil { if err != nil {
// Log and respond with an error if preparation fails. // Log and respond with an error if preparation fails.
log.ZError(c, "decodeData failed", err)
apiresp.GinError(c, err) apiresp.GinError(c, err)
return return
} }
@ -226,7 +225,6 @@ func (m *MessageApi) SendMessage(c *gin.Context) {
if err != nil { if err != nil {
// Set the status to failed and respond with an error if sending fails. // Set the status to failed and respond with an error if sending fails.
status = constant.MsgSendFailed status = constant.MsgSendFailed
log.ZError(c, "send message err", err)
apiresp.GinError(c, err) apiresp.GinError(c, err)
return return
} }
@ -240,7 +238,8 @@ func (m *MessageApi) SendMessage(c *gin.Context) {
}) })
if err != nil { if err != nil {
// Log the error if updating the status fails. // Log the error if updating the status fails.
log.ZError(c, "SetSendMsgStatus failed", err) apiresp.GinError(c, err)
return
} }
// Respond with a success message and the response payload. // Respond with a success message and the response payload.
@ -299,7 +298,6 @@ func (m *MessageApi) BatchSendMsg(c *gin.Context) {
resp apistruct.BatchSendMsgResp resp apistruct.BatchSendMsgResp
) )
if err := c.BindJSON(&req); err != nil { if err := c.BindJSON(&req); err != nil {
log.ZError(c, "BatchSendMsg BindJSON failed", err)
apiresp.GinError(c, errs.ErrArgs.WithDetail(err.Error()).Wrap()) apiresp.GinError(c, errs.ErrArgs.WithDetail(err.Error()).Wrap())
return return
} }
@ -310,14 +308,12 @@ func (m *MessageApi) BatchSendMsg(c *gin.Context) {
} }
var recvIDs []string var recvIDs []string
var err error
if req.IsSendAll { if req.IsSendAll {
pageNumber := 1 pageNumber := 1
showNumber := 500 showNumber := 500
for { for {
recvIDsPart, err := m.userRpcClient.GetAllUserIDs(c, int32(pageNumber), int32(showNumber)) recvIDsPart, err := m.userRpcClient.GetAllUserIDs(c, int32(pageNumber), int32(showNumber))
if err != nil { if err != nil {
log.ZError(c, "GetAllUserIDs failed", err)
apiresp.GinError(c, err) apiresp.GinError(c, err)
return return
} }
@ -333,7 +329,6 @@ func (m *MessageApi) BatchSendMsg(c *gin.Context) {
log.ZDebug(c, "BatchSendMsg nums", "nums ", len(recvIDs)) log.ZDebug(c, "BatchSendMsg nums", "nums ", len(recvIDs))
sendMsgReq, err := m.getSendMsgReq(c, req.SendMsg) sendMsgReq, err := m.getSendMsgReq(c, req.SendMsg)
if err != nil { if err != nil {
log.ZError(c, "decodeData failed", err)
apiresp.GinError(c, err) apiresp.GinError(c, err)
return return
} }

View File

@ -44,7 +44,7 @@ import (
) )
func NewGinRouter(discov discoveryregistry.SvcDiscoveryRegistry, rdb redis.UniversalClient) *gin.Engine { func NewGinRouter(discov discoveryregistry.SvcDiscoveryRegistry, rdb redis.UniversalClient) *gin.Engine {
discov.AddOption(mw.GrpcClient(), grpc.WithTransportCredentials(insecure.NewCredentials()), grpc.WithDefaultServiceConfig(fmt.Sprintf(`{"LoadBalancingPolicy": "%s"}`, "round_robin"))) // 默认RPC中间件 discov.AddOption(mw.GrpcClient(), grpc.WithTransportCredentials(insecure.NewCredentials()), grpc.WithDefaultServiceConfig(fmt.Sprintf(`{"LoadBalancingPolicy": "%s"}`, "round_robin"))) // Default RPC middleware
gin.SetMode(gin.ReleaseMode) gin.SetMode(gin.ReleaseMode)
r := gin.New() r := gin.New()
if v, ok := binding.Validator.Engine().(*validator.Validate); ok { if v, ok := binding.Validator.Engine().(*validator.Validate); ok {
@ -225,6 +225,7 @@ func NewGinRouter(discov discoveryregistry.SvcDiscoveryRegistry, rdb redis.Unive
return r return r
} }
// GinParseToken is a middleware that parses the token in the request header and verifies it.
func GinParseToken(rdb redis.UniversalClient) gin.HandlerFunc { func GinParseToken(rdb redis.UniversalClient) gin.HandlerFunc {
dataBase := controller.NewAuthDatabase( dataBase := controller.NewAuthDatabase(
cache.NewMsgCacheModel(rdb), cache.NewMsgCacheModel(rdb),
@ -250,13 +251,11 @@ func GinParseToken(rdb redis.UniversalClient) gin.HandlerFunc {
} }
m, err := dataBase.GetTokensWithoutError(c, claims.UserID, claims.PlatformID) m, err := dataBase.GetTokensWithoutError(c, claims.UserID, claims.PlatformID)
if err != nil { if err != nil {
log.ZWarn(c, "cache get token error", errs.ErrTokenNotExist.Wrap())
apiresp.GinError(c, errs.ErrTokenNotExist.Wrap()) apiresp.GinError(c, errs.ErrTokenNotExist.Wrap())
c.Abort() c.Abort()
return return
} }
if len(m) == 0 { if len(m) == 0 {
log.ZWarn(c, "cache do not exist token error", errs.ErrTokenNotExist.Wrap())
apiresp.GinError(c, errs.ErrTokenNotExist.Wrap()) apiresp.GinError(c, errs.ErrTokenNotExist.Wrap())
c.Abort() c.Abort()
return return
@ -265,12 +264,10 @@ func GinParseToken(rdb redis.UniversalClient) gin.HandlerFunc {
switch v { switch v {
case constant.NormalToken: case constant.NormalToken:
case constant.KickedToken: case constant.KickedToken:
log.ZWarn(c, "cache kicked token error", errs.ErrTokenKicked.Wrap())
apiresp.GinError(c, errs.ErrTokenKicked.Wrap()) apiresp.GinError(c, errs.ErrTokenKicked.Wrap())
c.Abort() c.Abort()
return return
default: default:
log.ZWarn(c, "cache unknown token error", errs.ErrTokenUnknown.Wrap())
apiresp.GinError(c, errs.ErrTokenUnknown.Wrap()) apiresp.GinError(c, errs.ErrTokenUnknown.Wrap())
c.Abort() c.Abort()
return return
@ -286,3 +283,10 @@ func GinParseToken(rdb redis.UniversalClient) gin.HandlerFunc {
} }
} }
} }
// // handleGinError logs and returns an error response through Gin context.
// func handleGinError(c *gin.Context, logMessage string, errType errs.CodeError, detail string) {
// wrappedErr := errType.Wrap(detail)
// apiresp.GinError(c, wrappedErr)
// c.Abort()
// }

View File

@ -68,7 +68,7 @@ func (u *UserApi) GetUsers(c *gin.Context) {
func (u *UserApi) GetUsersOnlineStatus(c *gin.Context) { func (u *UserApi) GetUsersOnlineStatus(c *gin.Context) {
var req msggateway.GetUsersOnlineStatusReq var req msggateway.GetUsersOnlineStatusReq
if err := c.BindJSON(&req); err != nil { if err := c.BindJSON(&req); err != nil {
apiresp.GinError(c, errs.ErrArgs.WithDetail(err.Error()).Wrap()) apiresp.GinError(c, err)
return return
} }
conns, err := u.Discov.GetConns(c, config.Config.RpcRegisterName.OpenImMessageGatewayName) conns, err := u.Discov.GetConns(c, config.Config.RpcRegisterName.OpenImMessageGatewayName)
@ -86,7 +86,7 @@ func (u *UserApi) GetUsersOnlineStatus(c *gin.Context) {
msgClient := msggateway.NewMsgGatewayClient(v) msgClient := msggateway.NewMsgGatewayClient(v)
reply, err := msgClient.GetUsersOnlineStatus(c, &req) reply, err := msgClient.GetUsersOnlineStatus(c, &req)
if err != nil { if err != nil {
log.ZWarn(c, "GetUsersOnlineStatus rpc err", err) log.ZDebug(c, "GetUsersOnlineStatus rpc error", err)
parseError := apiresp.ParseError(err) parseError := apiresp.ParseError(err)
if parseError.ErrCode == errs.NoPermissionError { if parseError.ErrCode == errs.NoPermissionError {

View File

@ -91,13 +91,7 @@ type Client struct {
// } // }
// ResetClient updates the client's state with new connection and context information. // ResetClient updates the client's state with new connection and context information.
func (c *Client) ResetClient( func (c *Client) ResetClient(ctx *UserConnContext, conn LongConn, isBackground, isCompress bool, longConnServer LongConnServer, token string) {
ctx *UserConnContext,
conn LongConn,
isBackground, isCompress bool,
longConnServer LongConnServer,
token string,
) {
c.w = new(sync.Mutex) c.w = new(sync.Mutex)
c.conn = conn c.conn = conn
c.PlatformID = utils.StringToInt(ctx.GetPlatformID()) c.PlatformID = utils.StringToInt(ctx.GetPlatformID())
@ -112,9 +106,11 @@ func (c *Client) ResetClient(
c.token = token c.token = token
} }
// pingHandler handles ping messages and sends pong responses.
func (c *Client) pingHandler(_ string) error { func (c *Client) pingHandler(_ string) error {
_ = c.conn.SetReadDeadline(pongWait) if err := c.conn.SetReadDeadline(pongWait); err != nil {
return err
}
return c.writePongMsg() return c.writePongMsg()
} }
@ -141,7 +137,8 @@ func (c *Client) readMessage() {
} }
log.ZDebug(c.ctx, "readMessage", "messageType", messageType) log.ZDebug(c.ctx, "readMessage", "messageType", messageType)
if c.closed.Load() { // 连接刚置位已经关闭,但是协程还没退出的场景 if c.closed.Load() {
// The scenario where the connection has just been closed, but the coroutine has not exited
c.closedErr = ErrConnClosed c.closedErr = ErrConnClosed
return return
} }
@ -185,11 +182,11 @@ func (c *Client) handleMessage(message []byte) error {
err := c.longConnServer.Decode(message, binaryReq) err := c.longConnServer.Decode(message, binaryReq)
if err != nil { if err != nil {
return errs.Wrap(err) return err
} }
if err := c.longConnServer.Validate(binaryReq); err != nil { if err := c.longConnServer.Validate(binaryReq); err != nil {
return errs.Wrap(err) return err
} }
if binaryReq.SendID != c.UserID { if binaryReq.SendID != c.UserID {
@ -239,7 +236,7 @@ func (c *Client) setAppBackgroundStatus(ctx context.Context, req *Req) ([]byte,
} }
c.IsBackground = isBackground c.IsBackground = isBackground
// todo callback // TODO: callback
return resp, nil return resp, nil
} }
@ -273,7 +270,7 @@ func (c *Client) replyMessage(ctx context.Context, binaryReq *Req, err error, re
} }
if binaryReq.ReqIdentifier == WsLogoutMsg { if binaryReq.ReqIdentifier == WsLogoutMsg {
return errors.New("user logout") return errs.Wrap(errors.New("user logout"))
} }
return nil return nil
} }
@ -316,17 +313,21 @@ func (c *Client) writeBinaryMsg(resp Resp) error {
encodedBuf, err := c.longConnServer.Encode(resp) encodedBuf, err := c.longConnServer.Encode(resp)
if err != nil { if err != nil {
return errs.Wrap(err) return err
} }
c.w.Lock() c.w.Lock()
defer c.w.Unlock() defer c.w.Unlock()
_ = c.conn.SetWriteDeadline(writeWait) err = c.conn.SetWriteDeadline(writeWait)
if err != nil {
return err
}
if c.IsCompress { if c.IsCompress {
resultBuf, compressErr := c.longConnServer.CompressWithPool(encodedBuf) resultBuf, compressErr := c.longConnServer.CompressWithPool(encodedBuf)
if compressErr != nil { if compressErr != nil {
return errs.Wrap(compressErr) return compressErr
} }
return c.conn.WriteMessage(MessageBinary, resultBuf) return c.conn.WriteMessage(MessageBinary, resultBuf)
} }
@ -344,7 +345,7 @@ func (c *Client) writePongMsg() error {
err := c.conn.SetWriteDeadline(writeWait) err := c.conn.SetWriteDeadline(writeWait)
if err != nil { if err != nil {
return errs.Wrap(err) return err
} }
return c.conn.WriteMessage(PongMessage, nil) return c.conn.WriteMessage(PongMessage, nil)

View File

@ -17,7 +17,6 @@ package msggateway
import ( import (
"bytes" "bytes"
"compress/gzip" "compress/gzip"
"errors"
"io" "io"
"sync" "sync"
@ -46,12 +45,15 @@ func NewGzipCompressor() *GzipCompressor {
func (g *GzipCompressor) Compress(rawData []byte) ([]byte, error) { func (g *GzipCompressor) Compress(rawData []byte) ([]byte, error) {
gzipBuffer := bytes.Buffer{} gzipBuffer := bytes.Buffer{}
gz := gzip.NewWriter(&gzipBuffer) gz := gzip.NewWriter(&gzipBuffer)
if _, err := gz.Write(rawData); err != nil { if _, err := gz.Write(rawData); err != nil {
return nil, errs.Wrap(err) return nil, errs.Wrap(err, "GzipCompressor.Compress: writing to gzip writer failed")
} }
if err := gz.Close(); err != nil { if err := gz.Close(); err != nil {
return nil, errs.Wrap(err) return nil, errs.Wrap(err, "GzipCompressor.Compress: closing gzip writer failed")
} }
return gzipBuffer.Bytes(), nil return gzipBuffer.Bytes(), nil
} }
@ -63,10 +65,10 @@ func (g *GzipCompressor) CompressWithPool(rawData []byte) ([]byte, error) {
gz.Reset(&gzipBuffer) gz.Reset(&gzipBuffer)
if _, err := gz.Write(rawData); err != nil { if _, err := gz.Write(rawData); err != nil {
return nil, errs.Wrap(err) return nil, errs.Wrap(err, "GzipCompressor.CompressWithPool: error writing data")
} }
if err := gz.Close(); err != nil { if err := gz.Close(); err != nil {
return nil, errs.Wrap(err) return nil, errs.Wrap(err, "GzipCompressor.CompressWithPool: error closing gzip writer")
} }
return gzipBuffer.Bytes(), nil return gzipBuffer.Bytes(), nil
} }
@ -75,32 +77,36 @@ func (g *GzipCompressor) DeCompress(compressedData []byte) ([]byte, error) {
buff := bytes.NewBuffer(compressedData) buff := bytes.NewBuffer(compressedData)
reader, err := gzip.NewReader(buff) reader, err := gzip.NewReader(buff)
if err != nil { if err != nil {
return nil, errs.Wrap(err, "NewReader failed") return nil, errs.Wrap(err, "GzipCompressor.DeCompress: NewReader creation failed")
} }
compressedData, err = io.ReadAll(reader) decompressedData, err := io.ReadAll(reader)
if err != nil { if err != nil {
return nil, errs.Wrap(err, "ReadAll failed") return nil, errs.Wrap(err, "GzipCompressor.DeCompress: reading from gzip reader failed")
} }
_ = reader.Close() if err = reader.Close(); err != nil {
return compressedData, nil // Even if closing the reader fails, we've successfully read the data,
// so we return the decompressed data and an error indicating the close failure.
return decompressedData, errs.Wrap(err, "GzipCompressor.DeCompress: closing gzip reader failed")
}
return decompressedData, nil
} }
func (g *GzipCompressor) DecompressWithPool(compressedData []byte) ([]byte, error) { func (g *GzipCompressor) DecompressWithPool(compressedData []byte) ([]byte, error) {
reader := gzipReaderPool.Get().(*gzip.Reader) reader := gzipReaderPool.Get().(*gzip.Reader)
if reader == nil {
return nil, errs.Wrap(errors.New("NewReader failed"))
}
defer gzipReaderPool.Put(reader) defer gzipReaderPool.Put(reader)
err := reader.Reset(bytes.NewReader(compressedData)) err := reader.Reset(bytes.NewReader(compressedData))
if err != nil { if err != nil {
return nil, errs.Wrap(err, "NewReader failed") return nil, errs.Wrap(err, "GzipCompressor.DecompressWithPool: resetting gzip reader failed")
} }
compressedData, err = io.ReadAll(reader) decompressedData, err := io.ReadAll(reader)
if err != nil { if err != nil {
return nil, errs.Wrap(err, "ReadAll failed") return nil, errs.Wrap(err, "GzipCompressor.DecompressWithPool: reading from pooled gzip reader failed")
} }
_ = reader.Close() if err = reader.Close(); err != nil {
return compressedData, nil // Similar to DeCompress, return the data and error for close failure.
return decompressedData, errs.Wrap(err, "GzipCompressor.DecompressWithPool: closing pooled gzip reader failed")
}
return decompressedData, nil
} }

View File

@ -37,10 +37,16 @@ func TestCompressDecompress(t *testing.T) {
// compress // compress
dest, err := compressor.CompressWithPool(src) dest, err := compressor.CompressWithPool(src)
if err != nil {
t.Log(err)
}
assert.Equal(t, nil, err) assert.Equal(t, nil, err)
// decompress // decompress
res, err := compressor.DecompressWithPool(dest) res, err := compressor.DecompressWithPool(dest)
if err != nil {
t.Log(err)
}
assert.Equal(t, nil, err) assert.Equal(t, nil, err)
// check // check
@ -60,10 +66,16 @@ func TestCompressDecompressWithConcurrency(t *testing.T) {
// compress // compress
dest, err := compressor.CompressWithPool(src) dest, err := compressor.CompressWithPool(src)
if err != nil {
t.Log(err)
}
assert.Equal(t, nil, err) assert.Equal(t, nil, err)
// decompress // decompress
res, err := compressor.DecompressWithPool(dest) res, err := compressor.DecompressWithPool(dest)
if err != nil {
t.Log(err)
}
assert.Equal(t, nil, err) assert.Equal(t, nil, err)
// check // check
@ -99,6 +111,7 @@ func BenchmarkDecompress(b *testing.B) {
compressor := NewGzipCompressor() compressor := NewGzipCompressor()
comdata, err := compressor.Compress(src) comdata, err := compressor.Compress(src)
assert.Equal(b, nil, err) assert.Equal(b, nil, err)
for i := 0; i < b.N; i++ { for i := 0; i < b.N; i++ {

View File

@ -37,7 +37,7 @@ func (g *GobEncoder) Encode(data any) ([]byte, error) {
enc := gob.NewEncoder(&buff) enc := gob.NewEncoder(&buff)
err := enc.Encode(data) err := enc.Encode(data)
if err != nil { if err != nil {
return nil, err return nil, errs.Wrap(err, "GobEncoder.Encode failed")
} }
return buff.Bytes(), nil return buff.Bytes(), nil
} }
@ -47,7 +47,7 @@ func (g *GobEncoder) Decode(encodeData []byte, decodeData any) error {
dec := gob.NewDecoder(buff) dec := gob.NewDecoder(buff)
err := dec.Decode(decodeData) err := dec.Decode(decodeData)
if err != nil { if err != nil {
return errs.Wrap(err) return errs.Wrap(err, "GobEncoder.Decode failed")
} }
return nil return nil
} }

View File

@ -23,14 +23,7 @@ import (
// RunWsAndServer run ws server. // RunWsAndServer run ws server.
func RunWsAndServer(rpcPort, wsPort, prometheusPort int) error { func RunWsAndServer(rpcPort, wsPort, prometheusPort int) error {
fmt.Println( fmt.Println("start rpc/msg_gateway server, port: ", rpcPort, wsPort, prometheusPort, ", OpenIM version: ", config.Version)
"start rpc/msg_gateway server, port: ",
rpcPort,
wsPort,
prometheusPort,
", OpenIM version: ",
config.Version,
)
longServer, err := NewWsServer( longServer, err := NewWsServer(
WithPort(wsPort), WithPort(wsPort),
WithMaxConnNum(int64(config.Config.LongConnSvr.WebsocketMaxConnNum)), WithMaxConnNum(int64(config.Config.LongConnSvr.WebsocketMaxConnNum)),

View File

@ -15,9 +15,11 @@
package msggateway package msggateway
import ( import (
"errors"
"net/http" "net/http"
"time" "time"
"github.com/OpenIMSDK/tools/errs"
"github.com/gorilla/websocket" "github.com/gorilla/websocket"
) )
@ -72,7 +74,8 @@ func (d *GWebSocket) GenerateLongConn(w http.ResponseWriter, r *http.Request) er
conn, err := upgrader.Upgrade(w, r, nil) conn, err := upgrader.Upgrade(w, r, nil)
if err != nil { if err != nil {
return err // The upgrader.Upgrade method usually returns enough error messages to diagnose problems that may occur during the upgrade
return errs.Wrap(err, "GenerateLongConn: WebSocket upgrade failed")
} }
d.conn = conn d.conn = conn
return nil return nil
@ -96,7 +99,16 @@ func (d *GWebSocket) SetReadDeadline(timeout time.Duration) error {
} }
func (d *GWebSocket) SetWriteDeadline(timeout time.Duration) error { func (d *GWebSocket) SetWriteDeadline(timeout time.Duration) error {
return d.conn.SetWriteDeadline(time.Now().Add(timeout)) // TODO add error
if timeout <= 0 {
return errs.Wrap(errors.New("timeout must be greater than 0"))
}
// TODO SetWriteDeadline Future add error handling
if err := d.conn.SetWriteDeadline(time.Now().Add(timeout)); err != nil {
return errs.Wrap(err, "GWebSocket.SetWriteDeadline failed")
}
return nil
} }
func (d *GWebSocket) Dial(urlStr string, requestHeader http.Header) (*http.Response, error) { func (d *GWebSocket) Dial(urlStr string, requestHeader http.Header) (*http.Response, error) {

View File

@ -20,6 +20,7 @@ import (
"github.com/OpenIMSDK/protocol/push" "github.com/OpenIMSDK/protocol/push"
"github.com/OpenIMSDK/tools/discoveryregistry" "github.com/OpenIMSDK/tools/discoveryregistry"
"github.com/OpenIMSDK/tools/errs"
"github.com/go-playground/validator/v10" "github.com/go-playground/validator/v10"
"google.golang.org/protobuf/proto" "google.golang.org/protobuf/proto"
@ -119,10 +120,10 @@ func NewGrpcHandler(validate *validator.Validate, client discoveryregistry.SvcDi
func (g GrpcHandler) GetSeq(context context.Context, data *Req) ([]byte, error) { func (g GrpcHandler) GetSeq(context context.Context, data *Req) ([]byte, error) {
req := sdkws.GetMaxSeqReq{} req := sdkws.GetMaxSeqReq{}
if err := proto.Unmarshal(data.Data, &req); err != nil { if err := proto.Unmarshal(data.Data, &req); err != nil {
return nil, err return nil, errs.Wrap(err, "GetSeq: error unmarshaling request")
} }
if err := g.validate.Struct(&req); err != nil { if err := g.validate.Struct(&req); err != nil {
return nil, err return nil, errs.Wrap(err, "GetSeq: validation failed")
} }
resp, err := g.msgRpcClient.GetMaxSeq(context, &req) resp, err := g.msgRpcClient.GetMaxSeq(context, &req)
if err != nil { if err != nil {
@ -130,28 +131,37 @@ func (g GrpcHandler) GetSeq(context context.Context, data *Req) ([]byte, error)
} }
c, err := proto.Marshal(resp) c, err := proto.Marshal(resp)
if err != nil { if err != nil {
return nil, err return nil, errs.Wrap(err, "GetSeq: error marshaling response")
} }
return c, nil return c, nil
} }
func (g GrpcHandler) SendMessage(context context.Context, data *Req) ([]byte, error) { // SendMessage handles the sending of messages through gRPC. It unmarshals the request data,
msgData := sdkws.MsgData{} // validates the message, and then sends it using the message RPC client.
func (g GrpcHandler) SendMessage(ctx context.Context, data *Req) ([]byte, error) {
// Unmarshal the message data from the request.
var msgData sdkws.MsgData
if err := proto.Unmarshal(data.Data, &msgData); err != nil { if err := proto.Unmarshal(data.Data, &msgData); err != nil {
return nil, err return nil, errs.Wrap(err, "error unmarshalling message data")
} }
// Validate the message data structure.
if err := g.validate.Struct(&msgData); err != nil { if err := g.validate.Struct(&msgData); err != nil {
return nil, err return nil, errs.Wrap(err, "message data validation failed")
} }
req := msg.SendMsgReq{MsgData: &msgData} req := msg.SendMsgReq{MsgData: &msgData}
resp, err := g.msgRpcClient.SendMsg(context, &req)
resp, err := g.msgRpcClient.SendMsg(ctx, &req)
if err != nil { if err != nil {
return nil, err return nil, err
} }
c, err := proto.Marshal(resp) c, err := proto.Marshal(resp)
if err != nil { if err != nil {
return nil, err return nil, errs.Wrap(err, "error marshaling response")
} }
return c, nil return c, nil
} }
@ -162,7 +172,7 @@ func (g GrpcHandler) SendSignalMessage(context context.Context, data *Req) ([]by
} }
c, err := proto.Marshal(resp) c, err := proto.Marshal(resp)
if err != nil { if err != nil {
return nil, err return nil, errs.Wrap(err, "error marshaling response")
} }
return c, nil return c, nil
} }
@ -170,7 +180,7 @@ func (g GrpcHandler) SendSignalMessage(context context.Context, data *Req) ([]by
func (g GrpcHandler) PullMessageBySeqList(context context.Context, data *Req) ([]byte, error) { func (g GrpcHandler) PullMessageBySeqList(context context.Context, data *Req) ([]byte, error) {
req := sdkws.PullMessageBySeqsReq{} req := sdkws.PullMessageBySeqsReq{}
if err := proto.Unmarshal(data.Data, &req); err != nil { if err := proto.Unmarshal(data.Data, &req); err != nil {
return nil, err return nil, errs.Wrap(err, "error unmarshaling request")
} }
if err := g.validate.Struct(data); err != nil { if err := g.validate.Struct(data); err != nil {
return nil, err return nil, err

View File

@ -88,6 +88,7 @@ type WsServer struct {
Encoder Encoder
MessageHandler MessageHandler
} }
type kickHandler struct { type kickHandler struct {
clientOK bool clientOK bool
oldClients []*Client oldClients []*Client
@ -129,7 +130,9 @@ func (ws *WsServer) UnRegister(c *Client) {
} }
func (ws *WsServer) Validate(s any) error { func (ws *WsServer) Validate(s any) error {
//?question? if s == nil {
return errs.Wrap(errors.New("input cannot be nil"))
}
return nil return nil
} }
@ -276,7 +279,7 @@ func (ws *WsServer) registerClient(client *Client) {
log.ZDebug(client.ctx, "user exist", "userID", client.UserID, "platformID", client.PlatformID) log.ZDebug(client.ctx, "user exist", "userID", client.UserID, "platformID", client.PlatformID)
if clientOK { if clientOK {
ws.clients.Set(client.UserID, client) ws.clients.Set(client.UserID, client)
// 已经有同平台的连接存在 // There is already a connection to the platform
log.ZInfo(client.ctx, "repeat login", "userID", client.UserID, "platformID", client.PlatformID, "old remote addr", getRemoteAdders(oldClients)) log.ZInfo(client.ctx, "repeat login", "userID", client.UserID, "platformID", client.PlatformID, "old remote addr", getRemoteAdders(oldClients))
ws.onlineUserConnNum.Add(1) ws.onlineUserConnNum.Add(1)
} else { } else {

View File

@ -19,15 +19,15 @@ import "time"
type ( type (
Option func(opt *configs) Option func(opt *configs)
configs struct { configs struct {
// 长连接监听端口 // Long connection listening port
port int port int
// 长连接允许最大链接数 // Maximum number of connections allowed for long connection
maxConnNum int64 maxConnNum int64
// 连接握手超时时间 // Connection handshake timeout
handshakeTimeout time.Duration handshakeTimeout time.Duration
// 允许消息最大长度 // Maximum length allowed for messages
messageMaxMsgLength int messageMaxMsgLength int
// websocket write buffer, default: 4096, 4kb. // Websocket write buffer, default: 4096, 4kb.
writeBufferSize int writeBufferSize int
} }
) )

View File

@ -45,8 +45,13 @@ import (
) )
type MsgTransfer struct { type MsgTransfer struct {
historyCH *OnlineHistoryRedisConsumerHandler // 这个消费者聚合消息, 订阅的topicws2ms_chat, 修改通知发往msg_to_modify topic, 消息存入redis后Incr Redis, 再发消息到ms2pschat topic推送 发消息到msg_to_mongo topic持久化 // This consumer aggregated messages, subscribed to the topic:ws2ms_chat,
historyMongoCH *OnlineHistoryMongoConsumerHandler // mongoDB批量插入, 成功后删除redis中消息以及处理删除通知消息删除的 订阅的topic: msg_to_mongo // the modification notification is sent to msg_to_modify topic, the message is stored in redis, Incr Redis,
// and then the message is sent to ms2pschat topic for push, and the message is sent to msg_to_mongo topic for persistence
historyCH *OnlineHistoryRedisConsumerHandler
// mongoDB batch insert, delete messages in redis after success,
// and handle the deletion notification message deleted subscriptions topic: msg_to_mongo
historyMongoCH *OnlineHistoryMongoConsumerHandler
ctx context.Context ctx context.Context
cancel context.CancelFunc cancel context.CancelFunc
} }
@ -65,6 +70,7 @@ func StartTransfer(prometheusPort int) error {
if err = mongo.CreateMsgIndex(); err != nil { if err = mongo.CreateMsgIndex(); err != nil {
return err return err
} }
client, err := kdisc.NewDiscoveryRegister(config.Config.Envs.Discovery) client, err := kdisc.NewDiscoveryRegister(config.Config.Envs.Discovery)
if err != nil { if err != nil {
return err return err
@ -73,6 +79,7 @@ func StartTransfer(prometheusPort int) error {
if err := client.CreateRpcRootNodes(config.Config.GetServiceNames()); err != nil { if err := client.CreateRpcRootNodes(config.Config.GetServiceNames()); err != nil {
return err return err
} }
client.AddOption(mw.GrpcClient(), grpc.WithTransportCredentials(insecure.NewCredentials()), grpc.WithDefaultServiceConfig(fmt.Sprintf(`{"LoadBalancingPolicy": "%s"}`, "round_robin"))) client.AddOption(mw.GrpcClient(), grpc.WithTransportCredentials(insecure.NewCredentials()), grpc.WithDefaultServiceConfig(fmt.Sprintf(`{"LoadBalancingPolicy": "%s"}`, "round_robin")))
msgModel := cache.NewMsgCacheModel(rdb) msgModel := cache.NewMsgCacheModel(rdb)
msgDocModel := unrelation.NewMsgMongoDriver(mongo.GetDatabase()) msgDocModel := unrelation.NewMsgMongoDriver(mongo.GetDatabase())
@ -106,7 +113,7 @@ func NewMsgTransfer(msgDatabase controller.CommonMsgDatabase, conversationRpcCli
} }
func (m *MsgTransfer) Start(prometheusPort int) error { func (m *MsgTransfer) Start(prometheusPort int) error {
fmt.Println("start msg transfer", "prometheusPort:", prometheusPort) fmt.Println("Start msg transfer", "prometheusPort:", prometheusPort)
if prometheusPort <= 0 { if prometheusPort <= 0 {
return errs.Wrap(errors.New("prometheusPort not correct")) return errs.Wrap(errors.New("prometheusPort not correct"))
} }

View File

@ -155,21 +155,13 @@ func (och *OnlineHistoryRedisConsumerHandler) Run(channelID int) {
notStorageNotificationList, notStorageNotificationList,
) )
if err := och.msgDatabase.MsgToModifyMQ(ctx, msgChannelValue.uniqueKey, conversationIDNotification, modifyMsgList); err != nil { if err := och.msgDatabase.MsgToModifyMQ(ctx, msgChannelValue.uniqueKey, conversationIDNotification, modifyMsgList); err != nil {
log.ZError( log.ZError(ctx, "msg to modify mq error", err, "uniqueKey", msgChannelValue.uniqueKey, "modifyMsgList", modifyMsgList)
ctx,
"msg to modify mq error",
err,
"uniqueKey",
msgChannelValue.uniqueKey,
"modifyMsgList",
modifyMsgList,
)
} }
} }
} }
} }
// 获取消息/通知 存储的消息列表, 不存储并且推送的消息列表,. // Get messages/notifications stored message list, not stored and pushed message list.
func (och *OnlineHistoryRedisConsumerHandler) getPushStorageMsgList( func (och *OnlineHistoryRedisConsumerHandler) getPushStorageMsgList(
totalMsgs []*ContextMsg, totalMsgs []*ContextMsg,
) (storageMsgList, notStorageMsgList, storageNotificatoinList, notStorageNotificationList, modifyMsgList []*sdkws.MsgData) { ) (storageMsgList, notStorageMsgList, storageNotificatoinList, notStorageNotificationList, modifyMsgList []*sdkws.MsgData) {
@ -190,7 +182,7 @@ func (och *OnlineHistoryRedisConsumerHandler) getPushStorageMsgList(
// clone msg from notificationMsg // clone msg from notificationMsg
if options.IsSendMsg() { if options.IsSendMsg() {
msg := proto.Clone(v.message).(*sdkws.MsgData) msg := proto.Clone(v.message).(*sdkws.MsgData)
// 消息 // message
if v.message.Options != nil { if v.message.Options != nil {
msg.Options = msgprocessor.NewMsgOptions() msg.Options = msgprocessor.NewMsgOptions()
} }

View File

@ -31,12 +31,7 @@ func url() string {
return config.Config.Callback.CallbackUrl return config.Config.Callback.CallbackUrl
} }
func callbackOfflinePush( func callbackOfflinePush(ctx context.Context, userIDs []string, msg *sdkws.MsgData, offlinePushUserIDs *[]string) error {
ctx context.Context,
userIDs []string,
msg *sdkws.MsgData,
offlinePushUserIDs *[]string,
) error {
if !config.Config.Callback.CallbackOfflinePush.Enable || msg.ContentType == constant.Typing { if !config.Config.Callback.CallbackOfflinePush.Enable || msg.ContentType == constant.Typing {
return nil return nil
} }
@ -59,10 +54,12 @@ func callbackOfflinePush(
AtUserIDs: msg.AtUserIDList, AtUserIDs: msg.AtUserIDList,
Content: GetContent(msg), Content: GetContent(msg),
} }
resp := &callbackstruct.CallbackBeforePushResp{} resp := &callbackstruct.CallbackBeforePushResp{}
if err := http.CallBackPostReturn(ctx, url(), req, resp, config.Config.Callback.CallbackOfflinePush); err != nil { if err := http.CallBackPostReturn(ctx, url(), req, resp, config.Config.Callback.CallbackOfflinePush); err != nil {
return err return err
} }
if len(resp.UserIDs) != 0 { if len(resp.UserIDs) != 0 {
*offlinePushUserIDs = resp.UserIDs *offlinePushUserIDs = resp.UserIDs
} }

View File

@ -39,20 +39,22 @@ type Fcm struct {
cache cache.MsgModel cache cache.MsgModel
} }
// NewClient initializes a new FCM client using the Firebase Admin SDK.
// It requires the FCM service account credentials file located within the project's configuration directory.
func NewClient(cache cache.MsgModel) *Fcm { func NewClient(cache cache.MsgModel) *Fcm {
projectRoot := config.GetProjectRoot() projectRoot, _ := config.GetProjectRoot()
credentialsFilePath := filepath.Join(projectRoot, "config", config.Config.Push.Fcm.ServiceAccount) credentialsFilePath := filepath.Join(projectRoot, "config", config.Config.Push.Fcm.ServiceAccount)
opt := option.WithCredentialsFile(credentialsFilePath) opt := option.WithCredentialsFile(credentialsFilePath)
fcmApp, err := firebase.NewApp(context.Background(), nil, opt) fcmApp, err := firebase.NewApp(context.Background(), nil, opt)
if err != nil { if err != nil {
return nil return nil
} }
ctx := context.Background() ctx := context.Background()
fcmMsgClient, err := fcmApp.Messaging(ctx) fcmMsgClient, err := fcmApp.Messaging(ctx)
if err != nil { if err != nil {
return nil return nil
} }
return &Fcm{fcmMsgCli: fcmMsgClient, cache: cache} return &Fcm{fcmMsgCli: fcmMsgClient, cache: cache}
} }

View File

@ -229,7 +229,8 @@ func (p *Pusher) Push2SuperGroup(ctx context.Context, groupID string, msg *sdkws
}(groupID, kickedUsers) }(groupID, kickedUsers)
pushToUserIDs = append(pushToUserIDs, kickedUsers...) pushToUserIDs = append(pushToUserIDs, kickedUsers...)
case constant.GroupDismissedNotification: case constant.GroupDismissedNotification:
if msgprocessor.IsNotification(msgprocessor.GetConversationIDByMsg(msg)) { // 消息先到,通知后到 // Messages arrive first, notifications arrive later
if msgprocessor.IsNotification(msgprocessor.GetConversationIDByMsg(msg)) {
var tips sdkws.GroupDismissedTips var tips sdkws.GroupDismissedTips
if p.UnmarshalNotificationElem(msg.Content, &tips) != nil { if p.UnmarshalNotificationElem(msg.Content, &tips) != nil {
return err return err

View File

@ -310,7 +310,7 @@ func (c *conversationServer) SetConversations(ctx context.Context,
unequal++ unequal++
} }
} }
if err := c.conversationDatabase.SetUsersConversationFiledTx(ctx, req.UserIDs, &conversation, m); err != nil { if err := c.conversationDatabase.SetUsersConversationFieldTx(ctx, req.UserIDs, &conversation, m); err != nil {
return nil, err return nil, err
} }
if unequal > 0 { if unequal > 0 {
@ -321,7 +321,7 @@ func (c *conversationServer) SetConversations(ctx context.Context,
return &pbconversation.SetConversationsResp{}, nil return &pbconversation.SetConversationsResp{}, nil
} }
// 获取超级大群开启免打扰的用户ID. // Get user IDs with "Do Not Disturb" enabled in super large groups.
func (c *conversationServer) GetRecvMsgNotNotifyUserIDs(ctx context.Context, req *pbconversation.GetRecvMsgNotNotifyUserIDsReq) (*pbconversation.GetRecvMsgNotNotifyUserIDsResp, error) { func (c *conversationServer) GetRecvMsgNotNotifyUserIDs(ctx context.Context, req *pbconversation.GetRecvMsgNotNotifyUserIDsReq) (*pbconversation.GetRecvMsgNotNotifyUserIDsResp, error) {
//userIDs, err := c.conversationDatabase.FindRecvMsgNotNotifyUserIDs(ctx, req.GroupID) //userIDs, err := c.conversationDatabase.FindRecvMsgNotNotifyUserIDs(ctx, req.GroupID)
//if err != nil { //if err != nil {
@ -378,7 +378,7 @@ func (c *conversationServer) CreateGroupChatConversations(ctx context.Context, r
} }
func (c *conversationServer) SetConversationMaxSeq(ctx context.Context, req *pbconversation.SetConversationMaxSeqReq) (*pbconversation.SetConversationMaxSeqResp, error) { func (c *conversationServer) SetConversationMaxSeq(ctx context.Context, req *pbconversation.SetConversationMaxSeqReq) (*pbconversation.SetConversationMaxSeqResp, error) {
if err := c.conversationDatabase.UpdateUsersConversationFiled(ctx, req.OwnerUserID, req.ConversationID, if err := c.conversationDatabase.UpdateUsersConversationField(ctx, req.OwnerUserID, req.ConversationID,
map[string]any{"max_seq": req.MaxSeq}); err != nil { map[string]any{"max_seq": req.MaxSeq}); err != nil {
return nil, err return nil, err
} }

View File

@ -278,6 +278,7 @@ func (s *friendServer) GetDesignatedFriends(ctx context.Context, req *pbfriend.G
return resp, nil return resp, nil
} }
// Get the list of friend requests sent out proactively.
func (s *friendServer) GetDesignatedFriendsApply(ctx context.Context, func (s *friendServer) GetDesignatedFriendsApply(ctx context.Context,
req *pbfriend.GetDesignatedFriendsApplyReq) (resp *pbfriend.GetDesignatedFriendsApplyResp, err error) { req *pbfriend.GetDesignatedFriendsApplyReq) (resp *pbfriend.GetDesignatedFriendsApplyResp, err error) {
friendRequests, err := s.friendDatabase.FindBothFriendRequests(ctx, req.FromUserID, req.ToUserID) friendRequests, err := s.friendDatabase.FindBothFriendRequests(ctx, req.FromUserID, req.ToUserID)
@ -292,7 +293,7 @@ func (s *friendServer) GetDesignatedFriendsApply(ctx context.Context,
return resp, nil return resp, nil
} }
// ok 获取接收到的好友申请(即别人主动申请的). // Get received friend requests (i.e., those initiated by others).
func (s *friendServer) GetPaginationFriendsApplyTo(ctx context.Context, req *pbfriend.GetPaginationFriendsApplyToReq) (resp *pbfriend.GetPaginationFriendsApplyToResp, err error) { func (s *friendServer) GetPaginationFriendsApplyTo(ctx context.Context, req *pbfriend.GetPaginationFriendsApplyToReq) (resp *pbfriend.GetPaginationFriendsApplyToResp, err error) {
defer log.ZInfo(ctx, utils.GetFuncName()+" Return") defer log.ZInfo(ctx, utils.GetFuncName()+" Return")
if err := s.userRpcClient.Access(ctx, req.UserID); err != nil { if err := s.userRpcClient.Access(ctx, req.UserID); err != nil {
@ -311,7 +312,6 @@ func (s *friendServer) GetPaginationFriendsApplyTo(ctx context.Context, req *pbf
return resp, nil return resp, nil
} }
// ok 获取主动发出去的好友申请列表.
func (s *friendServer) GetPaginationFriendsApplyFrom(ctx context.Context, req *pbfriend.GetPaginationFriendsApplyFromReq) (resp *pbfriend.GetPaginationFriendsApplyFromResp, err error) { func (s *friendServer) GetPaginationFriendsApplyFrom(ctx context.Context, req *pbfriend.GetPaginationFriendsApplyFromReq) (resp *pbfriend.GetPaginationFriendsApplyFromResp, err error) {
defer log.ZInfo(ctx, utils.GetFuncName()+" Return") defer log.ZInfo(ctx, utils.GetFuncName()+" Return")
resp = &pbfriend.GetPaginationFriendsApplyFromResp{} resp = &pbfriend.GetPaginationFriendsApplyFromResp{}

View File

@ -765,8 +765,8 @@ func (s *groupServer) GroupApplicationResponse(ctx context.Context, req *pbgroup
return nil, errs.ErrGroupRequestHandled.Wrap("group request already processed") return nil, errs.ErrGroupRequestHandled.Wrap("group request already processed")
} }
var inGroup bool var inGroup bool
if _, err := s.db.TakeGroupMember(ctx, req.GroupID, req.FromUserID); err == nil { if _, takeErr := s.db.TakeGroupMember(ctx, req.GroupID, req.FromUserID); takeErr == nil {
inGroup = true // 已经在群里了 inGroup = true // Already in group
} else if !s.IsNotFound(err) { } else if !s.IsNotFound(err) {
return nil, err return nil, err
} }

View File

@ -67,7 +67,7 @@ func Start(client discoveryregistry.SvcDiscoveryRegistry, server *grpc.Server) e
if err != nil { if err != nil {
return err return err
} }
// 根据配置文件策略选择 oss 方式 // Select based on the configuration file strategy
enable := config.Config.Object.Enable enable := config.Config.Object.Enable
var o s3.Interface var o s3.Interface
switch config.Config.Object.Enable { switch config.Config.Object.Enable {

View File

@ -58,9 +58,9 @@ import (
// continue // continue
// } // }
// if len(seqs) > 0 { // if len(seqs) > 0 {
// if err := c.conversationDatabase.UpdateUsersConversationFiled(ctx, []string{conversation.OwnerUserID}, conversation.ConversationID, map[string]interface{}{"latest_msg_destruct_time": now}); err // if err := c.conversationDatabase.UpdateUsersConversationField(ctx, []string{conversation.OwnerUserID}, conversation.ConversationID, map[string]interface{}{"latest_msg_destruct_time": now}); err
// != nil { // != nil {
// log.ZError(ctx, "updateUsersConversationFiled failed", err, "conversationID", conversation.ConversationID, "ownerUserID", conversation.OwnerUserID) // log.ZError(ctx, "updateUsersConversationField failed", err, "conversationID", conversation.ConversationID, "ownerUserID", conversation.OwnerUserID)
// continue // continue
// } // }
// if err := c.msgNotificationSender.UserDeleteMsgsNotification(ctx, conversation.OwnerUserID, conversation.ConversationID, seqs); err != nil { // if err := c.msgNotificationSender.UserDeleteMsgsNotification(ctx, conversation.OwnerUserID, conversation.ConversationID, seqs); err != nil {
@ -139,8 +139,8 @@ func (c *MsgTool) ConversationsDestructMsgs() {
continue continue
} }
if len(seqs) > 0 { if len(seqs) > 0 {
if err := c.conversationDatabase.UpdateUsersConversationFiled(ctx, []string{conversation.OwnerUserID}, conversation.ConversationID, map[string]any{"latest_msg_destruct_time": now}); err != nil { if err := c.conversationDatabase.UpdateUsersConversationField(ctx, []string{conversation.OwnerUserID}, conversation.ConversationID, map[string]any{"latest_msg_destruct_time": now}); err != nil {
log.ZError(ctx, "updateUsersConversationFiled failed", err, "conversationID", conversation.ConversationID, "ownerUserID", conversation.OwnerUserID) log.ZError(ctx, "updateUsersConversationField failed", err, "conversationID", conversation.ConversationID, "ownerUserID", conversation.OwnerUserID)
continue continue
} }
if err := c.msgNotificationSender.UserDeleteMsgsNotification(ctx, conversation.OwnerUserID, conversation.ConversationID, seqs); err != nil { if err := c.msgNotificationSender.UserDeleteMsgsNotification(ctx, conversation.OwnerUserID, conversation.ConversationID, seqs); err != nil {

View File

@ -32,7 +32,7 @@ import (
) )
func StartTask() error { func StartTask() error {
fmt.Println("cron task start, config", config.Config.ChatRecordsClearTime) fmt.Println("Cron task start, config:", config.Config.ChatRecordsClearTime)
msgTool, err := InitMsgTool() msgTool, err := InitMsgTool()
if err != nil { if err != nil {
@ -48,16 +48,16 @@ func StartTask() error {
// register cron tasks // register cron tasks
var crontab = cron.New() var crontab = cron.New()
fmt.Println("start chatRecordsClearTime cron task", "cron config", config.Config.ChatRecordsClearTime) fmt.Printf("Start chatRecordsClearTime cron task, cron config: %s\n", config.Config.ChatRecordsClearTime)
_, err = crontab.AddFunc(config.Config.ChatRecordsClearTime, cronWrapFunc(rdb, "cron_clear_msg_and_fix_seq", msgTool.AllConversationClearMsgAndFixSeq)) _, err = crontab.AddFunc(config.Config.ChatRecordsClearTime, cronWrapFunc(rdb, "cron_clear_msg_and_fix_seq", msgTool.AllConversationClearMsgAndFixSeq))
if err != nil { if err != nil {
return errs.Wrap(err) return errs.Wrap(err)
} }
fmt.Println("start msgDestruct cron task", "cron config", config.Config.MsgDestructTime) fmt.Printf("Start msgDestruct cron task, cron config: %s\n", config.Config.MsgDestructTime)
_, err = crontab.AddFunc(config.Config.MsgDestructTime, cronWrapFunc(rdb, "cron_conversations_destruct_msgs", msgTool.ConversationsDestructMsgs)) _, err = crontab.AddFunc(config.Config.MsgDestructTime, cronWrapFunc(rdb, "cron_conversations_destruct_msgs", msgTool.ConversationsDestructMsgs))
if err != nil { if err != nil {
return errs.Wrap(err) return errs.Wrap(err, "cron_conversations_destruct_msgs")
} }
// start crontab // start crontab

View File

@ -197,7 +197,8 @@ func (c *MsgTool) checkMaxSeqWithMongo(ctx context.Context, conversationID strin
return err return err
} }
if math.Abs(float64(maxSeqMongo-maxSeqCache)) > 10 { if math.Abs(float64(maxSeqMongo-maxSeqCache)) > 10 {
log.ZError(ctx, "cache max seq and mongo max seq is diff > 10", nil, "maxSeqMongo", maxSeqMongo, "minSeqMongo", minSeqMongo, "maxSeqCache", maxSeqCache, "conversationID", conversationID) err = fmt.Errorf("cache max seq and mongo max seq is diff > 10, maxSeqMongo:%d,minSeqMongo:%d,maxSeqCache:%d,conversationID:%s", maxSeqMongo, minSeqMongo, maxSeqCache, conversationID)
return errs.Wrap(err)
} }
return nil return nil
} }
@ -219,7 +220,6 @@ func (c *MsgTool) checkMaxSeq(ctx context.Context, conversationID string) error
func (c *MsgTool) FixAllSeq(ctx context.Context) error { func (c *MsgTool) FixAllSeq(ctx context.Context) error {
conversationIDs, err := c.conversationDatabase.GetAllConversationIDs(ctx) conversationIDs, err := c.conversationDatabase.GetAllConversationIDs(ctx)
if err != nil { if err != nil {
log.ZError(ctx, "GetAllConversationIDs failed", err)
return err return err
} }
for _, conversationID := range conversationIDs { for _, conversationID := range conversationIDs {

View File

@ -67,6 +67,7 @@ type LocationElem struct {
Longitude float64 `mapstructure:"longitude" validate:"required"` Longitude float64 `mapstructure:"longitude" validate:"required"`
Latitude float64 `mapstructure:"latitude" validate:"required"` Latitude float64 `mapstructure:"latitude" validate:"required"`
} }
type CustomElem struct { type CustomElem struct {
Data string `mapstructure:"data" validate:"required"` Data string `mapstructure:"data" validate:"required"`
Description string `mapstructure:"description"` Description string `mapstructure:"description"`

View File

@ -44,7 +44,7 @@ func CheckAccessV3(ctx context.Context, ownerUserID string) (err error) {
if opUserID == ownerUserID { if opUserID == ownerUserID {
return nil return nil
} }
return errs.ErrNoPermission.Wrap(utils.GetSelfFuncName()) return errs.Wrap(errs.ErrNoPermission, "CheckAccessV3: no permission for user "+opUserID)
} }
func IsAppManagerUid(ctx context.Context) bool { func IsAppManagerUid(ctx context.Context) bool {
@ -61,6 +61,7 @@ func CheckAdmin(ctx context.Context) error {
} }
return errs.ErrNoPermission.Wrap(fmt.Sprintf("user %s is not admin userID", mcontext.GetOpUserID(ctx))) return errs.ErrNoPermission.Wrap(fmt.Sprintf("user %s is not admin userID", mcontext.GetOpUserID(ctx)))
} }
func CheckIMAdmin(ctx context.Context) error { func CheckIMAdmin(ctx context.Context) error {
if utils.IsContain(mcontext.GetOpUserID(ctx), config.Config.IMAdmin.UserID) { if utils.IsContain(mcontext.GetOpUserID(ctx), config.Config.IMAdmin.UserID) {
return nil return nil

View File

@ -15,6 +15,9 @@
package cmd package cmd
import ( import (
"errors"
"fmt"
"github.com/OpenIMSDK/protocol/constant" "github.com/OpenIMSDK/protocol/constant"
"github.com/spf13/cobra" "github.com/spf13/cobra"
@ -32,17 +35,35 @@ func NewApiCmd() *ApiCmd {
return ret return ret
} }
// AddApi configures the API command to run with specified ports for the API and Prometheus monitoring.
// It ensures error handling for port retrieval and only proceeds if both port numbers are successfully obtained.
func (a *ApiCmd) AddApi(f func(port int, promPort int) error) { func (a *ApiCmd) AddApi(f func(port int, promPort int) error) {
a.Command.RunE = func(cmd *cobra.Command, args []string) error { a.Command.RunE = func(cmd *cobra.Command, args []string) error {
return f(a.getPortFlag(cmd), a.getPrometheusPortFlag(cmd)) port, err := a.getPortFlag(cmd)
if err != nil {
return err
}
promPort, err := a.getPrometheusPortFlag(cmd)
if err != nil {
return err
}
return f(port, promPort)
} }
} }
func (a *ApiCmd) GetPortFromConfig(portType string) int { func (a *ApiCmd) GetPortFromConfig(portType string) (int, error) {
if portType == constant.FlagPort { if portType == constant.FlagPort {
return config2.Config.Api.OpenImApiPort[0] if len(config2.Config.Api.OpenImApiPort) > 0 {
} else if portType == constant.FlagPrometheusPort { return config2.Config.Api.OpenImApiPort[0], nil
return config2.Config.Prometheus.ApiPrometheusPort[0]
} }
return 0 return 0, errors.New("API port configuration is empty or missing")
} else if portType == constant.FlagPrometheusPort {
if len(config2.Config.Prometheus.ApiPrometheusPort) > 0 {
return config2.Config.Prometheus.ApiPrometheusPort[0], nil
}
return 0, errors.New("Prometheus port configuration is empty or missing")
}
return 0, fmt.Errorf("unknown port type: %s", portType)
} }

View File

@ -36,3 +36,7 @@ func (c *CronTaskCmd) Exec(f func() error) error {
c.addRunE(f) c.addRunE(f)
return c.Execute() return c.Execute()
} }
func (c *CronTaskCmd) GetPortFromConfig(portType string) (int, error) {
return 0, nil
}

View File

@ -15,13 +15,15 @@
package cmd package cmd
import ( import (
"log" "errors"
"github.com/spf13/cobra" "github.com/spf13/cobra"
"github.com/OpenIMSDK/protocol/constant"
"github.com/openimsdk/open-im-server/v3/internal/msggateway" "github.com/openimsdk/open-im-server/v3/internal/msggateway"
"github.com/OpenIMSDK/protocol/constant"
"github.com/OpenIMSDK/tools/errs"
v3config "github.com/openimsdk/open-im-server/v3/pkg/common/config" v3config "github.com/openimsdk/open-im-server/v3/pkg/common/config"
) )
@ -39,20 +41,32 @@ func (m *MsgGatewayCmd) AddWsPortFlag() {
m.Command.Flags().IntP(constant.FlagWsPort, "w", 0, "ws server listen port") m.Command.Flags().IntP(constant.FlagWsPort, "w", 0, "ws server listen port")
} }
func (m *MsgGatewayCmd) getWsPortFlag(cmd *cobra.Command) int { func (m *MsgGatewayCmd) getWsPortFlag(cmd *cobra.Command) (int, error) {
port, err := cmd.Flags().GetInt(constant.FlagWsPort) port, err := cmd.Flags().GetInt(constant.FlagWsPort)
if err != nil { if err != nil {
log.Println("Error getting ws port flag:", err) return 0, errs.Wrap(err, "error getting ws port flag")
} }
if port == 0 { if port == 0 {
port = m.PortFromConfig(constant.FlagWsPort) port, _ = m.PortFromConfig(constant.FlagWsPort)
} }
return port return port, nil
} }
func (m *MsgGatewayCmd) addRunE() { func (m *MsgGatewayCmd) addRunE() {
m.Command.RunE = func(cmd *cobra.Command, args []string) error { m.Command.RunE = func(cmd *cobra.Command, args []string) error {
return msggateway.RunWsAndServer(m.getPortFlag(cmd), m.getWsPortFlag(cmd), m.getPrometheusPortFlag(cmd)) wsPort, err := m.getWsPortFlag(cmd)
if err != nil {
return errs.Wrap(err, "failed to get WS port flag")
}
port, err := m.getPortFlag(cmd)
if err != nil {
return err
}
prometheusPort, err := m.getPrometheusPortFlag(cmd)
if err != nil {
return err
}
return msggateway.RunWsAndServer(port, wsPort, prometheusPort)
} }
} }
@ -61,18 +75,33 @@ func (m *MsgGatewayCmd) Exec() error {
return m.Execute() return m.Execute()
} }
func (m *MsgGatewayCmd) GetPortFromConfig(portType string) int { func (m *MsgGatewayCmd) GetPortFromConfig(portType string) (int, error) {
var port int
var exists bool
switch portType { switch portType {
case constant.FlagWsPort: case constant.FlagWsPort:
return v3config.Config.LongConnSvr.OpenImWsPort[0] if len(v3config.Config.LongConnSvr.OpenImWsPort) > 0 {
port = v3config.Config.LongConnSvr.OpenImWsPort[0]
exists = true
}
case constant.FlagPort: case constant.FlagPort:
return v3config.Config.LongConnSvr.OpenImMessageGatewayPort[0] if len(v3config.Config.LongConnSvr.OpenImMessageGatewayPort) > 0 {
port = v3config.Config.LongConnSvr.OpenImMessageGatewayPort[0]
exists = true
}
case constant.FlagPrometheusPort: case constant.FlagPrometheusPort:
return v3config.Config.Prometheus.MessageGatewayPrometheusPort[0] if len(v3config.Config.Prometheus.MessageGatewayPrometheusPort) > 0 {
port = v3config.Config.Prometheus.MessageGatewayPrometheusPort[0]
default: exists = true
return 0
} }
}
if !exists {
return 0, errs.Wrap(errors.New("port type '%s' not found in configuration"), portType)
}
return port, nil
} }

View File

@ -44,7 +44,7 @@ func TestMsgGatewayCmd_GetPortFromConfig(t *testing.T) {
} }
for _, tt := range tests { for _, tt := range tests {
t.Run(tt.portType, func(t *testing.T) { t.Run(tt.portType, func(t *testing.T) {
got := msgGatewayCmd.GetPortFromConfig(tt.portType) got, _ := msgGatewayCmd.GetPortFromConfig(tt.portType)
assert.Equal(t, tt.want, got) assert.Equal(t, tt.want, got)
}) })
} }

View File

@ -37,7 +37,11 @@ func NewMsgTransferCmd() *MsgTransferCmd {
func (m *MsgTransferCmd) addRunE() { func (m *MsgTransferCmd) addRunE() {
m.Command.RunE = func(cmd *cobra.Command, args []string) error { m.Command.RunE = func(cmd *cobra.Command, args []string) error {
return msgtransfer.StartTransfer(m.getPrometheusPortFlag(cmd)) prometheusPort, err := m.getPrometheusPortFlag(cmd)
if err != nil {
return err
}
return msgtransfer.StartTransfer(prometheusPort)
} }
} }
@ -46,14 +50,18 @@ func (m *MsgTransferCmd) Exec() error {
return m.Execute() return m.Execute()
} }
func (m *MsgTransferCmd) GetPortFromConfig(portType string) int { func (m *MsgTransferCmd) GetPortFromConfig(portType string) (int, error) {
if portType == constant.FlagPort { if portType == constant.FlagPort {
return 0 return 0, nil
} else if portType == constant.FlagPrometheusPort { } else if portType == constant.FlagPrometheusPort {
n := m.getTransferProgressFlagValue() n := m.getTransferProgressFlagValue()
return config2.Config.Prometheus.MessageTransferPrometheusPort[n]
if n < len(config2.Config.Prometheus.MessageTransferPrometheusPort) {
return config2.Config.Prometheus.MessageTransferPrometheusPort[n], nil
} }
return 0 return 0, fmt.Errorf("index out of range for MessageTransferPrometheusPort with index %d", n)
}
return 0, fmt.Errorf("unknown port type: %s", portType)
} }
func (m *MsgTransferCmd) AddTransferProgressFlag() { func (m *MsgTransferCmd) AddTransferProgressFlag() {

View File

@ -18,6 +18,7 @@ import (
"github.com/spf13/cobra" "github.com/spf13/cobra"
"github.com/openimsdk/open-im-server/v3/internal/tools" "github.com/openimsdk/open-im-server/v3/internal/tools"
util "github.com/openimsdk/open-im-server/v3/pkg/util/genutil"
) )
type MsgUtilsCmd struct { type MsgUtilsCmd struct {
@ -137,7 +138,7 @@ func (s *SeqCmd) GetSeqCmd() *cobra.Command {
s.Command.Run = func(cmdLines *cobra.Command, args []string) { s.Command.Run = func(cmdLines *cobra.Command, args []string) {
_, err := tools.InitMsgTool() _, err := tools.InitMsgTool()
if err != nil { if err != nil {
panic(err) util.ExitWithError(err)
} }
userID := s.getUserIDFlag(cmdLines) userID := s.getUserIDFlag(cmdLines)
superGroupID := s.getSuperGroupIDFlag(cmdLines) superGroupID := s.getSuperGroupIDFlag(cmdLines)

View File

@ -19,6 +19,8 @@ import (
config2 "github.com/openimsdk/open-im-server/v3/pkg/common/config" config2 "github.com/openimsdk/open-im-server/v3/pkg/common/config"
"github.com/OpenIMSDK/tools/errs"
"github.com/spf13/cobra" "github.com/spf13/cobra"
"github.com/OpenIMSDK/protocol/constant" "github.com/OpenIMSDK/protocol/constant"
@ -28,8 +30,9 @@ import (
) )
type RootCmdPt interface { type RootCmdPt interface {
GetPortFromConfig(portType string) int GetPortFromConfig(portType string) (int, error)
} }
type RootCmd struct { type RootCmd struct {
Command cobra.Command Command cobra.Command
Name string Name string
@ -77,7 +80,7 @@ func (rc *RootCmd) persistentPreRun(cmd *cobra.Command, opts ...func(*CmdOpts))
cmdOpts := rc.applyOptions(opts...) cmdOpts := rc.applyOptions(opts...)
if err := rc.initializeLogger(cmdOpts); err != nil { if err := rc.initializeLogger(cmdOpts); err != nil {
return fmt.Errorf("failed to initialize from config: %w", err) return errs.Wrap(err, "failed to initialize logger")
} }
return nil return nil
@ -130,31 +133,41 @@ func (r *RootCmd) AddPortFlag() {
r.Command.Flags().IntP(constant.FlagPort, "p", 0, "server listen port") r.Command.Flags().IntP(constant.FlagPort, "p", 0, "server listen port")
} }
func (r *RootCmd) getPortFlag(cmd *cobra.Command) int { func (r *RootCmd) getPortFlag(cmd *cobra.Command) (int, error) {
port, err := cmd.Flags().GetInt(constant.FlagPort) port, err := cmd.Flags().GetInt(constant.FlagPort)
if err != nil { if err != nil {
fmt.Println("Error getting ws port flag:", err) // Wrapping the error with additional context
return 0, errs.Wrap(err, "error getting port flag")
} }
if port == 0 { if port == 0 {
port = r.PortFromConfig(constant.FlagPort) port, _ = r.PortFromConfig(constant.FlagPort)
// port, err := r.PortFromConfig(constant.FlagPort)
// if err != nil {
// // Optionally wrap the error if it's an internal error needing context
// return 0, errs.Wrap(err, "error getting port from config")
// }
} }
return port return port, nil
} }
func (r *RootCmd) GetPortFlag() int { // // GetPortFlag returns the port flag.
return r.port func (r *RootCmd) GetPortFlag() (int, error) {
return r.port, nil
} }
func (r *RootCmd) AddPrometheusPortFlag() { func (r *RootCmd) AddPrometheusPortFlag() {
r.Command.Flags().IntP(constant.FlagPrometheusPort, "", 0, "server prometheus listen port") r.Command.Flags().IntP(constant.FlagPrometheusPort, "", 0, "server prometheus listen port")
} }
func (r *RootCmd) getPrometheusPortFlag(cmd *cobra.Command) int { func (r *RootCmd) getPrometheusPortFlag(cmd *cobra.Command) (int, error) {
port, _ := cmd.Flags().GetInt(constant.FlagPrometheusPort) port, err := cmd.Flags().GetInt(constant.FlagPrometheusPort)
if port == 0 { if err != nil || port == 0 {
port = r.PortFromConfig(constant.FlagPrometheusPort) port, err = r.PortFromConfig(constant.FlagPrometheusPort)
if err != nil {
return 0, err
} }
return port }
return port, nil
} }
func (r *RootCmd) GetPrometheusPortFlag() int { func (r *RootCmd) GetPrometheusPortFlag() int {
@ -175,10 +188,11 @@ func (r *RootCmd) AddCommand(cmds ...*cobra.Command) {
r.Command.AddCommand(cmds...) r.Command.AddCommand(cmds...)
} }
func (r *RootCmd) GetPortFromConfig(portType string) int { func (r *RootCmd) PortFromConfig(portType string) (int, error) {
return 0 // Retrieve the port and cache it
} port, err := r.cmdItf.GetPortFromConfig(portType)
if err != nil {
func (r *RootCmd) PortFromConfig(portType string) int { return 0, err
return r.cmdItf.GetPortFromConfig(portType) }
return port, nil
} }

View File

@ -16,11 +16,14 @@ package cmd
import ( import (
"errors" "errors"
"fmt"
"github.com/OpenIMSDK/protocol/constant" "github.com/OpenIMSDK/protocol/constant"
"github.com/spf13/cobra" "github.com/spf13/cobra"
"google.golang.org/grpc" "google.golang.org/grpc"
"github.com/OpenIMSDK/tools/errs"
config2 "github.com/openimsdk/open-im-server/v3/pkg/common/config" config2 "github.com/openimsdk/open-im-server/v3/pkg/common/config"
"github.com/OpenIMSDK/tools/discoveryregistry" "github.com/OpenIMSDK/tools/discoveryregistry"
@ -39,78 +42,78 @@ func NewRpcCmd(name string) *RpcCmd {
} }
func (a *RpcCmd) Exec() error { func (a *RpcCmd) Exec() error {
a.Command.Run = func(cmd *cobra.Command, args []string) { a.Command.RunE = func(cmd *cobra.Command, args []string) error {
a.port = a.getPortFlag(cmd) portFlag, err := a.getPortFlag(cmd)
a.prometheusPort = a.getPrometheusPortFlag(cmd) if err != nil {
return err
}
a.port = portFlag
prometheusPort, err := a.getPrometheusPortFlag(cmd)
if err != nil {
return err
}
a.prometheusPort = prometheusPort
return nil
} }
return a.Execute() return a.Execute()
} }
func (a *RpcCmd) StartSvr(name string, rpcFn func(discov discoveryregistry.SvcDiscoveryRegistry, server *grpc.Server) error) error { func (a *RpcCmd) StartSvr(name string, rpcFn func(discov discoveryregistry.SvcDiscoveryRegistry, server *grpc.Server) error) error {
if a.GetPortFlag() == 0 { portFlag, err := a.GetPortFlag()
return errors.New("port is required") if err != nil {
return err
} else {
a.port = portFlag
} }
return startrpc.Start(a.GetPortFlag(), name, a.GetPrometheusPortFlag(), rpcFn)
return startrpc.Start(portFlag, name, a.GetPrometheusPortFlag(), rpcFn)
} }
func (a *RpcCmd) GetPortFromConfig(portType string) int { func (a *RpcCmd) GetPortFromConfig(portType string) (int, error) {
switch a.Name { portConfigMap := map[string]map[string]int{
case RpcPushServer: RpcPushServer: {
if portType == constant.FlagPort { constant.FlagPort: config2.Config.RpcPort.OpenImPushPort[0],
return config2.Config.RpcPort.OpenImPushPort[0] constant.FlagPrometheusPort: config2.Config.Prometheus.PushPrometheusPort[0],
},
RpcAuthServer: {
constant.FlagPort: config2.Config.RpcPort.OpenImAuthPort[0],
constant.FlagPrometheusPort: config2.Config.Prometheus.AuthPrometheusPort[0],
},
RpcConversationServer: {
constant.FlagPort: config2.Config.RpcPort.OpenImConversationPort[0],
constant.FlagPrometheusPort: config2.Config.Prometheus.ConversationPrometheusPort[0],
},
RpcFriendServer: {
constant.FlagPort: config2.Config.RpcPort.OpenImFriendPort[0],
constant.FlagPrometheusPort: config2.Config.Prometheus.FriendPrometheusPort[0],
},
RpcGroupServer: {
constant.FlagPort: config2.Config.RpcPort.OpenImGroupPort[0],
constant.FlagPrometheusPort: config2.Config.Prometheus.GroupPrometheusPort[0],
},
RpcMsgServer: {
constant.FlagPort: config2.Config.RpcPort.OpenImMessagePort[0],
constant.FlagPrometheusPort: config2.Config.Prometheus.MessagePrometheusPort[0],
},
RpcThirdServer: {
constant.FlagPort: config2.Config.RpcPort.OpenImThirdPort[0],
constant.FlagPrometheusPort: config2.Config.Prometheus.ThirdPrometheusPort[0],
},
RpcUserServer: {
constant.FlagPort: config2.Config.RpcPort.OpenImUserPort[0],
constant.FlagPrometheusPort: config2.Config.Prometheus.UserPrometheusPort[0],
},
} }
if portType == constant.FlagPrometheusPort {
return config2.Config.Prometheus.PushPrometheusPort[0] if portMap, ok := portConfigMap[a.Name]; ok {
} if port, ok := portMap[portType]; ok {
case RpcAuthServer: return port, nil
if portType == constant.FlagPort { } else {
return config2.Config.RpcPort.OpenImAuthPort[0] return 0, errs.Wrap(errors.New("port type not found"), fmt.Sprintf("Failed to get port for %s", a.Name))
}
if portType == constant.FlagPrometheusPort {
return config2.Config.Prometheus.AuthPrometheusPort[0]
}
case RpcConversationServer:
if portType == constant.FlagPort {
return config2.Config.RpcPort.OpenImConversationPort[0]
}
if portType == constant.FlagPrometheusPort {
return config2.Config.Prometheus.ConversationPrometheusPort[0]
}
case RpcFriendServer:
if portType == constant.FlagPort {
return config2.Config.RpcPort.OpenImFriendPort[0]
}
if portType == constant.FlagPrometheusPort {
return config2.Config.Prometheus.FriendPrometheusPort[0]
}
case RpcGroupServer:
if portType == constant.FlagPort {
return config2.Config.RpcPort.OpenImGroupPort[0]
}
if portType == constant.FlagPrometheusPort {
return config2.Config.Prometheus.GroupPrometheusPort[0]
}
case RpcMsgServer:
if portType == constant.FlagPort {
return config2.Config.RpcPort.OpenImMessagePort[0]
}
if portType == constant.FlagPrometheusPort {
return config2.Config.Prometheus.MessagePrometheusPort[0]
}
case RpcThirdServer:
if portType == constant.FlagPort {
return config2.Config.RpcPort.OpenImThirdPort[0]
}
if portType == constant.FlagPrometheusPort {
return config2.Config.Prometheus.ThirdPrometheusPort[0]
}
case RpcUserServer:
if portType == constant.FlagPort {
return config2.Config.RpcPort.OpenImUserPort[0]
}
if portType == constant.FlagPrometheusPort {
return config2.Config.Prometheus.UserPrometheusPort[0]
} }
} }
return 0
return 0, errs.Wrap(fmt.Errorf("server name '%s' not found", a.Name), "Failed to get port configuration")
} }

View File

@ -21,6 +21,7 @@ import (
"path/filepath" "path/filepath"
"github.com/OpenIMSDK/protocol/constant" "github.com/OpenIMSDK/protocol/constant"
"github.com/OpenIMSDK/tools/errs"
"gopkg.in/yaml.v3" "gopkg.in/yaml.v3"
"github.com/openimsdk/open-im-server/v3/pkg/msgprocessor" "github.com/openimsdk/open-im-server/v3/pkg/msgprocessor"
@ -36,32 +37,38 @@ const (
DefaultFolderPath = "../config/" DefaultFolderPath = "../config/"
) )
// return absolude path join ../config/, this is k8s container config path. // GetDefaultConfigPath returns the absolute path to the default configuration directory
func GetDefaultConfigPath() string { // relative to the executable's location. It is intended for use in Kubernetes container configurations.
// Errors are returned to the caller to allow for flexible error handling.
func GetDefaultConfigPath() (string, error) {
executablePath, err := os.Executable() executablePath, err := os.Executable()
if err != nil { if err != nil {
fmt.Println("GetDefaultConfigPath error:", err.Error()) return "", errs.Wrap(err, "failed to get executable path")
return ""
} }
// Calculate the config path as a directory relative to the executable's location
configPath, err := genutil.OutDir(filepath.Join(filepath.Dir(executablePath), "../config/")) configPath, err := genutil.OutDir(filepath.Join(filepath.Dir(executablePath), "../config/"))
if err != nil { if err != nil {
fmt.Fprintf(os.Stderr, "failed to get output directory: %v\n", err) return "", errs.Wrap(err, "failed to get output directory")
os.Exit(1)
} }
return configPath return configPath, nil
} }
// getProjectRoot returns the absolute path of the project root directory. // GetProjectRoot returns the absolute path of the project root directory by navigating up from the directory
func GetProjectRoot() string { // containing the executable. It provides a detailed error if the path cannot be determined.
executablePath, _ := os.Executable() func GetProjectRoot() (string, error) {
executablePath, err := os.Executable()
if err != nil {
return "", errs.Wrap(err, "failed to retrieve executable path")
}
// Attempt to compute the project root by navigating up from the executable's directory
projectRoot, err := genutil.OutDir(filepath.Join(filepath.Dir(executablePath), "../../../../..")) projectRoot, err := genutil.OutDir(filepath.Join(filepath.Dir(executablePath), "../../../../.."))
if err != nil { if err != nil {
fmt.Fprintf(os.Stderr, "failed to get output directory: %v\n", err) return "", err
os.Exit(1)
} }
return projectRoot
return projectRoot, nil
} }
func GetOptionsByNotification(cfg NotificationConf) msgprocessor.Options { func GetOptionsByNotification(cfg NotificationConf) msgprocessor.Options {
@ -83,42 +90,66 @@ func GetOptionsByNotification(cfg NotificationConf) msgprocessor.Options {
return opts return opts
} }
// initConfig loads configuration from a specified path into the provided config structure.
// If the specified config file does not exist, it attempts to load from the project's default "config" directory.
// It logs informative messages regarding the configuration path being used.
func initConfig(config any, configName, configFolderPath string) error { func initConfig(config any, configName, configFolderPath string) error {
configFolderPath = filepath.Join(configFolderPath, configName) configFilePath := filepath.Join(configFolderPath, configName)
_, err := os.Stat(configFolderPath) _, err := os.Stat(configFilePath)
if err != nil { if err != nil {
if !os.IsNotExist(err) { if !os.IsNotExist(err) {
fmt.Println("stat config path error:", err.Error()) return errs.Wrap(err, fmt.Sprintf("failed to check existence of config file at path: %s", configFilePath))
return fmt.Errorf("stat config path error: %w", err)
} }
configFolderPath = filepath.Join(GetProjectRoot(), "config", configName) var projectRoot string
fmt.Println("flag's path,enviment's path,default path all is not exist,using project path:", configFolderPath) projectRoot, err = GetProjectRoot()
}
data, err := os.ReadFile(configFolderPath)
if err != nil { if err != nil {
return fmt.Errorf("read file error: %w", err) return err
} }
if err = yaml.Unmarshal(data, config); err != nil { configFilePath = filepath.Join(projectRoot, "config", configName)
return fmt.Errorf("unmarshal yaml error: %w", err) fmt.Printf("Configuration file not found at specified path. Falling back to project path: %s\n", configFilePath)
} }
fmt.Println("The path of the configuration file to start the process:", configFolderPath)
data, err := os.ReadFile(configFilePath)
if err != nil {
// Wrap and return the error if reading the configuration file fails.
return errs.Wrap(err, fmt.Sprintf("failed to read configuration file at path: %s", configFilePath))
}
if err = yaml.Unmarshal(data, config); err != nil {
// Wrap and return the error if unmarshalling the YAML configuration fails.
return errs.Wrap(err, "failed to unmarshal YAML configuration")
}
fmt.Printf("Configuration file loaded successfully from path: %s\n", configFilePath)
return nil return nil
} }
// InitConfig initializes the application configuration by loading it from a specified folder path.
// If the folder path is not provided, it attempts to use the OPENIMCONFIG environment variable,
// and as a fallback, it uses the default configuration path. It loads both the main configuration
// and notification configuration, wrapping errors for better context.
func InitConfig(configFolderPath string) error { func InitConfig(configFolderPath string) error {
// Use the provided config folder path, or fallback to environment variable or default path
if configFolderPath == "" { if configFolderPath == "" {
envConfigPath := os.Getenv("OPENIMCONFIG") configFolderPath = os.Getenv("OPENIMCONFIG")
if envConfigPath != "" { if configFolderPath == "" {
configFolderPath = envConfigPath var err error
} else { configFolderPath, err = GetDefaultConfigPath()
configFolderPath = GetDefaultConfigPath() if err != nil {
return err
}
} }
} }
// Initialize the main configuration
if err := initConfig(&Config, FileName, configFolderPath); err != nil { if err := initConfig(&Config, FileName, configFolderPath); err != nil {
return err return err
} }
return initConfig(&Config.Notification, NotificationFileName, configFolderPath) // Initialize the notification configuration
if err := initConfig(&Config.Notification, NotificationFileName, configFolderPath); err != nil {
return err
}
return nil
} }

View File

@ -31,7 +31,7 @@ func TestGetDefaultConfigPath(t *testing.T) {
} }
for _, tt := range tests { for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) { t.Run(tt.name, func(t *testing.T) {
if got := GetDefaultConfigPath(); got != tt.want { if got, _ := GetDefaultConfigPath(); got != tt.want {
t.Errorf("GetDefaultConfigPath() = %v, want %v", got, tt.want) t.Errorf("GetDefaultConfigPath() = %v, want %v", got, tt.want)
} }
}) })
@ -47,7 +47,7 @@ func TestGetProjectRoot(t *testing.T) {
} }
for _, tt := range tests { for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) { t.Run(tt.name, func(t *testing.T) {
if got := GetProjectRoot(); got != tt.want { if got, _ := GetProjectRoot(); got != tt.want {
t.Errorf("GetProjectRoot() = %v, want %v", got, tt.want) t.Errorf("GetProjectRoot() = %v, want %v", got, tt.want)
} }
}) })

View File

@ -23,11 +23,7 @@ import (
"github.com/openimsdk/open-im-server/v3/pkg/common/db/table/relation" "github.com/openimsdk/open-im-server/v3/pkg/common/db/table/relation"
) )
func BlackDB2Pb( func BlackDB2Pb(ctx context.Context, blackDBs []*relation.BlackModel, f func(ctx context.Context, userIDs []string) (map[string]*sdkws.UserInfo, error)) (blackPbs []*sdk.BlackInfo, err error) {
ctx context.Context,
blackDBs []*relation.BlackModel,
f func(ctx context.Context, userIDs []string) (map[string]*sdkws.UserInfo, error),
) (blackPbs []*sdk.BlackInfo, err error) {
if len(blackDBs) == 0 { if len(blackDBs) == 0 {
return nil, nil return nil, nil
} }

View File

@ -53,11 +53,7 @@ func FriendDB2Pb(ctx context.Context, friendDB *relation.FriendModel,
}, nil }, nil
} }
func FriendsDB2Pb( func FriendsDB2Pb(ctx context.Context, friendsDB []*relation.FriendModel, getUsers func(ctx context.Context, userIDs []string) (map[string]*sdkws.UserInfo, error)) (friendsPb []*sdkws.FriendInfo, err error) {
ctx context.Context,
friendsDB []*relation.FriendModel,
getUsers func(ctx context.Context, userIDs []string) (map[string]*sdkws.UserInfo, error),
) (friendsPb []*sdkws.FriendInfo, err error) {
if len(friendsDB) == 0 { if len(friendsDB) == 0 {
return nil, nil return nil, nil
} }
@ -89,8 +85,7 @@ func FriendsDB2Pb(
} }
func FriendRequestDB2Pb( func FriendRequestDB2Pb(ctx context.Context,
ctx context.Context,
friendRequests []*relation.FriendRequestModel, friendRequests []*relation.FriendRequestModel,
getUsers func(ctx context.Context, userIDs []string) (map[string]*sdkws.UserInfo, error), getUsers func(ctx context.Context, userIDs []string) (map[string]*sdkws.UserInfo, error),
) ([]*sdkws.FriendRequest, error) { ) ([]*sdkws.FriendRequest, error) {

View File

@ -46,11 +46,7 @@ type BlackCacheRedis struct {
blackDB relationtb.BlackModelInterface blackDB relationtb.BlackModelInterface
} }
func NewBlackCacheRedis( func NewBlackCacheRedis(rdb redis.UniversalClient, blackDB relationtb.BlackModelInterface, options rockscache.Options) BlackCache {
rdb redis.UniversalClient,
blackDB relationtb.BlackModelInterface,
options rockscache.Options,
) BlackCache {
rcClient := rockscache.NewClient(rdb, options) rcClient := rockscache.NewClient(rdb, options)
return &BlackCacheRedis{ return &BlackCacheRedis{

View File

@ -49,7 +49,7 @@ func NewRedis() (redis.UniversalClient, error) {
overrideConfigFromEnv() overrideConfigFromEnv()
if len(config.Config.Redis.Address) == 0 { if len(config.Config.Redis.Address) == 0 {
return nil, errs.Wrap(errors.New("redis address is empty")) return nil, errs.Wrap(errors.New("redis address is empty"), "Redis configuration error")
} }
specialerror.AddReplace(redis.Nil, errs.ErrRecordNotFound) specialerror.AddReplace(redis.Nil, errs.ErrRecordNotFound)
var rdb redis.UniversalClient var rdb redis.UniversalClient
@ -65,7 +65,7 @@ func NewRedis() (redis.UniversalClient, error) {
rdb = redis.NewClient(&redis.Options{ rdb = redis.NewClient(&redis.Options{
Addr: config.Config.Redis.Address[0], Addr: config.Config.Redis.Address[0],
Username: config.Config.Redis.Username, Username: config.Config.Redis.Username,
Password: config.Config.Redis.Password, Password: config.Config.Redis.Password, // no password set
DB: 0, // use default DB DB: 0, // use default DB
PoolSize: 100, // connection pool size PoolSize: 100, // connection pool size
MaxRetries: maxRetry, MaxRetries: maxRetry,
@ -77,9 +77,9 @@ func NewRedis() (redis.UniversalClient, error) {
defer cancel() defer cancel()
err = rdb.Ping(ctx).Err() err = rdb.Ping(ctx).Err()
if err != nil { if err != nil {
uriFormat := "address:%s, username:%s, password:%s, clusterMode:%t, enablePipeline:%t" uriFormat := "address:%v, username:%s, clusterMode:%t, enablePipeline:%t"
errMsg := fmt.Sprintf(uriFormat, config.Config.Redis.Address, config.Config.Redis.Username, config.Config.Redis.Password, config.Config.Redis.ClusterMode, config.Config.Redis.EnablePipeline) errMsg := fmt.Sprintf(uriFormat, config.Config.Redis.Address, config.Config.Redis.Username, config.Config.Redis.ClusterMode, config.Config.Redis.EnablePipeline)
return nil, errs.Wrap(err, errMsg) return nil, errs.Wrap(err, "Redis connection failed: %s", errMsg)
} }
redisClient = rdb redisClient = rdb
return rdb, err return rdb, err
@ -98,9 +98,11 @@ func overrideConfigFromEnv() {
config.Config.Redis.Address = strings.Split(envAddr, ",") config.Config.Redis.Address = strings.Split(envAddr, ",")
} }
} }
if envUser := os.Getenv("REDIS_USERNAME"); envUser != "" { if envUser := os.Getenv("REDIS_USERNAME"); envUser != "" {
config.Config.Redis.Username = envUser config.Config.Redis.Username = envUser
} }
if envPass := os.Getenv("REDIS_PASSWORD"); envPass != "" { if envPass := os.Getenv("REDIS_PASSWORD"); envPass != "" {
config.Config.Redis.Password = envPass config.Config.Redis.Password = envPass
} }

View File

@ -134,7 +134,7 @@ func getCache[T any](ctx context.Context, rcClient *rockscache.Client, key strin
} }
bs, err := json.Marshal(t) bs, err := json.Marshal(t)
if err != nil { if err != nil {
return "", errs.Wrap(err) return "", errs.Wrap(err, "marshal failed")
} }
write = true write = true
@ -152,8 +152,7 @@ func getCache[T any](ctx context.Context, rcClient *rockscache.Client, key strin
err = json.Unmarshal([]byte(v), &t) err = json.Unmarshal([]byte(v), &t)
if err != nil { if err != nil {
log.ZError(ctx, "cache json.Unmarshal failed", err, "key", key, "value", v, "expire", expire) log.ZError(ctx, "cache json.Unmarshal failed", err, "key", key, "value", v, "expire", expire)
return t, errs.Wrap(err, "unmarshal failed")
return t, errs.Wrap(err)
} }
return t, nil return t, nil
@ -197,14 +196,7 @@ func getCache[T any](ctx context.Context, rcClient *rockscache.Client, key strin
// return tArrays, nil // return tArrays, nil
//} //}
func batchGetCache2[T any, K comparable]( func batchGetCache2[T any, K comparable](ctx context.Context, rcClient *rockscache.Client, expire time.Duration, keys []K, keyFn func(key K) string, fns func(ctx context.Context, key K) (T, error)) ([]T, error) {
ctx context.Context,
rcClient *rockscache.Client,
expire time.Duration,
keys []K,
keyFn func(key K) string,
fns func(ctx context.Context, key K) (T, error),
) ([]T, error) {
if len(keys) == 0 { if len(keys) == 0 {
return nil, nil return nil, nil
} }

View File

@ -24,12 +24,11 @@ import (
"github.com/openimsdk/open-im-server/v3/pkg/msgprocessor" "github.com/openimsdk/open-im-server/v3/pkg/msgprocessor"
"github.com/OpenIMSDK/tools/errs"
"github.com/gogo/protobuf/jsonpb" "github.com/gogo/protobuf/jsonpb"
"github.com/OpenIMSDK/protocol/constant" "github.com/OpenIMSDK/protocol/constant"
"github.com/OpenIMSDK/protocol/sdkws" "github.com/OpenIMSDK/protocol/sdkws"
"github.com/OpenIMSDK/tools/errs"
"github.com/OpenIMSDK/tools/log" "github.com/OpenIMSDK/tools/log"
"github.com/OpenIMSDK/tools/utils" "github.com/OpenIMSDK/tools/utils"
@ -433,7 +432,7 @@ func (c *msgCache) PipeSetMessageToCache(ctx context.Context, conversationID str
for _, msg := range msgs { for _, msg := range msgs {
s, err := msgprocessor.Pb2String(msg) s, err := msgprocessor.Pb2String(msg)
if err != nil { if err != nil {
return 0, errs.Wrap(err, "pb.marshal") return 0, err
} }
key := c.getMessageCacheKey(conversationID, msg.Seq) key := c.getMessageCacheKey(conversationID, msg.Seq)
@ -442,7 +441,7 @@ func (c *msgCache) PipeSetMessageToCache(ctx context.Context, conversationID str
results, err := pipe.Exec(ctx) results, err := pipe.Exec(ctx)
if err != nil { if err != nil {
return 0, errs.Wrap(err, "pipe.set") return 0, errs.Wrap(err)
} }
for _, res := range results { for _, res := range results {

View File

@ -66,11 +66,7 @@ type UserCacheRedis struct {
rcClient *rockscache.Client rcClient *rockscache.Client
} }
func NewUserCacheRedis( func NewUserCacheRedis(rdb redis.UniversalClient, userDB relationtb.UserModelInterface, options rockscache.Options) UserCache {
rdb redis.UniversalClient,
userDB relationtb.UserModelInterface,
options rockscache.Options,
) UserCache {
rcClient := rockscache.NewClient(rdb, options) rcClient := rockscache.NewClient(rdb, options)
return &UserCacheRedis{ return &UserCacheRedis{
@ -282,8 +278,8 @@ func (u *UserCacheRedis) refreshStatusOnline(ctx context.Context, userID string,
var onlineStatus user.OnlineStatus var onlineStatus user.OnlineStatus
if !isNil { if !isNil {
err2 := json.Unmarshal([]byte(result), &onlineStatus) err2 := json.Unmarshal([]byte(result), &onlineStatus)
if err != nil { if err2 != nil {
return errs.Wrap(err2) return errs.Wrap(err, "json.Unmarshal failed")
} }
onlineStatus.PlatformIDs = RemoveRepeatedElementsInList(append(onlineStatus.PlatformIDs, platformID)) onlineStatus.PlatformIDs = RemoveRepeatedElementsInList(append(onlineStatus.PlatformIDs, platformID))
} else { } else {
@ -293,7 +289,7 @@ func (u *UserCacheRedis) refreshStatusOnline(ctx context.Context, userID string,
onlineStatus.UserID = userID onlineStatus.UserID = userID
newjsonData, err := json.Marshal(&onlineStatus) newjsonData, err := json.Marshal(&onlineStatus)
if err != nil { if err != nil {
return errs.Wrap(err) return errs.Wrap(err, "json.Marshal failed")
} }
_, err = u.rdb.HSet(ctx, key, userID, string(newjsonData)).Result() _, err = u.rdb.HSet(ctx, key, userID, string(newjsonData)).Result()
if err != nil { if err != nil {

View File

@ -30,9 +30,9 @@ import (
) )
type AuthDatabase interface { type AuthDatabase interface {
// 结果为空 不返回错误 // If the result is empty, no error is returned.
GetTokensWithoutError(ctx context.Context, userID string, platformID int) (map[string]int, error) GetTokensWithoutError(ctx context.Context, userID string, platformID int) (map[string]int, error)
// 创建token // Create token
CreateToken(ctx context.Context, userID string, platformID int) (string, error) CreateToken(ctx context.Context, userID string, platformID int) (string, error)
} }
@ -47,16 +47,12 @@ func NewAuthDatabase(cache cache.MsgModel, accessSecret string, accessExpire int
return &authDatabase{cache: cache, accessSecret: accessSecret, accessExpire: accessExpire} return &authDatabase{cache: cache, accessSecret: accessSecret, accessExpire: accessExpire}
} }
// 结果为空 不返回错误. // If the result is empty.
func (a *authDatabase) GetTokensWithoutError( func (a *authDatabase) GetTokensWithoutError(ctx context.Context, userID string, platformID int) (map[string]int, error) {
ctx context.Context,
userID string,
platformID int,
) (map[string]int, error) {
return a.cache.GetTokensWithoutError(ctx, userID, platformID) return a.cache.GetTokensWithoutError(ctx, userID, platformID)
} }
// 创建token. // Create Token.
func (a *authDatabase) CreateToken(ctx context.Context, userID string, platformID int) (string, error) { func (a *authDatabase) CreateToken(ctx context.Context, userID string, platformID int) (string, error) {
tokens, err := a.cache.GetTokensWithoutError(ctx, userID, platformID) tokens, err := a.cache.GetTokensWithoutError(ctx, userID, platformID)
if err != nil { if err != nil {
@ -80,7 +76,7 @@ func (a *authDatabase) CreateToken(ctx context.Context, userID string, platformI
token := jwt.NewWithClaims(jwt.SigningMethodHS256, claims) token := jwt.NewWithClaims(jwt.SigningMethodHS256, claims)
tokenString, err := token.SignedString([]byte(a.accessSecret)) tokenString, err := token.SignedString([]byte(a.accessSecret))
if err != nil { if err != nil {
return "", errs.Wrap(err) return "", errs.Wrap(err, "token.SignedString")
} }
return tokenString, a.cache.AddTokenFlag(ctx, userID, platformID, tokenString, constant.NormalToken) return tokenString, a.cache.AddTokenFlag(ctx, userID, platformID, tokenString, constant.NormalToken)
} }

View File

@ -27,14 +27,14 @@ import (
) )
type BlackDatabase interface { type BlackDatabase interface {
// Create 增加黑名单 // Create add BlackList
Create(ctx context.Context, blacks []*relation.BlackModel) (err error) Create(ctx context.Context, blacks []*relation.BlackModel) (err error)
// Delete 删除黑名单 // Delete delete BlackList
Delete(ctx context.Context, blacks []*relation.BlackModel) (err error) Delete(ctx context.Context, blacks []*relation.BlackModel) (err error)
// FindOwnerBlacks 获取黑名单列表 // FindOwnerBlacks get BlackList list
FindOwnerBlacks(ctx context.Context, ownerUserID string, pagination pagination.Pagination) (total int64, blacks []*relation.BlackModel, err error) FindOwnerBlacks(ctx context.Context, ownerUserID string, pagination pagination.Pagination) (total int64, blacks []*relation.BlackModel, err error)
FindBlackInfos(ctx context.Context, ownerUserID string, userIDs []string) (blacks []*relation.BlackModel, err error) FindBlackInfos(ctx context.Context, ownerUserID string, userIDs []string) (blacks []*relation.BlackModel, err error)
// CheckIn 检查user2是否在user1的黑名单列表中(inUser1Blacks==true) 检查user1是否在user2的黑名单列表中(inUser2Blacks==true) // CheckIn Check whether user2 is in the black list of user1 (inUser1Blacks==true) Check whether user1 is in the black list of user2 (inUser2Blacks==true)
CheckIn(ctx context.Context, userID1, userID2 string) (inUser1Blacks bool, inUser2Blacks bool, err error) CheckIn(ctx context.Context, userID1, userID2 string) (inUser1Blacks bool, inUser2Blacks bool, err error)
} }
@ -47,7 +47,7 @@ func NewBlackDatabase(black relation.BlackModelInterface, cache cache.BlackCache
return &blackDatabase{black, cache} return &blackDatabase{black, cache}
} }
// Create 增加黑名单. // Create Add Blacklist.
func (b *blackDatabase) Create(ctx context.Context, blacks []*relation.BlackModel) (err error) { func (b *blackDatabase) Create(ctx context.Context, blacks []*relation.BlackModel) (err error) {
if err := b.black.Create(ctx, blacks); err != nil { if err := b.black.Create(ctx, blacks); err != nil {
return err return err
@ -55,7 +55,7 @@ func (b *blackDatabase) Create(ctx context.Context, blacks []*relation.BlackMode
return b.deleteBlackIDsCache(ctx, blacks) return b.deleteBlackIDsCache(ctx, blacks)
} }
// Delete 删除黑名单. // Delete Delete Blacklist.
func (b *blackDatabase) Delete(ctx context.Context, blacks []*relation.BlackModel) (err error) { func (b *blackDatabase) Delete(ctx context.Context, blacks []*relation.BlackModel) (err error) {
if err := b.black.Delete(ctx, blacks); err != nil { if err := b.black.Delete(ctx, blacks); err != nil {
return err return err
@ -63,6 +63,7 @@ func (b *blackDatabase) Delete(ctx context.Context, blacks []*relation.BlackMode
return b.deleteBlackIDsCache(ctx, blacks) return b.deleteBlackIDsCache(ctx, blacks)
} }
// FindOwnerBlacks Get Blacklist List.
func (b *blackDatabase) deleteBlackIDsCache(ctx context.Context, blacks []*relation.BlackModel) (err error) { func (b *blackDatabase) deleteBlackIDsCache(ctx context.Context, blacks []*relation.BlackModel) (err error) {
cache := b.cache.NewCache() cache := b.cache.NewCache()
for _, black := range blacks { for _, black := range blacks {
@ -71,16 +72,13 @@ func (b *blackDatabase) deleteBlackIDsCache(ctx context.Context, blacks []*relat
return cache.ExecDel(ctx) return cache.ExecDel(ctx)
} }
// FindOwnerBlacks 获取黑名单列表. // FindOwnerBlacks Get Blacklist List.
func (b *blackDatabase) FindOwnerBlacks(ctx context.Context, ownerUserID string, pagination pagination.Pagination) (total int64, blacks []*relation.BlackModel, err error) { func (b *blackDatabase) FindOwnerBlacks(ctx context.Context, ownerUserID string, pagination pagination.Pagination) (total int64, blacks []*relation.BlackModel, err error) {
return b.black.FindOwnerBlacks(ctx, ownerUserID, pagination) return b.black.FindOwnerBlacks(ctx, ownerUserID, pagination)
} }
// CheckIn 检查user2是否在user1的黑名单列表中(inUser1Blacks==true) 检查user1是否在user2的黑名单列表中(inUser2Blacks==true). // FindOwnerBlacks Get Blacklist List.
func (b *blackDatabase) CheckIn( func (b *blackDatabase) CheckIn(ctx context.Context, userID1, userID2 string) (inUser1Blacks bool, inUser2Blacks bool, err error) {
ctx context.Context,
userID1, userID2 string,
) (inUser1Blacks bool, inUser2Blacks bool, err error) {
userID1BlackIDs, err := b.cache.GetBlackIDs(ctx, userID1) userID1BlackIDs, err := b.cache.GetBlackIDs(ctx, userID1)
if err != nil { if err != nil {
return return
@ -93,10 +91,12 @@ func (b *blackDatabase) CheckIn(
return utils.IsContain(userID2, userID1BlackIDs), utils.IsContain(userID1, userID2BlackIDs), nil return utils.IsContain(userID2, userID1BlackIDs), utils.IsContain(userID1, userID2BlackIDs), nil
} }
// FindBlackIDs Get Blacklist List.
func (b *blackDatabase) FindBlackIDs(ctx context.Context, ownerUserID string) (blackIDs []string, err error) { func (b *blackDatabase) FindBlackIDs(ctx context.Context, ownerUserID string) (blackIDs []string, err error) {
return b.cache.GetBlackIDs(ctx, ownerUserID) return b.cache.GetBlackIDs(ctx, ownerUserID)
} }
// FindBlackInfos Get Blacklist List.
func (b *blackDatabase) FindBlackInfos(ctx context.Context, ownerUserID string, userIDs []string) (blacks []*relation.BlackModel, err error) { func (b *blackDatabase) FindBlackInfos(ctx context.Context, ownerUserID string, userIDs []string) (blacks []*relation.BlackModel, err error) {
return b.black.FindOwnerBlackInfos(ctx, ownerUserID, userIDs) return b.black.FindOwnerBlackInfos(ctx, ownerUserID, userIDs)
} }

View File

@ -32,32 +32,40 @@ import (
) )
type ConversationDatabase interface { type ConversationDatabase interface {
// UpdateUserConversationFiled 更新用户该会话的属性信息 // UpdateUsersConversationField updates the properties of a conversation for specified users.
UpdateUsersConversationFiled(ctx context.Context, userIDs []string, conversationID string, args map[string]any) error UpdateUsersConversationField(ctx context.Context, userIDs []string, conversationID string, args map[string]any) error
// CreateConversation 创建一批新的会话 // CreateConversation creates a batch of new conversations.
CreateConversation(ctx context.Context, conversations []*relationtb.ConversationModel) error CreateConversation(ctx context.Context, conversations []*relationtb.ConversationModel) error
// SyncPeerUserPrivateConversation 同步对端私聊会话内部保证事务操作 // SyncPeerUserPrivateConversationTx ensures transactional operation while syncing private conversations between peers.
SyncPeerUserPrivateConversationTx(ctx context.Context, conversation []*relationtb.ConversationModel) error SyncPeerUserPrivateConversationTx(ctx context.Context, conversation []*relationtb.ConversationModel) error
// FindConversations 根据会话ID获取某个用户的多个会话 // FindConversations retrieves multiple conversations of a user by conversation IDs.
FindConversations(ctx context.Context, ownerUserID string, conversationIDs []string) ([]*relationtb.ConversationModel, error) FindConversations(ctx context.Context, ownerUserID string, conversationIDs []string) ([]*relationtb.ConversationModel, error)
// FindRecvMsgNotNotifyUserIDs 获取超级大群开启免打扰的用户ID // GetUserAllConversation fetches all conversations of a user on the server.
//FindRecvMsgNotNotifyUserIDs(ctx context.Context, groupID string) ([]string, error)
// GetUserAllConversation 获取一个用户在服务器上所有的会话
GetUserAllConversation(ctx context.Context, ownerUserID string) ([]*relationtb.ConversationModel, error) GetUserAllConversation(ctx context.Context, ownerUserID string) ([]*relationtb.ConversationModel, error)
// SetUserConversations 设置用户多个会话属性,如果会话不存在则创建,否则更新,内部保证原子性 // SetUserConversations sets multiple conversation properties for a user, creates new conversations if they do not exist, or updates them otherwise. This operation is atomic.
SetUserConversations(ctx context.Context, ownerUserID string, conversations []*relationtb.ConversationModel) error SetUserConversations(ctx context.Context, ownerUserID string, conversations []*relationtb.ConversationModel) error
// SetUsersConversationFiledTx 设置多个用户会话关于某个字段的更新操作,如果会话不存在则创建,否则更新,内部保证事务操作 // SetUsersConversationFieldTx updates a specific field for multiple users' conversations, creating new conversations if they do not exist, or updates them otherwise. This operation is transactional.
SetUsersConversationFiledTx(ctx context.Context, userIDs []string, conversation *relationtb.ConversationModel, filedMap map[string]any) error SetUsersConversationFieldTx(ctx context.Context, userIDs []string, conversation *relationtb.ConversationModel, fieldMap map[string]any) error
// CreateGroupChatConversation creates a group chat conversation for the specified group ID and user IDs.
CreateGroupChatConversation(ctx context.Context, groupID string, userIDs []string) error CreateGroupChatConversation(ctx context.Context, groupID string, userIDs []string) error
// GetConversationIDs retrieves conversation IDs for a given user.
GetConversationIDs(ctx context.Context, userID string) ([]string, error) GetConversationIDs(ctx context.Context, userID string) ([]string, error)
// GetUserConversationIDsHash gets the hash of conversation IDs for a given user.
GetUserConversationIDsHash(ctx context.Context, ownerUserID string) (hash uint64, err error) GetUserConversationIDsHash(ctx context.Context, ownerUserID string) (hash uint64, err error)
// GetAllConversationIDs fetches all conversation IDs.
GetAllConversationIDs(ctx context.Context) ([]string, error) GetAllConversationIDs(ctx context.Context) ([]string, error)
// GetAllConversationIDsNumber returns the number of all conversation IDs.
GetAllConversationIDsNumber(ctx context.Context) (int64, error) GetAllConversationIDsNumber(ctx context.Context) (int64, error)
// PageConversationIDs paginates through conversation IDs based on the specified pagination settings.
PageConversationIDs(ctx context.Context, pagination pagination.Pagination) (conversationIDs []string, err error) PageConversationIDs(ctx context.Context, pagination pagination.Pagination) (conversationIDs []string, err error)
//GetUserAllHasReadSeqs(ctx context.Context, ownerUserID string) (map[string]int64, error) // GetConversationsByConversationID retrieves conversations by their IDs.
GetConversationsByConversationID(ctx context.Context, conversationIDs []string) ([]*relationtb.ConversationModel, error) GetConversationsByConversationID(ctx context.Context, conversationIDs []string) ([]*relationtb.ConversationModel, error)
// GetConversationIDsNeedDestruct fetches conversations that need to be destructed based on specific criteria.
GetConversationIDsNeedDestruct(ctx context.Context) ([]*relationtb.ConversationModel, error) GetConversationIDsNeedDestruct(ctx context.Context) ([]*relationtb.ConversationModel, error)
// GetConversationNotReceiveMessageUserIDs gets user IDs for users in a conversation who have not received messages.
GetConversationNotReceiveMessageUserIDs(ctx context.Context, conversationID string) ([]string, error) GetConversationNotReceiveMessageUserIDs(ctx context.Context, conversationID string) ([]string, error)
//GetUserAllHasReadSeqs(ctx context.Context, ownerUserID string) (map[string]int64, error)
//FindRecvMsgNotNotifyUserIDs(ctx context.Context, groupID string) ([]string, error)
} }
func NewConversationDatabase(conversation relationtb.ConversationModelInterface, cache cache.ConversationCache, tx tx.CtxTx) ConversationDatabase { func NewConversationDatabase(conversation relationtb.ConversationModelInterface, cache cache.ConversationCache, tx tx.CtxTx) ConversationDatabase {
@ -74,7 +82,7 @@ type conversationDatabase struct {
tx tx.CtxTx tx tx.CtxTx
} }
func (c *conversationDatabase) SetUsersConversationFiledTx(ctx context.Context, userIDs []string, conversation *relationtb.ConversationModel, filedMap map[string]any) (err error) { func (c *conversationDatabase) SetUsersConversationFieldTx(ctx context.Context, userIDs []string, conversation *relationtb.ConversationModel, fieldMap map[string]any) (err error) {
return c.tx.Transaction(ctx, func(ctx context.Context) error { return c.tx.Transaction(ctx, func(ctx context.Context) error {
cache := c.cache.NewCache() cache := c.cache.NewCache()
if conversation.GroupID != "" { if conversation.GroupID != "" {
@ -85,22 +93,22 @@ func (c *conversationDatabase) SetUsersConversationFiledTx(ctx context.Context,
return err return err
} }
if len(haveUserIDs) > 0 { if len(haveUserIDs) > 0 {
_, err = c.conversationDB.UpdateByMap(ctx, haveUserIDs, conversation.ConversationID, filedMap) _, err = c.conversationDB.UpdateByMap(ctx, haveUserIDs, conversation.ConversationID, fieldMap)
if err != nil { if err != nil {
return err return err
} }
cache = cache.DelUsersConversation(conversation.ConversationID, haveUserIDs...) cache = cache.DelUsersConversation(conversation.ConversationID, haveUserIDs...)
if _, ok := filedMap["has_read_seq"]; ok { if _, ok := fieldMap["has_read_seq"]; ok {
for _, userID := range haveUserIDs { for _, userID := range haveUserIDs {
cache = cache.DelUserAllHasReadSeqs(userID, conversation.ConversationID) cache = cache.DelUserAllHasReadSeqs(userID, conversation.ConversationID)
} }
} }
if _, ok := filedMap["recv_msg_opt"]; ok { if _, ok := fieldMap["recv_msg_opt"]; ok {
cache = cache.DelConversationNotReceiveMessageUserIDs(conversation.ConversationID) cache = cache.DelConversationNotReceiveMessageUserIDs(conversation.ConversationID)
} }
} }
NotUserIDs := utils.DifferenceString(haveUserIDs, userIDs) NotUserIDs := utils.DifferenceString(haveUserIDs, userIDs)
log.ZDebug(ctx, "SetUsersConversationFiledTx", "NotUserIDs", NotUserIDs, "haveUserIDs", haveUserIDs, "userIDs", userIDs) log.ZDebug(ctx, "SetUsersConversationFieldTx", "NotUserIDs", NotUserIDs, "haveUserIDs", haveUserIDs, "userIDs", userIDs)
var conversations []*relationtb.ConversationModel var conversations []*relationtb.ConversationModel
now := time.Now() now := time.Now()
for _, v := range NotUserIDs { for _, v := range NotUserIDs {
@ -123,7 +131,7 @@ func (c *conversationDatabase) SetUsersConversationFiledTx(ctx context.Context,
}) })
} }
func (c *conversationDatabase) UpdateUsersConversationFiled(ctx context.Context, userIDs []string, conversationID string, args map[string]any) error { func (c *conversationDatabase) UpdateUsersConversationField(ctx context.Context, userIDs []string, conversationID string, args map[string]any) error {
_, err := c.conversationDB.UpdateByMap(ctx, userIDs, conversationID, args) _, err := c.conversationDB.UpdateByMap(ctx, userIDs, conversationID, args)
if err != nil { if err != nil {
return err return err

View File

@ -16,6 +16,7 @@ package controller
import ( import (
"context" "context"
"fmt"
"time" "time"
"github.com/OpenIMSDK/tools/pagination" "github.com/OpenIMSDK/tools/pagination"
@ -89,20 +90,30 @@ func NewFriendDatabase(friend relation.FriendModelInterface, friendRequest relat
return &friendDatabase{friend: friend, friendRequest: friendRequest, cache: cache, tx: tx} return &friendDatabase{friend: friend, friendRequest: friendRequest, cache: cache, tx: tx}
} }
// ok 检查user2是否在user1的好友列表中(inUser1Friends==true) 检查user1是否在user2的好友列表中(inUser2Friends==true). // CheckIn verifies if user2 is in user1's friend list (inUser1Friends returns true) and
// if user1 is in user2's friend list (inUser2Friends returns true).
func (f *friendDatabase) CheckIn(ctx context.Context, userID1, userID2 string) (inUser1Friends bool, inUser2Friends bool, err error) { func (f *friendDatabase) CheckIn(ctx context.Context, userID1, userID2 string) (inUser1Friends bool, inUser2Friends bool, err error) {
// Retrieve friend IDs of userID1 from the cache
userID1FriendIDs, err := f.cache.GetFriendIDs(ctx, userID1) userID1FriendIDs, err := f.cache.GetFriendIDs(ctx, userID1)
if err != nil { if err != nil {
err = fmt.Errorf("error retrieving friend IDs for user %s: %w", userID1, err)
return return
} }
// Retrieve friend IDs of userID2 from the cache
userID2FriendIDs, err := f.cache.GetFriendIDs(ctx, userID2) userID2FriendIDs, err := f.cache.GetFriendIDs(ctx, userID2)
if err != nil { if err != nil {
err = fmt.Errorf("error retrieving friend IDs for user %s: %w", userID2, err)
return return
} }
return utils.IsContain(userID2, userID1FriendIDs), utils.IsContain(userID1, userID2FriendIDs), nil
// Check if userID2 is in userID1's friend list and vice versa
inUser1Friends = utils.IsContain(userID2, userID1FriendIDs)
inUser2Friends = utils.IsContain(userID1, userID2FriendIDs)
return inUser1Friends, inUser2Friends, nil
} }
// 增加或者更新好友申请 如果之前有记录则更新,没有记录则新增. // AddFriendRequest adds or updates a friend request.
func (f *friendDatabase) AddFriendRequest(ctx context.Context, fromUserID, toUserID string, reqMsg string, ex string) (err error) { func (f *friendDatabase) AddFriendRequest(ctx context.Context, fromUserID, toUserID string, reqMsg string, ex string) (err error) {
return f.tx.Transaction(ctx, func(ctx context.Context) error { return f.tx.Transaction(ctx, func(ctx context.Context) error {
_, err := f.friendRequest.Take(ctx, fromUserID, toUserID) _, err := f.friendRequest.Take(ctx, fromUserID, toUserID)
@ -126,11 +137,11 @@ func (f *friendDatabase) AddFriendRequest(ctx context.Context, fromUserID, toUse
}) })
} }
// (1)先判断是否在好友表 (在不在都不返回错误) (2)对于不在好友列表的 插入即可. // (1) First determine whether it is in the friends list (in or out does not return an error) (2) for not in the friends list can be inserted.
func (f *friendDatabase) BecomeFriends(ctx context.Context, ownerUserID string, friendUserIDs []string, addSource int32) (err error) { func (f *friendDatabase) BecomeFriends(ctx context.Context, ownerUserID string, friendUserIDs []string, addSource int32) (err error) {
return f.tx.Transaction(ctx, func(ctx context.Context) error { return f.tx.Transaction(ctx, func(ctx context.Context) error {
cache := f.cache.NewCache() cache := f.cache.NewCache()
// 先find 找出重复的 去掉重复的 // User find friends
fs1, err := f.friend.FindFriends(ctx, ownerUserID, friendUserIDs) fs1, err := f.friend.FindFriends(ctx, ownerUserID, friendUserIDs)
if err != nil { if err != nil {
return err return err
@ -170,26 +181,37 @@ func (f *friendDatabase) BecomeFriends(ctx context.Context, ownerUserID string,
}) })
} }
// 拒绝好友申请 (1)检查是否有申请记录且为未处理状态 (没有记录返回错误) (2)修改申请记录 已拒绝. // RefuseFriendRequest rejects a friend request. It first checks for an existing, unprocessed request.
func (f *friendDatabase) RefuseFriendRequest(ctx context.Context, friendRequest *relation.FriendRequestModel) (err error) { // If no such request exists, it returns an error. Otherwise, it marks the request as refused.
func (f *friendDatabase) RefuseFriendRequest(ctx context.Context, friendRequest *relation.FriendRequestModel) error {
// Attempt to retrieve the friend request from the database.
fr, err := f.friendRequest.Take(ctx, friendRequest.FromUserID, friendRequest.ToUserID) fr, err := f.friendRequest.Take(ctx, friendRequest.FromUserID, friendRequest.ToUserID)
if err != nil { if err != nil {
return err return fmt.Errorf("failed to retrieve friend request from %s to %s: %w", friendRequest.FromUserID, friendRequest.ToUserID, err)
} }
// Check if the friend request has already been handled.
if fr.HandleResult != 0 { if fr.HandleResult != 0 {
return errs.ErrArgs.Wrap("the friend request has been processed") return fmt.Errorf("friend request from %s to %s has already been processed", friendRequest.FromUserID, friendRequest.ToUserID)
} }
log.ZDebug(ctx, "refuse friend request", "friendRequest db", fr, "friendRequest arg", friendRequest)
// Log the action of refusing the friend request for debugging and auditing purposes.
log.ZDebug(ctx, "Refusing friend request", map[string]interface{}{
"DB_FriendRequest": fr,
"Arg_FriendRequest": friendRequest,
})
// Mark the friend request as refused and update the handle time.
friendRequest.HandleResult = constant.FriendResponseRefuse friendRequest.HandleResult = constant.FriendResponseRefuse
friendRequest.HandleTime = time.Now() friendRequest.HandleTime = time.Now()
err = f.friendRequest.Update(ctx, friendRequest) if err := f.friendRequest.Update(ctx, friendRequest); err != nil {
if err != nil { return fmt.Errorf("failed to update friend request from %s to %s as refused: %w", friendRequest.FromUserID, friendRequest.ToUserID, err)
return err
} }
return nil return nil
} }
// AgreeFriendRequest 同意好友申请 (1)检查是否有申请记录且为未处理状态 (没有记录返回错误) (2)检查是否好友(不返回错误) (3) 建立双向好友关系(存在的忽略). // AgreeFriendRequest accepts a friend request. It first checks for an existing, unprocessed request.
func (f *friendDatabase) AgreeFriendRequest(ctx context.Context, friendRequest *relation.FriendRequestModel) (err error) { func (f *friendDatabase) AgreeFriendRequest(ctx context.Context, friendRequest *relation.FriendRequestModel) (err error) {
return f.tx.Transaction(ctx, func(ctx context.Context) error { return f.tx.Transaction(ctx, func(ctx context.Context) error {
defer log.ZDebug(ctx, "return line") defer log.ZDebug(ctx, "return line")
@ -227,10 +249,10 @@ func (f *friendDatabase) AgreeFriendRequest(ctx context.Context, friendRequest *
return err return err
} }
existsMap := utils.SliceSet(utils.Slice(exists, func(friend *relation.FriendModel) [2]string { existsMap := utils.SliceSet(utils.Slice(exists, func(friend *relation.FriendModel) [2]string {
return [...]string{friend.OwnerUserID, friend.FriendUserID} // 自己 - 好友 return [...]string{friend.OwnerUserID, friend.FriendUserID} // My - Friend
})) }))
var adds []*relation.FriendModel var adds []*relation.FriendModel
if _, ok := existsMap[[...]string{friendRequest.ToUserID, friendRequest.FromUserID}]; !ok { // 自己 - 好友 if _, ok := existsMap[[...]string{friendRequest.ToUserID, friendRequest.FromUserID}]; !ok { // My - Friend
adds = append( adds = append(
adds, adds,
&relation.FriendModel{ &relation.FriendModel{
@ -241,7 +263,7 @@ func (f *friendDatabase) AgreeFriendRequest(ctx context.Context, friendRequest *
}, },
) )
} }
if _, ok := existsMap[[...]string{friendRequest.FromUserID, friendRequest.ToUserID}]; !ok { // 好友 - 自己 if _, ok := existsMap[[...]string{friendRequest.FromUserID, friendRequest.ToUserID}]; !ok { // My - Friend
adds = append( adds = append(
adds, adds,
&relation.FriendModel{ &relation.FriendModel{
@ -261,7 +283,7 @@ func (f *friendDatabase) AgreeFriendRequest(ctx context.Context, friendRequest *
}) })
} }
// 删除好友 外部判断是否好友关系. // Delete removes a friend relationship. It is assumed that the external caller has verified the friendship status.
func (f *friendDatabase) Delete(ctx context.Context, ownerUserID string, friendUserIDs []string) (err error) { func (f *friendDatabase) Delete(ctx context.Context, ownerUserID string, friendUserIDs []string) (err error) {
if err := f.friend.Delete(ctx, ownerUserID, friendUserIDs); err != nil { if err := f.friend.Delete(ctx, ownerUserID, friendUserIDs); err != nil {
return err return err
@ -269,7 +291,7 @@ func (f *friendDatabase) Delete(ctx context.Context, ownerUserID string, friendU
return f.cache.DelFriendIDs(append(friendUserIDs, ownerUserID)...).ExecDel(ctx) return f.cache.DelFriendIDs(append(friendUserIDs, ownerUserID)...).ExecDel(ctx)
} }
// 更新好友备注 零值也支持. // UpdateRemark updates the remark for a friend. Zero value for remark is also supported.
func (f *friendDatabase) UpdateRemark(ctx context.Context, ownerUserID, friendUserID, remark string) (err error) { func (f *friendDatabase) UpdateRemark(ctx context.Context, ownerUserID, friendUserID, remark string) (err error) {
if err := f.friend.UpdateRemark(ctx, ownerUserID, friendUserID, remark); err != nil { if err := f.friend.UpdateRemark(ctx, ownerUserID, friendUserID, remark); err != nil {
return err return err
@ -277,27 +299,27 @@ func (f *friendDatabase) UpdateRemark(ctx context.Context, ownerUserID, friendUs
return f.cache.DelFriend(ownerUserID, friendUserID).ExecDel(ctx) return f.cache.DelFriend(ownerUserID, friendUserID).ExecDel(ctx)
} }
// 获取ownerUserID的好友列表 无结果不返回错误. // PageOwnerFriends retrieves the list of friends for the ownerUserID. It does not return an error if the result is empty.
func (f *friendDatabase) PageOwnerFriends(ctx context.Context, ownerUserID string, pagination pagination.Pagination) (total int64, friends []*relation.FriendModel, err error) { func (f *friendDatabase) PageOwnerFriends(ctx context.Context, ownerUserID string, pagination pagination.Pagination) (total int64, friends []*relation.FriendModel, err error) {
return f.friend.FindOwnerFriends(ctx, ownerUserID, pagination) return f.friend.FindOwnerFriends(ctx, ownerUserID, pagination)
} }
// friendUserID在哪些人的好友列表中. // PageInWhoseFriends identifies in whose friend lists the friendUserID appears.
func (f *friendDatabase) PageInWhoseFriends(ctx context.Context, friendUserID string, pagination pagination.Pagination) (total int64, friends []*relation.FriendModel, err error) { func (f *friendDatabase) PageInWhoseFriends(ctx context.Context, friendUserID string, pagination pagination.Pagination) (total int64, friends []*relation.FriendModel, err error) {
return f.friend.FindInWhoseFriends(ctx, friendUserID, pagination) return f.friend.FindInWhoseFriends(ctx, friendUserID, pagination)
} }
// 获取我发出去的好友申请 无结果不返回错误. // PageFriendRequestFromMe retrieves friend requests sent by me. It does not return an error if the result is empty.
func (f *friendDatabase) PageFriendRequestFromMe(ctx context.Context, userID string, pagination pagination.Pagination) (total int64, friends []*relation.FriendRequestModel, err error) { func (f *friendDatabase) PageFriendRequestFromMe(ctx context.Context, userID string, pagination pagination.Pagination) (total int64, friends []*relation.FriendRequestModel, err error) {
return f.friendRequest.FindFromUserID(ctx, userID, pagination) return f.friendRequest.FindFromUserID(ctx, userID, pagination)
} }
// 获取我收到的的好友申请 无结果不返回错误. // PageFriendRequestToMe retrieves friend requests received by me. It does not return an error if the result is empty.
func (f *friendDatabase) PageFriendRequestToMe(ctx context.Context, userID string, pagination pagination.Pagination) (total int64, friends []*relation.FriendRequestModel, err error) { func (f *friendDatabase) PageFriendRequestToMe(ctx context.Context, userID string, pagination pagination.Pagination) (total int64, friends []*relation.FriendRequestModel, err error) {
return f.friendRequest.FindToUserID(ctx, userID, pagination) return f.friendRequest.FindToUserID(ctx, userID, pagination)
} }
// 获取某人指定好友的信息 如果有好友不存在,也返回错误. // FindFriendsWithError retrieves specified friends' information for ownerUserID. Returns an error if any friend does not exist.
func (f *friendDatabase) FindFriendsWithError(ctx context.Context, ownerUserID string, friendUserIDs []string) (friends []*relation.FriendModel, err error) { func (f *friendDatabase) FindFriendsWithError(ctx context.Context, ownerUserID string, friendUserIDs []string) (friends []*relation.FriendModel, err error) {
friends, err = f.friend.FindFriends(ctx, ownerUserID, friendUserIDs) friends, err = f.friend.FindFriends(ctx, ownerUserID, friendUserIDs)
if err != nil { if err != nil {

View File

@ -31,47 +31,79 @@ import (
) )
type GroupDatabase interface { type GroupDatabase interface {
// Group // CreateGroup creates new groups along with their members.
CreateGroup(ctx context.Context, groups []*relationtb.GroupModel, groupMembers []*relationtb.GroupMemberModel) error CreateGroup(ctx context.Context, groups []*relationtb.GroupModel, groupMembers []*relationtb.GroupMemberModel) error
// TakeGroup retrieves a single group by its ID.
TakeGroup(ctx context.Context, groupID string) (group *relationtb.GroupModel, err error) TakeGroup(ctx context.Context, groupID string) (group *relationtb.GroupModel, err error)
// FindGroup retrieves multiple groups by their IDs.
FindGroup(ctx context.Context, groupIDs []string) (groups []*relationtb.GroupModel, err error) FindGroup(ctx context.Context, groupIDs []string) (groups []*relationtb.GroupModel, err error)
// SearchGroup searches for groups based on a keyword and pagination settings, returns total count and groups.
SearchGroup(ctx context.Context, keyword string, pagination pagination.Pagination) (int64, []*relationtb.GroupModel, error) SearchGroup(ctx context.Context, keyword string, pagination pagination.Pagination) (int64, []*relationtb.GroupModel, error)
// UpdateGroup updates the properties of a group identified by its ID.
UpdateGroup(ctx context.Context, groupID string, data map[string]any) error UpdateGroup(ctx context.Context, groupID string, data map[string]any) error
DismissGroup(ctx context.Context, groupID string, deleteMember bool) error // 解散群,并删除群成员 // DismissGroup disbands a group and optionally removes its members based on the deleteMember flag.
DismissGroup(ctx context.Context, groupID string, deleteMember bool) error
// TakeGroupMember retrieves a specific group member by group ID and user ID.
TakeGroupMember(ctx context.Context, groupID string, userID string) (groupMember *relationtb.GroupMemberModel, err error) TakeGroupMember(ctx context.Context, groupID string, userID string) (groupMember *relationtb.GroupMemberModel, err error)
// TakeGroupOwner retrieves the owner of a group by group ID.
TakeGroupOwner(ctx context.Context, groupID string) (*relationtb.GroupMemberModel, error) TakeGroupOwner(ctx context.Context, groupID string) (*relationtb.GroupMemberModel, error)
FindGroupMembers(ctx context.Context, groupID string, userIDs []string) (groupMembers []*relationtb.GroupMemberModel, err error) // * // FindGroupMembers retrieves members of a group filtered by user IDs.
FindGroupMemberUser(ctx context.Context, groupIDs []string, userID string) (groupMembers []*relationtb.GroupMemberModel, err error) // * FindGroupMembers(ctx context.Context, groupID string, userIDs []string) (groupMembers []*relationtb.GroupMemberModel, err error)
FindGroupMemberRoleLevels(ctx context.Context, groupID string, roleLevels []int32) (groupMembers []*relationtb.GroupMemberModel, err error) // * // FindGroupMemberUser retrieves groups that a user is a member of, filtered by group IDs.
FindGroupMemberAll(ctx context.Context, groupID string) (groupMembers []*relationtb.GroupMemberModel, err error) // * FindGroupMemberUser(ctx context.Context, groupIDs []string, userID string) (groupMembers []*relationtb.GroupMemberModel, err error)
// FindGroupMemberRoleLevels retrieves group members filtered by their role levels within a group.
FindGroupMemberRoleLevels(ctx context.Context, groupID string, roleLevels []int32) (groupMembers []*relationtb.GroupMemberModel, err error)
// FindGroupMemberAll retrieves all members of a group.
FindGroupMemberAll(ctx context.Context, groupID string) (groupMembers []*relationtb.GroupMemberModel, err error)
// FindGroupsOwner retrieves the owners for multiple groups.
FindGroupsOwner(ctx context.Context, groupIDs []string) ([]*relationtb.GroupMemberModel, error) FindGroupsOwner(ctx context.Context, groupIDs []string) ([]*relationtb.GroupMemberModel, error)
// FindGroupMemberUserID retrieves the user IDs of all members in a group.
FindGroupMemberUserID(ctx context.Context, groupID string) ([]string, error) FindGroupMemberUserID(ctx context.Context, groupID string) ([]string, error)
// FindGroupMemberNum retrieves the number of members in a group.
FindGroupMemberNum(ctx context.Context, groupID string) (uint32, error) FindGroupMemberNum(ctx context.Context, groupID string) (uint32, error)
// FindUserManagedGroupID retrieves group IDs managed by a user.
FindUserManagedGroupID(ctx context.Context, userID string) (groupIDs []string, err error) FindUserManagedGroupID(ctx context.Context, userID string) (groupIDs []string, err error)
// PageGroupRequest paginates through group requests for specified groups.
PageGroupRequest(ctx context.Context, groupIDs []string, pagination pagination.Pagination) (int64, []*relationtb.GroupRequestModel, error) PageGroupRequest(ctx context.Context, groupIDs []string, pagination pagination.Pagination) (int64, []*relationtb.GroupRequestModel, error)
// GetGroupRoleLevelMemberIDs retrieves user IDs of group members with a specific role level.
GetGroupRoleLevelMemberIDs(ctx context.Context, groupID string, roleLevel int32) ([]string, error) GetGroupRoleLevelMemberIDs(ctx context.Context, groupID string, roleLevel int32) ([]string, error)
// PageGetJoinGroup paginates through groups that a user has joined.
PageGetJoinGroup(ctx context.Context, userID string, pagination pagination.Pagination) (total int64, totalGroupMembers []*relationtb.GroupMemberModel, err error) PageGetJoinGroup(ctx context.Context, userID string, pagination pagination.Pagination) (total int64, totalGroupMembers []*relationtb.GroupMemberModel, err error)
// PageGetGroupMember paginates through members of a group.
PageGetGroupMember(ctx context.Context, groupID string, pagination pagination.Pagination) (total int64, totalGroupMembers []*relationtb.GroupMemberModel, err error) PageGetGroupMember(ctx context.Context, groupID string, pagination pagination.Pagination) (total int64, totalGroupMembers []*relationtb.GroupMemberModel, err error)
// SearchGroupMember searches for group members based on a keyword, group ID, and pagination settings.
SearchGroupMember(ctx context.Context, keyword string, groupID string, pagination pagination.Pagination) (int64, []*relationtb.GroupMemberModel, error) SearchGroupMember(ctx context.Context, keyword string, groupID string, pagination pagination.Pagination) (int64, []*relationtb.GroupMemberModel, error)
// HandlerGroupRequest processes a group join request with a specified result.
HandlerGroupRequest(ctx context.Context, groupID string, userID string, handledMsg string, handleResult int32, member *relationtb.GroupMemberModel) error HandlerGroupRequest(ctx context.Context, groupID string, userID string, handledMsg string, handleResult int32, member *relationtb.GroupMemberModel) error
// DeleteGroupMember removes specified users from a group.
DeleteGroupMember(ctx context.Context, groupID string, userIDs []string) error DeleteGroupMember(ctx context.Context, groupID string, userIDs []string) error
// MapGroupMemberUserID maps group IDs to their members' simplified user IDs.
MapGroupMemberUserID(ctx context.Context, groupIDs []string) (map[string]*relationtb.GroupSimpleUserID, error) MapGroupMemberUserID(ctx context.Context, groupIDs []string) (map[string]*relationtb.GroupSimpleUserID, error)
// MapGroupMemberNum maps group IDs to their member count.
MapGroupMemberNum(ctx context.Context, groupIDs []string) (map[string]uint32, error) MapGroupMemberNum(ctx context.Context, groupIDs []string) (map[string]uint32, error)
TransferGroupOwner(ctx context.Context, groupID string, oldOwnerUserID, newOwnerUserID string, roleLevel int32) error // 转让群 // TransferGroupOwner transfers the ownership of a group to another user.
TransferGroupOwner(ctx context.Context, groupID string, oldOwnerUserID, newOwnerUserID string, roleLevel int32) error
// UpdateGroupMember updates properties of a group member.
UpdateGroupMember(ctx context.Context, groupID string, userID string, data map[string]any) error UpdateGroupMember(ctx context.Context, groupID string, userID string, data map[string]any) error
// UpdateGroupMembers batch updates properties of group members.
UpdateGroupMembers(ctx context.Context, data []*relationtb.BatchUpdateGroupMember) error UpdateGroupMembers(ctx context.Context, data []*relationtb.BatchUpdateGroupMember) error
// GroupRequest
// CreateGroupRequest creates new group join requests.
CreateGroupRequest(ctx context.Context, requests []*relationtb.GroupRequestModel) error CreateGroupRequest(ctx context.Context, requests []*relationtb.GroupRequestModel) error
// TakeGroupRequest retrieves a specific group join request.
TakeGroupRequest(ctx context.Context, groupID string, userID string) (*relationtb.GroupRequestModel, error) TakeGroupRequest(ctx context.Context, groupID string, userID string) (*relationtb.GroupRequestModel, error)
// FindGroupRequests retrieves multiple group join requests.
FindGroupRequests(ctx context.Context, groupID string, userIDs []string) ([]*relationtb.GroupRequestModel, error) FindGroupRequests(ctx context.Context, groupID string, userIDs []string) ([]*relationtb.GroupRequestModel, error)
// PageGroupRequestUser paginates through group join requests made by a user.
PageGroupRequestUser(ctx context.Context, userID string, pagination pagination.Pagination) (int64, []*relationtb.GroupRequestModel, error) PageGroupRequestUser(ctx context.Context, userID string, pagination pagination.Pagination) (int64, []*relationtb.GroupRequestModel, error)
// 获取群总数 // CountTotal counts the total number of groups as of a certain date.
CountTotal(ctx context.Context, before *time.Time) (count int64, err error) CountTotal(ctx context.Context, before *time.Time) (count int64, err error)
// 获取范围内群增量 // CountRangeEverydayTotal counts the daily group creation total within a specified date range.
CountRangeEverydayTotal(ctx context.Context, start time.Time, end time.Time) (map[string]int64, error) CountRangeEverydayTotal(ctx context.Context, start time.Time, end time.Time) (map[string]int64, error)
// DeleteGroupMemberHash deletes the hash entries for group members in specified groups.
DeleteGroupMemberHash(ctx context.Context, groupIDs []string) error DeleteGroupMemberHash(ctx context.Context, groupIDs []string) error
} }

View File

@ -48,33 +48,32 @@ const (
updateKeyRevoke updateKeyRevoke
) )
// CommonMsgDatabase defines the interface for message database operations.
type CommonMsgDatabase interface { type CommonMsgDatabase interface {
// 批量插入消息 // BatchInsertChat2DB inserts a batch of messages into the database for a specific conversation.
BatchInsertChat2DB(ctx context.Context, conversationID string, msgs []*sdkws.MsgData, currentMaxSeq int64) error BatchInsertChat2DB(ctx context.Context, conversationID string, msgs []*sdkws.MsgData, currentMaxSeq int64) error
// 撤回消息 // RevokeMsg revokes a message in a conversation.
RevokeMsg(ctx context.Context, conversationID string, seq int64, revoke *unrelationtb.RevokeModel) error RevokeMsg(ctx context.Context, conversationID string, seq int64, revoke *unrelationtb.RevokeModel) error
// mark as read // MarkSingleChatMsgsAsRead marks messages as read for a single chat by sequence numbers.
MarkSingleChatMsgsAsRead(ctx context.Context, userID string, conversationID string, seqs []int64) error MarkSingleChatMsgsAsRead(ctx context.Context, userID string, conversationID string, seqs []int64) error
// 刪除redis中消息缓存 // DeleteMessagesFromCache deletes message caches from Redis by sequence numbers.
DeleteMessagesFromCache(ctx context.Context, conversationID string, seqs []int64) error DeleteMessagesFromCache(ctx context.Context, conversationID string, seqs []int64) error
// DelUserDeleteMsgsList deletes user's message deletion list.
DelUserDeleteMsgsList(ctx context.Context, conversationID string, seqs []int64) DelUserDeleteMsgsList(ctx context.Context, conversationID string, seqs []int64)
// incrSeq然后批量插入缓存 // BatchInsertChat2Cache increments the sequence number and then batch inserts messages into the cache.
BatchInsertChat2Cache(ctx context.Context, conversationID string, msgs []*sdkws.MsgData) (seq int64, isNewConversation bool, err error) BatchInsertChat2Cache(ctx context.Context, conversationID string, msgs []*sdkws.MsgData) (seq int64, isNewConversation bool, err error)
// GetMsgBySeqsRange retrieves messages from MongoDB by a range of sequence numbers.
// 通过seqList获取mongo中写扩散消息
GetMsgBySeqsRange(ctx context.Context, userID string, conversationID string, begin, end, num, userMaxSeq int64) (minSeq int64, maxSeq int64, seqMsg []*sdkws.MsgData, err error) GetMsgBySeqsRange(ctx context.Context, userID string, conversationID string, begin, end, num, userMaxSeq int64) (minSeq int64, maxSeq int64, seqMsg []*sdkws.MsgData, err error)
// 通过seqList获取大群在 mongo里面的消息 // GetMsgBySeqs retrieves messages for large groups from MongoDB by sequence numbers.
GetMsgBySeqs(ctx context.Context, userID string, conversationID string, seqs []int64) (minSeq int64, maxSeq int64, seqMsg []*sdkws.MsgData, err error) GetMsgBySeqs(ctx context.Context, userID string, conversationID string, seqs []int64) (minSeq int64, maxSeq int64, seqMsg []*sdkws.MsgData, err error)
// 删除会话消息重置最小seq remainTime为消息保留的时间单位秒,超时消息删除, 传0删除所有消息(此方法不删除redis cache) // DeleteConversationMsgsAndSetMinSeq deletes conversation messages and resets the minimum sequence number. If `remainTime` is 0, all messages are deleted (this method does not delete Redis cache).
DeleteConversationMsgsAndSetMinSeq(ctx context.Context, conversationID string, remainTime int64) error DeleteConversationMsgsAndSetMinSeq(ctx context.Context, conversationID string, remainTime int64) error
// 用户标记删除过期消息返回标记删除的seq列表 // UserMsgsDestruct marks messages for deletion based on destruct time and returns a list of sequence numbers for marked messages.
UserMsgsDestruct(ctx context.Context, userID string, conversationID string, destructTime int64, lastMsgDestructTime time.Time) (seqs []int64, err error) UserMsgsDestruct(ctx context.Context, userID string, conversationID string, destructTime int64, lastMsgDestructTime time.Time) (seqs []int64, err error)
// DeleteUserMsgsBySeqs allows a user to delete messages based on sequence numbers.
// 用户根据seq删除消息
DeleteUserMsgsBySeqs(ctx context.Context, userID string, conversationID string, seqs []int64) error DeleteUserMsgsBySeqs(ctx context.Context, userID string, conversationID string, seqs []int64) error
// 物理删除消息置空 // DeleteMsgsPhysicalBySeqs physically deletes messages by emptying them based on sequence numbers.
DeleteMsgsPhysicalBySeqs(ctx context.Context, conversationID string, seqs []int64) error DeleteMsgsPhysicalBySeqs(ctx context.Context, conversationID string, seqs []int64) error
SetMaxSeq(ctx context.Context, conversationID string, maxSeq int64) error SetMaxSeq(ctx context.Context, conversationID string, maxSeq int64) error
GetMaxSeqs(ctx context.Context, conversationIDs []string) (map[string]int64, error) GetMaxSeqs(ctx context.Context, conversationIDs []string) (map[string]int64, error)
GetMaxSeq(ctx context.Context, conversationID string) (int64, error) GetMaxSeq(ctx context.Context, conversationID string) (int64, error)
@ -200,7 +199,7 @@ func (db *commonMsgDatabase) BatchInsertBlock(ctx context.Context, conversationI
} }
num := db.msg.GetSingleGocMsgNum() num := db.msg.GetSingleGocMsgNum()
// num = 100 // num = 100
for i, field := range fields { // 检查类型 for i, field := range fields { // Check the type of the field
var ok bool var ok bool
switch key { switch key {
case updateKeyMsg: case updateKeyMsg:
@ -218,7 +217,7 @@ func (db *commonMsgDatabase) BatchInsertBlock(ctx context.Context, conversationI
return errs.ErrInternalServer.Wrap("field type is invalid") return errs.ErrInternalServer.Wrap("field type is invalid")
} }
} }
// 返回值为true表示数据库存在该文档false表示数据库不存在该文档 // Returns true if the document exists in the database, false if the document does not exist in the database
updateMsgModel := func(seq int64, i int) (bool, error) { updateMsgModel := func(seq int64, i int) (bool, error) {
var ( var (
res *mongo.UpdateResult res *mongo.UpdateResult
@ -240,21 +239,21 @@ func (db *commonMsgDatabase) BatchInsertBlock(ctx context.Context, conversationI
} }
tryUpdate := true tryUpdate := true
for i := 0; i < len(fields); i++ { for i := 0; i < len(fields); i++ {
seq := firstSeq + int64(i) // 当前seq seq := firstSeq + int64(i) // Current sequence number
if tryUpdate { if tryUpdate {
matched, err := updateMsgModel(seq, i) matched, err := updateMsgModel(seq, i)
if err != nil { if err != nil {
return err return err
} }
if matched { if matched {
continue // 匹配到了,继续下一个(不一定修改) continue // The current data has been updated, skip the current data
} }
} }
doc := unrelationtb.MsgDocModel{ doc := unrelationtb.MsgDocModel{
DocID: db.msg.GetDocID(conversationID, seq), DocID: db.msg.GetDocID(conversationID, seq),
Msg: make([]*unrelationtb.MsgInfoModel, num), Msg: make([]*unrelationtb.MsgInfoModel, num),
} }
var insert int // 插入的数量 var insert int // Inserted data number
for j := i; j < len(fields); j++ { for j := i; j < len(fields); j++ {
seq = firstSeq + int64(j) seq = firstSeq + int64(j)
if db.msg.GetDocID(conversationID, seq) != doc.DocID { if db.msg.GetDocID(conversationID, seq) != doc.DocID {
@ -283,14 +282,14 @@ func (db *commonMsgDatabase) BatchInsertBlock(ctx context.Context, conversationI
} }
if err := db.msgDocDatabase.Create(ctx, &doc); err != nil { if err := db.msgDocDatabase.Create(ctx, &doc); err != nil {
if mongo.IsDuplicateKeyError(err) { if mongo.IsDuplicateKeyError(err) {
i-- // 存在并发,重试当前数据 i-- // already inserted
tryUpdate = true // 以修改模式 tryUpdate = true // next block use update mode
continue continue
} }
return err return err
} }
tryUpdate = false // 当前以插入成功,下一块优先插入模式 tryUpdate = false // The current block is inserted successfully, and the next block is inserted preferentially
i += insert - 1 // 跳过已插入的数据 i += insert - 1 // Skip the inserted data
} }
return nil return nil
} }
@ -754,7 +753,7 @@ func (db *commonMsgDatabase) UserMsgsDestruct(ctx context.Context, userID string
log.ZError(ctx, "deleteMsgRecursion GetUserMsgListByIndex failed", err, "conversationID", conversationID, "index", index) log.ZError(ctx, "deleteMsgRecursion GetUserMsgListByIndex failed", err, "conversationID", conversationID, "index", index)
} }
} }
// 获取报错或者获取不到了物理删除并且返回seq delMongoMsgsPhysical(delStruct.delDocIDList), 结束递归 // If an error is reported, or the error cannot be obtained, it is physically deleted and seq delMongoMsgsPhysical(delStruct.delDocIDList) is returned to end the recursion
break break
} }
index++ index++
@ -809,7 +808,7 @@ func (d *delMsgRecursionStruct) getSetMinSeq() int64 {
// index 0....19(del) 20...69 // index 0....19(del) 20...69
// seq 70 // seq 70
// set minSeq 21 // set minSeq 21
// recursion 删除list并且返回设置的最小seq. // recursion deletes the list and returns the set minimum seq.
func (db *commonMsgDatabase) deleteMsgRecursion(ctx context.Context, conversationID string, index int64, delStruct *delMsgRecursionStruct, remainTime int64) (int64, error) { func (db *commonMsgDatabase) deleteMsgRecursion(ctx context.Context, conversationID string, index int64, delStruct *delMsgRecursionStruct, remainTime int64) (int64, error) {
// find from oldest list // find from oldest list
msgDocModel, err := db.msgDocDatabase.GetMsgDocModelByIndex(ctx, conversationID, index, 1) msgDocModel, err := db.msgDocDatabase.GetMsgDocModelByIndex(ctx, conversationID, index, 1)
@ -821,7 +820,7 @@ func (db *commonMsgDatabase) deleteMsgRecursion(ctx context.Context, conversatio
log.ZError(ctx, "deleteMsgRecursion GetUserMsgListByIndex failed", err, "conversationID", conversationID, "index", index) log.ZError(ctx, "deleteMsgRecursion GetUserMsgListByIndex failed", err, "conversationID", conversationID, "index", index)
} }
} }
// 获取报错或者获取不到了物理删除并且返回seq delMongoMsgsPhysical(delStruct.delDocIDList), 结束递归 // If an error is reported, or the error cannot be obtained, it is physically deleted and seq delMongoMsgsPhysical(delStruct.delDocIDList) is returned to end the recursion
err = db.msgDocDatabase.DeleteDocs(ctx, delStruct.delDocIDs) err = db.msgDocDatabase.DeleteDocs(ctx, delStruct.delDocIDs)
if err != nil { if err != nil {
return 0, err return 0, err

View File

@ -63,13 +63,7 @@ func NewThirdDatabase(cache cache.MsgModel, logdb relation.LogInterface) ThirdDa
return &thirdDatabase{cache: cache, logdb: logdb} return &thirdDatabase{cache: cache, logdb: logdb}
} }
func (t *thirdDatabase) FcmUpdateToken( func (t *thirdDatabase) FcmUpdateToken(ctx context.Context, account string, platformID int, fcmToken string, expireTime int64) error {
ctx context.Context,
account string,
platformID int,
fcmToken string,
expireTime int64,
) error {
return t.cache.SetFcmToken(ctx, account, platformID, fcmToken, expireTime) return t.cache.SetFcmToken(ctx, account, platformID, fcmToken, expireTime)
} }

View File

@ -58,10 +58,10 @@ func NewAWS() (s3.Interface, error) {
credential := credentials.NewStaticCredentials( credential := credentials.NewStaticCredentials(
conf.AccessKeyID, // accessKey conf.AccessKeyID, // accessKey
conf.AccessKeySecret, // secretKey conf.AccessKeySecret, // secretKey
"") // sts的临时凭证 "") // stoken
sess, err := session.NewSession(&aws.Config{ sess, err := session.NewSession(&aws.Config{
Region: aws.String(conf.Region), // 桶所在的区域 Region: aws.String(conf.Region), // The area where the bucket is located
Credentials: credential, Credentials: credential,
}) })

View File

@ -15,10 +15,24 @@
package cont package cont
const ( const (
// hashPath defines the storage path for hash data within the 'openim' directory.
hashPath = "openim/data/hash/" hashPath = "openim/data/hash/"
// tempPath specifies the directory for temporary files in the 'openim' structure.
tempPath = "openim/temp/" tempPath = "openim/temp/"
// DirectPath indicates the directory for direct uploads or access within the 'openim' structure.
DirectPath = "openim/direct" DirectPath = "openim/direct"
UploadTypeMultipart = 1 // 分片上传
UploadTypePresigned = 2 // 预签名上传 // UploadTypeMultipart represents the identifier for multipart uploads,
// allowing large files to be uploaded in chunks.
UploadTypeMultipart = 1
// UploadTypePresigned signifies the use of presigned URLs for uploads,
// facilitating secure, authorized file transfers without requiring direct access to the storage credentials.
UploadTypePresigned = 2
// partSeparator is used as a delimiter in multipart upload processes,
// separating individual file parts.
partSeparator = "," partSeparator = ","
) )

View File

@ -114,7 +114,7 @@ func (c *Controller) InitiateUpload(ctx context.Context, hash string, size int64
return nil, err return nil, err
} }
if size <= partSize { if size <= partSize {
// 预签名上传 // Pre-signed upload
key := path.Join(tempPath, c.NowPath(), fmt.Sprintf("%s_%d_%s.presigned", hash, size, c.UUID())) key := path.Join(tempPath, c.NowPath(), fmt.Sprintf("%s_%d_%s.presigned", hash, size, c.UUID()))
rawURL, err := c.impl.PresignedPutObject(ctx, key, expire) rawURL, err := c.impl.PresignedPutObject(ctx, key, expire)
if err != nil { if err != nil {
@ -139,7 +139,7 @@ func (c *Controller) InitiateUpload(ctx context.Context, hash string, size int64
}, },
}, nil }, nil
} else { } else {
// 分片上传 // Fragment upload
upload, err := c.impl.InitiateMultipartUpload(ctx, c.HashPath(hash)) upload, err := c.impl.InitiateMultipartUpload(ctx, c.HashPath(hash))
if err != nil { if err != nil {
return nil, err return nil, err
@ -206,7 +206,7 @@ func (c *Controller) CompleteUpload(ctx context.Context, uploadID string, partHa
ETag: part, ETag: part,
} }
} }
// todo: 验证大小 // todo: Validation size
result, err := c.impl.CompleteMultipartUpload(ctx, upload.ID, upload.Key, parts) result, err := c.impl.CompleteMultipartUpload(ctx, upload.ID, upload.Key, parts)
if err != nil { if err != nil {
return nil, err return nil, err
@ -225,7 +225,7 @@ func (c *Controller) CompleteUpload(ctx context.Context, uploadID string, partHa
if md5val := hex.EncodeToString(md5Sum[:]); md5val != upload.Hash { if md5val := hex.EncodeToString(md5Sum[:]); md5val != upload.Hash {
return nil, errs.ErrArgs.Wrap(fmt.Sprintf("md5 mismatching %s != %s", md5val, upload.Hash)) return nil, errs.ErrArgs.Wrap(fmt.Sprintf("md5 mismatching %s != %s", md5val, upload.Hash))
} }
// 防止在这个时候,并发操作,导致文件被覆盖 // Prevents concurrent operations at this time that cause files to be overwritten
copyInfo, err := c.impl.CopyObject(ctx, uploadInfo.Key, upload.Key+"."+c.UUID()) copyInfo, err := c.impl.CopyObject(ctx, uploadInfo.Key, upload.Key+"."+c.UUID())
if err != nil { if err != nil {
return nil, err return nil, err

View File

@ -17,9 +17,14 @@ package cont
import "github.com/openimsdk/open-im-server/v3/pkg/common/db/s3" import "github.com/openimsdk/open-im-server/v3/pkg/common/db/s3"
type InitiateUploadResult struct { type InitiateUploadResult struct {
UploadID string `json:"uploadID"` // 上传ID // UploadID uniquely identifies the upload session for tracking and management purposes.
PartSize int64 `json:"partSize"` // 分片大小 UploadID string `json:"uploadID"`
Sign *s3.AuthSignResult `json:"sign"` // 分片信息
// PartSize specifies the size of each part in a multipart upload. This is relevant for breaking down large uploads into manageable pieces.
PartSize int64 `json:"partSize"`
// Sign contains the authentication and signature information necessary for securely uploading each part. This could include signed URLs or tokens.
Sign *s3.AuthSignResult `json:"sign"`
} }
type UploadResult struct { type UploadResult struct {

View File

@ -42,35 +42,38 @@ func ImageWidthHeight(img image.Image) (int, int) {
return bounds.X, bounds.Y return bounds.X, bounds.Y
} }
// resizeImage resizes an image to a specified maximum width and height, maintaining the aspect ratio.
// If both maxWidth and maxHeight are set to 0, the original image is returned.
// If both are non-zero, the image is scaled to fit within the constraints while maintaining aspect ratio.
// If only one of maxWidth or maxHeight is non-zero, the image is scaled accordingly.
func resizeImage(img image.Image, maxWidth, maxHeight int) image.Image { func resizeImage(img image.Image, maxWidth, maxHeight int) image.Image {
bounds := img.Bounds() bounds := img.Bounds()
imgWidth := bounds.Max.X imgWidth, imgHeight := bounds.Dx(), bounds.Dy()
imgHeight := bounds.Max.Y
// 计算缩放比例 // Return original image if no resizing is needed.
scaleWidth := float64(maxWidth) / float64(imgWidth)
scaleHeight := float64(maxHeight) / float64(imgHeight)
// 如果都为0则不缩放返回原始图片
if maxWidth == 0 && maxHeight == 0 { if maxWidth == 0 && maxHeight == 0 {
return img return img
} }
// 如果宽度和高度都大于0则选择较小的缩放比例以保持宽高比 var scale float64 = 1
if maxWidth > 0 && maxHeight > 0 { if maxWidth > 0 && maxHeight > 0 {
scale := scaleWidth scaleWidth := float64(maxWidth) / float64(imgWidth)
if scaleHeight < scaleWidth { scaleHeight := float64(maxHeight) / float64(imgHeight)
scale = scaleHeight // Choose the smaller scale to fit both constraints.
scale = min(scaleWidth, scaleHeight)
} else if maxWidth > 0 {
scale = float64(maxWidth) / float64(imgWidth)
} else if maxHeight > 0 {
scale = float64(maxHeight) / float64(imgHeight)
} }
// 计算缩略图尺寸 newWidth := int(float64(imgWidth) * scale)
thumbnailWidth := int(float64(imgWidth) * scale) newHeight := int(float64(imgHeight) * scale)
thumbnailHeight := int(float64(imgHeight) * scale)
// 使用"image"库的Resample方法生成缩略图 // Resize the image by creating a new image and manually copying pixels.
thumbnail := image.NewRGBA(image.Rect(0, 0, thumbnailWidth, thumbnailHeight)) thumbnail := image.NewRGBA(image.Rect(0, 0, newWidth, newHeight))
for y := 0; y < thumbnailHeight; y++ { for y := 0; y < newHeight; y++ {
for x := 0; x < thumbnailWidth; x++ { for x := 0; x < newWidth; x++ {
srcX := int(float64(x) / scale) srcX := int(float64(x) / scale)
srcY := int(float64(y) / scale) srcY := int(float64(y) / scale)
thumbnail.Set(x, y, img.At(srcX, srcY)) thumbnail.Set(x, y, img.At(srcX, srcY))
@ -78,43 +81,12 @@ func resizeImage(img image.Image, maxWidth, maxHeight int) image.Image {
} }
return thumbnail return thumbnail
} }
// 如果只指定了宽度或高度,则根据最大不超过的规则生成缩略图 // min returns the smaller of x or y.
if maxWidth > 0 { func min(x, y float64) float64 {
thumbnailWidth := maxWidth if x < y {
thumbnailHeight := int(float64(imgHeight) * scaleWidth) return x
}
// 使用"image"库的Resample方法生成缩略图 return y
thumbnail := image.NewRGBA(image.Rect(0, 0, thumbnailWidth, thumbnailHeight))
for y := 0; y < thumbnailHeight; y++ {
for x := 0; x < thumbnailWidth; x++ {
srcX := int(float64(x) / scaleWidth)
srcY := int(float64(y) / scaleWidth)
thumbnail.Set(x, y, img.At(srcX, srcY))
}
}
return thumbnail
}
if maxHeight > 0 {
thumbnailWidth := int(float64(imgWidth) * scaleHeight)
thumbnailHeight := maxHeight
// 使用"image"库的Resample方法生成缩略图
thumbnail := image.NewRGBA(image.Rect(0, 0, thumbnailWidth, thumbnailHeight))
for y := 0; y < thumbnailHeight; y++ {
for x := 0; x < thumbnailWidth; x++ {
srcX := int(float64(x) / scaleHeight)
srcY := int(float64(y) / scaleHeight)
thumbnail.Set(x, y, img.At(srcX, srcY))
}
}
return thumbnail
}
// 默认情况下,返回原始图片
return img
} }

View File

@ -30,6 +30,7 @@ import (
"strings" "strings"
"time" "time"
"github.com/OpenIMSDK/tools/errs"
"github.com/aliyun/aliyun-oss-go-sdk/oss" "github.com/aliyun/aliyun-oss-go-sdk/oss"
"github.com/openimsdk/open-im-server/v3/pkg/common/config" "github.com/openimsdk/open-im-server/v3/pkg/common/config"
@ -60,7 +61,7 @@ const successCode = http.StatusOK
func NewOSS() (s3.Interface, error) { func NewOSS() (s3.Interface, error) {
conf := config.Config.Object.Oss conf := config.Config.Object.Oss
if conf.BucketURL == "" { if conf.BucketURL == "" {
return nil, errors.New("bucket url is empty") return nil, errs.Wrap(errors.New("bucket url is empty"))
} }
client, err := oss.New(conf.Endpoint, conf.AccessKeyID, conf.AccessKeySecret) client, err := oss.New(conf.Endpoint, conf.AccessKeyID, conf.AccessKeySecret)
if err != nil { if err != nil {
@ -68,7 +69,7 @@ func NewOSS() (s3.Interface, error) {
} }
bucket, err := client.Bucket(conf.Bucket) bucket, err := client.Bucket(conf.Bucket)
if err != nil { if err != nil {
return nil, err return nil, errs.Wrap(err, "ali-oss bucket error")
} }
if conf.BucketURL[len(conf.BucketURL)-1] != '/' { if conf.BucketURL[len(conf.BucketURL)-1] != '/' {
conf.BucketURL += "/" conf.BucketURL += "/"
@ -138,10 +139,10 @@ func (o *OSS) CompleteMultipartUpload(ctx context.Context, uploadID string, name
func (o *OSS) PartSize(ctx context.Context, size int64) (int64, error) { func (o *OSS) PartSize(ctx context.Context, size int64) (int64, error) {
if size <= 0 { if size <= 0 {
return 0, errors.New("size must be greater than 0") return 0, errs.Wrap(errors.New("size must be greater than 0"))
} }
if size > maxPartSize*maxNumSize { if size > maxPartSize*maxNumSize {
return 0, fmt.Errorf("OSS size must be less than the maximum allowed limit") return 0, errs.Wrap(errors.New("size must be less than the maximum allowed limit"))
} }
if size <= minPartSize*maxNumSize { if size <= minPartSize*maxNumSize {
return minPartSize, nil return minPartSize, nil
@ -196,25 +197,25 @@ func (o *OSS) StatObject(ctx context.Context, name string) (*s3.ObjectInfo, erro
} }
res := &s3.ObjectInfo{Key: name} res := &s3.ObjectInfo{Key: name}
if res.ETag = strings.ToLower(strings.ReplaceAll(header.Get("ETag"), `"`, ``)); res.ETag == "" { if res.ETag = strings.ToLower(strings.ReplaceAll(header.Get("ETag"), `"`, ``)); res.ETag == "" {
return nil, errors.New("StatObject etag not found") return nil, errs.Wrap(errors.New("StatObject etag not found"))
} }
if contentLengthStr := header.Get("Content-Length"); contentLengthStr == "" { if contentLengthStr := header.Get("Content-Length"); contentLengthStr == "" {
return nil, errors.New("StatObject content-length not found") return nil, errors.New("StatObject content-length not found")
} else { } else {
res.Size, err = strconv.ParseInt(contentLengthStr, 10, 64) res.Size, err = strconv.ParseInt(contentLengthStr, 10, 64)
if err != nil { if err != nil {
return nil, fmt.Errorf("StatObject content-length parse error: %w", err) return nil, errs.Wrap(err, "StatObject content-length parse error")
} }
if res.Size < 0 { if res.Size < 0 {
return nil, errors.New("StatObject content-length must be greater than 0") return nil, errs.Wrap(errors.New("StatObject content-length must be greater than 0"))
} }
} }
if lastModified := header.Get("Last-Modified"); lastModified == "" { if lastModified := header.Get("Last-Modified"); lastModified == "" {
return nil, errors.New("StatObject last-modified not found") return nil, errs.Wrap(errors.New("StatObject last-modified not found"))
} else { } else {
res.LastModified, err = time.Parse(http.TimeFormat, lastModified) res.LastModified, err = time.Parse(http.TimeFormat, lastModified)
if err != nil { if err != nil {
return nil, fmt.Errorf("StatObject last-modified parse error: %w", err) return nil, errs.Wrap(err, "StatObject last-modified parse error")
} }
} }
return res, nil return res, nil
@ -227,7 +228,7 @@ func (o *OSS) DeleteObject(ctx context.Context, name string) error {
func (o *OSS) CopyObject(ctx context.Context, src string, dst string) (*s3.CopyObjectInfo, error) { func (o *OSS) CopyObject(ctx context.Context, src string, dst string) (*s3.CopyObjectInfo, error) {
result, err := o.bucket.CopyObject(src, dst) result, err := o.bucket.CopyObject(src, dst)
if err != nil { if err != nil {
return nil, err return nil, errs.Wrap(err, "CopyObject error")
} }
return &s3.CopyObjectInfo{ return &s3.CopyObjectInfo{
Key: dst, Key: dst,
@ -261,7 +262,7 @@ func (o *OSS) ListUploadedParts(ctx context.Context, uploadID string, name strin
Bucket: o.bucket.BucketName, Bucket: o.bucket.BucketName,
}, oss.MaxUploads(100), oss.MaxParts(maxParts), oss.PartNumberMarker(partNumberMarker)) }, oss.MaxUploads(100), oss.MaxParts(maxParts), oss.PartNumberMarker(partNumberMarker))
if err != nil { if err != nil {
return nil, err return nil, errs.Wrap(err, "ListUploadedParts error")
} }
res := &s3.ListUploadedPartsResult{ res := &s3.ListUploadedPartsResult{
Key: result.Key, Key: result.Key,
@ -286,7 +287,7 @@ func (o *OSS) AccessURL(ctx context.Context, name string, expire time.Duration,
var opts []oss.Option var opts []oss.Option
if opt != nil { if opt != nil {
if opt.Image != nil { if opt.Image != nil {
// 文档地址: https://help.aliyun.com/zh/oss/user-guide/resize-images-4?spm=a2c4g.11186623.0.0.4b3b1e4fWW6yji // Docs Address: https://help.aliyun.com/zh/oss/user-guide/resize-images-4?spm=a2c4g.11186623.0.0.4b3b1e4fWW6yji
var format string var format string
switch opt.Image.Format { switch opt.Image.Format {
case case
@ -329,7 +330,7 @@ func (o *OSS) AccessURL(ctx context.Context, name string, expire time.Duration,
} }
rawParams, err := oss.GetRawParams(opts) rawParams, err := oss.GetRawParams(opts)
if err != nil { if err != nil {
return "", err return "", errs.Wrap(err, "AccessURL error")
} }
params := getURLParams(*o.bucket.Client.Conn, rawParams) params := getURLParams(*o.bucket.Client.Conn, rawParams)
return getURL(o.um, o.bucket.BucketName, name, params).String(), nil return getURL(o.um, o.bucket.BucketName, name, params).String(), nil
@ -351,12 +352,12 @@ func (o *OSS) FormData(ctx context.Context, name string, size int64, contentType
} }
policyJson, err := json.Marshal(policy) policyJson, err := json.Marshal(policy)
if err != nil { if err != nil {
return nil, err return nil, errs.Wrap(err, "Marshal json error")
} }
policyStr := base64.StdEncoding.EncodeToString(policyJson) policyStr := base64.StdEncoding.EncodeToString(policyJson)
h := hmac.New(sha1.New, []byte(o.credentials.GetAccessKeySecret())) h := hmac.New(sha1.New, []byte(o.credentials.GetAccessKeySecret()))
if _, err := io.WriteString(h, policyStr); err != nil { if _, err := io.WriteString(h, policyStr); err != nil {
return nil, err return nil, errs.Wrap(err, "WriteString error")
} }
fd := &s3.FormData{ fd := &s3.FormData{
URL: o.bucketURL, URL: o.bucketURL,

View File

@ -46,8 +46,8 @@ type GroupModelInterface interface {
Find(ctx context.Context, groupIDs []string) (groups []*GroupModel, err error) Find(ctx context.Context, groupIDs []string) (groups []*GroupModel, err error)
Take(ctx context.Context, groupID string) (group *GroupModel, err error) Take(ctx context.Context, groupID string) (group *GroupModel, err error)
Search(ctx context.Context, keyword string, pagination pagination.Pagination) (total int64, groups []*GroupModel, err error) Search(ctx context.Context, keyword string, pagination pagination.Pagination) (total int64, groups []*GroupModel, err error)
// 获取群总数 // Get Group total quantity
CountTotal(ctx context.Context, before *time.Time) (count int64, err error) CountTotal(ctx context.Context, before *time.Time) (count int64, err error)
// 获取范围内群增量 // Get Group total quantity every day
CountRangeEverydayTotal(ctx context.Context, start time.Time, end time.Time) (map[string]int64, error) CountRangeEverydayTotal(ctx context.Context, start time.Time, end time.Time) (map[string]int64, error)
} }

View File

@ -62,9 +62,9 @@ type UserModelInterface interface {
Exist(ctx context.Context, userID string) (exist bool, err error) Exist(ctx context.Context, userID string) (exist bool, err error)
GetAllUserID(ctx context.Context, pagination pagination.Pagination) (count int64, userIDs []string, err error) GetAllUserID(ctx context.Context, pagination pagination.Pagination) (count int64, userIDs []string, err error)
GetUserGlobalRecvMsgOpt(ctx context.Context, userID string) (opt int, err error) GetUserGlobalRecvMsgOpt(ctx context.Context, userID string) (opt int, err error)
// 获取用户总数 // Get user total quantity
CountTotal(ctx context.Context, before *time.Time) (count int64, err error) CountTotal(ctx context.Context, before *time.Time) (count int64, err error)
// 获取范围内用户增量 // Get user total quantity every day
CountRangeEverydayTotal(ctx context.Context, start time.Time, end time.Time) (map[string]int64, error) CountRangeEverydayTotal(ctx context.Context, start time.Time, end time.Time) (map[string]int64, error)
//CRUD user command //CRUD user command
AddUserCommand(ctx context.Context, userID string, Type int32, UUID string, value string, ex string) error AddUserCommand(ctx context.Context, userID string, Type int32, UUID string, value string, ex string) error

View File

@ -21,8 +21,6 @@ import (
"fmt" "fmt"
"time" "time"
"github.com/OpenIMSDK/tools/log"
"github.com/OpenIMSDK/protocol/msg" "github.com/OpenIMSDK/protocol/msg"
"github.com/OpenIMSDK/protocol/constant" "github.com/OpenIMSDK/protocol/constant"
@ -35,6 +33,7 @@ import (
"github.com/OpenIMSDK/protocol/sdkws" "github.com/OpenIMSDK/protocol/sdkws"
"github.com/OpenIMSDK/tools/errs" "github.com/OpenIMSDK/tools/errs"
"github.com/OpenIMSDK/tools/log"
table "github.com/openimsdk/open-im-server/v3/pkg/common/db/table/unrelation" table "github.com/openimsdk/open-im-server/v3/pkg/common/db/table/unrelation"
) )
@ -122,13 +121,7 @@ func (m *MsgMongoDriver) UpdateMsgContent(ctx context.Context, docID string, ind
return nil return nil
} }
func (m *MsgMongoDriver) UpdateMsgStatusByIndexInOneDoc( func (m *MsgMongoDriver) UpdateMsgStatusByIndexInOneDoc(ctx context.Context, docID string, msg *sdkws.MsgData, seqIndex int, status int32) error {
ctx context.Context,
docID string,
msg *sdkws.MsgData,
seqIndex int,
status int32,
) error {
msg.Status = status msg.Status = status
bytes, err := proto.Marshal(msg) bytes, err := proto.Marshal(msg)
if err != nil { if err != nil {
@ -140,7 +133,7 @@ func (m *MsgMongoDriver) UpdateMsgStatusByIndexInOneDoc(
bson.M{"$set": bson.M{fmt.Sprintf("msgs.%d.msg", seqIndex): bytes}}, bson.M{"$set": bson.M{fmt.Sprintf("msgs.%d.msg", seqIndex): bytes}},
) )
if err != nil { if err != nil {
return errs.Wrap(err) return errs.Wrap(err, fmt.Sprintf("docID is %s, seqIndex is %d", docID, seqIndex))
} }
return nil return nil
} }
@ -166,7 +159,7 @@ func (m *MsgMongoDriver) GetMsgDocModelByIndex(
findOpts, findOpts,
) )
if err != nil { if err != nil {
return nil, errs.Wrap(err) return nil, errs.Wrap(err, fmt.Sprintf("conversationID is %s", conversationID))
} }
var msgs []table.MsgDocModel var msgs []table.MsgDocModel
err = cursor.All(ctx, &msgs) err = cursor.All(ctx, &msgs)
@ -222,7 +215,7 @@ func (m *MsgMongoDriver) DeleteMsgsInOneDocByIndex(ctx context.Context, docID st
} }
_, err := m.MsgCollection.UpdateMany(ctx, bson.M{"doc_id": docID}, updates) _, err := m.MsgCollection.UpdateMany(ctx, bson.M{"doc_id": docID}, updates)
if err != nil { if err != nil {
return errs.Wrap(err) return errs.Wrap(err, fmt.Sprintf("docID is %s, indexes is %v", docID, indexes))
} }
return nil return nil
} }
@ -289,7 +282,7 @@ func (m *MsgMongoDriver) GetMsgBySeqIndexIn1Doc(
defer cur.Close(ctx) defer cur.Close(ctx)
var msgDocModel []table.MsgDocModel var msgDocModel []table.MsgDocModel
if err := cur.All(ctx, &msgDocModel); err != nil { if err := cur.All(ctx, &msgDocModel); err != nil {
return nil, errs.Wrap(err) return nil, errs.Wrap(err, fmt.Sprintf("docID is %s, seqs is %v", docID, seqs))
} }
if len(msgDocModel) == 0 { if len(msgDocModel) == 0 {
return nil, errs.Wrap(mongo.ErrNoDocuments) return nil, errs.Wrap(mongo.ErrNoDocuments)
@ -316,14 +309,14 @@ func (m *MsgMongoDriver) GetMsgBySeqIndexIn1Doc(
} }
data, err := json.Marshal(&revokeContent) data, err := json.Marshal(&revokeContent)
if err != nil { if err != nil {
return nil, err return nil, errs.Wrap(err, fmt.Sprintf("docID is %s, seqs is %v", docID, seqs))
} }
elem := sdkws.NotificationElem{ elem := sdkws.NotificationElem{
Detail: string(data), Detail: string(data),
} }
content, err := json.Marshal(&elem) content, err := json.Marshal(&elem)
if err != nil { if err != nil {
return nil, err return nil, errs.Wrap(err, fmt.Sprintf("docID is %s, seqs is %v", docID, seqs))
} }
msg.Msg.ContentType = constant.MsgRevokeNotification msg.Msg.ContentType = constant.MsgRevokeNotification
msg.Msg.Content = string(content) msg.Msg.Content = string(content)
@ -336,17 +329,12 @@ func (m *MsgMongoDriver) GetMsgBySeqIndexIn1Doc(
func (m *MsgMongoDriver) IsExistDocID(ctx context.Context, docID string) (bool, error) { func (m *MsgMongoDriver) IsExistDocID(ctx context.Context, docID string) (bool, error) {
count, err := m.MsgCollection.CountDocuments(ctx, bson.M{"doc_id": docID}) count, err := m.MsgCollection.CountDocuments(ctx, bson.M{"doc_id": docID})
if err != nil { if err != nil {
return false, errs.Wrap(err) return false, errs.Wrap(err, fmt.Sprintf("docID is %s", docID))
} }
return count > 0, nil return count > 0, nil
} }
func (m *MsgMongoDriver) MarkSingleChatMsgsAsRead( func (m *MsgMongoDriver) MarkSingleChatMsgsAsRead(ctx context.Context, userID string, docID string, indexes []int64) error {
ctx context.Context,
userID string,
docID string,
indexes []int64,
) error {
updates := []mongo.WriteModel{} updates := []mongo.WriteModel{}
for _, index := range indexes { for _, index := range indexes {
filter := bson.M{ filter := bson.M{
@ -366,7 +354,7 @@ func (m *MsgMongoDriver) MarkSingleChatMsgsAsRead(
updates = append(updates, updateModel) updates = append(updates, updateModel)
} }
_, err := m.MsgCollection.BulkWrite(ctx, updates) _, err := m.MsgCollection.BulkWrite(ctx, updates)
return err return errs.Wrap(err, fmt.Sprintf("docID is %s, indexes is %v", docID, indexes))
} }
// RangeUserSendCount // RangeUserSendCount
@ -662,7 +650,7 @@ func (m *MsgMongoDriver) RangeUserSendCount(
"$dateToString": bson.M{ "$dateToString": bson.M{
"format": "%Y-%m-%d", "format": "%Y-%m-%d",
"date": bson.M{ "date": bson.M{
"$toDate": "$$item.msg.send_time", // 毫秒时间戳 "$toDate": "$$item.msg.send_time", // Millisecond timestamp
}, },
}, },
}, },
@ -911,7 +899,7 @@ func (m *MsgMongoDriver) RangeGroupSendCount(
"$dateToString": bson.M{ "$dateToString": bson.M{
"format": "%Y-%m-%d", "format": "%Y-%m-%d",
"date": bson.M{ "date": bson.M{
"$toDate": "$$item.msg.send_time", // 毫秒时间戳 "$toDate": "$$item.msg.send_time", // Millisecond timestamp
}, },
}, },
}, },
@ -1076,6 +1064,7 @@ func (m *MsgMongoDriver) searchMessage(ctx context.Context, req *msg.SearchMessa
var pipe mongo.Pipeline var pipe mongo.Pipeline
condition := bson.A{} condition := bson.A{}
if req.SendTime != "" { if req.SendTime != "" {
// Changed to keyed fields for bson.M to avoid govet errors
condition = append(condition, bson.M{"$eq": bson.A{bson.M{"$dateToString": bson.M{"format": "%Y-%m-%d", "date": bson.M{"$toDate": "$$item.msg.send_time"}}}, req.SendTime}}) condition = append(condition, bson.M{"$eq": bson.A{bson.M{"$dateToString": bson.M{"format": "%Y-%m-%d", "date": bson.M{"$toDate": "$$item.msg.send_time"}}}, req.SendTime}})
} }
if req.MsgType != 0 { if req.MsgType != 0 {
@ -1092,62 +1081,26 @@ func (m *MsgMongoDriver) searchMessage(ctx context.Context, req *msg.SearchMessa
} }
or := bson.A{ or := bson.A{
bson.M{ bson.M{"doc_id": bson.M{"$regex": "^si_", "$options": "i"}},
"doc_id": bson.M{ bson.M{"doc_id": bson.M{"$regex": "^g_", "$options": "i"}},
"$regex": "^si_", bson.M{"doc_id": bson.M{"$regex": "^sg_", "$options": "i"}},
"$options": "i",
},
},
} }
or = append(or,
bson.M{
"doc_id": bson.M{
"$regex": "^g_",
"$options": "i",
},
},
bson.M{
"doc_id": bson.M{
"$regex": "^sg_",
"$options": "i",
},
},
)
// Use bson.D with keyed fields to specify the order explicitly
pipe = mongo.Pipeline{ pipe = mongo.Pipeline{
{ {{"$match", bson.D{{Key: "$or", Value: or}}}},
{"$match", bson.D{ {{"$project", bson.D{
{ {Key: "msgs", Value: bson.D{
"$or", or, {Key: "$filter", Value: bson.D{
}, {Key: "input", Value: "$msgs"},
{Key: "as", Value: "item"},
{Key: "cond", Value: bson.D{{Key: "$and", Value: condition}}},
}}, }},
},
{
{"$project", bson.D{
{
"msgs", bson.D{
{
"$filter", bson.D{
{"input", "$msgs"},
{"as", "item"},
{
"cond", bson.D{
{"$and", condition},
},
},
},
},
},
},
{"doc_id", 1},
}}, }},
}, {Key: "doc_id", Value: 1},
{ }}},
{"$unwind", bson.M{"path": "$msgs"}}, {{"$unwind", bson.M{"path": "$msgs"}}},
}, {{"$sort", bson.M{"msgs.msg.send_time": -1}}},
{
{"$sort", bson.M{"msgs.msg.send_time": -1}},
},
} }
cursor, err := m.MsgCollection.Aggregate(ctx, pipe) cursor, err := m.MsgCollection.Aggregate(ctx, pipe)
if err != nil { if err != nil {
@ -1160,12 +1113,12 @@ func (m *MsgMongoDriver) searchMessage(ctx context.Context, req *msg.SearchMessa
var msgsDocs []docModel var msgsDocs []docModel
err = cursor.All(ctx, &msgsDocs) err = cursor.All(ctx, &msgsDocs)
if err != nil { if err != nil {
return 0, nil, err return 0, nil, errs.Wrap(err, "cursor.All msgsDocs")
} }
log.ZDebug(ctx, "query mongoDB", "result", msgsDocs) log.ZDebug(ctx, "query mongoDB", "result", msgsDocs)
msgs := make([]*table.MsgInfoModel, 0) msgs := make([]*table.MsgInfoModel, 0)
for index := range msgsDocs { for _, doc := range msgsDocs {
msgInfo := msgsDocs[index].Msg msgInfo := doc.Msg
if msgInfo == nil || msgInfo.Msg == nil { if msgInfo == nil || msgInfo.Msg == nil {
continue continue
} }
@ -1185,14 +1138,12 @@ func (m *MsgMongoDriver) searchMessage(ctx context.Context, req *msg.SearchMessa
} }
data, err := json.Marshal(&revokeContent) data, err := json.Marshal(&revokeContent)
if err != nil { if err != nil {
return 0, nil, err return 0, nil, errs.Wrap(err, "json.Marshal revokeContent")
}
elem := sdkws.NotificationElem{
Detail: string(data),
} }
elem := sdkws.NotificationElem{Detail: string(data)}
content, err := json.Marshal(&elem) content, err := json.Marshal(&elem)
if err != nil { if err != nil {
return 0, nil, err return 0, nil, errs.Wrap(err, "json.Marshal elem")
} }
msgInfo.Msg.ContentType = constant.MsgRevokeNotification msgInfo.Msg.ContentType = constant.MsgRevokeNotification
msgInfo.Msg.Content = string(content) msgInfo.Msg.Content = string(content)
@ -1203,7 +1154,8 @@ func (m *MsgMongoDriver) searchMessage(ctx context.Context, req *msg.SearchMessa
n := int32(len(msgs)) n := int32(len(msgs))
if start >= n { if start >= n {
return n, []*table.MsgInfoModel{}, nil return n, []*table.MsgInfoModel{}, nil
} else if start+req.Pagination.ShowNumber < n { }
if start+req.Pagination.ShowNumber < n {
msgs = msgs[start : start+req.Pagination.ShowNumber] msgs = msgs[start : start+req.Pagination.ShowNumber]
} else { } else {
msgs = msgs[start:] msgs = msgs[start:]

View File

@ -105,7 +105,7 @@ func (cd *ConnDirect) GetConns(ctx context.Context,
} }
if len(connections) == 0 { if len(connections) == 0 {
return nil, fmt.Errorf("no connections found for service: %s", serviceName) return nil, errs.Wrap(errors.New("no connections found for service"), "serviceName", serviceName)
} }
return connections, nil return connections, nil
} }
@ -155,10 +155,11 @@ func (cd *ConnDirect) dialService(ctx context.Context, address string, opts ...g
conn, err := grpc.DialContext(ctx, cd.resolverDirect.Scheme()+":///"+address, options...) conn, err := grpc.DialContext(ctx, cd.resolverDirect.Scheme()+":///"+address, options...)
if err != nil { if err != nil {
return nil, err return nil, errs.Wrap(err, "address", address)
} }
return conn, nil return conn, nil
} }
func (cd *ConnDirect) dialServiceWithoutResolver(ctx context.Context, address string, opts ...grpc.DialOption) (*grpc.ClientConn, error) { func (cd *ConnDirect) dialServiceWithoutResolver(ctx context.Context, address string, opts ...grpc.DialOption) (*grpc.ClientConn, error) {
options := append(opts, grpc.WithTransportCredentials(insecure.NewCredentials())) options := append(opts, grpc.WithTransportCredentials(insecure.NewCredentials()))
conn, err := grpc.DialContext(ctx, address, options...) conn, err := grpc.DialContext(ctx, address, options...)

View File

@ -24,6 +24,7 @@ import (
"github.com/openimsdk/open-im-server/v3/pkg/common/discoveryregister/zookeeper" "github.com/openimsdk/open-im-server/v3/pkg/common/discoveryregister/zookeeper"
"github.com/OpenIMSDK/tools/discoveryregistry" "github.com/OpenIMSDK/tools/discoveryregistry"
"github.com/OpenIMSDK/tools/errs"
) )
// NewDiscoveryRegister creates a new service discovery and registry client based on the provided environment type. // NewDiscoveryRegister creates a new service discovery and registry client based on the provided environment type.
@ -41,6 +42,6 @@ func NewDiscoveryRegister(envType string) (discoveryregistry.SvcDiscoveryRegistr
case "direct": case "direct":
return direct.NewConnDirect() return direct.NewConnDirect()
default: default:
return nil, errors.New("envType not correct") return nil, errs.Wrap(errors.New("envType not correct"))
} }
} }

View File

@ -39,7 +39,7 @@ func TestNewDiscoveryRegister(t *testing.T) {
expectedResult bool expectedResult bool
}{ }{
{"zookeeper", false, true}, {"zookeeper", false, true},
{"k8s", false, true}, // 假设 k8s 配置也已正确设置 {"k8s", false, true}, // Assume that the k8s configuration is also set up correctly
{"direct", false, true}, {"direct", false, true},
{"invalid", true, false}, {"invalid", true, false},
} }

View File

@ -66,12 +66,12 @@ func Post(ctx context.Context, url string, header map[string]string, data any, t
jsonStr, err := json.Marshal(data) jsonStr, err := json.Marshal(data)
if err != nil { if err != nil {
return nil, err return nil, errs.Wrap(err, "Post: JSON marshal failed")
} }
req, err := http.NewRequestWithContext(ctx, http.MethodPost, url, bytes.NewBuffer(jsonStr)) req, err := http.NewRequestWithContext(ctx, http.MethodPost, url, bytes.NewBuffer(jsonStr))
if err != nil { if err != nil {
return nil, err return nil, errs.Wrap(err, "Post: NewRequestWithContext failed")
} }
if operationID, _ := ctx.Value(constant.OperationID).(string); operationID != "" { if operationID, _ := ctx.Value(constant.OperationID).(string); operationID != "" {
@ -84,13 +84,13 @@ func Post(ctx context.Context, url string, header map[string]string, data any, t
resp, err := client.Do(req) resp, err := client.Do(req)
if err != nil { if err != nil {
return nil, err return nil, errs.Wrap(err, "Post: client.Do failed")
} }
defer resp.Body.Close() defer resp.Body.Close()
result, err := io.ReadAll(resp.Body) result, err := io.ReadAll(resp.Body)
if err != nil { if err != nil {
return nil, err return nil, errs.Wrap(err, "Post: ReadAll failed")
} }
return result, nil return result, nil
@ -102,7 +102,10 @@ func PostReturn(ctx context.Context, url string, header map[string]string, input
return err return err
} }
err = json.Unmarshal(b, output) err = json.Unmarshal(b, output)
return err if err != nil {
return errs.Wrap(err, "PostReturn: JSON unmarshal failed")
}
return nil
} }
func callBackPostReturn(ctx context.Context, url, command string, input interface{}, output callbackstruct.CallbackResp, callbackConfig config.CallBackConfig) error { func callBackPostReturn(ctx context.Context, url, command string, input interface{}, output callbackstruct.CallbackResp, callbackConfig config.CallBackConfig) error {
@ -127,7 +130,6 @@ func callBackPostReturn(ctx context.Context, url, command string, input interfac
} }
if err := output.Parse(); err != nil { if err := output.Parse(); err != nil {
log.ZWarn(ctx, "callback parse failed", err, "url", url, "input", input, "response", string(b)) log.ZWarn(ctx, "callback parse failed", err, "url", url, "input", input, "response", string(b))
return err
} }
log.ZInfo(ctx, "callback success", "url", url, "input", input, "response", string(b)) log.ZInfo(ctx, "callback success", "url", url, "input", input, "response", string(b))
return nil return nil

View File

@ -17,9 +17,8 @@ package kafka
import ( import (
"sync" "sync"
"github.com/openimsdk/open-im-server/v3/pkg/common/config"
"github.com/IBM/sarama" "github.com/IBM/sarama"
"github.com/OpenIMSDK/tools/errs"
) )
type Consumer struct { type Consumer struct {
@ -30,28 +29,33 @@ type Consumer struct {
Consumer sarama.Consumer Consumer sarama.Consumer
} }
func NewKafkaConsumer(addr []string, topic string) *Consumer { func NewKafkaConsumer(addr []string, topic string, kafkaConfig *sarama.Config) (*Consumer, error) {
p := Consumer{} p := Consumer{
p.Topic = topic Topic: topic,
p.addr = addr addr: addr,
consumerConfig := sarama.NewConfig()
if config.Config.Kafka.Username != "" && config.Config.Kafka.Password != "" {
consumerConfig.Net.SASL.Enable = true
consumerConfig.Net.SASL.User = config.Config.Kafka.Username
consumerConfig.Net.SASL.Password = config.Config.Kafka.Password
} }
SetupTLSConfig(consumerConfig)
consumer, err := sarama.NewConsumer(p.addr, consumerConfig) if kafkaConfig.Net.SASL.User != "" && kafkaConfig.Net.SASL.Password != "" {
kafkaConfig.Net.SASL.Enable = true
}
err := SetupTLSConfig(kafkaConfig)
if err != nil { if err != nil {
panic(err.Error()) return nil, err
}
consumer, err := sarama.NewConsumer(p.addr, kafkaConfig)
if err != nil {
return nil, errs.Wrap(err, "NewKafkaConsumer: creating consumer failed")
} }
p.Consumer = consumer p.Consumer = consumer
partitionList, err := consumer.Partitions(p.Topic) partitionList, err := consumer.Partitions(p.Topic)
if err != nil { if err != nil {
panic(err.Error()) return nil, errs.Wrap(err, "NewKafkaConsumer: getting partitions failed")
} }
p.PartitionList = partitionList p.PartitionList = partitionList
return &p return &p, nil
} }

View File

@ -90,11 +90,10 @@ func NewKafkaProducer(addr []string, topic string) (*Producer, error) {
for i := 0; i <= maxRetry; i++ { for i := 0; i <= maxRetry; i++ {
p.producer, err = sarama.NewSyncProducer(p.addr, p.config) p.producer, err = sarama.NewSyncProducer(p.addr, p.config)
if err == nil { if err == nil {
return &p, nil return &p, errs.Wrap(err)
} }
time.Sleep(1 * time.Second) // Wait before retrying time.Sleep(1 * time.Second) // Wait before retrying
} }
// Panic if unable to create producer after retries // Panic if unable to create producer after retries
if err != nil { if err != nil {
return nil, errs.Wrap(errors.New("failed to create Kafka producer: " + err.Error())) return nil, errs.Wrap(errors.New("failed to create Kafka producer: " + err.Error()))
@ -179,7 +178,7 @@ func (p *Producer) SendMessage(ctx context.Context, key string, msg proto.Messag
// Attach context metadata as headers // Attach context metadata as headers
header, err := GetMQHeaderWithContext(ctx) header, err := GetMQHeaderWithContext(ctx)
if err != nil { if err != nil {
return 0, 0, errs.Wrap(err) return 0, 0, err
} }
kMsg.Headers = header kMsg.Headers = header

View File

@ -26,16 +26,21 @@ import (
) )
// SetupTLSConfig set up the TLS config from config file. // SetupTLSConfig set up the TLS config from config file.
func SetupTLSConfig(cfg *sarama.Config) { func SetupTLSConfig(cfg *sarama.Config) error {
if config.Config.Kafka.TLS != nil { if config.Config.Kafka.TLS != nil {
cfg.Net.TLS.Enable = true cfg.Net.TLS.Enable = true
cfg.Net.TLS.Config = tls.NewTLSConfig( tlsConfig, err := tls.NewTLSConfig(
config.Config.Kafka.TLS.ClientCrt, config.Config.Kafka.TLS.ClientCrt,
config.Config.Kafka.TLS.ClientKey, config.Config.Kafka.TLS.ClientKey,
config.Config.Kafka.TLS.CACrt, config.Config.Kafka.TLS.CACrt,
[]byte(config.Config.Kafka.TLS.ClientKeyPwd), []byte(config.Config.Kafka.TLS.ClientKeyPwd),
) )
if err != nil {
return err
} }
cfg.Net.TLS.Config = tlsConfig
}
return nil
} }
// getEnvOrConfig returns the value of the environment variable if it exists, // getEnvOrConfig returns the value of the environment variable if it exists,

View File

@ -24,7 +24,6 @@ import (
) )
func NewGrpcPromObj(cusMetrics []prometheus.Collector) (*prometheus.Registry, *gp.ServerMetrics, error) { func NewGrpcPromObj(cusMetrics []prometheus.Collector) (*prometheus.Registry, *gp.ServerMetrics, error) {
////////////////////////////////////////////////////////
reg := prometheus.NewRegistry() reg := prometheus.NewRegistry()
grpcMetrics := gp.NewServerMetrics() grpcMetrics := gp.NewServerMetrics()
grpcMetrics.EnableHandlingTimeHistogram() grpcMetrics.EnableHandlingTimeHistogram()

View File

@ -70,7 +70,7 @@ func Start(
defer listener.Close() defer listener.Close()
client, err := kdisc.NewDiscoveryRegister(config.Config.Envs.Discovery) client, err := kdisc.NewDiscoveryRegister(config.Config.Envs.Discovery)
if err != nil { if err != nil {
return errs.Wrap(err) return err
} }
defer client.Close() defer client.Close()

26
pkg/common/tls/tls.go Executable file → Normal file
View File

@ -21,6 +21,8 @@ import (
"errors" "errors"
"os" "os"
"github.com/OpenIMSDK/tools/errs"
"github.com/openimsdk/open-im-server/v3/pkg/common/config" "github.com/openimsdk/open-im-server/v3/pkg/common/config"
) )
@ -49,37 +51,41 @@ func readEncryptablePEMBlock(path string, pwd []byte) ([]byte, error) {
} }
// NewTLSConfig setup the TLS config from general config file. // NewTLSConfig setup the TLS config from general config file.
func NewTLSConfig(clientCertFile, clientKeyFile, caCertFile string, keyPwd []byte) *tls.Config { func NewTLSConfig(clientCertFile, clientKeyFile, caCertFile string, keyPwd []byte) (*tls.Config, error) {
tlsConfig := tls.Config{} var tlsConfig tls.Config
if clientCertFile != "" && clientKeyFile != "" { if clientCertFile != "" && clientKeyFile != "" {
certPEMBlock, err := os.ReadFile(clientCertFile) certPEMBlock, err := os.ReadFile(clientCertFile)
if err != nil { if err != nil {
panic(err) return nil, errs.Wrap(err, "NewTLSConfig: failed to read client cert file")
} }
keyPEMBlock, err := readEncryptablePEMBlock(clientKeyFile, keyPwd) keyPEMBlock, err := readEncryptablePEMBlock(clientKeyFile, keyPwd)
if err != nil { if err != nil {
panic(err) return nil, err
} }
cert, err := tls.X509KeyPair(certPEMBlock, keyPEMBlock) cert, err := tls.X509KeyPair(certPEMBlock, keyPEMBlock)
if err != nil { if err != nil {
panic(err) return nil, errs.Wrap(err, "NewTLSConfig: failed to create X509 key pair")
} }
tlsConfig.Certificates = []tls.Certificate{cert} tlsConfig.Certificates = []tls.Certificate{cert}
} }
if caCertFile != "" {
caCert, err := os.ReadFile(caCertFile) caCert, err := os.ReadFile(caCertFile)
if err != nil { if err != nil {
panic(err) return nil, errs.Wrap(err, "NewTLSConfig: failed to read CA cert file")
} }
caCertPool := x509.NewCertPool() caCertPool := x509.NewCertPool()
ok := caCertPool.AppendCertsFromPEM(caCert) if ok := caCertPool.AppendCertsFromPEM(caCert); !ok {
if !ok { return nil, errors.New("NewTLSConfig: not a valid CA cert")
panic(errors.New("not a valid CA cert"))
} }
tlsConfig.RootCAs = caCertPool tlsConfig.RootCAs = caCertPool
}
tlsConfig.InsecureSkipVerify = config.Config.Kafka.TLS.InsecureSkipVerify tlsConfig.InsecureSkipVerify = config.Config.Kafka.TLS.InsecureSkipVerify
return &tlsConfig return &tlsConfig, nil
} }

View File

@ -20,6 +20,7 @@ import (
"github.com/OpenIMSDK/protocol/constant" "github.com/OpenIMSDK/protocol/constant"
"github.com/OpenIMSDK/protocol/sdkws" "github.com/OpenIMSDK/protocol/sdkws"
"github.com/OpenIMSDK/tools/errs"
"google.golang.org/protobuf/proto" "google.golang.org/protobuf/proto"
) )
@ -188,7 +189,7 @@ func (s MsgBySeq) Swap(i, j int) {
func Pb2String(pb proto.Message) (string, error) { func Pb2String(pb proto.Message) (string, error) {
s, err := proto.Marshal(pb) s, err := proto.Marshal(pb)
if err != nil { if err != nil {
return "", err return "", errs.Wrap(err)
} }
return string(s), nil return string(s), nil
} }

View File

@ -23,12 +23,13 @@ import (
"github.com/OpenIMSDK/tools/discoveryregistry" "github.com/OpenIMSDK/tools/discoveryregistry"
"github.com/openimsdk/open-im-server/v3/pkg/common/config" "github.com/openimsdk/open-im-server/v3/pkg/common/config"
util "github.com/openimsdk/open-im-server/v3/pkg/util/genutil"
) )
func NewAuth(discov discoveryregistry.SvcDiscoveryRegistry) *Auth { func NewAuth(discov discoveryregistry.SvcDiscoveryRegistry) *Auth {
conn, err := discov.GetConn(context.Background(), config.Config.RpcRegisterName.OpenImAuthName) conn, err := discov.GetConn(context.Background(), config.Config.RpcRegisterName.OpenImAuthName)
if err != nil { if err != nil {
panic(err) util.ExitWithError(err)
} }
client := auth.NewAuthClient(conn) client := auth.NewAuthClient(conn)
return &Auth{discov: discov, conn: conn, Client: client} return &Auth{discov: discov, conn: conn, Client: client}

View File

@ -24,6 +24,8 @@ import (
"github.com/OpenIMSDK/tools/discoveryregistry" "github.com/OpenIMSDK/tools/discoveryregistry"
"github.com/OpenIMSDK/tools/errs" "github.com/OpenIMSDK/tools/errs"
util "github.com/openimsdk/open-im-server/v3/pkg/util/genutil"
"github.com/openimsdk/open-im-server/v3/pkg/common/config" "github.com/openimsdk/open-im-server/v3/pkg/common/config"
) )
@ -36,7 +38,7 @@ type Conversation struct {
func NewConversation(discov discoveryregistry.SvcDiscoveryRegistry) *Conversation { func NewConversation(discov discoveryregistry.SvcDiscoveryRegistry) *Conversation {
conn, err := discov.GetConn(context.Background(), config.Config.RpcRegisterName.OpenImConversationName) conn, err := discov.GetConn(context.Background(), config.Config.RpcRegisterName.OpenImConversationName)
if err != nil { if err != nil {
panic(err) util.ExitWithError(err)
} }
client := pbconversation.NewConversationClient(conn) client := pbconversation.NewConversationClient(conn)
return &Conversation{discov: discov, conn: conn, Client: client} return &Conversation{discov: discov, conn: conn, Client: client}
@ -114,11 +116,7 @@ func (c *ConversationRpcClient) GetConversationsByConversationID(ctx context.Con
return resp.Conversations, nil return resp.Conversations, nil
} }
func (c *ConversationRpcClient) GetConversations( func (c *ConversationRpcClient) GetConversations(ctx context.Context, ownerUserID string, conversationIDs []string) ([]*pbconversation.Conversation, error) {
ctx context.Context,
ownerUserID string,
conversationIDs []string,
) ([]*pbconversation.Conversation, error) {
if len(conversationIDs) == 0 { if len(conversationIDs) == 0 {
return nil, nil return nil, nil
} }

View File

@ -24,6 +24,7 @@ import (
"github.com/OpenIMSDK/tools/discoveryregistry" "github.com/OpenIMSDK/tools/discoveryregistry"
"github.com/openimsdk/open-im-server/v3/pkg/common/config" "github.com/openimsdk/open-im-server/v3/pkg/common/config"
util "github.com/openimsdk/open-im-server/v3/pkg/util/genutil"
) )
type Friend struct { type Friend struct {
@ -35,7 +36,7 @@ type Friend struct {
func NewFriend(discov discoveryregistry.SvcDiscoveryRegistry) *Friend { func NewFriend(discov discoveryregistry.SvcDiscoveryRegistry) *Friend {
conn, err := discov.GetConn(context.Background(), config.Config.RpcRegisterName.OpenImFriendName) conn, err := discov.GetConn(context.Background(), config.Config.RpcRegisterName.OpenImFriendName)
if err != nil { if err != nil {
panic(err) util.ExitWithError(err)
} }
client := friend.NewFriendClient(conn) client := friend.NewFriendClient(conn)
return &Friend{discov: discov, conn: conn, Client: client} return &Friend{discov: discov, conn: conn, Client: client}
@ -62,7 +63,7 @@ func (f *FriendRpcClient) GetFriendsInfo(
return return
} }
// possibleFriendUserID是否在userID的好友中. // possibleFriendUserID Is PossibleFriendUserId's friends.
func (f *FriendRpcClient) IsFriend(ctx context.Context, possibleFriendUserID, userID string) (bool, error) { func (f *FriendRpcClient) IsFriend(ctx context.Context, possibleFriendUserID, userID string) (bool, error) {
resp, err := f.Client.IsFriend(ctx, &friend.IsFriendReq{UserID1: userID, UserID2: possibleFriendUserID}) resp, err := f.Client.IsFriend(ctx, &friend.IsFriendReq{UserID1: userID, UserID2: possibleFriendUserID})
if err != nil { if err != nil {

View File

@ -28,6 +28,7 @@ import (
"github.com/OpenIMSDK/tools/utils" "github.com/OpenIMSDK/tools/utils"
"github.com/openimsdk/open-im-server/v3/pkg/common/config" "github.com/openimsdk/open-im-server/v3/pkg/common/config"
util "github.com/openimsdk/open-im-server/v3/pkg/util/genutil"
) )
type Group struct { type Group struct {
@ -39,7 +40,7 @@ type Group struct {
func NewGroup(discov discoveryregistry.SvcDiscoveryRegistry) *Group { func NewGroup(discov discoveryregistry.SvcDiscoveryRegistry) *Group {
conn, err := discov.GetConn(context.Background(), config.Config.RpcRegisterName.OpenImGroupName) conn, err := discov.GetConn(context.Background(), config.Config.RpcRegisterName.OpenImGroupName)
if err != nil { if err != nil {
panic(err) util.ExitWithError(err)
} }
client := group.NewGroupClient(conn) client := group.NewGroupClient(conn)
return &Group{discov: discov, conn: conn, Client: client} return &Group{discov: discov, conn: conn, Client: client}

View File

@ -147,14 +147,24 @@ func NewMessageRpcClient(discov discoveryregistry.SvcDiscoveryRegistry) MessageR
return MessageRpcClient(*NewMessage(discov)) return MessageRpcClient(*NewMessage(discov))
} }
// SendMsg sends a message through the gRPC client and returns the response.
// It wraps any encountered error for better error handling and context understanding.
func (m *MessageRpcClient) SendMsg(ctx context.Context, req *msg.SendMsgReq) (*msg.SendMsgResp, error) { func (m *MessageRpcClient) SendMsg(ctx context.Context, req *msg.SendMsgReq) (*msg.SendMsgResp, error) {
resp, err := m.Client.SendMsg(ctx, req) resp, err := m.Client.SendMsg(ctx, req)
return resp, err if err != nil {
return nil, err
}
return resp, nil
} }
// GetMaxSeq retrieves the maximum sequence number from the gRPC client.
// Errors during the gRPC call are wrapped to provide additional context.
func (m *MessageRpcClient) GetMaxSeq(ctx context.Context, req *sdkws.GetMaxSeqReq) (*sdkws.GetMaxSeqResp, error) { func (m *MessageRpcClient) GetMaxSeq(ctx context.Context, req *sdkws.GetMaxSeqReq) (*sdkws.GetMaxSeqResp, error) {
resp, err := m.Client.GetMaxSeq(ctx, req) resp, err := m.Client.GetMaxSeq(ctx, req)
return resp, err if err != nil {
return nil, err
}
return resp, nil
} }
func (m *MessageRpcClient) GetMaxSeqs(ctx context.Context, conversationIDs []string) (map[string]int64, error) { func (m *MessageRpcClient) GetMaxSeqs(ctx context.Context, conversationIDs []string) (map[string]int64, error) {
@ -181,9 +191,15 @@ func (m *MessageRpcClient) GetMsgByConversationIDs(ctx context.Context, docIDs [
return resp.MsgDatas, err return resp.MsgDatas, err
} }
// PullMessageBySeqList retrieves messages by their sequence numbers using the gRPC client.
// It directly forwards the request to the gRPC client and returns the response along with any error encountered.
func (m *MessageRpcClient) PullMessageBySeqList(ctx context.Context, req *sdkws.PullMessageBySeqsReq) (*sdkws.PullMessageBySeqsResp, error) { func (m *MessageRpcClient) PullMessageBySeqList(ctx context.Context, req *sdkws.PullMessageBySeqsReq) (*sdkws.PullMessageBySeqsResp, error) {
resp, err := m.Client.PullMessageBySeqs(ctx, req) resp, err := m.Client.PullMessageBySeqs(ctx, req)
return resp, err if err != nil {
// Wrap the error to provide more context if the gRPC call fails.
return nil, err
}
return resp, nil
} }
func (m *MessageRpcClient) GetConversationMaxSeq(ctx context.Context, conversationID string) (int64, error) { func (m *MessageRpcClient) GetConversationMaxSeq(ctx context.Context, conversationID string) (int64, error) {

View File

@ -31,7 +31,7 @@ func NewConversationNotificationSender(msgRpcClient *rpcclient.MessageRpcClient)
return &ConversationNotificationSender{rpcclient.NewNotificationSender(rpcclient.WithRpcClient(msgRpcClient))} return &ConversationNotificationSender{rpcclient.NewNotificationSender(rpcclient.WithRpcClient(msgRpcClient))}
} }
// SetPrivate调用. // SetPrivate invote.
func (c *ConversationNotificationSender) ConversationSetPrivateNotification(ctx context.Context, sendID, recvID string, func (c *ConversationNotificationSender) ConversationSetPrivateNotification(ctx context.Context, sendID, recvID string,
isPrivateChat bool, conversationID string, isPrivateChat bool, conversationID string,
) error { ) error {
@ -45,7 +45,6 @@ func (c *ConversationNotificationSender) ConversationSetPrivateNotification(ctx
return c.Notification(ctx, sendID, recvID, constant.ConversationPrivateChatNotification, tips) return c.Notification(ctx, sendID, recvID, constant.ConversationPrivateChatNotification, tips)
} }
// 会话改变.
func (c *ConversationNotificationSender) ConversationChangeNotification(ctx context.Context, userID string, conversationIDs []string) error { func (c *ConversationNotificationSender) ConversationChangeNotification(ctx context.Context, userID string, conversationIDs []string) error {
tips := &sdkws.ConversationUpdateTips{ tips := &sdkws.ConversationUpdateTips{
UserID: userID, UserID: userID,
@ -55,7 +54,6 @@ func (c *ConversationNotificationSender) ConversationChangeNotification(ctx cont
return c.Notification(ctx, userID, userID, constant.ConversationChangeNotification, tips) return c.Notification(ctx, userID, userID, constant.ConversationChangeNotification, tips)
} }
// 会话未读数同步.
func (c *ConversationNotificationSender) ConversationUnreadChangeNotification( func (c *ConversationNotificationSender) ConversationUnreadChangeNotification(
ctx context.Context, ctx context.Context,
userID, conversationID string, userID, conversationID string,

View File

@ -31,7 +31,7 @@ import (
type FriendNotificationSender struct { type FriendNotificationSender struct {
*rpcclient.NotificationSender *rpcclient.NotificationSender
// 找不到报错 // Target not found err
getUsersInfo func(ctx context.Context, userIDs []string) ([]CommonUser, error) getUsersInfo func(ctx context.Context, userIDs []string) ([]CommonUser, error)
// db controller // db controller
db controller.FriendDatabase db controller.FriendDatabase

View File

@ -23,6 +23,7 @@ import (
"github.com/OpenIMSDK/tools/discoveryregistry" "github.com/OpenIMSDK/tools/discoveryregistry"
"github.com/openimsdk/open-im-server/v3/pkg/common/config" "github.com/openimsdk/open-im-server/v3/pkg/common/config"
util "github.com/openimsdk/open-im-server/v3/pkg/util/genutil"
) )
type Push struct { type Push struct {
@ -34,7 +35,7 @@ type Push struct {
func NewPush(discov discoveryregistry.SvcDiscoveryRegistry) *Push { func NewPush(discov discoveryregistry.SvcDiscoveryRegistry) *Push {
conn, err := discov.GetConn(context.Background(), config.Config.RpcRegisterName.OpenImPushName) conn, err := discov.GetConn(context.Background(), config.Config.RpcRegisterName.OpenImPushName)
if err != nil { if err != nil {
panic(err) util.ExitWithError(err)
} }
return &Push{ return &Push{
discov: discov, discov: discov,
@ -49,9 +50,6 @@ func NewPushRpcClient(discov discoveryregistry.SvcDiscoveryRegistry) PushRpcClie
return PushRpcClient(*NewPush(discov)) return PushRpcClient(*NewPush(discov))
} }
func (p *PushRpcClient) DelUserPushToken( func (p *PushRpcClient) DelUserPushToken(ctx context.Context, req *push.DelUserPushTokenReq) (*push.DelUserPushTokenResp, error) {
ctx context.Context,
req *push.DelUserPushTokenReq,
) (*push.DelUserPushTokenResp, error) {
return p.Client.DelUserPushToken(ctx, req) return p.Client.DelUserPushToken(ctx, req)
} }

View File

@ -24,8 +24,10 @@ import (
"github.com/OpenIMSDK/protocol/third" "github.com/OpenIMSDK/protocol/third"
"github.com/OpenIMSDK/tools/discoveryregistry" "github.com/OpenIMSDK/tools/discoveryregistry"
"github.com/OpenIMSDK/tools/errs"
"github.com/openimsdk/open-im-server/v3/pkg/common/config" "github.com/openimsdk/open-im-server/v3/pkg/common/config"
util "github.com/openimsdk/open-im-server/v3/pkg/util/genutil"
) )
type Third struct { type Third struct {
@ -38,35 +40,42 @@ type Third struct {
func NewThird(discov discoveryregistry.SvcDiscoveryRegistry) *Third { func NewThird(discov discoveryregistry.SvcDiscoveryRegistry) *Third {
conn, err := discov.GetConn(context.Background(), config.Config.RpcRegisterName.OpenImThirdName) conn, err := discov.GetConn(context.Background(), config.Config.RpcRegisterName.OpenImThirdName)
if err != nil { if err != nil {
panic(err) util.ExitWithError(err)
} }
client := third.NewThirdClient(conn) client := third.NewThirdClient(conn)
minioClient, err := minioInit() minioClient, err := minioInit()
if err != nil { if err != nil {
panic(err) util.ExitWithError(err)
} }
return &Third{discov: discov, Client: client, conn: conn, MinioClient: minioClient} return &Third{discov: discov, Client: client, conn: conn, MinioClient: minioClient}
} }
func minioInit() (*minio.Client, error) { func minioInit() (*minio.Client, error) {
minioClient := &minio.Client{} // Retrieve MinIO configuration details
initUrl := config.Config.Object.Minio.Endpoint endpoint := config.Config.Object.Minio.Endpoint
minioUrl, err := url.Parse(initUrl) accessKeyID := config.Config.Object.Minio.AccessKeyID
secretAccessKey := config.Config.Object.Minio.SecretAccessKey
// Parse the MinIO URL to determine if the connection should be secure
minioURL, err := url.Parse(endpoint)
if err != nil { if err != nil {
return nil, err return nil, errs.Wrap(err, "minioInit: failed to parse MinIO endpoint URL")
} }
// Determine the security of the connection based on the scheme
secure := minioURL.Scheme == "https"
// Setup MinIO client options
opts := &minio.Options{ opts := &minio.Options{
Creds: credentials.NewStaticV4(config.Config.Object.Minio.AccessKeyID, config.Config.Object.Minio.SecretAccessKey, ""), Creds: credentials.NewStaticV4(accessKeyID, secretAccessKey, ""),
// Region: config.Config.Credential.Minio.Location, Secure: secure,
} }
if minioUrl.Scheme == "http" {
opts.Secure = false // Initialize MinIO client
} else if minioUrl.Scheme == "https" { minioClient, err := minio.New(minioURL.Host, opts)
opts.Secure = true
}
minioClient, err = minio.New(minioUrl.Host, opts)
if err != nil { if err != nil {
return nil, err return nil, errs.Wrap(err, "minioInit: failed to create MinIO client")
} }
return minioClient, nil return minioClient, nil
} }

View File

@ -19,6 +19,7 @@ import (
"strings" "strings"
"github.com/openimsdk/open-im-server/v3/pkg/authverify" "github.com/openimsdk/open-im-server/v3/pkg/authverify"
util "github.com/openimsdk/open-im-server/v3/pkg/util/genutil"
"google.golang.org/grpc" "google.golang.org/grpc"
@ -42,7 +43,7 @@ type User struct {
func NewUser(discov discoveryregistry.SvcDiscoveryRegistry) *User { func NewUser(discov discoveryregistry.SvcDiscoveryRegistry) *User {
conn, err := discov.GetConn(context.Background(), config.Config.RpcRegisterName.OpenImUserName) conn, err := discov.GetConn(context.Background(), config.Config.RpcRegisterName.OpenImUserName)
if err != nil { if err != nil {
panic(err) util.ExitWithError(err)
} }
client := user.NewUserClient(conn) client := user.NewUserClient(conn)
return &User{Discov: discov, Client: client, conn: conn} return &User{Discov: discov, Client: client, conn: conn}

View File

@ -18,6 +18,8 @@ import (
"fmt" "fmt"
"os" "os"
"path/filepath" "path/filepath"
"github.com/OpenIMSDK/tools/errs"
) )
// OutDir creates the absolute path name from path and checks path exists. // OutDir creates the absolute path name from path and checks path exists.
@ -25,16 +27,16 @@ import (
func OutDir(path string) (string, error) { func OutDir(path string) (string, error) {
outDir, err := filepath.Abs(path) outDir, err := filepath.Abs(path)
if err != nil { if err != nil {
return "", err return "", errs.Wrap(err, "output directory %s does not exist", path)
} }
stat, err := os.Stat(outDir) stat, err := os.Stat(outDir)
if err != nil { if err != nil {
return "", err return "", errs.Wrap(err, "output directory %s does not exist", outDir)
} }
if !stat.IsDir() { if !stat.IsDir() {
return "", fmt.Errorf("output directory %s is not a directory", outDir) return "", errs.Wrap(err, "output directory %s is not a directory", outDir)
} }
outDir += "/" outDir += "/"
return outDir, nil return outDir, nil

View File

@ -103,14 +103,14 @@ function openim::tools::start_service() {
printf "Specifying prometheus port: %s\n" "${prometheus_port}" printf "Specifying prometheus port: %s\n" "${prometheus_port}"
cmd="${cmd} --prometheus_port ${prometheus_port}" cmd="${cmd} --prometheus_port ${prometheus_port}"
fi fi
openim::log::status "Starting ${binary_name}..." openim::log::status "Starting binary ${binary_name}..."
${cmd} | tee -a "${LOG_FILE}" ${cmd} | tee -a "${LOG_FILE}"
} }
function openim::tools::start() { function openim::tools::start() {
openim::log::info "Starting OpenIM Tools..." openim::log::info "Starting OpenIM Tools..."
for tool in "${OPENIM_TOOLS_NAME_LISTARIES[@]}"; do for tool in "${OPENIM_TOOLS_NAME_LISTARIES[@]}"; do
openim::log::info "Starting ${tool}..." openim::log::info "Starting tool ${tool}..."
# openim::tools::start_service ${tool} # openim::tools::start_service ${tool}
sleep 0.2 sleep 0.2
done done
@ -120,7 +120,7 @@ function openim::tools::start() {
function openim::tools::pre-start() { function openim::tools::pre-start() {
openim::log::info "Preparing to start OpenIM Tools..." openim::log::info "Preparing to start OpenIM Tools..."
for tool in "${OPENIM_TOOLS_PRE_START_NAME_LISTARIES[@]}"; do for tool in "${OPENIM_TOOLS_PRE_START_NAME_LISTARIES[@]}"; do
openim::log::info "Starting ${tool}..." openim::log::info "Starting tool ${tool}..."
openim::tools::start_service ${tool} ${OPNEIM_CONFIG} openim::tools::start_service ${tool} ${OPNEIM_CONFIG}
done done
} }
@ -128,7 +128,7 @@ function openim::tools::pre-start() {
function openim::tools::post-start() { function openim::tools::post-start() {
openim::log::info "Post-start actions for OpenIM Tools..." openim::log::info "Post-start actions for OpenIM Tools..."
for tool in "${OPENIM_TOOLS_POST_START_NAME_LISTARIES[@]}"; do for tool in "${OPENIM_TOOLS_POST_START_NAME_LISTARIES[@]}"; do
openim::log::info "Starting ${tool}..." openim::log::info "Starting tool ${tool}..."
openim::tools::start_service ${tool} openim::tools::start_service ${tool}
done done
} }

View File

@ -17,7 +17,7 @@
# #
GO := go GO := go
GO_SUPPORTED_VERSIONS ?= 1.19|1.20|1.21|1.22 GO_SUPPORTED_VERSIONS ?= 1.19|1.20|1.21|1.22|1.23
GO_LDFLAGS += -X $(VERSION_PACKAGE).gitVersion=$(GIT_TAG) \ GO_LDFLAGS += -X $(VERSION_PACKAGE).gitVersion=$(GIT_TAG) \
-X $(VERSION_PACKAGE).gitCommit=$(GIT_COMMIT) \ -X $(VERSION_PACKAGE).gitCommit=$(GIT_COMMIT) \

1
test/codescan/main.go Normal file
View File

@ -0,0 +1 @@
package main

View File

@ -128,6 +128,8 @@ Open issue: https://github.com/openimsdk/open-im-server/issues/new/choose, choos
The E2E test suite is integrated with CI, which runs the tests automatically on each code commit. The results are reported back to the pull request or commit to provide immediate feedback on the impact of the changes. The E2E test suite is integrated with CI, which runs the tests automatically on each code commit. The results are reported back to the pull request or commit to provide immediate feedback on the impact of the changes.
[![OpenIM Linux System E2E Test](https://github.com/openimsdk/open-im-server/actions/workflows/e2e-test.yml/badge.svg)](https://github.com/openimsdk/open-im-server/actions/workflows/e2e-test.yml)
## Contact ## Contact

Some files were not shown because too many files have changed in this diff Show More