refactor(backend): 引入端口接口模式

This commit is contained in:
Forest
2025-12-19 21:26:19 +08:00
parent 7fd94ab78b
commit e99b344b2b
45 changed files with 627 additions and 323 deletions

View File

@@ -57,32 +57,21 @@ func initializeApplication(buildInfo handler.BuildInfo) (*Application, error) {
usageService := service.NewUsageService(usageLogRepository, userRepository)
usageHandler := handler.NewUsageHandler(usageService, usageLogRepository, apiKeyService)
redeemCodeRepository := repository.NewRedeemCodeRepository(db)
accountRepository := repository.NewAccountRepository(db)
proxyRepository := repository.NewProxyRepository(db)
repositories := &repository.Repositories{
User: userRepository,
ApiKey: apiKeyRepository,
Group: groupRepository,
Account: accountRepository,
Proxy: proxyRepository,
RedeemCode: redeemCodeRepository,
UsageLog: usageLogRepository,
Setting: settingRepository,
UserSubscription: userSubscriptionRepository,
}
billingCacheService := service.NewBillingCacheService(client, userRepository, userSubscriptionRepository)
subscriptionService := service.NewSubscriptionService(repositories, billingCacheService)
subscriptionService := service.NewSubscriptionService(groupRepository, userSubscriptionRepository, billingCacheService)
redeemService := service.NewRedeemService(redeemCodeRepository, userRepository, subscriptionService, client, billingCacheService)
redeemHandler := handler.NewRedeemHandler(redeemService)
subscriptionHandler := handler.NewSubscriptionHandler(subscriptionService)
adminService := service.NewAdminService(repositories, billingCacheService)
accountRepository := repository.NewAccountRepository(db)
proxyRepository := repository.NewProxyRepository(db)
adminService := service.NewAdminService(userRepository, groupRepository, accountRepository, proxyRepository, apiKeyRepository, redeemCodeRepository, usageLogRepository, userSubscriptionRepository, billingCacheService)
dashboardHandler := admin.NewDashboardHandler(adminService, usageLogRepository)
adminUserHandler := admin.NewUserHandler(adminService)
groupHandler := admin.NewGroupHandler(adminService)
oAuthService := service.NewOAuthService(proxyRepository)
rateLimitService := service.NewRateLimitService(repositories, configConfig)
accountUsageService := service.NewAccountUsageService(repositories, oAuthService)
accountTestService := service.NewAccountTestService(repositories, oAuthService)
rateLimitService := service.NewRateLimitService(accountRepository, configConfig)
accountUsageService := service.NewAccountUsageService(accountRepository, usageLogRepository, oAuthService)
accountTestService := service.NewAccountTestService(accountRepository, oAuthService)
accountHandler := admin.NewAccountHandler(adminService, oAuthService, rateLimitService, accountUsageService, accountTestService)
oAuthHandler := admin.NewOAuthHandler(oAuthService, adminService)
proxyHandler := admin.NewProxyHandler(adminService)
@@ -98,7 +87,7 @@ func initializeApplication(buildInfo handler.BuildInfo) (*Application, error) {
}
billingService := service.NewBillingService(configConfig, pricingService)
identityService := service.NewIdentityService(client)
gatewayService := service.NewGatewayService(repositories, client, configConfig, oAuthService, billingService, rateLimitService, billingCacheService, identityService)
gatewayService := service.NewGatewayService(accountRepository, usageLogRepository, userRepository, userSubscriptionRepository, client, configConfig, oAuthService, billingService, rateLimitService, billingCacheService, identityService)
concurrencyService := service.NewConcurrencyService(client)
gatewayHandler := handler.NewGatewayHandler(gatewayService, userService, concurrencyService, billingCacheService)
handlerSettingHandler := handler.ProvideSettingHandler(settingService, buildInfo)
@@ -132,6 +121,17 @@ func initializeApplication(buildInfo handler.BuildInfo) (*Application, error) {
Concurrency: concurrencyService,
Identity: identityService,
}
repositories := &repository.Repositories{
User: userRepository,
ApiKey: apiKeyRepository,
Group: groupRepository,
Account: accountRepository,
Proxy: proxyRepository,
RedeemCode: redeemCodeRepository,
UsageLog: usageLogRepository,
Setting: settingRepository,
UserSubscription: userSubscriptionRepository,
}
engine := server.ProvideRouter(configConfig, handlers, services, repositories)
httpServer := server.ProvideHTTPServer(configConfig, engine)
v := provideCleanup(db, client, services)

View File

@@ -4,15 +4,15 @@ import (
"strconv"
"sub2api/internal/model"
"sub2api/internal/pkg/pagination"
"sub2api/internal/pkg/response"
"sub2api/internal/repository"
"sub2api/internal/service"
"github.com/gin-gonic/gin"
)
// toResponsePagination converts repository.PaginationResult to response.PaginationResult
func toResponsePagination(p *repository.PaginationResult) *response.PaginationResult {
// toResponsePagination converts pagination.PaginationResult to response.PaginationResult
func toResponsePagination(p *pagination.PaginationResult) *response.PaginationResult {
if p == nil {
return nil
}

View File

@@ -4,6 +4,7 @@ import (
"strconv"
"time"
"sub2api/internal/pkg/pagination"
"sub2api/internal/pkg/response"
"sub2api/internal/pkg/timezone"
"sub2api/internal/repository"
@@ -14,10 +15,10 @@ import (
// UsageHandler handles admin usage-related requests
type UsageHandler struct {
usageRepo *repository.UsageLogRepository
apiKeyRepo *repository.ApiKeyRepository
usageService *service.UsageService
adminService service.AdminService
usageRepo *repository.UsageLogRepository
apiKeyRepo *repository.ApiKeyRepository
usageService *service.UsageService
adminService service.AdminService
}
// NewUsageHandler creates a new admin usage handler
@@ -82,7 +83,7 @@ func (h *UsageHandler) List(c *gin.Context) {
endTime = &t
}
params := repository.PaginationParams{Page: page, PageSize: pageSize}
params := pagination.PaginationParams{Page: page, PageSize: pageSize}
filters := repository.UsageLogFilters{
UserID: userID,
ApiKeyID: apiKeyID,

View File

@@ -4,8 +4,8 @@ import (
"strconv"
"sub2api/internal/model"
"sub2api/internal/pkg/pagination"
"sub2api/internal/pkg/response"
"sub2api/internal/repository"
"sub2api/internal/service"
"github.com/gin-gonic/gin"
@@ -53,7 +53,7 @@ func (h *APIKeyHandler) List(c *gin.Context) {
}
page, pageSize := response.ParsePagination(c)
params := repository.PaginationParams{Page: page, PageSize: pageSize}
params := pagination.PaginationParams{Page: page, PageSize: pageSize}
keys, result, err := h.apiKeyService.List(c.Request.Context(), user.ID, params)
if err != nil {

View File

@@ -5,6 +5,7 @@ import (
"time"
"sub2api/internal/model"
"sub2api/internal/pkg/pagination"
"sub2api/internal/pkg/response"
"sub2api/internal/pkg/timezone"
"sub2api/internal/repository"
@@ -68,9 +69,9 @@ func (h *UsageHandler) List(c *gin.Context) {
apiKeyID = id
}
params := repository.PaginationParams{Page: page, PageSize: pageSize}
params := pagination.PaginationParams{Page: page, PageSize: pageSize}
var records []model.UsageLog
var result *repository.PaginationResult
var result *pagination.PaginationResult
var err error
if apiKeyID > 0 {
@@ -362,7 +363,7 @@ func (h *UsageHandler) DashboardApiKeysUsage(c *gin.Context) {
}
// Verify ownership of all requested API keys
userApiKeys, _, err := h.apiKeyService.List(c.Request.Context(), user.ID, repository.PaginationParams{Page: 1, PageSize: 1000})
userApiKeys, _, err := h.apiKeyService.List(c.Request.Context(), user.ID, pagination.PaginationParams{Page: 1, PageSize: 1000})
if err != nil {
response.InternalError(c, "Failed to verify API key ownership")
return

View File

@@ -0,0 +1,42 @@
package pagination
// PaginationParams 分页参数
type PaginationParams struct {
Page int
PageSize int
}
// PaginationResult 分页结果
type PaginationResult struct {
Total int64
Page int
PageSize int
Pages int
}
// DefaultPagination 默认分页参数
func DefaultPagination() PaginationParams {
return PaginationParams{
Page: 1,
PageSize: 20,
}
}
// Offset 计算偏移量
func (p PaginationParams) Offset() int {
if p.Page < 1 {
p.Page = 1
}
return (p.Page - 1) * p.PageSize
}
// Limit 获取限制数
func (p PaginationParams) Limit() int {
if p.PageSize < 1 {
return 20
}
if p.PageSize > 100 {
return 100
}
return p.PageSize
}

View File

@@ -90,7 +90,7 @@ func Paginated(c *gin.Context, items interface{}, total int64, page, pageSize in
})
}
// PaginationResult 分页结果(与repository.PaginationResult兼容
// PaginationResult 分页结果(与pagination.PaginationResult兼容
type PaginationResult struct {
Total int64
Page int

View File

@@ -0,0 +1,8 @@
package usagestats
// AccountStats 账号使用统计
type AccountStats struct {
Requests int64 `json:"requests"`
Tokens int64 `json:"tokens"`
Cost float64 `json:"cost"`
}

View File

@@ -3,6 +3,7 @@ package repository
import (
"context"
"sub2api/internal/model"
"sub2api/internal/pkg/pagination"
"time"
"gorm.io/gorm"
@@ -47,12 +48,12 @@ func (r *AccountRepository) Delete(ctx context.Context, id int64) error {
return r.db.WithContext(ctx).Delete(&model.Account{}, id).Error
}
func (r *AccountRepository) List(ctx context.Context, params PaginationParams) ([]model.Account, *PaginationResult, error) {
func (r *AccountRepository) List(ctx context.Context, params pagination.PaginationParams) ([]model.Account, *pagination.PaginationResult, error) {
return r.ListWithFilters(ctx, params, "", "", "", "")
}
// ListWithFilters lists accounts with optional filtering by platform, type, status, and search query
func (r *AccountRepository) ListWithFilters(ctx context.Context, params PaginationParams, platform, accountType, status, search string) ([]model.Account, *PaginationResult, error) {
func (r *AccountRepository) ListWithFilters(ctx context.Context, params pagination.PaginationParams, platform, accountType, status, search string) ([]model.Account, *pagination.PaginationResult, error) {
var accounts []model.Account
var total int64
@@ -94,7 +95,7 @@ func (r *AccountRepository) ListWithFilters(ctx context.Context, params Paginati
pages++
}
return accounts, &PaginationResult{
return accounts, &pagination.PaginationResult{
Total: total,
Page: params.Page,
PageSize: params.Limit(),
@@ -226,7 +227,7 @@ func (r *AccountRepository) SetRateLimited(ctx context.Context, id int64, resetA
now := time.Now()
return r.db.WithContext(ctx).Model(&model.Account{}).Where("id = ?", id).
Updates(map[string]interface{}{
"rate_limited_at": now,
"rate_limited_at": now,
"rate_limit_reset_at": resetAt,
}).Error
}

View File

@@ -3,6 +3,7 @@ package repository
import (
"context"
"sub2api/internal/model"
"sub2api/internal/pkg/pagination"
"gorm.io/gorm"
)
@@ -45,7 +46,7 @@ func (r *ApiKeyRepository) Delete(ctx context.Context, id int64) error {
return r.db.WithContext(ctx).Delete(&model.ApiKey{}, id).Error
}
func (r *ApiKeyRepository) ListByUserID(ctx context.Context, userID int64, params PaginationParams) ([]model.ApiKey, *PaginationResult, error) {
func (r *ApiKeyRepository) ListByUserID(ctx context.Context, userID int64, params pagination.PaginationParams) ([]model.ApiKey, *pagination.PaginationResult, error) {
var keys []model.ApiKey
var total int64
@@ -64,7 +65,7 @@ func (r *ApiKeyRepository) ListByUserID(ctx context.Context, userID int64, param
pages++
}
return keys, &PaginationResult{
return keys, &pagination.PaginationResult{
Total: total,
Page: params.Page,
PageSize: params.Limit(),
@@ -84,7 +85,7 @@ func (r *ApiKeyRepository) ExistsByKey(ctx context.Context, key string) (bool, e
return count > 0, err
}
func (r *ApiKeyRepository) ListByGroupID(ctx context.Context, groupID int64, params PaginationParams) ([]model.ApiKey, *PaginationResult, error) {
func (r *ApiKeyRepository) ListByGroupID(ctx context.Context, groupID int64, params pagination.PaginationParams) ([]model.ApiKey, *pagination.PaginationResult, error) {
var keys []model.ApiKey
var total int64
@@ -103,7 +104,7 @@ func (r *ApiKeyRepository) ListByGroupID(ctx context.Context, groupID int64, par
pages++
}
return keys, &PaginationResult{
return keys, &pagination.PaginationResult{
Total: total,
Page: params.Page,
PageSize: params.Limit(),

View File

@@ -3,6 +3,7 @@ package repository
import (
"context"
"sub2api/internal/model"
"sub2api/internal/pkg/pagination"
"gorm.io/gorm"
)
@@ -36,12 +37,12 @@ func (r *GroupRepository) Delete(ctx context.Context, id int64) error {
return r.db.WithContext(ctx).Delete(&model.Group{}, id).Error
}
func (r *GroupRepository) List(ctx context.Context, params PaginationParams) ([]model.Group, *PaginationResult, error) {
func (r *GroupRepository) List(ctx context.Context, params pagination.PaginationParams) ([]model.Group, *pagination.PaginationResult, error) {
return r.ListWithFilters(ctx, params, "", "", nil)
}
// ListWithFilters lists groups with optional filtering by platform, status, and is_exclusive
func (r *GroupRepository) ListWithFilters(ctx context.Context, params PaginationParams, platform, status string, isExclusive *bool) ([]model.Group, *PaginationResult, error) {
func (r *GroupRepository) ListWithFilters(ctx context.Context, params pagination.PaginationParams, platform, status string, isExclusive *bool) ([]model.Group, *pagination.PaginationResult, error) {
var groups []model.Group
var total int64
@@ -77,7 +78,7 @@ func (r *GroupRepository) ListWithFilters(ctx context.Context, params Pagination
pages++
}
return groups, &PaginationResult{
return groups, &pagination.PaginationResult{
Total: total,
Page: params.Page,
PageSize: params.Limit(),

View File

@@ -3,6 +3,7 @@ package repository
import (
"context"
"sub2api/internal/model"
"sub2api/internal/pkg/pagination"
"gorm.io/gorm"
)
@@ -36,12 +37,12 @@ func (r *ProxyRepository) Delete(ctx context.Context, id int64) error {
return r.db.WithContext(ctx).Delete(&model.Proxy{}, id).Error
}
func (r *ProxyRepository) List(ctx context.Context, params PaginationParams) ([]model.Proxy, *PaginationResult, error) {
func (r *ProxyRepository) List(ctx context.Context, params pagination.PaginationParams) ([]model.Proxy, *pagination.PaginationResult, error) {
return r.ListWithFilters(ctx, params, "", "", "")
}
// ListWithFilters lists proxies with optional filtering by protocol, status, and search query
func (r *ProxyRepository) ListWithFilters(ctx context.Context, params PaginationParams, protocol, status, search string) ([]model.Proxy, *PaginationResult, error) {
func (r *ProxyRepository) ListWithFilters(ctx context.Context, params pagination.PaginationParams, protocol, status, search string) ([]model.Proxy, *pagination.PaginationResult, error) {
var proxies []model.Proxy
var total int64
@@ -72,7 +73,7 @@ func (r *ProxyRepository) ListWithFilters(ctx context.Context, params Pagination
pages++
}
return proxies, &PaginationResult{
return proxies, &pagination.PaginationResult{
Total: total,
Page: params.Page,
PageSize: params.Limit(),

View File

@@ -3,6 +3,7 @@ package repository
import (
"context"
"sub2api/internal/model"
"sub2api/internal/pkg/pagination"
"time"
"gorm.io/gorm"
@@ -46,12 +47,12 @@ func (r *RedeemCodeRepository) Delete(ctx context.Context, id int64) error {
return r.db.WithContext(ctx).Delete(&model.RedeemCode{}, id).Error
}
func (r *RedeemCodeRepository) List(ctx context.Context, params PaginationParams) ([]model.RedeemCode, *PaginationResult, error) {
func (r *RedeemCodeRepository) List(ctx context.Context, params pagination.PaginationParams) ([]model.RedeemCode, *pagination.PaginationResult, error) {
return r.ListWithFilters(ctx, params, "", "", "")
}
// ListWithFilters lists redeem codes with optional filtering by type, status, and search query
func (r *RedeemCodeRepository) ListWithFilters(ctx context.Context, params PaginationParams, codeType, status, search string) ([]model.RedeemCode, *PaginationResult, error) {
func (r *RedeemCodeRepository) ListWithFilters(ctx context.Context, params pagination.PaginationParams, codeType, status, search string) ([]model.RedeemCode, *pagination.PaginationResult, error) {
var codes []model.RedeemCode
var total int64
@@ -82,7 +83,7 @@ func (r *RedeemCodeRepository) ListWithFilters(ctx context.Context, params Pagin
pages++
}
return codes, &PaginationResult{
return codes, &pagination.PaginationResult{
Total: total,
Page: params.Page,
PageSize: params.Limit(),

View File

@@ -12,44 +12,3 @@ type Repositories struct {
Setting *SettingRepository
UserSubscription *UserSubscriptionRepository
}
// PaginationParams 分页参数
type PaginationParams struct {
Page int
PageSize int
}
// PaginationResult 分页结果
type PaginationResult struct {
Total int64
Page int
PageSize int
Pages int
}
// DefaultPagination 默认分页参数
func DefaultPagination() PaginationParams {
return PaginationParams{
Page: 1,
PageSize: 20,
}
}
// Offset 计算偏移量
func (p PaginationParams) Offset() int {
if p.Page < 1 {
p.Page = 1
}
return (p.Page - 1) * p.PageSize
}
// Limit 获取限制数
func (p PaginationParams) Limit() int {
if p.PageSize < 1 {
return 20
}
if p.PageSize > 100 {
return 100
}
return p.PageSize
}

View File

@@ -3,7 +3,9 @@ package repository
import (
"context"
"sub2api/internal/model"
"sub2api/internal/pkg/pagination"
"sub2api/internal/pkg/timezone"
"sub2api/internal/pkg/usagestats"
"time"
"gorm.io/gorm"
@@ -30,7 +32,7 @@ func (r *UsageLogRepository) GetByID(ctx context.Context, id int64) (*model.Usag
return &log, nil
}
func (r *UsageLogRepository) ListByUser(ctx context.Context, userID int64, params PaginationParams) ([]model.UsageLog, *PaginationResult, error) {
func (r *UsageLogRepository) ListByUser(ctx context.Context, userID int64, params pagination.PaginationParams) ([]model.UsageLog, *pagination.PaginationResult, error) {
var logs []model.UsageLog
var total int64
@@ -49,7 +51,7 @@ func (r *UsageLogRepository) ListByUser(ctx context.Context, userID int64, param
pages++
}
return logs, &PaginationResult{
return logs, &pagination.PaginationResult{
Total: total,
Page: params.Page,
PageSize: params.Limit(),
@@ -57,7 +59,7 @@ func (r *UsageLogRepository) ListByUser(ctx context.Context, userID int64, param
}, nil
}
func (r *UsageLogRepository) ListByApiKey(ctx context.Context, apiKeyID int64, params PaginationParams) ([]model.UsageLog, *PaginationResult, error) {
func (r *UsageLogRepository) ListByApiKey(ctx context.Context, apiKeyID int64, params pagination.PaginationParams) ([]model.UsageLog, *pagination.PaginationResult, error) {
var logs []model.UsageLog
var total int64
@@ -76,7 +78,7 @@ func (r *UsageLogRepository) ListByApiKey(ctx context.Context, apiKeyID int64, p
pages++
}
return logs, &PaginationResult{
return logs, &pagination.PaginationResult{
Total: total,
Page: params.Page,
PageSize: params.Limit(),
@@ -270,7 +272,7 @@ func (r *UsageLogRepository) GetDashboardStats(ctx context.Context) (*DashboardS
return &stats, nil
}
func (r *UsageLogRepository) ListByAccount(ctx context.Context, accountID int64, params PaginationParams) ([]model.UsageLog, *PaginationResult, error) {
func (r *UsageLogRepository) ListByAccount(ctx context.Context, accountID int64, params pagination.PaginationParams) ([]model.UsageLog, *pagination.PaginationResult, error) {
var logs []model.UsageLog
var total int64
@@ -289,7 +291,7 @@ func (r *UsageLogRepository) ListByAccount(ctx context.Context, accountID int64,
pages++
}
return logs, &PaginationResult{
return logs, &pagination.PaginationResult{
Total: total,
Page: params.Page,
PageSize: params.Limit(),
@@ -297,7 +299,7 @@ func (r *UsageLogRepository) ListByAccount(ctx context.Context, accountID int64,
}, nil
}
func (r *UsageLogRepository) ListByUserAndTimeRange(ctx context.Context, userID int64, startTime, endTime time.Time) ([]model.UsageLog, *PaginationResult, error) {
func (r *UsageLogRepository) ListByUserAndTimeRange(ctx context.Context, userID int64, startTime, endTime time.Time) ([]model.UsageLog, *pagination.PaginationResult, error) {
var logs []model.UsageLog
err := r.db.WithContext(ctx).
Where("user_id = ? AND created_at >= ? AND created_at < ?", userID, startTime, endTime).
@@ -306,7 +308,7 @@ func (r *UsageLogRepository) ListByUserAndTimeRange(ctx context.Context, userID
return logs, nil, err
}
func (r *UsageLogRepository) ListByApiKeyAndTimeRange(ctx context.Context, apiKeyID int64, startTime, endTime time.Time) ([]model.UsageLog, *PaginationResult, error) {
func (r *UsageLogRepository) ListByApiKeyAndTimeRange(ctx context.Context, apiKeyID int64, startTime, endTime time.Time) ([]model.UsageLog, *pagination.PaginationResult, error) {
var logs []model.UsageLog
err := r.db.WithContext(ctx).
Where("api_key_id = ? AND created_at >= ? AND created_at < ?", apiKeyID, startTime, endTime).
@@ -315,7 +317,7 @@ func (r *UsageLogRepository) ListByApiKeyAndTimeRange(ctx context.Context, apiKe
return logs, nil, err
}
func (r *UsageLogRepository) ListByAccountAndTimeRange(ctx context.Context, accountID int64, startTime, endTime time.Time) ([]model.UsageLog, *PaginationResult, error) {
func (r *UsageLogRepository) ListByAccountAndTimeRange(ctx context.Context, accountID int64, startTime, endTime time.Time) ([]model.UsageLog, *pagination.PaginationResult, error) {
var logs []model.UsageLog
err := r.db.WithContext(ctx).
Where("account_id = ? AND created_at >= ? AND created_at < ?", accountID, startTime, endTime).
@@ -324,7 +326,7 @@ func (r *UsageLogRepository) ListByAccountAndTimeRange(ctx context.Context, acco
return logs, nil, err
}
func (r *UsageLogRepository) ListByModelAndTimeRange(ctx context.Context, modelName string, startTime, endTime time.Time) ([]model.UsageLog, *PaginationResult, error) {
func (r *UsageLogRepository) ListByModelAndTimeRange(ctx context.Context, modelName string, startTime, endTime time.Time) ([]model.UsageLog, *pagination.PaginationResult, error) {
var logs []model.UsageLog
err := r.db.WithContext(ctx).
Where("model = ? AND created_at >= ? AND created_at < ?", modelName, startTime, endTime).
@@ -337,15 +339,8 @@ func (r *UsageLogRepository) Delete(ctx context.Context, id int64) error {
return r.db.WithContext(ctx).Delete(&model.UsageLog{}, id).Error
}
// AccountStats 账号使用统计
type AccountStats struct {
Requests int64 `json:"requests"`
Tokens int64 `json:"tokens"`
Cost float64 `json:"cost"`
}
// GetAccountTodayStats 获取账号今日统计
func (r *UsageLogRepository) GetAccountTodayStats(ctx context.Context, accountID int64) (*AccountStats, error) {
func (r *UsageLogRepository) GetAccountTodayStats(ctx context.Context, accountID int64) (*usagestats.AccountStats, error) {
today := timezone.Today()
var stats struct {
@@ -367,7 +362,7 @@ func (r *UsageLogRepository) GetAccountTodayStats(ctx context.Context, accountID
return nil, err
}
return &AccountStats{
return &usagestats.AccountStats{
Requests: stats.Requests,
Tokens: stats.Tokens,
Cost: stats.Cost,
@@ -375,7 +370,7 @@ func (r *UsageLogRepository) GetAccountTodayStats(ctx context.Context, accountID
}
// GetAccountWindowStats 获取账号时间窗口内的统计
func (r *UsageLogRepository) GetAccountWindowStats(ctx context.Context, accountID int64, startTime time.Time) (*AccountStats, error) {
func (r *UsageLogRepository) GetAccountWindowStats(ctx context.Context, accountID int64, startTime time.Time) (*usagestats.AccountStats, error) {
var stats struct {
Requests int64 `gorm:"column:requests"`
Tokens int64 `gorm:"column:tokens"`
@@ -395,7 +390,7 @@ func (r *UsageLogRepository) GetAccountWindowStats(ctx context.Context, accountI
return nil, err
}
return &AccountStats{
return &usagestats.AccountStats{
Requests: stats.Requests,
Tokens: stats.Tokens,
Cost: stats.Cost,
@@ -500,11 +495,11 @@ func (r *UsageLogRepository) GetModelStats(ctx context.Context, startTime, endTi
// ApiKeyUsageTrendPoint represents API key usage trend data point
type ApiKeyUsageTrendPoint struct {
Date string `json:"date"`
ApiKeyID int64 `json:"api_key_id"`
KeyName string `json:"key_name"`
Requests int64 `json:"requests"`
Tokens int64 `json:"tokens"`
Date string `json:"date"`
ApiKeyID int64 `json:"api_key_id"`
KeyName string `json:"key_name"`
Requests int64 `json:"requests"`
Tokens int64 `json:"tokens"`
}
// GetApiKeyUsageTrend returns usage trend data grouped by API key and date
@@ -780,7 +775,7 @@ type UsageLogFilters struct {
}
// ListWithFilters lists usage logs with optional filters (for admin)
func (r *UsageLogRepository) ListWithFilters(ctx context.Context, params PaginationParams, filters UsageLogFilters) ([]model.UsageLog, *PaginationResult, error) {
func (r *UsageLogRepository) ListWithFilters(ctx context.Context, params pagination.PaginationParams, filters UsageLogFilters) ([]model.UsageLog, *pagination.PaginationResult, error) {
var logs []model.UsageLog
var total int64
@@ -816,7 +811,7 @@ func (r *UsageLogRepository) ListWithFilters(ctx context.Context, params Paginat
pages++
}
return logs, &PaginationResult{
return logs, &pagination.PaginationResult{
Total: total,
Page: params.Page,
PageSize: params.Limit(),
@@ -838,7 +833,7 @@ type UsageStats struct {
// BatchUserUsageStats represents usage stats for a single user
type BatchUserUsageStats struct {
UserID int64 `json:"user_id"`
UserID int64 `json:"user_id"`
TodayActualCost float64 `json:"today_actual_cost"`
TotalActualCost float64 `json:"total_actual_cost"`
}

View File

@@ -3,6 +3,7 @@ package repository
import (
"context"
"sub2api/internal/model"
"sub2api/internal/pkg/pagination"
"gorm.io/gorm"
)
@@ -45,12 +46,12 @@ func (r *UserRepository) Delete(ctx context.Context, id int64) error {
return r.db.WithContext(ctx).Delete(&model.User{}, id).Error
}
func (r *UserRepository) List(ctx context.Context, params PaginationParams) ([]model.User, *PaginationResult, error) {
func (r *UserRepository) List(ctx context.Context, params pagination.PaginationParams) ([]model.User, *pagination.PaginationResult, error) {
return r.ListWithFilters(ctx, params, "", "", "")
}
// ListWithFilters lists users with optional filtering by status, role, and search query
func (r *UserRepository) ListWithFilters(ctx context.Context, params PaginationParams, status, role, search string) ([]model.User, *PaginationResult, error) {
func (r *UserRepository) ListWithFilters(ctx context.Context, params pagination.PaginationParams, status, role, search string) ([]model.User, *pagination.PaginationResult, error) {
var users []model.User
var total int64
@@ -81,7 +82,7 @@ func (r *UserRepository) ListWithFilters(ctx context.Context, params PaginationP
pages++
}
return users, &PaginationResult{
return users, &pagination.PaginationResult{
Total: total,
Page: params.Page,
PageSize: params.Limit(),
@@ -127,4 +128,3 @@ func (r *UserRepository) RemoveGroupFromAllowedGroups(ctx context.Context, group
Update("allowed_groups", gorm.Expr("array_remove(allowed_groups, ?)", groupID))
return result.RowsAffected, result.Error
}

View File

@@ -5,6 +5,7 @@ import (
"time"
"sub2api/internal/model"
"sub2api/internal/pkg/pagination"
"gorm.io/gorm"
)
@@ -100,7 +101,7 @@ func (r *UserSubscriptionRepository) ListActiveByUserID(ctx context.Context, use
}
// ListByGroupID 获取分组的所有订阅(分页)
func (r *UserSubscriptionRepository) ListByGroupID(ctx context.Context, groupID int64, params PaginationParams) ([]model.UserSubscription, *PaginationResult, error) {
func (r *UserSubscriptionRepository) ListByGroupID(ctx context.Context, groupID int64, params pagination.PaginationParams) ([]model.UserSubscription, *pagination.PaginationResult, error) {
var subs []model.UserSubscription
var total int64
@@ -126,7 +127,7 @@ func (r *UserSubscriptionRepository) ListByGroupID(ctx context.Context, groupID
pages++
}
return subs, &PaginationResult{
return subs, &pagination.PaginationResult{
Total: total,
Page: params.Page,
PageSize: params.Limit(),
@@ -135,7 +136,7 @@ func (r *UserSubscriptionRepository) ListByGroupID(ctx context.Context, groupID
}
// List 获取所有订阅(分页,支持筛选)
func (r *UserSubscriptionRepository) List(ctx context.Context, params PaginationParams, userID, groupID *int64, status string) ([]model.UserSubscription, *PaginationResult, error) {
func (r *UserSubscriptionRepository) List(ctx context.Context, params pagination.PaginationParams, userID, groupID *int64, status string) ([]model.UserSubscription, *pagination.PaginationResult, error) {
var subs []model.UserSubscription
var total int64
@@ -172,7 +173,7 @@ func (r *UserSubscriptionRepository) List(ctx context.Context, params Pagination
pages++
}
return subs, &PaginationResult{
return subs, &pagination.PaginationResult{
Total: total,
Page: params.Page,
PageSize: params.Limit(),

View File

@@ -1,6 +1,8 @@
package repository
import (
"sub2api/internal/service/ports"
"github.com/google/wire"
)
@@ -16,4 +18,15 @@ var ProviderSet = wire.NewSet(
NewSettingRepository,
NewUserSubscriptionRepository,
wire.Struct(new(Repositories), "*"),
// Bind concrete repositories to service port interfaces
wire.Bind(new(ports.UserRepository), new(*UserRepository)),
wire.Bind(new(ports.ApiKeyRepository), new(*ApiKeyRepository)),
wire.Bind(new(ports.GroupRepository), new(*GroupRepository)),
wire.Bind(new(ports.AccountRepository), new(*AccountRepository)),
wire.Bind(new(ports.ProxyRepository), new(*ProxyRepository)),
wire.Bind(new(ports.RedeemCodeRepository), new(*RedeemCodeRepository)),
wire.Bind(new(ports.UsageLogRepository), new(*UsageLogRepository)),
wire.Bind(new(ports.SettingRepository), new(*SettingRepository)),
wire.Bind(new(ports.UserSubscriptionRepository), new(*UserSubscriptionRepository)),
)

View File

@@ -5,7 +5,8 @@ import (
"errors"
"fmt"
"sub2api/internal/model"
"sub2api/internal/repository"
"sub2api/internal/pkg/pagination"
"sub2api/internal/service/ports"
"gorm.io/gorm"
)
@@ -41,12 +42,12 @@ type UpdateAccountRequest struct {
// AccountService 账号管理服务
type AccountService struct {
accountRepo *repository.AccountRepository
groupRepo *repository.GroupRepository
accountRepo ports.AccountRepository
groupRepo ports.GroupRepository
}
// NewAccountService 创建账号服务实例
func NewAccountService(accountRepo *repository.AccountRepository, groupRepo *repository.GroupRepository) *AccountService {
func NewAccountService(accountRepo ports.AccountRepository, groupRepo ports.GroupRepository) *AccountService {
return &AccountService{
accountRepo: accountRepo,
groupRepo: groupRepo,
@@ -108,7 +109,7 @@ func (s *AccountService) GetByID(ctx context.Context, id int64) (*model.Account,
}
// List 获取账号列表
func (s *AccountService) List(ctx context.Context, params repository.PaginationParams) ([]model.Account, *repository.PaginationResult, error) {
func (s *AccountService) List(ctx context.Context, params pagination.PaginationParams) ([]model.Account, *pagination.PaginationResult, error) {
accounts, pagination, err := s.accountRepo.List(ctx, params)
if err != nil {
return nil, nil, fmt.Errorf("list accounts: %w", err)

View File

@@ -16,7 +16,7 @@ import (
"time"
"sub2api/internal/pkg/claude"
"sub2api/internal/repository"
"sub2api/internal/service/ports"
"github.com/gin-gonic/gin"
"github.com/google/uuid"
@@ -37,15 +37,15 @@ type TestEvent struct {
// AccountTestService handles account testing operations
type AccountTestService struct {
repos *repository.Repositories
accountRepo ports.AccountRepository
oauthService *OAuthService
httpClient *http.Client
}
// NewAccountTestService creates a new AccountTestService
func NewAccountTestService(repos *repository.Repositories, oauthService *OAuthService) *AccountTestService {
func NewAccountTestService(accountRepo ports.AccountRepository, oauthService *OAuthService) *AccountTestService {
return &AccountTestService{
repos: repos,
accountRepo: accountRepo,
oauthService: oauthService,
httpClient: &http.Client{
Timeout: 60 * time.Second,
@@ -105,7 +105,7 @@ func (s *AccountTestService) TestAccountConnection(c *gin.Context, accountID int
ctx := c.Request.Context()
// Get account
account, err := s.repos.Account.GetByID(ctx, accountID)
account, err := s.accountRepo.GetByID(ctx, accountID)
if err != nil {
return s.sendErrorAndEnd(c, "Account not found")
}

View File

@@ -12,7 +12,7 @@ import (
"time"
"sub2api/internal/model"
"sub2api/internal/repository"
"sub2api/internal/service/ports"
)
// usageCache 用于缓存usage数据
@@ -35,10 +35,10 @@ type WindowStats struct {
// UsageProgress 使用量进度
type UsageProgress struct {
Utilization float64 `json:"utilization"` // 使用率百分比 (0-100+100表示100%)
ResetsAt *time.Time `json:"resets_at"` // 重置时间
RemainingSeconds int `json:"remaining_seconds"` // 距重置剩余秒数
WindowStats *WindowStats `json:"window_stats,omitempty"` // 窗口期统计(从窗口开始到当前的使用量)
Utilization float64 `json:"utilization"` // 使用率百分比 (0-100+100表示100%)
ResetsAt *time.Time `json:"resets_at"` // 重置时间
RemainingSeconds int `json:"remaining_seconds"` // 距重置剩余秒数
WindowStats *WindowStats `json:"window_stats,omitempty"` // 窗口期统计(从窗口开始到当前的使用量)
}
// UsageInfo 账号使用量信息
@@ -67,15 +67,17 @@ type ClaudeUsageResponse struct {
// AccountUsageService 账号使用量查询服务
type AccountUsageService struct {
repos *repository.Repositories
accountRepo ports.AccountRepository
usageLogRepo ports.UsageLogRepository
oauthService *OAuthService
httpClient *http.Client
}
// NewAccountUsageService 创建AccountUsageService实例
func NewAccountUsageService(repos *repository.Repositories, oauthService *OAuthService) *AccountUsageService {
func NewAccountUsageService(accountRepo ports.AccountRepository, usageLogRepo ports.UsageLogRepository, oauthService *OAuthService) *AccountUsageService {
return &AccountUsageService{
repos: repos,
accountRepo: accountRepo,
usageLogRepo: usageLogRepo,
oauthService: oauthService,
httpClient: &http.Client{
Timeout: 30 * time.Second,
@@ -88,7 +90,7 @@ func NewAccountUsageService(repos *repository.Repositories, oauthService *OAuthS
// Setup Token账号: 根据session_window推算5h窗口7d数据不可用没有profile scope
// API Key账号: 不支持usage查询
func (s *AccountUsageService) GetUsage(ctx context.Context, accountID int64) (*UsageInfo, error) {
account, err := s.repos.Account.GetByID(ctx, accountID)
account, err := s.accountRepo.GetByID(ctx, accountID)
if err != nil {
return nil, fmt.Errorf("get account failed: %w", err)
}
@@ -148,7 +150,7 @@ func (s *AccountUsageService) addWindowStats(ctx context.Context, account *model
startTime = time.Now().Add(-5 * time.Hour)
}
stats, err := s.repos.UsageLog.GetAccountWindowStats(ctx, account.ID, startTime)
stats, err := s.usageLogRepo.GetAccountWindowStats(ctx, account.ID, startTime)
if err != nil {
log.Printf("Failed to get window stats for account %d: %v", account.ID, err)
return
@@ -163,7 +165,7 @@ func (s *AccountUsageService) addWindowStats(ctx context.Context, account *model
// GetTodayStats 获取账号今日统计
func (s *AccountUsageService) GetTodayStats(ctx context.Context, accountID int64) (*WindowStats, error) {
stats, err := s.repos.UsageLog.GetAccountTodayStats(ctx, accountID)
stats, err := s.usageLogRepo.GetAccountTodayStats(ctx, accountID)
if err != nil {
return nil, fmt.Errorf("get today stats failed: %w", err)
}

View File

@@ -13,7 +13,8 @@ import (
"time"
"sub2api/internal/model"
"sub2api/internal/repository"
"sub2api/internal/pkg/pagination"
"sub2api/internal/service/ports"
"golang.org/x/net/proxy"
"gorm.io/gorm"
@@ -179,35 +180,45 @@ type ProxyTestResult struct {
// adminServiceImpl implements AdminService
type adminServiceImpl struct {
userRepo *repository.UserRepository
groupRepo *repository.GroupRepository
accountRepo *repository.AccountRepository
proxyRepo *repository.ProxyRepository
apiKeyRepo *repository.ApiKeyRepository
redeemCodeRepo *repository.RedeemCodeRepository
usageLogRepo *repository.UsageLogRepository
userSubRepo *repository.UserSubscriptionRepository
userRepo ports.UserRepository
groupRepo ports.GroupRepository
accountRepo ports.AccountRepository
proxyRepo ports.ProxyRepository
apiKeyRepo ports.ApiKeyRepository
redeemCodeRepo ports.RedeemCodeRepository
usageLogRepo ports.UsageLogRepository
userSubRepo ports.UserSubscriptionRepository
billingCacheService *BillingCacheService
}
// NewAdminService creates a new AdminService
func NewAdminService(repos *repository.Repositories, billingCacheService *BillingCacheService) AdminService {
func NewAdminService(
userRepo ports.UserRepository,
groupRepo ports.GroupRepository,
accountRepo ports.AccountRepository,
proxyRepo ports.ProxyRepository,
apiKeyRepo ports.ApiKeyRepository,
redeemCodeRepo ports.RedeemCodeRepository,
usageLogRepo ports.UsageLogRepository,
userSubRepo ports.UserSubscriptionRepository,
billingCacheService *BillingCacheService,
) AdminService {
return &adminServiceImpl{
userRepo: repos.User,
groupRepo: repos.Group,
accountRepo: repos.Account,
proxyRepo: repos.Proxy,
apiKeyRepo: repos.ApiKey,
redeemCodeRepo: repos.RedeemCode,
usageLogRepo: repos.UsageLog,
userSubRepo: repos.UserSubscription,
userRepo: userRepo,
groupRepo: groupRepo,
accountRepo: accountRepo,
proxyRepo: proxyRepo,
apiKeyRepo: apiKeyRepo,
redeemCodeRepo: redeemCodeRepo,
usageLogRepo: usageLogRepo,
userSubRepo: userSubRepo,
billingCacheService: billingCacheService,
}
}
// User management implementations
func (s *adminServiceImpl) ListUsers(ctx context.Context, page, pageSize int, status, role, search string) ([]model.User, int64, error) {
params := repository.PaginationParams{Page: page, PageSize: pageSize}
params := pagination.PaginationParams{Page: page, PageSize: pageSize}
users, result, err := s.userRepo.ListWithFilters(ctx, params, status, role, search)
if err != nil {
return nil, 0, err
@@ -376,7 +387,7 @@ func (s *adminServiceImpl) UpdateUserBalance(ctx context.Context, userID int64,
}
func (s *adminServiceImpl) GetUserAPIKeys(ctx context.Context, userID int64, page, pageSize int) ([]model.ApiKey, int64, error) {
params := repository.PaginationParams{Page: page, PageSize: pageSize}
params := pagination.PaginationParams{Page: page, PageSize: pageSize}
keys, result, err := s.apiKeyRepo.ListByUserID(ctx, userID, params)
if err != nil {
return nil, 0, err
@@ -397,7 +408,7 @@ func (s *adminServiceImpl) GetUserUsageStats(ctx context.Context, userID int64,
// Group management implementations
func (s *adminServiceImpl) ListGroups(ctx context.Context, page, pageSize int, platform, status string, isExclusive *bool) ([]model.Group, int64, error) {
params := repository.PaginationParams{Page: page, PageSize: pageSize}
params := pagination.PaginationParams{Page: page, PageSize: pageSize}
groups, result, err := s.groupRepo.ListWithFilters(ctx, params, platform, status, isExclusive)
if err != nil {
return nil, 0, err
@@ -568,7 +579,7 @@ func (s *adminServiceImpl) DeleteGroup(ctx context.Context, id int64) error {
}
func (s *adminServiceImpl) GetGroupAPIKeys(ctx context.Context, groupID int64, page, pageSize int) ([]model.ApiKey, int64, error) {
params := repository.PaginationParams{Page: page, PageSize: pageSize}
params := pagination.PaginationParams{Page: page, PageSize: pageSize}
keys, result, err := s.apiKeyRepo.ListByGroupID(ctx, groupID, params)
if err != nil {
return nil, 0, err
@@ -578,7 +589,7 @@ func (s *adminServiceImpl) GetGroupAPIKeys(ctx context.Context, groupID int64, p
// Account management implementations
func (s *adminServiceImpl) ListAccounts(ctx context.Context, page, pageSize int, platform, accountType, status, search string) ([]model.Account, int64, error) {
params := repository.PaginationParams{Page: page, PageSize: pageSize}
params := pagination.PaginationParams{Page: page, PageSize: pageSize}
accounts, result, err := s.accountRepo.ListWithFilters(ctx, params, platform, accountType, status, search)
if err != nil {
return nil, 0, err
@@ -696,7 +707,7 @@ func (s *adminServiceImpl) SetAccountSchedulable(ctx context.Context, id int64,
// Proxy management implementations
func (s *adminServiceImpl) ListProxies(ctx context.Context, page, pageSize int, protocol, status, search string) ([]model.Proxy, int64, error) {
params := repository.PaginationParams{Page: page, PageSize: pageSize}
params := pagination.PaginationParams{Page: page, PageSize: pageSize}
proxies, result, err := s.proxyRepo.ListWithFilters(ctx, params, protocol, status, search)
if err != nil {
return nil, 0, err
@@ -781,7 +792,7 @@ func (s *adminServiceImpl) CheckProxyExists(ctx context.Context, host string, po
// Redeem code management implementations
func (s *adminServiceImpl) ListRedeemCodes(ctx context.Context, page, pageSize int, codeType, status, search string) ([]model.RedeemCode, int64, error) {
params := repository.PaginationParams{Page: page, PageSize: pageSize}
params := pagination.PaginationParams{Page: page, PageSize: pageSize}
codes, result, err := s.redeemCodeRepo.ListWithFilters(ctx, params, codeType, status, search)
if err != nil {
return nil, 0, err

View File

@@ -8,8 +8,9 @@ import (
"fmt"
"sub2api/internal/config"
"sub2api/internal/model"
"sub2api/internal/pkg/pagination"
"sub2api/internal/pkg/timezone"
"sub2api/internal/repository"
"sub2api/internal/service/ports"
"time"
"github.com/redis/go-redis/v9"
@@ -17,12 +18,12 @@ import (
)
var (
ErrApiKeyNotFound = errors.New("api key not found")
ErrGroupNotAllowed = errors.New("user is not allowed to bind this group")
ErrApiKeyExists = errors.New("api key already exists")
ErrApiKeyTooShort = errors.New("api key must be at least 16 characters")
ErrApiKeyInvalidChars = errors.New("api key can only contain letters, numbers, underscores, and hyphens")
ErrApiKeyRateLimited = errors.New("too many failed attempts, please try again later")
ErrApiKeyNotFound = errors.New("api key not found")
ErrGroupNotAllowed = errors.New("user is not allowed to bind this group")
ErrApiKeyExists = errors.New("api key already exists")
ErrApiKeyTooShort = errors.New("api key must be at least 16 characters")
ErrApiKeyInvalidChars = errors.New("api key can only contain letters, numbers, underscores, and hyphens")
ErrApiKeyRateLimited = errors.New("too many failed attempts, please try again later")
)
const (
@@ -47,20 +48,20 @@ type UpdateApiKeyRequest struct {
// ApiKeyService API Key服务
type ApiKeyService struct {
apiKeyRepo *repository.ApiKeyRepository
userRepo *repository.UserRepository
groupRepo *repository.GroupRepository
userSubRepo *repository.UserSubscriptionRepository
apiKeyRepo ports.ApiKeyRepository
userRepo ports.UserRepository
groupRepo ports.GroupRepository
userSubRepo ports.UserSubscriptionRepository
rdb *redis.Client
cfg *config.Config
}
// NewApiKeyService 创建API Key服务实例
func NewApiKeyService(
apiKeyRepo *repository.ApiKeyRepository,
userRepo *repository.UserRepository,
groupRepo *repository.GroupRepository,
userSubRepo *repository.UserSubscriptionRepository,
apiKeyRepo ports.ApiKeyRepository,
userRepo ports.UserRepository,
groupRepo ports.GroupRepository,
userSubRepo ports.UserSubscriptionRepository,
rdb *redis.Client,
cfg *config.Config,
) *ApiKeyService {
@@ -237,7 +238,7 @@ func (s *ApiKeyService) Create(ctx context.Context, userID int64, req CreateApiK
}
// List 获取用户的API Key列表
func (s *ApiKeyService) List(ctx context.Context, userID int64, params repository.PaginationParams) ([]model.ApiKey, *repository.PaginationResult, error) {
func (s *ApiKeyService) List(ctx context.Context, userID int64, params pagination.PaginationParams) ([]model.ApiKey, *pagination.PaginationResult, error) {
keys, pagination, err := s.apiKeyRepo.ListByUserID(ctx, userID, params)
if err != nil {
return nil, nil, fmt.Errorf("list api keys: %w", err)

View File

@@ -7,7 +7,7 @@ import (
"log"
"sub2api/internal/config"
"sub2api/internal/model"
"sub2api/internal/repository"
"sub2api/internal/service/ports"
"time"
"github.com/golang-jwt/jwt/v5"
@@ -35,7 +35,7 @@ type JWTClaims struct {
// AuthService 认证服务
type AuthService struct {
userRepo *repository.UserRepository
userRepo ports.UserRepository
cfg *config.Config
settingService *SettingService
emailService *EmailService
@@ -45,7 +45,7 @@ type AuthService struct {
// NewAuthService 创建认证服务实例
func NewAuthService(
userRepo *repository.UserRepository,
userRepo ports.UserRepository,
cfg *config.Config,
settingService *SettingService,
emailService *EmailService,

View File

@@ -9,7 +9,7 @@ import (
"time"
"sub2api/internal/model"
"sub2api/internal/repository"
"sub2api/internal/service/ports"
"github.com/redis/go-redis/v9"
)
@@ -81,12 +81,12 @@ type subscriptionCacheData struct {
// 负责余额和订阅数据的缓存管理,提供高性能的计费资格检查
type BillingCacheService struct {
rdb *redis.Client
userRepo *repository.UserRepository
subRepo *repository.UserSubscriptionRepository
userRepo ports.UserRepository
subRepo ports.UserSubscriptionRepository
}
// NewBillingCacheService 创建计费缓存服务
func NewBillingCacheService(rdb *redis.Client, userRepo *repository.UserRepository, subRepo *repository.UserSubscriptionRepository) *BillingCacheService {
func NewBillingCacheService(rdb *redis.Client, userRepo ports.UserRepository, subRepo ports.UserSubscriptionRepository) *BillingCacheService {
return &BillingCacheService{
rdb: rdb,
userRepo: userRepo,

View File

@@ -11,7 +11,7 @@ import (
"net/smtp"
"strconv"
"sub2api/internal/model"
"sub2api/internal/repository"
"sub2api/internal/service/ports"
"time"
"github.com/redis/go-redis/v9"
@@ -25,9 +25,9 @@ var (
)
const (
verifyCodeKeyPrefix = "email_verify:"
verifyCodeTTL = 15 * time.Minute
verifyCodeCooldown = 1 * time.Minute
verifyCodeKeyPrefix = "email_verify:"
verifyCodeTTL = 15 * time.Minute
verifyCodeCooldown = 1 * time.Minute
maxVerifyCodeAttempts = 5
)
@@ -51,12 +51,12 @@ type SmtpConfig struct {
// EmailService 邮件服务
type EmailService struct {
settingRepo *repository.SettingRepository
settingRepo ports.SettingRepository
rdb *redis.Client
}
// NewEmailService 创建邮件服务实例
func NewEmailService(settingRepo *repository.SettingRepository, rdb *redis.Client) *EmailService {
func NewEmailService(settingRepo ports.SettingRepository, rdb *redis.Client) *EmailService {
return &EmailService{
settingRepo: settingRepo,
rdb: rdb,

View File

@@ -21,7 +21,7 @@ import (
"sub2api/internal/config"
"sub2api/internal/model"
"sub2api/internal/pkg/claude"
"sub2api/internal/repository"
"sub2api/internal/service/ports"
"github.com/gin-gonic/gin"
"github.com/redis/go-redis/v9"
@@ -78,7 +78,10 @@ type ForwardResult struct {
// GatewayService handles API gateway operations
type GatewayService struct {
repos *repository.Repositories
accountRepo ports.AccountRepository
usageLogRepo ports.UsageLogRepository
userRepo ports.UserRepository
userSubRepo ports.UserSubscriptionRepository
rdb *redis.Client
cfg *config.Config
oauthService *OAuthService
@@ -90,7 +93,19 @@ type GatewayService struct {
}
// NewGatewayService creates a new GatewayService
func NewGatewayService(repos *repository.Repositories, rdb *redis.Client, cfg *config.Config, oauthService *OAuthService, billingService *BillingService, rateLimitService *RateLimitService, billingCacheService *BillingCacheService, identityService *IdentityService) *GatewayService {
func NewGatewayService(
accountRepo ports.AccountRepository,
usageLogRepo ports.UsageLogRepository,
userRepo ports.UserRepository,
userSubRepo ports.UserSubscriptionRepository,
rdb *redis.Client,
cfg *config.Config,
oauthService *OAuthService,
billingService *BillingService,
rateLimitService *RateLimitService,
billingCacheService *BillingCacheService,
identityService *IdentityService,
) *GatewayService {
// 计算响应头超时时间
responseHeaderTimeout := time.Duration(cfg.Gateway.ResponseHeaderTimeout) * time.Second
if responseHeaderTimeout == 0 {
@@ -105,7 +120,10 @@ func NewGatewayService(repos *repository.Repositories, rdb *redis.Client, cfg *c
// 注意:不设置整体 Timeout让流式响应可以无限时间传输
}
return &GatewayService{
repos: repos,
accountRepo: accountRepo,
usageLogRepo: usageLogRepo,
userRepo: userRepo,
userSubRepo: userSubRepo,
rdb: rdb,
cfg: cfg,
oauthService: oauthService,
@@ -274,7 +292,7 @@ func (s *GatewayService) SelectAccountForModel(ctx context.Context, groupID *int
if sessionHash != "" {
accountID, err := s.rdb.Get(ctx, stickySessionPrefix+sessionHash).Int64()
if err == nil && accountID > 0 {
account, err := s.repos.Account.GetByID(ctx, accountID)
account, err := s.accountRepo.GetByID(ctx, accountID)
// 使用IsSchedulable代替IsActive确保限流/过载账号不会被选中
// 同时检查模型支持
if err == nil && account.IsSchedulable() && (requestedModel == "" || account.IsModelSupported(requestedModel)) {
@@ -289,9 +307,9 @@ func (s *GatewayService) SelectAccountForModel(ctx context.Context, groupID *int
var accounts []model.Account
var err error
if groupID != nil {
accounts, err = s.repos.Account.ListSchedulableByGroupID(ctx, *groupID)
accounts, err = s.accountRepo.ListSchedulableByGroupID(ctx, *groupID)
} else {
accounts, err = s.repos.Account.ListSchedulable(ctx)
accounts, err = s.accountRepo.ListSchedulable(ctx)
}
if err != nil {
return nil, fmt.Errorf("query accounts failed: %w", err)
@@ -378,7 +396,7 @@ func (s *GatewayService) getOAuthToken(ctx context.Context, account *model.Accou
account.Credentials["refresh_token"] = tokenInfo.RefreshToken
}
if err := s.repos.Account.Update(ctx, account); err != nil {
if err := s.accountRepo.Update(ctx, account); err != nil {
log.Printf("Failed to update account credentials: %v", err)
}
@@ -667,7 +685,7 @@ func (s *GatewayService) forceRefreshToken(ctx context.Context, account *model.A
account.Credentials["refresh_token"] = tokenInfo.RefreshToken
}
if err := s.repos.Account.Update(ctx, account); err != nil {
if err := s.accountRepo.Update(ctx, account); err != nil {
log.Printf("Failed to update account credentials: %v", err)
}
@@ -999,7 +1017,7 @@ func (s *GatewayService) RecordUsage(ctx context.Context, input *RecordUsageInpu
usageLog.SubscriptionID = &subscription.ID
}
if err := s.repos.UsageLog.Create(ctx, usageLog); err != nil {
if err := s.usageLogRepo.Create(ctx, usageLog); err != nil {
log.Printf("Create usage log failed: %v", err)
}
@@ -1007,7 +1025,7 @@ func (s *GatewayService) RecordUsage(ctx context.Context, input *RecordUsageInpu
if isSubscriptionBilling {
// 订阅模式:更新订阅用量(使用 TotalCost 原始费用,不考虑倍率)
if cost.TotalCost > 0 {
if err := s.repos.UserSubscription.IncrementUsage(ctx, subscription.ID, cost.TotalCost); err != nil {
if err := s.userSubRepo.IncrementUsage(ctx, subscription.ID, cost.TotalCost); err != nil {
log.Printf("Increment subscription usage failed: %v", err)
}
// 异步更新订阅缓存
@@ -1022,7 +1040,7 @@ func (s *GatewayService) RecordUsage(ctx context.Context, input *RecordUsageInpu
} else {
// 余额模式:扣除用户余额(使用 ActualCost 考虑倍率后的费用)
if cost.ActualCost > 0 {
if err := s.repos.User.DeductBalance(ctx, user.ID, cost.ActualCost); err != nil {
if err := s.userRepo.DeductBalance(ctx, user.ID, cost.ActualCost); err != nil {
log.Printf("Deduct balance failed: %v", err)
}
// 异步更新余额缓存
@@ -1037,7 +1055,7 @@ func (s *GatewayService) RecordUsage(ctx context.Context, input *RecordUsageInpu
}
// 更新账号最后使用时间
if err := s.repos.Account.UpdateLastUsed(ctx, account.ID); err != nil {
if err := s.accountRepo.UpdateLastUsed(ctx, account.ID); err != nil {
log.Printf("Update last used failed: %v", err)
}

View File

@@ -5,7 +5,8 @@ import (
"errors"
"fmt"
"sub2api/internal/model"
"sub2api/internal/repository"
"sub2api/internal/pkg/pagination"
"sub2api/internal/service/ports"
"gorm.io/gorm"
)
@@ -34,11 +35,11 @@ type UpdateGroupRequest struct {
// GroupService 分组管理服务
type GroupService struct {
groupRepo *repository.GroupRepository
groupRepo ports.GroupRepository
}
// NewGroupService 创建分组服务实例
func NewGroupService(groupRepo *repository.GroupRepository) *GroupService {
func NewGroupService(groupRepo ports.GroupRepository) *GroupService {
return &GroupService{
groupRepo: groupRepo,
}
@@ -84,7 +85,7 @@ func (s *GroupService) GetByID(ctx context.Context, id int64) (*model.Group, err
}
// List 获取分组列表
func (s *GroupService) List(ctx context.Context, params repository.PaginationParams) ([]model.Group, *repository.PaginationResult, error) {
func (s *GroupService) List(ctx context.Context, params pagination.PaginationParams) ([]model.Group, *pagination.PaginationResult, error) {
groups, pagination, err := s.groupRepo.List(ctx, params)
if err != nil {
return nil, nil, fmt.Errorf("list groups: %w", err)

View File

@@ -12,7 +12,7 @@ import (
"sub2api/internal/model"
"sub2api/internal/pkg/oauth"
"sub2api/internal/repository"
"sub2api/internal/service/ports"
"github.com/imroc/req/v3"
)
@@ -20,11 +20,11 @@ import (
// OAuthService handles OAuth authentication flows
type OAuthService struct {
sessionStore *oauth.SessionStore
proxyRepo *repository.ProxyRepository
proxyRepo ports.ProxyRepository
}
// NewOAuthService creates a new OAuth service
func NewOAuthService(proxyRepo *repository.ProxyRepository) *OAuthService {
func NewOAuthService(proxyRepo ports.ProxyRepository) *OAuthService {
return &OAuthService{
sessionStore: oauth.NewSessionStore(),
proxyRepo: proxyRepo,
@@ -459,7 +459,7 @@ func (s *OAuthService) RefreshAccountToken(ctx context.Context, account *model.A
// createReqClient creates a req client with Chrome impersonation and optional proxy
func (s *OAuthService) createReqClient(proxyURL string) *req.Client {
client := req.C().
ImpersonateChrome(). // Impersonate Chrome browser to bypass Cloudflare
ImpersonateChrome(). // Impersonate Chrome browser to bypass Cloudflare
SetTimeout(60 * time.Second)
// Set proxy if specified

View File

@@ -0,0 +1,35 @@
package ports
import (
"context"
"time"
"sub2api/internal/model"
"sub2api/internal/pkg/pagination"
)
type AccountRepository interface {
Create(ctx context.Context, account *model.Account) error
GetByID(ctx context.Context, id int64) (*model.Account, error)
Update(ctx context.Context, account *model.Account) error
Delete(ctx context.Context, id int64) error
List(ctx context.Context, params pagination.PaginationParams) ([]model.Account, *pagination.PaginationResult, error)
ListWithFilters(ctx context.Context, params pagination.PaginationParams, platform, accountType, status, search string) ([]model.Account, *pagination.PaginationResult, error)
ListByGroup(ctx context.Context, groupID int64) ([]model.Account, error)
ListActive(ctx context.Context) ([]model.Account, error)
ListByPlatform(ctx context.Context, platform string) ([]model.Account, error)
UpdateLastUsed(ctx context.Context, id int64) error
SetError(ctx context.Context, id int64, errorMsg string) error
SetSchedulable(ctx context.Context, id int64, schedulable bool) error
BindGroups(ctx context.Context, accountID int64, groupIDs []int64) error
ListSchedulable(ctx context.Context) ([]model.Account, error)
ListSchedulableByGroupID(ctx context.Context, groupID int64) ([]model.Account, error)
SetRateLimited(ctx context.Context, id int64, resetAt time.Time) error
SetOverloaded(ctx context.Context, id int64, until time.Time) error
ClearRateLimit(ctx context.Context, id int64) error
UpdateSessionWindow(ctx context.Context, id int64, start, end *time.Time, status string) error
}

View File

@@ -0,0 +1,24 @@
package ports
import (
"context"
"sub2api/internal/model"
"sub2api/internal/pkg/pagination"
)
type ApiKeyRepository interface {
Create(ctx context.Context, key *model.ApiKey) error
GetByID(ctx context.Context, id int64) (*model.ApiKey, error)
GetByKey(ctx context.Context, key string) (*model.ApiKey, error)
Update(ctx context.Context, key *model.ApiKey) error
Delete(ctx context.Context, id int64) error
ListByUserID(ctx context.Context, userID int64, params pagination.PaginationParams) ([]model.ApiKey, *pagination.PaginationResult, error)
CountByUserID(ctx context.Context, userID int64) (int64, error)
ExistsByKey(ctx context.Context, key string) (bool, error)
ListByGroupID(ctx context.Context, groupID int64, params pagination.PaginationParams) ([]model.ApiKey, *pagination.PaginationResult, error)
SearchApiKeys(ctx context.Context, userID int64, keyword string, limit int) ([]model.ApiKey, error)
ClearGroupIDByGroupID(ctx context.Context, groupID int64) (int64, error)
CountByGroupID(ctx context.Context, groupID int64) (int64, error)
}

View File

@@ -0,0 +1,28 @@
package ports
import (
"context"
"sub2api/internal/model"
"sub2api/internal/pkg/pagination"
"gorm.io/gorm"
)
type GroupRepository interface {
Create(ctx context.Context, group *model.Group) error
GetByID(ctx context.Context, id int64) (*model.Group, error)
Update(ctx context.Context, group *model.Group) error
Delete(ctx context.Context, id int64) error
List(ctx context.Context, params pagination.PaginationParams) ([]model.Group, *pagination.PaginationResult, error)
ListWithFilters(ctx context.Context, params pagination.PaginationParams, platform, status string, isExclusive *bool) ([]model.Group, *pagination.PaginationResult, error)
ListActive(ctx context.Context) ([]model.Group, error)
ListActiveByPlatform(ctx context.Context, platform string) ([]model.Group, error)
ExistsByName(ctx context.Context, name string) (bool, error)
GetAccountCount(ctx context.Context, groupID int64) (int64, error)
DeleteAccountGroupsByGroupID(ctx context.Context, groupID int64) (int64, error)
DB() *gorm.DB
}

View File

@@ -0,0 +1,23 @@
package ports
import (
"context"
"sub2api/internal/model"
"sub2api/internal/pkg/pagination"
)
type ProxyRepository interface {
Create(ctx context.Context, proxy *model.Proxy) error
GetByID(ctx context.Context, id int64) (*model.Proxy, error)
Update(ctx context.Context, proxy *model.Proxy) error
Delete(ctx context.Context, id int64) error
List(ctx context.Context, params pagination.PaginationParams) ([]model.Proxy, *pagination.PaginationResult, error)
ListWithFilters(ctx context.Context, params pagination.PaginationParams, protocol, status, search string) ([]model.Proxy, *pagination.PaginationResult, error)
ListActive(ctx context.Context) ([]model.Proxy, error)
ListActiveWithAccountCount(ctx context.Context) ([]model.ProxyWithAccountCount, error)
ExistsByHostPortAuth(ctx context.Context, host string, port int, username, password string) (bool, error)
CountAccountsByProxyID(ctx context.Context, proxyID int64) (int64, error)
}

View File

@@ -0,0 +1,22 @@
package ports
import (
"context"
"sub2api/internal/model"
"sub2api/internal/pkg/pagination"
)
type RedeemCodeRepository interface {
Create(ctx context.Context, code *model.RedeemCode) error
CreateBatch(ctx context.Context, codes []model.RedeemCode) error
GetByID(ctx context.Context, id int64) (*model.RedeemCode, error)
GetByCode(ctx context.Context, code string) (*model.RedeemCode, error)
Update(ctx context.Context, code *model.RedeemCode) error
Delete(ctx context.Context, id int64) error
Use(ctx context.Context, id, userID int64) error
List(ctx context.Context, params pagination.PaginationParams) ([]model.RedeemCode, *pagination.PaginationResult, error)
ListWithFilters(ctx context.Context, params pagination.PaginationParams, codeType, status, search string) ([]model.RedeemCode, *pagination.PaginationResult, error)
ListByUser(ctx context.Context, userID int64, limit int) ([]model.RedeemCode, error)
}

View File

@@ -0,0 +1,17 @@
package ports
import (
"context"
"sub2api/internal/model"
)
type SettingRepository interface {
Get(ctx context.Context, key string) (*model.Setting, error)
GetValue(ctx context.Context, key string) (string, error)
Set(ctx context.Context, key, value string) error
GetMultiple(ctx context.Context, keys []string) (map[string]string, error)
SetMultiple(ctx context.Context, settings map[string]string) error
GetAll(ctx context.Context) (map[string]string, error)
Delete(ctx context.Context, key string) error
}

View File

@@ -0,0 +1,28 @@
package ports
import (
"context"
"time"
"sub2api/internal/model"
"sub2api/internal/pkg/pagination"
"sub2api/internal/pkg/usagestats"
)
type UsageLogRepository interface {
Create(ctx context.Context, log *model.UsageLog) error
GetByID(ctx context.Context, id int64) (*model.UsageLog, error)
Delete(ctx context.Context, id int64) error
ListByUser(ctx context.Context, userID int64, params pagination.PaginationParams) ([]model.UsageLog, *pagination.PaginationResult, error)
ListByApiKey(ctx context.Context, apiKeyID int64, params pagination.PaginationParams) ([]model.UsageLog, *pagination.PaginationResult, error)
ListByAccount(ctx context.Context, accountID int64, params pagination.PaginationParams) ([]model.UsageLog, *pagination.PaginationResult, error)
ListByUserAndTimeRange(ctx context.Context, userID int64, startTime, endTime time.Time) ([]model.UsageLog, *pagination.PaginationResult, error)
ListByApiKeyAndTimeRange(ctx context.Context, apiKeyID int64, startTime, endTime time.Time) ([]model.UsageLog, *pagination.PaginationResult, error)
ListByAccountAndTimeRange(ctx context.Context, accountID int64, startTime, endTime time.Time) ([]model.UsageLog, *pagination.PaginationResult, error)
ListByModelAndTimeRange(ctx context.Context, modelName string, startTime, endTime time.Time) ([]model.UsageLog, *pagination.PaginationResult, error)
GetAccountWindowStats(ctx context.Context, accountID int64, startTime time.Time) (*usagestats.AccountStats, error)
GetAccountTodayStats(ctx context.Context, accountID int64) (*usagestats.AccountStats, error)
}

View File

@@ -0,0 +1,25 @@
package ports
import (
"context"
"sub2api/internal/model"
"sub2api/internal/pkg/pagination"
)
type UserRepository interface {
Create(ctx context.Context, user *model.User) error
GetByID(ctx context.Context, id int64) (*model.User, error)
GetByEmail(ctx context.Context, email string) (*model.User, error)
Update(ctx context.Context, user *model.User) error
Delete(ctx context.Context, id int64) error
List(ctx context.Context, params pagination.PaginationParams) ([]model.User, *pagination.PaginationResult, error)
ListWithFilters(ctx context.Context, params pagination.PaginationParams, status, role, search string) ([]model.User, *pagination.PaginationResult, error)
UpdateBalance(ctx context.Context, id int64, amount float64) error
DeductBalance(ctx context.Context, id int64, amount float64) error
UpdateConcurrency(ctx context.Context, id int64, amount int) error
ExistsByEmail(ctx context.Context, email string) (bool, error)
RemoveGroupFromAllowedGroups(ctx context.Context, groupID int64) (int64, error)
}

View File

@@ -0,0 +1,36 @@
package ports
import (
"context"
"time"
"sub2api/internal/model"
"sub2api/internal/pkg/pagination"
)
type UserSubscriptionRepository interface {
Create(ctx context.Context, sub *model.UserSubscription) error
GetByID(ctx context.Context, id int64) (*model.UserSubscription, error)
GetByUserIDAndGroupID(ctx context.Context, userID, groupID int64) (*model.UserSubscription, error)
GetActiveByUserIDAndGroupID(ctx context.Context, userID, groupID int64) (*model.UserSubscription, error)
Update(ctx context.Context, sub *model.UserSubscription) error
Delete(ctx context.Context, id int64) error
ListByUserID(ctx context.Context, userID int64) ([]model.UserSubscription, error)
ListActiveByUserID(ctx context.Context, userID int64) ([]model.UserSubscription, error)
ListByGroupID(ctx context.Context, groupID int64, params pagination.PaginationParams) ([]model.UserSubscription, *pagination.PaginationResult, error)
List(ctx context.Context, params pagination.PaginationParams, userID, groupID *int64, status string) ([]model.UserSubscription, *pagination.PaginationResult, error)
ExistsByUserIDAndGroupID(ctx context.Context, userID, groupID int64) (bool, error)
ExtendExpiry(ctx context.Context, subscriptionID int64, newExpiresAt time.Time) error
UpdateStatus(ctx context.Context, subscriptionID int64, status string) error
UpdateNotes(ctx context.Context, subscriptionID int64, notes string) error
ActivateWindows(ctx context.Context, id int64, start time.Time) error
ResetDailyUsage(ctx context.Context, id int64, newWindowStart time.Time) error
ResetWeeklyUsage(ctx context.Context, id int64, newWindowStart time.Time) error
ResetMonthlyUsage(ctx context.Context, id int64, newWindowStart time.Time) error
IncrementUsage(ctx context.Context, id int64, costUSD float64) error
BatchUpdateExpiredStatus(ctx context.Context) (int64, error)
}

View File

@@ -5,7 +5,8 @@ import (
"errors"
"fmt"
"sub2api/internal/model"
"sub2api/internal/repository"
"sub2api/internal/pkg/pagination"
"sub2api/internal/service/ports"
"gorm.io/gorm"
)
@@ -37,11 +38,11 @@ type UpdateProxyRequest struct {
// ProxyService 代理管理服务
type ProxyService struct {
proxyRepo *repository.ProxyRepository
proxyRepo ports.ProxyRepository
}
// NewProxyService 创建代理服务实例
func NewProxyService(proxyRepo *repository.ProxyRepository) *ProxyService {
func NewProxyService(proxyRepo ports.ProxyRepository) *ProxyService {
return &ProxyService{
proxyRepo: proxyRepo,
}
@@ -80,7 +81,7 @@ func (s *ProxyService) GetByID(ctx context.Context, id int64) (*model.Proxy, err
}
// List 获取代理列表
func (s *ProxyService) List(ctx context.Context, params repository.PaginationParams) ([]model.Proxy, *repository.PaginationResult, error) {
func (s *ProxyService) List(ctx context.Context, params pagination.PaginationParams) ([]model.Proxy, *pagination.PaginationResult, error) {
proxies, pagination, err := s.proxyRepo.List(ctx, params)
if err != nil {
return nil, nil, fmt.Errorf("list proxies: %w", err)

View File

@@ -9,20 +9,20 @@ import (
"sub2api/internal/config"
"sub2api/internal/model"
"sub2api/internal/repository"
"sub2api/internal/service/ports"
)
// RateLimitService 处理限流和过载状态管理
type RateLimitService struct {
repos *repository.Repositories
cfg *config.Config
accountRepo ports.AccountRepository
cfg *config.Config
}
// NewRateLimitService 创建RateLimitService实例
func NewRateLimitService(repos *repository.Repositories, cfg *config.Config) *RateLimitService {
func NewRateLimitService(accountRepo ports.AccountRepository, cfg *config.Config) *RateLimitService {
return &RateLimitService{
repos: repos,
cfg: cfg,
accountRepo: accountRepo,
cfg: cfg,
}
}
@@ -62,7 +62,7 @@ func (s *RateLimitService) HandleUpstreamError(ctx context.Context, account *mod
// handleAuthError 处理认证类错误(401/403),停止账号调度
func (s *RateLimitService) handleAuthError(ctx context.Context, account *model.Account, errorMsg string) {
if err := s.repos.Account.SetError(ctx, account.ID, errorMsg); err != nil {
if err := s.accountRepo.SetError(ctx, account.ID, errorMsg); err != nil {
log.Printf("SetError failed for account %d: %v", account.ID, err)
return
}
@@ -77,7 +77,7 @@ func (s *RateLimitService) handle429(ctx context.Context, account *model.Account
if resetTimestamp == "" {
// 没有重置时间使用默认5分钟
resetAt := time.Now().Add(5 * time.Minute)
if err := s.repos.Account.SetRateLimited(ctx, account.ID, resetAt); err != nil {
if err := s.accountRepo.SetRateLimited(ctx, account.ID, resetAt); err != nil {
log.Printf("SetRateLimited failed for account %d: %v", account.ID, err)
}
return
@@ -88,7 +88,7 @@ func (s *RateLimitService) handle429(ctx context.Context, account *model.Account
if err != nil {
log.Printf("Parse reset timestamp failed: %v", err)
resetAt := time.Now().Add(5 * time.Minute)
if err := s.repos.Account.SetRateLimited(ctx, account.ID, resetAt); err != nil {
if err := s.accountRepo.SetRateLimited(ctx, account.ID, resetAt); err != nil {
log.Printf("SetRateLimited failed for account %d: %v", account.ID, err)
}
return
@@ -97,7 +97,7 @@ func (s *RateLimitService) handle429(ctx context.Context, account *model.Account
resetAt := time.Unix(ts, 0)
// 标记限流状态
if err := s.repos.Account.SetRateLimited(ctx, account.ID, resetAt); err != nil {
if err := s.accountRepo.SetRateLimited(ctx, account.ID, resetAt); err != nil {
log.Printf("SetRateLimited failed for account %d: %v", account.ID, err)
return
}
@@ -105,7 +105,7 @@ func (s *RateLimitService) handle429(ctx context.Context, account *model.Account
// 根据重置时间反推5h窗口
windowEnd := resetAt
windowStart := resetAt.Add(-5 * time.Hour)
if err := s.repos.Account.UpdateSessionWindow(ctx, account.ID, &windowStart, &windowEnd, "rejected"); err != nil {
if err := s.accountRepo.UpdateSessionWindow(ctx, account.ID, &windowStart, &windowEnd, "rejected"); err != nil {
log.Printf("UpdateSessionWindow failed for account %d: %v", account.ID, err)
}
@@ -121,7 +121,7 @@ func (s *RateLimitService) handle529(ctx context.Context, account *model.Account
}
until := time.Now().Add(time.Duration(cooldownMinutes) * time.Minute)
if err := s.repos.Account.SetOverloaded(ctx, account.ID, until); err != nil {
if err := s.accountRepo.SetOverloaded(ctx, account.ID, until); err != nil {
log.Printf("SetOverloaded failed for account %d: %v", account.ID, err)
return
}
@@ -152,13 +152,13 @@ func (s *RateLimitService) UpdateSessionWindow(ctx context.Context, account *mod
log.Printf("Account %d: initializing 5h window from %v to %v (status: %s)", account.ID, start, end, status)
}
if err := s.repos.Account.UpdateSessionWindow(ctx, account.ID, windowStart, windowEnd, status); err != nil {
if err := s.accountRepo.UpdateSessionWindow(ctx, account.ID, windowStart, windowEnd, status); err != nil {
log.Printf("UpdateSessionWindow failed for account %d: %v", account.ID, err)
}
// 如果状态为allowed且之前有限流说明窗口已重置清除限流状态
if status == "allowed" && account.IsRateLimited() {
if err := s.repos.Account.ClearRateLimit(ctx, account.ID); err != nil {
if err := s.accountRepo.ClearRateLimit(ctx, account.ID); err != nil {
log.Printf("ClearRateLimit failed for account %d: %v", account.ID, err)
}
}
@@ -166,5 +166,5 @@ func (s *RateLimitService) UpdateSessionWindow(ctx context.Context, account *mod
// ClearRateLimit 清除账号的限流状态
func (s *RateLimitService) ClearRateLimit(ctx context.Context, accountID int64) error {
return s.repos.Account.ClearRateLimit(ctx, accountID)
return s.accountRepo.ClearRateLimit(ctx, accountID)
}

View File

@@ -8,7 +8,8 @@ import (
"fmt"
"strings"
"sub2api/internal/model"
"sub2api/internal/repository"
"sub2api/internal/pkg/pagination"
"sub2api/internal/service/ports"
"time"
"github.com/redis/go-redis/v9"
@@ -49,8 +50,8 @@ type RedeemCodeResponse struct {
// RedeemService 兑换码服务
type RedeemService struct {
redeemRepo *repository.RedeemCodeRepository
userRepo *repository.UserRepository
redeemRepo ports.RedeemCodeRepository
userRepo ports.UserRepository
subscriptionService *SubscriptionService
rdb *redis.Client
billingCacheService *BillingCacheService
@@ -58,8 +59,8 @@ type RedeemService struct {
// NewRedeemService 创建兑换码服务实例
func NewRedeemService(
redeemRepo *repository.RedeemCodeRepository,
userRepo *repository.UserRepository,
redeemRepo ports.RedeemCodeRepository,
userRepo ports.UserRepository,
subscriptionService *SubscriptionService,
rdb *redis.Client,
billingCacheService *BillingCacheService,
@@ -337,7 +338,7 @@ func (s *RedeemService) GetByCode(ctx context.Context, code string) (*model.Rede
}
// List 获取兑换码列表(管理员功能)
func (s *RedeemService) List(ctx context.Context, params repository.PaginationParams) ([]model.RedeemCode, *repository.PaginationResult, error) {
func (s *RedeemService) List(ctx context.Context, params pagination.PaginationParams) ([]model.RedeemCode, *pagination.PaginationResult, error) {
codes, pagination, err := s.redeemRepo.List(ctx, params)
if err != nil {
return nil, nil, fmt.Errorf("list redeem codes: %w", err)

View File

@@ -7,7 +7,7 @@ import (
"strconv"
"sub2api/internal/config"
"sub2api/internal/model"
"sub2api/internal/repository"
"sub2api/internal/service/ports"
"gorm.io/gorm"
)
@@ -18,12 +18,12 @@ var (
// SettingService 系统设置服务
type SettingService struct {
settingRepo *repository.SettingRepository
settingRepo ports.SettingRepository
cfg *config.Config
}
// NewSettingService 创建系统设置服务实例
func NewSettingService(settingRepo *repository.SettingRepository, cfg *config.Config) *SettingService {
func NewSettingService(settingRepo ports.SettingRepository, cfg *config.Config) *SettingService {
return &SettingService{
settingRepo: settingRepo,
cfg: cfg,

View File

@@ -7,7 +7,8 @@ import (
"time"
"sub2api/internal/model"
"sub2api/internal/repository"
"sub2api/internal/pkg/pagination"
"sub2api/internal/service/ports"
)
var (
@@ -23,14 +24,16 @@ var (
// SubscriptionService 订阅服务
type SubscriptionService struct {
repos *repository.Repositories
groupRepo ports.GroupRepository
userSubRepo ports.UserSubscriptionRepository
billingCacheService *BillingCacheService
}
// NewSubscriptionService 创建订阅服务
func NewSubscriptionService(repos *repository.Repositories, billingCacheService *BillingCacheService) *SubscriptionService {
func NewSubscriptionService(groupRepo ports.GroupRepository, userSubRepo ports.UserSubscriptionRepository, billingCacheService *BillingCacheService) *SubscriptionService {
return &SubscriptionService{
repos: repos,
groupRepo: groupRepo,
userSubRepo: userSubRepo,
billingCacheService: billingCacheService,
}
}
@@ -47,7 +50,7 @@ type AssignSubscriptionInput struct {
// AssignSubscription 分配订阅给用户(不允许重复分配)
func (s *SubscriptionService) AssignSubscription(ctx context.Context, input *AssignSubscriptionInput) (*model.UserSubscription, error) {
// 检查分组是否存在且为订阅类型
group, err := s.repos.Group.GetByID(ctx, input.GroupID)
group, err := s.groupRepo.GetByID(ctx, input.GroupID)
if err != nil {
return nil, fmt.Errorf("group not found: %w", err)
}
@@ -56,7 +59,7 @@ func (s *SubscriptionService) AssignSubscription(ctx context.Context, input *Ass
}
// 检查是否已存在订阅
exists, err := s.repos.UserSubscription.ExistsByUserIDAndGroupID(ctx, input.UserID, input.GroupID)
exists, err := s.userSubRepo.ExistsByUserIDAndGroupID(ctx, input.UserID, input.GroupID)
if err != nil {
return nil, err
}
@@ -90,7 +93,7 @@ func (s *SubscriptionService) AssignSubscription(ctx context.Context, input *Ass
// 如果没有订阅:创建新订阅
func (s *SubscriptionService) AssignOrExtendSubscription(ctx context.Context, input *AssignSubscriptionInput) (*model.UserSubscription, bool, error) {
// 检查分组是否存在且为订阅类型
group, err := s.repos.Group.GetByID(ctx, input.GroupID)
group, err := s.groupRepo.GetByID(ctx, input.GroupID)
if err != nil {
return nil, false, fmt.Errorf("group not found: %w", err)
}
@@ -99,7 +102,7 @@ func (s *SubscriptionService) AssignOrExtendSubscription(ctx context.Context, in
}
// 查询是否已有订阅
existingSub, err := s.repos.UserSubscription.GetByUserIDAndGroupID(ctx, input.UserID, input.GroupID)
existingSub, err := s.userSubRepo.GetByUserIDAndGroupID(ctx, input.UserID, input.GroupID)
if err != nil {
// 不存在记录是正常情况,其他错误需要返回
existingSub = nil
@@ -124,13 +127,13 @@ func (s *SubscriptionService) AssignOrExtendSubscription(ctx context.Context, in
}
// 更新过期时间
if err := s.repos.UserSubscription.ExtendExpiry(ctx, existingSub.ID, newExpiresAt); err != nil {
if err := s.userSubRepo.ExtendExpiry(ctx, existingSub.ID, newExpiresAt); err != nil {
return nil, false, fmt.Errorf("extend subscription: %w", err)
}
// 如果订阅已过期或被暂停恢复为active状态
if existingSub.Status != model.SubscriptionStatusActive {
if err := s.repos.UserSubscription.UpdateStatus(ctx, existingSub.ID, model.SubscriptionStatusActive); err != nil {
if err := s.userSubRepo.UpdateStatus(ctx, existingSub.ID, model.SubscriptionStatusActive); err != nil {
return nil, false, fmt.Errorf("update subscription status: %w", err)
}
}
@@ -142,7 +145,7 @@ func (s *SubscriptionService) AssignOrExtendSubscription(ctx context.Context, in
newNotes += "\n"
}
newNotes += input.Notes
if err := s.repos.UserSubscription.UpdateNotes(ctx, existingSub.ID, newNotes); err != nil {
if err := s.userSubRepo.UpdateNotes(ctx, existingSub.ID, newNotes); err != nil {
// 备注更新失败不影响主流程
}
}
@@ -158,7 +161,7 @@ func (s *SubscriptionService) AssignOrExtendSubscription(ctx context.Context, in
}
// 返回更新后的订阅
sub, err := s.repos.UserSubscription.GetByID(ctx, existingSub.ID)
sub, err := s.userSubRepo.GetByID(ctx, existingSub.ID)
return sub, true, err // true 表示是续期
}
@@ -205,12 +208,12 @@ func (s *SubscriptionService) createSubscription(ctx context.Context, input *Ass
sub.AssignedBy = &input.AssignedBy
}
if err := s.repos.UserSubscription.Create(ctx, sub); err != nil {
if err := s.userSubRepo.Create(ctx, sub); err != nil {
return nil, err
}
// 重新获取完整订阅信息(包含关联)
return s.repos.UserSubscription.GetByID(ctx, sub.ID)
return s.userSubRepo.GetByID(ctx, sub.ID)
}
// BulkAssignSubscriptionInput 批量分配订阅输入
@@ -260,12 +263,12 @@ func (s *SubscriptionService) BulkAssignSubscription(ctx context.Context, input
// RevokeSubscription 撤销订阅
func (s *SubscriptionService) RevokeSubscription(ctx context.Context, subscriptionID int64) error {
// 先获取订阅信息用于失效缓存
sub, err := s.repos.UserSubscription.GetByID(ctx, subscriptionID)
sub, err := s.userSubRepo.GetByID(ctx, subscriptionID)
if err != nil {
return err
}
if err := s.repos.UserSubscription.Delete(ctx, subscriptionID); err != nil {
if err := s.userSubRepo.Delete(ctx, subscriptionID); err != nil {
return err
}
@@ -284,20 +287,20 @@ func (s *SubscriptionService) RevokeSubscription(ctx context.Context, subscripti
// ExtendSubscription 延长订阅
func (s *SubscriptionService) ExtendSubscription(ctx context.Context, subscriptionID int64, days int) (*model.UserSubscription, error) {
sub, err := s.repos.UserSubscription.GetByID(ctx, subscriptionID)
sub, err := s.userSubRepo.GetByID(ctx, subscriptionID)
if err != nil {
return nil, ErrSubscriptionNotFound
}
// 计算新的过期时间
newExpiresAt := sub.ExpiresAt.AddDate(0, 0, days)
if err := s.repos.UserSubscription.ExtendExpiry(ctx, subscriptionID, newExpiresAt); err != nil {
if err := s.userSubRepo.ExtendExpiry(ctx, subscriptionID, newExpiresAt); err != nil {
return nil, err
}
// 如果订阅已过期恢复为active状态
if sub.Status == model.SubscriptionStatusExpired {
if err := s.repos.UserSubscription.UpdateStatus(ctx, subscriptionID, model.SubscriptionStatusActive); err != nil {
if err := s.userSubRepo.UpdateStatus(ctx, subscriptionID, model.SubscriptionStatusActive); err != nil {
return nil, err
}
}
@@ -312,17 +315,17 @@ func (s *SubscriptionService) ExtendSubscription(ctx context.Context, subscripti
}()
}
return s.repos.UserSubscription.GetByID(ctx, subscriptionID)
return s.userSubRepo.GetByID(ctx, subscriptionID)
}
// GetByID 根据ID获取订阅
func (s *SubscriptionService) GetByID(ctx context.Context, id int64) (*model.UserSubscription, error) {
return s.repos.UserSubscription.GetByID(ctx, id)
return s.userSubRepo.GetByID(ctx, id)
}
// GetActiveSubscription 获取用户对特定分组的有效订阅
func (s *SubscriptionService) GetActiveSubscription(ctx context.Context, userID, groupID int64) (*model.UserSubscription, error) {
sub, err := s.repos.UserSubscription.GetActiveByUserIDAndGroupID(ctx, userID, groupID)
sub, err := s.userSubRepo.GetActiveByUserIDAndGroupID(ctx, userID, groupID)
if err != nil {
return nil, ErrSubscriptionNotFound
}
@@ -331,24 +334,24 @@ func (s *SubscriptionService) GetActiveSubscription(ctx context.Context, userID,
// ListUserSubscriptions 获取用户的所有订阅
func (s *SubscriptionService) ListUserSubscriptions(ctx context.Context, userID int64) ([]model.UserSubscription, error) {
return s.repos.UserSubscription.ListByUserID(ctx, userID)
return s.userSubRepo.ListByUserID(ctx, userID)
}
// ListActiveUserSubscriptions 获取用户的所有有效订阅
func (s *SubscriptionService) ListActiveUserSubscriptions(ctx context.Context, userID int64) ([]model.UserSubscription, error) {
return s.repos.UserSubscription.ListActiveByUserID(ctx, userID)
return s.userSubRepo.ListActiveByUserID(ctx, userID)
}
// ListGroupSubscriptions 获取分组的所有订阅
func (s *SubscriptionService) ListGroupSubscriptions(ctx context.Context, groupID int64, page, pageSize int) ([]model.UserSubscription, *repository.PaginationResult, error) {
params := repository.PaginationParams{Page: page, PageSize: pageSize}
return s.repos.UserSubscription.ListByGroupID(ctx, groupID, params)
func (s *SubscriptionService) ListGroupSubscriptions(ctx context.Context, groupID int64, page, pageSize int) ([]model.UserSubscription, *pagination.PaginationResult, error) {
params := pagination.PaginationParams{Page: page, PageSize: pageSize}
return s.userSubRepo.ListByGroupID(ctx, groupID, params)
}
// List 获取所有订阅(分页,支持筛选)
func (s *SubscriptionService) List(ctx context.Context, page, pageSize int, userID, groupID *int64, status string) ([]model.UserSubscription, *repository.PaginationResult, error) {
params := repository.PaginationParams{Page: page, PageSize: pageSize}
return s.repos.UserSubscription.List(ctx, params, userID, groupID, status)
func (s *SubscriptionService) List(ctx context.Context, page, pageSize int, userID, groupID *int64, status string) ([]model.UserSubscription, *pagination.PaginationResult, error) {
params := pagination.PaginationParams{Page: page, PageSize: pageSize}
return s.userSubRepo.List(ctx, params, userID, groupID, status)
}
// CheckAndActivateWindow 检查并激活窗口(首次使用时)
@@ -358,7 +361,7 @@ func (s *SubscriptionService) CheckAndActivateWindow(ctx context.Context, sub *m
}
now := time.Now()
return s.repos.UserSubscription.ActivateWindows(ctx, sub.ID, now)
return s.userSubRepo.ActivateWindows(ctx, sub.ID, now)
}
// CheckAndResetWindows 检查并重置过期的窗口
@@ -367,7 +370,7 @@ func (s *SubscriptionService) CheckAndResetWindows(ctx context.Context, sub *mod
// 日窗口重置24小时
if sub.NeedsDailyReset() {
if err := s.repos.UserSubscription.ResetDailyUsage(ctx, sub.ID, now); err != nil {
if err := s.userSubRepo.ResetDailyUsage(ctx, sub.ID, now); err != nil {
return err
}
sub.DailyWindowStart = &now
@@ -376,7 +379,7 @@ func (s *SubscriptionService) CheckAndResetWindows(ctx context.Context, sub *mod
// 周窗口重置7天
if sub.NeedsWeeklyReset() {
if err := s.repos.UserSubscription.ResetWeeklyUsage(ctx, sub.ID, now); err != nil {
if err := s.userSubRepo.ResetWeeklyUsage(ctx, sub.ID, now); err != nil {
return err
}
sub.WeeklyWindowStart = &now
@@ -385,7 +388,7 @@ func (s *SubscriptionService) CheckAndResetWindows(ctx context.Context, sub *mod
// 月窗口重置30天
if sub.NeedsMonthlyReset() {
if err := s.repos.UserSubscription.ResetMonthlyUsage(ctx, sub.ID, now); err != nil {
if err := s.userSubRepo.ResetMonthlyUsage(ctx, sub.ID, now); err != nil {
return err
}
sub.MonthlyWindowStart = &now
@@ -411,7 +414,7 @@ func (s *SubscriptionService) CheckUsageLimits(ctx context.Context, sub *model.U
// RecordUsage 记录使用量到订阅
func (s *SubscriptionService) RecordUsage(ctx context.Context, subscriptionID int64, costUSD float64) error {
return s.repos.UserSubscription.IncrementUsage(ctx, subscriptionID, costUSD)
return s.userSubRepo.IncrementUsage(ctx, subscriptionID, costUSD)
}
// SubscriptionProgress 订阅进度
@@ -438,14 +441,14 @@ type UsageWindowProgress struct {
// GetSubscriptionProgress 获取订阅使用进度
func (s *SubscriptionService) GetSubscriptionProgress(ctx context.Context, subscriptionID int64) (*SubscriptionProgress, error) {
sub, err := s.repos.UserSubscription.GetByID(ctx, subscriptionID)
sub, err := s.userSubRepo.GetByID(ctx, subscriptionID)
if err != nil {
return nil, ErrSubscriptionNotFound
}
group := sub.Group
if group == nil {
group, err = s.repos.Group.GetByID(ctx, sub.GroupID)
group, err = s.groupRepo.GetByID(ctx, sub.GroupID)
if err != nil {
return nil, err
}
@@ -535,7 +538,7 @@ func (s *SubscriptionService) GetSubscriptionProgress(ctx context.Context, subsc
// GetUserSubscriptionsWithProgress 获取用户所有订阅及进度
func (s *SubscriptionService) GetUserSubscriptionsWithProgress(ctx context.Context, userID int64) ([]SubscriptionProgress, error) {
subs, err := s.repos.UserSubscription.ListActiveByUserID(ctx, userID)
subs, err := s.userSubRepo.ListActiveByUserID(ctx, userID)
if err != nil {
return nil, err
}
@@ -554,7 +557,7 @@ func (s *SubscriptionService) GetUserSubscriptionsWithProgress(ctx context.Conte
// UpdateExpiredSubscriptions 更新过期订阅状态(定时任务调用)
func (s *SubscriptionService) UpdateExpiredSubscriptions(ctx context.Context) (int64, error) {
return s.repos.UserSubscription.BatchUpdateExpiredStatus(ctx)
return s.userSubRepo.BatchUpdateExpiredStatus(ctx)
}
// ValidateSubscription 验证订阅是否有效
@@ -567,7 +570,7 @@ func (s *SubscriptionService) ValidateSubscription(ctx context.Context, sub *mod
}
if sub.IsExpired() {
// 更新状态
_ = s.repos.UserSubscription.UpdateStatus(ctx, sub.ID, model.SubscriptionStatusExpired)
_ = s.userSubRepo.UpdateStatus(ctx, sub.ID, model.SubscriptionStatusExpired)
return ErrSubscriptionExpired
}
return nil

View File

@@ -5,7 +5,8 @@ import (
"errors"
"fmt"
"sub2api/internal/model"
"sub2api/internal/repository"
"sub2api/internal/pkg/pagination"
"sub2api/internal/service/ports"
"time"
"gorm.io/gorm"
@@ -41,24 +42,24 @@ type CreateUsageLogRequest struct {
// UsageStats 使用统计
type UsageStats struct {
TotalRequests int64 `json:"total_requests"`
TotalInputTokens int64 `json:"total_input_tokens"`
TotalOutputTokens int64 `json:"total_output_tokens"`
TotalCacheTokens int64 `json:"total_cache_tokens"`
TotalTokens int64 `json:"total_tokens"`
TotalCost float64 `json:"total_cost"`
TotalActualCost float64 `json:"total_actual_cost"`
AverageDurationMs float64 `json:"average_duration_ms"`
TotalRequests int64 `json:"total_requests"`
TotalInputTokens int64 `json:"total_input_tokens"`
TotalOutputTokens int64 `json:"total_output_tokens"`
TotalCacheTokens int64 `json:"total_cache_tokens"`
TotalTokens int64 `json:"total_tokens"`
TotalCost float64 `json:"total_cost"`
TotalActualCost float64 `json:"total_actual_cost"`
AverageDurationMs float64 `json:"average_duration_ms"`
}
// UsageService 使用统计服务
type UsageService struct {
usageRepo *repository.UsageLogRepository
userRepo *repository.UserRepository
usageRepo ports.UsageLogRepository
userRepo ports.UserRepository
}
// NewUsageService 创建使用统计服务实例
func NewUsageService(usageRepo *repository.UsageLogRepository, userRepo *repository.UserRepository) *UsageService {
func NewUsageService(usageRepo ports.UsageLogRepository, userRepo ports.UserRepository) *UsageService {
return &UsageService{
usageRepo: usageRepo,
userRepo: userRepo,
@@ -127,7 +128,7 @@ func (s *UsageService) GetByID(ctx context.Context, id int64) (*model.UsageLog,
}
// ListByUser 获取用户的使用日志列表
func (s *UsageService) ListByUser(ctx context.Context, userID int64, params repository.PaginationParams) ([]model.UsageLog, *repository.PaginationResult, error) {
func (s *UsageService) ListByUser(ctx context.Context, userID int64, params pagination.PaginationParams) ([]model.UsageLog, *pagination.PaginationResult, error) {
logs, pagination, err := s.usageRepo.ListByUser(ctx, userID, params)
if err != nil {
return nil, nil, fmt.Errorf("list usage logs: %w", err)
@@ -136,7 +137,7 @@ func (s *UsageService) ListByUser(ctx context.Context, userID int64, params repo
}
// ListByApiKey 获取API Key的使用日志列表
func (s *UsageService) ListByApiKey(ctx context.Context, apiKeyID int64, params repository.PaginationParams) ([]model.UsageLog, *repository.PaginationResult, error) {
func (s *UsageService) ListByApiKey(ctx context.Context, apiKeyID int64, params pagination.PaginationParams) ([]model.UsageLog, *pagination.PaginationResult, error) {
logs, pagination, err := s.usageRepo.ListByApiKey(ctx, apiKeyID, params)
if err != nil {
return nil, nil, fmt.Errorf("list usage logs: %w", err)
@@ -145,7 +146,7 @@ func (s *UsageService) ListByApiKey(ctx context.Context, apiKeyID int64, params
}
// ListByAccount 获取账号的使用日志列表
func (s *UsageService) ListByAccount(ctx context.Context, accountID int64, params repository.PaginationParams) ([]model.UsageLog, *repository.PaginationResult, error) {
func (s *UsageService) ListByAccount(ctx context.Context, accountID int64, params pagination.PaginationParams) ([]model.UsageLog, *pagination.PaginationResult, error) {
logs, pagination, err := s.usageRepo.ListByAccount(ctx, accountID, params)
if err != nil {
return nil, nil, fmt.Errorf("list usage logs: %w", err)
@@ -233,15 +234,15 @@ func (s *UsageService) GetDailyStats(ctx context.Context, userID int64, days int
}
result = append(result, map[string]interface{}{
"date": date,
"total_requests": stats.TotalRequests,
"total_input_tokens": stats.TotalInputTokens,
"total_output_tokens": stats.TotalOutputTokens,
"total_cache_tokens": stats.TotalCacheTokens,
"total_tokens": stats.TotalTokens,
"total_cost": stats.TotalCost,
"total_actual_cost": stats.TotalActualCost,
"average_duration_ms": stats.AverageDurationMs,
"date": date,
"total_requests": stats.TotalRequests,
"total_input_tokens": stats.TotalInputTokens,
"total_output_tokens": stats.TotalOutputTokens,
"total_cache_tokens": stats.TotalCacheTokens,
"total_tokens": stats.TotalTokens,
"total_cost": stats.TotalCost,
"total_actual_cost": stats.TotalActualCost,
"average_duration_ms": stats.AverageDurationMs,
})
}

View File

@@ -6,16 +6,17 @@ import (
"fmt"
"sub2api/internal/config"
"sub2api/internal/model"
"sub2api/internal/repository"
"sub2api/internal/pkg/pagination"
"sub2api/internal/service/ports"
"golang.org/x/crypto/bcrypt"
"gorm.io/gorm"
)
var (
ErrUserNotFound = errors.New("user not found")
ErrPasswordIncorrect = errors.New("current password is incorrect")
ErrInsufficientPerms = errors.New("insufficient permissions")
ErrUserNotFound = errors.New("user not found")
ErrPasswordIncorrect = errors.New("current password is incorrect")
ErrInsufficientPerms = errors.New("insufficient permissions")
)
// UpdateProfileRequest 更新用户资料请求
@@ -32,12 +33,12 @@ type ChangePasswordRequest struct {
// UserService 用户服务
type UserService struct {
userRepo *repository.UserRepository
userRepo ports.UserRepository
cfg *config.Config
}
// NewUserService 创建用户服务实例
func NewUserService(userRepo *repository.UserRepository, cfg *config.Config) *UserService {
func NewUserService(userRepo ports.UserRepository, cfg *config.Config) *UserService {
return &UserService{
userRepo: userRepo,
cfg: cfg,
@@ -133,7 +134,7 @@ func (s *UserService) GetByID(ctx context.Context, id int64) (*model.User, error
}
// List 获取用户列表(管理员功能)
func (s *UserService) List(ctx context.Context, params repository.PaginationParams) ([]model.User, *repository.PaginationResult, error) {
func (s *UserService) List(ctx context.Context, params pagination.PaginationParams) ([]model.User, *pagination.PaginationResult, error) {
users, pagination, err := s.userRepo.List(ctx, params)
if err != nil {
return nil, nil, fmt.Errorf("list users: %w", err)