fix(usage): 使用日志事务和幂等性修复
- UsageLogRepository.Create 返回 inserted 标志 - UsageService 使用事务保证原子性 - 避免重复扣费(幂等重试场景) - 更新依赖注入和测试
This commit is contained in:
@@ -62,7 +62,7 @@ func initializeApplication(buildInfo handler.BuildInfo) (*Application, error) {
|
||||
apiKeyService := service.NewApiKeyService(apiKeyRepository, userRepository, groupRepository, userSubscriptionRepository, apiKeyCache, configConfig)
|
||||
apiKeyHandler := handler.NewAPIKeyHandler(apiKeyService)
|
||||
usageLogRepository := repository.NewUsageLogRepository(client, db)
|
||||
usageService := service.NewUsageService(usageLogRepository, userRepository)
|
||||
usageService := service.NewUsageService(usageLogRepository, userRepository, client)
|
||||
usageHandler := handler.NewUsageHandler(usageService, apiKeyService)
|
||||
redeemCodeRepository := repository.NewRedeemCodeRepository(client)
|
||||
billingCache := repository.NewBillingCache(redisClient)
|
||||
@@ -100,7 +100,7 @@ func initializeApplication(buildInfo handler.BuildInfo) (*Application, error) {
|
||||
antigravityOAuthService := service.NewAntigravityOAuthService(proxyRepository)
|
||||
antigravityTokenProvider := service.NewAntigravityTokenProvider(accountRepository, geminiTokenCache, antigravityOAuthService)
|
||||
httpUpstream := repository.NewHTTPUpstream(configConfig)
|
||||
antigravityGatewayService := service.NewAntigravityGatewayService(accountRepository, gatewayCache, antigravityTokenProvider, rateLimitService, httpUpstream)
|
||||
antigravityGatewayService := service.NewAntigravityGatewayService(accountRepository, gatewayCache, antigravityTokenProvider, rateLimitService, httpUpstream, settingService)
|
||||
accountTestService := service.NewAccountTestService(accountRepository, oAuthService, openAIOAuthService, geminiTokenProvider, antigravityGatewayService, httpUpstream)
|
||||
concurrencyCache := repository.ProvideConcurrencyCache(redisClient, configConfig)
|
||||
concurrencyService := service.ProvideConcurrencyService(concurrencyCache, accountRepository, configConfig)
|
||||
|
||||
@@ -61,9 +61,16 @@ func (r *usageLogRepository) getPerformanceStats(ctx context.Context, userID int
|
||||
return requestCount / 5, tokenCount / 5, nil
|
||||
}
|
||||
|
||||
func (r *usageLogRepository) Create(ctx context.Context, log *service.UsageLog) error {
|
||||
func (r *usageLogRepository) Create(ctx context.Context, log *service.UsageLog) (bool, error) {
|
||||
if log == nil {
|
||||
return nil
|
||||
return false, nil
|
||||
}
|
||||
|
||||
// 在事务上下文中,使用 tx 绑定的 ExecQuerier 执行原生 SQL,保证与其他更新同事务。
|
||||
// 无事务时回退到默认的 *sql.DB 执行器。
|
||||
sqlq := r.sql
|
||||
if tx := dbent.TxFromContext(ctx); tx != nil {
|
||||
sqlq = tx.Client()
|
||||
}
|
||||
|
||||
createdAt := log.CreatedAt
|
||||
@@ -152,18 +159,20 @@ func (r *usageLogRepository) Create(ctx context.Context, log *service.UsageLog)
|
||||
firstToken,
|
||||
createdAt,
|
||||
}
|
||||
if err := scanSingleRow(ctx, r.sql, query, args, &log.ID, &log.CreatedAt); err != nil {
|
||||
if err := scanSingleRow(ctx, sqlq, query, args, &log.ID, &log.CreatedAt); err != nil {
|
||||
if errors.Is(err, sql.ErrNoRows) && requestID != "" {
|
||||
selectQuery := "SELECT id, created_at FROM usage_logs WHERE request_id = $1 AND api_key_id = $2"
|
||||
if err := scanSingleRow(ctx, r.sql, selectQuery, []any{requestID, log.ApiKeyID}, &log.ID, &log.CreatedAt); err != nil {
|
||||
return err
|
||||
if err := scanSingleRow(ctx, sqlq, selectQuery, []any{requestID, log.ApiKeyID}, &log.ID, &log.CreatedAt); err != nil {
|
||||
return false, err
|
||||
}
|
||||
log.RateMultiplier = rateMultiplier
|
||||
return false, nil
|
||||
} else {
|
||||
return err
|
||||
return false, err
|
||||
}
|
||||
}
|
||||
log.RateMultiplier = rateMultiplier
|
||||
return nil
|
||||
return true, nil
|
||||
}
|
||||
|
||||
func (r *usageLogRepository) GetByID(ctx context.Context, id int64) (log *service.UsageLog, err error) {
|
||||
|
||||
@@ -380,7 +380,7 @@ func newContractDeps(t *testing.T) *contractDeps {
|
||||
apiKeyService := service.NewApiKeyService(apiKeyRepo, userRepo, groupRepo, userSubRepo, apiKeyCache, cfg)
|
||||
|
||||
usageRepo := newStubUsageLogRepo()
|
||||
usageService := service.NewUsageService(usageRepo, userRepo)
|
||||
usageService := service.NewUsageService(usageRepo, userRepo, client)
|
||||
|
||||
settingRepo := newStubSettingRepo()
|
||||
settingService := service.NewSettingService(settingRepo, cfg)
|
||||
|
||||
@@ -12,7 +12,9 @@ import (
|
||||
)
|
||||
|
||||
type UsageLogRepository interface {
|
||||
Create(ctx context.Context, log *UsageLog) error
|
||||
// Create creates a usage log and returns whether it was actually inserted.
|
||||
// inserted is false when the insert was skipped due to conflict (idempotent retries).
|
||||
Create(ctx context.Context, log *UsageLog) (inserted bool, err error)
|
||||
GetByID(ctx context.Context, id int64) (*UsageLog, error)
|
||||
Delete(ctx context.Context, id int64) error
|
||||
|
||||
|
||||
@@ -1026,22 +1026,23 @@ func (s *OpenAIGatewayService) RecordUsage(ctx context.Context, input *OpenAIRec
|
||||
usageLog.SubscriptionID = &subscription.ID
|
||||
}
|
||||
|
||||
_ = s.usageLogRepo.Create(ctx, usageLog)
|
||||
|
||||
inserted, err := s.usageLogRepo.Create(ctx, usageLog)
|
||||
if s.cfg != nil && s.cfg.RunMode == config.RunModeSimple {
|
||||
log.Printf("[SIMPLE MODE] Usage recorded (not billed): user=%d, tokens=%d", usageLog.UserID, usageLog.TotalTokens())
|
||||
s.deferredService.ScheduleLastUsedUpdate(account.ID)
|
||||
return nil
|
||||
}
|
||||
|
||||
shouldBill := inserted || err != nil
|
||||
|
||||
// Deduct based on billing type
|
||||
if isSubscriptionBilling {
|
||||
if cost.TotalCost > 0 {
|
||||
if shouldBill && cost.TotalCost > 0 {
|
||||
_ = s.userSubRepo.IncrementUsage(ctx, subscription.ID, cost.TotalCost)
|
||||
s.billingCacheService.QueueUpdateSubscriptionUsage(user.ID, *apiKey.GroupID, cost.TotalCost)
|
||||
}
|
||||
} else {
|
||||
if cost.ActualCost > 0 {
|
||||
if shouldBill && cost.ActualCost > 0 {
|
||||
_ = s.userRepo.DeductBalance(ctx, user.ID, cost.ActualCost)
|
||||
s.billingCacheService.QueueDeductBalance(user.ID, cost.ActualCost)
|
||||
}
|
||||
|
||||
@@ -2,9 +2,11 @@ package service
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"time"
|
||||
|
||||
dbent "github.com/Wei-Shaw/sub2api/ent"
|
||||
infraerrors "github.com/Wei-Shaw/sub2api/internal/pkg/errors"
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/pagination"
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/usagestats"
|
||||
@@ -54,20 +56,34 @@ type UsageStats struct {
|
||||
type UsageService struct {
|
||||
usageRepo UsageLogRepository
|
||||
userRepo UserRepository
|
||||
entClient *dbent.Client
|
||||
}
|
||||
|
||||
// NewUsageService 创建使用统计服务实例
|
||||
func NewUsageService(usageRepo UsageLogRepository, userRepo UserRepository) *UsageService {
|
||||
func NewUsageService(usageRepo UsageLogRepository, userRepo UserRepository, entClient *dbent.Client) *UsageService {
|
||||
return &UsageService{
|
||||
usageRepo: usageRepo,
|
||||
userRepo: userRepo,
|
||||
entClient: entClient,
|
||||
}
|
||||
}
|
||||
|
||||
// Create 创建使用日志
|
||||
func (s *UsageService) Create(ctx context.Context, req CreateUsageLogRequest) (*UsageLog, error) {
|
||||
// 使用数据库事务保证「使用日志插入」与「扣费」的原子性,避免重复扣费或漏扣风险。
|
||||
tx, err := s.entClient.Tx(ctx)
|
||||
if err != nil && !errors.Is(err, dbent.ErrTxStarted) {
|
||||
return nil, fmt.Errorf("begin transaction: %w", err)
|
||||
}
|
||||
|
||||
txCtx := ctx
|
||||
if err == nil {
|
||||
defer tx.Rollback()
|
||||
txCtx = dbent.NewTxContext(ctx, tx)
|
||||
}
|
||||
|
||||
// 验证用户存在
|
||||
_, err := s.userRepo.GetByID(ctx, req.UserID)
|
||||
_, err = s.userRepo.GetByID(txCtx, req.UserID)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("get user: %w", err)
|
||||
}
|
||||
@@ -96,17 +112,24 @@ func (s *UsageService) Create(ctx context.Context, req CreateUsageLogRequest) (*
|
||||
DurationMs: req.DurationMs,
|
||||
}
|
||||
|
||||
if err := s.usageRepo.Create(ctx, usageLog); err != nil {
|
||||
inserted, err := s.usageRepo.Create(txCtx, usageLog)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("create usage log: %w", err)
|
||||
}
|
||||
|
||||
// 扣除用户余额
|
||||
if req.ActualCost > 0 {
|
||||
if err := s.userRepo.UpdateBalance(ctx, req.UserID, -req.ActualCost); err != nil {
|
||||
if inserted && req.ActualCost > 0 {
|
||||
if err := s.userRepo.UpdateBalance(txCtx, req.UserID, -req.ActualCost); err != nil {
|
||||
return nil, fmt.Errorf("update user balance: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
if tx != nil {
|
||||
if err := tx.Commit(); err != nil {
|
||||
return nil, fmt.Errorf("commit transaction: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
return usageLog, nil
|
||||
}
|
||||
|
||||
|
||||
Reference in New Issue
Block a user