mirror of
				https://github.com/openimsdk/open-im-server.git
				synced 2025-11-04 19:32:17 +08:00 
			
		
		
		
	
		
			
				
	
	
		
			242 lines
		
	
	
		
			7.6 KiB
		
	
	
	
		
			Go
		
	
	
	
	
	
			
		
		
	
	
			242 lines
		
	
	
		
			7.6 KiB
		
	
	
	
		
			Go
		
	
	
	
	
	
package controller
 | 
						|
 | 
						|
import (
 | 
						|
	"context"
 | 
						|
	"github.com/golang-jwt/jwt/v4"
 | 
						|
	"github.com/openimsdk/open-im-server/v3/pkg/authverify"
 | 
						|
	"github.com/openimsdk/open-im-server/v3/pkg/common/config"
 | 
						|
	"github.com/openimsdk/open-im-server/v3/pkg/common/storage/cache"
 | 
						|
	"github.com/openimsdk/open-im-server/v3/pkg/common/storage/cache/cachekey"
 | 
						|
	"github.com/openimsdk/protocol/constant"
 | 
						|
	"github.com/openimsdk/tools/errs"
 | 
						|
	"github.com/openimsdk/tools/log"
 | 
						|
	"github.com/openimsdk/tools/tokenverify"
 | 
						|
)
 | 
						|
 | 
						|
type AuthDatabase interface {
 | 
						|
	// If the result is empty, no error is returned.
 | 
						|
	GetTokensWithoutError(ctx context.Context, userID string, platformID int) (map[string]int, error)
 | 
						|
	// Create token
 | 
						|
	CreateToken(ctx context.Context, userID string, platformID int) (string, error)
 | 
						|
 | 
						|
	BatchSetTokenMapByUidPid(ctx context.Context, tokens []string) error
 | 
						|
 | 
						|
	SetTokenMapByUidPid(ctx context.Context, userID string, platformID int, m map[string]int) error
 | 
						|
}
 | 
						|
 | 
						|
type multiLoginConfig struct {
 | 
						|
	Policy            int
 | 
						|
	MaxNumOneEnd      int
 | 
						|
	CustomizeLoginNum map[int]int
 | 
						|
}
 | 
						|
 | 
						|
type authDatabase struct {
 | 
						|
	cache        cache.TokenModel
 | 
						|
	accessSecret string
 | 
						|
	accessExpire int64
 | 
						|
	multiLogin   multiLoginConfig
 | 
						|
}
 | 
						|
 | 
						|
func NewAuthDatabase(cache cache.TokenModel, accessSecret string, accessExpire int64, multiLogin config.MultiLogin) AuthDatabase {
 | 
						|
	return &authDatabase{cache: cache, accessSecret: accessSecret, accessExpire: accessExpire, multiLogin: multiLoginConfig{
 | 
						|
		Policy:       multiLogin.Policy,
 | 
						|
		MaxNumOneEnd: multiLogin.MaxNumOneEnd,
 | 
						|
		CustomizeLoginNum: map[int]int{
 | 
						|
			constant.IOSPlatformID:        multiLogin.CustomizeLoginNum.IOS,
 | 
						|
			constant.AndroidPlatformID:    multiLogin.CustomizeLoginNum.Android,
 | 
						|
			constant.WindowsPlatformID:    multiLogin.CustomizeLoginNum.Windows,
 | 
						|
			constant.OSXPlatformID:        multiLogin.CustomizeLoginNum.OSX,
 | 
						|
			constant.WebPlatformID:        multiLogin.CustomizeLoginNum.Web,
 | 
						|
			constant.MiniWebPlatformID:    multiLogin.CustomizeLoginNum.MiniWeb,
 | 
						|
			constant.LinuxPlatformID:      multiLogin.CustomizeLoginNum.Linux,
 | 
						|
			constant.AndroidPadPlatformID: multiLogin.CustomizeLoginNum.APad,
 | 
						|
			constant.IPadPlatformID:       multiLogin.CustomizeLoginNum.IPad,
 | 
						|
			constant.AdminPlatformID:      multiLogin.CustomizeLoginNum.Admin,
 | 
						|
		},
 | 
						|
	}}
 | 
						|
}
 | 
						|
 | 
						|
// If the result is empty.
 | 
						|
func (a *authDatabase) GetTokensWithoutError(ctx context.Context, userID string, platformID int) (map[string]int, error) {
 | 
						|
	return a.cache.GetTokensWithoutError(ctx, userID, platformID)
 | 
						|
}
 | 
						|
 | 
						|
