feat(sync): full code sync from release

This commit is contained in:
yangjianbo
2026-02-28 15:01:20 +08:00
parent bfc7b339f7
commit bb664d9bbf
338 changed files with 54513 additions and 2011 deletions

View File

@@ -50,11 +50,6 @@ type accountRepository struct {
schedulerCache service.SchedulerCache
}
type tempUnschedSnapshot struct {
until *time.Time
reason string
}
// NewAccountRepository 创建账户仓储实例。
// 这是对外暴露的构造函数,返回接口类型以便于依赖注入。
func NewAccountRepository(client *dbent.Client, sqlDB *sql.DB, schedulerCache service.SchedulerCache) service.AccountRepository {
@@ -189,11 +184,6 @@ func (r *accountRepository) GetByIDs(ctx context.Context, ids []int64) ([]*servi
accountIDs = append(accountIDs, acc.ID)
}
tempUnschedMap, err := r.loadTempUnschedStates(ctx, accountIDs)
if err != nil {
return nil, err
}
groupsByAccount, groupIDsByAccount, accountGroupsByAccount, err := r.loadAccountGroups(ctx, accountIDs)
if err != nil {
return nil, err
@@ -220,10 +210,6 @@ func (r *accountRepository) GetByIDs(ctx context.Context, ids []int64) ([]*servi
if ags, ok := accountGroupsByAccount[entAcc.ID]; ok {
out.AccountGroups = ags
}
if snap, ok := tempUnschedMap[entAcc.ID]; ok {
out.TempUnschedulableUntil = snap.until
out.TempUnschedulableReason = snap.reason
}
outByID[entAcc.ID] = out
}
@@ -611,6 +597,43 @@ func (r *accountRepository) syncSchedulerAccountSnapshot(ctx context.Context, ac
}
}
func (r *accountRepository) syncSchedulerAccountSnapshots(ctx context.Context, accountIDs []int64) {
if r == nil || r.schedulerCache == nil || len(accountIDs) == 0 {
return
}
uniqueIDs := make([]int64, 0, len(accountIDs))
seen := make(map[int64]struct{}, len(accountIDs))
for _, id := range accountIDs {
if id <= 0 {
continue
}
if _, exists := seen[id]; exists {
continue
}
seen[id] = struct{}{}
uniqueIDs = append(uniqueIDs, id)
}
if len(uniqueIDs) == 0 {
return
}
accounts, err := r.GetByIDs(ctx, uniqueIDs)
if err != nil {
logger.LegacyPrintf("repository.account", "[Scheduler] batch sync account snapshot read failed: count=%d err=%v", len(uniqueIDs), err)
return
}
for _, account := range accounts {
if account == nil {
continue
}
if err := r.schedulerCache.SetAccount(ctx, account); err != nil {
logger.LegacyPrintf("repository.account", "[Scheduler] batch sync account snapshot write failed: id=%d err=%v", account.ID, err)
}
}
}
func (r *accountRepository) ClearError(ctx context.Context, id int64) error {
_, err := r.client.Account.Update().
Where(dbaccount.IDEQ(id)).
@@ -1197,9 +1220,7 @@ func (r *accountRepository) BulkUpdate(ctx context.Context, ids []int64, updates
shouldSync = true
}
if shouldSync {
for _, id := range ids {
r.syncSchedulerAccountSnapshot(ctx, id)
}
r.syncSchedulerAccountSnapshots(ctx, ids)
}
}
return rows, nil
@@ -1291,10 +1312,6 @@ func (r *accountRepository) accountsToService(ctx context.Context, accounts []*d
if err != nil {
return nil, err
}
tempUnschedMap, err := r.loadTempUnschedStates(ctx, accountIDs)
if err != nil {
return nil, err
}
groupsByAccount, groupIDsByAccount, accountGroupsByAccount, err := r.loadAccountGroups(ctx, accountIDs)
if err != nil {
return nil, err
@@ -1320,10 +1337,6 @@ func (r *accountRepository) accountsToService(ctx context.Context, accounts []*d
if ags, ok := accountGroupsByAccount[acc.ID]; ok {
out.AccountGroups = ags
}
if snap, ok := tempUnschedMap[acc.ID]; ok {
out.TempUnschedulableUntil = snap.until
out.TempUnschedulableReason = snap.reason
}
outAccounts = append(outAccounts, *out)
}
@@ -1348,48 +1361,6 @@ func notExpiredPredicate(now time.Time) dbpredicate.Account {
)
}
func (r *accountRepository) loadTempUnschedStates(ctx context.Context, accountIDs []int64) (map[int64]tempUnschedSnapshot, error) {
out := make(map[int64]tempUnschedSnapshot)
if len(accountIDs) == 0 {
return out, nil
}
rows, err := r.sql.QueryContext(ctx, `
SELECT id, temp_unschedulable_until, temp_unschedulable_reason
FROM accounts
WHERE id = ANY($1)
`, pq.Array(accountIDs))
if err != nil {
return nil, err
}
defer func() { _ = rows.Close() }()
for rows.Next() {
var id int64
var until sql.NullTime
var reason sql.NullString
if err := rows.Scan(&id, &until, &reason); err != nil {
return nil, err
}
var untilPtr *time.Time
if until.Valid {
tmp := until.Time
untilPtr = &tmp
}
if reason.Valid {
out[id] = tempUnschedSnapshot{until: untilPtr, reason: reason.String}
} else {
out[id] = tempUnschedSnapshot{until: untilPtr, reason: ""}
}
}
if err := rows.Err(); err != nil {
return nil, err
}
return out, nil
}
func (r *accountRepository) loadProxies(ctx context.Context, proxyIDs []int64) (map[int64]*service.Proxy, error) {
proxyMap := make(map[int64]*service.Proxy)
if len(proxyIDs) == 0 {
@@ -1500,31 +1471,33 @@ func accountEntityToService(m *dbent.Account) *service.Account {
rateMultiplier := m.RateMultiplier
return &service.Account{
ID: m.ID,
Name: m.Name,
Notes: m.Notes,
Platform: m.Platform,
Type: m.Type,
Credentials: copyJSONMap(m.Credentials),
Extra: copyJSONMap(m.Extra),
ProxyID: m.ProxyID,
Concurrency: m.Concurrency,
Priority: m.Priority,
RateMultiplier: &rateMultiplier,
Status: m.Status,
ErrorMessage: derefString(m.ErrorMessage),
LastUsedAt: m.LastUsedAt,
ExpiresAt: m.ExpiresAt,
AutoPauseOnExpired: m.AutoPauseOnExpired,
CreatedAt: m.CreatedAt,
UpdatedAt: m.UpdatedAt,
Schedulable: m.Schedulable,
RateLimitedAt: m.RateLimitedAt,
RateLimitResetAt: m.RateLimitResetAt,
OverloadUntil: m.OverloadUntil,
SessionWindowStart: m.SessionWindowStart,
SessionWindowEnd: m.SessionWindowEnd,
SessionWindowStatus: derefString(m.SessionWindowStatus),
ID: m.ID,
Name: m.Name,
Notes: m.Notes,
Platform: m.Platform,
Type: m.Type,
Credentials: copyJSONMap(m.Credentials),
Extra: copyJSONMap(m.Extra),
ProxyID: m.ProxyID,
Concurrency: m.Concurrency,
Priority: m.Priority,
RateMultiplier: &rateMultiplier,
Status: m.Status,
ErrorMessage: derefString(m.ErrorMessage),
LastUsedAt: m.LastUsedAt,
ExpiresAt: m.ExpiresAt,
AutoPauseOnExpired: m.AutoPauseOnExpired,
CreatedAt: m.CreatedAt,
UpdatedAt: m.UpdatedAt,
Schedulable: m.Schedulable,
RateLimitedAt: m.RateLimitedAt,
RateLimitResetAt: m.RateLimitResetAt,
OverloadUntil: m.OverloadUntil,
TempUnschedulableUntil: m.TempUnschedulableUntil,
TempUnschedulableReason: derefString(m.TempUnschedulableReason),
SessionWindowStart: m.SessionWindowStart,
SessionWindowEnd: m.SessionWindowEnd,
SessionWindowStatus: derefString(m.SessionWindowStatus),
}
}

View File

@@ -500,6 +500,38 @@ func (s *AccountRepoSuite) TestClearRateLimit() {
s.Require().Nil(got.OverloadUntil)
}
func (s *AccountRepoSuite) TestTempUnschedulableFieldsLoadedByGetByIDAndGetByIDs() {
acc1 := mustCreateAccount(s.T(), s.client, &service.Account{Name: "acc-temp-1"})
acc2 := mustCreateAccount(s.T(), s.client, &service.Account{Name: "acc-temp-2"})
until := time.Now().Add(15 * time.Minute).UTC().Truncate(time.Second)
reason := `{"rule":"429","matched_keyword":"too many requests"}`
s.Require().NoError(s.repo.SetTempUnschedulable(s.ctx, acc1.ID, until, reason))
gotByID, err := s.repo.GetByID(s.ctx, acc1.ID)
s.Require().NoError(err)
s.Require().NotNil(gotByID.TempUnschedulableUntil)
s.Require().WithinDuration(until, *gotByID.TempUnschedulableUntil, time.Second)
s.Require().Equal(reason, gotByID.TempUnschedulableReason)
gotByIDs, err := s.repo.GetByIDs(s.ctx, []int64{acc2.ID, acc1.ID})
s.Require().NoError(err)
s.Require().Len(gotByIDs, 2)
s.Require().Equal(acc2.ID, gotByIDs[0].ID)
s.Require().Nil(gotByIDs[0].TempUnschedulableUntil)
s.Require().Equal("", gotByIDs[0].TempUnschedulableReason)
s.Require().Equal(acc1.ID, gotByIDs[1].ID)
s.Require().NotNil(gotByIDs[1].TempUnschedulableUntil)
s.Require().WithinDuration(until, *gotByIDs[1].TempUnschedulableUntil, time.Second)
s.Require().Equal(reason, gotByIDs[1].TempUnschedulableReason)
s.Require().NoError(s.repo.ClearTempUnschedulable(s.ctx, acc1.ID))
cleared, err := s.repo.GetByID(s.ctx, acc1.ID)
s.Require().NoError(err)
s.Require().Nil(cleared.TempUnschedulableUntil)
s.Require().Equal("", cleared.TempUnschedulableReason)
}
// --- UpdateLastUsed ---
func (s *AccountRepoSuite) TestUpdateLastUsed() {

View File

@@ -445,20 +445,22 @@ func userEntityToService(u *dbent.User) *service.User {
return nil
}
return &service.User{
ID: u.ID,
Email: u.Email,
Username: u.Username,
Notes: u.Notes,
PasswordHash: u.PasswordHash,
Role: u.Role,
Balance: u.Balance,
Concurrency: u.Concurrency,
Status: u.Status,
TotpSecretEncrypted: u.TotpSecretEncrypted,
TotpEnabled: u.TotpEnabled,
TotpEnabledAt: u.TotpEnabledAt,
CreatedAt: u.CreatedAt,
UpdatedAt: u.UpdatedAt,
ID: u.ID,
Email: u.Email,
Username: u.Username,
Notes: u.Notes,
PasswordHash: u.PasswordHash,
Role: u.Role,
Balance: u.Balance,
Concurrency: u.Concurrency,
Status: u.Status,
SoraStorageQuotaBytes: u.SoraStorageQuotaBytes,
SoraStorageUsedBytes: u.SoraStorageUsedBytes,
TotpSecretEncrypted: u.TotpSecretEncrypted,
TotpEnabled: u.TotpEnabled,
TotpEnabledAt: u.TotpEnabledAt,
CreatedAt: u.CreatedAt,
UpdatedAt: u.UpdatedAt,
}
}
@@ -486,6 +488,7 @@ func groupEntityToService(g *dbent.Group) *service.Group {
SoraImagePrice540: g.SoraImagePrice540,
SoraVideoPricePerRequest: g.SoraVideoPricePerRequest,
SoraVideoPricePerRequestHD: g.SoraVideoPricePerRequestHd,
SoraStorageQuotaBytes: g.SoraStorageQuotaBytes,
DefaultValidityDays: g.DefaultValidityDays,
ClaudeCodeOnly: g.ClaudeCodeOnly,
FallbackGroupID: g.FallbackGroupID,

View File

@@ -227,6 +227,43 @@ func (c *concurrencyCache) GetAccountConcurrency(ctx context.Context, accountID
return result, nil
}
func (c *concurrencyCache) GetAccountConcurrencyBatch(ctx context.Context, accountIDs []int64) (map[int64]int, error) {
if len(accountIDs) == 0 {
return map[int64]int{}, nil
}
now, err := c.rdb.Time(ctx).Result()
if err != nil {
return nil, fmt.Errorf("redis TIME: %w", err)
}
cutoffTime := now.Unix() - int64(c.slotTTLSeconds)
pipe := c.rdb.Pipeline()
type accountCmd struct {
accountID int64
zcardCmd *redis.IntCmd
}
cmds := make([]accountCmd, 0, len(accountIDs))
for _, accountID := range accountIDs {
slotKey := accountSlotKeyPrefix + strconv.FormatInt(accountID, 10)
pipe.ZRemRangeByScore(ctx, slotKey, "-inf", strconv.FormatInt(cutoffTime, 10))
cmds = append(cmds, accountCmd{
accountID: accountID,
zcardCmd: pipe.ZCard(ctx, slotKey),
})
}
if _, err := pipe.Exec(ctx); err != nil && !errors.Is(err, redis.Nil) {
return nil, fmt.Errorf("pipeline exec: %w", err)
}
result := make(map[int64]int, len(accountIDs))
for _, cmd := range cmds {
result[cmd.accountID] = int(cmd.zcardCmd.Val())
}
return result, nil
}
// User slot operations
func (c *concurrencyCache) AcquireUserSlot(ctx context.Context, userID int64, maxConcurrency int, requestID string) (bool, error) {

View File

@@ -104,7 +104,6 @@ func (s *GatewayCacheSuite) TestGetSessionAccountID_CorruptedValue() {
require.False(s.T(), errors.Is(err, redis.Nil), "expected parsing error, not redis.Nil")
}
func TestGatewayCacheSuite(t *testing.T) {
suite.Run(t, new(GatewayCacheSuite))
}

View File

@@ -4,6 +4,8 @@ import (
"context"
"database/sql"
"errors"
"fmt"
"strings"
dbent "github.com/Wei-Shaw/sub2api/ent"
"github.com/Wei-Shaw/sub2api/ent/apikey"
@@ -56,7 +58,8 @@ func (r *groupRepository) Create(ctx context.Context, groupIn *service.Group) er
SetNillableFallbackGroupID(groupIn.FallbackGroupID).
SetNillableFallbackGroupIDOnInvalidRequest(groupIn.FallbackGroupIDOnInvalidRequest).
SetModelRoutingEnabled(groupIn.ModelRoutingEnabled).
SetMcpXMLInject(groupIn.MCPXMLInject)
SetMcpXMLInject(groupIn.MCPXMLInject).
SetSoraStorageQuotaBytes(groupIn.SoraStorageQuotaBytes)
// 设置模型路由配置
if groupIn.ModelRouting != nil {
@@ -121,7 +124,8 @@ func (r *groupRepository) Update(ctx context.Context, groupIn *service.Group) er
SetDefaultValidityDays(groupIn.DefaultValidityDays).
SetClaudeCodeOnly(groupIn.ClaudeCodeOnly).
SetModelRoutingEnabled(groupIn.ModelRoutingEnabled).
SetMcpXMLInject(groupIn.MCPXMLInject)
SetMcpXMLInject(groupIn.MCPXMLInject).
SetSoraStorageQuotaBytes(groupIn.SoraStorageQuotaBytes)
// 处理 FallbackGroupIDnil 时清除,否则设置
if groupIn.FallbackGroupID != nil {
@@ -281,6 +285,54 @@ func (r *groupRepository) ExistsByName(ctx context.Context, name string) (bool,
return r.client.Group.Query().Where(group.NameEQ(name)).Exist(ctx)
}
// ExistsByIDs 批量检查分组是否存在(仅检查未软删除记录)。
// 返回结构map[groupID]exists。
func (r *groupRepository) ExistsByIDs(ctx context.Context, ids []int64) (map[int64]bool, error) {
result := make(map[int64]bool, len(ids))
if len(ids) == 0 {
return result, nil
}
uniqueIDs := make([]int64, 0, len(ids))
seen := make(map[int64]struct{}, len(ids))
for _, id := range ids {
if id <= 0 {
continue
}
if _, ok := seen[id]; ok {
continue
}
seen[id] = struct{}{}
uniqueIDs = append(uniqueIDs, id)
result[id] = false
}
if len(uniqueIDs) == 0 {
return result, nil
}
rows, err := r.sql.QueryContext(ctx, `
SELECT id
FROM groups
WHERE id = ANY($1) AND deleted_at IS NULL
`, pq.Array(uniqueIDs))
if err != nil {
return nil, err
}
defer func() { _ = rows.Close() }()
for rows.Next() {
var id int64
if err := rows.Scan(&id); err != nil {
return nil, err
}
result[id] = true
}
if err := rows.Err(); err != nil {
return nil, err
}
return result, nil
}
func (r *groupRepository) GetAccountCount(ctx context.Context, groupID int64) (int64, error) {
var count int64
if err := scanSingleRow(ctx, r.sql, "SELECT COUNT(*) FROM account_groups WHERE group_id = $1", []any{groupID}, &count); err != nil {
@@ -512,22 +564,72 @@ func (r *groupRepository) UpdateSortOrders(ctx context.Context, updates []servic
return nil
}
// 使用事务批量更新
tx, err := r.client.Tx(ctx)
// 去重后保留最后一次排序值,避免重复 ID 造成 CASE 分支冲突。
sortOrderByID := make(map[int64]int, len(updates))
groupIDs := make([]int64, 0, len(updates))
for _, u := range updates {
if u.ID <= 0 {
continue
}
if _, exists := sortOrderByID[u.ID]; !exists {
groupIDs = append(groupIDs, u.ID)
}
sortOrderByID[u.ID] = u.SortOrder
}
if len(groupIDs) == 0 {
return nil
}
// 与旧实现保持一致:任何不存在/已删除的分组都返回 not found且不执行更新。
var existingCount int
if err := scanSingleRow(
ctx,
r.sql,
`SELECT COUNT(*) FROM groups WHERE deleted_at IS NULL AND id = ANY($1)`,
[]any{pq.Array(groupIDs)},
&existingCount,
); err != nil {
return err
}
if existingCount != len(groupIDs) {
return service.ErrGroupNotFound
}
args := make([]any, 0, len(groupIDs)*2+1)
caseClauses := make([]string, 0, len(groupIDs))
placeholder := 1
for _, id := range groupIDs {
caseClauses = append(caseClauses, fmt.Sprintf("WHEN $%d THEN $%d", placeholder, placeholder+1))
args = append(args, id, sortOrderByID[id])
placeholder += 2
}
args = append(args, pq.Array(groupIDs))
query := fmt.Sprintf(`
UPDATE groups
SET sort_order = CASE id
%s
ELSE sort_order
END
WHERE deleted_at IS NULL AND id = ANY($%d)
`, strings.Join(caseClauses, "\n\t\t\t"), placeholder)
result, err := r.sql.ExecContext(ctx, query, args...)
if err != nil {
return err
}
defer func() { _ = tx.Rollback() }()
for _, u := range updates {
if _, err := tx.Group.UpdateOneID(u.ID).SetSortOrder(u.SortOrder).Save(ctx); err != nil {
return translatePersistenceError(err, service.ErrGroupNotFound, nil)
}
}
if err := tx.Commit(); err != nil {
affected, err := result.RowsAffected()
if err != nil {
return err
}
if affected != int64(len(groupIDs)) {
return service.ErrGroupNotFound
}
for _, id := range groupIDs {
if err := enqueueSchedulerOutbox(ctx, r.sql, service.SchedulerOutboxEventGroupChanged, nil, &id, nil); err != nil {
logger.LegacyPrintf("repository.group", "[SchedulerOutbox] enqueue group sort update failed: group=%d err=%v", id, err)
}
}
return nil
}

View File

@@ -352,6 +352,81 @@ func (s *GroupRepoSuite) TestListWithFilters_Search() {
})
}
func (s *GroupRepoSuite) TestUpdateSortOrders_BatchCaseWhen() {
g1 := &service.Group{
Name: "sort-g1",
Platform: service.PlatformAnthropic,
RateMultiplier: 1.0,
IsExclusive: false,
Status: service.StatusActive,
SubscriptionType: service.SubscriptionTypeStandard,
}
g2 := &service.Group{
Name: "sort-g2",
Platform: service.PlatformAnthropic,
RateMultiplier: 1.0,
IsExclusive: false,
Status: service.StatusActive,
SubscriptionType: service.SubscriptionTypeStandard,
}
g3 := &service.Group{
Name: "sort-g3",
Platform: service.PlatformAnthropic,
RateMultiplier: 1.0,
IsExclusive: false,
Status: service.StatusActive,
SubscriptionType: service.SubscriptionTypeStandard,
}
s.Require().NoError(s.repo.Create(s.ctx, g1))
s.Require().NoError(s.repo.Create(s.ctx, g2))
s.Require().NoError(s.repo.Create(s.ctx, g3))
err := s.repo.UpdateSortOrders(s.ctx, []service.GroupSortOrderUpdate{
{ID: g1.ID, SortOrder: 30},
{ID: g2.ID, SortOrder: 10},
{ID: g3.ID, SortOrder: 20},
{ID: g2.ID, SortOrder: 15}, // 重复 ID 应以最后一次为准
})
s.Require().NoError(err)
got1, err := s.repo.GetByID(s.ctx, g1.ID)
s.Require().NoError(err)
got2, err := s.repo.GetByID(s.ctx, g2.ID)
s.Require().NoError(err)
got3, err := s.repo.GetByID(s.ctx, g3.ID)
s.Require().NoError(err)
s.Require().Equal(30, got1.SortOrder)
s.Require().Equal(15, got2.SortOrder)
s.Require().Equal(20, got3.SortOrder)
}
func (s *GroupRepoSuite) TestUpdateSortOrders_MissingGroupNoPartialUpdate() {
g1 := &service.Group{
Name: "sort-no-partial",
Platform: service.PlatformAnthropic,
RateMultiplier: 1.0,
IsExclusive: false,
Status: service.StatusActive,
SubscriptionType: service.SubscriptionTypeStandard,
}
s.Require().NoError(s.repo.Create(s.ctx, g1))
before, err := s.repo.GetByID(s.ctx, g1.ID)
s.Require().NoError(err)
beforeSort := before.SortOrder
err = s.repo.UpdateSortOrders(s.ctx, []service.GroupSortOrderUpdate{
{ID: g1.ID, SortOrder: 99},
{ID: 99999999, SortOrder: 1},
})
s.Require().Error(err)
s.Require().ErrorIs(err, service.ErrGroupNotFound)
after, err := s.repo.GetByID(s.ctx, g1.ID)
s.Require().NoError(err)
s.Require().Equal(beforeSort, after.SortOrder)
}
func (s *GroupRepoSuite) TestListWithFilters_AccountCount() {
g1 := &service.Group{
Name: "g1",

View File

@@ -147,4 +147,3 @@ func TestIdempotencyRepo_StatusTransition_ToSucceeded(t *testing.T) {
require.Equal(t, `{"ok":true}`, *got.ResponseBody)
require.Nil(t, got.LockedUntil)
}

View File

@@ -50,6 +50,23 @@ CREATE TABLE IF NOT EXISTS atlas_schema_revisions (
// 任何稳定的 int64 值都可以,只要不与同一数据库中的其他锁冲突即可。
const migrationsAdvisoryLockID int64 = 694208311321144027
const migrationsLockRetryInterval = 500 * time.Millisecond
const nonTransactionalMigrationSuffix = "_notx.sql"
type migrationChecksumCompatibilityRule struct {
fileChecksum string
acceptedDBChecksum map[string]struct{}
}
// migrationChecksumCompatibilityRules 仅用于兼容历史上误修改过的迁移文件 checksum。
// 规则必须同时匹配「迁移名 + 当前文件 checksum + 历史库 checksum」才会放行避免放宽全局校验。
var migrationChecksumCompatibilityRules = map[string]migrationChecksumCompatibilityRule{
"054_drop_legacy_cache_columns.sql": {
fileChecksum: "82de761156e03876653e7a6a4eee883cd927847036f779b0b9f34c42a8af7a7d",
acceptedDBChecksum: map[string]struct{}{
"182c193f3359946cf094090cd9e57d5c3fd9abaffbc1e8fc378646b8a6fa12b4": {},
},
},
}
// ApplyMigrations 将嵌入的 SQL 迁移文件应用到指定的数据库。
//
@@ -147,6 +164,10 @@ func applyMigrationsFS(ctx context.Context, db *sql.DB, fsys fs.FS) error {
if rowErr == nil {
// 迁移已应用,验证校验和是否匹配
if existing != checksum {
// 兼容特定历史误改场景(仅白名单规则),其余仍保持严格不可变约束。
if isMigrationChecksumCompatible(name, existing, checksum) {
continue
}
// 校验和不匹配意味着迁移文件在应用后被修改,这是危险的。
// 正确的做法是创建新的迁移文件来进行变更。
return fmt.Errorf(
@@ -165,8 +186,34 @@ func applyMigrationsFS(ctx context.Context, db *sql.DB, fsys fs.FS) error {
return fmt.Errorf("check migration %s: %w", name, rowErr)
}
// 迁移未应用,在事务中执行。
// 使用事务确保迁移的原子性:要么完全成功,要么完全回滚。
nonTx, err := validateMigrationExecutionMode(name, content)
if err != nil {
return fmt.Errorf("validate migration %s: %w", name, err)
}
if nonTx {
// *_notx.sql用于 CREATE/DROP INDEX CONCURRENTLY 场景,必须非事务执行。
// 逐条语句执行,避免将多条 CONCURRENTLY 语句放入同一个隐式事务块。
statements := splitSQLStatements(content)
for i, stmt := range statements {
trimmed := strings.TrimSpace(stmt)
if trimmed == "" {
continue
}
if stripSQLLineComment(trimmed) == "" {
continue
}
if _, err := db.ExecContext(ctx, trimmed); err != nil {
return fmt.Errorf("apply migration %s (non-tx statement %d): %w", name, i+1, err)
}
}
if _, err := db.ExecContext(ctx, "INSERT INTO schema_migrations (filename, checksum) VALUES ($1, $2)", name, checksum); err != nil {
return fmt.Errorf("record migration %s (non-tx): %w", name, err)
}
continue
}
// 默认迁移在事务中执行,确保原子性:要么完全成功,要么完全回滚。
tx, err := db.BeginTx(ctx, nil)
if err != nil {
return fmt.Errorf("begin migration %s: %w", name, err)
@@ -268,6 +315,84 @@ func latestMigrationBaseline(fsys fs.FS) (string, string, string, error) {
return version, version, hash, nil
}
func isMigrationChecksumCompatible(name, dbChecksum, fileChecksum string) bool {
rule, ok := migrationChecksumCompatibilityRules[name]
if !ok {
return false
}
if rule.fileChecksum != fileChecksum {
return false
}
_, ok = rule.acceptedDBChecksum[dbChecksum]
return ok
}
func validateMigrationExecutionMode(name, content string) (bool, error) {
normalizedName := strings.ToLower(strings.TrimSpace(name))
upperContent := strings.ToUpper(content)
nonTx := strings.HasSuffix(normalizedName, nonTransactionalMigrationSuffix)
if !nonTx {
if strings.Contains(upperContent, "CONCURRENTLY") {
return false, errors.New("CONCURRENTLY statements must be placed in *_notx.sql migrations")
}
return false, nil
}
if strings.Contains(upperContent, "BEGIN") || strings.Contains(upperContent, "COMMIT") || strings.Contains(upperContent, "ROLLBACK") {
return false, errors.New("*_notx.sql must not contain transaction control statements (BEGIN/COMMIT/ROLLBACK)")
}
statements := splitSQLStatements(content)
for _, stmt := range statements {
normalizedStmt := strings.ToUpper(stripSQLLineComment(strings.TrimSpace(stmt)))
if normalizedStmt == "" {
continue
}
if strings.Contains(normalizedStmt, "CONCURRENTLY") {
isCreateIndex := strings.Contains(normalizedStmt, "CREATE") && strings.Contains(normalizedStmt, "INDEX")
isDropIndex := strings.Contains(normalizedStmt, "DROP") && strings.Contains(normalizedStmt, "INDEX")
if !isCreateIndex && !isDropIndex {
return false, errors.New("*_notx.sql currently only supports CREATE/DROP INDEX CONCURRENTLY statements")
}
if isCreateIndex && !strings.Contains(normalizedStmt, "IF NOT EXISTS") {
return false, errors.New("CREATE INDEX CONCURRENTLY in *_notx.sql must include IF NOT EXISTS for idempotency")
}
if isDropIndex && !strings.Contains(normalizedStmt, "IF EXISTS") {
return false, errors.New("DROP INDEX CONCURRENTLY in *_notx.sql must include IF EXISTS for idempotency")
}
continue
}
return false, errors.New("*_notx.sql must not mix non-CONCURRENTLY SQL statements")
}
return true, nil
}
func splitSQLStatements(content string) []string {
parts := strings.Split(content, ";")
out := make([]string, 0, len(parts))
for _, part := range parts {
if strings.TrimSpace(part) == "" {
continue
}
out = append(out, part)
}
return out
}
func stripSQLLineComment(s string) string {
lines := strings.Split(s, "\n")
for i, line := range lines {
if idx := strings.Index(line, "--"); idx >= 0 {
lines[i] = line[:idx]
}
}
return strings.TrimSpace(strings.Join(lines, "\n"))
}
// pgAdvisoryLock 获取 PostgreSQL Advisory Lock。
// Advisory Lock 是一种轻量级的锁机制,不与任何特定的数据库对象关联。
// 它非常适合用于应用层面的分布式锁场景,如迁移序列化。

View File

@@ -0,0 +1,36 @@
package repository
import (
"testing"
"github.com/stretchr/testify/require"
)
func TestIsMigrationChecksumCompatible(t *testing.T) {
t.Run("054历史checksum可兼容", func(t *testing.T) {
ok := isMigrationChecksumCompatible(
"054_drop_legacy_cache_columns.sql",
"182c193f3359946cf094090cd9e57d5c3fd9abaffbc1e8fc378646b8a6fa12b4",
"82de761156e03876653e7a6a4eee883cd927847036f779b0b9f34c42a8af7a7d",
)
require.True(t, ok)
})
t.Run("054在未知文件checksum下不兼容", func(t *testing.T) {
ok := isMigrationChecksumCompatible(
"054_drop_legacy_cache_columns.sql",
"182c193f3359946cf094090cd9e57d5c3fd9abaffbc1e8fc378646b8a6fa12b4",
"0000000000000000000000000000000000000000000000000000000000000000",
)
require.False(t, ok)
})
t.Run("非白名单迁移不兼容", func(t *testing.T) {
ok := isMigrationChecksumCompatible(
"001_init.sql",
"182c193f3359946cf094090cd9e57d5c3fd9abaffbc1e8fc378646b8a6fa12b4",
"82de761156e03876653e7a6a4eee883cd927847036f779b0b9f34c42a8af7a7d",
)
require.False(t, ok)
})
}

View File

@@ -0,0 +1,368 @@
package repository
import (
"context"
"crypto/sha256"
"encoding/hex"
"errors"
"io/fs"
"strings"
"testing"
"testing/fstest"
"time"
sqlmock "github.com/DATA-DOG/go-sqlmock"
"github.com/stretchr/testify/require"
)
func TestApplyMigrations_NilDB(t *testing.T) {
err := ApplyMigrations(context.Background(), nil)
require.Error(t, err)
require.Contains(t, err.Error(), "nil sql db")
}
func TestApplyMigrations_DelegatesToApplyMigrationsFS(t *testing.T) {
db, mock, err := sqlmock.New()
require.NoError(t, err)
defer func() { _ = db.Close() }()
mock.ExpectQuery("SELECT pg_try_advisory_lock\\(\\$1\\)").
WithArgs(migrationsAdvisoryLockID).
WillReturnError(errors.New("lock failed"))
err = ApplyMigrations(context.Background(), db)
require.Error(t, err)
require.Contains(t, err.Error(), "acquire migrations lock")
require.NoError(t, mock.ExpectationsWereMet())
}
func TestLatestMigrationBaseline(t *testing.T) {
t.Run("empty_fs_returns_baseline", func(t *testing.T) {
version, description, hash, err := latestMigrationBaseline(fstest.MapFS{})
require.NoError(t, err)
require.Equal(t, "baseline", version)
require.Equal(t, "baseline", description)
require.Equal(t, "", hash)
})
t.Run("uses_latest_sorted_sql_file", func(t *testing.T) {
fsys := fstest.MapFS{
"001_init.sql": &fstest.MapFile{Data: []byte("CREATE TABLE t1(id int);")},
"010_final.sql": &fstest.MapFile{
Data: []byte("CREATE TABLE t2(id int);"),
},
}
version, description, hash, err := latestMigrationBaseline(fsys)
require.NoError(t, err)
require.Equal(t, "010_final", version)
require.Equal(t, "010_final", description)
require.Len(t, hash, 64)
})
t.Run("read_file_error", func(t *testing.T) {
fsys := fstest.MapFS{
"010_bad.sql": &fstest.MapFile{Mode: fs.ModeDir},
}
_, _, _, err := latestMigrationBaseline(fsys)
require.Error(t, err)
})
}
func TestIsMigrationChecksumCompatible_AdditionalCases(t *testing.T) {
require.False(t, isMigrationChecksumCompatible("unknown.sql", "db", "file"))
var (
name string
rule migrationChecksumCompatibilityRule
)
for n, r := range migrationChecksumCompatibilityRules {
name = n
rule = r
break
}
require.NotEmpty(t, name)
require.False(t, isMigrationChecksumCompatible(name, "db-not-accepted", "file-not-match"))
require.False(t, isMigrationChecksumCompatible(name, "db-not-accepted", rule.fileChecksum))
var accepted string
for checksum := range rule.acceptedDBChecksum {
accepted = checksum
break
}
require.NotEmpty(t, accepted)
require.True(t, isMigrationChecksumCompatible(name, accepted, rule.fileChecksum))
}
func TestEnsureAtlasBaselineAligned(t *testing.T) {
t.Run("skip_when_no_legacy_table", func(t *testing.T) {
db, mock, err := sqlmock.New()
require.NoError(t, err)
defer func() { _ = db.Close() }()
mock.ExpectQuery("SELECT EXISTS \\(").
WithArgs("schema_migrations").
WillReturnRows(sqlmock.NewRows([]string{"exists"}).AddRow(false))
err = ensureAtlasBaselineAligned(context.Background(), db, fstest.MapFS{})
require.NoError(t, err)
require.NoError(t, mock.ExpectationsWereMet())
})
t.Run("create_atlas_and_insert_baseline_when_empty", func(t *testing.T) {
db, mock, err := sqlmock.New()
require.NoError(t, err)
defer func() { _ = db.Close() }()
mock.ExpectQuery("SELECT EXISTS \\(").
WithArgs("schema_migrations").
WillReturnRows(sqlmock.NewRows([]string{"exists"}).AddRow(true))
mock.ExpectQuery("SELECT EXISTS \\(").
WithArgs("atlas_schema_revisions").
WillReturnRows(sqlmock.NewRows([]string{"exists"}).AddRow(false))
mock.ExpectExec("CREATE TABLE IF NOT EXISTS atlas_schema_revisions").
WillReturnResult(sqlmock.NewResult(0, 0))
mock.ExpectQuery("SELECT COUNT\\(\\*\\) FROM atlas_schema_revisions").
WillReturnRows(sqlmock.NewRows([]string{"count"}).AddRow(0))
mock.ExpectExec("INSERT INTO atlas_schema_revisions").
WithArgs("002_next", "002_next", 1, sqlmock.AnyArg()).
WillReturnResult(sqlmock.NewResult(1, 1))
fsys := fstest.MapFS{
"001_init.sql": &fstest.MapFile{Data: []byte("CREATE TABLE t1(id int);")},
"002_next.sql": &fstest.MapFile{Data: []byte("CREATE TABLE t2(id int);")},
}
err = ensureAtlasBaselineAligned(context.Background(), db, fsys)
require.NoError(t, err)
require.NoError(t, mock.ExpectationsWereMet())
})
t.Run("error_when_checking_legacy_table", func(t *testing.T) {
db, mock, err := sqlmock.New()
require.NoError(t, err)
defer func() { _ = db.Close() }()
mock.ExpectQuery("SELECT EXISTS \\(").
WithArgs("schema_migrations").
WillReturnError(errors.New("exists failed"))
err = ensureAtlasBaselineAligned(context.Background(), db, fstest.MapFS{})
require.Error(t, err)
require.Contains(t, err.Error(), "check schema_migrations")
require.NoError(t, mock.ExpectationsWereMet())
})
t.Run("error_when_counting_atlas_rows", func(t *testing.T) {
db, mock, err := sqlmock.New()
require.NoError(t, err)
defer func() { _ = db.Close() }()
mock.ExpectQuery("SELECT EXISTS \\(").
WithArgs("schema_migrations").
WillReturnRows(sqlmock.NewRows([]string{"exists"}).AddRow(true))
mock.ExpectQuery("SELECT EXISTS \\(").
WithArgs("atlas_schema_revisions").
WillReturnRows(sqlmock.NewRows([]string{"exists"}).AddRow(true))
mock.ExpectQuery("SELECT COUNT\\(\\*\\) FROM atlas_schema_revisions").
WillReturnError(errors.New("count failed"))
err = ensureAtlasBaselineAligned(context.Background(), db, fstest.MapFS{})
require.Error(t, err)
require.Contains(t, err.Error(), "count atlas_schema_revisions")
require.NoError(t, mock.ExpectationsWereMet())
})
t.Run("error_when_creating_atlas_table", func(t *testing.T) {
db, mock, err := sqlmock.New()
require.NoError(t, err)
defer func() { _ = db.Close() }()
mock.ExpectQuery("SELECT EXISTS \\(").
WithArgs("schema_migrations").
WillReturnRows(sqlmock.NewRows([]string{"exists"}).AddRow(true))
mock.ExpectQuery("SELECT EXISTS \\(").
WithArgs("atlas_schema_revisions").
WillReturnRows(sqlmock.NewRows([]string{"exists"}).AddRow(false))
mock.ExpectExec("CREATE TABLE IF NOT EXISTS atlas_schema_revisions").
WillReturnError(errors.New("create failed"))
err = ensureAtlasBaselineAligned(context.Background(), db, fstest.MapFS{})
require.Error(t, err)
require.Contains(t, err.Error(), "create atlas_schema_revisions")
require.NoError(t, mock.ExpectationsWereMet())
})
t.Run("error_when_inserting_baseline", func(t *testing.T) {
db, mock, err := sqlmock.New()
require.NoError(t, err)
defer func() { _ = db.Close() }()
mock.ExpectQuery("SELECT EXISTS \\(").
WithArgs("schema_migrations").
WillReturnRows(sqlmock.NewRows([]string{"exists"}).AddRow(true))
mock.ExpectQuery("SELECT EXISTS \\(").
WithArgs("atlas_schema_revisions").
WillReturnRows(sqlmock.NewRows([]string{"exists"}).AddRow(true))
mock.ExpectQuery("SELECT COUNT\\(\\*\\) FROM atlas_schema_revisions").
WillReturnRows(sqlmock.NewRows([]string{"count"}).AddRow(0))
mock.ExpectExec("INSERT INTO atlas_schema_revisions").
WithArgs("001_init", "001_init", 1, sqlmock.AnyArg()).
WillReturnError(errors.New("insert failed"))
fsys := fstest.MapFS{
"001_init.sql": &fstest.MapFile{Data: []byte("CREATE TABLE t(id int);")},
}
err = ensureAtlasBaselineAligned(context.Background(), db, fsys)
require.Error(t, err)
require.Contains(t, err.Error(), "insert atlas baseline")
require.NoError(t, mock.ExpectationsWereMet())
})
}
func TestApplyMigrationsFS_ChecksumMismatchRejected(t *testing.T) {
db, mock, err := sqlmock.New()
require.NoError(t, err)
defer func() { _ = db.Close() }()
prepareMigrationsBootstrapExpectations(mock)
mock.ExpectQuery("SELECT checksum FROM schema_migrations WHERE filename = \\$1").
WithArgs("001_init.sql").
WillReturnRows(sqlmock.NewRows([]string{"checksum"}).AddRow("mismatched-checksum"))
mock.ExpectExec("SELECT pg_advisory_unlock\\(\\$1\\)").
WithArgs(migrationsAdvisoryLockID).
WillReturnResult(sqlmock.NewResult(0, 1))
fsys := fstest.MapFS{
"001_init.sql": &fstest.MapFile{Data: []byte("CREATE TABLE t(id int);")},
}
err = applyMigrationsFS(context.Background(), db, fsys)
require.Error(t, err)
require.Contains(t, err.Error(), "checksum mismatch")
require.NoError(t, mock.ExpectationsWereMet())
}
func TestApplyMigrationsFS_CheckMigrationQueryError(t *testing.T) {
db, mock, err := sqlmock.New()
require.NoError(t, err)
defer func() { _ = db.Close() }()
prepareMigrationsBootstrapExpectations(mock)
mock.ExpectQuery("SELECT checksum FROM schema_migrations WHERE filename = \\$1").
WithArgs("001_err.sql").
WillReturnError(errors.New("query failed"))
mock.ExpectExec("SELECT pg_advisory_unlock\\(\\$1\\)").
WithArgs(migrationsAdvisoryLockID).
WillReturnResult(sqlmock.NewResult(0, 1))
fsys := fstest.MapFS{
"001_err.sql": &fstest.MapFile{Data: []byte("SELECT 1;")},
}
err = applyMigrationsFS(context.Background(), db, fsys)
require.Error(t, err)
require.Contains(t, err.Error(), "check migration 001_err.sql")
require.NoError(t, mock.ExpectationsWereMet())
}
func TestApplyMigrationsFS_SkipEmptyAndAlreadyApplied(t *testing.T) {
db, mock, err := sqlmock.New()
require.NoError(t, err)
defer func() { _ = db.Close() }()
prepareMigrationsBootstrapExpectations(mock)
alreadySQL := "CREATE TABLE t(id int);"
checksum := migrationChecksum(alreadySQL)
mock.ExpectQuery("SELECT checksum FROM schema_migrations WHERE filename = \\$1").
WithArgs("001_already.sql").
WillReturnRows(sqlmock.NewRows([]string{"checksum"}).AddRow(checksum))
mock.ExpectExec("SELECT pg_advisory_unlock\\(\\$1\\)").
WithArgs(migrationsAdvisoryLockID).
WillReturnResult(sqlmock.NewResult(0, 1))
fsys := fstest.MapFS{
"000_empty.sql": &fstest.MapFile{Data: []byte(" \n\t ")},
"001_already.sql": &fstest.MapFile{Data: []byte(alreadySQL)},
}
err = applyMigrationsFS(context.Background(), db, fsys)
require.NoError(t, err)
require.NoError(t, mock.ExpectationsWereMet())
}
func TestApplyMigrationsFS_ReadMigrationError(t *testing.T) {
db, mock, err := sqlmock.New()
require.NoError(t, err)
defer func() { _ = db.Close() }()
prepareMigrationsBootstrapExpectations(mock)
mock.ExpectExec("SELECT pg_advisory_unlock\\(\\$1\\)").
WithArgs(migrationsAdvisoryLockID).
WillReturnResult(sqlmock.NewResult(0, 1))
fsys := fstest.MapFS{
"001_bad.sql": &fstest.MapFile{Mode: fs.ModeDir},
}
err = applyMigrationsFS(context.Background(), db, fsys)
require.Error(t, err)
require.Contains(t, err.Error(), "read migration 001_bad.sql")
require.NoError(t, mock.ExpectationsWereMet())
}
func TestPgAdvisoryLockAndUnlock_ErrorBranches(t *testing.T) {
t.Run("context_cancelled_while_not_locked", func(t *testing.T) {
db, mock, err := sqlmock.New()
require.NoError(t, err)
defer func() { _ = db.Close() }()
mock.ExpectQuery("SELECT pg_try_advisory_lock\\(\\$1\\)").
WithArgs(migrationsAdvisoryLockID).
WillReturnRows(sqlmock.NewRows([]string{"pg_try_advisory_lock"}).AddRow(false))
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Millisecond)
defer cancel()
err = pgAdvisoryLock(ctx, db)
require.Error(t, err)
require.Contains(t, err.Error(), "acquire migrations lock")
require.NoError(t, mock.ExpectationsWereMet())
})
t.Run("unlock_exec_error", func(t *testing.T) {
db, mock, err := sqlmock.New()
require.NoError(t, err)
defer func() { _ = db.Close() }()
mock.ExpectExec("SELECT pg_advisory_unlock\\(\\$1\\)").
WithArgs(migrationsAdvisoryLockID).
WillReturnError(errors.New("unlock failed"))
err = pgAdvisoryUnlock(context.Background(), db)
require.Error(t, err)
require.Contains(t, err.Error(), "release migrations lock")
require.NoError(t, mock.ExpectationsWereMet())
})
t.Run("acquire_lock_after_retry", func(t *testing.T) {
db, mock, err := sqlmock.New()
require.NoError(t, err)
defer func() { _ = db.Close() }()
mock.ExpectQuery("SELECT pg_try_advisory_lock\\(\\$1\\)").
WithArgs(migrationsAdvisoryLockID).
WillReturnRows(sqlmock.NewRows([]string{"pg_try_advisory_lock"}).AddRow(false))
mock.ExpectQuery("SELECT pg_try_advisory_lock\\(\\$1\\)").
WithArgs(migrationsAdvisoryLockID).
WillReturnRows(sqlmock.NewRows([]string{"pg_try_advisory_lock"}).AddRow(true))
ctx, cancel := context.WithTimeout(context.Background(), migrationsLockRetryInterval*3)
defer cancel()
start := time.Now()
err = pgAdvisoryLock(ctx, db)
require.NoError(t, err)
require.GreaterOrEqual(t, time.Since(start), migrationsLockRetryInterval)
require.NoError(t, mock.ExpectationsWereMet())
})
}
func migrationChecksum(content string) string {
sum := sha256.Sum256([]byte(strings.TrimSpace(content)))
return hex.EncodeToString(sum[:])
}

View File

@@ -0,0 +1,164 @@
package repository
import (
"context"
"database/sql"
"testing"
"testing/fstest"
sqlmock "github.com/DATA-DOG/go-sqlmock"
"github.com/stretchr/testify/require"
)
func TestValidateMigrationExecutionMode(t *testing.T) {
t.Run("事务迁移包含CONCURRENTLY会被拒绝", func(t *testing.T) {
nonTx, err := validateMigrationExecutionMode("001_add_idx.sql", "CREATE INDEX CONCURRENTLY idx_a ON t(a);")
require.False(t, nonTx)
require.Error(t, err)
})
t.Run("notx迁移要求CREATE使用IF NOT EXISTS", func(t *testing.T) {
nonTx, err := validateMigrationExecutionMode("001_add_idx_notx.sql", "CREATE INDEX CONCURRENTLY idx_a ON t(a);")
require.False(t, nonTx)
require.Error(t, err)
})
t.Run("notx迁移要求DROP使用IF EXISTS", func(t *testing.T) {
nonTx, err := validateMigrationExecutionMode("001_drop_idx_notx.sql", "DROP INDEX CONCURRENTLY idx_a;")
require.False(t, nonTx)
require.Error(t, err)
})
t.Run("notx迁移禁止事务控制语句", func(t *testing.T) {
nonTx, err := validateMigrationExecutionMode("001_add_idx_notx.sql", "BEGIN; CREATE INDEX CONCURRENTLY IF NOT EXISTS idx_a ON t(a); COMMIT;")
require.False(t, nonTx)
require.Error(t, err)
})
t.Run("notx迁移禁止混用非CONCURRENTLY语句", func(t *testing.T) {
nonTx, err := validateMigrationExecutionMode("001_add_idx_notx.sql", "CREATE INDEX CONCURRENTLY IF NOT EXISTS idx_a ON t(a); UPDATE t SET a = 1;")
require.False(t, nonTx)
require.Error(t, err)
})
t.Run("notx迁移允许幂等并发索引语句", func(t *testing.T) {
nonTx, err := validateMigrationExecutionMode("001_add_idx_notx.sql", `
CREATE INDEX CONCURRENTLY IF NOT EXISTS idx_a ON t(a);
DROP INDEX CONCURRENTLY IF EXISTS idx_b;
`)
require.True(t, nonTx)
require.NoError(t, err)
})
}
func TestApplyMigrationsFS_NonTransactionalMigration(t *testing.T) {
db, mock, err := sqlmock.New()
require.NoError(t, err)
defer func() { _ = db.Close() }()
prepareMigrationsBootstrapExpectations(mock)
mock.ExpectQuery("SELECT checksum FROM schema_migrations WHERE filename = \\$1").
WithArgs("001_add_idx_notx.sql").
WillReturnError(sql.ErrNoRows)
mock.ExpectExec("CREATE INDEX CONCURRENTLY IF NOT EXISTS idx_t_a ON t\\(a\\)").
WillReturnResult(sqlmock.NewResult(0, 0))
mock.ExpectExec("INSERT INTO schema_migrations \\(filename, checksum\\) VALUES \\(\\$1, \\$2\\)").
WithArgs("001_add_idx_notx.sql", sqlmock.AnyArg()).
WillReturnResult(sqlmock.NewResult(1, 1))
mock.ExpectExec("SELECT pg_advisory_unlock\\(\\$1\\)").
WithArgs(migrationsAdvisoryLockID).
WillReturnResult(sqlmock.NewResult(0, 1))
fsys := fstest.MapFS{
"001_add_idx_notx.sql": &fstest.MapFile{
Data: []byte("CREATE INDEX CONCURRENTLY IF NOT EXISTS idx_t_a ON t(a);"),
},
}
err = applyMigrationsFS(context.Background(), db, fsys)
require.NoError(t, err)
require.NoError(t, mock.ExpectationsWereMet())
}
func TestApplyMigrationsFS_NonTransactionalMigration_MultiStatements(t *testing.T) {
db, mock, err := sqlmock.New()
require.NoError(t, err)
defer func() { _ = db.Close() }()
prepareMigrationsBootstrapExpectations(mock)
mock.ExpectQuery("SELECT checksum FROM schema_migrations WHERE filename = \\$1").
WithArgs("001_add_multi_idx_notx.sql").
WillReturnError(sql.ErrNoRows)
mock.ExpectExec("CREATE INDEX CONCURRENTLY IF NOT EXISTS idx_t_a ON t\\(a\\)").
WillReturnResult(sqlmock.NewResult(0, 0))
mock.ExpectExec("CREATE INDEX CONCURRENTLY IF NOT EXISTS idx_t_b ON t\\(b\\)").
WillReturnResult(sqlmock.NewResult(0, 0))
mock.ExpectExec("INSERT INTO schema_migrations \\(filename, checksum\\) VALUES \\(\\$1, \\$2\\)").
WithArgs("001_add_multi_idx_notx.sql", sqlmock.AnyArg()).
WillReturnResult(sqlmock.NewResult(1, 1))
mock.ExpectExec("SELECT pg_advisory_unlock\\(\\$1\\)").
WithArgs(migrationsAdvisoryLockID).
WillReturnResult(sqlmock.NewResult(0, 1))
fsys := fstest.MapFS{
"001_add_multi_idx_notx.sql": &fstest.MapFile{
Data: []byte(`
-- first
CREATE INDEX CONCURRENTLY IF NOT EXISTS idx_t_a ON t(a);
-- second
CREATE INDEX CONCURRENTLY IF NOT EXISTS idx_t_b ON t(b);
`),
},
}
err = applyMigrationsFS(context.Background(), db, fsys)
require.NoError(t, err)
require.NoError(t, mock.ExpectationsWereMet())
}
func TestApplyMigrationsFS_TransactionalMigration(t *testing.T) {
db, mock, err := sqlmock.New()
require.NoError(t, err)
defer func() { _ = db.Close() }()
prepareMigrationsBootstrapExpectations(mock)
mock.ExpectQuery("SELECT checksum FROM schema_migrations WHERE filename = \\$1").
WithArgs("001_add_col.sql").
WillReturnError(sql.ErrNoRows)
mock.ExpectBegin()
mock.ExpectExec("ALTER TABLE t ADD COLUMN name TEXT").
WillReturnResult(sqlmock.NewResult(0, 0))
mock.ExpectExec("INSERT INTO schema_migrations \\(filename, checksum\\) VALUES \\(\\$1, \\$2\\)").
WithArgs("001_add_col.sql", sqlmock.AnyArg()).
WillReturnResult(sqlmock.NewResult(1, 1))
mock.ExpectCommit()
mock.ExpectExec("SELECT pg_advisory_unlock\\(\\$1\\)").
WithArgs(migrationsAdvisoryLockID).
WillReturnResult(sqlmock.NewResult(0, 1))
fsys := fstest.MapFS{
"001_add_col.sql": &fstest.MapFile{
Data: []byte("ALTER TABLE t ADD COLUMN name TEXT;"),
},
}
err = applyMigrationsFS(context.Background(), db, fsys)
require.NoError(t, err)
require.NoError(t, mock.ExpectationsWereMet())
}
func prepareMigrationsBootstrapExpectations(mock sqlmock.Sqlmock) {
mock.ExpectQuery("SELECT pg_try_advisory_lock\\(\\$1\\)").
WithArgs(migrationsAdvisoryLockID).
WillReturnRows(sqlmock.NewRows([]string{"pg_try_advisory_lock"}).AddRow(true))
mock.ExpectExec("CREATE TABLE IF NOT EXISTS schema_migrations").
WillReturnResult(sqlmock.NewResult(0, 0))
mock.ExpectQuery("SELECT EXISTS \\(").
WithArgs("schema_migrations").
WillReturnRows(sqlmock.NewRows([]string{"exists"}).AddRow(true))
mock.ExpectQuery("SELECT EXISTS \\(").
WithArgs("atlas_schema_revisions").
WillReturnRows(sqlmock.NewRows([]string{"exists"}).AddRow(true))
mock.ExpectQuery("SELECT COUNT\\(\\*\\) FROM atlas_schema_revisions").
WillReturnRows(sqlmock.NewRows([]string{"count"}).AddRow(1))
}

View File

@@ -42,6 +42,8 @@ func TestMigrationsRunner_IsIdempotent_AndSchemaIsUpToDate(t *testing.T) {
// usage_logs: billing_type used by filters/stats
requireColumn(t, tx, "usage_logs", "billing_type", "smallint", 0, false)
requireColumn(t, tx, "usage_logs", "request_type", "smallint", 0, false)
requireColumn(t, tx, "usage_logs", "openai_ws_mode", "boolean", 0, false)
// settings table should exist
var settingsRegclass sql.NullString

View File

@@ -22,16 +22,20 @@ type openaiOAuthService struct {
tokenURL string
}
func (s *openaiOAuthService) ExchangeCode(ctx context.Context, code, codeVerifier, redirectURI, proxyURL string) (*openai.TokenResponse, error) {
func (s *openaiOAuthService) ExchangeCode(ctx context.Context, code, codeVerifier, redirectURI, proxyURL, clientID string) (*openai.TokenResponse, error) {
client := createOpenAIReqClient(proxyURL)
if redirectURI == "" {
redirectURI = openai.DefaultRedirectURI
}
clientID = strings.TrimSpace(clientID)
if clientID == "" {
clientID = openai.ClientID
}
formData := url.Values{}
formData.Set("grant_type", "authorization_code")
formData.Set("client_id", openai.ClientID)
formData.Set("client_id", clientID)
formData.Set("code", code)
formData.Set("redirect_uri", redirectURI)
formData.Set("code_verifier", codeVerifier)
@@ -61,36 +65,12 @@ func (s *openaiOAuthService) RefreshToken(ctx context.Context, refreshToken, pro
}
func (s *openaiOAuthService) RefreshTokenWithClientID(ctx context.Context, refreshToken, proxyURL string, clientID string) (*openai.TokenResponse, error) {
if strings.TrimSpace(clientID) != "" {
return s.refreshTokenWithClientID(ctx, refreshToken, proxyURL, strings.TrimSpace(clientID))
// 调用方应始终传入正确的 client_id为兼容旧数据未指定时默认使用 OpenAI ClientID
clientID = strings.TrimSpace(clientID)
if clientID == "" {
clientID = openai.ClientID
}
clientIDs := []string{
openai.ClientID,
openai.SoraClientID,
}
seen := make(map[string]struct{}, len(clientIDs))
var lastErr error
for _, clientID := range clientIDs {
clientID = strings.TrimSpace(clientID)
if clientID == "" {
continue
}
if _, ok := seen[clientID]; ok {
continue
}
seen[clientID] = struct{}{}
tokenResp, err := s.refreshTokenWithClientID(ctx, refreshToken, proxyURL, clientID)
if err == nil {
return tokenResp, nil
}
lastErr = err
}
if lastErr != nil {
return nil, lastErr
}
return nil, infraerrors.New(http.StatusBadGateway, "OPENAI_OAUTH_TOKEN_REFRESH_FAILED", "token refresh failed")
return s.refreshTokenWithClientID(ctx, refreshToken, proxyURL, clientID)
}
func (s *openaiOAuthService) refreshTokenWithClientID(ctx context.Context, refreshToken, proxyURL, clientID string) (*openai.TokenResponse, error) {

View File

@@ -81,7 +81,7 @@ func (s *OpenAIOAuthServiceSuite) TestExchangeCode_DefaultRedirectURI() {
_, _ = io.WriteString(w, `{"access_token":"at","refresh_token":"rt","token_type":"bearer","expires_in":3600}`)
}))
resp, err := s.svc.ExchangeCode(s.ctx, "code", "ver", "", "")
resp, err := s.svc.ExchangeCode(s.ctx, "code", "ver", "", "", "")
require.NoError(s.T(), err, "ExchangeCode")
select {
case msg := <-errCh:
@@ -136,7 +136,9 @@ func (s *OpenAIOAuthServiceSuite) TestRefreshToken_FormFields() {
require.Equal(s.T(), "rt2", resp.RefreshToken)
}
func (s *OpenAIOAuthServiceSuite) TestRefreshToken_FallbackToSoraClientID() {
// TestRefreshToken_DefaultsToOpenAIClientID 验证未指定 client_id 时默认使用 OpenAI ClientID
// 且只发送一次请求(不再盲猜多个 client_id
func (s *OpenAIOAuthServiceSuite) TestRefreshToken_DefaultsToOpenAIClientID() {
var seenClientIDs []string
s.setupServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if err := r.ParseForm(); err != nil {
@@ -145,11 +147,27 @@ func (s *OpenAIOAuthServiceSuite) TestRefreshToken_FallbackToSoraClientID() {
}
clientID := r.PostForm.Get("client_id")
seenClientIDs = append(seenClientIDs, clientID)
if clientID == openai.ClientID {
w.Header().Set("Content-Type", "application/json")
_, _ = io.WriteString(w, `{"access_token":"at","refresh_token":"rt","token_type":"bearer","expires_in":3600}`)
}))
resp, err := s.svc.RefreshToken(s.ctx, "rt", "")
require.NoError(s.T(), err, "RefreshToken")
require.Equal(s.T(), "at", resp.AccessToken)
// 只发送了一次请求,使用默认的 OpenAI ClientID
require.Equal(s.T(), []string{openai.ClientID}, seenClientIDs)
}
// TestRefreshToken_UseSoraClientID 验证显式传入 Sora ClientID 时直接使用,不回退。
func (s *OpenAIOAuthServiceSuite) TestRefreshToken_UseSoraClientID() {
var seenClientIDs []string
s.setupServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if err := r.ParseForm(); err != nil {
w.WriteHeader(http.StatusBadRequest)
_, _ = io.WriteString(w, "invalid_grant")
return
}
clientID := r.PostForm.Get("client_id")
seenClientIDs = append(seenClientIDs, clientID)
if clientID == openai.SoraClientID {
w.Header().Set("Content-Type", "application/json")
_, _ = io.WriteString(w, `{"access_token":"at-sora","refresh_token":"rt-sora","token_type":"bearer","expires_in":3600}`)
@@ -158,11 +176,10 @@ func (s *OpenAIOAuthServiceSuite) TestRefreshToken_FallbackToSoraClientID() {
w.WriteHeader(http.StatusBadRequest)
}))
resp, err := s.svc.RefreshToken(s.ctx, "rt", "")
require.NoError(s.T(), err, "RefreshToken")
resp, err := s.svc.RefreshTokenWithClientID(s.ctx, "rt", "", openai.SoraClientID)
require.NoError(s.T(), err, "RefreshTokenWithClientID")
require.Equal(s.T(), "at-sora", resp.AccessToken)
require.Equal(s.T(), "rt-sora", resp.RefreshToken)
require.Equal(s.T(), []string{openai.ClientID, openai.SoraClientID}, seenClientIDs)
require.Equal(s.T(), []string{openai.SoraClientID}, seenClientIDs)
}
func (s *OpenAIOAuthServiceSuite) TestRefreshToken_UseProvidedClientID() {
@@ -196,7 +213,7 @@ func (s *OpenAIOAuthServiceSuite) TestNonSuccessStatus_IncludesBody() {
_, _ = io.WriteString(w, "bad")
}))
_, err := s.svc.ExchangeCode(s.ctx, "code", "ver", openai.DefaultRedirectURI, "")
_, err := s.svc.ExchangeCode(s.ctx, "code", "ver", openai.DefaultRedirectURI, "", "")
require.Error(s.T(), err)
require.ErrorContains(s.T(), err, "status 400")
require.ErrorContains(s.T(), err, "bad")
@@ -206,7 +223,7 @@ func (s *OpenAIOAuthServiceSuite) TestRequestError_ClosedServer() {
s.setupServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {}))
s.srv.Close()
_, err := s.svc.ExchangeCode(s.ctx, "code", "ver", openai.DefaultRedirectURI, "")
_, err := s.svc.ExchangeCode(s.ctx, "code", "ver", openai.DefaultRedirectURI, "", "")
require.Error(s.T(), err)
require.ErrorContains(s.T(), err, "request failed")
}
@@ -223,7 +240,7 @@ func (s *OpenAIOAuthServiceSuite) TestContextCancel() {
done := make(chan error, 1)
go func() {
_, err := s.svc.ExchangeCode(ctx, "code", "ver", openai.DefaultRedirectURI, "")
_, err := s.svc.ExchangeCode(ctx, "code", "ver", openai.DefaultRedirectURI, "", "")
done <- err
}()
@@ -249,7 +266,30 @@ func (s *OpenAIOAuthServiceSuite) TestExchangeCode_UsesProvidedRedirectURI() {
_, _ = io.WriteString(w, `{"access_token":"at","token_type":"bearer","expires_in":1}`)
}))
_, err := s.svc.ExchangeCode(s.ctx, "code", "ver", want, "")
_, err := s.svc.ExchangeCode(s.ctx, "code", "ver", want, "", "")
require.NoError(s.T(), err, "ExchangeCode")
select {
case msg := <-errCh:
require.Fail(s.T(), msg)
default:
}
}
func (s *OpenAIOAuthServiceSuite) TestExchangeCode_UseProvidedClientID() {
wantClientID := openai.SoraClientID
errCh := make(chan string, 1)
s.setupServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
_ = r.ParseForm()
if got := r.PostForm.Get("client_id"); got != wantClientID {
errCh <- "client_id mismatch"
w.WriteHeader(http.StatusBadRequest)
return
}
w.Header().Set("Content-Type", "application/json")
_, _ = io.WriteString(w, `{"access_token":"at","token_type":"bearer","expires_in":1}`)
}))
_, err := s.svc.ExchangeCode(s.ctx, "code", "ver", openai.DefaultRedirectURI, "", wantClientID)
require.NoError(s.T(), err, "ExchangeCode")
select {
case msg := <-errCh:
@@ -267,7 +307,7 @@ func (s *OpenAIOAuthServiceSuite) TestTokenURL_CanBeOverriddenWithQuery() {
}))
s.svc.tokenURL = s.srv.URL + "?x=1"
_, err := s.svc.ExchangeCode(s.ctx, "code", "ver", openai.DefaultRedirectURI, "")
_, err := s.svc.ExchangeCode(s.ctx, "code", "ver", openai.DefaultRedirectURI, "", "")
require.NoError(s.T(), err, "ExchangeCode")
select {
case <-s.received:
@@ -283,7 +323,7 @@ func (s *OpenAIOAuthServiceSuite) TestExchangeCode_SuccessButInvalidJSON() {
_, _ = io.WriteString(w, "not-valid-json")
}))
_, err := s.svc.ExchangeCode(s.ctx, "code", "ver", openai.DefaultRedirectURI, "")
_, err := s.svc.ExchangeCode(s.ctx, "code", "ver", openai.DefaultRedirectURI, "", "")
require.Error(s.T(), err, "expected error for invalid JSON response")
}

View File

@@ -12,6 +12,11 @@ import (
"github.com/Wei-Shaw/sub2api/internal/service"
)
const (
opsRawLatencyQueryTimeout = 2 * time.Second
opsRawPeakQueryTimeout = 1500 * time.Millisecond
)
func (r *opsRepository) GetDashboardOverview(ctx context.Context, filter *service.OpsDashboardFilter) (*service.OpsDashboardOverview, error) {
if r == nil || r.db == nil {
return nil, fmt.Errorf("nil ops repository")
@@ -45,15 +50,24 @@ func (r *opsRepository) GetDashboardOverview(ctx context.Context, filter *servic
func (r *opsRepository) getDashboardOverviewRaw(ctx context.Context, filter *service.OpsDashboardFilter) (*service.OpsDashboardOverview, error) {
start := filter.StartTime.UTC()
end := filter.EndTime.UTC()
degraded := false
successCount, tokenConsumed, err := r.queryUsageCounts(ctx, filter, start, end)
if err != nil {
return nil, err
}
duration, ttft, err := r.queryUsageLatency(ctx, filter, start, end)
latencyCtx, cancelLatency := context.WithTimeout(ctx, opsRawLatencyQueryTimeout)
duration, ttft, err := r.queryUsageLatency(latencyCtx, filter, start, end)
cancelLatency()
if err != nil {
return nil, err
if isQueryTimeoutErr(err) {
degraded = true
duration = service.OpsPercentiles{}
ttft = service.OpsPercentiles{}
} else {
return nil, err
}
}
errorTotal, businessLimited, errorCountSLA, upstreamExcl, upstream429, upstream529, err := r.queryErrorCounts(ctx, filter, start, end)
@@ -75,20 +89,40 @@ func (r *opsRepository) getDashboardOverviewRaw(ctx context.Context, filter *ser
qpsCurrent, tpsCurrent, err := r.queryCurrentRates(ctx, filter, end)
if err != nil {
return nil, err
if isQueryTimeoutErr(err) {
degraded = true
} else {
return nil, err
}
}
qpsPeak, err := r.queryPeakQPS(ctx, filter, start, end)
peakCtx, cancelPeak := context.WithTimeout(ctx, opsRawPeakQueryTimeout)
qpsPeak, tpsPeak, err := r.queryPeakRates(peakCtx, filter, start, end)
cancelPeak()
if err != nil {
return nil, err
}
tpsPeak, err := r.queryPeakTPS(ctx, filter, start, end)
if err != nil {
return nil, err
if isQueryTimeoutErr(err) {
degraded = true
} else {
return nil, err
}
}
qpsAvg := roundTo1DP(float64(requestCountTotal) / windowSeconds)
tpsAvg := roundTo1DP(float64(tokenConsumed) / windowSeconds)
if degraded {
if qpsCurrent <= 0 {
qpsCurrent = qpsAvg
}
if tpsCurrent <= 0 {
tpsCurrent = tpsAvg
}
if qpsPeak <= 0 {
qpsPeak = roundTo1DP(math.Max(qpsCurrent, qpsAvg))
}
if tpsPeak <= 0 {
tpsPeak = roundTo1DP(math.Max(tpsCurrent, tpsAvg))
}
}
return &service.OpsDashboardOverview{
StartTime: start,
@@ -230,26 +264,45 @@ func (r *opsRepository) getDashboardOverviewPreaggregated(ctx context.Context, f
sla := safeDivideFloat64(float64(successCount), float64(requestCountSLA))
errorRate := safeDivideFloat64(float64(errorCountSLA), float64(requestCountSLA))
upstreamErrorRate := safeDivideFloat64(float64(upstreamExcl), float64(requestCountSLA))
degraded := false
// Keep "current" rates as raw, to preserve realtime semantics.
qpsCurrent, tpsCurrent, err := r.queryCurrentRates(ctx, filter, end)
if err != nil {
return nil, err
if isQueryTimeoutErr(err) {
degraded = true
} else {
return nil, err
}
}
// NOTE: peak still uses raw logs (minute granularity). This is typically cheaper than percentile_cont
// and keeps semantics consistent across modes.
qpsPeak, err := r.queryPeakQPS(ctx, filter, start, end)
peakCtx, cancelPeak := context.WithTimeout(ctx, opsRawPeakQueryTimeout)
qpsPeak, tpsPeak, err := r.queryPeakRates(peakCtx, filter, start, end)
cancelPeak()
if err != nil {
return nil, err
}
tpsPeak, err := r.queryPeakTPS(ctx, filter, start, end)
if err != nil {
return nil, err
if isQueryTimeoutErr(err) {
degraded = true
} else {
return nil, err
}
}
qpsAvg := roundTo1DP(float64(requestCountTotal) / windowSeconds)
tpsAvg := roundTo1DP(float64(tokenConsumed) / windowSeconds)
if degraded {
if qpsCurrent <= 0 {
qpsCurrent = qpsAvg
}
if tpsCurrent <= 0 {
tpsCurrent = tpsAvg
}
if qpsPeak <= 0 {
qpsPeak = roundTo1DP(math.Max(qpsCurrent, qpsAvg))
}
if tpsPeak <= 0 {
tpsPeak = roundTo1DP(math.Max(tpsCurrent, tpsAvg))
}
}
return &service.OpsDashboardOverview{
StartTime: start,
@@ -577,9 +630,16 @@ func (r *opsRepository) queryRawPartial(ctx context.Context, filter *service.Ops
return nil, err
}
duration, ttft, err := r.queryUsageLatency(ctx, filter, start, end)
latencyCtx, cancelLatency := context.WithTimeout(ctx, opsRawLatencyQueryTimeout)
duration, ttft, err := r.queryUsageLatency(latencyCtx, filter, start, end)
cancelLatency()
if err != nil {
return nil, err
if isQueryTimeoutErr(err) {
duration = service.OpsPercentiles{}
ttft = service.OpsPercentiles{}
} else {
return nil, err
}
}
errorTotal, businessLimited, errorCountSLA, upstreamExcl, upstream429, upstream529, err := r.queryErrorCounts(ctx, filter, start, end)
@@ -735,68 +795,56 @@ FROM usage_logs ul
}
func (r *opsRepository) queryUsageLatency(ctx context.Context, filter *service.OpsDashboardFilter, start, end time.Time) (duration service.OpsPercentiles, ttft service.OpsPercentiles, err error) {
{
join, where, args, _ := buildUsageWhere(filter, start, end, 1)
q := `
join, where, args, _ := buildUsageWhere(filter, start, end, 1)
q := `
SELECT
percentile_cont(0.50) WITHIN GROUP (ORDER BY duration_ms) AS p50,
percentile_cont(0.90) WITHIN GROUP (ORDER BY duration_ms) AS p90,
percentile_cont(0.95) WITHIN GROUP (ORDER BY duration_ms) AS p95,
percentile_cont(0.99) WITHIN GROUP (ORDER BY duration_ms) AS p99,
AVG(duration_ms) AS avg_ms,
MAX(duration_ms) AS max_ms
percentile_cont(0.50) WITHIN GROUP (ORDER BY duration_ms) FILTER (WHERE duration_ms IS NOT NULL) AS duration_p50,
percentile_cont(0.90) WITHIN GROUP (ORDER BY duration_ms) FILTER (WHERE duration_ms IS NOT NULL) AS duration_p90,
percentile_cont(0.95) WITHIN GROUP (ORDER BY duration_ms) FILTER (WHERE duration_ms IS NOT NULL) AS duration_p95,
percentile_cont(0.99) WITHIN GROUP (ORDER BY duration_ms) FILTER (WHERE duration_ms IS NOT NULL) AS duration_p99,
AVG(duration_ms) FILTER (WHERE duration_ms IS NOT NULL) AS duration_avg,
MAX(duration_ms) AS duration_max,
percentile_cont(0.50) WITHIN GROUP (ORDER BY first_token_ms) FILTER (WHERE first_token_ms IS NOT NULL) AS ttft_p50,
percentile_cont(0.90) WITHIN GROUP (ORDER BY first_token_ms) FILTER (WHERE first_token_ms IS NOT NULL) AS ttft_p90,
percentile_cont(0.95) WITHIN GROUP (ORDER BY first_token_ms) FILTER (WHERE first_token_ms IS NOT NULL) AS ttft_p95,
percentile_cont(0.99) WITHIN GROUP (ORDER BY first_token_ms) FILTER (WHERE first_token_ms IS NOT NULL) AS ttft_p99,
AVG(first_token_ms) FILTER (WHERE first_token_ms IS NOT NULL) AS ttft_avg,
MAX(first_token_ms) AS ttft_max
FROM usage_logs ul
` + join + `
` + where + `
AND duration_ms IS NOT NULL`
` + where
var p50, p90, p95, p99 sql.NullFloat64
var avg sql.NullFloat64
var max sql.NullInt64
if err := r.db.QueryRowContext(ctx, q, args...).Scan(&p50, &p90, &p95, &p99, &avg, &max); err != nil {
return service.OpsPercentiles{}, service.OpsPercentiles{}, err
}
duration.P50 = floatToIntPtr(p50)
duration.P90 = floatToIntPtr(p90)
duration.P95 = floatToIntPtr(p95)
duration.P99 = floatToIntPtr(p99)
duration.Avg = floatToIntPtr(avg)
if max.Valid {
v := int(max.Int64)
duration.Max = &v
}
var dP50, dP90, dP95, dP99 sql.NullFloat64
var dAvg sql.NullFloat64
var dMax sql.NullInt64
var tP50, tP90, tP95, tP99 sql.NullFloat64
var tAvg sql.NullFloat64
var tMax sql.NullInt64
if err := r.db.QueryRowContext(ctx, q, args...).Scan(
&dP50, &dP90, &dP95, &dP99, &dAvg, &dMax,
&tP50, &tP90, &tP95, &tP99, &tAvg, &tMax,
); err != nil {
return service.OpsPercentiles{}, service.OpsPercentiles{}, err
}
{
join, where, args, _ := buildUsageWhere(filter, start, end, 1)
q := `
SELECT
percentile_cont(0.50) WITHIN GROUP (ORDER BY first_token_ms) AS p50,
percentile_cont(0.90) WITHIN GROUP (ORDER BY first_token_ms) AS p90,
percentile_cont(0.95) WITHIN GROUP (ORDER BY first_token_ms) AS p95,
percentile_cont(0.99) WITHIN GROUP (ORDER BY first_token_ms) AS p99,
AVG(first_token_ms) AS avg_ms,
MAX(first_token_ms) AS max_ms
FROM usage_logs ul
` + join + `
` + where + `
AND first_token_ms IS NOT NULL`
duration.P50 = floatToIntPtr(dP50)
duration.P90 = floatToIntPtr(dP90)
duration.P95 = floatToIntPtr(dP95)
duration.P99 = floatToIntPtr(dP99)
duration.Avg = floatToIntPtr(dAvg)
if dMax.Valid {
v := int(dMax.Int64)
duration.Max = &v
}
var p50, p90, p95, p99 sql.NullFloat64
var avg sql.NullFloat64
var max sql.NullInt64
if err := r.db.QueryRowContext(ctx, q, args...).Scan(&p50, &p90, &p95, &p99, &avg, &max); err != nil {
return service.OpsPercentiles{}, service.OpsPercentiles{}, err
}
ttft.P50 = floatToIntPtr(p50)
ttft.P90 = floatToIntPtr(p90)
ttft.P95 = floatToIntPtr(p95)
ttft.P99 = floatToIntPtr(p99)
ttft.Avg = floatToIntPtr(avg)
if max.Valid {
v := int(max.Int64)
ttft.Max = &v
}
ttft.P50 = floatToIntPtr(tP50)
ttft.P90 = floatToIntPtr(tP90)
ttft.P95 = floatToIntPtr(tP95)
ttft.P99 = floatToIntPtr(tP99)
ttft.Avg = floatToIntPtr(tAvg)
if tMax.Valid {
v := int(tMax.Int64)
ttft.Max = &v
}
return duration, ttft, nil
@@ -854,20 +902,23 @@ func (r *opsRepository) queryCurrentRates(ctx context.Context, filter *service.O
return qpsCurrent, tpsCurrent, nil
}
func (r *opsRepository) queryPeakQPS(ctx context.Context, filter *service.OpsDashboardFilter, start, end time.Time) (float64, error) {
func (r *opsRepository) queryPeakRates(ctx context.Context, filter *service.OpsDashboardFilter, start, end time.Time) (qpsPeak float64, tpsPeak float64, err error) {
usageJoin, usageWhere, usageArgs, next := buildUsageWhere(filter, start, end, 1)
errorWhere, errorArgs, _ := buildErrorWhere(filter, start, end, next)
q := `
WITH usage_buckets AS (
SELECT date_trunc('minute', ul.created_at) AS bucket, COUNT(*) AS cnt
SELECT
date_trunc('minute', ul.created_at) AS bucket,
COUNT(*) AS req_cnt,
COALESCE(SUM(input_tokens + output_tokens + cache_creation_tokens + cache_read_tokens), 0) AS token_cnt
FROM usage_logs ul
` + usageJoin + `
` + usageWhere + `
GROUP BY 1
),
error_buckets AS (
SELECT date_trunc('minute', created_at) AS bucket, COUNT(*) AS cnt
SELECT date_trunc('minute', created_at) AS bucket, COUNT(*) AS err_cnt
FROM ops_error_logs
` + errorWhere + `
AND COALESCE(status_code, 0) >= 400
@@ -875,47 +926,33 @@ error_buckets AS (
),
combined AS (
SELECT COALESCE(u.bucket, e.bucket) AS bucket,
COALESCE(u.cnt, 0) + COALESCE(e.cnt, 0) AS total
COALESCE(u.req_cnt, 0) + COALESCE(e.err_cnt, 0) AS total_req,
COALESCE(u.token_cnt, 0) AS total_tokens
FROM usage_buckets u
FULL OUTER JOIN error_buckets e ON u.bucket = e.bucket
)
SELECT COALESCE(MAX(total), 0) FROM combined`
SELECT
COALESCE(MAX(total_req), 0) AS max_req_per_min,
COALESCE(MAX(total_tokens), 0) AS max_tokens_per_min
FROM combined`
args := append(usageArgs, errorArgs...)
var maxPerMinute sql.NullInt64
if err := r.db.QueryRowContext(ctx, q, args...).Scan(&maxPerMinute); err != nil {
return 0, err
var maxReqPerMinute, maxTokensPerMinute sql.NullInt64
if err := r.db.QueryRowContext(ctx, q, args...).Scan(&maxReqPerMinute, &maxTokensPerMinute); err != nil {
return 0, 0, err
}
if !maxPerMinute.Valid || maxPerMinute.Int64 <= 0 {
return 0, nil
if maxReqPerMinute.Valid && maxReqPerMinute.Int64 > 0 {
qpsPeak = roundTo1DP(float64(maxReqPerMinute.Int64) / 60.0)
}
return roundTo1DP(float64(maxPerMinute.Int64) / 60.0), nil
if maxTokensPerMinute.Valid && maxTokensPerMinute.Int64 > 0 {
tpsPeak = roundTo1DP(float64(maxTokensPerMinute.Int64) / 60.0)
}
return qpsPeak, tpsPeak, nil
}
func (r *opsRepository) queryPeakTPS(ctx context.Context, filter *service.OpsDashboardFilter, start, end time.Time) (float64, error) {
join, where, args, _ := buildUsageWhere(filter, start, end, 1)
q := `
SELECT COALESCE(MAX(tokens_per_min), 0)
FROM (
SELECT
date_trunc('minute', ul.created_at) AS bucket,
COALESCE(SUM(input_tokens + output_tokens + cache_creation_tokens + cache_read_tokens), 0) AS tokens_per_min
FROM usage_logs ul
` + join + `
` + where + `
GROUP BY 1
) t`
var maxPerMinute sql.NullInt64
if err := r.db.QueryRowContext(ctx, q, args...).Scan(&maxPerMinute); err != nil {
return 0, err
}
if !maxPerMinute.Valid || maxPerMinute.Int64 <= 0 {
return 0, nil
}
return roundTo1DP(float64(maxPerMinute.Int64) / 60.0), nil
func isQueryTimeoutErr(err error) bool {
return errors.Is(err, context.DeadlineExceeded)
}
func buildUsageWhere(filter *service.OpsDashboardFilter, start, end time.Time, startIndex int) (join string, where string, args []any, nextIndex int) {

View File

@@ -0,0 +1,22 @@
package repository
import (
"context"
"fmt"
"testing"
)
func TestIsQueryTimeoutErr(t *testing.T) {
if !isQueryTimeoutErr(context.DeadlineExceeded) {
t.Fatalf("context.DeadlineExceeded should be treated as query timeout")
}
if !isQueryTimeoutErr(fmt.Errorf("wrapped: %w", context.DeadlineExceeded)) {
t.Fatalf("wrapped context.DeadlineExceeded should be treated as query timeout")
}
if isQueryTimeoutErr(context.Canceled) {
t.Fatalf("context.Canceled should not be treated as query timeout")
}
if isQueryTimeoutErr(fmt.Errorf("wrapped: %w", context.Canceled)) {
t.Fatalf("wrapped context.Canceled should not be treated as query timeout")
}
}

View File

@@ -0,0 +1,419 @@
package repository
import (
"context"
"database/sql"
"encoding/json"
"fmt"
"strings"
"time"
"github.com/Wei-Shaw/sub2api/internal/service"
)
// soraGenerationRepository 实现 service.SoraGenerationRepository 接口。
// 使用原生 SQL 操作 sora_generations 表。
type soraGenerationRepository struct {
sql *sql.DB
}
// NewSoraGenerationRepository 创建 Sora 生成记录仓储实例。
func NewSoraGenerationRepository(sqlDB *sql.DB) service.SoraGenerationRepository {
return &soraGenerationRepository{sql: sqlDB}
}
func (r *soraGenerationRepository) Create(ctx context.Context, gen *service.SoraGeneration) error {
mediaURLsJSON, _ := json.Marshal(gen.MediaURLs)
s3KeysJSON, _ := json.Marshal(gen.S3ObjectKeys)
err := r.sql.QueryRowContext(ctx, `
INSERT INTO sora_generations (
user_id, api_key_id, model, prompt, media_type,
status, media_url, media_urls, file_size_bytes,
storage_type, s3_object_keys, upstream_task_id, error_message
) VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11, $12, $13)
RETURNING id, created_at
`,
gen.UserID, gen.APIKeyID, gen.Model, gen.Prompt, gen.MediaType,
gen.Status, gen.MediaURL, mediaURLsJSON, gen.FileSizeBytes,
gen.StorageType, s3KeysJSON, gen.UpstreamTaskID, gen.ErrorMessage,
).Scan(&gen.ID, &gen.CreatedAt)
return err
}
// CreatePendingWithLimit 在单事务内执行“并发上限检查 + 创建”,避免 count+create 竞态。
func (r *soraGenerationRepository) CreatePendingWithLimit(
ctx context.Context,
gen *service.SoraGeneration,
activeStatuses []string,
maxActive int64,
) error {
if gen == nil {
return fmt.Errorf("generation is nil")
}
if maxActive <= 0 {
return r.Create(ctx, gen)
}
if len(activeStatuses) == 0 {
activeStatuses = []string{service.SoraGenStatusPending, service.SoraGenStatusGenerating}
}
tx, err := r.sql.BeginTx(ctx, nil)
if err != nil {
return err
}
defer func() { _ = tx.Rollback() }()
// 使用用户级 advisory lock 串行化并发创建,避免超限竞态。
if _, err := tx.ExecContext(ctx, `SELECT pg_advisory_xact_lock($1)`, gen.UserID); err != nil {
return err
}
placeholders := make([]string, len(activeStatuses))
args := make([]any, 0, 1+len(activeStatuses))
args = append(args, gen.UserID)
for i, s := range activeStatuses {
placeholders[i] = fmt.Sprintf("$%d", i+2)
args = append(args, s)
}
countQuery := fmt.Sprintf(
`SELECT COUNT(*) FROM sora_generations WHERE user_id = $1 AND status IN (%s)`,
strings.Join(placeholders, ","),
)
var activeCount int64
if err := tx.QueryRowContext(ctx, countQuery, args...).Scan(&activeCount); err != nil {
return err
}
if activeCount >= maxActive {
return service.ErrSoraGenerationConcurrencyLimit
}
mediaURLsJSON, _ := json.Marshal(gen.MediaURLs)
s3KeysJSON, _ := json.Marshal(gen.S3ObjectKeys)
if err := tx.QueryRowContext(ctx, `
INSERT INTO sora_generations (
user_id, api_key_id, model, prompt, media_type,
status, media_url, media_urls, file_size_bytes,
storage_type, s3_object_keys, upstream_task_id, error_message
) VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11, $12, $13)
RETURNING id, created_at
`,
gen.UserID, gen.APIKeyID, gen.Model, gen.Prompt, gen.MediaType,
gen.Status, gen.MediaURL, mediaURLsJSON, gen.FileSizeBytes,
gen.StorageType, s3KeysJSON, gen.UpstreamTaskID, gen.ErrorMessage,
).Scan(&gen.ID, &gen.CreatedAt); err != nil {
return err
}
return tx.Commit()
}
func (r *soraGenerationRepository) GetByID(ctx context.Context, id int64) (*service.SoraGeneration, error) {
gen := &service.SoraGeneration{}
var mediaURLsJSON, s3KeysJSON []byte
var completedAt sql.NullTime
var apiKeyID sql.NullInt64
err := r.sql.QueryRowContext(ctx, `
SELECT id, user_id, api_key_id, model, prompt, media_type,
status, media_url, media_urls, file_size_bytes,
storage_type, s3_object_keys, upstream_task_id, error_message,
created_at, completed_at
FROM sora_generations WHERE id = $1
`, id).Scan(
&gen.ID, &gen.UserID, &apiKeyID, &gen.Model, &gen.Prompt, &gen.MediaType,
&gen.Status, &gen.MediaURL, &mediaURLsJSON, &gen.FileSizeBytes,
&gen.StorageType, &s3KeysJSON, &gen.UpstreamTaskID, &gen.ErrorMessage,
&gen.CreatedAt, &completedAt,
)
if err != nil {
if err == sql.ErrNoRows {
return nil, fmt.Errorf("生成记录不存在")
}
return nil, err
}
if apiKeyID.Valid {
gen.APIKeyID = &apiKeyID.Int64
}
if completedAt.Valid {
gen.CompletedAt = &completedAt.Time
}
_ = json.Unmarshal(mediaURLsJSON, &gen.MediaURLs)
_ = json.Unmarshal(s3KeysJSON, &gen.S3ObjectKeys)
return gen, nil
}
func (r *soraGenerationRepository) Update(ctx context.Context, gen *service.SoraGeneration) error {
mediaURLsJSON, _ := json.Marshal(gen.MediaURLs)
s3KeysJSON, _ := json.Marshal(gen.S3ObjectKeys)
var completedAt *time.Time
if gen.CompletedAt != nil {
completedAt = gen.CompletedAt
}
_, err := r.sql.ExecContext(ctx, `
UPDATE sora_generations SET
status = $2, media_url = $3, media_urls = $4, file_size_bytes = $5,
storage_type = $6, s3_object_keys = $7, upstream_task_id = $8,
error_message = $9, completed_at = $10
WHERE id = $1
`,
gen.ID, gen.Status, gen.MediaURL, mediaURLsJSON, gen.FileSizeBytes,
gen.StorageType, s3KeysJSON, gen.UpstreamTaskID,
gen.ErrorMessage, completedAt,
)
return err
}
// UpdateGeneratingIfPending 仅当状态为 pending 时更新为 generating。
func (r *soraGenerationRepository) UpdateGeneratingIfPending(ctx context.Context, id int64, upstreamTaskID string) (bool, error) {
result, err := r.sql.ExecContext(ctx, `
UPDATE sora_generations
SET status = $2, upstream_task_id = $3
WHERE id = $1 AND status = $4
`,
id, service.SoraGenStatusGenerating, upstreamTaskID, service.SoraGenStatusPending,
)
if err != nil {
return false, err
}
affected, err := result.RowsAffected()
if err != nil {
return false, err
}
return affected > 0, nil
}
// UpdateCompletedIfActive 仅当状态为 pending/generating 时更新为 completed。
func (r *soraGenerationRepository) UpdateCompletedIfActive(
ctx context.Context,
id int64,
mediaURL string,
mediaURLs []string,
storageType string,
s3Keys []string,
fileSizeBytes int64,
completedAt time.Time,
) (bool, error) {
mediaURLsJSON, _ := json.Marshal(mediaURLs)
s3KeysJSON, _ := json.Marshal(s3Keys)
result, err := r.sql.ExecContext(ctx, `
UPDATE sora_generations
SET status = $2,
media_url = $3,
media_urls = $4,
file_size_bytes = $5,
storage_type = $6,
s3_object_keys = $7,
error_message = '',
completed_at = $8
WHERE id = $1 AND status IN ($9, $10)
`,
id, service.SoraGenStatusCompleted, mediaURL, mediaURLsJSON, fileSizeBytes,
storageType, s3KeysJSON, completedAt, service.SoraGenStatusPending, service.SoraGenStatusGenerating,
)
if err != nil {
return false, err
}
affected, err := result.RowsAffected()
if err != nil {
return false, err
}
return affected > 0, nil
}
// UpdateFailedIfActive 仅当状态为 pending/generating 时更新为 failed。
func (r *soraGenerationRepository) UpdateFailedIfActive(
ctx context.Context,
id int64,
errMsg string,
completedAt time.Time,
) (bool, error) {
result, err := r.sql.ExecContext(ctx, `
UPDATE sora_generations
SET status = $2,
error_message = $3,
completed_at = $4
WHERE id = $1 AND status IN ($5, $6)
`,
id, service.SoraGenStatusFailed, errMsg, completedAt, service.SoraGenStatusPending, service.SoraGenStatusGenerating,
)
if err != nil {
return false, err
}
affected, err := result.RowsAffected()
if err != nil {
return false, err
}
return affected > 0, nil
}
// UpdateCancelledIfActive 仅当状态为 pending/generating 时更新为 cancelled。
func (r *soraGenerationRepository) UpdateCancelledIfActive(ctx context.Context, id int64, completedAt time.Time) (bool, error) {
result, err := r.sql.ExecContext(ctx, `
UPDATE sora_generations
SET status = $2, completed_at = $3
WHERE id = $1 AND status IN ($4, $5)
`,
id, service.SoraGenStatusCancelled, completedAt, service.SoraGenStatusPending, service.SoraGenStatusGenerating,
)
if err != nil {
return false, err
}
affected, err := result.RowsAffected()
if err != nil {
return false, err
}
return affected > 0, nil
}
// UpdateStorageIfCompleted 更新已完成记录的存储信息(用于手动保存,不重置 completed_at
func (r *soraGenerationRepository) UpdateStorageIfCompleted(
ctx context.Context,
id int64,
mediaURL string,
mediaURLs []string,
storageType string,
s3Keys []string,
fileSizeBytes int64,
) (bool, error) {
mediaURLsJSON, _ := json.Marshal(mediaURLs)
s3KeysJSON, _ := json.Marshal(s3Keys)
result, err := r.sql.ExecContext(ctx, `
UPDATE sora_generations
SET media_url = $2,
media_urls = $3,
file_size_bytes = $4,
storage_type = $5,
s3_object_keys = $6
WHERE id = $1 AND status = $7
`,
id, mediaURL, mediaURLsJSON, fileSizeBytes, storageType, s3KeysJSON, service.SoraGenStatusCompleted,
)
if err != nil {
return false, err
}
affected, err := result.RowsAffected()
if err != nil {
return false, err
}
return affected > 0, nil
}
func (r *soraGenerationRepository) Delete(ctx context.Context, id int64) error {
_, err := r.sql.ExecContext(ctx, `DELETE FROM sora_generations WHERE id = $1`, id)
return err
}
func (r *soraGenerationRepository) List(ctx context.Context, params service.SoraGenerationListParams) ([]*service.SoraGeneration, int64, error) {
// 构建 WHERE 条件
conditions := []string{"user_id = $1"}
args := []any{params.UserID}
argIdx := 2
if params.Status != "" {
// 支持逗号分隔的多状态
statuses := strings.Split(params.Status, ",")
placeholders := make([]string, len(statuses))
for i, s := range statuses {
placeholders[i] = fmt.Sprintf("$%d", argIdx)
args = append(args, strings.TrimSpace(s))
argIdx++
}
conditions = append(conditions, fmt.Sprintf("status IN (%s)", strings.Join(placeholders, ",")))
}
if params.StorageType != "" {
storageTypes := strings.Split(params.StorageType, ",")
placeholders := make([]string, len(storageTypes))
for i, s := range storageTypes {
placeholders[i] = fmt.Sprintf("$%d", argIdx)
args = append(args, strings.TrimSpace(s))
argIdx++
}
conditions = append(conditions, fmt.Sprintf("storage_type IN (%s)", strings.Join(placeholders, ",")))
}
if params.MediaType != "" {
conditions = append(conditions, fmt.Sprintf("media_type = $%d", argIdx))
args = append(args, params.MediaType)
argIdx++
}
whereClause := "WHERE " + strings.Join(conditions, " AND ")
// 计数
var total int64
countQuery := fmt.Sprintf("SELECT COUNT(*) FROM sora_generations %s", whereClause)
if err := r.sql.QueryRowContext(ctx, countQuery, args...).Scan(&total); err != nil {
return nil, 0, err
}
// 分页查询
offset := (params.Page - 1) * params.PageSize
listQuery := fmt.Sprintf(`
SELECT id, user_id, api_key_id, model, prompt, media_type,
status, media_url, media_urls, file_size_bytes,
storage_type, s3_object_keys, upstream_task_id, error_message,
created_at, completed_at
FROM sora_generations %s
ORDER BY created_at DESC
LIMIT $%d OFFSET $%d
`, whereClause, argIdx, argIdx+1)
args = append(args, params.PageSize, offset)
rows, err := r.sql.QueryContext(ctx, listQuery, args...)
if err != nil {
return nil, 0, err
}
defer func() {
_ = rows.Close()
}()
var results []*service.SoraGeneration
for rows.Next() {
gen := &service.SoraGeneration{}
var mediaURLsJSON, s3KeysJSON []byte
var completedAt sql.NullTime
var apiKeyID sql.NullInt64
if err := rows.Scan(
&gen.ID, &gen.UserID, &apiKeyID, &gen.Model, &gen.Prompt, &gen.MediaType,
&gen.Status, &gen.MediaURL, &mediaURLsJSON, &gen.FileSizeBytes,
&gen.StorageType, &s3KeysJSON, &gen.UpstreamTaskID, &gen.ErrorMessage,
&gen.CreatedAt, &completedAt,
); err != nil {
return nil, 0, err
}
if apiKeyID.Valid {
gen.APIKeyID = &apiKeyID.Int64
}
if completedAt.Valid {
gen.CompletedAt = &completedAt.Time
}
_ = json.Unmarshal(mediaURLsJSON, &gen.MediaURLs)
_ = json.Unmarshal(s3KeysJSON, &gen.S3ObjectKeys)
results = append(results, gen)
}
return results, total, rows.Err()
}
func (r *soraGenerationRepository) CountByUserAndStatus(ctx context.Context, userID int64, statuses []string) (int64, error) {
if len(statuses) == 0 {
return 0, nil
}
placeholders := make([]string, len(statuses))
args := []any{userID}
for i, s := range statuses {
placeholders[i] = fmt.Sprintf("$%d", i+2)
args = append(args, s)
}
var count int64
query := fmt.Sprintf("SELECT COUNT(*) FROM sora_generations WHERE user_id = $1 AND status IN (%s)", strings.Join(placeholders, ","))
err := r.sql.QueryRowContext(ctx, query, args...).Scan(&count)
return count, err
}

View File

@@ -362,7 +362,12 @@ func buildUsageCleanupWhere(filters service.UsageCleanupFilters) (string, []any)
idx++
}
}
if filters.Stream != nil {
if filters.RequestType != nil {
condition, conditionArgs := buildRequestTypeFilterCondition(idx, *filters.RequestType)
conditions = append(conditions, condition)
args = append(args, conditionArgs...)
idx += len(conditionArgs)
} else if filters.Stream != nil {
conditions = append(conditions, fmt.Sprintf("stream = $%d", idx))
args = append(args, *filters.Stream)
idx++

View File

@@ -466,6 +466,38 @@ func TestBuildUsageCleanupWhere(t *testing.T) {
require.Equal(t, []any{start, end, userID, apiKeyID, accountID, groupID, "gpt-4", stream, billingType}, args)
}
func TestBuildUsageCleanupWhereRequestTypePriority(t *testing.T) {
start := time.Date(2024, 1, 1, 0, 0, 0, 0, time.UTC)
end := start.Add(24 * time.Hour)
requestType := int16(service.RequestTypeWSV2)
stream := false
where, args := buildUsageCleanupWhere(service.UsageCleanupFilters{
StartTime: start,
EndTime: end,
RequestType: &requestType,
Stream: &stream,
})
require.Equal(t, "created_at >= $1 AND created_at <= $2 AND (request_type = $3 OR (request_type = 0 AND openai_ws_mode = TRUE))", where)
require.Equal(t, []any{start, end, requestType}, args)
}
func TestBuildUsageCleanupWhereRequestTypeLegacyFallback(t *testing.T) {
start := time.Date(2024, 1, 1, 0, 0, 0, 0, time.UTC)
end := start.Add(24 * time.Hour)
requestType := int16(service.RequestTypeStream)
where, args := buildUsageCleanupWhere(service.UsageCleanupFilters{
StartTime: start,
EndTime: end,
RequestType: &requestType,
})
require.Equal(t, "created_at >= $1 AND created_at <= $2 AND (request_type = $3 OR (request_type = 0 AND stream = TRUE AND openai_ws_mode = FALSE))", where)
require.Equal(t, []any{start, end, requestType}, args)
}
func TestBuildUsageCleanupWhereModelEmpty(t *testing.T) {
start := time.Date(2024, 1, 1, 0, 0, 0, 0, time.UTC)
end := start.Add(24 * time.Hour)

View File

@@ -22,7 +22,7 @@ import (
"github.com/lib/pq"
)
const usageLogSelectColumns = "id, user_id, api_key_id, account_id, request_id, model, group_id, subscription_id, input_tokens, output_tokens, cache_creation_tokens, cache_read_tokens, cache_creation_5m_tokens, cache_creation_1h_tokens, input_cost, output_cost, cache_creation_cost, cache_read_cost, total_cost, actual_cost, rate_multiplier, account_rate_multiplier, billing_type, stream, duration_ms, first_token_ms, user_agent, ip_address, image_count, image_size, media_type, reasoning_effort, cache_ttl_overridden, created_at"
const usageLogSelectColumns = "id, user_id, api_key_id, account_id, request_id, model, group_id, subscription_id, input_tokens, output_tokens, cache_creation_tokens, cache_read_tokens, cache_creation_5m_tokens, cache_creation_1h_tokens, input_cost, output_cost, cache_creation_cost, cache_read_cost, total_cost, actual_cost, rate_multiplier, account_rate_multiplier, billing_type, request_type, stream, openai_ws_mode, duration_ms, first_token_ms, user_agent, ip_address, image_count, image_size, media_type, reasoning_effort, cache_ttl_overridden, created_at"
// dateFormatWhitelist 将 granularity 参数映射为 PostgreSQL TO_CHAR 格式字符串,防止外部输入直接拼入 SQL
var dateFormatWhitelist = map[string]string{
@@ -98,6 +98,8 @@ func (r *usageLogRepository) Create(ctx context.Context, log *service.UsageLog)
log.RequestID = requestID
rateMultiplier := log.RateMultiplier
log.SyncRequestTypeAndLegacyFields()
requestType := int16(log.RequestType)
query := `
INSERT INTO usage_logs (
@@ -123,7 +125,9 @@ func (r *usageLogRepository) Create(ctx context.Context, log *service.UsageLog)
rate_multiplier,
account_rate_multiplier,
billing_type,
request_type,
stream,
openai_ws_mode,
duration_ms,
first_token_ms,
user_agent,
@@ -140,7 +144,7 @@ func (r *usageLogRepository) Create(ctx context.Context, log *service.UsageLog)
$8, $9, $10, $11,
$12, $13,
$14, $15, $16, $17, $18, $19,
$20, $21, $22, $23, $24, $25, $26, $27, $28, $29, $30, $31, $32, $33
$20, $21, $22, $23, $24, $25, $26, $27, $28, $29, $30, $31, $32, $33, $34, $35
)
ON CONFLICT (request_id, api_key_id) DO NOTHING
RETURNING id, created_at
@@ -184,7 +188,9 @@ func (r *usageLogRepository) Create(ctx context.Context, log *service.UsageLog)
rateMultiplier,
log.AccountRateMultiplier,
log.BillingType,
requestType,
log.Stream,
log.OpenAIWSMode,
duration,
firstToken,
userAgent,
@@ -492,25 +498,46 @@ func (r *usageLogRepository) fillDashboardUsageStatsAggregated(ctx context.Conte
}
func (r *usageLogRepository) fillDashboardUsageStatsFromUsageLogs(ctx context.Context, stats *DashboardStats, startUTC, endUTC, todayUTC, now time.Time) error {
totalStatsQuery := `
todayEnd := todayUTC.Add(24 * time.Hour)
combinedStatsQuery := `
WITH scoped AS (
SELECT
created_at,
input_tokens,
output_tokens,
cache_creation_tokens,
cache_read_tokens,
total_cost,
actual_cost,
COALESCE(duration_ms, 0) AS duration_ms
FROM usage_logs
WHERE created_at >= LEAST($1::timestamptz, $3::timestamptz)
AND created_at < GREATEST($2::timestamptz, $4::timestamptz)
)
SELECT
COUNT(*) as total_requests,
COALESCE(SUM(input_tokens), 0) as total_input_tokens,
COALESCE(SUM(output_tokens), 0) as total_output_tokens,
COALESCE(SUM(cache_creation_tokens), 0) as total_cache_creation_tokens,
COALESCE(SUM(cache_read_tokens), 0) as total_cache_read_tokens,
COALESCE(SUM(total_cost), 0) as total_cost,
COALESCE(SUM(actual_cost), 0) as total_actual_cost,
COALESCE(SUM(COALESCE(duration_ms, 0)), 0) as total_duration_ms
FROM usage_logs
WHERE created_at >= $1 AND created_at < $2
COUNT(*) FILTER (WHERE created_at >= $1::timestamptz AND created_at < $2::timestamptz) AS total_requests,
COALESCE(SUM(input_tokens) FILTER (WHERE created_at >= $1::timestamptz AND created_at < $2::timestamptz), 0) AS total_input_tokens,
COALESCE(SUM(output_tokens) FILTER (WHERE created_at >= $1::timestamptz AND created_at < $2::timestamptz), 0) AS total_output_tokens,
COALESCE(SUM(cache_creation_tokens) FILTER (WHERE created_at >= $1::timestamptz AND created_at < $2::timestamptz), 0) AS total_cache_creation_tokens,
COALESCE(SUM(cache_read_tokens) FILTER (WHERE created_at >= $1::timestamptz AND created_at < $2::timestamptz), 0) AS total_cache_read_tokens,
COALESCE(SUM(total_cost) FILTER (WHERE created_at >= $1::timestamptz AND created_at < $2::timestamptz), 0) AS total_cost,
COALESCE(SUM(actual_cost) FILTER (WHERE created_at >= $1::timestamptz AND created_at < $2::timestamptz), 0) AS total_actual_cost,
COALESCE(SUM(duration_ms) FILTER (WHERE created_at >= $1::timestamptz AND created_at < $2::timestamptz), 0) AS total_duration_ms,
COUNT(*) FILTER (WHERE created_at >= $3::timestamptz AND created_at < $4::timestamptz) AS today_requests,
COALESCE(SUM(input_tokens) FILTER (WHERE created_at >= $3::timestamptz AND created_at < $4::timestamptz), 0) AS today_input_tokens,
COALESCE(SUM(output_tokens) FILTER (WHERE created_at >= $3::timestamptz AND created_at < $4::timestamptz), 0) AS today_output_tokens,
COALESCE(SUM(cache_creation_tokens) FILTER (WHERE created_at >= $3::timestamptz AND created_at < $4::timestamptz), 0) AS today_cache_creation_tokens,
COALESCE(SUM(cache_read_tokens) FILTER (WHERE created_at >= $3::timestamptz AND created_at < $4::timestamptz), 0) AS today_cache_read_tokens,
COALESCE(SUM(total_cost) FILTER (WHERE created_at >= $3::timestamptz AND created_at < $4::timestamptz), 0) AS today_cost,
COALESCE(SUM(actual_cost) FILTER (WHERE created_at >= $3::timestamptz AND created_at < $4::timestamptz), 0) AS today_actual_cost
FROM scoped
`
var totalDurationMs int64
if err := scanSingleRow(
ctx,
r.sql,
totalStatsQuery,
[]any{startUTC, endUTC},
combinedStatsQuery,
[]any{startUTC, endUTC, todayUTC, todayEnd},
&stats.TotalRequests,
&stats.TotalInputTokens,
&stats.TotalOutputTokens,
@@ -519,32 +546,6 @@ func (r *usageLogRepository) fillDashboardUsageStatsFromUsageLogs(ctx context.Co
&stats.TotalCost,
&stats.TotalActualCost,
&totalDurationMs,
); err != nil {
return err
}
stats.TotalTokens = stats.TotalInputTokens + stats.TotalOutputTokens + stats.TotalCacheCreationTokens + stats.TotalCacheReadTokens
if stats.TotalRequests > 0 {
stats.AverageDurationMs = float64(totalDurationMs) / float64(stats.TotalRequests)
}
todayEnd := todayUTC.Add(24 * time.Hour)
todayStatsQuery := `
SELECT
COUNT(*) as today_requests,
COALESCE(SUM(input_tokens), 0) as today_input_tokens,
COALESCE(SUM(output_tokens), 0) as today_output_tokens,
COALESCE(SUM(cache_creation_tokens), 0) as today_cache_creation_tokens,
COALESCE(SUM(cache_read_tokens), 0) as today_cache_read_tokens,
COALESCE(SUM(total_cost), 0) as today_cost,
COALESCE(SUM(actual_cost), 0) as today_actual_cost
FROM usage_logs
WHERE created_at >= $1 AND created_at < $2
`
if err := scanSingleRow(
ctx,
r.sql,
todayStatsQuery,
[]any{todayUTC, todayEnd},
&stats.TodayRequests,
&stats.TodayInputTokens,
&stats.TodayOutputTokens,
@@ -555,25 +556,28 @@ func (r *usageLogRepository) fillDashboardUsageStatsFromUsageLogs(ctx context.Co
); err != nil {
return err
}
stats.TodayTokens = stats.TodayInputTokens + stats.TodayOutputTokens + stats.TodayCacheCreationTokens + stats.TodayCacheReadTokens
activeUsersQuery := `
SELECT COUNT(DISTINCT user_id) as active_users
FROM usage_logs
WHERE created_at >= $1 AND created_at < $2
`
if err := scanSingleRow(ctx, r.sql, activeUsersQuery, []any{todayUTC, todayEnd}, &stats.ActiveUsers); err != nil {
return err
stats.TotalTokens = stats.TotalInputTokens + stats.TotalOutputTokens + stats.TotalCacheCreationTokens + stats.TotalCacheReadTokens
if stats.TotalRequests > 0 {
stats.AverageDurationMs = float64(totalDurationMs) / float64(stats.TotalRequests)
}
stats.TodayTokens = stats.TodayInputTokens + stats.TodayOutputTokens + stats.TodayCacheCreationTokens + stats.TodayCacheReadTokens
hourStart := now.UTC().Truncate(time.Hour)
hourEnd := hourStart.Add(time.Hour)
hourlyActiveQuery := `
SELECT COUNT(DISTINCT user_id) as active_users
FROM usage_logs
WHERE created_at >= $1 AND created_at < $2
activeUsersQuery := `
WITH scoped AS (
SELECT user_id, created_at
FROM usage_logs
WHERE created_at >= LEAST($1::timestamptz, $3::timestamptz)
AND created_at < GREATEST($2::timestamptz, $4::timestamptz)
)
SELECT
COUNT(DISTINCT CASE WHEN created_at >= $1::timestamptz AND created_at < $2::timestamptz THEN user_id END) AS active_users,
COUNT(DISTINCT CASE WHEN created_at >= $3::timestamptz AND created_at < $4::timestamptz THEN user_id END) AS hourly_active_users
FROM scoped
`
if err := scanSingleRow(ctx, r.sql, hourlyActiveQuery, []any{hourStart, hourEnd}, &stats.HourlyActiveUsers); err != nil {
if err := scanSingleRow(ctx, r.sql, activeUsersQuery, []any{todayUTC, todayEnd, hourStart, hourEnd}, &stats.ActiveUsers, &stats.HourlyActiveUsers); err != nil {
return err
}
@@ -968,6 +972,61 @@ func (r *usageLogRepository) GetAccountWindowStatsBatch(ctx context.Context, acc
return result, nil
}
// GetGeminiUsageTotalsBatch 批量聚合 Gemini 账号在窗口内的 Pro/Flash 请求与用量。
// 模型分类规则与 service.geminiModelClassFromName 一致model 包含 flash/lite 视为 flash其余视为 pro。
func (r *usageLogRepository) GetGeminiUsageTotalsBatch(ctx context.Context, accountIDs []int64, startTime, endTime time.Time) (map[int64]service.GeminiUsageTotals, error) {
result := make(map[int64]service.GeminiUsageTotals, len(accountIDs))
if len(accountIDs) == 0 {
return result, nil
}
query := `
SELECT
account_id,
COALESCE(SUM(CASE WHEN LOWER(COALESCE(model, '')) LIKE '%flash%' OR LOWER(COALESCE(model, '')) LIKE '%lite%' THEN 1 ELSE 0 END), 0) AS flash_requests,
COALESCE(SUM(CASE WHEN LOWER(COALESCE(model, '')) LIKE '%flash%' OR LOWER(COALESCE(model, '')) LIKE '%lite%' THEN 0 ELSE 1 END), 0) AS pro_requests,
COALESCE(SUM(CASE WHEN LOWER(COALESCE(model, '')) LIKE '%flash%' OR LOWER(COALESCE(model, '')) LIKE '%lite%' THEN (input_tokens + output_tokens + cache_creation_tokens + cache_read_tokens) ELSE 0 END), 0) AS flash_tokens,
COALESCE(SUM(CASE WHEN LOWER(COALESCE(model, '')) LIKE '%flash%' OR LOWER(COALESCE(model, '')) LIKE '%lite%' THEN 0 ELSE (input_tokens + output_tokens + cache_creation_tokens + cache_read_tokens) END), 0) AS pro_tokens,
COALESCE(SUM(CASE WHEN LOWER(COALESCE(model, '')) LIKE '%flash%' OR LOWER(COALESCE(model, '')) LIKE '%lite%' THEN actual_cost ELSE 0 END), 0) AS flash_cost,
COALESCE(SUM(CASE WHEN LOWER(COALESCE(model, '')) LIKE '%flash%' OR LOWER(COALESCE(model, '')) LIKE '%lite%' THEN 0 ELSE actual_cost END), 0) AS pro_cost
FROM usage_logs
WHERE account_id = ANY($1) AND created_at >= $2 AND created_at < $3
GROUP BY account_id
`
rows, err := r.sql.QueryContext(ctx, query, pq.Array(accountIDs), startTime, endTime)
if err != nil {
return nil, err
}
defer func() { _ = rows.Close() }()
for rows.Next() {
var accountID int64
var totals service.GeminiUsageTotals
if err := rows.Scan(
&accountID,
&totals.FlashRequests,
&totals.ProRequests,
&totals.FlashTokens,
&totals.ProTokens,
&totals.FlashCost,
&totals.ProCost,
); err != nil {
return nil, err
}
result[accountID] = totals
}
if err := rows.Err(); err != nil {
return nil, err
}
for _, accountID := range accountIDs {
if _, ok := result[accountID]; !ok {
result[accountID] = service.GeminiUsageTotals{}
}
}
return result, nil
}
// TrendDataPoint represents a single point in trend data
type TrendDataPoint = usagestats.TrendDataPoint
@@ -1399,10 +1458,7 @@ func (r *usageLogRepository) ListWithFilters(ctx context.Context, params paginat
conditions = append(conditions, fmt.Sprintf("model = $%d", len(args)+1))
args = append(args, filters.Model)
}
if filters.Stream != nil {
conditions = append(conditions, fmt.Sprintf("stream = $%d", len(args)+1))
args = append(args, *filters.Stream)
}
conditions, args = appendRequestTypeOrStreamWhereCondition(conditions, args, filters.RequestType, filters.Stream)
if filters.BillingType != nil {
conditions = append(conditions, fmt.Sprintf("billing_type = $%d", len(args)+1))
args = append(args, int16(*filters.BillingType))
@@ -1598,7 +1654,7 @@ func (r *usageLogRepository) GetBatchAPIKeyUsageStats(ctx context.Context, apiKe
}
// GetUsageTrendWithFilters returns usage trend data with optional filters
func (r *usageLogRepository) GetUsageTrendWithFilters(ctx context.Context, startTime, endTime time.Time, granularity string, userID, apiKeyID, accountID, groupID int64, model string, stream *bool, billingType *int8) (results []TrendDataPoint, err error) {
func (r *usageLogRepository) GetUsageTrendWithFilters(ctx context.Context, startTime, endTime time.Time, granularity string, userID, apiKeyID, accountID, groupID int64, model string, requestType *int16, stream *bool, billingType *int8) (results []TrendDataPoint, err error) {
dateFormat := safeDateFormat(granularity)
query := fmt.Sprintf(`
@@ -1636,10 +1692,7 @@ func (r *usageLogRepository) GetUsageTrendWithFilters(ctx context.Context, start
query += fmt.Sprintf(" AND model = $%d", len(args)+1)
args = append(args, model)
}
if stream != nil {
query += fmt.Sprintf(" AND stream = $%d", len(args)+1)
args = append(args, *stream)
}
query, args = appendRequestTypeOrStreamQueryFilter(query, args, requestType, stream)
if billingType != nil {
query += fmt.Sprintf(" AND billing_type = $%d", len(args)+1)
args = append(args, int16(*billingType))
@@ -1667,7 +1720,7 @@ func (r *usageLogRepository) GetUsageTrendWithFilters(ctx context.Context, start
}
// GetModelStatsWithFilters returns model statistics with optional filters
func (r *usageLogRepository) GetModelStatsWithFilters(ctx context.Context, startTime, endTime time.Time, userID, apiKeyID, accountID, groupID int64, stream *bool, billingType *int8) (results []ModelStat, err error) {
func (r *usageLogRepository) GetModelStatsWithFilters(ctx context.Context, startTime, endTime time.Time, userID, apiKeyID, accountID, groupID int64, requestType *int16, stream *bool, billingType *int8) (results []ModelStat, err error) {
actualCostExpr := "COALESCE(SUM(actual_cost), 0) as actual_cost"
// 当仅按 account_id 聚合时实际费用使用账号倍率total_cost * account_rate_multiplier
if accountID > 0 && userID == 0 && apiKeyID == 0 {
@@ -1704,10 +1757,7 @@ func (r *usageLogRepository) GetModelStatsWithFilters(ctx context.Context, start
query += fmt.Sprintf(" AND group_id = $%d", len(args)+1)
args = append(args, groupID)
}
if stream != nil {
query += fmt.Sprintf(" AND stream = $%d", len(args)+1)
args = append(args, *stream)
}
query, args = appendRequestTypeOrStreamQueryFilter(query, args, requestType, stream)
if billingType != nil {
query += fmt.Sprintf(" AND billing_type = $%d", len(args)+1)
args = append(args, int16(*billingType))
@@ -1794,10 +1844,7 @@ func (r *usageLogRepository) GetStatsWithFilters(ctx context.Context, filters Us
conditions = append(conditions, fmt.Sprintf("model = $%d", len(args)+1))
args = append(args, filters.Model)
}
if filters.Stream != nil {
conditions = append(conditions, fmt.Sprintf("stream = $%d", len(args)+1))
args = append(args, *filters.Stream)
}
conditions, args = appendRequestTypeOrStreamWhereCondition(conditions, args, filters.RequestType, filters.Stream)
if filters.BillingType != nil {
conditions = append(conditions, fmt.Sprintf("billing_type = $%d", len(args)+1))
args = append(args, int16(*filters.BillingType))
@@ -2017,7 +2064,7 @@ func (r *usageLogRepository) GetAccountUsageStats(ctx context.Context, accountID
}
}
models, err := r.GetModelStatsWithFilters(ctx, startTime, endTime, 0, 0, accountID, 0, nil, nil)
models, err := r.GetModelStatsWithFilters(ctx, startTime, endTime, 0, 0, accountID, 0, nil, nil, nil)
if err != nil {
models = []ModelStat{}
}
@@ -2267,7 +2314,9 @@ func scanUsageLog(scanner interface{ Scan(...any) error }) (*service.UsageLog, e
rateMultiplier float64
accountRateMultiplier sql.NullFloat64
billingType int16
requestTypeRaw int16
stream bool
openaiWSMode bool
durationMs sql.NullInt64
firstTokenMs sql.NullInt64
userAgent sql.NullString
@@ -2304,7 +2353,9 @@ func scanUsageLog(scanner interface{ Scan(...any) error }) (*service.UsageLog, e
&rateMultiplier,
&accountRateMultiplier,
&billingType,
&requestTypeRaw,
&stream,
&openaiWSMode,
&durationMs,
&firstTokenMs,
&userAgent,
@@ -2340,11 +2391,16 @@ func scanUsageLog(scanner interface{ Scan(...any) error }) (*service.UsageLog, e
RateMultiplier: rateMultiplier,
AccountRateMultiplier: nullFloat64Ptr(accountRateMultiplier),
BillingType: int8(billingType),
Stream: stream,
RequestType: service.RequestTypeFromInt16(requestTypeRaw),
ImageCount: imageCount,
CacheTTLOverridden: cacheTTLOverridden,
CreatedAt: createdAt,
}
// 先回填 legacy 字段,再基于 legacy + request_type 计算最终请求类型,保证历史数据兼容。
log.Stream = stream
log.OpenAIWSMode = openaiWSMode
log.RequestType = log.EffectiveRequestType()
log.Stream, log.OpenAIWSMode = service.ApplyLegacyRequestFields(log.RequestType, stream, openaiWSMode)
if requestID.Valid {
log.RequestID = requestID.String
@@ -2438,6 +2494,50 @@ func buildWhere(conditions []string) string {
return "WHERE " + strings.Join(conditions, " AND ")
}
func appendRequestTypeOrStreamWhereCondition(conditions []string, args []any, requestType *int16, stream *bool) ([]string, []any) {
if requestType != nil {
condition, conditionArgs := buildRequestTypeFilterCondition(len(args)+1, *requestType)
conditions = append(conditions, condition)
args = append(args, conditionArgs...)
return conditions, args
}
if stream != nil {
conditions = append(conditions, fmt.Sprintf("stream = $%d", len(args)+1))
args = append(args, *stream)
}
return conditions, args
}
func appendRequestTypeOrStreamQueryFilter(query string, args []any, requestType *int16, stream *bool) (string, []any) {
if requestType != nil {
condition, conditionArgs := buildRequestTypeFilterCondition(len(args)+1, *requestType)
query += " AND " + condition
args = append(args, conditionArgs...)
return query, args
}
if stream != nil {
query += fmt.Sprintf(" AND stream = $%d", len(args)+1)
args = append(args, *stream)
}
return query, args
}
// buildRequestTypeFilterCondition 在 request_type 过滤时兼容 legacy 字段,避免历史数据漏查。
func buildRequestTypeFilterCondition(startArgIndex int, requestType int16) (string, []any) {
normalized := service.RequestTypeFromInt16(requestType)
requestTypeArg := int16(normalized)
switch normalized {
case service.RequestTypeSync:
return fmt.Sprintf("(request_type = $%d OR (request_type = %d AND stream = FALSE AND openai_ws_mode = FALSE))", startArgIndex, int16(service.RequestTypeUnknown)), []any{requestTypeArg}
case service.RequestTypeStream:
return fmt.Sprintf("(request_type = $%d OR (request_type = %d AND stream = TRUE AND openai_ws_mode = FALSE))", startArgIndex, int16(service.RequestTypeUnknown)), []any{requestTypeArg}
case service.RequestTypeWSV2:
return fmt.Sprintf("(request_type = $%d OR (request_type = %d AND openai_ws_mode = TRUE))", startArgIndex, int16(service.RequestTypeUnknown)), []any{requestTypeArg}
default:
return fmt.Sprintf("request_type = $%d", startArgIndex), []any{requestTypeArg}
}
}
func nullInt64(v *int64) sql.NullInt64 {
if v == nil {
return sql.NullInt64{}

View File

@@ -130,6 +130,62 @@ func (s *UsageLogRepoSuite) TestGetByID_ReturnsAccountRateMultiplier() {
s.Require().InEpsilon(0.5, *got.AccountRateMultiplier, 0.0001)
}
func (s *UsageLogRepoSuite) TestGetByID_ReturnsOpenAIWSMode() {
user := mustCreateUser(s.T(), s.client, &service.User{Email: "getbyid-ws@test.com"})
apiKey := mustCreateApiKey(s.T(), s.client, &service.APIKey{UserID: user.ID, Key: "sk-getbyid-ws", Name: "k"})
account := mustCreateAccount(s.T(), s.client, &service.Account{Name: "acc-getbyid-ws"})
log := &service.UsageLog{
UserID: user.ID,
APIKeyID: apiKey.ID,
AccountID: account.ID,
RequestID: uuid.New().String(),
Model: "gpt-5.3-codex",
InputTokens: 10,
OutputTokens: 20,
TotalCost: 1.0,
ActualCost: 1.0,
OpenAIWSMode: true,
CreatedAt: timezone.Today().Add(3 * time.Hour),
}
_, err := s.repo.Create(s.ctx, log)
s.Require().NoError(err)
got, err := s.repo.GetByID(s.ctx, log.ID)
s.Require().NoError(err)
s.Require().True(got.OpenAIWSMode)
}
func (s *UsageLogRepoSuite) TestGetByID_ReturnsRequestTypeAndLegacyFallback() {
user := mustCreateUser(s.T(), s.client, &service.User{Email: "getbyid-request-type@test.com"})
apiKey := mustCreateApiKey(s.T(), s.client, &service.APIKey{UserID: user.ID, Key: "sk-getbyid-request-type", Name: "k"})
account := mustCreateAccount(s.T(), s.client, &service.Account{Name: "acc-getbyid-request-type"})
log := &service.UsageLog{
UserID: user.ID,
APIKeyID: apiKey.ID,
AccountID: account.ID,
RequestID: uuid.New().String(),
Model: "gpt-5.3-codex",
RequestType: service.RequestTypeWSV2,
Stream: true,
OpenAIWSMode: false,
InputTokens: 10,
OutputTokens: 20,
TotalCost: 1.0,
ActualCost: 1.0,
CreatedAt: timezone.Today().Add(4 * time.Hour),
}
_, err := s.repo.Create(s.ctx, log)
s.Require().NoError(err)
got, err := s.repo.GetByID(s.ctx, log.ID)
s.Require().NoError(err)
s.Require().Equal(service.RequestTypeWSV2, got.RequestType)
s.Require().True(got.Stream)
s.Require().True(got.OpenAIWSMode)
}
// --- Delete ---
func (s *UsageLogRepoSuite) TestDelete() {
@@ -944,17 +1000,17 @@ func (s *UsageLogRepoSuite) TestGetUsageTrendWithFilters() {
endTime := base.Add(48 * time.Hour)
// Test with user filter
trend, err := s.repo.GetUsageTrendWithFilters(s.ctx, startTime, endTime, "day", user.ID, 0, 0, 0, "", nil, nil)
trend, err := s.repo.GetUsageTrendWithFilters(s.ctx, startTime, endTime, "day", user.ID, 0, 0, 0, "", nil, nil, nil)
s.Require().NoError(err, "GetUsageTrendWithFilters user filter")
s.Require().Len(trend, 2)
// Test with apiKey filter
trend, err = s.repo.GetUsageTrendWithFilters(s.ctx, startTime, endTime, "day", 0, apiKey.ID, 0, 0, "", nil, nil)
trend, err = s.repo.GetUsageTrendWithFilters(s.ctx, startTime, endTime, "day", 0, apiKey.ID, 0, 0, "", nil, nil, nil)
s.Require().NoError(err, "GetUsageTrendWithFilters apiKey filter")
s.Require().Len(trend, 2)
// Test with both filters
trend, err = s.repo.GetUsageTrendWithFilters(s.ctx, startTime, endTime, "day", user.ID, apiKey.ID, 0, 0, "", nil, nil)
trend, err = s.repo.GetUsageTrendWithFilters(s.ctx, startTime, endTime, "day", user.ID, apiKey.ID, 0, 0, "", nil, nil, nil)
s.Require().NoError(err, "GetUsageTrendWithFilters both filters")
s.Require().Len(trend, 2)
}
@@ -971,7 +1027,7 @@ func (s *UsageLogRepoSuite) TestGetUsageTrendWithFilters_HourlyGranularity() {
startTime := base.Add(-1 * time.Hour)
endTime := base.Add(3 * time.Hour)
trend, err := s.repo.GetUsageTrendWithFilters(s.ctx, startTime, endTime, "hour", user.ID, 0, 0, 0, "", nil, nil)
trend, err := s.repo.GetUsageTrendWithFilters(s.ctx, startTime, endTime, "hour", user.ID, 0, 0, 0, "", nil, nil, nil)
s.Require().NoError(err, "GetUsageTrendWithFilters hourly")
s.Require().Len(trend, 2)
}
@@ -1017,17 +1073,17 @@ func (s *UsageLogRepoSuite) TestGetModelStatsWithFilters() {
endTime := base.Add(2 * time.Hour)
// Test with user filter
stats, err := s.repo.GetModelStatsWithFilters(s.ctx, startTime, endTime, user.ID, 0, 0, 0, nil, nil)
stats, err := s.repo.GetModelStatsWithFilters(s.ctx, startTime, endTime, user.ID, 0, 0, 0, nil, nil, nil)
s.Require().NoError(err, "GetModelStatsWithFilters user filter")
s.Require().Len(stats, 2)
// Test with apiKey filter
stats, err = s.repo.GetModelStatsWithFilters(s.ctx, startTime, endTime, 0, apiKey.ID, 0, 0, nil, nil)
stats, err = s.repo.GetModelStatsWithFilters(s.ctx, startTime, endTime, 0, apiKey.ID, 0, 0, nil, nil, nil)
s.Require().NoError(err, "GetModelStatsWithFilters apiKey filter")
s.Require().Len(stats, 2)
// Test with account filter
stats, err = s.repo.GetModelStatsWithFilters(s.ctx, startTime, endTime, 0, 0, account.ID, 0, nil, nil)
stats, err = s.repo.GetModelStatsWithFilters(s.ctx, startTime, endTime, 0, 0, account.ID, 0, nil, nil, nil)
s.Require().NoError(err, "GetModelStatsWithFilters account filter")
s.Require().Len(stats, 2)
}

View File

@@ -0,0 +1,327 @@
package repository
import (
"context"
"database/sql"
"fmt"
"reflect"
"testing"
"time"
"github.com/DATA-DOG/go-sqlmock"
"github.com/Wei-Shaw/sub2api/internal/pkg/pagination"
"github.com/Wei-Shaw/sub2api/internal/pkg/usagestats"
"github.com/Wei-Shaw/sub2api/internal/service"
"github.com/stretchr/testify/require"
)
func TestUsageLogRepositoryCreateSyncRequestTypeAndLegacyFields(t *testing.T) {
db, mock := newSQLMock(t)
repo := &usageLogRepository{sql: db}
createdAt := time.Date(2025, 1, 1, 12, 0, 0, 0, time.UTC)
log := &service.UsageLog{
UserID: 1,
APIKeyID: 2,
AccountID: 3,
RequestID: "req-1",
Model: "gpt-5",
InputTokens: 10,
OutputTokens: 20,
TotalCost: 1,
ActualCost: 1,
BillingType: service.BillingTypeBalance,
RequestType: service.RequestTypeWSV2,
Stream: false,
OpenAIWSMode: false,
CreatedAt: createdAt,
}
mock.ExpectQuery("INSERT INTO usage_logs").
WithArgs(
log.UserID,
log.APIKeyID,
log.AccountID,
log.RequestID,
log.Model,
sqlmock.AnyArg(), // group_id
sqlmock.AnyArg(), // subscription_id
log.InputTokens,
log.OutputTokens,
log.CacheCreationTokens,
log.CacheReadTokens,
log.CacheCreation5mTokens,
log.CacheCreation1hTokens,
log.InputCost,
log.OutputCost,
log.CacheCreationCost,
log.CacheReadCost,
log.TotalCost,
log.ActualCost,
log.RateMultiplier,
log.AccountRateMultiplier,
log.BillingType,
int16(service.RequestTypeWSV2),
true,
true,
sqlmock.AnyArg(), // duration_ms
sqlmock.AnyArg(), // first_token_ms
sqlmock.AnyArg(), // user_agent
sqlmock.AnyArg(), // ip_address
log.ImageCount,
sqlmock.AnyArg(), // image_size
sqlmock.AnyArg(), // media_type
sqlmock.AnyArg(), // reasoning_effort
log.CacheTTLOverridden,
createdAt,
).
WillReturnRows(sqlmock.NewRows([]string{"id", "created_at"}).AddRow(int64(99), createdAt))
inserted, err := repo.Create(context.Background(), log)
require.NoError(t, err)
require.True(t, inserted)
require.Equal(t, int64(99), log.ID)
require.Equal(t, service.RequestTypeWSV2, log.RequestType)
require.True(t, log.Stream)
require.True(t, log.OpenAIWSMode)
require.NoError(t, mock.ExpectationsWereMet())
}
func TestUsageLogRepositoryListWithFiltersRequestTypePriority(t *testing.T) {
db, mock := newSQLMock(t)
repo := &usageLogRepository{sql: db}
requestType := int16(service.RequestTypeWSV2)
stream := false
filters := usagestats.UsageLogFilters{
RequestType: &requestType,
Stream: &stream,
}
mock.ExpectQuery("SELECT COUNT\\(\\*\\) FROM usage_logs WHERE \\(request_type = \\$1 OR \\(request_type = 0 AND openai_ws_mode = TRUE\\)\\)").
WithArgs(requestType).
WillReturnRows(sqlmock.NewRows([]string{"count"}).AddRow(int64(0)))
mock.ExpectQuery("SELECT .* FROM usage_logs WHERE \\(request_type = \\$1 OR \\(request_type = 0 AND openai_ws_mode = TRUE\\)\\) ORDER BY id DESC LIMIT \\$2 OFFSET \\$3").
WithArgs(requestType, 20, 0).
WillReturnRows(sqlmock.NewRows([]string{"id"}))
logs, page, err := repo.ListWithFilters(context.Background(), pagination.PaginationParams{Page: 1, PageSize: 20}, filters)
require.NoError(t, err)
require.Empty(t, logs)
require.NotNil(t, page)
require.Equal(t, int64(0), page.Total)
require.NoError(t, mock.ExpectationsWereMet())
}
func TestUsageLogRepositoryGetUsageTrendWithFiltersRequestTypePriority(t *testing.T) {
db, mock := newSQLMock(t)
repo := &usageLogRepository{sql: db}
start := time.Date(2025, 1, 1, 0, 0, 0, 0, time.UTC)
end := start.Add(24 * time.Hour)
requestType := int16(service.RequestTypeStream)
stream := true
mock.ExpectQuery("AND \\(request_type = \\$3 OR \\(request_type = 0 AND stream = TRUE AND openai_ws_mode = FALSE\\)\\)").
WithArgs(start, end, requestType).
WillReturnRows(sqlmock.NewRows([]string{"date", "requests", "input_tokens", "output_tokens", "cache_tokens", "total_tokens", "cost", "actual_cost"}))
trend, err := repo.GetUsageTrendWithFilters(context.Background(), start, end, "day", 0, 0, 0, 0, "", &requestType, &stream, nil)
require.NoError(t, err)
require.Empty(t, trend)
require.NoError(t, mock.ExpectationsWereMet())
}
func TestUsageLogRepositoryGetModelStatsWithFiltersRequestTypePriority(t *testing.T) {
db, mock := newSQLMock(t)
repo := &usageLogRepository{sql: db}
start := time.Date(2025, 1, 1, 0, 0, 0, 0, time.UTC)
end := start.Add(24 * time.Hour)
requestType := int16(service.RequestTypeWSV2)
stream := false
mock.ExpectQuery("AND \\(request_type = \\$3 OR \\(request_type = 0 AND openai_ws_mode = TRUE\\)\\)").
WithArgs(start, end, requestType).
WillReturnRows(sqlmock.NewRows([]string{"model", "requests", "input_tokens", "output_tokens", "total_tokens", "cost", "actual_cost"}))
stats, err := repo.GetModelStatsWithFilters(context.Background(), start, end, 0, 0, 0, 0, &requestType, &stream, nil)
require.NoError(t, err)
require.Empty(t, stats)
require.NoError(t, mock.ExpectationsWereMet())
}
func TestUsageLogRepositoryGetStatsWithFiltersRequestTypePriority(t *testing.T) {
db, mock := newSQLMock(t)
repo := &usageLogRepository{sql: db}
requestType := int16(service.RequestTypeSync)
stream := true
filters := usagestats.UsageLogFilters{
RequestType: &requestType,
Stream: &stream,
}
mock.ExpectQuery("FROM usage_logs\\s+WHERE \\(request_type = \\$1 OR \\(request_type = 0 AND stream = FALSE AND openai_ws_mode = FALSE\\)\\)").
WithArgs(requestType).
WillReturnRows(sqlmock.NewRows([]string{
"total_requests",
"total_input_tokens",
"total_output_tokens",
"total_cache_tokens",
"total_cost",
"total_actual_cost",
"total_account_cost",
"avg_duration_ms",
}).AddRow(int64(1), int64(2), int64(3), int64(4), 1.2, 1.0, 1.2, 20.0))
stats, err := repo.GetStatsWithFilters(context.Background(), filters)
require.NoError(t, err)
require.Equal(t, int64(1), stats.TotalRequests)
require.Equal(t, int64(9), stats.TotalTokens)
require.NoError(t, mock.ExpectationsWereMet())
}
func TestBuildRequestTypeFilterConditionLegacyFallback(t *testing.T) {
tests := []struct {
name string
request int16
wantWhere string
wantArg int16
}{
{
name: "sync_with_legacy_fallback",
request: int16(service.RequestTypeSync),
wantWhere: "(request_type = $3 OR (request_type = 0 AND stream = FALSE AND openai_ws_mode = FALSE))",
wantArg: int16(service.RequestTypeSync),
},
{
name: "stream_with_legacy_fallback",
request: int16(service.RequestTypeStream),
wantWhere: "(request_type = $3 OR (request_type = 0 AND stream = TRUE AND openai_ws_mode = FALSE))",
wantArg: int16(service.RequestTypeStream),
},
{
name: "ws_v2_with_legacy_fallback",
request: int16(service.RequestTypeWSV2),
wantWhere: "(request_type = $3 OR (request_type = 0 AND openai_ws_mode = TRUE))",
wantArg: int16(service.RequestTypeWSV2),
},
{
name: "invalid_request_type_normalized_to_unknown",
request: int16(99),
wantWhere: "request_type = $3",
wantArg: int16(service.RequestTypeUnknown),
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
where, args := buildRequestTypeFilterCondition(3, tt.request)
require.Equal(t, tt.wantWhere, where)
require.Equal(t, []any{tt.wantArg}, args)
})
}
}
type usageLogScannerStub struct {
values []any
}
func (s usageLogScannerStub) Scan(dest ...any) error {
if len(dest) != len(s.values) {
return fmt.Errorf("scan arg count mismatch: got %d want %d", len(dest), len(s.values))
}
for i := range dest {
dv := reflect.ValueOf(dest[i])
if dv.Kind() != reflect.Ptr {
return fmt.Errorf("dest[%d] is not pointer", i)
}
dv.Elem().Set(reflect.ValueOf(s.values[i]))
}
return nil
}
func TestScanUsageLogRequestTypeAndLegacyFallback(t *testing.T) {
t.Run("request_type_ws_v2_overrides_legacy", func(t *testing.T) {
now := time.Now().UTC()
log, err := scanUsageLog(usageLogScannerStub{values: []any{
int64(1), // id
int64(10), // user_id
int64(20), // api_key_id
int64(30), // account_id
sql.NullString{Valid: true, String: "req-1"},
"gpt-5", // model
sql.NullInt64{}, // group_id
sql.NullInt64{}, // subscription_id
1, // input_tokens
2, // output_tokens
3, // cache_creation_tokens
4, // cache_read_tokens
5, // cache_creation_5m_tokens
6, // cache_creation_1h_tokens
0.1, // input_cost
0.2, // output_cost
0.3, // cache_creation_cost
0.4, // cache_read_cost
1.0, // total_cost
0.9, // actual_cost
1.0, // rate_multiplier
sql.NullFloat64{}, // account_rate_multiplier
int16(service.BillingTypeBalance),
int16(service.RequestTypeWSV2),
false, // legacy stream
false, // legacy openai ws
sql.NullInt64{},
sql.NullInt64{},
sql.NullString{},
sql.NullString{},
0,
sql.NullString{},
sql.NullString{},
sql.NullString{},
false,
now,
}})
require.NoError(t, err)
require.Equal(t, service.RequestTypeWSV2, log.RequestType)
require.True(t, log.Stream)
require.True(t, log.OpenAIWSMode)
})
t.Run("request_type_unknown_falls_back_to_legacy", func(t *testing.T) {
now := time.Now().UTC()
log, err := scanUsageLog(usageLogScannerStub{values: []any{
int64(2),
int64(11),
int64(21),
int64(31),
sql.NullString{Valid: true, String: "req-2"},
"gpt-5",
sql.NullInt64{},
sql.NullInt64{},
1, 2, 3, 4, 5, 6,
0.1, 0.2, 0.3, 0.4, 1.0, 0.9,
1.0,
sql.NullFloat64{},
int16(service.BillingTypeBalance),
int16(service.RequestTypeUnknown),
true,
false,
sql.NullInt64{},
sql.NullInt64{},
sql.NullString{},
sql.NullString{},
0,
sql.NullString{},
sql.NullString{},
sql.NullString{},
false,
now,
}})
require.NoError(t, err)
require.Equal(t, service.RequestTypeStream, log.RequestType)
require.True(t, log.Stream)
require.False(t, log.OpenAIWSMode)
})
}

View File

@@ -6,6 +6,7 @@ import (
"time"
"github.com/Wei-Shaw/sub2api/internal/service"
"github.com/lib/pq"
)
type userGroupRateRepository struct {
@@ -41,6 +42,59 @@ func (r *userGroupRateRepository) GetByUserID(ctx context.Context, userID int64)
return result, nil
}
// GetByUserIDs 批量获取多个用户的专属分组倍率。
// 返回结构map[userID]map[groupID]rate
func (r *userGroupRateRepository) GetByUserIDs(ctx context.Context, userIDs []int64) (map[int64]map[int64]float64, error) {
result := make(map[int64]map[int64]float64, len(userIDs))
if len(userIDs) == 0 {
return result, nil
}
uniqueIDs := make([]int64, 0, len(userIDs))
seen := make(map[int64]struct{}, len(userIDs))
for _, userID := range userIDs {
if userID <= 0 {
continue
}
if _, exists := seen[userID]; exists {
continue
}
seen[userID] = struct{}{}
uniqueIDs = append(uniqueIDs, userID)
result[userID] = make(map[int64]float64)
}
if len(uniqueIDs) == 0 {
return result, nil
}
rows, err := r.sql.QueryContext(ctx, `
SELECT user_id, group_id, rate_multiplier
FROM user_group_rate_multipliers
WHERE user_id = ANY($1)
`, pq.Array(uniqueIDs))
if err != nil {
return nil, err
}
defer func() { _ = rows.Close() }()
for rows.Next() {
var userID int64
var groupID int64
var rate float64
if err := rows.Scan(&userID, &groupID, &rate); err != nil {
return nil, err
}
if _, ok := result[userID]; !ok {
result[userID] = make(map[int64]float64)
}
result[userID][groupID] = rate
}
if err := rows.Err(); err != nil {
return nil, err
}
return result, nil
}
// GetByUserAndGroup 获取用户在特定分组的专属倍率
func (r *userGroupRateRepository) GetByUserAndGroup(ctx context.Context, userID, groupID int64) (*float64, error) {
query := `SELECT rate_multiplier FROM user_group_rate_multipliers WHERE user_id = $1 AND group_id = $2`
@@ -65,33 +119,43 @@ func (r *userGroupRateRepository) SyncUserGroupRates(ctx context.Context, userID
// 分离需要删除和需要 upsert 的记录
var toDelete []int64
toUpsert := make(map[int64]float64)
upsertGroupIDs := make([]int64, 0, len(rates))
upsertRates := make([]float64, 0, len(rates))
for groupID, rate := range rates {
if rate == nil {
toDelete = append(toDelete, groupID)
} else {
toUpsert[groupID] = *rate
upsertGroupIDs = append(upsertGroupIDs, groupID)
upsertRates = append(upsertRates, *rate)
}
}
// 删除指定的记录
for _, groupID := range toDelete {
_, err := r.sql.ExecContext(ctx,
`DELETE FROM user_group_rate_multipliers WHERE user_id = $1 AND group_id = $2`,
userID, groupID)
if err != nil {
if len(toDelete) > 0 {
if _, err := r.sql.ExecContext(ctx,
`DELETE FROM user_group_rate_multipliers WHERE user_id = $1 AND group_id = ANY($2)`,
userID, pq.Array(toDelete)); err != nil {
return err
}
}
// Upsert 记录
now := time.Now()
for groupID, rate := range toUpsert {
if len(upsertGroupIDs) > 0 {
_, err := r.sql.ExecContext(ctx, `
INSERT INTO user_group_rate_multipliers (user_id, group_id, rate_multiplier, created_at, updated_at)
VALUES ($1, $2, $3, $4, $4)
ON CONFLICT (user_id, group_id) DO UPDATE SET rate_multiplier = $3, updated_at = $4
`, userID, groupID, rate, now)
SELECT
$1::bigint,
data.group_id,
data.rate_multiplier,
$2::timestamptz,
$2::timestamptz
FROM unnest($3::bigint[], $4::double precision[]) AS data(group_id, rate_multiplier)
ON CONFLICT (user_id, group_id)
DO UPDATE SET
rate_multiplier = EXCLUDED.rate_multiplier,
updated_at = EXCLUDED.updated_at
`, userID, now, pq.Array(upsertGroupIDs), pq.Array(upsertRates))
if err != nil {
return err
}

View File

@@ -61,6 +61,7 @@ func (r *userRepository) Create(ctx context.Context, userIn *service.User) error
SetBalance(userIn.Balance).
SetConcurrency(userIn.Concurrency).
SetStatus(userIn.Status).
SetSoraStorageQuotaBytes(userIn.SoraStorageQuotaBytes).
Save(ctx)
if err != nil {
return translatePersistenceError(err, nil, service.ErrEmailExists)
@@ -143,6 +144,8 @@ func (r *userRepository) Update(ctx context.Context, userIn *service.User) error
SetBalance(userIn.Balance).
SetConcurrency(userIn.Concurrency).
SetStatus(userIn.Status).
SetSoraStorageQuotaBytes(userIn.SoraStorageQuotaBytes).
SetSoraStorageUsedBytes(userIn.SoraStorageUsedBytes).
Save(ctx)
if err != nil {
return translatePersistenceError(err, service.ErrUserNotFound, service.ErrEmailExists)
@@ -363,6 +366,65 @@ func (r *userRepository) UpdateConcurrency(ctx context.Context, id int64, amount
return nil
}
// AddSoraStorageUsageWithQuota 原子累加 Sora 存储用量,并在有配额时校验不超额。
func (r *userRepository) AddSoraStorageUsageWithQuota(ctx context.Context, userID int64, deltaBytes int64, effectiveQuota int64) (int64, error) {
if deltaBytes <= 0 {
user, err := r.GetByID(ctx, userID)
if err != nil {
return 0, err
}
return user.SoraStorageUsedBytes, nil
}
var newUsed int64
err := scanSingleRow(ctx, r.sql, `
UPDATE users
SET sora_storage_used_bytes = sora_storage_used_bytes + $2
WHERE id = $1
AND ($3 = 0 OR sora_storage_used_bytes + $2 <= $3)
RETURNING sora_storage_used_bytes
`, []any{userID, deltaBytes, effectiveQuota}, &newUsed)
if err == nil {
return newUsed, nil
}
if errors.Is(err, sql.ErrNoRows) {
// 区分用户不存在和配额冲突
exists, existsErr := r.client.User.Query().Where(dbuser.IDEQ(userID)).Exist(ctx)
if existsErr != nil {
return 0, existsErr
}
if !exists {
return 0, service.ErrUserNotFound
}
return 0, service.ErrSoraStorageQuotaExceeded
}
return 0, err
}
// ReleaseSoraStorageUsageAtomic 原子释放 Sora 存储用量,并保证不低于 0。
func (r *userRepository) ReleaseSoraStorageUsageAtomic(ctx context.Context, userID int64, deltaBytes int64) (int64, error) {
if deltaBytes <= 0 {
user, err := r.GetByID(ctx, userID)
if err != nil {
return 0, err
}
return user.SoraStorageUsedBytes, nil
}
var newUsed int64
err := scanSingleRow(ctx, r.sql, `
UPDATE users
SET sora_storage_used_bytes = GREATEST(sora_storage_used_bytes - $2, 0)
WHERE id = $1
RETURNING sora_storage_used_bytes
`, []any{userID, deltaBytes}, &newUsed)
if err != nil {
if errors.Is(err, sql.ErrNoRows) {
return 0, service.ErrUserNotFound
}
return 0, err
}
return newUsed, nil
}
func (r *userRepository) ExistsByEmail(ctx context.Context, email string) (bool, error) {
return r.client.User.Query().Where(dbuser.EmailEQ(email)).Exist(ctx)
}