refactor(backend): 引入端口接口模式
This commit is contained in:
@@ -5,7 +5,8 @@ import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"sub2api/internal/model"
|
||||
"sub2api/internal/repository"
|
||||
"sub2api/internal/pkg/pagination"
|
||||
"sub2api/internal/service/ports"
|
||||
|
||||
"gorm.io/gorm"
|
||||
)
|
||||
@@ -41,12 +42,12 @@ type UpdateAccountRequest struct {
|
||||
|
||||
// AccountService 账号管理服务
|
||||
type AccountService struct {
|
||||
accountRepo *repository.AccountRepository
|
||||
groupRepo *repository.GroupRepository
|
||||
accountRepo ports.AccountRepository
|
||||
groupRepo ports.GroupRepository
|
||||
}
|
||||
|
||||
// NewAccountService 创建账号服务实例
|
||||
func NewAccountService(accountRepo *repository.AccountRepository, groupRepo *repository.GroupRepository) *AccountService {
|
||||
func NewAccountService(accountRepo ports.AccountRepository, groupRepo ports.GroupRepository) *AccountService {
|
||||
return &AccountService{
|
||||
accountRepo: accountRepo,
|
||||
groupRepo: groupRepo,
|
||||
@@ -108,7 +109,7 @@ func (s *AccountService) GetByID(ctx context.Context, id int64) (*model.Account,
|
||||
}
|
||||
|
||||
// List 获取账号列表
|
||||
func (s *AccountService) List(ctx context.Context, params repository.PaginationParams) ([]model.Account, *repository.PaginationResult, error) {
|
||||
func (s *AccountService) List(ctx context.Context, params pagination.PaginationParams) ([]model.Account, *pagination.PaginationResult, error) {
|
||||
accounts, pagination, err := s.accountRepo.List(ctx, params)
|
||||
if err != nil {
|
||||
return nil, nil, fmt.Errorf("list accounts: %w", err)
|
||||
|
||||
@@ -16,7 +16,7 @@ import (
|
||||
"time"
|
||||
|
||||
"sub2api/internal/pkg/claude"
|
||||
"sub2api/internal/repository"
|
||||
"sub2api/internal/service/ports"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/google/uuid"
|
||||
@@ -37,15 +37,15 @@ type TestEvent struct {
|
||||
|
||||
// AccountTestService handles account testing operations
|
||||
type AccountTestService struct {
|
||||
repos *repository.Repositories
|
||||
accountRepo ports.AccountRepository
|
||||
oauthService *OAuthService
|
||||
httpClient *http.Client
|
||||
}
|
||||
|
||||
// NewAccountTestService creates a new AccountTestService
|
||||
func NewAccountTestService(repos *repository.Repositories, oauthService *OAuthService) *AccountTestService {
|
||||
func NewAccountTestService(accountRepo ports.AccountRepository, oauthService *OAuthService) *AccountTestService {
|
||||
return &AccountTestService{
|
||||
repos: repos,
|
||||
accountRepo: accountRepo,
|
||||
oauthService: oauthService,
|
||||
httpClient: &http.Client{
|
||||
Timeout: 60 * time.Second,
|
||||
@@ -105,7 +105,7 @@ func (s *AccountTestService) TestAccountConnection(c *gin.Context, accountID int
|
||||
ctx := c.Request.Context()
|
||||
|
||||
// Get account
|
||||
account, err := s.repos.Account.GetByID(ctx, accountID)
|
||||
account, err := s.accountRepo.GetByID(ctx, accountID)
|
||||
if err != nil {
|
||||
return s.sendErrorAndEnd(c, "Account not found")
|
||||
}
|
||||
|
||||
@@ -12,7 +12,7 @@ import (
|
||||
"time"
|
||||
|
||||
"sub2api/internal/model"
|
||||
"sub2api/internal/repository"
|
||||
"sub2api/internal/service/ports"
|
||||
)
|
||||
|
||||
// usageCache 用于缓存usage数据
|
||||
@@ -35,10 +35,10 @@ type WindowStats struct {
|
||||
|
||||
// UsageProgress 使用量进度
|
||||
type UsageProgress struct {
|
||||
Utilization float64 `json:"utilization"` // 使用率百分比 (0-100+,100表示100%)
|
||||
ResetsAt *time.Time `json:"resets_at"` // 重置时间
|
||||
RemainingSeconds int `json:"remaining_seconds"` // 距重置剩余秒数
|
||||
WindowStats *WindowStats `json:"window_stats,omitempty"` // 窗口期统计(从窗口开始到当前的使用量)
|
||||
Utilization float64 `json:"utilization"` // 使用率百分比 (0-100+,100表示100%)
|
||||
ResetsAt *time.Time `json:"resets_at"` // 重置时间
|
||||
RemainingSeconds int `json:"remaining_seconds"` // 距重置剩余秒数
|
||||
WindowStats *WindowStats `json:"window_stats,omitempty"` // 窗口期统计(从窗口开始到当前的使用量)
|
||||
}
|
||||
|
||||
// UsageInfo 账号使用量信息
|
||||
@@ -67,15 +67,17 @@ type ClaudeUsageResponse struct {
|
||||
|
||||
// AccountUsageService 账号使用量查询服务
|
||||
type AccountUsageService struct {
|
||||
repos *repository.Repositories
|
||||
accountRepo ports.AccountRepository
|
||||
usageLogRepo ports.UsageLogRepository
|
||||
oauthService *OAuthService
|
||||
httpClient *http.Client
|
||||
}
|
||||
|
||||
// NewAccountUsageService 创建AccountUsageService实例
|
||||
func NewAccountUsageService(repos *repository.Repositories, oauthService *OAuthService) *AccountUsageService {
|
||||
func NewAccountUsageService(accountRepo ports.AccountRepository, usageLogRepo ports.UsageLogRepository, oauthService *OAuthService) *AccountUsageService {
|
||||
return &AccountUsageService{
|
||||
repos: repos,
|
||||
accountRepo: accountRepo,
|
||||
usageLogRepo: usageLogRepo,
|
||||
oauthService: oauthService,
|
||||
httpClient: &http.Client{
|
||||
Timeout: 30 * time.Second,
|
||||
@@ -88,7 +90,7 @@ func NewAccountUsageService(repos *repository.Repositories, oauthService *OAuthS
|
||||
// Setup Token账号: 根据session_window推算5h窗口,7d数据不可用(没有profile scope)
|
||||
// API Key账号: 不支持usage查询
|
||||
func (s *AccountUsageService) GetUsage(ctx context.Context, accountID int64) (*UsageInfo, error) {
|
||||
account, err := s.repos.Account.GetByID(ctx, accountID)
|
||||
account, err := s.accountRepo.GetByID(ctx, accountID)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("get account failed: %w", err)
|
||||
}
|
||||
@@ -148,7 +150,7 @@ func (s *AccountUsageService) addWindowStats(ctx context.Context, account *model
|
||||
startTime = time.Now().Add(-5 * time.Hour)
|
||||
}
|
||||
|
||||
stats, err := s.repos.UsageLog.GetAccountWindowStats(ctx, account.ID, startTime)
|
||||
stats, err := s.usageLogRepo.GetAccountWindowStats(ctx, account.ID, startTime)
|
||||
if err != nil {
|
||||
log.Printf("Failed to get window stats for account %d: %v", account.ID, err)
|
||||
return
|
||||
@@ -163,7 +165,7 @@ func (s *AccountUsageService) addWindowStats(ctx context.Context, account *model
|
||||
|
||||
// GetTodayStats 获取账号今日统计
|
||||
func (s *AccountUsageService) GetTodayStats(ctx context.Context, accountID int64) (*WindowStats, error) {
|
||||
stats, err := s.repos.UsageLog.GetAccountTodayStats(ctx, accountID)
|
||||
stats, err := s.usageLogRepo.GetAccountTodayStats(ctx, accountID)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("get today stats failed: %w", err)
|
||||
}
|
||||
|
||||
@@ -13,7 +13,8 @@ import (
|
||||
"time"
|
||||
|
||||
"sub2api/internal/model"
|
||||
"sub2api/internal/repository"
|
||||
"sub2api/internal/pkg/pagination"
|
||||
"sub2api/internal/service/ports"
|
||||
|
||||
"golang.org/x/net/proxy"
|
||||
"gorm.io/gorm"
|
||||
@@ -179,35 +180,45 @@ type ProxyTestResult struct {
|
||||
|
||||
// adminServiceImpl implements AdminService
|
||||
type adminServiceImpl struct {
|
||||
userRepo *repository.UserRepository
|
||||
groupRepo *repository.GroupRepository
|
||||
accountRepo *repository.AccountRepository
|
||||
proxyRepo *repository.ProxyRepository
|
||||
apiKeyRepo *repository.ApiKeyRepository
|
||||
redeemCodeRepo *repository.RedeemCodeRepository
|
||||
usageLogRepo *repository.UsageLogRepository
|
||||
userSubRepo *repository.UserSubscriptionRepository
|
||||
userRepo ports.UserRepository
|
||||
groupRepo ports.GroupRepository
|
||||
accountRepo ports.AccountRepository
|
||||
proxyRepo ports.ProxyRepository
|
||||
apiKeyRepo ports.ApiKeyRepository
|
||||
redeemCodeRepo ports.RedeemCodeRepository
|
||||
usageLogRepo ports.UsageLogRepository
|
||||
userSubRepo ports.UserSubscriptionRepository
|
||||
billingCacheService *BillingCacheService
|
||||
}
|
||||
|
||||
// NewAdminService creates a new AdminService
|
||||
func NewAdminService(repos *repository.Repositories, billingCacheService *BillingCacheService) AdminService {
|
||||
func NewAdminService(
|
||||
userRepo ports.UserRepository,
|
||||
groupRepo ports.GroupRepository,
|
||||
accountRepo ports.AccountRepository,
|
||||
proxyRepo ports.ProxyRepository,
|
||||
apiKeyRepo ports.ApiKeyRepository,
|
||||
redeemCodeRepo ports.RedeemCodeRepository,
|
||||
usageLogRepo ports.UsageLogRepository,
|
||||
userSubRepo ports.UserSubscriptionRepository,
|
||||
billingCacheService *BillingCacheService,
|
||||
) AdminService {
|
||||
return &adminServiceImpl{
|
||||
userRepo: repos.User,
|
||||
groupRepo: repos.Group,
|
||||
accountRepo: repos.Account,
|
||||
proxyRepo: repos.Proxy,
|
||||
apiKeyRepo: repos.ApiKey,
|
||||
redeemCodeRepo: repos.RedeemCode,
|
||||
usageLogRepo: repos.UsageLog,
|
||||
userSubRepo: repos.UserSubscription,
|
||||
userRepo: userRepo,
|
||||
groupRepo: groupRepo,
|
||||
accountRepo: accountRepo,
|
||||
proxyRepo: proxyRepo,
|
||||
apiKeyRepo: apiKeyRepo,
|
||||
redeemCodeRepo: redeemCodeRepo,
|
||||
usageLogRepo: usageLogRepo,
|
||||
userSubRepo: userSubRepo,
|
||||
billingCacheService: billingCacheService,
|
||||
}
|
||||
}
|
||||
|
||||
// User management implementations
|
||||
func (s *adminServiceImpl) ListUsers(ctx context.Context, page, pageSize int, status, role, search string) ([]model.User, int64, error) {
|
||||
params := repository.PaginationParams{Page: page, PageSize: pageSize}
|
||||
params := pagination.PaginationParams{Page: page, PageSize: pageSize}
|
||||
users, result, err := s.userRepo.ListWithFilters(ctx, params, status, role, search)
|
||||
if err != nil {
|
||||
return nil, 0, err
|
||||
@@ -376,7 +387,7 @@ func (s *adminServiceImpl) UpdateUserBalance(ctx context.Context, userID int64,
|
||||
}
|
||||
|
||||
func (s *adminServiceImpl) GetUserAPIKeys(ctx context.Context, userID int64, page, pageSize int) ([]model.ApiKey, int64, error) {
|
||||
params := repository.PaginationParams{Page: page, PageSize: pageSize}
|
||||
params := pagination.PaginationParams{Page: page, PageSize: pageSize}
|
||||
keys, result, err := s.apiKeyRepo.ListByUserID(ctx, userID, params)
|
||||
if err != nil {
|
||||
return nil, 0, err
|
||||
@@ -397,7 +408,7 @@ func (s *adminServiceImpl) GetUserUsageStats(ctx context.Context, userID int64,
|
||||
|
||||
// Group management implementations
|
||||
func (s *adminServiceImpl) ListGroups(ctx context.Context, page, pageSize int, platform, status string, isExclusive *bool) ([]model.Group, int64, error) {
|
||||
params := repository.PaginationParams{Page: page, PageSize: pageSize}
|
||||
params := pagination.PaginationParams{Page: page, PageSize: pageSize}
|
||||
groups, result, err := s.groupRepo.ListWithFilters(ctx, params, platform, status, isExclusive)
|
||||
if err != nil {
|
||||
return nil, 0, err
|
||||
@@ -568,7 +579,7 @@ func (s *adminServiceImpl) DeleteGroup(ctx context.Context, id int64) error {
|
||||
}
|
||||
|
||||
func (s *adminServiceImpl) GetGroupAPIKeys(ctx context.Context, groupID int64, page, pageSize int) ([]model.ApiKey, int64, error) {
|
||||
params := repository.PaginationParams{Page: page, PageSize: pageSize}
|
||||
params := pagination.PaginationParams{Page: page, PageSize: pageSize}
|
||||
keys, result, err := s.apiKeyRepo.ListByGroupID(ctx, groupID, params)
|
||||
if err != nil {
|
||||
return nil, 0, err
|
||||
@@ -578,7 +589,7 @@ func (s *adminServiceImpl) GetGroupAPIKeys(ctx context.Context, groupID int64, p
|
||||
|
||||
// Account management implementations
|
||||
func (s *adminServiceImpl) ListAccounts(ctx context.Context, page, pageSize int, platform, accountType, status, search string) ([]model.Account, int64, error) {
|
||||
params := repository.PaginationParams{Page: page, PageSize: pageSize}
|
||||
params := pagination.PaginationParams{Page: page, PageSize: pageSize}
|
||||
accounts, result, err := s.accountRepo.ListWithFilters(ctx, params, platform, accountType, status, search)
|
||||
if err != nil {
|
||||
return nil, 0, err
|
||||
@@ -696,7 +707,7 @@ func (s *adminServiceImpl) SetAccountSchedulable(ctx context.Context, id int64,
|
||||
|
||||
// Proxy management implementations
|
||||
func (s *adminServiceImpl) ListProxies(ctx context.Context, page, pageSize int, protocol, status, search string) ([]model.Proxy, int64, error) {
|
||||
params := repository.PaginationParams{Page: page, PageSize: pageSize}
|
||||
params := pagination.PaginationParams{Page: page, PageSize: pageSize}
|
||||
proxies, result, err := s.proxyRepo.ListWithFilters(ctx, params, protocol, status, search)
|
||||
if err != nil {
|
||||
return nil, 0, err
|
||||
@@ -781,7 +792,7 @@ func (s *adminServiceImpl) CheckProxyExists(ctx context.Context, host string, po
|
||||
|
||||
// Redeem code management implementations
|
||||
func (s *adminServiceImpl) ListRedeemCodes(ctx context.Context, page, pageSize int, codeType, status, search string) ([]model.RedeemCode, int64, error) {
|
||||
params := repository.PaginationParams{Page: page, PageSize: pageSize}
|
||||
params := pagination.PaginationParams{Page: page, PageSize: pageSize}
|
||||
codes, result, err := s.redeemCodeRepo.ListWithFilters(ctx, params, codeType, status, search)
|
||||
if err != nil {
|
||||
return nil, 0, err
|
||||
|
||||
@@ -8,8 +8,9 @@ import (
|
||||
"fmt"
|
||||
"sub2api/internal/config"
|
||||
"sub2api/internal/model"
|
||||
"sub2api/internal/pkg/pagination"
|
||||
"sub2api/internal/pkg/timezone"
|
||||
"sub2api/internal/repository"
|
||||
"sub2api/internal/service/ports"
|
||||
"time"
|
||||
|
||||
"github.com/redis/go-redis/v9"
|
||||
@@ -17,12 +18,12 @@ import (
|
||||
)
|
||||
|
||||
var (
|
||||
ErrApiKeyNotFound = errors.New("api key not found")
|
||||
ErrGroupNotAllowed = errors.New("user is not allowed to bind this group")
|
||||
ErrApiKeyExists = errors.New("api key already exists")
|
||||
ErrApiKeyTooShort = errors.New("api key must be at least 16 characters")
|
||||
ErrApiKeyInvalidChars = errors.New("api key can only contain letters, numbers, underscores, and hyphens")
|
||||
ErrApiKeyRateLimited = errors.New("too many failed attempts, please try again later")
|
||||
ErrApiKeyNotFound = errors.New("api key not found")
|
||||
ErrGroupNotAllowed = errors.New("user is not allowed to bind this group")
|
||||
ErrApiKeyExists = errors.New("api key already exists")
|
||||
ErrApiKeyTooShort = errors.New("api key must be at least 16 characters")
|
||||
ErrApiKeyInvalidChars = errors.New("api key can only contain letters, numbers, underscores, and hyphens")
|
||||
ErrApiKeyRateLimited = errors.New("too many failed attempts, please try again later")
|
||||
)
|
||||
|
||||
const (
|
||||
@@ -47,20 +48,20 @@ type UpdateApiKeyRequest struct {
|
||||
|
||||
// ApiKeyService API Key服务
|
||||
type ApiKeyService struct {
|
||||
apiKeyRepo *repository.ApiKeyRepository
|
||||
userRepo *repository.UserRepository
|
||||
groupRepo *repository.GroupRepository
|
||||
userSubRepo *repository.UserSubscriptionRepository
|
||||
apiKeyRepo ports.ApiKeyRepository
|
||||
userRepo ports.UserRepository
|
||||
groupRepo ports.GroupRepository
|
||||
userSubRepo ports.UserSubscriptionRepository
|
||||
rdb *redis.Client
|
||||
cfg *config.Config
|
||||
}
|
||||
|
||||
// NewApiKeyService 创建API Key服务实例
|
||||
func NewApiKeyService(
|
||||
apiKeyRepo *repository.ApiKeyRepository,
|
||||
userRepo *repository.UserRepository,
|
||||
groupRepo *repository.GroupRepository,
|
||||
userSubRepo *repository.UserSubscriptionRepository,
|
||||
apiKeyRepo ports.ApiKeyRepository,
|
||||
userRepo ports.UserRepository,
|
||||
groupRepo ports.GroupRepository,
|
||||
userSubRepo ports.UserSubscriptionRepository,
|
||||
rdb *redis.Client,
|
||||
cfg *config.Config,
|
||||
) *ApiKeyService {
|
||||
@@ -237,7 +238,7 @@ func (s *ApiKeyService) Create(ctx context.Context, userID int64, req CreateApiK
|
||||
}
|
||||
|
||||
// List 获取用户的API Key列表
|
||||
func (s *ApiKeyService) List(ctx context.Context, userID int64, params repository.PaginationParams) ([]model.ApiKey, *repository.PaginationResult, error) {
|
||||
func (s *ApiKeyService) List(ctx context.Context, userID int64, params pagination.PaginationParams) ([]model.ApiKey, *pagination.PaginationResult, error) {
|
||||
keys, pagination, err := s.apiKeyRepo.ListByUserID(ctx, userID, params)
|
||||
if err != nil {
|
||||
return nil, nil, fmt.Errorf("list api keys: %w", err)
|
||||
|
||||
@@ -7,7 +7,7 @@ import (
|
||||
"log"
|
||||
"sub2api/internal/config"
|
||||
"sub2api/internal/model"
|
||||
"sub2api/internal/repository"
|
||||
"sub2api/internal/service/ports"
|
||||
"time"
|
||||
|
||||
"github.com/golang-jwt/jwt/v5"
|
||||
@@ -35,7 +35,7 @@ type JWTClaims struct {
|
||||
|
||||
// AuthService 认证服务
|
||||
type AuthService struct {
|
||||
userRepo *repository.UserRepository
|
||||
userRepo ports.UserRepository
|
||||
cfg *config.Config
|
||||
settingService *SettingService
|
||||
emailService *EmailService
|
||||
@@ -45,7 +45,7 @@ type AuthService struct {
|
||||
|
||||
// NewAuthService 创建认证服务实例
|
||||
func NewAuthService(
|
||||
userRepo *repository.UserRepository,
|
||||
userRepo ports.UserRepository,
|
||||
cfg *config.Config,
|
||||
settingService *SettingService,
|
||||
emailService *EmailService,
|
||||
|
||||
@@ -9,7 +9,7 @@ import (
|
||||
"time"
|
||||
|
||||
"sub2api/internal/model"
|
||||
"sub2api/internal/repository"
|
||||
"sub2api/internal/service/ports"
|
||||
|
||||
"github.com/redis/go-redis/v9"
|
||||
)
|
||||
@@ -81,12 +81,12 @@ type subscriptionCacheData struct {
|
||||
// 负责余额和订阅数据的缓存管理,提供高性能的计费资格检查
|
||||
type BillingCacheService struct {
|
||||
rdb *redis.Client
|
||||
userRepo *repository.UserRepository
|
||||
subRepo *repository.UserSubscriptionRepository
|
||||
userRepo ports.UserRepository
|
||||
subRepo ports.UserSubscriptionRepository
|
||||
}
|
||||
|
||||
// NewBillingCacheService 创建计费缓存服务
|
||||
func NewBillingCacheService(rdb *redis.Client, userRepo *repository.UserRepository, subRepo *repository.UserSubscriptionRepository) *BillingCacheService {
|
||||
func NewBillingCacheService(rdb *redis.Client, userRepo ports.UserRepository, subRepo ports.UserSubscriptionRepository) *BillingCacheService {
|
||||
return &BillingCacheService{
|
||||
rdb: rdb,
|
||||
userRepo: userRepo,
|
||||
|
||||
@@ -11,7 +11,7 @@ import (
|
||||
"net/smtp"
|
||||
"strconv"
|
||||
"sub2api/internal/model"
|
||||
"sub2api/internal/repository"
|
||||
"sub2api/internal/service/ports"
|
||||
"time"
|
||||
|
||||
"github.com/redis/go-redis/v9"
|
||||
@@ -25,9 +25,9 @@ var (
|
||||
)
|
||||
|
||||
const (
|
||||
verifyCodeKeyPrefix = "email_verify:"
|
||||
verifyCodeTTL = 15 * time.Minute
|
||||
verifyCodeCooldown = 1 * time.Minute
|
||||
verifyCodeKeyPrefix = "email_verify:"
|
||||
verifyCodeTTL = 15 * time.Minute
|
||||
verifyCodeCooldown = 1 * time.Minute
|
||||
maxVerifyCodeAttempts = 5
|
||||
)
|
||||
|
||||
@@ -51,12 +51,12 @@ type SmtpConfig struct {
|
||||
|
||||
// EmailService 邮件服务
|
||||
type EmailService struct {
|
||||
settingRepo *repository.SettingRepository
|
||||
settingRepo ports.SettingRepository
|
||||
rdb *redis.Client
|
||||
}
|
||||
|
||||
// NewEmailService 创建邮件服务实例
|
||||
func NewEmailService(settingRepo *repository.SettingRepository, rdb *redis.Client) *EmailService {
|
||||
func NewEmailService(settingRepo ports.SettingRepository, rdb *redis.Client) *EmailService {
|
||||
return &EmailService{
|
||||
settingRepo: settingRepo,
|
||||
rdb: rdb,
|
||||
|
||||
@@ -21,7 +21,7 @@ import (
|
||||
"sub2api/internal/config"
|
||||
"sub2api/internal/model"
|
||||
"sub2api/internal/pkg/claude"
|
||||
"sub2api/internal/repository"
|
||||
"sub2api/internal/service/ports"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/redis/go-redis/v9"
|
||||
@@ -78,7 +78,10 @@ type ForwardResult struct {
|
||||
|
||||
// GatewayService handles API gateway operations
|
||||
type GatewayService struct {
|
||||
repos *repository.Repositories
|
||||
accountRepo ports.AccountRepository
|
||||
usageLogRepo ports.UsageLogRepository
|
||||
userRepo ports.UserRepository
|
||||
userSubRepo ports.UserSubscriptionRepository
|
||||
rdb *redis.Client
|
||||
cfg *config.Config
|
||||
oauthService *OAuthService
|
||||
@@ -90,7 +93,19 @@ type GatewayService struct {
|
||||
}
|
||||
|
||||
// NewGatewayService creates a new GatewayService
|
||||
func NewGatewayService(repos *repository.Repositories, rdb *redis.Client, cfg *config.Config, oauthService *OAuthService, billingService *BillingService, rateLimitService *RateLimitService, billingCacheService *BillingCacheService, identityService *IdentityService) *GatewayService {
|
||||
func NewGatewayService(
|
||||
accountRepo ports.AccountRepository,
|
||||
usageLogRepo ports.UsageLogRepository,
|
||||
userRepo ports.UserRepository,
|
||||
userSubRepo ports.UserSubscriptionRepository,
|
||||
rdb *redis.Client,
|
||||
cfg *config.Config,
|
||||
oauthService *OAuthService,
|
||||
billingService *BillingService,
|
||||
rateLimitService *RateLimitService,
|
||||
billingCacheService *BillingCacheService,
|
||||
identityService *IdentityService,
|
||||
) *GatewayService {
|
||||
// 计算响应头超时时间
|
||||
responseHeaderTimeout := time.Duration(cfg.Gateway.ResponseHeaderTimeout) * time.Second
|
||||
if responseHeaderTimeout == 0 {
|
||||
@@ -105,7 +120,10 @@ func NewGatewayService(repos *repository.Repositories, rdb *redis.Client, cfg *c
|
||||
// 注意:不设置整体 Timeout,让流式响应可以无限时间传输
|
||||
}
|
||||
return &GatewayService{
|
||||
repos: repos,
|
||||
accountRepo: accountRepo,
|
||||
usageLogRepo: usageLogRepo,
|
||||
userRepo: userRepo,
|
||||
userSubRepo: userSubRepo,
|
||||
rdb: rdb,
|
||||
cfg: cfg,
|
||||
oauthService: oauthService,
|
||||
@@ -274,7 +292,7 @@ func (s *GatewayService) SelectAccountForModel(ctx context.Context, groupID *int
|
||||
if sessionHash != "" {
|
||||
accountID, err := s.rdb.Get(ctx, stickySessionPrefix+sessionHash).Int64()
|
||||
if err == nil && accountID > 0 {
|
||||
account, err := s.repos.Account.GetByID(ctx, accountID)
|
||||
account, err := s.accountRepo.GetByID(ctx, accountID)
|
||||
// 使用IsSchedulable代替IsActive,确保限流/过载账号不会被选中
|
||||
// 同时检查模型支持
|
||||
if err == nil && account.IsSchedulable() && (requestedModel == "" || account.IsModelSupported(requestedModel)) {
|
||||
@@ -289,9 +307,9 @@ func (s *GatewayService) SelectAccountForModel(ctx context.Context, groupID *int
|
||||
var accounts []model.Account
|
||||
var err error
|
||||
if groupID != nil {
|
||||
accounts, err = s.repos.Account.ListSchedulableByGroupID(ctx, *groupID)
|
||||
accounts, err = s.accountRepo.ListSchedulableByGroupID(ctx, *groupID)
|
||||
} else {
|
||||
accounts, err = s.repos.Account.ListSchedulable(ctx)
|
||||
accounts, err = s.accountRepo.ListSchedulable(ctx)
|
||||
}
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("query accounts failed: %w", err)
|
||||
@@ -378,7 +396,7 @@ func (s *GatewayService) getOAuthToken(ctx context.Context, account *model.Accou
|
||||
account.Credentials["refresh_token"] = tokenInfo.RefreshToken
|
||||
}
|
||||
|
||||
if err := s.repos.Account.Update(ctx, account); err != nil {
|
||||
if err := s.accountRepo.Update(ctx, account); err != nil {
|
||||
log.Printf("Failed to update account credentials: %v", err)
|
||||
}
|
||||
|
||||
@@ -667,7 +685,7 @@ func (s *GatewayService) forceRefreshToken(ctx context.Context, account *model.A
|
||||
account.Credentials["refresh_token"] = tokenInfo.RefreshToken
|
||||
}
|
||||
|
||||
if err := s.repos.Account.Update(ctx, account); err != nil {
|
||||
if err := s.accountRepo.Update(ctx, account); err != nil {
|
||||
log.Printf("Failed to update account credentials: %v", err)
|
||||
}
|
||||
|
||||
@@ -999,7 +1017,7 @@ func (s *GatewayService) RecordUsage(ctx context.Context, input *RecordUsageInpu
|
||||
usageLog.SubscriptionID = &subscription.ID
|
||||
}
|
||||
|
||||
if err := s.repos.UsageLog.Create(ctx, usageLog); err != nil {
|
||||
if err := s.usageLogRepo.Create(ctx, usageLog); err != nil {
|
||||
log.Printf("Create usage log failed: %v", err)
|
||||
}
|
||||
|
||||
@@ -1007,7 +1025,7 @@ func (s *GatewayService) RecordUsage(ctx context.Context, input *RecordUsageInpu
|
||||
if isSubscriptionBilling {
|
||||
// 订阅模式:更新订阅用量(使用 TotalCost 原始费用,不考虑倍率)
|
||||
if cost.TotalCost > 0 {
|
||||
if err := s.repos.UserSubscription.IncrementUsage(ctx, subscription.ID, cost.TotalCost); err != nil {
|
||||
if err := s.userSubRepo.IncrementUsage(ctx, subscription.ID, cost.TotalCost); err != nil {
|
||||
log.Printf("Increment subscription usage failed: %v", err)
|
||||
}
|
||||
// 异步更新订阅缓存
|
||||
@@ -1022,7 +1040,7 @@ func (s *GatewayService) RecordUsage(ctx context.Context, input *RecordUsageInpu
|
||||
} else {
|
||||
// 余额模式:扣除用户余额(使用 ActualCost 考虑倍率后的费用)
|
||||
if cost.ActualCost > 0 {
|
||||
if err := s.repos.User.DeductBalance(ctx, user.ID, cost.ActualCost); err != nil {
|
||||
if err := s.userRepo.DeductBalance(ctx, user.ID, cost.ActualCost); err != nil {
|
||||
log.Printf("Deduct balance failed: %v", err)
|
||||
}
|
||||
// 异步更新余额缓存
|
||||
@@ -1037,7 +1055,7 @@ func (s *GatewayService) RecordUsage(ctx context.Context, input *RecordUsageInpu
|
||||
}
|
||||
|
||||
// 更新账号最后使用时间
|
||||
if err := s.repos.Account.UpdateLastUsed(ctx, account.ID); err != nil {
|
||||
if err := s.accountRepo.UpdateLastUsed(ctx, account.ID); err != nil {
|
||||
log.Printf("Update last used failed: %v", err)
|
||||
}
|
||||
|
||||
|
||||
@@ -5,7 +5,8 @@ import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"sub2api/internal/model"
|
||||
"sub2api/internal/repository"
|
||||
"sub2api/internal/pkg/pagination"
|
||||
"sub2api/internal/service/ports"
|
||||
|
||||
"gorm.io/gorm"
|
||||
)
|
||||
@@ -34,11 +35,11 @@ type UpdateGroupRequest struct {
|
||||
|
||||
// GroupService 分组管理服务
|
||||
type GroupService struct {
|
||||
groupRepo *repository.GroupRepository
|
||||
groupRepo ports.GroupRepository
|
||||
}
|
||||
|
||||
// NewGroupService 创建分组服务实例
|
||||
func NewGroupService(groupRepo *repository.GroupRepository) *GroupService {
|
||||
func NewGroupService(groupRepo ports.GroupRepository) *GroupService {
|
||||
return &GroupService{
|
||||
groupRepo: groupRepo,
|
||||
}
|
||||
@@ -84,7 +85,7 @@ func (s *GroupService) GetByID(ctx context.Context, id int64) (*model.Group, err
|
||||
}
|
||||
|
||||
// List 获取分组列表
|
||||
func (s *GroupService) List(ctx context.Context, params repository.PaginationParams) ([]model.Group, *repository.PaginationResult, error) {
|
||||
func (s *GroupService) List(ctx context.Context, params pagination.PaginationParams) ([]model.Group, *pagination.PaginationResult, error) {
|
||||
groups, pagination, err := s.groupRepo.List(ctx, params)
|
||||
if err != nil {
|
||||
return nil, nil, fmt.Errorf("list groups: %w", err)
|
||||
|
||||
@@ -12,7 +12,7 @@ import (
|
||||
|
||||
"sub2api/internal/model"
|
||||
"sub2api/internal/pkg/oauth"
|
||||
"sub2api/internal/repository"
|
||||
"sub2api/internal/service/ports"
|
||||
|
||||
"github.com/imroc/req/v3"
|
||||
)
|
||||
@@ -20,11 +20,11 @@ import (
|
||||
// OAuthService handles OAuth authentication flows
|
||||
type OAuthService struct {
|
||||
sessionStore *oauth.SessionStore
|
||||
proxyRepo *repository.ProxyRepository
|
||||
proxyRepo ports.ProxyRepository
|
||||
}
|
||||
|
||||
// NewOAuthService creates a new OAuth service
|
||||
func NewOAuthService(proxyRepo *repository.ProxyRepository) *OAuthService {
|
||||
func NewOAuthService(proxyRepo ports.ProxyRepository) *OAuthService {
|
||||
return &OAuthService{
|
||||
sessionStore: oauth.NewSessionStore(),
|
||||
proxyRepo: proxyRepo,
|
||||
@@ -459,7 +459,7 @@ func (s *OAuthService) RefreshAccountToken(ctx context.Context, account *model.A
|
||||
// createReqClient creates a req client with Chrome impersonation and optional proxy
|
||||
func (s *OAuthService) createReqClient(proxyURL string) *req.Client {
|
||||
client := req.C().
|
||||
ImpersonateChrome(). // Impersonate Chrome browser to bypass Cloudflare
|
||||
ImpersonateChrome(). // Impersonate Chrome browser to bypass Cloudflare
|
||||
SetTimeout(60 * time.Second)
|
||||
|
||||
// Set proxy if specified
|
||||
|
||||
35
backend/internal/service/ports/account.go
Normal file
35
backend/internal/service/ports/account.go
Normal file
@@ -0,0 +1,35 @@
|
||||
package ports
|
||||
|
||||
import (
|
||||
"context"
|
||||
"time"
|
||||
|
||||
"sub2api/internal/model"
|
||||
"sub2api/internal/pkg/pagination"
|
||||
)
|
||||
|
||||
type AccountRepository interface {
|
||||
Create(ctx context.Context, account *model.Account) error
|
||||
GetByID(ctx context.Context, id int64) (*model.Account, error)
|
||||
Update(ctx context.Context, account *model.Account) error
|
||||
Delete(ctx context.Context, id int64) error
|
||||
|
||||
List(ctx context.Context, params pagination.PaginationParams) ([]model.Account, *pagination.PaginationResult, error)
|
||||
ListWithFilters(ctx context.Context, params pagination.PaginationParams, platform, accountType, status, search string) ([]model.Account, *pagination.PaginationResult, error)
|
||||
ListByGroup(ctx context.Context, groupID int64) ([]model.Account, error)
|
||||
ListActive(ctx context.Context) ([]model.Account, error)
|
||||
ListByPlatform(ctx context.Context, platform string) ([]model.Account, error)
|
||||
|
||||
UpdateLastUsed(ctx context.Context, id int64) error
|
||||
SetError(ctx context.Context, id int64, errorMsg string) error
|
||||
SetSchedulable(ctx context.Context, id int64, schedulable bool) error
|
||||
BindGroups(ctx context.Context, accountID int64, groupIDs []int64) error
|
||||
|
||||
ListSchedulable(ctx context.Context) ([]model.Account, error)
|
||||
ListSchedulableByGroupID(ctx context.Context, groupID int64) ([]model.Account, error)
|
||||
|
||||
SetRateLimited(ctx context.Context, id int64, resetAt time.Time) error
|
||||
SetOverloaded(ctx context.Context, id int64, until time.Time) error
|
||||
ClearRateLimit(ctx context.Context, id int64) error
|
||||
UpdateSessionWindow(ctx context.Context, id int64, start, end *time.Time, status string) error
|
||||
}
|
||||
24
backend/internal/service/ports/api_key.go
Normal file
24
backend/internal/service/ports/api_key.go
Normal file
@@ -0,0 +1,24 @@
|
||||
package ports
|
||||
|
||||
import (
|
||||
"context"
|
||||
|
||||
"sub2api/internal/model"
|
||||
"sub2api/internal/pkg/pagination"
|
||||
)
|
||||
|
||||
type ApiKeyRepository interface {
|
||||
Create(ctx context.Context, key *model.ApiKey) error
|
||||
GetByID(ctx context.Context, id int64) (*model.ApiKey, error)
|
||||
GetByKey(ctx context.Context, key string) (*model.ApiKey, error)
|
||||
Update(ctx context.Context, key *model.ApiKey) error
|
||||
Delete(ctx context.Context, id int64) error
|
||||
|
||||
ListByUserID(ctx context.Context, userID int64, params pagination.PaginationParams) ([]model.ApiKey, *pagination.PaginationResult, error)
|
||||
CountByUserID(ctx context.Context, userID int64) (int64, error)
|
||||
ExistsByKey(ctx context.Context, key string) (bool, error)
|
||||
ListByGroupID(ctx context.Context, groupID int64, params pagination.PaginationParams) ([]model.ApiKey, *pagination.PaginationResult, error)
|
||||
SearchApiKeys(ctx context.Context, userID int64, keyword string, limit int) ([]model.ApiKey, error)
|
||||
ClearGroupIDByGroupID(ctx context.Context, groupID int64) (int64, error)
|
||||
CountByGroupID(ctx context.Context, groupID int64) (int64, error)
|
||||
}
|
||||
28
backend/internal/service/ports/group.go
Normal file
28
backend/internal/service/ports/group.go
Normal file
@@ -0,0 +1,28 @@
|
||||
package ports
|
||||
|
||||
import (
|
||||
"context"
|
||||
|
||||
"sub2api/internal/model"
|
||||
"sub2api/internal/pkg/pagination"
|
||||
|
||||
"gorm.io/gorm"
|
||||
)
|
||||
|
||||
type GroupRepository interface {
|
||||
Create(ctx context.Context, group *model.Group) error
|
||||
GetByID(ctx context.Context, id int64) (*model.Group, error)
|
||||
Update(ctx context.Context, group *model.Group) error
|
||||
Delete(ctx context.Context, id int64) error
|
||||
|
||||
List(ctx context.Context, params pagination.PaginationParams) ([]model.Group, *pagination.PaginationResult, error)
|
||||
ListWithFilters(ctx context.Context, params pagination.PaginationParams, platform, status string, isExclusive *bool) ([]model.Group, *pagination.PaginationResult, error)
|
||||
ListActive(ctx context.Context) ([]model.Group, error)
|
||||
ListActiveByPlatform(ctx context.Context, platform string) ([]model.Group, error)
|
||||
|
||||
ExistsByName(ctx context.Context, name string) (bool, error)
|
||||
GetAccountCount(ctx context.Context, groupID int64) (int64, error)
|
||||
DeleteAccountGroupsByGroupID(ctx context.Context, groupID int64) (int64, error)
|
||||
|
||||
DB() *gorm.DB
|
||||
}
|
||||
23
backend/internal/service/ports/proxy.go
Normal file
23
backend/internal/service/ports/proxy.go
Normal file
@@ -0,0 +1,23 @@
|
||||
package ports
|
||||
|
||||
import (
|
||||
"context"
|
||||
|
||||
"sub2api/internal/model"
|
||||
"sub2api/internal/pkg/pagination"
|
||||
)
|
||||
|
||||
type ProxyRepository interface {
|
||||
Create(ctx context.Context, proxy *model.Proxy) error
|
||||
GetByID(ctx context.Context, id int64) (*model.Proxy, error)
|
||||
Update(ctx context.Context, proxy *model.Proxy) error
|
||||
Delete(ctx context.Context, id int64) error
|
||||
|
||||
List(ctx context.Context, params pagination.PaginationParams) ([]model.Proxy, *pagination.PaginationResult, error)
|
||||
ListWithFilters(ctx context.Context, params pagination.PaginationParams, protocol, status, search string) ([]model.Proxy, *pagination.PaginationResult, error)
|
||||
ListActive(ctx context.Context) ([]model.Proxy, error)
|
||||
ListActiveWithAccountCount(ctx context.Context) ([]model.ProxyWithAccountCount, error)
|
||||
|
||||
ExistsByHostPortAuth(ctx context.Context, host string, port int, username, password string) (bool, error)
|
||||
CountAccountsByProxyID(ctx context.Context, proxyID int64) (int64, error)
|
||||
}
|
||||
22
backend/internal/service/ports/redeem_code.go
Normal file
22
backend/internal/service/ports/redeem_code.go
Normal file
@@ -0,0 +1,22 @@
|
||||
package ports
|
||||
|
||||
import (
|
||||
"context"
|
||||
|
||||
"sub2api/internal/model"
|
||||
"sub2api/internal/pkg/pagination"
|
||||
)
|
||||
|
||||
type RedeemCodeRepository interface {
|
||||
Create(ctx context.Context, code *model.RedeemCode) error
|
||||
CreateBatch(ctx context.Context, codes []model.RedeemCode) error
|
||||
GetByID(ctx context.Context, id int64) (*model.RedeemCode, error)
|
||||
GetByCode(ctx context.Context, code string) (*model.RedeemCode, error)
|
||||
Update(ctx context.Context, code *model.RedeemCode) error
|
||||
Delete(ctx context.Context, id int64) error
|
||||
Use(ctx context.Context, id, userID int64) error
|
||||
|
||||
List(ctx context.Context, params pagination.PaginationParams) ([]model.RedeemCode, *pagination.PaginationResult, error)
|
||||
ListWithFilters(ctx context.Context, params pagination.PaginationParams, codeType, status, search string) ([]model.RedeemCode, *pagination.PaginationResult, error)
|
||||
ListByUser(ctx context.Context, userID int64, limit int) ([]model.RedeemCode, error)
|
||||
}
|
||||
17
backend/internal/service/ports/setting.go
Normal file
17
backend/internal/service/ports/setting.go
Normal file
@@ -0,0 +1,17 @@
|
||||
package ports
|
||||
|
||||
import (
|
||||
"context"
|
||||
|
||||
"sub2api/internal/model"
|
||||
)
|
||||
|
||||
type SettingRepository interface {
|
||||
Get(ctx context.Context, key string) (*model.Setting, error)
|
||||
GetValue(ctx context.Context, key string) (string, error)
|
||||
Set(ctx context.Context, key, value string) error
|
||||
GetMultiple(ctx context.Context, keys []string) (map[string]string, error)
|
||||
SetMultiple(ctx context.Context, settings map[string]string) error
|
||||
GetAll(ctx context.Context) (map[string]string, error)
|
||||
Delete(ctx context.Context, key string) error
|
||||
}
|
||||
28
backend/internal/service/ports/usage_log.go
Normal file
28
backend/internal/service/ports/usage_log.go
Normal file
@@ -0,0 +1,28 @@
|
||||
package ports
|
||||
|
||||
import (
|
||||
"context"
|
||||
"time"
|
||||
|
||||
"sub2api/internal/model"
|
||||
"sub2api/internal/pkg/pagination"
|
||||
"sub2api/internal/pkg/usagestats"
|
||||
)
|
||||
|
||||
type UsageLogRepository interface {
|
||||
Create(ctx context.Context, log *model.UsageLog) error
|
||||
GetByID(ctx context.Context, id int64) (*model.UsageLog, error)
|
||||
Delete(ctx context.Context, id int64) error
|
||||
|
||||
ListByUser(ctx context.Context, userID int64, params pagination.PaginationParams) ([]model.UsageLog, *pagination.PaginationResult, error)
|
||||
ListByApiKey(ctx context.Context, apiKeyID int64, params pagination.PaginationParams) ([]model.UsageLog, *pagination.PaginationResult, error)
|
||||
ListByAccount(ctx context.Context, accountID int64, params pagination.PaginationParams) ([]model.UsageLog, *pagination.PaginationResult, error)
|
||||
|
||||
ListByUserAndTimeRange(ctx context.Context, userID int64, startTime, endTime time.Time) ([]model.UsageLog, *pagination.PaginationResult, error)
|
||||
ListByApiKeyAndTimeRange(ctx context.Context, apiKeyID int64, startTime, endTime time.Time) ([]model.UsageLog, *pagination.PaginationResult, error)
|
||||
ListByAccountAndTimeRange(ctx context.Context, accountID int64, startTime, endTime time.Time) ([]model.UsageLog, *pagination.PaginationResult, error)
|
||||
ListByModelAndTimeRange(ctx context.Context, modelName string, startTime, endTime time.Time) ([]model.UsageLog, *pagination.PaginationResult, error)
|
||||
|
||||
GetAccountWindowStats(ctx context.Context, accountID int64, startTime time.Time) (*usagestats.AccountStats, error)
|
||||
GetAccountTodayStats(ctx context.Context, accountID int64) (*usagestats.AccountStats, error)
|
||||
}
|
||||
25
backend/internal/service/ports/user.go
Normal file
25
backend/internal/service/ports/user.go
Normal file
@@ -0,0 +1,25 @@
|
||||
package ports
|
||||
|
||||
import (
|
||||
"context"
|
||||
|
||||
"sub2api/internal/model"
|
||||
"sub2api/internal/pkg/pagination"
|
||||
)
|
||||
|
||||
type UserRepository interface {
|
||||
Create(ctx context.Context, user *model.User) error
|
||||
GetByID(ctx context.Context, id int64) (*model.User, error)
|
||||
GetByEmail(ctx context.Context, email string) (*model.User, error)
|
||||
Update(ctx context.Context, user *model.User) error
|
||||
Delete(ctx context.Context, id int64) error
|
||||
|
||||
List(ctx context.Context, params pagination.PaginationParams) ([]model.User, *pagination.PaginationResult, error)
|
||||
ListWithFilters(ctx context.Context, params pagination.PaginationParams, status, role, search string) ([]model.User, *pagination.PaginationResult, error)
|
||||
|
||||
UpdateBalance(ctx context.Context, id int64, amount float64) error
|
||||
DeductBalance(ctx context.Context, id int64, amount float64) error
|
||||
UpdateConcurrency(ctx context.Context, id int64, amount int) error
|
||||
ExistsByEmail(ctx context.Context, email string) (bool, error)
|
||||
RemoveGroupFromAllowedGroups(ctx context.Context, groupID int64) (int64, error)
|
||||
}
|
||||
36
backend/internal/service/ports/user_subscription.go
Normal file
36
backend/internal/service/ports/user_subscription.go
Normal file
@@ -0,0 +1,36 @@
|
||||
package ports
|
||||
|
||||
import (
|
||||
"context"
|
||||
"time"
|
||||
|
||||
"sub2api/internal/model"
|
||||
"sub2api/internal/pkg/pagination"
|
||||
)
|
||||
|
||||
type UserSubscriptionRepository interface {
|
||||
Create(ctx context.Context, sub *model.UserSubscription) error
|
||||
GetByID(ctx context.Context, id int64) (*model.UserSubscription, error)
|
||||
GetByUserIDAndGroupID(ctx context.Context, userID, groupID int64) (*model.UserSubscription, error)
|
||||
GetActiveByUserIDAndGroupID(ctx context.Context, userID, groupID int64) (*model.UserSubscription, error)
|
||||
Update(ctx context.Context, sub *model.UserSubscription) error
|
||||
Delete(ctx context.Context, id int64) error
|
||||
|
||||
ListByUserID(ctx context.Context, userID int64) ([]model.UserSubscription, error)
|
||||
ListActiveByUserID(ctx context.Context, userID int64) ([]model.UserSubscription, error)
|
||||
ListByGroupID(ctx context.Context, groupID int64, params pagination.PaginationParams) ([]model.UserSubscription, *pagination.PaginationResult, error)
|
||||
List(ctx context.Context, params pagination.PaginationParams, userID, groupID *int64, status string) ([]model.UserSubscription, *pagination.PaginationResult, error)
|
||||
|
||||
ExistsByUserIDAndGroupID(ctx context.Context, userID, groupID int64) (bool, error)
|
||||
ExtendExpiry(ctx context.Context, subscriptionID int64, newExpiresAt time.Time) error
|
||||
UpdateStatus(ctx context.Context, subscriptionID int64, status string) error
|
||||
UpdateNotes(ctx context.Context, subscriptionID int64, notes string) error
|
||||
|
||||
ActivateWindows(ctx context.Context, id int64, start time.Time) error
|
||||
ResetDailyUsage(ctx context.Context, id int64, newWindowStart time.Time) error
|
||||
ResetWeeklyUsage(ctx context.Context, id int64, newWindowStart time.Time) error
|
||||
ResetMonthlyUsage(ctx context.Context, id int64, newWindowStart time.Time) error
|
||||
IncrementUsage(ctx context.Context, id int64, costUSD float64) error
|
||||
|
||||
BatchUpdateExpiredStatus(ctx context.Context) (int64, error)
|
||||
}
|
||||
@@ -5,7 +5,8 @@ import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"sub2api/internal/model"
|
||||
"sub2api/internal/repository"
|
||||
"sub2api/internal/pkg/pagination"
|
||||
"sub2api/internal/service/ports"
|
||||
|
||||
"gorm.io/gorm"
|
||||
)
|
||||
@@ -37,11 +38,11 @@ type UpdateProxyRequest struct {
|
||||
|
||||
// ProxyService 代理管理服务
|
||||
type ProxyService struct {
|
||||
proxyRepo *repository.ProxyRepository
|
||||
proxyRepo ports.ProxyRepository
|
||||
}
|
||||
|
||||
// NewProxyService 创建代理服务实例
|
||||
func NewProxyService(proxyRepo *repository.ProxyRepository) *ProxyService {
|
||||
func NewProxyService(proxyRepo ports.ProxyRepository) *ProxyService {
|
||||
return &ProxyService{
|
||||
proxyRepo: proxyRepo,
|
||||
}
|
||||
@@ -80,7 +81,7 @@ func (s *ProxyService) GetByID(ctx context.Context, id int64) (*model.Proxy, err
|
||||
}
|
||||
|
||||
// List 获取代理列表
|
||||
func (s *ProxyService) List(ctx context.Context, params repository.PaginationParams) ([]model.Proxy, *repository.PaginationResult, error) {
|
||||
func (s *ProxyService) List(ctx context.Context, params pagination.PaginationParams) ([]model.Proxy, *pagination.PaginationResult, error) {
|
||||
proxies, pagination, err := s.proxyRepo.List(ctx, params)
|
||||
if err != nil {
|
||||
return nil, nil, fmt.Errorf("list proxies: %w", err)
|
||||
|
||||
@@ -9,20 +9,20 @@ import (
|
||||
|
||||
"sub2api/internal/config"
|
||||
"sub2api/internal/model"
|
||||
"sub2api/internal/repository"
|
||||
"sub2api/internal/service/ports"
|
||||
)
|
||||
|
||||
// RateLimitService 处理限流和过载状态管理
|
||||
type RateLimitService struct {
|
||||
repos *repository.Repositories
|
||||
cfg *config.Config
|
||||
accountRepo ports.AccountRepository
|
||||
cfg *config.Config
|
||||
}
|
||||
|
||||
// NewRateLimitService 创建RateLimitService实例
|
||||
func NewRateLimitService(repos *repository.Repositories, cfg *config.Config) *RateLimitService {
|
||||
func NewRateLimitService(accountRepo ports.AccountRepository, cfg *config.Config) *RateLimitService {
|
||||
return &RateLimitService{
|
||||
repos: repos,
|
||||
cfg: cfg,
|
||||
accountRepo: accountRepo,
|
||||
cfg: cfg,
|
||||
}
|
||||
}
|
||||
|
||||
@@ -62,7 +62,7 @@ func (s *RateLimitService) HandleUpstreamError(ctx context.Context, account *mod
|
||||
|
||||
// handleAuthError 处理认证类错误(401/403),停止账号调度
|
||||
func (s *RateLimitService) handleAuthError(ctx context.Context, account *model.Account, errorMsg string) {
|
||||
if err := s.repos.Account.SetError(ctx, account.ID, errorMsg); err != nil {
|
||||
if err := s.accountRepo.SetError(ctx, account.ID, errorMsg); err != nil {
|
||||
log.Printf("SetError failed for account %d: %v", account.ID, err)
|
||||
return
|
||||
}
|
||||
@@ -77,7 +77,7 @@ func (s *RateLimitService) handle429(ctx context.Context, account *model.Account
|
||||
if resetTimestamp == "" {
|
||||
// 没有重置时间,使用默认5分钟
|
||||
resetAt := time.Now().Add(5 * time.Minute)
|
||||
if err := s.repos.Account.SetRateLimited(ctx, account.ID, resetAt); err != nil {
|
||||
if err := s.accountRepo.SetRateLimited(ctx, account.ID, resetAt); err != nil {
|
||||
log.Printf("SetRateLimited failed for account %d: %v", account.ID, err)
|
||||
}
|
||||
return
|
||||
@@ -88,7 +88,7 @@ func (s *RateLimitService) handle429(ctx context.Context, account *model.Account
|
||||
if err != nil {
|
||||
log.Printf("Parse reset timestamp failed: %v", err)
|
||||
resetAt := time.Now().Add(5 * time.Minute)
|
||||
if err := s.repos.Account.SetRateLimited(ctx, account.ID, resetAt); err != nil {
|
||||
if err := s.accountRepo.SetRateLimited(ctx, account.ID, resetAt); err != nil {
|
||||
log.Printf("SetRateLimited failed for account %d: %v", account.ID, err)
|
||||
}
|
||||
return
|
||||
@@ -97,7 +97,7 @@ func (s *RateLimitService) handle429(ctx context.Context, account *model.Account
|
||||
resetAt := time.Unix(ts, 0)
|
||||
|
||||
// 标记限流状态
|
||||
if err := s.repos.Account.SetRateLimited(ctx, account.ID, resetAt); err != nil {
|
||||
if err := s.accountRepo.SetRateLimited(ctx, account.ID, resetAt); err != nil {
|
||||
log.Printf("SetRateLimited failed for account %d: %v", account.ID, err)
|
||||
return
|
||||
}
|
||||
@@ -105,7 +105,7 @@ func (s *RateLimitService) handle429(ctx context.Context, account *model.Account
|
||||
// 根据重置时间反推5h窗口
|
||||
windowEnd := resetAt
|
||||
windowStart := resetAt.Add(-5 * time.Hour)
|
||||
if err := s.repos.Account.UpdateSessionWindow(ctx, account.ID, &windowStart, &windowEnd, "rejected"); err != nil {
|
||||
if err := s.accountRepo.UpdateSessionWindow(ctx, account.ID, &windowStart, &windowEnd, "rejected"); err != nil {
|
||||
log.Printf("UpdateSessionWindow failed for account %d: %v", account.ID, err)
|
||||
}
|
||||
|
||||
@@ -121,7 +121,7 @@ func (s *RateLimitService) handle529(ctx context.Context, account *model.Account
|
||||
}
|
||||
|
||||
until := time.Now().Add(time.Duration(cooldownMinutes) * time.Minute)
|
||||
if err := s.repos.Account.SetOverloaded(ctx, account.ID, until); err != nil {
|
||||
if err := s.accountRepo.SetOverloaded(ctx, account.ID, until); err != nil {
|
||||
log.Printf("SetOverloaded failed for account %d: %v", account.ID, err)
|
||||
return
|
||||
}
|
||||
@@ -152,13 +152,13 @@ func (s *RateLimitService) UpdateSessionWindow(ctx context.Context, account *mod
|
||||
log.Printf("Account %d: initializing 5h window from %v to %v (status: %s)", account.ID, start, end, status)
|
||||
}
|
||||
|
||||
if err := s.repos.Account.UpdateSessionWindow(ctx, account.ID, windowStart, windowEnd, status); err != nil {
|
||||
if err := s.accountRepo.UpdateSessionWindow(ctx, account.ID, windowStart, windowEnd, status); err != nil {
|
||||
log.Printf("UpdateSessionWindow failed for account %d: %v", account.ID, err)
|
||||
}
|
||||
|
||||
// 如果状态为allowed且之前有限流,说明窗口已重置,清除限流状态
|
||||
if status == "allowed" && account.IsRateLimited() {
|
||||
if err := s.repos.Account.ClearRateLimit(ctx, account.ID); err != nil {
|
||||
if err := s.accountRepo.ClearRateLimit(ctx, account.ID); err != nil {
|
||||
log.Printf("ClearRateLimit failed for account %d: %v", account.ID, err)
|
||||
}
|
||||
}
|
||||
@@ -166,5 +166,5 @@ func (s *RateLimitService) UpdateSessionWindow(ctx context.Context, account *mod
|
||||
|
||||
// ClearRateLimit 清除账号的限流状态
|
||||
func (s *RateLimitService) ClearRateLimit(ctx context.Context, accountID int64) error {
|
||||
return s.repos.Account.ClearRateLimit(ctx, accountID)
|
||||
return s.accountRepo.ClearRateLimit(ctx, accountID)
|
||||
}
|
||||
|
||||
@@ -8,7 +8,8 @@ import (
|
||||
"fmt"
|
||||
"strings"
|
||||
"sub2api/internal/model"
|
||||
"sub2api/internal/repository"
|
||||
"sub2api/internal/pkg/pagination"
|
||||
"sub2api/internal/service/ports"
|
||||
"time"
|
||||
|
||||
"github.com/redis/go-redis/v9"
|
||||
@@ -49,8 +50,8 @@ type RedeemCodeResponse struct {
|
||||
|
||||
// RedeemService 兑换码服务
|
||||
type RedeemService struct {
|
||||
redeemRepo *repository.RedeemCodeRepository
|
||||
userRepo *repository.UserRepository
|
||||
redeemRepo ports.RedeemCodeRepository
|
||||
userRepo ports.UserRepository
|
||||
subscriptionService *SubscriptionService
|
||||
rdb *redis.Client
|
||||
billingCacheService *BillingCacheService
|
||||
@@ -58,8 +59,8 @@ type RedeemService struct {
|
||||
|
||||
// NewRedeemService 创建兑换码服务实例
|
||||
func NewRedeemService(
|
||||
redeemRepo *repository.RedeemCodeRepository,
|
||||
userRepo *repository.UserRepository,
|
||||
redeemRepo ports.RedeemCodeRepository,
|
||||
userRepo ports.UserRepository,
|
||||
subscriptionService *SubscriptionService,
|
||||
rdb *redis.Client,
|
||||
billingCacheService *BillingCacheService,
|
||||
@@ -337,7 +338,7 @@ func (s *RedeemService) GetByCode(ctx context.Context, code string) (*model.Rede
|
||||
}
|
||||
|
||||
// List 获取兑换码列表(管理员功能)
|
||||
func (s *RedeemService) List(ctx context.Context, params repository.PaginationParams) ([]model.RedeemCode, *repository.PaginationResult, error) {
|
||||
func (s *RedeemService) List(ctx context.Context, params pagination.PaginationParams) ([]model.RedeemCode, *pagination.PaginationResult, error) {
|
||||
codes, pagination, err := s.redeemRepo.List(ctx, params)
|
||||
if err != nil {
|
||||
return nil, nil, fmt.Errorf("list redeem codes: %w", err)
|
||||
|
||||
@@ -7,7 +7,7 @@ import (
|
||||
"strconv"
|
||||
"sub2api/internal/config"
|
||||
"sub2api/internal/model"
|
||||
"sub2api/internal/repository"
|
||||
"sub2api/internal/service/ports"
|
||||
|
||||
"gorm.io/gorm"
|
||||
)
|
||||
@@ -18,12 +18,12 @@ var (
|
||||
|
||||
// SettingService 系统设置服务
|
||||
type SettingService struct {
|
||||
settingRepo *repository.SettingRepository
|
||||
settingRepo ports.SettingRepository
|
||||
cfg *config.Config
|
||||
}
|
||||
|
||||
// NewSettingService 创建系统设置服务实例
|
||||
func NewSettingService(settingRepo *repository.SettingRepository, cfg *config.Config) *SettingService {
|
||||
func NewSettingService(settingRepo ports.SettingRepository, cfg *config.Config) *SettingService {
|
||||
return &SettingService{
|
||||
settingRepo: settingRepo,
|
||||
cfg: cfg,
|
||||
|
||||
@@ -7,7 +7,8 @@ import (
|
||||
"time"
|
||||
|
||||
"sub2api/internal/model"
|
||||
"sub2api/internal/repository"
|
||||
"sub2api/internal/pkg/pagination"
|
||||
"sub2api/internal/service/ports"
|
||||
)
|
||||
|
||||
var (
|
||||
@@ -23,14 +24,16 @@ var (
|
||||
|
||||
// SubscriptionService 订阅服务
|
||||
type SubscriptionService struct {
|
||||
repos *repository.Repositories
|
||||
groupRepo ports.GroupRepository
|
||||
userSubRepo ports.UserSubscriptionRepository
|
||||
billingCacheService *BillingCacheService
|
||||
}
|
||||
|
||||
// NewSubscriptionService 创建订阅服务
|
||||
func NewSubscriptionService(repos *repository.Repositories, billingCacheService *BillingCacheService) *SubscriptionService {
|
||||
func NewSubscriptionService(groupRepo ports.GroupRepository, userSubRepo ports.UserSubscriptionRepository, billingCacheService *BillingCacheService) *SubscriptionService {
|
||||
return &SubscriptionService{
|
||||
repos: repos,
|
||||
groupRepo: groupRepo,
|
||||
userSubRepo: userSubRepo,
|
||||
billingCacheService: billingCacheService,
|
||||
}
|
||||
}
|
||||
@@ -47,7 +50,7 @@ type AssignSubscriptionInput struct {
|
||||
// AssignSubscription 分配订阅给用户(不允许重复分配)
|
||||
func (s *SubscriptionService) AssignSubscription(ctx context.Context, input *AssignSubscriptionInput) (*model.UserSubscription, error) {
|
||||
// 检查分组是否存在且为订阅类型
|
||||
group, err := s.repos.Group.GetByID(ctx, input.GroupID)
|
||||
group, err := s.groupRepo.GetByID(ctx, input.GroupID)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("group not found: %w", err)
|
||||
}
|
||||
@@ -56,7 +59,7 @@ func (s *SubscriptionService) AssignSubscription(ctx context.Context, input *Ass
|
||||
}
|
||||
|
||||
// 检查是否已存在订阅
|
||||
exists, err := s.repos.UserSubscription.ExistsByUserIDAndGroupID(ctx, input.UserID, input.GroupID)
|
||||
exists, err := s.userSubRepo.ExistsByUserIDAndGroupID(ctx, input.UserID, input.GroupID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@@ -90,7 +93,7 @@ func (s *SubscriptionService) AssignSubscription(ctx context.Context, input *Ass
|
||||
// 如果没有订阅:创建新订阅
|
||||
func (s *SubscriptionService) AssignOrExtendSubscription(ctx context.Context, input *AssignSubscriptionInput) (*model.UserSubscription, bool, error) {
|
||||
// 检查分组是否存在且为订阅类型
|
||||
group, err := s.repos.Group.GetByID(ctx, input.GroupID)
|
||||
group, err := s.groupRepo.GetByID(ctx, input.GroupID)
|
||||
if err != nil {
|
||||
return nil, false, fmt.Errorf("group not found: %w", err)
|
||||
}
|
||||
@@ -99,7 +102,7 @@ func (s *SubscriptionService) AssignOrExtendSubscription(ctx context.Context, in
|
||||
}
|
||||
|
||||
// 查询是否已有订阅
|
||||
existingSub, err := s.repos.UserSubscription.GetByUserIDAndGroupID(ctx, input.UserID, input.GroupID)
|
||||
existingSub, err := s.userSubRepo.GetByUserIDAndGroupID(ctx, input.UserID, input.GroupID)
|
||||
if err != nil {
|
||||
// 不存在记录是正常情况,其他错误需要返回
|
||||
existingSub = nil
|
||||
@@ -124,13 +127,13 @@ func (s *SubscriptionService) AssignOrExtendSubscription(ctx context.Context, in
|
||||
}
|
||||
|
||||
// 更新过期时间
|
||||
if err := s.repos.UserSubscription.ExtendExpiry(ctx, existingSub.ID, newExpiresAt); err != nil {
|
||||
if err := s.userSubRepo.ExtendExpiry(ctx, existingSub.ID, newExpiresAt); err != nil {
|
||||
return nil, false, fmt.Errorf("extend subscription: %w", err)
|
||||
}
|
||||
|
||||
// 如果订阅已过期或被暂停,恢复为active状态
|
||||
if existingSub.Status != model.SubscriptionStatusActive {
|
||||
if err := s.repos.UserSubscription.UpdateStatus(ctx, existingSub.ID, model.SubscriptionStatusActive); err != nil {
|
||||
if err := s.userSubRepo.UpdateStatus(ctx, existingSub.ID, model.SubscriptionStatusActive); err != nil {
|
||||
return nil, false, fmt.Errorf("update subscription status: %w", err)
|
||||
}
|
||||
}
|
||||
@@ -142,7 +145,7 @@ func (s *SubscriptionService) AssignOrExtendSubscription(ctx context.Context, in
|
||||
newNotes += "\n"
|
||||
}
|
||||
newNotes += input.Notes
|
||||
if err := s.repos.UserSubscription.UpdateNotes(ctx, existingSub.ID, newNotes); err != nil {
|
||||
if err := s.userSubRepo.UpdateNotes(ctx, existingSub.ID, newNotes); err != nil {
|
||||
// 备注更新失败不影响主流程
|
||||
}
|
||||
}
|
||||
@@ -158,7 +161,7 @@ func (s *SubscriptionService) AssignOrExtendSubscription(ctx context.Context, in
|
||||
}
|
||||
|
||||
// 返回更新后的订阅
|
||||
sub, err := s.repos.UserSubscription.GetByID(ctx, existingSub.ID)
|
||||
sub, err := s.userSubRepo.GetByID(ctx, existingSub.ID)
|
||||
return sub, true, err // true 表示是续期
|
||||
}
|
||||
|
||||
@@ -205,12 +208,12 @@ func (s *SubscriptionService) createSubscription(ctx context.Context, input *Ass
|
||||
sub.AssignedBy = &input.AssignedBy
|
||||
}
|
||||
|
||||
if err := s.repos.UserSubscription.Create(ctx, sub); err != nil {
|
||||
if err := s.userSubRepo.Create(ctx, sub); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// 重新获取完整订阅信息(包含关联)
|
||||
return s.repos.UserSubscription.GetByID(ctx, sub.ID)
|
||||
return s.userSubRepo.GetByID(ctx, sub.ID)
|
||||
}
|
||||
|
||||
// BulkAssignSubscriptionInput 批量分配订阅输入
|
||||
@@ -260,12 +263,12 @@ func (s *SubscriptionService) BulkAssignSubscription(ctx context.Context, input
|
||||
// RevokeSubscription 撤销订阅
|
||||
func (s *SubscriptionService) RevokeSubscription(ctx context.Context, subscriptionID int64) error {
|
||||
// 先获取订阅信息用于失效缓存
|
||||
sub, err := s.repos.UserSubscription.GetByID(ctx, subscriptionID)
|
||||
sub, err := s.userSubRepo.GetByID(ctx, subscriptionID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if err := s.repos.UserSubscription.Delete(ctx, subscriptionID); err != nil {
|
||||
if err := s.userSubRepo.Delete(ctx, subscriptionID); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
@@ -284,20 +287,20 @@ func (s *SubscriptionService) RevokeSubscription(ctx context.Context, subscripti
|
||||
|
||||
// ExtendSubscription 延长订阅
|
||||
func (s *SubscriptionService) ExtendSubscription(ctx context.Context, subscriptionID int64, days int) (*model.UserSubscription, error) {
|
||||
sub, err := s.repos.UserSubscription.GetByID(ctx, subscriptionID)
|
||||
sub, err := s.userSubRepo.GetByID(ctx, subscriptionID)
|
||||
if err != nil {
|
||||
return nil, ErrSubscriptionNotFound
|
||||
}
|
||||
|
||||
// 计算新的过期时间
|
||||
newExpiresAt := sub.ExpiresAt.AddDate(0, 0, days)
|
||||
if err := s.repos.UserSubscription.ExtendExpiry(ctx, subscriptionID, newExpiresAt); err != nil {
|
||||
if err := s.userSubRepo.ExtendExpiry(ctx, subscriptionID, newExpiresAt); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// 如果订阅已过期,恢复为active状态
|
||||
if sub.Status == model.SubscriptionStatusExpired {
|
||||
if err := s.repos.UserSubscription.UpdateStatus(ctx, subscriptionID, model.SubscriptionStatusActive); err != nil {
|
||||
if err := s.userSubRepo.UpdateStatus(ctx, subscriptionID, model.SubscriptionStatusActive); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
@@ -312,17 +315,17 @@ func (s *SubscriptionService) ExtendSubscription(ctx context.Context, subscripti
|
||||
}()
|
||||
}
|
||||
|
||||
return s.repos.UserSubscription.GetByID(ctx, subscriptionID)
|
||||
return s.userSubRepo.GetByID(ctx, subscriptionID)
|
||||
}
|
||||
|
||||
// GetByID 根据ID获取订阅
|
||||
func (s *SubscriptionService) GetByID(ctx context.Context, id int64) (*model.UserSubscription, error) {
|
||||
return s.repos.UserSubscription.GetByID(ctx, id)
|
||||
return s.userSubRepo.GetByID(ctx, id)
|
||||
}
|
||||
|
||||
// GetActiveSubscription 获取用户对特定分组的有效订阅
|
||||
func (s *SubscriptionService) GetActiveSubscription(ctx context.Context, userID, groupID int64) (*model.UserSubscription, error) {
|
||||
sub, err := s.repos.UserSubscription.GetActiveByUserIDAndGroupID(ctx, userID, groupID)
|
||||
sub, err := s.userSubRepo.GetActiveByUserIDAndGroupID(ctx, userID, groupID)
|
||||
if err != nil {
|
||||
return nil, ErrSubscriptionNotFound
|
||||
}
|
||||
@@ -331,24 +334,24 @@ func (s *SubscriptionService) GetActiveSubscription(ctx context.Context, userID,
|
||||
|
||||
// ListUserSubscriptions 获取用户的所有订阅
|
||||
func (s *SubscriptionService) ListUserSubscriptions(ctx context.Context, userID int64) ([]model.UserSubscription, error) {
|
||||
return s.repos.UserSubscription.ListByUserID(ctx, userID)
|
||||
return s.userSubRepo.ListByUserID(ctx, userID)
|
||||
}
|
||||
|
||||
// ListActiveUserSubscriptions 获取用户的所有有效订阅
|
||||
func (s *SubscriptionService) ListActiveUserSubscriptions(ctx context.Context, userID int64) ([]model.UserSubscription, error) {
|
||||
return s.repos.UserSubscription.ListActiveByUserID(ctx, userID)
|
||||
return s.userSubRepo.ListActiveByUserID(ctx, userID)
|
||||
}
|
||||
|
||||
// ListGroupSubscriptions 获取分组的所有订阅
|
||||
func (s *SubscriptionService) ListGroupSubscriptions(ctx context.Context, groupID int64, page, pageSize int) ([]model.UserSubscription, *repository.PaginationResult, error) {
|
||||
params := repository.PaginationParams{Page: page, PageSize: pageSize}
|
||||
return s.repos.UserSubscription.ListByGroupID(ctx, groupID, params)
|
||||
func (s *SubscriptionService) ListGroupSubscriptions(ctx context.Context, groupID int64, page, pageSize int) ([]model.UserSubscription, *pagination.PaginationResult, error) {
|
||||
params := pagination.PaginationParams{Page: page, PageSize: pageSize}
|
||||
return s.userSubRepo.ListByGroupID(ctx, groupID, params)
|
||||
}
|
||||
|
||||
// List 获取所有订阅(分页,支持筛选)
|
||||
func (s *SubscriptionService) List(ctx context.Context, page, pageSize int, userID, groupID *int64, status string) ([]model.UserSubscription, *repository.PaginationResult, error) {
|
||||
params := repository.PaginationParams{Page: page, PageSize: pageSize}
|
||||
return s.repos.UserSubscription.List(ctx, params, userID, groupID, status)
|
||||
func (s *SubscriptionService) List(ctx context.Context, page, pageSize int, userID, groupID *int64, status string) ([]model.UserSubscription, *pagination.PaginationResult, error) {
|
||||
params := pagination.PaginationParams{Page: page, PageSize: pageSize}
|
||||
return s.userSubRepo.List(ctx, params, userID, groupID, status)
|
||||
}
|
||||
|
||||
// CheckAndActivateWindow 检查并激活窗口(首次使用时)
|
||||
@@ -358,7 +361,7 @@ func (s *SubscriptionService) CheckAndActivateWindow(ctx context.Context, sub *m
|
||||
}
|
||||
|
||||
now := time.Now()
|
||||
return s.repos.UserSubscription.ActivateWindows(ctx, sub.ID, now)
|
||||
return s.userSubRepo.ActivateWindows(ctx, sub.ID, now)
|
||||
}
|
||||
|
||||
// CheckAndResetWindows 检查并重置过期的窗口
|
||||
@@ -367,7 +370,7 @@ func (s *SubscriptionService) CheckAndResetWindows(ctx context.Context, sub *mod
|
||||
|
||||
// 日窗口重置(24小时)
|
||||
if sub.NeedsDailyReset() {
|
||||
if err := s.repos.UserSubscription.ResetDailyUsage(ctx, sub.ID, now); err != nil {
|
||||
if err := s.userSubRepo.ResetDailyUsage(ctx, sub.ID, now); err != nil {
|
||||
return err
|
||||
}
|
||||
sub.DailyWindowStart = &now
|
||||
@@ -376,7 +379,7 @@ func (s *SubscriptionService) CheckAndResetWindows(ctx context.Context, sub *mod
|
||||
|
||||
// 周窗口重置(7天)
|
||||
if sub.NeedsWeeklyReset() {
|
||||
if err := s.repos.UserSubscription.ResetWeeklyUsage(ctx, sub.ID, now); err != nil {
|
||||
if err := s.userSubRepo.ResetWeeklyUsage(ctx, sub.ID, now); err != nil {
|
||||
return err
|
||||
}
|
||||
sub.WeeklyWindowStart = &now
|
||||
@@ -385,7 +388,7 @@ func (s *SubscriptionService) CheckAndResetWindows(ctx context.Context, sub *mod
|
||||
|
||||
// 月窗口重置(30天)
|
||||
if sub.NeedsMonthlyReset() {
|
||||
if err := s.repos.UserSubscription.ResetMonthlyUsage(ctx, sub.ID, now); err != nil {
|
||||
if err := s.userSubRepo.ResetMonthlyUsage(ctx, sub.ID, now); err != nil {
|
||||
return err
|
||||
}
|
||||
sub.MonthlyWindowStart = &now
|
||||
@@ -411,7 +414,7 @@ func (s *SubscriptionService) CheckUsageLimits(ctx context.Context, sub *model.U
|
||||
|
||||
// RecordUsage 记录使用量到订阅
|
||||
func (s *SubscriptionService) RecordUsage(ctx context.Context, subscriptionID int64, costUSD float64) error {
|
||||
return s.repos.UserSubscription.IncrementUsage(ctx, subscriptionID, costUSD)
|
||||
return s.userSubRepo.IncrementUsage(ctx, subscriptionID, costUSD)
|
||||
}
|
||||
|
||||
// SubscriptionProgress 订阅进度
|
||||
@@ -438,14 +441,14 @@ type UsageWindowProgress struct {
|
||||
|
||||
// GetSubscriptionProgress 获取订阅使用进度
|
||||
func (s *SubscriptionService) GetSubscriptionProgress(ctx context.Context, subscriptionID int64) (*SubscriptionProgress, error) {
|
||||
sub, err := s.repos.UserSubscription.GetByID(ctx, subscriptionID)
|
||||
sub, err := s.userSubRepo.GetByID(ctx, subscriptionID)
|
||||
if err != nil {
|
||||
return nil, ErrSubscriptionNotFound
|
||||
}
|
||||
|
||||
group := sub.Group
|
||||
if group == nil {
|
||||
group, err = s.repos.Group.GetByID(ctx, sub.GroupID)
|
||||
group, err = s.groupRepo.GetByID(ctx, sub.GroupID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@@ -535,7 +538,7 @@ func (s *SubscriptionService) GetSubscriptionProgress(ctx context.Context, subsc
|
||||
|
||||
// GetUserSubscriptionsWithProgress 获取用户所有订阅及进度
|
||||
func (s *SubscriptionService) GetUserSubscriptionsWithProgress(ctx context.Context, userID int64) ([]SubscriptionProgress, error) {
|
||||
subs, err := s.repos.UserSubscription.ListActiveByUserID(ctx, userID)
|
||||
subs, err := s.userSubRepo.ListActiveByUserID(ctx, userID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@@ -554,7 +557,7 @@ func (s *SubscriptionService) GetUserSubscriptionsWithProgress(ctx context.Conte
|
||||
|
||||
// UpdateExpiredSubscriptions 更新过期订阅状态(定时任务调用)
|
||||
func (s *SubscriptionService) UpdateExpiredSubscriptions(ctx context.Context) (int64, error) {
|
||||
return s.repos.UserSubscription.BatchUpdateExpiredStatus(ctx)
|
||||
return s.userSubRepo.BatchUpdateExpiredStatus(ctx)
|
||||
}
|
||||
|
||||
// ValidateSubscription 验证订阅是否有效
|
||||
@@ -567,7 +570,7 @@ func (s *SubscriptionService) ValidateSubscription(ctx context.Context, sub *mod
|
||||
}
|
||||
if sub.IsExpired() {
|
||||
// 更新状态
|
||||
_ = s.repos.UserSubscription.UpdateStatus(ctx, sub.ID, model.SubscriptionStatusExpired)
|
||||
_ = s.userSubRepo.UpdateStatus(ctx, sub.ID, model.SubscriptionStatusExpired)
|
||||
return ErrSubscriptionExpired
|
||||
}
|
||||
return nil
|
||||
|
||||
@@ -5,7 +5,8 @@ import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"sub2api/internal/model"
|
||||
"sub2api/internal/repository"
|
||||
"sub2api/internal/pkg/pagination"
|
||||
"sub2api/internal/service/ports"
|
||||
"time"
|
||||
|
||||
"gorm.io/gorm"
|
||||
@@ -41,24 +42,24 @@ type CreateUsageLogRequest struct {
|
||||
|
||||
// UsageStats 使用统计
|
||||
type UsageStats struct {
|
||||
TotalRequests int64 `json:"total_requests"`
|
||||
TotalInputTokens int64 `json:"total_input_tokens"`
|
||||
TotalOutputTokens int64 `json:"total_output_tokens"`
|
||||
TotalCacheTokens int64 `json:"total_cache_tokens"`
|
||||
TotalTokens int64 `json:"total_tokens"`
|
||||
TotalCost float64 `json:"total_cost"`
|
||||
TotalActualCost float64 `json:"total_actual_cost"`
|
||||
AverageDurationMs float64 `json:"average_duration_ms"`
|
||||
TotalRequests int64 `json:"total_requests"`
|
||||
TotalInputTokens int64 `json:"total_input_tokens"`
|
||||
TotalOutputTokens int64 `json:"total_output_tokens"`
|
||||
TotalCacheTokens int64 `json:"total_cache_tokens"`
|
||||
TotalTokens int64 `json:"total_tokens"`
|
||||
TotalCost float64 `json:"total_cost"`
|
||||
TotalActualCost float64 `json:"total_actual_cost"`
|
||||
AverageDurationMs float64 `json:"average_duration_ms"`
|
||||
}
|
||||
|
||||
// UsageService 使用统计服务
|
||||
type UsageService struct {
|
||||
usageRepo *repository.UsageLogRepository
|
||||
userRepo *repository.UserRepository
|
||||
usageRepo ports.UsageLogRepository
|
||||
userRepo ports.UserRepository
|
||||
}
|
||||
|
||||
// NewUsageService 创建使用统计服务实例
|
||||
func NewUsageService(usageRepo *repository.UsageLogRepository, userRepo *repository.UserRepository) *UsageService {
|
||||
func NewUsageService(usageRepo ports.UsageLogRepository, userRepo ports.UserRepository) *UsageService {
|
||||
return &UsageService{
|
||||
usageRepo: usageRepo,
|
||||
userRepo: userRepo,
|
||||
@@ -127,7 +128,7 @@ func (s *UsageService) GetByID(ctx context.Context, id int64) (*model.UsageLog,
|
||||
}
|
||||
|
||||
// ListByUser 获取用户的使用日志列表
|
||||
func (s *UsageService) ListByUser(ctx context.Context, userID int64, params repository.PaginationParams) ([]model.UsageLog, *repository.PaginationResult, error) {
|
||||
func (s *UsageService) ListByUser(ctx context.Context, userID int64, params pagination.PaginationParams) ([]model.UsageLog, *pagination.PaginationResult, error) {
|
||||
logs, pagination, err := s.usageRepo.ListByUser(ctx, userID, params)
|
||||
if err != nil {
|
||||
return nil, nil, fmt.Errorf("list usage logs: %w", err)
|
||||
@@ -136,7 +137,7 @@ func (s *UsageService) ListByUser(ctx context.Context, userID int64, params repo
|
||||
}
|
||||
|
||||
// ListByApiKey 获取API Key的使用日志列表
|
||||
func (s *UsageService) ListByApiKey(ctx context.Context, apiKeyID int64, params repository.PaginationParams) ([]model.UsageLog, *repository.PaginationResult, error) {
|
||||
func (s *UsageService) ListByApiKey(ctx context.Context, apiKeyID int64, params pagination.PaginationParams) ([]model.UsageLog, *pagination.PaginationResult, error) {
|
||||
logs, pagination, err := s.usageRepo.ListByApiKey(ctx, apiKeyID, params)
|
||||
if err != nil {
|
||||
return nil, nil, fmt.Errorf("list usage logs: %w", err)
|
||||
@@ -145,7 +146,7 @@ func (s *UsageService) ListByApiKey(ctx context.Context, apiKeyID int64, params
|
||||
}
|
||||
|
||||
// ListByAccount 获取账号的使用日志列表
|
||||
func (s *UsageService) ListByAccount(ctx context.Context, accountID int64, params repository.PaginationParams) ([]model.UsageLog, *repository.PaginationResult, error) {
|
||||
func (s *UsageService) ListByAccount(ctx context.Context, accountID int64, params pagination.PaginationParams) ([]model.UsageLog, *pagination.PaginationResult, error) {
|
||||
logs, pagination, err := s.usageRepo.ListByAccount(ctx, accountID, params)
|
||||
if err != nil {
|
||||
return nil, nil, fmt.Errorf("list usage logs: %w", err)
|
||||
@@ -233,15 +234,15 @@ func (s *UsageService) GetDailyStats(ctx context.Context, userID int64, days int
|
||||
}
|
||||
|
||||
result = append(result, map[string]interface{}{
|
||||
"date": date,
|
||||
"total_requests": stats.TotalRequests,
|
||||
"total_input_tokens": stats.TotalInputTokens,
|
||||
"total_output_tokens": stats.TotalOutputTokens,
|
||||
"total_cache_tokens": stats.TotalCacheTokens,
|
||||
"total_tokens": stats.TotalTokens,
|
||||
"total_cost": stats.TotalCost,
|
||||
"total_actual_cost": stats.TotalActualCost,
|
||||
"average_duration_ms": stats.AverageDurationMs,
|
||||
"date": date,
|
||||
"total_requests": stats.TotalRequests,
|
||||
"total_input_tokens": stats.TotalInputTokens,
|
||||
"total_output_tokens": stats.TotalOutputTokens,
|
||||
"total_cache_tokens": stats.TotalCacheTokens,
|
||||
"total_tokens": stats.TotalTokens,
|
||||
"total_cost": stats.TotalCost,
|
||||
"total_actual_cost": stats.TotalActualCost,
|
||||
"average_duration_ms": stats.AverageDurationMs,
|
||||
})
|
||||
}
|
||||
|
||||
|
||||
@@ -6,16 +6,17 @@ import (
|
||||
"fmt"
|
||||
"sub2api/internal/config"
|
||||
"sub2api/internal/model"
|
||||
"sub2api/internal/repository"
|
||||
"sub2api/internal/pkg/pagination"
|
||||
"sub2api/internal/service/ports"
|
||||
|
||||
"golang.org/x/crypto/bcrypt"
|
||||
"gorm.io/gorm"
|
||||
)
|
||||
|
||||
var (
|
||||
ErrUserNotFound = errors.New("user not found")
|
||||
ErrPasswordIncorrect = errors.New("current password is incorrect")
|
||||
ErrInsufficientPerms = errors.New("insufficient permissions")
|
||||
ErrUserNotFound = errors.New("user not found")
|
||||
ErrPasswordIncorrect = errors.New("current password is incorrect")
|
||||
ErrInsufficientPerms = errors.New("insufficient permissions")
|
||||
)
|
||||
|
||||
// UpdateProfileRequest 更新用户资料请求
|
||||
@@ -32,12 +33,12 @@ type ChangePasswordRequest struct {
|
||||
|
||||
// UserService 用户服务
|
||||
type UserService struct {
|
||||
userRepo *repository.UserRepository
|
||||
userRepo ports.UserRepository
|
||||
cfg *config.Config
|
||||
}
|
||||
|
||||
// NewUserService 创建用户服务实例
|
||||
func NewUserService(userRepo *repository.UserRepository, cfg *config.Config) *UserService {
|
||||
func NewUserService(userRepo ports.UserRepository, cfg *config.Config) *UserService {
|
||||
return &UserService{
|
||||
userRepo: userRepo,
|
||||
cfg: cfg,
|
||||
@@ -133,7 +134,7 @@ func (s *UserService) GetByID(ctx context.Context, id int64) (*model.User, error
|
||||
}
|
||||
|
||||
// List 获取用户列表(管理员功能)
|
||||
func (s *UserService) List(ctx context.Context, params repository.PaginationParams) ([]model.User, *repository.PaginationResult, error) {
|
||||
func (s *UserService) List(ctx context.Context, params pagination.PaginationParams) ([]model.User, *pagination.PaginationResult, error) {
|
||||
users, pagination, err := s.userRepo.List(ctx, params)
|
||||
if err != nil {
|
||||
return nil, nil, fmt.Errorf("list users: %w", err)
|
||||
|
||||
Reference in New Issue
Block a user