修改403逻辑: 先临时冷却,再根据连续次数决定是否判坏号

This commit is contained in:
wx-11
2026-04-23 12:58:13 +08:00
parent eea6f38881
commit 11cf23da7d
11 changed files with 370 additions and 17 deletions

View File

@@ -0,0 +1,51 @@
package repository
import (
"context"
"fmt"
"github.com/Wei-Shaw/sub2api/internal/service"
"github.com/redis/go-redis/v9"
)
const openAI403CounterPrefix = "openai_403_count:account:"
var openAI403CounterIncrScript = redis.NewScript(`
local key = KEYS[1]
local ttl = tonumber(ARGV[1])
local count = redis.call('INCR', key)
if count == 1 then
redis.call('EXPIRE', key, ttl)
end
return count
`)
type openAI403CounterCache struct {
rdb *redis.Client
}
func NewOpenAI403CounterCache(rdb *redis.Client) service.OpenAI403CounterCache {
return &openAI403CounterCache{rdb: rdb}
}
func (c *openAI403CounterCache) IncrementOpenAI403Count(ctx context.Context, accountID int64, windowMinutes int) (int64, error) {
key := fmt.Sprintf("%s%d", openAI403CounterPrefix, accountID)
ttlSeconds := windowMinutes * 60
if ttlSeconds < 60 {
ttlSeconds = 60
}
result, err := openAI403CounterIncrScript.Run(ctx, c.rdb, []string{key}, ttlSeconds).Int64()
if err != nil {
return 0, fmt.Errorf("increment openai 403 count: %w", err)
}
return result, nil
}
func (c *openAI403CounterCache) ResetOpenAI403Count(ctx context.Context, accountID int64) error {
key := fmt.Sprintf("%s%d", openAI403CounterPrefix, accountID)
return c.rdb.Del(ctx, key).Err()
}

View File

