perf(gateway): 优化负载感知调度

主要改进:
- 优化负载感知调度的准确性和响应速度
- 将 AccountUsageService 的包级缓存改为依赖注入
- 修复 SSE/JSON 转义和 nil 安全问题
- 恢复 Google One 功能兼容性
This commit is contained in:
ianshaw
2026-01-03 06:32:51 -08:00
parent 26106eb0ac
commit acb718d355
7 changed files with 369 additions and 310 deletions

View File

@@ -905,7 +905,7 @@ func (s *GatewayService) GetAccessToken(ctx context.Context, account *Account) (
case AccountTypeOAuth, AccountTypeSetupToken:
// Both oauth and setup-token use OAuth token flow
return s.getOAuthToken(ctx, account)
case AccountTypeAPIKey:
case AccountTypeApiKey:
apiKey := account.GetCredential("api_key")
if apiKey == "" {
return "", "", errors.New("api_key not found in credentials")
@@ -976,7 +976,7 @@ func (s *GatewayService) Forward(ctx context.Context, c *gin.Context, account *A
// 应用模型映射仅对apikey类型账号
originalModel := reqModel
if account.Type == AccountTypeAPIKey {
if account.Type == AccountTypeApiKey {
mappedModel := account.GetMappedModel(reqModel)
if mappedModel != reqModel {
// 替换请求体中的模型名
@@ -1013,8 +1013,37 @@ func (s *GatewayService) Forward(ctx context.Context, c *gin.Context, account *A
return nil, fmt.Errorf("upstream request failed: %w", err)
}
// 检查是否需要重试
if resp.StatusCode >= 400 && s.shouldRetryUpstreamError(account, resp.StatusCode) {
// 优先检测thinking block签名错误400并重试一次
if resp.StatusCode == 400 {
respBody, readErr := io.ReadAll(io.LimitReader(resp.Body, 2<<20))
if readErr == nil {
_ = resp.Body.Close()
if s.isThinkingBlockSignatureError(respBody) {
log.Printf("Account %d: detected thinking block signature error, retrying with filtered thinking blocks", account.ID)
// 过滤thinking blocks并重试
filteredBody := FilterThinkingBlocks(body)
retryReq, buildErr := s.buildUpstreamRequest(ctx, c, account, filteredBody, token, tokenType, reqModel)
if buildErr == nil {
retryResp, retryErr := s.httpUpstream.Do(retryReq, proxyURL, account.ID, account.Concurrency)
if retryErr == nil {
// 使用重试后的响应,继续后续处理
resp = retryResp
break
}
}
// 重试失败,恢复原始响应体继续处理
resp.Body = io.NopCloser(bytes.NewReader(respBody))
break
}
// 不是thinking签名错误恢复响应体
resp.Body = io.NopCloser(bytes.NewReader(respBody))
}
}
// 检查是否需要通用重试排除400因为400已经在上面特殊处理过了
if resp.StatusCode >= 400 && resp.StatusCode != 400 && s.shouldRetryUpstreamError(account, resp.StatusCode) {
if attempt < maxRetries {
log.Printf("Account %d: upstream error %d, retry %d/%d after %v",
account.ID, resp.StatusCode, attempt, maxRetries, retryDelay)
@@ -1047,13 +1076,13 @@ func (s *GatewayService) Forward(ctx context.Context, c *gin.Context, account *A
}
// 处理错误响应(不可重试的错误)
if resp.StatusCode >= 400 {
// 可选:对部分 400 触发 failover默认关闭以保持语义
if resp.StatusCode == 400 && s.cfg != nil && s.cfg.Gateway.FailoverOn400 {
respBody, readErr := io.ReadAll(resp.Body)
if readErr != nil {
// ReadAll failed, fall back to normal error handling without consuming the stream
return s.handleErrorResponse(ctx, resp, c, account)
if resp.StatusCode >= 400 {
// 可选:对部分 400 触发 failover默认关闭以保持语义
if resp.StatusCode == 400 && s.cfg != nil && s.cfg.Gateway.FailoverOn400 {
respBody, readErr := io.ReadAll(io.LimitReader(resp.Body, 2<<20))
if readErr != nil {
// ReadAll failed, fall back to normal error handling without consuming the stream
return s.handleErrorResponse(ctx, resp, c, account)
}
_ = resp.Body.Close()
resp.Body = io.NopCloser(bytes.NewReader(respBody))
@@ -1110,7 +1139,7 @@ func (s *GatewayService) Forward(ctx context.Context, c *gin.Context, account *A
func (s *GatewayService) buildUpstreamRequest(ctx context.Context, c *gin.Context, account *Account, body []byte, token, tokenType, modelID string) (*http.Request, error) {
// 确定目标URL
targetURL := claudeAPIURL
if account.Type == AccountTypeAPIKey {
if account.Type == AccountTypeApiKey {
baseURL := account.GetBaseURL()
targetURL = baseURL + "/v1/messages"
}
@@ -1136,10 +1165,6 @@ func (s *GatewayService) buildUpstreamRequest(ctx context.Context, c *gin.Contex
}
}
// Filter thinking blocks from request body (prevents 400 errors from missing/invalid signatures).
// We apply this for the main /v1/messages path as well as count_tokens.
body = FilterThinkingBlocks(body)
req, err := http.NewRequestWithContext(ctx, "POST", targetURL, bytes.NewReader(body))
if err != nil {
return nil, err
@@ -1178,10 +1203,10 @@ func (s *GatewayService) buildUpstreamRequest(ctx context.Context, c *gin.Contex
// 处理anthropic-beta headerOAuth账号需要特殊处理
if tokenType == "oauth" {
req.Header.Set("anthropic-beta", s.getBetaHeader(modelID, c.GetHeader("anthropic-beta")))
} else if s.cfg != nil && s.cfg.Gateway.InjectBetaForAPIKey && req.Header.Get("anthropic-beta") == "" {
} else if s.cfg != nil && s.cfg.Gateway.InjectBetaForApiKey && req.Header.Get("anthropic-beta") == "" {
// API-key仅在请求显式使用 beta 特性且客户端未提供时,按需补齐(默认关闭)
if requestNeedsBetaFeatures(body) {
if beta := defaultAPIKeyBetaHeader(body); beta != "" {
if beta := defaultApiKeyBetaHeader(body); beta != "" {
req.Header.Set("anthropic-beta", beta)
}
}
@@ -1248,12 +1273,12 @@ func requestNeedsBetaFeatures(body []byte) bool {
return false
}
func defaultAPIKeyBetaHeader(body []byte) string {
func defaultApiKeyBetaHeader(body []byte) string {
modelID := gjson.GetBytes(body, "model").String()
if strings.Contains(strings.ToLower(modelID), "haiku") {
return claude.APIKeyHaikuBetaHeader
return claude.ApiKeyHaikuBetaHeader
}
return claude.APIKeyBetaHeader
return claude.ApiKeyBetaHeader
}
func truncateForLog(b []byte, maxBytes int) string {
@@ -1270,6 +1295,20 @@ func truncateForLog(b []byte, maxBytes int) string {
return s
}
// isThinkingBlockSignatureError 检测是否是thinking block签名错误
// 这类错误可以通过过滤thinking blocks并重试来解决
func (s *GatewayService) isThinkingBlockSignatureError(respBody []byte) bool {
msg := strings.ToLower(strings.TrimSpace(extractUpstreamErrorMessage(respBody)))
if msg == "" {
return false
}
// 检测thinking block签名相关的错误
// 例如: "Invalid `signature` in `thinking` block"
return (strings.Contains(msg, "thinking") || strings.Contains(msg, "thought")) &&
strings.Contains(msg, "signature")
}
func (s *GatewayService) shouldFailoverOn400(respBody []byte) bool {
// 只对“可能是兼容性差异导致”的 400 允许切换,避免无意义重试。
// 默认保守:无法识别则不切换。
@@ -1318,7 +1357,13 @@ func (s *GatewayService) handleErrorResponse(ctx context.Context, resp *http.Res
body, _ := io.ReadAll(resp.Body)
// 处理上游错误,标记账号状态
s.rateLimitService.HandleUpstreamError(ctx, account, resp.StatusCode, resp.Header, body)
shouldDisable := false
if s.rateLimitService != nil {
shouldDisable = s.rateLimitService.HandleUpstreamError(ctx, account, resp.StatusCode, resp.Header, body)
}
if shouldDisable {
return nil, &UpstreamFailoverError{StatusCode: resp.StatusCode}
}
// 根据状态码返回适当的自定义错误响应(不透传上游详细信息)
var errType, errMsg string
@@ -1630,7 +1675,7 @@ func (s *GatewayService) replaceModelInResponseBody(body []byte, fromModel, toMo
// RecordUsageInput 记录使用量的输入参数
type RecordUsageInput struct {
Result *ForwardResult
APIKey *APIKey
ApiKey *ApiKey
User *User
Account *Account
Subscription *UserSubscription // 可选:订阅信息
@@ -1639,7 +1684,7 @@ type RecordUsageInput struct {
// RecordUsage 记录使用量并扣费(或更新订阅用量)
func (s *GatewayService) RecordUsage(ctx context.Context, input *RecordUsageInput) error {
result := input.Result
apiKey := input.APIKey
apiKey := input.ApiKey
user := input.User
account := input.Account
subscription := input.Subscription
@@ -1676,7 +1721,7 @@ func (s *GatewayService) RecordUsage(ctx context.Context, input *RecordUsageInpu
durationMs := int(result.Duration.Milliseconds())
usageLog := &UsageLog{
UserID: user.ID,
APIKeyID: apiKey.ID,
ApiKeyID: apiKey.ID,
AccountID: account.ID,
RequestID: result.RequestID,
Model: result.Model,
@@ -1754,15 +1799,14 @@ func (s *GatewayService) ForwardCountTokens(ctx context.Context, c *gin.Context,
body := parsed.Body
reqModel := parsed.Model
// Antigravity 账户不支持 count_tokens 转发,返回估算
// 参考 Antigravity-Manager 和 proxycast 实现
// Antigravity 账户不支持 count_tokens 转发,直接返回空
if account.Platform == PlatformAntigravity {
c.JSON(http.StatusOK, gin.H{"input_tokens": 100})
c.JSON(http.StatusOK, gin.H{"input_tokens": 0})
return nil
}
// 应用模型映射(仅对 apikey 类型账号)
if account.Type == AccountTypeAPIKey {
if account.Type == AccountTypeApiKey {
if reqModel != "" {
mappedModel := account.GetMappedModel(reqModel)
if mappedModel != reqModel {
@@ -1848,7 +1892,7 @@ func (s *GatewayService) ForwardCountTokens(ctx context.Context, c *gin.Context,
func (s *GatewayService) buildCountTokensRequest(ctx context.Context, c *gin.Context, account *Account, body []byte, token, tokenType, modelID string) (*http.Request, error) {
// 确定目标 URL
targetURL := claudeAPICountTokensURL
if account.Type == AccountTypeAPIKey {
if account.Type == AccountTypeApiKey {
baseURL := account.GetBaseURL()
targetURL = baseURL + "/v1/messages/count_tokens"
}
@@ -1866,9 +1910,6 @@ func (s *GatewayService) buildCountTokensRequest(ctx context.Context, c *gin.Con
}
}
// Filter thinking blocks from request body (prevents 400 errors from invalid signatures)
body = FilterThinkingBlocks(body)
req, err := http.NewRequestWithContext(ctx, "POST", targetURL, bytes.NewReader(body))
if err != nil {
return nil, err
@@ -1910,10 +1951,10 @@ func (s *GatewayService) buildCountTokensRequest(ctx context.Context, c *gin.Con
// OAuth 账号:处理 anthropic-beta header
if tokenType == "oauth" {
req.Header.Set("anthropic-beta", s.getBetaHeader(modelID, c.GetHeader("anthropic-beta")))
} else if s.cfg != nil && s.cfg.Gateway.InjectBetaForAPIKey && req.Header.Get("anthropic-beta") == "" {
} else if s.cfg != nil && s.cfg.Gateway.InjectBetaForApiKey && req.Header.Get("anthropic-beta") == "" {
// API-key与 messages 同步的按需 beta 注入(默认关闭)
if requestNeedsBetaFeatures(body) {
if beta := defaultAPIKeyBetaHeader(body); beta != "" {
if beta := defaultApiKeyBetaHeader(body); beta != "" {
req.Header.Set("anthropic-beta", beta)
}
}