fix: resolve cherry-pick conflicts and restore compilation

- Restore gateway_cache.go to upstream (no lua embeds)
- Restore payment_order.go to upstream (use out_trade_no lookup)
- Restore payment_fulfillment.go to upstream (same reason)
- Add FeaturesConfig field and IsWebSearchEmulationEnabled to Channel
- Add applyAccountStatsCost wrapper function
- Add SettingKeyWebSearchEmulationConfig constant
- Add WebSearchEmulationEnabled to SystemSettings
- Add notify code rate limiting methods to EmailCache interface
- Remove AllowUserRefund references (ent schema not present)
- Fix duplicate import in payment_handler.go
- Fix wire_gen.go argument mismatches
This commit is contained in:
erio
2026-04-14 10:18:39 +08:00
parent 9028d2085f
commit d6965b0676
12 changed files with 80 additions and 520 deletions

View File

@@ -143,7 +143,7 @@ func initializeApplication(buildInfo handler.BuildInfo) (*Application, error) {
schedulerSnapshotService := service.ProvideSchedulerSnapshotService(schedulerCache, schedulerOutboxRepository, accountRepository, groupRepository, configConfig) schedulerSnapshotService := service.ProvideSchedulerSnapshotService(schedulerCache, schedulerOutboxRepository, accountRepository, groupRepository, configConfig)
antigravityTokenProvider := service.ProvideAntigravityTokenProvider(accountRepository, geminiTokenCache, antigravityOAuthService, oAuthRefreshAPI, tempUnschedCache) antigravityTokenProvider := service.ProvideAntigravityTokenProvider(accountRepository, geminiTokenCache, antigravityOAuthService, oAuthRefreshAPI, tempUnschedCache)
internal500CounterCache := repository.NewInternal500CounterCache(redisClient) internal500CounterCache := repository.NewInternal500CounterCache(redisClient)
antigravityGatewayService := service.NewAntigravityGatewayService(accountRepository, gatewayCache, schedulerSnapshotService, antigravityTokenProvider, rateLimitService, httpUpstream, settingService, internal500CounterCache, accountUsageService) antigravityGatewayService := service.NewAntigravityGatewayService(accountRepository, gatewayCache, schedulerSnapshotService, antigravityTokenProvider, rateLimitService, httpUpstream, settingService, internal500CounterCache)
accountTestService := service.NewAccountTestService(accountRepository, geminiTokenProvider, antigravityGatewayService, httpUpstream, configConfig, tlsFingerprintProfileService) accountTestService := service.NewAccountTestService(accountRepository, geminiTokenProvider, antigravityGatewayService, httpUpstream, configConfig, tlsFingerprintProfileService)
crsSyncService := service.NewCRSSyncService(accountRepository, proxyRepository, oAuthService, openAIOAuthService, geminiOAuthService, configConfig) crsSyncService := service.NewCRSSyncService(accountRepository, proxyRepository, oAuthService, openAIOAuthService, geminiOAuthService, configConfig)
accountHandler := admin.NewAccountHandler(adminService, oAuthService, openAIOAuthService, geminiOAuthService, antigravityOAuthService, rateLimitService, accountUsageService, accountTestService, concurrencyService, crsSyncService, sessionLimitCache, rpmCache, compositeTokenCacheInvalidator) accountHandler := admin.NewAccountHandler(adminService, oAuthService, openAIOAuthService, geminiOAuthService, antigravityOAuthService, rateLimitService, accountUsageService, accountTestService, concurrencyService, crsSyncService, sessionLimitCache, rpmCache, compositeTokenCacheInvalidator)
@@ -217,8 +217,8 @@ func initializeApplication(buildInfo handler.BuildInfo) (*Application, error) {
} }
defaultLoadBalancer := payment.ProvideDefaultLoadBalancer(client, encryptionKey) defaultLoadBalancer := payment.ProvideDefaultLoadBalancer(client, encryptionKey)
paymentConfigService := service.ProvidePaymentConfigService(client, settingRepository, encryptionKey) paymentConfigService := service.ProvidePaymentConfigService(client, settingRepository, encryptionKey)
settingHandler := admin.NewSettingHandler(settingService, emailService, turnstileService, opsService, paymentConfigService)
paymentService := service.NewPaymentService(client, registry, defaultLoadBalancer, redeemService, subscriptionService, paymentConfigService, userRepository, groupRepository) paymentService := service.NewPaymentService(client, registry, defaultLoadBalancer, redeemService, subscriptionService, paymentConfigService, userRepository, groupRepository)
settingHandler := admin.NewSettingHandler(settingService, emailService, turnstileService, opsService, paymentConfigService, paymentService)
paymentOrderExpiryService := service.ProvidePaymentOrderExpiryService(paymentService) paymentOrderExpiryService := service.ProvidePaymentOrderExpiryService(paymentService)
paymentHandler := admin.NewPaymentHandler(paymentService, paymentConfigService) paymentHandler := admin.NewPaymentHandler(paymentService, paymentConfigService)
adminHandlers := handler.ProvideAdminHandlers(dashboardHandler, adminUserHandler, groupHandler, accountHandler, adminAnnouncementHandler, dataManagementHandler, backupHandler, oAuthHandler, openAIOAuthHandler, geminiOAuthHandler, antigravityOAuthHandler, proxyHandler, adminRedeemHandler, promoHandler, settingHandler, opsHandler, systemHandler, adminSubscriptionHandler, adminUsageHandler, userAttributeHandler, errorPassthroughHandler, tlsFingerprintProfileHandler, adminAPIKeyHandler, scheduledTestHandler, channelHandler, paymentHandler) adminHandlers := handler.ProvideAdminHandlers(dashboardHandler, adminUserHandler, groupHandler, accountHandler, adminAnnouncementHandler, dataManagementHandler, backupHandler, oAuthHandler, openAIOAuthHandler, geminiOAuthHandler, antigravityOAuthHandler, proxyHandler, adminRedeemHandler, promoHandler, settingHandler, opsHandler, systemHandler, adminSubscriptionHandler, adminUsageHandler, userAttributeHandler, errorPassthroughHandler, tlsFingerprintProfileHandler, adminAPIKeyHandler, scheduledTestHandler, channelHandler, paymentHandler)