func (a *authDatabase) SetTokenMapByUidPid(ctx context.Context, userID string, platformID int, m map[string]int) error {
 | 
						|
	return a.cache.SetTokenMapByUidPid(ctx, userID, platformID, m)
 | 
						|
}
 | 
						|
 | 
						|
func (a *authDatabase) BatchSetTokenMapByUidPid(ctx context.Context, tokens []string) error {
 | 
						|
	setMap := make(map[string]map[string]int)
 | 
						|
	for _, token := range tokens {
 | 
						|
		claims, err := tokenverify.GetClaimFromToken(token, authverify.Secret(a.accessSecret))
 | 
						|
		key := cachekey.GetTokenKey(claims.UserID, claims.PlatformID)
 | 
						|
		if err != nil {
 | 
						|
			continue
 | 
						|
		} else {
 | 
						|
			if v, ok := setMap[key]; ok {
 | 
						|
				v[token] = constant.KickedToken
 | 
						|
			} else {
 | 
						|
				setMap[key] = map[string]int{
 | 
						|
					token: constant.KickedToken,
 | 
						|
				}
 | 
						|
			}
 | 
						|
		}
 | 
						|
	}
 | 
						|
	if err := a.cache.BatchSetTokenMapByUidPid(ctx, setMap); err != nil {
 | 
						|
		return err
 | 
						|
	}
 | 
						|
	return nil
 | 
						|
}
 | 
						|
 | 
						|
// Create Token.
 | 
						|
func (a *authDatabase) CreateToken(ctx context.Context, userID string, platformID int) (string, error) {
 | 
						|
	tokens, err := a.cache.GetAllTokensWithoutError(ctx, userID)
 | 
						|
	if err != nil {
 | 
						|
		return "", err
 | 
						|
	}
 | 
						|
	deleteTokenKey, kickedTokenKey, err := a.checkToken(ctx, tokens, platformID)
 | 
						|
	if err != nil {
 | 
						|
		return "", err
 | 
						|
	}
 | 
						|
	if len(deleteTokenKey) != 0 {
 | 
						|
		err = a.cache.DeleteTokenByUidPid(ctx, userID, platformID, deleteTokenKey)
 | 
						|
		if err != nil {
 | 
						|
			return "", err
 | 
						|
		}
 | 
						|
	}
 | 
						|
	if len(kickedTokenKey) != 0 {
 | 
						|
		for _, k := range kickedTokenKey {
 | 
						|
			err := a.cache.SetTokenFlagEx(ctx, userID, platformID, k, constant.KickedToken)
 | 
						|
			if err != nil {
 | 
						|
				return "", err
 | 
						|
			}
 | 
						|
			log.ZDebug(ctx, "kicked token in create token", "token", k)
 | 
						|
		}
 | 
						|
	}
 | 
						|
 | 
						|
	claims := tokenverify.BuildClaims(userID, platformID, a.accessExpire)
 | 
						|
	token := jwt.NewWithClaims(jwt.SigningMethodHS256, claims)
 | 
						|
	tokenString, err := token.SignedString([]byte(a.accessSecret))
 | 
						|
	if err != nil {
 | 
						|
		return "", errs.WrapMsg(err, "token.SignedString")
 | 
						|
	}
 | 
						|
 | 
						|
	if err = a.cache.SetTokenFlagEx(ctx, userID, platformID, tokenString, constant.NormalToken); err != nil {
 | 
						|
		return "", err
 | 
						|
	}
 | 
						|
	return tokenString, nil
 | 
						|
}
 | 
						|
 | 
						|
