diff --git a/utils/cors_middleware.go b/utils/cors_middleware.go new file mode 100644 index 000000000..1d75c4ead --- /dev/null +++ b/utils/cors_middleware.go @@ -0,0 +1,23 @@ +package utils + +import ( + "github.com/gin-gonic/gin" + "net/http" +) + +func CorsHandler() gin.HandlerFunc { + return func(context *gin.Context) { + context.Writer.Header().Set("Access-Control-Allow-Origin", "*") + context.Header("Access-Control-Allow-Methods", "*") + context.Header("Access-Control-Allow-Headers", "*") + context.Header("Access-Control-Expose-Headers", "Content-Length, Access-Control-Allow-Origin, Access-Control-Allow-Headers,Cache-Control,Content-Language,Content-Type,Expires,Last-Modified,Pragma,FooBar") // 跨域关键设置 让浏览器可以解析 + context.Header("Access-Control-Max-Age", "172800") // 缓存请求信息 单位为秒 + context.Header("Access-Control-Allow-Credentials", "false") // 跨域请求是否需要带cookie信息 默认设置为true + context.Header("content-type", "application/json") // 设置返回格式是json + //Release all option pre-requests + if context.Request.Method == http.MethodOptions { + context.JSON(http.StatusOK, "Options Request!") + } + context.Next() + } +} diff --git a/utils/file.go b/utils/file.go new file mode 100644 index 000000000..15ce153b0 --- /dev/null +++ b/utils/file.go @@ -0,0 +1,22 @@ +package utils + +import "os" + +// Determine whether the given path is a folder +func IsDir(path string) bool { + s, err := os.Stat(path) + if err != nil { + return false + } + return s.IsDir() +} + +// Determine whether the given path is a file +func IsFile(path string) bool { + return !IsDir(path) +} + +// Create a directory +func MkDir(path string) error { + return os.MkdirAll(path, os.ModePerm) +} diff --git a/utils/get_server_ip.go b/utils/get_server_ip.go new file mode 100644 index 000000000..21092ffa1 --- /dev/null +++ b/utils/get_server_ip.go @@ -0,0 +1,35 @@ +package utils + +import ( + "Open_IM/src/common/config" + "net" +) + +var ServerIP = "" + +func init() { + //fixme In the configuration file, ip takes precedence, if not, get the valid network card ip of the machine + if config.Config.ServerIP != "" { + ServerIP = config.Config.ServerIP + return + } + //fixme Get the ip of the local network card + netInterfaces, err := net.Interfaces() + if err != nil { + panic(err) + } + for i := 0; i < len(netInterfaces); i++ { + //Exclude useless network cards by judging the net.flag Up flag + if (netInterfaces[i].Flags & net.FlagUp) != 0 { + address, _ := netInterfaces[i].Addrs() + for _, addr := range address { + if ipNet, ok := addr.(*net.IPNet); ok && !ipNet.IP.IsLoopback() { + if ipNet.IP.To4() != nil { + ServerIP = ipNet.IP.String() + return + } + } + } + } + } +} diff --git a/utils/image.go b/utils/image.go new file mode 100644 index 000000000..393136e63 --- /dev/null +++ b/utils/image.go @@ -0,0 +1,56 @@ +package utils + +import ( + "errors" + "github.com/nfnt/resize" + "golang.org/x/image/bmp" + "image" + "image/gif" + "image/jpeg" + "image/png" + "io" + "os" +) + +func GenSmallImage(src, dst string) error { + fIn, _ := os.Open(src) + defer fIn.Close() + + fOut, _ := os.Create(dst) + defer fOut.Close() + + if err := scale(fIn, fOut, 0, 0, 0); err != nil { + return err + } + return nil +} + +func scale(in io.Reader, out io.Writer, width, height, quality int) error { + origin, fm, err := image.Decode(in) + if err != nil { + return err + } + if width == 0 || height == 0 { + width = origin.Bounds().Max.X / 2 + height = origin.Bounds().Max.Y / 2 + } + if quality == 0 { + quality = 25 + } + canvas := resize.Thumbnail(uint(width), uint(height), origin, resize.Lanczos3) + + switch fm { + case "jpeg": + return jpeg.Encode(out, canvas, &jpeg.Options{quality}) + case "png": + return png.Encode(out, canvas) + case "gif": + return gif.Encode(out, canvas, &gif.Options{}) + case "bmp": + return bmp.Encode(out, canvas) + default: + return errors.New("ERROR FORMAT") + } + + return nil +} diff --git a/utils/jwt_token.go b/utils/jwt_token.go new file mode 100644 index 000000000..f60cc201e --- /dev/null +++ b/utils/jwt_token.go @@ -0,0 +1,192 @@ +package utils + +import ( + "Open_IM/src/common/config" + "Open_IM/src/common/db" + "errors" + "time" +) + +var ( + TokenExpired = errors.New("token is timed out, please log in again") + TokenInvalid = errors.New("token has been invalidated") + TokenNotValidYet = errors.New("token not active yet") + TokenMalformed = errors.New("that's not even a token") + TokenUnknown = errors.New("couldn't handle this token") +) + +type Claims struct { + UID string + Platform string //login platform + jwt.StandardClaims +} + +func BuildClaims(uid, accountAddr, platform string, ttl int64) Claims { + now := time.Now().Unix() + //if ttl=-1 Permanent token + if ttl == -1 { + return Claims{ + UID: uid, + Platform: platform, + StandardClaims: jwt.StandardClaims{ + ExpiresAt: -1, + IssuedAt: now, + NotBefore: now, + }} + } + return Claims{ + UID: uid, + Platform: platform, + StandardClaims: jwt.StandardClaims{ + ExpiresAt: now + ttl, //Expiration time + IssuedAt: now, //Issuing time + NotBefore: now, //Begin Effective time + }} +} + +func CreateToken(userID, accountAddr string, platform int32) (string, int64, error) { + claims := BuildClaims(userID, accountAddr, PlatformIDToName(platform), config.Config.TokenPolicy.AccessExpire) + token := jwt.NewWithClaims(jwt.SigningMethodHS256, claims) + tokenString, err := token.SignedString([]byte(config.Config.TokenPolicy.AccessSecret)) + + return tokenString, claims.ExpiresAt, err +} + +func secret() jwt.Keyfunc { + return func(token *jwt.Token) (interface{}, error) { + return []byte(config.Config.TokenPolicy.AccessSecret), nil + } +} + +func ParseToken(tokensString string) (claims *Claims, err error) { + token, err := jwt.ParseWithClaims(tokensString, &Claims{}, secret()) + if err != nil { + if ve, ok := err.(*jwt.ValidationError); ok { + if ve.Errors&jwt.ValidationErrorMalformed != 0 { + return nil, TokenMalformed + } else if ve.Errors&jwt.ValidationErrorExpired != 0 { + return nil, TokenExpired + } else if ve.Errors&jwt.ValidationErrorNotValidYet != 0 { + return nil, TokenNotValidYet + } else { + return nil, TokenUnknown + } + } + } + if claims, ok := token.Claims.(*Claims); ok && token.Valid { + // 1.check userid and platform class 0 not exists and 1 exists + existsInterface, err := db.DB.ExistsUserIDAndPlatform(claims.UID, Platform2class[claims.Platform]) + if err != nil { + return nil, err + } + exists := existsInterface.(int64) + //get config multi login policy + if config.Config.MultiLoginPolicy.OnlyOneTerminalAccess { + //OnlyOneTerminalAccess policy need to check all terminal + //When only one end is allowed to log in, there is a situation that needs to be paid attention to. After PC login, + //mobile login should check two platform times. One of them is less than the redis storage time, which is the invalid token. + if Platform2class[claims.Platform] == "PC" { + existsInterface, err = db.DB.ExistsUserIDAndPlatform(claims.UID, "Mobile") + if err != nil { + return nil, err + } + exists = existsInterface.(int64) + if exists == 1 { + res, err := MakeTheTokenInvalid(*claims, "Mobile") + if err != nil { + return nil, err + } + if res { + return nil, TokenInvalid + } + } + } else { + existsInterface, err = db.DB.ExistsUserIDAndPlatform(claims.UID, "PC") + if err != nil { + return nil, err + } + exists = existsInterface.(int64) + if exists == 1 { + res, err := MakeTheTokenInvalid(*claims, "PC") + if err != nil { + return nil, err + } + if res { + return nil, TokenInvalid + } + } + } + + if exists == 1 { + res, err := MakeTheTokenInvalid(*claims, Platform2class[claims.Platform]) + if err != nil { + return nil, err + } + if res { + return nil, TokenInvalid + } + } + + } else if config.Config.MultiLoginPolicy.MobileAndPCTerminalAccessButOtherTerminalKickEachOther { + if exists == 1 { + res, err := MakeTheTokenInvalid(*claims, Platform2class[claims.Platform]) + if err != nil { + return nil, err + } + if res { + return nil, TokenInvalid + } + } + } + return claims, nil + } + return nil, TokenUnknown +} + +func MakeTheTokenInvalid(currentClaims Claims, platformClass string) (bool, error) { + storedRedisTokenInterface, err := db.DB.GetPlatformToken(currentClaims.UID, platformClass) + if err != nil { + return false, err + } + storedRedisPlatformClaims, err := ParseRedisInterfaceToken(storedRedisTokenInterface) + if err != nil { + return false, err + } + //if issue time less than redis token then make this token invalid + if currentClaims.IssuedAt < storedRedisPlatformClaims.IssuedAt { + return true, TokenInvalid + } + return false, nil +} +func ParseRedisInterfaceToken(redisToken interface{}) (*Claims, error) { + token, err := jwt.ParseWithClaims(string(redisToken.([]uint8)), &Claims{}, secret()) + if err != nil { + if ve, ok := err.(*jwt.ValidationError); ok { + if ve.Errors&jwt.ValidationErrorMalformed != 0 { + return nil, TokenMalformed + } else if ve.Errors&jwt.ValidationErrorExpired != 0 { + return nil, TokenExpired + } else if ve.Errors&jwt.ValidationErrorNotValidYet != 0 { + return nil, TokenNotValidYet + } else { + return nil, TokenInvalid + } + } + } + if claims, ok := token.Claims.(*Claims); ok && token.Valid { + return claims, nil + } + return nil, err +} + +//Validation token, false means failure, true means successful verification +func VerifyToken(token, uid string) bool { + claims, err := ParseToken(token) + if err != nil { + return false + } else if claims.UID != uid { + return false + } else { + return true + } +} diff --git a/utils/map.go b/utils/map.go new file mode 100644 index 000000000..bec647464 --- /dev/null +++ b/utils/map.go @@ -0,0 +1,119 @@ +package utils + +import ( + "encoding/json" + "sync" +) + +type Map struct { + sync.RWMutex + m map[interface{}]interface{} +} + +func (m *Map) init() { + if m.m == nil { + m.m = make(map[interface{}]interface{}) + } +} + +func (m *Map) UnsafeGet(key interface{}) interface{} { + if m.m == nil { + return nil + } else { + return m.m[key] + } +} + +func (m *Map) Get(key interface{}) interface{} { + m.RLock() + defer m.RUnlock() + return m.UnsafeGet(key) +} + +func (m *Map) UnsafeSet(key interface{}, value interface{}) { + m.init() + m.m[key] = value +} + +func (m *Map) Set(key interface{}, value interface{}) { + m.Lock() + defer m.Unlock() + m.UnsafeSet(key, value) +} + +func (m *Map) TestAndSet(key interface{}, value interface{}) interface{} { + m.Lock() + defer m.Unlock() + + m.init() + + if v, ok := m.m[key]; ok { + return v + } else { + m.m[key] = value + return nil + } +} + +func (m *Map) UnsafeDel(key interface{}) { + m.init() + delete(m.m, key) +} + +func (m *Map) Del(key interface{}) { + m.Lock() + defer m.Unlock() + m.UnsafeDel(key) +} + +func (m *Map) UnsafeLen() int { + if m.m == nil { + return 0 + } else { + return len(m.m) + } +} + +func (m *Map) Len() int { + m.RLock() + defer m.RUnlock() + return m.UnsafeLen() +} + +func (m *Map) UnsafeRange(f func(interface{}, interface{})) { + if m.m == nil { + return + } + for k, v := range m.m { + f(k, v) + } +} + +func (m *Map) RLockRange(f func(interface{}, interface{})) { + m.RLock() + defer m.RUnlock() + m.UnsafeRange(f) +} + +func (m *Map) LockRange(f func(interface{}, interface{})) { + m.Lock() + defer m.Unlock() + m.UnsafeRange(f) +} + +func MapToJsonString(param map[string]interface{}) string { + dataType, _ := json.Marshal(param) + dataString := string(dataType) + return dataString +} +func JsonStringToMap(str string) map[string]interface{} { + var tempMap map[string]interface{} + _ = json.Unmarshal([]byte(str), &tempMap) + return tempMap +} +func GetSwitchFromOptions(Options map[string]interface{}, key string) (result bool) { + if flag, ok := Options[key]; !ok || flag == 1 { + return true + } + return false +} diff --git a/utils/md5.go b/utils/md5.go new file mode 100644 index 000000000..8e3531668 --- /dev/null +++ b/utils/md5.go @@ -0,0 +1,13 @@ +package utils + +import ( + "crypto/md5" + "encoding/hex" +) + +func Md5(s string) string { + h := md5.New() + h.Write([]byte(s)) + cipher := h.Sum(nil) + return hex.EncodeToString(cipher) +} diff --git a/utils/platform_number_id_to_name.go b/utils/platform_number_id_to_name.go new file mode 100644 index 000000000..e9a3f7fed --- /dev/null +++ b/utils/platform_number_id_to_name.go @@ -0,0 +1,66 @@ +package utils + +// fixme 1<--->IOS 2<--->Android 3<--->Windows +//fixme 4<--->OSX 5<--->Web 6<--->MiniWeb 7<--->Linux + +const ( + //Platform ID + IOSPlatformID = 1 + AndroidPlatformID = 2 + WindowsPlatformID = 3 + OSXPlatformID = 4 + WebPlatformID = 5 + MiniWebPlatformID = 6 + LinuxPlatformID = 7 + + //Platform string match to Platform ID + IOSPlatformStr = "IOS" + AndroidPlatformStr = "Android" + WindowsPlatformStr = "Windows" + OSXPlatformStr = "OSX" + WebPlatformStr = "Web" + MiniWebPlatformStr = "MiniWeb" + LinuxPlatformStr = "Linux" + + //terminal types + TerminalPC = "PC" + TerminalMobile = "Mobile" +) + +var PlatformID2Name = map[int32]string{ + IOSPlatformID: IOSPlatformStr, + AndroidPlatformID: AndroidPlatformStr, + WindowsPlatformID: WindowsPlatformStr, + OSXPlatformID: OSXPlatformStr, + WebPlatformID: WebPlatformStr, + MiniWebPlatformID: MiniWebPlatformStr, + LinuxPlatformID: LinuxPlatformStr, +} +var PlatformName2ID = map[string]int32{ + IOSPlatformStr: IOSPlatformID, + AndroidPlatformStr: AndroidPlatformID, + WindowsPlatformStr: WindowsPlatformID, + OSXPlatformStr: OSXPlatformID, + WebPlatformStr: WebPlatformID, + MiniWebPlatformStr: MiniWebPlatformID, + LinuxPlatformStr: LinuxPlatformID, +} +var Platform2class = map[string]string{ + IOSPlatformStr: TerminalMobile, + AndroidPlatformStr: TerminalMobile, + MiniWebPlatformStr: TerminalMobile, + WindowsPlatformStr: TerminalPC, + OSXPlatformStr: TerminalPC, + WebPlatformStr: TerminalPC, + LinuxPlatformStr: TerminalPC, +} + +func PlatformIDToName(num int32) string { + return PlatformID2Name[num] +} +func PlatformNameToID(name string) int32 { + return PlatformName2ID[name] +} +func PlatformNameToClass(name string) string { + return Platform2class[name] +} diff --git a/utils/strings.go b/utils/strings.go new file mode 100644 index 000000000..e117f3cc9 --- /dev/null +++ b/utils/strings.go @@ -0,0 +1,41 @@ +/* +** description(""). +** copyright('tuoyun,www.tuoyun.net'). +** author("fg,Gordon@tuoyun.net"). +** time(2021/4/8 15:09). + */ +package utils + +import "strconv" + +func IntToString(i int) string { + return strconv.FormatInt(int64(i), 10) +} + +func StringToInt(i string) int { + j, _ := strconv.Atoi(i) + return j +} +func StringToInt64(i string) int64 { + j, _ := strconv.ParseInt(i, 10, 64) + return j +} + +//judge a string whether in the string list +func IsContain(target string, List []string) bool { + + for _, element := range List { + + if target == element { + return true + } + } + return false + +} +func InterfaceArrayToStringArray(data []interface{}) (i []string) { + for _, param := range data { + i = append(i, param.(string)) + } + return i +} diff --git a/utils/time_format.go b/utils/time_format.go new file mode 100644 index 000000000..91e6ddd3b --- /dev/null +++ b/utils/time_format.go @@ -0,0 +1,72 @@ +/* +** description(""). +** copyright('tuoyun,www.tuoyun.net'). +** author("fg,Gordon@tuoyun.net"). +** time(2021/2/22 11:52). + */ +package utils + +import ( + "strconv" + "time" +) + +const ( + TimeOffset = 8 * 3600 //8 hour offset + HalfOffset = 12 * 3600 //Half-day hourly offset +) + +//Get the current timestamp by Second +func GetCurrentTimestampBySecond() int64 { + return time.Now().Unix() +} + +//Convert timestamp to time.Time type +func UnixSecondToTime(second int64) time.Time { + return time.Unix(second, 0) +} + +//Get the current timestamp by Nano +func GetCurrentTimestampByNano() int64 { + return time.Now().UnixNano() +} + +//Get the current timestamp by Mill +func GetCurrentTimestampByMill() int64 { + return time.Now().UnixNano() / 1e6 +} + +//Get the timestamp at 0 o'clock of the day +func GetCurDayZeroTimestamp() int64 { + timeStr := time.Now().Format("2006-01-02") + t, _ := time.Parse("2006-01-02", timeStr) + return t.Unix() - TimeOffset +} + +//Get the timestamp at 12 o'clock on the day +func GetCurDayHalfTimestamp() int64 { + return GetCurDayZeroTimestamp() + HalfOffset + +} + +//Get the formatted time at 0 o'clock of the day, the format is "2006-01-02_00-00-00" +func GetCurDayZeroTimeFormat() string { + return time.Unix(GetCurDayZeroTimestamp(), 0).Format("2006-01-02_15-04-05") +} + +//Get the formatted time at 12 o'clock of the day, the format is "2006-01-02_12-00-00" +func GetCurDayHalfTimeFormat() string { + return time.Unix(GetCurDayZeroTimestamp()+HalfOffset, 0).Format("2006-01-02_15-04-05") +} +func GetTimeStampByFormat(datetime string) string { + timeLayout := "2006-01-02 15:04:05" + loc, _ := time.LoadLocation("Local") + tmp, _ := time.ParseInLocation(timeLayout, datetime, loc) + timestamp := tmp.Unix() + return strconv.FormatInt(timestamp, 10) +} + +func TimeStringFormatTimeUnix(timeFormat string, timeSrc string) int64 { + tm, _ := time.Parse(timeFormat, timeSrc) + return tm.Unix() +}