2026-05-15 11:31:46 +08:00

1108 lines
39 KiB
Go

package redpacket
import (
"context"
"encoding/hex"
"fmt"
"math/big"
"strings"
"time"
"github.com/ethereum/go-ethereum/common"
"github.com/ethereum/go-ethereum/crypto"
"github.com/google/uuid"
"github.com/openimsdk/open-im-server/v3/internal/rpc/redpacket/chain"
"github.com/openimsdk/open-im-server/v3/pkg/common/servererrs"
"github.com/openimsdk/open-im-server/v3/pkg/common/storage/model"
pbredpacket "github.com/openimsdk/protocol/redpacket"
"github.com/openimsdk/tools/errs"
"github.com/openimsdk/tools/log"
"github.com/openimsdk/tools/mcontext"
)
func (s *redPacketServer) CreateOrder(ctx context.Context, req *pbredpacket.CreateOrderReq) (*pbredpacket.CreateOrderResp, error) {
currentUserID := mcontext.GetOpUserID(ctx)
if currentUserID == "" {
return nil, servererrs.ErrNoPermission.WrapMsg("op user id is empty")
}
bizID := uuid.NewString()
chainType, err := normalizeChainType(req.ChainType)
if err != nil {
return nil, err
}
scopeType := normalizeScopeType(req.ScopeType)
if err := validateCreateScope(scopeType, req.GroupID, req.ReceiverUserID, req.ReceiverUserIDs); err != nil {
return nil, err
}
if err := s.validateCreateHook(ctx, req); err != nil {
return nil, err
}
chainID := req.ChainID
contractAddress := strings.TrimSpace(req.ContractAddress)
if chainType == "EVM" && s.chainClient != nil {
if chainID == 0 {
if chainValue := s.chainClient.ChainID(); chainValue != nil {
chainID = chainValue.Int64()
}
}
if contractAddress == "" {
contractAddress = s.chainClient.ContractAddress().Hex()
}
}
if chainType == "TRON" && s.tronClient != nil && contractAddress == "" {
contractAddress = s.tronClient.ContractAddress()
}
rp := &model.RedPacket{
BizID: bizID,
ChainType: chainType,
ChainID: chainID,
ContractAddress: contractAddress,
CreatorUserID: currentUserID,
CreatorWallet: req.CreatorWallet,
GroupID: req.GroupID,
ScopeType: scopeType,
ReceiverUserID: req.ReceiverUserID,
ReceiverUserIDs: append([]string(nil), req.ReceiverUserIDs...),
PacketType: req.PacketType,
Token: req.Token,
TotalAmount: req.TotalAmount,
TotalShares: req.TotalShares,
ExpiryAt: req.ExpiryAt,
Status: "PENDING",
CreatedAt: time.Now(),
UpdatedAt: time.Now(),
}
if err := s.db.CreateRedPacket(ctx, rp); err != nil {
log.ZError(ctx, "create redpacket failed", err, "bizID", bizID)
return nil, servererrs.ErrDatabase.WrapMsg("failed to create red packet")
}
return &pbredpacket.CreateOrderResp{BizID: bizID}, nil
}
func (s *redPacketServer) CreatedCallback(ctx context.Context, req *pbredpacket.CreatedCallbackReq) (*pbredpacket.CreatedCallbackResp, error) {
opUserID := mcontext.GetOpUserID(ctx)
if opUserID == "" {
return nil, servererrs.ErrNoPermission.WrapMsg("op user id is empty")
}
if strings.TrimSpace(req.BizID) == "" || strings.TrimSpace(req.TxHash) == "" {
return nil, errs.ErrArgs.WrapMsg("biz_id and tx_hash are required")
}
rp, err := s.db.GetRedPacketByBizID(ctx, req.BizID)
if err != nil {
return nil, err
}
if rp.CreatorUserID != opUserID {
return nil, servererrs.ErrNoPermission.WrapMsg("only the creator can submit the creation callback")
}
groupID := firstNonEmpty(req.GroupID, rp.GroupID)
scopeType := normalizeScopeType(firstNonEmpty(req.ScopeType, rp.ScopeType))
receiverUserID := firstNonEmpty(req.ReceiverUserID, rp.ReceiverUserID)
receiverUserIDs := rp.ReceiverUserIDs
if len(req.ReceiverUserIDs) > 0 {
receiverUserIDs = append([]string(nil), req.ReceiverUserIDs...)
}
if err := validateCreateScope(scopeType, groupID, receiverUserID, receiverUserIDs); err != nil {
return nil, err
}
createdPacket, err := s.resolveCreatedPacket(ctx, rp, req.TxHash, req.PacketID)
if err != nil {
return nil, err
}
if err := s.db.UpdateRedPacketCreated(ctx, &model.RedPacket{
BizID: req.BizID,
ChainType: rp.ChainType,
PacketID: createdPacket.PacketID,
ChainID: createdPacket.ChainID,
ContractAddress: createdPacket.ContractAddress,
CreatorWallet: createdPacket.CreatorWallet,
PacketType: createdPacket.PacketType,
Token: createdPacket.Token,
TotalAmount: createdPacket.TotalAmount,
TotalShares: createdPacket.TotalShares,
ExpiryAt: createdPacket.ExpiryAt,
TxHash: req.TxHash,
GroupID: groupID,
ScopeType: scopeType,
ReceiverUserID: receiverUserID,
ReceiverUserIDs: receiverUserIDs,
Status: "ACTIVE",
}); err != nil {
return nil, err
}
return &pbredpacket.CreatedCallbackResp{}, nil
}
func (s *redPacketServer) GetDetail(ctx context.Context, req *pbredpacket.GetDetailReq) (*pbredpacket.GetDetailResp, error) {
if strings.TrimSpace(req.PacketID) == "" {
return nil, errs.ErrArgs.WrapMsg("packet_id is required")
}
rp, err := s.db.GetRedPacketByPacketID(ctx, req.PacketID)
if err != nil {
return nil, err
}
claims, err := s.db.GetClaimsByPacketID(ctx, req.PacketID)
if err != nil {
claims = nil
}
return &pbredpacket.GetDetailResp{
Record: redPacketModelToProto(rp),
Claims: claimsModelToProto(claims),
}, nil
}
func (s *redPacketServer) IssueClaimSign(ctx context.Context, req *pbredpacket.IssueClaimSignReq) (*pbredpacket.IssueClaimSignResp, error) {
currentUserID := mcontext.GetOpUserID(ctx)
if currentUserID == "" {
return nil, servererrs.ErrNoPermission.WrapMsg("op user id is empty")
}
if strings.TrimSpace(req.PacketID) == "" || strings.TrimSpace(req.Claimer) == "" {
return nil, errs.ErrArgs.WrapMsg("packet_id and claimer are required")
}
if err := s.canClaim(ctx, req.PacketID, req.Claimer, currentUserID); err != nil {
return nil, err
}
packetIDBig := new(big.Int)
if _, ok := packetIDBig.SetString(req.PacketID, 10); !ok {
return nil, errs.ErrArgs.WrapMsg("invalid packet_id", "packetID", req.PacketID)
}
claimerAddr := common.HexToAddress(req.Claimer)
nonce := fmt.Sprintf("%d", time.Now().UnixNano())
authNonceBig := new(big.Int)
authNonceBig.SetString(nonce, 10)
deadline := time.Now().Add(5 * time.Minute).Unix()
randomSeedBig := new(big.Int)
if req.RandomSeed != "" && req.RandomSeed != "0" {
if _, ok := randomSeedBig.SetString(req.RandomSeed, 10); !ok {
return nil, errs.ErrArgs.WrapMsg("invalid random_seed", "randomSeed", req.RandomSeed)
}
} else {
randomSeedBig.SetInt64(time.Now().UnixNano())
}
deadlineBig := big.NewInt(deadline)
var digest [32]byte
var err error
if s.chainClient != nil {
digest, err = s.chainClient.GetSignMessage(ctx, packetIDBig, claimerAddr, authNonceBig, randomSeedBig, deadlineBig)
if err != nil {
return nil, errs.ErrInternalServer.WrapMsg("getSignMessage failed: " + err.Error())
}
} else {
digest = crypto.Keccak256Hash([]byte(fmt.Sprintf("%s:%s:%s:%s:%d", req.PacketID, req.Claimer, nonce, randomSeedBig.String(), deadline)))
}
var signature []byte
if s.signerKey != nil {
signature, err = crypto.Sign(digest[:], s.signerKey)
if err != nil {
return nil, errs.ErrInternalServer.WrapMsg("sign failed: " + err.Error())
}
if len(signature) == 65 && signature[64] < 27 {
signature[64] += 27
}
} else {
return nil, errs.ErrInternalServer.WrapMsg("signer key not configured; cannot issue claim signature")
}
sigHex := "0x" + hex.EncodeToString(signature)
auth := &model.RedPacketClaimAuth{
PacketID: req.PacketID,
Claimer: req.Claimer,
AuthNonce: nonce,
RandomSeed: randomSeedBig.String(),
Deadline: deadline,
Signature: sigHex,
CreatedAt: time.Now(),
}
if err := s.db.CreateClaimAuth(ctx, auth); err != nil {
return nil, servererrs.ErrDatabase.WrapMsg("save claim auth failed: " + err.Error())
}
return &pbredpacket.IssueClaimSignResp{
AuthNonce: nonce,
Deadline: deadline,
Signature: sigHex,
RandomSeed: randomSeedBig.String(),
}, nil
}
func (s *redPacketServer) ClaimResult(ctx context.Context, req *pbredpacket.ClaimResultReq) (*pbredpacket.ClaimResultResp, error) {
currentUserID := mcontext.GetOpUserID(ctx)
if currentUserID == "" {
return nil, servererrs.ErrNoPermission.WrapMsg("op user id is empty")
}
if strings.TrimSpace(req.PacketID) == "" || strings.TrimSpace(req.Claimer) == "" || strings.TrimSpace(req.TxHash) == "" {
return nil, errs.ErrArgs.WrapMsg("packet_id, claimer and tx_hash are required")
}
rp, err := s.db.GetRedPacketByPacketID(ctx, req.PacketID)
if err != nil {
return nil, err
}
if err := validateClaimBase(rp, currentUserID, req.Claimer); err != nil {
return nil, err
}
claim := &model.RedPacketClaim{
PacketID: req.PacketID,
UserID: currentUserID,
ClaimerWallet: req.Claimer,
ClaimTxHash: req.TxHash,
Status: "PENDING",
CreatedAt: time.Now(),
UpdatedAt: time.Now(),
}
if err := s.db.SaveClaim(ctx, claim); err != nil {
return nil, err
}
txSuccess, events, err := s.parseChainReceiptWithStatus(ctx, rp, req.TxHash)
if err != nil {
log.ZWarn(ctx, "parse claim receipt failed", err, "txHash", req.TxHash)
return &pbredpacket.ClaimResultResp{}, nil
}
if !txSuccess {
if markErr := s.markClaimFailed(ctx, req.PacketID, currentUserID, req.Claimer, req.TxHash); markErr != nil {
log.ZWarn(ctx, "mark claim failed status failed", markErr, "txHash", req.TxHash)
}
return &pbredpacket.ClaimResultResp{}, nil
}
claimedEvent, err := resolveClaimedEventFromParsedEvents(rp, events)
if err != nil {
log.ZWarn(ctx, "resolve claim event failed", err, "txHash", req.TxHash)
if markErr := s.markClaimFailed(ctx, req.PacketID, currentUserID, req.Claimer, req.TxHash); markErr != nil {
log.ZWarn(ctx, "mark claim failed status failed", markErr, "txHash", req.TxHash)
}
return &pbredpacket.ClaimResultResp{}, nil
}
if claimedEvent == nil {
if markErr := s.markClaimFailed(ctx, req.PacketID, currentUserID, req.Claimer, req.TxHash); markErr != nil {
log.ZWarn(ctx, "mark claim failed status failed", markErr, "txHash", req.TxHash)
}
return &pbredpacket.ClaimResultResp{}, nil
}
if !strings.EqualFold(claimedEvent.ClaimerWallet, req.Claimer) {
if markErr := s.markClaimFailed(ctx, req.PacketID, currentUserID, req.Claimer, req.TxHash); markErr != nil {
log.ZWarn(ctx, "mark claim failed status failed", markErr, "txHash", req.TxHash)
}
return nil, errs.ErrArgs.WrapMsg(fmt.Sprintf("claim event claimer mismatch: got %s want %s", claimedEvent.ClaimerWallet, req.Claimer))
}
confirmed := &model.RedPacketClaim{
PacketID: req.PacketID,
UserID: currentUserID,
ClaimerWallet: claimedEvent.ClaimerWallet,
AuthNonce: claimedEvent.AuthNonce,
ClaimTxHash: req.TxHash,
ClaimedAmount: claimedEvent.Amount,
BlockNumber: claimedEvent.BlockNumber,
Status: "CONFIRMED",
CreatedAt: time.Now(),
UpdatedAt: time.Now(),
}
if err := s.db.SaveClaim(ctx, confirmed); err != nil {
return nil, err
}
if claimedEvent.AuthNonce != "" {
if err := s.db.MarkClaimAuthUsed(ctx, claimedEvent.AuthNonce); err != nil {
log.ZWarn(ctx, "mark claim auth used failed", err, "authNonce", claimedEvent.AuthNonce)
}
}
// Pass "" for status so the DB layer auto-derives COMPLETED/ACTIVE.
// Pass req.TxHash as the idempotency key so concurrent indexer processing
// of the same transaction cannot double-count the claim.
if err := s.db.UpdateRedPacketClaimProgress(ctx, req.PacketID, claimedEvent.Amount, "", req.TxHash); err != nil {
return nil, err
}
return &pbredpacket.ClaimResultResp{}, nil
}
func (s *redPacketServer) parseChainReceiptWithStatus(ctx context.Context, rp *model.RedPacket, txHash string) (bool, []*chain.ParsedEvent, error) {
switch rp.ChainType {
case "EVM":
if s.chainClient == nil {
return false, nil, errs.ErrInternalServer.WrapMsg("evm client is unavailable")
}
return s.chainClient.ParseTransactionReceiptWithStatus(ctx, common.HexToHash(txHash))
case "TRON":
if s.tronClient == nil {
return false, nil, errs.ErrInternalServer.WrapMsg("tron client is unavailable")
}
return s.tronClient.ParseTransactionReceiptWithStatus(ctx, txHash)
default:
return false, nil, errs.ErrArgs.WrapMsg("unsupported chain_type: " + rp.ChainType)
}
}
func (s *redPacketServer) markClaimFailed(ctx context.Context, packetID, userID, claimer, txHash string) error {
return s.db.SaveClaim(ctx, &model.RedPacketClaim{
PacketID: packetID,
UserID: userID,
ClaimerWallet: claimer,
ClaimTxHash: txHash,
Status: "FAILED",
UpdatedAt: time.Now(),
})
}
// canClaim runs the claim-eligibility check (formerly RedPacketService.CanClaim).
func (s *redPacketServer) canClaim(ctx context.Context, packetID, claimer, userID string) error {
rp, err := s.db.GetRedPacketByPacketID(ctx, packetID)
if err != nil {
return err
}
if err := validateClaimBase(rp, userID, claimer); err != nil {
return err
}
if err := s.ensureWalletBinding(ctx, userID, claimer, rp.ChainType); err != nil {
return err
}
switch rp.PacketType {
case 0:
return s.validateFixedPacketClaim(ctx, rp, userID, claimer)
case 1:
return s.validateRandomPacketClaim(ctx, rp, userID, claimer)
case 2:
return s.validateTransferPacketClaim(ctx, rp, userID, claimer)
default:
return errs.ErrArgs.WrapMsg(fmt.Sprintf("unsupported packet_type: %d", rp.PacketType))
}
}
type claimedEventSnapshot struct {
ClaimerWallet string
AuthNonce string
Amount string
BlockNumber uint64
}
type refundedEventSnapshot struct {
RefundTo string
Amount string
BlockNumber uint64
}
type createdPacketSnapshot struct {
PacketID string
ChainID int64
ContractAddress string
CreatorWallet string
PacketType int32
Token string
TotalAmount string
TotalShares int32
ExpiryAt int64
}
func (s *redPacketServer) resolveCreatedPacket(ctx context.Context, rp *model.RedPacket, txHashHex, fallbackPacketID string) (*createdPacketSnapshot, error) {
switch rp.ChainType {
case "EVM":
// Offline mode: no chain client configured; caller must supply packet_id directly.
if s.chainClient == nil {
if fallbackPacketID == "" {
return nil, errs.ErrArgs.WrapMsg("packet_id is required when EVM client is unavailable")
}
return buildFallbackCreatedPacket(rp, fallbackPacketID), nil
}
success, events, err := s.chainClient.ParseTransactionReceiptWithStatus(ctx, common.HexToHash(txHashHex))
if err != nil {
return nil, errs.ErrInternalServer.WrapMsg("parse created tx failed: " + err.Error())
}
if !success {
return nil, errs.ErrArgs.WrapMsg("created tx execution failed on chain")
}
for _, event := range events {
if event.Name != "PacketCreated" {
continue
}
createdPacket := buildCreatedPacketSnapshot(rp, event)
if chainValue := s.chainClient.ChainID(); chainValue != nil {
createdPacket.ChainID = chainValue.Int64()
}
createdPacket.ContractAddress = s.chainClient.ContractAddress().Hex()
if err := validateCreatedPacket(rp, createdPacket); err != nil {
return nil, err
}
return createdPacket, nil
}
return nil, errs.ErrInternalServer.WrapMsg("PacketCreated event not found in tx: " + txHashHex)
case "TRON":
// Offline mode: no chain client configured; caller must supply packet_id directly.
if s.tronClient == nil {
if fallbackPacketID == "" {
return nil, errs.ErrArgs.WrapMsg("packet_id is required when TRON client is unavailable")
}
return buildFallbackCreatedPacket(rp, fallbackPacketID), nil
}
success, events, err := s.tronClient.ParseTransactionReceiptWithStatus(ctx, txHashHex)
if err != nil {
return nil, errs.ErrInternalServer.WrapMsg("parse tron created tx failed: " + err.Error())
}
if !success {
return nil, errs.ErrArgs.WrapMsg("created tx execution failed on chain")
}
for _, event := range events {
if event.Name != "PacketCreated" {
continue
}
createdPacket := buildCreatedPacketSnapshot(rp, event)
createdPacket.ContractAddress = firstNonEmpty(s.tronClient.ContractAddress(), rp.ContractAddress)
if err := validateCreatedPacket(rp, createdPacket); err != nil {
return nil, err
}
return createdPacket, nil
}
return nil, errs.ErrInternalServer.WrapMsg("PacketCreated event not found in TRON tx: " + txHashHex)
default:
return nil, errs.ErrArgs.WrapMsg("unsupported chain_type: " + rp.ChainType)
}
}
// validateCreateHook reserves a centralized validation extension point split by packet type.
func (s *redPacketServer) validateCreateHook(ctx context.Context, req *pbredpacket.CreateOrderReq) error {
switch req.PacketType {
case 0:
return s.validateFixedPacketCreate(ctx, req)
case 1:
return s.validateRandomPacketCreate(ctx, req)
case 2:
return s.validateTransferPacketCreate(ctx, req)
default:
return errs.ErrArgs.WrapMsg(fmt.Sprintf("unsupported packet_type: %d", req.PacketType))
}
}
// validateCreateBaseFields validates the fields shared by every red packet type.
// It does not look up creator identity or scope; those are handled by the per-type hooks.
func validateCreateBaseFields(req *pbredpacket.CreateOrderReq) (*big.Int, error) {
if strings.TrimSpace(req.CreatorWallet) == "" {
return nil, errs.ErrArgs.WrapMsg("creator_wallet is required")
}
if strings.TrimSpace(req.TotalAmount) == "" {
return nil, errs.ErrArgs.WrapMsg("total_amount is required")
}
total, ok := new(big.Int).SetString(req.TotalAmount, 10)
if !ok || total.Sign() <= 0 {
return nil, errs.ErrArgs.WrapMsg("total_amount must be a positive integer string", "totalAmount", req.TotalAmount)
}
if req.ExpiryAt != 0 && req.ExpiryAt <= time.Now().Unix() {
return nil, errs.ErrArgs.WrapMsg("expiry_at must be 0 or a future unix timestamp", "expiryAt", req.ExpiryAt)
}
return total, nil
}
// validateCreatorScope verifies group membership / friend relationship for the creator
// based on the requested scope. PUBLIC scope skips relationship checks.
func (s *redPacketServer) validateCreatorScope(ctx context.Context, req *pbredpacket.CreateOrderReq) error {
creatorUserID := mcontext.GetOpUserID(ctx)
if creatorUserID == "" {
return servererrs.ErrNoPermission.WrapMsg("op user id is empty")
}
switch normalizeScopeType(req.ScopeType) {
case "GROUP":
return s.ensureGroupEligibility(ctx, req.GroupID, creatorUserID)
case "DIRECT":
if strings.TrimSpace(req.ReceiverUserID) != "" {
if err := s.ensureFriendRelationship(ctx, creatorUserID, req.ReceiverUserID); err != nil {
return err
}
}
for _, receiverID := range req.ReceiverUserIDs {
if strings.TrimSpace(receiverID) == "" {
continue
}
if err := s.ensureFriendRelationship(ctx, creatorUserID, receiverID); err != nil {
return err
}
}
return nil
default:
return nil
}
}
// validateFixedPacketCreate validates fixed red packets:
// - shared base fields
// - scope_type must be GROUP (fixed packets are group-only; claim validators require group_id)
// - 0 < total_shares <= maxTotalShares
// - total_amount must be divisible by total_shares (each share is an integer in min units)
// - creator must be an active member of the group
func (s *redPacketServer) validateFixedPacketCreate(ctx context.Context, req *pbredpacket.CreateOrderReq) error {
total, err := validateCreateBaseFields(req)
if err != nil {
return err
}
if normalizeScopeType(req.ScopeType) != "GROUP" {
return errs.ErrArgs.WrapMsg("fixed packet must use scope_type=GROUP")
}
if req.TotalShares <= 0 {
return errs.ErrArgs.WrapMsg("total_shares must be positive for fixed packet", "totalShares", req.TotalShares)
}
if req.TotalShares > maxTotalShares {
return errs.ErrArgs.WrapMsg(fmt.Sprintf("total_shares must not exceed %d for fixed packet", maxTotalShares), "totalShares", req.TotalShares)
}
shares := big.NewInt(int64(req.TotalShares))
if new(big.Int).Mod(total, shares).Sign() != 0 {
return errs.ErrArgs.WrapMsg("total_amount must be divisible by total_shares for fixed packet",
"totalAmount", req.TotalAmount, "totalShares", req.TotalShares)
}
return s.validateCreatorScope(ctx, req)
}
// validateRandomPacketCreate validates random (lucky) red packets:
// - shared base fields
// - scope_type must be GROUP (random packets are group-only; claim validators require group_id)
// - 0 < total_shares <= maxTotalShares
// - total_amount >= total_shares (at least 1 min unit per share)
// - creator must be an active member of the group
func (s *redPacketServer) validateRandomPacketCreate(ctx context.Context, req *pbredpacket.CreateOrderReq) error {
total, err := validateCreateBaseFields(req)
if err != nil {
return err
}
if normalizeScopeType(req.ScopeType) != "GROUP" {
return errs.ErrArgs.WrapMsg("random packet must use scope_type=GROUP")
}
if req.TotalShares <= 0 {
return errs.ErrArgs.WrapMsg("total_shares must be positive for random packet", "totalShares", req.TotalShares)
}
if req.TotalShares > maxTotalShares {
return errs.ErrArgs.WrapMsg(fmt.Sprintf("total_shares must not exceed %d for random packet", maxTotalShares), "totalShares", req.TotalShares)
}
shares := big.NewInt(int64(req.TotalShares))
if total.Cmp(shares) < 0 {
return errs.ErrArgs.WrapMsg("total_amount must be >= total_shares for random packet",
"totalAmount", req.TotalAmount, "totalShares", req.TotalShares)
}
return s.validateCreatorScope(ctx, req)
}
// validateTransferPacketCreate validates transfer red packets:
// - shared base fields
// - scope_type must be DIRECT (transfer is a 1-to-1 direct send)
// - total_shares == 1
// - exactly one receiver_user_id (receiver_user_ids must be empty)
// - receiver must not be the creator (no self-transfer)
// - creator and receiver must be friends
func (s *redPacketServer) validateTransferPacketCreate(ctx context.Context, req *pbredpacket.CreateOrderReq) error {
if _, err := validateCreateBaseFields(req); err != nil {
return err
}
if normalizeScopeType(req.ScopeType) != "DIRECT" {
return errs.ErrArgs.WrapMsg("transfer packet must use scope_type=DIRECT")
}
if req.TotalShares != 1 {
return errs.ErrArgs.WrapMsg("transfer packet must have total_shares == 1", "totalShares", req.TotalShares)
}
// Reject ambiguous input: receiver_user_ids is not applicable for transfer.
if len(req.ReceiverUserIDs) > 0 {
return errs.ErrArgs.WrapMsg("transfer packet uses receiver_user_id (singular), not receiver_user_ids")
}
receiverUserID := strings.TrimSpace(req.ReceiverUserID)
if receiverUserID == "" {
return errs.ErrArgs.WrapMsg("receiver_user_id is required for transfer packet")
}
creatorUserID := mcontext.GetOpUserID(ctx)
if creatorUserID == "" {
return servererrs.ErrNoPermission.WrapMsg("op user id is empty")
}
if creatorUserID == receiverUserID {
return errs.ErrArgs.WrapMsg("transfer packet cannot be sent to yourself")
}
return s.ensureFriendRelationship(ctx, creatorUserID, receiverUserID)
}
func buildFallbackCreatedPacket(rp *model.RedPacket, packetID string) *createdPacketSnapshot {
return &createdPacketSnapshot{
PacketID: packetID,
ChainID: rp.ChainID,
ContractAddress: rp.ContractAddress,
CreatorWallet: strings.ToLower(rp.CreatorWallet),
PacketType: rp.PacketType,
Token: normalizeTokenAddress(rp.Token),
TotalAmount: rp.TotalAmount,
TotalShares: rp.TotalShares,
ExpiryAt: rp.ExpiryAt,
}
}
func buildCreatedPacketSnapshot(rp *model.RedPacket, event *chain.ParsedEvent) *createdPacketSnapshot {
return &createdPacketSnapshot{
PacketID: chain.GetPacketIDFromEvent(event).String(),
ChainID: rp.ChainID,
ContractAddress: rp.ContractAddress,
CreatorWallet: strings.ToLower(chain.GetAddressFromEvent(event, "creator").Hex()),
PacketType: int32(chain.GetUintFromEvent(event, "packetType").Int64()),
Token: strings.ToLower(chain.GetAddressFromEvent(event, "token").Hex()),
TotalAmount: chain.GetUintFromEvent(event, "totalAmount").String(),
TotalShares: int32(chain.GetUintFromEvent(event, "totalShares").Int64()),
ExpiryAt: chain.GetUintFromEvent(event, "expiryAt").Int64(),
}
}
func validateCreatedPacket(rp *model.RedPacket, createdPacket *createdPacketSnapshot) error {
if createdPacket == nil {
return errs.ErrInternalServer.WrapMsg("created packet is nil")
}
if createdPacket.CreatorWallet != "" && strings.ToLower(rp.CreatorWallet) != createdPacket.CreatorWallet {
return errs.ErrArgs.WrapMsg(fmt.Sprintf("creator mismatch: got %s want %s", createdPacket.CreatorWallet, rp.CreatorWallet))
}
if createdPacket.PacketType != rp.PacketType {
return errs.ErrArgs.WrapMsg(fmt.Sprintf("packet type mismatch: got %d want %d", createdPacket.PacketType, rp.PacketType))
}
if createdPacket.TotalAmount != rp.TotalAmount {
return errs.ErrArgs.WrapMsg(fmt.Sprintf("total amount mismatch: got %s want %s", createdPacket.TotalAmount, rp.TotalAmount))
}
if createdPacket.TotalShares != rp.TotalShares {
return errs.ErrArgs.WrapMsg(fmt.Sprintf("total shares mismatch: got %d want %d", createdPacket.TotalShares, rp.TotalShares))
}
expectedToken := normalizeTokenAddress(rp.Token)
if createdPacket.Token != expectedToken {
return errs.ErrArgs.WrapMsg(fmt.Sprintf("token mismatch: got %s want %s", createdPacket.Token, expectedToken))
}
if rp.ExpiryAt > 0 && createdPacket.ExpiryAt != rp.ExpiryAt {
return errs.ErrArgs.WrapMsg(fmt.Sprintf("expiry mismatch: got %d want %d", createdPacket.ExpiryAt, rp.ExpiryAt))
}
return nil
}
func validateClaimBase(rp *model.RedPacket, userID, claimer string) error {
if rp == nil {
return servererrs.ErrRecordNotFound.WrapMsg("packet not found")
}
if strings.TrimSpace(userID) == "" {
return errs.ErrArgs.WrapMsg("user_id is required")
}
if strings.TrimSpace(claimer) == "" {
return errs.ErrArgs.WrapMsg("claimer is required")
}
// Check status first to give precise error messages for each terminal state.
switch rp.Status {
case "ACTIVE":
// ok, continue to expiry check
case "REFUNDED":
return errs.ErrArgs.WrapMsg("packet has been refunded")
case "EXPIRED":
return errs.ErrArgs.WrapMsg("packet has expired")
default:
return errs.ErrArgs.WrapMsg("packet is not claimable, current status: " + rp.Status)
}
// Guard against the race where status is still ACTIVE but expiry has passed.
if rp.ExpiryAt > 0 && rp.ExpiryAt <= time.Now().Unix() {
return errs.ErrArgs.WrapMsg("packet has expired")
}
return nil
}
func (s *redPacketServer) validateFixedPacketClaim(ctx context.Context, rp *model.RedPacket, userID, claimer string) error {
if strings.TrimSpace(rp.GroupID) == "" {
return errs.ErrArgs.WrapMsg("group_id is required for fixed packet claim")
}
if err := s.ensureNotClaimed(ctx, rp.PacketID, userID, claimer); err != nil {
return err
}
return s.ensureGroupEligibility(ctx, rp.GroupID, userID)
}
func (s *redPacketServer) validateRandomPacketClaim(ctx context.Context, rp *model.RedPacket, userID, claimer string) error {
if strings.TrimSpace(rp.GroupID) == "" {
return errs.ErrArgs.WrapMsg("group_id is required for random packet claim")
}
if err := s.ensureNotClaimed(ctx, rp.PacketID, userID, claimer); err != nil {
return err
}
return s.ensureGroupEligibility(ctx, rp.GroupID, userID)
}
func (s *redPacketServer) validateTransferPacketClaim(ctx context.Context, rp *model.RedPacket, userID, claimer string) error {
if err := s.ensureNotClaimed(ctx, rp.PacketID, userID, claimer); err != nil {
return err
}
if strings.TrimSpace(rp.ReceiverUserID) == "" {
return errs.ErrArgs.WrapMsg("receiver_user_id is required for transfer claim")
}
if rp.ReceiverUserID != userID {
return errs.ErrNoPermission.WrapMsg("user is not the designated receiver")
}
return s.ensureFriendRelationship(ctx, rp.CreatorUserID, userID)
}
func (s *redPacketServer) ensureNotClaimed(ctx context.Context, packetID, userID, claimer string) error {
if strings.TrimSpace(userID) != "" {
claim, err := s.db.GetClaimByPacketIDAndUserID(ctx, packetID, userID)
if err == nil && claim != nil && claim.Status != "FAILED" {
return errs.ErrArgs.WrapMsg("user already claimed")
}
if err != nil && !errs.ErrRecordNotFound.Is(err) {
return err
}
}
claim, err := s.db.GetClaimByPacketIDAndClaimer(ctx, packetID, claimer)
if err == nil && claim != nil && claim.Status != "FAILED" {
return errs.ErrArgs.WrapMsg("already claimed")
}
if err != nil && !errs.ErrRecordNotFound.Is(err) {
return err
}
return nil
}
func (s *redPacketServer) ensureWalletBinding(ctx context.Context, userID, claimer, chainType string) error {
if _, err := s.db.GetActiveWalletBinding(ctx, userID, chainType, claimer); err != nil {
if errs.ErrRecordNotFound.Is(err) {
return errs.ErrNoPermission.WrapMsg("wallet is not bound to user")
}
return err
}
return nil
}
// ensureGroupEligibility verifies that userID is an active member of groupID.
func (s *redPacketServer) ensureGroupEligibility(ctx context.Context, groupID, userID string) error {
groupID = strings.TrimSpace(groupID)
userID = strings.TrimSpace(userID)
if groupID == "" {
return errs.ErrArgs.WrapMsg("group_id is required for group claim")
}
if userID == "" {
return errs.ErrArgs.WrapMsg("user_id is required for group claim")
}
if s.groupClient == nil {
return servererrs.ErrInternalServer.WrapMsg("group client is not initialized")
}
if _, err := s.groupClient.GetGroupMemberInfo(ctx, groupID, userID); err != nil {
if errs.ErrRecordNotFound.Is(err) {
return errs.ErrNoPermission.WrapMsg("user is not a member of the group", "groupID", groupID, "userID", userID)
}
return err
}
return nil
}
// ensureFriendRelationship verifies that userA and userB are mutual friends.
// It is used in two contexts:
// - validateCreatorScope (DIRECT scope): checking that each listed receiver is
// a friend of the creator. In that path userA == userB is theoretically possible
// (creator adding themselves to a list), which is allowed here; the transfer
// validator has its own explicit self-transfer prohibition.
// - validateTransferPacketClaim: re-confirming the friendship at claim time.
//
// Self-transfer is intentionally allowed at this level; call sites that need to
// prohibit it (e.g. validateTransferPacketCreate) must do so before calling here.
func (s *redPacketServer) ensureFriendRelationship(ctx context.Context, userA, userB string) error {
userA = strings.TrimSpace(userA)
userB = strings.TrimSpace(userB)
if userA == "" || userB == "" {
return errs.ErrArgs.WrapMsg("both user IDs are required for friend relationship check")
}
if userA == userB {
return nil
}
if s.relationClient == nil {
return servererrs.ErrInternalServer.WrapMsg("relation client is not initialized")
}
ok, err := s.relationClient.IsFriend(ctx, userA, userB)
if err != nil {
return err
}
if !ok {
return errs.ErrNoPermission.WrapMsg("users are not friends", "userA", userA, "userB", userB)
}
return nil
}
func (s *redPacketServer) resolveClaimedEvent(ctx context.Context, rp *model.RedPacket, txHash string) (*claimedEventSnapshot, error) {
var (
events []*chain.ParsedEvent
err error
)
switch rp.ChainType {
case "EVM":
if s.chainClient == nil {
return nil, nil
}
events, err = s.chainClient.ParseTransactionReceipt(ctx, common.HexToHash(txHash))
case "TRON":
if s.tronClient == nil {
return nil, nil
}
events, err = s.tronClient.ParseTransactionReceipt(ctx, txHash)
default:
return nil, errs.ErrArgs.WrapMsg("unsupported chain_type: " + rp.ChainType)
}
if err != nil {
return nil, err
}
return resolveClaimedEventFromParsedEvents(rp, events)
}
func resolveClaimedEventFromParsedEvents(rp *model.RedPacket, events []*chain.ParsedEvent) (*claimedEventSnapshot, error) {
for _, event := range events {
if event.Name != "PacketClaimed" {
continue
}
packetID := chain.GetPacketIDFromEvent(event).String()
claimerWallet := strings.ToLower(chain.GetAddressFromEvent(event, "claimer").Hex())
if packetID != rp.PacketID {
return nil, errs.ErrArgs.WrapMsg(fmt.Sprintf("claim event packet mismatch: got %s want %s", packetID, rp.PacketID))
}
return &claimedEventSnapshot{
ClaimerWallet: claimerWallet,
AuthNonce: chain.GetUintFromEvent(event, "authNonce").String(),
Amount: chain.GetAmountFromEvent(event).String(),
BlockNumber: event.BlockNumber,
}, nil
}
return nil, nil
}
func resolveRefundedEventFromParsedEvents(rp *model.RedPacket, events []*chain.ParsedEvent) (*refundedEventSnapshot, error) {
for _, event := range events {
if event.Name != "PacketRefunded" {
continue
}
packetID := chain.GetPacketIDFromEvent(event).String()
if packetID != rp.PacketID {
return nil, errs.ErrArgs.WrapMsg(fmt.Sprintf("refund event packet mismatch: got %s want %s", packetID, rp.PacketID))
}
return &refundedEventSnapshot{
RefundTo: strings.ToLower(chain.GetAddressFromEvent(event, "refundTo").Hex()),
Amount: chain.GetAmountFromEvent(event).String(),
BlockNumber: event.BlockNumber,
}, nil
}
return nil, nil
}
// maxTotalShares caps the number of shares to prevent abuse.
const maxTotalShares = 10_000
func normalizeScopeType(scopeType string) string {
switch strings.ToUpper(strings.TrimSpace(scopeType)) {
case "GROUP", "DIRECT", "PUBLIC":
return strings.ToUpper(strings.TrimSpace(scopeType))
default:
return "PUBLIC"
}
}
func normalizeChainType(chainType string) (string, error) {
switch strings.ToUpper(strings.TrimSpace(chainType)) {
case "EVM":
return "EVM", nil
case "TRON":
return "TRON", nil
default:
return "", errs.ErrArgs.WrapMsg("unsupported chain_type: " + chainType)
}
}
func validateCreateScope(scopeType, groupID, receiverUserID string, receiverUserIDs []string) error {
switch scopeType {
case "GROUP":
if strings.TrimSpace(groupID) == "" {
return errs.ErrArgs.WrapMsg("group_id is required when scope_type=GROUP")
}
case "DIRECT":
if strings.TrimSpace(receiverUserID) == "" && len(receiverUserIDs) == 0 {
return errs.ErrArgs.WrapMsg("receiver_user_id or receiver_user_ids is required when scope_type=DIRECT")
}
}
return nil
}
func normalizeTokenAddress(token string) string {
if strings.TrimSpace(token) == "" {
return strings.ToLower(common.Address{}.Hex())
}
return strings.ToLower(common.HexToAddress(token).Hex())
}
func firstNonEmpty(values ...string) string {
for _, value := range values {
if strings.TrimSpace(value) != "" {
return value
}
}
return ""
}
func redPacketModelToProto(rp *model.RedPacket) *pbredpacket.RedPacketRecord {
if rp == nil {
return nil
}
return &pbredpacket.RedPacketRecord{
BizID: rp.BizID,
ChainType: rp.ChainType,
PacketID: rp.PacketID,
ChainID: rp.ChainID,
ContractAddress: rp.ContractAddress,
CreatorUserID: rp.CreatorUserID,
CreatorWallet: rp.CreatorWallet,
GroupID: rp.GroupID,
ScopeType: rp.ScopeType,
ReceiverUserID: rp.ReceiverUserID,
ReceiverUserIDs: append([]string(nil), rp.ReceiverUserIDs...),
PacketType: rp.PacketType,
Token: rp.Token,
TotalAmount: rp.TotalAmount,
TotalShares: rp.TotalShares,
ClaimedAmount: rp.ClaimedAmount,
ClaimedShares: rp.ClaimedShares,
ExpiryAt: rp.ExpiryAt,
TxHash: rp.TxHash,
Status: rp.Status,
CreatedAt: rp.CreatedAt.Unix(),
UpdatedAt: rp.UpdatedAt.Unix(),
}
}
// RequestRefund allows the red-packet creator to submit an on-chain refund
// transaction for an expired packet. The indexer will asynchronously pick up
// the on-chain RefundPacket event and mark the packet as REFUNDED in the DB.
func (s *redPacketServer) RequestRefund(ctx context.Context, req *pbredpacket.RequestRefundReq) (*pbredpacket.RequestRefundResp, error) {
currentUserID := mcontext.GetOpUserID(ctx)
if currentUserID == "" {
return nil, servererrs.ErrNoPermission.WrapMsg("op user id is empty")
}
if req.GetPacketID() == "" {
return nil, errs.ErrArgs.WrapMsg("packet_id is required")
}
rp, err := s.db.GetRedPacketByPacketID(ctx, req.GetPacketID())
if err != nil {
return nil, err
}
if rp.CreatorUserID != currentUserID {
return nil, errs.ErrNoPermission.WrapMsg("only the creator can request a refund")
}
if rp.Status == "REFUNDED" {
return &pbredpacket.RequestRefundResp{TxHash: "", Status: "REFUNDED"}, nil
}
if rp.ExpiryAt > 0 && time.Now().Unix() < rp.ExpiryAt {
return nil, errs.ErrArgs.WrapMsg("red packet has not expired yet")
}
// Submit the on-chain refund transaction.
var txHash string
if s.chainClient != nil {
txHash, err = s.chainClient.RefundPacket(ctx, rp.PacketID)
if err != nil {
return nil, errs.ErrInternalServer.WrapMsg("submit refund tx failed: " + err.Error())
}
} else if s.tronClient != nil {
packetIDBig, ok := new(big.Int).SetString(rp.PacketID, 10)
if !ok {
return nil, errs.ErrInternalServer.WrapMsg("invalid packet id format")
}
txHash, err = s.tronClient.SendAdminTransaction(ctx, "refundPacket", packetIDBig)
if err != nil {
return nil, errs.ErrInternalServer.WrapMsg("submit tron refund tx failed: " + err.Error())
}
} else {
return nil, errs.ErrInternalServer.WrapMsg("no blockchain client configured")
}
log.ZInfo(ctx, "redpacket refund submitted", "packetID", rp.PacketID, "txHash", txHash)
txSuccess, events, parseErr := s.parseChainReceiptWithStatus(ctx, rp, txHash)
if parseErr != nil {
log.ZWarn(ctx, "parse refund receipt failed, fallback to async indexer", parseErr, "packetID", rp.PacketID, "txHash", txHash)
return &pbredpacket.RequestRefundResp{TxHash: txHash, Status: "PENDING"}, nil
}
if !txSuccess {
return &pbredpacket.RequestRefundResp{TxHash: txHash, Status: "FAILED"}, nil
}
refundedEvent, err := resolveRefundedEventFromParsedEvents(rp, events)
if err != nil {
log.ZWarn(ctx, "resolve refunded event failed, fallback to async indexer", err, "packetID", rp.PacketID, "txHash", txHash)
return &pbredpacket.RequestRefundResp{TxHash: txHash, Status: "PENDING"}, nil
}
if refundedEvent == nil {
return &pbredpacket.RequestRefundResp{TxHash: txHash, Status: "PENDING"}, nil
}
if err := s.db.SaveRefund(ctx, &model.RedPacketRefund{
PacketID: rp.PacketID,
RefundTo: refundedEvent.RefundTo,
TxHash: txHash,
Amount: refundedEvent.Amount,
CreatedAt: time.Now(),
}); err != nil {
return nil, err
}
if err := s.db.UpdateRedPacketStatus(ctx, rp.PacketID, "REFUNDED"); err != nil {
return nil, err
}
return &pbredpacket.RequestRefundResp{TxHash: txHash, Status: "REFUNDED"}, nil
}
func (s *redPacketServer) GetRefund(ctx context.Context, req *pbredpacket.GetRefundReq) (*pbredpacket.GetRefundResp, error) {
if req.GetPacketID() == "" {
return nil, errs.ErrArgs.WrapMsg("packet_id is required")
}
refund, err := s.db.GetRefundByPacketID(ctx, req.GetPacketID())
if err != nil {
return nil, err
}
return &pbredpacket.GetRefundResp{
PacketID: refund.PacketID,
RefundTo: refund.RefundTo,
TxHash: refund.TxHash,
Amount: refund.Amount,
CreatedAt: refund.CreatedAt.Unix(),
}, nil
}
func claimsModelToProto(claims []*model.RedPacketClaim) []*pbredpacket.RedPacketClaimRecord {
out := make([]*pbredpacket.RedPacketClaimRecord, 0, len(claims))
for _, c := range claims {
if c == nil {
continue
}
out = append(out, &pbredpacket.RedPacketClaimRecord{
PacketID: c.PacketID,
UserID: c.UserID,
ClaimerWallet: c.ClaimerWallet,
AuthNonce: c.AuthNonce,
ClaimTxHash: c.ClaimTxHash,
ClaimedAmount: c.ClaimedAmount,
BlockNumber: c.BlockNumber,
Status: c.Status,
CreatedAt: c.CreatedAt.Unix(),
UpdatedAt: c.UpdatedAt.Unix(),
})
}
return out
}