Files
sub2api/backend/internal/service/gateway_service.go
IanShaw 7fdc2b2d29 Fix/multiple issues (#24)
* fix(gemini): 修复 google_one OAuth 配置和 scopes 问题

- 修复 google_one 类型在 ExchangeCode 和 RefreshToken 中使用内置客户端
- 添加 DefaultGoogleOneScopes,包含 generative-language 和 drive.readonly 权限
- 在 EffectiveOAuthConfig 中为 google_one 类型使用专门的 scopes
- 将 docker-compose.override.yml 重命名为 .example 并添加到 .gitignore
- 完善 docker-compose.override.yml.example 示例文档

解决问题:
1. google_one OAuth 授权后 API 调用返回 403 权限不足
2. 缺少访问 Gemini API 所需的 generative-language scope
3. 缺少获取 Drive 存储配额所需的 drive.readonly scope

* fix(antigravity): 完全跳过 Claude 模型的所有 thinking 块

问题分析:
- 当前代码尝试保留有 signature 的 thinking 块
- 但 Vertex AI 的 signature 是完整性令牌,无法在本地验证
- 导致 400 错误:Invalid signature in thinking block

根本原因:
1. thinking 功能已对非 Gemini 模型禁用 (isThinkingEnabled=false)
2. Vertex AI 要求原样重放 (thinking, signature) 对或完全不发送
3. 本地无法复制 Vertex 的加密验证逻辑

修复方案:
- 对 Claude 模型完全跳过所有 thinking 块(无论是否有 signature)
- 保持 Gemini 模型使用 dummy signature 的行为不变
- 更新测试用例以反映新的预期行为

影响:
- 消除 thinking 相关的 400 错误
- 与现有的 thinking 禁用策略保持一致
- 不影响 Gemini 模型的 thinking 功能

测试:
-  TestBuildParts_ThinkingBlockWithoutSignature 全部通过
-  TestBuildTools_CustomTypeTools 全部通过

参考:Codex review 建议

* fix(gateway): 修复 count_tokens 端点 400 错误

问题分析:
- count_tokens 请求包含 thinking 块时返回 400 错误
- 原因:thinking 块未被过滤,直接转发到上游 API
- 上游 API 拒绝无效的 thinking signature

根本原因:
1. /v1/messages 请求通过 TransformClaudeToGemini 过滤 thinking 块
2. count_tokens 请求绕过转换,直接转发原始请求体
3. 导致包含无效 signature 的 thinking 块被发送到上游

修复方案:
- 创建 FilterThinkingBlocks 工具函数
- 在 buildCountTokensRequest 中应用过滤(1 行修改)
- 与 /v1/messages 行为保持一致

实现细节:
- FilterThinkingBlocks: 解析 JSON,过滤 thinking 块,重新序列化
- 失败安全:解析/序列化失败时返回原始请求体
- 性能优化:仅在发现 thinking 块时重新序列化

测试:
-  6 个单元测试全部通过
-  覆盖正常过滤、无 thinking 块、无效 JSON 等场景
-  现有测试不受影响

影响:
- 消除 count_tokens 的 400 错误
- 不影响 Antigravity 账号(仍返回模拟响应)
- 适用于所有账号类型(OAuth、API Key)

文件修改:
- backend/internal/service/gateway_request.go: +62 行(新函数)
- backend/internal/service/gateway_service.go: +2 行(应用过滤)
- backend/internal/service/gateway_request_test.go: +62 行(测试)

* fix(gateway): 增强 thinking 块过滤逻辑

基于 Codex 分析和建议的改进:

问题分析:
- 新错误:signature: Field required(signature 字段缺失)
- 旧错误:Invalid signature(signature 存在但无效)
- 两者都说明 thinking 块在请求中是危险的

Codex 建议:
- 保持 Option A:完全跳过所有 thinking 块
- 原因:thinking 块应该是只输出的,除非有服务端来源证明
- 在无状态代理中,无法安全区分上游来源 vs 客户端注入

改进内容:

1. 增强 FilterThinkingBlocks 函数
   - 过滤显式的 thinking 块:{"type":"thinking", ...}
   - 过滤无 type 的 thinking 对象:{"thinking": {...}}
   - 保留 tool_use 等其他类型块中的 thinking 字段
   - 修复:只在实际过滤时更新 content 数组

2. 扩展过滤范围
   - 将 FilterThinkingBlocks 应用到 /v1/messages 主路径
   - 之前只应用于 count_tokens,现在两个端点都过滤
   - 防止所有端点的 thinking 相关 400 错误

3. 改进测试
   - 新增:过滤无 type discriminator 的 thinking 块
   - 新增:不过滤 tool_use 中的 thinking 字段
   - 使用 containsThinkingBlock 辅助函数验证

测试:
-  8 个测试用例全部通过
-  覆盖各种 thinking 块格式
-  确保不误伤其他类型的块

影响:
- 消除 signature required 和 invalid signature 错误
- 统一 /v1/messages 和 count_tokens 的行为
- 更健壮的 thinking 块检测逻辑

参考:Codex review 和代码改进

* refactor: 根据 Codex 审查建议进行代码优化

基于 Codex 代码审查的 P1 和 P2 改进:

P1 改进(重要问题):

1. 优化日志输出
   - 移除 thinking 块跳过时的 log.Printf
   - 避免高频请求下的日志噪音
   - 添加注释说明可通过指标监控

2. 清理遗留代码
   - 删除未使用的 isValidThoughtSignature 函数(27行)
   - 该函数在改为完全跳过 thinking 块后不再需要

P2 改进(性能优化):

3. 添加快速路径检查
   - 在 FilterThinkingBlocks 中添加 bytes.Contains 预检查
   - 如果请求体不包含 "thinking" 字符串,直接返回
   - 避免不必要的 JSON 解析,提升性能

技术细节:
- request_transformer.go: -27行(删除函数),+1行(优化注释)
- gateway_request.go: +5行(快速路径 + bytes 导入)

测试:
-  TestBuildParts_ThinkingBlockWithoutSignature 全部通过
-  TestFilterThinkingBlocks 全部通过(8个测试用例)

影响:
- 减少日志噪音
- 提升性能(快速路径)
- 代码更简洁(删除未使用代码)

参考:Codex 代码审查建议

* fix: 修复 golangci-lint 检查问题

- 格式化 gateway_request_test.go
- 使用 switch 语句替代 if-else 链(staticcheck QF1003)

* fix(antigravity): 修复 thinking signature 处理并实现 Auto 模式降级

问题分析:
1. 原先代码错误地禁用了 Claude via Vertex 的 thinkingConfig
2. 历史 thinking 块的 signature 被完全跳过,导致验证失败
3. 跨模型混用时 dummy signature 会导致 400 错误

修复内容:

**request_transformer.go**:
- 删除第 38-43 行的错误逻辑(禁用 thinkingConfig)
- 引入 thoughtSignatureMode(Preserve/Dummy)策略
- Claude 模式:透传真实 signature,过滤空/dummy
- Gemini 模式:使用 dummy signature
- 支持 signature-only thinking 块
- tool_use 的 signature 也透传

**antigravity_gateway_service.go**:
- 新增 isSignatureRelatedError() 检测 signature 相关错误
- 新增 stripThinkingFromClaudeRequest() 移除 thinking 块
- 实现 Auto 模式:检测 400 + signature 关键词时自动降级重试
- 重试时完全移除 thinking 配置和消息中的 thinking 块
- 最多重试一次,避免循环

**测试**:
- 更新并新增测试覆盖 Claude preserve/Gemini dummy 模式
- 新增 tool_use signature 处理测试
- 所有测试通过(6/6)

影响:
-  Claude via Vertex 可以正常使用 thinking 功能
-  历史 signature 正确透传,避免验证失败
-  跨模型混用时自动过滤无效 signature
-  错误驱动降级,自动修复 signature 问题
-  不影响纯 Claude API 和其他渠道

参考:Codex 深度分析和实现建议

* fix(lint): 修复 gofmt 格式问题

* fix(antigravity): 修复 stripThinkingFromClaudeRequest 遗漏 untyped thinking blocks

问题:
- Codex 审查指出 stripThinkingFromClaudeRequest 只移除了 type="thinking" 的块
- 没有处理没有 type 字段的 thinking 对象(如 {"thinking": "...", "signature": "..."})
- 导致重试时仍包含无效 thinking 块,上游 400 错误持续

修复:
- 添加检查:跳过没有 type 但有 thinking 字段的块
- 现在会移除两种格式:
  1. {"type": "thinking", "thinking": "...", "signature": "..."}
  2. {"thinking": "...", "signature": "..."}(untyped)

测试:所有测试通过

参考:Codex P1 审查意见
2026-01-02 17:47:49 +08:00

1990 lines
62 KiB
Go
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

package service
import (
"bufio"
"bytes"
"context"
"crypto/sha256"
"encoding/hex"
"encoding/json"
"errors"
"fmt"
"io"
"log"
"net/http"
"regexp"
"sort"
"strings"
"time"
"github.com/Wei-Shaw/sub2api/internal/config"
"github.com/Wei-Shaw/sub2api/internal/pkg/claude"
"github.com/Wei-Shaw/sub2api/internal/pkg/ctxkey"
"github.com/tidwall/gjson"
"github.com/tidwall/sjson"
"github.com/gin-gonic/gin"
)
const (
claudeAPIURL = "https://api.anthropic.com/v1/messages?beta=true"
claudeAPICountTokensURL = "https://api.anthropic.com/v1/messages/count_tokens?beta=true"
stickySessionTTL = time.Hour // 粘性会话TTL
)
// sseDataRe matches SSE data lines with optional whitespace after colon.
// Some upstream APIs return non-standard "data:" without space (should be "data: ").
var (
sseDataRe = regexp.MustCompile(`^data:\s*`)
sessionIDRegex = regexp.MustCompile(`session_([a-f0-9-]{36})`)
)
// allowedHeaders 白名单headers参考CRS项目
var allowedHeaders = map[string]bool{
"accept": true,
"x-stainless-retry-count": true,
"x-stainless-timeout": true,
"x-stainless-lang": true,
"x-stainless-package-version": true,
"x-stainless-os": true,
"x-stainless-arch": true,
"x-stainless-runtime": true,
"x-stainless-runtime-version": true,
"x-stainless-helper-method": true,
"anthropic-dangerous-direct-browser-access": true,
"anthropic-version": true,
"x-app": true,
"anthropic-beta": true,
"accept-language": true,
"sec-fetch-mode": true,
"user-agent": true,
"content-type": true,
}
// GatewayCache defines cache operations for gateway service
type GatewayCache interface {
GetSessionAccountID(ctx context.Context, sessionHash string) (int64, error)
SetSessionAccountID(ctx context.Context, sessionHash string, accountID int64, ttl time.Duration) error
RefreshSessionTTL(ctx context.Context, sessionHash string, ttl time.Duration) error
}
type AccountWaitPlan struct {
AccountID int64
MaxConcurrency int
Timeout time.Duration
MaxWaiting int
}
type AccountSelectionResult struct {
Account *Account
Acquired bool
ReleaseFunc func()
WaitPlan *AccountWaitPlan // nil means no wait allowed
}
// ClaudeUsage 表示Claude API返回的usage信息
type ClaudeUsage struct {
InputTokens int `json:"input_tokens"`
OutputTokens int `json:"output_tokens"`
CacheCreationInputTokens int `json:"cache_creation_input_tokens"`
CacheReadInputTokens int `json:"cache_read_input_tokens"`
}
// ForwardResult 转发结果
type ForwardResult struct {
RequestID string
Usage ClaudeUsage
Model string
Stream bool
Duration time.Duration
FirstTokenMs *int // 首字时间(流式请求)
}
// UpstreamFailoverError indicates an upstream error that should trigger account failover.
type UpstreamFailoverError struct {
StatusCode int
}
func (e *UpstreamFailoverError) Error() string {
return fmt.Sprintf("upstream error: %d (failover)", e.StatusCode)
}
// GatewayService handles API gateway operations
type GatewayService struct {
accountRepo AccountRepository
groupRepo GroupRepository
usageLogRepo UsageLogRepository
userRepo UserRepository
userSubRepo UserSubscriptionRepository
cache GatewayCache
cfg *config.Config
billingService *BillingService
rateLimitService *RateLimitService
billingCacheService *BillingCacheService
identityService *IdentityService
httpUpstream HTTPUpstream
deferredService *DeferredService
concurrencyService *ConcurrencyService
}
// NewGatewayService creates a new GatewayService
func NewGatewayService(
accountRepo AccountRepository,
groupRepo GroupRepository,
usageLogRepo UsageLogRepository,
userRepo UserRepository,
userSubRepo UserSubscriptionRepository,
cache GatewayCache,
cfg *config.Config,
concurrencyService *ConcurrencyService,
billingService *BillingService,
rateLimitService *RateLimitService,
billingCacheService *BillingCacheService,
identityService *IdentityService,
httpUpstream HTTPUpstream,
deferredService *DeferredService,
) *GatewayService {
return &GatewayService{
accountRepo: accountRepo,
groupRepo: groupRepo,
usageLogRepo: usageLogRepo,
userRepo: userRepo,
userSubRepo: userSubRepo,
cache: cache,
cfg: cfg,
concurrencyService: concurrencyService,
billingService: billingService,
rateLimitService: rateLimitService,
billingCacheService: billingCacheService,
identityService: identityService,
httpUpstream: httpUpstream,
deferredService: deferredService,
}
}
// GenerateSessionHash 从预解析请求计算粘性会话 hash
func (s *GatewayService) GenerateSessionHash(parsed *ParsedRequest) string {
if parsed == nil {
return ""
}
// 1. 最高优先级:从 metadata.user_id 提取 session_xxx
if parsed.MetadataUserID != "" {
if match := sessionIDRegex.FindStringSubmatch(parsed.MetadataUserID); len(match) > 1 {
return match[1]
}
}
// 2. 提取带 cache_control: {type: "ephemeral"} 的内容
cacheableContent := s.extractCacheableContent(parsed)
if cacheableContent != "" {
return s.hashContent(cacheableContent)
}
// 3. Fallback: 使用 system 内容
if parsed.System != nil {
systemText := s.extractTextFromSystem(parsed.System)
if systemText != "" {
return s.hashContent(systemText)
}
}
// 4. 最后 fallback: 使用第一条消息
if len(parsed.Messages) > 0 {
if firstMsg, ok := parsed.Messages[0].(map[string]any); ok {
msgText := s.extractTextFromContent(firstMsg["content"])
if msgText != "" {
return s.hashContent(msgText)
}
}
}
return ""
}
// BindStickySession sets session -> account binding with standard TTL.
func (s *GatewayService) BindStickySession(ctx context.Context, sessionHash string, accountID int64) error {
if sessionHash == "" || accountID <= 0 || s.cache == nil {
return nil
}
return s.cache.SetSessionAccountID(ctx, sessionHash, accountID, stickySessionTTL)
}
func (s *GatewayService) extractCacheableContent(parsed *ParsedRequest) string {
if parsed == nil {
return ""
}
var builder strings.Builder
// 检查 system 中的 cacheable 内容
if system, ok := parsed.System.([]any); ok {
for _, part := range system {
if partMap, ok := part.(map[string]any); ok {
if cc, ok := partMap["cache_control"].(map[string]any); ok {
if cc["type"] == "ephemeral" {
if text, ok := partMap["text"].(string); ok {
_, _ = builder.WriteString(text)
}
}
}
}
}
}
systemText := builder.String()
// 检查 messages 中的 cacheable 内容
for _, msg := range parsed.Messages {
if msgMap, ok := msg.(map[string]any); ok {
if msgContent, ok := msgMap["content"].([]any); ok {
for _, part := range msgContent {
if partMap, ok := part.(map[string]any); ok {
if cc, ok := partMap["cache_control"].(map[string]any); ok {
if cc["type"] == "ephemeral" {
return s.extractTextFromContent(msgMap["content"])
}
}
}
}
}
}
}
return systemText
}
func (s *GatewayService) extractTextFromSystem(system any) string {
switch v := system.(type) {
case string:
return v
case []any:
var texts []string
for _, part := range v {
if partMap, ok := part.(map[string]any); ok {
if text, ok := partMap["text"].(string); ok {
texts = append(texts, text)
}
}
}
return strings.Join(texts, "")
}
return ""
}
func (s *GatewayService) extractTextFromContent(content any) string {
switch v := content.(type) {
case string:
return v
case []any:
var texts []string
for _, part := range v {
if partMap, ok := part.(map[string]any); ok {
if partMap["type"] == "text" {
if text, ok := partMap["text"].(string); ok {
texts = append(texts, text)
}
}
}
}
return strings.Join(texts, "")
}
return ""
}
func (s *GatewayService) hashContent(content string) string {
hash := sha256.Sum256([]byte(content))
return hex.EncodeToString(hash[:16]) // 32字符
}
// replaceModelInBody 替换请求体中的model字段
func (s *GatewayService) replaceModelInBody(body []byte, newModel string) []byte {
var req map[string]any
if err := json.Unmarshal(body, &req); err != nil {
return body
}
req["model"] = newModel
newBody, err := json.Marshal(req)
if err != nil {
return body
}
return newBody
}
// SelectAccount 选择账号(粘性会话+优先级)
func (s *GatewayService) SelectAccount(ctx context.Context, groupID *int64, sessionHash string) (*Account, error) {
return s.SelectAccountForModel(ctx, groupID, sessionHash, "")
}
// SelectAccountForModel 选择支持指定模型的账号(粘性会话+优先级+模型映射)
func (s *GatewayService) SelectAccountForModel(ctx context.Context, groupID *int64, sessionHash string, requestedModel string) (*Account, error) {
return s.SelectAccountForModelWithExclusions(ctx, groupID, sessionHash, requestedModel, nil)
}
// SelectAccountForModelWithExclusions selects an account supporting the requested model while excluding specified accounts.
func (s *GatewayService) SelectAccountForModelWithExclusions(ctx context.Context, groupID *int64, sessionHash string, requestedModel string, excludedIDs map[int64]struct{}) (*Account, error) {
// 优先检查 context 中的强制平台(/antigravity 路由)
var platform string
forcePlatform, hasForcePlatform := ctx.Value(ctxkey.ForcePlatform).(string)
if hasForcePlatform && forcePlatform != "" {
platform = forcePlatform
} else if groupID != nil {
// 根据分组 platform 决定查询哪种账号
group, err := s.groupRepo.GetByID(ctx, *groupID)
if err != nil {
return nil, fmt.Errorf("get group failed: %w", err)
}
platform = group.Platform
} else {
// 无分组时只使用原生 anthropic 平台
platform = PlatformAnthropic
}
// anthropic/gemini 分组支持混合调度(包含启用了 mixed_scheduling 的 antigravity 账户)
// 注意:强制平台模式不走混合调度
if (platform == PlatformAnthropic || platform == PlatformGemini) && !hasForcePlatform {
return s.selectAccountWithMixedScheduling(ctx, groupID, sessionHash, requestedModel, excludedIDs, platform)
}
// 强制平台模式:优先按分组查找,找不到再查全部该平台账户
if hasForcePlatform && groupID != nil {
account, err := s.selectAccountForModelWithPlatform(ctx, groupID, sessionHash, requestedModel, excludedIDs, platform)
if err == nil {
return account, nil
}
// 分组中找不到,回退查询全部该平台账户
groupID = nil
}
// antigravity 分组、强制平台模式或无分组使用单平台选择
return s.selectAccountForModelWithPlatform(ctx, groupID, sessionHash, requestedModel, excludedIDs, platform)
}
// SelectAccountWithLoadAwareness selects account with load-awareness and wait plan.
func (s *GatewayService) SelectAccountWithLoadAwareness(ctx context.Context, groupID *int64, sessionHash string, requestedModel string, excludedIDs map[int64]struct{}) (*AccountSelectionResult, error) {
cfg := s.schedulingConfig()
var stickyAccountID int64
if sessionHash != "" && s.cache != nil {
if accountID, err := s.cache.GetSessionAccountID(ctx, sessionHash); err == nil {
stickyAccountID = accountID
}
}
if s.concurrencyService == nil || !cfg.LoadBatchEnabled {
account, err := s.SelectAccountForModelWithExclusions(ctx, groupID, sessionHash, requestedModel, excludedIDs)
if err != nil {
return nil, err
}
result, err := s.tryAcquireAccountSlot(ctx, account.ID, account.Concurrency)
if err == nil && result.Acquired {
return &AccountSelectionResult{
Account: account,
Acquired: true,
ReleaseFunc: result.ReleaseFunc,
}, nil
}
if stickyAccountID > 0 && stickyAccountID == account.ID && s.concurrencyService != nil {
waitingCount, _ := s.concurrencyService.GetAccountWaitingCount(ctx, account.ID)
if waitingCount < cfg.StickySessionMaxWaiting {
return &AccountSelectionResult{
Account: account,
WaitPlan: &AccountWaitPlan{
AccountID: account.ID,
MaxConcurrency: account.Concurrency,
Timeout: cfg.StickySessionWaitTimeout,
MaxWaiting: cfg.StickySessionMaxWaiting,
},
}, nil
}
}
return &AccountSelectionResult{
Account: account,
WaitPlan: &AccountWaitPlan{
AccountID: account.ID,
MaxConcurrency: account.Concurrency,
Timeout: cfg.FallbackWaitTimeout,
MaxWaiting: cfg.FallbackMaxWaiting,
},
}, nil
}
platform, hasForcePlatform, err := s.resolvePlatform(ctx, groupID)
if err != nil {
return nil, err
}
preferOAuth := platform == PlatformGemini
accounts, useMixed, err := s.listSchedulableAccounts(ctx, groupID, platform, hasForcePlatform)
if err != nil {
return nil, err
}
if len(accounts) == 0 {
return nil, errors.New("no available accounts")
}
isExcluded := func(accountID int64) bool {
if excludedIDs == nil {
return false
}
_, excluded := excludedIDs[accountID]
return excluded
}
// ============ Layer 1: 粘性会话优先 ============
if sessionHash != "" && s.cache != nil {
accountID, err := s.cache.GetSessionAccountID(ctx, sessionHash)
if err == nil && accountID > 0 && !isExcluded(accountID) {
account, err := s.accountRepo.GetByID(ctx, accountID)
if err == nil && s.isAccountAllowedForPlatform(account, platform, useMixed) &&
account.IsSchedulable() &&
(requestedModel == "" || s.isModelSupportedByAccount(account, requestedModel)) {
result, err := s.tryAcquireAccountSlot(ctx, accountID, account.Concurrency)
if err == nil && result.Acquired {
_ = s.cache.RefreshSessionTTL(ctx, sessionHash, stickySessionTTL)
return &AccountSelectionResult{
Account: account,
Acquired: true,
ReleaseFunc: result.ReleaseFunc,
}, nil
}
waitingCount, _ := s.concurrencyService.GetAccountWaitingCount(ctx, accountID)
if waitingCount < cfg.StickySessionMaxWaiting {
return &AccountSelectionResult{
Account: account,
WaitPlan: &AccountWaitPlan{
AccountID: accountID,
MaxConcurrency: account.Concurrency,
Timeout: cfg.StickySessionWaitTimeout,
MaxWaiting: cfg.StickySessionMaxWaiting,
},
}, nil
}
}
}
}
// ============ Layer 2: 负载感知选择 ============
candidates := make([]*Account, 0, len(accounts))
for i := range accounts {
acc := &accounts[i]
if isExcluded(acc.ID) {
continue
}
if !s.isAccountAllowedForPlatform(acc, platform, useMixed) {
continue
}
if requestedModel != "" && !s.isModelSupportedByAccount(acc, requestedModel) {
continue
}
candidates = append(candidates, acc)
}
if len(candidates) == 0 {
return nil, errors.New("no available accounts")
}
accountLoads := make([]AccountWithConcurrency, 0, len(candidates))
for _, acc := range candidates {
accountLoads = append(accountLoads, AccountWithConcurrency{
ID: acc.ID,
MaxConcurrency: acc.Concurrency,
})
}
loadMap, err := s.concurrencyService.GetAccountsLoadBatch(ctx, accountLoads)
if err != nil {
if result, ok := s.tryAcquireByLegacyOrder(ctx, candidates, sessionHash, preferOAuth); ok {
return result, nil
}
} else {
type accountWithLoad struct {
account *Account
loadInfo *AccountLoadInfo
}
var available []accountWithLoad
for _, acc := range candidates {
loadInfo := loadMap[acc.ID]
if loadInfo == nil {
loadInfo = &AccountLoadInfo{AccountID: acc.ID}
}
if loadInfo.LoadRate < 100 {
available = append(available, accountWithLoad{
account: acc,
loadInfo: loadInfo,
})
}
}
if len(available) > 0 {
sort.SliceStable(available, func(i, j int) bool {
a, b := available[i], available[j]
if a.account.Priority != b.account.Priority {
return a.account.Priority < b.account.Priority
}
if a.loadInfo.LoadRate != b.loadInfo.LoadRate {
return a.loadInfo.LoadRate < b.loadInfo.LoadRate
}
switch {
case a.account.LastUsedAt == nil && b.account.LastUsedAt != nil:
return true
case a.account.LastUsedAt != nil && b.account.LastUsedAt == nil:
return false
case a.account.LastUsedAt == nil && b.account.LastUsedAt == nil:
if preferOAuth && a.account.Type != b.account.Type {
return a.account.Type == AccountTypeOAuth
}
return false
default:
return a.account.LastUsedAt.Before(*b.account.LastUsedAt)
}
})
for _, item := range available {
result, err := s.tryAcquireAccountSlot(ctx, item.account.ID, item.account.Concurrency)
if err == nil && result.Acquired {
if sessionHash != "" {
_ = s.cache.SetSessionAccountID(ctx, sessionHash, item.account.ID, stickySessionTTL)
}
return &AccountSelectionResult{
Account: item.account,
Acquired: true,
ReleaseFunc: result.ReleaseFunc,
}, nil
}
}
}
}
// ============ Layer 3: 兜底排队 ============
sortAccountsByPriorityAndLastUsed(candidates, preferOAuth)
for _, acc := range candidates {
return &AccountSelectionResult{
Account: acc,
WaitPlan: &AccountWaitPlan{
AccountID: acc.ID,
MaxConcurrency: acc.Concurrency,
Timeout: cfg.FallbackWaitTimeout,
MaxWaiting: cfg.FallbackMaxWaiting,
},
}, nil
}
return nil, errors.New("no available accounts")
}
func (s *GatewayService) tryAcquireByLegacyOrder(ctx context.Context, candidates []*Account, sessionHash string, preferOAuth bool) (*AccountSelectionResult, bool) {
ordered := append([]*Account(nil), candidates...)
sortAccountsByPriorityAndLastUsed(ordered, preferOAuth)
for _, acc := range ordered {
result, err := s.tryAcquireAccountSlot(ctx, acc.ID, acc.Concurrency)
if err == nil && result.Acquired {
if sessionHash != "" {
_ = s.cache.SetSessionAccountID(ctx, sessionHash, acc.ID, stickySessionTTL)
}
return &AccountSelectionResult{
Account: acc,
Acquired: true,
ReleaseFunc: result.ReleaseFunc,
}, true
}
}
return nil, false
}
func (s *GatewayService) schedulingConfig() config.GatewaySchedulingConfig {
if s.cfg != nil {
return s.cfg.Gateway.Scheduling
}
return config.GatewaySchedulingConfig{
StickySessionMaxWaiting: 3,
StickySessionWaitTimeout: 45 * time.Second,
FallbackWaitTimeout: 30 * time.Second,
FallbackMaxWaiting: 100,
LoadBatchEnabled: true,
SlotCleanupInterval: 30 * time.Second,
}
}
func (s *GatewayService) resolvePlatform(ctx context.Context, groupID *int64) (string, bool, error) {
forcePlatform, hasForcePlatform := ctx.Value(ctxkey.ForcePlatform).(string)
if hasForcePlatform && forcePlatform != "" {
return forcePlatform, true, nil
}
if groupID != nil {
group, err := s.groupRepo.GetByID(ctx, *groupID)
if err != nil {
return "", false, fmt.Errorf("get group failed: %w", err)
}
return group.Platform, false, nil
}
return PlatformAnthropic, false, nil
}
func (s *GatewayService) listSchedulableAccounts(ctx context.Context, groupID *int64, platform string, hasForcePlatform bool) ([]Account, bool, error) {
useMixed := (platform == PlatformAnthropic || platform == PlatformGemini) && !hasForcePlatform
if useMixed {
platforms := []string{platform, PlatformAntigravity}
var accounts []Account
var err error
if groupID != nil {
accounts, err = s.accountRepo.ListSchedulableByGroupIDAndPlatforms(ctx, *groupID, platforms)
} else {
accounts, err = s.accountRepo.ListSchedulableByPlatforms(ctx, platforms)
}
if err != nil {
return nil, useMixed, err
}
filtered := make([]Account, 0, len(accounts))
for _, acc := range accounts {
if acc.Platform == PlatformAntigravity && !acc.IsMixedSchedulingEnabled() {
continue
}
filtered = append(filtered, acc)
}
return filtered, useMixed, nil
}
var accounts []Account
var err error
if s.cfg != nil && s.cfg.RunMode == config.RunModeSimple {
accounts, err = s.accountRepo.ListSchedulableByPlatform(ctx, platform)
} else if groupID != nil {
accounts, err = s.accountRepo.ListSchedulableByGroupIDAndPlatform(ctx, *groupID, platform)
if err == nil && len(accounts) == 0 && hasForcePlatform {
accounts, err = s.accountRepo.ListSchedulableByPlatform(ctx, platform)
}
} else {
accounts, err = s.accountRepo.ListSchedulableByPlatform(ctx, platform)
}
if err != nil {
return nil, useMixed, err
}
return accounts, useMixed, nil
}
func (s *GatewayService) isAccountAllowedForPlatform(account *Account, platform string, useMixed bool) bool {
if account == nil {
return false
}
if useMixed {
if account.Platform == platform {
return true
}
return account.Platform == PlatformAntigravity && account.IsMixedSchedulingEnabled()
}
return account.Platform == platform
}
func (s *GatewayService) tryAcquireAccountSlot(ctx context.Context, accountID int64, maxConcurrency int) (*AcquireResult, error) {
if s.concurrencyService == nil {
return &AcquireResult{Acquired: true, ReleaseFunc: func() {}}, nil
}
return s.concurrencyService.AcquireAccountSlot(ctx, accountID, maxConcurrency)
}
func sortAccountsByPriorityAndLastUsed(accounts []*Account, preferOAuth bool) {
sort.SliceStable(accounts, func(i, j int) bool {
a, b := accounts[i], accounts[j]
if a.Priority != b.Priority {
return a.Priority < b.Priority
}
switch {
case a.LastUsedAt == nil && b.LastUsedAt != nil:
return true
case a.LastUsedAt != nil && b.LastUsedAt == nil:
return false
case a.LastUsedAt == nil && b.LastUsedAt == nil:
if preferOAuth && a.Type != b.Type {
return a.Type == AccountTypeOAuth
}
return false
default:
return a.LastUsedAt.Before(*b.LastUsedAt)
}
})
}
// selectAccountForModelWithPlatform 选择单平台账户(完全隔离)
func (s *GatewayService) selectAccountForModelWithPlatform(ctx context.Context, groupID *int64, sessionHash string, requestedModel string, excludedIDs map[int64]struct{}, platform string) (*Account, error) {
preferOAuth := platform == PlatformGemini
// 1. 查询粘性会话
if sessionHash != "" {
accountID, err := s.cache.GetSessionAccountID(ctx, sessionHash)
if err == nil && accountID > 0 {
if _, excluded := excludedIDs[accountID]; !excluded {
account, err := s.accountRepo.GetByID(ctx, accountID)
// 检查账号平台是否匹配(确保粘性会话不会跨平台)
if err == nil && account.Platform == platform && account.IsSchedulable() && (requestedModel == "" || s.isModelSupportedByAccount(account, requestedModel)) {
if err := s.cache.RefreshSessionTTL(ctx, sessionHash, stickySessionTTL); err != nil {
log.Printf("refresh session ttl failed: session=%s err=%v", sessionHash, err)
}
return account, nil
}
}
}
}
// 2. 获取可调度账号列表(单平台)
var accounts []Account
var err error
if s.cfg.RunMode == config.RunModeSimple {
// 简易模式:忽略 groupID查询所有可用账号
accounts, err = s.accountRepo.ListSchedulableByPlatform(ctx, platform)
} else if groupID != nil {
accounts, err = s.accountRepo.ListSchedulableByGroupIDAndPlatform(ctx, *groupID, platform)
} else {
accounts, err = s.accountRepo.ListSchedulableByPlatform(ctx, platform)
}
if err != nil {
return nil, fmt.Errorf("query accounts failed: %w", err)
}
// 3. 按优先级+最久未用选择(考虑模型支持)
var selected *Account
for i := range accounts {
acc := &accounts[i]
if _, excluded := excludedIDs[acc.ID]; excluded {
continue
}
if requestedModel != "" && !s.isModelSupportedByAccount(acc, requestedModel) {
continue
}
if selected == nil {
selected = acc
continue
}
if acc.Priority < selected.Priority {
selected = acc
} else if acc.Priority == selected.Priority {
switch {
case acc.LastUsedAt == nil && selected.LastUsedAt != nil:
selected = acc
case acc.LastUsedAt != nil && selected.LastUsedAt == nil:
// keep selected (never used is preferred)
case acc.LastUsedAt == nil && selected.LastUsedAt == nil:
if preferOAuth && acc.Type != selected.Type && acc.Type == AccountTypeOAuth {
selected = acc
}
default:
if acc.LastUsedAt.Before(*selected.LastUsedAt) {
selected = acc
}
}
}
}
if selected == nil {
if requestedModel != "" {
return nil, fmt.Errorf("no available accounts supporting model: %s", requestedModel)
}
return nil, errors.New("no available accounts")
}
// 4. 建立粘性绑定
if sessionHash != "" {
if err := s.cache.SetSessionAccountID(ctx, sessionHash, selected.ID, stickySessionTTL); err != nil {
log.Printf("set session account failed: session=%s account_id=%d err=%v", sessionHash, selected.ID, err)
}
}
return selected, nil
}
// selectAccountWithMixedScheduling 选择账户(支持混合调度)
// 查询原生平台账户 + 启用 mixed_scheduling 的 antigravity 账户
func (s *GatewayService) selectAccountWithMixedScheduling(ctx context.Context, groupID *int64, sessionHash string, requestedModel string, excludedIDs map[int64]struct{}, nativePlatform string) (*Account, error) {
platforms := []string{nativePlatform, PlatformAntigravity}
preferOAuth := nativePlatform == PlatformGemini
// 1. 查询粘性会话
if sessionHash != "" {
accountID, err := s.cache.GetSessionAccountID(ctx, sessionHash)
if err == nil && accountID > 0 {
if _, excluded := excludedIDs[accountID]; !excluded {
account, err := s.accountRepo.GetByID(ctx, accountID)
// 检查账号是否有效原生平台直接匹配antigravity 需要启用混合调度
if err == nil && account.IsSchedulable() && (requestedModel == "" || s.isModelSupportedByAccount(account, requestedModel)) {
if account.Platform == nativePlatform || (account.Platform == PlatformAntigravity && account.IsMixedSchedulingEnabled()) {
if err := s.cache.RefreshSessionTTL(ctx, sessionHash, stickySessionTTL); err != nil {
log.Printf("refresh session ttl failed: session=%s err=%v", sessionHash, err)
}
return account, nil
}
}
}
}
}
// 2. 获取可调度账号列表
var accounts []Account
var err error
if groupID != nil {
accounts, err = s.accountRepo.ListSchedulableByGroupIDAndPlatforms(ctx, *groupID, platforms)
} else {
accounts, err = s.accountRepo.ListSchedulableByPlatforms(ctx, platforms)
}
if err != nil {
return nil, fmt.Errorf("query accounts failed: %w", err)
}
// 3. 按优先级+最久未用选择(考虑模型支持和混合调度)
var selected *Account
for i := range accounts {
acc := &accounts[i]
if _, excluded := excludedIDs[acc.ID]; excluded {
continue
}
// 过滤原生平台直接通过antigravity 需要启用混合调度
if acc.Platform == PlatformAntigravity && !acc.IsMixedSchedulingEnabled() {
continue
}
if requestedModel != "" && !s.isModelSupportedByAccount(acc, requestedModel) {
continue
}
if selected == nil {
selected = acc
continue
}
if acc.Priority < selected.Priority {
selected = acc
} else if acc.Priority == selected.Priority {
switch {
case acc.LastUsedAt == nil && selected.LastUsedAt != nil:
selected = acc
case acc.LastUsedAt != nil && selected.LastUsedAt == nil:
// keep selected (never used is preferred)
case acc.LastUsedAt == nil && selected.LastUsedAt == nil:
if preferOAuth && acc.Platform == PlatformGemini && selected.Platform == PlatformGemini && acc.Type != selected.Type && acc.Type == AccountTypeOAuth {
selected = acc
}
default:
if acc.LastUsedAt.Before(*selected.LastUsedAt) {
selected = acc
}
}
}
}
if selected == nil {
if requestedModel != "" {
return nil, fmt.Errorf("no available accounts supporting model: %s", requestedModel)
}
return nil, errors.New("no available accounts")
}
// 4. 建立粘性绑定
if sessionHash != "" {
if err := s.cache.SetSessionAccountID(ctx, sessionHash, selected.ID, stickySessionTTL); err != nil {
log.Printf("set session account failed: session=%s account_id=%d err=%v", sessionHash, selected.ID, err)
}
}
return selected, nil
}
// isModelSupportedByAccount 根据账户平台检查模型支持
func (s *GatewayService) isModelSupportedByAccount(account *Account, requestedModel string) bool {
if account.Platform == PlatformAntigravity {
// Antigravity 平台使用专门的模型支持检查
return IsAntigravityModelSupported(requestedModel)
}
// 其他平台使用账户的模型支持检查
return account.IsModelSupported(requestedModel)
}
// IsAntigravityModelSupported 检查 Antigravity 平台是否支持指定模型
// 所有 claude- 和 gemini- 前缀的模型都能通过映射或透传支持
func IsAntigravityModelSupported(requestedModel string) bool {
return strings.HasPrefix(requestedModel, "claude-") ||
strings.HasPrefix(requestedModel, "gemini-")
}
// GetAccessToken 获取账号凭证
func (s *GatewayService) GetAccessToken(ctx context.Context, account *Account) (string, string, error) {
switch account.Type {
case AccountTypeOAuth, AccountTypeSetupToken:
// Both oauth and setup-token use OAuth token flow
return s.getOAuthToken(ctx, account)
case AccountTypeApiKey:
apiKey := account.GetCredential("api_key")
if apiKey == "" {
return "", "", errors.New("api_key not found in credentials")
}
return apiKey, "apikey", nil
default:
return "", "", fmt.Errorf("unsupported account type: %s", account.Type)
}
}
func (s *GatewayService) getOAuthToken(ctx context.Context, account *Account) (string, string, error) {
accessToken := account.GetCredential("access_token")
if accessToken == "" {
return "", "", errors.New("access_token not found in credentials")
}
// Token刷新由后台 TokenRefreshService 处理此处只返回当前token
return accessToken, "oauth", nil
}
// 重试相关常量
const (
maxRetries = 10 // 最大重试次数
retryDelay = 3 * time.Second // 重试等待时间
)
func (s *GatewayService) shouldRetryUpstreamError(account *Account, statusCode int) bool {
// OAuth/Setup Token 账号:仅 403 重试
if account.IsOAuth() {
return statusCode == 403
}
// API Key 账号:未配置的错误码重试
return !account.ShouldHandleErrorCode(statusCode)
}
// shouldFailoverUpstreamError determines whether an upstream error should trigger account failover.
func (s *GatewayService) shouldFailoverUpstreamError(statusCode int) bool {
switch statusCode {
case 401, 403, 429, 529:
return true
default:
return statusCode >= 500
}
}
// Forward 转发请求到Claude API
func (s *GatewayService) Forward(ctx context.Context, c *gin.Context, account *Account, parsed *ParsedRequest) (*ForwardResult, error) {
startTime := time.Now()
if parsed == nil {
return nil, fmt.Errorf("parse request: empty request")
}
body := parsed.Body
reqModel := parsed.Model
reqStream := parsed.Stream
if !parsed.HasSystem {
body, _ = sjson.SetBytes(body, "system", []any{
map[string]any{
"type": "text",
"text": "You are Claude Code, Anthropic's official CLI for Claude.",
"cache_control": map[string]string{
"type": "ephemeral",
},
},
})
}
// 应用模型映射仅对apikey类型账号
originalModel := reqModel
if account.Type == AccountTypeApiKey {
mappedModel := account.GetMappedModel(reqModel)
if mappedModel != reqModel {
// 替换请求体中的模型名
body = s.replaceModelInBody(body, mappedModel)
reqModel = mappedModel
log.Printf("Model mapping applied: %s -> %s (account: %s)", originalModel, mappedModel, account.Name)
}
}
// 获取凭证
token, tokenType, err := s.GetAccessToken(ctx, account)
if err != nil {
return nil, err
}
// 获取代理URL
proxyURL := ""
if account.ProxyID != nil && account.Proxy != nil {
proxyURL = account.Proxy.URL()
}
// 重试循环
var resp *http.Response
for attempt := 1; attempt <= maxRetries; attempt++ {
// 构建上游请求(每次重试需要重新构建,因为请求体需要重新读取)
upstreamReq, err := s.buildUpstreamRequest(ctx, c, account, body, token, tokenType, reqModel)
if err != nil {
return nil, err
}
// 发送请求
resp, err = s.httpUpstream.Do(upstreamReq, proxyURL, account.ID, account.Concurrency)
if err != nil {
return nil, fmt.Errorf("upstream request failed: %w", err)
}
// 检查是否需要重试
if resp.StatusCode >= 400 && s.shouldRetryUpstreamError(account, resp.StatusCode) {
if attempt < maxRetries {
log.Printf("Account %d: upstream error %d, retry %d/%d after %v",
account.ID, resp.StatusCode, attempt, maxRetries, retryDelay)
_ = resp.Body.Close()
time.Sleep(retryDelay)
continue
}
// 最后一次尝试也失败,跳出循环处理重试耗尽
break
}
// 不需要重试(成功或不可重试的错误),跳出循环
break
}
defer func() { _ = resp.Body.Close() }()
// 处理重试耗尽的情况
if resp.StatusCode >= 400 && s.shouldRetryUpstreamError(account, resp.StatusCode) {
if s.shouldFailoverUpstreamError(resp.StatusCode) {
s.handleRetryExhaustedSideEffects(ctx, resp, account)
return nil, &UpstreamFailoverError{StatusCode: resp.StatusCode}
}
return s.handleRetryExhaustedError(ctx, resp, c, account)
}
// 处理可切换账号的错误
if resp.StatusCode >= 400 && s.shouldFailoverUpstreamError(resp.StatusCode) {
s.handleFailoverSideEffects(ctx, resp, account)
return nil, &UpstreamFailoverError{StatusCode: resp.StatusCode}
}
// 处理错误响应(不可重试的错误)
if resp.StatusCode >= 400 {
// 可选:对部分 400 触发 failover默认关闭以保持语义
if resp.StatusCode == 400 && s.cfg != nil && s.cfg.Gateway.FailoverOn400 {
respBody, readErr := io.ReadAll(resp.Body)
if readErr != nil {
// ReadAll failed, fall back to normal error handling without consuming the stream
return s.handleErrorResponse(ctx, resp, c, account)
}
_ = resp.Body.Close()
resp.Body = io.NopCloser(bytes.NewReader(respBody))
if s.shouldFailoverOn400(respBody) {
if s.cfg.Gateway.LogUpstreamErrorBody {
log.Printf(
"Account %d: 400 error, attempting failover: %s",
account.ID,
truncateForLog(respBody, s.cfg.Gateway.LogUpstreamErrorBodyMaxBytes),
)
} else {
log.Printf("Account %d: 400 error, attempting failover", account.ID)
}
s.handleFailoverSideEffects(ctx, resp, account)
return nil, &UpstreamFailoverError{StatusCode: resp.StatusCode}
}
}
return s.handleErrorResponse(ctx, resp, c, account)
}
// 处理正常响应
var usage *ClaudeUsage
var firstTokenMs *int
if reqStream {
streamResult, err := s.handleStreamingResponse(ctx, resp, c, account, startTime, originalModel, reqModel)
if err != nil {
if err.Error() == "have error in stream" {
return nil, &UpstreamFailoverError{
StatusCode: 403,
}
}
return nil, err
}
usage = streamResult.usage
firstTokenMs = streamResult.firstTokenMs
} else {
usage, err = s.handleNonStreamingResponse(ctx, resp, c, account, originalModel, reqModel)
if err != nil {
return nil, err
}
}
return &ForwardResult{
RequestID: resp.Header.Get("x-request-id"),
Usage: *usage,
Model: originalModel, // 使用原始模型用于计费和日志
Stream: reqStream,
Duration: time.Since(startTime),
FirstTokenMs: firstTokenMs,
}, nil
}
func (s *GatewayService) buildUpstreamRequest(ctx context.Context, c *gin.Context, account *Account, body []byte, token, tokenType, modelID string) (*http.Request, error) {
// 确定目标URL
targetURL := claudeAPIURL
if account.Type == AccountTypeApiKey {
baseURL := account.GetBaseURL()
targetURL = baseURL + "/v1/messages"
}
// OAuth账号应用统一指纹
var fingerprint *Fingerprint
if account.IsOAuth() && s.identityService != nil {
// 1. 获取或创建指纹包含随机生成的ClientID
fp, err := s.identityService.GetOrCreateFingerprint(ctx, account.ID, c.Request.Header)
if err != nil {
log.Printf("Warning: failed to get fingerprint for account %d: %v", account.ID, err)
// 失败时降级为透传原始headers
} else {
fingerprint = fp
// 2. 重写metadata.user_id需要指纹中的ClientID和账号的account_uuid
accountUUID := account.GetExtraString("account_uuid")
if accountUUID != "" && fp.ClientID != "" {
if newBody, err := s.identityService.RewriteUserID(body, account.ID, accountUUID, fp.ClientID); err == nil && len(newBody) > 0 {
body = newBody
}
}
}
}
// Filter thinking blocks from request body (prevents 400 errors from missing/invalid signatures).
// We apply this for the main /v1/messages path as well as count_tokens.
body = FilterThinkingBlocks(body)
req, err := http.NewRequestWithContext(ctx, "POST", targetURL, bytes.NewReader(body))
if err != nil {
return nil, err
}
// 设置认证头
if tokenType == "oauth" {
req.Header.Set("authorization", "Bearer "+token)
} else {
req.Header.Set("x-api-key", token)
}
// 白名单透传headers
for key, values := range c.Request.Header {
lowerKey := strings.ToLower(key)
if allowedHeaders[lowerKey] {
for _, v := range values {
req.Header.Add(key, v)
}
}
}
// OAuth账号应用缓存的指纹到请求头覆盖白名单透传的头
if fingerprint != nil {
s.identityService.ApplyFingerprint(req, fingerprint)
}
// 确保必要的headers存在
if req.Header.Get("content-type") == "" {
req.Header.Set("content-type", "application/json")
}
if req.Header.Get("anthropic-version") == "" {
req.Header.Set("anthropic-version", "2023-06-01")
}
// 处理anthropic-beta headerOAuth账号需要特殊处理
if tokenType == "oauth" {
req.Header.Set("anthropic-beta", s.getBetaHeader(modelID, c.GetHeader("anthropic-beta")))
} else if s.cfg != nil && s.cfg.Gateway.InjectBetaForApiKey && req.Header.Get("anthropic-beta") == "" {
// API-key仅在请求显式使用 beta 特性且客户端未提供时,按需补齐(默认关闭)
if requestNeedsBetaFeatures(body) {
if beta := defaultApiKeyBetaHeader(body); beta != "" {
req.Header.Set("anthropic-beta", beta)
}
}
}
return req, nil
}
// getBetaHeader 处理anthropic-beta header
// 对于OAuth账号需要确保包含oauth-2025-04-20
func (s *GatewayService) getBetaHeader(modelID string, clientBetaHeader string) string {
// 如果客户端传了anthropic-beta
if clientBetaHeader != "" {
// 已包含oauth beta则直接返回
if strings.Contains(clientBetaHeader, claude.BetaOAuth) {
return clientBetaHeader
}
// 需要添加oauth beta
parts := strings.Split(clientBetaHeader, ",")
for i, p := range parts {
parts[i] = strings.TrimSpace(p)
}
// 在claude-code-20250219后面插入oauth beta
claudeCodeIdx := -1
for i, p := range parts {
if p == claude.BetaClaudeCode {
claudeCodeIdx = i
break
}
}
if claudeCodeIdx >= 0 {
// 在claude-code后面插入
newParts := make([]string, 0, len(parts)+1)
newParts = append(newParts, parts[:claudeCodeIdx+1]...)
newParts = append(newParts, claude.BetaOAuth)
newParts = append(newParts, parts[claudeCodeIdx+1:]...)
return strings.Join(newParts, ",")
}
// 没有claude-code放在第一位
return claude.BetaOAuth + "," + clientBetaHeader
}
// 客户端没传,根据模型生成
// haiku 模型不需要 claude-code beta
if strings.Contains(strings.ToLower(modelID), "haiku") {
return claude.HaikuBetaHeader
}
return claude.DefaultBetaHeader
}
func requestNeedsBetaFeatures(body []byte) bool {
tools := gjson.GetBytes(body, "tools")
if tools.Exists() && tools.IsArray() && len(tools.Array()) > 0 {
return true
}
if strings.EqualFold(gjson.GetBytes(body, "thinking.type").String(), "enabled") {
return true
}
return false
}
func defaultApiKeyBetaHeader(body []byte) string {
modelID := gjson.GetBytes(body, "model").String()
if strings.Contains(strings.ToLower(modelID), "haiku") {
return claude.ApiKeyHaikuBetaHeader
}
return claude.ApiKeyBetaHeader
}
func truncateForLog(b []byte, maxBytes int) string {
if maxBytes <= 0 {
maxBytes = 2048
}
if len(b) > maxBytes {
b = b[:maxBytes]
}
s := string(b)
// 保持一行,避免污染日志格式
s = strings.ReplaceAll(s, "\n", "\\n")
s = strings.ReplaceAll(s, "\r", "\\r")
return s
}
func (s *GatewayService) shouldFailoverOn400(respBody []byte) bool {
// 只对“可能是兼容性差异导致”的 400 允许切换,避免无意义重试。
// 默认保守:无法识别则不切换。
msg := strings.ToLower(strings.TrimSpace(extractUpstreamErrorMessage(respBody)))
if msg == "" {
return false
}
// 缺少/错误的 beta header换账号/链路可能成功(尤其是混合调度时)。
// 更精确匹配 beta 相关的兼容性问题,避免误触发切换。
if strings.Contains(msg, "anthropic-beta") ||
strings.Contains(msg, "beta feature") ||
strings.Contains(msg, "requires beta") {
return true
}
// thinking/tool streaming 等兼容性约束(常见于中间转换链路)
if strings.Contains(msg, "thinking") || strings.Contains(msg, "thought_signature") || strings.Contains(msg, "signature") {
return true
}
if strings.Contains(msg, "tool_use") || strings.Contains(msg, "tool_result") || strings.Contains(msg, "tools") {
return true
}
return false
}
func extractUpstreamErrorMessage(body []byte) string {
// Claude 风格:{"type":"error","error":{"type":"...","message":"..."}}
if m := gjson.GetBytes(body, "error.message").String(); strings.TrimSpace(m) != "" {
inner := strings.TrimSpace(m)
// 有些上游会把完整 JSON 作为字符串塞进 message
if strings.HasPrefix(inner, "{") {
if innerMsg := gjson.Get(inner, "error.message").String(); strings.TrimSpace(innerMsg) != "" {
return innerMsg
}
}
return m
}
// 兜底:尝试顶层 message
return gjson.GetBytes(body, "message").String()
}
func (s *GatewayService) handleErrorResponse(ctx context.Context, resp *http.Response, c *gin.Context, account *Account) (*ForwardResult, error) {
body, _ := io.ReadAll(resp.Body)
// 处理上游错误,标记账号状态
s.rateLimitService.HandleUpstreamError(ctx, account, resp.StatusCode, resp.Header, body)
// 根据状态码返回适当的自定义错误响应(不透传上游详细信息)
var errType, errMsg string
var statusCode int
switch resp.StatusCode {
case 400:
// 仅记录上游错误摘要(避免输出请求内容);需要时可通过配置打开
if s.cfg != nil && s.cfg.Gateway.LogUpstreamErrorBody {
log.Printf(
"Upstream 400 error (account=%d platform=%s type=%s): %s",
account.ID,
account.Platform,
account.Type,
truncateForLog(body, s.cfg.Gateway.LogUpstreamErrorBodyMaxBytes),
)
}
c.Data(http.StatusBadRequest, "application/json", body)
return nil, fmt.Errorf("upstream error: %d", resp.StatusCode)
case 401:
statusCode = http.StatusBadGateway
errType = "upstream_error"
errMsg = "Upstream authentication failed, please contact administrator"
case 403:
statusCode = http.StatusBadGateway
errType = "upstream_error"
errMsg = "Upstream access forbidden, please contact administrator"
case 429:
statusCode = http.StatusTooManyRequests
errType = "rate_limit_error"
errMsg = "Upstream rate limit exceeded, please retry later"
case 529:
statusCode = http.StatusServiceUnavailable
errType = "overloaded_error"
errMsg = "Upstream service overloaded, please retry later"
case 500, 502, 503, 504:
statusCode = http.StatusBadGateway
errType = "upstream_error"
errMsg = "Upstream service temporarily unavailable"
default:
statusCode = http.StatusBadGateway
errType = "upstream_error"
errMsg = "Upstream request failed"
}
// 返回自定义错误响应
c.JSON(statusCode, gin.H{
"type": "error",
"error": gin.H{
"type": errType,
"message": errMsg,
},
})
return nil, fmt.Errorf("upstream error: %d", resp.StatusCode)
}
func (s *GatewayService) handleRetryExhaustedSideEffects(ctx context.Context, resp *http.Response, account *Account) {
body, _ := io.ReadAll(resp.Body)
statusCode := resp.StatusCode
// OAuth/Setup Token 账号的 403标记账号异常
if account.IsOAuth() && statusCode == 403 {
s.rateLimitService.HandleUpstreamError(ctx, account, statusCode, resp.Header, body)
log.Printf("Account %d: marked as error after %d retries for status %d", account.ID, maxRetries, statusCode)
} else {
// API Key 未配置错误码:不标记账号状态
log.Printf("Account %d: upstream error %d after %d retries (not marking account)", account.ID, statusCode, maxRetries)
}
}
func (s *GatewayService) handleFailoverSideEffects(ctx context.Context, resp *http.Response, account *Account) {
body, _ := io.ReadAll(resp.Body)
s.rateLimitService.HandleUpstreamError(ctx, account, resp.StatusCode, resp.Header, body)
}
// handleRetryExhaustedError 处理重试耗尽后的错误
// OAuth 403标记账号异常
// API Key 未配置错误码:仅返回错误,不标记账号
func (s *GatewayService) handleRetryExhaustedError(ctx context.Context, resp *http.Response, c *gin.Context, account *Account) (*ForwardResult, error) {
s.handleRetryExhaustedSideEffects(ctx, resp, account)
// 返回统一的重试耗尽错误响应
c.JSON(http.StatusBadGateway, gin.H{
"type": "error",
"error": gin.H{
"type": "upstream_error",
"message": "Upstream request failed after retries",
},
})
return nil, fmt.Errorf("upstream error: %d (retries exhausted)", resp.StatusCode)
}
// streamingResult 流式响应结果
type streamingResult struct {
usage *ClaudeUsage
firstTokenMs *int
}
func (s *GatewayService) handleStreamingResponse(ctx context.Context, resp *http.Response, c *gin.Context, account *Account, startTime time.Time, originalModel, mappedModel string) (*streamingResult, error) {
// 更新5h窗口状态
s.rateLimitService.UpdateSessionWindow(ctx, account, resp.Header)
// 设置SSE响应头
c.Header("Content-Type", "text/event-stream")
c.Header("Cache-Control", "no-cache")
c.Header("Connection", "keep-alive")
c.Header("X-Accel-Buffering", "no")
// 透传其他响应头
if v := resp.Header.Get("x-request-id"); v != "" {
c.Header("x-request-id", v)
}
w := c.Writer
flusher, ok := w.(http.Flusher)
if !ok {
return nil, errors.New("streaming not supported")
}
usage := &ClaudeUsage{}
var firstTokenMs *int
scanner := bufio.NewScanner(resp.Body)
// 设置更大的buffer以处理长行
scanner.Buffer(make([]byte, 64*1024), 1024*1024)
needModelReplace := originalModel != mappedModel
for scanner.Scan() {
line := scanner.Text()
if line == "event: error" {
return nil, errors.New("have error in stream")
}
// Extract data from SSE line (supports both "data: " and "data:" formats)
if sseDataRe.MatchString(line) {
data := sseDataRe.ReplaceAllString(line, "")
// 如果有模型映射替换响应中的model字段
if needModelReplace {
line = s.replaceModelInSSELine(line, mappedModel, originalModel)
}
// 转发行
if _, err := fmt.Fprintf(w, "%s\n", line); err != nil {
return &streamingResult{usage: usage, firstTokenMs: firstTokenMs}, err
}
flusher.Flush()
// 记录首字时间:第一个有效的 content_block_delta 或 message_start
if firstTokenMs == nil && data != "" && data != "[DONE]" {
ms := int(time.Since(startTime).Milliseconds())
firstTokenMs = &ms
}
s.parseSSEUsage(data, usage)
} else {
// 非 data 行直接转发
if _, err := fmt.Fprintf(w, "%s\n", line); err != nil {
return &streamingResult{usage: usage, firstTokenMs: firstTokenMs}, err
}
flusher.Flush()
}
}
if err := scanner.Err(); err != nil {
return &streamingResult{usage: usage, firstTokenMs: firstTokenMs}, fmt.Errorf("stream read error: %w", err)
}
return &streamingResult{usage: usage, firstTokenMs: firstTokenMs}, nil
}
// replaceModelInSSELine 替换SSE数据行中的model字段
func (s *GatewayService) replaceModelInSSELine(line, fromModel, toModel string) string {
if !sseDataRe.MatchString(line) {
return line
}
data := sseDataRe.ReplaceAllString(line, "")
if data == "" || data == "[DONE]" {
return line
}
var event map[string]any
if err := json.Unmarshal([]byte(data), &event); err != nil {
return line
}
// 只替换 message_start 事件中的 message.model
if event["type"] != "message_start" {
return line
}
msg, ok := event["message"].(map[string]any)
if !ok {
return line
}
model, ok := msg["model"].(string)
if !ok || model != fromModel {
return line
}
msg["model"] = toModel
newData, err := json.Marshal(event)
if err != nil {
return line
}
return "data: " + string(newData)
}
func (s *GatewayService) parseSSEUsage(data string, usage *ClaudeUsage) {
// 解析message_start获取input tokens标准Claude API格式
var msgStart struct {
Type string `json:"type"`
Message struct {
Usage ClaudeUsage `json:"usage"`
} `json:"message"`
}
if json.Unmarshal([]byte(data), &msgStart) == nil && msgStart.Type == "message_start" {
usage.InputTokens = msgStart.Message.Usage.InputTokens
usage.CacheCreationInputTokens = msgStart.Message.Usage.CacheCreationInputTokens
usage.CacheReadInputTokens = msgStart.Message.Usage.CacheReadInputTokens
}
// 解析message_delta获取tokens兼容GLM等把所有usage放在delta中的API
var msgDelta struct {
Type string `json:"type"`
Usage struct {
InputTokens int `json:"input_tokens"`
OutputTokens int `json:"output_tokens"`
CacheCreationInputTokens int `json:"cache_creation_input_tokens"`
CacheReadInputTokens int `json:"cache_read_input_tokens"`
} `json:"usage"`
}
if json.Unmarshal([]byte(data), &msgDelta) == nil && msgDelta.Type == "message_delta" {
// output_tokens 总是从 message_delta 获取
usage.OutputTokens = msgDelta.Usage.OutputTokens
// 如果 message_start 中没有值,则从 message_delta 获取兼容GLM等API
if usage.InputTokens == 0 {
usage.InputTokens = msgDelta.Usage.InputTokens
}
if usage.CacheCreationInputTokens == 0 {
usage.CacheCreationInputTokens = msgDelta.Usage.CacheCreationInputTokens
}
if usage.CacheReadInputTokens == 0 {
usage.CacheReadInputTokens = msgDelta.Usage.CacheReadInputTokens
}
}
}
func (s *GatewayService) handleNonStreamingResponse(ctx context.Context, resp *http.Response, c *gin.Context, account *Account, originalModel, mappedModel string) (*ClaudeUsage, error) {
// 更新5h窗口状态
s.rateLimitService.UpdateSessionWindow(ctx, account, resp.Header)
body, err := io.ReadAll(resp.Body)
if err != nil {
return nil, err
}
// 解析usage
var response struct {
Usage ClaudeUsage `json:"usage"`
}
if err := json.Unmarshal(body, &response); err != nil {
return nil, fmt.Errorf("parse response: %w", err)
}
// 如果有模型映射替换响应中的model字段
if originalModel != mappedModel {
body = s.replaceModelInResponseBody(body, mappedModel, originalModel)
}
// 透传响应头
for key, values := range resp.Header {
for _, value := range values {
c.Header(key, value)
}
}
// 写入响应
c.Data(resp.StatusCode, "application/json", body)
return &response.Usage, nil
}
// replaceModelInResponseBody 替换响应体中的model字段
func (s *GatewayService) replaceModelInResponseBody(body []byte, fromModel, toModel string) []byte {
var resp map[string]any
if err := json.Unmarshal(body, &resp); err != nil {
return body
}
model, ok := resp["model"].(string)
if !ok || model != fromModel {
return body
}
resp["model"] = toModel
newBody, err := json.Marshal(resp)
if err != nil {
return body
}
return newBody
}
// RecordUsageInput 记录使用量的输入参数
type RecordUsageInput struct {
Result *ForwardResult
ApiKey *ApiKey
User *User
Account *Account
Subscription *UserSubscription // 可选:订阅信息
}
// RecordUsage 记录使用量并扣费(或更新订阅用量)
func (s *GatewayService) RecordUsage(ctx context.Context, input *RecordUsageInput) error {
result := input.Result
apiKey := input.ApiKey
user := input.User
account := input.Account
subscription := input.Subscription
// 计算费用
tokens := UsageTokens{
InputTokens: result.Usage.InputTokens,
OutputTokens: result.Usage.OutputTokens,
CacheCreationTokens: result.Usage.CacheCreationInputTokens,
CacheReadTokens: result.Usage.CacheReadInputTokens,
}
// 获取费率倍数
multiplier := s.cfg.Default.RateMultiplier
if apiKey.GroupID != nil && apiKey.Group != nil {
multiplier = apiKey.Group.RateMultiplier
}
cost, err := s.billingService.CalculateCost(result.Model, tokens, multiplier)
if err != nil {
log.Printf("Calculate cost failed: %v", err)
// 使用默认费用继续
cost = &CostBreakdown{ActualCost: 0}
}
// 判断计费方式:订阅模式 vs 余额模式
isSubscriptionBilling := subscription != nil && apiKey.Group != nil && apiKey.Group.IsSubscriptionType()
billingType := BillingTypeBalance
if isSubscriptionBilling {
billingType = BillingTypeSubscription
}
// 创建使用日志
durationMs := int(result.Duration.Milliseconds())
usageLog := &UsageLog{
UserID: user.ID,
ApiKeyID: apiKey.ID,
AccountID: account.ID,
RequestID: result.RequestID,
Model: result.Model,
InputTokens: result.Usage.InputTokens,
OutputTokens: result.Usage.OutputTokens,
CacheCreationTokens: result.Usage.CacheCreationInputTokens,
CacheReadTokens: result.Usage.CacheReadInputTokens,
InputCost: cost.InputCost,
OutputCost: cost.OutputCost,
CacheCreationCost: cost.CacheCreationCost,
CacheReadCost: cost.CacheReadCost,
TotalCost: cost.TotalCost,
ActualCost: cost.ActualCost,
RateMultiplier: multiplier,
BillingType: billingType,
Stream: result.Stream,
DurationMs: &durationMs,
FirstTokenMs: result.FirstTokenMs,
CreatedAt: time.Now(),
}
// 添加分组和订阅关联
if apiKey.GroupID != nil {
usageLog.GroupID = apiKey.GroupID
}
if subscription != nil {
usageLog.SubscriptionID = &subscription.ID
}
if err := s.usageLogRepo.Create(ctx, usageLog); err != nil {
log.Printf("Create usage log failed: %v", err)
}
if s.cfg != nil && s.cfg.RunMode == config.RunModeSimple {
log.Printf("[SIMPLE MODE] Usage recorded (not billed): user=%d, tokens=%d", usageLog.UserID, usageLog.TotalTokens())
s.deferredService.ScheduleLastUsedUpdate(account.ID)
return nil
}
// 根据计费类型执行扣费
if isSubscriptionBilling {
// 订阅模式:更新订阅用量(使用 TotalCost 原始费用,不考虑倍率)
if cost.TotalCost > 0 {
if err := s.userSubRepo.IncrementUsage(ctx, subscription.ID, cost.TotalCost); err != nil {
log.Printf("Increment subscription usage failed: %v", err)
}
// 异步更新订阅缓存
s.billingCacheService.QueueUpdateSubscriptionUsage(user.ID, *apiKey.GroupID, cost.TotalCost)
}
} else {
// 余额模式:扣除用户余额(使用 ActualCost 考虑倍率后的费用)
if cost.ActualCost > 0 {
if err := s.userRepo.DeductBalance(ctx, user.ID, cost.ActualCost); err != nil {
log.Printf("Deduct balance failed: %v", err)
}
// 异步更新余额缓存
s.billingCacheService.QueueDeductBalance(user.ID, cost.ActualCost)
}
}
// Schedule batch update for account last_used_at
s.deferredService.ScheduleLastUsedUpdate(account.ID)
return nil
}
// ForwardCountTokens 转发 count_tokens 请求到上游 API
// 特点:不记录使用量、仅支持非流式响应
func (s *GatewayService) ForwardCountTokens(ctx context.Context, c *gin.Context, account *Account, parsed *ParsedRequest) error {
if parsed == nil {
s.countTokensError(c, http.StatusBadRequest, "invalid_request_error", "Request body is empty")
return fmt.Errorf("parse request: empty request")
}
body := parsed.Body
reqModel := parsed.Model
// Antigravity 账户不支持 count_tokens 转发,返回估算值
// 参考 Antigravity-Manager 和 proxycast 实现
if account.Platform == PlatformAntigravity {
c.JSON(http.StatusOK, gin.H{"input_tokens": 100})
return nil
}
// 应用模型映射(仅对 apikey 类型账号)
if account.Type == AccountTypeApiKey {
if reqModel != "" {
mappedModel := account.GetMappedModel(reqModel)
if mappedModel != reqModel {
body = s.replaceModelInBody(body, mappedModel)
reqModel = mappedModel
log.Printf("CountTokens model mapping applied: %s -> %s (account: %s)", parsed.Model, mappedModel, account.Name)
}
}
}
// 获取凭证
token, tokenType, err := s.GetAccessToken(ctx, account)
if err != nil {
s.countTokensError(c, http.StatusBadGateway, "upstream_error", "Failed to get access token")
return err
}
// 构建上游请求
upstreamReq, err := s.buildCountTokensRequest(ctx, c, account, body, token, tokenType, reqModel)
if err != nil {
s.countTokensError(c, http.StatusInternalServerError, "api_error", "Failed to build request")
return err
}
// 获取代理URL
proxyURL := ""
if account.ProxyID != nil && account.Proxy != nil {
proxyURL = account.Proxy.URL()
}
// 发送请求
resp, err := s.httpUpstream.Do(upstreamReq, proxyURL, account.ID, account.Concurrency)
if err != nil {
s.countTokensError(c, http.StatusBadGateway, "upstream_error", "Request failed")
return fmt.Errorf("upstream request failed: %w", err)
}
defer func() {
_ = resp.Body.Close()
}()
// 读取响应体
respBody, err := io.ReadAll(resp.Body)
if err != nil {
s.countTokensError(c, http.StatusBadGateway, "upstream_error", "Failed to read response")
return err
}
// 处理错误响应
if resp.StatusCode >= 400 {
// 标记账号状态429/529等
s.rateLimitService.HandleUpstreamError(ctx, account, resp.StatusCode, resp.Header, respBody)
// 记录上游错误摘要便于排障(不回显请求内容)
if s.cfg != nil && s.cfg.Gateway.LogUpstreamErrorBody {
log.Printf(
"count_tokens upstream error %d (account=%d platform=%s type=%s): %s",
resp.StatusCode,
account.ID,
account.Platform,
account.Type,
truncateForLog(respBody, s.cfg.Gateway.LogUpstreamErrorBodyMaxBytes),
)
}
// 返回简化的错误响应
errMsg := "Upstream request failed"
switch resp.StatusCode {
case 429:
errMsg = "Rate limit exceeded"
case 529:
errMsg = "Service overloaded"
}
s.countTokensError(c, resp.StatusCode, "upstream_error", errMsg)
return fmt.Errorf("upstream error: %d", resp.StatusCode)
}
// 透传成功响应
c.Data(resp.StatusCode, "application/json", respBody)
return nil
}
// buildCountTokensRequest 构建 count_tokens 上游请求
func (s *GatewayService) buildCountTokensRequest(ctx context.Context, c *gin.Context, account *Account, body []byte, token, tokenType, modelID string) (*http.Request, error) {
// 确定目标 URL
targetURL := claudeAPICountTokensURL
if account.Type == AccountTypeApiKey {
baseURL := account.GetBaseURL()
targetURL = baseURL + "/v1/messages/count_tokens"
}
// OAuth 账号:应用统一指纹和重写 userID
if account.IsOAuth() && s.identityService != nil {
fp, err := s.identityService.GetOrCreateFingerprint(ctx, account.ID, c.Request.Header)
if err == nil {
accountUUID := account.GetExtraString("account_uuid")
if accountUUID != "" && fp.ClientID != "" {
if newBody, err := s.identityService.RewriteUserID(body, account.ID, accountUUID, fp.ClientID); err == nil && len(newBody) > 0 {
body = newBody
}
}
}
}
// Filter thinking blocks from request body (prevents 400 errors from invalid signatures)
body = FilterThinkingBlocks(body)
req, err := http.NewRequestWithContext(ctx, "POST", targetURL, bytes.NewReader(body))
if err != nil {
return nil, err
}
// 设置认证头
if tokenType == "oauth" {
req.Header.Set("authorization", "Bearer "+token)
} else {
req.Header.Set("x-api-key", token)
}
// 白名单透传 headers
for key, values := range c.Request.Header {
lowerKey := strings.ToLower(key)
if allowedHeaders[lowerKey] {
for _, v := range values {
req.Header.Add(key, v)
}
}
}
// OAuth 账号:应用指纹到请求头
if account.IsOAuth() && s.identityService != nil {
fp, _ := s.identityService.GetOrCreateFingerprint(ctx, account.ID, c.Request.Header)
if fp != nil {
s.identityService.ApplyFingerprint(req, fp)
}
}
// 确保必要的 headers 存在
if req.Header.Get("content-type") == "" {
req.Header.Set("content-type", "application/json")
}
if req.Header.Get("anthropic-version") == "" {
req.Header.Set("anthropic-version", "2023-06-01")
}
// OAuth 账号:处理 anthropic-beta header
if tokenType == "oauth" {
req.Header.Set("anthropic-beta", s.getBetaHeader(modelID, c.GetHeader("anthropic-beta")))
} else if s.cfg != nil && s.cfg.Gateway.InjectBetaForApiKey && req.Header.Get("anthropic-beta") == "" {
// API-key与 messages 同步的按需 beta 注入(默认关闭)
if requestNeedsBetaFeatures(body) {
if beta := defaultApiKeyBetaHeader(body); beta != "" {
req.Header.Set("anthropic-beta", beta)
}
}
}
return req, nil
}
// countTokensError 返回 count_tokens 错误响应
func (s *GatewayService) countTokensError(c *gin.Context, status int, errType, message string) {
c.JSON(status, gin.H{
"type": "error",
"error": gin.H{
"type": errType,
"message": message,
},
})
}
// GetAvailableModels returns the list of models available for a group
// It aggregates model_mapping keys from all schedulable accounts in the group
func (s *GatewayService) GetAvailableModels(ctx context.Context, groupID *int64, platform string) []string {
var accounts []Account
var err error
if groupID != nil {
accounts, err = s.accountRepo.ListSchedulableByGroupID(ctx, *groupID)
} else {
accounts, err = s.accountRepo.ListSchedulable(ctx)
}
if err != nil || len(accounts) == 0 {
return nil
}
// Filter by platform if specified
if platform != "" {
filtered := make([]Account, 0)
for _, acc := range accounts {
if acc.Platform == platform {
filtered = append(filtered, acc)
}
}
accounts = filtered
}
// Collect unique models from all accounts
modelSet := make(map[string]struct{})
hasAnyMapping := false
for _, acc := range accounts {
mapping := acc.GetModelMapping()
if len(mapping) > 0 {
hasAnyMapping = true
for model := range mapping {
modelSet[model] = struct{}{}
}
}
}
// If no account has model_mapping, return nil (use default)
if !hasAnyMapping {
return nil
}
// Convert to slice
models := make([]string, 0, len(modelSet))
for model := range modelSet {
models = append(models, model)
}
return models
}