diff --git a/backend/cmd/server/wire_gen.go b/backend/cmd/server/wire_gen.go index 69daeecf..a0e84f4c 100644 --- a/backend/cmd/server/wire_gen.go +++ b/backend/cmd/server/wire_gen.go @@ -143,7 +143,7 @@ func initializeApplication(buildInfo handler.BuildInfo) (*Application, error) { schedulerSnapshotService := service.ProvideSchedulerSnapshotService(schedulerCache, schedulerOutboxRepository, accountRepository, groupRepository, configConfig) antigravityTokenProvider := service.ProvideAntigravityTokenProvider(accountRepository, geminiTokenCache, antigravityOAuthService, oAuthRefreshAPI, tempUnschedCache) internal500CounterCache := repository.NewInternal500CounterCache(redisClient) - antigravityGatewayService := service.NewAntigravityGatewayService(accountRepository, gatewayCache, schedulerSnapshotService, antigravityTokenProvider, rateLimitService, httpUpstream, settingService, internal500CounterCache, accountUsageService) + antigravityGatewayService := service.NewAntigravityGatewayService(accountRepository, gatewayCache, schedulerSnapshotService, antigravityTokenProvider, rateLimitService, httpUpstream, settingService, internal500CounterCache) accountTestService := service.NewAccountTestService(accountRepository, geminiTokenProvider, antigravityGatewayService, httpUpstream, configConfig, tlsFingerprintProfileService) crsSyncService := service.NewCRSSyncService(accountRepository, proxyRepository, oAuthService, openAIOAuthService, geminiOAuthService, configConfig) accountHandler := admin.NewAccountHandler(adminService, oAuthService, openAIOAuthService, geminiOAuthService, antigravityOAuthService, rateLimitService, accountUsageService, accountTestService, concurrencyService, crsSyncService, sessionLimitCache, rpmCache, compositeTokenCacheInvalidator) @@ -217,8 +217,8 @@ func initializeApplication(buildInfo handler.BuildInfo) (*Application, error) { } defaultLoadBalancer := payment.ProvideDefaultLoadBalancer(client, encryptionKey) paymentConfigService := service.ProvidePaymentConfigService(client, settingRepository, encryptionKey) - settingHandler := admin.NewSettingHandler(settingService, emailService, turnstileService, opsService, paymentConfigService) paymentService := service.NewPaymentService(client, registry, defaultLoadBalancer, redeemService, subscriptionService, paymentConfigService, userRepository, groupRepository) + settingHandler := admin.NewSettingHandler(settingService, emailService, turnstileService, opsService, paymentConfigService, paymentService) paymentOrderExpiryService := service.ProvidePaymentOrderExpiryService(paymentService) paymentHandler := admin.NewPaymentHandler(paymentService, paymentConfigService) adminHandlers := handler.ProvideAdminHandlers(dashboardHandler, adminUserHandler, groupHandler, accountHandler, adminAnnouncementHandler, dataManagementHandler, backupHandler, oAuthHandler, openAIOAuthHandler, geminiOAuthHandler, antigravityOAuthHandler, proxyHandler, adminRedeemHandler, promoHandler, settingHandler, opsHandler, systemHandler, adminSubscriptionHandler, adminUsageHandler, userAttributeHandler, errorPassthroughHandler, tlsFingerprintProfileHandler, adminAPIKeyHandler, scheduledTestHandler, channelHandler, paymentHandler) diff --git a/backend/internal/handler/payment_handler.go b/backend/internal/handler/payment_handler.go index e01a2af1..0425fc49 100644 --- a/backend/internal/handler/payment_handler.go +++ b/backend/internal/handler/payment_handler.go @@ -7,7 +7,6 @@ import ( "github.com/Wei-Shaw/sub2api/internal/pkg/pagination" "github.com/Wei-Shaw/sub2api/internal/pkg/response" middleware2 "github.com/Wei-Shaw/sub2api/internal/server/middleware" - "github.com/Wei-Shaw/sub2api/internal/pkg/pagination" "github.com/Wei-Shaw/sub2api/internal/service" "github.com/gin-gonic/gin" diff --git a/backend/internal/repository/gateway_cache.go b/backend/internal/repository/gateway_cache.go index ec4bf40e..58291b66 100644 --- a/backend/internal/repository/gateway_cache.go +++ b/backend/internal/repository/gateway_cache.go @@ -2,42 +2,14 @@ package repository import ( "context" - _ "embed" "fmt" - "strconv" "time" "github.com/Wei-Shaw/sub2api/internal/service" "github.com/redis/go-redis/v9" ) -const ( - stickySessionPrefix = "sticky_session:" - clientAffinityPrefix = "client_affinity:" - clientAffinityReversePrefix = "client_affinity_rev:" -) - -var ( - //go:embed lua/get_affinity.lua - getAffinityLua string - //go:embed lua/update_affinity.lua - updateAffinityLua string - //go:embed lua/get_affinity_count.lua - getAffinityCountLua string - //go:embed lua/get_affinity_clients.lua - getAffinityClientsLua string - //go:embed lua/get_affinity_clients_with_scores.lua - getAffinityClientsWithScoresLua string - //go:embed lua/clear_account_affinity.lua - clearAccountAffinityLua string - - getAffinityScript = redis.NewScript(getAffinityLua) - updateAffinityScript = redis.NewScript(updateAffinityLua) - getAffinityCountScript = redis.NewScript(getAffinityCountLua) - getAffinityClientsScript = redis.NewScript(getAffinityClientsLua) - getAffinityClientsWithScoresScript = redis.NewScript(getAffinityClientsWithScoresLua) - clearAccountAffinityScript = redis.NewScript(clearAccountAffinityLua) -) +const stickySessionPrefix = "sticky_session:" type gatewayCache struct { rdb *redis.Client @@ -47,16 +19,6 @@ func NewGatewayCache(rdb *redis.Client) service.GatewayCache { return &gatewayCache{rdb: rdb} } -// ensureScriptLoaded 确保 Lua 脚本已加载到 Redis 服务器的脚本缓存中。 -// Pipeline 中的 Script.Run 只发送 EVALSHA,如果 Redis 重启过导致脚本缓存丢失, -// EVALSHA 会返回 NOSCRIPT 错误。此方法提前加载脚本以避免该问题。 -func ensureScriptLoaded(ctx context.Context, rdb *redis.Client, script *redis.Script) { - exists, err := script.Exists(ctx, rdb).Result() - if err != nil || len(exists) == 0 || !exists[0] { - _ = script.Load(ctx, rdb).Err() - } -} - // buildSessionKey 构建 session key,包含 groupID 实现分组隔离 // 格式: sticky_session:{groupID}:{sessionHash} func buildSessionKey(groupID int64, sessionHash string) string { @@ -79,218 +41,13 @@ func (c *gatewayCache) RefreshSessionTTL(ctx context.Context, groupID int64, ses } // DeleteSessionAccountID 删除粘性会话与账号的绑定关系。 +// 当检测到绑定的账号不可用(如状态错误、禁用、不可调度等)时调用, +// 以便下次请求能够重新选择可用账号。 +// +// DeleteSessionAccountID removes the sticky session binding for the given session. +// Called when the bound account becomes unavailable (e.g., error status, disabled, +// or unschedulable), allowing subsequent requests to select a new available account. func (c *gatewayCache) DeleteSessionAccountID(ctx context.Context, groupID int64, sessionHash string) error { key := buildSessionKey(groupID, sessionHash) return c.rdb.Del(ctx, key).Err() } - -// buildAffinityKey 构建正向亲和 key(client → accounts) -// 格式: client_affinity:{groupID}:{clientID} -func buildAffinityKey(groupID int64, clientID string) string { - return fmt.Sprintf("%s%d:%s", clientAffinityPrefix, groupID, clientID) -} - -// buildAffinityReverseKey 构建反向亲和 key(account → clients) -// 格式: client_affinity_rev:{groupID}:{accountID} -func buildAffinityReverseKey(groupID int64, accountID int64) string { - return fmt.Sprintf("%s%d:%d", clientAffinityReversePrefix, groupID, accountID) -} - -func (c *gatewayCache) GetClientAffinityAccounts(ctx context.Context, groupID int64, clientID string, ttl time.Duration) ([]int64, error) { - key := buildAffinityKey(groupID, clientID) - now := time.Now().Unix() - expireThreshold := now - int64(ttl.Seconds()) - - result, err := getAffinityScript.Run(ctx, c.rdb, []string{key}, expireThreshold).StringSlice() - if err != nil { - if err == redis.Nil { - return nil, nil - } - return nil, err - } - - accountIDs := make([]int64, 0, len(result)) - for _, s := range result { - id, err := strconv.ParseInt(s, 10, 64) - if err != nil { - continue - } - accountIDs = append(accountIDs, id) - } - return accountIDs, nil -} - -func (c *gatewayCache) UpdateClientAffinity(ctx context.Context, groupID int64, clientID string, accountID int64, ttl time.Duration) error { - fwdKey := buildAffinityKey(groupID, clientID) - revKey := buildAffinityReverseKey(groupID, accountID) - now := time.Now().Unix() - ttlSeconds := int64(ttl.Seconds()) - expireThreshold := now - ttlSeconds - - return updateAffinityScript.Run(ctx, c.rdb, []string{fwdKey, revKey}, - now, ttlSeconds, accountID, expireThreshold, clientID, - ).Err() -} - -// GetAccountAffinityCountBatch 批量获取账号的亲和客户端数量(惰性清理过期成员) -func (c *gatewayCache) GetAccountAffinityCountBatch(ctx context.Context, groupID int64, accountIDs []int64, ttl time.Duration) (map[int64]int64, error) { - if len(accountIDs) == 0 { - return map[int64]int64{}, nil - } - - now := time.Now().Unix() - expireThreshold := now - int64(ttl.Seconds()) - - ensureScriptLoaded(ctx, c.rdb, getAffinityCountScript) - - pipe := c.rdb.Pipeline() - cmds := make([]*redis.Cmd, len(accountIDs)) - for i, accID := range accountIDs { - key := buildAffinityReverseKey(groupID, accID) - cmds[i] = getAffinityCountScript.Run(ctx, pipe, []string{key}, expireThreshold) - } - _, err := pipe.Exec(ctx) - if err != nil && err != redis.Nil { - return nil, err - } - - result := make(map[int64]int64, len(accountIDs)) - for i, accID := range accountIDs { - count, _ := cmds[i].Int64() - result[accID] = count - } - return result, nil -} - -// GetAccountAffinityClientsBatch 批量获取每个账号跨所有分组的亲和客户端列表(去重)。 -// accountGroups: map[accountID][]groupID,对每个 (groupID, accountID) 组合查询反向索引。 -func (c *gatewayCache) GetAccountAffinityClientsBatch(ctx context.Context, accountGroups map[int64][]int64, ttl time.Duration) (map[int64][]string, error) { - if len(accountGroups) == 0 { - return map[int64][]string{}, nil - } - - now := time.Now().Unix() - expireThreshold := now - int64(ttl.Seconds()) - - // 构建所有 (accountID, groupID) 组合的查询 - type queryItem struct { - accountID int64 - groupID int64 - } - var queries []queryItem - for accID, groupIDs := range accountGroups { - for _, gID := range groupIDs { - queries = append(queries, queryItem{accountID: accID, groupID: gID}) - } - } - - ensureScriptLoaded(ctx, c.rdb, getAffinityClientsScript) - - pipe := c.rdb.Pipeline() - cmds := make([]*redis.Cmd, len(queries)) - for i, q := range queries { - key := buildAffinityReverseKey(q.groupID, q.accountID) - cmds[i] = getAffinityClientsScript.Run(ctx, pipe, []string{key}, expireThreshold) - } - _, err := pipe.Exec(ctx) - if err != nil && err != redis.Nil { - return nil, err - } - - // 合并结果:同一个 accountID 跨多个 group 的 clientID 去重 - result := make(map[int64][]string, len(accountGroups)) - seen := make(map[int64]map[string]struct{}, len(accountGroups)) - for i, q := range queries { - clients, _ := cmds[i].StringSlice() - if len(clients) == 0 { - continue - } - if seen[q.accountID] == nil { - seen[q.accountID] = make(map[string]struct{}) - } - for _, clientID := range clients { - if _, exists := seen[q.accountID][clientID]; !exists { - seen[q.accountID][clientID] = struct{}{} - result[q.accountID] = append(result[q.accountID], clientID) - } - } - } - return result, nil -} - -// GetAccountAffinityClientsWithScores 获取单个账号跨所有分组的亲和客户端列表(含最后活跃时间戳,去重取最近)。 -func (c *gatewayCache) GetAccountAffinityClientsWithScores( - ctx context.Context, - accountID int64, - groupIDs []int64, - ttl time.Duration, -) ([]service.AffinityClient, error) { - if len(groupIDs) == 0 { - return nil, nil - } - - now := time.Now().Unix() - expireThreshold := now - int64(ttl.Seconds()) - - ensureScriptLoaded(ctx, c.rdb, getAffinityClientsWithScoresScript) - - pipe := c.rdb.Pipeline() - cmds := make([]*redis.Cmd, len(groupIDs)) - for i, gID := range groupIDs { - key := buildAffinityReverseKey(gID, accountID) - cmds[i] = getAffinityClientsWithScoresScript.Run(ctx, pipe, []string{key}, expireThreshold) - } - _, err := pipe.Exec(ctx) - if err != nil && err != redis.Nil { - return nil, err - } - - // 合并跨组结果,同一 clientID 取最近的 lastActive - seen := make(map[string]int64) // clientID → max timestamp - for _, cmd := range cmds { - vals, _ := cmd.StringSlice() - // vals 格式: [clientID1, score1, clientID2, score2, ...] - for j := 0; j+1 < len(vals); j += 2 { - clientID := vals[j] - ts, _ := strconv.ParseInt(vals[j+1], 10, 64) - if existing, ok := seen[clientID]; !ok || ts > existing { - seen[clientID] = ts - } - } - } - - result := make([]service.AffinityClient, 0, len(seen)) - for clientID, ts := range seen { - result = append(result, service.AffinityClient{ - ClientID: clientID, - LastActive: time.Unix(ts, 0), - }) - } - - // 按最后活跃时间降序排序 - service.SortAffinityClients(result) - - return result, nil -} - -// ClearAccountAffinity 清除指定账号在所有分组的亲和记录(正向+反向索引)。 -// 对每个 groupID 执行 Lua 脚本:读取反向索引获取所有客户端, -// 从每个客户端的正向索引中移除该账号,然后删除反向索引。 -func (c *gatewayCache) ClearAccountAffinity(ctx context.Context, accountID int64, groupIDs []int64) error { - if len(groupIDs) == 0 { - return nil - } - - ensureScriptLoaded(ctx, c.rdb, clearAccountAffinityScript) - - pipe := c.rdb.Pipeline() - for _, gID := range groupIDs { - revKey := buildAffinityReverseKey(gID, accountID) - clearAccountAffinityScript.Run(ctx, pipe, []string{revKey}, gID, accountID) - } - _, err := pipe.Exec(ctx) - if err != nil && err != redis.Nil { - return err - } - return nil -} diff --git a/backend/internal/service/account_stats_pricing.go b/backend/internal/service/account_stats_pricing.go index 61c318d9..47b7496f 100644 --- a/backend/internal/service/account_stats_pricing.go +++ b/backend/internal/service/account_stats_pricing.go @@ -227,3 +227,24 @@ func calculateTokenStatsCost(pricing *ChannelModelPricing, tokens UsageTokens) * } return &cost } + +// applyAccountStatsCost resolves the account stats cost for a usage log entry. +// It resolves the upstream model (falling back to the requested model) and calls +// the 4-level priority chain via resolveAccountStatsCost. +func applyAccountStatsCost( + ctx context.Context, + usageLog *UsageLog, + cs *ChannelService, bs *BillingService, + accountID int64, groupID int64, + upstreamModel, requestedModel string, + tokens UsageTokens, + totalCost float64, +) { + model := upstreamModel + if model == "" { + model = requestedModel + } + usageLog.AccountStatsCost = resolveAccountStatsCost( + ctx, cs, bs, accountID, groupID, model, tokens, 1, totalCost, + ) +} diff --git a/backend/internal/service/channel.go b/backend/internal/service/channel.go index b034fda0..b3fb2eac 100644 --- a/backend/internal/service/channel.go +++ b/backend/internal/service/channel.go @@ -39,7 +39,8 @@ type Channel struct { Status string BillingModelSource string // "requested", "upstream", or "channel_mapped" RestrictModels bool // 是否限制模型(仅允许定价列表中的模型) - Features string // 渠道特性描述(JSON 数组),用于支付页面展示 + Features string // 渠道特性描述(JSON 数组),用于支付页面展示 + FeaturesConfig map[string]any // 渠道功能配置(如 web search emulation) CreatedAt time.Time UpdatedAt time.Time @@ -222,6 +223,19 @@ func (c *Channel) Clone() *Channel { return &cp } +// IsWebSearchEmulationEnabled 返回该渠道是否为指定平台启用了 web search 模拟。 +func (c *Channel) IsWebSearchEmulationEnabled(platform string) bool { + if c == nil || c.FeaturesConfig == nil { + return false + } + wse, ok := c.FeaturesConfig[featureKeyWebSearchEmulation].(map[string]any) + if !ok { + return false + } + enabled, ok := wse[platform].(bool) + return ok && enabled +} + // deepCopyFeaturesConfig creates a deep copy of FeaturesConfig to prevent cache pollution. func deepCopyFeaturesConfig(src map[string]any) map[string]any { dst := make(map[string]any, len(src)) diff --git a/backend/internal/service/domain_constants.go b/backend/internal/service/domain_constants.go index bdced29a..cb452efb 100644 --- a/backend/internal/service/domain_constants.go +++ b/backend/internal/service/domain_constants.go @@ -258,6 +258,9 @@ const ( // Account Quota Notification SettingKeyAccountQuotaNotifyEnabled = "account_quota_notify_enabled" // 全局开关 SettingKeyAccountQuotaNotifyEmails = "account_quota_notify_emails" // 管理员通知邮箱列表(JSON 数组) + + // Web Search Emulation + SettingKeyWebSearchEmulationConfig = "web_search_emulation_config" // JSON 配置 ) // AdminAPIKeyPrefix is the prefix for admin API keys (distinct from user "sk-" keys). diff --git a/backend/internal/service/email_service.go b/backend/internal/service/email_service.go index 9cfd3bbd..9a03ea30 100644 --- a/backend/internal/service/email_service.go +++ b/backend/internal/service/email_service.go @@ -49,6 +49,10 @@ type EmailCache interface { // Returns true if in cooldown period (email was sent recently) IsPasswordResetEmailInCooldown(ctx context.Context, email string) bool SetPasswordResetEmailCooldown(ctx context.Context, email string, ttl time.Duration) error + + // Notify code rate limiting per user + IncrNotifyCodeUserRate(ctx context.Context, userID int64, window time.Duration) (int64, error) + GetNotifyCodeUserRate(ctx context.Context, userID int64) (int64, error) } // VerificationCodeData represents verification code data diff --git a/backend/internal/service/payment_config_providers.go b/backend/internal/service/payment_config_providers.go index 0c71ab29..072ed002 100644 --- a/backend/internal/service/payment_config_providers.go +++ b/backend/internal/service/payment_config_providers.go @@ -30,7 +30,6 @@ type ProviderInstanceResponse struct { Limits string `json:"limits"` Enabled bool `json:"enabled"` RefundEnabled bool `json:"refund_enabled"` - AllowUserRefund bool `json:"allow_user_refund"` SortOrder int `json:"sort_order"` PaymentMode string `json:"payment_mode"` } @@ -47,7 +46,7 @@ func (s *PaymentConfigService) ListProviderInstancesWithConfig(ctx context.Conte resp := ProviderInstanceResponse{ ID: int64(inst.ID), ProviderKey: inst.ProviderKey, Name: inst.Name, SupportedTypes: splitTypes(inst.SupportedTypes), Limits: inst.Limits, - Enabled: inst.Enabled, RefundEnabled: inst.RefundEnabled, AllowUserRefund: inst.AllowUserRefund, + Enabled: inst.Enabled, RefundEnabled: inst.RefundEnabled, SortOrder: inst.SortOrder, PaymentMode: inst.PaymentMode, } resp.Config, err = s.decryptAndMaskConfig(inst.Config) @@ -111,12 +110,10 @@ func (s *PaymentConfigService) CreateProviderInstance(ctx context.Context, req C if err != nil { return nil, err } - allowUserRefund := req.AllowUserRefund && req.RefundEnabled return s.entClient.PaymentProviderInstance.Create(). SetProviderKey(req.ProviderKey).SetName(req.Name).SetConfig(enc). SetSupportedTypes(typesStr).SetEnabled(req.Enabled).SetPaymentMode(req.PaymentMode). SetSortOrder(req.SortOrder).SetLimits(req.Limits).SetRefundEnabled(req.RefundEnabled). - SetAllowUserRefund(allowUserRefund). Save(ctx) } @@ -224,21 +221,6 @@ func (s *PaymentConfigService) UpdateProviderInstance(ctx context.Context, id in } if req.RefundEnabled != nil { u.SetRefundEnabled(*req.RefundEnabled) - // Cascade: turning off refund_enabled also disables allow_user_refund - if !*req.RefundEnabled { - u.SetAllowUserRefund(false) - } - } - if req.AllowUserRefund != nil { - // Only allow enabling when refund_enabled is true - if *req.AllowUserRefund { - inst, err := s.entClient.PaymentProviderInstance.Get(ctx, id) - if err == nil && inst.RefundEnabled { - u.SetAllowUserRefund(true) - } - } else { - u.SetAllowUserRefund(false) - } } if req.PaymentMode != nil { u.SetPaymentMode(*req.PaymentMode) @@ -250,7 +232,6 @@ func (s *PaymentConfigService) UpdateProviderInstance(ctx context.Context, id in func (s *PaymentConfigService) GetUserRefundEligibleInstanceIDs(ctx context.Context) ([]string, error) { instances, err := s.entClient.PaymentProviderInstance.Query(). Where( - paymentproviderinstance.AllowUserRefundEQ(true), paymentproviderinstance.RefundEnabledEQ(true), ).Select(paymentproviderinstance.FieldID).All(ctx) if err != nil { diff --git a/backend/internal/service/payment_config_service.go b/backend/internal/service/payment_config_service.go index cce31f4d..6d470342 100644 --- a/backend/internal/service/payment_config_service.go +++ b/backend/internal/service/payment_config_service.go @@ -114,7 +114,6 @@ type CreateProviderInstanceRequest struct { SortOrder int `json:"sort_order"` Limits string `json:"limits"` RefundEnabled bool `json:"refund_enabled"` - AllowUserRefund bool `json:"allow_user_refund"` } type UpdateProviderInstanceRequest struct { @@ -126,7 +125,6 @@ type UpdateProviderInstanceRequest struct { SortOrder *int `json:"sort_order"` Limits *string `json:"limits"` RefundEnabled *bool `json:"refund_enabled"` - AllowUserRefund *bool `json:"allow_user_refund"` } type CreatePlanRequest struct { GroupID int64 `json:"group_id"` diff --git a/backend/internal/service/payment_fulfillment.go b/backend/internal/service/payment_fulfillment.go index 51307849..de41d742 100644 --- a/backend/internal/service/payment_fulfillment.go +++ b/backend/internal/service/payment_fulfillment.go @@ -5,6 +5,8 @@ import ( "fmt" "log/slog" "math" + "strconv" + "strings" "time" dbent "github.com/Wei-Shaw/sub2api/ent" @@ -20,11 +22,17 @@ func (s *PaymentService) HandlePaymentNotification(ctx context.Context, n *payme if n.Status != payment.NotificationStatusSuccess { return nil } - oid, err := parseOrderID(n.OrderID) + // Look up order by out_trade_no (the external order ID we sent to the provider) + order, err := s.entClient.PaymentOrder.Query().Where(paymentorder.OutTradeNo(n.OrderID)).Only(ctx) if err != nil { - return fmt.Errorf("invalid order ID: %s", n.OrderID) + // Fallback: try legacy format (sub2_N where N is DB ID) + trimmed := strings.TrimPrefix(n.OrderID, orderIDPrefix) + if oid, parseErr := strconv.ParseInt(trimmed, 10, 64); parseErr == nil { + return s.confirmPayment(ctx, oid, n.TradeNo, n.Amount, pk) + } + return fmt.Errorf("order not found for out_trade_no: %s", n.OrderID) } - return s.confirmPayment(ctx, oid, n.TradeNo, n.Amount, pk) + return s.confirmPayment(ctx, order.ID, n.TradeNo, n.Amount, pk) } func (s *PaymentService) confirmPayment(ctx context.Context, oid int64, tradeNo string, paid float64, pk string) error { diff --git a/backend/internal/service/payment_order.go b/backend/internal/service/payment_order.go index e81af3f5..ff4dfaa8 100644 --- a/backend/internal/service/payment_order.go +++ b/backend/internal/service/payment_order.go @@ -10,7 +10,6 @@ import ( "time" dbent "github.com/Wei-Shaw/sub2api/ent" - "github.com/Wei-Shaw/sub2api/ent/paymentauditlog" "github.com/Wei-Shaw/sub2api/ent/paymentorder" "github.com/Wei-Shaw/sub2api/internal/payment" "github.com/Wei-Shaw/sub2api/internal/payment/provider" @@ -170,68 +169,6 @@ func (s *PaymentService) checkPendingLimit(ctx context.Context, tx *dbent.Tx, us return nil } -func (s *PaymentService) checkCancelRateLimit(ctx context.Context, userID int64, cfg *PaymentConfig) error { - if !cfg.CancelRateLimitEnabled || cfg.CancelRateLimitMax <= 0 { - return nil - } - windowStart := cancelRateLimitWindowStart(cfg) - operator := fmt.Sprintf("user:%d", userID) - count, err := s.entClient.PaymentAuditLog.Query(). - Where( - paymentauditlog.ActionEQ("ORDER_CANCELLED"), - paymentauditlog.OperatorEQ(operator), - paymentauditlog.CreatedAtGTE(windowStart), - ).Count(ctx) - if err != nil { - slog.Error("check cancel rate limit failed", "userID", userID, "error", err) - return nil // fail open - } - if count >= cfg.CancelRateLimitMax { - return infraerrors.TooManyRequests("CANCEL_RATE_LIMITED", "cancel rate limited"). - WithMetadata(map[string]string{ - "max": strconv.Itoa(cfg.CancelRateLimitMax), - "window": strconv.Itoa(cfg.CancelRateLimitWindow), - "unit": cfg.CancelRateLimitUnit, - }) - } - return nil -} - -func cancelRateLimitWindowStart(cfg *PaymentConfig) time.Time { - now := time.Now() - w := cfg.CancelRateLimitWindow - if w <= 0 { - w = 1 - } - unit := cfg.CancelRateLimitUnit - if unit == "" { - unit = "day" - } - if cfg.CancelRateLimitMode == "fixed" { - switch unit { - case "minute": - t := now.Truncate(time.Minute) - return t.Add(-time.Duration(w-1) * time.Minute) - case "day": - y, m, d := now.Date() - t := time.Date(y, m, d, 0, 0, 0, 0, now.Location()) - return t.AddDate(0, 0, -(w - 1)) - default: // hour - t := now.Truncate(time.Hour) - return t.Add(-time.Duration(w-1) * time.Hour) - } - } - // rolling window - switch unit { - case "minute": - return now.Add(-time.Duration(w) * time.Minute) - case "day": - return now.AddDate(0, 0, -w) - default: // hour - return now.Add(-time.Duration(w) * time.Hour) - } -} - func (s *PaymentService) checkDailyLimit(ctx context.Context, tx *dbent.Tx, userID int64, amount, limit float64) error { if limit <= 0 { return nil @@ -252,19 +189,16 @@ func (s *PaymentService) checkDailyLimit(ctx context.Context, tx *dbent.Tx, user } func (s *PaymentService) invokeProvider(ctx context.Context, order *dbent.PaymentOrder, req CreateOrderRequest, cfg *PaymentConfig, payAmountStr string, payAmount float64, plan *dbent.SubscriptionPlan) (*CreateOrderResponse, error) { - s.EnsureProviders(ctx) - providerKey := s.registry.GetProviderKey(req.PaymentType) - if providerKey == "" { - return nil, infraerrors.ServiceUnavailable("PAYMENT_GATEWAY_ERROR", fmt.Sprintf("payment method (%s) is not configured", req.PaymentType)) - } - sel, err := s.loadBalancer.SelectInstance(ctx, providerKey, req.PaymentType, payment.Strategy(cfg.LoadBalanceStrategy), payAmount) + // Select an instance across all providers that support the requested payment type. + // This enables cross-provider load balancing (e.g. EasyPay + Alipay direct for "alipay"). + sel, err := s.loadBalancer.SelectInstance(ctx, "", req.PaymentType, payment.Strategy(cfg.LoadBalanceStrategy), payAmount) if err != nil { - return nil, fmt.Errorf("select provider instance: %w", err) + return nil, infraerrors.ServiceUnavailable("PAYMENT_GATEWAY_ERROR", fmt.Sprintf("payment method (%s) is not configured", req.PaymentType)) } if sel == nil { return nil, infraerrors.TooManyRequests("NO_AVAILABLE_INSTANCE", "no available payment instance") } - prov, err := provider.CreateProvider(providerKey, sel.InstanceID, sel.Config) + prov, err := provider.CreateProvider(sel.ProviderKey, sel.InstanceID, sel.Config) if err != nil { return nil, infraerrors.ServiceUnavailable("PAYMENT_GATEWAY_ERROR", "payment method is temporarily unavailable") } @@ -272,7 +206,7 @@ func (s *PaymentService) invokeProvider(ctx context.Context, order *dbent.Paymen outTradeNo := order.OutTradeNo pr, err := prov.CreatePayment(ctx, payment.CreatePaymentRequest{OrderID: outTradeNo, Amount: payAmountStr, PaymentType: req.PaymentType, Subject: subject, ClientIP: req.ClientIP, IsMobile: req.IsMobile, InstanceSubMethods: sel.SupportedTypes}) if err != nil { - slog.Error("[PaymentService] CreatePayment failed", "provider", providerKey, "instance", sel.InstanceID, "error", err) + slog.Error("[PaymentService] CreatePayment failed", "provider", sel.ProviderKey, "instance", sel.InstanceID, "error", err) return nil, infraerrors.ServiceUnavailable("PAYMENT_GATEWAY_ERROR", fmt.Sprintf("payment gateway error: %s", err.Error())) } _, err = s.entClient.PaymentOrder.UpdateOneID(order.ID).SetNillablePaymentTradeNo(psNilIfEmpty(pr.TradeNo)).SetNillablePayURL(psNilIfEmpty(pr.PayURL)).SetNillableQrCode(psNilIfEmpty(pr.QRCode)).SetNillableProviderInstanceID(psNilIfEmpty(sel.InstanceID)).Save(ctx) @@ -357,6 +291,13 @@ func (s *PaymentService) AdminListOrders(ctx context.Context, userID int64, p Or if p.PaymentType != "" { q = q.Where(paymentorder.PaymentTypeEQ(p.PaymentType)) } + if p.Keyword != "" { + q = q.Where(paymentorder.Or( + paymentorder.OutTradeNoContainsFold(p.Keyword), + paymentorder.UserEmailContainsFold(p.Keyword), + paymentorder.UserNameContainsFold(p.Keyword), + )) + } total, err := q.Clone().Count(ctx) if err != nil { return nil, 0, fmt.Errorf("count admin orders: %w", err) @@ -368,172 +309,3 @@ func (s *PaymentService) AdminListOrders(ctx context.Context, userID int64, p Or } return orders, total, nil } - -// --- Cancel & Expire --- - -func (s *PaymentService) CancelOrder(ctx context.Context, orderID, userID int64) (string, error) { - o, err := s.entClient.PaymentOrder.Get(ctx, orderID) - if err != nil { - return "", infraerrors.NotFound("NOT_FOUND", "order not found") - } - if o.UserID != userID { - return "", infraerrors.Forbidden("FORBIDDEN", "no permission for this order") - } - if o.Status != OrderStatusPending { - return "", infraerrors.BadRequest("INVALID_STATUS", "order cannot be cancelled in current status") - } - return s.cancelCore(ctx, o, OrderStatusCancelled, fmt.Sprintf("user:%d", userID), "user cancelled order") -} - -func (s *PaymentService) AdminCancelOrder(ctx context.Context, orderID int64) (string, error) { - o, err := s.entClient.PaymentOrder.Get(ctx, orderID) - if err != nil { - return "", infraerrors.NotFound("NOT_FOUND", "order not found") - } - if o.Status != OrderStatusPending { - return "", infraerrors.BadRequest("INVALID_STATUS", "order cannot be cancelled in current status") - } - return s.cancelCore(ctx, o, OrderStatusCancelled, "admin", "admin cancelled order") -} - -func (s *PaymentService) cancelCore(ctx context.Context, o *dbent.PaymentOrder, fs, op, ad string) (string, error) { - if o.PaymentTradeNo != "" || o.PaymentType != "" { - if s.checkPaid(ctx, o) == "already_paid" { - return "already_paid", nil - } - } - c, err := s.entClient.PaymentOrder.Update().Where(paymentorder.IDEQ(o.ID), paymentorder.StatusEQ(OrderStatusPending)).SetStatus(fs).Save(ctx) - if err != nil { - return "", fmt.Errorf("update order status: %w", err) - } - if c > 0 { - auditAction := "ORDER_CANCELLED" - if fs == OrderStatusExpired { - auditAction = "ORDER_EXPIRED" - } - s.writeAuditLog(ctx, o.ID, auditAction, op, map[string]any{"detail": ad}) - } - return "cancelled", nil -} - -func (s *PaymentService) checkPaid(ctx context.Context, o *dbent.PaymentOrder) string { - prov, err := s.getOrderProvider(ctx, o) - if err != nil { - return "" - } - // Use OutTradeNo as fallback when PaymentTradeNo is empty - // (e.g. EasyPay popup mode where trade_no arrives only via notify callback) - tradeNo := o.PaymentTradeNo - if tradeNo == "" { - tradeNo = o.OutTradeNo - } - resp, err := prov.QueryOrder(ctx, tradeNo) - if err != nil { - slog.Warn("query upstream failed", "orderID", o.ID, "error", err) - return "" - } - if resp.Status == payment.ProviderStatusPaid { - if err := s.HandlePaymentNotification(ctx, &payment.PaymentNotification{TradeNo: o.PaymentTradeNo, OrderID: o.OutTradeNo, Amount: resp.Amount, Status: payment.ProviderStatusSuccess}, prov.ProviderKey()); err != nil { - slog.Error("fulfillment failed during checkPaid", "orderID", o.ID, "error", err) - // Still return already_paid — order was paid, fulfillment can be retried - } - return "already_paid" - } - if cp, ok := prov.(payment.CancelableProvider); ok { - _ = cp.CancelPayment(ctx, tradeNo) - } - return "" -} - -// VerifyOrderByOutTradeNo actively queries the upstream provider to check -// if a payment was made, and processes it if so. This handles the case where -// the provider's notify callback was missed (e.g. EasyPay popup mode). -func (s *PaymentService) VerifyOrderByOutTradeNo(ctx context.Context, outTradeNo string, userID int64) (*dbent.PaymentOrder, error) { - o, err := s.entClient.PaymentOrder.Query(). - Where(paymentorder.OutTradeNo(outTradeNo)). - Only(ctx) - if err != nil { - return nil, infraerrors.NotFound("NOT_FOUND", "order not found") - } - if o.UserID != userID { - return nil, infraerrors.Forbidden("FORBIDDEN", "no permission for this order") - } - // Only verify orders that are still pending or recently expired - if o.Status == OrderStatusPending || o.Status == OrderStatusExpired { - result := s.checkPaid(ctx, o) - if result == "already_paid" { - // Reload order to get updated status - o, err = s.entClient.PaymentOrder.Get(ctx, o.ID) - if err != nil { - return nil, fmt.Errorf("reload order: %w", err) - } - } - } - return o, nil -} - -// VerifyOrderPublic verifies payment status without user authentication. -// Used by the payment result page when the user's session has expired. -func (s *PaymentService) VerifyOrderPublic(ctx context.Context, outTradeNo string) (*dbent.PaymentOrder, error) { - o, err := s.entClient.PaymentOrder.Query(). - Where(paymentorder.OutTradeNo(outTradeNo)). - Only(ctx) - if err != nil { - return nil, infraerrors.NotFound("NOT_FOUND", "order not found") - } - if o.Status == OrderStatusPending || o.Status == OrderStatusExpired { - result := s.checkPaid(ctx, o) - if result == "already_paid" { - o, err = s.entClient.PaymentOrder.Get(ctx, o.ID) - if err != nil { - return nil, fmt.Errorf("reload order: %w", err) - } - } - } - return o, nil -} - -func (s *PaymentService) ExpireTimedOutOrders(ctx context.Context) (int, error) { - now := time.Now() - orders, err := s.entClient.PaymentOrder.Query().Where(paymentorder.StatusEQ(OrderStatusPending), paymentorder.ExpiresAtLTE(now)).All(ctx) - if err != nil { - return 0, fmt.Errorf("query expired: %w", err) - } - n := 0 - for _, o := range orders { - // Check upstream payment status before expiring — the user may have - // paid just before timeout and the webhook hasn't arrived yet. - outcome, _ := s.cancelCore(ctx, o, OrderStatusExpired, "system", "order expired") - if outcome == "already_paid" { - slog.Info("order was paid during expiry", "orderID", o.ID) - continue - } - if outcome != "" { - n++ - } - } - return n, nil -} - -// getOrderProvider creates a provider using the order's original instance config. -// Falls back to registry lookup if instance ID is missing (legacy orders). -func (s *PaymentService) getOrderProvider(ctx context.Context, o *dbent.PaymentOrder) (payment.Provider, error) { - if o.ProviderInstanceID != nil && *o.ProviderInstanceID != "" { - instID, err := strconv.ParseInt(*o.ProviderInstanceID, 10, 64) - if err == nil { - cfg, err := s.loadBalancer.GetInstanceConfig(ctx, instID) - if err == nil { - providerKey := s.registry.GetProviderKey(o.PaymentType) - if providerKey == "" { - providerKey = o.PaymentType - } - p, err := provider.CreateProvider(providerKey, *o.ProviderInstanceID, cfg) - if err == nil { - return p, nil - } - } - } - } - s.EnsureProviders(ctx) - return s.registry.GetProvider(o.PaymentType) -} diff --git a/backend/internal/service/settings_view.go b/backend/internal/service/settings_view.go index ec20fe0a..ab2eb274 100644 --- a/backend/internal/service/settings_view.go +++ b/backend/internal/service/settings_view.go @@ -107,6 +107,9 @@ type SystemSettings struct { EnableMetadataPassthrough bool // 是否透传客户端原始 metadata(默认 false) EnableCCHSigning bool // 是否对 billing header cch 进行签名(默认 false) + // Web Search Emulation + WebSearchEmulationEnabled bool // 是否启用 web search 模拟 + // Balance low notification BalanceLowNotifyEnabled bool BalanceLowNotifyThreshold float64