From 7eda43c99ee14d2972263aeee0e21e664753ae49 Mon Sep 17 00:00:00 2001 From: ianshaw Date: Sat, 3 Jan 2026 17:10:25 -0800 Subject: [PATCH] =?UTF-8?q?fix(gateway):=20=E5=AE=8C=E5=96=84=20thinking?= =?UTF-8?q?=20block=20=E9=87=8D=E8=AF=95=E5=92=8C=20cache=20nil=20?= =?UTF-8?q?=E6=A3=80=E6=9F=A5?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit - 使用 FilterThinkingBlocksForRetry 替代 FilterThinkingBlocks - count_tokens 增加 thinking block 签名错误重试 - cache nil 检查防止空指针 - shouldBill 逻辑修复避免重复扣费 - 移除 debug 日志 --- backend/internal/service/gateway_service.go | 60 ++++++++++++++------- 1 file changed, 42 insertions(+), 18 deletions(-) diff --git a/backend/internal/service/gateway_service.go b/backend/internal/service/gateway_service.go index 9a8ffd33..0a20fd88 100644 --- a/backend/internal/service/gateway_service.go +++ b/backend/internal/service/gateway_service.go @@ -541,7 +541,7 @@ func (s *GatewayService) SelectAccountWithLoadAwareness(ctx context.Context, gro for _, item := range available { result, err := s.tryAcquireAccountSlot(ctx, item.account.ID, item.account.Concurrency) if err == nil && result.Acquired { - if sessionHash != "" { + if sessionHash != "" && s.cache != nil { _ = s.cache.SetSessionAccountID(ctx, sessionHash, item.account.ID, stickySessionTTL) } return &AccountSelectionResult{ @@ -577,7 +577,7 @@ func (s *GatewayService) tryAcquireByLegacyOrder(ctx context.Context, candidates for _, acc := range ordered { result, err := s.tryAcquireAccountSlot(ctx, acc.ID, acc.Concurrency) if err == nil && result.Acquired { - if sessionHash != "" { + if sessionHash != "" && s.cache != nil { _ = s.cache.SetSessionAccountID(ctx, sessionHash, acc.ID, stickySessionTTL) } return &AccountSelectionResult{ @@ -708,7 +708,7 @@ func sortAccountsByPriorityAndLastUsed(accounts []*Account, preferOAuth bool) { func (s *GatewayService) selectAccountForModelWithPlatform(ctx context.Context, groupID *int64, sessionHash string, requestedModel string, excludedIDs map[int64]struct{}, platform string) (*Account, error) { preferOAuth := platform == PlatformGemini // 1. 查询粘性会话 - if sessionHash != "" { + if sessionHash != "" && s.cache != nil { accountID, err := s.cache.GetSessionAccountID(ctx, sessionHash) if err == nil && accountID > 0 { if _, excluded := excludedIDs[accountID]; !excluded { @@ -781,7 +781,7 @@ func (s *GatewayService) selectAccountForModelWithPlatform(ctx context.Context, } // 4. 建立粘性绑定 - if sessionHash != "" { + if sessionHash != "" && s.cache != nil { if err := s.cache.SetSessionAccountID(ctx, sessionHash, selected.ID, stickySessionTTL); err != nil { log.Printf("set session account failed: session=%s account_id=%d err=%v", sessionHash, selected.ID, err) } @@ -797,7 +797,7 @@ func (s *GatewayService) selectAccountWithMixedScheduling(ctx context.Context, g preferOAuth := nativePlatform == PlatformGemini // 1. 查询粘性会话 - if sessionHash != "" { + if sessionHash != "" && s.cache != nil { accountID, err := s.cache.GetSessionAccountID(ctx, sessionHash) if err == nil && accountID > 0 { if _, excluded := excludedIDs[accountID]; !excluded { @@ -873,7 +873,7 @@ func (s *GatewayService) selectAccountWithMixedScheduling(ctx context.Context, g } // 4. 建立粘性绑定 - if sessionHash != "" { + if sessionHash != "" && s.cache != nil { if err := s.cache.SetSessionAccountID(ctx, sessionHash, selected.ID, stickySessionTTL); err != nil { log.Printf("set session account failed: session=%s account_id=%d err=%v", sessionHash, selected.ID, err) } @@ -1022,8 +1022,8 @@ func (s *GatewayService) Forward(ctx context.Context, c *gin.Context, account *A if s.isThinkingBlockSignatureError(respBody) { log.Printf("Account %d: detected thinking block signature error, retrying with filtered thinking blocks", account.ID) - // 过滤thinking blocks并重试 - filteredBody := FilterThinkingBlocks(body) + // 过滤thinking blocks并重试(使用更激进的过滤) + filteredBody := FilterThinkingBlocksForRetry(body) retryReq, buildErr := s.buildUpstreamRequest(ctx, c, account, filteredBody, token, tokenType, reqModel) if buildErr == nil { retryResp, retryErr := s.httpUpstream.Do(retryReq, proxyURL, account.ID, account.Concurrency) @@ -1303,10 +1303,13 @@ func (s *GatewayService) isThinkingBlockSignatureError(respBody []byte) bool { return false } - // 检测thinking block签名相关的错误 - // 例如: "Invalid `signature` in `thinking` block" - return (strings.Contains(msg, "thinking") || strings.Contains(msg, "thought")) && - strings.Contains(msg, "signature") + // 检测signature相关的错误(更宽松的匹配) + // 例如: "Invalid `signature` in `thinking` block", "***.signature" 等 + if strings.Contains(msg, "signature") { + return true + } + + return false } func (s *GatewayService) shouldFailoverOn400(respBody []byte) bool { @@ -1751,7 +1754,8 @@ func (s *GatewayService) RecordUsage(ctx context.Context, input *RecordUsageInpu usageLog.SubscriptionID = &subscription.ID } - if err := s.usageLogRepo.Create(ctx, usageLog); err != nil { + inserted, err := s.usageLogRepo.Create(ctx, usageLog) + if err != nil { log.Printf("Create usage log failed: %v", err) } @@ -1761,10 +1765,12 @@ func (s *GatewayService) RecordUsage(ctx context.Context, input *RecordUsageInpu return nil } + shouldBill := inserted || err != nil + // 根据计费类型执行扣费 if isSubscriptionBilling { // 订阅模式:更新订阅用量(使用 TotalCost 原始费用,不考虑倍率) - if cost.TotalCost > 0 { + if shouldBill && cost.TotalCost > 0 { if err := s.userSubRepo.IncrementUsage(ctx, subscription.ID, cost.TotalCost); err != nil { log.Printf("Increment subscription usage failed: %v", err) } @@ -1773,7 +1779,7 @@ func (s *GatewayService) RecordUsage(ctx context.Context, input *RecordUsageInpu } } else { // 余额模式:扣除用户余额(使用 ActualCost 考虑倍率后的费用) - if cost.ActualCost > 0 { + if shouldBill && cost.ActualCost > 0 { if err := s.userRepo.DeductBalance(ctx, user.ID, cost.ActualCost); err != nil { log.Printf("Deduct balance failed: %v", err) } @@ -1843,17 +1849,35 @@ func (s *GatewayService) ForwardCountTokens(ctx context.Context, c *gin.Context, s.countTokensError(c, http.StatusBadGateway, "upstream_error", "Request failed") return fmt.Errorf("upstream request failed: %w", err) } - defer func() { - _ = resp.Body.Close() - }() // 读取响应体 respBody, err := io.ReadAll(resp.Body) + _ = resp.Body.Close() if err != nil { s.countTokensError(c, http.StatusBadGateway, "upstream_error", "Failed to read response") return err } + // 检测 thinking block 签名错误(400)并重试一次(过滤 thinking blocks) + if resp.StatusCode == 400 && s.isThinkingBlockSignatureError(respBody) { + log.Printf("Account %d: detected thinking block signature error on count_tokens, retrying with filtered thinking blocks", account.ID) + + filteredBody := FilterThinkingBlocks(body) + retryReq, buildErr := s.buildCountTokensRequest(ctx, c, account, filteredBody, token, tokenType, reqModel) + if buildErr == nil { + retryResp, retryErr := s.httpUpstream.Do(retryReq, proxyURL, account.ID, account.Concurrency) + if retryErr == nil { + resp = retryResp + respBody, err = io.ReadAll(resp.Body) + _ = resp.Body.Close() + if err != nil { + s.countTokensError(c, http.StatusBadGateway, "upstream_error", "Failed to read response") + return err + } + } + } + } + // 处理错误响应 if resp.StatusCode >= 400 { // 标记账号状态(429/529等)