merge: 合并 main 分支到 test,解决 config 和 modelWhitelist 冲突
- config.go: 保留 Sora 配置,合入 SubscriptionCache 配置 - useModelWhitelist.ts: 同时保留 soraModels 和 antigravityModels Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
This commit is contained in:
@@ -2,6 +2,7 @@ package handler
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/rand"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
@@ -113,9 +114,6 @@ func (h *GatewayHandler) Messages(c *gin.Context) {
|
||||
return
|
||||
}
|
||||
|
||||
// 检查是否为 Claude Code 客户端,设置到 context 中
|
||||
SetClaudeCodeClientContext(c, body)
|
||||
|
||||
setOpsRequestContext(c, "", false, body)
|
||||
|
||||
parsedReq, err := service.ParseGatewayRequest(body)
|
||||
@@ -126,6 +124,20 @@ func (h *GatewayHandler) Messages(c *gin.Context) {
|
||||
reqModel := parsedReq.Model
|
||||
reqStream := parsedReq.Stream
|
||||
|
||||
// 设置 max_tokens=1 + haiku 探测请求标识到 context 中
|
||||
// 必须在 SetClaudeCodeClientContext 之前设置,因为 ClaudeCodeValidator 需要读取此标识进行绕过判断
|
||||
if isMaxTokensOneHaikuRequest(reqModel, parsedReq.MaxTokens, reqStream) {
|
||||
ctx := context.WithValue(c.Request.Context(), ctxkey.IsMaxTokensOneHaikuRequest, true)
|
||||
c.Request = c.Request.WithContext(ctx)
|
||||
}
|
||||
|
||||
// 检查是否为 Claude Code 客户端,设置到 context 中
|
||||
SetClaudeCodeClientContext(c, body)
|
||||
isClaudeCodeClient := service.IsClaudeCodeClient(c.Request.Context())
|
||||
|
||||
// 在请求上下文中记录 thinking 状态,供 Antigravity 最终模型 key 推导/模型维度限流使用
|
||||
c.Request = c.Request.WithContext(context.WithValue(c.Request.Context(), ctxkey.ThinkingEnabled, parsedReq.ThinkingEnabled))
|
||||
|
||||
setOpsRequestContext(c, reqModel, reqStream, body)
|
||||
|
||||
// 验证 model 必填
|
||||
@@ -137,6 +149,11 @@ func (h *GatewayHandler) Messages(c *gin.Context) {
|
||||
// Track if we've started streaming (for error handling)
|
||||
streamStarted := false
|
||||
|
||||
// 绑定错误透传服务,允许 service 层在非 failover 错误场景复用规则。
|
||||
if h.errorPassthroughService != nil {
|
||||
service.BindErrorPassthroughService(c, h.errorPassthroughService)
|
||||
}
|
||||
|
||||
// 获取订阅信息(可能为nil)- 提前获取用于后续检查
|
||||
subscription, _ := middleware2.GetSubscriptionFromContext(c)
|
||||
|
||||
@@ -202,17 +219,27 @@ func (h *GatewayHandler) Messages(c *gin.Context) {
|
||||
sessionKey = "gemini:" + sessionHash
|
||||
}
|
||||
|
||||
// 查询粘性会话绑定的账号 ID
|
||||
var sessionBoundAccountID int64
|
||||
if sessionKey != "" {
|
||||
sessionBoundAccountID, _ = h.gatewayService.GetCachedSessionAccountID(c.Request.Context(), apiKey.GroupID, sessionKey)
|
||||
}
|
||||
// 判断是否真的绑定了粘性会话:有 sessionKey 且已经绑定到某个账号
|
||||
hasBoundSession := sessionKey != "" && sessionBoundAccountID > 0
|
||||
|
||||
if platform == service.PlatformGemini {
|
||||
maxAccountSwitches := h.maxAccountSwitchesGemini
|
||||
switchCount := 0
|
||||
failedAccountIDs := make(map[int64]struct{})
|
||||
var lastFailoverErr *service.UpstreamFailoverError
|
||||
var forceCacheBilling bool // 粘性会话切换时的缓存计费标记
|
||||
|
||||
for {
|
||||
selection, err := h.gatewayService.SelectAccountWithLoadAwareness(c.Request.Context(), apiKey.GroupID, sessionKey, reqModel, failedAccountIDs, "") // Gemini 不使用会话限制
|
||||
if err != nil {
|
||||
if len(failedAccountIDs) == 0 {
|
||||
h.handleStreamingAwareError(c, http.StatusServiceUnavailable, "api_error", "No available accounts: "+err.Error(), streamStarted)
|
||||
log.Printf("[Gateway] SelectAccount failed: %v", err)
|
||||
h.handleStreamingAwareError(c, http.StatusServiceUnavailable, "api_error", "Service temporarily unavailable", streamStarted)
|
||||
return
|
||||
}
|
||||
if lastFailoverErr != nil {
|
||||
@@ -227,7 +254,7 @@ func (h *GatewayHandler) Messages(c *gin.Context) {
|
||||
|
||||
// 检查请求拦截(预热请求、SUGGESTION MODE等)
|
||||
if account.IsInterceptWarmupEnabled() {
|
||||
interceptType := detectInterceptType(body)
|
||||
interceptType := detectInterceptType(body, reqModel, parsedReq.MaxTokens, reqStream, isClaudeCodeClient)
|
||||
if interceptType != InterceptTypeNone {
|
||||
if selection.Acquired && selection.ReleaseFunc != nil {
|
||||
selection.ReleaseFunc()
|
||||
@@ -260,12 +287,12 @@ func (h *GatewayHandler) Messages(c *gin.Context) {
|
||||
if err == nil && canWait {
|
||||
accountWaitCounted = true
|
||||
}
|
||||
// Ensure the wait counter is decremented if we exit before acquiring the slot.
|
||||
defer func() {
|
||||
releaseWait := func() {
|
||||
if accountWaitCounted {
|
||||
h.concurrencyHelper.DecrementAccountWaitCount(c.Request.Context(), account.ID)
|
||||
accountWaitCounted = false
|
||||
}
|
||||
}()
|
||||
}
|
||||
|
||||
accountReleaseFunc, err = h.concurrencyHelper.AcquireAccountSlotWithWaitTimeout(
|
||||
c,
|
||||
@@ -277,14 +304,12 @@ func (h *GatewayHandler) Messages(c *gin.Context) {
|
||||
)
|
||||
if err != nil {
|
||||
log.Printf("Account concurrency acquire failed: %v", err)
|
||||
releaseWait()
|
||||
h.handleConcurrencyError(c, err, "account", streamStarted)
|
||||
return
|
||||
}
|
||||
// Slot acquired: no longer waiting in queue.
|
||||
if accountWaitCounted {
|
||||
h.concurrencyHelper.DecrementAccountWaitCount(c.Request.Context(), account.ID)
|
||||
accountWaitCounted = false
|
||||
}
|
||||
releaseWait()
|
||||
if err := h.gatewayService.BindStickySession(c.Request.Context(), apiKey.GroupID, sessionKey, account.ID); err != nil {
|
||||
log.Printf("Bind sticky session failed: %v", err)
|
||||
}
|
||||
@@ -299,7 +324,7 @@ func (h *GatewayHandler) Messages(c *gin.Context) {
|
||||
requestCtx = context.WithValue(requestCtx, ctxkey.AccountSwitchCount, switchCount)
|
||||
}
|
||||
if account.Platform == service.PlatformAntigravity {
|
||||
result, err = h.antigravityGatewayService.ForwardGemini(requestCtx, c, account, reqModel, "generateContent", reqStream, body)
|
||||
result, err = h.antigravityGatewayService.ForwardGemini(requestCtx, c, account, reqModel, "generateContent", reqStream, body, hasBoundSession)
|
||||
} else {
|
||||
result, err = h.geminiCompatService.Forward(requestCtx, c, account, body)
|
||||
}
|
||||
@@ -311,6 +336,9 @@ func (h *GatewayHandler) Messages(c *gin.Context) {
|
||||
if errors.As(err, &failoverErr) {
|
||||
failedAccountIDs[account.ID] = struct{}{}
|
||||
lastFailoverErr = failoverErr
|
||||
if failoverErr.ForceCacheBilling {
|
||||
forceCacheBilling = true
|
||||
}
|
||||
if switchCount >= maxAccountSwitches {
|
||||
h.handleFailoverExhausted(c, failoverErr, service.PlatformGemini, streamStarted)
|
||||
return
|
||||
@@ -329,22 +357,23 @@ func (h *GatewayHandler) Messages(c *gin.Context) {
|
||||
clientIP := ip.GetClientIP(c)
|
||||
|
||||
// 异步记录使用量(subscription已在函数开头获取)
|
||||
go func(result *service.ForwardResult, usedAccount *service.Account, ua, clientIP string) {
|
||||
go func(result *service.ForwardResult, usedAccount *service.Account, ua, clientIP string, fcb bool) {
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
|
||||
defer cancel()
|
||||
if err := h.gatewayService.RecordUsage(ctx, &service.RecordUsageInput{
|
||||
Result: result,
|
||||
APIKey: apiKey,
|
||||
User: apiKey.User,
|
||||
Account: usedAccount,
|
||||
Subscription: subscription,
|
||||
UserAgent: ua,
|
||||
IPAddress: clientIP,
|
||||
APIKeyService: h.apiKeyService,
|
||||
Result: result,
|
||||
APIKey: apiKey,
|
||||
User: apiKey.User,
|
||||
Account: usedAccount,
|
||||
Subscription: subscription,
|
||||
UserAgent: ua,
|
||||
IPAddress: clientIP,
|
||||
ForceCacheBilling: fcb,
|
||||
APIKeyService: h.apiKeyService,
|
||||
}); err != nil {
|
||||
log.Printf("Record usage failed: %v", err)
|
||||
}
|
||||
}(result, account, userAgent, clientIP)
|
||||
}(result, account, userAgent, clientIP, forceCacheBilling)
|
||||
return
|
||||
}
|
||||
}
|
||||
@@ -363,13 +392,15 @@ func (h *GatewayHandler) Messages(c *gin.Context) {
|
||||
failedAccountIDs := make(map[int64]struct{})
|
||||
var lastFailoverErr *service.UpstreamFailoverError
|
||||
retryWithFallback := false
|
||||
var forceCacheBilling bool // 粘性会话切换时的缓存计费标记
|
||||
|
||||
for {
|
||||
// 选择支持该模型的账号
|
||||
selection, err := h.gatewayService.SelectAccountWithLoadAwareness(c.Request.Context(), currentAPIKey.GroupID, sessionKey, reqModel, failedAccountIDs, parsedReq.MetadataUserID)
|
||||
if err != nil {
|
||||
if len(failedAccountIDs) == 0 {
|
||||
h.handleStreamingAwareError(c, http.StatusServiceUnavailable, "api_error", "No available accounts: "+err.Error(), streamStarted)
|
||||
log.Printf("[Gateway] SelectAccount failed: %v", err)
|
||||
h.handleStreamingAwareError(c, http.StatusServiceUnavailable, "api_error", "Service temporarily unavailable", streamStarted)
|
||||
return
|
||||
}
|
||||
if lastFailoverErr != nil {
|
||||
@@ -384,7 +415,7 @@ func (h *GatewayHandler) Messages(c *gin.Context) {
|
||||
|
||||
// 检查请求拦截(预热请求、SUGGESTION MODE等)
|
||||
if account.IsInterceptWarmupEnabled() {
|
||||
interceptType := detectInterceptType(body)
|
||||
interceptType := detectInterceptType(body, reqModel, parsedReq.MaxTokens, reqStream, isClaudeCodeClient)
|
||||
if interceptType != InterceptTypeNone {
|
||||
if selection.Acquired && selection.ReleaseFunc != nil {
|
||||
selection.ReleaseFunc()
|
||||
@@ -417,11 +448,12 @@ func (h *GatewayHandler) Messages(c *gin.Context) {
|
||||
if err == nil && canWait {
|
||||
accountWaitCounted = true
|
||||
}
|
||||
defer func() {
|
||||
releaseWait := func() {
|
||||
if accountWaitCounted {
|
||||
h.concurrencyHelper.DecrementAccountWaitCount(c.Request.Context(), account.ID)
|
||||
accountWaitCounted = false
|
||||
}
|
||||
}()
|
||||
}
|
||||
|
||||
accountReleaseFunc, err = h.concurrencyHelper.AcquireAccountSlotWithWaitTimeout(
|
||||
c,
|
||||
@@ -433,13 +465,12 @@ func (h *GatewayHandler) Messages(c *gin.Context) {
|
||||
)
|
||||
if err != nil {
|
||||
log.Printf("Account concurrency acquire failed: %v", err)
|
||||
releaseWait()
|
||||
h.handleConcurrencyError(c, err, "account", streamStarted)
|
||||
return
|
||||
}
|
||||
if accountWaitCounted {
|
||||
h.concurrencyHelper.DecrementAccountWaitCount(c.Request.Context(), account.ID)
|
||||
accountWaitCounted = false
|
||||
}
|
||||
// Slot acquired: no longer waiting in queue.
|
||||
releaseWait()
|
||||
if err := h.gatewayService.BindStickySession(c.Request.Context(), currentAPIKey.GroupID, sessionKey, account.ID); err != nil {
|
||||
log.Printf("Bind sticky session failed: %v", err)
|
||||
}
|
||||
@@ -454,7 +485,7 @@ func (h *GatewayHandler) Messages(c *gin.Context) {
|
||||
requestCtx = context.WithValue(requestCtx, ctxkey.AccountSwitchCount, switchCount)
|
||||
}
|
||||
if account.Platform == service.PlatformAntigravity {
|
||||
result, err = h.antigravityGatewayService.Forward(requestCtx, c, account, body)
|
||||
result, err = h.antigravityGatewayService.Forward(requestCtx, c, account, body, hasBoundSession)
|
||||
} else {
|
||||
result, err = h.gatewayService.Forward(requestCtx, c, account, parsedReq)
|
||||
}
|
||||
@@ -501,6 +532,9 @@ func (h *GatewayHandler) Messages(c *gin.Context) {
|
||||
if errors.As(err, &failoverErr) {
|
||||
failedAccountIDs[account.ID] = struct{}{}
|
||||
lastFailoverErr = failoverErr
|
||||
if failoverErr.ForceCacheBilling {
|
||||
forceCacheBilling = true
|
||||
}
|
||||
if switchCount >= maxAccountSwitches {
|
||||
h.handleFailoverExhausted(c, failoverErr, account.Platform, streamStarted)
|
||||
return
|
||||
@@ -519,22 +553,23 @@ func (h *GatewayHandler) Messages(c *gin.Context) {
|
||||
clientIP := ip.GetClientIP(c)
|
||||
|
||||
// 异步记录使用量(subscription已在函数开头获取)
|
||||
go func(result *service.ForwardResult, usedAccount *service.Account, ua, clientIP string) {
|
||||
go func(result *service.ForwardResult, usedAccount *service.Account, ua, clientIP string, fcb bool) {
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
|
||||
defer cancel()
|
||||
if err := h.gatewayService.RecordUsage(ctx, &service.RecordUsageInput{
|
||||
Result: result,
|
||||
APIKey: currentAPIKey,
|
||||
User: currentAPIKey.User,
|
||||
Account: usedAccount,
|
||||
Subscription: currentSubscription,
|
||||
UserAgent: ua,
|
||||
IPAddress: clientIP,
|
||||
APIKeyService: h.apiKeyService,
|
||||
Result: result,
|
||||
APIKey: currentAPIKey,
|
||||
User: currentAPIKey.User,
|
||||
Account: usedAccount,
|
||||
Subscription: currentSubscription,
|
||||
UserAgent: ua,
|
||||
IPAddress: clientIP,
|
||||
ForceCacheBilling: fcb,
|
||||
APIKeyService: h.apiKeyService,
|
||||
}); err != nil {
|
||||
log.Printf("Record usage failed: %v", err)
|
||||
}
|
||||
}(result, account, userAgent, clientIP)
|
||||
}(result, account, userAgent, clientIP, forceCacheBilling)
|
||||
return
|
||||
}
|
||||
if !retryWithFallback {
|
||||
@@ -917,6 +952,8 @@ func (h *GatewayHandler) CountTokens(c *gin.Context) {
|
||||
h.errorResponse(c, http.StatusBadRequest, "invalid_request_error", "Failed to parse request body")
|
||||
return
|
||||
}
|
||||
// 在请求上下文中记录 thinking 状态,供 Antigravity 最终模型 key 推导/模型维度限流使用
|
||||
c.Request = c.Request.WithContext(context.WithValue(c.Request.Context(), ctxkey.ThinkingEnabled, parsedReq.ThinkingEnabled))
|
||||
|
||||
// 验证 model 必填
|
||||
if parsedReq.Model == "" {
|
||||
@@ -943,7 +980,8 @@ func (h *GatewayHandler) CountTokens(c *gin.Context) {
|
||||
// 选择支持该模型的账号
|
||||
account, err := h.gatewayService.SelectAccountForModel(c.Request.Context(), apiKey.GroupID, sessionHash, parsedReq.Model)
|
||||
if err != nil {
|
||||
h.errorResponse(c, http.StatusServiceUnavailable, "api_error", "No available accounts: "+err.Error())
|
||||
log.Printf("[Gateway] SelectAccountForModel failed: %v", err)
|
||||
h.errorResponse(c, http.StatusServiceUnavailable, "api_error", "Service temporarily unavailable")
|
||||
return
|
||||
}
|
||||
setOpsSelectedAccount(c, account.ID)
|
||||
@@ -960,13 +998,37 @@ func (h *GatewayHandler) CountTokens(c *gin.Context) {
|
||||
type InterceptType int
|
||||
|
||||
const (
|
||||
InterceptTypeNone InterceptType = iota
|
||||
InterceptTypeWarmup // 预热请求(返回 "New Conversation")
|
||||
InterceptTypeSuggestionMode // SUGGESTION MODE(返回空字符串)
|
||||
InterceptTypeNone InterceptType = iota
|
||||
InterceptTypeWarmup // 预热请求(返回 "New Conversation")
|
||||
InterceptTypeSuggestionMode // SUGGESTION MODE(返回空字符串)
|
||||
InterceptTypeMaxTokensOneHaiku // max_tokens=1 + haiku 探测请求(返回 "#")
|
||||
)
|
||||
|
||||
// isHaikuModel 检查模型名称是否包含 "haiku"(大小写不敏感)
|
||||
func isHaikuModel(model string) bool {
|
||||
return strings.Contains(strings.ToLower(model), "haiku")
|
||||
}
|
||||
|
||||
// isMaxTokensOneHaikuRequest 检查是否为 max_tokens=1 + haiku 模型的探测请求
|
||||
// 这类请求用于 Claude Code 验证 API 连通性
|
||||
// 条件:max_tokens == 1 且 model 包含 "haiku" 且非流式请求
|
||||
func isMaxTokensOneHaikuRequest(model string, maxTokens int, isStream bool) bool {
|
||||
return maxTokens == 1 && isHaikuModel(model) && !isStream
|
||||
}
|
||||
|
||||
// detectInterceptType 检测请求是否需要拦截,返回拦截类型
|
||||
func detectInterceptType(body []byte) InterceptType {
|
||||
// 参数说明:
|
||||
// - body: 请求体字节
|
||||
// - model: 请求的模型名称
|
||||
// - maxTokens: max_tokens 值
|
||||
// - isStream: 是否为流式请求
|
||||
// - isClaudeCodeClient: 是否已通过 Claude Code 客户端校验
|
||||
func detectInterceptType(body []byte, model string, maxTokens int, isStream bool, isClaudeCodeClient bool) InterceptType {
|
||||
// 优先检查 max_tokens=1 + haiku 探测请求(仅非流式)
|
||||
if isClaudeCodeClient && isMaxTokensOneHaikuRequest(model, maxTokens, isStream) {
|
||||
return InterceptTypeMaxTokensOneHaiku
|
||||
}
|
||||
|
||||
// 快速检查:如果不包含任何关键字,直接返回
|
||||
bodyStr := string(body)
|
||||
hasSuggestionMode := strings.Contains(bodyStr, "[SUGGESTION MODE:")
|
||||
@@ -1116,9 +1178,25 @@ func sendMockInterceptStream(c *gin.Context, model string, interceptType Interce
|
||||
}
|
||||
}
|
||||
|
||||
// generateRealisticMsgID 生成仿真的消息 ID(msg_bdrk_XXXXXXX 格式)
|
||||
// 格式与 Claude API 真实响应一致,24 位随机字母数字
|
||||
func generateRealisticMsgID() string {
|
||||
const charset = "0123456789ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz"
|
||||
const idLen = 24
|
||||
randomBytes := make([]byte, idLen)
|
||||
if _, err := rand.Read(randomBytes); err != nil {
|
||||
return fmt.Sprintf("msg_bdrk_%d", time.Now().UnixNano())
|
||||
}
|
||||
b := make([]byte, idLen)
|
||||
for i := range b {
|
||||
b[i] = charset[int(randomBytes[i])%len(charset)]
|
||||
}
|
||||
return "msg_bdrk_" + string(b)
|
||||
}
|
||||
|
||||
// sendMockInterceptResponse 发送非流式 mock 响应(用于请求拦截)
|
||||
func sendMockInterceptResponse(c *gin.Context, model string, interceptType InterceptType) {
|
||||
var msgID, text string
|
||||
var msgID, text, stopReason string
|
||||
var outputTokens int
|
||||
|
||||
switch interceptType {
|
||||
@@ -1126,24 +1204,42 @@ func sendMockInterceptResponse(c *gin.Context, model string, interceptType Inter
|
||||
msgID = "msg_mock_suggestion"
|
||||
text = ""
|
||||
outputTokens = 1
|
||||
stopReason = "end_turn"
|
||||
case InterceptTypeMaxTokensOneHaiku:
|
||||
msgID = generateRealisticMsgID()
|
||||
text = "#"
|
||||
outputTokens = 1
|
||||
stopReason = "max_tokens" // max_tokens=1 探测请求的 stop_reason 应为 max_tokens
|
||||
default: // InterceptTypeWarmup
|
||||
msgID = "msg_mock_warmup"
|
||||
text = "New Conversation"
|
||||
outputTokens = 2
|
||||
stopReason = "end_turn"
|
||||
}
|
||||
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"id": msgID,
|
||||
"type": "message",
|
||||
"role": "assistant",
|
||||
"model": model,
|
||||
"content": []gin.H{{"type": "text", "text": text}},
|
||||
"stop_reason": "end_turn",
|
||||
// 构建完整的响应格式(与 Claude API 响应格式一致)
|
||||
response := gin.H{
|
||||
"model": model,
|
||||
"id": msgID,
|
||||
"type": "message",
|
||||
"role": "assistant",
|
||||
"content": []gin.H{{"type": "text", "text": text}},
|
||||
"stop_reason": stopReason,
|
||||
"stop_sequence": nil,
|
||||
"usage": gin.H{
|
||||
"input_tokens": 10,
|
||||
"input_tokens": 10,
|
||||
"cache_creation_input_tokens": 0,
|
||||
"cache_read_input_tokens": 0,
|
||||
"cache_creation": gin.H{
|
||||
"ephemeral_5m_input_tokens": 0,
|
||||
"ephemeral_1h_input_tokens": 0,
|
||||
},
|
||||
"output_tokens": outputTokens,
|
||||
"total_tokens": 10 + outputTokens,
|
||||
},
|
||||
})
|
||||
}
|
||||
|
||||
c.JSON(http.StatusOK, response)
|
||||
}
|
||||
|
||||
func billingErrorDetails(err error) (status int, code, message string) {
|
||||
@@ -1156,7 +1252,8 @@ func billingErrorDetails(err error) (status int, code, message string) {
|
||||
}
|
||||
msg := pkgerrors.Message(err)
|
||||
if msg == "" {
|
||||
msg = err.Error()
|
||||
log.Printf("[Gateway] billing error details: %v", err)
|
||||
msg = "Billing error"
|
||||
}
|
||||
return http.StatusForbidden, "billing_error", msg
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user