This commit is contained in:
yangjianbo
2026-01-04 21:06:12 +08:00
183 changed files with 8275 additions and 3879 deletions

View File

@@ -547,7 +547,7 @@ func (s *GatewayService) SelectAccountWithLoadAwareness(ctx context.Context, gro
for _, item := range available {
result, err := s.tryAcquireAccountSlot(ctx, item.account.ID, item.account.Concurrency)
if err == nil && result.Acquired {
if sessionHash != "" {
if sessionHash != "" && s.cache != nil {
_ = s.cache.SetSessionAccountID(ctx, sessionHash, item.account.ID, stickySessionTTL)
}
return &AccountSelectionResult{
@@ -583,7 +583,7 @@ func (s *GatewayService) tryAcquireByLegacyOrder(ctx context.Context, candidates
for _, acc := range ordered {
result, err := s.tryAcquireAccountSlot(ctx, acc.ID, acc.Concurrency)
if err == nil && result.Acquired {
if sessionHash != "" {
if sessionHash != "" && s.cache != nil {
_ = s.cache.SetSessionAccountID(ctx, sessionHash, acc.ID, stickySessionTTL)
}
return &AccountSelectionResult{
@@ -714,7 +714,7 @@ func sortAccountsByPriorityAndLastUsed(accounts []*Account, preferOAuth bool) {
func (s *GatewayService) selectAccountForModelWithPlatform(ctx context.Context, groupID *int64, sessionHash string, requestedModel string, excludedIDs map[int64]struct{}, platform string) (*Account, error) {
preferOAuth := platform == PlatformGemini
// 1. 查询粘性会话
if sessionHash != "" {
if sessionHash != "" && s.cache != nil {
accountID, err := s.cache.GetSessionAccountID(ctx, sessionHash)
if err == nil && accountID > 0 {
if _, excluded := excludedIDs[accountID]; !excluded {
@@ -787,7 +787,7 @@ func (s *GatewayService) selectAccountForModelWithPlatform(ctx context.Context,
}
// 4. 建立粘性绑定
if sessionHash != "" {
if sessionHash != "" && s.cache != nil {
if err := s.cache.SetSessionAccountID(ctx, sessionHash, selected.ID, stickySessionTTL); err != nil {
log.Printf("set session account failed: session=%s account_id=%d err=%v", sessionHash, selected.ID, err)
}
@@ -803,7 +803,7 @@ func (s *GatewayService) selectAccountWithMixedScheduling(ctx context.Context, g
preferOAuth := nativePlatform == PlatformGemini
// 1. 查询粘性会话
if sessionHash != "" {
if sessionHash != "" && s.cache != nil {
accountID, err := s.cache.GetSessionAccountID(ctx, sessionHash)
if err == nil && accountID > 0 {
if _, excluded := excludedIDs[accountID]; !excluded {
@@ -879,7 +879,7 @@ func (s *GatewayService) selectAccountWithMixedScheduling(ctx context.Context, g
}
// 4. 建立粘性绑定
if sessionHash != "" {
if sessionHash != "" && s.cache != nil {
if err := s.cache.SetSessionAccountID(ctx, sessionHash, selected.ID, stickySessionTTL); err != nil {
log.Printf("set session account failed: session=%s account_id=%d err=%v", sessionHash, selected.ID, err)
}
@@ -911,7 +911,7 @@ func (s *GatewayService) GetAccessToken(ctx context.Context, account *Account) (
case AccountTypeOAuth, AccountTypeSetupToken:
// Both oauth and setup-token use OAuth token flow
return s.getOAuthToken(ctx, account)
case AccountTypeApiKey:
case AccountTypeAPIKey:
apiKey := account.GetCredential("api_key")
if apiKey == "" {
return "", "", errors.New("api_key not found in credentials")
@@ -1049,7 +1049,7 @@ func (s *GatewayService) Forward(ctx context.Context, c *gin.Context, account *A
// 应用模型映射仅对apikey类型账号
originalModel := reqModel
if account.Type == AccountTypeApiKey {
if account.Type == AccountTypeAPIKey {
mappedModel := account.GetMappedModel(reqModel)
if mappedModel != reqModel {
// 替换请求体中的模型名
@@ -1086,8 +1086,45 @@ func (s *GatewayService) Forward(ctx context.Context, c *gin.Context, account *A
return nil, fmt.Errorf("upstream request failed: %w", err)
}
// 检查是否需要重试
if resp.StatusCode >= 400 && s.shouldRetryUpstreamError(account, resp.StatusCode) {
// 优先检测thinking block签名错误400并重试一次
if resp.StatusCode == 400 {
respBody, readErr := io.ReadAll(io.LimitReader(resp.Body, 2<<20))
if readErr == nil {
_ = resp.Body.Close()
if s.isThinkingBlockSignatureError(respBody) {
log.Printf("Account %d: detected thinking block signature error, retrying with filtered thinking blocks", account.ID)
// 过滤thinking blocks并重试使用更激进的过滤
filteredBody := FilterThinkingBlocksForRetry(body)
retryReq, buildErr := s.buildUpstreamRequest(ctx, c, account, filteredBody, token, tokenType, reqModel)
if buildErr == nil {
retryResp, retryErr := s.httpUpstream.Do(retryReq, proxyURL, account.ID, account.Concurrency)
if retryErr == nil {
// 使用重试后的响应,继续后续处理
if retryResp.StatusCode < 400 {
log.Printf("Account %d: signature error retry succeeded", account.ID)
} else {
log.Printf("Account %d: signature error retry returned status %d", account.ID, retryResp.StatusCode)
}
resp = retryResp
break
}
log.Printf("Account %d: signature error retry failed: %v", account.ID, retryErr)
} else {
log.Printf("Account %d: signature error retry build request failed: %v", account.ID, buildErr)
}
// 重试失败,恢复原始响应体继续处理
resp.Body = io.NopCloser(bytes.NewReader(respBody))
break
}
// 不是thinking签名错误恢复响应体
resp.Body = io.NopCloser(bytes.NewReader(respBody))
}
}
// 检查是否需要通用重试排除400因为400已经在上面特殊处理过了
if resp.StatusCode >= 400 && resp.StatusCode != 400 && s.shouldRetryUpstreamError(account, resp.StatusCode) {
if attempt < maxRetries {
log.Printf("Account %d: upstream error %d, retry %d/%d after %v",
account.ID, resp.StatusCode, attempt, maxRetries, retryDelay)
@@ -1100,6 +1137,13 @@ func (s *GatewayService) Forward(ctx context.Context, c *gin.Context, account *A
}
// 不需要重试(成功或不可重试的错误),跳出循环
// DEBUG: 输出响应 headers用于检测 rate limit 信息)
if account.Platform == PlatformGemini && resp.StatusCode < 400 {
log.Printf("[DEBUG] Gemini API Response Headers for account %d:", account.ID)
for k, v := range resp.Header {
log.Printf("[DEBUG] %s: %v", k, v)
}
}
break
}
defer func() { _ = resp.Body.Close() }()
@@ -1123,7 +1167,7 @@ func (s *GatewayService) Forward(ctx context.Context, c *gin.Context, account *A
if resp.StatusCode >= 400 {
// 可选:对部分 400 触发 failover默认关闭以保持语义
if resp.StatusCode == 400 && s.cfg != nil && s.cfg.Gateway.FailoverOn400 {
respBody, readErr := io.ReadAll(resp.Body)
respBody, readErr := io.ReadAll(io.LimitReader(resp.Body, 2<<20))
if readErr != nil {
// ReadAll failed, fall back to normal error handling without consuming the stream
return s.handleErrorResponse(ctx, resp, c, account)
@@ -1183,7 +1227,7 @@ func (s *GatewayService) Forward(ctx context.Context, c *gin.Context, account *A
func (s *GatewayService) buildUpstreamRequest(ctx context.Context, c *gin.Context, account *Account, body []byte, token, tokenType, modelID string) (*http.Request, error) {
// 确定目标URL
targetURL := claudeAPIURL
if account.Type == AccountTypeApiKey {
if account.Type == AccountTypeAPIKey {
baseURL := account.GetBaseURL()
if baseURL != "" {
validatedURL, err := s.validateUpstreamBaseURL(baseURL)
@@ -1253,10 +1297,10 @@ func (s *GatewayService) buildUpstreamRequest(ctx context.Context, c *gin.Contex
// 处理anthropic-beta headerOAuth账号需要特殊处理
if tokenType == "oauth" {
req.Header.Set("anthropic-beta", s.getBetaHeader(modelID, c.GetHeader("anthropic-beta")))
} else if s.cfg != nil && s.cfg.Gateway.InjectBetaForApiKey && req.Header.Get("anthropic-beta") == "" {
} else if s.cfg != nil && s.cfg.Gateway.InjectBetaForAPIKey && req.Header.Get("anthropic-beta") == "" {
// API-key仅在请求显式使用 beta 特性且客户端未提供时,按需补齐(默认关闭)
if requestNeedsBetaFeatures(body) {
if beta := defaultApiKeyBetaHeader(body); beta != "" {
if beta := defaultAPIKeyBetaHeader(body); beta != "" {
req.Header.Set("anthropic-beta", beta)
}
}
@@ -1323,12 +1367,12 @@ func requestNeedsBetaFeatures(body []byte) bool {
return false
}
func defaultApiKeyBetaHeader(body []byte) string {
func defaultAPIKeyBetaHeader(body []byte) string {
modelID := gjson.GetBytes(body, "model").String()
if strings.Contains(strings.ToLower(modelID), "haiku") {
return claude.ApiKeyHaikuBetaHeader
return claude.APIKeyHaikuBetaHeader
}
return claude.ApiKeyBetaHeader
return claude.APIKeyBetaHeader
}
func truncateForLog(b []byte, maxBytes int) string {
@@ -1345,6 +1389,41 @@ func truncateForLog(b []byte, maxBytes int) string {
return s
}
// isThinkingBlockSignatureError 检测是否是thinking block相关错误
// 这类错误可以通过过滤thinking blocks并重试来解决
func (s *GatewayService) isThinkingBlockSignatureError(respBody []byte) bool {
msg := strings.ToLower(strings.TrimSpace(extractUpstreamErrorMessage(respBody)))
if msg == "" {
return false
}
// Log for debugging
log.Printf("[SignatureCheck] Checking error message: %s", msg)
// 检测signature相关的错误更宽松的匹配
// 例如: "Invalid `signature` in `thinking` block", "***.signature" 等
if strings.Contains(msg, "signature") {
log.Printf("[SignatureCheck] Detected signature error")
return true
}
// 检测 thinking block 顺序/类型错误
// 例如: "Expected `thinking` or `redacted_thinking`, but found `text`"
if strings.Contains(msg, "expected") && (strings.Contains(msg, "thinking") || strings.Contains(msg, "redacted_thinking")) {
log.Printf("[SignatureCheck] Detected thinking block type error")
return true
}
// 检测空消息内容错误(可能是过滤 thinking blocks 后导致的)
// 例如: "all messages must have non-empty content"
if strings.Contains(msg, "non-empty content") || strings.Contains(msg, "empty content") {
log.Printf("[SignatureCheck] Detected empty content error")
return true
}
return false
}
func (s *GatewayService) shouldFailoverOn400(respBody []byte) bool {
// 只对“可能是兼容性差异导致”的 400 允许切换,避免无意义重试。
// 默认保守:无法识别则不切换。
@@ -1393,7 +1472,13 @@ func (s *GatewayService) handleErrorResponse(ctx context.Context, resp *http.Res
body, _ := io.ReadAll(resp.Body)
// 处理上游错误,标记账号状态
s.rateLimitService.HandleUpstreamError(ctx, account, resp.StatusCode, resp.Header, body)
shouldDisable := false
if s.rateLimitService != nil {
shouldDisable = s.rateLimitService.HandleUpstreamError(ctx, account, resp.StatusCode, resp.Header, body)
}
if shouldDisable {
return nil, &UpstreamFailoverError{StatusCode: resp.StatusCode}
}
// 根据状态码返回适当的自定义错误响应(不透传上游详细信息)
var errType, errMsg string
@@ -1783,7 +1868,7 @@ func (s *GatewayService) replaceModelInResponseBody(body []byte, fromModel, toMo
// RecordUsageInput 记录使用量的输入参数
type RecordUsageInput struct {
Result *ForwardResult
ApiKey *ApiKey
APIKey *APIKey
User *User
Account *Account
Subscription *UserSubscription // 可选:订阅信息
@@ -1792,7 +1877,7 @@ type RecordUsageInput struct {
// RecordUsage 记录使用量并扣费(或更新订阅用量)
func (s *GatewayService) RecordUsage(ctx context.Context, input *RecordUsageInput) error {
result := input.Result
apiKey := input.ApiKey
apiKey := input.APIKey
user := input.User
account := input.Account
subscription := input.Subscription
@@ -1829,7 +1914,7 @@ func (s *GatewayService) RecordUsage(ctx context.Context, input *RecordUsageInpu
durationMs := int(result.Duration.Milliseconds())
usageLog := &UsageLog{
UserID: user.ID,
ApiKeyID: apiKey.ID,
APIKeyID: apiKey.ID,
AccountID: account.ID,
RequestID: result.RequestID,
Model: result.Model,
@@ -1859,7 +1944,8 @@ func (s *GatewayService) RecordUsage(ctx context.Context, input *RecordUsageInpu
usageLog.SubscriptionID = &subscription.ID
}
if err := s.usageLogRepo.Create(ctx, usageLog); err != nil {
inserted, err := s.usageLogRepo.Create(ctx, usageLog)
if err != nil {
log.Printf("Create usage log failed: %v", err)
}
@@ -1869,10 +1955,12 @@ func (s *GatewayService) RecordUsage(ctx context.Context, input *RecordUsageInpu
return nil
}
shouldBill := inserted || err != nil
// 根据计费类型执行扣费
if isSubscriptionBilling {
// 订阅模式:更新订阅用量(使用 TotalCost 原始费用,不考虑倍率)
if cost.TotalCost > 0 {
if shouldBill && cost.TotalCost > 0 {
if err := s.userSubRepo.IncrementUsage(ctx, subscription.ID, cost.TotalCost); err != nil {
log.Printf("Increment subscription usage failed: %v", err)
}
@@ -1881,7 +1969,7 @@ func (s *GatewayService) RecordUsage(ctx context.Context, input *RecordUsageInpu
}
} else {
// 余额模式:扣除用户余额(使用 ActualCost 考虑倍率后的费用)
if cost.ActualCost > 0 {
if shouldBill && cost.ActualCost > 0 {
if err := s.userRepo.DeductBalance(ctx, user.ID, cost.ActualCost); err != nil {
log.Printf("Deduct balance failed: %v", err)
}
@@ -1914,7 +2002,7 @@ func (s *GatewayService) ForwardCountTokens(ctx context.Context, c *gin.Context,
}
// 应用模型映射(仅对 apikey 类型账号)
if account.Type == AccountTypeApiKey {
if account.Type == AccountTypeAPIKey {
if reqModel != "" {
mappedModel := account.GetMappedModel(reqModel)
if mappedModel != reqModel {
@@ -1951,17 +2039,35 @@ func (s *GatewayService) ForwardCountTokens(ctx context.Context, c *gin.Context,
s.countTokensError(c, http.StatusBadGateway, "upstream_error", "Request failed")
return fmt.Errorf("upstream request failed: %w", err)
}
defer func() {
_ = resp.Body.Close()
}()
// 读取响应体
respBody, err := io.ReadAll(resp.Body)
_ = resp.Body.Close()
if err != nil {
s.countTokensError(c, http.StatusBadGateway, "upstream_error", "Failed to read response")
return err
}
// 检测 thinking block 签名错误400并重试一次过滤 thinking blocks
if resp.StatusCode == 400 && s.isThinkingBlockSignatureError(respBody) {
log.Printf("Account %d: detected thinking block signature error on count_tokens, retrying with filtered thinking blocks", account.ID)
filteredBody := FilterThinkingBlocks(body)
retryReq, buildErr := s.buildCountTokensRequest(ctx, c, account, filteredBody, token, tokenType, reqModel)
if buildErr == nil {
retryResp, retryErr := s.httpUpstream.Do(retryReq, proxyURL, account.ID, account.Concurrency)
if retryErr == nil {
resp = retryResp
respBody, err = io.ReadAll(resp.Body)
_ = resp.Body.Close()
if err != nil {
s.countTokensError(c, http.StatusBadGateway, "upstream_error", "Failed to read response")
return err
}
}
}
}
// 处理错误响应
if resp.StatusCode >= 400 {
// 标记账号状态429/529等
@@ -2000,7 +2106,7 @@ func (s *GatewayService) ForwardCountTokens(ctx context.Context, c *gin.Context,
func (s *GatewayService) buildCountTokensRequest(ctx context.Context, c *gin.Context, account *Account, body []byte, token, tokenType, modelID string) (*http.Request, error) {
// 确定目标 URL
targetURL := claudeAPICountTokensURL
if account.Type == AccountTypeApiKey {
if account.Type == AccountTypeAPIKey {
baseURL := account.GetBaseURL()
if baseURL != "" {
validatedURL, err := s.validateUpstreamBaseURL(baseURL)
@@ -2065,10 +2171,10 @@ func (s *GatewayService) buildCountTokensRequest(ctx context.Context, c *gin.Con
// OAuth 账号:处理 anthropic-beta header
if tokenType == "oauth" {
req.Header.Set("anthropic-beta", s.getBetaHeader(modelID, c.GetHeader("anthropic-beta")))
} else if s.cfg != nil && s.cfg.Gateway.InjectBetaForApiKey && req.Header.Get("anthropic-beta") == "" {
} else if s.cfg != nil && s.cfg.Gateway.InjectBetaForAPIKey && req.Header.Get("anthropic-beta") == "" {
// API-key与 messages 同步的按需 beta 注入(默认关闭)
if requestNeedsBetaFeatures(body) {
if beta := defaultApiKeyBetaHeader(body); beta != "" {
if beta := defaultAPIKeyBetaHeader(body); beta != "" {
req.Header.Set("anthropic-beta", beta)
}
}