merge upstream/main
This commit is contained in:
@@ -94,6 +94,9 @@ type UpdateUserInput struct {
|
||||
Concurrency *int // 使用指针区分"未提供"和"设置为0"
|
||||
Status string
|
||||
AllowedGroups *[]int64 // 使用指针区分"未提供"和"设置为空数组"
|
||||
// GroupRates 用户专属分组倍率配置
|
||||
// map[groupID]*rate,nil 表示删除该分组的专属倍率
|
||||
GroupRates map[int64]*float64
|
||||
}
|
||||
|
||||
type CreateGroupInput struct {
|
||||
@@ -296,6 +299,7 @@ type adminServiceImpl struct {
|
||||
proxyRepo ProxyRepository
|
||||
apiKeyRepo APIKeyRepository
|
||||
redeemCodeRepo RedeemCodeRepository
|
||||
userGroupRateRepo UserGroupRateRepository
|
||||
billingCacheService *BillingCacheService
|
||||
proxyProber ProxyExitInfoProber
|
||||
proxyLatencyCache ProxyLatencyCache
|
||||
@@ -310,6 +314,7 @@ func NewAdminService(
|
||||
proxyRepo ProxyRepository,
|
||||
apiKeyRepo APIKeyRepository,
|
||||
redeemCodeRepo RedeemCodeRepository,
|
||||
userGroupRateRepo UserGroupRateRepository,
|
||||
billingCacheService *BillingCacheService,
|
||||
proxyProber ProxyExitInfoProber,
|
||||
proxyLatencyCache ProxyLatencyCache,
|
||||
@@ -322,6 +327,7 @@ func NewAdminService(
|
||||
proxyRepo: proxyRepo,
|
||||
apiKeyRepo: apiKeyRepo,
|
||||
redeemCodeRepo: redeemCodeRepo,
|
||||
userGroupRateRepo: userGroupRateRepo,
|
||||
billingCacheService: billingCacheService,
|
||||
proxyProber: proxyProber,
|
||||
proxyLatencyCache: proxyLatencyCache,
|
||||
@@ -336,11 +342,35 @@ func (s *adminServiceImpl) ListUsers(ctx context.Context, page, pageSize int, fi
|
||||
if err != nil {
|
||||
return nil, 0, err
|
||||
}
|
||||
// 批量加载用户专属分组倍率
|
||||
if s.userGroupRateRepo != nil && len(users) > 0 {
|
||||
for i := range users {
|
||||
rates, err := s.userGroupRateRepo.GetByUserID(ctx, users[i].ID)
|
||||
if err != nil {
|
||||
log.Printf("failed to load user group rates: user_id=%d err=%v", users[i].ID, err)
|
||||
continue
|
||||
}
|
||||
users[i].GroupRates = rates
|
||||
}
|
||||
}
|
||||
return users, result.Total, nil
|
||||
}
|
||||
|
||||
func (s *adminServiceImpl) GetUser(ctx context.Context, id int64) (*User, error) {
|
||||
return s.userRepo.GetByID(ctx, id)
|
||||
user, err := s.userRepo.GetByID(ctx, id)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
// 加载用户专属分组倍率
|
||||
if s.userGroupRateRepo != nil {
|
||||
rates, err := s.userGroupRateRepo.GetByUserID(ctx, id)
|
||||
if err != nil {
|
||||
log.Printf("failed to load user group rates: user_id=%d err=%v", id, err)
|
||||
} else {
|
||||
user.GroupRates = rates
|
||||
}
|
||||
}
|
||||
return user, nil
|
||||
}
|
||||
|
||||
func (s *adminServiceImpl) CreateUser(ctx context.Context, input *CreateUserInput) (*User, error) {
|
||||
@@ -409,6 +439,14 @@ func (s *adminServiceImpl) UpdateUser(ctx context.Context, id int64, input *Upda
|
||||
if err := s.userRepo.Update(ctx, user); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// 同步用户专属分组倍率
|
||||
if input.GroupRates != nil && s.userGroupRateRepo != nil {
|
||||
if err := s.userGroupRateRepo.SyncUserGroupRates(ctx, user.ID, input.GroupRates); err != nil {
|
||||
log.Printf("failed to sync user group rates: user_id=%d err=%v", user.ID, err)
|
||||
}
|
||||
}
|
||||
|
||||
if s.authCacheInvalidator != nil {
|
||||
if user.Concurrency != oldConcurrency || user.Status != oldStatus || user.Role != oldRole {
|
||||
s.authCacheInvalidator.InvalidateAuthCacheByUserID(ctx, user.ID)
|
||||
@@ -944,6 +982,7 @@ func (s *adminServiceImpl) DeleteGroup(ctx context.Context, id int64) error {
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
// 注意:user_group_rate_multipliers 表通过外键 ON DELETE CASCADE 自动清理
|
||||
|
||||
// 事务成功后,异步失效受影响用户的订阅缓存
|
||||
if len(affectedUserIDs) > 0 && s.billingCacheService != nil {
|
||||
|
||||
@@ -1106,7 +1106,7 @@ func (s *AntigravityGatewayService) Forward(ctx context.Context, c *gin.Context,
|
||||
Message: upstreamMsg,
|
||||
Detail: upstreamDetail,
|
||||
})
|
||||
return nil, &UpstreamFailoverError{StatusCode: resp.StatusCode}
|
||||
return nil, &UpstreamFailoverError{StatusCode: resp.StatusCode, ResponseBody: respBody}
|
||||
}
|
||||
|
||||
return nil, s.writeMappedClaudeError(c, account, resp.StatusCode, resp.Header.Get("x-request-id"), respBody)
|
||||
@@ -1779,6 +1779,7 @@ func (s *AntigravityGatewayService) ForwardGemini(ctx context.Context, c *gin.Co
|
||||
// 处理错误响应
|
||||
if resp.StatusCode >= 400 {
|
||||
respBody, _ := io.ReadAll(io.LimitReader(resp.Body, 2<<20))
|
||||
contentType := resp.Header.Get("Content-Type")
|
||||
// 尽早关闭原始响应体,释放连接;后续逻辑仍可能需要读取 body,因此用内存副本重新包装。
|
||||
_ = resp.Body.Close()
|
||||
resp.Body = io.NopCloser(bytes.NewReader(respBody))
|
||||
@@ -1849,10 +1850,8 @@ func (s *AntigravityGatewayService) ForwardGemini(ctx context.Context, c *gin.Co
|
||||
Message: upstreamMsg,
|
||||
Detail: upstreamDetail,
|
||||
})
|
||||
return nil, &UpstreamFailoverError{StatusCode: resp.StatusCode}
|
||||
return nil, &UpstreamFailoverError{StatusCode: resp.StatusCode, ResponseBody: unwrappedForOps}
|
||||
}
|
||||
|
||||
contentType := resp.Header.Get("Content-Type")
|
||||
if contentType == "" {
|
||||
contentType = "application/json"
|
||||
}
|
||||
|
||||
@@ -115,15 +115,16 @@ type UpdateAPIKeyRequest struct {
|
||||
|
||||
// APIKeyService API Key服务
|
||||
type APIKeyService struct {
|
||||
apiKeyRepo APIKeyRepository
|
||||
userRepo UserRepository
|
||||
groupRepo GroupRepository
|
||||
userSubRepo UserSubscriptionRepository
|
||||
cache APIKeyCache
|
||||
cfg *config.Config
|
||||
authCacheL1 *ristretto.Cache
|
||||
authCfg apiKeyAuthCacheConfig
|
||||
authGroup singleflight.Group
|
||||
apiKeyRepo APIKeyRepository
|
||||
userRepo UserRepository
|
||||
groupRepo GroupRepository
|
||||
userSubRepo UserSubscriptionRepository
|
||||
userGroupRateRepo UserGroupRateRepository
|
||||
cache APIKeyCache
|
||||
cfg *config.Config
|
||||
authCacheL1 *ristretto.Cache
|
||||
authCfg apiKeyAuthCacheConfig
|
||||
authGroup singleflight.Group
|
||||
}
|
||||
|
||||
// NewAPIKeyService 创建API Key服务实例
|
||||
@@ -132,16 +133,18 @@ func NewAPIKeyService(
|
||||
userRepo UserRepository,
|
||||
groupRepo GroupRepository,
|
||||
userSubRepo UserSubscriptionRepository,
|
||||
userGroupRateRepo UserGroupRateRepository,
|
||||
cache APIKeyCache,
|
||||
cfg *config.Config,
|
||||
) *APIKeyService {
|
||||
svc := &APIKeyService{
|
||||
apiKeyRepo: apiKeyRepo,
|
||||
userRepo: userRepo,
|
||||
groupRepo: groupRepo,
|
||||
userSubRepo: userSubRepo,
|
||||
cache: cache,
|
||||
cfg: cfg,
|
||||
apiKeyRepo: apiKeyRepo,
|
||||
userRepo: userRepo,
|
||||
groupRepo: groupRepo,
|
||||
userSubRepo: userSubRepo,
|
||||
userGroupRateRepo: userGroupRateRepo,
|
||||
cache: cache,
|
||||
cfg: cfg,
|
||||
}
|
||||
svc.initAuthCache(cfg)
|
||||
return svc
|
||||
@@ -627,6 +630,19 @@ func (s *APIKeyService) SearchAPIKeys(ctx context.Context, userID int64, keyword
|
||||
return keys, nil
|
||||
}
|
||||
|
||||
// GetUserGroupRates 获取用户的专属分组倍率配置
|
||||
// 返回 map[groupID]rateMultiplier
|
||||
func (s *APIKeyService) GetUserGroupRates(ctx context.Context, userID int64) (map[int64]float64, error) {
|
||||
if s.userGroupRateRepo == nil {
|
||||
return nil, nil
|
||||
}
|
||||
rates, err := s.userGroupRateRepo.GetByUserID(ctx, userID)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("get user group rates: %w", err)
|
||||
}
|
||||
return rates, nil
|
||||
}
|
||||
|
||||
// CheckAPIKeyQuotaAndExpiry checks if the API key is valid for use (not expired, quota not exhausted)
|
||||
// Returns nil if valid, error if invalid
|
||||
func (s *APIKeyService) CheckAPIKeyQuotaAndExpiry(apiKey *APIKey) error {
|
||||
|
||||
@@ -167,7 +167,7 @@ func TestAPIKeyService_GetByKey_UsesL2Cache(t *testing.T) {
|
||||
NegativeTTLSeconds: 30,
|
||||
},
|
||||
}
|
||||
svc := NewAPIKeyService(repo, nil, nil, nil, cache, cfg)
|
||||
svc := NewAPIKeyService(repo, nil, nil, nil, nil, cache, cfg)
|
||||
|
||||
groupID := int64(9)
|
||||
cacheEntry := &APIKeyAuthCacheEntry{
|
||||
@@ -223,7 +223,7 @@ func TestAPIKeyService_GetByKey_NegativeCache(t *testing.T) {
|
||||
NegativeTTLSeconds: 30,
|
||||
},
|
||||
}
|
||||
svc := NewAPIKeyService(repo, nil, nil, nil, cache, cfg)
|
||||
svc := NewAPIKeyService(repo, nil, nil, nil, nil, cache, cfg)
|
||||
cache.getAuthCache = func(ctx context.Context, key string) (*APIKeyAuthCacheEntry, error) {
|
||||
return &APIKeyAuthCacheEntry{NotFound: true}, nil
|
||||
}
|
||||
@@ -256,7 +256,7 @@ func TestAPIKeyService_GetByKey_CacheMissStoresL2(t *testing.T) {
|
||||
NegativeTTLSeconds: 30,
|
||||
},
|
||||
}
|
||||
svc := NewAPIKeyService(repo, nil, nil, nil, cache, cfg)
|
||||
svc := NewAPIKeyService(repo, nil, nil, nil, nil, cache, cfg)
|
||||
cache.getAuthCache = func(ctx context.Context, key string) (*APIKeyAuthCacheEntry, error) {
|
||||
return nil, redis.Nil
|
||||
}
|
||||
@@ -293,7 +293,7 @@ func TestAPIKeyService_GetByKey_UsesL1Cache(t *testing.T) {
|
||||
L1TTLSeconds: 60,
|
||||
},
|
||||
}
|
||||
svc := NewAPIKeyService(repo, nil, nil, nil, cache, cfg)
|
||||
svc := NewAPIKeyService(repo, nil, nil, nil, nil, cache, cfg)
|
||||
require.NotNil(t, svc.authCacheL1)
|
||||
|
||||
_, err := svc.GetByKey(context.Background(), "k-l1")
|
||||
@@ -320,7 +320,7 @@ func TestAPIKeyService_InvalidateAuthCacheByUserID(t *testing.T) {
|
||||
NegativeTTLSeconds: 30,
|
||||
},
|
||||
}
|
||||
svc := NewAPIKeyService(repo, nil, nil, nil, cache, cfg)
|
||||
svc := NewAPIKeyService(repo, nil, nil, nil, nil, cache, cfg)
|
||||
|
||||
svc.InvalidateAuthCacheByUserID(context.Background(), 7)
|
||||
require.Len(t, cache.deleteAuthKeys, 2)
|
||||
@@ -338,7 +338,7 @@ func TestAPIKeyService_InvalidateAuthCacheByGroupID(t *testing.T) {
|
||||
L2TTLSeconds: 60,
|
||||
},
|
||||
}
|
||||
svc := NewAPIKeyService(repo, nil, nil, nil, cache, cfg)
|
||||
svc := NewAPIKeyService(repo, nil, nil, nil, nil, cache, cfg)
|
||||
|
||||
svc.InvalidateAuthCacheByGroupID(context.Background(), 9)
|
||||
require.Len(t, cache.deleteAuthKeys, 2)
|
||||
@@ -356,7 +356,7 @@ func TestAPIKeyService_InvalidateAuthCacheByKey(t *testing.T) {
|
||||
L2TTLSeconds: 60,
|
||||
},
|
||||
}
|
||||
svc := NewAPIKeyService(repo, nil, nil, nil, cache, cfg)
|
||||
svc := NewAPIKeyService(repo, nil, nil, nil, nil, cache, cfg)
|
||||
|
||||
svc.InvalidateAuthCacheByKey(context.Background(), "k1")
|
||||
require.Len(t, cache.deleteAuthKeys, 1)
|
||||
@@ -375,7 +375,7 @@ func TestAPIKeyService_GetByKey_CachesNegativeOnRepoMiss(t *testing.T) {
|
||||
NegativeTTLSeconds: 30,
|
||||
},
|
||||
}
|
||||
svc := NewAPIKeyService(repo, nil, nil, nil, cache, cfg)
|
||||
svc := NewAPIKeyService(repo, nil, nil, nil, nil, cache, cfg)
|
||||
cache.getAuthCache = func(ctx context.Context, key string) (*APIKeyAuthCacheEntry, error) {
|
||||
return nil, redis.Nil
|
||||
}
|
||||
@@ -411,7 +411,7 @@ func TestAPIKeyService_GetByKey_SingleflightCollapses(t *testing.T) {
|
||||
Singleflight: true,
|
||||
},
|
||||
}
|
||||
svc := NewAPIKeyService(repo, nil, nil, nil, cache, cfg)
|
||||
svc := NewAPIKeyService(repo, nil, nil, nil, nil, cache, cfg)
|
||||
|
||||
start := make(chan struct{})
|
||||
wg := sync.WaitGroup{}
|
||||
|
||||
300
backend/internal/service/error_passthrough_service.go
Normal file
300
backend/internal/service/error_passthrough_service.go
Normal file
@@ -0,0 +1,300 @@
|
||||
package service
|
||||
|
||||
import (
|
||||
"context"
|
||||
"log"
|
||||
"sort"
|
||||
"strings"
|
||||
"sync"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/model"
|
||||
)
|
||||
|
||||
// ErrorPassthroughRepository 定义错误透传规则的数据访问接口
|
||||
type ErrorPassthroughRepository interface {
|
||||
// List 获取所有规则
|
||||
List(ctx context.Context) ([]*model.ErrorPassthroughRule, error)
|
||||
// GetByID 根据 ID 获取规则
|
||||
GetByID(ctx context.Context, id int64) (*model.ErrorPassthroughRule, error)
|
||||
// Create 创建规则
|
||||
Create(ctx context.Context, rule *model.ErrorPassthroughRule) (*model.ErrorPassthroughRule, error)
|
||||
// Update 更新规则
|
||||
Update(ctx context.Context, rule *model.ErrorPassthroughRule) (*model.ErrorPassthroughRule, error)
|
||||
// Delete 删除规则
|
||||
Delete(ctx context.Context, id int64) error
|
||||
}
|
||||
|
||||
// ErrorPassthroughCache 定义错误透传规则的缓存接口
|
||||
type ErrorPassthroughCache interface {
|
||||
// Get 从缓存获取规则列表
|
||||
Get(ctx context.Context) ([]*model.ErrorPassthroughRule, bool)
|
||||
// Set 设置缓存
|
||||
Set(ctx context.Context, rules []*model.ErrorPassthroughRule) error
|
||||
// Invalidate 使缓存失效
|
||||
Invalidate(ctx context.Context) error
|
||||
// NotifyUpdate 通知其他实例刷新缓存
|
||||
NotifyUpdate(ctx context.Context) error
|
||||
// SubscribeUpdates 订阅缓存更新通知
|
||||
SubscribeUpdates(ctx context.Context, handler func())
|
||||
}
|
||||
|
||||
// ErrorPassthroughService 错误透传规则服务
|
||||
type ErrorPassthroughService struct {
|
||||
repo ErrorPassthroughRepository
|
||||
cache ErrorPassthroughCache
|
||||
|
||||
// 本地内存缓存,用于快速匹配
|
||||
localCache []*model.ErrorPassthroughRule
|
||||
localCacheMu sync.RWMutex
|
||||
}
|
||||
|
||||
// NewErrorPassthroughService 创建错误透传规则服务
|
||||
func NewErrorPassthroughService(
|
||||
repo ErrorPassthroughRepository,
|
||||
cache ErrorPassthroughCache,
|
||||
) *ErrorPassthroughService {
|
||||
svc := &ErrorPassthroughService{
|
||||
repo: repo,
|
||||
cache: cache,
|
||||
}
|
||||
|
||||
// 启动时加载规则到本地缓存
|
||||
ctx := context.Background()
|
||||
if err := svc.refreshLocalCache(ctx); err != nil {
|
||||
log.Printf("[ErrorPassthroughService] Failed to load rules on startup: %v", err)
|
||||
}
|
||||
|
||||
// 订阅缓存更新通知
|
||||
if cache != nil {
|
||||
cache.SubscribeUpdates(ctx, func() {
|
||||
if err := svc.refreshLocalCache(context.Background()); err != nil {
|
||||
log.Printf("[ErrorPassthroughService] Failed to refresh cache on notification: %v", err)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
return svc
|
||||
}
|
||||
|
||||
// List 获取所有规则
|
||||
func (s *ErrorPassthroughService) List(ctx context.Context) ([]*model.ErrorPassthroughRule, error) {
|
||||
return s.repo.List(ctx)
|
||||
}
|
||||
|
||||
// GetByID 根据 ID 获取规则
|
||||
func (s *ErrorPassthroughService) GetByID(ctx context.Context, id int64) (*model.ErrorPassthroughRule, error) {
|
||||
return s.repo.GetByID(ctx, id)
|
||||
}
|
||||
|
||||
// Create 创建规则
|
||||
func (s *ErrorPassthroughService) Create(ctx context.Context, rule *model.ErrorPassthroughRule) (*model.ErrorPassthroughRule, error) {
|
||||
if err := rule.Validate(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
created, err := s.repo.Create(ctx, rule)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// 刷新缓存
|
||||
s.invalidateAndNotify(ctx)
|
||||
|
||||
return created, nil
|
||||
}
|
||||
|
||||
// Update 更新规则
|
||||
func (s *ErrorPassthroughService) Update(ctx context.Context, rule *model.ErrorPassthroughRule) (*model.ErrorPassthroughRule, error) {
|
||||
if err := rule.Validate(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
updated, err := s.repo.Update(ctx, rule)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// 刷新缓存
|
||||
s.invalidateAndNotify(ctx)
|
||||
|
||||
return updated, nil
|
||||
}
|
||||
|
||||
// Delete 删除规则
|
||||
func (s *ErrorPassthroughService) Delete(ctx context.Context, id int64) error {
|
||||
if err := s.repo.Delete(ctx, id); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// 刷新缓存
|
||||
s.invalidateAndNotify(ctx)
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// MatchRule 匹配透传规则
|
||||
// 返回第一个匹配的规则,如果没有匹配则返回 nil
|
||||
func (s *ErrorPassthroughService) MatchRule(platform string, statusCode int, body []byte) *model.ErrorPassthroughRule {
|
||||
rules := s.getCachedRules()
|
||||
if len(rules) == 0 {
|
||||
return nil
|
||||
}
|
||||
|
||||
bodyStr := strings.ToLower(string(body))
|
||||
|
||||
for _, rule := range rules {
|
||||
if !rule.Enabled {
|
||||
continue
|
||||
}
|
||||
if !s.platformMatches(rule, platform) {
|
||||
continue
|
||||
}
|
||||
if s.ruleMatches(rule, statusCode, bodyStr) {
|
||||
return rule
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// getCachedRules 获取缓存的规则列表(按优先级排序)
|
||||
func (s *ErrorPassthroughService) getCachedRules() []*model.ErrorPassthroughRule {
|
||||
s.localCacheMu.RLock()
|
||||
rules := s.localCache
|
||||
s.localCacheMu.RUnlock()
|
||||
|
||||
if rules != nil {
|
||||
return rules
|
||||
}
|
||||
|
||||
// 如果本地缓存为空,尝试刷新
|
||||
ctx := context.Background()
|
||||
if err := s.refreshLocalCache(ctx); err != nil {
|
||||
log.Printf("[ErrorPassthroughService] Failed to refresh cache: %v", err)
|
||||
return nil
|
||||
}
|
||||
|
||||
s.localCacheMu.RLock()
|
||||
defer s.localCacheMu.RUnlock()
|
||||
return s.localCache
|
||||
}
|
||||
|
||||
// refreshLocalCache 刷新本地缓存
|
||||
func (s *ErrorPassthroughService) refreshLocalCache(ctx context.Context) error {
|
||||
// 先尝试从 Redis 缓存获取
|
||||
if s.cache != nil {
|
||||
if rules, ok := s.cache.Get(ctx); ok {
|
||||
s.setLocalCache(rules)
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
// 从数据库加载(repo.List 已按 priority 排序)
|
||||
rules, err := s.repo.List(ctx)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// 更新 Redis 缓存
|
||||
if s.cache != nil {
|
||||
if err := s.cache.Set(ctx, rules); err != nil {
|
||||
log.Printf("[ErrorPassthroughService] Failed to set cache: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
// 更新本地缓存(setLocalCache 内部会确保排序)
|
||||
s.setLocalCache(rules)
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// setLocalCache 设置本地缓存
|
||||
func (s *ErrorPassthroughService) setLocalCache(rules []*model.ErrorPassthroughRule) {
|
||||
// 按优先级排序
|
||||
sorted := make([]*model.ErrorPassthroughRule, len(rules))
|
||||
copy(sorted, rules)
|
||||
sort.Slice(sorted, func(i, j int) bool {
|
||||
return sorted[i].Priority < sorted[j].Priority
|
||||
})
|
||||
|
||||
s.localCacheMu.Lock()
|
||||
s.localCache = sorted
|
||||
s.localCacheMu.Unlock()
|
||||
}
|
||||
|
||||
// invalidateAndNotify 使缓存失效并通知其他实例
|
||||
func (s *ErrorPassthroughService) invalidateAndNotify(ctx context.Context) {
|
||||
// 刷新本地缓存
|
||||
if err := s.refreshLocalCache(ctx); err != nil {
|
||||
log.Printf("[ErrorPassthroughService] Failed to refresh local cache: %v", err)
|
||||
}
|
||||
|
||||
// 通知其他实例
|
||||
if s.cache != nil {
|
||||
if err := s.cache.NotifyUpdate(ctx); err != nil {
|
||||
log.Printf("[ErrorPassthroughService] Failed to notify cache update: %v", err)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// platformMatches 检查平台是否匹配
|
||||
func (s *ErrorPassthroughService) platformMatches(rule *model.ErrorPassthroughRule, platform string) bool {
|
||||
// 如果没有配置平台限制,则匹配所有平台
|
||||
if len(rule.Platforms) == 0 {
|
||||
return true
|
||||
}
|
||||
|
||||
platform = strings.ToLower(platform)
|
||||
for _, p := range rule.Platforms {
|
||||
if strings.ToLower(p) == platform {
|
||||
return true
|
||||
}
|
||||
}
|
||||
|
||||
return false
|
||||
}
|
||||
|
||||
// ruleMatches 检查规则是否匹配
|
||||
func (s *ErrorPassthroughService) ruleMatches(rule *model.ErrorPassthroughRule, statusCode int, bodyLower string) bool {
|
||||
hasErrorCodes := len(rule.ErrorCodes) > 0
|
||||
hasKeywords := len(rule.Keywords) > 0
|
||||
|
||||
// 如果没有配置任何条件,不匹配
|
||||
if !hasErrorCodes && !hasKeywords {
|
||||
return false
|
||||
}
|
||||
|
||||
codeMatch := !hasErrorCodes || s.containsInt(rule.ErrorCodes, statusCode)
|
||||
keywordMatch := !hasKeywords || s.containsAnyKeyword(bodyLower, rule.Keywords)
|
||||
|
||||
if rule.MatchMode == model.MatchModeAll {
|
||||
// "all" 模式:所有配置的条件都必须满足
|
||||
return codeMatch && keywordMatch
|
||||
}
|
||||
|
||||
// "any" 模式:任一条件满足即可
|
||||
if hasErrorCodes && hasKeywords {
|
||||
return codeMatch || keywordMatch
|
||||
}
|
||||
return codeMatch && keywordMatch
|
||||
}
|
||||
|
||||
// containsInt 检查切片是否包含指定整数
|
||||
func (s *ErrorPassthroughService) containsInt(slice []int, val int) bool {
|
||||
for _, v := range slice {
|
||||
if v == val {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
// containsAnyKeyword 检查字符串是否包含任一关键词(不区分大小写)
|
||||
func (s *ErrorPassthroughService) containsAnyKeyword(bodyLower string, keywords []string) bool {
|
||||
for _, kw := range keywords {
|
||||
if strings.Contains(bodyLower, strings.ToLower(kw)) {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
755
backend/internal/service/error_passthrough_service_test.go
Normal file
755
backend/internal/service/error_passthrough_service_test.go
Normal file
@@ -0,0 +1,755 @@
|
||||
//go:build unit
|
||||
|
||||
package service
|
||||
|
||||
import (
|
||||
"context"
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/model"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
// mockErrorPassthroughRepo 用于测试的 mock repository
|
||||
type mockErrorPassthroughRepo struct {
|
||||
rules []*model.ErrorPassthroughRule
|
||||
}
|
||||
|
||||
func (m *mockErrorPassthroughRepo) List(ctx context.Context) ([]*model.ErrorPassthroughRule, error) {
|
||||
return m.rules, nil
|
||||
}
|
||||
|
||||
func (m *mockErrorPassthroughRepo) GetByID(ctx context.Context, id int64) (*model.ErrorPassthroughRule, error) {
|
||||
for _, r := range m.rules {
|
||||
if r.ID == id {
|
||||
return r, nil
|
||||
}
|
||||
}
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
func (m *mockErrorPassthroughRepo) Create(ctx context.Context, rule *model.ErrorPassthroughRule) (*model.ErrorPassthroughRule, error) {
|
||||
rule.ID = int64(len(m.rules) + 1)
|
||||
m.rules = append(m.rules, rule)
|
||||
return rule, nil
|
||||
}
|
||||
|
||||
func (m *mockErrorPassthroughRepo) Update(ctx context.Context, rule *model.ErrorPassthroughRule) (*model.ErrorPassthroughRule, error) {
|
||||
for i, r := range m.rules {
|
||||
if r.ID == rule.ID {
|
||||
m.rules[i] = rule
|
||||
return rule, nil
|
||||
}
|
||||
}
|
||||
return rule, nil
|
||||
}
|
||||
|
||||
func (m *mockErrorPassthroughRepo) Delete(ctx context.Context, id int64) error {
|
||||
for i, r := range m.rules {
|
||||
if r.ID == id {
|
||||
m.rules = append(m.rules[:i], m.rules[i+1:]...)
|
||||
return nil
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// newTestService 创建测试用的服务实例
|
||||
func newTestService(rules []*model.ErrorPassthroughRule) *ErrorPassthroughService {
|
||||
repo := &mockErrorPassthroughRepo{rules: rules}
|
||||
svc := &ErrorPassthroughService{
|
||||
repo: repo,
|
||||
cache: nil, // 不使用缓存
|
||||
}
|
||||
// 直接设置本地缓存,避免调用 refreshLocalCache
|
||||
svc.setLocalCache(rules)
|
||||
return svc
|
||||
}
|
||||
|
||||
// =============================================================================
|
||||
// 测试 ruleMatches 核心匹配逻辑
|
||||
// =============================================================================
|
||||
|
||||
func TestRuleMatches_NoConditions(t *testing.T) {
|
||||
// 没有配置任何条件时,不应该匹配
|
||||
svc := newTestService(nil)
|
||||
rule := &model.ErrorPassthroughRule{
|
||||
Enabled: true,
|
||||
ErrorCodes: []int{},
|
||||
Keywords: []string{},
|
||||
MatchMode: model.MatchModeAny,
|
||||
}
|
||||
|
||||
assert.False(t, svc.ruleMatches(rule, 422, "some error message"),
|
||||
"没有配置条件时不应该匹配")
|
||||
}
|
||||
|
||||
func TestRuleMatches_OnlyErrorCodes_AnyMode(t *testing.T) {
|
||||
svc := newTestService(nil)
|
||||
rule := &model.ErrorPassthroughRule{
|
||||
Enabled: true,
|
||||
ErrorCodes: []int{422, 400},
|
||||
Keywords: []string{},
|
||||
MatchMode: model.MatchModeAny,
|
||||
}
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
statusCode int
|
||||
body string
|
||||
expected bool
|
||||
}{
|
||||
{"状态码匹配 422", 422, "any message", true},
|
||||
{"状态码匹配 400", 400, "any message", true},
|
||||
{"状态码不匹配 500", 500, "any message", false},
|
||||
{"状态码不匹配 429", 429, "any message", false},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
result := svc.ruleMatches(rule, tt.statusCode, tt.body)
|
||||
assert.Equal(t, tt.expected, result)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestRuleMatches_OnlyKeywords_AnyMode(t *testing.T) {
|
||||
svc := newTestService(nil)
|
||||
rule := &model.ErrorPassthroughRule{
|
||||
Enabled: true,
|
||||
ErrorCodes: []int{},
|
||||
Keywords: []string{"context limit", "model not supported"},
|
||||
MatchMode: model.MatchModeAny,
|
||||
}
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
statusCode int
|
||||
body string
|
||||
expected bool
|
||||
}{
|
||||
{"关键词匹配 context limit", 500, "error: context limit reached", true},
|
||||
{"关键词匹配 model not supported", 400, "the model not supported here", true},
|
||||
{"关键词不匹配", 422, "some other error", false},
|
||||
// 注意:ruleMatches 接收的 body 参数应该是已经转换为小写的
|
||||
// 实际使用时,MatchRule 会先将 body 转换为小写再传给 ruleMatches
|
||||
{"关键词大小写 - 输入已小写", 500, "context limit exceeded", true},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
// 模拟 MatchRule 的行为:先转换为小写
|
||||
bodyLower := strings.ToLower(tt.body)
|
||||
result := svc.ruleMatches(rule, tt.statusCode, bodyLower)
|
||||
assert.Equal(t, tt.expected, result)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestRuleMatches_BothConditions_AnyMode(t *testing.T) {
|
||||
// any 模式:错误码 OR 关键词
|
||||
svc := newTestService(nil)
|
||||
rule := &model.ErrorPassthroughRule{
|
||||
Enabled: true,
|
||||
ErrorCodes: []int{422, 400},
|
||||
Keywords: []string{"context limit"},
|
||||
MatchMode: model.MatchModeAny,
|
||||
}
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
statusCode int
|
||||
body string
|
||||
expected bool
|
||||
reason string
|
||||
}{
|
||||
{
|
||||
name: "状态码和关键词都匹配",
|
||||
statusCode: 422,
|
||||
body: "context limit reached",
|
||||
expected: true,
|
||||
reason: "both match",
|
||||
},
|
||||
{
|
||||
name: "只有状态码匹配",
|
||||
statusCode: 422,
|
||||
body: "some other error",
|
||||
expected: true,
|
||||
reason: "code matches, keyword doesn't - OR mode should match",
|
||||
},
|
||||
{
|
||||
name: "只有关键词匹配",
|
||||
statusCode: 500,
|
||||
body: "context limit exceeded",
|
||||
expected: true,
|
||||
reason: "keyword matches, code doesn't - OR mode should match",
|
||||
},
|
||||
{
|
||||
name: "都不匹配",
|
||||
statusCode: 500,
|
||||
body: "some other error",
|
||||
expected: false,
|
||||
reason: "neither matches",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
result := svc.ruleMatches(rule, tt.statusCode, tt.body)
|
||||
assert.Equal(t, tt.expected, result, tt.reason)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestRuleMatches_BothConditions_AllMode(t *testing.T) {
|
||||
// all 模式:错误码 AND 关键词
|
||||
svc := newTestService(nil)
|
||||
rule := &model.ErrorPassthroughRule{
|
||||
Enabled: true,
|
||||
ErrorCodes: []int{422, 400},
|
||||
Keywords: []string{"context limit"},
|
||||
MatchMode: model.MatchModeAll,
|
||||
}
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
statusCode int
|
||||
body string
|
||||
expected bool
|
||||
reason string
|
||||
}{
|
||||
{
|
||||
name: "状态码和关键词都匹配",
|
||||
statusCode: 422,
|
||||
body: "context limit reached",
|
||||
expected: true,
|
||||
reason: "both match - AND mode should match",
|
||||
},
|
||||
{
|
||||
name: "只有状态码匹配",
|
||||
statusCode: 422,
|
||||
body: "some other error",
|
||||
expected: false,
|
||||
reason: "code matches but keyword doesn't - AND mode should NOT match",
|
||||
},
|
||||
{
|
||||
name: "只有关键词匹配",
|
||||
statusCode: 500,
|
||||
body: "context limit exceeded",
|
||||
expected: false,
|
||||
reason: "keyword matches but code doesn't - AND mode should NOT match",
|
||||
},
|
||||
{
|
||||
name: "都不匹配",
|
||||
statusCode: 500,
|
||||
body: "some other error",
|
||||
expected: false,
|
||||
reason: "neither matches",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
result := svc.ruleMatches(rule, tt.statusCode, tt.body)
|
||||
assert.Equal(t, tt.expected, result, tt.reason)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// =============================================================================
|
||||
// 测试 platformMatches 平台匹配逻辑
|
||||
// =============================================================================
|
||||
|
||||
func TestPlatformMatches(t *testing.T) {
|
||||
svc := newTestService(nil)
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
rulePlatforms []string
|
||||
requestPlatform string
|
||||
expected bool
|
||||
}{
|
||||
{
|
||||
name: "空平台列表匹配所有",
|
||||
rulePlatforms: []string{},
|
||||
requestPlatform: "anthropic",
|
||||
expected: true,
|
||||
},
|
||||
{
|
||||
name: "nil平台列表匹配所有",
|
||||
rulePlatforms: nil,
|
||||
requestPlatform: "openai",
|
||||
expected: true,
|
||||
},
|
||||
{
|
||||
name: "精确匹配 anthropic",
|
||||
rulePlatforms: []string{"anthropic", "openai"},
|
||||
requestPlatform: "anthropic",
|
||||
expected: true,
|
||||
},
|
||||
{
|
||||
name: "精确匹配 openai",
|
||||
rulePlatforms: []string{"anthropic", "openai"},
|
||||
requestPlatform: "openai",
|
||||
expected: true,
|
||||
},
|
||||
{
|
||||
name: "不匹配 gemini",
|
||||
rulePlatforms: []string{"anthropic", "openai"},
|
||||
requestPlatform: "gemini",
|
||||
expected: false,
|
||||
},
|
||||
{
|
||||
name: "大小写不敏感",
|
||||
rulePlatforms: []string{"Anthropic", "OpenAI"},
|
||||
requestPlatform: "anthropic",
|
||||
expected: true,
|
||||
},
|
||||
{
|
||||
name: "匹配 antigravity",
|
||||
rulePlatforms: []string{"antigravity"},
|
||||
requestPlatform: "antigravity",
|
||||
expected: true,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
rule := &model.ErrorPassthroughRule{
|
||||
Platforms: tt.rulePlatforms,
|
||||
}
|
||||
result := svc.platformMatches(rule, tt.requestPlatform)
|
||||
assert.Equal(t, tt.expected, result)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// =============================================================================
|
||||
// 测试 MatchRule 完整匹配流程
|
||||
// =============================================================================
|
||||
|
||||
func TestMatchRule_Priority(t *testing.T) {
|
||||
// 测试规则按优先级排序,优先级小的先匹配
|
||||
rules := []*model.ErrorPassthroughRule{
|
||||
{
|
||||
ID: 1,
|
||||
Name: "Low Priority",
|
||||
Enabled: true,
|
||||
Priority: 10,
|
||||
ErrorCodes: []int{422},
|
||||
MatchMode: model.MatchModeAny,
|
||||
},
|
||||
{
|
||||
ID: 2,
|
||||
Name: "High Priority",
|
||||
Enabled: true,
|
||||
Priority: 1,
|
||||
ErrorCodes: []int{422},
|
||||
MatchMode: model.MatchModeAny,
|
||||
},
|
||||
}
|
||||
|
||||
svc := newTestService(rules)
|
||||
matched := svc.MatchRule("anthropic", 422, []byte("error"))
|
||||
|
||||
require.NotNil(t, matched)
|
||||
assert.Equal(t, int64(2), matched.ID, "应该匹配优先级更高(数值更小)的规则")
|
||||
assert.Equal(t, "High Priority", matched.Name)
|
||||
}
|
||||
|
||||
func TestMatchRule_DisabledRule(t *testing.T) {
|
||||
rules := []*model.ErrorPassthroughRule{
|
||||
{
|
||||
ID: 1,
|
||||
Name: "Disabled Rule",
|
||||
Enabled: false,
|
||||
Priority: 1,
|
||||
ErrorCodes: []int{422},
|
||||
MatchMode: model.MatchModeAny,
|
||||
},
|
||||
{
|
||||
ID: 2,
|
||||
Name: "Enabled Rule",
|
||||
Enabled: true,
|
||||
Priority: 10,
|
||||
ErrorCodes: []int{422},
|
||||
MatchMode: model.MatchModeAny,
|
||||
},
|
||||
}
|
||||
|
||||
svc := newTestService(rules)
|
||||
matched := svc.MatchRule("anthropic", 422, []byte("error"))
|
||||
|
||||
require.NotNil(t, matched)
|
||||
assert.Equal(t, int64(2), matched.ID, "应该跳过禁用的规则")
|
||||
}
|
||||
|
||||
func TestMatchRule_PlatformFilter(t *testing.T) {
|
||||
rules := []*model.ErrorPassthroughRule{
|
||||
{
|
||||
ID: 1,
|
||||
Name: "Anthropic Only",
|
||||
Enabled: true,
|
||||
Priority: 1,
|
||||
ErrorCodes: []int{422},
|
||||
Platforms: []string{"anthropic"},
|
||||
MatchMode: model.MatchModeAny,
|
||||
},
|
||||
{
|
||||
ID: 2,
|
||||
Name: "OpenAI Only",
|
||||
Enabled: true,
|
||||
Priority: 2,
|
||||
ErrorCodes: []int{422},
|
||||
Platforms: []string{"openai"},
|
||||
MatchMode: model.MatchModeAny,
|
||||
},
|
||||
{
|
||||
ID: 3,
|
||||
Name: "All Platforms",
|
||||
Enabled: true,
|
||||
Priority: 3,
|
||||
ErrorCodes: []int{422},
|
||||
Platforms: []string{},
|
||||
MatchMode: model.MatchModeAny,
|
||||
},
|
||||
}
|
||||
|
||||
svc := newTestService(rules)
|
||||
|
||||
t.Run("Anthropic 请求匹配 Anthropic 规则", func(t *testing.T) {
|
||||
matched := svc.MatchRule("anthropic", 422, []byte("error"))
|
||||
require.NotNil(t, matched)
|
||||
assert.Equal(t, int64(1), matched.ID)
|
||||
})
|
||||
|
||||
t.Run("OpenAI 请求匹配 OpenAI 规则", func(t *testing.T) {
|
||||
matched := svc.MatchRule("openai", 422, []byte("error"))
|
||||
require.NotNil(t, matched)
|
||||
assert.Equal(t, int64(2), matched.ID)
|
||||
})
|
||||
|
||||
t.Run("Gemini 请求匹配全平台规则", func(t *testing.T) {
|
||||
matched := svc.MatchRule("gemini", 422, []byte("error"))
|
||||
require.NotNil(t, matched)
|
||||
assert.Equal(t, int64(3), matched.ID)
|
||||
})
|
||||
|
||||
t.Run("Antigravity 请求匹配全平台规则", func(t *testing.T) {
|
||||
matched := svc.MatchRule("antigravity", 422, []byte("error"))
|
||||
require.NotNil(t, matched)
|
||||
assert.Equal(t, int64(3), matched.ID)
|
||||
})
|
||||
}
|
||||
|
||||
func TestMatchRule_NoMatch(t *testing.T) {
|
||||
rules := []*model.ErrorPassthroughRule{
|
||||
{
|
||||
ID: 1,
|
||||
Name: "Rule for 422",
|
||||
Enabled: true,
|
||||
Priority: 1,
|
||||
ErrorCodes: []int{422},
|
||||
MatchMode: model.MatchModeAny,
|
||||
},
|
||||
}
|
||||
|
||||
svc := newTestService(rules)
|
||||
matched := svc.MatchRule("anthropic", 500, []byte("error"))
|
||||
|
||||
assert.Nil(t, matched, "不匹配任何规则时应返回 nil")
|
||||
}
|
||||
|
||||
func TestMatchRule_EmptyRules(t *testing.T) {
|
||||
svc := newTestService([]*model.ErrorPassthroughRule{})
|
||||
matched := svc.MatchRule("anthropic", 422, []byte("error"))
|
||||
|
||||
assert.Nil(t, matched, "没有规则时应返回 nil")
|
||||
}
|
||||
|
||||
func TestMatchRule_CaseInsensitiveKeyword(t *testing.T) {
|
||||
rules := []*model.ErrorPassthroughRule{
|
||||
{
|
||||
ID: 1,
|
||||
Name: "Context Limit",
|
||||
Enabled: true,
|
||||
Priority: 1,
|
||||
Keywords: []string{"Context Limit"},
|
||||
MatchMode: model.MatchModeAny,
|
||||
},
|
||||
}
|
||||
|
||||
svc := newTestService(rules)
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
body string
|
||||
expected bool
|
||||
}{
|
||||
{"完全匹配", "Context Limit reached", true},
|
||||
{"小写匹配", "context limit reached", true},
|
||||
{"大写匹配", "CONTEXT LIMIT REACHED", true},
|
||||
{"混合大小写", "ConTeXt LiMiT error", true},
|
||||
{"不匹配", "some other error", false},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
matched := svc.MatchRule("anthropic", 500, []byte(tt.body))
|
||||
if tt.expected {
|
||||
assert.NotNil(t, matched)
|
||||
} else {
|
||||
assert.Nil(t, matched)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// =============================================================================
|
||||
// 测试真实场景
|
||||
// =============================================================================
|
||||
|
||||
func TestMatchRule_RealWorldScenario_ContextLimitPassthrough(t *testing.T) {
|
||||
// 场景:上游返回 422 + "context limit has been reached",需要透传给客户端
|
||||
rules := []*model.ErrorPassthroughRule{
|
||||
{
|
||||
ID: 1,
|
||||
Name: "Context Limit Passthrough",
|
||||
Enabled: true,
|
||||
Priority: 1,
|
||||
ErrorCodes: []int{422},
|
||||
Keywords: []string{"context limit"},
|
||||
MatchMode: model.MatchModeAll, // 必须同时满足
|
||||
Platforms: []string{"anthropic", "antigravity"},
|
||||
PassthroughCode: true,
|
||||
PassthroughBody: true,
|
||||
},
|
||||
}
|
||||
|
||||
svc := newTestService(rules)
|
||||
|
||||
// 测试 Anthropic 平台
|
||||
t.Run("Anthropic 422 with context limit", func(t *testing.T) {
|
||||
body := []byte(`{"type":"error","error":{"type":"invalid_request","message":"The context limit has been reached"}}`)
|
||||
matched := svc.MatchRule("anthropic", 422, body)
|
||||
require.NotNil(t, matched)
|
||||
assert.True(t, matched.PassthroughCode)
|
||||
assert.True(t, matched.PassthroughBody)
|
||||
})
|
||||
|
||||
// 测试 Antigravity 平台
|
||||
t.Run("Antigravity 422 with context limit", func(t *testing.T) {
|
||||
body := []byte(`{"error":"context limit exceeded"}`)
|
||||
matched := svc.MatchRule("antigravity", 422, body)
|
||||
require.NotNil(t, matched)
|
||||
})
|
||||
|
||||
// 测试 OpenAI 平台(不在规则的平台列表中)
|
||||
t.Run("OpenAI should not match", func(t *testing.T) {
|
||||
body := []byte(`{"error":"context limit exceeded"}`)
|
||||
matched := svc.MatchRule("openai", 422, body)
|
||||
assert.Nil(t, matched, "OpenAI 不在规则的平台列表中")
|
||||
})
|
||||
|
||||
// 测试状态码不匹配
|
||||
t.Run("Wrong status code", func(t *testing.T) {
|
||||
body := []byte(`{"error":"context limit exceeded"}`)
|
||||
matched := svc.MatchRule("anthropic", 400, body)
|
||||
assert.Nil(t, matched, "状态码不匹配")
|
||||
})
|
||||
|
||||
// 测试关键词不匹配
|
||||
t.Run("Wrong keyword", func(t *testing.T) {
|
||||
body := []byte(`{"error":"rate limit exceeded"}`)
|
||||
matched := svc.MatchRule("anthropic", 422, body)
|
||||
assert.Nil(t, matched, "关键词不匹配")
|
||||
})
|
||||
}
|
||||
|
||||
func TestMatchRule_RealWorldScenario_CustomErrorMessage(t *testing.T) {
|
||||
// 场景:某些错误需要返回自定义消息,隐藏上游详细信息
|
||||
customMsg := "Service temporarily unavailable, please try again later"
|
||||
responseCode := 503
|
||||
rules := []*model.ErrorPassthroughRule{
|
||||
{
|
||||
ID: 1,
|
||||
Name: "Hide Internal Errors",
|
||||
Enabled: true,
|
||||
Priority: 1,
|
||||
ErrorCodes: []int{500, 502, 503},
|
||||
MatchMode: model.MatchModeAny,
|
||||
PassthroughCode: false,
|
||||
ResponseCode: &responseCode,
|
||||
PassthroughBody: false,
|
||||
CustomMessage: &customMsg,
|
||||
},
|
||||
}
|
||||
|
||||
svc := newTestService(rules)
|
||||
|
||||
matched := svc.MatchRule("anthropic", 500, []byte("internal server error"))
|
||||
require.NotNil(t, matched)
|
||||
assert.False(t, matched.PassthroughCode)
|
||||
assert.Equal(t, 503, *matched.ResponseCode)
|
||||
assert.False(t, matched.PassthroughBody)
|
||||
assert.Equal(t, customMsg, *matched.CustomMessage)
|
||||
}
|
||||
|
||||
// =============================================================================
|
||||
// 测试 model.Validate
|
||||
// =============================================================================
|
||||
|
||||
func TestErrorPassthroughRule_Validate(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
rule *model.ErrorPassthroughRule
|
||||
expectError bool
|
||||
errorField string
|
||||
}{
|
||||
{
|
||||
name: "有效规则 - 透传模式(含错误码)",
|
||||
rule: &model.ErrorPassthroughRule{
|
||||
Name: "Valid Rule",
|
||||
MatchMode: model.MatchModeAny,
|
||||
ErrorCodes: []int{422},
|
||||
PassthroughCode: true,
|
||||
PassthroughBody: true,
|
||||
},
|
||||
expectError: false,
|
||||
},
|
||||
{
|
||||
name: "有效规则 - 透传模式(含关键词)",
|
||||
rule: &model.ErrorPassthroughRule{
|
||||
Name: "Valid Rule",
|
||||
MatchMode: model.MatchModeAny,
|
||||
Keywords: []string{"context limit"},
|
||||
PassthroughCode: true,
|
||||
PassthroughBody: true,
|
||||
},
|
||||
expectError: false,
|
||||
},
|
||||
{
|
||||
name: "有效规则 - 自定义响应",
|
||||
rule: &model.ErrorPassthroughRule{
|
||||
Name: "Valid Rule",
|
||||
MatchMode: model.MatchModeAll,
|
||||
ErrorCodes: []int{500},
|
||||
Keywords: []string{"internal error"},
|
||||
PassthroughCode: false,
|
||||
ResponseCode: testIntPtr(503),
|
||||
PassthroughBody: false,
|
||||
CustomMessage: testStrPtr("Custom error"),
|
||||
},
|
||||
expectError: false,
|
||||
},
|
||||
{
|
||||
name: "缺少名称",
|
||||
rule: &model.ErrorPassthroughRule{
|
||||
Name: "",
|
||||
MatchMode: model.MatchModeAny,
|
||||
ErrorCodes: []int{422},
|
||||
PassthroughCode: true,
|
||||
PassthroughBody: true,
|
||||
},
|
||||
expectError: true,
|
||||
errorField: "name",
|
||||
},
|
||||
{
|
||||
name: "无效的匹配模式",
|
||||
rule: &model.ErrorPassthroughRule{
|
||||
Name: "Invalid Mode",
|
||||
MatchMode: "invalid",
|
||||
ErrorCodes: []int{422},
|
||||
PassthroughCode: true,
|
||||
PassthroughBody: true,
|
||||
},
|
||||
expectError: true,
|
||||
errorField: "match_mode",
|
||||
},
|
||||
{
|
||||
name: "缺少匹配条件(错误码和关键词都为空)",
|
||||
rule: &model.ErrorPassthroughRule{
|
||||
Name: "No Conditions",
|
||||
MatchMode: model.MatchModeAny,
|
||||
ErrorCodes: []int{},
|
||||
Keywords: []string{},
|
||||
PassthroughCode: true,
|
||||
PassthroughBody: true,
|
||||
},
|
||||
expectError: true,
|
||||
errorField: "conditions",
|
||||
},
|
||||
{
|
||||
name: "缺少匹配条件(nil切片)",
|
||||
rule: &model.ErrorPassthroughRule{
|
||||
Name: "Nil Conditions",
|
||||
MatchMode: model.MatchModeAny,
|
||||
ErrorCodes: nil,
|
||||
Keywords: nil,
|
||||
PassthroughCode: true,
|
||||
PassthroughBody: true,
|
||||
},
|
||||
expectError: true,
|
||||
errorField: "conditions",
|
||||
},
|
||||
{
|
||||
name: "自定义状态码但未提供值",
|
||||
rule: &model.ErrorPassthroughRule{
|
||||
Name: "Missing Code",
|
||||
MatchMode: model.MatchModeAny,
|
||||
ErrorCodes: []int{422},
|
||||
PassthroughCode: false,
|
||||
ResponseCode: nil,
|
||||
PassthroughBody: true,
|
||||
},
|
||||
expectError: true,
|
||||
errorField: "response_code",
|
||||
},
|
||||
{
|
||||
name: "自定义消息但未提供值",
|
||||
rule: &model.ErrorPassthroughRule{
|
||||
Name: "Missing Message",
|
||||
MatchMode: model.MatchModeAny,
|
||||
ErrorCodes: []int{422},
|
||||
PassthroughCode: true,
|
||||
PassthroughBody: false,
|
||||
CustomMessage: nil,
|
||||
},
|
||||
expectError: true,
|
||||
errorField: "custom_message",
|
||||
},
|
||||
{
|
||||
name: "自定义消息为空字符串",
|
||||
rule: &model.ErrorPassthroughRule{
|
||||
Name: "Empty Message",
|
||||
MatchMode: model.MatchModeAny,
|
||||
ErrorCodes: []int{422},
|
||||
PassthroughCode: true,
|
||||
PassthroughBody: false,
|
||||
CustomMessage: testStrPtr(""),
|
||||
},
|
||||
expectError: true,
|
||||
errorField: "custom_message",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
err := tt.rule.Validate()
|
||||
if tt.expectError {
|
||||
require.Error(t, err)
|
||||
validationErr, ok := err.(*model.ValidationError)
|
||||
require.True(t, ok, "应该返回 ValidationError")
|
||||
assert.Equal(t, tt.errorField, validationErr.Field)
|
||||
} else {
|
||||
assert.NoError(t, err)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// Helper functions
|
||||
func testIntPtr(i int) *int { return &i }
|
||||
func testStrPtr(s string) *string { return &s }
|
||||
288
backend/internal/service/gateway_cached_tokens_test.go
Normal file
288
backend/internal/service/gateway_cached_tokens_test.go
Normal file
@@ -0,0 +1,288 @@
|
||||
package service
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
"github.com/tidwall/gjson"
|
||||
"github.com/tidwall/sjson"
|
||||
)
|
||||
|
||||
// ---------- reconcileCachedTokens 单元测试 ----------
|
||||
|
||||
func TestReconcileCachedTokens_NilUsage(t *testing.T) {
|
||||
assert.False(t, reconcileCachedTokens(nil))
|
||||
}
|
||||
|
||||
func TestReconcileCachedTokens_AlreadyHasCacheRead(t *testing.T) {
|
||||
// 已有标准字段,不应覆盖
|
||||
usage := map[string]any{
|
||||
"cache_read_input_tokens": float64(100),
|
||||
"cached_tokens": float64(50),
|
||||
}
|
||||
assert.False(t, reconcileCachedTokens(usage))
|
||||
assert.Equal(t, float64(100), usage["cache_read_input_tokens"])
|
||||
}
|
||||
|
||||
func TestReconcileCachedTokens_KimiStyle(t *testing.T) {
|
||||
// Kimi 风格:cache_read_input_tokens=0,cached_tokens>0
|
||||
usage := map[string]any{
|
||||
"input_tokens": float64(23),
|
||||
"cache_creation_input_tokens": float64(0),
|
||||
"cache_read_input_tokens": float64(0),
|
||||
"cached_tokens": float64(23),
|
||||
}
|
||||
assert.True(t, reconcileCachedTokens(usage))
|
||||
assert.Equal(t, float64(23), usage["cache_read_input_tokens"])
|
||||
}
|
||||
|
||||
func TestReconcileCachedTokens_NoCachedTokens(t *testing.T) {
|
||||
// 无 cached_tokens 字段(原生 Claude)
|
||||
usage := map[string]any{
|
||||
"input_tokens": float64(100),
|
||||
"cache_read_input_tokens": float64(0),
|
||||
"cache_creation_input_tokens": float64(0),
|
||||
}
|
||||
assert.False(t, reconcileCachedTokens(usage))
|
||||
assert.Equal(t, float64(0), usage["cache_read_input_tokens"])
|
||||
}
|
||||
|
||||
func TestReconcileCachedTokens_CachedTokensZero(t *testing.T) {
|
||||
// cached_tokens 为 0,不应覆盖
|
||||
usage := map[string]any{
|
||||
"cache_read_input_tokens": float64(0),
|
||||
"cached_tokens": float64(0),
|
||||
}
|
||||
assert.False(t, reconcileCachedTokens(usage))
|
||||
assert.Equal(t, float64(0), usage["cache_read_input_tokens"])
|
||||
}
|
||||
|
||||
func TestReconcileCachedTokens_MissingCacheReadField(t *testing.T) {
|
||||
// cache_read_input_tokens 字段完全不存在,cached_tokens > 0
|
||||
usage := map[string]any{
|
||||
"cached_tokens": float64(42),
|
||||
}
|
||||
assert.True(t, reconcileCachedTokens(usage))
|
||||
assert.Equal(t, float64(42), usage["cache_read_input_tokens"])
|
||||
}
|
||||
|
||||
// ---------- 流式 message_start 事件 reconcile 测试 ----------
|
||||
|
||||
func TestStreamingReconcile_MessageStart(t *testing.T) {
|
||||
// 模拟 Kimi 返回的 message_start SSE 事件
|
||||
eventJSON := `{
|
||||
"type": "message_start",
|
||||
"message": {
|
||||
"id": "msg_123",
|
||||
"type": "message",
|
||||
"role": "assistant",
|
||||
"model": "kimi",
|
||||
"usage": {
|
||||
"input_tokens": 23,
|
||||
"cache_creation_input_tokens": 0,
|
||||
"cache_read_input_tokens": 0,
|
||||
"cached_tokens": 23
|
||||
}
|
||||
}
|
||||
}`
|
||||
|
||||
var event map[string]any
|
||||
require.NoError(t, json.Unmarshal([]byte(eventJSON), &event))
|
||||
|
||||
eventType, _ := event["type"].(string)
|
||||
require.Equal(t, "message_start", eventType)
|
||||
|
||||
// 模拟 processSSEEvent 中的 reconcile 逻辑
|
||||
if msg, ok := event["message"].(map[string]any); ok {
|
||||
if u, ok := msg["usage"].(map[string]any); ok {
|
||||
reconcileCachedTokens(u)
|
||||
}
|
||||
}
|
||||
|
||||
// 验证 cache_read_input_tokens 已被填充
|
||||
msg, ok := event["message"].(map[string]any)
|
||||
require.True(t, ok)
|
||||
usage, ok := msg["usage"].(map[string]any)
|
||||
require.True(t, ok)
|
||||
assert.Equal(t, float64(23), usage["cache_read_input_tokens"])
|
||||
|
||||
// 验证重新序列化后 JSON 也包含正确值
|
||||
data, err := json.Marshal(event)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, int64(23), gjson.GetBytes(data, "message.usage.cache_read_input_tokens").Int())
|
||||
}
|
||||
|
||||
func TestStreamingReconcile_MessageStart_NativeClaude(t *testing.T) {
|
||||
// 原生 Claude 不返回 cached_tokens,reconcile 不应改变任何值
|
||||
eventJSON := `{
|
||||
"type": "message_start",
|
||||
"message": {
|
||||
"usage": {
|
||||
"input_tokens": 100,
|
||||
"cache_creation_input_tokens": 50,
|
||||
"cache_read_input_tokens": 30
|
||||
}
|
||||
}
|
||||
}`
|
||||
|
||||
var event map[string]any
|
||||
require.NoError(t, json.Unmarshal([]byte(eventJSON), &event))
|
||||
|
||||
if msg, ok := event["message"].(map[string]any); ok {
|
||||
if u, ok := msg["usage"].(map[string]any); ok {
|
||||
reconcileCachedTokens(u)
|
||||
}
|
||||
}
|
||||
|
||||
msg, ok := event["message"].(map[string]any)
|
||||
require.True(t, ok)
|
||||
usage, ok := msg["usage"].(map[string]any)
|
||||
require.True(t, ok)
|
||||
assert.Equal(t, float64(30), usage["cache_read_input_tokens"])
|
||||
}
|
||||
|
||||
// ---------- 流式 message_delta 事件 reconcile 测试 ----------
|
||||
|
||||
func TestStreamingReconcile_MessageDelta(t *testing.T) {
|
||||
// 模拟 Kimi 返回的 message_delta SSE 事件
|
||||
eventJSON := `{
|
||||
"type": "message_delta",
|
||||
"usage": {
|
||||
"output_tokens": 7,
|
||||
"cache_read_input_tokens": 0,
|
||||
"cached_tokens": 15
|
||||
}
|
||||
}`
|
||||
|
||||
var event map[string]any
|
||||
require.NoError(t, json.Unmarshal([]byte(eventJSON), &event))
|
||||
|
||||
eventType, _ := event["type"].(string)
|
||||
require.Equal(t, "message_delta", eventType)
|
||||
|
||||
// 模拟 processSSEEvent 中的 reconcile 逻辑
|
||||
usage, ok := event["usage"].(map[string]any)
|
||||
require.True(t, ok)
|
||||
reconcileCachedTokens(usage)
|
||||
assert.Equal(t, float64(15), usage["cache_read_input_tokens"])
|
||||
}
|
||||
|
||||
func TestStreamingReconcile_MessageDelta_NativeClaude(t *testing.T) {
|
||||
// 原生 Claude 的 message_delta 通常没有 cached_tokens
|
||||
eventJSON := `{
|
||||
"type": "message_delta",
|
||||
"usage": {
|
||||
"output_tokens": 50
|
||||
}
|
||||
}`
|
||||
|
||||
var event map[string]any
|
||||
require.NoError(t, json.Unmarshal([]byte(eventJSON), &event))
|
||||
|
||||
usage, ok := event["usage"].(map[string]any)
|
||||
require.True(t, ok)
|
||||
reconcileCachedTokens(usage)
|
||||
_, hasCacheRead := usage["cache_read_input_tokens"]
|
||||
assert.False(t, hasCacheRead, "不应为原生 Claude 响应注入 cache_read_input_tokens")
|
||||
}
|
||||
|
||||
// ---------- 非流式响应 reconcile 测试 ----------
|
||||
|
||||
func TestNonStreamingReconcile_KimiResponse(t *testing.T) {
|
||||
// 模拟 Kimi 非流式响应
|
||||
body := []byte(`{
|
||||
"id": "msg_123",
|
||||
"type": "message",
|
||||
"role": "assistant",
|
||||
"content": [{"type": "text", "text": "hello"}],
|
||||
"model": "kimi",
|
||||
"usage": {
|
||||
"input_tokens": 23,
|
||||
"output_tokens": 7,
|
||||
"cache_creation_input_tokens": 0,
|
||||
"cache_read_input_tokens": 0,
|
||||
"cached_tokens": 23,
|
||||
"prompt_tokens": 23,
|
||||
"completion_tokens": 7
|
||||
}
|
||||
}`)
|
||||
|
||||
// 模拟 handleNonStreamingResponse 中的逻辑
|
||||
var response struct {
|
||||
Usage ClaudeUsage `json:"usage"`
|
||||
}
|
||||
require.NoError(t, json.Unmarshal(body, &response))
|
||||
|
||||
// reconcile
|
||||
if response.Usage.CacheReadInputTokens == 0 {
|
||||
cachedTokens := gjson.GetBytes(body, "usage.cached_tokens").Int()
|
||||
if cachedTokens > 0 {
|
||||
response.Usage.CacheReadInputTokens = int(cachedTokens)
|
||||
if newBody, err := sjson.SetBytes(body, "usage.cache_read_input_tokens", cachedTokens); err == nil {
|
||||
body = newBody
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// 验证内部 usage(计费用)
|
||||
assert.Equal(t, 23, response.Usage.CacheReadInputTokens)
|
||||
assert.Equal(t, 23, response.Usage.InputTokens)
|
||||
assert.Equal(t, 7, response.Usage.OutputTokens)
|
||||
|
||||
// 验证返回给客户端的 JSON body
|
||||
assert.Equal(t, int64(23), gjson.GetBytes(body, "usage.cache_read_input_tokens").Int())
|
||||
}
|
||||
|
||||
func TestNonStreamingReconcile_NativeClaude(t *testing.T) {
|
||||
// 原生 Claude 响应:cache_read_input_tokens 已有值
|
||||
body := []byte(`{
|
||||
"usage": {
|
||||
"input_tokens": 100,
|
||||
"output_tokens": 50,
|
||||
"cache_creation_input_tokens": 20,
|
||||
"cache_read_input_tokens": 30
|
||||
}
|
||||
}`)
|
||||
|
||||
var response struct {
|
||||
Usage ClaudeUsage `json:"usage"`
|
||||
}
|
||||
require.NoError(t, json.Unmarshal(body, &response))
|
||||
|
||||
// CacheReadInputTokens == 30,条件不成立,整个 reconcile 分支不会执行
|
||||
assert.NotZero(t, response.Usage.CacheReadInputTokens)
|
||||
assert.Equal(t, 30, response.Usage.CacheReadInputTokens)
|
||||
}
|
||||
|
||||
func TestNonStreamingReconcile_NoCachedTokens(t *testing.T) {
|
||||
// 没有 cached_tokens 字段
|
||||
body := []byte(`{
|
||||
"usage": {
|
||||
"input_tokens": 100,
|
||||
"output_tokens": 50,
|
||||
"cache_creation_input_tokens": 0,
|
||||
"cache_read_input_tokens": 0
|
||||
}
|
||||
}`)
|
||||
|
||||
var response struct {
|
||||
Usage ClaudeUsage `json:"usage"`
|
||||
}
|
||||
require.NoError(t, json.Unmarshal(body, &response))
|
||||
|
||||
if response.Usage.CacheReadInputTokens == 0 {
|
||||
cachedTokens := gjson.GetBytes(body, "usage.cached_tokens").Int()
|
||||
if cachedTokens > 0 {
|
||||
response.Usage.CacheReadInputTokens = int(cachedTokens)
|
||||
if newBody, err := sjson.SetBytes(body, "usage.cache_read_input_tokens", cachedTokens); err == nil {
|
||||
body = newBody
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// cache_read_input_tokens 应保持为 0
|
||||
assert.Equal(t, 0, response.Usage.CacheReadInputTokens)
|
||||
assert.Equal(t, int64(0), gjson.GetBytes(body, "usage.cache_read_input_tokens").Int())
|
||||
}
|
||||
@@ -370,7 +370,8 @@ type ForwardResult struct {
|
||||
|
||||
// UpstreamFailoverError indicates an upstream error that should trigger account failover.
|
||||
type UpstreamFailoverError struct {
|
||||
StatusCode int
|
||||
StatusCode int
|
||||
ResponseBody []byte // 上游响应体,用于错误透传规则匹配
|
||||
}
|
||||
|
||||
func (e *UpstreamFailoverError) Error() string {
|
||||
@@ -384,6 +385,7 @@ type GatewayService struct {
|
||||
usageLogRepo UsageLogRepository
|
||||
userRepo UserRepository
|
||||
userSubRepo UserSubscriptionRepository
|
||||
userGroupRateRepo UserGroupRateRepository
|
||||
cache GatewayCache
|
||||
cfg *config.Config
|
||||
schedulerSnapshot *SchedulerSnapshotService
|
||||
@@ -405,6 +407,7 @@ func NewGatewayService(
|
||||
usageLogRepo UsageLogRepository,
|
||||
userRepo UserRepository,
|
||||
userSubRepo UserSubscriptionRepository,
|
||||
userGroupRateRepo UserGroupRateRepository,
|
||||
cache GatewayCache,
|
||||
cfg *config.Config,
|
||||
schedulerSnapshot *SchedulerSnapshotService,
|
||||
@@ -424,6 +427,7 @@ func NewGatewayService(
|
||||
usageLogRepo: usageLogRepo,
|
||||
userRepo: userRepo,
|
||||
userSubRepo: userSubRepo,
|
||||
userGroupRateRepo: userGroupRateRepo,
|
||||
cache: cache,
|
||||
cfg: cfg,
|
||||
schedulerSnapshot: schedulerSnapshot,
|
||||
@@ -3281,7 +3285,7 @@ func (s *GatewayService) Forward(ctx context.Context, c *gin.Context, account *A
|
||||
return ""
|
||||
}(),
|
||||
})
|
||||
return nil, &UpstreamFailoverError{StatusCode: resp.StatusCode}
|
||||
return nil, &UpstreamFailoverError{StatusCode: resp.StatusCode, ResponseBody: respBody}
|
||||
}
|
||||
return s.handleRetryExhaustedError(ctx, resp, c, account)
|
||||
}
|
||||
@@ -3311,10 +3315,8 @@ func (s *GatewayService) Forward(ctx context.Context, c *gin.Context, account *A
|
||||
return ""
|
||||
}(),
|
||||
})
|
||||
return nil, &UpstreamFailoverError{StatusCode: resp.StatusCode}
|
||||
return nil, &UpstreamFailoverError{StatusCode: resp.StatusCode, ResponseBody: respBody}
|
||||
}
|
||||
|
||||
// 处理错误响应(不可重试的错误)
|
||||
if resp.StatusCode >= 400 {
|
||||
// 可选:对部分 400 触发 failover(默认关闭以保持语义)
|
||||
if resp.StatusCode == 400 && s.cfg != nil && s.cfg.Gateway.FailoverOn400 {
|
||||
@@ -3358,7 +3360,7 @@ func (s *GatewayService) Forward(ctx context.Context, c *gin.Context, account *A
|
||||
log.Printf("Account %d: 400 error, attempting failover", account.ID)
|
||||
}
|
||||
s.handleFailoverSideEffects(ctx, resp, account)
|
||||
return nil, &UpstreamFailoverError{StatusCode: resp.StatusCode}
|
||||
return nil, &UpstreamFailoverError{StatusCode: resp.StatusCode, ResponseBody: respBody}
|
||||
}
|
||||
}
|
||||
return s.handleErrorResponse(ctx, resp, c, account)
|
||||
@@ -3755,6 +3757,12 @@ func (s *GatewayService) shouldFailoverOn400(respBody []byte) bool {
|
||||
return false
|
||||
}
|
||||
|
||||
// ExtractUpstreamErrorMessage 从上游响应体中提取错误消息
|
||||
// 支持 Claude 风格的错误格式:{"type":"error","error":{"type":"...","message":"..."}}
|
||||
func ExtractUpstreamErrorMessage(body []byte) string {
|
||||
return extractUpstreamErrorMessage(body)
|
||||
}
|
||||
|
||||
func extractUpstreamErrorMessage(body []byte) string {
|
||||
// Claude 风格:{"type":"error","error":{"type":"...","message":"..."}}
|
||||
if m := gjson.GetBytes(body, "error.message").String(); strings.TrimSpace(m) != "" {
|
||||
@@ -3822,7 +3830,7 @@ func (s *GatewayService) handleErrorResponse(ctx context.Context, resp *http.Res
|
||||
shouldDisable = s.rateLimitService.HandleUpstreamError(ctx, account, resp.StatusCode, resp.Header, body)
|
||||
}
|
||||
if shouldDisable {
|
||||
return nil, &UpstreamFailoverError{StatusCode: resp.StatusCode}
|
||||
return nil, &UpstreamFailoverError{StatusCode: resp.StatusCode, ResponseBody: body}
|
||||
}
|
||||
|
||||
// 记录上游错误响应体摘要便于排障(可选:由配置控制;不回显到客户端)
|
||||
@@ -4168,6 +4176,20 @@ func (s *GatewayService) handleStreamingResponse(ctx context.Context, resp *http
|
||||
eventName = eventType
|
||||
}
|
||||
|
||||
// 兼容 Kimi cached_tokens → cache_read_input_tokens
|
||||
if eventType == "message_start" {
|
||||
if msg, ok := event["message"].(map[string]any); ok {
|
||||
if u, ok := msg["usage"].(map[string]any); ok {
|
||||
reconcileCachedTokens(u)
|
||||
}
|
||||
}
|
||||
}
|
||||
if eventType == "message_delta" {
|
||||
if u, ok := event["usage"].(map[string]any); ok {
|
||||
reconcileCachedTokens(u)
|
||||
}
|
||||
}
|
||||
|
||||
if needModelReplace {
|
||||
if msg, ok := event["message"].(map[string]any); ok {
|
||||
if model, ok := msg["model"].(string); ok && model == mappedModel {
|
||||
@@ -4518,6 +4540,17 @@ func (s *GatewayService) handleNonStreamingResponse(ctx context.Context, resp *h
|
||||
return nil, fmt.Errorf("parse response: %w", err)
|
||||
}
|
||||
|
||||
// 兼容 Kimi cached_tokens → cache_read_input_tokens
|
||||
if response.Usage.CacheReadInputTokens == 0 {
|
||||
cachedTokens := gjson.GetBytes(body, "usage.cached_tokens").Int()
|
||||
if cachedTokens > 0 {
|
||||
response.Usage.CacheReadInputTokens = int(cachedTokens)
|
||||
if newBody, err := sjson.SetBytes(body, "usage.cache_read_input_tokens", cachedTokens); err == nil {
|
||||
body = newBody
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// 如果有模型映射,替换响应中的model字段
|
||||
if originalModel != mappedModel {
|
||||
body = s.replaceModelInResponseBody(body, mappedModel, originalModel)
|
||||
@@ -4609,10 +4642,17 @@ func (s *GatewayService) RecordUsage(ctx context.Context, input *RecordUsageInpu
|
||||
account := input.Account
|
||||
subscription := input.Subscription
|
||||
|
||||
// 获取费率倍数
|
||||
// 获取费率倍数(优先级:用户专属 > 分组默认 > 系统默认)
|
||||
multiplier := s.cfg.Default.RateMultiplier
|
||||
if apiKey.GroupID != nil && apiKey.Group != nil {
|
||||
multiplier = apiKey.Group.RateMultiplier
|
||||
|
||||
// 检查用户专属倍率
|
||||
if s.userGroupRateRepo != nil {
|
||||
if userRate, err := s.userGroupRateRepo.GetByUserAndGroup(ctx, user.ID, *apiKey.GroupID); err == nil && userRate != nil {
|
||||
multiplier = *userRate
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
var cost *CostBreakdown
|
||||
@@ -4773,10 +4813,17 @@ func (s *GatewayService) RecordUsageWithLongContext(ctx context.Context, input *
|
||||
account := input.Account
|
||||
subscription := input.Subscription
|
||||
|
||||
// 获取费率倍数
|
||||
// 获取费率倍数(优先级:用户专属 > 分组默认 > 系统默认)
|
||||
multiplier := s.cfg.Default.RateMultiplier
|
||||
if apiKey.GroupID != nil && apiKey.Group != nil {
|
||||
multiplier = apiKey.Group.RateMultiplier
|
||||
|
||||
// 检查用户专属倍率
|
||||
if s.userGroupRateRepo != nil {
|
||||
if userRate, err := s.userGroupRateRepo.GetByUserAndGroup(ctx, user.ID, *apiKey.GroupID); err == nil && userRate != nil {
|
||||
multiplier = *userRate
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
var cost *CostBreakdown
|
||||
@@ -5289,3 +5336,21 @@ func (s *GatewayService) GetAvailableModels(ctx context.Context, groupID *int64,
|
||||
|
||||
return models
|
||||
}
|
||||
|
||||
// reconcileCachedTokens 兼容 Kimi 等上游:
|
||||
// 将 OpenAI 风格的 cached_tokens 映射到 Claude 标准的 cache_read_input_tokens
|
||||
func reconcileCachedTokens(usage map[string]any) bool {
|
||||
if usage == nil {
|
||||
return false
|
||||
}
|
||||
cacheRead, _ := usage["cache_read_input_tokens"].(float64)
|
||||
if cacheRead > 0 {
|
||||
return false // 已有标准字段,无需处理
|
||||
}
|
||||
cached, _ := usage["cached_tokens"].(float64)
|
||||
if cached <= 0 {
|
||||
return false
|
||||
}
|
||||
usage["cache_read_input_tokens"] = cached
|
||||
return true
|
||||
}
|
||||
|
||||
@@ -864,7 +864,7 @@ func (s *GeminiMessagesCompatService) Forward(ctx context.Context, c *gin.Contex
|
||||
Message: upstreamMsg,
|
||||
Detail: upstreamDetail,
|
||||
})
|
||||
return nil, &UpstreamFailoverError{StatusCode: resp.StatusCode}
|
||||
return nil, &UpstreamFailoverError{StatusCode: resp.StatusCode, ResponseBody: respBody}
|
||||
}
|
||||
if s.shouldFailoverGeminiUpstreamError(resp.StatusCode) {
|
||||
upstreamReqID := resp.Header.Get(requestIDHeader)
|
||||
@@ -891,7 +891,7 @@ func (s *GeminiMessagesCompatService) Forward(ctx context.Context, c *gin.Contex
|
||||
Message: upstreamMsg,
|
||||
Detail: upstreamDetail,
|
||||
})
|
||||
return nil, &UpstreamFailoverError{StatusCode: resp.StatusCode}
|
||||
return nil, &UpstreamFailoverError{StatusCode: resp.StatusCode, ResponseBody: respBody}
|
||||
}
|
||||
upstreamReqID := resp.Header.Get(requestIDHeader)
|
||||
if upstreamReqID == "" {
|
||||
@@ -1301,7 +1301,7 @@ func (s *GeminiMessagesCompatService) ForwardNative(ctx context.Context, c *gin.
|
||||
Message: upstreamMsg,
|
||||
Detail: upstreamDetail,
|
||||
})
|
||||
return nil, &UpstreamFailoverError{StatusCode: resp.StatusCode}
|
||||
return nil, &UpstreamFailoverError{StatusCode: resp.StatusCode, ResponseBody: respBody}
|
||||
}
|
||||
if s.shouldFailoverGeminiUpstreamError(resp.StatusCode) {
|
||||
evBody := unwrapIfNeeded(isOAuth, respBody)
|
||||
@@ -1325,7 +1325,7 @@ func (s *GeminiMessagesCompatService) ForwardNative(ctx context.Context, c *gin.
|
||||
Message: upstreamMsg,
|
||||
Detail: upstreamDetail,
|
||||
})
|
||||
return nil, &UpstreamFailoverError{StatusCode: resp.StatusCode}
|
||||
return nil, &UpstreamFailoverError{StatusCode: resp.StatusCode, ResponseBody: evBody}
|
||||
}
|
||||
|
||||
respBody = unwrapIfNeeded(isOAuth, respBody)
|
||||
|
||||
@@ -944,6 +944,32 @@ func (s *GeminiOAuthService) fetchProjectID(ctx context.Context, accessToken, pr
|
||||
return strings.TrimSpace(loadResp.CloudAICompanionProject), tierID, nil
|
||||
}
|
||||
|
||||
// 关键逻辑:对齐 Gemini CLI 对“已注册用户”的处理方式。
|
||||
// 当 LoadCodeAssist 返回了 currentTier / paidTier(表示账号已注册)但没有返回 cloudaicompanionProject 时:
|
||||
// - 不要再调用 onboardUser(通常不会再分配 project_id,且可能触发 INVALID_ARGUMENT)
|
||||
// - 先尝试从 Cloud Resource Manager 获取可用项目;仍失败则提示用户手动填写 project_id
|
||||
if loadResp != nil {
|
||||
registeredTierID := strings.TrimSpace(loadResp.GetTier())
|
||||
if registeredTierID != "" {
|
||||
// 已注册但未返回 cloudaicompanionProject,这在 Google One 用户中较常见:需要用户自行提供 project_id。
|
||||
log.Printf("[GeminiOAuth] User has tier (%s) but no cloudaicompanionProject, trying Cloud Resource Manager...", registeredTierID)
|
||||
|
||||
// Try to get project from Cloud Resource Manager
|
||||
fallback, fbErr := fetchProjectIDFromResourceManager(ctx, accessToken, proxyURL)
|
||||
if fbErr == nil && strings.TrimSpace(fallback) != "" {
|
||||
log.Printf("[GeminiOAuth] Found project from Cloud Resource Manager: %s", fallback)
|
||||
return strings.TrimSpace(fallback), tierID, nil
|
||||
}
|
||||
|
||||
// No project found - user must provide project_id manually
|
||||
log.Printf("[GeminiOAuth] No project found from Cloud Resource Manager, user must provide project_id manually")
|
||||
return "", tierID, fmt.Errorf("user is registered (tier: %s) but no project_id available. Please provide Project ID manually in the authorization form, or create a project at https://console.cloud.google.com", registeredTierID)
|
||||
}
|
||||
}
|
||||
|
||||
// 未检测到 currentTier/paidTier,视为新用户,继续调用 onboardUser
|
||||
log.Printf("[GeminiOAuth] No currentTier/paidTier found, proceeding with onboardUser (tierID: %s)", tierID)
|
||||
|
||||
req := &geminicli.OnboardUserRequest{
|
||||
TierID: tierID,
|
||||
Metadata: geminicli.LoadCodeAssistMetadata{
|
||||
|
||||
@@ -21,6 +21,17 @@ const (
|
||||
var codexCLIInstructions string
|
||||
|
||||
var codexModelMap = map[string]string{
|
||||
"gpt-5.3": "gpt-5.3",
|
||||
"gpt-5.3-none": "gpt-5.3",
|
||||
"gpt-5.3-low": "gpt-5.3",
|
||||
"gpt-5.3-medium": "gpt-5.3",
|
||||
"gpt-5.3-high": "gpt-5.3",
|
||||
"gpt-5.3-xhigh": "gpt-5.3",
|
||||
"gpt-5.3-codex": "gpt-5.3-codex",
|
||||
"gpt-5.3-codex-low": "gpt-5.3-codex",
|
||||
"gpt-5.3-codex-medium": "gpt-5.3-codex",
|
||||
"gpt-5.3-codex-high": "gpt-5.3-codex",
|
||||
"gpt-5.3-codex-xhigh": "gpt-5.3-codex",
|
||||
"gpt-5.1-codex": "gpt-5.1-codex",
|
||||
"gpt-5.1-codex-low": "gpt-5.1-codex",
|
||||
"gpt-5.1-codex-medium": "gpt-5.1-codex",
|
||||
@@ -156,6 +167,12 @@ func normalizeCodexModel(model string) string {
|
||||
if strings.Contains(normalized, "gpt-5.2") || strings.Contains(normalized, "gpt 5.2") {
|
||||
return "gpt-5.2"
|
||||
}
|
||||
if strings.Contains(normalized, "gpt-5.3-codex") || strings.Contains(normalized, "gpt 5.3 codex") {
|
||||
return "gpt-5.3-codex"
|
||||
}
|
||||
if strings.Contains(normalized, "gpt-5.3") || strings.Contains(normalized, "gpt 5.3") {
|
||||
return "gpt-5.3"
|
||||
}
|
||||
if strings.Contains(normalized, "gpt-5.1-codex-max") || strings.Contains(normalized, "gpt 5.1 codex max") {
|
||||
return "gpt-5.1-codex-max"
|
||||
}
|
||||
|
||||
@@ -176,6 +176,19 @@ func TestApplyCodexOAuthTransform_EmptyInput(t *testing.T) {
|
||||
require.Len(t, input, 0)
|
||||
}
|
||||
|
||||
func TestNormalizeCodexModel_Gpt53(t *testing.T) {
|
||||
cases := map[string]string{
|
||||
"gpt-5.3": "gpt-5.3",
|
||||
"gpt-5.3-codex": "gpt-5.3-codex",
|
||||
"gpt-5.3-codex-xhigh": "gpt-5.3-codex",
|
||||
"gpt 5.3 codex": "gpt-5.3-codex",
|
||||
}
|
||||
|
||||
for input, expected := range cases {
|
||||
require.Equal(t, expected, normalizeCodexModel(input))
|
||||
}
|
||||
}
|
||||
|
||||
func setupCodexCache(t *testing.T) {
|
||||
t.Helper()
|
||||
|
||||
|
||||
@@ -940,7 +940,7 @@ func (s *OpenAIGatewayService) Forward(ctx context.Context, c *gin.Context, acco
|
||||
})
|
||||
|
||||
s.handleFailoverSideEffects(ctx, resp, account)
|
||||
return nil, &UpstreamFailoverError{StatusCode: resp.StatusCode}
|
||||
return nil, &UpstreamFailoverError{StatusCode: resp.StatusCode, ResponseBody: respBody}
|
||||
}
|
||||
return s.handleErrorResponse(ctx, resp, c, account)
|
||||
}
|
||||
@@ -1131,7 +1131,7 @@ func (s *OpenAIGatewayService) handleErrorResponse(ctx context.Context, resp *ht
|
||||
Detail: upstreamDetail,
|
||||
})
|
||||
if shouldDisable {
|
||||
return nil, &UpstreamFailoverError{StatusCode: resp.StatusCode}
|
||||
return nil, &UpstreamFailoverError{StatusCode: resp.StatusCode, ResponseBody: body}
|
||||
}
|
||||
|
||||
// Return appropriate error response
|
||||
|
||||
@@ -579,6 +579,7 @@ func (s *PricingService) extractBaseName(model string) string {
|
||||
func (s *PricingService) matchByModelFamily(model string) *LiteLLMModelPricing {
|
||||
// Claude模型系列匹配规则
|
||||
familyPatterns := map[string][]string{
|
||||
"opus-4.6": {"claude-opus-4.6", "claude-opus-4-6"},
|
||||
"opus-4.5": {"claude-opus-4.5", "claude-opus-4-5"},
|
||||
"opus-4": {"claude-opus-4", "claude-3-opus"},
|
||||
"sonnet-4.5": {"claude-sonnet-4.5", "claude-sonnet-4-5"},
|
||||
@@ -651,7 +652,8 @@ func (s *PricingService) matchByModelFamily(model string) *LiteLLMModelPricing {
|
||||
// 回退顺序:
|
||||
// 1. gpt-5.2-codex -> gpt-5.2(去掉后缀如 -codex, -mini, -max 等)
|
||||
// 2. gpt-5.2-20251222 -> gpt-5.2(去掉日期版本号)
|
||||
// 3. 最终回退到 DefaultTestModel (gpt-5.1-codex)
|
||||
// 3. gpt-5.3-codex -> gpt-5.2-codex
|
||||
// 4. 最终回退到 DefaultTestModel (gpt-5.1-codex)
|
||||
func (s *PricingService) matchOpenAIModel(model string) *LiteLLMModelPricing {
|
||||
// 尝试的回退变体
|
||||
variants := s.generateOpenAIModelVariants(model, openAIModelDatePattern)
|
||||
@@ -663,6 +665,13 @@ func (s *PricingService) matchOpenAIModel(model string) *LiteLLMModelPricing {
|
||||
}
|
||||
}
|
||||
|
||||
if strings.HasPrefix(model, "gpt-5.3-codex") {
|
||||
if pricing, ok := s.pricingData["gpt-5.2-codex"]; ok {
|
||||
log.Printf("[Pricing] OpenAI fallback matched %s -> %s", model, "gpt-5.2-codex")
|
||||
return pricing
|
||||
}
|
||||
}
|
||||
|
||||
// 最终回退到 DefaultTestModel
|
||||
defaultModel := strings.ToLower(openai.DefaultTestModel)
|
||||
if pricing, ok := s.pricingData[defaultModel]; ok {
|
||||
|
||||
@@ -21,6 +21,10 @@ type User struct {
|
||||
CreatedAt time.Time
|
||||
UpdatedAt time.Time
|
||||
|
||||
// GroupRates 用户专属分组倍率配置
|
||||
// map[groupID]rateMultiplier
|
||||
GroupRates map[int64]float64
|
||||
|
||||
// TOTP 双因素认证字段
|
||||
TotpSecretEncrypted *string // AES-256-GCM 加密的 TOTP 密钥
|
||||
TotpEnabled bool // 是否启用 TOTP
|
||||
@@ -40,18 +44,20 @@ func (u *User) IsActive() bool {
|
||||
|
||||
// CanBindGroup checks whether a user can bind to a given group.
|
||||
// For standard groups:
|
||||
// - If AllowedGroups is non-empty, only allow binding to IDs in that list.
|
||||
// - If AllowedGroups is empty (nil or length 0), allow binding to any non-exclusive group.
|
||||
// - Public groups (non-exclusive): all users can bind
|
||||
// - Exclusive groups: only users with the group in AllowedGroups can bind
|
||||
func (u *User) CanBindGroup(groupID int64, isExclusive bool) bool {
|
||||
if len(u.AllowedGroups) > 0 {
|
||||
for _, id := range u.AllowedGroups {
|
||||
if id == groupID {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
// 公开分组(非专属):所有用户都可以绑定
|
||||
if !isExclusive {
|
||||
return true
|
||||
}
|
||||
return !isExclusive
|
||||
// 专属分组:需要在 AllowedGroups 中
|
||||
for _, id := range u.AllowedGroups {
|
||||
if id == groupID {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
func (u *User) SetPassword(password string) error {
|
||||
|
||||
25
backend/internal/service/user_group_rate.go
Normal file
25
backend/internal/service/user_group_rate.go
Normal file
@@ -0,0 +1,25 @@
|
||||
package service
|
||||
|
||||
import "context"
|
||||
|
||||
// UserGroupRateRepository 用户专属分组倍率仓储接口
|
||||
// 允许管理员为特定用户设置分组的专属计费倍率,覆盖分组默认倍率
|
||||
type UserGroupRateRepository interface {
|
||||
// GetByUserID 获取用户的所有专属分组倍率
|
||||
// 返回 map[groupID]rateMultiplier
|
||||
GetByUserID(ctx context.Context, userID int64) (map[int64]float64, error)
|
||||
|
||||
// GetByUserAndGroup 获取用户在特定分组的专属倍率
|
||||
// 如果未设置专属倍率,返回 nil
|
||||
GetByUserAndGroup(ctx context.Context, userID, groupID int64) (*float64, error)
|
||||
|
||||
// SyncUserGroupRates 同步用户的分组专属倍率
|
||||
// rates: map[groupID]*rateMultiplier,nil 表示删除该分组的专属倍率
|
||||
SyncUserGroupRates(ctx context.Context, userID int64, rates map[int64]*float64) error
|
||||
|
||||
// DeleteByGroupID 删除指定分组的所有用户专属倍率(分组删除时调用)
|
||||
DeleteByGroupID(ctx context.Context, groupID int64) error
|
||||
|
||||
// DeleteByUserID 删除指定用户的所有专属倍率(用户删除时调用)
|
||||
DeleteByUserID(ctx context.Context, userID int64) error
|
||||
}
|
||||
@@ -274,4 +274,5 @@ var ProviderSet = wire.NewSet(
|
||||
NewUserAttributeService,
|
||||
NewUsageCache,
|
||||
NewTotpService,
|
||||
NewErrorPassthroughService,
|
||||
)
|
||||
|
||||
Reference in New Issue
Block a user