refactor(backend): 引入端口接口模式
This commit is contained in:
@@ -21,7 +21,7 @@ import (
|
||||
"sub2api/internal/config"
|
||||
"sub2api/internal/model"
|
||||
"sub2api/internal/pkg/claude"
|
||||
"sub2api/internal/repository"
|
||||
"sub2api/internal/service/ports"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/redis/go-redis/v9"
|
||||
@@ -78,7 +78,10 @@ type ForwardResult struct {
|
||||
|
||||
// GatewayService handles API gateway operations
|
||||
type GatewayService struct {
|
||||
repos *repository.Repositories
|
||||
accountRepo ports.AccountRepository
|
||||
usageLogRepo ports.UsageLogRepository
|
||||
userRepo ports.UserRepository
|
||||
userSubRepo ports.UserSubscriptionRepository
|
||||
rdb *redis.Client
|
||||
cfg *config.Config
|
||||
oauthService *OAuthService
|
||||
@@ -90,7 +93,19 @@ type GatewayService struct {
|
||||
}
|
||||
|
||||
// NewGatewayService creates a new GatewayService
|
||||
func NewGatewayService(repos *repository.Repositories, rdb *redis.Client, cfg *config.Config, oauthService *OAuthService, billingService *BillingService, rateLimitService *RateLimitService, billingCacheService *BillingCacheService, identityService *IdentityService) *GatewayService {
|
||||
func NewGatewayService(
|
||||
accountRepo ports.AccountRepository,
|
||||
usageLogRepo ports.UsageLogRepository,
|
||||
userRepo ports.UserRepository,
|
||||
userSubRepo ports.UserSubscriptionRepository,
|
||||
rdb *redis.Client,
|
||||
cfg *config.Config,
|
||||
oauthService *OAuthService,
|
||||
billingService *BillingService,
|
||||
rateLimitService *RateLimitService,
|
||||
billingCacheService *BillingCacheService,
|
||||
identityService *IdentityService,
|
||||
) *GatewayService {
|
||||
// 计算响应头超时时间
|
||||
responseHeaderTimeout := time.Duration(cfg.Gateway.ResponseHeaderTimeout) * time.Second
|
||||
if responseHeaderTimeout == 0 {
|
||||
@@ -105,7 +120,10 @@ func NewGatewayService(repos *repository.Repositories, rdb *redis.Client, cfg *c
|
||||
// 注意:不设置整体 Timeout,让流式响应可以无限时间传输
|
||||
}
|
||||
return &GatewayService{
|
||||
repos: repos,
|
||||
accountRepo: accountRepo,
|
||||
usageLogRepo: usageLogRepo,
|
||||
userRepo: userRepo,
|
||||
userSubRepo: userSubRepo,
|
||||
rdb: rdb,
|
||||
cfg: cfg,
|
||||
oauthService: oauthService,
|
||||
@@ -274,7 +292,7 @@ func (s *GatewayService) SelectAccountForModel(ctx context.Context, groupID *int
|
||||
if sessionHash != "" {
|
||||
accountID, err := s.rdb.Get(ctx, stickySessionPrefix+sessionHash).Int64()
|
||||
if err == nil && accountID > 0 {
|
||||
account, err := s.repos.Account.GetByID(ctx, accountID)
|
||||
account, err := s.accountRepo.GetByID(ctx, accountID)
|
||||
// 使用IsSchedulable代替IsActive,确保限流/过载账号不会被选中
|
||||
// 同时检查模型支持
|
||||
if err == nil && account.IsSchedulable() && (requestedModel == "" || account.IsModelSupported(requestedModel)) {
|
||||
@@ -289,9 +307,9 @@ func (s *GatewayService) SelectAccountForModel(ctx context.Context, groupID *int
|
||||
var accounts []model.Account
|
||||
var err error
|
||||
if groupID != nil {
|
||||
accounts, err = s.repos.Account.ListSchedulableByGroupID(ctx, *groupID)
|
||||
accounts, err = s.accountRepo.ListSchedulableByGroupID(ctx, *groupID)
|
||||
} else {
|
||||
accounts, err = s.repos.Account.ListSchedulable(ctx)
|
||||
accounts, err = s.accountRepo.ListSchedulable(ctx)
|
||||
}
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("query accounts failed: %w", err)
|
||||
@@ -378,7 +396,7 @@ func (s *GatewayService) getOAuthToken(ctx context.Context, account *model.Accou
|
||||
account.Credentials["refresh_token"] = tokenInfo.RefreshToken
|
||||
}
|
||||
|
||||
if err := s.repos.Account.Update(ctx, account); err != nil {
|
||||
if err := s.accountRepo.Update(ctx, account); err != nil {
|
||||
log.Printf("Failed to update account credentials: %v", err)
|
||||
}
|
||||
|
||||
@@ -667,7 +685,7 @@ func (s *GatewayService) forceRefreshToken(ctx context.Context, account *model.A
|
||||
account.Credentials["refresh_token"] = tokenInfo.RefreshToken
|
||||
}
|
||||
|
||||
if err := s.repos.Account.Update(ctx, account); err != nil {
|
||||
if err := s.accountRepo.Update(ctx, account); err != nil {
|
||||
log.Printf("Failed to update account credentials: %v", err)
|
||||
}
|
||||
|
||||
@@ -999,7 +1017,7 @@ func (s *GatewayService) RecordUsage(ctx context.Context, input *RecordUsageInpu
|
||||
usageLog.SubscriptionID = &subscription.ID
|
||||
}
|
||||
|
||||
if err := s.repos.UsageLog.Create(ctx, usageLog); err != nil {
|
||||
if err := s.usageLogRepo.Create(ctx, usageLog); err != nil {
|
||||
log.Printf("Create usage log failed: %v", err)
|
||||
}
|
||||
|
||||
@@ -1007,7 +1025,7 @@ func (s *GatewayService) RecordUsage(ctx context.Context, input *RecordUsageInpu
|
||||
if isSubscriptionBilling {
|
||||
// 订阅模式:更新订阅用量(使用 TotalCost 原始费用,不考虑倍率)
|
||||
if cost.TotalCost > 0 {
|
||||
if err := s.repos.UserSubscription.IncrementUsage(ctx, subscription.ID, cost.TotalCost); err != nil {
|
||||
if err := s.userSubRepo.IncrementUsage(ctx, subscription.ID, cost.TotalCost); err != nil {
|
||||
log.Printf("Increment subscription usage failed: %v", err)
|
||||
}
|
||||
// 异步更新订阅缓存
|
||||
@@ -1022,7 +1040,7 @@ func (s *GatewayService) RecordUsage(ctx context.Context, input *RecordUsageInpu
|
||||
} else {
|
||||
// 余额模式:扣除用户余额(使用 ActualCost 考虑倍率后的费用)
|
||||
if cost.ActualCost > 0 {
|
||||
if err := s.repos.User.DeductBalance(ctx, user.ID, cost.ActualCost); err != nil {
|
||||
if err := s.userRepo.DeductBalance(ctx, user.ID, cost.ActualCost); err != nil {
|
||||
log.Printf("Deduct balance failed: %v", err)
|
||||
}
|
||||
// 异步更新余额缓存
|
||||
@@ -1037,7 +1055,7 @@ func (s *GatewayService) RecordUsage(ctx context.Context, input *RecordUsageInpu
|
||||
}
|
||||
|
||||
// 更新账号最后使用时间
|
||||
if err := s.repos.Account.UpdateLastUsed(ctx, account.ID); err != nil {
|
||||
if err := s.accountRepo.UpdateLastUsed(ctx, account.ID); err != nil {
|
||||
log.Printf("Update last used failed: %v", err)
|
||||
}
|
||||
|
||||
|
||||
Reference in New Issue
Block a user