Merge pull request #1349 from touwaeriol/feat/antigravity-internal500-penalty
feat(antigravity): progressive penalty for consecutive INTERNAL 500 errors
This commit is contained in:
@@ -137,7 +137,8 @@ func initializeApplication(buildInfo handler.BuildInfo) (*Application, error) {
|
||||
schedulerOutboxRepository := repository.NewSchedulerOutboxRepository(db)
|
||||
schedulerSnapshotService := service.ProvideSchedulerSnapshotService(schedulerCache, schedulerOutboxRepository, accountRepository, groupRepository, configConfig)
|
||||
antigravityTokenProvider := service.ProvideAntigravityTokenProvider(accountRepository, geminiTokenCache, antigravityOAuthService, oauthRefreshAPI, tempUnschedCache)
|
||||
antigravityGatewayService := service.NewAntigravityGatewayService(accountRepository, gatewayCache, schedulerSnapshotService, antigravityTokenProvider, rateLimitService, httpUpstream, settingService)
|
||||
internal500CounterCache := repository.NewInternal500CounterCache(redisClient)
|
||||
antigravityGatewayService := service.NewAntigravityGatewayService(accountRepository, gatewayCache, schedulerSnapshotService, antigravityTokenProvider, rateLimitService, httpUpstream, settingService, internal500CounterCache)
|
||||
tlsFingerprintProfileRepository := repository.NewTLSFingerprintProfileRepository(client)
|
||||
tlsFingerprintProfileCache := repository.NewTLSFingerprintProfileCache(redisClient)
|
||||
tlsFingerprintProfileService := service.NewTLSFingerprintProfileService(tlsFingerprintProfileRepository, tlsFingerprintProfileCache)
|
||||
|
||||
55
backend/internal/repository/internal500_counter_cache.go
Normal file
55
backend/internal/repository/internal500_counter_cache.go
Normal file
@@ -0,0 +1,55 @@
|
||||
package repository
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/service"
|
||||
"github.com/redis/go-redis/v9"
|
||||
)
|
||||
|
||||
const (
|
||||
internal500CounterPrefix = "internal500_count:account:"
|
||||
internal500CounterTTLSeconds = 86400 // 24 小时兜底
|
||||
)
|
||||
|
||||
// internal500CounterIncrScript 使用 Lua 脚本原子性地增加计数并返回当前值
|
||||
// 如果 key 不存在,则创建并设置过期时间
|
||||
var internal500CounterIncrScript = redis.NewScript(`
|
||||
local key = KEYS[1]
|
||||
local ttl = tonumber(ARGV[1])
|
||||
|
||||
local count = redis.call('INCR', key)
|
||||
if count == 1 then
|
||||
redis.call('EXPIRE', key, ttl)
|
||||
end
|
||||
|
||||
return count
|
||||
`)
|
||||
|
||||
type internal500CounterCache struct {
|
||||
rdb *redis.Client
|
||||
}
|
||||
|
||||
// NewInternal500CounterCache 创建 INTERNAL 500 连续失败计数器缓存实例
|
||||
func NewInternal500CounterCache(rdb *redis.Client) service.Internal500CounterCache {
|
||||
return &internal500CounterCache{rdb: rdb}
|
||||
}
|
||||
|
||||
// IncrementInternal500Count 原子递增计数并返回当前值
|
||||
func (c *internal500CounterCache) IncrementInternal500Count(ctx context.Context, accountID int64) (int64, error) {
|
||||
key := fmt.Sprintf("%s%d", internal500CounterPrefix, accountID)
|
||||
|
||||
result, err := internal500CounterIncrScript.Run(ctx, c.rdb, []string{key}, internal500CounterTTLSeconds).Int64()
|
||||
if err != nil {
|
||||
return 0, fmt.Errorf("increment internal500 count: %w", err)
|
||||
}
|
||||
|
||||
return result, nil
|
||||
}
|
||||
|
||||
// ResetInternal500Count 清零计数器(成功响应时调用)
|
||||
func (c *internal500CounterCache) ResetInternal500Count(ctx context.Context, accountID int64) error {
|
||||
key := fmt.Sprintf("%s%d", internal500CounterPrefix, accountID)
|
||||
return c.rdb.Del(ctx, key).Err()
|
||||
}
|
||||
@@ -81,6 +81,7 @@ var ProviderSet = wire.NewSet(
|
||||
NewAPIKeyCache,
|
||||
NewTempUnschedCache,
|
||||
NewTimeoutCounterCache,
|
||||
NewInternal500CounterCache,
|
||||
ProvideConcurrencyCache,
|
||||
ProvideSessionLimitCache,
|
||||
NewRPMCache,
|
||||
|
||||
@@ -614,6 +614,7 @@ func (s *AntigravityGatewayService) antigravityRetryLoop(p antigravityRetryLoopP
|
||||
urlFallbackLoop:
|
||||
for urlIdx, baseURL := range availableURLs {
|
||||
usedBaseURL = baseURL
|
||||
allAttemptsInternal500 := true // 追踪本轮所有 attempt 是否全部命中 INTERNAL 500
|
||||
for attempt := 1; attempt <= antigravityMaxRetries; attempt++ {
|
||||
select {
|
||||
case <-p.ctx.Done():
|
||||
@@ -766,10 +767,19 @@ urlFallbackLoop:
|
||||
logger.LegacyPrintf("service.antigravity_gateway", "%s status=context_canceled_during_backoff", p.prefix)
|
||||
return nil, p.ctx.Err()
|
||||
}
|
||||
// 追踪 INTERNAL 500:非匹配的 attempt 清除标记
|
||||
if !isAntigravityInternalServerError(resp.StatusCode, respBody) {
|
||||
allAttemptsInternal500 = false
|
||||
}
|
||||
continue
|
||||
}
|
||||
}
|
||||
|
||||
// INTERNAL 500 渐进惩罚:3 次重试全部命中特定 500 时递增计数器并惩罚
|
||||
if allAttemptsInternal500 && isAntigravityInternalServerError(resp.StatusCode, respBody) {
|
||||
s.handleInternal500RetryExhausted(p.ctx, p.prefix, p.account)
|
||||
}
|
||||
|
||||
// 其他 4xx 错误或重试用尽,直接返回
|
||||
resp = &http.Response{
|
||||
StatusCode: resp.StatusCode,
|
||||
@@ -788,6 +798,11 @@ urlFallbackLoop:
|
||||
antigravity.DefaultURLAvailability.MarkSuccess(usedBaseURL)
|
||||
}
|
||||
|
||||
// 成功响应时清零 INTERNAL 500 连续失败计数器(覆盖所有成功路径,含 smart retry)
|
||||
if resp != nil && resp.StatusCode < 400 {
|
||||
s.resetInternal500Counter(p.ctx, p.prefix, p.account.ID)
|
||||
}
|
||||
|
||||
return &antigravityRetryLoopResult{resp: resp}, nil
|
||||
}
|
||||
|
||||
@@ -862,6 +877,7 @@ type AntigravityGatewayService struct {
|
||||
settingService *SettingService
|
||||
cache GatewayCache // 用于模型级限流时清除粘性会话绑定
|
||||
schedulerSnapshot *SchedulerSnapshotService
|
||||
internal500Cache Internal500CounterCache // INTERNAL 500 渐进惩罚计数器
|
||||
}
|
||||
|
||||
func NewAntigravityGatewayService(
|
||||
@@ -872,6 +888,7 @@ func NewAntigravityGatewayService(
|
||||
rateLimitService *RateLimitService,
|
||||
httpUpstream HTTPUpstream,
|
||||
settingService *SettingService,
|
||||
internal500Cache Internal500CounterCache,
|
||||
) *AntigravityGatewayService {
|
||||
return &AntigravityGatewayService{
|
||||
accountRepo: accountRepo,
|
||||
@@ -881,6 +898,7 @@ func NewAntigravityGatewayService(
|
||||
settingService: settingService,
|
||||
cache: cache,
|
||||
schedulerSnapshot: schedulerSnapshot,
|
||||
internal500Cache: internal500Cache,
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
97
backend/internal/service/antigravity_internal500_penalty.go
Normal file
97
backend/internal/service/antigravity_internal500_penalty.go
Normal file
@@ -0,0 +1,97 @@
|
||||
package service
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"log/slog"
|
||||
"net/http"
|
||||
"time"
|
||||
|
||||
"github.com/tidwall/gjson"
|
||||
)
|
||||
|
||||
// INTERNAL 500 渐进惩罚:连续多轮全部返回特定 500 错误时的惩罚时长
|
||||
const (
|
||||
internal500PenaltyTier1Duration = 30 * time.Minute // 第 1 轮:临时不可调度 30 分钟
|
||||
internal500PenaltyTier2Duration = 2 * time.Hour // 第 2 轮:临时不可调度 2 小时
|
||||
internal500PenaltyTier3Threshold = 3 // 第 3+ 轮:永久禁用
|
||||
)
|
||||
|
||||
// isAntigravityInternalServerError 检测特定的 INTERNAL 500 错误
|
||||
// 必须同时匹配 error.code==500, error.message=="Internal error encountered.", error.status=="INTERNAL"
|
||||
func isAntigravityInternalServerError(statusCode int, body []byte) bool {
|
||||
if statusCode != http.StatusInternalServerError {
|
||||
return false
|
||||
}
|
||||
return gjson.GetBytes(body, "error.code").Int() == 500 &&
|
||||
gjson.GetBytes(body, "error.message").String() == "Internal error encountered." &&
|
||||
gjson.GetBytes(body, "error.status").String() == "INTERNAL"
|
||||
}
|
||||
|
||||
// applyInternal500Penalty 根据连续 INTERNAL 500 轮次数应用渐进惩罚
|
||||
// count=1: temp_unschedulable 10 分钟
|
||||
// count=2: temp_unschedulable 10 小时
|
||||
// count>=3: SetError 永久禁用
|
||||
func (s *AntigravityGatewayService) applyInternal500Penalty(
|
||||
ctx context.Context, prefix string, account *Account, count int64,
|
||||
) {
|
||||
switch {
|
||||
case count >= int64(internal500PenaltyTier3Threshold):
|
||||
reason := fmt.Sprintf("INTERNAL 500 consecutive failures: %d rounds", count)
|
||||
if err := s.accountRepo.SetError(ctx, account.ID, reason); err != nil {
|
||||
slog.Error("internal500_set_error_failed", "account_id", account.ID, "error", err)
|
||||
return
|
||||
}
|
||||
slog.Warn("internal500_account_disabled",
|
||||
"account_id", account.ID, "account_name", account.Name, "consecutive_count", count)
|
||||
case count == 2:
|
||||
until := time.Now().Add(internal500PenaltyTier2Duration)
|
||||
reason := fmt.Sprintf("INTERNAL 500 x%d (temp unsched %v)", count, internal500PenaltyTier2Duration)
|
||||
if err := s.accountRepo.SetTempUnschedulable(ctx, account.ID, until, reason); err != nil {
|
||||
slog.Error("internal500_temp_unsched_failed", "account_id", account.ID, "error", err)
|
||||
return
|
||||
}
|
||||
slog.Warn("internal500_temp_unschedulable",
|
||||
"account_id", account.ID, "account_name", account.Name,
|
||||
"duration", internal500PenaltyTier2Duration, "consecutive_count", count)
|
||||
case count == 1:
|
||||
until := time.Now().Add(internal500PenaltyTier1Duration)
|
||||
reason := fmt.Sprintf("INTERNAL 500 x%d (temp unsched %v)", count, internal500PenaltyTier1Duration)
|
||||
if err := s.accountRepo.SetTempUnschedulable(ctx, account.ID, until, reason); err != nil {
|
||||
slog.Error("internal500_temp_unsched_failed", "account_id", account.ID, "error", err)
|
||||
return
|
||||
}
|
||||
slog.Info("internal500_temp_unschedulable",
|
||||
"account_id", account.ID, "account_name", account.Name,
|
||||
"duration", internal500PenaltyTier1Duration, "consecutive_count", count)
|
||||
}
|
||||
}
|
||||
|
||||
// handleInternal500RetryExhausted 处理 INTERNAL 500 重试耗尽:递增计数器并应用惩罚
|
||||
func (s *AntigravityGatewayService) handleInternal500RetryExhausted(
|
||||
ctx context.Context, prefix string, account *Account,
|
||||
) {
|
||||
if s.internal500Cache == nil {
|
||||
return
|
||||
}
|
||||
count, err := s.internal500Cache.IncrementInternal500Count(ctx, account.ID)
|
||||
if err != nil {
|
||||
slog.Error("internal500_counter_increment_failed",
|
||||
"prefix", prefix, "account_id", account.ID, "error", err)
|
||||
return
|
||||
}
|
||||
s.applyInternal500Penalty(ctx, prefix, account, count)
|
||||
}
|
||||
|
||||
// resetInternal500Counter 成功响应时清零 INTERNAL 500 计数器
|
||||
func (s *AntigravityGatewayService) resetInternal500Counter(
|
||||
ctx context.Context, prefix string, accountID int64,
|
||||
) {
|
||||
if s.internal500Cache == nil {
|
||||
return
|
||||
}
|
||||
if err := s.internal500Cache.ResetInternal500Count(ctx, accountID); err != nil {
|
||||
slog.Error("internal500_counter_reset_failed",
|
||||
"prefix", prefix, "account_id", accountID, "error", err)
|
||||
}
|
||||
}
|
||||
321
backend/internal/service/antigravity_internal500_penalty_test.go
Normal file
321
backend/internal/service/antigravity_internal500_penalty_test.go
Normal file
@@ -0,0 +1,321 @@
|
||||
//go:build unit
|
||||
|
||||
package service
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
// --- mock: Internal500CounterCache ---
|
||||
|
||||
type mockInternal500Cache struct {
|
||||
incrementCount int64
|
||||
incrementErr error
|
||||
resetErr error
|
||||
|
||||
incrementCalls []int64 // 记录 IncrementInternal500Count 被调用时的 accountID
|
||||
resetCalls []int64 // 记录 ResetInternal500Count 被调用时的 accountID
|
||||
}
|
||||
|
||||
func (m *mockInternal500Cache) IncrementInternal500Count(_ context.Context, accountID int64) (int64, error) {
|
||||
m.incrementCalls = append(m.incrementCalls, accountID)
|
||||
return m.incrementCount, m.incrementErr
|
||||
}
|
||||
|
||||
func (m *mockInternal500Cache) ResetInternal500Count(_ context.Context, accountID int64) error {
|
||||
m.resetCalls = append(m.resetCalls, accountID)
|
||||
return m.resetErr
|
||||
}
|
||||
|
||||
// --- mock: 专用于 internal500 惩罚测试的 AccountRepository ---
|
||||
|
||||
type internal500AccountRepoStub struct {
|
||||
AccountRepository // 嵌入接口,未实现的方法会 panic(不应被调用)
|
||||
|
||||
tempUnschedCalls []tempUnschedCall
|
||||
setErrorCalls []setErrorCall
|
||||
}
|
||||
|
||||
type tempUnschedCall struct {
|
||||
accountID int64
|
||||
until time.Time
|
||||
reason string
|
||||
}
|
||||
|
||||
type setErrorCall struct {
|
||||
accountID int64
|
||||
reason string
|
||||
}
|
||||
|
||||
func (r *internal500AccountRepoStub) SetTempUnschedulable(_ context.Context, id int64, until time.Time, reason string) error {
|
||||
r.tempUnschedCalls = append(r.tempUnschedCalls, tempUnschedCall{accountID: id, until: until, reason: reason})
|
||||
return nil
|
||||
}
|
||||
|
||||
func (r *internal500AccountRepoStub) SetError(_ context.Context, id int64, errorMsg string) error {
|
||||
r.setErrorCalls = append(r.setErrorCalls, setErrorCall{accountID: id, reason: errorMsg})
|
||||
return nil
|
||||
}
|
||||
|
||||
// =============================================================================
|
||||
// TestIsAntigravityInternalServerError
|
||||
// =============================================================================
|
||||
|
||||
func TestIsAntigravityInternalServerError(t *testing.T) {
|
||||
t.Run("匹配完整的 INTERNAL 500 body", func(t *testing.T) {
|
||||
body := []byte(`{"error":{"code":500,"message":"Internal error encountered.","status":"INTERNAL"}}`)
|
||||
require.True(t, isAntigravityInternalServerError(500, body))
|
||||
})
|
||||
|
||||
t.Run("statusCode 不是 500", func(t *testing.T) {
|
||||
body := []byte(`{"error":{"code":500,"message":"Internal error encountered.","status":"INTERNAL"}}`)
|
||||
require.False(t, isAntigravityInternalServerError(429, body))
|
||||
require.False(t, isAntigravityInternalServerError(503, body))
|
||||
require.False(t, isAntigravityInternalServerError(200, body))
|
||||
})
|
||||
|
||||
t.Run("body 中 message 不匹配", func(t *testing.T) {
|
||||
body := []byte(`{"error":{"code":500,"message":"Some other error","status":"INTERNAL"}}`)
|
||||
require.False(t, isAntigravityInternalServerError(500, body))
|
||||
})
|
||||
|
||||
t.Run("body 中 status 不匹配", func(t *testing.T) {
|
||||
body := []byte(`{"error":{"code":500,"message":"Internal error encountered.","status":"UNAVAILABLE"}}`)
|
||||
require.False(t, isAntigravityInternalServerError(500, body))
|
||||
})
|
||||
|
||||
t.Run("body 中 code 不匹配", func(t *testing.T) {
|
||||
body := []byte(`{"error":{"code":503,"message":"Internal error encountered.","status":"INTERNAL"}}`)
|
||||
require.False(t, isAntigravityInternalServerError(500, body))
|
||||
})
|
||||
|
||||
t.Run("空 body", func(t *testing.T) {
|
||||
require.False(t, isAntigravityInternalServerError(500, []byte{}))
|
||||
require.False(t, isAntigravityInternalServerError(500, nil))
|
||||
})
|
||||
|
||||
t.Run("其他 500 错误格式(纯文本)", func(t *testing.T) {
|
||||
body := []byte(`Internal Server Error`)
|
||||
require.False(t, isAntigravityInternalServerError(500, body))
|
||||
})
|
||||
|
||||
t.Run("其他 500 错误格式(不同 JSON 结构)", func(t *testing.T) {
|
||||
body := []byte(`{"message":"Internal Server Error","statusCode":500}`)
|
||||
require.False(t, isAntigravityInternalServerError(500, body))
|
||||
})
|
||||
}
|
||||
|
||||
// =============================================================================
|
||||
// TestApplyInternal500Penalty
|
||||
// =============================================================================
|
||||
|
||||
func TestApplyInternal500Penalty(t *testing.T) {
|
||||
t.Run("count=1 → SetTempUnschedulable 10 分钟", func(t *testing.T) {
|
||||
repo := &internal500AccountRepoStub{}
|
||||
svc := &AntigravityGatewayService{accountRepo: repo}
|
||||
account := &Account{ID: 1, Name: "acc-1"}
|
||||
|
||||
before := time.Now()
|
||||
svc.applyInternal500Penalty(context.Background(), "[test]", account, 1)
|
||||
after := time.Now()
|
||||
|
||||
require.Len(t, repo.tempUnschedCalls, 1)
|
||||
require.Empty(t, repo.setErrorCalls)
|
||||
|
||||
call := repo.tempUnschedCalls[0]
|
||||
require.Equal(t, int64(1), call.accountID)
|
||||
require.Contains(t, call.reason, "INTERNAL 500")
|
||||
// until 应在 [before+10m, after+10m] 范围内
|
||||
require.True(t, call.until.After(before.Add(internal500PenaltyTier1Duration).Add(-time.Second)))
|
||||
require.True(t, call.until.Before(after.Add(internal500PenaltyTier1Duration).Add(time.Second)))
|
||||
})
|
||||
|
||||
t.Run("count=2 → SetTempUnschedulable 10 小时", func(t *testing.T) {
|
||||
repo := &internal500AccountRepoStub{}
|
||||
svc := &AntigravityGatewayService{accountRepo: repo}
|
||||
account := &Account{ID: 2, Name: "acc-2"}
|
||||
|
||||
before := time.Now()
|
||||
svc.applyInternal500Penalty(context.Background(), "[test]", account, 2)
|
||||
after := time.Now()
|
||||
|
||||
require.Len(t, repo.tempUnschedCalls, 1)
|
||||
require.Empty(t, repo.setErrorCalls)
|
||||
|
||||
call := repo.tempUnschedCalls[0]
|
||||
require.Equal(t, int64(2), call.accountID)
|
||||
require.Contains(t, call.reason, "INTERNAL 500")
|
||||
require.True(t, call.until.After(before.Add(internal500PenaltyTier2Duration).Add(-time.Second)))
|
||||
require.True(t, call.until.Before(after.Add(internal500PenaltyTier2Duration).Add(time.Second)))
|
||||
})
|
||||
|
||||
t.Run("count=3 → SetError 永久禁用", func(t *testing.T) {
|
||||
repo := &internal500AccountRepoStub{}
|
||||
svc := &AntigravityGatewayService{accountRepo: repo}
|
||||
account := &Account{ID: 3, Name: "acc-3"}
|
||||
|
||||
svc.applyInternal500Penalty(context.Background(), "[test]", account, 3)
|
||||
|
||||
require.Empty(t, repo.tempUnschedCalls)
|
||||
require.Len(t, repo.setErrorCalls, 1)
|
||||
|
||||
call := repo.setErrorCalls[0]
|
||||
require.Equal(t, int64(3), call.accountID)
|
||||
require.Contains(t, call.reason, "INTERNAL 500 consecutive failures: 3")
|
||||
})
|
||||
|
||||
t.Run("count=5 → SetError 永久禁用(>=3 都走永久禁用)", func(t *testing.T) {
|
||||
repo := &internal500AccountRepoStub{}
|
||||
svc := &AntigravityGatewayService{accountRepo: repo}
|
||||
account := &Account{ID: 5, Name: "acc-5"}
|
||||
|
||||
svc.applyInternal500Penalty(context.Background(), "[test]", account, 5)
|
||||
|
||||
require.Empty(t, repo.tempUnschedCalls)
|
||||
require.Len(t, repo.setErrorCalls, 1)
|
||||
|
||||
call := repo.setErrorCalls[0]
|
||||
require.Equal(t, int64(5), call.accountID)
|
||||
require.Contains(t, call.reason, "INTERNAL 500 consecutive failures: 5")
|
||||
})
|
||||
|
||||
t.Run("count=0 → 不调用任何方法", func(t *testing.T) {
|
||||
repo := &internal500AccountRepoStub{}
|
||||
svc := &AntigravityGatewayService{accountRepo: repo}
|
||||
account := &Account{ID: 10, Name: "acc-10"}
|
||||
|
||||
svc.applyInternal500Penalty(context.Background(), "[test]", account, 0)
|
||||
|
||||
require.Empty(t, repo.tempUnschedCalls)
|
||||
require.Empty(t, repo.setErrorCalls)
|
||||
})
|
||||
}
|
||||
|
||||
// =============================================================================
|
||||
// TestHandleInternal500RetryExhausted
|
||||
// =============================================================================
|
||||
|
||||
func TestHandleInternal500RetryExhausted(t *testing.T) {
|
||||
t.Run("internal500Cache 为 nil → 不 panic,不调用任何方法", func(t *testing.T) {
|
||||
repo := &internal500AccountRepoStub{}
|
||||
svc := &AntigravityGatewayService{
|
||||
accountRepo: repo,
|
||||
internal500Cache: nil,
|
||||
}
|
||||
account := &Account{ID: 1, Name: "acc-1"}
|
||||
|
||||
// 不应 panic
|
||||
require.NotPanics(t, func() {
|
||||
svc.handleInternal500RetryExhausted(context.Background(), "[test]", account)
|
||||
})
|
||||
require.Empty(t, repo.tempUnschedCalls)
|
||||
require.Empty(t, repo.setErrorCalls)
|
||||
})
|
||||
|
||||
t.Run("IncrementInternal500Count 返回 error → 不调用惩罚方法", func(t *testing.T) {
|
||||
repo := &internal500AccountRepoStub{}
|
||||
cache := &mockInternal500Cache{
|
||||
incrementErr: errors.New("redis connection error"),
|
||||
}
|
||||
svc := &AntigravityGatewayService{
|
||||
accountRepo: repo,
|
||||
internal500Cache: cache,
|
||||
}
|
||||
account := &Account{ID: 2, Name: "acc-2"}
|
||||
|
||||
svc.handleInternal500RetryExhausted(context.Background(), "[test]", account)
|
||||
|
||||
require.Len(t, cache.incrementCalls, 1)
|
||||
require.Equal(t, int64(2), cache.incrementCalls[0])
|
||||
require.Empty(t, repo.tempUnschedCalls)
|
||||
require.Empty(t, repo.setErrorCalls)
|
||||
})
|
||||
|
||||
t.Run("IncrementInternal500Count 返回 count=1 → 触发 tier1 惩罚", func(t *testing.T) {
|
||||
repo := &internal500AccountRepoStub{}
|
||||
cache := &mockInternal500Cache{
|
||||
incrementCount: 1,
|
||||
}
|
||||
svc := &AntigravityGatewayService{
|
||||
accountRepo: repo,
|
||||
internal500Cache: cache,
|
||||
}
|
||||
account := &Account{ID: 3, Name: "acc-3"}
|
||||
|
||||
svc.handleInternal500RetryExhausted(context.Background(), "[test]", account)
|
||||
|
||||
require.Len(t, cache.incrementCalls, 1)
|
||||
require.Equal(t, int64(3), cache.incrementCalls[0])
|
||||
// tier1: SetTempUnschedulable
|
||||
require.Len(t, repo.tempUnschedCalls, 1)
|
||||
require.Equal(t, int64(3), repo.tempUnschedCalls[0].accountID)
|
||||
require.Empty(t, repo.setErrorCalls)
|
||||
})
|
||||
|
||||
t.Run("IncrementInternal500Count 返回 count=3 → 触发 tier3 永久禁用", func(t *testing.T) {
|
||||
repo := &internal500AccountRepoStub{}
|
||||
cache := &mockInternal500Cache{
|
||||
incrementCount: 3,
|
||||
}
|
||||
svc := &AntigravityGatewayService{
|
||||
accountRepo: repo,
|
||||
internal500Cache: cache,
|
||||
}
|
||||
account := &Account{ID: 4, Name: "acc-4"}
|
||||
|
||||
svc.handleInternal500RetryExhausted(context.Background(), "[test]", account)
|
||||
|
||||
require.Len(t, cache.incrementCalls, 1)
|
||||
require.Empty(t, repo.tempUnschedCalls)
|
||||
require.Len(t, repo.setErrorCalls, 1)
|
||||
require.Equal(t, int64(4), repo.setErrorCalls[0].accountID)
|
||||
})
|
||||
}
|
||||
|
||||
// =============================================================================
|
||||
// TestResetInternal500Counter
|
||||
// =============================================================================
|
||||
|
||||
func TestResetInternal500Counter(t *testing.T) {
|
||||
t.Run("internal500Cache 为 nil → 不 panic", func(t *testing.T) {
|
||||
svc := &AntigravityGatewayService{
|
||||
internal500Cache: nil,
|
||||
}
|
||||
|
||||
require.NotPanics(t, func() {
|
||||
svc.resetInternal500Counter(context.Background(), "[test]", 1)
|
||||
})
|
||||
})
|
||||
|
||||
t.Run("ResetInternal500Count 返回 error → 不 panic(仅日志)", func(t *testing.T) {
|
||||
cache := &mockInternal500Cache{
|
||||
resetErr: errors.New("redis timeout"),
|
||||
}
|
||||
svc := &AntigravityGatewayService{
|
||||
internal500Cache: cache,
|
||||
}
|
||||
|
||||
require.NotPanics(t, func() {
|
||||
svc.resetInternal500Counter(context.Background(), "[test]", 42)
|
||||
})
|
||||
require.Len(t, cache.resetCalls, 1)
|
||||
require.Equal(t, int64(42), cache.resetCalls[0])
|
||||
})
|
||||
|
||||
t.Run("正常调用 → 调用 ResetInternal500Count", func(t *testing.T) {
|
||||
cache := &mockInternal500Cache{}
|
||||
svc := &AntigravityGatewayService{
|
||||
internal500Cache: cache,
|
||||
}
|
||||
|
||||
svc.resetInternal500Counter(context.Background(), "[test]", 99)
|
||||
|
||||
require.Len(t, cache.resetCalls, 1)
|
||||
require.Equal(t, int64(99), cache.resetCalls[0])
|
||||
})
|
||||
}
|
||||
11
backend/internal/service/internal500_counter.go
Normal file
11
backend/internal/service/internal500_counter.go
Normal file
@@ -0,0 +1,11 @@
|
||||
package service
|
||||
|
||||
import "context"
|
||||
|
||||
// Internal500CounterCache 追踪 Antigravity 账号连续 INTERNAL 500 失败轮数
|
||||
type Internal500CounterCache interface {
|
||||
// IncrementInternal500Count 原子递增计数并返回当前值
|
||||
IncrementInternal500Count(ctx context.Context, accountID int64) (int64, error)
|
||||
// ResetInternal500Count 清零计数器(成功响应时调用)
|
||||
ResetInternal500Count(ctx context.Context, accountID int64) error
|
||||
}
|
||||
Reference in New Issue
Block a user