Cover isAntigravityInternalServerError body matching, applyInternal500Penalty tier escalation, handleInternal500RetryExhausted nil-safety and error handling, and resetInternal500Counter paths.
322 lines
11 KiB
Go
322 lines
11 KiB
Go
//go:build unit
|
||
|
||
package service
|
||
|
||
import (
|
||
"context"
|
||
"errors"
|
||
"testing"
|
||
"time"
|
||
|
||
"github.com/stretchr/testify/require"
|
||
)
|
||
|
||
// --- mock: Internal500CounterCache ---
|
||
|
||
type mockInternal500Cache struct {
|
||
incrementCount int64
|
||
incrementErr error
|
||
resetErr error
|
||
|
||
incrementCalls []int64 // 记录 IncrementInternal500Count 被调用时的 accountID
|
||
resetCalls []int64 // 记录 ResetInternal500Count 被调用时的 accountID
|
||
}
|
||
|
||
func (m *mockInternal500Cache) IncrementInternal500Count(_ context.Context, accountID int64) (int64, error) {
|
||
m.incrementCalls = append(m.incrementCalls, accountID)
|
||
return m.incrementCount, m.incrementErr
|
||
}
|
||
|
||
func (m *mockInternal500Cache) ResetInternal500Count(_ context.Context, accountID int64) error {
|
||
m.resetCalls = append(m.resetCalls, accountID)
|
||
return m.resetErr
|
||
}
|
||
|
||
// --- mock: 专用于 internal500 惩罚测试的 AccountRepository ---
|
||
|
||
type internal500AccountRepoStub struct {
|
||
AccountRepository // 嵌入接口,未实现的方法会 panic(不应被调用)
|
||
|
||
tempUnschedCalls []tempUnschedCall
|
||
setErrorCalls []setErrorCall
|
||
}
|
||
|
||
type tempUnschedCall struct {
|
||
accountID int64
|
||
until time.Time
|
||
reason string
|
||
}
|
||
|
||
type setErrorCall struct {
|
||
accountID int64
|
||
reason string
|
||
}
|
||
|
||
func (r *internal500AccountRepoStub) SetTempUnschedulable(_ context.Context, id int64, until time.Time, reason string) error {
|
||
r.tempUnschedCalls = append(r.tempUnschedCalls, tempUnschedCall{accountID: id, until: until, reason: reason})
|
||
return nil
|
||
}
|
||
|
||
func (r *internal500AccountRepoStub) SetError(_ context.Context, id int64, errorMsg string) error {
|
||
r.setErrorCalls = append(r.setErrorCalls, setErrorCall{accountID: id, reason: errorMsg})
|
||
return nil
|
||
}
|
||
|
||
// =============================================================================
|
||
// TestIsAntigravityInternalServerError
|
||
// =============================================================================
|
||
|
||
func TestIsAntigravityInternalServerError(t *testing.T) {
|
||
t.Run("匹配完整的 INTERNAL 500 body", func(t *testing.T) {
|
||
body := []byte(`{"error":{"code":500,"message":"Internal error encountered.","status":"INTERNAL"}}`)
|
||
require.True(t, isAntigravityInternalServerError(500, body))
|
||
})
|
||
|
||
t.Run("statusCode 不是 500", func(t *testing.T) {
|
||
body := []byte(`{"error":{"code":500,"message":"Internal error encountered.","status":"INTERNAL"}}`)
|
||
require.False(t, isAntigravityInternalServerError(429, body))
|
||
require.False(t, isAntigravityInternalServerError(503, body))
|
||
require.False(t, isAntigravityInternalServerError(200, body))
|
||
})
|
||
|
||
t.Run("body 中 message 不匹配", func(t *testing.T) {
|
||
body := []byte(`{"error":{"code":500,"message":"Some other error","status":"INTERNAL"}}`)
|
||
require.False(t, isAntigravityInternalServerError(500, body))
|
||
})
|
||
|
||
t.Run("body 中 status 不匹配", func(t *testing.T) {
|
||
body := []byte(`{"error":{"code":500,"message":"Internal error encountered.","status":"UNAVAILABLE"}}`)
|
||
require.False(t, isAntigravityInternalServerError(500, body))
|
||
})
|
||
|
||
t.Run("body 中 code 不匹配", func(t *testing.T) {
|
||
body := []byte(`{"error":{"code":503,"message":"Internal error encountered.","status":"INTERNAL"}}`)
|
||
require.False(t, isAntigravityInternalServerError(500, body))
|
||
})
|
||
|
||
t.Run("空 body", func(t *testing.T) {
|
||
require.False(t, isAntigravityInternalServerError(500, []byte{}))
|
||
require.False(t, isAntigravityInternalServerError(500, nil))
|
||
})
|
||
|
||
t.Run("其他 500 错误格式(纯文本)", func(t *testing.T) {
|
||
body := []byte(`Internal Server Error`)
|
||
require.False(t, isAntigravityInternalServerError(500, body))
|
||
})
|
||
|
||
t.Run("其他 500 错误格式(不同 JSON 结构)", func(t *testing.T) {
|
||
body := []byte(`{"message":"Internal Server Error","statusCode":500}`)
|
||
require.False(t, isAntigravityInternalServerError(500, body))
|
||
})
|
||
}
|
||
|
||
// =============================================================================
|
||
// TestApplyInternal500Penalty
|
||
// =============================================================================
|
||
|
||
func TestApplyInternal500Penalty(t *testing.T) {
|
||
t.Run("count=1 → SetTempUnschedulable 10 分钟", func(t *testing.T) {
|
||
repo := &internal500AccountRepoStub{}
|
||
svc := &AntigravityGatewayService{accountRepo: repo}
|
||
account := &Account{ID: 1, Name: "acc-1"}
|
||
|
||
before := time.Now()
|
||
svc.applyInternal500Penalty(context.Background(), "[test]", account, 1)
|
||
after := time.Now()
|
||
|
||
require.Len(t, repo.tempUnschedCalls, 1)
|
||
require.Empty(t, repo.setErrorCalls)
|
||
|
||
call := repo.tempUnschedCalls[0]
|
||
require.Equal(t, int64(1), call.accountID)
|
||
require.Contains(t, call.reason, "INTERNAL 500")
|
||
// until 应在 [before+10m, after+10m] 范围内
|
||
require.True(t, call.until.After(before.Add(internal500PenaltyTier1Duration).Add(-time.Second)))
|
||
require.True(t, call.until.Before(after.Add(internal500PenaltyTier1Duration).Add(time.Second)))
|
||
})
|
||
|
||
t.Run("count=2 → SetTempUnschedulable 10 小时", func(t *testing.T) {
|
||
repo := &internal500AccountRepoStub{}
|
||
svc := &AntigravityGatewayService{accountRepo: repo}
|
||
account := &Account{ID: 2, Name: "acc-2"}
|
||
|
||
before := time.Now()
|
||
svc.applyInternal500Penalty(context.Background(), "[test]", account, 2)
|
||
after := time.Now()
|
||
|
||
require.Len(t, repo.tempUnschedCalls, 1)
|
||
require.Empty(t, repo.setErrorCalls)
|
||
|
||
call := repo.tempUnschedCalls[0]
|
||
require.Equal(t, int64(2), call.accountID)
|
||
require.Contains(t, call.reason, "INTERNAL 500")
|
||
require.True(t, call.until.After(before.Add(internal500PenaltyTier2Duration).Add(-time.Second)))
|
||
require.True(t, call.until.Before(after.Add(internal500PenaltyTier2Duration).Add(time.Second)))
|
||
})
|
||
|
||
t.Run("count=3 → SetError 永久禁用", func(t *testing.T) {
|
||
repo := &internal500AccountRepoStub{}
|
||
svc := &AntigravityGatewayService{accountRepo: repo}
|
||
account := &Account{ID: 3, Name: "acc-3"}
|
||
|
||
svc.applyInternal500Penalty(context.Background(), "[test]", account, 3)
|
||
|
||
require.Empty(t, repo.tempUnschedCalls)
|
||
require.Len(t, repo.setErrorCalls, 1)
|
||
|
||
call := repo.setErrorCalls[0]
|
||
require.Equal(t, int64(3), call.accountID)
|
||
require.Contains(t, call.reason, "INTERNAL 500 consecutive failures: 3")
|
||
})
|
||
|
||
t.Run("count=5 → SetError 永久禁用(>=3 都走永久禁用)", func(t *testing.T) {
|
||
repo := &internal500AccountRepoStub{}
|
||
svc := &AntigravityGatewayService{accountRepo: repo}
|
||
account := &Account{ID: 5, Name: "acc-5"}
|
||
|
||
svc.applyInternal500Penalty(context.Background(), "[test]", account, 5)
|
||
|
||
require.Empty(t, repo.tempUnschedCalls)
|
||
require.Len(t, repo.setErrorCalls, 1)
|
||
|
||
call := repo.setErrorCalls[0]
|
||
require.Equal(t, int64(5), call.accountID)
|
||
require.Contains(t, call.reason, "INTERNAL 500 consecutive failures: 5")
|
||
})
|
||
|
||
t.Run("count=0 → 不调用任何方法", func(t *testing.T) {
|
||
repo := &internal500AccountRepoStub{}
|
||
svc := &AntigravityGatewayService{accountRepo: repo}
|
||
account := &Account{ID: 10, Name: "acc-10"}
|
||
|
||
svc.applyInternal500Penalty(context.Background(), "[test]", account, 0)
|
||
|
||
require.Empty(t, repo.tempUnschedCalls)
|
||
require.Empty(t, repo.setErrorCalls)
|
||
})
|
||
}
|
||
|
||
// =============================================================================
|
||
// TestHandleInternal500RetryExhausted
|
||
// =============================================================================
|
||
|
||
func TestHandleInternal500RetryExhausted(t *testing.T) {
|
||
t.Run("internal500Cache 为 nil → 不 panic,不调用任何方法", func(t *testing.T) {
|
||
repo := &internal500AccountRepoStub{}
|
||
svc := &AntigravityGatewayService{
|
||
accountRepo: repo,
|
||
internal500Cache: nil,
|
||
}
|
||
account := &Account{ID: 1, Name: "acc-1"}
|
||
|
||
// 不应 panic
|
||
require.NotPanics(t, func() {
|
||
svc.handleInternal500RetryExhausted(context.Background(), "[test]", account)
|
||
})
|
||
require.Empty(t, repo.tempUnschedCalls)
|
||
require.Empty(t, repo.setErrorCalls)
|
||
})
|
||
|
||
t.Run("IncrementInternal500Count 返回 error → 不调用惩罚方法", func(t *testing.T) {
|
||
repo := &internal500AccountRepoStub{}
|
||
cache := &mockInternal500Cache{
|
||
incrementErr: errors.New("redis connection error"),
|
||
}
|
||
svc := &AntigravityGatewayService{
|
||
accountRepo: repo,
|
||
internal500Cache: cache,
|
||
}
|
||
account := &Account{ID: 2, Name: "acc-2"}
|
||
|
||
svc.handleInternal500RetryExhausted(context.Background(), "[test]", account)
|
||
|
||
require.Len(t, cache.incrementCalls, 1)
|
||
require.Equal(t, int64(2), cache.incrementCalls[0])
|
||
require.Empty(t, repo.tempUnschedCalls)
|
||
require.Empty(t, repo.setErrorCalls)
|
||
})
|
||
|
||
t.Run("IncrementInternal500Count 返回 count=1 → 触发 tier1 惩罚", func(t *testing.T) {
|
||
repo := &internal500AccountRepoStub{}
|
||
cache := &mockInternal500Cache{
|
||
incrementCount: 1,
|
||
}
|
||
svc := &AntigravityGatewayService{
|
||
accountRepo: repo,
|
||
internal500Cache: cache,
|
||
}
|
||
account := &Account{ID: 3, Name: "acc-3"}
|
||
|
||
svc.handleInternal500RetryExhausted(context.Background(), "[test]", account)
|
||
|
||
require.Len(t, cache.incrementCalls, 1)
|
||
require.Equal(t, int64(3), cache.incrementCalls[0])
|
||
// tier1: SetTempUnschedulable
|
||
require.Len(t, repo.tempUnschedCalls, 1)
|
||
require.Equal(t, int64(3), repo.tempUnschedCalls[0].accountID)
|
||
require.Empty(t, repo.setErrorCalls)
|
||
})
|
||
|
||
t.Run("IncrementInternal500Count 返回 count=3 → 触发 tier3 永久禁用", func(t *testing.T) {
|
||
repo := &internal500AccountRepoStub{}
|
||
cache := &mockInternal500Cache{
|
||
incrementCount: 3,
|
||
}
|
||
svc := &AntigravityGatewayService{
|
||
accountRepo: repo,
|
||
internal500Cache: cache,
|
||
}
|
||
account := &Account{ID: 4, Name: "acc-4"}
|
||
|
||
svc.handleInternal500RetryExhausted(context.Background(), "[test]", account)
|
||
|
||
require.Len(t, cache.incrementCalls, 1)
|
||
require.Empty(t, repo.tempUnschedCalls)
|
||
require.Len(t, repo.setErrorCalls, 1)
|
||
require.Equal(t, int64(4), repo.setErrorCalls[0].accountID)
|
||
})
|
||
}
|
||
|
||
// =============================================================================
|
||
// TestResetInternal500Counter
|
||
// =============================================================================
|
||
|
||
func TestResetInternal500Counter(t *testing.T) {
|
||
t.Run("internal500Cache 为 nil → 不 panic", func(t *testing.T) {
|
||
svc := &AntigravityGatewayService{
|
||
internal500Cache: nil,
|
||
}
|
||
|
||
require.NotPanics(t, func() {
|
||
svc.resetInternal500Counter(context.Background(), "[test]", 1)
|
||
})
|
||
})
|
||
|
||
t.Run("ResetInternal500Count 返回 error → 不 panic(仅日志)", func(t *testing.T) {
|
||
cache := &mockInternal500Cache{
|
||
resetErr: errors.New("redis timeout"),
|
||
}
|
||
svc := &AntigravityGatewayService{
|
||
internal500Cache: cache,
|
||
}
|
||
|
||
require.NotPanics(t, func() {
|
||
svc.resetInternal500Counter(context.Background(), "[test]", 42)
|
||
})
|
||
require.Len(t, cache.resetCalls, 1)
|
||
require.Equal(t, int64(42), cache.resetCalls[0])
|
||
})
|
||
|
||
t.Run("正常调用 → 调用 ResetInternal500Count", func(t *testing.T) {
|
||
cache := &mockInternal500Cache{}
|
||
svc := &AntigravityGatewayService{
|
||
internal500Cache: cache,
|
||
}
|
||
|
||
svc.resetInternal500Counter(context.Background(), "[test]", 99)
|
||
|
||
require.Len(t, cache.resetCalls, 1)
|
||
require.Equal(t, int64(99), cache.resetCalls[0])
|
||
})
|
||
}
|