diff --git a/backend/internal/handler/gateway_handler_chat_completions.go b/backend/internal/handler/gateway_handler_chat_completions.go index be267332..abe2a1e5 100644 --- a/backend/internal/handler/gateway_handler_chat_completions.go +++ b/backend/internal/handler/gateway_handler_chat_completions.go @@ -80,7 +80,7 @@ func (h *GatewayHandler) ChatCompletions(c *gin.Context) { setOpsRequestContext(c, reqModel, reqStream, body) setOpsEndpointContext(c, "", int16(service.RequestTypeFromLegacy(reqStream, false))) - // 解析渠道级模型映射 + // 解析渠道级模型映射 + 限制检查 channelMapping, _ := h.gatewayService.ResolveChannelMappingAndRestrict(c.Request.Context(), apiKey.GroupID, reqModel) // Claude Code only restriction diff --git a/backend/internal/handler/gateway_handler_responses.go b/backend/internal/handler/gateway_handler_responses.go index e908eb9e..cf877182 100644 --- a/backend/internal/handler/gateway_handler_responses.go +++ b/backend/internal/handler/gateway_handler_responses.go @@ -80,7 +80,7 @@ func (h *GatewayHandler) Responses(c *gin.Context) { setOpsRequestContext(c, reqModel, reqStream, body) setOpsEndpointContext(c, "", int16(service.RequestTypeFromLegacy(reqStream, false))) - // 解析渠道级模型映射 + // 解析渠道级模型映射 + 限制检查 channelMapping, _ := h.gatewayService.ResolveChannelMappingAndRestrict(c.Request.Context(), apiKey.GroupID, reqModel) // Claude Code only restriction: diff --git a/backend/internal/handler/gemini_v1beta_handler.go b/backend/internal/handler/gemini_v1beta_handler.go index d200c17c..ff63bc7f 100644 --- a/backend/internal/handler/gemini_v1beta_handler.go +++ b/backend/internal/handler/gemini_v1beta_handler.go @@ -121,7 +121,7 @@ func (h *GatewayHandler) GeminiV1BetaGetModel(c *gin.Context) { googleError(c, http.StatusBadGateway, err.Error()) return } - if shouldFallbackGeminiModel(modelName, res) { + if shouldFallbackGeminiModels(res) { c.JSON(http.StatusOK, gemini.FallbackModel(modelName)) return } @@ -184,7 +184,7 @@ func (h *GatewayHandler) GeminiV1BetaModels(c *gin.Context) { setOpsRequestContext(c, modelName, stream, body) setOpsEndpointContext(c, "", int16(service.RequestTypeFromLegacy(stream, false))) - // 解析渠道级模型映射 + // 解析渠道级模型映射 + 限制检查 channelMapping, _ := h.gatewayService.ResolveChannelMappingAndRestrict(c.Request.Context(), apiKey.GroupID, modelName) reqModel := modelName // 保存映射前的原始模型名 if channelMapping.Mapped { @@ -682,16 +682,6 @@ func shouldFallbackGeminiModels(res *service.UpstreamHTTPResult) bool { return false } -func shouldFallbackGeminiModel(modelName string, res *service.UpstreamHTTPResult) bool { - if shouldFallbackGeminiModels(res) { - return true - } - if res == nil || res.StatusCode != http.StatusNotFound { - return false - } - return gemini.HasFallbackModel(modelName) -} - // extractGeminiCLISessionHash 从 Gemini CLI 请求中提取会话标识。 // 组合 x-gemini-api-privileged-user-id header 和请求体中的 tmp 目录哈希。 // diff --git a/backend/internal/handler/openai_chat_completions.go b/backend/internal/handler/openai_chat_completions.go index 991cbb91..ada401c9 100644 --- a/backend/internal/handler/openai_chat_completions.go +++ b/backend/internal/handler/openai_chat_completions.go @@ -79,7 +79,7 @@ func (h *OpenAIGatewayHandler) ChatCompletions(c *gin.Context) { setOpsRequestContext(c, reqModel, reqStream, body) setOpsEndpointContext(c, "", int16(service.RequestTypeFromLegacy(reqStream, false))) - // 解析渠道级模型映射 + // 解析渠道级模型映射 + 限制检查 channelMapping, _ := h.gatewayService.ResolveChannelMappingAndRestrict(c.Request.Context(), apiKey.GroupID, reqModel) if h.errorPassthroughService != nil { diff --git a/backend/internal/handler/openai_gateway_handler.go b/backend/internal/handler/openai_gateway_handler.go index 5319b55d..2b081617 100644 --- a/backend/internal/handler/openai_gateway_handler.go +++ b/backend/internal/handler/openai_gateway_handler.go @@ -47,13 +47,6 @@ func resolveOpenAIForwardDefaultMappedModel(apiKey *service.APIKey, fallbackMode return strings.TrimSpace(apiKey.Group.DefaultMappedModel) } -func resolveOpenAIMessagesDispatchMappedModel(apiKey *service.APIKey, requestedModel string) string { - if apiKey == nil || apiKey.Group == nil { - return "" - } - return strings.TrimSpace(apiKey.Group.ResolveMessagesDispatchModel(requestedModel)) -} - // NewOpenAIGatewayHandler creates a new OpenAIGatewayHandler func NewOpenAIGatewayHandler( gatewayService *service.OpenAIGatewayService, @@ -557,8 +550,6 @@ func (h *OpenAIGatewayHandler) Messages(c *gin.Context) { return } reqModel := modelResult.String() - routingModel := service.NormalizeOpenAICompatRequestedModel(reqModel) - preferredMappedModel := resolveOpenAIMessagesDispatchMappedModel(apiKey, reqModel) reqStream := gjson.GetBytes(body, "stream").Bool() reqLog = reqLog.With(zap.String("model", reqModel), zap.Bool("stream", reqStream)) @@ -617,20 +608,17 @@ func (h *OpenAIGatewayHandler) Messages(c *gin.Context) { failedAccountIDs := make(map[int64]struct{}) sameAccountRetryCount := make(map[int64]int) var lastFailoverErr *service.UpstreamFailoverError - effectiveMappedModel := preferredMappedModel for { - currentRoutingModel := routingModel - if effectiveMappedModel != "" { - currentRoutingModel = effectiveMappedModel - } + // 清除上一次迭代的降级模型标记,避免残留影响本次迭代 + c.Set("openai_messages_fallback_model", "") reqLog.Debug("openai_messages.account_selecting", zap.Int("excluded_account_count", len(failedAccountIDs))) selection, scheduleDecision, err := h.gatewayService.SelectAccountWithScheduler( c.Request.Context(), apiKey.GroupID, "", // no previous_response_id sessionHash, - currentRoutingModel, + reqModel, failedAccountIDs, service.OpenAIUpstreamTransportAny, ) @@ -639,7 +627,29 @@ func (h *OpenAIGatewayHandler) Messages(c *gin.Context) { zap.Error(err), zap.Int("excluded_account_count", len(failedAccountIDs)), ) + // 首次调度失败 + 有默认映射模型 → 用默认模型重试 if len(failedAccountIDs) == 0 { + defaultModel := "" + if apiKey.Group != nil { + defaultModel = apiKey.Group.DefaultMappedModel + } + if defaultModel != "" && defaultModel != reqModel { + reqLog.Info("openai_messages.fallback_to_default_model", + zap.String("default_mapped_model", defaultModel), + ) + selection, scheduleDecision, err = h.gatewayService.SelectAccountWithScheduler( + c.Request.Context(), + apiKey.GroupID, + "", + sessionHash, + defaultModel, + failedAccountIDs, + service.OpenAIUpstreamTransportAny, + ) + if err == nil && selection != nil { + c.Set("openai_messages_fallback_model", defaultModel) + } + } if err != nil { h.anthropicStreamingAwareError(c, http.StatusServiceUnavailable, "api_error", "Service temporarily unavailable", streamStarted) return @@ -671,7 +681,9 @@ func (h *OpenAIGatewayHandler) Messages(c *gin.Context) { service.SetOpsLatencyMs(c, service.OpsRoutingLatencyMsKey, time.Since(routingStart).Milliseconds()) forwardStart := time.Now() - defaultMappedModel := strings.TrimSpace(effectiveMappedModel) + // Forward 层需要始终拿到 group 默认映射模型,这样未命中账号级映射的 + // Claude 兼容模型才不会在后续 Codex 规范化中意外退化到 gpt-5.1。 + defaultMappedModel := resolveOpenAIForwardDefaultMappedModel(apiKey, c.GetString("openai_messages_fallback_model")) // 应用渠道模型映射到请求体 forwardBody := body if channelMappingMsg.Mapped { @@ -1106,7 +1118,7 @@ func (h *OpenAIGatewayHandler) ResponsesWebSocket(c *gin.Context) { setOpsRequestContext(c, reqModel, true, firstMessage) setOpsEndpointContext(c, "", int16(service.RequestTypeWSV2)) - // 解析渠道级模型映射 + // 解析渠道级模型映射 + 限制检查 channelMappingWS, _ := h.gatewayService.ResolveChannelMappingAndRestrict(ctx, apiKey.GroupID, reqModel) var currentUserRelease func() diff --git a/backend/internal/service/gateway_service.go b/backend/internal/service/gateway_service.go index 8b0bdc2a..33ab38f2 100644 --- a/backend/internal/service/gateway_service.go +++ b/backend/internal/service/gateway_service.go @@ -12,7 +12,6 @@ import ( "log/slog" mathrand "math/rand" "net/http" - "net/url" "os" "path/filepath" "regexp" @@ -42,7 +41,8 @@ import ( const ( claudeAPIURL = "https://api.anthropic.com/v1/messages?beta=true" claudeAPICountTokensURL = "https://api.anthropic.com/v1/messages/count_tokens?beta=true" - stickySessionTTL = time.Hour // 粘性会话TTL + stickySessionTTL = time.Hour // 粘性会话TTL + ClientAffinityTTL = 24 * time.Hour // 客户端亲和TTL defaultMaxLineSize = 500 * 1024 * 1024 // Canonical Claude Code banner. Keep it EXACT (no trailing whitespace/newlines) // to match real Claude CLI traffic as closely as possible. When we need a visual @@ -60,14 +60,28 @@ const ( claudeMimicDebugInfoKey = "claude_mimic_debug_info" ) +// MediaType 媒体类型常量 +const ( + MediaTypeImage = "image" + MediaTypeVideo = "video" + MediaTypePrompt = "prompt" +) + +const ( + claudeMaxMessageOverheadTokens = 3 + claudeMaxBlockOverheadTokens = 1 + claudeMaxUnknownContentTokens = 4 +) + // ForceCacheBillingContextKey 强制缓存计费上下文键 // 用于粘性会话切换时,将 input_tokens 转为 cache_read_input_tokens 计费 type forceCacheBillingKeyType struct{} // accountWithLoad 账号与负载信息的组合,用于负载感知调度 type accountWithLoad struct { - account *Account - loadInfo *AccountLoadInfo + account *Account + loadInfo *AccountLoadInfo + affinityCount int64 // 亲和客户端数量(反向索引),越少越优先 } var ForceCacheBillingContextKey = forceCacheBillingKeyType{} @@ -331,6 +345,10 @@ var ( sseDataRe = regexp.MustCompile(`^data:\s*`) claudeCliUserAgentRe = regexp.MustCompile(`^claude-cli/\d+\.\d+\.\d+`) + // clientIDFromMetadataRegex 从 metadata.user_id 中提取客户端 ID(64位 hex) + // 格式: user_{64位hex}_account_... + clientIDFromMetadataRegex = regexp.MustCompile(`^user_([a-f0-9]{64})_account_`) + // claudeCodePromptPrefixes 用于检测 Claude Code 系统提示词的前缀列表 // 支持多种变体:标准版、Agent SDK 版、Explore Agent 版、Compact 版等 // 注意:前缀之间不应存在包含关系,否则会导致冗余匹配 @@ -348,6 +366,12 @@ var ErrNoAvailableAccounts = errors.New("no available accounts") // ErrClaudeCodeOnly 表示分组仅允许 Claude Code 客户端访问 var ErrClaudeCodeOnly = errors.New("this group only allows Claude Code clients") +// ErrAffinityNoSwitch 表示亲和账号不可用且不允许切换到其他账号 +var ErrAffinityNoSwitch = errors.New("affinity account unavailable and switching is disabled") + +// ErrAffinityLimitExceeded 表示亲和客户端限制已达上限 +var ErrAffinityLimitExceeded = errors.New("affinity client limit exceeded") + // allowedHeaders 白名单headers(参考CRS项目) var allowedHeaders = map[string]bool{ "accept": true, @@ -369,8 +393,6 @@ var allowedHeaders = map[string]bool{ "user-agent": true, "content-type": true, "accept-encoding": true, - "x-claude-code-session-id": true, - "x-client-request-id": true, } // GatewayCache 定义网关服务的缓存操作接口。 @@ -391,6 +413,39 @@ type GatewayCache interface { // DeleteSessionAccountID 删除粘性会话绑定,用于账号不可用时主动清理 // Delete sticky session binding, used to proactively clean up when account becomes unavailable DeleteSessionAccountID(ctx context.Context, groupID int64, sessionHash string) error + + // GetAffinityAccounts 获取亲和账号列表(按最近使用降序),同时清理过期成员 + GetAffinityAccounts(ctx context.Context, groupID int64, userID int64, clientID string, ttl time.Duration) ([]int64, error) + // UpdateAffinity 添加/更新亲和关系(更新 score 为当前时间戳,刷新 key TTL) + UpdateAffinity(ctx context.Context, groupID int64, userID int64, clientID string, accountID int64, ttl time.Duration) error + // GetAccountAffinityCountBatch 批量获取账号的亲和成员数量(惰性清理过期成员) + GetAccountAffinityCountBatch(ctx context.Context, groupID int64, accountIDs []int64, ttl time.Duration) (map[int64]int64, error) + // GetAccountAffinityClientsBatch 批量获取每个账号跨所有分组的亲和成员列表(去重) + // accountGroups: map[accountID][]groupID + // 返回值成员格式为 {userID}/{clientID} + GetAccountAffinityClientsBatch(ctx context.Context, accountGroups map[int64][]int64, ttl time.Duration) (map[int64][]string, error) + // GetAccountAffinityClientsWithScores 获取单个账号跨所有分组的亲和客户端列表(含最后活跃时间) + GetAccountAffinityClientsWithScores(ctx context.Context, accountID int64, groupIDs []int64, ttl time.Duration) ([]AffinityClient, error) + // ClearAccountAffinity 清除指定账号在所有分组的亲和记录(正向+反向索引) + // 用于账号关闭亲和时立即清理旧绑定 + ClearAccountAffinity(ctx context.Context, accountID int64, groupIDs []int64) error + // GetAffinityMultiCount 获取账号的多维度亲和计数 + // 返回: uniqueUsers, uniqueClients, perUserClients + GetAffinityMultiCount(ctx context.Context, groupID int64, accountID int64, targetUserID int64, ttl time.Duration) (users, clients, perUser int64, err error) +} + +// AffinityClient 亲和客户端信息(含用户 ID 和最后活跃时间) +type AffinityClient struct { + UserID int64 `json:"user_id"` + ClientID string `json:"client_id"` + LastActive time.Time `json:"last_active"` +} + +// SortAffinityClients 按最后活跃时间降序排序 +func SortAffinityClients(clients []AffinityClient) { + sort.Slice(clients, func(i, j int) bool { + return clients[i].LastActive.After(clients[j].LastActive) + }) } // derefGroupID safely dereferences *int64 to int64, returning 0 if nil @@ -461,6 +516,20 @@ func shouldClearStickySession(account *Account, requestedModel string) bool { return false } +// extractClientIDFromMetadata 从 metadata.user_id 中提取客户端 ID(64位 hex)。 +// 格式: user_{64位hex}_account_..._session_... +// 返回空字符串表示无法提取(非 Claude Code/Console 客户端)。 +func extractClientIDFromMetadata(metadataUserID string) string { + if metadataUserID == "" { + return "" + } + matches := clientIDFromMetadataRegex.FindStringSubmatch(metadataUserID) + if matches == nil { + return "" + } + return matches[1] +} + type AccountWaitPlan struct { AccountID int64 MaxConcurrency int @@ -504,6 +573,9 @@ type ForwardResult struct { ImageCount int // 生成的图片数量 ImageSize string // 图片尺寸 "1K", "2K", "4K" + // Sora 媒体字段 + MediaType string // image / video / prompt + MediaURL string // 生成后的媒体地址(可选) } // UpstreamFailoverError indicates an upstream error that should trigger account failover. @@ -1162,6 +1234,11 @@ func (s *GatewayService) SelectAccountForModel(ctx context.Context, groupID *int // SelectAccountForModelWithExclusions selects an account supporting the requested model while excluding specified accounts. func (s *GatewayService) SelectAccountForModelWithExclusions(ctx context.Context, groupID *int64, sessionHash string, requestedModel string, excludedIDs map[int64]struct{}) (*Account, error) { + // 渠道定价限制预检查(requested / channel_mapped 基准) + if s.checkChannelPricingRestriction(ctx, groupID, requestedModel) { + return nil, fmt.Errorf("%w supporting model: %s (channel pricing restriction)", ErrNoAvailableAccounts, requestedModel) + } + // 优先检查 context 中的强制平台(/antigravity 路由) var platform string forcePlatform, hasForcePlatform := ctx.Value(ctxkey.ForcePlatform).(string) @@ -1180,32 +1257,15 @@ func (s *GatewayService) SelectAccountForModelWithExclusions(ctx context.Context platform = PlatformAnthropic } - // Claude Code 限制可能已将 groupID 解析为 fallback group, - // 渠道限制预检查必须使用解析后的分组。 - if s.checkChannelPricingRestriction(ctx, groupID, requestedModel) { - slog.Warn("channel pricing restriction blocked request", - "group_id", derefGroupID(groupID), - "model", requestedModel) - return nil, fmt.Errorf("%w supporting model: %s (channel pricing restriction)", ErrNoAvailableAccounts, requestedModel) - } - // anthropic/gemini 分组支持混合调度(包含启用了 mixed_scheduling 的 antigravity 账户) // 注意:强制平台模式不走混合调度 if (platform == PlatformAnthropic || platform == PlatformGemini) && !hasForcePlatform { - account, err := s.selectAccountWithMixedScheduling(ctx, groupID, sessionHash, requestedModel, excludedIDs, platform) - if err != nil { - return nil, err - } - return s.hydrateSelectedAccount(ctx, account) + return s.selectAccountWithMixedScheduling(ctx, groupID, sessionHash, requestedModel, excludedIDs, platform) } // antigravity 分组、强制平台模式或无分组使用单平台选择 // 注意:强制平台模式也必须遵守分组限制,不再回退到全平台查询 - account, err := s.selectAccountForModelWithPlatform(ctx, groupID, sessionHash, requestedModel, excludedIDs, platform) - if err != nil { - return nil, err - } - return s.hydrateSelectedAccount(ctx, account) + return s.selectAccountForModelWithPlatform(ctx, groupID, sessionHash, requestedModel, excludedIDs, platform) } // SelectAccountWithLoadAwareness selects account with load-awareness and wait plan. @@ -1213,6 +1273,11 @@ func (s *GatewayService) SelectAccountForModelWithExclusions(ctx context.Context // metadataUserID: 用于客户端亲和调度,从中提取客户端 ID // sub2apiUserID: 系统用户 ID,用于二维亲和调度 func (s *GatewayService) SelectAccountWithLoadAwareness(ctx context.Context, groupID *int64, sessionHash string, requestedModel string, excludedIDs map[int64]struct{}, metadataUserID string, sub2apiUserID int64) (*AccountSelectionResult, error) { + // 渠道定价限制预检查(requested / channel_mapped 基准) + if s.checkChannelPricingRestriction(ctx, groupID, requestedModel) { + return nil, fmt.Errorf("%w supporting model: %s (channel pricing restriction)", ErrNoAvailableAccounts, requestedModel) + } + // 调试日志:记录调度入口参数 excludedIDsList := make([]int64, 0, len(excludedIDs)) for id := range excludedIDs { @@ -1233,15 +1298,6 @@ func (s *GatewayService) SelectAccountWithLoadAwareness(ctx context.Context, gro } ctx = s.withGroupContext(ctx, group) - // Claude Code 限制可能已将 groupID 解析为 fallback group, - // 渠道限制预检查必须使用解析后的分组。 - if s.checkChannelPricingRestriction(ctx, groupID, requestedModel) { - slog.Warn("channel pricing restriction blocked request", - "group_id", derefGroupID(groupID), - "model", requestedModel) - return nil, fmt.Errorf("%w supporting model: %s (channel pricing restriction)", ErrNoAvailableAccounts, requestedModel) - } - var stickyAccountID int64 if prefetch := prefetchedStickyAccountIDFromContext(ctx, groupID); prefetch > 0 { stickyAccountID = prefetch @@ -1251,6 +1307,10 @@ func (s *GatewayService) SelectAccountWithLoadAwareness(ctx context.Context, gro } } + // 提取客户端 ID(用于客户端亲和调度) + affinityClientID := extractClientIDFromMetadata(metadataUserID) + affinityUserID := sub2apiUserID + if s.debugModelRoutingEnabled() && requestedModel != "" { groupPlatform := "" if group != nil { @@ -1272,6 +1332,10 @@ func (s *GatewayService) SelectAccountWithLoadAwareness(ctx context.Context, gro if err != nil { return nil, err } + if shouldFilterAccountWithoutClientID(account, affinityClientID) { + localExcluded[account.ID] = struct{}{} + continue + } result, err := s.tryAcquireAccountSlot(ctx, account.ID, account.Concurrency) if err == nil && result.Acquired { @@ -1281,7 +1345,11 @@ func (s *GatewayService) SelectAccountWithLoadAwareness(ctx context.Context, gro localExcluded[account.ID] = struct{}{} // 排除此账号 continue // 重新选择 } - return s.newSelectionResult(ctx, account, true, result.ReleaseFunc, nil) + return &AccountSelectionResult{ + Account: account, + Acquired: true, + ReleaseFunc: result.ReleaseFunc, + }, nil } // 对于等待计划的情况,也需要先检查会话限制 @@ -1293,20 +1361,26 @@ func (s *GatewayService) SelectAccountWithLoadAwareness(ctx context.Context, gro if stickyAccountID > 0 && stickyAccountID == account.ID && s.concurrencyService != nil { waitingCount, _ := s.concurrencyService.GetAccountWaitingCount(ctx, account.ID) if waitingCount < cfg.StickySessionMaxWaiting { - return s.newSelectionResult(ctx, account, false, nil, &AccountWaitPlan{ - AccountID: account.ID, - MaxConcurrency: account.Concurrency, - Timeout: cfg.StickySessionWaitTimeout, - MaxWaiting: cfg.StickySessionMaxWaiting, - }) + return &AccountSelectionResult{ + Account: account, + WaitPlan: &AccountWaitPlan{ + AccountID: account.ID, + MaxConcurrency: account.Concurrency, + Timeout: cfg.StickySessionWaitTimeout, + MaxWaiting: cfg.StickySessionMaxWaiting, + }, + }, nil } } - return s.newSelectionResult(ctx, account, false, nil, &AccountWaitPlan{ - AccountID: account.ID, - MaxConcurrency: account.Concurrency, - Timeout: cfg.FallbackWaitTimeout, - MaxWaiting: cfg.FallbackMaxWaiting, - }) + return &AccountSelectionResult{ + Account: account, + WaitPlan: &AccountWaitPlan{ + AccountID: account.ID, + MaxConcurrency: account.Concurrency, + Timeout: cfg.FallbackWaitTimeout, + MaxWaiting: cfg.FallbackMaxWaiting, + }, + }, nil } } @@ -1323,12 +1397,18 @@ func (s *GatewayService) SelectAccountWithLoadAwareness(ctx context.Context, gro if err != nil { return nil, err } + accounts = filterAccountsWithoutClientID(accounts, affinityClientID) if len(accounts) == 0 { return nil, ErrNoAvailableAccounts } ctx = s.withWindowCostPrefetch(ctx, accounts) ctx = s.withRPMPrefetch(ctx, accounts) + // 提前构建 accountByID(供 Layer 1 和 Layer 1.5 使用) + accountByID := make(map[int64]*Account, len(accounts)) + for i := range accounts { + accountByID[accounts[i].ID] = &accounts[i] + } isExcluded := func(accountID int64) bool { if excludedIDs == nil { return false @@ -1336,12 +1416,19 @@ func (s *GatewayService) SelectAccountWithLoadAwareness(ctx context.Context, gro _, excluded := excludedIDs[accountID] return excluded } - - // 提前构建 accountByID(供 Layer 1 和 Layer 1.5 使用) - accountByID := make(map[int64]*Account, len(accounts)) - for i := range accounts { - accountByID[accounts[i].ID] = &accounts[i] - } + affinityFlow := newGatewayAffinityFlow( + s, + ctx, + groupID, + sessionHash, + requestedModel, + affinityClientID, + affinityUserID, + platform, + useMixed, + accountByID, + isExcluded, + ) // 获取模型路由配置(仅 anthropic 平台) var routingAccountIDs []int64 @@ -1430,76 +1517,53 @@ func (s *GatewayService) SelectAccountWithLoadAwareness(ctx context.Context, gro if containsInt64(routingAccountIDs, stickyAccountID) && !isExcluded(stickyAccountID) { // 粘性账号在路由列表中,优先使用 if stickyAccount, ok := accountByID[stickyAccountID]; ok { - var stickyCacheMissReason string - - gatePass := s.isAccountSchedulableForSelection(stickyAccount) && + if s.isAccountSchedulableForSelection(stickyAccount) && s.isAccountAllowedForPlatform(stickyAccount, platform, useMixed) && (requestedModel == "" || s.isModelSupportedByAccountWithContext(ctx, stickyAccount, requestedModel)) && s.isAccountSchedulableForModelSelection(ctx, stickyAccount, requestedModel) && s.isAccountSchedulableForQuota(stickyAccount) && - s.isAccountSchedulableForWindowCost(ctx, stickyAccount, true) + s.isAccountSchedulableForWindowCost(ctx, stickyAccount, true) && - rpmPass := gatePass && s.isAccountSchedulableForRPM(ctx, stickyAccount, true) - - if rpmPass { // 粘性会话窗口费用+RPM 检查 + s.isAccountSchedulableForRPM(ctx, stickyAccount, true) { // 粘性会话窗口费用+RPM 检查 result, err := s.tryAcquireAccountSlot(ctx, stickyAccountID, stickyAccount.Concurrency) if err == nil && result.Acquired { // 会话数量限制检查 if !s.checkAndRegisterSession(ctx, stickyAccount, sessionHash) { result.ReleaseFunc() // 释放槽位 - stickyCacheMissReason = "session_limit" // 继续到负载感知选择 } else { if s.debugModelRoutingEnabled() { logger.LegacyPrintf("service.gateway", "[ModelRoutingDebug] routed sticky hit: group_id=%v model=%s session=%s account=%d", derefGroupID(groupID), requestedModel, shortSessionHash(sessionHash), stickyAccountID) } - return s.newSelectionResult(ctx, stickyAccount, true, result.ReleaseFunc, nil) + return &AccountSelectionResult{ + Account: stickyAccount, + Acquired: true, + ReleaseFunc: result.ReleaseFunc, + }, nil } } - if stickyCacheMissReason == "" { - waitingCount, _ := s.concurrencyService.GetAccountWaitingCount(ctx, stickyAccountID) - if waitingCount < cfg.StickySessionMaxWaiting { - // 会话数量限制检查(等待计划也需要占用会话配额) - if !s.checkAndRegisterSession(ctx, stickyAccount, sessionHash) { - stickyCacheMissReason = "session_limit" - // 会话限制已满,继续到负载感知选择 - } else { - return &AccountSelectionResult{ - Account: stickyAccount, - WaitPlan: &AccountWaitPlan{ - AccountID: stickyAccountID, - MaxConcurrency: stickyAccount.Concurrency, - Timeout: cfg.StickySessionWaitTimeout, - MaxWaiting: cfg.StickySessionMaxWaiting, - }, - }, nil - } + waitingCount, _ := s.concurrencyService.GetAccountWaitingCount(ctx, stickyAccountID) + if waitingCount < cfg.StickySessionMaxWaiting { + // 会话数量限制检查(等待计划也需要占用会话配额) + if !s.checkAndRegisterSession(ctx, stickyAccount, sessionHash) { + // 会话限制已满,继续到负载感知选择 } else { - stickyCacheMissReason = "wait_queue_full" + return &AccountSelectionResult{ + Account: stickyAccount, + WaitPlan: &AccountWaitPlan{ + AccountID: stickyAccountID, + MaxConcurrency: stickyAccount.Concurrency, + Timeout: cfg.StickySessionWaitTimeout, + MaxWaiting: cfg.StickySessionMaxWaiting, + }, + }, nil } } // 粘性账号槽位满且等待队列已满,继续使用负载感知选择 - } else if !gatePass { - stickyCacheMissReason = "gate_check" - } else { - stickyCacheMissReason = "rpm_red" - } - - // 记录粘性缓存未命中的结构化日志 - if stickyCacheMissReason != "" { - baseRPM := stickyAccount.GetBaseRPM() - var currentRPM int - if count, ok := rpmFromPrefetchContext(ctx, stickyAccount.ID); ok { - currentRPM = count - } - logger.LegacyPrintf("service.gateway", "[StickyCacheMiss] reason=%s account_id=%d session=%s current_rpm=%d base_rpm=%d", - stickyCacheMissReason, stickyAccountID, shortSessionHash(sessionHash), currentRPM, baseRPM) } } else { _ = s.cache.DeleteSessionAccountID(ctx, derefGroupID(groupID), sessionHash) - logger.LegacyPrintf("service.gateway", "[StickyCacheMiss] reason=account_cleared account_id=%d session=%s current_rpm=0 base_rpm=0", - stickyAccountID, shortSessionHash(sessionHash)) } } } @@ -1527,7 +1591,10 @@ func (s *GatewayService) SelectAccountWithLoadAwareness(ctx context.Context, gro } if len(routingAvailable) > 0 { - // 排序:优先级 > 负载率 > 最后使用时间 + // 批量获取亲和客户端数量 + s.populateAffinityCounts(ctx, routingAvailable, derefGroupID(groupID)) + + // 排序:优先级 > 负载率 > 亲和客户端数 > 最后使用时间 sort.SliceStable(routingAvailable, func(i, j int) bool { a, b := routingAvailable[i], routingAvailable[j] if a.account.Priority != b.account.Priority { @@ -1536,6 +1603,9 @@ func (s *GatewayService) SelectAccountWithLoadAwareness(ctx context.Context, gro if a.loadInfo.LoadRate != b.loadInfo.LoadRate { return a.loadInfo.LoadRate < b.loadInfo.LoadRate } + if a.affinityCount != b.affinityCount { + return a.affinityCount < b.affinityCount + } switch { case a.account.LastUsedAt == nil && b.account.LastUsedAt != nil: return true @@ -1561,10 +1631,17 @@ func (s *GatewayService) SelectAccountWithLoadAwareness(ctx context.Context, gro if sessionHash != "" && s.cache != nil { _ = s.cache.SetSessionAccountID(ctx, derefGroupID(groupID), sessionHash, item.account.ID, stickySessionTTL) } + if affinityClientID != "" && affinityUserID > 0 && s.cache != nil && item.account.IsAffinityEnabled() { + _ = s.cache.UpdateAffinity(ctx, derefGroupID(groupID), affinityUserID, affinityClientID, item.account.ID, ClientAffinityTTL) + } if s.debugModelRoutingEnabled() { logger.LegacyPrintf("service.gateway", "[ModelRoutingDebug] routed select: group_id=%v model=%s session=%s account=%d", derefGroupID(groupID), requestedModel, shortSessionHash(sessionHash), item.account.ID) } - return s.newSelectionResult(ctx, item.account, true, result.ReleaseFunc, nil) + return &AccountSelectionResult{ + Account: item.account, + Acquired: true, + ReleaseFunc: result.ReleaseFunc, + }, nil } } @@ -1577,12 +1654,15 @@ func (s *GatewayService) SelectAccountWithLoadAwareness(ctx context.Context, gro if s.debugModelRoutingEnabled() { logger.LegacyPrintf("service.gateway", "[ModelRoutingDebug] routed wait: group_id=%v model=%s session=%s account=%d", derefGroupID(groupID), requestedModel, shortSessionHash(sessionHash), item.account.ID) } - return s.newSelectionResult(ctx, item.account, false, nil, &AccountWaitPlan{ - AccountID: item.account.ID, - MaxConcurrency: item.account.Concurrency, - Timeout: cfg.StickySessionWaitTimeout, - MaxWaiting: cfg.StickySessionMaxWaiting, - }) + return &AccountSelectionResult{ + Account: item.account, + WaitPlan: &AccountWaitPlan{ + AccountID: item.account.ID, + MaxConcurrency: item.account.Concurrency, + Timeout: cfg.StickySessionWaitTimeout, + MaxWaiting: cfg.StickySessionMaxWaiting, + }, + }, nil } // 所有路由账号会话限制都已满,继续到 Layer 2 回退 } @@ -1591,14 +1671,27 @@ func (s *GatewayService) SelectAccountWithLoadAwareness(ctx context.Context, gro } } - // ============ Layer 1.5: 粘性会话(仅在无模型路由配置时生效) ============ - if len(routingAccountIDs) == 0 && sessionHash != "" && stickyAccountID > 0 && !isExcluded(stickyAccountID) { + // ============ Layer 1.3: 用户亲和预处理(pinned_users 自动注入) ============ + affinityFlow.preprocessPinnedUsers(accounts) + + // ============ Layer 1.4: 客户端亲和调度(优先于粘性会话) ============ + affinityHit := false + if affinityResult, hit, err := affinityFlow.trySelectAffinityAccount(); err != nil { + return nil, err + } else { + affinityHit = hit + if affinityResult != nil { + return affinityResult, nil + } + } + + // ============ Layer 1.5: 粘性会话(仅在无模型路由配置 且 亲和未命中时生效) ============ + if !affinityHit && len(routingAccountIDs) == 0 && sessionHash != "" && stickyAccountID > 0 && !isExcluded(stickyAccountID) { accountID := stickyAccountID if accountID > 0 && !isExcluded(accountID) { account, ok := accountByID[accountID] if ok { // 检查账户是否需要清理粘性会话绑定 - // Check if the account needs sticky session cleanup clearSticky := shouldClearStickySession(account, requestedModel) if clearSticky { _ = s.cache.DeleteSessionAccountID(ctx, derefGroupID(groupID), sessionHash) @@ -1614,31 +1707,32 @@ func (s *GatewayService) SelectAccountWithLoadAwareness(ctx context.Context, gro result, err := s.tryAcquireAccountSlot(ctx, accountID, account.Concurrency) if err == nil && result.Acquired { // 会话数量限制检查 - // Session count limit check if !s.checkAndRegisterSession(ctx, account, sessionHash) { result.ReleaseFunc() // 释放槽位,继续到 Layer 2 } else { - if s.cache != nil { - _ = s.cache.RefreshSessionTTL(ctx, derefGroupID(groupID), sessionHash, stickySessionTTL) - } - return s.newSelectionResult(ctx, account, true, result.ReleaseFunc, nil) + return &AccountSelectionResult{ + Account: account, + Acquired: true, + ReleaseFunc: result.ReleaseFunc, + }, nil } } waitingCount, _ := s.concurrencyService.GetAccountWaitingCount(ctx, accountID) if waitingCount < cfg.StickySessionMaxWaiting { // 会话数量限制检查(等待计划也需要占用会话配额) - // Session count limit check (wait plan also requires session quota) if !s.checkAndRegisterSession(ctx, account, sessionHash) { // 会话限制已满,继续到 Layer 2 - // Session limit full, continue to Layer 2 } else { - return s.newSelectionResult(ctx, account, false, nil, &AccountWaitPlan{ - AccountID: accountID, - MaxConcurrency: account.Concurrency, - Timeout: cfg.StickySessionWaitTimeout, - MaxWaiting: cfg.StickySessionMaxWaiting, - }) + return &AccountSelectionResult{ + Account: account, + WaitPlan: &AccountWaitPlan{ + AccountID: accountID, + MaxConcurrency: account.Concurrency, + Timeout: cfg.StickySessionWaitTimeout, + MaxWaiting: cfg.StickySessionMaxWaiting, + }, + }, nil } } } @@ -1697,9 +1791,10 @@ func (s *GatewayService) SelectAccountWithLoadAwareness(ctx context.Context, gro loadMap, err := s.concurrencyService.GetAccountsLoadBatch(ctx, accountLoads) if err != nil { - if result, ok, legacyErr := s.tryAcquireByLegacyOrder(ctx, candidates, groupID, sessionHash, preferOAuth); legacyErr != nil { - return nil, legacyErr - } else if ok { + if result, ok := s.tryAcquireByLegacyOrder(ctx, candidates, groupID, sessionHash, preferOAuth); ok { + if affinityClientID != "" && affinityUserID > 0 && s.cache != nil && result.Account != nil && result.Account.IsAffinityEnabled() { + _ = s.cache.UpdateAffinity(ctx, derefGroupID(groupID), affinityUserID, affinityClientID, result.Account.ID, ClientAffinityTTL) + } return result, nil } } else { @@ -1717,13 +1812,37 @@ func (s *GatewayService) SelectAccountWithLoadAwareness(ctx context.Context, gro } } - // 分层过滤选择:优先级 → 负载率 → LRU + // 批量获取亲和客户端数量(用于均衡分配新客户端) + s.populateAffinityCounts(ctx, available, derefGroupID(groupID)) + + // 分层过滤选择:优先级 → 亲和三区 → 负载率 → 亲和客户端数 → LRU for len(available) > 0 { // 1. 取优先级最小的集合 candidates := filterByMinPriority(available) - // 2. 取负载率最低的集合 + // 2. 按亲和三区过滤:绿区优先 → 黄区降级 → 红区移除(在同优先级内) + candidates = classifyByAffinityZone(candidates) + if len(candidates) == 0 { + // 当前优先级组全部在红区,移除后回退到下一优先级组 + minPri := available[0].account.Priority + for _, a := range available[1:] { + if a.account.Priority < minPri { + minPri = a.account.Priority + } + } + newAvailable := make([]accountWithLoad, 0, len(available)) + for _, a := range available { + if a.account.Priority != minPri { + newAvailable = append(newAvailable, a) + } + } + available = newAvailable + continue + } + // 3. 取负载率最低的集合 candidates = filterByMinLoadRate(candidates) - // 3. LRU 选择最久未用的账号 + // 3. 取亲和客户端数最少的集合 + candidates = filterByMinAffinityCount(candidates) + // 4. LRU 选择最久未用的账号 selected := selectByLRU(candidates, preferOAuth) if selected == nil { break @@ -1738,7 +1857,15 @@ func (s *GatewayService) SelectAccountWithLoadAwareness(ctx context.Context, gro if sessionHash != "" && s.cache != nil { _ = s.cache.SetSessionAccountID(ctx, derefGroupID(groupID), sessionHash, selected.account.ID, stickySessionTTL) } - return s.newSelectionResult(ctx, selected.account, true, result.ReleaseFunc, nil) + // 更新亲和关系 + if affinityClientID != "" && affinityUserID > 0 && s.cache != nil && selected.account.IsAffinityEnabled() { + _ = s.cache.UpdateAffinity(ctx, derefGroupID(groupID), affinityUserID, affinityClientID, selected.account.ID, ClientAffinityTTL) + } + return &AccountSelectionResult{ + Account: selected.account, + Acquired: true, + ReleaseFunc: result.ReleaseFunc, + }, nil } } @@ -1761,17 +1888,20 @@ func (s *GatewayService) SelectAccountWithLoadAwareness(ctx context.Context, gro if !s.checkAndRegisterSession(ctx, acc, sessionHash) { continue // 会话限制已满,尝试下一个账号 } - return s.newSelectionResult(ctx, acc, false, nil, &AccountWaitPlan{ - AccountID: acc.ID, - MaxConcurrency: acc.Concurrency, - Timeout: cfg.FallbackWaitTimeout, - MaxWaiting: cfg.FallbackMaxWaiting, - }) + return &AccountSelectionResult{ + Account: acc, + WaitPlan: &AccountWaitPlan{ + AccountID: acc.ID, + MaxConcurrency: acc.Concurrency, + Timeout: cfg.FallbackWaitTimeout, + MaxWaiting: cfg.FallbackMaxWaiting, + }, + }, nil } return nil, ErrNoAvailableAccounts } -func (s *GatewayService) tryAcquireByLegacyOrder(ctx context.Context, candidates []*Account, groupID *int64, sessionHash string, preferOAuth bool) (*AccountSelectionResult, bool, error) { +func (s *GatewayService) tryAcquireByLegacyOrder(ctx context.Context, candidates []*Account, groupID *int64, sessionHash string, preferOAuth bool) (*AccountSelectionResult, bool) { ordered := append([]*Account(nil), candidates...) sortAccountsByPriorityAndLastUsed(ordered, preferOAuth) @@ -1786,15 +1916,15 @@ func (s *GatewayService) tryAcquireByLegacyOrder(ctx context.Context, candidates if sessionHash != "" && s.cache != nil { _ = s.cache.SetSessionAccountID(ctx, derefGroupID(groupID), sessionHash, acc.ID, stickySessionTTL) } - selection, err := s.newSelectionResult(ctx, acc, true, result.ReleaseFunc, nil) - if err != nil { - return nil, false, err - } - return selection, true, nil + return &AccountSelectionResult{ + Account: acc, + Acquired: true, + ReleaseFunc: result.ReleaseFunc, + }, true } } - return nil, false, nil + return nil, false } func (s *GatewayService) schedulingConfig() config.GatewaySchedulingConfig { @@ -1939,6 +2069,9 @@ func (s *GatewayService) resolvePlatform(ctx context.Context, groupID *int64, gr } func (s *GatewayService) listSchedulableAccounts(ctx context.Context, groupID *int64, platform string, hasForcePlatform bool) ([]Account, bool, error) { + if platform == PlatformSora { + return s.listSoraSchedulableAccounts(ctx, groupID) + } if s.schedulerSnapshot != nil { accounts, useMixed, err := s.schedulerSnapshot.ListSchedulableAccounts(ctx, groupID, platform, hasForcePlatform) if err == nil { @@ -2035,6 +2168,53 @@ func (s *GatewayService) listSchedulableAccounts(ctx context.Context, groupID *i return accounts, useMixed, nil } +func (s *GatewayService) listSoraSchedulableAccounts(ctx context.Context, groupID *int64) ([]Account, bool, error) { + const useMixed = false + + var accounts []Account + var err error + if s.cfg != nil && s.cfg.RunMode == config.RunModeSimple { + accounts, err = s.accountRepo.ListByPlatform(ctx, PlatformSora) + } else if groupID != nil { + accounts, err = s.accountRepo.ListByGroup(ctx, *groupID) + } else { + accounts, err = s.accountRepo.ListByPlatform(ctx, PlatformSora) + } + if err != nil { + slog.Debug("account_scheduling_list_failed", + "group_id", derefGroupID(groupID), + "platform", PlatformSora, + "error", err) + return nil, useMixed, err + } + + filtered := make([]Account, 0, len(accounts)) + for _, acc := range accounts { + if acc.Platform != PlatformSora { + continue + } + if !s.isSoraAccountSchedulable(&acc) { + continue + } + filtered = append(filtered, acc) + } + slog.Debug("account_scheduling_list_sora", + "group_id", derefGroupID(groupID), + "platform", PlatformSora, + "raw_count", len(accounts), + "filtered_count", len(filtered)) + for _, acc := range filtered { + slog.Debug("account_scheduling_account_detail", + "account_id", acc.ID, + "name", acc.Name, + "platform", acc.Platform, + "type", acc.Type, + "status", acc.Status, + "tls_fingerprint", acc.IsTLSFingerprintEnabled()) + } + return filtered, useMixed, nil +} + // IsSingleAntigravityAccountGroup 检查指定分组是否只有一个 antigravity 平台的可调度账号。 // 用于 Handler 层在首次请求时提前设置 SingleAccountRetry context, // 避免单账号分组收到 503 时错误地设置模型限流标记导致后续请求连续快速失败。 @@ -2059,10 +2239,33 @@ func (s *GatewayService) isAccountAllowedForPlatform(account *Account, platform return account.Platform == platform } +func (s *GatewayService) isSoraAccountSchedulable(account *Account) bool { + return s.soraUnschedulableReason(account) == "" +} + +func (s *GatewayService) soraUnschedulableReason(account *Account) string { + if account == nil { + return "account_nil" + } + if account.Status != StatusActive { + return fmt.Sprintf("status=%s", account.Status) + } + if !account.Schedulable { + return "schedulable=false" + } + if account.TempUnschedulableUntil != nil && time.Now().Before(*account.TempUnschedulableUntil) { + return fmt.Sprintf("temp_unschedulable_until=%s", account.TempUnschedulableUntil.UTC().Format(time.RFC3339)) + } + return "" +} + func (s *GatewayService) isAccountSchedulableForSelection(account *Account) bool { if account == nil { return false } + if account.Platform == PlatformSora { + return s.isSoraAccountSchedulable(account) + } return account.IsSchedulable() } @@ -2070,6 +2273,12 @@ func (s *GatewayService) isAccountSchedulableForModelSelection(ctx context.Conte if account == nil { return false } + if account.Platform == PlatformSora { + if !s.isSoraAccountSchedulable(account) { + return false + } + return account.GetRateLimitRemainingTimeWithContext(ctx, requestedModel) <= 0 + } return account.IsSchedulableForModelWithContext(ctx, requestedModel) } @@ -2409,31 +2618,34 @@ func (s *GatewayService) getSchedulableAccount(ctx context.Context, accountID in return s.accountRepo.GetByID(ctx, accountID) } -func (s *GatewayService) hydrateSelectedAccount(ctx context.Context, account *Account) (*Account, error) { - if account == nil || s.schedulerSnapshot == nil { - return account, nil +// populateAffinityCounts 批量获取账号的亲和客户端数量并填入 accountWithLoad 切片。 +// 仅当存在开启了客户端亲和的账号时才查询 Redis,否则跳过。 +func (s *GatewayService) populateAffinityCounts(ctx context.Context, accounts []accountWithLoad, groupID int64) { + if s.cache == nil || len(accounts) == 0 { + return } - hydrated, err := s.schedulerSnapshot.GetAccount(ctx, account.ID) + // 快速检查:是否有任何账号开启了亲和 + hasAffinity := false + for _, acc := range accounts { + if acc.account.IsAffinityEnabled() { + hasAffinity = true + break + } + } + if !hasAffinity { + return + } + accountIDs := make([]int64, len(accounts)) + for i, acc := range accounts { + accountIDs[i] = acc.account.ID + } + countMap, err := s.cache.GetAccountAffinityCountBatch(ctx, groupID, accountIDs, ClientAffinityTTL) if err != nil { - return nil, err + return // 查询失败不影响调度,affinityCount 保持 0 } - if hydrated == nil { - return nil, fmt.Errorf("selected gateway account %d not found during hydration", account.ID) + for i := range accounts { + accounts[i].affinityCount = countMap[accounts[i].account.ID] } - return hydrated, nil -} - -func (s *GatewayService) newSelectionResult(ctx context.Context, account *Account, acquired bool, release func(), waitPlan *AccountWaitPlan) (*AccountSelectionResult, error) { - hydrated, err := s.hydrateSelectedAccount(ctx, account) - if err != nil { - return nil, err - } - return &AccountSelectionResult{ - Account: hydrated, - Acquired: acquired, - ReleaseFunc: release, - WaitPlan: waitPlan, - }, nil } // filterByMinPriority 过滤出优先级最小的账号集合 @@ -2476,6 +2688,64 @@ func filterByMinLoadRate(accounts []accountWithLoad) []accountWithLoad { return result } +// filterByMinAffinityCount 过滤出亲和客户端数最少的账号集合 +func filterByMinAffinityCount(accounts []accountWithLoad) []accountWithLoad { + if len(accounts) == 0 { + return accounts + } + minCount := accounts[0].affinityCount + for _, acc := range accounts[1:] { + if acc.affinityCount < minCount { + minCount = acc.affinityCount + } + } + result := make([]accountWithLoad, 0, len(accounts)) + for _, acc := range accounts { + if acc.affinityCount == minCount { + result = append(result, acc) + } + } + return result +} + +// classifyByAffinityZone 按亲和分区对候选账号进行分类。 +// 返回值:仅绿区账号(有绿区时),否则返回黄区账号。红区账号被移除。 +// 如果没有任何账号开启了亲和三区配置(即 affinity_base <= 0),则原样返回所有账号。 +func classifyByAffinityZone(accounts []accountWithLoad) []accountWithLoad { + if len(accounts) == 0 { + return accounts + } + // 快速检查:是否有任何账号配置了 affinity_base + hasZoneConfig := false + for _, acc := range accounts { + if acc.account.IsAffinityEnabled() && acc.account.GetAffinityBase() > 0 { + hasZoneConfig = true + break + } + } + if !hasZoneConfig { + return accounts + } + + greens := make([]accountWithLoad, 0, len(accounts)) + yellows := make([]accountWithLoad, 0, len(accounts)) + for _, acc := range accounts { + zone := acc.account.GetAffinityZone(acc.affinityCount) + switch zone { + case AffinityZoneGreen: + greens = append(greens, acc) + case AffinityZoneYellow: + yellows = append(yellows, acc) + case AffinityZoneRed: + // 红区:移除,不参与调度 + } + } + if len(greens) > 0 { + return greens + } + return yellows +} + // selectByLRU 从集合中选择最久未用的账号 // 如果有多个账号具有相同的最小 LastUsedAt,则随机选择一个 func selectByLRU(accounts []accountWithLoad, preferOAuth bool) *accountWithLoad { @@ -2711,12 +2981,6 @@ func (s *GatewayService) selectAccountForModelWithPlatform(ctx context.Context, preferOAuth := platform == PlatformGemini routingAccountIDs := s.routingAccountIDsForRequest(ctx, groupID, requestedModel, platform) - // require_privacy_set: 获取分组信息 - var schedGroup *Group - if groupID != nil && s.groupRepo != nil { - schedGroup, _ = s.groupRepo.GetByID(ctx, *groupID) - } - var accounts []Account accountsLoaded := false @@ -2788,12 +3052,6 @@ func (s *GatewayService) selectAccountForModelWithPlatform(ctx context.Context, if !s.isAccountSchedulableForSelection(acc) { continue } - // require_privacy_set: 跳过 privacy 未设置的账号并标记异常 - if schedGroup != nil && schedGroup.RequirePrivacySet && !acc.IsPrivacySet() { - _ = s.accountRepo.SetError(ctx, acc.ID, - fmt.Sprintf("Privacy not set, required by group [%s]", schedGroup.Name)) - continue - } if requestedModel != "" && !s.isModelSupportedByAccountWithContext(ctx, acc, requestedModel) { continue } @@ -2885,8 +3143,6 @@ func (s *GatewayService) selectAccountForModelWithPlatform(ctx context.Context, ctx = s.withRPMPrefetch(ctx, accounts) // 3. 按优先级+最久未用选择(考虑模型支持) - // needsUpstreamCheck 仅在主选择循环中使用;粘性会话命中时跳过此检查, - // 因为粘性会话优先保持连接一致性,且 upstream 计费基准极少使用。 needsUpstreamCheck := s.needsUpstreamChannelRestrictionCheck(ctx, groupID) var selected *Account for i := range accounts { @@ -2899,12 +3155,6 @@ func (s *GatewayService) selectAccountForModelWithPlatform(ctx context.Context, if !s.isAccountSchedulableForSelection(acc) { continue } - // require_privacy_set: 跳过 privacy 未设置的账号并标记异常 - if schedGroup != nil && schedGroup.RequirePrivacySet && !acc.IsPrivacySet() { - _ = s.accountRepo.SetError(ctx, acc.ID, - fmt.Sprintf("Privacy not set, required by group [%s]", schedGroup.Name)) - continue - } if requestedModel != "" && !s.isModelSupportedByAccountWithContext(ctx, acc, requestedModel) { continue } @@ -2971,12 +3221,6 @@ func (s *GatewayService) selectAccountWithMixedScheduling(ctx context.Context, g preferOAuth := nativePlatform == PlatformGemini routingAccountIDs := s.routingAccountIDsForRequest(ctx, groupID, requestedModel, nativePlatform) - // require_privacy_set: 获取分组信息 - var schedGroup *Group - if groupID != nil && s.groupRepo != nil { - schedGroup, _ = s.groupRepo.GetByID(ctx, *groupID) - } - var accounts []Account accountsLoaded := false @@ -3044,12 +3288,6 @@ func (s *GatewayService) selectAccountWithMixedScheduling(ctx context.Context, g if !s.isAccountSchedulableForSelection(acc) { continue } - // require_privacy_set: 跳过 privacy 未设置的账号并标记异常 - if schedGroup != nil && schedGroup.RequirePrivacySet && !acc.IsPrivacySet() { - _ = s.accountRepo.SetError(ctx, acc.ID, - fmt.Sprintf("Privacy not set, required by group [%s]", schedGroup.Name)) - continue - } // 过滤:原生平台直接通过,antigravity 需要启用混合调度 if acc.Platform == PlatformAntigravity && !acc.IsMixedSchedulingEnabled() { continue @@ -3143,7 +3381,6 @@ func (s *GatewayService) selectAccountWithMixedScheduling(ctx context.Context, g ctx = s.withRPMPrefetch(ctx, accounts) // 3. 按优先级+最久未用选择(考虑模型支持和混合调度) - // needsUpstreamCheck 仅在主选择循环中使用;粘性会话命中时跳过此检查。 needsUpstreamCheck := s.needsUpstreamChannelRestrictionCheck(ctx, groupID) var selected *Account for i := range accounts { @@ -3156,12 +3393,6 @@ func (s *GatewayService) selectAccountWithMixedScheduling(ctx context.Context, g if !s.isAccountSchedulableForSelection(acc) { continue } - // require_privacy_set: 跳过 privacy 未设置的账号并标记异常 - if schedGroup != nil && schedGroup.RequirePrivacySet && !acc.IsPrivacySet() { - _ = s.accountRepo.SetError(ctx, acc.ID, - fmt.Sprintf("Privacy not set, required by group [%s]", schedGroup.Name)) - continue - } // 过滤:原生平台直接通过,antigravity 需要启用混合调度 if acc.Platform == PlatformAntigravity && !acc.IsMixedSchedulingEnabled() { continue @@ -3273,6 +3504,9 @@ func (s *GatewayService) logDetailedSelectionFailure( stats.SampleMappingIDs, stats.SampleRateLimitIDs, ) + if platform == PlatformSora { + s.logSoraSelectionFailureDetails(ctx, groupID, sessionHash, requestedModel, accounts, excludedIDs, allowMixedScheduling) + } return stats } @@ -3329,7 +3563,11 @@ func (s *GatewayService) diagnoseSelectionFailure( return selectionFailureDiagnosis{Category: "excluded"} } if !s.isAccountSchedulableForSelection(acc) { - return selectionFailureDiagnosis{Category: "unschedulable", Detail: "generic_unschedulable"} + detail := "generic_unschedulable" + if acc.Platform == PlatformSora { + detail = s.soraUnschedulableReason(acc) + } + return selectionFailureDiagnosis{Category: "unschedulable", Detail: detail} } if isPlatformFilteredForSelection(acc, platform, allowMixedScheduling) { return selectionFailureDiagnosis{ @@ -3353,6 +3591,57 @@ func (s *GatewayService) diagnoseSelectionFailure( return selectionFailureDiagnosis{Category: "eligible"} } +func (s *GatewayService) logSoraSelectionFailureDetails( + ctx context.Context, + groupID *int64, + sessionHash string, + requestedModel string, + accounts []Account, + excludedIDs map[int64]struct{}, + allowMixedScheduling bool, +) { + const maxLines = 30 + logged := 0 + + for i := range accounts { + if logged >= maxLines { + break + } + acc := &accounts[i] + diagnosis := s.diagnoseSelectionFailure(ctx, acc, requestedModel, PlatformSora, excludedIDs, allowMixedScheduling) + if diagnosis.Category == "eligible" { + continue + } + detail := diagnosis.Detail + if detail == "" { + detail = "-" + } + logger.LegacyPrintf( + "service.gateway", + "[SelectAccountDetailed:Sora] group_id=%v model=%s session=%s account_id=%d account_platform=%s category=%s detail=%s", + derefGroupID(groupID), + requestedModel, + shortSessionHash(sessionHash), + acc.ID, + acc.Platform, + diagnosis.Category, + detail, + ) + logged++ + } + if len(accounts) > maxLines { + logger.LegacyPrintf( + "service.gateway", + "[SelectAccountDetailed:Sora] group_id=%v model=%s session=%s truncated=true total=%d logged=%d", + derefGroupID(groupID), + requestedModel, + shortSessionHash(sessionHash), + len(accounts), + logged, + ) + } +} + func isPlatformFilteredForSelection(acc *Account, platform string, allowMixedScheduling bool) bool { if acc == nil { return true @@ -3431,10 +3720,17 @@ func (s *GatewayService) isModelSupportedByAccount(account *Account, requestedMo } return mapAntigravityModel(account, requestedModel) != "" } + if account.Platform == PlatformSora { + return s.isSoraModelSupportedByAccount(account, requestedModel) + } if account.IsBedrock() { _, ok := ResolveBedrockModelID(account, requestedModel) return ok } + // OpenAI 透传模式:仅替换认证,允许所有模型 + if account.Platform == PlatformOpenAI && account.IsOpenAIPassthroughEnabled() { + return true + } // OAuth/SetupToken 账号使用 Anthropic 标准映射(短ID → 长ID) if account.Platform == PlatformAnthropic && account.Type != AccountTypeAPIKey { requestedModel = claude.NormalizeModelID(requestedModel) @@ -3443,6 +3739,143 @@ func (s *GatewayService) isModelSupportedByAccount(account *Account, requestedMo return account.IsModelSupported(requestedModel) } +func (s *GatewayService) isSoraModelSupportedByAccount(account *Account, requestedModel string) bool { + if account == nil { + return false + } + if strings.TrimSpace(requestedModel) == "" { + return true + } + + // 先走原始精确/通配符匹配。 + mapping := account.GetModelMapping() + if len(mapping) == 0 || account.IsModelSupported(requestedModel) { + return true + } + + aliases := buildSoraModelAliases(requestedModel) + if len(aliases) == 0 { + return false + } + + hasSoraSelector := false + for pattern := range mapping { + if !isSoraModelSelector(pattern) { + continue + } + hasSoraSelector = true + if matchPatternAnyAlias(pattern, aliases) { + return true + } + } + + // 兼容旧账号:mapping 存在但未配置任何 Sora 选择器(例如只含 gpt-*), + // 此时不应误拦截 Sora 模型请求。 + if !hasSoraSelector { + return true + } + + return false +} + +func matchPatternAnyAlias(pattern string, aliases []string) bool { + normalizedPattern := strings.ToLower(strings.TrimSpace(pattern)) + if normalizedPattern == "" { + return false + } + for _, alias := range aliases { + if matchWildcard(normalizedPattern, alias) { + return true + } + } + return false +} + +func isSoraModelSelector(pattern string) bool { + p := strings.ToLower(strings.TrimSpace(pattern)) + if p == "" { + return false + } + + switch { + case strings.HasPrefix(p, "sora"), + strings.HasPrefix(p, "gpt-image"), + strings.HasPrefix(p, "prompt-enhance"), + strings.HasPrefix(p, "sy_"): + return true + } + + return p == "video" || p == "image" +} + +func buildSoraModelAliases(requestedModel string) []string { + modelID := strings.ToLower(strings.TrimSpace(requestedModel)) + if modelID == "" { + return nil + } + + aliases := make([]string, 0, 8) + addAlias := func(value string) { + v := strings.ToLower(strings.TrimSpace(value)) + if v == "" { + return + } + for _, existing := range aliases { + if existing == v { + return + } + } + aliases = append(aliases, v) + } + + addAlias(modelID) + cfg, ok := GetSoraModelConfig(modelID) + if ok { + addAlias(cfg.Model) + switch cfg.Type { + case "video": + addAlias("video") + addAlias("sora") + addAlias(soraVideoFamilyAlias(modelID)) + case "image": + addAlias("image") + addAlias("gpt-image") + case "prompt_enhance": + addAlias("prompt-enhance") + } + return aliases + } + + switch { + case strings.HasPrefix(modelID, "sora"): + addAlias("video") + addAlias("sora") + addAlias(soraVideoFamilyAlias(modelID)) + case strings.HasPrefix(modelID, "gpt-image"): + addAlias("image") + addAlias("gpt-image") + case strings.HasPrefix(modelID, "prompt-enhance"): + addAlias("prompt-enhance") + default: + return nil + } + + return aliases +} + +func soraVideoFamilyAlias(modelID string) string { + switch { + case strings.HasPrefix(modelID, "sora2pro-hd"): + return "sora2pro-hd" + case strings.HasPrefix(modelID, "sora2pro"): + return "sora2pro" + case strings.HasPrefix(modelID, "sora2"): + return "sora2" + default: + return "" + } +} + // GetAccessToken 获取账号凭证 func (s *GatewayService) GetAccessToken(ctx context.Context, account *Account) (string, string, error) { switch account.Type { @@ -3719,86 +4152,6 @@ func injectClaudeCodePrompt(body []byte, system any) []byte { return result } -// rewriteSystemForNonClaudeCode 将非 Claude Code 客户端的 system prompt 迁移至 messages, -// system 字段仅保留 Claude Code 标识提示词。 -// Anthropic 基于 system 参数内容检测第三方应用,仅前置追加 Claude Code 提示词 -// 无法通过检测,因为后续内容仍为非 Claude Code 格式。 -// 策略:将原始 system prompt 提取并注入为 user/assistant 消息对,system 仅保留 Claude Code 标识。 -func rewriteSystemForNonClaudeCode(body []byte, system any) []byte { - system = normalizeSystemParam(system) - - // 1. 提取原始 system prompt 文本 - var originalSystemText string - switch v := system.(type) { - case string: - originalSystemText = strings.TrimSpace(v) - case []any: - var parts []string - for _, item := range v { - if m, ok := item.(map[string]any); ok { - if text, ok := m["text"].(string); ok && strings.TrimSpace(text) != "" { - parts = append(parts, text) - } - } - } - originalSystemText = strings.Join(parts, "\n\n") - } - - // 2. 将 system 替换为 Claude Code 标准提示词(array 格式,与真实 Claude Code 一致) - // 真实 Claude Code 始终以 [{type: "text", text: "...", cache_control: {type: "ephemeral"}}] 发送 system。 - // 使用 string 格式会被 Anthropic 检测为第三方应用。 - claudeCodeSystemBlock := []map[string]any{ - { - "type": "text", - "text": claudeCodeSystemPrompt, - "cache_control": map[string]string{"type": "ephemeral"}, - }, - } - out, ok := setJSONValueBytes(body, "system", claudeCodeSystemBlock) - if !ok { - logger.LegacyPrintf("service.gateway", "Warning: failed to set Claude Code system prompt") - return body - } - - // 3. 将原始 system prompt 作为 user/assistant 消息对注入到 messages 开头 - // 模型仍通过 messages 接收完整指令,保留客户端功能 - ccPromptTrimmed := strings.TrimSpace(claudeCodeSystemPrompt) - if originalSystemText != "" && originalSystemText != ccPromptTrimmed && !hasClaudeCodePrefix(originalSystemText) { - instrMsg, err1 := json.Marshal(map[string]any{ - "role": "user", - "content": []map[string]any{ - {"type": "text", "text": "[System Instructions]\n" + originalSystemText}, - }, - }) - ackMsg, err2 := json.Marshal(map[string]any{ - "role": "assistant", - "content": []map[string]any{ - {"type": "text", "text": "Understood. I will follow these instructions."}, - }, - }) - if err1 != nil || err2 != nil { - logger.LegacyPrintf("service.gateway", "Warning: failed to marshal system-to-messages injection") - return out - } - - // 重建 messages 数组:[instruction, ack, ...originalMessages] - items := [][]byte{instrMsg, ackMsg} - messagesResult := gjson.GetBytes(out, "messages") - if messagesResult.IsArray() { - messagesResult.ForEach(func(_, msg gjson.Result) bool { - items = append(items, []byte(msg.Raw)) - return true - }) - } - - if next, setOk := setJSONRawBytes(out, "messages", buildJSONArrayRaw(items)); setOk { - out = next - } - } - - return out -} - type cacheControlPath struct { path string log string @@ -3960,7 +4313,7 @@ func (s *GatewayService) Forward(ctx context.Context, c *gin.Context, account *A // Beta policy: evaluate once; block check + cache filter set for buildUpstreamRequest. // Always overwrite the cache to prevent stale values from a previous retry with a different account. if account.Platform == PlatformAnthropic && c != nil { - policy := s.evaluateBetaPolicy(ctx, c.GetHeader("anthropic-beta"), account, parsed.Model) + policy := s.evaluateBetaPolicy(ctx, c.GetHeader("anthropic-beta"), account) if policy.blockErr != nil { return nil, policy.blockErr } @@ -3990,24 +4343,19 @@ func (s *GatewayService) Forward(ctx context.Context, c *gin.Context, account *A shouldMimicClaudeCode := account.IsOAuth() && !isClaudeCode if shouldMimicClaudeCode { - // 非 Claude Code 客户端:将 system 替换为 Claude Code 标识,原始 system 迁移至 messages + // 智能注入 Claude Code 系统提示词(仅 OAuth/SetupToken 账号需要) // 条件:1) OAuth/SetupToken 账号 2) 不是 Claude Code 客户端 3) 不是 Haiku 模型 4) system 中还没有 Claude Code 提示词 - systemRewritten := false if !strings.Contains(strings.ToLower(reqModel), "haiku") && !systemIncludesClaudeCodePrompt(parsed.System) { - body = rewriteSystemForNonClaudeCode(body, parsed.System) - systemRewritten = true + body = injectClaudeCodePrompt(body, parsed.System) } - // system 被重写时保留 CC prompt 的 cache_control: ephemeral(匹配真实 Claude Code 行为); - // 未重写时(haiku / 已含 CC 前缀)剥离客户端 cache_control,与原有行为一致。 - // 两种情况下 enforceCacheControlLimit 都会兜底处理上限。 - normalizeOpts := claudeOAuthNormalizeOptions{stripSystemCacheControl: !systemRewritten} + normalizeOpts := claudeOAuthNormalizeOptions{stripSystemCacheControl: true} if s.identityService != nil { fp, err := s.identityService.GetOrCreateFingerprint(ctx, account.ID, c.Request.Header) if err == nil && fp != nil { // metadata 透传开启时跳过 metadata 注入 - _, mimicMPT, _ := s.settingService.GetGatewayForwardingSettings(ctx) + _, mimicMPT := s.settingService.GetGatewayForwardingSettings(ctx) if !mimicMPT { if metadataUserID := s.buildOAuthMetadataUserID(parsed, account, fp); metadataUserID != "" { normalizeOpts.injectMetadata = true @@ -4054,12 +4402,10 @@ func (s *GatewayService) Forward(ctx context.Context, c *gin.Context, account *A return nil, err } - // 获取代理URL(自定义 base URL 模式下,proxy 通过 buildCustomRelayURL 作为查询参数传递) + // 获取代理URL proxyURL := "" if account.ProxyID != nil && account.Proxy != nil { - if !account.IsCustomBaseURLEnabled() || account.GetCustomBaseURL() == "" { - proxyURL = account.Proxy.URL() - } + proxyURL = account.Proxy.URL() } // 解析 TLS 指纹 profile(同一请求生命周期内不变,避免重试循环中重复解析) @@ -4468,6 +4814,7 @@ func (s *GatewayService) Forward(ctx context.Context, c *gin.Context, account *A } // 处理正常响应 + ctx = withClaudeMaxResponseRewriteContext(ctx, c, parsed) // 触发上游接受回调(提前释放串行锁,不等流完成) if parsed.OnUpstreamAccepted != nil { @@ -5534,16 +5881,6 @@ func (s *GatewayService) buildUpstreamRequest(ctx context.Context, c *gin.Contex } targetURL = validatedURL + "/v1/messages?beta=true" } - } else if account.IsCustomBaseURLEnabled() { - customURL := account.GetCustomBaseURL() - if customURL == "" { - return nil, fmt.Errorf("custom_base_url is enabled but not configured for account %d", account.ID) - } - validatedURL, err := s.validateUpstreamBaseURL(customURL) - if err != nil { - return nil, err - } - targetURL = s.buildCustomRelayURL(validatedURL, "/v1/messages", account) } clientHeaders := http.Header{} @@ -5553,9 +5890,9 @@ func (s *GatewayService) buildUpstreamRequest(ctx context.Context, c *gin.Contex // OAuth账号:应用统一指纹和metadata重写(受设置开关控制) var fingerprint *Fingerprint - enableFP, enableMPT, enableCCH := true, false, false + enableFP, enableMPT := true, false if s.settingService != nil { - enableFP, enableMPT, enableCCH = s.settingService.GetGatewayForwardingSettings(ctx) + enableFP, enableMPT = s.settingService.GetGatewayForwardingSettings(ctx) } if account.IsOAuth() && s.identityService != nil { // 1. 获取或创建指纹(包含随机生成的ClientID) @@ -5582,15 +5919,6 @@ func (s *GatewayService) buildUpstreamRequest(ctx context.Context, c *gin.Contex } } - // 同步 billing header cc_version 与实际发送的 User-Agent 版本 - if fingerprint != nil { - body = syncBillingHeaderVersion(body, fingerprint.UserAgent) - } - // CCH 签名:将 cch=00000 占位符替换为 xxHash64 签名(需在所有 body 修改之后) - if enableCCH { - body = signBillingHeaderCCH(body) - } - req, err := http.NewRequestWithContext(ctx, "POST", targetURL, bytes.NewReader(body)) if err != nil { return nil, err @@ -5631,8 +5959,9 @@ func (s *GatewayService) buildUpstreamRequest(ctx context.Context, c *gin.Contex } // Build effective drop set: merge static defaults with dynamic beta policy filter rules - policyFilterSet := s.getBetaPolicyFilterSet(ctx, c, account, modelID) + policyFilterSet := s.getBetaPolicyFilterSet(ctx, c, account) effectiveDropSet := mergeDropSets(policyFilterSet) + effectiveDropWithClaudeCodeSet := mergeDropSets(policyFilterSet, claude.BetaClaudeCode) // 处理 anthropic-beta header(OAuth 账号需要包含 oauth beta) if tokenType == "oauth" { @@ -5643,16 +5972,11 @@ func (s *GatewayService) buildUpstreamRequest(ctx context.Context, c *gin.Contex applyClaudeCodeMimicHeaders(req, reqStream) incomingBeta := getHeaderRaw(req.Header, "anthropic-beta") - // Claude Code OAuth credentials are scoped to Claude Code. - // Non-haiku models MUST include claude-code beta for Anthropic to recognize - // this as a legitimate Claude Code request; without it, the request is - // rejected as third-party ("out of extra usage"). - // Haiku models are exempt from third-party detection and don't need it. + // Match real Claude CLI traffic (per mitmproxy reports): + // messages requests typically use only oauth + interleaved-thinking. + // Also drop claude-code beta if a downstream client added it. requiredBetas := []string{claude.BetaOAuth, claude.BetaInterleavedThinking} - if !strings.Contains(strings.ToLower(modelID), "haiku") { - requiredBetas = []string{claude.BetaClaudeCode, claude.BetaOAuth, claude.BetaInterleavedThinking} - } - setHeaderRaw(req.Header, "anthropic-beta", mergeAnthropicBetaDropping(requiredBetas, incomingBeta, effectiveDropSet)) + setHeaderRaw(req.Header, "anthropic-beta", mergeAnthropicBetaDropping(requiredBetas, incomingBeta, effectiveDropWithClaudeCodeSet)) } else { // Claude Code 客户端:尽量透传原始 header,仅补齐 oauth beta clientBetaHeader := getHeaderRaw(req.Header, "anthropic-beta") @@ -5672,15 +5996,6 @@ func (s *GatewayService) buildUpstreamRequest(ctx context.Context, c *gin.Contex } } - // 同步 X-Claude-Code-Session-Id 头:取 body 中已处理的 metadata.user_id 的 session_id 覆盖 - if sessionHeader := getHeaderRaw(req.Header, "X-Claude-Code-Session-Id"); sessionHeader != "" { - if uid := gjson.GetBytes(body, "metadata.user_id").String(); uid != "" { - if parsed := ParseMetadataUserID(uid); parsed != nil { - setHeaderRaw(req.Header, "X-Claude-Code-Session-Id", parsed.SessionID) - } - } - } - // === DEBUG: 打印上游转发请求(headers + body 摘要),与 CLIENT_ORIGINAL 对比 === s.debugLogGatewaySnapshot("UPSTREAM_FORWARD", req.Header, body, map[string]string{ "url": req.URL.String(), @@ -5875,7 +6190,7 @@ type betaPolicyResult struct { } // evaluateBetaPolicy loads settings once and evaluates all rules against the given request. -func (s *GatewayService) evaluateBetaPolicy(ctx context.Context, betaHeader string, account *Account, model string) betaPolicyResult { +func (s *GatewayService) evaluateBetaPolicy(ctx context.Context, betaHeader string, account *Account) betaPolicyResult { if s.settingService == nil { return betaPolicyResult{} } @@ -5890,11 +6205,10 @@ func (s *GatewayService) evaluateBetaPolicy(ctx context.Context, betaHeader stri if !betaPolicyScopeMatches(rule.Scope, isOAuth, isBedrock) { continue } - effectiveAction, effectiveErrMsg := resolveRuleAction(rule, model) - switch effectiveAction { + switch rule.Action { case BetaPolicyActionBlock: if result.blockErr == nil && betaHeader != "" && containsBetaToken(betaHeader, rule.BetaToken) { - msg := effectiveErrMsg + msg := rule.ErrorMessage if msg == "" { msg = "beta feature " + rule.BetaToken + " is not allowed" } @@ -5936,7 +6250,7 @@ const betaPolicyFilterSetKey = "betaPolicyFilterSet" // In the /v1/messages path, Forward() evaluates the policy first and caches the result; // buildUpstreamRequest reuses it (zero extra DB calls). In the count_tokens path, this // evaluates on demand (one DB call). -func (s *GatewayService) getBetaPolicyFilterSet(ctx context.Context, c *gin.Context, account *Account, model string) map[string]struct{} { +func (s *GatewayService) getBetaPolicyFilterSet(ctx context.Context, c *gin.Context, account *Account) map[string]struct{} { if c != nil { if v, ok := c.Get(betaPolicyFilterSetKey); ok { if fs, ok := v.(map[string]struct{}); ok { @@ -5944,7 +6258,7 @@ func (s *GatewayService) getBetaPolicyFilterSet(ctx context.Context, c *gin.Cont } } } - return s.evaluateBetaPolicy(ctx, "", account, model).filterSet + return s.evaluateBetaPolicy(ctx, "", account).filterSet } // betaPolicyScopeMatches checks whether a rule's scope matches the current account type. @@ -5963,33 +6277,6 @@ func betaPolicyScopeMatches(scope string, isOAuth bool, isBedrock bool) bool { } } -// matchModelWhitelist checks if a model matches any pattern in the whitelist. -// Reuses matchModelPattern from group.go which supports exact and wildcard prefix matching. -func matchModelWhitelist(model string, whitelist []string) bool { - for _, pattern := range whitelist { - if matchModelPattern(pattern, model) { - return true - } - } - return false -} - -// resolveRuleAction determines the effective action and error message for a rule given the request model. -// When ModelWhitelist is empty, the rule's primary Action/ErrorMessage applies unconditionally. -// When non-empty, Action applies to matching models; FallbackAction/FallbackErrorMessage applies to others. -func resolveRuleAction(rule BetaPolicyRule, model string) (action, errorMessage string) { - if len(rule.ModelWhitelist) == 0 { - return rule.Action, rule.ErrorMessage - } - if matchModelWhitelist(model, rule.ModelWhitelist) { - return rule.Action, rule.ErrorMessage - } - if rule.FallbackAction != "" { - return rule.FallbackAction, rule.FallbackErrorMessage - } - return BetaPolicyActionPass, "" // default fallback: pass (fail-open) -} - // droppedBetaSet returns claude.DroppedBetas as a set, with optional extra tokens. func droppedBetaSet(extra ...string) map[string]struct{} { m := make(map[string]struct{}, len(defaultDroppedBetasSet)+len(extra)) @@ -6036,7 +6323,7 @@ func (s *GatewayService) resolveBedrockBetaTokensForRequest( modelID string, ) ([]string, error) { // 1. 对原始 header 中的 beta token 做 block 检查(快速失败) - policy := s.evaluateBetaPolicy(ctx, betaHeader, account, modelID) + policy := s.evaluateBetaPolicy(ctx, betaHeader, account) if policy.blockErr != nil { return nil, policy.blockErr } @@ -6048,7 +6335,7 @@ func (s *GatewayService) resolveBedrockBetaTokensForRequest( // 例如:管理员 block 了 interleaved-thinking,客户端不在 header 中带该 token, // 但请求体中包含 thinking 字段 → autoInjectBedrockBetaTokens 会自动补齐 → // 如果不做此检查,block 规则会被绕过。 - if blockErr := s.checkBetaPolicyBlockForTokens(ctx, betaTokens, account, modelID); blockErr != nil { + if blockErr := s.checkBetaPolicyBlockForTokens(ctx, betaTokens, account); blockErr != nil { return nil, blockErr } @@ -6057,7 +6344,7 @@ func (s *GatewayService) resolveBedrockBetaTokensForRequest( // checkBetaPolicyBlockForTokens 检查 token 列表中是否有被管理员 block 规则命中的 token。 // 用于补充 evaluateBetaPolicy 对 header 的检查,覆盖 body 自动注入的 token。 -func (s *GatewayService) checkBetaPolicyBlockForTokens(ctx context.Context, tokens []string, account *Account, model string) *BetaBlockedError { +func (s *GatewayService) checkBetaPolicyBlockForTokens(ctx context.Context, tokens []string, account *Account) *BetaBlockedError { if s.settingService == nil || len(tokens) == 0 { return nil } @@ -6069,15 +6356,14 @@ func (s *GatewayService) checkBetaPolicyBlockForTokens(ctx context.Context, toke isBedrock := account.IsBedrock() tokenSet := buildBetaTokenSet(tokens) for _, rule := range settings.Rules { - effectiveAction, effectiveErrMsg := resolveRuleAction(rule, model) - if effectiveAction != BetaPolicyActionBlock { + if rule.Action != BetaPolicyActionBlock { continue } if !betaPolicyScopeMatches(rule.Scope, isOAuth, isBedrock) { continue } if _, present := tokenSet[rule.BetaToken]; present { - msg := effectiveErrMsg + msg := rule.ErrorMessage if msg == "" { msg = "beta feature " + rule.BetaToken + " is not allowed" } @@ -6709,6 +6995,7 @@ func (s *GatewayService) handleStreamingResponse(ctx context.Context, resp *http needModelReplace := originalModel != mappedModel clientDisconnected := false // 客户端断开标志,断开后继续读取上游以获取完整usage sawTerminalEvent := false + skipAccountTTLOverride := false pendingEventLines := make([]string, 0, 4) @@ -6770,17 +7057,25 @@ func (s *GatewayService) handleStreamingResponse(ctx context.Context, resp *http if msg, ok := event["message"].(map[string]any); ok { if u, ok := msg["usage"].(map[string]any); ok { eventChanged = reconcileCachedTokens(u) || eventChanged + claudeMaxOutcome := applyClaudeMaxSimulationToUsageJSONMap(ctx, u, originalModel, account.ID) + if claudeMaxOutcome.Simulated { + skipAccountTTLOverride = true + } } } } if eventType == "message_delta" { if u, ok := event["usage"].(map[string]any); ok { eventChanged = reconcileCachedTokens(u) || eventChanged + claudeMaxOutcome := applyClaudeMaxSimulationToUsageJSONMap(ctx, u, originalModel, account.ID) + if claudeMaxOutcome.Simulated { + skipAccountTTLOverride = true + } } } // Cache TTL Override: 重写 SSE 事件中的 cache_creation 分类 - if account.IsCacheTTLOverrideEnabled() { + if account.IsCacheTTLOverrideEnabled() && !skipAccountTTLOverride { overrideTarget := account.GetCacheTTLOverrideTarget() if eventType == "message_start" { if msg, ok := event["message"].(map[string]any); ok { @@ -7212,8 +7507,13 @@ func (s *GatewayService) handleNonStreamingResponse(ctx context.Context, resp *h } } + claudeMaxOutcome := applyClaudeMaxSimulationToUsage(ctx, &response.Usage, originalModel, account.ID) + if claudeMaxOutcome.Simulated { + body = rewriteClaudeUsageJSONBytes(body, response.Usage) + } + // Cache TTL Override: 重写 non-streaming 响应中的 cache_creation 分类 - if account.IsCacheTTLOverrideEnabled() { + if account.IsCacheTTLOverrideEnabled() && !claudeMaxOutcome.Simulated { overrideTarget := account.GetCacheTTLOverrideTarget() if applyCacheTTLOverride(&response.Usage, overrideTarget) { // 同步更新 body JSON 中的嵌套 cache_creation 对象 @@ -7279,6 +7579,7 @@ func (s *GatewayService) getUserGroupRateMultiplier(ctx context.Context, userID, // RecordUsageInput 记录使用量的输入参数 type RecordUsageInput struct { Result *ForwardResult + ParsedRequest *ParsedRequest APIKey *APIKey User *User Account *Account @@ -7437,6 +7738,9 @@ func buildUsageBillingCommand(requestID string, usageLog *UsageLog, p *postUsage cmd.CacheCreationTokens = usageLog.CacheCreationTokens cmd.CacheReadTokens = usageLog.CacheReadTokens cmd.ImageCount = usageLog.ImageCount + if usageLog.MediaType != nil { + cmd.MediaType = *usageLog.MediaType + } if usageLog.ServiceTier != nil { cmd.ServiceTier = *usageLog.ServiceTier } @@ -7592,6 +7896,8 @@ type recordUsageOpts struct { // EnableClaudePath 启用 Claude 路径特有逻辑: // - Claude Max 缓存计费策略 + // - Sora 媒体类型分支(image/video/prompt) + // - MediaType 字段写入使用日志 EnableClaudePath bool // 长上下文计费(仅 Gemini 路径需要) @@ -7616,6 +7922,7 @@ func (s *GatewayService) RecordUsage(ctx context.Context, input *RecordUsageInpu APIKeyService: input.APIKeyService, ChannelUsageFields: input.ChannelUsageFields, }, &recordUsageOpts{ + ParsedRequest: input.ParsedRequest, EnableClaudePath: true, }) } @@ -7682,6 +7989,7 @@ type recordUsageCoreInput struct { // recordUsageCore 是 RecordUsage 和 RecordUsageWithLongContext 的统一实现。 // opts 中的字段控制两者之间的差异行为: // - ParsedRequest != nil → 启用 Claude Max 缓存计费策略 +// - EnableSoraMedia → 启用 Sora MediaType 分支(image/video/prompt) // - LongContextThreshold > 0 → Token 计费回退走 CalculateCostWithLongContext func (s *GatewayService) recordUsageCore(ctx context.Context, input *recordUsageCoreInput, opts *recordUsageOpts) error { result := input.Result @@ -7699,9 +8007,21 @@ func (s *GatewayService) recordUsageCore(ctx context.Context, input *recordUsage result.Usage.InputTokens = 0 } - // Cache TTL Override: 确保计费时 token 分类与账号设置一致 + // Claude Max cache billing policy(仅 Claude 路径启用) cacheTTLOverridden := false - if account.IsCacheTTLOverrideEnabled() { + simulatedClaudeMax := false + if opts.EnableClaudePath { + var apiKeyGroup *Group + if apiKey != nil { + apiKeyGroup = apiKey.Group + } + claudeMaxOutcome := applyClaudeMaxCacheBillingPolicyToUsage(&result.Usage, opts.ParsedRequest, apiKeyGroup, result.Model, account.ID) + simulatedClaudeMax = claudeMaxOutcome.Simulated || + (shouldApplyClaudeMaxBillingRulesForUsage(apiKeyGroup, result.Model, opts.ParsedRequest) && hasCacheCreationTokens(result.Usage)) + } + + // Cache TTL Override: 确保计费时 token 分类与账号设置一致 + if account.IsCacheTTLOverrideEnabled() && !simulatedClaudeMax { applyCacheTTLOverride(&result.Usage, account.GetCacheTTLOverrideTarget()) cacheTTLOverridden = (result.Usage.CacheCreation5mTokens + result.Usage.CacheCreation1hTokens) > 0 } @@ -7783,6 +8103,16 @@ func (s *GatewayService) calculateRecordUsageCost( multiplier float64, opts *recordUsageOpts, ) *CostBreakdown { + // Sora 媒体类型分支(仅 Claude 路径启用) + if opts.EnableClaudePath { + if result.MediaType == MediaTypeImage || result.MediaType == MediaTypeVideo { + return s.calculateSoraMediaCost(result, apiKey, billingModel, multiplier) + } + if result.MediaType == MediaTypePrompt { + return &CostBreakdown{} + } + } + // 图片生成计费 if result.ImageCount > 0 { return s.calculateImageCost(ctx, result, apiKey, billingModel, multiplier) @@ -7792,6 +8122,28 @@ func (s *GatewayService) calculateRecordUsageCost( return s.calculateTokenCost(ctx, result, apiKey, billingModel, multiplier, opts) } +// calculateSoraMediaCost 计算 Sora 图片/视频的费用。 +func (s *GatewayService) calculateSoraMediaCost( + result *ForwardResult, + apiKey *APIKey, + billingModel string, + multiplier float64, +) *CostBreakdown { + var soraConfig *SoraPriceConfig + if apiKey.Group != nil { + soraConfig = &SoraPriceConfig{ + ImagePrice360: apiKey.Group.SoraImagePrice360, + ImagePrice540: apiKey.Group.SoraImagePrice540, + VideoPricePerRequest: apiKey.Group.SoraVideoPricePerRequest, + VideoPricePerRequestHD: apiKey.Group.SoraVideoPricePerRequestHD, + } + } + if result.MediaType == MediaTypeImage { + return s.billingService.CalculateSoraImageCost(result.ImageSize, result.ImageCount, soraConfig, multiplier) + } + return s.billingService.CalculateSoraVideoCost(billingModel, soraConfig, multiplier) +} + // resolveChannelPricing 检查指定模型是否存在渠道级别定价。 // 返回非 nil 的 ResolvedPricing 表示有渠道定价,nil 表示走默认定价路径。 func (s *GatewayService) resolveChannelPricing(ctx context.Context, billingModel string, apiKey *APIKey) *ResolvedPricing { @@ -7814,7 +8166,7 @@ func (s *GatewayService) calculateImageCost( billingModel string, multiplier float64, ) *CostBreakdown { - if resolved := s.resolveChannelPricing(ctx, billingModel, apiKey); resolved != nil { + if s.resolveChannelPricing(ctx, billingModel, apiKey) != nil { tokens := UsageTokens{ InputTokens: result.Usage.InputTokens, OutputTokens: result.Usage.OutputTokens, @@ -7829,7 +8181,6 @@ func (s *GatewayService) calculateImageCost( RequestCount: 1, RateMultiplier: multiplier, Resolver: s.resolver, - Resolved: resolved, }) if err != nil { logger.LegacyPrintf("service.gateway", "Calculate image token cost failed: %v", err) @@ -7872,7 +8223,7 @@ func (s *GatewayService) calculateTokenCost( var err error // 优先尝试渠道定价 → CalculateCostUnified - if resolved := s.resolveChannelPricing(ctx, billingModel, apiKey); resolved != nil { + if s.resolveChannelPricing(ctx, billingModel, apiKey) != nil { gid := apiKey.Group.ID cost, err = s.billingService.CalculateCostUnified(CostInput{ Ctx: ctx, @@ -7882,7 +8233,6 @@ func (s *GatewayService) calculateTokenCost( RequestCount: 1, RateMultiplier: multiplier, Resolver: s.resolver, - Resolved: resolved, }) } else if opts.LongContextThreshold > 0 { // 长上下文双倍计费(如 Gemini 200K 阈值) @@ -7940,12 +8290,13 @@ func (s *GatewayService) buildRecordUsageLog( RateMultiplier: multiplier, AccountRateMultiplier: &accountRateMultiplier, BillingType: billingType, - BillingMode: resolveBillingMode(result, cost), + BillingMode: resolveBillingMode(opts, result, cost), Stream: result.Stream, DurationMs: &durationMs, FirstTokenMs: result.FirstTokenMs, ImageCount: result.ImageCount, ImageSize: optionalTrimmedStringPtr(result.ImageSize), + MediaType: resolveMediaType(opts, result), CacheTTLOverridden: cacheTTLOverridden, ChannelID: optionalInt64Ptr(input.ChannelID), ModelMappingChain: optionalTrimmedStringPtr(input.ModelMappingChain), @@ -7969,7 +8320,13 @@ func (s *GatewayService) buildRecordUsageLog( } // resolveBillingMode 根据计费结果和请求类型确定计费模式。 -func resolveBillingMode(result *ForwardResult, cost *CostBreakdown) *string { +// Sora 媒体类型自身已确定计费模式(由上游处理),返回 nil 跳过。 +func resolveBillingMode(opts *recordUsageOpts, result *ForwardResult, cost *CostBreakdown) *string { + isSoraMedia := opts.EnableClaudePath && + (result.MediaType == MediaTypeImage || result.MediaType == MediaTypeVideo || result.MediaType == MediaTypePrompt) + if isSoraMedia { + return nil + } var mode string switch { case cost != nil && cost.BillingMode != "": @@ -7982,6 +8339,13 @@ func resolveBillingMode(result *ForwardResult, cost *CostBreakdown) *string { return &mode } +func resolveMediaType(opts *recordUsageOpts, result *ForwardResult) *string { + if opts.EnableClaudePath && strings.TrimSpace(result.MediaType) != "" { + return &result.MediaType + } + return nil +} + func optionalSubscriptionID(subscription *UserSubscription) *int64 { if subscription != nil { return &subscription.ID @@ -8010,8 +8374,8 @@ func (s *GatewayService) IsModelRestricted(ctx context.Context, groupID int64, m return s.channelService.IsModelRestricted(ctx, groupID, model) } -// ResolveChannelMappingAndRestrict 解析渠道映射。 -// 模型限制检查已移至调度阶段(checkChannelPricingRestriction),restricted 始终返回 false。 +// ResolveChannelMappingAndRestrict 解析渠道映射并检查模型限制。 +// 返回映射结果和是否被限制。 func (s *GatewayService) ResolveChannelMappingAndRestrict(ctx context.Context, groupID *int64, model string) (ChannelMappingResult, bool) { if s.channelService == nil { return ChannelMappingResult{MappedModel: model}, false @@ -8042,9 +8406,7 @@ func billingModelForRestriction(source, requestedModel, channelMappedModel strin return requestedModel case BillingModelSourceUpstream: return "" - case BillingModelSourceChannelMapped: - return channelMappedModel - default: + default: // channel_mapped return channelMappedModel } } @@ -8076,11 +8438,7 @@ func (s *GatewayService) needsUpstreamChannelRestrictionCheck(ctx context.Contex return false } ch, err := s.channelService.GetChannelForGroup(ctx, *groupID) - if err != nil { - slog.Warn("failed to check channel upstream restriction", "group_id", *groupID, "error", err) - return false - } - if ch == nil || !ch.RestrictModels { + if err != nil || ch == nil || !ch.RestrictModels { return false } return ch.BillingModelSource == BillingModelSourceUpstream @@ -8172,12 +8530,10 @@ func (s *GatewayService) ForwardCountTokens(ctx context.Context, c *gin.Context, return err } - // 获取代理URL(自定义 base URL 模式下,proxy 通过 buildCustomRelayURL 作为查询参数传递) + // 获取代理URL proxyURL := "" if account.ProxyID != nil && account.Proxy != nil { - if !account.IsCustomBaseURLEnabled() || account.GetCustomBaseURL() == "" { - proxyURL = account.Proxy.URL() - } + proxyURL = account.Proxy.URL() } // 发送请求 @@ -8456,16 +8812,6 @@ func (s *GatewayService) buildCountTokensRequest(ctx context.Context, c *gin.Con } targetURL = validatedURL + "/v1/messages/count_tokens?beta=true" } - } else if account.IsCustomBaseURLEnabled() { - customURL := account.GetCustomBaseURL() - if customURL == "" { - return nil, fmt.Errorf("custom_base_url is enabled but not configured for account %d", account.ID) - } - validatedURL, err := s.validateUpstreamBaseURL(customURL) - if err != nil { - return nil, err - } - targetURL = s.buildCustomRelayURL(validatedURL, "/v1/messages/count_tokens", account) } clientHeaders := http.Header{} @@ -8475,9 +8821,9 @@ func (s *GatewayService) buildCountTokensRequest(ctx context.Context, c *gin.Con // OAuth 账号:应用统一指纹和重写 userID(受设置开关控制) // 如果启用了会话ID伪装,会在重写后替换 session 部分为固定值 - ctEnableFP, ctEnableMPT, ctEnableCCH := true, false, false + ctEnableFP, ctEnableMPT := true, false if s.settingService != nil { - ctEnableFP, ctEnableMPT, ctEnableCCH = s.settingService.GetGatewayForwardingSettings(ctx) + ctEnableFP, ctEnableMPT = s.settingService.GetGatewayForwardingSettings(ctx) } var ctFingerprint *Fingerprint if account.IsOAuth() && s.identityService != nil { @@ -8495,14 +8841,6 @@ func (s *GatewayService) buildCountTokensRequest(ctx context.Context, c *gin.Con } } - // 同步 billing header cc_version 与实际发送的 User-Agent 版本 - if ctFingerprint != nil && ctEnableFP { - body = syncBillingHeaderVersion(body, ctFingerprint.UserAgent) - } - if ctEnableCCH { - body = signBillingHeaderCCH(body) - } - req, err := http.NewRequestWithContext(ctx, "POST", targetURL, bytes.NewReader(body)) if err != nil { return nil, err @@ -8543,7 +8881,7 @@ func (s *GatewayService) buildCountTokensRequest(ctx context.Context, c *gin.Con } // Build effective drop set for count_tokens: merge static defaults with dynamic beta policy filter rules - ctEffectiveDropSet := mergeDropSets(s.getBetaPolicyFilterSet(ctx, c, account, modelID)) + ctEffectiveDropSet := mergeDropSets(s.getBetaPolicyFilterSet(ctx, c, account)) // OAuth 账号:处理 anthropic-beta header if tokenType == "oauth" { @@ -8579,15 +8917,6 @@ func (s *GatewayService) buildCountTokensRequest(ctx context.Context, c *gin.Con } } - // 同步 X-Claude-Code-Session-Id 头:取 body 中已处理的 metadata.user_id 的 session_id 覆盖 - if sessionHeader := getHeaderRaw(req.Header, "X-Claude-Code-Session-Id"); sessionHeader != "" { - if uid := gjson.GetBytes(body, "metadata.user_id").String(); uid != "" { - if parsed := ParseMetadataUserID(uid); parsed != nil { - setHeaderRaw(req.Header, "X-Claude-Code-Session-Id", parsed.SessionID) - } - } - } - if c != nil && tokenType == "oauth" { c.Set(claudeMimicDebugInfoKey, buildClaudeMimicDebugLine(req, body, account, tokenType, mimicClaudeCode)) } @@ -8609,19 +8938,6 @@ func (s *GatewayService) countTokensError(c *gin.Context, status int, errType, m }) } -// buildCustomRelayURL 构建自定义中继转发 URL -// 在 path 后附加 beta=true 和可选的 proxy 查询参数 -func (s *GatewayService) buildCustomRelayURL(baseURL, path string, account *Account) string { - u := strings.TrimRight(baseURL, "/") + path + "?beta=true" - if account.ProxyID != nil && account.Proxy != nil { - proxyURL := account.Proxy.URL() - if proxyURL != "" { - u += "&proxy=" + url.QueryEscape(proxyURL) - } - } - return u -} - func (s *GatewayService) validateUpstreamBaseURL(raw string) (string, error) { if s.cfg != nil && !s.cfg.Security.URLAllowlist.Enabled { normalized, err := urlvalidator.ValidateURLFormat(raw, s.cfg.Security.URLAllowlist.AllowInsecureHTTP)