fix(gateway): 完善 thinking block 重试和 cache nil 检查
- 使用 FilterThinkingBlocksForRetry 替代 FilterThinkingBlocks - count_tokens 增加 thinking block 签名错误重试 - cache nil 检查防止空指针 - shouldBill 逻辑修复避免重复扣费 - 移除 debug 日志
This commit is contained in:
@@ -541,7 +541,7 @@ func (s *GatewayService) SelectAccountWithLoadAwareness(ctx context.Context, gro
|
|||||||
for _, item := range available {
|
for _, item := range available {
|
||||||
result, err := s.tryAcquireAccountSlot(ctx, item.account.ID, item.account.Concurrency)
|
result, err := s.tryAcquireAccountSlot(ctx, item.account.ID, item.account.Concurrency)
|
||||||
if err == nil && result.Acquired {
|
if err == nil && result.Acquired {
|
||||||
if sessionHash != "" {
|
if sessionHash != "" && s.cache != nil {
|
||||||
_ = s.cache.SetSessionAccountID(ctx, sessionHash, item.account.ID, stickySessionTTL)
|
_ = s.cache.SetSessionAccountID(ctx, sessionHash, item.account.ID, stickySessionTTL)
|
||||||
}
|
}
|
||||||
return &AccountSelectionResult{
|
return &AccountSelectionResult{
|
||||||
@@ -577,7 +577,7 @@ func (s *GatewayService) tryAcquireByLegacyOrder(ctx context.Context, candidates
|
|||||||
for _, acc := range ordered {
|
for _, acc := range ordered {
|
||||||
result, err := s.tryAcquireAccountSlot(ctx, acc.ID, acc.Concurrency)
|
result, err := s.tryAcquireAccountSlot(ctx, acc.ID, acc.Concurrency)
|
||||||
if err == nil && result.Acquired {
|
if err == nil && result.Acquired {
|
||||||
if sessionHash != "" {
|
if sessionHash != "" && s.cache != nil {
|
||||||
_ = s.cache.SetSessionAccountID(ctx, sessionHash, acc.ID, stickySessionTTL)
|
_ = s.cache.SetSessionAccountID(ctx, sessionHash, acc.ID, stickySessionTTL)
|
||||||
}
|
}
|
||||||
return &AccountSelectionResult{
|
return &AccountSelectionResult{
|
||||||
@@ -708,7 +708,7 @@ func sortAccountsByPriorityAndLastUsed(accounts []*Account, preferOAuth bool) {
|
|||||||
func (s *GatewayService) selectAccountForModelWithPlatform(ctx context.Context, groupID *int64, sessionHash string, requestedModel string, excludedIDs map[int64]struct{}, platform string) (*Account, error) {
|
func (s *GatewayService) selectAccountForModelWithPlatform(ctx context.Context, groupID *int64, sessionHash string, requestedModel string, excludedIDs map[int64]struct{}, platform string) (*Account, error) {
|
||||||
preferOAuth := platform == PlatformGemini
|
preferOAuth := platform == PlatformGemini
|
||||||
// 1. 查询粘性会话
|
// 1. 查询粘性会话
|
||||||
if sessionHash != "" {
|
if sessionHash != "" && s.cache != nil {
|
||||||
accountID, err := s.cache.GetSessionAccountID(ctx, sessionHash)
|
accountID, err := s.cache.GetSessionAccountID(ctx, sessionHash)
|
||||||
if err == nil && accountID > 0 {
|
if err == nil && accountID > 0 {
|
||||||
if _, excluded := excludedIDs[accountID]; !excluded {
|
if _, excluded := excludedIDs[accountID]; !excluded {
|
||||||
@@ -781,7 +781,7 @@ func (s *GatewayService) selectAccountForModelWithPlatform(ctx context.Context,
|
|||||||
}
|
}
|
||||||
|
|
||||||
// 4. 建立粘性绑定
|
// 4. 建立粘性绑定
|
||||||
if sessionHash != "" {
|
if sessionHash != "" && s.cache != nil {
|
||||||
if err := s.cache.SetSessionAccountID(ctx, sessionHash, selected.ID, stickySessionTTL); err != nil {
|
if err := s.cache.SetSessionAccountID(ctx, sessionHash, selected.ID, stickySessionTTL); err != nil {
|
||||||
log.Printf("set session account failed: session=%s account_id=%d err=%v", sessionHash, selected.ID, err)
|
log.Printf("set session account failed: session=%s account_id=%d err=%v", sessionHash, selected.ID, err)
|
||||||
}
|
}
|
||||||
@@ -797,7 +797,7 @@ func (s *GatewayService) selectAccountWithMixedScheduling(ctx context.Context, g
|
|||||||
preferOAuth := nativePlatform == PlatformGemini
|
preferOAuth := nativePlatform == PlatformGemini
|
||||||
|
|
||||||
// 1. 查询粘性会话
|
// 1. 查询粘性会话
|
||||||
if sessionHash != "" {
|
if sessionHash != "" && s.cache != nil {
|
||||||
accountID, err := s.cache.GetSessionAccountID(ctx, sessionHash)
|
accountID, err := s.cache.GetSessionAccountID(ctx, sessionHash)
|
||||||
if err == nil && accountID > 0 {
|
if err == nil && accountID > 0 {
|
||||||
if _, excluded := excludedIDs[accountID]; !excluded {
|
if _, excluded := excludedIDs[accountID]; !excluded {
|
||||||
@@ -873,7 +873,7 @@ func (s *GatewayService) selectAccountWithMixedScheduling(ctx context.Context, g
|
|||||||
}
|
}
|
||||||
|
|
||||||
// 4. 建立粘性绑定
|
// 4. 建立粘性绑定
|
||||||
if sessionHash != "" {
|
if sessionHash != "" && s.cache != nil {
|
||||||
if err := s.cache.SetSessionAccountID(ctx, sessionHash, selected.ID, stickySessionTTL); err != nil {
|
if err := s.cache.SetSessionAccountID(ctx, sessionHash, selected.ID, stickySessionTTL); err != nil {
|
||||||
log.Printf("set session account failed: session=%s account_id=%d err=%v", sessionHash, selected.ID, err)
|
log.Printf("set session account failed: session=%s account_id=%d err=%v", sessionHash, selected.ID, err)
|
||||||
}
|
}
|
||||||
@@ -1022,8 +1022,8 @@ func (s *GatewayService) Forward(ctx context.Context, c *gin.Context, account *A
|
|||||||
if s.isThinkingBlockSignatureError(respBody) {
|
if s.isThinkingBlockSignatureError(respBody) {
|
||||||
log.Printf("Account %d: detected thinking block signature error, retrying with filtered thinking blocks", account.ID)
|
log.Printf("Account %d: detected thinking block signature error, retrying with filtered thinking blocks", account.ID)
|
||||||
|
|
||||||
// 过滤thinking blocks并重试
|
// 过滤thinking blocks并重试(使用更激进的过滤)
|
||||||
filteredBody := FilterThinkingBlocks(body)
|
filteredBody := FilterThinkingBlocksForRetry(body)
|
||||||
retryReq, buildErr := s.buildUpstreamRequest(ctx, c, account, filteredBody, token, tokenType, reqModel)
|
retryReq, buildErr := s.buildUpstreamRequest(ctx, c, account, filteredBody, token, tokenType, reqModel)
|
||||||
if buildErr == nil {
|
if buildErr == nil {
|
||||||
retryResp, retryErr := s.httpUpstream.Do(retryReq, proxyURL, account.ID, account.Concurrency)
|
retryResp, retryErr := s.httpUpstream.Do(retryReq, proxyURL, account.ID, account.Concurrency)
|
||||||
@@ -1303,10 +1303,13 @@ func (s *GatewayService) isThinkingBlockSignatureError(respBody []byte) bool {
|
|||||||
return false
|
return false
|
||||||
}
|
}
|
||||||
|
|
||||||
// 检测thinking block签名相关的错误
|
// 检测signature相关的错误(更宽松的匹配)
|
||||||
// 例如: "Invalid `signature` in `thinking` block"
|
// 例如: "Invalid `signature` in `thinking` block", "***.signature" 等
|
||||||
return (strings.Contains(msg, "thinking") || strings.Contains(msg, "thought")) &&
|
if strings.Contains(msg, "signature") {
|
||||||
strings.Contains(msg, "signature")
|
return true
|
||||||
|
}
|
||||||
|
|
||||||
|
return false
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *GatewayService) shouldFailoverOn400(respBody []byte) bool {
|
func (s *GatewayService) shouldFailoverOn400(respBody []byte) bool {
|
||||||
@@ -1751,7 +1754,8 @@ func (s *GatewayService) RecordUsage(ctx context.Context, input *RecordUsageInpu
|
|||||||
usageLog.SubscriptionID = &subscription.ID
|
usageLog.SubscriptionID = &subscription.ID
|
||||||
}
|
}
|
||||||
|
|
||||||
if err := s.usageLogRepo.Create(ctx, usageLog); err != nil {
|
inserted, err := s.usageLogRepo.Create(ctx, usageLog)
|
||||||
|
if err != nil {
|
||||||
log.Printf("Create usage log failed: %v", err)
|
log.Printf("Create usage log failed: %v", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -1761,10 +1765,12 @@ func (s *GatewayService) RecordUsage(ctx context.Context, input *RecordUsageInpu
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
shouldBill := inserted || err != nil
|
||||||
|
|
||||||
// 根据计费类型执行扣费
|
// 根据计费类型执行扣费
|
||||||
if isSubscriptionBilling {
|
if isSubscriptionBilling {
|
||||||
// 订阅模式:更新订阅用量(使用 TotalCost 原始费用,不考虑倍率)
|
// 订阅模式:更新订阅用量(使用 TotalCost 原始费用,不考虑倍率)
|
||||||
if cost.TotalCost > 0 {
|
if shouldBill && cost.TotalCost > 0 {
|
||||||
if err := s.userSubRepo.IncrementUsage(ctx, subscription.ID, cost.TotalCost); err != nil {
|
if err := s.userSubRepo.IncrementUsage(ctx, subscription.ID, cost.TotalCost); err != nil {
|
||||||
log.Printf("Increment subscription usage failed: %v", err)
|
log.Printf("Increment subscription usage failed: %v", err)
|
||||||
}
|
}
|
||||||
@@ -1773,7 +1779,7 @@ func (s *GatewayService) RecordUsage(ctx context.Context, input *RecordUsageInpu
|
|||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
// 余额模式:扣除用户余额(使用 ActualCost 考虑倍率后的费用)
|
// 余额模式:扣除用户余额(使用 ActualCost 考虑倍率后的费用)
|
||||||
if cost.ActualCost > 0 {
|
if shouldBill && cost.ActualCost > 0 {
|
||||||
if err := s.userRepo.DeductBalance(ctx, user.ID, cost.ActualCost); err != nil {
|
if err := s.userRepo.DeductBalance(ctx, user.ID, cost.ActualCost); err != nil {
|
||||||
log.Printf("Deduct balance failed: %v", err)
|
log.Printf("Deduct balance failed: %v", err)
|
||||||
}
|
}
|
||||||
@@ -1843,17 +1849,35 @@ func (s *GatewayService) ForwardCountTokens(ctx context.Context, c *gin.Context,
|
|||||||
s.countTokensError(c, http.StatusBadGateway, "upstream_error", "Request failed")
|
s.countTokensError(c, http.StatusBadGateway, "upstream_error", "Request failed")
|
||||||
return fmt.Errorf("upstream request failed: %w", err)
|
return fmt.Errorf("upstream request failed: %w", err)
|
||||||
}
|
}
|
||||||
defer func() {
|
|
||||||
_ = resp.Body.Close()
|
|
||||||
}()
|
|
||||||
|
|
||||||
// 读取响应体
|
// 读取响应体
|
||||||
respBody, err := io.ReadAll(resp.Body)
|
respBody, err := io.ReadAll(resp.Body)
|
||||||
|
_ = resp.Body.Close()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
s.countTokensError(c, http.StatusBadGateway, "upstream_error", "Failed to read response")
|
s.countTokensError(c, http.StatusBadGateway, "upstream_error", "Failed to read response")
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// 检测 thinking block 签名错误(400)并重试一次(过滤 thinking blocks)
|
||||||
|
if resp.StatusCode == 400 && s.isThinkingBlockSignatureError(respBody) {
|
||||||
|
log.Printf("Account %d: detected thinking block signature error on count_tokens, retrying with filtered thinking blocks", account.ID)
|
||||||
|
|
||||||
|
filteredBody := FilterThinkingBlocks(body)
|
||||||
|
retryReq, buildErr := s.buildCountTokensRequest(ctx, c, account, filteredBody, token, tokenType, reqModel)
|
||||||
|
if buildErr == nil {
|
||||||
|
retryResp, retryErr := s.httpUpstream.Do(retryReq, proxyURL, account.ID, account.Concurrency)
|
||||||
|
if retryErr == nil {
|
||||||
|
resp = retryResp
|
||||||
|
respBody, err = io.ReadAll(resp.Body)
|
||||||
|
_ = resp.Body.Close()
|
||||||
|
if err != nil {
|
||||||
|
s.countTokensError(c, http.StatusBadGateway, "upstream_error", "Failed to read response")
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
// 处理错误响应
|
// 处理错误响应
|
||||||
if resp.StatusCode >= 400 {
|
if resp.StatusCode >= 400 {
|
||||||
// 标记账号状态(429/529等)
|
// 标记账号状态(429/529等)
|
||||||
|
|||||||
Reference in New Issue
Block a user