diff --git a/backend/internal/handler/gateway_handler.go b/backend/internal/handler/gateway_handler.go index b95e67c3..7e6b2f03 100644 --- a/backend/internal/handler/gateway_handler.go +++ b/backend/internal/handler/gateway_handler.go @@ -137,6 +137,11 @@ func (h *GatewayHandler) Messages(c *gin.Context) { // Track if we've started streaming (for error handling) streamStarted := false + // 绑定错误透传服务,允许 service 层在非 failover 错误场景复用规则。 + if h.errorPassthroughService != nil { + service.BindErrorPassthroughService(c, h.errorPassthroughService) + } + // 获取订阅信息(可能为nil)- 提前获取用于后续检查 subscription, _ := middleware2.GetSubscriptionFromContext(c) diff --git a/backend/internal/handler/gemini_v1beta_handler.go b/backend/internal/handler/gemini_v1beta_handler.go index a43c5eb9..cc71e8e6 100644 --- a/backend/internal/handler/gemini_v1beta_handler.go +++ b/backend/internal/handler/gemini_v1beta_handler.go @@ -209,6 +209,9 @@ func (h *GatewayHandler) GeminiV1BetaModels(c *gin.Context) { // 1) user concurrency slot streamStarted := false + if h.errorPassthroughService != nil { + service.BindErrorPassthroughService(c, h.errorPassthroughService) + } userReleaseFunc, err := geminiConcurrency.AcquireUserSlotWithWait(c, authSubject.UserID, authSubject.Concurrency, stream, &streamStarted) if err != nil { googleError(c, http.StatusTooManyRequests, err.Error()) diff --git a/backend/internal/handler/openai_gateway_handler.go b/backend/internal/handler/openai_gateway_handler.go index 1dcb163b..835297b8 100644 --- a/backend/internal/handler/openai_gateway_handler.go +++ b/backend/internal/handler/openai_gateway_handler.go @@ -149,6 +149,11 @@ func (h *OpenAIGatewayHandler) Responses(c *gin.Context) { // Track if we've started streaming (for error handling) streamStarted := false + // 绑定错误透传服务,允许 service 层在非 failover 错误场景复用规则。 + if h.errorPassthroughService != nil { + service.BindErrorPassthroughService(c, h.errorPassthroughService) + } + // Get subscription info (may be nil) subscription, _ := middleware2.GetSubscriptionFromContext(c) diff --git a/backend/internal/service/error_passthrough_runtime.go b/backend/internal/service/error_passthrough_runtime.go new file mode 100644 index 00000000..65085d6f --- /dev/null +++ b/backend/internal/service/error_passthrough_runtime.go @@ -0,0 +1,67 @@ +package service + +import "github.com/gin-gonic/gin" + +const errorPassthroughServiceContextKey = "error_passthrough_service" + +// BindErrorPassthroughService 将错误透传服务绑定到请求上下文,供 service 层在非 failover 场景下复用规则。 +func BindErrorPassthroughService(c *gin.Context, svc *ErrorPassthroughService) { + if c == nil || svc == nil { + return + } + c.Set(errorPassthroughServiceContextKey, svc) +} + +func getBoundErrorPassthroughService(c *gin.Context) *ErrorPassthroughService { + if c == nil { + return nil + } + v, ok := c.Get(errorPassthroughServiceContextKey) + if !ok { + return nil + } + svc, ok := v.(*ErrorPassthroughService) + if !ok { + return nil + } + return svc +} + +// applyErrorPassthroughRule 按规则改写错误响应;未命中时返回默认响应参数。 +func applyErrorPassthroughRule( + c *gin.Context, + platform string, + upstreamStatus int, + responseBody []byte, + defaultStatus int, + defaultErrType string, + defaultErrMsg string, +) (status int, errType string, errMsg string, matched bool) { + status = defaultStatus + errType = defaultErrType + errMsg = defaultErrMsg + + svc := getBoundErrorPassthroughService(c) + if svc == nil { + return status, errType, errMsg, false + } + + rule := svc.MatchRule(platform, upstreamStatus, responseBody) + if rule == nil { + return status, errType, errMsg, false + } + + status = upstreamStatus + if !rule.PassthroughCode && rule.ResponseCode != nil { + status = *rule.ResponseCode + } + + errMsg = ExtractUpstreamErrorMessage(responseBody) + if !rule.PassthroughBody && rule.CustomMessage != nil { + errMsg = *rule.CustomMessage + } + + // 与现有 failover 场景保持一致:命中规则时统一返回 upstream_error。 + errType = "upstream_error" + return status, errType, errMsg, true +} diff --git a/backend/internal/service/error_passthrough_runtime_test.go b/backend/internal/service/error_passthrough_runtime_test.go new file mode 100644 index 00000000..393e6e59 --- /dev/null +++ b/backend/internal/service/error_passthrough_runtime_test.go @@ -0,0 +1,211 @@ +package service + +import ( + "bytes" + "context" + "encoding/json" + "io" + "net/http" + "net/http/httptest" + "testing" + + "github.com/Wei-Shaw/sub2api/internal/model" + "github.com/gin-gonic/gin" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestApplyErrorPassthroughRule_NoBoundService(t *testing.T) { + gin.SetMode(gin.TestMode) + rec := httptest.NewRecorder() + c, _ := gin.CreateTestContext(rec) + + status, errType, errMsg, matched := applyErrorPassthroughRule( + c, + PlatformAnthropic, + http.StatusUnprocessableEntity, + []byte(`{"error":{"message":"invalid schema"}}`), + http.StatusBadGateway, + "upstream_error", + "Upstream request failed", + ) + + assert.False(t, matched) + assert.Equal(t, http.StatusBadGateway, status) + assert.Equal(t, "upstream_error", errType) + assert.Equal(t, "Upstream request failed", errMsg) +} + +func TestGatewayHandleErrorResponse_NoRuleKeepsDefault(t *testing.T) { + gin.SetMode(gin.TestMode) + rec := httptest.NewRecorder() + c, _ := gin.CreateTestContext(rec) + + svc := &GatewayService{} + respBody := []byte(`{"error":{"message":"Invalid schema for field messages"}}`) + resp := &http.Response{ + StatusCode: http.StatusUnprocessableEntity, + Body: io.NopCloser(bytes.NewReader(respBody)), + Header: http.Header{}, + } + account := &Account{ID: 11, Platform: PlatformAnthropic, Type: AccountTypeAPIKey} + + _, err := svc.handleErrorResponse(context.Background(), resp, c, account) + require.Error(t, err) + assert.Equal(t, http.StatusBadGateway, rec.Code) + + var payload map[string]any + require.NoError(t, json.Unmarshal(rec.Body.Bytes(), &payload)) + errField, ok := payload["error"].(map[string]any) + require.True(t, ok) + assert.Equal(t, "upstream_error", errField["type"]) + assert.Equal(t, "Upstream request failed", errField["message"]) +} + +func TestOpenAIHandleErrorResponse_NoRuleKeepsDefault(t *testing.T) { + gin.SetMode(gin.TestMode) + rec := httptest.NewRecorder() + c, _ := gin.CreateTestContext(rec) + + svc := &OpenAIGatewayService{} + respBody := []byte(`{"error":{"message":"Invalid schema for field messages"}}`) + resp := &http.Response{ + StatusCode: http.StatusUnprocessableEntity, + Body: io.NopCloser(bytes.NewReader(respBody)), + Header: http.Header{}, + } + account := &Account{ID: 12, Platform: PlatformOpenAI, Type: AccountTypeAPIKey} + + _, err := svc.handleErrorResponse(context.Background(), resp, c, account) + require.Error(t, err) + assert.Equal(t, http.StatusBadGateway, rec.Code) + + var payload map[string]any + require.NoError(t, json.Unmarshal(rec.Body.Bytes(), &payload)) + errField, ok := payload["error"].(map[string]any) + require.True(t, ok) + assert.Equal(t, "upstream_error", errField["type"]) + assert.Equal(t, "Upstream request failed", errField["message"]) +} + +func TestGeminiWriteGeminiMappedError_NoRuleKeepsDefault(t *testing.T) { + gin.SetMode(gin.TestMode) + rec := httptest.NewRecorder() + c, _ := gin.CreateTestContext(rec) + + svc := &GeminiMessagesCompatService{} + respBody := []byte(`{"error":{"code":422,"message":"Invalid schema for field messages","status":"INVALID_ARGUMENT"}}`) + account := &Account{ID: 13, Platform: PlatformGemini, Type: AccountTypeAPIKey} + + err := svc.writeGeminiMappedError(c, account, http.StatusUnprocessableEntity, "req-2", respBody) + require.Error(t, err) + assert.Equal(t, http.StatusBadRequest, rec.Code) + + var payload map[string]any + require.NoError(t, json.Unmarshal(rec.Body.Bytes(), &payload)) + errField, ok := payload["error"].(map[string]any) + require.True(t, ok) + assert.Equal(t, "invalid_request_error", errField["type"]) + assert.Equal(t, "Upstream request failed", errField["message"]) +} + +func TestGatewayHandleErrorResponse_AppliesRuleFor422(t *testing.T) { + gin.SetMode(gin.TestMode) + rec := httptest.NewRecorder() + c, _ := gin.CreateTestContext(rec) + + ruleSvc := &ErrorPassthroughService{} + ruleSvc.setLocalCache([]*model.ErrorPassthroughRule{newNonFailoverPassthroughRule(http.StatusUnprocessableEntity, "invalid schema", http.StatusTeapot, "上游请求失败")}) + BindErrorPassthroughService(c, ruleSvc) + + svc := &GatewayService{} + respBody := []byte(`{"error":{"message":"Invalid schema for field messages"}}`) + resp := &http.Response{ + StatusCode: http.StatusUnprocessableEntity, + Body: io.NopCloser(bytes.NewReader(respBody)), + Header: http.Header{}, + } + account := &Account{ID: 1, Platform: PlatformAnthropic, Type: AccountTypeAPIKey} + + _, err := svc.handleErrorResponse(context.Background(), resp, c, account) + require.Error(t, err) + assert.Equal(t, http.StatusTeapot, rec.Code) + + var payload map[string]any + require.NoError(t, json.Unmarshal(rec.Body.Bytes(), &payload)) + errField, ok := payload["error"].(map[string]any) + require.True(t, ok) + assert.Equal(t, "upstream_error", errField["type"]) + assert.Equal(t, "上游请求失败", errField["message"]) +} + +func TestOpenAIHandleErrorResponse_AppliesRuleFor422(t *testing.T) { + gin.SetMode(gin.TestMode) + rec := httptest.NewRecorder() + c, _ := gin.CreateTestContext(rec) + + ruleSvc := &ErrorPassthroughService{} + ruleSvc.setLocalCache([]*model.ErrorPassthroughRule{newNonFailoverPassthroughRule(http.StatusUnprocessableEntity, "invalid schema", http.StatusTeapot, "OpenAI上游失败")}) + BindErrorPassthroughService(c, ruleSvc) + + svc := &OpenAIGatewayService{} + respBody := []byte(`{"error":{"message":"Invalid schema for field messages"}}`) + resp := &http.Response{ + StatusCode: http.StatusUnprocessableEntity, + Body: io.NopCloser(bytes.NewReader(respBody)), + Header: http.Header{}, + } + account := &Account{ID: 2, Platform: PlatformOpenAI, Type: AccountTypeAPIKey} + + _, err := svc.handleErrorResponse(context.Background(), resp, c, account) + require.Error(t, err) + assert.Equal(t, http.StatusTeapot, rec.Code) + + var payload map[string]any + require.NoError(t, json.Unmarshal(rec.Body.Bytes(), &payload)) + errField, ok := payload["error"].(map[string]any) + require.True(t, ok) + assert.Equal(t, "upstream_error", errField["type"]) + assert.Equal(t, "OpenAI上游失败", errField["message"]) +} + +func TestGeminiWriteGeminiMappedError_AppliesRuleFor422(t *testing.T) { + gin.SetMode(gin.TestMode) + rec := httptest.NewRecorder() + c, _ := gin.CreateTestContext(rec) + + ruleSvc := &ErrorPassthroughService{} + ruleSvc.setLocalCache([]*model.ErrorPassthroughRule{newNonFailoverPassthroughRule(http.StatusUnprocessableEntity, "invalid schema", http.StatusTeapot, "Gemini上游失败")}) + BindErrorPassthroughService(c, ruleSvc) + + svc := &GeminiMessagesCompatService{} + respBody := []byte(`{"error":{"code":422,"message":"Invalid schema for field messages","status":"INVALID_ARGUMENT"}}`) + account := &Account{ID: 3, Platform: PlatformGemini, Type: AccountTypeAPIKey} + + err := svc.writeGeminiMappedError(c, account, http.StatusUnprocessableEntity, "req-1", respBody) + require.Error(t, err) + assert.Equal(t, http.StatusTeapot, rec.Code) + + var payload map[string]any + require.NoError(t, json.Unmarshal(rec.Body.Bytes(), &payload)) + errField, ok := payload["error"].(map[string]any) + require.True(t, ok) + assert.Equal(t, "upstream_error", errField["type"]) + assert.Equal(t, "Gemini上游失败", errField["message"]) +} + +func newNonFailoverPassthroughRule(statusCode int, keyword string, respCode int, customMessage string) *model.ErrorPassthroughRule { + return &model.ErrorPassthroughRule{ + ID: 1, + Name: "non-failover-rule", + Enabled: true, + Priority: 1, + ErrorCodes: []int{statusCode}, + Keywords: []string{keyword}, + MatchMode: model.MatchModeAll, + PassthroughCode: false, + ResponseCode: &respCode, + PassthroughBody: false, + CustomMessage: &customMessage, + } +} diff --git a/backend/internal/service/gateway_service.go b/backend/internal/service/gateway_service.go index 7a029d49..250f7bff 100644 --- a/backend/internal/service/gateway_service.go +++ b/backend/internal/service/gateway_service.go @@ -2576,24 +2576,20 @@ func (s *GatewayService) selectAccountWithMixedScheduling(ctx context.Context, g // 对于 Antigravity 平台,会先获取映射后的最终模型名(包括 thinking 后缀)再检查支持 func (s *GatewayService) isModelSupportedByAccountWithContext(ctx context.Context, account *Account, requestedModel string) bool { if account.Platform == PlatformAntigravity { - // Antigravity 平台使用专门的模型支持检查 if strings.TrimSpace(requestedModel) == "" { return true } - if !IsAntigravityModelSupported(requestedModel) { + // 使用与转发阶段一致的映射逻辑:自定义映射优先 → 默认映射兜底 + mapped := mapAntigravityModel(account, requestedModel) + if mapped == "" { return false } - // 先用默认映射获取基础模型名,再应用 thinking 后缀 - defaultMapped, exists := domain.DefaultAntigravityModelMapping[requestedModel] - if !exists || defaultMapped == "" { - return false - } - finalModel := defaultMapped + // 应用 thinking 后缀后检查最终模型是否在账号映射中 if enabled, ok := ctx.Value(ctxkey.ThinkingEnabled).(bool); ok { - finalModel = applyThinkingModelSuffix(finalModel, enabled) + finalModel := applyThinkingModelSuffix(mapped, enabled) + return account.IsModelSupported(finalModel) } - // 使用最终模型名检查 model_mapping 支持 - return account.IsModelSupported(finalModel) + return true } return s.isModelSupportedByAccount(account, requestedModel) } @@ -2601,15 +2597,10 @@ func (s *GatewayService) isModelSupportedByAccountWithContext(ctx context.Contex // isModelSupportedByAccount 根据账户平台检查模型支持(无 context,用于非 Antigravity 平台) func (s *GatewayService) isModelSupportedByAccount(account *Account, requestedModel string) bool { if account.Platform == PlatformAntigravity { - // Antigravity 应使用 isModelSupportedByAccountWithContext - // 这里作为兼容保留,使用原始模型名检查 if strings.TrimSpace(requestedModel) == "" { return true } - if !IsAntigravityModelSupported(requestedModel) { - return false - } - return account.IsModelSupported(requestedModel) + return mapAntigravityModel(account, requestedModel) != "" } // OAuth/SetupToken 账号使用 Anthropic 标准映射(短ID → 长ID) if account.Platform == PlatformAnthropic && account.Type != AccountTypeAPIKey { @@ -3919,6 +3910,34 @@ func (s *GatewayService) handleErrorResponse(ctx context.Context, resp *http.Res ) } + // 非 failover 错误也支持错误透传规则匹配。 + if status, errType, errMsg, matched := applyErrorPassthroughRule( + c, + account.Platform, + resp.StatusCode, + body, + http.StatusBadGateway, + "upstream_error", + "Upstream request failed", + ); matched { + c.JSON(status, gin.H{ + "type": "error", + "error": gin.H{ + "type": errType, + "message": errMsg, + }, + }) + + summary := upstreamMsg + if summary == "" { + summary = errMsg + } + if summary == "" { + return nil, fmt.Errorf("upstream error: %d (passthrough rule matched)", resp.StatusCode) + } + return nil, fmt.Errorf("upstream error: %d (passthrough rule matched) message=%s", resp.StatusCode, summary) + } + // 根据状态码返回适当的自定义错误响应(不透传上游详细信息) var errType, errMsg string var statusCode int @@ -4050,6 +4069,33 @@ func (s *GatewayService) handleRetryExhaustedError(ctx context.Context, resp *ht ) } + if status, errType, errMsg, matched := applyErrorPassthroughRule( + c, + account.Platform, + resp.StatusCode, + respBody, + http.StatusBadGateway, + "upstream_error", + "Upstream request failed after retries", + ); matched { + c.JSON(status, gin.H{ + "type": "error", + "error": gin.H{ + "type": errType, + "message": errMsg, + }, + }) + + summary := upstreamMsg + if summary == "" { + summary = errMsg + } + if summary == "" { + return nil, fmt.Errorf("upstream error: %d (retries exhausted, passthrough rule matched)", resp.StatusCode) + } + return nil, fmt.Errorf("upstream error: %d (retries exhausted, passthrough rule matched) message=%s", resp.StatusCode, summary) + } + // 返回统一的重试耗尽错误响应 c.JSON(http.StatusBadGateway, gin.H{ "type": "error", diff --git a/backend/internal/service/gemini_messages_compat_service.go b/backend/internal/service/gemini_messages_compat_service.go index 964250d8..0f156c2e 100644 --- a/backend/internal/service/gemini_messages_compat_service.go +++ b/backend/internal/service/gemini_messages_compat_service.go @@ -362,7 +362,10 @@ func (s *GeminiMessagesCompatService) isBetterGeminiAccount(candidate, current * // isModelSupportedByAccount 根据账户平台检查模型支持 func (s *GeminiMessagesCompatService) isModelSupportedByAccount(account *Account, requestedModel string) bool { if account.Platform == PlatformAntigravity { - return IsAntigravityModelSupported(requestedModel) + if strings.TrimSpace(requestedModel) == "" { + return true + } + return mapAntigravityModel(account, requestedModel) != "" } return account.IsModelSupported(requestedModel) } @@ -1498,6 +1501,28 @@ func (s *GeminiMessagesCompatService) writeGeminiMappedError(c *gin.Context, acc log.Printf("[Gemini] upstream error %d: %s", upstreamStatus, truncateForLog(body, s.cfg.Gateway.LogUpstreamErrorBodyMaxBytes)) } + if status, errType, errMsg, matched := applyErrorPassthroughRule( + c, + PlatformGemini, + upstreamStatus, + body, + http.StatusBadGateway, + "upstream_error", + "Upstream request failed", + ); matched { + c.JSON(status, gin.H{ + "type": "error", + "error": gin.H{"type": errType, "message": errMsg}, + }) + if upstreamMsg == "" { + upstreamMsg = errMsg + } + if upstreamMsg == "" { + return fmt.Errorf("upstream error: %d (passthrough rule matched)", upstreamStatus) + } + return fmt.Errorf("upstream error: %d (passthrough rule matched) message=%s", upstreamStatus, upstreamMsg) + } + var statusCode int var errType, errMsg string diff --git a/backend/internal/service/openai_gateway_service.go b/backend/internal/service/openai_gateway_service.go index ae3106d2..fbe81cb4 100644 --- a/backend/internal/service/openai_gateway_service.go +++ b/backend/internal/service/openai_gateway_service.go @@ -1087,6 +1087,30 @@ func (s *OpenAIGatewayService) handleErrorResponse(ctx context.Context, resp *ht ) } + if status, errType, errMsg, matched := applyErrorPassthroughRule( + c, + PlatformOpenAI, + resp.StatusCode, + body, + http.StatusBadGateway, + "upstream_error", + "Upstream request failed", + ); matched { + c.JSON(status, gin.H{ + "error": gin.H{ + "type": errType, + "message": errMsg, + }, + }) + if upstreamMsg == "" { + upstreamMsg = errMsg + } + if upstreamMsg == "" { + return nil, fmt.Errorf("upstream error: %d (passthrough rule matched)", resp.StatusCode) + } + return nil, fmt.Errorf("upstream error: %d (passthrough rule matched) message=%s", resp.StatusCode, upstreamMsg) + } + // Check custom error codes if !account.ShouldHandleErrorCode(resp.StatusCode) { appendOpsUpstreamError(c, OpsUpstreamErrorEvent{