diff --git a/backend/cmd/server/main.go b/backend/cmd/server/main.go index 70540e35..ae000119 100644 --- a/backend/cmd/server/main.go +++ b/backend/cmd/server/main.go @@ -15,21 +15,12 @@ import ( "syscall" "time" - "sub2api/internal/config" "sub2api/internal/handler" "sub2api/internal/middleware" - "sub2api/internal/model" - "sub2api/internal/pkg/timezone" - "sub2api/internal/repository" - "sub2api/internal/service" "sub2api/internal/setup" "sub2api/internal/web" "github.com/gin-gonic/gin" - "github.com/redis/go-redis/v9" - "gorm.io/driver/postgres" - "gorm.io/gorm" - "gorm.io/gorm/logger" ) //go:embed VERSION @@ -149,319 +140,3 @@ func runMainServer() { log.Println("Server exited") } - -func initDB(cfg *config.Config) (*gorm.DB, error) { - // 初始化时区(在数据库连接之前,确保时区设置正确) - if err := timezone.Init(cfg.Timezone); err != nil { - return nil, err - } - - gormConfig := &gorm.Config{} - if cfg.Server.Mode == "debug" { - gormConfig.Logger = logger.Default.LogMode(logger.Info) - } - - // 使用带时区的 DSN 连接数据库 - db, err := gorm.Open(postgres.Open(cfg.Database.DSNWithTimezone(cfg.Timezone)), gormConfig) - if err != nil { - return nil, err - } - - // 自动迁移(始终执行,确保数据库结构与代码同步) - // GORM 的 AutoMigrate 只会添加新字段,不会删除或修改已有字段,是安全的 - if err := model.AutoMigrate(db); err != nil { - return nil, err - } - - return db, nil -} - -func initRedis(cfg *config.Config) *redis.Client { - return redis.NewClient(&redis.Options{ - Addr: cfg.Redis.Address(), - Password: cfg.Redis.Password, - DB: cfg.Redis.DB, - }) -} - -func registerRoutes(r *gin.Engine, h *handler.Handlers, s *service.Services, repos *repository.Repositories) { - // 健康检查 - r.GET("/health", func(c *gin.Context) { - c.JSON(http.StatusOK, gin.H{"status": "ok"}) - }) - - // Setup status endpoint (always returns needs_setup: false in normal mode) - // This is used by the frontend to detect when the service has restarted after setup - r.GET("/setup/status", func(c *gin.Context) { - c.JSON(http.StatusOK, gin.H{ - "code": 0, - "data": gin.H{ - "needs_setup": false, - "step": "completed", - }, - }) - }) - - // API v1 - v1 := r.Group("/api/v1") - { - // 公开接口 - auth := v1.Group("/auth") - { - auth.POST("/register", h.Auth.Register) - auth.POST("/login", h.Auth.Login) - auth.POST("/send-verify-code", h.Auth.SendVerifyCode) - } - - // 公开设置(无需认证) - settings := v1.Group("/settings") - { - settings.GET("/public", h.Setting.GetPublicSettings) - } - - // 需要认证的接口 - authenticated := v1.Group("") - authenticated.Use(middleware.JWTAuth(s.Auth, repos.User)) - { - // 当前用户信息 - authenticated.GET("/auth/me", h.Auth.GetCurrentUser) - - // 用户接口 - user := authenticated.Group("/users/me") - { - user.GET("", h.User.GetProfile) - user.POST("/password", h.User.ChangePassword) - } - - // API Key管理 - keys := authenticated.Group("/keys") - { - keys.GET("", h.APIKey.List) - keys.GET("/:id", h.APIKey.GetByID) - keys.POST("", h.APIKey.Create) - keys.PUT("/:id", h.APIKey.Update) - keys.DELETE("/:id", h.APIKey.Delete) - } - - // 用户可用分组(非管理员接口) - groups := authenticated.Group("/groups") - { - groups.GET("/available", h.APIKey.GetAvailableGroups) - } - - // 使用记录 - usage := authenticated.Group("/usage") - { - usage.GET("", h.Usage.List) - usage.GET("/:id", h.Usage.GetByID) - usage.GET("/stats", h.Usage.Stats) - // User dashboard endpoints - usage.GET("/dashboard/stats", h.Usage.DashboardStats) - usage.GET("/dashboard/trend", h.Usage.DashboardTrend) - usage.GET("/dashboard/models", h.Usage.DashboardModels) - usage.POST("/dashboard/api-keys-usage", h.Usage.DashboardApiKeysUsage) - } - - // 卡密兑换 - redeem := authenticated.Group("/redeem") - { - redeem.POST("", h.Redeem.Redeem) - redeem.GET("/history", h.Redeem.GetHistory) - } - - // 用户订阅 - subscriptions := authenticated.Group("/subscriptions") - { - subscriptions.GET("", h.Subscription.List) - subscriptions.GET("/active", h.Subscription.GetActive) - subscriptions.GET("/progress", h.Subscription.GetProgress) - subscriptions.GET("/summary", h.Subscription.GetSummary) - } - } - - // 管理员接口 - admin := v1.Group("/admin") - admin.Use(middleware.JWTAuth(s.Auth, repos.User), middleware.AdminOnly()) - { - // 仪表盘 - dashboard := admin.Group("/dashboard") - { - dashboard.GET("/stats", h.Admin.Dashboard.GetStats) - dashboard.GET("/realtime", h.Admin.Dashboard.GetRealtimeMetrics) - dashboard.GET("/trend", h.Admin.Dashboard.GetUsageTrend) - dashboard.GET("/models", h.Admin.Dashboard.GetModelStats) - dashboard.GET("/api-keys-trend", h.Admin.Dashboard.GetApiKeyUsageTrend) - dashboard.GET("/users-trend", h.Admin.Dashboard.GetUserUsageTrend) - dashboard.POST("/users-usage", h.Admin.Dashboard.GetBatchUsersUsage) - dashboard.POST("/api-keys-usage", h.Admin.Dashboard.GetBatchApiKeysUsage) - } - - // 用户管理 - users := admin.Group("/users") - { - users.GET("", h.Admin.User.List) - users.GET("/:id", h.Admin.User.GetByID) - users.POST("", h.Admin.User.Create) - users.PUT("/:id", h.Admin.User.Update) - users.DELETE("/:id", h.Admin.User.Delete) - users.POST("/:id/balance", h.Admin.User.UpdateBalance) - users.GET("/:id/api-keys", h.Admin.User.GetUserAPIKeys) - users.GET("/:id/usage", h.Admin.User.GetUserUsage) - } - - // 分组管理 - groups := admin.Group("/groups") - { - groups.GET("", h.Admin.Group.List) - groups.GET("/all", h.Admin.Group.GetAll) - groups.GET("/:id", h.Admin.Group.GetByID) - groups.POST("", h.Admin.Group.Create) - groups.PUT("/:id", h.Admin.Group.Update) - groups.DELETE("/:id", h.Admin.Group.Delete) - groups.GET("/:id/stats", h.Admin.Group.GetStats) - groups.GET("/:id/api-keys", h.Admin.Group.GetGroupAPIKeys) - } - - // 账号管理 - accounts := admin.Group("/accounts") - { - accounts.GET("", h.Admin.Account.List) - accounts.GET("/:id", h.Admin.Account.GetByID) - accounts.POST("", h.Admin.Account.Create) - accounts.PUT("/:id", h.Admin.Account.Update) - accounts.DELETE("/:id", h.Admin.Account.Delete) - accounts.POST("/:id/test", h.Admin.Account.Test) - accounts.POST("/:id/refresh", h.Admin.Account.Refresh) - accounts.GET("/:id/stats", h.Admin.Account.GetStats) - accounts.POST("/:id/clear-error", h.Admin.Account.ClearError) - accounts.GET("/:id/usage", h.Admin.Account.GetUsage) - accounts.GET("/:id/today-stats", h.Admin.Account.GetTodayStats) - accounts.POST("/:id/clear-rate-limit", h.Admin.Account.ClearRateLimit) - accounts.POST("/:id/schedulable", h.Admin.Account.SetSchedulable) - accounts.POST("/batch", h.Admin.Account.BatchCreate) - - // OAuth routes - accounts.POST("/generate-auth-url", h.Admin.OAuth.GenerateAuthURL) - accounts.POST("/generate-setup-token-url", h.Admin.OAuth.GenerateSetupTokenURL) - accounts.POST("/exchange-code", h.Admin.OAuth.ExchangeCode) - accounts.POST("/exchange-setup-token-code", h.Admin.OAuth.ExchangeSetupTokenCode) - accounts.POST("/cookie-auth", h.Admin.OAuth.CookieAuth) - accounts.POST("/setup-token-cookie-auth", h.Admin.OAuth.SetupTokenCookieAuth) - } - - // 代理管理 - proxies := admin.Group("/proxies") - { - proxies.GET("", h.Admin.Proxy.List) - proxies.GET("/all", h.Admin.Proxy.GetAll) - proxies.GET("/:id", h.Admin.Proxy.GetByID) - proxies.POST("", h.Admin.Proxy.Create) - proxies.PUT("/:id", h.Admin.Proxy.Update) - proxies.DELETE("/:id", h.Admin.Proxy.Delete) - proxies.POST("/:id/test", h.Admin.Proxy.Test) - proxies.GET("/:id/stats", h.Admin.Proxy.GetStats) - proxies.GET("/:id/accounts", h.Admin.Proxy.GetProxyAccounts) - proxies.POST("/batch", h.Admin.Proxy.BatchCreate) - } - - // 卡密管理 - codes := admin.Group("/redeem-codes") - { - codes.GET("", h.Admin.Redeem.List) - codes.GET("/stats", h.Admin.Redeem.GetStats) - codes.GET("/export", h.Admin.Redeem.Export) - codes.GET("/:id", h.Admin.Redeem.GetByID) - codes.POST("/generate", h.Admin.Redeem.Generate) - codes.DELETE("/:id", h.Admin.Redeem.Delete) - codes.POST("/batch-delete", h.Admin.Redeem.BatchDelete) - codes.POST("/:id/expire", h.Admin.Redeem.Expire) - } - - // 系统设置 - adminSettings := admin.Group("/settings") - { - adminSettings.GET("", h.Admin.Setting.GetSettings) - adminSettings.PUT("", h.Admin.Setting.UpdateSettings) - adminSettings.POST("/test-smtp", h.Admin.Setting.TestSmtpConnection) - adminSettings.POST("/send-test-email", h.Admin.Setting.SendTestEmail) - } - - // 系统管理 - system := admin.Group("/system") - { - system.GET("/version", h.Admin.System.GetVersion) - system.GET("/check-updates", h.Admin.System.CheckUpdates) - system.POST("/update", h.Admin.System.PerformUpdate) - system.POST("/rollback", h.Admin.System.Rollback) - system.POST("/restart", h.Admin.System.RestartService) - } - - // 订阅管理 - subscriptions := admin.Group("/subscriptions") - { - subscriptions.GET("", h.Admin.Subscription.List) - subscriptions.GET("/:id", h.Admin.Subscription.GetByID) - subscriptions.GET("/:id/progress", h.Admin.Subscription.GetProgress) - subscriptions.POST("/assign", h.Admin.Subscription.Assign) - subscriptions.POST("/bulk-assign", h.Admin.Subscription.BulkAssign) - subscriptions.POST("/:id/extend", h.Admin.Subscription.Extend) - subscriptions.DELETE("/:id", h.Admin.Subscription.Revoke) - } - - // 分组下的订阅列表 - admin.GET("/groups/:id/subscriptions", h.Admin.Subscription.ListByGroup) - - // 用户下的订阅列表 - admin.GET("/users/:id/subscriptions", h.Admin.Subscription.ListByUser) - - // 使用记录管理 - usage := admin.Group("/usage") - { - usage.GET("", h.Admin.Usage.List) - usage.GET("/stats", h.Admin.Usage.Stats) - usage.GET("/search-users", h.Admin.Usage.SearchUsers) - usage.GET("/search-api-keys", h.Admin.Usage.SearchApiKeys) - } - } - } - - // API网关(Claude API兼容) - gateway := r.Group("/v1") - gateway.Use(middleware.ApiKeyAuthWithSubscription(s.ApiKey, s.Subscription)) - { - gateway.POST("/messages", h.Gateway.Messages) - gateway.GET("/models", h.Gateway.Models) - gateway.GET("/usage", h.Gateway.Usage) - } -} - -// setupRouter 配置路由器中间件和路由 -func setupRouter(r *gin.Engine, cfg *config.Config, handlers *handler.Handlers, services *service.Services, repos *repository.Repositories) *gin.Engine { - // 应用中间件 - r.Use(middleware.Logger()) - r.Use(middleware.CORS()) - - // 注册路由 - registerRoutes(r, handlers, services, repos) - - // Serve embedded frontend if available - if web.HasEmbeddedFrontend() { - r.Use(web.ServeEmbeddedFrontend()) - } - - return r -} - -// createHTTPServer 创建HTTP服务器 -func createHTTPServer(cfg *config.Config, router *gin.Engine) *http.Server { - return &http.Server{ - Addr: cfg.Server.Address(), - Handler: router, - // ReadHeaderTimeout: 读取请求头的超时时间,防止慢速请求头攻击 - ReadHeaderTimeout: time.Duration(cfg.Server.ReadHeaderTimeout) * time.Second, - // IdleTimeout: 空闲连接超时时间,释放不活跃的连接资源 - IdleTimeout: time.Duration(cfg.Server.IdleTimeout) * time.Second, - // 注意:不设置 WriteTimeout,因为流式响应可能持续十几分钟 - // 不设置 ReadTimeout,因为大请求体可能需要较长时间读取 - } -} diff --git a/backend/cmd/server/wire.go b/backend/cmd/server/wire.go index 91cc131a..9e9f8677 100644 --- a/backend/cmd/server/wire.go +++ b/backend/cmd/server/wire.go @@ -6,12 +6,16 @@ package main import ( "sub2api/internal/config" "sub2api/internal/handler" + "sub2api/internal/infrastructure" "sub2api/internal/repository" + "sub2api/internal/server" "sub2api/internal/service" + "context" + "log" "net/http" + "time" - "github.com/gin-gonic/gin" "github.com/google/wire" "github.com/redis/go-redis/v9" "gorm.io/gorm" @@ -24,80 +28,76 @@ type Application struct { func initializeApplication(buildInfo handler.BuildInfo) (*Application, error) { wire.Build( - // Config provider - provideConfig, + // 基础设施层 ProviderSets + config.ProviderSet, + infrastructure.ProviderSet, - // Database provider - provideDB, + // 业务层 ProviderSets + repository.ProviderSet, + service.ProviderSet, + handler.ProviderSet, - // Redis provider - provideRedis, + // 服务器层 ProviderSet + server.ProviderSet, - // Repository provider - provideRepositories, - - // Service provider - provideServices, - - // Handler provider - provideHandlers, - - // Router provider - provideRouter, - - // HTTP Server provider - provideHTTPServer, - - // Cleanup provider + // 清理函数提供者 provideCleanup, - // Application provider + // 应用程序结构体 wire.Struct(new(Application), "Server", "Cleanup"), ) return nil, nil } -func provideConfig() (*config.Config, error) { - return config.Load() -} - -func provideDB(cfg *config.Config) (*gorm.DB, error) { - return initDB(cfg) -} - -func provideRedis(cfg *config.Config) *redis.Client { - return initRedis(cfg) -} - -func provideRepositories(db *gorm.DB) *repository.Repositories { - return repository.NewRepositories(db) -} - -func provideServices(repos *repository.Repositories, rdb *redis.Client, cfg *config.Config) *service.Services { - return service.NewServices(repos, rdb, cfg) -} - -func provideHandlers(services *service.Services, repos *repository.Repositories, rdb *redis.Client, buildInfo handler.BuildInfo) *handler.Handlers { - return handler.NewHandlers(services, repos, rdb, buildInfo) -} - -func provideRouter(cfg *config.Config, handlers *handler.Handlers, services *service.Services, repos *repository.Repositories) *gin.Engine { - if cfg.Server.Mode == "release" { - gin.SetMode(gin.ReleaseMode) - } - - r := gin.New() - r.Use(gin.Recovery()) - - return setupRouter(r, cfg, handlers, services, repos) -} - -func provideHTTPServer(cfg *config.Config, router *gin.Engine) *http.Server { - return createHTTPServer(cfg, router) -} - -func provideCleanup() func() { +func provideCleanup( + db *gorm.DB, + rdb *redis.Client, + services *service.Services, +) func() { return func() { - // @todo + ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) + defer cancel() + + // Cleanup steps in reverse dependency order + cleanupSteps := []struct { + name string + fn func() error + }{ + {"PricingService", func() error { + services.Pricing.Stop() + return nil + }}, + {"EmailQueueService", func() error { + services.EmailQueue.Stop() + return nil + }}, + {"Redis", func() error { + return rdb.Close() + }}, + {"Database", func() error { + sqlDB, err := db.DB() + if err != nil { + return err + } + return sqlDB.Close() + }}, + } + + for _, step := range cleanupSteps { + if err := step.fn(); err != nil { + log.Printf("[Cleanup] %s failed: %v", step.name, err) + // Continue with remaining cleanup steps even if one fails + } else { + log.Printf("[Cleanup] %s succeeded", step.name) + } + } + + // Check if context timed out + select { + case <-ctx.Done(): + log.Printf("[Cleanup] Warning: cleanup timed out after 10 seconds") + default: + log.Printf("[Cleanup] All cleanup steps completed") + } } } diff --git a/backend/cmd/server/wire_gen.go b/backend/cmd/server/wire_gen.go index e0d46e59..d7783f40 100644 --- a/backend/cmd/server/wire_gen.go +++ b/backend/cmd/server/wire_gen.go @@ -7,14 +7,19 @@ package main import ( - "github.com/gin-gonic/gin" + "context" "github.com/redis/go-redis/v9" "gorm.io/gorm" + "log" "net/http" "sub2api/internal/config" "sub2api/internal/handler" + "sub2api/internal/handler/admin" + "sub2api/internal/infrastructure" "sub2api/internal/repository" + "sub2api/internal/server" "sub2api/internal/service" + "time" ) import ( @@ -24,23 +29,114 @@ import ( // Injectors from wire.go: func initializeApplication(buildInfo handler.BuildInfo) (*Application, error) { - config, err := provideConfig() + configConfig, err := config.ProvideConfig() if err != nil { return nil, err } - db, err := provideDB(config) + db, err := infrastructure.ProvideDB(configConfig) if err != nil { return nil, err } - repositories := provideRepositories(db) - client := provideRedis(config) - services := provideServices(repositories, client, config) - handlers := provideHandlers(services, repositories, client, buildInfo) - engine := provideRouter(config, handlers, services, repositories) - server := provideHTTPServer(config, engine) - v := provideCleanup() + userRepository := repository.NewUserRepository(db) + settingRepository := repository.NewSettingRepository(db) + settingService := service.NewSettingService(settingRepository, configConfig) + client := infrastructure.ProvideRedis(configConfig) + emailService := service.NewEmailService(settingRepository, client) + turnstileService := service.NewTurnstileService(settingService) + emailQueueService := service.ProvideEmailQueueService(emailService) + authService := service.NewAuthService(userRepository, configConfig, settingService, emailService, turnstileService, emailQueueService) + authHandler := handler.NewAuthHandler(authService) + userService := service.NewUserService(userRepository, configConfig) + userHandler := handler.NewUserHandler(userService) + apiKeyRepository := repository.NewApiKeyRepository(db) + groupRepository := repository.NewGroupRepository(db) + userSubscriptionRepository := repository.NewUserSubscriptionRepository(db) + apiKeyService := service.NewApiKeyService(apiKeyRepository, userRepository, groupRepository, userSubscriptionRepository, client, configConfig) + apiKeyHandler := handler.NewAPIKeyHandler(apiKeyService) + usageLogRepository := repository.NewUsageLogRepository(db) + usageService := service.NewUsageService(usageLogRepository, userRepository) + usageHandler := handler.NewUsageHandler(usageService, usageLogRepository, apiKeyService) + redeemCodeRepository := repository.NewRedeemCodeRepository(db) + accountRepository := repository.NewAccountRepository(db) + proxyRepository := repository.NewProxyRepository(db) + repositories := &repository.Repositories{ + User: userRepository, + ApiKey: apiKeyRepository, + Group: groupRepository, + Account: accountRepository, + Proxy: proxyRepository, + RedeemCode: redeemCodeRepository, + UsageLog: usageLogRepository, + Setting: settingRepository, + UserSubscription: userSubscriptionRepository, + } + billingCacheService := service.NewBillingCacheService(client, userRepository, userSubscriptionRepository) + subscriptionService := service.NewSubscriptionService(repositories, billingCacheService) + redeemService := service.NewRedeemService(redeemCodeRepository, userRepository, subscriptionService, client, billingCacheService) + redeemHandler := handler.NewRedeemHandler(redeemService) + subscriptionHandler := handler.NewSubscriptionHandler(subscriptionService) + adminService := service.NewAdminService(repositories, billingCacheService) + dashboardHandler := admin.NewDashboardHandler(adminService, usageLogRepository) + adminUserHandler := admin.NewUserHandler(adminService) + groupHandler := admin.NewGroupHandler(adminService) + oAuthService := service.NewOAuthService(proxyRepository) + rateLimitService := service.NewRateLimitService(repositories, configConfig) + accountUsageService := service.NewAccountUsageService(repositories, oAuthService) + accountTestService := service.NewAccountTestService(repositories, oAuthService) + accountHandler := admin.NewAccountHandler(adminService, oAuthService, rateLimitService, accountUsageService, accountTestService) + oAuthHandler := admin.NewOAuthHandler(oAuthService, adminService) + proxyHandler := admin.NewProxyHandler(adminService) + adminRedeemHandler := admin.NewRedeemHandler(adminService) + settingHandler := admin.NewSettingHandler(settingService, emailService) + systemHandler := handler.ProvideSystemHandler(client, buildInfo) + adminSubscriptionHandler := admin.NewSubscriptionHandler(subscriptionService) + adminUsageHandler := admin.NewUsageHandler(usageLogRepository, apiKeyRepository, usageService, adminService) + adminHandlers := handler.ProvideAdminHandlers(dashboardHandler, adminUserHandler, groupHandler, accountHandler, oAuthHandler, proxyHandler, adminRedeemHandler, settingHandler, systemHandler, adminSubscriptionHandler, adminUsageHandler) + pricingService, err := service.ProvidePricingService(configConfig) + if err != nil { + return nil, err + } + billingService := service.NewBillingService(configConfig, pricingService) + identityService := service.NewIdentityService(client) + gatewayService := service.NewGatewayService(repositories, client, configConfig, oAuthService, billingService, rateLimitService, billingCacheService, identityService) + concurrencyService := service.NewConcurrencyService(client) + gatewayHandler := handler.NewGatewayHandler(gatewayService, userService, concurrencyService, billingCacheService) + handlerSettingHandler := handler.ProvideSettingHandler(settingService, buildInfo) + handlers := handler.ProvideHandlers(authHandler, userHandler, apiKeyHandler, usageHandler, redeemHandler, subscriptionHandler, adminHandlers, gatewayHandler, handlerSettingHandler) + groupService := service.NewGroupService(groupRepository) + accountService := service.NewAccountService(accountRepository, groupRepository) + proxyService := service.NewProxyService(proxyRepository) + services := &service.Services{ + Auth: authService, + User: userService, + ApiKey: apiKeyService, + Group: groupService, + Account: accountService, + Proxy: proxyService, + Redeem: redeemService, + Usage: usageService, + Pricing: pricingService, + Billing: billingService, + BillingCache: billingCacheService, + Admin: adminService, + Gateway: gatewayService, + OAuth: oAuthService, + RateLimit: rateLimitService, + AccountUsage: accountUsageService, + AccountTest: accountTestService, + Setting: settingService, + Email: emailService, + EmailQueue: emailQueueService, + Turnstile: turnstileService, + Subscription: subscriptionService, + Concurrency: concurrencyService, + Identity: identityService, + } + engine := server.ProvideRouter(configConfig, handlers, services, repositories) + httpServer := server.ProvideHTTPServer(configConfig, engine) + v := provideCleanup(db, client, services) application := &Application{ - Server: server, + Server: httpServer, Cleanup: v, } return application, nil @@ -53,47 +149,53 @@ type Application struct { Cleanup func() } -func provideConfig() (*config.Config, error) { - return config.Load() -} - -func provideDB(cfg *config.Config) (*gorm.DB, error) { - return initDB(cfg) -} - -func provideRedis(cfg *config.Config) *redis.Client { - return initRedis(cfg) -} - -func provideRepositories(db *gorm.DB) *repository.Repositories { - return repository.NewRepositories(db) -} - -func provideServices(repos *repository.Repositories, rdb *redis.Client, cfg *config.Config) *service.Services { - return service.NewServices(repos, rdb, cfg) -} - -func provideHandlers(services *service.Services, repos *repository.Repositories, rdb *redis.Client, buildInfo handler.BuildInfo) *handler.Handlers { - return handler.NewHandlers(services, repos, rdb, buildInfo) -} - -func provideRouter(cfg *config.Config, handlers *handler.Handlers, services *service.Services, repos *repository.Repositories) *gin.Engine { - if cfg.Server.Mode == "release" { - gin.SetMode(gin.ReleaseMode) - } - - r := gin.New() - r.Use(gin.Recovery()) - - return setupRouter(r, cfg, handlers, services, repos) -} - -func provideHTTPServer(cfg *config.Config, router *gin.Engine) *http.Server { - return createHTTPServer(cfg, router) -} - -func provideCleanup() func() { +func provideCleanup( + db *gorm.DB, + rdb *redis.Client, + services *service.Services, +) func() { return func() { + ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) + defer cancel() + cleanupSteps := []struct { + name string + fn func() error + }{ + {"PricingService", func() error { + services.Pricing.Stop() + return nil + }}, + {"EmailQueueService", func() error { + services.EmailQueue.Stop() + return nil + }}, + {"Redis", func() error { + return rdb.Close() + }}, + {"Database", func() error { + sqlDB, err := db.DB() + if err != nil { + return err + } + return sqlDB.Close() + }}, + } + + for _, step := range cleanupSteps { + if err := step.fn(); err != nil { + log.Printf("[Cleanup] %s failed: %v", step.name, err) + + } else { + log.Printf("[Cleanup] %s succeeded", step.name) + } + } + + select { + case <-ctx.Done(): + log.Printf("[Cleanup] Warning: cleanup timed out after 10 seconds") + default: + log.Printf("[Cleanup] All cleanup steps completed") + } } } diff --git a/backend/internal/config/wire.go b/backend/internal/config/wire.go new file mode 100644 index 00000000..ec26c401 --- /dev/null +++ b/backend/internal/config/wire.go @@ -0,0 +1,13 @@ +package config + +import "github.com/google/wire" + +// ProviderSet 提供配置层的依赖 +var ProviderSet = wire.NewSet( + ProvideConfig, +) + +// ProvideConfig 提供应用配置 +func ProvideConfig() (*Config, error) { + return Load() +} diff --git a/backend/internal/handler/handler.go b/backend/internal/handler/handler.go index 911a2434..830bd5e8 100644 --- a/backend/internal/handler/handler.go +++ b/backend/internal/handler/handler.go @@ -2,10 +2,6 @@ package handler import ( "sub2api/internal/handler/admin" - "sub2api/internal/repository" - "sub2api/internal/service" - - "github.com/redis/go-redis/v9" ) // AdminHandlers contains all admin-related HTTP handlers @@ -41,30 +37,3 @@ type BuildInfo struct { Version string BuildType string // "source" for manual builds, "release" for CI builds } - -// NewHandlers creates a new Handlers instance with all handlers initialized -func NewHandlers(services *service.Services, repos *repository.Repositories, rdb *redis.Client, buildInfo BuildInfo) *Handlers { - return &Handlers{ - Auth: NewAuthHandler(services.Auth), - User: NewUserHandler(services.User), - APIKey: NewAPIKeyHandler(services.ApiKey), - Usage: NewUsageHandler(services.Usage, repos.UsageLog, services.ApiKey), - Redeem: NewRedeemHandler(services.Redeem), - Subscription: NewSubscriptionHandler(services.Subscription), - Admin: &AdminHandlers{ - Dashboard: admin.NewDashboardHandler(services.Admin, repos.UsageLog), - User: admin.NewUserHandler(services.Admin), - Group: admin.NewGroupHandler(services.Admin), - Account: admin.NewAccountHandler(services.Admin, services.OAuth, services.RateLimit, services.AccountUsage, services.AccountTest), - OAuth: admin.NewOAuthHandler(services.OAuth, services.Admin), - Proxy: admin.NewProxyHandler(services.Admin), - Redeem: admin.NewRedeemHandler(services.Admin), - Setting: admin.NewSettingHandler(services.Setting, services.Email), - System: admin.NewSystemHandler(rdb, buildInfo.Version, buildInfo.BuildType), - Subscription: admin.NewSubscriptionHandler(services.Subscription), - Usage: admin.NewUsageHandler(repos.UsageLog, repos.ApiKey, services.Usage, services.Admin), - }, - Gateway: NewGatewayHandler(services.Gateway, services.User, services.Concurrency, services.BillingCache), - Setting: NewSettingHandler(services.Setting, buildInfo.Version), - } -} diff --git a/backend/internal/handler/wire.go b/backend/internal/handler/wire.go new file mode 100644 index 00000000..02e182e8 --- /dev/null +++ b/backend/internal/handler/wire.go @@ -0,0 +1,103 @@ +package handler + +import ( + "sub2api/internal/handler/admin" + "sub2api/internal/service" + + "github.com/google/wire" + "github.com/redis/go-redis/v9" +) + +// ProvideAdminHandlers creates the AdminHandlers struct +func ProvideAdminHandlers( + dashboardHandler *admin.DashboardHandler, + userHandler *admin.UserHandler, + groupHandler *admin.GroupHandler, + accountHandler *admin.AccountHandler, + oauthHandler *admin.OAuthHandler, + proxyHandler *admin.ProxyHandler, + redeemHandler *admin.RedeemHandler, + settingHandler *admin.SettingHandler, + systemHandler *admin.SystemHandler, + subscriptionHandler *admin.SubscriptionHandler, + usageHandler *admin.UsageHandler, +) *AdminHandlers { + return &AdminHandlers{ + Dashboard: dashboardHandler, + User: userHandler, + Group: groupHandler, + Account: accountHandler, + OAuth: oauthHandler, + Proxy: proxyHandler, + Redeem: redeemHandler, + Setting: settingHandler, + System: systemHandler, + Subscription: subscriptionHandler, + Usage: usageHandler, + } +} + +// ProvideSystemHandler creates admin.SystemHandler with BuildInfo parameters +func ProvideSystemHandler(rdb *redis.Client, buildInfo BuildInfo) *admin.SystemHandler { + return admin.NewSystemHandler(rdb, buildInfo.Version, buildInfo.BuildType) +} + +// ProvideSettingHandler creates SettingHandler with version from BuildInfo +func ProvideSettingHandler(settingService *service.SettingService, buildInfo BuildInfo) *SettingHandler { + return NewSettingHandler(settingService, buildInfo.Version) +} + +// ProvideHandlers creates the Handlers struct +func ProvideHandlers( + authHandler *AuthHandler, + userHandler *UserHandler, + apiKeyHandler *APIKeyHandler, + usageHandler *UsageHandler, + redeemHandler *RedeemHandler, + subscriptionHandler *SubscriptionHandler, + adminHandlers *AdminHandlers, + gatewayHandler *GatewayHandler, + settingHandler *SettingHandler, +) *Handlers { + return &Handlers{ + Auth: authHandler, + User: userHandler, + APIKey: apiKeyHandler, + Usage: usageHandler, + Redeem: redeemHandler, + Subscription: subscriptionHandler, + Admin: adminHandlers, + Gateway: gatewayHandler, + Setting: settingHandler, + } +} + +// ProviderSet is the Wire provider set for all handlers +var ProviderSet = wire.NewSet( + // Top-level handlers + NewAuthHandler, + NewUserHandler, + NewAPIKeyHandler, + NewUsageHandler, + NewRedeemHandler, + NewSubscriptionHandler, + NewGatewayHandler, + ProvideSettingHandler, + + // Admin handlers + admin.NewDashboardHandler, + admin.NewUserHandler, + admin.NewGroupHandler, + admin.NewAccountHandler, + admin.NewOAuthHandler, + admin.NewProxyHandler, + admin.NewRedeemHandler, + admin.NewSettingHandler, + ProvideSystemHandler, + admin.NewSubscriptionHandler, + admin.NewUsageHandler, + + // AdminHandlers and Handlers constructors + ProvideAdminHandlers, + ProvideHandlers, +) diff --git a/backend/internal/infrastructure/database.go b/backend/internal/infrastructure/database.go new file mode 100644 index 00000000..a6dad85f --- /dev/null +++ b/backend/internal/infrastructure/database.go @@ -0,0 +1,38 @@ +package infrastructure + +import ( + "sub2api/internal/config" + "sub2api/internal/model" + "sub2api/internal/pkg/timezone" + + "gorm.io/driver/postgres" + "gorm.io/gorm" + "gorm.io/gorm/logger" +) + +// InitDB 初始化数据库连接 +func InitDB(cfg *config.Config) (*gorm.DB, error) { + // 初始化时区(在数据库连接之前,确保时区设置正确) + if err := timezone.Init(cfg.Timezone); err != nil { + return nil, err + } + + gormConfig := &gorm.Config{} + if cfg.Server.Mode == "debug" { + gormConfig.Logger = logger.Default.LogMode(logger.Info) + } + + // 使用带时区的 DSN 连接数据库 + db, err := gorm.Open(postgres.Open(cfg.Database.DSNWithTimezone(cfg.Timezone)), gormConfig) + if err != nil { + return nil, err + } + + // 自动迁移(始终执行,确保数据库结构与代码同步) + // GORM 的 AutoMigrate 只会添加新字段,不会删除或修改已有字段,是安全的 + if err := model.AutoMigrate(db); err != nil { + return nil, err + } + + return db, nil +} diff --git a/backend/internal/infrastructure/redis.go b/backend/internal/infrastructure/redis.go new file mode 100644 index 00000000..77c9fee0 --- /dev/null +++ b/backend/internal/infrastructure/redis.go @@ -0,0 +1,16 @@ +package infrastructure + +import ( + "sub2api/internal/config" + + "github.com/redis/go-redis/v9" +) + +// InitRedis 初始化 Redis 客户端 +func InitRedis(cfg *config.Config) *redis.Client { + return redis.NewClient(&redis.Options{ + Addr: cfg.Redis.Address(), + Password: cfg.Redis.Password, + DB: cfg.Redis.DB, + }) +} diff --git a/backend/internal/infrastructure/wire.go b/backend/internal/infrastructure/wire.go new file mode 100644 index 00000000..3b6e4bbd --- /dev/null +++ b/backend/internal/infrastructure/wire.go @@ -0,0 +1,25 @@ +package infrastructure + +import ( + "sub2api/internal/config" + + "github.com/google/wire" + "github.com/redis/go-redis/v9" + "gorm.io/gorm" +) + +// ProviderSet 提供基础设施层的依赖 +var ProviderSet = wire.NewSet( + ProvideDB, + ProvideRedis, +) + +// ProvideDB 提供数据库连接 +func ProvideDB(cfg *config.Config) (*gorm.DB, error) { + return InitDB(cfg) +} + +// ProvideRedis 提供 Redis 客户端 +func ProvideRedis(cfg *config.Config) *redis.Client { + return InitRedis(cfg) +} diff --git a/backend/internal/repository/repository.go b/backend/internal/repository/repository.go index 0c91910e..0e880064 100644 --- a/backend/internal/repository/repository.go +++ b/backend/internal/repository/repository.go @@ -1,9 +1,5 @@ package repository -import ( - "gorm.io/gorm" -) - // Repositories 所有仓库的集合 type Repositories struct { User *UserRepository @@ -17,21 +13,6 @@ type Repositories struct { UserSubscription *UserSubscriptionRepository } -// NewRepositories 创建所有仓库 -func NewRepositories(db *gorm.DB) *Repositories { - return &Repositories{ - User: NewUserRepository(db), - ApiKey: NewApiKeyRepository(db), - Group: NewGroupRepository(db), - Account: NewAccountRepository(db), - Proxy: NewProxyRepository(db), - RedeemCode: NewRedeemCodeRepository(db), - UsageLog: NewUsageLogRepository(db), - Setting: NewSettingRepository(db), - UserSubscription: NewUserSubscriptionRepository(db), - } -} - // PaginationParams 分页参数 type PaginationParams struct { Page int diff --git a/backend/internal/repository/wire.go b/backend/internal/repository/wire.go new file mode 100644 index 00000000..d2b56a1c --- /dev/null +++ b/backend/internal/repository/wire.go @@ -0,0 +1,19 @@ +package repository + +import ( + "github.com/google/wire" +) + +// ProviderSet is the Wire provider set for all repositories +var ProviderSet = wire.NewSet( + NewUserRepository, + NewApiKeyRepository, + NewGroupRepository, + NewAccountRepository, + NewProxyRepository, + NewRedeemCodeRepository, + NewUsageLogRepository, + NewSettingRepository, + NewUserSubscriptionRepository, + wire.Struct(new(Repositories), "*"), +) diff --git a/backend/internal/server/http.go b/backend/internal/server/http.go new file mode 100644 index 00000000..a02e5c23 --- /dev/null +++ b/backend/internal/server/http.go @@ -0,0 +1,45 @@ +package server + +import ( + "net/http" + "sub2api/internal/config" + "sub2api/internal/handler" + "sub2api/internal/repository" + "sub2api/internal/service" + "time" + + "github.com/gin-gonic/gin" + "github.com/google/wire" +) + +// ProviderSet 提供服务器层的依赖 +var ProviderSet = wire.NewSet( + ProvideRouter, + ProvideHTTPServer, +) + +// ProvideRouter 提供路由器 +func ProvideRouter(cfg *config.Config, handlers *handler.Handlers, services *service.Services, repos *repository.Repositories) *gin.Engine { + if cfg.Server.Mode == "release" { + gin.SetMode(gin.ReleaseMode) + } + + r := gin.New() + r.Use(gin.Recovery()) + + return SetupRouter(r, cfg, handlers, services, repos) +} + +// ProvideHTTPServer 提供 HTTP 服务器 +func ProvideHTTPServer(cfg *config.Config, router *gin.Engine) *http.Server { + return &http.Server{ + Addr: cfg.Server.Address(), + Handler: router, + // ReadHeaderTimeout: 读取请求头的超时时间,防止慢速请求头攻击 + ReadHeaderTimeout: time.Duration(cfg.Server.ReadHeaderTimeout) * time.Second, + // IdleTimeout: 空闲连接超时时间,释放不活跃的连接资源 + IdleTimeout: time.Duration(cfg.Server.IdleTimeout) * time.Second, + // 注意:不设置 WriteTimeout,因为流式响应可能持续十几分钟 + // 不设置 ReadTimeout,因为大请求体可能需要较长时间读取 + } +} diff --git a/backend/internal/server/router.go b/backend/internal/server/router.go new file mode 100644 index 00000000..b28d3c79 --- /dev/null +++ b/backend/internal/server/router.go @@ -0,0 +1,282 @@ +package server + +import ( + "net/http" + "sub2api/internal/config" + "sub2api/internal/handler" + "sub2api/internal/middleware" + "sub2api/internal/repository" + "sub2api/internal/service" + "sub2api/internal/web" + + "github.com/gin-gonic/gin" +) + +// SetupRouter 配置路由器中间件和路由 +func SetupRouter(r *gin.Engine, cfg *config.Config, handlers *handler.Handlers, services *service.Services, repos *repository.Repositories) *gin.Engine { + // 应用中间件 + r.Use(middleware.Logger()) + r.Use(middleware.CORS()) + + // 注册路由 + registerRoutes(r, handlers, services, repos) + + // Serve embedded frontend if available + if web.HasEmbeddedFrontend() { + r.Use(web.ServeEmbeddedFrontend()) + } + + return r +} + +// registerRoutes 注册所有 HTTP 路由 +func registerRoutes(r *gin.Engine, h *handler.Handlers, s *service.Services, repos *repository.Repositories) { + // 健康检查 + r.GET("/health", func(c *gin.Context) { + c.JSON(http.StatusOK, gin.H{"status": "ok"}) + }) + + // Setup status endpoint (always returns needs_setup: false in normal mode) + // This is used by the frontend to detect when the service has restarted after setup + r.GET("/setup/status", func(c *gin.Context) { + c.JSON(http.StatusOK, gin.H{ + "code": 0, + "data": gin.H{ + "needs_setup": false, + "step": "completed", + }, + }) + }) + + // API v1 + v1 := r.Group("/api/v1") + { + // 公开接口 + auth := v1.Group("/auth") + { + auth.POST("/register", h.Auth.Register) + auth.POST("/login", h.Auth.Login) + auth.POST("/send-verify-code", h.Auth.SendVerifyCode) + } + + // 公开设置(无需认证) + settings := v1.Group("/settings") + { + settings.GET("/public", h.Setting.GetPublicSettings) + } + + // 需要认证的接口 + authenticated := v1.Group("") + authenticated.Use(middleware.JWTAuth(s.Auth, repos.User)) + { + // 当前用户信息 + authenticated.GET("/auth/me", h.Auth.GetCurrentUser) + + // 用户接口 + user := authenticated.Group("/user") + { + user.GET("/profile", h.User.GetProfile) + user.PUT("/password", h.User.ChangePassword) + } + + // API Key管理 + keys := authenticated.Group("/keys") + { + keys.GET("", h.APIKey.List) + keys.GET("/:id", h.APIKey.GetByID) + keys.POST("", h.APIKey.Create) + keys.PUT("/:id", h.APIKey.Update) + keys.DELETE("/:id", h.APIKey.Delete) + } + + // 用户可用分组(非管理员接口) + groups := authenticated.Group("/groups") + { + groups.GET("/available", h.APIKey.GetAvailableGroups) + } + + // 使用记录 + usage := authenticated.Group("/usage") + { + usage.GET("", h.Usage.List) + usage.GET("/:id", h.Usage.GetByID) + usage.GET("/stats", h.Usage.Stats) + // User dashboard endpoints + usage.GET("/dashboard/stats", h.Usage.DashboardStats) + usage.GET("/dashboard/trend", h.Usage.DashboardTrend) + usage.GET("/dashboard/models", h.Usage.DashboardModels) + usage.POST("/dashboard/api-keys-usage", h.Usage.DashboardApiKeysUsage) + } + + // 卡密兑换 + redeem := authenticated.Group("/redeem") + { + redeem.POST("", h.Redeem.Redeem) + redeem.GET("/history", h.Redeem.GetHistory) + } + + // 用户订阅 + subscriptions := authenticated.Group("/subscriptions") + { + subscriptions.GET("", h.Subscription.List) + subscriptions.GET("/active", h.Subscription.GetActive) + subscriptions.GET("/progress", h.Subscription.GetProgress) + subscriptions.GET("/summary", h.Subscription.GetSummary) + } + } + + // 管理员接口 + admin := v1.Group("/admin") + admin.Use(middleware.JWTAuth(s.Auth, repos.User), middleware.AdminOnly()) + { + // 仪表盘 + dashboard := admin.Group("/dashboard") + { + dashboard.GET("/stats", h.Admin.Dashboard.GetStats) + dashboard.GET("/realtime", h.Admin.Dashboard.GetRealtimeMetrics) + dashboard.GET("/trend", h.Admin.Dashboard.GetUsageTrend) + dashboard.GET("/models", h.Admin.Dashboard.GetModelStats) + dashboard.GET("/api-keys-trend", h.Admin.Dashboard.GetApiKeyUsageTrend) + dashboard.GET("/users-trend", h.Admin.Dashboard.GetUserUsageTrend) + dashboard.POST("/users-usage", h.Admin.Dashboard.GetBatchUsersUsage) + dashboard.POST("/api-keys-usage", h.Admin.Dashboard.GetBatchApiKeysUsage) + } + + // 用户管理 + users := admin.Group("/users") + { + users.GET("", h.Admin.User.List) + users.GET("/:id", h.Admin.User.GetByID) + users.POST("", h.Admin.User.Create) + users.PUT("/:id", h.Admin.User.Update) + users.DELETE("/:id", h.Admin.User.Delete) + users.POST("/:id/balance", h.Admin.User.UpdateBalance) + users.GET("/:id/api-keys", h.Admin.User.GetUserAPIKeys) + users.GET("/:id/usage", h.Admin.User.GetUserUsage) + } + + // 分组管理 + groups := admin.Group("/groups") + { + groups.GET("", h.Admin.Group.List) + groups.GET("/all", h.Admin.Group.GetAll) + groups.GET("/:id", h.Admin.Group.GetByID) + groups.POST("", h.Admin.Group.Create) + groups.PUT("/:id", h.Admin.Group.Update) + groups.DELETE("/:id", h.Admin.Group.Delete) + groups.GET("/:id/stats", h.Admin.Group.GetStats) + groups.GET("/:id/api-keys", h.Admin.Group.GetGroupAPIKeys) + } + + // 账号管理 + accounts := admin.Group("/accounts") + { + accounts.GET("", h.Admin.Account.List) + accounts.GET("/:id", h.Admin.Account.GetByID) + accounts.POST("", h.Admin.Account.Create) + accounts.PUT("/:id", h.Admin.Account.Update) + accounts.DELETE("/:id", h.Admin.Account.Delete) + accounts.POST("/:id/test", h.Admin.Account.Test) + accounts.POST("/:id/refresh", h.Admin.Account.Refresh) + accounts.GET("/:id/stats", h.Admin.Account.GetStats) + accounts.POST("/:id/clear-error", h.Admin.Account.ClearError) + accounts.GET("/:id/usage", h.Admin.Account.GetUsage) + accounts.GET("/:id/today-stats", h.Admin.Account.GetTodayStats) + accounts.POST("/:id/clear-rate-limit", h.Admin.Account.ClearRateLimit) + accounts.POST("/:id/schedulable", h.Admin.Account.SetSchedulable) + accounts.POST("/batch", h.Admin.Account.BatchCreate) + + // OAuth routes + accounts.POST("/generate-auth-url", h.Admin.OAuth.GenerateAuthURL) + accounts.POST("/generate-setup-token-url", h.Admin.OAuth.GenerateSetupTokenURL) + accounts.POST("/exchange-code", h.Admin.OAuth.ExchangeCode) + accounts.POST("/exchange-setup-token-code", h.Admin.OAuth.ExchangeSetupTokenCode) + accounts.POST("/cookie-auth", h.Admin.OAuth.CookieAuth) + accounts.POST("/setup-token-cookie-auth", h.Admin.OAuth.SetupTokenCookieAuth) + } + + // 代理管理 + proxies := admin.Group("/proxies") + { + proxies.GET("", h.Admin.Proxy.List) + proxies.GET("/all", h.Admin.Proxy.GetAll) + proxies.GET("/:id", h.Admin.Proxy.GetByID) + proxies.POST("", h.Admin.Proxy.Create) + proxies.PUT("/:id", h.Admin.Proxy.Update) + proxies.DELETE("/:id", h.Admin.Proxy.Delete) + proxies.POST("/:id/test", h.Admin.Proxy.Test) + proxies.GET("/:id/stats", h.Admin.Proxy.GetStats) + proxies.GET("/:id/accounts", h.Admin.Proxy.GetProxyAccounts) + proxies.POST("/batch", h.Admin.Proxy.BatchCreate) + } + + // 卡密管理 + codes := admin.Group("/redeem-codes") + { + codes.GET("", h.Admin.Redeem.List) + codes.GET("/stats", h.Admin.Redeem.GetStats) + codes.GET("/export", h.Admin.Redeem.Export) + codes.GET("/:id", h.Admin.Redeem.GetByID) + codes.POST("/generate", h.Admin.Redeem.Generate) + codes.DELETE("/:id", h.Admin.Redeem.Delete) + codes.POST("/batch-delete", h.Admin.Redeem.BatchDelete) + codes.POST("/:id/expire", h.Admin.Redeem.Expire) + } + + // 系统设置 + adminSettings := admin.Group("/settings") + { + adminSettings.GET("", h.Admin.Setting.GetSettings) + adminSettings.PUT("", h.Admin.Setting.UpdateSettings) + adminSettings.POST("/test-smtp", h.Admin.Setting.TestSmtpConnection) + adminSettings.POST("/send-test-email", h.Admin.Setting.SendTestEmail) + } + + // 系统管理 + system := admin.Group("/system") + { + system.GET("/version", h.Admin.System.GetVersion) + system.GET("/check-updates", h.Admin.System.CheckUpdates) + system.POST("/update", h.Admin.System.PerformUpdate) + system.POST("/rollback", h.Admin.System.Rollback) + system.POST("/restart", h.Admin.System.RestartService) + } + + // 订阅管理 + subscriptions := admin.Group("/subscriptions") + { + subscriptions.GET("", h.Admin.Subscription.List) + subscriptions.GET("/:id", h.Admin.Subscription.GetByID) + subscriptions.GET("/:id/progress", h.Admin.Subscription.GetProgress) + subscriptions.POST("/assign", h.Admin.Subscription.Assign) + subscriptions.POST("/bulk-assign", h.Admin.Subscription.BulkAssign) + subscriptions.POST("/:id/extend", h.Admin.Subscription.Extend) + subscriptions.DELETE("/:id", h.Admin.Subscription.Revoke) + } + + // 分组下的订阅列表 + admin.GET("/groups/:id/subscriptions", h.Admin.Subscription.ListByGroup) + + // 用户下的订阅列表 + admin.GET("/users/:id/subscriptions", h.Admin.Subscription.ListByUser) + + // 使用记录管理 + usage := admin.Group("/usage") + { + usage.GET("", h.Admin.Usage.List) + usage.GET("/stats", h.Admin.Usage.Stats) + usage.GET("/search-users", h.Admin.Usage.SearchUsers) + usage.GET("/search-api-keys", h.Admin.Usage.SearchApiKeys) + } + } + } + + // API网关(Claude API兼容) + gateway := r.Group("/v1") + gateway.Use(middleware.ApiKeyAuthWithSubscription(s.ApiKey, s.Subscription)) + { + gateway.POST("/messages", h.Gateway.Messages) + gateway.GET("/models", h.Gateway.Models) + gateway.GET("/usage", h.Gateway.Usage) + } +} diff --git a/backend/internal/service/account_test_service.go b/backend/internal/service/account_test_service.go index ae6a47ba..7a531b06 100644 --- a/backend/internal/service/account_test_service.go +++ b/backend/internal/service/account_test_service.go @@ -92,9 +92,9 @@ func createTestPayload() map[string]interface{} { "metadata": map[string]string{ "user_id": generateSessionString(), }, - "max_tokens": 1024, + "max_tokens": 1024, "temperature": 1, - "stream": true, + "stream": true, } } @@ -310,5 +310,5 @@ func (s *AccountTestService) sendEvent(c *gin.Context, event TestEvent) { func (s *AccountTestService) sendErrorAndEnd(c *gin.Context, errorMsg string) error { log.Printf("Account test error: %s", errorMsg) s.sendEvent(c, TestEvent{Type: "error", Error: errorMsg}) - return fmt.Errorf(errorMsg) + return fmt.Errorf("%s", errorMsg) } diff --git a/backend/internal/service/admin_service.go b/backend/internal/service/admin_service.go index af8b67f5..988bd7cd 100644 --- a/backend/internal/service/admin_service.go +++ b/backend/internal/service/admin_service.go @@ -191,24 +191,17 @@ type adminServiceImpl struct { } // NewAdminService creates a new AdminService -func NewAdminService(repos *repository.Repositories) AdminService { +func NewAdminService(repos *repository.Repositories, billingCacheService *BillingCacheService) AdminService { return &adminServiceImpl{ - userRepo: repos.User, - groupRepo: repos.Group, - accountRepo: repos.Account, - proxyRepo: repos.Proxy, - apiKeyRepo: repos.ApiKey, - redeemCodeRepo: repos.RedeemCode, - usageLogRepo: repos.UsageLog, - userSubRepo: repos.UserSubscription, - } -} - -// SetBillingCacheService 设置计费缓存服务(用于缓存失效) -// 注意:AdminService是接口,需要类型断言 -func SetAdminServiceBillingCache(adminService AdminService, billingCacheService *BillingCacheService) { - if impl, ok := adminService.(*adminServiceImpl); ok { - impl.billingCacheService = billingCacheService + userRepo: repos.User, + groupRepo: repos.Group, + accountRepo: repos.Account, + proxyRepo: repos.Proxy, + apiKeyRepo: repos.ApiKey, + redeemCodeRepo: repos.RedeemCode, + usageLogRepo: repos.UsageLog, + userSubRepo: repos.UserSubscription, + billingCacheService: billingCacheService, } } diff --git a/backend/internal/service/auth_service.go b/backend/internal/service/auth_service.go index 2a063bd1..a93a8450 100644 --- a/backend/internal/service/auth_service.go +++ b/backend/internal/service/auth_service.go @@ -16,13 +16,13 @@ import ( ) var ( - ErrInvalidCredentials = errors.New("invalid email or password") - ErrUserNotActive = errors.New("user is not active") - ErrEmailExists = errors.New("email already exists") - ErrInvalidToken = errors.New("invalid token") - ErrTokenExpired = errors.New("token has expired") - ErrEmailVerifyRequired = errors.New("email verification is required") - ErrRegDisabled = errors.New("registration is currently disabled") + ErrInvalidCredentials = errors.New("invalid email or password") + ErrUserNotActive = errors.New("user is not active") + ErrEmailExists = errors.New("email already exists") + ErrInvalidToken = errors.New("invalid token") + ErrTokenExpired = errors.New("token has expired") + ErrEmailVerifyRequired = errors.New("email verification is required") + ErrRegDisabled = errors.New("registration is currently disabled") ) // JWTClaims JWT载荷数据 @@ -44,33 +44,24 @@ type AuthService struct { } // NewAuthService 创建认证服务实例 -func NewAuthService(userRepo *repository.UserRepository, cfg *config.Config) *AuthService { +func NewAuthService( + userRepo *repository.UserRepository, + cfg *config.Config, + settingService *SettingService, + emailService *EmailService, + turnstileService *TurnstileService, + emailQueueService *EmailQueueService, +) *AuthService { return &AuthService{ - userRepo: userRepo, - cfg: cfg, + userRepo: userRepo, + cfg: cfg, + settingService: settingService, + emailService: emailService, + turnstileService: turnstileService, + emailQueueService: emailQueueService, } } -// SetSettingService 设置系统设置服务(用于检查注册开关和邮件验证) -func (s *AuthService) SetSettingService(settingService *SettingService) { - s.settingService = settingService -} - -// SetEmailService 设置邮件服务(用于邮件验证) -func (s *AuthService) SetEmailService(emailService *EmailService) { - s.emailService = emailService -} - -// SetTurnstileService 设置Turnstile服务(用于验证码校验) -func (s *AuthService) SetTurnstileService(turnstileService *TurnstileService) { - s.turnstileService = turnstileService -} - -// SetEmailQueueService 设置邮件队列服务(用于异步发送邮件) -func (s *AuthService) SetEmailQueueService(emailQueueService *EmailQueueService) { - s.emailQueueService = emailQueueService -} - // Register 用户注册,返回token和用户 func (s *AuthService) Register(ctx context.Context, email, password string) (string, *model.User, error) { return s.RegisterWithVerification(ctx, email, password, "") diff --git a/backend/internal/service/redeem_service.go b/backend/internal/service/redeem_service.go index 5601a9c1..d9c45364 100644 --- a/backend/internal/service/redeem_service.go +++ b/backend/internal/service/redeem_service.go @@ -57,20 +57,22 @@ type RedeemService struct { } // NewRedeemService 创建兑换码服务实例 -func NewRedeemService(redeemRepo *repository.RedeemCodeRepository, userRepo *repository.UserRepository, subscriptionService *SubscriptionService, rdb *redis.Client) *RedeemService { +func NewRedeemService( + redeemRepo *repository.RedeemCodeRepository, + userRepo *repository.UserRepository, + subscriptionService *SubscriptionService, + rdb *redis.Client, + billingCacheService *BillingCacheService, +) *RedeemService { return &RedeemService{ redeemRepo: redeemRepo, userRepo: userRepo, subscriptionService: subscriptionService, rdb: rdb, + billingCacheService: billingCacheService, } } -// SetBillingCacheService 设置计费缓存服务(用于缓存失效) -func (s *RedeemService) SetBillingCacheService(billingCacheService *BillingCacheService) { - s.billingCacheService = billingCacheService -} - // GenerateRandomCode 生成随机兑换码 func (s *RedeemService) GenerateRandomCode() (string, error) { // 生成16字节随机数据 diff --git a/backend/internal/service/service.go b/backend/internal/service/service.go index dcb964b9..10b98354 100644 --- a/backend/internal/service/service.go +++ b/backend/internal/service/service.go @@ -1,12 +1,5 @@ package service -import ( - "sub2api/internal/config" - "sub2api/internal/repository" - - "github.com/redis/go-redis/v9" -) - // Services 服务集合容器 type Services struct { Auth *AuthService @@ -34,106 +27,3 @@ type Services struct { Concurrency *ConcurrencyService Identity *IdentityService } - -// NewServices 创建所有服务实例 -func NewServices(repos *repository.Repositories, rdb *redis.Client, cfg *config.Config) *Services { - // 初始化价格服务 - pricingService := NewPricingService(cfg) - if err := pricingService.Initialize(); err != nil { - // 价格服务初始化失败不应阻止启动,使用回退价格 - println("[Service] Warning: Pricing service initialization failed:", err.Error()) - } - - // 初始化计费服务(依赖价格服务) - billingService := NewBillingService(cfg, pricingService) - - // 初始化其他服务 - authService := NewAuthService(repos.User, cfg) - userService := NewUserService(repos.User, cfg) - apiKeyService := NewApiKeyService(repos.ApiKey, repos.User, repos.Group, repos.UserSubscription, rdb, cfg) - groupService := NewGroupService(repos.Group) - accountService := NewAccountService(repos.Account, repos.Group) - proxyService := NewProxyService(repos.Proxy) - usageService := NewUsageService(repos.UsageLog, repos.User) - - // 初始化订阅服务 (RedeemService 依赖) - subscriptionService := NewSubscriptionService(repos) - - // 初始化兑换服务 (依赖订阅服务) - redeemService := NewRedeemService(repos.RedeemCode, repos.User, subscriptionService, rdb) - - // 初始化Admin服务 - adminService := NewAdminService(repos) - - // 初始化OAuth服务(GatewayService依赖) - oauthService := NewOAuthService(repos.Proxy) - - // 初始化限流服务 - rateLimitService := NewRateLimitService(repos, cfg) - - // 初始化计费缓存服务 - billingCacheService := NewBillingCacheService(rdb, repos.User, repos.UserSubscription) - - // 初始化账号使用量服务 - accountUsageService := NewAccountUsageService(repos, oauthService) - - // 初始化账号测试服务 - accountTestService := NewAccountTestService(repos, oauthService) - - // 初始化身份指纹服务 - identityService := NewIdentityService(rdb) - - // 初始化Gateway服务 - gatewayService := NewGatewayService(repos, rdb, cfg, oauthService, billingService, rateLimitService, billingCacheService, identityService) - - // 初始化设置服务 - settingService := NewSettingService(repos.Setting, cfg) - emailService := NewEmailService(repos.Setting, rdb) - - // 初始化邮件队列服务 - emailQueueService := NewEmailQueueService(emailService, 3) - - // 初始化Turnstile服务 - turnstileService := NewTurnstileService(settingService) - - // 设置Auth服务的依赖(用于注册开关和邮件验证) - authService.SetSettingService(settingService) - authService.SetEmailService(emailService) - authService.SetTurnstileService(turnstileService) - authService.SetEmailQueueService(emailQueueService) - - // 初始化并发控制服务 - concurrencyService := NewConcurrencyService(rdb) - - // 注入计费缓存服务到需要失效缓存的服务 - redeemService.SetBillingCacheService(billingCacheService) - subscriptionService.SetBillingCacheService(billingCacheService) - SetAdminServiceBillingCache(adminService, billingCacheService) - - return &Services{ - Auth: authService, - User: userService, - ApiKey: apiKeyService, - Group: groupService, - Account: accountService, - Proxy: proxyService, - Redeem: redeemService, - Usage: usageService, - Pricing: pricingService, - Billing: billingService, - BillingCache: billingCacheService, - Admin: adminService, - Gateway: gatewayService, - OAuth: oauthService, - RateLimit: rateLimitService, - AccountUsage: accountUsageService, - AccountTest: accountTestService, - Setting: settingService, - Email: emailService, - EmailQueue: emailQueueService, - Turnstile: turnstileService, - Subscription: subscriptionService, - Concurrency: concurrencyService, - Identity: identityService, - } -} diff --git a/backend/internal/service/subscription_service.go b/backend/internal/service/subscription_service.go index 16b17410..1e1bd3c6 100644 --- a/backend/internal/service/subscription_service.go +++ b/backend/internal/service/subscription_service.go @@ -28,13 +28,11 @@ type SubscriptionService struct { } // NewSubscriptionService 创建订阅服务 -func NewSubscriptionService(repos *repository.Repositories) *SubscriptionService { - return &SubscriptionService{repos: repos} -} - -// SetBillingCacheService 设置计费缓存服务(用于缓存失效) -func (s *SubscriptionService) SetBillingCacheService(billingCacheService *BillingCacheService) { - s.billingCacheService = billingCacheService +func NewSubscriptionService(repos *repository.Repositories, billingCacheService *BillingCacheService) *SubscriptionService { + return &SubscriptionService{ + repos: repos, + billingCacheService: billingCacheService, + } } // AssignSubscriptionInput 分配订阅输入 @@ -88,6 +86,7 @@ func (s *SubscriptionService) AssignSubscription(ctx context.Context, input *Ass // 如果用户已有同分组的订阅: // - 未过期:从当前过期时间累加天数 // - 已过期:从当前时间开始计算新的过期时间,并激活订阅 +// // 如果没有订阅:创建新订阅 func (s *SubscriptionService) AssignOrExtendSubscription(ctx context.Context, input *AssignSubscriptionInput) (*model.UserSubscription, bool, error) { // 检查分组是否存在且为订阅类型 @@ -191,15 +190,15 @@ func (s *SubscriptionService) createSubscription(ctx context.Context, input *Ass now := time.Now() sub := &model.UserSubscription{ - UserID: input.UserID, - GroupID: input.GroupID, - StartsAt: now, - ExpiresAt: now.AddDate(0, 0, validityDays), - Status: model.SubscriptionStatusActive, + UserID: input.UserID, + GroupID: input.GroupID, + StartsAt: now, + ExpiresAt: now.AddDate(0, 0, validityDays), + Status: model.SubscriptionStatusActive, AssignedAt: now, - Notes: input.Notes, - CreatedAt: now, - UpdatedAt: now, + Notes: input.Notes, + CreatedAt: now, + UpdatedAt: now, } // 只有当 AssignedBy > 0 时才设置(0 表示系统分配,如兑换码) if input.AssignedBy > 0 { @@ -225,17 +224,17 @@ type BulkAssignSubscriptionInput struct { // BulkAssignResult 批量分配结果 type BulkAssignResult struct { - SuccessCount int - FailedCount int + SuccessCount int + FailedCount int Subscriptions []model.UserSubscription - Errors []string + Errors []string } // BulkAssignSubscription 批量分配订阅 func (s *SubscriptionService) BulkAssignSubscription(ctx context.Context, input *BulkAssignSubscriptionInput) (*BulkAssignResult, error) { result := &BulkAssignResult{ Subscriptions: make([]model.UserSubscription, 0), - Errors: make([]string, 0), + Errors: make([]string, 0), } for _, userID := range input.UserIDs { @@ -417,10 +416,10 @@ func (s *SubscriptionService) RecordUsage(ctx context.Context, subscriptionID in // SubscriptionProgress 订阅进度 type SubscriptionProgress struct { - ID int64 `json:"id"` - GroupName string `json:"group_name"` - ExpiresAt time.Time `json:"expires_at"` - ExpiresInDays int `json:"expires_in_days"` + ID int64 `json:"id"` + GroupName string `json:"group_name"` + ExpiresAt time.Time `json:"expires_at"` + ExpiresInDays int `json:"expires_in_days"` Daily *UsageWindowProgress `json:"daily,omitempty"` Weekly *UsageWindowProgress `json:"weekly,omitempty"` Monthly *UsageWindowProgress `json:"monthly,omitempty"` @@ -428,13 +427,13 @@ type SubscriptionProgress struct { // UsageWindowProgress 使用窗口进度 type UsageWindowProgress struct { - LimitUSD float64 `json:"limit_usd"` - UsedUSD float64 `json:"used_usd"` - RemainingUSD float64 `json:"remaining_usd"` - Percentage float64 `json:"percentage"` - WindowStart time.Time `json:"window_start"` - ResetsAt time.Time `json:"resets_at"` - ResetsInSeconds int64 `json:"resets_in_seconds"` + LimitUSD float64 `json:"limit_usd"` + UsedUSD float64 `json:"used_usd"` + RemainingUSD float64 `json:"remaining_usd"` + Percentage float64 `json:"percentage"` + WindowStart time.Time `json:"window_start"` + ResetsAt time.Time `json:"resets_at"` + ResetsInSeconds int64 `json:"resets_in_seconds"` } // GetSubscriptionProgress 获取订阅使用进度 @@ -464,12 +463,12 @@ func (s *SubscriptionService) GetSubscriptionProgress(ctx context.Context, subsc limit := *group.DailyLimitUSD resetsAt := sub.DailyWindowStart.Add(24 * time.Hour) progress.Daily = &UsageWindowProgress{ - LimitUSD: limit, - UsedUSD: sub.DailyUsageUSD, - RemainingUSD: limit - sub.DailyUsageUSD, - Percentage: (sub.DailyUsageUSD / limit) * 100, - WindowStart: *sub.DailyWindowStart, - ResetsAt: resetsAt, + LimitUSD: limit, + UsedUSD: sub.DailyUsageUSD, + RemainingUSD: limit - sub.DailyUsageUSD, + Percentage: (sub.DailyUsageUSD / limit) * 100, + WindowStart: *sub.DailyWindowStart, + ResetsAt: resetsAt, ResetsInSeconds: int64(time.Until(resetsAt).Seconds()), } if progress.Daily.RemainingUSD < 0 { @@ -488,12 +487,12 @@ func (s *SubscriptionService) GetSubscriptionProgress(ctx context.Context, subsc limit := *group.WeeklyLimitUSD resetsAt := sub.WeeklyWindowStart.Add(7 * 24 * time.Hour) progress.Weekly = &UsageWindowProgress{ - LimitUSD: limit, - UsedUSD: sub.WeeklyUsageUSD, - RemainingUSD: limit - sub.WeeklyUsageUSD, - Percentage: (sub.WeeklyUsageUSD / limit) * 100, - WindowStart: *sub.WeeklyWindowStart, - ResetsAt: resetsAt, + LimitUSD: limit, + UsedUSD: sub.WeeklyUsageUSD, + RemainingUSD: limit - sub.WeeklyUsageUSD, + Percentage: (sub.WeeklyUsageUSD / limit) * 100, + WindowStart: *sub.WeeklyWindowStart, + ResetsAt: resetsAt, ResetsInSeconds: int64(time.Until(resetsAt).Seconds()), } if progress.Weekly.RemainingUSD < 0 { @@ -512,12 +511,12 @@ func (s *SubscriptionService) GetSubscriptionProgress(ctx context.Context, subsc limit := *group.MonthlyLimitUSD resetsAt := sub.MonthlyWindowStart.Add(30 * 24 * time.Hour) progress.Monthly = &UsageWindowProgress{ - LimitUSD: limit, - UsedUSD: sub.MonthlyUsageUSD, - RemainingUSD: limit - sub.MonthlyUsageUSD, - Percentage: (sub.MonthlyUsageUSD / limit) * 100, - WindowStart: *sub.MonthlyWindowStart, - ResetsAt: resetsAt, + LimitUSD: limit, + UsedUSD: sub.MonthlyUsageUSD, + RemainingUSD: limit - sub.MonthlyUsageUSD, + Percentage: (sub.MonthlyUsageUSD / limit) * 100, + WindowStart: *sub.MonthlyWindowStart, + ResetsAt: resetsAt, ResetsInSeconds: int64(time.Until(resetsAt).Seconds()), } if progress.Monthly.RemainingUSD < 0 { diff --git a/backend/internal/service/wire.go b/backend/internal/service/wire.go new file mode 100644 index 00000000..9fbe670f --- /dev/null +++ b/backend/internal/service/wire.go @@ -0,0 +1,54 @@ +package service + +import ( + "sub2api/internal/config" + + "github.com/google/wire" +) + +// ProvidePricingService creates and initializes PricingService +func ProvidePricingService(cfg *config.Config) (*PricingService, error) { + svc := NewPricingService(cfg) + if err := svc.Initialize(); err != nil { + // 价格服务初始化失败不应阻止启动,使用回退价格 + println("[Service] Warning: Pricing service initialization failed:", err.Error()) + } + return svc, nil +} + +// ProvideEmailQueueService creates EmailQueueService with default worker count +func ProvideEmailQueueService(emailService *EmailService) *EmailQueueService { + return NewEmailQueueService(emailService, 3) +} + +// ProviderSet is the Wire provider set for all services +var ProviderSet = wire.NewSet( + // Core services + NewAuthService, + NewUserService, + NewApiKeyService, + NewGroupService, + NewAccountService, + NewProxyService, + NewRedeemService, + NewUsageService, + ProvidePricingService, + NewBillingService, + NewBillingCacheService, + NewAdminService, + NewGatewayService, + NewOAuthService, + NewRateLimitService, + NewAccountUsageService, + NewAccountTestService, + NewSettingService, + NewEmailService, + ProvideEmailQueueService, + NewTurnstileService, + NewSubscriptionService, + NewConcurrencyService, + NewIdentityService, + + // Provide the Services container struct + wire.Struct(new(Services), "*"), +)