Merge branch 'main' of github.com:Wei-Shaw/sub2api
This commit is contained in:
@@ -547,9 +547,18 @@ func (h *AccountHandler) Refresh(c *gin.Context) {
|
||||
}
|
||||
}
|
||||
|
||||
// 如果 project_id 获取失败,先更新凭证,再标记账户为 error
|
||||
// 特殊处理 project_id:如果新值为空但旧值非空,保留旧值
|
||||
// 这确保了即使 LoadCodeAssist 失败,project_id 也不会丢失
|
||||
if newProjectID, _ := newCredentials["project_id"].(string); newProjectID == "" {
|
||||
if oldProjectID := strings.TrimSpace(account.GetCredential("project_id")); oldProjectID != "" {
|
||||
newCredentials["project_id"] = oldProjectID
|
||||
}
|
||||
}
|
||||
|
||||
// 如果 project_id 获取失败,更新凭证但不标记为 error
|
||||
// LoadCodeAssist 失败可能是临时网络问题,给它机会在下次自动刷新时重试
|
||||
if tokenInfo.ProjectIDMissing {
|
||||
// 先更新凭证
|
||||
// 先更新凭证(token 本身刷新成功了)
|
||||
_, updateErr := h.adminService.UpdateAccount(c.Request.Context(), accountID, &service.UpdateAccountInput{
|
||||
Credentials: newCredentials,
|
||||
})
|
||||
@@ -557,14 +566,10 @@ func (h *AccountHandler) Refresh(c *gin.Context) {
|
||||
response.InternalError(c, "Failed to update credentials: "+updateErr.Error())
|
||||
return
|
||||
}
|
||||
// 标记账户为 error
|
||||
if setErr := h.adminService.SetAccountError(c.Request.Context(), accountID, "missing_project_id: 账户缺少project id,可能无法使用Antigravity"); setErr != nil {
|
||||
response.InternalError(c, "Failed to set account error: "+setErr.Error())
|
||||
return
|
||||
}
|
||||
// 不标记为 error,只返回警告信息
|
||||
response.Success(c, gin.H{
|
||||
"message": "Token refreshed but project_id is missing, account marked as error",
|
||||
"warning": "missing_project_id",
|
||||
"message": "Token refreshed successfully, but project_id could not be retrieved (will retry automatically)",
|
||||
"warning": "missing_project_id_temporary",
|
||||
})
|
||||
return
|
||||
}
|
||||
@@ -668,6 +673,15 @@ func (h *AccountHandler) ClearError(c *gin.Context) {
|
||||
return
|
||||
}
|
||||
|
||||
// 清除错误后,同时清除 token 缓存,确保下次请求会获取最新的 token(触发刷新或从 DB 读取)
|
||||
// 这解决了管理员重置账号状态后,旧的失效 token 仍在缓存中导致立即再次 401 的问题
|
||||
if h.tokenCacheInvalidator != nil && account.IsOAuth() {
|
||||
if invalidateErr := h.tokenCacheInvalidator.InvalidateToken(c.Request.Context(), account); invalidateErr != nil {
|
||||
// 缓存失效失败只记录日志,不影响主流程
|
||||
_ = c.Error(invalidateErr)
|
||||
}
|
||||
}
|
||||
|
||||
response.Success(c, dto.AccountFromService(account))
|
||||
}
|
||||
|
||||
|
||||
@@ -48,6 +48,9 @@ func (h *SettingHandler) GetSettings(c *gin.Context) {
|
||||
RegistrationEnabled: settings.RegistrationEnabled,
|
||||
EmailVerifyEnabled: settings.EmailVerifyEnabled,
|
||||
PromoCodeEnabled: settings.PromoCodeEnabled,
|
||||
PasswordResetEnabled: settings.PasswordResetEnabled,
|
||||
TotpEnabled: settings.TotpEnabled,
|
||||
TotpEncryptionKeyConfigured: h.settingService.IsTotpEncryptionKeyConfigured(),
|
||||
SMTPHost: settings.SMTPHost,
|
||||
SMTPPort: settings.SMTPPort,
|
||||
SMTPUsername: settings.SMTPUsername,
|
||||
@@ -89,9 +92,11 @@ func (h *SettingHandler) GetSettings(c *gin.Context) {
|
||||
// UpdateSettingsRequest 更新设置请求
|
||||
type UpdateSettingsRequest struct {
|
||||
// 注册设置
|
||||
RegistrationEnabled bool `json:"registration_enabled"`
|
||||
EmailVerifyEnabled bool `json:"email_verify_enabled"`
|
||||
PromoCodeEnabled bool `json:"promo_code_enabled"`
|
||||
RegistrationEnabled bool `json:"registration_enabled"`
|
||||
EmailVerifyEnabled bool `json:"email_verify_enabled"`
|
||||
PromoCodeEnabled bool `json:"promo_code_enabled"`
|
||||
PasswordResetEnabled bool `json:"password_reset_enabled"`
|
||||
TotpEnabled bool `json:"totp_enabled"` // TOTP 双因素认证
|
||||
|
||||
// 邮件服务设置
|
||||
SMTPHost string `json:"smtp_host"`
|
||||
@@ -198,6 +203,16 @@ func (h *SettingHandler) UpdateSettings(c *gin.Context) {
|
||||
}
|
||||
}
|
||||
|
||||
// TOTP 双因素认证参数验证
|
||||
// 只有手动配置了加密密钥才允许启用 TOTP 功能
|
||||
if req.TotpEnabled && !previousSettings.TotpEnabled {
|
||||
// 尝试启用 TOTP,检查加密密钥是否已手动配置
|
||||
if !h.settingService.IsTotpEncryptionKeyConfigured() {
|
||||
response.BadRequest(c, "Cannot enable TOTP: TOTP_ENCRYPTION_KEY environment variable must be configured first. Generate a key with 'openssl rand -hex 32' and set it in your environment.")
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
// LinuxDo Connect 参数验证
|
||||
if req.LinuxDoConnectEnabled {
|
||||
req.LinuxDoConnectClientID = strings.TrimSpace(req.LinuxDoConnectClientID)
|
||||
@@ -243,6 +258,8 @@ func (h *SettingHandler) UpdateSettings(c *gin.Context) {
|
||||
RegistrationEnabled: req.RegistrationEnabled,
|
||||
EmailVerifyEnabled: req.EmailVerifyEnabled,
|
||||
PromoCodeEnabled: req.PromoCodeEnabled,
|
||||
PasswordResetEnabled: req.PasswordResetEnabled,
|
||||
TotpEnabled: req.TotpEnabled,
|
||||
SMTPHost: req.SMTPHost,
|
||||
SMTPPort: req.SMTPPort,
|
||||
SMTPUsername: req.SMTPUsername,
|
||||
@@ -318,6 +335,9 @@ func (h *SettingHandler) UpdateSettings(c *gin.Context) {
|
||||
RegistrationEnabled: updatedSettings.RegistrationEnabled,
|
||||
EmailVerifyEnabled: updatedSettings.EmailVerifyEnabled,
|
||||
PromoCodeEnabled: updatedSettings.PromoCodeEnabled,
|
||||
PasswordResetEnabled: updatedSettings.PasswordResetEnabled,
|
||||
TotpEnabled: updatedSettings.TotpEnabled,
|
||||
TotpEncryptionKeyConfigured: h.settingService.IsTotpEncryptionKeyConfigured(),
|
||||
SMTPHost: updatedSettings.SMTPHost,
|
||||
SMTPPort: updatedSettings.SMTPPort,
|
||||
SMTPUsername: updatedSettings.SMTPUsername,
|
||||
@@ -384,6 +404,12 @@ func diffSettings(before *service.SystemSettings, after *service.SystemSettings,
|
||||
if before.EmailVerifyEnabled != after.EmailVerifyEnabled {
|
||||
changed = append(changed, "email_verify_enabled")
|
||||
}
|
||||
if before.PasswordResetEnabled != after.PasswordResetEnabled {
|
||||
changed = append(changed, "password_reset_enabled")
|
||||
}
|
||||
if before.TotpEnabled != after.TotpEnabled {
|
||||
changed = append(changed, "totp_enabled")
|
||||
}
|
||||
if before.SMTPHost != after.SMTPHost {
|
||||
changed = append(changed, "smtp_host")
|
||||
}
|
||||
|
||||
@@ -77,7 +77,11 @@ func (h *SubscriptionHandler) List(c *gin.Context) {
|
||||
}
|
||||
status := c.Query("status")
|
||||
|
||||
subscriptions, pagination, err := h.subscriptionService.List(c.Request.Context(), page, pageSize, userID, groupID, status)
|
||||
// Parse sorting parameters
|
||||
sortBy := c.DefaultQuery("sort_by", "created_at")
|
||||
sortOrder := c.DefaultQuery("sort_order", "desc")
|
||||
|
||||
subscriptions, pagination, err := h.subscriptionService.List(c.Request.Context(), page, pageSize, userID, groupID, status, sortBy, sortOrder)
|
||||
if err != nil {
|
||||
response.ErrorFrom(c, err)
|
||||
return
|
||||
|
||||
@@ -1,6 +1,8 @@
|
||||
package handler
|
||||
|
||||
import (
|
||||
"log/slog"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/config"
|
||||
"github.com/Wei-Shaw/sub2api/internal/handler/dto"
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/ip"
|
||||
@@ -18,16 +20,18 @@ type AuthHandler struct {
|
||||
userService *service.UserService
|
||||
settingSvc *service.SettingService
|
||||
promoService *service.PromoService
|
||||
totpService *service.TotpService
|
||||
}
|
||||
|
||||
// NewAuthHandler creates a new AuthHandler
|
||||
func NewAuthHandler(cfg *config.Config, authService *service.AuthService, userService *service.UserService, settingService *service.SettingService, promoService *service.PromoService) *AuthHandler {
|
||||
func NewAuthHandler(cfg *config.Config, authService *service.AuthService, userService *service.UserService, settingService *service.SettingService, promoService *service.PromoService, totpService *service.TotpService) *AuthHandler {
|
||||
return &AuthHandler{
|
||||
cfg: cfg,
|
||||
authService: authService,
|
||||
userService: userService,
|
||||
settingSvc: settingService,
|
||||
promoService: promoService,
|
||||
totpService: totpService,
|
||||
}
|
||||
}
|
||||
|
||||
@@ -144,6 +148,100 @@ func (h *AuthHandler) Login(c *gin.Context) {
|
||||
return
|
||||
}
|
||||
|
||||
// Check if TOTP 2FA is enabled for this user
|
||||
if h.totpService != nil && h.settingSvc.IsTotpEnabled(c.Request.Context()) && user.TotpEnabled {
|
||||
// Create a temporary login session for 2FA
|
||||
tempToken, err := h.totpService.CreateLoginSession(c.Request.Context(), user.ID, user.Email)
|
||||
if err != nil {
|
||||
response.InternalError(c, "Failed to create 2FA session")
|
||||
return
|
||||
}
|
||||
|
||||
response.Success(c, TotpLoginResponse{
|
||||
Requires2FA: true,
|
||||
TempToken: tempToken,
|
||||
UserEmailMasked: service.MaskEmail(user.Email),
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
response.Success(c, AuthResponse{
|
||||
AccessToken: token,
|
||||
TokenType: "Bearer",
|
||||
User: dto.UserFromService(user),
|
||||
})
|
||||
}
|
||||
|
||||
// TotpLoginResponse represents the response when 2FA is required
|
||||
type TotpLoginResponse struct {
|
||||
Requires2FA bool `json:"requires_2fa"`
|
||||
TempToken string `json:"temp_token,omitempty"`
|
||||
UserEmailMasked string `json:"user_email_masked,omitempty"`
|
||||
}
|
||||
|
||||
// Login2FARequest represents the 2FA login request
|
||||
type Login2FARequest struct {
|
||||
TempToken string `json:"temp_token" binding:"required"`
|
||||
TotpCode string `json:"totp_code" binding:"required,len=6"`
|
||||
}
|
||||
|
||||
// Login2FA completes the login with 2FA verification
|
||||
// POST /api/v1/auth/login/2fa
|
||||
func (h *AuthHandler) Login2FA(c *gin.Context) {
|
||||
var req Login2FARequest
|
||||
if err := c.ShouldBindJSON(&req); err != nil {
|
||||
response.BadRequest(c, "Invalid request: "+err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
slog.Debug("login_2fa_request",
|
||||
"temp_token_len", len(req.TempToken),
|
||||
"totp_code_len", len(req.TotpCode))
|
||||
|
||||
// Get the login session
|
||||
session, err := h.totpService.GetLoginSession(c.Request.Context(), req.TempToken)
|
||||
if err != nil || session == nil {
|
||||
tokenPrefix := ""
|
||||
if len(req.TempToken) >= 8 {
|
||||
tokenPrefix = req.TempToken[:8]
|
||||
}
|
||||
slog.Debug("login_2fa_session_invalid",
|
||||
"temp_token_prefix", tokenPrefix,
|
||||
"error", err)
|
||||
response.BadRequest(c, "Invalid or expired 2FA session")
|
||||
return
|
||||
}
|
||||
|
||||
slog.Debug("login_2fa_session_found",
|
||||
"user_id", session.UserID,
|
||||
"email", session.Email)
|
||||
|
||||
// Verify the TOTP code
|
||||
if err := h.totpService.VerifyCode(c.Request.Context(), session.UserID, req.TotpCode); err != nil {
|
||||
slog.Debug("login_2fa_verify_failed",
|
||||
"user_id", session.UserID,
|
||||
"error", err)
|
||||
response.ErrorFrom(c, err)
|
||||
return
|
||||
}
|
||||
|
||||
// Delete the login session
|
||||
_ = h.totpService.DeleteLoginSession(c.Request.Context(), req.TempToken)
|
||||
|
||||
// Get the user
|
||||
user, err := h.userService.GetByID(c.Request.Context(), session.UserID)
|
||||
if err != nil {
|
||||
response.ErrorFrom(c, err)
|
||||
return
|
||||
}
|
||||
|
||||
// Generate the JWT token
|
||||
token, err := h.authService.GenerateToken(user)
|
||||
if err != nil {
|
||||
response.InternalError(c, "Failed to generate token")
|
||||
return
|
||||
}
|
||||
|
||||
response.Success(c, AuthResponse{
|
||||
AccessToken: token,
|
||||
TokenType: "Bearer",
|
||||
@@ -247,3 +345,85 @@ func (h *AuthHandler) ValidatePromoCode(c *gin.Context) {
|
||||
BonusAmount: promoCode.BonusAmount,
|
||||
})
|
||||
}
|
||||
|
||||
// ForgotPasswordRequest 忘记密码请求
|
||||
type ForgotPasswordRequest struct {
|
||||
Email string `json:"email" binding:"required,email"`
|
||||
TurnstileToken string `json:"turnstile_token"`
|
||||
}
|
||||
|
||||
// ForgotPasswordResponse 忘记密码响应
|
||||
type ForgotPasswordResponse struct {
|
||||
Message string `json:"message"`
|
||||
}
|
||||
|
||||
// ForgotPassword 请求密码重置
|
||||
// POST /api/v1/auth/forgot-password
|
||||
func (h *AuthHandler) ForgotPassword(c *gin.Context) {
|
||||
var req ForgotPasswordRequest
|
||||
if err := c.ShouldBindJSON(&req); err != nil {
|
||||
response.BadRequest(c, "Invalid request: "+err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
// Turnstile 验证
|
||||
if err := h.authService.VerifyTurnstile(c.Request.Context(), req.TurnstileToken, ip.GetClientIP(c)); err != nil {
|
||||
response.ErrorFrom(c, err)
|
||||
return
|
||||
}
|
||||
|
||||
// Build frontend base URL from request
|
||||
scheme := "https"
|
||||
if c.Request.TLS == nil {
|
||||
// Check X-Forwarded-Proto header (common in reverse proxy setups)
|
||||
if proto := c.GetHeader("X-Forwarded-Proto"); proto != "" {
|
||||
scheme = proto
|
||||
} else {
|
||||
scheme = "http"
|
||||
}
|
||||
}
|
||||
frontendBaseURL := scheme + "://" + c.Request.Host
|
||||
|
||||
// Request password reset (async)
|
||||
// Note: This returns success even if email doesn't exist (to prevent enumeration)
|
||||
if err := h.authService.RequestPasswordResetAsync(c.Request.Context(), req.Email, frontendBaseURL); err != nil {
|
||||
response.ErrorFrom(c, err)
|
||||
return
|
||||
}
|
||||
|
||||
response.Success(c, ForgotPasswordResponse{
|
||||
Message: "If your email is registered, you will receive a password reset link shortly.",
|
||||
})
|
||||
}
|
||||
|
||||
// ResetPasswordRequest 重置密码请求
|
||||
type ResetPasswordRequest struct {
|
||||
Email string `json:"email" binding:"required,email"`
|
||||
Token string `json:"token" binding:"required"`
|
||||
NewPassword string `json:"new_password" binding:"required,min=6"`
|
||||
}
|
||||
|
||||
// ResetPasswordResponse 重置密码响应
|
||||
type ResetPasswordResponse struct {
|
||||
Message string `json:"message"`
|
||||
}
|
||||
|
||||
// ResetPassword 重置密码
|
||||
// POST /api/v1/auth/reset-password
|
||||
func (h *AuthHandler) ResetPassword(c *gin.Context) {
|
||||
var req ResetPasswordRequest
|
||||
if err := c.ShouldBindJSON(&req); err != nil {
|
||||
response.BadRequest(c, "Invalid request: "+err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
// Reset password
|
||||
if err := h.authService.ResetPassword(c.Request.Context(), req.Email, req.Token, req.NewPassword); err != nil {
|
||||
response.ErrorFrom(c, err)
|
||||
return
|
||||
}
|
||||
|
||||
response.Success(c, ResetPasswordResponse{
|
||||
Message: "Your password has been reset successfully. You can now log in with your new password.",
|
||||
})
|
||||
}
|
||||
|
||||
@@ -2,9 +2,12 @@ package dto
|
||||
|
||||
// SystemSettings represents the admin settings API response payload.
|
||||
type SystemSettings struct {
|
||||
RegistrationEnabled bool `json:"registration_enabled"`
|
||||
EmailVerifyEnabled bool `json:"email_verify_enabled"`
|
||||
PromoCodeEnabled bool `json:"promo_code_enabled"`
|
||||
RegistrationEnabled bool `json:"registration_enabled"`
|
||||
EmailVerifyEnabled bool `json:"email_verify_enabled"`
|
||||
PromoCodeEnabled bool `json:"promo_code_enabled"`
|
||||
PasswordResetEnabled bool `json:"password_reset_enabled"`
|
||||
TotpEnabled bool `json:"totp_enabled"` // TOTP 双因素认证
|
||||
TotpEncryptionKeyConfigured bool `json:"totp_encryption_key_configured"` // TOTP 加密密钥是否已配置
|
||||
|
||||
SMTPHost string `json:"smtp_host"`
|
||||
SMTPPort int `json:"smtp_port"`
|
||||
@@ -54,21 +57,23 @@ type SystemSettings struct {
|
||||
}
|
||||
|
||||
type PublicSettings struct {
|
||||
RegistrationEnabled bool `json:"registration_enabled"`
|
||||
EmailVerifyEnabled bool `json:"email_verify_enabled"`
|
||||
PromoCodeEnabled bool `json:"promo_code_enabled"`
|
||||
TurnstileEnabled bool `json:"turnstile_enabled"`
|
||||
TurnstileSiteKey string `json:"turnstile_site_key"`
|
||||
SiteName string `json:"site_name"`
|
||||
SiteLogo string `json:"site_logo"`
|
||||
SiteSubtitle string `json:"site_subtitle"`
|
||||
APIBaseURL string `json:"api_base_url"`
|
||||
ContactInfo string `json:"contact_info"`
|
||||
DocURL string `json:"doc_url"`
|
||||
HomeContent string `json:"home_content"`
|
||||
HideCcsImportButton bool `json:"hide_ccs_import_button"`
|
||||
LinuxDoOAuthEnabled bool `json:"linuxdo_oauth_enabled"`
|
||||
Version string `json:"version"`
|
||||
RegistrationEnabled bool `json:"registration_enabled"`
|
||||
EmailVerifyEnabled bool `json:"email_verify_enabled"`
|
||||
PromoCodeEnabled bool `json:"promo_code_enabled"`
|
||||
PasswordResetEnabled bool `json:"password_reset_enabled"`
|
||||
TotpEnabled bool `json:"totp_enabled"` // TOTP 双因素认证
|
||||
TurnstileEnabled bool `json:"turnstile_enabled"`
|
||||
TurnstileSiteKey string `json:"turnstile_site_key"`
|
||||
SiteName string `json:"site_name"`
|
||||
SiteLogo string `json:"site_logo"`
|
||||
SiteSubtitle string `json:"site_subtitle"`
|
||||
APIBaseURL string `json:"api_base_url"`
|
||||
ContactInfo string `json:"contact_info"`
|
||||
DocURL string `json:"doc_url"`
|
||||
HomeContent string `json:"home_content"`
|
||||
HideCcsImportButton bool `json:"hide_ccs_import_button"`
|
||||
LinuxDoOAuthEnabled bool `json:"linuxdo_oauth_enabled"`
|
||||
Version string `json:"version"`
|
||||
}
|
||||
|
||||
// StreamTimeoutSettings 流超时处理配置 DTO
|
||||
|
||||
@@ -209,17 +209,20 @@ func (h *GatewayHandler) Messages(c *gin.Context) {
|
||||
account := selection.Account
|
||||
setOpsSelectedAccount(c, account.ID)
|
||||
|
||||
// 检查预热请求拦截(在账号选择后、转发前检查)
|
||||
if account.IsInterceptWarmupEnabled() && isWarmupRequest(body) {
|
||||
if selection.Acquired && selection.ReleaseFunc != nil {
|
||||
selection.ReleaseFunc()
|
||||
// 检查请求拦截(预热请求、SUGGESTION MODE等)
|
||||
if account.IsInterceptWarmupEnabled() {
|
||||
interceptType := detectInterceptType(body)
|
||||
if interceptType != InterceptTypeNone {
|
||||
if selection.Acquired && selection.ReleaseFunc != nil {
|
||||
selection.ReleaseFunc()
|
||||
}
|
||||
if reqStream {
|
||||
sendMockInterceptStream(c, reqModel, interceptType)
|
||||
} else {
|
||||
sendMockInterceptResponse(c, reqModel, interceptType)
|
||||
}
|
||||
return
|
||||
}
|
||||
if reqStream {
|
||||
sendMockWarmupStream(c, reqModel)
|
||||
} else {
|
||||
sendMockWarmupResponse(c, reqModel)
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
// 3. 获取账号并发槽位
|
||||
@@ -344,17 +347,20 @@ func (h *GatewayHandler) Messages(c *gin.Context) {
|
||||
account := selection.Account
|
||||
setOpsSelectedAccount(c, account.ID)
|
||||
|
||||
// 检查预热请求拦截(在账号选择后、转发前检查)
|
||||
if account.IsInterceptWarmupEnabled() && isWarmupRequest(body) {
|
||||
if selection.Acquired && selection.ReleaseFunc != nil {
|
||||
selection.ReleaseFunc()
|
||||
// 检查请求拦截(预热请求、SUGGESTION MODE等)
|
||||
if account.IsInterceptWarmupEnabled() {
|
||||
interceptType := detectInterceptType(body)
|
||||
if interceptType != InterceptTypeNone {
|
||||
if selection.Acquired && selection.ReleaseFunc != nil {
|
||||
selection.ReleaseFunc()
|
||||
}
|
||||
if reqStream {
|
||||
sendMockInterceptStream(c, reqModel, interceptType)
|
||||
} else {
|
||||
sendMockInterceptResponse(c, reqModel, interceptType)
|
||||
}
|
||||
return
|
||||
}
|
||||
if reqStream {
|
||||
sendMockWarmupStream(c, reqModel)
|
||||
} else {
|
||||
sendMockWarmupResponse(c, reqModel)
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
// 3. 获取账号并发槽位
|
||||
@@ -768,17 +774,30 @@ func (h *GatewayHandler) CountTokens(c *gin.Context) {
|
||||
}
|
||||
}
|
||||
|
||||
// isWarmupRequest 检测是否为预热请求(标题生成、Warmup等)
|
||||
func isWarmupRequest(body []byte) bool {
|
||||
// 快速检查:如果body不包含关键字,直接返回false
|
||||
// InterceptType 表示请求拦截类型
|
||||
type InterceptType int
|
||||
|
||||
const (
|
||||
InterceptTypeNone InterceptType = iota
|
||||
InterceptTypeWarmup // 预热请求(返回 "New Conversation")
|
||||
InterceptTypeSuggestionMode // SUGGESTION MODE(返回空字符串)
|
||||
)
|
||||
|
||||
// detectInterceptType 检测请求是否需要拦截,返回拦截类型
|
||||
func detectInterceptType(body []byte) InterceptType {
|
||||
// 快速检查:如果不包含任何关键字,直接返回
|
||||
bodyStr := string(body)
|
||||
if !strings.Contains(bodyStr, "title") && !strings.Contains(bodyStr, "Warmup") {
|
||||
return false
|
||||
hasSuggestionMode := strings.Contains(bodyStr, "[SUGGESTION MODE:")
|
||||
hasWarmupKeyword := strings.Contains(bodyStr, "title") || strings.Contains(bodyStr, "Warmup")
|
||||
|
||||
if !hasSuggestionMode && !hasWarmupKeyword {
|
||||
return InterceptTypeNone
|
||||
}
|
||||
|
||||
// 解析完整请求
|
||||
// 解析请求(只解析一次)
|
||||
var req struct {
|
||||
Messages []struct {
|
||||
Role string `json:"role"`
|
||||
Content []struct {
|
||||
Type string `json:"type"`
|
||||
Text string `json:"text"`
|
||||
@@ -789,43 +808,71 @@ func isWarmupRequest(body []byte) bool {
|
||||
} `json:"system"`
|
||||
}
|
||||
if err := json.Unmarshal(body, &req); err != nil {
|
||||
return false
|
||||
return InterceptTypeNone
|
||||
}
|
||||
|
||||
// 检查 messages 中的标题提示模式
|
||||
for _, msg := range req.Messages {
|
||||
for _, content := range msg.Content {
|
||||
if content.Type == "text" {
|
||||
if strings.Contains(content.Text, "Please write a 5-10 word title for the following conversation:") ||
|
||||
content.Text == "Warmup" {
|
||||
return true
|
||||
// 检查 SUGGESTION MODE(最后一条 user 消息)
|
||||
if hasSuggestionMode && len(req.Messages) > 0 {
|
||||
lastMsg := req.Messages[len(req.Messages)-1]
|
||||
if lastMsg.Role == "user" && len(lastMsg.Content) > 0 &&
|
||||
lastMsg.Content[0].Type == "text" &&
|
||||
strings.HasPrefix(lastMsg.Content[0].Text, "[SUGGESTION MODE:") {
|
||||
return InterceptTypeSuggestionMode
|
||||
}
|
||||
}
|
||||
|
||||
// 检查 Warmup 请求
|
||||
if hasWarmupKeyword {
|
||||
// 检查 messages 中的标题提示模式
|
||||
for _, msg := range req.Messages {
|
||||
for _, content := range msg.Content {
|
||||
if content.Type == "text" {
|
||||
if strings.Contains(content.Text, "Please write a 5-10 word title for the following conversation:") ||
|
||||
content.Text == "Warmup" {
|
||||
return InterceptTypeWarmup
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
// 检查 system 中的标题提取模式
|
||||
for _, sys := range req.System {
|
||||
if strings.Contains(sys.Text, "nalyze if this message indicates a new conversation topic. If it does, extract a 2-3 word title") {
|
||||
return InterceptTypeWarmup
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// 检查 system 中的标题提取模式
|
||||
for _, system := range req.System {
|
||||
if strings.Contains(system.Text, "nalyze if this message indicates a new conversation topic. If it does, extract a 2-3 word title") {
|
||||
return true
|
||||
}
|
||||
}
|
||||
|
||||
return false
|
||||
return InterceptTypeNone
|
||||
}
|
||||
|
||||
// sendMockWarmupStream 发送流式 mock 响应(用于预热请求拦截)
|
||||
func sendMockWarmupStream(c *gin.Context, model string) {
|
||||
// sendMockInterceptStream 发送流式 mock 响应(用于请求拦截)
|
||||
func sendMockInterceptStream(c *gin.Context, model string, interceptType InterceptType) {
|
||||
c.Header("Content-Type", "text/event-stream")
|
||||
c.Header("Cache-Control", "no-cache")
|
||||
c.Header("Connection", "keep-alive")
|
||||
c.Header("X-Accel-Buffering", "no")
|
||||
|
||||
// 根据拦截类型决定响应内容
|
||||
var msgID string
|
||||
var outputTokens int
|
||||
var textDeltas []string
|
||||
|
||||
switch interceptType {
|
||||
case InterceptTypeSuggestionMode:
|
||||
msgID = "msg_mock_suggestion"
|
||||
outputTokens = 1
|
||||
textDeltas = []string{""} // 空内容
|
||||
default: // InterceptTypeWarmup
|
||||
msgID = "msg_mock_warmup"
|
||||
outputTokens = 2
|
||||
textDeltas = []string{"New", " Conversation"}
|
||||
}
|
||||
|
||||
// Build message_start event with proper JSON marshaling
|
||||
messageStart := map[string]any{
|
||||
"type": "message_start",
|
||||
"message": map[string]any{
|
||||
"id": "msg_mock_warmup",
|
||||
"id": msgID,
|
||||
"type": "message",
|
||||
"role": "assistant",
|
||||
"model": model,
|
||||
@@ -840,16 +887,46 @@ func sendMockWarmupStream(c *gin.Context, model string) {
|
||||
}
|
||||
messageStartJSON, _ := json.Marshal(messageStart)
|
||||
|
||||
// Build events
|
||||
events := []string{
|
||||
`event: message_start` + "\n" + `data: ` + string(messageStartJSON),
|
||||
`event: content_block_start` + "\n" + `data: {"content_block":{"text":"","type":"text"},"index":0,"type":"content_block_start"}`,
|
||||
`event: content_block_delta` + "\n" + `data: {"delta":{"text":"New","type":"text_delta"},"index":0,"type":"content_block_delta"}`,
|
||||
`event: content_block_delta` + "\n" + `data: {"delta":{"text":" Conversation","type":"text_delta"},"index":0,"type":"content_block_delta"}`,
|
||||
`event: content_block_stop` + "\n" + `data: {"index":0,"type":"content_block_stop"}`,
|
||||
`event: message_delta` + "\n" + `data: {"delta":{"stop_reason":"end_turn","stop_sequence":null},"type":"message_delta","usage":{"input_tokens":10,"output_tokens":2}}`,
|
||||
`event: message_stop` + "\n" + `data: {"type":"message_stop"}`,
|
||||
}
|
||||
|
||||
// Add text deltas
|
||||
for _, text := range textDeltas {
|
||||
delta := map[string]any{
|
||||
"type": "content_block_delta",
|
||||
"index": 0,
|
||||
"delta": map[string]string{
|
||||
"type": "text_delta",
|
||||
"text": text,
|
||||
},
|
||||
}
|
||||
deltaJSON, _ := json.Marshal(delta)
|
||||
events = append(events, `event: content_block_delta`+"\n"+`data: `+string(deltaJSON))
|
||||
}
|
||||
|
||||
// Add final events
|
||||
messageDelta := map[string]any{
|
||||
"type": "message_delta",
|
||||
"delta": map[string]any{
|
||||
"stop_reason": "end_turn",
|
||||
"stop_sequence": nil,
|
||||
},
|
||||
"usage": map[string]int{
|
||||
"input_tokens": 10,
|
||||
"output_tokens": outputTokens,
|
||||
},
|
||||
}
|
||||
messageDeltaJSON, _ := json.Marshal(messageDelta)
|
||||
|
||||
events = append(events,
|
||||
`event: content_block_stop`+"\n"+`data: {"index":0,"type":"content_block_stop"}`,
|
||||
`event: message_delta`+"\n"+`data: `+string(messageDeltaJSON),
|
||||
`event: message_stop`+"\n"+`data: {"type":"message_stop"}`,
|
||||
)
|
||||
|
||||
for _, event := range events {
|
||||
_, _ = c.Writer.WriteString(event + "\n\n")
|
||||
c.Writer.Flush()
|
||||
@@ -857,18 +934,32 @@ func sendMockWarmupStream(c *gin.Context, model string) {
|
||||
}
|
||||
}
|
||||
|
||||
// sendMockWarmupResponse 发送非流式 mock 响应(用于预热请求拦截)
|
||||
func sendMockWarmupResponse(c *gin.Context, model string) {
|
||||
// sendMockInterceptResponse 发送非流式 mock 响应(用于请求拦截)
|
||||
func sendMockInterceptResponse(c *gin.Context, model string, interceptType InterceptType) {
|
||||
var msgID, text string
|
||||
var outputTokens int
|
||||
|
||||
switch interceptType {
|
||||
case InterceptTypeSuggestionMode:
|
||||
msgID = "msg_mock_suggestion"
|
||||
text = ""
|
||||
outputTokens = 1
|
||||
default: // InterceptTypeWarmup
|
||||
msgID = "msg_mock_warmup"
|
||||
text = "New Conversation"
|
||||
outputTokens = 2
|
||||
}
|
||||
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"id": "msg_mock_warmup",
|
||||
"id": msgID,
|
||||
"type": "message",
|
||||
"role": "assistant",
|
||||
"model": model,
|
||||
"content": []gin.H{{"type": "text", "text": "New Conversation"}},
|
||||
"content": []gin.H{{"type": "text", "text": text}},
|
||||
"stop_reason": "end_turn",
|
||||
"usage": gin.H{
|
||||
"input_tokens": 10,
|
||||
"output_tokens": 2,
|
||||
"output_tokens": outputTokens,
|
||||
},
|
||||
})
|
||||
}
|
||||
|
||||
122
backend/internal/handler/gemini_cli_session_test.go
Normal file
122
backend/internal/handler/gemini_cli_session_test.go
Normal file
@@ -0,0 +1,122 @@
|
||||
//go:build unit
|
||||
|
||||
package handler
|
||||
|
||||
import (
|
||||
"crypto/sha256"
|
||||
"encoding/hex"
|
||||
"net/http/httptest"
|
||||
"testing"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestExtractGeminiCLISessionHash(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
body string
|
||||
privilegedUserID string
|
||||
wantEmpty bool
|
||||
wantHash string
|
||||
}{
|
||||
{
|
||||
name: "with privileged-user-id and tmp dir",
|
||||
body: `{"contents":[{"parts":[{"text":"The project's temporary directory is: /Users/ianshaw/.gemini/tmp/f7851b009ed314d1baee62e83115f486160283f4a55a582d89fdac8b9fe3b740"}]}]}`,
|
||||
privilegedUserID: "90785f52-8bbe-4b17-b111-a1ddea1636c3",
|
||||
wantEmpty: false,
|
||||
wantHash: func() string {
|
||||
combined := "90785f52-8bbe-4b17-b111-a1ddea1636c3:f7851b009ed314d1baee62e83115f486160283f4a55a582d89fdac8b9fe3b740"
|
||||
hash := sha256.Sum256([]byte(combined))
|
||||
return hex.EncodeToString(hash[:])
|
||||
}(),
|
||||
},
|
||||
{
|
||||
name: "without privileged-user-id but with tmp dir",
|
||||
body: `{"contents":[{"parts":[{"text":"The project's temporary directory is: /Users/ianshaw/.gemini/tmp/f7851b009ed314d1baee62e83115f486160283f4a55a582d89fdac8b9fe3b740"}]}]}`,
|
||||
privilegedUserID: "",
|
||||
wantEmpty: false,
|
||||
wantHash: "f7851b009ed314d1baee62e83115f486160283f4a55a582d89fdac8b9fe3b740",
|
||||
},
|
||||
{
|
||||
name: "without tmp dir",
|
||||
body: `{"contents":[{"parts":[{"text":"Hello world"}]}]}`,
|
||||
privilegedUserID: "90785f52-8bbe-4b17-b111-a1ddea1636c3",
|
||||
wantEmpty: true,
|
||||
},
|
||||
{
|
||||
name: "empty body",
|
||||
body: "",
|
||||
privilegedUserID: "90785f52-8bbe-4b17-b111-a1ddea1636c3",
|
||||
wantEmpty: true,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
// 创建测试上下文
|
||||
w := httptest.NewRecorder()
|
||||
c, _ := gin.CreateTestContext(w)
|
||||
c.Request = httptest.NewRequest("POST", "/test", nil)
|
||||
if tt.privilegedUserID != "" {
|
||||
c.Request.Header.Set("x-gemini-api-privileged-user-id", tt.privilegedUserID)
|
||||
}
|
||||
|
||||
// 调用函数
|
||||
result := extractGeminiCLISessionHash(c, []byte(tt.body))
|
||||
|
||||
// 验证结果
|
||||
if tt.wantEmpty {
|
||||
require.Empty(t, result, "expected empty session hash")
|
||||
} else {
|
||||
require.NotEmpty(t, result, "expected non-empty session hash")
|
||||
require.Equal(t, tt.wantHash, result, "session hash mismatch")
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestGeminiCLITmpDirRegex(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
input string
|
||||
wantMatch bool
|
||||
wantHash string
|
||||
}{
|
||||
{
|
||||
name: "valid tmp dir path",
|
||||
input: "/Users/ianshaw/.gemini/tmp/f7851b009ed314d1baee62e83115f486160283f4a55a582d89fdac8b9fe3b740",
|
||||
wantMatch: true,
|
||||
wantHash: "f7851b009ed314d1baee62e83115f486160283f4a55a582d89fdac8b9fe3b740",
|
||||
},
|
||||
{
|
||||
name: "valid tmp dir path in text",
|
||||
input: "The project's temporary directory is: /Users/ianshaw/.gemini/tmp/f7851b009ed314d1baee62e83115f486160283f4a55a582d89fdac8b9fe3b740\nOther text",
|
||||
wantMatch: true,
|
||||
wantHash: "f7851b009ed314d1baee62e83115f486160283f4a55a582d89fdac8b9fe3b740",
|
||||
},
|
||||
{
|
||||
name: "invalid hash length",
|
||||
input: "/Users/ianshaw/.gemini/tmp/abc123",
|
||||
wantMatch: false,
|
||||
},
|
||||
{
|
||||
name: "no tmp dir",
|
||||
input: "Hello world",
|
||||
wantMatch: false,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
match := geminiCLITmpDirRegex.FindStringSubmatch(tt.input)
|
||||
if tt.wantMatch {
|
||||
require.NotNil(t, match, "expected regex to match")
|
||||
require.Len(t, match, 2, "expected 2 capture groups")
|
||||
require.Equal(t, tt.wantHash, match[1], "hash mismatch")
|
||||
} else {
|
||||
require.Nil(t, match, "expected regex not to match")
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
@@ -1,11 +1,15 @@
|
||||
package handler
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"crypto/sha256"
|
||||
"encoding/hex"
|
||||
"errors"
|
||||
"io"
|
||||
"log"
|
||||
"net/http"
|
||||
"regexp"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
@@ -19,6 +23,17 @@ import (
|
||||
"github.com/gin-gonic/gin"
|
||||
)
|
||||
|
||||
// geminiCLITmpDirRegex 用于从 Gemini CLI 请求体中提取 tmp 目录的哈希值
|
||||
// 匹配格式: /Users/xxx/.gemini/tmp/[64位十六进制哈希]
|
||||
var geminiCLITmpDirRegex = regexp.MustCompile(`/\.gemini/tmp/([A-Fa-f0-9]{64})`)
|
||||
|
||||
func isGeminiCLIRequest(c *gin.Context, body []byte) bool {
|
||||
if strings.TrimSpace(c.GetHeader("x-gemini-api-privileged-user-id")) != "" {
|
||||
return true
|
||||
}
|
||||
return geminiCLITmpDirRegex.Match(body)
|
||||
}
|
||||
|
||||
// GeminiV1BetaListModels proxies:
|
||||
// GET /v1beta/models
|
||||
func (h *GatewayHandler) GeminiV1BetaListModels(c *gin.Context) {
|
||||
@@ -214,12 +229,26 @@ func (h *GatewayHandler) GeminiV1BetaModels(c *gin.Context) {
|
||||
}
|
||||
|
||||
// 3) select account (sticky session based on request body)
|
||||
parsedReq, _ := service.ParseGatewayRequest(body)
|
||||
sessionHash := h.gatewayService.GenerateSessionHash(parsedReq)
|
||||
// 优先使用 Gemini CLI 的会话标识(privileged-user-id + tmp 目录哈希)
|
||||
sessionHash := extractGeminiCLISessionHash(c, body)
|
||||
if sessionHash == "" {
|
||||
// Fallback: 使用通用的会话哈希生成逻辑(适用于其他客户端)
|
||||
parsedReq, _ := service.ParseGatewayRequest(body)
|
||||
sessionHash = h.gatewayService.GenerateSessionHash(parsedReq)
|
||||
}
|
||||
sessionKey := sessionHash
|
||||
if sessionHash != "" {
|
||||
sessionKey = "gemini:" + sessionHash
|
||||
}
|
||||
|
||||
// 查询粘性会话绑定的账号 ID(用于检测账号切换)
|
||||
var sessionBoundAccountID int64
|
||||
if sessionKey != "" {
|
||||
sessionBoundAccountID, _ = h.gatewayService.GetCachedSessionAccountID(c.Request.Context(), apiKey.GroupID, sessionKey)
|
||||
}
|
||||
isCLI := isGeminiCLIRequest(c, body)
|
||||
cleanedForUnknownBinding := false
|
||||
|
||||
maxAccountSwitches := h.maxAccountSwitchesGemini
|
||||
switchCount := 0
|
||||
failedAccountIDs := make(map[int64]struct{})
|
||||
@@ -238,6 +267,24 @@ func (h *GatewayHandler) GeminiV1BetaModels(c *gin.Context) {
|
||||
account := selection.Account
|
||||
setOpsSelectedAccount(c, account.ID)
|
||||
|
||||
// 检测账号切换:如果粘性会话绑定的账号与当前选择的账号不同,清除 thoughtSignature
|
||||
// 注意:Gemini 原生 API 的 thoughtSignature 与具体上游账号强相关;跨账号透传会导致 400。
|
||||
if sessionBoundAccountID > 0 && sessionBoundAccountID != account.ID {
|
||||
log.Printf("[Gemini] Sticky session account switched: %d -> %d, cleaning thoughtSignature", sessionBoundAccountID, account.ID)
|
||||
body = service.CleanGeminiNativeThoughtSignatures(body)
|
||||
sessionBoundAccountID = account.ID
|
||||
} else if sessionKey != "" && sessionBoundAccountID == 0 && isCLI && !cleanedForUnknownBinding && bytes.Contains(body, []byte(`"thoughtSignature"`)) {
|
||||
// 无缓存绑定但请求里已有 thoughtSignature:常见于缓存丢失/TTL 过期后,CLI 继续携带旧签名。
|
||||
// 为避免第一次转发就 400,这里做一次确定性清理,让新账号重新生成签名链路。
|
||||
log.Printf("[Gemini] Sticky session binding missing for CLI request, cleaning thoughtSignature proactively")
|
||||
body = service.CleanGeminiNativeThoughtSignatures(body)
|
||||
cleanedForUnknownBinding = true
|
||||
sessionBoundAccountID = account.ID
|
||||
} else if sessionBoundAccountID == 0 {
|
||||
// 记录本次请求中首次选择到的账号,便于同一请求内 failover 时检测切换。
|
||||
sessionBoundAccountID = account.ID
|
||||
}
|
||||
|
||||
// 4) account concurrency slot
|
||||
accountReleaseFunc := selection.ReleaseFunc
|
||||
if !selection.Acquired {
|
||||
@@ -433,3 +480,38 @@ func shouldFallbackGeminiModels(res *service.UpstreamHTTPResult) bool {
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
// extractGeminiCLISessionHash 从 Gemini CLI 请求中提取会话标识。
|
||||
// 组合 x-gemini-api-privileged-user-id header 和请求体中的 tmp 目录哈希。
|
||||
//
|
||||
// 会话标识生成策略:
|
||||
// 1. 从请求体中提取 tmp 目录哈希(64位十六进制)
|
||||
// 2. 从 header 中提取 privileged-user-id(UUID)
|
||||
// 3. 组合两者生成 SHA256 哈希作为最终的会话标识
|
||||
//
|
||||
// 如果找不到 tmp 目录哈希,返回空字符串(不使用粘性会话)。
|
||||
//
|
||||
// extractGeminiCLISessionHash extracts session identifier from Gemini CLI requests.
|
||||
// Combines x-gemini-api-privileged-user-id header with tmp directory hash from request body.
|
||||
func extractGeminiCLISessionHash(c *gin.Context, body []byte) string {
|
||||
// 1. 从请求体中提取 tmp 目录哈希
|
||||
match := geminiCLITmpDirRegex.FindSubmatch(body)
|
||||
if len(match) < 2 {
|
||||
return "" // 没有找到 tmp 目录,不使用粘性会话
|
||||
}
|
||||
tmpDirHash := string(match[1])
|
||||
|
||||
// 2. 提取 privileged-user-id
|
||||
privilegedUserID := strings.TrimSpace(c.GetHeader("x-gemini-api-privileged-user-id"))
|
||||
|
||||
// 3. 组合生成最终的 session hash
|
||||
if privilegedUserID != "" {
|
||||
// 组合两个标识符:privileged-user-id + tmp 目录哈希
|
||||
combined := privilegedUserID + ":" + tmpDirHash
|
||||
hash := sha256.Sum256([]byte(combined))
|
||||
return hex.EncodeToString(hash[:])
|
||||
}
|
||||
|
||||
// 如果没有 privileged-user-id,直接使用 tmp 目录哈希
|
||||
return tmpDirHash
|
||||
}
|
||||
|
||||
@@ -37,6 +37,7 @@ type Handlers struct {
|
||||
Gateway *GatewayHandler
|
||||
OpenAIGateway *OpenAIGatewayHandler
|
||||
Setting *SettingHandler
|
||||
Totp *TotpHandler
|
||||
}
|
||||
|
||||
// BuildInfo contains build-time information
|
||||
|
||||
@@ -32,20 +32,21 @@ func (h *SettingHandler) GetPublicSettings(c *gin.Context) {
|
||||
}
|
||||
|
||||
response.Success(c, dto.PublicSettings{
|
||||
RegistrationEnabled: settings.RegistrationEnabled,
|
||||
EmailVerifyEnabled: settings.EmailVerifyEnabled,
|
||||
PromoCodeEnabled: settings.PromoCodeEnabled,
|
||||
TurnstileEnabled: settings.TurnstileEnabled,
|
||||
TurnstileSiteKey: settings.TurnstileSiteKey,
|
||||
SiteName: settings.SiteName,
|
||||
SiteLogo: settings.SiteLogo,
|
||||
SiteSubtitle: settings.SiteSubtitle,
|
||||
APIBaseURL: settings.APIBaseURL,
|
||||
ContactInfo: settings.ContactInfo,
|
||||
DocURL: settings.DocURL,
|
||||
HomeContent: settings.HomeContent,
|
||||
HideCcsImportButton: settings.HideCcsImportButton,
|
||||
LinuxDoOAuthEnabled: settings.LinuxDoOAuthEnabled,
|
||||
Version: h.version,
|
||||
RegistrationEnabled: settings.RegistrationEnabled,
|
||||
EmailVerifyEnabled: settings.EmailVerifyEnabled,
|
||||
PromoCodeEnabled: settings.PromoCodeEnabled,
|
||||
PasswordResetEnabled: settings.PasswordResetEnabled,
|
||||
TurnstileEnabled: settings.TurnstileEnabled,
|
||||
TurnstileSiteKey: settings.TurnstileSiteKey,
|
||||
SiteName: settings.SiteName,
|
||||
SiteLogo: settings.SiteLogo,
|
||||
SiteSubtitle: settings.SiteSubtitle,
|
||||
APIBaseURL: settings.APIBaseURL,
|
||||
ContactInfo: settings.ContactInfo,
|
||||
DocURL: settings.DocURL,
|
||||
HomeContent: settings.HomeContent,
|
||||
HideCcsImportButton: settings.HideCcsImportButton,
|
||||
LinuxDoOAuthEnabled: settings.LinuxDoOAuthEnabled,
|
||||
Version: h.version,
|
||||
})
|
||||
}
|
||||
|
||||
181
backend/internal/handler/totp_handler.go
Normal file
181
backend/internal/handler/totp_handler.go
Normal file
@@ -0,0 +1,181 @@
|
||||
package handler
|
||||
|
||||
import (
|
||||
"github.com/gin-gonic/gin"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/response"
|
||||
middleware2 "github.com/Wei-Shaw/sub2api/internal/server/middleware"
|
||||
"github.com/Wei-Shaw/sub2api/internal/service"
|
||||
)
|
||||
|
||||
// TotpHandler handles TOTP-related requests
|
||||
type TotpHandler struct {
|
||||
totpService *service.TotpService
|
||||
}
|
||||
|
||||
// NewTotpHandler creates a new TotpHandler
|
||||
func NewTotpHandler(totpService *service.TotpService) *TotpHandler {
|
||||
return &TotpHandler{
|
||||
totpService: totpService,
|
||||
}
|
||||
}
|
||||
|
||||
// TotpStatusResponse represents the TOTP status response
|
||||
type TotpStatusResponse struct {
|
||||
Enabled bool `json:"enabled"`
|
||||
EnabledAt *int64 `json:"enabled_at,omitempty"` // Unix timestamp
|
||||
FeatureEnabled bool `json:"feature_enabled"`
|
||||
}
|
||||
|
||||
// GetStatus returns the TOTP status for the current user
|
||||
// GET /api/v1/user/totp/status
|
||||
func (h *TotpHandler) GetStatus(c *gin.Context) {
|
||||
subject, ok := middleware2.GetAuthSubjectFromContext(c)
|
||||
if !ok {
|
||||
response.Unauthorized(c, "User not authenticated")
|
||||
return
|
||||
}
|
||||
|
||||
status, err := h.totpService.GetStatus(c.Request.Context(), subject.UserID)
|
||||
if err != nil {
|
||||
response.ErrorFrom(c, err)
|
||||
return
|
||||
}
|
||||
|
||||
resp := TotpStatusResponse{
|
||||
Enabled: status.Enabled,
|
||||
FeatureEnabled: status.FeatureEnabled,
|
||||
}
|
||||
|
||||
if status.EnabledAt != nil {
|
||||
ts := status.EnabledAt.Unix()
|
||||
resp.EnabledAt = &ts
|
||||
}
|
||||
|
||||
response.Success(c, resp)
|
||||
}
|
||||
|
||||
// TotpSetupRequest represents the request to initiate TOTP setup
|
||||
type TotpSetupRequest struct {
|
||||
EmailCode string `json:"email_code"`
|
||||
Password string `json:"password"`
|
||||
}
|
||||
|
||||
// TotpSetupResponse represents the TOTP setup response
|
||||
type TotpSetupResponse struct {
|
||||
Secret string `json:"secret"`
|
||||
QRCodeURL string `json:"qr_code_url"`
|
||||
SetupToken string `json:"setup_token"`
|
||||
Countdown int `json:"countdown"`
|
||||
}
|
||||
|
||||
// InitiateSetup starts the TOTP setup process
|
||||
// POST /api/v1/user/totp/setup
|
||||
func (h *TotpHandler) InitiateSetup(c *gin.Context) {
|
||||
subject, ok := middleware2.GetAuthSubjectFromContext(c)
|
||||
if !ok {
|
||||
response.Unauthorized(c, "User not authenticated")
|
||||
return
|
||||
}
|
||||
|
||||
var req TotpSetupRequest
|
||||
if err := c.ShouldBindJSON(&req); err != nil {
|
||||
// Allow empty body (optional params)
|
||||
req = TotpSetupRequest{}
|
||||
}
|
||||
|
||||
result, err := h.totpService.InitiateSetup(c.Request.Context(), subject.UserID, req.EmailCode, req.Password)
|
||||
if err != nil {
|
||||
response.ErrorFrom(c, err)
|
||||
return
|
||||
}
|
||||
|
||||
response.Success(c, TotpSetupResponse{
|
||||
Secret: result.Secret,
|
||||
QRCodeURL: result.QRCodeURL,
|
||||
SetupToken: result.SetupToken,
|
||||
Countdown: result.Countdown,
|
||||
})
|
||||
}
|
||||
|
||||
// TotpEnableRequest represents the request to enable TOTP
|
||||
type TotpEnableRequest struct {
|
||||
TotpCode string `json:"totp_code" binding:"required,len=6"`
|
||||
SetupToken string `json:"setup_token" binding:"required"`
|
||||
}
|
||||
|
||||
// Enable completes the TOTP setup
|
||||
// POST /api/v1/user/totp/enable
|
||||
func (h *TotpHandler) Enable(c *gin.Context) {
|
||||
subject, ok := middleware2.GetAuthSubjectFromContext(c)
|
||||
if !ok {
|
||||
response.Unauthorized(c, "User not authenticated")
|
||||
return
|
||||
}
|
||||
|
||||
var req TotpEnableRequest
|
||||
if err := c.ShouldBindJSON(&req); err != nil {
|
||||
response.BadRequest(c, "Invalid request: "+err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
if err := h.totpService.CompleteSetup(c.Request.Context(), subject.UserID, req.TotpCode, req.SetupToken); err != nil {
|
||||
response.ErrorFrom(c, err)
|
||||
return
|
||||
}
|
||||
|
||||
response.Success(c, gin.H{"success": true})
|
||||
}
|
||||
|
||||
// TotpDisableRequest represents the request to disable TOTP
|
||||
type TotpDisableRequest struct {
|
||||
EmailCode string `json:"email_code"`
|
||||
Password string `json:"password"`
|
||||
}
|
||||
|
||||
// Disable disables TOTP for the current user
|
||||
// POST /api/v1/user/totp/disable
|
||||
func (h *TotpHandler) Disable(c *gin.Context) {
|
||||
subject, ok := middleware2.GetAuthSubjectFromContext(c)
|
||||
if !ok {
|
||||
response.Unauthorized(c, "User not authenticated")
|
||||
return
|
||||
}
|
||||
|
||||
var req TotpDisableRequest
|
||||
if err := c.ShouldBindJSON(&req); err != nil {
|
||||
response.BadRequest(c, "Invalid request: "+err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
if err := h.totpService.Disable(c.Request.Context(), subject.UserID, req.EmailCode, req.Password); err != nil {
|
||||
response.ErrorFrom(c, err)
|
||||
return
|
||||
}
|
||||
|
||||
response.Success(c, gin.H{"success": true})
|
||||
}
|
||||
|
||||
// GetVerificationMethod returns the verification method for TOTP operations
|
||||
// GET /api/v1/user/totp/verification-method
|
||||
func (h *TotpHandler) GetVerificationMethod(c *gin.Context) {
|
||||
method := h.totpService.GetVerificationMethod(c.Request.Context())
|
||||
response.Success(c, method)
|
||||
}
|
||||
|
||||
// SendVerifyCode sends an email verification code for TOTP operations
|
||||
// POST /api/v1/user/totp/send-code
|
||||
func (h *TotpHandler) SendVerifyCode(c *gin.Context) {
|
||||
subject, ok := middleware2.GetAuthSubjectFromContext(c)
|
||||
if !ok {
|
||||
response.Unauthorized(c, "User not authenticated")
|
||||
return
|
||||
}
|
||||
|
||||
if err := h.totpService.SendVerifyCode(c.Request.Context(), subject.UserID); err != nil {
|
||||
response.ErrorFrom(c, err)
|
||||
return
|
||||
}
|
||||
|
||||
response.Success(c, gin.H{"success": true})
|
||||
}
|
||||
@@ -70,6 +70,7 @@ func ProvideHandlers(
|
||||
gatewayHandler *GatewayHandler,
|
||||
openaiGatewayHandler *OpenAIGatewayHandler,
|
||||
settingHandler *SettingHandler,
|
||||
totpHandler *TotpHandler,
|
||||
) *Handlers {
|
||||
return &Handlers{
|
||||
Auth: authHandler,
|
||||
@@ -82,6 +83,7 @@ func ProvideHandlers(
|
||||
Gateway: gatewayHandler,
|
||||
OpenAIGateway: openaiGatewayHandler,
|
||||
Setting: settingHandler,
|
||||
Totp: totpHandler,
|
||||
}
|
||||
}
|
||||
|
||||
@@ -96,6 +98,7 @@ var ProviderSet = wire.NewSet(
|
||||
NewSubscriptionHandler,
|
||||
NewGatewayHandler,
|
||||
NewOpenAIGatewayHandler,
|
||||
NewTotpHandler,
|
||||
ProvideSettingHandler,
|
||||
|
||||
// Admin handlers
|
||||
|
||||
Reference in New Issue
Block a user