View File

@@ -7,7 +7,6 @@ import (
"github.com/Wei-Shaw/sub2api/internal/pkg/pagination" "github.com/Wei-Shaw/sub2api/internal/pkg/pagination"
"github.com/Wei-Shaw/sub2api/internal/pkg/response" "github.com/Wei-Shaw/sub2api/internal/pkg/response"
middleware2 "github.com/Wei-Shaw/sub2api/internal/server/middleware" middleware2 "github.com/Wei-Shaw/sub2api/internal/server/middleware"
"github.com/Wei-Shaw/sub2api/internal/pkg/pagination"
"github.com/Wei-Shaw/sub2api/internal/service" "github.com/Wei-Shaw/sub2api/internal/service"
"github.com/gin-gonic/gin" "github.com/gin-gonic/gin"

View File

@@ -2,42 +2,14 @@ package repository
import ( import (
"context" "context"
_ "embed"
"fmt" "fmt"
"strconv"
"time" "time"
"github.com/Wei-Shaw/sub2api/internal/service" "github.com/Wei-Shaw/sub2api/internal/service"
"github.com/redis/go-redis/v9" "github.com/redis/go-redis/v9"
) )
const ( const stickySessionPrefix = "sticky_session:"
stickySessionPrefix = "sticky_session:"
clientAffinityPrefix = "client_affinity:"
clientAffinityReversePrefix = "client_affinity_rev:"
)
var (
//go:embed lua/get_affinity.lua
getAffinityLua string
//go:embed lua/update_affinity.lua
updateAffinityLua string
//go:embed lua/get_affinity_count.lua
getAffinityCountLua string
//go:embed lua/get_affinity_clients.lua
getAffinityClientsLua string
//go:embed lua/get_affinity_clients_with_scores.lua
getAffinityClientsWithScoresLua string
//go:embed lua/clear_account_affinity.lua
clearAccountAffinityLua string
getAffinityScript = redis.NewScript(getAffinityLua)
updateAffinityScript = redis.NewScript(updateAffinityLua)
getAffinityCountScript = redis.NewScript(getAffinityCountLua)
getAffinityClientsScript = redis.NewScript(getAffinityClientsLua)
getAffinityClientsWithScoresScript = redis.NewScript(getAffinityClientsWithScoresLua)
clearAccountAffinityScript = redis.NewScript(clearAccountAffinityLua)
)
type gatewayCache struct { type gatewayCache struct {
rdb *redis.Client rdb *redis.Client
@@ -47,16 +19,6 @@ func NewGatewayCache(rdb *redis.Client) service.GatewayCache {
return &gatewayCache{rdb: rdb} return &gatewayCache{rdb: rdb}
} }
// ensureScriptLoaded 确保 Lua 脚本已加载到 Redis 服务器的脚本缓存中。
// Pipeline 中的 Script.Run 只发送 EVALSHA如果 Redis 重启过导致脚本缓存丢失,
// EVALSHA 会返回 NOSCRIPT 错误。此方法提前加载脚本以避免该问题。
func ensureScriptLoaded(ctx context.Context, rdb *redis.Client, script *redis.Script) {
exists, err := script.Exists(ctx, rdb).Result()
if err != nil || len(exists) == 0 || !exists[0] {
_ = script.Load(ctx, rdb).Err()
}
}
// buildSessionKey 构建 session key包含 groupID 实现分组隔离 // buildSessionKey 构建 session key包含 groupID 实现分组隔离
// 格式: sticky_session:{groupID}:{sessionHash} // 格式: sticky_session:{groupID}:{sessionHash}
func buildSessionKey(groupID int64, sessionHash string) string { func buildSessionKey(groupID int64, sessionHash string) string {
@@ -79,218 +41,13 @@ func (c *gatewayCache) RefreshSessionTTL(ctx context.Context, groupID int64, ses
} }
// DeleteSessionAccountID 删除粘性会话与账号的绑定关系。 // DeleteSessionAccountID 删除粘性会话与账号的绑定关系。
// 当检测到绑定的账号不可用(如状态错误、禁用、不可调度等)时调用,
// 以便下次请求能够重新选择可用账号。
//
// DeleteSessionAccountID removes the sticky session binding for the given session.
// Called when the bound account becomes unavailable (e.g., error status, disabled,
// or unschedulable), allowing subsequent requests to select a new available account.
func (c *gatewayCache) DeleteSessionAccountID(ctx context.Context, groupID int64, sessionHash string) error { func (c *gatewayCache) DeleteSessionAccountID(ctx context.Context, groupID int64, sessionHash string) error {
key := buildSessionKey(groupID, sessionHash) key := buildSessionKey(groupID, sessionHash)
return c.rdb.Del(ctx, key).Err() return c.rdb.Del(ctx, key).Err()
} }
// buildAffinityKey 构建正向亲和 keyclient → accounts
// 格式: client_affinity:{groupID}:{clientID}
func buildAffinityKey(groupID int64, clientID string) string {
return fmt.Sprintf("%s%d:%s", clientAffinityPrefix, groupID, clientID)
}
// buildAffinityReverseKey 构建反向亲和 keyaccount → clients
// 格式: client_affinity_rev:{groupID}:{accountID}
func buildAffinityReverseKey(groupID int64, accountID int64) string {
return fmt.Sprintf("%s%d:%d", clientAffinityReversePrefix, groupID, accountID)
}
func (c *gatewayCache) GetClientAffinityAccounts(ctx context.Context, groupID int64, clientID string, ttl time.Duration) ([]int64, error) {
key := buildAffinityKey(groupID, clientID)
now := time.Now().Unix()
expireThreshold := now - int64(ttl.Seconds())
result, err := getAffinityScript.Run(ctx, c.rdb, []string{key}, expireThreshold).StringSlice()
if err != nil {
if err == redis.Nil {
return nil, nil
}
return nil, err
}
accountIDs := make([]int64, 0, len(result))
for _, s := range result {
id, err := strconv.ParseInt(s, 10, 64)
if err != nil {
continue
}
accountIDs = append(accountIDs, id)
}
return accountIDs, nil
}
func (c *gatewayCache) UpdateClientAffinity(ctx context.Context, groupID int64, clientID string, accountID int64, ttl time.Duration) error {
fwdKey := buildAffinityKey(groupID, clientID)
revKey := buildAffinityReverseKey(groupID, accountID)
now := time.Now().Unix()
ttlSeconds := int64(ttl.Seconds())
expireThreshold := now - ttlSeconds
return updateAffinityScript.Run(ctx, c.rdb, []string{fwdKey, revKey},
now, ttlSeconds, accountID, expireThreshold, clientID,
).Err()
}
// GetAccountAffinityCountBatch 批量获取账号的亲和客户端数量(惰性清理过期成员)
func (c *gatewayCache) GetAccountAffinityCountBatch(ctx context.Context, groupID int64, accountIDs []int64, ttl time.Duration) (map[int64]int64, error) {
if len(accountIDs) == 0 {
return map[int64]int64{}, nil
}
now := time.Now().Unix()
expireThreshold := now - int64(ttl.Seconds())
ensureScriptLoaded(ctx, c.rdb, getAffinityCountScript)
pipe := c.rdb.Pipeline()
cmds := make([]*redis.Cmd, len(accountIDs))
for i, accID := range accountIDs {
key := buildAffinityReverseKey(groupID, accID)
cmds[i] = getAffinityCountScript.Run(ctx, pipe, []string{key}, expireThreshold)
}
_, err := pipe.Exec(ctx)
if err != nil && err != redis.Nil {
return nil, err
}
result := make(map[int64]int64, len(accountIDs))
for i, accID := range accountIDs {
count, _ := cmds[i].Int64()
result[accID] = count
}
return result, nil
}
// GetAccountAffinityClientsBatch 批量获取每个账号跨所有分组的亲和客户端列表(去重)。
// accountGroups: map[accountID][]groupID对每个 (groupID, accountID) 组合查询反向索引。
func (c *gatewayCache) GetAccountAffinityClientsBatch(ctx context.Context, accountGroups map[int64][]int64, ttl time.Duration) (map[int64][]string, error) {
if len(accountGroups) == 0 {
return map[int64][]string{}, nil
}
now := time.Now().Unix()
expireThreshold := now - int64(ttl.Seconds())
// 构建所有 (accountID, groupID) 组合的查询
type queryItem struct {
accountID int64
groupID int64
}
var queries []queryItem
for accID, groupIDs := range accountGroups {
for _, gID := range groupIDs {
queries = append(queries, queryItem{accountID: accID, groupID: gID})
}
}
ensureScriptLoaded(ctx, c.rdb, getAffinityClientsScript)
pipe := c.rdb.Pipeline()
cmds := make([]*redis.Cmd, len(queries))
for i, q := range queries {
key := buildAffinityReverseKey(q.groupID, q.accountID)
cmds[i] = getAffinityClientsScript.Run(ctx, pipe, []string{key}, expireThreshold)
}
_, err := pipe.Exec(ctx)
if err != nil && err != redis.Nil {
return nil, err
}
// 合并结果:同一个 accountID 跨多个 group 的 clientID 去重
result := make(map[int64][]string, len(accountGroups))
seen := make(map[int64]map[string]struct{}, len(accountGroups))
for i, q := range queries {
clients, _ := cmds[i].StringSlice()
if len(clients) == 0 {
continue
}
if seen[q.accountID] == nil {
seen[q.accountID] = make(map[string]struct{})
}
for _, clientID := range clients {
if _, exists := seen[q.accountID][clientID]; !exists {
seen[q.accountID][clientID] = struct{}{}
result[q.accountID] = append(result[q.accountID], clientID)
}
}
}
return result, nil
}
// GetAccountAffinityClientsWithScores 获取单个账号跨所有分组的亲和客户端列表(含最后活跃时间戳,去重取最近)。
func (c *gatewayCache) GetAccountAffinityClientsWithScores(
ctx context.Context,
accountID int64,
groupIDs []int64,
ttl time.Duration,
) ([]service.AffinityClient, error) {
if len(groupIDs) == 0 {
return nil, nil
}
now := time.Now().Unix()
expireThreshold := now - int64(ttl.Seconds())
ensureScriptLoaded(ctx, c.rdb, getAffinityClientsWithScoresScript)
pipe := c.rdb.Pipeline()
cmds := make([]*redis.Cmd, len(groupIDs))
for i, gID := range groupIDs {
key := buildAffinityReverseKey(gID, accountID)
cmds[i] = getAffinityClientsWithScoresScript.Run(ctx, pipe, []string{key}, expireThreshold)
}
_, err := pipe.Exec(ctx)
if err != nil && err != redis.Nil {
return nil, err
}
// 合并跨组结果,同一 clientID 取最近的 lastActive
seen := make(map[string]int64) // clientID → max timestamp
for _, cmd := range cmds {
vals, _ := cmd.StringSlice()
// vals 格式: [clientID1, score1, clientID2, score2, ...]
for j := 0; j+1 < len(vals); j += 2 {
clientID := vals[j]
ts, _ := strconv.ParseInt(vals[j+1], 10, 64)
if existing, ok := seen[clientID]; !ok || ts > existing {
seen[clientID] = ts
}
}
}
result := make([]service.AffinityClient, 0, len(seen))
for clientID, ts := range seen {
result = append(result, service.AffinityClient{
ClientID: clientID,
LastActive: time.Unix(ts, 0),
})
}
// 按最后活跃时间降序排序
service.SortAffinityClients(result)
return result, nil
}
// ClearAccountAffinity 清除指定账号在所有分组的亲和记录(正向+反向索引)。
// 对每个 groupID 执行 Lua 脚本:读取反向索引获取所有客户端,
// 从每个客户端的正向索引中移除该账号,然后删除反向索引。
func (c *gatewayCache) ClearAccountAffinity(ctx context.Context, accountID int64, groupIDs []int64) error {
if len(groupIDs) == 0 {
return nil
}
ensureScriptLoaded(ctx, c.rdb, clearAccountAffinityScript)
pipe := c.rdb.Pipeline()
for _, gID := range groupIDs {
revKey := buildAffinityReverseKey(gID, accountID)
clearAccountAffinityScript.Run(ctx, pipe, []string{revKey}, gID, accountID)
}
_, err := pipe.Exec(ctx)
if err != nil && err != redis.Nil {
return err
}
return nil
}

View File

@@ -227,3 +227,24 @@ func calculateTokenStatsCost(pricing *ChannelModelPricing, tokens UsageTokens) *
} }
return &cost return &cost
} }
// applyAccountStatsCost resolves the account stats cost for a usage log entry.
// It resolves the upstream model (falling back to the requested model) and calls
// the 4-level priority chain via resolveAccountStatsCost.
func applyAccountStatsCost(
ctx context.Context,
usageLog *UsageLog,
cs *ChannelService, bs *BillingService,
accountID int64, groupID int64,
upstreamModel, requestedModel string,
tokens UsageTokens,
totalCost float64,
) {
model := upstreamModel
if model == "" {
model = requestedModel
}
usageLog.AccountStatsCost = resolveAccountStatsCost(
ctx, cs, bs, accountID, groupID, model, tokens, 1, totalCost,
)
}

