diff --git a/backend/internal/handler/gateway_handler.go b/backend/internal/handler/gateway_handler.go index beaddbca..a8b7bd61 100644 --- a/backend/internal/handler/gateway_handler.go +++ b/backend/internal/handler/gateway_handler.go @@ -135,6 +135,11 @@ func (h *GatewayHandler) Messages(c *gin.Context) { // Track if we've started streaming (for error handling) streamStarted := false + // 绑定错误透传服务,允许 service 层在非 failover 错误场景复用规则。 + if h.errorPassthroughService != nil { + service.BindErrorPassthroughService(c, h.errorPassthroughService) + } + // 获取订阅信息(可能为nil)- 提前获取用于后续检查 subscription, _ := middleware2.GetSubscriptionFromContext(c) diff --git a/backend/internal/handler/gemini_v1beta_handler.go b/backend/internal/handler/gemini_v1beta_handler.go index be634c0c..3e670378 100644 --- a/backend/internal/handler/gemini_v1beta_handler.go +++ b/backend/internal/handler/gemini_v1beta_handler.go @@ -207,6 +207,9 @@ func (h *GatewayHandler) GeminiV1BetaModels(c *gin.Context) { // 1) user concurrency slot streamStarted := false + if h.errorPassthroughService != nil { + service.BindErrorPassthroughService(c, h.errorPassthroughService) + } userReleaseFunc, err := geminiConcurrency.AcquireUserSlotWithWait(c, authSubject.UserID, authSubject.Concurrency, stream, &streamStarted) if err != nil { googleError(c, http.StatusTooManyRequests, err.Error()) diff --git a/backend/internal/handler/openai_gateway_handler.go b/backend/internal/handler/openai_gateway_handler.go index 1dcb163b..835297b8 100644 --- a/backend/internal/handler/openai_gateway_handler.go +++ b/backend/internal/handler/openai_gateway_handler.go @@ -149,6 +149,11 @@ func (h *OpenAIGatewayHandler) Responses(c *gin.Context) { // Track if we've started streaming (for error handling) streamStarted := false + // 绑定错误透传服务,允许 service 层在非 failover 错误场景复用规则。 + if h.errorPassthroughService != nil { + service.BindErrorPassthroughService(c, h.errorPassthroughService) + } + // Get subscription info (may be nil) subscription, _ := middleware2.GetSubscriptionFromContext(c) diff --git a/backend/internal/service/error_passthrough_runtime.go b/backend/internal/service/error_passthrough_runtime.go new file mode 100644 index 00000000..65085d6f --- /dev/null +++ b/backend/internal/service/error_passthrough_runtime.go @@ -0,0 +1,67 @@ +package service + +import "github.com/gin-gonic/gin" + +const errorPassthroughServiceContextKey = "error_passthrough_service" + +// BindErrorPassthroughService 将错误透传服务绑定到请求上下文,供 service 层在非 failover 场景下复用规则。 +func BindErrorPassthroughService(c *gin.Context, svc *ErrorPassthroughService) { + if c == nil || svc == nil { + return + } + c.Set(errorPassthroughServiceContextKey, svc) +} + +func getBoundErrorPassthroughService(c *gin.Context) *ErrorPassthroughService { + if c == nil { + return nil + } + v, ok := c.Get(errorPassthroughServiceContextKey) + if !ok { + return nil + } + svc, ok := v.(*ErrorPassthroughService) + if !ok { + return nil + } + return svc +} + +// applyErrorPassthroughRule 按规则改写错误响应;未命中时返回默认响应参数。 +func applyErrorPassthroughRule( + c *gin.Context, + platform string, + upstreamStatus int, + responseBody []byte, + defaultStatus int, + defaultErrType string, + defaultErrMsg string, +) (status int, errType string, errMsg string, matched bool) { + status = defaultStatus + errType = defaultErrType + errMsg = defaultErrMsg + + svc := getBoundErrorPassthroughService(c) + if svc == nil { + return status, errType, errMsg, false + } + + rule := svc.MatchRule(platform, upstreamStatus, responseBody) + if rule == nil { + return status, errType, errMsg, false + } + + status = upstreamStatus + if !rule.PassthroughCode && rule.ResponseCode != nil { + status = *rule.ResponseCode + } + + errMsg = ExtractUpstreamErrorMessage(responseBody) + if !rule.PassthroughBody && rule.CustomMessage != nil { + errMsg = *rule.CustomMessage + } + + // 与现有 failover 场景保持一致:命中规则时统一返回 upstream_error。 + errType = "upstream_error" + return status, errType, errMsg, true +} diff --git a/backend/internal/service/error_passthrough_runtime_test.go b/backend/internal/service/error_passthrough_runtime_test.go new file mode 100644 index 00000000..393e6e59 --- /dev/null +++ b/backend/internal/service/error_passthrough_runtime_test.go @@ -0,0 +1,211 @@ +package service + +import ( + "bytes" + "context" + "encoding/json" + "io" + "net/http" + "net/http/httptest" + "testing" + + "github.com/Wei-Shaw/sub2api/internal/model" + "github.com/gin-gonic/gin" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestApplyErrorPassthroughRule_NoBoundService(t *testing.T) { + gin.SetMode(gin.TestMode) + rec := httptest.NewRecorder() + c, _ := gin.CreateTestContext(rec) + + status, errType, errMsg, matched := applyErrorPassthroughRule( + c, + PlatformAnthropic, + http.StatusUnprocessableEntity, + []byte(`{"error":{"message":"invalid schema"}}`), + http.StatusBadGateway, + "upstream_error", + "Upstream request failed", + ) + + assert.False(t, matched) + assert.Equal(t, http.StatusBadGateway, status) + assert.Equal(t, "upstream_error", errType) + assert.Equal(t, "Upstream request failed", errMsg) +} + +func TestGatewayHandleErrorResponse_NoRuleKeepsDefault(t *testing.T) { + gin.SetMode(gin.TestMode) + rec := httptest.NewRecorder() + c, _ := gin.CreateTestContext(rec) + + svc := &GatewayService{} + respBody := []byte(`{"error":{"message":"Invalid schema for field messages"}}`) + resp := &http.Response{ + StatusCode: http.StatusUnprocessableEntity, + Body: io.NopCloser(bytes.NewReader(respBody)), + Header: http.Header{}, + } + account := &Account{ID: 11, Platform: PlatformAnthropic, Type: AccountTypeAPIKey} + + _, err := svc.handleErrorResponse(context.Background(), resp, c, account) + require.Error(t, err) + assert.Equal(t, http.StatusBadGateway, rec.Code) + + var payload map[string]any + require.NoError(t, json.Unmarshal(rec.Body.Bytes(), &payload)) + errField, ok := payload["error"].(map[string]any) + require.True(t, ok) + assert.Equal(t, "upstream_error", errField["type"]) + assert.Equal(t, "Upstream request failed", errField["message"]) +} + +func TestOpenAIHandleErrorResponse_NoRuleKeepsDefault(t *testing.T) { + gin.SetMode(gin.TestMode) + rec := httptest.NewRecorder() + c, _ := gin.CreateTestContext(rec) + + svc := &OpenAIGatewayService{} + respBody := []byte(`{"error":{"message":"Invalid schema for field messages"}}`) + resp := &http.Response{ + StatusCode: http.StatusUnprocessableEntity, + Body: io.NopCloser(bytes.NewReader(respBody)), + Header: http.Header{}, + } + account := &Account{ID: 12, Platform: PlatformOpenAI, Type: AccountTypeAPIKey} + + _, err := svc.handleErrorResponse(context.Background(), resp, c, account) + require.Error(t, err) + assert.Equal(t, http.StatusBadGateway, rec.Code) + + var payload map[string]any + require.NoError(t, json.Unmarshal(rec.Body.Bytes(), &payload)) + errField, ok := payload["error"].(map[string]any) + require.True(t, ok) + assert.Equal(t, "upstream_error", errField["type"]) + assert.Equal(t, "Upstream request failed", errField["message"]) +} + +func TestGeminiWriteGeminiMappedError_NoRuleKeepsDefault(t *testing.T) { + gin.SetMode(gin.TestMode) + rec := httptest.NewRecorder() + c, _ := gin.CreateTestContext(rec) + + svc := &GeminiMessagesCompatService{} + respBody := []byte(`{"error":{"code":422,"message":"Invalid schema for field messages","status":"INVALID_ARGUMENT"}}`) + account := &Account{ID: 13, Platform: PlatformGemini, Type: AccountTypeAPIKey} + + err := svc.writeGeminiMappedError(c, account, http.StatusUnprocessableEntity, "req-2", respBody) + require.Error(t, err) + assert.Equal(t, http.StatusBadRequest, rec.Code) + + var payload map[string]any + require.NoError(t, json.Unmarshal(rec.Body.Bytes(), &payload)) + errField, ok := payload["error"].(map[string]any) + require.True(t, ok) + assert.Equal(t, "invalid_request_error", errField["type"]) + assert.Equal(t, "Upstream request failed", errField["message"]) +} + +func TestGatewayHandleErrorResponse_AppliesRuleFor422(t *testing.T) { + gin.SetMode(gin.TestMode) + rec := httptest.NewRecorder() + c, _ := gin.CreateTestContext(rec) + + ruleSvc := &ErrorPassthroughService{} + ruleSvc.setLocalCache([]*model.ErrorPassthroughRule{newNonFailoverPassthroughRule(http.StatusUnprocessableEntity, "invalid schema", http.StatusTeapot, "上游请求失败")}) + BindErrorPassthroughService(c, ruleSvc) + + svc := &GatewayService{} + respBody := []byte(`{"error":{"message":"Invalid schema for field messages"}}`) + resp := &http.Response{ + StatusCode: http.StatusUnprocessableEntity, + Body: io.NopCloser(bytes.NewReader(respBody)), + Header: http.Header{}, + } + account := &Account{ID: 1, Platform: PlatformAnthropic, Type: AccountTypeAPIKey} + + _, err := svc.handleErrorResponse(context.Background(), resp, c, account) + require.Error(t, err) + assert.Equal(t, http.StatusTeapot, rec.Code) + + var payload map[string]any + require.NoError(t, json.Unmarshal(rec.Body.Bytes(), &payload)) + errField, ok := payload["error"].(map[string]any) + require.True(t, ok) + assert.Equal(t, "upstream_error", errField["type"]) + assert.Equal(t, "上游请求失败", errField["message"]) +} + +func TestOpenAIHandleErrorResponse_AppliesRuleFor422(t *testing.T) { + gin.SetMode(gin.TestMode) + rec := httptest.NewRecorder() + c, _ := gin.CreateTestContext(rec) + + ruleSvc := &ErrorPassthroughService{} + ruleSvc.setLocalCache([]*model.ErrorPassthroughRule{newNonFailoverPassthroughRule(http.StatusUnprocessableEntity, "invalid schema", http.StatusTeapot, "OpenAI上游失败")}) + BindErrorPassthroughService(c, ruleSvc) + + svc := &OpenAIGatewayService{} + respBody := []byte(`{"error":{"message":"Invalid schema for field messages"}}`) + resp := &http.Response{ + StatusCode: http.StatusUnprocessableEntity, + Body: io.NopCloser(bytes.NewReader(respBody)), + Header: http.Header{}, + } + account := &Account{ID: 2, Platform: PlatformOpenAI, Type: AccountTypeAPIKey} + + _, err := svc.handleErrorResponse(context.Background(), resp, c, account) + require.Error(t, err) + assert.Equal(t, http.StatusTeapot, rec.Code) + + var payload map[string]any + require.NoError(t, json.Unmarshal(rec.Body.Bytes(), &payload)) + errField, ok := payload["error"].(map[string]any) + require.True(t, ok) + assert.Equal(t, "upstream_error", errField["type"]) + assert.Equal(t, "OpenAI上游失败", errField["message"]) +} + +func TestGeminiWriteGeminiMappedError_AppliesRuleFor422(t *testing.T) { + gin.SetMode(gin.TestMode) + rec := httptest.NewRecorder() + c, _ := gin.CreateTestContext(rec) + + ruleSvc := &ErrorPassthroughService{} + ruleSvc.setLocalCache([]*model.ErrorPassthroughRule{newNonFailoverPassthroughRule(http.StatusUnprocessableEntity, "invalid schema", http.StatusTeapot, "Gemini上游失败")}) + BindErrorPassthroughService(c, ruleSvc) + + svc := &GeminiMessagesCompatService{} + respBody := []byte(`{"error":{"code":422,"message":"Invalid schema for field messages","status":"INVALID_ARGUMENT"}}`) + account := &Account{ID: 3, Platform: PlatformGemini, Type: AccountTypeAPIKey} + + err := svc.writeGeminiMappedError(c, account, http.StatusUnprocessableEntity, "req-1", respBody) + require.Error(t, err) + assert.Equal(t, http.StatusTeapot, rec.Code) + + var payload map[string]any + require.NoError(t, json.Unmarshal(rec.Body.Bytes(), &payload)) + errField, ok := payload["error"].(map[string]any) + require.True(t, ok) + assert.Equal(t, "upstream_error", errField["type"]) + assert.Equal(t, "Gemini上游失败", errField["message"]) +} + +func newNonFailoverPassthroughRule(statusCode int, keyword string, respCode int, customMessage string) *model.ErrorPassthroughRule { + return &model.ErrorPassthroughRule{ + ID: 1, + Name: "non-failover-rule", + Enabled: true, + Priority: 1, + ErrorCodes: []int{statusCode}, + Keywords: []string{keyword}, + MatchMode: model.MatchModeAll, + PassthroughCode: false, + ResponseCode: &respCode, + PassthroughBody: false, + CustomMessage: &customMessage, + } +} diff --git a/backend/internal/service/error_passthrough_service.go b/backend/internal/service/error_passthrough_service.go index 99dc70e3..c3e0f630 100644 --- a/backend/internal/service/error_passthrough_service.go +++ b/backend/internal/service/error_passthrough_service.go @@ -6,6 +6,7 @@ import ( "sort" "strings" "sync" + "time" "github.com/Wei-Shaw/sub2api/internal/model" ) @@ -60,8 +61,11 @@ func NewErrorPassthroughService( // 启动时加载规则到本地缓存 ctx := context.Background() - if err := svc.refreshLocalCache(ctx); err != nil { - log.Printf("[ErrorPassthroughService] Failed to load rules on startup: %v", err) + if err := svc.reloadRulesFromDB(ctx); err != nil { + log.Printf("[ErrorPassthroughService] Failed to load rules from DB on startup: %v", err) + if fallbackErr := svc.refreshLocalCache(ctx); fallbackErr != nil { + log.Printf("[ErrorPassthroughService] Failed to load rules from cache fallback on startup: %v", fallbackErr) + } } // 订阅缓存更新通知 @@ -98,7 +102,9 @@ func (s *ErrorPassthroughService) Create(ctx context.Context, rule *model.ErrorP } // 刷新缓存 - s.invalidateAndNotify(ctx) + refreshCtx, cancel := s.newCacheRefreshContext() + defer cancel() + s.invalidateAndNotify(refreshCtx) return created, nil } @@ -115,7 +121,9 @@ func (s *ErrorPassthroughService) Update(ctx context.Context, rule *model.ErrorP } // 刷新缓存 - s.invalidateAndNotify(ctx) + refreshCtx, cancel := s.newCacheRefreshContext() + defer cancel() + s.invalidateAndNotify(refreshCtx) return updated, nil } @@ -127,7 +135,9 @@ func (s *ErrorPassthroughService) Delete(ctx context.Context, id int64) error { } // 刷新缓存 - s.invalidateAndNotify(ctx) + refreshCtx, cancel := s.newCacheRefreshContext() + defer cancel() + s.invalidateAndNotify(refreshCtx) return nil } @@ -189,7 +199,12 @@ func (s *ErrorPassthroughService) refreshLocalCache(ctx context.Context) error { } } - // 从数据库加载(repo.List 已按 priority 排序) + return s.reloadRulesFromDB(ctx) +} + +// 从数据库加载(repo.List 已按 priority 排序) +// 注意:该方法会绕过 cache.Get,确保拿到数据库最新值。 +func (s *ErrorPassthroughService) reloadRulesFromDB(ctx context.Context) error { rules, err := s.repo.List(ctx) if err != nil { return err @@ -222,11 +237,32 @@ func (s *ErrorPassthroughService) setLocalCache(rules []*model.ErrorPassthroughR s.localCacheMu.Unlock() } +// clearLocalCache 清空本地缓存,避免刷新失败时继续命中陈旧规则。 +func (s *ErrorPassthroughService) clearLocalCache() { + s.localCacheMu.Lock() + s.localCache = nil + s.localCacheMu.Unlock() +} + +// newCacheRefreshContext 为写路径缓存同步创建独立上下文,避免受请求取消影响。 +func (s *ErrorPassthroughService) newCacheRefreshContext() (context.Context, context.CancelFunc) { + return context.WithTimeout(context.Background(), 3*time.Second) +} + // invalidateAndNotify 使缓存失效并通知其他实例 func (s *ErrorPassthroughService) invalidateAndNotify(ctx context.Context) { + // 先失效缓存,避免后续刷新读到陈旧规则。 + if s.cache != nil { + if err := s.cache.Invalidate(ctx); err != nil { + log.Printf("[ErrorPassthroughService] Failed to invalidate cache: %v", err) + } + } + // 刷新本地缓存 - if err := s.refreshLocalCache(ctx); err != nil { + if err := s.reloadRulesFromDB(ctx); err != nil { log.Printf("[ErrorPassthroughService] Failed to refresh local cache: %v", err) + // 刷新失败时清空本地缓存,避免继续使用陈旧规则。 + s.clearLocalCache() } // 通知其他实例 diff --git a/backend/internal/service/error_passthrough_service_test.go b/backend/internal/service/error_passthrough_service_test.go index 205b4ec4..74c98d86 100644 --- a/backend/internal/service/error_passthrough_service_test.go +++ b/backend/internal/service/error_passthrough_service_test.go @@ -4,6 +4,7 @@ package service import ( "context" + "errors" "strings" "testing" @@ -14,14 +15,81 @@ import ( // mockErrorPassthroughRepo 用于测试的 mock repository type mockErrorPassthroughRepo struct { - rules []*model.ErrorPassthroughRule + rules []*model.ErrorPassthroughRule + listErr error + getErr error + createErr error + updateErr error + deleteErr error +} + +type mockErrorPassthroughCache struct { + rules []*model.ErrorPassthroughRule + hasData bool + getCalled int + setCalled int + invalidateCalled int + notifyCalled int +} + +func newMockErrorPassthroughCache(rules []*model.ErrorPassthroughRule, hasData bool) *mockErrorPassthroughCache { + return &mockErrorPassthroughCache{ + rules: cloneRules(rules), + hasData: hasData, + } +} + +func (m *mockErrorPassthroughCache) Get(ctx context.Context) ([]*model.ErrorPassthroughRule, bool) { + m.getCalled++ + if !m.hasData { + return nil, false + } + return cloneRules(m.rules), true +} + +func (m *mockErrorPassthroughCache) Set(ctx context.Context, rules []*model.ErrorPassthroughRule) error { + m.setCalled++ + m.rules = cloneRules(rules) + m.hasData = true + return nil +} + +func (m *mockErrorPassthroughCache) Invalidate(ctx context.Context) error { + m.invalidateCalled++ + m.rules = nil + m.hasData = false + return nil +} + +func (m *mockErrorPassthroughCache) NotifyUpdate(ctx context.Context) error { + m.notifyCalled++ + return nil +} + +func (m *mockErrorPassthroughCache) SubscribeUpdates(ctx context.Context, handler func()) { + // 单测中无需订阅行为 +} + +func cloneRules(rules []*model.ErrorPassthroughRule) []*model.ErrorPassthroughRule { + if rules == nil { + return nil + } + out := make([]*model.ErrorPassthroughRule, len(rules)) + copy(out, rules) + return out } func (m *mockErrorPassthroughRepo) List(ctx context.Context) ([]*model.ErrorPassthroughRule, error) { + if m.listErr != nil { + return nil, m.listErr + } return m.rules, nil } func (m *mockErrorPassthroughRepo) GetByID(ctx context.Context, id int64) (*model.ErrorPassthroughRule, error) { + if m.getErr != nil { + return nil, m.getErr + } for _, r := range m.rules { if r.ID == id { return r, nil @@ -31,12 +99,18 @@ func (m *mockErrorPassthroughRepo) GetByID(ctx context.Context, id int64) (*mode } func (m *mockErrorPassthroughRepo) Create(ctx context.Context, rule *model.ErrorPassthroughRule) (*model.ErrorPassthroughRule, error) { + if m.createErr != nil { + return nil, m.createErr + } rule.ID = int64(len(m.rules) + 1) m.rules = append(m.rules, rule) return rule, nil } func (m *mockErrorPassthroughRepo) Update(ctx context.Context, rule *model.ErrorPassthroughRule) (*model.ErrorPassthroughRule, error) { + if m.updateErr != nil { + return nil, m.updateErr + } for i, r := range m.rules { if r.ID == rule.ID { m.rules[i] = rule @@ -47,6 +121,9 @@ func (m *mockErrorPassthroughRepo) Update(ctx context.Context, rule *model.Error } func (m *mockErrorPassthroughRepo) Delete(ctx context.Context, id int64) error { + if m.deleteErr != nil { + return m.deleteErr + } for i, r := range m.rules { if r.ID == id { m.rules = append(m.rules[:i], m.rules[i+1:]...) @@ -750,6 +827,158 @@ func TestErrorPassthroughRule_Validate(t *testing.T) { } } +// ============================================================================= +// 测试写路径缓存刷新(Create/Update/Delete) +// ============================================================================= + +func TestCreate_ForceRefreshCacheAfterWrite(t *testing.T) { + ctx := context.Background() + + staleRule := newPassthroughRuleForWritePathTest(99, "service temporarily unavailable after multiple", "旧缓存消息") + repo := &mockErrorPassthroughRepo{rules: []*model.ErrorPassthroughRule{}} + cache := newMockErrorPassthroughCache([]*model.ErrorPassthroughRule{staleRule}, true) + + svc := &ErrorPassthroughService{repo: repo, cache: cache} + svc.setLocalCache([]*model.ErrorPassthroughRule{staleRule}) + + newRule := newPassthroughRuleForWritePathTest(0, "service temporarily unavailable after multiple", "上游请求失败") + created, err := svc.Create(ctx, newRule) + require.NoError(t, err) + require.NotNil(t, created) + + body := []byte(`{"message":"Service temporarily unavailable after multiple retries, please try again later"}`) + matched := svc.MatchRule("anthropic", 503, body) + require.NotNil(t, matched) + assert.Equal(t, created.ID, matched.ID) + if assert.NotNil(t, matched.CustomMessage) { + assert.Equal(t, "上游请求失败", *matched.CustomMessage) + } + + assert.Equal(t, 0, cache.getCalled, "写路径刷新不应依赖 cache.Get") + assert.Equal(t, 1, cache.invalidateCalled) + assert.Equal(t, 1, cache.setCalled) + assert.Equal(t, 1, cache.notifyCalled) +} + +func TestUpdate_ForceRefreshCacheAfterWrite(t *testing.T) { + ctx := context.Background() + + originalRule := newPassthroughRuleForWritePathTest(1, "old keyword", "旧消息") + repo := &mockErrorPassthroughRepo{rules: []*model.ErrorPassthroughRule{originalRule}} + cache := newMockErrorPassthroughCache([]*model.ErrorPassthroughRule{originalRule}, true) + + svc := &ErrorPassthroughService{repo: repo, cache: cache} + svc.setLocalCache([]*model.ErrorPassthroughRule{originalRule}) + + updatedRule := newPassthroughRuleForWritePathTest(1, "new keyword", "新消息") + _, err := svc.Update(ctx, updatedRule) + require.NoError(t, err) + + oldBody := []byte(`{"message":"old keyword"}`) + oldMatched := svc.MatchRule("anthropic", 503, oldBody) + assert.Nil(t, oldMatched, "更新后旧关键词不应继续命中") + + newBody := []byte(`{"message":"new keyword"}`) + newMatched := svc.MatchRule("anthropic", 503, newBody) + require.NotNil(t, newMatched) + if assert.NotNil(t, newMatched.CustomMessage) { + assert.Equal(t, "新消息", *newMatched.CustomMessage) + } + + assert.Equal(t, 0, cache.getCalled, "写路径刷新不应依赖 cache.Get") + assert.Equal(t, 1, cache.invalidateCalled) + assert.Equal(t, 1, cache.setCalled) + assert.Equal(t, 1, cache.notifyCalled) +} + +func TestDelete_ForceRefreshCacheAfterWrite(t *testing.T) { + ctx := context.Background() + + rule := newPassthroughRuleForWritePathTest(1, "to be deleted", "删除前消息") + repo := &mockErrorPassthroughRepo{rules: []*model.ErrorPassthroughRule{rule}} + cache := newMockErrorPassthroughCache([]*model.ErrorPassthroughRule{rule}, true) + + svc := &ErrorPassthroughService{repo: repo, cache: cache} + svc.setLocalCache([]*model.ErrorPassthroughRule{rule}) + + err := svc.Delete(ctx, 1) + require.NoError(t, err) + + body := []byte(`{"message":"to be deleted"}`) + matched := svc.MatchRule("anthropic", 503, body) + assert.Nil(t, matched, "删除后规则不应再命中") + + assert.Equal(t, 0, cache.getCalled, "写路径刷新不应依赖 cache.Get") + assert.Equal(t, 1, cache.invalidateCalled) + assert.Equal(t, 1, cache.setCalled) + assert.Equal(t, 1, cache.notifyCalled) +} + +func TestNewService_StartupReloadFromDBToHealStaleCache(t *testing.T) { + staleRule := newPassthroughRuleForWritePathTest(99, "stale keyword", "旧缓存消息") + latestRule := newPassthroughRuleForWritePathTest(1, "fresh keyword", "最新消息") + + repo := &mockErrorPassthroughRepo{rules: []*model.ErrorPassthroughRule{latestRule}} + cache := newMockErrorPassthroughCache([]*model.ErrorPassthroughRule{staleRule}, true) + + svc := NewErrorPassthroughService(repo, cache) + + matchedFresh := svc.MatchRule("anthropic", 503, []byte(`{"message":"fresh keyword"}`)) + require.NotNil(t, matchedFresh) + assert.Equal(t, int64(1), matchedFresh.ID) + + matchedStale := svc.MatchRule("anthropic", 503, []byte(`{"message":"stale keyword"}`)) + assert.Nil(t, matchedStale, "启动后应以 DB 最新规则覆盖旧缓存") + + assert.Equal(t, 0, cache.getCalled, "启动强制 DB 刷新不应依赖 cache.Get") + assert.Equal(t, 1, cache.setCalled, "启动后应回写缓存,覆盖陈旧缓存") +} + +func TestUpdate_RefreshFailureShouldNotKeepStaleEnabledRule(t *testing.T) { + ctx := context.Background() + + staleRule := newPassthroughRuleForWritePathTest(1, "service temporarily unavailable after multiple", "旧缓存消息") + repo := &mockErrorPassthroughRepo{ + rules: []*model.ErrorPassthroughRule{staleRule}, + listErr: errors.New("db list failed"), + } + cache := newMockErrorPassthroughCache([]*model.ErrorPassthroughRule{staleRule}, true) + + svc := &ErrorPassthroughService{repo: repo, cache: cache} + svc.setLocalCache([]*model.ErrorPassthroughRule{staleRule}) + + disabledRule := *staleRule + disabledRule.Enabled = false + _, err := svc.Update(ctx, &disabledRule) + require.NoError(t, err) + + body := []byte(`{"message":"Service temporarily unavailable after multiple retries, please try again later"}`) + matched := svc.MatchRule("anthropic", 503, body) + assert.Nil(t, matched, "刷新失败时不应继续命中旧的启用规则") + + svc.localCacheMu.RLock() + assert.Nil(t, svc.localCache, "刷新失败后应清空本地缓存,避免误命中") + svc.localCacheMu.RUnlock() +} + +func newPassthroughRuleForWritePathTest(id int64, keyword, customMsg string) *model.ErrorPassthroughRule { + responseCode := 503 + rule := &model.ErrorPassthroughRule{ + ID: id, + Name: "write-path-cache-refresh", + Enabled: true, + Priority: 1, + ErrorCodes: []int{503}, + Keywords: []string{keyword}, + MatchMode: model.MatchModeAll, + PassthroughCode: false, + ResponseCode: &responseCode, + PassthroughBody: false, + CustomMessage: &customMsg, + } + return rule +} + // Helper functions func testIntPtr(i int) *int { return &i } func testStrPtr(s string) *string { return &s } diff --git a/backend/internal/service/gateway_service.go b/backend/internal/service/gateway_service.go index 308f0f18..0256ac75 100644 --- a/backend/internal/service/gateway_service.go +++ b/backend/internal/service/gateway_service.go @@ -3563,6 +3563,34 @@ func (s *GatewayService) handleErrorResponse(ctx context.Context, resp *http.Res ) } + // 非 failover 错误也支持错误透传规则匹配。 + if status, errType, errMsg, matched := applyErrorPassthroughRule( + c, + account.Platform, + resp.StatusCode, + body, + http.StatusBadGateway, + "upstream_error", + "Upstream request failed", + ); matched { + c.JSON(status, gin.H{ + "type": "error", + "error": gin.H{ + "type": errType, + "message": errMsg, + }, + }) + + summary := upstreamMsg + if summary == "" { + summary = errMsg + } + if summary == "" { + return nil, fmt.Errorf("upstream error: %d (passthrough rule matched)", resp.StatusCode) + } + return nil, fmt.Errorf("upstream error: %d (passthrough rule matched) message=%s", resp.StatusCode, summary) + } + // 根据状态码返回适当的自定义错误响应(不透传上游详细信息) var errType, errMsg string var statusCode int @@ -3694,6 +3722,33 @@ func (s *GatewayService) handleRetryExhaustedError(ctx context.Context, resp *ht ) } + if status, errType, errMsg, matched := applyErrorPassthroughRule( + c, + account.Platform, + resp.StatusCode, + respBody, + http.StatusBadGateway, + "upstream_error", + "Upstream request failed after retries", + ); matched { + c.JSON(status, gin.H{ + "type": "error", + "error": gin.H{ + "type": errType, + "message": errMsg, + }, + }) + + summary := upstreamMsg + if summary == "" { + summary = errMsg + } + if summary == "" { + return nil, fmt.Errorf("upstream error: %d (retries exhausted, passthrough rule matched)", resp.StatusCode) + } + return nil, fmt.Errorf("upstream error: %d (retries exhausted, passthrough rule matched) message=%s", resp.StatusCode, summary) + } + // 返回统一的重试耗尽错误响应 c.JSON(http.StatusBadGateway, gin.H{ "type": "error", diff --git a/backend/internal/service/gemini_messages_compat_service.go b/backend/internal/service/gemini_messages_compat_service.go index eecb88f6..75b69656 100644 --- a/backend/internal/service/gemini_messages_compat_service.go +++ b/backend/internal/service/gemini_messages_compat_service.go @@ -1498,6 +1498,28 @@ func (s *GeminiMessagesCompatService) writeGeminiMappedError(c *gin.Context, acc log.Printf("[Gemini] upstream error %d: %s", upstreamStatus, truncateForLog(body, s.cfg.Gateway.LogUpstreamErrorBodyMaxBytes)) } + if status, errType, errMsg, matched := applyErrorPassthroughRule( + c, + PlatformGemini, + upstreamStatus, + body, + http.StatusBadGateway, + "upstream_error", + "Upstream request failed", + ); matched { + c.JSON(status, gin.H{ + "type": "error", + "error": gin.H{"type": errType, "message": errMsg}, + }) + if upstreamMsg == "" { + upstreamMsg = errMsg + } + if upstreamMsg == "" { + return fmt.Errorf("upstream error: %d (passthrough rule matched)", upstreamStatus) + } + return fmt.Errorf("upstream error: %d (passthrough rule matched) message=%s", upstreamStatus, upstreamMsg) + } + var statusCode int var errType, errMsg string diff --git a/backend/internal/service/openai_gateway_service.go b/backend/internal/service/openai_gateway_service.go index 564ffa4d..52800f07 100644 --- a/backend/internal/service/openai_gateway_service.go +++ b/backend/internal/service/openai_gateway_service.go @@ -1087,6 +1087,30 @@ func (s *OpenAIGatewayService) handleErrorResponse(ctx context.Context, resp *ht ) } + if status, errType, errMsg, matched := applyErrorPassthroughRule( + c, + PlatformOpenAI, + resp.StatusCode, + body, + http.StatusBadGateway, + "upstream_error", + "Upstream request failed", + ); matched { + c.JSON(status, gin.H{ + "error": gin.H{ + "type": errType, + "message": errMsg, + }, + }) + if upstreamMsg == "" { + upstreamMsg = errMsg + } + if upstreamMsg == "" { + return nil, fmt.Errorf("upstream error: %d (passthrough rule matched)", resp.StatusCode) + } + return nil, fmt.Errorf("upstream error: %d (passthrough rule matched) message=%s", resp.StatusCode, upstreamMsg) + } + // Check custom error codes if !account.ShouldHandleErrorCode(resp.StatusCode) { appendOpsUpstreamError(c, OpsUpstreamErrorEvent{