mirror of
https://github.com/openimsdk/open-im-server.git
synced 2025-05-24 13:59:17 +08:00
Merge branch 'v2.3.0release' into del
This commit is contained in:
commit
2016a1656b
@ -43,7 +43,7 @@ func main() {
|
|||||||
gin.SetMode(gin.ReleaseMode)
|
gin.SetMode(gin.ReleaseMode)
|
||||||
f, _ := os.Create("../logs/api.log")
|
f, _ := os.Create("../logs/api.log")
|
||||||
gin.DefaultWriter = io.MultiWriter(f)
|
gin.DefaultWriter = io.MultiWriter(f)
|
||||||
gin.SetMode(gin.DebugMode)
|
// gin.SetMode(gin.DebugMode)
|
||||||
r := gin.Default()
|
r := gin.Default()
|
||||||
r.Use(utils.CorsHandler())
|
r.Use(utils.CorsHandler())
|
||||||
|
|
||||||
|
@ -1,6 +1,7 @@
|
|||||||
package apiThird
|
package apiThird
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
api "Open_IM/pkg/base_info"
|
||||||
"Open_IM/pkg/common/db"
|
"Open_IM/pkg/common/db"
|
||||||
"Open_IM/pkg/common/log"
|
"Open_IM/pkg/common/log"
|
||||||
"Open_IM/pkg/common/token_verify"
|
"Open_IM/pkg/common/token_verify"
|
||||||
@ -10,18 +11,10 @@ import (
|
|||||||
"github.com/gin-gonic/gin"
|
"github.com/gin-gonic/gin"
|
||||||
)
|
)
|
||||||
|
|
||||||
/**
|
|
||||||
* FCM第三方上报Token
|
|
||||||
*/
|
|
||||||
type FcmUpdateTokenReq struct {
|
|
||||||
OperationID string `json:"operationID"`
|
|
||||||
Platform int `json:"platform" binding:"required,min=1,max=2"` //only for ios + android
|
|
||||||
FcmToken string `json:"fcmToken"`
|
|
||||||
}
|
|
||||||
|
|
||||||
func FcmUpdateToken(c *gin.Context) {
|
func FcmUpdateToken(c *gin.Context) {
|
||||||
var (
|
var (
|
||||||
req FcmUpdateTokenReq
|
req api.FcmUpdateTokenReq
|
||||||
|
resp api.FcmUpdateTokenResp
|
||||||
)
|
)
|
||||||
if err := c.Bind(&req); err != nil {
|
if err := c.Bind(&req); err != nil {
|
||||||
log.NewError("0", utils.GetSelfFuncName(), "BindJSON failed ", err.Error())
|
log.NewError("0", utils.GetSelfFuncName(), "BindJSON failed ", err.Error())
|
||||||
@ -34,7 +27,9 @@ func FcmUpdateToken(c *gin.Context) {
|
|||||||
if !ok {
|
if !ok {
|
||||||
errMsg := req.OperationID + " " + "GetUserIDFromToken failed " + errInfo + " token:" + c.Request.Header.Get("token")
|
errMsg := req.OperationID + " " + "GetUserIDFromToken failed " + errInfo + " token:" + c.Request.Header.Get("token")
|
||||||
log.NewError(req.OperationID, errMsg)
|
log.NewError(req.OperationID, errMsg)
|
||||||
c.JSON(http.StatusInternalServerError, gin.H{"errCode": 500, "errMsg": errMsg})
|
resp.ErrCode = 500
|
||||||
|
resp.ErrMsg = errMsg
|
||||||
|
c.JSON(http.StatusInternalServerError, resp)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
log.NewInfo(req.OperationID, utils.GetSelfFuncName(), req, UserId)
|
log.NewInfo(req.OperationID, utils.GetSelfFuncName(), req, UserId)
|
||||||
@ -43,10 +38,12 @@ func FcmUpdateToken(c *gin.Context) {
|
|||||||
if err != nil {
|
if err != nil {
|
||||||
errMsg := req.OperationID + " " + "SetFcmToken failed " + err.Error() + " token:" + c.Request.Header.Get("token")
|
errMsg := req.OperationID + " " + "SetFcmToken failed " + err.Error() + " token:" + c.Request.Header.Get("token")
|
||||||
log.NewError(req.OperationID, errMsg)
|
log.NewError(req.OperationID, errMsg)
|
||||||
c.JSON(http.StatusInternalServerError, gin.H{"errCode": 500, "errMsg": errMsg})
|
resp.ErrCode = 500
|
||||||
|
resp.ErrMsg = errMsg
|
||||||
|
c.JSON(http.StatusInternalServerError, resp)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
//逻辑处理完毕
|
//逻辑处理完毕
|
||||||
c.JSON(http.StatusOK, gin.H{"errCode": 0, "errMsg": ""})
|
c.JSON(http.StatusOK, resp)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
@ -12,9 +12,10 @@ import (
|
|||||||
"bytes"
|
"bytes"
|
||||||
"context"
|
"context"
|
||||||
"encoding/gob"
|
"encoding/gob"
|
||||||
|
"strings"
|
||||||
|
|
||||||
go_redis "github.com/go-redis/redis/v8"
|
go_redis "github.com/go-redis/redis/v8"
|
||||||
"github.com/pkg/errors"
|
"github.com/pkg/errors"
|
||||||
"strings"
|
|
||||||
|
|
||||||
//"gopkg.in/errgo.v2/errors"
|
//"gopkg.in/errgo.v2/errors"
|
||||||
"net/http"
|
"net/http"
|
||||||
@ -133,6 +134,7 @@ func (ws *WServer) MultiTerminalLoginRemoteChecker(userID string, platformID int
|
|||||||
resp, err := client.MultiTerminalLoginCheck(context.Background(), req)
|
resp, err := client.MultiTerminalLoginCheck(context.Background(), req)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Error(operationID, "MultiTerminalLoginCheck failed ", err.Error())
|
log.Error(operationID, "MultiTerminalLoginCheck failed ", err.Error())
|
||||||
|
continue
|
||||||
}
|
}
|
||||||
if resp.ErrCode != 0 {
|
if resp.ErrCode != 0 {
|
||||||
log.Error(operationID, "MultiTerminalLoginCheck errCode, errMsg: ", resp.ErrCode, resp.ErrMsg)
|
log.Error(operationID, "MultiTerminalLoginCheck errCode, errMsg: ", resp.ErrCode, resp.ErrMsg)
|
||||||
@ -237,7 +239,7 @@ func (ws *WServer) MultiTerminalLoginChecker(uid string, platformID int, newConn
|
|||||||
log.NewError(operationID, utils.GetSelfFuncName(), "callbackUserOffline failed", callbackResp)
|
log.NewError(operationID, utils.GetSelfFuncName(), "callbackUserOffline failed", callbackResp)
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
log.NewWarn(operationID, "normal uid-conn ", uid, platformID, oldConnMap[platformID])
|
log.Debug(operationID, "normal uid-conn ", uid, platformID, oldConnMap[platformID])
|
||||||
}
|
}
|
||||||
|
|
||||||
} else {
|
} else {
|
||||||
|
@ -407,7 +407,7 @@ func (och *OnlineHistoryRedisConsumerHandler) ConsumeClaim(sess sarama.ConsumerG
|
|||||||
rwLock.Unlock()
|
rwLock.Unlock()
|
||||||
split := 1000
|
split := 1000
|
||||||
triggerID = utils.OperationIDGenerator()
|
triggerID = utils.OperationIDGenerator()
|
||||||
log.NewWarn(triggerID, "timer trigger msg consumer start", len(ccMsg))
|
log.Debug(triggerID, "timer trigger msg consumer start", len(ccMsg))
|
||||||
for i := 0; i < len(ccMsg)/split; i++ {
|
for i := 0; i < len(ccMsg)/split; i++ {
|
||||||
//log.Debug()
|
//log.Debug()
|
||||||
och.msgDistributionCh <- Cmd2Value{Cmd: ConsumerMsgs, Value: TriggerChannelValue{
|
och.msgDistributionCh <- Cmd2Value{Cmd: ConsumerMsgs, Value: TriggerChannelValue{
|
||||||
@ -419,9 +419,8 @@ func (och *OnlineHistoryRedisConsumerHandler) ConsumeClaim(sess sarama.ConsumerG
|
|||||||
}
|
}
|
||||||
//sess.MarkMessage(ccMsg[len(cMsg)-1], "")
|
//sess.MarkMessage(ccMsg[len(cMsg)-1], "")
|
||||||
|
|
||||||
log.NewWarn(triggerID, "timer trigger msg consumer end", len(cMsg))
|
log.Debug(triggerID, "timer trigger msg consumer end", len(cMsg))
|
||||||
}
|
}
|
||||||
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -13,6 +13,7 @@ import (
|
|||||||
"Open_IM/pkg/common/log"
|
"Open_IM/pkg/common/log"
|
||||||
pbChat "Open_IM/pkg/proto/msg"
|
pbChat "Open_IM/pkg/proto/msg"
|
||||||
pbPush "Open_IM/pkg/proto/push"
|
pbPush "Open_IM/pkg/proto/push"
|
||||||
|
"Open_IM/pkg/utils"
|
||||||
"github.com/Shopify/sarama"
|
"github.com/Shopify/sarama"
|
||||||
"github.com/golang/protobuf/proto"
|
"github.com/golang/protobuf/proto"
|
||||||
)
|
)
|
||||||
@ -43,6 +44,11 @@ func (ms *PushConsumerHandler) handleMs2PsChat(msg []byte) {
|
|||||||
MsgData: msgFromMQ.MsgData,
|
MsgData: msgFromMQ.MsgData,
|
||||||
PushToUserID: msgFromMQ.PushToUserID,
|
PushToUserID: msgFromMQ.PushToUserID,
|
||||||
}
|
}
|
||||||
|
sec := msgFromMQ.MsgData.SendTime / 1000
|
||||||
|
nowSec := utils.GetCurrentTimestampBySecond()
|
||||||
|
if nowSec-sec > 10 {
|
||||||
|
return
|
||||||
|
}
|
||||||
switch msgFromMQ.MsgData.SessionType {
|
switch msgFromMQ.MsgData.SessionType {
|
||||||
case constant.SuperGroupChatType:
|
case constant.SuperGroupChatType:
|
||||||
MsgToSuperGroupUser(pbData)
|
MsgToSuperGroupUser(pbData)
|
||||||
@ -59,6 +65,7 @@ func (ms *PushConsumerHandler) ConsumeClaim(sess sarama.ConsumerGroupSession,
|
|||||||
for msg := range claim.Messages() {
|
for msg := range claim.Messages() {
|
||||||
log.NewDebug("", "kafka get info to mysql", "msgTopic", msg.Topic, "msgPartition", msg.Partition, "msg", string(msg.Value))
|
log.NewDebug("", "kafka get info to mysql", "msgTopic", msg.Topic, "msgPartition", msg.Partition, "msg", string(msg.Value))
|
||||||
ms.msgHandle[msg.Topic](msg.Value)
|
ms.msgHandle[msg.Topic](msg.Value)
|
||||||
|
sess.MarkMessage(msg, "")
|
||||||
}
|
}
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
@ -305,29 +305,6 @@ func (s *organizationServer) CreateOrganizationUser(ctx context.Context, req *rp
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (s *organizationServer) UpdateOrganizationUser(ctx context.Context, req *rpc.UpdateOrganizationUserReq) (*rpc.UpdateOrganizationUserResp, error) {
|
func (s *organizationServer) UpdateOrganizationUser(ctx context.Context, req *rpc.UpdateOrganizationUserReq) (*rpc.UpdateOrganizationUserResp, error) {
|
||||||
authReq := &pbAuth.UserRegisterReq{UserInfo: &open_im_sdk.UserInfo{}}
|
|
||||||
utils.CopyStructFields(authReq.UserInfo, req.OrganizationUser)
|
|
||||||
authReq.OperationID = req.OperationID
|
|
||||||
etcdConn := getcdv3.GetConn(config.Config.Etcd.EtcdSchema, strings.Join(config.Config.Etcd.EtcdAddr, ","), config.Config.RpcRegisterName.OpenImAuthName, req.OperationID)
|
|
||||||
if etcdConn == nil {
|
|
||||||
errMsg := req.OperationID + "getcdv3.GetConn == nil"
|
|
||||||
log.NewError(req.OperationID, errMsg)
|
|
||||||
return &rpc.UpdateOrganizationUserResp{ErrCode: constant.ErrInternal.ErrCode, ErrMsg: errMsg}, nil
|
|
||||||
}
|
|
||||||
client := pbAuth.NewAuthClient(etcdConn)
|
|
||||||
|
|
||||||
reply, err := client.UserRegister(context.Background(), authReq)
|
|
||||||
if err != nil {
|
|
||||||
errMsg := "UserRegister failed " + err.Error()
|
|
||||||
log.NewError(req.OperationID, errMsg)
|
|
||||||
return &rpc.UpdateOrganizationUserResp{ErrCode: constant.ErrDB.ErrCode, ErrMsg: errMsg}, nil
|
|
||||||
}
|
|
||||||
if reply.CommonResp.ErrCode != 0 {
|
|
||||||
errMsg := "UserRegister failed " + reply.CommonResp.ErrMsg
|
|
||||||
log.NewError(req.OperationID, errMsg)
|
|
||||||
return &rpc.UpdateOrganizationUserResp{ErrCode: constant.ErrDB.ErrCode, ErrMsg: errMsg}, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
log.NewInfo(req.OperationID, utils.GetSelfFuncName(), " rpc args ", req.String())
|
log.NewInfo(req.OperationID, utils.GetSelfFuncName(), " rpc args ", req.String())
|
||||||
if !token_verify.IsManagerUserID(req.OpUserID) {
|
if !token_verify.IsManagerUserID(req.OpUserID) {
|
||||||
errMsg := req.OperationID + " " + req.OpUserID + " is not app manager"
|
errMsg := req.OperationID + " " + req.OpUserID + " is not app manager"
|
||||||
@ -342,7 +319,7 @@ func (s *organizationServer) UpdateOrganizationUser(ctx context.Context, req *rp
|
|||||||
}
|
}
|
||||||
|
|
||||||
log.Debug(req.OperationID, "src ", *req.OrganizationUser, "dst ", organizationUser)
|
log.Debug(req.OperationID, "src ", *req.OrganizationUser, "dst ", organizationUser)
|
||||||
err = imdb.UpdateOrganizationUser(&organizationUser, nil)
|
err := imdb.UpdateOrganizationUser(&organizationUser, nil)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
errMsg := req.OperationID + " " + "CreateOrganizationUser failed " + err.Error()
|
errMsg := req.OperationID + " " + "CreateOrganizationUser failed " + err.Error()
|
||||||
log.Error(req.OperationID, errMsg, organizationUser)
|
log.Error(req.OperationID, errMsg, organizationUser)
|
||||||
|
@ -203,7 +203,7 @@ type SetGroupInfoReq struct {
|
|||||||
FaceURL string `json:"faceURL"`
|
FaceURL string `json:"faceURL"`
|
||||||
Ex string `json:"ex"`
|
Ex string `json:"ex"`
|
||||||
OperationID string `json:"operationID" binding:"required"`
|
OperationID string `json:"operationID" binding:"required"`
|
||||||
NeedVerification *int32 `json:"needVerification" `
|
NeedVerification *int32 `json:"needVerification"`
|
||||||
LookMemberInfo *int32 `json:"lookMemberInfo"`
|
LookMemberInfo *int32 `json:"lookMemberInfo"`
|
||||||
ApplyMemberFriend *int32 `json:"applyMemberFriend"`
|
ApplyMemberFriend *int32 `json:"applyMemberFriend"`
|
||||||
}
|
}
|
||||||
|
@ -99,3 +99,16 @@ type GetRTCInvitationInfoStartAppReq struct {
|
|||||||
type GetRTCInvitationInfoStartAppResp struct {
|
type GetRTCInvitationInfoStartAppResp struct {
|
||||||
GetRTCInvitationInfoResp
|
GetRTCInvitationInfoResp
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* FCM第三方上报Token
|
||||||
|
*/
|
||||||
|
type FcmUpdateTokenReq struct {
|
||||||
|
OperationID string `json:"operationID" binding:"required"`
|
||||||
|
Platform int `json:"platform" binding:"required,min=1,max=2"` //only for ios + android
|
||||||
|
FcmToken string `json:"fcmToken" binding:"required"`
|
||||||
|
}
|
||||||
|
|
||||||
|
type FcmUpdateTokenResp struct {
|
||||||
|
CommResp
|
||||||
|
}
|
||||||
|
@ -535,7 +535,7 @@ func init() {
|
|||||||
if err != nil {
|
if err != nil {
|
||||||
bytes, err = ioutil.ReadFile(filepath.Join(Root, "config", "config.yaml"))
|
bytes, err = ioutil.ReadFile(filepath.Join(Root, "config", "config.yaml"))
|
||||||
if err != nil {
|
if err != nil {
|
||||||
panic(err.Error())
|
panic(err.Error() + " config: " + filepath.Join(cfgName, "config", "config.yaml"))
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
Root = cfgName
|
Root = cfgName
|
||||||
@ -552,5 +552,4 @@ func init() {
|
|||||||
panic(err.Error())
|
panic(err.Error())
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
}
|
}
|
||||||
|
@ -185,14 +185,14 @@ func (d *DataBases) GetMessageListBySeq(userID string, seqList []uint32, operati
|
|||||||
if err != nil {
|
if err != nil {
|
||||||
errResult = err
|
errResult = err
|
||||||
failedSeqList = append(failedSeqList, v)
|
failedSeqList = append(failedSeqList, v)
|
||||||
log2.NewWarn(operationID, "redis get message error:", err.Error(), v)
|
log2.Debug(operationID, "redis get message error: ", err.Error(), v)
|
||||||
} else {
|
} else {
|
||||||
msg := pbCommon.MsgData{}
|
msg := pbCommon.MsgData{}
|
||||||
err = jsonpb.UnmarshalString(result, &msg)
|
err = jsonpb.UnmarshalString(result, &msg)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
errResult = err
|
errResult = err
|
||||||
failedSeqList = append(failedSeqList, v)
|
failedSeqList = append(failedSeqList, v)
|
||||||
log2.NewWarn(operationID, "Unmarshal err", result, err.Error())
|
log2.NewWarn(operationID, "Unmarshal err ", result, err.Error())
|
||||||
} else {
|
} else {
|
||||||
log2.NewDebug(operationID, "redis get msg is ", msg.String())
|
log2.NewDebug(operationID, "redis get msg is ", msg.String())
|
||||||
seqMsg = append(seqMsg, &msg)
|
seqMsg = append(seqMsg, &msg)
|
||||||
|
@ -6,9 +6,10 @@ import (
|
|||||||
commonDB "Open_IM/pkg/common/db"
|
commonDB "Open_IM/pkg/common/db"
|
||||||
"Open_IM/pkg/common/log"
|
"Open_IM/pkg/common/log"
|
||||||
"Open_IM/pkg/utils"
|
"Open_IM/pkg/utils"
|
||||||
|
"time"
|
||||||
|
|
||||||
go_redis "github.com/go-redis/redis/v8"
|
go_redis "github.com/go-redis/redis/v8"
|
||||||
"github.com/golang-jwt/jwt/v4"
|
"github.com/golang-jwt/jwt/v4"
|
||||||
"time"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
//var (
|
//var (
|
||||||
@ -27,13 +28,14 @@ type Claims struct {
|
|||||||
|
|
||||||
func BuildClaims(uid, platform string, ttl int64) Claims {
|
func BuildClaims(uid, platform string, ttl int64) Claims {
|
||||||
now := time.Now()
|
now := time.Now()
|
||||||
|
before := now.Add(-time.Minute * 5)
|
||||||
return Claims{
|
return Claims{
|
||||||
UID: uid,
|
UID: uid,
|
||||||
Platform: platform,
|
Platform: platform,
|
||||||
RegisteredClaims: jwt.RegisteredClaims{
|
RegisteredClaims: jwt.RegisteredClaims{
|
||||||
ExpiresAt: jwt.NewNumericDate(now.Add(time.Duration(ttl*24) * time.Hour)), //Expiration time
|
ExpiresAt: jwt.NewNumericDate(now.Add(time.Duration(ttl*24) * time.Hour)), //Expiration time
|
||||||
IssuedAt: jwt.NewNumericDate(now), //Issuing time
|
IssuedAt: jwt.NewNumericDate(now), //Issuing time
|
||||||
NotBefore: jwt.NewNumericDate(now), //Begin Effective time
|
NotBefore: jwt.NewNumericDate(before), //Begin Effective time
|
||||||
}}
|
}}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -99,23 +101,22 @@ func GetClaimFromToken(tokensString string) (*Claims, error) {
|
|||||||
if err != nil {
|
if err != nil {
|
||||||
if ve, ok := err.(*jwt.ValidationError); ok {
|
if ve, ok := err.(*jwt.ValidationError); ok {
|
||||||
if ve.Errors&jwt.ValidationErrorMalformed != 0 {
|
if ve.Errors&jwt.ValidationErrorMalformed != 0 {
|
||||||
return nil, constant.ErrTokenMalformed
|
return nil, utils.Wrap(constant.ErrTokenMalformed, "")
|
||||||
} else if ve.Errors&jwt.ValidationErrorExpired != 0 {
|
} else if ve.Errors&jwt.ValidationErrorExpired != 0 {
|
||||||
return nil, constant.ErrTokenExpired
|
return nil, utils.Wrap(constant.ErrTokenExpired, "")
|
||||||
} else if ve.Errors&jwt.ValidationErrorNotValidYet != 0 {
|
} else if ve.Errors&jwt.ValidationErrorNotValidYet != 0 {
|
||||||
return nil, constant.ErrTokenNotValidYet
|
return nil, utils.Wrap(constant.ErrTokenNotValidYet, "")
|
||||||
} else {
|
} else {
|
||||||
return nil, constant.ErrTokenUnknown
|
return nil, utils.Wrap(constant.ErrTokenUnknown, "")
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
return nil, constant.ErrTokenNotValidYet
|
return nil, utils.Wrap(constant.ErrTokenNotValidYet, "")
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
if claims, ok := token.Claims.(*Claims); ok && token.Valid {
|
if claims, ok := token.Claims.(*Claims); ok && token.Valid {
|
||||||
//log.NewDebug("", claims.UID, claims.Platform)
|
|
||||||
return claims, nil
|
return claims, nil
|
||||||
}
|
}
|
||||||
return nil, constant.ErrTokenNotValidYet
|
return nil, utils.Wrap(constant.ErrTokenNotValidYet, "")
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
184
pkg/tools/retry/retry.go
Normal file
184
pkg/tools/retry/retry.go
Normal file
@ -0,0 +1,184 @@
|
|||||||
|
package retry
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"errors"
|
||||||
|
"fmt"
|
||||||
|
"runtime/debug"
|
||||||
|
"time"
|
||||||
|
)
|
||||||
|
|
||||||
|
var (
|
||||||
|
ErrorAbort = errors.New("stop retry")
|
||||||
|
ErrorTimeout = errors.New("retry timeout")
|
||||||
|
ErrorContextDeadlineExceed = errors.New("context deadline exceeded")
|
||||||
|
ErrorEmptyRetryFunc = errors.New("empty retry function")
|
||||||
|
ErrorTimeFormat = errors.New("time out err")
|
||||||
|
)
|
||||||
|
|
||||||
|
type RetriesFunc func() error
|
||||||
|
type Option func(c *Config)
|
||||||
|
type HookFunc func()
|
||||||
|
type RetriesChecker func(err error) (needRetry bool)
|
||||||
|
type Config struct {
|
||||||
|
MaxRetryTimes int
|
||||||
|
Timeout time.Duration
|
||||||
|
RetryChecker RetriesChecker
|
||||||
|
Strategy Strategy
|
||||||
|
RecoverPanic bool
|
||||||
|
BeforeTry HookFunc
|
||||||
|
AfterTry HookFunc
|
||||||
|
}
|
||||||
|
|
||||||
|
var (
|
||||||
|
DefaultMaxRetryTimes = 3
|
||||||
|
DefaultTimeout = time.Minute
|
||||||
|
DefaultInterval = time.Second * 2
|
||||||
|
DefaultRetryChecker = func(err error) bool {
|
||||||
|
return !errors.Is(err, ErrorAbort) // not abort error, should continue retry
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
func newDefaultConfig() *Config {
|
||||||
|
return &Config{
|
||||||
|
MaxRetryTimes: DefaultMaxRetryTimes,
|
||||||
|
RetryChecker: DefaultRetryChecker,
|
||||||
|
Timeout: DefaultTimeout,
|
||||||
|
Strategy: NewLinear(DefaultInterval),
|
||||||
|
BeforeTry: func() {},
|
||||||
|
AfterTry: func() {},
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func WithTimeout(timeout time.Duration) Option {
|
||||||
|
return func(c *Config) {
|
||||||
|
c.Timeout = timeout
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func WithMaxRetryTimes(times int) Option {
|
||||||
|
return func(c *Config) {
|
||||||
|
c.MaxRetryTimes = times
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func WithRecoverPanic() Option {
|
||||||
|
return func(c *Config) {
|
||||||
|
c.RecoverPanic = true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func WithBeforeHook(hook HookFunc) Option {
|
||||||
|
return func(c *Config) {
|
||||||
|
c.BeforeTry = hook
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func WithAfterHook(hook HookFunc) Option {
|
||||||
|
return func(c *Config) {
|
||||||
|
c.AfterTry = hook
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func WithRetryChecker(checker RetriesChecker) Option {
|
||||||
|
return func(c *Config) {
|
||||||
|
c.RetryChecker = checker
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func WithBackOffStrategy(s BackoffStrategy, duration time.Duration) Option {
|
||||||
|
return func(c *Config) {
|
||||||
|
switch s {
|
||||||
|
case StrategyConstant:
|
||||||
|
c.Strategy = NewConstant(duration)
|
||||||
|
case StrategyLinear:
|
||||||
|
c.Strategy = NewLinear(duration)
|
||||||
|
case StrategyFibonacci:
|
||||||
|
c.Strategy = NewFibonacci(duration)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func WithCustomStrategy(s Strategy) Option {
|
||||||
|
return func(c *Config) {
|
||||||
|
c.Strategy = s
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func Do(ctx context.Context, fn RetriesFunc, opts ...Option) error {
|
||||||
|
if fn == nil {
|
||||||
|
return ErrorEmptyRetryFunc
|
||||||
|
}
|
||||||
|
var (
|
||||||
|
abort = make(chan struct{}, 1) // caller choose to abort retry
|
||||||
|
run = make(chan error, 1)
|
||||||
|
panicInfoChan = make(chan string, 1)
|
||||||
|
|
||||||
|
timer *time.Timer
|
||||||
|
runErr error
|
||||||
|
)
|
||||||
|
config := newDefaultConfig()
|
||||||
|
for _, o := range opts {
|
||||||
|
o(config)
|
||||||
|
}
|
||||||
|
if config.Timeout > 0 {
|
||||||
|
timer = time.NewTimer(config.Timeout)
|
||||||
|
} else {
|
||||||
|
return ErrorTimeFormat
|
||||||
|
}
|
||||||
|
go func() {
|
||||||
|
var err error
|
||||||
|
defer func() {
|
||||||
|
if e := recover(); e == nil {
|
||||||
|
return
|
||||||
|
} else {
|
||||||
|
panicInfoChan <- fmt.Sprintf("retry function panic has occured, err=%v, stack:%s", e, string(debug.Stack()))
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
for i := 0; i < config.MaxRetryTimes; i++ {
|
||||||
|
config.BeforeTry()
|
||||||
|
err = fn()
|
||||||
|
config.AfterTry()
|
||||||
|
if err == nil {
|
||||||
|
run <- nil
|
||||||
|
return
|
||||||
|
}
|
||||||
|
// check whether to retry
|
||||||
|
if config.RetryChecker != nil {
|
||||||
|
needRetry := config.RetryChecker(err)
|
||||||
|
if !needRetry {
|
||||||
|
abort <- struct{}{}
|
||||||
|
return
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if config.Strategy != nil {
|
||||||
|
interval := config.Strategy.Sleep(i + 1)
|
||||||
|
<-time.After(interval)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
run <- err
|
||||||
|
}()
|
||||||
|
select {
|
||||||
|
case <-ctx.Done():
|
||||||
|
// context deadline exceed
|
||||||
|
return ErrorContextDeadlineExceed
|
||||||
|
case <-timer.C:
|
||||||
|
// timeout
|
||||||
|
return ErrorTimeout
|
||||||
|
case <-abort:
|
||||||
|
// caller abort
|
||||||
|
return ErrorAbort
|
||||||
|
case msg := <-panicInfoChan:
|
||||||
|
// panic occurred
|
||||||
|
if !config.RecoverPanic {
|
||||||
|
panic(msg)
|
||||||
|
}
|
||||||
|
runErr = fmt.Errorf("panic occurred=%s", msg)
|
||||||
|
case e := <-run:
|
||||||
|
// normal run
|
||||||
|
if e != nil {
|
||||||
|
runErr = fmt.Errorf("retry failed, err=%w", e)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return runErr
|
||||||
|
}
|
56
pkg/tools/retry/stratey.go
Normal file
56
pkg/tools/retry/stratey.go
Normal file
@ -0,0 +1,56 @@
|
|||||||
|
package retry
|
||||||
|
|
||||||
|
import "time"
|
||||||
|
|
||||||
|
type BackoffStrategy int
|
||||||
|
|
||||||
|
const (
|
||||||
|
StrategyConstant BackoffStrategy = iota
|
||||||
|
StrategyLinear
|
||||||
|
StrategyFibonacci
|
||||||
|
)
|
||||||
|
|
||||||
|
type Strategy interface {
|
||||||
|
Sleep(times int) time.Duration
|
||||||
|
}
|
||||||
|
type Constant struct {
|
||||||
|
startInterval time.Duration
|
||||||
|
}
|
||||||
|
|
||||||
|
func NewConstant(d time.Duration) *Constant {
|
||||||
|
return &Constant{startInterval: d}
|
||||||
|
}
|
||||||
|
|
||||||
|
type Linear struct {
|
||||||
|
startInterval time.Duration
|
||||||
|
}
|
||||||
|
|
||||||
|
func NewLinear(d time.Duration) *Linear {
|
||||||
|
return &Linear{startInterval: d}
|
||||||
|
}
|
||||||
|
|
||||||
|
type Fibonacci struct {
|
||||||
|
startInterval time.Duration
|
||||||
|
}
|
||||||
|
|
||||||
|
func NewFibonacci(d time.Duration) *Fibonacci {
|
||||||
|
return &Fibonacci{startInterval: d}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *Constant) Sleep(_ int) time.Duration {
|
||||||
|
return c.startInterval
|
||||||
|
}
|
||||||
|
func (l *Linear) Sleep(times int) time.Duration {
|
||||||
|
return l.startInterval * time.Duration(times)
|
||||||
|
|
||||||
|
}
|
||||||
|
func (f *Fibonacci) Sleep(times int) time.Duration {
|
||||||
|
return f.startInterval * time.Duration(fibonacciNumber(times))
|
||||||
|
|
||||||
|
}
|
||||||
|
func fibonacciNumber(n int) int {
|
||||||
|
if n == 0 || n == 1 {
|
||||||
|
return n
|
||||||
|
}
|
||||||
|
return fibonacciNumber(n-1) + fibonacciNumber(n-2)
|
||||||
|
}
|
Loading…
x
Reference in New Issue
Block a user