Files
sub2api/backend/internal/service/usage_service.go
IanShaw027 a4953785d9 fix(lint): 修复所有 Go 命名规范问题
- 全局替换 ApiKey → APIKey(类型、字段、方法、变量)
- 修复所有 initialism 命名(API, SMTP, HTML, URL 等)
- 添加所有缺失的包注释
- 修复导出符号的注释格式

主要修改:
- ApiKey → APIKey(所有出现的地方)
- ApiKeyID → APIKeyID
- ApiKeyIDs → APIKeyIDs
- TestSmtpConnection → TestSMTPConnection
- HtmlURL → HTMLURL
- 添加 20+ 个包注释
- 修复 10+ 个导出符号注释格式

验证结果:
- ✓ golangci-lint: 0 issues
- ✓ 单元测试: 通过
- ✓ 集成测试: 通过
2026-01-04 19:28:20 +08:00

322 lines
12 KiB
Go
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

package service
import (
"context"
"errors"
"fmt"
"time"
dbent "github.com/Wei-Shaw/sub2api/ent"
infraerrors "github.com/Wei-Shaw/sub2api/internal/pkg/errors"
"github.com/Wei-Shaw/sub2api/internal/pkg/pagination"
"github.com/Wei-Shaw/sub2api/internal/pkg/usagestats"
)
var (
ErrUsageLogNotFound = infraerrors.NotFound("USAGE_LOG_NOT_FOUND", "usage log not found")
)
// CreateUsageLogRequest 创建使用日志请求
type CreateUsageLogRequest struct {
UserID int64 `json:"user_id"`
APIKeyID int64 `json:"api_key_id"`
AccountID int64 `json:"account_id"`
RequestID string `json:"request_id"`
Model string `json:"model"`
InputTokens int `json:"input_tokens"`
OutputTokens int `json:"output_tokens"`
CacheCreationTokens int `json:"cache_creation_tokens"`
CacheReadTokens int `json:"cache_read_tokens"`
CacheCreation5mTokens int `json:"cache_creation_5m_tokens"`
CacheCreation1hTokens int `json:"cache_creation_1h_tokens"`
InputCost float64 `json:"input_cost"`
OutputCost float64 `json:"output_cost"`
CacheCreationCost float64 `json:"cache_creation_cost"`
CacheReadCost float64 `json:"cache_read_cost"`
TotalCost float64 `json:"total_cost"`
ActualCost float64 `json:"actual_cost"`
RateMultiplier float64 `json:"rate_multiplier"`
Stream bool `json:"stream"`
DurationMs *int `json:"duration_ms"`
}
// UsageStats 使用统计
type UsageStats struct {
TotalRequests int64 `json:"total_requests"`
TotalInputTokens int64 `json:"total_input_tokens"`
TotalOutputTokens int64 `json:"total_output_tokens"`
TotalCacheTokens int64 `json:"total_cache_tokens"`
TotalTokens int64 `json:"total_tokens"`
TotalCost float64 `json:"total_cost"`
TotalActualCost float64 `json:"total_actual_cost"`
AverageDurationMs float64 `json:"average_duration_ms"`
}
// UsageService 使用统计服务
type UsageService struct {
usageRepo UsageLogRepository
userRepo UserRepository
entClient *dbent.Client
}
// NewUsageService 创建使用统计服务实例
func NewUsageService(usageRepo UsageLogRepository, userRepo UserRepository, entClient *dbent.Client) *UsageService {
return &UsageService{
usageRepo: usageRepo,
userRepo: userRepo,
entClient: entClient,
}
}
// Create 创建使用日志
func (s *UsageService) Create(ctx context.Context, req CreateUsageLogRequest) (*UsageLog, error) {
// 使用数据库事务保证「使用日志插入」与「扣费」的原子性,避免重复扣费或漏扣风险。
tx, err := s.entClient.Tx(ctx)
if err != nil && !errors.Is(err, dbent.ErrTxStarted) {
return nil, fmt.Errorf("begin transaction: %w", err)
}
txCtx := ctx
if err == nil {
defer func() { _ = tx.Rollback() }()
txCtx = dbent.NewTxContext(ctx, tx)
}
// 验证用户存在
_, err = s.userRepo.GetByID(txCtx, req.UserID)
if err != nil {
return nil, fmt.Errorf("get user: %w", err)
}
// 创建使用日志
usageLog := &UsageLog{
UserID: req.UserID,
APIKeyID: req.APIKeyID,
AccountID: req.AccountID,
RequestID: req.RequestID,
Model: req.Model,
InputTokens: req.InputTokens,
OutputTokens: req.OutputTokens,
CacheCreationTokens: req.CacheCreationTokens,
CacheReadTokens: req.CacheReadTokens,
CacheCreation5mTokens: req.CacheCreation5mTokens,
CacheCreation1hTokens: req.CacheCreation1hTokens,
InputCost: req.InputCost,
OutputCost: req.OutputCost,
CacheCreationCost: req.CacheCreationCost,
CacheReadCost: req.CacheReadCost,
TotalCost: req.TotalCost,
ActualCost: req.ActualCost,
RateMultiplier: req.RateMultiplier,
Stream: req.Stream,
DurationMs: req.DurationMs,
}
inserted, err := s.usageRepo.Create(txCtx, usageLog)
if err != nil {
return nil, fmt.Errorf("create usage log: %w", err)
}
// 扣除用户余额
if inserted && req.ActualCost > 0 {
if err := s.userRepo.UpdateBalance(txCtx, req.UserID, -req.ActualCost); err != nil {
return nil, fmt.Errorf("update user balance: %w", err)
}
}
if tx != nil {
if err := tx.Commit(); err != nil {
return nil, fmt.Errorf("commit transaction: %w", err)
}
}
return usageLog, nil
}
// GetByID 根据ID获取使用日志
func (s *UsageService) GetByID(ctx context.Context, id int64) (*UsageLog, error) {
log, err := s.usageRepo.GetByID(ctx, id)
if err != nil {
return nil, fmt.Errorf("get usage log: %w", err)
}
return log, nil
}
// ListByUser 获取用户的使用日志列表
func (s *UsageService) ListByUser(ctx context.Context, userID int64, params pagination.PaginationParams) ([]UsageLog, *pagination.PaginationResult, error) {
logs, pagination, err := s.usageRepo.ListByUser(ctx, userID, params)
if err != nil {
return nil, nil, fmt.Errorf("list usage logs: %w", err)
}
return logs, pagination, nil
}
// ListByAPIKey 获取API Key的使用日志列表
func (s *UsageService) ListByAPIKey(ctx context.Context, apiKeyID int64, params pagination.PaginationParams) ([]UsageLog, *pagination.PaginationResult, error) {
logs, pagination, err := s.usageRepo.ListByAPIKey(ctx, apiKeyID, params)
if err != nil {
return nil, nil, fmt.Errorf("list usage logs: %w", err)
}
return logs, pagination, nil
}
// ListByAccount 获取账号的使用日志列表
func (s *UsageService) ListByAccount(ctx context.Context, accountID int64, params pagination.PaginationParams) ([]UsageLog, *pagination.PaginationResult, error) {
logs, pagination, err := s.usageRepo.ListByAccount(ctx, accountID, params)
if err != nil {
return nil, nil, fmt.Errorf("list usage logs: %w", err)
}
return logs, pagination, nil
}
// GetStatsByUser 获取用户的使用统计
func (s *UsageService) GetStatsByUser(ctx context.Context, userID int64, startTime, endTime time.Time) (*UsageStats, error) {
stats, err := s.usageRepo.GetUserStatsAggregated(ctx, userID, startTime, endTime)
if err != nil {
return nil, fmt.Errorf("get user stats: %w", err)
}
return &UsageStats{
TotalRequests: stats.TotalRequests,
TotalInputTokens: stats.TotalInputTokens,
TotalOutputTokens: stats.TotalOutputTokens,
TotalCacheTokens: stats.TotalCacheTokens,
TotalTokens: stats.TotalTokens,
TotalCost: stats.TotalCost,
TotalActualCost: stats.TotalActualCost,
AverageDurationMs: stats.AverageDurationMs,
}, nil
}
// GetStatsByAPIKey 获取API Key的使用统计
func (s *UsageService) GetStatsByAPIKey(ctx context.Context, apiKeyID int64, startTime, endTime time.Time) (*UsageStats, error) {
stats, err := s.usageRepo.GetAPIKeyStatsAggregated(ctx, apiKeyID, startTime, endTime)
if err != nil {
return nil, fmt.Errorf("get api key stats: %w", err)
}
return &UsageStats{
TotalRequests: stats.TotalRequests,
TotalInputTokens: stats.TotalInputTokens,
TotalOutputTokens: stats.TotalOutputTokens,
TotalCacheTokens: stats.TotalCacheTokens,
TotalTokens: stats.TotalTokens,
TotalCost: stats.TotalCost,
TotalActualCost: stats.TotalActualCost,
AverageDurationMs: stats.AverageDurationMs,
}, nil
}
// GetStatsByAccount 获取账号的使用统计
func (s *UsageService) GetStatsByAccount(ctx context.Context, accountID int64, startTime, endTime time.Time) (*UsageStats, error) {
stats, err := s.usageRepo.GetAccountStatsAggregated(ctx, accountID, startTime, endTime)
if err != nil {
return nil, fmt.Errorf("get account stats: %w", err)
}
return &UsageStats{
TotalRequests: stats.TotalRequests,
TotalInputTokens: stats.TotalInputTokens,
TotalOutputTokens: stats.TotalOutputTokens,
TotalCacheTokens: stats.TotalCacheTokens,
TotalTokens: stats.TotalTokens,
TotalCost: stats.TotalCost,
TotalActualCost: stats.TotalActualCost,
AverageDurationMs: stats.AverageDurationMs,
}, nil
}
// GetStatsByModel 获取模型的使用统计
func (s *UsageService) GetStatsByModel(ctx context.Context, modelName string, startTime, endTime time.Time) (*UsageStats, error) {
stats, err := s.usageRepo.GetModelStatsAggregated(ctx, modelName, startTime, endTime)
if err != nil {
return nil, fmt.Errorf("get model stats: %w", err)
}
return &UsageStats{
TotalRequests: stats.TotalRequests,
TotalInputTokens: stats.TotalInputTokens,
TotalOutputTokens: stats.TotalOutputTokens,
TotalCacheTokens: stats.TotalCacheTokens,
TotalTokens: stats.TotalTokens,
TotalCost: stats.TotalCost,
TotalActualCost: stats.TotalActualCost,
AverageDurationMs: stats.AverageDurationMs,
}, nil
}
// GetDailyStats 获取每日使用统计最近N天
func (s *UsageService) GetDailyStats(ctx context.Context, userID int64, days int) ([]map[string]any, error) {
endTime := time.Now()
startTime := endTime.AddDate(0, 0, -days)
stats, err := s.usageRepo.GetDailyStatsAggregated(ctx, userID, startTime, endTime)
if err != nil {
return nil, fmt.Errorf("get daily stats: %w", err)
}
return stats, nil
}
// Delete 删除使用日志(管理员功能,谨慎使用)
func (s *UsageService) Delete(ctx context.Context, id int64) error {
if err := s.usageRepo.Delete(ctx, id); err != nil {
return fmt.Errorf("delete usage log: %w", err)
}
return nil
}
// GetUserDashboardStats returns per-user dashboard summary stats.
func (s *UsageService) GetUserDashboardStats(ctx context.Context, userID int64) (*usagestats.UserDashboardStats, error) {
stats, err := s.usageRepo.GetUserDashboardStats(ctx, userID)
if err != nil {
return nil, fmt.Errorf("get user dashboard stats: %w", err)
}
return stats, nil
}
// GetUserUsageTrendByUserID returns per-user usage trend.
func (s *UsageService) GetUserUsageTrendByUserID(ctx context.Context, userID int64, startTime, endTime time.Time, granularity string) ([]usagestats.TrendDataPoint, error) {
trend, err := s.usageRepo.GetUserUsageTrendByUserID(ctx, userID, startTime, endTime, granularity)
if err != nil {
return nil, fmt.Errorf("get user usage trend: %w", err)
}
return trend, nil
}
// GetUserModelStats returns per-user model usage stats.
func (s *UsageService) GetUserModelStats(ctx context.Context, userID int64, startTime, endTime time.Time) ([]usagestats.ModelStat, error) {
stats, err := s.usageRepo.GetUserModelStats(ctx, userID, startTime, endTime)
if err != nil {
return nil, fmt.Errorf("get user model stats: %w", err)
}
return stats, nil
}
// GetBatchAPIKeyUsageStats returns today/total actual_cost for given api keys.
func (s *UsageService) GetBatchAPIKeyUsageStats(ctx context.Context, apiKeyIDs []int64) (map[int64]*usagestats.BatchAPIKeyUsageStats, error) {
stats, err := s.usageRepo.GetBatchAPIKeyUsageStats(ctx, apiKeyIDs)
if err != nil {
return nil, fmt.Errorf("get batch api key usage stats: %w", err)
}
return stats, nil
}
// ListWithFilters lists usage logs with admin filters.
func (s *UsageService) ListWithFilters(ctx context.Context, params pagination.PaginationParams, filters usagestats.UsageLogFilters) ([]UsageLog, *pagination.PaginationResult, error) {
logs, result, err := s.usageRepo.ListWithFilters(ctx, params, filters)
if err != nil {
return nil, nil, fmt.Errorf("list usage logs with filters: %w", err)
}
return logs, result, nil
}
// GetGlobalStats returns global usage stats for a time range.
func (s *UsageService) GetGlobalStats(ctx context.Context, startTime, endTime time.Time) (*usagestats.UsageStats, error) {
stats, err := s.usageRepo.GetGlobalStats(ctx, startTime, endTime)
if err != nil {
return nil, fmt.Errorf("get global usage stats: %w", err)
}
return stats, nil
}