diff --git a/backend/internal/service/antigravity_internal500_penalty_test.go b/backend/internal/service/antigravity_internal500_penalty_test.go new file mode 100644 index 00000000..03831839 --- /dev/null +++ b/backend/internal/service/antigravity_internal500_penalty_test.go @@ -0,0 +1,321 @@ +//go:build unit + +package service + +import ( + "context" + "errors" + "testing" + "time" + + "github.com/stretchr/testify/require" +) + +// --- mock: Internal500CounterCache --- + +type mockInternal500Cache struct { + incrementCount int64 + incrementErr error + resetErr error + + incrementCalls []int64 // 记录 IncrementInternal500Count 被调用时的 accountID + resetCalls []int64 // 记录 ResetInternal500Count 被调用时的 accountID +} + +func (m *mockInternal500Cache) IncrementInternal500Count(_ context.Context, accountID int64) (int64, error) { + m.incrementCalls = append(m.incrementCalls, accountID) + return m.incrementCount, m.incrementErr +} + +func (m *mockInternal500Cache) ResetInternal500Count(_ context.Context, accountID int64) error { + m.resetCalls = append(m.resetCalls, accountID) + return m.resetErr +} + +// --- mock: 专用于 internal500 惩罚测试的 AccountRepository --- + +type internal500AccountRepoStub struct { + AccountRepository // 嵌入接口,未实现的方法会 panic(不应被调用) + + tempUnschedCalls []tempUnschedCall + setErrorCalls []setErrorCall +} + +type tempUnschedCall struct { + accountID int64 + until time.Time + reason string +} + +type setErrorCall struct { + accountID int64 + reason string +} + +func (r *internal500AccountRepoStub) SetTempUnschedulable(_ context.Context, id int64, until time.Time, reason string) error { + r.tempUnschedCalls = append(r.tempUnschedCalls, tempUnschedCall{accountID: id, until: until, reason: reason}) + return nil +} + +func (r *internal500AccountRepoStub) SetError(_ context.Context, id int64, errorMsg string) error { + r.setErrorCalls = append(r.setErrorCalls, setErrorCall{accountID: id, reason: errorMsg}) + return nil +} + +// ============================================================================= +// TestIsAntigravityInternalServerError +// ============================================================================= + +func TestIsAntigravityInternalServerError(t *testing.T) { + t.Run("匹配完整的 INTERNAL 500 body", func(t *testing.T) { + body := []byte(`{"error":{"code":500,"message":"Internal error encountered.","status":"INTERNAL"}}`) + require.True(t, isAntigravityInternalServerError(500, body)) + }) + + t.Run("statusCode 不是 500", func(t *testing.T) { + body := []byte(`{"error":{"code":500,"message":"Internal error encountered.","status":"INTERNAL"}}`) + require.False(t, isAntigravityInternalServerError(429, body)) + require.False(t, isAntigravityInternalServerError(503, body)) + require.False(t, isAntigravityInternalServerError(200, body)) + }) + + t.Run("body 中 message 不匹配", func(t *testing.T) { + body := []byte(`{"error":{"code":500,"message":"Some other error","status":"INTERNAL"}}`) + require.False(t, isAntigravityInternalServerError(500, body)) + }) + + t.Run("body 中 status 不匹配", func(t *testing.T) { + body := []byte(`{"error":{"code":500,"message":"Internal error encountered.","status":"UNAVAILABLE"}}`) + require.False(t, isAntigravityInternalServerError(500, body)) + }) + + t.Run("body 中 code 不匹配", func(t *testing.T) { + body := []byte(`{"error":{"code":503,"message":"Internal error encountered.","status":"INTERNAL"}}`) + require.False(t, isAntigravityInternalServerError(500, body)) + }) + + t.Run("空 body", func(t *testing.T) { + require.False(t, isAntigravityInternalServerError(500, []byte{})) + require.False(t, isAntigravityInternalServerError(500, nil)) + }) + + t.Run("其他 500 错误格式(纯文本)", func(t *testing.T) { + body := []byte(`Internal Server Error`) + require.False(t, isAntigravityInternalServerError(500, body)) + }) + + t.Run("其他 500 错误格式(不同 JSON 结构)", func(t *testing.T) { + body := []byte(`{"message":"Internal Server Error","statusCode":500}`) + require.False(t, isAntigravityInternalServerError(500, body)) + }) +} + +// ============================================================================= +// TestApplyInternal500Penalty +// ============================================================================= + +func TestApplyInternal500Penalty(t *testing.T) { + t.Run("count=1 → SetTempUnschedulable 10 分钟", func(t *testing.T) { + repo := &internal500AccountRepoStub{} + svc := &AntigravityGatewayService{accountRepo: repo} + account := &Account{ID: 1, Name: "acc-1"} + + before := time.Now() + svc.applyInternal500Penalty(context.Background(), "[test]", account, 1) + after := time.Now() + + require.Len(t, repo.tempUnschedCalls, 1) + require.Empty(t, repo.setErrorCalls) + + call := repo.tempUnschedCalls[0] + require.Equal(t, int64(1), call.accountID) + require.Contains(t, call.reason, "INTERNAL 500") + // until 应在 [before+10m, after+10m] 范围内 + require.True(t, call.until.After(before.Add(internal500PenaltyTier1Duration).Add(-time.Second))) + require.True(t, call.until.Before(after.Add(internal500PenaltyTier1Duration).Add(time.Second))) + }) + + t.Run("count=2 → SetTempUnschedulable 10 小时", func(t *testing.T) { + repo := &internal500AccountRepoStub{} + svc := &AntigravityGatewayService{accountRepo: repo} + account := &Account{ID: 2, Name: "acc-2"} + + before := time.Now() + svc.applyInternal500Penalty(context.Background(), "[test]", account, 2) + after := time.Now() + + require.Len(t, repo.tempUnschedCalls, 1) + require.Empty(t, repo.setErrorCalls) + + call := repo.tempUnschedCalls[0] + require.Equal(t, int64(2), call.accountID) + require.Contains(t, call.reason, "INTERNAL 500") + require.True(t, call.until.After(before.Add(internal500PenaltyTier2Duration).Add(-time.Second))) + require.True(t, call.until.Before(after.Add(internal500PenaltyTier2Duration).Add(time.Second))) + }) + + t.Run("count=3 → SetError 永久禁用", func(t *testing.T) { + repo := &internal500AccountRepoStub{} + svc := &AntigravityGatewayService{accountRepo: repo} + account := &Account{ID: 3, Name: "acc-3"} + + svc.applyInternal500Penalty(context.Background(), "[test]", account, 3) + + require.Empty(t, repo.tempUnschedCalls) + require.Len(t, repo.setErrorCalls, 1) + + call := repo.setErrorCalls[0] + require.Equal(t, int64(3), call.accountID) + require.Contains(t, call.reason, "INTERNAL 500 consecutive failures: 3") + }) + + t.Run("count=5 → SetError 永久禁用(>=3 都走永久禁用)", func(t *testing.T) { + repo := &internal500AccountRepoStub{} + svc := &AntigravityGatewayService{accountRepo: repo} + account := &Account{ID: 5, Name: "acc-5"} + + svc.applyInternal500Penalty(context.Background(), "[test]", account, 5) + + require.Empty(t, repo.tempUnschedCalls) + require.Len(t, repo.setErrorCalls, 1) + + call := repo.setErrorCalls[0] + require.Equal(t, int64(5), call.accountID) + require.Contains(t, call.reason, "INTERNAL 500 consecutive failures: 5") + }) + + t.Run("count=0 → 不调用任何方法", func(t *testing.T) { + repo := &internal500AccountRepoStub{} + svc := &AntigravityGatewayService{accountRepo: repo} + account := &Account{ID: 10, Name: "acc-10"} + + svc.applyInternal500Penalty(context.Background(), "[test]", account, 0) + + require.Empty(t, repo.tempUnschedCalls) + require.Empty(t, repo.setErrorCalls) + }) +} + +// ============================================================================= +// TestHandleInternal500RetryExhausted +// ============================================================================= + +func TestHandleInternal500RetryExhausted(t *testing.T) { + t.Run("internal500Cache 为 nil → 不 panic,不调用任何方法", func(t *testing.T) { + repo := &internal500AccountRepoStub{} + svc := &AntigravityGatewayService{ + accountRepo: repo, + internal500Cache: nil, + } + account := &Account{ID: 1, Name: "acc-1"} + + // 不应 panic + require.NotPanics(t, func() { + svc.handleInternal500RetryExhausted(context.Background(), "[test]", account) + }) + require.Empty(t, repo.tempUnschedCalls) + require.Empty(t, repo.setErrorCalls) + }) + + t.Run("IncrementInternal500Count 返回 error → 不调用惩罚方法", func(t *testing.T) { + repo := &internal500AccountRepoStub{} + cache := &mockInternal500Cache{ + incrementErr: errors.New("redis connection error"), + } + svc := &AntigravityGatewayService{ + accountRepo: repo, + internal500Cache: cache, + } + account := &Account{ID: 2, Name: "acc-2"} + + svc.handleInternal500RetryExhausted(context.Background(), "[test]", account) + + require.Len(t, cache.incrementCalls, 1) + require.Equal(t, int64(2), cache.incrementCalls[0]) + require.Empty(t, repo.tempUnschedCalls) + require.Empty(t, repo.setErrorCalls) + }) + + t.Run("IncrementInternal500Count 返回 count=1 → 触发 tier1 惩罚", func(t *testing.T) { + repo := &internal500AccountRepoStub{} + cache := &mockInternal500Cache{ + incrementCount: 1, + } + svc := &AntigravityGatewayService{ + accountRepo: repo, + internal500Cache: cache, + } + account := &Account{ID: 3, Name: "acc-3"} + + svc.handleInternal500RetryExhausted(context.Background(), "[test]", account) + + require.Len(t, cache.incrementCalls, 1) + require.Equal(t, int64(3), cache.incrementCalls[0]) + // tier1: SetTempUnschedulable + require.Len(t, repo.tempUnschedCalls, 1) + require.Equal(t, int64(3), repo.tempUnschedCalls[0].accountID) + require.Empty(t, repo.setErrorCalls) + }) + + t.Run("IncrementInternal500Count 返回 count=3 → 触发 tier3 永久禁用", func(t *testing.T) { + repo := &internal500AccountRepoStub{} + cache := &mockInternal500Cache{ + incrementCount: 3, + } + svc := &AntigravityGatewayService{ + accountRepo: repo, + internal500Cache: cache, + } + account := &Account{ID: 4, Name: "acc-4"} + + svc.handleInternal500RetryExhausted(context.Background(), "[test]", account) + + require.Len(t, cache.incrementCalls, 1) + require.Empty(t, repo.tempUnschedCalls) + require.Len(t, repo.setErrorCalls, 1) + require.Equal(t, int64(4), repo.setErrorCalls[0].accountID) + }) +} + +// ============================================================================= +// TestResetInternal500Counter +// ============================================================================= + +func TestResetInternal500Counter(t *testing.T) { + t.Run("internal500Cache 为 nil → 不 panic", func(t *testing.T) { + svc := &AntigravityGatewayService{ + internal500Cache: nil, + } + + require.NotPanics(t, func() { + svc.resetInternal500Counter(context.Background(), "[test]", 1) + }) + }) + + t.Run("ResetInternal500Count 返回 error → 不 panic(仅日志)", func(t *testing.T) { + cache := &mockInternal500Cache{ + resetErr: errors.New("redis timeout"), + } + svc := &AntigravityGatewayService{ + internal500Cache: cache, + } + + require.NotPanics(t, func() { + svc.resetInternal500Counter(context.Background(), "[test]", 42) + }) + require.Len(t, cache.resetCalls, 1) + require.Equal(t, int64(42), cache.resetCalls[0]) + }) + + t.Run("正常调用 → 调用 ResetInternal500Count", func(t *testing.T) { + cache := &mockInternal500Cache{} + svc := &AntigravityGatewayService{ + internal500Cache: cache, + } + + svc.resetInternal500Counter(context.Background(), "[test]", 99) + + require.Len(t, cache.resetCalls, 1) + require.Equal(t, int64(99), cache.resetCalls[0]) + }) +}