Files
sub2api/backend/internal/service/ratelimit_service_401_test.go
erio 3fa5b8bca5 fix: flaky WebSocket test, usage request queue, and test improvements
- Fix flaky WebSocket passthrough test: allow StatusNormalClosure after
  client close instead of requiring NoError (race condition fix)
- Fix ratelimit 401 test: use PlatformOpenAI instead of PlatformGemini
  for OAuth token cache invalidation scenario (more accurate)
- Add usageLoadQueue: Anthropic OAuth/setup-token accounts sharing the
  same proxy exit are serialized with 1-2s jitter to prevent upstream 429
- AccountUsageCell: add module-level usage cache (5min TTL), unmounted
  safety guard, and integrate enqueueUsageRequest for throttled fetching
2026-04-14 20:13:59 +08:00

163 lines
5.0 KiB
Go

//go:build unit
package service
import (
"context"
"errors"
"net/http"
"testing"
"time"
"github.com/Wei-Shaw/sub2api/internal/config"
"github.com/stretchr/testify/require"
)
type rateLimitAccountRepoStub struct {
mockAccountRepoForGemini
setErrorCalls int
tempCalls int
updateCredentialsCalls int
lastCredentials map[string]any
lastErrorMsg string
}
func (r *rateLimitAccountRepoStub) SetError(ctx context.Context, id int64, errorMsg string) error {
r.setErrorCalls++
r.lastErrorMsg = errorMsg
return nil
}
func (r *rateLimitAccountRepoStub) SetTempUnschedulable(ctx context.Context, id int64, until time.Time, reason string) error {
r.tempCalls++
return nil
}
func (r *rateLimitAccountRepoStub) UpdateCredentials(ctx context.Context, id int64, credentials map[string]any) error {
r.updateCredentialsCalls++
r.lastCredentials = cloneCredentials(credentials)
return nil
}
type tokenCacheInvalidatorRecorder struct {
accounts []*Account
err error
}
func (r *tokenCacheInvalidatorRecorder) InvalidateToken(ctx context.Context, account *Account) error {
r.accounts = append(r.accounts, account)
return r.err
}
func TestRateLimitService_HandleUpstreamError_OAuth401SetsTempUnschedulable(t *testing.T) {
t.Run("gemini", func(t *testing.T) {
repo := &rateLimitAccountRepoStub{}
invalidator := &tokenCacheInvalidatorRecorder{}
service := NewRateLimitService(repo, nil, &config.Config{}, nil, nil)
service.SetTokenCacheInvalidator(invalidator)
account := &Account{
ID: 100,
Platform: PlatformGemini,
Type: AccountTypeOAuth,
Credentials: map[string]any{
"temp_unschedulable_enabled": true,
"temp_unschedulable_rules": []any{
map[string]any{
"error_code": 401,
"keywords": []any{"unauthorized"},
"duration_minutes": 30,
"description": "custom rule",
},
},
},
}
shouldDisable := service.HandleUpstreamError(context.Background(), account, 401, http.Header{}, []byte("unauthorized"))
require.True(t, shouldDisable)
require.Equal(t, 0, repo.setErrorCalls)
require.Equal(t, 1, repo.tempCalls)
require.Len(t, invalidator.accounts, 1)
})
t.Run("antigravity_401_uses_SetError", func(t *testing.T) {
// Antigravity 401 由 applyErrorPolicy 的 temp_unschedulable_rules 控制,
// HandleUpstreamError 中走 SetError 路径。
repo := &rateLimitAccountRepoStub{}
invalidator := &tokenCacheInvalidatorRecorder{}
service := NewRateLimitService(repo, nil, &config.Config{}, nil, nil)
service.SetTokenCacheInvalidator(invalidator)
account := &Account{
ID: 100,
Platform: PlatformAntigravity,
Type: AccountTypeOAuth,
}
shouldDisable := service.HandleUpstreamError(context.Background(), account, 401, http.Header{}, []byte("unauthorized"))
require.True(t, shouldDisable)
require.Equal(t, 1, repo.setErrorCalls)
require.Equal(t, 0, repo.tempCalls)
require.Empty(t, invalidator.accounts)
})
}
// TestRateLimitService_HandleUpstreamError_OAuth401InvalidatorError
// OpenAI OAuth 401 缓存失效出错时仍走 temp_unschedulable
func TestRateLimitService_HandleUpstreamError_OAuth401InvalidatorError(t *testing.T) {
repo := &rateLimitAccountRepoStub{}
invalidator := &tokenCacheInvalidatorRecorder{err: errors.New("boom")}
service := NewRateLimitService(repo, nil, &config.Config{}, nil, nil)
service.SetTokenCacheInvalidator(invalidator)
account := &Account{
ID: 101,
Platform: PlatformOpenAI,
Type: AccountTypeOAuth,
}
shouldDisable := service.HandleUpstreamError(context.Background(), account, 401, http.Header{}, []byte("unauthorized"))
require.True(t, shouldDisable)
require.Equal(t, 0, repo.setErrorCalls)
require.Equal(t, 1, repo.tempCalls)
require.Equal(t, 1, repo.updateCredentialsCalls)
require.Len(t, invalidator.accounts, 1)
}
func TestRateLimitService_HandleUpstreamError_NonOAuth401(t *testing.T) {
repo := &rateLimitAccountRepoStub{}
invalidator := &tokenCacheInvalidatorRecorder{}
service := NewRateLimitService(repo, nil, &config.Config{}, nil, nil)
service.SetTokenCacheInvalidator(invalidator)
account := &Account{
ID: 102,
Platform: PlatformOpenAI,
Type: AccountTypeAPIKey,
}
shouldDisable := service.HandleUpstreamError(context.Background(), account, 401, http.Header{}, []byte("unauthorized"))
require.True(t, shouldDisable)
require.Equal(t, 1, repo.setErrorCalls)
require.Empty(t, invalidator.accounts)
}
func TestRateLimitService_HandleUpstreamError_OAuth401UsesCredentialsUpdater(t *testing.T) {
repo := &rateLimitAccountRepoStub{}
service := NewRateLimitService(repo, nil, &config.Config{}, nil, nil)
account := &Account{
ID: 103,
Platform: PlatformOpenAI,
Type: AccountTypeOAuth,
Credentials: map[string]any{
"access_token": "token",
},
}
shouldDisable := service.HandleUpstreamError(context.Background(), account, 401, http.Header{}, []byte("unauthorized"))
require.True(t, shouldDisable)
require.Equal(t, 1, repo.updateCredentialsCalls)
require.NotEmpty(t, repo.lastCredentials["expires_at"])
}