View File

@@ -40,6 +40,7 @@ type Channel struct {
BillingModelSource string // "requested", "upstream", or "channel_mapped" BillingModelSource string // "requested", "upstream", or "channel_mapped"
RestrictModels bool // 是否限制模型(仅允许定价列表中的模型) RestrictModels bool // 是否限制模型(仅允许定价列表中的模型)
Features string // 渠道特性描述JSON 数组),用于支付页面展示 Features string // 渠道特性描述JSON 数组),用于支付页面展示
FeaturesConfig map[string]any // 渠道功能配置(如 web search emulation
CreatedAt time.Time CreatedAt time.Time
UpdatedAt time.Time UpdatedAt time.Time
@@ -222,6 +223,19 @@ func (c *Channel) Clone() *Channel {
return &cp return &cp
} }
// IsWebSearchEmulationEnabled 返回该渠道是否为指定平台启用了 web search 模拟。
func (c *Channel) IsWebSearchEmulationEnabled(platform string) bool {
if c == nil || c.FeaturesConfig == nil {
return false
}
wse, ok := c.FeaturesConfig[featureKeyWebSearchEmulation].(map[string]any)
if !ok {
return false
}
enabled, ok := wse[platform].(bool)
return ok && enabled
}
// deepCopyFeaturesConfig creates a deep copy of FeaturesConfig to prevent cache pollution. // deepCopyFeaturesConfig creates a deep copy of FeaturesConfig to prevent cache pollution.
func deepCopyFeaturesConfig(src map[string]any) map[string]any { func deepCopyFeaturesConfig(src map[string]any) map[string]any {
dst := make(map[string]any, len(src)) dst := make(map[string]any, len(src))

View File

@@ -258,6 +258,9 @@ const (
// Account Quota Notification // Account Quota Notification
SettingKeyAccountQuotaNotifyEnabled = "account_quota_notify_enabled" // 全局开关 SettingKeyAccountQuotaNotifyEnabled = "account_quota_notify_enabled" // 全局开关
SettingKeyAccountQuotaNotifyEmails = "account_quota_notify_emails" // 管理员通知邮箱列表JSON 数组) SettingKeyAccountQuotaNotifyEmails = "account_quota_notify_emails" // 管理员通知邮箱列表JSON 数组)
// Web Search Emulation
SettingKeyWebSearchEmulationConfig = "web_search_emulation_config" // JSON 配置
) )
// AdminAPIKeyPrefix is the prefix for admin API keys (distinct from user "sk-" keys). // AdminAPIKeyPrefix is the prefix for admin API keys (distinct from user "sk-" keys).

View File

@@ -49,6 +49,10 @@ type EmailCache interface {
// Returns true if in cooldown period (email was sent recently) // Returns true if in cooldown period (email was sent recently)
IsPasswordResetEmailInCooldown(ctx context.Context, email string) bool IsPasswordResetEmailInCooldown(ctx context.Context, email string) bool
SetPasswordResetEmailCooldown(ctx context.Context, email string, ttl time.Duration) error SetPasswordResetEmailCooldown(ctx context.Context, email string, ttl time.Duration) error
// Notify code rate limiting per user
IncrNotifyCodeUserRate(ctx context.Context, userID int64, window time.Duration) (int64, error)
GetNotifyCodeUserRate(ctx context.Context, userID int64) (int64, error)
} }
// VerificationCodeData represents verification code data // VerificationCodeData represents verification code data

