Files
sub2api/backend/internal/service/gateway_service.go
2026-02-02 22:20:08 +08:00

3993 lines
136 KiB
Go
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

package service
import (
"bufio"
"bytes"
"context"
"crypto/sha256"
"encoding/hex"
"encoding/json"
"errors"
"fmt"
"io"
"log"
"log/slog"
mathrand "math/rand"
"net/http"
"os"
"regexp"
"sort"
"strings"
"sync/atomic"
"time"
"github.com/Wei-Shaw/sub2api/internal/config"
"github.com/Wei-Shaw/sub2api/internal/pkg/claude"
"github.com/Wei-Shaw/sub2api/internal/pkg/ctxkey"
"github.com/Wei-Shaw/sub2api/internal/util/responseheaders"
"github.com/Wei-Shaw/sub2api/internal/util/urlvalidator"
"github.com/tidwall/gjson"
"github.com/tidwall/sjson"
"github.com/gin-gonic/gin"
)
const (
claudeAPIURL = "https://api.anthropic.com/v1/messages?beta=true"
claudeAPICountTokensURL = "https://api.anthropic.com/v1/messages/count_tokens?beta=true"
stickySessionTTL = time.Hour // 粘性会话TTL
defaultMaxLineSize = 40 * 1024 * 1024
claudeCodeSystemPrompt = "You are Claude Code, Anthropic's official CLI for Claude."
maxCacheControlBlocks = 4 // Anthropic API 允许的最大 cache_control 块数量
)
func (s *GatewayService) debugModelRoutingEnabled() bool {
v := strings.ToLower(strings.TrimSpace(os.Getenv("SUB2API_DEBUG_MODEL_ROUTING")))
return v == "1" || v == "true" || v == "yes" || v == "on"
}
func shortSessionHash(sessionHash string) string {
if sessionHash == "" {
return ""
}
if len(sessionHash) <= 8 {
return sessionHash
}
return sessionHash[:8]
}
func normalizeClaudeModelForAnthropic(requestedModel string) string {
for _, prefix := range anthropicPrefixMappings {
if strings.HasPrefix(requestedModel, prefix) {
return prefix
}
}
return requestedModel
}
// sseDataRe matches SSE data lines with optional whitespace after colon.
// Some upstream APIs return non-standard "data:" without space (should be "data: ").
var (
sseDataRe = regexp.MustCompile(`^data:\s*`)
sessionIDRegex = regexp.MustCompile(`session_([a-f0-9-]{36})`)
claudeCliUserAgentRe = regexp.MustCompile(`^claude-cli/\d+\.\d+\.\d+`)
// claudeCodePromptPrefixes 用于检测 Claude Code 系统提示词的前缀列表
// 支持多种变体标准版、Agent SDK 版、Explore Agent 版、Compact 版等
// 注意:前缀之间不应存在包含关系,否则会导致冗余匹配
claudeCodePromptPrefixes = []string{
"You are Claude Code, Anthropic's official CLI for Claude", // 标准版 & Agent SDK 版(含 running within...
"You are a Claude agent, built on Anthropic's Claude Agent SDK", // Agent SDK 变体
"You are a file search specialist for Claude Code", // Explore Agent 版
"You are a helpful AI assistant tasked with summarizing conversations", // Compact 版
}
anthropicPrefixMappings = []string{
"claude-opus-4-5",
"claude-haiku-4-5",
"claude-sonnet-4-5",
}
)
// ErrClaudeCodeOnly 表示分组仅允许 Claude Code 客户端访问
var ErrClaudeCodeOnly = errors.New("this group only allows Claude Code clients")
// ErrModelScopeNotSupported 表示请求的模型系列不在分组支持的范围内
var ErrModelScopeNotSupported = errors.New("model scope not supported by this group")
// allowedHeaders 白名单headers参考CRS项目
var allowedHeaders = map[string]bool{
"accept": true,
"x-stainless-retry-count": true,
"x-stainless-timeout": true,
"x-stainless-lang": true,
"x-stainless-package-version": true,
"x-stainless-os": true,
"x-stainless-arch": true,
"x-stainless-runtime": true,
"x-stainless-runtime-version": true,
"x-stainless-helper-method": true,
"anthropic-dangerous-direct-browser-access": true,
"anthropic-version": true,
"x-app": true,
"anthropic-beta": true,
"accept-language": true,
"sec-fetch-mode": true,
"user-agent": true,
"content-type": true,
}
// GatewayCache 定义网关服务的缓存操作接口。
// 提供粘性会话Sticky Session的存储、查询、刷新和删除功能。
//
// GatewayCache defines cache operations for gateway service.
// Provides sticky session storage, retrieval, refresh and deletion capabilities.
type GatewayCache interface {
// GetSessionAccountID 获取粘性会话绑定的账号 ID
// Get the account ID bound to a sticky session
GetSessionAccountID(ctx context.Context, groupID int64, sessionHash string) (int64, error)
// SetSessionAccountID 设置粘性会话与账号的绑定关系
// Set the binding between sticky session and account
SetSessionAccountID(ctx context.Context, groupID int64, sessionHash string, accountID int64, ttl time.Duration) error
// RefreshSessionTTL 刷新粘性会话的过期时间
// Refresh the expiration time of a sticky session
RefreshSessionTTL(ctx context.Context, groupID int64, sessionHash string, ttl time.Duration) error
// DeleteSessionAccountID 删除粘性会话绑定,用于账号不可用时主动清理
// Delete sticky session binding, used to proactively clean up when account becomes unavailable
DeleteSessionAccountID(ctx context.Context, groupID int64, sessionHash string) error
}
// derefGroupID safely dereferences *int64 to int64, returning 0 if nil
func derefGroupID(groupID *int64) int64 {
if groupID == nil {
return 0
}
return *groupID
}
// shouldClearStickySession 检查账号是否处于不可调度状态,需要清理粘性会话绑定。
// 当账号状态为错误、禁用、不可调度,或处于临时不可调度期间时,返回 true。
// 这确保后续请求不会继续使用不可用的账号。
//
// shouldClearStickySession checks if an account is in an unschedulable state
// and the sticky session binding should be cleared.
// Returns true when account status is error/disabled, schedulable is false,
// or within temporary unschedulable period.
// This ensures subsequent requests won't continue using unavailable accounts.
func shouldClearStickySession(account *Account) bool {
if account == nil {
return false
}
if account.Status == StatusError || account.Status == StatusDisabled || !account.Schedulable {
return true
}
if account.TempUnschedulableUntil != nil && time.Now().Before(*account.TempUnschedulableUntil) {
return true
}
return false
}
type AccountWaitPlan struct {
AccountID int64
MaxConcurrency int
Timeout time.Duration
MaxWaiting int
}
type AccountSelectionResult struct {
Account *Account
Acquired bool
ReleaseFunc func()
WaitPlan *AccountWaitPlan // nil means no wait allowed
}
// ClaudeUsage 表示Claude API返回的usage信息
type ClaudeUsage struct {
InputTokens int `json:"input_tokens"`
OutputTokens int `json:"output_tokens"`
CacheCreationInputTokens int `json:"cache_creation_input_tokens"`
CacheReadInputTokens int `json:"cache_read_input_tokens"`
}
// ForwardResult 转发结果
type ForwardResult struct {
RequestID string
Usage ClaudeUsage
Model string
Stream bool
Duration time.Duration
FirstTokenMs *int // 首字时间(流式请求)
ClientDisconnect bool // 客户端是否在流式传输过程中断开
// 图片生成计费字段(仅 gemini-3-pro-image 使用)
ImageCount int // 生成的图片数量
ImageSize string // 图片尺寸 "1K", "2K", "4K"
}
// UpstreamFailoverError indicates an upstream error that should trigger account failover.
type UpstreamFailoverError struct {
StatusCode int
}
func (e *UpstreamFailoverError) Error() string {
return fmt.Sprintf("upstream error: %d (failover)", e.StatusCode)
}
// GatewayService handles API gateway operations
type GatewayService struct {
accountRepo AccountRepository
groupRepo GroupRepository
usageLogRepo UsageLogRepository
userRepo UserRepository
userSubRepo UserSubscriptionRepository
cache GatewayCache
cfg *config.Config
schedulerSnapshot *SchedulerSnapshotService
billingService *BillingService
rateLimitService *RateLimitService
billingCacheService *BillingCacheService
identityService *IdentityService
httpUpstream HTTPUpstream
deferredService *DeferredService
concurrencyService *ConcurrencyService
claudeTokenProvider *ClaudeTokenProvider
sessionLimitCache SessionLimitCache // 会话数量限制缓存(仅 Anthropic OAuth/SetupToken
}
// NewGatewayService creates a new GatewayService
func NewGatewayService(
accountRepo AccountRepository,
groupRepo GroupRepository,
usageLogRepo UsageLogRepository,
userRepo UserRepository,
userSubRepo UserSubscriptionRepository,
cache GatewayCache,
cfg *config.Config,
schedulerSnapshot *SchedulerSnapshotService,
concurrencyService *ConcurrencyService,
billingService *BillingService,
rateLimitService *RateLimitService,
billingCacheService *BillingCacheService,
identityService *IdentityService,
httpUpstream HTTPUpstream,
deferredService *DeferredService,
claudeTokenProvider *ClaudeTokenProvider,
sessionLimitCache SessionLimitCache,
) *GatewayService {
return &GatewayService{
accountRepo: accountRepo,
groupRepo: groupRepo,
usageLogRepo: usageLogRepo,
userRepo: userRepo,
userSubRepo: userSubRepo,
cache: cache,
cfg: cfg,
schedulerSnapshot: schedulerSnapshot,
concurrencyService: concurrencyService,
billingService: billingService,
rateLimitService: rateLimitService,
billingCacheService: billingCacheService,
identityService: identityService,
httpUpstream: httpUpstream,
deferredService: deferredService,
claudeTokenProvider: claudeTokenProvider,
sessionLimitCache: sessionLimitCache,
}
}
// GenerateSessionHash 从预解析请求计算粘性会话 hash
func (s *GatewayService) GenerateSessionHash(parsed *ParsedRequest) string {
if parsed == nil {
return ""
}
// 1. 最高优先级:从 metadata.user_id 提取 session_xxx
if parsed.MetadataUserID != "" {
if match := sessionIDRegex.FindStringSubmatch(parsed.MetadataUserID); len(match) > 1 {
return match[1]
}
}
// 2. 提取带 cache_control: {type: "ephemeral"} 的内容
cacheableContent := s.extractCacheableContent(parsed)
if cacheableContent != "" {
return s.hashContent(cacheableContent)
}
// 3. Fallback: 使用 system 内容
if parsed.System != nil {
systemText := s.extractTextFromSystem(parsed.System)
if systemText != "" {
return s.hashContent(systemText)
}
}
// 4. 最后 fallback: 使用第一条消息
if len(parsed.Messages) > 0 {
if firstMsg, ok := parsed.Messages[0].(map[string]any); ok {
msgText := s.extractTextFromContent(firstMsg["content"])
if msgText != "" {
return s.hashContent(msgText)
}
}
}
return ""
}
// BindStickySession sets session -> account binding with standard TTL.
func (s *GatewayService) BindStickySession(ctx context.Context, groupID *int64, sessionHash string, accountID int64) error {
if sessionHash == "" || accountID <= 0 || s.cache == nil {
return nil
}
return s.cache.SetSessionAccountID(ctx, derefGroupID(groupID), sessionHash, accountID, stickySessionTTL)
}
// GetCachedSessionAccountID retrieves the account ID bound to a sticky session.
// Returns 0 if no binding exists or on error.
func (s *GatewayService) GetCachedSessionAccountID(ctx context.Context, groupID *int64, sessionHash string) (int64, error) {
if sessionHash == "" || s.cache == nil {
return 0, nil
}
accountID, err := s.cache.GetSessionAccountID(ctx, derefGroupID(groupID), sessionHash)
if err != nil {
return 0, err
}
return accountID, nil
}
func (s *GatewayService) extractCacheableContent(parsed *ParsedRequest) string {
if parsed == nil {
return ""
}
var builder strings.Builder
// 检查 system 中的 cacheable 内容
if system, ok := parsed.System.([]any); ok {
for _, part := range system {
if partMap, ok := part.(map[string]any); ok {
if cc, ok := partMap["cache_control"].(map[string]any); ok {
if cc["type"] == "ephemeral" {
if text, ok := partMap["text"].(string); ok {
_, _ = builder.WriteString(text)
}
}
}
}
}
}
systemText := builder.String()
// 检查 messages 中的 cacheable 内容
for _, msg := range parsed.Messages {
if msgMap, ok := msg.(map[string]any); ok {
if msgContent, ok := msgMap["content"].([]any); ok {
for _, part := range msgContent {
if partMap, ok := part.(map[string]any); ok {
if cc, ok := partMap["cache_control"].(map[string]any); ok {
if cc["type"] == "ephemeral" {
return s.extractTextFromContent(msgMap["content"])
}
}
}
}
}
}
}
return systemText
}
func (s *GatewayService) extractTextFromSystem(system any) string {
switch v := system.(type) {
case string:
return v
case []any:
var texts []string
for _, part := range v {
if partMap, ok := part.(map[string]any); ok {
if text, ok := partMap["text"].(string); ok {
texts = append(texts, text)
}
}
}
return strings.Join(texts, "")
}
return ""
}
func (s *GatewayService) extractTextFromContent(content any) string {
switch v := content.(type) {
case string:
return v
case []any:
var texts []string
for _, part := range v {
if partMap, ok := part.(map[string]any); ok {
if partMap["type"] == "text" {
if text, ok := partMap["text"].(string); ok {
texts = append(texts, text)
}
}
}
}
return strings.Join(texts, "")
}
return ""
}
func (s *GatewayService) hashContent(content string) string {
hash := sha256.Sum256([]byte(content))
return hex.EncodeToString(hash[:16]) // 32字符
}
// replaceModelInBody 替换请求体中的model字段
func (s *GatewayService) replaceModelInBody(body []byte, newModel string) []byte {
var req map[string]any
if err := json.Unmarshal(body, &req); err != nil {
return body
}
req["model"] = newModel
newBody, err := json.Marshal(req)
if err != nil {
return body
}
return newBody
}
// SelectAccount 选择账号(粘性会话+优先级)
func (s *GatewayService) SelectAccount(ctx context.Context, groupID *int64, sessionHash string) (*Account, error) {
return s.SelectAccountForModel(ctx, groupID, sessionHash, "")
}
// SelectAccountForModel 选择支持指定模型的账号(粘性会话+优先级+模型映射)
func (s *GatewayService) SelectAccountForModel(ctx context.Context, groupID *int64, sessionHash string, requestedModel string) (*Account, error) {
return s.SelectAccountForModelWithExclusions(ctx, groupID, sessionHash, requestedModel, nil)
}
// SelectAccountForModelWithExclusions selects an account supporting the requested model while excluding specified accounts.
func (s *GatewayService) SelectAccountForModelWithExclusions(ctx context.Context, groupID *int64, sessionHash string, requestedModel string, excludedIDs map[int64]struct{}) (*Account, error) {
// 优先检查 context 中的强制平台(/antigravity 路由)
var platform string
forcePlatform, hasForcePlatform := ctx.Value(ctxkey.ForcePlatform).(string)
if hasForcePlatform && forcePlatform != "" {
platform = forcePlatform
} else if groupID != nil {
group, resolvedGroupID, err := s.resolveGatewayGroup(ctx, groupID)
if err != nil {
return nil, err
}
groupID = resolvedGroupID
ctx = s.withGroupContext(ctx, group)
platform = group.Platform
} else {
// 无分组时只使用原生 anthropic 平台
platform = PlatformAnthropic
}
// anthropic/gemini 分组支持混合调度(包含启用了 mixed_scheduling 的 antigravity 账户)
// 注意:强制平台模式不走混合调度
if (platform == PlatformAnthropic || platform == PlatformGemini) && !hasForcePlatform {
return s.selectAccountWithMixedScheduling(ctx, groupID, sessionHash, requestedModel, excludedIDs, platform)
}
// antigravity 分组、强制平台模式或无分组使用单平台选择
// 注意:强制平台模式也必须遵守分组限制,不再回退到全平台查询
return s.selectAccountForModelWithPlatform(ctx, groupID, sessionHash, requestedModel, excludedIDs, platform)
}
// SelectAccountWithLoadAwareness selects account with load-awareness and wait plan.
// metadataUserID: 已废弃参数,会话限制现在统一使用 sessionHash
func (s *GatewayService) SelectAccountWithLoadAwareness(ctx context.Context, groupID *int64, sessionHash string, requestedModel string, excludedIDs map[int64]struct{}, metadataUserID string) (*AccountSelectionResult, error) {
// 调试日志:记录调度入口参数
excludedIDsList := make([]int64, 0, len(excludedIDs))
for id := range excludedIDs {
excludedIDsList = append(excludedIDsList, id)
}
slog.Debug("account_scheduling_starting",
"group_id", derefGroupID(groupID),
"model", requestedModel,
"session", shortSessionHash(sessionHash),
"excluded_ids", excludedIDsList)
cfg := s.schedulingConfig()
var stickyAccountID int64
if sessionHash != "" && s.cache != nil {
if accountID, err := s.cache.GetSessionAccountID(ctx, derefGroupID(groupID), sessionHash); err == nil {
stickyAccountID = accountID
}
}
// 检查 Claude Code 客户端限制(可能会替换 groupID 为降级分组)
group, groupID, err := s.checkClaudeCodeRestriction(ctx, groupID)
if err != nil {
return nil, err
}
ctx = s.withGroupContext(ctx, group)
if s.debugModelRoutingEnabled() && requestedModel != "" {
groupPlatform := ""
if group != nil {
groupPlatform = group.Platform
}
log.Printf("[ModelRoutingDebug] select entry: group_id=%v group_platform=%s model=%s session=%s sticky_account=%d load_batch=%v concurrency=%v",
derefGroupID(groupID), groupPlatform, requestedModel, shortSessionHash(sessionHash), stickyAccountID, cfg.LoadBatchEnabled, s.concurrencyService != nil)
}
if s.concurrencyService == nil || !cfg.LoadBatchEnabled {
// 复制排除列表,用于会话限制拒绝时的重试
localExcluded := make(map[int64]struct{})
for k, v := range excludedIDs {
localExcluded[k] = v
}
for {
account, err := s.SelectAccountForModelWithExclusions(ctx, groupID, sessionHash, requestedModel, localExcluded)
if err != nil {
return nil, err
}
result, err := s.tryAcquireAccountSlot(ctx, account.ID, account.Concurrency)
if err == nil && result.Acquired {
// 获取槽位后检查会话限制(使用 sessionHash 作为会话标识符)
if !s.checkAndRegisterSession(ctx, account, sessionHash) {
result.ReleaseFunc() // 释放槽位
localExcluded[account.ID] = struct{}{} // 排除此账号
continue // 重新选择
}
return &AccountSelectionResult{
Account: account,
Acquired: true,
ReleaseFunc: result.ReleaseFunc,
}, nil
}
// 对于等待计划的情况,也需要先检查会话限制
if !s.checkAndRegisterSession(ctx, account, sessionHash) {
localExcluded[account.ID] = struct{}{}
continue
}
if stickyAccountID > 0 && stickyAccountID == account.ID && s.concurrencyService != nil {
waitingCount, _ := s.concurrencyService.GetAccountWaitingCount(ctx, account.ID)
if waitingCount < cfg.StickySessionMaxWaiting {
return &AccountSelectionResult{
Account: account,
WaitPlan: &AccountWaitPlan{
AccountID: account.ID,
MaxConcurrency: account.Concurrency,
Timeout: cfg.StickySessionWaitTimeout,
MaxWaiting: cfg.StickySessionMaxWaiting,
},
}, nil
}
}
return &AccountSelectionResult{
Account: account,
WaitPlan: &AccountWaitPlan{
AccountID: account.ID,
MaxConcurrency: account.Concurrency,
Timeout: cfg.FallbackWaitTimeout,
MaxWaiting: cfg.FallbackMaxWaiting,
},
}, nil
}
}
platform, hasForcePlatform, err := s.resolvePlatform(ctx, groupID, group)
if err != nil {
return nil, err
}
preferOAuth := platform == PlatformGemini
if s.debugModelRoutingEnabled() && platform == PlatformAnthropic && requestedModel != "" {
log.Printf("[ModelRoutingDebug] load-aware enabled: group_id=%v model=%s session=%s platform=%s", derefGroupID(groupID), requestedModel, shortSessionHash(sessionHash), platform)
}
// Antigravity 模型系列检查(在账号选择前检查,确保所有代码路径都经过此检查)
if platform == PlatformAntigravity && groupID != nil && requestedModel != "" {
if err := s.checkAntigravityModelScope(ctx, *groupID, requestedModel); err != nil {
return nil, err
}
}
accounts, useMixed, err := s.listSchedulableAccounts(ctx, groupID, platform, hasForcePlatform)
if err != nil {
return nil, err
}
if len(accounts) == 0 {
return nil, errors.New("no available accounts")
}
isExcluded := func(accountID int64) bool {
if excludedIDs == nil {
return false
}
_, excluded := excludedIDs[accountID]
return excluded
}
// 提前构建 accountByID供 Layer 1 和 Layer 1.5 使用)
accountByID := make(map[int64]*Account, len(accounts))
for i := range accounts {
accountByID[accounts[i].ID] = &accounts[i]
}
// 获取模型路由配置(仅 anthropic 平台)
var routingAccountIDs []int64
if group != nil && requestedModel != "" && group.Platform == PlatformAnthropic {
routingAccountIDs = group.GetRoutingAccountIDs(requestedModel)
if s.debugModelRoutingEnabled() {
log.Printf("[ModelRoutingDebug] context group routing: group_id=%d model=%s enabled=%v rules=%d matched_ids=%v session=%s sticky_account=%d",
group.ID, requestedModel, group.ModelRoutingEnabled, len(group.ModelRouting), routingAccountIDs, shortSessionHash(sessionHash), stickyAccountID)
if len(routingAccountIDs) == 0 && group.ModelRoutingEnabled && len(group.ModelRouting) > 0 {
keys := make([]string, 0, len(group.ModelRouting))
for k := range group.ModelRouting {
keys = append(keys, k)
}
sort.Strings(keys)
const maxKeys = 20
if len(keys) > maxKeys {
keys = keys[:maxKeys]
}
log.Printf("[ModelRoutingDebug] context group routing miss: group_id=%d model=%s patterns(sample)=%v", group.ID, requestedModel, keys)
}
}
}
// ============ Layer 1: 模型路由优先选择(优先级高于粘性会话) ============
if len(routingAccountIDs) > 0 && s.concurrencyService != nil {
// 1. 过滤出路由列表中可调度的账号
var routingCandidates []*Account
var filteredExcluded, filteredMissing, filteredUnsched, filteredPlatform, filteredModelScope, filteredModelMapping, filteredWindowCost int
for _, routingAccountID := range routingAccountIDs {
if isExcluded(routingAccountID) {
filteredExcluded++
continue
}
account, ok := accountByID[routingAccountID]
if !ok || !account.IsSchedulable() {
if !ok {
filteredMissing++
} else {
filteredUnsched++
}
continue
}
if !s.isAccountAllowedForPlatform(account, platform, useMixed) {
filteredPlatform++
continue
}
if !account.IsSchedulableForModel(requestedModel) {
filteredModelScope++
continue
}
if requestedModel != "" && !s.isModelSupportedByAccount(account, requestedModel) {
filteredModelMapping++
continue
}
// 窗口费用检查(非粘性会话路径)
if !s.isAccountSchedulableForWindowCost(ctx, account, false) {
filteredWindowCost++
continue
}
routingCandidates = append(routingCandidates, account)
}
if s.debugModelRoutingEnabled() {
log.Printf("[ModelRoutingDebug] routed candidates: group_id=%v model=%s routed=%d candidates=%d filtered(excluded=%d missing=%d unsched=%d platform=%d model_scope=%d model_mapping=%d window_cost=%d)",
derefGroupID(groupID), requestedModel, len(routingAccountIDs), len(routingCandidates),
filteredExcluded, filteredMissing, filteredUnsched, filteredPlatform, filteredModelScope, filteredModelMapping, filteredWindowCost)
}
if len(routingCandidates) > 0 {
// 1.5. 在路由账号范围内检查粘性会话
if sessionHash != "" && s.cache != nil {
stickyAccountID, err := s.cache.GetSessionAccountID(ctx, derefGroupID(groupID), sessionHash)
if err == nil && stickyAccountID > 0 && containsInt64(routingAccountIDs, stickyAccountID) && !isExcluded(stickyAccountID) {
// 粘性账号在路由列表中,优先使用
if stickyAccount, ok := accountByID[stickyAccountID]; ok {
if stickyAccount.IsSchedulable() &&
s.isAccountAllowedForPlatform(stickyAccount, platform, useMixed) &&
stickyAccount.IsSchedulableForModel(requestedModel) &&
(requestedModel == "" || s.isModelSupportedByAccount(stickyAccount, requestedModel)) &&
s.isAccountSchedulableForWindowCost(ctx, stickyAccount, true) { // 粘性会话窗口费用检查
result, err := s.tryAcquireAccountSlot(ctx, stickyAccountID, stickyAccount.Concurrency)
if err == nil && result.Acquired {
// 会话数量限制检查
if !s.checkAndRegisterSession(ctx, stickyAccount, sessionHash) {
result.ReleaseFunc() // 释放槽位
// 继续到负载感知选择
} else {
_ = s.cache.RefreshSessionTTL(ctx, derefGroupID(groupID), sessionHash, stickySessionTTL)
if s.debugModelRoutingEnabled() {
log.Printf("[ModelRoutingDebug] routed sticky hit: group_id=%v model=%s session=%s account=%d", derefGroupID(groupID), requestedModel, shortSessionHash(sessionHash), stickyAccountID)
}
return &AccountSelectionResult{
Account: stickyAccount,
Acquired: true,
ReleaseFunc: result.ReleaseFunc,
}, nil
}
}
waitingCount, _ := s.concurrencyService.GetAccountWaitingCount(ctx, stickyAccountID)
if waitingCount < cfg.StickySessionMaxWaiting {
// 会话数量限制检查(等待计划也需要占用会话配额)
if !s.checkAndRegisterSession(ctx, stickyAccount, sessionHash) {
// 会话限制已满,继续到负载感知选择
} else {
return &AccountSelectionResult{
Account: stickyAccount,
WaitPlan: &AccountWaitPlan{
AccountID: stickyAccountID,
MaxConcurrency: stickyAccount.Concurrency,
Timeout: cfg.StickySessionWaitTimeout,
MaxWaiting: cfg.StickySessionMaxWaiting,
},
}, nil
}
}
// 粘性账号槽位满且等待队列已满,继续使用负载感知选择
}
} else {
_ = s.cache.DeleteSessionAccountID(ctx, derefGroupID(groupID), sessionHash)
}
}
}
// 2. 批量获取负载信息
routingLoads := make([]AccountWithConcurrency, 0, len(routingCandidates))
for _, acc := range routingCandidates {
routingLoads = append(routingLoads, AccountWithConcurrency{
ID: acc.ID,
MaxConcurrency: acc.Concurrency,
})
}
routingLoadMap, _ := s.concurrencyService.GetAccountsLoadBatch(ctx, routingLoads)
// 3. 按负载感知排序
type accountWithLoad struct {
account *Account
loadInfo *AccountLoadInfo
}
var routingAvailable []accountWithLoad
for _, acc := range routingCandidates {
loadInfo := routingLoadMap[acc.ID]
if loadInfo == nil {
loadInfo = &AccountLoadInfo{AccountID: acc.ID}
}
if loadInfo.LoadRate < 100 {
routingAvailable = append(routingAvailable, accountWithLoad{account: acc, loadInfo: loadInfo})
}
}
if len(routingAvailable) > 0 {
// 排序:优先级 > 负载率 > 最后使用时间
sort.SliceStable(routingAvailable, func(i, j int) bool {
a, b := routingAvailable[i], routingAvailable[j]
if a.account.Priority != b.account.Priority {
return a.account.Priority < b.account.Priority
}
if a.loadInfo.LoadRate != b.loadInfo.LoadRate {
return a.loadInfo.LoadRate < b.loadInfo.LoadRate
}
switch {
case a.account.LastUsedAt == nil && b.account.LastUsedAt != nil:
return true
case a.account.LastUsedAt != nil && b.account.LastUsedAt == nil:
return false
case a.account.LastUsedAt == nil && b.account.LastUsedAt == nil:
return false
default:
return a.account.LastUsedAt.Before(*b.account.LastUsedAt)
}
})
// 4. 尝试获取槽位
for _, item := range routingAvailable {
result, err := s.tryAcquireAccountSlot(ctx, item.account.ID, item.account.Concurrency)
if err == nil && result.Acquired {
// 会话数量限制检查
if !s.checkAndRegisterSession(ctx, item.account, sessionHash) {
result.ReleaseFunc() // 释放槽位,继续尝试下一个账号
continue
}
if sessionHash != "" && s.cache != nil {
_ = s.cache.SetSessionAccountID(ctx, derefGroupID(groupID), sessionHash, item.account.ID, stickySessionTTL)
}
if s.debugModelRoutingEnabled() {
log.Printf("[ModelRoutingDebug] routed select: group_id=%v model=%s session=%s account=%d", derefGroupID(groupID), requestedModel, shortSessionHash(sessionHash), item.account.ID)
}
return &AccountSelectionResult{
Account: item.account,
Acquired: true,
ReleaseFunc: result.ReleaseFunc,
}, nil
}
}
// 5. 所有路由账号槽位满,尝试返回等待计划(选择负载最低的)
// 遍历找到第一个满足会话限制的账号
for _, item := range routingAvailable {
if !s.checkAndRegisterSession(ctx, item.account, sessionHash) {
continue // 会话限制已满,尝试下一个
}
if s.debugModelRoutingEnabled() {
log.Printf("[ModelRoutingDebug] routed wait: group_id=%v model=%s session=%s account=%d", derefGroupID(groupID), requestedModel, shortSessionHash(sessionHash), item.account.ID)
}
return &AccountSelectionResult{
Account: item.account,
WaitPlan: &AccountWaitPlan{
AccountID: item.account.ID,
MaxConcurrency: item.account.Concurrency,
Timeout: cfg.StickySessionWaitTimeout,
MaxWaiting: cfg.StickySessionMaxWaiting,
},
}, nil
}
// 所有路由账号会话限制都已满,继续到 Layer 2 回退
}
// 路由列表中的账号都不可用(负载率 >= 100继续到 Layer 2 回退
log.Printf("[ModelRouting] All routed accounts unavailable for model=%s, falling back to normal selection", requestedModel)
}
}
// ============ Layer 1.5: 粘性会话(仅在无模型路由配置时生效) ============
if len(routingAccountIDs) == 0 && sessionHash != "" && s.cache != nil {
accountID, err := s.cache.GetSessionAccountID(ctx, derefGroupID(groupID), sessionHash)
if err == nil && accountID > 0 && !isExcluded(accountID) {
account, ok := accountByID[accountID]
if ok {
// 检查账户是否需要清理粘性会话绑定
// Check if the account needs sticky session cleanup
clearSticky := shouldClearStickySession(account)
if clearSticky {
_ = s.cache.DeleteSessionAccountID(ctx, derefGroupID(groupID), sessionHash)
}
if !clearSticky && s.isAccountInGroup(account, groupID) &&
s.isAccountAllowedForPlatform(account, platform, useMixed) &&
account.IsSchedulableForModel(requestedModel) &&
(requestedModel == "" || s.isModelSupportedByAccount(account, requestedModel)) &&
s.isAccountSchedulableForWindowCost(ctx, account, true) { // 粘性会话窗口费用检查
result, err := s.tryAcquireAccountSlot(ctx, accountID, account.Concurrency)
if err == nil && result.Acquired {
// 会话数量限制检查
// Session count limit check
if !s.checkAndRegisterSession(ctx, account, sessionHash) {
result.ReleaseFunc() // 释放槽位,继续到 Layer 2
} else {
_ = s.cache.RefreshSessionTTL(ctx, derefGroupID(groupID), sessionHash, stickySessionTTL)
return &AccountSelectionResult{
Account: account,
Acquired: true,
ReleaseFunc: result.ReleaseFunc,
}, nil
}
}
waitingCount, _ := s.concurrencyService.GetAccountWaitingCount(ctx, accountID)
if waitingCount < cfg.StickySessionMaxWaiting {
// 会话数量限制检查(等待计划也需要占用会话配额)
// Session count limit check (wait plan also requires session quota)
if !s.checkAndRegisterSession(ctx, account, sessionHash) {
// 会话限制已满,继续到 Layer 2
// Session limit full, continue to Layer 2
} else {
return &AccountSelectionResult{
Account: account,
WaitPlan: &AccountWaitPlan{
AccountID: accountID,
MaxConcurrency: account.Concurrency,
Timeout: cfg.StickySessionWaitTimeout,
MaxWaiting: cfg.StickySessionMaxWaiting,
},
}, nil
}
}
}
}
}
}
// ============ Layer 2: 负载感知选择 ============
candidates := make([]*Account, 0, len(accounts))
for i := range accounts {
acc := &accounts[i]
if isExcluded(acc.ID) {
continue
}
// Scheduler snapshots can be temporarily stale (bucket rebuild is throttled);
// re-check schedulability here so recently rate-limited/overloaded accounts
// are not selected again before the bucket is rebuilt.
if !acc.IsSchedulable() {
continue
}
if !s.isAccountAllowedForPlatform(acc, platform, useMixed) {
continue
}
if !acc.IsSchedulableForModel(requestedModel) {
continue
}
if requestedModel != "" && !s.isModelSupportedByAccount(acc, requestedModel) {
continue
}
// 窗口费用检查(非粘性会话路径)
if !s.isAccountSchedulableForWindowCost(ctx, acc, false) {
continue
}
candidates = append(candidates, acc)
}
if len(candidates) == 0 {
return nil, errors.New("no available accounts")
}
accountLoads := make([]AccountWithConcurrency, 0, len(candidates))
for _, acc := range candidates {
accountLoads = append(accountLoads, AccountWithConcurrency{
ID: acc.ID,
MaxConcurrency: acc.Concurrency,
})
}
loadMap, err := s.concurrencyService.GetAccountsLoadBatch(ctx, accountLoads)
if err != nil {
if result, ok := s.tryAcquireByLegacyOrder(ctx, candidates, groupID, sessionHash, preferOAuth); ok {
return result, nil
}
} else {
type accountWithLoad struct {
account *Account
loadInfo *AccountLoadInfo
}
var available []accountWithLoad
for _, acc := range candidates {
loadInfo := loadMap[acc.ID]
if loadInfo == nil {
loadInfo = &AccountLoadInfo{AccountID: acc.ID}
}
if loadInfo.LoadRate < 100 {
available = append(available, accountWithLoad{
account: acc,
loadInfo: loadInfo,
})
}
}
if len(available) > 0 {
sort.SliceStable(available, func(i, j int) bool {
a, b := available[i], available[j]
if a.account.Priority != b.account.Priority {
return a.account.Priority < b.account.Priority
}
if a.loadInfo.LoadRate != b.loadInfo.LoadRate {
return a.loadInfo.LoadRate < b.loadInfo.LoadRate
}
switch {
case a.account.LastUsedAt == nil && b.account.LastUsedAt != nil:
return true
case a.account.LastUsedAt != nil && b.account.LastUsedAt == nil:
return false
case a.account.LastUsedAt == nil && b.account.LastUsedAt == nil:
if preferOAuth && a.account.Type != b.account.Type {
return a.account.Type == AccountTypeOAuth
}
return false
default:
return a.account.LastUsedAt.Before(*b.account.LastUsedAt)
}
})
for _, item := range available {
result, err := s.tryAcquireAccountSlot(ctx, item.account.ID, item.account.Concurrency)
if err == nil && result.Acquired {
// 会话数量限制检查
if !s.checkAndRegisterSession(ctx, item.account, sessionHash) {
result.ReleaseFunc() // 释放槽位,继续尝试下一个账号
continue
}
if sessionHash != "" && s.cache != nil {
_ = s.cache.SetSessionAccountID(ctx, derefGroupID(groupID), sessionHash, item.account.ID, stickySessionTTL)
}
return &AccountSelectionResult{
Account: item.account,
Acquired: true,
ReleaseFunc: result.ReleaseFunc,
}, nil
}
}
}
}
// ============ Layer 3: 兜底排队 ============
s.sortCandidatesForFallback(candidates, preferOAuth, cfg.FallbackSelectionMode)
for _, acc := range candidates {
// 会话数量限制检查(等待计划也需要占用会话配额)
if !s.checkAndRegisterSession(ctx, acc, sessionHash) {
continue // 会话限制已满,尝试下一个账号
}
return &AccountSelectionResult{
Account: acc,
WaitPlan: &AccountWaitPlan{
AccountID: acc.ID,
MaxConcurrency: acc.Concurrency,
Timeout: cfg.FallbackWaitTimeout,
MaxWaiting: cfg.FallbackMaxWaiting,
},
}, nil
}
return nil, errors.New("no available accounts")
}
func (s *GatewayService) tryAcquireByLegacyOrder(ctx context.Context, candidates []*Account, groupID *int64, sessionHash string, preferOAuth bool) (*AccountSelectionResult, bool) {
ordered := append([]*Account(nil), candidates...)
sortAccountsByPriorityAndLastUsed(ordered, preferOAuth)
for _, acc := range ordered {
result, err := s.tryAcquireAccountSlot(ctx, acc.ID, acc.Concurrency)
if err == nil && result.Acquired {
// 会话数量限制检查
if !s.checkAndRegisterSession(ctx, acc, sessionHash) {
result.ReleaseFunc() // 释放槽位,继续尝试下一个账号
continue
}
if sessionHash != "" && s.cache != nil {
_ = s.cache.SetSessionAccountID(ctx, derefGroupID(groupID), sessionHash, acc.ID, stickySessionTTL)
}
return &AccountSelectionResult{
Account: acc,
Acquired: true,
ReleaseFunc: result.ReleaseFunc,
}, true
}
}
return nil, false
}
func (s *GatewayService) schedulingConfig() config.GatewaySchedulingConfig {
if s.cfg != nil {
return s.cfg.Gateway.Scheduling
}
return config.GatewaySchedulingConfig{
StickySessionMaxWaiting: 3,
StickySessionWaitTimeout: 45 * time.Second,
FallbackWaitTimeout: 30 * time.Second,
FallbackMaxWaiting: 100,
LoadBatchEnabled: true,
SlotCleanupInterval: 30 * time.Second,
}
}
func (s *GatewayService) withGroupContext(ctx context.Context, group *Group) context.Context {
if !IsGroupContextValid(group) {
return ctx
}
if existing, ok := ctx.Value(ctxkey.Group).(*Group); ok && existing != nil && existing.ID == group.ID && IsGroupContextValid(existing) {
return ctx
}
return context.WithValue(ctx, ctxkey.Group, group)
}
func (s *GatewayService) groupFromContext(ctx context.Context, groupID int64) *Group {
if group, ok := ctx.Value(ctxkey.Group).(*Group); ok && IsGroupContextValid(group) && group.ID == groupID {
return group
}
return nil
}
func (s *GatewayService) resolveGroupByID(ctx context.Context, groupID int64) (*Group, error) {
if group := s.groupFromContext(ctx, groupID); group != nil {
return group, nil
}
group, err := s.groupRepo.GetByIDLite(ctx, groupID)
if err != nil {
return nil, fmt.Errorf("get group failed: %w", err)
}
return group, nil
}
func (s *GatewayService) ResolveGroupByID(ctx context.Context, groupID int64) (*Group, error) {
return s.resolveGroupByID(ctx, groupID)
}
func (s *GatewayService) routingAccountIDsForRequest(ctx context.Context, groupID *int64, requestedModel string, platform string) []int64 {
if groupID == nil || requestedModel == "" || platform != PlatformAnthropic {
return nil
}
group, err := s.resolveGroupByID(ctx, *groupID)
if err != nil || group == nil {
if s.debugModelRoutingEnabled() {
log.Printf("[ModelRoutingDebug] resolve group failed: group_id=%v model=%s platform=%s err=%v", derefGroupID(groupID), requestedModel, platform, err)
}
return nil
}
// Preserve existing behavior: model routing only applies to anthropic groups.
if group.Platform != PlatformAnthropic {
if s.debugModelRoutingEnabled() {
log.Printf("[ModelRoutingDebug] skip: non-anthropic group platform: group_id=%d group_platform=%s model=%s", group.ID, group.Platform, requestedModel)
}
return nil
}
ids := group.GetRoutingAccountIDs(requestedModel)
if s.debugModelRoutingEnabled() {
log.Printf("[ModelRoutingDebug] routing lookup: group_id=%d model=%s enabled=%v rules=%d matched_ids=%v",
group.ID, requestedModel, group.ModelRoutingEnabled, len(group.ModelRouting), ids)
}
return ids
}
func (s *GatewayService) resolveGatewayGroup(ctx context.Context, groupID *int64) (*Group, *int64, error) {
if groupID == nil {
return nil, nil, nil
}
currentID := *groupID
visited := map[int64]struct{}{}
for {
if _, seen := visited[currentID]; seen {
return nil, nil, fmt.Errorf("fallback group cycle detected")
}
visited[currentID] = struct{}{}
group, err := s.resolveGroupByID(ctx, currentID)
if err != nil {
return nil, nil, err
}
if !group.ClaudeCodeOnly || IsClaudeCodeClient(ctx) {
return group, &currentID, nil
}
if group.FallbackGroupID == nil {
return nil, nil, ErrClaudeCodeOnly
}
currentID = *group.FallbackGroupID
}
}
// checkClaudeCodeRestriction 检查分组的 Claude Code 客户端限制
// 如果分组启用了 claude_code_only 且请求不是来自 Claude Code 客户端:
// - 有降级分组:返回降级分组的 ID
// - 无降级分组:返回 ErrClaudeCodeOnly 错误
func (s *GatewayService) checkClaudeCodeRestriction(ctx context.Context, groupID *int64) (*Group, *int64, error) {
if groupID == nil {
return nil, groupID, nil
}
// 强制平台模式不检查 Claude Code 限制
if forcePlatform, hasForcePlatform := ctx.Value(ctxkey.ForcePlatform).(string); hasForcePlatform && forcePlatform != "" {
return nil, groupID, nil
}
group, resolvedID, err := s.resolveGatewayGroup(ctx, groupID)
if err != nil {
return nil, nil, err
}
return group, resolvedID, nil
}
func (s *GatewayService) resolvePlatform(ctx context.Context, groupID *int64, group *Group) (string, bool, error) {
forcePlatform, hasForcePlatform := ctx.Value(ctxkey.ForcePlatform).(string)
if hasForcePlatform && forcePlatform != "" {
return forcePlatform, true, nil
}
if group != nil {
return group.Platform, false, nil
}
if groupID != nil {
group, err := s.resolveGroupByID(ctx, *groupID)
if err != nil {
return "", false, err
}
return group.Platform, false, nil
}
return PlatformAnthropic, false, nil
}
func (s *GatewayService) listSchedulableAccounts(ctx context.Context, groupID *int64, platform string, hasForcePlatform bool) ([]Account, bool, error) {
if s.schedulerSnapshot != nil {
accounts, useMixed, err := s.schedulerSnapshot.ListSchedulableAccounts(ctx, groupID, platform, hasForcePlatform)
if err == nil {
slog.Debug("account_scheduling_list_snapshot",
"group_id", derefGroupID(groupID),
"platform", platform,
"use_mixed", useMixed,
"count", len(accounts))
for _, acc := range accounts {
slog.Debug("account_scheduling_account_detail",
"account_id", acc.ID,
"name", acc.Name,
"platform", acc.Platform,
"type", acc.Type,
"status", acc.Status,
"tls_fingerprint", acc.IsTLSFingerprintEnabled())
}
}
return accounts, useMixed, err
}
useMixed := (platform == PlatformAnthropic || platform == PlatformGemini) && !hasForcePlatform
if useMixed {
platforms := []string{platform, PlatformAntigravity}
var accounts []Account
var err error
if groupID != nil {
accounts, err = s.accountRepo.ListSchedulableByGroupIDAndPlatforms(ctx, *groupID, platforms)
} else {
accounts, err = s.accountRepo.ListSchedulableByPlatforms(ctx, platforms)
}
if err != nil {
slog.Debug("account_scheduling_list_failed",
"group_id", derefGroupID(groupID),
"platform", platform,
"error", err)
return nil, useMixed, err
}
filtered := make([]Account, 0, len(accounts))
for _, acc := range accounts {
if acc.Platform == PlatformAntigravity && !acc.IsMixedSchedulingEnabled() {
continue
}
filtered = append(filtered, acc)
}
slog.Debug("account_scheduling_list_mixed",
"group_id", derefGroupID(groupID),
"platform", platform,
"raw_count", len(accounts),
"filtered_count", len(filtered))
for _, acc := range filtered {
slog.Debug("account_scheduling_account_detail",
"account_id", acc.ID,
"name", acc.Name,
"platform", acc.Platform,
"type", acc.Type,
"status", acc.Status,
"tls_fingerprint", acc.IsTLSFingerprintEnabled())
}
return filtered, useMixed, nil
}
var accounts []Account
var err error
if s.cfg != nil && s.cfg.RunMode == config.RunModeSimple {
accounts, err = s.accountRepo.ListSchedulableByPlatform(ctx, platform)
} else if groupID != nil {
accounts, err = s.accountRepo.ListSchedulableByGroupIDAndPlatform(ctx, *groupID, platform)
// 分组内无账号则返回空列表,由上层处理错误,不再回退到全平台查询
} else {
accounts, err = s.accountRepo.ListSchedulableByPlatform(ctx, platform)
}
if err != nil {
slog.Debug("account_scheduling_list_failed",
"group_id", derefGroupID(groupID),
"platform", platform,
"error", err)
return nil, useMixed, err
}
slog.Debug("account_scheduling_list_single",
"group_id", derefGroupID(groupID),
"platform", platform,
"count", len(accounts))
for _, acc := range accounts {
slog.Debug("account_scheduling_account_detail",
"account_id", acc.ID,
"name", acc.Name,
"platform", acc.Platform,
"type", acc.Type,
"status", acc.Status,
"tls_fingerprint", acc.IsTLSFingerprintEnabled())
}
return accounts, useMixed, nil
}
func (s *GatewayService) isAccountAllowedForPlatform(account *Account, platform string, useMixed bool) bool {
if account == nil {
return false
}
if useMixed {
if account.Platform == platform {
return true
}
return account.Platform == PlatformAntigravity && account.IsMixedSchedulingEnabled()
}
return account.Platform == platform
}
// isAccountInGroup checks if the account belongs to the specified group.
// Returns true if groupID is nil (no group restriction) or account belongs to the group.
func (s *GatewayService) isAccountInGroup(account *Account, groupID *int64) bool {
if groupID == nil {
return true // 无分组限制
}
if account == nil {
return false
}
for _, ag := range account.AccountGroups {
if ag.GroupID == *groupID {
return true
}
}
return false
}
func (s *GatewayService) tryAcquireAccountSlot(ctx context.Context, accountID int64, maxConcurrency int) (*AcquireResult, error) {
if s.concurrencyService == nil {
return &AcquireResult{Acquired: true, ReleaseFunc: func() {}}, nil
}
return s.concurrencyService.AcquireAccountSlot(ctx, accountID, maxConcurrency)
}
// isAccountSchedulableForWindowCost 检查账号是否可根据窗口费用进行调度
// 仅适用于 Anthropic OAuth/SetupToken 账号
// 返回 true 表示可调度false 表示不可调度
func (s *GatewayService) isAccountSchedulableForWindowCost(ctx context.Context, account *Account, isSticky bool) bool {
// 只检查 Anthropic OAuth/SetupToken 账号
if !account.IsAnthropicOAuthOrSetupToken() {
return true
}
limit := account.GetWindowCostLimit()
if limit <= 0 {
return true // 未启用窗口费用限制
}
// 尝试从缓存获取窗口费用
var currentCost float64
if s.sessionLimitCache != nil {
if cost, hit, err := s.sessionLimitCache.GetWindowCost(ctx, account.ID); err == nil && hit {
currentCost = cost
goto checkSchedulability
}
}
// 缓存未命中,从数据库查询
{
// 使用统一的窗口开始时间计算逻辑(考虑窗口过期情况)
startTime := account.GetCurrentWindowStartTime()
stats, err := s.usageLogRepo.GetAccountWindowStats(ctx, account.ID, startTime)
if err != nil {
// 失败开放:查询失败时允许调度
return true
}
// 使用标准费用(不含账号倍率)
currentCost = stats.StandardCost
// 设置缓存(忽略错误)
if s.sessionLimitCache != nil {
_ = s.sessionLimitCache.SetWindowCost(ctx, account.ID, currentCost)
}
}
checkSchedulability:
schedulability := account.CheckWindowCostSchedulability(currentCost)
switch schedulability {
case WindowCostSchedulable:
return true
case WindowCostStickyOnly:
return isSticky
case WindowCostNotSchedulable:
return false
}
return true
}
// checkAndRegisterSession 检查并注册会话,用于会话数量限制
// 仅适用于 Anthropic OAuth/SetupToken 账号
// sessionID: 会话标识符(使用粘性会话的 hash
// 返回 true 表示允许在限制内或会话已存在false 表示拒绝(超出限制且是新会话)
func (s *GatewayService) checkAndRegisterSession(ctx context.Context, account *Account, sessionID string) bool {
// 只检查 Anthropic OAuth/SetupToken 账号
if !account.IsAnthropicOAuthOrSetupToken() {
return true
}
maxSessions := account.GetMaxSessions()
if maxSessions <= 0 || sessionID == "" {
return true // 未启用会话限制或无会话ID
}
if s.sessionLimitCache == nil {
return true // 缓存不可用时允许通过
}
idleTimeout := time.Duration(account.GetSessionIdleTimeoutMinutes()) * time.Minute
allowed, err := s.sessionLimitCache.RegisterSession(ctx, account.ID, sessionID, maxSessions, idleTimeout)
if err != nil {
// 失败开放:缓存错误时允许通过
return true
}
return allowed
}
func (s *GatewayService) getSchedulableAccount(ctx context.Context, accountID int64) (*Account, error) {
if s.schedulerSnapshot != nil {
return s.schedulerSnapshot.GetAccount(ctx, accountID)
}
return s.accountRepo.GetByID(ctx, accountID)
}
func sortAccountsByPriorityAndLastUsed(accounts []*Account, preferOAuth bool) {
sort.SliceStable(accounts, func(i, j int) bool {
a, b := accounts[i], accounts[j]
if a.Priority != b.Priority {
return a.Priority < b.Priority
}
switch {
case a.LastUsedAt == nil && b.LastUsedAt != nil:
return true
case a.LastUsedAt != nil && b.LastUsedAt == nil:
return false
case a.LastUsedAt == nil && b.LastUsedAt == nil:
if preferOAuth && a.Type != b.Type {
return a.Type == AccountTypeOAuth
}
return false
default:
return a.LastUsedAt.Before(*b.LastUsedAt)
}
})
}
// sortCandidatesForFallback 根据配置选择排序策略
// mode: "last_used"(按最后使用时间) 或 "random"(随机)
func (s *GatewayService) sortCandidatesForFallback(accounts []*Account, preferOAuth bool, mode string) {
if mode == "random" {
// 先按优先级排序,然后在同优先级内随机打乱
sortAccountsByPriorityOnly(accounts, preferOAuth)
shuffleWithinPriority(accounts)
} else {
// 默认按最后使用时间排序
sortAccountsByPriorityAndLastUsed(accounts, preferOAuth)
}
}
// sortAccountsByPriorityOnly 仅按优先级排序
func sortAccountsByPriorityOnly(accounts []*Account, preferOAuth bool) {
sort.SliceStable(accounts, func(i, j int) bool {
a, b := accounts[i], accounts[j]
if a.Priority != b.Priority {
return a.Priority < b.Priority
}
if preferOAuth && a.Type != b.Type {
return a.Type == AccountTypeOAuth
}
return false
})
}
// shuffleWithinPriority 在同优先级内随机打乱顺序
func shuffleWithinPriority(accounts []*Account) {
if len(accounts) <= 1 {
return
}
r := mathrand.New(mathrand.NewSource(time.Now().UnixNano()))
start := 0
for start < len(accounts) {
priority := accounts[start].Priority
end := start + 1
for end < len(accounts) && accounts[end].Priority == priority {
end++
}
// 对 [start, end) 范围内的账户随机打乱
if end-start > 1 {
r.Shuffle(end-start, func(i, j int) {
accounts[start+i], accounts[start+j] = accounts[start+j], accounts[start+i]
})
}
start = end
}
}
// selectAccountForModelWithPlatform 选择单平台账户(完全隔离)
func (s *GatewayService) selectAccountForModelWithPlatform(ctx context.Context, groupID *int64, sessionHash string, requestedModel string, excludedIDs map[int64]struct{}, platform string) (*Account, error) {
// 对 Antigravity 平台,检查请求的模型系列是否在分组支持范围内
if platform == PlatformAntigravity && groupID != nil && requestedModel != "" {
if err := s.checkAntigravityModelScope(ctx, *groupID, requestedModel); err != nil {
return nil, err
}
}
preferOAuth := platform == PlatformGemini
routingAccountIDs := s.routingAccountIDsForRequest(ctx, groupID, requestedModel, platform)
var accounts []Account
accountsLoaded := false
// ============ Model Routing (legacy path): apply before sticky session ============
// When load-awareness is disabled (e.g. concurrency service not configured), we still honor model routing
// so switching model can switch upstream account within the same sticky session.
if len(routingAccountIDs) > 0 {
if s.debugModelRoutingEnabled() {
log.Printf("[ModelRoutingDebug] legacy routed begin: group_id=%v model=%s platform=%s session=%s routed_ids=%v",
derefGroupID(groupID), requestedModel, platform, shortSessionHash(sessionHash), routingAccountIDs)
}
// 1) Sticky session only applies if the bound account is within the routing set.
if sessionHash != "" && s.cache != nil {
accountID, err := s.cache.GetSessionAccountID(ctx, derefGroupID(groupID), sessionHash)
if err == nil && accountID > 0 && containsInt64(routingAccountIDs, accountID) {
if _, excluded := excludedIDs[accountID]; !excluded {
account, err := s.getSchedulableAccount(ctx, accountID)
// 检查账号分组归属和平台匹配(确保粘性会话不会跨分组或跨平台)
if err == nil {
clearSticky := shouldClearStickySession(account)
if clearSticky {
_ = s.cache.DeleteSessionAccountID(ctx, derefGroupID(groupID), sessionHash)
}
if !clearSticky && s.isAccountInGroup(account, groupID) && account.Platform == platform && account.IsSchedulableForModel(requestedModel) && (requestedModel == "" || s.isModelSupportedByAccount(account, requestedModel)) {
if err := s.cache.RefreshSessionTTL(ctx, derefGroupID(groupID), sessionHash, stickySessionTTL); err != nil {
log.Printf("refresh session ttl failed: session=%s err=%v", sessionHash, err)
}
if s.debugModelRoutingEnabled() {
log.Printf("[ModelRoutingDebug] legacy routed sticky hit: group_id=%v model=%s session=%s account=%d", derefGroupID(groupID), requestedModel, shortSessionHash(sessionHash), accountID)
}
return account, nil
}
}
}
}
}
// 2) Select an account from the routed candidates.
forcePlatform, hasForcePlatform := ctx.Value(ctxkey.ForcePlatform).(string)
if hasForcePlatform && forcePlatform == "" {
hasForcePlatform = false
}
var err error
accounts, _, err = s.listSchedulableAccounts(ctx, groupID, platform, hasForcePlatform)
if err != nil {
return nil, fmt.Errorf("query accounts failed: %w", err)
}
accountsLoaded = true
routingSet := make(map[int64]struct{}, len(routingAccountIDs))
for _, id := range routingAccountIDs {
if id > 0 {
routingSet[id] = struct{}{}
}
}
var selected *Account
for i := range accounts {
acc := &accounts[i]
if _, ok := routingSet[acc.ID]; !ok {
continue
}
if _, excluded := excludedIDs[acc.ID]; excluded {
continue
}
// Scheduler snapshots can be temporarily stale; re-check schedulability here to
// avoid selecting accounts that were recently rate-limited/overloaded.
if !acc.IsSchedulable() {
continue
}
if !acc.IsSchedulableForModel(requestedModel) {
continue
}
if requestedModel != "" && !s.isModelSupportedByAccount(acc, requestedModel) {
continue
}
if selected == nil {
selected = acc
continue
}
if acc.Priority < selected.Priority {
selected = acc
} else if acc.Priority == selected.Priority {
switch {
case acc.LastUsedAt == nil && selected.LastUsedAt != nil:
selected = acc
case acc.LastUsedAt != nil && selected.LastUsedAt == nil:
// keep selected (never used is preferred)
case acc.LastUsedAt == nil && selected.LastUsedAt == nil:
if preferOAuth && acc.Type != selected.Type && acc.Type == AccountTypeOAuth {
selected = acc
}
default:
if acc.LastUsedAt.Before(*selected.LastUsedAt) {
selected = acc
}
}
}
}
if selected != nil {
if sessionHash != "" && s.cache != nil {
if err := s.cache.SetSessionAccountID(ctx, derefGroupID(groupID), sessionHash, selected.ID, stickySessionTTL); err != nil {
log.Printf("set session account failed: session=%s account_id=%d err=%v", sessionHash, selected.ID, err)
}
}
if s.debugModelRoutingEnabled() {
log.Printf("[ModelRoutingDebug] legacy routed select: group_id=%v model=%s session=%s account=%d", derefGroupID(groupID), requestedModel, shortSessionHash(sessionHash), selected.ID)
}
return selected, nil
}
log.Printf("[ModelRouting] No routed accounts available for model=%s, falling back to normal selection", requestedModel)
}
// 1. 查询粘性会话
if sessionHash != "" && s.cache != nil {
accountID, err := s.cache.GetSessionAccountID(ctx, derefGroupID(groupID), sessionHash)
if err == nil && accountID > 0 {
if _, excluded := excludedIDs[accountID]; !excluded {
account, err := s.getSchedulableAccount(ctx, accountID)
// 检查账号分组归属和平台匹配(确保粘性会话不会跨分组或跨平台)
if err == nil {
clearSticky := shouldClearStickySession(account)
if clearSticky {
_ = s.cache.DeleteSessionAccountID(ctx, derefGroupID(groupID), sessionHash)
}
if !clearSticky && s.isAccountInGroup(account, groupID) && account.Platform == platform && account.IsSchedulableForModel(requestedModel) && (requestedModel == "" || s.isModelSupportedByAccount(account, requestedModel)) {
if err := s.cache.RefreshSessionTTL(ctx, derefGroupID(groupID), sessionHash, stickySessionTTL); err != nil {
log.Printf("refresh session ttl failed: session=%s err=%v", sessionHash, err)
}
return account, nil
}
}
}
}
}
// 2. 获取可调度账号列表(单平台)
if !accountsLoaded {
forcePlatform, hasForcePlatform := ctx.Value(ctxkey.ForcePlatform).(string)
if hasForcePlatform && forcePlatform == "" {
hasForcePlatform = false
}
var err error
accounts, _, err = s.listSchedulableAccounts(ctx, groupID, platform, hasForcePlatform)
if err != nil {
return nil, fmt.Errorf("query accounts failed: %w", err)
}
}
// 3. 按优先级+最久未用选择(考虑模型支持)
var selected *Account
for i := range accounts {
acc := &accounts[i]
if _, excluded := excludedIDs[acc.ID]; excluded {
continue
}
// Scheduler snapshots can be temporarily stale; re-check schedulability here to
// avoid selecting accounts that were recently rate-limited/overloaded.
if !acc.IsSchedulable() {
continue
}
if !acc.IsSchedulableForModel(requestedModel) {
continue
}
if requestedModel != "" && !s.isModelSupportedByAccount(acc, requestedModel) {
continue
}
if selected == nil {
selected = acc
continue
}
if acc.Priority < selected.Priority {
selected = acc
} else if acc.Priority == selected.Priority {
switch {
case acc.LastUsedAt == nil && selected.LastUsedAt != nil:
selected = acc
case acc.LastUsedAt != nil && selected.LastUsedAt == nil:
// keep selected (never used is preferred)
case acc.LastUsedAt == nil && selected.LastUsedAt == nil:
if preferOAuth && acc.Type != selected.Type && acc.Type == AccountTypeOAuth {
selected = acc
}
default:
if acc.LastUsedAt.Before(*selected.LastUsedAt) {
selected = acc
}
}
}
}
if selected == nil {
if requestedModel != "" {
return nil, fmt.Errorf("no available accounts supporting model: %s", requestedModel)
}
return nil, errors.New("no available accounts")
}
// 4. 建立粘性绑定
if sessionHash != "" && s.cache != nil {
if err := s.cache.SetSessionAccountID(ctx, derefGroupID(groupID), sessionHash, selected.ID, stickySessionTTL); err != nil {
log.Printf("set session account failed: session=%s account_id=%d err=%v", sessionHash, selected.ID, err)
}
}
return selected, nil
}
// selectAccountWithMixedScheduling 选择账户(支持混合调度)
// 查询原生平台账户 + 启用 mixed_scheduling 的 antigravity 账户
func (s *GatewayService) selectAccountWithMixedScheduling(ctx context.Context, groupID *int64, sessionHash string, requestedModel string, excludedIDs map[int64]struct{}, nativePlatform string) (*Account, error) {
preferOAuth := nativePlatform == PlatformGemini
routingAccountIDs := s.routingAccountIDsForRequest(ctx, groupID, requestedModel, nativePlatform)
var accounts []Account
accountsLoaded := false
// ============ Model Routing (legacy path): apply before sticky session ============
if len(routingAccountIDs) > 0 {
if s.debugModelRoutingEnabled() {
log.Printf("[ModelRoutingDebug] legacy mixed routed begin: group_id=%v model=%s platform=%s session=%s routed_ids=%v",
derefGroupID(groupID), requestedModel, nativePlatform, shortSessionHash(sessionHash), routingAccountIDs)
}
// 1) Sticky session only applies if the bound account is within the routing set.
if sessionHash != "" && s.cache != nil {
accountID, err := s.cache.GetSessionAccountID(ctx, derefGroupID(groupID), sessionHash)
if err == nil && accountID > 0 && containsInt64(routingAccountIDs, accountID) {
if _, excluded := excludedIDs[accountID]; !excluded {
account, err := s.getSchedulableAccount(ctx, accountID)
// 检查账号分组归属和有效性原生平台直接匹配antigravity 需要启用混合调度
if err == nil {
clearSticky := shouldClearStickySession(account)
if clearSticky {
_ = s.cache.DeleteSessionAccountID(ctx, derefGroupID(groupID), sessionHash)
}
if !clearSticky && s.isAccountInGroup(account, groupID) && account.IsSchedulableForModel(requestedModel) && (requestedModel == "" || s.isModelSupportedByAccount(account, requestedModel)) {
if account.Platform == nativePlatform || (account.Platform == PlatformAntigravity && account.IsMixedSchedulingEnabled()) {
if err := s.cache.RefreshSessionTTL(ctx, derefGroupID(groupID), sessionHash, stickySessionTTL); err != nil {
log.Printf("refresh session ttl failed: session=%s err=%v", sessionHash, err)
}
if s.debugModelRoutingEnabled() {
log.Printf("[ModelRoutingDebug] legacy mixed routed sticky hit: group_id=%v model=%s session=%s account=%d", derefGroupID(groupID), requestedModel, shortSessionHash(sessionHash), accountID)
}
return account, nil
}
}
}
}
}
}
// 2) Select an account from the routed candidates.
var err error
accounts, _, err = s.listSchedulableAccounts(ctx, groupID, nativePlatform, false)
if err != nil {
return nil, fmt.Errorf("query accounts failed: %w", err)
}
accountsLoaded = true
routingSet := make(map[int64]struct{}, len(routingAccountIDs))
for _, id := range routingAccountIDs {
if id > 0 {
routingSet[id] = struct{}{}
}
}
var selected *Account
for i := range accounts {
acc := &accounts[i]
if _, ok := routingSet[acc.ID]; !ok {
continue
}
if _, excluded := excludedIDs[acc.ID]; excluded {
continue
}
// Scheduler snapshots can be temporarily stale; re-check schedulability here to
// avoid selecting accounts that were recently rate-limited/overloaded.
if !acc.IsSchedulable() {
continue
}
// 过滤原生平台直接通过antigravity 需要启用混合调度
if acc.Platform == PlatformAntigravity && !acc.IsMixedSchedulingEnabled() {
continue
}
if !acc.IsSchedulableForModel(requestedModel) {
continue
}
if requestedModel != "" && !s.isModelSupportedByAccount(acc, requestedModel) {
continue
}
if selected == nil {
selected = acc
continue
}
if acc.Priority < selected.Priority {
selected = acc
} else if acc.Priority == selected.Priority {
switch {
case acc.LastUsedAt == nil && selected.LastUsedAt != nil:
selected = acc
case acc.LastUsedAt != nil && selected.LastUsedAt == nil:
// keep selected (never used is preferred)
case acc.LastUsedAt == nil && selected.LastUsedAt == nil:
if preferOAuth && acc.Platform == PlatformGemini && selected.Platform == PlatformGemini && acc.Type != selected.Type && acc.Type == AccountTypeOAuth {
selected = acc
}
default:
if acc.LastUsedAt.Before(*selected.LastUsedAt) {
selected = acc
}
}
}
}
if selected != nil {
if sessionHash != "" && s.cache != nil {
if err := s.cache.SetSessionAccountID(ctx, derefGroupID(groupID), sessionHash, selected.ID, stickySessionTTL); err != nil {
log.Printf("set session account failed: session=%s account_id=%d err=%v", sessionHash, selected.ID, err)
}
}
if s.debugModelRoutingEnabled() {
log.Printf("[ModelRoutingDebug] legacy mixed routed select: group_id=%v model=%s session=%s account=%d", derefGroupID(groupID), requestedModel, shortSessionHash(sessionHash), selected.ID)
}
return selected, nil
}
log.Printf("[ModelRouting] No routed accounts available for model=%s, falling back to normal selection", requestedModel)
}
// 1. 查询粘性会话
if sessionHash != "" && s.cache != nil {
accountID, err := s.cache.GetSessionAccountID(ctx, derefGroupID(groupID), sessionHash)
if err == nil && accountID > 0 {
if _, excluded := excludedIDs[accountID]; !excluded {
account, err := s.getSchedulableAccount(ctx, accountID)
// 检查账号分组归属和有效性原生平台直接匹配antigravity 需要启用混合调度
if err == nil {
clearSticky := shouldClearStickySession(account)
if clearSticky {
_ = s.cache.DeleteSessionAccountID(ctx, derefGroupID(groupID), sessionHash)
}
if !clearSticky && s.isAccountInGroup(account, groupID) && account.IsSchedulableForModel(requestedModel) && (requestedModel == "" || s.isModelSupportedByAccount(account, requestedModel)) {
if account.Platform == nativePlatform || (account.Platform == PlatformAntigravity && account.IsMixedSchedulingEnabled()) {
if err := s.cache.RefreshSessionTTL(ctx, derefGroupID(groupID), sessionHash, stickySessionTTL); err != nil {
log.Printf("refresh session ttl failed: session=%s err=%v", sessionHash, err)
}
return account, nil
}
}
}
}
}
}
// 2. 获取可调度账号列表
if !accountsLoaded {
var err error
accounts, _, err = s.listSchedulableAccounts(ctx, groupID, nativePlatform, false)
if err != nil {
return nil, fmt.Errorf("query accounts failed: %w", err)
}
}
// 3. 按优先级+最久未用选择(考虑模型支持和混合调度)
var selected *Account
for i := range accounts {
acc := &accounts[i]
if _, excluded := excludedIDs[acc.ID]; excluded {
continue
}
// Scheduler snapshots can be temporarily stale; re-check schedulability here to
// avoid selecting accounts that were recently rate-limited/overloaded.
if !acc.IsSchedulable() {
continue
}
// 过滤原生平台直接通过antigravity 需要启用混合调度
if acc.Platform == PlatformAntigravity && !acc.IsMixedSchedulingEnabled() {
continue
}
if !acc.IsSchedulableForModel(requestedModel) {
continue
}
if requestedModel != "" && !s.isModelSupportedByAccount(acc, requestedModel) {
continue
}
if selected == nil {
selected = acc
continue
}
if acc.Priority < selected.Priority {
selected = acc
} else if acc.Priority == selected.Priority {
switch {
case acc.LastUsedAt == nil && selected.LastUsedAt != nil:
selected = acc
case acc.LastUsedAt != nil && selected.LastUsedAt == nil:
// keep selected (never used is preferred)
case acc.LastUsedAt == nil && selected.LastUsedAt == nil:
if preferOAuth && acc.Platform == PlatformGemini && selected.Platform == PlatformGemini && acc.Type != selected.Type && acc.Type == AccountTypeOAuth {
selected = acc
}
default:
if acc.LastUsedAt.Before(*selected.LastUsedAt) {
selected = acc
}
}
}
}
if selected == nil {
if requestedModel != "" {
return nil, fmt.Errorf("no available accounts supporting model: %s", requestedModel)
}
return nil, errors.New("no available accounts")
}
// 4. 建立粘性绑定
if sessionHash != "" && s.cache != nil {
if err := s.cache.SetSessionAccountID(ctx, derefGroupID(groupID), sessionHash, selected.ID, stickySessionTTL); err != nil {
log.Printf("set session account failed: session=%s account_id=%d err=%v", sessionHash, selected.ID, err)
}
}
return selected, nil
}
// isModelSupportedByAccount 根据账户平台检查模型支持
func (s *GatewayService) isModelSupportedByAccount(account *Account, requestedModel string) bool {
if account.Platform == PlatformAntigravity {
// Antigravity 平台使用专门的模型支持检查
return IsAntigravityModelSupported(requestedModel)
}
if account.Platform == PlatformAnthropic {
requestedModel = normalizeClaudeModelForAnthropic(requestedModel)
}
// Gemini API Key 账户直接透传,由上游判断模型是否支持
if account.Platform == PlatformGemini && account.Type == AccountTypeAPIKey {
return true
}
// 其他平台使用账户的模型支持检查
return account.IsModelSupported(requestedModel)
}
// IsAntigravityModelSupported 检查 Antigravity 平台是否支持指定模型
// 所有 claude- 和 gemini- 前缀的模型都能通过映射或透传支持
func IsAntigravityModelSupported(requestedModel string) bool {
return strings.HasPrefix(requestedModel, "claude-") ||
strings.HasPrefix(requestedModel, "gemini-")
}
// GetAccessToken 获取账号凭证
func (s *GatewayService) GetAccessToken(ctx context.Context, account *Account) (string, string, error) {
switch account.Type {
case AccountTypeOAuth, AccountTypeSetupToken:
// Both oauth and setup-token use OAuth token flow
return s.getOAuthToken(ctx, account)
case AccountTypeAPIKey:
apiKey := account.GetCredential("api_key")
if apiKey == "" {
return "", "", errors.New("api_key not found in credentials")
}
return apiKey, "apikey", nil
default:
return "", "", fmt.Errorf("unsupported account type: %s", account.Type)
}
}
func (s *GatewayService) getOAuthToken(ctx context.Context, account *Account) (string, string, error) {
// 对于 Anthropic OAuth 账号,使用 ClaudeTokenProvider 获取缓存的 token
if account.Platform == PlatformAnthropic && account.Type == AccountTypeOAuth && s.claudeTokenProvider != nil {
accessToken, err := s.claudeTokenProvider.GetAccessToken(ctx, account)
if err != nil {
return "", "", err
}
return accessToken, "oauth", nil
}
// 其他情况Gemini 有自己的 TokenProvidersetup-token 类型等)直接从账号读取
accessToken := account.GetCredential("access_token")
if accessToken == "" {
return "", "", errors.New("access_token not found in credentials")
}
// Token刷新由后台 TokenRefreshService 处理此处只返回当前token
return accessToken, "oauth", nil
}
// 重试相关常量
const (
// 最大尝试次数(包含首次请求)。过多重试会导致请求堆积与资源耗尽。
maxRetryAttempts = 5
// 指数退避:第 N 次失败后的等待 = retryBaseDelay * 2^(N-1),并且上限为 retryMaxDelay。
retryBaseDelay = 300 * time.Millisecond
retryMaxDelay = 3 * time.Second
// 最大重试耗时(包含请求本身耗时 + 退避等待时间)。
// 用于防止极端情况下 goroutine 长时间堆积导致资源耗尽。
maxRetryElapsed = 10 * time.Second
)
func (s *GatewayService) shouldRetryUpstreamError(account *Account, statusCode int) bool {
// OAuth/Setup Token 账号:仅 403 重试
if account.IsOAuth() {
return statusCode == 403
}
// API Key 账号:未配置的错误码重试
return !account.ShouldHandleErrorCode(statusCode)
}
// shouldFailoverUpstreamError determines whether an upstream error should trigger account failover.
func (s *GatewayService) shouldFailoverUpstreamError(statusCode int) bool {
switch statusCode {
case 401, 403, 429, 529:
return true
default:
return statusCode >= 500
}
}
func retryBackoffDelay(attempt int) time.Duration {
// attempt 从 1 开始,表示第 attempt 次请求刚失败,需要等待后进行第 attempt+1 次请求。
if attempt <= 0 {
return retryBaseDelay
}
delay := retryBaseDelay * time.Duration(1<<(attempt-1))
if delay > retryMaxDelay {
return retryMaxDelay
}
return delay
}
func sleepWithContext(ctx context.Context, d time.Duration) error {
if d <= 0 {
return nil
}
timer := time.NewTimer(d)
defer func() {
if !timer.Stop() {
select {
case <-timer.C:
default:
}
}
}()
select {
case <-ctx.Done():
return ctx.Err()
case <-timer.C:
return nil
}
}
// isClaudeCodeClient 判断请求是否来自 Claude Code 客户端
// 简化判断User-Agent 匹配 + metadata.user_id 存在
func isClaudeCodeClient(userAgent string, metadataUserID string) bool {
if metadataUserID == "" {
return false
}
return claudeCliUserAgentRe.MatchString(userAgent)
}
// systemIncludesClaudeCodePrompt 检查 system 中是否已包含 Claude Code 提示词
// 使用前缀匹配支持多种变体标准版、Agent SDK 版等)
func systemIncludesClaudeCodePrompt(system any) bool {
switch v := system.(type) {
case string:
return hasClaudeCodePrefix(v)
case []any:
for _, item := range v {
if m, ok := item.(map[string]any); ok {
if text, ok := m["text"].(string); ok && hasClaudeCodePrefix(text) {
return true
}
}
}
}
return false
}
// hasClaudeCodePrefix 检查文本是否以 Claude Code 提示词的特征前缀开头
func hasClaudeCodePrefix(text string) bool {
for _, prefix := range claudeCodePromptPrefixes {
if strings.HasPrefix(text, prefix) {
return true
}
}
return false
}
// injectClaudeCodePrompt 在 system 开头注入 Claude Code 提示词
// 处理 null、字符串、数组三种格式
func injectClaudeCodePrompt(body []byte, system any) []byte {
claudeCodeBlock := map[string]any{
"type": "text",
"text": claudeCodeSystemPrompt,
"cache_control": map[string]string{"type": "ephemeral"},
}
var newSystem []any
switch v := system.(type) {
case nil:
newSystem = []any{claudeCodeBlock}
case string:
if v == "" || v == claudeCodeSystemPrompt {
newSystem = []any{claudeCodeBlock}
} else {
newSystem = []any{claudeCodeBlock, map[string]any{"type": "text", "text": v}}
}
case []any:
newSystem = make([]any, 0, len(v)+1)
newSystem = append(newSystem, claudeCodeBlock)
for _, item := range v {
if m, ok := item.(map[string]any); ok {
if text, ok := m["text"].(string); ok && text == claudeCodeSystemPrompt {
continue
}
}
newSystem = append(newSystem, item)
}
default:
newSystem = []any{claudeCodeBlock}
}
result, err := sjson.SetBytes(body, "system", newSystem)
if err != nil {
log.Printf("Warning: failed to inject Claude Code prompt: %v", err)
return body
}
return result
}
// enforceCacheControlLimit 强制执行 cache_control 块数量限制(最多 4 个)
// 超限时优先从 messages 中移除 cache_control保护 system 中的缓存控制
func enforceCacheControlLimit(body []byte) []byte {
var data map[string]any
if err := json.Unmarshal(body, &data); err != nil {
return body
}
// 清理 thinking 块中的非法 cache_controlthinking 块不支持该字段)
removeCacheControlFromThinkingBlocks(data)
// 计算当前 cache_control 块数量
count := countCacheControlBlocks(data)
if count <= maxCacheControlBlocks {
return body
}
// 超限:优先从 messages 中移除,再从 system 中移除
for count > maxCacheControlBlocks {
if removeCacheControlFromMessages(data) {
count--
continue
}
if removeCacheControlFromSystem(data) {
count--
continue
}
break
}
result, err := json.Marshal(data)
if err != nil {
return body
}
return result
}
// countCacheControlBlocks 统计 system 和 messages 中的 cache_control 块数量
// 注意thinking 块不支持 cache_control统计时跳过
func countCacheControlBlocks(data map[string]any) int {
count := 0
// 统计 system 中的块
if system, ok := data["system"].([]any); ok {
for _, item := range system {
if m, ok := item.(map[string]any); ok {
// thinking 块不支持 cache_control跳过
if blockType, _ := m["type"].(string); blockType == "thinking" {
continue
}
if _, has := m["cache_control"]; has {
count++
}
}
}
}
// 统计 messages 中的块
if messages, ok := data["messages"].([]any); ok {
for _, msg := range messages {
if msgMap, ok := msg.(map[string]any); ok {
if content, ok := msgMap["content"].([]any); ok {
for _, item := range content {
if m, ok := item.(map[string]any); ok {
// thinking 块不支持 cache_control跳过
if blockType, _ := m["type"].(string); blockType == "thinking" {
continue
}
if _, has := m["cache_control"]; has {
count++
}
}
}
}
}
}
}
return count
}
// removeCacheControlFromMessages 从 messages 中移除一个 cache_control从头开始
// 返回 true 表示成功移除false 表示没有可移除的
// 注意:跳过 thinking 块(它不支持 cache_control
func removeCacheControlFromMessages(data map[string]any) bool {
messages, ok := data["messages"].([]any)
if !ok {
return false
}
for _, msg := range messages {
msgMap, ok := msg.(map[string]any)
if !ok {
continue
}
content, ok := msgMap["content"].([]any)
if !ok {
continue
}
for _, item := range content {
if m, ok := item.(map[string]any); ok {
// thinking 块不支持 cache_control跳过
if blockType, _ := m["type"].(string); blockType == "thinking" {
continue
}
if _, has := m["cache_control"]; has {
delete(m, "cache_control")
return true
}
}
}
}
return false
}
// removeCacheControlFromSystem 从 system 中移除一个 cache_control从尾部开始保护注入的 prompt
// 返回 true 表示成功移除false 表示没有可移除的
// 注意:跳过 thinking 块(它不支持 cache_control
func removeCacheControlFromSystem(data map[string]any) bool {
system, ok := data["system"].([]any)
if !ok {
return false
}
// 从尾部开始移除,保护开头注入的 Claude Code prompt
for i := len(system) - 1; i >= 0; i-- {
if m, ok := system[i].(map[string]any); ok {
// thinking 块不支持 cache_control跳过
if blockType, _ := m["type"].(string); blockType == "thinking" {
continue
}
if _, has := m["cache_control"]; has {
delete(m, "cache_control")
return true
}
}
}
return false
}
// removeCacheControlFromThinkingBlocks 强制清理所有 thinking 块中的非法 cache_control
// thinking 块不支持 cache_control 字段,这个函数确保所有 thinking 块都不含该字段
func removeCacheControlFromThinkingBlocks(data map[string]any) {
// 清理 system 中的 thinking 块
if system, ok := data["system"].([]any); ok {
for _, item := range system {
if m, ok := item.(map[string]any); ok {
if blockType, _ := m["type"].(string); blockType == "thinking" {
if _, has := m["cache_control"]; has {
delete(m, "cache_control")
log.Printf("[Warning] Removed illegal cache_control from thinking block in system")
}
}
}
}
}
// 清理 messages 中的 thinking 块
if messages, ok := data["messages"].([]any); ok {
for msgIdx, msg := range messages {
if msgMap, ok := msg.(map[string]any); ok {
if content, ok := msgMap["content"].([]any); ok {
for contentIdx, item := range content {
if m, ok := item.(map[string]any); ok {
if blockType, _ := m["type"].(string); blockType == "thinking" {
if _, has := m["cache_control"]; has {
delete(m, "cache_control")
log.Printf("[Warning] Removed illegal cache_control from thinking block in messages[%d].content[%d]", msgIdx, contentIdx)
}
}
}
}
}
}
}
}
}
// Forward 转发请求到Claude API
func (s *GatewayService) Forward(ctx context.Context, c *gin.Context, account *Account, parsed *ParsedRequest) (*ForwardResult, error) {
startTime := time.Now()
if parsed == nil {
return nil, fmt.Errorf("parse request: empty request")
}
body := parsed.Body
reqModel := parsed.Model
reqStream := parsed.Stream
// 智能注入 Claude Code 系统提示词(仅 OAuth/SetupToken 账号需要)
// 条件1) OAuth/SetupToken 账号 2) 不是 Claude Code 客户端 3) 不是 Haiku 模型 4) system 中还没有 Claude Code 提示词
if account.IsOAuth() &&
!isClaudeCodeClient(c.GetHeader("User-Agent"), parsed.MetadataUserID) &&
!strings.Contains(strings.ToLower(reqModel), "haiku") &&
!systemIncludesClaudeCodePrompt(parsed.System) {
body = injectClaudeCodePrompt(body, parsed.System)
}
// 强制执行 cache_control 块数量限制(最多 4 个)
body = enforceCacheControlLimit(body)
// 应用模型映射APIKey 明确映射优先,其次使用 Anthropic 前缀映射)
originalModel := reqModel
mappedModel := reqModel
mappingSource := ""
if account.Type == AccountTypeAPIKey {
mappedModel = account.GetMappedModel(reqModel)
if mappedModel != reqModel {
mappingSource = "account"
}
}
if mappingSource == "" && account.Platform == PlatformAnthropic {
normalized := normalizeClaudeModelForAnthropic(reqModel)
if normalized != reqModel {
mappedModel = normalized
mappingSource = "prefix"
}
}
if mappedModel != reqModel {
// 替换请求体中的模型名
body = s.replaceModelInBody(body, mappedModel)
reqModel = mappedModel
log.Printf("Model mapping applied: %s -> %s (account: %s, source=%s)", originalModel, mappedModel, account.Name, mappingSource)
}
// 获取凭证
token, tokenType, err := s.GetAccessToken(ctx, account)
if err != nil {
return nil, err
}
// 获取代理URL
proxyURL := ""
if account.ProxyID != nil && account.Proxy != nil {
proxyURL = account.Proxy.URL()
}
// 调试日志:记录即将转发的账号信息
log.Printf("[Forward] Using account: ID=%d Name=%s Platform=%s Type=%s TLSFingerprint=%v Proxy=%s",
account.ID, account.Name, account.Platform, account.Type, account.IsTLSFingerprintEnabled(), proxyURL)
// 重试循环
var resp *http.Response
retryStart := time.Now()
for attempt := 1; attempt <= maxRetryAttempts; attempt++ {
// 构建上游请求(每次重试需要重新构建,因为请求体需要重新读取)
upstreamReq, err := s.buildUpstreamRequest(ctx, c, account, body, token, tokenType, reqModel)
// Capture upstream request body for ops retry of this attempt.
c.Set(OpsUpstreamRequestBodyKey, string(body))
if err != nil {
return nil, err
}
// 发送请求
resp, err = s.httpUpstream.DoWithTLS(upstreamReq, proxyURL, account.ID, account.Concurrency, account.IsTLSFingerprintEnabled())
if err != nil {
if resp != nil && resp.Body != nil {
_ = resp.Body.Close()
}
// Ensure the client receives an error response (handlers assume Forward writes on non-failover errors).
safeErr := sanitizeUpstreamErrorMessage(err.Error())
setOpsUpstreamError(c, 0, safeErr, "")
appendOpsUpstreamError(c, OpsUpstreamErrorEvent{
Platform: account.Platform,
AccountID: account.ID,
AccountName: account.Name,
UpstreamStatusCode: 0,
Kind: "request_error",
Message: safeErr,
})
c.JSON(http.StatusBadGateway, gin.H{
"type": "error",
"error": gin.H{
"type": "upstream_error",
"message": "Upstream request failed",
},
})
return nil, fmt.Errorf("upstream request failed: %s", safeErr)
}
// 优先检测thinking block签名错误400并重试一次
if resp.StatusCode == 400 {
respBody, readErr := io.ReadAll(io.LimitReader(resp.Body, 2<<20))
if readErr == nil {
_ = resp.Body.Close()
if s.isThinkingBlockSignatureError(respBody) {
appendOpsUpstreamError(c, OpsUpstreamErrorEvent{
Platform: account.Platform,
AccountID: account.ID,
AccountName: account.Name,
UpstreamStatusCode: resp.StatusCode,
UpstreamRequestID: resp.Header.Get("x-request-id"),
Kind: "signature_error",
Message: extractUpstreamErrorMessage(respBody),
Detail: func() string {
if s.cfg != nil && s.cfg.Gateway.LogUpstreamErrorBody {
return truncateString(string(respBody), s.cfg.Gateway.LogUpstreamErrorBodyMaxBytes)
}
return ""
}(),
})
looksLikeToolSignatureError := func(msg string) bool {
m := strings.ToLower(msg)
return strings.Contains(m, "tool_use") ||
strings.Contains(m, "tool_result") ||
strings.Contains(m, "functioncall") ||
strings.Contains(m, "function_call") ||
strings.Contains(m, "functionresponse") ||
strings.Contains(m, "function_response")
}
// 避免在重试预算已耗尽时再发起额外请求
if time.Since(retryStart) >= maxRetryElapsed {
resp.Body = io.NopCloser(bytes.NewReader(respBody))
break
}
log.Printf("Account %d: detected thinking block signature error, retrying with filtered thinking blocks", account.ID)
// Conservative two-stage fallback:
// 1) Disable thinking + thinking->text (preserve content)
// 2) Only if upstream still errors AND error message points to tool/function signature issues:
// also downgrade tool_use/tool_result blocks to text.
filteredBody := FilterThinkingBlocksForRetry(body)
retryReq, buildErr := s.buildUpstreamRequest(ctx, c, account, filteredBody, token, tokenType, reqModel)
if buildErr == nil {
retryResp, retryErr := s.httpUpstream.DoWithTLS(retryReq, proxyURL, account.ID, account.Concurrency, account.IsTLSFingerprintEnabled())
if retryErr == nil {
if retryResp.StatusCode < 400 {
log.Printf("Account %d: signature error retry succeeded (thinking downgraded)", account.ID)
resp = retryResp
break
}
retryRespBody, retryReadErr := io.ReadAll(io.LimitReader(retryResp.Body, 2<<20))
_ = retryResp.Body.Close()
if retryReadErr == nil && retryResp.StatusCode == 400 && s.isThinkingBlockSignatureError(retryRespBody) {
appendOpsUpstreamError(c, OpsUpstreamErrorEvent{
Platform: account.Platform,
AccountID: account.ID,
AccountName: account.Name,
UpstreamStatusCode: retryResp.StatusCode,
UpstreamRequestID: retryResp.Header.Get("x-request-id"),
Kind: "signature_retry_thinking",
Message: extractUpstreamErrorMessage(retryRespBody),
Detail: func() string {
if s.cfg != nil && s.cfg.Gateway.LogUpstreamErrorBody {
return truncateString(string(retryRespBody), s.cfg.Gateway.LogUpstreamErrorBodyMaxBytes)
}
return ""
}(),
})
msg2 := extractUpstreamErrorMessage(retryRespBody)
if looksLikeToolSignatureError(msg2) && time.Since(retryStart) < maxRetryElapsed {
log.Printf("Account %d: signature retry still failing and looks tool-related, retrying with tool blocks downgraded", account.ID)
filteredBody2 := FilterSignatureSensitiveBlocksForRetry(body)
retryReq2, buildErr2 := s.buildUpstreamRequest(ctx, c, account, filteredBody2, token, tokenType, reqModel)
if buildErr2 == nil {
retryResp2, retryErr2 := s.httpUpstream.DoWithTLS(retryReq2, proxyURL, account.ID, account.Concurrency, account.IsTLSFingerprintEnabled())
if retryErr2 == nil {
resp = retryResp2
break
}
if retryResp2 != nil && retryResp2.Body != nil {
_ = retryResp2.Body.Close()
}
appendOpsUpstreamError(c, OpsUpstreamErrorEvent{
Platform: account.Platform,
AccountID: account.ID,
AccountName: account.Name,
UpstreamStatusCode: 0,
Kind: "signature_retry_tools_request_error",
Message: sanitizeUpstreamErrorMessage(retryErr2.Error()),
})
log.Printf("Account %d: tool-downgrade signature retry failed: %v", account.ID, retryErr2)
} else {
log.Printf("Account %d: tool-downgrade signature retry build failed: %v", account.ID, buildErr2)
}
}
}
// Fall back to the original retry response context.
resp = &http.Response{
StatusCode: retryResp.StatusCode,
Header: retryResp.Header.Clone(),
Body: io.NopCloser(bytes.NewReader(retryRespBody)),
}
break
}
if retryResp != nil && retryResp.Body != nil {
_ = retryResp.Body.Close()
}
log.Printf("Account %d: signature error retry failed: %v", account.ID, retryErr)
} else {
log.Printf("Account %d: signature error retry build request failed: %v", account.ID, buildErr)
}
// Retry failed: restore original response body and continue handling.
resp.Body = io.NopCloser(bytes.NewReader(respBody))
break
}
// 不是thinking签名错误恢复响应体
resp.Body = io.NopCloser(bytes.NewReader(respBody))
}
}
// 检查是否需要通用重试排除400因为400已经在上面特殊处理过了
if resp.StatusCode >= 400 && resp.StatusCode != 400 && s.shouldRetryUpstreamError(account, resp.StatusCode) {
if attempt < maxRetryAttempts {
elapsed := time.Since(retryStart)
if elapsed >= maxRetryElapsed {
break
}
delay := retryBackoffDelay(attempt)
remaining := maxRetryElapsed - elapsed
if delay > remaining {
delay = remaining
}
if delay <= 0 {
break
}
respBody, _ := io.ReadAll(io.LimitReader(resp.Body, 2<<20))
_ = resp.Body.Close()
appendOpsUpstreamError(c, OpsUpstreamErrorEvent{
Platform: account.Platform,
AccountID: account.ID,
AccountName: account.Name,
UpstreamStatusCode: resp.StatusCode,
UpstreamRequestID: resp.Header.Get("x-request-id"),
Kind: "retry",
Message: extractUpstreamErrorMessage(respBody),
Detail: func() string {
if s.cfg != nil && s.cfg.Gateway.LogUpstreamErrorBody {
return truncateString(string(respBody), s.cfg.Gateway.LogUpstreamErrorBodyMaxBytes)
}
return ""
}(),
})
log.Printf("Account %d: upstream error %d, retry %d/%d after %v (elapsed=%v/%v)",
account.ID, resp.StatusCode, attempt, maxRetryAttempts, delay, elapsed, maxRetryElapsed)
if err := sleepWithContext(ctx, delay); err != nil {
return nil, err
}
continue
}
// 最后一次尝试也失败,跳出循环处理重试耗尽
break
}
// 不需要重试(成功或不可重试的错误),跳出循环
// DEBUG: 输出响应 headers用于检测 rate limit 信息)
if account.Platform == PlatformGemini && resp.StatusCode < 400 {
log.Printf("[DEBUG] Gemini API Response Headers for account %d:", account.ID)
for k, v := range resp.Header {
log.Printf("[DEBUG] %s: %v", k, v)
}
}
break
}
if resp == nil || resp.Body == nil {
return nil, errors.New("upstream request failed: empty response")
}
defer func() { _ = resp.Body.Close() }()
// 处理重试耗尽的情况
if resp.StatusCode >= 400 && s.shouldRetryUpstreamError(account, resp.StatusCode) {
if s.shouldFailoverUpstreamError(resp.StatusCode) {
respBody, _ := io.ReadAll(io.LimitReader(resp.Body, 2<<20))
_ = resp.Body.Close()
resp.Body = io.NopCloser(bytes.NewReader(respBody))
// 调试日志:打印重试耗尽后的错误响应
log.Printf("[Forward] Upstream error (retry exhausted, failover): Account=%d(%s) Status=%d RequestID=%s Body=%s",
account.ID, account.Name, resp.StatusCode, resp.Header.Get("x-request-id"), truncateString(string(respBody), 1000))
s.handleRetryExhaustedSideEffects(ctx, resp, account)
appendOpsUpstreamError(c, OpsUpstreamErrorEvent{
Platform: account.Platform,
AccountID: account.ID,
AccountName: account.Name,
UpstreamStatusCode: resp.StatusCode,
UpstreamRequestID: resp.Header.Get("x-request-id"),
Kind: "retry_exhausted_failover",
Message: extractUpstreamErrorMessage(respBody),
Detail: func() string {
if s.cfg != nil && s.cfg.Gateway.LogUpstreamErrorBody {
return truncateString(string(respBody), s.cfg.Gateway.LogUpstreamErrorBodyMaxBytes)
}
return ""
}(),
})
return nil, &UpstreamFailoverError{StatusCode: resp.StatusCode}
}
return s.handleRetryExhaustedError(ctx, resp, c, account)
}
// 处理可切换账号的错误
if resp.StatusCode >= 400 && s.shouldFailoverUpstreamError(resp.StatusCode) {
respBody, _ := io.ReadAll(io.LimitReader(resp.Body, 2<<20))
_ = resp.Body.Close()
resp.Body = io.NopCloser(bytes.NewReader(respBody))
// 调试日志:打印上游错误响应
log.Printf("[Forward] Upstream error (failover): Account=%d(%s) Status=%d RequestID=%s Body=%s",
account.ID, account.Name, resp.StatusCode, resp.Header.Get("x-request-id"), truncateString(string(respBody), 1000))
s.handleFailoverSideEffects(ctx, resp, account)
appendOpsUpstreamError(c, OpsUpstreamErrorEvent{
Platform: account.Platform,
AccountID: account.ID,
UpstreamStatusCode: resp.StatusCode,
UpstreamRequestID: resp.Header.Get("x-request-id"),
Kind: "failover",
Message: extractUpstreamErrorMessage(respBody),
Detail: func() string {
if s.cfg != nil && s.cfg.Gateway.LogUpstreamErrorBody {
return truncateString(string(respBody), s.cfg.Gateway.LogUpstreamErrorBodyMaxBytes)
}
return ""
}(),
})
return nil, &UpstreamFailoverError{StatusCode: resp.StatusCode}
}
// 处理错误响应(不可重试的错误)
if resp.StatusCode >= 400 {
// 可选:对部分 400 触发 failover默认关闭以保持语义
if resp.StatusCode == 400 && s.cfg != nil && s.cfg.Gateway.FailoverOn400 {
respBody, readErr := io.ReadAll(io.LimitReader(resp.Body, 2<<20))
if readErr != nil {
// ReadAll failed, fall back to normal error handling without consuming the stream
return s.handleErrorResponse(ctx, resp, c, account)
}
_ = resp.Body.Close()
resp.Body = io.NopCloser(bytes.NewReader(respBody))
if s.shouldFailoverOn400(respBody) {
upstreamMsg := strings.TrimSpace(extractUpstreamErrorMessage(respBody))
upstreamMsg = sanitizeUpstreamErrorMessage(upstreamMsg)
upstreamDetail := ""
if s.cfg != nil && s.cfg.Gateway.LogUpstreamErrorBody {
maxBytes := s.cfg.Gateway.LogUpstreamErrorBodyMaxBytes
if maxBytes <= 0 {
maxBytes = 2048
}
upstreamDetail = truncateString(string(respBody), maxBytes)
}
appendOpsUpstreamError(c, OpsUpstreamErrorEvent{
Platform: account.Platform,
AccountID: account.ID,
AccountName: account.Name,
UpstreamStatusCode: resp.StatusCode,
UpstreamRequestID: resp.Header.Get("x-request-id"),
Kind: "failover_on_400",
Message: upstreamMsg,
Detail: upstreamDetail,
})
if s.cfg.Gateway.LogUpstreamErrorBody {
log.Printf(
"Account %d: 400 error, attempting failover: %s",
account.ID,
truncateForLog(respBody, s.cfg.Gateway.LogUpstreamErrorBodyMaxBytes),
)
} else {
log.Printf("Account %d: 400 error, attempting failover", account.ID)
}
s.handleFailoverSideEffects(ctx, resp, account)
return nil, &UpstreamFailoverError{StatusCode: resp.StatusCode}
}
}
return s.handleErrorResponse(ctx, resp, c, account)
}
// 处理正常响应
var usage *ClaudeUsage
var firstTokenMs *int
var clientDisconnect bool
if reqStream {
streamResult, err := s.handleStreamingResponse(ctx, resp, c, account, startTime, originalModel, reqModel)
if err != nil {
if err.Error() == "have error in stream" {
return nil, &UpstreamFailoverError{
StatusCode: 403,
}
}
return nil, err
}
usage = streamResult.usage
firstTokenMs = streamResult.firstTokenMs
clientDisconnect = streamResult.clientDisconnect
} else {
usage, err = s.handleNonStreamingResponse(ctx, resp, c, account, originalModel, reqModel)
if err != nil {
return nil, err
}
}
return &ForwardResult{
RequestID: resp.Header.Get("x-request-id"),
Usage: *usage,
Model: originalModel, // 使用原始模型用于计费和日志
Stream: reqStream,
Duration: time.Since(startTime),
FirstTokenMs: firstTokenMs,
ClientDisconnect: clientDisconnect,
}, nil
}
func (s *GatewayService) buildUpstreamRequest(ctx context.Context, c *gin.Context, account *Account, body []byte, token, tokenType, modelID string) (*http.Request, error) {
// 确定目标URL
targetURL := claudeAPIURL
if account.Type == AccountTypeAPIKey {
baseURL := account.GetBaseURL()
if baseURL != "" {
validatedURL, err := s.validateUpstreamBaseURL(baseURL)
if err != nil {
return nil, err
}
targetURL = validatedURL + "/v1/messages"
}
}
// OAuth账号应用统一指纹
var fingerprint *Fingerprint
if account.IsOAuth() && s.identityService != nil {
// 1. 获取或创建指纹包含随机生成的ClientID
fp, err := s.identityService.GetOrCreateFingerprint(ctx, account.ID, c.Request.Header)
if err != nil {
log.Printf("Warning: failed to get fingerprint for account %d: %v", account.ID, err)
// 失败时降级为透传原始headers
} else {
fingerprint = fp
// 2. 重写metadata.user_id需要指纹中的ClientID和账号的account_uuid
// 如果启用了会话ID伪装会在重写后替换 session 部分为固定值
accountUUID := account.GetExtraString("account_uuid")
if accountUUID != "" && fp.ClientID != "" {
if newBody, err := s.identityService.RewriteUserIDWithMasking(ctx, body, account, accountUUID, fp.ClientID); err == nil && len(newBody) > 0 {
body = newBody
}
}
}
}
req, err := http.NewRequestWithContext(ctx, "POST", targetURL, bytes.NewReader(body))
if err != nil {
return nil, err
}
// 设置认证头
if tokenType == "oauth" {
req.Header.Set("authorization", "Bearer "+token)
} else {
req.Header.Set("x-api-key", token)
}
// 白名单透传headers
for key, values := range c.Request.Header {
lowerKey := strings.ToLower(key)
if allowedHeaders[lowerKey] {
for _, v := range values {
req.Header.Add(key, v)
}
}
}
// OAuth账号应用缓存的指纹到请求头覆盖白名单透传的头
if fingerprint != nil {
s.identityService.ApplyFingerprint(req, fingerprint)
}
// 确保必要的headers存在
if req.Header.Get("content-type") == "" {
req.Header.Set("content-type", "application/json")
}
if req.Header.Get("anthropic-version") == "" {
req.Header.Set("anthropic-version", "2023-06-01")
}
// 处理anthropic-beta headerOAuth账号需要特殊处理
if tokenType == "oauth" {
req.Header.Set("anthropic-beta", s.getBetaHeader(modelID, c.GetHeader("anthropic-beta")))
} else if s.cfg != nil && s.cfg.Gateway.InjectBetaForAPIKey && req.Header.Get("anthropic-beta") == "" {
// API-key仅在请求显式使用 beta 特性且客户端未提供时,按需补齐(默认关闭)
if requestNeedsBetaFeatures(body) {
if beta := defaultAPIKeyBetaHeader(body); beta != "" {
req.Header.Set("anthropic-beta", beta)
}
}
}
return req, nil
}
// getBetaHeader 处理anthropic-beta header
// 对于OAuth账号需要确保包含oauth-2025-04-20
func (s *GatewayService) getBetaHeader(modelID string, clientBetaHeader string) string {
// 如果客户端传了anthropic-beta
if clientBetaHeader != "" {
// 已包含oauth beta则直接返回
if strings.Contains(clientBetaHeader, claude.BetaOAuth) {
return clientBetaHeader
}
// 需要添加oauth beta
parts := strings.Split(clientBetaHeader, ",")
for i, p := range parts {
parts[i] = strings.TrimSpace(p)
}
// 在claude-code-20250219后面插入oauth beta
claudeCodeIdx := -1
for i, p := range parts {
if p == claude.BetaClaudeCode {
claudeCodeIdx = i
break
}
}
if claudeCodeIdx >= 0 {
// 在claude-code后面插入
newParts := make([]string, 0, len(parts)+1)
newParts = append(newParts, parts[:claudeCodeIdx+1]...)
newParts = append(newParts, claude.BetaOAuth)
newParts = append(newParts, parts[claudeCodeIdx+1:]...)
return strings.Join(newParts, ",")
}
// 没有claude-code放在第一位
return claude.BetaOAuth + "," + clientBetaHeader
}
// 客户端没传,根据模型生成
// haiku 模型不需要 claude-code beta
if strings.Contains(strings.ToLower(modelID), "haiku") {
return claude.HaikuBetaHeader
}
return claude.DefaultBetaHeader
}
func requestNeedsBetaFeatures(body []byte) bool {
tools := gjson.GetBytes(body, "tools")
if tools.Exists() && tools.IsArray() && len(tools.Array()) > 0 {
return true
}
if strings.EqualFold(gjson.GetBytes(body, "thinking.type").String(), "enabled") {
return true
}
return false
}
func defaultAPIKeyBetaHeader(body []byte) string {
modelID := gjson.GetBytes(body, "model").String()
if strings.Contains(strings.ToLower(modelID), "haiku") {
return claude.APIKeyHaikuBetaHeader
}
return claude.APIKeyBetaHeader
}
func truncateForLog(b []byte, maxBytes int) string {
if maxBytes <= 0 {
maxBytes = 2048
}
if len(b) > maxBytes {
b = b[:maxBytes]
}
s := string(b)
// 保持一行,避免污染日志格式
s = strings.ReplaceAll(s, "\n", "\\n")
s = strings.ReplaceAll(s, "\r", "\\r")
return s
}
// isThinkingBlockSignatureError 检测是否是thinking block相关错误
// 这类错误可以通过过滤thinking blocks并重试来解决
func (s *GatewayService) isThinkingBlockSignatureError(respBody []byte) bool {
msg := strings.ToLower(strings.TrimSpace(extractUpstreamErrorMessage(respBody)))
if msg == "" {
return false
}
// Log for debugging
log.Printf("[SignatureCheck] Checking error message: %s", msg)
// 检测signature相关的错误更宽松的匹配
// 例如: "Invalid `signature` in `thinking` block", "***.signature" 等
if strings.Contains(msg, "signature") {
log.Printf("[SignatureCheck] Detected signature error")
return true
}
// 检测 thinking block 顺序/类型错误
// 例如: "Expected `thinking` or `redacted_thinking`, but found `text`"
if strings.Contains(msg, "expected") && (strings.Contains(msg, "thinking") || strings.Contains(msg, "redacted_thinking")) {
log.Printf("[SignatureCheck] Detected thinking block type error")
return true
}
// 检测空消息内容错误(可能是过滤 thinking blocks 后导致的)
// 例如: "all messages must have non-empty content"
if strings.Contains(msg, "non-empty content") || strings.Contains(msg, "empty content") {
log.Printf("[SignatureCheck] Detected empty content error")
return true
}
return false
}
func (s *GatewayService) shouldFailoverOn400(respBody []byte) bool {
// 只对“可能是兼容性差异导致”的 400 允许切换,避免无意义重试。
// 默认保守:无法识别则不切换。
msg := strings.ToLower(strings.TrimSpace(extractUpstreamErrorMessage(respBody)))
if msg == "" {
return false
}
// 缺少/错误的 beta header换账号/链路可能成功(尤其是混合调度时)。
// 更精确匹配 beta 相关的兼容性问题,避免误触发切换。
if strings.Contains(msg, "anthropic-beta") ||
strings.Contains(msg, "beta feature") ||
strings.Contains(msg, "requires beta") {
return true
}
// thinking/tool streaming 等兼容性约束(常见于中间转换链路)
if strings.Contains(msg, "thinking") || strings.Contains(msg, "thought_signature") || strings.Contains(msg, "signature") {
return true
}
if strings.Contains(msg, "tool_use") || strings.Contains(msg, "tool_result") || strings.Contains(msg, "tools") {
return true
}
return false
}
func extractUpstreamErrorMessage(body []byte) string {
// Claude 风格:{"type":"error","error":{"type":"...","message":"..."}}
if m := gjson.GetBytes(body, "error.message").String(); strings.TrimSpace(m) != "" {
inner := strings.TrimSpace(m)
// 有些上游会把完整 JSON 作为字符串塞进 message
if strings.HasPrefix(inner, "{") {
if innerMsg := gjson.Get(inner, "error.message").String(); strings.TrimSpace(innerMsg) != "" {
return innerMsg
}
}
return m
}
// 兜底:尝试顶层 message
return gjson.GetBytes(body, "message").String()
}
func (s *GatewayService) handleErrorResponse(ctx context.Context, resp *http.Response, c *gin.Context, account *Account) (*ForwardResult, error) {
body, _ := io.ReadAll(io.LimitReader(resp.Body, 2<<20))
// 调试日志:打印上游错误响应
log.Printf("[Forward] Upstream error (non-retryable): Account=%d(%s) Status=%d RequestID=%s Body=%s",
account.ID, account.Name, resp.StatusCode, resp.Header.Get("x-request-id"), truncateString(string(body), 1000))
upstreamMsg := strings.TrimSpace(extractUpstreamErrorMessage(body))
upstreamMsg = sanitizeUpstreamErrorMessage(upstreamMsg)
// Enrich Ops error logs with upstream status + message, and optionally a truncated body snippet.
upstreamDetail := ""
if s.cfg != nil && s.cfg.Gateway.LogUpstreamErrorBody {
maxBytes := s.cfg.Gateway.LogUpstreamErrorBodyMaxBytes
if maxBytes <= 0 {
maxBytes = 2048
}
upstreamDetail = truncateString(string(body), maxBytes)
}
setOpsUpstreamError(c, resp.StatusCode, upstreamMsg, upstreamDetail)
appendOpsUpstreamError(c, OpsUpstreamErrorEvent{
Platform: account.Platform,
AccountID: account.ID,
UpstreamStatusCode: resp.StatusCode,
UpstreamRequestID: resp.Header.Get("x-request-id"),
Kind: "http_error",
Message: upstreamMsg,
Detail: upstreamDetail,
})
// 处理上游错误,标记账号状态
shouldDisable := false
if s.rateLimitService != nil {
shouldDisable = s.rateLimitService.HandleUpstreamError(ctx, account, resp.StatusCode, resp.Header, body)
}
if shouldDisable {
return nil, &UpstreamFailoverError{StatusCode: resp.StatusCode}
}
// 记录上游错误响应体摘要便于排障(可选:由配置控制;不回显到客户端)
if s.cfg != nil && s.cfg.Gateway.LogUpstreamErrorBody {
log.Printf(
"Upstream error %d (account=%d platform=%s type=%s): %s",
resp.StatusCode,
account.ID,
account.Platform,
account.Type,
truncateForLog(body, s.cfg.Gateway.LogUpstreamErrorBodyMaxBytes),
)
}
// 根据状态码返回适当的自定义错误响应(不透传上游详细信息)
var errType, errMsg string
var statusCode int
switch resp.StatusCode {
case 400:
c.Data(http.StatusBadRequest, "application/json", body)
summary := upstreamMsg
if summary == "" {
summary = truncateForLog(body, 512)
}
if summary == "" {
return nil, fmt.Errorf("upstream error: %d", resp.StatusCode)
}
return nil, fmt.Errorf("upstream error: %d message=%s", resp.StatusCode, summary)
case 401:
statusCode = http.StatusBadGateway
errType = "upstream_error"
errMsg = "Upstream authentication failed, please contact administrator"
case 403:
statusCode = http.StatusBadGateway
errType = "upstream_error"
errMsg = "Upstream access forbidden, please contact administrator"
case 429:
statusCode = http.StatusTooManyRequests
errType = "rate_limit_error"
errMsg = "Upstream rate limit exceeded, please retry later"
case 529:
statusCode = http.StatusServiceUnavailable
errType = "overloaded_error"
errMsg = "Upstream service overloaded, please retry later"
case 500, 502, 503, 504:
statusCode = http.StatusBadGateway
errType = "upstream_error"
errMsg = "Upstream service temporarily unavailable"
default:
statusCode = http.StatusBadGateway
errType = "upstream_error"
errMsg = "Upstream request failed"
}
// 返回自定义错误响应
c.JSON(statusCode, gin.H{
"type": "error",
"error": gin.H{
"type": errType,
"message": errMsg,
},
})
if upstreamMsg == "" {
return nil, fmt.Errorf("upstream error: %d", resp.StatusCode)
}
return nil, fmt.Errorf("upstream error: %d message=%s", resp.StatusCode, upstreamMsg)
}
func (s *GatewayService) handleRetryExhaustedSideEffects(ctx context.Context, resp *http.Response, account *Account) {
body, _ := io.ReadAll(io.LimitReader(resp.Body, 2<<20))
statusCode := resp.StatusCode
// OAuth/Setup Token 账号的 403标记账号异常
if account.IsOAuth() && statusCode == 403 {
s.rateLimitService.HandleUpstreamError(ctx, account, statusCode, resp.Header, body)
log.Printf("Account %d: marked as error after %d retries for status %d", account.ID, maxRetryAttempts, statusCode)
} else {
// API Key 未配置错误码:不标记账号状态
log.Printf("Account %d: upstream error %d after %d retries (not marking account)", account.ID, statusCode, maxRetryAttempts)
}
}
func (s *GatewayService) handleFailoverSideEffects(ctx context.Context, resp *http.Response, account *Account) {
body, _ := io.ReadAll(io.LimitReader(resp.Body, 2<<20))
s.rateLimitService.HandleUpstreamError(ctx, account, resp.StatusCode, resp.Header, body)
}
// handleRetryExhaustedError 处理重试耗尽后的错误
// OAuth 403标记账号异常
// API Key 未配置错误码:仅返回错误,不标记账号
func (s *GatewayService) handleRetryExhaustedError(ctx context.Context, resp *http.Response, c *gin.Context, account *Account) (*ForwardResult, error) {
// Capture upstream error body before side-effects consume the stream.
respBody, _ := io.ReadAll(io.LimitReader(resp.Body, 2<<20))
_ = resp.Body.Close()
resp.Body = io.NopCloser(bytes.NewReader(respBody))
s.handleRetryExhaustedSideEffects(ctx, resp, account)
upstreamMsg := strings.TrimSpace(extractUpstreamErrorMessage(respBody))
upstreamMsg = sanitizeUpstreamErrorMessage(upstreamMsg)
upstreamDetail := ""
if s.cfg != nil && s.cfg.Gateway.LogUpstreamErrorBody {
maxBytes := s.cfg.Gateway.LogUpstreamErrorBodyMaxBytes
if maxBytes <= 0 {
maxBytes = 2048
}
upstreamDetail = truncateString(string(respBody), maxBytes)
}
setOpsUpstreamError(c, resp.StatusCode, upstreamMsg, upstreamDetail)
appendOpsUpstreamError(c, OpsUpstreamErrorEvent{
Platform: account.Platform,
AccountID: account.ID,
UpstreamStatusCode: resp.StatusCode,
UpstreamRequestID: resp.Header.Get("x-request-id"),
Kind: "retry_exhausted",
Message: upstreamMsg,
Detail: upstreamDetail,
})
if s.cfg != nil && s.cfg.Gateway.LogUpstreamErrorBody {
log.Printf(
"Upstream error %d retries_exhausted (account=%d platform=%s type=%s): %s",
resp.StatusCode,
account.ID,
account.Platform,
account.Type,
truncateForLog(respBody, s.cfg.Gateway.LogUpstreamErrorBodyMaxBytes),
)
}
// 返回统一的重试耗尽错误响应
c.JSON(http.StatusBadGateway, gin.H{
"type": "error",
"error": gin.H{
"type": "upstream_error",
"message": "Upstream request failed after retries",
},
})
if upstreamMsg == "" {
return nil, fmt.Errorf("upstream error: %d (retries exhausted)", resp.StatusCode)
}
return nil, fmt.Errorf("upstream error: %d (retries exhausted) message=%s", resp.StatusCode, upstreamMsg)
}
// streamingResult 流式响应结果
type streamingResult struct {
usage *ClaudeUsage
firstTokenMs *int
clientDisconnect bool // 客户端是否在流式传输过程中断开
}
func (s *GatewayService) handleStreamingResponse(ctx context.Context, resp *http.Response, c *gin.Context, account *Account, startTime time.Time, originalModel, mappedModel string) (*streamingResult, error) {
// 更新5h窗口状态
s.rateLimitService.UpdateSessionWindow(ctx, account, resp.Header)
if s.cfg != nil {
responseheaders.WriteFilteredHeaders(c.Writer.Header(), resp.Header, s.cfg.Security.ResponseHeaders)
}
// 设置SSE响应头
c.Header("Content-Type", "text/event-stream")
c.Header("Cache-Control", "no-cache")
c.Header("Connection", "keep-alive")
c.Header("X-Accel-Buffering", "no")
// 透传其他响应头
if v := resp.Header.Get("x-request-id"); v != "" {
c.Header("x-request-id", v)
}
w := c.Writer
flusher, ok := w.(http.Flusher)
if !ok {
return nil, errors.New("streaming not supported")
}
usage := &ClaudeUsage{}
var firstTokenMs *int
scanner := bufio.NewScanner(resp.Body)
// 设置更大的buffer以处理长行
maxLineSize := defaultMaxLineSize
if s.cfg != nil && s.cfg.Gateway.MaxLineSize > 0 {
maxLineSize = s.cfg.Gateway.MaxLineSize
}
scanner.Buffer(make([]byte, 64*1024), maxLineSize)
type scanEvent struct {
line string
err error
}
// 独立 goroutine 读取上游,避免读取阻塞导致超时/keepalive无法处理
events := make(chan scanEvent, 16)
done := make(chan struct{})
sendEvent := func(ev scanEvent) bool {
select {
case events <- ev:
return true
case <-done:
return false
}
}
var lastReadAt int64
atomic.StoreInt64(&lastReadAt, time.Now().UnixNano())
go func() {
defer close(events)
for scanner.Scan() {
atomic.StoreInt64(&lastReadAt, time.Now().UnixNano())
if !sendEvent(scanEvent{line: scanner.Text()}) {
return
}
}
if err := scanner.Err(); err != nil {
_ = sendEvent(scanEvent{err: err})
}
}()
defer close(done)
streamInterval := time.Duration(0)
if s.cfg != nil && s.cfg.Gateway.StreamDataIntervalTimeout > 0 {
streamInterval = time.Duration(s.cfg.Gateway.StreamDataIntervalTimeout) * time.Second
}
// 仅监控上游数据间隔超时,避免下游写入阻塞导致误判
var intervalTicker *time.Ticker
if streamInterval > 0 {
intervalTicker = time.NewTicker(streamInterval)
defer intervalTicker.Stop()
}
var intervalCh <-chan time.Time
if intervalTicker != nil {
intervalCh = intervalTicker.C
}
// 仅发送一次错误事件,避免多次写入导致协议混乱(写失败时尽力通知客户端)
errorEventSent := false
sendErrorEvent := func(reason string) {
if errorEventSent {
return
}
errorEventSent = true
_, _ = fmt.Fprintf(w, "event: error\ndata: {\"error\":\"%s\"}\n\n", reason)
flusher.Flush()
}
needModelReplace := originalModel != mappedModel
clientDisconnected := false // 客户端断开标志断开后继续读取上游以获取完整usage
for {
select {
case ev, ok := <-events:
if !ok {
// 上游完成,返回结果
return &streamingResult{usage: usage, firstTokenMs: firstTokenMs, clientDisconnect: clientDisconnected}, nil
}
if ev.err != nil {
// 检测 context 取消(客户端断开会导致 context 取消,进而影响上游读取)
if errors.Is(ev.err, context.Canceled) || errors.Is(ev.err, context.DeadlineExceeded) {
log.Printf("Context canceled during streaming, returning collected usage")
return &streamingResult{usage: usage, firstTokenMs: firstTokenMs, clientDisconnect: true}, nil
}
// 客户端已通过写入失败检测到断开,上游也出错了,返回已收集的 usage
if clientDisconnected {
log.Printf("Upstream read error after client disconnect: %v, returning collected usage", ev.err)
return &streamingResult{usage: usage, firstTokenMs: firstTokenMs, clientDisconnect: true}, nil
}
// 客户端未断开,正常的错误处理
if errors.Is(ev.err, bufio.ErrTooLong) {
log.Printf("SSE line too long: account=%d max_size=%d error=%v", account.ID, maxLineSize, ev.err)
sendErrorEvent("response_too_large")
return &streamingResult{usage: usage, firstTokenMs: firstTokenMs}, ev.err
}
sendErrorEvent("stream_read_error")
return &streamingResult{usage: usage, firstTokenMs: firstTokenMs}, fmt.Errorf("stream read error: %w", ev.err)
}
line := ev.line
if line == "event: error" {
// 上游返回错误事件,如果客户端已断开仍返回已收集的 usage
if clientDisconnected {
return &streamingResult{usage: usage, firstTokenMs: firstTokenMs, clientDisconnect: true}, nil
}
return nil, errors.New("have error in stream")
}
// Extract data from SSE line (supports both "data: " and "data:" formats)
var data string
if sseDataRe.MatchString(line) {
data = sseDataRe.ReplaceAllString(line, "")
// 如果有模型映射替换响应中的model字段
if needModelReplace {
line = s.replaceModelInSSELine(line, mappedModel, originalModel)
}
}
// 写入客户端(统一处理 data 行和非 data 行)
if !clientDisconnected {
if _, err := fmt.Fprintf(w, "%s\n", line); err != nil {
clientDisconnected = true
log.Printf("Client disconnected during streaming, continuing to drain upstream for billing")
} else {
flusher.Flush()
}
}
// 无论客户端是否断开,都解析 usage仅对 data 行)
if data != "" {
if firstTokenMs == nil && data != "[DONE]" {
ms := int(time.Since(startTime).Milliseconds())
firstTokenMs = &ms
}
s.parseSSEUsage(data, usage)
}
case <-intervalCh:
lastRead := time.Unix(0, atomic.LoadInt64(&lastReadAt))
if time.Since(lastRead) < streamInterval {
continue
}
if clientDisconnected {
// 客户端已断开,上游也超时了,返回已收集的 usage
log.Printf("Upstream timeout after client disconnect, returning collected usage")
return &streamingResult{usage: usage, firstTokenMs: firstTokenMs, clientDisconnect: true}, nil
}
log.Printf("Stream data interval timeout: account=%d model=%s interval=%s", account.ID, originalModel, streamInterval)
// 处理流超时,可能标记账户为临时不可调度或错误状态
if s.rateLimitService != nil {
s.rateLimitService.HandleStreamTimeout(ctx, account, originalModel)
}
sendErrorEvent("stream_timeout")
return &streamingResult{usage: usage, firstTokenMs: firstTokenMs}, fmt.Errorf("stream data interval timeout")
}
}
}
// replaceModelInSSELine 替换SSE数据行中的model字段
func (s *GatewayService) replaceModelInSSELine(line, fromModel, toModel string) string {
if !sseDataRe.MatchString(line) {
return line
}
data := sseDataRe.ReplaceAllString(line, "")
if data == "" || data == "[DONE]" {
return line
}
var event map[string]any
if err := json.Unmarshal([]byte(data), &event); err != nil {
return line
}
// 只替换 message_start 事件中的 message.model
if event["type"] != "message_start" {
return line
}
msg, ok := event["message"].(map[string]any)
if !ok {
return line
}
model, ok := msg["model"].(string)
if !ok || model != fromModel {
return line
}
msg["model"] = toModel
newData, err := json.Marshal(event)
if err != nil {
return line
}
return "data: " + string(newData)
}
func (s *GatewayService) parseSSEUsage(data string, usage *ClaudeUsage) {
// 解析message_start获取input tokens标准Claude API格式
var msgStart struct {
Type string `json:"type"`
Message struct {
Usage ClaudeUsage `json:"usage"`
} `json:"message"`
}
if json.Unmarshal([]byte(data), &msgStart) == nil && msgStart.Type == "message_start" {
usage.InputTokens = msgStart.Message.Usage.InputTokens
usage.CacheCreationInputTokens = msgStart.Message.Usage.CacheCreationInputTokens
usage.CacheReadInputTokens = msgStart.Message.Usage.CacheReadInputTokens
}
// 解析message_delta获取tokens兼容GLM等把所有usage放在delta中的API
var msgDelta struct {
Type string `json:"type"`
Usage struct {
InputTokens int `json:"input_tokens"`
OutputTokens int `json:"output_tokens"`
CacheCreationInputTokens int `json:"cache_creation_input_tokens"`
CacheReadInputTokens int `json:"cache_read_input_tokens"`
} `json:"usage"`
}
if json.Unmarshal([]byte(data), &msgDelta) == nil && msgDelta.Type == "message_delta" {
// message_delta 仅覆盖存在且非0的字段
// 避免覆盖 message_start 中已有的值(如 input_tokens
// Claude API 的 message_delta 通常只包含 output_tokens
if msgDelta.Usage.InputTokens > 0 {
usage.InputTokens = msgDelta.Usage.InputTokens
}
if msgDelta.Usage.OutputTokens > 0 {
usage.OutputTokens = msgDelta.Usage.OutputTokens
}
if msgDelta.Usage.CacheCreationInputTokens > 0 {
usage.CacheCreationInputTokens = msgDelta.Usage.CacheCreationInputTokens
}
if msgDelta.Usage.CacheReadInputTokens > 0 {
usage.CacheReadInputTokens = msgDelta.Usage.CacheReadInputTokens
}
}
}
func (s *GatewayService) handleNonStreamingResponse(ctx context.Context, resp *http.Response, c *gin.Context, account *Account, originalModel, mappedModel string) (*ClaudeUsage, error) {
// 更新5h窗口状态
s.rateLimitService.UpdateSessionWindow(ctx, account, resp.Header)
body, err := io.ReadAll(resp.Body)
if err != nil {
return nil, err
}
// 解析usage
var response struct {
Usage ClaudeUsage `json:"usage"`
}
if err := json.Unmarshal(body, &response); err != nil {
return nil, fmt.Errorf("parse response: %w", err)
}
// 如果有模型映射替换响应中的model字段
if originalModel != mappedModel {
body = s.replaceModelInResponseBody(body, mappedModel, originalModel)
}
responseheaders.WriteFilteredHeaders(c.Writer.Header(), resp.Header, s.cfg.Security.ResponseHeaders)
contentType := "application/json"
if s.cfg != nil && !s.cfg.Security.ResponseHeaders.Enabled {
if upstreamType := resp.Header.Get("Content-Type"); upstreamType != "" {
contentType = upstreamType
}
}
// 写入响应
c.Data(resp.StatusCode, contentType, body)
return &response.Usage, nil
}
// replaceModelInResponseBody 替换响应体中的model字段
func (s *GatewayService) replaceModelInResponseBody(body []byte, fromModel, toModel string) []byte {
var resp map[string]any
if err := json.Unmarshal(body, &resp); err != nil {
return body
}
model, ok := resp["model"].(string)
if !ok || model != fromModel {
return body
}
resp["model"] = toModel
newBody, err := json.Marshal(resp)
if err != nil {
return body
}
return newBody
}
// RecordUsageInput 记录使用量的输入参数
type RecordUsageInput struct {
Result *ForwardResult
APIKey *APIKey
User *User
Account *Account
Subscription *UserSubscription // 可选:订阅信息
UserAgent string // 请求的 User-Agent
IPAddress string // 请求的客户端 IP 地址
}
// RecordUsage 记录使用量并扣费(或更新订阅用量)
func (s *GatewayService) RecordUsage(ctx context.Context, input *RecordUsageInput) error {
result := input.Result
apiKey := input.APIKey
user := input.User
account := input.Account
subscription := input.Subscription
// 获取费率倍数
multiplier := s.cfg.Default.RateMultiplier
if apiKey.GroupID != nil && apiKey.Group != nil {
multiplier = apiKey.Group.RateMultiplier
}
var cost *CostBreakdown
// 根据请求类型选择计费方式
if result.ImageCount > 0 {
// 图片生成计费
var groupConfig *ImagePriceConfig
if apiKey.Group != nil {
groupConfig = &ImagePriceConfig{
Price1K: apiKey.Group.ImagePrice1K,
Price2K: apiKey.Group.ImagePrice2K,
Price4K: apiKey.Group.ImagePrice4K,
}
}
cost = s.billingService.CalculateImageCost(result.Model, result.ImageSize, result.ImageCount, groupConfig, multiplier)
} else {
// Token 计费
tokens := UsageTokens{
InputTokens: result.Usage.InputTokens,
OutputTokens: result.Usage.OutputTokens,
CacheCreationTokens: result.Usage.CacheCreationInputTokens,
CacheReadTokens: result.Usage.CacheReadInputTokens,
}
var err error
cost, err = s.billingService.CalculateCost(result.Model, tokens, multiplier)
if err != nil {
log.Printf("Calculate cost failed: %v", err)
cost = &CostBreakdown{ActualCost: 0}
}
}
// 判断计费方式:订阅模式 vs 余额模式
isSubscriptionBilling := subscription != nil && apiKey.Group != nil && apiKey.Group.IsSubscriptionType()
billingType := BillingTypeBalance
if isSubscriptionBilling {
billingType = BillingTypeSubscription
}
// 创建使用日志
durationMs := int(result.Duration.Milliseconds())
var imageSize *string
if result.ImageSize != "" {
imageSize = &result.ImageSize
}
accountRateMultiplier := account.BillingRateMultiplier()
usageLog := &UsageLog{
UserID: user.ID,
APIKeyID: apiKey.ID,
AccountID: account.ID,
RequestID: result.RequestID,
Model: result.Model,
InputTokens: result.Usage.InputTokens,
OutputTokens: result.Usage.OutputTokens,
CacheCreationTokens: result.Usage.CacheCreationInputTokens,
CacheReadTokens: result.Usage.CacheReadInputTokens,
InputCost: cost.InputCost,
OutputCost: cost.OutputCost,
CacheCreationCost: cost.CacheCreationCost,
CacheReadCost: cost.CacheReadCost,
TotalCost: cost.TotalCost,
ActualCost: cost.ActualCost,
RateMultiplier: multiplier,
AccountRateMultiplier: &accountRateMultiplier,
BillingType: billingType,
Stream: result.Stream,
DurationMs: &durationMs,
FirstTokenMs: result.FirstTokenMs,
ImageCount: result.ImageCount,
ImageSize: imageSize,
CreatedAt: time.Now(),
}
// 添加 UserAgent
if input.UserAgent != "" {
usageLog.UserAgent = &input.UserAgent
}
// 添加 IPAddress
if input.IPAddress != "" {
usageLog.IPAddress = &input.IPAddress
}
// 添加分组和订阅关联
if apiKey.GroupID != nil {
usageLog.GroupID = apiKey.GroupID
}
if subscription != nil {
usageLog.SubscriptionID = &subscription.ID
}
inserted, err := s.usageLogRepo.Create(ctx, usageLog)
if err != nil {
log.Printf("Create usage log failed: %v", err)
}
if s.cfg != nil && s.cfg.RunMode == config.RunModeSimple {
log.Printf("[SIMPLE MODE] Usage recorded (not billed): user=%d, tokens=%d", usageLog.UserID, usageLog.TotalTokens())
s.deferredService.ScheduleLastUsedUpdate(account.ID)
return nil
}
shouldBill := inserted || err != nil
// 根据计费类型执行扣费
if isSubscriptionBilling {
// 订阅模式:更新订阅用量(使用 TotalCost 原始费用,不考虑倍率)
if shouldBill && cost.TotalCost > 0 {
if err := s.userSubRepo.IncrementUsage(ctx, subscription.ID, cost.TotalCost); err != nil {
log.Printf("Increment subscription usage failed: %v", err)
}
// 异步更新订阅缓存
s.billingCacheService.QueueUpdateSubscriptionUsage(user.ID, *apiKey.GroupID, cost.TotalCost)
}
} else {
// 余额模式:扣除用户余额(使用 ActualCost 考虑倍率后的费用)
if shouldBill && cost.ActualCost > 0 {
if err := s.userRepo.DeductBalance(ctx, user.ID, cost.ActualCost); err != nil {
log.Printf("Deduct balance failed: %v", err)
}
// 异步更新余额缓存
s.billingCacheService.QueueDeductBalance(user.ID, cost.ActualCost)
}
}
// Schedule batch update for account last_used_at
s.deferredService.ScheduleLastUsedUpdate(account.ID)
return nil
}
// ForwardCountTokens 转发 count_tokens 请求到上游 API
// 特点:不记录使用量、仅支持非流式响应
func (s *GatewayService) ForwardCountTokens(ctx context.Context, c *gin.Context, account *Account, parsed *ParsedRequest) error {
if parsed == nil {
s.countTokensError(c, http.StatusBadRequest, "invalid_request_error", "Request body is empty")
return fmt.Errorf("parse request: empty request")
}
body := parsed.Body
reqModel := parsed.Model
// Antigravity 账户不支持 count_tokens 转发,直接返回空值
if account.Platform == PlatformAntigravity {
c.JSON(http.StatusOK, gin.H{"input_tokens": 0})
return nil
}
// 应用模型映射APIKey 明确映射优先,其次使用 Anthropic 前缀映射)
if reqModel != "" {
mappedModel := reqModel
mappingSource := ""
if account.Type == AccountTypeAPIKey {
mappedModel = account.GetMappedModel(reqModel)
if mappedModel != reqModel {
mappingSource = "account"
}
}
if mappingSource == "" && account.Platform == PlatformAnthropic {
normalized := normalizeClaudeModelForAnthropic(reqModel)
if normalized != reqModel {
mappedModel = normalized
mappingSource = "prefix"
}
}
if mappedModel != reqModel {
body = s.replaceModelInBody(body, mappedModel)
reqModel = mappedModel
log.Printf("CountTokens model mapping applied: %s -> %s (account: %s, source=%s)", parsed.Model, mappedModel, account.Name, mappingSource)
}
}
// 获取凭证
token, tokenType, err := s.GetAccessToken(ctx, account)
if err != nil {
s.countTokensError(c, http.StatusBadGateway, "upstream_error", "Failed to get access token")
return err
}
// 构建上游请求
upstreamReq, err := s.buildCountTokensRequest(ctx, c, account, body, token, tokenType, reqModel)
if err != nil {
s.countTokensError(c, http.StatusInternalServerError, "api_error", "Failed to build request")
return err
}
// 获取代理URL
proxyURL := ""
if account.ProxyID != nil && account.Proxy != nil {
proxyURL = account.Proxy.URL()
}
// 发送请求
resp, err := s.httpUpstream.DoWithTLS(upstreamReq, proxyURL, account.ID, account.Concurrency, account.IsTLSFingerprintEnabled())
if err != nil {
setOpsUpstreamError(c, 0, sanitizeUpstreamErrorMessage(err.Error()), "")
s.countTokensError(c, http.StatusBadGateway, "upstream_error", "Request failed")
return fmt.Errorf("upstream request failed: %w", err)
}
// 读取响应体
respBody, err := io.ReadAll(resp.Body)
_ = resp.Body.Close()
if err != nil {
s.countTokensError(c, http.StatusBadGateway, "upstream_error", "Failed to read response")
return err
}
// 检测 thinking block 签名错误400并重试一次过滤 thinking blocks
if resp.StatusCode == 400 && s.isThinkingBlockSignatureError(respBody) {
log.Printf("Account %d: detected thinking block signature error on count_tokens, retrying with filtered thinking blocks", account.ID)
filteredBody := FilterThinkingBlocksForRetry(body)
retryReq, buildErr := s.buildCountTokensRequest(ctx, c, account, filteredBody, token, tokenType, reqModel)
if buildErr == nil {
retryResp, retryErr := s.httpUpstream.DoWithTLS(retryReq, proxyURL, account.ID, account.Concurrency, account.IsTLSFingerprintEnabled())
if retryErr == nil {
resp = retryResp
respBody, err = io.ReadAll(resp.Body)
_ = resp.Body.Close()
if err != nil {
s.countTokensError(c, http.StatusBadGateway, "upstream_error", "Failed to read response")
return err
}
}
}
}
// 处理错误响应
if resp.StatusCode >= 400 {
// 标记账号状态429/529等
s.rateLimitService.HandleUpstreamError(ctx, account, resp.StatusCode, resp.Header, respBody)
upstreamMsg := strings.TrimSpace(extractUpstreamErrorMessage(respBody))
upstreamMsg = sanitizeUpstreamErrorMessage(upstreamMsg)
upstreamDetail := ""
if s.cfg != nil && s.cfg.Gateway.LogUpstreamErrorBody {
maxBytes := s.cfg.Gateway.LogUpstreamErrorBodyMaxBytes
if maxBytes <= 0 {
maxBytes = 2048
}
upstreamDetail = truncateString(string(respBody), maxBytes)
}
setOpsUpstreamError(c, resp.StatusCode, upstreamMsg, upstreamDetail)
// 记录上游错误摘要便于排障(不回显请求内容)
if s.cfg != nil && s.cfg.Gateway.LogUpstreamErrorBody {
log.Printf(
"count_tokens upstream error %d (account=%d platform=%s type=%s): %s",
resp.StatusCode,
account.ID,
account.Platform,
account.Type,
truncateForLog(respBody, s.cfg.Gateway.LogUpstreamErrorBodyMaxBytes),
)
}
// 返回简化的错误响应
errMsg := "Upstream request failed"
switch resp.StatusCode {
case 429:
errMsg = "Rate limit exceeded"
case 529:
errMsg = "Service overloaded"
}
s.countTokensError(c, resp.StatusCode, "upstream_error", errMsg)
if upstreamMsg == "" {
return fmt.Errorf("upstream error: %d", resp.StatusCode)
}
return fmt.Errorf("upstream error: %d message=%s", resp.StatusCode, upstreamMsg)
}
// 透传成功响应
c.Data(resp.StatusCode, "application/json", respBody)
return nil
}
// buildCountTokensRequest 构建 count_tokens 上游请求
func (s *GatewayService) buildCountTokensRequest(ctx context.Context, c *gin.Context, account *Account, body []byte, token, tokenType, modelID string) (*http.Request, error) {
// 确定目标 URL
targetURL := claudeAPICountTokensURL
if account.Type == AccountTypeAPIKey {
baseURL := account.GetBaseURL()
if baseURL != "" {
validatedURL, err := s.validateUpstreamBaseURL(baseURL)
if err != nil {
return nil, err
}
targetURL = validatedURL + "/v1/messages/count_tokens"
}
}
// OAuth 账号:应用统一指纹和重写 userID
// 如果启用了会话ID伪装会在重写后替换 session 部分为固定值
if account.IsOAuth() && s.identityService != nil {
fp, err := s.identityService.GetOrCreateFingerprint(ctx, account.ID, c.Request.Header)
if err == nil {
accountUUID := account.GetExtraString("account_uuid")
if accountUUID != "" && fp.ClientID != "" {
if newBody, err := s.identityService.RewriteUserIDWithMasking(ctx, body, account, accountUUID, fp.ClientID); err == nil && len(newBody) > 0 {
body = newBody
}
}
}
}
req, err := http.NewRequestWithContext(ctx, "POST", targetURL, bytes.NewReader(body))
if err != nil {
return nil, err
}
// 设置认证头
if tokenType == "oauth" {
req.Header.Set("authorization", "Bearer "+token)
} else {
req.Header.Set("x-api-key", token)
}
// 白名单透传 headers
for key, values := range c.Request.Header {
lowerKey := strings.ToLower(key)
if allowedHeaders[lowerKey] {
for _, v := range values {
req.Header.Add(key, v)
}
}
}
// OAuth 账号:应用指纹到请求头
if account.IsOAuth() && s.identityService != nil {
fp, _ := s.identityService.GetOrCreateFingerprint(ctx, account.ID, c.Request.Header)
if fp != nil {
s.identityService.ApplyFingerprint(req, fp)
}
}
// 确保必要的 headers 存在
if req.Header.Get("content-type") == "" {
req.Header.Set("content-type", "application/json")
}
if req.Header.Get("anthropic-version") == "" {
req.Header.Set("anthropic-version", "2023-06-01")
}
// OAuth 账号:处理 anthropic-beta header
if tokenType == "oauth" {
req.Header.Set("anthropic-beta", s.getBetaHeader(modelID, c.GetHeader("anthropic-beta")))
} else if s.cfg != nil && s.cfg.Gateway.InjectBetaForAPIKey && req.Header.Get("anthropic-beta") == "" {
// API-key与 messages 同步的按需 beta 注入(默认关闭)
if requestNeedsBetaFeatures(body) {
if beta := defaultAPIKeyBetaHeader(body); beta != "" {
req.Header.Set("anthropic-beta", beta)
}
}
}
return req, nil
}
// countTokensError 返回 count_tokens 错误响应
func (s *GatewayService) countTokensError(c *gin.Context, status int, errType, message string) {
c.JSON(status, gin.H{
"type": "error",
"error": gin.H{
"type": errType,
"message": message,
},
})
}
func (s *GatewayService) validateUpstreamBaseURL(raw string) (string, error) {
if s.cfg != nil && !s.cfg.Security.URLAllowlist.Enabled {
normalized, err := urlvalidator.ValidateURLFormat(raw, s.cfg.Security.URLAllowlist.AllowInsecureHTTP)
if err != nil {
return "", fmt.Errorf("invalid base_url: %w", err)
}
return normalized, nil
}
normalized, err := urlvalidator.ValidateHTTPSURL(raw, urlvalidator.ValidationOptions{
AllowedHosts: s.cfg.Security.URLAllowlist.UpstreamHosts,
RequireAllowlist: true,
AllowPrivate: s.cfg.Security.URLAllowlist.AllowPrivateHosts,
})
if err != nil {
return "", fmt.Errorf("invalid base_url: %w", err)
}
return normalized, nil
}
// checkAntigravityModelScope 检查 Antigravity 平台的模型系列是否在分组支持范围内
func (s *GatewayService) checkAntigravityModelScope(ctx context.Context, groupID int64, requestedModel string) error {
scope, ok := ResolveAntigravityQuotaScope(requestedModel)
if !ok {
return nil // 无法解析 scope跳过检查
}
group, err := s.resolveGroupByID(ctx, groupID)
if err != nil {
return nil // 查询失败时放行
}
if group == nil {
return nil // 分组不存在时放行
}
if !IsScopeSupported(group.SupportedModelScopes, scope) {
return ErrModelScopeNotSupported
}
return nil
}
// GetAvailableModels returns the list of models available for a group
// It aggregates model_mapping keys from all schedulable accounts in the group
func (s *GatewayService) GetAvailableModels(ctx context.Context, groupID *int64, platform string) []string {
var accounts []Account
var err error
if groupID != nil {
accounts, err = s.accountRepo.ListSchedulableByGroupID(ctx, *groupID)
} else {
accounts, err = s.accountRepo.ListSchedulable(ctx)
}
if err != nil || len(accounts) == 0 {
return nil
}
// Filter by platform if specified
if platform != "" {
filtered := make([]Account, 0)
for _, acc := range accounts {
if acc.Platform == platform {
filtered = append(filtered, acc)
}
}
accounts = filtered
}
// Collect unique models from all accounts
modelSet := make(map[string]struct{})
hasAnyMapping := false
for _, acc := range accounts {
mapping := acc.GetModelMapping()
if len(mapping) > 0 {
hasAnyMapping = true
for model := range mapping {
modelSet[model] = struct{}{}
}
}
}
// If no account has model_mapping, return nil (use default)
if !hasAnyMapping {
return nil
}
// Convert to slice
models := make([]string, 0, len(modelSet))
for model := range modelSet {
models = append(models, model)
}
return models
}