refactor(backend): 引入端口接口模式

This commit is contained in:
Forest
2025-12-19 21:26:19 +08:00
parent 7fd94ab78b
commit e99b344b2b
45 changed files with 627 additions and 323 deletions

View File

@@ -21,7 +21,7 @@ import (
"sub2api/internal/config"
"sub2api/internal/model"
"sub2api/internal/pkg/claude"
"sub2api/internal/repository"
"sub2api/internal/service/ports"
"github.com/gin-gonic/gin"
"github.com/redis/go-redis/v9"
@@ -78,7 +78,10 @@ type ForwardResult struct {
// GatewayService handles API gateway operations
type GatewayService struct {
repos *repository.Repositories
accountRepo ports.AccountRepository
usageLogRepo ports.UsageLogRepository
userRepo ports.UserRepository
userSubRepo ports.UserSubscriptionRepository
rdb *redis.Client
cfg *config.Config
oauthService *OAuthService
@@ -90,7 +93,19 @@ type GatewayService struct {
}
// NewGatewayService creates a new GatewayService
func NewGatewayService(repos *repository.Repositories, rdb *redis.Client, cfg *config.Config, oauthService *OAuthService, billingService *BillingService, rateLimitService *RateLimitService, billingCacheService *BillingCacheService, identityService *IdentityService) *GatewayService {
func NewGatewayService(
accountRepo ports.AccountRepository,
usageLogRepo ports.UsageLogRepository,
userRepo ports.UserRepository,
userSubRepo ports.UserSubscriptionRepository,
rdb *redis.Client,
cfg *config.Config,
oauthService *OAuthService,
billingService *BillingService,
rateLimitService *RateLimitService,
billingCacheService *BillingCacheService,
identityService *IdentityService,
) *GatewayService {
// 计算响应头超时时间
responseHeaderTimeout := time.Duration(cfg.Gateway.ResponseHeaderTimeout) * time.Second
if responseHeaderTimeout == 0 {
@@ -105,7 +120,10 @@ func NewGatewayService(repos *repository.Repositories, rdb *redis.Client, cfg *c
// 注意:不设置整体 Timeout让流式响应可以无限时间传输
}
return &GatewayService{
repos: repos,
accountRepo: accountRepo,
usageLogRepo: usageLogRepo,
userRepo: userRepo,
userSubRepo: userSubRepo,
rdb: rdb,
cfg: cfg,
oauthService: oauthService,
@@ -274,7 +292,7 @@ func (s *GatewayService) SelectAccountForModel(ctx context.Context, groupID *int
if sessionHash != "" {
accountID, err := s.rdb.Get(ctx, stickySessionPrefix+sessionHash).Int64()
if err == nil && accountID > 0 {
account, err := s.repos.Account.GetByID(ctx, accountID)
account, err := s.accountRepo.GetByID(ctx, accountID)
// 使用IsSchedulable代替IsActive确保限流/过载账号不会被选中
// 同时检查模型支持
if err == nil && account.IsSchedulable() && (requestedModel == "" || account.IsModelSupported(requestedModel)) {
@@ -289,9 +307,9 @@ func (s *GatewayService) SelectAccountForModel(ctx context.Context, groupID *int
var accounts []model.Account
var err error
if groupID != nil {
accounts, err = s.repos.Account.ListSchedulableByGroupID(ctx, *groupID)
accounts, err = s.accountRepo.ListSchedulableByGroupID(ctx, *groupID)
} else {
accounts, err = s.repos.Account.ListSchedulable(ctx)
accounts, err = s.accountRepo.ListSchedulable(ctx)
}
if err != nil {
return nil, fmt.Errorf("query accounts failed: %w", err)
@@ -378,7 +396,7 @@ func (s *GatewayService) getOAuthToken(ctx context.Context, account *model.Accou
account.Credentials["refresh_token"] = tokenInfo.RefreshToken
}
if err := s.repos.Account.Update(ctx, account); err != nil {
if err := s.accountRepo.Update(ctx, account); err != nil {
log.Printf("Failed to update account credentials: %v", err)
}
@@ -667,7 +685,7 @@ func (s *GatewayService) forceRefreshToken(ctx context.Context, account *model.A
account.Credentials["refresh_token"] = tokenInfo.RefreshToken
}
if err := s.repos.Account.Update(ctx, account); err != nil {
if err := s.accountRepo.Update(ctx, account); err != nil {
log.Printf("Failed to update account credentials: %v", err)
}
@@ -999,7 +1017,7 @@ func (s *GatewayService) RecordUsage(ctx context.Context, input *RecordUsageInpu
usageLog.SubscriptionID = &subscription.ID
}
if err := s.repos.UsageLog.Create(ctx, usageLog); err != nil {
if err := s.usageLogRepo.Create(ctx, usageLog); err != nil {
log.Printf("Create usage log failed: %v", err)
}
@@ -1007,7 +1025,7 @@ func (s *GatewayService) RecordUsage(ctx context.Context, input *RecordUsageInpu
if isSubscriptionBilling {
// 订阅模式:更新订阅用量(使用 TotalCost 原始费用,不考虑倍率)
if cost.TotalCost > 0 {
if err := s.repos.UserSubscription.IncrementUsage(ctx, subscription.ID, cost.TotalCost); err != nil {
if err := s.userSubRepo.IncrementUsage(ctx, subscription.ID, cost.TotalCost); err != nil {
log.Printf("Increment subscription usage failed: %v", err)
}
// 异步更新订阅缓存
@@ -1022,7 +1040,7 @@ func (s *GatewayService) RecordUsage(ctx context.Context, input *RecordUsageInpu
} else {
// 余额模式:扣除用户余额(使用 ActualCost 考虑倍率后的费用)
if cost.ActualCost > 0 {
if err := s.repos.User.DeductBalance(ctx, user.ID, cost.ActualCost); err != nil {
if err := s.userRepo.DeductBalance(ctx, user.ID, cost.ActualCost); err != nil {
log.Printf("Deduct balance failed: %v", err)
}
// 异步更新余额缓存
@@ -1037,7 +1055,7 @@ func (s *GatewayService) RecordUsage(ctx context.Context, input *RecordUsageInpu
}
// 更新账号最后使用时间
if err := s.repos.Account.UpdateLastUsed(ctx, account.ID); err != nil {
if err := s.accountRepo.UpdateLastUsed(ctx, account.ID); err != nil {
log.Printf("Update last used failed: %v", err)
}