package cont

import (
	"context"
	"crypto/md5"
	"encoding/hex"
	"errors"
	"fmt"
	"github.com/OpenIMSDK/Open-IM-Server/pkg/common/db/s3"
	"github.com/google/uuid"
	"path"
	"strings"
	"time"
)

func New(impl s3.Interface) *Controller {
	return &Controller{impl: impl}
}

type Controller struct {
	impl s3.Interface
}

func (c *Controller) HashPath(md5 string) string {
	return path.Join(hashPath, md5)
}

func (c *Controller) NowPath() string {
	now := time.Now()
	return path.Join(
		fmt.Sprintf("%04d", now.Year()),
		fmt.Sprintf("%02d", now.Month()),
		fmt.Sprintf("%02d", now.Day()),
		fmt.Sprintf("%02d", now.Hour()),
		fmt.Sprintf("%02d", now.Minute()),
		fmt.Sprintf("%02d", now.Second()),
	)
}

func (c *Controller) UUID() string {
	id := uuid.New()
	return hex.EncodeToString(id[:])
}

func (c *Controller) PartSize(ctx context.Context, size int64) (int64, error) {
	return c.impl.PartSize(ctx, size)
}

func (c *Controller) PartLimit() *s3.PartLimit {
	return c.impl.PartLimit()
}

func (c *Controller) GetHashObject(ctx context.Context, hash string) (*s3.ObjectInfo, error) {
	return c.impl.StatObject(ctx, c.HashPath(hash))
}

func (c *Controller) InitiateUpload(ctx context.Context, hash string, size int64, expire time.Duration, maxParts int) (*InitiateUploadResult, error) {
	if size < 0 {
		return nil, errors.New("invalid size")
	}
	if hashBytes, err := hex.DecodeString(hash); err != nil {
		return nil, err
	} else if len(hashBytes) != md5.Size {
		return nil, errors.New("invalid md5")
	}
	partSize, err := c.impl.PartSize(ctx, size)
	if err != nil {
		return nil, err
	}
	partNumber := int(size / partSize)
	if size%partSize > 0 {
		partNumber++
	}
	if maxParts > 0 && partNumber > 0 && partNumber < maxParts {
		return nil, errors.New(fmt.Sprintf("too many parts: %d", partNumber))
	}
	if info, err := c.impl.StatObject(ctx, c.HashPath(hash)); err == nil {
		return nil, &HashAlreadyExistsError{Object: info}
	} else if !c.impl.IsNotFound(err) {
		return nil, err
	}
	if size <= partSize {
		// 预签名上传
		key := path.Join(tempPath, c.NowPath(), fmt.Sprintf("%s_%d_%s.presigned", hash, size, c.UUID()))
		rawURL, err := c.impl.PresignedPutObject(ctx, key, expire)
		if err != nil {
			return nil, err
		}
		return &InitiateUploadResult{
			UploadID: newMultipartUploadID(multipartUploadID{
				Type: UploadTypePresigned,
				ID:   "",
				Key:  key,
				Size: size,
				Hash: hash,
			}),
			PartSize: partSize,
			Sign: &s3.AuthSignResult{
				Parts: []s3.SignPart{
					{
						PartNumber: 1,
						URL:        rawURL,
					},
				},
			},
		}, nil
	} else {
		// 分片上传
		upload, err := c.impl.InitiateMultipartUpload(ctx, c.HashPath(hash))
		if err != nil {
			return nil, err
		}
		if maxParts < 0 {
			maxParts = partNumber
		}
		var authSign *s3.AuthSignResult
		if maxParts > 0 {
			partNumbers := make([]int, partNumber)
			for i := 0; i < maxParts; i++ {
				partNumbers[i] = i + 1
			}
			authSign, err = c.impl.AuthSign(ctx, upload.UploadID, upload.Key, time.Hour*24, partNumbers)
			if err != nil {
				return nil, err
			}
		}
		return &InitiateUploadResult{
			UploadID: newMultipartUploadID(multipartUploadID{
				Type: UploadTypeMultipart,
				ID:   upload.UploadID,
				Key:  upload.Key,
				Size: size,
				Hash: hash,
			}),
			PartSize: partSize,
			Sign:     authSign,
		}, nil
	}
}