@@ -96,6 +96,7 @@ var ProviderSet = wire.NewSet(
NewAPIKeyCache,
NewTempUnschedCache,
NewTimeoutCounterCache,
NewOpenAI403CounterCache,
NewInternal500CounterCache,
ProvideConcurrencyCache,
ProvideSessionLimitCache,

View File

@@ -0,0 +1,11 @@
package service
import "context"
// OpenAI403CounterCache 追踪 OpenAI 账号连续 403 失败次数。
type OpenAI403CounterCache interface {
// IncrementOpenAI403Count 原子递增 403 计数并返回当前值。
IncrementOpenAI403Count(ctx context.Context, accountID int64, windowMinutes int) (int64, error)
// ResetOpenAI403Count 成功后清零计数器。
ResetOpenAI403Count(ctx context.Context, accountID int64) error
}

View File

@@ -0,0 +1,39 @@
package service
import (
"context"
"testing"
"github.com/stretchr/testify/require"
)
type openAI403CounterResetStub struct {
resetCalls []int64
}
func (s *openAI403CounterResetStub) IncrementOpenAI403Count(context.Context, int64, int) (int64, error) {
return 0, nil
}
func (s *openAI403CounterResetStub) ResetOpenAI403Count(_ context.Context, accountID int64) error {
s.resetCalls = append(s.resetCalls, accountID)
return nil
}
func TestOpenAIGatewayServiceRecordUsage_ResetsOpenAI403CounterBeforeZeroUsageReturn(t *testing.T) {
counter := &openAI403CounterResetStub{}
rateLimitSvc := NewRateLimitService(nil, nil, nil, nil, nil)
rateLimitSvc.SetOpenAI403CounterCache(counter)
svc := &OpenAIGatewayService{
rateLimitService: rateLimitSvc,
}
err := svc.RecordUsage(context.Background(), &OpenAIRecordUsageInput{
Result: &OpenAIForwardResult{},
Account: &Account{ID: 777, Platform: PlatformOpenAI},
})
require.NoError(t, err)
require.Equal(t, []int64{777}, counter.resetCalls)
}

View File

@@ -4425,6 +4425,9 @@ type OpenAIRecordUsageInput struct {
// RecordUsage records usage and deducts balance
func (s *OpenAIGatewayService) RecordUsage(ctx context.Context, input *OpenAIRecordUsageInput) error {
result := input.Result
if s.rateLimitService != nil && input != nil && input.Account != nil && input.Account.Platform == PlatformOpenAI {
s.rateLimitService.ResetOpenAI403Counter(ctx, input.Account.ID)
}
// 跳过所有 token 均为零的用量记录——上游未返回 usage 时不应写入数据库
if result.Usage.InputTokens == 0 && result.Usage.OutputTokens == 0 &&

View File

@@ -1,8 +1,10 @@
package service
import (
"bytes"
"context"
"encoding/json"
"fmt"
"log/slog"
"net/http"
"strconv"
@@ -23,6 +25,7 @@ type RateLimitService struct {
geminiQuotaService *GeminiQuotaService
tempUnschedCache TempUnschedCache
timeoutCounterCache TimeoutCounterCache
openAI403CounterCache OpenAI403CounterCache
settingService *SettingService
tokenCacheInvalidator TokenCacheInvalidator
usageCacheMu sync.RWMutex
@@ -52,6 +55,12 @@ type geminiUsageTotalsBatchProvider interface {
const geminiPrecheckCacheTTL = time.Minute
const (
openAI403CooldownMinutesDefault = 10
openAI403DisableThreshold = 3
openAI403CounterWindowMinutes = 180
)
// NewRateLimitService 创建RateLimitService实例
func NewRateLimitService(accountRepo AccountRepository, usageRepo UsageLogRepository, cfg *config.Config, geminiQuotaService *GeminiQuotaService, tempUnschedCache TempUnschedCache) *RateLimitService {
return &RateLimitService{
@@ -69,6 +78,11 @@ func (s *RateLimitService) SetTimeoutCounterCache(cache TimeoutCounterCache) {
s.timeoutCounterCache = cache
}
// SetOpenAI403CounterCache 设置 OpenAI 403 连续失败计数器(可选依赖)
func (s *RateLimitService) SetOpenAI403CounterCache(cache OpenAI403CounterCache) {
s.openAI403CounterCache = cache
}
// SetSettingService 设置系统设置服务(可选依赖)
func (s *RateLimitService) SetSettingService(settingService *SettingService) {
s.settingService = settingService
@@ -655,6 +669,30 @@ func (s *RateLimitService) handleAuthError(ctx context.Context, account *Account
slog.Warn("account_disabled_auth_error", "account_id", account.ID, "error", errorMsg)
}
func buildForbiddenErrorMessage(prefix string, upstreamMsg string, responseBody []byte, fallback string) string {
prefix = strings.TrimSpace(prefix)
if prefix != "" && !strings.HasSuffix(prefix, " ") {
prefix += " "
}
if msg := strings.TrimSpace(upstreamMsg); msg != "" {
return prefix + msg
}
rawBody := bytes.TrimSpace(responseBody)
if len(rawBody) > 0 {
if json.Valid(rawBody) {
var compact bytes.Buffer
if err := json.Compact(&compact, rawBody); err == nil {
return prefix + truncateForLog(compact.Bytes(), 512)
}
}
return prefix + truncateForLog(rawBody, 512)
}
return prefix + fallback
}
// handle403 处理 403 Forbidden 错误
// Antigravity 平台区分 validation/violation/generic 三种类型,均 SetError 永久禁用;
// 其他平台保持原有 SetError 行为。
@@ -662,15 +700,64 @@ func (s *RateLimitService) handle403(ctx context.Context, account *Account, upst
if account.Platform == PlatformAntigravity {
return s.handleAntigravity403(ctx, account, upstreamMsg, responseBody)
}
// 非 Antigravity 平台:保持原有行为
msg := "Access forbidden (403): account may be suspended or lack permissions"
if upstreamMsg != "" {
msg = "Access forbidden (403): " + upstreamMsg
if account.Platform == PlatformOpenAI {
return s.handleOpenAI403(ctx, account, upstreamMsg, responseBody)
}
// 非 Antigravity 平台:保持原有行为
msg := buildForbiddenErrorMessage(
"Access forbidden (403):",
upstreamMsg,
responseBody,
"account may be suspended or lack permissions",
)
s.handleAuthError(ctx, account, msg)
return true
}
func (s *RateLimitService) handleOpenAI403(ctx context.Context, account *Account, upstreamMsg string, responseBody []byte) (shouldDisable bool) {
msg := buildForbiddenErrorMessage(
"Access forbidden (403):",
upstreamMsg,
responseBody,
"account may be suspended or lack permissions",
)
if s.openAI403CounterCache == nil {
s.handleAuthError(ctx, account, msg)
return true
}
count, err := s.openAI403CounterCache.IncrementOpenAI403Count(ctx, account.ID, openAI403CounterWindowMinutes)
if err != nil {
slog.Warn("openai_403_increment_failed", "account_id", account.ID, "error", err)
s.handleAuthError(ctx, account, msg)
return true
}
if count >= openAI403DisableThreshold {
msg = fmt.Sprintf("%s | consecutive_403=%d/%d", msg, count, openAI403DisableThreshold)
s.handleAuthError(ctx, account, msg)
return true
}
until := time.Now().Add(time.Duration(openAI403CooldownMinutesDefault) * time.Minute)
reason := fmt.Sprintf("OpenAI 403 temporary cooldown (%d/%d): %s", count, openAI403DisableThreshold, msg)
if err := s.accountRepo.SetTempUnschedulable(ctx, account.ID, until, reason); err != nil {
slog.Warn("openai_403_set_temp_unschedulable_failed", "account_id", account.ID, "error", err)
s.handleAuthError(ctx, account, msg)
return true
}
slog.Warn(
"openai_403_temp_unschedulable",
"account_id", account.ID,
"until", until,
"count", count,
"threshold", openAI403DisableThreshold,
)
return true
}
// handleAntigravity403 处理 Antigravity 平台的 403 错误
// validation需要验证→ 永久 SetError需人工去 Google 验证后恢复)
// violation违规封号→ 永久 SetError需人工处理
@@ -681,10 +768,12 @@ func (s *RateLimitService) handleAntigravity403(ctx context.Context, account *Ac
switch fbType {
case forbiddenTypeValidation:
// VALIDATION_REQUIRED: 永久禁用,需人工去 Google 验证后手动恢复
msg := "Validation required (403): account needs Google verification"
if upstreamMsg != "" {
msg = "Validation required (403): " + upstreamMsg
}
msg := buildForbiddenErrorMessage(
"Validation required (403):",
upstreamMsg,
responseBody,
"account needs Google verification",
)
if validationURL := extractValidationURL(string(responseBody)); validationURL != "" {
msg += " | validation_url: " + validationURL
}
@@ -693,19 +782,23 @@ func (s *RateLimitService) handleAntigravity403(ctx context.Context, account *Ac
case forbiddenTypeViolation:
// 违规封号: 永久禁用,需人工处理
msg := "Account violation (403): terms of service violation"
if upstreamMsg != "" {
msg = "Account violation (403): " + upstreamMsg
}
msg := buildForbiddenErrorMessage(
"Account violation (403):",
upstreamMsg,
responseBody,
"terms of service violation",
)
s.handleAuthError(ctx, account, msg)
return true
default:
// 通用 403: 保持原有行为
msg := "Access forbidden (403): account may be suspended or lack permissions"
if upstreamMsg != "" {
msg = "Access forbidden (403): " + upstreamMsg
}
msg := buildForbiddenErrorMessage(
"Access forbidden (403):",
upstreamMsg,
responseBody,
"account may be suspended or lack permissions",
)
s.handleAuthError(ctx, account, msg)
return true
}
@@ -1221,9 +1314,19 @@ func (s *RateLimitService) ClearRateLimit(ctx context.Context, accountID int64)
slog.Warn("temp_unsched_cache_delete_failed", "account_id", accountID, "error", err)
}
}
s.ResetOpenAI403Counter(ctx, accountID)
return nil
}
func (s *RateLimitService) ResetOpenAI403Counter(ctx context.Context, accountID int64) {
if s == nil || s.openAI403CounterCache == nil || accountID <= 0 {
return
}
if err := s.openAI403CounterCache.ResetOpenAI403Count(ctx, accountID); err != nil {
slog.Warn("openai_403_reset_failed", "account_id", accountID, "error", err)
}
}
// RecoverAccountState 按需恢复账号的可恢复运行时状态。
func (s *RateLimitService) RecoverAccountState(ctx context.Context, accountID int64, options AccountRecoveryOptions) (*SuccessfulTestRecoveryResult, error) {
account, err := s.accountRepo.GetByID(ctx, accountID)
@@ -1250,6 +1353,9 @@ func (s *RateLimitService) RecoverAccountState(ctx context.Context, accountID in
}
result.ClearedRateLimit = true
}
if result.ClearedError || result.ClearedRateLimit {
s.ResetOpenAI403Counter(ctx, accountID)
}
return result, nil
}

View File

@@ -20,6 +20,7 @@ type rateLimitAccountRepoStub struct {
updateCredentialsCalls int
lastCredentials map[string]any
lastErrorMsg string
lastTempReason string
}
func (r *rateLimitAccountRepoStub) SetError(ctx context.Context, id int64, errorMsg string) error {
@@ -30,6 +31,7 @@ func (r *rateLimitAccountRepoStub) SetError(ctx context.Context, id int64, error
func (r *rateLimitAccountRepoStub) SetTempUnschedulable(ctx context.Context, id int64, until time.Time, reason string) error {
r.tempCalls++
r.lastTempReason = reason
return nil
}
@@ -44,6 +46,29 @@ type tokenCacheInvalidatorRecorder struct {
err error
}
type openAI403CounterCacheStub struct {
counts []int64
resetCalls []int64
err error
}
func (s *openAI403CounterCacheStub) IncrementOpenAI403Count(_ context.Context, _ int64, _ int) (int64, error) {
if s.err != nil {
return 0, s.err
}
if len(s.counts) == 0 {
return 1, nil
}
count := s.counts[0]
s.counts = s.counts[1:]
return count, nil
}
func (s *openAI403CounterCacheStub) ResetOpenAI403Count(_ context.Context, accountID int64) error {
s.resetCalls = append(s.resetCalls, accountID)
return nil
}
func (r *tokenCacheInvalidatorRecorder) InvalidateToken(ctx context.Context, account *Account) error {
r.accounts = append(r.accounts, account)
return r.err

View File

@@ -0,0 +1,64 @@
//go:build unit
package service
import (
"context"
"net/http"
"testing"
"github.com/Wei-Shaw/sub2api/internal/config"
"github.com/stretchr/testify/require"
)
func TestRateLimitService_HandleUpstreamError_OpenAI403FirstHitTempUnschedulable(t *testing.T) {
repo := &rateLimitAccountRepoStub{}
counter := &openAI403CounterCacheStub{counts: []int64{1}}
service := NewRateLimitService(repo, nil, &config.Config{}, nil, nil)
service.SetOpenAI403CounterCache(counter)
account := &Account{
ID: 301,
Platform: PlatformOpenAI,
Type: AccountTypeOAuth,
}
shouldDisable := service.HandleUpstreamError(
context.Background(),
account,
http.StatusForbidden,
http.Header{},
[]byte(`{"error":{"message":"temporary edge rejection"}}`),
)
require.True(t, shouldDisable)
require.Equal(t, 0, repo.setErrorCalls)
require.Equal(t, 1, repo.tempCalls)
require.Contains(t, repo.lastTempReason, "temporary edge rejection")
require.Contains(t, repo.lastTempReason, "(1/3)")
}
func TestRateLimitService_HandleUpstreamError_OpenAI403ThresholdDisables(t *testing.T) {
repo := &rateLimitAccountRepoStub{}
counter := &openAI403CounterCacheStub{counts: []int64{3}}
service := NewRateLimitService(repo, nil, &config.Config{}, nil, nil)
service.SetOpenAI403CounterCache(counter)
account := &Account{
ID: 302,
Platform: PlatformOpenAI,
Type: AccountTypeOAuth,
}
shouldDisable := service.HandleUpstreamError(
context.Background(),
account,
http.StatusForbidden,
http.Header{},
[]byte(`{"error":{"message":"workspace forbidden by policy"}}`),
)
require.True(t, shouldDisable)
require.Equal(t, 1, repo.setErrorCalls)
require.Equal(t, 0, repo.tempCalls)
require.Contains(t, repo.lastErrorMsg, "workspace forbidden by policy")
require.Contains(t, repo.lastErrorMsg, "consecutive_403=3/3")
}

View File

@@ -7,6 +7,9 @@ import (
"net/http"
"testing"
"time"
"github.com/Wei-Shaw/sub2api/internal/config"
"github.com/stretchr/testify/require"
)
func TestCalculateOpenAI429ResetTime_7dExhausted(t *testing.T) {
@@ -259,6 +262,53 @@ func TestNormalizedCodexLimits_OnlyPrimaryData(t *testing.T) {
}
}
func TestRateLimitService_HandleUpstreamError_403PreservesOriginalUpstreamMessage(t *testing.T) {
repo := &rateLimitAccountRepoStub{}
service := NewRateLimitService(repo, nil, &config.Config{}, nil, nil)
account := &Account{
ID: 201,
Platform: PlatformOpenAI,
Type: AccountTypeOAuth,
}
shouldDisable := service.HandleUpstreamError(
context.Background(),
account,
403,
http.Header{},
[]byte(`{"error":{"message":"workspace forbidden by policy","type":"invalid_request_error"}}`),
)
require.True(t, shouldDisable)
require.Equal(t, 1, repo.setErrorCalls)
require.Contains(t, repo.lastErrorMsg, "workspace forbidden by policy")
require.NotContains(t, repo.lastErrorMsg, "account may be suspended or lack permissions")
}
func TestRateLimitService_HandleUpstreamError_403FallsBackToRawBody(t *testing.T) {
repo := &rateLimitAccountRepoStub{}
service := NewRateLimitService(repo, nil, &config.Config{}, nil, nil)
account := &Account{
ID: 202,
Platform: PlatformOpenAI,
Type: AccountTypeOAuth,
}
shouldDisable := service.HandleUpstreamError(
context.Background(),
account,
403,
http.Header{},
[]byte(`{"error":{"type":"access_denied","details":{"reason":"ip_blocked"}}}`),
)
require.True(t, shouldDisable)
require.Equal(t, 1, repo.setErrorCalls)
require.Contains(t, repo.lastErrorMsg, `"access_denied"`)
require.Contains(t, repo.lastErrorMsg, `"ip_blocked"`)
require.NotContains(t, repo.lastErrorMsg, "account may be suspended or lack permissions")
}
func TestNormalizedCodexLimits_OnlySecondaryData(t *testing.T) {
// Test when only secondary has data, no window_minutes
sUsed := 60.0

View File

@@ -210,11 +210,13 @@ func ProvideRateLimitService(
geminiQuotaService *GeminiQuotaService,
tempUnschedCache TempUnschedCache,
timeoutCounterCache TimeoutCounterCache,
openAI403CounterCache OpenAI403CounterCache,
settingService *SettingService,
tokenCacheInvalidator TokenCacheInvalidator,
) *RateLimitService {
svc := NewRateLimitService(accountRepo, usageRepo, cfg, geminiQuotaService, tempUnschedCache)
svc.SetTimeoutCounterCache(timeoutCounterCache)
svc.SetOpenAI403CounterCache(openAI403CounterCache)
svc.SetSettingService(settingService)
svc.SetTokenCacheInvalidator(tokenCacheInvalidator)
return svc