View File

@@ -30,7 +30,6 @@ type ProviderInstanceResponse struct {
Limits string `json:"limits"` Limits string `json:"limits"`
Enabled bool `json:"enabled"` Enabled bool `json:"enabled"`
RefundEnabled bool `json:"refund_enabled"` RefundEnabled bool `json:"refund_enabled"`
AllowUserRefund bool `json:"allow_user_refund"`
SortOrder int `json:"sort_order"` SortOrder int `json:"sort_order"`
PaymentMode string `json:"payment_mode"` PaymentMode string `json:"payment_mode"`
} }
@@ -47,7 +46,7 @@ func (s *PaymentConfigService) ListProviderInstancesWithConfig(ctx context.Conte
resp := ProviderInstanceResponse{ resp := ProviderInstanceResponse{
ID: int64(inst.ID), ProviderKey: inst.ProviderKey, Name: inst.Name, ID: int64(inst.ID), ProviderKey: inst.ProviderKey, Name: inst.Name,
SupportedTypes: splitTypes(inst.SupportedTypes), Limits: inst.Limits, SupportedTypes: splitTypes(inst.SupportedTypes), Limits: inst.Limits,
Enabled: inst.Enabled, RefundEnabled: inst.RefundEnabled, AllowUserRefund: inst.AllowUserRefund, Enabled: inst.Enabled, RefundEnabled: inst.RefundEnabled,
SortOrder: inst.SortOrder, PaymentMode: inst.PaymentMode, SortOrder: inst.SortOrder, PaymentMode: inst.PaymentMode,
} }
resp.Config, err = s.decryptAndMaskConfig(inst.Config) resp.Config, err = s.decryptAndMaskConfig(inst.Config)
@@ -111,12 +110,10 @@ func (s *PaymentConfigService) CreateProviderInstance(ctx context.Context, req C
if err != nil { if err != nil {
return nil, err return nil, err
} }
allowUserRefund := req.AllowUserRefund && req.RefundEnabled
return s.entClient.PaymentProviderInstance.Create(). return s.entClient.PaymentProviderInstance.Create().
SetProviderKey(req.ProviderKey).SetName(req.Name).SetConfig(enc). SetProviderKey(req.ProviderKey).SetName(req.Name).SetConfig(enc).
SetSupportedTypes(typesStr).SetEnabled(req.Enabled).SetPaymentMode(req.PaymentMode). SetSupportedTypes(typesStr).SetEnabled(req.Enabled).SetPaymentMode(req.PaymentMode).
SetSortOrder(req.SortOrder).SetLimits(req.Limits).SetRefundEnabled(req.RefundEnabled). SetSortOrder(req.SortOrder).SetLimits(req.Limits).SetRefundEnabled(req.RefundEnabled).
SetAllowUserRefund(allowUserRefund).
Save(ctx) Save(ctx)
} }
@@ -224,21 +221,6 @@ func (s *PaymentConfigService) UpdateProviderInstance(ctx context.Context, id in
} }
if req.RefundEnabled != nil { if req.RefundEnabled != nil {
u.SetRefundEnabled(*req.RefundEnabled) u.SetRefundEnabled(*req.RefundEnabled)
// Cascade: turning off refund_enabled also disables allow_user_refund
if !*req.RefundEnabled {
u.SetAllowUserRefund(false)
}
}
if req.AllowUserRefund != nil {
// Only allow enabling when refund_enabled is true
if *req.AllowUserRefund {
inst, err := s.entClient.PaymentProviderInstance.Get(ctx, id)
if err == nil && inst.RefundEnabled {
u.SetAllowUserRefund(true)
}
} else {
u.SetAllowUserRefund(false)
}
} }
if req.PaymentMode != nil { if req.PaymentMode != nil {
u.SetPaymentMode(*req.PaymentMode) u.SetPaymentMode(*req.PaymentMode)
@@ -250,7 +232,6 @@ func (s *PaymentConfigService) UpdateProviderInstance(ctx context.Context, id in
func (s *PaymentConfigService) GetUserRefundEligibleInstanceIDs(ctx context.Context) ([]string, error) { func (s *PaymentConfigService) GetUserRefundEligibleInstanceIDs(ctx context.Context) ([]string, error) {
instances, err := s.entClient.PaymentProviderInstance.Query(). instances, err := s.entClient.PaymentProviderInstance.Query().
Where( Where(
paymentproviderinstance.AllowUserRefundEQ(true),
paymentproviderinstance.RefundEnabledEQ(true), paymentproviderinstance.RefundEnabledEQ(true),
).Select(paymentproviderinstance.FieldID).All(ctx) ).Select(paymentproviderinstance.FieldID).All(ctx)
if err != nil { if err != nil {

View File

@@ -114,7 +114,6 @@ type CreateProviderInstanceRequest struct {
SortOrder int `json:"sort_order"` SortOrder int `json:"sort_order"`
Limits string `json:"limits"` Limits string `json:"limits"`
RefundEnabled bool `json:"refund_enabled"` RefundEnabled bool `json:"refund_enabled"`
AllowUserRefund bool `json:"allow_user_refund"`
} }
type UpdateProviderInstanceRequest struct { type UpdateProviderInstanceRequest struct {
@@ -126,7 +125,6 @@ type UpdateProviderInstanceRequest struct {
SortOrder *int `json:"sort_order"` SortOrder *int `json:"sort_order"`
Limits *string `json:"limits"` Limits *string `json:"limits"`
RefundEnabled *bool `json:"refund_enabled"` RefundEnabled *bool `json:"refund_enabled"`
AllowUserRefund *bool `json:"allow_user_refund"`
} }
type CreatePlanRequest struct { type CreatePlanRequest struct {
GroupID int64 `json:"group_id"` GroupID int64 `json:"group_id"`

View File

@@ -5,6 +5,8 @@ import (
"fmt" "fmt"
"log/slog" "log/slog"
"math" "math"
"strconv"
"strings"
"time" "time"
dbent "github.com/Wei-Shaw/sub2api/ent" dbent "github.com/Wei-Shaw/sub2api/ent"
@@ -20,12 +22,18 @@ func (s *PaymentService) HandlePaymentNotification(ctx context.Context, n *payme
if n.Status != payment.NotificationStatusSuccess { if n.Status != payment.NotificationStatusSuccess {
return nil return nil
} }
oid, err := parseOrderID(n.OrderID) // Look up order by out_trade_no (the external order ID we sent to the provider)
order, err := s.entClient.PaymentOrder.Query().Where(paymentorder.OutTradeNo(n.OrderID)).Only(ctx)
if err != nil { if err != nil {
return fmt.Errorf("invalid order ID: %s", n.OrderID) // Fallback: try legacy format (sub2_N where N is DB ID)
} trimmed := strings.TrimPrefix(n.OrderID, orderIDPrefix)
if oid, parseErr := strconv.ParseInt(trimmed, 10, 64); parseErr == nil {
return s.confirmPayment(ctx, oid, n.TradeNo, n.Amount, pk) return s.confirmPayment(ctx, oid, n.TradeNo, n.Amount, pk)
} }
return fmt.Errorf("order not found for out_trade_no: %s", n.OrderID)
}
return s.confirmPayment(ctx, order.ID, n.TradeNo, n.Amount, pk)
}
func (s *PaymentService) confirmPayment(ctx context.Context, oid int64, tradeNo string, paid float64, pk string) error { func (s *PaymentService) confirmPayment(ctx context.Context, oid int64, tradeNo string, paid float64, pk string) error {
o, err := s.entClient.PaymentOrder.Get(ctx, oid) o, err := s.entClient.PaymentOrder.Get(ctx, oid)

View File

@@ -10,7 +10,6 @@ import (
"time" "time"
dbent "github.com/Wei-Shaw/sub2api/ent" dbent "github.com/Wei-Shaw/sub2api/ent"
"github.com/Wei-Shaw/sub2api/ent/paymentauditlog"
"github.com/Wei-Shaw/sub2api/ent/paymentorder" "github.com/Wei-Shaw/sub2api/ent/paymentorder"
"github.com/Wei-Shaw/sub2api/internal/payment" "github.com/Wei-Shaw/sub2api/internal/payment"
"github.com/Wei-Shaw/sub2api/internal/payment/provider" "github.com/Wei-Shaw/sub2api/internal/payment/provider"
@@ -170,68 +169,6 @@ func (s *PaymentService) checkPendingLimit(ctx context.Context, tx *dbent.Tx, us
return nil return nil
} }
func (s *PaymentService) checkCancelRateLimit(ctx context.Context, userID int64, cfg *PaymentConfig) error {
if !cfg.CancelRateLimitEnabled || cfg.CancelRateLimitMax <= 0 {
return nil
}
windowStart := cancelRateLimitWindowStart(cfg)
operator := fmt.Sprintf("user:%d", userID)
count, err := s.entClient.PaymentAuditLog.Query().
Where(
paymentauditlog.ActionEQ("ORDER_CANCELLED"),
paymentauditlog.OperatorEQ(operator),
paymentauditlog.CreatedAtGTE(windowStart),
).Count(ctx)
if err != nil {
slog.Error("check cancel rate limit failed", "userID", userID, "error", err)
return nil // fail open
}
if count >= cfg.CancelRateLimitMax {
return infraerrors.TooManyRequests("CANCEL_RATE_LIMITED", "cancel rate limited").
WithMetadata(map[string]string{
"max": strconv.Itoa(cfg.CancelRateLimitMax),
"window": strconv.Itoa(cfg.CancelRateLimitWindow),
"unit": cfg.CancelRateLimitUnit,
})
}
return nil
}
func cancelRateLimitWindowStart(cfg *PaymentConfig) time.Time {
now := time.Now()
w := cfg.CancelRateLimitWindow
if w <= 0 {
w = 1
}
unit := cfg.CancelRateLimitUnit
if unit == "" {
unit = "day"
}
if cfg.CancelRateLimitMode == "fixed" {
switch unit {
case "minute":
t := now.Truncate(time.Minute)
return t.Add(-time.Duration(w-1) * time.Minute)
case "day":
y, m, d := now.Date()
t := time.Date(y, m, d, 0, 0, 0, 0, now.Location())
return t.AddDate(0, 0, -(w - 1))
default: // hour
t := now.Truncate(time.Hour)
return t.Add(-time.Duration(w-1) * time.Hour)
}
}
// rolling window
switch unit {
case "minute":
return now.Add(-time.Duration(w) * time.Minute)
case "day":
return now.AddDate(0, 0, -w)
default: // hour
return now.Add(-time.Duration(w) * time.Hour)
}
}
func (s *PaymentService) checkDailyLimit(ctx context.Context, tx *dbent.Tx, userID int64, amount, limit float64) error { func (s *PaymentService) checkDailyLimit(ctx context.Context, tx *dbent.Tx, userID int64, amount, limit float64) error {
if limit <= 0 { if limit <= 0 {
return nil return nil
@@ -252,19 +189,16 @@ func (s *PaymentService) checkDailyLimit(ctx context.Context, tx *dbent.Tx, user
} }
func (s *PaymentService) invokeProvider(ctx context.Context, order *dbent.PaymentOrder, req CreateOrderRequest, cfg *PaymentConfig, payAmountStr string, payAmount float64, plan *dbent.SubscriptionPlan) (*CreateOrderResponse, error) { func (s *PaymentService) invokeProvider(ctx context.Context, order *dbent.PaymentOrder, req CreateOrderRequest, cfg *PaymentConfig, payAmountStr string, payAmount float64, plan *dbent.SubscriptionPlan) (*CreateOrderResponse, error) {
s.EnsureProviders(ctx) // Select an instance across all providers that support the requested payment type.
providerKey := s.registry.GetProviderKey(req.PaymentType) // This enables cross-provider load balancing (e.g. EasyPay + Alipay direct for "alipay").
if providerKey == "" { sel, err := s.loadBalancer.SelectInstance(ctx, "", req.PaymentType, payment.Strategy(cfg.LoadBalanceStrategy), payAmount)
return nil, infraerrors.ServiceUnavailable("PAYMENT_GATEWAY_ERROR", fmt.Sprintf("payment method (%s) is not configured", req.PaymentType))
}
sel, err := s.loadBalancer.SelectInstance(ctx, providerKey, req.PaymentType, payment.Strategy(cfg.LoadBalanceStrategy), payAmount)
if err != nil { if err != nil {
return nil, fmt.Errorf("select provider instance: %w", err) return nil, infraerrors.ServiceUnavailable("PAYMENT_GATEWAY_ERROR", fmt.Sprintf("payment method (%s) is not configured", req.PaymentType))
} }
if sel == nil { if sel == nil {
return nil, infraerrors.TooManyRequests("NO_AVAILABLE_INSTANCE", "no available payment instance") return nil, infraerrors.TooManyRequests("NO_AVAILABLE_INSTANCE", "no available payment instance")
} }
prov, err := provider.CreateProvider(providerKey, sel.InstanceID, sel.Config) prov, err := provider.CreateProvider(sel.ProviderKey, sel.InstanceID, sel.Config)
if err != nil { if err != nil {
return nil, infraerrors.ServiceUnavailable("PAYMENT_GATEWAY_ERROR", "payment method is temporarily unavailable") return nil, infraerrors.ServiceUnavailable("PAYMENT_GATEWAY_ERROR", "payment method is temporarily unavailable")
} }
@@ -272,7 +206,7 @@ func (s *PaymentService) invokeProvider(ctx context.Context, order *dbent.Paymen
outTradeNo := order.OutTradeNo outTradeNo := order.OutTradeNo
pr, err := prov.CreatePayment(ctx, payment.CreatePaymentRequest{OrderID: outTradeNo, Amount: payAmountStr, PaymentType: req.PaymentType, Subject: subject, ClientIP: req.ClientIP, IsMobile: req.IsMobile, InstanceSubMethods: sel.SupportedTypes}) pr, err := prov.CreatePayment(ctx, payment.CreatePaymentRequest{OrderID: outTradeNo, Amount: payAmountStr, PaymentType: req.PaymentType, Subject: subject, ClientIP: req.ClientIP, IsMobile: req.IsMobile, InstanceSubMethods: sel.SupportedTypes})
if err != nil { if err != nil {
slog.Error("[PaymentService] CreatePayment failed", "provider", providerKey, "instance", sel.InstanceID, "error", err) slog.Error("[PaymentService] CreatePayment failed", "provider", sel.ProviderKey, "instance", sel.InstanceID, "error", err)
return nil, infraerrors.ServiceUnavailable("PAYMENT_GATEWAY_ERROR", fmt.Sprintf("payment gateway error: %s", err.Error())) return nil, infraerrors.ServiceUnavailable("PAYMENT_GATEWAY_ERROR", fmt.Sprintf("payment gateway error: %s", err.Error()))
} }
_, err = s.entClient.PaymentOrder.UpdateOneID(order.ID).SetNillablePaymentTradeNo(psNilIfEmpty(pr.TradeNo)).SetNillablePayURL(psNilIfEmpty(pr.PayURL)).SetNillableQrCode(psNilIfEmpty(pr.QRCode)).SetNillableProviderInstanceID(psNilIfEmpty(sel.InstanceID)).Save(ctx) _, err = s.entClient.PaymentOrder.UpdateOneID(order.ID).SetNillablePaymentTradeNo(psNilIfEmpty(pr.TradeNo)).SetNillablePayURL(psNilIfEmpty(pr.PayURL)).SetNillableQrCode(psNilIfEmpty(pr.QRCode)).SetNillableProviderInstanceID(psNilIfEmpty(sel.InstanceID)).Save(ctx)
@@ -357,6 +291,13 @@ func (s *PaymentService) AdminListOrders(ctx context.Context, userID int64, p Or
if p.PaymentType != "" { if p.PaymentType != "" {
q = q.Where(paymentorder.PaymentTypeEQ(p.PaymentType)) q = q.Where(paymentorder.PaymentTypeEQ(p.PaymentType))
} }
if p.Keyword != "" {
q = q.Where(paymentorder.Or(
paymentorder.OutTradeNoContainsFold(p.Keyword),
paymentorder.UserEmailContainsFold(p.Keyword),
paymentorder.UserNameContainsFold(p.Keyword),
))
}
total, err := q.Clone().Count(ctx) total, err := q.Clone().Count(ctx)
if err != nil { if err != nil {
return nil, 0, fmt.Errorf("count admin orders: %w", err) return nil, 0, fmt.Errorf("count admin orders: %w", err)
@@ -368,172 +309,3 @@ func (s *PaymentService) AdminListOrders(ctx context.Context, userID int64, p Or
} }
return orders, total, nil return orders, total, nil
} }
// --- Cancel & Expire ---
func (s *PaymentService) CancelOrder(ctx context.Context, orderID, userID int64) (string, error) {
o, err := s.entClient.PaymentOrder.Get(ctx, orderID)
if err != nil {
return "", infraerrors.NotFound("NOT_FOUND", "order not found")
}
if o.UserID != userID {
return "", infraerrors.Forbidden("FORBIDDEN", "no permission for this order")
}
if o.Status != OrderStatusPending {
return "", infraerrors.BadRequest("INVALID_STATUS", "order cannot be cancelled in current status")
}
return s.cancelCore(ctx, o, OrderStatusCancelled, fmt.Sprintf("user:%d", userID), "user cancelled order")
}
func (s *PaymentService) AdminCancelOrder(ctx context.Context, orderID int64) (string, error) {
o, err := s.entClient.PaymentOrder.Get(ctx, orderID)
if err != nil {
return "", infraerrors.NotFound("NOT_FOUND", "order not found")
}
if o.Status != OrderStatusPending {
return "", infraerrors.BadRequest("INVALID_STATUS", "order cannot be cancelled in current status")
}
return s.cancelCore(ctx, o, OrderStatusCancelled, "admin", "admin cancelled order")
}
func (s *PaymentService) cancelCore(ctx context.Context, o *dbent.PaymentOrder, fs, op, ad string) (string, error) {
if o.PaymentTradeNo != "" || o.PaymentType != "" {
if s.checkPaid(ctx, o) == "already_paid" {
return "already_paid", nil
}
}
c, err := s.entClient.PaymentOrder.Update().Where(paymentorder.IDEQ(o.ID), paymentorder.StatusEQ(OrderStatusPending)).SetStatus(fs).Save(ctx)
if err != nil {
return "", fmt.Errorf("update order status: %w", err)
}
if c > 0 {
auditAction := "ORDER_CANCELLED"
if fs == OrderStatusExpired {
auditAction = "ORDER_EXPIRED"
}
s.writeAuditLog(ctx, o.ID, auditAction, op, map[string]any{"detail": ad})
}
return "cancelled", nil
}
func (s *PaymentService) checkPaid(ctx context.Context, o *dbent.PaymentOrder) string {
prov, err := s.getOrderProvider(ctx, o)
if err != nil {
return ""
}
// Use OutTradeNo as fallback when PaymentTradeNo is empty
// (e.g. EasyPay popup mode where trade_no arrives only via notify callback)
tradeNo := o.PaymentTradeNo
if tradeNo == "" {
tradeNo = o.OutTradeNo
}
resp, err := prov.QueryOrder(ctx, tradeNo)
if err != nil {
slog.Warn("query upstream failed", "orderID", o.ID, "error", err)
return ""
}
if resp.Status == payment.ProviderStatusPaid {
if err := s.HandlePaymentNotification(ctx, &payment.PaymentNotification{TradeNo: o.PaymentTradeNo, OrderID: o.OutTradeNo, Amount: resp.Amount, Status: payment.ProviderStatusSuccess}, prov.ProviderKey()); err != nil {
slog.Error("fulfillment failed during checkPaid", "orderID", o.ID, "error", err)
// Still return already_paid — order was paid, fulfillment can be retried
}
return "already_paid"
}
if cp, ok := prov.(payment.CancelableProvider); ok {
_ = cp.CancelPayment(ctx, tradeNo)
}
return ""
}
// VerifyOrderByOutTradeNo actively queries the upstream provider to check
// if a payment was made, and processes it if so. This handles the case where
// the provider's notify callback was missed (e.g. EasyPay popup mode).
func (s *PaymentService) VerifyOrderByOutTradeNo(ctx context.Context, outTradeNo string, userID int64) (*dbent.PaymentOrder, error) {
o, err := s.entClient.PaymentOrder.Query().
Where(paymentorder.OutTradeNo(outTradeNo)).
Only(ctx)
if err != nil {
return nil, infraerrors.NotFound("NOT_FOUND", "order not found")
}
if o.UserID != userID {
return nil, infraerrors.Forbidden("FORBIDDEN", "no permission for this order")
}
// Only verify orders that are still pending or recently expired
if o.Status == OrderStatusPending || o.Status == OrderStatusExpired {
result := s.checkPaid(ctx, o)
if result == "already_paid" {
// Reload order to get updated status
o, err = s.entClient.PaymentOrder.Get(ctx, o.ID)
if err != nil {
return nil, fmt.Errorf("reload order: %w", err)
}
}
}
return o, nil
}
// VerifyOrderPublic verifies payment status without user authentication.
// Used by the payment result page when the user's session has expired.
func (s *PaymentService) VerifyOrderPublic(ctx context.Context, outTradeNo string) (*dbent.PaymentOrder, error) {
o, err := s.entClient.PaymentOrder.Query().
Where(paymentorder.OutTradeNo(outTradeNo)).
Only(ctx)
if err != nil {
return nil, infraerrors.NotFound("NOT_FOUND", "order not found")
}
if o.Status == OrderStatusPending || o.Status == OrderStatusExpired {
result := s.checkPaid(ctx, o)
if result == "already_paid" {
o, err = s.entClient.PaymentOrder.Get(ctx, o.ID)
if err != nil {
return nil, fmt.Errorf("reload order: %w", err)
}
}
}
return o, nil
}
func (s *PaymentService) ExpireTimedOutOrders(ctx context.Context) (int, error) {
now := time.Now()
orders, err := s.entClient.PaymentOrder.Query().Where(paymentorder.StatusEQ(OrderStatusPending), paymentorder.ExpiresAtLTE(now)).All(ctx)
if err != nil {
return 0, fmt.Errorf("query expired: %w", err)
}
n := 0
for _, o := range orders {
// Check upstream payment status before expiring — the user may have
// paid just before timeout and the webhook hasn't arrived yet.
outcome, _ := s.cancelCore(ctx, o, OrderStatusExpired, "system", "order expired")
if outcome == "already_paid" {
slog.Info("order was paid during expiry", "orderID", o.ID)
continue
}
if outcome != "" {
n++
}
}
return n, nil
}
// getOrderProvider creates a provider using the order's original instance config.
// Falls back to registry lookup if instance ID is missing (legacy orders).
func (s *PaymentService) getOrderProvider(ctx context.Context, o *dbent.PaymentOrder) (payment.Provider, error) {
if o.ProviderInstanceID != nil && *o.ProviderInstanceID != "" {
instID, err := strconv.ParseInt(*o.ProviderInstanceID, 10, 64)
if err == nil {
cfg, err := s.loadBalancer.GetInstanceConfig(ctx, instID)
if err == nil {
providerKey := s.registry.GetProviderKey(o.PaymentType)
if providerKey == "" {
providerKey = o.PaymentType
}
p, err := provider.CreateProvider(providerKey, *o.ProviderInstanceID, cfg)
if err == nil {
return p, nil
}
}
}
}
s.EnsureProviders(ctx)
return s.registry.GetProvider(o.PaymentType)
}

View File

@@ -107,6 +107,9 @@ type SystemSettings struct {
EnableMetadataPassthrough bool // 是否透传客户端原始 metadata默认 false EnableMetadataPassthrough bool // 是否透传客户端原始 metadata默认 false
EnableCCHSigning bool // 是否对 billing header cch 进行签名(默认 false EnableCCHSigning bool // 是否对 billing header cch 进行签名(默认 false
// Web Search Emulation
WebSearchEmulationEnabled bool // 是否启用 web search 模拟
// Balance low notification // Balance low notification
BalanceLowNotifyEnabled bool BalanceLowNotifyEnabled bool
BalanceLowNotifyThreshold float64 BalanceLowNotifyThreshold float64