diff --git a/backend/cmd/server/wire_gen.go b/backend/cmd/server/wire_gen.go index d7783f40..94db1128 100644 --- a/backend/cmd/server/wire_gen.go +++ b/backend/cmd/server/wire_gen.go @@ -57,32 +57,21 @@ func initializeApplication(buildInfo handler.BuildInfo) (*Application, error) { usageService := service.NewUsageService(usageLogRepository, userRepository) usageHandler := handler.NewUsageHandler(usageService, usageLogRepository, apiKeyService) redeemCodeRepository := repository.NewRedeemCodeRepository(db) - accountRepository := repository.NewAccountRepository(db) - proxyRepository := repository.NewProxyRepository(db) - repositories := &repository.Repositories{ - User: userRepository, - ApiKey: apiKeyRepository, - Group: groupRepository, - Account: accountRepository, - Proxy: proxyRepository, - RedeemCode: redeemCodeRepository, - UsageLog: usageLogRepository, - Setting: settingRepository, - UserSubscription: userSubscriptionRepository, - } billingCacheService := service.NewBillingCacheService(client, userRepository, userSubscriptionRepository) - subscriptionService := service.NewSubscriptionService(repositories, billingCacheService) + subscriptionService := service.NewSubscriptionService(groupRepository, userSubscriptionRepository, billingCacheService) redeemService := service.NewRedeemService(redeemCodeRepository, userRepository, subscriptionService, client, billingCacheService) redeemHandler := handler.NewRedeemHandler(redeemService) subscriptionHandler := handler.NewSubscriptionHandler(subscriptionService) - adminService := service.NewAdminService(repositories, billingCacheService) + accountRepository := repository.NewAccountRepository(db) + proxyRepository := repository.NewProxyRepository(db) + adminService := service.NewAdminService(userRepository, groupRepository, accountRepository, proxyRepository, apiKeyRepository, redeemCodeRepository, usageLogRepository, userSubscriptionRepository, billingCacheService) dashboardHandler := admin.NewDashboardHandler(adminService, usageLogRepository) adminUserHandler := admin.NewUserHandler(adminService) groupHandler := admin.NewGroupHandler(adminService) oAuthService := service.NewOAuthService(proxyRepository) - rateLimitService := service.NewRateLimitService(repositories, configConfig) - accountUsageService := service.NewAccountUsageService(repositories, oAuthService) - accountTestService := service.NewAccountTestService(repositories, oAuthService) + rateLimitService := service.NewRateLimitService(accountRepository, configConfig) + accountUsageService := service.NewAccountUsageService(accountRepository, usageLogRepository, oAuthService) + accountTestService := service.NewAccountTestService(accountRepository, oAuthService) accountHandler := admin.NewAccountHandler(adminService, oAuthService, rateLimitService, accountUsageService, accountTestService) oAuthHandler := admin.NewOAuthHandler(oAuthService, adminService) proxyHandler := admin.NewProxyHandler(adminService) @@ -98,7 +87,7 @@ func initializeApplication(buildInfo handler.BuildInfo) (*Application, error) { } billingService := service.NewBillingService(configConfig, pricingService) identityService := service.NewIdentityService(client) - gatewayService := service.NewGatewayService(repositories, client, configConfig, oAuthService, billingService, rateLimitService, billingCacheService, identityService) + gatewayService := service.NewGatewayService(accountRepository, usageLogRepository, userRepository, userSubscriptionRepository, client, configConfig, oAuthService, billingService, rateLimitService, billingCacheService, identityService) concurrencyService := service.NewConcurrencyService(client) gatewayHandler := handler.NewGatewayHandler(gatewayService, userService, concurrencyService, billingCacheService) handlerSettingHandler := handler.ProvideSettingHandler(settingService, buildInfo) @@ -132,6 +121,17 @@ func initializeApplication(buildInfo handler.BuildInfo) (*Application, error) { Concurrency: concurrencyService, Identity: identityService, } + repositories := &repository.Repositories{ + User: userRepository, + ApiKey: apiKeyRepository, + Group: groupRepository, + Account: accountRepository, + Proxy: proxyRepository, + RedeemCode: redeemCodeRepository, + UsageLog: usageLogRepository, + Setting: settingRepository, + UserSubscription: userSubscriptionRepository, + } engine := server.ProvideRouter(configConfig, handlers, services, repositories) httpServer := server.ProvideHTTPServer(configConfig, engine) v := provideCleanup(db, client, services) diff --git a/backend/internal/handler/admin/subscription_handler.go b/backend/internal/handler/admin/subscription_handler.go index 0d83e848..f68e4b3e 100644 --- a/backend/internal/handler/admin/subscription_handler.go +++ b/backend/internal/handler/admin/subscription_handler.go @@ -4,15 +4,15 @@ import ( "strconv" "sub2api/internal/model" + "sub2api/internal/pkg/pagination" "sub2api/internal/pkg/response" - "sub2api/internal/repository" "sub2api/internal/service" "github.com/gin-gonic/gin" ) -// toResponsePagination converts repository.PaginationResult to response.PaginationResult -func toResponsePagination(p *repository.PaginationResult) *response.PaginationResult { +// toResponsePagination converts pagination.PaginationResult to response.PaginationResult +func toResponsePagination(p *pagination.PaginationResult) *response.PaginationResult { if p == nil { return nil } diff --git a/backend/internal/handler/admin/usage_handler.go b/backend/internal/handler/admin/usage_handler.go index a0154c12..9241abb6 100644 --- a/backend/internal/handler/admin/usage_handler.go +++ b/backend/internal/handler/admin/usage_handler.go @@ -4,6 +4,7 @@ import ( "strconv" "time" + "sub2api/internal/pkg/pagination" "sub2api/internal/pkg/response" "sub2api/internal/pkg/timezone" "sub2api/internal/repository" @@ -14,10 +15,10 @@ import ( // UsageHandler handles admin usage-related requests type UsageHandler struct { - usageRepo *repository.UsageLogRepository - apiKeyRepo *repository.ApiKeyRepository - usageService *service.UsageService - adminService service.AdminService + usageRepo *repository.UsageLogRepository + apiKeyRepo *repository.ApiKeyRepository + usageService *service.UsageService + adminService service.AdminService } // NewUsageHandler creates a new admin usage handler @@ -82,7 +83,7 @@ func (h *UsageHandler) List(c *gin.Context) { endTime = &t } - params := repository.PaginationParams{Page: page, PageSize: pageSize} + params := pagination.PaginationParams{Page: page, PageSize: pageSize} filters := repository.UsageLogFilters{ UserID: userID, ApiKeyID: apiKeyID, diff --git a/backend/internal/handler/api_key_handler.go b/backend/internal/handler/api_key_handler.go index 3e4a5733..5505fd88 100644 --- a/backend/internal/handler/api_key_handler.go +++ b/backend/internal/handler/api_key_handler.go @@ -4,8 +4,8 @@ import ( "strconv" "sub2api/internal/model" + "sub2api/internal/pkg/pagination" "sub2api/internal/pkg/response" - "sub2api/internal/repository" "sub2api/internal/service" "github.com/gin-gonic/gin" @@ -53,7 +53,7 @@ func (h *APIKeyHandler) List(c *gin.Context) { } page, pageSize := response.ParsePagination(c) - params := repository.PaginationParams{Page: page, PageSize: pageSize} + params := pagination.PaginationParams{Page: page, PageSize: pageSize} keys, result, err := h.apiKeyService.List(c.Request.Context(), user.ID, params) if err != nil { diff --git a/backend/internal/handler/usage_handler.go b/backend/internal/handler/usage_handler.go index 1a6f2614..26bd8da4 100644 --- a/backend/internal/handler/usage_handler.go +++ b/backend/internal/handler/usage_handler.go @@ -5,6 +5,7 @@ import ( "time" "sub2api/internal/model" + "sub2api/internal/pkg/pagination" "sub2api/internal/pkg/response" "sub2api/internal/pkg/timezone" "sub2api/internal/repository" @@ -68,9 +69,9 @@ func (h *UsageHandler) List(c *gin.Context) { apiKeyID = id } - params := repository.PaginationParams{Page: page, PageSize: pageSize} + params := pagination.PaginationParams{Page: page, PageSize: pageSize} var records []model.UsageLog - var result *repository.PaginationResult + var result *pagination.PaginationResult var err error if apiKeyID > 0 { @@ -362,7 +363,7 @@ func (h *UsageHandler) DashboardApiKeysUsage(c *gin.Context) { } // Verify ownership of all requested API keys - userApiKeys, _, err := h.apiKeyService.List(c.Request.Context(), user.ID, repository.PaginationParams{Page: 1, PageSize: 1000}) + userApiKeys, _, err := h.apiKeyService.List(c.Request.Context(), user.ID, pagination.PaginationParams{Page: 1, PageSize: 1000}) if err != nil { response.InternalError(c, "Failed to verify API key ownership") return diff --git a/backend/internal/pkg/pagination/pagination.go b/backend/internal/pkg/pagination/pagination.go new file mode 100644 index 00000000..12ff321e --- /dev/null +++ b/backend/internal/pkg/pagination/pagination.go @@ -0,0 +1,42 @@ +package pagination + +// PaginationParams 分页参数 +type PaginationParams struct { + Page int + PageSize int +} + +// PaginationResult 分页结果 +type PaginationResult struct { + Total int64 + Page int + PageSize int + Pages int +} + +// DefaultPagination 默认分页参数 +func DefaultPagination() PaginationParams { + return PaginationParams{ + Page: 1, + PageSize: 20, + } +} + +// Offset 计算偏移量 +func (p PaginationParams) Offset() int { + if p.Page < 1 { + p.Page = 1 + } + return (p.Page - 1) * p.PageSize +} + +// Limit 获取限制数 +func (p PaginationParams) Limit() int { + if p.PageSize < 1 { + return 20 + } + if p.PageSize > 100 { + return 100 + } + return p.PageSize +} diff --git a/backend/internal/pkg/response/response.go b/backend/internal/pkg/response/response.go index 42acbd35..0739bb25 100644 --- a/backend/internal/pkg/response/response.go +++ b/backend/internal/pkg/response/response.go @@ -90,7 +90,7 @@ func Paginated(c *gin.Context, items interface{}, total int64, page, pageSize in }) } -// PaginationResult 分页结果(与repository.PaginationResult兼容) +// PaginationResult 分页结果(与pagination.PaginationResult兼容) type PaginationResult struct { Total int64 Page int diff --git a/backend/internal/pkg/usagestats/account_stats.go b/backend/internal/pkg/usagestats/account_stats.go new file mode 100644 index 00000000..ed77dd27 --- /dev/null +++ b/backend/internal/pkg/usagestats/account_stats.go @@ -0,0 +1,8 @@ +package usagestats + +// AccountStats 账号使用统计 +type AccountStats struct { + Requests int64 `json:"requests"` + Tokens int64 `json:"tokens"` + Cost float64 `json:"cost"` +} diff --git a/backend/internal/repository/account_repo.go b/backend/internal/repository/account_repo.go index 7f0fd7f2..a746d412 100644 --- a/backend/internal/repository/account_repo.go +++ b/backend/internal/repository/account_repo.go @@ -3,6 +3,7 @@ package repository import ( "context" "sub2api/internal/model" + "sub2api/internal/pkg/pagination" "time" "gorm.io/gorm" @@ -47,12 +48,12 @@ func (r *AccountRepository) Delete(ctx context.Context, id int64) error { return r.db.WithContext(ctx).Delete(&model.Account{}, id).Error } -func (r *AccountRepository) List(ctx context.Context, params PaginationParams) ([]model.Account, *PaginationResult, error) { +func (r *AccountRepository) List(ctx context.Context, params pagination.PaginationParams) ([]model.Account, *pagination.PaginationResult, error) { return r.ListWithFilters(ctx, params, "", "", "", "") } // ListWithFilters lists accounts with optional filtering by platform, type, status, and search query -func (r *AccountRepository) ListWithFilters(ctx context.Context, params PaginationParams, platform, accountType, status, search string) ([]model.Account, *PaginationResult, error) { +func (r *AccountRepository) ListWithFilters(ctx context.Context, params pagination.PaginationParams, platform, accountType, status, search string) ([]model.Account, *pagination.PaginationResult, error) { var accounts []model.Account var total int64 @@ -94,7 +95,7 @@ func (r *AccountRepository) ListWithFilters(ctx context.Context, params Paginati pages++ } - return accounts, &PaginationResult{ + return accounts, &pagination.PaginationResult{ Total: total, Page: params.Page, PageSize: params.Limit(), @@ -226,7 +227,7 @@ func (r *AccountRepository) SetRateLimited(ctx context.Context, id int64, resetA now := time.Now() return r.db.WithContext(ctx).Model(&model.Account{}).Where("id = ?", id). Updates(map[string]interface{}{ - "rate_limited_at": now, + "rate_limited_at": now, "rate_limit_reset_at": resetAt, }).Error } diff --git a/backend/internal/repository/api_key_repo.go b/backend/internal/repository/api_key_repo.go index 66ed8cdb..2c957368 100644 --- a/backend/internal/repository/api_key_repo.go +++ b/backend/internal/repository/api_key_repo.go @@ -3,6 +3,7 @@ package repository import ( "context" "sub2api/internal/model" + "sub2api/internal/pkg/pagination" "gorm.io/gorm" ) @@ -45,7 +46,7 @@ func (r *ApiKeyRepository) Delete(ctx context.Context, id int64) error { return r.db.WithContext(ctx).Delete(&model.ApiKey{}, id).Error } -func (r *ApiKeyRepository) ListByUserID(ctx context.Context, userID int64, params PaginationParams) ([]model.ApiKey, *PaginationResult, error) { +func (r *ApiKeyRepository) ListByUserID(ctx context.Context, userID int64, params pagination.PaginationParams) ([]model.ApiKey, *pagination.PaginationResult, error) { var keys []model.ApiKey var total int64 @@ -64,7 +65,7 @@ func (r *ApiKeyRepository) ListByUserID(ctx context.Context, userID int64, param pages++ } - return keys, &PaginationResult{ + return keys, &pagination.PaginationResult{ Total: total, Page: params.Page, PageSize: params.Limit(), @@ -84,7 +85,7 @@ func (r *ApiKeyRepository) ExistsByKey(ctx context.Context, key string) (bool, e return count > 0, err } -func (r *ApiKeyRepository) ListByGroupID(ctx context.Context, groupID int64, params PaginationParams) ([]model.ApiKey, *PaginationResult, error) { +func (r *ApiKeyRepository) ListByGroupID(ctx context.Context, groupID int64, params pagination.PaginationParams) ([]model.ApiKey, *pagination.PaginationResult, error) { var keys []model.ApiKey var total int64 @@ -103,7 +104,7 @@ func (r *ApiKeyRepository) ListByGroupID(ctx context.Context, groupID int64, par pages++ } - return keys, &PaginationResult{ + return keys, &pagination.PaginationResult{ Total: total, Page: params.Page, PageSize: params.Limit(), diff --git a/backend/internal/repository/group_repo.go b/backend/internal/repository/group_repo.go index 19e72a15..d39ec8d3 100644 --- a/backend/internal/repository/group_repo.go +++ b/backend/internal/repository/group_repo.go @@ -3,6 +3,7 @@ package repository import ( "context" "sub2api/internal/model" + "sub2api/internal/pkg/pagination" "gorm.io/gorm" ) @@ -36,12 +37,12 @@ func (r *GroupRepository) Delete(ctx context.Context, id int64) error { return r.db.WithContext(ctx).Delete(&model.Group{}, id).Error } -func (r *GroupRepository) List(ctx context.Context, params PaginationParams) ([]model.Group, *PaginationResult, error) { +func (r *GroupRepository) List(ctx context.Context, params pagination.PaginationParams) ([]model.Group, *pagination.PaginationResult, error) { return r.ListWithFilters(ctx, params, "", "", nil) } // ListWithFilters lists groups with optional filtering by platform, status, and is_exclusive -func (r *GroupRepository) ListWithFilters(ctx context.Context, params PaginationParams, platform, status string, isExclusive *bool) ([]model.Group, *PaginationResult, error) { +func (r *GroupRepository) ListWithFilters(ctx context.Context, params pagination.PaginationParams, platform, status string, isExclusive *bool) ([]model.Group, *pagination.PaginationResult, error) { var groups []model.Group var total int64 @@ -77,7 +78,7 @@ func (r *GroupRepository) ListWithFilters(ctx context.Context, params Pagination pages++ } - return groups, &PaginationResult{ + return groups, &pagination.PaginationResult{ Total: total, Page: params.Page, PageSize: params.Limit(), diff --git a/backend/internal/repository/proxy_repo.go b/backend/internal/repository/proxy_repo.go index 1cd22cc8..20117f47 100644 --- a/backend/internal/repository/proxy_repo.go +++ b/backend/internal/repository/proxy_repo.go @@ -3,6 +3,7 @@ package repository import ( "context" "sub2api/internal/model" + "sub2api/internal/pkg/pagination" "gorm.io/gorm" ) @@ -36,12 +37,12 @@ func (r *ProxyRepository) Delete(ctx context.Context, id int64) error { return r.db.WithContext(ctx).Delete(&model.Proxy{}, id).Error } -func (r *ProxyRepository) List(ctx context.Context, params PaginationParams) ([]model.Proxy, *PaginationResult, error) { +func (r *ProxyRepository) List(ctx context.Context, params pagination.PaginationParams) ([]model.Proxy, *pagination.PaginationResult, error) { return r.ListWithFilters(ctx, params, "", "", "") } // ListWithFilters lists proxies with optional filtering by protocol, status, and search query -func (r *ProxyRepository) ListWithFilters(ctx context.Context, params PaginationParams, protocol, status, search string) ([]model.Proxy, *PaginationResult, error) { +func (r *ProxyRepository) ListWithFilters(ctx context.Context, params pagination.PaginationParams, protocol, status, search string) ([]model.Proxy, *pagination.PaginationResult, error) { var proxies []model.Proxy var total int64 @@ -72,7 +73,7 @@ func (r *ProxyRepository) ListWithFilters(ctx context.Context, params Pagination pages++ } - return proxies, &PaginationResult{ + return proxies, &pagination.PaginationResult{ Total: total, Page: params.Page, PageSize: params.Limit(), diff --git a/backend/internal/repository/redeem_code_repo.go b/backend/internal/repository/redeem_code_repo.go index 90a44ba2..f971a9f3 100644 --- a/backend/internal/repository/redeem_code_repo.go +++ b/backend/internal/repository/redeem_code_repo.go @@ -3,6 +3,7 @@ package repository import ( "context" "sub2api/internal/model" + "sub2api/internal/pkg/pagination" "time" "gorm.io/gorm" @@ -46,12 +47,12 @@ func (r *RedeemCodeRepository) Delete(ctx context.Context, id int64) error { return r.db.WithContext(ctx).Delete(&model.RedeemCode{}, id).Error } -func (r *RedeemCodeRepository) List(ctx context.Context, params PaginationParams) ([]model.RedeemCode, *PaginationResult, error) { +func (r *RedeemCodeRepository) List(ctx context.Context, params pagination.PaginationParams) ([]model.RedeemCode, *pagination.PaginationResult, error) { return r.ListWithFilters(ctx, params, "", "", "") } // ListWithFilters lists redeem codes with optional filtering by type, status, and search query -func (r *RedeemCodeRepository) ListWithFilters(ctx context.Context, params PaginationParams, codeType, status, search string) ([]model.RedeemCode, *PaginationResult, error) { +func (r *RedeemCodeRepository) ListWithFilters(ctx context.Context, params pagination.PaginationParams, codeType, status, search string) ([]model.RedeemCode, *pagination.PaginationResult, error) { var codes []model.RedeemCode var total int64 @@ -82,7 +83,7 @@ func (r *RedeemCodeRepository) ListWithFilters(ctx context.Context, params Pagin pages++ } - return codes, &PaginationResult{ + return codes, &pagination.PaginationResult{ Total: total, Page: params.Page, PageSize: params.Limit(), diff --git a/backend/internal/repository/repository.go b/backend/internal/repository/repository.go index 0e880064..82cc46f0 100644 --- a/backend/internal/repository/repository.go +++ b/backend/internal/repository/repository.go @@ -12,44 +12,3 @@ type Repositories struct { Setting *SettingRepository UserSubscription *UserSubscriptionRepository } - -// PaginationParams 分页参数 -type PaginationParams struct { - Page int - PageSize int -} - -// PaginationResult 分页结果 -type PaginationResult struct { - Total int64 - Page int - PageSize int - Pages int -} - -// DefaultPagination 默认分页参数 -func DefaultPagination() PaginationParams { - return PaginationParams{ - Page: 1, - PageSize: 20, - } -} - -// Offset 计算偏移量 -func (p PaginationParams) Offset() int { - if p.Page < 1 { - p.Page = 1 - } - return (p.Page - 1) * p.PageSize -} - -// Limit 获取限制数 -func (p PaginationParams) Limit() int { - if p.PageSize < 1 { - return 20 - } - if p.PageSize > 100 { - return 100 - } - return p.PageSize -} diff --git a/backend/internal/repository/usage_log_repo.go b/backend/internal/repository/usage_log_repo.go index e8eab78e..080b975e 100644 --- a/backend/internal/repository/usage_log_repo.go +++ b/backend/internal/repository/usage_log_repo.go @@ -3,7 +3,9 @@ package repository import ( "context" "sub2api/internal/model" + "sub2api/internal/pkg/pagination" "sub2api/internal/pkg/timezone" + "sub2api/internal/pkg/usagestats" "time" "gorm.io/gorm" @@ -30,7 +32,7 @@ func (r *UsageLogRepository) GetByID(ctx context.Context, id int64) (*model.Usag return &log, nil } -func (r *UsageLogRepository) ListByUser(ctx context.Context, userID int64, params PaginationParams) ([]model.UsageLog, *PaginationResult, error) { +func (r *UsageLogRepository) ListByUser(ctx context.Context, userID int64, params pagination.PaginationParams) ([]model.UsageLog, *pagination.PaginationResult, error) { var logs []model.UsageLog var total int64 @@ -49,7 +51,7 @@ func (r *UsageLogRepository) ListByUser(ctx context.Context, userID int64, param pages++ } - return logs, &PaginationResult{ + return logs, &pagination.PaginationResult{ Total: total, Page: params.Page, PageSize: params.Limit(), @@ -57,7 +59,7 @@ func (r *UsageLogRepository) ListByUser(ctx context.Context, userID int64, param }, nil } -func (r *UsageLogRepository) ListByApiKey(ctx context.Context, apiKeyID int64, params PaginationParams) ([]model.UsageLog, *PaginationResult, error) { +func (r *UsageLogRepository) ListByApiKey(ctx context.Context, apiKeyID int64, params pagination.PaginationParams) ([]model.UsageLog, *pagination.PaginationResult, error) { var logs []model.UsageLog var total int64 @@ -76,7 +78,7 @@ func (r *UsageLogRepository) ListByApiKey(ctx context.Context, apiKeyID int64, p pages++ } - return logs, &PaginationResult{ + return logs, &pagination.PaginationResult{ Total: total, Page: params.Page, PageSize: params.Limit(), @@ -270,7 +272,7 @@ func (r *UsageLogRepository) GetDashboardStats(ctx context.Context) (*DashboardS return &stats, nil } -func (r *UsageLogRepository) ListByAccount(ctx context.Context, accountID int64, params PaginationParams) ([]model.UsageLog, *PaginationResult, error) { +func (r *UsageLogRepository) ListByAccount(ctx context.Context, accountID int64, params pagination.PaginationParams) ([]model.UsageLog, *pagination.PaginationResult, error) { var logs []model.UsageLog var total int64 @@ -289,7 +291,7 @@ func (r *UsageLogRepository) ListByAccount(ctx context.Context, accountID int64, pages++ } - return logs, &PaginationResult{ + return logs, &pagination.PaginationResult{ Total: total, Page: params.Page, PageSize: params.Limit(), @@ -297,7 +299,7 @@ func (r *UsageLogRepository) ListByAccount(ctx context.Context, accountID int64, }, nil } -func (r *UsageLogRepository) ListByUserAndTimeRange(ctx context.Context, userID int64, startTime, endTime time.Time) ([]model.UsageLog, *PaginationResult, error) { +func (r *UsageLogRepository) ListByUserAndTimeRange(ctx context.Context, userID int64, startTime, endTime time.Time) ([]model.UsageLog, *pagination.PaginationResult, error) { var logs []model.UsageLog err := r.db.WithContext(ctx). Where("user_id = ? AND created_at >= ? AND created_at < ?", userID, startTime, endTime). @@ -306,7 +308,7 @@ func (r *UsageLogRepository) ListByUserAndTimeRange(ctx context.Context, userID return logs, nil, err } -func (r *UsageLogRepository) ListByApiKeyAndTimeRange(ctx context.Context, apiKeyID int64, startTime, endTime time.Time) ([]model.UsageLog, *PaginationResult, error) { +func (r *UsageLogRepository) ListByApiKeyAndTimeRange(ctx context.Context, apiKeyID int64, startTime, endTime time.Time) ([]model.UsageLog, *pagination.PaginationResult, error) { var logs []model.UsageLog err := r.db.WithContext(ctx). Where("api_key_id = ? AND created_at >= ? AND created_at < ?", apiKeyID, startTime, endTime). @@ -315,7 +317,7 @@ func (r *UsageLogRepository) ListByApiKeyAndTimeRange(ctx context.Context, apiKe return logs, nil, err } -func (r *UsageLogRepository) ListByAccountAndTimeRange(ctx context.Context, accountID int64, startTime, endTime time.Time) ([]model.UsageLog, *PaginationResult, error) { +func (r *UsageLogRepository) ListByAccountAndTimeRange(ctx context.Context, accountID int64, startTime, endTime time.Time) ([]model.UsageLog, *pagination.PaginationResult, error) { var logs []model.UsageLog err := r.db.WithContext(ctx). Where("account_id = ? AND created_at >= ? AND created_at < ?", accountID, startTime, endTime). @@ -324,7 +326,7 @@ func (r *UsageLogRepository) ListByAccountAndTimeRange(ctx context.Context, acco return logs, nil, err } -func (r *UsageLogRepository) ListByModelAndTimeRange(ctx context.Context, modelName string, startTime, endTime time.Time) ([]model.UsageLog, *PaginationResult, error) { +func (r *UsageLogRepository) ListByModelAndTimeRange(ctx context.Context, modelName string, startTime, endTime time.Time) ([]model.UsageLog, *pagination.PaginationResult, error) { var logs []model.UsageLog err := r.db.WithContext(ctx). Where("model = ? AND created_at >= ? AND created_at < ?", modelName, startTime, endTime). @@ -337,15 +339,8 @@ func (r *UsageLogRepository) Delete(ctx context.Context, id int64) error { return r.db.WithContext(ctx).Delete(&model.UsageLog{}, id).Error } -// AccountStats 账号使用统计 -type AccountStats struct { - Requests int64 `json:"requests"` - Tokens int64 `json:"tokens"` - Cost float64 `json:"cost"` -} - // GetAccountTodayStats 获取账号今日统计 -func (r *UsageLogRepository) GetAccountTodayStats(ctx context.Context, accountID int64) (*AccountStats, error) { +func (r *UsageLogRepository) GetAccountTodayStats(ctx context.Context, accountID int64) (*usagestats.AccountStats, error) { today := timezone.Today() var stats struct { @@ -367,7 +362,7 @@ func (r *UsageLogRepository) GetAccountTodayStats(ctx context.Context, accountID return nil, err } - return &AccountStats{ + return &usagestats.AccountStats{ Requests: stats.Requests, Tokens: stats.Tokens, Cost: stats.Cost, @@ -375,7 +370,7 @@ func (r *UsageLogRepository) GetAccountTodayStats(ctx context.Context, accountID } // GetAccountWindowStats 获取账号时间窗口内的统计 -func (r *UsageLogRepository) GetAccountWindowStats(ctx context.Context, accountID int64, startTime time.Time) (*AccountStats, error) { +func (r *UsageLogRepository) GetAccountWindowStats(ctx context.Context, accountID int64, startTime time.Time) (*usagestats.AccountStats, error) { var stats struct { Requests int64 `gorm:"column:requests"` Tokens int64 `gorm:"column:tokens"` @@ -395,7 +390,7 @@ func (r *UsageLogRepository) GetAccountWindowStats(ctx context.Context, accountI return nil, err } - return &AccountStats{ + return &usagestats.AccountStats{ Requests: stats.Requests, Tokens: stats.Tokens, Cost: stats.Cost, @@ -500,11 +495,11 @@ func (r *UsageLogRepository) GetModelStats(ctx context.Context, startTime, endTi // ApiKeyUsageTrendPoint represents API key usage trend data point type ApiKeyUsageTrendPoint struct { - Date string `json:"date"` - ApiKeyID int64 `json:"api_key_id"` - KeyName string `json:"key_name"` - Requests int64 `json:"requests"` - Tokens int64 `json:"tokens"` + Date string `json:"date"` + ApiKeyID int64 `json:"api_key_id"` + KeyName string `json:"key_name"` + Requests int64 `json:"requests"` + Tokens int64 `json:"tokens"` } // GetApiKeyUsageTrend returns usage trend data grouped by API key and date @@ -780,7 +775,7 @@ type UsageLogFilters struct { } // ListWithFilters lists usage logs with optional filters (for admin) -func (r *UsageLogRepository) ListWithFilters(ctx context.Context, params PaginationParams, filters UsageLogFilters) ([]model.UsageLog, *PaginationResult, error) { +func (r *UsageLogRepository) ListWithFilters(ctx context.Context, params pagination.PaginationParams, filters UsageLogFilters) ([]model.UsageLog, *pagination.PaginationResult, error) { var logs []model.UsageLog var total int64 @@ -816,7 +811,7 @@ func (r *UsageLogRepository) ListWithFilters(ctx context.Context, params Paginat pages++ } - return logs, &PaginationResult{ + return logs, &pagination.PaginationResult{ Total: total, Page: params.Page, PageSize: params.Limit(), @@ -838,7 +833,7 @@ type UsageStats struct { // BatchUserUsageStats represents usage stats for a single user type BatchUserUsageStats struct { - UserID int64 `json:"user_id"` + UserID int64 `json:"user_id"` TodayActualCost float64 `json:"today_actual_cost"` TotalActualCost float64 `json:"total_actual_cost"` } diff --git a/backend/internal/repository/user_repo.go b/backend/internal/repository/user_repo.go index 46de5390..6867f796 100644 --- a/backend/internal/repository/user_repo.go +++ b/backend/internal/repository/user_repo.go @@ -3,6 +3,7 @@ package repository import ( "context" "sub2api/internal/model" + "sub2api/internal/pkg/pagination" "gorm.io/gorm" ) @@ -45,12 +46,12 @@ func (r *UserRepository) Delete(ctx context.Context, id int64) error { return r.db.WithContext(ctx).Delete(&model.User{}, id).Error } -func (r *UserRepository) List(ctx context.Context, params PaginationParams) ([]model.User, *PaginationResult, error) { +func (r *UserRepository) List(ctx context.Context, params pagination.PaginationParams) ([]model.User, *pagination.PaginationResult, error) { return r.ListWithFilters(ctx, params, "", "", "") } // ListWithFilters lists users with optional filtering by status, role, and search query -func (r *UserRepository) ListWithFilters(ctx context.Context, params PaginationParams, status, role, search string) ([]model.User, *PaginationResult, error) { +func (r *UserRepository) ListWithFilters(ctx context.Context, params pagination.PaginationParams, status, role, search string) ([]model.User, *pagination.PaginationResult, error) { var users []model.User var total int64 @@ -81,7 +82,7 @@ func (r *UserRepository) ListWithFilters(ctx context.Context, params PaginationP pages++ } - return users, &PaginationResult{ + return users, &pagination.PaginationResult{ Total: total, Page: params.Page, PageSize: params.Limit(), @@ -127,4 +128,3 @@ func (r *UserRepository) RemoveGroupFromAllowedGroups(ctx context.Context, group Update("allowed_groups", gorm.Expr("array_remove(allowed_groups, ?)", groupID)) return result.RowsAffected, result.Error } - diff --git a/backend/internal/repository/user_subscription_repo.go b/backend/internal/repository/user_subscription_repo.go index 49840a86..7f2b74e4 100644 --- a/backend/internal/repository/user_subscription_repo.go +++ b/backend/internal/repository/user_subscription_repo.go @@ -5,6 +5,7 @@ import ( "time" "sub2api/internal/model" + "sub2api/internal/pkg/pagination" "gorm.io/gorm" ) @@ -100,7 +101,7 @@ func (r *UserSubscriptionRepository) ListActiveByUserID(ctx context.Context, use } // ListByGroupID 获取分组的所有订阅(分页) -func (r *UserSubscriptionRepository) ListByGroupID(ctx context.Context, groupID int64, params PaginationParams) ([]model.UserSubscription, *PaginationResult, error) { +func (r *UserSubscriptionRepository) ListByGroupID(ctx context.Context, groupID int64, params pagination.PaginationParams) ([]model.UserSubscription, *pagination.PaginationResult, error) { var subs []model.UserSubscription var total int64 @@ -126,7 +127,7 @@ func (r *UserSubscriptionRepository) ListByGroupID(ctx context.Context, groupID pages++ } - return subs, &PaginationResult{ + return subs, &pagination.PaginationResult{ Total: total, Page: params.Page, PageSize: params.Limit(), @@ -135,7 +136,7 @@ func (r *UserSubscriptionRepository) ListByGroupID(ctx context.Context, groupID } // List 获取所有订阅(分页,支持筛选) -func (r *UserSubscriptionRepository) List(ctx context.Context, params PaginationParams, userID, groupID *int64, status string) ([]model.UserSubscription, *PaginationResult, error) { +func (r *UserSubscriptionRepository) List(ctx context.Context, params pagination.PaginationParams, userID, groupID *int64, status string) ([]model.UserSubscription, *pagination.PaginationResult, error) { var subs []model.UserSubscription var total int64 @@ -172,7 +173,7 @@ func (r *UserSubscriptionRepository) List(ctx context.Context, params Pagination pages++ } - return subs, &PaginationResult{ + return subs, &pagination.PaginationResult{ Total: total, Page: params.Page, PageSize: params.Limit(), diff --git a/backend/internal/repository/wire.go b/backend/internal/repository/wire.go index d2b56a1c..a5909b9b 100644 --- a/backend/internal/repository/wire.go +++ b/backend/internal/repository/wire.go @@ -1,6 +1,8 @@ package repository import ( + "sub2api/internal/service/ports" + "github.com/google/wire" ) @@ -16,4 +18,15 @@ var ProviderSet = wire.NewSet( NewSettingRepository, NewUserSubscriptionRepository, wire.Struct(new(Repositories), "*"), + + // Bind concrete repositories to service port interfaces + wire.Bind(new(ports.UserRepository), new(*UserRepository)), + wire.Bind(new(ports.ApiKeyRepository), new(*ApiKeyRepository)), + wire.Bind(new(ports.GroupRepository), new(*GroupRepository)), + wire.Bind(new(ports.AccountRepository), new(*AccountRepository)), + wire.Bind(new(ports.ProxyRepository), new(*ProxyRepository)), + wire.Bind(new(ports.RedeemCodeRepository), new(*RedeemCodeRepository)), + wire.Bind(new(ports.UsageLogRepository), new(*UsageLogRepository)), + wire.Bind(new(ports.SettingRepository), new(*SettingRepository)), + wire.Bind(new(ports.UserSubscriptionRepository), new(*UserSubscriptionRepository)), ) diff --git a/backend/internal/service/account_service.go b/backend/internal/service/account_service.go index b9bf7329..6e6b31fb 100644 --- a/backend/internal/service/account_service.go +++ b/backend/internal/service/account_service.go @@ -5,7 +5,8 @@ import ( "errors" "fmt" "sub2api/internal/model" - "sub2api/internal/repository" + "sub2api/internal/pkg/pagination" + "sub2api/internal/service/ports" "gorm.io/gorm" ) @@ -41,12 +42,12 @@ type UpdateAccountRequest struct { // AccountService 账号管理服务 type AccountService struct { - accountRepo *repository.AccountRepository - groupRepo *repository.GroupRepository + accountRepo ports.AccountRepository + groupRepo ports.GroupRepository } // NewAccountService 创建账号服务实例 -func NewAccountService(accountRepo *repository.AccountRepository, groupRepo *repository.GroupRepository) *AccountService { +func NewAccountService(accountRepo ports.AccountRepository, groupRepo ports.GroupRepository) *AccountService { return &AccountService{ accountRepo: accountRepo, groupRepo: groupRepo, @@ -108,7 +109,7 @@ func (s *AccountService) GetByID(ctx context.Context, id int64) (*model.Account, } // List 获取账号列表 -func (s *AccountService) List(ctx context.Context, params repository.PaginationParams) ([]model.Account, *repository.PaginationResult, error) { +func (s *AccountService) List(ctx context.Context, params pagination.PaginationParams) ([]model.Account, *pagination.PaginationResult, error) { accounts, pagination, err := s.accountRepo.List(ctx, params) if err != nil { return nil, nil, fmt.Errorf("list accounts: %w", err) diff --git a/backend/internal/service/account_test_service.go b/backend/internal/service/account_test_service.go index 3ab659f6..8cfc00f8 100644 --- a/backend/internal/service/account_test_service.go +++ b/backend/internal/service/account_test_service.go @@ -16,7 +16,7 @@ import ( "time" "sub2api/internal/pkg/claude" - "sub2api/internal/repository" + "sub2api/internal/service/ports" "github.com/gin-gonic/gin" "github.com/google/uuid" @@ -37,15 +37,15 @@ type TestEvent struct { // AccountTestService handles account testing operations type AccountTestService struct { - repos *repository.Repositories + accountRepo ports.AccountRepository oauthService *OAuthService httpClient *http.Client } // NewAccountTestService creates a new AccountTestService -func NewAccountTestService(repos *repository.Repositories, oauthService *OAuthService) *AccountTestService { +func NewAccountTestService(accountRepo ports.AccountRepository, oauthService *OAuthService) *AccountTestService { return &AccountTestService{ - repos: repos, + accountRepo: accountRepo, oauthService: oauthService, httpClient: &http.Client{ Timeout: 60 * time.Second, @@ -105,7 +105,7 @@ func (s *AccountTestService) TestAccountConnection(c *gin.Context, accountID int ctx := c.Request.Context() // Get account - account, err := s.repos.Account.GetByID(ctx, accountID) + account, err := s.accountRepo.GetByID(ctx, accountID) if err != nil { return s.sendErrorAndEnd(c, "Account not found") } diff --git a/backend/internal/service/account_usage_service.go b/backend/internal/service/account_usage_service.go index fe227fe6..b27830c4 100644 --- a/backend/internal/service/account_usage_service.go +++ b/backend/internal/service/account_usage_service.go @@ -12,7 +12,7 @@ import ( "time" "sub2api/internal/model" - "sub2api/internal/repository" + "sub2api/internal/service/ports" ) // usageCache 用于缓存usage数据 @@ -35,10 +35,10 @@ type WindowStats struct { // UsageProgress 使用量进度 type UsageProgress struct { - Utilization float64 `json:"utilization"` // 使用率百分比 (0-100+,100表示100%) - ResetsAt *time.Time `json:"resets_at"` // 重置时间 - RemainingSeconds int `json:"remaining_seconds"` // 距重置剩余秒数 - WindowStats *WindowStats `json:"window_stats,omitempty"` // 窗口期统计(从窗口开始到当前的使用量) + Utilization float64 `json:"utilization"` // 使用率百分比 (0-100+,100表示100%) + ResetsAt *time.Time `json:"resets_at"` // 重置时间 + RemainingSeconds int `json:"remaining_seconds"` // 距重置剩余秒数 + WindowStats *WindowStats `json:"window_stats,omitempty"` // 窗口期统计(从窗口开始到当前的使用量) } // UsageInfo 账号使用量信息 @@ -67,15 +67,17 @@ type ClaudeUsageResponse struct { // AccountUsageService 账号使用量查询服务 type AccountUsageService struct { - repos *repository.Repositories + accountRepo ports.AccountRepository + usageLogRepo ports.UsageLogRepository oauthService *OAuthService httpClient *http.Client } // NewAccountUsageService 创建AccountUsageService实例 -func NewAccountUsageService(repos *repository.Repositories, oauthService *OAuthService) *AccountUsageService { +func NewAccountUsageService(accountRepo ports.AccountRepository, usageLogRepo ports.UsageLogRepository, oauthService *OAuthService) *AccountUsageService { return &AccountUsageService{ - repos: repos, + accountRepo: accountRepo, + usageLogRepo: usageLogRepo, oauthService: oauthService, httpClient: &http.Client{ Timeout: 30 * time.Second, @@ -88,7 +90,7 @@ func NewAccountUsageService(repos *repository.Repositories, oauthService *OAuthS // Setup Token账号: 根据session_window推算5h窗口,7d数据不可用(没有profile scope) // API Key账号: 不支持usage查询 func (s *AccountUsageService) GetUsage(ctx context.Context, accountID int64) (*UsageInfo, error) { - account, err := s.repos.Account.GetByID(ctx, accountID) + account, err := s.accountRepo.GetByID(ctx, accountID) if err != nil { return nil, fmt.Errorf("get account failed: %w", err) } @@ -148,7 +150,7 @@ func (s *AccountUsageService) addWindowStats(ctx context.Context, account *model startTime = time.Now().Add(-5 * time.Hour) } - stats, err := s.repos.UsageLog.GetAccountWindowStats(ctx, account.ID, startTime) + stats, err := s.usageLogRepo.GetAccountWindowStats(ctx, account.ID, startTime) if err != nil { log.Printf("Failed to get window stats for account %d: %v", account.ID, err) return @@ -163,7 +165,7 @@ func (s *AccountUsageService) addWindowStats(ctx context.Context, account *model // GetTodayStats 获取账号今日统计 func (s *AccountUsageService) GetTodayStats(ctx context.Context, accountID int64) (*WindowStats, error) { - stats, err := s.repos.UsageLog.GetAccountTodayStats(ctx, accountID) + stats, err := s.usageLogRepo.GetAccountTodayStats(ctx, accountID) if err != nil { return nil, fmt.Errorf("get today stats failed: %w", err) } diff --git a/backend/internal/service/admin_service.go b/backend/internal/service/admin_service.go index 988bd7cd..4962dafd 100644 --- a/backend/internal/service/admin_service.go +++ b/backend/internal/service/admin_service.go @@ -13,7 +13,8 @@ import ( "time" "sub2api/internal/model" - "sub2api/internal/repository" + "sub2api/internal/pkg/pagination" + "sub2api/internal/service/ports" "golang.org/x/net/proxy" "gorm.io/gorm" @@ -179,35 +180,45 @@ type ProxyTestResult struct { // adminServiceImpl implements AdminService type adminServiceImpl struct { - userRepo *repository.UserRepository - groupRepo *repository.GroupRepository - accountRepo *repository.AccountRepository - proxyRepo *repository.ProxyRepository - apiKeyRepo *repository.ApiKeyRepository - redeemCodeRepo *repository.RedeemCodeRepository - usageLogRepo *repository.UsageLogRepository - userSubRepo *repository.UserSubscriptionRepository + userRepo ports.UserRepository + groupRepo ports.GroupRepository + accountRepo ports.AccountRepository + proxyRepo ports.ProxyRepository + apiKeyRepo ports.ApiKeyRepository + redeemCodeRepo ports.RedeemCodeRepository + usageLogRepo ports.UsageLogRepository + userSubRepo ports.UserSubscriptionRepository billingCacheService *BillingCacheService } // NewAdminService creates a new AdminService -func NewAdminService(repos *repository.Repositories, billingCacheService *BillingCacheService) AdminService { +func NewAdminService( + userRepo ports.UserRepository, + groupRepo ports.GroupRepository, + accountRepo ports.AccountRepository, + proxyRepo ports.ProxyRepository, + apiKeyRepo ports.ApiKeyRepository, + redeemCodeRepo ports.RedeemCodeRepository, + usageLogRepo ports.UsageLogRepository, + userSubRepo ports.UserSubscriptionRepository, + billingCacheService *BillingCacheService, +) AdminService { return &adminServiceImpl{ - userRepo: repos.User, - groupRepo: repos.Group, - accountRepo: repos.Account, - proxyRepo: repos.Proxy, - apiKeyRepo: repos.ApiKey, - redeemCodeRepo: repos.RedeemCode, - usageLogRepo: repos.UsageLog, - userSubRepo: repos.UserSubscription, + userRepo: userRepo, + groupRepo: groupRepo, + accountRepo: accountRepo, + proxyRepo: proxyRepo, + apiKeyRepo: apiKeyRepo, + redeemCodeRepo: redeemCodeRepo, + usageLogRepo: usageLogRepo, + userSubRepo: userSubRepo, billingCacheService: billingCacheService, } } // User management implementations func (s *adminServiceImpl) ListUsers(ctx context.Context, page, pageSize int, status, role, search string) ([]model.User, int64, error) { - params := repository.PaginationParams{Page: page, PageSize: pageSize} + params := pagination.PaginationParams{Page: page, PageSize: pageSize} users, result, err := s.userRepo.ListWithFilters(ctx, params, status, role, search) if err != nil { return nil, 0, err @@ -376,7 +387,7 @@ func (s *adminServiceImpl) UpdateUserBalance(ctx context.Context, userID int64, } func (s *adminServiceImpl) GetUserAPIKeys(ctx context.Context, userID int64, page, pageSize int) ([]model.ApiKey, int64, error) { - params := repository.PaginationParams{Page: page, PageSize: pageSize} + params := pagination.PaginationParams{Page: page, PageSize: pageSize} keys, result, err := s.apiKeyRepo.ListByUserID(ctx, userID, params) if err != nil { return nil, 0, err @@ -397,7 +408,7 @@ func (s *adminServiceImpl) GetUserUsageStats(ctx context.Context, userID int64, // Group management implementations func (s *adminServiceImpl) ListGroups(ctx context.Context, page, pageSize int, platform, status string, isExclusive *bool) ([]model.Group, int64, error) { - params := repository.PaginationParams{Page: page, PageSize: pageSize} + params := pagination.PaginationParams{Page: page, PageSize: pageSize} groups, result, err := s.groupRepo.ListWithFilters(ctx, params, platform, status, isExclusive) if err != nil { return nil, 0, err @@ -568,7 +579,7 @@ func (s *adminServiceImpl) DeleteGroup(ctx context.Context, id int64) error { } func (s *adminServiceImpl) GetGroupAPIKeys(ctx context.Context, groupID int64, page, pageSize int) ([]model.ApiKey, int64, error) { - params := repository.PaginationParams{Page: page, PageSize: pageSize} + params := pagination.PaginationParams{Page: page, PageSize: pageSize} keys, result, err := s.apiKeyRepo.ListByGroupID(ctx, groupID, params) if err != nil { return nil, 0, err @@ -578,7 +589,7 @@ func (s *adminServiceImpl) GetGroupAPIKeys(ctx context.Context, groupID int64, p // Account management implementations func (s *adminServiceImpl) ListAccounts(ctx context.Context, page, pageSize int, platform, accountType, status, search string) ([]model.Account, int64, error) { - params := repository.PaginationParams{Page: page, PageSize: pageSize} + params := pagination.PaginationParams{Page: page, PageSize: pageSize} accounts, result, err := s.accountRepo.ListWithFilters(ctx, params, platform, accountType, status, search) if err != nil { return nil, 0, err @@ -696,7 +707,7 @@ func (s *adminServiceImpl) SetAccountSchedulable(ctx context.Context, id int64, // Proxy management implementations func (s *adminServiceImpl) ListProxies(ctx context.Context, page, pageSize int, protocol, status, search string) ([]model.Proxy, int64, error) { - params := repository.PaginationParams{Page: page, PageSize: pageSize} + params := pagination.PaginationParams{Page: page, PageSize: pageSize} proxies, result, err := s.proxyRepo.ListWithFilters(ctx, params, protocol, status, search) if err != nil { return nil, 0, err @@ -781,7 +792,7 @@ func (s *adminServiceImpl) CheckProxyExists(ctx context.Context, host string, po // Redeem code management implementations func (s *adminServiceImpl) ListRedeemCodes(ctx context.Context, page, pageSize int, codeType, status, search string) ([]model.RedeemCode, int64, error) { - params := repository.PaginationParams{Page: page, PageSize: pageSize} + params := pagination.PaginationParams{Page: page, PageSize: pageSize} codes, result, err := s.redeemCodeRepo.ListWithFilters(ctx, params, codeType, status, search) if err != nil { return nil, 0, err diff --git a/backend/internal/service/api_key_service.go b/backend/internal/service/api_key_service.go index b0830d3e..1e98888d 100644 --- a/backend/internal/service/api_key_service.go +++ b/backend/internal/service/api_key_service.go @@ -8,8 +8,9 @@ import ( "fmt" "sub2api/internal/config" "sub2api/internal/model" + "sub2api/internal/pkg/pagination" "sub2api/internal/pkg/timezone" - "sub2api/internal/repository" + "sub2api/internal/service/ports" "time" "github.com/redis/go-redis/v9" @@ -17,12 +18,12 @@ import ( ) var ( - ErrApiKeyNotFound = errors.New("api key not found") - ErrGroupNotAllowed = errors.New("user is not allowed to bind this group") - ErrApiKeyExists = errors.New("api key already exists") - ErrApiKeyTooShort = errors.New("api key must be at least 16 characters") - ErrApiKeyInvalidChars = errors.New("api key can only contain letters, numbers, underscores, and hyphens") - ErrApiKeyRateLimited = errors.New("too many failed attempts, please try again later") + ErrApiKeyNotFound = errors.New("api key not found") + ErrGroupNotAllowed = errors.New("user is not allowed to bind this group") + ErrApiKeyExists = errors.New("api key already exists") + ErrApiKeyTooShort = errors.New("api key must be at least 16 characters") + ErrApiKeyInvalidChars = errors.New("api key can only contain letters, numbers, underscores, and hyphens") + ErrApiKeyRateLimited = errors.New("too many failed attempts, please try again later") ) const ( @@ -47,20 +48,20 @@ type UpdateApiKeyRequest struct { // ApiKeyService API Key服务 type ApiKeyService struct { - apiKeyRepo *repository.ApiKeyRepository - userRepo *repository.UserRepository - groupRepo *repository.GroupRepository - userSubRepo *repository.UserSubscriptionRepository + apiKeyRepo ports.ApiKeyRepository + userRepo ports.UserRepository + groupRepo ports.GroupRepository + userSubRepo ports.UserSubscriptionRepository rdb *redis.Client cfg *config.Config } // NewApiKeyService 创建API Key服务实例 func NewApiKeyService( - apiKeyRepo *repository.ApiKeyRepository, - userRepo *repository.UserRepository, - groupRepo *repository.GroupRepository, - userSubRepo *repository.UserSubscriptionRepository, + apiKeyRepo ports.ApiKeyRepository, + userRepo ports.UserRepository, + groupRepo ports.GroupRepository, + userSubRepo ports.UserSubscriptionRepository, rdb *redis.Client, cfg *config.Config, ) *ApiKeyService { @@ -237,7 +238,7 @@ func (s *ApiKeyService) Create(ctx context.Context, userID int64, req CreateApiK } // List 获取用户的API Key列表 -func (s *ApiKeyService) List(ctx context.Context, userID int64, params repository.PaginationParams) ([]model.ApiKey, *repository.PaginationResult, error) { +func (s *ApiKeyService) List(ctx context.Context, userID int64, params pagination.PaginationParams) ([]model.ApiKey, *pagination.PaginationResult, error) { keys, pagination, err := s.apiKeyRepo.ListByUserID(ctx, userID, params) if err != nil { return nil, nil, fmt.Errorf("list api keys: %w", err) diff --git a/backend/internal/service/auth_service.go b/backend/internal/service/auth_service.go index a93a8450..def9f417 100644 --- a/backend/internal/service/auth_service.go +++ b/backend/internal/service/auth_service.go @@ -7,7 +7,7 @@ import ( "log" "sub2api/internal/config" "sub2api/internal/model" - "sub2api/internal/repository" + "sub2api/internal/service/ports" "time" "github.com/golang-jwt/jwt/v5" @@ -35,7 +35,7 @@ type JWTClaims struct { // AuthService 认证服务 type AuthService struct { - userRepo *repository.UserRepository + userRepo ports.UserRepository cfg *config.Config settingService *SettingService emailService *EmailService @@ -45,7 +45,7 @@ type AuthService struct { // NewAuthService 创建认证服务实例 func NewAuthService( - userRepo *repository.UserRepository, + userRepo ports.UserRepository, cfg *config.Config, settingService *SettingService, emailService *EmailService, diff --git a/backend/internal/service/billing_cache_service.go b/backend/internal/service/billing_cache_service.go index 0921a153..39384da6 100644 --- a/backend/internal/service/billing_cache_service.go +++ b/backend/internal/service/billing_cache_service.go @@ -9,7 +9,7 @@ import ( "time" "sub2api/internal/model" - "sub2api/internal/repository" + "sub2api/internal/service/ports" "github.com/redis/go-redis/v9" ) @@ -81,12 +81,12 @@ type subscriptionCacheData struct { // 负责余额和订阅数据的缓存管理,提供高性能的计费资格检查 type BillingCacheService struct { rdb *redis.Client - userRepo *repository.UserRepository - subRepo *repository.UserSubscriptionRepository + userRepo ports.UserRepository + subRepo ports.UserSubscriptionRepository } // NewBillingCacheService 创建计费缓存服务 -func NewBillingCacheService(rdb *redis.Client, userRepo *repository.UserRepository, subRepo *repository.UserSubscriptionRepository) *BillingCacheService { +func NewBillingCacheService(rdb *redis.Client, userRepo ports.UserRepository, subRepo ports.UserSubscriptionRepository) *BillingCacheService { return &BillingCacheService{ rdb: rdb, userRepo: userRepo, diff --git a/backend/internal/service/email_service.go b/backend/internal/service/email_service.go index c1fa69f0..01cd98b4 100644 --- a/backend/internal/service/email_service.go +++ b/backend/internal/service/email_service.go @@ -11,7 +11,7 @@ import ( "net/smtp" "strconv" "sub2api/internal/model" - "sub2api/internal/repository" + "sub2api/internal/service/ports" "time" "github.com/redis/go-redis/v9" @@ -25,9 +25,9 @@ var ( ) const ( - verifyCodeKeyPrefix = "email_verify:" - verifyCodeTTL = 15 * time.Minute - verifyCodeCooldown = 1 * time.Minute + verifyCodeKeyPrefix = "email_verify:" + verifyCodeTTL = 15 * time.Minute + verifyCodeCooldown = 1 * time.Minute maxVerifyCodeAttempts = 5 ) @@ -51,12 +51,12 @@ type SmtpConfig struct { // EmailService 邮件服务 type EmailService struct { - settingRepo *repository.SettingRepository + settingRepo ports.SettingRepository rdb *redis.Client } // NewEmailService 创建邮件服务实例 -func NewEmailService(settingRepo *repository.SettingRepository, rdb *redis.Client) *EmailService { +func NewEmailService(settingRepo ports.SettingRepository, rdb *redis.Client) *EmailService { return &EmailService{ settingRepo: settingRepo, rdb: rdb, diff --git a/backend/internal/service/gateway_service.go b/backend/internal/service/gateway_service.go index 5991a2db..aa490ea1 100644 --- a/backend/internal/service/gateway_service.go +++ b/backend/internal/service/gateway_service.go @@ -21,7 +21,7 @@ import ( "sub2api/internal/config" "sub2api/internal/model" "sub2api/internal/pkg/claude" - "sub2api/internal/repository" + "sub2api/internal/service/ports" "github.com/gin-gonic/gin" "github.com/redis/go-redis/v9" @@ -78,7 +78,10 @@ type ForwardResult struct { // GatewayService handles API gateway operations type GatewayService struct { - repos *repository.Repositories + accountRepo ports.AccountRepository + usageLogRepo ports.UsageLogRepository + userRepo ports.UserRepository + userSubRepo ports.UserSubscriptionRepository rdb *redis.Client cfg *config.Config oauthService *OAuthService @@ -90,7 +93,19 @@ type GatewayService struct { } // NewGatewayService creates a new GatewayService -func NewGatewayService(repos *repository.Repositories, rdb *redis.Client, cfg *config.Config, oauthService *OAuthService, billingService *BillingService, rateLimitService *RateLimitService, billingCacheService *BillingCacheService, identityService *IdentityService) *GatewayService { +func NewGatewayService( + accountRepo ports.AccountRepository, + usageLogRepo ports.UsageLogRepository, + userRepo ports.UserRepository, + userSubRepo ports.UserSubscriptionRepository, + rdb *redis.Client, + cfg *config.Config, + oauthService *OAuthService, + billingService *BillingService, + rateLimitService *RateLimitService, + billingCacheService *BillingCacheService, + identityService *IdentityService, +) *GatewayService { // 计算响应头超时时间 responseHeaderTimeout := time.Duration(cfg.Gateway.ResponseHeaderTimeout) * time.Second if responseHeaderTimeout == 0 { @@ -105,7 +120,10 @@ func NewGatewayService(repos *repository.Repositories, rdb *redis.Client, cfg *c // 注意:不设置整体 Timeout,让流式响应可以无限时间传输 } return &GatewayService{ - repos: repos, + accountRepo: accountRepo, + usageLogRepo: usageLogRepo, + userRepo: userRepo, + userSubRepo: userSubRepo, rdb: rdb, cfg: cfg, oauthService: oauthService, @@ -274,7 +292,7 @@ func (s *GatewayService) SelectAccountForModel(ctx context.Context, groupID *int if sessionHash != "" { accountID, err := s.rdb.Get(ctx, stickySessionPrefix+sessionHash).Int64() if err == nil && accountID > 0 { - account, err := s.repos.Account.GetByID(ctx, accountID) + account, err := s.accountRepo.GetByID(ctx, accountID) // 使用IsSchedulable代替IsActive,确保限流/过载账号不会被选中 // 同时检查模型支持 if err == nil && account.IsSchedulable() && (requestedModel == "" || account.IsModelSupported(requestedModel)) { @@ -289,9 +307,9 @@ func (s *GatewayService) SelectAccountForModel(ctx context.Context, groupID *int var accounts []model.Account var err error if groupID != nil { - accounts, err = s.repos.Account.ListSchedulableByGroupID(ctx, *groupID) + accounts, err = s.accountRepo.ListSchedulableByGroupID(ctx, *groupID) } else { - accounts, err = s.repos.Account.ListSchedulable(ctx) + accounts, err = s.accountRepo.ListSchedulable(ctx) } if err != nil { return nil, fmt.Errorf("query accounts failed: %w", err) @@ -378,7 +396,7 @@ func (s *GatewayService) getOAuthToken(ctx context.Context, account *model.Accou account.Credentials["refresh_token"] = tokenInfo.RefreshToken } - if err := s.repos.Account.Update(ctx, account); err != nil { + if err := s.accountRepo.Update(ctx, account); err != nil { log.Printf("Failed to update account credentials: %v", err) } @@ -667,7 +685,7 @@ func (s *GatewayService) forceRefreshToken(ctx context.Context, account *model.A account.Credentials["refresh_token"] = tokenInfo.RefreshToken } - if err := s.repos.Account.Update(ctx, account); err != nil { + if err := s.accountRepo.Update(ctx, account); err != nil { log.Printf("Failed to update account credentials: %v", err) } @@ -999,7 +1017,7 @@ func (s *GatewayService) RecordUsage(ctx context.Context, input *RecordUsageInpu usageLog.SubscriptionID = &subscription.ID } - if err := s.repos.UsageLog.Create(ctx, usageLog); err != nil { + if err := s.usageLogRepo.Create(ctx, usageLog); err != nil { log.Printf("Create usage log failed: %v", err) } @@ -1007,7 +1025,7 @@ func (s *GatewayService) RecordUsage(ctx context.Context, input *RecordUsageInpu if isSubscriptionBilling { // 订阅模式:更新订阅用量(使用 TotalCost 原始费用,不考虑倍率) if cost.TotalCost > 0 { - if err := s.repos.UserSubscription.IncrementUsage(ctx, subscription.ID, cost.TotalCost); err != nil { + if err := s.userSubRepo.IncrementUsage(ctx, subscription.ID, cost.TotalCost); err != nil { log.Printf("Increment subscription usage failed: %v", err) } // 异步更新订阅缓存 @@ -1022,7 +1040,7 @@ func (s *GatewayService) RecordUsage(ctx context.Context, input *RecordUsageInpu } else { // 余额模式:扣除用户余额(使用 ActualCost 考虑倍率后的费用) if cost.ActualCost > 0 { - if err := s.repos.User.DeductBalance(ctx, user.ID, cost.ActualCost); err != nil { + if err := s.userRepo.DeductBalance(ctx, user.ID, cost.ActualCost); err != nil { log.Printf("Deduct balance failed: %v", err) } // 异步更新余额缓存 @@ -1037,7 +1055,7 @@ func (s *GatewayService) RecordUsage(ctx context.Context, input *RecordUsageInpu } // 更新账号最后使用时间 - if err := s.repos.Account.UpdateLastUsed(ctx, account.ID); err != nil { + if err := s.accountRepo.UpdateLastUsed(ctx, account.ID); err != nil { log.Printf("Update last used failed: %v", err) } diff --git a/backend/internal/service/group_service.go b/backend/internal/service/group_service.go index 9c74c626..59222740 100644 --- a/backend/internal/service/group_service.go +++ b/backend/internal/service/group_service.go @@ -5,7 +5,8 @@ import ( "errors" "fmt" "sub2api/internal/model" - "sub2api/internal/repository" + "sub2api/internal/pkg/pagination" + "sub2api/internal/service/ports" "gorm.io/gorm" ) @@ -34,11 +35,11 @@ type UpdateGroupRequest struct { // GroupService 分组管理服务 type GroupService struct { - groupRepo *repository.GroupRepository + groupRepo ports.GroupRepository } // NewGroupService 创建分组服务实例 -func NewGroupService(groupRepo *repository.GroupRepository) *GroupService { +func NewGroupService(groupRepo ports.GroupRepository) *GroupService { return &GroupService{ groupRepo: groupRepo, } @@ -84,7 +85,7 @@ func (s *GroupService) GetByID(ctx context.Context, id int64) (*model.Group, err } // List 获取分组列表 -func (s *GroupService) List(ctx context.Context, params repository.PaginationParams) ([]model.Group, *repository.PaginationResult, error) { +func (s *GroupService) List(ctx context.Context, params pagination.PaginationParams) ([]model.Group, *pagination.PaginationResult, error) { groups, pagination, err := s.groupRepo.List(ctx, params) if err != nil { return nil, nil, fmt.Errorf("list groups: %w", err) diff --git a/backend/internal/service/oauth_service.go b/backend/internal/service/oauth_service.go index 7e616fc4..829e267d 100644 --- a/backend/internal/service/oauth_service.go +++ b/backend/internal/service/oauth_service.go @@ -12,7 +12,7 @@ import ( "sub2api/internal/model" "sub2api/internal/pkg/oauth" - "sub2api/internal/repository" + "sub2api/internal/service/ports" "github.com/imroc/req/v3" ) @@ -20,11 +20,11 @@ import ( // OAuthService handles OAuth authentication flows type OAuthService struct { sessionStore *oauth.SessionStore - proxyRepo *repository.ProxyRepository + proxyRepo ports.ProxyRepository } // NewOAuthService creates a new OAuth service -func NewOAuthService(proxyRepo *repository.ProxyRepository) *OAuthService { +func NewOAuthService(proxyRepo ports.ProxyRepository) *OAuthService { return &OAuthService{ sessionStore: oauth.NewSessionStore(), proxyRepo: proxyRepo, @@ -459,7 +459,7 @@ func (s *OAuthService) RefreshAccountToken(ctx context.Context, account *model.A // createReqClient creates a req client with Chrome impersonation and optional proxy func (s *OAuthService) createReqClient(proxyURL string) *req.Client { client := req.C(). - ImpersonateChrome(). // Impersonate Chrome browser to bypass Cloudflare + ImpersonateChrome(). // Impersonate Chrome browser to bypass Cloudflare SetTimeout(60 * time.Second) // Set proxy if specified diff --git a/backend/internal/service/ports/account.go b/backend/internal/service/ports/account.go new file mode 100644 index 00000000..95b597d1 --- /dev/null +++ b/backend/internal/service/ports/account.go @@ -0,0 +1,35 @@ +package ports + +import ( + "context" + "time" + + "sub2api/internal/model" + "sub2api/internal/pkg/pagination" +) + +type AccountRepository interface { + Create(ctx context.Context, account *model.Account) error + GetByID(ctx context.Context, id int64) (*model.Account, error) + Update(ctx context.Context, account *model.Account) error + Delete(ctx context.Context, id int64) error + + List(ctx context.Context, params pagination.PaginationParams) ([]model.Account, *pagination.PaginationResult, error) + ListWithFilters(ctx context.Context, params pagination.PaginationParams, platform, accountType, status, search string) ([]model.Account, *pagination.PaginationResult, error) + ListByGroup(ctx context.Context, groupID int64) ([]model.Account, error) + ListActive(ctx context.Context) ([]model.Account, error) + ListByPlatform(ctx context.Context, platform string) ([]model.Account, error) + + UpdateLastUsed(ctx context.Context, id int64) error + SetError(ctx context.Context, id int64, errorMsg string) error + SetSchedulable(ctx context.Context, id int64, schedulable bool) error + BindGroups(ctx context.Context, accountID int64, groupIDs []int64) error + + ListSchedulable(ctx context.Context) ([]model.Account, error) + ListSchedulableByGroupID(ctx context.Context, groupID int64) ([]model.Account, error) + + SetRateLimited(ctx context.Context, id int64, resetAt time.Time) error + SetOverloaded(ctx context.Context, id int64, until time.Time) error + ClearRateLimit(ctx context.Context, id int64) error + UpdateSessionWindow(ctx context.Context, id int64, start, end *time.Time, status string) error +} diff --git a/backend/internal/service/ports/api_key.go b/backend/internal/service/ports/api_key.go new file mode 100644 index 00000000..6440971b --- /dev/null +++ b/backend/internal/service/ports/api_key.go @@ -0,0 +1,24 @@ +package ports + +import ( + "context" + + "sub2api/internal/model" + "sub2api/internal/pkg/pagination" +) + +type ApiKeyRepository interface { + Create(ctx context.Context, key *model.ApiKey) error + GetByID(ctx context.Context, id int64) (*model.ApiKey, error) + GetByKey(ctx context.Context, key string) (*model.ApiKey, error) + Update(ctx context.Context, key *model.ApiKey) error + Delete(ctx context.Context, id int64) error + + ListByUserID(ctx context.Context, userID int64, params pagination.PaginationParams) ([]model.ApiKey, *pagination.PaginationResult, error) + CountByUserID(ctx context.Context, userID int64) (int64, error) + ExistsByKey(ctx context.Context, key string) (bool, error) + ListByGroupID(ctx context.Context, groupID int64, params pagination.PaginationParams) ([]model.ApiKey, *pagination.PaginationResult, error) + SearchApiKeys(ctx context.Context, userID int64, keyword string, limit int) ([]model.ApiKey, error) + ClearGroupIDByGroupID(ctx context.Context, groupID int64) (int64, error) + CountByGroupID(ctx context.Context, groupID int64) (int64, error) +} diff --git a/backend/internal/service/ports/group.go b/backend/internal/service/ports/group.go new file mode 100644 index 00000000..e0e102b6 --- /dev/null +++ b/backend/internal/service/ports/group.go @@ -0,0 +1,28 @@ +package ports + +import ( + "context" + + "sub2api/internal/model" + "sub2api/internal/pkg/pagination" + + "gorm.io/gorm" +) + +type GroupRepository interface { + Create(ctx context.Context, group *model.Group) error + GetByID(ctx context.Context, id int64) (*model.Group, error) + Update(ctx context.Context, group *model.Group) error + Delete(ctx context.Context, id int64) error + + List(ctx context.Context, params pagination.PaginationParams) ([]model.Group, *pagination.PaginationResult, error) + ListWithFilters(ctx context.Context, params pagination.PaginationParams, platform, status string, isExclusive *bool) ([]model.Group, *pagination.PaginationResult, error) + ListActive(ctx context.Context) ([]model.Group, error) + ListActiveByPlatform(ctx context.Context, platform string) ([]model.Group, error) + + ExistsByName(ctx context.Context, name string) (bool, error) + GetAccountCount(ctx context.Context, groupID int64) (int64, error) + DeleteAccountGroupsByGroupID(ctx context.Context, groupID int64) (int64, error) + + DB() *gorm.DB +} diff --git a/backend/internal/service/ports/proxy.go b/backend/internal/service/ports/proxy.go new file mode 100644 index 00000000..b97ed561 --- /dev/null +++ b/backend/internal/service/ports/proxy.go @@ -0,0 +1,23 @@ +package ports + +import ( + "context" + + "sub2api/internal/model" + "sub2api/internal/pkg/pagination" +) + +type ProxyRepository interface { + Create(ctx context.Context, proxy *model.Proxy) error + GetByID(ctx context.Context, id int64) (*model.Proxy, error) + Update(ctx context.Context, proxy *model.Proxy) error + Delete(ctx context.Context, id int64) error + + List(ctx context.Context, params pagination.PaginationParams) ([]model.Proxy, *pagination.PaginationResult, error) + ListWithFilters(ctx context.Context, params pagination.PaginationParams, protocol, status, search string) ([]model.Proxy, *pagination.PaginationResult, error) + ListActive(ctx context.Context) ([]model.Proxy, error) + ListActiveWithAccountCount(ctx context.Context) ([]model.ProxyWithAccountCount, error) + + ExistsByHostPortAuth(ctx context.Context, host string, port int, username, password string) (bool, error) + CountAccountsByProxyID(ctx context.Context, proxyID int64) (int64, error) +} diff --git a/backend/internal/service/ports/redeem_code.go b/backend/internal/service/ports/redeem_code.go new file mode 100644 index 00000000..27f41b59 --- /dev/null +++ b/backend/internal/service/ports/redeem_code.go @@ -0,0 +1,22 @@ +package ports + +import ( + "context" + + "sub2api/internal/model" + "sub2api/internal/pkg/pagination" +) + +type RedeemCodeRepository interface { + Create(ctx context.Context, code *model.RedeemCode) error + CreateBatch(ctx context.Context, codes []model.RedeemCode) error + GetByID(ctx context.Context, id int64) (*model.RedeemCode, error) + GetByCode(ctx context.Context, code string) (*model.RedeemCode, error) + Update(ctx context.Context, code *model.RedeemCode) error + Delete(ctx context.Context, id int64) error + Use(ctx context.Context, id, userID int64) error + + List(ctx context.Context, params pagination.PaginationParams) ([]model.RedeemCode, *pagination.PaginationResult, error) + ListWithFilters(ctx context.Context, params pagination.PaginationParams, codeType, status, search string) ([]model.RedeemCode, *pagination.PaginationResult, error) + ListByUser(ctx context.Context, userID int64, limit int) ([]model.RedeemCode, error) +} diff --git a/backend/internal/service/ports/setting.go b/backend/internal/service/ports/setting.go new file mode 100644 index 00000000..4ce0fe46 --- /dev/null +++ b/backend/internal/service/ports/setting.go @@ -0,0 +1,17 @@ +package ports + +import ( + "context" + + "sub2api/internal/model" +) + +type SettingRepository interface { + Get(ctx context.Context, key string) (*model.Setting, error) + GetValue(ctx context.Context, key string) (string, error) + Set(ctx context.Context, key, value string) error + GetMultiple(ctx context.Context, keys []string) (map[string]string, error) + SetMultiple(ctx context.Context, settings map[string]string) error + GetAll(ctx context.Context) (map[string]string, error) + Delete(ctx context.Context, key string) error +} diff --git a/backend/internal/service/ports/usage_log.go b/backend/internal/service/ports/usage_log.go new file mode 100644 index 00000000..d8ac8a37 --- /dev/null +++ b/backend/internal/service/ports/usage_log.go @@ -0,0 +1,28 @@ +package ports + +import ( + "context" + "time" + + "sub2api/internal/model" + "sub2api/internal/pkg/pagination" + "sub2api/internal/pkg/usagestats" +) + +type UsageLogRepository interface { + Create(ctx context.Context, log *model.UsageLog) error + GetByID(ctx context.Context, id int64) (*model.UsageLog, error) + Delete(ctx context.Context, id int64) error + + ListByUser(ctx context.Context, userID int64, params pagination.PaginationParams) ([]model.UsageLog, *pagination.PaginationResult, error) + ListByApiKey(ctx context.Context, apiKeyID int64, params pagination.PaginationParams) ([]model.UsageLog, *pagination.PaginationResult, error) + ListByAccount(ctx context.Context, accountID int64, params pagination.PaginationParams) ([]model.UsageLog, *pagination.PaginationResult, error) + + ListByUserAndTimeRange(ctx context.Context, userID int64, startTime, endTime time.Time) ([]model.UsageLog, *pagination.PaginationResult, error) + ListByApiKeyAndTimeRange(ctx context.Context, apiKeyID int64, startTime, endTime time.Time) ([]model.UsageLog, *pagination.PaginationResult, error) + ListByAccountAndTimeRange(ctx context.Context, accountID int64, startTime, endTime time.Time) ([]model.UsageLog, *pagination.PaginationResult, error) + ListByModelAndTimeRange(ctx context.Context, modelName string, startTime, endTime time.Time) ([]model.UsageLog, *pagination.PaginationResult, error) + + GetAccountWindowStats(ctx context.Context, accountID int64, startTime time.Time) (*usagestats.AccountStats, error) + GetAccountTodayStats(ctx context.Context, accountID int64) (*usagestats.AccountStats, error) +} diff --git a/backend/internal/service/ports/user.go b/backend/internal/service/ports/user.go new file mode 100644 index 00000000..44dcec8c --- /dev/null +++ b/backend/internal/service/ports/user.go @@ -0,0 +1,25 @@ +package ports + +import ( + "context" + + "sub2api/internal/model" + "sub2api/internal/pkg/pagination" +) + +type UserRepository interface { + Create(ctx context.Context, user *model.User) error + GetByID(ctx context.Context, id int64) (*model.User, error) + GetByEmail(ctx context.Context, email string) (*model.User, error) + Update(ctx context.Context, user *model.User) error + Delete(ctx context.Context, id int64) error + + List(ctx context.Context, params pagination.PaginationParams) ([]model.User, *pagination.PaginationResult, error) + ListWithFilters(ctx context.Context, params pagination.PaginationParams, status, role, search string) ([]model.User, *pagination.PaginationResult, error) + + UpdateBalance(ctx context.Context, id int64, amount float64) error + DeductBalance(ctx context.Context, id int64, amount float64) error + UpdateConcurrency(ctx context.Context, id int64, amount int) error + ExistsByEmail(ctx context.Context, email string) (bool, error) + RemoveGroupFromAllowedGroups(ctx context.Context, groupID int64) (int64, error) +} diff --git a/backend/internal/service/ports/user_subscription.go b/backend/internal/service/ports/user_subscription.go new file mode 100644 index 00000000..ba41dd89 --- /dev/null +++ b/backend/internal/service/ports/user_subscription.go @@ -0,0 +1,36 @@ +package ports + +import ( + "context" + "time" + + "sub2api/internal/model" + "sub2api/internal/pkg/pagination" +) + +type UserSubscriptionRepository interface { + Create(ctx context.Context, sub *model.UserSubscription) error + GetByID(ctx context.Context, id int64) (*model.UserSubscription, error) + GetByUserIDAndGroupID(ctx context.Context, userID, groupID int64) (*model.UserSubscription, error) + GetActiveByUserIDAndGroupID(ctx context.Context, userID, groupID int64) (*model.UserSubscription, error) + Update(ctx context.Context, sub *model.UserSubscription) error + Delete(ctx context.Context, id int64) error + + ListByUserID(ctx context.Context, userID int64) ([]model.UserSubscription, error) + ListActiveByUserID(ctx context.Context, userID int64) ([]model.UserSubscription, error) + ListByGroupID(ctx context.Context, groupID int64, params pagination.PaginationParams) ([]model.UserSubscription, *pagination.PaginationResult, error) + List(ctx context.Context, params pagination.PaginationParams, userID, groupID *int64, status string) ([]model.UserSubscription, *pagination.PaginationResult, error) + + ExistsByUserIDAndGroupID(ctx context.Context, userID, groupID int64) (bool, error) + ExtendExpiry(ctx context.Context, subscriptionID int64, newExpiresAt time.Time) error + UpdateStatus(ctx context.Context, subscriptionID int64, status string) error + UpdateNotes(ctx context.Context, subscriptionID int64, notes string) error + + ActivateWindows(ctx context.Context, id int64, start time.Time) error + ResetDailyUsage(ctx context.Context, id int64, newWindowStart time.Time) error + ResetWeeklyUsage(ctx context.Context, id int64, newWindowStart time.Time) error + ResetMonthlyUsage(ctx context.Context, id int64, newWindowStart time.Time) error + IncrementUsage(ctx context.Context, id int64, costUSD float64) error + + BatchUpdateExpiredStatus(ctx context.Context) (int64, error) +} diff --git a/backend/internal/service/proxy_service.go b/backend/internal/service/proxy_service.go index 92cba53a..a6a1f924 100644 --- a/backend/internal/service/proxy_service.go +++ b/backend/internal/service/proxy_service.go @@ -5,7 +5,8 @@ import ( "errors" "fmt" "sub2api/internal/model" - "sub2api/internal/repository" + "sub2api/internal/pkg/pagination" + "sub2api/internal/service/ports" "gorm.io/gorm" ) @@ -37,11 +38,11 @@ type UpdateProxyRequest struct { // ProxyService 代理管理服务 type ProxyService struct { - proxyRepo *repository.ProxyRepository + proxyRepo ports.ProxyRepository } // NewProxyService 创建代理服务实例 -func NewProxyService(proxyRepo *repository.ProxyRepository) *ProxyService { +func NewProxyService(proxyRepo ports.ProxyRepository) *ProxyService { return &ProxyService{ proxyRepo: proxyRepo, } @@ -80,7 +81,7 @@ func (s *ProxyService) GetByID(ctx context.Context, id int64) (*model.Proxy, err } // List 获取代理列表 -func (s *ProxyService) List(ctx context.Context, params repository.PaginationParams) ([]model.Proxy, *repository.PaginationResult, error) { +func (s *ProxyService) List(ctx context.Context, params pagination.PaginationParams) ([]model.Proxy, *pagination.PaginationResult, error) { proxies, pagination, err := s.proxyRepo.List(ctx, params) if err != nil { return nil, nil, fmt.Errorf("list proxies: %w", err) diff --git a/backend/internal/service/ratelimit_service.go b/backend/internal/service/ratelimit_service.go index 43feef30..41d3736f 100644 --- a/backend/internal/service/ratelimit_service.go +++ b/backend/internal/service/ratelimit_service.go @@ -9,20 +9,20 @@ import ( "sub2api/internal/config" "sub2api/internal/model" - "sub2api/internal/repository" + "sub2api/internal/service/ports" ) // RateLimitService 处理限流和过载状态管理 type RateLimitService struct { - repos *repository.Repositories - cfg *config.Config + accountRepo ports.AccountRepository + cfg *config.Config } // NewRateLimitService 创建RateLimitService实例 -func NewRateLimitService(repos *repository.Repositories, cfg *config.Config) *RateLimitService { +func NewRateLimitService(accountRepo ports.AccountRepository, cfg *config.Config) *RateLimitService { return &RateLimitService{ - repos: repos, - cfg: cfg, + accountRepo: accountRepo, + cfg: cfg, } } @@ -62,7 +62,7 @@ func (s *RateLimitService) HandleUpstreamError(ctx context.Context, account *mod // handleAuthError 处理认证类错误(401/403),停止账号调度 func (s *RateLimitService) handleAuthError(ctx context.Context, account *model.Account, errorMsg string) { - if err := s.repos.Account.SetError(ctx, account.ID, errorMsg); err != nil { + if err := s.accountRepo.SetError(ctx, account.ID, errorMsg); err != nil { log.Printf("SetError failed for account %d: %v", account.ID, err) return } @@ -77,7 +77,7 @@ func (s *RateLimitService) handle429(ctx context.Context, account *model.Account if resetTimestamp == "" { // 没有重置时间,使用默认5分钟 resetAt := time.Now().Add(5 * time.Minute) - if err := s.repos.Account.SetRateLimited(ctx, account.ID, resetAt); err != nil { + if err := s.accountRepo.SetRateLimited(ctx, account.ID, resetAt); err != nil { log.Printf("SetRateLimited failed for account %d: %v", account.ID, err) } return @@ -88,7 +88,7 @@ func (s *RateLimitService) handle429(ctx context.Context, account *model.Account if err != nil { log.Printf("Parse reset timestamp failed: %v", err) resetAt := time.Now().Add(5 * time.Minute) - if err := s.repos.Account.SetRateLimited(ctx, account.ID, resetAt); err != nil { + if err := s.accountRepo.SetRateLimited(ctx, account.ID, resetAt); err != nil { log.Printf("SetRateLimited failed for account %d: %v", account.ID, err) } return @@ -97,7 +97,7 @@ func (s *RateLimitService) handle429(ctx context.Context, account *model.Account resetAt := time.Unix(ts, 0) // 标记限流状态 - if err := s.repos.Account.SetRateLimited(ctx, account.ID, resetAt); err != nil { + if err := s.accountRepo.SetRateLimited(ctx, account.ID, resetAt); err != nil { log.Printf("SetRateLimited failed for account %d: %v", account.ID, err) return } @@ -105,7 +105,7 @@ func (s *RateLimitService) handle429(ctx context.Context, account *model.Account // 根据重置时间反推5h窗口 windowEnd := resetAt windowStart := resetAt.Add(-5 * time.Hour) - if err := s.repos.Account.UpdateSessionWindow(ctx, account.ID, &windowStart, &windowEnd, "rejected"); err != nil { + if err := s.accountRepo.UpdateSessionWindow(ctx, account.ID, &windowStart, &windowEnd, "rejected"); err != nil { log.Printf("UpdateSessionWindow failed for account %d: %v", account.ID, err) } @@ -121,7 +121,7 @@ func (s *RateLimitService) handle529(ctx context.Context, account *model.Account } until := time.Now().Add(time.Duration(cooldownMinutes) * time.Minute) - if err := s.repos.Account.SetOverloaded(ctx, account.ID, until); err != nil { + if err := s.accountRepo.SetOverloaded(ctx, account.ID, until); err != nil { log.Printf("SetOverloaded failed for account %d: %v", account.ID, err) return } @@ -152,13 +152,13 @@ func (s *RateLimitService) UpdateSessionWindow(ctx context.Context, account *mod log.Printf("Account %d: initializing 5h window from %v to %v (status: %s)", account.ID, start, end, status) } - if err := s.repos.Account.UpdateSessionWindow(ctx, account.ID, windowStart, windowEnd, status); err != nil { + if err := s.accountRepo.UpdateSessionWindow(ctx, account.ID, windowStart, windowEnd, status); err != nil { log.Printf("UpdateSessionWindow failed for account %d: %v", account.ID, err) } // 如果状态为allowed且之前有限流,说明窗口已重置,清除限流状态 if status == "allowed" && account.IsRateLimited() { - if err := s.repos.Account.ClearRateLimit(ctx, account.ID); err != nil { + if err := s.accountRepo.ClearRateLimit(ctx, account.ID); err != nil { log.Printf("ClearRateLimit failed for account %d: %v", account.ID, err) } } @@ -166,5 +166,5 @@ func (s *RateLimitService) UpdateSessionWindow(ctx context.Context, account *mod // ClearRateLimit 清除账号的限流状态 func (s *RateLimitService) ClearRateLimit(ctx context.Context, accountID int64) error { - return s.repos.Account.ClearRateLimit(ctx, accountID) + return s.accountRepo.ClearRateLimit(ctx, accountID) } diff --git a/backend/internal/service/redeem_service.go b/backend/internal/service/redeem_service.go index d9c45364..72137f1b 100644 --- a/backend/internal/service/redeem_service.go +++ b/backend/internal/service/redeem_service.go @@ -8,7 +8,8 @@ import ( "fmt" "strings" "sub2api/internal/model" - "sub2api/internal/repository" + "sub2api/internal/pkg/pagination" + "sub2api/internal/service/ports" "time" "github.com/redis/go-redis/v9" @@ -49,8 +50,8 @@ type RedeemCodeResponse struct { // RedeemService 兑换码服务 type RedeemService struct { - redeemRepo *repository.RedeemCodeRepository - userRepo *repository.UserRepository + redeemRepo ports.RedeemCodeRepository + userRepo ports.UserRepository subscriptionService *SubscriptionService rdb *redis.Client billingCacheService *BillingCacheService @@ -58,8 +59,8 @@ type RedeemService struct { // NewRedeemService 创建兑换码服务实例 func NewRedeemService( - redeemRepo *repository.RedeemCodeRepository, - userRepo *repository.UserRepository, + redeemRepo ports.RedeemCodeRepository, + userRepo ports.UserRepository, subscriptionService *SubscriptionService, rdb *redis.Client, billingCacheService *BillingCacheService, @@ -337,7 +338,7 @@ func (s *RedeemService) GetByCode(ctx context.Context, code string) (*model.Rede } // List 获取兑换码列表(管理员功能) -func (s *RedeemService) List(ctx context.Context, params repository.PaginationParams) ([]model.RedeemCode, *repository.PaginationResult, error) { +func (s *RedeemService) List(ctx context.Context, params pagination.PaginationParams) ([]model.RedeemCode, *pagination.PaginationResult, error) { codes, pagination, err := s.redeemRepo.List(ctx, params) if err != nil { return nil, nil, fmt.Errorf("list redeem codes: %w", err) diff --git a/backend/internal/service/setting_service.go b/backend/internal/service/setting_service.go index 013b07f6..afd6a3a9 100644 --- a/backend/internal/service/setting_service.go +++ b/backend/internal/service/setting_service.go @@ -7,7 +7,7 @@ import ( "strconv" "sub2api/internal/config" "sub2api/internal/model" - "sub2api/internal/repository" + "sub2api/internal/service/ports" "gorm.io/gorm" ) @@ -18,12 +18,12 @@ var ( // SettingService 系统设置服务 type SettingService struct { - settingRepo *repository.SettingRepository + settingRepo ports.SettingRepository cfg *config.Config } // NewSettingService 创建系统设置服务实例 -func NewSettingService(settingRepo *repository.SettingRepository, cfg *config.Config) *SettingService { +func NewSettingService(settingRepo ports.SettingRepository, cfg *config.Config) *SettingService { return &SettingService{ settingRepo: settingRepo, cfg: cfg, diff --git a/backend/internal/service/subscription_service.go b/backend/internal/service/subscription_service.go index 1e1bd3c6..590b09e8 100644 --- a/backend/internal/service/subscription_service.go +++ b/backend/internal/service/subscription_service.go @@ -7,7 +7,8 @@ import ( "time" "sub2api/internal/model" - "sub2api/internal/repository" + "sub2api/internal/pkg/pagination" + "sub2api/internal/service/ports" ) var ( @@ -23,14 +24,16 @@ var ( // SubscriptionService 订阅服务 type SubscriptionService struct { - repos *repository.Repositories + groupRepo ports.GroupRepository + userSubRepo ports.UserSubscriptionRepository billingCacheService *BillingCacheService } // NewSubscriptionService 创建订阅服务 -func NewSubscriptionService(repos *repository.Repositories, billingCacheService *BillingCacheService) *SubscriptionService { +func NewSubscriptionService(groupRepo ports.GroupRepository, userSubRepo ports.UserSubscriptionRepository, billingCacheService *BillingCacheService) *SubscriptionService { return &SubscriptionService{ - repos: repos, + groupRepo: groupRepo, + userSubRepo: userSubRepo, billingCacheService: billingCacheService, } } @@ -47,7 +50,7 @@ type AssignSubscriptionInput struct { // AssignSubscription 分配订阅给用户(不允许重复分配) func (s *SubscriptionService) AssignSubscription(ctx context.Context, input *AssignSubscriptionInput) (*model.UserSubscription, error) { // 检查分组是否存在且为订阅类型 - group, err := s.repos.Group.GetByID(ctx, input.GroupID) + group, err := s.groupRepo.GetByID(ctx, input.GroupID) if err != nil { return nil, fmt.Errorf("group not found: %w", err) } @@ -56,7 +59,7 @@ func (s *SubscriptionService) AssignSubscription(ctx context.Context, input *Ass } // 检查是否已存在订阅 - exists, err := s.repos.UserSubscription.ExistsByUserIDAndGroupID(ctx, input.UserID, input.GroupID) + exists, err := s.userSubRepo.ExistsByUserIDAndGroupID(ctx, input.UserID, input.GroupID) if err != nil { return nil, err } @@ -90,7 +93,7 @@ func (s *SubscriptionService) AssignSubscription(ctx context.Context, input *Ass // 如果没有订阅:创建新订阅 func (s *SubscriptionService) AssignOrExtendSubscription(ctx context.Context, input *AssignSubscriptionInput) (*model.UserSubscription, bool, error) { // 检查分组是否存在且为订阅类型 - group, err := s.repos.Group.GetByID(ctx, input.GroupID) + group, err := s.groupRepo.GetByID(ctx, input.GroupID) if err != nil { return nil, false, fmt.Errorf("group not found: %w", err) } @@ -99,7 +102,7 @@ func (s *SubscriptionService) AssignOrExtendSubscription(ctx context.Context, in } // 查询是否已有订阅 - existingSub, err := s.repos.UserSubscription.GetByUserIDAndGroupID(ctx, input.UserID, input.GroupID) + existingSub, err := s.userSubRepo.GetByUserIDAndGroupID(ctx, input.UserID, input.GroupID) if err != nil { // 不存在记录是正常情况,其他错误需要返回 existingSub = nil @@ -124,13 +127,13 @@ func (s *SubscriptionService) AssignOrExtendSubscription(ctx context.Context, in } // 更新过期时间 - if err := s.repos.UserSubscription.ExtendExpiry(ctx, existingSub.ID, newExpiresAt); err != nil { + if err := s.userSubRepo.ExtendExpiry(ctx, existingSub.ID, newExpiresAt); err != nil { return nil, false, fmt.Errorf("extend subscription: %w", err) } // 如果订阅已过期或被暂停,恢复为active状态 if existingSub.Status != model.SubscriptionStatusActive { - if err := s.repos.UserSubscription.UpdateStatus(ctx, existingSub.ID, model.SubscriptionStatusActive); err != nil { + if err := s.userSubRepo.UpdateStatus(ctx, existingSub.ID, model.SubscriptionStatusActive); err != nil { return nil, false, fmt.Errorf("update subscription status: %w", err) } } @@ -142,7 +145,7 @@ func (s *SubscriptionService) AssignOrExtendSubscription(ctx context.Context, in newNotes += "\n" } newNotes += input.Notes - if err := s.repos.UserSubscription.UpdateNotes(ctx, existingSub.ID, newNotes); err != nil { + if err := s.userSubRepo.UpdateNotes(ctx, existingSub.ID, newNotes); err != nil { // 备注更新失败不影响主流程 } } @@ -158,7 +161,7 @@ func (s *SubscriptionService) AssignOrExtendSubscription(ctx context.Context, in } // 返回更新后的订阅 - sub, err := s.repos.UserSubscription.GetByID(ctx, existingSub.ID) + sub, err := s.userSubRepo.GetByID(ctx, existingSub.ID) return sub, true, err // true 表示是续期 } @@ -205,12 +208,12 @@ func (s *SubscriptionService) createSubscription(ctx context.Context, input *Ass sub.AssignedBy = &input.AssignedBy } - if err := s.repos.UserSubscription.Create(ctx, sub); err != nil { + if err := s.userSubRepo.Create(ctx, sub); err != nil { return nil, err } // 重新获取完整订阅信息(包含关联) - return s.repos.UserSubscription.GetByID(ctx, sub.ID) + return s.userSubRepo.GetByID(ctx, sub.ID) } // BulkAssignSubscriptionInput 批量分配订阅输入 @@ -260,12 +263,12 @@ func (s *SubscriptionService) BulkAssignSubscription(ctx context.Context, input // RevokeSubscription 撤销订阅 func (s *SubscriptionService) RevokeSubscription(ctx context.Context, subscriptionID int64) error { // 先获取订阅信息用于失效缓存 - sub, err := s.repos.UserSubscription.GetByID(ctx, subscriptionID) + sub, err := s.userSubRepo.GetByID(ctx, subscriptionID) if err != nil { return err } - if err := s.repos.UserSubscription.Delete(ctx, subscriptionID); err != nil { + if err := s.userSubRepo.Delete(ctx, subscriptionID); err != nil { return err } @@ -284,20 +287,20 @@ func (s *SubscriptionService) RevokeSubscription(ctx context.Context, subscripti // ExtendSubscription 延长订阅 func (s *SubscriptionService) ExtendSubscription(ctx context.Context, subscriptionID int64, days int) (*model.UserSubscription, error) { - sub, err := s.repos.UserSubscription.GetByID(ctx, subscriptionID) + sub, err := s.userSubRepo.GetByID(ctx, subscriptionID) if err != nil { return nil, ErrSubscriptionNotFound } // 计算新的过期时间 newExpiresAt := sub.ExpiresAt.AddDate(0, 0, days) - if err := s.repos.UserSubscription.ExtendExpiry(ctx, subscriptionID, newExpiresAt); err != nil { + if err := s.userSubRepo.ExtendExpiry(ctx, subscriptionID, newExpiresAt); err != nil { return nil, err } // 如果订阅已过期,恢复为active状态 if sub.Status == model.SubscriptionStatusExpired { - if err := s.repos.UserSubscription.UpdateStatus(ctx, subscriptionID, model.SubscriptionStatusActive); err != nil { + if err := s.userSubRepo.UpdateStatus(ctx, subscriptionID, model.SubscriptionStatusActive); err != nil { return nil, err } } @@ -312,17 +315,17 @@ func (s *SubscriptionService) ExtendSubscription(ctx context.Context, subscripti }() } - return s.repos.UserSubscription.GetByID(ctx, subscriptionID) + return s.userSubRepo.GetByID(ctx, subscriptionID) } // GetByID 根据ID获取订阅 func (s *SubscriptionService) GetByID(ctx context.Context, id int64) (*model.UserSubscription, error) { - return s.repos.UserSubscription.GetByID(ctx, id) + return s.userSubRepo.GetByID(ctx, id) } // GetActiveSubscription 获取用户对特定分组的有效订阅 func (s *SubscriptionService) GetActiveSubscription(ctx context.Context, userID, groupID int64) (*model.UserSubscription, error) { - sub, err := s.repos.UserSubscription.GetActiveByUserIDAndGroupID(ctx, userID, groupID) + sub, err := s.userSubRepo.GetActiveByUserIDAndGroupID(ctx, userID, groupID) if err != nil { return nil, ErrSubscriptionNotFound } @@ -331,24 +334,24 @@ func (s *SubscriptionService) GetActiveSubscription(ctx context.Context, userID, // ListUserSubscriptions 获取用户的所有订阅 func (s *SubscriptionService) ListUserSubscriptions(ctx context.Context, userID int64) ([]model.UserSubscription, error) { - return s.repos.UserSubscription.ListByUserID(ctx, userID) + return s.userSubRepo.ListByUserID(ctx, userID) } // ListActiveUserSubscriptions 获取用户的所有有效订阅 func (s *SubscriptionService) ListActiveUserSubscriptions(ctx context.Context, userID int64) ([]model.UserSubscription, error) { - return s.repos.UserSubscription.ListActiveByUserID(ctx, userID) + return s.userSubRepo.ListActiveByUserID(ctx, userID) } // ListGroupSubscriptions 获取分组的所有订阅 -func (s *SubscriptionService) ListGroupSubscriptions(ctx context.Context, groupID int64, page, pageSize int) ([]model.UserSubscription, *repository.PaginationResult, error) { - params := repository.PaginationParams{Page: page, PageSize: pageSize} - return s.repos.UserSubscription.ListByGroupID(ctx, groupID, params) +func (s *SubscriptionService) ListGroupSubscriptions(ctx context.Context, groupID int64, page, pageSize int) ([]model.UserSubscription, *pagination.PaginationResult, error) { + params := pagination.PaginationParams{Page: page, PageSize: pageSize} + return s.userSubRepo.ListByGroupID(ctx, groupID, params) } // List 获取所有订阅(分页,支持筛选) -func (s *SubscriptionService) List(ctx context.Context, page, pageSize int, userID, groupID *int64, status string) ([]model.UserSubscription, *repository.PaginationResult, error) { - params := repository.PaginationParams{Page: page, PageSize: pageSize} - return s.repos.UserSubscription.List(ctx, params, userID, groupID, status) +func (s *SubscriptionService) List(ctx context.Context, page, pageSize int, userID, groupID *int64, status string) ([]model.UserSubscription, *pagination.PaginationResult, error) { + params := pagination.PaginationParams{Page: page, PageSize: pageSize} + return s.userSubRepo.List(ctx, params, userID, groupID, status) } // CheckAndActivateWindow 检查并激活窗口(首次使用时) @@ -358,7 +361,7 @@ func (s *SubscriptionService) CheckAndActivateWindow(ctx context.Context, sub *m } now := time.Now() - return s.repos.UserSubscription.ActivateWindows(ctx, sub.ID, now) + return s.userSubRepo.ActivateWindows(ctx, sub.ID, now) } // CheckAndResetWindows 检查并重置过期的窗口 @@ -367,7 +370,7 @@ func (s *SubscriptionService) CheckAndResetWindows(ctx context.Context, sub *mod // 日窗口重置(24小时) if sub.NeedsDailyReset() { - if err := s.repos.UserSubscription.ResetDailyUsage(ctx, sub.ID, now); err != nil { + if err := s.userSubRepo.ResetDailyUsage(ctx, sub.ID, now); err != nil { return err } sub.DailyWindowStart = &now @@ -376,7 +379,7 @@ func (s *SubscriptionService) CheckAndResetWindows(ctx context.Context, sub *mod // 周窗口重置(7天) if sub.NeedsWeeklyReset() { - if err := s.repos.UserSubscription.ResetWeeklyUsage(ctx, sub.ID, now); err != nil { + if err := s.userSubRepo.ResetWeeklyUsage(ctx, sub.ID, now); err != nil { return err } sub.WeeklyWindowStart = &now @@ -385,7 +388,7 @@ func (s *SubscriptionService) CheckAndResetWindows(ctx context.Context, sub *mod // 月窗口重置(30天) if sub.NeedsMonthlyReset() { - if err := s.repos.UserSubscription.ResetMonthlyUsage(ctx, sub.ID, now); err != nil { + if err := s.userSubRepo.ResetMonthlyUsage(ctx, sub.ID, now); err != nil { return err } sub.MonthlyWindowStart = &now @@ -411,7 +414,7 @@ func (s *SubscriptionService) CheckUsageLimits(ctx context.Context, sub *model.U // RecordUsage 记录使用量到订阅 func (s *SubscriptionService) RecordUsage(ctx context.Context, subscriptionID int64, costUSD float64) error { - return s.repos.UserSubscription.IncrementUsage(ctx, subscriptionID, costUSD) + return s.userSubRepo.IncrementUsage(ctx, subscriptionID, costUSD) } // SubscriptionProgress 订阅进度 @@ -438,14 +441,14 @@ type UsageWindowProgress struct { // GetSubscriptionProgress 获取订阅使用进度 func (s *SubscriptionService) GetSubscriptionProgress(ctx context.Context, subscriptionID int64) (*SubscriptionProgress, error) { - sub, err := s.repos.UserSubscription.GetByID(ctx, subscriptionID) + sub, err := s.userSubRepo.GetByID(ctx, subscriptionID) if err != nil { return nil, ErrSubscriptionNotFound } group := sub.Group if group == nil { - group, err = s.repos.Group.GetByID(ctx, sub.GroupID) + group, err = s.groupRepo.GetByID(ctx, sub.GroupID) if err != nil { return nil, err } @@ -535,7 +538,7 @@ func (s *SubscriptionService) GetSubscriptionProgress(ctx context.Context, subsc // GetUserSubscriptionsWithProgress 获取用户所有订阅及进度 func (s *SubscriptionService) GetUserSubscriptionsWithProgress(ctx context.Context, userID int64) ([]SubscriptionProgress, error) { - subs, err := s.repos.UserSubscription.ListActiveByUserID(ctx, userID) + subs, err := s.userSubRepo.ListActiveByUserID(ctx, userID) if err != nil { return nil, err } @@ -554,7 +557,7 @@ func (s *SubscriptionService) GetUserSubscriptionsWithProgress(ctx context.Conte // UpdateExpiredSubscriptions 更新过期订阅状态(定时任务调用) func (s *SubscriptionService) UpdateExpiredSubscriptions(ctx context.Context) (int64, error) { - return s.repos.UserSubscription.BatchUpdateExpiredStatus(ctx) + return s.userSubRepo.BatchUpdateExpiredStatus(ctx) } // ValidateSubscription 验证订阅是否有效 @@ -567,7 +570,7 @@ func (s *SubscriptionService) ValidateSubscription(ctx context.Context, sub *mod } if sub.IsExpired() { // 更新状态 - _ = s.repos.UserSubscription.UpdateStatus(ctx, sub.ID, model.SubscriptionStatusExpired) + _ = s.userSubRepo.UpdateStatus(ctx, sub.ID, model.SubscriptionStatusExpired) return ErrSubscriptionExpired } return nil diff --git a/backend/internal/service/usage_service.go b/backend/internal/service/usage_service.go index 214ba26a..97d13950 100644 --- a/backend/internal/service/usage_service.go +++ b/backend/internal/service/usage_service.go @@ -5,7 +5,8 @@ import ( "errors" "fmt" "sub2api/internal/model" - "sub2api/internal/repository" + "sub2api/internal/pkg/pagination" + "sub2api/internal/service/ports" "time" "gorm.io/gorm" @@ -41,24 +42,24 @@ type CreateUsageLogRequest struct { // UsageStats 使用统计 type UsageStats struct { - TotalRequests int64 `json:"total_requests"` - TotalInputTokens int64 `json:"total_input_tokens"` - TotalOutputTokens int64 `json:"total_output_tokens"` - TotalCacheTokens int64 `json:"total_cache_tokens"` - TotalTokens int64 `json:"total_tokens"` - TotalCost float64 `json:"total_cost"` - TotalActualCost float64 `json:"total_actual_cost"` - AverageDurationMs float64 `json:"average_duration_ms"` + TotalRequests int64 `json:"total_requests"` + TotalInputTokens int64 `json:"total_input_tokens"` + TotalOutputTokens int64 `json:"total_output_tokens"` + TotalCacheTokens int64 `json:"total_cache_tokens"` + TotalTokens int64 `json:"total_tokens"` + TotalCost float64 `json:"total_cost"` + TotalActualCost float64 `json:"total_actual_cost"` + AverageDurationMs float64 `json:"average_duration_ms"` } // UsageService 使用统计服务 type UsageService struct { - usageRepo *repository.UsageLogRepository - userRepo *repository.UserRepository + usageRepo ports.UsageLogRepository + userRepo ports.UserRepository } // NewUsageService 创建使用统计服务实例 -func NewUsageService(usageRepo *repository.UsageLogRepository, userRepo *repository.UserRepository) *UsageService { +func NewUsageService(usageRepo ports.UsageLogRepository, userRepo ports.UserRepository) *UsageService { return &UsageService{ usageRepo: usageRepo, userRepo: userRepo, @@ -127,7 +128,7 @@ func (s *UsageService) GetByID(ctx context.Context, id int64) (*model.UsageLog, } // ListByUser 获取用户的使用日志列表 -func (s *UsageService) ListByUser(ctx context.Context, userID int64, params repository.PaginationParams) ([]model.UsageLog, *repository.PaginationResult, error) { +func (s *UsageService) ListByUser(ctx context.Context, userID int64, params pagination.PaginationParams) ([]model.UsageLog, *pagination.PaginationResult, error) { logs, pagination, err := s.usageRepo.ListByUser(ctx, userID, params) if err != nil { return nil, nil, fmt.Errorf("list usage logs: %w", err) @@ -136,7 +137,7 @@ func (s *UsageService) ListByUser(ctx context.Context, userID int64, params repo } // ListByApiKey 获取API Key的使用日志列表 -func (s *UsageService) ListByApiKey(ctx context.Context, apiKeyID int64, params repository.PaginationParams) ([]model.UsageLog, *repository.PaginationResult, error) { +func (s *UsageService) ListByApiKey(ctx context.Context, apiKeyID int64, params pagination.PaginationParams) ([]model.UsageLog, *pagination.PaginationResult, error) { logs, pagination, err := s.usageRepo.ListByApiKey(ctx, apiKeyID, params) if err != nil { return nil, nil, fmt.Errorf("list usage logs: %w", err) @@ -145,7 +146,7 @@ func (s *UsageService) ListByApiKey(ctx context.Context, apiKeyID int64, params } // ListByAccount 获取账号的使用日志列表 -func (s *UsageService) ListByAccount(ctx context.Context, accountID int64, params repository.PaginationParams) ([]model.UsageLog, *repository.PaginationResult, error) { +func (s *UsageService) ListByAccount(ctx context.Context, accountID int64, params pagination.PaginationParams) ([]model.UsageLog, *pagination.PaginationResult, error) { logs, pagination, err := s.usageRepo.ListByAccount(ctx, accountID, params) if err != nil { return nil, nil, fmt.Errorf("list usage logs: %w", err) @@ -233,15 +234,15 @@ func (s *UsageService) GetDailyStats(ctx context.Context, userID int64, days int } result = append(result, map[string]interface{}{ - "date": date, - "total_requests": stats.TotalRequests, - "total_input_tokens": stats.TotalInputTokens, - "total_output_tokens": stats.TotalOutputTokens, - "total_cache_tokens": stats.TotalCacheTokens, - "total_tokens": stats.TotalTokens, - "total_cost": stats.TotalCost, - "total_actual_cost": stats.TotalActualCost, - "average_duration_ms": stats.AverageDurationMs, + "date": date, + "total_requests": stats.TotalRequests, + "total_input_tokens": stats.TotalInputTokens, + "total_output_tokens": stats.TotalOutputTokens, + "total_cache_tokens": stats.TotalCacheTokens, + "total_tokens": stats.TotalTokens, + "total_cost": stats.TotalCost, + "total_actual_cost": stats.TotalActualCost, + "average_duration_ms": stats.AverageDurationMs, }) } diff --git a/backend/internal/service/user_service.go b/backend/internal/service/user_service.go index bc0af756..07bfac8e 100644 --- a/backend/internal/service/user_service.go +++ b/backend/internal/service/user_service.go @@ -6,16 +6,17 @@ import ( "fmt" "sub2api/internal/config" "sub2api/internal/model" - "sub2api/internal/repository" + "sub2api/internal/pkg/pagination" + "sub2api/internal/service/ports" "golang.org/x/crypto/bcrypt" "gorm.io/gorm" ) var ( - ErrUserNotFound = errors.New("user not found") - ErrPasswordIncorrect = errors.New("current password is incorrect") - ErrInsufficientPerms = errors.New("insufficient permissions") + ErrUserNotFound = errors.New("user not found") + ErrPasswordIncorrect = errors.New("current password is incorrect") + ErrInsufficientPerms = errors.New("insufficient permissions") ) // UpdateProfileRequest 更新用户资料请求 @@ -32,12 +33,12 @@ type ChangePasswordRequest struct { // UserService 用户服务 type UserService struct { - userRepo *repository.UserRepository + userRepo ports.UserRepository cfg *config.Config } // NewUserService 创建用户服务实例 -func NewUserService(userRepo *repository.UserRepository, cfg *config.Config) *UserService { +func NewUserService(userRepo ports.UserRepository, cfg *config.Config) *UserService { return &UserService{ userRepo: userRepo, cfg: cfg, @@ -133,7 +134,7 @@ func (s *UserService) GetByID(ctx context.Context, id int64) (*model.User, error } // List 获取用户列表(管理员功能) -func (s *UserService) List(ctx context.Context, params repository.PaginationParams) ([]model.User, *repository.PaginationResult, error) { +func (s *UserService) List(ctx context.Context, params pagination.PaginationParams) ([]model.User, *pagination.PaginationResult, error) { users, pagination, err := s.userRepo.List(ctx, params) if err != nil { return nil, nil, fmt.Errorf("list users: %w", err)