fix: restore non-failover error passthrough from 7b156489

This commit is contained in:
erio
2026-02-07 14:24:55 +08:00
parent 43a4840daf
commit edb0937024
8 changed files with 404 additions and 18 deletions

View File

@@ -137,6 +137,11 @@ func (h *GatewayHandler) Messages(c *gin.Context) {
// Track if we've started streaming (for error handling)
streamStarted := false
// 绑定错误透传服务,允许 service 层在非 failover 错误场景复用规则。
if h.errorPassthroughService != nil {
service.BindErrorPassthroughService(c, h.errorPassthroughService)
}
// 获取订阅信息可能为nil- 提前获取用于后续检查
subscription, _ := middleware2.GetSubscriptionFromContext(c)

View File

@@ -209,6 +209,9 @@ func (h *GatewayHandler) GeminiV1BetaModels(c *gin.Context) {
// 1) user concurrency slot
streamStarted := false
if h.errorPassthroughService != nil {
service.BindErrorPassthroughService(c, h.errorPassthroughService)
}
userReleaseFunc, err := geminiConcurrency.AcquireUserSlotWithWait(c, authSubject.UserID, authSubject.Concurrency, stream, &streamStarted)
if err != nil {
googleError(c, http.StatusTooManyRequests, err.Error())

View File

@@ -149,6 +149,11 @@ func (h *OpenAIGatewayHandler) Responses(c *gin.Context) {
// Track if we've started streaming (for error handling)
streamStarted := false
// 绑定错误透传服务,允许 service 层在非 failover 错误场景复用规则。
if h.errorPassthroughService != nil {
service.BindErrorPassthroughService(c, h.errorPassthroughService)
}
// Get subscription info (may be nil)
subscription, _ := middleware2.GetSubscriptionFromContext(c)

View File

@@ -0,0 +1,67 @@
package service
import "github.com/gin-gonic/gin"
const errorPassthroughServiceContextKey = "error_passthrough_service"
// BindErrorPassthroughService 将错误透传服务绑定到请求上下文,供 service 层在非 failover 场景下复用规则。
func BindErrorPassthroughService(c *gin.Context, svc *ErrorPassthroughService) {
if c == nil || svc == nil {
return
}
c.Set(errorPassthroughServiceContextKey, svc)
}
func getBoundErrorPassthroughService(c *gin.Context) *ErrorPassthroughService {
if c == nil {
return nil
}
v, ok := c.Get(errorPassthroughServiceContextKey)
if !ok {
return nil
}
svc, ok := v.(*ErrorPassthroughService)
if !ok {
return nil
}
return svc
}
// applyErrorPassthroughRule 按规则改写错误响应;未命中时返回默认响应参数。
func applyErrorPassthroughRule(
c *gin.Context,
platform string,
upstreamStatus int,
responseBody []byte,
defaultStatus int,
defaultErrType string,
defaultErrMsg string,
) (status int, errType string, errMsg string, matched bool) {
status = defaultStatus
errType = defaultErrType
errMsg = defaultErrMsg
svc := getBoundErrorPassthroughService(c)
if svc == nil {
return status, errType, errMsg, false
}
rule := svc.MatchRule(platform, upstreamStatus, responseBody)
if rule == nil {
return status, errType, errMsg, false
}
status = upstreamStatus
if !rule.PassthroughCode && rule.ResponseCode != nil {
status = *rule.ResponseCode
}
errMsg = ExtractUpstreamErrorMessage(responseBody)
if !rule.PassthroughBody && rule.CustomMessage != nil {
errMsg = *rule.CustomMessage
}
// 与现有 failover 场景保持一致:命中规则时统一返回 upstream_error。
errType = "upstream_error"
return status, errType, errMsg, true
}

View File

@@ -0,0 +1,211 @@
package service
import (
"bytes"
"context"
"encoding/json"
"io"
"net/http"
"net/http/httptest"
"testing"
"github.com/Wei-Shaw/sub2api/internal/model"
"github.com/gin-gonic/gin"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
func TestApplyErrorPassthroughRule_NoBoundService(t *testing.T) {
gin.SetMode(gin.TestMode)
rec := httptest.NewRecorder()
c, _ := gin.CreateTestContext(rec)
status, errType, errMsg, matched := applyErrorPassthroughRule(
c,
PlatformAnthropic,
http.StatusUnprocessableEntity,
[]byte(`{"error":{"message":"invalid schema"}}`),
http.StatusBadGateway,
"upstream_error",
"Upstream request failed",
)
assert.False(t, matched)
assert.Equal(t, http.StatusBadGateway, status)
assert.Equal(t, "upstream_error", errType)
assert.Equal(t, "Upstream request failed", errMsg)
}
func TestGatewayHandleErrorResponse_NoRuleKeepsDefault(t *testing.T) {
gin.SetMode(gin.TestMode)
rec := httptest.NewRecorder()
c, _ := gin.CreateTestContext(rec)
svc := &GatewayService{}
respBody := []byte(`{"error":{"message":"Invalid schema for field messages"}}`)
resp := &http.Response{
StatusCode: http.StatusUnprocessableEntity,
Body: io.NopCloser(bytes.NewReader(respBody)),
Header: http.Header{},
}
account := &Account{ID: 11, Platform: PlatformAnthropic, Type: AccountTypeAPIKey}
_, err := svc.handleErrorResponse(context.Background(), resp, c, account)
require.Error(t, err)
assert.Equal(t, http.StatusBadGateway, rec.Code)
var payload map[string]any
require.NoError(t, json.Unmarshal(rec.Body.Bytes(), &payload))
errField, ok := payload["error"].(map[string]any)
require.True(t, ok)
assert.Equal(t, "upstream_error", errField["type"])
assert.Equal(t, "Upstream request failed", errField["message"])
}
func TestOpenAIHandleErrorResponse_NoRuleKeepsDefault(t *testing.T) {
gin.SetMode(gin.TestMode)
rec := httptest.NewRecorder()
c, _ := gin.CreateTestContext(rec)
svc := &OpenAIGatewayService{}
respBody := []byte(`{"error":{"message":"Invalid schema for field messages"}}`)
resp := &http.Response{
StatusCode: http.StatusUnprocessableEntity,
Body: io.NopCloser(bytes.NewReader(respBody)),
Header: http.Header{},
}
account := &Account{ID: 12, Platform: PlatformOpenAI, Type: AccountTypeAPIKey}
_, err := svc.handleErrorResponse(context.Background(), resp, c, account)
require.Error(t, err)
assert.Equal(t, http.StatusBadGateway, rec.Code)
var payload map[string]any
require.NoError(t, json.Unmarshal(rec.Body.Bytes(), &payload))
errField, ok := payload["error"].(map[string]any)
require.True(t, ok)
assert.Equal(t, "upstream_error", errField["type"])
assert.Equal(t, "Upstream request failed", errField["message"])
}
func TestGeminiWriteGeminiMappedError_NoRuleKeepsDefault(t *testing.T) {
gin.SetMode(gin.TestMode)
rec := httptest.NewRecorder()
c, _ := gin.CreateTestContext(rec)
svc := &GeminiMessagesCompatService{}
respBody := []byte(`{"error":{"code":422,"message":"Invalid schema for field messages","status":"INVALID_ARGUMENT"}}`)
account := &Account{ID: 13, Platform: PlatformGemini, Type: AccountTypeAPIKey}
err := svc.writeGeminiMappedError(c, account, http.StatusUnprocessableEntity, "req-2", respBody)
require.Error(t, err)
assert.Equal(t, http.StatusBadRequest, rec.Code)
var payload map[string]any
require.NoError(t, json.Unmarshal(rec.Body.Bytes(), &payload))
errField, ok := payload["error"].(map[string]any)
require.True(t, ok)
assert.Equal(t, "invalid_request_error", errField["type"])
assert.Equal(t, "Upstream request failed", errField["message"])
}
func TestGatewayHandleErrorResponse_AppliesRuleFor422(t *testing.T) {
gin.SetMode(gin.TestMode)
rec := httptest.NewRecorder()
c, _ := gin.CreateTestContext(rec)
ruleSvc := &ErrorPassthroughService{}
ruleSvc.setLocalCache([]*model.ErrorPassthroughRule{newNonFailoverPassthroughRule(http.StatusUnprocessableEntity, "invalid schema", http.StatusTeapot, "上游请求失败")})
BindErrorPassthroughService(c, ruleSvc)
svc := &GatewayService{}
respBody := []byte(`{"error":{"message":"Invalid schema for field messages"}}`)
resp := &http.Response{
StatusCode: http.StatusUnprocessableEntity,
Body: io.NopCloser(bytes.NewReader(respBody)),
Header: http.Header{},
}
account := &Account{ID: 1, Platform: PlatformAnthropic, Type: AccountTypeAPIKey}
_, err := svc.handleErrorResponse(context.Background(), resp, c, account)
require.Error(t, err)
assert.Equal(t, http.StatusTeapot, rec.Code)
var payload map[string]any
require.NoError(t, json.Unmarshal(rec.Body.Bytes(), &payload))
errField, ok := payload["error"].(map[string]any)
require.True(t, ok)
assert.Equal(t, "upstream_error", errField["type"])
assert.Equal(t, "上游请求失败", errField["message"])
}
func TestOpenAIHandleErrorResponse_AppliesRuleFor422(t *testing.T) {
gin.SetMode(gin.TestMode)
rec := httptest.NewRecorder()
c, _ := gin.CreateTestContext(rec)
ruleSvc := &ErrorPassthroughService{}
ruleSvc.setLocalCache([]*model.ErrorPassthroughRule{newNonFailoverPassthroughRule(http.StatusUnprocessableEntity, "invalid schema", http.StatusTeapot, "OpenAI上游失败")})
BindErrorPassthroughService(c, ruleSvc)
svc := &OpenAIGatewayService{}
respBody := []byte(`{"error":{"message":"Invalid schema for field messages"}}`)
resp := &http.Response{
StatusCode: http.StatusUnprocessableEntity,
Body: io.NopCloser(bytes.NewReader(respBody)),
Header: http.Header{},
}
account := &Account{ID: 2, Platform: PlatformOpenAI, Type: AccountTypeAPIKey}
_, err := svc.handleErrorResponse(context.Background(), resp, c, account)
require.Error(t, err)
assert.Equal(t, http.StatusTeapot, rec.Code)
var payload map[string]any
require.NoError(t, json.Unmarshal(rec.Body.Bytes(), &payload))
errField, ok := payload["error"].(map[string]any)
require.True(t, ok)
assert.Equal(t, "upstream_error", errField["type"])
assert.Equal(t, "OpenAI上游失败", errField["message"])
}
func TestGeminiWriteGeminiMappedError_AppliesRuleFor422(t *testing.T) {
gin.SetMode(gin.TestMode)
rec := httptest.NewRecorder()
c, _ := gin.CreateTestContext(rec)
ruleSvc := &ErrorPassthroughService{}
ruleSvc.setLocalCache([]*model.ErrorPassthroughRule{newNonFailoverPassthroughRule(http.StatusUnprocessableEntity, "invalid schema", http.StatusTeapot, "Gemini上游失败")})
BindErrorPassthroughService(c, ruleSvc)
svc := &GeminiMessagesCompatService{}
respBody := []byte(`{"error":{"code":422,"message":"Invalid schema for field messages","status":"INVALID_ARGUMENT"}}`)
account := &Account{ID: 3, Platform: PlatformGemini, Type: AccountTypeAPIKey}
err := svc.writeGeminiMappedError(c, account, http.StatusUnprocessableEntity, "req-1", respBody)
require.Error(t, err)
assert.Equal(t, http.StatusTeapot, rec.Code)
var payload map[string]any
require.NoError(t, json.Unmarshal(rec.Body.Bytes(), &payload))
errField, ok := payload["error"].(map[string]any)
require.True(t, ok)
assert.Equal(t, "upstream_error", errField["type"])
assert.Equal(t, "Gemini上游失败", errField["message"])
}
func newNonFailoverPassthroughRule(statusCode int, keyword string, respCode int, customMessage string) *model.ErrorPassthroughRule {
return &model.ErrorPassthroughRule{
ID: 1,
Name: "non-failover-rule",
Enabled: true,
Priority: 1,
ErrorCodes: []int{statusCode},
Keywords: []string{keyword},
MatchMode: model.MatchModeAll,
PassthroughCode: false,
ResponseCode: &respCode,
PassthroughBody: false,
CustomMessage: &customMessage,
}
}

View File

@@ -2576,24 +2576,20 @@ func (s *GatewayService) selectAccountWithMixedScheduling(ctx context.Context, g
// 对于 Antigravity 平台,会先获取映射后的最终模型名(包括 thinking 后缀)再检查支持
func (s *GatewayService) isModelSupportedByAccountWithContext(ctx context.Context, account *Account, requestedModel string) bool {
if account.Platform == PlatformAntigravity {
// Antigravity 平台使用专门的模型支持检查
if strings.TrimSpace(requestedModel) == "" {
return true
}
if !IsAntigravityModelSupported(requestedModel) {
// 使用与转发阶段一致的映射逻辑:自定义映射优先 → 默认映射兜底
mapped := mapAntigravityModel(account, requestedModel)
if mapped == "" {
return false
}
// 先用默认映射获取基础模型名,再应用 thinking 后缀
defaultMapped, exists := domain.DefaultAntigravityModelMapping[requestedModel]
if !exists || defaultMapped == "" {
return false
}
finalModel := defaultMapped
// 应用 thinking 后缀后检查最终模型是否在账号映射中
if enabled, ok := ctx.Value(ctxkey.ThinkingEnabled).(bool); ok {
finalModel = applyThinkingModelSuffix(finalModel, enabled)
finalModel := applyThinkingModelSuffix(mapped, enabled)
return account.IsModelSupported(finalModel)
}
// 使用最终模型名检查 model_mapping 支持
return account.IsModelSupported(finalModel)
return true
}
return s.isModelSupportedByAccount(account, requestedModel)
}
@@ -2601,15 +2597,10 @@ func (s *GatewayService) isModelSupportedByAccountWithContext(ctx context.Contex
// isModelSupportedByAccount 根据账户平台检查模型支持(无 context用于非 Antigravity 平台)
func (s *GatewayService) isModelSupportedByAccount(account *Account, requestedModel string) bool {
if account.Platform == PlatformAntigravity {
// Antigravity 应使用 isModelSupportedByAccountWithContext
// 这里作为兼容保留,使用原始模型名检查
if strings.TrimSpace(requestedModel) == "" {
return true
}
if !IsAntigravityModelSupported(requestedModel) {
return false
}
return account.IsModelSupported(requestedModel)
return mapAntigravityModel(account, requestedModel) != ""
}
// OAuth/SetupToken 账号使用 Anthropic 标准映射短ID → 长ID
if account.Platform == PlatformAnthropic && account.Type != AccountTypeAPIKey {
@@ -3919,6 +3910,34 @@ func (s *GatewayService) handleErrorResponse(ctx context.Context, resp *http.Res
)
}
// 非 failover 错误也支持错误透传规则匹配。
if status, errType, errMsg, matched := applyErrorPassthroughRule(
c,
account.Platform,
resp.StatusCode,
body,
http.StatusBadGateway,
"upstream_error",
"Upstream request failed",
); matched {
c.JSON(status, gin.H{
"type": "error",
"error": gin.H{
"type": errType,
"message": errMsg,
},
})
summary := upstreamMsg
if summary == "" {
summary = errMsg
}
if summary == "" {
return nil, fmt.Errorf("upstream error: %d (passthrough rule matched)", resp.StatusCode)
}
return nil, fmt.Errorf("upstream error: %d (passthrough rule matched) message=%s", resp.StatusCode, summary)
}
// 根据状态码返回适当的自定义错误响应(不透传上游详细信息)
var errType, errMsg string
var statusCode int
@@ -4050,6 +4069,33 @@ func (s *GatewayService) handleRetryExhaustedError(ctx context.Context, resp *ht
)
}
if status, errType, errMsg, matched := applyErrorPassthroughRule(
c,
account.Platform,
resp.StatusCode,
respBody,
http.StatusBadGateway,
"upstream_error",
"Upstream request failed after retries",
); matched {
c.JSON(status, gin.H{
"type": "error",
"error": gin.H{
"type": errType,
"message": errMsg,
},
})
summary := upstreamMsg
if summary == "" {
summary = errMsg
}
if summary == "" {
return nil, fmt.Errorf("upstream error: %d (retries exhausted, passthrough rule matched)", resp.StatusCode)
}
return nil, fmt.Errorf("upstream error: %d (retries exhausted, passthrough rule matched) message=%s", resp.StatusCode, summary)
}
// 返回统一的重试耗尽错误响应
c.JSON(http.StatusBadGateway, gin.H{
"type": "error",

View File

@@ -362,7 +362,10 @@ func (s *GeminiMessagesCompatService) isBetterGeminiAccount(candidate, current *
// isModelSupportedByAccount 根据账户平台检查模型支持
func (s *GeminiMessagesCompatService) isModelSupportedByAccount(account *Account, requestedModel string) bool {
if account.Platform == PlatformAntigravity {
return IsAntigravityModelSupported(requestedModel)
if strings.TrimSpace(requestedModel) == "" {
return true
}
return mapAntigravityModel(account, requestedModel) != ""
}
return account.IsModelSupported(requestedModel)
}
@@ -1498,6 +1501,28 @@ func (s *GeminiMessagesCompatService) writeGeminiMappedError(c *gin.Context, acc
log.Printf("[Gemini] upstream error %d: %s", upstreamStatus, truncateForLog(body, s.cfg.Gateway.LogUpstreamErrorBodyMaxBytes))
}
if status, errType, errMsg, matched := applyErrorPassthroughRule(
c,
PlatformGemini,
upstreamStatus,
body,
http.StatusBadGateway,
"upstream_error",
"Upstream request failed",
); matched {
c.JSON(status, gin.H{
"type": "error",
"error": gin.H{"type": errType, "message": errMsg},
})
if upstreamMsg == "" {
upstreamMsg = errMsg
}
if upstreamMsg == "" {
return fmt.Errorf("upstream error: %d (passthrough rule matched)", upstreamStatus)
}
return fmt.Errorf("upstream error: %d (passthrough rule matched) message=%s", upstreamStatus, upstreamMsg)
}
var statusCode int
var errType, errMsg string

View File

@@ -1087,6 +1087,30 @@ func (s *OpenAIGatewayService) handleErrorResponse(ctx context.Context, resp *ht
)
}
if status, errType, errMsg, matched := applyErrorPassthroughRule(
c,
PlatformOpenAI,
resp.StatusCode,
body,
http.StatusBadGateway,
"upstream_error",
"Upstream request failed",
); matched {
c.JSON(status, gin.H{
"error": gin.H{
"type": errType,
"message": errMsg,
},
})
if upstreamMsg == "" {
upstreamMsg = errMsg
}
if upstreamMsg == "" {
return nil, fmt.Errorf("upstream error: %d (passthrough rule matched)", resp.StatusCode)
}
return nil, fmt.Errorf("upstream error: %d (passthrough rule matched) message=%s", resp.StatusCode, upstreamMsg)
}
// Check custom error codes
if !account.ShouldHandleErrorCode(resp.StatusCode) {
appendOpsUpstreamError(c, OpsUpstreamErrorEvent{