//go:build unit package service import ( "context" "errors" "strings" "testing" "github.com/Wei-Shaw/sub2api/internal/model" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" ) // mockErrorPassthroughRepo 用于测试的 mock repository type mockErrorPassthroughRepo struct { rules []*model.ErrorPassthroughRule listErr error getErr error createErr error updateErr error deleteErr error } type mockErrorPassthroughCache struct { rules []*model.ErrorPassthroughRule hasData bool getCalled int setCalled int invalidateCalled int notifyCalled int } func newMockErrorPassthroughCache(rules []*model.ErrorPassthroughRule, hasData bool) *mockErrorPassthroughCache { return &mockErrorPassthroughCache{ rules: cloneRules(rules), hasData: hasData, } } func (m *mockErrorPassthroughCache) Get(ctx context.Context) ([]*model.ErrorPassthroughRule, bool) { m.getCalled++ if !m.hasData { return nil, false } return cloneRules(m.rules), true } func (m *mockErrorPassthroughCache) Set(ctx context.Context, rules []*model.ErrorPassthroughRule) error { m.setCalled++ m.rules = cloneRules(rules) m.hasData = true return nil } func (m *mockErrorPassthroughCache) Invalidate(ctx context.Context) error { m.invalidateCalled++ m.rules = nil m.hasData = false return nil } func (m *mockErrorPassthroughCache) NotifyUpdate(ctx context.Context) error { m.notifyCalled++ return nil } func (m *mockErrorPassthroughCache) SubscribeUpdates(ctx context.Context, handler func()) { // 单测中无需订阅行为 } func cloneRules(rules []*model.ErrorPassthroughRule) []*model.ErrorPassthroughRule { if rules == nil { return nil } out := make([]*model.ErrorPassthroughRule, len(rules)) copy(out, rules) return out } func (m *mockErrorPassthroughRepo) List(ctx context.Context) ([]*model.ErrorPassthroughRule, error) { if m.listErr != nil { return nil, m.listErr } return m.rules, nil } func (m *mockErrorPassthroughRepo) GetByID(ctx context.Context, id int64) (*model.ErrorPassthroughRule, error) { if m.getErr != nil { return nil, m.getErr } for _, r := range m.rules { if r.ID == id { return r, nil } } return nil, nil } func (m *mockErrorPassthroughRepo) Create(ctx context.Context, rule *model.ErrorPassthroughRule) (*model.ErrorPassthroughRule, error) { if m.createErr != nil { return nil, m.createErr } rule.ID = int64(len(m.rules) + 1) m.rules = append(m.rules, rule) return rule, nil } func (m *mockErrorPassthroughRepo) Update(ctx context.Context, rule *model.ErrorPassthroughRule) (*model.ErrorPassthroughRule, error) { if m.updateErr != nil { return nil, m.updateErr } for i, r := range m.rules { if r.ID == rule.ID { m.rules[i] = rule return rule, nil } } return rule, nil } func (m *mockErrorPassthroughRepo) Delete(ctx context.Context, id int64) error { if m.deleteErr != nil { return m.deleteErr } for i, r := range m.rules { if r.ID == id { m.rules = append(m.rules[:i], m.rules[i+1:]...) return nil } } return nil } // newTestService 创建测试用的服务实例 func newTestService(rules []*model.ErrorPassthroughRule) *ErrorPassthroughService { repo := &mockErrorPassthroughRepo{rules: rules} svc := &ErrorPassthroughService{ repo: repo, cache: nil, // 不使用缓存 } // 直接设置本地缓存,避免调用 refreshLocalCache svc.setLocalCache(rules) return svc } // ============================================================================= // 测试 ruleMatches 核心匹配逻辑 // ============================================================================= func TestRuleMatches_NoConditions(t *testing.T) { // 没有配置任何条件时,不应该匹配 svc := newTestService(nil) rule := &model.ErrorPassthroughRule{ Enabled: true, ErrorCodes: []int{}, Keywords: []string{}, MatchMode: model.MatchModeAny, } assert.False(t, svc.ruleMatches(rule, 422, "some error message"), "没有配置条件时不应该匹配") } func TestRuleMatches_OnlyErrorCodes_AnyMode(t *testing.T) { svc := newTestService(nil) rule := &model.ErrorPassthroughRule{ Enabled: true, ErrorCodes: []int{422, 400}, Keywords: []string{}, MatchMode: model.MatchModeAny, } tests := []struct { name string statusCode int body string expected bool }{ {"状态码匹配 422", 422, "any message", true}, {"状态码匹配 400", 400, "any message", true}, {"状态码不匹配 500", 500, "any message", false}, {"状态码不匹配 429", 429, "any message", false}, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { result := svc.ruleMatches(rule, tt.statusCode, tt.body) assert.Equal(t, tt.expected, result) }) } } func TestRuleMatches_OnlyKeywords_AnyMode(t *testing.T) { svc := newTestService(nil) rule := &model.ErrorPassthroughRule{ Enabled: true, ErrorCodes: []int{}, Keywords: []string{"context limit", "model not supported"}, MatchMode: model.MatchModeAny, } tests := []struct { name string statusCode int body string expected bool }{ {"关键词匹配 context limit", 500, "error: context limit reached", true}, {"关键词匹配 model not supported", 400, "the model not supported here", true}, {"关键词不匹配", 422, "some other error", false}, // 注意:ruleMatches 接收的 body 参数应该是已经转换为小写的 // 实际使用时,MatchRule 会先将 body 转换为小写再传给 ruleMatches {"关键词大小写 - 输入已小写", 500, "context limit exceeded", true}, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { // 模拟 MatchRule 的行为:先转换为小写 bodyLower := strings.ToLower(tt.body) result := svc.ruleMatches(rule, tt.statusCode, bodyLower) assert.Equal(t, tt.expected, result) }) } } func TestRuleMatches_BothConditions_AnyMode(t *testing.T) { // any 模式:错误码 OR 关键词 svc := newTestService(nil) rule := &model.ErrorPassthroughRule{ Enabled: true, ErrorCodes: []int{422, 400}, Keywords: []string{"context limit"}, MatchMode: model.MatchModeAny, } tests := []struct { name string statusCode int body string expected bool reason string }{ { name: "状态码和关键词都匹配", statusCode: 422, body: "context limit reached", expected: true, reason: "both match", }, { name: "只有状态码匹配", statusCode: 422, body: "some other error", expected: true, reason: "code matches, keyword doesn't - OR mode should match", }, { name: "只有关键词匹配", statusCode: 500, body: "context limit exceeded", expected: true, reason: "keyword matches, code doesn't - OR mode should match", }, { name: "都不匹配", statusCode: 500, body: "some other error", expected: false, reason: "neither matches", }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { result := svc.ruleMatches(rule, tt.statusCode, tt.body) assert.Equal(t, tt.expected, result, tt.reason) }) } } func TestRuleMatches_BothConditions_AllMode(t *testing.T) { // all 模式:错误码 AND 关键词 svc := newTestService(nil) rule := &model.ErrorPassthroughRule{ Enabled: true, ErrorCodes: []int{422, 400}, Keywords: []string{"context limit"}, MatchMode: model.MatchModeAll, } tests := []struct { name string statusCode int body string expected bool reason string }{ { name: "状态码和关键词都匹配", statusCode: 422, body: "context limit reached", expected: true, reason: "both match - AND mode should match", }, { name: "只有状态码匹配", statusCode: 422, body: "some other error", expected: false, reason: "code matches but keyword doesn't - AND mode should NOT match", }, { name: "只有关键词匹配", statusCode: 500, body: "context limit exceeded", expected: false, reason: "keyword matches but code doesn't - AND mode should NOT match", }, { name: "都不匹配", statusCode: 500, body: "some other error", expected: false, reason: "neither matches", }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { result := svc.ruleMatches(rule, tt.statusCode, tt.body) assert.Equal(t, tt.expected, result, tt.reason) }) } } // ============================================================================= // 测试 platformMatches 平台匹配逻辑 // ============================================================================= func TestPlatformMatches(t *testing.T) { svc := newTestService(nil) tests := []struct { name string rulePlatforms []string requestPlatform string expected bool }{ { name: "空平台列表匹配所有", rulePlatforms: []string{}, requestPlatform: "anthropic", expected: true, }, { name: "nil平台列表匹配所有", rulePlatforms: nil, requestPlatform: "openai", expected: true, }, { name: "精确匹配 anthropic", rulePlatforms: []string{"anthropic", "openai"}, requestPlatform: "anthropic", expected: true, }, { name: "精确匹配 openai", rulePlatforms: []string{"anthropic", "openai"}, requestPlatform: "openai", expected: true, }, { name: "不匹配 gemini", rulePlatforms: []string{"anthropic", "openai"}, requestPlatform: "gemini", expected: false, }, { name: "大小写不敏感", rulePlatforms: []string{"Anthropic", "OpenAI"}, requestPlatform: "anthropic", expected: true, }, { name: "匹配 antigravity", rulePlatforms: []string{"antigravity"}, requestPlatform: "antigravity", expected: true, }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { rule := &model.ErrorPassthroughRule{ Platforms: tt.rulePlatforms, } result := svc.platformMatches(rule, tt.requestPlatform) assert.Equal(t, tt.expected, result) }) } } // ============================================================================= // 测试 MatchRule 完整匹配流程 // ============================================================================= func TestMatchRule_Priority(t *testing.T) { // 测试规则按优先级排序,优先级小的先匹配 rules := []*model.ErrorPassthroughRule{ { ID: 1, Name: "Low Priority", Enabled: true, Priority: 10, ErrorCodes: []int{422}, MatchMode: model.MatchModeAny, }, { ID: 2, Name: "High Priority", Enabled: true, Priority: 1, ErrorCodes: []int{422}, MatchMode: model.MatchModeAny, }, } svc := newTestService(rules) matched := svc.MatchRule("anthropic", 422, []byte("error")) require.NotNil(t, matched) assert.Equal(t, int64(2), matched.ID, "应该匹配优先级更高(数值更小)的规则") assert.Equal(t, "High Priority", matched.Name) } func TestMatchRule_DisabledRule(t *testing.T) { rules := []*model.ErrorPassthroughRule{ { ID: 1, Name: "Disabled Rule", Enabled: false, Priority: 1, ErrorCodes: []int{422}, MatchMode: model.MatchModeAny, }, { ID: 2, Name: "Enabled Rule", Enabled: true, Priority: 10, ErrorCodes: []int{422}, MatchMode: model.MatchModeAny, }, } svc := newTestService(rules) matched := svc.MatchRule("anthropic", 422, []byte("error")) require.NotNil(t, matched) assert.Equal(t, int64(2), matched.ID, "应该跳过禁用的规则") } func TestMatchRule_PlatformFilter(t *testing.T) { rules := []*model.ErrorPassthroughRule{ { ID: 1, Name: "Anthropic Only", Enabled: true, Priority: 1, ErrorCodes: []int{422}, Platforms: []string{"anthropic"}, MatchMode: model.MatchModeAny, }, { ID: 2, Name: "OpenAI Only", Enabled: true, Priority: 2, ErrorCodes: []int{422}, Platforms: []string{"openai"}, MatchMode: model.MatchModeAny, }, { ID: 3, Name: "All Platforms", Enabled: true, Priority: 3, ErrorCodes: []int{422}, Platforms: []string{}, MatchMode: model.MatchModeAny, }, } svc := newTestService(rules) t.Run("Anthropic 请求匹配 Anthropic 规则", func(t *testing.T) { matched := svc.MatchRule("anthropic", 422, []byte("error")) require.NotNil(t, matched) assert.Equal(t, int64(1), matched.ID) }) t.Run("OpenAI 请求匹配 OpenAI 规则", func(t *testing.T) { matched := svc.MatchRule("openai", 422, []byte("error")) require.NotNil(t, matched) assert.Equal(t, int64(2), matched.ID) }) t.Run("Gemini 请求匹配全平台规则", func(t *testing.T) { matched := svc.MatchRule("gemini", 422, []byte("error")) require.NotNil(t, matched) assert.Equal(t, int64(3), matched.ID) }) t.Run("Antigravity 请求匹配全平台规则", func(t *testing.T) { matched := svc.MatchRule("antigravity", 422, []byte("error")) require.NotNil(t, matched) assert.Equal(t, int64(3), matched.ID) }) } func TestMatchRule_NoMatch(t *testing.T) { rules := []*model.ErrorPassthroughRule{ { ID: 1, Name: "Rule for 422", Enabled: true, Priority: 1, ErrorCodes: []int{422}, MatchMode: model.MatchModeAny, }, } svc := newTestService(rules) matched := svc.MatchRule("anthropic", 500, []byte("error")) assert.Nil(t, matched, "不匹配任何规则时应返回 nil") } func TestMatchRule_EmptyRules(t *testing.T) { svc := newTestService([]*model.ErrorPassthroughRule{}) matched := svc.MatchRule("anthropic", 422, []byte("error")) assert.Nil(t, matched, "没有规则时应返回 nil") } func TestMatchRule_CaseInsensitiveKeyword(t *testing.T) { rules := []*model.ErrorPassthroughRule{ { ID: 1, Name: "Context Limit", Enabled: true, Priority: 1, Keywords: []string{"Context Limit"}, MatchMode: model.MatchModeAny, }, } svc := newTestService(rules) tests := []struct { name string body string expected bool }{ {"完全匹配", "Context Limit reached", true}, {"小写匹配", "context limit reached", true}, {"大写匹配", "CONTEXT LIMIT REACHED", true}, {"混合大小写", "ConTeXt LiMiT error", true}, {"不匹配", "some other error", false}, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { matched := svc.MatchRule("anthropic", 500, []byte(tt.body)) if tt.expected { assert.NotNil(t, matched) } else { assert.Nil(t, matched) } }) } } // ============================================================================= // 测试真实场景 // ============================================================================= func TestMatchRule_RealWorldScenario_ContextLimitPassthrough(t *testing.T) { // 场景:上游返回 422 + "context limit has been reached",需要透传给客户端 rules := []*model.ErrorPassthroughRule{ { ID: 1, Name: "Context Limit Passthrough", Enabled: true, Priority: 1, ErrorCodes: []int{422}, Keywords: []string{"context limit"}, MatchMode: model.MatchModeAll, // 必须同时满足 Platforms: []string{"anthropic", "antigravity"}, PassthroughCode: true, PassthroughBody: true, }, } svc := newTestService(rules) // 测试 Anthropic 平台 t.Run("Anthropic 422 with context limit", func(t *testing.T) { body := []byte(`{"type":"error","error":{"type":"invalid_request","message":"The context limit has been reached"}}`) matched := svc.MatchRule("anthropic", 422, body) require.NotNil(t, matched) assert.True(t, matched.PassthroughCode) assert.True(t, matched.PassthroughBody) }) // 测试 Antigravity 平台 t.Run("Antigravity 422 with context limit", func(t *testing.T) { body := []byte(`{"error":"context limit exceeded"}`) matched := svc.MatchRule("antigravity", 422, body) require.NotNil(t, matched) }) // 测试 OpenAI 平台(不在规则的平台列表中) t.Run("OpenAI should not match", func(t *testing.T) { body := []byte(`{"error":"context limit exceeded"}`) matched := svc.MatchRule("openai", 422, body) assert.Nil(t, matched, "OpenAI 不在规则的平台列表中") }) // 测试状态码不匹配 t.Run("Wrong status code", func(t *testing.T) { body := []byte(`{"error":"context limit exceeded"}`) matched := svc.MatchRule("anthropic", 400, body) assert.Nil(t, matched, "状态码不匹配") }) // 测试关键词不匹配 t.Run("Wrong keyword", func(t *testing.T) { body := []byte(`{"error":"rate limit exceeded"}`) matched := svc.MatchRule("anthropic", 422, body) assert.Nil(t, matched, "关键词不匹配") }) } func TestMatchRule_RealWorldScenario_CustomErrorMessage(t *testing.T) { // 场景:某些错误需要返回自定义消息,隐藏上游详细信息 customMsg := "Service temporarily unavailable, please try again later" responseCode := 503 rules := []*model.ErrorPassthroughRule{ { ID: 1, Name: "Hide Internal Errors", Enabled: true, Priority: 1, ErrorCodes: []int{500, 502, 503}, MatchMode: model.MatchModeAny, PassthroughCode: false, ResponseCode: &responseCode, PassthroughBody: false, CustomMessage: &customMsg, }, } svc := newTestService(rules) matched := svc.MatchRule("anthropic", 500, []byte("internal server error")) require.NotNil(t, matched) assert.False(t, matched.PassthroughCode) assert.Equal(t, 503, *matched.ResponseCode) assert.False(t, matched.PassthroughBody) assert.Equal(t, customMsg, *matched.CustomMessage) } // ============================================================================= // 测试 model.Validate // ============================================================================= func TestErrorPassthroughRule_Validate(t *testing.T) { tests := []struct { name string rule *model.ErrorPassthroughRule expectError bool errorField string }{ { name: "有效规则 - 透传模式(含错误码)", rule: &model.ErrorPassthroughRule{ Name: "Valid Rule", MatchMode: model.MatchModeAny, ErrorCodes: []int{422}, PassthroughCode: true, PassthroughBody: true, }, expectError: false, }, { name: "有效规则 - 透传模式(含关键词)", rule: &model.ErrorPassthroughRule{ Name: "Valid Rule", MatchMode: model.MatchModeAny, Keywords: []string{"context limit"}, PassthroughCode: true, PassthroughBody: true, }, expectError: false, }, { name: "有效规则 - 自定义响应", rule: &model.ErrorPassthroughRule{ Name: "Valid Rule", MatchMode: model.MatchModeAll, ErrorCodes: []int{500}, Keywords: []string{"internal error"}, PassthroughCode: false, ResponseCode: testIntPtr(503), PassthroughBody: false, CustomMessage: testStrPtr("Custom error"), }, expectError: false, }, { name: "缺少名称", rule: &model.ErrorPassthroughRule{ Name: "", MatchMode: model.MatchModeAny, ErrorCodes: []int{422}, PassthroughCode: true, PassthroughBody: true, }, expectError: true, errorField: "name", }, { name: "无效的匹配模式", rule: &model.ErrorPassthroughRule{ Name: "Invalid Mode", MatchMode: "invalid", ErrorCodes: []int{422}, PassthroughCode: true, PassthroughBody: true, }, expectError: true, errorField: "match_mode", }, { name: "缺少匹配条件(错误码和关键词都为空)", rule: &model.ErrorPassthroughRule{ Name: "No Conditions", MatchMode: model.MatchModeAny, ErrorCodes: []int{}, Keywords: []string{}, PassthroughCode: true, PassthroughBody: true, }, expectError: true, errorField: "conditions", }, { name: "缺少匹配条件(nil切片)", rule: &model.ErrorPassthroughRule{ Name: "Nil Conditions", MatchMode: model.MatchModeAny, ErrorCodes: nil, Keywords: nil, PassthroughCode: true, PassthroughBody: true, }, expectError: true, errorField: "conditions", }, { name: "自定义状态码但未提供值", rule: &model.ErrorPassthroughRule{ Name: "Missing Code", MatchMode: model.MatchModeAny, ErrorCodes: []int{422}, PassthroughCode: false, ResponseCode: nil, PassthroughBody: true, }, expectError: true, errorField: "response_code", }, { name: "自定义消息但未提供值", rule: &model.ErrorPassthroughRule{ Name: "Missing Message", MatchMode: model.MatchModeAny, ErrorCodes: []int{422}, PassthroughCode: true, PassthroughBody: false, CustomMessage: nil, }, expectError: true, errorField: "custom_message", }, { name: "自定义消息为空字符串", rule: &model.ErrorPassthroughRule{ Name: "Empty Message", MatchMode: model.MatchModeAny, ErrorCodes: []int{422}, PassthroughCode: true, PassthroughBody: false, CustomMessage: testStrPtr(""), }, expectError: true, errorField: "custom_message", }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { err := tt.rule.Validate() if tt.expectError { require.Error(t, err) validationErr, ok := err.(*model.ValidationError) require.True(t, ok, "应该返回 ValidationError") assert.Equal(t, tt.errorField, validationErr.Field) } else { assert.NoError(t, err) } }) } } // ============================================================================= // 测试写路径缓存刷新(Create/Update/Delete) // ============================================================================= func TestCreate_ForceRefreshCacheAfterWrite(t *testing.T) { ctx := context.Background() staleRule := newPassthroughRuleForWritePathTest(99, "service temporarily unavailable after multiple", "旧缓存消息") repo := &mockErrorPassthroughRepo{rules: []*model.ErrorPassthroughRule{}} cache := newMockErrorPassthroughCache([]*model.ErrorPassthroughRule{staleRule}, true) svc := &ErrorPassthroughService{repo: repo, cache: cache} svc.setLocalCache([]*model.ErrorPassthroughRule{staleRule}) newRule := newPassthroughRuleForWritePathTest(0, "service temporarily unavailable after multiple", "上游请求失败") created, err := svc.Create(ctx, newRule) require.NoError(t, err) require.NotNil(t, created) body := []byte(`{"message":"Service temporarily unavailable after multiple retries, please try again later"}`) matched := svc.MatchRule("anthropic", 503, body) require.NotNil(t, matched) assert.Equal(t, created.ID, matched.ID) if assert.NotNil(t, matched.CustomMessage) { assert.Equal(t, "上游请求失败", *matched.CustomMessage) } assert.Equal(t, 0, cache.getCalled, "写路径刷新不应依赖 cache.Get") assert.Equal(t, 1, cache.invalidateCalled) assert.Equal(t, 1, cache.setCalled) assert.Equal(t, 1, cache.notifyCalled) } func TestUpdate_ForceRefreshCacheAfterWrite(t *testing.T) { ctx := context.Background() originalRule := newPassthroughRuleForWritePathTest(1, "old keyword", "旧消息") repo := &mockErrorPassthroughRepo{rules: []*model.ErrorPassthroughRule{originalRule}} cache := newMockErrorPassthroughCache([]*model.ErrorPassthroughRule{originalRule}, true) svc := &ErrorPassthroughService{repo: repo, cache: cache} svc.setLocalCache([]*model.ErrorPassthroughRule{originalRule}) updatedRule := newPassthroughRuleForWritePathTest(1, "new keyword", "新消息") _, err := svc.Update(ctx, updatedRule) require.NoError(t, err) oldBody := []byte(`{"message":"old keyword"}`) oldMatched := svc.MatchRule("anthropic", 503, oldBody) assert.Nil(t, oldMatched, "更新后旧关键词不应继续命中") newBody := []byte(`{"message":"new keyword"}`) newMatched := svc.MatchRule("anthropic", 503, newBody) require.NotNil(t, newMatched) if assert.NotNil(t, newMatched.CustomMessage) { assert.Equal(t, "新消息", *newMatched.CustomMessage) } assert.Equal(t, 0, cache.getCalled, "写路径刷新不应依赖 cache.Get") assert.Equal(t, 1, cache.invalidateCalled) assert.Equal(t, 1, cache.setCalled) assert.Equal(t, 1, cache.notifyCalled) } func TestDelete_ForceRefreshCacheAfterWrite(t *testing.T) { ctx := context.Background() rule := newPassthroughRuleForWritePathTest(1, "to be deleted", "删除前消息") repo := &mockErrorPassthroughRepo{rules: []*model.ErrorPassthroughRule{rule}} cache := newMockErrorPassthroughCache([]*model.ErrorPassthroughRule{rule}, true) svc := &ErrorPassthroughService{repo: repo, cache: cache} svc.setLocalCache([]*model.ErrorPassthroughRule{rule}) err := svc.Delete(ctx, 1) require.NoError(t, err) body := []byte(`{"message":"to be deleted"}`) matched := svc.MatchRule("anthropic", 503, body) assert.Nil(t, matched, "删除后规则不应再命中") assert.Equal(t, 0, cache.getCalled, "写路径刷新不应依赖 cache.Get") assert.Equal(t, 1, cache.invalidateCalled) assert.Equal(t, 1, cache.setCalled) assert.Equal(t, 1, cache.notifyCalled) } func TestNewService_StartupReloadFromDBToHealStaleCache(t *testing.T) { staleRule := newPassthroughRuleForWritePathTest(99, "stale keyword", "旧缓存消息") latestRule := newPassthroughRuleForWritePathTest(1, "fresh keyword", "最新消息") repo := &mockErrorPassthroughRepo{rules: []*model.ErrorPassthroughRule{latestRule}} cache := newMockErrorPassthroughCache([]*model.ErrorPassthroughRule{staleRule}, true) svc := NewErrorPassthroughService(repo, cache) matchedFresh := svc.MatchRule("anthropic", 503, []byte(`{"message":"fresh keyword"}`)) require.NotNil(t, matchedFresh) assert.Equal(t, int64(1), matchedFresh.ID) matchedStale := svc.MatchRule("anthropic", 503, []byte(`{"message":"stale keyword"}`)) assert.Nil(t, matchedStale, "启动后应以 DB 最新规则覆盖旧缓存") assert.Equal(t, 0, cache.getCalled, "启动强制 DB 刷新不应依赖 cache.Get") assert.Equal(t, 1, cache.setCalled, "启动后应回写缓存,覆盖陈旧缓存") } func TestUpdate_RefreshFailureShouldNotKeepStaleEnabledRule(t *testing.T) { ctx := context.Background() staleRule := newPassthroughRuleForWritePathTest(1, "service temporarily unavailable after multiple", "旧缓存消息") repo := &mockErrorPassthroughRepo{ rules: []*model.ErrorPassthroughRule{staleRule}, listErr: errors.New("db list failed"), } cache := newMockErrorPassthroughCache([]*model.ErrorPassthroughRule{staleRule}, true) svc := &ErrorPassthroughService{repo: repo, cache: cache} svc.setLocalCache([]*model.ErrorPassthroughRule{staleRule}) disabledRule := *staleRule disabledRule.Enabled = false _, err := svc.Update(ctx, &disabledRule) require.NoError(t, err) body := []byte(`{"message":"Service temporarily unavailable after multiple retries, please try again later"}`) matched := svc.MatchRule("anthropic", 503, body) assert.Nil(t, matched, "刷新失败时不应继续命中旧的启用规则") svc.localCacheMu.RLock() assert.Nil(t, svc.localCache, "刷新失败后应清空本地缓存,避免误命中") svc.localCacheMu.RUnlock() } func newPassthroughRuleForWritePathTest(id int64, keyword, customMsg string) *model.ErrorPassthroughRule { responseCode := 503 rule := &model.ErrorPassthroughRule{ ID: id, Name: "write-path-cache-refresh", Enabled: true, Priority: 1, ErrorCodes: []int{503}, Keywords: []string{keyword}, MatchMode: model.MatchModeAll, PassthroughCode: false, ResponseCode: &responseCode, PassthroughBody: false, CustomMessage: &customMsg, } return rule } // Helper functions func testIntPtr(i int) *int { return &i } func testStrPtr(s string) *string { return &s }