fix(认证): 修复 OAuth token 缓存失效与 401 处理
新增 token 缓存失效接口并在刷新后清理 401 限流支持自定义规则与可配置冷却时间 补齐缓存失效与 401 处理测试 测试: make test
This commit is contained in:
353
backend/internal/service/ratelimit_service_401_test.go
Normal file
353
backend/internal/service/ratelimit_service_401_test.go
Normal file
@@ -0,0 +1,353 @@
|
||||
//go:build unit
|
||||
|
||||
package service
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"net/http"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/config"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
type rateLimitAccountRepoStub struct {
|
||||
mockAccountRepoForGemini
|
||||
tempCalls int
|
||||
tempUntil time.Time
|
||||
tempReason string
|
||||
setErrorCalls int
|
||||
}
|
||||
|
||||
func (r *rateLimitAccountRepoStub) SetTempUnschedulable(ctx context.Context, id int64, until time.Time, reason string) error {
|
||||
r.tempCalls++
|
||||
r.tempUntil = until
|
||||
r.tempReason = reason
|
||||
return nil
|
||||
}
|
||||
|
||||
func (r *rateLimitAccountRepoStub) SetError(ctx context.Context, id int64, errorMsg string) error {
|
||||
r.setErrorCalls++
|
||||
return nil
|
||||
}
|
||||
|
||||
type tokenCacheInvalidatorRecorder struct {
|
||||
accounts []*Account
|
||||
err error
|
||||
}
|
||||
|
||||
func (r *tokenCacheInvalidatorRecorder) InvalidateToken(ctx context.Context, account *Account) error {
|
||||
r.accounts = append(r.accounts, account)
|
||||
return r.err
|
||||
}
|
||||
|
||||
func TestRateLimitService_HandleUpstreamError_OAuth401TempUnschedulable(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
platform string
|
||||
}{
|
||||
{name: "gemini", platform: PlatformGemini},
|
||||
{name: "antigravity", platform: PlatformAntigravity},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
repo := &rateLimitAccountRepoStub{}
|
||||
invalidator := &tokenCacheInvalidatorRecorder{}
|
||||
service := NewRateLimitService(repo, nil, &config.Config{}, nil, nil)
|
||||
service.SetTokenCacheInvalidator(invalidator)
|
||||
account := &Account{
|
||||
ID: 100,
|
||||
Platform: tt.platform,
|
||||
Type: AccountTypeOAuth,
|
||||
}
|
||||
|
||||
start := time.Now()
|
||||
shouldDisable := service.HandleUpstreamError(context.Background(), account, 401, http.Header{}, []byte("unauthorized"))
|
||||
|
||||
require.True(t, shouldDisable)
|
||||
require.Equal(t, 1, repo.tempCalls)
|
||||
require.Equal(t, 0, repo.setErrorCalls)
|
||||
require.Len(t, invalidator.accounts, 1)
|
||||
require.WithinDuration(t, start.Add(5*time.Minute), repo.tempUntil, 10*time.Second)
|
||||
require.NotEmpty(t, repo.tempReason)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestRateLimitService_HandleUpstreamError_OAuth401InvalidatorError(t *testing.T) {
|
||||
repo := &rateLimitAccountRepoStub{}
|
||||
invalidator := &tokenCacheInvalidatorRecorder{err: errors.New("boom")}
|
||||
service := NewRateLimitService(repo, nil, &config.Config{}, nil, nil)
|
||||
service.SetTokenCacheInvalidator(invalidator)
|
||||
account := &Account{
|
||||
ID: 101,
|
||||
Platform: PlatformGemini,
|
||||
Type: AccountTypeOAuth,
|
||||
}
|
||||
|
||||
shouldDisable := service.HandleUpstreamError(context.Background(), account, 401, http.Header{}, []byte("unauthorized"))
|
||||
|
||||
require.True(t, shouldDisable)
|
||||
require.Equal(t, 1, repo.tempCalls)
|
||||
require.Equal(t, 0, repo.setErrorCalls)
|
||||
require.Len(t, invalidator.accounts, 1)
|
||||
}
|
||||
|
||||
func TestRateLimitService_HandleUpstreamError_OAuth401CustomRule(t *testing.T) {
|
||||
repo := &rateLimitAccountRepoStub{}
|
||||
invalidator := &tokenCacheInvalidatorRecorder{}
|
||||
service := NewRateLimitService(repo, nil, &config.Config{}, nil, nil)
|
||||
service.SetTokenCacheInvalidator(invalidator)
|
||||
account := &Account{
|
||||
ID: 103,
|
||||
Platform: PlatformGemini,
|
||||
Type: AccountTypeOAuth,
|
||||
Credentials: map[string]any{
|
||||
"temp_unschedulable_enabled": true,
|
||||
"temp_unschedulable_rules": []any{
|
||||
map[string]any{
|
||||
"error_code": 401,
|
||||
"keywords": []any{"unauthorized"},
|
||||
"duration_minutes": 30,
|
||||
"description": "custom rule",
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
start := time.Now()
|
||||
shouldDisable := service.HandleUpstreamError(context.Background(), account, 401, http.Header{}, []byte("unauthorized"))
|
||||
|
||||
require.True(t, shouldDisable)
|
||||
require.Equal(t, 1, repo.tempCalls)
|
||||
require.Equal(t, 0, repo.setErrorCalls)
|
||||
require.Len(t, invalidator.accounts, 1)
|
||||
require.WithinDuration(t, start.Add(30*time.Minute), repo.tempUntil, 10*time.Second)
|
||||
}
|
||||
|
||||
func TestRateLimitService_HandleUpstreamError_NonOAuth401(t *testing.T) {
|
||||
repo := &rateLimitAccountRepoStub{}
|
||||
invalidator := &tokenCacheInvalidatorRecorder{}
|
||||
service := NewRateLimitService(repo, nil, &config.Config{}, nil, nil)
|
||||
service.SetTokenCacheInvalidator(invalidator)
|
||||
account := &Account{
|
||||
ID: 102,
|
||||
Platform: PlatformOpenAI,
|
||||
Type: AccountTypeAPIKey,
|
||||
}
|
||||
|
||||
shouldDisable := service.HandleUpstreamError(context.Background(), account, 401, http.Header{}, []byte("unauthorized"))
|
||||
|
||||
require.True(t, shouldDisable)
|
||||
require.Equal(t, 0, repo.tempCalls)
|
||||
require.Equal(t, 1, repo.setErrorCalls)
|
||||
require.Empty(t, invalidator.accounts)
|
||||
}
|
||||
|
||||
// TestRateLimitService_HandleOAuth401_NilAccount 测试 account 为 nil 的情况
|
||||
func TestRateLimitService_HandleOAuth401_NilAccount(t *testing.T) {
|
||||
repo := &rateLimitAccountRepoStub{}
|
||||
service := NewRateLimitService(repo, nil, &config.Config{}, nil, nil)
|
||||
|
||||
result := service.handleOAuth401TempUnschedulable(context.Background(), nil, "error")
|
||||
|
||||
require.False(t, result)
|
||||
require.Equal(t, 0, repo.tempCalls)
|
||||
}
|
||||
|
||||
// TestRateLimitService_HandleOAuth401_NilInvalidator 测试 tokenCacheInvalidator 为 nil 的情况
|
||||
func TestRateLimitService_HandleOAuth401_NilInvalidator(t *testing.T) {
|
||||
repo := &rateLimitAccountRepoStub{}
|
||||
service := NewRateLimitService(repo, nil, &config.Config{}, nil, nil)
|
||||
// 不设置 tokenCacheInvalidator
|
||||
account := &Account{
|
||||
ID: 200,
|
||||
Platform: PlatformGemini,
|
||||
Type: AccountTypeOAuth,
|
||||
}
|
||||
|
||||
result := service.handleOAuth401TempUnschedulable(context.Background(), account, "error")
|
||||
|
||||
require.True(t, result)
|
||||
require.Equal(t, 1, repo.tempCalls)
|
||||
}
|
||||
|
||||
// TestRateLimitService_HandleOAuth401_SetTempUnschedulableFailed 测试 SetTempUnschedulable 失败的情况
|
||||
func TestRateLimitService_HandleOAuth401_SetTempUnschedulableFailed(t *testing.T) {
|
||||
repo := &rateLimitAccountRepoStubWithError{
|
||||
setTempErr: errors.New("db error"),
|
||||
}
|
||||
invalidator := &tokenCacheInvalidatorRecorder{}
|
||||
service := NewRateLimitService(repo, nil, &config.Config{}, nil, nil)
|
||||
service.SetTokenCacheInvalidator(invalidator)
|
||||
account := &Account{
|
||||
ID: 201,
|
||||
Platform: PlatformGemini,
|
||||
Type: AccountTypeOAuth,
|
||||
}
|
||||
|
||||
result := service.handleOAuth401TempUnschedulable(context.Background(), account, "error")
|
||||
|
||||
require.False(t, result) // 失败应返回 false
|
||||
require.Len(t, invalidator.accounts, 1) // 但 invalidator 仍然被调用
|
||||
}
|
||||
|
||||
// rateLimitAccountRepoStubWithError 支持返回错误的 stub
|
||||
type rateLimitAccountRepoStubWithError struct {
|
||||
mockAccountRepoForGemini
|
||||
setTempErr error
|
||||
setErrorCalls int
|
||||
}
|
||||
|
||||
func (r *rateLimitAccountRepoStubWithError) SetTempUnschedulable(ctx context.Context, id int64, until time.Time, reason string) error {
|
||||
return r.setTempErr
|
||||
}
|
||||
|
||||
func (r *rateLimitAccountRepoStubWithError) SetError(ctx context.Context, id int64, errorMsg string) error {
|
||||
r.setErrorCalls++
|
||||
return nil
|
||||
}
|
||||
|
||||
// TestRateLimitService_HandleOAuth401_WithTempUnschedCache 测试 tempUnschedCache 存在的情况
|
||||
func TestRateLimitService_HandleOAuth401_WithTempUnschedCache(t *testing.T) {
|
||||
repo := &rateLimitAccountRepoStub{}
|
||||
invalidator := &tokenCacheInvalidatorRecorder{}
|
||||
tempCache := &tempUnschedCacheStub{}
|
||||
service := NewRateLimitService(repo, nil, &config.Config{}, nil, tempCache)
|
||||
service.SetTokenCacheInvalidator(invalidator)
|
||||
account := &Account{
|
||||
ID: 202,
|
||||
Platform: PlatformGemini,
|
||||
Type: AccountTypeOAuth,
|
||||
}
|
||||
|
||||
result := service.handleOAuth401TempUnschedulable(context.Background(), account, "error")
|
||||
|
||||
require.True(t, result)
|
||||
require.Equal(t, 1, repo.tempCalls)
|
||||
require.Equal(t, 1, tempCache.setCalls)
|
||||
}
|
||||
|
||||
// TestRateLimitService_HandleOAuth401_TempUnschedCacheError 测试 tempUnschedCache 设置失败的情况
|
||||
func TestRateLimitService_HandleOAuth401_TempUnschedCacheError(t *testing.T) {
|
||||
repo := &rateLimitAccountRepoStub{}
|
||||
invalidator := &tokenCacheInvalidatorRecorder{}
|
||||
tempCache := &tempUnschedCacheStub{setErr: errors.New("cache error")}
|
||||
service := NewRateLimitService(repo, nil, &config.Config{}, nil, tempCache)
|
||||
service.SetTokenCacheInvalidator(invalidator)
|
||||
account := &Account{
|
||||
ID: 203,
|
||||
Platform: PlatformGemini,
|
||||
Type: AccountTypeOAuth,
|
||||
}
|
||||
|
||||
result := service.handleOAuth401TempUnschedulable(context.Background(), account, "error")
|
||||
|
||||
require.True(t, result) // 缓存错误不影响主流程
|
||||
require.Equal(t, 1, repo.tempCalls)
|
||||
}
|
||||
|
||||
// tempUnschedCacheStub 用于测试的 TempUnschedCache stub
|
||||
type tempUnschedCacheStub struct {
|
||||
setCalls int
|
||||
setErr error
|
||||
}
|
||||
|
||||
func (c *tempUnschedCacheStub) GetTempUnsched(ctx context.Context, accountID int64) (*TempUnschedState, error) {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
func (c *tempUnschedCacheStub) SetTempUnsched(ctx context.Context, accountID int64, state *TempUnschedState) error {
|
||||
c.setCalls++
|
||||
return c.setErr
|
||||
}
|
||||
|
||||
func (c *tempUnschedCacheStub) DeleteTempUnsched(ctx context.Context, accountID int64) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
// TestRateLimitService_OAuth401Cooldown 测试 oauth401Cooldown 函数
|
||||
func TestRateLimitService_OAuth401Cooldown(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
cfg *config.Config
|
||||
expected time.Duration
|
||||
}{
|
||||
{
|
||||
name: "default_when_config_zero",
|
||||
cfg: &config.Config{RateLimit: config.RateLimitConfig{OAuth401CooldownMinutes: 0}},
|
||||
expected: 5 * time.Minute,
|
||||
},
|
||||
{
|
||||
name: "custom_cooldown_10_minutes",
|
||||
cfg: &config.Config{RateLimit: config.RateLimitConfig{OAuth401CooldownMinutes: 10}},
|
||||
expected: 10 * time.Minute,
|
||||
},
|
||||
{
|
||||
name: "custom_cooldown_1_minute",
|
||||
cfg: &config.Config{RateLimit: config.RateLimitConfig{OAuth401CooldownMinutes: 1}},
|
||||
expected: 1 * time.Minute,
|
||||
},
|
||||
{
|
||||
name: "negative_value_uses_default",
|
||||
cfg: &config.Config{RateLimit: config.RateLimitConfig{OAuth401CooldownMinutes: -5}},
|
||||
expected: 5 * time.Minute,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
service := NewRateLimitService(nil, nil, tt.cfg, nil, nil)
|
||||
result := service.oauth401Cooldown()
|
||||
require.Equal(t, tt.expected, result)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// TestRateLimitService_OAuth401Cooldown_NilConfig 测试 cfg 为 nil 的情况
|
||||
func TestRateLimitService_OAuth401Cooldown_NilConfig(t *testing.T) {
|
||||
service := &RateLimitService{cfg: nil}
|
||||
result := service.oauth401Cooldown()
|
||||
require.Equal(t, 5*time.Minute, result)
|
||||
}
|
||||
|
||||
// TestRateLimitService_HandleOAuth401_WithCustomCooldown 测试自定义 cooldown 配置
|
||||
func TestRateLimitService_HandleOAuth401_WithCustomCooldown(t *testing.T) {
|
||||
repo := &rateLimitAccountRepoStub{}
|
||||
cfg := &config.Config{
|
||||
RateLimit: config.RateLimitConfig{
|
||||
OAuth401CooldownMinutes: 15,
|
||||
},
|
||||
}
|
||||
service := NewRateLimitService(repo, nil, cfg, nil, nil)
|
||||
account := &Account{
|
||||
ID: 204,
|
||||
Platform: PlatformAntigravity,
|
||||
Type: AccountTypeOAuth,
|
||||
}
|
||||
|
||||
start := time.Now()
|
||||
result := service.handleOAuth401TempUnschedulable(context.Background(), account, "error")
|
||||
|
||||
require.True(t, result)
|
||||
require.WithinDuration(t, start.Add(15*time.Minute), repo.tempUntil, 10*time.Second)
|
||||
}
|
||||
|
||||
// TestRateLimitService_HandleOAuth401_EmptyUpstreamMsg 测试 upstreamMsg 为空的情况
|
||||
func TestRateLimitService_HandleOAuth401_EmptyUpstreamMsg(t *testing.T) {
|
||||
repo := &rateLimitAccountRepoStub{}
|
||||
service := NewRateLimitService(repo, nil, &config.Config{}, nil, nil)
|
||||
account := &Account{
|
||||
ID: 205,
|
||||
Platform: PlatformGemini,
|
||||
Type: AccountTypeOAuth,
|
||||
}
|
||||
|
||||
result := service.handleOAuth401TempUnschedulable(context.Background(), account, "")
|
||||
|
||||
require.True(t, result)
|
||||
require.Contains(t, repo.tempReason, "Authentication failed (401)")
|
||||
}
|
||||
Reference in New Issue
Block a user