perf(gateway): 优化负载感知调度
主要改进: - 优化负载感知调度的准确性和响应速度 - 将 AccountUsageService 的包级缓存改为依赖注入 - 修复 SSE/JSON 转义和 nil 安全问题 - 恢复 Google One 功能兼容性
This commit is contained in:
@@ -905,7 +905,7 @@ func (s *GatewayService) GetAccessToken(ctx context.Context, account *Account) (
|
||||
case AccountTypeOAuth, AccountTypeSetupToken:
|
||||
// Both oauth and setup-token use OAuth token flow
|
||||
return s.getOAuthToken(ctx, account)
|
||||
case AccountTypeAPIKey:
|
||||
case AccountTypeApiKey:
|
||||
apiKey := account.GetCredential("api_key")
|
||||
if apiKey == "" {
|
||||
return "", "", errors.New("api_key not found in credentials")
|
||||
@@ -976,7 +976,7 @@ func (s *GatewayService) Forward(ctx context.Context, c *gin.Context, account *A
|
||||
|
||||
// 应用模型映射(仅对apikey类型账号)
|
||||
originalModel := reqModel
|
||||
if account.Type == AccountTypeAPIKey {
|
||||
if account.Type == AccountTypeApiKey {
|
||||
mappedModel := account.GetMappedModel(reqModel)
|
||||
if mappedModel != reqModel {
|
||||
// 替换请求体中的模型名
|
||||
@@ -1013,8 +1013,37 @@ func (s *GatewayService) Forward(ctx context.Context, c *gin.Context, account *A
|
||||
return nil, fmt.Errorf("upstream request failed: %w", err)
|
||||
}
|
||||
|
||||
// 检查是否需要重试
|
||||
if resp.StatusCode >= 400 && s.shouldRetryUpstreamError(account, resp.StatusCode) {
|
||||
// 优先检测thinking block签名错误(400)并重试一次
|
||||
if resp.StatusCode == 400 {
|
||||
respBody, readErr := io.ReadAll(io.LimitReader(resp.Body, 2<<20))
|
||||
if readErr == nil {
|
||||
_ = resp.Body.Close()
|
||||
|
||||
if s.isThinkingBlockSignatureError(respBody) {
|
||||
log.Printf("Account %d: detected thinking block signature error, retrying with filtered thinking blocks", account.ID)
|
||||
|
||||
// 过滤thinking blocks并重试
|
||||
filteredBody := FilterThinkingBlocks(body)
|
||||
retryReq, buildErr := s.buildUpstreamRequest(ctx, c, account, filteredBody, token, tokenType, reqModel)
|
||||
if buildErr == nil {
|
||||
retryResp, retryErr := s.httpUpstream.Do(retryReq, proxyURL, account.ID, account.Concurrency)
|
||||
if retryErr == nil {
|
||||
// 使用重试后的响应,继续后续处理
|
||||
resp = retryResp
|
||||
break
|
||||
}
|
||||
}
|
||||
// 重试失败,恢复原始响应体继续处理
|
||||
resp.Body = io.NopCloser(bytes.NewReader(respBody))
|
||||
break
|
||||
}
|
||||
// 不是thinking签名错误,恢复响应体
|
||||
resp.Body = io.NopCloser(bytes.NewReader(respBody))
|
||||
}
|
||||
}
|
||||
|
||||
// 检查是否需要通用重试(排除400,因为400已经在上面特殊处理过了)
|
||||
if resp.StatusCode >= 400 && resp.StatusCode != 400 && s.shouldRetryUpstreamError(account, resp.StatusCode) {
|
||||
if attempt < maxRetries {
|
||||
log.Printf("Account %d: upstream error %d, retry %d/%d after %v",
|
||||
account.ID, resp.StatusCode, attempt, maxRetries, retryDelay)
|
||||
@@ -1047,13 +1076,13 @@ func (s *GatewayService) Forward(ctx context.Context, c *gin.Context, account *A
|
||||
}
|
||||
|
||||
// 处理错误响应(不可重试的错误)
|
||||
if resp.StatusCode >= 400 {
|
||||
// 可选:对部分 400 触发 failover(默认关闭以保持语义)
|
||||
if resp.StatusCode == 400 && s.cfg != nil && s.cfg.Gateway.FailoverOn400 {
|
||||
respBody, readErr := io.ReadAll(resp.Body)
|
||||
if readErr != nil {
|
||||
// ReadAll failed, fall back to normal error handling without consuming the stream
|
||||
return s.handleErrorResponse(ctx, resp, c, account)
|
||||
if resp.StatusCode >= 400 {
|
||||
// 可选:对部分 400 触发 failover(默认关闭以保持语义)
|
||||
if resp.StatusCode == 400 && s.cfg != nil && s.cfg.Gateway.FailoverOn400 {
|
||||
respBody, readErr := io.ReadAll(io.LimitReader(resp.Body, 2<<20))
|
||||
if readErr != nil {
|
||||
// ReadAll failed, fall back to normal error handling without consuming the stream
|
||||
return s.handleErrorResponse(ctx, resp, c, account)
|
||||
}
|
||||
_ = resp.Body.Close()
|
||||
resp.Body = io.NopCloser(bytes.NewReader(respBody))
|
||||
@@ -1110,7 +1139,7 @@ func (s *GatewayService) Forward(ctx context.Context, c *gin.Context, account *A
|
||||
func (s *GatewayService) buildUpstreamRequest(ctx context.Context, c *gin.Context, account *Account, body []byte, token, tokenType, modelID string) (*http.Request, error) {
|
||||
// 确定目标URL
|
||||
targetURL := claudeAPIURL
|
||||
if account.Type == AccountTypeAPIKey {
|
||||
if account.Type == AccountTypeApiKey {
|
||||
baseURL := account.GetBaseURL()
|
||||
targetURL = baseURL + "/v1/messages"
|
||||
}
|
||||
@@ -1136,10 +1165,6 @@ func (s *GatewayService) buildUpstreamRequest(ctx context.Context, c *gin.Contex
|
||||
}
|
||||
}
|
||||
|
||||
// Filter thinking blocks from request body (prevents 400 errors from missing/invalid signatures).
|
||||
// We apply this for the main /v1/messages path as well as count_tokens.
|
||||
body = FilterThinkingBlocks(body)
|
||||
|
||||
req, err := http.NewRequestWithContext(ctx, "POST", targetURL, bytes.NewReader(body))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
@@ -1178,10 +1203,10 @@ func (s *GatewayService) buildUpstreamRequest(ctx context.Context, c *gin.Contex
|
||||
// 处理anthropic-beta header(OAuth账号需要特殊处理)
|
||||
if tokenType == "oauth" {
|
||||
req.Header.Set("anthropic-beta", s.getBetaHeader(modelID, c.GetHeader("anthropic-beta")))
|
||||
} else if s.cfg != nil && s.cfg.Gateway.InjectBetaForAPIKey && req.Header.Get("anthropic-beta") == "" {
|
||||
} else if s.cfg != nil && s.cfg.Gateway.InjectBetaForApiKey && req.Header.Get("anthropic-beta") == "" {
|
||||
// API-key:仅在请求显式使用 beta 特性且客户端未提供时,按需补齐(默认关闭)
|
||||
if requestNeedsBetaFeatures(body) {
|
||||
if beta := defaultAPIKeyBetaHeader(body); beta != "" {
|
||||
if beta := defaultApiKeyBetaHeader(body); beta != "" {
|
||||
req.Header.Set("anthropic-beta", beta)
|
||||
}
|
||||
}
|
||||
@@ -1248,12 +1273,12 @@ func requestNeedsBetaFeatures(body []byte) bool {
|
||||
return false
|
||||
}
|
||||
|
||||
func defaultAPIKeyBetaHeader(body []byte) string {
|
||||
func defaultApiKeyBetaHeader(body []byte) string {
|
||||
modelID := gjson.GetBytes(body, "model").String()
|
||||
if strings.Contains(strings.ToLower(modelID), "haiku") {
|
||||
return claude.APIKeyHaikuBetaHeader
|
||||
return claude.ApiKeyHaikuBetaHeader
|
||||
}
|
||||
return claude.APIKeyBetaHeader
|
||||
return claude.ApiKeyBetaHeader
|
||||
}
|
||||
|
||||
func truncateForLog(b []byte, maxBytes int) string {
|
||||
@@ -1270,6 +1295,20 @@ func truncateForLog(b []byte, maxBytes int) string {
|
||||
return s
|
||||
}
|
||||
|
||||
// isThinkingBlockSignatureError 检测是否是thinking block签名错误
|
||||
// 这类错误可以通过过滤thinking blocks并重试来解决
|
||||
func (s *GatewayService) isThinkingBlockSignatureError(respBody []byte) bool {
|
||||
msg := strings.ToLower(strings.TrimSpace(extractUpstreamErrorMessage(respBody)))
|
||||
if msg == "" {
|
||||
return false
|
||||
}
|
||||
|
||||
// 检测thinking block签名相关的错误
|
||||
// 例如: "Invalid `signature` in `thinking` block"
|
||||
return (strings.Contains(msg, "thinking") || strings.Contains(msg, "thought")) &&
|
||||
strings.Contains(msg, "signature")
|
||||
}
|
||||
|
||||
func (s *GatewayService) shouldFailoverOn400(respBody []byte) bool {
|
||||
// 只对“可能是兼容性差异导致”的 400 允许切换,避免无意义重试。
|
||||
// 默认保守:无法识别则不切换。
|
||||
@@ -1318,7 +1357,13 @@ func (s *GatewayService) handleErrorResponse(ctx context.Context, resp *http.Res
|
||||
body, _ := io.ReadAll(resp.Body)
|
||||
|
||||
// 处理上游错误,标记账号状态
|
||||
s.rateLimitService.HandleUpstreamError(ctx, account, resp.StatusCode, resp.Header, body)
|
||||
shouldDisable := false
|
||||
if s.rateLimitService != nil {
|
||||
shouldDisable = s.rateLimitService.HandleUpstreamError(ctx, account, resp.StatusCode, resp.Header, body)
|
||||
}
|
||||
if shouldDisable {
|
||||
return nil, &UpstreamFailoverError{StatusCode: resp.StatusCode}
|
||||
}
|
||||
|
||||
// 根据状态码返回适当的自定义错误响应(不透传上游详细信息)
|
||||
var errType, errMsg string
|
||||
@@ -1630,7 +1675,7 @@ func (s *GatewayService) replaceModelInResponseBody(body []byte, fromModel, toMo
|
||||
// RecordUsageInput 记录使用量的输入参数
|
||||
type RecordUsageInput struct {
|
||||
Result *ForwardResult
|
||||
APIKey *APIKey
|
||||
ApiKey *ApiKey
|
||||
User *User
|
||||
Account *Account
|
||||
Subscription *UserSubscription // 可选:订阅信息
|
||||
@@ -1639,7 +1684,7 @@ type RecordUsageInput struct {
|
||||
// RecordUsage 记录使用量并扣费(或更新订阅用量)
|
||||
func (s *GatewayService) RecordUsage(ctx context.Context, input *RecordUsageInput) error {
|
||||
result := input.Result
|
||||
apiKey := input.APIKey
|
||||
apiKey := input.ApiKey
|
||||
user := input.User
|
||||
account := input.Account
|
||||
subscription := input.Subscription
|
||||
@@ -1676,7 +1721,7 @@ func (s *GatewayService) RecordUsage(ctx context.Context, input *RecordUsageInpu
|
||||
durationMs := int(result.Duration.Milliseconds())
|
||||
usageLog := &UsageLog{
|
||||
UserID: user.ID,
|
||||
APIKeyID: apiKey.ID,
|
||||
ApiKeyID: apiKey.ID,
|
||||
AccountID: account.ID,
|
||||
RequestID: result.RequestID,
|
||||
Model: result.Model,
|
||||
@@ -1754,15 +1799,14 @@ func (s *GatewayService) ForwardCountTokens(ctx context.Context, c *gin.Context,
|
||||
body := parsed.Body
|
||||
reqModel := parsed.Model
|
||||
|
||||
// Antigravity 账户不支持 count_tokens 转发,返回估算值
|
||||
// 参考 Antigravity-Manager 和 proxycast 实现
|
||||
// Antigravity 账户不支持 count_tokens 转发,直接返回空值
|
||||
if account.Platform == PlatformAntigravity {
|
||||
c.JSON(http.StatusOK, gin.H{"input_tokens": 100})
|
||||
c.JSON(http.StatusOK, gin.H{"input_tokens": 0})
|
||||
return nil
|
||||
}
|
||||
|
||||
// 应用模型映射(仅对 apikey 类型账号)
|
||||
if account.Type == AccountTypeAPIKey {
|
||||
if account.Type == AccountTypeApiKey {
|
||||
if reqModel != "" {
|
||||
mappedModel := account.GetMappedModel(reqModel)
|
||||
if mappedModel != reqModel {
|
||||
@@ -1848,7 +1892,7 @@ func (s *GatewayService) ForwardCountTokens(ctx context.Context, c *gin.Context,
|
||||
func (s *GatewayService) buildCountTokensRequest(ctx context.Context, c *gin.Context, account *Account, body []byte, token, tokenType, modelID string) (*http.Request, error) {
|
||||
// 确定目标 URL
|
||||
targetURL := claudeAPICountTokensURL
|
||||
if account.Type == AccountTypeAPIKey {
|
||||
if account.Type == AccountTypeApiKey {
|
||||
baseURL := account.GetBaseURL()
|
||||
targetURL = baseURL + "/v1/messages/count_tokens"
|
||||
}
|
||||
@@ -1866,9 +1910,6 @@ func (s *GatewayService) buildCountTokensRequest(ctx context.Context, c *gin.Con
|
||||
}
|
||||
}
|
||||
|
||||
// Filter thinking blocks from request body (prevents 400 errors from invalid signatures)
|
||||
body = FilterThinkingBlocks(body)
|
||||
|
||||
req, err := http.NewRequestWithContext(ctx, "POST", targetURL, bytes.NewReader(body))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
@@ -1910,10 +1951,10 @@ func (s *GatewayService) buildCountTokensRequest(ctx context.Context, c *gin.Con
|
||||
// OAuth 账号:处理 anthropic-beta header
|
||||
if tokenType == "oauth" {
|
||||
req.Header.Set("anthropic-beta", s.getBetaHeader(modelID, c.GetHeader("anthropic-beta")))
|
||||
} else if s.cfg != nil && s.cfg.Gateway.InjectBetaForAPIKey && req.Header.Get("anthropic-beta") == "" {
|
||||
} else if s.cfg != nil && s.cfg.Gateway.InjectBetaForApiKey && req.Header.Get("anthropic-beta") == "" {
|
||||
// API-key:与 messages 同步的按需 beta 注入(默认关闭)
|
||||
if requestNeedsBetaFeatures(body) {
|
||||
if beta := defaultAPIKeyBetaHeader(body); beta != "" {
|
||||
if beta := defaultApiKeyBetaHeader(body); beta != "" {
|
||||
req.Header.Set("anthropic-beta", beta)
|
||||
}
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user