diff --git a/backend/cmd/server/wire_gen.go b/backend/cmd/server/wire_gen.go index 881f2e69..a5501181 100644 --- a/backend/cmd/server/wire_gen.go +++ b/backend/cmd/server/wire_gen.go @@ -74,7 +74,7 @@ func initializeApplication(buildInfo handler.BuildInfo) (*Application, error) { authService := service.NewAuthService(client, userRepository, redeemCodeRepository, refreshTokenCache, configConfig, settingService, emailService, turnstileService, emailQueueService, promoService, subscriptionService, affiliateService) userService := service.NewUserService(userRepository, settingRepository, apiKeyAuthCacheInvalidator, billingCache) redeemCache := repository.NewRedeemCache(redisClient) - redeemService := service.NewRedeemService(redeemCodeRepository, userRepository, subscriptionService, redeemCache, billingCacheService, client, apiKeyAuthCacheInvalidator) + redeemService := service.NewRedeemService(redeemCodeRepository, userRepository, subscriptionService, redeemCache, billingCacheService, client, apiKeyAuthCacheInvalidator, affiliateService) secretEncryptor, err := repository.NewAESEncryptor(configConfig) if err != nil { return nil, err diff --git a/backend/internal/handler/admin/admin_service_stub_test.go b/backend/internal/handler/admin/admin_service_stub_test.go index b187b47f..2fef94f1 100644 --- a/backend/internal/handler/admin/admin_service_stub_test.go +++ b/backend/internal/handler/admin/admin_service_stub_test.go @@ -175,6 +175,10 @@ func (s *stubAdminService) UpdateUserBalance(ctx context.Context, userID int64, return &user, nil } +func (s *stubAdminService) BatchUpdateConcurrency(ctx context.Context, userIDs []int64, value int, mode string) (int, error) { + return len(userIDs), nil +} + func (s *stubAdminService) GetUserAPIKeys(ctx context.Context, userID int64, page, pageSize int, sortBy, sortOrder string) ([]service.APIKey, int64, error) { return s.apiKeys, int64(len(s.apiKeys)), nil } diff --git a/backend/internal/handler/admin/setting_handler.go b/backend/internal/handler/admin/setting_handler.go index 2ff94fe6..6485971a 100644 --- a/backend/internal/handler/admin/setting_handler.go +++ b/backend/internal/handler/admin/setting_handler.go @@ -998,17 +998,27 @@ func (h *SettingHandler) UpdateSettings(c *gin.Context) { response.BadRequest(c, "Custom menu item label is too long (max 50 characters)") return } - if strings.TrimSpace(item.URL) == "" { - response.BadRequest(c, "Custom menu item URL is required") - return - } - if len(item.URL) > maxMenuItemURLLen { - response.BadRequest(c, "Custom menu item URL is too long (max 2048 characters)") - return - } - if err := config.ValidateAbsoluteHTTPURL(strings.TrimSpace(item.URL)); err != nil { - response.BadRequest(c, "Custom menu item URL must be an absolute http(s) URL") - return + urlTrimmed := strings.TrimSpace(item.URL) + if strings.HasPrefix(urlTrimmed, "md:") { + // Markdown page mode: URL = "md:" + slug := strings.TrimPrefix(urlTrimmed, "md:") + if slug == "" { + response.BadRequest(c, "Custom menu item markdown slug cannot be empty (use md:slug format)") + return + } + } else { + if urlTrimmed == "" { + response.BadRequest(c, "Custom menu item URL is required (use md:slug for markdown pages)") + return + } + if len(item.URL) > maxMenuItemURLLen { + response.BadRequest(c, "Custom menu item URL is too long (max 2048 characters)") + return + } + if err := config.ValidateAbsoluteHTTPURL(urlTrimmed); err != nil { + response.BadRequest(c, "Custom menu item URL must be an absolute http(s) URL or md:") + return + } } if item.Visibility != "user" && item.Visibility != "admin" { response.BadRequest(c, "Custom menu item visibility must be 'user' or 'admin'") diff --git a/backend/internal/handler/admin/user_handler.go b/backend/internal/handler/admin/user_handler.go index a297c56c..db35472e 100644 --- a/backend/internal/handler/admin/user_handler.go +++ b/backend/internal/handler/admin/user_handler.go @@ -477,3 +477,63 @@ func (h *UserHandler) GetUserRPMStatus(c *gin.Context) { response.Success(c, status) } + +// BatchUpdateConcurrency 批量修改用户并发数 +// POST /api/v1/admin/users/batch-concurrency +type BatchUpdateConcurrencyRequest struct { + UserIDs []int64 `json:"user_ids"` + All bool `json:"all"` + Concurrency int `json:"concurrency"` + Mode string `json:"mode" binding:"required,oneof=set add"` +} + +func (h *UserHandler) BatchUpdateConcurrency(c *gin.Context) { + var req BatchUpdateConcurrencyRequest + if err := c.ShouldBindJSON(&req); err != nil { + response.BadRequest(c, "Invalid request: "+err.Error()) + return + } + if !req.All && len(req.UserIDs) == 0 { + response.BadRequest(c, "user_ids is required unless all=true") + return + } + if len(req.UserIDs) > 500 { + response.BadRequest(c, "user_ids cannot exceed 500") + return + } + + var userIDs []int64 + if req.All { + // Fetch all user IDs via pagination + page := 1 + const pageSize = 500 + for { + users, _, err := h.adminService.ListUsers(c.Request.Context(), page, pageSize, service.UserListFilters{}, "id", "asc") + if err != nil { + response.ErrorFrom(c, err) + return + } + for _, u := range users { + userIDs = append(userIDs, u.ID) + } + if len(users) < pageSize { + break + } + page++ + } + } else { + userIDs = req.UserIDs + } + + if len(userIDs) == 0 { + response.Success(c, gin.H{"affected": 0}) + return + } + + affected, err := h.adminService.BatchUpdateConcurrency(c.Request.Context(), userIDs, req.Concurrency, req.Mode) + if err != nil { + response.ErrorFrom(c, err) + return + } + response.Success(c, gin.H{"affected": affected}) +} diff --git a/backend/internal/handler/auth_oauth_pending_flow_test.go b/backend/internal/handler/auth_oauth_pending_flow_test.go index ffe9ff5f..b598eae1 100644 --- a/backend/internal/handler/auth_oauth_pending_flow_test.go +++ b/backend/internal/handler/auth_oauth_pending_flow_test.go @@ -2798,6 +2798,14 @@ func (r *oauthPendingFlowUserRepo) UpdateConcurrency(context.Context, int64, int panic("unexpected UpdateConcurrency call") } +func (r *oauthPendingFlowUserRepo) BatchSetConcurrency(context.Context, []int64, int) (int, error) { + panic("unexpected BatchSetConcurrency call") +} + +func (r *oauthPendingFlowUserRepo) BatchAddConcurrency(context.Context, []int64, int) (int, error) { + panic("unexpected BatchAddConcurrency call") +} + func (r *oauthPendingFlowUserRepo) GetLatestUsedAtByUserIDs(context.Context, []int64) (map[int64]*time.Time, error) { return map[int64]*time.Time{}, nil } diff --git a/backend/internal/handler/dto/settings.go b/backend/internal/handler/dto/settings.go index fba85cf2..261d9c78 100644 --- a/backend/internal/handler/dto/settings.go +++ b/backend/internal/handler/dto/settings.go @@ -11,6 +11,7 @@ type CustomMenuItem struct { Label string `json:"label"` IconSVG string `json:"icon_svg"` URL string `json:"url"` + PageSlug string `json:"page_slug,omitempty"` Visibility string `json:"visibility"` // "user" or "admin" SortOrder int `json:"sort_order"` } diff --git a/backend/internal/handler/page_handler.go b/backend/internal/handler/page_handler.go new file mode 100644 index 00000000..a3e4f5d2 --- /dev/null +++ b/backend/internal/handler/page_handler.go @@ -0,0 +1,215 @@ +package handler + +import ( + "encoding/json" + "net/http" + "os" + "path/filepath" + "regexp" + "strings" + + "github.com/Wei-Shaw/sub2api/internal/pkg/response" + middleware2 "github.com/Wei-Shaw/sub2api/internal/server/middleware" + "github.com/Wei-Shaw/sub2api/internal/service" + "github.com/gin-gonic/gin" +) + +var validSlugPattern = regexp.MustCompile(`^[a-zA-Z0-9][a-zA-Z0-9_-]*$`) + +const maxPageFileSize = 1 << 20 // 1MB + +type PageHandler struct { + pagesDir string + settingService *service.SettingService +} + +func NewPageHandler(dataDir string, settingService *service.SettingService) *PageHandler { + pagesDir := filepath.Join(dataDir, "pages") + _ = os.MkdirAll(pagesDir, 0755) + return &PageHandler{pagesDir: pagesDir, settingService: settingService} +} + +// GetPageContent serves raw markdown content for a given slug. +// GET /api/v1/pages/:slug +func (h *PageHandler) GetPageContent(c *gin.Context) { + slug := c.Param("slug") + if !validSlugPattern.MatchString(slug) || len(slug) > 64 { + response.BadRequest(c, "Invalid page slug") + return + } + + // Visibility check: slug must be configured in custom_menu_items + // and the user must have permission based on visibility setting + if !h.checkSlugVisibility(c, slug) { + c.JSON(http.StatusNotFound, gin.H{"error": "page not found"}) + return + } + + filePath := filepath.Join(h.pagesDir, slug+".md") + cleaned := filepath.Clean(filePath) + if !strings.HasPrefix(cleaned, filepath.Clean(h.pagesDir)) { + response.BadRequest(c, "Invalid page slug") + return + } + + info, err := os.Stat(cleaned) + if err != nil || info.IsDir() { + c.JSON(http.StatusNotFound, gin.H{"error": "page not found"}) + return + } + if info.Size() > maxPageFileSize { + c.JSON(http.StatusRequestEntityTooLarge, gin.H{"error": "page too large"}) + return + } + + content, err := os.ReadFile(cleaned) + if err != nil { + c.JSON(http.StatusInternalServerError, gin.H{"error": "failed to read page"}) + return + } + + c.Data(http.StatusOK, "text/markdown; charset=utf-8", content) +} + +// ListPages returns available page slugs. +// GET /api/v1/pages +func (h *PageHandler) ListPages(c *gin.Context) { + entries, err := os.ReadDir(h.pagesDir) + if err != nil { + response.Success(c, []string{}) + return + } + + slugs := make([]string, 0, len(entries)) + for _, e := range entries { + if e.IsDir() { + continue + } + name := e.Name() + if strings.HasSuffix(name, ".md") { + slugs = append(slugs, strings.TrimSuffix(name, ".md")) + } + } + response.Success(c, slugs) +} + +// ServePageImage serves images from data/pages/{slug}/ directory. +// GET /api/v1/pages/:slug/images/*filename +// No JWT required (browser img tags can't carry tokens), but visibility is checked. +func (h *PageHandler) ServePageImage(c *gin.Context) { + slug := c.Param("slug") + filename := c.Param("filename") + filename = strings.TrimPrefix(filename, "/") + + if !validSlugPattern.MatchString(slug) || len(slug) > 64 { + c.Status(http.StatusNotFound) + return + } + + if !h.checkImageSlugVisibility(c, slug) { + c.Status(http.StatusNotFound) + return + } + + if filename == "" || strings.Contains(filename, "..") || strings.Contains(filename, "/") || strings.Contains(filename, "\\") { + c.Status(http.StatusNotFound) + return + } + + imagesDir := filepath.Join(h.pagesDir, slug) + filePath := filepath.Join(imagesDir, filename) + cleaned := filepath.Clean(filePath) + if !strings.HasPrefix(cleaned, filepath.Clean(imagesDir)) { + c.Status(http.StatusNotFound) + return + } + + info, err := os.Stat(cleaned) + if err != nil || info.IsDir() { + c.Status(http.StatusNotFound) + return + } + + c.File(cleaned) +} + +// findSlugVisibility looks up the slug in custom_menu_items and returns (visibility, found). +func (h *PageHandler) findSlugVisibility(c *gin.Context, slug string) (string, bool) { + if h.settingService == nil { + return "", false + } + + raw := h.settingService.GetCustomMenuItemsRaw(c.Request.Context()) + if raw == "" || raw == "[]" { + return "", false + } + + var items []struct { + URL string `json:"url"` + PageSlug string `json:"page_slug"` + Visibility string `json:"visibility"` + } + if err := json.Unmarshal([]byte(raw), &items); err != nil { + return "", false + } + + for _, item := range items { + itemSlug := item.PageSlug + if itemSlug == "" && strings.HasPrefix(item.URL, "md:") { + itemSlug = strings.TrimPrefix(item.URL, "md:") + } + if itemSlug == slug { + return item.Visibility, true + } + } + return "", false +} + +// checkSlugVisibility verifies the slug is configured in custom_menu_items +// and the authenticated user has permission to view it. +func (h *PageHandler) checkSlugVisibility(c *gin.Context, slug string) bool { + visibility, found := h.findSlugVisibility(c, slug) + if !found { + return false + } + if visibility == "admin" { + role, _ := middleware2.GetUserRoleFromContext(c) + return role == "admin" + } + return true +} + +// checkImageSlugVisibility checks visibility for image requests (no JWT available). +// Only allows user-visible pages; admin-only pages are blocked. +func (h *PageHandler) checkImageSlugVisibility(c *gin.Context, slug string) bool { + visibility, found := h.findSlugVisibility(c, slug) + if !found { + return false + } + return visibility != "admin" +} + +// RegisterPageRoutes registers page routes on a router group. +func RegisterPageRoutes(v1 *gin.RouterGroup, dataDir string, jwtAuth gin.HandlerFunc, adminAuth gin.HandlerFunc, settingService *service.SettingService) { + h := NewPageHandler(dataDir, settingService) + + // Authenticated page content (JWT required + visibility check) + pages := v1.Group("/pages") + pages.Use(jwtAuth) + { + pages.GET("/:slug", h.GetPageContent) + } + + // Images: no JWT (browser img tags can't carry tokens), visibility check in handler + pageImages := v1.Group("/pages") + { + pageImages.GET("/:slug/images/*filename", h.ServePageImage) + } + + // Admin-only: list all available pages + adminPages := v1.Group("/pages") + adminPages.Use(adminAuth) + { + adminPages.GET("", h.ListPages) + } +} diff --git a/backend/internal/handler/user_handler_test.go b/backend/internal/handler/user_handler_test.go index 8a864b51..ffca86dc 100644 --- a/backend/internal/handler/user_handler_test.go +++ b/backend/internal/handler/user_handler_test.go @@ -87,6 +87,8 @@ func (s *userHandlerRepoStub) ListWithFilters(context.Context, pagination.Pagina func (s *userHandlerRepoStub) UpdateBalance(context.Context, int64, float64) error { return nil } func (s *userHandlerRepoStub) DeductBalance(context.Context, int64, float64) error { return nil } func (s *userHandlerRepoStub) UpdateConcurrency(context.Context, int64, int) error { return nil } +func (s *userHandlerRepoStub) BatchSetConcurrency(context.Context, []int64, int) (int, error) { return 0, nil } +func (s *userHandlerRepoStub) BatchAddConcurrency(context.Context, []int64, int) (int, error) { return 0, nil } func (s *userHandlerRepoStub) ExistsByEmail(context.Context, string) (bool, error) { return false, nil } func (s *userHandlerRepoStub) RemoveGroupFromAllowedGroups(context.Context, int64) (int64, error) { return 0, nil diff --git a/backend/internal/repository/user_repo.go b/backend/internal/repository/user_repo.go index d1f10cbd..1566756d 100644 --- a/backend/internal/repository/user_repo.go +++ b/backend/internal/repository/user_repo.go @@ -737,6 +737,37 @@ func (r *userRepository) UpdateConcurrency(ctx context.Context, id int64, amount return nil } +func (r *userRepository) BatchSetConcurrency(ctx context.Context, userIDs []int64, value int) (int, error) { + if len(userIDs) == 0 { + return 0, nil + } + if value < 0 { + value = 0 + } + res, err := r.sql.ExecContext(ctx, + "UPDATE users SET concurrency = $1, updated_at = NOW() WHERE id = ANY($2) AND deleted_at IS NULL", + value, pq.Array(userIDs)) + if err != nil { + return 0, fmt.Errorf("batch set concurrency: %w", err) + } + affected, _ := res.RowsAffected() + return int(affected), nil +} + +func (r *userRepository) BatchAddConcurrency(ctx context.Context, userIDs []int64, delta int) (int, error) { + if len(userIDs) == 0 { + return 0, nil + } + res, err := r.sql.ExecContext(ctx, + "UPDATE users SET concurrency = GREATEST(concurrency + $1, 0), updated_at = NOW() WHERE id = ANY($2) AND deleted_at IS NULL", + delta, pq.Array(userIDs)) + if err != nil { + return 0, fmt.Errorf("batch add concurrency: %w", err) + } + affected, _ := res.RowsAffected() + return int(affected), nil +} + func (r *userRepository) ExistsByEmail(ctx context.Context, email string) (bool, error) { return r.client.User.Query().Where(userEmailLookupPredicate(email)).Exist(ctx) } diff --git a/backend/internal/server/api_contract_test.go b/backend/internal/server/api_contract_test.go index 37606d94..7de4e17b 100644 --- a/backend/internal/server/api_contract_test.go +++ b/backend/internal/server/api_contract_test.go @@ -1125,7 +1125,7 @@ func newContractDeps(t *testing.T) *contractDeps { subscriptionService := service.NewSubscriptionService(groupRepo, userSubRepo, nil, nil, cfg) subscriptionHandler := handler.NewSubscriptionHandler(subscriptionService) - redeemService := service.NewRedeemService(redeemRepo, userRepo, subscriptionService, nil, nil, nil, nil) + redeemService := service.NewRedeemService(redeemRepo, userRepo, subscriptionService, nil, nil, nil, nil, nil) redeemHandler := handler.NewRedeemHandler(redeemService) settingRepo := newStubSettingRepo() @@ -1296,6 +1296,9 @@ func (r *stubUserRepo) UpdateConcurrency(ctx context.Context, id int64, amount i return errors.New("not implemented") } +func (r *stubUserRepo) BatchSetConcurrency(context.Context, []int64, int) (int, error) { return 0, nil } +func (r *stubUserRepo) BatchAddConcurrency(context.Context, []int64, int) (int, error) { return 0, nil } + func (r *stubUserRepo) ExistsByEmail(ctx context.Context, email string) (bool, error) { return false, errors.New("not implemented") } diff --git a/backend/internal/server/middleware/admin_auth_test.go b/backend/internal/server/middleware/admin_auth_test.go index dde92dfd..3fbbb716 100644 --- a/backend/internal/server/middleware/admin_auth_test.go +++ b/backend/internal/server/middleware/admin_auth_test.go @@ -198,6 +198,9 @@ func (s *stubUserRepo) UpdateConcurrency(ctx context.Context, id int64, amount i panic("unexpected UpdateConcurrency call") } +func (s *stubUserRepo) BatchSetConcurrency(context.Context, []int64, int) (int, error) { return 0, nil } +func (s *stubUserRepo) BatchAddConcurrency(context.Context, []int64, int) (int, error) { return 0, nil } + func (s *stubUserRepo) ExistsByEmail(ctx context.Context, email string) (bool, error) { panic("unexpected ExistsByEmail call") } diff --git a/backend/internal/server/router.go b/backend/internal/server/router.go index a507b6f8..f477f3a7 100644 --- a/backend/internal/server/router.go +++ b/backend/internal/server/router.go @@ -112,4 +112,6 @@ func registerRoutes( routes.RegisterAdminRoutes(v1, h, adminAuth) routes.RegisterGatewayRoutes(r, h, apiKeyAuth, apiKeyService, subscriptionService, opsService, settingService, cfg) routes.RegisterPaymentRoutes(v1, h.Payment, h.PaymentWebhook, h.Admin.Payment, jwtAuth, adminAuth, settingService) + + handler.RegisterPageRoutes(v1, cfg.Pricing.DataDir, gin.HandlerFunc(jwtAuth), gin.HandlerFunc(adminAuth), settingService) } diff --git a/backend/internal/server/routes/admin.go b/backend/internal/server/routes/admin.go index a2d225e0..5eb0d34b 100644 --- a/backend/internal/server/routes/admin.go +++ b/backend/internal/server/routes/admin.go @@ -245,6 +245,7 @@ func registerUserManagementRoutes(admin *gin.RouterGroup, h *handler.Handlers) { users.GET("/:id/balance-history", h.Admin.User.GetBalanceHistory) users.POST("/:id/replace-group", h.Admin.User.ReplaceGroup) users.GET("/:id/rpm-status", h.Admin.User.GetUserRPMStatus) + users.POST("/batch-concurrency", h.Admin.User.BatchUpdateConcurrency) // User attribute values users.GET("/:id/attributes", h.Admin.UserAttribute.GetUserAttributes) diff --git a/backend/internal/service/admin_service.go b/backend/internal/service/admin_service.go index 793d60d8..eb5994d5 100644 --- a/backend/internal/service/admin_service.go +++ b/backend/internal/service/admin_service.go @@ -33,6 +33,7 @@ type AdminService interface { UpdateUser(ctx context.Context, id int64, input *UpdateUserInput) (*User, error) DeleteUser(ctx context.Context, id int64) error UpdateUserBalance(ctx context.Context, userID int64, balance float64, operation string, notes string) (*User, error) + BatchUpdateConcurrency(ctx context.Context, userIDs []int64, value int, mode string) (int, error) GetUserAPIKeys(ctx context.Context, userID int64, page, pageSize int, sortBy, sortOrder string) ([]APIKey, int64, error) GetUserUsageStats(ctx context.Context, userID int64, period string) (any, error) GetUserRPMStatus(ctx context.Context, userID int64) (*UserRPMStatus, error) @@ -817,6 +818,39 @@ func (s *adminServiceImpl) DeleteUser(ctx context.Context, id int64) error { return nil } +func (s *adminServiceImpl) BatchUpdateConcurrency(ctx context.Context, userIDs []int64, value int, mode string) (int, error) { + cleaned := make([]int64, 0, len(userIDs)) + for _, uid := range userIDs { + if uid > 0 { + cleaned = append(cleaned, uid) + } + } + if len(cleaned) == 0 { + return 0, nil + } + + var affected int + var err error + switch mode { + case "set": + affected, err = s.userRepo.BatchSetConcurrency(ctx, cleaned, value) + case "add": + affected, err = s.userRepo.BatchAddConcurrency(ctx, cleaned, value) + default: + return 0, errors.New("invalid mode: must be 'set' or 'add'") + } + if err != nil { + return 0, err + } + + if s.authCacheInvalidator != nil { + for _, uid := range cleaned { + s.authCacheInvalidator.InvalidateAuthCacheByUserID(ctx, uid) + } + } + return affected, nil +} + func (s *adminServiceImpl) UpdateUserBalance(ctx context.Context, userID int64, balance float64, operation string, notes string) (*User, error) { user, err := s.userRepo.GetByID(ctx, userID) if err != nil { diff --git a/backend/internal/service/admin_service_apikey_test.go b/backend/internal/service/admin_service_apikey_test.go index fcde5cbf..3b3dbc21 100644 --- a/backend/internal/service/admin_service_apikey_test.go +++ b/backend/internal/service/admin_service_apikey_test.go @@ -68,6 +68,9 @@ func (s *userRepoStubForGroupUpdate) DeductBalance(context.Context, int64, float func (s *userRepoStubForGroupUpdate) UpdateConcurrency(context.Context, int64, int) error { panic("unexpected") } + +func (s *userRepoStubForGroupUpdate) BatchSetConcurrency(context.Context, []int64, int) (int, error) { return 0, nil } +func (s *userRepoStubForGroupUpdate) BatchAddConcurrency(context.Context, []int64, int) (int, error) { return 0, nil } func (s *userRepoStubForGroupUpdate) ExistsByEmail(context.Context, string) (bool, error) { panic("unexpected") } diff --git a/backend/internal/service/admin_service_delete_test.go b/backend/internal/service/admin_service_delete_test.go index fe9e7701..a9492a1d 100644 --- a/backend/internal/service/admin_service_delete_test.go +++ b/backend/internal/service/admin_service_delete_test.go @@ -131,6 +131,9 @@ func (s *userRepoStub) UpdateConcurrency(ctx context.Context, id int64, amount i panic("unexpected UpdateConcurrency call") } +func (s *userRepoStub) BatchSetConcurrency(context.Context, []int64, int) (int, error) { return 0, nil } +func (s *userRepoStub) BatchAddConcurrency(context.Context, []int64, int) (int, error) { return 0, nil } + func (s *userRepoStub) ExistsByEmail(ctx context.Context, email string) (bool, error) { if s.existsErr != nil { return false, s.existsErr diff --git a/backend/internal/service/admin_service_email_identity_sync_test.go b/backend/internal/service/admin_service_email_identity_sync_test.go index 2232c9c3..c791b747 100644 --- a/backend/internal/service/admin_service_email_identity_sync_test.go +++ b/backend/internal/service/admin_service_email_identity_sync_test.go @@ -113,6 +113,9 @@ func (s *emailSyncRepoStub) RemoveGroupFromAllowedGroups(context.Context, int64) return 0, nil } +func (s *emailSyncRepoStub) BatchSetConcurrency(context.Context, []int64, int) (int, error) { return 0, nil } +func (s *emailSyncRepoStub) BatchAddConcurrency(context.Context, []int64, int) (int, error) { return 0, nil } + func (s *emailSyncRepoStub) AddGroupToAllowedGroups(context.Context, int64, int64) error { return nil } func (s *emailSyncRepoStub) RemoveGroupFromUserAllowedGroups(context.Context, int64, int64) error { diff --git a/backend/internal/service/auth_service_email_bind_test.go b/backend/internal/service/auth_service_email_bind_test.go index ea2308f7..8f03f857 100644 --- a/backend/internal/service/auth_service_email_bind_test.go +++ b/backend/internal/service/auth_service_email_bind_test.go @@ -820,6 +820,9 @@ func (s *emailBindUserRepoStub) ExistsByEmail(_ context.Context, email string) ( return ok, nil } +func (s *emailBindUserRepoStub) BatchSetConcurrency(context.Context, []int64, int) (int, error) { return 0, nil } +func (s *emailBindUserRepoStub) BatchAddConcurrency(context.Context, []int64, int) (int, error) { return 0, nil } + func (s *emailBindUserRepoStub) RemoveGroupFromAllowedGroups(context.Context, int64) (int64, error) { return 0, nil } diff --git a/backend/internal/service/payment_fulfillment.go b/backend/internal/service/payment_fulfillment.go index 4ae6d134..f96684a4 100644 --- a/backend/internal/service/payment_fulfillment.go +++ b/backend/internal/service/payment_fulfillment.go @@ -282,7 +282,7 @@ func (s *PaymentService) doBalance(ctx context.Context, o *dbent.PaymentOrder) e case redeemActionRedeem: // Code exists but unused — skip creation, proceed to redeem } - if _, err := s.redeemService.Redeem(ctx, o.UserID, o.RechargeCode); err != nil { + if _, err := s.redeemService.Redeem(ContextSkipRedeemAffiliate(ctx), o.UserID, o.RechargeCode); err != nil { return fmt.Errorf("redeem balance: %w", err) } if err := s.applyAffiliateRebateForOrder(ctx, o); err != nil { diff --git a/backend/internal/service/payment_order_lifecycle_test.go b/backend/internal/service/payment_order_lifecycle_test.go index 8dfd2e7e..d8595715 100644 --- a/backend/internal/service/payment_order_lifecycle_test.go +++ b/backend/internal/service/payment_order_lifecycle_test.go @@ -208,6 +208,7 @@ func TestVerifyOrderByOutTradeNoBackfillsTradeNoFromPaidQuery(t *testing.T) { nil, client, nil, + nil, ) registry := payment.NewRegistry() provider := &paymentOrderLifecycleQueryProvider{ @@ -308,6 +309,7 @@ func TestVerifyOrderByOutTradeNoRetriesZeroAmountPaidQueryOnce(t *testing.T) { nil, client, nil, + nil, ) registry := payment.NewRegistry() provider := &paymentOrderLifecycleQueryProvider{ @@ -398,6 +400,7 @@ func TestVerifyOrderByOutTradeNoRejectsPaidQueryWithZeroAmount(t *testing.T) { nil, client, nil, + nil, ) registry := payment.NewRegistry() provider := &paymentOrderLifecycleQueryProvider{ @@ -496,6 +499,7 @@ func TestVerifyOrderByOutTradeNoUsesOutTradeNoWhenPaymentTradeNoAlreadyExistsFor nil, client, nil, + nil, ) registry := payment.NewRegistry() provider := &paymentOrderLifecycleQueryProvider{ diff --git a/backend/internal/service/redeem_service.go b/backend/internal/service/redeem_service.go index 9ced6201..dcf293c5 100644 --- a/backend/internal/service/redeem_service.go +++ b/backend/internal/service/redeem_service.go @@ -11,6 +11,7 @@ import ( dbent "github.com/Wei-Shaw/sub2api/ent" infraerrors "github.com/Wei-Shaw/sub2api/internal/pkg/errors" + "github.com/Wei-Shaw/sub2api/internal/pkg/logger" "github.com/Wei-Shaw/sub2api/internal/pkg/pagination" ) @@ -28,6 +29,15 @@ const ( redeemLockDuration = 10 * time.Second // 锁超时时间,防止死锁 ) +type ctxKeySkipRedeemAffiliate struct{} + +// ContextSkipRedeemAffiliate returns a context that suppresses the redeem-level +// affiliate rebate. Used by payment fulfillment which handles rebate separately +// via applyAffiliateRebateForOrder (with audit-log deduplication). +func ContextSkipRedeemAffiliate(ctx context.Context) context.Context { + return context.WithValue(ctx, ctxKeySkipRedeemAffiliate{}, true) +} + // RedeemCache defines cache operations for redeem service type RedeemCache interface { GetRedeemAttemptCount(ctx context.Context, userID int64) (int, error) @@ -80,6 +90,7 @@ type RedeemService struct { billingCacheService *BillingCacheService entClient *dbent.Client authCacheInvalidator APIKeyAuthCacheInvalidator + affiliateService *AffiliateService } // NewRedeemService 创建兑换码服务实例 @@ -91,6 +102,7 @@ func NewRedeemService( billingCacheService *BillingCacheService, entClient *dbent.Client, authCacheInvalidator APIKeyAuthCacheInvalidator, + affiliateService *AffiliateService, ) *RedeemService { return &RedeemService{ redeemRepo: redeemRepo, @@ -100,6 +112,7 @@ func NewRedeemService( billingCacheService: billingCacheService, entClient: entClient, authCacheInvalidator: authCacheInvalidator, + affiliateService: affiliateService, } } @@ -369,6 +382,11 @@ func (s *RedeemService) Redeem(ctx context.Context, userID int64, code string) ( // 事务提交成功后失效缓存 s.invalidateRedeemCaches(ctx, userID, redeemCode) + // 余额类正数兑换码触发邀请返利(best-effort,失败不影响兑换结果) + if redeemCode.Type == RedeemTypeBalance && redeemCode.Value > 0 { + s.tryAccrueAffiliateRebateForRedeem(ctx, userID, redeemCode.Value) + } + // 重新获取更新后的兑换码 redeemCode, err = s.redeemRepo.GetByID(ctx, redeemCode.ID) if err != nil { @@ -418,6 +436,26 @@ func (s *RedeemService) invalidateRedeemCaches(ctx context.Context, userID int64 } } +func (s *RedeemService) tryAccrueAffiliateRebateForRedeem(ctx context.Context, userID int64, amount float64) { + if ctx.Value(ctxKeySkipRedeemAffiliate{}) != nil { + return + } + if s.affiliateService == nil { + return + } + if !s.affiliateService.IsEnabled(ctx) { + return + } + rebate, err := s.affiliateService.AccrueInviteRebate(ctx, userID, amount) + if err != nil { + logger.LegacyPrintf("service.redeem", "[Redeem] affiliate rebate failed for user %d amount %.2f: %v", userID, amount, err) + return + } + if rebate > 0 { + logger.LegacyPrintf("service.redeem", "[Redeem] affiliate rebate accrued %.8f for inviter of user %d", rebate, userID) + } +} + // GetByID 根据ID获取兑换码 func (s *RedeemService) GetByID(ctx context.Context, id int64) (*RedeemCode, error) { code, err := s.redeemRepo.GetByID(ctx, id) diff --git a/backend/internal/service/setting_service.go b/backend/internal/service/setting_service.go index bf6294db..c0b7816a 100644 --- a/backend/internal/service/setting_service.go +++ b/backend/internal/service/setting_service.go @@ -1542,6 +1542,15 @@ func (s *SettingService) IsInvitationCodeEnabled(ctx context.Context) bool { return value == "true" } +// GetCustomMenuItemsRaw returns the raw JSON string of custom_menu_items setting. +func (s *SettingService) GetCustomMenuItemsRaw(ctx context.Context) string { + value, err := s.settingRepo.GetValue(ctx, SettingKeyCustomMenuItems) + if err != nil { + return "[]" + } + return value +} + // IsAffiliateEnabled 检查是否启用邀请返利功能(总开关) func (s *SettingService) IsAffiliateEnabled(ctx context.Context) bool { value, err := s.settingRepo.GetValue(ctx, SettingKeyAffiliateEnabled) diff --git a/backend/internal/service/user_service.go b/backend/internal/service/user_service.go index a7279e6a..f84e6f0a 100644 --- a/backend/internal/service/user_service.go +++ b/backend/internal/service/user_service.go @@ -96,6 +96,8 @@ type UserRepository interface { UpdateBalance(ctx context.Context, id int64, amount float64) error DeductBalance(ctx context.Context, id int64, amount float64) error UpdateConcurrency(ctx context.Context, id int64, amount int) error + BatchSetConcurrency(ctx context.Context, userIDs []int64, value int) (int, error) + BatchAddConcurrency(ctx context.Context, userIDs []int64, delta int) (int, error) ExistsByEmail(ctx context.Context, email string) (bool, error) RemoveGroupFromAllowedGroups(ctx context.Context, groupID int64) (int64, error) // AddGroupToAllowedGroups 将指定分组增量添加到用户的 allowed_groups(幂等,冲突忽略) diff --git a/backend/internal/service/user_service_test.go b/backend/internal/service/user_service_test.go index ff55c2a5..775dd602 100644 --- a/backend/internal/service/user_service_test.go +++ b/backend/internal/service/user_service_test.go @@ -199,6 +199,9 @@ func (m *mockUserRepo) ExistsByEmail(context.Context, string) (bool, error) { re func (m *mockUserRepo) RemoveGroupFromAllowedGroups(context.Context, int64) (int64, error) { return 0, nil } + +func (m *mockUserRepo) BatchSetConcurrency(context.Context, []int64, int) (int, error) { return 0, nil } +func (m *mockUserRepo) BatchAddConcurrency(context.Context, []int64, int) (int, error) { return 0, nil } func (m *mockUserRepo) AddGroupToAllowedGroups(context.Context, int64, int64) error { return nil } func (m *mockUserRepo) ListUserAuthIdentities(context.Context, int64) ([]UserAuthIdentityRecord, error) { out := make([]UserAuthIdentityRecord, len(m.identities)) diff --git a/frontend/src/types/index.ts b/frontend/src/types/index.ts index 727b9436..26e1681b 100644 --- a/frontend/src/types/index.ts +++ b/frontend/src/types/index.ts @@ -168,6 +168,7 @@ export interface CustomMenuItem { label: string icon_svg: string url: string + page_slug?: string visibility: 'user' | 'admin' sort_order: number } diff --git a/frontend/src/views/user/CustomPageView.vue b/frontend/src/views/user/CustomPageView.vue index ce930d96..0752d5e3 100644 --- a/frontend/src/views/user/CustomPageView.vue +++ b/frontend/src/views/user/CustomPageView.vue @@ -27,6 +27,56 @@ + +
+ + + + + + + +
+
+ +
+