Files
sub2api/backend/internal/service/antigravity_internal500_penalty_test.go
erio d563eb2336 test: add unit tests for INTERNAL 500 progressive penalty
Cover isAntigravityInternalServerError body matching,
applyInternal500Penalty tier escalation, handleInternal500RetryExhausted
nil-safety and error handling, and resetInternal500Counter paths.
2026-03-27 20:11:24 +08:00

322 lines
11 KiB
Go
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

//go:build unit
package service
import (
"context"
"errors"
"testing"
"time"
"github.com/stretchr/testify/require"
)
// --- mock: Internal500CounterCache ---
type mockInternal500Cache struct {
incrementCount int64
incrementErr error
resetErr error
incrementCalls []int64 // 记录 IncrementInternal500Count 被调用时的 accountID
resetCalls []int64 // 记录 ResetInternal500Count 被调用时的 accountID
}
func (m *mockInternal500Cache) IncrementInternal500Count(_ context.Context, accountID int64) (int64, error) {
m.incrementCalls = append(m.incrementCalls, accountID)
return m.incrementCount, m.incrementErr
}
func (m *mockInternal500Cache) ResetInternal500Count(_ context.Context, accountID int64) error {
m.resetCalls = append(m.resetCalls, accountID)
return m.resetErr
}
// --- mock: 专用于 internal500 惩罚测试的 AccountRepository ---
type internal500AccountRepoStub struct {
AccountRepository // 嵌入接口,未实现的方法会 panic不应被调用
tempUnschedCalls []tempUnschedCall
setErrorCalls []setErrorCall
}
type tempUnschedCall struct {
accountID int64
until time.Time
reason string
}
type setErrorCall struct {
accountID int64
reason string
}
func (r *internal500AccountRepoStub) SetTempUnschedulable(_ context.Context, id int64, until time.Time, reason string) error {
r.tempUnschedCalls = append(r.tempUnschedCalls, tempUnschedCall{accountID: id, until: until, reason: reason})
return nil
}
func (r *internal500AccountRepoStub) SetError(_ context.Context, id int64, errorMsg string) error {
r.setErrorCalls = append(r.setErrorCalls, setErrorCall{accountID: id, reason: errorMsg})
return nil
}
// =============================================================================
// TestIsAntigravityInternalServerError
// =============================================================================
func TestIsAntigravityInternalServerError(t *testing.T) {
t.Run("匹配完整的 INTERNAL 500 body", func(t *testing.T) {
body := []byte(`{"error":{"code":500,"message":"Internal error encountered.","status":"INTERNAL"}}`)
require.True(t, isAntigravityInternalServerError(500, body))
})
t.Run("statusCode 不是 500", func(t *testing.T) {
body := []byte(`{"error":{"code":500,"message":"Internal error encountered.","status":"INTERNAL"}}`)
require.False(t, isAntigravityInternalServerError(429, body))
require.False(t, isAntigravityInternalServerError(503, body))
require.False(t, isAntigravityInternalServerError(200, body))
})
t.Run("body 中 message 不匹配", func(t *testing.T) {
body := []byte(`{"error":{"code":500,"message":"Some other error","status":"INTERNAL"}}`)
require.False(t, isAntigravityInternalServerError(500, body))
})
t.Run("body 中 status 不匹配", func(t *testing.T) {
body := []byte(`{"error":{"code":500,"message":"Internal error encountered.","status":"UNAVAILABLE"}}`)
require.False(t, isAntigravityInternalServerError(500, body))
})
t.Run("body 中 code 不匹配", func(t *testing.T) {
body := []byte(`{"error":{"code":503,"message":"Internal error encountered.","status":"INTERNAL"}}`)
require.False(t, isAntigravityInternalServerError(500, body))
})
t.Run("空 body", func(t *testing.T) {
require.False(t, isAntigravityInternalServerError(500, []byte{}))
require.False(t, isAntigravityInternalServerError(500, nil))
})
t.Run("其他 500 错误格式(纯文本)", func(t *testing.T) {
body := []byte(`Internal Server Error`)
require.False(t, isAntigravityInternalServerError(500, body))
})
t.Run("其他 500 错误格式(不同 JSON 结构)", func(t *testing.T) {
body := []byte(`{"message":"Internal Server Error","statusCode":500}`)
require.False(t, isAntigravityInternalServerError(500, body))
})
}
// =============================================================================
// TestApplyInternal500Penalty
// =============================================================================
func TestApplyInternal500Penalty(t *testing.T) {
t.Run("count=1 → SetTempUnschedulable 10 分钟", func(t *testing.T) {
repo := &internal500AccountRepoStub{}
svc := &AntigravityGatewayService{accountRepo: repo}
account := &Account{ID: 1, Name: "acc-1"}
before := time.Now()
svc.applyInternal500Penalty(context.Background(), "[test]", account, 1)
after := time.Now()
require.Len(t, repo.tempUnschedCalls, 1)
require.Empty(t, repo.setErrorCalls)
call := repo.tempUnschedCalls[0]
require.Equal(t, int64(1), call.accountID)
require.Contains(t, call.reason, "INTERNAL 500")
// until 应在 [before+10m, after+10m] 范围内
require.True(t, call.until.After(before.Add(internal500PenaltyTier1Duration).Add(-time.Second)))
require.True(t, call.until.Before(after.Add(internal500PenaltyTier1Duration).Add(time.Second)))
})
t.Run("count=2 → SetTempUnschedulable 10 小时", func(t *testing.T) {
repo := &internal500AccountRepoStub{}
svc := &AntigravityGatewayService{accountRepo: repo}
account := &Account{ID: 2, Name: "acc-2"}
before := time.Now()
svc.applyInternal500Penalty(context.Background(), "[test]", account, 2)
after := time.Now()
require.Len(t, repo.tempUnschedCalls, 1)
require.Empty(t, repo.setErrorCalls)
call := repo.tempUnschedCalls[0]
require.Equal(t, int64(2), call.accountID)
require.Contains(t, call.reason, "INTERNAL 500")
require.True(t, call.until.After(before.Add(internal500PenaltyTier2Duration).Add(-time.Second)))
require.True(t, call.until.Before(after.Add(internal500PenaltyTier2Duration).Add(time.Second)))
})
t.Run("count=3 → SetError 永久禁用", func(t *testing.T) {
repo := &internal500AccountRepoStub{}
svc := &AntigravityGatewayService{accountRepo: repo}
account := &Account{ID: 3, Name: "acc-3"}
svc.applyInternal500Penalty(context.Background(), "[test]", account, 3)
require.Empty(t, repo.tempUnschedCalls)
require.Len(t, repo.setErrorCalls, 1)
call := repo.setErrorCalls[0]
require.Equal(t, int64(3), call.accountID)
require.Contains(t, call.reason, "INTERNAL 500 consecutive failures: 3")
})
t.Run("count=5 → SetError 永久禁用(>=3 都走永久禁用)", func(t *testing.T) {
repo := &internal500AccountRepoStub{}
svc := &AntigravityGatewayService{accountRepo: repo}
account := &Account{ID: 5, Name: "acc-5"}
svc.applyInternal500Penalty(context.Background(), "[test]", account, 5)
require.Empty(t, repo.tempUnschedCalls)
require.Len(t, repo.setErrorCalls, 1)
call := repo.setErrorCalls[0]
require.Equal(t, int64(5), call.accountID)
require.Contains(t, call.reason, "INTERNAL 500 consecutive failures: 5")
})
t.Run("count=0 → 不调用任何方法", func(t *testing.T) {
repo := &internal500AccountRepoStub{}
svc := &AntigravityGatewayService{accountRepo: repo}
account := &Account{ID: 10, Name: "acc-10"}
svc.applyInternal500Penalty(context.Background(), "[test]", account, 0)
require.Empty(t, repo.tempUnschedCalls)
require.Empty(t, repo.setErrorCalls)
})
}
// =============================================================================
// TestHandleInternal500RetryExhausted
// =============================================================================
func TestHandleInternal500RetryExhausted(t *testing.T) {
t.Run("internal500Cache 为 nil → 不 panic不调用任何方法", func(t *testing.T) {
repo := &internal500AccountRepoStub{}
svc := &AntigravityGatewayService{
accountRepo: repo,
internal500Cache: nil,
}
account := &Account{ID: 1, Name: "acc-1"}
// 不应 panic
require.NotPanics(t, func() {
svc.handleInternal500RetryExhausted(context.Background(), "[test]", account)
})
require.Empty(t, repo.tempUnschedCalls)
require.Empty(t, repo.setErrorCalls)
})
t.Run("IncrementInternal500Count 返回 error → 不调用惩罚方法", func(t *testing.T) {
repo := &internal500AccountRepoStub{}
cache := &mockInternal500Cache{
incrementErr: errors.New("redis connection error"),
}
svc := &AntigravityGatewayService{
accountRepo: repo,
internal500Cache: cache,
}
account := &Account{ID: 2, Name: "acc-2"}
svc.handleInternal500RetryExhausted(context.Background(), "[test]", account)
require.Len(t, cache.incrementCalls, 1)
require.Equal(t, int64(2), cache.incrementCalls[0])
require.Empty(t, repo.tempUnschedCalls)
require.Empty(t, repo.setErrorCalls)
})
t.Run("IncrementInternal500Count 返回 count=1 → 触发 tier1 惩罚", func(t *testing.T) {
repo := &internal500AccountRepoStub{}
cache := &mockInternal500Cache{
incrementCount: 1,
}
svc := &AntigravityGatewayService{
accountRepo: repo,
internal500Cache: cache,
}
account := &Account{ID: 3, Name: "acc-3"}
svc.handleInternal500RetryExhausted(context.Background(), "[test]", account)
require.Len(t, cache.incrementCalls, 1)
require.Equal(t, int64(3), cache.incrementCalls[0])
// tier1: SetTempUnschedulable
require.Len(t, repo.tempUnschedCalls, 1)
require.Equal(t, int64(3), repo.tempUnschedCalls[0].accountID)
require.Empty(t, repo.setErrorCalls)
})
t.Run("IncrementInternal500Count 返回 count=3 → 触发 tier3 永久禁用", func(t *testing.T) {
repo := &internal500AccountRepoStub{}
cache := &mockInternal500Cache{
incrementCount: 3,
}
svc := &AntigravityGatewayService{
accountRepo: repo,
internal500Cache: cache,
}
account := &Account{ID: 4, Name: "acc-4"}
svc.handleInternal500RetryExhausted(context.Background(), "[test]", account)
require.Len(t, cache.incrementCalls, 1)
require.Empty(t, repo.tempUnschedCalls)
require.Len(t, repo.setErrorCalls, 1)
require.Equal(t, int64(4), repo.setErrorCalls[0].accountID)
})
}
// =============================================================================
// TestResetInternal500Counter
// =============================================================================
func TestResetInternal500Counter(t *testing.T) {
t.Run("internal500Cache 为 nil → 不 panic", func(t *testing.T) {
svc := &AntigravityGatewayService{
internal500Cache: nil,
}
require.NotPanics(t, func() {
svc.resetInternal500Counter(context.Background(), "[test]", 1)
})
})
t.Run("ResetInternal500Count 返回 error → 不 panic仅日志", func(t *testing.T) {
cache := &mockInternal500Cache{
resetErr: errors.New("redis timeout"),
}
svc := &AntigravityGatewayService{
internal500Cache: cache,
}
require.NotPanics(t, func() {
svc.resetInternal500Counter(context.Background(), "[test]", 42)
})
require.Len(t, cache.resetCalls, 1)
require.Equal(t, int64(42), cache.resetCalls[0])
})
t.Run("正常调用 → 调用 ResetInternal500Count", func(t *testing.T) {
cache := &mockInternal500Cache{}
svc := &AntigravityGatewayService{
internal500Cache: cache,
}
svc.resetInternal500Counter(context.Background(), "[test]", 99)
require.Len(t, cache.resetCalls, 1)
require.Equal(t, int64(99), cache.resetCalls[0])
})
}