func (c *Controller) CompleteUpload(ctx context.Context, uploadID string, partHashs []string) (*UploadResult, error) {
	upload, err := parseMultipartUploadID(uploadID)
	if err != nil {
		return nil, err
	}
	if md5Sum := md5.Sum([]byte(strings.Join(partHashs, ","))); hex.EncodeToString(md5Sum[:]) != upload.Hash {
		fmt.Println("CompleteUpload sum:", hex.EncodeToString(md5Sum[:]), "upload hash:", upload.Hash)
		return nil, errors.New("md5 mismatching")
	}
	if info, err := c.impl.StatObject(ctx, c.HashPath(upload.Hash)); err == nil {
		return &UploadResult{
			Key:  info.Key,
			Size: info.Size,
			Hash: info.ETag,
		}, nil
	} else if !c.impl.IsNotFound(err) {
		return nil, err
	}
	cleanObject := make(map[string]struct{})
	defer func() {
		for key := range cleanObject {
			_ = c.impl.DeleteObject(ctx, key)
		}
	}()
	var targetKey string
	switch upload.Type {
	case UploadTypeMultipart:
		parts := make([]s3.Part, len(partHashs))
		for i, part := range partHashs {
			parts[i] = s3.Part{
				PartNumber: i + 1,
				ETag:       part,
			}
		}
		// todo: 验证大小
		result, err := c.impl.CompleteMultipartUpload(ctx, upload.ID, upload.Key, parts)
		if err != nil {
			return nil, err
		}
		targetKey = result.Key
	case UploadTypePresigned:
		uploadInfo, err := c.impl.StatObject(ctx, upload.Key)
		if err != nil {
			return nil, err
		}
		cleanObject[uploadInfo.Key] = struct{}{}
		if uploadInfo.Size != upload.Size {
			return nil, errors.New("upload size mismatching")
		}
		md5Sum := md5.Sum([]byte(strings.Join([]string{uploadInfo.ETag}, ",")))
		if hex.EncodeToString(md5Sum[:]) != upload.Hash {
			return nil, errors.New("upload md5 mismatching")
		}
		// 防止在这个时候,并发操作,导致文件被覆盖
		copyInfo, err := c.impl.CopyObject(ctx, targetKey, upload.Key+"."+c.UUID())
		if err != nil {
			return nil, err
		}
		cleanObject[copyInfo.Key] = struct{}{}
		if copyInfo.ETag != upload.Hash {
			return nil, errors.New("copy md5 mismatching")
		}
		if _, err := c.impl.CopyObject(ctx, copyInfo.Key, c.HashPath(upload.Hash)); err != nil {
			return nil, err
		}
		targetKey = copyInfo.Key
	default:
		return nil, errors.New("invalid upload id type")
	}
	return &UploadResult{
		Key:  targetKey,
		Size: upload.Size,
		Hash: upload.Hash,
	}, nil
}

func (c *Controller) AuthSign(ctx context.Context, uploadID string, partNumbers []int) (*s3.AuthSignResult, error) {
	upload, err := parseMultipartUploadID(uploadID)
	if err != nil {
		return nil, err
	}
	switch upload.Type {
	case UploadTypeMultipart:
		return c.impl.AuthSign(ctx, upload.ID, upload.Key, time.Hour*24, partNumbers)
	case UploadTypePresigned:
		return nil, errors.New("presigned id not support auth sign")
	default:
		return nil, errors.New("invalid upload id type")
	}
}

func (c *Controller) IsNotFound(err error) bool {
	return c.impl.IsNotFound(err)
}

func (c *Controller) AccessURL(ctx context.Context, name string, expire time.Duration, opt *s3.AccessURLOption) (string, error) {
	return c.impl.AccessURL(ctx, name, expire, opt)
}