diff --git a/backend/internal/service/channel_service.go b/backend/internal/service/channel_service.go index c6a249ef..ec8310f6 100644 --- a/backend/internal/service/channel_service.go +++ b/backend/internal/service/channel_service.go @@ -134,7 +134,7 @@ func (r ChannelMappingResult) ToUsageFields(reqModel, upstreamModel string) Chan const ( channelCacheTTL = 10 * time.Minute - channelErrorTTL = 5 * time.Second // DB 错误时的短缓存 + channelErrorTTL = 5 * time.Second // DB 错误时的短缓存 channelCacheDBTimeout = 10 * time.Second ) diff --git a/backend/internal/service/gateway_service.go b/backend/internal/service/gateway_service.go index 31137fb4..5d285fb6 100644 --- a/backend/internal/service/gateway_service.go +++ b/backend/internal/service/gateway_service.go @@ -12,6 +12,7 @@ import ( "log/slog" mathrand "math/rand" "net/http" + "net/url" "os" "path/filepath" "regexp" @@ -41,8 +42,7 @@ import ( const ( claudeAPIURL = "https://api.anthropic.com/v1/messages?beta=true" claudeAPICountTokensURL = "https://api.anthropic.com/v1/messages/count_tokens?beta=true" - stickySessionTTL = time.Hour // 粘性会话TTL - ClientAffinityTTL = 24 * time.Hour // 客户端亲和TTL + stickySessionTTL = time.Hour // 粘性会话TTL defaultMaxLineSize = 500 * 1024 * 1024 // Canonical Claude Code banner. Keep it EXACT (no trailing whitespace/newlines) // to match real Claude CLI traffic as closely as possible. When we need a visual @@ -60,28 +60,14 @@ const ( claudeMimicDebugInfoKey = "claude_mimic_debug_info" ) -// MediaType 媒体类型常量 -const ( - MediaTypeImage = "image" - MediaTypeVideo = "video" - MediaTypePrompt = "prompt" -) - -const ( - claudeMaxMessageOverheadTokens = 3 - claudeMaxBlockOverheadTokens = 1 - claudeMaxUnknownContentTokens = 4 -) - // ForceCacheBillingContextKey 强制缓存计费上下文键 // 用于粘性会话切换时,将 input_tokens 转为 cache_read_input_tokens 计费 type forceCacheBillingKeyType struct{} // accountWithLoad 账号与负载信息的组合,用于负载感知调度 type accountWithLoad struct { - account *Account - loadInfo *AccountLoadInfo - affinityCount int64 // 亲和客户端数量(反向索引),越少越优先 + account *Account + loadInfo *AccountLoadInfo } var ForceCacheBillingContextKey = forceCacheBillingKeyType{} @@ -345,10 +331,6 @@ var ( sseDataRe = regexp.MustCompile(`^data:\s*`) claudeCliUserAgentRe = regexp.MustCompile(`^claude-cli/\d+\.\d+\.\d+`) - // clientIDFromMetadataRegex 从 metadata.user_id 中提取客户端 ID(64位 hex) - // 格式: user_{64位hex}_account_... - clientIDFromMetadataRegex = regexp.MustCompile(`^user_([a-f0-9]{64})_account_`) - // claudeCodePromptPrefixes 用于检测 Claude Code 系统提示词的前缀列表 // 支持多种变体:标准版、Agent SDK 版、Explore Agent 版、Compact 版等 // 注意:前缀之间不应存在包含关系,否则会导致冗余匹配 @@ -366,12 +348,6 @@ var ErrNoAvailableAccounts = errors.New("no available accounts") // ErrClaudeCodeOnly 表示分组仅允许 Claude Code 客户端访问 var ErrClaudeCodeOnly = errors.New("this group only allows Claude Code clients") -// ErrAffinityNoSwitch 表示亲和账号不可用且不允许切换到其他账号 -var ErrAffinityNoSwitch = errors.New("affinity account unavailable and switching is disabled") - -// ErrAffinityLimitExceeded 表示亲和客户端限制已达上限 -var ErrAffinityLimitExceeded = errors.New("affinity client limit exceeded") - // allowedHeaders 白名单headers(参考CRS项目) var allowedHeaders = map[string]bool{ "accept": true, @@ -393,6 +369,8 @@ var allowedHeaders = map[string]bool{ "user-agent": true, "content-type": true, "accept-encoding": true, + "x-claude-code-session-id": true, + "x-client-request-id": true, } // GatewayCache 定义网关服务的缓存操作接口。 @@ -413,39 +391,6 @@ type GatewayCache interface { // DeleteSessionAccountID 删除粘性会话绑定,用于账号不可用时主动清理 // Delete sticky session binding, used to proactively clean up when account becomes unavailable DeleteSessionAccountID(ctx context.Context, groupID int64, sessionHash string) error - - // GetAffinityAccounts 获取亲和账号列表(按最近使用降序),同时清理过期成员 - GetAffinityAccounts(ctx context.Context, groupID int64, userID int64, clientID string, ttl time.Duration) ([]int64, error) - // UpdateAffinity 添加/更新亲和关系(更新 score 为当前时间戳,刷新 key TTL) - UpdateAffinity(ctx context.Context, groupID int64, userID int64, clientID string, accountID int64, ttl time.Duration) error - // GetAccountAffinityCountBatch 批量获取账号的亲和成员数量(惰性清理过期成员) - GetAccountAffinityCountBatch(ctx context.Context, groupID int64, accountIDs []int64, ttl time.Duration) (map[int64]int64, error) - // GetAccountAffinityClientsBatch 批量获取每个账号跨所有分组的亲和成员列表(去重) - // accountGroups: map[accountID][]groupID - // 返回值成员格式为 {userID}/{clientID} - GetAccountAffinityClientsBatch(ctx context.Context, accountGroups map[int64][]int64, ttl time.Duration) (map[int64][]string, error) - // GetAccountAffinityClientsWithScores 获取单个账号跨所有分组的亲和客户端列表(含最后活跃时间) - GetAccountAffinityClientsWithScores(ctx context.Context, accountID int64, groupIDs []int64, ttl time.Duration) ([]AffinityClient, error) - // ClearAccountAffinity 清除指定账号在所有分组的亲和记录(正向+反向索引) - // 用于账号关闭亲和时立即清理旧绑定 - ClearAccountAffinity(ctx context.Context, accountID int64, groupIDs []int64) error - // GetAffinityMultiCount 获取账号的多维度亲和计数 - // 返回: uniqueUsers, uniqueClients, perUserClients - GetAffinityMultiCount(ctx context.Context, groupID int64, accountID int64, targetUserID int64, ttl time.Duration) (users, clients, perUser int64, err error) -} - -// AffinityClient 亲和客户端信息(含用户 ID 和最后活跃时间) -type AffinityClient struct { - UserID int64 `json:"user_id"` - ClientID string `json:"client_id"` - LastActive time.Time `json:"last_active"` -} - -// SortAffinityClients 按最后活跃时间降序排序 -func SortAffinityClients(clients []AffinityClient) { - sort.Slice(clients, func(i, j int) bool { - return clients[i].LastActive.After(clients[j].LastActive) - }) } // derefGroupID safely dereferences *int64 to int64, returning 0 if nil @@ -516,20 +461,6 @@ func shouldClearStickySession(account *Account, requestedModel string) bool { return false } -// extractClientIDFromMetadata 从 metadata.user_id 中提取客户端 ID(64位 hex)。 -// 格式: user_{64位hex}_account_..._session_... -// 返回空字符串表示无法提取(非 Claude Code/Console 客户端)。 -func extractClientIDFromMetadata(metadataUserID string) string { - if metadataUserID == "" { - return "" - } - matches := clientIDFromMetadataRegex.FindStringSubmatch(metadataUserID) - if matches == nil { - return "" - } - return matches[1] -} - type AccountWaitPlan struct { AccountID int64 MaxConcurrency int @@ -572,10 +503,6 @@ type ForwardResult struct { // 图片生成计费字段(图片生成模型使用) ImageCount int // 生成的图片数量 ImageSize string // 图片尺寸 "1K", "2K", "4K" - - // Sora 媒体字段 - MediaType string // image / video / prompt - MediaURL string // 生成后的媒体地址(可选) } // UpstreamFailoverError indicates an upstream error that should trigger account failover. @@ -1315,10 +1242,6 @@ func (s *GatewayService) SelectAccountWithLoadAwareness(ctx context.Context, gro } } - // 提取客户端 ID(用于客户端亲和调度) - affinityClientID := extractClientIDFromMetadata(metadataUserID) - affinityUserID := sub2apiUserID - if s.debugModelRoutingEnabled() && requestedModel != "" { groupPlatform := "" if group != nil { @@ -1340,10 +1263,6 @@ func (s *GatewayService) SelectAccountWithLoadAwareness(ctx context.Context, gro if err != nil { return nil, err } - if shouldFilterAccountWithoutClientID(account, affinityClientID) { - localExcluded[account.ID] = struct{}{} - continue - } result, err := s.tryAcquireAccountSlot(ctx, account.ID, account.Concurrency) if err == nil && result.Acquired { @@ -1405,7 +1324,6 @@ func (s *GatewayService) SelectAccountWithLoadAwareness(ctx context.Context, gro if err != nil { return nil, err } - accounts = filterAccountsWithoutClientID(accounts, affinityClientID) if len(accounts) == 0 { return nil, ErrNoAvailableAccounts } @@ -1424,19 +1342,6 @@ func (s *GatewayService) SelectAccountWithLoadAwareness(ctx context.Context, gro _, excluded := excludedIDs[accountID] return excluded } - affinityFlow := newGatewayAffinityFlow( - s, - ctx, - groupID, - sessionHash, - requestedModel, - affinityClientID, - affinityUserID, - platform, - useMixed, - accountByID, - isExcluded, - ) // 获取模型路由配置(仅 anthropic 平台) var routingAccountIDs []int64 @@ -1599,10 +1504,7 @@ func (s *GatewayService) SelectAccountWithLoadAwareness(ctx context.Context, gro } if len(routingAvailable) > 0 { - // 批量获取亲和客户端数量 - s.populateAffinityCounts(ctx, routingAvailable, derefGroupID(groupID)) - - // 排序:优先级 > 负载率 > 亲和客户端数 > 最后使用时间 + // 排序:优先级 > 负载率 > 最后使用时间 sort.SliceStable(routingAvailable, func(i, j int) bool { a, b := routingAvailable[i], routingAvailable[j] if a.account.Priority != b.account.Priority { @@ -1611,9 +1513,6 @@ func (s *GatewayService) SelectAccountWithLoadAwareness(ctx context.Context, gro if a.loadInfo.LoadRate != b.loadInfo.LoadRate { return a.loadInfo.LoadRate < b.loadInfo.LoadRate } - if a.affinityCount != b.affinityCount { - return a.affinityCount < b.affinityCount - } switch { case a.account.LastUsedAt == nil && b.account.LastUsedAt != nil: return true @@ -1639,9 +1538,6 @@ func (s *GatewayService) SelectAccountWithLoadAwareness(ctx context.Context, gro if sessionHash != "" && s.cache != nil { _ = s.cache.SetSessionAccountID(ctx, derefGroupID(groupID), sessionHash, item.account.ID, stickySessionTTL) } - if affinityClientID != "" && affinityUserID > 0 && s.cache != nil && item.account.IsAffinityEnabled() { - _ = s.cache.UpdateAffinity(ctx, derefGroupID(groupID), affinityUserID, affinityClientID, item.account.ID, ClientAffinityTTL) - } if s.debugModelRoutingEnabled() { logger.LegacyPrintf("service.gateway", "[ModelRoutingDebug] routed select: group_id=%v model=%s session=%s account=%d", derefGroupID(groupID), requestedModel, shortSessionHash(sessionHash), item.account.ID) } @@ -1679,22 +1575,8 @@ func (s *GatewayService) SelectAccountWithLoadAwareness(ctx context.Context, gro } } - // ============ Layer 1.3: 用户亲和预处理(pinned_users 自动注入) ============ - affinityFlow.preprocessPinnedUsers(accounts) - - // ============ Layer 1.4: 客户端亲和调度(优先于粘性会话) ============ - affinityHit := false - if affinityResult, hit, err := affinityFlow.trySelectAffinityAccount(); err != nil { - return nil, err - } else { - affinityHit = hit - if affinityResult != nil { - return affinityResult, nil - } - } - - // ============ Layer 1.5: 粘性会话(仅在无模型路由配置 且 亲和未命中时生效) ============ - if !affinityHit && len(routingAccountIDs) == 0 && sessionHash != "" && stickyAccountID > 0 && !isExcluded(stickyAccountID) { + // ============ Layer 1.5: 粘性会话(仅在无模型路由配置时生效) ============ + if len(routingAccountIDs) == 0 && sessionHash != "" && stickyAccountID > 0 && !isExcluded(stickyAccountID) { accountID := stickyAccountID if accountID > 0 && !isExcluded(accountID) { account, ok := accountByID[accountID] @@ -1800,9 +1682,6 @@ func (s *GatewayService) SelectAccountWithLoadAwareness(ctx context.Context, gro loadMap, err := s.concurrencyService.GetAccountsLoadBatch(ctx, accountLoads) if err != nil { if result, ok := s.tryAcquireByLegacyOrder(ctx, candidates, groupID, sessionHash, preferOAuth); ok { - if affinityClientID != "" && affinityUserID > 0 && s.cache != nil && result.Account != nil && result.Account.IsAffinityEnabled() { - _ = s.cache.UpdateAffinity(ctx, derefGroupID(groupID), affinityUserID, affinityClientID, result.Account.ID, ClientAffinityTTL) - } return result, nil } } else { @@ -1820,37 +1699,13 @@ func (s *GatewayService) SelectAccountWithLoadAwareness(ctx context.Context, gro } } - // 批量获取亲和客户端数量(用于均衡分配新客户端) - s.populateAffinityCounts(ctx, available, derefGroupID(groupID)) - - // 分层过滤选择:优先级 → 亲和三区 → 负载率 → 亲和客户端数 → LRU + // 分层过滤选择:优先级 → 负载率 → LRU for len(available) > 0 { // 1. 取优先级最小的集合 candidates := filterByMinPriority(available) - // 2. 按亲和三区过滤:绿区优先 → 黄区降级 → 红区移除(在同优先级内) - candidates = classifyByAffinityZone(candidates) - if len(candidates) == 0 { - // 当前优先级组全部在红区,移除后回退到下一优先级组 - minPri := available[0].account.Priority - for _, a := range available[1:] { - if a.account.Priority < minPri { - minPri = a.account.Priority - } - } - newAvailable := make([]accountWithLoad, 0, len(available)) - for _, a := range available { - if a.account.Priority != minPri { - newAvailable = append(newAvailable, a) - } - } - available = newAvailable - continue - } - // 3. 取负载率最低的集合 + // 2. 取负载率最低的集合 candidates = filterByMinLoadRate(candidates) - // 3. 取亲和客户端数最少的集合 - candidates = filterByMinAffinityCount(candidates) - // 4. LRU 选择最久未用的账号 + // 3. LRU 选择最久未用的账号 selected := selectByLRU(candidates, preferOAuth) if selected == nil { break @@ -1865,10 +1720,6 @@ func (s *GatewayService) SelectAccountWithLoadAwareness(ctx context.Context, gro if sessionHash != "" && s.cache != nil { _ = s.cache.SetSessionAccountID(ctx, derefGroupID(groupID), sessionHash, selected.account.ID, stickySessionTTL) } - // 更新亲和关系 - if affinityClientID != "" && affinityUserID > 0 && s.cache != nil && selected.account.IsAffinityEnabled() { - _ = s.cache.UpdateAffinity(ctx, derefGroupID(groupID), affinityUserID, affinityClientID, selected.account.ID, ClientAffinityTTL) - } return &AccountSelectionResult{ Account: selected.account, Acquired: true, @@ -2077,9 +1928,6 @@ func (s *GatewayService) resolvePlatform(ctx context.Context, groupID *int64, gr } func (s *GatewayService) listSchedulableAccounts(ctx context.Context, groupID *int64, platform string, hasForcePlatform bool) ([]Account, bool, error) { - if platform == PlatformSora { - return s.listSoraSchedulableAccounts(ctx, groupID) - } if s.schedulerSnapshot != nil { accounts, useMixed, err := s.schedulerSnapshot.ListSchedulableAccounts(ctx, groupID, platform, hasForcePlatform) if err == nil { @@ -2176,53 +2024,6 @@ func (s *GatewayService) listSchedulableAccounts(ctx context.Context, groupID *i return accounts, useMixed, nil } -func (s *GatewayService) listSoraSchedulableAccounts(ctx context.Context, groupID *int64) ([]Account, bool, error) { - const useMixed = false - - var accounts []Account - var err error - if s.cfg != nil && s.cfg.RunMode == config.RunModeSimple { - accounts, err = s.accountRepo.ListByPlatform(ctx, PlatformSora) - } else if groupID != nil { - accounts, err = s.accountRepo.ListByGroup(ctx, *groupID) - } else { - accounts, err = s.accountRepo.ListByPlatform(ctx, PlatformSora) - } - if err != nil { - slog.Debug("account_scheduling_list_failed", - "group_id", derefGroupID(groupID), - "platform", PlatformSora, - "error", err) - return nil, useMixed, err - } - - filtered := make([]Account, 0, len(accounts)) - for _, acc := range accounts { - if acc.Platform != PlatformSora { - continue - } - if !s.isSoraAccountSchedulable(&acc) { - continue - } - filtered = append(filtered, acc) - } - slog.Debug("account_scheduling_list_sora", - "group_id", derefGroupID(groupID), - "platform", PlatformSora, - "raw_count", len(accounts), - "filtered_count", len(filtered)) - for _, acc := range filtered { - slog.Debug("account_scheduling_account_detail", - "account_id", acc.ID, - "name", acc.Name, - "platform", acc.Platform, - "type", acc.Type, - "status", acc.Status, - "tls_fingerprint", acc.IsTLSFingerprintEnabled()) - } - return filtered, useMixed, nil -} - // IsSingleAntigravityAccountGroup 检查指定分组是否只有一个 antigravity 平台的可调度账号。 // 用于 Handler 层在首次请求时提前设置 SingleAccountRetry context, // 避免单账号分组收到 503 时错误地设置模型限流标记导致后续请求连续快速失败。 @@ -2247,33 +2048,10 @@ func (s *GatewayService) isAccountAllowedForPlatform(account *Account, platform return account.Platform == platform } -func (s *GatewayService) isSoraAccountSchedulable(account *Account) bool { - return s.soraUnschedulableReason(account) == "" -} - -func (s *GatewayService) soraUnschedulableReason(account *Account) string { - if account == nil { - return "account_nil" - } - if account.Status != StatusActive { - return fmt.Sprintf("status=%s", account.Status) - } - if !account.Schedulable { - return "schedulable=false" - } - if account.TempUnschedulableUntil != nil && time.Now().Before(*account.TempUnschedulableUntil) { - return fmt.Sprintf("temp_unschedulable_until=%s", account.TempUnschedulableUntil.UTC().Format(time.RFC3339)) - } - return "" -} - func (s *GatewayService) isAccountSchedulableForSelection(account *Account) bool { if account == nil { return false } - if account.Platform == PlatformSora { - return s.isSoraAccountSchedulable(account) - } return account.IsSchedulable() } @@ -2281,12 +2059,6 @@ func (s *GatewayService) isAccountSchedulableForModelSelection(ctx context.Conte if account == nil { return false } - if account.Platform == PlatformSora { - if !s.isSoraAccountSchedulable(account) { - return false - } - return account.GetRateLimitRemainingTimeWithContext(ctx, requestedModel) <= 0 - } return account.IsSchedulableForModelWithContext(ctx, requestedModel) } @@ -2626,36 +2398,6 @@ func (s *GatewayService) getSchedulableAccount(ctx context.Context, accountID in return s.accountRepo.GetByID(ctx, accountID) } -// populateAffinityCounts 批量获取账号的亲和客户端数量并填入 accountWithLoad 切片。 -// 仅当存在开启了客户端亲和的账号时才查询 Redis,否则跳过。 -func (s *GatewayService) populateAffinityCounts(ctx context.Context, accounts []accountWithLoad, groupID int64) { - if s.cache == nil || len(accounts) == 0 { - return - } - // 快速检查:是否有任何账号开启了亲和 - hasAffinity := false - for _, acc := range accounts { - if acc.account.IsAffinityEnabled() { - hasAffinity = true - break - } - } - if !hasAffinity { - return - } - accountIDs := make([]int64, len(accounts)) - for i, acc := range accounts { - accountIDs[i] = acc.account.ID - } - countMap, err := s.cache.GetAccountAffinityCountBatch(ctx, groupID, accountIDs, ClientAffinityTTL) - if err != nil { - return // 查询失败不影响调度,affinityCount 保持 0 - } - for i := range accounts { - accounts[i].affinityCount = countMap[accounts[i].account.ID] - } -} - // filterByMinPriority 过滤出优先级最小的账号集合 func filterByMinPriority(accounts []accountWithLoad) []accountWithLoad { if len(accounts) == 0 { @@ -2696,64 +2438,6 @@ func filterByMinLoadRate(accounts []accountWithLoad) []accountWithLoad { return result } -// filterByMinAffinityCount 过滤出亲和客户端数最少的账号集合 -func filterByMinAffinityCount(accounts []accountWithLoad) []accountWithLoad { - if len(accounts) == 0 { - return accounts - } - minCount := accounts[0].affinityCount - for _, acc := range accounts[1:] { - if acc.affinityCount < minCount { - minCount = acc.affinityCount - } - } - result := make([]accountWithLoad, 0, len(accounts)) - for _, acc := range accounts { - if acc.affinityCount == minCount { - result = append(result, acc) - } - } - return result -} - -// classifyByAffinityZone 按亲和分区对候选账号进行分类。 -// 返回值:仅绿区账号(有绿区时),否则返回黄区账号。红区账号被移除。 -// 如果没有任何账号开启了亲和三区配置(即 affinity_base <= 0),则原样返回所有账号。 -func classifyByAffinityZone(accounts []accountWithLoad) []accountWithLoad { - if len(accounts) == 0 { - return accounts - } - // 快速检查:是否有任何账号配置了 affinity_base - hasZoneConfig := false - for _, acc := range accounts { - if acc.account.IsAffinityEnabled() && acc.account.GetAffinityBase() > 0 { - hasZoneConfig = true - break - } - } - if !hasZoneConfig { - return accounts - } - - greens := make([]accountWithLoad, 0, len(accounts)) - yellows := make([]accountWithLoad, 0, len(accounts)) - for _, acc := range accounts { - zone := acc.account.GetAffinityZone(acc.affinityCount) - switch zone { - case AffinityZoneGreen: - greens = append(greens, acc) - case AffinityZoneYellow: - yellows = append(yellows, acc) - case AffinityZoneRed: - // 红区:移除,不参与调度 - } - } - if len(greens) > 0 { - return greens - } - return yellows -} - // selectByLRU 从集合中选择最久未用的账号 // 如果有多个账号具有相同的最小 LastUsedAt,则随机选择一个 func selectByLRU(accounts []accountWithLoad, preferOAuth bool) *accountWithLoad { @@ -3514,9 +3198,6 @@ func (s *GatewayService) logDetailedSelectionFailure( stats.SampleMappingIDs, stats.SampleRateLimitIDs, ) - if platform == PlatformSora { - s.logSoraSelectionFailureDetails(ctx, groupID, sessionHash, requestedModel, accounts, excludedIDs, allowMixedScheduling) - } return stats } @@ -3574,9 +3255,6 @@ func (s *GatewayService) diagnoseSelectionFailure( } if !s.isAccountSchedulableForSelection(acc) { detail := "generic_unschedulable" - if acc.Platform == PlatformSora { - detail = s.soraUnschedulableReason(acc) - } return selectionFailureDiagnosis{Category: "unschedulable", Detail: detail} } if isPlatformFilteredForSelection(acc, platform, allowMixedScheduling) { @@ -3601,57 +3279,7 @@ func (s *GatewayService) diagnoseSelectionFailure( return selectionFailureDiagnosis{Category: "eligible"} } -func (s *GatewayService) logSoraSelectionFailureDetails( - ctx context.Context, - groupID *int64, - sessionHash string, - requestedModel string, - accounts []Account, - excludedIDs map[int64]struct{}, - allowMixedScheduling bool, -) { - const maxLines = 30 - logged := 0 - - for i := range accounts { - if logged >= maxLines { - break - } - acc := &accounts[i] - diagnosis := s.diagnoseSelectionFailure(ctx, acc, requestedModel, PlatformSora, excludedIDs, allowMixedScheduling) - if diagnosis.Category == "eligible" { - continue - } - detail := diagnosis.Detail - if detail == "" { - detail = "-" - } - logger.LegacyPrintf( - "service.gateway", - "[SelectAccountDetailed:Sora] group_id=%v model=%s session=%s account_id=%d account_platform=%s category=%s detail=%s", - derefGroupID(groupID), - requestedModel, - shortSessionHash(sessionHash), - acc.ID, - acc.Platform, - diagnosis.Category, - detail, - ) - logged++ - } - if len(accounts) > maxLines { - logger.LegacyPrintf( - "service.gateway", - "[SelectAccountDetailed:Sora] group_id=%v model=%s session=%s truncated=true total=%d logged=%d", - derefGroupID(groupID), - requestedModel, - shortSessionHash(sessionHash), - len(accounts), - logged, - ) - } -} - +// GetAccessToken 获取账号凭证 func isPlatformFilteredForSelection(acc *Account, platform string, allowMixedScheduling bool) bool { if acc == nil { return true @@ -3730,9 +3358,6 @@ func (s *GatewayService) isModelSupportedByAccount(account *Account, requestedMo } return mapAntigravityModel(account, requestedModel) != "" } - if account.Platform == PlatformSora { - return s.isSoraModelSupportedByAccount(account, requestedModel) - } if account.IsBedrock() { _, ok := ResolveBedrockModelID(account, requestedModel) return ok @@ -3749,143 +3374,6 @@ func (s *GatewayService) isModelSupportedByAccount(account *Account, requestedMo return account.IsModelSupported(requestedModel) } -func (s *GatewayService) isSoraModelSupportedByAccount(account *Account, requestedModel string) bool { - if account == nil { - return false - } - if strings.TrimSpace(requestedModel) == "" { - return true - } - - // 先走原始精确/通配符匹配。 - mapping := account.GetModelMapping() - if len(mapping) == 0 || account.IsModelSupported(requestedModel) { - return true - } - - aliases := buildSoraModelAliases(requestedModel) - if len(aliases) == 0 { - return false - } - - hasSoraSelector := false - for pattern := range mapping { - if !isSoraModelSelector(pattern) { - continue - } - hasSoraSelector = true - if matchPatternAnyAlias(pattern, aliases) { - return true - } - } - - // 兼容旧账号:mapping 存在但未配置任何 Sora 选择器(例如只含 gpt-*), - // 此时不应误拦截 Sora 模型请求。 - if !hasSoraSelector { - return true - } - - return false -} - -func matchPatternAnyAlias(pattern string, aliases []string) bool { - normalizedPattern := strings.ToLower(strings.TrimSpace(pattern)) - if normalizedPattern == "" { - return false - } - for _, alias := range aliases { - if matchWildcard(normalizedPattern, alias) { - return true - } - } - return false -} - -func isSoraModelSelector(pattern string) bool { - p := strings.ToLower(strings.TrimSpace(pattern)) - if p == "" { - return false - } - - switch { - case strings.HasPrefix(p, "sora"), - strings.HasPrefix(p, "gpt-image"), - strings.HasPrefix(p, "prompt-enhance"), - strings.HasPrefix(p, "sy_"): - return true - } - - return p == "video" || p == "image" -} - -func buildSoraModelAliases(requestedModel string) []string { - modelID := strings.ToLower(strings.TrimSpace(requestedModel)) - if modelID == "" { - return nil - } - - aliases := make([]string, 0, 8) - addAlias := func(value string) { - v := strings.ToLower(strings.TrimSpace(value)) - if v == "" { - return - } - for _, existing := range aliases { - if existing == v { - return - } - } - aliases = append(aliases, v) - } - - addAlias(modelID) - cfg, ok := GetSoraModelConfig(modelID) - if ok { - addAlias(cfg.Model) - switch cfg.Type { - case "video": - addAlias("video") - addAlias("sora") - addAlias(soraVideoFamilyAlias(modelID)) - case "image": - addAlias("image") - addAlias("gpt-image") - case "prompt_enhance": - addAlias("prompt-enhance") - } - return aliases - } - - switch { - case strings.HasPrefix(modelID, "sora"): - addAlias("video") - addAlias("sora") - addAlias(soraVideoFamilyAlias(modelID)) - case strings.HasPrefix(modelID, "gpt-image"): - addAlias("image") - addAlias("gpt-image") - case strings.HasPrefix(modelID, "prompt-enhance"): - addAlias("prompt-enhance") - default: - return nil - } - - return aliases -} - -func soraVideoFamilyAlias(modelID string) string { - switch { - case strings.HasPrefix(modelID, "sora2pro-hd"): - return "sora2pro-hd" - case strings.HasPrefix(modelID, "sora2pro"): - return "sora2pro" - case strings.HasPrefix(modelID, "sora2"): - return "sora2" - default: - return "" - } -} - // GetAccessToken 获取账号凭证 func (s *GatewayService) GetAccessToken(ctx context.Context, account *Account) (string, string, error) { switch account.Type { @@ -4412,10 +3900,12 @@ func (s *GatewayService) Forward(ctx context.Context, c *gin.Context, account *A return nil, err } - // 获取代理URL + // 获取代理URL(自定义 base URL 模式下,proxy 通过 buildCustomRelayURL 作为查询参数传递) proxyURL := "" if account.ProxyID != nil && account.Proxy != nil { - proxyURL = account.Proxy.URL() + if !account.IsCustomBaseURLEnabled() || account.GetCustomBaseURL() == "" { + proxyURL = account.Proxy.URL() + } } // 解析 TLS 指纹 profile(同一请求生命周期内不变,避免重试循环中重复解析) @@ -4824,7 +4314,6 @@ func (s *GatewayService) Forward(ctx context.Context, c *gin.Context, account *A } // 处理正常响应 - ctx = withClaudeMaxResponseRewriteContext(ctx, c, parsed) // 触发上游接受回调(提前释放串行锁,不等流完成) if parsed.OnUpstreamAccepted != nil { @@ -5891,6 +5380,16 @@ func (s *GatewayService) buildUpstreamRequest(ctx context.Context, c *gin.Contex } targetURL = validatedURL + "/v1/messages?beta=true" } + } else if account.IsCustomBaseURLEnabled() { + customURL := account.GetCustomBaseURL() + if customURL == "" { + return nil, fmt.Errorf("custom_base_url is enabled but not configured for account %d", account.ID) + } + validatedURL, err := s.validateUpstreamBaseURL(customURL) + if err != nil { + return nil, err + } + targetURL = s.buildCustomRelayURL(validatedURL, "/v1/messages", account) } clientHeaders := http.Header{} @@ -6006,6 +5505,15 @@ func (s *GatewayService) buildUpstreamRequest(ctx context.Context, c *gin.Contex } } + // 同步 X-Claude-Code-Session-Id 头:取 body 中已处理的 metadata.user_id 的 session_id 覆盖 + if sessionHeader := getHeaderRaw(req.Header, "X-Claude-Code-Session-Id"); sessionHeader != "" { + if uid := gjson.GetBytes(body, "metadata.user_id").String(); uid != "" { + if parsed := ParseMetadataUserID(uid); parsed != nil { + setHeaderRaw(req.Header, "X-Claude-Code-Session-Id", parsed.SessionID) + } + } + } + // === DEBUG: 打印上游转发请求(headers + body 摘要),与 CLIENT_ORIGINAL 对比 === s.debugLogGatewaySnapshot("UPSTREAM_FORWARD", req.Header, body, map[string]string{ "url": req.URL.String(), @@ -7005,7 +6513,6 @@ func (s *GatewayService) handleStreamingResponse(ctx context.Context, resp *http needModelReplace := originalModel != mappedModel clientDisconnected := false // 客户端断开标志,断开后继续读取上游以获取完整usage sawTerminalEvent := false - skipAccountTTLOverride := false pendingEventLines := make([]string, 0, 4) @@ -7067,25 +6574,17 @@ func (s *GatewayService) handleStreamingResponse(ctx context.Context, resp *http if msg, ok := event["message"].(map[string]any); ok { if u, ok := msg["usage"].(map[string]any); ok { eventChanged = reconcileCachedTokens(u) || eventChanged - claudeMaxOutcome := applyClaudeMaxSimulationToUsageJSONMap(ctx, u, originalModel, account.ID) - if claudeMaxOutcome.Simulated { - skipAccountTTLOverride = true - } } } } if eventType == "message_delta" { if u, ok := event["usage"].(map[string]any); ok { eventChanged = reconcileCachedTokens(u) || eventChanged - claudeMaxOutcome := applyClaudeMaxSimulationToUsageJSONMap(ctx, u, originalModel, account.ID) - if claudeMaxOutcome.Simulated { - skipAccountTTLOverride = true - } } } // Cache TTL Override: 重写 SSE 事件中的 cache_creation 分类 - if account.IsCacheTTLOverrideEnabled() && !skipAccountTTLOverride { + if account.IsCacheTTLOverrideEnabled() { overrideTarget := account.GetCacheTTLOverrideTarget() if eventType == "message_start" { if msg, ok := event["message"].(map[string]any); ok { @@ -7517,13 +7016,8 @@ func (s *GatewayService) handleNonStreamingResponse(ctx context.Context, resp *h } } - claudeMaxOutcome := applyClaudeMaxSimulationToUsage(ctx, &response.Usage, originalModel, account.ID) - if claudeMaxOutcome.Simulated { - body = rewriteClaudeUsageJSONBytes(body, response.Usage) - } - // Cache TTL Override: 重写 non-streaming 响应中的 cache_creation 分类 - if account.IsCacheTTLOverrideEnabled() && !claudeMaxOutcome.Simulated { + if account.IsCacheTTLOverrideEnabled() { overrideTarget := account.GetCacheTTLOverrideTarget() if applyCacheTTLOverride(&response.Usage, overrideTarget) { // 同步更新 body JSON 中的嵌套 cache_creation 对象 @@ -7901,12 +7395,10 @@ func writeUsageLogBestEffort(ctx context.Context, repo UsageLogRepository, usage // recordUsageOpts 内部选项,参数化 RecordUsage 与 RecordUsageWithLongContext 的差异点。 type recordUsageOpts struct { - // Claude Max 策略所需的 ParsedRequest(可选,仅 Claude 路径传入) + // ParsedRequest(可选,仅 Claude 路径传入) ParsedRequest *ParsedRequest // EnableClaudePath 启用 Claude 路径特有逻辑: - // - Claude Max 缓存计费策略 - // - Sora 媒体类型分支(image/video/prompt) // - MediaType 字段写入使用日志 EnableClaudePath bool @@ -7998,8 +7490,6 @@ type recordUsageCoreInput struct { // recordUsageCore 是 RecordUsage 和 RecordUsageWithLongContext 的统一实现。 // opts 中的字段控制两者之间的差异行为: -// - ParsedRequest != nil → 启用 Claude Max 缓存计费策略 -// - EnableSoraMedia → 启用 Sora MediaType 分支(image/video/prompt) // - LongContextThreshold > 0 → Token 计费回退走 CalculateCostWithLongContext func (s *GatewayService) recordUsageCore(ctx context.Context, input *recordUsageCoreInput, opts *recordUsageOpts) error { result := input.Result @@ -8017,21 +7507,9 @@ func (s *GatewayService) recordUsageCore(ctx context.Context, input *recordUsage result.Usage.InputTokens = 0 } - // Claude Max cache billing policy(仅 Claude 路径启用) - cacheTTLOverridden := false - simulatedClaudeMax := false - if opts.EnableClaudePath { - var apiKeyGroup *Group - if apiKey != nil { - apiKeyGroup = apiKey.Group - } - claudeMaxOutcome := applyClaudeMaxCacheBillingPolicyToUsage(&result.Usage, opts.ParsedRequest, apiKeyGroup, result.Model, account.ID) - simulatedClaudeMax = claudeMaxOutcome.Simulated || - (shouldApplyClaudeMaxBillingRulesForUsage(apiKeyGroup, result.Model, opts.ParsedRequest) && hasCacheCreationTokens(result.Usage)) - } - // Cache TTL Override: 确保计费时 token 分类与账号设置一致 - if account.IsCacheTTLOverrideEnabled() && !simulatedClaudeMax { + cacheTTLOverridden := false + if account.IsCacheTTLOverrideEnabled() { applyCacheTTLOverride(&result.Usage, account.GetCacheTTLOverrideTarget()) cacheTTLOverridden = (result.Usage.CacheCreation5mTokens + result.Usage.CacheCreation1hTokens) > 0 } @@ -8113,16 +7591,6 @@ func (s *GatewayService) calculateRecordUsageCost( multiplier float64, opts *recordUsageOpts, ) *CostBreakdown { - // Sora 媒体类型分支(仅 Claude 路径启用) - if opts.EnableClaudePath { - if result.MediaType == MediaTypeImage || result.MediaType == MediaTypeVideo { - return s.calculateSoraMediaCost(result, apiKey, billingModel, multiplier) - } - if result.MediaType == MediaTypePrompt { - return &CostBreakdown{} - } - } - // 图片生成计费 if result.ImageCount > 0 { return s.calculateImageCost(ctx, result, apiKey, billingModel, multiplier) @@ -8132,28 +7600,6 @@ func (s *GatewayService) calculateRecordUsageCost( return s.calculateTokenCost(ctx, result, apiKey, billingModel, multiplier, opts) } -// calculateSoraMediaCost 计算 Sora 图片/视频的费用。 -func (s *GatewayService) calculateSoraMediaCost( - result *ForwardResult, - apiKey *APIKey, - billingModel string, - multiplier float64, -) *CostBreakdown { - var soraConfig *SoraPriceConfig - if apiKey.Group != nil { - soraConfig = &SoraPriceConfig{ - ImagePrice360: apiKey.Group.SoraImagePrice360, - ImagePrice540: apiKey.Group.SoraImagePrice540, - VideoPricePerRequest: apiKey.Group.SoraVideoPricePerRequest, - VideoPricePerRequestHD: apiKey.Group.SoraVideoPricePerRequestHD, - } - } - if result.MediaType == MediaTypeImage { - return s.billingService.CalculateSoraImageCost(result.ImageSize, result.ImageCount, soraConfig, multiplier) - } - return s.billingService.CalculateSoraVideoCost(billingModel, soraConfig, multiplier) -} - // resolveChannelPricing 检查指定模型是否存在渠道级别定价。 // 返回非 nil 的 ResolvedPricing 表示有渠道定价,nil 表示走默认定价路径。 func (s *GatewayService) resolveChannelPricing(ctx context.Context, billingModel string, apiKey *APIKey) *ResolvedPricing { @@ -8176,7 +7622,7 @@ func (s *GatewayService) calculateImageCost( billingModel string, multiplier float64, ) *CostBreakdown { - if s.resolveChannelPricing(ctx, billingModel, apiKey) != nil { + if resolved := s.resolveChannelPricing(ctx, billingModel, apiKey); resolved != nil { tokens := UsageTokens{ InputTokens: result.Usage.InputTokens, OutputTokens: result.Usage.OutputTokens, @@ -8191,6 +7637,7 @@ func (s *GatewayService) calculateImageCost( RequestCount: 1, RateMultiplier: multiplier, Resolver: s.resolver, + Resolved: resolved, }) if err != nil { logger.LegacyPrintf("service.gateway", "Calculate image token cost failed: %v", err) @@ -8233,7 +7680,7 @@ func (s *GatewayService) calculateTokenCost( var err error // 优先尝试渠道定价 → CalculateCostUnified - if s.resolveChannelPricing(ctx, billingModel, apiKey) != nil { + if resolved := s.resolveChannelPricing(ctx, billingModel, apiKey); resolved != nil { gid := apiKey.Group.ID cost, err = s.billingService.CalculateCostUnified(CostInput{ Ctx: ctx, @@ -8243,6 +7690,7 @@ func (s *GatewayService) calculateTokenCost( RequestCount: 1, RateMultiplier: multiplier, Resolver: s.resolver, + Resolved: resolved, }) } else if opts.LongContextThreshold > 0 { // 长上下文双倍计费(如 Gemini 200K 阈值) @@ -8330,13 +7778,7 @@ func (s *GatewayService) buildRecordUsageLog( } // resolveBillingMode 根据计费结果和请求类型确定计费模式。 -// Sora 媒体类型自身已确定计费模式(由上游处理),返回 nil 跳过。 func resolveBillingMode(opts *recordUsageOpts, result *ForwardResult, cost *CostBreakdown) *string { - isSoraMedia := opts.EnableClaudePath && - (result.MediaType == MediaTypeImage || result.MediaType == MediaTypeVideo || result.MediaType == MediaTypePrompt) - if isSoraMedia { - return nil - } var mode string switch { case cost != nil && cost.BillingMode != "": @@ -8350,9 +7792,6 @@ func resolveBillingMode(opts *recordUsageOpts, result *ForwardResult, cost *Cost } func resolveMediaType(opts *recordUsageOpts, result *ForwardResult) *string { - if opts.EnableClaudePath && strings.TrimSpace(result.MediaType) != "" { - return &result.MediaType - } return nil } @@ -8559,10 +7998,12 @@ func (s *GatewayService) ForwardCountTokens(ctx context.Context, c *gin.Context, return err } - // 获取代理URL + // 获取代理URL(自定义 base URL 模式下,proxy 通过 buildCustomRelayURL 作为查询参数传递) proxyURL := "" if account.ProxyID != nil && account.Proxy != nil { - proxyURL = account.Proxy.URL() + if !account.IsCustomBaseURLEnabled() || account.GetCustomBaseURL() == "" { + proxyURL = account.Proxy.URL() + } } // 发送请求 @@ -8841,6 +8282,16 @@ func (s *GatewayService) buildCountTokensRequest(ctx context.Context, c *gin.Con } targetURL = validatedURL + "/v1/messages/count_tokens?beta=true" } + } else if account.IsCustomBaseURLEnabled() { + customURL := account.GetCustomBaseURL() + if customURL == "" { + return nil, fmt.Errorf("custom_base_url is enabled but not configured for account %d", account.ID) + } + validatedURL, err := s.validateUpstreamBaseURL(customURL) + if err != nil { + return nil, err + } + targetURL = s.buildCustomRelayURL(validatedURL, "/v1/messages/count_tokens", account) } clientHeaders := http.Header{} @@ -8946,6 +8397,15 @@ func (s *GatewayService) buildCountTokensRequest(ctx context.Context, c *gin.Con } } + // 同步 X-Claude-Code-Session-Id 头:取 body 中已处理的 metadata.user_id 的 session_id 覆盖 + if sessionHeader := getHeaderRaw(req.Header, "X-Claude-Code-Session-Id"); sessionHeader != "" { + if uid := gjson.GetBytes(body, "metadata.user_id").String(); uid != "" { + if parsed := ParseMetadataUserID(uid); parsed != nil { + setHeaderRaw(req.Header, "X-Claude-Code-Session-Id", parsed.SessionID) + } + } + } + if c != nil && tokenType == "oauth" { c.Set(claudeMimicDebugInfoKey, buildClaudeMimicDebugLine(req, body, account, tokenType, mimicClaudeCode)) } @@ -8967,6 +8427,19 @@ func (s *GatewayService) countTokensError(c *gin.Context, status int, errType, m }) } +// buildCustomRelayURL 构建自定义中继转发 URL +// 在 path 后附加 beta=true 和可选的 proxy 查询参数 +func (s *GatewayService) buildCustomRelayURL(baseURL, path string, account *Account) string { + u := strings.TrimRight(baseURL, "/") + path + "?beta=true" + if account.ProxyID != nil && account.Proxy != nil { + proxyURL := account.Proxy.URL() + if proxyURL != "" { + u += "&proxy=" + url.QueryEscape(proxyURL) + } + } + return u +} + func (s *GatewayService) validateUpstreamBaseURL(raw string) (string, error) { if s.cfg != nil && !s.cfg.Security.URLAllowlist.Enabled { normalized, err := urlvalidator.ValidateURLFormat(raw, s.cfg.Security.URLAllowlist.AllowInsecureHTTP)