diff --git a/backend/cmd/server/wire_gen.go b/backend/cmd/server/wire_gen.go index cdf63bd3..75e84e13 100644 --- a/backend/cmd/server/wire_gen.go +++ b/backend/cmd/server/wire_gen.go @@ -62,7 +62,7 @@ func initializeApplication(buildInfo handler.BuildInfo) (*Application, error) { apiKeyService := service.NewApiKeyService(apiKeyRepository, userRepository, groupRepository, userSubscriptionRepository, apiKeyCache, configConfig) apiKeyHandler := handler.NewAPIKeyHandler(apiKeyService) usageLogRepository := repository.NewUsageLogRepository(client, db) - usageService := service.NewUsageService(usageLogRepository, userRepository) + usageService := service.NewUsageService(usageLogRepository, userRepository, client) usageHandler := handler.NewUsageHandler(usageService, apiKeyService) redeemCodeRepository := repository.NewRedeemCodeRepository(client) billingCache := repository.NewBillingCache(redisClient) @@ -100,7 +100,7 @@ func initializeApplication(buildInfo handler.BuildInfo) (*Application, error) { antigravityOAuthService := service.NewAntigravityOAuthService(proxyRepository) antigravityTokenProvider := service.NewAntigravityTokenProvider(accountRepository, geminiTokenCache, antigravityOAuthService) httpUpstream := repository.NewHTTPUpstream(configConfig) - antigravityGatewayService := service.NewAntigravityGatewayService(accountRepository, gatewayCache, antigravityTokenProvider, rateLimitService, httpUpstream) + antigravityGatewayService := service.NewAntigravityGatewayService(accountRepository, gatewayCache, antigravityTokenProvider, rateLimitService, httpUpstream, settingService) accountTestService := service.NewAccountTestService(accountRepository, oAuthService, openAIOAuthService, geminiTokenProvider, antigravityGatewayService, httpUpstream) concurrencyCache := repository.ProvideConcurrencyCache(redisClient, configConfig) concurrencyService := service.ProvideConcurrencyService(concurrencyCache, accountRepository, configConfig) diff --git a/backend/internal/repository/usage_log_repo.go b/backend/internal/repository/usage_log_repo.go index 51abdbb0..4056431d 100644 --- a/backend/internal/repository/usage_log_repo.go +++ b/backend/internal/repository/usage_log_repo.go @@ -61,9 +61,16 @@ func (r *usageLogRepository) getPerformanceStats(ctx context.Context, userID int return requestCount / 5, tokenCount / 5, nil } -func (r *usageLogRepository) Create(ctx context.Context, log *service.UsageLog) error { +func (r *usageLogRepository) Create(ctx context.Context, log *service.UsageLog) (bool, error) { if log == nil { - return nil + return false, nil + } + + // 在事务上下文中,使用 tx 绑定的 ExecQuerier 执行原生 SQL,保证与其他更新同事务。 + // 无事务时回退到默认的 *sql.DB 执行器。 + sqlq := r.sql + if tx := dbent.TxFromContext(ctx); tx != nil { + sqlq = tx.Client() } createdAt := log.CreatedAt @@ -152,18 +159,20 @@ func (r *usageLogRepository) Create(ctx context.Context, log *service.UsageLog) firstToken, createdAt, } - if err := scanSingleRow(ctx, r.sql, query, args, &log.ID, &log.CreatedAt); err != nil { + if err := scanSingleRow(ctx, sqlq, query, args, &log.ID, &log.CreatedAt); err != nil { if errors.Is(err, sql.ErrNoRows) && requestID != "" { selectQuery := "SELECT id, created_at FROM usage_logs WHERE request_id = $1 AND api_key_id = $2" - if err := scanSingleRow(ctx, r.sql, selectQuery, []any{requestID, log.ApiKeyID}, &log.ID, &log.CreatedAt); err != nil { - return err + if err := scanSingleRow(ctx, sqlq, selectQuery, []any{requestID, log.ApiKeyID}, &log.ID, &log.CreatedAt); err != nil { + return false, err } + log.RateMultiplier = rateMultiplier + return false, nil } else { - return err + return false, err } } log.RateMultiplier = rateMultiplier - return nil + return true, nil } func (r *usageLogRepository) GetByID(ctx context.Context, id int64) (log *service.UsageLog, err error) { diff --git a/backend/internal/server/api_contract_test.go b/backend/internal/server/api_contract_test.go index b62ed618..1cf05dc7 100644 --- a/backend/internal/server/api_contract_test.go +++ b/backend/internal/server/api_contract_test.go @@ -380,7 +380,7 @@ func newContractDeps(t *testing.T) *contractDeps { apiKeyService := service.NewApiKeyService(apiKeyRepo, userRepo, groupRepo, userSubRepo, apiKeyCache, cfg) usageRepo := newStubUsageLogRepo() - usageService := service.NewUsageService(usageRepo, userRepo) + usageService := service.NewUsageService(usageRepo, userRepo, client) settingRepo := newStubSettingRepo() settingService := service.NewSettingService(settingRepo, cfg) diff --git a/backend/internal/service/account_usage_service.go b/backend/internal/service/account_usage_service.go index c4220c0c..eb992cd4 100644 --- a/backend/internal/service/account_usage_service.go +++ b/backend/internal/service/account_usage_service.go @@ -12,7 +12,9 @@ import ( ) type UsageLogRepository interface { - Create(ctx context.Context, log *UsageLog) error + // Create creates a usage log and returns whether it was actually inserted. + // inserted is false when the insert was skipped due to conflict (idempotent retries). + Create(ctx context.Context, log *UsageLog) (inserted bool, err error) GetByID(ctx context.Context, id int64) (*UsageLog, error) Delete(ctx context.Context, id int64) error diff --git a/backend/internal/service/openai_gateway_service.go b/backend/internal/service/openai_gateway_service.go index 6da587ad..63166cd2 100644 --- a/backend/internal/service/openai_gateway_service.go +++ b/backend/internal/service/openai_gateway_service.go @@ -1026,22 +1026,23 @@ func (s *OpenAIGatewayService) RecordUsage(ctx context.Context, input *OpenAIRec usageLog.SubscriptionID = &subscription.ID } - _ = s.usageLogRepo.Create(ctx, usageLog) - + inserted, err := s.usageLogRepo.Create(ctx, usageLog) if s.cfg != nil && s.cfg.RunMode == config.RunModeSimple { log.Printf("[SIMPLE MODE] Usage recorded (not billed): user=%d, tokens=%d", usageLog.UserID, usageLog.TotalTokens()) s.deferredService.ScheduleLastUsedUpdate(account.ID) return nil } + shouldBill := inserted || err != nil + // Deduct based on billing type if isSubscriptionBilling { - if cost.TotalCost > 0 { + if shouldBill && cost.TotalCost > 0 { _ = s.userSubRepo.IncrementUsage(ctx, subscription.ID, cost.TotalCost) s.billingCacheService.QueueUpdateSubscriptionUsage(user.ID, *apiKey.GroupID, cost.TotalCost) } } else { - if cost.ActualCost > 0 { + if shouldBill && cost.ActualCost > 0 { _ = s.userRepo.DeductBalance(ctx, user.ID, cost.ActualCost) s.billingCacheService.QueueDeductBalance(user.ID, cost.ActualCost) } diff --git a/backend/internal/service/usage_service.go b/backend/internal/service/usage_service.go index e1e97671..4a0f6e56 100644 --- a/backend/internal/service/usage_service.go +++ b/backend/internal/service/usage_service.go @@ -2,9 +2,11 @@ package service import ( "context" + "errors" "fmt" "time" + dbent "github.com/Wei-Shaw/sub2api/ent" infraerrors "github.com/Wei-Shaw/sub2api/internal/pkg/errors" "github.com/Wei-Shaw/sub2api/internal/pkg/pagination" "github.com/Wei-Shaw/sub2api/internal/pkg/usagestats" @@ -54,20 +56,34 @@ type UsageStats struct { type UsageService struct { usageRepo UsageLogRepository userRepo UserRepository + entClient *dbent.Client } // NewUsageService 创建使用统计服务实例 -func NewUsageService(usageRepo UsageLogRepository, userRepo UserRepository) *UsageService { +func NewUsageService(usageRepo UsageLogRepository, userRepo UserRepository, entClient *dbent.Client) *UsageService { return &UsageService{ usageRepo: usageRepo, userRepo: userRepo, + entClient: entClient, } } // Create 创建使用日志 func (s *UsageService) Create(ctx context.Context, req CreateUsageLogRequest) (*UsageLog, error) { + // 使用数据库事务保证「使用日志插入」与「扣费」的原子性,避免重复扣费或漏扣风险。 + tx, err := s.entClient.Tx(ctx) + if err != nil && !errors.Is(err, dbent.ErrTxStarted) { + return nil, fmt.Errorf("begin transaction: %w", err) + } + + txCtx := ctx + if err == nil { + defer tx.Rollback() + txCtx = dbent.NewTxContext(ctx, tx) + } + // 验证用户存在 - _, err := s.userRepo.GetByID(ctx, req.UserID) + _, err = s.userRepo.GetByID(txCtx, req.UserID) if err != nil { return nil, fmt.Errorf("get user: %w", err) } @@ -96,17 +112,24 @@ func (s *UsageService) Create(ctx context.Context, req CreateUsageLogRequest) (* DurationMs: req.DurationMs, } - if err := s.usageRepo.Create(ctx, usageLog); err != nil { + inserted, err := s.usageRepo.Create(txCtx, usageLog) + if err != nil { return nil, fmt.Errorf("create usage log: %w", err) } // 扣除用户余额 - if req.ActualCost > 0 { - if err := s.userRepo.UpdateBalance(ctx, req.UserID, -req.ActualCost); err != nil { + if inserted && req.ActualCost > 0 { + if err := s.userRepo.UpdateBalance(txCtx, req.UserID, -req.ActualCost); err != nil { return nil, fmt.Errorf("update user balance: %w", err) } } + if tx != nil { + if err := tx.Commit(); err != nil { + return nil, fmt.Errorf("commit transaction: %w", err) + } + } + return usageLog, nil }