diff --git a/backend/internal/service/error_passthrough_service.go b/backend/internal/service/error_passthrough_service.go index 99dc70e3..c3e0f630 100644 --- a/backend/internal/service/error_passthrough_service.go +++ b/backend/internal/service/error_passthrough_service.go @@ -6,6 +6,7 @@ import ( "sort" "strings" "sync" + "time" "github.com/Wei-Shaw/sub2api/internal/model" ) @@ -60,8 +61,11 @@ func NewErrorPassthroughService( // 启动时加载规则到本地缓存 ctx := context.Background() - if err := svc.refreshLocalCache(ctx); err != nil { - log.Printf("[ErrorPassthroughService] Failed to load rules on startup: %v", err) + if err := svc.reloadRulesFromDB(ctx); err != nil { + log.Printf("[ErrorPassthroughService] Failed to load rules from DB on startup: %v", err) + if fallbackErr := svc.refreshLocalCache(ctx); fallbackErr != nil { + log.Printf("[ErrorPassthroughService] Failed to load rules from cache fallback on startup: %v", fallbackErr) + } } // 订阅缓存更新通知 @@ -98,7 +102,9 @@ func (s *ErrorPassthroughService) Create(ctx context.Context, rule *model.ErrorP } // 刷新缓存 - s.invalidateAndNotify(ctx) + refreshCtx, cancel := s.newCacheRefreshContext() + defer cancel() + s.invalidateAndNotify(refreshCtx) return created, nil } @@ -115,7 +121,9 @@ func (s *ErrorPassthroughService) Update(ctx context.Context, rule *model.ErrorP } // 刷新缓存 - s.invalidateAndNotify(ctx) + refreshCtx, cancel := s.newCacheRefreshContext() + defer cancel() + s.invalidateAndNotify(refreshCtx) return updated, nil } @@ -127,7 +135,9 @@ func (s *ErrorPassthroughService) Delete(ctx context.Context, id int64) error { } // 刷新缓存 - s.invalidateAndNotify(ctx) + refreshCtx, cancel := s.newCacheRefreshContext() + defer cancel() + s.invalidateAndNotify(refreshCtx) return nil } @@ -189,7 +199,12 @@ func (s *ErrorPassthroughService) refreshLocalCache(ctx context.Context) error { } } - // 从数据库加载(repo.List 已按 priority 排序) + return s.reloadRulesFromDB(ctx) +} + +// 从数据库加载(repo.List 已按 priority 排序) +// 注意:该方法会绕过 cache.Get,确保拿到数据库最新值。 +func (s *ErrorPassthroughService) reloadRulesFromDB(ctx context.Context) error { rules, err := s.repo.List(ctx) if err != nil { return err @@ -222,11 +237,32 @@ func (s *ErrorPassthroughService) setLocalCache(rules []*model.ErrorPassthroughR s.localCacheMu.Unlock() } +// clearLocalCache 清空本地缓存,避免刷新失败时继续命中陈旧规则。 +func (s *ErrorPassthroughService) clearLocalCache() { + s.localCacheMu.Lock() + s.localCache = nil + s.localCacheMu.Unlock() +} + +// newCacheRefreshContext 为写路径缓存同步创建独立上下文,避免受请求取消影响。 +func (s *ErrorPassthroughService) newCacheRefreshContext() (context.Context, context.CancelFunc) { + return context.WithTimeout(context.Background(), 3*time.Second) +} + // invalidateAndNotify 使缓存失效并通知其他实例 func (s *ErrorPassthroughService) invalidateAndNotify(ctx context.Context) { + // 先失效缓存,避免后续刷新读到陈旧规则。 + if s.cache != nil { + if err := s.cache.Invalidate(ctx); err != nil { + log.Printf("[ErrorPassthroughService] Failed to invalidate cache: %v", err) + } + } + // 刷新本地缓存 - if err := s.refreshLocalCache(ctx); err != nil { + if err := s.reloadRulesFromDB(ctx); err != nil { log.Printf("[ErrorPassthroughService] Failed to refresh local cache: %v", err) + // 刷新失败时清空本地缓存,避免继续使用陈旧规则。 + s.clearLocalCache() } // 通知其他实例 diff --git a/backend/internal/service/error_passthrough_service_test.go b/backend/internal/service/error_passthrough_service_test.go index 205b4ec4..74c98d86 100644 --- a/backend/internal/service/error_passthrough_service_test.go +++ b/backend/internal/service/error_passthrough_service_test.go @@ -4,6 +4,7 @@ package service import ( "context" + "errors" "strings" "testing" @@ -14,14 +15,81 @@ import ( // mockErrorPassthroughRepo 用于测试的 mock repository type mockErrorPassthroughRepo struct { - rules []*model.ErrorPassthroughRule + rules []*model.ErrorPassthroughRule + listErr error + getErr error + createErr error + updateErr error + deleteErr error +} + +type mockErrorPassthroughCache struct { + rules []*model.ErrorPassthroughRule + hasData bool + getCalled int + setCalled int + invalidateCalled int + notifyCalled int +} + +func newMockErrorPassthroughCache(rules []*model.ErrorPassthroughRule, hasData bool) *mockErrorPassthroughCache { + return &mockErrorPassthroughCache{ + rules: cloneRules(rules), + hasData: hasData, + } +} + +func (m *mockErrorPassthroughCache) Get(ctx context.Context) ([]*model.ErrorPassthroughRule, bool) { + m.getCalled++ + if !m.hasData { + return nil, false + } + return cloneRules(m.rules), true +} + +func (m *mockErrorPassthroughCache) Set(ctx context.Context, rules []*model.ErrorPassthroughRule) error { + m.setCalled++ + m.rules = cloneRules(rules) + m.hasData = true + return nil +} + +func (m *mockErrorPassthroughCache) Invalidate(ctx context.Context) error { + m.invalidateCalled++ + m.rules = nil + m.hasData = false + return nil +} + +func (m *mockErrorPassthroughCache) NotifyUpdate(ctx context.Context) error { + m.notifyCalled++ + return nil +} + +func (m *mockErrorPassthroughCache) SubscribeUpdates(ctx context.Context, handler func()) { + // 单测中无需订阅行为 +} + +func cloneRules(rules []*model.ErrorPassthroughRule) []*model.ErrorPassthroughRule { + if rules == nil { + return nil + } + out := make([]*model.ErrorPassthroughRule, len(rules)) + copy(out, rules) + return out } func (m *mockErrorPassthroughRepo) List(ctx context.Context) ([]*model.ErrorPassthroughRule, error) { + if m.listErr != nil { + return nil, m.listErr + } return m.rules, nil } func (m *mockErrorPassthroughRepo) GetByID(ctx context.Context, id int64) (*model.ErrorPassthroughRule, error) { + if m.getErr != nil { + return nil, m.getErr + } for _, r := range m.rules { if r.ID == id { return r, nil @@ -31,12 +99,18 @@ func (m *mockErrorPassthroughRepo) GetByID(ctx context.Context, id int64) (*mode } func (m *mockErrorPassthroughRepo) Create(ctx context.Context, rule *model.ErrorPassthroughRule) (*model.ErrorPassthroughRule, error) { + if m.createErr != nil { + return nil, m.createErr + } rule.ID = int64(len(m.rules) + 1) m.rules = append(m.rules, rule) return rule, nil } func (m *mockErrorPassthroughRepo) Update(ctx context.Context, rule *model.ErrorPassthroughRule) (*model.ErrorPassthroughRule, error) { + if m.updateErr != nil { + return nil, m.updateErr + } for i, r := range m.rules { if r.ID == rule.ID { m.rules[i] = rule @@ -47,6 +121,9 @@ func (m *mockErrorPassthroughRepo) Update(ctx context.Context, rule *model.Error } func (m *mockErrorPassthroughRepo) Delete(ctx context.Context, id int64) error { + if m.deleteErr != nil { + return m.deleteErr + } for i, r := range m.rules { if r.ID == id { m.rules = append(m.rules[:i], m.rules[i+1:]...) @@ -750,6 +827,158 @@ func TestErrorPassthroughRule_Validate(t *testing.T) { } } +// ============================================================================= +// 测试写路径缓存刷新(Create/Update/Delete) +// ============================================================================= + +func TestCreate_ForceRefreshCacheAfterWrite(t *testing.T) { + ctx := context.Background() + + staleRule := newPassthroughRuleForWritePathTest(99, "service temporarily unavailable after multiple", "旧缓存消息") + repo := &mockErrorPassthroughRepo{rules: []*model.ErrorPassthroughRule{}} + cache := newMockErrorPassthroughCache([]*model.ErrorPassthroughRule{staleRule}, true) + + svc := &ErrorPassthroughService{repo: repo, cache: cache} + svc.setLocalCache([]*model.ErrorPassthroughRule{staleRule}) + + newRule := newPassthroughRuleForWritePathTest(0, "service temporarily unavailable after multiple", "上游请求失败") + created, err := svc.Create(ctx, newRule) + require.NoError(t, err) + require.NotNil(t, created) + + body := []byte(`{"message":"Service temporarily unavailable after multiple retries, please try again later"}`) + matched := svc.MatchRule("anthropic", 503, body) + require.NotNil(t, matched) + assert.Equal(t, created.ID, matched.ID) + if assert.NotNil(t, matched.CustomMessage) { + assert.Equal(t, "上游请求失败", *matched.CustomMessage) + } + + assert.Equal(t, 0, cache.getCalled, "写路径刷新不应依赖 cache.Get") + assert.Equal(t, 1, cache.invalidateCalled) + assert.Equal(t, 1, cache.setCalled) + assert.Equal(t, 1, cache.notifyCalled) +} + +func TestUpdate_ForceRefreshCacheAfterWrite(t *testing.T) { + ctx := context.Background() + + originalRule := newPassthroughRuleForWritePathTest(1, "old keyword", "旧消息") + repo := &mockErrorPassthroughRepo{rules: []*model.ErrorPassthroughRule{originalRule}} + cache := newMockErrorPassthroughCache([]*model.ErrorPassthroughRule{originalRule}, true) + + svc := &ErrorPassthroughService{repo: repo, cache: cache} + svc.setLocalCache([]*model.ErrorPassthroughRule{originalRule}) + + updatedRule := newPassthroughRuleForWritePathTest(1, "new keyword", "新消息") + _, err := svc.Update(ctx, updatedRule) + require.NoError(t, err) + + oldBody := []byte(`{"message":"old keyword"}`) + oldMatched := svc.MatchRule("anthropic", 503, oldBody) + assert.Nil(t, oldMatched, "更新后旧关键词不应继续命中") + + newBody := []byte(`{"message":"new keyword"}`) + newMatched := svc.MatchRule("anthropic", 503, newBody) + require.NotNil(t, newMatched) + if assert.NotNil(t, newMatched.CustomMessage) { + assert.Equal(t, "新消息", *newMatched.CustomMessage) + } + + assert.Equal(t, 0, cache.getCalled, "写路径刷新不应依赖 cache.Get") + assert.Equal(t, 1, cache.invalidateCalled) + assert.Equal(t, 1, cache.setCalled) + assert.Equal(t, 1, cache.notifyCalled) +} + +func TestDelete_ForceRefreshCacheAfterWrite(t *testing.T) { + ctx := context.Background() + + rule := newPassthroughRuleForWritePathTest(1, "to be deleted", "删除前消息") + repo := &mockErrorPassthroughRepo{rules: []*model.ErrorPassthroughRule{rule}} + cache := newMockErrorPassthroughCache([]*model.ErrorPassthroughRule{rule}, true) + + svc := &ErrorPassthroughService{repo: repo, cache: cache} + svc.setLocalCache([]*model.ErrorPassthroughRule{rule}) + + err := svc.Delete(ctx, 1) + require.NoError(t, err) + + body := []byte(`{"message":"to be deleted"}`) + matched := svc.MatchRule("anthropic", 503, body) + assert.Nil(t, matched, "删除后规则不应再命中") + + assert.Equal(t, 0, cache.getCalled, "写路径刷新不应依赖 cache.Get") + assert.Equal(t, 1, cache.invalidateCalled) + assert.Equal(t, 1, cache.setCalled) + assert.Equal(t, 1, cache.notifyCalled) +} + +func TestNewService_StartupReloadFromDBToHealStaleCache(t *testing.T) { + staleRule := newPassthroughRuleForWritePathTest(99, "stale keyword", "旧缓存消息") + latestRule := newPassthroughRuleForWritePathTest(1, "fresh keyword", "最新消息") + + repo := &mockErrorPassthroughRepo{rules: []*model.ErrorPassthroughRule{latestRule}} + cache := newMockErrorPassthroughCache([]*model.ErrorPassthroughRule{staleRule}, true) + + svc := NewErrorPassthroughService(repo, cache) + + matchedFresh := svc.MatchRule("anthropic", 503, []byte(`{"message":"fresh keyword"}`)) + require.NotNil(t, matchedFresh) + assert.Equal(t, int64(1), matchedFresh.ID) + + matchedStale := svc.MatchRule("anthropic", 503, []byte(`{"message":"stale keyword"}`)) + assert.Nil(t, matchedStale, "启动后应以 DB 最新规则覆盖旧缓存") + + assert.Equal(t, 0, cache.getCalled, "启动强制 DB 刷新不应依赖 cache.Get") + assert.Equal(t, 1, cache.setCalled, "启动后应回写缓存,覆盖陈旧缓存") +} + +func TestUpdate_RefreshFailureShouldNotKeepStaleEnabledRule(t *testing.T) { + ctx := context.Background() + + staleRule := newPassthroughRuleForWritePathTest(1, "service temporarily unavailable after multiple", "旧缓存消息") + repo := &mockErrorPassthroughRepo{ + rules: []*model.ErrorPassthroughRule{staleRule}, + listErr: errors.New("db list failed"), + } + cache := newMockErrorPassthroughCache([]*model.ErrorPassthroughRule{staleRule}, true) + + svc := &ErrorPassthroughService{repo: repo, cache: cache} + svc.setLocalCache([]*model.ErrorPassthroughRule{staleRule}) + + disabledRule := *staleRule + disabledRule.Enabled = false + _, err := svc.Update(ctx, &disabledRule) + require.NoError(t, err) + + body := []byte(`{"message":"Service temporarily unavailable after multiple retries, please try again later"}`) + matched := svc.MatchRule("anthropic", 503, body) + assert.Nil(t, matched, "刷新失败时不应继续命中旧的启用规则") + + svc.localCacheMu.RLock() + assert.Nil(t, svc.localCache, "刷新失败后应清空本地缓存,避免误命中") + svc.localCacheMu.RUnlock() +} + +func newPassthroughRuleForWritePathTest(id int64, keyword, customMsg string) *model.ErrorPassthroughRule { + responseCode := 503 + rule := &model.ErrorPassthroughRule{ + ID: id, + Name: "write-path-cache-refresh", + Enabled: true, + Priority: 1, + ErrorCodes: []int{503}, + Keywords: []string{keyword}, + MatchMode: model.MatchModeAll, + PassthroughCode: false, + ResponseCode: &responseCode, + PassthroughBody: false, + CustomMessage: &customMsg, + } + return rule +} + // Helper functions func testIntPtr(i int) *int { return &i } func testStrPtr(s string) *string { return &s }