func (a *authDatabase) checkToken(ctx context.Context, tokens map[int]map[string]int, platformID int) ([]string, []string, error) {
 | 
						|
	// todo: Move the logic for handling old data to another location.
 | 
						|
	var (
 | 
						|
		loginTokenMap  = make(map[int][]string) // The length of the value of the map must be greater than 0
 | 
						|
		deleteToken    = make([]string, 0)
 | 
						|
		kickToken      = make([]string, 0)
 | 
						|
		adminToken     = make([]string, 0)
 | 
						|
		unkickTerminal = ""
 | 
						|
	)
 | 
						|
 | 
						|
	for plfID, tks := range tokens {
 | 
						|
		for k, v := range tks {
 | 
						|
			_, err := tokenverify.GetClaimFromToken(k, authverify.Secret(a.accessSecret))
 | 
						|
			if err != nil || v != constant.NormalToken {
 | 
						|
				deleteToken = append(deleteToken, k)
 | 
						|
			} else {
 | 
						|
				if plfID != constant.AdminPlatformID {
 | 
						|
					loginTokenMap[plfID] = append(loginTokenMap[plfID], k)
 | 
						|
				} else {
 | 
						|
					adminToken = append(adminToken, k)
 | 
						|
				}
 | 
						|
			}
 | 
						|
		}
 | 
						|
	}
 | 
						|
 | 
						|
	switch a.multiLogin.Policy {
 | 
						|
	case constant.DefalutNotKick:
 | 
						|
		for plt, ts := range loginTokenMap {
 | 
						|
			l := len(ts)
 | 
						|
			if platformID == plt {
 | 
						|
				l++
 | 
						|
			}
 | 
						|
			limit := a.multiLogin.MaxNumOneEnd
 | 
						|
			if l > limit {
 | 
						|
				kickToken = append(kickToken, ts[:l-limit]...)
 | 
						|
			}
 | 
						|
		}
 | 
						|
	case constant.AllLoginButSameTermKick:
 | 
						|
		for plt, ts := range loginTokenMap {
 | 
						|
			kickToken = append(kickToken, ts[:len(ts)-1]...)
 | 
						|
			if plt == platformID {
 | 
						|
				kickToken = append(kickToken, ts[len(ts)-1])
 | 
						|
			}
 | 
						|
		}
 | 
						|
	case constant.PCAndOther:
 | 
						|
		unkickTerminal = constant.TerminalPC
 | 
						|
		if constant.PlatformIDToClass(platformID) != unkickTerminal {
 | 
						|
			for plt, ts := range loginTokenMap {
 | 
						|
				if constant.PlatformIDToClass(plt) != unkickTerminal {
 | 
						|
					kickToken = append(kickToken, ts...)
 | 
						|
				}
 | 
						|
			}
 | 
						|
		} else {
 | 
						|
			var (
 | 
						|
				preKick   []string
 | 
						|
				isReserve = true
 | 
						|
			)
 | 
						|
			for plt, ts := range loginTokenMap {
 | 
						|
				if constant.PlatformIDToClass(plt) != unkickTerminal {
 | 
						|
					// Keep a token from another end
 | 
						|
					if isReserve {
 | 
						|
						isReserve = false
 | 
						|
						kickToken = append(kickToken, ts[:len(ts)-1]...)
 | 
						|
						preKick = append(preKick, ts[len(ts)-1])
 | 
						|
						continue
 | 
						|
					} else {
 | 
						|
						// Prioritize keeping Android
 | 
						|
						if plt == constant.AndroidPlatformID {
 | 
						|
							kickToken = append(kickToken, preKick...)
 | 
						|
							kickToken = append(kickToken, ts[:len(ts)-1]...)
 | 
						|
						} else {
 | 
						|
							kickToken = append(kickToken, ts...)
 | 
						|
						}
 | 
						|
					}
 | 
						|
				}
 | 
						|
			}
 | 
						|
		}
 | 
						|
	case constant.AllLoginButSameClassKick:
 | 
						|
		var (
 | 
						|
			reserved = make(map[string]struct{})
 | 
						|
		)
 | 
						|
 | 
						|
		for plt, ts := range loginTokenMap {
 | 
						|
			if constant.PlatformIDToClass(plt) == constant.PlatformIDToClass(platformID) {
 | 
						|
				kickToken = append(kickToken, ts...)
 | 
						|
			} else {
 | 
						|
				if _, ok := reserved[constant.PlatformIDToClass(plt)]; !ok {
 | 
						|
					reserved[constant.PlatformIDToClass(plt)] = struct{}{}
 | 
						|
					kickToken = append(kickToken, ts[:len(ts)-1]...)
 | 
						|
					continue
 | 
						|
				} else {
 | 
						|
					kickToken = append(kickToken, ts...)
 | 
						|
				}
 | 
						|
			}
 | 
						|
		}
 | 
						|
	default:
 | 
						|
		return nil, nil, errs.New("unknown multiLogin policy").Wrap()
 | 
						|
	}
 | 
						|
 | 
						|
	var adminTokenMaxNum = a.multiLogin.MaxNumOneEnd
 | 
						|
	if a.multiLogin.Policy == constant.Customize {
 | 
						|
		adminTokenMaxNum = a.multiLogin.CustomizeLoginNum[constant.AdminPlatformID]
 | 
						|
	}
 | 
						|
	l := len(adminToken)
 | 
						|
	if platformID == constant.AdminPlatformID {
 | 
						|
		l++
 | 
						|
	}
 | 
						|
	if l > adminTokenMaxNum {
 | 
						|
		kickToken = append(kickToken, adminToken[:l-adminTokenMaxNum]...)
 | 
						|
	}
 | 
						|
	return deleteToken, kickToken, nil
 | 
						|
}
 |