This commit is contained in:
yangjianbo
2026-02-06 06:56:23 +08:00
94 changed files with 10861 additions and 858 deletions

View File

@@ -41,6 +41,7 @@ type UsageLogRepository interface {
// User dashboard stats
GetUserDashboardStats(ctx context.Context, userID int64) (*usagestats.UserDashboardStats, error)
GetAPIKeyDashboardStats(ctx context.Context, apiKeyID int64) (*usagestats.UserDashboardStats, error)
GetUserUsageTrendByUserID(ctx context.Context, userID int64, startTime, endTime time.Time, granularity string) ([]usagestats.TrendDataPoint, error)
GetUserModelStats(ctx context.Context, userID int64, startTime, endTime time.Time) ([]usagestats.ModelStat, error)

View File

@@ -93,6 +93,9 @@ type UpdateUserInput struct {
Concurrency *int // 使用指针区分"未提供"和"设置为0"
Status string
AllowedGroups *[]int64 // 使用指针区分"未提供"和"设置为空数组"
// GroupRates 用户专属分组倍率配置
// map[groupID]*ratenil 表示删除该分组的专属倍率
GroupRates map[int64]*float64
}
type CreateGroupInput struct {
@@ -304,6 +307,7 @@ type adminServiceImpl struct {
proxyRepo ProxyRepository
apiKeyRepo APIKeyRepository
redeemCodeRepo RedeemCodeRepository
userGroupRateRepo UserGroupRateRepository
billingCacheService *BillingCacheService
proxyProber ProxyExitInfoProber
proxyLatencyCache ProxyLatencyCache
@@ -319,6 +323,7 @@ func NewAdminService(
proxyRepo ProxyRepository,
apiKeyRepo APIKeyRepository,
redeemCodeRepo RedeemCodeRepository,
userGroupRateRepo UserGroupRateRepository,
billingCacheService *BillingCacheService,
proxyProber ProxyExitInfoProber,
proxyLatencyCache ProxyLatencyCache,
@@ -332,6 +337,7 @@ func NewAdminService(
proxyRepo: proxyRepo,
apiKeyRepo: apiKeyRepo,
redeemCodeRepo: redeemCodeRepo,
userGroupRateRepo: userGroupRateRepo,
billingCacheService: billingCacheService,
proxyProber: proxyProber,
proxyLatencyCache: proxyLatencyCache,
@@ -346,11 +352,35 @@ func (s *adminServiceImpl) ListUsers(ctx context.Context, page, pageSize int, fi
if err != nil {
return nil, 0, err
}
// 批量加载用户专属分组倍率
if s.userGroupRateRepo != nil && len(users) > 0 {
for i := range users {
rates, err := s.userGroupRateRepo.GetByUserID(ctx, users[i].ID)
if err != nil {
log.Printf("failed to load user group rates: user_id=%d err=%v", users[i].ID, err)
continue
}
users[i].GroupRates = rates
}
}
return users, result.Total, nil
}
func (s *adminServiceImpl) GetUser(ctx context.Context, id int64) (*User, error) {
return s.userRepo.GetByID(ctx, id)
user, err := s.userRepo.GetByID(ctx, id)
if err != nil {
return nil, err
}
// 加载用户专属分组倍率
if s.userGroupRateRepo != nil {
rates, err := s.userGroupRateRepo.GetByUserID(ctx, id)
if err != nil {
log.Printf("failed to load user group rates: user_id=%d err=%v", id, err)
} else {
user.GroupRates = rates
}
}
return user, nil
}
func (s *adminServiceImpl) CreateUser(ctx context.Context, input *CreateUserInput) (*User, error) {
@@ -419,6 +449,14 @@ func (s *adminServiceImpl) UpdateUser(ctx context.Context, id int64, input *Upda
if err := s.userRepo.Update(ctx, user); err != nil {
return nil, err
}
// 同步用户专属分组倍率
if input.GroupRates != nil && s.userGroupRateRepo != nil {
if err := s.userGroupRateRepo.SyncUserGroupRates(ctx, user.ID, input.GroupRates); err != nil {
log.Printf("failed to sync user group rates: user_id=%d err=%v", user.ID, err)
}
}
if s.authCacheInvalidator != nil {
if user.Concurrency != oldConcurrency || user.Status != oldStatus || user.Role != oldRole {
s.authCacheInvalidator.InvalidateAuthCacheByUserID(ctx, user.ID)
@@ -974,6 +1012,7 @@ func (s *adminServiceImpl) DeleteGroup(ctx context.Context, id int64) error {
if err != nil {
return err
}
// 注意user_group_rate_multipliers 表通过外键 ON DELETE CASCADE 自动清理
// 事务成功后,异步失效受影响用户的订阅缓存
if len(affectedUserIDs) > 0 && s.billingCacheService != nil {

View File

@@ -1106,7 +1106,7 @@ func (s *AntigravityGatewayService) Forward(ctx context.Context, c *gin.Context,
Message: upstreamMsg,
Detail: upstreamDetail,
})
return nil, &UpstreamFailoverError{StatusCode: resp.StatusCode}
return nil, &UpstreamFailoverError{StatusCode: resp.StatusCode, ResponseBody: respBody}
}
return nil, s.writeMappedClaudeError(c, account, resp.StatusCode, resp.Header.Get("x-request-id"), respBody)
@@ -1779,6 +1779,7 @@ func (s *AntigravityGatewayService) ForwardGemini(ctx context.Context, c *gin.Co
// 处理错误响应
if resp.StatusCode >= 400 {
respBody, _ := io.ReadAll(io.LimitReader(resp.Body, 2<<20))
contentType := resp.Header.Get("Content-Type")
// 尽早关闭原始响应体,释放连接;后续逻辑仍可能需要读取 body因此用内存副本重新包装。
_ = resp.Body.Close()
resp.Body = io.NopCloser(bytes.NewReader(respBody))
@@ -1849,10 +1850,8 @@ func (s *AntigravityGatewayService) ForwardGemini(ctx context.Context, c *gin.Co
Message: upstreamMsg,
Detail: upstreamDetail,
})
return nil, &UpstreamFailoverError{StatusCode: resp.StatusCode}
return nil, &UpstreamFailoverError{StatusCode: resp.StatusCode, ResponseBody: unwrappedForOps}
}
contentType := resp.Header.Get("Content-Type")
if contentType == "" {
contentType = "application/json"
}

View File

@@ -115,15 +115,16 @@ type UpdateAPIKeyRequest struct {
// APIKeyService API Key服务
type APIKeyService struct {
apiKeyRepo APIKeyRepository
userRepo UserRepository
groupRepo GroupRepository
userSubRepo UserSubscriptionRepository
cache APIKeyCache
cfg *config.Config
authCacheL1 *ristretto.Cache
authCfg apiKeyAuthCacheConfig
authGroup singleflight.Group
apiKeyRepo APIKeyRepository
userRepo UserRepository
groupRepo GroupRepository
userSubRepo UserSubscriptionRepository
userGroupRateRepo UserGroupRateRepository
cache APIKeyCache
cfg *config.Config
authCacheL1 *ristretto.Cache
authCfg apiKeyAuthCacheConfig
authGroup singleflight.Group
}
// NewAPIKeyService 创建API Key服务实例
@@ -132,16 +133,18 @@ func NewAPIKeyService(
userRepo UserRepository,
groupRepo GroupRepository,
userSubRepo UserSubscriptionRepository,
userGroupRateRepo UserGroupRateRepository,
cache APIKeyCache,
cfg *config.Config,
) *APIKeyService {
svc := &APIKeyService{
apiKeyRepo: apiKeyRepo,
userRepo: userRepo,
groupRepo: groupRepo,
userSubRepo: userSubRepo,
cache: cache,
cfg: cfg,
apiKeyRepo: apiKeyRepo,
userRepo: userRepo,
groupRepo: groupRepo,
userSubRepo: userSubRepo,
userGroupRateRepo: userGroupRateRepo,
cache: cache,
cfg: cfg,
}
svc.initAuthCache(cfg)
return svc
@@ -627,6 +630,19 @@ func (s *APIKeyService) SearchAPIKeys(ctx context.Context, userID int64, keyword
return keys, nil
}
// GetUserGroupRates 获取用户的专属分组倍率配置
// 返回 map[groupID]rateMultiplier
func (s *APIKeyService) GetUserGroupRates(ctx context.Context, userID int64) (map[int64]float64, error) {
if s.userGroupRateRepo == nil {
return nil, nil
}
rates, err := s.userGroupRateRepo.GetByUserID(ctx, userID)
if err != nil {
return nil, fmt.Errorf("get user group rates: %w", err)
}
return rates, nil
}
// CheckAPIKeyQuotaAndExpiry checks if the API key is valid for use (not expired, quota not exhausted)
// Returns nil if valid, error if invalid
func (s *APIKeyService) CheckAPIKeyQuotaAndExpiry(apiKey *APIKey) error {

View File

@@ -167,7 +167,7 @@ func TestAPIKeyService_GetByKey_UsesL2Cache(t *testing.T) {
NegativeTTLSeconds: 30,
},
}
svc := NewAPIKeyService(repo, nil, nil, nil, cache, cfg)
svc := NewAPIKeyService(repo, nil, nil, nil, nil, cache, cfg)
groupID := int64(9)
cacheEntry := &APIKeyAuthCacheEntry{
@@ -223,7 +223,7 @@ func TestAPIKeyService_GetByKey_NegativeCache(t *testing.T) {
NegativeTTLSeconds: 30,
},
}
svc := NewAPIKeyService(repo, nil, nil, nil, cache, cfg)
svc := NewAPIKeyService(repo, nil, nil, nil, nil, cache, cfg)
cache.getAuthCache = func(ctx context.Context, key string) (*APIKeyAuthCacheEntry, error) {
return &APIKeyAuthCacheEntry{NotFound: true}, nil
}
@@ -256,7 +256,7 @@ func TestAPIKeyService_GetByKey_CacheMissStoresL2(t *testing.T) {
NegativeTTLSeconds: 30,
},
}
svc := NewAPIKeyService(repo, nil, nil, nil, cache, cfg)
svc := NewAPIKeyService(repo, nil, nil, nil, nil, cache, cfg)
cache.getAuthCache = func(ctx context.Context, key string) (*APIKeyAuthCacheEntry, error) {
return nil, redis.Nil
}
@@ -293,7 +293,7 @@ func TestAPIKeyService_GetByKey_UsesL1Cache(t *testing.T) {
L1TTLSeconds: 60,
},
}
svc := NewAPIKeyService(repo, nil, nil, nil, cache, cfg)
svc := NewAPIKeyService(repo, nil, nil, nil, nil, cache, cfg)
require.NotNil(t, svc.authCacheL1)
_, err := svc.GetByKey(context.Background(), "k-l1")
@@ -320,7 +320,7 @@ func TestAPIKeyService_InvalidateAuthCacheByUserID(t *testing.T) {
NegativeTTLSeconds: 30,
},
}
svc := NewAPIKeyService(repo, nil, nil, nil, cache, cfg)
svc := NewAPIKeyService(repo, nil, nil, nil, nil, cache, cfg)
svc.InvalidateAuthCacheByUserID(context.Background(), 7)
require.Len(t, cache.deleteAuthKeys, 2)
@@ -338,7 +338,7 @@ func TestAPIKeyService_InvalidateAuthCacheByGroupID(t *testing.T) {
L2TTLSeconds: 60,
},
}
svc := NewAPIKeyService(repo, nil, nil, nil, cache, cfg)
svc := NewAPIKeyService(repo, nil, nil, nil, nil, cache, cfg)
svc.InvalidateAuthCacheByGroupID(context.Background(), 9)
require.Len(t, cache.deleteAuthKeys, 2)
@@ -356,7 +356,7 @@ func TestAPIKeyService_InvalidateAuthCacheByKey(t *testing.T) {
L2TTLSeconds: 60,
},
}
svc := NewAPIKeyService(repo, nil, nil, nil, cache, cfg)
svc := NewAPIKeyService(repo, nil, nil, nil, nil, cache, cfg)
svc.InvalidateAuthCacheByKey(context.Background(), "k1")
require.Len(t, cache.deleteAuthKeys, 1)
@@ -375,7 +375,7 @@ func TestAPIKeyService_GetByKey_CachesNegativeOnRepoMiss(t *testing.T) {
NegativeTTLSeconds: 30,
},
}
svc := NewAPIKeyService(repo, nil, nil, nil, cache, cfg)
svc := NewAPIKeyService(repo, nil, nil, nil, nil, cache, cfg)
cache.getAuthCache = func(ctx context.Context, key string) (*APIKeyAuthCacheEntry, error) {
return nil, redis.Nil
}
@@ -411,7 +411,7 @@ func TestAPIKeyService_GetByKey_SingleflightCollapses(t *testing.T) {
Singleflight: true,
},
}
svc := NewAPIKeyService(repo, nil, nil, nil, cache, cfg)
svc := NewAPIKeyService(repo, nil, nil, nil, nil, cache, cfg)
start := make(chan struct{})
wg := sync.WaitGroup{}

View File

@@ -3,6 +3,7 @@ package service
import (
"context"
"crypto/rand"
"crypto/sha256"
"encoding/hex"
"errors"
"fmt"
@@ -25,8 +26,12 @@ var (
ErrEmailReserved = infraerrors.BadRequest("EMAIL_RESERVED", "email is reserved")
ErrInvalidToken = infraerrors.Unauthorized("INVALID_TOKEN", "invalid token")
ErrTokenExpired = infraerrors.Unauthorized("TOKEN_EXPIRED", "token has expired")
ErrAccessTokenExpired = infraerrors.Unauthorized("ACCESS_TOKEN_EXPIRED", "access token has expired")
ErrTokenTooLarge = infraerrors.BadRequest("TOKEN_TOO_LARGE", "token too large")
ErrTokenRevoked = infraerrors.Unauthorized("TOKEN_REVOKED", "token has been revoked")
ErrRefreshTokenInvalid = infraerrors.Unauthorized("REFRESH_TOKEN_INVALID", "invalid refresh token")
ErrRefreshTokenExpired = infraerrors.Unauthorized("REFRESH_TOKEN_EXPIRED", "refresh token has expired")
ErrRefreshTokenReused = infraerrors.Unauthorized("REFRESH_TOKEN_REUSED", "refresh token has been reused")
ErrEmailVerifyRequired = infraerrors.BadRequest("EMAIL_VERIFY_REQUIRED", "email verification is required")
ErrRegDisabled = infraerrors.Forbidden("REGISTRATION_DISABLED", "registration is currently disabled")
ErrServiceUnavailable = infraerrors.ServiceUnavailable("SERVICE_UNAVAILABLE", "service temporarily unavailable")
@@ -37,6 +42,9 @@ var (
// maxTokenLength 限制 token 大小,避免超长 header 触发解析时的异常内存分配。
const maxTokenLength = 8192
// refreshTokenPrefix is the prefix for refresh tokens to distinguish them from access tokens.
const refreshTokenPrefix = "rt_"
// JWTClaims JWT载荷数据
type JWTClaims struct {
UserID int64 `json:"user_id"`
@@ -50,6 +58,7 @@ type JWTClaims struct {
type AuthService struct {
userRepo UserRepository
redeemRepo RedeemCodeRepository
refreshTokenCache RefreshTokenCache
cfg *config.Config
settingService *SettingService
emailService *EmailService
@@ -62,6 +71,7 @@ type AuthService struct {
func NewAuthService(
userRepo UserRepository,
redeemRepo RedeemCodeRepository,
refreshTokenCache RefreshTokenCache,
cfg *config.Config,
settingService *SettingService,
emailService *EmailService,
@@ -72,6 +82,7 @@ func NewAuthService(
return &AuthService{
userRepo: userRepo,
redeemRepo: redeemRepo,
refreshTokenCache: refreshTokenCache,
cfg: cfg,
settingService: settingService,
emailService: emailService,
@@ -481,6 +492,100 @@ func (s *AuthService) LoginOrRegisterOAuth(ctx context.Context, email, username
return token, user, nil
}
// LoginOrRegisterOAuthWithTokenPair 用于第三方 OAuth/SSO 登录,返回完整的 TokenPair
// 与 LoginOrRegisterOAuth 功能相同,但返回 TokenPair 而非单个 token
func (s *AuthService) LoginOrRegisterOAuthWithTokenPair(ctx context.Context, email, username string) (*TokenPair, *User, error) {
// 检查 refreshTokenCache 是否可用
if s.refreshTokenCache == nil {
return nil, nil, errors.New("refresh token cache not configured")
}
email = strings.TrimSpace(email)
if email == "" || len(email) > 255 {
return nil, nil, infraerrors.BadRequest("INVALID_EMAIL", "invalid email")
}
if _, err := mail.ParseAddress(email); err != nil {
return nil, nil, infraerrors.BadRequest("INVALID_EMAIL", "invalid email")
}
username = strings.TrimSpace(username)
if len([]rune(username)) > 100 {
username = string([]rune(username)[:100])
}
user, err := s.userRepo.GetByEmail(ctx, email)
if err != nil {
if errors.Is(err, ErrUserNotFound) {
// OAuth 首次登录视为注册
if s.settingService == nil || !s.settingService.IsRegistrationEnabled(ctx) {
return nil, nil, ErrRegDisabled
}
randomPassword, err := randomHexString(32)
if err != nil {
log.Printf("[Auth] Failed to generate random password for oauth signup: %v", err)
return nil, nil, ErrServiceUnavailable
}
hashedPassword, err := s.HashPassword(randomPassword)
if err != nil {
return nil, nil, fmt.Errorf("hash password: %w", err)
}
defaultBalance := s.cfg.Default.UserBalance
defaultConcurrency := s.cfg.Default.UserConcurrency
if s.settingService != nil {
defaultBalance = s.settingService.GetDefaultBalance(ctx)
defaultConcurrency = s.settingService.GetDefaultConcurrency(ctx)
}
newUser := &User{
Email: email,
Username: username,
PasswordHash: hashedPassword,
Role: RoleUser,
Balance: defaultBalance,
Concurrency: defaultConcurrency,
Status: StatusActive,
}
if err := s.userRepo.Create(ctx, newUser); err != nil {
if errors.Is(err, ErrEmailExists) {
user, err = s.userRepo.GetByEmail(ctx, email)
if err != nil {
log.Printf("[Auth] Database error getting user after conflict: %v", err)
return nil, nil, ErrServiceUnavailable
}
} else {
log.Printf("[Auth] Database error creating oauth user: %v", err)
return nil, nil, ErrServiceUnavailable
}
} else {
user = newUser
}
} else {
log.Printf("[Auth] Database error during oauth login: %v", err)
return nil, nil, ErrServiceUnavailable
}
}
if !user.IsActive() {
return nil, nil, ErrUserNotActive
}
if user.Username == "" && username != "" {
user.Username = username
if err := s.userRepo.Update(ctx, user); err != nil {
log.Printf("[Auth] Failed to update username after oauth login: %v", err)
}
}
tokenPair, err := s.GenerateTokenPair(ctx, user, "")
if err != nil {
return nil, nil, fmt.Errorf("generate token pair: %w", err)
}
return tokenPair, user, nil
}
// ValidateToken 验证JWT token并返回用户声明
func (s *AuthService) ValidateToken(tokenString string) (*JWTClaims, error) {
// 先做长度校验,尽早拒绝异常超长 token降低 DoS 风险。
@@ -539,10 +644,17 @@ func isReservedEmail(email string) bool {
return strings.HasSuffix(normalized, LinuxDoConnectSyntheticEmailDomain)
}
// GenerateToken 生成JWT token
// GenerateToken 生成JWT access token
// 使用新的access_token_expire_minutes配置项如果配置了否则回退到expire_hour
func (s *AuthService) GenerateToken(user *User) (string, error) {
now := time.Now()
expiresAt := now.Add(time.Duration(s.cfg.JWT.ExpireHour) * time.Hour)
var expiresAt time.Time
if s.cfg.JWT.AccessTokenExpireMinutes > 0 {
expiresAt = now.Add(time.Duration(s.cfg.JWT.AccessTokenExpireMinutes) * time.Minute)
} else {
// 向后兼容使用旧的expire_hour配置
expiresAt = now.Add(time.Duration(s.cfg.JWT.ExpireHour) * time.Hour)
}
claims := &JWTClaims{
UserID: user.ID,
@@ -565,6 +677,15 @@ func (s *AuthService) GenerateToken(user *User) (string, error) {
return tokenString, nil
}
// GetAccessTokenExpiresIn 返回Access Token的有效期
// 用于前端设置刷新定时器
func (s *AuthService) GetAccessTokenExpiresIn() int {
if s.cfg.JWT.AccessTokenExpireMinutes > 0 {
return s.cfg.JWT.AccessTokenExpireMinutes * 60
}
return s.cfg.JWT.ExpireHour * 3600
}
// HashPassword 使用bcrypt加密密码
func (s *AuthService) HashPassword(password string) (string, error) {
hashedBytes, err := bcrypt.GenerateFromPassword([]byte(password), bcrypt.DefaultCost)
@@ -755,6 +876,198 @@ func (s *AuthService) ResetPassword(ctx context.Context, email, token, newPasswo
return ErrServiceUnavailable
}
// Also revoke all refresh tokens for this user
if err := s.RevokeAllUserSessions(ctx, user.ID); err != nil {
log.Printf("[Auth] Failed to revoke refresh tokens for user %d: %v", user.ID, err)
// Don't return error - password was already changed successfully
}
log.Printf("[Auth] Password reset successful for user: %s", email)
return nil
}
// ==================== Refresh Token Methods ====================
// TokenPair 包含Access Token和Refresh Token
type TokenPair struct {
AccessToken string `json:"access_token"`
RefreshToken string `json:"refresh_token"`
ExpiresIn int `json:"expires_in"` // Access Token有效期
}
// GenerateTokenPair 生成Access Token和Refresh Token对
// familyID: 可选的Token家族ID用于Token轮转时保持家族关系
func (s *AuthService) GenerateTokenPair(ctx context.Context, user *User, familyID string) (*TokenPair, error) {
// 检查 refreshTokenCache 是否可用
if s.refreshTokenCache == nil {
return nil, errors.New("refresh token cache not configured")
}
// 生成Access Token
accessToken, err := s.GenerateToken(user)
if err != nil {
return nil, fmt.Errorf("generate access token: %w", err)
}
// 生成Refresh Token
refreshToken, err := s.generateRefreshToken(ctx, user, familyID)
if err != nil {
return nil, fmt.Errorf("generate refresh token: %w", err)
}
return &TokenPair{
AccessToken: accessToken,
RefreshToken: refreshToken,
ExpiresIn: s.GetAccessTokenExpiresIn(),
}, nil
}
// generateRefreshToken 生成并存储Refresh Token
func (s *AuthService) generateRefreshToken(ctx context.Context, user *User, familyID string) (string, error) {
// 生成随机Token
tokenBytes := make([]byte, 32)
if _, err := rand.Read(tokenBytes); err != nil {
return "", fmt.Errorf("generate random bytes: %w", err)
}
rawToken := refreshTokenPrefix + hex.EncodeToString(tokenBytes)
// 计算Token哈希存储哈希而非原始Token
tokenHash := hashToken(rawToken)
// 如果没有提供familyID生成新的
if familyID == "" {
familyBytes := make([]byte, 16)
if _, err := rand.Read(familyBytes); err != nil {
return "", fmt.Errorf("generate family id: %w", err)
}
familyID = hex.EncodeToString(familyBytes)
}
now := time.Now()
ttl := time.Duration(s.cfg.JWT.RefreshTokenExpireDays) * 24 * time.Hour
data := &RefreshTokenData{
UserID: user.ID,
TokenVersion: user.TokenVersion,
FamilyID: familyID,
CreatedAt: now,
ExpiresAt: now.Add(ttl),
}
// 存储Token数据
if err := s.refreshTokenCache.StoreRefreshToken(ctx, tokenHash, data, ttl); err != nil {
return "", fmt.Errorf("store refresh token: %w", err)
}
// 添加到用户Token集合
if err := s.refreshTokenCache.AddToUserTokenSet(ctx, user.ID, tokenHash, ttl); err != nil {
log.Printf("[Auth] Failed to add token to user set: %v", err)
// 不影响主流程
}
// 添加到家族Token集合
if err := s.refreshTokenCache.AddToFamilyTokenSet(ctx, familyID, tokenHash, ttl); err != nil {
log.Printf("[Auth] Failed to add token to family set: %v", err)
// 不影响主流程
}
return rawToken, nil
}
// RefreshTokenPair 使用Refresh Token刷新Token对
// 实现Token轮转每次刷新都会生成新的Refresh Token旧Token立即失效
func (s *AuthService) RefreshTokenPair(ctx context.Context, refreshToken string) (*TokenPair, error) {
// 检查 refreshTokenCache 是否可用
if s.refreshTokenCache == nil {
return nil, ErrRefreshTokenInvalid
}
// 验证Token格式
if !strings.HasPrefix(refreshToken, refreshTokenPrefix) {
return nil, ErrRefreshTokenInvalid
}
tokenHash := hashToken(refreshToken)
// 获取Token数据
data, err := s.refreshTokenCache.GetRefreshToken(ctx, tokenHash)
if err != nil {
if errors.Is(err, ErrRefreshTokenNotFound) {
// Token不存在可能是已被使用Token轮转或已过期
log.Printf("[Auth] Refresh token not found, possible reuse attack")
return nil, ErrRefreshTokenInvalid
}
log.Printf("[Auth] Error getting refresh token: %v", err)
return nil, ErrServiceUnavailable
}
// 检查Token是否过期
if time.Now().After(data.ExpiresAt) {
// 删除过期Token
_ = s.refreshTokenCache.DeleteRefreshToken(ctx, tokenHash)
return nil, ErrRefreshTokenExpired
}
// 获取用户信息
user, err := s.userRepo.GetByID(ctx, data.UserID)
if err != nil {
if errors.Is(err, ErrUserNotFound) {
// 用户已删除撤销整个Token家族
_ = s.refreshTokenCache.DeleteTokenFamily(ctx, data.FamilyID)
return nil, ErrRefreshTokenInvalid
}
log.Printf("[Auth] Database error getting user for token refresh: %v", err)
return nil, ErrServiceUnavailable
}
// 检查用户状态
if !user.IsActive() {
// 用户被禁用撤销整个Token家族
_ = s.refreshTokenCache.DeleteTokenFamily(ctx, data.FamilyID)
return nil, ErrUserNotActive
}
// 检查TokenVersion密码更改后所有Token失效
if data.TokenVersion != user.TokenVersion {
// TokenVersion不匹配撤销整个Token家族
_ = s.refreshTokenCache.DeleteTokenFamily(ctx, data.FamilyID)
return nil, ErrTokenRevoked
}
// Token轮转立即使旧Token失效
if err := s.refreshTokenCache.DeleteRefreshToken(ctx, tokenHash); err != nil {
log.Printf("[Auth] Failed to delete old refresh token: %v", err)
// 继续处理,不影响主流程
}
// 生成新的Token对保持同一个家族ID
return s.GenerateTokenPair(ctx, user, data.FamilyID)
}
// RevokeRefreshToken 撤销单个Refresh Token
func (s *AuthService) RevokeRefreshToken(ctx context.Context, refreshToken string) error {
if s.refreshTokenCache == nil {
return nil // No-op if cache not configured
}
if !strings.HasPrefix(refreshToken, refreshTokenPrefix) {
return ErrRefreshTokenInvalid
}
tokenHash := hashToken(refreshToken)
return s.refreshTokenCache.DeleteRefreshToken(ctx, tokenHash)
}
// RevokeAllUserSessions 撤销用户的所有会话所有Refresh Token
// 用于密码更改或用户主动登出所有设备
func (s *AuthService) RevokeAllUserSessions(ctx context.Context, userID int64) error {
if s.refreshTokenCache == nil {
return nil // No-op if cache not configured
}
return s.refreshTokenCache.DeleteUserRefreshTokens(ctx, userID)
}
// hashToken 计算Token的SHA256哈希
func hashToken(token string) string {
hash := sha256.Sum256([]byte(token))
return hex.EncodeToString(hash[:])
}

View File

@@ -116,6 +116,7 @@ func newAuthService(repo *userRepoStub, settings map[string]string, emailCache E
return NewAuthService(
repo,
nil, // redeemRepo
nil, // refreshTokenCache
cfg,
settingService,
emailService,

View File

@@ -0,0 +1,300 @@
package service
import (
"context"
"log"
"sort"
"strings"
"sync"
"github.com/Wei-Shaw/sub2api/internal/model"
)
// ErrorPassthroughRepository 定义错误透传规则的数据访问接口
type ErrorPassthroughRepository interface {
// List 获取所有规则
List(ctx context.Context) ([]*model.ErrorPassthroughRule, error)
// GetByID 根据 ID 获取规则
GetByID(ctx context.Context, id int64) (*model.ErrorPassthroughRule, error)
// Create 创建规则
Create(ctx context.Context, rule *model.ErrorPassthroughRule) (*model.ErrorPassthroughRule, error)
// Update 更新规则
Update(ctx context.Context, rule *model.ErrorPassthroughRule) (*model.ErrorPassthroughRule, error)
// Delete 删除规则
Delete(ctx context.Context, id int64) error
}
// ErrorPassthroughCache 定义错误透传规则的缓存接口
type ErrorPassthroughCache interface {
// Get 从缓存获取规则列表
Get(ctx context.Context) ([]*model.ErrorPassthroughRule, bool)
// Set 设置缓存
Set(ctx context.Context, rules []*model.ErrorPassthroughRule) error
// Invalidate 使缓存失效
Invalidate(ctx context.Context) error
// NotifyUpdate 通知其他实例刷新缓存
NotifyUpdate(ctx context.Context) error
// SubscribeUpdates 订阅缓存更新通知
SubscribeUpdates(ctx context.Context, handler func())
}
// ErrorPassthroughService 错误透传规则服务
type ErrorPassthroughService struct {
repo ErrorPassthroughRepository
cache ErrorPassthroughCache
// 本地内存缓存,用于快速匹配
localCache []*model.ErrorPassthroughRule
localCacheMu sync.RWMutex
}
// NewErrorPassthroughService 创建错误透传规则服务
func NewErrorPassthroughService(
repo ErrorPassthroughRepository,
cache ErrorPassthroughCache,
) *ErrorPassthroughService {
svc := &ErrorPassthroughService{
repo: repo,
cache: cache,
}
// 启动时加载规则到本地缓存
ctx := context.Background()
if err := svc.refreshLocalCache(ctx); err != nil {
log.Printf("[ErrorPassthroughService] Failed to load rules on startup: %v", err)
}
// 订阅缓存更新通知
if cache != nil {
cache.SubscribeUpdates(ctx, func() {
if err := svc.refreshLocalCache(context.Background()); err != nil {
log.Printf("[ErrorPassthroughService] Failed to refresh cache on notification: %v", err)
}
})
}
return svc
}
// List 获取所有规则
func (s *ErrorPassthroughService) List(ctx context.Context) ([]*model.ErrorPassthroughRule, error) {
return s.repo.List(ctx)
}
// GetByID 根据 ID 获取规则
func (s *ErrorPassthroughService) GetByID(ctx context.Context, id int64) (*model.ErrorPassthroughRule, error) {
return s.repo.GetByID(ctx, id)
}
// Create 创建规则
func (s *ErrorPassthroughService) Create(ctx context.Context, rule *model.ErrorPassthroughRule) (*model.ErrorPassthroughRule, error) {
if err := rule.Validate(); err != nil {
return nil, err
}
created, err := s.repo.Create(ctx, rule)
if err != nil {
return nil, err
}
// 刷新缓存
s.invalidateAndNotify(ctx)
return created, nil
}
// Update 更新规则
func (s *ErrorPassthroughService) Update(ctx context.Context, rule *model.ErrorPassthroughRule) (*model.ErrorPassthroughRule, error) {
if err := rule.Validate(); err != nil {
return nil, err
}
updated, err := s.repo.Update(ctx, rule)
if err != nil {
return nil, err
}
// 刷新缓存
s.invalidateAndNotify(ctx)
return updated, nil
}
// Delete 删除规则
func (s *ErrorPassthroughService) Delete(ctx context.Context, id int64) error {
if err := s.repo.Delete(ctx, id); err != nil {
return err
}
// 刷新缓存
s.invalidateAndNotify(ctx)
return nil
}
// MatchRule 匹配透传规则
// 返回第一个匹配的规则,如果没有匹配则返回 nil
func (s *ErrorPassthroughService) MatchRule(platform string, statusCode int, body []byte) *model.ErrorPassthroughRule {
rules := s.getCachedRules()
if len(rules) == 0 {
return nil
}
bodyStr := strings.ToLower(string(body))
for _, rule := range rules {
if !rule.Enabled {
continue
}
if !s.platformMatches(rule, platform) {
continue
}
if s.ruleMatches(rule, statusCode, bodyStr) {
return rule
}
}
return nil
}
// getCachedRules 获取缓存的规则列表(按优先级排序)
func (s *ErrorPassthroughService) getCachedRules() []*model.ErrorPassthroughRule {
s.localCacheMu.RLock()
rules := s.localCache
s.localCacheMu.RUnlock()
if rules != nil {
return rules
}
// 如果本地缓存为空,尝试刷新
ctx := context.Background()
if err := s.refreshLocalCache(ctx); err != nil {
log.Printf("[ErrorPassthroughService] Failed to refresh cache: %v", err)
return nil
}
s.localCacheMu.RLock()
defer s.localCacheMu.RUnlock()
return s.localCache
}
// refreshLocalCache 刷新本地缓存
func (s *ErrorPassthroughService) refreshLocalCache(ctx context.Context) error {
// 先尝试从 Redis 缓存获取
if s.cache != nil {
if rules, ok := s.cache.Get(ctx); ok {
s.setLocalCache(rules)
return nil
}
}
// 从数据库加载repo.List 已按 priority 排序)
rules, err := s.repo.List(ctx)
if err != nil {
return err
}
// 更新 Redis 缓存
if s.cache != nil {
if err := s.cache.Set(ctx, rules); err != nil {
log.Printf("[ErrorPassthroughService] Failed to set cache: %v", err)
}
}
// 更新本地缓存setLocalCache 内部会确保排序)
s.setLocalCache(rules)
return nil
}
// setLocalCache 设置本地缓存
func (s *ErrorPassthroughService) setLocalCache(rules []*model.ErrorPassthroughRule) {
// 按优先级排序
sorted := make([]*model.ErrorPassthroughRule, len(rules))
copy(sorted, rules)
sort.Slice(sorted, func(i, j int) bool {
return sorted[i].Priority < sorted[j].Priority
})
s.localCacheMu.Lock()
s.localCache = sorted
s.localCacheMu.Unlock()
}
// invalidateAndNotify 使缓存失效并通知其他实例
func (s *ErrorPassthroughService) invalidateAndNotify(ctx context.Context) {
// 刷新本地缓存
if err := s.refreshLocalCache(ctx); err != nil {
log.Printf("[ErrorPassthroughService] Failed to refresh local cache: %v", err)
}
// 通知其他实例
if s.cache != nil {
if err := s.cache.NotifyUpdate(ctx); err != nil {
log.Printf("[ErrorPassthroughService] Failed to notify cache update: %v", err)
}
}
}
// platformMatches 检查平台是否匹配
func (s *ErrorPassthroughService) platformMatches(rule *model.ErrorPassthroughRule, platform string) bool {
// 如果没有配置平台限制,则匹配所有平台
if len(rule.Platforms) == 0 {
return true
}
platform = strings.ToLower(platform)
for _, p := range rule.Platforms {
if strings.ToLower(p) == platform {
return true
}
}
return false
}
// ruleMatches 检查规则是否匹配
func (s *ErrorPassthroughService) ruleMatches(rule *model.ErrorPassthroughRule, statusCode int, bodyLower string) bool {
hasErrorCodes := len(rule.ErrorCodes) > 0
hasKeywords := len(rule.Keywords) > 0
// 如果没有配置任何条件,不匹配
if !hasErrorCodes && !hasKeywords {
return false
}
codeMatch := !hasErrorCodes || s.containsInt(rule.ErrorCodes, statusCode)
keywordMatch := !hasKeywords || s.containsAnyKeyword(bodyLower, rule.Keywords)
if rule.MatchMode == model.MatchModeAll {
// "all" 模式:所有配置的条件都必须满足
return codeMatch && keywordMatch
}
// "any" 模式:任一条件满足即可
if hasErrorCodes && hasKeywords {
return codeMatch || keywordMatch
}
return codeMatch && keywordMatch
}
// containsInt 检查切片是否包含指定整数
func (s *ErrorPassthroughService) containsInt(slice []int, val int) bool {
for _, v := range slice {
if v == val {
return true
}
}
return false
}
// containsAnyKeyword 检查字符串是否包含任一关键词(不区分大小写)
func (s *ErrorPassthroughService) containsAnyKeyword(bodyLower string, keywords []string) bool {
for _, kw := range keywords {
if strings.Contains(bodyLower, strings.ToLower(kw)) {
return true
}
}
return false
}

View File

@@ -0,0 +1,755 @@
//go:build unit
package service
import (
"context"
"strings"
"testing"
"github.com/Wei-Shaw/sub2api/internal/model"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
// mockErrorPassthroughRepo 用于测试的 mock repository
type mockErrorPassthroughRepo struct {
rules []*model.ErrorPassthroughRule
}
func (m *mockErrorPassthroughRepo) List(ctx context.Context) ([]*model.ErrorPassthroughRule, error) {
return m.rules, nil
}
func (m *mockErrorPassthroughRepo) GetByID(ctx context.Context, id int64) (*model.ErrorPassthroughRule, error) {
for _, r := range m.rules {
if r.ID == id {
return r, nil
}
}
return nil, nil
}
func (m *mockErrorPassthroughRepo) Create(ctx context.Context, rule *model.ErrorPassthroughRule) (*model.ErrorPassthroughRule, error) {
rule.ID = int64(len(m.rules) + 1)
m.rules = append(m.rules, rule)
return rule, nil
}
func (m *mockErrorPassthroughRepo) Update(ctx context.Context, rule *model.ErrorPassthroughRule) (*model.ErrorPassthroughRule, error) {
for i, r := range m.rules {
if r.ID == rule.ID {
m.rules[i] = rule
return rule, nil
}
}
return rule, nil
}
func (m *mockErrorPassthroughRepo) Delete(ctx context.Context, id int64) error {
for i, r := range m.rules {
if r.ID == id {
m.rules = append(m.rules[:i], m.rules[i+1:]...)
return nil
}
}
return nil
}
// newTestService 创建测试用的服务实例
func newTestService(rules []*model.ErrorPassthroughRule) *ErrorPassthroughService {
repo := &mockErrorPassthroughRepo{rules: rules}
svc := &ErrorPassthroughService{
repo: repo,
cache: nil, // 不使用缓存
}
// 直接设置本地缓存,避免调用 refreshLocalCache
svc.setLocalCache(rules)
return svc
}
// =============================================================================
// 测试 ruleMatches 核心匹配逻辑
// =============================================================================
func TestRuleMatches_NoConditions(t *testing.T) {
// 没有配置任何条件时,不应该匹配
svc := newTestService(nil)
rule := &model.ErrorPassthroughRule{
Enabled: true,
ErrorCodes: []int{},
Keywords: []string{},
MatchMode: model.MatchModeAny,
}
assert.False(t, svc.ruleMatches(rule, 422, "some error message"),
"没有配置条件时不应该匹配")
}
func TestRuleMatches_OnlyErrorCodes_AnyMode(t *testing.T) {
svc := newTestService(nil)
rule := &model.ErrorPassthroughRule{
Enabled: true,
ErrorCodes: []int{422, 400},
Keywords: []string{},
MatchMode: model.MatchModeAny,
}
tests := []struct {
name string
statusCode int
body string
expected bool
}{
{"状态码匹配 422", 422, "any message", true},
{"状态码匹配 400", 400, "any message", true},
{"状态码不匹配 500", 500, "any message", false},
{"状态码不匹配 429", 429, "any message", false},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
result := svc.ruleMatches(rule, tt.statusCode, tt.body)
assert.Equal(t, tt.expected, result)
})
}
}
func TestRuleMatches_OnlyKeywords_AnyMode(t *testing.T) {
svc := newTestService(nil)
rule := &model.ErrorPassthroughRule{
Enabled: true,
ErrorCodes: []int{},
Keywords: []string{"context limit", "model not supported"},
MatchMode: model.MatchModeAny,
}
tests := []struct {
name string
statusCode int
body string
expected bool
}{
{"关键词匹配 context limit", 500, "error: context limit reached", true},
{"关键词匹配 model not supported", 400, "the model not supported here", true},
{"关键词不匹配", 422, "some other error", false},
// 注意ruleMatches 接收的 body 参数应该是已经转换为小写的
// 实际使用时MatchRule 会先将 body 转换为小写再传给 ruleMatches
{"关键词大小写 - 输入已小写", 500, "context limit exceeded", true},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
// 模拟 MatchRule 的行为:先转换为小写
bodyLower := strings.ToLower(tt.body)
result := svc.ruleMatches(rule, tt.statusCode, bodyLower)
assert.Equal(t, tt.expected, result)
})
}
}
func TestRuleMatches_BothConditions_AnyMode(t *testing.T) {
// any 模式:错误码 OR 关键词
svc := newTestService(nil)
rule := &model.ErrorPassthroughRule{
Enabled: true,
ErrorCodes: []int{422, 400},
Keywords: []string{"context limit"},
MatchMode: model.MatchModeAny,
}
tests := []struct {
name string
statusCode int
body string
expected bool
reason string
}{
{
name: "状态码和关键词都匹配",
statusCode: 422,
body: "context limit reached",
expected: true,
reason: "both match",
},
{
name: "只有状态码匹配",
statusCode: 422,
body: "some other error",
expected: true,
reason: "code matches, keyword doesn't - OR mode should match",
},
{
name: "只有关键词匹配",
statusCode: 500,
body: "context limit exceeded",
expected: true,
reason: "keyword matches, code doesn't - OR mode should match",
},
{
name: "都不匹配",
statusCode: 500,
body: "some other error",
expected: false,
reason: "neither matches",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
result := svc.ruleMatches(rule, tt.statusCode, tt.body)
assert.Equal(t, tt.expected, result, tt.reason)
})
}
}
func TestRuleMatches_BothConditions_AllMode(t *testing.T) {
// all 模式:错误码 AND 关键词
svc := newTestService(nil)
rule := &model.ErrorPassthroughRule{
Enabled: true,
ErrorCodes: []int{422, 400},
Keywords: []string{"context limit"},
MatchMode: model.MatchModeAll,
}
tests := []struct {
name string
statusCode int
body string
expected bool
reason string
}{
{
name: "状态码和关键词都匹配",
statusCode: 422,
body: "context limit reached",
expected: true,
reason: "both match - AND mode should match",
},
{
name: "只有状态码匹配",
statusCode: 422,
body: "some other error",
expected: false,
reason: "code matches but keyword doesn't - AND mode should NOT match",
},
{
name: "只有关键词匹配",
statusCode: 500,
body: "context limit exceeded",
expected: false,
reason: "keyword matches but code doesn't - AND mode should NOT match",
},
{
name: "都不匹配",
statusCode: 500,
body: "some other error",
expected: false,
reason: "neither matches",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
result := svc.ruleMatches(rule, tt.statusCode, tt.body)
assert.Equal(t, tt.expected, result, tt.reason)
})
}
}
// =============================================================================
// 测试 platformMatches 平台匹配逻辑
// =============================================================================
func TestPlatformMatches(t *testing.T) {
svc := newTestService(nil)
tests := []struct {
name string
rulePlatforms []string
requestPlatform string
expected bool
}{
{
name: "空平台列表匹配所有",
rulePlatforms: []string{},
requestPlatform: "anthropic",
expected: true,
},
{
name: "nil平台列表匹配所有",
rulePlatforms: nil,
requestPlatform: "openai",
expected: true,
},
{
name: "精确匹配 anthropic",
rulePlatforms: []string{"anthropic", "openai"},
requestPlatform: "anthropic",
expected: true,
},
{
name: "精确匹配 openai",
rulePlatforms: []string{"anthropic", "openai"},
requestPlatform: "openai",
expected: true,
},
{
name: "不匹配 gemini",
rulePlatforms: []string{"anthropic", "openai"},
requestPlatform: "gemini",
expected: false,
},
{
name: "大小写不敏感",
rulePlatforms: []string{"Anthropic", "OpenAI"},
requestPlatform: "anthropic",
expected: true,
},
{
name: "匹配 antigravity",
rulePlatforms: []string{"antigravity"},
requestPlatform: "antigravity",
expected: true,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
rule := &model.ErrorPassthroughRule{
Platforms: tt.rulePlatforms,
}
result := svc.platformMatches(rule, tt.requestPlatform)
assert.Equal(t, tt.expected, result)
})
}
}
// =============================================================================
// 测试 MatchRule 完整匹配流程
// =============================================================================
func TestMatchRule_Priority(t *testing.T) {
// 测试规则按优先级排序,优先级小的先匹配
rules := []*model.ErrorPassthroughRule{
{
ID: 1,
Name: "Low Priority",
Enabled: true,
Priority: 10,
ErrorCodes: []int{422},
MatchMode: model.MatchModeAny,
},
{
ID: 2,
Name: "High Priority",
Enabled: true,
Priority: 1,
ErrorCodes: []int{422},
MatchMode: model.MatchModeAny,
},
}
svc := newTestService(rules)
matched := svc.MatchRule("anthropic", 422, []byte("error"))
require.NotNil(t, matched)
assert.Equal(t, int64(2), matched.ID, "应该匹配优先级更高(数值更小)的规则")
assert.Equal(t, "High Priority", matched.Name)
}
func TestMatchRule_DisabledRule(t *testing.T) {
rules := []*model.ErrorPassthroughRule{
{
ID: 1,
Name: "Disabled Rule",
Enabled: false,
Priority: 1,
ErrorCodes: []int{422},
MatchMode: model.MatchModeAny,
},
{
ID: 2,
Name: "Enabled Rule",
Enabled: true,
Priority: 10,
ErrorCodes: []int{422},
MatchMode: model.MatchModeAny,
},
}
svc := newTestService(rules)
matched := svc.MatchRule("anthropic", 422, []byte("error"))
require.NotNil(t, matched)
assert.Equal(t, int64(2), matched.ID, "应该跳过禁用的规则")
}
func TestMatchRule_PlatformFilter(t *testing.T) {
rules := []*model.ErrorPassthroughRule{
{
ID: 1,
Name: "Anthropic Only",
Enabled: true,
Priority: 1,
ErrorCodes: []int{422},
Platforms: []string{"anthropic"},
MatchMode: model.MatchModeAny,
},
{
ID: 2,
Name: "OpenAI Only",
Enabled: true,
Priority: 2,
ErrorCodes: []int{422},
Platforms: []string{"openai"},
MatchMode: model.MatchModeAny,
},
{
ID: 3,
Name: "All Platforms",
Enabled: true,
Priority: 3,
ErrorCodes: []int{422},
Platforms: []string{},
MatchMode: model.MatchModeAny,
},
}
svc := newTestService(rules)
t.Run("Anthropic 请求匹配 Anthropic 规则", func(t *testing.T) {
matched := svc.MatchRule("anthropic", 422, []byte("error"))
require.NotNil(t, matched)
assert.Equal(t, int64(1), matched.ID)
})
t.Run("OpenAI 请求匹配 OpenAI 规则", func(t *testing.T) {
matched := svc.MatchRule("openai", 422, []byte("error"))
require.NotNil(t, matched)
assert.Equal(t, int64(2), matched.ID)
})
t.Run("Gemini 请求匹配全平台规则", func(t *testing.T) {
matched := svc.MatchRule("gemini", 422, []byte("error"))
require.NotNil(t, matched)
assert.Equal(t, int64(3), matched.ID)
})
t.Run("Antigravity 请求匹配全平台规则", func(t *testing.T) {
matched := svc.MatchRule("antigravity", 422, []byte("error"))
require.NotNil(t, matched)
assert.Equal(t, int64(3), matched.ID)
})
}
func TestMatchRule_NoMatch(t *testing.T) {
rules := []*model.ErrorPassthroughRule{
{
ID: 1,
Name: "Rule for 422",
Enabled: true,
Priority: 1,
ErrorCodes: []int{422},
MatchMode: model.MatchModeAny,
},
}
svc := newTestService(rules)
matched := svc.MatchRule("anthropic", 500, []byte("error"))
assert.Nil(t, matched, "不匹配任何规则时应返回 nil")
}
func TestMatchRule_EmptyRules(t *testing.T) {
svc := newTestService([]*model.ErrorPassthroughRule{})
matched := svc.MatchRule("anthropic", 422, []byte("error"))
assert.Nil(t, matched, "没有规则时应返回 nil")
}
func TestMatchRule_CaseInsensitiveKeyword(t *testing.T) {
rules := []*model.ErrorPassthroughRule{
{
ID: 1,
Name: "Context Limit",
Enabled: true,
Priority: 1,
Keywords: []string{"Context Limit"},
MatchMode: model.MatchModeAny,
},
}
svc := newTestService(rules)
tests := []struct {
name string
body string
expected bool
}{
{"完全匹配", "Context Limit reached", true},
{"小写匹配", "context limit reached", true},
{"大写匹配", "CONTEXT LIMIT REACHED", true},
{"混合大小写", "ConTeXt LiMiT error", true},
{"不匹配", "some other error", false},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
matched := svc.MatchRule("anthropic", 500, []byte(tt.body))
if tt.expected {
assert.NotNil(t, matched)
} else {
assert.Nil(t, matched)
}
})
}
}
// =============================================================================
// 测试真实场景
// =============================================================================
func TestMatchRule_RealWorldScenario_ContextLimitPassthrough(t *testing.T) {
// 场景:上游返回 422 + "context limit has been reached",需要透传给客户端
rules := []*model.ErrorPassthroughRule{
{
ID: 1,
Name: "Context Limit Passthrough",
Enabled: true,
Priority: 1,
ErrorCodes: []int{422},
Keywords: []string{"context limit"},
MatchMode: model.MatchModeAll, // 必须同时满足
Platforms: []string{"anthropic", "antigravity"},
PassthroughCode: true,
PassthroughBody: true,
},
}
svc := newTestService(rules)
// 测试 Anthropic 平台
t.Run("Anthropic 422 with context limit", func(t *testing.T) {
body := []byte(`{"type":"error","error":{"type":"invalid_request","message":"The context limit has been reached"}}`)
matched := svc.MatchRule("anthropic", 422, body)
require.NotNil(t, matched)
assert.True(t, matched.PassthroughCode)
assert.True(t, matched.PassthroughBody)
})
// 测试 Antigravity 平台
t.Run("Antigravity 422 with context limit", func(t *testing.T) {
body := []byte(`{"error":"context limit exceeded"}`)
matched := svc.MatchRule("antigravity", 422, body)
require.NotNil(t, matched)
})
// 测试 OpenAI 平台(不在规则的平台列表中)
t.Run("OpenAI should not match", func(t *testing.T) {
body := []byte(`{"error":"context limit exceeded"}`)
matched := svc.MatchRule("openai", 422, body)
assert.Nil(t, matched, "OpenAI 不在规则的平台列表中")
})
// 测试状态码不匹配
t.Run("Wrong status code", func(t *testing.T) {
body := []byte(`{"error":"context limit exceeded"}`)
matched := svc.MatchRule("anthropic", 400, body)
assert.Nil(t, matched, "状态码不匹配")
})
// 测试关键词不匹配
t.Run("Wrong keyword", func(t *testing.T) {
body := []byte(`{"error":"rate limit exceeded"}`)
matched := svc.MatchRule("anthropic", 422, body)
assert.Nil(t, matched, "关键词不匹配")
})
}
func TestMatchRule_RealWorldScenario_CustomErrorMessage(t *testing.T) {
// 场景:某些错误需要返回自定义消息,隐藏上游详细信息
customMsg := "Service temporarily unavailable, please try again later"
responseCode := 503
rules := []*model.ErrorPassthroughRule{
{
ID: 1,
Name: "Hide Internal Errors",
Enabled: true,
Priority: 1,
ErrorCodes: []int{500, 502, 503},
MatchMode: model.MatchModeAny,
PassthroughCode: false,
ResponseCode: &responseCode,
PassthroughBody: false,
CustomMessage: &customMsg,
},
}
svc := newTestService(rules)
matched := svc.MatchRule("anthropic", 500, []byte("internal server error"))
require.NotNil(t, matched)
assert.False(t, matched.PassthroughCode)
assert.Equal(t, 503, *matched.ResponseCode)
assert.False(t, matched.PassthroughBody)
assert.Equal(t, customMsg, *matched.CustomMessage)
}
// =============================================================================
// 测试 model.Validate
// =============================================================================
func TestErrorPassthroughRule_Validate(t *testing.T) {
tests := []struct {
name string
rule *model.ErrorPassthroughRule
expectError bool
errorField string
}{
{
name: "有效规则 - 透传模式(含错误码)",
rule: &model.ErrorPassthroughRule{
Name: "Valid Rule",
MatchMode: model.MatchModeAny,
ErrorCodes: []int{422},
PassthroughCode: true,
PassthroughBody: true,
},
expectError: false,
},
{
name: "有效规则 - 透传模式(含关键词)",
rule: &model.ErrorPassthroughRule{
Name: "Valid Rule",
MatchMode: model.MatchModeAny,
Keywords: []string{"context limit"},
PassthroughCode: true,
PassthroughBody: true,
},
expectError: false,
},
{
name: "有效规则 - 自定义响应",
rule: &model.ErrorPassthroughRule{
Name: "Valid Rule",
MatchMode: model.MatchModeAll,
ErrorCodes: []int{500},
Keywords: []string{"internal error"},
PassthroughCode: false,
ResponseCode: testIntPtr(503),
PassthroughBody: false,
CustomMessage: testStrPtr("Custom error"),
},
expectError: false,
},
{
name: "缺少名称",
rule: &model.ErrorPassthroughRule{
Name: "",
MatchMode: model.MatchModeAny,
ErrorCodes: []int{422},
PassthroughCode: true,
PassthroughBody: true,
},
expectError: true,
errorField: "name",
},
{
name: "无效的匹配模式",
rule: &model.ErrorPassthroughRule{
Name: "Invalid Mode",
MatchMode: "invalid",
ErrorCodes: []int{422},
PassthroughCode: true,
PassthroughBody: true,
},
expectError: true,
errorField: "match_mode",
},
{
name: "缺少匹配条件(错误码和关键词都为空)",
rule: &model.ErrorPassthroughRule{
Name: "No Conditions",
MatchMode: model.MatchModeAny,
ErrorCodes: []int{},
Keywords: []string{},
PassthroughCode: true,
PassthroughBody: true,
},
expectError: true,
errorField: "conditions",
},
{
name: "缺少匹配条件nil切片",
rule: &model.ErrorPassthroughRule{
Name: "Nil Conditions",
MatchMode: model.MatchModeAny,
ErrorCodes: nil,
Keywords: nil,
PassthroughCode: true,
PassthroughBody: true,
},
expectError: true,
errorField: "conditions",
},
{
name: "自定义状态码但未提供值",
rule: &model.ErrorPassthroughRule{
Name: "Missing Code",
MatchMode: model.MatchModeAny,
ErrorCodes: []int{422},
PassthroughCode: false,
ResponseCode: nil,
PassthroughBody: true,
},
expectError: true,
errorField: "response_code",
},
{
name: "自定义消息但未提供值",
rule: &model.ErrorPassthroughRule{
Name: "Missing Message",
MatchMode: model.MatchModeAny,
ErrorCodes: []int{422},
PassthroughCode: true,
PassthroughBody: false,
CustomMessage: nil,
},
expectError: true,
errorField: "custom_message",
},
{
name: "自定义消息为空字符串",
rule: &model.ErrorPassthroughRule{
Name: "Empty Message",
MatchMode: model.MatchModeAny,
ErrorCodes: []int{422},
PassthroughCode: true,
PassthroughBody: false,
CustomMessage: testStrPtr(""),
},
expectError: true,
errorField: "custom_message",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
err := tt.rule.Validate()
if tt.expectError {
require.Error(t, err)
validationErr, ok := err.(*model.ValidationError)
require.True(t, ok, "应该返回 ValidationError")
assert.Equal(t, tt.errorField, validationErr.Field)
} else {
assert.NoError(t, err)
}
})
}
}
// Helper functions
func testIntPtr(i int) *int { return &i }
func testStrPtr(s string) *string { return &s }

View File

@@ -20,7 +20,6 @@ import (
"strings"
"sync/atomic"
"time"
"unicode"
"github.com/Wei-Shaw/sub2api/internal/config"
"github.com/Wei-Shaw/sub2api/internal/pkg/claude"
@@ -375,7 +374,8 @@ type ForwardResult struct {
// UpstreamFailoverError indicates an upstream error that should trigger account failover.
type UpstreamFailoverError struct {
StatusCode int
StatusCode int
ResponseBody []byte // 上游响应体,用于错误透传规则匹配
}
func (e *UpstreamFailoverError) Error() string {
@@ -389,6 +389,7 @@ type GatewayService struct {
usageLogRepo UsageLogRepository
userRepo UserRepository
userSubRepo UserSubscriptionRepository
userGroupRateRepo UserGroupRateRepository
cache GatewayCache
cfg *config.Config
schedulerSnapshot *SchedulerSnapshotService
@@ -410,6 +411,7 @@ func NewGatewayService(
usageLogRepo UsageLogRepository,
userRepo UserRepository,
userSubRepo UserSubscriptionRepository,
userGroupRateRepo UserGroupRateRepository,
cache GatewayCache,
cfg *config.Config,
schedulerSnapshot *SchedulerSnapshotService,
@@ -429,6 +431,7 @@ func NewGatewayService(
usageLogRepo: usageLogRepo,
userRepo: userRepo,
userSubRepo: userSubRepo,
userGroupRateRepo: userGroupRateRepo,
cache: cache,
cfg: cfg,
schedulerSnapshot: schedulerSnapshot,
@@ -624,35 +627,6 @@ func stripToolPrefix(value string) string {
return toolPrefixRe.ReplaceAllString(value, "")
}
func toPascalCase(value string) string {
if value == "" {
return value
}
normalized := toolNameBoundaryRe.ReplaceAllString(value, " ")
tokens := make([]string, 0)
for _, token := range strings.Fields(normalized) {
expanded := toolNameCamelRe.ReplaceAllString(token, "$1 $2")
parts := strings.Fields(expanded)
if len(parts) > 0 {
tokens = append(tokens, parts...)
}
}
if len(tokens) == 0 {
return value
}
var builder strings.Builder
for _, token := range tokens {
lower := strings.ToLower(token)
if lower == "" {
continue
}
runes := []rune(lower)
runes[0] = unicode.ToUpper(runes[0])
_, _ = builder.WriteString(string(runes))
}
return builder.String()
}
func toSnakeCase(value string) string {
if value == "" {
return value
@@ -668,16 +642,15 @@ func normalizeToolNameForClaude(name string, cache map[string]string) string {
return name
}
stripped := stripToolPrefix(name)
// 只对已知的工具名进行映射,未知工具名保持原样
// 避免破坏 Anthropic 特殊工具(如 text_editor_20250728
mapped, ok := claudeToolNameOverrides[strings.ToLower(stripped)]
if !ok {
mapped = toPascalCase(stripped)
}
if mapped != "" && cache != nil && mapped != stripped {
cache[mapped] = stripped
}
if mapped == "" {
return stripped
}
if cache != nil && mapped != stripped {
cache[mapped] = stripped
}
return mapped
}
@@ -686,15 +659,18 @@ func normalizeToolNameForOpenCode(name string, cache map[string]string) string {
return name
}
stripped := stripToolPrefix(name)
// 优先从请求时建立的映射中查找
if cache != nil {
if mapped, ok := cache[stripped]; ok {
return mapped
}
}
// 已知工具名的硬编码映射
if mapped, ok := openCodeToolOverrides[stripped]; ok {
return mapped
}
return toSnakeCase(stripped)
// 未知工具名保持原样,避免破坏 Anthropic 特殊工具
return stripped
}
func normalizeParamNameForOpenCode(name string, cache map[string]string) string {
@@ -3313,7 +3289,7 @@ func (s *GatewayService) Forward(ctx context.Context, c *gin.Context, account *A
return ""
}(),
})
return nil, &UpstreamFailoverError{StatusCode: resp.StatusCode}
return nil, &UpstreamFailoverError{StatusCode: resp.StatusCode, ResponseBody: respBody}
}
return s.handleRetryExhaustedError(ctx, resp, c, account)
}
@@ -3343,10 +3319,8 @@ func (s *GatewayService) Forward(ctx context.Context, c *gin.Context, account *A
return ""
}(),
})
return nil, &UpstreamFailoverError{StatusCode: resp.StatusCode}
return nil, &UpstreamFailoverError{StatusCode: resp.StatusCode, ResponseBody: respBody}
}
// 处理错误响应(不可重试的错误)
if resp.StatusCode >= 400 {
// 可选:对部分 400 触发 failover默认关闭以保持语义
if resp.StatusCode == 400 && s.cfg != nil && s.cfg.Gateway.FailoverOn400 {
@@ -3390,7 +3364,7 @@ func (s *GatewayService) Forward(ctx context.Context, c *gin.Context, account *A
log.Printf("Account %d: 400 error, attempting failover", account.ID)
}
s.handleFailoverSideEffects(ctx, resp, account)
return nil, &UpstreamFailoverError{StatusCode: resp.StatusCode}
return nil, &UpstreamFailoverError{StatusCode: resp.StatusCode, ResponseBody: respBody}
}
}
return s.handleErrorResponse(ctx, resp, c, account)
@@ -3787,6 +3761,12 @@ func (s *GatewayService) shouldFailoverOn400(respBody []byte) bool {
return false
}
// ExtractUpstreamErrorMessage 从上游响应体中提取错误消息
// 支持 Claude 风格的错误格式:{"type":"error","error":{"type":"...","message":"..."}}
func ExtractUpstreamErrorMessage(body []byte) string {
return extractUpstreamErrorMessage(body)
}
func extractUpstreamErrorMessage(body []byte) string {
// Claude 风格:{"type":"error","error":{"type":"...","message":"..."}}
if m := gjson.GetBytes(body, "error.message").String(); strings.TrimSpace(m) != "" {
@@ -3854,7 +3834,7 @@ func (s *GatewayService) handleErrorResponse(ctx context.Context, resp *http.Res
shouldDisable = s.rateLimitService.HandleUpstreamError(ctx, account, resp.StatusCode, resp.Header, body)
}
if shouldDisable {
return nil, &UpstreamFailoverError{StatusCode: resp.StatusCode}
return nil, &UpstreamFailoverError{StatusCode: resp.StatusCode, ResponseBody: body}
}
// 记录上游错误响应体摘要便于排障(可选:由配置控制;不回显到客户端)
@@ -4641,10 +4621,17 @@ func (s *GatewayService) RecordUsage(ctx context.Context, input *RecordUsageInpu
account := input.Account
subscription := input.Subscription
// 获取费率倍数
// 获取费率倍数(优先级:用户专属 > 分组默认 > 系统默认)
multiplier := s.cfg.Default.RateMultiplier
if apiKey.GroupID != nil && apiKey.Group != nil {
multiplier = apiKey.Group.RateMultiplier
// 检查用户专属倍率
if s.userGroupRateRepo != nil {
if userRate, err := s.userGroupRateRepo.GetByUserAndGroup(ctx, user.ID, *apiKey.GroupID); err == nil && userRate != nil {
multiplier = *userRate
}
}
}
var cost *CostBreakdown
@@ -4827,10 +4814,17 @@ func (s *GatewayService) RecordUsageWithLongContext(ctx context.Context, input *
account := input.Account
subscription := input.Subscription
// 获取费率倍数
// 获取费率倍数(优先级:用户专属 > 分组默认 > 系统默认)
multiplier := s.cfg.Default.RateMultiplier
if apiKey.GroupID != nil && apiKey.Group != nil {
multiplier = apiKey.Group.RateMultiplier
// 检查用户专属倍率
if s.userGroupRateRepo != nil {
if userRate, err := s.userGroupRateRepo.GetByUserAndGroup(ctx, user.ID, *apiKey.GroupID); err == nil && userRate != nil {
multiplier = *userRate
}
}
}
var cost *CostBreakdown

View File

@@ -864,7 +864,7 @@ func (s *GeminiMessagesCompatService) Forward(ctx context.Context, c *gin.Contex
Message: upstreamMsg,
Detail: upstreamDetail,
})
return nil, &UpstreamFailoverError{StatusCode: resp.StatusCode}
return nil, &UpstreamFailoverError{StatusCode: resp.StatusCode, ResponseBody: respBody}
}
if s.shouldFailoverGeminiUpstreamError(resp.StatusCode) {
upstreamReqID := resp.Header.Get(requestIDHeader)
@@ -891,7 +891,7 @@ func (s *GeminiMessagesCompatService) Forward(ctx context.Context, c *gin.Contex
Message: upstreamMsg,
Detail: upstreamDetail,
})
return nil, &UpstreamFailoverError{StatusCode: resp.StatusCode}
return nil, &UpstreamFailoverError{StatusCode: resp.StatusCode, ResponseBody: respBody}
}
upstreamReqID := resp.Header.Get(requestIDHeader)
if upstreamReqID == "" {
@@ -1301,7 +1301,7 @@ func (s *GeminiMessagesCompatService) ForwardNative(ctx context.Context, c *gin.
Message: upstreamMsg,
Detail: upstreamDetail,
})
return nil, &UpstreamFailoverError{StatusCode: resp.StatusCode}
return nil, &UpstreamFailoverError{StatusCode: resp.StatusCode, ResponseBody: respBody}
}
if s.shouldFailoverGeminiUpstreamError(resp.StatusCode) {
evBody := unwrapIfNeeded(isOAuth, respBody)
@@ -1325,7 +1325,7 @@ func (s *GeminiMessagesCompatService) ForwardNative(ctx context.Context, c *gin.
Message: upstreamMsg,
Detail: upstreamDetail,
})
return nil, &UpstreamFailoverError{StatusCode: resp.StatusCode}
return nil, &UpstreamFailoverError{StatusCode: resp.StatusCode, ResponseBody: evBody}
}
respBody = unwrapIfNeeded(isOAuth, respBody)

View File

@@ -944,6 +944,32 @@ func (s *GeminiOAuthService) fetchProjectID(ctx context.Context, accessToken, pr
return strings.TrimSpace(loadResp.CloudAICompanionProject), tierID, nil
}
// 关键逻辑:对齐 Gemini CLI 对“已注册用户”的处理方式。
// 当 LoadCodeAssist 返回了 currentTier / paidTier表示账号已注册但没有返回 cloudaicompanionProject 时:
// - 不要再调用 onboardUser通常不会再分配 project_id且可能触发 INVALID_ARGUMENT
// - 先尝试从 Cloud Resource Manager 获取可用项目;仍失败则提示用户手动填写 project_id
if loadResp != nil {
registeredTierID := strings.TrimSpace(loadResp.GetTier())
if registeredTierID != "" {
// 已注册但未返回 cloudaicompanionProject这在 Google One 用户中较常见:需要用户自行提供 project_id。
log.Printf("[GeminiOAuth] User has tier (%s) but no cloudaicompanionProject, trying Cloud Resource Manager...", registeredTierID)
// Try to get project from Cloud Resource Manager
fallback, fbErr := fetchProjectIDFromResourceManager(ctx, accessToken, proxyURL)
if fbErr == nil && strings.TrimSpace(fallback) != "" {
log.Printf("[GeminiOAuth] Found project from Cloud Resource Manager: %s", fallback)
return strings.TrimSpace(fallback), tierID, nil
}
// No project found - user must provide project_id manually
log.Printf("[GeminiOAuth] No project found from Cloud Resource Manager, user must provide project_id manually")
return "", tierID, fmt.Errorf("user is registered (tier: %s) but no project_id available. Please provide Project ID manually in the authorization form, or create a project at https://console.cloud.google.com", registeredTierID)
}
}
// 未检测到 currentTier/paidTier视为新用户继续调用 onboardUser
log.Printf("[GeminiOAuth] No currentTier/paidTier found, proceeding with onboardUser (tierID: %s)", tierID)
req := &geminicli.OnboardUserRequest{
TierID: tierID,
Metadata: geminicli.LoadCodeAssistMetadata{

View File

@@ -846,10 +846,12 @@ func (s *OpenAIGatewayService) Forward(ctx context.Context, c *gin.Context, acco
}
}
// Remove prompt_cache_retention (not supported by upstream OpenAI API)
if _, has := reqBody["prompt_cache_retention"]; has {
delete(reqBody, "prompt_cache_retention")
bodyModified = true
// Remove unsupported fields (not supported by upstream OpenAI API)
for _, unsupportedField := range []string{"prompt_cache_retention", "safety_identifier", "previous_response_id"} {
if _, has := reqBody[unsupportedField]; has {
delete(reqBody, unsupportedField)
bodyModified = true
}
}
}
@@ -938,7 +940,7 @@ func (s *OpenAIGatewayService) Forward(ctx context.Context, c *gin.Context, acco
})
s.handleFailoverSideEffects(ctx, resp, account)
return nil, &UpstreamFailoverError{StatusCode: resp.StatusCode}
return nil, &UpstreamFailoverError{StatusCode: resp.StatusCode, ResponseBody: respBody}
}
return s.handleErrorResponse(ctx, resp, c, account)
}
@@ -1129,7 +1131,7 @@ func (s *OpenAIGatewayService) handleErrorResponse(ctx context.Context, resp *ht
Detail: upstreamDetail,
})
if shouldDisable {
return nil, &UpstreamFailoverError{StatusCode: resp.StatusCode}
return nil, &UpstreamFailoverError{StatusCode: resp.StatusCode, ResponseBody: body}
}
// Return appropriate error response

View File

@@ -0,0 +1,73 @@
package service
import (
"context"
"errors"
"time"
)
// ErrRefreshTokenNotFound is returned when a refresh token is not found in cache.
// This is used to abstract away the underlying cache implementation (e.g., redis.Nil).
var ErrRefreshTokenNotFound = errors.New("refresh token not found")
// RefreshTokenData 存储在Redis中的Refresh Token数据
type RefreshTokenData struct {
UserID int64 `json:"user_id"`
TokenVersion int64 `json:"token_version"` // 用于检测密码更改后的Token失效
FamilyID string `json:"family_id"` // Token家族ID用于防重放攻击
CreatedAt time.Time `json:"created_at"`
ExpiresAt time.Time `json:"expires_at"`
}
// RefreshTokenCache 管理Refresh Token的Redis缓存
// 用于JWT Token刷新机制支持Token轮转和防重放攻击
//
// Key 格式:
// - refresh_token:{token_hash} -> RefreshTokenData (JSON)
// - user_refresh_tokens:{user_id} -> Set<token_hash>
// - token_family:{family_id} -> Set<token_hash>
type RefreshTokenCache interface {
// StoreRefreshToken 存储Refresh Token
// tokenHash: Token的SHA256哈希值不存储原始Token
// data: Token关联的数据
// ttl: Token过期时间
StoreRefreshToken(ctx context.Context, tokenHash string, data *RefreshTokenData, ttl time.Duration) error
// GetRefreshToken 获取Refresh Token数据
// 返回 (data, nil) 如果Token存在
// 返回 (nil, ErrRefreshTokenNotFound) 如果Token不存在
// 返回 (nil, err) 如果发生其他错误
GetRefreshToken(ctx context.Context, tokenHash string) (*RefreshTokenData, error)
// DeleteRefreshToken 删除单个Refresh Token
// 用于Token轮转时使旧Token失效
DeleteRefreshToken(ctx context.Context, tokenHash string) error
// DeleteUserRefreshTokens 删除用户的所有Refresh Token
// 用于密码更改或用户主动登出所有设备
DeleteUserRefreshTokens(ctx context.Context, userID int64) error
// DeleteTokenFamily 删除整个Token家族
// 用于检测到Token重放攻击时撤销整个会话链
DeleteTokenFamily(ctx context.Context, familyID string) error
// AddToUserTokenSet 将Token添加到用户的Token集合
// 用于跟踪用户的所有活跃Refresh Token
AddToUserTokenSet(ctx context.Context, userID int64, tokenHash string, ttl time.Duration) error
// AddToFamilyTokenSet 将Token添加到家族Token集合
// 用于跟踪同一登录会话的所有Token
AddToFamilyTokenSet(ctx context.Context, familyID string, tokenHash string, ttl time.Duration) error
// GetUserTokenHashes 获取用户的所有Token哈希
// 用于批量删除用户Token
GetUserTokenHashes(ctx context.Context, userID int64) ([]string, error)
// GetFamilyTokenHashes 获取家族的所有Token哈希
// 用于批量删除家族Token
GetFamilyTokenHashes(ctx context.Context, familyID string) ([]string, error)
// IsTokenInFamily 检查Token是否属于指定家族
// 用于验证Token家族关系
IsTokenInFamily(ctx context.Context, familyID string, tokenHash string) (bool, error)
}

View File

@@ -288,6 +288,15 @@ func (s *UsageService) GetUserDashboardStats(ctx context.Context, userID int64)
return stats, nil
}
// GetAPIKeyDashboardStats returns dashboard summary stats filtered by API Key.
func (s *UsageService) GetAPIKeyDashboardStats(ctx context.Context, apiKeyID int64) (*usagestats.UserDashboardStats, error) {
stats, err := s.usageRepo.GetAPIKeyDashboardStats(ctx, apiKeyID)
if err != nil {
return nil, fmt.Errorf("get api key dashboard stats: %w", err)
}
return stats, nil
}
// GetUserUsageTrendByUserID returns per-user usage trend.
func (s *UsageService) GetUserUsageTrendByUserID(ctx context.Context, userID int64, startTime, endTime time.Time, granularity string) ([]usagestats.TrendDataPoint, error) {
trend, err := s.usageRepo.GetUserUsageTrendByUserID(ctx, userID, startTime, endTime, granularity)

View File

@@ -21,6 +21,10 @@ type User struct {
CreatedAt time.Time
UpdatedAt time.Time
// GroupRates 用户专属分组倍率配置
// map[groupID]rateMultiplier
GroupRates map[int64]float64
// TOTP 双因素认证字段
TotpSecretEncrypted *string // AES-256-GCM 加密的 TOTP 密钥
TotpEnabled bool // 是否启用 TOTP
@@ -40,18 +44,20 @@ func (u *User) IsActive() bool {
// CanBindGroup checks whether a user can bind to a given group.
// For standard groups:
// - If AllowedGroups is non-empty, only allow binding to IDs in that list.
// - If AllowedGroups is empty (nil or length 0), allow binding to any non-exclusive group.
// - Public groups (non-exclusive): all users can bind
// - Exclusive groups: only users with the group in AllowedGroups can bind
func (u *User) CanBindGroup(groupID int64, isExclusive bool) bool {
if len(u.AllowedGroups) > 0 {
for _, id := range u.AllowedGroups {
if id == groupID {
return true
}
}
return false
// 公开分组(非专属):所有用户都可以绑定
if !isExclusive {
return true
}
return !isExclusive
// 专属分组:需要在 AllowedGroups 中
for _, id := range u.AllowedGroups {
if id == groupID {
return true
}
}
return false
}
func (u *User) SetPassword(password string) error {

View File

@@ -0,0 +1,25 @@
package service
import "context"
// UserGroupRateRepository 用户专属分组倍率仓储接口
// 允许管理员为特定用户设置分组的专属计费倍率,覆盖分组默认倍率
type UserGroupRateRepository interface {
// GetByUserID 获取用户的所有专属分组倍率
// 返回 map[groupID]rateMultiplier
GetByUserID(ctx context.Context, userID int64) (map[int64]float64, error)
// GetByUserAndGroup 获取用户在特定分组的专属倍率
// 如果未设置专属倍率,返回 nil
GetByUserAndGroup(ctx context.Context, userID, groupID int64) (*float64, error)
// SyncUserGroupRates 同步用户的分组专属倍率
// rates: map[groupID]*rateMultipliernil 表示删除该分组的专属倍率
SyncUserGroupRates(ctx context.Context, userID int64, rates map[int64]*float64) error
// DeleteByGroupID 删除指定分组的所有用户专属倍率(分组删除时调用)
DeleteByGroupID(ctx context.Context, groupID int64) error
// DeleteByUserID 删除指定用户的所有专属倍率(用户删除时调用)
DeleteByUserID(ctx context.Context, userID int64) error
}

View File

@@ -294,4 +294,5 @@ var ProviderSet = wire.NewSet(
NewUserAttributeService,
NewUsageCache,
NewTotpService,
NewErrorPassthroughService,
)