package service import ( "bufio" "bytes" "context" "crypto/sha256" "encoding/hex" "encoding/json" "errors" "fmt" "io" "log" "log/slog" mathrand "math/rand" "net/http" "os" "regexp" "sort" "strings" "sync/atomic" "time" "github.com/Wei-Shaw/sub2api/internal/config" "github.com/Wei-Shaw/sub2api/internal/pkg/claude" "github.com/Wei-Shaw/sub2api/internal/pkg/ctxkey" "github.com/Wei-Shaw/sub2api/internal/util/responseheaders" "github.com/Wei-Shaw/sub2api/internal/util/urlvalidator" "github.com/google/uuid" "github.com/tidwall/gjson" "github.com/tidwall/sjson" "github.com/gin-gonic/gin" ) const ( claudeAPIURL = "https://api.anthropic.com/v1/messages?beta=true" claudeAPICountTokensURL = "https://api.anthropic.com/v1/messages/count_tokens?beta=true" stickySessionTTL = time.Hour // 粘性会话TTL defaultMaxLineSize = 40 * 1024 * 1024 // Canonical Claude Code banner. Keep it EXACT (no trailing whitespace/newlines) // to match real Claude CLI traffic as closely as possible. When we need a visual // separator between system blocks, we add "\n\n" at concatenation time. claudeCodeSystemPrompt = "You are Claude Code, Anthropic's official CLI for Claude." maxCacheControlBlocks = 4 // Anthropic API 允许的最大 cache_control 块数量 ) const ( claudeMimicDebugInfoKey = "claude_mimic_debug_info" ) // ForceCacheBillingContextKey 强制缓存计费上下文键 // 用于粘性会话切换时,将 input_tokens 转为 cache_read_input_tokens 计费 type forceCacheBillingKeyType struct{} // accountWithLoad 账号与负载信息的组合,用于负载感知调度 type accountWithLoad struct { account *Account loadInfo *AccountLoadInfo } var ForceCacheBillingContextKey = forceCacheBillingKeyType{} // IsForceCacheBilling 检查是否启用强制缓存计费 func IsForceCacheBilling(ctx context.Context) bool { v, _ := ctx.Value(ForceCacheBillingContextKey).(bool) return v } // WithForceCacheBilling 返回带有强制缓存计费标记的上下文 func WithForceCacheBilling(ctx context.Context) context.Context { return context.WithValue(ctx, ForceCacheBillingContextKey, true) } func (s *GatewayService) debugModelRoutingEnabled() bool { v := strings.ToLower(strings.TrimSpace(os.Getenv("SUB2API_DEBUG_MODEL_ROUTING"))) return v == "1" || v == "true" || v == "yes" || v == "on" } func (s *GatewayService) debugClaudeMimicEnabled() bool { v := strings.ToLower(strings.TrimSpace(os.Getenv("SUB2API_DEBUG_CLAUDE_MIMIC"))) return v == "1" || v == "true" || v == "yes" || v == "on" } func shortSessionHash(sessionHash string) string { if sessionHash == "" { return "" } if len(sessionHash) <= 8 { return sessionHash } return sessionHash[:8] } func redactAuthHeaderValue(v string) string { v = strings.TrimSpace(v) if v == "" { return "" } // Keep scheme for debugging, redact secret. if strings.HasPrefix(strings.ToLower(v), "bearer ") { return "Bearer [redacted]" } return "[redacted]" } func safeHeaderValueForLog(key string, v string) string { key = strings.ToLower(strings.TrimSpace(key)) switch key { case "authorization", "x-api-key": return redactAuthHeaderValue(v) default: return strings.TrimSpace(v) } } func extractSystemPreviewFromBody(body []byte) string { if len(body) == 0 { return "" } sys := gjson.GetBytes(body, "system") if !sys.Exists() { return "" } switch { case sys.IsArray(): for _, item := range sys.Array() { if !item.IsObject() { continue } if strings.EqualFold(item.Get("type").String(), "text") { if t := item.Get("text").String(); strings.TrimSpace(t) != "" { return t } } } return "" case sys.Type == gjson.String: return sys.String() default: return "" } } func buildClaudeMimicDebugLine(req *http.Request, body []byte, account *Account, tokenType string, mimicClaudeCode bool) string { if req == nil { return "" } // Only log a minimal fingerprint to avoid leaking user content. interesting := []string{ "user-agent", "x-app", "anthropic-dangerous-direct-browser-access", "anthropic-version", "anthropic-beta", "x-stainless-lang", "x-stainless-package-version", "x-stainless-os", "x-stainless-arch", "x-stainless-runtime", "x-stainless-runtime-version", "x-stainless-retry-count", "x-stainless-timeout", "authorization", "x-api-key", "content-type", "accept", "x-stainless-helper-method", } h := make([]string, 0, len(interesting)) for _, k := range interesting { if v := req.Header.Get(k); v != "" { h = append(h, fmt.Sprintf("%s=%q", k, safeHeaderValueForLog(k, v))) } } metaUserID := strings.TrimSpace(gjson.GetBytes(body, "metadata.user_id").String()) sysPreview := strings.TrimSpace(extractSystemPreviewFromBody(body)) // Truncate preview to keep logs sane. if len(sysPreview) > 300 { sysPreview = sysPreview[:300] + "..." } sysPreview = strings.ReplaceAll(sysPreview, "\n", "\\n") sysPreview = strings.ReplaceAll(sysPreview, "\r", "\\r") aid := int64(0) aname := "" if account != nil { aid = account.ID aname = account.Name } return fmt.Sprintf( "url=%s account=%d(%s) tokenType=%s mimic=%t meta.user_id=%q system.preview=%q headers={%s}", req.URL.String(), aid, aname, tokenType, mimicClaudeCode, metaUserID, sysPreview, strings.Join(h, " "), ) } func logClaudeMimicDebug(req *http.Request, body []byte, account *Account, tokenType string, mimicClaudeCode bool) { line := buildClaudeMimicDebugLine(req, body, account, tokenType, mimicClaudeCode) if line == "" { return } log.Printf("[ClaudeMimicDebug] %s", line) } func isClaudeCodeCredentialScopeError(msg string) bool { m := strings.ToLower(strings.TrimSpace(msg)) if m == "" { return false } return strings.Contains(m, "only authorized for use with claude code") && strings.Contains(m, "cannot be used for other api requests") } // sseDataRe matches SSE data lines with optional whitespace after colon. // Some upstream APIs return non-standard "data:" without space (should be "data: "). var ( sseDataRe = regexp.MustCompile(`^data:\s*`) sessionIDRegex = regexp.MustCompile(`session_([a-f0-9-]{36})`) claudeCliUserAgentRe = regexp.MustCompile(`^claude-cli/\d+\.\d+\.\d+`) // claudeCodePromptPrefixes 用于检测 Claude Code 系统提示词的前缀列表 // 支持多种变体:标准版、Agent SDK 版、Explore Agent 版、Compact 版等 // 注意:前缀之间不应存在包含关系,否则会导致冗余匹配 claudeCodePromptPrefixes = []string{ "You are Claude Code, Anthropic's official CLI for Claude", // 标准版 & Agent SDK 版(含 running within...) "You are a Claude agent, built on Anthropic's Claude Agent SDK", // Agent SDK 变体 "You are a file search specialist for Claude Code", // Explore Agent 版 "You are a helpful AI assistant tasked with summarizing conversations", // Compact 版 } ) // ErrClaudeCodeOnly 表示分组仅允许 Claude Code 客户端访问 var ErrClaudeCodeOnly = errors.New("this group only allows Claude Code clients") // ErrModelScopeNotSupported 表示请求的模型系列不在分组支持的范围内 var ErrModelScopeNotSupported = errors.New("model scope not supported by this group") // allowedHeaders 白名单headers(参考CRS项目) var allowedHeaders = map[string]bool{ "accept": true, "x-stainless-retry-count": true, "x-stainless-timeout": true, "x-stainless-lang": true, "x-stainless-package-version": true, "x-stainless-os": true, "x-stainless-arch": true, "x-stainless-runtime": true, "x-stainless-runtime-version": true, "x-stainless-helper-method": true, "anthropic-dangerous-direct-browser-access": true, "anthropic-version": true, "x-app": true, "anthropic-beta": true, "accept-language": true, "sec-fetch-mode": true, "user-agent": true, "content-type": true, } // GatewayCache 定义网关服务的缓存操作接口。 // 提供粘性会话(Sticky Session)的存储、查询、刷新和删除功能。 // // ModelLoadInfo 模型负载信息(用于 Antigravity 调度) // Model load info for Antigravity scheduling type ModelLoadInfo struct { CallCount int64 // 当前分钟调用次数 / Call count in current minute LastUsedAt time.Time // 最后调度时间(零值表示未调度过)/ Last scheduling time (zero means never scheduled) } // GatewayCache defines cache operations for gateway service. // Provides sticky session storage, retrieval, refresh and deletion capabilities. type GatewayCache interface { // GetSessionAccountID 获取粘性会话绑定的账号 ID // Get the account ID bound to a sticky session GetSessionAccountID(ctx context.Context, groupID int64, sessionHash string) (int64, error) // SetSessionAccountID 设置粘性会话与账号的绑定关系 // Set the binding between sticky session and account SetSessionAccountID(ctx context.Context, groupID int64, sessionHash string, accountID int64, ttl time.Duration) error // RefreshSessionTTL 刷新粘性会话的过期时间 // Refresh the expiration time of a sticky session RefreshSessionTTL(ctx context.Context, groupID int64, sessionHash string, ttl time.Duration) error // DeleteSessionAccountID 删除粘性会话绑定,用于账号不可用时主动清理 // Delete sticky session binding, used to proactively clean up when account becomes unavailable DeleteSessionAccountID(ctx context.Context, groupID int64, sessionHash string) error // IncrModelCallCount 增加模型调用次数并更新最后调度时间(Antigravity 专用) // Increment model call count and update last scheduling time (Antigravity only) // 返回更新后的调用次数 IncrModelCallCount(ctx context.Context, accountID int64, model string) (int64, error) // GetModelLoadBatch 批量获取账号的模型负载信息(Antigravity 专用) // Batch get model load info for accounts (Antigravity only) GetModelLoadBatch(ctx context.Context, accountIDs []int64, model string) (map[int64]*ModelLoadInfo, error) // FindGeminiSession 查找 Gemini 会话(MGET 倒序匹配) // Find Gemini session using MGET reverse order matching // 返回最长匹配的会话信息(uuid, accountID) FindGeminiSession(ctx context.Context, groupID int64, prefixHash, digestChain string) (uuid string, accountID int64, found bool) // SaveGeminiSession 保存 Gemini 会话 // Save Gemini session binding SaveGeminiSession(ctx context.Context, groupID int64, prefixHash, digestChain, uuid string, accountID int64) error } // derefGroupID safely dereferences *int64 to int64, returning 0 if nil func derefGroupID(groupID *int64) int64 { if groupID == nil { return 0 } return *groupID } // stickySessionRateLimitThreshold 定义清除粘性会话的限流时间阈值。 // 当账号限流剩余时间超过此阈值时,清除粘性会话以便切换到其他账号。 // 低于此阈值时保持粘性会话,等待短暂限流结束。 const stickySessionRateLimitThreshold = 10 * time.Second // shouldClearStickySession 检查账号是否处于不可调度状态,需要清理粘性会话绑定。 // 当账号状态为错误、禁用、不可调度、处于临时不可调度期间, // 或模型限流剩余时间超过 stickySessionRateLimitThreshold 时,返回 true。 // 这确保后续请求不会继续使用不可用的账号。 // // shouldClearStickySession checks if an account is in an unschedulable state // and the sticky session binding should be cleared. // Returns true when account status is error/disabled, schedulable is false, // within temporary unschedulable period, or model rate limit remaining time // exceeds stickySessionRateLimitThreshold. // This ensures subsequent requests won't continue using unavailable accounts. func shouldClearStickySession(account *Account, requestedModel string) bool { if account == nil { return false } if account.Status == StatusError || account.Status == StatusDisabled || !account.Schedulable { return true } if account.TempUnschedulableUntil != nil && time.Now().Before(*account.TempUnschedulableUntil) { return true } // 检查模型限流和 scope 限流,只在超过阈值时清除粘性会话 if remaining := account.GetRateLimitRemainingTimeWithContext(context.Background(), requestedModel); remaining > stickySessionRateLimitThreshold { return true } return false } type AccountWaitPlan struct { AccountID int64 MaxConcurrency int Timeout time.Duration MaxWaiting int } type AccountSelectionResult struct { Account *Account Acquired bool ReleaseFunc func() WaitPlan *AccountWaitPlan // nil means no wait allowed } // ClaudeUsage 表示Claude API返回的usage信息 type ClaudeUsage struct { InputTokens int `json:"input_tokens"` OutputTokens int `json:"output_tokens"` CacheCreationInputTokens int `json:"cache_creation_input_tokens"` CacheReadInputTokens int `json:"cache_read_input_tokens"` } // ForwardResult 转发结果 type ForwardResult struct { RequestID string Usage ClaudeUsage Model string Stream bool Duration time.Duration FirstTokenMs *int // 首字时间(流式请求) ClientDisconnect bool // 客户端是否在流式传输过程中断开 // 图片生成计费字段(仅 gemini-3-pro-image 使用) ImageCount int // 生成的图片数量 ImageSize string // 图片尺寸 "1K", "2K", "4K" } // UpstreamFailoverError indicates an upstream error that should trigger account failover. type UpstreamFailoverError struct { StatusCode int ResponseBody []byte // 上游响应体,用于错误透传规则匹配 ForceCacheBilling bool // Antigravity 粘性会话切换时设为 true } func (e *UpstreamFailoverError) Error() string { return fmt.Sprintf("upstream error: %d (failover)", e.StatusCode) } // GatewayService handles API gateway operations type GatewayService struct { accountRepo AccountRepository groupRepo GroupRepository usageLogRepo UsageLogRepository userRepo UserRepository userSubRepo UserSubscriptionRepository userGroupRateRepo UserGroupRateRepository cache GatewayCache cfg *config.Config schedulerSnapshot *SchedulerSnapshotService billingService *BillingService rateLimitService *RateLimitService billingCacheService *BillingCacheService identityService *IdentityService httpUpstream HTTPUpstream deferredService *DeferredService concurrencyService *ConcurrencyService claudeTokenProvider *ClaudeTokenProvider sessionLimitCache SessionLimitCache // 会话数量限制缓存(仅 Anthropic OAuth/SetupToken) } // NewGatewayService creates a new GatewayService func NewGatewayService( accountRepo AccountRepository, groupRepo GroupRepository, usageLogRepo UsageLogRepository, userRepo UserRepository, userSubRepo UserSubscriptionRepository, userGroupRateRepo UserGroupRateRepository, cache GatewayCache, cfg *config.Config, schedulerSnapshot *SchedulerSnapshotService, concurrencyService *ConcurrencyService, billingService *BillingService, rateLimitService *RateLimitService, billingCacheService *BillingCacheService, identityService *IdentityService, httpUpstream HTTPUpstream, deferredService *DeferredService, claudeTokenProvider *ClaudeTokenProvider, sessionLimitCache SessionLimitCache, ) *GatewayService { return &GatewayService{ accountRepo: accountRepo, groupRepo: groupRepo, usageLogRepo: usageLogRepo, userRepo: userRepo, userSubRepo: userSubRepo, userGroupRateRepo: userGroupRateRepo, cache: cache, cfg: cfg, schedulerSnapshot: schedulerSnapshot, concurrencyService: concurrencyService, billingService: billingService, rateLimitService: rateLimitService, billingCacheService: billingCacheService, identityService: identityService, httpUpstream: httpUpstream, deferredService: deferredService, claudeTokenProvider: claudeTokenProvider, sessionLimitCache: sessionLimitCache, } } // GenerateSessionHash 从预解析请求计算粘性会话 hash func (s *GatewayService) GenerateSessionHash(parsed *ParsedRequest) string { if parsed == nil { return "" } // 1. 最高优先级:从 metadata.user_id 提取 session_xxx if parsed.MetadataUserID != "" { if match := sessionIDRegex.FindStringSubmatch(parsed.MetadataUserID); len(match) > 1 { return match[1] } } // 2. 提取带 cache_control: {type: "ephemeral"} 的内容 cacheableContent := s.extractCacheableContent(parsed) if cacheableContent != "" { return s.hashContent(cacheableContent) } // 3. Fallback: 使用 system 内容 if parsed.System != nil { systemText := s.extractTextFromSystem(parsed.System) if systemText != "" { return s.hashContent(systemText) } } // 4. 最后 fallback: 使用第一条消息 if len(parsed.Messages) > 0 { if firstMsg, ok := parsed.Messages[0].(map[string]any); ok { msgText := s.extractTextFromContent(firstMsg["content"]) if msgText != "" { return s.hashContent(msgText) } } } return "" } // BindStickySession sets session -> account binding with standard TTL. func (s *GatewayService) BindStickySession(ctx context.Context, groupID *int64, sessionHash string, accountID int64) error { if sessionHash == "" || accountID <= 0 || s.cache == nil { return nil } return s.cache.SetSessionAccountID(ctx, derefGroupID(groupID), sessionHash, accountID, stickySessionTTL) } // GetCachedSessionAccountID retrieves the account ID bound to a sticky session. // Returns 0 if no binding exists or on error. func (s *GatewayService) GetCachedSessionAccountID(ctx context.Context, groupID *int64, sessionHash string) (int64, error) { if sessionHash == "" || s.cache == nil { return 0, nil } accountID, err := s.cache.GetSessionAccountID(ctx, derefGroupID(groupID), sessionHash) if err != nil { return 0, err } return accountID, nil } // FindGeminiSession 查找 Gemini 会话(基于内容摘要链的 Fallback 匹配) // 返回最长匹配的会话信息(uuid, accountID) func (s *GatewayService) FindGeminiSession(ctx context.Context, groupID int64, prefixHash, digestChain string) (uuid string, accountID int64, found bool) { if digestChain == "" || s.cache == nil { return "", 0, false } return s.cache.FindGeminiSession(ctx, groupID, prefixHash, digestChain) } // SaveGeminiSession 保存 Gemini 会话 func (s *GatewayService) SaveGeminiSession(ctx context.Context, groupID int64, prefixHash, digestChain, uuid string, accountID int64) error { if digestChain == "" || s.cache == nil { return nil } return s.cache.SaveGeminiSession(ctx, groupID, prefixHash, digestChain, uuid, accountID) } func (s *GatewayService) extractCacheableContent(parsed *ParsedRequest) string { if parsed == nil { return "" } var builder strings.Builder // 检查 system 中的 cacheable 内容 if system, ok := parsed.System.([]any); ok { for _, part := range system { if partMap, ok := part.(map[string]any); ok { if cc, ok := partMap["cache_control"].(map[string]any); ok { if cc["type"] == "ephemeral" { if text, ok := partMap["text"].(string); ok { _, _ = builder.WriteString(text) } } } } } } systemText := builder.String() // 检查 messages 中的 cacheable 内容 for _, msg := range parsed.Messages { if msgMap, ok := msg.(map[string]any); ok { if msgContent, ok := msgMap["content"].([]any); ok { for _, part := range msgContent { if partMap, ok := part.(map[string]any); ok { if cc, ok := partMap["cache_control"].(map[string]any); ok { if cc["type"] == "ephemeral" { return s.extractTextFromContent(msgMap["content"]) } } } } } } } return systemText } func (s *GatewayService) extractTextFromSystem(system any) string { switch v := system.(type) { case string: return v case []any: var texts []string for _, part := range v { if partMap, ok := part.(map[string]any); ok { if text, ok := partMap["text"].(string); ok { texts = append(texts, text) } } } return strings.Join(texts, "") } return "" } func (s *GatewayService) extractTextFromContent(content any) string { switch v := content.(type) { case string: return v case []any: var texts []string for _, part := range v { if partMap, ok := part.(map[string]any); ok { if partMap["type"] == "text" { if text, ok := partMap["text"].(string); ok { texts = append(texts, text) } } } } return strings.Join(texts, "") } return "" } func (s *GatewayService) hashContent(content string) string { hash := sha256.Sum256([]byte(content)) return hex.EncodeToString(hash[:16]) // 32字符 } // replaceModelInBody 替换请求体中的model字段 // 使用 json.RawMessage 保留其他字段的原始字节,避免 thinking 块等内容被修改 func (s *GatewayService) replaceModelInBody(body []byte, newModel string) []byte { var req map[string]json.RawMessage if err := json.Unmarshal(body, &req); err != nil { return body } // 只序列化 model 字段 modelBytes, err := json.Marshal(newModel) if err != nil { return body } req["model"] = modelBytes newBody, err := json.Marshal(req) if err != nil { return body } return newBody } type claudeOAuthNormalizeOptions struct { injectMetadata bool metadataUserID string stripSystemCacheControl bool } // sanitizeSystemText rewrites only the fixed OpenCode identity sentence (if present). // We intentionally avoid broad keyword replacement in system prompts to prevent // accidentally changing user-provided instructions. func sanitizeSystemText(text string) string { if text == "" { return text } // Some clients include a fixed OpenCode identity sentence. Anthropic may treat // this as a non-Claude-Code fingerprint, so rewrite it to the canonical // Claude Code banner before generic "OpenCode"/"opencode" replacements. text = strings.ReplaceAll( text, "You are OpenCode, the best coding agent on the planet.", strings.TrimSpace(claudeCodeSystemPrompt), ) return text } func stripCacheControlFromSystemBlocks(system any) bool { blocks, ok := system.([]any) if !ok { return false } changed := false for _, item := range blocks { block, ok := item.(map[string]any) if !ok { continue } if _, exists := block["cache_control"]; !exists { continue } delete(block, "cache_control") changed = true } return changed } func normalizeClaudeOAuthRequestBody(body []byte, modelID string, opts claudeOAuthNormalizeOptions) ([]byte, string) { if len(body) == 0 { return body, modelID } // 解析为 map[string]any 用于修改字段 var req map[string]any if err := json.Unmarshal(body, &req); err != nil { return body, modelID } modified := false if system, ok := req["system"]; ok { switch v := system.(type) { case string: sanitized := sanitizeSystemText(v) if sanitized != v { req["system"] = sanitized modified = true } case []any: for _, item := range v { block, ok := item.(map[string]any) if !ok { continue } if blockType, _ := block["type"].(string); blockType != "text" { continue } text, ok := block["text"].(string) if !ok || text == "" { continue } sanitized := sanitizeSystemText(text) if sanitized != text { block["text"] = sanitized modified = true } } } } if rawModel, ok := req["model"].(string); ok { normalized := claude.NormalizeModelID(rawModel) if normalized != rawModel { req["model"] = normalized modelID = normalized modified = true } } // 确保 tools 字段存在(即使为空数组) if _, exists := req["tools"]; !exists { req["tools"] = []any{} modified = true } if opts.stripSystemCacheControl { if system, ok := req["system"]; ok { _ = stripCacheControlFromSystemBlocks(system) modified = true } } if opts.injectMetadata && opts.metadataUserID != "" { metadata, ok := req["metadata"].(map[string]any) if !ok { metadata = map[string]any{} req["metadata"] = metadata } if existing, ok := metadata["user_id"].(string); !ok || existing == "" { metadata["user_id"] = opts.metadataUserID modified = true } } if _, hasTemp := req["temperature"]; hasTemp { delete(req, "temperature") modified = true } if _, hasChoice := req["tool_choice"]; hasChoice { delete(req, "tool_choice") modified = true } if !modified { return body, modelID } newBody, err := json.Marshal(req) if err != nil { return body, modelID } return newBody, modelID } func (s *GatewayService) buildOAuthMetadataUserID(parsed *ParsedRequest, account *Account, fp *Fingerprint) string { if parsed == nil || account == nil { return "" } if parsed.MetadataUserID != "" { return "" } userID := strings.TrimSpace(account.GetClaudeUserID()) if userID == "" && fp != nil { userID = fp.ClientID } if userID == "" { // Fall back to a random, well-formed client id so we can still satisfy // Claude Code OAuth requirements when account metadata is incomplete. userID = generateClientID() } sessionHash := s.GenerateSessionHash(parsed) sessionID := uuid.NewString() if sessionHash != "" { seed := fmt.Sprintf("%d::%s", account.ID, sessionHash) sessionID = generateSessionUUID(seed) } // Prefer the newer format that includes account_uuid (if present), // otherwise fall back to the legacy Claude Code format. accountUUID := strings.TrimSpace(account.GetExtraString("account_uuid")) if accountUUID != "" { return fmt.Sprintf("user_%s_account_%s_session_%s", userID, accountUUID, sessionID) } return fmt.Sprintf("user_%s_account__session_%s", userID, sessionID) } func generateSessionUUID(seed string) string { if seed == "" { return uuid.NewString() } hash := sha256.Sum256([]byte(seed)) bytes := hash[:16] bytes[6] = (bytes[6] & 0x0f) | 0x40 bytes[8] = (bytes[8] & 0x3f) | 0x80 return fmt.Sprintf("%x-%x-%x-%x-%x", bytes[0:4], bytes[4:6], bytes[6:8], bytes[8:10], bytes[10:16]) } // SelectAccount 选择账号(粘性会话+优先级) func (s *GatewayService) SelectAccount(ctx context.Context, groupID *int64, sessionHash string) (*Account, error) { return s.SelectAccountForModel(ctx, groupID, sessionHash, "") } // SelectAccountForModel 选择支持指定模型的账号(粘性会话+优先级+模型映射) func (s *GatewayService) SelectAccountForModel(ctx context.Context, groupID *int64, sessionHash string, requestedModel string) (*Account, error) { return s.SelectAccountForModelWithExclusions(ctx, groupID, sessionHash, requestedModel, nil) } // SelectAccountForModelWithExclusions selects an account supporting the requested model while excluding specified accounts. func (s *GatewayService) SelectAccountForModelWithExclusions(ctx context.Context, groupID *int64, sessionHash string, requestedModel string, excludedIDs map[int64]struct{}) (*Account, error) { // 优先检查 context 中的强制平台(/antigravity 路由) var platform string forcePlatform, hasForcePlatform := ctx.Value(ctxkey.ForcePlatform).(string) if hasForcePlatform && forcePlatform != "" { platform = forcePlatform } else if groupID != nil { group, resolvedGroupID, err := s.resolveGatewayGroup(ctx, groupID) if err != nil { return nil, err } groupID = resolvedGroupID ctx = s.withGroupContext(ctx, group) platform = group.Platform } else { // 无分组时只使用原生 anthropic 平台 platform = PlatformAnthropic } // anthropic/gemini 分组支持混合调度(包含启用了 mixed_scheduling 的 antigravity 账户) // 注意:强制平台模式不走混合调度 if (platform == PlatformAnthropic || platform == PlatformGemini) && !hasForcePlatform { return s.selectAccountWithMixedScheduling(ctx, groupID, sessionHash, requestedModel, excludedIDs, platform) } // antigravity 分组、强制平台模式或无分组使用单平台选择 // 注意:强制平台模式也必须遵守分组限制,不再回退到全平台查询 return s.selectAccountForModelWithPlatform(ctx, groupID, sessionHash, requestedModel, excludedIDs, platform) } // SelectAccountWithLoadAwareness selects account with load-awareness and wait plan. // metadataUserID: 已废弃参数,会话限制现在统一使用 sessionHash func (s *GatewayService) SelectAccountWithLoadAwareness(ctx context.Context, groupID *int64, sessionHash string, requestedModel string, excludedIDs map[int64]struct{}, metadataUserID string) (*AccountSelectionResult, error) { // 调试日志:记录调度入口参数 excludedIDsList := make([]int64, 0, len(excludedIDs)) for id := range excludedIDs { excludedIDsList = append(excludedIDsList, id) } slog.Debug("account_scheduling_starting", "group_id", derefGroupID(groupID), "model", requestedModel, "session", shortSessionHash(sessionHash), "excluded_ids", excludedIDsList) cfg := s.schedulingConfig() var stickyAccountID int64 if sessionHash != "" && s.cache != nil { if accountID, err := s.cache.GetSessionAccountID(ctx, derefGroupID(groupID), sessionHash); err == nil { stickyAccountID = accountID } } // 检查 Claude Code 客户端限制(可能会替换 groupID 为降级分组) group, groupID, err := s.checkClaudeCodeRestriction(ctx, groupID) if err != nil { return nil, err } ctx = s.withGroupContext(ctx, group) if s.debugModelRoutingEnabled() && requestedModel != "" { groupPlatform := "" if group != nil { groupPlatform = group.Platform } log.Printf("[ModelRoutingDebug] select entry: group_id=%v group_platform=%s model=%s session=%s sticky_account=%d load_batch=%v concurrency=%v", derefGroupID(groupID), groupPlatform, requestedModel, shortSessionHash(sessionHash), stickyAccountID, cfg.LoadBatchEnabled, s.concurrencyService != nil) } if s.concurrencyService == nil || !cfg.LoadBatchEnabled { // 复制排除列表,用于会话限制拒绝时的重试 localExcluded := make(map[int64]struct{}) for k, v := range excludedIDs { localExcluded[k] = v } for { account, err := s.SelectAccountForModelWithExclusions(ctx, groupID, sessionHash, requestedModel, localExcluded) if err != nil { return nil, err } result, err := s.tryAcquireAccountSlot(ctx, account.ID, account.Concurrency) if err == nil && result.Acquired { // 获取槽位后检查会话限制(使用 sessionHash 作为会话标识符) if !s.checkAndRegisterSession(ctx, account, sessionHash) { result.ReleaseFunc() // 释放槽位 localExcluded[account.ID] = struct{}{} // 排除此账号 continue // 重新选择 } return &AccountSelectionResult{ Account: account, Acquired: true, ReleaseFunc: result.ReleaseFunc, }, nil } // 对于等待计划的情况,也需要先检查会话限制 if !s.checkAndRegisterSession(ctx, account, sessionHash) { localExcluded[account.ID] = struct{}{} continue } if stickyAccountID > 0 && stickyAccountID == account.ID && s.concurrencyService != nil { waitingCount, _ := s.concurrencyService.GetAccountWaitingCount(ctx, account.ID) if waitingCount < cfg.StickySessionMaxWaiting { return &AccountSelectionResult{ Account: account, WaitPlan: &AccountWaitPlan{ AccountID: account.ID, MaxConcurrency: account.Concurrency, Timeout: cfg.StickySessionWaitTimeout, MaxWaiting: cfg.StickySessionMaxWaiting, }, }, nil } } return &AccountSelectionResult{ Account: account, WaitPlan: &AccountWaitPlan{ AccountID: account.ID, MaxConcurrency: account.Concurrency, Timeout: cfg.FallbackWaitTimeout, MaxWaiting: cfg.FallbackMaxWaiting, }, }, nil } } platform, hasForcePlatform, err := s.resolvePlatform(ctx, groupID, group) if err != nil { return nil, err } preferOAuth := platform == PlatformGemini if s.debugModelRoutingEnabled() && platform == PlatformAnthropic && requestedModel != "" { log.Printf("[ModelRoutingDebug] load-aware enabled: group_id=%v model=%s session=%s platform=%s", derefGroupID(groupID), requestedModel, shortSessionHash(sessionHash), platform) } // Antigravity 模型系列检查(在账号选择前检查,确保所有代码路径都经过此检查) if platform == PlatformAntigravity && groupID != nil && requestedModel != "" { if err := s.checkAntigravityModelScope(ctx, *groupID, requestedModel); err != nil { return nil, err } } accounts, useMixed, err := s.listSchedulableAccounts(ctx, groupID, platform, hasForcePlatform) if err != nil { return nil, err } if len(accounts) == 0 { return nil, errors.New("no available accounts") } isExcluded := func(accountID int64) bool { if excludedIDs == nil { return false } _, excluded := excludedIDs[accountID] return excluded } // 提前构建 accountByID(供 Layer 1 和 Layer 1.5 使用) accountByID := make(map[int64]*Account, len(accounts)) for i := range accounts { accountByID[accounts[i].ID] = &accounts[i] } // 获取模型路由配置(仅 anthropic 平台) var routingAccountIDs []int64 if group != nil && requestedModel != "" && group.Platform == PlatformAnthropic { routingAccountIDs = group.GetRoutingAccountIDs(requestedModel) if s.debugModelRoutingEnabled() { log.Printf("[ModelRoutingDebug] context group routing: group_id=%d model=%s enabled=%v rules=%d matched_ids=%v session=%s sticky_account=%d", group.ID, requestedModel, group.ModelRoutingEnabled, len(group.ModelRouting), routingAccountIDs, shortSessionHash(sessionHash), stickyAccountID) if len(routingAccountIDs) == 0 && group.ModelRoutingEnabled && len(group.ModelRouting) > 0 { keys := make([]string, 0, len(group.ModelRouting)) for k := range group.ModelRouting { keys = append(keys, k) } sort.Strings(keys) const maxKeys = 20 if len(keys) > maxKeys { keys = keys[:maxKeys] } log.Printf("[ModelRoutingDebug] context group routing miss: group_id=%d model=%s patterns(sample)=%v", group.ID, requestedModel, keys) } } } // ============ Layer 1: 模型路由优先选择(优先级高于粘性会话) ============ if len(routingAccountIDs) > 0 && s.concurrencyService != nil { // 1. 过滤出路由列表中可调度的账号 var routingCandidates []*Account var filteredExcluded, filteredMissing, filteredUnsched, filteredPlatform, filteredModelScope, filteredModelMapping, filteredWindowCost int var modelScopeSkippedIDs []int64 // 记录因模型限流被跳过的账号 ID for _, routingAccountID := range routingAccountIDs { if isExcluded(routingAccountID) { filteredExcluded++ continue } account, ok := accountByID[routingAccountID] if !ok || !account.IsSchedulable() { if !ok { filteredMissing++ } else { filteredUnsched++ } continue } if !s.isAccountAllowedForPlatform(account, platform, useMixed) { filteredPlatform++ continue } if requestedModel != "" && !s.isModelSupportedByAccountWithContext(ctx, account, requestedModel) { filteredModelMapping++ continue } if !account.IsSchedulableForModelWithContext(ctx, requestedModel) { filteredModelScope++ modelScopeSkippedIDs = append(modelScopeSkippedIDs, account.ID) continue } // 窗口费用检查(非粘性会话路径) if !s.isAccountSchedulableForWindowCost(ctx, account, false) { filteredWindowCost++ continue } routingCandidates = append(routingCandidates, account) } if s.debugModelRoutingEnabled() { log.Printf("[ModelRoutingDebug] routed candidates: group_id=%v model=%s routed=%d candidates=%d filtered(excluded=%d missing=%d unsched=%d platform=%d model_scope=%d model_mapping=%d window_cost=%d)", derefGroupID(groupID), requestedModel, len(routingAccountIDs), len(routingCandidates), filteredExcluded, filteredMissing, filteredUnsched, filteredPlatform, filteredModelScope, filteredModelMapping, filteredWindowCost) if len(modelScopeSkippedIDs) > 0 { log.Printf("[ModelRoutingDebug] model_rate_limited accounts skipped: group_id=%v model=%s account_ids=%v", derefGroupID(groupID), requestedModel, modelScopeSkippedIDs) } } if len(routingCandidates) > 0 { // 1.5. 在路由账号范围内检查粘性会话 if sessionHash != "" && s.cache != nil { stickyAccountID, err := s.cache.GetSessionAccountID(ctx, derefGroupID(groupID), sessionHash) if err == nil && stickyAccountID > 0 && containsInt64(routingAccountIDs, stickyAccountID) && !isExcluded(stickyAccountID) { // 粘性账号在路由列表中,优先使用 if stickyAccount, ok := accountByID[stickyAccountID]; ok { if stickyAccount.IsSchedulable() && s.isAccountAllowedForPlatform(stickyAccount, platform, useMixed) && (requestedModel == "" || s.isModelSupportedByAccountWithContext(ctx, stickyAccount, requestedModel)) && stickyAccount.IsSchedulableForModelWithContext(ctx, requestedModel) && s.isAccountSchedulableForWindowCost(ctx, stickyAccount, true) { // 粘性会话窗口费用检查 result, err := s.tryAcquireAccountSlot(ctx, stickyAccountID, stickyAccount.Concurrency) if err == nil && result.Acquired { // 会话数量限制检查 if !s.checkAndRegisterSession(ctx, stickyAccount, sessionHash) { result.ReleaseFunc() // 释放槽位 // 继续到负载感知选择 } else { _ = s.cache.RefreshSessionTTL(ctx, derefGroupID(groupID), sessionHash, stickySessionTTL) if s.debugModelRoutingEnabled() { log.Printf("[ModelRoutingDebug] routed sticky hit: group_id=%v model=%s session=%s account=%d", derefGroupID(groupID), requestedModel, shortSessionHash(sessionHash), stickyAccountID) } return &AccountSelectionResult{ Account: stickyAccount, Acquired: true, ReleaseFunc: result.ReleaseFunc, }, nil } } waitingCount, _ := s.concurrencyService.GetAccountWaitingCount(ctx, stickyAccountID) if waitingCount < cfg.StickySessionMaxWaiting { // 会话数量限制检查(等待计划也需要占用会话配额) if !s.checkAndRegisterSession(ctx, stickyAccount, sessionHash) { // 会话限制已满,继续到负载感知选择 } else { return &AccountSelectionResult{ Account: stickyAccount, WaitPlan: &AccountWaitPlan{ AccountID: stickyAccountID, MaxConcurrency: stickyAccount.Concurrency, Timeout: cfg.StickySessionWaitTimeout, MaxWaiting: cfg.StickySessionMaxWaiting, }, }, nil } } // 粘性账号槽位满且等待队列已满,继续使用负载感知选择 } } else { _ = s.cache.DeleteSessionAccountID(ctx, derefGroupID(groupID), sessionHash) } } } // 2. 批量获取负载信息 routingLoads := make([]AccountWithConcurrency, 0, len(routingCandidates)) for _, acc := range routingCandidates { routingLoads = append(routingLoads, AccountWithConcurrency{ ID: acc.ID, MaxConcurrency: acc.Concurrency, }) } routingLoadMap, _ := s.concurrencyService.GetAccountsLoadBatch(ctx, routingLoads) // 3. 按负载感知排序 var routingAvailable []accountWithLoad for _, acc := range routingCandidates { loadInfo := routingLoadMap[acc.ID] if loadInfo == nil { loadInfo = &AccountLoadInfo{AccountID: acc.ID} } if loadInfo.LoadRate < 100 { routingAvailable = append(routingAvailable, accountWithLoad{account: acc, loadInfo: loadInfo}) } } if len(routingAvailable) > 0 { // 排序:优先级 > 负载率 > 最后使用时间 sort.SliceStable(routingAvailable, func(i, j int) bool { a, b := routingAvailable[i], routingAvailable[j] if a.account.Priority != b.account.Priority { return a.account.Priority < b.account.Priority } if a.loadInfo.LoadRate != b.loadInfo.LoadRate { return a.loadInfo.LoadRate < b.loadInfo.LoadRate } switch { case a.account.LastUsedAt == nil && b.account.LastUsedAt != nil: return true case a.account.LastUsedAt != nil && b.account.LastUsedAt == nil: return false case a.account.LastUsedAt == nil && b.account.LastUsedAt == nil: return false default: return a.account.LastUsedAt.Before(*b.account.LastUsedAt) } }) // 4. 尝试获取槽位 for _, item := range routingAvailable { result, err := s.tryAcquireAccountSlot(ctx, item.account.ID, item.account.Concurrency) if err == nil && result.Acquired { // 会话数量限制检查 if !s.checkAndRegisterSession(ctx, item.account, sessionHash) { result.ReleaseFunc() // 释放槽位,继续尝试下一个账号 continue } if sessionHash != "" && s.cache != nil { _ = s.cache.SetSessionAccountID(ctx, derefGroupID(groupID), sessionHash, item.account.ID, stickySessionTTL) } if s.debugModelRoutingEnabled() { log.Printf("[ModelRoutingDebug] routed select: group_id=%v model=%s session=%s account=%d", derefGroupID(groupID), requestedModel, shortSessionHash(sessionHash), item.account.ID) } return &AccountSelectionResult{ Account: item.account, Acquired: true, ReleaseFunc: result.ReleaseFunc, }, nil } } // 5. 所有路由账号槽位满,尝试返回等待计划(选择负载最低的) // 遍历找到第一个满足会话限制的账号 for _, item := range routingAvailable { if !s.checkAndRegisterSession(ctx, item.account, sessionHash) { continue // 会话限制已满,尝试下一个 } if s.debugModelRoutingEnabled() { log.Printf("[ModelRoutingDebug] routed wait: group_id=%v model=%s session=%s account=%d", derefGroupID(groupID), requestedModel, shortSessionHash(sessionHash), item.account.ID) } return &AccountSelectionResult{ Account: item.account, WaitPlan: &AccountWaitPlan{ AccountID: item.account.ID, MaxConcurrency: item.account.Concurrency, Timeout: cfg.StickySessionWaitTimeout, MaxWaiting: cfg.StickySessionMaxWaiting, }, }, nil } // 所有路由账号会话限制都已满,继续到 Layer 2 回退 } // 路由列表中的账号都不可用(负载率 >= 100),继续到 Layer 2 回退 log.Printf("[ModelRouting] All routed accounts unavailable for model=%s, falling back to normal selection", requestedModel) } } // ============ Layer 1.5: 粘性会话(仅在无模型路由配置时生效) ============ if len(routingAccountIDs) == 0 && sessionHash != "" && s.cache != nil { accountID, err := s.cache.GetSessionAccountID(ctx, derefGroupID(groupID), sessionHash) if err == nil && accountID > 0 && !isExcluded(accountID) { account, ok := accountByID[accountID] if ok { // 检查账户是否需要清理粘性会话绑定 // Check if the account needs sticky session cleanup clearSticky := shouldClearStickySession(account, requestedModel) if clearSticky { _ = s.cache.DeleteSessionAccountID(ctx, derefGroupID(groupID), sessionHash) } if !clearSticky && s.isAccountInGroup(account, groupID) && s.isAccountAllowedForPlatform(account, platform, useMixed) && (requestedModel == "" || s.isModelSupportedByAccountWithContext(ctx, account, requestedModel)) && account.IsSchedulableForModelWithContext(ctx, requestedModel) && s.isAccountSchedulableForWindowCost(ctx, account, true) { // 粘性会话窗口费用检查 result, err := s.tryAcquireAccountSlot(ctx, accountID, account.Concurrency) if err == nil && result.Acquired { // 会话数量限制检查 // Session count limit check if !s.checkAndRegisterSession(ctx, account, sessionHash) { result.ReleaseFunc() // 释放槽位,继续到 Layer 2 } else { _ = s.cache.RefreshSessionTTL(ctx, derefGroupID(groupID), sessionHash, stickySessionTTL) return &AccountSelectionResult{ Account: account, Acquired: true, ReleaseFunc: result.ReleaseFunc, }, nil } } waitingCount, _ := s.concurrencyService.GetAccountWaitingCount(ctx, accountID) if waitingCount < cfg.StickySessionMaxWaiting { // 会话数量限制检查(等待计划也需要占用会话配额) // Session count limit check (wait plan also requires session quota) if !s.checkAndRegisterSession(ctx, account, sessionHash) { // 会话限制已满,继续到 Layer 2 // Session limit full, continue to Layer 2 } else { return &AccountSelectionResult{ Account: account, WaitPlan: &AccountWaitPlan{ AccountID: accountID, MaxConcurrency: account.Concurrency, Timeout: cfg.StickySessionWaitTimeout, MaxWaiting: cfg.StickySessionMaxWaiting, }, }, nil } } } } } } // ============ Layer 2: 负载感知选择 ============ candidates := make([]*Account, 0, len(accounts)) for i := range accounts { acc := &accounts[i] if isExcluded(acc.ID) { continue } // Scheduler snapshots can be temporarily stale (bucket rebuild is throttled); // re-check schedulability here so recently rate-limited/overloaded accounts // are not selected again before the bucket is rebuilt. if !acc.IsSchedulable() { continue } if !s.isAccountAllowedForPlatform(acc, platform, useMixed) { continue } if requestedModel != "" && !s.isModelSupportedByAccountWithContext(ctx, acc, requestedModel) { continue } if !acc.IsSchedulableForModelWithContext(ctx, requestedModel) { continue } // 窗口费用检查(非粘性会话路径) if !s.isAccountSchedulableForWindowCost(ctx, acc, false) { continue } candidates = append(candidates, acc) } if len(candidates) == 0 { return nil, errors.New("no available accounts") } accountLoads := make([]AccountWithConcurrency, 0, len(candidates)) for _, acc := range candidates { accountLoads = append(accountLoads, AccountWithConcurrency{ ID: acc.ID, MaxConcurrency: acc.Concurrency, }) } loadMap, err := s.concurrencyService.GetAccountsLoadBatch(ctx, accountLoads) if err != nil { if result, ok := s.tryAcquireByLegacyOrder(ctx, candidates, groupID, sessionHash, preferOAuth); ok { return result, nil } } else { // Antigravity 平台:获取模型负载信息 var modelLoadMap map[int64]*ModelLoadInfo isAntigravity := platform == PlatformAntigravity var available []accountWithLoad for _, acc := range candidates { loadInfo := loadMap[acc.ID] if loadInfo == nil { loadInfo = &AccountLoadInfo{AccountID: acc.ID} } if loadInfo.LoadRate < 100 { available = append(available, accountWithLoad{ account: acc, loadInfo: loadInfo, }) } } // Antigravity 平台:按账号实际映射后的模型名获取模型负载(与 Forward 的统计保持一致) if isAntigravity && requestedModel != "" && s.cache != nil && len(available) > 0 { modelLoadMap = make(map[int64]*ModelLoadInfo, len(available)) modelToAccountIDs := make(map[string][]int64) for _, item := range available { mappedModel := mapAntigravityModel(item.account, requestedModel) if mappedModel == "" { continue } modelToAccountIDs[mappedModel] = append(modelToAccountIDs[mappedModel], item.account.ID) } for model, ids := range modelToAccountIDs { batch, err := s.cache.GetModelLoadBatch(ctx, ids, model) if err != nil { continue } for id, info := range batch { modelLoadMap[id] = info } } if len(modelLoadMap) == 0 { modelLoadMap = nil } } // Antigravity 平台:优先级硬过滤 →(同优先级内)按调用次数选择(最少优先,新账号用平均值) // 其他平台:分层过滤选择:优先级 → 负载率 → LRU if isAntigravity { for len(available) > 0 { // 1. 取优先级最小的集合(硬过滤) candidates := filterByMinPriority(available) // 2. 同优先级内按调用次数选择(调用次数最少优先,新账号使用平均值) selected := selectByCallCount(candidates, modelLoadMap, preferOAuth) if selected == nil { break } result, err := s.tryAcquireAccountSlot(ctx, selected.account.ID, selected.account.Concurrency) if err == nil && result.Acquired { // 会话数量限制检查 if !s.checkAndRegisterSession(ctx, selected.account, sessionHash) { result.ReleaseFunc() // 释放槽位,继续尝试下一个账号 } else { if sessionHash != "" && s.cache != nil { _ = s.cache.SetSessionAccountID(ctx, derefGroupID(groupID), sessionHash, selected.account.ID, stickySessionTTL) } return &AccountSelectionResult{ Account: selected.account, Acquired: true, ReleaseFunc: result.ReleaseFunc, }, nil } } // 移除已尝试的账号,重新选择 selectedID := selected.account.ID newAvailable := make([]accountWithLoad, 0, len(available)-1) for _, acc := range available { if acc.account.ID != selectedID { newAvailable = append(newAvailable, acc) } } available = newAvailable } } else { for len(available) > 0 { // 1. 取优先级最小的集合 candidates := filterByMinPriority(available) // 2. 取负载率最低的集合 candidates = filterByMinLoadRate(candidates) // 3. LRU 选择最久未用的账号 selected := selectByLRU(candidates, preferOAuth) if selected == nil { break } result, err := s.tryAcquireAccountSlot(ctx, selected.account.ID, selected.account.Concurrency) if err == nil && result.Acquired { // 会话数量限制检查 if !s.checkAndRegisterSession(ctx, selected.account, sessionHash) { result.ReleaseFunc() // 释放槽位,继续尝试下一个账号 } else { if sessionHash != "" && s.cache != nil { _ = s.cache.SetSessionAccountID(ctx, derefGroupID(groupID), sessionHash, selected.account.ID, stickySessionTTL) } return &AccountSelectionResult{ Account: selected.account, Acquired: true, ReleaseFunc: result.ReleaseFunc, }, nil } } // 移除已尝试的账号,重新进行分层过滤 selectedID := selected.account.ID newAvailable := make([]accountWithLoad, 0, len(available)-1) for _, acc := range available { if acc.account.ID != selectedID { newAvailable = append(newAvailable, acc) } } available = newAvailable } } } // ============ Layer 3: 兜底排队 ============ s.sortCandidatesForFallback(candidates, preferOAuth, cfg.FallbackSelectionMode) for _, acc := range candidates { // 会话数量限制检查(等待计划也需要占用会话配额) if !s.checkAndRegisterSession(ctx, acc, sessionHash) { continue // 会话限制已满,尝试下一个账号 } return &AccountSelectionResult{ Account: acc, WaitPlan: &AccountWaitPlan{ AccountID: acc.ID, MaxConcurrency: acc.Concurrency, Timeout: cfg.FallbackWaitTimeout, MaxWaiting: cfg.FallbackMaxWaiting, }, }, nil } return nil, errors.New("no available accounts") } func (s *GatewayService) tryAcquireByLegacyOrder(ctx context.Context, candidates []*Account, groupID *int64, sessionHash string, preferOAuth bool) (*AccountSelectionResult, bool) { ordered := append([]*Account(nil), candidates...) sortAccountsByPriorityAndLastUsed(ordered, preferOAuth) for _, acc := range ordered { result, err := s.tryAcquireAccountSlot(ctx, acc.ID, acc.Concurrency) if err == nil && result.Acquired { // 会话数量限制检查 if !s.checkAndRegisterSession(ctx, acc, sessionHash) { result.ReleaseFunc() // 释放槽位,继续尝试下一个账号 continue } if sessionHash != "" && s.cache != nil { _ = s.cache.SetSessionAccountID(ctx, derefGroupID(groupID), sessionHash, acc.ID, stickySessionTTL) } return &AccountSelectionResult{ Account: acc, Acquired: true, ReleaseFunc: result.ReleaseFunc, }, true } } return nil, false } func (s *GatewayService) schedulingConfig() config.GatewaySchedulingConfig { if s.cfg != nil { return s.cfg.Gateway.Scheduling } return config.GatewaySchedulingConfig{ StickySessionMaxWaiting: 3, StickySessionWaitTimeout: 45 * time.Second, FallbackWaitTimeout: 30 * time.Second, FallbackMaxWaiting: 100, LoadBatchEnabled: true, SlotCleanupInterval: 30 * time.Second, } } func (s *GatewayService) withGroupContext(ctx context.Context, group *Group) context.Context { if !IsGroupContextValid(group) { return ctx } if existing, ok := ctx.Value(ctxkey.Group).(*Group); ok && existing != nil && existing.ID == group.ID && IsGroupContextValid(existing) { return ctx } return context.WithValue(ctx, ctxkey.Group, group) } func (s *GatewayService) groupFromContext(ctx context.Context, groupID int64) *Group { if group, ok := ctx.Value(ctxkey.Group).(*Group); ok && IsGroupContextValid(group) && group.ID == groupID { return group } return nil } func (s *GatewayService) resolveGroupByID(ctx context.Context, groupID int64) (*Group, error) { if group := s.groupFromContext(ctx, groupID); group != nil { return group, nil } group, err := s.groupRepo.GetByIDLite(ctx, groupID) if err != nil { return nil, fmt.Errorf("get group failed: %w", err) } return group, nil } func (s *GatewayService) ResolveGroupByID(ctx context.Context, groupID int64) (*Group, error) { return s.resolveGroupByID(ctx, groupID) } func (s *GatewayService) routingAccountIDsForRequest(ctx context.Context, groupID *int64, requestedModel string, platform string) []int64 { if groupID == nil || requestedModel == "" || platform != PlatformAnthropic { return nil } group, err := s.resolveGroupByID(ctx, *groupID) if err != nil || group == nil { if s.debugModelRoutingEnabled() { log.Printf("[ModelRoutingDebug] resolve group failed: group_id=%v model=%s platform=%s err=%v", derefGroupID(groupID), requestedModel, platform, err) } return nil } // Preserve existing behavior: model routing only applies to anthropic groups. if group.Platform != PlatformAnthropic { if s.debugModelRoutingEnabled() { log.Printf("[ModelRoutingDebug] skip: non-anthropic group platform: group_id=%d group_platform=%s model=%s", group.ID, group.Platform, requestedModel) } return nil } ids := group.GetRoutingAccountIDs(requestedModel) if s.debugModelRoutingEnabled() { log.Printf("[ModelRoutingDebug] routing lookup: group_id=%d model=%s enabled=%v rules=%d matched_ids=%v", group.ID, requestedModel, group.ModelRoutingEnabled, len(group.ModelRouting), ids) } return ids } func (s *GatewayService) resolveGatewayGroup(ctx context.Context, groupID *int64) (*Group, *int64, error) { if groupID == nil { return nil, nil, nil } currentID := *groupID visited := map[int64]struct{}{} for { if _, seen := visited[currentID]; seen { return nil, nil, fmt.Errorf("fallback group cycle detected") } visited[currentID] = struct{}{} group, err := s.resolveGroupByID(ctx, currentID) if err != nil { return nil, nil, err } if !group.ClaudeCodeOnly || IsClaudeCodeClient(ctx) { return group, ¤tID, nil } if group.FallbackGroupID == nil { return nil, nil, ErrClaudeCodeOnly } currentID = *group.FallbackGroupID } } // checkClaudeCodeRestriction 检查分组的 Claude Code 客户端限制 // 如果分组启用了 claude_code_only 且请求不是来自 Claude Code 客户端: // - 有降级分组:返回降级分组的 ID // - 无降级分组:返回 ErrClaudeCodeOnly 错误 func (s *GatewayService) checkClaudeCodeRestriction(ctx context.Context, groupID *int64) (*Group, *int64, error) { if groupID == nil { return nil, groupID, nil } // 强制平台模式不检查 Claude Code 限制 if forcePlatform, hasForcePlatform := ctx.Value(ctxkey.ForcePlatform).(string); hasForcePlatform && forcePlatform != "" { return nil, groupID, nil } group, resolvedID, err := s.resolveGatewayGroup(ctx, groupID) if err != nil { return nil, nil, err } return group, resolvedID, nil } func (s *GatewayService) resolvePlatform(ctx context.Context, groupID *int64, group *Group) (string, bool, error) { forcePlatform, hasForcePlatform := ctx.Value(ctxkey.ForcePlatform).(string) if hasForcePlatform && forcePlatform != "" { return forcePlatform, true, nil } if group != nil { return group.Platform, false, nil } if groupID != nil { group, err := s.resolveGroupByID(ctx, *groupID) if err != nil { return "", false, err } return group.Platform, false, nil } return PlatformAnthropic, false, nil } func (s *GatewayService) listSchedulableAccounts(ctx context.Context, groupID *int64, platform string, hasForcePlatform bool) ([]Account, bool, error) { if s.schedulerSnapshot != nil { accounts, useMixed, err := s.schedulerSnapshot.ListSchedulableAccounts(ctx, groupID, platform, hasForcePlatform) if err == nil { slog.Debug("account_scheduling_list_snapshot", "group_id", derefGroupID(groupID), "platform", platform, "use_mixed", useMixed, "count", len(accounts)) for _, acc := range accounts { slog.Debug("account_scheduling_account_detail", "account_id", acc.ID, "name", acc.Name, "platform", acc.Platform, "type", acc.Type, "status", acc.Status, "tls_fingerprint", acc.IsTLSFingerprintEnabled()) } } return accounts, useMixed, err } useMixed := (platform == PlatformAnthropic || platform == PlatformGemini) && !hasForcePlatform if useMixed { platforms := []string{platform, PlatformAntigravity} var accounts []Account var err error if groupID != nil { accounts, err = s.accountRepo.ListSchedulableByGroupIDAndPlatforms(ctx, *groupID, platforms) } else { accounts, err = s.accountRepo.ListSchedulableByPlatforms(ctx, platforms) } if err != nil { slog.Debug("account_scheduling_list_failed", "group_id", derefGroupID(groupID), "platform", platform, "error", err) return nil, useMixed, err } filtered := make([]Account, 0, len(accounts)) for _, acc := range accounts { if acc.Platform == PlatformAntigravity && !acc.IsMixedSchedulingEnabled() { continue } filtered = append(filtered, acc) } slog.Debug("account_scheduling_list_mixed", "group_id", derefGroupID(groupID), "platform", platform, "raw_count", len(accounts), "filtered_count", len(filtered)) for _, acc := range filtered { slog.Debug("account_scheduling_account_detail", "account_id", acc.ID, "name", acc.Name, "platform", acc.Platform, "type", acc.Type, "status", acc.Status, "tls_fingerprint", acc.IsTLSFingerprintEnabled()) } return filtered, useMixed, nil } var accounts []Account var err error if s.cfg != nil && s.cfg.RunMode == config.RunModeSimple { accounts, err = s.accountRepo.ListSchedulableByPlatform(ctx, platform) } else if groupID != nil { accounts, err = s.accountRepo.ListSchedulableByGroupIDAndPlatform(ctx, *groupID, platform) // 分组内无账号则返回空列表,由上层处理错误,不再回退到全平台查询 } else { accounts, err = s.accountRepo.ListSchedulableByPlatform(ctx, platform) } if err != nil { slog.Debug("account_scheduling_list_failed", "group_id", derefGroupID(groupID), "platform", platform, "error", err) return nil, useMixed, err } slog.Debug("account_scheduling_list_single", "group_id", derefGroupID(groupID), "platform", platform, "count", len(accounts)) for _, acc := range accounts { slog.Debug("account_scheduling_account_detail", "account_id", acc.ID, "name", acc.Name, "platform", acc.Platform, "type", acc.Type, "status", acc.Status, "tls_fingerprint", acc.IsTLSFingerprintEnabled()) } return accounts, useMixed, nil } func (s *GatewayService) isAccountAllowedForPlatform(account *Account, platform string, useMixed bool) bool { if account == nil { return false } if useMixed { if account.Platform == platform { return true } return account.Platform == PlatformAntigravity && account.IsMixedSchedulingEnabled() } return account.Platform == platform } // isAccountInGroup checks if the account belongs to the specified group. // Returns true if groupID is nil (no group restriction) or account belongs to the group. func (s *GatewayService) isAccountInGroup(account *Account, groupID *int64) bool { if groupID == nil { return true // 无分组限制 } if account == nil { return false } for _, ag := range account.AccountGroups { if ag.GroupID == *groupID { return true } } return false } func (s *GatewayService) tryAcquireAccountSlot(ctx context.Context, accountID int64, maxConcurrency int) (*AcquireResult, error) { if s.concurrencyService == nil { return &AcquireResult{Acquired: true, ReleaseFunc: func() {}}, nil } return s.concurrencyService.AcquireAccountSlot(ctx, accountID, maxConcurrency) } // isAccountSchedulableForWindowCost 检查账号是否可根据窗口费用进行调度 // 仅适用于 Anthropic OAuth/SetupToken 账号 // 返回 true 表示可调度,false 表示不可调度 func (s *GatewayService) isAccountSchedulableForWindowCost(ctx context.Context, account *Account, isSticky bool) bool { // 只检查 Anthropic OAuth/SetupToken 账号 if !account.IsAnthropicOAuthOrSetupToken() { return true } limit := account.GetWindowCostLimit() if limit <= 0 { return true // 未启用窗口费用限制 } // 尝试从缓存获取窗口费用 var currentCost float64 if s.sessionLimitCache != nil { if cost, hit, err := s.sessionLimitCache.GetWindowCost(ctx, account.ID); err == nil && hit { currentCost = cost goto checkSchedulability } } // 缓存未命中,从数据库查询 { // 使用统一的窗口开始时间计算逻辑(考虑窗口过期情况) startTime := account.GetCurrentWindowStartTime() stats, err := s.usageLogRepo.GetAccountWindowStats(ctx, account.ID, startTime) if err != nil { // 失败开放:查询失败时允许调度 return true } // 使用标准费用(不含账号倍率) currentCost = stats.StandardCost // 设置缓存(忽略错误) if s.sessionLimitCache != nil { _ = s.sessionLimitCache.SetWindowCost(ctx, account.ID, currentCost) } } checkSchedulability: schedulability := account.CheckWindowCostSchedulability(currentCost) switch schedulability { case WindowCostSchedulable: return true case WindowCostStickyOnly: return isSticky case WindowCostNotSchedulable: return false } return true } // checkAndRegisterSession 检查并注册会话,用于会话数量限制 // 仅适用于 Anthropic OAuth/SetupToken 账号 // sessionID: 会话标识符(使用粘性会话的 hash) // 返回 true 表示允许(在限制内或会话已存在),false 表示拒绝(超出限制且是新会话) func (s *GatewayService) checkAndRegisterSession(ctx context.Context, account *Account, sessionID string) bool { // 只检查 Anthropic OAuth/SetupToken 账号 if !account.IsAnthropicOAuthOrSetupToken() { return true } maxSessions := account.GetMaxSessions() if maxSessions <= 0 || sessionID == "" { return true // 未启用会话限制或无会话ID } if s.sessionLimitCache == nil { return true // 缓存不可用时允许通过 } idleTimeout := time.Duration(account.GetSessionIdleTimeoutMinutes()) * time.Minute allowed, err := s.sessionLimitCache.RegisterSession(ctx, account.ID, sessionID, maxSessions, idleTimeout) if err != nil { // 失败开放:缓存错误时允许通过 return true } return allowed } func (s *GatewayService) getSchedulableAccount(ctx context.Context, accountID int64) (*Account, error) { if s.schedulerSnapshot != nil { return s.schedulerSnapshot.GetAccount(ctx, accountID) } return s.accountRepo.GetByID(ctx, accountID) } // filterByMinPriority 过滤出优先级最小的账号集合 func filterByMinPriority(accounts []accountWithLoad) []accountWithLoad { if len(accounts) == 0 { return accounts } minPriority := accounts[0].account.Priority for _, acc := range accounts[1:] { if acc.account.Priority < minPriority { minPriority = acc.account.Priority } } result := make([]accountWithLoad, 0, len(accounts)) for _, acc := range accounts { if acc.account.Priority == minPriority { result = append(result, acc) } } return result } // filterByMinLoadRate 过滤出负载率最低的账号集合 func filterByMinLoadRate(accounts []accountWithLoad) []accountWithLoad { if len(accounts) == 0 { return accounts } minLoadRate := accounts[0].loadInfo.LoadRate for _, acc := range accounts[1:] { if acc.loadInfo.LoadRate < minLoadRate { minLoadRate = acc.loadInfo.LoadRate } } result := make([]accountWithLoad, 0, len(accounts)) for _, acc := range accounts { if acc.loadInfo.LoadRate == minLoadRate { result = append(result, acc) } } return result } // selectByLRU 从集合中选择最久未用的账号 // 如果有多个账号具有相同的最小 LastUsedAt,则随机选择一个 func selectByLRU(accounts []accountWithLoad, preferOAuth bool) *accountWithLoad { if len(accounts) == 0 { return nil } if len(accounts) == 1 { return &accounts[0] } // 1. 找到最小的 LastUsedAt(nil 被视为最小) var minTime *time.Time hasNil := false for _, acc := range accounts { if acc.account.LastUsedAt == nil { hasNil = true break } if minTime == nil || acc.account.LastUsedAt.Before(*minTime) { minTime = acc.account.LastUsedAt } } // 2. 收集所有具有最小 LastUsedAt 的账号索引 var candidateIdxs []int for i, acc := range accounts { if hasNil { if acc.account.LastUsedAt == nil { candidateIdxs = append(candidateIdxs, i) } } else { if acc.account.LastUsedAt != nil && acc.account.LastUsedAt.Equal(*minTime) { candidateIdxs = append(candidateIdxs, i) } } } // 3. 如果只有一个候选,直接返回 if len(candidateIdxs) == 1 { return &accounts[candidateIdxs[0]] } // 4. 如果有多个候选且 preferOAuth,优先选择 OAuth 类型 if preferOAuth { var oauthIdxs []int for _, idx := range candidateIdxs { if accounts[idx].account.Type == AccountTypeOAuth { oauthIdxs = append(oauthIdxs, idx) } } if len(oauthIdxs) > 0 { candidateIdxs = oauthIdxs } } // 5. 随机选择一个 selectedIdx := candidateIdxs[mathrand.Intn(len(candidateIdxs))] return &accounts[selectedIdx] } func sortAccountsByPriorityAndLastUsed(accounts []*Account, preferOAuth bool) { sort.SliceStable(accounts, func(i, j int) bool { a, b := accounts[i], accounts[j] if a.Priority != b.Priority { return a.Priority < b.Priority } switch { case a.LastUsedAt == nil && b.LastUsedAt != nil: return true case a.LastUsedAt != nil && b.LastUsedAt == nil: return false case a.LastUsedAt == nil && b.LastUsedAt == nil: if preferOAuth && a.Type != b.Type { return a.Type == AccountTypeOAuth } return false default: return a.LastUsedAt.Before(*b.LastUsedAt) } }) } // selectByCallCount 从候选账号中选择调用次数最少的账号(Antigravity 专用) // 新账号(CallCount=0)使用平均调用次数作为虚拟值,避免冷启动被猛调 // 如果有多个账号具有相同的最小调用次数,则随机选择一个 func selectByCallCount(accounts []accountWithLoad, modelLoadMap map[int64]*ModelLoadInfo, preferOAuth bool) *accountWithLoad { if len(accounts) == 0 { return nil } if len(accounts) == 1 { return &accounts[0] } // 如果没有负载信息,回退到 LRU if modelLoadMap == nil { return selectByLRU(accounts, preferOAuth) } // 1. 计算平均调用次数(用于新账号冷启动) var totalCallCount int64 var countWithCalls int for _, acc := range accounts { if info := modelLoadMap[acc.account.ID]; info != nil && info.CallCount > 0 { totalCallCount += info.CallCount countWithCalls++ } } var avgCallCount int64 if countWithCalls > 0 { avgCallCount = totalCallCount / int64(countWithCalls) } // 2. 获取每个账号的有效调用次数 getEffectiveCallCount := func(acc accountWithLoad) int64 { if acc.account == nil { return 0 } info := modelLoadMap[acc.account.ID] if info == nil || info.CallCount == 0 { return avgCallCount // 新账号使用平均值 } return info.CallCount } // 3. 找到最小调用次数 minCount := getEffectiveCallCount(accounts[0]) for _, acc := range accounts[1:] { if c := getEffectiveCallCount(acc); c < minCount { minCount = c } } // 4. 收集所有具有最小调用次数的账号 var candidateIdxs []int for i, acc := range accounts { if getEffectiveCallCount(acc) == minCount { candidateIdxs = append(candidateIdxs, i) } } // 5. 如果只有一个候选,直接返回 if len(candidateIdxs) == 1 { return &accounts[candidateIdxs[0]] } // 6. preferOAuth 处理 if preferOAuth { var oauthIdxs []int for _, idx := range candidateIdxs { if accounts[idx].account.Type == AccountTypeOAuth { oauthIdxs = append(oauthIdxs, idx) } } if len(oauthIdxs) > 0 { candidateIdxs = oauthIdxs } } // 7. 随机选择 return &accounts[candidateIdxs[mathrand.Intn(len(candidateIdxs))]] } // sortCandidatesForFallback 根据配置选择排序策略 // mode: "last_used"(按最后使用时间) 或 "random"(随机) func (s *GatewayService) sortCandidatesForFallback(accounts []*Account, preferOAuth bool, mode string) { if mode == "random" { // 先按优先级排序,然后在同优先级内随机打乱 sortAccountsByPriorityOnly(accounts, preferOAuth) shuffleWithinPriority(accounts) } else { // 默认按最后使用时间排序 sortAccountsByPriorityAndLastUsed(accounts, preferOAuth) } } // sortAccountsByPriorityOnly 仅按优先级排序 func sortAccountsByPriorityOnly(accounts []*Account, preferOAuth bool) { sort.SliceStable(accounts, func(i, j int) bool { a, b := accounts[i], accounts[j] if a.Priority != b.Priority { return a.Priority < b.Priority } if preferOAuth && a.Type != b.Type { return a.Type == AccountTypeOAuth } return false }) } // shuffleWithinPriority 在同优先级内随机打乱顺序 func shuffleWithinPriority(accounts []*Account) { if len(accounts) <= 1 { return } r := mathrand.New(mathrand.NewSource(time.Now().UnixNano())) start := 0 for start < len(accounts) { priority := accounts[start].Priority end := start + 1 for end < len(accounts) && accounts[end].Priority == priority { end++ } // 对 [start, end) 范围内的账户随机打乱 if end-start > 1 { r.Shuffle(end-start, func(i, j int) { accounts[start+i], accounts[start+j] = accounts[start+j], accounts[start+i] }) } start = end } } // selectAccountForModelWithPlatform 选择单平台账户(完全隔离) func (s *GatewayService) selectAccountForModelWithPlatform(ctx context.Context, groupID *int64, sessionHash string, requestedModel string, excludedIDs map[int64]struct{}, platform string) (*Account, error) { // 对 Antigravity 平台,检查请求的模型系列是否在分组支持范围内 if platform == PlatformAntigravity && groupID != nil && requestedModel != "" { if err := s.checkAntigravityModelScope(ctx, *groupID, requestedModel); err != nil { return nil, err } } preferOAuth := platform == PlatformGemini routingAccountIDs := s.routingAccountIDsForRequest(ctx, groupID, requestedModel, platform) var accounts []Account accountsLoaded := false // ============ Model Routing (legacy path): apply before sticky session ============ // When load-awareness is disabled (e.g. concurrency service not configured), we still honor model routing // so switching model can switch upstream account within the same sticky session. if len(routingAccountIDs) > 0 { if s.debugModelRoutingEnabled() { log.Printf("[ModelRoutingDebug] legacy routed begin: group_id=%v model=%s platform=%s session=%s routed_ids=%v", derefGroupID(groupID), requestedModel, platform, shortSessionHash(sessionHash), routingAccountIDs) } // 1) Sticky session only applies if the bound account is within the routing set. if sessionHash != "" && s.cache != nil { accountID, err := s.cache.GetSessionAccountID(ctx, derefGroupID(groupID), sessionHash) if err == nil && accountID > 0 && containsInt64(routingAccountIDs, accountID) { if _, excluded := excludedIDs[accountID]; !excluded { account, err := s.getSchedulableAccount(ctx, accountID) // 检查账号分组归属和平台匹配(确保粘性会话不会跨分组或跨平台) if err == nil { clearSticky := shouldClearStickySession(account, requestedModel) if clearSticky { _ = s.cache.DeleteSessionAccountID(ctx, derefGroupID(groupID), sessionHash) } if !clearSticky && s.isAccountInGroup(account, groupID) && account.Platform == platform && (requestedModel == "" || s.isModelSupportedByAccountWithContext(ctx, account, requestedModel)) && account.IsSchedulableForModelWithContext(ctx, requestedModel) { if err := s.cache.RefreshSessionTTL(ctx, derefGroupID(groupID), sessionHash, stickySessionTTL); err != nil { log.Printf("refresh session ttl failed: session=%s err=%v", sessionHash, err) } if s.debugModelRoutingEnabled() { log.Printf("[ModelRoutingDebug] legacy routed sticky hit: group_id=%v model=%s session=%s account=%d", derefGroupID(groupID), requestedModel, shortSessionHash(sessionHash), accountID) } return account, nil } } } } } // 2) Select an account from the routed candidates. forcePlatform, hasForcePlatform := ctx.Value(ctxkey.ForcePlatform).(string) if hasForcePlatform && forcePlatform == "" { hasForcePlatform = false } var err error accounts, _, err = s.listSchedulableAccounts(ctx, groupID, platform, hasForcePlatform) if err != nil { return nil, fmt.Errorf("query accounts failed: %w", err) } accountsLoaded = true routingSet := make(map[int64]struct{}, len(routingAccountIDs)) for _, id := range routingAccountIDs { if id > 0 { routingSet[id] = struct{}{} } } var selected *Account for i := range accounts { acc := &accounts[i] if _, ok := routingSet[acc.ID]; !ok { continue } if _, excluded := excludedIDs[acc.ID]; excluded { continue } // Scheduler snapshots can be temporarily stale; re-check schedulability here to // avoid selecting accounts that were recently rate-limited/overloaded. if !acc.IsSchedulable() { continue } if requestedModel != "" && !s.isModelSupportedByAccountWithContext(ctx, acc, requestedModel) { continue } if !acc.IsSchedulableForModelWithContext(ctx, requestedModel) { continue } if selected == nil { selected = acc continue } if acc.Priority < selected.Priority { selected = acc } else if acc.Priority == selected.Priority { switch { case acc.LastUsedAt == nil && selected.LastUsedAt != nil: selected = acc case acc.LastUsedAt != nil && selected.LastUsedAt == nil: // keep selected (never used is preferred) case acc.LastUsedAt == nil && selected.LastUsedAt == nil: if preferOAuth && acc.Type != selected.Type && acc.Type == AccountTypeOAuth { selected = acc } default: if acc.LastUsedAt.Before(*selected.LastUsedAt) { selected = acc } } } } if selected != nil { if sessionHash != "" && s.cache != nil { if err := s.cache.SetSessionAccountID(ctx, derefGroupID(groupID), sessionHash, selected.ID, stickySessionTTL); err != nil { log.Printf("set session account failed: session=%s account_id=%d err=%v", sessionHash, selected.ID, err) } } if s.debugModelRoutingEnabled() { log.Printf("[ModelRoutingDebug] legacy routed select: group_id=%v model=%s session=%s account=%d", derefGroupID(groupID), requestedModel, shortSessionHash(sessionHash), selected.ID) } return selected, nil } log.Printf("[ModelRouting] No routed accounts available for model=%s, falling back to normal selection", requestedModel) } // 1. 查询粘性会话 if sessionHash != "" && s.cache != nil { accountID, err := s.cache.GetSessionAccountID(ctx, derefGroupID(groupID), sessionHash) if err == nil && accountID > 0 { if _, excluded := excludedIDs[accountID]; !excluded { account, err := s.getSchedulableAccount(ctx, accountID) // 检查账号分组归属和平台匹配(确保粘性会话不会跨分组或跨平台) if err == nil { clearSticky := shouldClearStickySession(account, requestedModel) if clearSticky { _ = s.cache.DeleteSessionAccountID(ctx, derefGroupID(groupID), sessionHash) } if !clearSticky && s.isAccountInGroup(account, groupID) && account.Platform == platform && (requestedModel == "" || s.isModelSupportedByAccountWithContext(ctx, account, requestedModel)) && account.IsSchedulableForModelWithContext(ctx, requestedModel) { if err := s.cache.RefreshSessionTTL(ctx, derefGroupID(groupID), sessionHash, stickySessionTTL); err != nil { log.Printf("refresh session ttl failed: session=%s err=%v", sessionHash, err) } return account, nil } } } } } // 2. 获取可调度账号列表(单平台) if !accountsLoaded { forcePlatform, hasForcePlatform := ctx.Value(ctxkey.ForcePlatform).(string) if hasForcePlatform && forcePlatform == "" { hasForcePlatform = false } var err error accounts, _, err = s.listSchedulableAccounts(ctx, groupID, platform, hasForcePlatform) if err != nil { return nil, fmt.Errorf("query accounts failed: %w", err) } } // 3. 按优先级+最久未用选择(考虑模型支持) var selected *Account for i := range accounts { acc := &accounts[i] if _, excluded := excludedIDs[acc.ID]; excluded { continue } // Scheduler snapshots can be temporarily stale; re-check schedulability here to // avoid selecting accounts that were recently rate-limited/overloaded. if !acc.IsSchedulable() { continue } if requestedModel != "" && !s.isModelSupportedByAccountWithContext(ctx, acc, requestedModel) { continue } if !acc.IsSchedulableForModelWithContext(ctx, requestedModel) { continue } if selected == nil { selected = acc continue } if acc.Priority < selected.Priority { selected = acc } else if acc.Priority == selected.Priority { switch { case acc.LastUsedAt == nil && selected.LastUsedAt != nil: selected = acc case acc.LastUsedAt != nil && selected.LastUsedAt == nil: // keep selected (never used is preferred) case acc.LastUsedAt == nil && selected.LastUsedAt == nil: if preferOAuth && acc.Type != selected.Type && acc.Type == AccountTypeOAuth { selected = acc } default: if acc.LastUsedAt.Before(*selected.LastUsedAt) { selected = acc } } } } if selected == nil { if requestedModel != "" { return nil, fmt.Errorf("no available accounts supporting model: %s", requestedModel) } return nil, errors.New("no available accounts") } // 4. 建立粘性绑定 if sessionHash != "" && s.cache != nil { if err := s.cache.SetSessionAccountID(ctx, derefGroupID(groupID), sessionHash, selected.ID, stickySessionTTL); err != nil { log.Printf("set session account failed: session=%s account_id=%d err=%v", sessionHash, selected.ID, err) } } return selected, nil } // selectAccountWithMixedScheduling 选择账户(支持混合调度) // 查询原生平台账户 + 启用 mixed_scheduling 的 antigravity 账户 func (s *GatewayService) selectAccountWithMixedScheduling(ctx context.Context, groupID *int64, sessionHash string, requestedModel string, excludedIDs map[int64]struct{}, nativePlatform string) (*Account, error) { preferOAuth := nativePlatform == PlatformGemini routingAccountIDs := s.routingAccountIDsForRequest(ctx, groupID, requestedModel, nativePlatform) var accounts []Account accountsLoaded := false // ============ Model Routing (legacy path): apply before sticky session ============ if len(routingAccountIDs) > 0 { if s.debugModelRoutingEnabled() { log.Printf("[ModelRoutingDebug] legacy mixed routed begin: group_id=%v model=%s platform=%s session=%s routed_ids=%v", derefGroupID(groupID), requestedModel, nativePlatform, shortSessionHash(sessionHash), routingAccountIDs) } // 1) Sticky session only applies if the bound account is within the routing set. if sessionHash != "" && s.cache != nil { accountID, err := s.cache.GetSessionAccountID(ctx, derefGroupID(groupID), sessionHash) if err == nil && accountID > 0 && containsInt64(routingAccountIDs, accountID) { if _, excluded := excludedIDs[accountID]; !excluded { account, err := s.getSchedulableAccount(ctx, accountID) // 检查账号分组归属和有效性:原生平台直接匹配,antigravity 需要启用混合调度 if err == nil { clearSticky := shouldClearStickySession(account, requestedModel) if clearSticky { _ = s.cache.DeleteSessionAccountID(ctx, derefGroupID(groupID), sessionHash) } if !clearSticky && s.isAccountInGroup(account, groupID) && (requestedModel == "" || s.isModelSupportedByAccountWithContext(ctx, account, requestedModel)) && account.IsSchedulableForModelWithContext(ctx, requestedModel) { if account.Platform == nativePlatform || (account.Platform == PlatformAntigravity && account.IsMixedSchedulingEnabled()) { if err := s.cache.RefreshSessionTTL(ctx, derefGroupID(groupID), sessionHash, stickySessionTTL); err != nil { log.Printf("refresh session ttl failed: session=%s err=%v", sessionHash, err) } if s.debugModelRoutingEnabled() { log.Printf("[ModelRoutingDebug] legacy mixed routed sticky hit: group_id=%v model=%s session=%s account=%d", derefGroupID(groupID), requestedModel, shortSessionHash(sessionHash), accountID) } return account, nil } } } } } } // 2) Select an account from the routed candidates. var err error accounts, _, err = s.listSchedulableAccounts(ctx, groupID, nativePlatform, false) if err != nil { return nil, fmt.Errorf("query accounts failed: %w", err) } accountsLoaded = true routingSet := make(map[int64]struct{}, len(routingAccountIDs)) for _, id := range routingAccountIDs { if id > 0 { routingSet[id] = struct{}{} } } var selected *Account for i := range accounts { acc := &accounts[i] if _, ok := routingSet[acc.ID]; !ok { continue } if _, excluded := excludedIDs[acc.ID]; excluded { continue } // Scheduler snapshots can be temporarily stale; re-check schedulability here to // avoid selecting accounts that were recently rate-limited/overloaded. if !acc.IsSchedulable() { continue } // 过滤:原生平台直接通过,antigravity 需要启用混合调度 if acc.Platform == PlatformAntigravity && !acc.IsMixedSchedulingEnabled() { continue } if requestedModel != "" && !s.isModelSupportedByAccountWithContext(ctx, acc, requestedModel) { continue } if !acc.IsSchedulableForModelWithContext(ctx, requestedModel) { continue } if selected == nil { selected = acc continue } if acc.Priority < selected.Priority { selected = acc } else if acc.Priority == selected.Priority { switch { case acc.LastUsedAt == nil && selected.LastUsedAt != nil: selected = acc case acc.LastUsedAt != nil && selected.LastUsedAt == nil: // keep selected (never used is preferred) case acc.LastUsedAt == nil && selected.LastUsedAt == nil: if preferOAuth && acc.Platform == PlatformGemini && selected.Platform == PlatformGemini && acc.Type != selected.Type && acc.Type == AccountTypeOAuth { selected = acc } default: if acc.LastUsedAt.Before(*selected.LastUsedAt) { selected = acc } } } } if selected != nil { if sessionHash != "" && s.cache != nil { if err := s.cache.SetSessionAccountID(ctx, derefGroupID(groupID), sessionHash, selected.ID, stickySessionTTL); err != nil { log.Printf("set session account failed: session=%s account_id=%d err=%v", sessionHash, selected.ID, err) } } if s.debugModelRoutingEnabled() { log.Printf("[ModelRoutingDebug] legacy mixed routed select: group_id=%v model=%s session=%s account=%d", derefGroupID(groupID), requestedModel, shortSessionHash(sessionHash), selected.ID) } return selected, nil } log.Printf("[ModelRouting] No routed accounts available for model=%s, falling back to normal selection", requestedModel) } // 1. 查询粘性会话 if sessionHash != "" && s.cache != nil { accountID, err := s.cache.GetSessionAccountID(ctx, derefGroupID(groupID), sessionHash) if err == nil && accountID > 0 { if _, excluded := excludedIDs[accountID]; !excluded { account, err := s.getSchedulableAccount(ctx, accountID) // 检查账号分组归属和有效性:原生平台直接匹配,antigravity 需要启用混合调度 if err == nil { clearSticky := shouldClearStickySession(account, requestedModel) if clearSticky { _ = s.cache.DeleteSessionAccountID(ctx, derefGroupID(groupID), sessionHash) } if !clearSticky && s.isAccountInGroup(account, groupID) && (requestedModel == "" || s.isModelSupportedByAccountWithContext(ctx, account, requestedModel)) && account.IsSchedulableForModelWithContext(ctx, requestedModel) { if account.Platform == nativePlatform || (account.Platform == PlatformAntigravity && account.IsMixedSchedulingEnabled()) { if err := s.cache.RefreshSessionTTL(ctx, derefGroupID(groupID), sessionHash, stickySessionTTL); err != nil { log.Printf("refresh session ttl failed: session=%s err=%v", sessionHash, err) } return account, nil } } } } } } // 2. 获取可调度账号列表 if !accountsLoaded { var err error accounts, _, err = s.listSchedulableAccounts(ctx, groupID, nativePlatform, false) if err != nil { return nil, fmt.Errorf("query accounts failed: %w", err) } } // 3. 按优先级+最久未用选择(考虑模型支持和混合调度) var selected *Account for i := range accounts { acc := &accounts[i] if _, excluded := excludedIDs[acc.ID]; excluded { continue } // Scheduler snapshots can be temporarily stale; re-check schedulability here to // avoid selecting accounts that were recently rate-limited/overloaded. if !acc.IsSchedulable() { continue } // 过滤:原生平台直接通过,antigravity 需要启用混合调度 if acc.Platform == PlatformAntigravity && !acc.IsMixedSchedulingEnabled() { continue } if requestedModel != "" && !s.isModelSupportedByAccountWithContext(ctx, acc, requestedModel) { continue } if !acc.IsSchedulableForModelWithContext(ctx, requestedModel) { continue } if selected == nil { selected = acc continue } if acc.Priority < selected.Priority { selected = acc } else if acc.Priority == selected.Priority { switch { case acc.LastUsedAt == nil && selected.LastUsedAt != nil: selected = acc case acc.LastUsedAt != nil && selected.LastUsedAt == nil: // keep selected (never used is preferred) case acc.LastUsedAt == nil && selected.LastUsedAt == nil: if preferOAuth && acc.Platform == PlatformGemini && selected.Platform == PlatformGemini && acc.Type != selected.Type && acc.Type == AccountTypeOAuth { selected = acc } default: if acc.LastUsedAt.Before(*selected.LastUsedAt) { selected = acc } } } } if selected == nil { if requestedModel != "" { return nil, fmt.Errorf("no available accounts supporting model: %s", requestedModel) } return nil, errors.New("no available accounts") } // 4. 建立粘性绑定 if sessionHash != "" && s.cache != nil { if err := s.cache.SetSessionAccountID(ctx, derefGroupID(groupID), sessionHash, selected.ID, stickySessionTTL); err != nil { log.Printf("set session account failed: session=%s account_id=%d err=%v", sessionHash, selected.ID, err) } } return selected, nil } // isModelSupportedByAccountWithContext 根据账户平台检查模型支持(带 context) // 对于 Antigravity 平台,会先获取映射后的最终模型名(包括 thinking 后缀)再检查支持 func (s *GatewayService) isModelSupportedByAccountWithContext(ctx context.Context, account *Account, requestedModel string) bool { if account.Platform == PlatformAntigravity { if strings.TrimSpace(requestedModel) == "" { return true } // 使用与转发阶段一致的映射逻辑:自定义映射优先 → 默认映射兜底 mapped := mapAntigravityModel(account, requestedModel) if mapped == "" { return false } // 应用 thinking 后缀后检查最终模型是否在账号映射中 if enabled, ok := ctx.Value(ctxkey.ThinkingEnabled).(bool); ok { finalModel := applyThinkingModelSuffix(mapped, enabled) if finalModel == mapped { return true // thinking 后缀未改变模型名,映射已通过 } return account.IsModelSupported(finalModel) } return true } return s.isModelSupportedByAccount(account, requestedModel) } // isModelSupportedByAccount 根据账户平台检查模型支持(无 context,用于非 Antigravity 平台) func (s *GatewayService) isModelSupportedByAccount(account *Account, requestedModel string) bool { if account.Platform == PlatformAntigravity { if strings.TrimSpace(requestedModel) == "" { return true } return mapAntigravityModel(account, requestedModel) != "" } // OAuth/SetupToken 账号使用 Anthropic 标准映射(短ID → 长ID) if account.Platform == PlatformAnthropic && account.Type != AccountTypeAPIKey { requestedModel = claude.NormalizeModelID(requestedModel) } // Gemini API Key 账户直接透传,由上游判断模型是否支持 if account.Platform == PlatformGemini && account.Type == AccountTypeAPIKey { return true } // 其他平台使用账户的模型支持检查 return account.IsModelSupported(requestedModel) } // GetAccessToken 获取账号凭证 func (s *GatewayService) GetAccessToken(ctx context.Context, account *Account) (string, string, error) { switch account.Type { case AccountTypeOAuth, AccountTypeSetupToken: // Both oauth and setup-token use OAuth token flow return s.getOAuthToken(ctx, account) case AccountTypeAPIKey: apiKey := account.GetCredential("api_key") if apiKey == "" { return "", "", errors.New("api_key not found in credentials") } return apiKey, "apikey", nil default: return "", "", fmt.Errorf("unsupported account type: %s", account.Type) } } func (s *GatewayService) getOAuthToken(ctx context.Context, account *Account) (string, string, error) { // 对于 Anthropic OAuth 账号,使用 ClaudeTokenProvider 获取缓存的 token if account.Platform == PlatformAnthropic && account.Type == AccountTypeOAuth && s.claudeTokenProvider != nil { accessToken, err := s.claudeTokenProvider.GetAccessToken(ctx, account) if err != nil { return "", "", err } return accessToken, "oauth", nil } // 其他情况(Gemini 有自己的 TokenProvider,setup-token 类型等)直接从账号读取 accessToken := account.GetCredential("access_token") if accessToken == "" { return "", "", errors.New("access_token not found in credentials") } // Token刷新由后台 TokenRefreshService 处理,此处只返回当前token return accessToken, "oauth", nil } // 重试相关常量 const ( // 最大尝试次数(包含首次请求)。过多重试会导致请求堆积与资源耗尽。 maxRetryAttempts = 5 // 指数退避:第 N 次失败后的等待 = retryBaseDelay * 2^(N-1),并且上限为 retryMaxDelay。 retryBaseDelay = 300 * time.Millisecond retryMaxDelay = 3 * time.Second // 最大重试耗时(包含请求本身耗时 + 退避等待时间)。 // 用于防止极端情况下 goroutine 长时间堆积导致资源耗尽。 maxRetryElapsed = 10 * time.Second ) func (s *GatewayService) shouldRetryUpstreamError(account *Account, statusCode int) bool { // OAuth/Setup Token 账号:仅 403 重试 if account.IsOAuth() { return statusCode == 403 } // API Key 账号:未配置的错误码重试 return !account.ShouldHandleErrorCode(statusCode) } // shouldFailoverUpstreamError determines whether an upstream error should trigger account failover. func (s *GatewayService) shouldFailoverUpstreamError(statusCode int) bool { switch statusCode { case 401, 403, 429, 529: return true default: return statusCode >= 500 } } func retryBackoffDelay(attempt int) time.Duration { // attempt 从 1 开始,表示第 attempt 次请求刚失败,需要等待后进行第 attempt+1 次请求。 if attempt <= 0 { return retryBaseDelay } delay := retryBaseDelay * time.Duration(1<<(attempt-1)) if delay > retryMaxDelay { return retryMaxDelay } return delay } func sleepWithContext(ctx context.Context, d time.Duration) error { if d <= 0 { return nil } timer := time.NewTimer(d) defer func() { if !timer.Stop() { select { case <-timer.C: default: } } }() select { case <-ctx.Done(): return ctx.Err() case <-timer.C: return nil } } // isClaudeCodeClient 判断请求是否来自 Claude Code 客户端 // 简化判断:User-Agent 匹配 + metadata.user_id 存在 func isClaudeCodeClient(userAgent string, metadataUserID string) bool { if metadataUserID == "" { return false } return claudeCliUserAgentRe.MatchString(userAgent) } func isClaudeCodeRequest(ctx context.Context, c *gin.Context, parsed *ParsedRequest) bool { if IsClaudeCodeClient(ctx) { return true } if parsed == nil || c == nil { return false } return isClaudeCodeClient(c.GetHeader("User-Agent"), parsed.MetadataUserID) } // systemIncludesClaudeCodePrompt 检查 system 中是否已包含 Claude Code 提示词 // 使用前缀匹配支持多种变体(标准版、Agent SDK 版等) func systemIncludesClaudeCodePrompt(system any) bool { switch v := system.(type) { case string: return hasClaudeCodePrefix(v) case []any: for _, item := range v { if m, ok := item.(map[string]any); ok { if text, ok := m["text"].(string); ok && hasClaudeCodePrefix(text) { return true } } } } return false } // hasClaudeCodePrefix 检查文本是否以 Claude Code 提示词的特征前缀开头 func hasClaudeCodePrefix(text string) bool { for _, prefix := range claudeCodePromptPrefixes { if strings.HasPrefix(text, prefix) { return true } } return false } // injectClaudeCodePrompt 在 system 开头注入 Claude Code 提示词 // 处理 null、字符串、数组三种格式 func injectClaudeCodePrompt(body []byte, system any) []byte { claudeCodeBlock := map[string]any{ "type": "text", "text": claudeCodeSystemPrompt, "cache_control": map[string]string{"type": "ephemeral"}, } // Opencode plugin applies an extra safeguard: it not only prepends the Claude Code // banner, it also prefixes the next system instruction with the same banner plus // a blank line. This helps when upstream concatenates system instructions. claudeCodePrefix := strings.TrimSpace(claudeCodeSystemPrompt) var newSystem []any switch v := system.(type) { case nil: newSystem = []any{claudeCodeBlock} case string: // Be tolerant of older/newer clients that may differ only by trailing whitespace/newlines. if strings.TrimSpace(v) == "" || strings.TrimSpace(v) == strings.TrimSpace(claudeCodeSystemPrompt) { newSystem = []any{claudeCodeBlock} } else { // Mirror opencode behavior: keep the banner as a separate system entry, // but also prefix the next system text with the banner. merged := v if !strings.HasPrefix(v, claudeCodePrefix) { merged = claudeCodePrefix + "\n\n" + v } newSystem = []any{claudeCodeBlock, map[string]any{"type": "text", "text": merged}} } case []any: newSystem = make([]any, 0, len(v)+1) newSystem = append(newSystem, claudeCodeBlock) prefixedNext := false for _, item := range v { if m, ok := item.(map[string]any); ok { if text, ok := m["text"].(string); ok && strings.TrimSpace(text) == strings.TrimSpace(claudeCodeSystemPrompt) { continue } // Prefix the first subsequent text system block once. if !prefixedNext { if blockType, _ := m["type"].(string); blockType == "text" { if text, ok := m["text"].(string); ok && strings.TrimSpace(text) != "" && !strings.HasPrefix(text, claudeCodePrefix) { m["text"] = claudeCodePrefix + "\n\n" + text prefixedNext = true } } } } newSystem = append(newSystem, item) } default: newSystem = []any{claudeCodeBlock} } result, err := sjson.SetBytes(body, "system", newSystem) if err != nil { log.Printf("Warning: failed to inject Claude Code prompt: %v", err) return body } return result } // enforceCacheControlLimit 强制执行 cache_control 块数量限制(最多 4 个) // 超限时优先从 messages 中移除 cache_control,保护 system 中的缓存控制 func enforceCacheControlLimit(body []byte) []byte { var data map[string]any if err := json.Unmarshal(body, &data); err != nil { return body } // 清理 thinking 块中的非法 cache_control(thinking 块不支持该字段) removeCacheControlFromThinkingBlocks(data) // 计算当前 cache_control 块数量 count := countCacheControlBlocks(data) if count <= maxCacheControlBlocks { return body } // 超限:优先从 messages 中移除,再从 system 中移除 for count > maxCacheControlBlocks { if removeCacheControlFromMessages(data) { count-- continue } if removeCacheControlFromSystem(data) { count-- continue } break } result, err := json.Marshal(data) if err != nil { return body } return result } // countCacheControlBlocks 统计 system 和 messages 中的 cache_control 块数量 // 注意:thinking 块不支持 cache_control,统计时跳过 func countCacheControlBlocks(data map[string]any) int { count := 0 // 统计 system 中的块 if system, ok := data["system"].([]any); ok { for _, item := range system { if m, ok := item.(map[string]any); ok { // thinking 块不支持 cache_control,跳过 if blockType, _ := m["type"].(string); blockType == "thinking" { continue } if _, has := m["cache_control"]; has { count++ } } } } // 统计 messages 中的块 if messages, ok := data["messages"].([]any); ok { for _, msg := range messages { if msgMap, ok := msg.(map[string]any); ok { if content, ok := msgMap["content"].([]any); ok { for _, item := range content { if m, ok := item.(map[string]any); ok { // thinking 块不支持 cache_control,跳过 if blockType, _ := m["type"].(string); blockType == "thinking" { continue } if _, has := m["cache_control"]; has { count++ } } } } } } } return count } // removeCacheControlFromMessages 从 messages 中移除一个 cache_control(从头开始) // 返回 true 表示成功移除,false 表示没有可移除的 // 注意:跳过 thinking 块(它不支持 cache_control) func removeCacheControlFromMessages(data map[string]any) bool { messages, ok := data["messages"].([]any) if !ok { return false } for _, msg := range messages { msgMap, ok := msg.(map[string]any) if !ok { continue } content, ok := msgMap["content"].([]any) if !ok { continue } for _, item := range content { if m, ok := item.(map[string]any); ok { // thinking 块不支持 cache_control,跳过 if blockType, _ := m["type"].(string); blockType == "thinking" { continue } if _, has := m["cache_control"]; has { delete(m, "cache_control") return true } } } } return false } // removeCacheControlFromSystem 从 system 中移除一个 cache_control(从尾部开始,保护注入的 prompt) // 返回 true 表示成功移除,false 表示没有可移除的 // 注意:跳过 thinking 块(它不支持 cache_control) func removeCacheControlFromSystem(data map[string]any) bool { system, ok := data["system"].([]any) if !ok { return false } // 从尾部开始移除,保护开头注入的 Claude Code prompt for i := len(system) - 1; i >= 0; i-- { if m, ok := system[i].(map[string]any); ok { // thinking 块不支持 cache_control,跳过 if blockType, _ := m["type"].(string); blockType == "thinking" { continue } if _, has := m["cache_control"]; has { delete(m, "cache_control") return true } } } return false } // removeCacheControlFromThinkingBlocks 强制清理所有 thinking 块中的非法 cache_control // thinking 块不支持 cache_control 字段,这个函数确保所有 thinking 块都不含该字段 func removeCacheControlFromThinkingBlocks(data map[string]any) { // 清理 system 中的 thinking 块 if system, ok := data["system"].([]any); ok { for _, item := range system { if m, ok := item.(map[string]any); ok { if blockType, _ := m["type"].(string); blockType == "thinking" { if _, has := m["cache_control"]; has { delete(m, "cache_control") log.Printf("[Warning] Removed illegal cache_control from thinking block in system") } } } } } // 清理 messages 中的 thinking 块 if messages, ok := data["messages"].([]any); ok { for msgIdx, msg := range messages { if msgMap, ok := msg.(map[string]any); ok { if content, ok := msgMap["content"].([]any); ok { for contentIdx, item := range content { if m, ok := item.(map[string]any); ok { if blockType, _ := m["type"].(string); blockType == "thinking" { if _, has := m["cache_control"]; has { delete(m, "cache_control") log.Printf("[Warning] Removed illegal cache_control from thinking block in messages[%d].content[%d]", msgIdx, contentIdx) } } } } } } } } } // Forward 转发请求到Claude API func (s *GatewayService) Forward(ctx context.Context, c *gin.Context, account *Account, parsed *ParsedRequest) (*ForwardResult, error) { startTime := time.Now() if parsed == nil { return nil, fmt.Errorf("parse request: empty request") } body := parsed.Body reqModel := parsed.Model reqStream := parsed.Stream originalModel := reqModel isClaudeCode := isClaudeCodeRequest(ctx, c, parsed) shouldMimicClaudeCode := account.IsOAuth() && !isClaudeCode if shouldMimicClaudeCode { // 智能注入 Claude Code 系统提示词(仅 OAuth/SetupToken 账号需要) // 条件:1) OAuth/SetupToken 账号 2) 不是 Claude Code 客户端 3) 不是 Haiku 模型 4) system 中还没有 Claude Code 提示词 if !strings.Contains(strings.ToLower(reqModel), "haiku") && !systemIncludesClaudeCodePrompt(parsed.System) { body = injectClaudeCodePrompt(body, parsed.System) } normalizeOpts := claudeOAuthNormalizeOptions{stripSystemCacheControl: true} if s.identityService != nil { fp, err := s.identityService.GetOrCreateFingerprint(ctx, account.ID, c.Request.Header) if err == nil && fp != nil { if metadataUserID := s.buildOAuthMetadataUserID(parsed, account, fp); metadataUserID != "" { normalizeOpts.injectMetadata = true normalizeOpts.metadataUserID = metadataUserID } } } body, reqModel = normalizeClaudeOAuthRequestBody(body, reqModel, normalizeOpts) } // 强制执行 cache_control 块数量限制(最多 4 个) body = enforceCacheControlLimit(body) // 应用模型映射: // - APIKey 账号:使用账号级别的显式映射(如果配置),否则透传原始模型名 // - OAuth/SetupToken 账号:使用 Anthropic 标准映射(短ID → 长ID) mappedModel := reqModel mappingSource := "" if account.Type == AccountTypeAPIKey { mappedModel = account.GetMappedModel(reqModel) if mappedModel != reqModel { mappingSource = "account" } } if mappingSource == "" && account.Platform == PlatformAnthropic && account.Type != AccountTypeAPIKey { normalized := claude.NormalizeModelID(reqModel) if normalized != reqModel { mappedModel = normalized mappingSource = "prefix" } } if mappedModel != reqModel { // 替换请求体中的模型名 body = s.replaceModelInBody(body, mappedModel) reqModel = mappedModel log.Printf("Model mapping applied: %s -> %s (account: %s, source=%s)", originalModel, mappedModel, account.Name, mappingSource) } // 获取凭证 token, tokenType, err := s.GetAccessToken(ctx, account) if err != nil { return nil, err } // 获取代理URL proxyURL := "" if account.ProxyID != nil && account.Proxy != nil { proxyURL = account.Proxy.URL() } // 调试日志:记录即将转发的账号信息 log.Printf("[Forward] Using account: ID=%d Name=%s Platform=%s Type=%s TLSFingerprint=%v Proxy=%s", account.ID, account.Name, account.Platform, account.Type, account.IsTLSFingerprintEnabled(), proxyURL) // 重试循环 var resp *http.Response retryStart := time.Now() for attempt := 1; attempt <= maxRetryAttempts; attempt++ { // 构建上游请求(每次重试需要重新构建,因为请求体需要重新读取) // Capture upstream request body for ops retry of this attempt. c.Set(OpsUpstreamRequestBodyKey, string(body)) upstreamReq, err := s.buildUpstreamRequest(ctx, c, account, body, token, tokenType, reqModel, reqStream, shouldMimicClaudeCode) if err != nil { return nil, err } // 发送请求 resp, err = s.httpUpstream.DoWithTLS(upstreamReq, proxyURL, account.ID, account.Concurrency, account.IsTLSFingerprintEnabled()) if err != nil { if resp != nil && resp.Body != nil { _ = resp.Body.Close() } // Ensure the client receives an error response (handlers assume Forward writes on non-failover errors). safeErr := sanitizeUpstreamErrorMessage(err.Error()) setOpsUpstreamError(c, 0, safeErr, "") appendOpsUpstreamError(c, OpsUpstreamErrorEvent{ Platform: account.Platform, AccountID: account.ID, AccountName: account.Name, UpstreamStatusCode: 0, Kind: "request_error", Message: safeErr, }) c.JSON(http.StatusBadGateway, gin.H{ "type": "error", "error": gin.H{ "type": "upstream_error", "message": "Upstream request failed", }, }) return nil, fmt.Errorf("upstream request failed: %s", safeErr) } // 优先检测thinking block签名错误(400)并重试一次 if resp.StatusCode == 400 { respBody, readErr := io.ReadAll(io.LimitReader(resp.Body, 2<<20)) if readErr == nil { _ = resp.Body.Close() if s.isThinkingBlockSignatureError(respBody) { appendOpsUpstreamError(c, OpsUpstreamErrorEvent{ Platform: account.Platform, AccountID: account.ID, AccountName: account.Name, UpstreamStatusCode: resp.StatusCode, UpstreamRequestID: resp.Header.Get("x-request-id"), Kind: "signature_error", Message: extractUpstreamErrorMessage(respBody), Detail: func() string { if s.cfg != nil && s.cfg.Gateway.LogUpstreamErrorBody { return truncateString(string(respBody), s.cfg.Gateway.LogUpstreamErrorBodyMaxBytes) } return "" }(), }) looksLikeToolSignatureError := func(msg string) bool { m := strings.ToLower(msg) return strings.Contains(m, "tool_use") || strings.Contains(m, "tool_result") || strings.Contains(m, "functioncall") || strings.Contains(m, "function_call") || strings.Contains(m, "functionresponse") || strings.Contains(m, "function_response") } // 避免在重试预算已耗尽时再发起额外请求 if time.Since(retryStart) >= maxRetryElapsed { resp.Body = io.NopCloser(bytes.NewReader(respBody)) break } log.Printf("Account %d: detected thinking block signature error, retrying with filtered thinking blocks", account.ID) // Conservative two-stage fallback: // 1) Disable thinking + thinking->text (preserve content) // 2) Only if upstream still errors AND error message points to tool/function signature issues: // also downgrade tool_use/tool_result blocks to text. filteredBody := FilterThinkingBlocksForRetry(body) retryReq, buildErr := s.buildUpstreamRequest(ctx, c, account, filteredBody, token, tokenType, reqModel, reqStream, shouldMimicClaudeCode) if buildErr == nil { retryResp, retryErr := s.httpUpstream.DoWithTLS(retryReq, proxyURL, account.ID, account.Concurrency, account.IsTLSFingerprintEnabled()) if retryErr == nil { if retryResp.StatusCode < 400 { log.Printf("Account %d: signature error retry succeeded (thinking downgraded)", account.ID) resp = retryResp break } retryRespBody, retryReadErr := io.ReadAll(io.LimitReader(retryResp.Body, 2<<20)) _ = retryResp.Body.Close() if retryReadErr == nil && retryResp.StatusCode == 400 && s.isThinkingBlockSignatureError(retryRespBody) { appendOpsUpstreamError(c, OpsUpstreamErrorEvent{ Platform: account.Platform, AccountID: account.ID, AccountName: account.Name, UpstreamStatusCode: retryResp.StatusCode, UpstreamRequestID: retryResp.Header.Get("x-request-id"), Kind: "signature_retry_thinking", Message: extractUpstreamErrorMessage(retryRespBody), Detail: func() string { if s.cfg != nil && s.cfg.Gateway.LogUpstreamErrorBody { return truncateString(string(retryRespBody), s.cfg.Gateway.LogUpstreamErrorBodyMaxBytes) } return "" }(), }) msg2 := extractUpstreamErrorMessage(retryRespBody) if looksLikeToolSignatureError(msg2) && time.Since(retryStart) < maxRetryElapsed { log.Printf("Account %d: signature retry still failing and looks tool-related, retrying with tool blocks downgraded", account.ID) filteredBody2 := FilterSignatureSensitiveBlocksForRetry(body) retryReq2, buildErr2 := s.buildUpstreamRequest(ctx, c, account, filteredBody2, token, tokenType, reqModel, reqStream, shouldMimicClaudeCode) if buildErr2 == nil { retryResp2, retryErr2 := s.httpUpstream.DoWithTLS(retryReq2, proxyURL, account.ID, account.Concurrency, account.IsTLSFingerprintEnabled()) if retryErr2 == nil { resp = retryResp2 break } if retryResp2 != nil && retryResp2.Body != nil { _ = retryResp2.Body.Close() } appendOpsUpstreamError(c, OpsUpstreamErrorEvent{ Platform: account.Platform, AccountID: account.ID, AccountName: account.Name, UpstreamStatusCode: 0, Kind: "signature_retry_tools_request_error", Message: sanitizeUpstreamErrorMessage(retryErr2.Error()), }) log.Printf("Account %d: tool-downgrade signature retry failed: %v", account.ID, retryErr2) } else { log.Printf("Account %d: tool-downgrade signature retry build failed: %v", account.ID, buildErr2) } } } // Fall back to the original retry response context. resp = &http.Response{ StatusCode: retryResp.StatusCode, Header: retryResp.Header.Clone(), Body: io.NopCloser(bytes.NewReader(retryRespBody)), } break } if retryResp != nil && retryResp.Body != nil { _ = retryResp.Body.Close() } log.Printf("Account %d: signature error retry failed: %v", account.ID, retryErr) } else { log.Printf("Account %d: signature error retry build request failed: %v", account.ID, buildErr) } // Retry failed: restore original response body and continue handling. resp.Body = io.NopCloser(bytes.NewReader(respBody)) break } // 不是thinking签名错误,恢复响应体 resp.Body = io.NopCloser(bytes.NewReader(respBody)) } } // 检查是否需要通用重试(排除400,因为400已经在上面特殊处理过了) if resp.StatusCode >= 400 && resp.StatusCode != 400 && s.shouldRetryUpstreamError(account, resp.StatusCode) { if attempt < maxRetryAttempts { elapsed := time.Since(retryStart) if elapsed >= maxRetryElapsed { break } delay := retryBackoffDelay(attempt) remaining := maxRetryElapsed - elapsed if delay > remaining { delay = remaining } if delay <= 0 { break } respBody, _ := io.ReadAll(io.LimitReader(resp.Body, 2<<20)) _ = resp.Body.Close() appendOpsUpstreamError(c, OpsUpstreamErrorEvent{ Platform: account.Platform, AccountID: account.ID, AccountName: account.Name, UpstreamStatusCode: resp.StatusCode, UpstreamRequestID: resp.Header.Get("x-request-id"), Kind: "retry", Message: extractUpstreamErrorMessage(respBody), Detail: func() string { if s.cfg != nil && s.cfg.Gateway.LogUpstreamErrorBody { return truncateString(string(respBody), s.cfg.Gateway.LogUpstreamErrorBodyMaxBytes) } return "" }(), }) log.Printf("Account %d: upstream error %d, retry %d/%d after %v (elapsed=%v/%v)", account.ID, resp.StatusCode, attempt, maxRetryAttempts, delay, elapsed, maxRetryElapsed) if err := sleepWithContext(ctx, delay); err != nil { return nil, err } continue } // 最后一次尝试也失败,跳出循环处理重试耗尽 break } // 不需要重试(成功或不可重试的错误),跳出循环 // DEBUG: 输出响应 headers(用于检测 rate limit 信息) if account.Platform == PlatformGemini && resp.StatusCode < 400 { log.Printf("[DEBUG] Gemini API Response Headers for account %d:", account.ID) for k, v := range resp.Header { log.Printf("[DEBUG] %s: %v", k, v) } } break } if resp == nil || resp.Body == nil { return nil, errors.New("upstream request failed: empty response") } defer func() { _ = resp.Body.Close() }() // 处理重试耗尽的情况 if resp.StatusCode >= 400 && s.shouldRetryUpstreamError(account, resp.StatusCode) { if s.shouldFailoverUpstreamError(resp.StatusCode) { respBody, _ := io.ReadAll(io.LimitReader(resp.Body, 2<<20)) _ = resp.Body.Close() resp.Body = io.NopCloser(bytes.NewReader(respBody)) // 调试日志:打印重试耗尽后的错误响应 log.Printf("[Forward] Upstream error (retry exhausted, failover): Account=%d(%s) Status=%d RequestID=%s Body=%s", account.ID, account.Name, resp.StatusCode, resp.Header.Get("x-request-id"), truncateString(string(respBody), 1000)) s.handleRetryExhaustedSideEffects(ctx, resp, account) appendOpsUpstreamError(c, OpsUpstreamErrorEvent{ Platform: account.Platform, AccountID: account.ID, AccountName: account.Name, UpstreamStatusCode: resp.StatusCode, UpstreamRequestID: resp.Header.Get("x-request-id"), Kind: "retry_exhausted_failover", Message: extractUpstreamErrorMessage(respBody), Detail: func() string { if s.cfg != nil && s.cfg.Gateway.LogUpstreamErrorBody { return truncateString(string(respBody), s.cfg.Gateway.LogUpstreamErrorBodyMaxBytes) } return "" }(), }) return nil, &UpstreamFailoverError{StatusCode: resp.StatusCode, ResponseBody: respBody} } return s.handleRetryExhaustedError(ctx, resp, c, account) } // 处理可切换账号的错误 if resp.StatusCode >= 400 && s.shouldFailoverUpstreamError(resp.StatusCode) { respBody, _ := io.ReadAll(io.LimitReader(resp.Body, 2<<20)) _ = resp.Body.Close() resp.Body = io.NopCloser(bytes.NewReader(respBody)) // 调试日志:打印上游错误响应 log.Printf("[Forward] Upstream error (failover): Account=%d(%s) Status=%d RequestID=%s Body=%s", account.ID, account.Name, resp.StatusCode, resp.Header.Get("x-request-id"), truncateString(string(respBody), 1000)) s.handleFailoverSideEffects(ctx, resp, account) appendOpsUpstreamError(c, OpsUpstreamErrorEvent{ Platform: account.Platform, AccountID: account.ID, UpstreamStatusCode: resp.StatusCode, UpstreamRequestID: resp.Header.Get("x-request-id"), Kind: "failover", Message: extractUpstreamErrorMessage(respBody), Detail: func() string { if s.cfg != nil && s.cfg.Gateway.LogUpstreamErrorBody { return truncateString(string(respBody), s.cfg.Gateway.LogUpstreamErrorBodyMaxBytes) } return "" }(), }) return nil, &UpstreamFailoverError{StatusCode: resp.StatusCode, ResponseBody: respBody} } if resp.StatusCode >= 400 { // 可选:对部分 400 触发 failover(默认关闭以保持语义) if resp.StatusCode == 400 && s.cfg != nil && s.cfg.Gateway.FailoverOn400 { respBody, readErr := io.ReadAll(io.LimitReader(resp.Body, 2<<20)) if readErr != nil { // ReadAll failed, fall back to normal error handling without consuming the stream return s.handleErrorResponse(ctx, resp, c, account) } _ = resp.Body.Close() resp.Body = io.NopCloser(bytes.NewReader(respBody)) if s.shouldFailoverOn400(respBody) { upstreamMsg := strings.TrimSpace(extractUpstreamErrorMessage(respBody)) upstreamMsg = sanitizeUpstreamErrorMessage(upstreamMsg) upstreamDetail := "" if s.cfg != nil && s.cfg.Gateway.LogUpstreamErrorBody { maxBytes := s.cfg.Gateway.LogUpstreamErrorBodyMaxBytes if maxBytes <= 0 { maxBytes = 2048 } upstreamDetail = truncateString(string(respBody), maxBytes) } appendOpsUpstreamError(c, OpsUpstreamErrorEvent{ Platform: account.Platform, AccountID: account.ID, AccountName: account.Name, UpstreamStatusCode: resp.StatusCode, UpstreamRequestID: resp.Header.Get("x-request-id"), Kind: "failover_on_400", Message: upstreamMsg, Detail: upstreamDetail, }) if s.cfg.Gateway.LogUpstreamErrorBody { log.Printf( "Account %d: 400 error, attempting failover: %s", account.ID, truncateForLog(respBody, s.cfg.Gateway.LogUpstreamErrorBodyMaxBytes), ) } else { log.Printf("Account %d: 400 error, attempting failover", account.ID) } s.handleFailoverSideEffects(ctx, resp, account) return nil, &UpstreamFailoverError{StatusCode: resp.StatusCode, ResponseBody: respBody} } } return s.handleErrorResponse(ctx, resp, c, account) } // 处理正常响应 var usage *ClaudeUsage var firstTokenMs *int var clientDisconnect bool if reqStream { streamResult, err := s.handleStreamingResponse(ctx, resp, c, account, startTime, originalModel, reqModel, shouldMimicClaudeCode) if err != nil { if err.Error() == "have error in stream" { return nil, &UpstreamFailoverError{ StatusCode: 403, } } return nil, err } usage = streamResult.usage firstTokenMs = streamResult.firstTokenMs clientDisconnect = streamResult.clientDisconnect } else { usage, err = s.handleNonStreamingResponse(ctx, resp, c, account, originalModel, reqModel) if err != nil { return nil, err } } return &ForwardResult{ RequestID: resp.Header.Get("x-request-id"), Usage: *usage, Model: originalModel, // 使用原始模型用于计费和日志 Stream: reqStream, Duration: time.Since(startTime), FirstTokenMs: firstTokenMs, ClientDisconnect: clientDisconnect, }, nil } func (s *GatewayService) buildUpstreamRequest(ctx context.Context, c *gin.Context, account *Account, body []byte, token, tokenType, modelID string, reqStream bool, mimicClaudeCode bool) (*http.Request, error) { // 确定目标URL targetURL := claudeAPIURL if account.Type == AccountTypeAPIKey { baseURL := account.GetBaseURL() if baseURL != "" { validatedURL, err := s.validateUpstreamBaseURL(baseURL) if err != nil { return nil, err } targetURL = validatedURL + "/v1/messages" } } clientHeaders := http.Header{} if c != nil && c.Request != nil { clientHeaders = c.Request.Header } // OAuth账号:应用统一指纹 var fingerprint *Fingerprint if account.IsOAuth() && s.identityService != nil { // 1. 获取或创建指纹(包含随机生成的ClientID) fp, err := s.identityService.GetOrCreateFingerprint(ctx, account.ID, clientHeaders) if err != nil { log.Printf("Warning: failed to get fingerprint for account %d: %v", account.ID, err) // 失败时降级为透传原始headers } else { fingerprint = fp // 2. 重写metadata.user_id(需要指纹中的ClientID和账号的account_uuid) // 如果启用了会话ID伪装,会在重写后替换 session 部分为固定值 accountUUID := account.GetExtraString("account_uuid") if accountUUID != "" && fp.ClientID != "" { if newBody, err := s.identityService.RewriteUserIDWithMasking(ctx, body, account, accountUUID, fp.ClientID); err == nil && len(newBody) > 0 { body = newBody } } } } req, err := http.NewRequestWithContext(ctx, "POST", targetURL, bytes.NewReader(body)) if err != nil { return nil, err } // 设置认证头 if tokenType == "oauth" { req.Header.Set("authorization", "Bearer "+token) } else { req.Header.Set("x-api-key", token) } // 白名单透传headers for key, values := range clientHeaders { lowerKey := strings.ToLower(key) if allowedHeaders[lowerKey] { for _, v := range values { req.Header.Add(key, v) } } } // OAuth账号:应用缓存的指纹到请求头(覆盖白名单透传的头) if fingerprint != nil { s.identityService.ApplyFingerprint(req, fingerprint) } // 确保必要的headers存在 if req.Header.Get("content-type") == "" { req.Header.Set("content-type", "application/json") } if req.Header.Get("anthropic-version") == "" { req.Header.Set("anthropic-version", "2023-06-01") } if tokenType == "oauth" { applyClaudeOAuthHeaderDefaults(req, reqStream) } // 处理 anthropic-beta header(OAuth 账号需要包含 oauth beta) if tokenType == "oauth" { if mimicClaudeCode { // 非 Claude Code 客户端:按 opencode 的策略处理: // - 强制 Claude Code 指纹相关请求头(尤其是 user-agent/x-stainless/x-app) // - 保留 incoming beta 的同时,确保 OAuth 所需 beta 存在 applyClaudeCodeMimicHeaders(req, reqStream) incomingBeta := req.Header.Get("anthropic-beta") // Match real Claude CLI traffic (per mitmproxy reports): // messages requests typically use only oauth + interleaved-thinking. // Also drop claude-code beta if a downstream client added it. requiredBetas := []string{claude.BetaOAuth, claude.BetaInterleavedThinking} drop := map[string]struct{}{claude.BetaClaudeCode: {}} req.Header.Set("anthropic-beta", mergeAnthropicBetaDropping(requiredBetas, incomingBeta, drop)) } else { // Claude Code 客户端:尽量透传原始 header,仅补齐 oauth beta clientBetaHeader := req.Header.Get("anthropic-beta") req.Header.Set("anthropic-beta", s.getBetaHeader(modelID, clientBetaHeader)) } } else if s.cfg != nil && s.cfg.Gateway.InjectBetaForAPIKey && req.Header.Get("anthropic-beta") == "" { // API-key:仅在请求显式使用 beta 特性且客户端未提供时,按需补齐(默认关闭) if requestNeedsBetaFeatures(body) { if beta := defaultAPIKeyBetaHeader(body); beta != "" { req.Header.Set("anthropic-beta", beta) } } } // Always capture a compact fingerprint line for later error diagnostics. // We only print it when needed (or when the explicit debug flag is enabled). if c != nil && tokenType == "oauth" { c.Set(claudeMimicDebugInfoKey, buildClaudeMimicDebugLine(req, body, account, tokenType, mimicClaudeCode)) } if s.debugClaudeMimicEnabled() { logClaudeMimicDebug(req, body, account, tokenType, mimicClaudeCode) } return req, nil } // getBetaHeader 处理anthropic-beta header // 对于OAuth账号,需要确保包含oauth-2025-04-20 func (s *GatewayService) getBetaHeader(modelID string, clientBetaHeader string) string { // 如果客户端传了anthropic-beta if clientBetaHeader != "" { // 已包含oauth beta则直接返回 if strings.Contains(clientBetaHeader, claude.BetaOAuth) { return clientBetaHeader } // 需要添加oauth beta parts := strings.Split(clientBetaHeader, ",") for i, p := range parts { parts[i] = strings.TrimSpace(p) } // 在claude-code-20250219后面插入oauth beta claudeCodeIdx := -1 for i, p := range parts { if p == claude.BetaClaudeCode { claudeCodeIdx = i break } } if claudeCodeIdx >= 0 { // 在claude-code后面插入 newParts := make([]string, 0, len(parts)+1) newParts = append(newParts, parts[:claudeCodeIdx+1]...) newParts = append(newParts, claude.BetaOAuth) newParts = append(newParts, parts[claudeCodeIdx+1:]...) return strings.Join(newParts, ",") } // 没有claude-code,放在第一位 return claude.BetaOAuth + "," + clientBetaHeader } // 客户端没传,根据模型生成 // haiku 模型不需要 claude-code beta if strings.Contains(strings.ToLower(modelID), "haiku") { return claude.HaikuBetaHeader } return claude.DefaultBetaHeader } func requestNeedsBetaFeatures(body []byte) bool { tools := gjson.GetBytes(body, "tools") if tools.Exists() && tools.IsArray() && len(tools.Array()) > 0 { return true } if strings.EqualFold(gjson.GetBytes(body, "thinking.type").String(), "enabled") { return true } return false } func defaultAPIKeyBetaHeader(body []byte) string { modelID := gjson.GetBytes(body, "model").String() if strings.Contains(strings.ToLower(modelID), "haiku") { return claude.APIKeyHaikuBetaHeader } return claude.APIKeyBetaHeader } func applyClaudeOAuthHeaderDefaults(req *http.Request, isStream bool) { if req == nil { return } if req.Header.Get("accept") == "" { req.Header.Set("accept", "application/json") } for key, value := range claude.DefaultHeaders { if value == "" { continue } if req.Header.Get(key) == "" { req.Header.Set(key, value) } } if isStream && req.Header.Get("x-stainless-helper-method") == "" { req.Header.Set("x-stainless-helper-method", "stream") } } func mergeAnthropicBeta(required []string, incoming string) string { seen := make(map[string]struct{}, len(required)+8) out := make([]string, 0, len(required)+8) add := func(v string) { v = strings.TrimSpace(v) if v == "" { return } if _, ok := seen[v]; ok { return } seen[v] = struct{}{} out = append(out, v) } for _, r := range required { add(r) } for _, p := range strings.Split(incoming, ",") { add(p) } return strings.Join(out, ",") } func mergeAnthropicBetaDropping(required []string, incoming string, drop map[string]struct{}) string { merged := mergeAnthropicBeta(required, incoming) if merged == "" || len(drop) == 0 { return merged } out := make([]string, 0, 8) for _, p := range strings.Split(merged, ",") { p = strings.TrimSpace(p) if p == "" { continue } if _, ok := drop[p]; ok { continue } out = append(out, p) } return strings.Join(out, ",") } // applyClaudeCodeMimicHeaders forces "Claude Code-like" request headers. // This mirrors opencode-anthropic-auth behavior: do not trust downstream // headers when using Claude Code-scoped OAuth credentials. func applyClaudeCodeMimicHeaders(req *http.Request, isStream bool) { if req == nil { return } // Start with the standard defaults (fill missing). applyClaudeOAuthHeaderDefaults(req, isStream) // Then force key headers to match Claude Code fingerprint regardless of what the client sent. for key, value := range claude.DefaultHeaders { if value == "" { continue } req.Header.Set(key, value) } // Real Claude CLI uses Accept: application/json (even for streaming). req.Header.Set("accept", "application/json") if isStream { req.Header.Set("x-stainless-helper-method", "stream") } } func truncateForLog(b []byte, maxBytes int) string { if maxBytes <= 0 { maxBytes = 2048 } if len(b) > maxBytes { b = b[:maxBytes] } s := string(b) // 保持一行,避免污染日志格式 s = strings.ReplaceAll(s, "\n", "\\n") s = strings.ReplaceAll(s, "\r", "\\r") return s } // isThinkingBlockSignatureError 检测是否是thinking block相关错误 // 这类错误可以通过过滤thinking blocks并重试来解决 func (s *GatewayService) isThinkingBlockSignatureError(respBody []byte) bool { msg := strings.ToLower(strings.TrimSpace(extractUpstreamErrorMessage(respBody))) if msg == "" { return false } // Log for debugging log.Printf("[SignatureCheck] Checking error message: %s", msg) // 检测signature相关的错误(更宽松的匹配) // 例如: "Invalid `signature` in `thinking` block", "***.signature" 等 if strings.Contains(msg, "signature") { log.Printf("[SignatureCheck] Detected signature error") return true } // 检测 thinking block 顺序/类型错误 // 例如: "Expected `thinking` or `redacted_thinking`, but found `text`" if strings.Contains(msg, "expected") && (strings.Contains(msg, "thinking") || strings.Contains(msg, "redacted_thinking")) { log.Printf("[SignatureCheck] Detected thinking block type error") return true } // 检测 thinking block 被修改的错误 // 例如: "thinking or redacted_thinking blocks in the latest assistant message cannot be modified" if strings.Contains(msg, "cannot be modified") && (strings.Contains(msg, "thinking") || strings.Contains(msg, "redacted_thinking")) { log.Printf("[SignatureCheck] Detected thinking block modification error") return true } // 检测空消息内容错误(可能是过滤 thinking blocks 后导致的) // 例如: "all messages must have non-empty content" if strings.Contains(msg, "non-empty content") || strings.Contains(msg, "empty content") { log.Printf("[SignatureCheck] Detected empty content error") return true } return false } func (s *GatewayService) shouldFailoverOn400(respBody []byte) bool { // 只对“可能是兼容性差异导致”的 400 允许切换,避免无意义重试。 // 默认保守:无法识别则不切换。 msg := strings.ToLower(strings.TrimSpace(extractUpstreamErrorMessage(respBody))) if msg == "" { return false } // 缺少/错误的 beta header:换账号/链路可能成功(尤其是混合调度时)。 // 更精确匹配 beta 相关的兼容性问题,避免误触发切换。 if strings.Contains(msg, "anthropic-beta") || strings.Contains(msg, "beta feature") || strings.Contains(msg, "requires beta") { return true } // thinking/tool streaming 等兼容性约束(常见于中间转换链路) if strings.Contains(msg, "thinking") || strings.Contains(msg, "thought_signature") || strings.Contains(msg, "signature") { return true } if strings.Contains(msg, "tool_use") || strings.Contains(msg, "tool_result") || strings.Contains(msg, "tools") { return true } return false } // ExtractUpstreamErrorMessage 从上游响应体中提取错误消息 // 支持 Claude 风格的错误格式:{"type":"error","error":{"type":"...","message":"..."}} func ExtractUpstreamErrorMessage(body []byte) string { return extractUpstreamErrorMessage(body) } func extractUpstreamErrorMessage(body []byte) string { // Claude 风格:{"type":"error","error":{"type":"...","message":"..."}} if m := gjson.GetBytes(body, "error.message").String(); strings.TrimSpace(m) != "" { inner := strings.TrimSpace(m) // 有些上游会把完整 JSON 作为字符串塞进 message if strings.HasPrefix(inner, "{") { if innerMsg := gjson.Get(inner, "error.message").String(); strings.TrimSpace(innerMsg) != "" { return innerMsg } } return m } // 兜底:尝试顶层 message return gjson.GetBytes(body, "message").String() } func (s *GatewayService) handleErrorResponse(ctx context.Context, resp *http.Response, c *gin.Context, account *Account) (*ForwardResult, error) { body, _ := io.ReadAll(io.LimitReader(resp.Body, 2<<20)) // 调试日志:打印上游错误响应 log.Printf("[Forward] Upstream error (non-retryable): Account=%d(%s) Status=%d RequestID=%s Body=%s", account.ID, account.Name, resp.StatusCode, resp.Header.Get("x-request-id"), truncateString(string(body), 1000)) upstreamMsg := strings.TrimSpace(extractUpstreamErrorMessage(body)) upstreamMsg = sanitizeUpstreamErrorMessage(upstreamMsg) // Print a compact upstream request fingerprint when we hit the Claude Code OAuth // credential scope error. This avoids requiring env-var tweaks in a fixed deploy. if isClaudeCodeCredentialScopeError(upstreamMsg) && c != nil { if v, ok := c.Get(claudeMimicDebugInfoKey); ok { if line, ok := v.(string); ok && strings.TrimSpace(line) != "" { log.Printf("[ClaudeMimicDebugOnError] status=%d request_id=%s %s", resp.StatusCode, resp.Header.Get("x-request-id"), line, ) } } } // Enrich Ops error logs with upstream status + message, and optionally a truncated body snippet. upstreamDetail := "" if s.cfg != nil && s.cfg.Gateway.LogUpstreamErrorBody { maxBytes := s.cfg.Gateway.LogUpstreamErrorBodyMaxBytes if maxBytes <= 0 { maxBytes = 2048 } upstreamDetail = truncateString(string(body), maxBytes) } setOpsUpstreamError(c, resp.StatusCode, upstreamMsg, upstreamDetail) appendOpsUpstreamError(c, OpsUpstreamErrorEvent{ Platform: account.Platform, AccountID: account.ID, UpstreamStatusCode: resp.StatusCode, UpstreamRequestID: resp.Header.Get("x-request-id"), Kind: "http_error", Message: upstreamMsg, Detail: upstreamDetail, }) // 处理上游错误,标记账号状态 shouldDisable := false if s.rateLimitService != nil { shouldDisable = s.rateLimitService.HandleUpstreamError(ctx, account, resp.StatusCode, resp.Header, body) } if shouldDisable { return nil, &UpstreamFailoverError{StatusCode: resp.StatusCode, ResponseBody: body} } // 记录上游错误响应体摘要便于排障(可选:由配置控制;不回显到客户端) if s.cfg != nil && s.cfg.Gateway.LogUpstreamErrorBody { log.Printf( "Upstream error %d (account=%d platform=%s type=%s): %s", resp.StatusCode, account.ID, account.Platform, account.Type, truncateForLog(body, s.cfg.Gateway.LogUpstreamErrorBodyMaxBytes), ) } // 非 failover 错误也支持错误透传规则匹配。 if status, errType, errMsg, matched := applyErrorPassthroughRule( c, account.Platform, resp.StatusCode, body, http.StatusBadGateway, "upstream_error", "Upstream request failed", ); matched { c.JSON(status, gin.H{ "type": "error", "error": gin.H{ "type": errType, "message": errMsg, }, }) summary := upstreamMsg if summary == "" { summary = errMsg } if summary == "" { return nil, fmt.Errorf("upstream error: %d (passthrough rule matched)", resp.StatusCode) } return nil, fmt.Errorf("upstream error: %d (passthrough rule matched) message=%s", resp.StatusCode, summary) } // 根据状态码返回适当的自定义错误响应(不透传上游详细信息) var errType, errMsg string var statusCode int switch resp.StatusCode { case 400: c.Data(http.StatusBadRequest, "application/json", body) summary := upstreamMsg if summary == "" { summary = truncateForLog(body, 512) } if summary == "" { return nil, fmt.Errorf("upstream error: %d", resp.StatusCode) } return nil, fmt.Errorf("upstream error: %d message=%s", resp.StatusCode, summary) case 401: statusCode = http.StatusBadGateway errType = "upstream_error" errMsg = "Upstream authentication failed, please contact administrator" case 403: statusCode = http.StatusBadGateway errType = "upstream_error" errMsg = "Upstream access forbidden, please contact administrator" case 429: statusCode = http.StatusTooManyRequests errType = "rate_limit_error" errMsg = "Upstream rate limit exceeded, please retry later" case 529: statusCode = http.StatusServiceUnavailable errType = "overloaded_error" errMsg = "Upstream service overloaded, please retry later" case 500, 502, 503, 504: statusCode = http.StatusBadGateway errType = "upstream_error" errMsg = "Upstream service temporarily unavailable" default: statusCode = http.StatusBadGateway errType = "upstream_error" errMsg = "Upstream request failed" } // 返回自定义错误响应 c.JSON(statusCode, gin.H{ "type": "error", "error": gin.H{ "type": errType, "message": errMsg, }, }) if upstreamMsg == "" { return nil, fmt.Errorf("upstream error: %d", resp.StatusCode) } return nil, fmt.Errorf("upstream error: %d message=%s", resp.StatusCode, upstreamMsg) } func (s *GatewayService) handleRetryExhaustedSideEffects(ctx context.Context, resp *http.Response, account *Account) { body, _ := io.ReadAll(io.LimitReader(resp.Body, 2<<20)) statusCode := resp.StatusCode // OAuth/Setup Token 账号的 403:标记账号异常 if account.IsOAuth() && statusCode == 403 { s.rateLimitService.HandleUpstreamError(ctx, account, statusCode, resp.Header, body) log.Printf("Account %d: marked as error after %d retries for status %d", account.ID, maxRetryAttempts, statusCode) } else { // API Key 未配置错误码:不标记账号状态 log.Printf("Account %d: upstream error %d after %d retries (not marking account)", account.ID, statusCode, maxRetryAttempts) } } func (s *GatewayService) handleFailoverSideEffects(ctx context.Context, resp *http.Response, account *Account) { body, _ := io.ReadAll(io.LimitReader(resp.Body, 2<<20)) s.rateLimitService.HandleUpstreamError(ctx, account, resp.StatusCode, resp.Header, body) } // handleRetryExhaustedError 处理重试耗尽后的错误 // OAuth 403:标记账号异常 // API Key 未配置错误码:仅返回错误,不标记账号 func (s *GatewayService) handleRetryExhaustedError(ctx context.Context, resp *http.Response, c *gin.Context, account *Account) (*ForwardResult, error) { // Capture upstream error body before side-effects consume the stream. respBody, _ := io.ReadAll(io.LimitReader(resp.Body, 2<<20)) _ = resp.Body.Close() resp.Body = io.NopCloser(bytes.NewReader(respBody)) s.handleRetryExhaustedSideEffects(ctx, resp, account) upstreamMsg := strings.TrimSpace(extractUpstreamErrorMessage(respBody)) upstreamMsg = sanitizeUpstreamErrorMessage(upstreamMsg) if isClaudeCodeCredentialScopeError(upstreamMsg) && c != nil { if v, ok := c.Get(claudeMimicDebugInfoKey); ok { if line, ok := v.(string); ok && strings.TrimSpace(line) != "" { log.Printf("[ClaudeMimicDebugOnError] status=%d request_id=%s %s", resp.StatusCode, resp.Header.Get("x-request-id"), line, ) } } } upstreamDetail := "" if s.cfg != nil && s.cfg.Gateway.LogUpstreamErrorBody { maxBytes := s.cfg.Gateway.LogUpstreamErrorBodyMaxBytes if maxBytes <= 0 { maxBytes = 2048 } upstreamDetail = truncateString(string(respBody), maxBytes) } setOpsUpstreamError(c, resp.StatusCode, upstreamMsg, upstreamDetail) appendOpsUpstreamError(c, OpsUpstreamErrorEvent{ Platform: account.Platform, AccountID: account.ID, UpstreamStatusCode: resp.StatusCode, UpstreamRequestID: resp.Header.Get("x-request-id"), Kind: "retry_exhausted", Message: upstreamMsg, Detail: upstreamDetail, }) if s.cfg != nil && s.cfg.Gateway.LogUpstreamErrorBody { log.Printf( "Upstream error %d retries_exhausted (account=%d platform=%s type=%s): %s", resp.StatusCode, account.ID, account.Platform, account.Type, truncateForLog(respBody, s.cfg.Gateway.LogUpstreamErrorBodyMaxBytes), ) } if status, errType, errMsg, matched := applyErrorPassthroughRule( c, account.Platform, resp.StatusCode, respBody, http.StatusBadGateway, "upstream_error", "Upstream request failed after retries", ); matched { c.JSON(status, gin.H{ "type": "error", "error": gin.H{ "type": errType, "message": errMsg, }, }) summary := upstreamMsg if summary == "" { summary = errMsg } if summary == "" { return nil, fmt.Errorf("upstream error: %d (retries exhausted, passthrough rule matched)", resp.StatusCode) } return nil, fmt.Errorf("upstream error: %d (retries exhausted, passthrough rule matched) message=%s", resp.StatusCode, summary) } // 返回统一的重试耗尽错误响应 c.JSON(http.StatusBadGateway, gin.H{ "type": "error", "error": gin.H{ "type": "upstream_error", "message": "Upstream request failed after retries", }, }) if upstreamMsg == "" { return nil, fmt.Errorf("upstream error: %d (retries exhausted)", resp.StatusCode) } return nil, fmt.Errorf("upstream error: %d (retries exhausted) message=%s", resp.StatusCode, upstreamMsg) } // streamingResult 流式响应结果 type streamingResult struct { usage *ClaudeUsage firstTokenMs *int clientDisconnect bool // 客户端是否在流式传输过程中断开 } func (s *GatewayService) handleStreamingResponse(ctx context.Context, resp *http.Response, c *gin.Context, account *Account, startTime time.Time, originalModel, mappedModel string, mimicClaudeCode bool) (*streamingResult, error) { // 更新5h窗口状态 s.rateLimitService.UpdateSessionWindow(ctx, account, resp.Header) if s.cfg != nil { responseheaders.WriteFilteredHeaders(c.Writer.Header(), resp.Header, s.cfg.Security.ResponseHeaders) } // 设置SSE响应头 c.Header("Content-Type", "text/event-stream") c.Header("Cache-Control", "no-cache") c.Header("Connection", "keep-alive") c.Header("X-Accel-Buffering", "no") // 透传其他响应头 if v := resp.Header.Get("x-request-id"); v != "" { c.Header("x-request-id", v) } w := c.Writer flusher, ok := w.(http.Flusher) if !ok { return nil, errors.New("streaming not supported") } usage := &ClaudeUsage{} var firstTokenMs *int scanner := bufio.NewScanner(resp.Body) // 设置更大的buffer以处理长行 maxLineSize := defaultMaxLineSize if s.cfg != nil && s.cfg.Gateway.MaxLineSize > 0 { maxLineSize = s.cfg.Gateway.MaxLineSize } scanner.Buffer(make([]byte, 64*1024), maxLineSize) type scanEvent struct { line string err error } // 独立 goroutine 读取上游,避免读取阻塞导致超时/keepalive无法处理 events := make(chan scanEvent, 16) done := make(chan struct{}) sendEvent := func(ev scanEvent) bool { select { case events <- ev: return true case <-done: return false } } var lastReadAt int64 atomic.StoreInt64(&lastReadAt, time.Now().UnixNano()) go func() { defer close(events) for scanner.Scan() { atomic.StoreInt64(&lastReadAt, time.Now().UnixNano()) if !sendEvent(scanEvent{line: scanner.Text()}) { return } } if err := scanner.Err(); err != nil { _ = sendEvent(scanEvent{err: err}) } }() defer close(done) streamInterval := time.Duration(0) if s.cfg != nil && s.cfg.Gateway.StreamDataIntervalTimeout > 0 { streamInterval = time.Duration(s.cfg.Gateway.StreamDataIntervalTimeout) * time.Second } // 仅监控上游数据间隔超时,避免下游写入阻塞导致误判 var intervalTicker *time.Ticker if streamInterval > 0 { intervalTicker = time.NewTicker(streamInterval) defer intervalTicker.Stop() } var intervalCh <-chan time.Time if intervalTicker != nil { intervalCh = intervalTicker.C } // 仅发送一次错误事件,避免多次写入导致协议混乱(写失败时尽力通知客户端) errorEventSent := false sendErrorEvent := func(reason string) { if errorEventSent { return } errorEventSent = true _, _ = fmt.Fprintf(w, "event: error\ndata: {\"error\":\"%s\"}\n\n", reason) flusher.Flush() } needModelReplace := originalModel != mappedModel clientDisconnected := false // 客户端断开标志,断开后继续读取上游以获取完整usage pendingEventLines := make([]string, 0, 4) processSSEEvent := func(lines []string) ([]string, string, error) { if len(lines) == 0 { return nil, "", nil } eventName := "" dataLine := "" for _, line := range lines { trimmed := strings.TrimSpace(line) if strings.HasPrefix(trimmed, "event:") { eventName = strings.TrimSpace(strings.TrimPrefix(trimmed, "event:")) continue } if dataLine == "" && sseDataRe.MatchString(trimmed) { dataLine = sseDataRe.ReplaceAllString(trimmed, "") } } if eventName == "error" { return nil, dataLine, errors.New("have error in stream") } if dataLine == "" { return []string{strings.Join(lines, "\n") + "\n\n"}, "", nil } if dataLine == "[DONE]" { block := "" if eventName != "" { block = "event: " + eventName + "\n" } block += "data: " + dataLine + "\n\n" return []string{block}, dataLine, nil } var event map[string]any if err := json.Unmarshal([]byte(dataLine), &event); err != nil { // JSON 解析失败,直接透传原始数据 block := "" if eventName != "" { block = "event: " + eventName + "\n" } block += "data: " + dataLine + "\n\n" return []string{block}, dataLine, nil } eventType, _ := event["type"].(string) if eventName == "" { eventName = eventType } // 兼容 Kimi cached_tokens → cache_read_input_tokens if eventType == "message_start" { if msg, ok := event["message"].(map[string]any); ok { if u, ok := msg["usage"].(map[string]any); ok { reconcileCachedTokens(u) } } } if eventType == "message_delta" { if u, ok := event["usage"].(map[string]any); ok { reconcileCachedTokens(u) } } if needModelReplace { if msg, ok := event["message"].(map[string]any); ok { if model, ok := msg["model"].(string); ok && model == mappedModel { msg["model"] = originalModel } } } newData, err := json.Marshal(event) if err != nil { // 序列化失败,直接透传原始数据 block := "" if eventName != "" { block = "event: " + eventName + "\n" } block += "data: " + dataLine + "\n\n" return []string{block}, dataLine, nil } block := "" if eventName != "" { block = "event: " + eventName + "\n" } block += "data: " + string(newData) + "\n\n" return []string{block}, string(newData), nil } for { select { case ev, ok := <-events: if !ok { // 上游完成,返回结果 return &streamingResult{usage: usage, firstTokenMs: firstTokenMs, clientDisconnect: clientDisconnected}, nil } if ev.err != nil { // 检测 context 取消(客户端断开会导致 context 取消,进而影响上游读取) if errors.Is(ev.err, context.Canceled) || errors.Is(ev.err, context.DeadlineExceeded) { log.Printf("Context canceled during streaming, returning collected usage") return &streamingResult{usage: usage, firstTokenMs: firstTokenMs, clientDisconnect: true}, nil } // 客户端已通过写入失败检测到断开,上游也出错了,返回已收集的 usage if clientDisconnected { log.Printf("Upstream read error after client disconnect: %v, returning collected usage", ev.err) return &streamingResult{usage: usage, firstTokenMs: firstTokenMs, clientDisconnect: true}, nil } // 客户端未断开,正常的错误处理 if errors.Is(ev.err, bufio.ErrTooLong) { log.Printf("SSE line too long: account=%d max_size=%d error=%v", account.ID, maxLineSize, ev.err) sendErrorEvent("response_too_large") return &streamingResult{usage: usage, firstTokenMs: firstTokenMs}, ev.err } sendErrorEvent("stream_read_error") return &streamingResult{usage: usage, firstTokenMs: firstTokenMs}, fmt.Errorf("stream read error: %w", ev.err) } line := ev.line trimmed := strings.TrimSpace(line) if trimmed == "" { if len(pendingEventLines) == 0 { continue } outputBlocks, data, err := processSSEEvent(pendingEventLines) pendingEventLines = pendingEventLines[:0] if err != nil { if clientDisconnected { return &streamingResult{usage: usage, firstTokenMs: firstTokenMs, clientDisconnect: true}, nil } return nil, err } for _, block := range outputBlocks { if !clientDisconnected { if _, werr := fmt.Fprint(w, block); werr != nil { clientDisconnected = true log.Printf("Client disconnected during streaming, continuing to drain upstream for billing") break } flusher.Flush() } if data != "" { if firstTokenMs == nil && data != "[DONE]" { ms := int(time.Since(startTime).Milliseconds()) firstTokenMs = &ms } s.parseSSEUsage(data, usage) } } continue } pendingEventLines = append(pendingEventLines, line) case <-intervalCh: lastRead := time.Unix(0, atomic.LoadInt64(&lastReadAt)) if time.Since(lastRead) < streamInterval { continue } if clientDisconnected { // 客户端已断开,上游也超时了,返回已收集的 usage log.Printf("Upstream timeout after client disconnect, returning collected usage") return &streamingResult{usage: usage, firstTokenMs: firstTokenMs, clientDisconnect: true}, nil } log.Printf("Stream data interval timeout: account=%d model=%s interval=%s", account.ID, originalModel, streamInterval) // 处理流超时,可能标记账户为临时不可调度或错误状态 if s.rateLimitService != nil { s.rateLimitService.HandleStreamTimeout(ctx, account, originalModel) } sendErrorEvent("stream_timeout") return &streamingResult{usage: usage, firstTokenMs: firstTokenMs}, fmt.Errorf("stream data interval timeout") } } } func (s *GatewayService) parseSSEUsage(data string, usage *ClaudeUsage) { // 解析message_start获取input tokens(标准Claude API格式) var msgStart struct { Type string `json:"type"` Message struct { Usage ClaudeUsage `json:"usage"` } `json:"message"` } if json.Unmarshal([]byte(data), &msgStart) == nil && msgStart.Type == "message_start" { usage.InputTokens = msgStart.Message.Usage.InputTokens usage.CacheCreationInputTokens = msgStart.Message.Usage.CacheCreationInputTokens usage.CacheReadInputTokens = msgStart.Message.Usage.CacheReadInputTokens } // 解析message_delta获取tokens(兼容GLM等把所有usage放在delta中的API) var msgDelta struct { Type string `json:"type"` Usage struct { InputTokens int `json:"input_tokens"` OutputTokens int `json:"output_tokens"` CacheCreationInputTokens int `json:"cache_creation_input_tokens"` CacheReadInputTokens int `json:"cache_read_input_tokens"` } `json:"usage"` } if json.Unmarshal([]byte(data), &msgDelta) == nil && msgDelta.Type == "message_delta" { // message_delta 仅覆盖存在且非0的字段 // 避免覆盖 message_start 中已有的值(如 input_tokens) // Claude API 的 message_delta 通常只包含 output_tokens if msgDelta.Usage.InputTokens > 0 { usage.InputTokens = msgDelta.Usage.InputTokens } if msgDelta.Usage.OutputTokens > 0 { usage.OutputTokens = msgDelta.Usage.OutputTokens } if msgDelta.Usage.CacheCreationInputTokens > 0 { usage.CacheCreationInputTokens = msgDelta.Usage.CacheCreationInputTokens } if msgDelta.Usage.CacheReadInputTokens > 0 { usage.CacheReadInputTokens = msgDelta.Usage.CacheReadInputTokens } } } func (s *GatewayService) handleNonStreamingResponse(ctx context.Context, resp *http.Response, c *gin.Context, account *Account, originalModel, mappedModel string) (*ClaudeUsage, error) { // 更新5h窗口状态 s.rateLimitService.UpdateSessionWindow(ctx, account, resp.Header) body, err := io.ReadAll(resp.Body) if err != nil { return nil, err } // 解析usage var response struct { Usage ClaudeUsage `json:"usage"` } if err := json.Unmarshal(body, &response); err != nil { return nil, fmt.Errorf("parse response: %w", err) } // 兼容 Kimi cached_tokens → cache_read_input_tokens if response.Usage.CacheReadInputTokens == 0 { cachedTokens := gjson.GetBytes(body, "usage.cached_tokens").Int() if cachedTokens > 0 { response.Usage.CacheReadInputTokens = int(cachedTokens) if newBody, err := sjson.SetBytes(body, "usage.cache_read_input_tokens", cachedTokens); err == nil { body = newBody } } } // 如果有模型映射,替换响应中的model字段 if originalModel != mappedModel { body = s.replaceModelInResponseBody(body, mappedModel, originalModel) } responseheaders.WriteFilteredHeaders(c.Writer.Header(), resp.Header, s.cfg.Security.ResponseHeaders) contentType := "application/json" if s.cfg != nil && !s.cfg.Security.ResponseHeaders.Enabled { if upstreamType := resp.Header.Get("Content-Type"); upstreamType != "" { contentType = upstreamType } } // 写入响应 c.Data(resp.StatusCode, contentType, body) return &response.Usage, nil } // replaceModelInResponseBody 替换响应体中的model字段 func (s *GatewayService) replaceModelInResponseBody(body []byte, fromModel, toModel string) []byte { var resp map[string]any if err := json.Unmarshal(body, &resp); err != nil { return body } model, ok := resp["model"].(string) if !ok || model != fromModel { return body } resp["model"] = toModel newBody, err := json.Marshal(resp) if err != nil { return body } return newBody } // RecordUsageInput 记录使用量的输入参数 type RecordUsageInput struct { Result *ForwardResult APIKey *APIKey User *User Account *Account Subscription *UserSubscription // 可选:订阅信息 UserAgent string // 请求的 User-Agent IPAddress string // 请求的客户端 IP 地址 ForceCacheBilling bool // 强制缓存计费:将 input_tokens 转为 cache_read 计费(用于粘性会话切换) APIKeyService APIKeyQuotaUpdater // 可选:用于更新API Key配额 } // APIKeyQuotaUpdater defines the interface for updating API Key quota type APIKeyQuotaUpdater interface { UpdateQuotaUsed(ctx context.Context, apiKeyID int64, cost float64) error } // RecordUsage 记录使用量并扣费(或更新订阅用量) func (s *GatewayService) RecordUsage(ctx context.Context, input *RecordUsageInput) error { result := input.Result apiKey := input.APIKey user := input.User account := input.Account subscription := input.Subscription // 强制缓存计费:将 input_tokens 转为 cache_read_input_tokens // 用于粘性会话切换时的特殊计费处理 if input.ForceCacheBilling && result.Usage.InputTokens > 0 { log.Printf("force_cache_billing: %d input_tokens → cache_read_input_tokens (account=%d)", result.Usage.InputTokens, account.ID) result.Usage.CacheReadInputTokens += result.Usage.InputTokens result.Usage.InputTokens = 0 } // 获取费率倍数(优先级:用户专属 > 分组默认 > 系统默认) multiplier := s.cfg.Default.RateMultiplier if apiKey.GroupID != nil && apiKey.Group != nil { multiplier = apiKey.Group.RateMultiplier // 检查用户专属倍率 if s.userGroupRateRepo != nil { if userRate, err := s.userGroupRateRepo.GetByUserAndGroup(ctx, user.ID, *apiKey.GroupID); err == nil && userRate != nil { multiplier = *userRate } } } var cost *CostBreakdown // 根据请求类型选择计费方式 if result.ImageCount > 0 { // 图片生成计费 var groupConfig *ImagePriceConfig if apiKey.Group != nil { groupConfig = &ImagePriceConfig{ Price1K: apiKey.Group.ImagePrice1K, Price2K: apiKey.Group.ImagePrice2K, Price4K: apiKey.Group.ImagePrice4K, } } cost = s.billingService.CalculateImageCost(result.Model, result.ImageSize, result.ImageCount, groupConfig, multiplier) } else { // Token 计费 tokens := UsageTokens{ InputTokens: result.Usage.InputTokens, OutputTokens: result.Usage.OutputTokens, CacheCreationTokens: result.Usage.CacheCreationInputTokens, CacheReadTokens: result.Usage.CacheReadInputTokens, } var err error cost, err = s.billingService.CalculateCost(result.Model, tokens, multiplier) if err != nil { log.Printf("Calculate cost failed: %v", err) cost = &CostBreakdown{ActualCost: 0} } } // 判断计费方式:订阅模式 vs 余额模式 isSubscriptionBilling := subscription != nil && apiKey.Group != nil && apiKey.Group.IsSubscriptionType() billingType := BillingTypeBalance if isSubscriptionBilling { billingType = BillingTypeSubscription } // 创建使用日志 durationMs := int(result.Duration.Milliseconds()) var imageSize *string if result.ImageSize != "" { imageSize = &result.ImageSize } accountRateMultiplier := account.BillingRateMultiplier() usageLog := &UsageLog{ UserID: user.ID, APIKeyID: apiKey.ID, AccountID: account.ID, RequestID: result.RequestID, Model: result.Model, InputTokens: result.Usage.InputTokens, OutputTokens: result.Usage.OutputTokens, CacheCreationTokens: result.Usage.CacheCreationInputTokens, CacheReadTokens: result.Usage.CacheReadInputTokens, InputCost: cost.InputCost, OutputCost: cost.OutputCost, CacheCreationCost: cost.CacheCreationCost, CacheReadCost: cost.CacheReadCost, TotalCost: cost.TotalCost, ActualCost: cost.ActualCost, RateMultiplier: multiplier, AccountRateMultiplier: &accountRateMultiplier, BillingType: billingType, Stream: result.Stream, DurationMs: &durationMs, FirstTokenMs: result.FirstTokenMs, ImageCount: result.ImageCount, ImageSize: imageSize, CreatedAt: time.Now(), } // 添加 UserAgent if input.UserAgent != "" { usageLog.UserAgent = &input.UserAgent } // 添加 IPAddress if input.IPAddress != "" { usageLog.IPAddress = &input.IPAddress } // 添加分组和订阅关联 if apiKey.GroupID != nil { usageLog.GroupID = apiKey.GroupID } if subscription != nil { usageLog.SubscriptionID = &subscription.ID } inserted, err := s.usageLogRepo.Create(ctx, usageLog) if err != nil { log.Printf("Create usage log failed: %v", err) } if s.cfg != nil && s.cfg.RunMode == config.RunModeSimple { log.Printf("[SIMPLE MODE] Usage recorded (not billed): user=%d, tokens=%d", usageLog.UserID, usageLog.TotalTokens()) s.deferredService.ScheduleLastUsedUpdate(account.ID) return nil } shouldBill := inserted || err != nil // 根据计费类型执行扣费 if isSubscriptionBilling { // 订阅模式:更新订阅用量(使用 TotalCost 原始费用,不考虑倍率) if shouldBill && cost.TotalCost > 0 { if err := s.userSubRepo.IncrementUsage(ctx, subscription.ID, cost.TotalCost); err != nil { log.Printf("Increment subscription usage failed: %v", err) } // 异步更新订阅缓存 s.billingCacheService.QueueUpdateSubscriptionUsage(user.ID, *apiKey.GroupID, cost.TotalCost) } } else { // 余额模式:扣除用户余额(使用 ActualCost 考虑倍率后的费用) if shouldBill && cost.ActualCost > 0 { if err := s.userRepo.DeductBalance(ctx, user.ID, cost.ActualCost); err != nil { log.Printf("Deduct balance failed: %v", err) } // 异步更新余额缓存 s.billingCacheService.QueueDeductBalance(user.ID, cost.ActualCost) } } // 更新 API Key 配额(如果设置了配额限制) if shouldBill && cost.ActualCost > 0 && apiKey.Quota > 0 && input.APIKeyService != nil { if err := input.APIKeyService.UpdateQuotaUsed(ctx, apiKey.ID, cost.ActualCost); err != nil { log.Printf("Update API key quota failed: %v", err) } } // Schedule batch update for account last_used_at s.deferredService.ScheduleLastUsedUpdate(account.ID) return nil } // RecordUsageLongContextInput 记录使用量的输入参数(支持长上下文双倍计费) type RecordUsageLongContextInput struct { Result *ForwardResult APIKey *APIKey User *User Account *Account Subscription *UserSubscription // 可选:订阅信息 UserAgent string // 请求的 User-Agent IPAddress string // 请求的客户端 IP 地址 LongContextThreshold int // 长上下文阈值(如 200000) LongContextMultiplier float64 // 超出阈值部分的倍率(如 2.0) ForceCacheBilling bool // 强制缓存计费:将 input_tokens 转为 cache_read 计费(用于粘性会话切换) APIKeyService *APIKeyService // API Key 配额服务(可选) } // RecordUsageWithLongContext 记录使用量并扣费,支持长上下文双倍计费(用于 Gemini) func (s *GatewayService) RecordUsageWithLongContext(ctx context.Context, input *RecordUsageLongContextInput) error { result := input.Result apiKey := input.APIKey user := input.User account := input.Account subscription := input.Subscription // 强制缓存计费:将 input_tokens 转为 cache_read_input_tokens // 用于粘性会话切换时的特殊计费处理 if input.ForceCacheBilling && result.Usage.InputTokens > 0 { log.Printf("force_cache_billing: %d input_tokens → cache_read_input_tokens (account=%d)", result.Usage.InputTokens, account.ID) result.Usage.CacheReadInputTokens += result.Usage.InputTokens result.Usage.InputTokens = 0 } // 获取费率倍数(优先级:用户专属 > 分组默认 > 系统默认) multiplier := s.cfg.Default.RateMultiplier if apiKey.GroupID != nil && apiKey.Group != nil { multiplier = apiKey.Group.RateMultiplier // 检查用户专属倍率 if s.userGroupRateRepo != nil { if userRate, err := s.userGroupRateRepo.GetByUserAndGroup(ctx, user.ID, *apiKey.GroupID); err == nil && userRate != nil { multiplier = *userRate } } } var cost *CostBreakdown // 根据请求类型选择计费方式 if result.ImageCount > 0 { // 图片生成计费 var groupConfig *ImagePriceConfig if apiKey.Group != nil { groupConfig = &ImagePriceConfig{ Price1K: apiKey.Group.ImagePrice1K, Price2K: apiKey.Group.ImagePrice2K, Price4K: apiKey.Group.ImagePrice4K, } } cost = s.billingService.CalculateImageCost(result.Model, result.ImageSize, result.ImageCount, groupConfig, multiplier) } else { // Token 计费(使用长上下文计费方法) tokens := UsageTokens{ InputTokens: result.Usage.InputTokens, OutputTokens: result.Usage.OutputTokens, CacheCreationTokens: result.Usage.CacheCreationInputTokens, CacheReadTokens: result.Usage.CacheReadInputTokens, } var err error cost, err = s.billingService.CalculateCostWithLongContext(result.Model, tokens, multiplier, input.LongContextThreshold, input.LongContextMultiplier) if err != nil { log.Printf("Calculate cost failed: %v", err) cost = &CostBreakdown{ActualCost: 0} } } // 判断计费方式:订阅模式 vs 余额模式 isSubscriptionBilling := subscription != nil && apiKey.Group != nil && apiKey.Group.IsSubscriptionType() billingType := BillingTypeBalance if isSubscriptionBilling { billingType = BillingTypeSubscription } // 创建使用日志 durationMs := int(result.Duration.Milliseconds()) var imageSize *string if result.ImageSize != "" { imageSize = &result.ImageSize } accountRateMultiplier := account.BillingRateMultiplier() usageLog := &UsageLog{ UserID: user.ID, APIKeyID: apiKey.ID, AccountID: account.ID, RequestID: result.RequestID, Model: result.Model, InputTokens: result.Usage.InputTokens, OutputTokens: result.Usage.OutputTokens, CacheCreationTokens: result.Usage.CacheCreationInputTokens, CacheReadTokens: result.Usage.CacheReadInputTokens, InputCost: cost.InputCost, OutputCost: cost.OutputCost, CacheCreationCost: cost.CacheCreationCost, CacheReadCost: cost.CacheReadCost, TotalCost: cost.TotalCost, ActualCost: cost.ActualCost, RateMultiplier: multiplier, AccountRateMultiplier: &accountRateMultiplier, BillingType: billingType, Stream: result.Stream, DurationMs: &durationMs, FirstTokenMs: result.FirstTokenMs, ImageCount: result.ImageCount, ImageSize: imageSize, CreatedAt: time.Now(), } // 添加 UserAgent if input.UserAgent != "" { usageLog.UserAgent = &input.UserAgent } // 添加 IPAddress if input.IPAddress != "" { usageLog.IPAddress = &input.IPAddress } // 添加分组和订阅关联 if apiKey.GroupID != nil { usageLog.GroupID = apiKey.GroupID } if subscription != nil { usageLog.SubscriptionID = &subscription.ID } inserted, err := s.usageLogRepo.Create(ctx, usageLog) if err != nil { log.Printf("Create usage log failed: %v", err) } if s.cfg != nil && s.cfg.RunMode == config.RunModeSimple { log.Printf("[SIMPLE MODE] Usage recorded (not billed): user=%d, tokens=%d", usageLog.UserID, usageLog.TotalTokens()) s.deferredService.ScheduleLastUsedUpdate(account.ID) return nil } shouldBill := inserted || err != nil // 根据计费类型执行扣费 if isSubscriptionBilling { // 订阅模式:更新订阅用量(使用 TotalCost 原始费用,不考虑倍率) if shouldBill && cost.TotalCost > 0 { if err := s.userSubRepo.IncrementUsage(ctx, subscription.ID, cost.TotalCost); err != nil { log.Printf("Increment subscription usage failed: %v", err) } // 异步更新订阅缓存 s.billingCacheService.QueueUpdateSubscriptionUsage(user.ID, *apiKey.GroupID, cost.TotalCost) } } else { // 余额模式:扣除用户余额(使用 ActualCost 考虑倍率后的费用) if shouldBill && cost.ActualCost > 0 { if err := s.userRepo.DeductBalance(ctx, user.ID, cost.ActualCost); err != nil { log.Printf("Deduct balance failed: %v", err) } // 异步更新余额缓存 s.billingCacheService.QueueDeductBalance(user.ID, cost.ActualCost) // API Key 独立配额扣费 if input.APIKeyService != nil && apiKey.Quota > 0 { if err := input.APIKeyService.UpdateQuotaUsed(ctx, apiKey.ID, cost.ActualCost); err != nil { log.Printf("Add API key quota used failed: %v", err) } } } } // Schedule batch update for account last_used_at s.deferredService.ScheduleLastUsedUpdate(account.ID) return nil } // ForwardCountTokens 转发 count_tokens 请求到上游 API // 特点:不记录使用量、仅支持非流式响应 func (s *GatewayService) ForwardCountTokens(ctx context.Context, c *gin.Context, account *Account, parsed *ParsedRequest) error { if parsed == nil { s.countTokensError(c, http.StatusBadRequest, "invalid_request_error", "Request body is empty") return fmt.Errorf("parse request: empty request") } body := parsed.Body reqModel := parsed.Model isClaudeCode := isClaudeCodeRequest(ctx, c, parsed) shouldMimicClaudeCode := account.IsOAuth() && !isClaudeCode if shouldMimicClaudeCode { normalizeOpts := claudeOAuthNormalizeOptions{stripSystemCacheControl: true} body, reqModel = normalizeClaudeOAuthRequestBody(body, reqModel, normalizeOpts) } // Antigravity 账户不支持 count_tokens 转发,直接返回空值 if account.Platform == PlatformAntigravity { c.JSON(http.StatusOK, gin.H{"input_tokens": 0}) return nil } // 应用模型映射: // - APIKey 账号:使用账号级别的显式映射(如果配置),否则透传原始模型名 // - OAuth/SetupToken 账号:使用 Anthropic 标准映射(短ID → 长ID) if reqModel != "" { mappedModel := reqModel mappingSource := "" if account.Type == AccountTypeAPIKey { mappedModel = account.GetMappedModel(reqModel) if mappedModel != reqModel { mappingSource = "account" } } if mappingSource == "" && account.Platform == PlatformAnthropic && account.Type != AccountTypeAPIKey { normalized := claude.NormalizeModelID(reqModel) if normalized != reqModel { mappedModel = normalized mappingSource = "prefix" } } if mappedModel != reqModel { body = s.replaceModelInBody(body, mappedModel) reqModel = mappedModel log.Printf("CountTokens model mapping applied: %s -> %s (account: %s, source=%s)", parsed.Model, mappedModel, account.Name, mappingSource) } } // 获取凭证 token, tokenType, err := s.GetAccessToken(ctx, account) if err != nil { s.countTokensError(c, http.StatusBadGateway, "upstream_error", "Failed to get access token") return err } // 构建上游请求 upstreamReq, err := s.buildCountTokensRequest(ctx, c, account, body, token, tokenType, reqModel, shouldMimicClaudeCode) if err != nil { s.countTokensError(c, http.StatusInternalServerError, "api_error", "Failed to build request") return err } // 获取代理URL proxyURL := "" if account.ProxyID != nil && account.Proxy != nil { proxyURL = account.Proxy.URL() } // 发送请求 resp, err := s.httpUpstream.DoWithTLS(upstreamReq, proxyURL, account.ID, account.Concurrency, account.IsTLSFingerprintEnabled()) if err != nil { setOpsUpstreamError(c, 0, sanitizeUpstreamErrorMessage(err.Error()), "") s.countTokensError(c, http.StatusBadGateway, "upstream_error", "Request failed") return fmt.Errorf("upstream request failed: %w", err) } // 读取响应体 respBody, err := io.ReadAll(resp.Body) _ = resp.Body.Close() if err != nil { s.countTokensError(c, http.StatusBadGateway, "upstream_error", "Failed to read response") return err } // 检测 thinking block 签名错误(400)并重试一次(过滤 thinking blocks) if resp.StatusCode == 400 && s.isThinkingBlockSignatureError(respBody) { log.Printf("Account %d: detected thinking block signature error on count_tokens, retrying with filtered thinking blocks", account.ID) filteredBody := FilterThinkingBlocksForRetry(body) retryReq, buildErr := s.buildCountTokensRequest(ctx, c, account, filteredBody, token, tokenType, reqModel, shouldMimicClaudeCode) if buildErr == nil { retryResp, retryErr := s.httpUpstream.DoWithTLS(retryReq, proxyURL, account.ID, account.Concurrency, account.IsTLSFingerprintEnabled()) if retryErr == nil { resp = retryResp respBody, err = io.ReadAll(resp.Body) _ = resp.Body.Close() if err != nil { s.countTokensError(c, http.StatusBadGateway, "upstream_error", "Failed to read response") return err } } } } // 处理错误响应 if resp.StatusCode >= 400 { // 标记账号状态(429/529等) s.rateLimitService.HandleUpstreamError(ctx, account, resp.StatusCode, resp.Header, respBody) upstreamMsg := strings.TrimSpace(extractUpstreamErrorMessage(respBody)) upstreamMsg = sanitizeUpstreamErrorMessage(upstreamMsg) upstreamDetail := "" if s.cfg != nil && s.cfg.Gateway.LogUpstreamErrorBody { maxBytes := s.cfg.Gateway.LogUpstreamErrorBodyMaxBytes if maxBytes <= 0 { maxBytes = 2048 } upstreamDetail = truncateString(string(respBody), maxBytes) } setOpsUpstreamError(c, resp.StatusCode, upstreamMsg, upstreamDetail) // 记录上游错误摘要便于排障(不回显请求内容) if s.cfg != nil && s.cfg.Gateway.LogUpstreamErrorBody { log.Printf( "count_tokens upstream error %d (account=%d platform=%s type=%s): %s", resp.StatusCode, account.ID, account.Platform, account.Type, truncateForLog(respBody, s.cfg.Gateway.LogUpstreamErrorBodyMaxBytes), ) } // 返回简化的错误响应 errMsg := "Upstream request failed" switch resp.StatusCode { case 429: errMsg = "Rate limit exceeded" case 529: errMsg = "Service overloaded" } s.countTokensError(c, resp.StatusCode, "upstream_error", errMsg) if upstreamMsg == "" { return fmt.Errorf("upstream error: %d", resp.StatusCode) } return fmt.Errorf("upstream error: %d message=%s", resp.StatusCode, upstreamMsg) } // 透传成功响应 c.Data(resp.StatusCode, "application/json", respBody) return nil } // buildCountTokensRequest 构建 count_tokens 上游请求 func (s *GatewayService) buildCountTokensRequest(ctx context.Context, c *gin.Context, account *Account, body []byte, token, tokenType, modelID string, mimicClaudeCode bool) (*http.Request, error) { // 确定目标 URL targetURL := claudeAPICountTokensURL if account.Type == AccountTypeAPIKey { baseURL := account.GetBaseURL() if baseURL != "" { validatedURL, err := s.validateUpstreamBaseURL(baseURL) if err != nil { return nil, err } targetURL = validatedURL + "/v1/messages/count_tokens" } } clientHeaders := http.Header{} if c != nil && c.Request != nil { clientHeaders = c.Request.Header } // OAuth 账号:应用统一指纹和重写 userID // 如果启用了会话ID伪装,会在重写后替换 session 部分为固定值 if account.IsOAuth() && s.identityService != nil { fp, err := s.identityService.GetOrCreateFingerprint(ctx, account.ID, clientHeaders) if err == nil { accountUUID := account.GetExtraString("account_uuid") if accountUUID != "" && fp.ClientID != "" { if newBody, err := s.identityService.RewriteUserIDWithMasking(ctx, body, account, accountUUID, fp.ClientID); err == nil && len(newBody) > 0 { body = newBody } } } } req, err := http.NewRequestWithContext(ctx, "POST", targetURL, bytes.NewReader(body)) if err != nil { return nil, err } // 设置认证头 if tokenType == "oauth" { req.Header.Set("authorization", "Bearer "+token) } else { req.Header.Set("x-api-key", token) } // 白名单透传 headers for key, values := range clientHeaders { lowerKey := strings.ToLower(key) if allowedHeaders[lowerKey] { for _, v := range values { req.Header.Add(key, v) } } } // OAuth 账号:应用指纹到请求头 if account.IsOAuth() && s.identityService != nil { fp, _ := s.identityService.GetOrCreateFingerprint(ctx, account.ID, clientHeaders) if fp != nil { s.identityService.ApplyFingerprint(req, fp) } } // 确保必要的 headers 存在 if req.Header.Get("content-type") == "" { req.Header.Set("content-type", "application/json") } if req.Header.Get("anthropic-version") == "" { req.Header.Set("anthropic-version", "2023-06-01") } if tokenType == "oauth" { applyClaudeOAuthHeaderDefaults(req, false) } // OAuth 账号:处理 anthropic-beta header if tokenType == "oauth" { if mimicClaudeCode { applyClaudeCodeMimicHeaders(req, false) incomingBeta := req.Header.Get("anthropic-beta") requiredBetas := []string{claude.BetaClaudeCode, claude.BetaOAuth, claude.BetaInterleavedThinking, claude.BetaTokenCounting} req.Header.Set("anthropic-beta", mergeAnthropicBeta(requiredBetas, incomingBeta)) } else { clientBetaHeader := req.Header.Get("anthropic-beta") if clientBetaHeader == "" { req.Header.Set("anthropic-beta", claude.CountTokensBetaHeader) } else { beta := s.getBetaHeader(modelID, clientBetaHeader) if !strings.Contains(beta, claude.BetaTokenCounting) { beta = beta + "," + claude.BetaTokenCounting } req.Header.Set("anthropic-beta", beta) } } } else if s.cfg != nil && s.cfg.Gateway.InjectBetaForAPIKey && req.Header.Get("anthropic-beta") == "" { // API-key:与 messages 同步的按需 beta 注入(默认关闭) if requestNeedsBetaFeatures(body) { if beta := defaultAPIKeyBetaHeader(body); beta != "" { req.Header.Set("anthropic-beta", beta) } } } if c != nil && tokenType == "oauth" { c.Set(claudeMimicDebugInfoKey, buildClaudeMimicDebugLine(req, body, account, tokenType, mimicClaudeCode)) } if s.debugClaudeMimicEnabled() { logClaudeMimicDebug(req, body, account, tokenType, mimicClaudeCode) } return req, nil } // countTokensError 返回 count_tokens 错误响应 func (s *GatewayService) countTokensError(c *gin.Context, status int, errType, message string) { c.JSON(status, gin.H{ "type": "error", "error": gin.H{ "type": errType, "message": message, }, }) } func (s *GatewayService) validateUpstreamBaseURL(raw string) (string, error) { if s.cfg != nil && !s.cfg.Security.URLAllowlist.Enabled { normalized, err := urlvalidator.ValidateURLFormat(raw, s.cfg.Security.URLAllowlist.AllowInsecureHTTP) if err != nil { return "", fmt.Errorf("invalid base_url: %w", err) } return normalized, nil } normalized, err := urlvalidator.ValidateHTTPSURL(raw, urlvalidator.ValidationOptions{ AllowedHosts: s.cfg.Security.URLAllowlist.UpstreamHosts, RequireAllowlist: true, AllowPrivate: s.cfg.Security.URLAllowlist.AllowPrivateHosts, }) if err != nil { return "", fmt.Errorf("invalid base_url: %w", err) } return normalized, nil } // checkAntigravityModelScope 检查 Antigravity 平台的模型系列是否在分组支持范围内 func (s *GatewayService) checkAntigravityModelScope(ctx context.Context, groupID int64, requestedModel string) error { scope, ok := ResolveAntigravityQuotaScope(requestedModel) if !ok { return nil // 无法解析 scope,跳过检查 } group, err := s.resolveGroupByID(ctx, groupID) if err != nil { return nil // 查询失败时放行 } if group == nil { return nil // 分组不存在时放行 } if !IsScopeSupported(group.SupportedModelScopes, scope) { return ErrModelScopeNotSupported } return nil } // GetAvailableModels returns the list of models available for a group // It aggregates model_mapping keys from all schedulable accounts in the group func (s *GatewayService) GetAvailableModels(ctx context.Context, groupID *int64, platform string) []string { var accounts []Account var err error if groupID != nil { accounts, err = s.accountRepo.ListSchedulableByGroupID(ctx, *groupID) } else { accounts, err = s.accountRepo.ListSchedulable(ctx) } if err != nil || len(accounts) == 0 { return nil } // Filter by platform if specified if platform != "" { filtered := make([]Account, 0) for _, acc := range accounts { if acc.Platform == platform { filtered = append(filtered, acc) } } accounts = filtered } // Collect unique models from all accounts modelSet := make(map[string]struct{}) hasAnyMapping := false for _, acc := range accounts { mapping := acc.GetModelMapping() if len(mapping) > 0 { hasAnyMapping = true for model := range mapping { modelSet[model] = struct{}{} } } } // If no account has model_mapping, return nil (use default) if !hasAnyMapping { return nil } // Convert to slice models := make([]string, 0, len(modelSet)) for model := range modelSet { models = append(models, model) } return models } // reconcileCachedTokens 兼容 Kimi 等上游: // 将 OpenAI 风格的 cached_tokens 映射到 Claude 标准的 cache_read_input_tokens func reconcileCachedTokens(usage map[string]any) bool { if usage == nil { return false } cacheRead, _ := usage["cache_read_input_tokens"].(float64) if cacheRead > 0 { return false // 已有标准字段,无需处理 } cached, _ := usage["cached_tokens"].(float64) if cached <= 0 { return false } usage["cache_read_input_tokens"] = cached return true }