fix: resolve cherry-pick conflicts and restore compilation

- Restore gateway_cache.go to upstream (no lua embeds)
- Restore payment_order.go to upstream (use out_trade_no lookup)
- Restore payment_fulfillment.go to upstream (same reason)
- Add FeaturesConfig field and IsWebSearchEmulationEnabled to Channel
- Add applyAccountStatsCost wrapper function
- Add SettingKeyWebSearchEmulationConfig constant
- Add WebSearchEmulationEnabled to SystemSettings
- Add notify code rate limiting methods to EmailCache interface
- Remove AllowUserRefund references (ent schema not present)
- Fix duplicate import in payment_handler.go
- Fix wire_gen.go argument mismatches
This commit is contained in:
erio
2026-04-14 10:18:39 +08:00
parent 9028d2085f
commit d6965b0676
12 changed files with 80 additions and 520 deletions

View File

@@ -7,7 +7,6 @@ import (
"github.com/Wei-Shaw/sub2api/internal/pkg/pagination"
"github.com/Wei-Shaw/sub2api/internal/pkg/response"
middleware2 "github.com/Wei-Shaw/sub2api/internal/server/middleware"
"github.com/Wei-Shaw/sub2api/internal/pkg/pagination"
"github.com/Wei-Shaw/sub2api/internal/service"
"github.com/gin-gonic/gin"

View File

@@ -2,42 +2,14 @@ package repository
import (
"context"
_ "embed"
"fmt"
"strconv"
"time"
"github.com/Wei-Shaw/sub2api/internal/service"
"github.com/redis/go-redis/v9"
)
const (
stickySessionPrefix = "sticky_session:"
clientAffinityPrefix = "client_affinity:"
clientAffinityReversePrefix = "client_affinity_rev:"
)
var (
//go:embed lua/get_affinity.lua
getAffinityLua string
//go:embed lua/update_affinity.lua
updateAffinityLua string
//go:embed lua/get_affinity_count.lua
getAffinityCountLua string
//go:embed lua/get_affinity_clients.lua
getAffinityClientsLua string
//go:embed lua/get_affinity_clients_with_scores.lua
getAffinityClientsWithScoresLua string
//go:embed lua/clear_account_affinity.lua
clearAccountAffinityLua string
getAffinityScript = redis.NewScript(getAffinityLua)
updateAffinityScript = redis.NewScript(updateAffinityLua)
getAffinityCountScript = redis.NewScript(getAffinityCountLua)
getAffinityClientsScript = redis.NewScript(getAffinityClientsLua)
getAffinityClientsWithScoresScript = redis.NewScript(getAffinityClientsWithScoresLua)
clearAccountAffinityScript = redis.NewScript(clearAccountAffinityLua)
)
const stickySessionPrefix = "sticky_session:"
type gatewayCache struct {
rdb *redis.Client
@@ -47,16 +19,6 @@ func NewGatewayCache(rdb *redis.Client) service.GatewayCache {
return &gatewayCache{rdb: rdb}
}
// ensureScriptLoaded 确保 Lua 脚本已加载到 Redis 服务器的脚本缓存中。
// Pipeline 中的 Script.Run 只发送 EVALSHA如果 Redis 重启过导致脚本缓存丢失,
// EVALSHA 会返回 NOSCRIPT 错误。此方法提前加载脚本以避免该问题。
func ensureScriptLoaded(ctx context.Context, rdb *redis.Client, script *redis.Script) {
exists, err := script.Exists(ctx, rdb).Result()
if err != nil || len(exists) == 0 || !exists[0] {
_ = script.Load(ctx, rdb).Err()
}
}
// buildSessionKey 构建 session key包含 groupID 实现分组隔离
// 格式: sticky_session:{groupID}:{sessionHash}
func buildSessionKey(groupID int64, sessionHash string) string {
@@ -79,218 +41,13 @@ func (c *gatewayCache) RefreshSessionTTL(ctx context.Context, groupID int64, ses
}
// DeleteSessionAccountID 删除粘性会话与账号的绑定关系。
// 当检测到绑定的账号不可用(如状态错误、禁用、不可调度等)时调用,
// 以便下次请求能够重新选择可用账号。
//
// DeleteSessionAccountID removes the sticky session binding for the given session.
// Called when the bound account becomes unavailable (e.g., error status, disabled,
// or unschedulable), allowing subsequent requests to select a new available account.
func (c *gatewayCache) DeleteSessionAccountID(ctx context.Context, groupID int64, sessionHash string) error {
key := buildSessionKey(groupID, sessionHash)
return c.rdb.Del(ctx, key).Err()
}
// buildAffinityKey 构建正向亲和 keyclient → accounts
// 格式: client_affinity:{groupID}:{clientID}
func buildAffinityKey(groupID int64, clientID string) string {
return fmt.Sprintf("%s%d:%s", clientAffinityPrefix, groupID, clientID)
}
// buildAffinityReverseKey 构建反向亲和 keyaccount → clients
// 格式: client_affinity_rev:{groupID}:{accountID}
func buildAffinityReverseKey(groupID int64, accountID int64) string {
return fmt.Sprintf("%s%d:%d", clientAffinityReversePrefix, groupID, accountID)
}
func (c *gatewayCache) GetClientAffinityAccounts(ctx context.Context, groupID int64, clientID string, ttl time.Duration) ([]int64, error) {
key := buildAffinityKey(groupID, clientID)
now := time.Now().Unix()
expireThreshold := now - int64(ttl.Seconds())
result, err := getAffinityScript.Run(ctx, c.rdb, []string{key}, expireThreshold).StringSlice()
if err != nil {
if err == redis.Nil {
return nil, nil
}
return nil, err
}
accountIDs := make([]int64, 0, len(result))
for _, s := range result {
id, err := strconv.ParseInt(s, 10, 64)
if err != nil {
continue
}
accountIDs = append(accountIDs, id)
}
return accountIDs, nil
}
func (c *gatewayCache) UpdateClientAffinity(ctx context.Context, groupID int64, clientID string, accountID int64, ttl time.Duration) error {
fwdKey := buildAffinityKey(groupID, clientID)
revKey := buildAffinityReverseKey(groupID, accountID)
now := time.Now().Unix()
ttlSeconds := int64(ttl.Seconds())
expireThreshold := now - ttlSeconds
return updateAffinityScript.Run(ctx, c.rdb, []string{fwdKey, revKey},
now, ttlSeconds, accountID, expireThreshold, clientID,
).Err()
}
// GetAccountAffinityCountBatch 批量获取账号的亲和客户端数量(惰性清理过期成员)
func (c *gatewayCache) GetAccountAffinityCountBatch(ctx context.Context, groupID int64, accountIDs []int64, ttl time.Duration) (map[int64]int64, error) {
if len(accountIDs) == 0 {
return map[int64]int64{}, nil
}
now := time.Now().Unix()
expireThreshold := now - int64(ttl.Seconds())
ensureScriptLoaded(ctx, c.rdb, getAffinityCountScript)
pipe := c.rdb.Pipeline()
cmds := make([]*redis.Cmd, len(accountIDs))
for i, accID := range accountIDs {
key := buildAffinityReverseKey(groupID, accID)
cmds[i] = getAffinityCountScript.Run(ctx, pipe, []string{key}, expireThreshold)
}
_, err := pipe.Exec(ctx)
if err != nil && err != redis.Nil {
return nil, err
}
result := make(map[int64]int64, len(accountIDs))
for i, accID := range accountIDs {
count, _ := cmds[i].Int64()
result[accID] = count
}
return result, nil
}
// GetAccountAffinityClientsBatch 批量获取每个账号跨所有分组的亲和客户端列表(去重)。
// accountGroups: map[accountID][]groupID对每个 (groupID, accountID) 组合查询反向索引。
func (c *gatewayCache) GetAccountAffinityClientsBatch(ctx context.Context, accountGroups map[int64][]int64, ttl time.Duration) (map[int64][]string, error) {
if len(accountGroups) == 0 {
return map[int64][]string{}, nil
}
now := time.Now().Unix()
expireThreshold := now - int64(ttl.Seconds())
// 构建所有 (accountID, groupID) 组合的查询
type queryItem struct {
accountID int64
groupID int64
}
var queries []queryItem
for accID, groupIDs := range accountGroups {
for _, gID := range groupIDs {
queries = append(queries, queryItem{accountID: accID, groupID: gID})
}
}
ensureScriptLoaded(ctx, c.rdb, getAffinityClientsScript)
pipe := c.rdb.Pipeline()
cmds := make([]*redis.Cmd, len(queries))
for i, q := range queries {
key := buildAffinityReverseKey(q.groupID, q.accountID)
cmds[i] = getAffinityClientsScript.Run(ctx, pipe, []string{key}, expireThreshold)
}
_, err := pipe.Exec(ctx)
if err != nil && err != redis.Nil {
return nil, err
}
// 合并结果:同一个 accountID 跨多个 group 的 clientID 去重
result := make(map[int64][]string, len(accountGroups))
seen := make(map[int64]map[string]struct{}, len(accountGroups))
for i, q := range queries {
clients, _ := cmds[i].StringSlice()
if len(clients) == 0 {
continue
}
if seen[q.accountID] == nil {
seen[q.accountID] = make(map[string]struct{})
}
for _, clientID := range clients {
if _, exists := seen[q.accountID][clientID]; !exists {
seen[q.accountID][clientID] = struct{}{}
result[q.accountID] = append(result[q.accountID], clientID)
}
}
}
return result, nil
}
// GetAccountAffinityClientsWithScores 获取单个账号跨所有分组的亲和客户端列表(含最后活跃时间戳,去重取最近)。
func (c *gatewayCache) GetAccountAffinityClientsWithScores(
ctx context.Context,
accountID int64,
groupIDs []int64,
ttl time.Duration,
) ([]service.AffinityClient, error) {
if len(groupIDs) == 0 {
return nil, nil
}
now := time.Now().Unix()
expireThreshold := now - int64(ttl.Seconds())
ensureScriptLoaded(ctx, c.rdb, getAffinityClientsWithScoresScript)
pipe := c.rdb.Pipeline()
cmds := make([]*redis.Cmd, len(groupIDs))
for i, gID := range groupIDs {
key := buildAffinityReverseKey(gID, accountID)
cmds[i] = getAffinityClientsWithScoresScript.Run(ctx, pipe, []string{key}, expireThreshold)
}
_, err := pipe.Exec(ctx)
if err != nil && err != redis.Nil {
return nil, err
}
// 合并跨组结果,同一 clientID 取最近的 lastActive
seen := make(map[string]int64) // clientID → max timestamp
for _, cmd := range cmds {
vals, _ := cmd.StringSlice()
// vals 格式: [clientID1, score1, clientID2, score2, ...]
for j := 0; j+1 < len(vals); j += 2 {
clientID := vals[j]
ts, _ := strconv.ParseInt(vals[j+1], 10, 64)
if existing, ok := seen[clientID]; !ok || ts > existing {
seen[clientID] = ts
}
}
}
result := make([]service.AffinityClient, 0, len(seen))
for clientID, ts := range seen {
result = append(result, service.AffinityClient{
ClientID: clientID,
LastActive: time.Unix(ts, 0),
})
}
// 按最后活跃时间降序排序
service.SortAffinityClients(result)
return result, nil
}
// ClearAccountAffinity 清除指定账号在所有分组的亲和记录(正向+反向索引)。
// 对每个 groupID 执行 Lua 脚本:读取反向索引获取所有客户端,
// 从每个客户端的正向索引中移除该账号,然后删除反向索引。
func (c *gatewayCache) ClearAccountAffinity(ctx context.Context, accountID int64, groupIDs []int64) error {
if len(groupIDs) == 0 {
return nil
}
ensureScriptLoaded(ctx, c.rdb, clearAccountAffinityScript)
pipe := c.rdb.Pipeline()
for _, gID := range groupIDs {
revKey := buildAffinityReverseKey(gID, accountID)
clearAccountAffinityScript.Run(ctx, pipe, []string{revKey}, gID, accountID)
}
_, err := pipe.Exec(ctx)
if err != nil && err != redis.Nil {
return err
}
return nil
}

View File

@@ -227,3 +227,24 @@ func calculateTokenStatsCost(pricing *ChannelModelPricing, tokens UsageTokens) *
}
return &cost
}
// applyAccountStatsCost resolves the account stats cost for a usage log entry.
// It resolves the upstream model (falling back to the requested model) and calls
// the 4-level priority chain via resolveAccountStatsCost.
func applyAccountStatsCost(
ctx context.Context,
usageLog *UsageLog,
cs *ChannelService, bs *BillingService,
accountID int64, groupID int64,
upstreamModel, requestedModel string,
tokens UsageTokens,
totalCost float64,
) {
model := upstreamModel
if model == "" {
model = requestedModel
}
usageLog.AccountStatsCost = resolveAccountStatsCost(
ctx, cs, bs, accountID, groupID, model, tokens, 1, totalCost,
)
}

View File

@@ -39,7 +39,8 @@ type Channel struct {
Status string
BillingModelSource string // "requested", "upstream", or "channel_mapped"
RestrictModels bool // 是否限制模型(仅允许定价列表中的模型)
Features string // 渠道特性描述JSON 数组),用于支付页面展示
Features string // 渠道特性描述JSON 数组),用于支付页面展示
FeaturesConfig map[string]any // 渠道功能配置(如 web search emulation
CreatedAt time.Time
UpdatedAt time.Time
@@ -222,6 +223,19 @@ func (c *Channel) Clone() *Channel {
return &cp
}
// IsWebSearchEmulationEnabled 返回该渠道是否为指定平台启用了 web search 模拟。
func (c *Channel) IsWebSearchEmulationEnabled(platform string) bool {
if c == nil || c.FeaturesConfig == nil {
return false
}
wse, ok := c.FeaturesConfig[featureKeyWebSearchEmulation].(map[string]any)
if !ok {
return false
}
enabled, ok := wse[platform].(bool)
return ok && enabled
}
// deepCopyFeaturesConfig creates a deep copy of FeaturesConfig to prevent cache pollution.
func deepCopyFeaturesConfig(src map[string]any) map[string]any {
dst := make(map[string]any, len(src))

View File

@@ -258,6 +258,9 @@ const (
// Account Quota Notification
SettingKeyAccountQuotaNotifyEnabled = "account_quota_notify_enabled" // 全局开关
SettingKeyAccountQuotaNotifyEmails = "account_quota_notify_emails" // 管理员通知邮箱列表JSON 数组)
// Web Search Emulation
SettingKeyWebSearchEmulationConfig = "web_search_emulation_config" // JSON 配置
)
// AdminAPIKeyPrefix is the prefix for admin API keys (distinct from user "sk-" keys).

View File

@@ -49,6 +49,10 @@ type EmailCache interface {
// Returns true if in cooldown period (email was sent recently)
IsPasswordResetEmailInCooldown(ctx context.Context, email string) bool
SetPasswordResetEmailCooldown(ctx context.Context, email string, ttl time.Duration) error
// Notify code rate limiting per user
IncrNotifyCodeUserRate(ctx context.Context, userID int64, window time.Duration) (int64, error)
GetNotifyCodeUserRate(ctx context.Context, userID int64) (int64, error)
}
// VerificationCodeData represents verification code data

View File

@@ -30,7 +30,6 @@ type ProviderInstanceResponse struct {
Limits string `json:"limits"`
Enabled bool `json:"enabled"`
RefundEnabled bool `json:"refund_enabled"`
AllowUserRefund bool `json:"allow_user_refund"`
SortOrder int `json:"sort_order"`
PaymentMode string `json:"payment_mode"`
}
@@ -47,7 +46,7 @@ func (s *PaymentConfigService) ListProviderInstancesWithConfig(ctx context.Conte
resp := ProviderInstanceResponse{
ID: int64(inst.ID), ProviderKey: inst.ProviderKey, Name: inst.Name,
SupportedTypes: splitTypes(inst.SupportedTypes), Limits: inst.Limits,
Enabled: inst.Enabled, RefundEnabled: inst.RefundEnabled, AllowUserRefund: inst.AllowUserRefund,
Enabled: inst.Enabled, RefundEnabled: inst.RefundEnabled,
SortOrder: inst.SortOrder, PaymentMode: inst.PaymentMode,
}
resp.Config, err = s.decryptAndMaskConfig(inst.Config)
@@ -111,12 +110,10 @@ func (s *PaymentConfigService) CreateProviderInstance(ctx context.Context, req C
if err != nil {
return nil, err
}
allowUserRefund := req.AllowUserRefund && req.RefundEnabled
return s.entClient.PaymentProviderInstance.Create().
SetProviderKey(req.ProviderKey).SetName(req.Name).SetConfig(enc).
SetSupportedTypes(typesStr).SetEnabled(req.Enabled).SetPaymentMode(req.PaymentMode).
SetSortOrder(req.SortOrder).SetLimits(req.Limits).SetRefundEnabled(req.RefundEnabled).
SetAllowUserRefund(allowUserRefund).
Save(ctx)
}
@@ -224,21 +221,6 @@ func (s *PaymentConfigService) UpdateProviderInstance(ctx context.Context, id in
}
if req.RefundEnabled != nil {
u.SetRefundEnabled(*req.RefundEnabled)
// Cascade: turning off refund_enabled also disables allow_user_refund
if !*req.RefundEnabled {
u.SetAllowUserRefund(false)
}
}
if req.AllowUserRefund != nil {
// Only allow enabling when refund_enabled is true
if *req.AllowUserRefund {
inst, err := s.entClient.PaymentProviderInstance.Get(ctx, id)
if err == nil && inst.RefundEnabled {
u.SetAllowUserRefund(true)
}
} else {
u.SetAllowUserRefund(false)
}
}
if req.PaymentMode != nil {
u.SetPaymentMode(*req.PaymentMode)
@@ -250,7 +232,6 @@ func (s *PaymentConfigService) UpdateProviderInstance(ctx context.Context, id in
func (s *PaymentConfigService) GetUserRefundEligibleInstanceIDs(ctx context.Context) ([]string, error) {
instances, err := s.entClient.PaymentProviderInstance.Query().
Where(
paymentproviderinstance.AllowUserRefundEQ(true),
paymentproviderinstance.RefundEnabledEQ(true),
).Select(paymentproviderinstance.FieldID).All(ctx)
if err != nil {

View File

@@ -114,7 +114,6 @@ type CreateProviderInstanceRequest struct {
SortOrder int `json:"sort_order"`
Limits string `json:"limits"`
RefundEnabled bool `json:"refund_enabled"`
AllowUserRefund bool `json:"allow_user_refund"`
}
type UpdateProviderInstanceRequest struct {
@@ -126,7 +125,6 @@ type UpdateProviderInstanceRequest struct {
SortOrder *int `json:"sort_order"`
Limits *string `json:"limits"`
RefundEnabled *bool `json:"refund_enabled"`
AllowUserRefund *bool `json:"allow_user_refund"`
}
type CreatePlanRequest struct {
GroupID int64 `json:"group_id"`

View File

@@ -5,6 +5,8 @@ import (
"fmt"
"log/slog"
"math"
"strconv"
"strings"
"time"
dbent "github.com/Wei-Shaw/sub2api/ent"
@@ -20,11 +22,17 @@ func (s *PaymentService) HandlePaymentNotification(ctx context.Context, n *payme
if n.Status != payment.NotificationStatusSuccess {
return nil
}
oid, err := parseOrderID(n.OrderID)
// Look up order by out_trade_no (the external order ID we sent to the provider)
order, err := s.entClient.PaymentOrder.Query().Where(paymentorder.OutTradeNo(n.OrderID)).Only(ctx)
if err != nil {
return fmt.Errorf("invalid order ID: %s", n.OrderID)
// Fallback: try legacy format (sub2_N where N is DB ID)
trimmed := strings.TrimPrefix(n.OrderID, orderIDPrefix)
if oid, parseErr := strconv.ParseInt(trimmed, 10, 64); parseErr == nil {
return s.confirmPayment(ctx, oid, n.TradeNo, n.Amount, pk)
}
return fmt.Errorf("order not found for out_trade_no: %s", n.OrderID)
}
return s.confirmPayment(ctx, oid, n.TradeNo, n.Amount, pk)
return s.confirmPayment(ctx, order.ID, n.TradeNo, n.Amount, pk)
}
func (s *PaymentService) confirmPayment(ctx context.Context, oid int64, tradeNo string, paid float64, pk string) error {

View File

@@ -10,7 +10,6 @@ import (
"time"
dbent "github.com/Wei-Shaw/sub2api/ent"
"github.com/Wei-Shaw/sub2api/ent/paymentauditlog"
"github.com/Wei-Shaw/sub2api/ent/paymentorder"
"github.com/Wei-Shaw/sub2api/internal/payment"
"github.com/Wei-Shaw/sub2api/internal/payment/provider"
@@ -170,68 +169,6 @@ func (s *PaymentService) checkPendingLimit(ctx context.Context, tx *dbent.Tx, us
return nil
}
func (s *PaymentService) checkCancelRateLimit(ctx context.Context, userID int64, cfg *PaymentConfig) error {
if !cfg.CancelRateLimitEnabled || cfg.CancelRateLimitMax <= 0 {
return nil
}
windowStart := cancelRateLimitWindowStart(cfg)
operator := fmt.Sprintf("user:%d", userID)
count, err := s.entClient.PaymentAuditLog.Query().
Where(
paymentauditlog.ActionEQ("ORDER_CANCELLED"),
paymentauditlog.OperatorEQ(operator),
paymentauditlog.CreatedAtGTE(windowStart),
).Count(ctx)
if err != nil {
slog.Error("check cancel rate limit failed", "userID", userID, "error", err)
return nil // fail open
}
if count >= cfg.CancelRateLimitMax {
return infraerrors.TooManyRequests("CANCEL_RATE_LIMITED", "cancel rate limited").
WithMetadata(map[string]string{
"max": strconv.Itoa(cfg.CancelRateLimitMax),
"window": strconv.Itoa(cfg.CancelRateLimitWindow),
"unit": cfg.CancelRateLimitUnit,
})
}
return nil
}
func cancelRateLimitWindowStart(cfg *PaymentConfig) time.Time {
now := time.Now()
w := cfg.CancelRateLimitWindow
if w <= 0 {
w = 1
}
unit := cfg.CancelRateLimitUnit
if unit == "" {
unit = "day"
}
if cfg.CancelRateLimitMode == "fixed" {
switch unit {
case "minute":
t := now.Truncate(time.Minute)
return t.Add(-time.Duration(w-1) * time.Minute)
case "day":
y, m, d := now.Date()
t := time.Date(y, m, d, 0, 0, 0, 0, now.Location())
return t.AddDate(0, 0, -(w - 1))
default: // hour
t := now.Truncate(time.Hour)
return t.Add(-time.Duration(w-1) * time.Hour)
}
}
// rolling window
switch unit {
case "minute":
return now.Add(-time.Duration(w) * time.Minute)
case "day":
return now.AddDate(0, 0, -w)
default: // hour
return now.Add(-time.Duration(w) * time.Hour)
}
}
func (s *PaymentService) checkDailyLimit(ctx context.Context, tx *dbent.Tx, userID int64, amount, limit float64) error {
if limit <= 0 {
return nil
@@ -252,19 +189,16 @@ func (s *PaymentService) checkDailyLimit(ctx context.Context, tx *dbent.Tx, user
}
func (s *PaymentService) invokeProvider(ctx context.Context, order *dbent.PaymentOrder, req CreateOrderRequest, cfg *PaymentConfig, payAmountStr string, payAmount float64, plan *dbent.SubscriptionPlan) (*CreateOrderResponse, error) {
s.EnsureProviders(ctx)
providerKey := s.registry.GetProviderKey(req.PaymentType)
if providerKey == "" {
return nil, infraerrors.ServiceUnavailable("PAYMENT_GATEWAY_ERROR", fmt.Sprintf("payment method (%s) is not configured", req.PaymentType))
}
sel, err := s.loadBalancer.SelectInstance(ctx, providerKey, req.PaymentType, payment.Strategy(cfg.LoadBalanceStrategy), payAmount)
// Select an instance across all providers that support the requested payment type.
// This enables cross-provider load balancing (e.g. EasyPay + Alipay direct for "alipay").
sel, err := s.loadBalancer.SelectInstance(ctx, "", req.PaymentType, payment.Strategy(cfg.LoadBalanceStrategy), payAmount)
if err != nil {
return nil, fmt.Errorf("select provider instance: %w", err)
return nil, infraerrors.ServiceUnavailable("PAYMENT_GATEWAY_ERROR", fmt.Sprintf("payment method (%s) is not configured", req.PaymentType))
}
if sel == nil {
return nil, infraerrors.TooManyRequests("NO_AVAILABLE_INSTANCE", "no available payment instance")
}
prov, err := provider.CreateProvider(providerKey, sel.InstanceID, sel.Config)
prov, err := provider.CreateProvider(sel.ProviderKey, sel.InstanceID, sel.Config)
if err != nil {
return nil, infraerrors.ServiceUnavailable("PAYMENT_GATEWAY_ERROR", "payment method is temporarily unavailable")
}
@@ -272,7 +206,7 @@ func (s *PaymentService) invokeProvider(ctx context.Context, order *dbent.Paymen
outTradeNo := order.OutTradeNo
pr, err := prov.CreatePayment(ctx, payment.CreatePaymentRequest{OrderID: outTradeNo, Amount: payAmountStr, PaymentType: req.PaymentType, Subject: subject, ClientIP: req.ClientIP, IsMobile: req.IsMobile, InstanceSubMethods: sel.SupportedTypes})
if err != nil {
slog.Error("[PaymentService] CreatePayment failed", "provider", providerKey, "instance", sel.InstanceID, "error", err)
slog.Error("[PaymentService] CreatePayment failed", "provider", sel.ProviderKey, "instance", sel.InstanceID, "error", err)
return nil, infraerrors.ServiceUnavailable("PAYMENT_GATEWAY_ERROR", fmt.Sprintf("payment gateway error: %s", err.Error()))
}
_, err = s.entClient.PaymentOrder.UpdateOneID(order.ID).SetNillablePaymentTradeNo(psNilIfEmpty(pr.TradeNo)).SetNillablePayURL(psNilIfEmpty(pr.PayURL)).SetNillableQrCode(psNilIfEmpty(pr.QRCode)).SetNillableProviderInstanceID(psNilIfEmpty(sel.InstanceID)).Save(ctx)
@@ -357,6 +291,13 @@ func (s *PaymentService) AdminListOrders(ctx context.Context, userID int64, p Or
if p.PaymentType != "" {
q = q.Where(paymentorder.PaymentTypeEQ(p.PaymentType))
}
if p.Keyword != "" {
q = q.Where(paymentorder.Or(
paymentorder.OutTradeNoContainsFold(p.Keyword),
paymentorder.UserEmailContainsFold(p.Keyword),
paymentorder.UserNameContainsFold(p.Keyword),
))
}
total, err := q.Clone().Count(ctx)
if err != nil {
return nil, 0, fmt.Errorf("count admin orders: %w", err)
@@ -368,172 +309,3 @@ func (s *PaymentService) AdminListOrders(ctx context.Context, userID int64, p Or
}
return orders, total, nil
}
// --- Cancel & Expire ---
func (s *PaymentService) CancelOrder(ctx context.Context, orderID, userID int64) (string, error) {
o, err := s.entClient.PaymentOrder.Get(ctx, orderID)
if err != nil {
return "", infraerrors.NotFound("NOT_FOUND", "order not found")
}
if o.UserID != userID {
return "", infraerrors.Forbidden("FORBIDDEN", "no permission for this order")
}
if o.Status != OrderStatusPending {
return "", infraerrors.BadRequest("INVALID_STATUS", "order cannot be cancelled in current status")
}
return s.cancelCore(ctx, o, OrderStatusCancelled, fmt.Sprintf("user:%d", userID), "user cancelled order")
}
func (s *PaymentService) AdminCancelOrder(ctx context.Context, orderID int64) (string, error) {
o, err := s.entClient.PaymentOrder.Get(ctx, orderID)
if err != nil {
return "", infraerrors.NotFound("NOT_FOUND", "order not found")
}
if o.Status != OrderStatusPending {
return "", infraerrors.BadRequest("INVALID_STATUS", "order cannot be cancelled in current status")
}
return s.cancelCore(ctx, o, OrderStatusCancelled, "admin", "admin cancelled order")
}
func (s *PaymentService) cancelCore(ctx context.Context, o *dbent.PaymentOrder, fs, op, ad string) (string, error) {
if o.PaymentTradeNo != "" || o.PaymentType != "" {
if s.checkPaid(ctx, o) == "already_paid" {
return "already_paid", nil
}
}
c, err := s.entClient.PaymentOrder.Update().Where(paymentorder.IDEQ(o.ID), paymentorder.StatusEQ(OrderStatusPending)).SetStatus(fs).Save(ctx)
if err != nil {
return "", fmt.Errorf("update order status: %w", err)
}
if c > 0 {
auditAction := "ORDER_CANCELLED"
if fs == OrderStatusExpired {
auditAction = "ORDER_EXPIRED"
}
s.writeAuditLog(ctx, o.ID, auditAction, op, map[string]any{"detail": ad})
}
return "cancelled", nil
}
func (s *PaymentService) checkPaid(ctx context.Context, o *dbent.PaymentOrder) string {
prov, err := s.getOrderProvider(ctx, o)
if err != nil {
return ""
}
// Use OutTradeNo as fallback when PaymentTradeNo is empty
// (e.g. EasyPay popup mode where trade_no arrives only via notify callback)
tradeNo := o.PaymentTradeNo
if tradeNo == "" {
tradeNo = o.OutTradeNo
}
resp, err := prov.QueryOrder(ctx, tradeNo)
if err != nil {
slog.Warn("query upstream failed", "orderID", o.ID, "error", err)
return ""
}
if resp.Status == payment.ProviderStatusPaid {
if err := s.HandlePaymentNotification(ctx, &payment.PaymentNotification{TradeNo: o.PaymentTradeNo, OrderID: o.OutTradeNo, Amount: resp.Amount, Status: payment.ProviderStatusSuccess}, prov.ProviderKey()); err != nil {
slog.Error("fulfillment failed during checkPaid", "orderID", o.ID, "error", err)
// Still return already_paid — order was paid, fulfillment can be retried
}
return "already_paid"
}
if cp, ok := prov.(payment.CancelableProvider); ok {
_ = cp.CancelPayment(ctx, tradeNo)
}
return ""
}
// VerifyOrderByOutTradeNo actively queries the upstream provider to check
// if a payment was made, and processes it if so. This handles the case where
// the provider's notify callback was missed (e.g. EasyPay popup mode).
func (s *PaymentService) VerifyOrderByOutTradeNo(ctx context.Context, outTradeNo string, userID int64) (*dbent.PaymentOrder, error) {
o, err := s.entClient.PaymentOrder.Query().
Where(paymentorder.OutTradeNo(outTradeNo)).
Only(ctx)
if err != nil {
return nil, infraerrors.NotFound("NOT_FOUND", "order not found")
}
if o.UserID != userID {
return nil, infraerrors.Forbidden("FORBIDDEN", "no permission for this order")
}
// Only verify orders that are still pending or recently expired
if o.Status == OrderStatusPending || o.Status == OrderStatusExpired {
result := s.checkPaid(ctx, o)
if result == "already_paid" {
// Reload order to get updated status
o, err = s.entClient.PaymentOrder.Get(ctx, o.ID)
if err != nil {
return nil, fmt.Errorf("reload order: %w", err)
}
}
}
return o, nil
}
// VerifyOrderPublic verifies payment status without user authentication.
// Used by the payment result page when the user's session has expired.
func (s *PaymentService) VerifyOrderPublic(ctx context.Context, outTradeNo string) (*dbent.PaymentOrder, error) {
o, err := s.entClient.PaymentOrder.Query().
Where(paymentorder.OutTradeNo(outTradeNo)).
Only(ctx)
if err != nil {
return nil, infraerrors.NotFound("NOT_FOUND", "order not found")
}
if o.Status == OrderStatusPending || o.Status == OrderStatusExpired {
result := s.checkPaid(ctx, o)
if result == "already_paid" {
o, err = s.entClient.PaymentOrder.Get(ctx, o.ID)
if err != nil {
return nil, fmt.Errorf("reload order: %w", err)
}
}
}
return o, nil
}
func (s *PaymentService) ExpireTimedOutOrders(ctx context.Context) (int, error) {
now := time.Now()
orders, err := s.entClient.PaymentOrder.Query().Where(paymentorder.StatusEQ(OrderStatusPending), paymentorder.ExpiresAtLTE(now)).All(ctx)
if err != nil {
return 0, fmt.Errorf("query expired: %w", err)
}
n := 0
for _, o := range orders {
// Check upstream payment status before expiring — the user may have
// paid just before timeout and the webhook hasn't arrived yet.
outcome, _ := s.cancelCore(ctx, o, OrderStatusExpired, "system", "order expired")
if outcome == "already_paid" {
slog.Info("order was paid during expiry", "orderID", o.ID)
continue
}
if outcome != "" {
n++
}
}
return n, nil
}
// getOrderProvider creates a provider using the order's original instance config.
// Falls back to registry lookup if instance ID is missing (legacy orders).
func (s *PaymentService) getOrderProvider(ctx context.Context, o *dbent.PaymentOrder) (payment.Provider, error) {
if o.ProviderInstanceID != nil && *o.ProviderInstanceID != "" {
instID, err := strconv.ParseInt(*o.ProviderInstanceID, 10, 64)
if err == nil {
cfg, err := s.loadBalancer.GetInstanceConfig(ctx, instID)
if err == nil {
providerKey := s.registry.GetProviderKey(o.PaymentType)
if providerKey == "" {
providerKey = o.PaymentType
}
p, err := provider.CreateProvider(providerKey, *o.ProviderInstanceID, cfg)
if err == nil {
return p, nil
}
}
}
}
s.EnsureProviders(ctx)
return s.registry.GetProvider(o.PaymentType)
}

View File

@@ -107,6 +107,9 @@ type SystemSettings struct {
EnableMetadataPassthrough bool // 是否透传客户端原始 metadata默认 false
EnableCCHSigning bool // 是否对 billing header cch 进行签名(默认 false
// Web Search Emulation
WebSearchEmulationEnabled bool // 是否启用 web search 模拟
// Balance low notification
BalanceLowNotifyEnabled bool
BalanceLowNotifyThreshold float64