Merge upstream/main: v0.1.61 updates with TOTP & password reset
Major features: - feat(auth): add TOTP (2FA) support for enhanced security - feat(auth): add password reset functionality - feat(openai): improve rate limit handling with reset time tracking - feat(subscription): add expiry notification service - fix(urlvalidator): remove trailing slash from validated URLs - docs: update demo site domain Resolved conflicts: - deploy/docker-compose.yml: added TOTP_ENCRYPTION_KEY config - setting_service.go: keep TianShuAPI site name while adding new features Co-Authored-By: Claude Sonnet 4.5 <noreply@anthropic.com>
This commit is contained in:
@@ -93,6 +93,18 @@ func (s *userRepoStub) RemoveGroupFromAllowedGroups(ctx context.Context, groupID
|
||||
panic("unexpected RemoveGroupFromAllowedGroups call")
|
||||
}
|
||||
|
||||
func (s *userRepoStub) UpdateTotpSecret(ctx context.Context, userID int64, encryptedSecret *string) error {
|
||||
panic("unexpected UpdateTotpSecret call")
|
||||
}
|
||||
|
||||
func (s *userRepoStub) EnableTotp(ctx context.Context, userID int64) error {
|
||||
panic("unexpected EnableTotp call")
|
||||
}
|
||||
|
||||
func (s *userRepoStub) DisableTotp(ctx context.Context, userID int64) error {
|
||||
panic("unexpected DisableTotp call")
|
||||
}
|
||||
|
||||
type groupRepoStub struct {
|
||||
affectedUserIDs []int64
|
||||
deleteErr error
|
||||
|
||||
@@ -142,12 +142,13 @@ func (s *AntigravityOAuthService) ExchangeCode(ctx context.Context, input *Antig
|
||||
result.Email = userInfo.Email
|
||||
}
|
||||
|
||||
// 获取 project_id(部分账户类型可能没有)
|
||||
loadResp, _, err := client.LoadCodeAssist(ctx, tokenResp.AccessToken)
|
||||
if err != nil {
|
||||
fmt.Printf("[AntigravityOAuth] 警告: 获取 project_id 失败: %v\n", err)
|
||||
} else if loadResp != nil && loadResp.CloudAICompanionProject != "" {
|
||||
result.ProjectID = loadResp.CloudAICompanionProject
|
||||
// 获取 project_id(部分账户类型可能没有),失败时重试
|
||||
projectID, loadErr := s.loadProjectIDWithRetry(ctx, tokenResp.AccessToken, proxyURL, 3)
|
||||
if loadErr != nil {
|
||||
fmt.Printf("[AntigravityOAuth] 警告: 获取 project_id 失败(重试后): %v\n", loadErr)
|
||||
result.ProjectIDMissing = true
|
||||
} else {
|
||||
result.ProjectID = projectID
|
||||
}
|
||||
|
||||
return result, nil
|
||||
@@ -237,21 +238,60 @@ func (s *AntigravityOAuthService) RefreshAccountToken(ctx context.Context, accou
|
||||
tokenInfo.Email = existingEmail
|
||||
}
|
||||
|
||||
// 每次刷新都调用 LoadCodeAssist 获取 project_id
|
||||
client := antigravity.NewClient(proxyURL)
|
||||
loadResp, _, err := client.LoadCodeAssist(ctx, tokenInfo.AccessToken)
|
||||
if err != nil || loadResp == nil || loadResp.CloudAICompanionProject == "" {
|
||||
// LoadCodeAssist 失败或返回空,保留原有 project_id,标记缺失
|
||||
existingProjectID := strings.TrimSpace(account.GetCredential("project_id"))
|
||||
// 每次刷新都调用 LoadCodeAssist 获取 project_id,失败时重试
|
||||
existingProjectID := strings.TrimSpace(account.GetCredential("project_id"))
|
||||
projectID, loadErr := s.loadProjectIDWithRetry(ctx, tokenInfo.AccessToken, proxyURL, 3)
|
||||
|
||||
if loadErr != nil {
|
||||
// LoadCodeAssist 失败,保留原有 project_id
|
||||
tokenInfo.ProjectID = existingProjectID
|
||||
tokenInfo.ProjectIDMissing = true
|
||||
// 只有从未获取过 project_id 且本次也获取失败时,才标记为真正缺失
|
||||
// 如果之前有 project_id,本次只是临时故障,不应标记为错误
|
||||
if existingProjectID == "" {
|
||||
tokenInfo.ProjectIDMissing = true
|
||||
}
|
||||
} else {
|
||||
tokenInfo.ProjectID = loadResp.CloudAICompanionProject
|
||||
tokenInfo.ProjectID = projectID
|
||||
}
|
||||
|
||||
return tokenInfo, nil
|
||||
}
|
||||
|
||||
// loadProjectIDWithRetry 带重试机制获取 project_id
|
||||
// 返回 project_id 和错误,失败时会重试指定次数
|
||||
func (s *AntigravityOAuthService) loadProjectIDWithRetry(ctx context.Context, accessToken, proxyURL string, maxRetries int) (string, error) {
|
||||
var lastErr error
|
||||
|
||||
for attempt := 0; attempt <= maxRetries; attempt++ {
|
||||
if attempt > 0 {
|
||||
// 指数退避:1s, 2s, 4s
|
||||
backoff := time.Duration(1<<uint(attempt-1)) * time.Second
|
||||
if backoff > 8*time.Second {
|
||||
backoff = 8 * time.Second
|
||||
}
|
||||
time.Sleep(backoff)
|
||||
}
|
||||
|
||||
client := antigravity.NewClient(proxyURL)
|
||||
loadResp, _, err := client.LoadCodeAssist(ctx, accessToken)
|
||||
|
||||
if err == nil && loadResp != nil && loadResp.CloudAICompanionProject != "" {
|
||||
return loadResp.CloudAICompanionProject, nil
|
||||
}
|
||||
|
||||
// 记录错误
|
||||
if err != nil {
|
||||
lastErr = err
|
||||
} else if loadResp == nil {
|
||||
lastErr = fmt.Errorf("LoadCodeAssist 返回空响应")
|
||||
} else {
|
||||
lastErr = fmt.Errorf("LoadCodeAssist 返回空 project_id")
|
||||
}
|
||||
}
|
||||
|
||||
return "", fmt.Errorf("获取 project_id 失败 (重试 %d 次后): %w", maxRetries, lastErr)
|
||||
}
|
||||
|
||||
// BuildAccountCredentials 构建账户凭证
|
||||
func (s *AntigravityOAuthService) BuildAccountCredentials(tokenInfo *AntigravityTokenInfo) map[string]any {
|
||||
creds := map[string]any{
|
||||
|
||||
@@ -3,6 +3,8 @@ package service
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"log"
|
||||
"strings"
|
||||
"time"
|
||||
)
|
||||
|
||||
@@ -55,15 +57,32 @@ func (r *AntigravityTokenRefresher) Refresh(ctx context.Context, account *Accoun
|
||||
}
|
||||
|
||||
newCredentials := r.antigravityOAuthService.BuildAccountCredentials(tokenInfo)
|
||||
// 合并旧的 credentials,保留新 credentials 中不存在的字段
|
||||
for k, v := range account.Credentials {
|
||||
if _, exists := newCredentials[k]; !exists {
|
||||
newCredentials[k] = v
|
||||
}
|
||||
}
|
||||
|
||||
// 如果 project_id 获取失败,返回 credentials 但同时返回错误让账户被标记
|
||||
// 特殊处理 project_id:如果新值为空但旧值非空,保留旧值
|
||||
// 这确保了即使 LoadCodeAssist 失败,project_id 也不会丢失
|
||||
if newProjectID, _ := newCredentials["project_id"].(string); newProjectID == "" {
|
||||
if oldProjectID := strings.TrimSpace(account.GetCredential("project_id")); oldProjectID != "" {
|
||||
newCredentials["project_id"] = oldProjectID
|
||||
}
|
||||
}
|
||||
|
||||
// 如果 project_id 获取失败,只记录警告,不返回错误
|
||||
// LoadCodeAssist 失败可能是临时网络问题,应该允许重试而不是立即标记为不可重试错误
|
||||
// Token 刷新本身是成功的(access_token 和 refresh_token 已更新)
|
||||
if tokenInfo.ProjectIDMissing {
|
||||
return newCredentials, fmt.Errorf("missing_project_id: 账户缺少project id,可能无法使用Antigravity")
|
||||
if tokenInfo.ProjectID != "" {
|
||||
// 有旧的 project_id,本次获取失败,保留旧值
|
||||
log.Printf("[AntigravityTokenRefresher] Account %d: LoadCodeAssist 临时失败,保留旧 project_id", account.ID)
|
||||
} else {
|
||||
// 从未获取过 project_id,本次也失败,但不返回错误以允许下次重试
|
||||
log.Printf("[AntigravityTokenRefresher] Account %d: LoadCodeAssist 失败,project_id 缺失,但 token 已更新,将在下次刷新时重试", account.ID)
|
||||
}
|
||||
}
|
||||
|
||||
return newCredentials, nil
|
||||
|
||||
@@ -580,3 +580,149 @@ func (s *AuthService) RefreshToken(ctx context.Context, oldTokenString string) (
|
||||
// 生成新token
|
||||
return s.GenerateToken(user)
|
||||
}
|
||||
|
||||
// IsPasswordResetEnabled 检查是否启用密码重置功能
|
||||
// 要求:必须同时开启邮件验证且 SMTP 配置正确
|
||||
func (s *AuthService) IsPasswordResetEnabled(ctx context.Context) bool {
|
||||
if s.settingService == nil {
|
||||
return false
|
||||
}
|
||||
// Must have email verification enabled and SMTP configured
|
||||
if !s.settingService.IsEmailVerifyEnabled(ctx) {
|
||||
return false
|
||||
}
|
||||
return s.settingService.IsPasswordResetEnabled(ctx)
|
||||
}
|
||||
|
||||
// preparePasswordReset validates the password reset request and returns necessary data
|
||||
// Returns (siteName, resetURL, shouldProceed)
|
||||
// shouldProceed is false when we should silently return success (to prevent enumeration)
|
||||
func (s *AuthService) preparePasswordReset(ctx context.Context, email, frontendBaseURL string) (string, string, bool) {
|
||||
// Check if user exists (but don't reveal this to the caller)
|
||||
user, err := s.userRepo.GetByEmail(ctx, email)
|
||||
if err != nil {
|
||||
if errors.Is(err, ErrUserNotFound) {
|
||||
// Security: Log but don't reveal that user doesn't exist
|
||||
log.Printf("[Auth] Password reset requested for non-existent email: %s", email)
|
||||
return "", "", false
|
||||
}
|
||||
log.Printf("[Auth] Database error checking email for password reset: %v", err)
|
||||
return "", "", false
|
||||
}
|
||||
|
||||
// Check if user is active
|
||||
if !user.IsActive() {
|
||||
log.Printf("[Auth] Password reset requested for inactive user: %s", email)
|
||||
return "", "", false
|
||||
}
|
||||
|
||||
// Get site name
|
||||
siteName := "Sub2API"
|
||||
if s.settingService != nil {
|
||||
siteName = s.settingService.GetSiteName(ctx)
|
||||
}
|
||||
|
||||
// Build reset URL base
|
||||
resetURL := fmt.Sprintf("%s/reset-password", strings.TrimSuffix(frontendBaseURL, "/"))
|
||||
|
||||
return siteName, resetURL, true
|
||||
}
|
||||
|
||||
// RequestPasswordReset 请求密码重置(同步发送)
|
||||
// Security: Returns the same response regardless of whether the email exists (prevent user enumeration)
|
||||
func (s *AuthService) RequestPasswordReset(ctx context.Context, email, frontendBaseURL string) error {
|
||||
if !s.IsPasswordResetEnabled(ctx) {
|
||||
return infraerrors.Forbidden("PASSWORD_RESET_DISABLED", "password reset is not enabled")
|
||||
}
|
||||
if s.emailService == nil {
|
||||
return ErrServiceUnavailable
|
||||
}
|
||||
|
||||
siteName, resetURL, shouldProceed := s.preparePasswordReset(ctx, email, frontendBaseURL)
|
||||
if !shouldProceed {
|
||||
return nil // Silent success to prevent enumeration
|
||||
}
|
||||
|
||||
if err := s.emailService.SendPasswordResetEmail(ctx, email, siteName, resetURL); err != nil {
|
||||
log.Printf("[Auth] Failed to send password reset email to %s: %v", email, err)
|
||||
return nil // Silent success to prevent enumeration
|
||||
}
|
||||
|
||||
log.Printf("[Auth] Password reset email sent to: %s", email)
|
||||
return nil
|
||||
}
|
||||
|
||||
// RequestPasswordResetAsync 异步请求密码重置(队列发送)
|
||||
// Security: Returns the same response regardless of whether the email exists (prevent user enumeration)
|
||||
func (s *AuthService) RequestPasswordResetAsync(ctx context.Context, email, frontendBaseURL string) error {
|
||||
if !s.IsPasswordResetEnabled(ctx) {
|
||||
return infraerrors.Forbidden("PASSWORD_RESET_DISABLED", "password reset is not enabled")
|
||||
}
|
||||
if s.emailQueueService == nil {
|
||||
return ErrServiceUnavailable
|
||||
}
|
||||
|
||||
siteName, resetURL, shouldProceed := s.preparePasswordReset(ctx, email, frontendBaseURL)
|
||||
if !shouldProceed {
|
||||
return nil // Silent success to prevent enumeration
|
||||
}
|
||||
|
||||
if err := s.emailQueueService.EnqueuePasswordReset(email, siteName, resetURL); err != nil {
|
||||
log.Printf("[Auth] Failed to enqueue password reset email for %s: %v", email, err)
|
||||
return nil // Silent success to prevent enumeration
|
||||
}
|
||||
|
||||
log.Printf("[Auth] Password reset email enqueued for: %s", email)
|
||||
return nil
|
||||
}
|
||||
|
||||
// ResetPassword 重置密码
|
||||
// Security: Increments TokenVersion to invalidate all existing JWT tokens
|
||||
func (s *AuthService) ResetPassword(ctx context.Context, email, token, newPassword string) error {
|
||||
// Check if password reset is enabled
|
||||
if !s.IsPasswordResetEnabled(ctx) {
|
||||
return infraerrors.Forbidden("PASSWORD_RESET_DISABLED", "password reset is not enabled")
|
||||
}
|
||||
|
||||
if s.emailService == nil {
|
||||
return ErrServiceUnavailable
|
||||
}
|
||||
|
||||
// Verify and consume the reset token (one-time use)
|
||||
if err := s.emailService.ConsumePasswordResetToken(ctx, email, token); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Get user
|
||||
user, err := s.userRepo.GetByEmail(ctx, email)
|
||||
if err != nil {
|
||||
if errors.Is(err, ErrUserNotFound) {
|
||||
return ErrInvalidResetToken // Token was valid but user was deleted
|
||||
}
|
||||
log.Printf("[Auth] Database error getting user for password reset: %v", err)
|
||||
return ErrServiceUnavailable
|
||||
}
|
||||
|
||||
// Check if user is active
|
||||
if !user.IsActive() {
|
||||
return ErrUserNotActive
|
||||
}
|
||||
|
||||
// Hash new password
|
||||
hashedPassword, err := s.HashPassword(newPassword)
|
||||
if err != nil {
|
||||
return fmt.Errorf("hash password: %w", err)
|
||||
}
|
||||
|
||||
// Update password and increment TokenVersion
|
||||
user.PasswordHash = hashedPassword
|
||||
user.TokenVersion++ // Invalidate all existing tokens
|
||||
|
||||
if err := s.userRepo.Update(ctx, user); err != nil {
|
||||
log.Printf("[Auth] Database error updating password for user %d: %v", user.ID, err)
|
||||
return ErrServiceUnavailable
|
||||
}
|
||||
|
||||
log.Printf("[Auth] Password reset successful for user: %s", email)
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -71,6 +71,26 @@ func (s *emailCacheStub) DeleteVerificationCode(ctx context.Context, email strin
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *emailCacheStub) GetPasswordResetToken(ctx context.Context, email string) (*PasswordResetTokenData, error) {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
func (s *emailCacheStub) SetPasswordResetToken(ctx context.Context, email string, data *PasswordResetTokenData, ttl time.Duration) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *emailCacheStub) DeletePasswordResetToken(ctx context.Context, email string) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *emailCacheStub) IsPasswordResetEmailInCooldown(ctx context.Context, email string) bool {
|
||||
return false
|
||||
}
|
||||
|
||||
func (s *emailCacheStub) SetPasswordResetEmailCooldown(ctx context.Context, email string, ttl time.Duration) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func newAuthService(repo *userRepoStub, settings map[string]string, emailCache EmailCache) *AuthService {
|
||||
cfg := &config.Config{
|
||||
JWT: config.JWTConfig{
|
||||
|
||||
@@ -69,9 +69,10 @@ const LinuxDoConnectSyntheticEmailDomain = "@linuxdo-connect.invalid"
|
||||
// Setting keys
|
||||
const (
|
||||
// 注册设置
|
||||
SettingKeyRegistrationEnabled = "registration_enabled" // 是否开放注册
|
||||
SettingKeyEmailVerifyEnabled = "email_verify_enabled" // 是否开启邮件验证
|
||||
SettingKeyPromoCodeEnabled = "promo_code_enabled" // 是否启用优惠码功能
|
||||
SettingKeyRegistrationEnabled = "registration_enabled" // 是否开放注册
|
||||
SettingKeyEmailVerifyEnabled = "email_verify_enabled" // 是否开启邮件验证
|
||||
SettingKeyPromoCodeEnabled = "promo_code_enabled" // 是否启用优惠码功能
|
||||
SettingKeyPasswordResetEnabled = "password_reset_enabled" // 是否启用忘记密码功能(需要先开启邮件验证)
|
||||
|
||||
// 邮件服务设置
|
||||
SettingKeySMTPHost = "smtp_host" // SMTP服务器地址
|
||||
@@ -87,6 +88,9 @@ const (
|
||||
SettingKeyTurnstileSiteKey = "turnstile_site_key" // Turnstile Site Key
|
||||
SettingKeyTurnstileSecretKey = "turnstile_secret_key" // Turnstile Secret Key
|
||||
|
||||
// TOTP 双因素认证设置
|
||||
SettingKeyTotpEnabled = "totp_enabled" // 是否启用 TOTP 2FA 功能
|
||||
|
||||
// LinuxDo Connect OAuth 登录设置
|
||||
SettingKeyLinuxDoConnectEnabled = "linuxdo_connect_enabled"
|
||||
SettingKeyLinuxDoConnectClientID = "linuxdo_connect_client_id"
|
||||
|
||||
@@ -8,11 +8,18 @@ import (
|
||||
"time"
|
||||
)
|
||||
|
||||
// Task type constants
|
||||
const (
|
||||
TaskTypeVerifyCode = "verify_code"
|
||||
TaskTypePasswordReset = "password_reset"
|
||||
)
|
||||
|
||||
// EmailTask 邮件发送任务
|
||||
type EmailTask struct {
|
||||
Email string
|
||||
SiteName string
|
||||
TaskType string // "verify_code"
|
||||
TaskType string // "verify_code" or "password_reset"
|
||||
ResetURL string // Only used for password_reset task type
|
||||
}
|
||||
|
||||
// EmailQueueService 异步邮件队列服务
|
||||
@@ -73,12 +80,18 @@ func (s *EmailQueueService) processTask(workerID int, task EmailTask) {
|
||||
defer cancel()
|
||||
|
||||
switch task.TaskType {
|
||||
case "verify_code":
|
||||
case TaskTypeVerifyCode:
|
||||
if err := s.emailService.SendVerifyCode(ctx, task.Email, task.SiteName); err != nil {
|
||||
log.Printf("[EmailQueue] Worker %d failed to send verify code to %s: %v", workerID, task.Email, err)
|
||||
} else {
|
||||
log.Printf("[EmailQueue] Worker %d sent verify code to %s", workerID, task.Email)
|
||||
}
|
||||
case TaskTypePasswordReset:
|
||||
if err := s.emailService.SendPasswordResetEmailWithCooldown(ctx, task.Email, task.SiteName, task.ResetURL); err != nil {
|
||||
log.Printf("[EmailQueue] Worker %d failed to send password reset to %s: %v", workerID, task.Email, err)
|
||||
} else {
|
||||
log.Printf("[EmailQueue] Worker %d sent password reset to %s", workerID, task.Email)
|
||||
}
|
||||
default:
|
||||
log.Printf("[EmailQueue] Worker %d unknown task type: %s", workerID, task.TaskType)
|
||||
}
|
||||
@@ -89,7 +102,7 @@ func (s *EmailQueueService) EnqueueVerifyCode(email, siteName string) error {
|
||||
task := EmailTask{
|
||||
Email: email,
|
||||
SiteName: siteName,
|
||||
TaskType: "verify_code",
|
||||
TaskType: TaskTypeVerifyCode,
|
||||
}
|
||||
|
||||
select {
|
||||
@@ -101,6 +114,24 @@ func (s *EmailQueueService) EnqueueVerifyCode(email, siteName string) error {
|
||||
}
|
||||
}
|
||||
|
||||
// EnqueuePasswordReset 将密码重置邮件任务加入队列
|
||||
func (s *EmailQueueService) EnqueuePasswordReset(email, siteName, resetURL string) error {
|
||||
task := EmailTask{
|
||||
Email: email,
|
||||
SiteName: siteName,
|
||||
TaskType: TaskTypePasswordReset,
|
||||
ResetURL: resetURL,
|
||||
}
|
||||
|
||||
select {
|
||||
case s.taskChan <- task:
|
||||
log.Printf("[EmailQueue] Enqueued password reset task for %s", email)
|
||||
return nil
|
||||
default:
|
||||
return fmt.Errorf("email queue is full")
|
||||
}
|
||||
}
|
||||
|
||||
// Stop 停止队列服务
|
||||
func (s *EmailQueueService) Stop() {
|
||||
close(s.stopChan)
|
||||
|
||||
@@ -3,11 +3,14 @@ package service
|
||||
import (
|
||||
"context"
|
||||
"crypto/rand"
|
||||
"crypto/subtle"
|
||||
"crypto/tls"
|
||||
"encoding/hex"
|
||||
"fmt"
|
||||
"log"
|
||||
"math/big"
|
||||
"net/smtp"
|
||||
"net/url"
|
||||
"strconv"
|
||||
"time"
|
||||
|
||||
@@ -19,6 +22,9 @@ var (
|
||||
ErrInvalidVerifyCode = infraerrors.BadRequest("INVALID_VERIFY_CODE", "invalid or expired verification code")
|
||||
ErrVerifyCodeTooFrequent = infraerrors.TooManyRequests("VERIFY_CODE_TOO_FREQUENT", "please wait before requesting a new code")
|
||||
ErrVerifyCodeMaxAttempts = infraerrors.TooManyRequests("VERIFY_CODE_MAX_ATTEMPTS", "too many failed attempts, please request a new code")
|
||||
|
||||
// Password reset errors
|
||||
ErrInvalidResetToken = infraerrors.BadRequest("INVALID_RESET_TOKEN", "invalid or expired password reset token")
|
||||
)
|
||||
|
||||
// EmailCache defines cache operations for email service
|
||||
@@ -26,6 +32,16 @@ type EmailCache interface {
|
||||
GetVerificationCode(ctx context.Context, email string) (*VerificationCodeData, error)
|
||||
SetVerificationCode(ctx context.Context, email string, data *VerificationCodeData, ttl time.Duration) error
|
||||
DeleteVerificationCode(ctx context.Context, email string) error
|
||||
|
||||
// Password reset token methods
|
||||
GetPasswordResetToken(ctx context.Context, email string) (*PasswordResetTokenData, error)
|
||||
SetPasswordResetToken(ctx context.Context, email string, data *PasswordResetTokenData, ttl time.Duration) error
|
||||
DeletePasswordResetToken(ctx context.Context, email string) error
|
||||
|
||||
// Password reset email cooldown methods
|
||||
// Returns true if in cooldown period (email was sent recently)
|
||||
IsPasswordResetEmailInCooldown(ctx context.Context, email string) bool
|
||||
SetPasswordResetEmailCooldown(ctx context.Context, email string, ttl time.Duration) error
|
||||
}
|
||||
|
||||
// VerificationCodeData represents verification code data
|
||||
@@ -35,10 +51,22 @@ type VerificationCodeData struct {
|
||||
CreatedAt time.Time
|
||||
}
|
||||
|
||||
// PasswordResetTokenData represents password reset token data
|
||||
type PasswordResetTokenData struct {
|
||||
Token string
|
||||
CreatedAt time.Time
|
||||
}
|
||||
|
||||
const (
|
||||
verifyCodeTTL = 15 * time.Minute
|
||||
verifyCodeCooldown = 1 * time.Minute
|
||||
maxVerifyCodeAttempts = 5
|
||||
|
||||
// Password reset token settings
|
||||
passwordResetTokenTTL = 30 * time.Minute
|
||||
|
||||
// Password reset email cooldown (prevent email bombing)
|
||||
passwordResetEmailCooldown = 30 * time.Second
|
||||
)
|
||||
|
||||
// SMTPConfig SMTP配置
|
||||
@@ -254,8 +282,8 @@ func (s *EmailService) VerifyCode(ctx context.Context, email, code string) error
|
||||
return ErrVerifyCodeMaxAttempts
|
||||
}
|
||||
|
||||
// 验证码不匹配
|
||||
if data.Code != code {
|
||||
// 验证码不匹配 (constant-time comparison to prevent timing attacks)
|
||||
if subtle.ConstantTimeCompare([]byte(data.Code), []byte(code)) != 1 {
|
||||
data.Attempts++
|
||||
if err := s.cache.SetVerificationCode(ctx, email, data, verifyCodeTTL); err != nil {
|
||||
log.Printf("[Email] Failed to update verification attempt count: %v", err)
|
||||
@@ -357,3 +385,157 @@ func (s *EmailService) TestSMTPConnectionWithConfig(config *SMTPConfig) error {
|
||||
|
||||
return client.Quit()
|
||||
}
|
||||
|
||||
// GeneratePasswordResetToken generates a secure 32-byte random token (64 hex characters)
|
||||
func (s *EmailService) GeneratePasswordResetToken() (string, error) {
|
||||
bytes := make([]byte, 32)
|
||||
if _, err := rand.Read(bytes); err != nil {
|
||||
return "", err
|
||||
}
|
||||
return hex.EncodeToString(bytes), nil
|
||||
}
|
||||
|
||||
// SendPasswordResetEmail sends a password reset email with a reset link
|
||||
func (s *EmailService) SendPasswordResetEmail(ctx context.Context, email, siteName, resetURL string) error {
|
||||
var token string
|
||||
var needSaveToken bool
|
||||
|
||||
// Check if token already exists
|
||||
existing, err := s.cache.GetPasswordResetToken(ctx, email)
|
||||
if err == nil && existing != nil {
|
||||
// Token exists, reuse it (allows resending email without generating new token)
|
||||
token = existing.Token
|
||||
needSaveToken = false
|
||||
} else {
|
||||
// Generate new token
|
||||
token, err = s.GeneratePasswordResetToken()
|
||||
if err != nil {
|
||||
return fmt.Errorf("generate token: %w", err)
|
||||
}
|
||||
needSaveToken = true
|
||||
}
|
||||
|
||||
// Save token to Redis (only if new token generated)
|
||||
if needSaveToken {
|
||||
data := &PasswordResetTokenData{
|
||||
Token: token,
|
||||
CreatedAt: time.Now(),
|
||||
}
|
||||
if err := s.cache.SetPasswordResetToken(ctx, email, data, passwordResetTokenTTL); err != nil {
|
||||
return fmt.Errorf("save reset token: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
// Build full reset URL with URL-encoded token and email
|
||||
fullResetURL := fmt.Sprintf("%s?email=%s&token=%s", resetURL, url.QueryEscape(email), url.QueryEscape(token))
|
||||
|
||||
// Build email content
|
||||
subject := fmt.Sprintf("[%s] 密码重置请求", siteName)
|
||||
body := s.buildPasswordResetEmailBody(fullResetURL, siteName)
|
||||
|
||||
// Send email
|
||||
if err := s.SendEmail(ctx, email, subject, body); err != nil {
|
||||
return fmt.Errorf("send email: %w", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// SendPasswordResetEmailWithCooldown sends password reset email with cooldown check (called by queue worker)
|
||||
// This method wraps SendPasswordResetEmail with email cooldown to prevent email bombing
|
||||
func (s *EmailService) SendPasswordResetEmailWithCooldown(ctx context.Context, email, siteName, resetURL string) error {
|
||||
// Check email cooldown to prevent email bombing
|
||||
if s.cache.IsPasswordResetEmailInCooldown(ctx, email) {
|
||||
log.Printf("[Email] Password reset email skipped (cooldown): %s", email)
|
||||
return nil // Silent success to prevent revealing cooldown to attackers
|
||||
}
|
||||
|
||||
// Send email using core method
|
||||
if err := s.SendPasswordResetEmail(ctx, email, siteName, resetURL); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Set cooldown marker (Redis TTL handles expiration)
|
||||
if err := s.cache.SetPasswordResetEmailCooldown(ctx, email, passwordResetEmailCooldown); err != nil {
|
||||
log.Printf("[Email] Failed to set password reset cooldown for %s: %v", email, err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// VerifyPasswordResetToken verifies the password reset token without consuming it
|
||||
func (s *EmailService) VerifyPasswordResetToken(ctx context.Context, email, token string) error {
|
||||
data, err := s.cache.GetPasswordResetToken(ctx, email)
|
||||
if err != nil || data == nil {
|
||||
return ErrInvalidResetToken
|
||||
}
|
||||
|
||||
// Use constant-time comparison to prevent timing attacks
|
||||
if subtle.ConstantTimeCompare([]byte(data.Token), []byte(token)) != 1 {
|
||||
return ErrInvalidResetToken
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// ConsumePasswordResetToken verifies and deletes the token (one-time use)
|
||||
func (s *EmailService) ConsumePasswordResetToken(ctx context.Context, email, token string) error {
|
||||
// Verify first
|
||||
if err := s.VerifyPasswordResetToken(ctx, email, token); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Delete after verification (one-time use)
|
||||
if err := s.cache.DeletePasswordResetToken(ctx, email); err != nil {
|
||||
log.Printf("[Email] Failed to delete password reset token after consumption: %v", err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// buildPasswordResetEmailBody builds the HTML content for password reset email
|
||||
func (s *EmailService) buildPasswordResetEmailBody(resetURL, siteName string) string {
|
||||
return fmt.Sprintf(`
|
||||
<!DOCTYPE html>
|
||||
<html>
|
||||
<head>
|
||||
<meta charset="UTF-8">
|
||||
<style>
|
||||
body { font-family: -apple-system, BlinkMacSystemFont, 'Segoe UI', Roboto, Oxygen, Ubuntu, sans-serif; background-color: #f5f5f5; margin: 0; padding: 20px; }
|
||||
.container { max-width: 600px; margin: 0 auto; background-color: #ffffff; border-radius: 8px; overflow: hidden; box-shadow: 0 2px 8px rgba(0,0,0,0.1); }
|
||||
.header { background: linear-gradient(135deg, #667eea 0%%, #764ba2 100%%); color: white; padding: 30px; text-align: center; }
|
||||
.header h1 { margin: 0; font-size: 24px; }
|
||||
.content { padding: 40px 30px; text-align: center; }
|
||||
.button { display: inline-block; background: linear-gradient(135deg, #667eea 0%%, #764ba2 100%%); color: white; padding: 14px 32px; text-decoration: none; border-radius: 8px; font-size: 16px; font-weight: 600; margin: 20px 0; }
|
||||
.button:hover { opacity: 0.9; }
|
||||
.info { color: #666; font-size: 14px; line-height: 1.6; margin-top: 20px; }
|
||||
.link-fallback { color: #666; font-size: 12px; word-break: break-all; margin-top: 20px; padding: 15px; background-color: #f8f9fa; border-radius: 4px; }
|
||||
.footer { background-color: #f8f9fa; padding: 20px; text-align: center; color: #999; font-size: 12px; }
|
||||
.warning { color: #e74c3c; font-weight: 500; }
|
||||
</style>
|
||||
</head>
|
||||
<body>
|
||||
<div class="container">
|
||||
<div class="header">
|
||||
<h1>%s</h1>
|
||||
</div>
|
||||
<div class="content">
|
||||
<p style="font-size: 18px; color: #333;">密码重置请求</p>
|
||||
<p style="color: #666;">您已请求重置密码。请点击下方按钮设置新密码:</p>
|
||||
<a href="%s" class="button">重置密码</a>
|
||||
<div class="info">
|
||||
<p>此链接将在 <strong>30 分钟</strong>后失效。</p>
|
||||
<p class="warning">如果您没有请求重置密码,请忽略此邮件。您的密码将保持不变。</p>
|
||||
</div>
|
||||
<div class="link-fallback">
|
||||
<p>如果按钮无法点击,请复制以下链接到浏览器中打开:</p>
|
||||
<p>%s</p>
|
||||
</div>
|
||||
</div>
|
||||
<div class="footer">
|
||||
<p>这是一封自动发送的邮件,请勿回复。</p>
|
||||
</div>
|
||||
</div>
|
||||
</body>
|
||||
</html>
|
||||
`, siteName, resetURL, resetURL)
|
||||
}
|
||||
|
||||
@@ -305,6 +305,19 @@ func (s *GatewayService) BindStickySession(ctx context.Context, groupID *int64,
|
||||
return s.cache.SetSessionAccountID(ctx, derefGroupID(groupID), sessionHash, accountID, stickySessionTTL)
|
||||
}
|
||||
|
||||
// GetCachedSessionAccountID retrieves the account ID bound to a sticky session.
|
||||
// Returns 0 if no binding exists or on error.
|
||||
func (s *GatewayService) GetCachedSessionAccountID(ctx context.Context, groupID *int64, sessionHash string) (int64, error) {
|
||||
if sessionHash == "" || s.cache == nil {
|
||||
return 0, nil
|
||||
}
|
||||
accountID, err := s.cache.GetSessionAccountID(ctx, derefGroupID(groupID), sessionHash)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
return accountID, nil
|
||||
}
|
||||
|
||||
func (s *GatewayService) extractCacheableContent(parsed *ParsedRequest) string {
|
||||
if parsed == nil {
|
||||
return ""
|
||||
|
||||
72
backend/internal/service/gemini_native_signature_cleaner.go
Normal file
72
backend/internal/service/gemini_native_signature_cleaner.go
Normal file
@@ -0,0 +1,72 @@
|
||||
package service
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
)
|
||||
|
||||
// CleanGeminiNativeThoughtSignatures 从 Gemini 原生 API 请求中移除 thoughtSignature 字段,
|
||||
// 以避免跨账号签名验证错误。
|
||||
//
|
||||
// 当粘性会话切换账号时(例如原账号异常、不可调度等),旧账号返回的 thoughtSignature
|
||||
// 会导致新账号的签名验证失败。通过移除这些签名,让新账号重新生成有效的签名。
|
||||
//
|
||||
// CleanGeminiNativeThoughtSignatures removes thoughtSignature fields from Gemini native API requests
|
||||
// to avoid cross-account signature validation errors.
|
||||
//
|
||||
// When sticky session switches accounts (e.g., original account becomes unavailable),
|
||||
// thoughtSignatures from the old account will cause validation failures on the new account.
|
||||
// By removing these signatures, we allow the new account to generate valid signatures.
|
||||
func CleanGeminiNativeThoughtSignatures(body []byte) []byte {
|
||||
if len(body) == 0 {
|
||||
return body
|
||||
}
|
||||
|
||||
// 解析 JSON
|
||||
var data any
|
||||
if err := json.Unmarshal(body, &data); err != nil {
|
||||
// 如果解析失败,返回原始 body(可能不是 JSON 或格式不正确)
|
||||
return body
|
||||
}
|
||||
|
||||
// 递归清理 thoughtSignature
|
||||
cleaned := cleanThoughtSignaturesRecursive(data)
|
||||
|
||||
// 重新序列化
|
||||
result, err := json.Marshal(cleaned)
|
||||
if err != nil {
|
||||
// 如果序列化失败,返回原始 body
|
||||
return body
|
||||
}
|
||||
|
||||
return result
|
||||
}
|
||||
|
||||
// cleanThoughtSignaturesRecursive 递归遍历数据结构,移除所有 thoughtSignature 字段
|
||||
func cleanThoughtSignaturesRecursive(data any) any {
|
||||
switch v := data.(type) {
|
||||
case map[string]any:
|
||||
// 创建新的 map,移除 thoughtSignature
|
||||
result := make(map[string]any, len(v))
|
||||
for key, value := range v {
|
||||
// 跳过 thoughtSignature 字段
|
||||
if key == "thoughtSignature" {
|
||||
continue
|
||||
}
|
||||
// 递归处理嵌套结构
|
||||
result[key] = cleanThoughtSignaturesRecursive(value)
|
||||
}
|
||||
return result
|
||||
|
||||
case []any:
|
||||
// 递归处理数组中的每个元素
|
||||
result := make([]any, len(v))
|
||||
for i, item := range v {
|
||||
result[i] = cleanThoughtSignaturesRecursive(item)
|
||||
}
|
||||
return result
|
||||
|
||||
default:
|
||||
// 基本类型(string, number, bool, null)直接返回
|
||||
return v
|
||||
}
|
||||
}
|
||||
@@ -60,6 +60,92 @@ type OpenAICodexUsageSnapshot struct {
|
||||
UpdatedAt string `json:"updated_at,omitempty"`
|
||||
}
|
||||
|
||||
// NormalizedCodexLimits contains normalized 5h/7d rate limit data
|
||||
type NormalizedCodexLimits struct {
|
||||
Used5hPercent *float64
|
||||
Reset5hSeconds *int
|
||||
Window5hMinutes *int
|
||||
Used7dPercent *float64
|
||||
Reset7dSeconds *int
|
||||
Window7dMinutes *int
|
||||
}
|
||||
|
||||
// Normalize converts primary/secondary fields to canonical 5h/7d fields.
|
||||
// Strategy: Compare window_minutes to determine which is 5h vs 7d.
|
||||
// Returns nil if snapshot is nil or has no useful data.
|
||||
func (s *OpenAICodexUsageSnapshot) Normalize() *NormalizedCodexLimits {
|
||||
if s == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
result := &NormalizedCodexLimits{}
|
||||
|
||||
primaryMins := 0
|
||||
secondaryMins := 0
|
||||
hasPrimaryWindow := false
|
||||
hasSecondaryWindow := false
|
||||
|
||||
if s.PrimaryWindowMinutes != nil {
|
||||
primaryMins = *s.PrimaryWindowMinutes
|
||||
hasPrimaryWindow = true
|
||||
}
|
||||
if s.SecondaryWindowMinutes != nil {
|
||||
secondaryMins = *s.SecondaryWindowMinutes
|
||||
hasSecondaryWindow = true
|
||||
}
|
||||
|
||||
// Determine mapping based on window_minutes
|
||||
use5hFromPrimary := false
|
||||
use7dFromPrimary := false
|
||||
|
||||
if hasPrimaryWindow && hasSecondaryWindow {
|
||||
// Both known: smaller window is 5h, larger is 7d
|
||||
if primaryMins < secondaryMins {
|
||||
use5hFromPrimary = true
|
||||
} else {
|
||||
use7dFromPrimary = true
|
||||
}
|
||||
} else if hasPrimaryWindow {
|
||||
// Only primary known: classify by threshold (<=360 min = 6h -> 5h window)
|
||||
if primaryMins <= 360 {
|
||||
use5hFromPrimary = true
|
||||
} else {
|
||||
use7dFromPrimary = true
|
||||
}
|
||||
} else if hasSecondaryWindow {
|
||||
// Only secondary known: classify by threshold
|
||||
if secondaryMins <= 360 {
|
||||
// 5h from secondary, so primary (if any data) is 7d
|
||||
use7dFromPrimary = true
|
||||
} else {
|
||||
// 7d from secondary, so primary (if any data) is 5h
|
||||
use5hFromPrimary = true
|
||||
}
|
||||
} else {
|
||||
// No window_minutes: fall back to legacy assumption (primary=7d, secondary=5h)
|
||||
use7dFromPrimary = true
|
||||
}
|
||||
|
||||
// Assign values
|
||||
if use5hFromPrimary {
|
||||
result.Used5hPercent = s.PrimaryUsedPercent
|
||||
result.Reset5hSeconds = s.PrimaryResetAfterSeconds
|
||||
result.Window5hMinutes = s.PrimaryWindowMinutes
|
||||
result.Used7dPercent = s.SecondaryUsedPercent
|
||||
result.Reset7dSeconds = s.SecondaryResetAfterSeconds
|
||||
result.Window7dMinutes = s.SecondaryWindowMinutes
|
||||
} else if use7dFromPrimary {
|
||||
result.Used7dPercent = s.PrimaryUsedPercent
|
||||
result.Reset7dSeconds = s.PrimaryResetAfterSeconds
|
||||
result.Window7dMinutes = s.PrimaryWindowMinutes
|
||||
result.Used5hPercent = s.SecondaryUsedPercent
|
||||
result.Reset5hSeconds = s.SecondaryResetAfterSeconds
|
||||
result.Window5hMinutes = s.SecondaryWindowMinutes
|
||||
}
|
||||
|
||||
return result
|
||||
}
|
||||
|
||||
// OpenAIUsage represents OpenAI API response usage
|
||||
type OpenAIUsage struct {
|
||||
InputTokens int `json:"input_tokens"`
|
||||
@@ -867,7 +953,7 @@ func (s *OpenAIGatewayService) Forward(ctx context.Context, c *gin.Context, acco
|
||||
|
||||
// Extract and save Codex usage snapshot from response headers (for OAuth accounts)
|
||||
if account.Type == AccountTypeOAuth {
|
||||
if snapshot := extractCodexUsageHeaders(resp.Header); snapshot != nil {
|
||||
if snapshot := ParseCodexRateLimitHeaders(resp.Header); snapshot != nil {
|
||||
s.updateCodexUsageSnapshot(ctx, account.ID, snapshot)
|
||||
}
|
||||
}
|
||||
@@ -1665,8 +1751,9 @@ func (s *OpenAIGatewayService) RecordUsage(ctx context.Context, input *OpenAIRec
|
||||
return nil
|
||||
}
|
||||
|
||||
// extractCodexUsageHeaders extracts Codex usage limits from response headers
|
||||
func extractCodexUsageHeaders(headers http.Header) *OpenAICodexUsageSnapshot {
|
||||
// ParseCodexRateLimitHeaders extracts Codex usage limits from response headers.
|
||||
// Exported for use in ratelimit_service when handling OpenAI 429 responses.
|
||||
func ParseCodexRateLimitHeaders(headers http.Header) *OpenAICodexUsageSnapshot {
|
||||
snapshot := &OpenAICodexUsageSnapshot{}
|
||||
hasData := false
|
||||
|
||||
@@ -1740,6 +1827,8 @@ func (s *OpenAIGatewayService) updateCodexUsageSnapshot(ctx context.Context, acc
|
||||
|
||||
// Convert snapshot to map for merging into Extra
|
||||
updates := make(map[string]any)
|
||||
|
||||
// Save raw primary/secondary fields for debugging/tracing
|
||||
if snapshot.PrimaryUsedPercent != nil {
|
||||
updates["codex_primary_used_percent"] = *snapshot.PrimaryUsedPercent
|
||||
}
|
||||
@@ -1763,109 +1852,25 @@ func (s *OpenAIGatewayService) updateCodexUsageSnapshot(ctx context.Context, acc
|
||||
}
|
||||
updates["codex_usage_updated_at"] = snapshot.UpdatedAt
|
||||
|
||||
// Normalize to canonical 5h/7d fields based on window_minutes
|
||||
// This fixes the issue where OpenAI's primary/secondary naming is reversed
|
||||
// Strategy: Compare the two windows and assign the smaller one to 5h, larger one to 7d
|
||||
|
||||
// IMPORTANT: We can only reliably determine window type from window_minutes field
|
||||
// The reset_after_seconds is remaining time, not window size, so it cannot be used for comparison
|
||||
|
||||
var primaryWindowMins, secondaryWindowMins int
|
||||
var hasPrimaryWindow, hasSecondaryWindow bool
|
||||
|
||||
// Only use window_minutes for reliable window size comparison
|
||||
if snapshot.PrimaryWindowMinutes != nil {
|
||||
primaryWindowMins = *snapshot.PrimaryWindowMinutes
|
||||
hasPrimaryWindow = true
|
||||
}
|
||||
|
||||
if snapshot.SecondaryWindowMinutes != nil {
|
||||
secondaryWindowMins = *snapshot.SecondaryWindowMinutes
|
||||
hasSecondaryWindow = true
|
||||
}
|
||||
|
||||
// Determine which is 5h and which is 7d
|
||||
var use5hFromPrimary, use7dFromPrimary bool
|
||||
var use5hFromSecondary, use7dFromSecondary bool
|
||||
|
||||
if hasPrimaryWindow && hasSecondaryWindow {
|
||||
// Both window sizes known: compare and assign smaller to 5h, larger to 7d
|
||||
if primaryWindowMins < secondaryWindowMins {
|
||||
use5hFromPrimary = true
|
||||
use7dFromSecondary = true
|
||||
} else {
|
||||
use5hFromSecondary = true
|
||||
use7dFromPrimary = true
|
||||
// Normalize to canonical 5h/7d fields
|
||||
if normalized := snapshot.Normalize(); normalized != nil {
|
||||
if normalized.Used5hPercent != nil {
|
||||
updates["codex_5h_used_percent"] = *normalized.Used5hPercent
|
||||
}
|
||||
} else if hasPrimaryWindow {
|
||||
// Only primary window size known: classify by absolute threshold
|
||||
if primaryWindowMins <= 360 {
|
||||
use5hFromPrimary = true
|
||||
} else {
|
||||
use7dFromPrimary = true
|
||||
if normalized.Reset5hSeconds != nil {
|
||||
updates["codex_5h_reset_after_seconds"] = *normalized.Reset5hSeconds
|
||||
}
|
||||
} else if hasSecondaryWindow {
|
||||
// Only secondary window size known: classify by absolute threshold
|
||||
if secondaryWindowMins <= 360 {
|
||||
use5hFromSecondary = true
|
||||
} else {
|
||||
use7dFromSecondary = true
|
||||
if normalized.Window5hMinutes != nil {
|
||||
updates["codex_5h_window_minutes"] = *normalized.Window5hMinutes
|
||||
}
|
||||
} else {
|
||||
// No window_minutes available: cannot reliably determine window types
|
||||
// Fall back to legacy assumption (may be incorrect)
|
||||
// Assume primary=7d, secondary=5h based on historical observation
|
||||
if snapshot.SecondaryUsedPercent != nil || snapshot.SecondaryResetAfterSeconds != nil || snapshot.SecondaryWindowMinutes != nil {
|
||||
use5hFromSecondary = true
|
||||
if normalized.Used7dPercent != nil {
|
||||
updates["codex_7d_used_percent"] = *normalized.Used7dPercent
|
||||
}
|
||||
if snapshot.PrimaryUsedPercent != nil || snapshot.PrimaryResetAfterSeconds != nil || snapshot.PrimaryWindowMinutes != nil {
|
||||
use7dFromPrimary = true
|
||||
if normalized.Reset7dSeconds != nil {
|
||||
updates["codex_7d_reset_after_seconds"] = *normalized.Reset7dSeconds
|
||||
}
|
||||
}
|
||||
|
||||
// Write canonical 5h fields
|
||||
if use5hFromPrimary {
|
||||
if snapshot.PrimaryUsedPercent != nil {
|
||||
updates["codex_5h_used_percent"] = *snapshot.PrimaryUsedPercent
|
||||
}
|
||||
if snapshot.PrimaryResetAfterSeconds != nil {
|
||||
updates["codex_5h_reset_after_seconds"] = *snapshot.PrimaryResetAfterSeconds
|
||||
}
|
||||
if snapshot.PrimaryWindowMinutes != nil {
|
||||
updates["codex_5h_window_minutes"] = *snapshot.PrimaryWindowMinutes
|
||||
}
|
||||
} else if use5hFromSecondary {
|
||||
if snapshot.SecondaryUsedPercent != nil {
|
||||
updates["codex_5h_used_percent"] = *snapshot.SecondaryUsedPercent
|
||||
}
|
||||
if snapshot.SecondaryResetAfterSeconds != nil {
|
||||
updates["codex_5h_reset_after_seconds"] = *snapshot.SecondaryResetAfterSeconds
|
||||
}
|
||||
if snapshot.SecondaryWindowMinutes != nil {
|
||||
updates["codex_5h_window_minutes"] = *snapshot.SecondaryWindowMinutes
|
||||
}
|
||||
}
|
||||
|
||||
// Write canonical 7d fields
|
||||
if use7dFromPrimary {
|
||||
if snapshot.PrimaryUsedPercent != nil {
|
||||
updates["codex_7d_used_percent"] = *snapshot.PrimaryUsedPercent
|
||||
}
|
||||
if snapshot.PrimaryResetAfterSeconds != nil {
|
||||
updates["codex_7d_reset_after_seconds"] = *snapshot.PrimaryResetAfterSeconds
|
||||
}
|
||||
if snapshot.PrimaryWindowMinutes != nil {
|
||||
updates["codex_7d_window_minutes"] = *snapshot.PrimaryWindowMinutes
|
||||
}
|
||||
} else if use7dFromSecondary {
|
||||
if snapshot.SecondaryUsedPercent != nil {
|
||||
updates["codex_7d_used_percent"] = *snapshot.SecondaryUsedPercent
|
||||
}
|
||||
if snapshot.SecondaryResetAfterSeconds != nil {
|
||||
updates["codex_7d_reset_after_seconds"] = *snapshot.SecondaryResetAfterSeconds
|
||||
}
|
||||
if snapshot.SecondaryWindowMinutes != nil {
|
||||
updates["codex_7d_window_minutes"] = *snapshot.SecondaryWindowMinutes
|
||||
if normalized.Window7dMinutes != nil {
|
||||
updates["codex_7d_window_minutes"] = *normalized.Window7dMinutes
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -343,9 +343,48 @@ func (s *RateLimitService) handleCustomErrorCode(ctx context.Context, account *A
|
||||
// handle429 处理429限流错误
|
||||
// 解析响应头获取重置时间,标记账号为限流状态
|
||||
func (s *RateLimitService) handle429(ctx context.Context, account *Account, headers http.Header, responseBody []byte) {
|
||||
// 解析重置时间戳
|
||||
// 1. OpenAI 平台:优先尝试解析 x-codex-* 响应头(用于 rate_limit_exceeded)
|
||||
if account.Platform == PlatformOpenAI {
|
||||
if resetAt := s.calculateOpenAI429ResetTime(headers); resetAt != nil {
|
||||
if err := s.accountRepo.SetRateLimited(ctx, account.ID, *resetAt); err != nil {
|
||||
slog.Warn("rate_limit_set_failed", "account_id", account.ID, "error", err)
|
||||
return
|
||||
}
|
||||
slog.Info("openai_account_rate_limited", "account_id", account.ID, "reset_at", *resetAt)
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
// 2. 尝试从响应头解析重置时间(Anthropic)
|
||||
resetTimestamp := headers.Get("anthropic-ratelimit-unified-reset")
|
||||
|
||||
// 3. 如果响应头没有,尝试从响应体解析(OpenAI usage_limit_reached, Gemini)
|
||||
if resetTimestamp == "" {
|
||||
switch account.Platform {
|
||||
case PlatformOpenAI:
|
||||
// 尝试解析 OpenAI 的 usage_limit_reached 错误
|
||||
if resetAt := parseOpenAIRateLimitResetTime(responseBody); resetAt != nil {
|
||||
resetTime := time.Unix(*resetAt, 0)
|
||||
if err := s.accountRepo.SetRateLimited(ctx, account.ID, resetTime); err != nil {
|
||||
slog.Warn("rate_limit_set_failed", "account_id", account.ID, "error", err)
|
||||
return
|
||||
}
|
||||
slog.Info("account_rate_limited", "account_id", account.ID, "platform", account.Platform, "reset_at", resetTime, "reset_in", time.Until(resetTime).Truncate(time.Second))
|
||||
return
|
||||
}
|
||||
case PlatformGemini, PlatformAntigravity:
|
||||
// 尝试解析 Gemini 格式(用于其他平台)
|
||||
if resetAt := ParseGeminiRateLimitResetTime(responseBody); resetAt != nil {
|
||||
resetTime := time.Unix(*resetAt, 0)
|
||||
if err := s.accountRepo.SetRateLimited(ctx, account.ID, resetTime); err != nil {
|
||||
slog.Warn("rate_limit_set_failed", "account_id", account.ID, "error", err)
|
||||
return
|
||||
}
|
||||
slog.Info("account_rate_limited", "account_id", account.ID, "platform", account.Platform, "reset_at", resetTime, "reset_in", time.Until(resetTime).Truncate(time.Second))
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
// 没有重置时间,使用默认5分钟
|
||||
resetAt := time.Now().Add(5 * time.Minute)
|
||||
if s.shouldScopeClaudeSonnetRateLimit(account, responseBody) {
|
||||
@@ -356,6 +395,7 @@ func (s *RateLimitService) handle429(ctx context.Context, account *Account, head
|
||||
}
|
||||
return
|
||||
}
|
||||
slog.Warn("rate_limit_no_reset_time", "account_id", account.ID, "platform", account.Platform, "using_default", "5m")
|
||||
if err := s.accountRepo.SetRateLimited(ctx, account.ID, resetAt); err != nil {
|
||||
slog.Warn("rate_limit_set_failed", "account_id", account.ID, "error", err)
|
||||
}
|
||||
@@ -419,6 +459,108 @@ func (s *RateLimitService) shouldScopeClaudeSonnetRateLimit(account *Account, re
|
||||
return strings.Contains(msg, "sonnet")
|
||||
}
|
||||
|
||||
// calculateOpenAI429ResetTime 从 OpenAI 429 响应头计算正确的重置时间
|
||||
// 返回 nil 表示无法从响应头中确定重置时间
|
||||
func (s *RateLimitService) calculateOpenAI429ResetTime(headers http.Header) *time.Time {
|
||||
snapshot := ParseCodexRateLimitHeaders(headers)
|
||||
if snapshot == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
normalized := snapshot.Normalize()
|
||||
if normalized == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
now := time.Now()
|
||||
|
||||
// 判断哪个限制被触发(used_percent >= 100)
|
||||
is7dExhausted := normalized.Used7dPercent != nil && *normalized.Used7dPercent >= 100
|
||||
is5hExhausted := normalized.Used5hPercent != nil && *normalized.Used5hPercent >= 100
|
||||
|
||||
// 优先使用被触发限制的重置时间
|
||||
if is7dExhausted && normalized.Reset7dSeconds != nil {
|
||||
resetAt := now.Add(time.Duration(*normalized.Reset7dSeconds) * time.Second)
|
||||
slog.Info("openai_429_7d_limit_exhausted", "reset_after_seconds", *normalized.Reset7dSeconds, "reset_at", resetAt)
|
||||
return &resetAt
|
||||
}
|
||||
if is5hExhausted && normalized.Reset5hSeconds != nil {
|
||||
resetAt := now.Add(time.Duration(*normalized.Reset5hSeconds) * time.Second)
|
||||
slog.Info("openai_429_5h_limit_exhausted", "reset_after_seconds", *normalized.Reset5hSeconds, "reset_at", resetAt)
|
||||
return &resetAt
|
||||
}
|
||||
|
||||
// 都未达到100%但收到429,使用较长的重置时间
|
||||
var maxResetSecs int
|
||||
if normalized.Reset7dSeconds != nil && *normalized.Reset7dSeconds > maxResetSecs {
|
||||
maxResetSecs = *normalized.Reset7dSeconds
|
||||
}
|
||||
if normalized.Reset5hSeconds != nil && *normalized.Reset5hSeconds > maxResetSecs {
|
||||
maxResetSecs = *normalized.Reset5hSeconds
|
||||
}
|
||||
if maxResetSecs > 0 {
|
||||
resetAt := now.Add(time.Duration(maxResetSecs) * time.Second)
|
||||
slog.Info("openai_429_using_max_reset", "max_reset_seconds", maxResetSecs, "reset_at", resetAt)
|
||||
return &resetAt
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// parseOpenAIRateLimitResetTime 解析 OpenAI 格式的 429 响应,返回重置时间的 Unix 时间戳
|
||||
// OpenAI 的 usage_limit_reached 错误格式:
|
||||
//
|
||||
// {
|
||||
// "error": {
|
||||
// "message": "The usage limit has been reached",
|
||||
// "type": "usage_limit_reached",
|
||||
// "resets_at": 1769404154,
|
||||
// "resets_in_seconds": 133107
|
||||
// }
|
||||
// }
|
||||
func parseOpenAIRateLimitResetTime(body []byte) *int64 {
|
||||
var parsed map[string]any
|
||||
if err := json.Unmarshal(body, &parsed); err != nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
errObj, ok := parsed["error"].(map[string]any)
|
||||
if !ok {
|
||||
return nil
|
||||
}
|
||||
|
||||
// 检查是否为 usage_limit_reached 或 rate_limit_exceeded 类型
|
||||
errType, _ := errObj["type"].(string)
|
||||
if errType != "usage_limit_reached" && errType != "rate_limit_exceeded" {
|
||||
return nil
|
||||
}
|
||||
|
||||
// 优先使用 resets_at(Unix 时间戳)
|
||||
if resetsAt, ok := errObj["resets_at"].(float64); ok {
|
||||
ts := int64(resetsAt)
|
||||
return &ts
|
||||
}
|
||||
if resetsAt, ok := errObj["resets_at"].(string); ok {
|
||||
if ts, err := strconv.ParseInt(resetsAt, 10, 64); err == nil {
|
||||
return &ts
|
||||
}
|
||||
}
|
||||
|
||||
// 如果没有 resets_at,尝试使用 resets_in_seconds
|
||||
if resetsInSeconds, ok := errObj["resets_in_seconds"].(float64); ok {
|
||||
ts := time.Now().Unix() + int64(resetsInSeconds)
|
||||
return &ts
|
||||
}
|
||||
if resetsInSeconds, ok := errObj["resets_in_seconds"].(string); ok {
|
||||
if sec, err := strconv.ParseInt(resetsInSeconds, 10, 64); err == nil {
|
||||
ts := time.Now().Unix() + sec
|
||||
return &ts
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// handle529 处理529过载错误
|
||||
// 根据配置设置过载冷却时间
|
||||
func (s *RateLimitService) handle529(ctx context.Context, account *Account) {
|
||||
|
||||
364
backend/internal/service/ratelimit_service_openai_test.go
Normal file
364
backend/internal/service/ratelimit_service_openai_test.go
Normal file
@@ -0,0 +1,364 @@
|
||||
package service
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
"testing"
|
||||
"time"
|
||||
)
|
||||
|
||||
func TestCalculateOpenAI429ResetTime_7dExhausted(t *testing.T) {
|
||||
svc := &RateLimitService{}
|
||||
|
||||
// Simulate headers when 7d limit is exhausted (100% used)
|
||||
// Primary = 7d (10080 minutes), Secondary = 5h (300 minutes)
|
||||
headers := http.Header{}
|
||||
headers.Set("x-codex-primary-used-percent", "100")
|
||||
headers.Set("x-codex-primary-reset-after-seconds", "384607") // ~4.5 days
|
||||
headers.Set("x-codex-primary-window-minutes", "10080") // 7 days
|
||||
headers.Set("x-codex-secondary-used-percent", "3")
|
||||
headers.Set("x-codex-secondary-reset-after-seconds", "17369") // ~4.8 hours
|
||||
headers.Set("x-codex-secondary-window-minutes", "300") // 5 hours
|
||||
|
||||
before := time.Now()
|
||||
resetAt := svc.calculateOpenAI429ResetTime(headers)
|
||||
after := time.Now()
|
||||
|
||||
if resetAt == nil {
|
||||
t.Fatal("expected non-nil resetAt")
|
||||
}
|
||||
|
||||
// Should be approximately 384607 seconds from now
|
||||
expectedDuration := 384607 * time.Second
|
||||
minExpected := before.Add(expectedDuration)
|
||||
maxExpected := after.Add(expectedDuration)
|
||||
|
||||
if resetAt.Before(minExpected) || resetAt.After(maxExpected) {
|
||||
t.Errorf("resetAt %v not in expected range [%v, %v]", resetAt, minExpected, maxExpected)
|
||||
}
|
||||
}
|
||||
|
||||
func TestCalculateOpenAI429ResetTime_5hExhausted(t *testing.T) {
|
||||
svc := &RateLimitService{}
|
||||
|
||||
// Simulate headers when 5h limit is exhausted (100% used)
|
||||
headers := http.Header{}
|
||||
headers.Set("x-codex-primary-used-percent", "50")
|
||||
headers.Set("x-codex-primary-reset-after-seconds", "500000")
|
||||
headers.Set("x-codex-primary-window-minutes", "10080") // 7 days
|
||||
headers.Set("x-codex-secondary-used-percent", "100")
|
||||
headers.Set("x-codex-secondary-reset-after-seconds", "3600") // 1 hour
|
||||
headers.Set("x-codex-secondary-window-minutes", "300") // 5 hours
|
||||
|
||||
before := time.Now()
|
||||
resetAt := svc.calculateOpenAI429ResetTime(headers)
|
||||
after := time.Now()
|
||||
|
||||
if resetAt == nil {
|
||||
t.Fatal("expected non-nil resetAt")
|
||||
}
|
||||
|
||||
// Should be approximately 3600 seconds from now
|
||||
expectedDuration := 3600 * time.Second
|
||||
minExpected := before.Add(expectedDuration)
|
||||
maxExpected := after.Add(expectedDuration)
|
||||
|
||||
if resetAt.Before(minExpected) || resetAt.After(maxExpected) {
|
||||
t.Errorf("resetAt %v not in expected range [%v, %v]", resetAt, minExpected, maxExpected)
|
||||
}
|
||||
}
|
||||
|
||||
func TestCalculateOpenAI429ResetTime_NeitherExhausted_UsesMax(t *testing.T) {
|
||||
svc := &RateLimitService{}
|
||||
|
||||
// Neither limit at 100%, should use the longer reset time
|
||||
headers := http.Header{}
|
||||
headers.Set("x-codex-primary-used-percent", "80")
|
||||
headers.Set("x-codex-primary-reset-after-seconds", "100000")
|
||||
headers.Set("x-codex-primary-window-minutes", "10080")
|
||||
headers.Set("x-codex-secondary-used-percent", "90")
|
||||
headers.Set("x-codex-secondary-reset-after-seconds", "5000")
|
||||
headers.Set("x-codex-secondary-window-minutes", "300")
|
||||
|
||||
before := time.Now()
|
||||
resetAt := svc.calculateOpenAI429ResetTime(headers)
|
||||
after := time.Now()
|
||||
|
||||
if resetAt == nil {
|
||||
t.Fatal("expected non-nil resetAt")
|
||||
}
|
||||
|
||||
// Should use the max (100000 seconds from 7d window)
|
||||
expectedDuration := 100000 * time.Second
|
||||
minExpected := before.Add(expectedDuration)
|
||||
maxExpected := after.Add(expectedDuration)
|
||||
|
||||
if resetAt.Before(minExpected) || resetAt.After(maxExpected) {
|
||||
t.Errorf("resetAt %v not in expected range [%v, %v]", resetAt, minExpected, maxExpected)
|
||||
}
|
||||
}
|
||||
|
||||
func TestCalculateOpenAI429ResetTime_NoCodexHeaders(t *testing.T) {
|
||||
svc := &RateLimitService{}
|
||||
|
||||
// No codex headers at all
|
||||
headers := http.Header{}
|
||||
headers.Set("content-type", "application/json")
|
||||
|
||||
resetAt := svc.calculateOpenAI429ResetTime(headers)
|
||||
|
||||
if resetAt != nil {
|
||||
t.Errorf("expected nil resetAt when no codex headers, got %v", resetAt)
|
||||
}
|
||||
}
|
||||
|
||||
func TestCalculateOpenAI429ResetTime_ReversedWindowOrder(t *testing.T) {
|
||||
svc := &RateLimitService{}
|
||||
|
||||
// Test when OpenAI sends primary as 5h and secondary as 7d (reversed)
|
||||
headers := http.Header{}
|
||||
headers.Set("x-codex-primary-used-percent", "100") // This is 5h
|
||||
headers.Set("x-codex-primary-reset-after-seconds", "3600") // 1 hour
|
||||
headers.Set("x-codex-primary-window-minutes", "300") // 5 hours - smaller!
|
||||
headers.Set("x-codex-secondary-used-percent", "50")
|
||||
headers.Set("x-codex-secondary-reset-after-seconds", "500000")
|
||||
headers.Set("x-codex-secondary-window-minutes", "10080") // 7 days - larger!
|
||||
|
||||
before := time.Now()
|
||||
resetAt := svc.calculateOpenAI429ResetTime(headers)
|
||||
after := time.Now()
|
||||
|
||||
if resetAt == nil {
|
||||
t.Fatal("expected non-nil resetAt")
|
||||
}
|
||||
|
||||
// Should correctly identify that primary is 5h (smaller window) and use its reset time
|
||||
expectedDuration := 3600 * time.Second
|
||||
minExpected := before.Add(expectedDuration)
|
||||
maxExpected := after.Add(expectedDuration)
|
||||
|
||||
if resetAt.Before(minExpected) || resetAt.After(maxExpected) {
|
||||
t.Errorf("resetAt %v not in expected range [%v, %v]", resetAt, minExpected, maxExpected)
|
||||
}
|
||||
}
|
||||
|
||||
func TestNormalizedCodexLimits(t *testing.T) {
|
||||
// Test the Normalize() method directly
|
||||
pUsed := 100.0
|
||||
pReset := 384607
|
||||
pWindow := 10080
|
||||
sUsed := 3.0
|
||||
sReset := 17369
|
||||
sWindow := 300
|
||||
|
||||
snapshot := &OpenAICodexUsageSnapshot{
|
||||
PrimaryUsedPercent: &pUsed,
|
||||
PrimaryResetAfterSeconds: &pReset,
|
||||
PrimaryWindowMinutes: &pWindow,
|
||||
SecondaryUsedPercent: &sUsed,
|
||||
SecondaryResetAfterSeconds: &sReset,
|
||||
SecondaryWindowMinutes: &sWindow,
|
||||
}
|
||||
|
||||
normalized := snapshot.Normalize()
|
||||
if normalized == nil {
|
||||
t.Fatal("expected non-nil normalized")
|
||||
}
|
||||
|
||||
// Primary has larger window (10080 > 300), so primary should be 7d
|
||||
if normalized.Used7dPercent == nil || *normalized.Used7dPercent != 100.0 {
|
||||
t.Errorf("expected Used7dPercent=100, got %v", normalized.Used7dPercent)
|
||||
}
|
||||
if normalized.Reset7dSeconds == nil || *normalized.Reset7dSeconds != 384607 {
|
||||
t.Errorf("expected Reset7dSeconds=384607, got %v", normalized.Reset7dSeconds)
|
||||
}
|
||||
if normalized.Used5hPercent == nil || *normalized.Used5hPercent != 3.0 {
|
||||
t.Errorf("expected Used5hPercent=3, got %v", normalized.Used5hPercent)
|
||||
}
|
||||
if normalized.Reset5hSeconds == nil || *normalized.Reset5hSeconds != 17369 {
|
||||
t.Errorf("expected Reset5hSeconds=17369, got %v", normalized.Reset5hSeconds)
|
||||
}
|
||||
}
|
||||
|
||||
func TestNormalizedCodexLimits_OnlyPrimaryData(t *testing.T) {
|
||||
// Test when only primary has data, no window_minutes
|
||||
pUsed := 80.0
|
||||
pReset := 50000
|
||||
|
||||
snapshot := &OpenAICodexUsageSnapshot{
|
||||
PrimaryUsedPercent: &pUsed,
|
||||
PrimaryResetAfterSeconds: &pReset,
|
||||
// No window_minutes, no secondary data
|
||||
}
|
||||
|
||||
normalized := snapshot.Normalize()
|
||||
if normalized == nil {
|
||||
t.Fatal("expected non-nil normalized")
|
||||
}
|
||||
|
||||
// Legacy assumption: primary=7d, secondary=5h
|
||||
if normalized.Used7dPercent == nil || *normalized.Used7dPercent != 80.0 {
|
||||
t.Errorf("expected Used7dPercent=80, got %v", normalized.Used7dPercent)
|
||||
}
|
||||
if normalized.Reset7dSeconds == nil || *normalized.Reset7dSeconds != 50000 {
|
||||
t.Errorf("expected Reset7dSeconds=50000, got %v", normalized.Reset7dSeconds)
|
||||
}
|
||||
// Secondary (5h) should be nil
|
||||
if normalized.Used5hPercent != nil {
|
||||
t.Errorf("expected Used5hPercent=nil, got %v", *normalized.Used5hPercent)
|
||||
}
|
||||
if normalized.Reset5hSeconds != nil {
|
||||
t.Errorf("expected Reset5hSeconds=nil, got %v", *normalized.Reset5hSeconds)
|
||||
}
|
||||
}
|
||||
|
||||
func TestNormalizedCodexLimits_OnlySecondaryData(t *testing.T) {
|
||||
// Test when only secondary has data, no window_minutes
|
||||
sUsed := 60.0
|
||||
sReset := 3000
|
||||
|
||||
snapshot := &OpenAICodexUsageSnapshot{
|
||||
SecondaryUsedPercent: &sUsed,
|
||||
SecondaryResetAfterSeconds: &sReset,
|
||||
// No window_minutes, no primary data
|
||||
}
|
||||
|
||||
normalized := snapshot.Normalize()
|
||||
if normalized == nil {
|
||||
t.Fatal("expected non-nil normalized")
|
||||
}
|
||||
|
||||
// Legacy assumption: primary=7d, secondary=5h
|
||||
// So secondary goes to 5h
|
||||
if normalized.Used5hPercent == nil || *normalized.Used5hPercent != 60.0 {
|
||||
t.Errorf("expected Used5hPercent=60, got %v", normalized.Used5hPercent)
|
||||
}
|
||||
if normalized.Reset5hSeconds == nil || *normalized.Reset5hSeconds != 3000 {
|
||||
t.Errorf("expected Reset5hSeconds=3000, got %v", normalized.Reset5hSeconds)
|
||||
}
|
||||
// Primary (7d) should be nil
|
||||
if normalized.Used7dPercent != nil {
|
||||
t.Errorf("expected Used7dPercent=nil, got %v", *normalized.Used7dPercent)
|
||||
}
|
||||
}
|
||||
|
||||
func TestNormalizedCodexLimits_BothDataNoWindowMinutes(t *testing.T) {
|
||||
// Test when both have data but no window_minutes
|
||||
pUsed := 100.0
|
||||
pReset := 400000
|
||||
sUsed := 50.0
|
||||
sReset := 10000
|
||||
|
||||
snapshot := &OpenAICodexUsageSnapshot{
|
||||
PrimaryUsedPercent: &pUsed,
|
||||
PrimaryResetAfterSeconds: &pReset,
|
||||
SecondaryUsedPercent: &sUsed,
|
||||
SecondaryResetAfterSeconds: &sReset,
|
||||
// No window_minutes
|
||||
}
|
||||
|
||||
normalized := snapshot.Normalize()
|
||||
if normalized == nil {
|
||||
t.Fatal("expected non-nil normalized")
|
||||
}
|
||||
|
||||
// Legacy assumption: primary=7d, secondary=5h
|
||||
if normalized.Used7dPercent == nil || *normalized.Used7dPercent != 100.0 {
|
||||
t.Errorf("expected Used7dPercent=100, got %v", normalized.Used7dPercent)
|
||||
}
|
||||
if normalized.Reset7dSeconds == nil || *normalized.Reset7dSeconds != 400000 {
|
||||
t.Errorf("expected Reset7dSeconds=400000, got %v", normalized.Reset7dSeconds)
|
||||
}
|
||||
if normalized.Used5hPercent == nil || *normalized.Used5hPercent != 50.0 {
|
||||
t.Errorf("expected Used5hPercent=50, got %v", normalized.Used5hPercent)
|
||||
}
|
||||
if normalized.Reset5hSeconds == nil || *normalized.Reset5hSeconds != 10000 {
|
||||
t.Errorf("expected Reset5hSeconds=10000, got %v", normalized.Reset5hSeconds)
|
||||
}
|
||||
}
|
||||
|
||||
func TestHandle429_AnthropicPlatformUnaffected(t *testing.T) {
|
||||
// Verify that Anthropic platform accounts still use the original logic
|
||||
// This test ensures we don't break existing Claude account rate limiting
|
||||
|
||||
svc := &RateLimitService{}
|
||||
|
||||
// Simulate Anthropic 429 headers
|
||||
headers := http.Header{}
|
||||
headers.Set("anthropic-ratelimit-unified-reset", "1737820800") // A future Unix timestamp
|
||||
|
||||
// For Anthropic platform, calculateOpenAI429ResetTime should return nil
|
||||
// because it only handles OpenAI platform
|
||||
resetAt := svc.calculateOpenAI429ResetTime(headers)
|
||||
|
||||
// Should return nil since there are no x-codex-* headers
|
||||
if resetAt != nil {
|
||||
t.Errorf("expected nil for Anthropic headers, got %v", resetAt)
|
||||
}
|
||||
}
|
||||
|
||||
func TestCalculateOpenAI429ResetTime_UserProvidedScenario(t *testing.T) {
|
||||
// This is the exact scenario from the user:
|
||||
// codex_7d_used_percent: 100
|
||||
// codex_7d_reset_after_seconds: 384607 (约4.5天后重置)
|
||||
// codex_5h_used_percent: 3
|
||||
// codex_5h_reset_after_seconds: 17369 (约4.8小时后重置)
|
||||
|
||||
svc := &RateLimitService{}
|
||||
|
||||
// Simulate headers matching user's data
|
||||
// Note: We need to map the canonical 5h/7d back to primary/secondary
|
||||
// Based on typical OpenAI behavior: primary=7d (larger window), secondary=5h (smaller window)
|
||||
headers := http.Header{}
|
||||
headers.Set("x-codex-primary-used-percent", "100")
|
||||
headers.Set("x-codex-primary-reset-after-seconds", "384607")
|
||||
headers.Set("x-codex-primary-window-minutes", "10080") // 7 days = 10080 minutes
|
||||
headers.Set("x-codex-secondary-used-percent", "3")
|
||||
headers.Set("x-codex-secondary-reset-after-seconds", "17369")
|
||||
headers.Set("x-codex-secondary-window-minutes", "300") // 5 hours = 300 minutes
|
||||
|
||||
before := time.Now()
|
||||
resetAt := svc.calculateOpenAI429ResetTime(headers)
|
||||
after := time.Now()
|
||||
|
||||
if resetAt == nil {
|
||||
t.Fatal("expected non-nil resetAt for user scenario")
|
||||
}
|
||||
|
||||
// Should use the 7d reset time (384607 seconds) since 7d limit is exhausted (100%)
|
||||
expectedDuration := 384607 * time.Second
|
||||
minExpected := before.Add(expectedDuration)
|
||||
maxExpected := after.Add(expectedDuration)
|
||||
|
||||
if resetAt.Before(minExpected) || resetAt.After(maxExpected) {
|
||||
t.Errorf("resetAt %v not in expected range [%v, %v]", resetAt, minExpected, maxExpected)
|
||||
}
|
||||
|
||||
// Verify it's approximately 4.45 days (384607 seconds)
|
||||
duration := resetAt.Sub(before)
|
||||
actualDays := duration.Hours() / 24.0
|
||||
|
||||
// 384607 / 86400 = ~4.45 days
|
||||
if actualDays < 4.4 || actualDays > 4.5 {
|
||||
t.Errorf("expected ~4.45 days, got %.2f days", actualDays)
|
||||
}
|
||||
|
||||
t.Logf("User scenario: reset_at=%v, duration=%.2f days", resetAt, actualDays)
|
||||
}
|
||||
|
||||
func TestCalculateOpenAI429ResetTime_5MinFallbackWhenNoReset(t *testing.T) {
|
||||
// Test that we return nil when there's used_percent but no reset_after_seconds
|
||||
// This should cause the caller to use the default 5-minute fallback
|
||||
|
||||
svc := &RateLimitService{}
|
||||
|
||||
headers := http.Header{}
|
||||
headers.Set("x-codex-primary-used-percent", "100")
|
||||
// No reset_after_seconds!
|
||||
|
||||
resetAt := svc.calculateOpenAI429ResetTime(headers)
|
||||
|
||||
// Should return nil since there's no reset time available
|
||||
if resetAt != nil {
|
||||
t.Errorf("expected nil when no reset_after_seconds, got %v", resetAt)
|
||||
}
|
||||
}
|
||||
@@ -61,6 +61,8 @@ func (s *SettingService) GetPublicSettings(ctx context.Context) (*PublicSettings
|
||||
SettingKeyRegistrationEnabled,
|
||||
SettingKeyEmailVerifyEnabled,
|
||||
SettingKeyPromoCodeEnabled,
|
||||
SettingKeyPasswordResetEnabled,
|
||||
SettingKeyTotpEnabled,
|
||||
SettingKeyTurnstileEnabled,
|
||||
SettingKeyTurnstileSiteKey,
|
||||
SettingKeySiteName,
|
||||
@@ -86,21 +88,27 @@ func (s *SettingService) GetPublicSettings(ctx context.Context) (*PublicSettings
|
||||
linuxDoEnabled = s.cfg != nil && s.cfg.LinuxDo.Enabled
|
||||
}
|
||||
|
||||
// Password reset requires email verification to be enabled
|
||||
emailVerifyEnabled := settings[SettingKeyEmailVerifyEnabled] == "true"
|
||||
passwordResetEnabled := emailVerifyEnabled && settings[SettingKeyPasswordResetEnabled] == "true"
|
||||
|
||||
return &PublicSettings{
|
||||
RegistrationEnabled: settings[SettingKeyRegistrationEnabled] == "true",
|
||||
EmailVerifyEnabled: settings[SettingKeyEmailVerifyEnabled] == "true",
|
||||
PromoCodeEnabled: settings[SettingKeyPromoCodeEnabled] != "false", // 默认启用
|
||||
TurnstileEnabled: settings[SettingKeyTurnstileEnabled] == "true",
|
||||
TurnstileSiteKey: settings[SettingKeyTurnstileSiteKey],
|
||||
SiteName: s.getStringOrDefault(settings, SettingKeySiteName, "TianShuAPI"),
|
||||
SiteLogo: settings[SettingKeySiteLogo],
|
||||
SiteSubtitle: s.getStringOrDefault(settings, SettingKeySiteSubtitle, "Subscription to API Conversion Platform"),
|
||||
APIBaseURL: settings[SettingKeyAPIBaseURL],
|
||||
ContactInfo: settings[SettingKeyContactInfo],
|
||||
DocURL: settings[SettingKeyDocURL],
|
||||
HomeContent: settings[SettingKeyHomeContent],
|
||||
HideCcsImportButton: settings[SettingKeyHideCcsImportButton] == "true",
|
||||
LinuxDoOAuthEnabled: linuxDoEnabled,
|
||||
RegistrationEnabled: settings[SettingKeyRegistrationEnabled] == "true",
|
||||
EmailVerifyEnabled: emailVerifyEnabled,
|
||||
PromoCodeEnabled: settings[SettingKeyPromoCodeEnabled] != "false", // 默认启用
|
||||
PasswordResetEnabled: passwordResetEnabled,
|
||||
TotpEnabled: settings[SettingKeyTotpEnabled] == "true",
|
||||
TurnstileEnabled: settings[SettingKeyTurnstileEnabled] == "true",
|
||||
TurnstileSiteKey: settings[SettingKeyTurnstileSiteKey],
|
||||
SiteName: s.getStringOrDefault(settings, SettingKeySiteName, "TianShuAPI"),
|
||||
SiteLogo: settings[SettingKeySiteLogo],
|
||||
SiteSubtitle: s.getStringOrDefault(settings, SettingKeySiteSubtitle, "Subscription to API Conversion Platform"),
|
||||
APIBaseURL: settings[SettingKeyAPIBaseURL],
|
||||
ContactInfo: settings[SettingKeyContactInfo],
|
||||
DocURL: settings[SettingKeyDocURL],
|
||||
HomeContent: settings[SettingKeyHomeContent],
|
||||
HideCcsImportButton: settings[SettingKeyHideCcsImportButton] == "true",
|
||||
LinuxDoOAuthEnabled: linuxDoEnabled,
|
||||
}, nil
|
||||
}
|
||||
|
||||
@@ -125,37 +133,41 @@ func (s *SettingService) GetPublicSettingsForInjection(ctx context.Context) (any
|
||||
|
||||
// Return a struct that matches the frontend's expected format
|
||||
return &struct {
|
||||
RegistrationEnabled bool `json:"registration_enabled"`
|
||||
EmailVerifyEnabled bool `json:"email_verify_enabled"`
|
||||
PromoCodeEnabled bool `json:"promo_code_enabled"`
|
||||
TurnstileEnabled bool `json:"turnstile_enabled"`
|
||||
TurnstileSiteKey string `json:"turnstile_site_key,omitempty"`
|
||||
SiteName string `json:"site_name"`
|
||||
SiteLogo string `json:"site_logo,omitempty"`
|
||||
SiteSubtitle string `json:"site_subtitle,omitempty"`
|
||||
APIBaseURL string `json:"api_base_url,omitempty"`
|
||||
ContactInfo string `json:"contact_info,omitempty"`
|
||||
DocURL string `json:"doc_url,omitempty"`
|
||||
HomeContent string `json:"home_content,omitempty"`
|
||||
HideCcsImportButton bool `json:"hide_ccs_import_button"`
|
||||
LinuxDoOAuthEnabled bool `json:"linuxdo_oauth_enabled"`
|
||||
Version string `json:"version,omitempty"`
|
||||
RegistrationEnabled bool `json:"registration_enabled"`
|
||||
EmailVerifyEnabled bool `json:"email_verify_enabled"`
|
||||
PromoCodeEnabled bool `json:"promo_code_enabled"`
|
||||
PasswordResetEnabled bool `json:"password_reset_enabled"`
|
||||
TotpEnabled bool `json:"totp_enabled"`
|
||||
TurnstileEnabled bool `json:"turnstile_enabled"`
|
||||
TurnstileSiteKey string `json:"turnstile_site_key,omitempty"`
|
||||
SiteName string `json:"site_name"`
|
||||
SiteLogo string `json:"site_logo,omitempty"`
|
||||
SiteSubtitle string `json:"site_subtitle,omitempty"`
|
||||
APIBaseURL string `json:"api_base_url,omitempty"`
|
||||
ContactInfo string `json:"contact_info,omitempty"`
|
||||
DocURL string `json:"doc_url,omitempty"`
|
||||
HomeContent string `json:"home_content,omitempty"`
|
||||
HideCcsImportButton bool `json:"hide_ccs_import_button"`
|
||||
LinuxDoOAuthEnabled bool `json:"linuxdo_oauth_enabled"`
|
||||
Version string `json:"version,omitempty"`
|
||||
}{
|
||||
RegistrationEnabled: settings.RegistrationEnabled,
|
||||
EmailVerifyEnabled: settings.EmailVerifyEnabled,
|
||||
PromoCodeEnabled: settings.PromoCodeEnabled,
|
||||
TurnstileEnabled: settings.TurnstileEnabled,
|
||||
TurnstileSiteKey: settings.TurnstileSiteKey,
|
||||
SiteName: settings.SiteName,
|
||||
SiteLogo: settings.SiteLogo,
|
||||
SiteSubtitle: settings.SiteSubtitle,
|
||||
APIBaseURL: settings.APIBaseURL,
|
||||
ContactInfo: settings.ContactInfo,
|
||||
DocURL: settings.DocURL,
|
||||
HomeContent: settings.HomeContent,
|
||||
HideCcsImportButton: settings.HideCcsImportButton,
|
||||
LinuxDoOAuthEnabled: settings.LinuxDoOAuthEnabled,
|
||||
Version: s.version,
|
||||
RegistrationEnabled: settings.RegistrationEnabled,
|
||||
EmailVerifyEnabled: settings.EmailVerifyEnabled,
|
||||
PromoCodeEnabled: settings.PromoCodeEnabled,
|
||||
PasswordResetEnabled: settings.PasswordResetEnabled,
|
||||
TotpEnabled: settings.TotpEnabled,
|
||||
TurnstileEnabled: settings.TurnstileEnabled,
|
||||
TurnstileSiteKey: settings.TurnstileSiteKey,
|
||||
SiteName: settings.SiteName,
|
||||
SiteLogo: settings.SiteLogo,
|
||||
SiteSubtitle: settings.SiteSubtitle,
|
||||
APIBaseURL: settings.APIBaseURL,
|
||||
ContactInfo: settings.ContactInfo,
|
||||
DocURL: settings.DocURL,
|
||||
HomeContent: settings.HomeContent,
|
||||
HideCcsImportButton: settings.HideCcsImportButton,
|
||||
LinuxDoOAuthEnabled: settings.LinuxDoOAuthEnabled,
|
||||
Version: s.version,
|
||||
}, nil
|
||||
}
|
||||
|
||||
@@ -167,6 +179,8 @@ func (s *SettingService) UpdateSettings(ctx context.Context, settings *SystemSet
|
||||
updates[SettingKeyRegistrationEnabled] = strconv.FormatBool(settings.RegistrationEnabled)
|
||||
updates[SettingKeyEmailVerifyEnabled] = strconv.FormatBool(settings.EmailVerifyEnabled)
|
||||
updates[SettingKeyPromoCodeEnabled] = strconv.FormatBool(settings.PromoCodeEnabled)
|
||||
updates[SettingKeyPasswordResetEnabled] = strconv.FormatBool(settings.PasswordResetEnabled)
|
||||
updates[SettingKeyTotpEnabled] = strconv.FormatBool(settings.TotpEnabled)
|
||||
|
||||
// 邮件服务设置(只有非空才更新密码)
|
||||
updates[SettingKeySMTPHost] = settings.SMTPHost
|
||||
@@ -262,6 +276,35 @@ func (s *SettingService) IsPromoCodeEnabled(ctx context.Context) bool {
|
||||
return value != "false"
|
||||
}
|
||||
|
||||
// IsPasswordResetEnabled 检查是否启用密码重置功能
|
||||
// 要求:必须同时开启邮件验证
|
||||
func (s *SettingService) IsPasswordResetEnabled(ctx context.Context) bool {
|
||||
// Password reset requires email verification to be enabled
|
||||
if !s.IsEmailVerifyEnabled(ctx) {
|
||||
return false
|
||||
}
|
||||
value, err := s.settingRepo.GetValue(ctx, SettingKeyPasswordResetEnabled)
|
||||
if err != nil {
|
||||
return false // 默认关闭
|
||||
}
|
||||
return value == "true"
|
||||
}
|
||||
|
||||
// IsTotpEnabled 检查是否启用 TOTP 双因素认证功能
|
||||
func (s *SettingService) IsTotpEnabled(ctx context.Context) bool {
|
||||
value, err := s.settingRepo.GetValue(ctx, SettingKeyTotpEnabled)
|
||||
if err != nil {
|
||||
return false // 默认关闭
|
||||
}
|
||||
return value == "true"
|
||||
}
|
||||
|
||||
// IsTotpEncryptionKeyConfigured 检查 TOTP 加密密钥是否已手动配置
|
||||
// 只有手动配置了密钥才允许在管理后台启用 TOTP 功能
|
||||
func (s *SettingService) IsTotpEncryptionKeyConfigured() bool {
|
||||
return s.cfg.Totp.EncryptionKeyConfigured
|
||||
}
|
||||
|
||||
// GetSiteName 获取网站名称
|
||||
func (s *SettingService) GetSiteName(ctx context.Context) string {
|
||||
value, err := s.settingRepo.GetValue(ctx, SettingKeySiteName)
|
||||
@@ -340,10 +383,13 @@ func (s *SettingService) InitializeDefaultSettings(ctx context.Context) error {
|
||||
|
||||
// parseSettings 解析设置到结构体
|
||||
func (s *SettingService) parseSettings(settings map[string]string) *SystemSettings {
|
||||
emailVerifyEnabled := settings[SettingKeyEmailVerifyEnabled] == "true"
|
||||
result := &SystemSettings{
|
||||
RegistrationEnabled: settings[SettingKeyRegistrationEnabled] == "true",
|
||||
EmailVerifyEnabled: settings[SettingKeyEmailVerifyEnabled] == "true",
|
||||
EmailVerifyEnabled: emailVerifyEnabled,
|
||||
PromoCodeEnabled: settings[SettingKeyPromoCodeEnabled] != "false", // 默认启用
|
||||
PasswordResetEnabled: emailVerifyEnabled && settings[SettingKeyPasswordResetEnabled] == "true",
|
||||
TotpEnabled: settings[SettingKeyTotpEnabled] == "true",
|
||||
SMTPHost: settings[SettingKeySMTPHost],
|
||||
SMTPUsername: settings[SettingKeySMTPUsername],
|
||||
SMTPFrom: settings[SettingKeySMTPFrom],
|
||||
|
||||
@@ -1,9 +1,11 @@
|
||||
package service
|
||||
|
||||
type SystemSettings struct {
|
||||
RegistrationEnabled bool
|
||||
EmailVerifyEnabled bool
|
||||
PromoCodeEnabled bool
|
||||
RegistrationEnabled bool
|
||||
EmailVerifyEnabled bool
|
||||
PromoCodeEnabled bool
|
||||
PasswordResetEnabled bool
|
||||
TotpEnabled bool // TOTP 双因素认证
|
||||
|
||||
SMTPHost string
|
||||
SMTPPort int
|
||||
@@ -57,21 +59,23 @@ type SystemSettings struct {
|
||||
}
|
||||
|
||||
type PublicSettings struct {
|
||||
RegistrationEnabled bool
|
||||
EmailVerifyEnabled bool
|
||||
PromoCodeEnabled bool
|
||||
TurnstileEnabled bool
|
||||
TurnstileSiteKey string
|
||||
SiteName string
|
||||
SiteLogo string
|
||||
SiteSubtitle string
|
||||
APIBaseURL string
|
||||
ContactInfo string
|
||||
DocURL string
|
||||
HomeContent string
|
||||
HideCcsImportButton bool
|
||||
LinuxDoOAuthEnabled bool
|
||||
Version string
|
||||
RegistrationEnabled bool
|
||||
EmailVerifyEnabled bool
|
||||
PromoCodeEnabled bool
|
||||
PasswordResetEnabled bool
|
||||
TotpEnabled bool // TOTP 双因素认证
|
||||
TurnstileEnabled bool
|
||||
TurnstileSiteKey string
|
||||
SiteName string
|
||||
SiteLogo string
|
||||
SiteSubtitle string
|
||||
APIBaseURL string
|
||||
ContactInfo string
|
||||
DocURL string
|
||||
HomeContent string
|
||||
HideCcsImportButton bool
|
||||
LinuxDoOAuthEnabled bool
|
||||
Version string
|
||||
}
|
||||
|
||||
// StreamTimeoutSettings 流超时处理配置(仅控制超时后的处理方式,超时判定由网关配置控制)
|
||||
|
||||
71
backend/internal/service/subscription_expiry_service.go
Normal file
71
backend/internal/service/subscription_expiry_service.go
Normal file
@@ -0,0 +1,71 @@
|
||||
package service
|
||||
|
||||
import (
|
||||
"context"
|
||||
"log"
|
||||
"sync"
|
||||
"time"
|
||||
)
|
||||
|
||||
// SubscriptionExpiryService periodically updates expired subscription status.
|
||||
type SubscriptionExpiryService struct {
|
||||
userSubRepo UserSubscriptionRepository
|
||||
interval time.Duration
|
||||
stopCh chan struct{}
|
||||
stopOnce sync.Once
|
||||
wg sync.WaitGroup
|
||||
}
|
||||
|
||||
func NewSubscriptionExpiryService(userSubRepo UserSubscriptionRepository, interval time.Duration) *SubscriptionExpiryService {
|
||||
return &SubscriptionExpiryService{
|
||||
userSubRepo: userSubRepo,
|
||||
interval: interval,
|
||||
stopCh: make(chan struct{}),
|
||||
}
|
||||
}
|
||||
|
||||
func (s *SubscriptionExpiryService) Start() {
|
||||
if s == nil || s.userSubRepo == nil || s.interval <= 0 {
|
||||
return
|
||||
}
|
||||
s.wg.Add(1)
|
||||
go func() {
|
||||
defer s.wg.Done()
|
||||
ticker := time.NewTicker(s.interval)
|
||||
defer ticker.Stop()
|
||||
|
||||
s.runOnce()
|
||||
for {
|
||||
select {
|
||||
case <-ticker.C:
|
||||
s.runOnce()
|
||||
case <-s.stopCh:
|
||||
return
|
||||
}
|
||||
}
|
||||
}()
|
||||
}
|
||||
|
||||
func (s *SubscriptionExpiryService) Stop() {
|
||||
if s == nil {
|
||||
return
|
||||
}
|
||||
s.stopOnce.Do(func() {
|
||||
close(s.stopCh)
|
||||
})
|
||||
s.wg.Wait()
|
||||
}
|
||||
|
||||
func (s *SubscriptionExpiryService) runOnce() {
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
|
||||
defer cancel()
|
||||
|
||||
updated, err := s.userSubRepo.BatchUpdateExpiredStatus(ctx)
|
||||
if err != nil {
|
||||
log.Printf("[SubscriptionExpiry] Update expired subscriptions failed: %v", err)
|
||||
return
|
||||
}
|
||||
if updated > 0 {
|
||||
log.Printf("[SubscriptionExpiry] Updated %d expired subscriptions", updated)
|
||||
}
|
||||
}
|
||||
@@ -324,18 +324,31 @@ func (s *SubscriptionService) ExtendSubscription(ctx context.Context, subscripti
|
||||
days = -MaxValidityDays
|
||||
}
|
||||
|
||||
now := time.Now()
|
||||
isExpired := !sub.ExpiresAt.After(now)
|
||||
|
||||
// 如果订阅已过期,不允许负向调整
|
||||
if isExpired && days < 0 {
|
||||
return nil, infraerrors.BadRequest("CANNOT_SHORTEN_EXPIRED", "cannot shorten an expired subscription")
|
||||
}
|
||||
|
||||
// 计算新的过期时间
|
||||
newExpiresAt := sub.ExpiresAt.AddDate(0, 0, days)
|
||||
var newExpiresAt time.Time
|
||||
if isExpired {
|
||||
// 已过期:从当前时间开始增加天数
|
||||
newExpiresAt = now.AddDate(0, 0, days)
|
||||
} else {
|
||||
// 未过期:从原过期时间增加/减少天数
|
||||
newExpiresAt = sub.ExpiresAt.AddDate(0, 0, days)
|
||||
}
|
||||
|
||||
if newExpiresAt.After(MaxExpiresAt) {
|
||||
newExpiresAt = MaxExpiresAt
|
||||
}
|
||||
|
||||
// 如果是缩短(负数),检查新的过期时间必须大于当前时间
|
||||
if days < 0 {
|
||||
now := time.Now()
|
||||
if !newExpiresAt.After(now) {
|
||||
return nil, ErrAdjustWouldExpire
|
||||
}
|
||||
// 检查新的过期时间必须大于当前时间
|
||||
if !newExpiresAt.After(now) {
|
||||
return nil, ErrAdjustWouldExpire
|
||||
}
|
||||
|
||||
if err := s.userSubRepo.ExtendExpiry(ctx, subscriptionID, newExpiresAt); err != nil {
|
||||
@@ -383,6 +396,7 @@ func (s *SubscriptionService) ListUserSubscriptions(ctx context.Context, userID
|
||||
return nil, err
|
||||
}
|
||||
normalizeExpiredWindows(subs)
|
||||
normalizeSubscriptionStatus(subs)
|
||||
return subs, nil
|
||||
}
|
||||
|
||||
@@ -404,17 +418,19 @@ func (s *SubscriptionService) ListGroupSubscriptions(ctx context.Context, groupI
|
||||
return nil, nil, err
|
||||
}
|
||||
normalizeExpiredWindows(subs)
|
||||
normalizeSubscriptionStatus(subs)
|
||||
return subs, pag, nil
|
||||
}
|
||||
|
||||
// List 获取所有订阅(分页,支持筛选)
|
||||
func (s *SubscriptionService) List(ctx context.Context, page, pageSize int, userID, groupID *int64, status string) ([]UserSubscription, *pagination.PaginationResult, error) {
|
||||
// List 获取所有订阅(分页,支持筛选和排序)
|
||||
func (s *SubscriptionService) List(ctx context.Context, page, pageSize int, userID, groupID *int64, status, sortBy, sortOrder string) ([]UserSubscription, *pagination.PaginationResult, error) {
|
||||
params := pagination.PaginationParams{Page: page, PageSize: pageSize}
|
||||
subs, pag, err := s.userSubRepo.List(ctx, params, userID, groupID, status)
|
||||
subs, pag, err := s.userSubRepo.List(ctx, params, userID, groupID, status, sortBy, sortOrder)
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
normalizeExpiredWindows(subs)
|
||||
normalizeSubscriptionStatus(subs)
|
||||
return subs, pag, nil
|
||||
}
|
||||
|
||||
@@ -441,6 +457,18 @@ func normalizeExpiredWindows(subs []UserSubscription) {
|
||||
}
|
||||
}
|
||||
|
||||
// normalizeSubscriptionStatus 根据实际过期时间修正状态(仅影响返回数据,不影响数据库)
|
||||
// 这确保前端显示正确的状态,即使定时任务尚未更新数据库
|
||||
func normalizeSubscriptionStatus(subs []UserSubscription) {
|
||||
now := time.Now()
|
||||
for i := range subs {
|
||||
sub := &subs[i]
|
||||
if sub.Status == SubscriptionStatusActive && !sub.ExpiresAt.After(now) {
|
||||
sub.Status = SubscriptionStatusExpired
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// startOfDay 返回给定时间所在日期的零点(保持原时区)
|
||||
func startOfDay(t time.Time) time.Time {
|
||||
return time.Date(t.Year(), t.Month(), t.Day(), 0, 0, 0, 0, t.Location())
|
||||
@@ -659,11 +687,6 @@ func (s *SubscriptionService) GetUserSubscriptionsWithProgress(ctx context.Conte
|
||||
return progresses, nil
|
||||
}
|
||||
|
||||
// UpdateExpiredSubscriptions 更新过期订阅状态(定时任务调用)
|
||||
func (s *SubscriptionService) UpdateExpiredSubscriptions(ctx context.Context) (int64, error) {
|
||||
return s.userSubRepo.BatchUpdateExpiredStatus(ctx)
|
||||
}
|
||||
|
||||
// ValidateSubscription 验证订阅是否有效
|
||||
func (s *SubscriptionService) ValidateSubscription(ctx context.Context, sub *UserSubscription) error {
|
||||
if sub.Status == SubscriptionStatusExpired {
|
||||
|
||||
@@ -237,7 +237,8 @@ func (s *TokenRefreshService) refreshWithRetry(ctx context.Context, account *Acc
|
||||
}
|
||||
|
||||
// isNonRetryableRefreshError 判断是否为不可重试的刷新错误
|
||||
// 这些错误通常表示凭证已失效,需要用户重新授权
|
||||
// 这些错误通常表示凭证已失效或配置确实缺失,需要用户重新授权
|
||||
// 注意:missing_project_id 错误只在真正缺失(从未获取过)时返回,临时获取失败不会返回此错误
|
||||
func isNonRetryableRefreshError(err error) bool {
|
||||
if err == nil {
|
||||
return false
|
||||
|
||||
506
backend/internal/service/totp_service.go
Normal file
506
backend/internal/service/totp_service.go
Normal file
@@ -0,0 +1,506 @@
|
||||
package service
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/rand"
|
||||
"crypto/subtle"
|
||||
"encoding/hex"
|
||||
"fmt"
|
||||
"log/slog"
|
||||
"time"
|
||||
|
||||
"github.com/pquerna/otp/totp"
|
||||
|
||||
infraerrors "github.com/Wei-Shaw/sub2api/internal/pkg/errors"
|
||||
)
|
||||
|
||||
var (
|
||||
ErrTotpNotEnabled = infraerrors.BadRequest("TOTP_NOT_ENABLED", "totp feature is not enabled")
|
||||
ErrTotpAlreadyEnabled = infraerrors.BadRequest("TOTP_ALREADY_ENABLED", "totp is already enabled for this account")
|
||||
ErrTotpNotSetup = infraerrors.BadRequest("TOTP_NOT_SETUP", "totp is not set up for this account")
|
||||
ErrTotpInvalidCode = infraerrors.BadRequest("TOTP_INVALID_CODE", "invalid totp code")
|
||||
ErrTotpSetupExpired = infraerrors.BadRequest("TOTP_SETUP_EXPIRED", "totp setup session expired")
|
||||
ErrTotpTooManyAttempts = infraerrors.TooManyRequests("TOTP_TOO_MANY_ATTEMPTS", "too many verification attempts, please try again later")
|
||||
ErrVerifyCodeRequired = infraerrors.BadRequest("VERIFY_CODE_REQUIRED", "email verification code is required")
|
||||
ErrPasswordRequired = infraerrors.BadRequest("PASSWORD_REQUIRED", "password is required")
|
||||
)
|
||||
|
||||
// TotpCache defines cache operations for TOTP service
|
||||
type TotpCache interface {
|
||||
// Setup session methods
|
||||
GetSetupSession(ctx context.Context, userID int64) (*TotpSetupSession, error)
|
||||
SetSetupSession(ctx context.Context, userID int64, session *TotpSetupSession, ttl time.Duration) error
|
||||
DeleteSetupSession(ctx context.Context, userID int64) error
|
||||
|
||||
// Login session methods (for 2FA login flow)
|
||||
GetLoginSession(ctx context.Context, tempToken string) (*TotpLoginSession, error)
|
||||
SetLoginSession(ctx context.Context, tempToken string, session *TotpLoginSession, ttl time.Duration) error
|
||||
DeleteLoginSession(ctx context.Context, tempToken string) error
|
||||
|
||||
// Rate limiting
|
||||
IncrementVerifyAttempts(ctx context.Context, userID int64) (int, error)
|
||||
GetVerifyAttempts(ctx context.Context, userID int64) (int, error)
|
||||
ClearVerifyAttempts(ctx context.Context, userID int64) error
|
||||
}
|
||||
|
||||
// SecretEncryptor defines encryption operations for TOTP secrets
|
||||
type SecretEncryptor interface {
|
||||
Encrypt(plaintext string) (string, error)
|
||||
Decrypt(ciphertext string) (string, error)
|
||||
}
|
||||
|
||||
// TotpSetupSession represents a TOTP setup session
|
||||
type TotpSetupSession struct {
|
||||
Secret string // Plain text TOTP secret (not encrypted yet)
|
||||
SetupToken string // Random token to verify setup request
|
||||
CreatedAt time.Time
|
||||
}
|
||||
|
||||
// TotpLoginSession represents a pending 2FA login session
|
||||
type TotpLoginSession struct {
|
||||
UserID int64
|
||||
Email string
|
||||
TokenExpiry time.Time
|
||||
}
|
||||
|
||||
// TotpStatus represents the TOTP status for a user
|
||||
type TotpStatus struct {
|
||||
Enabled bool `json:"enabled"`
|
||||
EnabledAt *time.Time `json:"enabled_at,omitempty"`
|
||||
FeatureEnabled bool `json:"feature_enabled"`
|
||||
}
|
||||
|
||||
// TotpSetupResponse represents the response for initiating TOTP setup
|
||||
type TotpSetupResponse struct {
|
||||
Secret string `json:"secret"`
|
||||
QRCodeURL string `json:"qr_code_url"`
|
||||
SetupToken string `json:"setup_token"`
|
||||
Countdown int `json:"countdown"` // seconds until setup expires
|
||||
}
|
||||
|
||||
const (
|
||||
totpSetupTTL = 5 * time.Minute
|
||||
totpLoginTTL = 5 * time.Minute
|
||||
totpAttemptsTTL = 15 * time.Minute
|
||||
maxTotpAttempts = 5
|
||||
totpIssuer = "Sub2API"
|
||||
)
|
||||
|
||||
// TotpService handles TOTP operations
|
||||
type TotpService struct {
|
||||
userRepo UserRepository
|
||||
encryptor SecretEncryptor
|
||||
cache TotpCache
|
||||
settingService *SettingService
|
||||
emailService *EmailService
|
||||
emailQueueService *EmailQueueService
|
||||
}
|
||||
|
||||
// NewTotpService creates a new TOTP service
|
||||
func NewTotpService(
|
||||
userRepo UserRepository,
|
||||
encryptor SecretEncryptor,
|
||||
cache TotpCache,
|
||||
settingService *SettingService,
|
||||
emailService *EmailService,
|
||||
emailQueueService *EmailQueueService,
|
||||
) *TotpService {
|
||||
return &TotpService{
|
||||
userRepo: userRepo,
|
||||
encryptor: encryptor,
|
||||
cache: cache,
|
||||
settingService: settingService,
|
||||
emailService: emailService,
|
||||
emailQueueService: emailQueueService,
|
||||
}
|
||||
}
|
||||
|
||||
// GetStatus returns the TOTP status for a user
|
||||
func (s *TotpService) GetStatus(ctx context.Context, userID int64) (*TotpStatus, error) {
|
||||
featureEnabled := s.settingService.IsTotpEnabled(ctx)
|
||||
|
||||
user, err := s.userRepo.GetByID(ctx, userID)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("get user: %w", err)
|
||||
}
|
||||
|
||||
return &TotpStatus{
|
||||
Enabled: user.TotpEnabled,
|
||||
EnabledAt: user.TotpEnabledAt,
|
||||
FeatureEnabled: featureEnabled,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// InitiateSetup starts the TOTP setup process
|
||||
// If email verification is enabled, emailCode is required; otherwise password is required
|
||||
func (s *TotpService) InitiateSetup(ctx context.Context, userID int64, emailCode, password string) (*TotpSetupResponse, error) {
|
||||
// Check if TOTP feature is enabled globally
|
||||
if !s.settingService.IsTotpEnabled(ctx) {
|
||||
return nil, ErrTotpNotEnabled
|
||||
}
|
||||
|
||||
// Get user and check if TOTP is already enabled
|
||||
user, err := s.userRepo.GetByID(ctx, userID)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("get user: %w", err)
|
||||
}
|
||||
|
||||
if user.TotpEnabled {
|
||||
return nil, ErrTotpAlreadyEnabled
|
||||
}
|
||||
|
||||
// Verify identity based on email verification setting
|
||||
if s.settingService.IsEmailVerifyEnabled(ctx) {
|
||||
// Email verification enabled - verify email code
|
||||
if emailCode == "" {
|
||||
return nil, ErrVerifyCodeRequired
|
||||
}
|
||||
if err := s.emailService.VerifyCode(ctx, user.Email, emailCode); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
} else {
|
||||
// Email verification disabled - verify password
|
||||
if password == "" {
|
||||
return nil, ErrPasswordRequired
|
||||
}
|
||||
if !user.CheckPassword(password) {
|
||||
return nil, ErrPasswordIncorrect
|
||||
}
|
||||
}
|
||||
|
||||
// Generate a new TOTP key
|
||||
key, err := totp.Generate(totp.GenerateOpts{
|
||||
Issuer: totpIssuer,
|
||||
AccountName: user.Email,
|
||||
})
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("generate totp key: %w", err)
|
||||
}
|
||||
|
||||
// Generate a random setup token
|
||||
setupToken, err := generateRandomToken(32)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("generate setup token: %w", err)
|
||||
}
|
||||
|
||||
// Store the setup session in cache
|
||||
session := &TotpSetupSession{
|
||||
Secret: key.Secret(),
|
||||
SetupToken: setupToken,
|
||||
CreatedAt: time.Now(),
|
||||
}
|
||||
|
||||
if err := s.cache.SetSetupSession(ctx, userID, session, totpSetupTTL); err != nil {
|
||||
return nil, fmt.Errorf("store setup session: %w", err)
|
||||
}
|
||||
|
||||
return &TotpSetupResponse{
|
||||
Secret: key.Secret(),
|
||||
QRCodeURL: key.URL(),
|
||||
SetupToken: setupToken,
|
||||
Countdown: int(totpSetupTTL.Seconds()),
|
||||
}, nil
|
||||
}
|
||||
|
||||
// CompleteSetup completes the TOTP setup by verifying the code
|
||||
func (s *TotpService) CompleteSetup(ctx context.Context, userID int64, totpCode, setupToken string) error {
|
||||
// Check if TOTP feature is enabled globally
|
||||
if !s.settingService.IsTotpEnabled(ctx) {
|
||||
return ErrTotpNotEnabled
|
||||
}
|
||||
|
||||
// Get the setup session
|
||||
session, err := s.cache.GetSetupSession(ctx, userID)
|
||||
if err != nil {
|
||||
return ErrTotpSetupExpired
|
||||
}
|
||||
|
||||
if session == nil {
|
||||
return ErrTotpSetupExpired
|
||||
}
|
||||
|
||||
// Verify the setup token (constant-time comparison)
|
||||
if subtle.ConstantTimeCompare([]byte(session.SetupToken), []byte(setupToken)) != 1 {
|
||||
return ErrTotpSetupExpired
|
||||
}
|
||||
|
||||
// Verify the TOTP code
|
||||
if !totp.Validate(totpCode, session.Secret) {
|
||||
return ErrTotpInvalidCode
|
||||
}
|
||||
|
||||
setupSecretPrefix := "N/A"
|
||||
if len(session.Secret) >= 4 {
|
||||
setupSecretPrefix = session.Secret[:4]
|
||||
}
|
||||
slog.Debug("totp_complete_setup_before_encrypt",
|
||||
"user_id", userID,
|
||||
"secret_len", len(session.Secret),
|
||||
"secret_prefix", setupSecretPrefix)
|
||||
|
||||
// Encrypt the secret
|
||||
encryptedSecret, err := s.encryptor.Encrypt(session.Secret)
|
||||
if err != nil {
|
||||
return fmt.Errorf("encrypt totp secret: %w", err)
|
||||
}
|
||||
|
||||
slog.Debug("totp_complete_setup_encrypted",
|
||||
"user_id", userID,
|
||||
"encrypted_len", len(encryptedSecret))
|
||||
|
||||
// Verify encryption by decrypting
|
||||
decrypted, decErr := s.encryptor.Decrypt(encryptedSecret)
|
||||
if decErr != nil {
|
||||
slog.Debug("totp_complete_setup_verify_failed",
|
||||
"user_id", userID,
|
||||
"error", decErr)
|
||||
} else {
|
||||
decryptedPrefix := "N/A"
|
||||
if len(decrypted) >= 4 {
|
||||
decryptedPrefix = decrypted[:4]
|
||||
}
|
||||
slog.Debug("totp_complete_setup_verified",
|
||||
"user_id", userID,
|
||||
"original_len", len(session.Secret),
|
||||
"decrypted_len", len(decrypted),
|
||||
"match", session.Secret == decrypted,
|
||||
"decrypted_prefix", decryptedPrefix)
|
||||
}
|
||||
|
||||
// Update user with encrypted TOTP secret
|
||||
if err := s.userRepo.UpdateTotpSecret(ctx, userID, &encryptedSecret); err != nil {
|
||||
return fmt.Errorf("update totp secret: %w", err)
|
||||
}
|
||||
|
||||
// Enable TOTP for the user
|
||||
if err := s.userRepo.EnableTotp(ctx, userID); err != nil {
|
||||
return fmt.Errorf("enable totp: %w", err)
|
||||
}
|
||||
|
||||
// Clean up the setup session
|
||||
_ = s.cache.DeleteSetupSession(ctx, userID)
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// Disable disables TOTP for a user
|
||||
// If email verification is enabled, emailCode is required; otherwise password is required
|
||||
func (s *TotpService) Disable(ctx context.Context, userID int64, emailCode, password string) error {
|
||||
// Get user
|
||||
user, err := s.userRepo.GetByID(ctx, userID)
|
||||
if err != nil {
|
||||
return fmt.Errorf("get user: %w", err)
|
||||
}
|
||||
|
||||
if !user.TotpEnabled {
|
||||
return ErrTotpNotSetup
|
||||
}
|
||||
|
||||
// Verify identity based on email verification setting
|
||||
if s.settingService.IsEmailVerifyEnabled(ctx) {
|
||||
// Email verification enabled - verify email code
|
||||
if emailCode == "" {
|
||||
return ErrVerifyCodeRequired
|
||||
}
|
||||
if err := s.emailService.VerifyCode(ctx, user.Email, emailCode); err != nil {
|
||||
return err
|
||||
}
|
||||
} else {
|
||||
// Email verification disabled - verify password
|
||||
if password == "" {
|
||||
return ErrPasswordRequired
|
||||
}
|
||||
if !user.CheckPassword(password) {
|
||||
return ErrPasswordIncorrect
|
||||
}
|
||||
}
|
||||
|
||||
// Disable TOTP
|
||||
if err := s.userRepo.DisableTotp(ctx, userID); err != nil {
|
||||
return fmt.Errorf("disable totp: %w", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// VerifyCode verifies a TOTP code for a user
|
||||
func (s *TotpService) VerifyCode(ctx context.Context, userID int64, code string) error {
|
||||
slog.Debug("totp_verify_code_called",
|
||||
"user_id", userID,
|
||||
"code_len", len(code))
|
||||
|
||||
// Check rate limiting
|
||||
attempts, err := s.cache.GetVerifyAttempts(ctx, userID)
|
||||
if err == nil && attempts >= maxTotpAttempts {
|
||||
return ErrTotpTooManyAttempts
|
||||
}
|
||||
|
||||
// Get user
|
||||
user, err := s.userRepo.GetByID(ctx, userID)
|
||||
if err != nil {
|
||||
slog.Debug("totp_verify_get_user_failed",
|
||||
"user_id", userID,
|
||||
"error", err)
|
||||
return infraerrors.InternalServer("TOTP_VERIFY_ERROR", "failed to verify totp code")
|
||||
}
|
||||
|
||||
if !user.TotpEnabled || user.TotpSecretEncrypted == nil {
|
||||
slog.Debug("totp_verify_not_setup",
|
||||
"user_id", userID,
|
||||
"enabled", user.TotpEnabled,
|
||||
"has_secret", user.TotpSecretEncrypted != nil)
|
||||
return ErrTotpNotSetup
|
||||
}
|
||||
|
||||
slog.Debug("totp_verify_encrypted_secret",
|
||||
"user_id", userID,
|
||||
"encrypted_len", len(*user.TotpSecretEncrypted))
|
||||
|
||||
// Decrypt the secret
|
||||
secret, err := s.encryptor.Decrypt(*user.TotpSecretEncrypted)
|
||||
if err != nil {
|
||||
slog.Debug("totp_verify_decrypt_failed",
|
||||
"user_id", userID,
|
||||
"error", err)
|
||||
return infraerrors.InternalServer("TOTP_VERIFY_ERROR", "failed to verify totp code")
|
||||
}
|
||||
|
||||
secretPrefix := "N/A"
|
||||
if len(secret) >= 4 {
|
||||
secretPrefix = secret[:4]
|
||||
}
|
||||
slog.Debug("totp_verify_decrypted",
|
||||
"user_id", userID,
|
||||
"secret_len", len(secret),
|
||||
"secret_prefix", secretPrefix)
|
||||
|
||||
// Verify the code
|
||||
valid := totp.Validate(code, secret)
|
||||
slog.Debug("totp_verify_result",
|
||||
"user_id", userID,
|
||||
"valid", valid,
|
||||
"secret_len", len(secret),
|
||||
"secret_prefix", secretPrefix,
|
||||
"server_time", time.Now().UTC().Format(time.RFC3339))
|
||||
|
||||
if !valid {
|
||||
// Increment failed attempts
|
||||
_, _ = s.cache.IncrementVerifyAttempts(ctx, userID)
|
||||
return ErrTotpInvalidCode
|
||||
}
|
||||
|
||||
// Clear attempt counter on success
|
||||
_ = s.cache.ClearVerifyAttempts(ctx, userID)
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// CreateLoginSession creates a temporary login session for 2FA
|
||||
func (s *TotpService) CreateLoginSession(ctx context.Context, userID int64, email string) (string, error) {
|
||||
// Generate a random temp token
|
||||
tempToken, err := generateRandomToken(32)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("generate temp token: %w", err)
|
||||
}
|
||||
|
||||
session := &TotpLoginSession{
|
||||
UserID: userID,
|
||||
Email: email,
|
||||
TokenExpiry: time.Now().Add(totpLoginTTL),
|
||||
}
|
||||
|
||||
if err := s.cache.SetLoginSession(ctx, tempToken, session, totpLoginTTL); err != nil {
|
||||
return "", fmt.Errorf("store login session: %w", err)
|
||||
}
|
||||
|
||||
return tempToken, nil
|
||||
}
|
||||
|
||||
// GetLoginSession retrieves a login session
|
||||
func (s *TotpService) GetLoginSession(ctx context.Context, tempToken string) (*TotpLoginSession, error) {
|
||||
return s.cache.GetLoginSession(ctx, tempToken)
|
||||
}
|
||||
|
||||
// DeleteLoginSession deletes a login session
|
||||
func (s *TotpService) DeleteLoginSession(ctx context.Context, tempToken string) error {
|
||||
return s.cache.DeleteLoginSession(ctx, tempToken)
|
||||
}
|
||||
|
||||
// IsTotpEnabledForUser checks if TOTP is enabled for a specific user
|
||||
func (s *TotpService) IsTotpEnabledForUser(ctx context.Context, userID int64) (bool, error) {
|
||||
user, err := s.userRepo.GetByID(ctx, userID)
|
||||
if err != nil {
|
||||
return false, fmt.Errorf("get user: %w", err)
|
||||
}
|
||||
return user.TotpEnabled, nil
|
||||
}
|
||||
|
||||
// MaskEmail masks an email address for display
|
||||
func MaskEmail(email string) string {
|
||||
if len(email) < 3 {
|
||||
return "***"
|
||||
}
|
||||
|
||||
atIdx := -1
|
||||
for i, c := range email {
|
||||
if c == '@' {
|
||||
atIdx = i
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
if atIdx == -1 || atIdx < 1 {
|
||||
return email[:1] + "***"
|
||||
}
|
||||
|
||||
localPart := email[:atIdx]
|
||||
domain := email[atIdx:]
|
||||
|
||||
if len(localPart) <= 2 {
|
||||
return localPart[:1] + "***" + domain
|
||||
}
|
||||
|
||||
return localPart[:1] + "***" + localPart[len(localPart)-1:] + domain
|
||||
}
|
||||
|
||||
// generateRandomToken generates a random hex-encoded token
|
||||
func generateRandomToken(byteLength int) (string, error) {
|
||||
b := make([]byte, byteLength)
|
||||
if _, err := rand.Read(b); err != nil {
|
||||
return "", err
|
||||
}
|
||||
return hex.EncodeToString(b), nil
|
||||
}
|
||||
|
||||
// VerificationMethod represents the method required for TOTP operations
|
||||
type VerificationMethod struct {
|
||||
Method string `json:"method"` // "email" or "password"
|
||||
}
|
||||
|
||||
// GetVerificationMethod returns the verification method for TOTP operations
|
||||
func (s *TotpService) GetVerificationMethod(ctx context.Context) *VerificationMethod {
|
||||
if s.settingService.IsEmailVerifyEnabled(ctx) {
|
||||
return &VerificationMethod{Method: "email"}
|
||||
}
|
||||
return &VerificationMethod{Method: "password"}
|
||||
}
|
||||
|
||||
// SendVerifyCode sends an email verification code for TOTP operations
|
||||
func (s *TotpService) SendVerifyCode(ctx context.Context, userID int64) error {
|
||||
// Check if email verification is enabled
|
||||
if !s.settingService.IsEmailVerifyEnabled(ctx) {
|
||||
return infraerrors.BadRequest("EMAIL_VERIFY_NOT_ENABLED", "email verification is not enabled")
|
||||
}
|
||||
|
||||
// Get user email
|
||||
user, err := s.userRepo.GetByID(ctx, userID)
|
||||
if err != nil {
|
||||
return fmt.Errorf("get user: %w", err)
|
||||
}
|
||||
|
||||
// Get site name for email
|
||||
siteName := s.settingService.GetSiteName(ctx)
|
||||
|
||||
// Send verification code via queue
|
||||
return s.emailQueueService.EnqueueVerifyCode(user.Email, siteName)
|
||||
}
|
||||
@@ -21,6 +21,11 @@ type User struct {
|
||||
CreatedAt time.Time
|
||||
UpdatedAt time.Time
|
||||
|
||||
// TOTP 双因素认证字段
|
||||
TotpSecretEncrypted *string // AES-256-GCM 加密的 TOTP 密钥
|
||||
TotpEnabled bool // 是否启用 TOTP
|
||||
TotpEnabledAt *time.Time // TOTP 启用时间
|
||||
|
||||
APIKeys []APIKey
|
||||
Subscriptions []UserSubscription
|
||||
}
|
||||
|
||||
@@ -38,6 +38,11 @@ type UserRepository interface {
|
||||
UpdateConcurrency(ctx context.Context, id int64, amount int) error
|
||||
ExistsByEmail(ctx context.Context, email string) (bool, error)
|
||||
RemoveGroupFromAllowedGroups(ctx context.Context, groupID int64) (int64, error)
|
||||
|
||||
// TOTP 相关方法
|
||||
UpdateTotpSecret(ctx context.Context, userID int64, encryptedSecret *string) error
|
||||
EnableTotp(ctx context.Context, userID int64) error
|
||||
DisableTotp(ctx context.Context, userID int64) error
|
||||
}
|
||||
|
||||
// UpdateProfileRequest 更新用户资料请求
|
||||
|
||||
@@ -18,7 +18,7 @@ type UserSubscriptionRepository interface {
|
||||
ListByUserID(ctx context.Context, userID int64) ([]UserSubscription, error)
|
||||
ListActiveByUserID(ctx context.Context, userID int64) ([]UserSubscription, error)
|
||||
ListByGroupID(ctx context.Context, groupID int64, params pagination.PaginationParams) ([]UserSubscription, *pagination.PaginationResult, error)
|
||||
List(ctx context.Context, params pagination.PaginationParams, userID, groupID *int64, status string) ([]UserSubscription, *pagination.PaginationResult, error)
|
||||
List(ctx context.Context, params pagination.PaginationParams, userID, groupID *int64, status, sortBy, sortOrder string) ([]UserSubscription, *pagination.PaginationResult, error)
|
||||
|
||||
ExistsByUserIDAndGroupID(ctx context.Context, userID, groupID int64) (bool, error)
|
||||
ExtendExpiry(ctx context.Context, subscriptionID int64, newExpiresAt time.Time) error
|
||||
|
||||
@@ -72,6 +72,13 @@ func ProvideAccountExpiryService(accountRepo AccountRepository) *AccountExpirySe
|
||||
return svc
|
||||
}
|
||||
|
||||
// ProvideSubscriptionExpiryService creates and starts SubscriptionExpiryService.
|
||||
func ProvideSubscriptionExpiryService(userSubRepo UserSubscriptionRepository) *SubscriptionExpiryService {
|
||||
svc := NewSubscriptionExpiryService(userSubRepo, time.Minute)
|
||||
svc.Start()
|
||||
return svc
|
||||
}
|
||||
|
||||
// ProvideTimingWheelService creates and starts TimingWheelService
|
||||
func ProvideTimingWheelService() (*TimingWheelService, error) {
|
||||
svc, err := NewTimingWheelService()
|
||||
@@ -256,6 +263,7 @@ var ProviderSet = wire.NewSet(
|
||||
ProvideUpdateService,
|
||||
ProvideTokenRefreshService,
|
||||
ProvideAccountExpiryService,
|
||||
ProvideSubscriptionExpiryService,
|
||||
ProvideTimingWheelService,
|
||||
ProvideDashboardAggregationService,
|
||||
ProvideUsageCleanupService,
|
||||
@@ -263,4 +271,5 @@ var ProviderSet = wire.NewSet(
|
||||
NewAntigravityQuotaFetcher,
|
||||
NewUserAttributeService,
|
||||
NewUsageCache,
|
||||
NewTotpService,
|
||||
)
|
||||
|
||||
Reference in New Issue
Block a user