Merge upstream/main

This commit is contained in:
song
2026-01-17 18:00:07 +08:00
394 changed files with 76872 additions and 1877 deletions

View File

@@ -15,6 +15,7 @@ import (
"database/sql"
"encoding/json"
"errors"
"log"
"strconv"
"time"
@@ -79,6 +80,10 @@ func (r *accountRepository) Create(ctx context.Context, account *service.Account
SetSchedulable(account.Schedulable).
SetAutoPauseOnExpired(account.AutoPauseOnExpired)
if account.RateMultiplier != nil {
builder.SetRateMultiplier(*account.RateMultiplier)
}
if account.ProxyID != nil {
builder.SetProxyID(*account.ProxyID)
}
@@ -115,6 +120,9 @@ func (r *accountRepository) Create(ctx context.Context, account *service.Account
account.ID = created.ID
account.CreatedAt = created.CreatedAt
account.UpdatedAt = created.UpdatedAt
if err := enqueueSchedulerOutbox(ctx, r.sql, service.SchedulerOutboxEventAccountChanged, &account.ID, nil, buildSchedulerGroupPayload(account.GroupIDs)); err != nil {
log.Printf("[SchedulerOutbox] enqueue account create failed: account=%d err=%v", account.ID, err)
}
return nil
}
@@ -287,6 +295,10 @@ func (r *accountRepository) Update(ctx context.Context, account *service.Account
SetSchedulable(account.Schedulable).
SetAutoPauseOnExpired(account.AutoPauseOnExpired)
if account.RateMultiplier != nil {
builder.SetRateMultiplier(*account.RateMultiplier)
}
if account.ProxyID != nil {
builder.SetProxyID(*account.ProxyID)
} else {
@@ -341,10 +353,17 @@ func (r *accountRepository) Update(ctx context.Context, account *service.Account
return translatePersistenceError(err, service.ErrAccountNotFound, nil)
}
account.UpdatedAt = updated.UpdatedAt
if err := enqueueSchedulerOutbox(ctx, r.sql, service.SchedulerOutboxEventAccountChanged, &account.ID, nil, buildSchedulerGroupPayload(account.GroupIDs)); err != nil {
log.Printf("[SchedulerOutbox] enqueue account update failed: account=%d err=%v", account.ID, err)
}
return nil
}
func (r *accountRepository) Delete(ctx context.Context, id int64) error {
groupIDs, err := r.loadAccountGroupIDs(ctx, id)
if err != nil {
return err
}
// 使用事务保证账号与关联分组的删除原子性
tx, err := r.client.Tx(ctx)
if err != nil && !errors.Is(err, dbent.ErrTxStarted) {
@@ -368,7 +387,12 @@ func (r *accountRepository) Delete(ctx context.Context, id int64) error {
}
if tx != nil {
return tx.Commit()
if err := tx.Commit(); err != nil {
return err
}
}
if err := enqueueSchedulerOutbox(ctx, r.sql, service.SchedulerOutboxEventAccountChanged, &id, nil, buildSchedulerGroupPayload(groupIDs)); err != nil {
log.Printf("[SchedulerOutbox] enqueue account delete failed: account=%d err=%v", id, err)
}
return nil
}
@@ -455,7 +479,18 @@ func (r *accountRepository) UpdateLastUsed(ctx context.Context, id int64) error
Where(dbaccount.IDEQ(id)).
SetLastUsedAt(now).
Save(ctx)
return err
if err != nil {
return err
}
payload := map[string]any{
"last_used": map[string]int64{
strconv.FormatInt(id, 10): now.Unix(),
},
}
if err := enqueueSchedulerOutbox(ctx, r.sql, service.SchedulerOutboxEventAccountLastUsed, &id, nil, payload); err != nil {
log.Printf("[SchedulerOutbox] enqueue last used failed: account=%d err=%v", id, err)
}
return nil
}
func (r *accountRepository) BatchUpdateLastUsed(ctx context.Context, updates map[int64]time.Time) error {
@@ -479,7 +514,18 @@ func (r *accountRepository) BatchUpdateLastUsed(ctx context.Context, updates map
args = append(args, pq.Array(ids))
_, err := r.sql.ExecContext(ctx, caseSQL, args...)
return err
if err != nil {
return err
}
lastUsedPayload := make(map[string]int64, len(updates))
for id, ts := range updates {
lastUsedPayload[strconv.FormatInt(id, 10)] = ts.Unix()
}
payload := map[string]any{"last_used": lastUsedPayload}
if err := enqueueSchedulerOutbox(ctx, r.sql, service.SchedulerOutboxEventAccountLastUsed, nil, nil, payload); err != nil {
log.Printf("[SchedulerOutbox] enqueue batch last used failed: err=%v", err)
}
return nil
}
func (r *accountRepository) SetError(ctx context.Context, id int64, errorMsg string) error {
@@ -488,7 +534,13 @@ func (r *accountRepository) SetError(ctx context.Context, id int64, errorMsg str
SetStatus(service.StatusError).
SetErrorMessage(errorMsg).
Save(ctx)
return err
if err != nil {
return err
}
if err := enqueueSchedulerOutbox(ctx, r.sql, service.SchedulerOutboxEventAccountChanged, &id, nil, nil); err != nil {
log.Printf("[SchedulerOutbox] enqueue set error failed: account=%d err=%v", id, err)
}
return nil
}
func (r *accountRepository) ClearError(ctx context.Context, id int64) error {
@@ -506,7 +558,14 @@ func (r *accountRepository) AddToGroup(ctx context.Context, accountID, groupID i
SetGroupID(groupID).
SetPriority(priority).
Save(ctx)
return err
if err != nil {
return err
}
payload := buildSchedulerGroupPayload([]int64{groupID})
if err := enqueueSchedulerOutbox(ctx, r.sql, service.SchedulerOutboxEventAccountGroupsChanged, &accountID, nil, payload); err != nil {
log.Printf("[SchedulerOutbox] enqueue add to group failed: account=%d group=%d err=%v", accountID, groupID, err)
}
return nil
}
func (r *accountRepository) RemoveFromGroup(ctx context.Context, accountID, groupID int64) error {
@@ -516,7 +575,14 @@ func (r *accountRepository) RemoveFromGroup(ctx context.Context, accountID, grou
dbaccountgroup.GroupIDEQ(groupID),
).
Exec(ctx)
return err
if err != nil {
return err
}
payload := buildSchedulerGroupPayload([]int64{groupID})
if err := enqueueSchedulerOutbox(ctx, r.sql, service.SchedulerOutboxEventAccountGroupsChanged, &accountID, nil, payload); err != nil {
log.Printf("[SchedulerOutbox] enqueue remove from group failed: account=%d group=%d err=%v", accountID, groupID, err)
}
return nil
}
func (r *accountRepository) GetGroups(ctx context.Context, accountID int64) ([]service.Group, error) {
@@ -537,6 +603,10 @@ func (r *accountRepository) GetGroups(ctx context.Context, accountID int64) ([]s
}
func (r *accountRepository) BindGroups(ctx context.Context, accountID int64, groupIDs []int64) error {
existingGroupIDs, err := r.loadAccountGroupIDs(ctx, accountID)
if err != nil {
return err
}
// 使用事务保证删除旧绑定与创建新绑定的原子性
tx, err := r.client.Tx(ctx)
if err != nil && !errors.Is(err, dbent.ErrTxStarted) {
@@ -577,7 +647,13 @@ func (r *accountRepository) BindGroups(ctx context.Context, accountID int64, gro
}
if tx != nil {
return tx.Commit()
if err := tx.Commit(); err != nil {
return err
}
}
payload := buildSchedulerGroupPayload(mergeGroupIDs(existingGroupIDs, groupIDs))
if err := enqueueSchedulerOutbox(ctx, r.sql, service.SchedulerOutboxEventAccountGroupsChanged, &accountID, nil, payload); err != nil {
log.Printf("[SchedulerOutbox] enqueue bind groups failed: account=%d err=%v", accountID, err)
}
return nil
}
@@ -681,7 +757,13 @@ func (r *accountRepository) SetRateLimited(ctx context.Context, id int64, resetA
SetRateLimitedAt(now).
SetRateLimitResetAt(resetAt).
Save(ctx)
return err
if err != nil {
return err
}
if err := enqueueSchedulerOutbox(ctx, r.sql, service.SchedulerOutboxEventAccountChanged, &id, nil, nil); err != nil {
log.Printf("[SchedulerOutbox] enqueue rate limit failed: account=%d err=%v", id, err)
}
return nil
}
func (r *accountRepository) SetAntigravityQuotaScopeLimit(ctx context.Context, id int64, scope service.AntigravityQuotaScope, resetAt time.Time) error {
@@ -715,6 +797,49 @@ func (r *accountRepository) SetAntigravityQuotaScopeLimit(ctx context.Context, i
if affected == 0 {
return service.ErrAccountNotFound
}
if err := enqueueSchedulerOutbox(ctx, r.sql, service.SchedulerOutboxEventAccountChanged, &id, nil, nil); err != nil {
log.Printf("[SchedulerOutbox] enqueue quota scope failed: account=%d err=%v", id, err)
}
return nil
}
func (r *accountRepository) SetModelRateLimit(ctx context.Context, id int64, scope string, resetAt time.Time) error {
if scope == "" {
return nil
}
now := time.Now().UTC()
payload := map[string]string{
"rate_limited_at": now.Format(time.RFC3339),
"rate_limit_reset_at": resetAt.UTC().Format(time.RFC3339),
}
raw, err := json.Marshal(payload)
if err != nil {
return err
}
path := "{model_rate_limits," + scope + "}"
client := clientFromContext(ctx, r.client)
result, err := client.ExecContext(
ctx,
"UPDATE accounts SET extra = jsonb_set(COALESCE(extra, '{}'::jsonb), $1::text[], $2::jsonb, true), updated_at = NOW() WHERE id = $3 AND deleted_at IS NULL",
path,
raw,
id,
)
if err != nil {
return err
}
affected, err := result.RowsAffected()
if err != nil {
return err
}
if affected == 0 {
return service.ErrAccountNotFound
}
if err := enqueueSchedulerOutbox(ctx, r.sql, service.SchedulerOutboxEventAccountChanged, &id, nil, nil); err != nil {
log.Printf("[SchedulerOutbox] enqueue model rate limit failed: account=%d err=%v", id, err)
}
return nil
}
@@ -723,7 +848,13 @@ func (r *accountRepository) SetOverloaded(ctx context.Context, id int64, until t
Where(dbaccount.IDEQ(id)).
SetOverloadUntil(until).
Save(ctx)
return err
if err != nil {
return err
}
if err := enqueueSchedulerOutbox(ctx, r.sql, service.SchedulerOutboxEventAccountChanged, &id, nil, nil); err != nil {
log.Printf("[SchedulerOutbox] enqueue overload failed: account=%d err=%v", id, err)
}
return nil
}
func (r *accountRepository) SetTempUnschedulable(ctx context.Context, id int64, until time.Time, reason string) error {
@@ -736,7 +867,13 @@ func (r *accountRepository) SetTempUnschedulable(ctx context.Context, id int64,
AND deleted_at IS NULL
AND (temp_unschedulable_until IS NULL OR temp_unschedulable_until < $1)
`, until, reason, id)
return err
if err != nil {
return err
}
if err := enqueueSchedulerOutbox(ctx, r.sql, service.SchedulerOutboxEventAccountChanged, &id, nil, nil); err != nil {
log.Printf("[SchedulerOutbox] enqueue temp unschedulable failed: account=%d err=%v", id, err)
}
return nil
}
func (r *accountRepository) ClearTempUnschedulable(ctx context.Context, id int64) error {
@@ -748,7 +885,13 @@ func (r *accountRepository) ClearTempUnschedulable(ctx context.Context, id int64
WHERE id = $1
AND deleted_at IS NULL
`, id)
return err
if err != nil {
return err
}
if err := enqueueSchedulerOutbox(ctx, r.sql, service.SchedulerOutboxEventAccountChanged, &id, nil, nil); err != nil {
log.Printf("[SchedulerOutbox] enqueue clear temp unschedulable failed: account=%d err=%v", id, err)
}
return nil
}
func (r *accountRepository) ClearRateLimit(ctx context.Context, id int64) error {
@@ -758,7 +901,13 @@ func (r *accountRepository) ClearRateLimit(ctx context.Context, id int64) error
ClearRateLimitResetAt().
ClearOverloadUntil().
Save(ctx)
return err
if err != nil {
return err
}
if err := enqueueSchedulerOutbox(ctx, r.sql, service.SchedulerOutboxEventAccountChanged, &id, nil, nil); err != nil {
log.Printf("[SchedulerOutbox] enqueue clear rate limit failed: account=%d err=%v", id, err)
}
return nil
}
func (r *accountRepository) ClearAntigravityQuotaScopes(ctx context.Context, id int64) error {
@@ -779,6 +928,33 @@ func (r *accountRepository) ClearAntigravityQuotaScopes(ctx context.Context, id
if affected == 0 {
return service.ErrAccountNotFound
}
if err := enqueueSchedulerOutbox(ctx, r.sql, service.SchedulerOutboxEventAccountChanged, &id, nil, nil); err != nil {
log.Printf("[SchedulerOutbox] enqueue clear quota scopes failed: account=%d err=%v", id, err)
}
return nil
}
func (r *accountRepository) ClearModelRateLimits(ctx context.Context, id int64) error {
client := clientFromContext(ctx, r.client)
result, err := client.ExecContext(
ctx,
"UPDATE accounts SET extra = COALESCE(extra, '{}'::jsonb) - 'model_rate_limits', updated_at = NOW() WHERE id = $1 AND deleted_at IS NULL",
id,
)
if err != nil {
return err
}
affected, err := result.RowsAffected()
if err != nil {
return err
}
if affected == 0 {
return service.ErrAccountNotFound
}
if err := enqueueSchedulerOutbox(ctx, r.sql, service.SchedulerOutboxEventAccountChanged, &id, nil, nil); err != nil {
log.Printf("[SchedulerOutbox] enqueue clear model rate limit failed: account=%d err=%v", id, err)
}
return nil
}
@@ -801,7 +977,13 @@ func (r *accountRepository) SetSchedulable(ctx context.Context, id int64, schedu
Where(dbaccount.IDEQ(id)).
SetSchedulable(schedulable).
Save(ctx)
return err
if err != nil {
return err
}
if err := enqueueSchedulerOutbox(ctx, r.sql, service.SchedulerOutboxEventAccountChanged, &id, nil, nil); err != nil {
log.Printf("[SchedulerOutbox] enqueue schedulable change failed: account=%d err=%v", id, err)
}
return nil
}
func (r *accountRepository) AutoPauseExpiredAccounts(ctx context.Context, now time.Time) (int64, error) {
@@ -822,6 +1004,11 @@ func (r *accountRepository) AutoPauseExpiredAccounts(ctx context.Context, now ti
if err != nil {
return 0, err
}
if rows > 0 {
if err := enqueueSchedulerOutbox(ctx, r.sql, service.SchedulerOutboxEventFullRebuild, nil, nil, nil); err != nil {
log.Printf("[SchedulerOutbox] enqueue auto pause rebuild failed: err=%v", err)
}
}
return rows, nil
}
@@ -853,6 +1040,9 @@ func (r *accountRepository) UpdateExtra(ctx context.Context, id int64, updates m
if affected == 0 {
return service.ErrAccountNotFound
}
if err := enqueueSchedulerOutbox(ctx, r.sql, service.SchedulerOutboxEventAccountChanged, &id, nil, nil); err != nil {
log.Printf("[SchedulerOutbox] enqueue extra update failed: account=%d err=%v", id, err)
}
return nil
}
@@ -890,6 +1080,11 @@ func (r *accountRepository) BulkUpdate(ctx context.Context, ids []int64, updates
args = append(args, *updates.Priority)
idx++
}
if updates.RateMultiplier != nil {
setClauses = append(setClauses, "rate_multiplier = $"+itoa(idx))
args = append(args, *updates.RateMultiplier)
idx++
}
if updates.Status != nil {
setClauses = append(setClauses, "status = $"+itoa(idx))
args = append(args, *updates.Status)
@@ -937,6 +1132,12 @@ func (r *accountRepository) BulkUpdate(ctx context.Context, ids []int64, updates
if err != nil {
return 0, err
}
if rows > 0 {
payload := map[string]any{"account_ids": ids}
if err := enqueueSchedulerOutbox(ctx, r.sql, service.SchedulerOutboxEventAccountBulkChanged, nil, nil, payload); err != nil {
log.Printf("[SchedulerOutbox] enqueue bulk update failed: err=%v", err)
}
}
return rows, nil
}
@@ -1179,11 +1380,61 @@ func (r *accountRepository) loadAccountGroups(ctx context.Context, accountIDs []
return groupsByAccount, groupIDsByAccount, accountGroupsByAccount, nil
}
func (r *accountRepository) loadAccountGroupIDs(ctx context.Context, accountID int64) ([]int64, error) {
entries, err := r.client.AccountGroup.
Query().
Where(dbaccountgroup.AccountIDEQ(accountID)).
All(ctx)
if err != nil {
return nil, err
}
ids := make([]int64, 0, len(entries))
for _, entry := range entries {
ids = append(ids, entry.GroupID)
}
return ids, nil
}
func mergeGroupIDs(a []int64, b []int64) []int64 {
seen := make(map[int64]struct{}, len(a)+len(b))
out := make([]int64, 0, len(a)+len(b))
for _, id := range a {
if id <= 0 {
continue
}
if _, ok := seen[id]; ok {
continue
}
seen[id] = struct{}{}
out = append(out, id)
}
for _, id := range b {
if id <= 0 {
continue
}
if _, ok := seen[id]; ok {
continue
}
seen[id] = struct{}{}
out = append(out, id)
}
return out
}
func buildSchedulerGroupPayload(groupIDs []int64) map[string]any {
if len(groupIDs) == 0 {
return nil
}
return map[string]any{"group_ids": groupIDs}
}
func accountEntityToService(m *dbent.Account) *service.Account {
if m == nil {
return nil
}
rateMultiplier := m.RateMultiplier
return &service.Account{
ID: m.ID,
Name: m.Name,
@@ -1195,6 +1446,7 @@ func accountEntityToService(m *dbent.Account) *service.Account {
ProxyID: m.ProxyID,
Concurrency: m.Concurrency,
Priority: m.Priority,
RateMultiplier: &rateMultiplier,
Status: m.Status,
ErrorMessage: derefString(m.ErrorMessage),
LastUsedAt: m.LastUsedAt,

View File

@@ -2,6 +2,7 @@ package repository
import (
"context"
"encoding/json"
"errors"
"fmt"
"time"
@@ -13,6 +14,7 @@ import (
const (
apiKeyRateLimitKeyPrefix = "apikey:ratelimit:"
apiKeyRateLimitDuration = 24 * time.Hour
apiKeyAuthCachePrefix = "apikey:auth:"
)
// apiKeyRateLimitKey generates the Redis key for API key creation rate limiting.
@@ -20,6 +22,10 @@ func apiKeyRateLimitKey(userID int64) string {
return fmt.Sprintf("%s%d", apiKeyRateLimitKeyPrefix, userID)
}
func apiKeyAuthCacheKey(key string) string {
return fmt.Sprintf("%s%s", apiKeyAuthCachePrefix, key)
}
type apiKeyCache struct {
rdb *redis.Client
}
@@ -58,3 +64,30 @@ func (c *apiKeyCache) IncrementDailyUsage(ctx context.Context, apiKey string) er
func (c *apiKeyCache) SetDailyUsageExpiry(ctx context.Context, apiKey string, ttl time.Duration) error {
return c.rdb.Expire(ctx, apiKey, ttl).Err()
}
func (c *apiKeyCache) GetAuthCache(ctx context.Context, key string) (*service.APIKeyAuthCacheEntry, error) {
val, err := c.rdb.Get(ctx, apiKeyAuthCacheKey(key)).Bytes()
if err != nil {
return nil, err
}
var entry service.APIKeyAuthCacheEntry
if err := json.Unmarshal(val, &entry); err != nil {
return nil, err
}
return &entry, nil
}
func (c *apiKeyCache) SetAuthCache(ctx context.Context, key string, entry *service.APIKeyAuthCacheEntry, ttl time.Duration) error {
if entry == nil {
return nil
}
payload, err := json.Marshal(entry)
if err != nil {
return err
}
return c.rdb.Set(ctx, apiKeyAuthCacheKey(key), payload, ttl).Err()
}
func (c *apiKeyCache) DeleteAuthCache(ctx context.Context, key string) error {
return c.rdb.Del(ctx, apiKeyAuthCacheKey(key)).Err()
}

View File

@@ -6,7 +6,9 @@ import (
dbent "github.com/Wei-Shaw/sub2api/ent"
"github.com/Wei-Shaw/sub2api/ent/apikey"
"github.com/Wei-Shaw/sub2api/ent/group"
"github.com/Wei-Shaw/sub2api/ent/schema/mixins"
"github.com/Wei-Shaw/sub2api/ent/user"
"github.com/Wei-Shaw/sub2api/internal/service"
"github.com/Wei-Shaw/sub2api/internal/pkg/pagination"
@@ -26,13 +28,21 @@ func (r *apiKeyRepository) activeQuery() *dbent.APIKeyQuery {
}
func (r *apiKeyRepository) Create(ctx context.Context, key *service.APIKey) error {
created, err := r.client.APIKey.Create().
builder := r.client.APIKey.Create().
SetUserID(key.UserID).
SetKey(key.Key).
SetName(key.Name).
SetStatus(key.Status).
SetNillableGroupID(key.GroupID).
Save(ctx)
SetNillableGroupID(key.GroupID)
if len(key.IPWhitelist) > 0 {
builder.SetIPWhitelist(key.IPWhitelist)
}
if len(key.IPBlacklist) > 0 {
builder.SetIPBlacklist(key.IPBlacklist)
}
created, err := builder.Save(ctx)
if err == nil {
key.ID = created.ID
key.CreatedAt = created.CreatedAt
@@ -56,23 +66,23 @@ func (r *apiKeyRepository) GetByID(ctx context.Context, id int64) (*service.APIK
return apiKeyEntityToService(m), nil
}
// GetOwnerID 根据 API Key ID 获取其所有者(用户)ID。
// GetKeyAndOwnerID 根据 API Key ID 获取其 key 与所有者用户ID。
// 相比 GetByID此方法性能更优因为
// - 使用 Select() 只查询 user_id 字段,减少数据传输量
// - 使用 Select() 只查询必要字段,减少数据传输量
// - 不加载完整的 API Key 实体及其关联数据User、Group 等)
// - 适用于权限验证等只需用户 ID 的场景(如删除前的所有权检查)
func (r *apiKeyRepository) GetOwnerID(ctx context.Context, id int64) (int64, error) {
// - 适用于删除等只需 key 与用户 ID 的场景
func (r *apiKeyRepository) GetKeyAndOwnerID(ctx context.Context, id int64) (string, int64, error) {
m, err := r.activeQuery().
Where(apikey.IDEQ(id)).
Select(apikey.FieldUserID).
Select(apikey.FieldKey, apikey.FieldUserID).
Only(ctx)
if err != nil {
if dbent.IsNotFound(err) {
return 0, service.ErrAPIKeyNotFound
return "", 0, service.ErrAPIKeyNotFound
}
return 0, err
return "", 0, err
}
return m.UserID, nil
return m.Key, m.UserID, nil
}
func (r *apiKeyRepository) GetByKey(ctx context.Context, key string) (*service.APIKey, error) {
@@ -90,6 +100,56 @@ func (r *apiKeyRepository) GetByKey(ctx context.Context, key string) (*service.A
return apiKeyEntityToService(m), nil
}
func (r *apiKeyRepository) GetByKeyForAuth(ctx context.Context, key string) (*service.APIKey, error) {
m, err := r.activeQuery().
Where(apikey.KeyEQ(key)).
Select(
apikey.FieldID,
apikey.FieldUserID,
apikey.FieldGroupID,
apikey.FieldStatus,
apikey.FieldIPWhitelist,
apikey.FieldIPBlacklist,
).
WithUser(func(q *dbent.UserQuery) {
q.Select(
user.FieldID,
user.FieldStatus,
user.FieldRole,
user.FieldBalance,
user.FieldConcurrency,
)
}).
WithGroup(func(q *dbent.GroupQuery) {
q.Select(
group.FieldID,
group.FieldName,
group.FieldPlatform,
group.FieldStatus,
group.FieldSubscriptionType,
group.FieldRateMultiplier,
group.FieldDailyLimitUsd,
group.FieldWeeklyLimitUsd,
group.FieldMonthlyLimitUsd,
group.FieldImagePrice1k,
group.FieldImagePrice2k,
group.FieldImagePrice4k,
group.FieldClaudeCodeOnly,
group.FieldFallbackGroupID,
group.FieldModelRoutingEnabled,
group.FieldModelRouting,
)
}).
Only(ctx)
if err != nil {
if dbent.IsNotFound(err) {
return nil, service.ErrAPIKeyNotFound
}
return nil, err
}
return apiKeyEntityToService(m), nil
}
func (r *apiKeyRepository) Update(ctx context.Context, key *service.APIKey) error {
// 使用原子操作:将软删除检查与更新合并到同一语句,避免竞态条件。
// 之前的实现先检查 Exist 再 UpdateOneID若在两步之间发生软删除
@@ -108,6 +168,18 @@ func (r *apiKeyRepository) Update(ctx context.Context, key *service.APIKey) erro
builder.ClearGroupID()
}
// IP 限制字段
if len(key.IPWhitelist) > 0 {
builder.SetIPWhitelist(key.IPWhitelist)
} else {
builder.ClearIPWhitelist()
}
if len(key.IPBlacklist) > 0 {
builder.SetIPBlacklist(key.IPBlacklist)
} else {
builder.ClearIPBlacklist()
}
affected, err := builder.Save(ctx)
if err != nil {
return err
@@ -263,19 +335,43 @@ func (r *apiKeyRepository) CountByGroupID(ctx context.Context, groupID int64) (i
return int64(count), err
}
func (r *apiKeyRepository) ListKeysByUserID(ctx context.Context, userID int64) ([]string, error) {
keys, err := r.activeQuery().
Where(apikey.UserIDEQ(userID)).
Select(apikey.FieldKey).
Strings(ctx)
if err != nil {
return nil, err
}
return keys, nil
}
func (r *apiKeyRepository) ListKeysByGroupID(ctx context.Context, groupID int64) ([]string, error) {
keys, err := r.activeQuery().
Where(apikey.GroupIDEQ(groupID)).
Select(apikey.FieldKey).
Strings(ctx)
if err != nil {
return nil, err
}
return keys, nil
}
func apiKeyEntityToService(m *dbent.APIKey) *service.APIKey {
if m == nil {
return nil
}
out := &service.APIKey{
ID: m.ID,
UserID: m.UserID,
Key: m.Key,
Name: m.Name,
Status: m.Status,
CreatedAt: m.CreatedAt,
UpdatedAt: m.UpdatedAt,
GroupID: m.GroupID,
ID: m.ID,
UserID: m.UserID,
Key: m.Key,
Name: m.Name,
Status: m.Status,
IPWhitelist: m.IPWhitelist,
IPBlacklist: m.IPBlacklist,
CreatedAt: m.CreatedAt,
UpdatedAt: m.UpdatedAt,
GroupID: m.GroupID,
}
if m.Edges.User != nil {
out.User = userEntityToService(m.Edges.User)
@@ -317,6 +413,7 @@ func groupEntityToService(g *dbent.Group) *service.Group {
RateMultiplier: g.RateMultiplier,
IsExclusive: g.IsExclusive,
Status: g.Status,
Hydrated: true,
SubscriptionType: g.SubscriptionType,
DailyLimitUSD: g.DailyLimitUsd,
WeeklyLimitUSD: g.WeeklyLimitUsd,
@@ -327,6 +424,8 @@ func groupEntityToService(g *dbent.Group) *service.Group {
DefaultValidityDays: g.DefaultValidityDays,
ClaudeCodeOnly: g.ClaudeCodeOnly,
FallbackGroupID: g.FallbackGroupID,
ModelRouting: g.ModelRouting,
ModelRoutingEnabled: g.ModelRoutingEnabled,
CreatedAt: g.CreatedAt,
UpdatedAt: g.UpdatedAt,
}

View File

@@ -93,7 +93,7 @@ var (
return redis.call('ZCARD', key)
`)
// incrementWaitScript - only sets TTL on first creation to avoid refreshing
// incrementWaitScript - refreshes TTL on each increment to keep queue depth accurate
// KEYS[1] = wait queue key
// ARGV[1] = maxWait
// ARGV[2] = TTL in seconds
@@ -111,15 +111,13 @@ var (
local newVal = redis.call('INCR', KEYS[1])
-- Only set TTL on first creation to avoid refreshing zombie data
if newVal == 1 then
redis.call('EXPIRE', KEYS[1], ARGV[2])
end
-- Refresh TTL so long-running traffic doesn't expire active queue counters.
redis.call('EXPIRE', KEYS[1], ARGV[2])
return 1
`)
// incrementAccountWaitScript - account-level wait queue count
// incrementAccountWaitScript - account-level wait queue count (refresh TTL on each increment)
incrementAccountWaitScript = redis.NewScript(`
local current = redis.call('GET', KEYS[1])
if current == false then
@@ -134,10 +132,8 @@ var (
local newVal = redis.call('INCR', KEYS[1])
-- Only set TTL on first creation to avoid refreshing zombie data
if newVal == 1 then
redis.call('EXPIRE', KEYS[1], ARGV[2])
end
-- Refresh TTL so long-running traffic doesn't expire active queue counters.
redis.call('EXPIRE', KEYS[1], ARGV[2])
return 1
`)

View File

@@ -0,0 +1,392 @@
package repository
import (
"context"
"database/sql"
"fmt"
"log"
"strings"
"time"
"github.com/Wei-Shaw/sub2api/internal/pkg/timezone"
"github.com/Wei-Shaw/sub2api/internal/service"
"github.com/lib/pq"
)
type dashboardAggregationRepository struct {
sql sqlExecutor
}
// NewDashboardAggregationRepository 创建仪表盘预聚合仓储。
func NewDashboardAggregationRepository(sqlDB *sql.DB) service.DashboardAggregationRepository {
if sqlDB == nil {
return nil
}
if !isPostgresDriver(sqlDB) {
log.Printf("[DashboardAggregation] 检测到非 PostgreSQL 驱动,已自动禁用预聚合")
return nil
}
return newDashboardAggregationRepositoryWithSQL(sqlDB)
}
func newDashboardAggregationRepositoryWithSQL(sqlq sqlExecutor) *dashboardAggregationRepository {
return &dashboardAggregationRepository{sql: sqlq}
}
func isPostgresDriver(db *sql.DB) bool {
if db == nil {
return false
}
_, ok := db.Driver().(*pq.Driver)
return ok
}
func (r *dashboardAggregationRepository) AggregateRange(ctx context.Context, start, end time.Time) error {
loc := timezone.Location()
startLocal := start.In(loc)
endLocal := end.In(loc)
if !endLocal.After(startLocal) {
return nil
}
hourStart := startLocal.Truncate(time.Hour)
hourEnd := endLocal.Truncate(time.Hour)
if endLocal.After(hourEnd) {
hourEnd = hourEnd.Add(time.Hour)
}
dayStart := truncateToDay(startLocal)
dayEnd := truncateToDay(endLocal)
if endLocal.After(dayEnd) {
dayEnd = dayEnd.Add(24 * time.Hour)
}
// 以桶边界聚合,允许覆盖 end 所在桶的剩余区间。
if err := r.insertHourlyActiveUsers(ctx, hourStart, hourEnd); err != nil {
return err
}
if err := r.insertDailyActiveUsers(ctx, hourStart, hourEnd); err != nil {
return err
}
if err := r.upsertHourlyAggregates(ctx, hourStart, hourEnd); err != nil {
return err
}
if err := r.upsertDailyAggregates(ctx, dayStart, dayEnd); err != nil {
return err
}
return nil
}
func (r *dashboardAggregationRepository) GetAggregationWatermark(ctx context.Context) (time.Time, error) {
var ts time.Time
query := "SELECT last_aggregated_at FROM usage_dashboard_aggregation_watermark WHERE id = 1"
if err := scanSingleRow(ctx, r.sql, query, nil, &ts); err != nil {
if err == sql.ErrNoRows {
return time.Unix(0, 0).UTC(), nil
}
return time.Time{}, err
}
return ts.UTC(), nil
}
func (r *dashboardAggregationRepository) UpdateAggregationWatermark(ctx context.Context, aggregatedAt time.Time) error {
query := `
INSERT INTO usage_dashboard_aggregation_watermark (id, last_aggregated_at, updated_at)
VALUES (1, $1, NOW())
ON CONFLICT (id)
DO UPDATE SET last_aggregated_at = EXCLUDED.last_aggregated_at, updated_at = EXCLUDED.updated_at
`
_, err := r.sql.ExecContext(ctx, query, aggregatedAt.UTC())
return err
}
func (r *dashboardAggregationRepository) CleanupAggregates(ctx context.Context, hourlyCutoff, dailyCutoff time.Time) error {
hourlyCutoffUTC := hourlyCutoff.UTC()
dailyCutoffUTC := dailyCutoff.UTC()
if _, err := r.sql.ExecContext(ctx, "DELETE FROM usage_dashboard_hourly WHERE bucket_start < $1", hourlyCutoffUTC); err != nil {
return err
}
if _, err := r.sql.ExecContext(ctx, "DELETE FROM usage_dashboard_hourly_users WHERE bucket_start < $1", hourlyCutoffUTC); err != nil {
return err
}
if _, err := r.sql.ExecContext(ctx, "DELETE FROM usage_dashboard_daily WHERE bucket_date < $1::date", dailyCutoffUTC); err != nil {
return err
}
if _, err := r.sql.ExecContext(ctx, "DELETE FROM usage_dashboard_daily_users WHERE bucket_date < $1::date", dailyCutoffUTC); err != nil {
return err
}
return nil
}
func (r *dashboardAggregationRepository) CleanupUsageLogs(ctx context.Context, cutoff time.Time) error {
isPartitioned, err := r.isUsageLogsPartitioned(ctx)
if err != nil {
return err
}
if isPartitioned {
return r.dropUsageLogsPartitions(ctx, cutoff)
}
_, err = r.sql.ExecContext(ctx, "DELETE FROM usage_logs WHERE created_at < $1", cutoff.UTC())
return err
}
func (r *dashboardAggregationRepository) EnsureUsageLogsPartitions(ctx context.Context, now time.Time) error {
isPartitioned, err := r.isUsageLogsPartitioned(ctx)
if err != nil || !isPartitioned {
return err
}
monthStart := truncateToMonthUTC(now)
prevMonth := monthStart.AddDate(0, -1, 0)
nextMonth := monthStart.AddDate(0, 1, 0)
for _, m := range []time.Time{prevMonth, monthStart, nextMonth} {
if err := r.createUsageLogsPartition(ctx, m); err != nil {
return err
}
}
return nil
}
func (r *dashboardAggregationRepository) insertHourlyActiveUsers(ctx context.Context, start, end time.Time) error {
tzName := timezone.Name()
query := `
INSERT INTO usage_dashboard_hourly_users (bucket_start, user_id)
SELECT DISTINCT
date_trunc('hour', created_at AT TIME ZONE $3) AT TIME ZONE $3 AS bucket_start,
user_id
FROM usage_logs
WHERE created_at >= $1 AND created_at < $2
ON CONFLICT DO NOTHING
`
_, err := r.sql.ExecContext(ctx, query, start, end, tzName)
return err
}
func (r *dashboardAggregationRepository) insertDailyActiveUsers(ctx context.Context, start, end time.Time) error {
tzName := timezone.Name()
query := `
INSERT INTO usage_dashboard_daily_users (bucket_date, user_id)
SELECT DISTINCT
(bucket_start AT TIME ZONE $3)::date AS bucket_date,
user_id
FROM usage_dashboard_hourly_users
WHERE bucket_start >= $1 AND bucket_start < $2
ON CONFLICT DO NOTHING
`
_, err := r.sql.ExecContext(ctx, query, start, end, tzName)
return err
}
func (r *dashboardAggregationRepository) upsertHourlyAggregates(ctx context.Context, start, end time.Time) error {
tzName := timezone.Name()
query := `
WITH hourly AS (
SELECT
date_trunc('hour', created_at AT TIME ZONE $3) AT TIME ZONE $3 AS bucket_start,
COUNT(*) AS total_requests,
COALESCE(SUM(input_tokens), 0) AS input_tokens,
COALESCE(SUM(output_tokens), 0) AS output_tokens,
COALESCE(SUM(cache_creation_tokens), 0) AS cache_creation_tokens,
COALESCE(SUM(cache_read_tokens), 0) AS cache_read_tokens,
COALESCE(SUM(total_cost), 0) AS total_cost,
COALESCE(SUM(actual_cost), 0) AS actual_cost,
COALESCE(SUM(COALESCE(duration_ms, 0)), 0) AS total_duration_ms
FROM usage_logs
WHERE created_at >= $1 AND created_at < $2
GROUP BY 1
),
user_counts AS (
SELECT bucket_start, COUNT(*) AS active_users
FROM usage_dashboard_hourly_users
WHERE bucket_start >= $1 AND bucket_start < $2
GROUP BY bucket_start
)
INSERT INTO usage_dashboard_hourly (
bucket_start,
total_requests,
input_tokens,
output_tokens,
cache_creation_tokens,
cache_read_tokens,
total_cost,
actual_cost,
total_duration_ms,
active_users,
computed_at
)
SELECT
hourly.bucket_start,
hourly.total_requests,
hourly.input_tokens,
hourly.output_tokens,
hourly.cache_creation_tokens,
hourly.cache_read_tokens,
hourly.total_cost,
hourly.actual_cost,
hourly.total_duration_ms,
COALESCE(user_counts.active_users, 0) AS active_users,
NOW()
FROM hourly
LEFT JOIN user_counts ON user_counts.bucket_start = hourly.bucket_start
ON CONFLICT (bucket_start)
DO UPDATE SET
total_requests = EXCLUDED.total_requests,
input_tokens = EXCLUDED.input_tokens,
output_tokens = EXCLUDED.output_tokens,
cache_creation_tokens = EXCLUDED.cache_creation_tokens,
cache_read_tokens = EXCLUDED.cache_read_tokens,
total_cost = EXCLUDED.total_cost,
actual_cost = EXCLUDED.actual_cost,
total_duration_ms = EXCLUDED.total_duration_ms,
active_users = EXCLUDED.active_users,
computed_at = EXCLUDED.computed_at
`
_, err := r.sql.ExecContext(ctx, query, start, end, tzName)
return err
}
func (r *dashboardAggregationRepository) upsertDailyAggregates(ctx context.Context, start, end time.Time) error {
tzName := timezone.Name()
query := `
WITH daily AS (
SELECT
(bucket_start AT TIME ZONE $5)::date AS bucket_date,
COALESCE(SUM(total_requests), 0) AS total_requests,
COALESCE(SUM(input_tokens), 0) AS input_tokens,
COALESCE(SUM(output_tokens), 0) AS output_tokens,
COALESCE(SUM(cache_creation_tokens), 0) AS cache_creation_tokens,
COALESCE(SUM(cache_read_tokens), 0) AS cache_read_tokens,
COALESCE(SUM(total_cost), 0) AS total_cost,
COALESCE(SUM(actual_cost), 0) AS actual_cost,
COALESCE(SUM(total_duration_ms), 0) AS total_duration_ms
FROM usage_dashboard_hourly
WHERE bucket_start >= $1 AND bucket_start < $2
GROUP BY (bucket_start AT TIME ZONE $5)::date
),
user_counts AS (
SELECT bucket_date, COUNT(*) AS active_users
FROM usage_dashboard_daily_users
WHERE bucket_date >= $3::date AND bucket_date < $4::date
GROUP BY bucket_date
)
INSERT INTO usage_dashboard_daily (
bucket_date,
total_requests,
input_tokens,
output_tokens,
cache_creation_tokens,
cache_read_tokens,
total_cost,
actual_cost,
total_duration_ms,
active_users,
computed_at
)
SELECT
daily.bucket_date,
daily.total_requests,
daily.input_tokens,
daily.output_tokens,
daily.cache_creation_tokens,
daily.cache_read_tokens,
daily.total_cost,
daily.actual_cost,
daily.total_duration_ms,
COALESCE(user_counts.active_users, 0) AS active_users,
NOW()
FROM daily
LEFT JOIN user_counts ON user_counts.bucket_date = daily.bucket_date
ON CONFLICT (bucket_date)
DO UPDATE SET
total_requests = EXCLUDED.total_requests,
input_tokens = EXCLUDED.input_tokens,
output_tokens = EXCLUDED.output_tokens,
cache_creation_tokens = EXCLUDED.cache_creation_tokens,
cache_read_tokens = EXCLUDED.cache_read_tokens,
total_cost = EXCLUDED.total_cost,
actual_cost = EXCLUDED.actual_cost,
total_duration_ms = EXCLUDED.total_duration_ms,
active_users = EXCLUDED.active_users,
computed_at = EXCLUDED.computed_at
`
_, err := r.sql.ExecContext(ctx, query, start, end, start, end, tzName)
return err
}
func (r *dashboardAggregationRepository) isUsageLogsPartitioned(ctx context.Context) (bool, error) {
query := `
SELECT EXISTS(
SELECT 1
FROM pg_partitioned_table pt
JOIN pg_class c ON c.oid = pt.partrelid
WHERE c.relname = 'usage_logs'
)
`
var partitioned bool
if err := scanSingleRow(ctx, r.sql, query, nil, &partitioned); err != nil {
return false, err
}
return partitioned, nil
}
func (r *dashboardAggregationRepository) dropUsageLogsPartitions(ctx context.Context, cutoff time.Time) error {
rows, err := r.sql.QueryContext(ctx, `
SELECT c.relname
FROM pg_inherits
JOIN pg_class c ON c.oid = pg_inherits.inhrelid
JOIN pg_class p ON p.oid = pg_inherits.inhparent
WHERE p.relname = 'usage_logs'
`)
if err != nil {
return err
}
defer func() {
_ = rows.Close()
}()
cutoffMonth := truncateToMonthUTC(cutoff)
for rows.Next() {
var name string
if err := rows.Scan(&name); err != nil {
return err
}
if !strings.HasPrefix(name, "usage_logs_") {
continue
}
suffix := strings.TrimPrefix(name, "usage_logs_")
month, err := time.Parse("200601", suffix)
if err != nil {
continue
}
month = month.UTC()
if month.Before(cutoffMonth) {
if _, err := r.sql.ExecContext(ctx, fmt.Sprintf("DROP TABLE IF EXISTS %s", pq.QuoteIdentifier(name))); err != nil {
return err
}
}
}
return rows.Err()
}
func (r *dashboardAggregationRepository) createUsageLogsPartition(ctx context.Context, month time.Time) error {
monthStart := truncateToMonthUTC(month)
nextMonth := monthStart.AddDate(0, 1, 0)
name := fmt.Sprintf("usage_logs_%s", monthStart.Format("200601"))
query := fmt.Sprintf(
"CREATE TABLE IF NOT EXISTS %s PARTITION OF usage_logs FOR VALUES FROM (%s) TO (%s)",
pq.QuoteIdentifier(name),
pq.QuoteLiteral(monthStart.Format("2006-01-02")),
pq.QuoteLiteral(nextMonth.Format("2006-01-02")),
)
_, err := r.sql.ExecContext(ctx, query)
return err
}
func truncateToDay(t time.Time) time.Time {
return timezone.StartOfDay(t)
}
func truncateToMonthUTC(t time.Time) time.Time {
t = t.UTC()
return time.Date(t.Year(), t.Month(), 1, 0, 0, 0, 0, time.UTC)
}

View File

@@ -0,0 +1,58 @@
package repository
import (
"context"
"strings"
"time"
"github.com/Wei-Shaw/sub2api/internal/config"
"github.com/Wei-Shaw/sub2api/internal/service"
"github.com/redis/go-redis/v9"
)
const dashboardStatsCacheKey = "dashboard:stats:v1"
type dashboardCache struct {
rdb *redis.Client
keyPrefix string
}
func NewDashboardCache(rdb *redis.Client, cfg *config.Config) service.DashboardStatsCache {
prefix := "sub2api:"
if cfg != nil {
prefix = strings.TrimSpace(cfg.Dashboard.KeyPrefix)
}
if prefix != "" && !strings.HasSuffix(prefix, ":") {
prefix += ":"
}
return &dashboardCache{
rdb: rdb,
keyPrefix: prefix,
}
}
func (c *dashboardCache) GetDashboardStats(ctx context.Context) (string, error) {
val, err := c.rdb.Get(ctx, c.buildKey()).Result()
if err != nil {
if err == redis.Nil {
return "", service.ErrDashboardStatsCacheMiss
}
return "", err
}
return val, nil
}
func (c *dashboardCache) SetDashboardStats(ctx context.Context, data string, ttl time.Duration) error {
return c.rdb.Set(ctx, c.buildKey(), data, ttl).Err()
}
func (c *dashboardCache) buildKey() string {
if c.keyPrefix == "" {
return dashboardStatsCacheKey
}
return c.keyPrefix + dashboardStatsCacheKey
}
func (c *dashboardCache) DeleteDashboardStats(ctx context.Context) error {
return c.rdb.Del(ctx, c.buildKey()).Err()
}

View File

@@ -0,0 +1,28 @@
package repository
import (
"testing"
"github.com/Wei-Shaw/sub2api/internal/config"
"github.com/stretchr/testify/require"
)
func TestNewDashboardCacheKeyPrefix(t *testing.T) {
cache := NewDashboardCache(nil, &config.Config{
Dashboard: config.DashboardCacheConfig{
KeyPrefix: "prod",
},
})
impl, ok := cache.(*dashboardCache)
require.True(t, ok)
require.Equal(t, "prod:", impl.keyPrefix)
cache = NewDashboardCache(nil, &config.Config{
Dashboard: config.DashboardCacheConfig{
KeyPrefix: "staging:",
},
})
impl, ok = cache.(*dashboardCache)
require.True(t, ok)
require.Equal(t, "staging:", impl.keyPrefix)
}

View File

@@ -11,8 +11,8 @@ import (
)
const (
geminiTokenKeyPrefix = "gemini:token:"
geminiRefreshLockKeyPrefix = "gemini:refresh_lock:"
oauthTokenKeyPrefix = "oauth:token:"
oauthRefreshLockKeyPrefix = "oauth:refresh_lock:"
)
type geminiTokenCache struct {
@@ -24,21 +24,26 @@ func NewGeminiTokenCache(rdb *redis.Client) service.GeminiTokenCache {
}
func (c *geminiTokenCache) GetAccessToken(ctx context.Context, cacheKey string) (string, error) {
key := fmt.Sprintf("%s%s", geminiTokenKeyPrefix, cacheKey)
key := fmt.Sprintf("%s%s", oauthTokenKeyPrefix, cacheKey)
return c.rdb.Get(ctx, key).Result()
}
func (c *geminiTokenCache) SetAccessToken(ctx context.Context, cacheKey string, token string, ttl time.Duration) error {
key := fmt.Sprintf("%s%s", geminiTokenKeyPrefix, cacheKey)
key := fmt.Sprintf("%s%s", oauthTokenKeyPrefix, cacheKey)
return c.rdb.Set(ctx, key, token, ttl).Err()
}
func (c *geminiTokenCache) DeleteAccessToken(ctx context.Context, cacheKey string) error {
key := fmt.Sprintf("%s%s", oauthTokenKeyPrefix, cacheKey)
return c.rdb.Del(ctx, key).Err()
}
func (c *geminiTokenCache) AcquireRefreshLock(ctx context.Context, cacheKey string, ttl time.Duration) (bool, error) {
key := fmt.Sprintf("%s%s", geminiRefreshLockKeyPrefix, cacheKey)
key := fmt.Sprintf("%s%s", oauthRefreshLockKeyPrefix, cacheKey)
return c.rdb.SetNX(ctx, key, 1, ttl).Result()
}
func (c *geminiTokenCache) ReleaseRefreshLock(ctx context.Context, cacheKey string) error {
key := fmt.Sprintf("%s%s", geminiRefreshLockKeyPrefix, cacheKey)
key := fmt.Sprintf("%s%s", oauthRefreshLockKeyPrefix, cacheKey)
return c.rdb.Del(ctx, key).Err()
}

View File

@@ -0,0 +1,47 @@
//go:build integration
package repository
import (
"errors"
"testing"
"time"
"github.com/Wei-Shaw/sub2api/internal/service"
"github.com/redis/go-redis/v9"
"github.com/stretchr/testify/require"
"github.com/stretchr/testify/suite"
)
type GeminiTokenCacheSuite struct {
IntegrationRedisSuite
cache service.GeminiTokenCache
}
func (s *GeminiTokenCacheSuite) SetupTest() {
s.IntegrationRedisSuite.SetupTest()
s.cache = NewGeminiTokenCache(s.rdb)
}
func (s *GeminiTokenCacheSuite) TestDeleteAccessToken() {
cacheKey := "project-123"
token := "token-value"
require.NoError(s.T(), s.cache.SetAccessToken(s.ctx, cacheKey, token, time.Minute))
got, err := s.cache.GetAccessToken(s.ctx, cacheKey)
require.NoError(s.T(), err)
require.Equal(s.T(), token, got)
require.NoError(s.T(), s.cache.DeleteAccessToken(s.ctx, cacheKey))
_, err = s.cache.GetAccessToken(s.ctx, cacheKey)
require.True(s.T(), errors.Is(err, redis.Nil), "expected redis.Nil after delete")
}
func (s *GeminiTokenCacheSuite) TestDeleteAccessToken_MissingKey() {
require.NoError(s.T(), s.cache.DeleteAccessToken(s.ctx, "missing-key"))
}
func TestGeminiTokenCacheSuite(t *testing.T) {
suite.Run(t, new(GeminiTokenCacheSuite))
}

View File

@@ -0,0 +1,28 @@
//go:build unit
package repository
import (
"context"
"testing"
"time"
"github.com/redis/go-redis/v9"
"github.com/stretchr/testify/require"
)
func TestGeminiTokenCache_DeleteAccessToken_RedisError(t *testing.T) {
rdb := redis.NewClient(&redis.Options{
Addr: "127.0.0.1:1",
DialTimeout: 50 * time.Millisecond,
ReadTimeout: 50 * time.Millisecond,
WriteTimeout: 50 * time.Millisecond,
})
t.Cleanup(func() {
_ = rdb.Close()
})
cache := NewGeminiTokenCache(rdb)
err := cache.DeleteAccessToken(context.Background(), "broken")
require.Error(t, err)
}

View File

@@ -4,6 +4,7 @@ import (
"context"
"database/sql"
"errors"
"log"
dbent "github.com/Wei-Shaw/sub2api/ent"
"github.com/Wei-Shaw/sub2api/ent/apikey"
@@ -48,18 +49,38 @@ func (r *groupRepository) Create(ctx context.Context, groupIn *service.Group) er
SetNillableImagePrice4k(groupIn.ImagePrice4K).
SetDefaultValidityDays(groupIn.DefaultValidityDays).
SetClaudeCodeOnly(groupIn.ClaudeCodeOnly).
SetNillableFallbackGroupID(groupIn.FallbackGroupID)
SetNillableFallbackGroupID(groupIn.FallbackGroupID).
SetModelRoutingEnabled(groupIn.ModelRoutingEnabled)
// 设置模型路由配置
if groupIn.ModelRouting != nil {
builder = builder.SetModelRouting(groupIn.ModelRouting)
}
created, err := builder.Save(ctx)
if err == nil {
groupIn.ID = created.ID
groupIn.CreatedAt = created.CreatedAt
groupIn.UpdatedAt = created.UpdatedAt
if err := enqueueSchedulerOutbox(ctx, r.sql, service.SchedulerOutboxEventGroupChanged, nil, &groupIn.ID, nil); err != nil {
log.Printf("[SchedulerOutbox] enqueue group create failed: group=%d err=%v", groupIn.ID, err)
}
}
return translatePersistenceError(err, nil, service.ErrGroupExists)
}
func (r *groupRepository) GetByID(ctx context.Context, id int64) (*service.Group, error) {
out, err := r.GetByIDLite(ctx, id)
if err != nil {
return nil, err
}
count, _ := r.GetAccountCount(ctx, out.ID)
out.AccountCount = count
return out, nil
}
func (r *groupRepository) GetByIDLite(ctx context.Context, id int64) (*service.Group, error) {
// AccountCount is intentionally not loaded here; use GetByID when needed.
m, err := r.client.Group.Query().
Where(group.IDEQ(id)).
Only(ctx)
@@ -67,10 +88,7 @@ func (r *groupRepository) GetByID(ctx context.Context, id int64) (*service.Group
return nil, translatePersistenceError(err, service.ErrGroupNotFound, nil)
}
out := groupEntityToService(m)
count, _ := r.GetAccountCount(ctx, out.ID)
out.AccountCount = count
return out, nil
return groupEntityToService(m), nil
}
func (r *groupRepository) Update(ctx context.Context, groupIn *service.Group) error {
@@ -89,7 +107,8 @@ func (r *groupRepository) Update(ctx context.Context, groupIn *service.Group) er
SetNillableImagePrice2k(groupIn.ImagePrice2K).
SetNillableImagePrice4k(groupIn.ImagePrice4K).
SetDefaultValidityDays(groupIn.DefaultValidityDays).
SetClaudeCodeOnly(groupIn.ClaudeCodeOnly)
SetClaudeCodeOnly(groupIn.ClaudeCodeOnly).
SetModelRoutingEnabled(groupIn.ModelRoutingEnabled)
// 处理 FallbackGroupIDnil 时清除,否则设置
if groupIn.FallbackGroupID != nil {
@@ -98,17 +117,33 @@ func (r *groupRepository) Update(ctx context.Context, groupIn *service.Group) er
builder = builder.ClearFallbackGroupID()
}
// 处理 ModelRoutingnil 时清除,否则设置
if groupIn.ModelRouting != nil {
builder = builder.SetModelRouting(groupIn.ModelRouting)
} else {
builder = builder.ClearModelRouting()
}
updated, err := builder.Save(ctx)
if err != nil {
return translatePersistenceError(err, service.ErrGroupNotFound, service.ErrGroupExists)
}
groupIn.UpdatedAt = updated.UpdatedAt
if err := enqueueSchedulerOutbox(ctx, r.sql, service.SchedulerOutboxEventGroupChanged, nil, &groupIn.ID, nil); err != nil {
log.Printf("[SchedulerOutbox] enqueue group update failed: group=%d err=%v", groupIn.ID, err)
}
return nil
}
func (r *groupRepository) Delete(ctx context.Context, id int64) error {
_, err := r.client.Group.Delete().Where(group.IDEQ(id)).Exec(ctx)
return translatePersistenceError(err, service.ErrGroupNotFound, nil)
if err != nil {
return translatePersistenceError(err, service.ErrGroupNotFound, nil)
}
if err := enqueueSchedulerOutbox(ctx, r.sql, service.SchedulerOutboxEventGroupChanged, nil, &id, nil); err != nil {
log.Printf("[SchedulerOutbox] enqueue group delete failed: group=%d err=%v", id, err)
}
return nil
}
func (r *groupRepository) List(ctx context.Context, params pagination.PaginationParams) ([]service.Group, *pagination.PaginationResult, error) {
@@ -238,6 +273,9 @@ func (r *groupRepository) DeleteAccountGroupsByGroupID(ctx context.Context, grou
return 0, err
}
affected, _ := res.RowsAffected()
if err := enqueueSchedulerOutbox(ctx, r.sql, service.SchedulerOutboxEventGroupChanged, nil, &groupID, nil); err != nil {
log.Printf("[SchedulerOutbox] enqueue group account clear failed: group=%d err=%v", groupID, err)
}
return affected, nil
}
@@ -345,6 +383,9 @@ func (r *groupRepository) DeleteCascade(ctx context.Context, id int64) ([]int64,
return nil, err
}
}
if err := enqueueSchedulerOutbox(ctx, r.sql, service.SchedulerOutboxEventGroupChanged, nil, &id, nil); err != nil {
log.Printf("[SchedulerOutbox] enqueue group cascade delete failed: group=%d err=%v", id, err)
}
return affectedUserIDs, nil
}

View File

@@ -4,6 +4,8 @@ package repository
import (
"context"
"database/sql"
"errors"
"testing"
dbent "github.com/Wei-Shaw/sub2api/ent"
@@ -19,6 +21,20 @@ type GroupRepoSuite struct {
repo *groupRepository
}
type forbidSQLExecutor struct {
called bool
}
func (s *forbidSQLExecutor) ExecContext(ctx context.Context, query string, args ...any) (sql.Result, error) {
s.called = true
return nil, errors.New("unexpected sql exec")
}
func (s *forbidSQLExecutor) QueryContext(ctx context.Context, query string, args ...any) (*sql.Rows, error) {
s.called = true
return nil, errors.New("unexpected sql query")
}
func (s *GroupRepoSuite) SetupTest() {
s.ctx = context.Background()
tx := testEntTx(s.T())
@@ -57,6 +73,26 @@ func (s *GroupRepoSuite) TestGetByID_NotFound() {
s.Require().ErrorIs(err, service.ErrGroupNotFound)
}
func (s *GroupRepoSuite) TestGetByIDLite_DoesNotUseAccountCount() {
group := &service.Group{
Name: "lite-group",
Platform: service.PlatformAnthropic,
RateMultiplier: 1.0,
IsExclusive: false,
Status: service.StatusActive,
SubscriptionType: service.SubscriptionTypeStandard,
}
s.Require().NoError(s.repo.Create(s.ctx, group))
spy := &forbidSQLExecutor{}
repo := newGroupRepositoryWithSQL(s.tx.Client(), spy)
got, err := repo.GetByIDLite(s.ctx, group.ID)
s.Require().NoError(err)
s.Require().Equal(group.ID, got.ID)
s.Require().False(spy.called, "expected no direct sql executor usage")
}
func (s *GroupRepoSuite) TestUpdate() {
group := &service.Group{
Name: "original",

View File

@@ -28,6 +28,23 @@ CREATE TABLE IF NOT EXISTS schema_migrations (
);
`
const atlasSchemaRevisionsTableDDL = `
CREATE TABLE IF NOT EXISTS atlas_schema_revisions (
version TEXT PRIMARY KEY,
description TEXT NOT NULL,
type INTEGER NOT NULL,
applied INTEGER NOT NULL DEFAULT 0,
total INTEGER NOT NULL DEFAULT 0,
executed_at TIMESTAMPTZ NOT NULL DEFAULT NOW(),
execution_time BIGINT NOT NULL DEFAULT 0,
error TEXT NULL,
error_stmt TEXT NULL,
hash TEXT NOT NULL DEFAULT '',
partial_hashes TEXT[] NULL,
operator_version TEXT NULL
);
`
// migrationsAdvisoryLockID 是用于序列化迁移操作的 PostgreSQL Advisory Lock ID。
// 在多实例部署场景下,该锁确保同一时间只有一个实例执行迁移。
// 任何稳定的 int64 值都可以,只要不与同一数据库中的其他锁冲突即可。
@@ -94,6 +111,11 @@ func applyMigrationsFS(ctx context.Context, db *sql.DB, fsys fs.FS) error {
return fmt.Errorf("create schema_migrations: %w", err)
}
// 自动对齐 Atlas 基线(如果检测到 legacy schema_migrations 且缺失 atlas_schema_revisions
if err := ensureAtlasBaselineAligned(ctx, db, fsys); err != nil {
return err
}
// 获取所有 .sql 迁移文件并按文件名排序。
// 命名规范:使用零填充数字前缀(如 001_init.sql, 002_add_users.sql
files, err := fs.Glob(fsys, "*.sql")
@@ -172,6 +194,80 @@ func applyMigrationsFS(ctx context.Context, db *sql.DB, fsys fs.FS) error {
return nil
}
func ensureAtlasBaselineAligned(ctx context.Context, db *sql.DB, fsys fs.FS) error {
hasLegacy, err := tableExists(ctx, db, "schema_migrations")
if err != nil {
return fmt.Errorf("check schema_migrations: %w", err)
}
if !hasLegacy {
return nil
}
hasAtlas, err := tableExists(ctx, db, "atlas_schema_revisions")
if err != nil {
return fmt.Errorf("check atlas_schema_revisions: %w", err)
}
if !hasAtlas {
if _, err := db.ExecContext(ctx, atlasSchemaRevisionsTableDDL); err != nil {
return fmt.Errorf("create atlas_schema_revisions: %w", err)
}
}
var count int
if err := db.QueryRowContext(ctx, "SELECT COUNT(*) FROM atlas_schema_revisions").Scan(&count); err != nil {
return fmt.Errorf("count atlas_schema_revisions: %w", err)
}
if count > 0 {
return nil
}
version, description, hash, err := latestMigrationBaseline(fsys)
if err != nil {
return fmt.Errorf("atlas baseline version: %w", err)
}
if _, err := db.ExecContext(ctx, `
INSERT INTO atlas_schema_revisions (version, description, type, applied, total, executed_at, execution_time, hash)
VALUES ($1, $2, $3, 0, 0, NOW(), 0, $4)
`, version, description, 1, hash); err != nil {
return fmt.Errorf("insert atlas baseline: %w", err)
}
return nil
}
func tableExists(ctx context.Context, db *sql.DB, tableName string) (bool, error) {
var exists bool
err := db.QueryRowContext(ctx, `
SELECT EXISTS (
SELECT 1
FROM information_schema.tables
WHERE table_schema = 'public' AND table_name = $1
)
`, tableName).Scan(&exists)
return exists, err
}
func latestMigrationBaseline(fsys fs.FS) (string, string, string, error) {
files, err := fs.Glob(fsys, "*.sql")
if err != nil {
return "", "", "", err
}
if len(files) == 0 {
return "baseline", "baseline", "", nil
}
sort.Strings(files)
name := files[len(files)-1]
contentBytes, err := fs.ReadFile(fsys, name)
if err != nil {
return "", "", "", err
}
content := strings.TrimSpace(string(contentBytes))
sum := sha256.Sum256([]byte(content))
hash := hex.EncodeToString(sum[:])
version := strings.TrimSuffix(name, ".sql")
return version, version, hash, nil
}
// pgAdvisoryLock 获取 PostgreSQL Advisory Lock。
// Advisory Lock 是一种轻量级的锁机制,不与任何特定的数据库对象关联。
// 它非常适合用于应用层面的分布式锁场景,如迁移序列化。

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,853 @@
package repository
import (
"context"
"database/sql"
"encoding/json"
"fmt"
"strings"
"time"
"github.com/Wei-Shaw/sub2api/internal/service"
)
func (r *opsRepository) ListAlertRules(ctx context.Context) ([]*service.OpsAlertRule, error) {
if r == nil || r.db == nil {
return nil, fmt.Errorf("nil ops repository")
}
q := `
SELECT
id,
name,
COALESCE(description, ''),
enabled,
COALESCE(severity, ''),
metric_type,
operator,
threshold,
window_minutes,
sustained_minutes,
cooldown_minutes,
COALESCE(notify_email, true),
filters,
last_triggered_at,
created_at,
updated_at
FROM ops_alert_rules
ORDER BY id DESC`
rows, err := r.db.QueryContext(ctx, q)
if err != nil {
return nil, err
}
defer func() { _ = rows.Close() }()
out := []*service.OpsAlertRule{}
for rows.Next() {
var rule service.OpsAlertRule
var filtersRaw []byte
var lastTriggeredAt sql.NullTime
if err := rows.Scan(
&rule.ID,
&rule.Name,
&rule.Description,
&rule.Enabled,
&rule.Severity,
&rule.MetricType,
&rule.Operator,
&rule.Threshold,
&rule.WindowMinutes,
&rule.SustainedMinutes,
&rule.CooldownMinutes,
&rule.NotifyEmail,
&filtersRaw,
&lastTriggeredAt,
&rule.CreatedAt,
&rule.UpdatedAt,
); err != nil {
return nil, err
}
if lastTriggeredAt.Valid {
v := lastTriggeredAt.Time
rule.LastTriggeredAt = &v
}
if len(filtersRaw) > 0 && string(filtersRaw) != "null" {
var decoded map[string]any
if err := json.Unmarshal(filtersRaw, &decoded); err == nil {
rule.Filters = decoded
}
}
out = append(out, &rule)
}
if err := rows.Err(); err != nil {
return nil, err
}
return out, nil
}
func (r *opsRepository) CreateAlertRule(ctx context.Context, input *service.OpsAlertRule) (*service.OpsAlertRule, error) {
if r == nil || r.db == nil {
return nil, fmt.Errorf("nil ops repository")
}
if input == nil {
return nil, fmt.Errorf("nil input")
}
filtersArg, err := opsNullJSONMap(input.Filters)
if err != nil {
return nil, err
}
q := `
INSERT INTO ops_alert_rules (
name,
description,
enabled,
severity,
metric_type,
operator,
threshold,
window_minutes,
sustained_minutes,
cooldown_minutes,
notify_email,
filters,
created_at,
updated_at
) VALUES (
$1,$2,$3,$4,$5,$6,$7,$8,$9,$10,$11,$12,NOW(),NOW()
)
RETURNING
id,
name,
COALESCE(description, ''),
enabled,
COALESCE(severity, ''),
metric_type,
operator,
threshold,
window_minutes,
sustained_minutes,
cooldown_minutes,
COALESCE(notify_email, true),
filters,
last_triggered_at,
created_at,
updated_at`
var out service.OpsAlertRule
var filtersRaw []byte
var lastTriggeredAt sql.NullTime
if err := r.db.QueryRowContext(
ctx,
q,
strings.TrimSpace(input.Name),
strings.TrimSpace(input.Description),
input.Enabled,
strings.TrimSpace(input.Severity),
strings.TrimSpace(input.MetricType),
strings.TrimSpace(input.Operator),
input.Threshold,
input.WindowMinutes,
input.SustainedMinutes,
input.CooldownMinutes,
input.NotifyEmail,
filtersArg,
).Scan(
&out.ID,
&out.Name,
&out.Description,
&out.Enabled,
&out.Severity,
&out.MetricType,
&out.Operator,
&out.Threshold,
&out.WindowMinutes,
&out.SustainedMinutes,
&out.CooldownMinutes,
&out.NotifyEmail,
&filtersRaw,
&lastTriggeredAt,
&out.CreatedAt,
&out.UpdatedAt,
); err != nil {
return nil, err
}
if lastTriggeredAt.Valid {
v := lastTriggeredAt.Time
out.LastTriggeredAt = &v
}
if len(filtersRaw) > 0 && string(filtersRaw) != "null" {
var decoded map[string]any
if err := json.Unmarshal(filtersRaw, &decoded); err == nil {
out.Filters = decoded
}
}
return &out, nil
}
func (r *opsRepository) UpdateAlertRule(ctx context.Context, input *service.OpsAlertRule) (*service.OpsAlertRule, error) {
if r == nil || r.db == nil {
return nil, fmt.Errorf("nil ops repository")
}
if input == nil {
return nil, fmt.Errorf("nil input")
}
if input.ID <= 0 {
return nil, fmt.Errorf("invalid id")
}
filtersArg, err := opsNullJSONMap(input.Filters)
if err != nil {
return nil, err
}
q := `
UPDATE ops_alert_rules
SET
name = $2,
description = $3,
enabled = $4,
severity = $5,
metric_type = $6,
operator = $7,
threshold = $8,
window_minutes = $9,
sustained_minutes = $10,
cooldown_minutes = $11,
notify_email = $12,
filters = $13,
updated_at = NOW()
WHERE id = $1
RETURNING
id,
name,
COALESCE(description, ''),
enabled,
COALESCE(severity, ''),
metric_type,
operator,
threshold,
window_minutes,
sustained_minutes,
cooldown_minutes,
COALESCE(notify_email, true),
filters,
last_triggered_at,
created_at,
updated_at`
var out service.OpsAlertRule
var filtersRaw []byte
var lastTriggeredAt sql.NullTime
if err := r.db.QueryRowContext(
ctx,
q,
input.ID,
strings.TrimSpace(input.Name),
strings.TrimSpace(input.Description),
input.Enabled,
strings.TrimSpace(input.Severity),
strings.TrimSpace(input.MetricType),
strings.TrimSpace(input.Operator),
input.Threshold,
input.WindowMinutes,
input.SustainedMinutes,
input.CooldownMinutes,
input.NotifyEmail,
filtersArg,
).Scan(
&out.ID,
&out.Name,
&out.Description,
&out.Enabled,
&out.Severity,
&out.MetricType,
&out.Operator,
&out.Threshold,
&out.WindowMinutes,
&out.SustainedMinutes,
&out.CooldownMinutes,
&out.NotifyEmail,
&filtersRaw,
&lastTriggeredAt,
&out.CreatedAt,
&out.UpdatedAt,
); err != nil {
return nil, err
}
if lastTriggeredAt.Valid {
v := lastTriggeredAt.Time
out.LastTriggeredAt = &v
}
if len(filtersRaw) > 0 && string(filtersRaw) != "null" {
var decoded map[string]any
if err := json.Unmarshal(filtersRaw, &decoded); err == nil {
out.Filters = decoded
}
}
return &out, nil
}
func (r *opsRepository) DeleteAlertRule(ctx context.Context, id int64) error {
if r == nil || r.db == nil {
return fmt.Errorf("nil ops repository")
}
if id <= 0 {
return fmt.Errorf("invalid id")
}
res, err := r.db.ExecContext(ctx, "DELETE FROM ops_alert_rules WHERE id = $1", id)
if err != nil {
return err
}
affected, err := res.RowsAffected()
if err != nil {
return err
}
if affected == 0 {
return sql.ErrNoRows
}
return nil
}
func (r *opsRepository) ListAlertEvents(ctx context.Context, filter *service.OpsAlertEventFilter) ([]*service.OpsAlertEvent, error) {
if r == nil || r.db == nil {
return nil, fmt.Errorf("nil ops repository")
}
if filter == nil {
filter = &service.OpsAlertEventFilter{}
}
limit := filter.Limit
if limit <= 0 {
limit = 100
}
if limit > 500 {
limit = 500
}
where, args := buildOpsAlertEventsWhere(filter)
args = append(args, limit)
limitArg := "$" + itoa(len(args))
q := `
SELECT
id,
COALESCE(rule_id, 0),
COALESCE(severity, ''),
COALESCE(status, ''),
COALESCE(title, ''),
COALESCE(description, ''),
metric_value,
threshold_value,
dimensions,
fired_at,
resolved_at,
email_sent,
created_at
FROM ops_alert_events
` + where + `
ORDER BY fired_at DESC, id DESC
LIMIT ` + limitArg
rows, err := r.db.QueryContext(ctx, q, args...)
if err != nil {
return nil, err
}
defer func() { _ = rows.Close() }()
out := []*service.OpsAlertEvent{}
for rows.Next() {
var ev service.OpsAlertEvent
var metricValue sql.NullFloat64
var thresholdValue sql.NullFloat64
var dimensionsRaw []byte
var resolvedAt sql.NullTime
if err := rows.Scan(
&ev.ID,
&ev.RuleID,
&ev.Severity,
&ev.Status,
&ev.Title,
&ev.Description,
&metricValue,
&thresholdValue,
&dimensionsRaw,
&ev.FiredAt,
&resolvedAt,
&ev.EmailSent,
&ev.CreatedAt,
); err != nil {
return nil, err
}
if metricValue.Valid {
v := metricValue.Float64
ev.MetricValue = &v
}
if thresholdValue.Valid {
v := thresholdValue.Float64
ev.ThresholdValue = &v
}
if resolvedAt.Valid {
v := resolvedAt.Time
ev.ResolvedAt = &v
}
if len(dimensionsRaw) > 0 && string(dimensionsRaw) != "null" {
var decoded map[string]any
if err := json.Unmarshal(dimensionsRaw, &decoded); err == nil {
ev.Dimensions = decoded
}
}
out = append(out, &ev)
}
if err := rows.Err(); err != nil {
return nil, err
}
return out, nil
}
func (r *opsRepository) GetAlertEventByID(ctx context.Context, eventID int64) (*service.OpsAlertEvent, error) {
if r == nil || r.db == nil {
return nil, fmt.Errorf("nil ops repository")
}
if eventID <= 0 {
return nil, fmt.Errorf("invalid event id")
}
q := `
SELECT
id,
COALESCE(rule_id, 0),
COALESCE(severity, ''),
COALESCE(status, ''),
COALESCE(title, ''),
COALESCE(description, ''),
metric_value,
threshold_value,
dimensions,
fired_at,
resolved_at,
email_sent,
created_at
FROM ops_alert_events
WHERE id = $1`
row := r.db.QueryRowContext(ctx, q, eventID)
ev, err := scanOpsAlertEvent(row)
if err != nil {
if err == sql.ErrNoRows {
return nil, nil
}
return nil, err
}
return ev, nil
}
func (r *opsRepository) GetActiveAlertEvent(ctx context.Context, ruleID int64) (*service.OpsAlertEvent, error) {
if r == nil || r.db == nil {
return nil, fmt.Errorf("nil ops repository")
}
if ruleID <= 0 {
return nil, fmt.Errorf("invalid rule id")
}
q := `
SELECT
id,
COALESCE(rule_id, 0),
COALESCE(severity, ''),
COALESCE(status, ''),
COALESCE(title, ''),
COALESCE(description, ''),
metric_value,
threshold_value,
dimensions,
fired_at,
resolved_at,
email_sent,
created_at
FROM ops_alert_events
WHERE rule_id = $1 AND status = $2
ORDER BY fired_at DESC
LIMIT 1`
row := r.db.QueryRowContext(ctx, q, ruleID, service.OpsAlertStatusFiring)
ev, err := scanOpsAlertEvent(row)
if err != nil {
if err == sql.ErrNoRows {
return nil, nil
}
return nil, err
}
return ev, nil
}
func (r *opsRepository) GetLatestAlertEvent(ctx context.Context, ruleID int64) (*service.OpsAlertEvent, error) {
if r == nil || r.db == nil {
return nil, fmt.Errorf("nil ops repository")
}
if ruleID <= 0 {
return nil, fmt.Errorf("invalid rule id")
}
q := `
SELECT
id,
COALESCE(rule_id, 0),
COALESCE(severity, ''),
COALESCE(status, ''),
COALESCE(title, ''),
COALESCE(description, ''),
metric_value,
threshold_value,
dimensions,
fired_at,
resolved_at,
email_sent,
created_at
FROM ops_alert_events
WHERE rule_id = $1
ORDER BY fired_at DESC
LIMIT 1`
row := r.db.QueryRowContext(ctx, q, ruleID)
ev, err := scanOpsAlertEvent(row)
if err != nil {
if err == sql.ErrNoRows {
return nil, nil
}
return nil, err
}
return ev, nil
}
func (r *opsRepository) CreateAlertEvent(ctx context.Context, event *service.OpsAlertEvent) (*service.OpsAlertEvent, error) {
if r == nil || r.db == nil {
return nil, fmt.Errorf("nil ops repository")
}
if event == nil {
return nil, fmt.Errorf("nil event")
}
dimensionsArg, err := opsNullJSONMap(event.Dimensions)
if err != nil {
return nil, err
}
q := `
INSERT INTO ops_alert_events (
rule_id,
severity,
status,
title,
description,
metric_value,
threshold_value,
dimensions,
fired_at,
resolved_at,
email_sent,
created_at
) VALUES (
$1,$2,$3,$4,$5,$6,$7,$8,$9,$10,$11,NOW()
)
RETURNING
id,
COALESCE(rule_id, 0),
COALESCE(severity, ''),
COALESCE(status, ''),
COALESCE(title, ''),
COALESCE(description, ''),
metric_value,
threshold_value,
dimensions,
fired_at,
resolved_at,
email_sent,
created_at`
row := r.db.QueryRowContext(
ctx,
q,
opsNullInt64(&event.RuleID),
opsNullString(event.Severity),
opsNullString(event.Status),
opsNullString(event.Title),
opsNullString(event.Description),
opsNullFloat64(event.MetricValue),
opsNullFloat64(event.ThresholdValue),
dimensionsArg,
event.FiredAt,
opsNullTime(event.ResolvedAt),
event.EmailSent,
)
return scanOpsAlertEvent(row)
}
func (r *opsRepository) UpdateAlertEventStatus(ctx context.Context, eventID int64, status string, resolvedAt *time.Time) error {
if r == nil || r.db == nil {
return fmt.Errorf("nil ops repository")
}
if eventID <= 0 {
return fmt.Errorf("invalid event id")
}
if strings.TrimSpace(status) == "" {
return fmt.Errorf("invalid status")
}
q := `
UPDATE ops_alert_events
SET status = $2,
resolved_at = $3
WHERE id = $1`
_, err := r.db.ExecContext(ctx, q, eventID, strings.TrimSpace(status), opsNullTime(resolvedAt))
return err
}
func (r *opsRepository) UpdateAlertEventEmailSent(ctx context.Context, eventID int64, emailSent bool) error {
if r == nil || r.db == nil {
return fmt.Errorf("nil ops repository")
}
if eventID <= 0 {
return fmt.Errorf("invalid event id")
}
_, err := r.db.ExecContext(ctx, "UPDATE ops_alert_events SET email_sent = $2 WHERE id = $1", eventID, emailSent)
return err
}
type opsAlertEventRow interface {
Scan(dest ...any) error
}
func (r *opsRepository) CreateAlertSilence(ctx context.Context, input *service.OpsAlertSilence) (*service.OpsAlertSilence, error) {
if r == nil || r.db == nil {
return nil, fmt.Errorf("nil ops repository")
}
if input == nil {
return nil, fmt.Errorf("nil input")
}
if input.RuleID <= 0 {
return nil, fmt.Errorf("invalid rule_id")
}
platform := strings.TrimSpace(input.Platform)
if platform == "" {
return nil, fmt.Errorf("invalid platform")
}
if input.Until.IsZero() {
return nil, fmt.Errorf("invalid until")
}
q := `
INSERT INTO ops_alert_silences (
rule_id,
platform,
group_id,
region,
until,
reason,
created_by,
created_at
) VALUES (
$1,$2,$3,$4,$5,$6,$7,NOW()
)
RETURNING id, rule_id, platform, group_id, region, until, COALESCE(reason,''), created_by, created_at`
row := r.db.QueryRowContext(
ctx,
q,
input.RuleID,
platform,
opsNullInt64(input.GroupID),
opsNullString(input.Region),
input.Until,
opsNullString(input.Reason),
opsNullInt64(input.CreatedBy),
)
var out service.OpsAlertSilence
var groupID sql.NullInt64
var region sql.NullString
var createdBy sql.NullInt64
if err := row.Scan(
&out.ID,
&out.RuleID,
&out.Platform,
&groupID,
&region,
&out.Until,
&out.Reason,
&createdBy,
&out.CreatedAt,
); err != nil {
return nil, err
}
if groupID.Valid {
v := groupID.Int64
out.GroupID = &v
}
if region.Valid {
v := strings.TrimSpace(region.String)
if v != "" {
out.Region = &v
}
}
if createdBy.Valid {
v := createdBy.Int64
out.CreatedBy = &v
}
return &out, nil
}
func (r *opsRepository) IsAlertSilenced(ctx context.Context, ruleID int64, platform string, groupID *int64, region *string, now time.Time) (bool, error) {
if r == nil || r.db == nil {
return false, fmt.Errorf("nil ops repository")
}
if ruleID <= 0 {
return false, fmt.Errorf("invalid rule id")
}
platform = strings.TrimSpace(platform)
if platform == "" {
return false, nil
}
if now.IsZero() {
now = time.Now().UTC()
}
q := `
SELECT 1
FROM ops_alert_silences
WHERE rule_id = $1
AND platform = $2
AND (group_id IS NOT DISTINCT FROM $3)
AND (region IS NOT DISTINCT FROM $4)
AND until > $5
LIMIT 1`
var dummy int
err := r.db.QueryRowContext(ctx, q, ruleID, platform, opsNullInt64(groupID), opsNullString(region), now).Scan(&dummy)
if err != nil {
if err == sql.ErrNoRows {
return false, nil
}
return false, err
}
return true, nil
}
func scanOpsAlertEvent(row opsAlertEventRow) (*service.OpsAlertEvent, error) {
var ev service.OpsAlertEvent
var metricValue sql.NullFloat64
var thresholdValue sql.NullFloat64
var dimensionsRaw []byte
var resolvedAt sql.NullTime
if err := row.Scan(
&ev.ID,
&ev.RuleID,
&ev.Severity,
&ev.Status,
&ev.Title,
&ev.Description,
&metricValue,
&thresholdValue,
&dimensionsRaw,
&ev.FiredAt,
&resolvedAt,
&ev.EmailSent,
&ev.CreatedAt,
); err != nil {
return nil, err
}
if metricValue.Valid {
v := metricValue.Float64
ev.MetricValue = &v
}
if thresholdValue.Valid {
v := thresholdValue.Float64
ev.ThresholdValue = &v
}
if resolvedAt.Valid {
v := resolvedAt.Time
ev.ResolvedAt = &v
}
if len(dimensionsRaw) > 0 && string(dimensionsRaw) != "null" {
var decoded map[string]any
if err := json.Unmarshal(dimensionsRaw, &decoded); err == nil {
ev.Dimensions = decoded
}
}
return &ev, nil
}
func buildOpsAlertEventsWhere(filter *service.OpsAlertEventFilter) (string, []any) {
clauses := []string{"1=1"}
args := []any{}
if filter == nil {
return "WHERE " + strings.Join(clauses, " AND "), args
}
if status := strings.TrimSpace(filter.Status); status != "" {
args = append(args, status)
clauses = append(clauses, "status = $"+itoa(len(args)))
}
if severity := strings.TrimSpace(filter.Severity); severity != "" {
args = append(args, severity)
clauses = append(clauses, "severity = $"+itoa(len(args)))
}
if filter.EmailSent != nil {
args = append(args, *filter.EmailSent)
clauses = append(clauses, "email_sent = $"+itoa(len(args)))
}
if filter.StartTime != nil && !filter.StartTime.IsZero() {
args = append(args, *filter.StartTime)
clauses = append(clauses, "fired_at >= $"+itoa(len(args)))
}
if filter.EndTime != nil && !filter.EndTime.IsZero() {
args = append(args, *filter.EndTime)
clauses = append(clauses, "fired_at < $"+itoa(len(args)))
}
// Cursor pagination (descending by fired_at, then id)
if filter.BeforeFiredAt != nil && !filter.BeforeFiredAt.IsZero() && filter.BeforeID != nil && *filter.BeforeID > 0 {
args = append(args, *filter.BeforeFiredAt)
tsArg := "$" + itoa(len(args))
args = append(args, *filter.BeforeID)
idArg := "$" + itoa(len(args))
clauses = append(clauses, fmt.Sprintf("(fired_at < %s OR (fired_at = %s AND id < %s))", tsArg, tsArg, idArg))
}
// Dimensions are stored in JSONB. We filter best-effort without requiring GIN indexes.
if platform := strings.TrimSpace(filter.Platform); platform != "" {
args = append(args, platform)
clauses = append(clauses, "(dimensions->>'platform') = $"+itoa(len(args)))
}
if filter.GroupID != nil && *filter.GroupID > 0 {
args = append(args, fmt.Sprintf("%d", *filter.GroupID))
clauses = append(clauses, "(dimensions->>'group_id') = $"+itoa(len(args)))
}
return "WHERE " + strings.Join(clauses, " AND "), args
}
func opsNullJSONMap(v map[string]any) (any, error) {
if v == nil {
return sql.NullString{}, nil
}
b, err := json.Marshal(v)
if err != nil {
return nil, err
}
if len(b) == 0 {
return sql.NullString{}, nil
}
return sql.NullString{String: string(b), Valid: true}, nil
}

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,79 @@
package repository
import (
"context"
"fmt"
"strings"
"github.com/Wei-Shaw/sub2api/internal/service"
)
func (r *opsRepository) GetLatencyHistogram(ctx context.Context, filter *service.OpsDashboardFilter) (*service.OpsLatencyHistogramResponse, error) {
if r == nil || r.db == nil {
return nil, fmt.Errorf("nil ops repository")
}
if filter == nil {
return nil, fmt.Errorf("nil filter")
}
if filter.StartTime.IsZero() || filter.EndTime.IsZero() {
return nil, fmt.Errorf("start_time/end_time required")
}
start := filter.StartTime.UTC()
end := filter.EndTime.UTC()
join, where, args, _ := buildUsageWhere(filter, start, end, 1)
rangeExpr := latencyHistogramRangeCaseExpr("ul.duration_ms")
orderExpr := latencyHistogramRangeOrderCaseExpr("ul.duration_ms")
q := `
SELECT
` + rangeExpr + ` AS range,
COALESCE(COUNT(*), 0) AS count,
` + orderExpr + ` AS ord
FROM usage_logs ul
` + join + `
` + where + `
AND ul.duration_ms IS NOT NULL
GROUP BY 1, 3
ORDER BY 3 ASC`
rows, err := r.db.QueryContext(ctx, q, args...)
if err != nil {
return nil, err
}
defer func() { _ = rows.Close() }()
counts := make(map[string]int64, len(latencyHistogramOrderedRanges))
var total int64
for rows.Next() {
var label string
var count int64
var _ord int
if err := rows.Scan(&label, &count, &_ord); err != nil {
return nil, err
}
counts[label] = count
total += count
}
if err := rows.Err(); err != nil {
return nil, err
}
buckets := make([]*service.OpsLatencyHistogramBucket, 0, len(latencyHistogramOrderedRanges))
for _, label := range latencyHistogramOrderedRanges {
buckets = append(buckets, &service.OpsLatencyHistogramBucket{
Range: label,
Count: counts[label],
})
}
return &service.OpsLatencyHistogramResponse{
StartTime: start,
EndTime: end,
Platform: strings.TrimSpace(filter.Platform),
GroupID: filter.GroupID,
TotalRequests: total,
Buckets: buckets,
}, nil
}

View File

@@ -0,0 +1,64 @@
package repository
import (
"fmt"
"strings"
)
type latencyHistogramBucket struct {
upperMs int
label string
}
var latencyHistogramBuckets = []latencyHistogramBucket{
{upperMs: 100, label: "0-100ms"},
{upperMs: 200, label: "100-200ms"},
{upperMs: 500, label: "200-500ms"},
{upperMs: 1000, label: "500-1000ms"},
{upperMs: 2000, label: "1000-2000ms"},
{upperMs: 0, label: "2000ms+"}, // default bucket
}
var latencyHistogramOrderedRanges = func() []string {
out := make([]string, 0, len(latencyHistogramBuckets))
for _, b := range latencyHistogramBuckets {
out = append(out, b.label)
}
return out
}()
func latencyHistogramRangeCaseExpr(column string) string {
var sb strings.Builder
_, _ = sb.WriteString("CASE\n")
for _, b := range latencyHistogramBuckets {
if b.upperMs <= 0 {
continue
}
_, _ = sb.WriteString(fmt.Sprintf("\tWHEN %s < %d THEN '%s'\n", column, b.upperMs, b.label))
}
// Default bucket.
last := latencyHistogramBuckets[len(latencyHistogramBuckets)-1]
_, _ = sb.WriteString(fmt.Sprintf("\tELSE '%s'\n", last.label))
_, _ = sb.WriteString("END")
return sb.String()
}
func latencyHistogramRangeOrderCaseExpr(column string) string {
var sb strings.Builder
_, _ = sb.WriteString("CASE\n")
order := 1
for _, b := range latencyHistogramBuckets {
if b.upperMs <= 0 {
continue
}
_, _ = sb.WriteString(fmt.Sprintf("\tWHEN %s < %d THEN %d\n", column, b.upperMs, order))
order++
}
_, _ = sb.WriteString(fmt.Sprintf("\tELSE %d\n", order))
_, _ = sb.WriteString("END")
return sb.String()
}

View File

@@ -0,0 +1,14 @@
package repository
import (
"testing"
"github.com/stretchr/testify/require"
)
func TestLatencyHistogramBuckets_AreConsistent(t *testing.T) {
require.Equal(t, len(latencyHistogramBuckets), len(latencyHistogramOrderedRanges))
for i, b := range latencyHistogramBuckets {
require.Equal(t, b.label, latencyHistogramOrderedRanges[i])
}
}

View File

@@ -0,0 +1,436 @@
package repository
import (
"context"
"database/sql"
"fmt"
"time"
"github.com/Wei-Shaw/sub2api/internal/service"
)
func (r *opsRepository) InsertSystemMetrics(ctx context.Context, input *service.OpsInsertSystemMetricsInput) error {
if r == nil || r.db == nil {
return fmt.Errorf("nil ops repository")
}
if input == nil {
return fmt.Errorf("nil input")
}
window := input.WindowMinutes
if window <= 0 {
window = 1
}
createdAt := input.CreatedAt
if createdAt.IsZero() {
createdAt = time.Now().UTC()
}
q := `
INSERT INTO ops_system_metrics (
created_at,
window_minutes,
platform,
group_id,
success_count,
error_count_total,
business_limited_count,
error_count_sla,
upstream_error_count_excl_429_529,
upstream_429_count,
upstream_529_count,
token_consumed,
qps,
tps,
duration_p50_ms,
duration_p90_ms,
duration_p95_ms,
duration_p99_ms,
duration_avg_ms,
duration_max_ms,
ttft_p50_ms,
ttft_p90_ms,
ttft_p95_ms,
ttft_p99_ms,
ttft_avg_ms,
ttft_max_ms,
cpu_usage_percent,
memory_used_mb,
memory_total_mb,
memory_usage_percent,
db_ok,
redis_ok,
redis_conn_total,
redis_conn_idle,
db_conn_active,
db_conn_idle,
db_conn_waiting,
goroutine_count,
concurrency_queue_depth
) VALUES (
$1,$2,$3,$4,
$5,$6,$7,$8,
$9,$10,$11,
$12,$13,$14,
$15,$16,$17,$18,$19,$20,
$21,$22,$23,$24,$25,$26,
$27,$28,$29,$30,
$31,$32,
$33,$34,
$35,$36,$37,
$38,$39
)`
_, err := r.db.ExecContext(
ctx,
q,
createdAt,
window,
opsNullString(input.Platform),
opsNullInt64(input.GroupID),
input.SuccessCount,
input.ErrorCountTotal,
input.BusinessLimitedCount,
input.ErrorCountSLA,
input.UpstreamErrorCountExcl429529,
input.Upstream429Count,
input.Upstream529Count,
input.TokenConsumed,
opsNullFloat64(input.QPS),
opsNullFloat64(input.TPS),
opsNullInt(input.DurationP50Ms),
opsNullInt(input.DurationP90Ms),
opsNullInt(input.DurationP95Ms),
opsNullInt(input.DurationP99Ms),
opsNullFloat64(input.DurationAvgMs),
opsNullInt(input.DurationMaxMs),
opsNullInt(input.TTFTP50Ms),
opsNullInt(input.TTFTP90Ms),
opsNullInt(input.TTFTP95Ms),
opsNullInt(input.TTFTP99Ms),
opsNullFloat64(input.TTFTAvgMs),
opsNullInt(input.TTFTMaxMs),
opsNullFloat64(input.CPUUsagePercent),
opsNullInt(input.MemoryUsedMB),
opsNullInt(input.MemoryTotalMB),
opsNullFloat64(input.MemoryUsagePercent),
opsNullBool(input.DBOK),
opsNullBool(input.RedisOK),
opsNullInt(input.RedisConnTotal),
opsNullInt(input.RedisConnIdle),
opsNullInt(input.DBConnActive),
opsNullInt(input.DBConnIdle),
opsNullInt(input.DBConnWaiting),
opsNullInt(input.GoroutineCount),
opsNullInt(input.ConcurrencyQueueDepth),
)
return err
}
func (r *opsRepository) GetLatestSystemMetrics(ctx context.Context, windowMinutes int) (*service.OpsSystemMetricsSnapshot, error) {
if r == nil || r.db == nil {
return nil, fmt.Errorf("nil ops repository")
}
if windowMinutes <= 0 {
windowMinutes = 1
}
q := `
SELECT
id,
created_at,
window_minutes,
cpu_usage_percent,
memory_used_mb,
memory_total_mb,
memory_usage_percent,
db_ok,
redis_ok,
redis_conn_total,
redis_conn_idle,
db_conn_active,
db_conn_idle,
db_conn_waiting,
goroutine_count,
concurrency_queue_depth
FROM ops_system_metrics
WHERE window_minutes = $1
AND platform IS NULL
AND group_id IS NULL
ORDER BY created_at DESC
LIMIT 1`
var out service.OpsSystemMetricsSnapshot
var cpu sql.NullFloat64
var memUsed sql.NullInt64
var memTotal sql.NullInt64
var memPct sql.NullFloat64
var dbOK sql.NullBool
var redisOK sql.NullBool
var redisTotal sql.NullInt64
var redisIdle sql.NullInt64
var dbActive sql.NullInt64
var dbIdle sql.NullInt64
var dbWaiting sql.NullInt64
var goroutines sql.NullInt64
var queueDepth sql.NullInt64
if err := r.db.QueryRowContext(ctx, q, windowMinutes).Scan(
&out.ID,
&out.CreatedAt,
&out.WindowMinutes,
&cpu,
&memUsed,
&memTotal,
&memPct,
&dbOK,
&redisOK,
&redisTotal,
&redisIdle,
&dbActive,
&dbIdle,
&dbWaiting,
&goroutines,
&queueDepth,
); err != nil {
return nil, err
}
if cpu.Valid {
v := cpu.Float64
out.CPUUsagePercent = &v
}
if memUsed.Valid {
v := memUsed.Int64
out.MemoryUsedMB = &v
}
if memTotal.Valid {
v := memTotal.Int64
out.MemoryTotalMB = &v
}
if memPct.Valid {
v := memPct.Float64
out.MemoryUsagePercent = &v
}
if dbOK.Valid {
v := dbOK.Bool
out.DBOK = &v
}
if redisOK.Valid {
v := redisOK.Bool
out.RedisOK = &v
}
if redisTotal.Valid {
v := int(redisTotal.Int64)
out.RedisConnTotal = &v
}
if redisIdle.Valid {
v := int(redisIdle.Int64)
out.RedisConnIdle = &v
}
if dbActive.Valid {
v := int(dbActive.Int64)
out.DBConnActive = &v
}
if dbIdle.Valid {
v := int(dbIdle.Int64)
out.DBConnIdle = &v
}
if dbWaiting.Valid {
v := int(dbWaiting.Int64)
out.DBConnWaiting = &v
}
if goroutines.Valid {
v := int(goroutines.Int64)
out.GoroutineCount = &v
}
if queueDepth.Valid {
v := int(queueDepth.Int64)
out.ConcurrencyQueueDepth = &v
}
return &out, nil
}
func (r *opsRepository) UpsertJobHeartbeat(ctx context.Context, input *service.OpsUpsertJobHeartbeatInput) error {
if r == nil || r.db == nil {
return fmt.Errorf("nil ops repository")
}
if input == nil {
return fmt.Errorf("nil input")
}
if input.JobName == "" {
return fmt.Errorf("job_name required")
}
q := `
INSERT INTO ops_job_heartbeats (
job_name,
last_run_at,
last_success_at,
last_error_at,
last_error,
last_duration_ms,
last_result,
updated_at
) VALUES (
$1,$2,$3,$4,$5,$6,$7,NOW()
)
ON CONFLICT (job_name) DO UPDATE SET
last_run_at = COALESCE(EXCLUDED.last_run_at, ops_job_heartbeats.last_run_at),
last_success_at = COALESCE(EXCLUDED.last_success_at, ops_job_heartbeats.last_success_at),
last_error_at = CASE
WHEN EXCLUDED.last_success_at IS NOT NULL THEN NULL
ELSE COALESCE(EXCLUDED.last_error_at, ops_job_heartbeats.last_error_at)
END,
last_error = CASE
WHEN EXCLUDED.last_success_at IS NOT NULL THEN NULL
ELSE COALESCE(EXCLUDED.last_error, ops_job_heartbeats.last_error)
END,
last_duration_ms = COALESCE(EXCLUDED.last_duration_ms, ops_job_heartbeats.last_duration_ms),
last_result = CASE
WHEN EXCLUDED.last_success_at IS NOT NULL THEN COALESCE(EXCLUDED.last_result, ops_job_heartbeats.last_result)
ELSE ops_job_heartbeats.last_result
END,
updated_at = NOW()`
_, err := r.db.ExecContext(
ctx,
q,
input.JobName,
opsNullTime(input.LastRunAt),
opsNullTime(input.LastSuccessAt),
opsNullTime(input.LastErrorAt),
opsNullString(input.LastError),
opsNullInt(input.LastDurationMs),
opsNullString(input.LastResult),
)
return err
}
func (r *opsRepository) ListJobHeartbeats(ctx context.Context) ([]*service.OpsJobHeartbeat, error) {
if r == nil || r.db == nil {
return nil, fmt.Errorf("nil ops repository")
}
q := `
SELECT
job_name,
last_run_at,
last_success_at,
last_error_at,
last_error,
last_duration_ms,
last_result,
updated_at
FROM ops_job_heartbeats
ORDER BY job_name ASC`
rows, err := r.db.QueryContext(ctx, q)
if err != nil {
return nil, err
}
defer func() { _ = rows.Close() }()
out := make([]*service.OpsJobHeartbeat, 0, 8)
for rows.Next() {
var item service.OpsJobHeartbeat
var lastRun sql.NullTime
var lastSuccess sql.NullTime
var lastErrorAt sql.NullTime
var lastError sql.NullString
var lastDuration sql.NullInt64
var lastResult sql.NullString
if err := rows.Scan(
&item.JobName,
&lastRun,
&lastSuccess,
&lastErrorAt,
&lastError,
&lastDuration,
&lastResult,
&item.UpdatedAt,
); err != nil {
return nil, err
}
if lastRun.Valid {
v := lastRun.Time
item.LastRunAt = &v
}
if lastSuccess.Valid {
v := lastSuccess.Time
item.LastSuccessAt = &v
}
if lastErrorAt.Valid {
v := lastErrorAt.Time
item.LastErrorAt = &v
}
if lastError.Valid {
v := lastError.String
item.LastError = &v
}
if lastDuration.Valid {
v := lastDuration.Int64
item.LastDurationMs = &v
}
if lastResult.Valid {
v := lastResult.String
item.LastResult = &v
}
out = append(out, &item)
}
if err := rows.Err(); err != nil {
return nil, err
}
return out, nil
}
func opsNullBool(v *bool) any {
if v == nil {
return sql.NullBool{}
}
return sql.NullBool{Bool: *v, Valid: true}
}
func opsNullFloat64(v *float64) any {
if v == nil {
return sql.NullFloat64{}
}
return sql.NullFloat64{Float64: *v, Valid: true}
}
func opsNullTime(v *time.Time) any {
if v == nil || v.IsZero() {
return sql.NullTime{}
}
return sql.NullTime{Time: *v, Valid: true}
}

View File

@@ -0,0 +1,363 @@
package repository
import (
"context"
"database/sql"
"fmt"
"time"
)
func (r *opsRepository) UpsertHourlyMetrics(ctx context.Context, startTime, endTime time.Time) error {
if r == nil || r.db == nil {
return fmt.Errorf("nil ops repository")
}
if startTime.IsZero() || endTime.IsZero() || !endTime.After(startTime) {
return nil
}
start := startTime.UTC()
end := endTime.UTC()
// NOTE:
// - We aggregate usage_logs + ops_error_logs into ops_metrics_hourly.
// - We emit three dimension granularities via GROUPING SETS:
// 1) overall: (bucket_start)
// 2) platform: (bucket_start, platform)
// 3) group: (bucket_start, platform, group_id)
//
// IMPORTANT: Postgres UNIQUE treats NULLs as distinct, so the table uses a COALESCE-based
// unique index; our ON CONFLICT target must match that expression set.
q := `
WITH usage_base AS (
SELECT
date_trunc('hour', ul.created_at AT TIME ZONE 'UTC') AT TIME ZONE 'UTC' AS bucket_start,
g.platform AS platform,
ul.group_id AS group_id,
ul.duration_ms AS duration_ms,
ul.first_token_ms AS first_token_ms,
(ul.input_tokens + ul.output_tokens + ul.cache_creation_tokens + ul.cache_read_tokens) AS tokens
FROM usage_logs ul
JOIN groups g ON g.id = ul.group_id
WHERE ul.created_at >= $1 AND ul.created_at < $2
),
usage_agg AS (
SELECT
bucket_start,
CASE WHEN GROUPING(platform) = 1 THEN NULL ELSE platform END AS platform,
CASE WHEN GROUPING(group_id) = 1 THEN NULL ELSE group_id END AS group_id,
COUNT(*) AS success_count,
COALESCE(SUM(tokens), 0) AS token_consumed,
percentile_cont(0.50) WITHIN GROUP (ORDER BY duration_ms) FILTER (WHERE duration_ms IS NOT NULL) AS duration_p50_ms,
percentile_cont(0.90) WITHIN GROUP (ORDER BY duration_ms) FILTER (WHERE duration_ms IS NOT NULL) AS duration_p90_ms,
percentile_cont(0.95) WITHIN GROUP (ORDER BY duration_ms) FILTER (WHERE duration_ms IS NOT NULL) AS duration_p95_ms,
percentile_cont(0.99) WITHIN GROUP (ORDER BY duration_ms) FILTER (WHERE duration_ms IS NOT NULL) AS duration_p99_ms,
AVG(duration_ms) FILTER (WHERE duration_ms IS NOT NULL) AS duration_avg_ms,
MAX(duration_ms) AS duration_max_ms,
percentile_cont(0.50) WITHIN GROUP (ORDER BY first_token_ms) FILTER (WHERE first_token_ms IS NOT NULL) AS ttft_p50_ms,
percentile_cont(0.90) WITHIN GROUP (ORDER BY first_token_ms) FILTER (WHERE first_token_ms IS NOT NULL) AS ttft_p90_ms,
percentile_cont(0.95) WITHIN GROUP (ORDER BY first_token_ms) FILTER (WHERE first_token_ms IS NOT NULL) AS ttft_p95_ms,
percentile_cont(0.99) WITHIN GROUP (ORDER BY first_token_ms) FILTER (WHERE first_token_ms IS NOT NULL) AS ttft_p99_ms,
AVG(first_token_ms) FILTER (WHERE first_token_ms IS NOT NULL) AS ttft_avg_ms,
MAX(first_token_ms) AS ttft_max_ms
FROM usage_base
GROUP BY GROUPING SETS (
(bucket_start),
(bucket_start, platform),
(bucket_start, platform, group_id)
)
),
error_base AS (
SELECT
date_trunc('hour', created_at AT TIME ZONE 'UTC') AT TIME ZONE 'UTC' AS bucket_start,
-- platform is NULL for some early-phase errors (e.g. before routing); map to a sentinel
-- value so platform-level GROUPING SETS don't collide with the overall (platform=NULL) row.
COALESCE(platform, 'unknown') AS platform,
group_id AS group_id,
is_business_limited AS is_business_limited,
error_owner AS error_owner,
status_code AS client_status_code,
COALESCE(upstream_status_code, status_code, 0) AS effective_status_code
FROM ops_error_logs
-- Exclude count_tokens requests from error metrics as they are informational probes
WHERE created_at >= $1 AND created_at < $2
AND is_count_tokens = FALSE
),
error_agg AS (
SELECT
bucket_start,
CASE WHEN GROUPING(platform) = 1 THEN NULL ELSE platform END AS platform,
CASE WHEN GROUPING(group_id) = 1 THEN NULL ELSE group_id END AS group_id,
COUNT(*) FILTER (WHERE COALESCE(client_status_code, 0) >= 400) AS error_count_total,
COUNT(*) FILTER (WHERE COALESCE(client_status_code, 0) >= 400 AND is_business_limited) AS business_limited_count,
COUNT(*) FILTER (WHERE COALESCE(client_status_code, 0) >= 400 AND NOT is_business_limited) AS error_count_sla,
COUNT(*) FILTER (WHERE error_owner = 'provider' AND NOT is_business_limited AND COALESCE(effective_status_code, 0) NOT IN (429, 529)) AS upstream_error_count_excl_429_529,
COUNT(*) FILTER (WHERE error_owner = 'provider' AND NOT is_business_limited AND COALESCE(effective_status_code, 0) = 429) AS upstream_429_count,
COUNT(*) FILTER (WHERE error_owner = 'provider' AND NOT is_business_limited AND COALESCE(effective_status_code, 0) = 529) AS upstream_529_count
FROM error_base
GROUP BY GROUPING SETS (
(bucket_start),
(bucket_start, platform),
(bucket_start, platform, group_id)
)
HAVING GROUPING(group_id) = 1 OR group_id IS NOT NULL
),
combined AS (
SELECT
COALESCE(u.bucket_start, e.bucket_start) AS bucket_start,
COALESCE(u.platform, e.platform) AS platform,
COALESCE(u.group_id, e.group_id) AS group_id,
COALESCE(u.success_count, 0) AS success_count,
COALESCE(e.error_count_total, 0) AS error_count_total,
COALESCE(e.business_limited_count, 0) AS business_limited_count,
COALESCE(e.error_count_sla, 0) AS error_count_sla,
COALESCE(e.upstream_error_count_excl_429_529, 0) AS upstream_error_count_excl_429_529,
COALESCE(e.upstream_429_count, 0) AS upstream_429_count,
COALESCE(e.upstream_529_count, 0) AS upstream_529_count,
COALESCE(u.token_consumed, 0) AS token_consumed,
u.duration_p50_ms,
u.duration_p90_ms,
u.duration_p95_ms,
u.duration_p99_ms,
u.duration_avg_ms,
u.duration_max_ms,
u.ttft_p50_ms,
u.ttft_p90_ms,
u.ttft_p95_ms,
u.ttft_p99_ms,
u.ttft_avg_ms,
u.ttft_max_ms
FROM usage_agg u
FULL OUTER JOIN error_agg e
ON u.bucket_start = e.bucket_start
AND COALESCE(u.platform, '') = COALESCE(e.platform, '')
AND COALESCE(u.group_id, 0) = COALESCE(e.group_id, 0)
)
INSERT INTO ops_metrics_hourly (
bucket_start,
platform,
group_id,
success_count,
error_count_total,
business_limited_count,
error_count_sla,
upstream_error_count_excl_429_529,
upstream_429_count,
upstream_529_count,
token_consumed,
duration_p50_ms,
duration_p90_ms,
duration_p95_ms,
duration_p99_ms,
duration_avg_ms,
duration_max_ms,
ttft_p50_ms,
ttft_p90_ms,
ttft_p95_ms,
ttft_p99_ms,
ttft_avg_ms,
ttft_max_ms,
computed_at
)
SELECT
bucket_start,
NULLIF(platform, '') AS platform,
group_id,
success_count,
error_count_total,
business_limited_count,
error_count_sla,
upstream_error_count_excl_429_529,
upstream_429_count,
upstream_529_count,
token_consumed,
duration_p50_ms::int,
duration_p90_ms::int,
duration_p95_ms::int,
duration_p99_ms::int,
duration_avg_ms,
duration_max_ms::int,
ttft_p50_ms::int,
ttft_p90_ms::int,
ttft_p95_ms::int,
ttft_p99_ms::int,
ttft_avg_ms,
ttft_max_ms::int,
NOW()
FROM combined
WHERE bucket_start IS NOT NULL
AND (platform IS NULL OR platform <> '')
ON CONFLICT (bucket_start, COALESCE(platform, ''), COALESCE(group_id, 0)) DO UPDATE SET
success_count = EXCLUDED.success_count,
error_count_total = EXCLUDED.error_count_total,
business_limited_count = EXCLUDED.business_limited_count,
error_count_sla = EXCLUDED.error_count_sla,
upstream_error_count_excl_429_529 = EXCLUDED.upstream_error_count_excl_429_529,
upstream_429_count = EXCLUDED.upstream_429_count,
upstream_529_count = EXCLUDED.upstream_529_count,
token_consumed = EXCLUDED.token_consumed,
duration_p50_ms = EXCLUDED.duration_p50_ms,
duration_p90_ms = EXCLUDED.duration_p90_ms,
duration_p95_ms = EXCLUDED.duration_p95_ms,
duration_p99_ms = EXCLUDED.duration_p99_ms,
duration_avg_ms = EXCLUDED.duration_avg_ms,
duration_max_ms = EXCLUDED.duration_max_ms,
ttft_p50_ms = EXCLUDED.ttft_p50_ms,
ttft_p90_ms = EXCLUDED.ttft_p90_ms,
ttft_p95_ms = EXCLUDED.ttft_p95_ms,
ttft_p99_ms = EXCLUDED.ttft_p99_ms,
ttft_avg_ms = EXCLUDED.ttft_avg_ms,
ttft_max_ms = EXCLUDED.ttft_max_ms,
computed_at = NOW()
`
_, err := r.db.ExecContext(ctx, q, start, end)
return err
}
func (r *opsRepository) UpsertDailyMetrics(ctx context.Context, startTime, endTime time.Time) error {
if r == nil || r.db == nil {
return fmt.Errorf("nil ops repository")
}
if startTime.IsZero() || endTime.IsZero() || !endTime.After(startTime) {
return nil
}
start := startTime.UTC()
end := endTime.UTC()
q := `
INSERT INTO ops_metrics_daily (
bucket_date,
platform,
group_id,
success_count,
error_count_total,
business_limited_count,
error_count_sla,
upstream_error_count_excl_429_529,
upstream_429_count,
upstream_529_count,
token_consumed,
duration_p50_ms,
duration_p90_ms,
duration_p95_ms,
duration_p99_ms,
duration_avg_ms,
duration_max_ms,
ttft_p50_ms,
ttft_p90_ms,
ttft_p95_ms,
ttft_p99_ms,
ttft_avg_ms,
ttft_max_ms,
computed_at
)
SELECT
(bucket_start AT TIME ZONE 'UTC')::date AS bucket_date,
platform,
group_id,
COALESCE(SUM(success_count), 0) AS success_count,
COALESCE(SUM(error_count_total), 0) AS error_count_total,
COALESCE(SUM(business_limited_count), 0) AS business_limited_count,
COALESCE(SUM(error_count_sla), 0) AS error_count_sla,
COALESCE(SUM(upstream_error_count_excl_429_529), 0) AS upstream_error_count_excl_429_529,
COALESCE(SUM(upstream_429_count), 0) AS upstream_429_count,
COALESCE(SUM(upstream_529_count), 0) AS upstream_529_count,
COALESCE(SUM(token_consumed), 0) AS token_consumed,
-- Approximation: weighted average for p50/p90, max for p95/p99 (conservative tail).
ROUND(SUM(duration_p50_ms::double precision * success_count) FILTER (WHERE duration_p50_ms IS NOT NULL)
/ NULLIF(SUM(success_count) FILTER (WHERE duration_p50_ms IS NOT NULL), 0))::int AS duration_p50_ms,
ROUND(SUM(duration_p90_ms::double precision * success_count) FILTER (WHERE duration_p90_ms IS NOT NULL)
/ NULLIF(SUM(success_count) FILTER (WHERE duration_p90_ms IS NOT NULL), 0))::int AS duration_p90_ms,
MAX(duration_p95_ms) AS duration_p95_ms,
MAX(duration_p99_ms) AS duration_p99_ms,
SUM(duration_avg_ms * success_count) FILTER (WHERE duration_avg_ms IS NOT NULL)
/ NULLIF(SUM(success_count) FILTER (WHERE duration_avg_ms IS NOT NULL), 0) AS duration_avg_ms,
MAX(duration_max_ms) AS duration_max_ms,
ROUND(SUM(ttft_p50_ms::double precision * success_count) FILTER (WHERE ttft_p50_ms IS NOT NULL)
/ NULLIF(SUM(success_count) FILTER (WHERE ttft_p50_ms IS NOT NULL), 0))::int AS ttft_p50_ms,
ROUND(SUM(ttft_p90_ms::double precision * success_count) FILTER (WHERE ttft_p90_ms IS NOT NULL)
/ NULLIF(SUM(success_count) FILTER (WHERE ttft_p90_ms IS NOT NULL), 0))::int AS ttft_p90_ms,
MAX(ttft_p95_ms) AS ttft_p95_ms,
MAX(ttft_p99_ms) AS ttft_p99_ms,
SUM(ttft_avg_ms * success_count) FILTER (WHERE ttft_avg_ms IS NOT NULL)
/ NULLIF(SUM(success_count) FILTER (WHERE ttft_avg_ms IS NOT NULL), 0) AS ttft_avg_ms,
MAX(ttft_max_ms) AS ttft_max_ms,
NOW()
FROM ops_metrics_hourly
WHERE bucket_start >= $1 AND bucket_start < $2
GROUP BY 1, 2, 3
ON CONFLICT (bucket_date, COALESCE(platform, ''), COALESCE(group_id, 0)) DO UPDATE SET
success_count = EXCLUDED.success_count,
error_count_total = EXCLUDED.error_count_total,
business_limited_count = EXCLUDED.business_limited_count,
error_count_sla = EXCLUDED.error_count_sla,
upstream_error_count_excl_429_529 = EXCLUDED.upstream_error_count_excl_429_529,
upstream_429_count = EXCLUDED.upstream_429_count,
upstream_529_count = EXCLUDED.upstream_529_count,
token_consumed = EXCLUDED.token_consumed,
duration_p50_ms = EXCLUDED.duration_p50_ms,
duration_p90_ms = EXCLUDED.duration_p90_ms,
duration_p95_ms = EXCLUDED.duration_p95_ms,
duration_p99_ms = EXCLUDED.duration_p99_ms,
duration_avg_ms = EXCLUDED.duration_avg_ms,
duration_max_ms = EXCLUDED.duration_max_ms,
ttft_p50_ms = EXCLUDED.ttft_p50_ms,
ttft_p90_ms = EXCLUDED.ttft_p90_ms,
ttft_p95_ms = EXCLUDED.ttft_p95_ms,
ttft_p99_ms = EXCLUDED.ttft_p99_ms,
ttft_avg_ms = EXCLUDED.ttft_avg_ms,
ttft_max_ms = EXCLUDED.ttft_max_ms,
computed_at = NOW()
`
_, err := r.db.ExecContext(ctx, q, start, end)
return err
}
func (r *opsRepository) GetLatestHourlyBucketStart(ctx context.Context) (time.Time, bool, error) {
if r == nil || r.db == nil {
return time.Time{}, false, fmt.Errorf("nil ops repository")
}
var value sql.NullTime
if err := r.db.QueryRowContext(ctx, `SELECT MAX(bucket_start) FROM ops_metrics_hourly`).Scan(&value); err != nil {
return time.Time{}, false, err
}
if !value.Valid {
return time.Time{}, false, nil
}
return value.Time.UTC(), true, nil
}
func (r *opsRepository) GetLatestDailyBucketDate(ctx context.Context) (time.Time, bool, error) {
if r == nil || r.db == nil {
return time.Time{}, false, fmt.Errorf("nil ops repository")
}
var value sql.NullTime
if err := r.db.QueryRowContext(ctx, `SELECT MAX(bucket_date) FROM ops_metrics_daily`).Scan(&value); err != nil {
return time.Time{}, false, err
}
if !value.Valid {
return time.Time{}, false, nil
}
t := value.Time.UTC()
return time.Date(t.Year(), t.Month(), t.Day(), 0, 0, 0, 0, time.UTC), true, nil
}

View File

@@ -0,0 +1,129 @@
package repository
import (
"context"
"fmt"
"strings"
"time"
"github.com/Wei-Shaw/sub2api/internal/service"
)
func (r *opsRepository) GetRealtimeTrafficSummary(ctx context.Context, filter *service.OpsDashboardFilter) (*service.OpsRealtimeTrafficSummary, error) {
if r == nil || r.db == nil {
return nil, fmt.Errorf("nil ops repository")
}
if filter == nil {
return nil, fmt.Errorf("nil filter")
}
if filter.StartTime.IsZero() || filter.EndTime.IsZero() {
return nil, fmt.Errorf("start_time/end_time required")
}
start := filter.StartTime.UTC()
end := filter.EndTime.UTC()
if start.After(end) {
return nil, fmt.Errorf("start_time must be <= end_time")
}
window := end.Sub(start)
if window <= 0 {
return nil, fmt.Errorf("invalid time window")
}
if window > time.Hour {
return nil, fmt.Errorf("window too large")
}
usageJoin, usageWhere, usageArgs, next := buildUsageWhere(filter, start, end, 1)
errorWhere, errorArgs, _ := buildErrorWhere(filter, start, end, next)
q := `
WITH usage_buckets AS (
SELECT
date_trunc('minute', ul.created_at) AS bucket,
COALESCE(COUNT(*), 0) AS success_count,
COALESCE(SUM(input_tokens + output_tokens + cache_creation_tokens + cache_read_tokens), 0) AS token_sum
FROM usage_logs ul
` + usageJoin + `
` + usageWhere + `
GROUP BY 1
),
error_buckets AS (
SELECT
date_trunc('minute', created_at) AS bucket,
COALESCE(COUNT(*), 0) AS error_count
FROM ops_error_logs
` + errorWhere + `
AND COALESCE(status_code, 0) >= 400
GROUP BY 1
),
combined AS (
SELECT
COALESCE(u.bucket, e.bucket) AS bucket,
COALESCE(u.success_count, 0) AS success_count,
COALESCE(u.token_sum, 0) AS token_sum,
COALESCE(e.error_count, 0) AS error_count,
COALESCE(u.success_count, 0) + COALESCE(e.error_count, 0) AS request_total
FROM usage_buckets u
FULL OUTER JOIN error_buckets e ON u.bucket = e.bucket
)
SELECT
COALESCE(SUM(success_count), 0) AS success_total,
COALESCE(SUM(error_count), 0) AS error_total,
COALESCE(SUM(token_sum), 0) AS token_total,
COALESCE(MAX(request_total), 0) AS peak_requests_per_min,
COALESCE(MAX(token_sum), 0) AS peak_tokens_per_min
FROM combined`
args := append(usageArgs, errorArgs...)
var successCount int64
var errorTotal int64
var tokenConsumed int64
var peakRequestsPerMin int64
var peakTokensPerMin int64
if err := r.db.QueryRowContext(ctx, q, args...).Scan(
&successCount,
&errorTotal,
&tokenConsumed,
&peakRequestsPerMin,
&peakTokensPerMin,
); err != nil {
return nil, err
}
windowSeconds := window.Seconds()
if windowSeconds <= 0 {
windowSeconds = 1
}
requestCountTotal := successCount + errorTotal
qpsAvg := roundTo1DP(float64(requestCountTotal) / windowSeconds)
tpsAvg := roundTo1DP(float64(tokenConsumed) / windowSeconds)
// Keep "current" consistent with the dashboard overview semantics: last 1 minute.
// This remains "within the selected window" since end=start+window.
qpsCurrent, tpsCurrent, err := r.queryCurrentRates(ctx, filter, end)
if err != nil {
return nil, err
}
qpsPeak := roundTo1DP(float64(peakRequestsPerMin) / 60.0)
tpsPeak := roundTo1DP(float64(peakTokensPerMin) / 60.0)
return &service.OpsRealtimeTrafficSummary{
StartTime: start,
EndTime: end,
Platform: strings.TrimSpace(filter.Platform),
GroupID: filter.GroupID,
QPS: service.OpsRateSummary{
Current: qpsCurrent,
Peak: qpsPeak,
Avg: qpsAvg,
},
TPS: service.OpsRateSummary{
Current: tpsCurrent,
Peak: tpsPeak,
Avg: tpsAvg,
},
}, nil
}

View File

@@ -0,0 +1,286 @@
package repository
import (
"context"
"database/sql"
"fmt"
"strings"
"time"
"github.com/Wei-Shaw/sub2api/internal/service"
)
func (r *opsRepository) ListRequestDetails(ctx context.Context, filter *service.OpsRequestDetailFilter) ([]*service.OpsRequestDetail, int64, error) {
if r == nil || r.db == nil {
return nil, 0, fmt.Errorf("nil ops repository")
}
page, pageSize, startTime, endTime := filter.Normalize()
offset := (page - 1) * pageSize
conditions := make([]string, 0, 16)
args := make([]any, 0, 24)
// Placeholders $1/$2 reserved for time window inside the CTE.
args = append(args, startTime.UTC(), endTime.UTC())
addCondition := func(condition string, values ...any) {
conditions = append(conditions, condition)
args = append(args, values...)
}
if filter != nil {
if kind := strings.TrimSpace(strings.ToLower(filter.Kind)); kind != "" && kind != "all" {
if kind != string(service.OpsRequestKindSuccess) && kind != string(service.OpsRequestKindError) {
return nil, 0, fmt.Errorf("invalid kind")
}
addCondition(fmt.Sprintf("kind = $%d", len(args)+1), kind)
}
if platform := strings.TrimSpace(strings.ToLower(filter.Platform)); platform != "" {
addCondition(fmt.Sprintf("platform = $%d", len(args)+1), platform)
}
if filter.GroupID != nil && *filter.GroupID > 0 {
addCondition(fmt.Sprintf("group_id = $%d", len(args)+1), *filter.GroupID)
}
if filter.UserID != nil && *filter.UserID > 0 {
addCondition(fmt.Sprintf("user_id = $%d", len(args)+1), *filter.UserID)
}
if filter.APIKeyID != nil && *filter.APIKeyID > 0 {
addCondition(fmt.Sprintf("api_key_id = $%d", len(args)+1), *filter.APIKeyID)
}
if filter.AccountID != nil && *filter.AccountID > 0 {
addCondition(fmt.Sprintf("account_id = $%d", len(args)+1), *filter.AccountID)
}
if model := strings.TrimSpace(filter.Model); model != "" {
addCondition(fmt.Sprintf("model = $%d", len(args)+1), model)
}
if requestID := strings.TrimSpace(filter.RequestID); requestID != "" {
addCondition(fmt.Sprintf("request_id = $%d", len(args)+1), requestID)
}
if q := strings.TrimSpace(filter.Query); q != "" {
like := "%" + strings.ToLower(q) + "%"
startIdx := len(args) + 1
addCondition(
fmt.Sprintf("(LOWER(COALESCE(request_id,'')) LIKE $%d OR LOWER(COALESCE(model,'')) LIKE $%d OR LOWER(COALESCE(message,'')) LIKE $%d)",
startIdx, startIdx+1, startIdx+2,
),
like, like, like,
)
}
if filter.MinDurationMs != nil {
addCondition(fmt.Sprintf("duration_ms >= $%d", len(args)+1), *filter.MinDurationMs)
}
if filter.MaxDurationMs != nil {
addCondition(fmt.Sprintf("duration_ms <= $%d", len(args)+1), *filter.MaxDurationMs)
}
}
where := ""
if len(conditions) > 0 {
where = "WHERE " + strings.Join(conditions, " AND ")
}
cte := `
WITH combined AS (
SELECT
'success'::TEXT AS kind,
ul.created_at AS created_at,
ul.request_id AS request_id,
COALESCE(NULLIF(g.platform, ''), NULLIF(a.platform, ''), '') AS platform,
ul.model AS model,
ul.duration_ms AS duration_ms,
NULL::INT AS status_code,
NULL::BIGINT AS error_id,
NULL::TEXT AS phase,
NULL::TEXT AS severity,
NULL::TEXT AS message,
ul.user_id AS user_id,
ul.api_key_id AS api_key_id,
ul.account_id AS account_id,
ul.group_id AS group_id,
ul.stream AS stream
FROM usage_logs ul
LEFT JOIN groups g ON g.id = ul.group_id
LEFT JOIN accounts a ON a.id = ul.account_id
WHERE ul.created_at >= $1 AND ul.created_at < $2
UNION ALL
SELECT
'error'::TEXT AS kind,
o.created_at AS created_at,
COALESCE(NULLIF(o.request_id,''), NULLIF(o.client_request_id,''), '') AS request_id,
COALESCE(NULLIF(o.platform, ''), NULLIF(g.platform, ''), NULLIF(a.platform, ''), '') AS platform,
o.model AS model,
o.duration_ms AS duration_ms,
o.status_code AS status_code,
o.id AS error_id,
o.error_phase AS phase,
o.severity AS severity,
o.error_message AS message,
o.user_id AS user_id,
o.api_key_id AS api_key_id,
o.account_id AS account_id,
o.group_id AS group_id,
o.stream AS stream
FROM ops_error_logs o
LEFT JOIN groups g ON g.id = o.group_id
LEFT JOIN accounts a ON a.id = o.account_id
WHERE o.created_at >= $1 AND o.created_at < $2
AND COALESCE(o.status_code, 0) >= 400
)
`
countQuery := fmt.Sprintf(`%s SELECT COUNT(1) FROM combined %s`, cte, where)
var total int64
if err := r.db.QueryRowContext(ctx, countQuery, args...).Scan(&total); err != nil {
if err == sql.ErrNoRows {
total = 0
} else {
return nil, 0, err
}
}
sort := "ORDER BY created_at DESC"
if filter != nil {
switch strings.TrimSpace(strings.ToLower(filter.Sort)) {
case "", "created_at_desc":
// default
case "duration_desc":
sort = "ORDER BY duration_ms DESC NULLS LAST, created_at DESC"
default:
return nil, 0, fmt.Errorf("invalid sort")
}
}
listQuery := fmt.Sprintf(`
%s
SELECT
kind,
created_at,
request_id,
platform,
model,
duration_ms,
status_code,
error_id,
phase,
severity,
message,
user_id,
api_key_id,
account_id,
group_id,
stream
FROM combined
%s
%s
LIMIT $%d OFFSET $%d
`, cte, where, sort, len(args)+1, len(args)+2)
listArgs := append(append([]any{}, args...), pageSize, offset)
rows, err := r.db.QueryContext(ctx, listQuery, listArgs...)
if err != nil {
return nil, 0, err
}
defer func() { _ = rows.Close() }()
toIntPtr := func(v sql.NullInt64) *int {
if !v.Valid {
return nil
}
i := int(v.Int64)
return &i
}
toInt64Ptr := func(v sql.NullInt64) *int64 {
if !v.Valid {
return nil
}
i := v.Int64
return &i
}
out := make([]*service.OpsRequestDetail, 0, pageSize)
for rows.Next() {
var (
kind string
createdAt time.Time
requestID sql.NullString
platform sql.NullString
model sql.NullString
durationMs sql.NullInt64
statusCode sql.NullInt64
errorID sql.NullInt64
phase sql.NullString
severity sql.NullString
message sql.NullString
userID sql.NullInt64
apiKeyID sql.NullInt64
accountID sql.NullInt64
groupID sql.NullInt64
stream bool
)
if err := rows.Scan(
&kind,
&createdAt,
&requestID,
&platform,
&model,
&durationMs,
&statusCode,
&errorID,
&phase,
&severity,
&message,
&userID,
&apiKeyID,
&accountID,
&groupID,
&stream,
); err != nil {
return nil, 0, err
}
item := &service.OpsRequestDetail{
Kind: service.OpsRequestKind(kind),
CreatedAt: createdAt,
RequestID: strings.TrimSpace(requestID.String),
Platform: strings.TrimSpace(platform.String),
Model: strings.TrimSpace(model.String),
DurationMs: toIntPtr(durationMs),
StatusCode: toIntPtr(statusCode),
ErrorID: toInt64Ptr(errorID),
Phase: phase.String,
Severity: severity.String,
Message: message.String,
UserID: toInt64Ptr(userID),
APIKeyID: toInt64Ptr(apiKeyID),
AccountID: toInt64Ptr(accountID),
GroupID: toInt64Ptr(groupID),
Stream: stream,
}
if item.Platform == "" {
item.Platform = "unknown"
}
out = append(out, item)
}
if err := rows.Err(); err != nil {
return nil, 0, err
}
return out, total, nil
}

View File

@@ -0,0 +1,573 @@
package repository
import (
"context"
"database/sql"
"fmt"
"strings"
"time"
"github.com/Wei-Shaw/sub2api/internal/service"
)
func (r *opsRepository) GetThroughputTrend(ctx context.Context, filter *service.OpsDashboardFilter, bucketSeconds int) (*service.OpsThroughputTrendResponse, error) {
if r == nil || r.db == nil {
return nil, fmt.Errorf("nil ops repository")
}
if filter == nil {
return nil, fmt.Errorf("nil filter")
}
if filter.StartTime.IsZero() || filter.EndTime.IsZero() {
return nil, fmt.Errorf("start_time/end_time required")
}
if bucketSeconds <= 0 {
bucketSeconds = 60
}
if bucketSeconds != 60 && bucketSeconds != 300 && bucketSeconds != 3600 {
// Keep a small, predictable set of supported buckets for now.
bucketSeconds = 60
}
start := filter.StartTime.UTC()
end := filter.EndTime.UTC()
usageJoin, usageWhere, usageArgs, next := buildUsageWhere(filter, start, end, 1)
errorWhere, errorArgs, _ := buildErrorWhere(filter, start, end, next)
usageBucketExpr := opsBucketExprForUsage(bucketSeconds)
errorBucketExpr := opsBucketExprForError(bucketSeconds)
q := `
WITH usage_buckets AS (
SELECT ` + usageBucketExpr + ` AS bucket,
COUNT(*) AS success_count,
COALESCE(SUM(input_tokens + output_tokens + cache_creation_tokens + cache_read_tokens), 0) AS token_consumed
FROM usage_logs ul
` + usageJoin + `
` + usageWhere + `
GROUP BY 1
),
error_buckets AS (
SELECT ` + errorBucketExpr + ` AS bucket,
COUNT(*) AS error_count
FROM ops_error_logs
` + errorWhere + `
AND COALESCE(status_code, 0) >= 400
GROUP BY 1
),
combined AS (
SELECT COALESCE(u.bucket, e.bucket) AS bucket,
COALESCE(u.success_count, 0) AS success_count,
COALESCE(e.error_count, 0) AS error_count,
COALESCE(u.token_consumed, 0) AS token_consumed
FROM usage_buckets u
FULL OUTER JOIN error_buckets e ON u.bucket = e.bucket
)
SELECT
bucket,
(success_count + error_count) AS request_count,
token_consumed
FROM combined
ORDER BY bucket ASC`
args := append(usageArgs, errorArgs...)
rows, err := r.db.QueryContext(ctx, q, args...)
if err != nil {
return nil, err
}
defer func() { _ = rows.Close() }()
points := make([]*service.OpsThroughputTrendPoint, 0, 256)
for rows.Next() {
var bucket time.Time
var requests int64
var tokens sql.NullInt64
if err := rows.Scan(&bucket, &requests, &tokens); err != nil {
return nil, err
}
tokenConsumed := int64(0)
if tokens.Valid {
tokenConsumed = tokens.Int64
}
denom := float64(bucketSeconds)
if denom <= 0 {
denom = 60
}
qps := roundTo1DP(float64(requests) / denom)
tps := roundTo1DP(float64(tokenConsumed) / denom)
points = append(points, &service.OpsThroughputTrendPoint{
BucketStart: bucket.UTC(),
RequestCount: requests,
TokenConsumed: tokenConsumed,
QPS: qps,
TPS: tps,
})
}
if err := rows.Err(); err != nil {
return nil, err
}
// Fill missing buckets with zeros so charts render continuous timelines.
points = fillOpsThroughputBuckets(start, end, bucketSeconds, points)
var byPlatform []*service.OpsThroughputPlatformBreakdownItem
var topGroups []*service.OpsThroughputGroupBreakdownItem
platform := ""
if filter != nil {
platform = strings.TrimSpace(strings.ToLower(filter.Platform))
}
groupID := (*int64)(nil)
if filter != nil {
groupID = filter.GroupID
}
// Drilldown helpers:
// - No platform/group: totals by platform
// - Platform selected but no group: top groups in that platform
if platform == "" && (groupID == nil || *groupID <= 0) {
items, err := r.getThroughputBreakdownByPlatform(ctx, start, end)
if err != nil {
return nil, err
}
byPlatform = items
} else if platform != "" && (groupID == nil || *groupID <= 0) {
items, err := r.getThroughputTopGroupsByPlatform(ctx, start, end, platform, 10)
if err != nil {
return nil, err
}
topGroups = items
}
return &service.OpsThroughputTrendResponse{
Bucket: opsBucketLabel(bucketSeconds),
Points: points,
ByPlatform: byPlatform,
TopGroups: topGroups,
}, nil
}
func (r *opsRepository) getThroughputBreakdownByPlatform(ctx context.Context, start, end time.Time) ([]*service.OpsThroughputPlatformBreakdownItem, error) {
q := `
WITH usage_totals AS (
SELECT COALESCE(NULLIF(g.platform,''), a.platform) AS platform,
COUNT(*) AS success_count,
COALESCE(SUM(input_tokens + output_tokens + cache_creation_tokens + cache_read_tokens), 0) AS token_consumed
FROM usage_logs ul
LEFT JOIN groups g ON g.id = ul.group_id
LEFT JOIN accounts a ON a.id = ul.account_id
WHERE ul.created_at >= $1 AND ul.created_at < $2
GROUP BY 1
),
error_totals AS (
SELECT platform,
COUNT(*) AS error_count
FROM ops_error_logs
WHERE created_at >= $1 AND created_at < $2
AND COALESCE(status_code, 0) >= 400
AND is_count_tokens = FALSE -- 排除 count_tokens 请求的错误
GROUP BY 1
),
combined AS (
SELECT COALESCE(u.platform, e.platform) AS platform,
COALESCE(u.success_count, 0) AS success_count,
COALESCE(e.error_count, 0) AS error_count,
COALESCE(u.token_consumed, 0) AS token_consumed
FROM usage_totals u
FULL OUTER JOIN error_totals e ON u.platform = e.platform
)
SELECT platform, (success_count + error_count) AS request_count, token_consumed
FROM combined
WHERE platform IS NOT NULL AND platform <> ''
ORDER BY request_count DESC`
rows, err := r.db.QueryContext(ctx, q, start, end)
if err != nil {
return nil, err
}
defer func() { _ = rows.Close() }()
items := make([]*service.OpsThroughputPlatformBreakdownItem, 0, 8)
for rows.Next() {
var platform string
var requests int64
var tokens sql.NullInt64
if err := rows.Scan(&platform, &requests, &tokens); err != nil {
return nil, err
}
tokenConsumed := int64(0)
if tokens.Valid {
tokenConsumed = tokens.Int64
}
items = append(items, &service.OpsThroughputPlatformBreakdownItem{
Platform: platform,
RequestCount: requests,
TokenConsumed: tokenConsumed,
})
}
if err := rows.Err(); err != nil {
return nil, err
}
return items, nil
}
func (r *opsRepository) getThroughputTopGroupsByPlatform(ctx context.Context, start, end time.Time, platform string, limit int) ([]*service.OpsThroughputGroupBreakdownItem, error) {
if strings.TrimSpace(platform) == "" {
return nil, nil
}
if limit <= 0 || limit > 100 {
limit = 10
}
q := `
WITH usage_totals AS (
SELECT ul.group_id AS group_id,
g.name AS group_name,
COUNT(*) AS success_count,
COALESCE(SUM(input_tokens + output_tokens + cache_creation_tokens + cache_read_tokens), 0) AS token_consumed
FROM usage_logs ul
JOIN groups g ON g.id = ul.group_id
WHERE ul.created_at >= $1 AND ul.created_at < $2
AND g.platform = $3
GROUP BY 1, 2
),
error_totals AS (
SELECT group_id,
COUNT(*) AS error_count
FROM ops_error_logs
WHERE created_at >= $1 AND created_at < $2
AND platform = $3
AND group_id IS NOT NULL
AND COALESCE(status_code, 0) >= 400
AND is_count_tokens = FALSE -- 排除 count_tokens 请求的错误
GROUP BY 1
),
combined AS (
SELECT COALESCE(u.group_id, e.group_id) AS group_id,
COALESCE(u.group_name, g2.name, '') AS group_name,
COALESCE(u.success_count, 0) AS success_count,
COALESCE(e.error_count, 0) AS error_count,
COALESCE(u.token_consumed, 0) AS token_consumed
FROM usage_totals u
FULL OUTER JOIN error_totals e ON u.group_id = e.group_id
LEFT JOIN groups g2 ON g2.id = COALESCE(u.group_id, e.group_id)
)
SELECT group_id, group_name, (success_count + error_count) AS request_count, token_consumed
FROM combined
WHERE group_id IS NOT NULL
ORDER BY request_count DESC
LIMIT $4`
rows, err := r.db.QueryContext(ctx, q, start, end, platform, limit)
if err != nil {
return nil, err
}
defer func() { _ = rows.Close() }()
items := make([]*service.OpsThroughputGroupBreakdownItem, 0, limit)
for rows.Next() {
var groupID int64
var groupName sql.NullString
var requests int64
var tokens sql.NullInt64
if err := rows.Scan(&groupID, &groupName, &requests, &tokens); err != nil {
return nil, err
}
tokenConsumed := int64(0)
if tokens.Valid {
tokenConsumed = tokens.Int64
}
name := ""
if groupName.Valid {
name = groupName.String
}
items = append(items, &service.OpsThroughputGroupBreakdownItem{
GroupID: groupID,
GroupName: name,
RequestCount: requests,
TokenConsumed: tokenConsumed,
})
}
if err := rows.Err(); err != nil {
return nil, err
}
return items, nil
}
func opsBucketExprForUsage(bucketSeconds int) string {
switch bucketSeconds {
case 3600:
return "date_trunc('hour', ul.created_at)"
case 300:
// 5-minute buckets in UTC.
return "to_timestamp(floor(extract(epoch from ul.created_at) / 300) * 300)"
default:
return "date_trunc('minute', ul.created_at)"
}
}
func opsBucketExprForError(bucketSeconds int) string {
switch bucketSeconds {
case 3600:
return "date_trunc('hour', created_at)"
case 300:
return "to_timestamp(floor(extract(epoch from created_at) / 300) * 300)"
default:
return "date_trunc('minute', created_at)"
}
}
func opsBucketLabel(bucketSeconds int) string {
if bucketSeconds <= 0 {
return "1m"
}
if bucketSeconds%3600 == 0 {
h := bucketSeconds / 3600
if h <= 0 {
h = 1
}
return fmt.Sprintf("%dh", h)
}
m := bucketSeconds / 60
if m <= 0 {
m = 1
}
return fmt.Sprintf("%dm", m)
}
func opsFloorToBucketStart(t time.Time, bucketSeconds int) time.Time {
t = t.UTC()
if bucketSeconds <= 0 {
bucketSeconds = 60
}
secs := t.Unix()
floored := secs - (secs % int64(bucketSeconds))
return time.Unix(floored, 0).UTC()
}
func fillOpsThroughputBuckets(start, end time.Time, bucketSeconds int, points []*service.OpsThroughputTrendPoint) []*service.OpsThroughputTrendPoint {
if bucketSeconds <= 0 {
bucketSeconds = 60
}
if !start.Before(end) {
return points
}
endMinus := end.Add(-time.Nanosecond)
if endMinus.Before(start) {
return points
}
first := opsFloorToBucketStart(start, bucketSeconds)
last := opsFloorToBucketStart(endMinus, bucketSeconds)
step := time.Duration(bucketSeconds) * time.Second
existing := make(map[int64]*service.OpsThroughputTrendPoint, len(points))
for _, p := range points {
if p == nil {
continue
}
existing[p.BucketStart.UTC().Unix()] = p
}
out := make([]*service.OpsThroughputTrendPoint, 0, int(last.Sub(first)/step)+1)
for cursor := first; !cursor.After(last); cursor = cursor.Add(step) {
if p, ok := existing[cursor.Unix()]; ok && p != nil {
out = append(out, p)
continue
}
out = append(out, &service.OpsThroughputTrendPoint{
BucketStart: cursor,
RequestCount: 0,
TokenConsumed: 0,
QPS: 0,
TPS: 0,
})
}
return out
}
func (r *opsRepository) GetErrorTrend(ctx context.Context, filter *service.OpsDashboardFilter, bucketSeconds int) (*service.OpsErrorTrendResponse, error) {
if r == nil || r.db == nil {
return nil, fmt.Errorf("nil ops repository")
}
if filter == nil {
return nil, fmt.Errorf("nil filter")
}
if filter.StartTime.IsZero() || filter.EndTime.IsZero() {
return nil, fmt.Errorf("start_time/end_time required")
}
if bucketSeconds <= 0 {
bucketSeconds = 60
}
if bucketSeconds != 60 && bucketSeconds != 300 && bucketSeconds != 3600 {
bucketSeconds = 60
}
start := filter.StartTime.UTC()
end := filter.EndTime.UTC()
where, args, _ := buildErrorWhere(filter, start, end, 1)
bucketExpr := opsBucketExprForError(bucketSeconds)
q := `
SELECT
` + bucketExpr + ` AS bucket,
COUNT(*) FILTER (WHERE COALESCE(status_code, 0) >= 400) AS error_total,
COUNT(*) FILTER (WHERE COALESCE(status_code, 0) >= 400 AND is_business_limited) AS business_limited,
COUNT(*) FILTER (WHERE COALESCE(status_code, 0) >= 400 AND NOT is_business_limited) AS error_sla,
COUNT(*) FILTER (WHERE error_owner = 'provider' AND NOT is_business_limited AND COALESCE(upstream_status_code, status_code, 0) NOT IN (429, 529)) AS upstream_excl,
COUNT(*) FILTER (WHERE error_owner = 'provider' AND NOT is_business_limited AND COALESCE(upstream_status_code, status_code, 0) = 429) AS upstream_429,
COUNT(*) FILTER (WHERE error_owner = 'provider' AND NOT is_business_limited AND COALESCE(upstream_status_code, status_code, 0) = 529) AS upstream_529
FROM ops_error_logs
` + where + `
GROUP BY 1
ORDER BY 1 ASC`
rows, err := r.db.QueryContext(ctx, q, args...)
if err != nil {
return nil, err
}
defer func() { _ = rows.Close() }()
points := make([]*service.OpsErrorTrendPoint, 0, 256)
for rows.Next() {
var bucket time.Time
var total, businessLimited, sla, upstreamExcl, upstream429, upstream529 int64
if err := rows.Scan(&bucket, &total, &businessLimited, &sla, &upstreamExcl, &upstream429, &upstream529); err != nil {
return nil, err
}
points = append(points, &service.OpsErrorTrendPoint{
BucketStart: bucket.UTC(),
ErrorCountTotal: total,
BusinessLimitedCount: businessLimited,
ErrorCountSLA: sla,
UpstreamErrorCountExcl429529: upstreamExcl,
Upstream429Count: upstream429,
Upstream529Count: upstream529,
})
}
if err := rows.Err(); err != nil {
return nil, err
}
points = fillOpsErrorTrendBuckets(start, end, bucketSeconds, points)
return &service.OpsErrorTrendResponse{
Bucket: opsBucketLabel(bucketSeconds),
Points: points,
}, nil
}
func fillOpsErrorTrendBuckets(start, end time.Time, bucketSeconds int, points []*service.OpsErrorTrendPoint) []*service.OpsErrorTrendPoint {
if bucketSeconds <= 0 {
bucketSeconds = 60
}
if !start.Before(end) {
return points
}
endMinus := end.Add(-time.Nanosecond)
if endMinus.Before(start) {
return points
}
first := opsFloorToBucketStart(start, bucketSeconds)
last := opsFloorToBucketStart(endMinus, bucketSeconds)
step := time.Duration(bucketSeconds) * time.Second
existing := make(map[int64]*service.OpsErrorTrendPoint, len(points))
for _, p := range points {
if p == nil {
continue
}
existing[p.BucketStart.UTC().Unix()] = p
}
out := make([]*service.OpsErrorTrendPoint, 0, int(last.Sub(first)/step)+1)
for cursor := first; !cursor.After(last); cursor = cursor.Add(step) {
if p, ok := existing[cursor.Unix()]; ok && p != nil {
out = append(out, p)
continue
}
out = append(out, &service.OpsErrorTrendPoint{
BucketStart: cursor,
ErrorCountTotal: 0,
BusinessLimitedCount: 0,
ErrorCountSLA: 0,
UpstreamErrorCountExcl429529: 0,
Upstream429Count: 0,
Upstream529Count: 0,
})
}
return out
}
func (r *opsRepository) GetErrorDistribution(ctx context.Context, filter *service.OpsDashboardFilter) (*service.OpsErrorDistributionResponse, error) {
if r == nil || r.db == nil {
return nil, fmt.Errorf("nil ops repository")
}
if filter == nil {
return nil, fmt.Errorf("nil filter")
}
if filter.StartTime.IsZero() || filter.EndTime.IsZero() {
return nil, fmt.Errorf("start_time/end_time required")
}
start := filter.StartTime.UTC()
end := filter.EndTime.UTC()
where, args, _ := buildErrorWhere(filter, start, end, 1)
q := `
SELECT
COALESCE(upstream_status_code, status_code, 0) AS status_code,
COUNT(*) AS total,
COUNT(*) FILTER (WHERE NOT is_business_limited) AS sla,
COUNT(*) FILTER (WHERE is_business_limited) AS business_limited
FROM ops_error_logs
` + where + `
AND COALESCE(status_code, 0) >= 400
GROUP BY 1
ORDER BY total DESC
LIMIT 20`
rows, err := r.db.QueryContext(ctx, q, args...)
if err != nil {
return nil, err
}
defer func() { _ = rows.Close() }()
items := make([]*service.OpsErrorDistributionItem, 0, 16)
var total int64
for rows.Next() {
var statusCode int
var cntTotal, cntSLA, cntBiz int64
if err := rows.Scan(&statusCode, &cntTotal, &cntSLA, &cntBiz); err != nil {
return nil, err
}
total += cntTotal
items = append(items, &service.OpsErrorDistributionItem{
StatusCode: statusCode,
Total: cntTotal,
SLA: cntSLA,
BusinessLimited: cntBiz,
})
}
if err := rows.Err(); err != nil {
return nil, err
}
return &service.OpsErrorDistributionResponse{
Total: total,
Items: items,
}, nil
}

View File

@@ -0,0 +1,50 @@
package repository
import (
"context"
"fmt"
"time"
"github.com/Wei-Shaw/sub2api/internal/service"
)
func (r *opsRepository) GetWindowStats(ctx context.Context, filter *service.OpsDashboardFilter) (*service.OpsWindowStats, error) {
if r == nil || r.db == nil {
return nil, fmt.Errorf("nil ops repository")
}
if filter == nil {
return nil, fmt.Errorf("nil filter")
}
if filter.StartTime.IsZero() || filter.EndTime.IsZero() {
return nil, fmt.Errorf("start_time/end_time required")
}
start := filter.StartTime.UTC()
end := filter.EndTime.UTC()
if start.After(end) {
return nil, fmt.Errorf("start_time must be <= end_time")
}
// Bound excessively large windows to prevent accidental heavy queries.
if end.Sub(start) > 24*time.Hour {
return nil, fmt.Errorf("window too large")
}
successCount, tokenConsumed, err := r.queryUsageCounts(ctx, filter, start, end)
if err != nil {
return nil, err
}
errorTotal, _, _, _, _, _, err := r.queryErrorCounts(ctx, filter, start, end)
if err != nil {
return nil, err
}
return &service.OpsWindowStats{
StartTime: start,
EndTime: end,
SuccessCount: successCount,
ErrorCountTotal: errorTotal,
TokenConsumed: tokenConsumed,
}, nil
}

View File

@@ -0,0 +1,273 @@
package repository
import (
"context"
dbent "github.com/Wei-Shaw/sub2api/ent"
"github.com/Wei-Shaw/sub2api/ent/promocode"
"github.com/Wei-Shaw/sub2api/ent/promocodeusage"
"github.com/Wei-Shaw/sub2api/internal/pkg/pagination"
"github.com/Wei-Shaw/sub2api/internal/service"
)
type promoCodeRepository struct {
client *dbent.Client
}
func NewPromoCodeRepository(client *dbent.Client) service.PromoCodeRepository {
return &promoCodeRepository{client: client}
}
func (r *promoCodeRepository) Create(ctx context.Context, code *service.PromoCode) error {
client := clientFromContext(ctx, r.client)
builder := client.PromoCode.Create().
SetCode(code.Code).
SetBonusAmount(code.BonusAmount).
SetMaxUses(code.MaxUses).
SetUsedCount(code.UsedCount).
SetStatus(code.Status).
SetNotes(code.Notes)
if code.ExpiresAt != nil {
builder.SetExpiresAt(*code.ExpiresAt)
}
created, err := builder.Save(ctx)
if err != nil {
return err
}
code.ID = created.ID
code.CreatedAt = created.CreatedAt
code.UpdatedAt = created.UpdatedAt
return nil
}
func (r *promoCodeRepository) GetByID(ctx context.Context, id int64) (*service.PromoCode, error) {
m, err := r.client.PromoCode.Query().
Where(promocode.IDEQ(id)).
Only(ctx)
if err != nil {
if dbent.IsNotFound(err) {
return nil, service.ErrPromoCodeNotFound
}
return nil, err
}
return promoCodeEntityToService(m), nil
}
func (r *promoCodeRepository) GetByCode(ctx context.Context, code string) (*service.PromoCode, error) {
m, err := r.client.PromoCode.Query().
Where(promocode.CodeEqualFold(code)).
Only(ctx)
if err != nil {
if dbent.IsNotFound(err) {
return nil, service.ErrPromoCodeNotFound
}
return nil, err
}
return promoCodeEntityToService(m), nil
}
func (r *promoCodeRepository) GetByCodeForUpdate(ctx context.Context, code string) (*service.PromoCode, error) {
client := clientFromContext(ctx, r.client)
m, err := client.PromoCode.Query().
Where(promocode.CodeEqualFold(code)).
ForUpdate().
Only(ctx)
if err != nil {
if dbent.IsNotFound(err) {
return nil, service.ErrPromoCodeNotFound
}
return nil, err
}
return promoCodeEntityToService(m), nil
}
func (r *promoCodeRepository) Update(ctx context.Context, code *service.PromoCode) error {
client := clientFromContext(ctx, r.client)
builder := client.PromoCode.UpdateOneID(code.ID).
SetCode(code.Code).
SetBonusAmount(code.BonusAmount).
SetMaxUses(code.MaxUses).
SetUsedCount(code.UsedCount).
SetStatus(code.Status).
SetNotes(code.Notes)
if code.ExpiresAt != nil {
builder.SetExpiresAt(*code.ExpiresAt)
} else {
builder.ClearExpiresAt()
}
updated, err := builder.Save(ctx)
if err != nil {
if dbent.IsNotFound(err) {
return service.ErrPromoCodeNotFound
}
return err
}
code.UpdatedAt = updated.UpdatedAt
return nil
}
func (r *promoCodeRepository) Delete(ctx context.Context, id int64) error {
client := clientFromContext(ctx, r.client)
_, err := client.PromoCode.Delete().Where(promocode.IDEQ(id)).Exec(ctx)
return err
}
func (r *promoCodeRepository) List(ctx context.Context, params pagination.PaginationParams) ([]service.PromoCode, *pagination.PaginationResult, error) {
return r.ListWithFilters(ctx, params, "", "")
}
func (r *promoCodeRepository) ListWithFilters(ctx context.Context, params pagination.PaginationParams, status, search string) ([]service.PromoCode, *pagination.PaginationResult, error) {
q := r.client.PromoCode.Query()
if status != "" {
q = q.Where(promocode.StatusEQ(status))
}
if search != "" {
q = q.Where(promocode.CodeContainsFold(search))
}
total, err := q.Count(ctx)
if err != nil {
return nil, nil, err
}
codes, err := q.
Offset(params.Offset()).
Limit(params.Limit()).
Order(dbent.Desc(promocode.FieldID)).
All(ctx)
if err != nil {
return nil, nil, err
}
outCodes := promoCodeEntitiesToService(codes)
return outCodes, paginationResultFromTotal(int64(total), params), nil
}
func (r *promoCodeRepository) CreateUsage(ctx context.Context, usage *service.PromoCodeUsage) error {
client := clientFromContext(ctx, r.client)
created, err := client.PromoCodeUsage.Create().
SetPromoCodeID(usage.PromoCodeID).
SetUserID(usage.UserID).
SetBonusAmount(usage.BonusAmount).
SetUsedAt(usage.UsedAt).
Save(ctx)
if err != nil {
return err
}
usage.ID = created.ID
return nil
}
func (r *promoCodeRepository) GetUsageByPromoCodeAndUser(ctx context.Context, promoCodeID, userID int64) (*service.PromoCodeUsage, error) {
m, err := r.client.PromoCodeUsage.Query().
Where(
promocodeusage.PromoCodeIDEQ(promoCodeID),
promocodeusage.UserIDEQ(userID),
).
Only(ctx)
if err != nil {
if dbent.IsNotFound(err) {
return nil, nil
}
return nil, err
}
return promoCodeUsageEntityToService(m), nil
}
func (r *promoCodeRepository) ListUsagesByPromoCode(ctx context.Context, promoCodeID int64, params pagination.PaginationParams) ([]service.PromoCodeUsage, *pagination.PaginationResult, error) {
q := r.client.PromoCodeUsage.Query().
Where(promocodeusage.PromoCodeIDEQ(promoCodeID))
total, err := q.Count(ctx)
if err != nil {
return nil, nil, err
}
usages, err := q.
WithUser().
Offset(params.Offset()).
Limit(params.Limit()).
Order(dbent.Desc(promocodeusage.FieldID)).
All(ctx)
if err != nil {
return nil, nil, err
}
outUsages := promoCodeUsageEntitiesToService(usages)
return outUsages, paginationResultFromTotal(int64(total), params), nil
}
func (r *promoCodeRepository) IncrementUsedCount(ctx context.Context, id int64) error {
client := clientFromContext(ctx, r.client)
_, err := client.PromoCode.UpdateOneID(id).
AddUsedCount(1).
Save(ctx)
return err
}
// Entity to Service conversions
func promoCodeEntityToService(m *dbent.PromoCode) *service.PromoCode {
if m == nil {
return nil
}
return &service.PromoCode{
ID: m.ID,
Code: m.Code,
BonusAmount: m.BonusAmount,
MaxUses: m.MaxUses,
UsedCount: m.UsedCount,
Status: m.Status,
ExpiresAt: m.ExpiresAt,
Notes: derefString(m.Notes),
CreatedAt: m.CreatedAt,
UpdatedAt: m.UpdatedAt,
}
}
func promoCodeEntitiesToService(models []*dbent.PromoCode) []service.PromoCode {
out := make([]service.PromoCode, 0, len(models))
for i := range models {
if s := promoCodeEntityToService(models[i]); s != nil {
out = append(out, *s)
}
}
return out
}
func promoCodeUsageEntityToService(m *dbent.PromoCodeUsage) *service.PromoCodeUsage {
if m == nil {
return nil
}
out := &service.PromoCodeUsage{
ID: m.ID,
PromoCodeID: m.PromoCodeID,
UserID: m.UserID,
BonusAmount: m.BonusAmount,
UsedAt: m.UsedAt,
}
if m.Edges.User != nil {
out.User = userEntityToService(m.Edges.User)
}
return out
}
func promoCodeUsageEntitiesToService(models []*dbent.PromoCodeUsage) []service.PromoCodeUsage {
out := make([]service.PromoCodeUsage, 0, len(models))
for i := range models {
if s := promoCodeUsageEntityToService(models[i]); s != nil {
out = append(out, *s)
}
}
return out
}

View File

@@ -0,0 +1,74 @@
package repository
import (
"context"
"encoding/json"
"fmt"
"github.com/Wei-Shaw/sub2api/internal/service"
"github.com/redis/go-redis/v9"
)
const proxyLatencyKeyPrefix = "proxy:latency:"
func proxyLatencyKey(proxyID int64) string {
return fmt.Sprintf("%s%d", proxyLatencyKeyPrefix, proxyID)
}
type proxyLatencyCache struct {
rdb *redis.Client
}
func NewProxyLatencyCache(rdb *redis.Client) service.ProxyLatencyCache {
return &proxyLatencyCache{rdb: rdb}
}
func (c *proxyLatencyCache) GetProxyLatencies(ctx context.Context, proxyIDs []int64) (map[int64]*service.ProxyLatencyInfo, error) {
results := make(map[int64]*service.ProxyLatencyInfo)
if len(proxyIDs) == 0 {
return results, nil
}
keys := make([]string, 0, len(proxyIDs))
for _, id := range proxyIDs {
keys = append(keys, proxyLatencyKey(id))
}
values, err := c.rdb.MGet(ctx, keys...).Result()
if err != nil {
return results, err
}
for i, raw := range values {
if raw == nil {
continue
}
var payload []byte
switch v := raw.(type) {
case string:
payload = []byte(v)
case []byte:
payload = v
default:
continue
}
var info service.ProxyLatencyInfo
if err := json.Unmarshal(payload, &info); err != nil {
continue
}
results[proxyIDs[i]] = &info
}
return results, nil
}
func (c *proxyLatencyCache) SetProxyLatency(ctx context.Context, proxyID int64, info *service.ProxyLatencyInfo) error {
if info == nil {
return nil
}
payload, err := json.Marshal(info)
if err != nil {
return err
}
return c.rdb.Set(ctx, proxyLatencyKey(proxyID), payload, 0).Err()
}

View File

@@ -7,6 +7,7 @@ import (
"io"
"log"
"net/http"
"strings"
"time"
"github.com/Wei-Shaw/sub2api/internal/config"
@@ -34,7 +35,10 @@ func NewProxyExitInfoProber(cfg *config.Config) service.ProxyExitInfoProber {
}
}
const defaultIPInfoURL = "https://ipinfo.io/json"
const (
defaultIPInfoURL = "http://ip-api.com/json/?lang=zh-CN"
defaultProxyProbeTimeout = 30 * time.Second
)
type proxyProbeService struct {
ipInfoURL string
@@ -46,7 +50,7 @@ type proxyProbeService struct {
func (s *proxyProbeService) ProbeProxy(ctx context.Context, proxyURL string) (*service.ProxyExitInfo, int64, error) {
client, err := httpclient.GetClient(httpclient.Options{
ProxyURL: proxyURL,
Timeout: 15 * time.Second,
Timeout: defaultProxyProbeTimeout,
InsecureSkipVerify: s.insecureSkipVerify,
ProxyStrict: true,
ValidateResolvedIP: s.validateResolvedIP,
@@ -75,10 +79,14 @@ func (s *proxyProbeService) ProbeProxy(ctx context.Context, proxyURL string) (*s
}
var ipInfo struct {
IP string `json:"ip"`
City string `json:"city"`
Region string `json:"region"`
Country string `json:"country"`
Status string `json:"status"`
Message string `json:"message"`
Query string `json:"query"`
City string `json:"city"`
Region string `json:"region"`
RegionName string `json:"regionName"`
Country string `json:"country"`
CountryCode string `json:"countryCode"`
}
body, err := io.ReadAll(resp.Body)
@@ -89,11 +97,22 @@ func (s *proxyProbeService) ProbeProxy(ctx context.Context, proxyURL string) (*s
if err := json.Unmarshal(body, &ipInfo); err != nil {
return nil, latencyMs, fmt.Errorf("failed to parse response: %w", err)
}
if strings.ToLower(ipInfo.Status) != "success" {
if ipInfo.Message == "" {
ipInfo.Message = "ip-api request failed"
}
return nil, latencyMs, fmt.Errorf("ip-api request failed: %s", ipInfo.Message)
}
region := ipInfo.RegionName
if region == "" {
region = ipInfo.Region
}
return &service.ProxyExitInfo{
IP: ipInfo.IP,
City: ipInfo.City,
Region: ipInfo.Region,
Country: ipInfo.Country,
IP: ipInfo.Query,
City: ipInfo.City,
Region: region,
Country: ipInfo.Country,
CountryCode: ipInfo.CountryCode,
}, latencyMs, nil
}

View File

@@ -21,7 +21,7 @@ type ProxyProbeServiceSuite struct {
func (s *ProxyProbeServiceSuite) SetupTest() {
s.ctx = context.Background()
s.prober = &proxyProbeService{
ipInfoURL: "http://ipinfo.test/json",
ipInfoURL: "http://ip-api.test/json/?lang=zh-CN",
allowPrivateHosts: true,
}
}
@@ -54,7 +54,7 @@ func (s *ProxyProbeServiceSuite) TestProbeProxy_Success() {
s.setupProxyServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
seen <- r.RequestURI
w.Header().Set("Content-Type", "application/json")
_, _ = io.WriteString(w, `{"ip":"1.2.3.4","city":"c","region":"r","country":"cc"}`)
_, _ = io.WriteString(w, `{"status":"success","query":"1.2.3.4","city":"c","regionName":"r","country":"cc","countryCode":"CC"}`)
}))
info, latencyMs, err := s.prober.ProbeProxy(s.ctx, s.proxySrv.URL)
@@ -64,11 +64,12 @@ func (s *ProxyProbeServiceSuite) TestProbeProxy_Success() {
require.Equal(s.T(), "c", info.City)
require.Equal(s.T(), "r", info.Region)
require.Equal(s.T(), "cc", info.Country)
require.Equal(s.T(), "CC", info.CountryCode)
// Verify proxy received the request
select {
case uri := <-seen:
require.Contains(s.T(), uri, "ipinfo.test", "expected request to go through proxy")
require.Contains(s.T(), uri, "ip-api.test", "expected request to go through proxy")
default:
require.Fail(s.T(), "expected proxy to receive request")
}

View File

@@ -219,12 +219,54 @@ func (r *proxyRepository) ExistsByHostPortAuth(ctx context.Context, host string,
// CountAccountsByProxyID returns the number of accounts using a specific proxy
func (r *proxyRepository) CountAccountsByProxyID(ctx context.Context, proxyID int64) (int64, error) {
var count int64
if err := scanSingleRow(ctx, r.sql, "SELECT COUNT(*) FROM accounts WHERE proxy_id = $1", []any{proxyID}, &count); err != nil {
if err := scanSingleRow(ctx, r.sql, "SELECT COUNT(*) FROM accounts WHERE proxy_id = $1 AND deleted_at IS NULL", []any{proxyID}, &count); err != nil {
return 0, err
}
return count, nil
}
func (r *proxyRepository) ListAccountSummariesByProxyID(ctx context.Context, proxyID int64) ([]service.ProxyAccountSummary, error) {
rows, err := r.sql.QueryContext(ctx, `
SELECT id, name, platform, type, notes
FROM accounts
WHERE proxy_id = $1 AND deleted_at IS NULL
ORDER BY id DESC
`, proxyID)
if err != nil {
return nil, err
}
defer func() { _ = rows.Close() }()
out := make([]service.ProxyAccountSummary, 0)
for rows.Next() {
var (
id int64
name string
platform string
accType string
notes sql.NullString
)
if err := rows.Scan(&id, &name, &platform, &accType, &notes); err != nil {
return nil, err
}
var notesPtr *string
if notes.Valid {
notesPtr = &notes.String
}
out = append(out, service.ProxyAccountSummary{
ID: id,
Name: name,
Platform: platform,
Type: accType,
Notes: notesPtr,
})
}
if err := rows.Err(); err != nil {
return nil, err
}
return out, nil
}
// GetAccountCountsForProxies returns a map of proxy ID to account count for all proxies
func (r *proxyRepository) GetAccountCountsForProxies(ctx context.Context) (counts map[int64]int64, err error) {
rows, err := r.sql.QueryContext(ctx, "SELECT proxy_id, COUNT(*) AS count FROM accounts WHERE proxy_id IS NOT NULL AND deleted_at IS NULL GROUP BY proxy_id")

View File

@@ -0,0 +1,276 @@
package repository
import (
"context"
"encoding/json"
"fmt"
"strconv"
"time"
"github.com/Wei-Shaw/sub2api/internal/service"
"github.com/redis/go-redis/v9"
)
const (
schedulerBucketSetKey = "sched:buckets"
schedulerOutboxWatermarkKey = "sched:outbox:watermark"
schedulerAccountPrefix = "sched:acc:"
schedulerActivePrefix = "sched:active:"
schedulerReadyPrefix = "sched:ready:"
schedulerVersionPrefix = "sched:ver:"
schedulerSnapshotPrefix = "sched:"
schedulerLockPrefix = "sched:lock:"
)
type schedulerCache struct {
rdb *redis.Client
}
func NewSchedulerCache(rdb *redis.Client) service.SchedulerCache {
return &schedulerCache{rdb: rdb}
}
func (c *schedulerCache) GetSnapshot(ctx context.Context, bucket service.SchedulerBucket) ([]*service.Account, bool, error) {
readyKey := schedulerBucketKey(schedulerReadyPrefix, bucket)
readyVal, err := c.rdb.Get(ctx, readyKey).Result()
if err == redis.Nil {
return nil, false, nil
}
if err != nil {
return nil, false, err
}
if readyVal != "1" {
return nil, false, nil
}
activeKey := schedulerBucketKey(schedulerActivePrefix, bucket)
activeVal, err := c.rdb.Get(ctx, activeKey).Result()
if err == redis.Nil {
return nil, false, nil
}
if err != nil {
return nil, false, err
}
snapshotKey := schedulerSnapshotKey(bucket, activeVal)
ids, err := c.rdb.ZRange(ctx, snapshotKey, 0, -1).Result()
if err != nil {
return nil, false, err
}
if len(ids) == 0 {
return []*service.Account{}, true, nil
}
keys := make([]string, 0, len(ids))
for _, id := range ids {
keys = append(keys, schedulerAccountKey(id))
}
values, err := c.rdb.MGet(ctx, keys...).Result()
if err != nil {
return nil, false, err
}
accounts := make([]*service.Account, 0, len(values))
for _, val := range values {
if val == nil {
return nil, false, nil
}
account, err := decodeCachedAccount(val)
if err != nil {
return nil, false, err
}
accounts = append(accounts, account)
}
return accounts, true, nil
}
func (c *schedulerCache) SetSnapshot(ctx context.Context, bucket service.SchedulerBucket, accounts []service.Account) error {
activeKey := schedulerBucketKey(schedulerActivePrefix, bucket)
oldActive, _ := c.rdb.Get(ctx, activeKey).Result()
versionKey := schedulerBucketKey(schedulerVersionPrefix, bucket)
version, err := c.rdb.Incr(ctx, versionKey).Result()
if err != nil {
return err
}
versionStr := strconv.FormatInt(version, 10)
snapshotKey := schedulerSnapshotKey(bucket, versionStr)
pipe := c.rdb.Pipeline()
for _, account := range accounts {
payload, err := json.Marshal(account)
if err != nil {
return err
}
pipe.Set(ctx, schedulerAccountKey(strconv.FormatInt(account.ID, 10)), payload, 0)
}
if len(accounts) > 0 {
// 使用序号作为 score保持数据库返回的排序语义。
members := make([]redis.Z, 0, len(accounts))
for idx, account := range accounts {
members = append(members, redis.Z{
Score: float64(idx),
Member: strconv.FormatInt(account.ID, 10),
})
}
pipe.ZAdd(ctx, snapshotKey, members...)
} else {
pipe.Del(ctx, snapshotKey)
}
pipe.Set(ctx, activeKey, versionStr, 0)
pipe.Set(ctx, schedulerBucketKey(schedulerReadyPrefix, bucket), "1", 0)
pipe.SAdd(ctx, schedulerBucketSetKey, bucket.String())
if _, err := pipe.Exec(ctx); err != nil {
return err
}
if oldActive != "" && oldActive != versionStr {
_ = c.rdb.Del(ctx, schedulerSnapshotKey(bucket, oldActive)).Err()
}
return nil
}
func (c *schedulerCache) GetAccount(ctx context.Context, accountID int64) (*service.Account, error) {
key := schedulerAccountKey(strconv.FormatInt(accountID, 10))
val, err := c.rdb.Get(ctx, key).Result()
if err == redis.Nil {
return nil, nil
}
if err != nil {
return nil, err
}
return decodeCachedAccount(val)
}
func (c *schedulerCache) SetAccount(ctx context.Context, account *service.Account) error {
if account == nil || account.ID <= 0 {
return nil
}
payload, err := json.Marshal(account)
if err != nil {
return err
}
key := schedulerAccountKey(strconv.FormatInt(account.ID, 10))
return c.rdb.Set(ctx, key, payload, 0).Err()
}
func (c *schedulerCache) DeleteAccount(ctx context.Context, accountID int64) error {
if accountID <= 0 {
return nil
}
key := schedulerAccountKey(strconv.FormatInt(accountID, 10))
return c.rdb.Del(ctx, key).Err()
}
func (c *schedulerCache) UpdateLastUsed(ctx context.Context, updates map[int64]time.Time) error {
if len(updates) == 0 {
return nil
}
keys := make([]string, 0, len(updates))
ids := make([]int64, 0, len(updates))
for id := range updates {
keys = append(keys, schedulerAccountKey(strconv.FormatInt(id, 10)))
ids = append(ids, id)
}
values, err := c.rdb.MGet(ctx, keys...).Result()
if err != nil {
return err
}
pipe := c.rdb.Pipeline()
for i, val := range values {
if val == nil {
continue
}
account, err := decodeCachedAccount(val)
if err != nil {
return err
}
account.LastUsedAt = ptrTime(updates[ids[i]])
updated, err := json.Marshal(account)
if err != nil {
return err
}
pipe.Set(ctx, keys[i], updated, 0)
}
_, err = pipe.Exec(ctx)
return err
}
func (c *schedulerCache) TryLockBucket(ctx context.Context, bucket service.SchedulerBucket, ttl time.Duration) (bool, error) {
key := schedulerBucketKey(schedulerLockPrefix, bucket)
return c.rdb.SetNX(ctx, key, time.Now().UnixNano(), ttl).Result()
}
func (c *schedulerCache) ListBuckets(ctx context.Context) ([]service.SchedulerBucket, error) {
raw, err := c.rdb.SMembers(ctx, schedulerBucketSetKey).Result()
if err != nil {
return nil, err
}
out := make([]service.SchedulerBucket, 0, len(raw))
for _, entry := range raw {
bucket, ok := service.ParseSchedulerBucket(entry)
if !ok {
continue
}
out = append(out, bucket)
}
return out, nil
}
func (c *schedulerCache) GetOutboxWatermark(ctx context.Context) (int64, error) {
val, err := c.rdb.Get(ctx, schedulerOutboxWatermarkKey).Result()
if err == redis.Nil {
return 0, nil
}
if err != nil {
return 0, err
}
id, err := strconv.ParseInt(val, 10, 64)
if err != nil {
return 0, err
}
return id, nil
}
func (c *schedulerCache) SetOutboxWatermark(ctx context.Context, id int64) error {
return c.rdb.Set(ctx, schedulerOutboxWatermarkKey, strconv.FormatInt(id, 10), 0).Err()
}
func schedulerBucketKey(prefix string, bucket service.SchedulerBucket) string {
return fmt.Sprintf("%s%d:%s:%s", prefix, bucket.GroupID, bucket.Platform, bucket.Mode)
}
func schedulerSnapshotKey(bucket service.SchedulerBucket, version string) string {
return fmt.Sprintf("%s%d:%s:%s:v%s", schedulerSnapshotPrefix, bucket.GroupID, bucket.Platform, bucket.Mode, version)
}
func schedulerAccountKey(id string) string {
return schedulerAccountPrefix + id
}
func ptrTime(t time.Time) *time.Time {
return &t
}
func decodeCachedAccount(val any) (*service.Account, error) {
var payload []byte
switch raw := val.(type) {
case string:
payload = []byte(raw)
case []byte:
payload = raw
default:
return nil, fmt.Errorf("unexpected account cache type: %T", val)
}
var account service.Account
if err := json.Unmarshal(payload, &account); err != nil {
return nil, err
}
return &account, nil
}

View File

@@ -0,0 +1,96 @@
package repository
import (
"context"
"database/sql"
"encoding/json"
"github.com/Wei-Shaw/sub2api/internal/service"
)
type schedulerOutboxRepository struct {
db *sql.DB
}
func NewSchedulerOutboxRepository(db *sql.DB) service.SchedulerOutboxRepository {
return &schedulerOutboxRepository{db: db}
}
func (r *schedulerOutboxRepository) ListAfter(ctx context.Context, afterID int64, limit int) ([]service.SchedulerOutboxEvent, error) {
if limit <= 0 {
limit = 100
}
rows, err := r.db.QueryContext(ctx, `
SELECT id, event_type, account_id, group_id, payload, created_at
FROM scheduler_outbox
WHERE id > $1
ORDER BY id ASC
LIMIT $2
`, afterID, limit)
if err != nil {
return nil, err
}
defer func() {
_ = rows.Close()
}()
events := make([]service.SchedulerOutboxEvent, 0, limit)
for rows.Next() {
var (
payloadRaw []byte
accountID sql.NullInt64
groupID sql.NullInt64
event service.SchedulerOutboxEvent
)
if err := rows.Scan(&event.ID, &event.EventType, &accountID, &groupID, &payloadRaw, &event.CreatedAt); err != nil {
return nil, err
}
if accountID.Valid {
v := accountID.Int64
event.AccountID = &v
}
if groupID.Valid {
v := groupID.Int64
event.GroupID = &v
}
if len(payloadRaw) > 0 {
var payload map[string]any
if err := json.Unmarshal(payloadRaw, &payload); err != nil {
return nil, err
}
event.Payload = payload
}
events = append(events, event)
}
if err := rows.Err(); err != nil {
return nil, err
}
return events, nil
}
func (r *schedulerOutboxRepository) MaxID(ctx context.Context) (int64, error) {
var maxID int64
if err := r.db.QueryRowContext(ctx, "SELECT COALESCE(MAX(id), 0) FROM scheduler_outbox").Scan(&maxID); err != nil {
return 0, err
}
return maxID, nil
}
func enqueueSchedulerOutbox(ctx context.Context, exec sqlExecutor, eventType string, accountID *int64, groupID *int64, payload any) error {
if exec == nil {
return nil
}
var payloadArg any
if payload != nil {
encoded, err := json.Marshal(payload)
if err != nil {
return err
}
payloadArg = encoded
}
_, err := exec.ExecContext(ctx, `
INSERT INTO scheduler_outbox (event_type, account_id, group_id, payload)
VALUES ($1, $2, $3, $4)
`, eventType, accountID, groupID, payloadArg)
return err
}

View File

@@ -0,0 +1,68 @@
//go:build integration
package repository
import (
"context"
"testing"
"time"
"github.com/Wei-Shaw/sub2api/internal/config"
"github.com/Wei-Shaw/sub2api/internal/service"
"github.com/stretchr/testify/require"
)
func TestSchedulerSnapshotOutboxReplay(t *testing.T) {
ctx := context.Background()
rdb := testRedis(t)
client := testEntClient(t)
_, _ = integrationDB.ExecContext(ctx, "TRUNCATE scheduler_outbox")
accountRepo := newAccountRepositoryWithSQL(client, integrationDB)
outboxRepo := NewSchedulerOutboxRepository(integrationDB)
cache := NewSchedulerCache(rdb)
cfg := &config.Config{
RunMode: config.RunModeStandard,
Gateway: config.GatewayConfig{
Scheduling: config.GatewaySchedulingConfig{
OutboxPollIntervalSeconds: 1,
FullRebuildIntervalSeconds: 0,
DbFallbackEnabled: true,
},
},
}
account := &service.Account{
Name: "outbox-replay-" + time.Now().Format("150405.000000"),
Platform: service.PlatformOpenAI,
Type: service.AccountTypeAPIKey,
Status: service.StatusActive,
Schedulable: true,
Concurrency: 3,
Priority: 1,
Credentials: map[string]any{},
Extra: map[string]any{},
}
require.NoError(t, accountRepo.Create(ctx, account))
require.NoError(t, cache.SetAccount(ctx, account))
svc := service.NewSchedulerSnapshotService(cache, outboxRepo, accountRepo, nil, cfg)
svc.Start()
t.Cleanup(svc.Stop)
require.NoError(t, accountRepo.UpdateLastUsed(ctx, account.ID))
updated, err := accountRepo.GetByID(ctx, account.ID)
require.NoError(t, err)
require.NotNil(t, updated.LastUsedAt)
expectedUnix := updated.LastUsedAt.Unix()
require.Eventually(t, func() bool {
cached, err := cache.GetAccount(ctx, account.ID)
if err != nil || cached == nil || cached.LastUsedAt == nil {
return false
}
return cached.LastUsedAt.Unix() == expectedUnix
}, 5*time.Second, 100*time.Millisecond)
}

View File

@@ -0,0 +1,321 @@
package repository
import (
"context"
"fmt"
"strconv"
"time"
"github.com/Wei-Shaw/sub2api/internal/service"
"github.com/redis/go-redis/v9"
)
// 会话限制缓存常量定义
//
// 设计说明:
// 使用 Redis 有序集合Sorted Set跟踪每个账号的活跃会话
// - Key: session_limit:account:{accountID}
// - Member: sessionUUID从 metadata.user_id 中提取)
// - Score: Unix 时间戳(会话最后活跃时间)
//
// 通过 ZREMRANGEBYSCORE 自动清理过期会话,无需手动管理 TTL
const (
// 会话限制键前缀
// 格式: session_limit:account:{accountID}
sessionLimitKeyPrefix = "session_limit:account:"
// 窗口费用缓存键前缀
// 格式: window_cost:account:{accountID}
windowCostKeyPrefix = "window_cost:account:"
// 窗口费用缓存 TTL30秒
windowCostCacheTTL = 30 * time.Second
)
var (
// registerSessionScript 注册会话活动
// 使用 Redis TIME 命令获取服务器时间,避免多实例时钟不同步
// KEYS[1] = session_limit:account:{accountID}
// ARGV[1] = maxSessions
// ARGV[2] = idleTimeout
// ARGV[3] = sessionUUID
// 返回: 1 = 允许, 0 = 拒绝
registerSessionScript = redis.NewScript(`
local key = KEYS[1]
local maxSessions = tonumber(ARGV[1])
local idleTimeout = tonumber(ARGV[2])
local sessionUUID = ARGV[3]
-- 使用 Redis 服务器时间,确保多实例时钟一致
local timeResult = redis.call('TIME')
local now = tonumber(timeResult[1])
local expireBefore = now - idleTimeout
-- 清理过期会话
redis.call('ZREMRANGEBYSCORE', key, '-inf', expireBefore)
-- 检查会话是否已存在(支持刷新时间戳)
local exists = redis.call('ZSCORE', key, sessionUUID)
if exists ~= false then
-- 会话已存在,刷新时间戳
redis.call('ZADD', key, now, sessionUUID)
redis.call('EXPIRE', key, idleTimeout + 60)
return 1
end
-- 检查是否达到会话数量上限
local count = redis.call('ZCARD', key)
if count < maxSessions then
-- 未达上限,添加新会话
redis.call('ZADD', key, now, sessionUUID)
redis.call('EXPIRE', key, idleTimeout + 60)
return 1
end
-- 达到上限,拒绝新会话
return 0
`)
// refreshSessionScript 刷新会话时间戳
// KEYS[1] = session_limit:account:{accountID}
// ARGV[1] = idleTimeout
// ARGV[2] = sessionUUID
refreshSessionScript = redis.NewScript(`
local key = KEYS[1]
local idleTimeout = tonumber(ARGV[1])
local sessionUUID = ARGV[2]
local timeResult = redis.call('TIME')
local now = tonumber(timeResult[1])
-- 检查会话是否存在
local exists = redis.call('ZSCORE', key, sessionUUID)
if exists ~= false then
redis.call('ZADD', key, now, sessionUUID)
redis.call('EXPIRE', key, idleTimeout + 60)
end
return 1
`)
// getActiveSessionCountScript 获取活跃会话数
// KEYS[1] = session_limit:account:{accountID}
// ARGV[1] = idleTimeout
getActiveSessionCountScript = redis.NewScript(`
local key = KEYS[1]
local idleTimeout = tonumber(ARGV[1])
local timeResult = redis.call('TIME')
local now = tonumber(timeResult[1])
local expireBefore = now - idleTimeout
-- 清理过期会话
redis.call('ZREMRANGEBYSCORE', key, '-inf', expireBefore)
return redis.call('ZCARD', key)
`)
// isSessionActiveScript 检查会话是否活跃
// KEYS[1] = session_limit:account:{accountID}
// ARGV[1] = idleTimeout
// ARGV[2] = sessionUUID
isSessionActiveScript = redis.NewScript(`
local key = KEYS[1]
local idleTimeout = tonumber(ARGV[1])
local sessionUUID = ARGV[2]
local timeResult = redis.call('TIME')
local now = tonumber(timeResult[1])
local expireBefore = now - idleTimeout
-- 获取会话的时间戳
local score = redis.call('ZSCORE', key, sessionUUID)
if score == false then
return 0
end
-- 检查是否过期
if tonumber(score) <= expireBefore then
return 0
end
return 1
`)
)
type sessionLimitCache struct {
rdb *redis.Client
defaultIdleTimeout time.Duration // 默认空闲超时(用于 GetActiveSessionCount
}
// NewSessionLimitCache 创建会话限制缓存
// defaultIdleTimeoutMinutes: 默认空闲超时时间(分钟),用于无参数查询
func NewSessionLimitCache(rdb *redis.Client, defaultIdleTimeoutMinutes int) service.SessionLimitCache {
if defaultIdleTimeoutMinutes <= 0 {
defaultIdleTimeoutMinutes = 5 // 默认 5 分钟
}
return &sessionLimitCache{
rdb: rdb,
defaultIdleTimeout: time.Duration(defaultIdleTimeoutMinutes) * time.Minute,
}
}
// sessionLimitKey 生成会话限制的 Redis 键
func sessionLimitKey(accountID int64) string {
return fmt.Sprintf("%s%d", sessionLimitKeyPrefix, accountID)
}
// windowCostKey 生成窗口费用缓存的 Redis 键
func windowCostKey(accountID int64) string {
return fmt.Sprintf("%s%d", windowCostKeyPrefix, accountID)
}
// RegisterSession 注册会话活动
func (c *sessionLimitCache) RegisterSession(ctx context.Context, accountID int64, sessionUUID string, maxSessions int, idleTimeout time.Duration) (bool, error) {
if sessionUUID == "" || maxSessions <= 0 {
return true, nil // 无效参数,默认允许
}
key := sessionLimitKey(accountID)
idleTimeoutSeconds := int(idleTimeout.Seconds())
if idleTimeoutSeconds <= 0 {
idleTimeoutSeconds = int(c.defaultIdleTimeout.Seconds())
}
result, err := registerSessionScript.Run(ctx, c.rdb, []string{key}, maxSessions, idleTimeoutSeconds, sessionUUID).Int()
if err != nil {
return true, err // 失败开放:缓存错误时允许请求通过
}
return result == 1, nil
}
// RefreshSession 刷新会话时间戳
func (c *sessionLimitCache) RefreshSession(ctx context.Context, accountID int64, sessionUUID string, idleTimeout time.Duration) error {
if sessionUUID == "" {
return nil
}
key := sessionLimitKey(accountID)
idleTimeoutSeconds := int(idleTimeout.Seconds())
if idleTimeoutSeconds <= 0 {
idleTimeoutSeconds = int(c.defaultIdleTimeout.Seconds())
}
_, err := refreshSessionScript.Run(ctx, c.rdb, []string{key}, idleTimeoutSeconds, sessionUUID).Result()
return err
}
// GetActiveSessionCount 获取活跃会话数
func (c *sessionLimitCache) GetActiveSessionCount(ctx context.Context, accountID int64) (int, error) {
key := sessionLimitKey(accountID)
idleTimeoutSeconds := int(c.defaultIdleTimeout.Seconds())
result, err := getActiveSessionCountScript.Run(ctx, c.rdb, []string{key}, idleTimeoutSeconds).Int()
if err != nil {
return 0, err
}
return result, nil
}
// GetActiveSessionCountBatch 批量获取多个账号的活跃会话数
func (c *sessionLimitCache) GetActiveSessionCountBatch(ctx context.Context, accountIDs []int64) (map[int64]int, error) {
if len(accountIDs) == 0 {
return make(map[int64]int), nil
}
results := make(map[int64]int, len(accountIDs))
// 使用 pipeline 批量执行
pipe := c.rdb.Pipeline()
idleTimeoutSeconds := int(c.defaultIdleTimeout.Seconds())
cmds := make(map[int64]*redis.Cmd, len(accountIDs))
for _, accountID := range accountIDs {
key := sessionLimitKey(accountID)
cmds[accountID] = getActiveSessionCountScript.Run(ctx, pipe, []string{key}, idleTimeoutSeconds)
}
// 执行 pipeline即使部分失败也尝试获取成功的结果
_, _ = pipe.Exec(ctx)
for accountID, cmd := range cmds {
if result, err := cmd.Int(); err == nil {
results[accountID] = result
}
}
return results, nil
}
// IsSessionActive 检查会话是否活跃
func (c *sessionLimitCache) IsSessionActive(ctx context.Context, accountID int64, sessionUUID string) (bool, error) {
if sessionUUID == "" {
return false, nil
}
key := sessionLimitKey(accountID)
idleTimeoutSeconds := int(c.defaultIdleTimeout.Seconds())
result, err := isSessionActiveScript.Run(ctx, c.rdb, []string{key}, idleTimeoutSeconds, sessionUUID).Int()
if err != nil {
return false, err
}
return result == 1, nil
}
// ========== 5h窗口费用缓存实现 ==========
// GetWindowCost 获取缓存的窗口费用
func (c *sessionLimitCache) GetWindowCost(ctx context.Context, accountID int64) (float64, bool, error) {
key := windowCostKey(accountID)
val, err := c.rdb.Get(ctx, key).Float64()
if err == redis.Nil {
return 0, false, nil // 缓存未命中
}
if err != nil {
return 0, false, err
}
return val, true, nil
}
// SetWindowCost 设置窗口费用缓存
func (c *sessionLimitCache) SetWindowCost(ctx context.Context, accountID int64, cost float64) error {
key := windowCostKey(accountID)
return c.rdb.Set(ctx, key, cost, windowCostCacheTTL).Err()
}
// GetWindowCostBatch 批量获取窗口费用缓存
func (c *sessionLimitCache) GetWindowCostBatch(ctx context.Context, accountIDs []int64) (map[int64]float64, error) {
if len(accountIDs) == 0 {
return make(map[int64]float64), nil
}
// 构建批量查询的 keys
keys := make([]string, len(accountIDs))
for i, accountID := range accountIDs {
keys[i] = windowCostKey(accountID)
}
// 使用 MGET 批量获取
vals, err := c.rdb.MGet(ctx, keys...).Result()
if err != nil {
return nil, err
}
results := make(map[int64]float64, len(accountIDs))
for i, val := range vals {
if val == nil {
continue // 缓存未命中
}
// 尝试解析为 float64
switch v := val.(type) {
case string:
if cost, err := strconv.ParseFloat(v, 64); err == nil {
results[accountIDs[i]] = cost
}
case float64:
results[accountIDs[i]] = v
}
}
return results, nil
}

View File

@@ -0,0 +1,80 @@
package repository
import (
"context"
"fmt"
"time"
"github.com/Wei-Shaw/sub2api/internal/service"
"github.com/redis/go-redis/v9"
)
const timeoutCounterPrefix = "timeout_count:account:"
// timeoutCounterIncrScript 使用 Lua 脚本原子性地增加计数并返回当前值
// 如果 key 不存在,则创建并设置过期时间
var timeoutCounterIncrScript = redis.NewScript(`
local key = KEYS[1]
local ttl = tonumber(ARGV[1])
local count = redis.call('INCR', key)
if count == 1 then
redis.call('EXPIRE', key, ttl)
end
return count
`)
type timeoutCounterCache struct {
rdb *redis.Client
}
// NewTimeoutCounterCache 创建超时计数器缓存实例
func NewTimeoutCounterCache(rdb *redis.Client) service.TimeoutCounterCache {
return &timeoutCounterCache{rdb: rdb}
}
// IncrementTimeoutCount 增加账户的超时计数,返回当前计数值
// windowMinutes 是计数窗口时间(分钟),超过此时间计数器会自动重置
func (c *timeoutCounterCache) IncrementTimeoutCount(ctx context.Context, accountID int64, windowMinutes int) (int64, error) {
key := fmt.Sprintf("%s%d", timeoutCounterPrefix, accountID)
ttlSeconds := windowMinutes * 60
if ttlSeconds < 60 {
ttlSeconds = 60 // 最小1分钟
}
result, err := timeoutCounterIncrScript.Run(ctx, c.rdb, []string{key}, ttlSeconds).Int64()
if err != nil {
return 0, fmt.Errorf("increment timeout count: %w", err)
}
return result, nil
}
// GetTimeoutCount 获取账户当前的超时计数
func (c *timeoutCounterCache) GetTimeoutCount(ctx context.Context, accountID int64) (int64, error) {
key := fmt.Sprintf("%s%d", timeoutCounterPrefix, accountID)
val, err := c.rdb.Get(ctx, key).Int64()
if err == redis.Nil {
return 0, nil
}
if err != nil {
return 0, fmt.Errorf("get timeout count: %w", err)
}
return val, nil
}
// ResetTimeoutCount 重置账户的超时计数
func (c *timeoutCounterCache) ResetTimeoutCount(ctx context.Context, accountID int64) error {
key := fmt.Sprintf("%s%d", timeoutCounterPrefix, accountID)
return c.rdb.Del(ctx, key).Err()
}
// GetTimeoutCountTTL 获取计数器剩余过期时间
func (c *timeoutCounterCache) GetTimeoutCountTTL(ctx context.Context, accountID int64) (time.Duration, error) {
key := fmt.Sprintf("%s%d", timeoutCounterPrefix, accountID)
return c.rdb.TTL(ctx, key).Result()
}

View File

@@ -22,7 +22,7 @@ import (
"github.com/lib/pq"
)
const usageLogSelectColumns = "id, user_id, api_key_id, account_id, request_id, model, group_id, subscription_id, input_tokens, output_tokens, cache_creation_tokens, cache_read_tokens, cache_creation_5m_tokens, cache_creation_1h_tokens, input_cost, output_cost, cache_creation_cost, cache_read_cost, total_cost, actual_cost, rate_multiplier, billing_type, stream, duration_ms, first_token_ms, user_agent, image_count, image_size, created_at"
const usageLogSelectColumns = "id, user_id, api_key_id, account_id, request_id, model, group_id, subscription_id, input_tokens, output_tokens, cache_creation_tokens, cache_read_tokens, cache_creation_5m_tokens, cache_creation_1h_tokens, input_cost, output_cost, cache_creation_cost, cache_read_cost, total_cost, actual_cost, rate_multiplier, account_rate_multiplier, billing_type, stream, duration_ms, first_token_ms, user_agent, ip_address, image_count, image_size, created_at"
type usageLogRepository struct {
client *dbent.Client
@@ -105,11 +105,13 @@ func (r *usageLogRepository) Create(ctx context.Context, log *service.UsageLog)
total_cost,
actual_cost,
rate_multiplier,
account_rate_multiplier,
billing_type,
stream,
duration_ms,
first_token_ms,
user_agent,
ip_address,
image_count,
image_size,
created_at
@@ -119,7 +121,7 @@ func (r *usageLogRepository) Create(ctx context.Context, log *service.UsageLog)
$8, $9, $10, $11,
$12, $13,
$14, $15, $16, $17, $18, $19,
$20, $21, $22, $23, $24, $25, $26, $27, $28
$20, $21, $22, $23, $24, $25, $26, $27, $28, $29, $30
)
ON CONFLICT (request_id, api_key_id) DO NOTHING
RETURNING id, created_at
@@ -130,6 +132,7 @@ func (r *usageLogRepository) Create(ctx context.Context, log *service.UsageLog)
duration := nullInt(log.DurationMs)
firstToken := nullInt(log.FirstTokenMs)
userAgent := nullString(log.UserAgent)
ipAddress := nullString(log.IPAddress)
imageSize := nullString(log.ImageSize)
var requestIDArg any
@@ -158,11 +161,13 @@ func (r *usageLogRepository) Create(ctx context.Context, log *service.UsageLog)
log.TotalCost,
log.ActualCost,
rateMultiplier,
log.AccountRateMultiplier,
log.BillingType,
log.Stream,
duration,
firstToken,
userAgent,
ipAddress,
log.ImageCount,
imageSize,
createdAt,
@@ -266,16 +271,60 @@ func (r *usageLogRepository) GetUserStats(ctx context.Context, userID int64, sta
type DashboardStats = usagestats.DashboardStats
func (r *usageLogRepository) GetDashboardStats(ctx context.Context) (*DashboardStats, error) {
var stats DashboardStats
today := timezone.Today()
now := time.Now()
stats := &DashboardStats{}
now := timezone.Now()
todayStart := timezone.Today()
// 合并用户统计查询
if err := r.fillDashboardEntityStats(ctx, stats, todayStart, now); err != nil {
return nil, err
}
if err := r.fillDashboardUsageStatsAggregated(ctx, stats, todayStart, now); err != nil {
return nil, err
}
rpm, tpm, err := r.getPerformanceStats(ctx, 0)
if err != nil {
return nil, err
}
stats.Rpm = rpm
stats.Tpm = tpm
return stats, nil
}
func (r *usageLogRepository) GetDashboardStatsWithRange(ctx context.Context, start, end time.Time) (*DashboardStats, error) {
startUTC := start.UTC()
endUTC := end.UTC()
if !endUTC.After(startUTC) {
return nil, errors.New("统计时间范围无效")
}
stats := &DashboardStats{}
now := timezone.Now()
todayStart := timezone.Today()
if err := r.fillDashboardEntityStats(ctx, stats, todayStart, now); err != nil {
return nil, err
}
if err := r.fillDashboardUsageStatsFromUsageLogs(ctx, stats, startUTC, endUTC, todayStart, now); err != nil {
return nil, err
}
rpm, tpm, err := r.getPerformanceStats(ctx, 0)
if err != nil {
return nil, err
}
stats.Rpm = rpm
stats.Tpm = tpm
return stats, nil
}
func (r *usageLogRepository) fillDashboardEntityStats(ctx context.Context, stats *DashboardStats, todayUTC, now time.Time) error {
userStatsQuery := `
SELECT
COUNT(*) as total_users,
COUNT(CASE WHEN created_at >= $1 THEN 1 END) as today_new_users,
(SELECT COUNT(DISTINCT user_id) FROM usage_logs WHERE created_at >= $2) as active_users
COUNT(CASE WHEN created_at >= $1 THEN 1 END) as today_new_users
FROM users
WHERE deleted_at IS NULL
`
@@ -283,15 +332,13 @@ func (r *usageLogRepository) GetDashboardStats(ctx context.Context) (*DashboardS
ctx,
r.sql,
userStatsQuery,
[]any{today, today},
[]any{todayUTC},
&stats.TotalUsers,
&stats.TodayNewUsers,
&stats.ActiveUsers,
); err != nil {
return nil, err
return err
}
// 合并API Key统计查询
apiKeyStatsQuery := `
SELECT
COUNT(*) as total_api_keys,
@@ -307,10 +354,9 @@ func (r *usageLogRepository) GetDashboardStats(ctx context.Context) (*DashboardS
&stats.TotalAPIKeys,
&stats.ActiveAPIKeys,
); err != nil {
return nil, err
return err
}
// 合并账户统计查询
accountStatsQuery := `
SELECT
COUNT(*) as total_accounts,
@@ -332,22 +378,26 @@ func (r *usageLogRepository) GetDashboardStats(ctx context.Context) (*DashboardS
&stats.RateLimitAccounts,
&stats.OverloadAccounts,
); err != nil {
return nil, err
return err
}
// 累计 Token 统计
return nil
}
func (r *usageLogRepository) fillDashboardUsageStatsAggregated(ctx context.Context, stats *DashboardStats, todayUTC, now time.Time) error {
totalStatsQuery := `
SELECT
COUNT(*) as total_requests,
COALESCE(SUM(total_requests), 0) as total_requests,
COALESCE(SUM(input_tokens), 0) as total_input_tokens,
COALESCE(SUM(output_tokens), 0) as total_output_tokens,
COALESCE(SUM(cache_creation_tokens), 0) as total_cache_creation_tokens,
COALESCE(SUM(cache_read_tokens), 0) as total_cache_read_tokens,
COALESCE(SUM(total_cost), 0) as total_cost,
COALESCE(SUM(actual_cost), 0) as total_actual_cost,
COALESCE(AVG(duration_ms), 0) as avg_duration_ms
FROM usage_logs
COALESCE(SUM(total_duration_ms), 0) as total_duration_ms
FROM usage_dashboard_daily
`
var totalDurationMs int64
if err := scanSingleRow(
ctx,
r.sql,
@@ -360,13 +410,100 @@ func (r *usageLogRepository) GetDashboardStats(ctx context.Context) (*DashboardS
&stats.TotalCacheReadTokens,
&stats.TotalCost,
&stats.TotalActualCost,
&stats.AverageDurationMs,
&totalDurationMs,
); err != nil {
return nil, err
return err
}
stats.TotalTokens = stats.TotalInputTokens + stats.TotalOutputTokens + stats.TotalCacheCreationTokens + stats.TotalCacheReadTokens
if stats.TotalRequests > 0 {
stats.AverageDurationMs = float64(totalDurationMs) / float64(stats.TotalRequests)
}
// 今日 Token 统计
todayStatsQuery := `
SELECT
total_requests as today_requests,
input_tokens as today_input_tokens,
output_tokens as today_output_tokens,
cache_creation_tokens as today_cache_creation_tokens,
cache_read_tokens as today_cache_read_tokens,
total_cost as today_cost,
actual_cost as today_actual_cost,
active_users as active_users
FROM usage_dashboard_daily
WHERE bucket_date = $1::date
`
if err := scanSingleRow(
ctx,
r.sql,
todayStatsQuery,
[]any{todayUTC},
&stats.TodayRequests,
&stats.TodayInputTokens,
&stats.TodayOutputTokens,
&stats.TodayCacheCreationTokens,
&stats.TodayCacheReadTokens,
&stats.TodayCost,
&stats.TodayActualCost,
&stats.ActiveUsers,
); err != nil {
if err != sql.ErrNoRows {
return err
}
}
stats.TodayTokens = stats.TodayInputTokens + stats.TodayOutputTokens + stats.TodayCacheCreationTokens + stats.TodayCacheReadTokens
hourlyActiveQuery := `
SELECT active_users
FROM usage_dashboard_hourly
WHERE bucket_start = $1
`
hourStart := now.In(timezone.Location()).Truncate(time.Hour)
if err := scanSingleRow(ctx, r.sql, hourlyActiveQuery, []any{hourStart}, &stats.HourlyActiveUsers); err != nil {
if err != sql.ErrNoRows {
return err
}
}
return nil
}
func (r *usageLogRepository) fillDashboardUsageStatsFromUsageLogs(ctx context.Context, stats *DashboardStats, startUTC, endUTC, todayUTC, now time.Time) error {
totalStatsQuery := `
SELECT
COUNT(*) as total_requests,
COALESCE(SUM(input_tokens), 0) as total_input_tokens,
COALESCE(SUM(output_tokens), 0) as total_output_tokens,
COALESCE(SUM(cache_creation_tokens), 0) as total_cache_creation_tokens,
COALESCE(SUM(cache_read_tokens), 0) as total_cache_read_tokens,
COALESCE(SUM(total_cost), 0) as total_cost,
COALESCE(SUM(actual_cost), 0) as total_actual_cost,
COALESCE(SUM(COALESCE(duration_ms, 0)), 0) as total_duration_ms
FROM usage_logs
WHERE created_at >= $1 AND created_at < $2
`
var totalDurationMs int64
if err := scanSingleRow(
ctx,
r.sql,
totalStatsQuery,
[]any{startUTC, endUTC},
&stats.TotalRequests,
&stats.TotalInputTokens,
&stats.TotalOutputTokens,
&stats.TotalCacheCreationTokens,
&stats.TotalCacheReadTokens,
&stats.TotalCost,
&stats.TotalActualCost,
&totalDurationMs,
); err != nil {
return err
}
stats.TotalTokens = stats.TotalInputTokens + stats.TotalOutputTokens + stats.TotalCacheCreationTokens + stats.TotalCacheReadTokens
if stats.TotalRequests > 0 {
stats.AverageDurationMs = float64(totalDurationMs) / float64(stats.TotalRequests)
}
todayEnd := todayUTC.Add(24 * time.Hour)
todayStatsQuery := `
SELECT
COUNT(*) as today_requests,
@@ -377,13 +514,13 @@ func (r *usageLogRepository) GetDashboardStats(ctx context.Context) (*DashboardS
COALESCE(SUM(total_cost), 0) as today_cost,
COALESCE(SUM(actual_cost), 0) as today_actual_cost
FROM usage_logs
WHERE created_at >= $1
WHERE created_at >= $1 AND created_at < $2
`
if err := scanSingleRow(
ctx,
r.sql,
todayStatsQuery,
[]any{today},
[]any{todayUTC, todayEnd},
&stats.TodayRequests,
&stats.TodayInputTokens,
&stats.TodayOutputTokens,
@@ -392,19 +529,31 @@ func (r *usageLogRepository) GetDashboardStats(ctx context.Context) (*DashboardS
&stats.TodayCost,
&stats.TodayActualCost,
); err != nil {
return nil, err
return err
}
stats.TodayTokens = stats.TodayInputTokens + stats.TodayOutputTokens + stats.TodayCacheCreationTokens + stats.TodayCacheReadTokens
// 性能指标RPM 和 TPM最近1分钟全局
rpm, tpm, err := r.getPerformanceStats(ctx, 0)
if err != nil {
return nil, err
activeUsersQuery := `
SELECT COUNT(DISTINCT user_id) as active_users
FROM usage_logs
WHERE created_at >= $1 AND created_at < $2
`
if err := scanSingleRow(ctx, r.sql, activeUsersQuery, []any{todayUTC, todayEnd}, &stats.ActiveUsers); err != nil {
return err
}
stats.Rpm = rpm
stats.Tpm = tpm
return &stats, nil
hourStart := now.UTC().Truncate(time.Hour)
hourEnd := hourStart.Add(time.Hour)
hourlyActiveQuery := `
SELECT COUNT(DISTINCT user_id) as active_users
FROM usage_logs
WHERE created_at >= $1 AND created_at < $2
`
if err := scanSingleRow(ctx, r.sql, hourlyActiveQuery, []any{hourStart, hourEnd}, &stats.HourlyActiveUsers); err != nil {
return err
}
return nil
}
func (r *usageLogRepository) ListByAccount(ctx context.Context, accountID int64, params pagination.PaginationParams) ([]service.UsageLog, *pagination.PaginationResult, error) {
@@ -688,7 +837,9 @@ func (r *usageLogRepository) GetAccountTodayStats(ctx context.Context, accountID
SELECT
COUNT(*) as requests,
COALESCE(SUM(input_tokens + output_tokens + cache_creation_tokens + cache_read_tokens), 0) as tokens,
COALESCE(SUM(actual_cost), 0) as cost
COALESCE(SUM(total_cost * COALESCE(account_rate_multiplier, 1)), 0) as cost,
COALESCE(SUM(total_cost), 0) as standard_cost,
COALESCE(SUM(actual_cost), 0) as user_cost
FROM usage_logs
WHERE account_id = $1 AND created_at >= $2
`
@@ -702,6 +853,8 @@ func (r *usageLogRepository) GetAccountTodayStats(ctx context.Context, accountID
&stats.Requests,
&stats.Tokens,
&stats.Cost,
&stats.StandardCost,
&stats.UserCost,
); err != nil {
return nil, err
}
@@ -714,7 +867,9 @@ func (r *usageLogRepository) GetAccountWindowStats(ctx context.Context, accountI
SELECT
COUNT(*) as requests,
COALESCE(SUM(input_tokens + output_tokens + cache_creation_tokens + cache_read_tokens), 0) as tokens,
COALESCE(SUM(actual_cost), 0) as cost
COALESCE(SUM(total_cost * COALESCE(account_rate_multiplier, 1)), 0) as cost,
COALESCE(SUM(total_cost), 0) as standard_cost,
COALESCE(SUM(actual_cost), 0) as user_cost
FROM usage_logs
WHERE account_id = $1 AND created_at >= $2
`
@@ -728,6 +883,8 @@ func (r *usageLogRepository) GetAccountWindowStats(ctx context.Context, accountI
&stats.Requests,
&stats.Tokens,
&stats.Cost,
&stats.StandardCost,
&stats.UserCost,
); err != nil {
return nil, err
}
@@ -1253,8 +1410,8 @@ func (r *usageLogRepository) GetBatchAPIKeyUsageStats(ctx context.Context, apiKe
return result, nil
}
// GetUsageTrendWithFilters returns usage trend data with optional user/api_key filters
func (r *usageLogRepository) GetUsageTrendWithFilters(ctx context.Context, startTime, endTime time.Time, granularity string, userID, apiKeyID int64) (results []TrendDataPoint, err error) {
// GetUsageTrendWithFilters returns usage trend data with optional filters
func (r *usageLogRepository) GetUsageTrendWithFilters(ctx context.Context, startTime, endTime time.Time, granularity string, userID, apiKeyID, accountID, groupID int64, model string, stream *bool) (results []TrendDataPoint, err error) {
dateFormat := "YYYY-MM-DD"
if granularity == "hour" {
dateFormat = "YYYY-MM-DD HH24:00"
@@ -1283,6 +1440,22 @@ func (r *usageLogRepository) GetUsageTrendWithFilters(ctx context.Context, start
query += fmt.Sprintf(" AND api_key_id = $%d", len(args)+1)
args = append(args, apiKeyID)
}
if accountID > 0 {
query += fmt.Sprintf(" AND account_id = $%d", len(args)+1)
args = append(args, accountID)
}
if groupID > 0 {
query += fmt.Sprintf(" AND group_id = $%d", len(args)+1)
args = append(args, groupID)
}
if model != "" {
query += fmt.Sprintf(" AND model = $%d", len(args)+1)
args = append(args, model)
}
if stream != nil {
query += fmt.Sprintf(" AND stream = $%d", len(args)+1)
args = append(args, *stream)
}
query += " GROUP BY date ORDER BY date ASC"
rows, err := r.sql.QueryContext(ctx, query, args...)
@@ -1305,9 +1478,15 @@ func (r *usageLogRepository) GetUsageTrendWithFilters(ctx context.Context, start
return results, nil
}
// GetModelStatsWithFilters returns model statistics with optional user/api_key filters
func (r *usageLogRepository) GetModelStatsWithFilters(ctx context.Context, startTime, endTime time.Time, userID, apiKeyID, accountID int64) (results []ModelStat, err error) {
query := `
// GetModelStatsWithFilters returns model statistics with optional filters
func (r *usageLogRepository) GetModelStatsWithFilters(ctx context.Context, startTime, endTime time.Time, userID, apiKeyID, accountID, groupID int64, stream *bool) (results []ModelStat, err error) {
actualCostExpr := "COALESCE(SUM(actual_cost), 0) as actual_cost"
// 当仅按 account_id 聚合时实际费用使用账号倍率total_cost * account_rate_multiplier
if accountID > 0 && userID == 0 && apiKeyID == 0 {
actualCostExpr = "COALESCE(SUM(total_cost * COALESCE(account_rate_multiplier, 1)), 0) as actual_cost"
}
query := fmt.Sprintf(`
SELECT
model,
COUNT(*) as requests,
@@ -1315,10 +1494,10 @@ func (r *usageLogRepository) GetModelStatsWithFilters(ctx context.Context, start
COALESCE(SUM(output_tokens), 0) as output_tokens,
COALESCE(SUM(input_tokens + output_tokens + cache_creation_tokens + cache_read_tokens), 0) as total_tokens,
COALESCE(SUM(total_cost), 0) as cost,
COALESCE(SUM(actual_cost), 0) as actual_cost
%s
FROM usage_logs
WHERE created_at >= $1 AND created_at < $2
`
`, actualCostExpr)
args := []any{startTime, endTime}
if userID > 0 {
@@ -1333,6 +1512,14 @@ func (r *usageLogRepository) GetModelStatsWithFilters(ctx context.Context, start
query += fmt.Sprintf(" AND account_id = $%d", len(args)+1)
args = append(args, accountID)
}
if groupID > 0 {
query += fmt.Sprintf(" AND group_id = $%d", len(args)+1)
args = append(args, groupID)
}
if stream != nil {
query += fmt.Sprintf(" AND stream = $%d", len(args)+1)
args = append(args, *stream)
}
query += " GROUP BY model ORDER BY total_tokens DESC"
rows, err := r.sql.QueryContext(ctx, query, args...)
@@ -1440,12 +1627,14 @@ func (r *usageLogRepository) GetStatsWithFilters(ctx context.Context, filters Us
COALESCE(SUM(cache_creation_tokens + cache_read_tokens), 0) as total_cache_tokens,
COALESCE(SUM(total_cost), 0) as total_cost,
COALESCE(SUM(actual_cost), 0) as total_actual_cost,
COALESCE(SUM(total_cost * COALESCE(account_rate_multiplier, 1)), 0) as total_account_cost,
COALESCE(AVG(duration_ms), 0) as avg_duration_ms
FROM usage_logs
%s
`, buildWhere(conditions))
stats := &UsageStats{}
var totalAccountCost float64
if err := scanSingleRow(
ctx,
r.sql,
@@ -1457,10 +1646,14 @@ func (r *usageLogRepository) GetStatsWithFilters(ctx context.Context, filters Us
&stats.TotalCacheTokens,
&stats.TotalCost,
&stats.TotalActualCost,
&totalAccountCost,
&stats.AverageDurationMs,
); err != nil {
return nil, err
}
if filters.AccountID > 0 {
stats.TotalAccountCost = &totalAccountCost
}
stats.TotalTokens = stats.TotalInputTokens + stats.TotalOutputTokens + stats.TotalCacheTokens
return stats, nil
}
@@ -1487,7 +1680,8 @@ func (r *usageLogRepository) GetAccountUsageStats(ctx context.Context, accountID
COUNT(*) as requests,
COALESCE(SUM(input_tokens + output_tokens + cache_creation_tokens + cache_read_tokens), 0) as tokens,
COALESCE(SUM(total_cost), 0) as cost,
COALESCE(SUM(actual_cost), 0) as actual_cost
COALESCE(SUM(total_cost * COALESCE(account_rate_multiplier, 1)), 0) as actual_cost,
COALESCE(SUM(actual_cost), 0) as user_cost
FROM usage_logs
WHERE account_id = $1 AND created_at >= $2 AND created_at < $3
GROUP BY date
@@ -1514,7 +1708,8 @@ func (r *usageLogRepository) GetAccountUsageStats(ctx context.Context, accountID
var tokens int64
var cost float64
var actualCost float64
if err = rows.Scan(&date, &requests, &tokens, &cost, &actualCost); err != nil {
var userCost float64
if err = rows.Scan(&date, &requests, &tokens, &cost, &actualCost, &userCost); err != nil {
return nil, err
}
t, _ := time.Parse("2006-01-02", date)
@@ -1525,19 +1720,21 @@ func (r *usageLogRepository) GetAccountUsageStats(ctx context.Context, accountID
Tokens: tokens,
Cost: cost,
ActualCost: actualCost,
UserCost: userCost,
})
}
if err = rows.Err(); err != nil {
return nil, err
}
var totalActualCost, totalStandardCost float64
var totalAccountCost, totalUserCost, totalStandardCost float64
var totalRequests, totalTokens int64
var highestCostDay, highestRequestDay *AccountUsageHistory
for i := range history {
h := &history[i]
totalActualCost += h.ActualCost
totalAccountCost += h.ActualCost
totalUserCost += h.UserCost
totalStandardCost += h.Cost
totalRequests += h.Requests
totalTokens += h.Tokens
@@ -1564,11 +1761,13 @@ func (r *usageLogRepository) GetAccountUsageStats(ctx context.Context, accountID
summary := AccountUsageSummary{
Days: daysCount,
ActualDaysUsed: actualDaysUsed,
TotalCost: totalActualCost,
TotalCost: totalAccountCost,
TotalUserCost: totalUserCost,
TotalStandardCost: totalStandardCost,
TotalRequests: totalRequests,
TotalTokens: totalTokens,
AvgDailyCost: totalActualCost / float64(actualDaysUsed),
AvgDailyCost: totalAccountCost / float64(actualDaysUsed),
AvgDailyUserCost: totalUserCost / float64(actualDaysUsed),
AvgDailyRequests: float64(totalRequests) / float64(actualDaysUsed),
AvgDailyTokens: float64(totalTokens) / float64(actualDaysUsed),
AvgDurationMs: avgDuration,
@@ -1580,11 +1779,13 @@ func (r *usageLogRepository) GetAccountUsageStats(ctx context.Context, accountID
summary.Today = &struct {
Date string `json:"date"`
Cost float64 `json:"cost"`
UserCost float64 `json:"user_cost"`
Requests int64 `json:"requests"`
Tokens int64 `json:"tokens"`
}{
Date: history[i].Date,
Cost: history[i].ActualCost,
UserCost: history[i].UserCost,
Requests: history[i].Requests,
Tokens: history[i].Tokens,
}
@@ -1597,11 +1798,13 @@ func (r *usageLogRepository) GetAccountUsageStats(ctx context.Context, accountID
Date string `json:"date"`
Label string `json:"label"`
Cost float64 `json:"cost"`
UserCost float64 `json:"user_cost"`
Requests int64 `json:"requests"`
}{
Date: highestCostDay.Date,
Label: highestCostDay.Label,
Cost: highestCostDay.ActualCost,
UserCost: highestCostDay.UserCost,
Requests: highestCostDay.Requests,
}
}
@@ -1612,15 +1815,17 @@ func (r *usageLogRepository) GetAccountUsageStats(ctx context.Context, accountID
Label string `json:"label"`
Requests int64 `json:"requests"`
Cost float64 `json:"cost"`
UserCost float64 `json:"user_cost"`
}{
Date: highestRequestDay.Date,
Label: highestRequestDay.Label,
Requests: highestRequestDay.Requests,
Cost: highestRequestDay.ActualCost,
UserCost: highestRequestDay.UserCost,
}
}
models, err := r.GetModelStatsWithFilters(ctx, startTime, endTime, 0, 0, accountID)
models, err := r.GetModelStatsWithFilters(ctx, startTime, endTime, 0, 0, accountID, 0, nil)
if err != nil {
models = []ModelStat{}
}
@@ -1847,35 +2052,37 @@ func (r *usageLogRepository) loadSubscriptions(ctx context.Context, ids []int64)
func scanUsageLog(scanner interface{ Scan(...any) error }) (*service.UsageLog, error) {
var (
id int64
userID int64
apiKeyID int64
accountID int64
requestID sql.NullString
model string
groupID sql.NullInt64
subscriptionID sql.NullInt64
inputTokens int
outputTokens int
cacheCreationTokens int
cacheReadTokens int
cacheCreation5m int
cacheCreation1h int
inputCost float64
outputCost float64
cacheCreationCost float64
cacheReadCost float64
totalCost float64
actualCost float64
rateMultiplier float64
billingType int16
stream bool
durationMs sql.NullInt64
firstTokenMs sql.NullInt64
userAgent sql.NullString
imageCount int
imageSize sql.NullString
createdAt time.Time
id int64
userID int64
apiKeyID int64
accountID int64
requestID sql.NullString
model string
groupID sql.NullInt64
subscriptionID sql.NullInt64
inputTokens int
outputTokens int
cacheCreationTokens int
cacheReadTokens int
cacheCreation5m int
cacheCreation1h int
inputCost float64
outputCost float64
cacheCreationCost float64
cacheReadCost float64
totalCost float64
actualCost float64
rateMultiplier float64
accountRateMultiplier sql.NullFloat64
billingType int16
stream bool
durationMs sql.NullInt64
firstTokenMs sql.NullInt64
userAgent sql.NullString
ipAddress sql.NullString
imageCount int
imageSize sql.NullString
createdAt time.Time
)
if err := scanner.Scan(
@@ -1900,11 +2107,13 @@ func scanUsageLog(scanner interface{ Scan(...any) error }) (*service.UsageLog, e
&totalCost,
&actualCost,
&rateMultiplier,
&accountRateMultiplier,
&billingType,
&stream,
&durationMs,
&firstTokenMs,
&userAgent,
&ipAddress,
&imageCount,
&imageSize,
&createdAt,
@@ -1931,6 +2140,7 @@ func scanUsageLog(scanner interface{ Scan(...any) error }) (*service.UsageLog, e
TotalCost: totalCost,
ActualCost: actualCost,
RateMultiplier: rateMultiplier,
AccountRateMultiplier: nullFloat64Ptr(accountRateMultiplier),
BillingType: int8(billingType),
Stream: stream,
ImageCount: imageCount,
@@ -1959,6 +2169,9 @@ func scanUsageLog(scanner interface{ Scan(...any) error }) (*service.UsageLog, e
if userAgent.Valid {
log.UserAgent = &userAgent.String
}
if ipAddress.Valid {
log.IPAddress = &ipAddress.String
}
if imageSize.Valid {
log.ImageSize = &imageSize.String
}
@@ -2034,6 +2247,14 @@ func nullInt(v *int) sql.NullInt64 {
return sql.NullInt64{Int64: int64(*v), Valid: true}
}
func nullFloat64Ptr(v sql.NullFloat64) *float64 {
if !v.Valid {
return nil
}
out := v.Float64
return &out
}
func nullString(v *string) sql.NullString {
if v == nil || *v == "" {
return sql.NullString{}

View File

@@ -37,6 +37,12 @@ func TestUsageLogRepoSuite(t *testing.T) {
suite.Run(t, new(UsageLogRepoSuite))
}
// truncateToDayUTC 截断到 UTC 日期边界(测试辅助函数)
func truncateToDayUTC(t time.Time) time.Time {
t = t.UTC()
return time.Date(t.Year(), t.Month(), t.Day(), 0, 0, 0, 0, time.UTC)
}
func (s *UsageLogRepoSuite) createUsageLog(user *service.User, apiKey *service.APIKey, account *service.Account, inputTokens, outputTokens int, cost float64, createdAt time.Time) *service.UsageLog {
log := &service.UsageLog{
UserID: user.ID,
@@ -96,6 +102,34 @@ func (s *UsageLogRepoSuite) TestGetByID_NotFound() {
s.Require().Error(err, "expected error for non-existent ID")
}
func (s *UsageLogRepoSuite) TestGetByID_ReturnsAccountRateMultiplier() {
user := mustCreateUser(s.T(), s.client, &service.User{Email: "getbyid-mult@test.com"})
apiKey := mustCreateApiKey(s.T(), s.client, &service.APIKey{UserID: user.ID, Key: "sk-getbyid-mult", Name: "k"})
account := mustCreateAccount(s.T(), s.client, &service.Account{Name: "acc-getbyid-mult"})
m := 0.5
log := &service.UsageLog{
UserID: user.ID,
APIKeyID: apiKey.ID,
AccountID: account.ID,
RequestID: uuid.New().String(),
Model: "claude-3",
InputTokens: 10,
OutputTokens: 20,
TotalCost: 1.0,
ActualCost: 2.0,
AccountRateMultiplier: &m,
CreatedAt: timezone.Today().Add(2 * time.Hour),
}
_, err := s.repo.Create(s.ctx, log)
s.Require().NoError(err)
got, err := s.repo.GetByID(s.ctx, log.ID)
s.Require().NoError(err)
s.Require().NotNil(got.AccountRateMultiplier)
s.Require().InEpsilon(0.5, *got.AccountRateMultiplier, 0.0001)
}
// --- Delete ---
func (s *UsageLogRepoSuite) TestDelete() {
@@ -198,14 +232,14 @@ func (s *UsageLogRepoSuite) TestListWithFilters() {
// --- GetDashboardStats ---
func (s *UsageLogRepoSuite) TestDashboardStats_TodayTotalsAndPerformance() {
now := time.Now()
todayStart := timezone.Today()
now := time.Now().UTC()
todayStart := truncateToDayUTC(now)
baseStats, err := s.repo.GetDashboardStats(s.ctx)
s.Require().NoError(err, "GetDashboardStats base")
userToday := mustCreateUser(s.T(), s.client, &service.User{
Email: "today@example.com",
CreatedAt: maxTime(todayStart.Add(10*time.Second), now.Add(-10*time.Second)),
CreatedAt: testMaxTime(todayStart.Add(10*time.Second), now.Add(-10*time.Second)),
UpdatedAt: now,
})
userOld := mustCreateUser(s.T(), s.client, &service.User{
@@ -238,7 +272,7 @@ func (s *UsageLogRepoSuite) TestDashboardStats_TodayTotalsAndPerformance() {
TotalCost: 1.5,
ActualCost: 1.2,
DurationMs: &d1,
CreatedAt: maxTime(todayStart.Add(2*time.Minute), now.Add(-2*time.Minute)),
CreatedAt: testMaxTime(todayStart.Add(2*time.Minute), now.Add(-2*time.Minute)),
}
_, err = s.repo.Create(s.ctx, logToday)
s.Require().NoError(err, "Create logToday")
@@ -273,6 +307,11 @@ func (s *UsageLogRepoSuite) TestDashboardStats_TodayTotalsAndPerformance() {
_, err = s.repo.Create(s.ctx, logPerf)
s.Require().NoError(err, "Create logPerf")
aggRepo := newDashboardAggregationRepositoryWithSQL(s.tx)
aggStart := todayStart.Add(-2 * time.Hour)
aggEnd := now.Add(2 * time.Minute)
s.Require().NoError(aggRepo.AggregateRange(s.ctx, aggStart, aggEnd), "AggregateRange")
stats, err := s.repo.GetDashboardStats(s.ctx)
s.Require().NoError(err, "GetDashboardStats")
@@ -303,6 +342,80 @@ func (s *UsageLogRepoSuite) TestDashboardStats_TodayTotalsAndPerformance() {
s.Require().Equal(wantTpm, stats.Tpm, "Tpm mismatch")
}
func (s *UsageLogRepoSuite) TestDashboardStatsWithRange_Fallback() {
now := time.Now().UTC()
todayStart := truncateToDayUTC(now)
rangeStart := todayStart.Add(-24 * time.Hour)
rangeEnd := now.Add(1 * time.Second)
user1 := mustCreateUser(s.T(), s.client, &service.User{Email: "range-u1@test.com"})
user2 := mustCreateUser(s.T(), s.client, &service.User{Email: "range-u2@test.com"})
apiKey1 := mustCreateApiKey(s.T(), s.client, &service.APIKey{UserID: user1.ID, Key: "sk-range-1", Name: "k1"})
apiKey2 := mustCreateApiKey(s.T(), s.client, &service.APIKey{UserID: user2.ID, Key: "sk-range-2", Name: "k2"})
account := mustCreateAccount(s.T(), s.client, &service.Account{Name: "acc-range"})
d1, d2, d3 := 100, 200, 300
logOutside := &service.UsageLog{
UserID: user1.ID,
APIKeyID: apiKey1.ID,
AccountID: account.ID,
Model: "claude-3",
InputTokens: 7,
OutputTokens: 8,
TotalCost: 0.8,
ActualCost: 0.7,
DurationMs: &d3,
CreatedAt: rangeStart.Add(-1 * time.Hour),
}
_, err := s.repo.Create(s.ctx, logOutside)
s.Require().NoError(err)
logRange := &service.UsageLog{
UserID: user1.ID,
APIKeyID: apiKey1.ID,
AccountID: account.ID,
Model: "claude-3",
InputTokens: 10,
OutputTokens: 20,
CacheCreationTokens: 1,
CacheReadTokens: 2,
TotalCost: 1.0,
ActualCost: 0.9,
DurationMs: &d1,
CreatedAt: rangeStart.Add(2 * time.Hour),
}
_, err = s.repo.Create(s.ctx, logRange)
s.Require().NoError(err)
logToday := &service.UsageLog{
UserID: user2.ID,
APIKeyID: apiKey2.ID,
AccountID: account.ID,
Model: "claude-3",
InputTokens: 5,
OutputTokens: 6,
CacheReadTokens: 1,
TotalCost: 0.5,
ActualCost: 0.5,
DurationMs: &d2,
CreatedAt: now,
}
_, err = s.repo.Create(s.ctx, logToday)
s.Require().NoError(err)
stats, err := s.repo.GetDashboardStatsWithRange(s.ctx, rangeStart, rangeEnd)
s.Require().NoError(err)
s.Require().Equal(int64(2), stats.TotalRequests)
s.Require().Equal(int64(15), stats.TotalInputTokens)
s.Require().Equal(int64(26), stats.TotalOutputTokens)
s.Require().Equal(int64(1), stats.TotalCacheCreationTokens)
s.Require().Equal(int64(3), stats.TotalCacheReadTokens)
s.Require().Equal(int64(45), stats.TotalTokens)
s.Require().Equal(1.5, stats.TotalCost)
s.Require().Equal(1.4, stats.TotalActualCost)
s.Require().InEpsilon(150.0, stats.AverageDurationMs, 0.0001)
}
// --- GetUserDashboardStats ---
func (s *UsageLogRepoSuite) TestGetUserDashboardStats() {
@@ -325,12 +438,202 @@ func (s *UsageLogRepoSuite) TestGetAccountTodayStats() {
apiKey := mustCreateApiKey(s.T(), s.client, &service.APIKey{UserID: user.ID, Key: "sk-acctoday", Name: "k"})
account := mustCreateAccount(s.T(), s.client, &service.Account{Name: "acc-today"})
s.createUsageLog(user, apiKey, account, 10, 20, 0.5, time.Now())
createdAt := timezone.Today().Add(1 * time.Hour)
m1 := 1.5
m2 := 0.0
_, err := s.repo.Create(s.ctx, &service.UsageLog{
UserID: user.ID,
APIKeyID: apiKey.ID,
AccountID: account.ID,
RequestID: uuid.New().String(),
Model: "claude-3",
InputTokens: 10,
OutputTokens: 20,
TotalCost: 1.0,
ActualCost: 2.0,
AccountRateMultiplier: &m1,
CreatedAt: createdAt,
})
s.Require().NoError(err)
_, err = s.repo.Create(s.ctx, &service.UsageLog{
UserID: user.ID,
APIKeyID: apiKey.ID,
AccountID: account.ID,
RequestID: uuid.New().String(),
Model: "claude-3",
InputTokens: 5,
OutputTokens: 5,
TotalCost: 0.5,
ActualCost: 1.0,
AccountRateMultiplier: &m2,
CreatedAt: createdAt,
})
s.Require().NoError(err)
stats, err := s.repo.GetAccountTodayStats(s.ctx, account.ID)
s.Require().NoError(err, "GetAccountTodayStats")
s.Require().Equal(int64(1), stats.Requests)
s.Require().Equal(int64(30), stats.Tokens)
s.Require().Equal(int64(2), stats.Requests)
s.Require().Equal(int64(40), stats.Tokens)
// account cost = SUM(total_cost * account_rate_multiplier)
s.Require().InEpsilon(1.5, stats.Cost, 0.0001)
// standard cost = SUM(total_cost)
s.Require().InEpsilon(1.5, stats.StandardCost, 0.0001)
// user cost = SUM(actual_cost)
s.Require().InEpsilon(3.0, stats.UserCost, 0.0001)
}
func (s *UsageLogRepoSuite) TestDashboardAggregationConsistency() {
now := time.Now().UTC().Truncate(time.Second)
// 使用固定的时间偏移确保 hour1 和 hour2 在同一天且都在过去
// 选择当天 02:00 和 03:00 作为测试时间点(基于 now 的日期)
dayStart := truncateToDayUTC(now)
hour1 := dayStart.Add(2 * time.Hour) // 当天 02:00
hour2 := dayStart.Add(3 * time.Hour) // 当天 03:00
// 如果当前时间早于 hour2则使用昨天的时间
if now.Before(hour2.Add(time.Hour)) {
dayStart = dayStart.Add(-24 * time.Hour)
hour1 = dayStart.Add(2 * time.Hour)
hour2 = dayStart.Add(3 * time.Hour)
}
user1 := mustCreateUser(s.T(), s.client, &service.User{Email: "agg-u1@test.com"})
user2 := mustCreateUser(s.T(), s.client, &service.User{Email: "agg-u2@test.com"})
apiKey1 := mustCreateApiKey(s.T(), s.client, &service.APIKey{UserID: user1.ID, Key: "sk-agg-1", Name: "k1"})
apiKey2 := mustCreateApiKey(s.T(), s.client, &service.APIKey{UserID: user2.ID, Key: "sk-agg-2", Name: "k2"})
account := mustCreateAccount(s.T(), s.client, &service.Account{Name: "acc-agg"})
d1, d2, d3 := 100, 200, 150
log1 := &service.UsageLog{
UserID: user1.ID,
APIKeyID: apiKey1.ID,
AccountID: account.ID,
Model: "claude-3",
InputTokens: 10,
OutputTokens: 20,
CacheCreationTokens: 2,
CacheReadTokens: 1,
TotalCost: 1.0,
ActualCost: 0.9,
DurationMs: &d1,
CreatedAt: hour1.Add(5 * time.Minute),
}
_, err := s.repo.Create(s.ctx, log1)
s.Require().NoError(err)
log2 := &service.UsageLog{
UserID: user1.ID,
APIKeyID: apiKey1.ID,
AccountID: account.ID,
Model: "claude-3",
InputTokens: 5,
OutputTokens: 5,
TotalCost: 0.5,
ActualCost: 0.5,
DurationMs: &d2,
CreatedAt: hour1.Add(20 * time.Minute),
}
_, err = s.repo.Create(s.ctx, log2)
s.Require().NoError(err)
log3 := &service.UsageLog{
UserID: user2.ID,
APIKeyID: apiKey2.ID,
AccountID: account.ID,
Model: "claude-3",
InputTokens: 7,
OutputTokens: 8,
TotalCost: 0.7,
ActualCost: 0.7,
DurationMs: &d3,
CreatedAt: hour2.Add(10 * time.Minute),
}
_, err = s.repo.Create(s.ctx, log3)
s.Require().NoError(err)
aggRepo := newDashboardAggregationRepositoryWithSQL(s.tx)
aggStart := hour1.Add(-5 * time.Minute)
aggEnd := hour2.Add(time.Hour) // 确保覆盖 hour2 的所有数据
s.Require().NoError(aggRepo.AggregateRange(s.ctx, aggStart, aggEnd))
type hourlyRow struct {
totalRequests int64
inputTokens int64
outputTokens int64
cacheCreationTokens int64
cacheReadTokens int64
totalCost float64
actualCost float64
totalDurationMs int64
activeUsers int64
}
fetchHourly := func(bucketStart time.Time) hourlyRow {
var row hourlyRow
err := scanSingleRow(s.ctx, s.tx, `
SELECT total_requests, input_tokens, output_tokens, cache_creation_tokens, cache_read_tokens,
total_cost, actual_cost, total_duration_ms, active_users
FROM usage_dashboard_hourly
WHERE bucket_start = $1
`, []any{bucketStart}, &row.totalRequests, &row.inputTokens, &row.outputTokens,
&row.cacheCreationTokens, &row.cacheReadTokens, &row.totalCost, &row.actualCost,
&row.totalDurationMs, &row.activeUsers,
)
s.Require().NoError(err)
return row
}
hour1Row := fetchHourly(hour1)
s.Require().Equal(int64(2), hour1Row.totalRequests)
s.Require().Equal(int64(15), hour1Row.inputTokens)
s.Require().Equal(int64(25), hour1Row.outputTokens)
s.Require().Equal(int64(2), hour1Row.cacheCreationTokens)
s.Require().Equal(int64(1), hour1Row.cacheReadTokens)
s.Require().Equal(1.5, hour1Row.totalCost)
s.Require().Equal(1.4, hour1Row.actualCost)
s.Require().Equal(int64(300), hour1Row.totalDurationMs)
s.Require().Equal(int64(1), hour1Row.activeUsers)
hour2Row := fetchHourly(hour2)
s.Require().Equal(int64(1), hour2Row.totalRequests)
s.Require().Equal(int64(7), hour2Row.inputTokens)
s.Require().Equal(int64(8), hour2Row.outputTokens)
s.Require().Equal(int64(0), hour2Row.cacheCreationTokens)
s.Require().Equal(int64(0), hour2Row.cacheReadTokens)
s.Require().Equal(0.7, hour2Row.totalCost)
s.Require().Equal(0.7, hour2Row.actualCost)
s.Require().Equal(int64(150), hour2Row.totalDurationMs)
s.Require().Equal(int64(1), hour2Row.activeUsers)
var daily struct {
totalRequests int64
inputTokens int64
outputTokens int64
cacheCreationTokens int64
cacheReadTokens int64
totalCost float64
actualCost float64
totalDurationMs int64
activeUsers int64
}
err = scanSingleRow(s.ctx, s.tx, `
SELECT total_requests, input_tokens, output_tokens, cache_creation_tokens, cache_read_tokens,
total_cost, actual_cost, total_duration_ms, active_users
FROM usage_dashboard_daily
WHERE bucket_date = $1::date
`, []any{dayStart}, &daily.totalRequests, &daily.inputTokens, &daily.outputTokens,
&daily.cacheCreationTokens, &daily.cacheReadTokens, &daily.totalCost, &daily.actualCost,
&daily.totalDurationMs, &daily.activeUsers,
)
s.Require().NoError(err)
s.Require().Equal(int64(3), daily.totalRequests)
s.Require().Equal(int64(22), daily.inputTokens)
s.Require().Equal(int64(33), daily.outputTokens)
s.Require().Equal(int64(2), daily.cacheCreationTokens)
s.Require().Equal(int64(1), daily.cacheReadTokens)
s.Require().Equal(2.2, daily.totalCost)
s.Require().Equal(2.1, daily.actualCost)
s.Require().Equal(int64(450), daily.totalDurationMs)
s.Require().Equal(int64(2), daily.activeUsers)
}
// --- GetBatchUserUsageStats ---
@@ -398,7 +701,7 @@ func (s *UsageLogRepoSuite) TestGetGlobalStats() {
s.Require().Equal(int64(45), stats.TotalOutputTokens)
}
func maxTime(a, b time.Time) time.Time {
func testMaxTime(a, b time.Time) time.Time {
if a.After(b) {
return a
}
@@ -641,17 +944,17 @@ func (s *UsageLogRepoSuite) TestGetUsageTrendWithFilters() {
endTime := base.Add(48 * time.Hour)
// Test with user filter
trend, err := s.repo.GetUsageTrendWithFilters(s.ctx, startTime, endTime, "day", user.ID, 0)
trend, err := s.repo.GetUsageTrendWithFilters(s.ctx, startTime, endTime, "day", user.ID, 0, 0, 0, "", nil)
s.Require().NoError(err, "GetUsageTrendWithFilters user filter")
s.Require().Len(trend, 2)
// Test with apiKey filter
trend, err = s.repo.GetUsageTrendWithFilters(s.ctx, startTime, endTime, "day", 0, apiKey.ID)
trend, err = s.repo.GetUsageTrendWithFilters(s.ctx, startTime, endTime, "day", 0, apiKey.ID, 0, 0, "", nil)
s.Require().NoError(err, "GetUsageTrendWithFilters apiKey filter")
s.Require().Len(trend, 2)
// Test with both filters
trend, err = s.repo.GetUsageTrendWithFilters(s.ctx, startTime, endTime, "day", user.ID, apiKey.ID)
trend, err = s.repo.GetUsageTrendWithFilters(s.ctx, startTime, endTime, "day", user.ID, apiKey.ID, 0, 0, "", nil)
s.Require().NoError(err, "GetUsageTrendWithFilters both filters")
s.Require().Len(trend, 2)
}
@@ -668,7 +971,7 @@ func (s *UsageLogRepoSuite) TestGetUsageTrendWithFilters_HourlyGranularity() {
startTime := base.Add(-1 * time.Hour)
endTime := base.Add(3 * time.Hour)
trend, err := s.repo.GetUsageTrendWithFilters(s.ctx, startTime, endTime, "hour", user.ID, 0)
trend, err := s.repo.GetUsageTrendWithFilters(s.ctx, startTime, endTime, "hour", user.ID, 0, 0, 0, "", nil)
s.Require().NoError(err, "GetUsageTrendWithFilters hourly")
s.Require().Len(trend, 2)
}
@@ -714,17 +1017,17 @@ func (s *UsageLogRepoSuite) TestGetModelStatsWithFilters() {
endTime := base.Add(2 * time.Hour)
// Test with user filter
stats, err := s.repo.GetModelStatsWithFilters(s.ctx, startTime, endTime, user.ID, 0, 0)
stats, err := s.repo.GetModelStatsWithFilters(s.ctx, startTime, endTime, user.ID, 0, 0, 0, nil)
s.Require().NoError(err, "GetModelStatsWithFilters user filter")
s.Require().Len(stats, 2)
// Test with apiKey filter
stats, err = s.repo.GetModelStatsWithFilters(s.ctx, startTime, endTime, 0, apiKey.ID, 0)
stats, err = s.repo.GetModelStatsWithFilters(s.ctx, startTime, endTime, 0, apiKey.ID, 0, 0, nil)
s.Require().NoError(err, "GetModelStatsWithFilters apiKey filter")
s.Require().Len(stats, 2)
// Test with account filter
stats, err = s.repo.GetModelStatsWithFilters(s.ctx, startTime, endTime, 0, 0, account.ID)
stats, err = s.repo.GetModelStatsWithFilters(s.ctx, startTime, endTime, 0, 0, account.ID, 0, nil)
s.Require().NoError(err, "GetModelStatsWithFilters account filter")
s.Require().Len(stats, 2)
}

View File

@@ -37,6 +37,16 @@ func ProvidePricingRemoteClient(cfg *config.Config) service.PricingRemoteClient
return NewPricingRemoteClient(cfg.Update.ProxyURL)
}
// ProvideSessionLimitCache 创建会话限制缓存
// 用于 Anthropic OAuth/SetupToken 账号的并发会话数量控制
func ProvideSessionLimitCache(rdb *redis.Client, cfg *config.Config) service.SessionLimitCache {
defaultIdleTimeoutMinutes := 5 // 默认 5 分钟空闲超时
if cfg != nil && cfg.Gateway.SessionIdleTimeoutMinutes > 0 {
defaultIdleTimeoutMinutes = cfg.Gateway.SessionIdleTimeoutMinutes
}
return NewSessionLimitCache(rdb, defaultIdleTimeoutMinutes)
}
// ProviderSet is the Wire provider set for all repositories
var ProviderSet = wire.NewSet(
NewUserRepository,
@@ -45,8 +55,11 @@ var ProviderSet = wire.NewSet(
NewAccountRepository,
NewProxyRepository,
NewRedeemCodeRepository,
NewPromoCodeRepository,
NewUsageLogRepository,
NewDashboardAggregationRepository,
NewSettingRepository,
NewOpsRepository,
NewUserSubscriptionRepository,
NewUserAttributeDefinitionRepository,
NewUserAttributeValueRepository,
@@ -56,12 +69,18 @@ var ProviderSet = wire.NewSet(
NewBillingCache,
NewAPIKeyCache,
NewTempUnschedCache,
NewTimeoutCounterCache,
ProvideConcurrencyCache,
ProvideSessionLimitCache,
NewDashboardCache,
NewEmailCache,
NewIdentityCache,
NewRedeemCache,
NewUpdateCache,
NewGeminiTokenCache,
NewSchedulerCache,
NewSchedulerOutboxRepository,
NewProxyLatencyCache,
// HTTP service ports (DI Strategy A: return interface directly)
NewTurnstileVerifier,