Merge upstream/main: v0.1.85-v0.1.86 updates
Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
This commit is contained in:
@@ -15,7 +15,6 @@ import (
|
||||
"database/sql"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"log"
|
||||
"strconv"
|
||||
"time"
|
||||
|
||||
@@ -25,6 +24,7 @@ import (
|
||||
dbgroup "github.com/Wei-Shaw/sub2api/ent/group"
|
||||
dbpredicate "github.com/Wei-Shaw/sub2api/ent/predicate"
|
||||
dbproxy "github.com/Wei-Shaw/sub2api/ent/proxy"
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/logger"
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/pagination"
|
||||
"github.com/Wei-Shaw/sub2api/internal/service"
|
||||
"github.com/lib/pq"
|
||||
@@ -127,7 +127,7 @@ func (r *accountRepository) Create(ctx context.Context, account *service.Account
|
||||
account.CreatedAt = created.CreatedAt
|
||||
account.UpdatedAt = created.UpdatedAt
|
||||
if err := enqueueSchedulerOutbox(ctx, r.sql, service.SchedulerOutboxEventAccountChanged, &account.ID, nil, buildSchedulerGroupPayload(account.GroupIDs)); err != nil {
|
||||
log.Printf("[SchedulerOutbox] enqueue account create failed: account=%d err=%v", account.ID, err)
|
||||
logger.LegacyPrintf("repository.account", "[SchedulerOutbox] enqueue account create failed: account=%d err=%v", account.ID, err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
@@ -388,7 +388,7 @@ func (r *accountRepository) Update(ctx context.Context, account *service.Account
|
||||
}
|
||||
account.UpdatedAt = updated.UpdatedAt
|
||||
if err := enqueueSchedulerOutbox(ctx, r.sql, service.SchedulerOutboxEventAccountChanged, &account.ID, nil, buildSchedulerGroupPayload(account.GroupIDs)); err != nil {
|
||||
log.Printf("[SchedulerOutbox] enqueue account update failed: account=%d err=%v", account.ID, err)
|
||||
logger.LegacyPrintf("repository.account", "[SchedulerOutbox] enqueue account update failed: account=%d err=%v", account.ID, err)
|
||||
}
|
||||
if account.Status == service.StatusError || account.Status == service.StatusDisabled || !account.Schedulable {
|
||||
r.syncSchedulerAccountSnapshot(ctx, account.ID)
|
||||
@@ -429,7 +429,7 @@ func (r *accountRepository) Delete(ctx context.Context, id int64) error {
|
||||
}
|
||||
}
|
||||
if err := enqueueSchedulerOutbox(ctx, r.sql, service.SchedulerOutboxEventAccountChanged, &id, nil, buildSchedulerGroupPayload(groupIDs)); err != nil {
|
||||
log.Printf("[SchedulerOutbox] enqueue account delete failed: account=%d err=%v", id, err)
|
||||
logger.LegacyPrintf("repository.account", "[SchedulerOutbox] enqueue account delete failed: account=%d err=%v", id, err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
@@ -533,7 +533,7 @@ func (r *accountRepository) UpdateLastUsed(ctx context.Context, id int64) error
|
||||
},
|
||||
}
|
||||
if err := enqueueSchedulerOutbox(ctx, r.sql, service.SchedulerOutboxEventAccountLastUsed, &id, nil, payload); err != nil {
|
||||
log.Printf("[SchedulerOutbox] enqueue last used failed: account=%d err=%v", id, err)
|
||||
logger.LegacyPrintf("repository.account", "[SchedulerOutbox] enqueue last used failed: account=%d err=%v", id, err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
@@ -568,7 +568,7 @@ func (r *accountRepository) BatchUpdateLastUsed(ctx context.Context, updates map
|
||||
}
|
||||
payload := map[string]any{"last_used": lastUsedPayload}
|
||||
if err := enqueueSchedulerOutbox(ctx, r.sql, service.SchedulerOutboxEventAccountLastUsed, nil, nil, payload); err != nil {
|
||||
log.Printf("[SchedulerOutbox] enqueue batch last used failed: err=%v", err)
|
||||
logger.LegacyPrintf("repository.account", "[SchedulerOutbox] enqueue batch last used failed: err=%v", err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
@@ -583,7 +583,7 @@ func (r *accountRepository) SetError(ctx context.Context, id int64, errorMsg str
|
||||
return err
|
||||
}
|
||||
if err := enqueueSchedulerOutbox(ctx, r.sql, service.SchedulerOutboxEventAccountChanged, &id, nil, nil); err != nil {
|
||||
log.Printf("[SchedulerOutbox] enqueue set error failed: account=%d err=%v", id, err)
|
||||
logger.LegacyPrintf("repository.account", "[SchedulerOutbox] enqueue set error failed: account=%d err=%v", id, err)
|
||||
}
|
||||
r.syncSchedulerAccountSnapshot(ctx, id)
|
||||
return nil
|
||||
@@ -603,11 +603,11 @@ func (r *accountRepository) syncSchedulerAccountSnapshot(ctx context.Context, ac
|
||||
}
|
||||
account, err := r.GetByID(ctx, accountID)
|
||||
if err != nil {
|
||||
log.Printf("[Scheduler] sync account snapshot read failed: id=%d err=%v", accountID, err)
|
||||
logger.LegacyPrintf("repository.account", "[Scheduler] sync account snapshot read failed: id=%d err=%v", accountID, err)
|
||||
return
|
||||
}
|
||||
if err := r.schedulerCache.SetAccount(ctx, account); err != nil {
|
||||
log.Printf("[Scheduler] sync account snapshot write failed: id=%d err=%v", accountID, err)
|
||||
logger.LegacyPrintf("repository.account", "[Scheduler] sync account snapshot write failed: id=%d err=%v", accountID, err)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -631,7 +631,7 @@ func (r *accountRepository) AddToGroup(ctx context.Context, accountID, groupID i
|
||||
}
|
||||
payload := buildSchedulerGroupPayload([]int64{groupID})
|
||||
if err := enqueueSchedulerOutbox(ctx, r.sql, service.SchedulerOutboxEventAccountGroupsChanged, &accountID, nil, payload); err != nil {
|
||||
log.Printf("[SchedulerOutbox] enqueue add to group failed: account=%d group=%d err=%v", accountID, groupID, err)
|
||||
logger.LegacyPrintf("repository.account", "[SchedulerOutbox] enqueue add to group failed: account=%d group=%d err=%v", accountID, groupID, err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
@@ -648,7 +648,7 @@ func (r *accountRepository) RemoveFromGroup(ctx context.Context, accountID, grou
|
||||
}
|
||||
payload := buildSchedulerGroupPayload([]int64{groupID})
|
||||
if err := enqueueSchedulerOutbox(ctx, r.sql, service.SchedulerOutboxEventAccountGroupsChanged, &accountID, nil, payload); err != nil {
|
||||
log.Printf("[SchedulerOutbox] enqueue remove from group failed: account=%d group=%d err=%v", accountID, groupID, err)
|
||||
logger.LegacyPrintf("repository.account", "[SchedulerOutbox] enqueue remove from group failed: account=%d group=%d err=%v", accountID, groupID, err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
@@ -721,7 +721,7 @@ func (r *accountRepository) BindGroups(ctx context.Context, accountID int64, gro
|
||||
}
|
||||
payload := buildSchedulerGroupPayload(mergeGroupIDs(existingGroupIDs, groupIDs))
|
||||
if err := enqueueSchedulerOutbox(ctx, r.sql, service.SchedulerOutboxEventAccountGroupsChanged, &accountID, nil, payload); err != nil {
|
||||
log.Printf("[SchedulerOutbox] enqueue bind groups failed: account=%d err=%v", accountID, err)
|
||||
logger.LegacyPrintf("repository.account", "[SchedulerOutbox] enqueue bind groups failed: account=%d err=%v", accountID, err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
@@ -829,7 +829,7 @@ func (r *accountRepository) SetRateLimited(ctx context.Context, id int64, resetA
|
||||
return err
|
||||
}
|
||||
if err := enqueueSchedulerOutbox(ctx, r.sql, service.SchedulerOutboxEventAccountChanged, &id, nil, nil); err != nil {
|
||||
log.Printf("[SchedulerOutbox] enqueue rate limit failed: account=%d err=%v", id, err)
|
||||
logger.LegacyPrintf("repository.account", "[SchedulerOutbox] enqueue rate limit failed: account=%d err=%v", id, err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
@@ -876,7 +876,7 @@ func (r *accountRepository) SetModelRateLimit(ctx context.Context, id int64, sco
|
||||
return service.ErrAccountNotFound
|
||||
}
|
||||
if err := enqueueSchedulerOutbox(ctx, r.sql, service.SchedulerOutboxEventAccountChanged, &id, nil, nil); err != nil {
|
||||
log.Printf("[SchedulerOutbox] enqueue model rate limit failed: account=%d err=%v", id, err)
|
||||
logger.LegacyPrintf("repository.account", "[SchedulerOutbox] enqueue model rate limit failed: account=%d err=%v", id, err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
@@ -890,7 +890,7 @@ func (r *accountRepository) SetOverloaded(ctx context.Context, id int64, until t
|
||||
return err
|
||||
}
|
||||
if err := enqueueSchedulerOutbox(ctx, r.sql, service.SchedulerOutboxEventAccountChanged, &id, nil, nil); err != nil {
|
||||
log.Printf("[SchedulerOutbox] enqueue overload failed: account=%d err=%v", id, err)
|
||||
logger.LegacyPrintf("repository.account", "[SchedulerOutbox] enqueue overload failed: account=%d err=%v", id, err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
@@ -909,7 +909,7 @@ func (r *accountRepository) SetTempUnschedulable(ctx context.Context, id int64,
|
||||
return err
|
||||
}
|
||||
if err := enqueueSchedulerOutbox(ctx, r.sql, service.SchedulerOutboxEventAccountChanged, &id, nil, nil); err != nil {
|
||||
log.Printf("[SchedulerOutbox] enqueue temp unschedulable failed: account=%d err=%v", id, err)
|
||||
logger.LegacyPrintf("repository.account", "[SchedulerOutbox] enqueue temp unschedulable failed: account=%d err=%v", id, err)
|
||||
}
|
||||
r.syncSchedulerAccountSnapshot(ctx, id)
|
||||
return nil
|
||||
@@ -928,7 +928,7 @@ func (r *accountRepository) ClearTempUnschedulable(ctx context.Context, id int64
|
||||
return err
|
||||
}
|
||||
if err := enqueueSchedulerOutbox(ctx, r.sql, service.SchedulerOutboxEventAccountChanged, &id, nil, nil); err != nil {
|
||||
log.Printf("[SchedulerOutbox] enqueue clear temp unschedulable failed: account=%d err=%v", id, err)
|
||||
logger.LegacyPrintf("repository.account", "[SchedulerOutbox] enqueue clear temp unschedulable failed: account=%d err=%v", id, err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
@@ -944,7 +944,7 @@ func (r *accountRepository) ClearRateLimit(ctx context.Context, id int64) error
|
||||
return err
|
||||
}
|
||||
if err := enqueueSchedulerOutbox(ctx, r.sql, service.SchedulerOutboxEventAccountChanged, &id, nil, nil); err != nil {
|
||||
log.Printf("[SchedulerOutbox] enqueue clear rate limit failed: account=%d err=%v", id, err)
|
||||
logger.LegacyPrintf("repository.account", "[SchedulerOutbox] enqueue clear rate limit failed: account=%d err=%v", id, err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
@@ -968,7 +968,7 @@ func (r *accountRepository) ClearAntigravityQuotaScopes(ctx context.Context, id
|
||||
return service.ErrAccountNotFound
|
||||
}
|
||||
if err := enqueueSchedulerOutbox(ctx, r.sql, service.SchedulerOutboxEventAccountChanged, &id, nil, nil); err != nil {
|
||||
log.Printf("[SchedulerOutbox] enqueue clear quota scopes failed: account=%d err=%v", id, err)
|
||||
logger.LegacyPrintf("repository.account", "[SchedulerOutbox] enqueue clear quota scopes failed: account=%d err=%v", id, err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
@@ -992,7 +992,7 @@ func (r *accountRepository) ClearModelRateLimits(ctx context.Context, id int64)
|
||||
return service.ErrAccountNotFound
|
||||
}
|
||||
if err := enqueueSchedulerOutbox(ctx, r.sql, service.SchedulerOutboxEventAccountChanged, &id, nil, nil); err != nil {
|
||||
log.Printf("[SchedulerOutbox] enqueue clear model rate limit failed: account=%d err=%v", id, err)
|
||||
logger.LegacyPrintf("repository.account", "[SchedulerOutbox] enqueue clear model rate limit failed: account=%d err=%v", id, err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
@@ -1014,7 +1014,7 @@ func (r *accountRepository) UpdateSessionWindow(ctx context.Context, id int64, s
|
||||
// 触发调度器缓存更新(仅当窗口时间有变化时)
|
||||
if start != nil || end != nil {
|
||||
if err := enqueueSchedulerOutbox(ctx, r.sql, service.SchedulerOutboxEventAccountChanged, &id, nil, nil); err != nil {
|
||||
log.Printf("[SchedulerOutbox] enqueue session window update failed: account=%d err=%v", id, err)
|
||||
logger.LegacyPrintf("repository.account", "[SchedulerOutbox] enqueue session window update failed: account=%d err=%v", id, err)
|
||||
}
|
||||
}
|
||||
return nil
|
||||
@@ -1029,7 +1029,7 @@ func (r *accountRepository) SetSchedulable(ctx context.Context, id int64, schedu
|
||||
return err
|
||||
}
|
||||
if err := enqueueSchedulerOutbox(ctx, r.sql, service.SchedulerOutboxEventAccountChanged, &id, nil, nil); err != nil {
|
||||
log.Printf("[SchedulerOutbox] enqueue schedulable change failed: account=%d err=%v", id, err)
|
||||
logger.LegacyPrintf("repository.account", "[SchedulerOutbox] enqueue schedulable change failed: account=%d err=%v", id, err)
|
||||
}
|
||||
if !schedulable {
|
||||
r.syncSchedulerAccountSnapshot(ctx, id)
|
||||
@@ -1057,7 +1057,7 @@ func (r *accountRepository) AutoPauseExpiredAccounts(ctx context.Context, now ti
|
||||
}
|
||||
if rows > 0 {
|
||||
if err := enqueueSchedulerOutbox(ctx, r.sql, service.SchedulerOutboxEventFullRebuild, nil, nil, nil); err != nil {
|
||||
log.Printf("[SchedulerOutbox] enqueue auto pause rebuild failed: err=%v", err)
|
||||
logger.LegacyPrintf("repository.account", "[SchedulerOutbox] enqueue auto pause rebuild failed: err=%v", err)
|
||||
}
|
||||
}
|
||||
return rows, nil
|
||||
@@ -1093,7 +1093,7 @@ func (r *accountRepository) UpdateExtra(ctx context.Context, id int64, updates m
|
||||
return service.ErrAccountNotFound
|
||||
}
|
||||
if err := enqueueSchedulerOutbox(ctx, r.sql, service.SchedulerOutboxEventAccountChanged, &id, nil, nil); err != nil {
|
||||
log.Printf("[SchedulerOutbox] enqueue extra update failed: account=%d err=%v", id, err)
|
||||
logger.LegacyPrintf("repository.account", "[SchedulerOutbox] enqueue extra update failed: account=%d err=%v", id, err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
@@ -1187,7 +1187,7 @@ func (r *accountRepository) BulkUpdate(ctx context.Context, ids []int64, updates
|
||||
if rows > 0 {
|
||||
payload := map[string]any{"account_ids": ids}
|
||||
if err := enqueueSchedulerOutbox(ctx, r.sql, service.SchedulerOutboxEventAccountBulkChanged, nil, nil, payload); err != nil {
|
||||
log.Printf("[SchedulerOutbox] enqueue bulk update failed: err=%v", err)
|
||||
logger.LegacyPrintf("repository.account", "[SchedulerOutbox] enqueue bulk update failed: err=%v", err)
|
||||
}
|
||||
shouldSync := false
|
||||
if updates.Status != nil && (*updates.Status == service.StatusError || *updates.Status == service.StatusDisabled) {
|
||||
@@ -1560,3 +1560,64 @@ func joinClauses(clauses []string, sep string) string {
|
||||
func itoa(v int) string {
|
||||
return strconv.Itoa(v)
|
||||
}
|
||||
|
||||
// FindByExtraField 根据 extra 字段中的键值对查找账号。
|
||||
// 该方法限定 platform='sora',避免误查询其他平台的账号。
|
||||
// 使用 PostgreSQL JSONB @> 操作符进行高效查询(需要 GIN 索引支持)。
|
||||
//
|
||||
// 应用场景:查找通过 linked_openai_account_id 关联的 Sora 账号。
|
||||
//
|
||||
// FindByExtraField finds accounts by key-value pairs in the extra field.
|
||||
// Limited to platform='sora' to avoid querying accounts from other platforms.
|
||||
// Uses PostgreSQL JSONB @> operator for efficient queries (requires GIN index).
|
||||
//
|
||||
// Use case: Finding Sora accounts linked via linked_openai_account_id.
|
||||
func (r *accountRepository) FindByExtraField(ctx context.Context, key string, value any) ([]service.Account, error) {
|
||||
accounts, err := r.client.Account.Query().
|
||||
Where(
|
||||
dbaccount.PlatformEQ("sora"), // 限定平台为 sora
|
||||
dbaccount.DeletedAtIsNil(),
|
||||
func(s *entsql.Selector) {
|
||||
path := sqljson.Path(key)
|
||||
switch v := value.(type) {
|
||||
case string:
|
||||
preds := []*entsql.Predicate{sqljson.ValueEQ(dbaccount.FieldExtra, v, path)}
|
||||
if parsed, err := strconv.ParseInt(v, 10, 64); err == nil {
|
||||
preds = append(preds, sqljson.ValueEQ(dbaccount.FieldExtra, parsed, path))
|
||||
}
|
||||
if len(preds) == 1 {
|
||||
s.Where(preds[0])
|
||||
} else {
|
||||
s.Where(entsql.Or(preds...))
|
||||
}
|
||||
case int:
|
||||
s.Where(entsql.Or(
|
||||
sqljson.ValueEQ(dbaccount.FieldExtra, v, path),
|
||||
sqljson.ValueEQ(dbaccount.FieldExtra, strconv.Itoa(v), path),
|
||||
))
|
||||
case int64:
|
||||
s.Where(entsql.Or(
|
||||
sqljson.ValueEQ(dbaccount.FieldExtra, v, path),
|
||||
sqljson.ValueEQ(dbaccount.FieldExtra, strconv.FormatInt(v, 10), path),
|
||||
))
|
||||
case json.Number:
|
||||
if parsed, err := v.Int64(); err == nil {
|
||||
s.Where(entsql.Or(
|
||||
sqljson.ValueEQ(dbaccount.FieldExtra, parsed, path),
|
||||
sqljson.ValueEQ(dbaccount.FieldExtra, v.String(), path),
|
||||
))
|
||||
} else {
|
||||
s.Where(sqljson.ValueEQ(dbaccount.FieldExtra, v.String(), path))
|
||||
}
|
||||
default:
|
||||
s.Where(sqljson.ValueEQ(dbaccount.FieldExtra, value, path))
|
||||
}
|
||||
},
|
||||
).
|
||||
All(ctx)
|
||||
if err != nil {
|
||||
return nil, translatePersistenceError(err, service.ErrAccountNotFound, nil)
|
||||
}
|
||||
|
||||
return r.accountsToService(ctx, accounts)
|
||||
}
|
||||
|
||||
@@ -34,6 +34,7 @@ func (r *apiKeyRepository) Create(ctx context.Context, key *service.APIKey) erro
|
||||
SetName(key.Name).
|
||||
SetStatus(key.Status).
|
||||
SetNillableGroupID(key.GroupID).
|
||||
SetNillableLastUsedAt(key.LastUsedAt).
|
||||
SetQuota(key.Quota).
|
||||
SetQuotaUsed(key.QuotaUsed).
|
||||
SetNillableExpiresAt(key.ExpiresAt)
|
||||
@@ -48,6 +49,7 @@ func (r *apiKeyRepository) Create(ctx context.Context, key *service.APIKey) erro
|
||||
created, err := builder.Save(ctx)
|
||||
if err == nil {
|
||||
key.ID = created.ID
|
||||
key.LastUsedAt = created.LastUsedAt
|
||||
key.CreatedAt = created.CreatedAt
|
||||
key.UpdatedAt = created.UpdatedAt
|
||||
}
|
||||
@@ -140,6 +142,10 @@ func (r *apiKeyRepository) GetByKeyForAuth(ctx context.Context, key string) (*se
|
||||
group.FieldImagePrice1k,
|
||||
group.FieldImagePrice2k,
|
||||
group.FieldImagePrice4k,
|
||||
group.FieldSoraImagePrice360,
|
||||
group.FieldSoraImagePrice540,
|
||||
group.FieldSoraVideoPricePerRequest,
|
||||
group.FieldSoraVideoPricePerRequestHd,
|
||||
group.FieldClaudeCodeOnly,
|
||||
group.FieldFallbackGroupID,
|
||||
group.FieldFallbackGroupIDOnInvalidRequest,
|
||||
@@ -375,36 +381,34 @@ func (r *apiKeyRepository) ListKeysByGroupID(ctx context.Context, groupID int64)
|
||||
return keys, nil
|
||||
}
|
||||
|
||||
// IncrementQuotaUsed atomically increments the quota_used field and returns the new value
|
||||
// IncrementQuotaUsed 使用 Ent 原子递增 quota_used 字段并返回新值
|
||||
func (r *apiKeyRepository) IncrementQuotaUsed(ctx context.Context, id int64, amount float64) (float64, error) {
|
||||
// Use raw SQL for atomic increment to avoid race conditions
|
||||
// First get current value
|
||||
m, err := r.activeQuery().
|
||||
Where(apikey.IDEQ(id)).
|
||||
Select(apikey.FieldQuotaUsed).
|
||||
Only(ctx)
|
||||
updated, err := r.client.APIKey.UpdateOneID(id).
|
||||
Where(apikey.DeletedAtIsNil()).
|
||||
AddQuotaUsed(amount).
|
||||
Save(ctx)
|
||||
if err != nil {
|
||||
if dbent.IsNotFound(err) {
|
||||
return 0, service.ErrAPIKeyNotFound
|
||||
}
|
||||
return 0, err
|
||||
}
|
||||
return updated.QuotaUsed, nil
|
||||
}
|
||||
|
||||
newValue := m.QuotaUsed + amount
|
||||
|
||||
// Update with new value
|
||||
func (r *apiKeyRepository) UpdateLastUsed(ctx context.Context, id int64, usedAt time.Time) error {
|
||||
affected, err := r.client.APIKey.Update().
|
||||
Where(apikey.IDEQ(id), apikey.DeletedAtIsNil()).
|
||||
SetQuotaUsed(newValue).
|
||||
SetLastUsedAt(usedAt).
|
||||
SetUpdatedAt(usedAt).
|
||||
Save(ctx)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
return err
|
||||
}
|
||||
if affected == 0 {
|
||||
return 0, service.ErrAPIKeyNotFound
|
||||
return service.ErrAPIKeyNotFound
|
||||
}
|
||||
|
||||
return newValue, nil
|
||||
return nil
|
||||
}
|
||||
|
||||
func apiKeyEntityToService(m *dbent.APIKey) *service.APIKey {
|
||||
@@ -419,6 +423,7 @@ func apiKeyEntityToService(m *dbent.APIKey) *service.APIKey {
|
||||
Status: m.Status,
|
||||
IPWhitelist: m.IPWhitelist,
|
||||
IPBlacklist: m.IPBlacklist,
|
||||
LastUsedAt: m.LastUsedAt,
|
||||
CreatedAt: m.CreatedAt,
|
||||
UpdatedAt: m.UpdatedAt,
|
||||
GroupID: m.GroupID,
|
||||
@@ -477,6 +482,10 @@ func groupEntityToService(g *dbent.Group) *service.Group {
|
||||
ImagePrice1K: g.ImagePrice1k,
|
||||
ImagePrice2K: g.ImagePrice2k,
|
||||
ImagePrice4K: g.ImagePrice4k,
|
||||
SoraImagePrice360: g.SoraImagePrice360,
|
||||
SoraImagePrice540: g.SoraImagePrice540,
|
||||
SoraVideoPricePerRequest: g.SoraVideoPricePerRequest,
|
||||
SoraVideoPricePerRequestHD: g.SoraVideoPricePerRequestHd,
|
||||
DefaultValidityDays: g.DefaultValidityDays,
|
||||
ClaudeCodeOnly: g.ClaudeCodeOnly,
|
||||
FallbackGroupID: g.FallbackGroupID,
|
||||
|
||||
@@ -4,11 +4,14 @@ package repository
|
||||
|
||||
import (
|
||||
"context"
|
||||
"sync"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
dbent "github.com/Wei-Shaw/sub2api/ent"
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/pagination"
|
||||
"github.com/Wei-Shaw/sub2api/internal/service"
|
||||
"github.com/stretchr/testify/require"
|
||||
"github.com/stretchr/testify/suite"
|
||||
)
|
||||
|
||||
@@ -383,3 +386,87 @@ func (s *APIKeyRepoSuite) mustCreateApiKey(userID int64, key, name string, group
|
||||
s.Require().NoError(s.repo.Create(s.ctx, k), "create api key")
|
||||
return k
|
||||
}
|
||||
|
||||
// --- IncrementQuotaUsed ---
|
||||
|
||||
func (s *APIKeyRepoSuite) TestIncrementQuotaUsed_Basic() {
|
||||
user := s.mustCreateUser("incr-basic@test.com")
|
||||
key := s.mustCreateApiKey(user.ID, "sk-incr-basic", "Incr", nil)
|
||||
|
||||
newQuota, err := s.repo.IncrementQuotaUsed(s.ctx, key.ID, 1.5)
|
||||
s.Require().NoError(err, "IncrementQuotaUsed")
|
||||
s.Require().Equal(1.5, newQuota, "第一次递增后应为 1.5")
|
||||
|
||||
newQuota, err = s.repo.IncrementQuotaUsed(s.ctx, key.ID, 2.5)
|
||||
s.Require().NoError(err, "IncrementQuotaUsed second")
|
||||
s.Require().Equal(4.0, newQuota, "第二次递增后应为 4.0")
|
||||
}
|
||||
|
||||
func (s *APIKeyRepoSuite) TestIncrementQuotaUsed_NotFound() {
|
||||
_, err := s.repo.IncrementQuotaUsed(s.ctx, 999999, 1.0)
|
||||
s.Require().ErrorIs(err, service.ErrAPIKeyNotFound, "不存在的 key 应返回 ErrAPIKeyNotFound")
|
||||
}
|
||||
|
||||
func (s *APIKeyRepoSuite) TestIncrementQuotaUsed_DeletedKey() {
|
||||
user := s.mustCreateUser("incr-deleted@test.com")
|
||||
key := s.mustCreateApiKey(user.ID, "sk-incr-del", "Deleted", nil)
|
||||
|
||||
s.Require().NoError(s.repo.Delete(s.ctx, key.ID), "Delete")
|
||||
|
||||
_, err := s.repo.IncrementQuotaUsed(s.ctx, key.ID, 1.0)
|
||||
s.Require().ErrorIs(err, service.ErrAPIKeyNotFound, "已删除的 key 应返回 ErrAPIKeyNotFound")
|
||||
}
|
||||
|
||||
// TestIncrementQuotaUsed_Concurrent 使用真实数据库验证并发原子性。
|
||||
// 注意:此测试使用 testEntClient(非事务隔离),数据会真正写入数据库。
|
||||
func TestIncrementQuotaUsed_Concurrent(t *testing.T) {
|
||||
client := testEntClient(t)
|
||||
repo := NewAPIKeyRepository(client).(*apiKeyRepository)
|
||||
ctx := context.Background()
|
||||
|
||||
// 创建测试用户和 API Key
|
||||
u, err := client.User.Create().
|
||||
SetEmail("concurrent-incr-" + time.Now().Format(time.RFC3339Nano) + "@test.com").
|
||||
SetPasswordHash("hash").
|
||||
SetStatus(service.StatusActive).
|
||||
SetRole(service.RoleUser).
|
||||
Save(ctx)
|
||||
require.NoError(t, err, "create user")
|
||||
|
||||
k := &service.APIKey{
|
||||
UserID: u.ID,
|
||||
Key: "sk-concurrent-" + time.Now().Format(time.RFC3339Nano),
|
||||
Name: "Concurrent",
|
||||
Status: service.StatusActive,
|
||||
}
|
||||
require.NoError(t, repo.Create(ctx, k), "create api key")
|
||||
t.Cleanup(func() {
|
||||
_ = client.APIKey.DeleteOneID(k.ID).Exec(ctx)
|
||||
_ = client.User.DeleteOneID(u.ID).Exec(ctx)
|
||||
})
|
||||
|
||||
// 10 个 goroutine 各递增 1.0,总计应为 10.0
|
||||
const goroutines = 10
|
||||
const increment = 1.0
|
||||
var wg sync.WaitGroup
|
||||
errs := make([]error, goroutines)
|
||||
|
||||
for i := 0; i < goroutines; i++ {
|
||||
wg.Add(1)
|
||||
go func(idx int) {
|
||||
defer wg.Done()
|
||||
_, errs[idx] = repo.IncrementQuotaUsed(ctx, k.ID, increment)
|
||||
}(i)
|
||||
}
|
||||
wg.Wait()
|
||||
|
||||
for i, e := range errs {
|
||||
require.NoError(t, e, "goroutine %d failed", i)
|
||||
}
|
||||
|
||||
// 验证最终结果
|
||||
got, err := repo.GetByID(ctx, k.ID)
|
||||
require.NoError(t, err, "GetByID")
|
||||
require.Equal(t, float64(goroutines)*increment, got.QuotaUsed,
|
||||
"并发递增后总和应为 %v,实际为 %v", float64(goroutines)*increment, got.QuotaUsed)
|
||||
}
|
||||
|
||||
156
backend/internal/repository/api_key_repo_last_used_unit_test.go
Normal file
156
backend/internal/repository/api_key_repo_last_used_unit_test.go
Normal file
@@ -0,0 +1,156 @@
|
||||
package repository
|
||||
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
dbent "github.com/Wei-Shaw/sub2api/ent"
|
||||
"github.com/Wei-Shaw/sub2api/ent/enttest"
|
||||
"github.com/Wei-Shaw/sub2api/internal/service"
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
"entgo.io/ent/dialect"
|
||||
entsql "entgo.io/ent/dialect/sql"
|
||||
_ "modernc.org/sqlite"
|
||||
)
|
||||
|
||||
func newAPIKeyRepoSQLite(t *testing.T) (*apiKeyRepository, *dbent.Client) {
|
||||
t.Helper()
|
||||
|
||||
db, err := sql.Open("sqlite", "file:api_key_repo_last_used?mode=memory&cache=shared")
|
||||
require.NoError(t, err)
|
||||
t.Cleanup(func() { _ = db.Close() })
|
||||
|
||||
_, err = db.Exec("PRAGMA foreign_keys = ON")
|
||||
require.NoError(t, err)
|
||||
|
||||
drv := entsql.OpenDB(dialect.SQLite, db)
|
||||
client := enttest.NewClient(t, enttest.WithOptions(dbent.Driver(drv)))
|
||||
t.Cleanup(func() { _ = client.Close() })
|
||||
|
||||
return &apiKeyRepository{client: client}, client
|
||||
}
|
||||
|
||||
func mustCreateAPIKeyRepoUser(t *testing.T, ctx context.Context, client *dbent.Client, email string) *service.User {
|
||||
t.Helper()
|
||||
u, err := client.User.Create().
|
||||
SetEmail(email).
|
||||
SetPasswordHash("test-password-hash").
|
||||
SetRole(service.RoleUser).
|
||||
SetStatus(service.StatusActive).
|
||||
Save(ctx)
|
||||
require.NoError(t, err)
|
||||
return userEntityToService(u)
|
||||
}
|
||||
|
||||
func TestAPIKeyRepository_CreateWithLastUsedAt(t *testing.T) {
|
||||
repo, client := newAPIKeyRepoSQLite(t)
|
||||
ctx := context.Background()
|
||||
user := mustCreateAPIKeyRepoUser(t, ctx, client, "create-last-used@test.com")
|
||||
|
||||
lastUsed := time.Now().UTC().Add(-time.Hour).Truncate(time.Second)
|
||||
key := &service.APIKey{
|
||||
UserID: user.ID,
|
||||
Key: "sk-create-last-used",
|
||||
Name: "CreateWithLastUsed",
|
||||
Status: service.StatusActive,
|
||||
LastUsedAt: &lastUsed,
|
||||
}
|
||||
|
||||
require.NoError(t, repo.Create(ctx, key))
|
||||
require.NotNil(t, key.LastUsedAt)
|
||||
require.WithinDuration(t, lastUsed, *key.LastUsedAt, time.Second)
|
||||
|
||||
got, err := repo.GetByID(ctx, key.ID)
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, got.LastUsedAt)
|
||||
require.WithinDuration(t, lastUsed, *got.LastUsedAt, time.Second)
|
||||
}
|
||||
|
||||
func TestAPIKeyRepository_UpdateLastUsed(t *testing.T) {
|
||||
repo, client := newAPIKeyRepoSQLite(t)
|
||||
ctx := context.Background()
|
||||
user := mustCreateAPIKeyRepoUser(t, ctx, client, "update-last-used@test.com")
|
||||
|
||||
key := &service.APIKey{
|
||||
UserID: user.ID,
|
||||
Key: "sk-update-last-used",
|
||||
Name: "UpdateLastUsed",
|
||||
Status: service.StatusActive,
|
||||
}
|
||||
require.NoError(t, repo.Create(ctx, key))
|
||||
|
||||
before, err := repo.GetByID(ctx, key.ID)
|
||||
require.NoError(t, err)
|
||||
require.Nil(t, before.LastUsedAt)
|
||||
|
||||
target := time.Now().UTC().Add(2 * time.Minute).Truncate(time.Second)
|
||||
require.NoError(t, repo.UpdateLastUsed(ctx, key.ID, target))
|
||||
|
||||
after, err := repo.GetByID(ctx, key.ID)
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, after.LastUsedAt)
|
||||
require.WithinDuration(t, target, *after.LastUsedAt, time.Second)
|
||||
require.WithinDuration(t, target, after.UpdatedAt, time.Second)
|
||||
}
|
||||
|
||||
func TestAPIKeyRepository_UpdateLastUsedDeletedKey(t *testing.T) {
|
||||
repo, client := newAPIKeyRepoSQLite(t)
|
||||
ctx := context.Background()
|
||||
user := mustCreateAPIKeyRepoUser(t, ctx, client, "deleted-last-used@test.com")
|
||||
|
||||
key := &service.APIKey{
|
||||
UserID: user.ID,
|
||||
Key: "sk-update-last-used-deleted",
|
||||
Name: "UpdateLastUsedDeleted",
|
||||
Status: service.StatusActive,
|
||||
}
|
||||
require.NoError(t, repo.Create(ctx, key))
|
||||
require.NoError(t, repo.Delete(ctx, key.ID))
|
||||
|
||||
err := repo.UpdateLastUsed(ctx, key.ID, time.Now().UTC())
|
||||
require.ErrorIs(t, err, service.ErrAPIKeyNotFound)
|
||||
}
|
||||
|
||||
func TestAPIKeyRepository_UpdateLastUsedDBError(t *testing.T) {
|
||||
repo, client := newAPIKeyRepoSQLite(t)
|
||||
ctx := context.Background()
|
||||
user := mustCreateAPIKeyRepoUser(t, ctx, client, "db-error-last-used@test.com")
|
||||
|
||||
key := &service.APIKey{
|
||||
UserID: user.ID,
|
||||
Key: "sk-update-last-used-db-error",
|
||||
Name: "UpdateLastUsedDBError",
|
||||
Status: service.StatusActive,
|
||||
}
|
||||
require.NoError(t, repo.Create(ctx, key))
|
||||
|
||||
require.NoError(t, client.Close())
|
||||
err := repo.UpdateLastUsed(ctx, key.ID, time.Now().UTC())
|
||||
require.Error(t, err)
|
||||
}
|
||||
|
||||
func TestAPIKeyRepository_CreateDuplicateKey(t *testing.T) {
|
||||
repo, client := newAPIKeyRepoSQLite(t)
|
||||
ctx := context.Background()
|
||||
user := mustCreateAPIKeyRepoUser(t, ctx, client, "duplicate-key@test.com")
|
||||
|
||||
first := &service.APIKey{
|
||||
UserID: user.ID,
|
||||
Key: "sk-duplicate",
|
||||
Name: "first",
|
||||
Status: service.StatusActive,
|
||||
}
|
||||
second := &service.APIKey{
|
||||
UserID: user.ID,
|
||||
Key: "sk-duplicate",
|
||||
Name: "second",
|
||||
Status: service.StatusActive,
|
||||
}
|
||||
|
||||
require.NoError(t, repo.Create(ctx, first))
|
||||
err := repo.Create(ctx, second)
|
||||
require.ErrorIs(t, err, service.ErrAPIKeyExists)
|
||||
}
|
||||
@@ -5,6 +5,7 @@ import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"log"
|
||||
"math/rand/v2"
|
||||
"strconv"
|
||||
"time"
|
||||
|
||||
@@ -16,8 +17,19 @@ const (
|
||||
billingBalanceKeyPrefix = "billing:balance:"
|
||||
billingSubKeyPrefix = "billing:sub:"
|
||||
billingCacheTTL = 5 * time.Minute
|
||||
billingCacheJitter = 30 * time.Second
|
||||
)
|
||||
|
||||
// jitteredTTL 返回带随机抖动的 TTL,防止缓存雪崩
|
||||
func jitteredTTL() time.Duration {
|
||||
// 只做“减法抖动”,确保实际 TTL 不会超过 billingCacheTTL(避免上界预期被打破)。
|
||||
if billingCacheJitter <= 0 {
|
||||
return billingCacheTTL
|
||||
}
|
||||
jitter := time.Duration(rand.IntN(int(billingCacheJitter)))
|
||||
return billingCacheTTL - jitter
|
||||
}
|
||||
|
||||
// billingBalanceKey generates the Redis key for user balance cache.
|
||||
func billingBalanceKey(userID int64) string {
|
||||
return fmt.Sprintf("%s%d", billingBalanceKeyPrefix, userID)
|
||||
@@ -82,14 +94,15 @@ func (c *billingCache) GetUserBalance(ctx context.Context, userID int64) (float6
|
||||
|
||||
func (c *billingCache) SetUserBalance(ctx context.Context, userID int64, balance float64) error {
|
||||
key := billingBalanceKey(userID)
|
||||
return c.rdb.Set(ctx, key, balance, billingCacheTTL).Err()
|
||||
return c.rdb.Set(ctx, key, balance, jitteredTTL()).Err()
|
||||
}
|
||||
|
||||
func (c *billingCache) DeductUserBalance(ctx context.Context, userID int64, amount float64) error {
|
||||
key := billingBalanceKey(userID)
|
||||
_, err := deductBalanceScript.Run(ctx, c.rdb, []string{key}, amount, int(billingCacheTTL.Seconds())).Result()
|
||||
_, err := deductBalanceScript.Run(ctx, c.rdb, []string{key}, amount, int(jitteredTTL().Seconds())).Result()
|
||||
if err != nil && !errors.Is(err, redis.Nil) {
|
||||
log.Printf("Warning: deduct balance cache failed for user %d: %v", userID, err)
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
}
|
||||
@@ -163,16 +176,17 @@ func (c *billingCache) SetSubscriptionCache(ctx context.Context, userID, groupID
|
||||
|
||||
pipe := c.rdb.Pipeline()
|
||||
pipe.HSet(ctx, key, fields)
|
||||
pipe.Expire(ctx, key, billingCacheTTL)
|
||||
pipe.Expire(ctx, key, jitteredTTL())
|
||||
_, err := pipe.Exec(ctx)
|
||||
return err
|
||||
}
|
||||
|
||||
func (c *billingCache) UpdateSubscriptionUsage(ctx context.Context, userID, groupID int64, cost float64) error {
|
||||
key := billingSubKey(userID, groupID)
|
||||
_, err := updateSubUsageScript.Run(ctx, c.rdb, []string{key}, cost, int(billingCacheTTL.Seconds())).Result()
|
||||
_, err := updateSubUsageScript.Run(ctx, c.rdb, []string{key}, cost, int(jitteredTTL().Seconds())).Result()
|
||||
if err != nil && !errors.Is(err, redis.Nil) {
|
||||
log.Printf("Warning: update subscription usage cache failed for user %d group %d: %v", userID, groupID, err)
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -278,6 +278,90 @@ func (s *BillingCacheSuite) TestSubscriptionCache() {
|
||||
}
|
||||
}
|
||||
|
||||
// TestDeductUserBalance_ErrorPropagation 验证 P2-12 修复:
|
||||
// Redis 真实错误应传播,key 不存在(redis.Nil)应返回 nil。
|
||||
func (s *BillingCacheSuite) TestDeductUserBalance_ErrorPropagation() {
|
||||
tests := []struct {
|
||||
name string
|
||||
fn func(ctx context.Context, cache service.BillingCache)
|
||||
expectErr bool
|
||||
}{
|
||||
{
|
||||
name: "key_not_exists_returns_nil",
|
||||
fn: func(ctx context.Context, cache service.BillingCache) {
|
||||
// key 不存在时,Lua 脚本返回 0(redis.Nil),应返回 nil 而非错误
|
||||
err := cache.DeductUserBalance(ctx, 99999, 1.0)
|
||||
require.NoError(s.T(), err, "DeductUserBalance on non-existent key should return nil")
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "existing_key_deducts_successfully",
|
||||
fn: func(ctx context.Context, cache service.BillingCache) {
|
||||
require.NoError(s.T(), cache.SetUserBalance(ctx, 200, 50.0))
|
||||
err := cache.DeductUserBalance(ctx, 200, 10.0)
|
||||
require.NoError(s.T(), err, "DeductUserBalance should succeed")
|
||||
|
||||
bal, err := cache.GetUserBalance(ctx, 200)
|
||||
require.NoError(s.T(), err)
|
||||
require.Equal(s.T(), 40.0, bal, "余额应为 40.0")
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "cancelled_context_propagates_error",
|
||||
fn: func(ctx context.Context, cache service.BillingCache) {
|
||||
require.NoError(s.T(), cache.SetUserBalance(ctx, 201, 50.0))
|
||||
|
||||
cancelCtx, cancel := context.WithCancel(ctx)
|
||||
cancel() // 立即取消
|
||||
|
||||
err := cache.DeductUserBalance(cancelCtx, 201, 10.0)
|
||||
require.Error(s.T(), err, "cancelled context should propagate error")
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
s.Run(tt.name, func() {
|
||||
rdb := testRedis(s.T())
|
||||
cache := NewBillingCache(rdb)
|
||||
ctx := context.Background()
|
||||
tt.fn(ctx, cache)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// TestUpdateSubscriptionUsage_ErrorPropagation 验证 P2-12 修复:
|
||||
// Redis 真实错误应传播,key 不存在(redis.Nil)应返回 nil。
|
||||
func (s *BillingCacheSuite) TestUpdateSubscriptionUsage_ErrorPropagation() {
|
||||
s.Run("key_not_exists_returns_nil", func() {
|
||||
rdb := testRedis(s.T())
|
||||
cache := NewBillingCache(rdb)
|
||||
ctx := context.Background()
|
||||
|
||||
err := cache.UpdateSubscriptionUsage(ctx, 88888, 77777, 1.0)
|
||||
require.NoError(s.T(), err, "UpdateSubscriptionUsage on non-existent key should return nil")
|
||||
})
|
||||
|
||||
s.Run("cancelled_context_propagates_error", func() {
|
||||
rdb := testRedis(s.T())
|
||||
cache := NewBillingCache(rdb)
|
||||
ctx := context.Background()
|
||||
|
||||
data := &service.SubscriptionCacheData{
|
||||
Status: "active",
|
||||
ExpiresAt: time.Now().Add(1 * time.Hour),
|
||||
Version: 1,
|
||||
}
|
||||
require.NoError(s.T(), cache.SetSubscriptionCache(ctx, 301, 401, data))
|
||||
|
||||
cancelCtx, cancel := context.WithCancel(ctx)
|
||||
cancel()
|
||||
|
||||
err := cache.UpdateSubscriptionUsage(cancelCtx, 301, 401, 1.0)
|
||||
require.Error(s.T(), err, "cancelled context should propagate error")
|
||||
})
|
||||
}
|
||||
|
||||
func TestBillingCacheSuite(t *testing.T) {
|
||||
suite.Run(t, new(BillingCacheSuite))
|
||||
}
|
||||
|
||||
82
backend/internal/repository/billing_cache_jitter_test.go
Normal file
82
backend/internal/repository/billing_cache_jitter_test.go
Normal file
@@ -0,0 +1,82 @@
|
||||
package repository
|
||||
|
||||
import (
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
// --- Task 6.1 验证: math/rand/v2 迁移后 jitteredTTL 行为正确 ---
|
||||
|
||||
func TestJitteredTTL_WithinExpectedRange(t *testing.T) {
|
||||
// jitteredTTL 使用减法抖动: billingCacheTTL - [0, billingCacheJitter)
|
||||
// 所以结果应在 [billingCacheTTL - billingCacheJitter, billingCacheTTL] 范围内
|
||||
lowerBound := billingCacheTTL - billingCacheJitter // 5min - 30s = 4min30s
|
||||
upperBound := billingCacheTTL // 5min
|
||||
|
||||
for i := 0; i < 200; i++ {
|
||||
ttl := jitteredTTL()
|
||||
assert.GreaterOrEqual(t, int64(ttl), int64(lowerBound),
|
||||
"TTL 不应低于 %v,实际得到 %v", lowerBound, ttl)
|
||||
assert.LessOrEqual(t, int64(ttl), int64(upperBound),
|
||||
"TTL 不应超过 %v(上界不变保证),实际得到 %v", upperBound, ttl)
|
||||
}
|
||||
}
|
||||
|
||||
func TestJitteredTTL_NeverExceedsBase(t *testing.T) {
|
||||
// 关键安全性测试:jitteredTTL 使用减法抖动,确保永远不超过 billingCacheTTL
|
||||
for i := 0; i < 500; i++ {
|
||||
ttl := jitteredTTL()
|
||||
assert.LessOrEqual(t, int64(ttl), int64(billingCacheTTL),
|
||||
"jitteredTTL 不应超过基础 TTL(上界预期不被打破)")
|
||||
}
|
||||
}
|
||||
|
||||
func TestJitteredTTL_HasVariance(t *testing.T) {
|
||||
// 验证抖动确实产生了不同的值
|
||||
results := make(map[time.Duration]bool)
|
||||
for i := 0; i < 100; i++ {
|
||||
ttl := jitteredTTL()
|
||||
results[ttl] = true
|
||||
}
|
||||
|
||||
require.Greater(t, len(results), 1,
|
||||
"jitteredTTL 应产生不同的值(抖动生效),但 100 次调用结果全部相同")
|
||||
}
|
||||
|
||||
func TestJitteredTTL_AverageNearCenter(t *testing.T) {
|
||||
// 验证平均值大约在抖动范围中间
|
||||
var sum time.Duration
|
||||
runs := 1000
|
||||
for i := 0; i < runs; i++ {
|
||||
sum += jitteredTTL()
|
||||
}
|
||||
|
||||
avg := sum / time.Duration(runs)
|
||||
expectedCenter := billingCacheTTL - billingCacheJitter/2 // 4min45s
|
||||
|
||||
// 允许 ±5s 的误差
|
||||
tolerance := 5 * time.Second
|
||||
assert.InDelta(t, float64(expectedCenter), float64(avg), float64(tolerance),
|
||||
"平均 TTL 应接近抖动范围中心 %v", expectedCenter)
|
||||
}
|
||||
|
||||
func TestBillingKeyGeneration(t *testing.T) {
|
||||
t.Run("balance_key", func(t *testing.T) {
|
||||
key := billingBalanceKey(12345)
|
||||
assert.Equal(t, "billing:balance:12345", key)
|
||||
})
|
||||
|
||||
t.Run("sub_key", func(t *testing.T) {
|
||||
key := billingSubKey(100, 200)
|
||||
assert.Equal(t, "billing:sub:100:200", key)
|
||||
})
|
||||
}
|
||||
|
||||
func BenchmarkJitteredTTL(b *testing.B) {
|
||||
for i := 0; i < b.N; i++ {
|
||||
_ = jitteredTTL()
|
||||
}
|
||||
}
|
||||
@@ -5,6 +5,7 @@ package repository
|
||||
import (
|
||||
"math"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
@@ -85,3 +86,26 @@ func TestBillingSubKey(t *testing.T) {
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestJitteredTTL(t *testing.T) {
|
||||
const (
|
||||
minTTL = 4*time.Minute + 30*time.Second // 270s = 5min - 30s
|
||||
maxTTL = 5*time.Minute + 30*time.Second // 330s = 5min + 30s
|
||||
)
|
||||
|
||||
for i := 0; i < 200; i++ {
|
||||
ttl := jitteredTTL()
|
||||
require.GreaterOrEqual(t, ttl, minTTL, "jitteredTTL() 返回值低于下限: %v", ttl)
|
||||
require.LessOrEqual(t, ttl, maxTTL, "jitteredTTL() 返回值超过上限: %v", ttl)
|
||||
}
|
||||
}
|
||||
|
||||
func TestJitteredTTL_HasVariation(t *testing.T) {
|
||||
// 多次调用应该产生不同的值(验证抖动存在)
|
||||
seen := make(map[time.Duration]struct{}, 50)
|
||||
for i := 0; i < 50; i++ {
|
||||
seen[jitteredTTL()] = struct{}{}
|
||||
}
|
||||
// 50 次调用中应该至少有 2 个不同的值
|
||||
require.Greater(t, len(seen), 1, "jitteredTTL() 应产生不同的 TTL 值")
|
||||
}
|
||||
|
||||
@@ -4,12 +4,12 @@ import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"log"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/logger"
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/oauth"
|
||||
"github.com/Wei-Shaw/sub2api/internal/service"
|
||||
"github.com/Wei-Shaw/sub2api/internal/util/logredact"
|
||||
@@ -41,7 +41,7 @@ func (s *claudeOAuthService) GetOrganizationUUID(ctx context.Context, sessionKey
|
||||
}
|
||||
|
||||
targetURL := s.baseURL + "/api/organizations"
|
||||
log.Printf("[OAuth] Step 1: Getting organization UUID from %s", targetURL)
|
||||
logger.LegacyPrintf("repository.claude_oauth", "[OAuth] Step 1: Getting organization UUID from %s", targetURL)
|
||||
|
||||
resp, err := client.R().
|
||||
SetContext(ctx).
|
||||
@@ -53,11 +53,11 @@ func (s *claudeOAuthService) GetOrganizationUUID(ctx context.Context, sessionKey
|
||||
Get(targetURL)
|
||||
|
||||
if err != nil {
|
||||
log.Printf("[OAuth] Step 1 FAILED - Request error: %v", err)
|
||||
logger.LegacyPrintf("repository.claude_oauth", "[OAuth] Step 1 FAILED - Request error: %v", err)
|
||||
return "", fmt.Errorf("request failed: %w", err)
|
||||
}
|
||||
|
||||
log.Printf("[OAuth] Step 1 Response - Status: %d", resp.StatusCode)
|
||||
logger.LegacyPrintf("repository.claude_oauth", "[OAuth] Step 1 Response - Status: %d", resp.StatusCode)
|
||||
|
||||
if !resp.IsSuccessState() {
|
||||
return "", fmt.Errorf("failed to get organizations: status %d, body: %s", resp.StatusCode, resp.String())
|
||||
@@ -69,21 +69,21 @@ func (s *claudeOAuthService) GetOrganizationUUID(ctx context.Context, sessionKey
|
||||
|
||||
// 如果只有一个组织,直接使用
|
||||
if len(orgs) == 1 {
|
||||
log.Printf("[OAuth] Step 1 SUCCESS - Single org found, UUID: %s, Name: %s", orgs[0].UUID, orgs[0].Name)
|
||||
logger.LegacyPrintf("repository.claude_oauth", "[OAuth] Step 1 SUCCESS - Single org found, UUID: %s, Name: %s", orgs[0].UUID, orgs[0].Name)
|
||||
return orgs[0].UUID, nil
|
||||
}
|
||||
|
||||
// 如果有多个组织,优先选择 raven_type 为 "team" 的组织
|
||||
for _, org := range orgs {
|
||||
if org.RavenType != nil && *org.RavenType == "team" {
|
||||
log.Printf("[OAuth] Step 1 SUCCESS - Selected team org, UUID: %s, Name: %s, RavenType: %s",
|
||||
logger.LegacyPrintf("repository.claude_oauth", "[OAuth] Step 1 SUCCESS - Selected team org, UUID: %s, Name: %s, RavenType: %s",
|
||||
org.UUID, org.Name, *org.RavenType)
|
||||
return org.UUID, nil
|
||||
}
|
||||
}
|
||||
|
||||
// 如果没有 team 类型的组织,使用第一个
|
||||
log.Printf("[OAuth] Step 1 SUCCESS - No team org found, using first org, UUID: %s, Name: %s", orgs[0].UUID, orgs[0].Name)
|
||||
logger.LegacyPrintf("repository.claude_oauth", "[OAuth] Step 1 SUCCESS - No team org found, using first org, UUID: %s, Name: %s", orgs[0].UUID, orgs[0].Name)
|
||||
return orgs[0].UUID, nil
|
||||
}
|
||||
|
||||
@@ -103,9 +103,9 @@ func (s *claudeOAuthService) GetAuthorizationCode(ctx context.Context, sessionKe
|
||||
"code_challenge_method": "S256",
|
||||
}
|
||||
|
||||
log.Printf("[OAuth] Step 2: Getting authorization code from %s", authURL)
|
||||
logger.LegacyPrintf("repository.claude_oauth", "[OAuth] Step 2: Getting authorization code from %s", authURL)
|
||||
reqBodyJSON, _ := json.Marshal(logredact.RedactMap(reqBody))
|
||||
log.Printf("[OAuth] Step 2 Request Body: %s", string(reqBodyJSON))
|
||||
logger.LegacyPrintf("repository.claude_oauth", "[OAuth] Step 2 Request Body: %s", string(reqBodyJSON))
|
||||
|
||||
var result struct {
|
||||
RedirectURI string `json:"redirect_uri"`
|
||||
@@ -128,11 +128,11 @@ func (s *claudeOAuthService) GetAuthorizationCode(ctx context.Context, sessionKe
|
||||
Post(authURL)
|
||||
|
||||
if err != nil {
|
||||
log.Printf("[OAuth] Step 2 FAILED - Request error: %v", err)
|
||||
logger.LegacyPrintf("repository.claude_oauth", "[OAuth] Step 2 FAILED - Request error: %v", err)
|
||||
return "", fmt.Errorf("request failed: %w", err)
|
||||
}
|
||||
|
||||
log.Printf("[OAuth] Step 2 Response - Status: %d, Body: %s", resp.StatusCode, logredact.RedactJSON(resp.Bytes()))
|
||||
logger.LegacyPrintf("repository.claude_oauth", "[OAuth] Step 2 Response - Status: %d, Body: %s", resp.StatusCode, logredact.RedactJSON(resp.Bytes()))
|
||||
|
||||
if !resp.IsSuccessState() {
|
||||
return "", fmt.Errorf("failed to get authorization code: status %d, body: %s", resp.StatusCode, resp.String())
|
||||
@@ -160,7 +160,7 @@ func (s *claudeOAuthService) GetAuthorizationCode(ctx context.Context, sessionKe
|
||||
fullCode = authCode + "#" + responseState
|
||||
}
|
||||
|
||||
log.Printf("[OAuth] Step 2 SUCCESS - Got authorization code")
|
||||
logger.LegacyPrintf("repository.claude_oauth", "[OAuth] Step 2 SUCCESS - Got authorization code")
|
||||
return fullCode, nil
|
||||
}
|
||||
|
||||
@@ -192,9 +192,9 @@ func (s *claudeOAuthService) ExchangeCodeForToken(ctx context.Context, code, cod
|
||||
reqBody["expires_in"] = 31536000 // 365 * 24 * 60 * 60 seconds
|
||||
}
|
||||
|
||||
log.Printf("[OAuth] Step 3: Exchanging code for token at %s", s.tokenURL)
|
||||
logger.LegacyPrintf("repository.claude_oauth", "[OAuth] Step 3: Exchanging code for token at %s", s.tokenURL)
|
||||
reqBodyJSON, _ := json.Marshal(logredact.RedactMap(reqBody))
|
||||
log.Printf("[OAuth] Step 3 Request Body: %s", string(reqBodyJSON))
|
||||
logger.LegacyPrintf("repository.claude_oauth", "[OAuth] Step 3 Request Body: %s", string(reqBodyJSON))
|
||||
|
||||
var tokenResp oauth.TokenResponse
|
||||
|
||||
@@ -208,17 +208,17 @@ func (s *claudeOAuthService) ExchangeCodeForToken(ctx context.Context, code, cod
|
||||
Post(s.tokenURL)
|
||||
|
||||
if err != nil {
|
||||
log.Printf("[OAuth] Step 3 FAILED - Request error: %v", err)
|
||||
logger.LegacyPrintf("repository.claude_oauth", "[OAuth] Step 3 FAILED - Request error: %v", err)
|
||||
return nil, fmt.Errorf("request failed: %w", err)
|
||||
}
|
||||
|
||||
log.Printf("[OAuth] Step 3 Response - Status: %d, Body: %s", resp.StatusCode, logredact.RedactJSON(resp.Bytes()))
|
||||
logger.LegacyPrintf("repository.claude_oauth", "[OAuth] Step 3 Response - Status: %d, Body: %s", resp.StatusCode, logredact.RedactJSON(resp.Bytes()))
|
||||
|
||||
if !resp.IsSuccessState() {
|
||||
return nil, fmt.Errorf("token exchange failed: status %d, body: %s", resp.StatusCode, resp.String())
|
||||
}
|
||||
|
||||
log.Printf("[OAuth] Step 3 SUCCESS - Got access token")
|
||||
logger.LegacyPrintf("repository.claude_oauth", "[OAuth] Step 3 SUCCESS - Got access token")
|
||||
return &tokenResp, nil
|
||||
}
|
||||
|
||||
|
||||
@@ -147,100 +147,6 @@ var (
|
||||
return 1
|
||||
`)
|
||||
|
||||
// getAccountsLoadBatchScript - batch load query with expired slot cleanup
|
||||
// ARGV[1] = slot TTL (seconds)
|
||||
// ARGV[2..n] = accountID1, maxConcurrency1, accountID2, maxConcurrency2, ...
|
||||
getAccountsLoadBatchScript = redis.NewScript(`
|
||||
local result = {}
|
||||
local slotTTL = tonumber(ARGV[1])
|
||||
|
||||
-- Get current server time
|
||||
local timeResult = redis.call('TIME')
|
||||
local nowSeconds = tonumber(timeResult[1])
|
||||
local cutoffTime = nowSeconds - slotTTL
|
||||
|
||||
local i = 2
|
||||
while i <= #ARGV do
|
||||
local accountID = ARGV[i]
|
||||
local maxConcurrency = tonumber(ARGV[i + 1])
|
||||
|
||||
local slotKey = 'concurrency:account:' .. accountID
|
||||
|
||||
-- Clean up expired slots before counting
|
||||
redis.call('ZREMRANGEBYSCORE', slotKey, '-inf', cutoffTime)
|
||||
local currentConcurrency = redis.call('ZCARD', slotKey)
|
||||
|
||||
local waitKey = 'wait:account:' .. accountID
|
||||
local waitingCount = redis.call('GET', waitKey)
|
||||
if waitingCount == false then
|
||||
waitingCount = 0
|
||||
else
|
||||
waitingCount = tonumber(waitingCount)
|
||||
end
|
||||
|
||||
local loadRate = 0
|
||||
if maxConcurrency > 0 then
|
||||
loadRate = math.floor((currentConcurrency + waitingCount) * 100 / maxConcurrency)
|
||||
end
|
||||
|
||||
table.insert(result, accountID)
|
||||
table.insert(result, currentConcurrency)
|
||||
table.insert(result, waitingCount)
|
||||
table.insert(result, loadRate)
|
||||
|
||||
i = i + 2
|
||||
end
|
||||
|
||||
return result
|
||||
`)
|
||||
|
||||
// getUsersLoadBatchScript - batch load query for users with expired slot cleanup
|
||||
// ARGV[1] = slot TTL (seconds)
|
||||
// ARGV[2..n] = userID1, maxConcurrency1, userID2, maxConcurrency2, ...
|
||||
getUsersLoadBatchScript = redis.NewScript(`
|
||||
local result = {}
|
||||
local slotTTL = tonumber(ARGV[1])
|
||||
|
||||
-- Get current server time
|
||||
local timeResult = redis.call('TIME')
|
||||
local nowSeconds = tonumber(timeResult[1])
|
||||
local cutoffTime = nowSeconds - slotTTL
|
||||
|
||||
local i = 2
|
||||
while i <= #ARGV do
|
||||
local userID = ARGV[i]
|
||||
local maxConcurrency = tonumber(ARGV[i + 1])
|
||||
|
||||
local slotKey = 'concurrency:user:' .. userID
|
||||
|
||||
-- Clean up expired slots before counting
|
||||
redis.call('ZREMRANGEBYSCORE', slotKey, '-inf', cutoffTime)
|
||||
local currentConcurrency = redis.call('ZCARD', slotKey)
|
||||
|
||||
local waitKey = 'concurrency:wait:' .. userID
|
||||
local waitingCount = redis.call('GET', waitKey)
|
||||
if waitingCount == false then
|
||||
waitingCount = 0
|
||||
else
|
||||
waitingCount = tonumber(waitingCount)
|
||||
end
|
||||
|
||||
local loadRate = 0
|
||||
if maxConcurrency > 0 then
|
||||
loadRate = math.floor((currentConcurrency + waitingCount) * 100 / maxConcurrency)
|
||||
end
|
||||
|
||||
table.insert(result, userID)
|
||||
table.insert(result, currentConcurrency)
|
||||
table.insert(result, waitingCount)
|
||||
table.insert(result, loadRate)
|
||||
|
||||
i = i + 2
|
||||
end
|
||||
|
||||
return result
|
||||
`)
|
||||
|
||||
// cleanupExpiredSlotsScript - remove expired slots
|
||||
// KEYS[1] = concurrency:account:{accountID}
|
||||
// ARGV[1] = TTL (seconds)
|
||||
@@ -399,29 +305,53 @@ func (c *concurrencyCache) GetAccountsLoadBatch(ctx context.Context, accounts []
|
||||
return map[int64]*service.AccountLoadInfo{}, nil
|
||||
}
|
||||
|
||||
args := []any{c.slotTTLSeconds}
|
||||
for _, acc := range accounts {
|
||||
args = append(args, acc.ID, acc.MaxConcurrency)
|
||||
}
|
||||
|
||||
result, err := getAccountsLoadBatchScript.Run(ctx, c.rdb, []string{}, args...).Slice()
|
||||
// 使用 Pipeline 替代 Lua 脚本,兼容 Redis Cluster(Lua 内动态拼 key 会 CROSSSLOT)。
|
||||
// 每个账号执行 3 个命令:ZREMRANGEBYSCORE(清理过期)、ZCARD(并发数)、GET(等待数)。
|
||||
now, err := c.rdb.Time(ctx).Result()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
return nil, fmt.Errorf("redis TIME: %w", err)
|
||||
}
|
||||
cutoffTime := now.Unix() - int64(c.slotTTLSeconds)
|
||||
|
||||
pipe := c.rdb.Pipeline()
|
||||
|
||||
type accountCmds struct {
|
||||
id int64
|
||||
maxConcurrency int
|
||||
zcardCmd *redis.IntCmd
|
||||
getCmd *redis.StringCmd
|
||||
}
|
||||
cmds := make([]accountCmds, 0, len(accounts))
|
||||
for _, acc := range accounts {
|
||||
slotKey := accountSlotKeyPrefix + strconv.FormatInt(acc.ID, 10)
|
||||
waitKey := accountWaitKeyPrefix + strconv.FormatInt(acc.ID, 10)
|
||||
pipe.ZRemRangeByScore(ctx, slotKey, "-inf", strconv.FormatInt(cutoffTime, 10))
|
||||
ac := accountCmds{
|
||||
id: acc.ID,
|
||||
maxConcurrency: acc.MaxConcurrency,
|
||||
zcardCmd: pipe.ZCard(ctx, slotKey),
|
||||
getCmd: pipe.Get(ctx, waitKey),
|
||||
}
|
||||
cmds = append(cmds, ac)
|
||||
}
|
||||
|
||||
loadMap := make(map[int64]*service.AccountLoadInfo)
|
||||
for i := 0; i < len(result); i += 4 {
|
||||
if i+3 >= len(result) {
|
||||
break
|
||||
if _, err := pipe.Exec(ctx); err != nil && !errors.Is(err, redis.Nil) {
|
||||
return nil, fmt.Errorf("pipeline exec: %w", err)
|
||||
}
|
||||
|
||||
loadMap := make(map[int64]*service.AccountLoadInfo, len(accounts))
|
||||
for _, ac := range cmds {
|
||||
currentConcurrency := int(ac.zcardCmd.Val())
|
||||
waitingCount := 0
|
||||
if v, err := ac.getCmd.Int(); err == nil {
|
||||
waitingCount = v
|
||||
}
|
||||
|
||||
accountID, _ := strconv.ParseInt(fmt.Sprintf("%v", result[i]), 10, 64)
|
||||
currentConcurrency, _ := strconv.Atoi(fmt.Sprintf("%v", result[i+1]))
|
||||
waitingCount, _ := strconv.Atoi(fmt.Sprintf("%v", result[i+2]))
|
||||
loadRate, _ := strconv.Atoi(fmt.Sprintf("%v", result[i+3]))
|
||||
|
||||
loadMap[accountID] = &service.AccountLoadInfo{
|
||||
AccountID: accountID,
|
||||
loadRate := 0
|
||||
if ac.maxConcurrency > 0 {
|
||||
loadRate = (currentConcurrency + waitingCount) * 100 / ac.maxConcurrency
|
||||
}
|
||||
loadMap[ac.id] = &service.AccountLoadInfo{
|
||||
AccountID: ac.id,
|
||||
CurrentConcurrency: currentConcurrency,
|
||||
WaitingCount: waitingCount,
|
||||
LoadRate: loadRate,
|
||||
@@ -436,29 +366,52 @@ func (c *concurrencyCache) GetUsersLoadBatch(ctx context.Context, users []servic
|
||||
return map[int64]*service.UserLoadInfo{}, nil
|
||||
}
|
||||
|
||||
args := []any{c.slotTTLSeconds}
|
||||
for _, u := range users {
|
||||
args = append(args, u.ID, u.MaxConcurrency)
|
||||
}
|
||||
|
||||
result, err := getUsersLoadBatchScript.Run(ctx, c.rdb, []string{}, args...).Slice()
|
||||
// 使用 Pipeline 替代 Lua 脚本,兼容 Redis Cluster。
|
||||
now, err := c.rdb.Time(ctx).Result()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
return nil, fmt.Errorf("redis TIME: %w", err)
|
||||
}
|
||||
cutoffTime := now.Unix() - int64(c.slotTTLSeconds)
|
||||
|
||||
pipe := c.rdb.Pipeline()
|
||||
|
||||
type userCmds struct {
|
||||
id int64
|
||||
maxConcurrency int
|
||||
zcardCmd *redis.IntCmd
|
||||
getCmd *redis.StringCmd
|
||||
}
|
||||
cmds := make([]userCmds, 0, len(users))
|
||||
for _, u := range users {
|
||||
slotKey := userSlotKeyPrefix + strconv.FormatInt(u.ID, 10)
|
||||
waitKey := waitQueueKeyPrefix + strconv.FormatInt(u.ID, 10)
|
||||
pipe.ZRemRangeByScore(ctx, slotKey, "-inf", strconv.FormatInt(cutoffTime, 10))
|
||||
uc := userCmds{
|
||||
id: u.ID,
|
||||
maxConcurrency: u.MaxConcurrency,
|
||||
zcardCmd: pipe.ZCard(ctx, slotKey),
|
||||
getCmd: pipe.Get(ctx, waitKey),
|
||||
}
|
||||
cmds = append(cmds, uc)
|
||||
}
|
||||
|
||||
loadMap := make(map[int64]*service.UserLoadInfo)
|
||||
for i := 0; i < len(result); i += 4 {
|
||||
if i+3 >= len(result) {
|
||||
break
|
||||
if _, err := pipe.Exec(ctx); err != nil && !errors.Is(err, redis.Nil) {
|
||||
return nil, fmt.Errorf("pipeline exec: %w", err)
|
||||
}
|
||||
|
||||
loadMap := make(map[int64]*service.UserLoadInfo, len(users))
|
||||
for _, uc := range cmds {
|
||||
currentConcurrency := int(uc.zcardCmd.Val())
|
||||
waitingCount := 0
|
||||
if v, err := uc.getCmd.Int(); err == nil {
|
||||
waitingCount = v
|
||||
}
|
||||
|
||||
userID, _ := strconv.ParseInt(fmt.Sprintf("%v", result[i]), 10, 64)
|
||||
currentConcurrency, _ := strconv.Atoi(fmt.Sprintf("%v", result[i+1]))
|
||||
waitingCount, _ := strconv.Atoi(fmt.Sprintf("%v", result[i+2]))
|
||||
loadRate, _ := strconv.Atoi(fmt.Sprintf("%v", result[i+3]))
|
||||
|
||||
loadMap[userID] = &service.UserLoadInfo{
|
||||
UserID: userID,
|
||||
loadRate := 0
|
||||
if uc.maxConcurrency > 0 {
|
||||
loadRate = (currentConcurrency + waitingCount) * 100 / uc.maxConcurrency
|
||||
}
|
||||
loadMap[uc.id] = &service.UserLoadInfo{
|
||||
UserID: uc.id,
|
||||
CurrentConcurrency: currentConcurrency,
|
||||
WaitingCount: waitingCount,
|
||||
LoadRate: loadRate,
|
||||
|
||||
@@ -5,6 +5,7 @@ package repository
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"fmt"
|
||||
"time"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/ent"
|
||||
@@ -66,6 +67,18 @@ func InitEnt(cfg *config.Config) (*ent.Client, *sql.DB, error) {
|
||||
// 创建 Ent 客户端,绑定到已配置的数据库驱动。
|
||||
client := ent.NewClient(ent.Driver(drv))
|
||||
|
||||
// 启动阶段:从配置或数据库中确保系统密钥可用。
|
||||
if err := ensureBootstrapSecrets(migrationCtx, client, cfg); err != nil {
|
||||
_ = client.Close()
|
||||
return nil, nil, err
|
||||
}
|
||||
|
||||
// 在密钥补齐后执行完整配置校验,避免空 jwt.secret 导致服务运行时失败。
|
||||
if err := cfg.Validate(); err != nil {
|
||||
_ = client.Close()
|
||||
return nil, nil, fmt.Errorf("validate config after secret bootstrap: %w", err)
|
||||
}
|
||||
|
||||
// SIMPLE 模式:启动时补齐各平台默认分组。
|
||||
// - anthropic/openai/gemini: 确保存在 <platform>-default
|
||||
// - antigravity: 仅要求存在 >=2 个未软删除分组(用于 claude/gemini 混合调度场景)
|
||||
|
||||
@@ -18,14 +18,21 @@ type githubReleaseClient struct {
|
||||
downloadHTTPClient *http.Client
|
||||
}
|
||||
|
||||
type githubReleaseClientError struct {
|
||||
err error
|
||||
}
|
||||
|
||||
// NewGitHubReleaseClient 创建 GitHub Release 客户端
|
||||
// proxyURL 为空时直连 GitHub,支持 http/https/socks5/socks5h 协议
|
||||
func NewGitHubReleaseClient(proxyURL string) service.GitHubReleaseClient {
|
||||
func NewGitHubReleaseClient(proxyURL string, allowDirectOnProxyError bool) service.GitHubReleaseClient {
|
||||
sharedClient, err := httpclient.GetClient(httpclient.Options{
|
||||
Timeout: 30 * time.Second,
|
||||
ProxyURL: proxyURL,
|
||||
})
|
||||
if err != nil {
|
||||
if proxyURL != "" && !allowDirectOnProxyError {
|
||||
return &githubReleaseClientError{err: fmt.Errorf("proxy client init failed and direct fallback is disabled; set security.proxy_fallback.allow_direct_on_error=true to allow fallback: %w", err)}
|
||||
}
|
||||
sharedClient = &http.Client{Timeout: 30 * time.Second}
|
||||
}
|
||||
|
||||
@@ -35,6 +42,9 @@ func NewGitHubReleaseClient(proxyURL string) service.GitHubReleaseClient {
|
||||
ProxyURL: proxyURL,
|
||||
})
|
||||
if err != nil {
|
||||
if proxyURL != "" && !allowDirectOnProxyError {
|
||||
return &githubReleaseClientError{err: fmt.Errorf("proxy client init failed and direct fallback is disabled; set security.proxy_fallback.allow_direct_on_error=true to allow fallback: %w", err)}
|
||||
}
|
||||
downloadClient = &http.Client{Timeout: 10 * time.Minute}
|
||||
}
|
||||
|
||||
@@ -44,6 +54,18 @@ func NewGitHubReleaseClient(proxyURL string) service.GitHubReleaseClient {
|
||||
}
|
||||
}
|
||||
|
||||
func (c *githubReleaseClientError) FetchLatestRelease(ctx context.Context, repo string) (*service.GitHubRelease, error) {
|
||||
return nil, c.err
|
||||
}
|
||||
|
||||
func (c *githubReleaseClientError) DownloadFile(ctx context.Context, url, dest string, maxSize int64) error {
|
||||
return c.err
|
||||
}
|
||||
|
||||
func (c *githubReleaseClientError) FetchChecksumFile(ctx context.Context, url string) ([]byte, error) {
|
||||
return nil, c.err
|
||||
}
|
||||
|
||||
func (c *githubReleaseClient) FetchLatestRelease(ctx context.Context, repo string) (*service.GitHubRelease, error) {
|
||||
url := fmt.Sprintf("https://api.github.com/repos/%s/releases/latest", repo)
|
||||
|
||||
|
||||
@@ -4,11 +4,11 @@ import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"errors"
|
||||
"log"
|
||||
|
||||
dbent "github.com/Wei-Shaw/sub2api/ent"
|
||||
"github.com/Wei-Shaw/sub2api/ent/apikey"
|
||||
"github.com/Wei-Shaw/sub2api/ent/group"
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/logger"
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/pagination"
|
||||
"github.com/Wei-Shaw/sub2api/internal/service"
|
||||
"github.com/lib/pq"
|
||||
@@ -47,6 +47,10 @@ func (r *groupRepository) Create(ctx context.Context, groupIn *service.Group) er
|
||||
SetNillableImagePrice1k(groupIn.ImagePrice1K).
|
||||
SetNillableImagePrice2k(groupIn.ImagePrice2K).
|
||||
SetNillableImagePrice4k(groupIn.ImagePrice4K).
|
||||
SetNillableSoraImagePrice360(groupIn.SoraImagePrice360).
|
||||
SetNillableSoraImagePrice540(groupIn.SoraImagePrice540).
|
||||
SetNillableSoraVideoPricePerRequest(groupIn.SoraVideoPricePerRequest).
|
||||
SetNillableSoraVideoPricePerRequestHd(groupIn.SoraVideoPricePerRequestHD).
|
||||
SetDefaultValidityDays(groupIn.DefaultValidityDays).
|
||||
SetClaudeCodeOnly(groupIn.ClaudeCodeOnly).
|
||||
SetNillableFallbackGroupID(groupIn.FallbackGroupID).
|
||||
@@ -68,7 +72,7 @@ func (r *groupRepository) Create(ctx context.Context, groupIn *service.Group) er
|
||||
groupIn.CreatedAt = created.CreatedAt
|
||||
groupIn.UpdatedAt = created.UpdatedAt
|
||||
if err := enqueueSchedulerOutbox(ctx, r.sql, service.SchedulerOutboxEventGroupChanged, nil, &groupIn.ID, nil); err != nil {
|
||||
log.Printf("[SchedulerOutbox] enqueue group create failed: group=%d err=%v", groupIn.ID, err)
|
||||
logger.LegacyPrintf("repository.group", "[SchedulerOutbox] enqueue group create failed: group=%d err=%v", groupIn.ID, err)
|
||||
}
|
||||
}
|
||||
return translatePersistenceError(err, nil, service.ErrGroupExists)
|
||||
@@ -110,6 +114,10 @@ func (r *groupRepository) Update(ctx context.Context, groupIn *service.Group) er
|
||||
SetNillableImagePrice1k(groupIn.ImagePrice1K).
|
||||
SetNillableImagePrice2k(groupIn.ImagePrice2K).
|
||||
SetNillableImagePrice4k(groupIn.ImagePrice4K).
|
||||
SetNillableSoraImagePrice360(groupIn.SoraImagePrice360).
|
||||
SetNillableSoraImagePrice540(groupIn.SoraImagePrice540).
|
||||
SetNillableSoraVideoPricePerRequest(groupIn.SoraVideoPricePerRequest).
|
||||
SetNillableSoraVideoPricePerRequestHd(groupIn.SoraVideoPricePerRequestHD).
|
||||
SetDefaultValidityDays(groupIn.DefaultValidityDays).
|
||||
SetClaudeCodeOnly(groupIn.ClaudeCodeOnly).
|
||||
SetModelRoutingEnabled(groupIn.ModelRoutingEnabled).
|
||||
@@ -144,7 +152,7 @@ func (r *groupRepository) Update(ctx context.Context, groupIn *service.Group) er
|
||||
}
|
||||
groupIn.UpdatedAt = updated.UpdatedAt
|
||||
if err := enqueueSchedulerOutbox(ctx, r.sql, service.SchedulerOutboxEventGroupChanged, nil, &groupIn.ID, nil); err != nil {
|
||||
log.Printf("[SchedulerOutbox] enqueue group update failed: group=%d err=%v", groupIn.ID, err)
|
||||
logger.LegacyPrintf("repository.group", "[SchedulerOutbox] enqueue group update failed: group=%d err=%v", groupIn.ID, err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
@@ -155,7 +163,7 @@ func (r *groupRepository) Delete(ctx context.Context, id int64) error {
|
||||
return translatePersistenceError(err, service.ErrGroupNotFound, nil)
|
||||
}
|
||||
if err := enqueueSchedulerOutbox(ctx, r.sql, service.SchedulerOutboxEventGroupChanged, nil, &id, nil); err != nil {
|
||||
log.Printf("[SchedulerOutbox] enqueue group delete failed: group=%d err=%v", id, err)
|
||||
logger.LegacyPrintf("repository.group", "[SchedulerOutbox] enqueue group delete failed: group=%d err=%v", id, err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
@@ -183,7 +191,7 @@ func (r *groupRepository) ListWithFilters(ctx context.Context, params pagination
|
||||
q = q.Where(group.IsExclusiveEQ(*isExclusive))
|
||||
}
|
||||
|
||||
total, err := q.Count(ctx)
|
||||
total, err := q.Clone().Count(ctx)
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
@@ -288,7 +296,7 @@ func (r *groupRepository) DeleteAccountGroupsByGroupID(ctx context.Context, grou
|
||||
}
|
||||
affected, _ := res.RowsAffected()
|
||||
if err := enqueueSchedulerOutbox(ctx, r.sql, service.SchedulerOutboxEventGroupChanged, nil, &groupID, nil); err != nil {
|
||||
log.Printf("[SchedulerOutbox] enqueue group account clear failed: group=%d err=%v", groupID, err)
|
||||
logger.LegacyPrintf("repository.group", "[SchedulerOutbox] enqueue group account clear failed: group=%d err=%v", groupID, err)
|
||||
}
|
||||
return affected, nil
|
||||
}
|
||||
@@ -398,7 +406,7 @@ func (r *groupRepository) DeleteCascade(ctx context.Context, id int64) ([]int64,
|
||||
}
|
||||
}
|
||||
if err := enqueueSchedulerOutbox(ctx, r.sql, service.SchedulerOutboxEventGroupChanged, nil, &id, nil); err != nil {
|
||||
log.Printf("[SchedulerOutbox] enqueue group cascade delete failed: group=%d err=%v", id, err)
|
||||
logger.LegacyPrintf("repository.group", "[SchedulerOutbox] enqueue group cascade delete failed: group=%d err=%v", id, err)
|
||||
}
|
||||
|
||||
return affectedUserIDs, nil
|
||||
@@ -492,7 +500,7 @@ func (r *groupRepository) BindAccountsToGroup(ctx context.Context, groupID int64
|
||||
|
||||
// 发送调度器事件
|
||||
if err := enqueueSchedulerOutbox(ctx, r.sql, service.SchedulerOutboxEventGroupChanged, nil, &groupID, nil); err != nil {
|
||||
log.Printf("[SchedulerOutbox] enqueue bind accounts to group failed: group=%d err=%v", groupID, err)
|
||||
logger.LegacyPrintf("repository.group", "[SchedulerOutbox] enqueue bind accounts to group failed: group=%d err=%v", groupID, err)
|
||||
}
|
||||
|
||||
return nil
|
||||
|
||||
237
backend/internal/repository/idempotency_repo.go
Normal file
237
backend/internal/repository/idempotency_repo.go
Normal file
@@ -0,0 +1,237 @@
|
||||
package repository
|
||||
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"errors"
|
||||
"time"
|
||||
|
||||
dbent "github.com/Wei-Shaw/sub2api/ent"
|
||||
"github.com/Wei-Shaw/sub2api/internal/service"
|
||||
)
|
||||
|
||||
type idempotencyRepository struct {
|
||||
sql sqlExecutor
|
||||
}
|
||||
|
||||
func NewIdempotencyRepository(_ *dbent.Client, sqlDB *sql.DB) service.IdempotencyRepository {
|
||||
return &idempotencyRepository{sql: sqlDB}
|
||||
}
|
||||
|
||||
func (r *idempotencyRepository) CreateProcessing(ctx context.Context, record *service.IdempotencyRecord) (bool, error) {
|
||||
if record == nil {
|
||||
return false, nil
|
||||
}
|
||||
query := `
|
||||
INSERT INTO idempotency_records (
|
||||
scope, idempotency_key_hash, request_fingerprint, status, locked_until, expires_at
|
||||
) VALUES ($1, $2, $3, $4, $5, $6)
|
||||
ON CONFLICT (scope, idempotency_key_hash) DO NOTHING
|
||||
RETURNING id, created_at, updated_at
|
||||
`
|
||||
var createdAt time.Time
|
||||
var updatedAt time.Time
|
||||
err := scanSingleRow(ctx, r.sql, query, []any{
|
||||
record.Scope,
|
||||
record.IdempotencyKeyHash,
|
||||
record.RequestFingerprint,
|
||||
record.Status,
|
||||
record.LockedUntil,
|
||||
record.ExpiresAt,
|
||||
}, &record.ID, &createdAt, &updatedAt)
|
||||
if errors.Is(err, sql.ErrNoRows) {
|
||||
return false, nil
|
||||
}
|
||||
if err != nil {
|
||||
return false, err
|
||||
}
|
||||
record.CreatedAt = createdAt
|
||||
record.UpdatedAt = updatedAt
|
||||
return true, nil
|
||||
}
|
||||
|
||||
func (r *idempotencyRepository) GetByScopeAndKeyHash(ctx context.Context, scope, keyHash string) (*service.IdempotencyRecord, error) {
|
||||
query := `
|
||||
SELECT
|
||||
id, scope, idempotency_key_hash, request_fingerprint, status, response_status,
|
||||
response_body, error_reason, locked_until, expires_at, created_at, updated_at
|
||||
FROM idempotency_records
|
||||
WHERE scope = $1 AND idempotency_key_hash = $2
|
||||
`
|
||||
record := &service.IdempotencyRecord{}
|
||||
var responseStatus sql.NullInt64
|
||||
var responseBody sql.NullString
|
||||
var errorReason sql.NullString
|
||||
var lockedUntil sql.NullTime
|
||||
err := scanSingleRow(ctx, r.sql, query, []any{scope, keyHash},
|
||||
&record.ID,
|
||||
&record.Scope,
|
||||
&record.IdempotencyKeyHash,
|
||||
&record.RequestFingerprint,
|
||||
&record.Status,
|
||||
&responseStatus,
|
||||
&responseBody,
|
||||
&errorReason,
|
||||
&lockedUntil,
|
||||
&record.ExpiresAt,
|
||||
&record.CreatedAt,
|
||||
&record.UpdatedAt,
|
||||
)
|
||||
if errors.Is(err, sql.ErrNoRows) {
|
||||
return nil, nil
|
||||
}
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if responseStatus.Valid {
|
||||
v := int(responseStatus.Int64)
|
||||
record.ResponseStatus = &v
|
||||
}
|
||||
if responseBody.Valid {
|
||||
v := responseBody.String
|
||||
record.ResponseBody = &v
|
||||
}
|
||||
if errorReason.Valid {
|
||||
v := errorReason.String
|
||||
record.ErrorReason = &v
|
||||
}
|
||||
if lockedUntil.Valid {
|
||||
v := lockedUntil.Time
|
||||
record.LockedUntil = &v
|
||||
}
|
||||
return record, nil
|
||||
}
|
||||
|
||||
func (r *idempotencyRepository) TryReclaim(
|
||||
ctx context.Context,
|
||||
id int64,
|
||||
fromStatus string,
|
||||
now, newLockedUntil, newExpiresAt time.Time,
|
||||
) (bool, error) {
|
||||
query := `
|
||||
UPDATE idempotency_records
|
||||
SET status = $2,
|
||||
locked_until = $3,
|
||||
error_reason = NULL,
|
||||
updated_at = NOW(),
|
||||
expires_at = $4
|
||||
WHERE id = $1
|
||||
AND status = $5
|
||||
AND (locked_until IS NULL OR locked_until <= $6)
|
||||
`
|
||||
res, err := r.sql.ExecContext(ctx, query,
|
||||
id,
|
||||
service.IdempotencyStatusProcessing,
|
||||
newLockedUntil,
|
||||
newExpiresAt,
|
||||
fromStatus,
|
||||
now,
|
||||
)
|
||||
if err != nil {
|
||||
return false, err
|
||||
}
|
||||
affected, err := res.RowsAffected()
|
||||
if err != nil {
|
||||
return false, err
|
||||
}
|
||||
return affected > 0, nil
|
||||
}
|
||||
|
||||
func (r *idempotencyRepository) ExtendProcessingLock(
|
||||
ctx context.Context,
|
||||
id int64,
|
||||
requestFingerprint string,
|
||||
newLockedUntil,
|
||||
newExpiresAt time.Time,
|
||||
) (bool, error) {
|
||||
query := `
|
||||
UPDATE idempotency_records
|
||||
SET locked_until = $2,
|
||||
expires_at = $3,
|
||||
updated_at = NOW()
|
||||
WHERE id = $1
|
||||
AND status = $4
|
||||
AND request_fingerprint = $5
|
||||
`
|
||||
res, err := r.sql.ExecContext(
|
||||
ctx,
|
||||
query,
|
||||
id,
|
||||
newLockedUntil,
|
||||
newExpiresAt,
|
||||
service.IdempotencyStatusProcessing,
|
||||
requestFingerprint,
|
||||
)
|
||||
if err != nil {
|
||||
return false, err
|
||||
}
|
||||
affected, err := res.RowsAffected()
|
||||
if err != nil {
|
||||
return false, err
|
||||
}
|
||||
return affected > 0, nil
|
||||
}
|
||||
|
||||
func (r *idempotencyRepository) MarkSucceeded(ctx context.Context, id int64, responseStatus int, responseBody string, expiresAt time.Time) error {
|
||||
query := `
|
||||
UPDATE idempotency_records
|
||||
SET status = $2,
|
||||
response_status = $3,
|
||||
response_body = $4,
|
||||
error_reason = NULL,
|
||||
locked_until = NULL,
|
||||
expires_at = $5,
|
||||
updated_at = NOW()
|
||||
WHERE id = $1
|
||||
`
|
||||
_, err := r.sql.ExecContext(ctx, query,
|
||||
id,
|
||||
service.IdempotencyStatusSucceeded,
|
||||
responseStatus,
|
||||
responseBody,
|
||||
expiresAt,
|
||||
)
|
||||
return err
|
||||
}
|
||||
|
||||
func (r *idempotencyRepository) MarkFailedRetryable(ctx context.Context, id int64, errorReason string, lockedUntil, expiresAt time.Time) error {
|
||||
query := `
|
||||
UPDATE idempotency_records
|
||||
SET status = $2,
|
||||
error_reason = $3,
|
||||
locked_until = $4,
|
||||
expires_at = $5,
|
||||
updated_at = NOW()
|
||||
WHERE id = $1
|
||||
`
|
||||
_, err := r.sql.ExecContext(ctx, query,
|
||||
id,
|
||||
service.IdempotencyStatusFailedRetryable,
|
||||
errorReason,
|
||||
lockedUntil,
|
||||
expiresAt,
|
||||
)
|
||||
return err
|
||||
}
|
||||
|
||||
func (r *idempotencyRepository) DeleteExpired(ctx context.Context, now time.Time, limit int) (int64, error) {
|
||||
if limit <= 0 {
|
||||
limit = 500
|
||||
}
|
||||
query := `
|
||||
WITH victims AS (
|
||||
SELECT id
|
||||
FROM idempotency_records
|
||||
WHERE expires_at <= $1
|
||||
ORDER BY expires_at ASC
|
||||
LIMIT $2
|
||||
)
|
||||
DELETE FROM idempotency_records
|
||||
WHERE id IN (SELECT id FROM victims)
|
||||
`
|
||||
res, err := r.sql.ExecContext(ctx, query, now, limit)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
return res.RowsAffected()
|
||||
}
|
||||
150
backend/internal/repository/idempotency_repo_integration_test.go
Normal file
150
backend/internal/repository/idempotency_repo_integration_test.go
Normal file
@@ -0,0 +1,150 @@
|
||||
//go:build integration
|
||||
|
||||
package repository
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/sha256"
|
||||
"encoding/hex"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/service"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
// hashedTestValue returns a unique SHA-256 hex string (64 chars) that fits VARCHAR(64) columns.
|
||||
func hashedTestValue(t *testing.T, prefix string) string {
|
||||
t.Helper()
|
||||
sum := sha256.Sum256([]byte(uniqueTestValue(t, prefix)))
|
||||
return hex.EncodeToString(sum[:])
|
||||
}
|
||||
|
||||
func TestIdempotencyRepo_CreateProcessing_CompeteSameKey(t *testing.T) {
|
||||
tx := testTx(t)
|
||||
repo := &idempotencyRepository{sql: tx}
|
||||
ctx := context.Background()
|
||||
|
||||
now := time.Now().UTC()
|
||||
record := &service.IdempotencyRecord{
|
||||
Scope: uniqueTestValue(t, "idem-scope-create"),
|
||||
IdempotencyKeyHash: hashedTestValue(t, "idem-hash"),
|
||||
RequestFingerprint: hashedTestValue(t, "idem-fp"),
|
||||
Status: service.IdempotencyStatusProcessing,
|
||||
LockedUntil: ptrTime(now.Add(30 * time.Second)),
|
||||
ExpiresAt: now.Add(24 * time.Hour),
|
||||
}
|
||||
owner, err := repo.CreateProcessing(ctx, record)
|
||||
require.NoError(t, err)
|
||||
require.True(t, owner)
|
||||
require.NotZero(t, record.ID)
|
||||
|
||||
duplicate := &service.IdempotencyRecord{
|
||||
Scope: record.Scope,
|
||||
IdempotencyKeyHash: record.IdempotencyKeyHash,
|
||||
RequestFingerprint: hashedTestValue(t, "idem-fp-other"),
|
||||
Status: service.IdempotencyStatusProcessing,
|
||||
LockedUntil: ptrTime(now.Add(30 * time.Second)),
|
||||
ExpiresAt: now.Add(24 * time.Hour),
|
||||
}
|
||||
owner, err = repo.CreateProcessing(ctx, duplicate)
|
||||
require.NoError(t, err)
|
||||
require.False(t, owner, "same scope+key hash should be de-duplicated")
|
||||
}
|
||||
|
||||
func TestIdempotencyRepo_TryReclaim_StatusAndLockWindow(t *testing.T) {
|
||||
tx := testTx(t)
|
||||
repo := &idempotencyRepository{sql: tx}
|
||||
ctx := context.Background()
|
||||
|
||||
now := time.Now().UTC()
|
||||
record := &service.IdempotencyRecord{
|
||||
Scope: uniqueTestValue(t, "idem-scope-reclaim"),
|
||||
IdempotencyKeyHash: hashedTestValue(t, "idem-hash-reclaim"),
|
||||
RequestFingerprint: hashedTestValue(t, "idem-fp-reclaim"),
|
||||
Status: service.IdempotencyStatusProcessing,
|
||||
LockedUntil: ptrTime(now.Add(10 * time.Second)),
|
||||
ExpiresAt: now.Add(24 * time.Hour),
|
||||
}
|
||||
owner, err := repo.CreateProcessing(ctx, record)
|
||||
require.NoError(t, err)
|
||||
require.True(t, owner)
|
||||
|
||||
require.NoError(t, repo.MarkFailedRetryable(
|
||||
ctx,
|
||||
record.ID,
|
||||
"RETRYABLE_FAILURE",
|
||||
now.Add(-2*time.Second),
|
||||
now.Add(24*time.Hour),
|
||||
))
|
||||
|
||||
newLockedUntil := now.Add(20 * time.Second)
|
||||
reclaimed, err := repo.TryReclaim(
|
||||
ctx,
|
||||
record.ID,
|
||||
service.IdempotencyStatusFailedRetryable,
|
||||
now,
|
||||
newLockedUntil,
|
||||
now.Add(24*time.Hour),
|
||||
)
|
||||
require.NoError(t, err)
|
||||
require.True(t, reclaimed, "failed_retryable + expired lock should allow reclaim")
|
||||
|
||||
got, err := repo.GetByScopeAndKeyHash(ctx, record.Scope, record.IdempotencyKeyHash)
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, got)
|
||||
require.Equal(t, service.IdempotencyStatusProcessing, got.Status)
|
||||
require.NotNil(t, got.LockedUntil)
|
||||
require.True(t, got.LockedUntil.After(now))
|
||||
|
||||
require.NoError(t, repo.MarkFailedRetryable(
|
||||
ctx,
|
||||
record.ID,
|
||||
"RETRYABLE_FAILURE",
|
||||
now.Add(20*time.Second),
|
||||
now.Add(24*time.Hour),
|
||||
))
|
||||
|
||||
reclaimed, err = repo.TryReclaim(
|
||||
ctx,
|
||||
record.ID,
|
||||
service.IdempotencyStatusFailedRetryable,
|
||||
now,
|
||||
now.Add(40*time.Second),
|
||||
now.Add(24*time.Hour),
|
||||
)
|
||||
require.NoError(t, err)
|
||||
require.False(t, reclaimed, "within lock window should not reclaim")
|
||||
}
|
||||
|
||||
func TestIdempotencyRepo_StatusTransition_ToSucceeded(t *testing.T) {
|
||||
tx := testTx(t)
|
||||
repo := &idempotencyRepository{sql: tx}
|
||||
ctx := context.Background()
|
||||
|
||||
now := time.Now().UTC()
|
||||
record := &service.IdempotencyRecord{
|
||||
Scope: uniqueTestValue(t, "idem-scope-success"),
|
||||
IdempotencyKeyHash: hashedTestValue(t, "idem-hash-success"),
|
||||
RequestFingerprint: hashedTestValue(t, "idem-fp-success"),
|
||||
Status: service.IdempotencyStatusProcessing,
|
||||
LockedUntil: ptrTime(now.Add(10 * time.Second)),
|
||||
ExpiresAt: now.Add(24 * time.Hour),
|
||||
}
|
||||
owner, err := repo.CreateProcessing(ctx, record)
|
||||
require.NoError(t, err)
|
||||
require.True(t, owner)
|
||||
|
||||
require.NoError(t, repo.MarkSucceeded(ctx, record.ID, 200, `{"ok":true}`, now.Add(24*time.Hour)))
|
||||
|
||||
got, err := repo.GetByScopeAndKeyHash(ctx, record.Scope, record.IdempotencyKeyHash)
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, got)
|
||||
require.Equal(t, service.IdempotencyStatusSucceeded, got.Status)
|
||||
require.NotNil(t, got.ResponseStatus)
|
||||
require.Equal(t, 200, *got.ResponseStatus)
|
||||
require.NotNil(t, got.ResponseBody)
|
||||
require.Equal(t, `{"ok":true}`, *got.ResponseBody)
|
||||
require.Nil(t, got.LockedUntil)
|
||||
}
|
||||
|
||||
@@ -48,6 +48,11 @@ func TestMigrationsRunner_IsIdempotent_AndSchemaIsUpToDate(t *testing.T) {
|
||||
require.NoError(t, tx.QueryRowContext(context.Background(), "SELECT to_regclass('public.settings')").Scan(&settingsRegclass))
|
||||
require.True(t, settingsRegclass.Valid, "expected settings table to exist")
|
||||
|
||||
// security_secrets table should exist
|
||||
var securitySecretsRegclass sql.NullString
|
||||
require.NoError(t, tx.QueryRowContext(context.Background(), "SELECT to_regclass('public.security_secrets')").Scan(&securitySecretsRegclass))
|
||||
require.True(t, securitySecretsRegclass.Valid, "expected security_secrets table to exist")
|
||||
|
||||
// user_allowed_groups table should exist
|
||||
var uagRegclass sql.NullString
|
||||
require.NoError(t, tx.QueryRowContext(context.Background(), "SELECT to_regclass('public.user_allowed_groups')").Scan(&uagRegclass))
|
||||
|
||||
@@ -4,6 +4,7 @@ import (
|
||||
"context"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
infraerrors "github.com/Wei-Shaw/sub2api/internal/pkg/errors"
|
||||
@@ -56,12 +57,49 @@ func (s *openaiOAuthService) ExchangeCode(ctx context.Context, code, codeVerifie
|
||||
}
|
||||
|
||||
func (s *openaiOAuthService) RefreshToken(ctx context.Context, refreshToken, proxyURL string) (*openai.TokenResponse, error) {
|
||||
return s.RefreshTokenWithClientID(ctx, refreshToken, proxyURL, "")
|
||||
}
|
||||
|
||||
func (s *openaiOAuthService) RefreshTokenWithClientID(ctx context.Context, refreshToken, proxyURL string, clientID string) (*openai.TokenResponse, error) {
|
||||
if strings.TrimSpace(clientID) != "" {
|
||||
return s.refreshTokenWithClientID(ctx, refreshToken, proxyURL, strings.TrimSpace(clientID))
|
||||
}
|
||||
|
||||
clientIDs := []string{
|
||||
openai.ClientID,
|
||||
openai.SoraClientID,
|
||||
}
|
||||
seen := make(map[string]struct{}, len(clientIDs))
|
||||
var lastErr error
|
||||
for _, clientID := range clientIDs {
|
||||
clientID = strings.TrimSpace(clientID)
|
||||
if clientID == "" {
|
||||
continue
|
||||
}
|
||||
if _, ok := seen[clientID]; ok {
|
||||
continue
|
||||
}
|
||||
seen[clientID] = struct{}{}
|
||||
|
||||
tokenResp, err := s.refreshTokenWithClientID(ctx, refreshToken, proxyURL, clientID)
|
||||
if err == nil {
|
||||
return tokenResp, nil
|
||||
}
|
||||
lastErr = err
|
||||
}
|
||||
if lastErr != nil {
|
||||
return nil, lastErr
|
||||
}
|
||||
return nil, infraerrors.New(http.StatusBadGateway, "OPENAI_OAUTH_TOKEN_REFRESH_FAILED", "token refresh failed")
|
||||
}
|
||||
|
||||
func (s *openaiOAuthService) refreshTokenWithClientID(ctx context.Context, refreshToken, proxyURL, clientID string) (*openai.TokenResponse, error) {
|
||||
client := createOpenAIReqClient(proxyURL)
|
||||
|
||||
formData := url.Values{}
|
||||
formData.Set("grant_type", "refresh_token")
|
||||
formData.Set("refresh_token", refreshToken)
|
||||
formData.Set("client_id", openai.ClientID)
|
||||
formData.Set("client_id", clientID)
|
||||
formData.Set("scope", openai.RefreshScopes)
|
||||
|
||||
var tokenResp openai.TokenResponse
|
||||
|
||||
@@ -136,6 +136,60 @@ func (s *OpenAIOAuthServiceSuite) TestRefreshToken_FormFields() {
|
||||
require.Equal(s.T(), "rt2", resp.RefreshToken)
|
||||
}
|
||||
|
||||
func (s *OpenAIOAuthServiceSuite) TestRefreshToken_FallbackToSoraClientID() {
|
||||
var seenClientIDs []string
|
||||
s.setupServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
if err := r.ParseForm(); err != nil {
|
||||
w.WriteHeader(http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
clientID := r.PostForm.Get("client_id")
|
||||
seenClientIDs = append(seenClientIDs, clientID)
|
||||
if clientID == openai.ClientID {
|
||||
w.WriteHeader(http.StatusBadRequest)
|
||||
_, _ = io.WriteString(w, "invalid_grant")
|
||||
return
|
||||
}
|
||||
if clientID == openai.SoraClientID {
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
_, _ = io.WriteString(w, `{"access_token":"at-sora","refresh_token":"rt-sora","token_type":"bearer","expires_in":3600}`)
|
||||
return
|
||||
}
|
||||
w.WriteHeader(http.StatusBadRequest)
|
||||
}))
|
||||
|
||||
resp, err := s.svc.RefreshToken(s.ctx, "rt", "")
|
||||
require.NoError(s.T(), err, "RefreshToken")
|
||||
require.Equal(s.T(), "at-sora", resp.AccessToken)
|
||||
require.Equal(s.T(), "rt-sora", resp.RefreshToken)
|
||||
require.Equal(s.T(), []string{openai.ClientID, openai.SoraClientID}, seenClientIDs)
|
||||
}
|
||||
|
||||
func (s *OpenAIOAuthServiceSuite) TestRefreshToken_UseProvidedClientID() {
|
||||
const customClientID = "custom-client-id"
|
||||
var seenClientIDs []string
|
||||
s.setupServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
if err := r.ParseForm(); err != nil {
|
||||
w.WriteHeader(http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
clientID := r.PostForm.Get("client_id")
|
||||
seenClientIDs = append(seenClientIDs, clientID)
|
||||
if clientID != customClientID {
|
||||
w.WriteHeader(http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
_, _ = io.WriteString(w, `{"access_token":"at-custom","refresh_token":"rt-custom","token_type":"bearer","expires_in":3600}`)
|
||||
}))
|
||||
|
||||
resp, err := s.svc.RefreshTokenWithClientID(s.ctx, "rt", "", customClientID)
|
||||
require.NoError(s.T(), err, "RefreshTokenWithClientID")
|
||||
require.Equal(s.T(), "at-custom", resp.AccessToken)
|
||||
require.Equal(s.T(), "rt-custom", resp.RefreshToken)
|
||||
require.Equal(s.T(), []string{customClientID}, seenClientIDs)
|
||||
}
|
||||
|
||||
func (s *OpenAIOAuthServiceSuite) TestNonSuccessStatus_IncludesBody() {
|
||||
s.setupServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
w.WriteHeader(http.StatusBadRequest)
|
||||
|
||||
@@ -3,6 +3,7 @@ package repository
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"strings"
|
||||
"time"
|
||||
@@ -55,6 +56,10 @@ INSERT INTO ops_error_logs (
|
||||
upstream_error_message,
|
||||
upstream_error_detail,
|
||||
upstream_errors,
|
||||
auth_latency_ms,
|
||||
routing_latency_ms,
|
||||
upstream_latency_ms,
|
||||
response_latency_ms,
|
||||
time_to_first_token_ms,
|
||||
request_body,
|
||||
request_body_truncated,
|
||||
@@ -64,7 +69,7 @@ INSERT INTO ops_error_logs (
|
||||
retry_count,
|
||||
created_at
|
||||
) VALUES (
|
||||
$1,$2,$3,$4,$5,$6,$7,$8,$9,$10,$11,$12,$13,$14,$15,$16,$17,$18,$19,$20,$21,$22,$23,$24,$25,$26,$27,$28,$29,$30,$31,$32,$33,$34
|
||||
$1,$2,$3,$4,$5,$6,$7,$8,$9,$10,$11,$12,$13,$14,$15,$16,$17,$18,$19,$20,$21,$22,$23,$24,$25,$26,$27,$28,$29,$30,$31,$32,$33,$34,$35,$36,$37,$38
|
||||
) RETURNING id`
|
||||
|
||||
var id int64
|
||||
@@ -97,6 +102,10 @@ INSERT INTO ops_error_logs (
|
||||
opsNullString(input.UpstreamErrorMessage),
|
||||
opsNullString(input.UpstreamErrorDetail),
|
||||
opsNullString(input.UpstreamErrorsJSON),
|
||||
opsNullInt64(input.AuthLatencyMs),
|
||||
opsNullInt64(input.RoutingLatencyMs),
|
||||
opsNullInt64(input.UpstreamLatencyMs),
|
||||
opsNullInt64(input.ResponseLatencyMs),
|
||||
opsNullInt64(input.TimeToFirstTokenMs),
|
||||
opsNullString(input.RequestBodyJSON),
|
||||
input.RequestBodyTruncated,
|
||||
@@ -930,6 +939,243 @@ WHERE id = $1`
|
||||
return err
|
||||
}
|
||||
|
||||
func (r *opsRepository) BatchInsertSystemLogs(ctx context.Context, inputs []*service.OpsInsertSystemLogInput) (int64, error) {
|
||||
if r == nil || r.db == nil {
|
||||
return 0, fmt.Errorf("nil ops repository")
|
||||
}
|
||||
if len(inputs) == 0 {
|
||||
return 0, nil
|
||||
}
|
||||
|
||||
tx, err := r.db.BeginTx(ctx, nil)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
stmt, err := tx.PrepareContext(ctx, pq.CopyIn(
|
||||
"ops_system_logs",
|
||||
"created_at",
|
||||
"level",
|
||||
"component",
|
||||
"message",
|
||||
"request_id",
|
||||
"client_request_id",
|
||||
"user_id",
|
||||
"account_id",
|
||||
"platform",
|
||||
"model",
|
||||
"extra",
|
||||
))
|
||||
if err != nil {
|
||||
_ = tx.Rollback()
|
||||
return 0, err
|
||||
}
|
||||
|
||||
var inserted int64
|
||||
for _, input := range inputs {
|
||||
if input == nil {
|
||||
continue
|
||||
}
|
||||
createdAt := input.CreatedAt
|
||||
if createdAt.IsZero() {
|
||||
createdAt = time.Now().UTC()
|
||||
}
|
||||
component := strings.TrimSpace(input.Component)
|
||||
level := strings.ToLower(strings.TrimSpace(input.Level))
|
||||
message := strings.TrimSpace(input.Message)
|
||||
if level == "" || message == "" {
|
||||
continue
|
||||
}
|
||||
if component == "" {
|
||||
component = "app"
|
||||
}
|
||||
extra := strings.TrimSpace(input.ExtraJSON)
|
||||
if extra == "" {
|
||||
extra = "{}"
|
||||
}
|
||||
if _, err := stmt.ExecContext(
|
||||
ctx,
|
||||
createdAt.UTC(),
|
||||
level,
|
||||
component,
|
||||
message,
|
||||
opsNullString(input.RequestID),
|
||||
opsNullString(input.ClientRequestID),
|
||||
opsNullInt64(input.UserID),
|
||||
opsNullInt64(input.AccountID),
|
||||
opsNullString(input.Platform),
|
||||
opsNullString(input.Model),
|
||||
extra,
|
||||
); err != nil {
|
||||
_ = stmt.Close()
|
||||
_ = tx.Rollback()
|
||||
return inserted, err
|
||||
}
|
||||
inserted++
|
||||
}
|
||||
|
||||
if _, err := stmt.ExecContext(ctx); err != nil {
|
||||
_ = stmt.Close()
|
||||
_ = tx.Rollback()
|
||||
return inserted, err
|
||||
}
|
||||
if err := stmt.Close(); err != nil {
|
||||
_ = tx.Rollback()
|
||||
return inserted, err
|
||||
}
|
||||
if err := tx.Commit(); err != nil {
|
||||
return inserted, err
|
||||
}
|
||||
return inserted, nil
|
||||
}
|
||||
|
||||
func (r *opsRepository) ListSystemLogs(ctx context.Context, filter *service.OpsSystemLogFilter) (*service.OpsSystemLogList, error) {
|
||||
if r == nil || r.db == nil {
|
||||
return nil, fmt.Errorf("nil ops repository")
|
||||
}
|
||||
if filter == nil {
|
||||
filter = &service.OpsSystemLogFilter{}
|
||||
}
|
||||
|
||||
page := filter.Page
|
||||
if page <= 0 {
|
||||
page = 1
|
||||
}
|
||||
pageSize := filter.PageSize
|
||||
if pageSize <= 0 {
|
||||
pageSize = 50
|
||||
}
|
||||
if pageSize > 200 {
|
||||
pageSize = 200
|
||||
}
|
||||
|
||||
where, args, _ := buildOpsSystemLogsWhere(filter)
|
||||
countSQL := "SELECT COUNT(*) FROM ops_system_logs l " + where
|
||||
var total int
|
||||
if err := r.db.QueryRowContext(ctx, countSQL, args...).Scan(&total); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
offset := (page - 1) * pageSize
|
||||
argsWithLimit := append(args, pageSize, offset)
|
||||
query := `
|
||||
SELECT
|
||||
l.id,
|
||||
l.created_at,
|
||||
l.level,
|
||||
COALESCE(l.component, ''),
|
||||
COALESCE(l.message, ''),
|
||||
COALESCE(l.request_id, ''),
|
||||
COALESCE(l.client_request_id, ''),
|
||||
l.user_id,
|
||||
l.account_id,
|
||||
COALESCE(l.platform, ''),
|
||||
COALESCE(l.model, ''),
|
||||
COALESCE(l.extra::text, '{}')
|
||||
FROM ops_system_logs l
|
||||
` + where + `
|
||||
ORDER BY l.created_at DESC, l.id DESC
|
||||
LIMIT $` + itoa(len(args)+1) + ` OFFSET $` + itoa(len(args)+2)
|
||||
|
||||
rows, err := r.db.QueryContext(ctx, query, argsWithLimit...)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer func() { _ = rows.Close() }()
|
||||
|
||||
logs := make([]*service.OpsSystemLog, 0, pageSize)
|
||||
for rows.Next() {
|
||||
item := &service.OpsSystemLog{}
|
||||
var userID sql.NullInt64
|
||||
var accountID sql.NullInt64
|
||||
var extraRaw string
|
||||
if err := rows.Scan(
|
||||
&item.ID,
|
||||
&item.CreatedAt,
|
||||
&item.Level,
|
||||
&item.Component,
|
||||
&item.Message,
|
||||
&item.RequestID,
|
||||
&item.ClientRequestID,
|
||||
&userID,
|
||||
&accountID,
|
||||
&item.Platform,
|
||||
&item.Model,
|
||||
&extraRaw,
|
||||
); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if userID.Valid {
|
||||
v := userID.Int64
|
||||
item.UserID = &v
|
||||
}
|
||||
if accountID.Valid {
|
||||
v := accountID.Int64
|
||||
item.AccountID = &v
|
||||
}
|
||||
extraRaw = strings.TrimSpace(extraRaw)
|
||||
if extraRaw != "" && extraRaw != "null" && extraRaw != "{}" {
|
||||
extra := make(map[string]any)
|
||||
if err := json.Unmarshal([]byte(extraRaw), &extra); err == nil {
|
||||
item.Extra = extra
|
||||
}
|
||||
}
|
||||
logs = append(logs, item)
|
||||
}
|
||||
if err := rows.Err(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return &service.OpsSystemLogList{
|
||||
Logs: logs,
|
||||
Total: total,
|
||||
Page: page,
|
||||
PageSize: pageSize,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (r *opsRepository) DeleteSystemLogs(ctx context.Context, filter *service.OpsSystemLogCleanupFilter) (int64, error) {
|
||||
if r == nil || r.db == nil {
|
||||
return 0, fmt.Errorf("nil ops repository")
|
||||
}
|
||||
if filter == nil {
|
||||
filter = &service.OpsSystemLogCleanupFilter{}
|
||||
}
|
||||
|
||||
where, args, hasConstraint := buildOpsSystemLogsCleanupWhere(filter)
|
||||
if !hasConstraint {
|
||||
return 0, fmt.Errorf("cleanup requires at least one filter condition")
|
||||
}
|
||||
|
||||
query := "DELETE FROM ops_system_logs l " + where
|
||||
res, err := r.db.ExecContext(ctx, query, args...)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
return res.RowsAffected()
|
||||
}
|
||||
|
||||
func (r *opsRepository) InsertSystemLogCleanupAudit(ctx context.Context, input *service.OpsSystemLogCleanupAudit) error {
|
||||
if r == nil || r.db == nil {
|
||||
return fmt.Errorf("nil ops repository")
|
||||
}
|
||||
if input == nil {
|
||||
return fmt.Errorf("nil input")
|
||||
}
|
||||
createdAt := input.CreatedAt
|
||||
if createdAt.IsZero() {
|
||||
createdAt = time.Now().UTC()
|
||||
}
|
||||
_, err := r.db.ExecContext(ctx, `
|
||||
INSERT INTO ops_system_log_cleanup_audits (
|
||||
created_at,
|
||||
operator_id,
|
||||
conditions,
|
||||
deleted_rows
|
||||
) VALUES ($1,$2,$3,$4)
|
||||
`, createdAt.UTC(), input.OperatorID, input.Conditions, input.DeletedRows)
|
||||
return err
|
||||
}
|
||||
|
||||
func buildOpsErrorLogsWhere(filter *service.OpsErrorLogFilter) (string, []any) {
|
||||
clauses := make([]string, 0, 12)
|
||||
args := make([]any, 0, 12)
|
||||
@@ -948,7 +1194,7 @@ func buildOpsErrorLogsWhere(filter *service.OpsErrorLogFilter) (string, []any) {
|
||||
}
|
||||
// Keep list endpoints scoped to client errors unless explicitly filtering upstream phase.
|
||||
if phaseFilter != "upstream" {
|
||||
clauses = append(clauses, "COALESCE(status_code, 0) >= 400")
|
||||
clauses = append(clauses, "COALESCE(e.status_code, 0) >= 400")
|
||||
}
|
||||
|
||||
if filter.StartTime != nil && !filter.StartTime.IsZero() {
|
||||
@@ -962,33 +1208,33 @@ func buildOpsErrorLogsWhere(filter *service.OpsErrorLogFilter) (string, []any) {
|
||||
}
|
||||
if p := strings.TrimSpace(filter.Platform); p != "" {
|
||||
args = append(args, p)
|
||||
clauses = append(clauses, "platform = $"+itoa(len(args)))
|
||||
clauses = append(clauses, "e.platform = $"+itoa(len(args)))
|
||||
}
|
||||
if filter.GroupID != nil && *filter.GroupID > 0 {
|
||||
args = append(args, *filter.GroupID)
|
||||
clauses = append(clauses, "group_id = $"+itoa(len(args)))
|
||||
clauses = append(clauses, "e.group_id = $"+itoa(len(args)))
|
||||
}
|
||||
if filter.AccountID != nil && *filter.AccountID > 0 {
|
||||
args = append(args, *filter.AccountID)
|
||||
clauses = append(clauses, "account_id = $"+itoa(len(args)))
|
||||
clauses = append(clauses, "e.account_id = $"+itoa(len(args)))
|
||||
}
|
||||
if phase := phaseFilter; phase != "" {
|
||||
args = append(args, phase)
|
||||
clauses = append(clauses, "error_phase = $"+itoa(len(args)))
|
||||
clauses = append(clauses, "e.error_phase = $"+itoa(len(args)))
|
||||
}
|
||||
if filter != nil {
|
||||
if owner := strings.TrimSpace(strings.ToLower(filter.Owner)); owner != "" {
|
||||
args = append(args, owner)
|
||||
clauses = append(clauses, "LOWER(COALESCE(error_owner,'')) = $"+itoa(len(args)))
|
||||
clauses = append(clauses, "LOWER(COALESCE(e.error_owner,'')) = $"+itoa(len(args)))
|
||||
}
|
||||
if source := strings.TrimSpace(strings.ToLower(filter.Source)); source != "" {
|
||||
args = append(args, source)
|
||||
clauses = append(clauses, "LOWER(COALESCE(error_source,'')) = $"+itoa(len(args)))
|
||||
clauses = append(clauses, "LOWER(COALESCE(e.error_source,'')) = $"+itoa(len(args)))
|
||||
}
|
||||
}
|
||||
if resolvedFilter != nil {
|
||||
args = append(args, *resolvedFilter)
|
||||
clauses = append(clauses, "COALESCE(resolved,false) = $"+itoa(len(args)))
|
||||
clauses = append(clauses, "COALESCE(e.resolved,false) = $"+itoa(len(args)))
|
||||
}
|
||||
|
||||
// View filter: errors vs excluded vs all.
|
||||
@@ -1000,51 +1246,140 @@ func buildOpsErrorLogsWhere(filter *service.OpsErrorLogFilter) (string, []any) {
|
||||
}
|
||||
switch view {
|
||||
case "", "errors":
|
||||
clauses = append(clauses, "COALESCE(is_business_limited,false) = false")
|
||||
clauses = append(clauses, "COALESCE(e.is_business_limited,false) = false")
|
||||
case "excluded":
|
||||
clauses = append(clauses, "COALESCE(is_business_limited,false) = true")
|
||||
clauses = append(clauses, "COALESCE(e.is_business_limited,false) = true")
|
||||
case "all":
|
||||
// no-op
|
||||
default:
|
||||
// treat unknown as default 'errors'
|
||||
clauses = append(clauses, "COALESCE(is_business_limited,false) = false")
|
||||
clauses = append(clauses, "COALESCE(e.is_business_limited,false) = false")
|
||||
}
|
||||
if len(filter.StatusCodes) > 0 {
|
||||
args = append(args, pq.Array(filter.StatusCodes))
|
||||
clauses = append(clauses, "COALESCE(upstream_status_code, status_code, 0) = ANY($"+itoa(len(args))+")")
|
||||
clauses = append(clauses, "COALESCE(e.upstream_status_code, e.status_code, 0) = ANY($"+itoa(len(args))+")")
|
||||
} else if filter.StatusCodesOther {
|
||||
// "Other" means: status codes not in the common list.
|
||||
known := []int{400, 401, 403, 404, 409, 422, 429, 500, 502, 503, 504, 529}
|
||||
args = append(args, pq.Array(known))
|
||||
clauses = append(clauses, "NOT (COALESCE(upstream_status_code, status_code, 0) = ANY($"+itoa(len(args))+"))")
|
||||
clauses = append(clauses, "NOT (COALESCE(e.upstream_status_code, e.status_code, 0) = ANY($"+itoa(len(args))+"))")
|
||||
}
|
||||
// Exact correlation keys (preferred for request↔upstream linkage).
|
||||
if rid := strings.TrimSpace(filter.RequestID); rid != "" {
|
||||
args = append(args, rid)
|
||||
clauses = append(clauses, "COALESCE(request_id,'') = $"+itoa(len(args)))
|
||||
clauses = append(clauses, "COALESCE(e.request_id,'') = $"+itoa(len(args)))
|
||||
}
|
||||
if crid := strings.TrimSpace(filter.ClientRequestID); crid != "" {
|
||||
args = append(args, crid)
|
||||
clauses = append(clauses, "COALESCE(client_request_id,'') = $"+itoa(len(args)))
|
||||
clauses = append(clauses, "COALESCE(e.client_request_id,'') = $"+itoa(len(args)))
|
||||
}
|
||||
|
||||
if q := strings.TrimSpace(filter.Query); q != "" {
|
||||
like := "%" + q + "%"
|
||||
args = append(args, like)
|
||||
n := itoa(len(args))
|
||||
clauses = append(clauses, "(request_id ILIKE $"+n+" OR client_request_id ILIKE $"+n+" OR error_message ILIKE $"+n+")")
|
||||
clauses = append(clauses, "(e.request_id ILIKE $"+n+" OR e.client_request_id ILIKE $"+n+" OR e.error_message ILIKE $"+n+")")
|
||||
}
|
||||
|
||||
if userQuery := strings.TrimSpace(filter.UserQuery); userQuery != "" {
|
||||
like := "%" + userQuery + "%"
|
||||
args = append(args, like)
|
||||
n := itoa(len(args))
|
||||
clauses = append(clauses, "u.email ILIKE $"+n)
|
||||
clauses = append(clauses, "EXISTS (SELECT 1 FROM users u WHERE u.id = e.user_id AND u.email ILIKE $"+n+")")
|
||||
}
|
||||
|
||||
return "WHERE " + strings.Join(clauses, " AND "), args
|
||||
}
|
||||
|
||||
func buildOpsSystemLogsWhere(filter *service.OpsSystemLogFilter) (string, []any, bool) {
|
||||
clauses := make([]string, 0, 10)
|
||||
args := make([]any, 0, 10)
|
||||
clauses = append(clauses, "1=1")
|
||||
hasConstraint := false
|
||||
|
||||
if filter != nil && filter.StartTime != nil && !filter.StartTime.IsZero() {
|
||||
args = append(args, filter.StartTime.UTC())
|
||||
clauses = append(clauses, "l.created_at >= $"+itoa(len(args)))
|
||||
hasConstraint = true
|
||||
}
|
||||
if filter != nil && filter.EndTime != nil && !filter.EndTime.IsZero() {
|
||||
args = append(args, filter.EndTime.UTC())
|
||||
clauses = append(clauses, "l.created_at < $"+itoa(len(args)))
|
||||
hasConstraint = true
|
||||
}
|
||||
if filter != nil {
|
||||
if v := strings.ToLower(strings.TrimSpace(filter.Level)); v != "" {
|
||||
args = append(args, v)
|
||||
clauses = append(clauses, "LOWER(COALESCE(l.level,'')) = $"+itoa(len(args)))
|
||||
hasConstraint = true
|
||||
}
|
||||
if v := strings.TrimSpace(filter.Component); v != "" {
|
||||
args = append(args, v)
|
||||
clauses = append(clauses, "COALESCE(l.component,'') = $"+itoa(len(args)))
|
||||
hasConstraint = true
|
||||
}
|
||||
if v := strings.TrimSpace(filter.RequestID); v != "" {
|
||||
args = append(args, v)
|
||||
clauses = append(clauses, "COALESCE(l.request_id,'') = $"+itoa(len(args)))
|
||||
hasConstraint = true
|
||||
}
|
||||
if v := strings.TrimSpace(filter.ClientRequestID); v != "" {
|
||||
args = append(args, v)
|
||||
clauses = append(clauses, "COALESCE(l.client_request_id,'') = $"+itoa(len(args)))
|
||||
hasConstraint = true
|
||||
}
|
||||
if filter.UserID != nil && *filter.UserID > 0 {
|
||||
args = append(args, *filter.UserID)
|
||||
clauses = append(clauses, "l.user_id = $"+itoa(len(args)))
|
||||
hasConstraint = true
|
||||
}
|
||||
if filter.AccountID != nil && *filter.AccountID > 0 {
|
||||
args = append(args, *filter.AccountID)
|
||||
clauses = append(clauses, "l.account_id = $"+itoa(len(args)))
|
||||
hasConstraint = true
|
||||
}
|
||||
if v := strings.TrimSpace(filter.Platform); v != "" {
|
||||
args = append(args, v)
|
||||
clauses = append(clauses, "COALESCE(l.platform,'') = $"+itoa(len(args)))
|
||||
hasConstraint = true
|
||||
}
|
||||
if v := strings.TrimSpace(filter.Model); v != "" {
|
||||
args = append(args, v)
|
||||
clauses = append(clauses, "COALESCE(l.model,'') = $"+itoa(len(args)))
|
||||
hasConstraint = true
|
||||
}
|
||||
if v := strings.TrimSpace(filter.Query); v != "" {
|
||||
like := "%" + v + "%"
|
||||
args = append(args, like)
|
||||
n := itoa(len(args))
|
||||
clauses = append(clauses, "(l.message ILIKE $"+n+" OR COALESCE(l.request_id,'') ILIKE $"+n+" OR COALESCE(l.client_request_id,'') ILIKE $"+n+" OR COALESCE(l.extra::text,'') ILIKE $"+n+")")
|
||||
hasConstraint = true
|
||||
}
|
||||
}
|
||||
|
||||
return "WHERE " + strings.Join(clauses, " AND "), args, hasConstraint
|
||||
}
|
||||
|
||||
func buildOpsSystemLogsCleanupWhere(filter *service.OpsSystemLogCleanupFilter) (string, []any, bool) {
|
||||
if filter == nil {
|
||||
filter = &service.OpsSystemLogCleanupFilter{}
|
||||
}
|
||||
listFilter := &service.OpsSystemLogFilter{
|
||||
StartTime: filter.StartTime,
|
||||
EndTime: filter.EndTime,
|
||||
Level: filter.Level,
|
||||
Component: filter.Component,
|
||||
RequestID: filter.RequestID,
|
||||
ClientRequestID: filter.ClientRequestID,
|
||||
UserID: filter.UserID,
|
||||
AccountID: filter.AccountID,
|
||||
Platform: filter.Platform,
|
||||
Model: filter.Model,
|
||||
Query: filter.Query,
|
||||
}
|
||||
return buildOpsSystemLogsWhere(listFilter)
|
||||
}
|
||||
|
||||
// Helpers for nullable args
|
||||
func opsNullString(v any) any {
|
||||
switch s := v.(type) {
|
||||
|
||||
48
backend/internal/repository/ops_repo_error_where_test.go
Normal file
48
backend/internal/repository/ops_repo_error_where_test.go
Normal file
@@ -0,0 +1,48 @@
|
||||
package repository
|
||||
|
||||
import (
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/service"
|
||||
)
|
||||
|
||||
func TestBuildOpsErrorLogsWhere_QueryUsesQualifiedColumns(t *testing.T) {
|
||||
filter := &service.OpsErrorLogFilter{
|
||||
Query: "ACCESS_DENIED",
|
||||
}
|
||||
|
||||
where, args := buildOpsErrorLogsWhere(filter)
|
||||
if where == "" {
|
||||
t.Fatalf("where should not be empty")
|
||||
}
|
||||
if len(args) != 1 {
|
||||
t.Fatalf("args len = %d, want 1", len(args))
|
||||
}
|
||||
if !strings.Contains(where, "e.request_id ILIKE $") {
|
||||
t.Fatalf("where should include qualified request_id condition: %s", where)
|
||||
}
|
||||
if !strings.Contains(where, "e.client_request_id ILIKE $") {
|
||||
t.Fatalf("where should include qualified client_request_id condition: %s", where)
|
||||
}
|
||||
if !strings.Contains(where, "e.error_message ILIKE $") {
|
||||
t.Fatalf("where should include qualified error_message condition: %s", where)
|
||||
}
|
||||
}
|
||||
|
||||
func TestBuildOpsErrorLogsWhere_UserQueryUsesExistsSubquery(t *testing.T) {
|
||||
filter := &service.OpsErrorLogFilter{
|
||||
UserQuery: "admin@",
|
||||
}
|
||||
|
||||
where, args := buildOpsErrorLogsWhere(filter)
|
||||
if where == "" {
|
||||
t.Fatalf("where should not be empty")
|
||||
}
|
||||
if len(args) != 1 {
|
||||
t.Fatalf("args len = %d, want 1", len(args))
|
||||
}
|
||||
if !strings.Contains(where, "EXISTS (SELECT 1 FROM users u WHERE u.id = e.user_id AND u.email ILIKE $") {
|
||||
t.Fatalf("where should include EXISTS user email condition: %s", where)
|
||||
}
|
||||
}
|
||||
145
backend/internal/repository/ops_repo_openai_token_stats.go
Normal file
145
backend/internal/repository/ops_repo_openai_token_stats.go
Normal file
@@ -0,0 +1,145 @@
|
||||
package repository
|
||||
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"fmt"
|
||||
"strings"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/service"
|
||||
)
|
||||
|
||||
func (r *opsRepository) GetOpenAITokenStats(ctx context.Context, filter *service.OpsOpenAITokenStatsFilter) (*service.OpsOpenAITokenStatsResponse, error) {
|
||||
if r == nil || r.db == nil {
|
||||
return nil, fmt.Errorf("nil ops repository")
|
||||
}
|
||||
if filter == nil {
|
||||
return nil, fmt.Errorf("nil filter")
|
||||
}
|
||||
if filter.StartTime.IsZero() || filter.EndTime.IsZero() {
|
||||
return nil, fmt.Errorf("start_time/end_time required")
|
||||
}
|
||||
// 允许 start_time == end_time(结果为空),与 service 层校验口径保持一致。
|
||||
if filter.StartTime.After(filter.EndTime) {
|
||||
return nil, fmt.Errorf("start_time must be <= end_time")
|
||||
}
|
||||
|
||||
dashboardFilter := &service.OpsDashboardFilter{
|
||||
StartTime: filter.StartTime.UTC(),
|
||||
EndTime: filter.EndTime.UTC(),
|
||||
Platform: strings.TrimSpace(strings.ToLower(filter.Platform)),
|
||||
GroupID: filter.GroupID,
|
||||
}
|
||||
|
||||
join, where, baseArgs, next := buildUsageWhere(dashboardFilter, dashboardFilter.StartTime, dashboardFilter.EndTime, 1)
|
||||
where += " AND ul.model LIKE 'gpt%'"
|
||||
|
||||
baseCTE := `
|
||||
WITH stats AS (
|
||||
SELECT
|
||||
ul.model AS model,
|
||||
COUNT(*)::bigint AS request_count,
|
||||
ROUND(
|
||||
AVG(
|
||||
CASE
|
||||
WHEN ul.duration_ms > 0 AND ul.output_tokens > 0
|
||||
THEN ul.output_tokens * 1000.0 / ul.duration_ms
|
||||
END
|
||||
)::numeric,
|
||||
2
|
||||
)::float8 AS avg_tokens_per_sec,
|
||||
ROUND(AVG(ul.first_token_ms)::numeric, 2)::float8 AS avg_first_token_ms,
|
||||
COALESCE(SUM(ul.output_tokens), 0)::bigint AS total_output_tokens,
|
||||
COALESCE(ROUND(AVG(ul.duration_ms)::numeric, 0), 0)::bigint AS avg_duration_ms,
|
||||
COUNT(CASE WHEN ul.first_token_ms IS NOT NULL THEN 1 END)::bigint AS requests_with_first_token
|
||||
FROM usage_logs ul
|
||||
` + join + `
|
||||
` + where + `
|
||||
GROUP BY ul.model
|
||||
)
|
||||
`
|
||||
|
||||
countSQL := baseCTE + `SELECT COUNT(*) FROM stats`
|
||||
var total int64
|
||||
if err := r.db.QueryRowContext(ctx, countSQL, baseArgs...).Scan(&total); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
querySQL := baseCTE + `
|
||||
SELECT
|
||||
model,
|
||||
request_count,
|
||||
avg_tokens_per_sec,
|
||||
avg_first_token_ms,
|
||||
total_output_tokens,
|
||||
avg_duration_ms,
|
||||
requests_with_first_token
|
||||
FROM stats
|
||||
ORDER BY request_count DESC, model ASC`
|
||||
|
||||
args := make([]any, 0, len(baseArgs)+2)
|
||||
args = append(args, baseArgs...)
|
||||
|
||||
if filter.IsTopNMode() {
|
||||
querySQL += fmt.Sprintf("\nLIMIT $%d", next)
|
||||
args = append(args, filter.TopN)
|
||||
} else {
|
||||
offset := (filter.Page - 1) * filter.PageSize
|
||||
querySQL += fmt.Sprintf("\nLIMIT $%d OFFSET $%d", next, next+1)
|
||||
args = append(args, filter.PageSize, offset)
|
||||
}
|
||||
|
||||
rows, err := r.db.QueryContext(ctx, querySQL, args...)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer func() { _ = rows.Close() }()
|
||||
|
||||
items := make([]*service.OpsOpenAITokenStatsItem, 0, 32)
|
||||
for rows.Next() {
|
||||
item := &service.OpsOpenAITokenStatsItem{}
|
||||
var avgTPS sql.NullFloat64
|
||||
var avgFirstToken sql.NullFloat64
|
||||
if err := rows.Scan(
|
||||
&item.Model,
|
||||
&item.RequestCount,
|
||||
&avgTPS,
|
||||
&avgFirstToken,
|
||||
&item.TotalOutputTokens,
|
||||
&item.AvgDurationMs,
|
||||
&item.RequestsWithFirstToken,
|
||||
); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if avgTPS.Valid {
|
||||
v := avgTPS.Float64
|
||||
item.AvgTokensPerSec = &v
|
||||
}
|
||||
if avgFirstToken.Valid {
|
||||
v := avgFirstToken.Float64
|
||||
item.AvgFirstTokenMs = &v
|
||||
}
|
||||
items = append(items, item)
|
||||
}
|
||||
if err := rows.Err(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
resp := &service.OpsOpenAITokenStatsResponse{
|
||||
TimeRange: strings.TrimSpace(filter.TimeRange),
|
||||
StartTime: dashboardFilter.StartTime,
|
||||
EndTime: dashboardFilter.EndTime,
|
||||
Platform: dashboardFilter.Platform,
|
||||
GroupID: dashboardFilter.GroupID,
|
||||
Items: items,
|
||||
Total: total,
|
||||
}
|
||||
if filter.IsTopNMode() {
|
||||
topN := filter.TopN
|
||||
resp.TopN = &topN
|
||||
} else {
|
||||
resp.Page = filter.Page
|
||||
resp.PageSize = filter.PageSize
|
||||
}
|
||||
return resp, nil
|
||||
}
|
||||
156
backend/internal/repository/ops_repo_openai_token_stats_test.go
Normal file
156
backend/internal/repository/ops_repo_openai_token_stats_test.go
Normal file
@@ -0,0 +1,156 @@
|
||||
package repository
|
||||
|
||||
import (
|
||||
"context"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/DATA-DOG/go-sqlmock"
|
||||
"github.com/Wei-Shaw/sub2api/internal/service"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestOpsRepositoryGetOpenAITokenStats_PaginationMode(t *testing.T) {
|
||||
db, mock := newSQLMock(t)
|
||||
repo := &opsRepository{db: db}
|
||||
|
||||
start := time.Date(2026, 1, 1, 0, 0, 0, 0, time.UTC)
|
||||
end := start.Add(24 * time.Hour)
|
||||
groupID := int64(9)
|
||||
|
||||
filter := &service.OpsOpenAITokenStatsFilter{
|
||||
TimeRange: "1d",
|
||||
StartTime: start,
|
||||
EndTime: end,
|
||||
Platform: " OpenAI ",
|
||||
GroupID: &groupID,
|
||||
Page: 2,
|
||||
PageSize: 10,
|
||||
}
|
||||
|
||||
mock.ExpectQuery(`SELECT COUNT\(\*\) FROM stats`).
|
||||
WithArgs(start, end, groupID, "openai").
|
||||
WillReturnRows(sqlmock.NewRows([]string{"count"}).AddRow(int64(3)))
|
||||
|
||||
rows := sqlmock.NewRows([]string{
|
||||
"model",
|
||||
"request_count",
|
||||
"avg_tokens_per_sec",
|
||||
"avg_first_token_ms",
|
||||
"total_output_tokens",
|
||||
"avg_duration_ms",
|
||||
"requests_with_first_token",
|
||||
}).
|
||||
AddRow("gpt-4o-mini", int64(20), 21.56, 120.34, int64(3000), int64(850), int64(18)).
|
||||
AddRow("gpt-4.1", int64(20), 10.2, 240.0, int64(2500), int64(900), int64(20))
|
||||
|
||||
mock.ExpectQuery(`ORDER BY request_count DESC, model ASC\s+LIMIT \$5 OFFSET \$6`).
|
||||
WithArgs(start, end, groupID, "openai", 10, 10).
|
||||
WillReturnRows(rows)
|
||||
|
||||
resp, err := repo.GetOpenAITokenStats(context.Background(), filter)
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, resp)
|
||||
require.Equal(t, int64(3), resp.Total)
|
||||
require.Equal(t, 2, resp.Page)
|
||||
require.Equal(t, 10, resp.PageSize)
|
||||
require.Nil(t, resp.TopN)
|
||||
require.Equal(t, "openai", resp.Platform)
|
||||
require.NotNil(t, resp.GroupID)
|
||||
require.Equal(t, groupID, *resp.GroupID)
|
||||
require.Len(t, resp.Items, 2)
|
||||
require.Equal(t, "gpt-4o-mini", resp.Items[0].Model)
|
||||
require.NotNil(t, resp.Items[0].AvgTokensPerSec)
|
||||
require.InDelta(t, 21.56, *resp.Items[0].AvgTokensPerSec, 0.0001)
|
||||
require.NotNil(t, resp.Items[0].AvgFirstTokenMs)
|
||||
require.InDelta(t, 120.34, *resp.Items[0].AvgFirstTokenMs, 0.0001)
|
||||
|
||||
require.NoError(t, mock.ExpectationsWereMet())
|
||||
}
|
||||
|
||||
func TestOpsRepositoryGetOpenAITokenStats_TopNMode(t *testing.T) {
|
||||
db, mock := newSQLMock(t)
|
||||
repo := &opsRepository{db: db}
|
||||
|
||||
start := time.Date(2026, 1, 1, 10, 0, 0, 0, time.UTC)
|
||||
end := start.Add(time.Hour)
|
||||
filter := &service.OpsOpenAITokenStatsFilter{
|
||||
TimeRange: "1h",
|
||||
StartTime: start,
|
||||
EndTime: end,
|
||||
TopN: 5,
|
||||
}
|
||||
|
||||
mock.ExpectQuery(`SELECT COUNT\(\*\) FROM stats`).
|
||||
WithArgs(start, end).
|
||||
WillReturnRows(sqlmock.NewRows([]string{"count"}).AddRow(int64(1)))
|
||||
|
||||
rows := sqlmock.NewRows([]string{
|
||||
"model",
|
||||
"request_count",
|
||||
"avg_tokens_per_sec",
|
||||
"avg_first_token_ms",
|
||||
"total_output_tokens",
|
||||
"avg_duration_ms",
|
||||
"requests_with_first_token",
|
||||
}).
|
||||
AddRow("gpt-4o", int64(5), nil, nil, int64(0), int64(0), int64(0))
|
||||
|
||||
mock.ExpectQuery(`ORDER BY request_count DESC, model ASC\s+LIMIT \$3`).
|
||||
WithArgs(start, end, 5).
|
||||
WillReturnRows(rows)
|
||||
|
||||
resp, err := repo.GetOpenAITokenStats(context.Background(), filter)
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, resp)
|
||||
require.NotNil(t, resp.TopN)
|
||||
require.Equal(t, 5, *resp.TopN)
|
||||
require.Equal(t, 0, resp.Page)
|
||||
require.Equal(t, 0, resp.PageSize)
|
||||
require.Len(t, resp.Items, 1)
|
||||
require.Nil(t, resp.Items[0].AvgTokensPerSec)
|
||||
require.Nil(t, resp.Items[0].AvgFirstTokenMs)
|
||||
|
||||
require.NoError(t, mock.ExpectationsWereMet())
|
||||
}
|
||||
|
||||
func TestOpsRepositoryGetOpenAITokenStats_EmptyResult(t *testing.T) {
|
||||
db, mock := newSQLMock(t)
|
||||
repo := &opsRepository{db: db}
|
||||
|
||||
start := time.Date(2026, 1, 2, 0, 0, 0, 0, time.UTC)
|
||||
end := start.Add(30 * time.Minute)
|
||||
filter := &service.OpsOpenAITokenStatsFilter{
|
||||
TimeRange: "30m",
|
||||
StartTime: start,
|
||||
EndTime: end,
|
||||
Page: 1,
|
||||
PageSize: 20,
|
||||
}
|
||||
|
||||
mock.ExpectQuery(`SELECT COUNT\(\*\) FROM stats`).
|
||||
WithArgs(start, end).
|
||||
WillReturnRows(sqlmock.NewRows([]string{"count"}).AddRow(int64(0)))
|
||||
|
||||
mock.ExpectQuery(`ORDER BY request_count DESC, model ASC\s+LIMIT \$3 OFFSET \$4`).
|
||||
WithArgs(start, end, 20, 0).
|
||||
WillReturnRows(sqlmock.NewRows([]string{
|
||||
"model",
|
||||
"request_count",
|
||||
"avg_tokens_per_sec",
|
||||
"avg_first_token_ms",
|
||||
"total_output_tokens",
|
||||
"avg_duration_ms",
|
||||
"requests_with_first_token",
|
||||
}))
|
||||
|
||||
resp, err := repo.GetOpenAITokenStats(context.Background(), filter)
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, resp)
|
||||
require.Equal(t, int64(0), resp.Total)
|
||||
require.Len(t, resp.Items, 0)
|
||||
require.Equal(t, 1, resp.Page)
|
||||
require.Equal(t, 20, resp.PageSize)
|
||||
|
||||
require.NoError(t, mock.ExpectationsWereMet())
|
||||
}
|
||||
86
backend/internal/repository/ops_repo_system_logs_test.go
Normal file
86
backend/internal/repository/ops_repo_system_logs_test.go
Normal file
@@ -0,0 +1,86 @@
|
||||
package repository
|
||||
|
||||
import (
|
||||
"strings"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/service"
|
||||
)
|
||||
|
||||
func TestBuildOpsSystemLogsWhere_WithClientRequestIDAndUserID(t *testing.T) {
|
||||
start := time.Date(2026, 2, 1, 0, 0, 0, 0, time.UTC)
|
||||
end := time.Date(2026, 2, 2, 0, 0, 0, 0, time.UTC)
|
||||
userID := int64(12)
|
||||
accountID := int64(34)
|
||||
|
||||
filter := &service.OpsSystemLogFilter{
|
||||
StartTime: &start,
|
||||
EndTime: &end,
|
||||
Level: "warn",
|
||||
Component: "http.access",
|
||||
RequestID: "req-1",
|
||||
ClientRequestID: "creq-1",
|
||||
UserID: &userID,
|
||||
AccountID: &accountID,
|
||||
Platform: "openai",
|
||||
Model: "gpt-5",
|
||||
Query: "timeout",
|
||||
}
|
||||
|
||||
where, args, hasConstraint := buildOpsSystemLogsWhere(filter)
|
||||
if !hasConstraint {
|
||||
t.Fatalf("expected hasConstraint=true")
|
||||
}
|
||||
if where == "" {
|
||||
t.Fatalf("where should not be empty")
|
||||
}
|
||||
if len(args) != 11 {
|
||||
t.Fatalf("args len = %d, want 11", len(args))
|
||||
}
|
||||
if !contains(where, "COALESCE(l.client_request_id,'') = $") {
|
||||
t.Fatalf("where should include client_request_id condition: %s", where)
|
||||
}
|
||||
if !contains(where, "l.user_id = $") {
|
||||
t.Fatalf("where should include user_id condition: %s", where)
|
||||
}
|
||||
}
|
||||
|
||||
func TestBuildOpsSystemLogsCleanupWhere_RequireConstraint(t *testing.T) {
|
||||
where, args, hasConstraint := buildOpsSystemLogsCleanupWhere(&service.OpsSystemLogCleanupFilter{})
|
||||
if hasConstraint {
|
||||
t.Fatalf("expected hasConstraint=false")
|
||||
}
|
||||
if where == "" {
|
||||
t.Fatalf("where should not be empty")
|
||||
}
|
||||
if len(args) != 0 {
|
||||
t.Fatalf("args len = %d, want 0", len(args))
|
||||
}
|
||||
}
|
||||
|
||||
func TestBuildOpsSystemLogsCleanupWhere_WithClientRequestIDAndUserID(t *testing.T) {
|
||||
userID := int64(9)
|
||||
filter := &service.OpsSystemLogCleanupFilter{
|
||||
ClientRequestID: "creq-9",
|
||||
UserID: &userID,
|
||||
}
|
||||
|
||||
where, args, hasConstraint := buildOpsSystemLogsCleanupWhere(filter)
|
||||
if !hasConstraint {
|
||||
t.Fatalf("expected hasConstraint=true")
|
||||
}
|
||||
if len(args) != 2 {
|
||||
t.Fatalf("args len = %d, want 2", len(args))
|
||||
}
|
||||
if !contains(where, "COALESCE(l.client_request_id,'') = $") {
|
||||
t.Fatalf("where should include client_request_id condition: %s", where)
|
||||
}
|
||||
if !contains(where, "l.user_id = $") {
|
||||
t.Fatalf("where should include user_id condition: %s", where)
|
||||
}
|
||||
}
|
||||
|
||||
func contains(s string, sub string) bool {
|
||||
return strings.Contains(s, sub)
|
||||
}
|
||||
@@ -132,7 +132,7 @@ func (r *promoCodeRepository) ListWithFilters(ctx context.Context, params pagina
|
||||
q = q.Where(promocode.CodeContainsFold(search))
|
||||
}
|
||||
|
||||
total, err := q.Count(ctx)
|
||||
total, err := q.Clone().Count(ctx)
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
@@ -187,7 +187,7 @@ func (r *promoCodeRepository) ListUsagesByPromoCode(ctx context.Context, promoCo
|
||||
q := r.client.PromoCodeUsage.Query().
|
||||
Where(promocodeusage.PromoCodeIDEQ(promoCodeID))
|
||||
|
||||
total, err := q.Count(ctx)
|
||||
total, err := q.Clone().Count(ctx)
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
|
||||
@@ -19,10 +19,14 @@ func NewProxyExitInfoProber(cfg *config.Config) service.ProxyExitInfoProber {
|
||||
insecure := false
|
||||
allowPrivate := false
|
||||
validateResolvedIP := true
|
||||
maxResponseBytes := defaultProxyProbeResponseMaxBytes
|
||||
if cfg != nil {
|
||||
insecure = cfg.Security.ProxyProbe.InsecureSkipVerify
|
||||
allowPrivate = cfg.Security.URLAllowlist.AllowPrivateHosts
|
||||
validateResolvedIP = cfg.Security.URLAllowlist.Enabled
|
||||
if cfg.Gateway.ProxyProbeResponseReadMaxBytes > 0 {
|
||||
maxResponseBytes = cfg.Gateway.ProxyProbeResponseReadMaxBytes
|
||||
}
|
||||
}
|
||||
if insecure {
|
||||
log.Printf("[ProxyProbe] Warning: insecure_skip_verify is not allowed and will cause probe failure.")
|
||||
@@ -31,11 +35,13 @@ func NewProxyExitInfoProber(cfg *config.Config) service.ProxyExitInfoProber {
|
||||
insecureSkipVerify: insecure,
|
||||
allowPrivateHosts: allowPrivate,
|
||||
validateResolvedIP: validateResolvedIP,
|
||||
maxResponseBytes: maxResponseBytes,
|
||||
}
|
||||
}
|
||||
|
||||
const (
|
||||
defaultProxyProbeTimeout = 30 * time.Second
|
||||
defaultProxyProbeTimeout = 30 * time.Second
|
||||
defaultProxyProbeResponseMaxBytes = int64(1024 * 1024)
|
||||
)
|
||||
|
||||
// probeURLs 按优先级排列的探测 URL 列表
|
||||
@@ -52,6 +58,7 @@ type proxyProbeService struct {
|
||||
insecureSkipVerify bool
|
||||
allowPrivateHosts bool
|
||||
validateResolvedIP bool
|
||||
maxResponseBytes int64
|
||||
}
|
||||
|
||||
func (s *proxyProbeService) ProbeProxy(ctx context.Context, proxyURL string) (*service.ProxyExitInfo, int64, error) {
|
||||
@@ -98,10 +105,17 @@ func (s *proxyProbeService) probeWithURL(ctx context.Context, client *http.Clien
|
||||
return nil, latencyMs, fmt.Errorf("request failed with status: %d", resp.StatusCode)
|
||||
}
|
||||
|
||||
body, err := io.ReadAll(resp.Body)
|
||||
maxResponseBytes := s.maxResponseBytes
|
||||
if maxResponseBytes <= 0 {
|
||||
maxResponseBytes = defaultProxyProbeResponseMaxBytes
|
||||
}
|
||||
body, err := io.ReadAll(io.LimitReader(resp.Body, maxResponseBytes+1))
|
||||
if err != nil {
|
||||
return nil, latencyMs, fmt.Errorf("failed to read response: %w", err)
|
||||
}
|
||||
if int64(len(body)) > maxResponseBytes {
|
||||
return nil, latencyMs, fmt.Errorf("proxy probe response exceeds limit: %d", maxResponseBytes)
|
||||
}
|
||||
|
||||
switch parser {
|
||||
case "ip-api":
|
||||
|
||||
177
backend/internal/repository/security_secret_bootstrap.go
Normal file
177
backend/internal/repository/security_secret_bootstrap.go
Normal file
@@ -0,0 +1,177 @@
|
||||
package repository
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/rand"
|
||||
"database/sql"
|
||||
"encoding/hex"
|
||||
"errors"
|
||||
"fmt"
|
||||
"log"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/ent"
|
||||
"github.com/Wei-Shaw/sub2api/ent/securitysecret"
|
||||
"github.com/Wei-Shaw/sub2api/internal/config"
|
||||
)
|
||||
|
||||
const (
|
||||
securitySecretKeyJWT = "jwt_secret"
|
||||
securitySecretReadRetryMax = 5
|
||||
securitySecretReadRetryWait = 10 * time.Millisecond
|
||||
)
|
||||
|
||||
var readRandomBytes = rand.Read
|
||||
|
||||
func ensureBootstrapSecrets(ctx context.Context, client *ent.Client, cfg *config.Config) error {
|
||||
if client == nil {
|
||||
return fmt.Errorf("nil ent client")
|
||||
}
|
||||
if cfg == nil {
|
||||
return fmt.Errorf("nil config")
|
||||
}
|
||||
|
||||
cfg.JWT.Secret = strings.TrimSpace(cfg.JWT.Secret)
|
||||
if cfg.JWT.Secret != "" {
|
||||
storedSecret, err := createSecuritySecretIfAbsent(ctx, client, securitySecretKeyJWT, cfg.JWT.Secret)
|
||||
if err != nil {
|
||||
return fmt.Errorf("persist jwt secret: %w", err)
|
||||
}
|
||||
if storedSecret != cfg.JWT.Secret {
|
||||
log.Println("Warning: configured JWT secret mismatches persisted value; using persisted secret for cross-instance consistency.")
|
||||
}
|
||||
cfg.JWT.Secret = storedSecret
|
||||
return nil
|
||||
}
|
||||
|
||||
secret, created, err := getOrCreateGeneratedSecuritySecret(ctx, client, securitySecretKeyJWT, 32)
|
||||
if err != nil {
|
||||
return fmt.Errorf("ensure jwt secret: %w", err)
|
||||
}
|
||||
cfg.JWT.Secret = secret
|
||||
|
||||
if created {
|
||||
log.Println("Warning: JWT secret auto-generated and persisted to database. Consider rotating to a managed secret for production.")
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func getOrCreateGeneratedSecuritySecret(ctx context.Context, client *ent.Client, key string, byteLength int) (string, bool, error) {
|
||||
existing, err := client.SecuritySecret.Query().Where(securitysecret.KeyEQ(key)).Only(ctx)
|
||||
if err == nil {
|
||||
value := strings.TrimSpace(existing.Value)
|
||||
if len([]byte(value)) < 32 {
|
||||
return "", false, fmt.Errorf("stored secret %q must be at least 32 bytes", key)
|
||||
}
|
||||
return value, false, nil
|
||||
}
|
||||
if !ent.IsNotFound(err) {
|
||||
return "", false, err
|
||||
}
|
||||
|
||||
generated, err := generateHexSecret(byteLength)
|
||||
if err != nil {
|
||||
return "", false, err
|
||||
}
|
||||
|
||||
if err := client.SecuritySecret.Create().
|
||||
SetKey(key).
|
||||
SetValue(generated).
|
||||
OnConflictColumns(securitysecret.FieldKey).
|
||||
DoNothing().
|
||||
Exec(ctx); err != nil {
|
||||
if !isSQLNoRowsError(err) {
|
||||
return "", false, err
|
||||
}
|
||||
}
|
||||
|
||||
stored, err := querySecuritySecretWithRetry(ctx, client, key)
|
||||
if err != nil {
|
||||
return "", false, err
|
||||
}
|
||||
value := strings.TrimSpace(stored.Value)
|
||||
if len([]byte(value)) < 32 {
|
||||
return "", false, fmt.Errorf("stored secret %q must be at least 32 bytes", key)
|
||||
}
|
||||
return value, value == generated, nil
|
||||
}
|
||||
|
||||
func createSecuritySecretIfAbsent(ctx context.Context, client *ent.Client, key, value string) (string, error) {
|
||||
value = strings.TrimSpace(value)
|
||||
if len([]byte(value)) < 32 {
|
||||
return "", fmt.Errorf("secret %q must be at least 32 bytes", key)
|
||||
}
|
||||
|
||||
if err := client.SecuritySecret.Create().
|
||||
SetKey(key).
|
||||
SetValue(value).
|
||||
OnConflictColumns(securitysecret.FieldKey).
|
||||
DoNothing().
|
||||
Exec(ctx); err != nil {
|
||||
if !isSQLNoRowsError(err) {
|
||||
return "", err
|
||||
}
|
||||
}
|
||||
|
||||
stored, err := querySecuritySecretWithRetry(ctx, client, key)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
storedValue := strings.TrimSpace(stored.Value)
|
||||
if len([]byte(storedValue)) < 32 {
|
||||
return "", fmt.Errorf("stored secret %q must be at least 32 bytes", key)
|
||||
}
|
||||
return storedValue, nil
|
||||
}
|
||||
|
||||
func querySecuritySecretWithRetry(ctx context.Context, client *ent.Client, key string) (*ent.SecuritySecret, error) {
|
||||
var lastErr error
|
||||
for attempt := 0; attempt <= securitySecretReadRetryMax; attempt++ {
|
||||
stored, err := client.SecuritySecret.Query().Where(securitysecret.KeyEQ(key)).Only(ctx)
|
||||
if err == nil {
|
||||
return stored, nil
|
||||
}
|
||||
if !isSecretNotFoundError(err) {
|
||||
return nil, err
|
||||
}
|
||||
lastErr = err
|
||||
if attempt == securitySecretReadRetryMax {
|
||||
break
|
||||
}
|
||||
|
||||
timer := time.NewTimer(securitySecretReadRetryWait)
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
timer.Stop()
|
||||
return nil, ctx.Err()
|
||||
case <-timer.C:
|
||||
}
|
||||
}
|
||||
return nil, lastErr
|
||||
}
|
||||
|
||||
func isSecretNotFoundError(err error) bool {
|
||||
if err == nil {
|
||||
return false
|
||||
}
|
||||
return ent.IsNotFound(err) || isSQLNoRowsError(err)
|
||||
}
|
||||
|
||||
func isSQLNoRowsError(err error) bool {
|
||||
if err == nil {
|
||||
return false
|
||||
}
|
||||
return errors.Is(err, sql.ErrNoRows) || strings.Contains(err.Error(), "no rows in result set")
|
||||
}
|
||||
|
||||
func generateHexSecret(byteLength int) (string, error) {
|
||||
if byteLength <= 0 {
|
||||
byteLength = 32
|
||||
}
|
||||
buf := make([]byte, byteLength)
|
||||
if _, err := readRandomBytes(buf); err != nil {
|
||||
return "", fmt.Errorf("generate random secret: %w", err)
|
||||
}
|
||||
return hex.EncodeToString(buf), nil
|
||||
}
|
||||
337
backend/internal/repository/security_secret_bootstrap_test.go
Normal file
337
backend/internal/repository/security_secret_bootstrap_test.go
Normal file
@@ -0,0 +1,337 @@
|
||||
package repository
|
||||
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"encoding/hex"
|
||||
"errors"
|
||||
"fmt"
|
||||
"strings"
|
||||
"sync"
|
||||
"testing"
|
||||
|
||||
dbent "github.com/Wei-Shaw/sub2api/ent"
|
||||
"github.com/Wei-Shaw/sub2api/ent/enttest"
|
||||
"github.com/Wei-Shaw/sub2api/ent/securitysecret"
|
||||
"github.com/Wei-Shaw/sub2api/internal/config"
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
"entgo.io/ent/dialect"
|
||||
entsql "entgo.io/ent/dialect/sql"
|
||||
_ "modernc.org/sqlite"
|
||||
)
|
||||
|
||||
func newSecuritySecretTestClient(t *testing.T) *dbent.Client {
|
||||
t.Helper()
|
||||
name := strings.ReplaceAll(t.Name(), "/", "_")
|
||||
dsn := fmt.Sprintf("file:%s?mode=memory&cache=shared&_fk=1", name)
|
||||
|
||||
db, err := sql.Open("sqlite", dsn)
|
||||
require.NoError(t, err)
|
||||
t.Cleanup(func() { _ = db.Close() })
|
||||
|
||||
_, err = db.Exec("PRAGMA foreign_keys = ON")
|
||||
require.NoError(t, err)
|
||||
|
||||
drv := entsql.OpenDB(dialect.SQLite, db)
|
||||
client := enttest.NewClient(t, enttest.WithOptions(dbent.Driver(drv)))
|
||||
t.Cleanup(func() { _ = client.Close() })
|
||||
return client
|
||||
}
|
||||
|
||||
func TestEnsureBootstrapSecretsNilInputs(t *testing.T) {
|
||||
err := ensureBootstrapSecrets(context.Background(), nil, &config.Config{})
|
||||
require.Error(t, err)
|
||||
require.Contains(t, err.Error(), "nil ent client")
|
||||
|
||||
client := newSecuritySecretTestClient(t)
|
||||
err = ensureBootstrapSecrets(context.Background(), client, nil)
|
||||
require.Error(t, err)
|
||||
require.Contains(t, err.Error(), "nil config")
|
||||
}
|
||||
|
||||
func TestEnsureBootstrapSecretsGenerateAndPersistJWTSecret(t *testing.T) {
|
||||
client := newSecuritySecretTestClient(t)
|
||||
cfg := &config.Config{}
|
||||
|
||||
err := ensureBootstrapSecrets(context.Background(), client, cfg)
|
||||
require.NoError(t, err)
|
||||
require.NotEmpty(t, cfg.JWT.Secret)
|
||||
require.GreaterOrEqual(t, len([]byte(cfg.JWT.Secret)), 32)
|
||||
|
||||
stored, err := client.SecuritySecret.Query().Where(securitysecret.KeyEQ(securitySecretKeyJWT)).Only(context.Background())
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, cfg.JWT.Secret, stored.Value)
|
||||
}
|
||||
|
||||
func TestEnsureBootstrapSecretsLoadExistingJWTSecret(t *testing.T) {
|
||||
client := newSecuritySecretTestClient(t)
|
||||
_, err := client.SecuritySecret.Create().SetKey(securitySecretKeyJWT).SetValue("existing-jwt-secret-32bytes-long!!!!").Save(context.Background())
|
||||
require.NoError(t, err)
|
||||
|
||||
cfg := &config.Config{}
|
||||
err = ensureBootstrapSecrets(context.Background(), client, cfg)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, "existing-jwt-secret-32bytes-long!!!!", cfg.JWT.Secret)
|
||||
}
|
||||
|
||||
func TestEnsureBootstrapSecretsRejectInvalidStoredSecret(t *testing.T) {
|
||||
client := newSecuritySecretTestClient(t)
|
||||
_, err := client.SecuritySecret.Create().SetKey(securitySecretKeyJWT).SetValue("too-short").Save(context.Background())
|
||||
require.NoError(t, err)
|
||||
|
||||
cfg := &config.Config{}
|
||||
err = ensureBootstrapSecrets(context.Background(), client, cfg)
|
||||
require.Error(t, err)
|
||||
require.Contains(t, err.Error(), "at least 32 bytes")
|
||||
}
|
||||
|
||||
func TestEnsureBootstrapSecretsPersistConfiguredJWTSecret(t *testing.T) {
|
||||
client := newSecuritySecretTestClient(t)
|
||||
cfg := &config.Config{
|
||||
JWT: config.JWTConfig{Secret: "configured-jwt-secret-32bytes-long!!"},
|
||||
}
|
||||
|
||||
err := ensureBootstrapSecrets(context.Background(), client, cfg)
|
||||
require.NoError(t, err)
|
||||
|
||||
stored, err := client.SecuritySecret.Query().Where(securitysecret.KeyEQ(securitySecretKeyJWT)).Only(context.Background())
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, "configured-jwt-secret-32bytes-long!!", stored.Value)
|
||||
}
|
||||
|
||||
func TestEnsureBootstrapSecretsConfiguredSecretTooShort(t *testing.T) {
|
||||
client := newSecuritySecretTestClient(t)
|
||||
cfg := &config.Config{JWT: config.JWTConfig{Secret: "short"}}
|
||||
|
||||
err := ensureBootstrapSecrets(context.Background(), client, cfg)
|
||||
require.Error(t, err)
|
||||
require.Contains(t, err.Error(), "at least 32 bytes")
|
||||
}
|
||||
|
||||
func TestEnsureBootstrapSecretsConfiguredSecretDuplicateIgnored(t *testing.T) {
|
||||
client := newSecuritySecretTestClient(t)
|
||||
_, err := client.SecuritySecret.Create().
|
||||
SetKey(securitySecretKeyJWT).
|
||||
SetValue("existing-jwt-secret-32bytes-long!!!!").
|
||||
Save(context.Background())
|
||||
require.NoError(t, err)
|
||||
|
||||
cfg := &config.Config{JWT: config.JWTConfig{Secret: "another-configured-jwt-secret-32!!!!"}}
|
||||
err = ensureBootstrapSecrets(context.Background(), client, cfg)
|
||||
require.NoError(t, err)
|
||||
|
||||
stored, err := client.SecuritySecret.Query().Where(securitysecret.KeyEQ(securitySecretKeyJWT)).Only(context.Background())
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, "existing-jwt-secret-32bytes-long!!!!", stored.Value)
|
||||
require.Equal(t, "existing-jwt-secret-32bytes-long!!!!", cfg.JWT.Secret)
|
||||
}
|
||||
|
||||
func TestGetOrCreateGeneratedSecuritySecretTrimmedExistingValue(t *testing.T) {
|
||||
client := newSecuritySecretTestClient(t)
|
||||
_, err := client.SecuritySecret.Create().
|
||||
SetKey("trimmed_key").
|
||||
SetValue(" existing-trimmed-secret-32bytes-long!! ").
|
||||
Save(context.Background())
|
||||
require.NoError(t, err)
|
||||
|
||||
value, created, err := getOrCreateGeneratedSecuritySecret(context.Background(), client, "trimmed_key", 32)
|
||||
require.NoError(t, err)
|
||||
require.False(t, created)
|
||||
require.Equal(t, "existing-trimmed-secret-32bytes-long!!", value)
|
||||
}
|
||||
|
||||
func TestGetOrCreateGeneratedSecuritySecretQueryError(t *testing.T) {
|
||||
client := newSecuritySecretTestClient(t)
|
||||
require.NoError(t, client.Close())
|
||||
|
||||
_, _, err := getOrCreateGeneratedSecuritySecret(context.Background(), client, "closed_client_key", 32)
|
||||
require.Error(t, err)
|
||||
}
|
||||
|
||||
func TestGetOrCreateGeneratedSecuritySecretCreateValidationError(t *testing.T) {
|
||||
client := newSecuritySecretTestClient(t)
|
||||
tooLongKey := strings.Repeat("k", 101)
|
||||
|
||||
_, _, err := getOrCreateGeneratedSecuritySecret(context.Background(), client, tooLongKey, 32)
|
||||
require.Error(t, err)
|
||||
}
|
||||
|
||||
func TestGetOrCreateGeneratedSecuritySecretConcurrentCreation(t *testing.T) {
|
||||
client := newSecuritySecretTestClient(t)
|
||||
const goroutines = 8
|
||||
key := "concurrent_bootstrap_key"
|
||||
|
||||
values := make([]string, goroutines)
|
||||
createdFlags := make([]bool, goroutines)
|
||||
errs := make([]error, goroutines)
|
||||
|
||||
var wg sync.WaitGroup
|
||||
for i := 0; i < goroutines; i++ {
|
||||
wg.Add(1)
|
||||
go func(idx int) {
|
||||
defer wg.Done()
|
||||
values[idx], createdFlags[idx], errs[idx] = getOrCreateGeneratedSecuritySecret(context.Background(), client, key, 32)
|
||||
}(i)
|
||||
}
|
||||
wg.Wait()
|
||||
|
||||
for i := range errs {
|
||||
require.NoError(t, errs[i])
|
||||
require.NotEmpty(t, values[i])
|
||||
}
|
||||
for i := 1; i < len(values); i++ {
|
||||
require.Equal(t, values[0], values[i])
|
||||
}
|
||||
|
||||
createdCount := 0
|
||||
for _, created := range createdFlags {
|
||||
if created {
|
||||
createdCount++
|
||||
}
|
||||
}
|
||||
require.GreaterOrEqual(t, createdCount, 1)
|
||||
require.LessOrEqual(t, createdCount, 1)
|
||||
|
||||
count, err := client.SecuritySecret.Query().Where(securitysecret.KeyEQ(key)).Count(context.Background())
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, 1, count)
|
||||
}
|
||||
|
||||
func TestGetOrCreateGeneratedSecuritySecretGenerateError(t *testing.T) {
|
||||
client := newSecuritySecretTestClient(t)
|
||||
originalRead := readRandomBytes
|
||||
readRandomBytes = func([]byte) (int, error) {
|
||||
return 0, errors.New("boom")
|
||||
}
|
||||
t.Cleanup(func() {
|
||||
readRandomBytes = originalRead
|
||||
})
|
||||
|
||||
_, _, err := getOrCreateGeneratedSecuritySecret(context.Background(), client, "gen_error_key", 32)
|
||||
require.Error(t, err)
|
||||
require.Contains(t, err.Error(), "boom")
|
||||
}
|
||||
|
||||
func TestCreateSecuritySecretIfAbsent(t *testing.T) {
|
||||
client := newSecuritySecretTestClient(t)
|
||||
|
||||
_, err := createSecuritySecretIfAbsent(context.Background(), client, "abc", "short")
|
||||
require.Error(t, err)
|
||||
require.Contains(t, err.Error(), "at least 32 bytes")
|
||||
|
||||
stored, err := createSecuritySecretIfAbsent(context.Background(), client, "abc", "valid-jwt-secret-value-32bytes-long")
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, "valid-jwt-secret-value-32bytes-long", stored)
|
||||
|
||||
stored, err = createSecuritySecretIfAbsent(context.Background(), client, "abc", "another-valid-secret-value-32bytes")
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, "valid-jwt-secret-value-32bytes-long", stored)
|
||||
|
||||
count, err := client.SecuritySecret.Query().Where(securitysecret.KeyEQ("abc")).Count(context.Background())
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, 1, count)
|
||||
}
|
||||
|
||||
func TestCreateSecuritySecretIfAbsentValidationError(t *testing.T) {
|
||||
client := newSecuritySecretTestClient(t)
|
||||
_, err := createSecuritySecretIfAbsent(
|
||||
context.Background(),
|
||||
client,
|
||||
strings.Repeat("k", 101),
|
||||
"valid-jwt-secret-value-32bytes-long",
|
||||
)
|
||||
require.Error(t, err)
|
||||
}
|
||||
|
||||
func TestCreateSecuritySecretIfAbsentExecError(t *testing.T) {
|
||||
client := newSecuritySecretTestClient(t)
|
||||
require.NoError(t, client.Close())
|
||||
|
||||
_, err := createSecuritySecretIfAbsent(context.Background(), client, "closed-client-key", "valid-jwt-secret-value-32bytes-long")
|
||||
require.Error(t, err)
|
||||
}
|
||||
|
||||
func TestQuerySecuritySecretWithRetrySuccess(t *testing.T) {
|
||||
client := newSecuritySecretTestClient(t)
|
||||
created, err := client.SecuritySecret.Create().
|
||||
SetKey("retry_success_key").
|
||||
SetValue("retry-success-jwt-secret-value-32!!").
|
||||
Save(context.Background())
|
||||
require.NoError(t, err)
|
||||
|
||||
got, err := querySecuritySecretWithRetry(context.Background(), client, "retry_success_key")
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, created.ID, got.ID)
|
||||
require.Equal(t, "retry-success-jwt-secret-value-32!!", got.Value)
|
||||
}
|
||||
|
||||
func TestQuerySecuritySecretWithRetryExhausted(t *testing.T) {
|
||||
client := newSecuritySecretTestClient(t)
|
||||
|
||||
_, err := querySecuritySecretWithRetry(context.Background(), client, "retry_missing_key")
|
||||
require.Error(t, err)
|
||||
require.True(t, isSecretNotFoundError(err))
|
||||
}
|
||||
|
||||
func TestQuerySecuritySecretWithRetryContextCanceled(t *testing.T) {
|
||||
client := newSecuritySecretTestClient(t)
|
||||
ctx, cancel := context.WithTimeout(context.Background(), securitySecretReadRetryWait/2)
|
||||
defer cancel()
|
||||
|
||||
_, err := querySecuritySecretWithRetry(ctx, client, "retry_ctx_cancel_key")
|
||||
require.Error(t, err)
|
||||
require.ErrorIs(t, err, context.DeadlineExceeded)
|
||||
}
|
||||
|
||||
func TestQuerySecuritySecretWithRetryNonNotFoundError(t *testing.T) {
|
||||
client := newSecuritySecretTestClient(t)
|
||||
require.NoError(t, client.Close())
|
||||
|
||||
_, err := querySecuritySecretWithRetry(context.Background(), client, "retry_closed_client_key")
|
||||
require.Error(t, err)
|
||||
require.False(t, isSecretNotFoundError(err))
|
||||
}
|
||||
|
||||
func TestSecretNotFoundHelpers(t *testing.T) {
|
||||
require.False(t, isSecretNotFoundError(nil))
|
||||
require.False(t, isSQLNoRowsError(nil))
|
||||
|
||||
require.True(t, isSQLNoRowsError(sql.ErrNoRows))
|
||||
require.True(t, isSQLNoRowsError(fmt.Errorf("wrapped: %w", sql.ErrNoRows)))
|
||||
require.True(t, isSQLNoRowsError(errors.New("sql: no rows in result set")))
|
||||
|
||||
require.True(t, isSecretNotFoundError(sql.ErrNoRows))
|
||||
require.True(t, isSecretNotFoundError(errors.New("sql: no rows in result set")))
|
||||
require.False(t, isSecretNotFoundError(errors.New("some other error")))
|
||||
}
|
||||
|
||||
func TestGenerateHexSecretReadError(t *testing.T) {
|
||||
originalRead := readRandomBytes
|
||||
readRandomBytes = func([]byte) (int, error) {
|
||||
return 0, errors.New("read random failed")
|
||||
}
|
||||
t.Cleanup(func() {
|
||||
readRandomBytes = originalRead
|
||||
})
|
||||
|
||||
_, err := generateHexSecret(32)
|
||||
require.Error(t, err)
|
||||
require.Contains(t, err.Error(), "read random failed")
|
||||
}
|
||||
|
||||
func TestGenerateHexSecretLengths(t *testing.T) {
|
||||
v1, err := generateHexSecret(0)
|
||||
require.NoError(t, err)
|
||||
require.Len(t, v1, 64)
|
||||
_, err = hex.DecodeString(v1)
|
||||
require.NoError(t, err)
|
||||
|
||||
v2, err := generateHexSecret(16)
|
||||
require.NoError(t, err)
|
||||
require.Len(t, v2, 32)
|
||||
_, err = hex.DecodeString(v2)
|
||||
require.NoError(t, err)
|
||||
|
||||
require.NotEqual(t, v1, v2)
|
||||
}
|
||||
98
backend/internal/repository/sora_account_repo.go
Normal file
98
backend/internal/repository/sora_account_repo.go
Normal file
@@ -0,0 +1,98 @@
|
||||
package repository
|
||||
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"errors"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/service"
|
||||
)
|
||||
|
||||
// soraAccountRepository 实现 service.SoraAccountRepository 接口。
|
||||
// 使用原生 SQL 操作 sora_accounts 表,因为该表不在 Ent ORM 管理范围内。
|
||||
//
|
||||
// 设计说明:
|
||||
// - sora_accounts 表是独立迁移创建的,不通过 Ent Schema 管理
|
||||
// - 使用 ON CONFLICT (account_id) DO UPDATE 实现 Upsert 语义
|
||||
// - 与 accounts 主表通过外键关联,ON DELETE CASCADE 确保级联删除
|
||||
type soraAccountRepository struct {
|
||||
sql *sql.DB
|
||||
}
|
||||
|
||||
// NewSoraAccountRepository 创建 Sora 账号扩展表仓储实例
|
||||
func NewSoraAccountRepository(sqlDB *sql.DB) service.SoraAccountRepository {
|
||||
return &soraAccountRepository{sql: sqlDB}
|
||||
}
|
||||
|
||||
// Upsert 创建或更新 Sora 账号扩展信息
|
||||
// 使用 PostgreSQL ON CONFLICT ... DO UPDATE 实现原子性 upsert
|
||||
func (r *soraAccountRepository) Upsert(ctx context.Context, accountID int64, updates map[string]any) error {
|
||||
accessToken, accessOK := updates["access_token"].(string)
|
||||
refreshToken, refreshOK := updates["refresh_token"].(string)
|
||||
sessionToken, sessionOK := updates["session_token"].(string)
|
||||
|
||||
if !accessOK || accessToken == "" || !refreshOK || refreshToken == "" {
|
||||
if !sessionOK {
|
||||
return errors.New("缺少 access_token/refresh_token,且未提供可更新字段")
|
||||
}
|
||||
result, err := r.sql.ExecContext(ctx, `
|
||||
UPDATE sora_accounts
|
||||
SET session_token = CASE WHEN $2 = '' THEN session_token ELSE $2 END,
|
||||
updated_at = NOW()
|
||||
WHERE account_id = $1
|
||||
`, accountID, sessionToken)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
rows, err := result.RowsAffected()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if rows == 0 {
|
||||
return errors.New("sora_accounts 记录不存在,无法仅更新 session_token")
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
_, err := r.sql.ExecContext(ctx, `
|
||||
INSERT INTO sora_accounts (account_id, access_token, refresh_token, session_token, created_at, updated_at)
|
||||
VALUES ($1, $2, $3, $4, NOW(), NOW())
|
||||
ON CONFLICT (account_id) DO UPDATE SET
|
||||
access_token = EXCLUDED.access_token,
|
||||
refresh_token = EXCLUDED.refresh_token,
|
||||
session_token = CASE WHEN EXCLUDED.session_token = '' THEN sora_accounts.session_token ELSE EXCLUDED.session_token END,
|
||||
updated_at = NOW()
|
||||
`, accountID, accessToken, refreshToken, sessionToken)
|
||||
return err
|
||||
}
|
||||
|
||||
// GetByAccountID 根据账号 ID 获取 Sora 扩展信息
|
||||
func (r *soraAccountRepository) GetByAccountID(ctx context.Context, accountID int64) (*service.SoraAccount, error) {
|
||||
rows, err := r.sql.QueryContext(ctx, `
|
||||
SELECT account_id, access_token, refresh_token, COALESCE(session_token, '')
|
||||
FROM sora_accounts
|
||||
WHERE account_id = $1
|
||||
`, accountID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer func() { _ = rows.Close() }()
|
||||
|
||||
if !rows.Next() {
|
||||
return nil, nil // 记录不存在
|
||||
}
|
||||
|
||||
var sa service.SoraAccount
|
||||
if err := rows.Scan(&sa.AccountID, &sa.AccessToken, &sa.RefreshToken, &sa.SessionToken); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &sa, nil
|
||||
}
|
||||
|
||||
// Delete 删除 Sora 账号扩展信息
|
||||
func (r *soraAccountRepository) Delete(ctx context.Context, accountID int64) error {
|
||||
_, err := r.sql.ExecContext(ctx, `
|
||||
DELETE FROM sora_accounts WHERE account_id = $1
|
||||
`, accountID)
|
||||
return err
|
||||
}
|
||||
@@ -22,7 +22,23 @@ import (
|
||||
"github.com/lib/pq"
|
||||
)
|
||||
|
||||
const usageLogSelectColumns = "id, user_id, api_key_id, account_id, request_id, model, group_id, subscription_id, input_tokens, output_tokens, cache_creation_tokens, cache_read_tokens, cache_creation_5m_tokens, cache_creation_1h_tokens, input_cost, output_cost, cache_creation_cost, cache_read_cost, total_cost, actual_cost, rate_multiplier, account_rate_multiplier, billing_type, stream, duration_ms, first_token_ms, user_agent, ip_address, image_count, image_size, reasoning_effort, cache_ttl_overridden, created_at"
|
||||
const usageLogSelectColumns = "id, user_id, api_key_id, account_id, request_id, model, group_id, subscription_id, input_tokens, output_tokens, cache_creation_tokens, cache_read_tokens, cache_creation_5m_tokens, cache_creation_1h_tokens, input_cost, output_cost, cache_creation_cost, cache_read_cost, total_cost, actual_cost, rate_multiplier, account_rate_multiplier, billing_type, stream, duration_ms, first_token_ms, user_agent, ip_address, image_count, image_size, media_type, reasoning_effort, cache_ttl_overridden, created_at"
|
||||
|
||||
// dateFormatWhitelist 将 granularity 参数映射为 PostgreSQL TO_CHAR 格式字符串,防止外部输入直接拼入 SQL
|
||||
var dateFormatWhitelist = map[string]string{
|
||||
"hour": "YYYY-MM-DD HH24:00",
|
||||
"day": "YYYY-MM-DD",
|
||||
"week": "IYYY-IW",
|
||||
"month": "YYYY-MM",
|
||||
}
|
||||
|
||||
// safeDateFormat 根据白名单获取 dateFormat,未匹配时返回默认值
|
||||
func safeDateFormat(granularity string) string {
|
||||
if f, ok := dateFormatWhitelist[granularity]; ok {
|
||||
return f
|
||||
}
|
||||
return "YYYY-MM-DD"
|
||||
}
|
||||
|
||||
type usageLogRepository struct {
|
||||
client *dbent.Client
|
||||
@@ -111,23 +127,24 @@ func (r *usageLogRepository) Create(ctx context.Context, log *service.UsageLog)
|
||||
duration_ms,
|
||||
first_token_ms,
|
||||
user_agent,
|
||||
ip_address,
|
||||
image_count,
|
||||
image_size,
|
||||
reasoning_effort,
|
||||
cache_ttl_overridden,
|
||||
created_at
|
||||
) VALUES (
|
||||
$1, $2, $3, $4, $5,
|
||||
$6, $7,
|
||||
$8, $9, $10, $11,
|
||||
$12, $13,
|
||||
$14, $15, $16, $17, $18, $19,
|
||||
$20, $21, $22, $23, $24, $25, $26, $27, $28, $29, $30, $31, $32
|
||||
)
|
||||
ON CONFLICT (request_id, api_key_id) DO NOTHING
|
||||
RETURNING id, created_at
|
||||
`
|
||||
ip_address,
|
||||
image_count,
|
||||
image_size,
|
||||
media_type,
|
||||
reasoning_effort,
|
||||
cache_ttl_overridden,
|
||||
created_at
|
||||
) VALUES (
|
||||
$1, $2, $3, $4, $5,
|
||||
$6, $7,
|
||||
$8, $9, $10, $11,
|
||||
$12, $13,
|
||||
$14, $15, $16, $17, $18, $19,
|
||||
$20, $21, $22, $23, $24, $25, $26, $27, $28, $29, $30, $31, $32, $33
|
||||
)
|
||||
ON CONFLICT (request_id, api_key_id) DO NOTHING
|
||||
RETURNING id, created_at
|
||||
`
|
||||
|
||||
groupID := nullInt64(log.GroupID)
|
||||
subscriptionID := nullInt64(log.SubscriptionID)
|
||||
@@ -136,6 +153,7 @@ func (r *usageLogRepository) Create(ctx context.Context, log *service.UsageLog)
|
||||
userAgent := nullString(log.UserAgent)
|
||||
ipAddress := nullString(log.IPAddress)
|
||||
imageSize := nullString(log.ImageSize)
|
||||
mediaType := nullString(log.MediaType)
|
||||
reasoningEffort := nullString(log.ReasoningEffort)
|
||||
|
||||
var requestIDArg any
|
||||
@@ -173,6 +191,7 @@ func (r *usageLogRepository) Create(ctx context.Context, log *service.UsageLog)
|
||||
ipAddress,
|
||||
log.ImageCount,
|
||||
imageSize,
|
||||
mediaType,
|
||||
reasoningEffort,
|
||||
log.CacheTTLOverridden,
|
||||
createdAt,
|
||||
@@ -566,7 +585,7 @@ func (r *usageLogRepository) ListByAccount(ctx context.Context, accountID int64,
|
||||
}
|
||||
|
||||
func (r *usageLogRepository) ListByUserAndTimeRange(ctx context.Context, userID int64, startTime, endTime time.Time) ([]service.UsageLog, *pagination.PaginationResult, error) {
|
||||
query := "SELECT " + usageLogSelectColumns + " FROM usage_logs WHERE user_id = $1 AND created_at >= $2 AND created_at < $3 ORDER BY id DESC"
|
||||
query := "SELECT " + usageLogSelectColumns + " FROM usage_logs WHERE user_id = $1 AND created_at >= $2 AND created_at < $3 ORDER BY id DESC LIMIT 10000"
|
||||
logs, err := r.queryUsageLogs(ctx, query, userID, startTime, endTime)
|
||||
return logs, nil, err
|
||||
}
|
||||
@@ -812,19 +831,19 @@ func resolveUsageStatsTimezone() string {
|
||||
}
|
||||
|
||||
func (r *usageLogRepository) ListByAPIKeyAndTimeRange(ctx context.Context, apiKeyID int64, startTime, endTime time.Time) ([]service.UsageLog, *pagination.PaginationResult, error) {
|
||||
query := "SELECT " + usageLogSelectColumns + " FROM usage_logs WHERE api_key_id = $1 AND created_at >= $2 AND created_at < $3 ORDER BY id DESC"
|
||||
query := "SELECT " + usageLogSelectColumns + " FROM usage_logs WHERE api_key_id = $1 AND created_at >= $2 AND created_at < $3 ORDER BY id DESC LIMIT 10000"
|
||||
logs, err := r.queryUsageLogs(ctx, query, apiKeyID, startTime, endTime)
|
||||
return logs, nil, err
|
||||
}
|
||||
|
||||
func (r *usageLogRepository) ListByAccountAndTimeRange(ctx context.Context, accountID int64, startTime, endTime time.Time) ([]service.UsageLog, *pagination.PaginationResult, error) {
|
||||
query := "SELECT " + usageLogSelectColumns + " FROM usage_logs WHERE account_id = $1 AND created_at >= $2 AND created_at < $3 ORDER BY id DESC"
|
||||
query := "SELECT " + usageLogSelectColumns + " FROM usage_logs WHERE account_id = $1 AND created_at >= $2 AND created_at < $3 ORDER BY id DESC LIMIT 10000"
|
||||
logs, err := r.queryUsageLogs(ctx, query, accountID, startTime, endTime)
|
||||
return logs, nil, err
|
||||
}
|
||||
|
||||
func (r *usageLogRepository) ListByModelAndTimeRange(ctx context.Context, modelName string, startTime, endTime time.Time) ([]service.UsageLog, *pagination.PaginationResult, error) {
|
||||
query := "SELECT " + usageLogSelectColumns + " FROM usage_logs WHERE model = $1 AND created_at >= $2 AND created_at < $3 ORDER BY id DESC"
|
||||
query := "SELECT " + usageLogSelectColumns + " FROM usage_logs WHERE model = $1 AND created_at >= $2 AND created_at < $3 ORDER BY id DESC LIMIT 10000"
|
||||
logs, err := r.queryUsageLogs(ctx, query, modelName, startTime, endTime)
|
||||
return logs, nil, err
|
||||
}
|
||||
@@ -896,6 +915,59 @@ func (r *usageLogRepository) GetAccountWindowStats(ctx context.Context, accountI
|
||||
return stats, nil
|
||||
}
|
||||
|
||||
// GetAccountWindowStatsBatch 批量获取同一窗口起点下多个账号的统计数据。
|
||||
// 返回 map[accountID]*AccountStats,未命中的账号会返回零值统计,便于上层直接复用。
|
||||
func (r *usageLogRepository) GetAccountWindowStatsBatch(ctx context.Context, accountIDs []int64, startTime time.Time) (map[int64]*usagestats.AccountStats, error) {
|
||||
result := make(map[int64]*usagestats.AccountStats, len(accountIDs))
|
||||
if len(accountIDs) == 0 {
|
||||
return result, nil
|
||||
}
|
||||
|
||||
query := `
|
||||
SELECT
|
||||
account_id,
|
||||
COUNT(*) as requests,
|
||||
COALESCE(SUM(input_tokens + output_tokens + cache_creation_tokens + cache_read_tokens), 0) as tokens,
|
||||
COALESCE(SUM(total_cost * COALESCE(account_rate_multiplier, 1)), 0) as cost,
|
||||
COALESCE(SUM(total_cost), 0) as standard_cost,
|
||||
COALESCE(SUM(actual_cost), 0) as user_cost
|
||||
FROM usage_logs
|
||||
WHERE account_id = ANY($1) AND created_at >= $2
|
||||
GROUP BY account_id
|
||||
`
|
||||
rows, err := r.sql.QueryContext(ctx, query, pq.Array(accountIDs), startTime)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer func() { _ = rows.Close() }()
|
||||
|
||||
for rows.Next() {
|
||||
var accountID int64
|
||||
stats := &usagestats.AccountStats{}
|
||||
if err := rows.Scan(
|
||||
&accountID,
|
||||
&stats.Requests,
|
||||
&stats.Tokens,
|
||||
&stats.Cost,
|
||||
&stats.StandardCost,
|
||||
&stats.UserCost,
|
||||
); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
result[accountID] = stats
|
||||
}
|
||||
if err := rows.Err(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
for _, accountID := range accountIDs {
|
||||
if _, ok := result[accountID]; !ok {
|
||||
result[accountID] = &usagestats.AccountStats{}
|
||||
}
|
||||
}
|
||||
return result, nil
|
||||
}
|
||||
|
||||
// TrendDataPoint represents a single point in trend data
|
||||
type TrendDataPoint = usagestats.TrendDataPoint
|
||||
|
||||
@@ -910,10 +982,7 @@ type APIKeyUsageTrendPoint = usagestats.APIKeyUsageTrendPoint
|
||||
|
||||
// GetAPIKeyUsageTrend returns usage trend data grouped by API key and date
|
||||
func (r *usageLogRepository) GetAPIKeyUsageTrend(ctx context.Context, startTime, endTime time.Time, granularity string, limit int) (results []APIKeyUsageTrendPoint, err error) {
|
||||
dateFormat := "YYYY-MM-DD"
|
||||
if granularity == "hour" {
|
||||
dateFormat = "YYYY-MM-DD HH24:00"
|
||||
}
|
||||
dateFormat := safeDateFormat(granularity)
|
||||
|
||||
query := fmt.Sprintf(`
|
||||
WITH top_keys AS (
|
||||
@@ -968,10 +1037,7 @@ func (r *usageLogRepository) GetAPIKeyUsageTrend(ctx context.Context, startTime,
|
||||
|
||||
// GetUserUsageTrend returns usage trend data grouped by user and date
|
||||
func (r *usageLogRepository) GetUserUsageTrend(ctx context.Context, startTime, endTime time.Time, granularity string, limit int) (results []UserUsageTrendPoint, err error) {
|
||||
dateFormat := "YYYY-MM-DD"
|
||||
if granularity == "hour" {
|
||||
dateFormat = "YYYY-MM-DD HH24:00"
|
||||
}
|
||||
dateFormat := safeDateFormat(granularity)
|
||||
|
||||
query := fmt.Sprintf(`
|
||||
WITH top_users AS (
|
||||
@@ -1230,10 +1296,7 @@ func (r *usageLogRepository) GetAPIKeyDashboardStats(ctx context.Context, apiKey
|
||||
|
||||
// GetUserUsageTrendByUserID 获取指定用户的使用趋势
|
||||
func (r *usageLogRepository) GetUserUsageTrendByUserID(ctx context.Context, userID int64, startTime, endTime time.Time, granularity string) (results []TrendDataPoint, err error) {
|
||||
dateFormat := "YYYY-MM-DD"
|
||||
if granularity == "hour" {
|
||||
dateFormat = "YYYY-MM-DD HH24:00"
|
||||
}
|
||||
dateFormat := safeDateFormat(granularity)
|
||||
|
||||
query := fmt.Sprintf(`
|
||||
SELECT
|
||||
@@ -1371,13 +1434,22 @@ type UsageStats = usagestats.UsageStats
|
||||
// BatchUserUsageStats represents usage stats for a single user
|
||||
type BatchUserUsageStats = usagestats.BatchUserUsageStats
|
||||
|
||||
// GetBatchUserUsageStats gets today and total actual_cost for multiple users
|
||||
func (r *usageLogRepository) GetBatchUserUsageStats(ctx context.Context, userIDs []int64) (map[int64]*BatchUserUsageStats, error) {
|
||||
// GetBatchUserUsageStats gets today and total actual_cost for multiple users within a time range.
|
||||
// If startTime is zero, defaults to 30 days ago.
|
||||
func (r *usageLogRepository) GetBatchUserUsageStats(ctx context.Context, userIDs []int64, startTime, endTime time.Time) (map[int64]*BatchUserUsageStats, error) {
|
||||
result := make(map[int64]*BatchUserUsageStats)
|
||||
if len(userIDs) == 0 {
|
||||
return result, nil
|
||||
}
|
||||
|
||||
// 默认最近 30 天
|
||||
if startTime.IsZero() {
|
||||
startTime = time.Now().AddDate(0, 0, -30)
|
||||
}
|
||||
if endTime.IsZero() {
|
||||
endTime = time.Now()
|
||||
}
|
||||
|
||||
for _, id := range userIDs {
|
||||
result[id] = &BatchUserUsageStats{UserID: id}
|
||||
}
|
||||
@@ -1385,10 +1457,10 @@ func (r *usageLogRepository) GetBatchUserUsageStats(ctx context.Context, userIDs
|
||||
query := `
|
||||
SELECT user_id, COALESCE(SUM(actual_cost), 0) as total_cost
|
||||
FROM usage_logs
|
||||
WHERE user_id = ANY($1)
|
||||
WHERE user_id = ANY($1) AND created_at >= $2 AND created_at < $3
|
||||
GROUP BY user_id
|
||||
`
|
||||
rows, err := r.sql.QueryContext(ctx, query, pq.Array(userIDs))
|
||||
rows, err := r.sql.QueryContext(ctx, query, pq.Array(userIDs), startTime, endTime)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@@ -1445,13 +1517,22 @@ func (r *usageLogRepository) GetBatchUserUsageStats(ctx context.Context, userIDs
|
||||
// BatchAPIKeyUsageStats represents usage stats for a single API key
|
||||
type BatchAPIKeyUsageStats = usagestats.BatchAPIKeyUsageStats
|
||||
|
||||
// GetBatchAPIKeyUsageStats gets today and total actual_cost for multiple API keys
|
||||
func (r *usageLogRepository) GetBatchAPIKeyUsageStats(ctx context.Context, apiKeyIDs []int64) (map[int64]*BatchAPIKeyUsageStats, error) {
|
||||
// GetBatchAPIKeyUsageStats gets today and total actual_cost for multiple API keys within a time range.
|
||||
// If startTime is zero, defaults to 30 days ago.
|
||||
func (r *usageLogRepository) GetBatchAPIKeyUsageStats(ctx context.Context, apiKeyIDs []int64, startTime, endTime time.Time) (map[int64]*BatchAPIKeyUsageStats, error) {
|
||||
result := make(map[int64]*BatchAPIKeyUsageStats)
|
||||
if len(apiKeyIDs) == 0 {
|
||||
return result, nil
|
||||
}
|
||||
|
||||
// 默认最近 30 天
|
||||
if startTime.IsZero() {
|
||||
startTime = time.Now().AddDate(0, 0, -30)
|
||||
}
|
||||
if endTime.IsZero() {
|
||||
endTime = time.Now()
|
||||
}
|
||||
|
||||
for _, id := range apiKeyIDs {
|
||||
result[id] = &BatchAPIKeyUsageStats{APIKeyID: id}
|
||||
}
|
||||
@@ -1459,10 +1540,10 @@ func (r *usageLogRepository) GetBatchAPIKeyUsageStats(ctx context.Context, apiKe
|
||||
query := `
|
||||
SELECT api_key_id, COALESCE(SUM(actual_cost), 0) as total_cost
|
||||
FROM usage_logs
|
||||
WHERE api_key_id = ANY($1)
|
||||
WHERE api_key_id = ANY($1) AND created_at >= $2 AND created_at < $3
|
||||
GROUP BY api_key_id
|
||||
`
|
||||
rows, err := r.sql.QueryContext(ctx, query, pq.Array(apiKeyIDs))
|
||||
rows, err := r.sql.QueryContext(ctx, query, pq.Array(apiKeyIDs), startTime, endTime)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@@ -1518,10 +1599,7 @@ func (r *usageLogRepository) GetBatchAPIKeyUsageStats(ctx context.Context, apiKe
|
||||
|
||||
// GetUsageTrendWithFilters returns usage trend data with optional filters
|
||||
func (r *usageLogRepository) GetUsageTrendWithFilters(ctx context.Context, startTime, endTime time.Time, granularity string, userID, apiKeyID, accountID, groupID int64, model string, stream *bool, billingType *int8) (results []TrendDataPoint, err error) {
|
||||
dateFormat := "YYYY-MM-DD"
|
||||
if granularity == "hour" {
|
||||
dateFormat = "YYYY-MM-DD HH24:00"
|
||||
}
|
||||
dateFormat := safeDateFormat(granularity)
|
||||
|
||||
query := fmt.Sprintf(`
|
||||
SELECT
|
||||
@@ -2196,6 +2274,7 @@ func scanUsageLog(scanner interface{ Scan(...any) error }) (*service.UsageLog, e
|
||||
ipAddress sql.NullString
|
||||
imageCount int
|
||||
imageSize sql.NullString
|
||||
mediaType sql.NullString
|
||||
reasoningEffort sql.NullString
|
||||
cacheTTLOverridden bool
|
||||
createdAt time.Time
|
||||
@@ -2232,6 +2311,7 @@ func scanUsageLog(scanner interface{ Scan(...any) error }) (*service.UsageLog, e
|
||||
&ipAddress,
|
||||
&imageCount,
|
||||
&imageSize,
|
||||
&mediaType,
|
||||
&reasoningEffort,
|
||||
&cacheTTLOverridden,
|
||||
&createdAt,
|
||||
@@ -2294,6 +2374,9 @@ func scanUsageLog(scanner interface{ Scan(...any) error }) (*service.UsageLog, e
|
||||
if imageSize.Valid {
|
||||
log.ImageSize = &imageSize.String
|
||||
}
|
||||
if mediaType.Valid {
|
||||
log.MediaType = &mediaType.String
|
||||
}
|
||||
if reasoningEffort.Valid {
|
||||
log.ReasoningEffort = &reasoningEffort.String
|
||||
}
|
||||
|
||||
@@ -648,7 +648,7 @@ func (s *UsageLogRepoSuite) TestGetBatchUserUsageStats() {
|
||||
s.createUsageLog(user1, apiKey1, account, 10, 20, 0.5, time.Now())
|
||||
s.createUsageLog(user2, apiKey2, account, 15, 25, 0.6, time.Now())
|
||||
|
||||
stats, err := s.repo.GetBatchUserUsageStats(s.ctx, []int64{user1.ID, user2.ID})
|
||||
stats, err := s.repo.GetBatchUserUsageStats(s.ctx, []int64{user1.ID, user2.ID}, time.Time{}, time.Time{})
|
||||
s.Require().NoError(err, "GetBatchUserUsageStats")
|
||||
s.Require().Len(stats, 2)
|
||||
s.Require().NotNil(stats[user1.ID])
|
||||
@@ -656,7 +656,7 @@ func (s *UsageLogRepoSuite) TestGetBatchUserUsageStats() {
|
||||
}
|
||||
|
||||
func (s *UsageLogRepoSuite) TestGetBatchUserUsageStats_Empty() {
|
||||
stats, err := s.repo.GetBatchUserUsageStats(s.ctx, []int64{})
|
||||
stats, err := s.repo.GetBatchUserUsageStats(s.ctx, []int64{}, time.Time{}, time.Time{})
|
||||
s.Require().NoError(err)
|
||||
s.Require().Empty(stats)
|
||||
}
|
||||
@@ -672,13 +672,13 @@ func (s *UsageLogRepoSuite) TestGetBatchApiKeyUsageStats() {
|
||||
s.createUsageLog(user, apiKey1, account, 10, 20, 0.5, time.Now())
|
||||
s.createUsageLog(user, apiKey2, account, 15, 25, 0.6, time.Now())
|
||||
|
||||
stats, err := s.repo.GetBatchAPIKeyUsageStats(s.ctx, []int64{apiKey1.ID, apiKey2.ID})
|
||||
stats, err := s.repo.GetBatchAPIKeyUsageStats(s.ctx, []int64{apiKey1.ID, apiKey2.ID}, time.Time{}, time.Time{})
|
||||
s.Require().NoError(err, "GetBatchAPIKeyUsageStats")
|
||||
s.Require().Len(stats, 2)
|
||||
}
|
||||
|
||||
func (s *UsageLogRepoSuite) TestGetBatchApiKeyUsageStats_Empty() {
|
||||
stats, err := s.repo.GetBatchAPIKeyUsageStats(s.ctx, []int64{})
|
||||
stats, err := s.repo.GetBatchAPIKeyUsageStats(s.ctx, []int64{}, time.Time{}, time.Time{})
|
||||
s.Require().NoError(err)
|
||||
s.Require().Empty(stats)
|
||||
}
|
||||
|
||||
41
backend/internal/repository/usage_log_repo_unit_test.go
Normal file
41
backend/internal/repository/usage_log_repo_unit_test.go
Normal file
@@ -0,0 +1,41 @@
|
||||
//go:build unit
|
||||
|
||||
package repository
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestSafeDateFormat(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
granularity string
|
||||
expected string
|
||||
}{
|
||||
// 合法值
|
||||
{"hour", "hour", "YYYY-MM-DD HH24:00"},
|
||||
{"day", "day", "YYYY-MM-DD"},
|
||||
{"week", "week", "IYYY-IW"},
|
||||
{"month", "month", "YYYY-MM"},
|
||||
|
||||
// 非法值回退到默认
|
||||
{"空字符串", "", "YYYY-MM-DD"},
|
||||
{"未知粒度 year", "year", "YYYY-MM-DD"},
|
||||
{"未知粒度 minute", "minute", "YYYY-MM-DD"},
|
||||
|
||||
// 恶意字符串
|
||||
{"SQL 注入尝试", "'; DROP TABLE users; --", "YYYY-MM-DD"},
|
||||
{"带引号", "day'", "YYYY-MM-DD"},
|
||||
{"带括号", "day)", "YYYY-MM-DD"},
|
||||
{"Unicode", "日", "YYYY-MM-DD"},
|
||||
}
|
||||
|
||||
for _, tc := range tests {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
got := safeDateFormat(tc.granularity)
|
||||
require.Equal(t, tc.expected, got, "safeDateFormat(%q)", tc.granularity)
|
||||
})
|
||||
}
|
||||
}
|
||||
@@ -28,7 +28,7 @@ func ProvideConcurrencyCache(rdb *redis.Client, cfg *config.Config) service.Conc
|
||||
// ProvideGitHubReleaseClient 创建 GitHub Release 客户端
|
||||
// 从配置中读取代理设置,支持国内服务器通过代理访问 GitHub
|
||||
func ProvideGitHubReleaseClient(cfg *config.Config) service.GitHubReleaseClient {
|
||||
return NewGitHubReleaseClient(cfg.Update.ProxyURL)
|
||||
return NewGitHubReleaseClient(cfg.Update.ProxyURL, cfg.Security.ProxyFallback.AllowDirectOnError)
|
||||
}
|
||||
|
||||
// ProvidePricingRemoteClient 创建定价数据远程客户端
|
||||
@@ -53,12 +53,14 @@ var ProviderSet = wire.NewSet(
|
||||
NewAPIKeyRepository,
|
||||
NewGroupRepository,
|
||||
NewAccountRepository,
|
||||
NewSoraAccountRepository, // Sora 账号扩展表仓储
|
||||
NewProxyRepository,
|
||||
NewRedeemCodeRepository,
|
||||
NewPromoCodeRepository,
|
||||
NewAnnouncementRepository,
|
||||
NewAnnouncementReadRepository,
|
||||
NewUsageLogRepository,
|
||||
NewIdempotencyRepository,
|
||||
NewUsageCleanupRepository,
|
||||
NewDashboardAggregationRepository,
|
||||
NewSettingRepository,
|
||||
|
||||
Reference in New Issue
Block a user