fix: restore error passthrough service improvements from 7b156489
This commit is contained in:
@@ -6,6 +6,7 @@ import (
|
||||
"sort"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/model"
|
||||
)
|
||||
@@ -60,8 +61,11 @@ func NewErrorPassthroughService(
|
||||
|
||||
// 启动时加载规则到本地缓存
|
||||
ctx := context.Background()
|
||||
if err := svc.refreshLocalCache(ctx); err != nil {
|
||||
log.Printf("[ErrorPassthroughService] Failed to load rules on startup: %v", err)
|
||||
if err := svc.reloadRulesFromDB(ctx); err != nil {
|
||||
log.Printf("[ErrorPassthroughService] Failed to load rules from DB on startup: %v", err)
|
||||
if fallbackErr := svc.refreshLocalCache(ctx); fallbackErr != nil {
|
||||
log.Printf("[ErrorPassthroughService] Failed to load rules from cache fallback on startup: %v", fallbackErr)
|
||||
}
|
||||
}
|
||||
|
||||
// 订阅缓存更新通知
|
||||
@@ -98,7 +102,9 @@ func (s *ErrorPassthroughService) Create(ctx context.Context, rule *model.ErrorP
|
||||
}
|
||||
|
||||
// 刷新缓存
|
||||
s.invalidateAndNotify(ctx)
|
||||
refreshCtx, cancel := s.newCacheRefreshContext()
|
||||
defer cancel()
|
||||
s.invalidateAndNotify(refreshCtx)
|
||||
|
||||
return created, nil
|
||||
}
|
||||
@@ -115,7 +121,9 @@ func (s *ErrorPassthroughService) Update(ctx context.Context, rule *model.ErrorP
|
||||
}
|
||||
|
||||
// 刷新缓存
|
||||
s.invalidateAndNotify(ctx)
|
||||
refreshCtx, cancel := s.newCacheRefreshContext()
|
||||
defer cancel()
|
||||
s.invalidateAndNotify(refreshCtx)
|
||||
|
||||
return updated, nil
|
||||
}
|
||||
@@ -127,7 +135,9 @@ func (s *ErrorPassthroughService) Delete(ctx context.Context, id int64) error {
|
||||
}
|
||||
|
||||
// 刷新缓存
|
||||
s.invalidateAndNotify(ctx)
|
||||
refreshCtx, cancel := s.newCacheRefreshContext()
|
||||
defer cancel()
|
||||
s.invalidateAndNotify(refreshCtx)
|
||||
|
||||
return nil
|
||||
}
|
||||
@@ -189,7 +199,12 @@ func (s *ErrorPassthroughService) refreshLocalCache(ctx context.Context) error {
|
||||
}
|
||||
}
|
||||
|
||||
// 从数据库加载(repo.List 已按 priority 排序)
|
||||
return s.reloadRulesFromDB(ctx)
|
||||
}
|
||||
|
||||
// 从数据库加载(repo.List 已按 priority 排序)
|
||||
// 注意:该方法会绕过 cache.Get,确保拿到数据库最新值。
|
||||
func (s *ErrorPassthroughService) reloadRulesFromDB(ctx context.Context) error {
|
||||
rules, err := s.repo.List(ctx)
|
||||
if err != nil {
|
||||
return err
|
||||
@@ -222,11 +237,32 @@ func (s *ErrorPassthroughService) setLocalCache(rules []*model.ErrorPassthroughR
|
||||
s.localCacheMu.Unlock()
|
||||
}
|
||||
|
||||
// clearLocalCache 清空本地缓存,避免刷新失败时继续命中陈旧规则。
|
||||
func (s *ErrorPassthroughService) clearLocalCache() {
|
||||
s.localCacheMu.Lock()
|
||||
s.localCache = nil
|
||||
s.localCacheMu.Unlock()
|
||||
}
|
||||
|
||||
// newCacheRefreshContext 为写路径缓存同步创建独立上下文,避免受请求取消影响。
|
||||
func (s *ErrorPassthroughService) newCacheRefreshContext() (context.Context, context.CancelFunc) {
|
||||
return context.WithTimeout(context.Background(), 3*time.Second)
|
||||
}
|
||||
|
||||
// invalidateAndNotify 使缓存失效并通知其他实例
|
||||
func (s *ErrorPassthroughService) invalidateAndNotify(ctx context.Context) {
|
||||
// 先失效缓存,避免后续刷新读到陈旧规则。
|
||||
if s.cache != nil {
|
||||
if err := s.cache.Invalidate(ctx); err != nil {
|
||||
log.Printf("[ErrorPassthroughService] Failed to invalidate cache: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
// 刷新本地缓存
|
||||
if err := s.refreshLocalCache(ctx); err != nil {
|
||||
if err := s.reloadRulesFromDB(ctx); err != nil {
|
||||
log.Printf("[ErrorPassthroughService] Failed to refresh local cache: %v", err)
|
||||
// 刷新失败时清空本地缓存,避免继续使用陈旧规则。
|
||||
s.clearLocalCache()
|
||||
}
|
||||
|
||||
// 通知其他实例
|
||||
|
||||
@@ -4,6 +4,7 @@ package service
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
@@ -15,13 +16,80 @@ import (
|
||||
// mockErrorPassthroughRepo 用于测试的 mock repository
|
||||
type mockErrorPassthroughRepo struct {
|
||||
rules []*model.ErrorPassthroughRule
|
||||
listErr error
|
||||
getErr error
|
||||
createErr error
|
||||
updateErr error
|
||||
deleteErr error
|
||||
}
|
||||
|
||||
type mockErrorPassthroughCache struct {
|
||||
rules []*model.ErrorPassthroughRule
|
||||
hasData bool
|
||||
getCalled int
|
||||
setCalled int
|
||||
invalidateCalled int
|
||||
notifyCalled int
|
||||
}
|
||||
|
||||
func newMockErrorPassthroughCache(rules []*model.ErrorPassthroughRule, hasData bool) *mockErrorPassthroughCache {
|
||||
return &mockErrorPassthroughCache{
|
||||
rules: cloneRules(rules),
|
||||
hasData: hasData,
|
||||
}
|
||||
}
|
||||
|
||||
func (m *mockErrorPassthroughCache) Get(ctx context.Context) ([]*model.ErrorPassthroughRule, bool) {
|
||||
m.getCalled++
|
||||
if !m.hasData {
|
||||
return nil, false
|
||||
}
|
||||
return cloneRules(m.rules), true
|
||||
}
|
||||
|
||||
func (m *mockErrorPassthroughCache) Set(ctx context.Context, rules []*model.ErrorPassthroughRule) error {
|
||||
m.setCalled++
|
||||
m.rules = cloneRules(rules)
|
||||
m.hasData = true
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *mockErrorPassthroughCache) Invalidate(ctx context.Context) error {
|
||||
m.invalidateCalled++
|
||||
m.rules = nil
|
||||
m.hasData = false
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *mockErrorPassthroughCache) NotifyUpdate(ctx context.Context) error {
|
||||
m.notifyCalled++
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *mockErrorPassthroughCache) SubscribeUpdates(ctx context.Context, handler func()) {
|
||||
// 单测中无需订阅行为
|
||||
}
|
||||
|
||||
func cloneRules(rules []*model.ErrorPassthroughRule) []*model.ErrorPassthroughRule {
|
||||
if rules == nil {
|
||||
return nil
|
||||
}
|
||||
out := make([]*model.ErrorPassthroughRule, len(rules))
|
||||
copy(out, rules)
|
||||
return out
|
||||
}
|
||||
|
||||
func (m *mockErrorPassthroughRepo) List(ctx context.Context) ([]*model.ErrorPassthroughRule, error) {
|
||||
if m.listErr != nil {
|
||||
return nil, m.listErr
|
||||
}
|
||||
return m.rules, nil
|
||||
}
|
||||
|
||||
func (m *mockErrorPassthroughRepo) GetByID(ctx context.Context, id int64) (*model.ErrorPassthroughRule, error) {
|
||||
if m.getErr != nil {
|
||||
return nil, m.getErr
|
||||
}
|
||||
for _, r := range m.rules {
|
||||
if r.ID == id {
|
||||
return r, nil
|
||||
@@ -31,12 +99,18 @@ func (m *mockErrorPassthroughRepo) GetByID(ctx context.Context, id int64) (*mode
|
||||
}
|
||||
|
||||
func (m *mockErrorPassthroughRepo) Create(ctx context.Context, rule *model.ErrorPassthroughRule) (*model.ErrorPassthroughRule, error) {
|
||||
if m.createErr != nil {
|
||||
return nil, m.createErr
|
||||
}
|
||||
rule.ID = int64(len(m.rules) + 1)
|
||||
m.rules = append(m.rules, rule)
|
||||
return rule, nil
|
||||
}
|
||||
|
||||
func (m *mockErrorPassthroughRepo) Update(ctx context.Context, rule *model.ErrorPassthroughRule) (*model.ErrorPassthroughRule, error) {
|
||||
if m.updateErr != nil {
|
||||
return nil, m.updateErr
|
||||
}
|
||||
for i, r := range m.rules {
|
||||
if r.ID == rule.ID {
|
||||
m.rules[i] = rule
|
||||
@@ -47,6 +121,9 @@ func (m *mockErrorPassthroughRepo) Update(ctx context.Context, rule *model.Error
|
||||
}
|
||||
|
||||
func (m *mockErrorPassthroughRepo) Delete(ctx context.Context, id int64) error {
|
||||
if m.deleteErr != nil {
|
||||
return m.deleteErr
|
||||
}
|
||||
for i, r := range m.rules {
|
||||
if r.ID == id {
|
||||
m.rules = append(m.rules[:i], m.rules[i+1:]...)
|
||||
@@ -750,6 +827,158 @@ func TestErrorPassthroughRule_Validate(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
// =============================================================================
|
||||
// 测试写路径缓存刷新(Create/Update/Delete)
|
||||
// =============================================================================
|
||||
|
||||
func TestCreate_ForceRefreshCacheAfterWrite(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
|
||||
staleRule := newPassthroughRuleForWritePathTest(99, "service temporarily unavailable after multiple", "旧缓存消息")
|
||||
repo := &mockErrorPassthroughRepo{rules: []*model.ErrorPassthroughRule{}}
|
||||
cache := newMockErrorPassthroughCache([]*model.ErrorPassthroughRule{staleRule}, true)
|
||||
|
||||
svc := &ErrorPassthroughService{repo: repo, cache: cache}
|
||||
svc.setLocalCache([]*model.ErrorPassthroughRule{staleRule})
|
||||
|
||||
newRule := newPassthroughRuleForWritePathTest(0, "service temporarily unavailable after multiple", "上游请求失败")
|
||||
created, err := svc.Create(ctx, newRule)
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, created)
|
||||
|
||||
body := []byte(`{"message":"Service temporarily unavailable after multiple retries, please try again later"}`)
|
||||
matched := svc.MatchRule("anthropic", 503, body)
|
||||
require.NotNil(t, matched)
|
||||
assert.Equal(t, created.ID, matched.ID)
|
||||
if assert.NotNil(t, matched.CustomMessage) {
|
||||
assert.Equal(t, "上游请求失败", *matched.CustomMessage)
|
||||
}
|
||||
|
||||
assert.Equal(t, 0, cache.getCalled, "写路径刷新不应依赖 cache.Get")
|
||||
assert.Equal(t, 1, cache.invalidateCalled)
|
||||
assert.Equal(t, 1, cache.setCalled)
|
||||
assert.Equal(t, 1, cache.notifyCalled)
|
||||
}
|
||||
|
||||
func TestUpdate_ForceRefreshCacheAfterWrite(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
|
||||
originalRule := newPassthroughRuleForWritePathTest(1, "old keyword", "旧消息")
|
||||
repo := &mockErrorPassthroughRepo{rules: []*model.ErrorPassthroughRule{originalRule}}
|
||||
cache := newMockErrorPassthroughCache([]*model.ErrorPassthroughRule{originalRule}, true)
|
||||
|
||||
svc := &ErrorPassthroughService{repo: repo, cache: cache}
|
||||
svc.setLocalCache([]*model.ErrorPassthroughRule{originalRule})
|
||||
|
||||
updatedRule := newPassthroughRuleForWritePathTest(1, "new keyword", "新消息")
|
||||
_, err := svc.Update(ctx, updatedRule)
|
||||
require.NoError(t, err)
|
||||
|
||||
oldBody := []byte(`{"message":"old keyword"}`)
|
||||
oldMatched := svc.MatchRule("anthropic", 503, oldBody)
|
||||
assert.Nil(t, oldMatched, "更新后旧关键词不应继续命中")
|
||||
|
||||
newBody := []byte(`{"message":"new keyword"}`)
|
||||
newMatched := svc.MatchRule("anthropic", 503, newBody)
|
||||
require.NotNil(t, newMatched)
|
||||
if assert.NotNil(t, newMatched.CustomMessage) {
|
||||
assert.Equal(t, "新消息", *newMatched.CustomMessage)
|
||||
}
|
||||
|
||||
assert.Equal(t, 0, cache.getCalled, "写路径刷新不应依赖 cache.Get")
|
||||
assert.Equal(t, 1, cache.invalidateCalled)
|
||||
assert.Equal(t, 1, cache.setCalled)
|
||||
assert.Equal(t, 1, cache.notifyCalled)
|
||||
}
|
||||
|
||||
func TestDelete_ForceRefreshCacheAfterWrite(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
|
||||
rule := newPassthroughRuleForWritePathTest(1, "to be deleted", "删除前消息")
|
||||
repo := &mockErrorPassthroughRepo{rules: []*model.ErrorPassthroughRule{rule}}
|
||||
cache := newMockErrorPassthroughCache([]*model.ErrorPassthroughRule{rule}, true)
|
||||
|
||||
svc := &ErrorPassthroughService{repo: repo, cache: cache}
|
||||
svc.setLocalCache([]*model.ErrorPassthroughRule{rule})
|
||||
|
||||
err := svc.Delete(ctx, 1)
|
||||
require.NoError(t, err)
|
||||
|
||||
body := []byte(`{"message":"to be deleted"}`)
|
||||
matched := svc.MatchRule("anthropic", 503, body)
|
||||
assert.Nil(t, matched, "删除后规则不应再命中")
|
||||
|
||||
assert.Equal(t, 0, cache.getCalled, "写路径刷新不应依赖 cache.Get")
|
||||
assert.Equal(t, 1, cache.invalidateCalled)
|
||||
assert.Equal(t, 1, cache.setCalled)
|
||||
assert.Equal(t, 1, cache.notifyCalled)
|
||||
}
|
||||
|
||||
func TestNewService_StartupReloadFromDBToHealStaleCache(t *testing.T) {
|
||||
staleRule := newPassthroughRuleForWritePathTest(99, "stale keyword", "旧缓存消息")
|
||||
latestRule := newPassthroughRuleForWritePathTest(1, "fresh keyword", "最新消息")
|
||||
|
||||
repo := &mockErrorPassthroughRepo{rules: []*model.ErrorPassthroughRule{latestRule}}
|
||||
cache := newMockErrorPassthroughCache([]*model.ErrorPassthroughRule{staleRule}, true)
|
||||
|
||||
svc := NewErrorPassthroughService(repo, cache)
|
||||
|
||||
matchedFresh := svc.MatchRule("anthropic", 503, []byte(`{"message":"fresh keyword"}`))
|
||||
require.NotNil(t, matchedFresh)
|
||||
assert.Equal(t, int64(1), matchedFresh.ID)
|
||||
|
||||
matchedStale := svc.MatchRule("anthropic", 503, []byte(`{"message":"stale keyword"}`))
|
||||
assert.Nil(t, matchedStale, "启动后应以 DB 最新规则覆盖旧缓存")
|
||||
|
||||
assert.Equal(t, 0, cache.getCalled, "启动强制 DB 刷新不应依赖 cache.Get")
|
||||
assert.Equal(t, 1, cache.setCalled, "启动后应回写缓存,覆盖陈旧缓存")
|
||||
}
|
||||
|
||||
func TestUpdate_RefreshFailureShouldNotKeepStaleEnabledRule(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
|
||||
staleRule := newPassthroughRuleForWritePathTest(1, "service temporarily unavailable after multiple", "旧缓存消息")
|
||||
repo := &mockErrorPassthroughRepo{
|
||||
rules: []*model.ErrorPassthroughRule{staleRule},
|
||||
listErr: errors.New("db list failed"),
|
||||
}
|
||||
cache := newMockErrorPassthroughCache([]*model.ErrorPassthroughRule{staleRule}, true)
|
||||
|
||||
svc := &ErrorPassthroughService{repo: repo, cache: cache}
|
||||
svc.setLocalCache([]*model.ErrorPassthroughRule{staleRule})
|
||||
|
||||
disabledRule := *staleRule
|
||||
disabledRule.Enabled = false
|
||||
_, err := svc.Update(ctx, &disabledRule)
|
||||
require.NoError(t, err)
|
||||
|
||||
body := []byte(`{"message":"Service temporarily unavailable after multiple retries, please try again later"}`)
|
||||
matched := svc.MatchRule("anthropic", 503, body)
|
||||
assert.Nil(t, matched, "刷新失败时不应继续命中旧的启用规则")
|
||||
|
||||
svc.localCacheMu.RLock()
|
||||
assert.Nil(t, svc.localCache, "刷新失败后应清空本地缓存,避免误命中")
|
||||
svc.localCacheMu.RUnlock()
|
||||
}
|
||||
|
||||
func newPassthroughRuleForWritePathTest(id int64, keyword, customMsg string) *model.ErrorPassthroughRule {
|
||||
responseCode := 503
|
||||
rule := &model.ErrorPassthroughRule{
|
||||
ID: id,
|
||||
Name: "write-path-cache-refresh",
|
||||
Enabled: true,
|
||||
Priority: 1,
|
||||
ErrorCodes: []int{503},
|
||||
Keywords: []string{keyword},
|
||||
MatchMode: model.MatchModeAll,
|
||||
PassthroughCode: false,
|
||||
ResponseCode: &responseCode,
|
||||
PassthroughBody: false,
|
||||
CustomMessage: &customMsg,
|
||||
}
|
||||
return rule
|
||||
}
|
||||
|
||||
// Helper functions
|
||||
func testIntPtr(i int) *int { return &i }
|
||||
func testStrPtr(s string) *string { return &s }
|
||||
|
||||
Reference in New Issue
Block a user