fix(usage): 使用日志事务和幂等性修复

- UsageLogRepository.Create 返回 inserted 标志
- UsageService 使用事务保证原子性
- 避免重复扣费(幂等重试场景)
- 更新依赖注入和测试
This commit is contained in:
ianshaw
2026-01-03 17:10:32 -08:00
parent 7eda43c99e
commit 71bf5b9e77
6 changed files with 55 additions and 20 deletions

View File

@@ -62,7 +62,7 @@ func initializeApplication(buildInfo handler.BuildInfo) (*Application, error) {
apiKeyService := service.NewApiKeyService(apiKeyRepository, userRepository, groupRepository, userSubscriptionRepository, apiKeyCache, configConfig)
apiKeyHandler := handler.NewAPIKeyHandler(apiKeyService)
usageLogRepository := repository.NewUsageLogRepository(client, db)
usageService := service.NewUsageService(usageLogRepository, userRepository)
usageService := service.NewUsageService(usageLogRepository, userRepository, client)
usageHandler := handler.NewUsageHandler(usageService, apiKeyService)
redeemCodeRepository := repository.NewRedeemCodeRepository(client)
billingCache := repository.NewBillingCache(redisClient)
@@ -100,7 +100,7 @@ func initializeApplication(buildInfo handler.BuildInfo) (*Application, error) {
antigravityOAuthService := service.NewAntigravityOAuthService(proxyRepository)
antigravityTokenProvider := service.NewAntigravityTokenProvider(accountRepository, geminiTokenCache, antigravityOAuthService)
httpUpstream := repository.NewHTTPUpstream(configConfig)
antigravityGatewayService := service.NewAntigravityGatewayService(accountRepository, gatewayCache, antigravityTokenProvider, rateLimitService, httpUpstream)
antigravityGatewayService := service.NewAntigravityGatewayService(accountRepository, gatewayCache, antigravityTokenProvider, rateLimitService, httpUpstream, settingService)
accountTestService := service.NewAccountTestService(accountRepository, oAuthService, openAIOAuthService, geminiTokenProvider, antigravityGatewayService, httpUpstream)
concurrencyCache := repository.ProvideConcurrencyCache(redisClient, configConfig)
concurrencyService := service.ProvideConcurrencyService(concurrencyCache, accountRepository, configConfig)

View File

@@ -61,9 +61,16 @@ func (r *usageLogRepository) getPerformanceStats(ctx context.Context, userID int
return requestCount / 5, tokenCount / 5, nil
}
func (r *usageLogRepository) Create(ctx context.Context, log *service.UsageLog) error {
func (r *usageLogRepository) Create(ctx context.Context, log *service.UsageLog) (bool, error) {
if log == nil {
return nil
return false, nil
}
// 在事务上下文中,使用 tx 绑定的 ExecQuerier 执行原生 SQL保证与其他更新同事务。
// 无事务时回退到默认的 *sql.DB 执行器。
sqlq := r.sql
if tx := dbent.TxFromContext(ctx); tx != nil {
sqlq = tx.Client()
}
createdAt := log.CreatedAt
@@ -152,18 +159,20 @@ func (r *usageLogRepository) Create(ctx context.Context, log *service.UsageLog)
firstToken,
createdAt,
}
if err := scanSingleRow(ctx, r.sql, query, args, &log.ID, &log.CreatedAt); err != nil {
if err := scanSingleRow(ctx, sqlq, query, args, &log.ID, &log.CreatedAt); err != nil {
if errors.Is(err, sql.ErrNoRows) && requestID != "" {
selectQuery := "SELECT id, created_at FROM usage_logs WHERE request_id = $1 AND api_key_id = $2"
if err := scanSingleRow(ctx, r.sql, selectQuery, []any{requestID, log.ApiKeyID}, &log.ID, &log.CreatedAt); err != nil {
return err
if err := scanSingleRow(ctx, sqlq, selectQuery, []any{requestID, log.ApiKeyID}, &log.ID, &log.CreatedAt); err != nil {
return false, err
}
log.RateMultiplier = rateMultiplier
return false, nil
} else {
return err
return false, err
}
}
log.RateMultiplier = rateMultiplier
return nil
return true, nil
}
func (r *usageLogRepository) GetByID(ctx context.Context, id int64) (log *service.UsageLog, err error) {

View File

@@ -380,7 +380,7 @@ func newContractDeps(t *testing.T) *contractDeps {
apiKeyService := service.NewApiKeyService(apiKeyRepo, userRepo, groupRepo, userSubRepo, apiKeyCache, cfg)
usageRepo := newStubUsageLogRepo()
usageService := service.NewUsageService(usageRepo, userRepo)
usageService := service.NewUsageService(usageRepo, userRepo, client)
settingRepo := newStubSettingRepo()
settingService := service.NewSettingService(settingRepo, cfg)

View File

@@ -12,7 +12,9 @@ import (
)
type UsageLogRepository interface {
Create(ctx context.Context, log *UsageLog) error
// Create creates a usage log and returns whether it was actually inserted.
// inserted is false when the insert was skipped due to conflict (idempotent retries).
Create(ctx context.Context, log *UsageLog) (inserted bool, err error)
GetByID(ctx context.Context, id int64) (*UsageLog, error)
Delete(ctx context.Context, id int64) error

View File

@@ -1026,22 +1026,23 @@ func (s *OpenAIGatewayService) RecordUsage(ctx context.Context, input *OpenAIRec
usageLog.SubscriptionID = &subscription.ID
}
_ = s.usageLogRepo.Create(ctx, usageLog)
inserted, err := s.usageLogRepo.Create(ctx, usageLog)
if s.cfg != nil && s.cfg.RunMode == config.RunModeSimple {
log.Printf("[SIMPLE MODE] Usage recorded (not billed): user=%d, tokens=%d", usageLog.UserID, usageLog.TotalTokens())
s.deferredService.ScheduleLastUsedUpdate(account.ID)
return nil
}
shouldBill := inserted || err != nil
// Deduct based on billing type
if isSubscriptionBilling {
if cost.TotalCost > 0 {
if shouldBill && cost.TotalCost > 0 {
_ = s.userSubRepo.IncrementUsage(ctx, subscription.ID, cost.TotalCost)
s.billingCacheService.QueueUpdateSubscriptionUsage(user.ID, *apiKey.GroupID, cost.TotalCost)
}
} else {
if cost.ActualCost > 0 {
if shouldBill && cost.ActualCost > 0 {
_ = s.userRepo.DeductBalance(ctx, user.ID, cost.ActualCost)
s.billingCacheService.QueueDeductBalance(user.ID, cost.ActualCost)
}

View File

@@ -2,9 +2,11 @@ package service
import (
"context"
"errors"
"fmt"
"time"
dbent "github.com/Wei-Shaw/sub2api/ent"
infraerrors "github.com/Wei-Shaw/sub2api/internal/pkg/errors"
"github.com/Wei-Shaw/sub2api/internal/pkg/pagination"
"github.com/Wei-Shaw/sub2api/internal/pkg/usagestats"
@@ -54,20 +56,34 @@ type UsageStats struct {
type UsageService struct {
usageRepo UsageLogRepository
userRepo UserRepository
entClient *dbent.Client
}
// NewUsageService 创建使用统计服务实例
func NewUsageService(usageRepo UsageLogRepository, userRepo UserRepository) *UsageService {
func NewUsageService(usageRepo UsageLogRepository, userRepo UserRepository, entClient *dbent.Client) *UsageService {
return &UsageService{
usageRepo: usageRepo,
userRepo: userRepo,
entClient: entClient,
}
}
// Create 创建使用日志
func (s *UsageService) Create(ctx context.Context, req CreateUsageLogRequest) (*UsageLog, error) {
// 使用数据库事务保证「使用日志插入」与「扣费」的原子性,避免重复扣费或漏扣风险。
tx, err := s.entClient.Tx(ctx)
if err != nil && !errors.Is(err, dbent.ErrTxStarted) {
return nil, fmt.Errorf("begin transaction: %w", err)
}
txCtx := ctx
if err == nil {
defer tx.Rollback()
txCtx = dbent.NewTxContext(ctx, tx)
}
// 验证用户存在
_, err := s.userRepo.GetByID(ctx, req.UserID)
_, err = s.userRepo.GetByID(txCtx, req.UserID)
if err != nil {
return nil, fmt.Errorf("get user: %w", err)
}
@@ -96,17 +112,24 @@ func (s *UsageService) Create(ctx context.Context, req CreateUsageLogRequest) (*
DurationMs: req.DurationMs,
}
if err := s.usageRepo.Create(ctx, usageLog); err != nil {
inserted, err := s.usageRepo.Create(txCtx, usageLog)
if err != nil {
return nil, fmt.Errorf("create usage log: %w", err)
}
// 扣除用户余额
if req.ActualCost > 0 {
if err := s.userRepo.UpdateBalance(ctx, req.UserID, -req.ActualCost); err != nil {
if inserted && req.ActualCost > 0 {
if err := s.userRepo.UpdateBalance(txCtx, req.UserID, -req.ActualCost); err != nil {
return nil, fmt.Errorf("update user balance: %w", err)
}
}
if tx != nil {
if err := tx.Commit(); err != nil {
return nil, fmt.Errorf("commit transaction: %w", err)
}
}
return usageLog, nil
}