merge: 合并 main 分支到 test,解决 config 和 modelWhitelist 冲突

- config.go: 保留 Sora 配置,合入 SubscriptionCache 配置
- useModelWhitelist.ts: 同时保留 soraModels 和 antigravityModels

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
This commit is contained in:
yangjianbo
2026-02-07 20:18:07 +08:00
156 changed files with 14550 additions and 2206 deletions

View File

@@ -2,6 +2,7 @@ package handler
import (
"context"
"crypto/rand"
"encoding/json"
"errors"
"fmt"
@@ -113,9 +114,6 @@ func (h *GatewayHandler) Messages(c *gin.Context) {
return
}
// 检查是否为 Claude Code 客户端,设置到 context 中
SetClaudeCodeClientContext(c, body)
setOpsRequestContext(c, "", false, body)
parsedReq, err := service.ParseGatewayRequest(body)
@@ -126,6 +124,20 @@ func (h *GatewayHandler) Messages(c *gin.Context) {
reqModel := parsedReq.Model
reqStream := parsedReq.Stream
// 设置 max_tokens=1 + haiku 探测请求标识到 context 中
// 必须在 SetClaudeCodeClientContext 之前设置,因为 ClaudeCodeValidator 需要读取此标识进行绕过判断
if isMaxTokensOneHaikuRequest(reqModel, parsedReq.MaxTokens, reqStream) {
ctx := context.WithValue(c.Request.Context(), ctxkey.IsMaxTokensOneHaikuRequest, true)
c.Request = c.Request.WithContext(ctx)
}
// 检查是否为 Claude Code 客户端,设置到 context 中
SetClaudeCodeClientContext(c, body)
isClaudeCodeClient := service.IsClaudeCodeClient(c.Request.Context())
// 在请求上下文中记录 thinking 状态,供 Antigravity 最终模型 key 推导/模型维度限流使用
c.Request = c.Request.WithContext(context.WithValue(c.Request.Context(), ctxkey.ThinkingEnabled, parsedReq.ThinkingEnabled))
setOpsRequestContext(c, reqModel, reqStream, body)
// 验证 model 必填
@@ -137,6 +149,11 @@ func (h *GatewayHandler) Messages(c *gin.Context) {
// Track if we've started streaming (for error handling)
streamStarted := false
// 绑定错误透传服务,允许 service 层在非 failover 错误场景复用规则。
if h.errorPassthroughService != nil {
service.BindErrorPassthroughService(c, h.errorPassthroughService)
}
// 获取订阅信息可能为nil- 提前获取用于后续检查
subscription, _ := middleware2.GetSubscriptionFromContext(c)
@@ -202,17 +219,27 @@ func (h *GatewayHandler) Messages(c *gin.Context) {
sessionKey = "gemini:" + sessionHash
}
// 查询粘性会话绑定的账号 ID
var sessionBoundAccountID int64
if sessionKey != "" {
sessionBoundAccountID, _ = h.gatewayService.GetCachedSessionAccountID(c.Request.Context(), apiKey.GroupID, sessionKey)
}
// 判断是否真的绑定了粘性会话:有 sessionKey 且已经绑定到某个账号
hasBoundSession := sessionKey != "" && sessionBoundAccountID > 0
if platform == service.PlatformGemini {
maxAccountSwitches := h.maxAccountSwitchesGemini
switchCount := 0
failedAccountIDs := make(map[int64]struct{})
var lastFailoverErr *service.UpstreamFailoverError
var forceCacheBilling bool // 粘性会话切换时的缓存计费标记
for {
selection, err := h.gatewayService.SelectAccountWithLoadAwareness(c.Request.Context(), apiKey.GroupID, sessionKey, reqModel, failedAccountIDs, "") // Gemini 不使用会话限制
if err != nil {
if len(failedAccountIDs) == 0 {
h.handleStreamingAwareError(c, http.StatusServiceUnavailable, "api_error", "No available accounts: "+err.Error(), streamStarted)
log.Printf("[Gateway] SelectAccount failed: %v", err)
h.handleStreamingAwareError(c, http.StatusServiceUnavailable, "api_error", "Service temporarily unavailable", streamStarted)
return
}
if lastFailoverErr != nil {
@@ -227,7 +254,7 @@ func (h *GatewayHandler) Messages(c *gin.Context) {
// 检查请求拦截预热请求、SUGGESTION MODE等
if account.IsInterceptWarmupEnabled() {
interceptType := detectInterceptType(body)
interceptType := detectInterceptType(body, reqModel, parsedReq.MaxTokens, reqStream, isClaudeCodeClient)
if interceptType != InterceptTypeNone {
if selection.Acquired && selection.ReleaseFunc != nil {
selection.ReleaseFunc()
@@ -260,12 +287,12 @@ func (h *GatewayHandler) Messages(c *gin.Context) {
if err == nil && canWait {
accountWaitCounted = true
}
// Ensure the wait counter is decremented if we exit before acquiring the slot.
defer func() {
releaseWait := func() {
if accountWaitCounted {
h.concurrencyHelper.DecrementAccountWaitCount(c.Request.Context(), account.ID)
accountWaitCounted = false
}
}()
}
accountReleaseFunc, err = h.concurrencyHelper.AcquireAccountSlotWithWaitTimeout(
c,
@@ -277,14 +304,12 @@ func (h *GatewayHandler) Messages(c *gin.Context) {
)
if err != nil {
log.Printf("Account concurrency acquire failed: %v", err)
releaseWait()
h.handleConcurrencyError(c, err, "account", streamStarted)
return
}
// Slot acquired: no longer waiting in queue.
if accountWaitCounted {
h.concurrencyHelper.DecrementAccountWaitCount(c.Request.Context(), account.ID)
accountWaitCounted = false
}
releaseWait()
if err := h.gatewayService.BindStickySession(c.Request.Context(), apiKey.GroupID, sessionKey, account.ID); err != nil {
log.Printf("Bind sticky session failed: %v", err)
}
@@ -299,7 +324,7 @@ func (h *GatewayHandler) Messages(c *gin.Context) {
requestCtx = context.WithValue(requestCtx, ctxkey.AccountSwitchCount, switchCount)
}
if account.Platform == service.PlatformAntigravity {
result, err = h.antigravityGatewayService.ForwardGemini(requestCtx, c, account, reqModel, "generateContent", reqStream, body)
result, err = h.antigravityGatewayService.ForwardGemini(requestCtx, c, account, reqModel, "generateContent", reqStream, body, hasBoundSession)
} else {
result, err = h.geminiCompatService.Forward(requestCtx, c, account, body)
}
@@ -311,6 +336,9 @@ func (h *GatewayHandler) Messages(c *gin.Context) {
if errors.As(err, &failoverErr) {
failedAccountIDs[account.ID] = struct{}{}
lastFailoverErr = failoverErr
if failoverErr.ForceCacheBilling {
forceCacheBilling = true
}
if switchCount >= maxAccountSwitches {
h.handleFailoverExhausted(c, failoverErr, service.PlatformGemini, streamStarted)
return
@@ -329,22 +357,23 @@ func (h *GatewayHandler) Messages(c *gin.Context) {
clientIP := ip.GetClientIP(c)
// 异步记录使用量subscription已在函数开头获取
go func(result *service.ForwardResult, usedAccount *service.Account, ua, clientIP string) {
go func(result *service.ForwardResult, usedAccount *service.Account, ua, clientIP string, fcb bool) {
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
defer cancel()
if err := h.gatewayService.RecordUsage(ctx, &service.RecordUsageInput{
Result: result,
APIKey: apiKey,
User: apiKey.User,
Account: usedAccount,
Subscription: subscription,
UserAgent: ua,
IPAddress: clientIP,
APIKeyService: h.apiKeyService,
Result: result,
APIKey: apiKey,
User: apiKey.User,
Account: usedAccount,
Subscription: subscription,
UserAgent: ua,
IPAddress: clientIP,
ForceCacheBilling: fcb,
APIKeyService: h.apiKeyService,
}); err != nil {
log.Printf("Record usage failed: %v", err)
}
}(result, account, userAgent, clientIP)
}(result, account, userAgent, clientIP, forceCacheBilling)
return
}
}
@@ -363,13 +392,15 @@ func (h *GatewayHandler) Messages(c *gin.Context) {
failedAccountIDs := make(map[int64]struct{})
var lastFailoverErr *service.UpstreamFailoverError
retryWithFallback := false
var forceCacheBilling bool // 粘性会话切换时的缓存计费标记
for {
// 选择支持该模型的账号
selection, err := h.gatewayService.SelectAccountWithLoadAwareness(c.Request.Context(), currentAPIKey.GroupID, sessionKey, reqModel, failedAccountIDs, parsedReq.MetadataUserID)
if err != nil {
if len(failedAccountIDs) == 0 {
h.handleStreamingAwareError(c, http.StatusServiceUnavailable, "api_error", "No available accounts: "+err.Error(), streamStarted)
log.Printf("[Gateway] SelectAccount failed: %v", err)
h.handleStreamingAwareError(c, http.StatusServiceUnavailable, "api_error", "Service temporarily unavailable", streamStarted)
return
}
if lastFailoverErr != nil {
@@ -384,7 +415,7 @@ func (h *GatewayHandler) Messages(c *gin.Context) {
// 检查请求拦截预热请求、SUGGESTION MODE等
if account.IsInterceptWarmupEnabled() {
interceptType := detectInterceptType(body)
interceptType := detectInterceptType(body, reqModel, parsedReq.MaxTokens, reqStream, isClaudeCodeClient)
if interceptType != InterceptTypeNone {
if selection.Acquired && selection.ReleaseFunc != nil {
selection.ReleaseFunc()
@@ -417,11 +448,12 @@ func (h *GatewayHandler) Messages(c *gin.Context) {
if err == nil && canWait {
accountWaitCounted = true
}
defer func() {
releaseWait := func() {
if accountWaitCounted {
h.concurrencyHelper.DecrementAccountWaitCount(c.Request.Context(), account.ID)
accountWaitCounted = false
}
}()
}
accountReleaseFunc, err = h.concurrencyHelper.AcquireAccountSlotWithWaitTimeout(
c,
@@ -433,13 +465,12 @@ func (h *GatewayHandler) Messages(c *gin.Context) {
)
if err != nil {
log.Printf("Account concurrency acquire failed: %v", err)
releaseWait()
h.handleConcurrencyError(c, err, "account", streamStarted)
return
}
if accountWaitCounted {
h.concurrencyHelper.DecrementAccountWaitCount(c.Request.Context(), account.ID)
accountWaitCounted = false
}
// Slot acquired: no longer waiting in queue.
releaseWait()
if err := h.gatewayService.BindStickySession(c.Request.Context(), currentAPIKey.GroupID, sessionKey, account.ID); err != nil {
log.Printf("Bind sticky session failed: %v", err)
}
@@ -454,7 +485,7 @@ func (h *GatewayHandler) Messages(c *gin.Context) {
requestCtx = context.WithValue(requestCtx, ctxkey.AccountSwitchCount, switchCount)
}
if account.Platform == service.PlatformAntigravity {
result, err = h.antigravityGatewayService.Forward(requestCtx, c, account, body)
result, err = h.antigravityGatewayService.Forward(requestCtx, c, account, body, hasBoundSession)
} else {
result, err = h.gatewayService.Forward(requestCtx, c, account, parsedReq)
}
@@ -501,6 +532,9 @@ func (h *GatewayHandler) Messages(c *gin.Context) {
if errors.As(err, &failoverErr) {
failedAccountIDs[account.ID] = struct{}{}
lastFailoverErr = failoverErr
if failoverErr.ForceCacheBilling {
forceCacheBilling = true
}
if switchCount >= maxAccountSwitches {
h.handleFailoverExhausted(c, failoverErr, account.Platform, streamStarted)
return
@@ -519,22 +553,23 @@ func (h *GatewayHandler) Messages(c *gin.Context) {
clientIP := ip.GetClientIP(c)
// 异步记录使用量subscription已在函数开头获取
go func(result *service.ForwardResult, usedAccount *service.Account, ua, clientIP string) {
go func(result *service.ForwardResult, usedAccount *service.Account, ua, clientIP string, fcb bool) {
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
defer cancel()
if err := h.gatewayService.RecordUsage(ctx, &service.RecordUsageInput{
Result: result,
APIKey: currentAPIKey,
User: currentAPIKey.User,
Account: usedAccount,
Subscription: currentSubscription,
UserAgent: ua,
IPAddress: clientIP,
APIKeyService: h.apiKeyService,
Result: result,
APIKey: currentAPIKey,
User: currentAPIKey.User,
Account: usedAccount,
Subscription: currentSubscription,
UserAgent: ua,
IPAddress: clientIP,
ForceCacheBilling: fcb,
APIKeyService: h.apiKeyService,
}); err != nil {
log.Printf("Record usage failed: %v", err)
}
}(result, account, userAgent, clientIP)
}(result, account, userAgent, clientIP, forceCacheBilling)
return
}
if !retryWithFallback {
@@ -917,6 +952,8 @@ func (h *GatewayHandler) CountTokens(c *gin.Context) {
h.errorResponse(c, http.StatusBadRequest, "invalid_request_error", "Failed to parse request body")
return
}
// 在请求上下文中记录 thinking 状态,供 Antigravity 最终模型 key 推导/模型维度限流使用
c.Request = c.Request.WithContext(context.WithValue(c.Request.Context(), ctxkey.ThinkingEnabled, parsedReq.ThinkingEnabled))
// 验证 model 必填
if parsedReq.Model == "" {
@@ -943,7 +980,8 @@ func (h *GatewayHandler) CountTokens(c *gin.Context) {
// 选择支持该模型的账号
account, err := h.gatewayService.SelectAccountForModel(c.Request.Context(), apiKey.GroupID, sessionHash, parsedReq.Model)
if err != nil {
h.errorResponse(c, http.StatusServiceUnavailable, "api_error", "No available accounts: "+err.Error())
log.Printf("[Gateway] SelectAccountForModel failed: %v", err)
h.errorResponse(c, http.StatusServiceUnavailable, "api_error", "Service temporarily unavailable")
return
}
setOpsSelectedAccount(c, account.ID)
@@ -960,13 +998,37 @@ func (h *GatewayHandler) CountTokens(c *gin.Context) {
type InterceptType int
const (
InterceptTypeNone InterceptType = iota
InterceptTypeWarmup // 预热请求(返回 "New Conversation"
InterceptTypeSuggestionMode // SUGGESTION MODE返回空字符串
InterceptTypeNone InterceptType = iota
InterceptTypeWarmup // 预热请求(返回 "New Conversation"
InterceptTypeSuggestionMode // SUGGESTION MODE返回空字符串
InterceptTypeMaxTokensOneHaiku // max_tokens=1 + haiku 探测请求(返回 "#"
)
// isHaikuModel 检查模型名称是否包含 "haiku"(大小写不敏感)
func isHaikuModel(model string) bool {
return strings.Contains(strings.ToLower(model), "haiku")
}
// isMaxTokensOneHaikuRequest 检查是否为 max_tokens=1 + haiku 模型的探测请求
// 这类请求用于 Claude Code 验证 API 连通性
// 条件max_tokens == 1 且 model 包含 "haiku" 且非流式请求
func isMaxTokensOneHaikuRequest(model string, maxTokens int, isStream bool) bool {
return maxTokens == 1 && isHaikuModel(model) && !isStream
}
// detectInterceptType 检测请求是否需要拦截,返回拦截类型
func detectInterceptType(body []byte) InterceptType {
// 参数说明:
// - body: 请求体字节
// - model: 请求的模型名称
// - maxTokens: max_tokens 值
// - isStream: 是否为流式请求
// - isClaudeCodeClient: 是否已通过 Claude Code 客户端校验
func detectInterceptType(body []byte, model string, maxTokens int, isStream bool, isClaudeCodeClient bool) InterceptType {
// 优先检查 max_tokens=1 + haiku 探测请求(仅非流式)
if isClaudeCodeClient && isMaxTokensOneHaikuRequest(model, maxTokens, isStream) {
return InterceptTypeMaxTokensOneHaiku
}
// 快速检查:如果不包含任何关键字,直接返回
bodyStr := string(body)
hasSuggestionMode := strings.Contains(bodyStr, "[SUGGESTION MODE:")
@@ -1116,9 +1178,25 @@ func sendMockInterceptStream(c *gin.Context, model string, interceptType Interce
}
}
// generateRealisticMsgID 生成仿真的消息 IDmsg_bdrk_XXXXXXX 格式)
// 格式与 Claude API 真实响应一致24 位随机字母数字
func generateRealisticMsgID() string {
const charset = "0123456789ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz"
const idLen = 24
randomBytes := make([]byte, idLen)
if _, err := rand.Read(randomBytes); err != nil {
return fmt.Sprintf("msg_bdrk_%d", time.Now().UnixNano())
}
b := make([]byte, idLen)
for i := range b {
b[i] = charset[int(randomBytes[i])%len(charset)]
}
return "msg_bdrk_" + string(b)
}
// sendMockInterceptResponse 发送非流式 mock 响应(用于请求拦截)
func sendMockInterceptResponse(c *gin.Context, model string, interceptType InterceptType) {
var msgID, text string
var msgID, text, stopReason string
var outputTokens int
switch interceptType {
@@ -1126,24 +1204,42 @@ func sendMockInterceptResponse(c *gin.Context, model string, interceptType Inter
msgID = "msg_mock_suggestion"
text = ""
outputTokens = 1
stopReason = "end_turn"
case InterceptTypeMaxTokensOneHaiku:
msgID = generateRealisticMsgID()
text = "#"
outputTokens = 1
stopReason = "max_tokens" // max_tokens=1 探测请求的 stop_reason 应为 max_tokens
default: // InterceptTypeWarmup
msgID = "msg_mock_warmup"
text = "New Conversation"
outputTokens = 2
stopReason = "end_turn"
}
c.JSON(http.StatusOK, gin.H{
"id": msgID,
"type": "message",
"role": "assistant",
"model": model,
"content": []gin.H{{"type": "text", "text": text}},
"stop_reason": "end_turn",
// 构建完整的响应格式(与 Claude API 响应格式一致)
response := gin.H{
"model": model,
"id": msgID,
"type": "message",
"role": "assistant",
"content": []gin.H{{"type": "text", "text": text}},
"stop_reason": stopReason,
"stop_sequence": nil,
"usage": gin.H{
"input_tokens": 10,
"input_tokens": 10,
"cache_creation_input_tokens": 0,
"cache_read_input_tokens": 0,
"cache_creation": gin.H{
"ephemeral_5m_input_tokens": 0,
"ephemeral_1h_input_tokens": 0,
},
"output_tokens": outputTokens,
"total_tokens": 10 + outputTokens,
},
})
}
c.JSON(http.StatusOK, response)
}
func billingErrorDetails(err error) (status int, code, message string) {
@@ -1156,7 +1252,8 @@ func billingErrorDetails(err error) (status int, code, message string) {
}
msg := pkgerrors.Message(err)
if msg == "" {
msg = err.Error()
log.Printf("[Gateway] billing error details: %v", err)
msg = "Billing error"
}
return http.StatusForbidden, "billing_error", msg
}