merge: 正确合并 main 分支改动
合并 origin/main 最新改动,正确保留所有配置: - Ops 运维监控配置和功能 - LinuxDo Connect OAuth 配置 - Update 在线更新配置 - 优惠码功能 - 其他 main 分支新功能 修复之前合并时错误删除 LinuxDo 和 Update 配置的问题。
This commit is contained in:
@@ -294,6 +294,13 @@ type DatabaseConfig struct {
|
||||
}
|
||||
|
||||
func (d *DatabaseConfig) DSN() string {
|
||||
// 当密码为空时不包含 password 参数,避免 libpq 解析错误
|
||||
if d.Password == "" {
|
||||
return fmt.Sprintf(
|
||||
"host=%s port=%d user=%s dbname=%s sslmode=%s",
|
||||
d.Host, d.Port, d.User, d.DBName, d.SSLMode,
|
||||
)
|
||||
}
|
||||
return fmt.Sprintf(
|
||||
"host=%s port=%d user=%s password=%s dbname=%s sslmode=%s",
|
||||
d.Host, d.Port, d.User, d.Password, d.DBName, d.SSLMode,
|
||||
@@ -305,6 +312,13 @@ func (d *DatabaseConfig) DSNWithTimezone(tz string) string {
|
||||
if tz == "" {
|
||||
tz = "Asia/Shanghai"
|
||||
}
|
||||
// 当密码为空时不包含 password 参数,避免 libpq 解析错误
|
||||
if d.Password == "" {
|
||||
return fmt.Sprintf(
|
||||
"host=%s port=%d user=%s dbname=%s sslmode=%s TimeZone=%s",
|
||||
d.Host, d.Port, d.User, d.DBName, d.SSLMode, tz,
|
||||
)
|
||||
}
|
||||
return fmt.Sprintf(
|
||||
"host=%s port=%d user=%s password=%s dbname=%s sslmode=%s TimeZone=%s",
|
||||
d.Host, d.Port, d.User, d.Password, d.DBName, d.SSLMode, tz,
|
||||
@@ -538,6 +552,22 @@ func setDefaults() {
|
||||
// Turnstile
|
||||
viper.SetDefault("turnstile.required", false)
|
||||
|
||||
// LinuxDo Connect OAuth 登录(终端用户 SSO)
|
||||
viper.SetDefault("linuxdo_connect.enabled", false)
|
||||
viper.SetDefault("linuxdo_connect.client_id", "")
|
||||
viper.SetDefault("linuxdo_connect.client_secret", "")
|
||||
viper.SetDefault("linuxdo_connect.authorize_url", "https://connect.linux.do/oauth2/authorize")
|
||||
viper.SetDefault("linuxdo_connect.token_url", "https://connect.linux.do/oauth2/token")
|
||||
viper.SetDefault("linuxdo_connect.userinfo_url", "https://connect.linux.do/api/user")
|
||||
viper.SetDefault("linuxdo_connect.scopes", "user")
|
||||
viper.SetDefault("linuxdo_connect.redirect_url", "")
|
||||
viper.SetDefault("linuxdo_connect.frontend_redirect_url", "/auth/linuxdo/callback")
|
||||
viper.SetDefault("linuxdo_connect.token_auth_method", "client_secret_post")
|
||||
viper.SetDefault("linuxdo_connect.use_pkce", false)
|
||||
viper.SetDefault("linuxdo_connect.userinfo_email_path", "")
|
||||
viper.SetDefault("linuxdo_connect.userinfo_id_path", "")
|
||||
viper.SetDefault("linuxdo_connect.userinfo_username_path", "")
|
||||
|
||||
// Database
|
||||
viper.SetDefault("database.host", "localhost")
|
||||
viper.SetDefault("database.port", 5432)
|
||||
@@ -659,6 +689,61 @@ func (c *Config) Validate() error {
|
||||
if c.Security.CSP.Enabled && strings.TrimSpace(c.Security.CSP.Policy) == "" {
|
||||
return fmt.Errorf("security.csp.policy is required when CSP is enabled")
|
||||
}
|
||||
if c.LinuxDo.Enabled {
|
||||
if strings.TrimSpace(c.LinuxDo.ClientID) == "" {
|
||||
return fmt.Errorf("linuxdo_connect.client_id is required when linuxdo_connect.enabled=true")
|
||||
}
|
||||
if strings.TrimSpace(c.LinuxDo.AuthorizeURL) == "" {
|
||||
return fmt.Errorf("linuxdo_connect.authorize_url is required when linuxdo_connect.enabled=true")
|
||||
}
|
||||
if strings.TrimSpace(c.LinuxDo.TokenURL) == "" {
|
||||
return fmt.Errorf("linuxdo_connect.token_url is required when linuxdo_connect.enabled=true")
|
||||
}
|
||||
if strings.TrimSpace(c.LinuxDo.UserInfoURL) == "" {
|
||||
return fmt.Errorf("linuxdo_connect.userinfo_url is required when linuxdo_connect.enabled=true")
|
||||
}
|
||||
if strings.TrimSpace(c.LinuxDo.RedirectURL) == "" {
|
||||
return fmt.Errorf("linuxdo_connect.redirect_url is required when linuxdo_connect.enabled=true")
|
||||
}
|
||||
method := strings.ToLower(strings.TrimSpace(c.LinuxDo.TokenAuthMethod))
|
||||
switch method {
|
||||
case "", "client_secret_post", "client_secret_basic", "none":
|
||||
default:
|
||||
return fmt.Errorf("linuxdo_connect.token_auth_method must be one of: client_secret_post/client_secret_basic/none")
|
||||
}
|
||||
if method == "none" && !c.LinuxDo.UsePKCE {
|
||||
return fmt.Errorf("linuxdo_connect.use_pkce must be true when linuxdo_connect.token_auth_method=none")
|
||||
}
|
||||
if (method == "" || method == "client_secret_post" || method == "client_secret_basic") &&
|
||||
strings.TrimSpace(c.LinuxDo.ClientSecret) == "" {
|
||||
return fmt.Errorf("linuxdo_connect.client_secret is required when linuxdo_connect.enabled=true and token_auth_method is client_secret_post/client_secret_basic")
|
||||
}
|
||||
if strings.TrimSpace(c.LinuxDo.FrontendRedirectURL) == "" {
|
||||
return fmt.Errorf("linuxdo_connect.frontend_redirect_url is required when linuxdo_connect.enabled=true")
|
||||
}
|
||||
|
||||
if err := ValidateAbsoluteHTTPURL(c.LinuxDo.AuthorizeURL); err != nil {
|
||||
return fmt.Errorf("linuxdo_connect.authorize_url invalid: %w", err)
|
||||
}
|
||||
if err := ValidateAbsoluteHTTPURL(c.LinuxDo.TokenURL); err != nil {
|
||||
return fmt.Errorf("linuxdo_connect.token_url invalid: %w", err)
|
||||
}
|
||||
if err := ValidateAbsoluteHTTPURL(c.LinuxDo.UserInfoURL); err != nil {
|
||||
return fmt.Errorf("linuxdo_connect.userinfo_url invalid: %w", err)
|
||||
}
|
||||
if err := ValidateAbsoluteHTTPURL(c.LinuxDo.RedirectURL); err != nil {
|
||||
return fmt.Errorf("linuxdo_connect.redirect_url invalid: %w", err)
|
||||
}
|
||||
if err := ValidateFrontendRedirectURL(c.LinuxDo.FrontendRedirectURL); err != nil {
|
||||
return fmt.Errorf("linuxdo_connect.frontend_redirect_url invalid: %w", err)
|
||||
}
|
||||
|
||||
warnIfInsecureURL("linuxdo_connect.authorize_url", c.LinuxDo.AuthorizeURL)
|
||||
warnIfInsecureURL("linuxdo_connect.token_url", c.LinuxDo.TokenURL)
|
||||
warnIfInsecureURL("linuxdo_connect.userinfo_url", c.LinuxDo.UserInfoURL)
|
||||
warnIfInsecureURL("linuxdo_connect.redirect_url", c.LinuxDo.RedirectURL)
|
||||
warnIfInsecureURL("linuxdo_connect.frontend_redirect_url", c.LinuxDo.FrontendRedirectURL)
|
||||
}
|
||||
if c.Billing.CircuitBreaker.Enabled {
|
||||
if c.Billing.CircuitBreaker.FailureThreshold <= 0 {
|
||||
return fmt.Errorf("billing.circuit_breaker.failure_threshold must be positive")
|
||||
@@ -928,3 +1013,13 @@ func ValidateFrontendRedirectURL(raw string) error {
|
||||
func isHTTPScheme(scheme string) bool {
|
||||
return strings.EqualFold(scheme, "http") || strings.EqualFold(scheme, "https")
|
||||
}
|
||||
|
||||
func warnIfInsecureURL(field, raw string) {
|
||||
u, err := url.Parse(strings.TrimSpace(raw))
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
if strings.EqualFold(u.Scheme, "http") {
|
||||
log.Printf("Warning: %s uses http scheme; use https in production to avoid token leakage.", field)
|
||||
}
|
||||
}
|
||||
|
||||
209
backend/internal/handler/admin/promo_handler.go
Normal file
209
backend/internal/handler/admin/promo_handler.go
Normal file
@@ -0,0 +1,209 @@
|
||||
package admin
|
||||
|
||||
import (
|
||||
"strconv"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/handler/dto"
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/pagination"
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/response"
|
||||
"github.com/Wei-Shaw/sub2api/internal/service"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
)
|
||||
|
||||
// PromoHandler handles admin promo code management
|
||||
type PromoHandler struct {
|
||||
promoService *service.PromoService
|
||||
}
|
||||
|
||||
// NewPromoHandler creates a new admin promo handler
|
||||
func NewPromoHandler(promoService *service.PromoService) *PromoHandler {
|
||||
return &PromoHandler{
|
||||
promoService: promoService,
|
||||
}
|
||||
}
|
||||
|
||||
// CreatePromoCodeRequest represents create promo code request
|
||||
type CreatePromoCodeRequest struct {
|
||||
Code string `json:"code"` // 可选,为空则自动生成
|
||||
BonusAmount float64 `json:"bonus_amount" binding:"required,min=0"` // 赠送余额
|
||||
MaxUses int `json:"max_uses" binding:"min=0"` // 最大使用次数,0=无限
|
||||
ExpiresAt *int64 `json:"expires_at"` // 过期时间戳(秒)
|
||||
Notes string `json:"notes"` // 备注
|
||||
}
|
||||
|
||||
// UpdatePromoCodeRequest represents update promo code request
|
||||
type UpdatePromoCodeRequest struct {
|
||||
Code *string `json:"code"`
|
||||
BonusAmount *float64 `json:"bonus_amount" binding:"omitempty,min=0"`
|
||||
MaxUses *int `json:"max_uses" binding:"omitempty,min=0"`
|
||||
Status *string `json:"status" binding:"omitempty,oneof=active disabled"`
|
||||
ExpiresAt *int64 `json:"expires_at"`
|
||||
Notes *string `json:"notes"`
|
||||
}
|
||||
|
||||
// List handles listing all promo codes with pagination
|
||||
// GET /api/v1/admin/promo-codes
|
||||
func (h *PromoHandler) List(c *gin.Context) {
|
||||
page, pageSize := response.ParsePagination(c)
|
||||
status := c.Query("status")
|
||||
search := strings.TrimSpace(c.Query("search"))
|
||||
if len(search) > 100 {
|
||||
search = search[:100]
|
||||
}
|
||||
|
||||
params := pagination.PaginationParams{
|
||||
Page: page,
|
||||
PageSize: pageSize,
|
||||
}
|
||||
|
||||
codes, paginationResult, err := h.promoService.List(c.Request.Context(), params, status, search)
|
||||
if err != nil {
|
||||
response.ErrorFrom(c, err)
|
||||
return
|
||||
}
|
||||
|
||||
out := make([]dto.PromoCode, 0, len(codes))
|
||||
for i := range codes {
|
||||
out = append(out, *dto.PromoCodeFromService(&codes[i]))
|
||||
}
|
||||
response.Paginated(c, out, paginationResult.Total, page, pageSize)
|
||||
}
|
||||
|
||||
// GetByID handles getting a promo code by ID
|
||||
// GET /api/v1/admin/promo-codes/:id
|
||||
func (h *PromoHandler) GetByID(c *gin.Context) {
|
||||
codeID, err := strconv.ParseInt(c.Param("id"), 10, 64)
|
||||
if err != nil {
|
||||
response.BadRequest(c, "Invalid promo code ID")
|
||||
return
|
||||
}
|
||||
|
||||
code, err := h.promoService.GetByID(c.Request.Context(), codeID)
|
||||
if err != nil {
|
||||
response.ErrorFrom(c, err)
|
||||
return
|
||||
}
|
||||
|
||||
response.Success(c, dto.PromoCodeFromService(code))
|
||||
}
|
||||
|
||||
// Create handles creating a new promo code
|
||||
// POST /api/v1/admin/promo-codes
|
||||
func (h *PromoHandler) Create(c *gin.Context) {
|
||||
var req CreatePromoCodeRequest
|
||||
if err := c.ShouldBindJSON(&req); err != nil {
|
||||
response.BadRequest(c, "Invalid request: "+err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
input := &service.CreatePromoCodeInput{
|
||||
Code: req.Code,
|
||||
BonusAmount: req.BonusAmount,
|
||||
MaxUses: req.MaxUses,
|
||||
Notes: req.Notes,
|
||||
}
|
||||
|
||||
if req.ExpiresAt != nil {
|
||||
t := time.Unix(*req.ExpiresAt, 0)
|
||||
input.ExpiresAt = &t
|
||||
}
|
||||
|
||||
code, err := h.promoService.Create(c.Request.Context(), input)
|
||||
if err != nil {
|
||||
response.ErrorFrom(c, err)
|
||||
return
|
||||
}
|
||||
|
||||
response.Success(c, dto.PromoCodeFromService(code))
|
||||
}
|
||||
|
||||
// Update handles updating a promo code
|
||||
// PUT /api/v1/admin/promo-codes/:id
|
||||
func (h *PromoHandler) Update(c *gin.Context) {
|
||||
codeID, err := strconv.ParseInt(c.Param("id"), 10, 64)
|
||||
if err != nil {
|
||||
response.BadRequest(c, "Invalid promo code ID")
|
||||
return
|
||||
}
|
||||
|
||||
var req UpdatePromoCodeRequest
|
||||
if err := c.ShouldBindJSON(&req); err != nil {
|
||||
response.BadRequest(c, "Invalid request: "+err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
input := &service.UpdatePromoCodeInput{
|
||||
Code: req.Code,
|
||||
BonusAmount: req.BonusAmount,
|
||||
MaxUses: req.MaxUses,
|
||||
Status: req.Status,
|
||||
Notes: req.Notes,
|
||||
}
|
||||
|
||||
if req.ExpiresAt != nil {
|
||||
if *req.ExpiresAt == 0 {
|
||||
// 0 表示清除过期时间
|
||||
input.ExpiresAt = nil
|
||||
} else {
|
||||
t := time.Unix(*req.ExpiresAt, 0)
|
||||
input.ExpiresAt = &t
|
||||
}
|
||||
}
|
||||
|
||||
code, err := h.promoService.Update(c.Request.Context(), codeID, input)
|
||||
if err != nil {
|
||||
response.ErrorFrom(c, err)
|
||||
return
|
||||
}
|
||||
|
||||
response.Success(c, dto.PromoCodeFromService(code))
|
||||
}
|
||||
|
||||
// Delete handles deleting a promo code
|
||||
// DELETE /api/v1/admin/promo-codes/:id
|
||||
func (h *PromoHandler) Delete(c *gin.Context) {
|
||||
codeID, err := strconv.ParseInt(c.Param("id"), 10, 64)
|
||||
if err != nil {
|
||||
response.BadRequest(c, "Invalid promo code ID")
|
||||
return
|
||||
}
|
||||
|
||||
err = h.promoService.Delete(c.Request.Context(), codeID)
|
||||
if err != nil {
|
||||
response.ErrorFrom(c, err)
|
||||
return
|
||||
}
|
||||
|
||||
response.Success(c, gin.H{"message": "Promo code deleted successfully"})
|
||||
}
|
||||
|
||||
// GetUsages handles getting usage records for a promo code
|
||||
// GET /api/v1/admin/promo-codes/:id/usages
|
||||
func (h *PromoHandler) GetUsages(c *gin.Context) {
|
||||
codeID, err := strconv.ParseInt(c.Param("id"), 10, 64)
|
||||
if err != nil {
|
||||
response.BadRequest(c, "Invalid promo code ID")
|
||||
return
|
||||
}
|
||||
|
||||
page, pageSize := response.ParsePagination(c)
|
||||
params := pagination.PaginationParams{
|
||||
Page: page,
|
||||
PageSize: pageSize,
|
||||
}
|
||||
|
||||
usages, paginationResult, err := h.promoService.ListUsages(c.Request.Context(), codeID, params)
|
||||
if err != nil {
|
||||
response.ErrorFrom(c, err)
|
||||
return
|
||||
}
|
||||
|
||||
out := make([]dto.PromoCodeUsage, 0, len(usages))
|
||||
for i := range usages {
|
||||
out = append(out, *dto.PromoCodeUsageFromService(&usages[i]))
|
||||
}
|
||||
response.Paginated(c, out, paginationResult.Total, page, pageSize)
|
||||
}
|
||||
@@ -2,8 +2,10 @@ package admin
|
||||
|
||||
import (
|
||||
"log"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/config"
|
||||
"github.com/Wei-Shaw/sub2api/internal/handler/dto"
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/response"
|
||||
"github.com/Wei-Shaw/sub2api/internal/server/middleware"
|
||||
@@ -38,37 +40,42 @@ func (h *SettingHandler) GetSettings(c *gin.Context) {
|
||||
}
|
||||
|
||||
response.Success(c, dto.SystemSettings{
|
||||
RegistrationEnabled: settings.RegistrationEnabled,
|
||||
EmailVerifyEnabled: settings.EmailVerifyEnabled,
|
||||
SMTPHost: settings.SMTPHost,
|
||||
SMTPPort: settings.SMTPPort,
|
||||
SMTPUsername: settings.SMTPUsername,
|
||||
SMTPPasswordConfigured: settings.SMTPPasswordConfigured,
|
||||
SMTPFrom: settings.SMTPFrom,
|
||||
SMTPFromName: settings.SMTPFromName,
|
||||
SMTPUseTLS: settings.SMTPUseTLS,
|
||||
TurnstileEnabled: settings.TurnstileEnabled,
|
||||
TurnstileSiteKey: settings.TurnstileSiteKey,
|
||||
TurnstileSecretKeyConfigured: settings.TurnstileSecretKeyConfigured,
|
||||
SiteName: settings.SiteName,
|
||||
SiteLogo: settings.SiteLogo,
|
||||
SiteSubtitle: settings.SiteSubtitle,
|
||||
APIBaseURL: settings.APIBaseURL,
|
||||
ContactInfo: settings.ContactInfo,
|
||||
DocURL: settings.DocURL,
|
||||
DefaultConcurrency: settings.DefaultConcurrency,
|
||||
DefaultBalance: settings.DefaultBalance,
|
||||
EnableModelFallback: settings.EnableModelFallback,
|
||||
FallbackModelAnthropic: settings.FallbackModelAnthropic,
|
||||
FallbackModelOpenAI: settings.FallbackModelOpenAI,
|
||||
FallbackModelGemini: settings.FallbackModelGemini,
|
||||
FallbackModelAntigravity: settings.FallbackModelAntigravity,
|
||||
EnableIdentityPatch: settings.EnableIdentityPatch,
|
||||
IdentityPatchPrompt: settings.IdentityPatchPrompt,
|
||||
OpsMonitoringEnabled: settings.OpsMonitoringEnabled,
|
||||
OpsRealtimeMonitoringEnabled: settings.OpsRealtimeMonitoringEnabled,
|
||||
OpsQueryModeDefault: settings.OpsQueryModeDefault,
|
||||
OpsMetricsIntervalSeconds: settings.OpsMetricsIntervalSeconds,
|
||||
RegistrationEnabled: settings.RegistrationEnabled,
|
||||
EmailVerifyEnabled: settings.EmailVerifyEnabled,
|
||||
SMTPHost: settings.SMTPHost,
|
||||
SMTPPort: settings.SMTPPort,
|
||||
SMTPUsername: settings.SMTPUsername,
|
||||
SMTPPasswordConfigured: settings.SMTPPasswordConfigured,
|
||||
SMTPFrom: settings.SMTPFrom,
|
||||
SMTPFromName: settings.SMTPFromName,
|
||||
SMTPUseTLS: settings.SMTPUseTLS,
|
||||
TurnstileEnabled: settings.TurnstileEnabled,
|
||||
TurnstileSiteKey: settings.TurnstileSiteKey,
|
||||
TurnstileSecretKeyConfigured: settings.TurnstileSecretKeyConfigured,
|
||||
LinuxDoConnectEnabled: settings.LinuxDoConnectEnabled,
|
||||
LinuxDoConnectClientID: settings.LinuxDoConnectClientID,
|
||||
LinuxDoConnectClientSecretConfigured: settings.LinuxDoConnectClientSecretConfigured,
|
||||
LinuxDoConnectRedirectURL: settings.LinuxDoConnectRedirectURL,
|
||||
SiteName: settings.SiteName,
|
||||
SiteLogo: settings.SiteLogo,
|
||||
SiteSubtitle: settings.SiteSubtitle,
|
||||
APIBaseURL: settings.APIBaseURL,
|
||||
ContactInfo: settings.ContactInfo,
|
||||
DocURL: settings.DocURL,
|
||||
HomeContent: settings.HomeContent,
|
||||
DefaultConcurrency: settings.DefaultConcurrency,
|
||||
DefaultBalance: settings.DefaultBalance,
|
||||
EnableModelFallback: settings.EnableModelFallback,
|
||||
FallbackModelAnthropic: settings.FallbackModelAnthropic,
|
||||
FallbackModelOpenAI: settings.FallbackModelOpenAI,
|
||||
FallbackModelGemini: settings.FallbackModelGemini,
|
||||
FallbackModelAntigravity: settings.FallbackModelAntigravity,
|
||||
EnableIdentityPatch: settings.EnableIdentityPatch,
|
||||
IdentityPatchPrompt: settings.IdentityPatchPrompt,
|
||||
OpsMonitoringEnabled: settings.OpsMonitoringEnabled,
|
||||
OpsRealtimeMonitoringEnabled: settings.OpsRealtimeMonitoringEnabled,
|
||||
OpsQueryModeDefault: settings.OpsQueryModeDefault,
|
||||
OpsMetricsIntervalSeconds: settings.OpsMetricsIntervalSeconds,
|
||||
})
|
||||
}
|
||||
|
||||
@@ -92,6 +99,12 @@ type UpdateSettingsRequest struct {
|
||||
TurnstileSiteKey string `json:"turnstile_site_key"`
|
||||
TurnstileSecretKey string `json:"turnstile_secret_key"`
|
||||
|
||||
// LinuxDo Connect OAuth 登录(终端用户 SSO)
|
||||
LinuxDoConnectEnabled bool `json:"linuxdo_connect_enabled"`
|
||||
LinuxDoConnectClientID string `json:"linuxdo_connect_client_id"`
|
||||
LinuxDoConnectClientSecret string `json:"linuxdo_connect_client_secret"`
|
||||
LinuxDoConnectRedirectURL string `json:"linuxdo_connect_redirect_url"`
|
||||
|
||||
// OEM设置
|
||||
SiteName string `json:"site_name"`
|
||||
SiteLogo string `json:"site_logo"`
|
||||
@@ -99,6 +112,7 @@ type UpdateSettingsRequest struct {
|
||||
APIBaseURL string `json:"api_base_url"`
|
||||
ContactInfo string `json:"contact_info"`
|
||||
DocURL string `json:"doc_url"`
|
||||
HomeContent string `json:"home_content"`
|
||||
|
||||
// 默认配置
|
||||
DefaultConcurrency int `json:"default_concurrency"`
|
||||
@@ -175,6 +189,35 @@ func (h *SettingHandler) UpdateSettings(c *gin.Context) {
|
||||
}
|
||||
}
|
||||
|
||||
// LinuxDo Connect 参数验证
|
||||
if req.LinuxDoConnectEnabled {
|
||||
req.LinuxDoConnectClientID = strings.TrimSpace(req.LinuxDoConnectClientID)
|
||||
req.LinuxDoConnectClientSecret = strings.TrimSpace(req.LinuxDoConnectClientSecret)
|
||||
req.LinuxDoConnectRedirectURL = strings.TrimSpace(req.LinuxDoConnectRedirectURL)
|
||||
|
||||
if req.LinuxDoConnectClientID == "" {
|
||||
response.BadRequest(c, "LinuxDo Client ID is required when enabled")
|
||||
return
|
||||
}
|
||||
if req.LinuxDoConnectRedirectURL == "" {
|
||||
response.BadRequest(c, "LinuxDo Redirect URL is required when enabled")
|
||||
return
|
||||
}
|
||||
if err := config.ValidateAbsoluteHTTPURL(req.LinuxDoConnectRedirectURL); err != nil {
|
||||
response.BadRequest(c, "LinuxDo Redirect URL must be an absolute http(s) URL")
|
||||
return
|
||||
}
|
||||
|
||||
// 如果未提供 client_secret,则保留现有值(如有)。
|
||||
if req.LinuxDoConnectClientSecret == "" {
|
||||
if previousSettings.LinuxDoConnectClientSecret == "" {
|
||||
response.BadRequest(c, "LinuxDo Client Secret is required when enabled")
|
||||
return
|
||||
}
|
||||
req.LinuxDoConnectClientSecret = previousSettings.LinuxDoConnectClientSecret
|
||||
}
|
||||
}
|
||||
|
||||
// Ops metrics collector interval validation (seconds).
|
||||
if req.OpsMetricsIntervalSeconds != nil {
|
||||
v := *req.OpsMetricsIntervalSeconds
|
||||
@@ -188,33 +231,38 @@ func (h *SettingHandler) UpdateSettings(c *gin.Context) {
|
||||
}
|
||||
|
||||
settings := &service.SystemSettings{
|
||||
RegistrationEnabled: req.RegistrationEnabled,
|
||||
EmailVerifyEnabled: req.EmailVerifyEnabled,
|
||||
SMTPHost: req.SMTPHost,
|
||||
SMTPPort: req.SMTPPort,
|
||||
SMTPUsername: req.SMTPUsername,
|
||||
SMTPPassword: req.SMTPPassword,
|
||||
SMTPFrom: req.SMTPFrom,
|
||||
SMTPFromName: req.SMTPFromName,
|
||||
SMTPUseTLS: req.SMTPUseTLS,
|
||||
TurnstileEnabled: req.TurnstileEnabled,
|
||||
TurnstileSiteKey: req.TurnstileSiteKey,
|
||||
TurnstileSecretKey: req.TurnstileSecretKey,
|
||||
SiteName: req.SiteName,
|
||||
SiteLogo: req.SiteLogo,
|
||||
SiteSubtitle: req.SiteSubtitle,
|
||||
APIBaseURL: req.APIBaseURL,
|
||||
ContactInfo: req.ContactInfo,
|
||||
DocURL: req.DocURL,
|
||||
DefaultConcurrency: req.DefaultConcurrency,
|
||||
DefaultBalance: req.DefaultBalance,
|
||||
EnableModelFallback: req.EnableModelFallback,
|
||||
FallbackModelAnthropic: req.FallbackModelAnthropic,
|
||||
FallbackModelOpenAI: req.FallbackModelOpenAI,
|
||||
FallbackModelGemini: req.FallbackModelGemini,
|
||||
FallbackModelAntigravity: req.FallbackModelAntigravity,
|
||||
EnableIdentityPatch: req.EnableIdentityPatch,
|
||||
IdentityPatchPrompt: req.IdentityPatchPrompt,
|
||||
RegistrationEnabled: req.RegistrationEnabled,
|
||||
EmailVerifyEnabled: req.EmailVerifyEnabled,
|
||||
SMTPHost: req.SMTPHost,
|
||||
SMTPPort: req.SMTPPort,
|
||||
SMTPUsername: req.SMTPUsername,
|
||||
SMTPPassword: req.SMTPPassword,
|
||||
SMTPFrom: req.SMTPFrom,
|
||||
SMTPFromName: req.SMTPFromName,
|
||||
SMTPUseTLS: req.SMTPUseTLS,
|
||||
TurnstileEnabled: req.TurnstileEnabled,
|
||||
TurnstileSiteKey: req.TurnstileSiteKey,
|
||||
TurnstileSecretKey: req.TurnstileSecretKey,
|
||||
LinuxDoConnectEnabled: req.LinuxDoConnectEnabled,
|
||||
LinuxDoConnectClientID: req.LinuxDoConnectClientID,
|
||||
LinuxDoConnectClientSecret: req.LinuxDoConnectClientSecret,
|
||||
LinuxDoConnectRedirectURL: req.LinuxDoConnectRedirectURL,
|
||||
SiteName: req.SiteName,
|
||||
SiteLogo: req.SiteLogo,
|
||||
SiteSubtitle: req.SiteSubtitle,
|
||||
APIBaseURL: req.APIBaseURL,
|
||||
ContactInfo: req.ContactInfo,
|
||||
DocURL: req.DocURL,
|
||||
HomeContent: req.HomeContent,
|
||||
DefaultConcurrency: req.DefaultConcurrency,
|
||||
DefaultBalance: req.DefaultBalance,
|
||||
EnableModelFallback: req.EnableModelFallback,
|
||||
FallbackModelAnthropic: req.FallbackModelAnthropic,
|
||||
FallbackModelOpenAI: req.FallbackModelOpenAI,
|
||||
FallbackModelGemini: req.FallbackModelGemini,
|
||||
FallbackModelAntigravity: req.FallbackModelAntigravity,
|
||||
EnableIdentityPatch: req.EnableIdentityPatch,
|
||||
IdentityPatchPrompt: req.IdentityPatchPrompt,
|
||||
OpsMonitoringEnabled: func() bool {
|
||||
if req.OpsMonitoringEnabled != nil {
|
||||
return *req.OpsMonitoringEnabled
|
||||
@@ -256,37 +304,42 @@ func (h *SettingHandler) UpdateSettings(c *gin.Context) {
|
||||
}
|
||||
|
||||
response.Success(c, dto.SystemSettings{
|
||||
RegistrationEnabled: updatedSettings.RegistrationEnabled,
|
||||
EmailVerifyEnabled: updatedSettings.EmailVerifyEnabled,
|
||||
SMTPHost: updatedSettings.SMTPHost,
|
||||
SMTPPort: updatedSettings.SMTPPort,
|
||||
SMTPUsername: updatedSettings.SMTPUsername,
|
||||
SMTPPasswordConfigured: updatedSettings.SMTPPasswordConfigured,
|
||||
SMTPFrom: updatedSettings.SMTPFrom,
|
||||
SMTPFromName: updatedSettings.SMTPFromName,
|
||||
SMTPUseTLS: updatedSettings.SMTPUseTLS,
|
||||
TurnstileEnabled: updatedSettings.TurnstileEnabled,
|
||||
TurnstileSiteKey: updatedSettings.TurnstileSiteKey,
|
||||
TurnstileSecretKeyConfigured: updatedSettings.TurnstileSecretKeyConfigured,
|
||||
SiteName: updatedSettings.SiteName,
|
||||
SiteLogo: updatedSettings.SiteLogo,
|
||||
SiteSubtitle: updatedSettings.SiteSubtitle,
|
||||
APIBaseURL: updatedSettings.APIBaseURL,
|
||||
ContactInfo: updatedSettings.ContactInfo,
|
||||
DocURL: updatedSettings.DocURL,
|
||||
DefaultConcurrency: updatedSettings.DefaultConcurrency,
|
||||
DefaultBalance: updatedSettings.DefaultBalance,
|
||||
EnableModelFallback: updatedSettings.EnableModelFallback,
|
||||
FallbackModelAnthropic: updatedSettings.FallbackModelAnthropic,
|
||||
FallbackModelOpenAI: updatedSettings.FallbackModelOpenAI,
|
||||
FallbackModelGemini: updatedSettings.FallbackModelGemini,
|
||||
FallbackModelAntigravity: updatedSettings.FallbackModelAntigravity,
|
||||
EnableIdentityPatch: updatedSettings.EnableIdentityPatch,
|
||||
IdentityPatchPrompt: updatedSettings.IdentityPatchPrompt,
|
||||
OpsMonitoringEnabled: updatedSettings.OpsMonitoringEnabled,
|
||||
OpsRealtimeMonitoringEnabled: updatedSettings.OpsRealtimeMonitoringEnabled,
|
||||
OpsQueryModeDefault: updatedSettings.OpsQueryModeDefault,
|
||||
OpsMetricsIntervalSeconds: updatedSettings.OpsMetricsIntervalSeconds,
|
||||
RegistrationEnabled: updatedSettings.RegistrationEnabled,
|
||||
EmailVerifyEnabled: updatedSettings.EmailVerifyEnabled,
|
||||
SMTPHost: updatedSettings.SMTPHost,
|
||||
SMTPPort: updatedSettings.SMTPPort,
|
||||
SMTPUsername: updatedSettings.SMTPUsername,
|
||||
SMTPPasswordConfigured: updatedSettings.SMTPPasswordConfigured,
|
||||
SMTPFrom: updatedSettings.SMTPFrom,
|
||||
SMTPFromName: updatedSettings.SMTPFromName,
|
||||
SMTPUseTLS: updatedSettings.SMTPUseTLS,
|
||||
TurnstileEnabled: updatedSettings.TurnstileEnabled,
|
||||
TurnstileSiteKey: updatedSettings.TurnstileSiteKey,
|
||||
TurnstileSecretKeyConfigured: updatedSettings.TurnstileSecretKeyConfigured,
|
||||
LinuxDoConnectEnabled: updatedSettings.LinuxDoConnectEnabled,
|
||||
LinuxDoConnectClientID: updatedSettings.LinuxDoConnectClientID,
|
||||
LinuxDoConnectClientSecretConfigured: updatedSettings.LinuxDoConnectClientSecretConfigured,
|
||||
LinuxDoConnectRedirectURL: updatedSettings.LinuxDoConnectRedirectURL,
|
||||
SiteName: updatedSettings.SiteName,
|
||||
SiteLogo: updatedSettings.SiteLogo,
|
||||
SiteSubtitle: updatedSettings.SiteSubtitle,
|
||||
APIBaseURL: updatedSettings.APIBaseURL,
|
||||
ContactInfo: updatedSettings.ContactInfo,
|
||||
DocURL: updatedSettings.DocURL,
|
||||
HomeContent: updatedSettings.HomeContent,
|
||||
DefaultConcurrency: updatedSettings.DefaultConcurrency,
|
||||
DefaultBalance: updatedSettings.DefaultBalance,
|
||||
EnableModelFallback: updatedSettings.EnableModelFallback,
|
||||
FallbackModelAnthropic: updatedSettings.FallbackModelAnthropic,
|
||||
FallbackModelOpenAI: updatedSettings.FallbackModelOpenAI,
|
||||
FallbackModelGemini: updatedSettings.FallbackModelGemini,
|
||||
FallbackModelAntigravity: updatedSettings.FallbackModelAntigravity,
|
||||
EnableIdentityPatch: updatedSettings.EnableIdentityPatch,
|
||||
IdentityPatchPrompt: updatedSettings.IdentityPatchPrompt,
|
||||
OpsMonitoringEnabled: updatedSettings.OpsMonitoringEnabled,
|
||||
OpsRealtimeMonitoringEnabled: updatedSettings.OpsRealtimeMonitoringEnabled,
|
||||
OpsQueryModeDefault: updatedSettings.OpsQueryModeDefault,
|
||||
OpsMetricsIntervalSeconds: updatedSettings.OpsMetricsIntervalSeconds,
|
||||
})
|
||||
}
|
||||
|
||||
@@ -348,6 +401,18 @@ func diffSettings(before *service.SystemSettings, after *service.SystemSettings,
|
||||
if req.TurnstileSecretKey != "" {
|
||||
changed = append(changed, "turnstile_secret_key")
|
||||
}
|
||||
if before.LinuxDoConnectEnabled != after.LinuxDoConnectEnabled {
|
||||
changed = append(changed, "linuxdo_connect_enabled")
|
||||
}
|
||||
if before.LinuxDoConnectClientID != after.LinuxDoConnectClientID {
|
||||
changed = append(changed, "linuxdo_connect_client_id")
|
||||
}
|
||||
if req.LinuxDoConnectClientSecret != "" {
|
||||
changed = append(changed, "linuxdo_connect_client_secret")
|
||||
}
|
||||
if before.LinuxDoConnectRedirectURL != after.LinuxDoConnectRedirectURL {
|
||||
changed = append(changed, "linuxdo_connect_redirect_url")
|
||||
}
|
||||
if before.SiteName != after.SiteName {
|
||||
changed = append(changed, "site_name")
|
||||
}
|
||||
@@ -366,6 +431,9 @@ func diffSettings(before *service.SystemSettings, after *service.SystemSettings,
|
||||
if before.DocURL != after.DocURL {
|
||||
changed = append(changed, "doc_url")
|
||||
}
|
||||
if before.HomeContent != after.HomeContent {
|
||||
changed = append(changed, "home_content")
|
||||
}
|
||||
if before.DefaultConcurrency != after.DefaultConcurrency {
|
||||
changed = append(changed, "default_concurrency")
|
||||
}
|
||||
@@ -387,6 +455,12 @@ func diffSettings(before *service.SystemSettings, after *service.SystemSettings,
|
||||
if before.FallbackModelAntigravity != after.FallbackModelAntigravity {
|
||||
changed = append(changed, "fallback_model_antigravity")
|
||||
}
|
||||
if before.EnableIdentityPatch != after.EnableIdentityPatch {
|
||||
changed = append(changed, "enable_identity_patch")
|
||||
}
|
||||
if before.IdentityPatchPrompt != after.IdentityPatchPrompt {
|
||||
changed = append(changed, "identity_patch_prompt")
|
||||
}
|
||||
if before.OpsMonitoringEnabled != after.OpsMonitoringEnabled {
|
||||
changed = append(changed, "ops_monitoring_enabled")
|
||||
}
|
||||
|
||||
@@ -12,19 +12,21 @@ import (
|
||||
|
||||
// AuthHandler handles authentication-related requests
|
||||
type AuthHandler struct {
|
||||
cfg *config.Config
|
||||
authService *service.AuthService
|
||||
userService *service.UserService
|
||||
settingSvc *service.SettingService
|
||||
cfg *config.Config
|
||||
authService *service.AuthService
|
||||
userService *service.UserService
|
||||
settingSvc *service.SettingService
|
||||
promoService *service.PromoService
|
||||
}
|
||||
|
||||
// NewAuthHandler creates a new AuthHandler
|
||||
func NewAuthHandler(cfg *config.Config, authService *service.AuthService, userService *service.UserService, settingService *service.SettingService) *AuthHandler {
|
||||
func NewAuthHandler(cfg *config.Config, authService *service.AuthService, userService *service.UserService, settingService *service.SettingService, promoService *service.PromoService) *AuthHandler {
|
||||
return &AuthHandler{
|
||||
cfg: cfg,
|
||||
authService: authService,
|
||||
userService: userService,
|
||||
settingSvc: settingService,
|
||||
cfg: cfg,
|
||||
authService: authService,
|
||||
userService: userService,
|
||||
settingSvc: settingService,
|
||||
promoService: promoService,
|
||||
}
|
||||
}
|
||||
|
||||
@@ -34,6 +36,7 @@ type RegisterRequest struct {
|
||||
Password string `json:"password" binding:"required,min=6"`
|
||||
VerifyCode string `json:"verify_code"`
|
||||
TurnstileToken string `json:"turnstile_token"`
|
||||
PromoCode string `json:"promo_code"` // 注册优惠码
|
||||
}
|
||||
|
||||
// SendVerifyCodeRequest 发送验证码请求
|
||||
@@ -79,7 +82,7 @@ func (h *AuthHandler) Register(c *gin.Context) {
|
||||
}
|
||||
}
|
||||
|
||||
token, user, err := h.authService.RegisterWithVerification(c.Request.Context(), req.Email, req.Password, req.VerifyCode)
|
||||
token, user, err := h.authService.RegisterWithVerification(c.Request.Context(), req.Email, req.Password, req.VerifyCode, req.PromoCode)
|
||||
if err != nil {
|
||||
response.ErrorFrom(c, err)
|
||||
return
|
||||
@@ -174,3 +177,63 @@ func (h *AuthHandler) GetCurrentUser(c *gin.Context) {
|
||||
|
||||
response.Success(c, UserResponse{User: dto.UserFromService(user), RunMode: runMode})
|
||||
}
|
||||
|
||||
// ValidatePromoCodeRequest 验证优惠码请求
|
||||
type ValidatePromoCodeRequest struct {
|
||||
Code string `json:"code" binding:"required"`
|
||||
}
|
||||
|
||||
// ValidatePromoCodeResponse 验证优惠码响应
|
||||
type ValidatePromoCodeResponse struct {
|
||||
Valid bool `json:"valid"`
|
||||
BonusAmount float64 `json:"bonus_amount,omitempty"`
|
||||
ErrorCode string `json:"error_code,omitempty"`
|
||||
Message string `json:"message,omitempty"`
|
||||
}
|
||||
|
||||
// ValidatePromoCode 验证优惠码(公开接口,注册前调用)
|
||||
// POST /api/v1/auth/validate-promo-code
|
||||
func (h *AuthHandler) ValidatePromoCode(c *gin.Context) {
|
||||
var req ValidatePromoCodeRequest
|
||||
if err := c.ShouldBindJSON(&req); err != nil {
|
||||
response.BadRequest(c, "Invalid request: "+err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
promoCode, err := h.promoService.ValidatePromoCode(c.Request.Context(), req.Code)
|
||||
if err != nil {
|
||||
// 根据错误类型返回对应的错误码
|
||||
errorCode := "PROMO_CODE_INVALID"
|
||||
switch err {
|
||||
case service.ErrPromoCodeNotFound:
|
||||
errorCode = "PROMO_CODE_NOT_FOUND"
|
||||
case service.ErrPromoCodeExpired:
|
||||
errorCode = "PROMO_CODE_EXPIRED"
|
||||
case service.ErrPromoCodeDisabled:
|
||||
errorCode = "PROMO_CODE_DISABLED"
|
||||
case service.ErrPromoCodeMaxUsed:
|
||||
errorCode = "PROMO_CODE_MAX_USED"
|
||||
case service.ErrPromoCodeAlreadyUsed:
|
||||
errorCode = "PROMO_CODE_ALREADY_USED"
|
||||
}
|
||||
|
||||
response.Success(c, ValidatePromoCodeResponse{
|
||||
Valid: false,
|
||||
ErrorCode: errorCode,
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
if promoCode == nil {
|
||||
response.Success(c, ValidatePromoCodeResponse{
|
||||
Valid: false,
|
||||
ErrorCode: "PROMO_CODE_INVALID",
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
response.Success(c, ValidatePromoCodeResponse{
|
||||
Valid: true,
|
||||
BonusAmount: promoCode.BonusAmount,
|
||||
})
|
||||
}
|
||||
|
||||
@@ -370,3 +370,35 @@ func BulkAssignResultFromService(r *service.BulkAssignResult) *BulkAssignResult
|
||||
Errors: r.Errors,
|
||||
}
|
||||
}
|
||||
|
||||
func PromoCodeFromService(pc *service.PromoCode) *PromoCode {
|
||||
if pc == nil {
|
||||
return nil
|
||||
}
|
||||
return &PromoCode{
|
||||
ID: pc.ID,
|
||||
Code: pc.Code,
|
||||
BonusAmount: pc.BonusAmount,
|
||||
MaxUses: pc.MaxUses,
|
||||
UsedCount: pc.UsedCount,
|
||||
Status: pc.Status,
|
||||
ExpiresAt: pc.ExpiresAt,
|
||||
Notes: pc.Notes,
|
||||
CreatedAt: pc.CreatedAt,
|
||||
UpdatedAt: pc.UpdatedAt,
|
||||
}
|
||||
}
|
||||
|
||||
func PromoCodeUsageFromService(u *service.PromoCodeUsage) *PromoCodeUsage {
|
||||
if u == nil {
|
||||
return nil
|
||||
}
|
||||
return &PromoCodeUsage{
|
||||
ID: u.ID,
|
||||
PromoCodeID: u.PromoCodeID,
|
||||
UserID: u.UserID,
|
||||
BonusAmount: u.BonusAmount,
|
||||
UsedAt: u.UsedAt,
|
||||
User: UserFromServiceShallow(u.User),
|
||||
}
|
||||
}
|
||||
|
||||
@@ -28,6 +28,7 @@ type SystemSettings struct {
|
||||
APIBaseURL string `json:"api_base_url"`
|
||||
ContactInfo string `json:"contact_info"`
|
||||
DocURL string `json:"doc_url"`
|
||||
HomeContent string `json:"home_content"`
|
||||
|
||||
DefaultConcurrency int `json:"default_concurrency"`
|
||||
DefaultBalance float64 `json:"default_balance"`
|
||||
@@ -61,6 +62,7 @@ type PublicSettings struct {
|
||||
APIBaseURL string `json:"api_base_url"`
|
||||
ContactInfo string `json:"contact_info"`
|
||||
DocURL string `json:"doc_url"`
|
||||
HomeContent string `json:"home_content"`
|
||||
LinuxDoOAuthEnabled bool `json:"linuxdo_oauth_enabled"`
|
||||
Version string `json:"version"`
|
||||
}
|
||||
|
||||
@@ -250,3 +250,28 @@ type BulkAssignResult struct {
|
||||
Subscriptions []UserSubscription `json:"subscriptions"`
|
||||
Errors []string `json:"errors"`
|
||||
}
|
||||
|
||||
// PromoCode 注册优惠码
|
||||
type PromoCode struct {
|
||||
ID int64 `json:"id"`
|
||||
Code string `json:"code"`
|
||||
BonusAmount float64 `json:"bonus_amount"`
|
||||
MaxUses int `json:"max_uses"`
|
||||
UsedCount int `json:"used_count"`
|
||||
Status string `json:"status"`
|
||||
ExpiresAt *time.Time `json:"expires_at"`
|
||||
Notes string `json:"notes"`
|
||||
CreatedAt time.Time `json:"created_at"`
|
||||
UpdatedAt time.Time `json:"updated_at"`
|
||||
}
|
||||
|
||||
// PromoCodeUsage 优惠码使用记录
|
||||
type PromoCodeUsage struct {
|
||||
ID int64 `json:"id"`
|
||||
PromoCodeID int64 `json:"promo_code_id"`
|
||||
UserID int64 `json:"user_id"`
|
||||
BonusAmount float64 `json:"bonus_amount"`
|
||||
UsedAt time.Time `json:"used_at"`
|
||||
|
||||
User *User `json:"user,omitempty"`
|
||||
}
|
||||
|
||||
@@ -16,6 +16,7 @@ type AdminHandlers struct {
|
||||
AntigravityOAuth *admin.AntigravityOAuthHandler
|
||||
Proxy *admin.ProxyHandler
|
||||
Redeem *admin.RedeemHandler
|
||||
Promo *admin.PromoHandler
|
||||
Setting *admin.SettingHandler
|
||||
Ops *admin.OpsHandler
|
||||
System *admin.SystemHandler
|
||||
|
||||
@@ -42,6 +42,7 @@ func (h *SettingHandler) GetPublicSettings(c *gin.Context) {
|
||||
APIBaseURL: settings.APIBaseURL,
|
||||
ContactInfo: settings.ContactInfo,
|
||||
DocURL: settings.DocURL,
|
||||
HomeContent: settings.HomeContent,
|
||||
LinuxDoOAuthEnabled: settings.LinuxDoOAuthEnabled,
|
||||
Version: h.version,
|
||||
})
|
||||
|
||||
@@ -19,6 +19,7 @@ func ProvideAdminHandlers(
|
||||
antigravityOAuthHandler *admin.AntigravityOAuthHandler,
|
||||
proxyHandler *admin.ProxyHandler,
|
||||
redeemHandler *admin.RedeemHandler,
|
||||
promoHandler *admin.PromoHandler,
|
||||
settingHandler *admin.SettingHandler,
|
||||
opsHandler *admin.OpsHandler,
|
||||
systemHandler *admin.SystemHandler,
|
||||
@@ -37,6 +38,7 @@ func ProvideAdminHandlers(
|
||||
AntigravityOAuth: antigravityOAuthHandler,
|
||||
Proxy: proxyHandler,
|
||||
Redeem: redeemHandler,
|
||||
Promo: promoHandler,
|
||||
Setting: settingHandler,
|
||||
Ops: opsHandler,
|
||||
System: systemHandler,
|
||||
@@ -107,6 +109,7 @@ var ProviderSet = wire.NewSet(
|
||||
admin.NewAntigravityOAuthHandler,
|
||||
admin.NewProxyHandler,
|
||||
admin.NewRedeemHandler,
|
||||
admin.NewPromoHandler,
|
||||
admin.NewSettingHandler,
|
||||
admin.NewOpsHandler,
|
||||
ProvideSystemHandler,
|
||||
|
||||
60
backend/internal/middleware/rate_limiter.go
Normal file
60
backend/internal/middleware/rate_limiter.go
Normal file
@@ -0,0 +1,60 @@
|
||||
package middleware
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
"time"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/redis/go-redis/v9"
|
||||
)
|
||||
|
||||
// RateLimiter Redis 速率限制器
|
||||
type RateLimiter struct {
|
||||
redis *redis.Client
|
||||
prefix string
|
||||
}
|
||||
|
||||
// NewRateLimiter 创建速率限制器实例
|
||||
func NewRateLimiter(redisClient *redis.Client) *RateLimiter {
|
||||
return &RateLimiter{
|
||||
redis: redisClient,
|
||||
prefix: "rate_limit:",
|
||||
}
|
||||
}
|
||||
|
||||
// Limit 返回速率限制中间件
|
||||
// key: 限制类型标识
|
||||
// limit: 时间窗口内最大请求数
|
||||
// window: 时间窗口
|
||||
func (r *RateLimiter) Limit(key string, limit int, window time.Duration) gin.HandlerFunc {
|
||||
return func(c *gin.Context) {
|
||||
ip := c.ClientIP()
|
||||
redisKey := r.prefix + key + ":" + ip
|
||||
|
||||
ctx := c.Request.Context()
|
||||
|
||||
// 使用 INCR 原子操作增加计数
|
||||
count, err := r.redis.Incr(ctx, redisKey).Result()
|
||||
if err != nil {
|
||||
// Redis 错误时放行,避免影响正常服务
|
||||
c.Next()
|
||||
return
|
||||
}
|
||||
|
||||
// 首次访问时设置过期时间
|
||||
if count == 1 {
|
||||
r.redis.Expire(ctx, redisKey, window)
|
||||
}
|
||||
|
||||
// 超过限制
|
||||
if count > int64(limit) {
|
||||
c.AbortWithStatusJSON(http.StatusTooManyRequests, gin.H{
|
||||
"error": "rate limit exceeded",
|
||||
"message": "Too many requests, please try again later",
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
c.Next()
|
||||
}
|
||||
}
|
||||
@@ -16,4 +16,6 @@ const (
|
||||
|
||||
// IsClaudeCodeClient 标识当前请求是否来自 Claude Code 客户端
|
||||
IsClaudeCodeClient Key = "ctx_is_claude_code_client"
|
||||
// Group 认证后的分组信息,由 API Key 认证中间件设置
|
||||
Group Key = "ctx_group"
|
||||
)
|
||||
|
||||
@@ -339,6 +339,7 @@ func groupEntityToService(g *dbent.Group) *service.Group {
|
||||
RateMultiplier: g.RateMultiplier,
|
||||
IsExclusive: g.IsExclusive,
|
||||
Status: g.Status,
|
||||
Hydrated: true,
|
||||
SubscriptionType: g.SubscriptionType,
|
||||
DailyLimitUSD: g.DailyLimitUsd,
|
||||
WeeklyLimitUSD: g.WeeklyLimitUsd,
|
||||
|
||||
@@ -60,6 +60,17 @@ func (r *groupRepository) Create(ctx context.Context, groupIn *service.Group) er
|
||||
}
|
||||
|
||||
func (r *groupRepository) GetByID(ctx context.Context, id int64) (*service.Group, error) {
|
||||
out, err := r.GetByIDLite(ctx, id)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
count, _ := r.GetAccountCount(ctx, out.ID)
|
||||
out.AccountCount = count
|
||||
return out, nil
|
||||
}
|
||||
|
||||
func (r *groupRepository) GetByIDLite(ctx context.Context, id int64) (*service.Group, error) {
|
||||
// AccountCount is intentionally not loaded here; use GetByID when needed.
|
||||
m, err := r.client.Group.Query().
|
||||
Where(group.IDEQ(id)).
|
||||
Only(ctx)
|
||||
@@ -67,10 +78,7 @@ func (r *groupRepository) GetByID(ctx context.Context, id int64) (*service.Group
|
||||
return nil, translatePersistenceError(err, service.ErrGroupNotFound, nil)
|
||||
}
|
||||
|
||||
out := groupEntityToService(m)
|
||||
count, _ := r.GetAccountCount(ctx, out.ID)
|
||||
out.AccountCount = count
|
||||
return out, nil
|
||||
return groupEntityToService(m), nil
|
||||
}
|
||||
|
||||
func (r *groupRepository) Update(ctx context.Context, groupIn *service.Group) error {
|
||||
|
||||
@@ -4,6 +4,8 @@ package repository
|
||||
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"errors"
|
||||
"testing"
|
||||
|
||||
dbent "github.com/Wei-Shaw/sub2api/ent"
|
||||
@@ -19,6 +21,20 @@ type GroupRepoSuite struct {
|
||||
repo *groupRepository
|
||||
}
|
||||
|
||||
type forbidSQLExecutor struct {
|
||||
called bool
|
||||
}
|
||||
|
||||
func (s *forbidSQLExecutor) ExecContext(ctx context.Context, query string, args ...any) (sql.Result, error) {
|
||||
s.called = true
|
||||
return nil, errors.New("unexpected sql exec")
|
||||
}
|
||||
|
||||
func (s *forbidSQLExecutor) QueryContext(ctx context.Context, query string, args ...any) (*sql.Rows, error) {
|
||||
s.called = true
|
||||
return nil, errors.New("unexpected sql query")
|
||||
}
|
||||
|
||||
func (s *GroupRepoSuite) SetupTest() {
|
||||
s.ctx = context.Background()
|
||||
tx := testEntTx(s.T())
|
||||
@@ -57,6 +73,26 @@ func (s *GroupRepoSuite) TestGetByID_NotFound() {
|
||||
s.Require().ErrorIs(err, service.ErrGroupNotFound)
|
||||
}
|
||||
|
||||
func (s *GroupRepoSuite) TestGetByIDLite_DoesNotUseAccountCount() {
|
||||
group := &service.Group{
|
||||
Name: "lite-group",
|
||||
Platform: service.PlatformAnthropic,
|
||||
RateMultiplier: 1.0,
|
||||
IsExclusive: false,
|
||||
Status: service.StatusActive,
|
||||
SubscriptionType: service.SubscriptionTypeStandard,
|
||||
}
|
||||
s.Require().NoError(s.repo.Create(s.ctx, group))
|
||||
|
||||
spy := &forbidSQLExecutor{}
|
||||
repo := newGroupRepositoryWithSQL(s.tx.Client(), spy)
|
||||
|
||||
got, err := repo.GetByIDLite(s.ctx, group.ID)
|
||||
s.Require().NoError(err)
|
||||
s.Require().Equal(group.ID, got.ID)
|
||||
s.Require().False(spy.called, "expected no direct sql executor usage")
|
||||
}
|
||||
|
||||
func (s *GroupRepoSuite) TestUpdate() {
|
||||
group := &service.Group{
|
||||
Name: "original",
|
||||
|
||||
273
backend/internal/repository/promo_code_repo.go
Normal file
273
backend/internal/repository/promo_code_repo.go
Normal file
@@ -0,0 +1,273 @@
|
||||
package repository
|
||||
|
||||
import (
|
||||
"context"
|
||||
|
||||
dbent "github.com/Wei-Shaw/sub2api/ent"
|
||||
"github.com/Wei-Shaw/sub2api/ent/promocode"
|
||||
"github.com/Wei-Shaw/sub2api/ent/promocodeusage"
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/pagination"
|
||||
"github.com/Wei-Shaw/sub2api/internal/service"
|
||||
)
|
||||
|
||||
type promoCodeRepository struct {
|
||||
client *dbent.Client
|
||||
}
|
||||
|
||||
func NewPromoCodeRepository(client *dbent.Client) service.PromoCodeRepository {
|
||||
return &promoCodeRepository{client: client}
|
||||
}
|
||||
|
||||
func (r *promoCodeRepository) Create(ctx context.Context, code *service.PromoCode) error {
|
||||
client := clientFromContext(ctx, r.client)
|
||||
builder := client.PromoCode.Create().
|
||||
SetCode(code.Code).
|
||||
SetBonusAmount(code.BonusAmount).
|
||||
SetMaxUses(code.MaxUses).
|
||||
SetUsedCount(code.UsedCount).
|
||||
SetStatus(code.Status).
|
||||
SetNotes(code.Notes)
|
||||
|
||||
if code.ExpiresAt != nil {
|
||||
builder.SetExpiresAt(*code.ExpiresAt)
|
||||
}
|
||||
|
||||
created, err := builder.Save(ctx)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
code.ID = created.ID
|
||||
code.CreatedAt = created.CreatedAt
|
||||
code.UpdatedAt = created.UpdatedAt
|
||||
return nil
|
||||
}
|
||||
|
||||
func (r *promoCodeRepository) GetByID(ctx context.Context, id int64) (*service.PromoCode, error) {
|
||||
m, err := r.client.PromoCode.Query().
|
||||
Where(promocode.IDEQ(id)).
|
||||
Only(ctx)
|
||||
if err != nil {
|
||||
if dbent.IsNotFound(err) {
|
||||
return nil, service.ErrPromoCodeNotFound
|
||||
}
|
||||
return nil, err
|
||||
}
|
||||
return promoCodeEntityToService(m), nil
|
||||
}
|
||||
|
||||
func (r *promoCodeRepository) GetByCode(ctx context.Context, code string) (*service.PromoCode, error) {
|
||||
m, err := r.client.PromoCode.Query().
|
||||
Where(promocode.CodeEqualFold(code)).
|
||||
Only(ctx)
|
||||
if err != nil {
|
||||
if dbent.IsNotFound(err) {
|
||||
return nil, service.ErrPromoCodeNotFound
|
||||
}
|
||||
return nil, err
|
||||
}
|
||||
return promoCodeEntityToService(m), nil
|
||||
}
|
||||
|
||||
func (r *promoCodeRepository) GetByCodeForUpdate(ctx context.Context, code string) (*service.PromoCode, error) {
|
||||
client := clientFromContext(ctx, r.client)
|
||||
m, err := client.PromoCode.Query().
|
||||
Where(promocode.CodeEqualFold(code)).
|
||||
ForUpdate().
|
||||
Only(ctx)
|
||||
if err != nil {
|
||||
if dbent.IsNotFound(err) {
|
||||
return nil, service.ErrPromoCodeNotFound
|
||||
}
|
||||
return nil, err
|
||||
}
|
||||
return promoCodeEntityToService(m), nil
|
||||
}
|
||||
|
||||
func (r *promoCodeRepository) Update(ctx context.Context, code *service.PromoCode) error {
|
||||
client := clientFromContext(ctx, r.client)
|
||||
builder := client.PromoCode.UpdateOneID(code.ID).
|
||||
SetCode(code.Code).
|
||||
SetBonusAmount(code.BonusAmount).
|
||||
SetMaxUses(code.MaxUses).
|
||||
SetUsedCount(code.UsedCount).
|
||||
SetStatus(code.Status).
|
||||
SetNotes(code.Notes)
|
||||
|
||||
if code.ExpiresAt != nil {
|
||||
builder.SetExpiresAt(*code.ExpiresAt)
|
||||
} else {
|
||||
builder.ClearExpiresAt()
|
||||
}
|
||||
|
||||
updated, err := builder.Save(ctx)
|
||||
if err != nil {
|
||||
if dbent.IsNotFound(err) {
|
||||
return service.ErrPromoCodeNotFound
|
||||
}
|
||||
return err
|
||||
}
|
||||
|
||||
code.UpdatedAt = updated.UpdatedAt
|
||||
return nil
|
||||
}
|
||||
|
||||
func (r *promoCodeRepository) Delete(ctx context.Context, id int64) error {
|
||||
client := clientFromContext(ctx, r.client)
|
||||
_, err := client.PromoCode.Delete().Where(promocode.IDEQ(id)).Exec(ctx)
|
||||
return err
|
||||
}
|
||||
|
||||
func (r *promoCodeRepository) List(ctx context.Context, params pagination.PaginationParams) ([]service.PromoCode, *pagination.PaginationResult, error) {
|
||||
return r.ListWithFilters(ctx, params, "", "")
|
||||
}
|
||||
|
||||
func (r *promoCodeRepository) ListWithFilters(ctx context.Context, params pagination.PaginationParams, status, search string) ([]service.PromoCode, *pagination.PaginationResult, error) {
|
||||
q := r.client.PromoCode.Query()
|
||||
|
||||
if status != "" {
|
||||
q = q.Where(promocode.StatusEQ(status))
|
||||
}
|
||||
if search != "" {
|
||||
q = q.Where(promocode.CodeContainsFold(search))
|
||||
}
|
||||
|
||||
total, err := q.Count(ctx)
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
|
||||
codes, err := q.
|
||||
Offset(params.Offset()).
|
||||
Limit(params.Limit()).
|
||||
Order(dbent.Desc(promocode.FieldID)).
|
||||
All(ctx)
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
|
||||
outCodes := promoCodeEntitiesToService(codes)
|
||||
|
||||
return outCodes, paginationResultFromTotal(int64(total), params), nil
|
||||
}
|
||||
|
||||
func (r *promoCodeRepository) CreateUsage(ctx context.Context, usage *service.PromoCodeUsage) error {
|
||||
client := clientFromContext(ctx, r.client)
|
||||
created, err := client.PromoCodeUsage.Create().
|
||||
SetPromoCodeID(usage.PromoCodeID).
|
||||
SetUserID(usage.UserID).
|
||||
SetBonusAmount(usage.BonusAmount).
|
||||
SetUsedAt(usage.UsedAt).
|
||||
Save(ctx)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
usage.ID = created.ID
|
||||
return nil
|
||||
}
|
||||
|
||||
func (r *promoCodeRepository) GetUsageByPromoCodeAndUser(ctx context.Context, promoCodeID, userID int64) (*service.PromoCodeUsage, error) {
|
||||
m, err := r.client.PromoCodeUsage.Query().
|
||||
Where(
|
||||
promocodeusage.PromoCodeIDEQ(promoCodeID),
|
||||
promocodeusage.UserIDEQ(userID),
|
||||
).
|
||||
Only(ctx)
|
||||
if err != nil {
|
||||
if dbent.IsNotFound(err) {
|
||||
return nil, nil
|
||||
}
|
||||
return nil, err
|
||||
}
|
||||
return promoCodeUsageEntityToService(m), nil
|
||||
}
|
||||
|
||||
func (r *promoCodeRepository) ListUsagesByPromoCode(ctx context.Context, promoCodeID int64, params pagination.PaginationParams) ([]service.PromoCodeUsage, *pagination.PaginationResult, error) {
|
||||
q := r.client.PromoCodeUsage.Query().
|
||||
Where(promocodeusage.PromoCodeIDEQ(promoCodeID))
|
||||
|
||||
total, err := q.Count(ctx)
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
|
||||
usages, err := q.
|
||||
WithUser().
|
||||
Offset(params.Offset()).
|
||||
Limit(params.Limit()).
|
||||
Order(dbent.Desc(promocodeusage.FieldID)).
|
||||
All(ctx)
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
|
||||
outUsages := promoCodeUsageEntitiesToService(usages)
|
||||
|
||||
return outUsages, paginationResultFromTotal(int64(total), params), nil
|
||||
}
|
||||
|
||||
func (r *promoCodeRepository) IncrementUsedCount(ctx context.Context, id int64) error {
|
||||
client := clientFromContext(ctx, r.client)
|
||||
_, err := client.PromoCode.UpdateOneID(id).
|
||||
AddUsedCount(1).
|
||||
Save(ctx)
|
||||
return err
|
||||
}
|
||||
|
||||
// Entity to Service conversions
|
||||
|
||||
func promoCodeEntityToService(m *dbent.PromoCode) *service.PromoCode {
|
||||
if m == nil {
|
||||
return nil
|
||||
}
|
||||
return &service.PromoCode{
|
||||
ID: m.ID,
|
||||
Code: m.Code,
|
||||
BonusAmount: m.BonusAmount,
|
||||
MaxUses: m.MaxUses,
|
||||
UsedCount: m.UsedCount,
|
||||
Status: m.Status,
|
||||
ExpiresAt: m.ExpiresAt,
|
||||
Notes: derefString(m.Notes),
|
||||
CreatedAt: m.CreatedAt,
|
||||
UpdatedAt: m.UpdatedAt,
|
||||
}
|
||||
}
|
||||
|
||||
func promoCodeEntitiesToService(models []*dbent.PromoCode) []service.PromoCode {
|
||||
out := make([]service.PromoCode, 0, len(models))
|
||||
for i := range models {
|
||||
if s := promoCodeEntityToService(models[i]); s != nil {
|
||||
out = append(out, *s)
|
||||
}
|
||||
}
|
||||
return out
|
||||
}
|
||||
|
||||
func promoCodeUsageEntityToService(m *dbent.PromoCodeUsage) *service.PromoCodeUsage {
|
||||
if m == nil {
|
||||
return nil
|
||||
}
|
||||
out := &service.PromoCodeUsage{
|
||||
ID: m.ID,
|
||||
PromoCodeID: m.PromoCodeID,
|
||||
UserID: m.UserID,
|
||||
BonusAmount: m.BonusAmount,
|
||||
UsedAt: m.UsedAt,
|
||||
}
|
||||
if m.Edges.User != nil {
|
||||
out.User = userEntityToService(m.Edges.User)
|
||||
}
|
||||
return out
|
||||
}
|
||||
|
||||
func promoCodeUsageEntitiesToService(models []*dbent.PromoCodeUsage) []service.PromoCodeUsage {
|
||||
out := make([]service.PromoCodeUsage, 0, len(models))
|
||||
for i := range models {
|
||||
if s := promoCodeUsageEntityToService(models[i]); s != nil {
|
||||
out = append(out, *s)
|
||||
}
|
||||
}
|
||||
return out
|
||||
}
|
||||
@@ -45,6 +45,7 @@ var ProviderSet = wire.NewSet(
|
||||
NewAccountRepository,
|
||||
NewProxyRepository,
|
||||
NewRedeemCodeRepository,
|
||||
NewPromoCodeRepository,
|
||||
NewUsageLogRepository,
|
||||
NewSettingRepository,
|
||||
NewOpsRepository,
|
||||
|
||||
@@ -327,10 +327,7 @@ func TestAPIContracts(t *testing.T) {
|
||||
"fallback_model_openai": "gpt-4o",
|
||||
"enable_identity_patch": true,
|
||||
"identity_patch_prompt": "",
|
||||
"ops_monitoring_enabled": true,
|
||||
"ops_realtime_monitoring_enabled": true,
|
||||
"ops_query_mode_default": "auto",
|
||||
"ops_metrics_interval_seconds": 60
|
||||
"home_content": ""
|
||||
}
|
||||
}`,
|
||||
},
|
||||
@@ -402,7 +399,7 @@ func newContractDeps(t *testing.T) *contractDeps {
|
||||
settingRepo := newStubSettingRepo()
|
||||
settingService := service.NewSettingService(settingRepo, cfg)
|
||||
|
||||
authHandler := handler.NewAuthHandler(cfg, nil, userService, settingService)
|
||||
authHandler := handler.NewAuthHandler(cfg, nil, userService, settingService, nil)
|
||||
apiKeyHandler := handler.NewAPIKeyHandler(apiKeyService)
|
||||
usageHandler := handler.NewUsageHandler(usageService, apiKeyService)
|
||||
adminSettingHandler := adminhandler.NewSettingHandler(settingService, nil, nil)
|
||||
@@ -579,6 +576,10 @@ func (stubGroupRepo) GetByID(ctx context.Context, id int64) (*service.Group, err
|
||||
return nil, service.ErrGroupNotFound
|
||||
}
|
||||
|
||||
func (stubGroupRepo) GetByIDLite(ctx context.Context, id int64) (*service.Group, error) {
|
||||
return nil, service.ErrGroupNotFound
|
||||
}
|
||||
|
||||
func (stubGroupRepo) Update(ctx context.Context, group *service.Group) error {
|
||||
return errors.New("not implemented")
|
||||
}
|
||||
|
||||
@@ -13,6 +13,7 @@ import (
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/google/wire"
|
||||
"github.com/redis/go-redis/v9"
|
||||
)
|
||||
|
||||
// ProviderSet 提供服务器层的依赖
|
||||
@@ -31,6 +32,8 @@ func ProvideRouter(
|
||||
apiKeyService *service.APIKeyService,
|
||||
subscriptionService *service.SubscriptionService,
|
||||
opsService *service.OpsService,
|
||||
settingService *service.SettingService,
|
||||
redisClient *redis.Client,
|
||||
) *gin.Engine {
|
||||
if cfg.Server.Mode == "release" {
|
||||
gin.SetMode(gin.ReleaseMode)
|
||||
@@ -48,7 +51,7 @@ func ProvideRouter(
|
||||
}
|
||||
}
|
||||
|
||||
return SetupRouter(r, handlers, jwtAuth, adminAuth, apiKeyAuth, apiKeyService, subscriptionService, opsService, cfg)
|
||||
return SetupRouter(r, handlers, jwtAuth, adminAuth, apiKeyAuth, apiKeyService, subscriptionService, opsService, settingService, cfg, redisClient)
|
||||
}
|
||||
|
||||
// ProvideHTTPServer 提供 HTTP 服务器
|
||||
|
||||
@@ -1,11 +1,13 @@
|
||||
package middleware
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"log"
|
||||
"strings"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/config"
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/ctxkey"
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/ip"
|
||||
"github.com/Wei-Shaw/sub2api/internal/service"
|
||||
|
||||
@@ -103,6 +105,7 @@ func apiKeyAuthWithSubscription(apiKeyService *service.APIKeyService, subscripti
|
||||
Concurrency: apiKey.User.Concurrency,
|
||||
})
|
||||
c.Set(string(ContextKeyUserRole), apiKey.User.Role)
|
||||
setGroupContext(c, apiKey.Group)
|
||||
c.Next()
|
||||
return
|
||||
}
|
||||
@@ -161,6 +164,7 @@ func apiKeyAuthWithSubscription(apiKeyService *service.APIKeyService, subscripti
|
||||
Concurrency: apiKey.User.Concurrency,
|
||||
})
|
||||
c.Set(string(ContextKeyUserRole), apiKey.User.Role)
|
||||
setGroupContext(c, apiKey.Group)
|
||||
|
||||
c.Next()
|
||||
}
|
||||
@@ -185,3 +189,14 @@ func GetSubscriptionFromContext(c *gin.Context) (*service.UserSubscription, bool
|
||||
subscription, ok := value.(*service.UserSubscription)
|
||||
return subscription, ok
|
||||
}
|
||||
|
||||
func setGroupContext(c *gin.Context, group *service.Group) {
|
||||
if !service.IsGroupContextValid(group) {
|
||||
return
|
||||
}
|
||||
if existing, ok := c.Request.Context().Value(ctxkey.Group).(*service.Group); ok && existing != nil && existing.ID == group.ID && service.IsGroupContextValid(existing) {
|
||||
return
|
||||
}
|
||||
ctx := context.WithValue(c.Request.Context(), ctxkey.Group, group)
|
||||
c.Request = c.Request.WithContext(ctx)
|
||||
}
|
||||
|
||||
@@ -63,6 +63,7 @@ func APIKeyAuthWithSubscriptionGoogle(apiKeyService *service.APIKeyService, subs
|
||||
Concurrency: apiKey.User.Concurrency,
|
||||
})
|
||||
c.Set(string(ContextKeyUserRole), apiKey.User.Role)
|
||||
setGroupContext(c, apiKey.Group)
|
||||
c.Next()
|
||||
return
|
||||
}
|
||||
@@ -102,6 +103,7 @@ func APIKeyAuthWithSubscriptionGoogle(apiKeyService *service.APIKeyService, subs
|
||||
Concurrency: apiKey.User.Concurrency,
|
||||
})
|
||||
c.Set(string(ContextKeyUserRole), apiKey.User.Role)
|
||||
setGroupContext(c, apiKey.Group)
|
||||
c.Next()
|
||||
}
|
||||
}
|
||||
|
||||
@@ -9,6 +9,7 @@ import (
|
||||
"testing"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/config"
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/ctxkey"
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/pagination"
|
||||
"github.com/Wei-Shaw/sub2api/internal/service"
|
||||
|
||||
@@ -133,6 +134,70 @@ func TestApiKeyAuthWithSubscriptionGoogle_QueryApiKeyRejected(t *testing.T) {
|
||||
require.Equal(t, "INVALID_ARGUMENT", resp.Error.Status)
|
||||
}
|
||||
|
||||
func TestApiKeyAuthWithSubscriptionGoogleSetsGroupContext(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
|
||||
group := &service.Group{
|
||||
ID: 99,
|
||||
Name: "g1",
|
||||
Status: service.StatusActive,
|
||||
Platform: service.PlatformGemini,
|
||||
Hydrated: true,
|
||||
}
|
||||
user := &service.User{
|
||||
ID: 7,
|
||||
Role: service.RoleUser,
|
||||
Status: service.StatusActive,
|
||||
Balance: 10,
|
||||
Concurrency: 3,
|
||||
}
|
||||
apiKey := &service.APIKey{
|
||||
ID: 100,
|
||||
UserID: user.ID,
|
||||
Key: "test-key",
|
||||
Status: service.StatusActive,
|
||||
User: user,
|
||||
Group: group,
|
||||
}
|
||||
apiKey.GroupID = &group.ID
|
||||
|
||||
apiKeyService := service.NewAPIKeyService(
|
||||
fakeAPIKeyRepo{
|
||||
getByKey: func(ctx context.Context, key string) (*service.APIKey, error) {
|
||||
if key != apiKey.Key {
|
||||
return nil, service.ErrAPIKeyNotFound
|
||||
}
|
||||
clone := *apiKey
|
||||
return &clone, nil
|
||||
},
|
||||
},
|
||||
nil,
|
||||
nil,
|
||||
nil,
|
||||
nil,
|
||||
&config.Config{RunMode: config.RunModeSimple},
|
||||
)
|
||||
|
||||
cfg := &config.Config{RunMode: config.RunModeSimple}
|
||||
r := gin.New()
|
||||
r.Use(APIKeyAuthWithSubscriptionGoogle(apiKeyService, nil, cfg))
|
||||
r.GET("/v1beta/test", func(c *gin.Context) {
|
||||
groupFromCtx, ok := c.Request.Context().Value(ctxkey.Group).(*service.Group)
|
||||
if !ok || groupFromCtx == nil || groupFromCtx.ID != group.ID {
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"ok": false})
|
||||
return
|
||||
}
|
||||
c.JSON(http.StatusOK, gin.H{"ok": true})
|
||||
})
|
||||
|
||||
req := httptest.NewRequest(http.MethodGet, "/v1beta/test", nil)
|
||||
req.Header.Set("x-api-key", apiKey.Key)
|
||||
rec := httptest.NewRecorder()
|
||||
r.ServeHTTP(rec, req)
|
||||
|
||||
require.Equal(t, http.StatusOK, rec.Code)
|
||||
}
|
||||
|
||||
func TestApiKeyAuthWithSubscriptionGoogle_QueryKeyAllowedOnV1Beta(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
|
||||
|
||||
@@ -11,6 +11,7 @@ import (
|
||||
"time"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/config"
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/ctxkey"
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/pagination"
|
||||
"github.com/Wei-Shaw/sub2api/internal/service"
|
||||
"github.com/gin-gonic/gin"
|
||||
@@ -25,6 +26,7 @@ func TestSimpleModeBypassesQuotaCheck(t *testing.T) {
|
||||
ID: 42,
|
||||
Name: "sub",
|
||||
Status: service.StatusActive,
|
||||
Hydrated: true,
|
||||
SubscriptionType: service.SubscriptionTypeSubscription,
|
||||
DailyLimitUSD: &limit,
|
||||
}
|
||||
@@ -110,6 +112,129 @@ func TestSimpleModeBypassesQuotaCheck(t *testing.T) {
|
||||
})
|
||||
}
|
||||
|
||||
func TestAPIKeyAuthSetsGroupContext(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
|
||||
group := &service.Group{
|
||||
ID: 101,
|
||||
Name: "g1",
|
||||
Status: service.StatusActive,
|
||||
Platform: service.PlatformAnthropic,
|
||||
Hydrated: true,
|
||||
}
|
||||
user := &service.User{
|
||||
ID: 7,
|
||||
Role: service.RoleUser,
|
||||
Status: service.StatusActive,
|
||||
Balance: 10,
|
||||
Concurrency: 3,
|
||||
}
|
||||
apiKey := &service.APIKey{
|
||||
ID: 100,
|
||||
UserID: user.ID,
|
||||
Key: "test-key",
|
||||
Status: service.StatusActive,
|
||||
User: user,
|
||||
Group: group,
|
||||
}
|
||||
apiKey.GroupID = &group.ID
|
||||
|
||||
apiKeyRepo := &stubApiKeyRepo{
|
||||
getByKey: func(ctx context.Context, key string) (*service.APIKey, error) {
|
||||
if key != apiKey.Key {
|
||||
return nil, service.ErrAPIKeyNotFound
|
||||
}
|
||||
clone := *apiKey
|
||||
return &clone, nil
|
||||
},
|
||||
}
|
||||
|
||||
cfg := &config.Config{RunMode: config.RunModeSimple}
|
||||
apiKeyService := service.NewAPIKeyService(apiKeyRepo, nil, nil, nil, nil, cfg)
|
||||
router := gin.New()
|
||||
router.Use(gin.HandlerFunc(NewAPIKeyAuthMiddleware(apiKeyService, nil, cfg)))
|
||||
router.GET("/t", func(c *gin.Context) {
|
||||
groupFromCtx, ok := c.Request.Context().Value(ctxkey.Group).(*service.Group)
|
||||
if !ok || groupFromCtx == nil || groupFromCtx.ID != group.ID {
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"ok": false})
|
||||
return
|
||||
}
|
||||
c.JSON(http.StatusOK, gin.H{"ok": true})
|
||||
})
|
||||
|
||||
w := httptest.NewRecorder()
|
||||
req := httptest.NewRequest(http.MethodGet, "/t", nil)
|
||||
req.Header.Set("x-api-key", apiKey.Key)
|
||||
router.ServeHTTP(w, req)
|
||||
|
||||
require.Equal(t, http.StatusOK, w.Code)
|
||||
}
|
||||
|
||||
func TestAPIKeyAuthOverwritesInvalidContextGroup(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
|
||||
group := &service.Group{
|
||||
ID: 101,
|
||||
Name: "g1",
|
||||
Status: service.StatusActive,
|
||||
Platform: service.PlatformAnthropic,
|
||||
Hydrated: true,
|
||||
}
|
||||
user := &service.User{
|
||||
ID: 7,
|
||||
Role: service.RoleUser,
|
||||
Status: service.StatusActive,
|
||||
Balance: 10,
|
||||
Concurrency: 3,
|
||||
}
|
||||
apiKey := &service.APIKey{
|
||||
ID: 100,
|
||||
UserID: user.ID,
|
||||
Key: "test-key",
|
||||
Status: service.StatusActive,
|
||||
User: user,
|
||||
Group: group,
|
||||
}
|
||||
apiKey.GroupID = &group.ID
|
||||
|
||||
apiKeyRepo := &stubApiKeyRepo{
|
||||
getByKey: func(ctx context.Context, key string) (*service.APIKey, error) {
|
||||
if key != apiKey.Key {
|
||||
return nil, service.ErrAPIKeyNotFound
|
||||
}
|
||||
clone := *apiKey
|
||||
return &clone, nil
|
||||
},
|
||||
}
|
||||
|
||||
cfg := &config.Config{RunMode: config.RunModeSimple}
|
||||
apiKeyService := service.NewAPIKeyService(apiKeyRepo, nil, nil, nil, nil, cfg)
|
||||
router := gin.New()
|
||||
router.Use(gin.HandlerFunc(NewAPIKeyAuthMiddleware(apiKeyService, nil, cfg)))
|
||||
|
||||
invalidGroup := &service.Group{
|
||||
ID: group.ID,
|
||||
Platform: group.Platform,
|
||||
Status: group.Status,
|
||||
}
|
||||
router.GET("/t", func(c *gin.Context) {
|
||||
groupFromCtx, ok := c.Request.Context().Value(ctxkey.Group).(*service.Group)
|
||||
if !ok || groupFromCtx == nil || groupFromCtx.ID != group.ID || !groupFromCtx.Hydrated || groupFromCtx == invalidGroup {
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"ok": false})
|
||||
return
|
||||
}
|
||||
c.JSON(http.StatusOK, gin.H{"ok": true})
|
||||
})
|
||||
|
||||
w := httptest.NewRecorder()
|
||||
req := httptest.NewRequest(http.MethodGet, "/t", nil)
|
||||
req.Header.Set("x-api-key", apiKey.Key)
|
||||
req = req.WithContext(context.WithValue(req.Context(), ctxkey.Group, invalidGroup))
|
||||
router.ServeHTTP(w, req)
|
||||
|
||||
require.Equal(t, http.StatusOK, w.Code)
|
||||
}
|
||||
|
||||
func newAuthTestRouter(apiKeyService *service.APIKeyService, subscriptionService *service.SubscriptionService, cfg *config.Config) *gin.Engine {
|
||||
router := gin.New()
|
||||
router.Use(gin.HandlerFunc(NewAPIKeyAuthMiddleware(apiKeyService, subscriptionService, cfg)))
|
||||
|
||||
@@ -1,6 +1,8 @@
|
||||
package server
|
||||
|
||||
import (
|
||||
"log"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/config"
|
||||
"github.com/Wei-Shaw/sub2api/internal/handler"
|
||||
middleware2 "github.com/Wei-Shaw/sub2api/internal/server/middleware"
|
||||
@@ -9,6 +11,7 @@ import (
|
||||
"github.com/Wei-Shaw/sub2api/internal/web"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/redis/go-redis/v9"
|
||||
)
|
||||
|
||||
// SetupRouter 配置路由器中间件和路由
|
||||
@@ -21,20 +24,30 @@ func SetupRouter(
|
||||
apiKeyService *service.APIKeyService,
|
||||
subscriptionService *service.SubscriptionService,
|
||||
opsService *service.OpsService,
|
||||
settingService *service.SettingService,
|
||||
cfg *config.Config,
|
||||
redisClient *redis.Client,
|
||||
) *gin.Engine {
|
||||
// 应用中间件
|
||||
r.Use(middleware2.Logger())
|
||||
r.Use(middleware2.CORS(cfg.CORS))
|
||||
r.Use(middleware2.SecurityHeaders(cfg.Security.CSP))
|
||||
|
||||
// Serve embedded frontend if available
|
||||
// Serve embedded frontend with settings injection if available
|
||||
if web.HasEmbeddedFrontend() {
|
||||
r.Use(web.ServeEmbeddedFrontend())
|
||||
frontendServer, err := web.NewFrontendServer(settingService)
|
||||
if err != nil {
|
||||
log.Printf("Warning: Failed to create frontend server with settings injection: %v, using legacy mode", err)
|
||||
r.Use(web.ServeEmbeddedFrontend())
|
||||
} else {
|
||||
// Register cache invalidation callback
|
||||
settingService.SetOnUpdateCallback(frontendServer.InvalidateCache)
|
||||
r.Use(frontendServer.Middleware())
|
||||
}
|
||||
}
|
||||
|
||||
// 注册路由
|
||||
registerRoutes(r, handlers, jwtAuth, adminAuth, apiKeyAuth, apiKeyService, subscriptionService, opsService, cfg)
|
||||
registerRoutes(r, handlers, jwtAuth, adminAuth, apiKeyAuth, apiKeyService, subscriptionService, opsService, cfg, redisClient)
|
||||
|
||||
return r
|
||||
}
|
||||
@@ -50,6 +63,7 @@ func registerRoutes(
|
||||
subscriptionService *service.SubscriptionService,
|
||||
opsService *service.OpsService,
|
||||
cfg *config.Config,
|
||||
redisClient *redis.Client,
|
||||
) {
|
||||
// 通用路由(健康检查、状态等)
|
||||
routes.RegisterCommonRoutes(r)
|
||||
@@ -58,7 +72,7 @@ func registerRoutes(
|
||||
v1 := r.Group("/api/v1")
|
||||
|
||||
// 注册各模块路由
|
||||
routes.RegisterAuthRoutes(v1, h, jwtAuth)
|
||||
routes.RegisterAuthRoutes(v1, h, jwtAuth, redisClient)
|
||||
routes.RegisterUserRoutes(v1, h, jwtAuth)
|
||||
routes.RegisterAdminRoutes(v1, h, adminAuth)
|
||||
routes.RegisterGatewayRoutes(r, h, apiKeyAuth, apiKeyService, subscriptionService, opsService, cfg)
|
||||
|
||||
@@ -44,6 +44,9 @@ func RegisterAdminRoutes(
|
||||
// 卡密管理
|
||||
registerRedeemCodeRoutes(admin, h)
|
||||
|
||||
// 优惠码管理
|
||||
registerPromoCodeRoutes(admin, h)
|
||||
|
||||
// 系统设置
|
||||
registerSettingsRoutes(admin, h)
|
||||
|
||||
@@ -252,6 +255,18 @@ func registerRedeemCodeRoutes(admin *gin.RouterGroup, h *handler.Handlers) {
|
||||
}
|
||||
}
|
||||
|
||||
func registerPromoCodeRoutes(admin *gin.RouterGroup, h *handler.Handlers) {
|
||||
promoCodes := admin.Group("/promo-codes")
|
||||
{
|
||||
promoCodes.GET("", h.Admin.Promo.List)
|
||||
promoCodes.GET("/:id", h.Admin.Promo.GetByID)
|
||||
promoCodes.POST("", h.Admin.Promo.Create)
|
||||
promoCodes.PUT("/:id", h.Admin.Promo.Update)
|
||||
promoCodes.DELETE("/:id", h.Admin.Promo.Delete)
|
||||
promoCodes.GET("/:id/usages", h.Admin.Promo.GetUsages)
|
||||
}
|
||||
}
|
||||
|
||||
func registerSettingsRoutes(admin *gin.RouterGroup, h *handler.Handlers) {
|
||||
adminSettings := admin.Group("/settings")
|
||||
{
|
||||
|
||||
@@ -1,24 +1,34 @@
|
||||
package routes
|
||||
|
||||
import (
|
||||
"time"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/handler"
|
||||
"github.com/Wei-Shaw/sub2api/internal/server/middleware"
|
||||
"github.com/Wei-Shaw/sub2api/internal/middleware"
|
||||
servermiddleware "github.com/Wei-Shaw/sub2api/internal/server/middleware"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/redis/go-redis/v9"
|
||||
)
|
||||
|
||||
// RegisterAuthRoutes 注册认证相关路由
|
||||
func RegisterAuthRoutes(
|
||||
v1 *gin.RouterGroup,
|
||||
h *handler.Handlers,
|
||||
jwtAuth middleware.JWTAuthMiddleware,
|
||||
jwtAuth servermiddleware.JWTAuthMiddleware,
|
||||
redisClient *redis.Client,
|
||||
) {
|
||||
// 创建速率限制器
|
||||
rateLimiter := middleware.NewRateLimiter(redisClient)
|
||||
|
||||
// 公开接口
|
||||
auth := v1.Group("/auth")
|
||||
{
|
||||
auth.POST("/register", h.Auth.Register)
|
||||
auth.POST("/login", h.Auth.Login)
|
||||
auth.POST("/send-verify-code", h.Auth.SendVerifyCode)
|
||||
// 优惠码验证接口添加速率限制:每分钟最多 10 次
|
||||
auth.POST("/validate-promo-code", rateLimiter.Limit("validate-promo", 10, time.Minute), h.Auth.ValidatePromoCode)
|
||||
auth.GET("/oauth/linuxdo/start", h.Auth.LinuxDoOAuthStart)
|
||||
auth.GET("/oauth/linuxdo/callback", h.Auth.LinuxDoOAuthCallback)
|
||||
}
|
||||
|
||||
@@ -576,18 +576,33 @@ func (s *adminServiceImpl) validateFallbackGroup(ctx context.Context, currentGro
|
||||
return fmt.Errorf("cannot set self as fallback group")
|
||||
}
|
||||
|
||||
// 检查降级分组是否存在
|
||||
fallbackGroup, err := s.groupRepo.GetByID(ctx, fallbackGroupID)
|
||||
if err != nil {
|
||||
return fmt.Errorf("fallback group not found: %w", err)
|
||||
}
|
||||
visited := map[int64]struct{}{}
|
||||
nextID := fallbackGroupID
|
||||
for {
|
||||
if _, seen := visited[nextID]; seen {
|
||||
return fmt.Errorf("fallback group cycle detected")
|
||||
}
|
||||
visited[nextID] = struct{}{}
|
||||
if currentGroupID > 0 && nextID == currentGroupID {
|
||||
return fmt.Errorf("fallback group cycle detected")
|
||||
}
|
||||
|
||||
// 降级分组不能启用 claude_code_only,否则会造成死循环
|
||||
if fallbackGroup.ClaudeCodeOnly {
|
||||
return fmt.Errorf("fallback group cannot have claude_code_only enabled")
|
||||
}
|
||||
// 检查降级分组是否存在
|
||||
fallbackGroup, err := s.groupRepo.GetByIDLite(ctx, nextID)
|
||||
if err != nil {
|
||||
return fmt.Errorf("fallback group not found: %w", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
// 降级分组不能启用 claude_code_only,否则会造成死循环
|
||||
if nextID == fallbackGroupID && fallbackGroup.ClaudeCodeOnly {
|
||||
return fmt.Errorf("fallback group cannot have claude_code_only enabled")
|
||||
}
|
||||
|
||||
if fallbackGroup.FallbackGroupID == nil {
|
||||
return nil
|
||||
}
|
||||
nextID = *fallbackGroup.FallbackGroupID
|
||||
}
|
||||
}
|
||||
|
||||
func (s *adminServiceImpl) UpdateGroup(ctx context.Context, id int64, input *UpdateGroupInput) (*Group, error) {
|
||||
|
||||
@@ -107,6 +107,10 @@ func (s *groupRepoStub) GetByID(ctx context.Context, id int64) (*Group, error) {
|
||||
panic("unexpected GetByID call")
|
||||
}
|
||||
|
||||
func (s *groupRepoStub) GetByIDLite(ctx context.Context, id int64) (*Group, error) {
|
||||
panic("unexpected GetByIDLite call")
|
||||
}
|
||||
|
||||
func (s *groupRepoStub) Update(ctx context.Context, group *Group) error {
|
||||
panic("unexpected Update call")
|
||||
}
|
||||
|
||||
@@ -45,6 +45,13 @@ func (s *groupRepoStubForAdmin) GetByID(_ context.Context, _ int64) (*Group, err
|
||||
return s.getByID, nil
|
||||
}
|
||||
|
||||
func (s *groupRepoStubForAdmin) GetByIDLite(_ context.Context, _ int64) (*Group, error) {
|
||||
if s.getErr != nil {
|
||||
return nil, s.getErr
|
||||
}
|
||||
return s.getByID, nil
|
||||
}
|
||||
|
||||
func (s *groupRepoStubForAdmin) Delete(_ context.Context, _ int64) error {
|
||||
panic("unexpected Delete call")
|
||||
}
|
||||
@@ -290,3 +297,84 @@ func TestAdminService_ListGroups_WithSearch(t *testing.T) {
|
||||
require.True(t, *repo.listWithFiltersIsExclusive)
|
||||
})
|
||||
}
|
||||
|
||||
func TestAdminService_ValidateFallbackGroup_DetectsCycle(t *testing.T) {
|
||||
groupID := int64(1)
|
||||
fallbackID := int64(2)
|
||||
repo := &groupRepoStubForFallbackCycle{
|
||||
groups: map[int64]*Group{
|
||||
groupID: {
|
||||
ID: groupID,
|
||||
FallbackGroupID: &fallbackID,
|
||||
},
|
||||
fallbackID: {
|
||||
ID: fallbackID,
|
||||
FallbackGroupID: &groupID,
|
||||
},
|
||||
},
|
||||
}
|
||||
svc := &adminServiceImpl{groupRepo: repo}
|
||||
|
||||
err := svc.validateFallbackGroup(context.Background(), groupID, fallbackID)
|
||||
require.Error(t, err)
|
||||
require.Contains(t, err.Error(), "fallback group cycle")
|
||||
}
|
||||
|
||||
type groupRepoStubForFallbackCycle struct {
|
||||
groups map[int64]*Group
|
||||
}
|
||||
|
||||
func (s *groupRepoStubForFallbackCycle) Create(_ context.Context, _ *Group) error {
|
||||
panic("unexpected Create call")
|
||||
}
|
||||
|
||||
func (s *groupRepoStubForFallbackCycle) Update(_ context.Context, _ *Group) error {
|
||||
panic("unexpected Update call")
|
||||
}
|
||||
|
||||
func (s *groupRepoStubForFallbackCycle) GetByID(ctx context.Context, id int64) (*Group, error) {
|
||||
return s.GetByIDLite(ctx, id)
|
||||
}
|
||||
|
||||
func (s *groupRepoStubForFallbackCycle) GetByIDLite(_ context.Context, id int64) (*Group, error) {
|
||||
if g, ok := s.groups[id]; ok {
|
||||
return g, nil
|
||||
}
|
||||
return nil, ErrGroupNotFound
|
||||
}
|
||||
|
||||
func (s *groupRepoStubForFallbackCycle) Delete(_ context.Context, _ int64) error {
|
||||
panic("unexpected Delete call")
|
||||
}
|
||||
|
||||
func (s *groupRepoStubForFallbackCycle) DeleteCascade(_ context.Context, _ int64) ([]int64, error) {
|
||||
panic("unexpected DeleteCascade call")
|
||||
}
|
||||
|
||||
func (s *groupRepoStubForFallbackCycle) List(_ context.Context, _ pagination.PaginationParams) ([]Group, *pagination.PaginationResult, error) {
|
||||
panic("unexpected List call")
|
||||
}
|
||||
|
||||
func (s *groupRepoStubForFallbackCycle) ListWithFilters(_ context.Context, _ pagination.PaginationParams, _, _, _ string, _ *bool) ([]Group, *pagination.PaginationResult, error) {
|
||||
panic("unexpected ListWithFilters call")
|
||||
}
|
||||
|
||||
func (s *groupRepoStubForFallbackCycle) ListActive(_ context.Context) ([]Group, error) {
|
||||
panic("unexpected ListActive call")
|
||||
}
|
||||
|
||||
func (s *groupRepoStubForFallbackCycle) ListActiveByPlatform(_ context.Context, _ string) ([]Group, error) {
|
||||
panic("unexpected ListActiveByPlatform call")
|
||||
}
|
||||
|
||||
func (s *groupRepoStubForFallbackCycle) ExistsByName(_ context.Context, _ string) (bool, error) {
|
||||
panic("unexpected ExistsByName call")
|
||||
}
|
||||
|
||||
func (s *groupRepoStubForFallbackCycle) GetAccountCount(_ context.Context, _ int64) (int64, error) {
|
||||
panic("unexpected GetAccountCount call")
|
||||
}
|
||||
|
||||
func (s *groupRepoStubForFallbackCycle) DeleteAccountGroupsByGroupID(_ context.Context, _ int64) (int64, error) {
|
||||
panic("unexpected DeleteAccountGroupsByGroupID call")
|
||||
}
|
||||
|
||||
@@ -35,9 +35,6 @@ var (
|
||||
// maxTokenLength 限制 token 大小,避免超长 header 触发解析时的异常内存分配。
|
||||
const maxTokenLength = 8192
|
||||
|
||||
// LinuxDoConnectSyntheticEmailDomain LinuxDo Connect 生成的合成邮箱域名后缀
|
||||
const LinuxDoConnectSyntheticEmailDomain = "@linuxdo.synthetic"
|
||||
|
||||
// JWTClaims JWT载荷数据
|
||||
type JWTClaims struct {
|
||||
UserID int64 `json:"user_id"`
|
||||
@@ -55,6 +52,7 @@ type AuthService struct {
|
||||
emailService *EmailService
|
||||
turnstileService *TurnstileService
|
||||
emailQueueService *EmailQueueService
|
||||
promoService *PromoService
|
||||
}
|
||||
|
||||
// NewAuthService 创建认证服务实例
|
||||
@@ -65,6 +63,7 @@ func NewAuthService(
|
||||
emailService *EmailService,
|
||||
turnstileService *TurnstileService,
|
||||
emailQueueService *EmailQueueService,
|
||||
promoService *PromoService,
|
||||
) *AuthService {
|
||||
return &AuthService{
|
||||
userRepo: userRepo,
|
||||
@@ -73,16 +72,17 @@ func NewAuthService(
|
||||
emailService: emailService,
|
||||
turnstileService: turnstileService,
|
||||
emailQueueService: emailQueueService,
|
||||
promoService: promoService,
|
||||
}
|
||||
}
|
||||
|
||||
// Register 用户注册,返回token和用户
|
||||
func (s *AuthService) Register(ctx context.Context, email, password string) (string, *User, error) {
|
||||
return s.RegisterWithVerification(ctx, email, password, "")
|
||||
return s.RegisterWithVerification(ctx, email, password, "", "")
|
||||
}
|
||||
|
||||
// RegisterWithVerification 用户注册(支持邮件验证),返回token和用户
|
||||
func (s *AuthService) RegisterWithVerification(ctx context.Context, email, password, verifyCode string) (string, *User, error) {
|
||||
// RegisterWithVerification 用户注册(支持邮件验证和优惠码),返回token和用户
|
||||
func (s *AuthService) RegisterWithVerification(ctx context.Context, email, password, verifyCode, promoCode string) (string, *User, error) {
|
||||
// 检查是否开放注册(默认关闭:settingService 未配置时不允许注册)
|
||||
if s.settingService == nil || !s.settingService.IsRegistrationEnabled(ctx) {
|
||||
return "", nil, ErrRegDisabled
|
||||
@@ -153,6 +153,19 @@ func (s *AuthService) RegisterWithVerification(ctx context.Context, email, passw
|
||||
return "", nil, ErrServiceUnavailable
|
||||
}
|
||||
|
||||
// 应用优惠码(如果提供)
|
||||
if promoCode != "" && s.promoService != nil {
|
||||
if err := s.promoService.ApplyPromoCode(ctx, user.ID, promoCode); err != nil {
|
||||
// 优惠码应用失败不影响注册,只记录日志
|
||||
log.Printf("[Auth] Failed to apply promo code for user %d: %v", user.ID, err)
|
||||
} else {
|
||||
// 重新获取用户信息以获取更新后的余额
|
||||
if updatedUser, err := s.userRepo.GetByID(ctx, user.ID); err == nil {
|
||||
user = updatedUser
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// 生成token
|
||||
token, err := s.GenerateToken(user)
|
||||
if err != nil {
|
||||
|
||||
@@ -100,6 +100,7 @@ func newAuthService(repo *userRepoStub, settings map[string]string, emailCache E
|
||||
emailService,
|
||||
nil,
|
||||
nil,
|
||||
nil, // promoService
|
||||
)
|
||||
}
|
||||
|
||||
@@ -131,7 +132,7 @@ func TestAuthService_Register_EmailVerifyEnabledButServiceNotConfigured(t *testi
|
||||
}, nil)
|
||||
|
||||
// 应返回服务不可用错误,而不是允许绕过验证
|
||||
_, _, err := service.RegisterWithVerification(context.Background(), "user@test.com", "password", "any-code")
|
||||
_, _, err := service.RegisterWithVerification(context.Background(), "user@test.com", "password", "any-code", "")
|
||||
require.ErrorIs(t, err, ErrServiceUnavailable)
|
||||
}
|
||||
|
||||
@@ -143,7 +144,7 @@ func TestAuthService_Register_EmailVerifyRequired(t *testing.T) {
|
||||
SettingKeyEmailVerifyEnabled: "true",
|
||||
}, cache)
|
||||
|
||||
_, _, err := service.RegisterWithVerification(context.Background(), "user@test.com", "password", "")
|
||||
_, _, err := service.RegisterWithVerification(context.Background(), "user@test.com", "password", "", "")
|
||||
require.ErrorIs(t, err, ErrEmailVerifyRequired)
|
||||
}
|
||||
|
||||
@@ -157,7 +158,7 @@ func TestAuthService_Register_EmailVerifyInvalid(t *testing.T) {
|
||||
SettingKeyEmailVerifyEnabled: "true",
|
||||
}, cache)
|
||||
|
||||
_, _, err := service.RegisterWithVerification(context.Background(), "user@test.com", "password", "wrong")
|
||||
_, _, err := service.RegisterWithVerification(context.Background(), "user@test.com", "password", "wrong", "")
|
||||
require.ErrorIs(t, err, ErrInvalidVerifyCode)
|
||||
require.ErrorContains(t, err, "verify code")
|
||||
}
|
||||
|
||||
@@ -38,6 +38,12 @@ const (
|
||||
RedeemTypeSubscription = "subscription"
|
||||
)
|
||||
|
||||
// PromoCode status constants
|
||||
const (
|
||||
PromoCodeStatusActive = "active"
|
||||
PromoCodeStatusDisabled = "disabled"
|
||||
)
|
||||
|
||||
// Admin adjustment type constants
|
||||
const (
|
||||
AdjustmentTypeAdminBalance = "admin_balance" // 管理员调整余额
|
||||
@@ -57,6 +63,9 @@ const (
|
||||
SubscriptionStatusSuspended = "suspended"
|
||||
)
|
||||
|
||||
// LinuxDoConnectSyntheticEmailDomain 是 LinuxDo Connect 用户的合成邮箱后缀(RFC 保留域名)。
|
||||
const LinuxDoConnectSyntheticEmailDomain = "@linuxdo-connect.invalid"
|
||||
|
||||
// Setting keys
|
||||
const (
|
||||
// 注册设置
|
||||
@@ -90,6 +99,7 @@ const (
|
||||
SettingKeyAPIBaseURL = "api_base_url" // API端点地址(用于客户端配置和导入)
|
||||
SettingKeyContactInfo = "contact_info" // 客服联系方式
|
||||
SettingKeyDocURL = "doc_url" // 文档链接
|
||||
SettingKeyHomeContent = "home_content" // 首页内容(支持 Markdown/HTML,或 URL 作为 iframe src)
|
||||
|
||||
// 默认配置
|
||||
SettingKeyDefaultConcurrency = "default_concurrency" // 新用户默认并发量
|
||||
|
||||
@@ -9,6 +9,7 @@ import (
|
||||
"time"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/config"
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/ctxkey"
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/pagination"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
@@ -23,9 +24,11 @@ type mockAccountRepoForPlatform struct {
|
||||
accounts []Account
|
||||
accountsByID map[int64]*Account
|
||||
listPlatformFunc func(ctx context.Context, platform string) ([]Account, error)
|
||||
getByIDCalls int
|
||||
}
|
||||
|
||||
func (m *mockAccountRepoForPlatform) GetByID(ctx context.Context, id int64) (*Account, error) {
|
||||
m.getByIDCalls++
|
||||
if acc, ok := m.accountsByID[id]; ok {
|
||||
return acc, nil
|
||||
}
|
||||
@@ -191,6 +194,56 @@ func (m *mockGatewayCacheForPlatform) RefreshSessionTTL(ctx context.Context, gro
|
||||
return nil
|
||||
}
|
||||
|
||||
type mockGroupRepoForGateway struct {
|
||||
groups map[int64]*Group
|
||||
getByIDCalls int
|
||||
getByIDLiteCalls int
|
||||
}
|
||||
|
||||
func (m *mockGroupRepoForGateway) GetByID(ctx context.Context, id int64) (*Group, error) {
|
||||
m.getByIDCalls++
|
||||
if g, ok := m.groups[id]; ok {
|
||||
return g, nil
|
||||
}
|
||||
return nil, ErrGroupNotFound
|
||||
}
|
||||
|
||||
func (m *mockGroupRepoForGateway) GetByIDLite(ctx context.Context, id int64) (*Group, error) {
|
||||
m.getByIDLiteCalls++
|
||||
if g, ok := m.groups[id]; ok {
|
||||
return g, nil
|
||||
}
|
||||
return nil, ErrGroupNotFound
|
||||
}
|
||||
|
||||
func (m *mockGroupRepoForGateway) Create(ctx context.Context, group *Group) error { return nil }
|
||||
func (m *mockGroupRepoForGateway) Update(ctx context.Context, group *Group) error { return nil }
|
||||
func (m *mockGroupRepoForGateway) Delete(ctx context.Context, id int64) error { return nil }
|
||||
func (m *mockGroupRepoForGateway) DeleteCascade(ctx context.Context, id int64) ([]int64, error) {
|
||||
return nil, nil
|
||||
}
|
||||
func (m *mockGroupRepoForGateway) List(ctx context.Context, params pagination.PaginationParams) ([]Group, *pagination.PaginationResult, error) {
|
||||
return nil, nil, nil
|
||||
}
|
||||
func (m *mockGroupRepoForGateway) ListWithFilters(ctx context.Context, params pagination.PaginationParams, platform, status, search string, isExclusive *bool) ([]Group, *pagination.PaginationResult, error) {
|
||||
return nil, nil, nil
|
||||
}
|
||||
func (m *mockGroupRepoForGateway) ListActive(ctx context.Context) ([]Group, error) {
|
||||
return nil, nil
|
||||
}
|
||||
func (m *mockGroupRepoForGateway) ListActiveByPlatform(ctx context.Context, platform string) ([]Group, error) {
|
||||
return nil, nil
|
||||
}
|
||||
func (m *mockGroupRepoForGateway) ExistsByName(ctx context.Context, name string) (bool, error) {
|
||||
return false, nil
|
||||
}
|
||||
func (m *mockGroupRepoForGateway) GetAccountCount(ctx context.Context, groupID int64) (int64, error) {
|
||||
return 0, nil
|
||||
}
|
||||
func (m *mockGroupRepoForGateway) DeleteAccountGroupsByGroupID(ctx context.Context, groupID int64) (int64, error) {
|
||||
return 0, nil
|
||||
}
|
||||
|
||||
func ptr[T any](v T) *T {
|
||||
return &v
|
||||
}
|
||||
@@ -897,6 +950,74 @@ func (m *mockConcurrencyService) GetAccountWaitingCount(ctx context.Context, acc
|
||||
return m.accountWaitCounts[accountID], nil
|
||||
}
|
||||
|
||||
type mockConcurrencyCache struct {
|
||||
acquireAccountCalls int
|
||||
loadBatchCalls int
|
||||
}
|
||||
|
||||
func (m *mockConcurrencyCache) AcquireAccountSlot(ctx context.Context, accountID int64, maxConcurrency int, requestID string) (bool, error) {
|
||||
m.acquireAccountCalls++
|
||||
return true, nil
|
||||
}
|
||||
|
||||
func (m *mockConcurrencyCache) ReleaseAccountSlot(ctx context.Context, accountID int64, requestID string) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *mockConcurrencyCache) GetAccountConcurrency(ctx context.Context, accountID int64) (int, error) {
|
||||
return 0, nil
|
||||
}
|
||||
|
||||
func (m *mockConcurrencyCache) IncrementAccountWaitCount(ctx context.Context, accountID int64, maxWait int) (bool, error) {
|
||||
return true, nil
|
||||
}
|
||||
|
||||
func (m *mockConcurrencyCache) DecrementAccountWaitCount(ctx context.Context, accountID int64) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *mockConcurrencyCache) GetAccountWaitingCount(ctx context.Context, accountID int64) (int, error) {
|
||||
return 0, nil
|
||||
}
|
||||
|
||||
func (m *mockConcurrencyCache) AcquireUserSlot(ctx context.Context, userID int64, maxConcurrency int, requestID string) (bool, error) {
|
||||
return true, nil
|
||||
}
|
||||
|
||||
func (m *mockConcurrencyCache) ReleaseUserSlot(ctx context.Context, userID int64, requestID string) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *mockConcurrencyCache) GetUserConcurrency(ctx context.Context, userID int64) (int, error) {
|
||||
return 0, nil
|
||||
}
|
||||
|
||||
func (m *mockConcurrencyCache) IncrementWaitCount(ctx context.Context, userID int64, maxWait int) (bool, error) {
|
||||
return true, nil
|
||||
}
|
||||
|
||||
func (m *mockConcurrencyCache) DecrementWaitCount(ctx context.Context, userID int64) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *mockConcurrencyCache) GetAccountsLoadBatch(ctx context.Context, accounts []AccountWithConcurrency) (map[int64]*AccountLoadInfo, error) {
|
||||
m.loadBatchCalls++
|
||||
result := make(map[int64]*AccountLoadInfo, len(accounts))
|
||||
for _, acc := range accounts {
|
||||
result[acc.ID] = &AccountLoadInfo{
|
||||
AccountID: acc.ID,
|
||||
CurrentConcurrency: 0,
|
||||
WaitingCount: 0,
|
||||
LoadRate: 0,
|
||||
}
|
||||
}
|
||||
return result, nil
|
||||
}
|
||||
|
||||
func (m *mockConcurrencyCache) CleanupExpiredAccountSlots(ctx context.Context, accountID int64) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
// TestGatewayService_SelectAccountWithLoadAwareness tests load-aware account selection
|
||||
func TestGatewayService_SelectAccountWithLoadAwareness(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
@@ -995,6 +1116,78 @@ func TestGatewayService_SelectAccountWithLoadAwareness(t *testing.T) {
|
||||
require.Equal(t, int64(2), result.Account.ID, "不应选择被排除的账号")
|
||||
})
|
||||
|
||||
t.Run("粘性命中-不调用GetByID", func(t *testing.T) {
|
||||
repo := &mockAccountRepoForPlatform{
|
||||
accounts: []Account{
|
||||
{ID: 1, Platform: PlatformAnthropic, Priority: 1, Status: StatusActive, Schedulable: true, Concurrency: 5},
|
||||
},
|
||||
accountsByID: map[int64]*Account{},
|
||||
}
|
||||
for i := range repo.accounts {
|
||||
repo.accountsByID[repo.accounts[i].ID] = &repo.accounts[i]
|
||||
}
|
||||
|
||||
cache := &mockGatewayCacheForPlatform{
|
||||
sessionBindings: map[string]int64{"sticky": 1},
|
||||
}
|
||||
|
||||
cfg := testConfig()
|
||||
cfg.Gateway.Scheduling.LoadBatchEnabled = true
|
||||
|
||||
concurrencyCache := &mockConcurrencyCache{}
|
||||
|
||||
svc := &GatewayService{
|
||||
accountRepo: repo,
|
||||
cache: cache,
|
||||
cfg: cfg,
|
||||
concurrencyService: NewConcurrencyService(concurrencyCache),
|
||||
}
|
||||
|
||||
result, err := svc.SelectAccountWithLoadAwareness(ctx, nil, "sticky", "claude-3-5-sonnet-20241022", nil)
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, result)
|
||||
require.NotNil(t, result.Account)
|
||||
require.Equal(t, int64(1), result.Account.ID)
|
||||
require.Equal(t, 0, repo.getByIDCalls, "粘性命中不应调用GetByID")
|
||||
require.Equal(t, 0, concurrencyCache.loadBatchCalls, "粘性命中应在负载批量查询前返回")
|
||||
})
|
||||
|
||||
t.Run("粘性账号不在候选集-回退负载感知选择", func(t *testing.T) {
|
||||
repo := &mockAccountRepoForPlatform{
|
||||
accounts: []Account{
|
||||
{ID: 2, Platform: PlatformAnthropic, Priority: 1, Status: StatusActive, Schedulable: true, Concurrency: 5},
|
||||
},
|
||||
accountsByID: map[int64]*Account{},
|
||||
}
|
||||
for i := range repo.accounts {
|
||||
repo.accountsByID[repo.accounts[i].ID] = &repo.accounts[i]
|
||||
}
|
||||
|
||||
cache := &mockGatewayCacheForPlatform{
|
||||
sessionBindings: map[string]int64{"sticky": 1},
|
||||
}
|
||||
|
||||
cfg := testConfig()
|
||||
cfg.Gateway.Scheduling.LoadBatchEnabled = true
|
||||
|
||||
concurrencyCache := &mockConcurrencyCache{}
|
||||
|
||||
svc := &GatewayService{
|
||||
accountRepo: repo,
|
||||
cache: cache,
|
||||
cfg: cfg,
|
||||
concurrencyService: NewConcurrencyService(concurrencyCache),
|
||||
}
|
||||
|
||||
result, err := svc.SelectAccountWithLoadAwareness(ctx, nil, "sticky", "claude-3-5-sonnet-20241022", nil)
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, result)
|
||||
require.NotNil(t, result.Account)
|
||||
require.Equal(t, int64(2), result.Account.ID, "粘性账号不在候选集时应回退到可用账号")
|
||||
require.Equal(t, 0, repo.getByIDCalls, "粘性账号缺失不应回退到GetByID")
|
||||
require.Equal(t, 1, concurrencyCache.loadBatchCalls, "应继续进行负载批量查询")
|
||||
})
|
||||
|
||||
t.Run("无可用账号-返回错误", func(t *testing.T) {
|
||||
repo := &mockAccountRepoForPlatform{
|
||||
accounts: []Account{},
|
||||
@@ -1019,3 +1212,190 @@ func TestGatewayService_SelectAccountWithLoadAwareness(t *testing.T) {
|
||||
require.Contains(t, err.Error(), "no available accounts")
|
||||
})
|
||||
}
|
||||
|
||||
func TestGatewayService_GroupResolution_ReusesContextGroup(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
groupID := int64(42)
|
||||
group := &Group{
|
||||
ID: groupID,
|
||||
Platform: PlatformAnthropic,
|
||||
Status: StatusActive,
|
||||
Hydrated: true,
|
||||
}
|
||||
ctx = context.WithValue(ctx, ctxkey.Group, group)
|
||||
|
||||
repo := &mockAccountRepoForPlatform{
|
||||
accounts: []Account{
|
||||
{ID: 1, Platform: PlatformAnthropic, Priority: 1, Status: StatusActive, Schedulable: true},
|
||||
},
|
||||
accountsByID: map[int64]*Account{},
|
||||
}
|
||||
for i := range repo.accounts {
|
||||
repo.accountsByID[repo.accounts[i].ID] = &repo.accounts[i]
|
||||
}
|
||||
|
||||
groupRepo := &mockGroupRepoForGateway{
|
||||
groups: map[int64]*Group{groupID: group},
|
||||
}
|
||||
|
||||
svc := &GatewayService{
|
||||
accountRepo: repo,
|
||||
groupRepo: groupRepo,
|
||||
cfg: testConfig(),
|
||||
}
|
||||
|
||||
account, err := svc.SelectAccountForModelWithExclusions(ctx, &groupID, "", "claude-3-5-sonnet-20241022", nil)
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, account)
|
||||
require.Equal(t, 0, groupRepo.getByIDCalls)
|
||||
require.Equal(t, 0, groupRepo.getByIDLiteCalls)
|
||||
}
|
||||
|
||||
func TestGatewayService_GroupResolution_IgnoresInvalidContextGroup(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
groupID := int64(42)
|
||||
ctxGroup := &Group{
|
||||
ID: groupID,
|
||||
Platform: PlatformAnthropic,
|
||||
Status: StatusActive,
|
||||
}
|
||||
ctx = context.WithValue(ctx, ctxkey.Group, ctxGroup)
|
||||
|
||||
repo := &mockAccountRepoForPlatform{
|
||||
accounts: []Account{
|
||||
{ID: 1, Platform: PlatformAnthropic, Priority: 1, Status: StatusActive, Schedulable: true},
|
||||
},
|
||||
accountsByID: map[int64]*Account{},
|
||||
}
|
||||
for i := range repo.accounts {
|
||||
repo.accountsByID[repo.accounts[i].ID] = &repo.accounts[i]
|
||||
}
|
||||
|
||||
group := &Group{
|
||||
ID: groupID,
|
||||
Platform: PlatformAnthropic,
|
||||
Status: StatusActive,
|
||||
}
|
||||
groupRepo := &mockGroupRepoForGateway{
|
||||
groups: map[int64]*Group{groupID: group},
|
||||
}
|
||||
|
||||
svc := &GatewayService{
|
||||
accountRepo: repo,
|
||||
groupRepo: groupRepo,
|
||||
cfg: testConfig(),
|
||||
}
|
||||
|
||||
account, err := svc.SelectAccountForModelWithExclusions(ctx, &groupID, "", "claude-3-5-sonnet-20241022", nil)
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, account)
|
||||
require.Equal(t, 0, groupRepo.getByIDCalls)
|
||||
require.Equal(t, 1, groupRepo.getByIDLiteCalls)
|
||||
}
|
||||
|
||||
func TestGatewayService_GroupContext_OverwritesInvalidContextGroup(t *testing.T) {
|
||||
groupID := int64(42)
|
||||
invalidGroup := &Group{
|
||||
ID: groupID,
|
||||
Platform: PlatformAnthropic,
|
||||
Status: StatusActive,
|
||||
}
|
||||
hydratedGroup := &Group{
|
||||
ID: groupID,
|
||||
Platform: PlatformAnthropic,
|
||||
Status: StatusActive,
|
||||
Hydrated: true,
|
||||
}
|
||||
|
||||
ctx := context.WithValue(context.Background(), ctxkey.Group, invalidGroup)
|
||||
svc := &GatewayService{}
|
||||
ctx = svc.withGroupContext(ctx, hydratedGroup)
|
||||
|
||||
got, ok := ctx.Value(ctxkey.Group).(*Group)
|
||||
require.True(t, ok)
|
||||
require.Same(t, hydratedGroup, got)
|
||||
}
|
||||
|
||||
func TestGatewayService_GroupResolution_FallbackUsesLiteOnce(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
groupID := int64(10)
|
||||
fallbackID := int64(11)
|
||||
group := &Group{
|
||||
ID: groupID,
|
||||
Platform: PlatformAnthropic,
|
||||
Status: StatusActive,
|
||||
ClaudeCodeOnly: true,
|
||||
FallbackGroupID: &fallbackID,
|
||||
Hydrated: true,
|
||||
}
|
||||
fallbackGroup := &Group{
|
||||
ID: fallbackID,
|
||||
Platform: PlatformAnthropic,
|
||||
Status: StatusActive,
|
||||
}
|
||||
ctx = context.WithValue(ctx, ctxkey.Group, group)
|
||||
|
||||
repo := &mockAccountRepoForPlatform{
|
||||
accounts: []Account{
|
||||
{ID: 1, Platform: PlatformAnthropic, Priority: 1, Status: StatusActive, Schedulable: true},
|
||||
},
|
||||
accountsByID: map[int64]*Account{},
|
||||
}
|
||||
for i := range repo.accounts {
|
||||
repo.accountsByID[repo.accounts[i].ID] = &repo.accounts[i]
|
||||
}
|
||||
|
||||
groupRepo := &mockGroupRepoForGateway{
|
||||
groups: map[int64]*Group{fallbackID: fallbackGroup},
|
||||
}
|
||||
|
||||
svc := &GatewayService{
|
||||
accountRepo: repo,
|
||||
groupRepo: groupRepo,
|
||||
cfg: testConfig(),
|
||||
}
|
||||
|
||||
account, err := svc.SelectAccountForModelWithExclusions(ctx, &groupID, "", "claude-3-5-sonnet-20241022", nil)
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, account)
|
||||
require.Equal(t, 0, groupRepo.getByIDCalls)
|
||||
require.Equal(t, 1, groupRepo.getByIDLiteCalls)
|
||||
}
|
||||
|
||||
func TestGatewayService_ResolveGatewayGroup_DetectsFallbackCycle(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
groupID := int64(10)
|
||||
fallbackID := int64(11)
|
||||
|
||||
group := &Group{
|
||||
ID: groupID,
|
||||
Platform: PlatformAnthropic,
|
||||
Status: StatusActive,
|
||||
ClaudeCodeOnly: true,
|
||||
FallbackGroupID: &fallbackID,
|
||||
}
|
||||
fallbackGroup := &Group{
|
||||
ID: fallbackID,
|
||||
Platform: PlatformAnthropic,
|
||||
Status: StatusActive,
|
||||
ClaudeCodeOnly: true,
|
||||
FallbackGroupID: &groupID,
|
||||
}
|
||||
|
||||
groupRepo := &mockGroupRepoForGateway{
|
||||
groups: map[int64]*Group{
|
||||
groupID: group,
|
||||
fallbackID: fallbackGroup,
|
||||
},
|
||||
}
|
||||
|
||||
svc := &GatewayService{
|
||||
groupRepo: groupRepo,
|
||||
}
|
||||
|
||||
gotGroup, gotID, err := svc.resolveGatewayGroup(ctx, &groupID)
|
||||
require.Error(t, err)
|
||||
require.Nil(t, gotGroup)
|
||||
require.Nil(t, gotID)
|
||||
require.Contains(t, err.Error(), "fallback group cycle")
|
||||
}
|
||||
|
||||
@@ -361,27 +361,13 @@ func (s *GatewayService) SelectAccountForModelWithExclusions(ctx context.Context
|
||||
if hasForcePlatform && forcePlatform != "" {
|
||||
platform = forcePlatform
|
||||
} else if groupID != nil {
|
||||
// 根据分组 platform 决定查询哪种账号
|
||||
group, err := s.groupRepo.GetByID(ctx, *groupID)
|
||||
group, resolvedGroupID, err := s.resolveGatewayGroup(ctx, groupID)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("get group failed: %w", err)
|
||||
return nil, err
|
||||
}
|
||||
groupID = resolvedGroupID
|
||||
ctx = s.withGroupContext(ctx, group)
|
||||
platform = group.Platform
|
||||
|
||||
// 检查 Claude Code 客户端限制
|
||||
if group.ClaudeCodeOnly {
|
||||
isClaudeCode := IsClaudeCodeClient(ctx)
|
||||
if !isClaudeCode {
|
||||
// 非 Claude Code 客户端,检查是否有降级分组
|
||||
if group.FallbackGroupID != nil {
|
||||
// 使用降级分组重新调度
|
||||
fallbackGroupID := *group.FallbackGroupID
|
||||
return s.SelectAccountForModelWithExclusions(ctx, &fallbackGroupID, sessionHash, requestedModel, excludedIDs)
|
||||
}
|
||||
// 无降级分组,拒绝访问
|
||||
return nil, ErrClaudeCodeOnly
|
||||
}
|
||||
}
|
||||
} else {
|
||||
// 无分组时只使用原生 anthropic 平台
|
||||
platform = PlatformAnthropic
|
||||
@@ -409,10 +395,11 @@ func (s *GatewayService) SelectAccountWithLoadAwareness(ctx context.Context, gro
|
||||
}
|
||||
|
||||
// 检查 Claude Code 客户端限制(可能会替换 groupID 为降级分组)
|
||||
groupID, err := s.checkClaudeCodeRestriction(ctx, groupID)
|
||||
group, groupID, err := s.checkClaudeCodeRestriction(ctx, groupID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
ctx = s.withGroupContext(ctx, group)
|
||||
|
||||
if s.concurrencyService == nil || !cfg.LoadBatchEnabled {
|
||||
account, err := s.SelectAccountForModelWithExclusions(ctx, groupID, sessionHash, requestedModel, excludedIDs)
|
||||
@@ -452,7 +439,7 @@ func (s *GatewayService) SelectAccountWithLoadAwareness(ctx context.Context, gro
|
||||
}, nil
|
||||
}
|
||||
|
||||
platform, hasForcePlatform, err := s.resolvePlatform(ctx, groupID)
|
||||
platform, hasForcePlatform, err := s.resolvePlatform(ctx, groupID, group)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@@ -478,8 +465,13 @@ func (s *GatewayService) SelectAccountWithLoadAwareness(ctx context.Context, gro
|
||||
if sessionHash != "" && s.cache != nil {
|
||||
accountID, err := s.cache.GetSessionAccountID(ctx, derefGroupID(groupID), sessionHash)
|
||||
if err == nil && accountID > 0 && !isExcluded(accountID) {
|
||||
account, err := s.accountRepo.GetByID(ctx, accountID)
|
||||
if err == nil && s.isAccountInGroup(account, groupID) &&
|
||||
// 粘性命中仅在当前可调度候选集中生效。
|
||||
accountByID := make(map[int64]*Account, len(accounts))
|
||||
for i := range accounts {
|
||||
accountByID[accounts[i].ID] = &accounts[i]
|
||||
}
|
||||
account, ok := accountByID[accountID]
|
||||
if ok && s.isAccountInGroup(account, groupID) &&
|
||||
s.isAccountAllowedForPlatform(account, platform, useMixed) &&
|
||||
account.IsSchedulableForModel(requestedModel) &&
|
||||
(requestedModel == "" || s.isModelSupportedByAccount(account, requestedModel)) {
|
||||
@@ -655,51 +647,97 @@ func (s *GatewayService) schedulingConfig() config.GatewaySchedulingConfig {
|
||||
}
|
||||
}
|
||||
|
||||
func (s *GatewayService) withGroupContext(ctx context.Context, group *Group) context.Context {
|
||||
if !IsGroupContextValid(group) {
|
||||
return ctx
|
||||
}
|
||||
if existing, ok := ctx.Value(ctxkey.Group).(*Group); ok && existing != nil && existing.ID == group.ID && IsGroupContextValid(existing) {
|
||||
return ctx
|
||||
}
|
||||
return context.WithValue(ctx, ctxkey.Group, group)
|
||||
}
|
||||
|
||||
func (s *GatewayService) groupFromContext(ctx context.Context, groupID int64) *Group {
|
||||
if group, ok := ctx.Value(ctxkey.Group).(*Group); ok && IsGroupContextValid(group) && group.ID == groupID {
|
||||
return group
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *GatewayService) resolveGroupByID(ctx context.Context, groupID int64) (*Group, error) {
|
||||
if group := s.groupFromContext(ctx, groupID); group != nil {
|
||||
return group, nil
|
||||
}
|
||||
group, err := s.groupRepo.GetByIDLite(ctx, groupID)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("get group failed: %w", err)
|
||||
}
|
||||
return group, nil
|
||||
}
|
||||
|
||||
func (s *GatewayService) resolveGatewayGroup(ctx context.Context, groupID *int64) (*Group, *int64, error) {
|
||||
if groupID == nil {
|
||||
return nil, nil, nil
|
||||
}
|
||||
|
||||
currentID := *groupID
|
||||
visited := map[int64]struct{}{}
|
||||
for {
|
||||
if _, seen := visited[currentID]; seen {
|
||||
return nil, nil, fmt.Errorf("fallback group cycle detected")
|
||||
}
|
||||
visited[currentID] = struct{}{}
|
||||
|
||||
group, err := s.resolveGroupByID(ctx, currentID)
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
|
||||
if !group.ClaudeCodeOnly || IsClaudeCodeClient(ctx) {
|
||||
return group, ¤tID, nil
|
||||
}
|
||||
|
||||
if group.FallbackGroupID == nil {
|
||||
return nil, nil, ErrClaudeCodeOnly
|
||||
}
|
||||
currentID = *group.FallbackGroupID
|
||||
}
|
||||
}
|
||||
|
||||
// checkClaudeCodeRestriction 检查分组的 Claude Code 客户端限制
|
||||
// 如果分组启用了 claude_code_only 且请求不是来自 Claude Code 客户端:
|
||||
// - 有降级分组:返回降级分组的 ID
|
||||
// - 无降级分组:返回 ErrClaudeCodeOnly 错误
|
||||
func (s *GatewayService) checkClaudeCodeRestriction(ctx context.Context, groupID *int64) (*int64, error) {
|
||||
func (s *GatewayService) checkClaudeCodeRestriction(ctx context.Context, groupID *int64) (*Group, *int64, error) {
|
||||
if groupID == nil {
|
||||
return groupID, nil
|
||||
return nil, groupID, nil
|
||||
}
|
||||
|
||||
// 强制平台模式不检查 Claude Code 限制
|
||||
if _, hasForcePlatform := ctx.Value(ctxkey.ForcePlatform).(string); hasForcePlatform {
|
||||
return groupID, nil
|
||||
return nil, groupID, nil
|
||||
}
|
||||
|
||||
group, err := s.groupRepo.GetByID(ctx, *groupID)
|
||||
group, resolvedID, err := s.resolveGatewayGroup(ctx, groupID)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("get group failed: %w", err)
|
||||
return nil, nil, err
|
||||
}
|
||||
|
||||
if !group.ClaudeCodeOnly {
|
||||
return groupID, nil
|
||||
}
|
||||
|
||||
// 分组启用了 Claude Code 限制
|
||||
if IsClaudeCodeClient(ctx) {
|
||||
return groupID, nil
|
||||
}
|
||||
|
||||
// 非 Claude Code 客户端,检查降级分组
|
||||
if group.FallbackGroupID != nil {
|
||||
return group.FallbackGroupID, nil
|
||||
}
|
||||
|
||||
return nil, ErrClaudeCodeOnly
|
||||
return group, resolvedID, nil
|
||||
}
|
||||
|
||||
func (s *GatewayService) resolvePlatform(ctx context.Context, groupID *int64) (string, bool, error) {
|
||||
func (s *GatewayService) resolvePlatform(ctx context.Context, groupID *int64, group *Group) (string, bool, error) {
|
||||
forcePlatform, hasForcePlatform := ctx.Value(ctxkey.ForcePlatform).(string)
|
||||
if hasForcePlatform && forcePlatform != "" {
|
||||
return forcePlatform, true, nil
|
||||
}
|
||||
if group != nil {
|
||||
return group.Platform, false, nil
|
||||
}
|
||||
if groupID != nil {
|
||||
group, err := s.groupRepo.GetByID(ctx, *groupID)
|
||||
group, err := s.resolveGroupByID(ctx, *groupID)
|
||||
if err != nil {
|
||||
return "", false, fmt.Errorf("get group failed: %w", err)
|
||||
return "", false, err
|
||||
}
|
||||
return group.Platform, false, nil
|
||||
}
|
||||
|
||||
@@ -86,9 +86,15 @@ func (s *GeminiMessagesCompatService) SelectAccountForModelWithExclusions(ctx co
|
||||
platform = forcePlatform
|
||||
} else if groupID != nil {
|
||||
// 根据分组 platform 决定查询哪种账号
|
||||
group, err := s.groupRepo.GetByID(ctx, *groupID)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("get group failed: %w", err)
|
||||
var group *Group
|
||||
if ctxGroup, ok := ctx.Value(ctxkey.Group).(*Group); ok && IsGroupContextValid(ctxGroup) && ctxGroup.ID == *groupID {
|
||||
group = ctxGroup
|
||||
} else {
|
||||
var err error
|
||||
group, err = s.groupRepo.GetByIDLite(ctx, *groupID)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("get group failed: %w", err)
|
||||
}
|
||||
}
|
||||
platform = group.Platform
|
||||
} else {
|
||||
|
||||
@@ -8,6 +8,7 @@ import (
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/ctxkey"
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/pagination"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
@@ -152,10 +153,21 @@ var _ AccountRepository = (*mockAccountRepoForGemini)(nil)
|
||||
|
||||
// mockGroupRepoForGemini Gemini 测试用的 group repo mock
|
||||
type mockGroupRepoForGemini struct {
|
||||
groups map[int64]*Group
|
||||
groups map[int64]*Group
|
||||
getByIDCalls int
|
||||
getByIDLiteCalls int
|
||||
}
|
||||
|
||||
func (m *mockGroupRepoForGemini) GetByID(ctx context.Context, id int64) (*Group, error) {
|
||||
m.getByIDCalls++
|
||||
if g, ok := m.groups[id]; ok {
|
||||
return g, nil
|
||||
}
|
||||
return nil, errors.New("group not found")
|
||||
}
|
||||
|
||||
func (m *mockGroupRepoForGemini) GetByIDLite(ctx context.Context, id int64) (*Group, error) {
|
||||
m.getByIDLiteCalls++
|
||||
if g, ok := m.groups[id]; ok {
|
||||
return g, nil
|
||||
}
|
||||
@@ -248,6 +260,77 @@ func TestGeminiMessagesCompatService_SelectAccountForModelWithExclusions_GeminiP
|
||||
require.Equal(t, PlatformGemini, acc.Platform, "无分组时应只返回 gemini 平台账户")
|
||||
}
|
||||
|
||||
func TestGeminiMessagesCompatService_GroupResolution_ReusesContextGroup(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
groupID := int64(7)
|
||||
group := &Group{
|
||||
ID: groupID,
|
||||
Platform: PlatformGemini,
|
||||
Status: StatusActive,
|
||||
Hydrated: true,
|
||||
}
|
||||
ctx = context.WithValue(ctx, ctxkey.Group, group)
|
||||
|
||||
repo := &mockAccountRepoForGemini{
|
||||
accounts: []Account{
|
||||
{ID: 1, Platform: PlatformGemini, Priority: 1, Status: StatusActive, Schedulable: true},
|
||||
},
|
||||
accountsByID: map[int64]*Account{},
|
||||
}
|
||||
for i := range repo.accounts {
|
||||
repo.accountsByID[repo.accounts[i].ID] = &repo.accounts[i]
|
||||
}
|
||||
|
||||
cache := &mockGatewayCacheForGemini{}
|
||||
groupRepo := &mockGroupRepoForGemini{groups: map[int64]*Group{}}
|
||||
|
||||
svc := &GeminiMessagesCompatService{
|
||||
accountRepo: repo,
|
||||
groupRepo: groupRepo,
|
||||
cache: cache,
|
||||
}
|
||||
|
||||
acc, err := svc.SelectAccountForModelWithExclusions(ctx, &groupID, "", "gemini-2.5-flash", nil)
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, acc)
|
||||
require.Equal(t, 0, groupRepo.getByIDCalls)
|
||||
require.Equal(t, 0, groupRepo.getByIDLiteCalls)
|
||||
}
|
||||
|
||||
func TestGeminiMessagesCompatService_GroupResolution_UsesLiteFetch(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
groupID := int64(7)
|
||||
|
||||
repo := &mockAccountRepoForGemini{
|
||||
accounts: []Account{
|
||||
{ID: 1, Platform: PlatformGemini, Priority: 1, Status: StatusActive, Schedulable: true},
|
||||
},
|
||||
accountsByID: map[int64]*Account{},
|
||||
}
|
||||
for i := range repo.accounts {
|
||||
repo.accountsByID[repo.accounts[i].ID] = &repo.accounts[i]
|
||||
}
|
||||
|
||||
cache := &mockGatewayCacheForGemini{}
|
||||
groupRepo := &mockGroupRepoForGemini{
|
||||
groups: map[int64]*Group{
|
||||
groupID: {ID: groupID, Platform: PlatformGemini},
|
||||
},
|
||||
}
|
||||
|
||||
svc := &GeminiMessagesCompatService{
|
||||
accountRepo: repo,
|
||||
groupRepo: groupRepo,
|
||||
cache: cache,
|
||||
}
|
||||
|
||||
acc, err := svc.SelectAccountForModelWithExclusions(ctx, &groupID, "", "gemini-2.5-flash", nil)
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, acc)
|
||||
require.Equal(t, 0, groupRepo.getByIDCalls)
|
||||
require.Equal(t, 1, groupRepo.getByIDLiteCalls)
|
||||
}
|
||||
|
||||
// TestGeminiMessagesCompatService_SelectAccountForModelWithExclusions_AntigravityGroup 测试 antigravity 分组
|
||||
func TestGeminiMessagesCompatService_SelectAccountForModelWithExclusions_AntigravityGroup(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
|
||||
@@ -10,6 +10,7 @@ type Group struct {
|
||||
RateMultiplier float64
|
||||
IsExclusive bool
|
||||
Status string
|
||||
Hydrated bool // indicates the group was loaded from a trusted repository source
|
||||
|
||||
SubscriptionType string
|
||||
DailyLimitUSD *float64
|
||||
@@ -72,3 +73,20 @@ func (g *Group) GetImagePrice(imageSize string) *float64 {
|
||||
return g.ImagePrice2K
|
||||
}
|
||||
}
|
||||
|
||||
// IsGroupContextValid reports whether a group from context has the fields required for routing decisions.
|
||||
func IsGroupContextValid(group *Group) bool {
|
||||
if group == nil {
|
||||
return false
|
||||
}
|
||||
if group.ID <= 0 {
|
||||
return false
|
||||
}
|
||||
if !group.Hydrated {
|
||||
return false
|
||||
}
|
||||
if group.Platform == "" || group.Status == "" {
|
||||
return false
|
||||
}
|
||||
return true
|
||||
}
|
||||
|
||||
@@ -16,6 +16,7 @@ var (
|
||||
type GroupRepository interface {
|
||||
Create(ctx context.Context, group *Group) error
|
||||
GetByID(ctx context.Context, id int64) (*Group, error)
|
||||
GetByIDLite(ctx context.Context, id int64) (*Group, error)
|
||||
Update(ctx context.Context, group *Group) error
|
||||
Delete(ctx context.Context, id int64) error
|
||||
DeleteCascade(ctx context.Context, id int64) ([]int64, error)
|
||||
|
||||
@@ -545,6 +545,8 @@ func (s *OpenAIGatewayService) Forward(ctx context.Context, c *gin.Context, acco
|
||||
// 2. Normalize input format for Codex API compatibility
|
||||
if account.Type == AccountTypeOAuth {
|
||||
reqBody["store"] = false
|
||||
// Codex 上游不接受 max_output_tokens 参数,需要在转发前移除。
|
||||
delete(reqBody, "max_output_tokens")
|
||||
bodyModified = true
|
||||
|
||||
// Normalize input format: convert AI SDK multi-part content format to simplified format
|
||||
|
||||
73
backend/internal/service/promo_code.go
Normal file
73
backend/internal/service/promo_code.go
Normal file
@@ -0,0 +1,73 @@
|
||||
package service
|
||||
|
||||
import (
|
||||
"time"
|
||||
)
|
||||
|
||||
// PromoCode 注册优惠码
|
||||
type PromoCode struct {
|
||||
ID int64
|
||||
Code string
|
||||
BonusAmount float64
|
||||
MaxUses int
|
||||
UsedCount int
|
||||
Status string
|
||||
ExpiresAt *time.Time
|
||||
Notes string
|
||||
CreatedAt time.Time
|
||||
UpdatedAt time.Time
|
||||
|
||||
// 关联
|
||||
UsageRecords []PromoCodeUsage
|
||||
}
|
||||
|
||||
// PromoCodeUsage 优惠码使用记录
|
||||
type PromoCodeUsage struct {
|
||||
ID int64
|
||||
PromoCodeID int64
|
||||
UserID int64
|
||||
BonusAmount float64
|
||||
UsedAt time.Time
|
||||
|
||||
// 关联
|
||||
PromoCode *PromoCode
|
||||
User *User
|
||||
}
|
||||
|
||||
// CanUse 检查优惠码是否可用
|
||||
func (p *PromoCode) CanUse() bool {
|
||||
if p.Status != PromoCodeStatusActive {
|
||||
return false
|
||||
}
|
||||
if p.ExpiresAt != nil && time.Now().After(*p.ExpiresAt) {
|
||||
return false
|
||||
}
|
||||
if p.MaxUses > 0 && p.UsedCount >= p.MaxUses {
|
||||
return false
|
||||
}
|
||||
return true
|
||||
}
|
||||
|
||||
// IsExpired 检查是否已过期
|
||||
func (p *PromoCode) IsExpired() bool {
|
||||
return p.ExpiresAt != nil && time.Now().After(*p.ExpiresAt)
|
||||
}
|
||||
|
||||
// CreatePromoCodeInput 创建优惠码输入
|
||||
type CreatePromoCodeInput struct {
|
||||
Code string
|
||||
BonusAmount float64
|
||||
MaxUses int
|
||||
ExpiresAt *time.Time
|
||||
Notes string
|
||||
}
|
||||
|
||||
// UpdatePromoCodeInput 更新优惠码输入
|
||||
type UpdatePromoCodeInput struct {
|
||||
Code *string
|
||||
BonusAmount *float64
|
||||
MaxUses *int
|
||||
Status *string
|
||||
ExpiresAt *time.Time
|
||||
Notes *string
|
||||
}
|
||||
30
backend/internal/service/promo_code_repository.go
Normal file
30
backend/internal/service/promo_code_repository.go
Normal file
@@ -0,0 +1,30 @@
|
||||
package service
|
||||
|
||||
import (
|
||||
"context"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/pagination"
|
||||
)
|
||||
|
||||
// PromoCodeRepository 优惠码仓储接口
|
||||
type PromoCodeRepository interface {
|
||||
// 基础 CRUD
|
||||
Create(ctx context.Context, code *PromoCode) error
|
||||
GetByID(ctx context.Context, id int64) (*PromoCode, error)
|
||||
GetByCode(ctx context.Context, code string) (*PromoCode, error)
|
||||
GetByCodeForUpdate(ctx context.Context, code string) (*PromoCode, error) // 带行锁的查询,用于并发控制
|
||||
Update(ctx context.Context, code *PromoCode) error
|
||||
Delete(ctx context.Context, id int64) error
|
||||
|
||||
// 列表查询
|
||||
List(ctx context.Context, params pagination.PaginationParams) ([]PromoCode, *pagination.PaginationResult, error)
|
||||
ListWithFilters(ctx context.Context, params pagination.PaginationParams, status, search string) ([]PromoCode, *pagination.PaginationResult, error)
|
||||
|
||||
// 使用记录
|
||||
CreateUsage(ctx context.Context, usage *PromoCodeUsage) error
|
||||
GetUsageByPromoCodeAndUser(ctx context.Context, promoCodeID, userID int64) (*PromoCodeUsage, error)
|
||||
ListUsagesByPromoCode(ctx context.Context, promoCodeID int64, params pagination.PaginationParams) ([]PromoCodeUsage, *pagination.PaginationResult, error)
|
||||
|
||||
// 计数操作
|
||||
IncrementUsedCount(ctx context.Context, id int64) error
|
||||
}
|
||||
256
backend/internal/service/promo_service.go
Normal file
256
backend/internal/service/promo_service.go
Normal file
@@ -0,0 +1,256 @@
|
||||
package service
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/rand"
|
||||
"encoding/hex"
|
||||
"fmt"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
dbent "github.com/Wei-Shaw/sub2api/ent"
|
||||
infraerrors "github.com/Wei-Shaw/sub2api/internal/pkg/errors"
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/pagination"
|
||||
)
|
||||
|
||||
var (
|
||||
ErrPromoCodeNotFound = infraerrors.NotFound("PROMO_CODE_NOT_FOUND", "promo code not found")
|
||||
ErrPromoCodeExpired = infraerrors.BadRequest("PROMO_CODE_EXPIRED", "promo code has expired")
|
||||
ErrPromoCodeDisabled = infraerrors.BadRequest("PROMO_CODE_DISABLED", "promo code is disabled")
|
||||
ErrPromoCodeMaxUsed = infraerrors.BadRequest("PROMO_CODE_MAX_USED", "promo code has reached maximum uses")
|
||||
ErrPromoCodeAlreadyUsed = infraerrors.Conflict("PROMO_CODE_ALREADY_USED", "you have already used this promo code")
|
||||
ErrPromoCodeInvalid = infraerrors.BadRequest("PROMO_CODE_INVALID", "invalid promo code")
|
||||
)
|
||||
|
||||
// PromoService 优惠码服务
|
||||
type PromoService struct {
|
||||
promoRepo PromoCodeRepository
|
||||
userRepo UserRepository
|
||||
billingCacheService *BillingCacheService
|
||||
entClient *dbent.Client
|
||||
}
|
||||
|
||||
// NewPromoService 创建优惠码服务实例
|
||||
func NewPromoService(
|
||||
promoRepo PromoCodeRepository,
|
||||
userRepo UserRepository,
|
||||
billingCacheService *BillingCacheService,
|
||||
entClient *dbent.Client,
|
||||
) *PromoService {
|
||||
return &PromoService{
|
||||
promoRepo: promoRepo,
|
||||
userRepo: userRepo,
|
||||
billingCacheService: billingCacheService,
|
||||
entClient: entClient,
|
||||
}
|
||||
}
|
||||
|
||||
// ValidatePromoCode 验证优惠码(注册前调用)
|
||||
// 返回 nil, nil 表示空码(不报错)
|
||||
func (s *PromoService) ValidatePromoCode(ctx context.Context, code string) (*PromoCode, error) {
|
||||
code = strings.TrimSpace(code)
|
||||
if code == "" {
|
||||
return nil, nil // 空码不报错,直接返回
|
||||
}
|
||||
|
||||
promoCode, err := s.promoRepo.GetByCode(ctx, code)
|
||||
if err != nil {
|
||||
// 保留原始错误类型,不要统一映射为 NotFound
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if err := s.validatePromoCodeStatus(promoCode); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return promoCode, nil
|
||||
}
|
||||
|
||||
// validatePromoCodeStatus 验证优惠码状态
|
||||
func (s *PromoService) validatePromoCodeStatus(promoCode *PromoCode) error {
|
||||
if !promoCode.CanUse() {
|
||||
if promoCode.IsExpired() {
|
||||
return ErrPromoCodeExpired
|
||||
}
|
||||
if promoCode.Status == PromoCodeStatusDisabled {
|
||||
return ErrPromoCodeDisabled
|
||||
}
|
||||
if promoCode.MaxUses > 0 && promoCode.UsedCount >= promoCode.MaxUses {
|
||||
return ErrPromoCodeMaxUsed
|
||||
}
|
||||
return ErrPromoCodeInvalid
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// ApplyPromoCode 应用优惠码(注册成功后调用)
|
||||
// 使用事务和行锁确保并发安全
|
||||
func (s *PromoService) ApplyPromoCode(ctx context.Context, userID int64, code string) error {
|
||||
code = strings.TrimSpace(code)
|
||||
if code == "" {
|
||||
return nil
|
||||
}
|
||||
|
||||
// 开启事务
|
||||
tx, err := s.entClient.Tx(ctx)
|
||||
if err != nil {
|
||||
return fmt.Errorf("begin transaction: %w", err)
|
||||
}
|
||||
defer func() { _ = tx.Rollback() }()
|
||||
|
||||
txCtx := dbent.NewTxContext(ctx, tx)
|
||||
|
||||
// 在事务中获取并锁定优惠码记录(FOR UPDATE)
|
||||
promoCode, err := s.promoRepo.GetByCodeForUpdate(txCtx, code)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// 在事务中验证优惠码状态
|
||||
if err := s.validatePromoCodeStatus(promoCode); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// 在事务中检查用户是否已使用过此优惠码
|
||||
existing, err := s.promoRepo.GetUsageByPromoCodeAndUser(txCtx, promoCode.ID, userID)
|
||||
if err != nil {
|
||||
return fmt.Errorf("check existing usage: %w", err)
|
||||
}
|
||||
if existing != nil {
|
||||
return ErrPromoCodeAlreadyUsed
|
||||
}
|
||||
|
||||
// 增加用户余额
|
||||
if err := s.userRepo.UpdateBalance(txCtx, userID, promoCode.BonusAmount); err != nil {
|
||||
return fmt.Errorf("update user balance: %w", err)
|
||||
}
|
||||
|
||||
// 创建使用记录
|
||||
usage := &PromoCodeUsage{
|
||||
PromoCodeID: promoCode.ID,
|
||||
UserID: userID,
|
||||
BonusAmount: promoCode.BonusAmount,
|
||||
UsedAt: time.Now(),
|
||||
}
|
||||
if err := s.promoRepo.CreateUsage(txCtx, usage); err != nil {
|
||||
return fmt.Errorf("create usage record: %w", err)
|
||||
}
|
||||
|
||||
// 增加使用次数
|
||||
if err := s.promoRepo.IncrementUsedCount(txCtx, promoCode.ID); err != nil {
|
||||
return fmt.Errorf("increment used count: %w", err)
|
||||
}
|
||||
|
||||
if err := tx.Commit(); err != nil {
|
||||
return fmt.Errorf("commit transaction: %w", err)
|
||||
}
|
||||
|
||||
// 失效余额缓存
|
||||
if s.billingCacheService != nil {
|
||||
go func() {
|
||||
cacheCtx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
|
||||
defer cancel()
|
||||
_ = s.billingCacheService.InvalidateUserBalance(cacheCtx, userID)
|
||||
}()
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// GenerateRandomCode 生成随机优惠码
|
||||
func (s *PromoService) GenerateRandomCode() (string, error) {
|
||||
bytes := make([]byte, 8)
|
||||
if _, err := rand.Read(bytes); err != nil {
|
||||
return "", fmt.Errorf("generate random bytes: %w", err)
|
||||
}
|
||||
return strings.ToUpper(hex.EncodeToString(bytes)), nil
|
||||
}
|
||||
|
||||
// Create 创建优惠码
|
||||
func (s *PromoService) Create(ctx context.Context, input *CreatePromoCodeInput) (*PromoCode, error) {
|
||||
code := strings.TrimSpace(input.Code)
|
||||
if code == "" {
|
||||
// 自动生成
|
||||
var err error
|
||||
code, err = s.GenerateRandomCode()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
|
||||
promoCode := &PromoCode{
|
||||
Code: strings.ToUpper(code),
|
||||
BonusAmount: input.BonusAmount,
|
||||
MaxUses: input.MaxUses,
|
||||
UsedCount: 0,
|
||||
Status: PromoCodeStatusActive,
|
||||
ExpiresAt: input.ExpiresAt,
|
||||
Notes: input.Notes,
|
||||
}
|
||||
|
||||
if err := s.promoRepo.Create(ctx, promoCode); err != nil {
|
||||
return nil, fmt.Errorf("create promo code: %w", err)
|
||||
}
|
||||
|
||||
return promoCode, nil
|
||||
}
|
||||
|
||||
// GetByID 根据ID获取优惠码
|
||||
func (s *PromoService) GetByID(ctx context.Context, id int64) (*PromoCode, error) {
|
||||
code, err := s.promoRepo.GetByID(ctx, id)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return code, nil
|
||||
}
|
||||
|
||||
// Update 更新优惠码
|
||||
func (s *PromoService) Update(ctx context.Context, id int64, input *UpdatePromoCodeInput) (*PromoCode, error) {
|
||||
promoCode, err := s.promoRepo.GetByID(ctx, id)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if input.Code != nil {
|
||||
promoCode.Code = strings.ToUpper(strings.TrimSpace(*input.Code))
|
||||
}
|
||||
if input.BonusAmount != nil {
|
||||
promoCode.BonusAmount = *input.BonusAmount
|
||||
}
|
||||
if input.MaxUses != nil {
|
||||
promoCode.MaxUses = *input.MaxUses
|
||||
}
|
||||
if input.Status != nil {
|
||||
promoCode.Status = *input.Status
|
||||
}
|
||||
if input.ExpiresAt != nil {
|
||||
promoCode.ExpiresAt = input.ExpiresAt
|
||||
}
|
||||
if input.Notes != nil {
|
||||
promoCode.Notes = *input.Notes
|
||||
}
|
||||
|
||||
if err := s.promoRepo.Update(ctx, promoCode); err != nil {
|
||||
return nil, fmt.Errorf("update promo code: %w", err)
|
||||
}
|
||||
|
||||
return promoCode, nil
|
||||
}
|
||||
|
||||
// Delete 删除优惠码
|
||||
func (s *PromoService) Delete(ctx context.Context, id int64) error {
|
||||
if err := s.promoRepo.Delete(ctx, id); err != nil {
|
||||
return fmt.Errorf("delete promo code: %w", err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// List 获取优惠码列表
|
||||
func (s *PromoService) List(ctx context.Context, params pagination.PaginationParams, status, search string) ([]PromoCode, *pagination.PaginationResult, error) {
|
||||
return s.promoRepo.ListWithFilters(ctx, params, status, search)
|
||||
}
|
||||
|
||||
// ListUsages 获取使用记录
|
||||
func (s *PromoService) ListUsages(ctx context.Context, promoCodeID int64, params pagination.PaginationParams) ([]PromoCodeUsage, *pagination.PaginationResult, error) {
|
||||
return s.promoRepo.ListUsagesByPromoCode(ctx, promoCodeID, params)
|
||||
}
|
||||
@@ -32,6 +32,8 @@ type SettingRepository interface {
|
||||
type SettingService struct {
|
||||
settingRepo SettingRepository
|
||||
cfg *config.Config
|
||||
onUpdate func() // Callback when settings are updated (for cache invalidation)
|
||||
version string // Application version
|
||||
}
|
||||
|
||||
// NewSettingService 创建系统设置服务实例
|
||||
@@ -65,6 +67,8 @@ func (s *SettingService) GetPublicSettings(ctx context.Context) (*PublicSettings
|
||||
SettingKeyAPIBaseURL,
|
||||
SettingKeyContactInfo,
|
||||
SettingKeyDocURL,
|
||||
SettingKeyHomeContent,
|
||||
SettingKeyLinuxDoConnectEnabled,
|
||||
}
|
||||
|
||||
settings, err := s.settingRepo.GetMultiple(ctx, keys)
|
||||
@@ -72,6 +76,13 @@ func (s *SettingService) GetPublicSettings(ctx context.Context) (*PublicSettings
|
||||
return nil, fmt.Errorf("get public settings: %w", err)
|
||||
}
|
||||
|
||||
linuxDoEnabled := false
|
||||
if raw, ok := settings[SettingKeyLinuxDoConnectEnabled]; ok {
|
||||
linuxDoEnabled = raw == "true"
|
||||
} else {
|
||||
linuxDoEnabled = s.cfg != nil && s.cfg.LinuxDo.Enabled
|
||||
}
|
||||
|
||||
return &PublicSettings{
|
||||
RegistrationEnabled: settings[SettingKeyRegistrationEnabled] == "true",
|
||||
EmailVerifyEnabled: settings[SettingKeyEmailVerifyEnabled] == "true",
|
||||
@@ -83,6 +94,59 @@ func (s *SettingService) GetPublicSettings(ctx context.Context) (*PublicSettings
|
||||
APIBaseURL: settings[SettingKeyAPIBaseURL],
|
||||
ContactInfo: settings[SettingKeyContactInfo],
|
||||
DocURL: settings[SettingKeyDocURL],
|
||||
HomeContent: settings[SettingKeyHomeContent],
|
||||
LinuxDoOAuthEnabled: linuxDoEnabled,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// SetOnUpdateCallback sets a callback function to be called when settings are updated
|
||||
// This is used for cache invalidation (e.g., HTML cache in frontend server)
|
||||
func (s *SettingService) SetOnUpdateCallback(callback func()) {
|
||||
s.onUpdate = callback
|
||||
}
|
||||
|
||||
// SetVersion sets the application version for injection into public settings
|
||||
func (s *SettingService) SetVersion(version string) {
|
||||
s.version = version
|
||||
}
|
||||
|
||||
// GetPublicSettingsForInjection returns public settings in a format suitable for HTML injection
|
||||
// This implements the web.PublicSettingsProvider interface
|
||||
func (s *SettingService) GetPublicSettingsForInjection(ctx context.Context) (any, error) {
|
||||
settings, err := s.GetPublicSettings(ctx)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Return a struct that matches the frontend's expected format
|
||||
return &struct {
|
||||
RegistrationEnabled bool `json:"registration_enabled"`
|
||||
EmailVerifyEnabled bool `json:"email_verify_enabled"`
|
||||
TurnstileEnabled bool `json:"turnstile_enabled"`
|
||||
TurnstileSiteKey string `json:"turnstile_site_key,omitempty"`
|
||||
SiteName string `json:"site_name"`
|
||||
SiteLogo string `json:"site_logo,omitempty"`
|
||||
SiteSubtitle string `json:"site_subtitle,omitempty"`
|
||||
APIBaseURL string `json:"api_base_url,omitempty"`
|
||||
ContactInfo string `json:"contact_info,omitempty"`
|
||||
DocURL string `json:"doc_url,omitempty"`
|
||||
HomeContent string `json:"home_content,omitempty"`
|
||||
LinuxDoOAuthEnabled bool `json:"linuxdo_oauth_enabled"`
|
||||
Version string `json:"version,omitempty"`
|
||||
}{
|
||||
RegistrationEnabled: settings.RegistrationEnabled,
|
||||
EmailVerifyEnabled: settings.EmailVerifyEnabled,
|
||||
TurnstileEnabled: settings.TurnstileEnabled,
|
||||
TurnstileSiteKey: settings.TurnstileSiteKey,
|
||||
SiteName: settings.SiteName,
|
||||
SiteLogo: settings.SiteLogo,
|
||||
SiteSubtitle: settings.SiteSubtitle,
|
||||
APIBaseURL: settings.APIBaseURL,
|
||||
ContactInfo: settings.ContactInfo,
|
||||
DocURL: settings.DocURL,
|
||||
HomeContent: settings.HomeContent,
|
||||
LinuxDoOAuthEnabled: settings.LinuxDoOAuthEnabled,
|
||||
Version: s.version,
|
||||
}, nil
|
||||
}
|
||||
|
||||
@@ -112,6 +176,14 @@ func (s *SettingService) UpdateSettings(ctx context.Context, settings *SystemSet
|
||||
updates[SettingKeyTurnstileSecretKey] = settings.TurnstileSecretKey
|
||||
}
|
||||
|
||||
// LinuxDo Connect OAuth 登录(终端用户 SSO)
|
||||
updates[SettingKeyLinuxDoConnectEnabled] = strconv.FormatBool(settings.LinuxDoConnectEnabled)
|
||||
updates[SettingKeyLinuxDoConnectClientID] = settings.LinuxDoConnectClientID
|
||||
updates[SettingKeyLinuxDoConnectRedirectURL] = settings.LinuxDoConnectRedirectURL
|
||||
if settings.LinuxDoConnectClientSecret != "" {
|
||||
updates[SettingKeyLinuxDoConnectClientSecret] = settings.LinuxDoConnectClientSecret
|
||||
}
|
||||
|
||||
// OEM设置
|
||||
updates[SettingKeySiteName] = settings.SiteName
|
||||
updates[SettingKeySiteLogo] = settings.SiteLogo
|
||||
@@ -119,6 +191,7 @@ func (s *SettingService) UpdateSettings(ctx context.Context, settings *SystemSet
|
||||
updates[SettingKeyAPIBaseURL] = settings.APIBaseURL
|
||||
updates[SettingKeyContactInfo] = settings.ContactInfo
|
||||
updates[SettingKeyDocURL] = settings.DocURL
|
||||
updates[SettingKeyHomeContent] = settings.HomeContent
|
||||
|
||||
// 默认配置
|
||||
updates[SettingKeyDefaultConcurrency] = strconv.Itoa(settings.DefaultConcurrency)
|
||||
@@ -143,7 +216,11 @@ func (s *SettingService) UpdateSettings(ctx context.Context, settings *SystemSet
|
||||
updates[SettingKeyOpsMetricsIntervalSeconds] = strconv.Itoa(settings.OpsMetricsIntervalSeconds)
|
||||
}
|
||||
|
||||
return s.settingRepo.SetMultiple(ctx, updates)
|
||||
err := s.settingRepo.SetMultiple(ctx, updates)
|
||||
if err == nil && s.onUpdate != nil {
|
||||
s.onUpdate() // Invalidate cache after settings update
|
||||
}
|
||||
return err
|
||||
}
|
||||
|
||||
// IsRegistrationEnabled 检查是否开放注册
|
||||
@@ -260,6 +337,7 @@ func (s *SettingService) parseSettings(settings map[string]string) *SystemSettin
|
||||
APIBaseURL: settings[SettingKeyAPIBaseURL],
|
||||
ContactInfo: settings[SettingKeyContactInfo],
|
||||
DocURL: settings[SettingKeyDocURL],
|
||||
HomeContent: settings[SettingKeyHomeContent],
|
||||
}
|
||||
|
||||
// 解析整数类型
|
||||
@@ -286,6 +364,38 @@ func (s *SettingService) parseSettings(settings map[string]string) *SystemSettin
|
||||
result.SMTPPassword = settings[SettingKeySMTPPassword]
|
||||
result.TurnstileSecretKey = settings[SettingKeyTurnstileSecretKey]
|
||||
|
||||
// LinuxDo Connect 设置:
|
||||
// - 兼容 config.yaml/env(避免老部署因为未迁移到数据库设置而被意外关闭)
|
||||
// - 支持在后台“系统设置”中覆盖并持久化(存储于 DB)
|
||||
linuxDoBase := config.LinuxDoConnectConfig{}
|
||||
if s.cfg != nil {
|
||||
linuxDoBase = s.cfg.LinuxDo
|
||||
}
|
||||
|
||||
if raw, ok := settings[SettingKeyLinuxDoConnectEnabled]; ok {
|
||||
result.LinuxDoConnectEnabled = raw == "true"
|
||||
} else {
|
||||
result.LinuxDoConnectEnabled = linuxDoBase.Enabled
|
||||
}
|
||||
|
||||
if v, ok := settings[SettingKeyLinuxDoConnectClientID]; ok && strings.TrimSpace(v) != "" {
|
||||
result.LinuxDoConnectClientID = strings.TrimSpace(v)
|
||||
} else {
|
||||
result.LinuxDoConnectClientID = linuxDoBase.ClientID
|
||||
}
|
||||
|
||||
if v, ok := settings[SettingKeyLinuxDoConnectRedirectURL]; ok && strings.TrimSpace(v) != "" {
|
||||
result.LinuxDoConnectRedirectURL = strings.TrimSpace(v)
|
||||
} else {
|
||||
result.LinuxDoConnectRedirectURL = linuxDoBase.RedirectURL
|
||||
}
|
||||
|
||||
result.LinuxDoConnectClientSecret = strings.TrimSpace(settings[SettingKeyLinuxDoConnectClientSecret])
|
||||
if result.LinuxDoConnectClientSecret == "" {
|
||||
result.LinuxDoConnectClientSecret = strings.TrimSpace(linuxDoBase.ClientSecret)
|
||||
}
|
||||
result.LinuxDoConnectClientSecretConfigured = result.LinuxDoConnectClientSecret != ""
|
||||
|
||||
// Model fallback settings
|
||||
result.EnableModelFallback = settings[SettingKeyEnableModelFallback] == "true"
|
||||
result.FallbackModelAnthropic = s.getStringOrDefault(settings, SettingKeyFallbackModelAnthropic, "claude-3-5-sonnet-20241022")
|
||||
|
||||
@@ -31,6 +31,7 @@ type SystemSettings struct {
|
||||
APIBaseURL string
|
||||
ContactInfo string
|
||||
DocURL string
|
||||
HomeContent string
|
||||
|
||||
DefaultConcurrency int
|
||||
DefaultBalance float64
|
||||
@@ -64,6 +65,7 @@ type PublicSettings struct {
|
||||
APIBaseURL string
|
||||
ContactInfo string
|
||||
DocURL string
|
||||
HomeContent string
|
||||
LinuxDoOAuthEnabled bool
|
||||
Version string
|
||||
}
|
||||
|
||||
@@ -140,6 +140,7 @@ var ProviderSet = wire.NewSet(
|
||||
NewAccountService,
|
||||
NewProxyService,
|
||||
NewRedeemService,
|
||||
NewPromoService,
|
||||
NewUsageService,
|
||||
NewDashboardService,
|
||||
ProvidePricingService,
|
||||
|
||||
@@ -4,11 +4,38 @@
|
||||
package web
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"net/http"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
)
|
||||
|
||||
// PublicSettingsProvider is an interface to fetch public settings
|
||||
// This stub is needed for compilation when frontend is not embedded
|
||||
type PublicSettingsProvider interface {
|
||||
GetPublicSettingsForInjection(ctx context.Context) (any, error)
|
||||
}
|
||||
|
||||
// FrontendServer is a stub for non-embed builds
|
||||
type FrontendServer struct{}
|
||||
|
||||
// NewFrontendServer returns an error when frontend is not embedded
|
||||
func NewFrontendServer(settingsProvider PublicSettingsProvider) (*FrontendServer, error) {
|
||||
return nil, errors.New("frontend not embedded")
|
||||
}
|
||||
|
||||
// InvalidateCache is a no-op for non-embed builds
|
||||
func (s *FrontendServer) InvalidateCache() {}
|
||||
|
||||
// Middleware returns a handler that returns 404 for non-embed builds
|
||||
func (s *FrontendServer) Middleware() gin.HandlerFunc {
|
||||
return func(c *gin.Context) {
|
||||
c.String(http.StatusNotFound, "Frontend not embedded. Build with -tags embed to include frontend.")
|
||||
c.Abort()
|
||||
}
|
||||
}
|
||||
|
||||
func ServeEmbeddedFrontend() gin.HandlerFunc {
|
||||
return func(c *gin.Context) {
|
||||
c.String(http.StatusNotFound, "Frontend not embedded. Build with -tags embed to include frontend.")
|
||||
|
||||
@@ -3,11 +3,15 @@
|
||||
package web
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"embed"
|
||||
"encoding/json"
|
||||
"io"
|
||||
"io/fs"
|
||||
"net/http"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
)
|
||||
@@ -15,6 +19,162 @@ import (
|
||||
//go:embed all:dist
|
||||
var frontendFS embed.FS
|
||||
|
||||
// PublicSettingsProvider is an interface to fetch public settings
|
||||
type PublicSettingsProvider interface {
|
||||
GetPublicSettingsForInjection(ctx context.Context) (any, error)
|
||||
}
|
||||
|
||||
// FrontendServer serves the embedded frontend with settings injection
|
||||
type FrontendServer struct {
|
||||
distFS fs.FS
|
||||
fileServer http.Handler
|
||||
baseHTML []byte
|
||||
cache *HTMLCache
|
||||
settings PublicSettingsProvider
|
||||
}
|
||||
|
||||
// NewFrontendServer creates a new frontend server with settings injection
|
||||
func NewFrontendServer(settingsProvider PublicSettingsProvider) (*FrontendServer, error) {
|
||||
distFS, err := fs.Sub(frontendFS, "dist")
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Read base HTML once
|
||||
file, err := distFS.Open("index.html")
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer func() { _ = file.Close() }()
|
||||
|
||||
baseHTML, err := io.ReadAll(file)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
cache := NewHTMLCache()
|
||||
cache.SetBaseHTML(baseHTML)
|
||||
|
||||
return &FrontendServer{
|
||||
distFS: distFS,
|
||||
fileServer: http.FileServer(http.FS(distFS)),
|
||||
baseHTML: baseHTML,
|
||||
cache: cache,
|
||||
settings: settingsProvider,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// InvalidateCache invalidates the HTML cache (call when settings change)
|
||||
func (s *FrontendServer) InvalidateCache() {
|
||||
if s != nil && s.cache != nil {
|
||||
s.cache.Invalidate()
|
||||
}
|
||||
}
|
||||
|
||||
// Middleware returns the Gin middleware handler
|
||||
func (s *FrontendServer) Middleware() gin.HandlerFunc {
|
||||
return func(c *gin.Context) {
|
||||
path := c.Request.URL.Path
|
||||
|
||||
// Skip API routes
|
||||
if strings.HasPrefix(path, "/api/") ||
|
||||
strings.HasPrefix(path, "/v1/") ||
|
||||
strings.HasPrefix(path, "/v1beta/") ||
|
||||
strings.HasPrefix(path, "/antigravity/") ||
|
||||
strings.HasPrefix(path, "/setup/") ||
|
||||
path == "/health" ||
|
||||
path == "/responses" {
|
||||
c.Next()
|
||||
return
|
||||
}
|
||||
|
||||
cleanPath := strings.TrimPrefix(path, "/")
|
||||
if cleanPath == "" {
|
||||
cleanPath = "index.html"
|
||||
}
|
||||
|
||||
// For index.html or SPA routes, serve with injected settings
|
||||
if cleanPath == "index.html" || !s.fileExists(cleanPath) {
|
||||
s.serveIndexHTML(c)
|
||||
return
|
||||
}
|
||||
|
||||
// Serve static files normally
|
||||
s.fileServer.ServeHTTP(c.Writer, c.Request)
|
||||
c.Abort()
|
||||
}
|
||||
}
|
||||
|
||||
func (s *FrontendServer) fileExists(path string) bool {
|
||||
file, err := s.distFS.Open(path)
|
||||
if err != nil {
|
||||
return false
|
||||
}
|
||||
_ = file.Close()
|
||||
return true
|
||||
}
|
||||
|
||||
func (s *FrontendServer) serveIndexHTML(c *gin.Context) {
|
||||
// Check cache first
|
||||
cached := s.cache.Get()
|
||||
if cached != nil {
|
||||
// Check If-None-Match for 304 response
|
||||
if match := c.GetHeader("If-None-Match"); match == cached.ETag {
|
||||
c.Status(http.StatusNotModified)
|
||||
c.Abort()
|
||||
return
|
||||
}
|
||||
|
||||
c.Header("ETag", cached.ETag)
|
||||
c.Header("Cache-Control", "no-cache") // Must revalidate
|
||||
c.Data(http.StatusOK, "text/html; charset=utf-8", cached.Content)
|
||||
c.Abort()
|
||||
return
|
||||
}
|
||||
|
||||
// Cache miss - fetch settings and render
|
||||
ctx, cancel := context.WithTimeout(c.Request.Context(), 2*time.Second)
|
||||
defer cancel()
|
||||
|
||||
settings, err := s.settings.GetPublicSettingsForInjection(ctx)
|
||||
if err != nil {
|
||||
// Fallback: serve without injection
|
||||
c.Data(http.StatusOK, "text/html; charset=utf-8", s.baseHTML)
|
||||
c.Abort()
|
||||
return
|
||||
}
|
||||
|
||||
settingsJSON, err := json.Marshal(settings)
|
||||
if err != nil {
|
||||
// Fallback: serve without injection
|
||||
c.Data(http.StatusOK, "text/html; charset=utf-8", s.baseHTML)
|
||||
c.Abort()
|
||||
return
|
||||
}
|
||||
|
||||
rendered := s.injectSettings(settingsJSON)
|
||||
s.cache.Set(rendered, settingsJSON)
|
||||
|
||||
cached = s.cache.Get()
|
||||
if cached != nil {
|
||||
c.Header("ETag", cached.ETag)
|
||||
}
|
||||
c.Header("Cache-Control", "no-cache")
|
||||
c.Data(http.StatusOK, "text/html; charset=utf-8", rendered)
|
||||
c.Abort()
|
||||
}
|
||||
|
||||
func (s *FrontendServer) injectSettings(settingsJSON []byte) []byte {
|
||||
// Create the script tag to inject
|
||||
script := []byte(`<script>window.__APP_CONFIG__=` + string(settingsJSON) + `;</script>`)
|
||||
|
||||
// Inject before </head>
|
||||
headClose := []byte("</head>")
|
||||
return bytes.Replace(s.baseHTML, headClose, append(script, headClose...), 1)
|
||||
}
|
||||
|
||||
// ServeEmbeddedFrontend returns a middleware for serving embedded frontend
|
||||
// This is the legacy function for backward compatibility when no settings provider is available
|
||||
func ServeEmbeddedFrontend() gin.HandlerFunc {
|
||||
distFS, err := fs.Sub(frontendFS, "dist")
|
||||
if err != nil {
|
||||
|
||||
77
backend/internal/web/html_cache.go
Normal file
77
backend/internal/web/html_cache.go
Normal file
@@ -0,0 +1,77 @@
|
||||
//go:build embed
|
||||
|
||||
package web
|
||||
|
||||
import (
|
||||
"crypto/sha256"
|
||||
"encoding/hex"
|
||||
"sync"
|
||||
)
|
||||
|
||||
// HTMLCache manages the cached index.html with injected settings
|
||||
type HTMLCache struct {
|
||||
mu sync.RWMutex
|
||||
cachedHTML []byte
|
||||
etag string
|
||||
baseHTMLHash string // Hash of the original index.html (immutable after build)
|
||||
settingsVersion uint64 // Incremented when settings change
|
||||
}
|
||||
|
||||
// CachedHTML represents the cache state
|
||||
type CachedHTML struct {
|
||||
Content []byte
|
||||
ETag string
|
||||
}
|
||||
|
||||
// NewHTMLCache creates a new HTML cache instance
|
||||
func NewHTMLCache() *HTMLCache {
|
||||
return &HTMLCache{}
|
||||
}
|
||||
|
||||
// SetBaseHTML initializes the cache with the base HTML template
|
||||
func (c *HTMLCache) SetBaseHTML(baseHTML []byte) {
|
||||
c.mu.Lock()
|
||||
defer c.mu.Unlock()
|
||||
|
||||
hash := sha256.Sum256(baseHTML)
|
||||
c.baseHTMLHash = hex.EncodeToString(hash[:8]) // First 8 bytes for brevity
|
||||
}
|
||||
|
||||
// Invalidate marks the cache as stale
|
||||
func (c *HTMLCache) Invalidate() {
|
||||
c.mu.Lock()
|
||||
defer c.mu.Unlock()
|
||||
|
||||
c.settingsVersion++
|
||||
c.cachedHTML = nil
|
||||
c.etag = ""
|
||||
}
|
||||
|
||||
// Get returns the cached HTML or nil if cache is stale
|
||||
func (c *HTMLCache) Get() *CachedHTML {
|
||||
c.mu.RLock()
|
||||
defer c.mu.RUnlock()
|
||||
|
||||
if c.cachedHTML == nil {
|
||||
return nil
|
||||
}
|
||||
return &CachedHTML{
|
||||
Content: c.cachedHTML,
|
||||
ETag: c.etag,
|
||||
}
|
||||
}
|
||||
|
||||
// Set updates the cache with new rendered HTML
|
||||
func (c *HTMLCache) Set(html []byte, settingsJSON []byte) {
|
||||
c.mu.Lock()
|
||||
defer c.mu.Unlock()
|
||||
|
||||
c.cachedHTML = html
|
||||
c.etag = c.generateETag(settingsJSON)
|
||||
}
|
||||
|
||||
// generateETag creates an ETag from base HTML hash + settings hash
|
||||
func (c *HTMLCache) generateETag(settingsJSON []byte) string {
|
||||
settingsHash := sha256.Sum256(settingsJSON)
|
||||
return `"` + c.baseHTMLHash + "-" + hex.EncodeToString(settingsHash[:8]) + `"`
|
||||
}
|
||||
Reference in New Issue
Block a user