diff --git a/backend/internal/service/gemini_error_policy_test.go b/backend/internal/service/gemini_error_policy_test.go new file mode 100644 index 00000000..2ce8793a --- /dev/null +++ b/backend/internal/service/gemini_error_policy_test.go @@ -0,0 +1,384 @@ +//go:build unit + +package service + +import ( + "context" + "net/http" + "net/http/httptest" + "testing" + "time" + + "github.com/Wei-Shaw/sub2api/internal/config" + "github.com/gin-gonic/gin" + "github.com/stretchr/testify/require" +) + +// --------------------------------------------------------------------------- +// TestShouldFailoverGeminiUpstreamError — verifies the failover decision +// for the ErrorPolicyNone path (original logic preserved). +// --------------------------------------------------------------------------- + +func TestShouldFailoverGeminiUpstreamError(t *testing.T) { + svc := &GeminiMessagesCompatService{} + + tests := []struct { + name string + statusCode int + expected bool + }{ + {"401_failover", 401, true}, + {"403_failover", 403, true}, + {"429_failover", 429, true}, + {"529_failover", 529, true}, + {"500_failover", 500, true}, + {"502_failover", 502, true}, + {"503_failover", 503, true}, + {"400_no_failover", 400, false}, + {"404_no_failover", 404, false}, + {"422_no_failover", 422, false}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got := svc.shouldFailoverGeminiUpstreamError(tt.statusCode) + require.Equal(t, tt.expected, got) + }) + } +} + +// --------------------------------------------------------------------------- +// TestCheckErrorPolicy_GeminiAccounts — verifies CheckErrorPolicy works +// correctly for Gemini platform accounts (API Key type). +// --------------------------------------------------------------------------- + +func TestCheckErrorPolicy_GeminiAccounts(t *testing.T) { + tests := []struct { + name string + account *Account + statusCode int + body []byte + expected ErrorPolicyResult + }{ + { + name: "gemini_apikey_custom_codes_hit", + account: &Account{ + ID: 100, + Type: AccountTypeAPIKey, + Platform: PlatformGemini, + Credentials: map[string]any{ + "custom_error_codes_enabled": true, + "custom_error_codes": []any{float64(429), float64(500)}, + }, + }, + statusCode: 429, + body: []byte(`{"error":"rate limited"}`), + expected: ErrorPolicyMatched, + }, + { + name: "gemini_apikey_custom_codes_miss", + account: &Account{ + ID: 101, + Type: AccountTypeAPIKey, + Platform: PlatformGemini, + Credentials: map[string]any{ + "custom_error_codes_enabled": true, + "custom_error_codes": []any{float64(429)}, + }, + }, + statusCode: 500, + body: []byte(`{"error":"internal"}`), + expected: ErrorPolicySkipped, + }, + { + name: "gemini_apikey_no_custom_codes_returns_none", + account: &Account{ + ID: 102, + Type: AccountTypeAPIKey, + Platform: PlatformGemini, + }, + statusCode: 500, + body: []byte(`{"error":"internal"}`), + expected: ErrorPolicyNone, + }, + { + name: "gemini_apikey_temp_unschedulable_hit", + account: &Account{ + ID: 103, + Type: AccountTypeAPIKey, + Platform: PlatformGemini, + Credentials: map[string]any{ + "temp_unschedulable_enabled": true, + "temp_unschedulable_rules": []any{ + map[string]any{ + "error_code": float64(503), + "keywords": []any{"overloaded"}, + "duration_minutes": float64(10), + }, + }, + }, + }, + statusCode: 503, + body: []byte(`overloaded service`), + expected: ErrorPolicyTempUnscheduled, + }, + { + name: "gemini_custom_codes_override_temp_unschedulable", + account: &Account{ + ID: 104, + Type: AccountTypeAPIKey, + Platform: PlatformGemini, + Credentials: map[string]any{ + "custom_error_codes_enabled": true, + "custom_error_codes": []any{float64(503)}, + "temp_unschedulable_enabled": true, + "temp_unschedulable_rules": []any{ + map[string]any{ + "error_code": float64(503), + "keywords": []any{"overloaded"}, + "duration_minutes": float64(10), + }, + }, + }, + }, + statusCode: 503, + body: []byte(`overloaded`), + expected: ErrorPolicyMatched, // custom codes take precedence + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + repo := &errorPolicyRepoStub{} + svc := NewRateLimitService(repo, nil, &config.Config{}, nil, nil) + + result := svc.CheckErrorPolicy(context.Background(), tt.account, tt.statusCode, tt.body) + require.Equal(t, tt.expected, result) + }) + } +} + +// --------------------------------------------------------------------------- +// TestGeminiErrorPolicyIntegration — verifies the Gemini error handling +// paths produce the correct behavior for each ErrorPolicyResult. +// +// These tests simulate the inline error policy switch in handleClaudeCompat +// and forwardNativeGemini by calling the same methods in the same order. +// --------------------------------------------------------------------------- + +func TestGeminiErrorPolicyIntegration(t *testing.T) { + gin.SetMode(gin.TestMode) + + tests := []struct { + name string + account *Account + statusCode int + respBody []byte + expectFailover bool // expect UpstreamFailoverError + expectHandleError bool // expect handleGeminiUpstreamError to be called + expectShouldFailover bool // for None path, whether shouldFailover triggers + }{ + { + name: "custom_codes_matched_429_failover", + account: &Account{ + ID: 200, + Type: AccountTypeAPIKey, + Platform: PlatformGemini, + Credentials: map[string]any{ + "custom_error_codes_enabled": true, + "custom_error_codes": []any{float64(429)}, + }, + }, + statusCode: 429, + respBody: []byte(`{"error":"rate limited"}`), + expectFailover: true, + expectHandleError: true, + }, + { + name: "custom_codes_skipped_500_no_failover", + account: &Account{ + ID: 201, + Type: AccountTypeAPIKey, + Platform: PlatformGemini, + Credentials: map[string]any{ + "custom_error_codes_enabled": true, + "custom_error_codes": []any{float64(429)}, + }, + }, + statusCode: 500, + respBody: []byte(`{"error":"internal"}`), + expectFailover: false, + expectHandleError: false, + }, + { + name: "temp_unschedulable_matched_failover", + account: &Account{ + ID: 202, + Type: AccountTypeAPIKey, + Platform: PlatformGemini, + Credentials: map[string]any{ + "temp_unschedulable_enabled": true, + "temp_unschedulable_rules": []any{ + map[string]any{ + "error_code": float64(503), + "keywords": []any{"overloaded"}, + "duration_minutes": float64(10), + }, + }, + }, + }, + statusCode: 503, + respBody: []byte(`overloaded`), + expectFailover: true, + expectHandleError: true, + }, + { + name: "no_policy_429_failover_via_shouldFailover", + account: &Account{ + ID: 203, + Type: AccountTypeAPIKey, + Platform: PlatformGemini, + }, + statusCode: 429, + respBody: []byte(`{"error":"rate limited"}`), + expectFailover: true, + expectHandleError: true, + expectShouldFailover: true, + }, + { + name: "no_policy_400_no_failover", + account: &Account{ + ID: 204, + Type: AccountTypeAPIKey, + Platform: PlatformGemini, + }, + statusCode: 400, + respBody: []byte(`{"error":"bad request"}`), + expectFailover: false, + expectHandleError: true, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + repo := &geminiErrorPolicyRepo{} + rlSvc := NewRateLimitService(repo, nil, &config.Config{}, nil, nil) + svc := &GeminiMessagesCompatService{ + accountRepo: repo, + rateLimitService: rlSvc, + } + + writer := httptest.NewRecorder() + c, _ := gin.CreateTestContext(writer) + c.Request = httptest.NewRequest(http.MethodPost, "/v1/messages", nil) + + // Simulate the Claude compat error handling path (same logic as native). + // This mirrors the inline switch in handleClaudeCompat. + var handleErrorCalled bool + var gotFailover bool + + ctx := context.Background() + statusCode := tt.statusCode + respBody := tt.respBody + account := tt.account + headers := http.Header{} + + if svc.rateLimitService != nil { + switch svc.rateLimitService.CheckErrorPolicy(ctx, account, statusCode, respBody) { + case ErrorPolicySkipped: + // Skipped → return error directly (no handleGeminiUpstreamError, no failover) + gotFailover = false + handleErrorCalled = false + goto verify + case ErrorPolicyMatched, ErrorPolicyTempUnscheduled: + svc.handleGeminiUpstreamError(ctx, account, statusCode, headers, respBody) + handleErrorCalled = true + gotFailover = true + goto verify + } + } + + // ErrorPolicyNone → original logic + svc.handleGeminiUpstreamError(ctx, account, statusCode, headers, respBody) + handleErrorCalled = true + if svc.shouldFailoverGeminiUpstreamError(statusCode) { + gotFailover = true + } + + verify: + require.Equal(t, tt.expectFailover, gotFailover, "failover mismatch") + require.Equal(t, tt.expectHandleError, handleErrorCalled, "handleGeminiUpstreamError call mismatch") + + if tt.expectShouldFailover { + require.True(t, svc.shouldFailoverGeminiUpstreamError(statusCode), + "shouldFailoverGeminiUpstreamError should return true for status %d", statusCode) + } + }) + } +} + +// --------------------------------------------------------------------------- +// TestGeminiErrorPolicy_NilRateLimitService — verifies nil safety +// --------------------------------------------------------------------------- + +func TestGeminiErrorPolicy_NilRateLimitService(t *testing.T) { + svc := &GeminiMessagesCompatService{ + rateLimitService: nil, + } + + // When rateLimitService is nil, error policy is skipped → falls through to + // shouldFailoverGeminiUpstreamError (original logic). + // Verify this doesn't panic and follows expected behavior. + + ctx := context.Background() + account := &Account{ + ID: 300, + Type: AccountTypeAPIKey, + Platform: PlatformGemini, + Credentials: map[string]any{ + "custom_error_codes_enabled": true, + "custom_error_codes": []any{float64(429)}, + }, + } + + // The nil check should prevent CheckErrorPolicy from being called + if svc.rateLimitService != nil { + t.Fatal("rateLimitService should be nil for this test") + } + + // shouldFailoverGeminiUpstreamError still works + require.True(t, svc.shouldFailoverGeminiUpstreamError(429)) + require.False(t, svc.shouldFailoverGeminiUpstreamError(400)) + + // handleGeminiUpstreamError should not panic with nil rateLimitService + require.NotPanics(t, func() { + svc.handleGeminiUpstreamError(ctx, account, 500, http.Header{}, []byte(`error`)) + }) +} + +// --------------------------------------------------------------------------- +// geminiErrorPolicyRepo — minimal AccountRepository stub for Gemini error +// policy tests. Embeds mockAccountRepoForGemini and adds tracking. +// --------------------------------------------------------------------------- + +type geminiErrorPolicyRepo struct { + mockAccountRepoForGemini + setErrorCalls int + setRateLimitedCalls int + setTempCalls int +} + +func (r *geminiErrorPolicyRepo) SetError(_ context.Context, _ int64, _ string) error { + r.setErrorCalls++ + return nil +} + +func (r *geminiErrorPolicyRepo) SetRateLimited(_ context.Context, _ int64, _ time.Time) error { + r.setRateLimitedCalls++ + return nil +} + +func (r *geminiErrorPolicyRepo) SetTempUnschedulable(_ context.Context, _ int64, _ time.Time, _ string) error { + r.setTempCalls++ + return nil +} diff --git a/backend/internal/service/gemini_messages_compat_service.go b/backend/internal/service/gemini_messages_compat_service.go index 4e0442fd..d77f6f92 100644 --- a/backend/internal/service/gemini_messages_compat_service.go +++ b/backend/internal/service/gemini_messages_compat_service.go @@ -831,38 +831,47 @@ func (s *GeminiMessagesCompatService) Forward(ctx context.Context, c *gin.Contex if resp.StatusCode >= 400 { respBody, _ := io.ReadAll(io.LimitReader(resp.Body, 2<<20)) - tempMatched := false + // 统一错误策略:自定义错误码 + 临时不可调度 if s.rateLimitService != nil { - tempMatched = s.rateLimitService.HandleTempUnschedulable(ctx, account, resp.StatusCode, respBody) - } - s.handleGeminiUpstreamError(ctx, account, resp.StatusCode, resp.Header, respBody) - if tempMatched { - upstreamReqID := resp.Header.Get(requestIDHeader) - if upstreamReqID == "" { - upstreamReqID = resp.Header.Get("x-goog-request-id") - } - upstreamMsg := strings.TrimSpace(extractUpstreamErrorMessage(respBody)) - upstreamMsg = sanitizeUpstreamErrorMessage(upstreamMsg) - upstreamDetail := "" - if s.cfg != nil && s.cfg.Gateway.LogUpstreamErrorBody { - maxBytes := s.cfg.Gateway.LogUpstreamErrorBodyMaxBytes - if maxBytes <= 0 { - maxBytes = 2048 + switch s.rateLimitService.CheckErrorPolicy(ctx, account, resp.StatusCode, respBody) { + case ErrorPolicySkipped: + upstreamReqID := resp.Header.Get(requestIDHeader) + if upstreamReqID == "" { + upstreamReqID = resp.Header.Get("x-goog-request-id") } - upstreamDetail = truncateString(string(respBody), maxBytes) + return nil, s.writeGeminiMappedError(c, account, resp.StatusCode, upstreamReqID, respBody) + case ErrorPolicyMatched, ErrorPolicyTempUnscheduled: + s.handleGeminiUpstreamError(ctx, account, resp.StatusCode, resp.Header, respBody) + upstreamReqID := resp.Header.Get(requestIDHeader) + if upstreamReqID == "" { + upstreamReqID = resp.Header.Get("x-goog-request-id") + } + upstreamMsg := strings.TrimSpace(extractUpstreamErrorMessage(respBody)) + upstreamMsg = sanitizeUpstreamErrorMessage(upstreamMsg) + upstreamDetail := "" + if s.cfg != nil && s.cfg.Gateway.LogUpstreamErrorBody { + maxBytes := s.cfg.Gateway.LogUpstreamErrorBodyMaxBytes + if maxBytes <= 0 { + maxBytes = 2048 + } + upstreamDetail = truncateString(string(respBody), maxBytes) + } + appendOpsUpstreamError(c, OpsUpstreamErrorEvent{ + Platform: account.Platform, + AccountID: account.ID, + AccountName: account.Name, + UpstreamStatusCode: resp.StatusCode, + UpstreamRequestID: upstreamReqID, + Kind: "failover", + Message: upstreamMsg, + Detail: upstreamDetail, + }) + return nil, &UpstreamFailoverError{StatusCode: resp.StatusCode, ResponseBody: respBody} } - appendOpsUpstreamError(c, OpsUpstreamErrorEvent{ - Platform: account.Platform, - AccountID: account.ID, - AccountName: account.Name, - UpstreamStatusCode: resp.StatusCode, - UpstreamRequestID: upstreamReqID, - Kind: "failover", - Message: upstreamMsg, - Detail: upstreamDetail, - }) - return nil, &UpstreamFailoverError{StatusCode: resp.StatusCode, ResponseBody: respBody} } + + // ErrorPolicyNone → 原有逻辑 + s.handleGeminiUpstreamError(ctx, account, resp.StatusCode, resp.Header, respBody) if s.shouldFailoverGeminiUpstreamError(resp.StatusCode) { upstreamReqID := resp.Header.Get(requestIDHeader) if upstreamReqID == "" { @@ -1249,14 +1258,9 @@ func (s *GeminiMessagesCompatService) ForwardNative(ctx context.Context, c *gin. if resp.StatusCode >= 400 { respBody, _ := io.ReadAll(io.LimitReader(resp.Body, 2<<20)) - tempMatched := false - if s.rateLimitService != nil { - tempMatched = s.rateLimitService.HandleTempUnschedulable(ctx, account, resp.StatusCode, respBody) - } - s.handleGeminiUpstreamError(ctx, account, resp.StatusCode, resp.Header, respBody) - // Best-effort fallback for OAuth tokens missing AI Studio scopes when calling countTokens. // This avoids Gemini SDKs failing hard during preflight token counting. + // Checked before error policy so it always works regardless of custom error codes. if action == "countTokens" && isOAuth && isGeminiInsufficientScope(resp.Header, respBody) { estimated := estimateGeminiCountTokens(body) c.JSON(http.StatusOK, map[string]any{"totalTokens": estimated}) @@ -1270,30 +1274,46 @@ func (s *GeminiMessagesCompatService) ForwardNative(ctx context.Context, c *gin. }, nil } - if tempMatched { - evBody := unwrapIfNeeded(isOAuth, respBody) - upstreamMsg := strings.TrimSpace(extractUpstreamErrorMessage(evBody)) - upstreamMsg = sanitizeUpstreamErrorMessage(upstreamMsg) - upstreamDetail := "" - if s.cfg != nil && s.cfg.Gateway.LogUpstreamErrorBody { - maxBytes := s.cfg.Gateway.LogUpstreamErrorBodyMaxBytes - if maxBytes <= 0 { - maxBytes = 2048 + // 统一错误策略:自定义错误码 + 临时不可调度 + if s.rateLimitService != nil { + switch s.rateLimitService.CheckErrorPolicy(ctx, account, resp.StatusCode, respBody) { + case ErrorPolicySkipped: + respBody = unwrapIfNeeded(isOAuth, respBody) + contentType := resp.Header.Get("Content-Type") + if contentType == "" { + contentType = "application/json" } - upstreamDetail = truncateString(string(evBody), maxBytes) + c.Data(resp.StatusCode, contentType, respBody) + return nil, fmt.Errorf("gemini upstream error: %d (skipped by error policy)", resp.StatusCode) + case ErrorPolicyMatched, ErrorPolicyTempUnscheduled: + s.handleGeminiUpstreamError(ctx, account, resp.StatusCode, resp.Header, respBody) + evBody := unwrapIfNeeded(isOAuth, respBody) + upstreamMsg := strings.TrimSpace(extractUpstreamErrorMessage(evBody)) + upstreamMsg = sanitizeUpstreamErrorMessage(upstreamMsg) + upstreamDetail := "" + if s.cfg != nil && s.cfg.Gateway.LogUpstreamErrorBody { + maxBytes := s.cfg.Gateway.LogUpstreamErrorBodyMaxBytes + if maxBytes <= 0 { + maxBytes = 2048 + } + upstreamDetail = truncateString(string(evBody), maxBytes) + } + appendOpsUpstreamError(c, OpsUpstreamErrorEvent{ + Platform: account.Platform, + AccountID: account.ID, + AccountName: account.Name, + UpstreamStatusCode: resp.StatusCode, + UpstreamRequestID: requestID, + Kind: "failover", + Message: upstreamMsg, + Detail: upstreamDetail, + }) + return nil, &UpstreamFailoverError{StatusCode: resp.StatusCode, ResponseBody: respBody} } - appendOpsUpstreamError(c, OpsUpstreamErrorEvent{ - Platform: account.Platform, - AccountID: account.ID, - AccountName: account.Name, - UpstreamStatusCode: resp.StatusCode, - UpstreamRequestID: requestID, - Kind: "failover", - Message: upstreamMsg, - Detail: upstreamDetail, - }) - return nil, &UpstreamFailoverError{StatusCode: resp.StatusCode, ResponseBody: respBody} } + + // ErrorPolicyNone → 原有逻辑 + s.handleGeminiUpstreamError(ctx, account, resp.StatusCode, resp.Header, respBody) if s.shouldFailoverGeminiUpstreamError(resp.StatusCode) { evBody := unwrapIfNeeded(isOAuth, respBody) upstreamMsg := strings.TrimSpace(extractUpstreamErrorMessage(evBody))