Merge branch 'dev-release'
This commit is contained in:
@@ -375,36 +375,19 @@ func (r *apiKeyRepository) ListKeysByGroupID(ctx context.Context, groupID int64)
|
||||
return keys, nil
|
||||
}
|
||||
|
||||
// IncrementQuotaUsed atomically increments the quota_used field and returns the new value
|
||||
// IncrementQuotaUsed 使用 Ent 原子递增 quota_used 字段并返回新值
|
||||
func (r *apiKeyRepository) IncrementQuotaUsed(ctx context.Context, id int64, amount float64) (float64, error) {
|
||||
// Use raw SQL for atomic increment to avoid race conditions
|
||||
// First get current value
|
||||
m, err := r.activeQuery().
|
||||
Where(apikey.IDEQ(id)).
|
||||
Select(apikey.FieldQuotaUsed).
|
||||
Only(ctx)
|
||||
updated, err := r.client.APIKey.UpdateOneID(id).
|
||||
Where(apikey.DeletedAtIsNil()).
|
||||
AddQuotaUsed(amount).
|
||||
Save(ctx)
|
||||
if err != nil {
|
||||
if dbent.IsNotFound(err) {
|
||||
return 0, service.ErrAPIKeyNotFound
|
||||
}
|
||||
return 0, err
|
||||
}
|
||||
|
||||
newValue := m.QuotaUsed + amount
|
||||
|
||||
// Update with new value
|
||||
affected, err := r.client.APIKey.Update().
|
||||
Where(apikey.IDEQ(id), apikey.DeletedAtIsNil()).
|
||||
SetQuotaUsed(newValue).
|
||||
Save(ctx)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
if affected == 0 {
|
||||
return 0, service.ErrAPIKeyNotFound
|
||||
}
|
||||
|
||||
return newValue, nil
|
||||
return updated.QuotaUsed, nil
|
||||
}
|
||||
|
||||
func apiKeyEntityToService(m *dbent.APIKey) *service.APIKey {
|
||||
|
||||
@@ -4,11 +4,14 @@ package repository
|
||||
|
||||
import (
|
||||
"context"
|
||||
"sync"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
dbent "github.com/Wei-Shaw/sub2api/ent"
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/pagination"
|
||||
"github.com/Wei-Shaw/sub2api/internal/service"
|
||||
"github.com/stretchr/testify/require"
|
||||
"github.com/stretchr/testify/suite"
|
||||
)
|
||||
|
||||
@@ -383,3 +386,87 @@ func (s *APIKeyRepoSuite) mustCreateApiKey(userID int64, key, name string, group
|
||||
s.Require().NoError(s.repo.Create(s.ctx, k), "create api key")
|
||||
return k
|
||||
}
|
||||
|
||||
// --- IncrementQuotaUsed ---
|
||||
|
||||
func (s *APIKeyRepoSuite) TestIncrementQuotaUsed_Basic() {
|
||||
user := s.mustCreateUser("incr-basic@test.com")
|
||||
key := s.mustCreateApiKey(user.ID, "sk-incr-basic", "Incr", nil)
|
||||
|
||||
newQuota, err := s.repo.IncrementQuotaUsed(s.ctx, key.ID, 1.5)
|
||||
s.Require().NoError(err, "IncrementQuotaUsed")
|
||||
s.Require().Equal(1.5, newQuota, "第一次递增后应为 1.5")
|
||||
|
||||
newQuota, err = s.repo.IncrementQuotaUsed(s.ctx, key.ID, 2.5)
|
||||
s.Require().NoError(err, "IncrementQuotaUsed second")
|
||||
s.Require().Equal(4.0, newQuota, "第二次递增后应为 4.0")
|
||||
}
|
||||
|
||||
func (s *APIKeyRepoSuite) TestIncrementQuotaUsed_NotFound() {
|
||||
_, err := s.repo.IncrementQuotaUsed(s.ctx, 999999, 1.0)
|
||||
s.Require().ErrorIs(err, service.ErrAPIKeyNotFound, "不存在的 key 应返回 ErrAPIKeyNotFound")
|
||||
}
|
||||
|
||||
func (s *APIKeyRepoSuite) TestIncrementQuotaUsed_DeletedKey() {
|
||||
user := s.mustCreateUser("incr-deleted@test.com")
|
||||
key := s.mustCreateApiKey(user.ID, "sk-incr-del", "Deleted", nil)
|
||||
|
||||
s.Require().NoError(s.repo.Delete(s.ctx, key.ID), "Delete")
|
||||
|
||||
_, err := s.repo.IncrementQuotaUsed(s.ctx, key.ID, 1.0)
|
||||
s.Require().ErrorIs(err, service.ErrAPIKeyNotFound, "已删除的 key 应返回 ErrAPIKeyNotFound")
|
||||
}
|
||||
|
||||
// TestIncrementQuotaUsed_Concurrent 使用真实数据库验证并发原子性。
|
||||
// 注意:此测试使用 testEntClient(非事务隔离),数据会真正写入数据库。
|
||||
func TestIncrementQuotaUsed_Concurrent(t *testing.T) {
|
||||
client := testEntClient(t)
|
||||
repo := NewAPIKeyRepository(client).(*apiKeyRepository)
|
||||
ctx := context.Background()
|
||||
|
||||
// 创建测试用户和 API Key
|
||||
u, err := client.User.Create().
|
||||
SetEmail("concurrent-incr-" + time.Now().Format(time.RFC3339Nano) + "@test.com").
|
||||
SetPasswordHash("hash").
|
||||
SetStatus(service.StatusActive).
|
||||
SetRole(service.RoleUser).
|
||||
Save(ctx)
|
||||
require.NoError(t, err, "create user")
|
||||
|
||||
k := &service.APIKey{
|
||||
UserID: u.ID,
|
||||
Key: "sk-concurrent-" + time.Now().Format(time.RFC3339Nano),
|
||||
Name: "Concurrent",
|
||||
Status: service.StatusActive,
|
||||
}
|
||||
require.NoError(t, repo.Create(ctx, k), "create api key")
|
||||
t.Cleanup(func() {
|
||||
_ = client.APIKey.DeleteOneID(k.ID).Exec(ctx)
|
||||
_ = client.User.DeleteOneID(u.ID).Exec(ctx)
|
||||
})
|
||||
|
||||
// 10 个 goroutine 各递增 1.0,总计应为 10.0
|
||||
const goroutines = 10
|
||||
const increment = 1.0
|
||||
var wg sync.WaitGroup
|
||||
errs := make([]error, goroutines)
|
||||
|
||||
for i := 0; i < goroutines; i++ {
|
||||
wg.Add(1)
|
||||
go func(idx int) {
|
||||
defer wg.Done()
|
||||
_, errs[idx] = repo.IncrementQuotaUsed(ctx, k.ID, increment)
|
||||
}(i)
|
||||
}
|
||||
wg.Wait()
|
||||
|
||||
for i, e := range errs {
|
||||
require.NoError(t, e, "goroutine %d failed", i)
|
||||
}
|
||||
|
||||
// 验证最终结果
|
||||
got, err := repo.GetByID(ctx, k.ID)
|
||||
require.NoError(t, err, "GetByID")
|
||||
require.Equal(t, float64(goroutines)*increment, got.QuotaUsed,
|
||||
"并发递增后总和应为 %v,实际为 %v", float64(goroutines)*increment, got.QuotaUsed)
|
||||
}
|
||||
|
||||
@@ -5,6 +5,7 @@ import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"log"
|
||||
"math/rand"
|
||||
"strconv"
|
||||
"time"
|
||||
|
||||
@@ -16,8 +17,15 @@ const (
|
||||
billingBalanceKeyPrefix = "billing:balance:"
|
||||
billingSubKeyPrefix = "billing:sub:"
|
||||
billingCacheTTL = 5 * time.Minute
|
||||
billingCacheJitter = 30 * time.Second
|
||||
)
|
||||
|
||||
// jitteredTTL 返回带随机抖动的 TTL,防止缓存雪崩
|
||||
func jitteredTTL() time.Duration {
|
||||
jitter := time.Duration(rand.Int63n(int64(2*billingCacheJitter))) - billingCacheJitter
|
||||
return billingCacheTTL + jitter
|
||||
}
|
||||
|
||||
// billingBalanceKey generates the Redis key for user balance cache.
|
||||
func billingBalanceKey(userID int64) string {
|
||||
return fmt.Sprintf("%s%d", billingBalanceKeyPrefix, userID)
|
||||
@@ -82,14 +90,15 @@ func (c *billingCache) GetUserBalance(ctx context.Context, userID int64) (float6
|
||||
|
||||
func (c *billingCache) SetUserBalance(ctx context.Context, userID int64, balance float64) error {
|
||||
key := billingBalanceKey(userID)
|
||||
return c.rdb.Set(ctx, key, balance, billingCacheTTL).Err()
|
||||
return c.rdb.Set(ctx, key, balance, jitteredTTL()).Err()
|
||||
}
|
||||
|
||||
func (c *billingCache) DeductUserBalance(ctx context.Context, userID int64, amount float64) error {
|
||||
key := billingBalanceKey(userID)
|
||||
_, err := deductBalanceScript.Run(ctx, c.rdb, []string{key}, amount, int(billingCacheTTL.Seconds())).Result()
|
||||
_, err := deductBalanceScript.Run(ctx, c.rdb, []string{key}, amount, int(jitteredTTL().Seconds())).Result()
|
||||
if err != nil && !errors.Is(err, redis.Nil) {
|
||||
log.Printf("Warning: deduct balance cache failed for user %d: %v", userID, err)
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
}
|
||||
@@ -163,16 +172,17 @@ func (c *billingCache) SetSubscriptionCache(ctx context.Context, userID, groupID
|
||||
|
||||
pipe := c.rdb.Pipeline()
|
||||
pipe.HSet(ctx, key, fields)
|
||||
pipe.Expire(ctx, key, billingCacheTTL)
|
||||
pipe.Expire(ctx, key, jitteredTTL())
|
||||
_, err := pipe.Exec(ctx)
|
||||
return err
|
||||
}
|
||||
|
||||
func (c *billingCache) UpdateSubscriptionUsage(ctx context.Context, userID, groupID int64, cost float64) error {
|
||||
key := billingSubKey(userID, groupID)
|
||||
_, err := updateSubUsageScript.Run(ctx, c.rdb, []string{key}, cost, int(billingCacheTTL.Seconds())).Result()
|
||||
_, err := updateSubUsageScript.Run(ctx, c.rdb, []string{key}, cost, int(jitteredTTL().Seconds())).Result()
|
||||
if err != nil && !errors.Is(err, redis.Nil) {
|
||||
log.Printf("Warning: update subscription usage cache failed for user %d group %d: %v", userID, groupID, err)
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -278,6 +278,90 @@ func (s *BillingCacheSuite) TestSubscriptionCache() {
|
||||
}
|
||||
}
|
||||
|
||||
// TestDeductUserBalance_ErrorPropagation 验证 P2-12 修复:
|
||||
// Redis 真实错误应传播,key 不存在(redis.Nil)应返回 nil。
|
||||
func (s *BillingCacheSuite) TestDeductUserBalance_ErrorPropagation() {
|
||||
tests := []struct {
|
||||
name string
|
||||
fn func(ctx context.Context, cache service.BillingCache)
|
||||
expectErr bool
|
||||
}{
|
||||
{
|
||||
name: "key_not_exists_returns_nil",
|
||||
fn: func(ctx context.Context, cache service.BillingCache) {
|
||||
// key 不存在时,Lua 脚本返回 0(redis.Nil),应返回 nil 而非错误
|
||||
err := cache.DeductUserBalance(ctx, 99999, 1.0)
|
||||
require.NoError(s.T(), err, "DeductUserBalance on non-existent key should return nil")
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "existing_key_deducts_successfully",
|
||||
fn: func(ctx context.Context, cache service.BillingCache) {
|
||||
require.NoError(s.T(), cache.SetUserBalance(ctx, 200, 50.0))
|
||||
err := cache.DeductUserBalance(ctx, 200, 10.0)
|
||||
require.NoError(s.T(), err, "DeductUserBalance should succeed")
|
||||
|
||||
bal, err := cache.GetUserBalance(ctx, 200)
|
||||
require.NoError(s.T(), err)
|
||||
require.Equal(s.T(), 40.0, bal, "余额应为 40.0")
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "cancelled_context_propagates_error",
|
||||
fn: func(ctx context.Context, cache service.BillingCache) {
|
||||
require.NoError(s.T(), cache.SetUserBalance(ctx, 201, 50.0))
|
||||
|
||||
cancelCtx, cancel := context.WithCancel(ctx)
|
||||
cancel() // 立即取消
|
||||
|
||||
err := cache.DeductUserBalance(cancelCtx, 201, 10.0)
|
||||
require.Error(s.T(), err, "cancelled context should propagate error")
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
s.Run(tt.name, func() {
|
||||
rdb := testRedis(s.T())
|
||||
cache := NewBillingCache(rdb)
|
||||
ctx := context.Background()
|
||||
tt.fn(ctx, cache)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// TestUpdateSubscriptionUsage_ErrorPropagation 验证 P2-12 修复:
|
||||
// Redis 真实错误应传播,key 不存在(redis.Nil)应返回 nil。
|
||||
func (s *BillingCacheSuite) TestUpdateSubscriptionUsage_ErrorPropagation() {
|
||||
s.Run("key_not_exists_returns_nil", func() {
|
||||
rdb := testRedis(s.T())
|
||||
cache := NewBillingCache(rdb)
|
||||
ctx := context.Background()
|
||||
|
||||
err := cache.UpdateSubscriptionUsage(ctx, 88888, 77777, 1.0)
|
||||
require.NoError(s.T(), err, "UpdateSubscriptionUsage on non-existent key should return nil")
|
||||
})
|
||||
|
||||
s.Run("cancelled_context_propagates_error", func() {
|
||||
rdb := testRedis(s.T())
|
||||
cache := NewBillingCache(rdb)
|
||||
ctx := context.Background()
|
||||
|
||||
data := &service.SubscriptionCacheData{
|
||||
Status: "active",
|
||||
ExpiresAt: time.Now().Add(1 * time.Hour),
|
||||
Version: 1,
|
||||
}
|
||||
require.NoError(s.T(), cache.SetSubscriptionCache(ctx, 301, 401, data))
|
||||
|
||||
cancelCtx, cancel := context.WithCancel(ctx)
|
||||
cancel()
|
||||
|
||||
err := cache.UpdateSubscriptionUsage(cancelCtx, 301, 401, 1.0)
|
||||
require.Error(s.T(), err, "cancelled context should propagate error")
|
||||
})
|
||||
}
|
||||
|
||||
func TestBillingCacheSuite(t *testing.T) {
|
||||
suite.Run(t, new(BillingCacheSuite))
|
||||
}
|
||||
|
||||
@@ -5,6 +5,7 @@ package repository
|
||||
import (
|
||||
"math"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
@@ -85,3 +86,26 @@ func TestBillingSubKey(t *testing.T) {
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestJitteredTTL(t *testing.T) {
|
||||
const (
|
||||
minTTL = 4*time.Minute + 30*time.Second // 270s = 5min - 30s
|
||||
maxTTL = 5*time.Minute + 30*time.Second // 330s = 5min + 30s
|
||||
)
|
||||
|
||||
for i := 0; i < 200; i++ {
|
||||
ttl := jitteredTTL()
|
||||
require.GreaterOrEqual(t, ttl, minTTL, "jitteredTTL() 返回值低于下限: %v", ttl)
|
||||
require.LessOrEqual(t, ttl, maxTTL, "jitteredTTL() 返回值超过上限: %v", ttl)
|
||||
}
|
||||
}
|
||||
|
||||
func TestJitteredTTL_HasVariation(t *testing.T) {
|
||||
// 多次调用应该产生不同的值(验证抖动存在)
|
||||
seen := make(map[time.Duration]struct{}, 50)
|
||||
for i := 0; i < 50; i++ {
|
||||
seen[jitteredTTL()] = struct{}{}
|
||||
}
|
||||
// 50 次调用中应该至少有 2 个不同的值
|
||||
require.Greater(t, len(seen), 1, "jitteredTTL() 应产生不同的 TTL 值")
|
||||
}
|
||||
|
||||
@@ -183,7 +183,7 @@ func (r *groupRepository) ListWithFilters(ctx context.Context, params pagination
|
||||
q = q.Where(group.IsExclusiveEQ(*isExclusive))
|
||||
}
|
||||
|
||||
total, err := q.Count(ctx)
|
||||
total, err := q.Clone().Count(ctx)
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
|
||||
@@ -132,7 +132,7 @@ func (r *promoCodeRepository) ListWithFilters(ctx context.Context, params pagina
|
||||
q = q.Where(promocode.CodeContainsFold(search))
|
||||
}
|
||||
|
||||
total, err := q.Count(ctx)
|
||||
total, err := q.Clone().Count(ctx)
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
@@ -187,7 +187,7 @@ func (r *promoCodeRepository) ListUsagesByPromoCode(ctx context.Context, promoCo
|
||||
q := r.client.PromoCodeUsage.Query().
|
||||
Where(promocodeusage.PromoCodeIDEQ(promoCodeID))
|
||||
|
||||
total, err := q.Count(ctx)
|
||||
total, err := q.Clone().Count(ctx)
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
|
||||
@@ -24,6 +24,22 @@ import (
|
||||
|
||||
const usageLogSelectColumns = "id, user_id, api_key_id, account_id, request_id, model, group_id, subscription_id, input_tokens, output_tokens, cache_creation_tokens, cache_read_tokens, cache_creation_5m_tokens, cache_creation_1h_tokens, input_cost, output_cost, cache_creation_cost, cache_read_cost, total_cost, actual_cost, rate_multiplier, account_rate_multiplier, billing_type, stream, duration_ms, first_token_ms, user_agent, ip_address, image_count, image_size, reasoning_effort, created_at"
|
||||
|
||||
// dateFormatWhitelist 将 granularity 参数映射为 PostgreSQL TO_CHAR 格式字符串,防止外部输入直接拼入 SQL
|
||||
var dateFormatWhitelist = map[string]string{
|
||||
"hour": "YYYY-MM-DD HH24:00",
|
||||
"day": "YYYY-MM-DD",
|
||||
"week": "IYYY-IW",
|
||||
"month": "YYYY-MM",
|
||||
}
|
||||
|
||||
// safeDateFormat 根据白名单获取 dateFormat,未匹配时返回默认值
|
||||
func safeDateFormat(granularity string) string {
|
||||
if f, ok := dateFormatWhitelist[granularity]; ok {
|
||||
return f
|
||||
}
|
||||
return "YYYY-MM-DD"
|
||||
}
|
||||
|
||||
type usageLogRepository struct {
|
||||
client *dbent.Client
|
||||
sql sqlExecutor
|
||||
@@ -564,7 +580,7 @@ func (r *usageLogRepository) ListByAccount(ctx context.Context, accountID int64,
|
||||
}
|
||||
|
||||
func (r *usageLogRepository) ListByUserAndTimeRange(ctx context.Context, userID int64, startTime, endTime time.Time) ([]service.UsageLog, *pagination.PaginationResult, error) {
|
||||
query := "SELECT " + usageLogSelectColumns + " FROM usage_logs WHERE user_id = $1 AND created_at >= $2 AND created_at < $3 ORDER BY id DESC"
|
||||
query := "SELECT " + usageLogSelectColumns + " FROM usage_logs WHERE user_id = $1 AND created_at >= $2 AND created_at < $3 ORDER BY id DESC LIMIT 10000"
|
||||
logs, err := r.queryUsageLogs(ctx, query, userID, startTime, endTime)
|
||||
return logs, nil, err
|
||||
}
|
||||
@@ -810,19 +826,19 @@ func resolveUsageStatsTimezone() string {
|
||||
}
|
||||
|
||||
func (r *usageLogRepository) ListByAPIKeyAndTimeRange(ctx context.Context, apiKeyID int64, startTime, endTime time.Time) ([]service.UsageLog, *pagination.PaginationResult, error) {
|
||||
query := "SELECT " + usageLogSelectColumns + " FROM usage_logs WHERE api_key_id = $1 AND created_at >= $2 AND created_at < $3 ORDER BY id DESC"
|
||||
query := "SELECT " + usageLogSelectColumns + " FROM usage_logs WHERE api_key_id = $1 AND created_at >= $2 AND created_at < $3 ORDER BY id DESC LIMIT 10000"
|
||||
logs, err := r.queryUsageLogs(ctx, query, apiKeyID, startTime, endTime)
|
||||
return logs, nil, err
|
||||
}
|
||||
|
||||
func (r *usageLogRepository) ListByAccountAndTimeRange(ctx context.Context, accountID int64, startTime, endTime time.Time) ([]service.UsageLog, *pagination.PaginationResult, error) {
|
||||
query := "SELECT " + usageLogSelectColumns + " FROM usage_logs WHERE account_id = $1 AND created_at >= $2 AND created_at < $3 ORDER BY id DESC"
|
||||
query := "SELECT " + usageLogSelectColumns + " FROM usage_logs WHERE account_id = $1 AND created_at >= $2 AND created_at < $3 ORDER BY id DESC LIMIT 10000"
|
||||
logs, err := r.queryUsageLogs(ctx, query, accountID, startTime, endTime)
|
||||
return logs, nil, err
|
||||
}
|
||||
|
||||
func (r *usageLogRepository) ListByModelAndTimeRange(ctx context.Context, modelName string, startTime, endTime time.Time) ([]service.UsageLog, *pagination.PaginationResult, error) {
|
||||
query := "SELECT " + usageLogSelectColumns + " FROM usage_logs WHERE model = $1 AND created_at >= $2 AND created_at < $3 ORDER BY id DESC"
|
||||
query := "SELECT " + usageLogSelectColumns + " FROM usage_logs WHERE model = $1 AND created_at >= $2 AND created_at < $3 ORDER BY id DESC LIMIT 10000"
|
||||
logs, err := r.queryUsageLogs(ctx, query, modelName, startTime, endTime)
|
||||
return logs, nil, err
|
||||
}
|
||||
@@ -908,10 +924,7 @@ type APIKeyUsageTrendPoint = usagestats.APIKeyUsageTrendPoint
|
||||
|
||||
// GetAPIKeyUsageTrend returns usage trend data grouped by API key and date
|
||||
func (r *usageLogRepository) GetAPIKeyUsageTrend(ctx context.Context, startTime, endTime time.Time, granularity string, limit int) (results []APIKeyUsageTrendPoint, err error) {
|
||||
dateFormat := "YYYY-MM-DD"
|
||||
if granularity == "hour" {
|
||||
dateFormat = "YYYY-MM-DD HH24:00"
|
||||
}
|
||||
dateFormat := safeDateFormat(granularity)
|
||||
|
||||
query := fmt.Sprintf(`
|
||||
WITH top_keys AS (
|
||||
@@ -966,10 +979,7 @@ func (r *usageLogRepository) GetAPIKeyUsageTrend(ctx context.Context, startTime,
|
||||
|
||||
// GetUserUsageTrend returns usage trend data grouped by user and date
|
||||
func (r *usageLogRepository) GetUserUsageTrend(ctx context.Context, startTime, endTime time.Time, granularity string, limit int) (results []UserUsageTrendPoint, err error) {
|
||||
dateFormat := "YYYY-MM-DD"
|
||||
if granularity == "hour" {
|
||||
dateFormat = "YYYY-MM-DD HH24:00"
|
||||
}
|
||||
dateFormat := safeDateFormat(granularity)
|
||||
|
||||
query := fmt.Sprintf(`
|
||||
WITH top_users AS (
|
||||
@@ -1228,10 +1238,7 @@ func (r *usageLogRepository) GetAPIKeyDashboardStats(ctx context.Context, apiKey
|
||||
|
||||
// GetUserUsageTrendByUserID 获取指定用户的使用趋势
|
||||
func (r *usageLogRepository) GetUserUsageTrendByUserID(ctx context.Context, userID int64, startTime, endTime time.Time, granularity string) (results []TrendDataPoint, err error) {
|
||||
dateFormat := "YYYY-MM-DD"
|
||||
if granularity == "hour" {
|
||||
dateFormat = "YYYY-MM-DD HH24:00"
|
||||
}
|
||||
dateFormat := safeDateFormat(granularity)
|
||||
|
||||
query := fmt.Sprintf(`
|
||||
SELECT
|
||||
@@ -1369,13 +1376,22 @@ type UsageStats = usagestats.UsageStats
|
||||
// BatchUserUsageStats represents usage stats for a single user
|
||||
type BatchUserUsageStats = usagestats.BatchUserUsageStats
|
||||
|
||||
// GetBatchUserUsageStats gets today and total actual_cost for multiple users
|
||||
func (r *usageLogRepository) GetBatchUserUsageStats(ctx context.Context, userIDs []int64) (map[int64]*BatchUserUsageStats, error) {
|
||||
// GetBatchUserUsageStats gets today and total actual_cost for multiple users within a time range.
|
||||
// If startTime is zero, defaults to 30 days ago.
|
||||
func (r *usageLogRepository) GetBatchUserUsageStats(ctx context.Context, userIDs []int64, startTime, endTime time.Time) (map[int64]*BatchUserUsageStats, error) {
|
||||
result := make(map[int64]*BatchUserUsageStats)
|
||||
if len(userIDs) == 0 {
|
||||
return result, nil
|
||||
}
|
||||
|
||||
// 默认最近 30 天
|
||||
if startTime.IsZero() {
|
||||
startTime = time.Now().AddDate(0, 0, -30)
|
||||
}
|
||||
if endTime.IsZero() {
|
||||
endTime = time.Now()
|
||||
}
|
||||
|
||||
for _, id := range userIDs {
|
||||
result[id] = &BatchUserUsageStats{UserID: id}
|
||||
}
|
||||
@@ -1383,10 +1399,10 @@ func (r *usageLogRepository) GetBatchUserUsageStats(ctx context.Context, userIDs
|
||||
query := `
|
||||
SELECT user_id, COALESCE(SUM(actual_cost), 0) as total_cost
|
||||
FROM usage_logs
|
||||
WHERE user_id = ANY($1)
|
||||
WHERE user_id = ANY($1) AND created_at >= $2 AND created_at < $3
|
||||
GROUP BY user_id
|
||||
`
|
||||
rows, err := r.sql.QueryContext(ctx, query, pq.Array(userIDs))
|
||||
rows, err := r.sql.QueryContext(ctx, query, pq.Array(userIDs), startTime, endTime)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@@ -1443,13 +1459,22 @@ func (r *usageLogRepository) GetBatchUserUsageStats(ctx context.Context, userIDs
|
||||
// BatchAPIKeyUsageStats represents usage stats for a single API key
|
||||
type BatchAPIKeyUsageStats = usagestats.BatchAPIKeyUsageStats
|
||||
|
||||
// GetBatchAPIKeyUsageStats gets today and total actual_cost for multiple API keys
|
||||
func (r *usageLogRepository) GetBatchAPIKeyUsageStats(ctx context.Context, apiKeyIDs []int64) (map[int64]*BatchAPIKeyUsageStats, error) {
|
||||
// GetBatchAPIKeyUsageStats gets today and total actual_cost for multiple API keys within a time range.
|
||||
// If startTime is zero, defaults to 30 days ago.
|
||||
func (r *usageLogRepository) GetBatchAPIKeyUsageStats(ctx context.Context, apiKeyIDs []int64, startTime, endTime time.Time) (map[int64]*BatchAPIKeyUsageStats, error) {
|
||||
result := make(map[int64]*BatchAPIKeyUsageStats)
|
||||
if len(apiKeyIDs) == 0 {
|
||||
return result, nil
|
||||
}
|
||||
|
||||
// 默认最近 30 天
|
||||
if startTime.IsZero() {
|
||||
startTime = time.Now().AddDate(0, 0, -30)
|
||||
}
|
||||
if endTime.IsZero() {
|
||||
endTime = time.Now()
|
||||
}
|
||||
|
||||
for _, id := range apiKeyIDs {
|
||||
result[id] = &BatchAPIKeyUsageStats{APIKeyID: id}
|
||||
}
|
||||
@@ -1457,10 +1482,10 @@ func (r *usageLogRepository) GetBatchAPIKeyUsageStats(ctx context.Context, apiKe
|
||||
query := `
|
||||
SELECT api_key_id, COALESCE(SUM(actual_cost), 0) as total_cost
|
||||
FROM usage_logs
|
||||
WHERE api_key_id = ANY($1)
|
||||
WHERE api_key_id = ANY($1) AND created_at >= $2 AND created_at < $3
|
||||
GROUP BY api_key_id
|
||||
`
|
||||
rows, err := r.sql.QueryContext(ctx, query, pq.Array(apiKeyIDs))
|
||||
rows, err := r.sql.QueryContext(ctx, query, pq.Array(apiKeyIDs), startTime, endTime)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@@ -1516,10 +1541,7 @@ func (r *usageLogRepository) GetBatchAPIKeyUsageStats(ctx context.Context, apiKe
|
||||
|
||||
// GetUsageTrendWithFilters returns usage trend data with optional filters
|
||||
func (r *usageLogRepository) GetUsageTrendWithFilters(ctx context.Context, startTime, endTime time.Time, granularity string, userID, apiKeyID, accountID, groupID int64, model string, stream *bool, billingType *int8) (results []TrendDataPoint, err error) {
|
||||
dateFormat := "YYYY-MM-DD"
|
||||
if granularity == "hour" {
|
||||
dateFormat = "YYYY-MM-DD HH24:00"
|
||||
}
|
||||
dateFormat := safeDateFormat(granularity)
|
||||
|
||||
query := fmt.Sprintf(`
|
||||
SELECT
|
||||
|
||||
@@ -648,7 +648,7 @@ func (s *UsageLogRepoSuite) TestGetBatchUserUsageStats() {
|
||||
s.createUsageLog(user1, apiKey1, account, 10, 20, 0.5, time.Now())
|
||||
s.createUsageLog(user2, apiKey2, account, 15, 25, 0.6, time.Now())
|
||||
|
||||
stats, err := s.repo.GetBatchUserUsageStats(s.ctx, []int64{user1.ID, user2.ID})
|
||||
stats, err := s.repo.GetBatchUserUsageStats(s.ctx, []int64{user1.ID, user2.ID}, time.Time{}, time.Time{})
|
||||
s.Require().NoError(err, "GetBatchUserUsageStats")
|
||||
s.Require().Len(stats, 2)
|
||||
s.Require().NotNil(stats[user1.ID])
|
||||
@@ -656,7 +656,7 @@ func (s *UsageLogRepoSuite) TestGetBatchUserUsageStats() {
|
||||
}
|
||||
|
||||
func (s *UsageLogRepoSuite) TestGetBatchUserUsageStats_Empty() {
|
||||
stats, err := s.repo.GetBatchUserUsageStats(s.ctx, []int64{})
|
||||
stats, err := s.repo.GetBatchUserUsageStats(s.ctx, []int64{}, time.Time{}, time.Time{})
|
||||
s.Require().NoError(err)
|
||||
s.Require().Empty(stats)
|
||||
}
|
||||
@@ -672,13 +672,13 @@ func (s *UsageLogRepoSuite) TestGetBatchApiKeyUsageStats() {
|
||||
s.createUsageLog(user, apiKey1, account, 10, 20, 0.5, time.Now())
|
||||
s.createUsageLog(user, apiKey2, account, 15, 25, 0.6, time.Now())
|
||||
|
||||
stats, err := s.repo.GetBatchAPIKeyUsageStats(s.ctx, []int64{apiKey1.ID, apiKey2.ID})
|
||||
stats, err := s.repo.GetBatchAPIKeyUsageStats(s.ctx, []int64{apiKey1.ID, apiKey2.ID}, time.Time{}, time.Time{})
|
||||
s.Require().NoError(err, "GetBatchAPIKeyUsageStats")
|
||||
s.Require().Len(stats, 2)
|
||||
}
|
||||
|
||||
func (s *UsageLogRepoSuite) TestGetBatchApiKeyUsageStats_Empty() {
|
||||
stats, err := s.repo.GetBatchAPIKeyUsageStats(s.ctx, []int64{})
|
||||
stats, err := s.repo.GetBatchAPIKeyUsageStats(s.ctx, []int64{}, time.Time{}, time.Time{})
|
||||
s.Require().NoError(err)
|
||||
s.Require().Empty(stats)
|
||||
}
|
||||
|
||||
41
backend/internal/repository/usage_log_repo_unit_test.go
Normal file
41
backend/internal/repository/usage_log_repo_unit_test.go
Normal file
@@ -0,0 +1,41 @@
|
||||
//go:build unit
|
||||
|
||||
package repository
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestSafeDateFormat(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
granularity string
|
||||
expected string
|
||||
}{
|
||||
// 合法值
|
||||
{"hour", "hour", "YYYY-MM-DD HH24:00"},
|
||||
{"day", "day", "YYYY-MM-DD"},
|
||||
{"week", "week", "IYYY-IW"},
|
||||
{"month", "month", "YYYY-MM"},
|
||||
|
||||
// 非法值回退到默认
|
||||
{"空字符串", "", "YYYY-MM-DD"},
|
||||
{"未知粒度 year", "year", "YYYY-MM-DD"},
|
||||
{"未知粒度 minute", "minute", "YYYY-MM-DD"},
|
||||
|
||||
// 恶意字符串
|
||||
{"SQL 注入尝试", "'; DROP TABLE users; --", "YYYY-MM-DD"},
|
||||
{"带引号", "day'", "YYYY-MM-DD"},
|
||||
{"带括号", "day)", "YYYY-MM-DD"},
|
||||
{"Unicode", "日", "YYYY-MM-DD"},
|
||||
}
|
||||
|
||||
for _, tc := range tests {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
got := safeDateFormat(tc.granularity)
|
||||
require.Equal(t, tc.expected, got, "safeDateFormat(%q)", tc.granularity)
|
||||
})
|
||||
}
|
||||
}
|
||||
Reference in New Issue
Block a user