//go:build unit package service import ( "context" "errors" "net/http" "testing" "time" "github.com/Wei-Shaw/sub2api/internal/config" "github.com/stretchr/testify/require" ) type rateLimitAccountRepoStub struct { mockAccountRepoForGemini setErrorCalls int tempCalls int lastErrorMsg string } func (r *rateLimitAccountRepoStub) SetError(ctx context.Context, id int64, errorMsg string) error { r.setErrorCalls++ r.lastErrorMsg = errorMsg return nil } func (r *rateLimitAccountRepoStub) SetTempUnschedulable(ctx context.Context, id int64, until time.Time, reason string) error { r.tempCalls++ return nil } type tokenCacheInvalidatorRecorder struct { accounts []*Account err error } func (r *tokenCacheInvalidatorRecorder) InvalidateToken(ctx context.Context, account *Account) error { r.accounts = append(r.accounts, account) return r.err } func TestRateLimitService_HandleUpstreamError_OAuth401MarksError(t *testing.T) { tests := []struct { name string platform string }{ {name: "gemini", platform: PlatformGemini}, {name: "antigravity", platform: PlatformAntigravity}, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { repo := &rateLimitAccountRepoStub{} invalidator := &tokenCacheInvalidatorRecorder{} service := NewRateLimitService(repo, nil, &config.Config{}, nil, nil) service.SetTokenCacheInvalidator(invalidator) account := &Account{ ID: 100, Platform: tt.platform, Type: AccountTypeOAuth, Credentials: map[string]any{ "temp_unschedulable_enabled": true, "temp_unschedulable_rules": []any{ map[string]any{ "error_code": 401, "keywords": []any{"unauthorized"}, "duration_minutes": 30, "description": "custom rule", }, }, }, } shouldDisable := service.HandleUpstreamError(context.Background(), account, 401, http.Header{}, []byte("unauthorized")) require.True(t, shouldDisable) require.Equal(t, 1, repo.setErrorCalls) require.Equal(t, 0, repo.tempCalls) require.Contains(t, repo.lastErrorMsg, "Authentication failed (401)") require.Len(t, invalidator.accounts, 1) }) } } func TestRateLimitService_HandleUpstreamError_OAuth401InvalidatorError(t *testing.T) { repo := &rateLimitAccountRepoStub{} invalidator := &tokenCacheInvalidatorRecorder{err: errors.New("boom")} service := NewRateLimitService(repo, nil, &config.Config{}, nil, nil) service.SetTokenCacheInvalidator(invalidator) account := &Account{ ID: 101, Platform: PlatformGemini, Type: AccountTypeOAuth, } shouldDisable := service.HandleUpstreamError(context.Background(), account, 401, http.Header{}, []byte("unauthorized")) require.True(t, shouldDisable) require.Equal(t, 1, repo.setErrorCalls) require.Len(t, invalidator.accounts, 1) } func TestRateLimitService_HandleUpstreamError_NonOAuth401(t *testing.T) { repo := &rateLimitAccountRepoStub{} invalidator := &tokenCacheInvalidatorRecorder{} service := NewRateLimitService(repo, nil, &config.Config{}, nil, nil) service.SetTokenCacheInvalidator(invalidator) account := &Account{ ID: 102, Platform: PlatformOpenAI, Type: AccountTypeAPIKey, } shouldDisable := service.HandleUpstreamError(context.Background(), account, 401, http.Header{}, []byte("unauthorized")) require.True(t, shouldDisable) require.Equal(t, 1, repo.setErrorCalls) require.Empty(t, invalidator.accounts) }