Merge upstream/main
This commit is contained in:
@@ -15,6 +15,7 @@ import (
|
||||
"database/sql"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"log"
|
||||
"strconv"
|
||||
"time"
|
||||
|
||||
@@ -79,6 +80,10 @@ func (r *accountRepository) Create(ctx context.Context, account *service.Account
|
||||
SetSchedulable(account.Schedulable).
|
||||
SetAutoPauseOnExpired(account.AutoPauseOnExpired)
|
||||
|
||||
if account.RateMultiplier != nil {
|
||||
builder.SetRateMultiplier(*account.RateMultiplier)
|
||||
}
|
||||
|
||||
if account.ProxyID != nil {
|
||||
builder.SetProxyID(*account.ProxyID)
|
||||
}
|
||||
@@ -115,6 +120,9 @@ func (r *accountRepository) Create(ctx context.Context, account *service.Account
|
||||
account.ID = created.ID
|
||||
account.CreatedAt = created.CreatedAt
|
||||
account.UpdatedAt = created.UpdatedAt
|
||||
if err := enqueueSchedulerOutbox(ctx, r.sql, service.SchedulerOutboxEventAccountChanged, &account.ID, nil, buildSchedulerGroupPayload(account.GroupIDs)); err != nil {
|
||||
log.Printf("[SchedulerOutbox] enqueue account create failed: account=%d err=%v", account.ID, err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -287,6 +295,10 @@ func (r *accountRepository) Update(ctx context.Context, account *service.Account
|
||||
SetSchedulable(account.Schedulable).
|
||||
SetAutoPauseOnExpired(account.AutoPauseOnExpired)
|
||||
|
||||
if account.RateMultiplier != nil {
|
||||
builder.SetRateMultiplier(*account.RateMultiplier)
|
||||
}
|
||||
|
||||
if account.ProxyID != nil {
|
||||
builder.SetProxyID(*account.ProxyID)
|
||||
} else {
|
||||
@@ -341,10 +353,17 @@ func (r *accountRepository) Update(ctx context.Context, account *service.Account
|
||||
return translatePersistenceError(err, service.ErrAccountNotFound, nil)
|
||||
}
|
||||
account.UpdatedAt = updated.UpdatedAt
|
||||
if err := enqueueSchedulerOutbox(ctx, r.sql, service.SchedulerOutboxEventAccountChanged, &account.ID, nil, buildSchedulerGroupPayload(account.GroupIDs)); err != nil {
|
||||
log.Printf("[SchedulerOutbox] enqueue account update failed: account=%d err=%v", account.ID, err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (r *accountRepository) Delete(ctx context.Context, id int64) error {
|
||||
groupIDs, err := r.loadAccountGroupIDs(ctx, id)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
// 使用事务保证账号与关联分组的删除原子性
|
||||
tx, err := r.client.Tx(ctx)
|
||||
if err != nil && !errors.Is(err, dbent.ErrTxStarted) {
|
||||
@@ -368,7 +387,12 @@ func (r *accountRepository) Delete(ctx context.Context, id int64) error {
|
||||
}
|
||||
|
||||
if tx != nil {
|
||||
return tx.Commit()
|
||||
if err := tx.Commit(); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
if err := enqueueSchedulerOutbox(ctx, r.sql, service.SchedulerOutboxEventAccountChanged, &id, nil, buildSchedulerGroupPayload(groupIDs)); err != nil {
|
||||
log.Printf("[SchedulerOutbox] enqueue account delete failed: account=%d err=%v", id, err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
@@ -455,7 +479,18 @@ func (r *accountRepository) UpdateLastUsed(ctx context.Context, id int64) error
|
||||
Where(dbaccount.IDEQ(id)).
|
||||
SetLastUsedAt(now).
|
||||
Save(ctx)
|
||||
return err
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
payload := map[string]any{
|
||||
"last_used": map[string]int64{
|
||||
strconv.FormatInt(id, 10): now.Unix(),
|
||||
},
|
||||
}
|
||||
if err := enqueueSchedulerOutbox(ctx, r.sql, service.SchedulerOutboxEventAccountLastUsed, &id, nil, payload); err != nil {
|
||||
log.Printf("[SchedulerOutbox] enqueue last used failed: account=%d err=%v", id, err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (r *accountRepository) BatchUpdateLastUsed(ctx context.Context, updates map[int64]time.Time) error {
|
||||
@@ -479,7 +514,18 @@ func (r *accountRepository) BatchUpdateLastUsed(ctx context.Context, updates map
|
||||
args = append(args, pq.Array(ids))
|
||||
|
||||
_, err := r.sql.ExecContext(ctx, caseSQL, args...)
|
||||
return err
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
lastUsedPayload := make(map[string]int64, len(updates))
|
||||
for id, ts := range updates {
|
||||
lastUsedPayload[strconv.FormatInt(id, 10)] = ts.Unix()
|
||||
}
|
||||
payload := map[string]any{"last_used": lastUsedPayload}
|
||||
if err := enqueueSchedulerOutbox(ctx, r.sql, service.SchedulerOutboxEventAccountLastUsed, nil, nil, payload); err != nil {
|
||||
log.Printf("[SchedulerOutbox] enqueue batch last used failed: err=%v", err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (r *accountRepository) SetError(ctx context.Context, id int64, errorMsg string) error {
|
||||
@@ -488,7 +534,13 @@ func (r *accountRepository) SetError(ctx context.Context, id int64, errorMsg str
|
||||
SetStatus(service.StatusError).
|
||||
SetErrorMessage(errorMsg).
|
||||
Save(ctx)
|
||||
return err
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if err := enqueueSchedulerOutbox(ctx, r.sql, service.SchedulerOutboxEventAccountChanged, &id, nil, nil); err != nil {
|
||||
log.Printf("[SchedulerOutbox] enqueue set error failed: account=%d err=%v", id, err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (r *accountRepository) ClearError(ctx context.Context, id int64) error {
|
||||
@@ -506,7 +558,14 @@ func (r *accountRepository) AddToGroup(ctx context.Context, accountID, groupID i
|
||||
SetGroupID(groupID).
|
||||
SetPriority(priority).
|
||||
Save(ctx)
|
||||
return err
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
payload := buildSchedulerGroupPayload([]int64{groupID})
|
||||
if err := enqueueSchedulerOutbox(ctx, r.sql, service.SchedulerOutboxEventAccountGroupsChanged, &accountID, nil, payload); err != nil {
|
||||
log.Printf("[SchedulerOutbox] enqueue add to group failed: account=%d group=%d err=%v", accountID, groupID, err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (r *accountRepository) RemoveFromGroup(ctx context.Context, accountID, groupID int64) error {
|
||||
@@ -516,7 +575,14 @@ func (r *accountRepository) RemoveFromGroup(ctx context.Context, accountID, grou
|
||||
dbaccountgroup.GroupIDEQ(groupID),
|
||||
).
|
||||
Exec(ctx)
|
||||
return err
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
payload := buildSchedulerGroupPayload([]int64{groupID})
|
||||
if err := enqueueSchedulerOutbox(ctx, r.sql, service.SchedulerOutboxEventAccountGroupsChanged, &accountID, nil, payload); err != nil {
|
||||
log.Printf("[SchedulerOutbox] enqueue remove from group failed: account=%d group=%d err=%v", accountID, groupID, err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (r *accountRepository) GetGroups(ctx context.Context, accountID int64) ([]service.Group, error) {
|
||||
@@ -537,6 +603,10 @@ func (r *accountRepository) GetGroups(ctx context.Context, accountID int64) ([]s
|
||||
}
|
||||
|
||||
func (r *accountRepository) BindGroups(ctx context.Context, accountID int64, groupIDs []int64) error {
|
||||
existingGroupIDs, err := r.loadAccountGroupIDs(ctx, accountID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
// 使用事务保证删除旧绑定与创建新绑定的原子性
|
||||
tx, err := r.client.Tx(ctx)
|
||||
if err != nil && !errors.Is(err, dbent.ErrTxStarted) {
|
||||
@@ -577,7 +647,13 @@ func (r *accountRepository) BindGroups(ctx context.Context, accountID int64, gro
|
||||
}
|
||||
|
||||
if tx != nil {
|
||||
return tx.Commit()
|
||||
if err := tx.Commit(); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
payload := buildSchedulerGroupPayload(mergeGroupIDs(existingGroupIDs, groupIDs))
|
||||
if err := enqueueSchedulerOutbox(ctx, r.sql, service.SchedulerOutboxEventAccountGroupsChanged, &accountID, nil, payload); err != nil {
|
||||
log.Printf("[SchedulerOutbox] enqueue bind groups failed: account=%d err=%v", accountID, err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
@@ -681,7 +757,13 @@ func (r *accountRepository) SetRateLimited(ctx context.Context, id int64, resetA
|
||||
SetRateLimitedAt(now).
|
||||
SetRateLimitResetAt(resetAt).
|
||||
Save(ctx)
|
||||
return err
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if err := enqueueSchedulerOutbox(ctx, r.sql, service.SchedulerOutboxEventAccountChanged, &id, nil, nil); err != nil {
|
||||
log.Printf("[SchedulerOutbox] enqueue rate limit failed: account=%d err=%v", id, err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (r *accountRepository) SetAntigravityQuotaScopeLimit(ctx context.Context, id int64, scope service.AntigravityQuotaScope, resetAt time.Time) error {
|
||||
@@ -715,6 +797,49 @@ func (r *accountRepository) SetAntigravityQuotaScopeLimit(ctx context.Context, i
|
||||
if affected == 0 {
|
||||
return service.ErrAccountNotFound
|
||||
}
|
||||
if err := enqueueSchedulerOutbox(ctx, r.sql, service.SchedulerOutboxEventAccountChanged, &id, nil, nil); err != nil {
|
||||
log.Printf("[SchedulerOutbox] enqueue quota scope failed: account=%d err=%v", id, err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (r *accountRepository) SetModelRateLimit(ctx context.Context, id int64, scope string, resetAt time.Time) error {
|
||||
if scope == "" {
|
||||
return nil
|
||||
}
|
||||
now := time.Now().UTC()
|
||||
payload := map[string]string{
|
||||
"rate_limited_at": now.Format(time.RFC3339),
|
||||
"rate_limit_reset_at": resetAt.UTC().Format(time.RFC3339),
|
||||
}
|
||||
raw, err := json.Marshal(payload)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
path := "{model_rate_limits," + scope + "}"
|
||||
client := clientFromContext(ctx, r.client)
|
||||
result, err := client.ExecContext(
|
||||
ctx,
|
||||
"UPDATE accounts SET extra = jsonb_set(COALESCE(extra, '{}'::jsonb), $1::text[], $2::jsonb, true), updated_at = NOW() WHERE id = $3 AND deleted_at IS NULL",
|
||||
path,
|
||||
raw,
|
||||
id,
|
||||
)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
affected, err := result.RowsAffected()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if affected == 0 {
|
||||
return service.ErrAccountNotFound
|
||||
}
|
||||
if err := enqueueSchedulerOutbox(ctx, r.sql, service.SchedulerOutboxEventAccountChanged, &id, nil, nil); err != nil {
|
||||
log.Printf("[SchedulerOutbox] enqueue model rate limit failed: account=%d err=%v", id, err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -723,7 +848,13 @@ func (r *accountRepository) SetOverloaded(ctx context.Context, id int64, until t
|
||||
Where(dbaccount.IDEQ(id)).
|
||||
SetOverloadUntil(until).
|
||||
Save(ctx)
|
||||
return err
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if err := enqueueSchedulerOutbox(ctx, r.sql, service.SchedulerOutboxEventAccountChanged, &id, nil, nil); err != nil {
|
||||
log.Printf("[SchedulerOutbox] enqueue overload failed: account=%d err=%v", id, err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (r *accountRepository) SetTempUnschedulable(ctx context.Context, id int64, until time.Time, reason string) error {
|
||||
@@ -736,7 +867,13 @@ func (r *accountRepository) SetTempUnschedulable(ctx context.Context, id int64,
|
||||
AND deleted_at IS NULL
|
||||
AND (temp_unschedulable_until IS NULL OR temp_unschedulable_until < $1)
|
||||
`, until, reason, id)
|
||||
return err
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if err := enqueueSchedulerOutbox(ctx, r.sql, service.SchedulerOutboxEventAccountChanged, &id, nil, nil); err != nil {
|
||||
log.Printf("[SchedulerOutbox] enqueue temp unschedulable failed: account=%d err=%v", id, err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (r *accountRepository) ClearTempUnschedulable(ctx context.Context, id int64) error {
|
||||
@@ -748,7 +885,13 @@ func (r *accountRepository) ClearTempUnschedulable(ctx context.Context, id int64
|
||||
WHERE id = $1
|
||||
AND deleted_at IS NULL
|
||||
`, id)
|
||||
return err
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if err := enqueueSchedulerOutbox(ctx, r.sql, service.SchedulerOutboxEventAccountChanged, &id, nil, nil); err != nil {
|
||||
log.Printf("[SchedulerOutbox] enqueue clear temp unschedulable failed: account=%d err=%v", id, err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (r *accountRepository) ClearRateLimit(ctx context.Context, id int64) error {
|
||||
@@ -758,7 +901,13 @@ func (r *accountRepository) ClearRateLimit(ctx context.Context, id int64) error
|
||||
ClearRateLimitResetAt().
|
||||
ClearOverloadUntil().
|
||||
Save(ctx)
|
||||
return err
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if err := enqueueSchedulerOutbox(ctx, r.sql, service.SchedulerOutboxEventAccountChanged, &id, nil, nil); err != nil {
|
||||
log.Printf("[SchedulerOutbox] enqueue clear rate limit failed: account=%d err=%v", id, err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (r *accountRepository) ClearAntigravityQuotaScopes(ctx context.Context, id int64) error {
|
||||
@@ -779,6 +928,33 @@ func (r *accountRepository) ClearAntigravityQuotaScopes(ctx context.Context, id
|
||||
if affected == 0 {
|
||||
return service.ErrAccountNotFound
|
||||
}
|
||||
if err := enqueueSchedulerOutbox(ctx, r.sql, service.SchedulerOutboxEventAccountChanged, &id, nil, nil); err != nil {
|
||||
log.Printf("[SchedulerOutbox] enqueue clear quota scopes failed: account=%d err=%v", id, err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (r *accountRepository) ClearModelRateLimits(ctx context.Context, id int64) error {
|
||||
client := clientFromContext(ctx, r.client)
|
||||
result, err := client.ExecContext(
|
||||
ctx,
|
||||
"UPDATE accounts SET extra = COALESCE(extra, '{}'::jsonb) - 'model_rate_limits', updated_at = NOW() WHERE id = $1 AND deleted_at IS NULL",
|
||||
id,
|
||||
)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
affected, err := result.RowsAffected()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if affected == 0 {
|
||||
return service.ErrAccountNotFound
|
||||
}
|
||||
if err := enqueueSchedulerOutbox(ctx, r.sql, service.SchedulerOutboxEventAccountChanged, &id, nil, nil); err != nil {
|
||||
log.Printf("[SchedulerOutbox] enqueue clear model rate limit failed: account=%d err=%v", id, err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -801,7 +977,13 @@ func (r *accountRepository) SetSchedulable(ctx context.Context, id int64, schedu
|
||||
Where(dbaccount.IDEQ(id)).
|
||||
SetSchedulable(schedulable).
|
||||
Save(ctx)
|
||||
return err
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if err := enqueueSchedulerOutbox(ctx, r.sql, service.SchedulerOutboxEventAccountChanged, &id, nil, nil); err != nil {
|
||||
log.Printf("[SchedulerOutbox] enqueue schedulable change failed: account=%d err=%v", id, err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (r *accountRepository) AutoPauseExpiredAccounts(ctx context.Context, now time.Time) (int64, error) {
|
||||
@@ -822,6 +1004,11 @@ func (r *accountRepository) AutoPauseExpiredAccounts(ctx context.Context, now ti
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
if rows > 0 {
|
||||
if err := enqueueSchedulerOutbox(ctx, r.sql, service.SchedulerOutboxEventFullRebuild, nil, nil, nil); err != nil {
|
||||
log.Printf("[SchedulerOutbox] enqueue auto pause rebuild failed: err=%v", err)
|
||||
}
|
||||
}
|
||||
return rows, nil
|
||||
}
|
||||
|
||||
@@ -853,6 +1040,9 @@ func (r *accountRepository) UpdateExtra(ctx context.Context, id int64, updates m
|
||||
if affected == 0 {
|
||||
return service.ErrAccountNotFound
|
||||
}
|
||||
if err := enqueueSchedulerOutbox(ctx, r.sql, service.SchedulerOutboxEventAccountChanged, &id, nil, nil); err != nil {
|
||||
log.Printf("[SchedulerOutbox] enqueue extra update failed: account=%d err=%v", id, err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -890,6 +1080,11 @@ func (r *accountRepository) BulkUpdate(ctx context.Context, ids []int64, updates
|
||||
args = append(args, *updates.Priority)
|
||||
idx++
|
||||
}
|
||||
if updates.RateMultiplier != nil {
|
||||
setClauses = append(setClauses, "rate_multiplier = $"+itoa(idx))
|
||||
args = append(args, *updates.RateMultiplier)
|
||||
idx++
|
||||
}
|
||||
if updates.Status != nil {
|
||||
setClauses = append(setClauses, "status = $"+itoa(idx))
|
||||
args = append(args, *updates.Status)
|
||||
@@ -937,6 +1132,12 @@ func (r *accountRepository) BulkUpdate(ctx context.Context, ids []int64, updates
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
if rows > 0 {
|
||||
payload := map[string]any{"account_ids": ids}
|
||||
if err := enqueueSchedulerOutbox(ctx, r.sql, service.SchedulerOutboxEventAccountBulkChanged, nil, nil, payload); err != nil {
|
||||
log.Printf("[SchedulerOutbox] enqueue bulk update failed: err=%v", err)
|
||||
}
|
||||
}
|
||||
return rows, nil
|
||||
}
|
||||
|
||||
@@ -1179,11 +1380,61 @@ func (r *accountRepository) loadAccountGroups(ctx context.Context, accountIDs []
|
||||
return groupsByAccount, groupIDsByAccount, accountGroupsByAccount, nil
|
||||
}
|
||||
|
||||
func (r *accountRepository) loadAccountGroupIDs(ctx context.Context, accountID int64) ([]int64, error) {
|
||||
entries, err := r.client.AccountGroup.
|
||||
Query().
|
||||
Where(dbaccountgroup.AccountIDEQ(accountID)).
|
||||
All(ctx)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
ids := make([]int64, 0, len(entries))
|
||||
for _, entry := range entries {
|
||||
ids = append(ids, entry.GroupID)
|
||||
}
|
||||
return ids, nil
|
||||
}
|
||||
|
||||
func mergeGroupIDs(a []int64, b []int64) []int64 {
|
||||
seen := make(map[int64]struct{}, len(a)+len(b))
|
||||
out := make([]int64, 0, len(a)+len(b))
|
||||
for _, id := range a {
|
||||
if id <= 0 {
|
||||
continue
|
||||
}
|
||||
if _, ok := seen[id]; ok {
|
||||
continue
|
||||
}
|
||||
seen[id] = struct{}{}
|
||||
out = append(out, id)
|
||||
}
|
||||
for _, id := range b {
|
||||
if id <= 0 {
|
||||
continue
|
||||
}
|
||||
if _, ok := seen[id]; ok {
|
||||
continue
|
||||
}
|
||||
seen[id] = struct{}{}
|
||||
out = append(out, id)
|
||||
}
|
||||
return out
|
||||
}
|
||||
|
||||
func buildSchedulerGroupPayload(groupIDs []int64) map[string]any {
|
||||
if len(groupIDs) == 0 {
|
||||
return nil
|
||||
}
|
||||
return map[string]any{"group_ids": groupIDs}
|
||||
}
|
||||
|
||||
func accountEntityToService(m *dbent.Account) *service.Account {
|
||||
if m == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
rateMultiplier := m.RateMultiplier
|
||||
|
||||
return &service.Account{
|
||||
ID: m.ID,
|
||||
Name: m.Name,
|
||||
@@ -1195,6 +1446,7 @@ func accountEntityToService(m *dbent.Account) *service.Account {
|
||||
ProxyID: m.ProxyID,
|
||||
Concurrency: m.Concurrency,
|
||||
Priority: m.Priority,
|
||||
RateMultiplier: &rateMultiplier,
|
||||
Status: m.Status,
|
||||
ErrorMessage: derefString(m.ErrorMessage),
|
||||
LastUsedAt: m.LastUsedAt,
|
||||
|
||||
@@ -2,6 +2,7 @@ package repository
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"time"
|
||||
@@ -13,6 +14,7 @@ import (
|
||||
const (
|
||||
apiKeyRateLimitKeyPrefix = "apikey:ratelimit:"
|
||||
apiKeyRateLimitDuration = 24 * time.Hour
|
||||
apiKeyAuthCachePrefix = "apikey:auth:"
|
||||
)
|
||||
|
||||
// apiKeyRateLimitKey generates the Redis key for API key creation rate limiting.
|
||||
@@ -20,6 +22,10 @@ func apiKeyRateLimitKey(userID int64) string {
|
||||
return fmt.Sprintf("%s%d", apiKeyRateLimitKeyPrefix, userID)
|
||||
}
|
||||
|
||||
func apiKeyAuthCacheKey(key string) string {
|
||||
return fmt.Sprintf("%s%s", apiKeyAuthCachePrefix, key)
|
||||
}
|
||||
|
||||
type apiKeyCache struct {
|
||||
rdb *redis.Client
|
||||
}
|
||||
@@ -58,3 +64,30 @@ func (c *apiKeyCache) IncrementDailyUsage(ctx context.Context, apiKey string) er
|
||||
func (c *apiKeyCache) SetDailyUsageExpiry(ctx context.Context, apiKey string, ttl time.Duration) error {
|
||||
return c.rdb.Expire(ctx, apiKey, ttl).Err()
|
||||
}
|
||||
|
||||
func (c *apiKeyCache) GetAuthCache(ctx context.Context, key string) (*service.APIKeyAuthCacheEntry, error) {
|
||||
val, err := c.rdb.Get(ctx, apiKeyAuthCacheKey(key)).Bytes()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
var entry service.APIKeyAuthCacheEntry
|
||||
if err := json.Unmarshal(val, &entry); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &entry, nil
|
||||
}
|
||||
|
||||
func (c *apiKeyCache) SetAuthCache(ctx context.Context, key string, entry *service.APIKeyAuthCacheEntry, ttl time.Duration) error {
|
||||
if entry == nil {
|
||||
return nil
|
||||
}
|
||||
payload, err := json.Marshal(entry)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
return c.rdb.Set(ctx, apiKeyAuthCacheKey(key), payload, ttl).Err()
|
||||
}
|
||||
|
||||
func (c *apiKeyCache) DeleteAuthCache(ctx context.Context, key string) error {
|
||||
return c.rdb.Del(ctx, apiKeyAuthCacheKey(key)).Err()
|
||||
}
|
||||
|
||||
@@ -6,7 +6,9 @@ import (
|
||||
|
||||
dbent "github.com/Wei-Shaw/sub2api/ent"
|
||||
"github.com/Wei-Shaw/sub2api/ent/apikey"
|
||||
"github.com/Wei-Shaw/sub2api/ent/group"
|
||||
"github.com/Wei-Shaw/sub2api/ent/schema/mixins"
|
||||
"github.com/Wei-Shaw/sub2api/ent/user"
|
||||
"github.com/Wei-Shaw/sub2api/internal/service"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/pagination"
|
||||
@@ -26,13 +28,21 @@ func (r *apiKeyRepository) activeQuery() *dbent.APIKeyQuery {
|
||||
}
|
||||
|
||||
func (r *apiKeyRepository) Create(ctx context.Context, key *service.APIKey) error {
|
||||
created, err := r.client.APIKey.Create().
|
||||
builder := r.client.APIKey.Create().
|
||||
SetUserID(key.UserID).
|
||||
SetKey(key.Key).
|
||||
SetName(key.Name).
|
||||
SetStatus(key.Status).
|
||||
SetNillableGroupID(key.GroupID).
|
||||
Save(ctx)
|
||||
SetNillableGroupID(key.GroupID)
|
||||
|
||||
if len(key.IPWhitelist) > 0 {
|
||||
builder.SetIPWhitelist(key.IPWhitelist)
|
||||
}
|
||||
if len(key.IPBlacklist) > 0 {
|
||||
builder.SetIPBlacklist(key.IPBlacklist)
|
||||
}
|
||||
|
||||
created, err := builder.Save(ctx)
|
||||
if err == nil {
|
||||
key.ID = created.ID
|
||||
key.CreatedAt = created.CreatedAt
|
||||
@@ -56,23 +66,23 @@ func (r *apiKeyRepository) GetByID(ctx context.Context, id int64) (*service.APIK
|
||||
return apiKeyEntityToService(m), nil
|
||||
}
|
||||
|
||||
// GetOwnerID 根据 API Key ID 获取其所有者(用户)的 ID。
|
||||
// GetKeyAndOwnerID 根据 API Key ID 获取其 key 与所有者(用户)ID。
|
||||
// 相比 GetByID,此方法性能更优,因为:
|
||||
// - 使用 Select() 只查询 user_id 字段,减少数据传输量
|
||||
// - 使用 Select() 只查询必要字段,减少数据传输量
|
||||
// - 不加载完整的 API Key 实体及其关联数据(User、Group 等)
|
||||
// - 适用于权限验证等只需用户 ID 的场景(如删除前的所有权检查)
|
||||
func (r *apiKeyRepository) GetOwnerID(ctx context.Context, id int64) (int64, error) {
|
||||
// - 适用于删除等只需 key 与用户 ID 的场景
|
||||
func (r *apiKeyRepository) GetKeyAndOwnerID(ctx context.Context, id int64) (string, int64, error) {
|
||||
m, err := r.activeQuery().
|
||||
Where(apikey.IDEQ(id)).
|
||||
Select(apikey.FieldUserID).
|
||||
Select(apikey.FieldKey, apikey.FieldUserID).
|
||||
Only(ctx)
|
||||
if err != nil {
|
||||
if dbent.IsNotFound(err) {
|
||||
return 0, service.ErrAPIKeyNotFound
|
||||
return "", 0, service.ErrAPIKeyNotFound
|
||||
}
|
||||
return 0, err
|
||||
return "", 0, err
|
||||
}
|
||||
return m.UserID, nil
|
||||
return m.Key, m.UserID, nil
|
||||
}
|
||||
|
||||
func (r *apiKeyRepository) GetByKey(ctx context.Context, key string) (*service.APIKey, error) {
|
||||
@@ -90,6 +100,56 @@ func (r *apiKeyRepository) GetByKey(ctx context.Context, key string) (*service.A
|
||||
return apiKeyEntityToService(m), nil
|
||||
}
|
||||
|
||||
func (r *apiKeyRepository) GetByKeyForAuth(ctx context.Context, key string) (*service.APIKey, error) {
|
||||
m, err := r.activeQuery().
|
||||
Where(apikey.KeyEQ(key)).
|
||||
Select(
|
||||
apikey.FieldID,
|
||||
apikey.FieldUserID,
|
||||
apikey.FieldGroupID,
|
||||
apikey.FieldStatus,
|
||||
apikey.FieldIPWhitelist,
|
||||
apikey.FieldIPBlacklist,
|
||||
).
|
||||
WithUser(func(q *dbent.UserQuery) {
|
||||
q.Select(
|
||||
user.FieldID,
|
||||
user.FieldStatus,
|
||||
user.FieldRole,
|
||||
user.FieldBalance,
|
||||
user.FieldConcurrency,
|
||||
)
|
||||
}).
|
||||
WithGroup(func(q *dbent.GroupQuery) {
|
||||
q.Select(
|
||||
group.FieldID,
|
||||
group.FieldName,
|
||||
group.FieldPlatform,
|
||||
group.FieldStatus,
|
||||
group.FieldSubscriptionType,
|
||||
group.FieldRateMultiplier,
|
||||
group.FieldDailyLimitUsd,
|
||||
group.FieldWeeklyLimitUsd,
|
||||
group.FieldMonthlyLimitUsd,
|
||||
group.FieldImagePrice1k,
|
||||
group.FieldImagePrice2k,
|
||||
group.FieldImagePrice4k,
|
||||
group.FieldClaudeCodeOnly,
|
||||
group.FieldFallbackGroupID,
|
||||
group.FieldModelRoutingEnabled,
|
||||
group.FieldModelRouting,
|
||||
)
|
||||
}).
|
||||
Only(ctx)
|
||||
if err != nil {
|
||||
if dbent.IsNotFound(err) {
|
||||
return nil, service.ErrAPIKeyNotFound
|
||||
}
|
||||
return nil, err
|
||||
}
|
||||
return apiKeyEntityToService(m), nil
|
||||
}
|
||||
|
||||
func (r *apiKeyRepository) Update(ctx context.Context, key *service.APIKey) error {
|
||||
// 使用原子操作:将软删除检查与更新合并到同一语句,避免竞态条件。
|
||||
// 之前的实现先检查 Exist 再 UpdateOneID,若在两步之间发生软删除,
|
||||
@@ -108,6 +168,18 @@ func (r *apiKeyRepository) Update(ctx context.Context, key *service.APIKey) erro
|
||||
builder.ClearGroupID()
|
||||
}
|
||||
|
||||
// IP 限制字段
|
||||
if len(key.IPWhitelist) > 0 {
|
||||
builder.SetIPWhitelist(key.IPWhitelist)
|
||||
} else {
|
||||
builder.ClearIPWhitelist()
|
||||
}
|
||||
if len(key.IPBlacklist) > 0 {
|
||||
builder.SetIPBlacklist(key.IPBlacklist)
|
||||
} else {
|
||||
builder.ClearIPBlacklist()
|
||||
}
|
||||
|
||||
affected, err := builder.Save(ctx)
|
||||
if err != nil {
|
||||
return err
|
||||
@@ -263,19 +335,43 @@ func (r *apiKeyRepository) CountByGroupID(ctx context.Context, groupID int64) (i
|
||||
return int64(count), err
|
||||
}
|
||||
|
||||
func (r *apiKeyRepository) ListKeysByUserID(ctx context.Context, userID int64) ([]string, error) {
|
||||
keys, err := r.activeQuery().
|
||||
Where(apikey.UserIDEQ(userID)).
|
||||
Select(apikey.FieldKey).
|
||||
Strings(ctx)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return keys, nil
|
||||
}
|
||||
|
||||
func (r *apiKeyRepository) ListKeysByGroupID(ctx context.Context, groupID int64) ([]string, error) {
|
||||
keys, err := r.activeQuery().
|
||||
Where(apikey.GroupIDEQ(groupID)).
|
||||
Select(apikey.FieldKey).
|
||||
Strings(ctx)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return keys, nil
|
||||
}
|
||||
|
||||
func apiKeyEntityToService(m *dbent.APIKey) *service.APIKey {
|
||||
if m == nil {
|
||||
return nil
|
||||
}
|
||||
out := &service.APIKey{
|
||||
ID: m.ID,
|
||||
UserID: m.UserID,
|
||||
Key: m.Key,
|
||||
Name: m.Name,
|
||||
Status: m.Status,
|
||||
CreatedAt: m.CreatedAt,
|
||||
UpdatedAt: m.UpdatedAt,
|
||||
GroupID: m.GroupID,
|
||||
ID: m.ID,
|
||||
UserID: m.UserID,
|
||||
Key: m.Key,
|
||||
Name: m.Name,
|
||||
Status: m.Status,
|
||||
IPWhitelist: m.IPWhitelist,
|
||||
IPBlacklist: m.IPBlacklist,
|
||||
CreatedAt: m.CreatedAt,
|
||||
UpdatedAt: m.UpdatedAt,
|
||||
GroupID: m.GroupID,
|
||||
}
|
||||
if m.Edges.User != nil {
|
||||
out.User = userEntityToService(m.Edges.User)
|
||||
@@ -317,6 +413,7 @@ func groupEntityToService(g *dbent.Group) *service.Group {
|
||||
RateMultiplier: g.RateMultiplier,
|
||||
IsExclusive: g.IsExclusive,
|
||||
Status: g.Status,
|
||||
Hydrated: true,
|
||||
SubscriptionType: g.SubscriptionType,
|
||||
DailyLimitUSD: g.DailyLimitUsd,
|
||||
WeeklyLimitUSD: g.WeeklyLimitUsd,
|
||||
@@ -327,6 +424,8 @@ func groupEntityToService(g *dbent.Group) *service.Group {
|
||||
DefaultValidityDays: g.DefaultValidityDays,
|
||||
ClaudeCodeOnly: g.ClaudeCodeOnly,
|
||||
FallbackGroupID: g.FallbackGroupID,
|
||||
ModelRouting: g.ModelRouting,
|
||||
ModelRoutingEnabled: g.ModelRoutingEnabled,
|
||||
CreatedAt: g.CreatedAt,
|
||||
UpdatedAt: g.UpdatedAt,
|
||||
}
|
||||
|
||||
@@ -93,7 +93,7 @@ var (
|
||||
return redis.call('ZCARD', key)
|
||||
`)
|
||||
|
||||
// incrementWaitScript - only sets TTL on first creation to avoid refreshing
|
||||
// incrementWaitScript - refreshes TTL on each increment to keep queue depth accurate
|
||||
// KEYS[1] = wait queue key
|
||||
// ARGV[1] = maxWait
|
||||
// ARGV[2] = TTL in seconds
|
||||
@@ -111,15 +111,13 @@ var (
|
||||
|
||||
local newVal = redis.call('INCR', KEYS[1])
|
||||
|
||||
-- Only set TTL on first creation to avoid refreshing zombie data
|
||||
if newVal == 1 then
|
||||
redis.call('EXPIRE', KEYS[1], ARGV[2])
|
||||
end
|
||||
-- Refresh TTL so long-running traffic doesn't expire active queue counters.
|
||||
redis.call('EXPIRE', KEYS[1], ARGV[2])
|
||||
|
||||
return 1
|
||||
`)
|
||||
|
||||
// incrementAccountWaitScript - account-level wait queue count
|
||||
// incrementAccountWaitScript - account-level wait queue count (refresh TTL on each increment)
|
||||
incrementAccountWaitScript = redis.NewScript(`
|
||||
local current = redis.call('GET', KEYS[1])
|
||||
if current == false then
|
||||
@@ -134,10 +132,8 @@ var (
|
||||
|
||||
local newVal = redis.call('INCR', KEYS[1])
|
||||
|
||||
-- Only set TTL on first creation to avoid refreshing zombie data
|
||||
if newVal == 1 then
|
||||
redis.call('EXPIRE', KEYS[1], ARGV[2])
|
||||
end
|
||||
-- Refresh TTL so long-running traffic doesn't expire active queue counters.
|
||||
redis.call('EXPIRE', KEYS[1], ARGV[2])
|
||||
|
||||
return 1
|
||||
`)
|
||||
|
||||
392
backend/internal/repository/dashboard_aggregation_repo.go
Normal file
392
backend/internal/repository/dashboard_aggregation_repo.go
Normal file
@@ -0,0 +1,392 @@
|
||||
package repository
|
||||
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"fmt"
|
||||
"log"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/timezone"
|
||||
"github.com/Wei-Shaw/sub2api/internal/service"
|
||||
"github.com/lib/pq"
|
||||
)
|
||||
|
||||
type dashboardAggregationRepository struct {
|
||||
sql sqlExecutor
|
||||
}
|
||||
|
||||
// NewDashboardAggregationRepository 创建仪表盘预聚合仓储。
|
||||
func NewDashboardAggregationRepository(sqlDB *sql.DB) service.DashboardAggregationRepository {
|
||||
if sqlDB == nil {
|
||||
return nil
|
||||
}
|
||||
if !isPostgresDriver(sqlDB) {
|
||||
log.Printf("[DashboardAggregation] 检测到非 PostgreSQL 驱动,已自动禁用预聚合")
|
||||
return nil
|
||||
}
|
||||
return newDashboardAggregationRepositoryWithSQL(sqlDB)
|
||||
}
|
||||
|
||||
func newDashboardAggregationRepositoryWithSQL(sqlq sqlExecutor) *dashboardAggregationRepository {
|
||||
return &dashboardAggregationRepository{sql: sqlq}
|
||||
}
|
||||
|
||||
func isPostgresDriver(db *sql.DB) bool {
|
||||
if db == nil {
|
||||
return false
|
||||
}
|
||||
_, ok := db.Driver().(*pq.Driver)
|
||||
return ok
|
||||
}
|
||||
|
||||
func (r *dashboardAggregationRepository) AggregateRange(ctx context.Context, start, end time.Time) error {
|
||||
loc := timezone.Location()
|
||||
startLocal := start.In(loc)
|
||||
endLocal := end.In(loc)
|
||||
if !endLocal.After(startLocal) {
|
||||
return nil
|
||||
}
|
||||
|
||||
hourStart := startLocal.Truncate(time.Hour)
|
||||
hourEnd := endLocal.Truncate(time.Hour)
|
||||
if endLocal.After(hourEnd) {
|
||||
hourEnd = hourEnd.Add(time.Hour)
|
||||
}
|
||||
|
||||
dayStart := truncateToDay(startLocal)
|
||||
dayEnd := truncateToDay(endLocal)
|
||||
if endLocal.After(dayEnd) {
|
||||
dayEnd = dayEnd.Add(24 * time.Hour)
|
||||
}
|
||||
|
||||
// 以桶边界聚合,允许覆盖 end 所在桶的剩余区间。
|
||||
if err := r.insertHourlyActiveUsers(ctx, hourStart, hourEnd); err != nil {
|
||||
return err
|
||||
}
|
||||
if err := r.insertDailyActiveUsers(ctx, hourStart, hourEnd); err != nil {
|
||||
return err
|
||||
}
|
||||
if err := r.upsertHourlyAggregates(ctx, hourStart, hourEnd); err != nil {
|
||||
return err
|
||||
}
|
||||
if err := r.upsertDailyAggregates(ctx, dayStart, dayEnd); err != nil {
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (r *dashboardAggregationRepository) GetAggregationWatermark(ctx context.Context) (time.Time, error) {
|
||||
var ts time.Time
|
||||
query := "SELECT last_aggregated_at FROM usage_dashboard_aggregation_watermark WHERE id = 1"
|
||||
if err := scanSingleRow(ctx, r.sql, query, nil, &ts); err != nil {
|
||||
if err == sql.ErrNoRows {
|
||||
return time.Unix(0, 0).UTC(), nil
|
||||
}
|
||||
return time.Time{}, err
|
||||
}
|
||||
return ts.UTC(), nil
|
||||
}
|
||||
|
||||
func (r *dashboardAggregationRepository) UpdateAggregationWatermark(ctx context.Context, aggregatedAt time.Time) error {
|
||||
query := `
|
||||
INSERT INTO usage_dashboard_aggregation_watermark (id, last_aggregated_at, updated_at)
|
||||
VALUES (1, $1, NOW())
|
||||
ON CONFLICT (id)
|
||||
DO UPDATE SET last_aggregated_at = EXCLUDED.last_aggregated_at, updated_at = EXCLUDED.updated_at
|
||||
`
|
||||
_, err := r.sql.ExecContext(ctx, query, aggregatedAt.UTC())
|
||||
return err
|
||||
}
|
||||
|
||||
func (r *dashboardAggregationRepository) CleanupAggregates(ctx context.Context, hourlyCutoff, dailyCutoff time.Time) error {
|
||||
hourlyCutoffUTC := hourlyCutoff.UTC()
|
||||
dailyCutoffUTC := dailyCutoff.UTC()
|
||||
if _, err := r.sql.ExecContext(ctx, "DELETE FROM usage_dashboard_hourly WHERE bucket_start < $1", hourlyCutoffUTC); err != nil {
|
||||
return err
|
||||
}
|
||||
if _, err := r.sql.ExecContext(ctx, "DELETE FROM usage_dashboard_hourly_users WHERE bucket_start < $1", hourlyCutoffUTC); err != nil {
|
||||
return err
|
||||
}
|
||||
if _, err := r.sql.ExecContext(ctx, "DELETE FROM usage_dashboard_daily WHERE bucket_date < $1::date", dailyCutoffUTC); err != nil {
|
||||
return err
|
||||
}
|
||||
if _, err := r.sql.ExecContext(ctx, "DELETE FROM usage_dashboard_daily_users WHERE bucket_date < $1::date", dailyCutoffUTC); err != nil {
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (r *dashboardAggregationRepository) CleanupUsageLogs(ctx context.Context, cutoff time.Time) error {
|
||||
isPartitioned, err := r.isUsageLogsPartitioned(ctx)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if isPartitioned {
|
||||
return r.dropUsageLogsPartitions(ctx, cutoff)
|
||||
}
|
||||
_, err = r.sql.ExecContext(ctx, "DELETE FROM usage_logs WHERE created_at < $1", cutoff.UTC())
|
||||
return err
|
||||
}
|
||||
|
||||
func (r *dashboardAggregationRepository) EnsureUsageLogsPartitions(ctx context.Context, now time.Time) error {
|
||||
isPartitioned, err := r.isUsageLogsPartitioned(ctx)
|
||||
if err != nil || !isPartitioned {
|
||||
return err
|
||||
}
|
||||
monthStart := truncateToMonthUTC(now)
|
||||
prevMonth := monthStart.AddDate(0, -1, 0)
|
||||
nextMonth := monthStart.AddDate(0, 1, 0)
|
||||
|
||||
for _, m := range []time.Time{prevMonth, monthStart, nextMonth} {
|
||||
if err := r.createUsageLogsPartition(ctx, m); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (r *dashboardAggregationRepository) insertHourlyActiveUsers(ctx context.Context, start, end time.Time) error {
|
||||
tzName := timezone.Name()
|
||||
query := `
|
||||
INSERT INTO usage_dashboard_hourly_users (bucket_start, user_id)
|
||||
SELECT DISTINCT
|
||||
date_trunc('hour', created_at AT TIME ZONE $3) AT TIME ZONE $3 AS bucket_start,
|
||||
user_id
|
||||
FROM usage_logs
|
||||
WHERE created_at >= $1 AND created_at < $2
|
||||
ON CONFLICT DO NOTHING
|
||||
`
|
||||
_, err := r.sql.ExecContext(ctx, query, start, end, tzName)
|
||||
return err
|
||||
}
|
||||
|
||||
func (r *dashboardAggregationRepository) insertDailyActiveUsers(ctx context.Context, start, end time.Time) error {
|
||||
tzName := timezone.Name()
|
||||
query := `
|
||||
INSERT INTO usage_dashboard_daily_users (bucket_date, user_id)
|
||||
SELECT DISTINCT
|
||||
(bucket_start AT TIME ZONE $3)::date AS bucket_date,
|
||||
user_id
|
||||
FROM usage_dashboard_hourly_users
|
||||
WHERE bucket_start >= $1 AND bucket_start < $2
|
||||
ON CONFLICT DO NOTHING
|
||||
`
|
||||
_, err := r.sql.ExecContext(ctx, query, start, end, tzName)
|
||||
return err
|
||||
}
|
||||
|
||||
func (r *dashboardAggregationRepository) upsertHourlyAggregates(ctx context.Context, start, end time.Time) error {
|
||||
tzName := timezone.Name()
|
||||
query := `
|
||||
WITH hourly AS (
|
||||
SELECT
|
||||
date_trunc('hour', created_at AT TIME ZONE $3) AT TIME ZONE $3 AS bucket_start,
|
||||
COUNT(*) AS total_requests,
|
||||
COALESCE(SUM(input_tokens), 0) AS input_tokens,
|
||||
COALESCE(SUM(output_tokens), 0) AS output_tokens,
|
||||
COALESCE(SUM(cache_creation_tokens), 0) AS cache_creation_tokens,
|
||||
COALESCE(SUM(cache_read_tokens), 0) AS cache_read_tokens,
|
||||
COALESCE(SUM(total_cost), 0) AS total_cost,
|
||||
COALESCE(SUM(actual_cost), 0) AS actual_cost,
|
||||
COALESCE(SUM(COALESCE(duration_ms, 0)), 0) AS total_duration_ms
|
||||
FROM usage_logs
|
||||
WHERE created_at >= $1 AND created_at < $2
|
||||
GROUP BY 1
|
||||
),
|
||||
user_counts AS (
|
||||
SELECT bucket_start, COUNT(*) AS active_users
|
||||
FROM usage_dashboard_hourly_users
|
||||
WHERE bucket_start >= $1 AND bucket_start < $2
|
||||
GROUP BY bucket_start
|
||||
)
|
||||
INSERT INTO usage_dashboard_hourly (
|
||||
bucket_start,
|
||||
total_requests,
|
||||
input_tokens,
|
||||
output_tokens,
|
||||
cache_creation_tokens,
|
||||
cache_read_tokens,
|
||||
total_cost,
|
||||
actual_cost,
|
||||
total_duration_ms,
|
||||
active_users,
|
||||
computed_at
|
||||
)
|
||||
SELECT
|
||||
hourly.bucket_start,
|
||||
hourly.total_requests,
|
||||
hourly.input_tokens,
|
||||
hourly.output_tokens,
|
||||
hourly.cache_creation_tokens,
|
||||
hourly.cache_read_tokens,
|
||||
hourly.total_cost,
|
||||
hourly.actual_cost,
|
||||
hourly.total_duration_ms,
|
||||
COALESCE(user_counts.active_users, 0) AS active_users,
|
||||
NOW()
|
||||
FROM hourly
|
||||
LEFT JOIN user_counts ON user_counts.bucket_start = hourly.bucket_start
|
||||
ON CONFLICT (bucket_start)
|
||||
DO UPDATE SET
|
||||
total_requests = EXCLUDED.total_requests,
|
||||
input_tokens = EXCLUDED.input_tokens,
|
||||
output_tokens = EXCLUDED.output_tokens,
|
||||
cache_creation_tokens = EXCLUDED.cache_creation_tokens,
|
||||
cache_read_tokens = EXCLUDED.cache_read_tokens,
|
||||
total_cost = EXCLUDED.total_cost,
|
||||
actual_cost = EXCLUDED.actual_cost,
|
||||
total_duration_ms = EXCLUDED.total_duration_ms,
|
||||
active_users = EXCLUDED.active_users,
|
||||
computed_at = EXCLUDED.computed_at
|
||||
`
|
||||
_, err := r.sql.ExecContext(ctx, query, start, end, tzName)
|
||||
return err
|
||||
}
|
||||
|
||||
func (r *dashboardAggregationRepository) upsertDailyAggregates(ctx context.Context, start, end time.Time) error {
|
||||
tzName := timezone.Name()
|
||||
query := `
|
||||
WITH daily AS (
|
||||
SELECT
|
||||
(bucket_start AT TIME ZONE $5)::date AS bucket_date,
|
||||
COALESCE(SUM(total_requests), 0) AS total_requests,
|
||||
COALESCE(SUM(input_tokens), 0) AS input_tokens,
|
||||
COALESCE(SUM(output_tokens), 0) AS output_tokens,
|
||||
COALESCE(SUM(cache_creation_tokens), 0) AS cache_creation_tokens,
|
||||
COALESCE(SUM(cache_read_tokens), 0) AS cache_read_tokens,
|
||||
COALESCE(SUM(total_cost), 0) AS total_cost,
|
||||
COALESCE(SUM(actual_cost), 0) AS actual_cost,
|
||||
COALESCE(SUM(total_duration_ms), 0) AS total_duration_ms
|
||||
FROM usage_dashboard_hourly
|
||||
WHERE bucket_start >= $1 AND bucket_start < $2
|
||||
GROUP BY (bucket_start AT TIME ZONE $5)::date
|
||||
),
|
||||
user_counts AS (
|
||||
SELECT bucket_date, COUNT(*) AS active_users
|
||||
FROM usage_dashboard_daily_users
|
||||
WHERE bucket_date >= $3::date AND bucket_date < $4::date
|
||||
GROUP BY bucket_date
|
||||
)
|
||||
INSERT INTO usage_dashboard_daily (
|
||||
bucket_date,
|
||||
total_requests,
|
||||
input_tokens,
|
||||
output_tokens,
|
||||
cache_creation_tokens,
|
||||
cache_read_tokens,
|
||||
total_cost,
|
||||
actual_cost,
|
||||
total_duration_ms,
|
||||
active_users,
|
||||
computed_at
|
||||
)
|
||||
SELECT
|
||||
daily.bucket_date,
|
||||
daily.total_requests,
|
||||
daily.input_tokens,
|
||||
daily.output_tokens,
|
||||
daily.cache_creation_tokens,
|
||||
daily.cache_read_tokens,
|
||||
daily.total_cost,
|
||||
daily.actual_cost,
|
||||
daily.total_duration_ms,
|
||||
COALESCE(user_counts.active_users, 0) AS active_users,
|
||||
NOW()
|
||||
FROM daily
|
||||
LEFT JOIN user_counts ON user_counts.bucket_date = daily.bucket_date
|
||||
ON CONFLICT (bucket_date)
|
||||
DO UPDATE SET
|
||||
total_requests = EXCLUDED.total_requests,
|
||||
input_tokens = EXCLUDED.input_tokens,
|
||||
output_tokens = EXCLUDED.output_tokens,
|
||||
cache_creation_tokens = EXCLUDED.cache_creation_tokens,
|
||||
cache_read_tokens = EXCLUDED.cache_read_tokens,
|
||||
total_cost = EXCLUDED.total_cost,
|
||||
actual_cost = EXCLUDED.actual_cost,
|
||||
total_duration_ms = EXCLUDED.total_duration_ms,
|
||||
active_users = EXCLUDED.active_users,
|
||||
computed_at = EXCLUDED.computed_at
|
||||
`
|
||||
_, err := r.sql.ExecContext(ctx, query, start, end, start, end, tzName)
|
||||
return err
|
||||
}
|
||||
|
||||
func (r *dashboardAggregationRepository) isUsageLogsPartitioned(ctx context.Context) (bool, error) {
|
||||
query := `
|
||||
SELECT EXISTS(
|
||||
SELECT 1
|
||||
FROM pg_partitioned_table pt
|
||||
JOIN pg_class c ON c.oid = pt.partrelid
|
||||
WHERE c.relname = 'usage_logs'
|
||||
)
|
||||
`
|
||||
var partitioned bool
|
||||
if err := scanSingleRow(ctx, r.sql, query, nil, &partitioned); err != nil {
|
||||
return false, err
|
||||
}
|
||||
return partitioned, nil
|
||||
}
|
||||
|
||||
func (r *dashboardAggregationRepository) dropUsageLogsPartitions(ctx context.Context, cutoff time.Time) error {
|
||||
rows, err := r.sql.QueryContext(ctx, `
|
||||
SELECT c.relname
|
||||
FROM pg_inherits
|
||||
JOIN pg_class c ON c.oid = pg_inherits.inhrelid
|
||||
JOIN pg_class p ON p.oid = pg_inherits.inhparent
|
||||
WHERE p.relname = 'usage_logs'
|
||||
`)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer func() {
|
||||
_ = rows.Close()
|
||||
}()
|
||||
|
||||
cutoffMonth := truncateToMonthUTC(cutoff)
|
||||
for rows.Next() {
|
||||
var name string
|
||||
if err := rows.Scan(&name); err != nil {
|
||||
return err
|
||||
}
|
||||
if !strings.HasPrefix(name, "usage_logs_") {
|
||||
continue
|
||||
}
|
||||
suffix := strings.TrimPrefix(name, "usage_logs_")
|
||||
month, err := time.Parse("200601", suffix)
|
||||
if err != nil {
|
||||
continue
|
||||
}
|
||||
month = month.UTC()
|
||||
if month.Before(cutoffMonth) {
|
||||
if _, err := r.sql.ExecContext(ctx, fmt.Sprintf("DROP TABLE IF EXISTS %s", pq.QuoteIdentifier(name))); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
}
|
||||
return rows.Err()
|
||||
}
|
||||
|
||||
func (r *dashboardAggregationRepository) createUsageLogsPartition(ctx context.Context, month time.Time) error {
|
||||
monthStart := truncateToMonthUTC(month)
|
||||
nextMonth := monthStart.AddDate(0, 1, 0)
|
||||
name := fmt.Sprintf("usage_logs_%s", monthStart.Format("200601"))
|
||||
query := fmt.Sprintf(
|
||||
"CREATE TABLE IF NOT EXISTS %s PARTITION OF usage_logs FOR VALUES FROM (%s) TO (%s)",
|
||||
pq.QuoteIdentifier(name),
|
||||
pq.QuoteLiteral(monthStart.Format("2006-01-02")),
|
||||
pq.QuoteLiteral(nextMonth.Format("2006-01-02")),
|
||||
)
|
||||
_, err := r.sql.ExecContext(ctx, query)
|
||||
return err
|
||||
}
|
||||
|
||||
func truncateToDay(t time.Time) time.Time {
|
||||
return timezone.StartOfDay(t)
|
||||
}
|
||||
|
||||
func truncateToMonthUTC(t time.Time) time.Time {
|
||||
t = t.UTC()
|
||||
return time.Date(t.Year(), t.Month(), 1, 0, 0, 0, 0, time.UTC)
|
||||
}
|
||||
58
backend/internal/repository/dashboard_cache.go
Normal file
58
backend/internal/repository/dashboard_cache.go
Normal file
@@ -0,0 +1,58 @@
|
||||
package repository
|
||||
|
||||
import (
|
||||
"context"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/config"
|
||||
"github.com/Wei-Shaw/sub2api/internal/service"
|
||||
"github.com/redis/go-redis/v9"
|
||||
)
|
||||
|
||||
const dashboardStatsCacheKey = "dashboard:stats:v1"
|
||||
|
||||
type dashboardCache struct {
|
||||
rdb *redis.Client
|
||||
keyPrefix string
|
||||
}
|
||||
|
||||
func NewDashboardCache(rdb *redis.Client, cfg *config.Config) service.DashboardStatsCache {
|
||||
prefix := "sub2api:"
|
||||
if cfg != nil {
|
||||
prefix = strings.TrimSpace(cfg.Dashboard.KeyPrefix)
|
||||
}
|
||||
if prefix != "" && !strings.HasSuffix(prefix, ":") {
|
||||
prefix += ":"
|
||||
}
|
||||
return &dashboardCache{
|
||||
rdb: rdb,
|
||||
keyPrefix: prefix,
|
||||
}
|
||||
}
|
||||
|
||||
func (c *dashboardCache) GetDashboardStats(ctx context.Context) (string, error) {
|
||||
val, err := c.rdb.Get(ctx, c.buildKey()).Result()
|
||||
if err != nil {
|
||||
if err == redis.Nil {
|
||||
return "", service.ErrDashboardStatsCacheMiss
|
||||
}
|
||||
return "", err
|
||||
}
|
||||
return val, nil
|
||||
}
|
||||
|
||||
func (c *dashboardCache) SetDashboardStats(ctx context.Context, data string, ttl time.Duration) error {
|
||||
return c.rdb.Set(ctx, c.buildKey(), data, ttl).Err()
|
||||
}
|
||||
|
||||
func (c *dashboardCache) buildKey() string {
|
||||
if c.keyPrefix == "" {
|
||||
return dashboardStatsCacheKey
|
||||
}
|
||||
return c.keyPrefix + dashboardStatsCacheKey
|
||||
}
|
||||
|
||||
func (c *dashboardCache) DeleteDashboardStats(ctx context.Context) error {
|
||||
return c.rdb.Del(ctx, c.buildKey()).Err()
|
||||
}
|
||||
28
backend/internal/repository/dashboard_cache_test.go
Normal file
28
backend/internal/repository/dashboard_cache_test.go
Normal file
@@ -0,0 +1,28 @@
|
||||
package repository
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/config"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestNewDashboardCacheKeyPrefix(t *testing.T) {
|
||||
cache := NewDashboardCache(nil, &config.Config{
|
||||
Dashboard: config.DashboardCacheConfig{
|
||||
KeyPrefix: "prod",
|
||||
},
|
||||
})
|
||||
impl, ok := cache.(*dashboardCache)
|
||||
require.True(t, ok)
|
||||
require.Equal(t, "prod:", impl.keyPrefix)
|
||||
|
||||
cache = NewDashboardCache(nil, &config.Config{
|
||||
Dashboard: config.DashboardCacheConfig{
|
||||
KeyPrefix: "staging:",
|
||||
},
|
||||
})
|
||||
impl, ok = cache.(*dashboardCache)
|
||||
require.True(t, ok)
|
||||
require.Equal(t, "staging:", impl.keyPrefix)
|
||||
}
|
||||
@@ -11,8 +11,8 @@ import (
|
||||
)
|
||||
|
||||
const (
|
||||
geminiTokenKeyPrefix = "gemini:token:"
|
||||
geminiRefreshLockKeyPrefix = "gemini:refresh_lock:"
|
||||
oauthTokenKeyPrefix = "oauth:token:"
|
||||
oauthRefreshLockKeyPrefix = "oauth:refresh_lock:"
|
||||
)
|
||||
|
||||
type geminiTokenCache struct {
|
||||
@@ -24,21 +24,26 @@ func NewGeminiTokenCache(rdb *redis.Client) service.GeminiTokenCache {
|
||||
}
|
||||
|
||||
func (c *geminiTokenCache) GetAccessToken(ctx context.Context, cacheKey string) (string, error) {
|
||||
key := fmt.Sprintf("%s%s", geminiTokenKeyPrefix, cacheKey)
|
||||
key := fmt.Sprintf("%s%s", oauthTokenKeyPrefix, cacheKey)
|
||||
return c.rdb.Get(ctx, key).Result()
|
||||
}
|
||||
|
||||
func (c *geminiTokenCache) SetAccessToken(ctx context.Context, cacheKey string, token string, ttl time.Duration) error {
|
||||
key := fmt.Sprintf("%s%s", geminiTokenKeyPrefix, cacheKey)
|
||||
key := fmt.Sprintf("%s%s", oauthTokenKeyPrefix, cacheKey)
|
||||
return c.rdb.Set(ctx, key, token, ttl).Err()
|
||||
}
|
||||
|
||||
func (c *geminiTokenCache) DeleteAccessToken(ctx context.Context, cacheKey string) error {
|
||||
key := fmt.Sprintf("%s%s", oauthTokenKeyPrefix, cacheKey)
|
||||
return c.rdb.Del(ctx, key).Err()
|
||||
}
|
||||
|
||||
func (c *geminiTokenCache) AcquireRefreshLock(ctx context.Context, cacheKey string, ttl time.Duration) (bool, error) {
|
||||
key := fmt.Sprintf("%s%s", geminiRefreshLockKeyPrefix, cacheKey)
|
||||
key := fmt.Sprintf("%s%s", oauthRefreshLockKeyPrefix, cacheKey)
|
||||
return c.rdb.SetNX(ctx, key, 1, ttl).Result()
|
||||
}
|
||||
|
||||
func (c *geminiTokenCache) ReleaseRefreshLock(ctx context.Context, cacheKey string) error {
|
||||
key := fmt.Sprintf("%s%s", geminiRefreshLockKeyPrefix, cacheKey)
|
||||
key := fmt.Sprintf("%s%s", oauthRefreshLockKeyPrefix, cacheKey)
|
||||
return c.rdb.Del(ctx, key).Err()
|
||||
}
|
||||
|
||||
@@ -0,0 +1,47 @@
|
||||
//go:build integration
|
||||
|
||||
package repository
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/service"
|
||||
"github.com/redis/go-redis/v9"
|
||||
"github.com/stretchr/testify/require"
|
||||
"github.com/stretchr/testify/suite"
|
||||
)
|
||||
|
||||
type GeminiTokenCacheSuite struct {
|
||||
IntegrationRedisSuite
|
||||
cache service.GeminiTokenCache
|
||||
}
|
||||
|
||||
func (s *GeminiTokenCacheSuite) SetupTest() {
|
||||
s.IntegrationRedisSuite.SetupTest()
|
||||
s.cache = NewGeminiTokenCache(s.rdb)
|
||||
}
|
||||
|
||||
func (s *GeminiTokenCacheSuite) TestDeleteAccessToken() {
|
||||
cacheKey := "project-123"
|
||||
token := "token-value"
|
||||
require.NoError(s.T(), s.cache.SetAccessToken(s.ctx, cacheKey, token, time.Minute))
|
||||
|
||||
got, err := s.cache.GetAccessToken(s.ctx, cacheKey)
|
||||
require.NoError(s.T(), err)
|
||||
require.Equal(s.T(), token, got)
|
||||
|
||||
require.NoError(s.T(), s.cache.DeleteAccessToken(s.ctx, cacheKey))
|
||||
|
||||
_, err = s.cache.GetAccessToken(s.ctx, cacheKey)
|
||||
require.True(s.T(), errors.Is(err, redis.Nil), "expected redis.Nil after delete")
|
||||
}
|
||||
|
||||
func (s *GeminiTokenCacheSuite) TestDeleteAccessToken_MissingKey() {
|
||||
require.NoError(s.T(), s.cache.DeleteAccessToken(s.ctx, "missing-key"))
|
||||
}
|
||||
|
||||
func TestGeminiTokenCacheSuite(t *testing.T) {
|
||||
suite.Run(t, new(GeminiTokenCacheSuite))
|
||||
}
|
||||
28
backend/internal/repository/gemini_token_cache_test.go
Normal file
28
backend/internal/repository/gemini_token_cache_test.go
Normal file
@@ -0,0 +1,28 @@
|
||||
//go:build unit
|
||||
|
||||
package repository
|
||||
|
||||
import (
|
||||
"context"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/redis/go-redis/v9"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestGeminiTokenCache_DeleteAccessToken_RedisError(t *testing.T) {
|
||||
rdb := redis.NewClient(&redis.Options{
|
||||
Addr: "127.0.0.1:1",
|
||||
DialTimeout: 50 * time.Millisecond,
|
||||
ReadTimeout: 50 * time.Millisecond,
|
||||
WriteTimeout: 50 * time.Millisecond,
|
||||
})
|
||||
t.Cleanup(func() {
|
||||
_ = rdb.Close()
|
||||
})
|
||||
|
||||
cache := NewGeminiTokenCache(rdb)
|
||||
err := cache.DeleteAccessToken(context.Background(), "broken")
|
||||
require.Error(t, err)
|
||||
}
|
||||
@@ -4,6 +4,7 @@ import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"errors"
|
||||
"log"
|
||||
|
||||
dbent "github.com/Wei-Shaw/sub2api/ent"
|
||||
"github.com/Wei-Shaw/sub2api/ent/apikey"
|
||||
@@ -48,18 +49,38 @@ func (r *groupRepository) Create(ctx context.Context, groupIn *service.Group) er
|
||||
SetNillableImagePrice4k(groupIn.ImagePrice4K).
|
||||
SetDefaultValidityDays(groupIn.DefaultValidityDays).
|
||||
SetClaudeCodeOnly(groupIn.ClaudeCodeOnly).
|
||||
SetNillableFallbackGroupID(groupIn.FallbackGroupID)
|
||||
SetNillableFallbackGroupID(groupIn.FallbackGroupID).
|
||||
SetModelRoutingEnabled(groupIn.ModelRoutingEnabled)
|
||||
|
||||
// 设置模型路由配置
|
||||
if groupIn.ModelRouting != nil {
|
||||
builder = builder.SetModelRouting(groupIn.ModelRouting)
|
||||
}
|
||||
|
||||
created, err := builder.Save(ctx)
|
||||
if err == nil {
|
||||
groupIn.ID = created.ID
|
||||
groupIn.CreatedAt = created.CreatedAt
|
||||
groupIn.UpdatedAt = created.UpdatedAt
|
||||
if err := enqueueSchedulerOutbox(ctx, r.sql, service.SchedulerOutboxEventGroupChanged, nil, &groupIn.ID, nil); err != nil {
|
||||
log.Printf("[SchedulerOutbox] enqueue group create failed: group=%d err=%v", groupIn.ID, err)
|
||||
}
|
||||
}
|
||||
return translatePersistenceError(err, nil, service.ErrGroupExists)
|
||||
}
|
||||
|
||||
func (r *groupRepository) GetByID(ctx context.Context, id int64) (*service.Group, error) {
|
||||
out, err := r.GetByIDLite(ctx, id)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
count, _ := r.GetAccountCount(ctx, out.ID)
|
||||
out.AccountCount = count
|
||||
return out, nil
|
||||
}
|
||||
|
||||
func (r *groupRepository) GetByIDLite(ctx context.Context, id int64) (*service.Group, error) {
|
||||
// AccountCount is intentionally not loaded here; use GetByID when needed.
|
||||
m, err := r.client.Group.Query().
|
||||
Where(group.IDEQ(id)).
|
||||
Only(ctx)
|
||||
@@ -67,10 +88,7 @@ func (r *groupRepository) GetByID(ctx context.Context, id int64) (*service.Group
|
||||
return nil, translatePersistenceError(err, service.ErrGroupNotFound, nil)
|
||||
}
|
||||
|
||||
out := groupEntityToService(m)
|
||||
count, _ := r.GetAccountCount(ctx, out.ID)
|
||||
out.AccountCount = count
|
||||
return out, nil
|
||||
return groupEntityToService(m), nil
|
||||
}
|
||||
|
||||
func (r *groupRepository) Update(ctx context.Context, groupIn *service.Group) error {
|
||||
@@ -89,7 +107,8 @@ func (r *groupRepository) Update(ctx context.Context, groupIn *service.Group) er
|
||||
SetNillableImagePrice2k(groupIn.ImagePrice2K).
|
||||
SetNillableImagePrice4k(groupIn.ImagePrice4K).
|
||||
SetDefaultValidityDays(groupIn.DefaultValidityDays).
|
||||
SetClaudeCodeOnly(groupIn.ClaudeCodeOnly)
|
||||
SetClaudeCodeOnly(groupIn.ClaudeCodeOnly).
|
||||
SetModelRoutingEnabled(groupIn.ModelRoutingEnabled)
|
||||
|
||||
// 处理 FallbackGroupID:nil 时清除,否则设置
|
||||
if groupIn.FallbackGroupID != nil {
|
||||
@@ -98,17 +117,33 @@ func (r *groupRepository) Update(ctx context.Context, groupIn *service.Group) er
|
||||
builder = builder.ClearFallbackGroupID()
|
||||
}
|
||||
|
||||
// 处理 ModelRouting:nil 时清除,否则设置
|
||||
if groupIn.ModelRouting != nil {
|
||||
builder = builder.SetModelRouting(groupIn.ModelRouting)
|
||||
} else {
|
||||
builder = builder.ClearModelRouting()
|
||||
}
|
||||
|
||||
updated, err := builder.Save(ctx)
|
||||
if err != nil {
|
||||
return translatePersistenceError(err, service.ErrGroupNotFound, service.ErrGroupExists)
|
||||
}
|
||||
groupIn.UpdatedAt = updated.UpdatedAt
|
||||
if err := enqueueSchedulerOutbox(ctx, r.sql, service.SchedulerOutboxEventGroupChanged, nil, &groupIn.ID, nil); err != nil {
|
||||
log.Printf("[SchedulerOutbox] enqueue group update failed: group=%d err=%v", groupIn.ID, err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (r *groupRepository) Delete(ctx context.Context, id int64) error {
|
||||
_, err := r.client.Group.Delete().Where(group.IDEQ(id)).Exec(ctx)
|
||||
return translatePersistenceError(err, service.ErrGroupNotFound, nil)
|
||||
if err != nil {
|
||||
return translatePersistenceError(err, service.ErrGroupNotFound, nil)
|
||||
}
|
||||
if err := enqueueSchedulerOutbox(ctx, r.sql, service.SchedulerOutboxEventGroupChanged, nil, &id, nil); err != nil {
|
||||
log.Printf("[SchedulerOutbox] enqueue group delete failed: group=%d err=%v", id, err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (r *groupRepository) List(ctx context.Context, params pagination.PaginationParams) ([]service.Group, *pagination.PaginationResult, error) {
|
||||
@@ -238,6 +273,9 @@ func (r *groupRepository) DeleteAccountGroupsByGroupID(ctx context.Context, grou
|
||||
return 0, err
|
||||
}
|
||||
affected, _ := res.RowsAffected()
|
||||
if err := enqueueSchedulerOutbox(ctx, r.sql, service.SchedulerOutboxEventGroupChanged, nil, &groupID, nil); err != nil {
|
||||
log.Printf("[SchedulerOutbox] enqueue group account clear failed: group=%d err=%v", groupID, err)
|
||||
}
|
||||
return affected, nil
|
||||
}
|
||||
|
||||
@@ -345,6 +383,9 @@ func (r *groupRepository) DeleteCascade(ctx context.Context, id int64) ([]int64,
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
if err := enqueueSchedulerOutbox(ctx, r.sql, service.SchedulerOutboxEventGroupChanged, nil, &id, nil); err != nil {
|
||||
log.Printf("[SchedulerOutbox] enqueue group cascade delete failed: group=%d err=%v", id, err)
|
||||
}
|
||||
|
||||
return affectedUserIDs, nil
|
||||
}
|
||||
|
||||
@@ -4,6 +4,8 @@ package repository
|
||||
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"errors"
|
||||
"testing"
|
||||
|
||||
dbent "github.com/Wei-Shaw/sub2api/ent"
|
||||
@@ -19,6 +21,20 @@ type GroupRepoSuite struct {
|
||||
repo *groupRepository
|
||||
}
|
||||
|
||||
type forbidSQLExecutor struct {
|
||||
called bool
|
||||
}
|
||||
|
||||
func (s *forbidSQLExecutor) ExecContext(ctx context.Context, query string, args ...any) (sql.Result, error) {
|
||||
s.called = true
|
||||
return nil, errors.New("unexpected sql exec")
|
||||
}
|
||||
|
||||
func (s *forbidSQLExecutor) QueryContext(ctx context.Context, query string, args ...any) (*sql.Rows, error) {
|
||||
s.called = true
|
||||
return nil, errors.New("unexpected sql query")
|
||||
}
|
||||
|
||||
func (s *GroupRepoSuite) SetupTest() {
|
||||
s.ctx = context.Background()
|
||||
tx := testEntTx(s.T())
|
||||
@@ -57,6 +73,26 @@ func (s *GroupRepoSuite) TestGetByID_NotFound() {
|
||||
s.Require().ErrorIs(err, service.ErrGroupNotFound)
|
||||
}
|
||||
|
||||
func (s *GroupRepoSuite) TestGetByIDLite_DoesNotUseAccountCount() {
|
||||
group := &service.Group{
|
||||
Name: "lite-group",
|
||||
Platform: service.PlatformAnthropic,
|
||||
RateMultiplier: 1.0,
|
||||
IsExclusive: false,
|
||||
Status: service.StatusActive,
|
||||
SubscriptionType: service.SubscriptionTypeStandard,
|
||||
}
|
||||
s.Require().NoError(s.repo.Create(s.ctx, group))
|
||||
|
||||
spy := &forbidSQLExecutor{}
|
||||
repo := newGroupRepositoryWithSQL(s.tx.Client(), spy)
|
||||
|
||||
got, err := repo.GetByIDLite(s.ctx, group.ID)
|
||||
s.Require().NoError(err)
|
||||
s.Require().Equal(group.ID, got.ID)
|
||||
s.Require().False(spy.called, "expected no direct sql executor usage")
|
||||
}
|
||||
|
||||
func (s *GroupRepoSuite) TestUpdate() {
|
||||
group := &service.Group{
|
||||
Name: "original",
|
||||
|
||||
@@ -28,6 +28,23 @@ CREATE TABLE IF NOT EXISTS schema_migrations (
|
||||
);
|
||||
`
|
||||
|
||||
const atlasSchemaRevisionsTableDDL = `
|
||||
CREATE TABLE IF NOT EXISTS atlas_schema_revisions (
|
||||
version TEXT PRIMARY KEY,
|
||||
description TEXT NOT NULL,
|
||||
type INTEGER NOT NULL,
|
||||
applied INTEGER NOT NULL DEFAULT 0,
|
||||
total INTEGER NOT NULL DEFAULT 0,
|
||||
executed_at TIMESTAMPTZ NOT NULL DEFAULT NOW(),
|
||||
execution_time BIGINT NOT NULL DEFAULT 0,
|
||||
error TEXT NULL,
|
||||
error_stmt TEXT NULL,
|
||||
hash TEXT NOT NULL DEFAULT '',
|
||||
partial_hashes TEXT[] NULL,
|
||||
operator_version TEXT NULL
|
||||
);
|
||||
`
|
||||
|
||||
// migrationsAdvisoryLockID 是用于序列化迁移操作的 PostgreSQL Advisory Lock ID。
|
||||
// 在多实例部署场景下,该锁确保同一时间只有一个实例执行迁移。
|
||||
// 任何稳定的 int64 值都可以,只要不与同一数据库中的其他锁冲突即可。
|
||||
@@ -94,6 +111,11 @@ func applyMigrationsFS(ctx context.Context, db *sql.DB, fsys fs.FS) error {
|
||||
return fmt.Errorf("create schema_migrations: %w", err)
|
||||
}
|
||||
|
||||
// 自动对齐 Atlas 基线(如果检测到 legacy schema_migrations 且缺失 atlas_schema_revisions)。
|
||||
if err := ensureAtlasBaselineAligned(ctx, db, fsys); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// 获取所有 .sql 迁移文件并按文件名排序。
|
||||
// 命名规范:使用零填充数字前缀(如 001_init.sql, 002_add_users.sql)。
|
||||
files, err := fs.Glob(fsys, "*.sql")
|
||||
@@ -172,6 +194,80 @@ func applyMigrationsFS(ctx context.Context, db *sql.DB, fsys fs.FS) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func ensureAtlasBaselineAligned(ctx context.Context, db *sql.DB, fsys fs.FS) error {
|
||||
hasLegacy, err := tableExists(ctx, db, "schema_migrations")
|
||||
if err != nil {
|
||||
return fmt.Errorf("check schema_migrations: %w", err)
|
||||
}
|
||||
if !hasLegacy {
|
||||
return nil
|
||||
}
|
||||
|
||||
hasAtlas, err := tableExists(ctx, db, "atlas_schema_revisions")
|
||||
if err != nil {
|
||||
return fmt.Errorf("check atlas_schema_revisions: %w", err)
|
||||
}
|
||||
if !hasAtlas {
|
||||
if _, err := db.ExecContext(ctx, atlasSchemaRevisionsTableDDL); err != nil {
|
||||
return fmt.Errorf("create atlas_schema_revisions: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
var count int
|
||||
if err := db.QueryRowContext(ctx, "SELECT COUNT(*) FROM atlas_schema_revisions").Scan(&count); err != nil {
|
||||
return fmt.Errorf("count atlas_schema_revisions: %w", err)
|
||||
}
|
||||
if count > 0 {
|
||||
return nil
|
||||
}
|
||||
|
||||
version, description, hash, err := latestMigrationBaseline(fsys)
|
||||
if err != nil {
|
||||
return fmt.Errorf("atlas baseline version: %w", err)
|
||||
}
|
||||
|
||||
if _, err := db.ExecContext(ctx, `
|
||||
INSERT INTO atlas_schema_revisions (version, description, type, applied, total, executed_at, execution_time, hash)
|
||||
VALUES ($1, $2, $3, 0, 0, NOW(), 0, $4)
|
||||
`, version, description, 1, hash); err != nil {
|
||||
return fmt.Errorf("insert atlas baseline: %w", err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func tableExists(ctx context.Context, db *sql.DB, tableName string) (bool, error) {
|
||||
var exists bool
|
||||
err := db.QueryRowContext(ctx, `
|
||||
SELECT EXISTS (
|
||||
SELECT 1
|
||||
FROM information_schema.tables
|
||||
WHERE table_schema = 'public' AND table_name = $1
|
||||
)
|
||||
`, tableName).Scan(&exists)
|
||||
return exists, err
|
||||
}
|
||||
|
||||
func latestMigrationBaseline(fsys fs.FS) (string, string, string, error) {
|
||||
files, err := fs.Glob(fsys, "*.sql")
|
||||
if err != nil {
|
||||
return "", "", "", err
|
||||
}
|
||||
if len(files) == 0 {
|
||||
return "baseline", "baseline", "", nil
|
||||
}
|
||||
sort.Strings(files)
|
||||
name := files[len(files)-1]
|
||||
contentBytes, err := fs.ReadFile(fsys, name)
|
||||
if err != nil {
|
||||
return "", "", "", err
|
||||
}
|
||||
content := strings.TrimSpace(string(contentBytes))
|
||||
sum := sha256.Sum256([]byte(content))
|
||||
hash := hex.EncodeToString(sum[:])
|
||||
version := strings.TrimSuffix(name, ".sql")
|
||||
return version, version, hash, nil
|
||||
}
|
||||
|
||||
// pgAdvisoryLock 获取 PostgreSQL Advisory Lock。
|
||||
// Advisory Lock 是一种轻量级的锁机制,不与任何特定的数据库对象关联。
|
||||
// 它非常适合用于应用层面的分布式锁场景,如迁移序列化。
|
||||
|
||||
1098
backend/internal/repository/ops_repo.go
Normal file
1098
backend/internal/repository/ops_repo.go
Normal file
File diff suppressed because it is too large
Load Diff
853
backend/internal/repository/ops_repo_alerts.go
Normal file
853
backend/internal/repository/ops_repo_alerts.go
Normal file
@@ -0,0 +1,853 @@
|
||||
package repository
|
||||
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/service"
|
||||
)
|
||||
|
||||
func (r *opsRepository) ListAlertRules(ctx context.Context) ([]*service.OpsAlertRule, error) {
|
||||
if r == nil || r.db == nil {
|
||||
return nil, fmt.Errorf("nil ops repository")
|
||||
}
|
||||
|
||||
q := `
|
||||
SELECT
|
||||
id,
|
||||
name,
|
||||
COALESCE(description, ''),
|
||||
enabled,
|
||||
COALESCE(severity, ''),
|
||||
metric_type,
|
||||
operator,
|
||||
threshold,
|
||||
window_minutes,
|
||||
sustained_minutes,
|
||||
cooldown_minutes,
|
||||
COALESCE(notify_email, true),
|
||||
filters,
|
||||
last_triggered_at,
|
||||
created_at,
|
||||
updated_at
|
||||
FROM ops_alert_rules
|
||||
ORDER BY id DESC`
|
||||
|
||||
rows, err := r.db.QueryContext(ctx, q)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer func() { _ = rows.Close() }()
|
||||
|
||||
out := []*service.OpsAlertRule{}
|
||||
for rows.Next() {
|
||||
var rule service.OpsAlertRule
|
||||
var filtersRaw []byte
|
||||
var lastTriggeredAt sql.NullTime
|
||||
if err := rows.Scan(
|
||||
&rule.ID,
|
||||
&rule.Name,
|
||||
&rule.Description,
|
||||
&rule.Enabled,
|
||||
&rule.Severity,
|
||||
&rule.MetricType,
|
||||
&rule.Operator,
|
||||
&rule.Threshold,
|
||||
&rule.WindowMinutes,
|
||||
&rule.SustainedMinutes,
|
||||
&rule.CooldownMinutes,
|
||||
&rule.NotifyEmail,
|
||||
&filtersRaw,
|
||||
&lastTriggeredAt,
|
||||
&rule.CreatedAt,
|
||||
&rule.UpdatedAt,
|
||||
); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if lastTriggeredAt.Valid {
|
||||
v := lastTriggeredAt.Time
|
||||
rule.LastTriggeredAt = &v
|
||||
}
|
||||
if len(filtersRaw) > 0 && string(filtersRaw) != "null" {
|
||||
var decoded map[string]any
|
||||
if err := json.Unmarshal(filtersRaw, &decoded); err == nil {
|
||||
rule.Filters = decoded
|
||||
}
|
||||
}
|
||||
out = append(out, &rule)
|
||||
}
|
||||
if err := rows.Err(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return out, nil
|
||||
}
|
||||
|
||||
func (r *opsRepository) CreateAlertRule(ctx context.Context, input *service.OpsAlertRule) (*service.OpsAlertRule, error) {
|
||||
if r == nil || r.db == nil {
|
||||
return nil, fmt.Errorf("nil ops repository")
|
||||
}
|
||||
if input == nil {
|
||||
return nil, fmt.Errorf("nil input")
|
||||
}
|
||||
|
||||
filtersArg, err := opsNullJSONMap(input.Filters)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
q := `
|
||||
INSERT INTO ops_alert_rules (
|
||||
name,
|
||||
description,
|
||||
enabled,
|
||||
severity,
|
||||
metric_type,
|
||||
operator,
|
||||
threshold,
|
||||
window_minutes,
|
||||
sustained_minutes,
|
||||
cooldown_minutes,
|
||||
notify_email,
|
||||
filters,
|
||||
created_at,
|
||||
updated_at
|
||||
) VALUES (
|
||||
$1,$2,$3,$4,$5,$6,$7,$8,$9,$10,$11,$12,NOW(),NOW()
|
||||
)
|
||||
RETURNING
|
||||
id,
|
||||
name,
|
||||
COALESCE(description, ''),
|
||||
enabled,
|
||||
COALESCE(severity, ''),
|
||||
metric_type,
|
||||
operator,
|
||||
threshold,
|
||||
window_minutes,
|
||||
sustained_minutes,
|
||||
cooldown_minutes,
|
||||
COALESCE(notify_email, true),
|
||||
filters,
|
||||
last_triggered_at,
|
||||
created_at,
|
||||
updated_at`
|
||||
|
||||
var out service.OpsAlertRule
|
||||
var filtersRaw []byte
|
||||
var lastTriggeredAt sql.NullTime
|
||||
|
||||
if err := r.db.QueryRowContext(
|
||||
ctx,
|
||||
q,
|
||||
strings.TrimSpace(input.Name),
|
||||
strings.TrimSpace(input.Description),
|
||||
input.Enabled,
|
||||
strings.TrimSpace(input.Severity),
|
||||
strings.TrimSpace(input.MetricType),
|
||||
strings.TrimSpace(input.Operator),
|
||||
input.Threshold,
|
||||
input.WindowMinutes,
|
||||
input.SustainedMinutes,
|
||||
input.CooldownMinutes,
|
||||
input.NotifyEmail,
|
||||
filtersArg,
|
||||
).Scan(
|
||||
&out.ID,
|
||||
&out.Name,
|
||||
&out.Description,
|
||||
&out.Enabled,
|
||||
&out.Severity,
|
||||
&out.MetricType,
|
||||
&out.Operator,
|
||||
&out.Threshold,
|
||||
&out.WindowMinutes,
|
||||
&out.SustainedMinutes,
|
||||
&out.CooldownMinutes,
|
||||
&out.NotifyEmail,
|
||||
&filtersRaw,
|
||||
&lastTriggeredAt,
|
||||
&out.CreatedAt,
|
||||
&out.UpdatedAt,
|
||||
); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if lastTriggeredAt.Valid {
|
||||
v := lastTriggeredAt.Time
|
||||
out.LastTriggeredAt = &v
|
||||
}
|
||||
if len(filtersRaw) > 0 && string(filtersRaw) != "null" {
|
||||
var decoded map[string]any
|
||||
if err := json.Unmarshal(filtersRaw, &decoded); err == nil {
|
||||
out.Filters = decoded
|
||||
}
|
||||
}
|
||||
|
||||
return &out, nil
|
||||
}
|
||||
|
||||
func (r *opsRepository) UpdateAlertRule(ctx context.Context, input *service.OpsAlertRule) (*service.OpsAlertRule, error) {
|
||||
if r == nil || r.db == nil {
|
||||
return nil, fmt.Errorf("nil ops repository")
|
||||
}
|
||||
if input == nil {
|
||||
return nil, fmt.Errorf("nil input")
|
||||
}
|
||||
if input.ID <= 0 {
|
||||
return nil, fmt.Errorf("invalid id")
|
||||
}
|
||||
|
||||
filtersArg, err := opsNullJSONMap(input.Filters)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
q := `
|
||||
UPDATE ops_alert_rules
|
||||
SET
|
||||
name = $2,
|
||||
description = $3,
|
||||
enabled = $4,
|
||||
severity = $5,
|
||||
metric_type = $6,
|
||||
operator = $7,
|
||||
threshold = $8,
|
||||
window_minutes = $9,
|
||||
sustained_minutes = $10,
|
||||
cooldown_minutes = $11,
|
||||
notify_email = $12,
|
||||
filters = $13,
|
||||
updated_at = NOW()
|
||||
WHERE id = $1
|
||||
RETURNING
|
||||
id,
|
||||
name,
|
||||
COALESCE(description, ''),
|
||||
enabled,
|
||||
COALESCE(severity, ''),
|
||||
metric_type,
|
||||
operator,
|
||||
threshold,
|
||||
window_minutes,
|
||||
sustained_minutes,
|
||||
cooldown_minutes,
|
||||
COALESCE(notify_email, true),
|
||||
filters,
|
||||
last_triggered_at,
|
||||
created_at,
|
||||
updated_at`
|
||||
|
||||
var out service.OpsAlertRule
|
||||
var filtersRaw []byte
|
||||
var lastTriggeredAt sql.NullTime
|
||||
|
||||
if err := r.db.QueryRowContext(
|
||||
ctx,
|
||||
q,
|
||||
input.ID,
|
||||
strings.TrimSpace(input.Name),
|
||||
strings.TrimSpace(input.Description),
|
||||
input.Enabled,
|
||||
strings.TrimSpace(input.Severity),
|
||||
strings.TrimSpace(input.MetricType),
|
||||
strings.TrimSpace(input.Operator),
|
||||
input.Threshold,
|
||||
input.WindowMinutes,
|
||||
input.SustainedMinutes,
|
||||
input.CooldownMinutes,
|
||||
input.NotifyEmail,
|
||||
filtersArg,
|
||||
).Scan(
|
||||
&out.ID,
|
||||
&out.Name,
|
||||
&out.Description,
|
||||
&out.Enabled,
|
||||
&out.Severity,
|
||||
&out.MetricType,
|
||||
&out.Operator,
|
||||
&out.Threshold,
|
||||
&out.WindowMinutes,
|
||||
&out.SustainedMinutes,
|
||||
&out.CooldownMinutes,
|
||||
&out.NotifyEmail,
|
||||
&filtersRaw,
|
||||
&lastTriggeredAt,
|
||||
&out.CreatedAt,
|
||||
&out.UpdatedAt,
|
||||
); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if lastTriggeredAt.Valid {
|
||||
v := lastTriggeredAt.Time
|
||||
out.LastTriggeredAt = &v
|
||||
}
|
||||
if len(filtersRaw) > 0 && string(filtersRaw) != "null" {
|
||||
var decoded map[string]any
|
||||
if err := json.Unmarshal(filtersRaw, &decoded); err == nil {
|
||||
out.Filters = decoded
|
||||
}
|
||||
}
|
||||
|
||||
return &out, nil
|
||||
}
|
||||
|
||||
func (r *opsRepository) DeleteAlertRule(ctx context.Context, id int64) error {
|
||||
if r == nil || r.db == nil {
|
||||
return fmt.Errorf("nil ops repository")
|
||||
}
|
||||
if id <= 0 {
|
||||
return fmt.Errorf("invalid id")
|
||||
}
|
||||
|
||||
res, err := r.db.ExecContext(ctx, "DELETE FROM ops_alert_rules WHERE id = $1", id)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
affected, err := res.RowsAffected()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if affected == 0 {
|
||||
return sql.ErrNoRows
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (r *opsRepository) ListAlertEvents(ctx context.Context, filter *service.OpsAlertEventFilter) ([]*service.OpsAlertEvent, error) {
|
||||
if r == nil || r.db == nil {
|
||||
return nil, fmt.Errorf("nil ops repository")
|
||||
}
|
||||
if filter == nil {
|
||||
filter = &service.OpsAlertEventFilter{}
|
||||
}
|
||||
|
||||
limit := filter.Limit
|
||||
if limit <= 0 {
|
||||
limit = 100
|
||||
}
|
||||
if limit > 500 {
|
||||
limit = 500
|
||||
}
|
||||
|
||||
where, args := buildOpsAlertEventsWhere(filter)
|
||||
args = append(args, limit)
|
||||
limitArg := "$" + itoa(len(args))
|
||||
|
||||
q := `
|
||||
SELECT
|
||||
id,
|
||||
COALESCE(rule_id, 0),
|
||||
COALESCE(severity, ''),
|
||||
COALESCE(status, ''),
|
||||
COALESCE(title, ''),
|
||||
COALESCE(description, ''),
|
||||
metric_value,
|
||||
threshold_value,
|
||||
dimensions,
|
||||
fired_at,
|
||||
resolved_at,
|
||||
email_sent,
|
||||
created_at
|
||||
FROM ops_alert_events
|
||||
` + where + `
|
||||
ORDER BY fired_at DESC, id DESC
|
||||
LIMIT ` + limitArg
|
||||
|
||||
rows, err := r.db.QueryContext(ctx, q, args...)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer func() { _ = rows.Close() }()
|
||||
|
||||
out := []*service.OpsAlertEvent{}
|
||||
for rows.Next() {
|
||||
var ev service.OpsAlertEvent
|
||||
var metricValue sql.NullFloat64
|
||||
var thresholdValue sql.NullFloat64
|
||||
var dimensionsRaw []byte
|
||||
var resolvedAt sql.NullTime
|
||||
if err := rows.Scan(
|
||||
&ev.ID,
|
||||
&ev.RuleID,
|
||||
&ev.Severity,
|
||||
&ev.Status,
|
||||
&ev.Title,
|
||||
&ev.Description,
|
||||
&metricValue,
|
||||
&thresholdValue,
|
||||
&dimensionsRaw,
|
||||
&ev.FiredAt,
|
||||
&resolvedAt,
|
||||
&ev.EmailSent,
|
||||
&ev.CreatedAt,
|
||||
); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if metricValue.Valid {
|
||||
v := metricValue.Float64
|
||||
ev.MetricValue = &v
|
||||
}
|
||||
if thresholdValue.Valid {
|
||||
v := thresholdValue.Float64
|
||||
ev.ThresholdValue = &v
|
||||
}
|
||||
if resolvedAt.Valid {
|
||||
v := resolvedAt.Time
|
||||
ev.ResolvedAt = &v
|
||||
}
|
||||
if len(dimensionsRaw) > 0 && string(dimensionsRaw) != "null" {
|
||||
var decoded map[string]any
|
||||
if err := json.Unmarshal(dimensionsRaw, &decoded); err == nil {
|
||||
ev.Dimensions = decoded
|
||||
}
|
||||
}
|
||||
out = append(out, &ev)
|
||||
}
|
||||
if err := rows.Err(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return out, nil
|
||||
}
|
||||
|
||||
func (r *opsRepository) GetAlertEventByID(ctx context.Context, eventID int64) (*service.OpsAlertEvent, error) {
|
||||
if r == nil || r.db == nil {
|
||||
return nil, fmt.Errorf("nil ops repository")
|
||||
}
|
||||
if eventID <= 0 {
|
||||
return nil, fmt.Errorf("invalid event id")
|
||||
}
|
||||
|
||||
q := `
|
||||
SELECT
|
||||
id,
|
||||
COALESCE(rule_id, 0),
|
||||
COALESCE(severity, ''),
|
||||
COALESCE(status, ''),
|
||||
COALESCE(title, ''),
|
||||
COALESCE(description, ''),
|
||||
metric_value,
|
||||
threshold_value,
|
||||
dimensions,
|
||||
fired_at,
|
||||
resolved_at,
|
||||
email_sent,
|
||||
created_at
|
||||
FROM ops_alert_events
|
||||
WHERE id = $1`
|
||||
|
||||
row := r.db.QueryRowContext(ctx, q, eventID)
|
||||
ev, err := scanOpsAlertEvent(row)
|
||||
if err != nil {
|
||||
if err == sql.ErrNoRows {
|
||||
return nil, nil
|
||||
}
|
||||
return nil, err
|
||||
}
|
||||
return ev, nil
|
||||
}
|
||||
|
||||
func (r *opsRepository) GetActiveAlertEvent(ctx context.Context, ruleID int64) (*service.OpsAlertEvent, error) {
|
||||
if r == nil || r.db == nil {
|
||||
return nil, fmt.Errorf("nil ops repository")
|
||||
}
|
||||
if ruleID <= 0 {
|
||||
return nil, fmt.Errorf("invalid rule id")
|
||||
}
|
||||
|
||||
q := `
|
||||
SELECT
|
||||
id,
|
||||
COALESCE(rule_id, 0),
|
||||
COALESCE(severity, ''),
|
||||
COALESCE(status, ''),
|
||||
COALESCE(title, ''),
|
||||
COALESCE(description, ''),
|
||||
metric_value,
|
||||
threshold_value,
|
||||
dimensions,
|
||||
fired_at,
|
||||
resolved_at,
|
||||
email_sent,
|
||||
created_at
|
||||
FROM ops_alert_events
|
||||
WHERE rule_id = $1 AND status = $2
|
||||
ORDER BY fired_at DESC
|
||||
LIMIT 1`
|
||||
|
||||
row := r.db.QueryRowContext(ctx, q, ruleID, service.OpsAlertStatusFiring)
|
||||
ev, err := scanOpsAlertEvent(row)
|
||||
if err != nil {
|
||||
if err == sql.ErrNoRows {
|
||||
return nil, nil
|
||||
}
|
||||
return nil, err
|
||||
}
|
||||
return ev, nil
|
||||
}
|
||||
|
||||
func (r *opsRepository) GetLatestAlertEvent(ctx context.Context, ruleID int64) (*service.OpsAlertEvent, error) {
|
||||
if r == nil || r.db == nil {
|
||||
return nil, fmt.Errorf("nil ops repository")
|
||||
}
|
||||
if ruleID <= 0 {
|
||||
return nil, fmt.Errorf("invalid rule id")
|
||||
}
|
||||
|
||||
q := `
|
||||
SELECT
|
||||
id,
|
||||
COALESCE(rule_id, 0),
|
||||
COALESCE(severity, ''),
|
||||
COALESCE(status, ''),
|
||||
COALESCE(title, ''),
|
||||
COALESCE(description, ''),
|
||||
metric_value,
|
||||
threshold_value,
|
||||
dimensions,
|
||||
fired_at,
|
||||
resolved_at,
|
||||
email_sent,
|
||||
created_at
|
||||
FROM ops_alert_events
|
||||
WHERE rule_id = $1
|
||||
ORDER BY fired_at DESC
|
||||
LIMIT 1`
|
||||
|
||||
row := r.db.QueryRowContext(ctx, q, ruleID)
|
||||
ev, err := scanOpsAlertEvent(row)
|
||||
if err != nil {
|
||||
if err == sql.ErrNoRows {
|
||||
return nil, nil
|
||||
}
|
||||
return nil, err
|
||||
}
|
||||
return ev, nil
|
||||
}
|
||||
|
||||
func (r *opsRepository) CreateAlertEvent(ctx context.Context, event *service.OpsAlertEvent) (*service.OpsAlertEvent, error) {
|
||||
if r == nil || r.db == nil {
|
||||
return nil, fmt.Errorf("nil ops repository")
|
||||
}
|
||||
if event == nil {
|
||||
return nil, fmt.Errorf("nil event")
|
||||
}
|
||||
|
||||
dimensionsArg, err := opsNullJSONMap(event.Dimensions)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
q := `
|
||||
INSERT INTO ops_alert_events (
|
||||
rule_id,
|
||||
severity,
|
||||
status,
|
||||
title,
|
||||
description,
|
||||
metric_value,
|
||||
threshold_value,
|
||||
dimensions,
|
||||
fired_at,
|
||||
resolved_at,
|
||||
email_sent,
|
||||
created_at
|
||||
) VALUES (
|
||||
$1,$2,$3,$4,$5,$6,$7,$8,$9,$10,$11,NOW()
|
||||
)
|
||||
RETURNING
|
||||
id,
|
||||
COALESCE(rule_id, 0),
|
||||
COALESCE(severity, ''),
|
||||
COALESCE(status, ''),
|
||||
COALESCE(title, ''),
|
||||
COALESCE(description, ''),
|
||||
metric_value,
|
||||
threshold_value,
|
||||
dimensions,
|
||||
fired_at,
|
||||
resolved_at,
|
||||
email_sent,
|
||||
created_at`
|
||||
|
||||
row := r.db.QueryRowContext(
|
||||
ctx,
|
||||
q,
|
||||
opsNullInt64(&event.RuleID),
|
||||
opsNullString(event.Severity),
|
||||
opsNullString(event.Status),
|
||||
opsNullString(event.Title),
|
||||
opsNullString(event.Description),
|
||||
opsNullFloat64(event.MetricValue),
|
||||
opsNullFloat64(event.ThresholdValue),
|
||||
dimensionsArg,
|
||||
event.FiredAt,
|
||||
opsNullTime(event.ResolvedAt),
|
||||
event.EmailSent,
|
||||
)
|
||||
return scanOpsAlertEvent(row)
|
||||
}
|
||||
|
||||
func (r *opsRepository) UpdateAlertEventStatus(ctx context.Context, eventID int64, status string, resolvedAt *time.Time) error {
|
||||
if r == nil || r.db == nil {
|
||||
return fmt.Errorf("nil ops repository")
|
||||
}
|
||||
if eventID <= 0 {
|
||||
return fmt.Errorf("invalid event id")
|
||||
}
|
||||
if strings.TrimSpace(status) == "" {
|
||||
return fmt.Errorf("invalid status")
|
||||
}
|
||||
|
||||
q := `
|
||||
UPDATE ops_alert_events
|
||||
SET status = $2,
|
||||
resolved_at = $3
|
||||
WHERE id = $1`
|
||||
|
||||
_, err := r.db.ExecContext(ctx, q, eventID, strings.TrimSpace(status), opsNullTime(resolvedAt))
|
||||
return err
|
||||
}
|
||||
|
||||
func (r *opsRepository) UpdateAlertEventEmailSent(ctx context.Context, eventID int64, emailSent bool) error {
|
||||
if r == nil || r.db == nil {
|
||||
return fmt.Errorf("nil ops repository")
|
||||
}
|
||||
if eventID <= 0 {
|
||||
return fmt.Errorf("invalid event id")
|
||||
}
|
||||
|
||||
_, err := r.db.ExecContext(ctx, "UPDATE ops_alert_events SET email_sent = $2 WHERE id = $1", eventID, emailSent)
|
||||
return err
|
||||
}
|
||||
|
||||
type opsAlertEventRow interface {
|
||||
Scan(dest ...any) error
|
||||
}
|
||||
|
||||
func (r *opsRepository) CreateAlertSilence(ctx context.Context, input *service.OpsAlertSilence) (*service.OpsAlertSilence, error) {
|
||||
if r == nil || r.db == nil {
|
||||
return nil, fmt.Errorf("nil ops repository")
|
||||
}
|
||||
if input == nil {
|
||||
return nil, fmt.Errorf("nil input")
|
||||
}
|
||||
if input.RuleID <= 0 {
|
||||
return nil, fmt.Errorf("invalid rule_id")
|
||||
}
|
||||
platform := strings.TrimSpace(input.Platform)
|
||||
if platform == "" {
|
||||
return nil, fmt.Errorf("invalid platform")
|
||||
}
|
||||
if input.Until.IsZero() {
|
||||
return nil, fmt.Errorf("invalid until")
|
||||
}
|
||||
|
||||
q := `
|
||||
INSERT INTO ops_alert_silences (
|
||||
rule_id,
|
||||
platform,
|
||||
group_id,
|
||||
region,
|
||||
until,
|
||||
reason,
|
||||
created_by,
|
||||
created_at
|
||||
) VALUES (
|
||||
$1,$2,$3,$4,$5,$6,$7,NOW()
|
||||
)
|
||||
RETURNING id, rule_id, platform, group_id, region, until, COALESCE(reason,''), created_by, created_at`
|
||||
|
||||
row := r.db.QueryRowContext(
|
||||
ctx,
|
||||
q,
|
||||
input.RuleID,
|
||||
platform,
|
||||
opsNullInt64(input.GroupID),
|
||||
opsNullString(input.Region),
|
||||
input.Until,
|
||||
opsNullString(input.Reason),
|
||||
opsNullInt64(input.CreatedBy),
|
||||
)
|
||||
|
||||
var out service.OpsAlertSilence
|
||||
var groupID sql.NullInt64
|
||||
var region sql.NullString
|
||||
var createdBy sql.NullInt64
|
||||
if err := row.Scan(
|
||||
&out.ID,
|
||||
&out.RuleID,
|
||||
&out.Platform,
|
||||
&groupID,
|
||||
®ion,
|
||||
&out.Until,
|
||||
&out.Reason,
|
||||
&createdBy,
|
||||
&out.CreatedAt,
|
||||
); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if groupID.Valid {
|
||||
v := groupID.Int64
|
||||
out.GroupID = &v
|
||||
}
|
||||
if region.Valid {
|
||||
v := strings.TrimSpace(region.String)
|
||||
if v != "" {
|
||||
out.Region = &v
|
||||
}
|
||||
}
|
||||
if createdBy.Valid {
|
||||
v := createdBy.Int64
|
||||
out.CreatedBy = &v
|
||||
}
|
||||
return &out, nil
|
||||
}
|
||||
|
||||
func (r *opsRepository) IsAlertSilenced(ctx context.Context, ruleID int64, platform string, groupID *int64, region *string, now time.Time) (bool, error) {
|
||||
if r == nil || r.db == nil {
|
||||
return false, fmt.Errorf("nil ops repository")
|
||||
}
|
||||
if ruleID <= 0 {
|
||||
return false, fmt.Errorf("invalid rule id")
|
||||
}
|
||||
platform = strings.TrimSpace(platform)
|
||||
if platform == "" {
|
||||
return false, nil
|
||||
}
|
||||
if now.IsZero() {
|
||||
now = time.Now().UTC()
|
||||
}
|
||||
|
||||
q := `
|
||||
SELECT 1
|
||||
FROM ops_alert_silences
|
||||
WHERE rule_id = $1
|
||||
AND platform = $2
|
||||
AND (group_id IS NOT DISTINCT FROM $3)
|
||||
AND (region IS NOT DISTINCT FROM $4)
|
||||
AND until > $5
|
||||
LIMIT 1`
|
||||
|
||||
var dummy int
|
||||
err := r.db.QueryRowContext(ctx, q, ruleID, platform, opsNullInt64(groupID), opsNullString(region), now).Scan(&dummy)
|
||||
if err != nil {
|
||||
if err == sql.ErrNoRows {
|
||||
return false, nil
|
||||
}
|
||||
return false, err
|
||||
}
|
||||
return true, nil
|
||||
}
|
||||
|
||||
func scanOpsAlertEvent(row opsAlertEventRow) (*service.OpsAlertEvent, error) {
|
||||
var ev service.OpsAlertEvent
|
||||
var metricValue sql.NullFloat64
|
||||
var thresholdValue sql.NullFloat64
|
||||
var dimensionsRaw []byte
|
||||
var resolvedAt sql.NullTime
|
||||
|
||||
if err := row.Scan(
|
||||
&ev.ID,
|
||||
&ev.RuleID,
|
||||
&ev.Severity,
|
||||
&ev.Status,
|
||||
&ev.Title,
|
||||
&ev.Description,
|
||||
&metricValue,
|
||||
&thresholdValue,
|
||||
&dimensionsRaw,
|
||||
&ev.FiredAt,
|
||||
&resolvedAt,
|
||||
&ev.EmailSent,
|
||||
&ev.CreatedAt,
|
||||
); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if metricValue.Valid {
|
||||
v := metricValue.Float64
|
||||
ev.MetricValue = &v
|
||||
}
|
||||
if thresholdValue.Valid {
|
||||
v := thresholdValue.Float64
|
||||
ev.ThresholdValue = &v
|
||||
}
|
||||
if resolvedAt.Valid {
|
||||
v := resolvedAt.Time
|
||||
ev.ResolvedAt = &v
|
||||
}
|
||||
if len(dimensionsRaw) > 0 && string(dimensionsRaw) != "null" {
|
||||
var decoded map[string]any
|
||||
if err := json.Unmarshal(dimensionsRaw, &decoded); err == nil {
|
||||
ev.Dimensions = decoded
|
||||
}
|
||||
}
|
||||
return &ev, nil
|
||||
}
|
||||
|
||||
func buildOpsAlertEventsWhere(filter *service.OpsAlertEventFilter) (string, []any) {
|
||||
clauses := []string{"1=1"}
|
||||
args := []any{}
|
||||
|
||||
if filter == nil {
|
||||
return "WHERE " + strings.Join(clauses, " AND "), args
|
||||
}
|
||||
|
||||
if status := strings.TrimSpace(filter.Status); status != "" {
|
||||
args = append(args, status)
|
||||
clauses = append(clauses, "status = $"+itoa(len(args)))
|
||||
}
|
||||
if severity := strings.TrimSpace(filter.Severity); severity != "" {
|
||||
args = append(args, severity)
|
||||
clauses = append(clauses, "severity = $"+itoa(len(args)))
|
||||
}
|
||||
if filter.EmailSent != nil {
|
||||
args = append(args, *filter.EmailSent)
|
||||
clauses = append(clauses, "email_sent = $"+itoa(len(args)))
|
||||
}
|
||||
if filter.StartTime != nil && !filter.StartTime.IsZero() {
|
||||
args = append(args, *filter.StartTime)
|
||||
clauses = append(clauses, "fired_at >= $"+itoa(len(args)))
|
||||
}
|
||||
if filter.EndTime != nil && !filter.EndTime.IsZero() {
|
||||
args = append(args, *filter.EndTime)
|
||||
clauses = append(clauses, "fired_at < $"+itoa(len(args)))
|
||||
}
|
||||
|
||||
// Cursor pagination (descending by fired_at, then id)
|
||||
if filter.BeforeFiredAt != nil && !filter.BeforeFiredAt.IsZero() && filter.BeforeID != nil && *filter.BeforeID > 0 {
|
||||
args = append(args, *filter.BeforeFiredAt)
|
||||
tsArg := "$" + itoa(len(args))
|
||||
args = append(args, *filter.BeforeID)
|
||||
idArg := "$" + itoa(len(args))
|
||||
clauses = append(clauses, fmt.Sprintf("(fired_at < %s OR (fired_at = %s AND id < %s))", tsArg, tsArg, idArg))
|
||||
}
|
||||
// Dimensions are stored in JSONB. We filter best-effort without requiring GIN indexes.
|
||||
if platform := strings.TrimSpace(filter.Platform); platform != "" {
|
||||
args = append(args, platform)
|
||||
clauses = append(clauses, "(dimensions->>'platform') = $"+itoa(len(args)))
|
||||
}
|
||||
if filter.GroupID != nil && *filter.GroupID > 0 {
|
||||
args = append(args, fmt.Sprintf("%d", *filter.GroupID))
|
||||
clauses = append(clauses, "(dimensions->>'group_id') = $"+itoa(len(args)))
|
||||
}
|
||||
|
||||
return "WHERE " + strings.Join(clauses, " AND "), args
|
||||
}
|
||||
|
||||
func opsNullJSONMap(v map[string]any) (any, error) {
|
||||
if v == nil {
|
||||
return sql.NullString{}, nil
|
||||
}
|
||||
b, err := json.Marshal(v)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if len(b) == 0 {
|
||||
return sql.NullString{}, nil
|
||||
}
|
||||
return sql.NullString{String: string(b), Valid: true}, nil
|
||||
}
|
||||
1015
backend/internal/repository/ops_repo_dashboard.go
Normal file
1015
backend/internal/repository/ops_repo_dashboard.go
Normal file
File diff suppressed because it is too large
Load Diff
79
backend/internal/repository/ops_repo_histograms.go
Normal file
79
backend/internal/repository/ops_repo_histograms.go
Normal file
@@ -0,0 +1,79 @@
|
||||
package repository
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"strings"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/service"
|
||||
)
|
||||
|
||||
func (r *opsRepository) GetLatencyHistogram(ctx context.Context, filter *service.OpsDashboardFilter) (*service.OpsLatencyHistogramResponse, error) {
|
||||
if r == nil || r.db == nil {
|
||||
return nil, fmt.Errorf("nil ops repository")
|
||||
}
|
||||
if filter == nil {
|
||||
return nil, fmt.Errorf("nil filter")
|
||||
}
|
||||
if filter.StartTime.IsZero() || filter.EndTime.IsZero() {
|
||||
return nil, fmt.Errorf("start_time/end_time required")
|
||||
}
|
||||
|
||||
start := filter.StartTime.UTC()
|
||||
end := filter.EndTime.UTC()
|
||||
|
||||
join, where, args, _ := buildUsageWhere(filter, start, end, 1)
|
||||
rangeExpr := latencyHistogramRangeCaseExpr("ul.duration_ms")
|
||||
orderExpr := latencyHistogramRangeOrderCaseExpr("ul.duration_ms")
|
||||
|
||||
q := `
|
||||
SELECT
|
||||
` + rangeExpr + ` AS range,
|
||||
COALESCE(COUNT(*), 0) AS count,
|
||||
` + orderExpr + ` AS ord
|
||||
FROM usage_logs ul
|
||||
` + join + `
|
||||
` + where + `
|
||||
AND ul.duration_ms IS NOT NULL
|
||||
GROUP BY 1, 3
|
||||
ORDER BY 3 ASC`
|
||||
|
||||
rows, err := r.db.QueryContext(ctx, q, args...)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer func() { _ = rows.Close() }()
|
||||
|
||||
counts := make(map[string]int64, len(latencyHistogramOrderedRanges))
|
||||
var total int64
|
||||
for rows.Next() {
|
||||
var label string
|
||||
var count int64
|
||||
var _ord int
|
||||
if err := rows.Scan(&label, &count, &_ord); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
counts[label] = count
|
||||
total += count
|
||||
}
|
||||
if err := rows.Err(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
buckets := make([]*service.OpsLatencyHistogramBucket, 0, len(latencyHistogramOrderedRanges))
|
||||
for _, label := range latencyHistogramOrderedRanges {
|
||||
buckets = append(buckets, &service.OpsLatencyHistogramBucket{
|
||||
Range: label,
|
||||
Count: counts[label],
|
||||
})
|
||||
}
|
||||
|
||||
return &service.OpsLatencyHistogramResponse{
|
||||
StartTime: start,
|
||||
EndTime: end,
|
||||
Platform: strings.TrimSpace(filter.Platform),
|
||||
GroupID: filter.GroupID,
|
||||
TotalRequests: total,
|
||||
Buckets: buckets,
|
||||
}, nil
|
||||
}
|
||||
@@ -0,0 +1,64 @@
|
||||
package repository
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"strings"
|
||||
)
|
||||
|
||||
type latencyHistogramBucket struct {
|
||||
upperMs int
|
||||
label string
|
||||
}
|
||||
|
||||
var latencyHistogramBuckets = []latencyHistogramBucket{
|
||||
{upperMs: 100, label: "0-100ms"},
|
||||
{upperMs: 200, label: "100-200ms"},
|
||||
{upperMs: 500, label: "200-500ms"},
|
||||
{upperMs: 1000, label: "500-1000ms"},
|
||||
{upperMs: 2000, label: "1000-2000ms"},
|
||||
{upperMs: 0, label: "2000ms+"}, // default bucket
|
||||
}
|
||||
|
||||
var latencyHistogramOrderedRanges = func() []string {
|
||||
out := make([]string, 0, len(latencyHistogramBuckets))
|
||||
for _, b := range latencyHistogramBuckets {
|
||||
out = append(out, b.label)
|
||||
}
|
||||
return out
|
||||
}()
|
||||
|
||||
func latencyHistogramRangeCaseExpr(column string) string {
|
||||
var sb strings.Builder
|
||||
_, _ = sb.WriteString("CASE\n")
|
||||
|
||||
for _, b := range latencyHistogramBuckets {
|
||||
if b.upperMs <= 0 {
|
||||
continue
|
||||
}
|
||||
_, _ = sb.WriteString(fmt.Sprintf("\tWHEN %s < %d THEN '%s'\n", column, b.upperMs, b.label))
|
||||
}
|
||||
|
||||
// Default bucket.
|
||||
last := latencyHistogramBuckets[len(latencyHistogramBuckets)-1]
|
||||
_, _ = sb.WriteString(fmt.Sprintf("\tELSE '%s'\n", last.label))
|
||||
_, _ = sb.WriteString("END")
|
||||
return sb.String()
|
||||
}
|
||||
|
||||
func latencyHistogramRangeOrderCaseExpr(column string) string {
|
||||
var sb strings.Builder
|
||||
_, _ = sb.WriteString("CASE\n")
|
||||
|
||||
order := 1
|
||||
for _, b := range latencyHistogramBuckets {
|
||||
if b.upperMs <= 0 {
|
||||
continue
|
||||
}
|
||||
_, _ = sb.WriteString(fmt.Sprintf("\tWHEN %s < %d THEN %d\n", column, b.upperMs, order))
|
||||
order++
|
||||
}
|
||||
|
||||
_, _ = sb.WriteString(fmt.Sprintf("\tELSE %d\n", order))
|
||||
_, _ = sb.WriteString("END")
|
||||
return sb.String()
|
||||
}
|
||||
@@ -0,0 +1,14 @@
|
||||
package repository
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestLatencyHistogramBuckets_AreConsistent(t *testing.T) {
|
||||
require.Equal(t, len(latencyHistogramBuckets), len(latencyHistogramOrderedRanges))
|
||||
for i, b := range latencyHistogramBuckets {
|
||||
require.Equal(t, b.label, latencyHistogramOrderedRanges[i])
|
||||
}
|
||||
}
|
||||
436
backend/internal/repository/ops_repo_metrics.go
Normal file
436
backend/internal/repository/ops_repo_metrics.go
Normal file
@@ -0,0 +1,436 @@
|
||||
package repository
|
||||
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"fmt"
|
||||
"time"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/service"
|
||||
)
|
||||
|
||||
func (r *opsRepository) InsertSystemMetrics(ctx context.Context, input *service.OpsInsertSystemMetricsInput) error {
|
||||
if r == nil || r.db == nil {
|
||||
return fmt.Errorf("nil ops repository")
|
||||
}
|
||||
if input == nil {
|
||||
return fmt.Errorf("nil input")
|
||||
}
|
||||
|
||||
window := input.WindowMinutes
|
||||
if window <= 0 {
|
||||
window = 1
|
||||
}
|
||||
createdAt := input.CreatedAt
|
||||
if createdAt.IsZero() {
|
||||
createdAt = time.Now().UTC()
|
||||
}
|
||||
|
||||
q := `
|
||||
INSERT INTO ops_system_metrics (
|
||||
created_at,
|
||||
window_minutes,
|
||||
platform,
|
||||
group_id,
|
||||
|
||||
success_count,
|
||||
error_count_total,
|
||||
business_limited_count,
|
||||
error_count_sla,
|
||||
|
||||
upstream_error_count_excl_429_529,
|
||||
upstream_429_count,
|
||||
upstream_529_count,
|
||||
|
||||
token_consumed,
|
||||
qps,
|
||||
tps,
|
||||
|
||||
duration_p50_ms,
|
||||
duration_p90_ms,
|
||||
duration_p95_ms,
|
||||
duration_p99_ms,
|
||||
duration_avg_ms,
|
||||
duration_max_ms,
|
||||
|
||||
ttft_p50_ms,
|
||||
ttft_p90_ms,
|
||||
ttft_p95_ms,
|
||||
ttft_p99_ms,
|
||||
ttft_avg_ms,
|
||||
ttft_max_ms,
|
||||
|
||||
cpu_usage_percent,
|
||||
memory_used_mb,
|
||||
memory_total_mb,
|
||||
memory_usage_percent,
|
||||
|
||||
db_ok,
|
||||
redis_ok,
|
||||
|
||||
redis_conn_total,
|
||||
redis_conn_idle,
|
||||
|
||||
db_conn_active,
|
||||
db_conn_idle,
|
||||
db_conn_waiting,
|
||||
|
||||
goroutine_count,
|
||||
concurrency_queue_depth
|
||||
) VALUES (
|
||||
$1,$2,$3,$4,
|
||||
$5,$6,$7,$8,
|
||||
$9,$10,$11,
|
||||
$12,$13,$14,
|
||||
$15,$16,$17,$18,$19,$20,
|
||||
$21,$22,$23,$24,$25,$26,
|
||||
$27,$28,$29,$30,
|
||||
$31,$32,
|
||||
$33,$34,
|
||||
$35,$36,$37,
|
||||
$38,$39
|
||||
)`
|
||||
|
||||
_, err := r.db.ExecContext(
|
||||
ctx,
|
||||
q,
|
||||
createdAt,
|
||||
window,
|
||||
opsNullString(input.Platform),
|
||||
opsNullInt64(input.GroupID),
|
||||
|
||||
input.SuccessCount,
|
||||
input.ErrorCountTotal,
|
||||
input.BusinessLimitedCount,
|
||||
input.ErrorCountSLA,
|
||||
|
||||
input.UpstreamErrorCountExcl429529,
|
||||
input.Upstream429Count,
|
||||
input.Upstream529Count,
|
||||
|
||||
input.TokenConsumed,
|
||||
opsNullFloat64(input.QPS),
|
||||
opsNullFloat64(input.TPS),
|
||||
|
||||
opsNullInt(input.DurationP50Ms),
|
||||
opsNullInt(input.DurationP90Ms),
|
||||
opsNullInt(input.DurationP95Ms),
|
||||
opsNullInt(input.DurationP99Ms),
|
||||
opsNullFloat64(input.DurationAvgMs),
|
||||
opsNullInt(input.DurationMaxMs),
|
||||
|
||||
opsNullInt(input.TTFTP50Ms),
|
||||
opsNullInt(input.TTFTP90Ms),
|
||||
opsNullInt(input.TTFTP95Ms),
|
||||
opsNullInt(input.TTFTP99Ms),
|
||||
opsNullFloat64(input.TTFTAvgMs),
|
||||
opsNullInt(input.TTFTMaxMs),
|
||||
|
||||
opsNullFloat64(input.CPUUsagePercent),
|
||||
opsNullInt(input.MemoryUsedMB),
|
||||
opsNullInt(input.MemoryTotalMB),
|
||||
opsNullFloat64(input.MemoryUsagePercent),
|
||||
|
||||
opsNullBool(input.DBOK),
|
||||
opsNullBool(input.RedisOK),
|
||||
|
||||
opsNullInt(input.RedisConnTotal),
|
||||
opsNullInt(input.RedisConnIdle),
|
||||
|
||||
opsNullInt(input.DBConnActive),
|
||||
opsNullInt(input.DBConnIdle),
|
||||
opsNullInt(input.DBConnWaiting),
|
||||
|
||||
opsNullInt(input.GoroutineCount),
|
||||
opsNullInt(input.ConcurrencyQueueDepth),
|
||||
)
|
||||
return err
|
||||
}
|
||||
|
||||
func (r *opsRepository) GetLatestSystemMetrics(ctx context.Context, windowMinutes int) (*service.OpsSystemMetricsSnapshot, error) {
|
||||
if r == nil || r.db == nil {
|
||||
return nil, fmt.Errorf("nil ops repository")
|
||||
}
|
||||
if windowMinutes <= 0 {
|
||||
windowMinutes = 1
|
||||
}
|
||||
|
||||
q := `
|
||||
SELECT
|
||||
id,
|
||||
created_at,
|
||||
window_minutes,
|
||||
|
||||
cpu_usage_percent,
|
||||
memory_used_mb,
|
||||
memory_total_mb,
|
||||
memory_usage_percent,
|
||||
|
||||
db_ok,
|
||||
redis_ok,
|
||||
|
||||
redis_conn_total,
|
||||
redis_conn_idle,
|
||||
|
||||
db_conn_active,
|
||||
db_conn_idle,
|
||||
db_conn_waiting,
|
||||
|
||||
goroutine_count,
|
||||
concurrency_queue_depth
|
||||
FROM ops_system_metrics
|
||||
WHERE window_minutes = $1
|
||||
AND platform IS NULL
|
||||
AND group_id IS NULL
|
||||
ORDER BY created_at DESC
|
||||
LIMIT 1`
|
||||
|
||||
var out service.OpsSystemMetricsSnapshot
|
||||
var cpu sql.NullFloat64
|
||||
var memUsed sql.NullInt64
|
||||
var memTotal sql.NullInt64
|
||||
var memPct sql.NullFloat64
|
||||
var dbOK sql.NullBool
|
||||
var redisOK sql.NullBool
|
||||
var redisTotal sql.NullInt64
|
||||
var redisIdle sql.NullInt64
|
||||
var dbActive sql.NullInt64
|
||||
var dbIdle sql.NullInt64
|
||||
var dbWaiting sql.NullInt64
|
||||
var goroutines sql.NullInt64
|
||||
var queueDepth sql.NullInt64
|
||||
|
||||
if err := r.db.QueryRowContext(ctx, q, windowMinutes).Scan(
|
||||
&out.ID,
|
||||
&out.CreatedAt,
|
||||
&out.WindowMinutes,
|
||||
&cpu,
|
||||
&memUsed,
|
||||
&memTotal,
|
||||
&memPct,
|
||||
&dbOK,
|
||||
&redisOK,
|
||||
&redisTotal,
|
||||
&redisIdle,
|
||||
&dbActive,
|
||||
&dbIdle,
|
||||
&dbWaiting,
|
||||
&goroutines,
|
||||
&queueDepth,
|
||||
); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if cpu.Valid {
|
||||
v := cpu.Float64
|
||||
out.CPUUsagePercent = &v
|
||||
}
|
||||
if memUsed.Valid {
|
||||
v := memUsed.Int64
|
||||
out.MemoryUsedMB = &v
|
||||
}
|
||||
if memTotal.Valid {
|
||||
v := memTotal.Int64
|
||||
out.MemoryTotalMB = &v
|
||||
}
|
||||
if memPct.Valid {
|
||||
v := memPct.Float64
|
||||
out.MemoryUsagePercent = &v
|
||||
}
|
||||
if dbOK.Valid {
|
||||
v := dbOK.Bool
|
||||
out.DBOK = &v
|
||||
}
|
||||
if redisOK.Valid {
|
||||
v := redisOK.Bool
|
||||
out.RedisOK = &v
|
||||
}
|
||||
if redisTotal.Valid {
|
||||
v := int(redisTotal.Int64)
|
||||
out.RedisConnTotal = &v
|
||||
}
|
||||
if redisIdle.Valid {
|
||||
v := int(redisIdle.Int64)
|
||||
out.RedisConnIdle = &v
|
||||
}
|
||||
if dbActive.Valid {
|
||||
v := int(dbActive.Int64)
|
||||
out.DBConnActive = &v
|
||||
}
|
||||
if dbIdle.Valid {
|
||||
v := int(dbIdle.Int64)
|
||||
out.DBConnIdle = &v
|
||||
}
|
||||
if dbWaiting.Valid {
|
||||
v := int(dbWaiting.Int64)
|
||||
out.DBConnWaiting = &v
|
||||
}
|
||||
if goroutines.Valid {
|
||||
v := int(goroutines.Int64)
|
||||
out.GoroutineCount = &v
|
||||
}
|
||||
if queueDepth.Valid {
|
||||
v := int(queueDepth.Int64)
|
||||
out.ConcurrencyQueueDepth = &v
|
||||
}
|
||||
|
||||
return &out, nil
|
||||
}
|
||||
|
||||
func (r *opsRepository) UpsertJobHeartbeat(ctx context.Context, input *service.OpsUpsertJobHeartbeatInput) error {
|
||||
if r == nil || r.db == nil {
|
||||
return fmt.Errorf("nil ops repository")
|
||||
}
|
||||
if input == nil {
|
||||
return fmt.Errorf("nil input")
|
||||
}
|
||||
if input.JobName == "" {
|
||||
return fmt.Errorf("job_name required")
|
||||
}
|
||||
|
||||
q := `
|
||||
INSERT INTO ops_job_heartbeats (
|
||||
job_name,
|
||||
last_run_at,
|
||||
last_success_at,
|
||||
last_error_at,
|
||||
last_error,
|
||||
last_duration_ms,
|
||||
last_result,
|
||||
updated_at
|
||||
) VALUES (
|
||||
$1,$2,$3,$4,$5,$6,$7,NOW()
|
||||
)
|
||||
ON CONFLICT (job_name) DO UPDATE SET
|
||||
last_run_at = COALESCE(EXCLUDED.last_run_at, ops_job_heartbeats.last_run_at),
|
||||
last_success_at = COALESCE(EXCLUDED.last_success_at, ops_job_heartbeats.last_success_at),
|
||||
last_error_at = CASE
|
||||
WHEN EXCLUDED.last_success_at IS NOT NULL THEN NULL
|
||||
ELSE COALESCE(EXCLUDED.last_error_at, ops_job_heartbeats.last_error_at)
|
||||
END,
|
||||
last_error = CASE
|
||||
WHEN EXCLUDED.last_success_at IS NOT NULL THEN NULL
|
||||
ELSE COALESCE(EXCLUDED.last_error, ops_job_heartbeats.last_error)
|
||||
END,
|
||||
last_duration_ms = COALESCE(EXCLUDED.last_duration_ms, ops_job_heartbeats.last_duration_ms),
|
||||
last_result = CASE
|
||||
WHEN EXCLUDED.last_success_at IS NOT NULL THEN COALESCE(EXCLUDED.last_result, ops_job_heartbeats.last_result)
|
||||
ELSE ops_job_heartbeats.last_result
|
||||
END,
|
||||
updated_at = NOW()`
|
||||
|
||||
_, err := r.db.ExecContext(
|
||||
ctx,
|
||||
q,
|
||||
input.JobName,
|
||||
opsNullTime(input.LastRunAt),
|
||||
opsNullTime(input.LastSuccessAt),
|
||||
opsNullTime(input.LastErrorAt),
|
||||
opsNullString(input.LastError),
|
||||
opsNullInt(input.LastDurationMs),
|
||||
opsNullString(input.LastResult),
|
||||
)
|
||||
return err
|
||||
}
|
||||
|
||||
func (r *opsRepository) ListJobHeartbeats(ctx context.Context) ([]*service.OpsJobHeartbeat, error) {
|
||||
if r == nil || r.db == nil {
|
||||
return nil, fmt.Errorf("nil ops repository")
|
||||
}
|
||||
|
||||
q := `
|
||||
SELECT
|
||||
job_name,
|
||||
last_run_at,
|
||||
last_success_at,
|
||||
last_error_at,
|
||||
last_error,
|
||||
last_duration_ms,
|
||||
last_result,
|
||||
updated_at
|
||||
FROM ops_job_heartbeats
|
||||
ORDER BY job_name ASC`
|
||||
|
||||
rows, err := r.db.QueryContext(ctx, q)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer func() { _ = rows.Close() }()
|
||||
|
||||
out := make([]*service.OpsJobHeartbeat, 0, 8)
|
||||
for rows.Next() {
|
||||
var item service.OpsJobHeartbeat
|
||||
var lastRun sql.NullTime
|
||||
var lastSuccess sql.NullTime
|
||||
var lastErrorAt sql.NullTime
|
||||
var lastError sql.NullString
|
||||
var lastDuration sql.NullInt64
|
||||
|
||||
var lastResult sql.NullString
|
||||
|
||||
if err := rows.Scan(
|
||||
&item.JobName,
|
||||
&lastRun,
|
||||
&lastSuccess,
|
||||
&lastErrorAt,
|
||||
&lastError,
|
||||
&lastDuration,
|
||||
&lastResult,
|
||||
&item.UpdatedAt,
|
||||
); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if lastRun.Valid {
|
||||
v := lastRun.Time
|
||||
item.LastRunAt = &v
|
||||
}
|
||||
if lastSuccess.Valid {
|
||||
v := lastSuccess.Time
|
||||
item.LastSuccessAt = &v
|
||||
}
|
||||
if lastErrorAt.Valid {
|
||||
v := lastErrorAt.Time
|
||||
item.LastErrorAt = &v
|
||||
}
|
||||
if lastError.Valid {
|
||||
v := lastError.String
|
||||
item.LastError = &v
|
||||
}
|
||||
if lastDuration.Valid {
|
||||
v := lastDuration.Int64
|
||||
item.LastDurationMs = &v
|
||||
}
|
||||
if lastResult.Valid {
|
||||
v := lastResult.String
|
||||
item.LastResult = &v
|
||||
}
|
||||
|
||||
out = append(out, &item)
|
||||
}
|
||||
if err := rows.Err(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return out, nil
|
||||
}
|
||||
|
||||
func opsNullBool(v *bool) any {
|
||||
if v == nil {
|
||||
return sql.NullBool{}
|
||||
}
|
||||
return sql.NullBool{Bool: *v, Valid: true}
|
||||
}
|
||||
|
||||
func opsNullFloat64(v *float64) any {
|
||||
if v == nil {
|
||||
return sql.NullFloat64{}
|
||||
}
|
||||
return sql.NullFloat64{Float64: *v, Valid: true}
|
||||
}
|
||||
|
||||
func opsNullTime(v *time.Time) any {
|
||||
if v == nil || v.IsZero() {
|
||||
return sql.NullTime{}
|
||||
}
|
||||
return sql.NullTime{Time: *v, Valid: true}
|
||||
}
|
||||
363
backend/internal/repository/ops_repo_preagg.go
Normal file
363
backend/internal/repository/ops_repo_preagg.go
Normal file
@@ -0,0 +1,363 @@
|
||||
package repository
|
||||
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"fmt"
|
||||
"time"
|
||||
)
|
||||
|
||||
func (r *opsRepository) UpsertHourlyMetrics(ctx context.Context, startTime, endTime time.Time) error {
|
||||
if r == nil || r.db == nil {
|
||||
return fmt.Errorf("nil ops repository")
|
||||
}
|
||||
if startTime.IsZero() || endTime.IsZero() || !endTime.After(startTime) {
|
||||
return nil
|
||||
}
|
||||
|
||||
start := startTime.UTC()
|
||||
end := endTime.UTC()
|
||||
|
||||
// NOTE:
|
||||
// - We aggregate usage_logs + ops_error_logs into ops_metrics_hourly.
|
||||
// - We emit three dimension granularities via GROUPING SETS:
|
||||
// 1) overall: (bucket_start)
|
||||
// 2) platform: (bucket_start, platform)
|
||||
// 3) group: (bucket_start, platform, group_id)
|
||||
//
|
||||
// IMPORTANT: Postgres UNIQUE treats NULLs as distinct, so the table uses a COALESCE-based
|
||||
// unique index; our ON CONFLICT target must match that expression set.
|
||||
q := `
|
||||
WITH usage_base AS (
|
||||
SELECT
|
||||
date_trunc('hour', ul.created_at AT TIME ZONE 'UTC') AT TIME ZONE 'UTC' AS bucket_start,
|
||||
g.platform AS platform,
|
||||
ul.group_id AS group_id,
|
||||
ul.duration_ms AS duration_ms,
|
||||
ul.first_token_ms AS first_token_ms,
|
||||
(ul.input_tokens + ul.output_tokens + ul.cache_creation_tokens + ul.cache_read_tokens) AS tokens
|
||||
FROM usage_logs ul
|
||||
JOIN groups g ON g.id = ul.group_id
|
||||
WHERE ul.created_at >= $1 AND ul.created_at < $2
|
||||
),
|
||||
usage_agg AS (
|
||||
SELECT
|
||||
bucket_start,
|
||||
CASE WHEN GROUPING(platform) = 1 THEN NULL ELSE platform END AS platform,
|
||||
CASE WHEN GROUPING(group_id) = 1 THEN NULL ELSE group_id END AS group_id,
|
||||
COUNT(*) AS success_count,
|
||||
COALESCE(SUM(tokens), 0) AS token_consumed,
|
||||
|
||||
percentile_cont(0.50) WITHIN GROUP (ORDER BY duration_ms) FILTER (WHERE duration_ms IS NOT NULL) AS duration_p50_ms,
|
||||
percentile_cont(0.90) WITHIN GROUP (ORDER BY duration_ms) FILTER (WHERE duration_ms IS NOT NULL) AS duration_p90_ms,
|
||||
percentile_cont(0.95) WITHIN GROUP (ORDER BY duration_ms) FILTER (WHERE duration_ms IS NOT NULL) AS duration_p95_ms,
|
||||
percentile_cont(0.99) WITHIN GROUP (ORDER BY duration_ms) FILTER (WHERE duration_ms IS NOT NULL) AS duration_p99_ms,
|
||||
AVG(duration_ms) FILTER (WHERE duration_ms IS NOT NULL) AS duration_avg_ms,
|
||||
MAX(duration_ms) AS duration_max_ms,
|
||||
|
||||
percentile_cont(0.50) WITHIN GROUP (ORDER BY first_token_ms) FILTER (WHERE first_token_ms IS NOT NULL) AS ttft_p50_ms,
|
||||
percentile_cont(0.90) WITHIN GROUP (ORDER BY first_token_ms) FILTER (WHERE first_token_ms IS NOT NULL) AS ttft_p90_ms,
|
||||
percentile_cont(0.95) WITHIN GROUP (ORDER BY first_token_ms) FILTER (WHERE first_token_ms IS NOT NULL) AS ttft_p95_ms,
|
||||
percentile_cont(0.99) WITHIN GROUP (ORDER BY first_token_ms) FILTER (WHERE first_token_ms IS NOT NULL) AS ttft_p99_ms,
|
||||
AVG(first_token_ms) FILTER (WHERE first_token_ms IS NOT NULL) AS ttft_avg_ms,
|
||||
MAX(first_token_ms) AS ttft_max_ms
|
||||
FROM usage_base
|
||||
GROUP BY GROUPING SETS (
|
||||
(bucket_start),
|
||||
(bucket_start, platform),
|
||||
(bucket_start, platform, group_id)
|
||||
)
|
||||
),
|
||||
error_base AS (
|
||||
SELECT
|
||||
date_trunc('hour', created_at AT TIME ZONE 'UTC') AT TIME ZONE 'UTC' AS bucket_start,
|
||||
-- platform is NULL for some early-phase errors (e.g. before routing); map to a sentinel
|
||||
-- value so platform-level GROUPING SETS don't collide with the overall (platform=NULL) row.
|
||||
COALESCE(platform, 'unknown') AS platform,
|
||||
group_id AS group_id,
|
||||
is_business_limited AS is_business_limited,
|
||||
error_owner AS error_owner,
|
||||
status_code AS client_status_code,
|
||||
COALESCE(upstream_status_code, status_code, 0) AS effective_status_code
|
||||
FROM ops_error_logs
|
||||
-- Exclude count_tokens requests from error metrics as they are informational probes
|
||||
WHERE created_at >= $1 AND created_at < $2
|
||||
AND is_count_tokens = FALSE
|
||||
),
|
||||
error_agg AS (
|
||||
SELECT
|
||||
bucket_start,
|
||||
CASE WHEN GROUPING(platform) = 1 THEN NULL ELSE platform END AS platform,
|
||||
CASE WHEN GROUPING(group_id) = 1 THEN NULL ELSE group_id END AS group_id,
|
||||
COUNT(*) FILTER (WHERE COALESCE(client_status_code, 0) >= 400) AS error_count_total,
|
||||
COUNT(*) FILTER (WHERE COALESCE(client_status_code, 0) >= 400 AND is_business_limited) AS business_limited_count,
|
||||
COUNT(*) FILTER (WHERE COALESCE(client_status_code, 0) >= 400 AND NOT is_business_limited) AS error_count_sla,
|
||||
COUNT(*) FILTER (WHERE error_owner = 'provider' AND NOT is_business_limited AND COALESCE(effective_status_code, 0) NOT IN (429, 529)) AS upstream_error_count_excl_429_529,
|
||||
COUNT(*) FILTER (WHERE error_owner = 'provider' AND NOT is_business_limited AND COALESCE(effective_status_code, 0) = 429) AS upstream_429_count,
|
||||
COUNT(*) FILTER (WHERE error_owner = 'provider' AND NOT is_business_limited AND COALESCE(effective_status_code, 0) = 529) AS upstream_529_count
|
||||
FROM error_base
|
||||
GROUP BY GROUPING SETS (
|
||||
(bucket_start),
|
||||
(bucket_start, platform),
|
||||
(bucket_start, platform, group_id)
|
||||
)
|
||||
HAVING GROUPING(group_id) = 1 OR group_id IS NOT NULL
|
||||
),
|
||||
combined AS (
|
||||
SELECT
|
||||
COALESCE(u.bucket_start, e.bucket_start) AS bucket_start,
|
||||
COALESCE(u.platform, e.platform) AS platform,
|
||||
COALESCE(u.group_id, e.group_id) AS group_id,
|
||||
|
||||
COALESCE(u.success_count, 0) AS success_count,
|
||||
COALESCE(e.error_count_total, 0) AS error_count_total,
|
||||
COALESCE(e.business_limited_count, 0) AS business_limited_count,
|
||||
COALESCE(e.error_count_sla, 0) AS error_count_sla,
|
||||
COALESCE(e.upstream_error_count_excl_429_529, 0) AS upstream_error_count_excl_429_529,
|
||||
COALESCE(e.upstream_429_count, 0) AS upstream_429_count,
|
||||
COALESCE(e.upstream_529_count, 0) AS upstream_529_count,
|
||||
|
||||
COALESCE(u.token_consumed, 0) AS token_consumed,
|
||||
|
||||
u.duration_p50_ms,
|
||||
u.duration_p90_ms,
|
||||
u.duration_p95_ms,
|
||||
u.duration_p99_ms,
|
||||
u.duration_avg_ms,
|
||||
u.duration_max_ms,
|
||||
|
||||
u.ttft_p50_ms,
|
||||
u.ttft_p90_ms,
|
||||
u.ttft_p95_ms,
|
||||
u.ttft_p99_ms,
|
||||
u.ttft_avg_ms,
|
||||
u.ttft_max_ms
|
||||
FROM usage_agg u
|
||||
FULL OUTER JOIN error_agg e
|
||||
ON u.bucket_start = e.bucket_start
|
||||
AND COALESCE(u.platform, '') = COALESCE(e.platform, '')
|
||||
AND COALESCE(u.group_id, 0) = COALESCE(e.group_id, 0)
|
||||
)
|
||||
INSERT INTO ops_metrics_hourly (
|
||||
bucket_start,
|
||||
platform,
|
||||
group_id,
|
||||
success_count,
|
||||
error_count_total,
|
||||
business_limited_count,
|
||||
error_count_sla,
|
||||
upstream_error_count_excl_429_529,
|
||||
upstream_429_count,
|
||||
upstream_529_count,
|
||||
token_consumed,
|
||||
duration_p50_ms,
|
||||
duration_p90_ms,
|
||||
duration_p95_ms,
|
||||
duration_p99_ms,
|
||||
duration_avg_ms,
|
||||
duration_max_ms,
|
||||
ttft_p50_ms,
|
||||
ttft_p90_ms,
|
||||
ttft_p95_ms,
|
||||
ttft_p99_ms,
|
||||
ttft_avg_ms,
|
||||
ttft_max_ms,
|
||||
computed_at
|
||||
)
|
||||
SELECT
|
||||
bucket_start,
|
||||
NULLIF(platform, '') AS platform,
|
||||
group_id,
|
||||
success_count,
|
||||
error_count_total,
|
||||
business_limited_count,
|
||||
error_count_sla,
|
||||
upstream_error_count_excl_429_529,
|
||||
upstream_429_count,
|
||||
upstream_529_count,
|
||||
token_consumed,
|
||||
duration_p50_ms::int,
|
||||
duration_p90_ms::int,
|
||||
duration_p95_ms::int,
|
||||
duration_p99_ms::int,
|
||||
duration_avg_ms,
|
||||
duration_max_ms::int,
|
||||
ttft_p50_ms::int,
|
||||
ttft_p90_ms::int,
|
||||
ttft_p95_ms::int,
|
||||
ttft_p99_ms::int,
|
||||
ttft_avg_ms,
|
||||
ttft_max_ms::int,
|
||||
NOW()
|
||||
FROM combined
|
||||
WHERE bucket_start IS NOT NULL
|
||||
AND (platform IS NULL OR platform <> '')
|
||||
ON CONFLICT (bucket_start, COALESCE(platform, ''), COALESCE(group_id, 0)) DO UPDATE SET
|
||||
success_count = EXCLUDED.success_count,
|
||||
error_count_total = EXCLUDED.error_count_total,
|
||||
business_limited_count = EXCLUDED.business_limited_count,
|
||||
error_count_sla = EXCLUDED.error_count_sla,
|
||||
upstream_error_count_excl_429_529 = EXCLUDED.upstream_error_count_excl_429_529,
|
||||
upstream_429_count = EXCLUDED.upstream_429_count,
|
||||
upstream_529_count = EXCLUDED.upstream_529_count,
|
||||
token_consumed = EXCLUDED.token_consumed,
|
||||
|
||||
duration_p50_ms = EXCLUDED.duration_p50_ms,
|
||||
duration_p90_ms = EXCLUDED.duration_p90_ms,
|
||||
duration_p95_ms = EXCLUDED.duration_p95_ms,
|
||||
duration_p99_ms = EXCLUDED.duration_p99_ms,
|
||||
duration_avg_ms = EXCLUDED.duration_avg_ms,
|
||||
duration_max_ms = EXCLUDED.duration_max_ms,
|
||||
|
||||
ttft_p50_ms = EXCLUDED.ttft_p50_ms,
|
||||
ttft_p90_ms = EXCLUDED.ttft_p90_ms,
|
||||
ttft_p95_ms = EXCLUDED.ttft_p95_ms,
|
||||
ttft_p99_ms = EXCLUDED.ttft_p99_ms,
|
||||
ttft_avg_ms = EXCLUDED.ttft_avg_ms,
|
||||
ttft_max_ms = EXCLUDED.ttft_max_ms,
|
||||
|
||||
computed_at = NOW()
|
||||
`
|
||||
|
||||
_, err := r.db.ExecContext(ctx, q, start, end)
|
||||
return err
|
||||
}
|
||||
|
||||
func (r *opsRepository) UpsertDailyMetrics(ctx context.Context, startTime, endTime time.Time) error {
|
||||
if r == nil || r.db == nil {
|
||||
return fmt.Errorf("nil ops repository")
|
||||
}
|
||||
if startTime.IsZero() || endTime.IsZero() || !endTime.After(startTime) {
|
||||
return nil
|
||||
}
|
||||
|
||||
start := startTime.UTC()
|
||||
end := endTime.UTC()
|
||||
|
||||
q := `
|
||||
INSERT INTO ops_metrics_daily (
|
||||
bucket_date,
|
||||
platform,
|
||||
group_id,
|
||||
success_count,
|
||||
error_count_total,
|
||||
business_limited_count,
|
||||
error_count_sla,
|
||||
upstream_error_count_excl_429_529,
|
||||
upstream_429_count,
|
||||
upstream_529_count,
|
||||
token_consumed,
|
||||
duration_p50_ms,
|
||||
duration_p90_ms,
|
||||
duration_p95_ms,
|
||||
duration_p99_ms,
|
||||
duration_avg_ms,
|
||||
duration_max_ms,
|
||||
ttft_p50_ms,
|
||||
ttft_p90_ms,
|
||||
ttft_p95_ms,
|
||||
ttft_p99_ms,
|
||||
ttft_avg_ms,
|
||||
ttft_max_ms,
|
||||
computed_at
|
||||
)
|
||||
SELECT
|
||||
(bucket_start AT TIME ZONE 'UTC')::date AS bucket_date,
|
||||
platform,
|
||||
group_id,
|
||||
|
||||
COALESCE(SUM(success_count), 0) AS success_count,
|
||||
COALESCE(SUM(error_count_total), 0) AS error_count_total,
|
||||
COALESCE(SUM(business_limited_count), 0) AS business_limited_count,
|
||||
COALESCE(SUM(error_count_sla), 0) AS error_count_sla,
|
||||
COALESCE(SUM(upstream_error_count_excl_429_529), 0) AS upstream_error_count_excl_429_529,
|
||||
COALESCE(SUM(upstream_429_count), 0) AS upstream_429_count,
|
||||
COALESCE(SUM(upstream_529_count), 0) AS upstream_529_count,
|
||||
COALESCE(SUM(token_consumed), 0) AS token_consumed,
|
||||
|
||||
-- Approximation: weighted average for p50/p90, max for p95/p99 (conservative tail).
|
||||
ROUND(SUM(duration_p50_ms::double precision * success_count) FILTER (WHERE duration_p50_ms IS NOT NULL)
|
||||
/ NULLIF(SUM(success_count) FILTER (WHERE duration_p50_ms IS NOT NULL), 0))::int AS duration_p50_ms,
|
||||
ROUND(SUM(duration_p90_ms::double precision * success_count) FILTER (WHERE duration_p90_ms IS NOT NULL)
|
||||
/ NULLIF(SUM(success_count) FILTER (WHERE duration_p90_ms IS NOT NULL), 0))::int AS duration_p90_ms,
|
||||
MAX(duration_p95_ms) AS duration_p95_ms,
|
||||
MAX(duration_p99_ms) AS duration_p99_ms,
|
||||
SUM(duration_avg_ms * success_count) FILTER (WHERE duration_avg_ms IS NOT NULL)
|
||||
/ NULLIF(SUM(success_count) FILTER (WHERE duration_avg_ms IS NOT NULL), 0) AS duration_avg_ms,
|
||||
MAX(duration_max_ms) AS duration_max_ms,
|
||||
|
||||
ROUND(SUM(ttft_p50_ms::double precision * success_count) FILTER (WHERE ttft_p50_ms IS NOT NULL)
|
||||
/ NULLIF(SUM(success_count) FILTER (WHERE ttft_p50_ms IS NOT NULL), 0))::int AS ttft_p50_ms,
|
||||
ROUND(SUM(ttft_p90_ms::double precision * success_count) FILTER (WHERE ttft_p90_ms IS NOT NULL)
|
||||
/ NULLIF(SUM(success_count) FILTER (WHERE ttft_p90_ms IS NOT NULL), 0))::int AS ttft_p90_ms,
|
||||
MAX(ttft_p95_ms) AS ttft_p95_ms,
|
||||
MAX(ttft_p99_ms) AS ttft_p99_ms,
|
||||
SUM(ttft_avg_ms * success_count) FILTER (WHERE ttft_avg_ms IS NOT NULL)
|
||||
/ NULLIF(SUM(success_count) FILTER (WHERE ttft_avg_ms IS NOT NULL), 0) AS ttft_avg_ms,
|
||||
MAX(ttft_max_ms) AS ttft_max_ms,
|
||||
|
||||
NOW()
|
||||
FROM ops_metrics_hourly
|
||||
WHERE bucket_start >= $1 AND bucket_start < $2
|
||||
GROUP BY 1, 2, 3
|
||||
ON CONFLICT (bucket_date, COALESCE(platform, ''), COALESCE(group_id, 0)) DO UPDATE SET
|
||||
success_count = EXCLUDED.success_count,
|
||||
error_count_total = EXCLUDED.error_count_total,
|
||||
business_limited_count = EXCLUDED.business_limited_count,
|
||||
error_count_sla = EXCLUDED.error_count_sla,
|
||||
upstream_error_count_excl_429_529 = EXCLUDED.upstream_error_count_excl_429_529,
|
||||
upstream_429_count = EXCLUDED.upstream_429_count,
|
||||
upstream_529_count = EXCLUDED.upstream_529_count,
|
||||
token_consumed = EXCLUDED.token_consumed,
|
||||
|
||||
duration_p50_ms = EXCLUDED.duration_p50_ms,
|
||||
duration_p90_ms = EXCLUDED.duration_p90_ms,
|
||||
duration_p95_ms = EXCLUDED.duration_p95_ms,
|
||||
duration_p99_ms = EXCLUDED.duration_p99_ms,
|
||||
duration_avg_ms = EXCLUDED.duration_avg_ms,
|
||||
duration_max_ms = EXCLUDED.duration_max_ms,
|
||||
|
||||
ttft_p50_ms = EXCLUDED.ttft_p50_ms,
|
||||
ttft_p90_ms = EXCLUDED.ttft_p90_ms,
|
||||
ttft_p95_ms = EXCLUDED.ttft_p95_ms,
|
||||
ttft_p99_ms = EXCLUDED.ttft_p99_ms,
|
||||
ttft_avg_ms = EXCLUDED.ttft_avg_ms,
|
||||
ttft_max_ms = EXCLUDED.ttft_max_ms,
|
||||
|
||||
computed_at = NOW()
|
||||
`
|
||||
|
||||
_, err := r.db.ExecContext(ctx, q, start, end)
|
||||
return err
|
||||
}
|
||||
|
||||
func (r *opsRepository) GetLatestHourlyBucketStart(ctx context.Context) (time.Time, bool, error) {
|
||||
if r == nil || r.db == nil {
|
||||
return time.Time{}, false, fmt.Errorf("nil ops repository")
|
||||
}
|
||||
|
||||
var value sql.NullTime
|
||||
if err := r.db.QueryRowContext(ctx, `SELECT MAX(bucket_start) FROM ops_metrics_hourly`).Scan(&value); err != nil {
|
||||
return time.Time{}, false, err
|
||||
}
|
||||
if !value.Valid {
|
||||
return time.Time{}, false, nil
|
||||
}
|
||||
return value.Time.UTC(), true, nil
|
||||
}
|
||||
|
||||
func (r *opsRepository) GetLatestDailyBucketDate(ctx context.Context) (time.Time, bool, error) {
|
||||
if r == nil || r.db == nil {
|
||||
return time.Time{}, false, fmt.Errorf("nil ops repository")
|
||||
}
|
||||
|
||||
var value sql.NullTime
|
||||
if err := r.db.QueryRowContext(ctx, `SELECT MAX(bucket_date) FROM ops_metrics_daily`).Scan(&value); err != nil {
|
||||
return time.Time{}, false, err
|
||||
}
|
||||
if !value.Valid {
|
||||
return time.Time{}, false, nil
|
||||
}
|
||||
t := value.Time.UTC()
|
||||
return time.Date(t.Year(), t.Month(), t.Day(), 0, 0, 0, 0, time.UTC), true, nil
|
||||
}
|
||||
129
backend/internal/repository/ops_repo_realtime_traffic.go
Normal file
129
backend/internal/repository/ops_repo_realtime_traffic.go
Normal file
@@ -0,0 +1,129 @@
|
||||
package repository
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/service"
|
||||
)
|
||||
|
||||
func (r *opsRepository) GetRealtimeTrafficSummary(ctx context.Context, filter *service.OpsDashboardFilter) (*service.OpsRealtimeTrafficSummary, error) {
|
||||
if r == nil || r.db == nil {
|
||||
return nil, fmt.Errorf("nil ops repository")
|
||||
}
|
||||
if filter == nil {
|
||||
return nil, fmt.Errorf("nil filter")
|
||||
}
|
||||
if filter.StartTime.IsZero() || filter.EndTime.IsZero() {
|
||||
return nil, fmt.Errorf("start_time/end_time required")
|
||||
}
|
||||
|
||||
start := filter.StartTime.UTC()
|
||||
end := filter.EndTime.UTC()
|
||||
if start.After(end) {
|
||||
return nil, fmt.Errorf("start_time must be <= end_time")
|
||||
}
|
||||
|
||||
window := end.Sub(start)
|
||||
if window <= 0 {
|
||||
return nil, fmt.Errorf("invalid time window")
|
||||
}
|
||||
if window > time.Hour {
|
||||
return nil, fmt.Errorf("window too large")
|
||||
}
|
||||
|
||||
usageJoin, usageWhere, usageArgs, next := buildUsageWhere(filter, start, end, 1)
|
||||
errorWhere, errorArgs, _ := buildErrorWhere(filter, start, end, next)
|
||||
|
||||
q := `
|
||||
WITH usage_buckets AS (
|
||||
SELECT
|
||||
date_trunc('minute', ul.created_at) AS bucket,
|
||||
COALESCE(COUNT(*), 0) AS success_count,
|
||||
COALESCE(SUM(input_tokens + output_tokens + cache_creation_tokens + cache_read_tokens), 0) AS token_sum
|
||||
FROM usage_logs ul
|
||||
` + usageJoin + `
|
||||
` + usageWhere + `
|
||||
GROUP BY 1
|
||||
),
|
||||
error_buckets AS (
|
||||
SELECT
|
||||
date_trunc('minute', created_at) AS bucket,
|
||||
COALESCE(COUNT(*), 0) AS error_count
|
||||
FROM ops_error_logs
|
||||
` + errorWhere + `
|
||||
AND COALESCE(status_code, 0) >= 400
|
||||
GROUP BY 1
|
||||
),
|
||||
combined AS (
|
||||
SELECT
|
||||
COALESCE(u.bucket, e.bucket) AS bucket,
|
||||
COALESCE(u.success_count, 0) AS success_count,
|
||||
COALESCE(u.token_sum, 0) AS token_sum,
|
||||
COALESCE(e.error_count, 0) AS error_count,
|
||||
COALESCE(u.success_count, 0) + COALESCE(e.error_count, 0) AS request_total
|
||||
FROM usage_buckets u
|
||||
FULL OUTER JOIN error_buckets e ON u.bucket = e.bucket
|
||||
)
|
||||
SELECT
|
||||
COALESCE(SUM(success_count), 0) AS success_total,
|
||||
COALESCE(SUM(error_count), 0) AS error_total,
|
||||
COALESCE(SUM(token_sum), 0) AS token_total,
|
||||
COALESCE(MAX(request_total), 0) AS peak_requests_per_min,
|
||||
COALESCE(MAX(token_sum), 0) AS peak_tokens_per_min
|
||||
FROM combined`
|
||||
|
||||
args := append(usageArgs, errorArgs...)
|
||||
var successCount int64
|
||||
var errorTotal int64
|
||||
var tokenConsumed int64
|
||||
var peakRequestsPerMin int64
|
||||
var peakTokensPerMin int64
|
||||
if err := r.db.QueryRowContext(ctx, q, args...).Scan(
|
||||
&successCount,
|
||||
&errorTotal,
|
||||
&tokenConsumed,
|
||||
&peakRequestsPerMin,
|
||||
&peakTokensPerMin,
|
||||
); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
windowSeconds := window.Seconds()
|
||||
if windowSeconds <= 0 {
|
||||
windowSeconds = 1
|
||||
}
|
||||
|
||||
requestCountTotal := successCount + errorTotal
|
||||
qpsAvg := roundTo1DP(float64(requestCountTotal) / windowSeconds)
|
||||
tpsAvg := roundTo1DP(float64(tokenConsumed) / windowSeconds)
|
||||
|
||||
// Keep "current" consistent with the dashboard overview semantics: last 1 minute.
|
||||
// This remains "within the selected window" since end=start+window.
|
||||
qpsCurrent, tpsCurrent, err := r.queryCurrentRates(ctx, filter, end)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
qpsPeak := roundTo1DP(float64(peakRequestsPerMin) / 60.0)
|
||||
tpsPeak := roundTo1DP(float64(peakTokensPerMin) / 60.0)
|
||||
|
||||
return &service.OpsRealtimeTrafficSummary{
|
||||
StartTime: start,
|
||||
EndTime: end,
|
||||
Platform: strings.TrimSpace(filter.Platform),
|
||||
GroupID: filter.GroupID,
|
||||
QPS: service.OpsRateSummary{
|
||||
Current: qpsCurrent,
|
||||
Peak: qpsPeak,
|
||||
Avg: qpsAvg,
|
||||
},
|
||||
TPS: service.OpsRateSummary{
|
||||
Current: tpsCurrent,
|
||||
Peak: tpsPeak,
|
||||
Avg: tpsAvg,
|
||||
},
|
||||
}, nil
|
||||
}
|
||||
286
backend/internal/repository/ops_repo_request_details.go
Normal file
286
backend/internal/repository/ops_repo_request_details.go
Normal file
@@ -0,0 +1,286 @@
|
||||
package repository
|
||||
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"fmt"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/service"
|
||||
)
|
||||
|
||||
func (r *opsRepository) ListRequestDetails(ctx context.Context, filter *service.OpsRequestDetailFilter) ([]*service.OpsRequestDetail, int64, error) {
|
||||
if r == nil || r.db == nil {
|
||||
return nil, 0, fmt.Errorf("nil ops repository")
|
||||
}
|
||||
|
||||
page, pageSize, startTime, endTime := filter.Normalize()
|
||||
offset := (page - 1) * pageSize
|
||||
|
||||
conditions := make([]string, 0, 16)
|
||||
args := make([]any, 0, 24)
|
||||
|
||||
// Placeholders $1/$2 reserved for time window inside the CTE.
|
||||
args = append(args, startTime.UTC(), endTime.UTC())
|
||||
|
||||
addCondition := func(condition string, values ...any) {
|
||||
conditions = append(conditions, condition)
|
||||
args = append(args, values...)
|
||||
}
|
||||
|
||||
if filter != nil {
|
||||
if kind := strings.TrimSpace(strings.ToLower(filter.Kind)); kind != "" && kind != "all" {
|
||||
if kind != string(service.OpsRequestKindSuccess) && kind != string(service.OpsRequestKindError) {
|
||||
return nil, 0, fmt.Errorf("invalid kind")
|
||||
}
|
||||
addCondition(fmt.Sprintf("kind = $%d", len(args)+1), kind)
|
||||
}
|
||||
|
||||
if platform := strings.TrimSpace(strings.ToLower(filter.Platform)); platform != "" {
|
||||
addCondition(fmt.Sprintf("platform = $%d", len(args)+1), platform)
|
||||
}
|
||||
if filter.GroupID != nil && *filter.GroupID > 0 {
|
||||
addCondition(fmt.Sprintf("group_id = $%d", len(args)+1), *filter.GroupID)
|
||||
}
|
||||
|
||||
if filter.UserID != nil && *filter.UserID > 0 {
|
||||
addCondition(fmt.Sprintf("user_id = $%d", len(args)+1), *filter.UserID)
|
||||
}
|
||||
if filter.APIKeyID != nil && *filter.APIKeyID > 0 {
|
||||
addCondition(fmt.Sprintf("api_key_id = $%d", len(args)+1), *filter.APIKeyID)
|
||||
}
|
||||
if filter.AccountID != nil && *filter.AccountID > 0 {
|
||||
addCondition(fmt.Sprintf("account_id = $%d", len(args)+1), *filter.AccountID)
|
||||
}
|
||||
|
||||
if model := strings.TrimSpace(filter.Model); model != "" {
|
||||
addCondition(fmt.Sprintf("model = $%d", len(args)+1), model)
|
||||
}
|
||||
if requestID := strings.TrimSpace(filter.RequestID); requestID != "" {
|
||||
addCondition(fmt.Sprintf("request_id = $%d", len(args)+1), requestID)
|
||||
}
|
||||
if q := strings.TrimSpace(filter.Query); q != "" {
|
||||
like := "%" + strings.ToLower(q) + "%"
|
||||
startIdx := len(args) + 1
|
||||
addCondition(
|
||||
fmt.Sprintf("(LOWER(COALESCE(request_id,'')) LIKE $%d OR LOWER(COALESCE(model,'')) LIKE $%d OR LOWER(COALESCE(message,'')) LIKE $%d)",
|
||||
startIdx, startIdx+1, startIdx+2,
|
||||
),
|
||||
like, like, like,
|
||||
)
|
||||
}
|
||||
|
||||
if filter.MinDurationMs != nil {
|
||||
addCondition(fmt.Sprintf("duration_ms >= $%d", len(args)+1), *filter.MinDurationMs)
|
||||
}
|
||||
if filter.MaxDurationMs != nil {
|
||||
addCondition(fmt.Sprintf("duration_ms <= $%d", len(args)+1), *filter.MaxDurationMs)
|
||||
}
|
||||
}
|
||||
|
||||
where := ""
|
||||
if len(conditions) > 0 {
|
||||
where = "WHERE " + strings.Join(conditions, " AND ")
|
||||
}
|
||||
|
||||
cte := `
|
||||
WITH combined AS (
|
||||
SELECT
|
||||
'success'::TEXT AS kind,
|
||||
ul.created_at AS created_at,
|
||||
ul.request_id AS request_id,
|
||||
COALESCE(NULLIF(g.platform, ''), NULLIF(a.platform, ''), '') AS platform,
|
||||
ul.model AS model,
|
||||
ul.duration_ms AS duration_ms,
|
||||
NULL::INT AS status_code,
|
||||
NULL::BIGINT AS error_id,
|
||||
NULL::TEXT AS phase,
|
||||
NULL::TEXT AS severity,
|
||||
NULL::TEXT AS message,
|
||||
ul.user_id AS user_id,
|
||||
ul.api_key_id AS api_key_id,
|
||||
ul.account_id AS account_id,
|
||||
ul.group_id AS group_id,
|
||||
ul.stream AS stream
|
||||
FROM usage_logs ul
|
||||
LEFT JOIN groups g ON g.id = ul.group_id
|
||||
LEFT JOIN accounts a ON a.id = ul.account_id
|
||||
WHERE ul.created_at >= $1 AND ul.created_at < $2
|
||||
|
||||
UNION ALL
|
||||
|
||||
SELECT
|
||||
'error'::TEXT AS kind,
|
||||
o.created_at AS created_at,
|
||||
COALESCE(NULLIF(o.request_id,''), NULLIF(o.client_request_id,''), '') AS request_id,
|
||||
COALESCE(NULLIF(o.platform, ''), NULLIF(g.platform, ''), NULLIF(a.platform, ''), '') AS platform,
|
||||
o.model AS model,
|
||||
o.duration_ms AS duration_ms,
|
||||
o.status_code AS status_code,
|
||||
o.id AS error_id,
|
||||
o.error_phase AS phase,
|
||||
o.severity AS severity,
|
||||
o.error_message AS message,
|
||||
o.user_id AS user_id,
|
||||
o.api_key_id AS api_key_id,
|
||||
o.account_id AS account_id,
|
||||
o.group_id AS group_id,
|
||||
o.stream AS stream
|
||||
FROM ops_error_logs o
|
||||
LEFT JOIN groups g ON g.id = o.group_id
|
||||
LEFT JOIN accounts a ON a.id = o.account_id
|
||||
WHERE o.created_at >= $1 AND o.created_at < $2
|
||||
AND COALESCE(o.status_code, 0) >= 400
|
||||
)
|
||||
`
|
||||
|
||||
countQuery := fmt.Sprintf(`%s SELECT COUNT(1) FROM combined %s`, cte, where)
|
||||
var total int64
|
||||
if err := r.db.QueryRowContext(ctx, countQuery, args...).Scan(&total); err != nil {
|
||||
if err == sql.ErrNoRows {
|
||||
total = 0
|
||||
} else {
|
||||
return nil, 0, err
|
||||
}
|
||||
}
|
||||
|
||||
sort := "ORDER BY created_at DESC"
|
||||
if filter != nil {
|
||||
switch strings.TrimSpace(strings.ToLower(filter.Sort)) {
|
||||
case "", "created_at_desc":
|
||||
// default
|
||||
case "duration_desc":
|
||||
sort = "ORDER BY duration_ms DESC NULLS LAST, created_at DESC"
|
||||
default:
|
||||
return nil, 0, fmt.Errorf("invalid sort")
|
||||
}
|
||||
}
|
||||
|
||||
listQuery := fmt.Sprintf(`
|
||||
%s
|
||||
SELECT
|
||||
kind,
|
||||
created_at,
|
||||
request_id,
|
||||
platform,
|
||||
model,
|
||||
duration_ms,
|
||||
status_code,
|
||||
error_id,
|
||||
phase,
|
||||
severity,
|
||||
message,
|
||||
user_id,
|
||||
api_key_id,
|
||||
account_id,
|
||||
group_id,
|
||||
stream
|
||||
FROM combined
|
||||
%s
|
||||
%s
|
||||
LIMIT $%d OFFSET $%d
|
||||
`, cte, where, sort, len(args)+1, len(args)+2)
|
||||
|
||||
listArgs := append(append([]any{}, args...), pageSize, offset)
|
||||
rows, err := r.db.QueryContext(ctx, listQuery, listArgs...)
|
||||
if err != nil {
|
||||
return nil, 0, err
|
||||
}
|
||||
defer func() { _ = rows.Close() }()
|
||||
|
||||
toIntPtr := func(v sql.NullInt64) *int {
|
||||
if !v.Valid {
|
||||
return nil
|
||||
}
|
||||
i := int(v.Int64)
|
||||
return &i
|
||||
}
|
||||
toInt64Ptr := func(v sql.NullInt64) *int64 {
|
||||
if !v.Valid {
|
||||
return nil
|
||||
}
|
||||
i := v.Int64
|
||||
return &i
|
||||
}
|
||||
|
||||
out := make([]*service.OpsRequestDetail, 0, pageSize)
|
||||
for rows.Next() {
|
||||
var (
|
||||
kind string
|
||||
createdAt time.Time
|
||||
requestID sql.NullString
|
||||
platform sql.NullString
|
||||
model sql.NullString
|
||||
|
||||
durationMs sql.NullInt64
|
||||
statusCode sql.NullInt64
|
||||
errorID sql.NullInt64
|
||||
|
||||
phase sql.NullString
|
||||
severity sql.NullString
|
||||
message sql.NullString
|
||||
|
||||
userID sql.NullInt64
|
||||
apiKeyID sql.NullInt64
|
||||
accountID sql.NullInt64
|
||||
groupID sql.NullInt64
|
||||
|
||||
stream bool
|
||||
)
|
||||
|
||||
if err := rows.Scan(
|
||||
&kind,
|
||||
&createdAt,
|
||||
&requestID,
|
||||
&platform,
|
||||
&model,
|
||||
&durationMs,
|
||||
&statusCode,
|
||||
&errorID,
|
||||
&phase,
|
||||
&severity,
|
||||
&message,
|
||||
&userID,
|
||||
&apiKeyID,
|
||||
&accountID,
|
||||
&groupID,
|
||||
&stream,
|
||||
); err != nil {
|
||||
return nil, 0, err
|
||||
}
|
||||
|
||||
item := &service.OpsRequestDetail{
|
||||
Kind: service.OpsRequestKind(kind),
|
||||
CreatedAt: createdAt,
|
||||
RequestID: strings.TrimSpace(requestID.String),
|
||||
Platform: strings.TrimSpace(platform.String),
|
||||
Model: strings.TrimSpace(model.String),
|
||||
|
||||
DurationMs: toIntPtr(durationMs),
|
||||
StatusCode: toIntPtr(statusCode),
|
||||
ErrorID: toInt64Ptr(errorID),
|
||||
Phase: phase.String,
|
||||
Severity: severity.String,
|
||||
Message: message.String,
|
||||
|
||||
UserID: toInt64Ptr(userID),
|
||||
APIKeyID: toInt64Ptr(apiKeyID),
|
||||
AccountID: toInt64Ptr(accountID),
|
||||
GroupID: toInt64Ptr(groupID),
|
||||
|
||||
Stream: stream,
|
||||
}
|
||||
|
||||
if item.Platform == "" {
|
||||
item.Platform = "unknown"
|
||||
}
|
||||
|
||||
out = append(out, item)
|
||||
}
|
||||
if err := rows.Err(); err != nil {
|
||||
return nil, 0, err
|
||||
}
|
||||
|
||||
return out, total, nil
|
||||
}
|
||||
573
backend/internal/repository/ops_repo_trends.go
Normal file
573
backend/internal/repository/ops_repo_trends.go
Normal file
@@ -0,0 +1,573 @@
|
||||
package repository
|
||||
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"fmt"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/service"
|
||||
)
|
||||
|
||||
func (r *opsRepository) GetThroughputTrend(ctx context.Context, filter *service.OpsDashboardFilter, bucketSeconds int) (*service.OpsThroughputTrendResponse, error) {
|
||||
if r == nil || r.db == nil {
|
||||
return nil, fmt.Errorf("nil ops repository")
|
||||
}
|
||||
if filter == nil {
|
||||
return nil, fmt.Errorf("nil filter")
|
||||
}
|
||||
if filter.StartTime.IsZero() || filter.EndTime.IsZero() {
|
||||
return nil, fmt.Errorf("start_time/end_time required")
|
||||
}
|
||||
|
||||
if bucketSeconds <= 0 {
|
||||
bucketSeconds = 60
|
||||
}
|
||||
if bucketSeconds != 60 && bucketSeconds != 300 && bucketSeconds != 3600 {
|
||||
// Keep a small, predictable set of supported buckets for now.
|
||||
bucketSeconds = 60
|
||||
}
|
||||
|
||||
start := filter.StartTime.UTC()
|
||||
end := filter.EndTime.UTC()
|
||||
|
||||
usageJoin, usageWhere, usageArgs, next := buildUsageWhere(filter, start, end, 1)
|
||||
errorWhere, errorArgs, _ := buildErrorWhere(filter, start, end, next)
|
||||
|
||||
usageBucketExpr := opsBucketExprForUsage(bucketSeconds)
|
||||
errorBucketExpr := opsBucketExprForError(bucketSeconds)
|
||||
|
||||
q := `
|
||||
WITH usage_buckets AS (
|
||||
SELECT ` + usageBucketExpr + ` AS bucket,
|
||||
COUNT(*) AS success_count,
|
||||
COALESCE(SUM(input_tokens + output_tokens + cache_creation_tokens + cache_read_tokens), 0) AS token_consumed
|
||||
FROM usage_logs ul
|
||||
` + usageJoin + `
|
||||
` + usageWhere + `
|
||||
GROUP BY 1
|
||||
),
|
||||
error_buckets AS (
|
||||
SELECT ` + errorBucketExpr + ` AS bucket,
|
||||
COUNT(*) AS error_count
|
||||
FROM ops_error_logs
|
||||
` + errorWhere + `
|
||||
AND COALESCE(status_code, 0) >= 400
|
||||
GROUP BY 1
|
||||
),
|
||||
combined AS (
|
||||
SELECT COALESCE(u.bucket, e.bucket) AS bucket,
|
||||
COALESCE(u.success_count, 0) AS success_count,
|
||||
COALESCE(e.error_count, 0) AS error_count,
|
||||
COALESCE(u.token_consumed, 0) AS token_consumed
|
||||
FROM usage_buckets u
|
||||
FULL OUTER JOIN error_buckets e ON u.bucket = e.bucket
|
||||
)
|
||||
SELECT
|
||||
bucket,
|
||||
(success_count + error_count) AS request_count,
|
||||
token_consumed
|
||||
FROM combined
|
||||
ORDER BY bucket ASC`
|
||||
|
||||
args := append(usageArgs, errorArgs...)
|
||||
|
||||
rows, err := r.db.QueryContext(ctx, q, args...)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer func() { _ = rows.Close() }()
|
||||
|
||||
points := make([]*service.OpsThroughputTrendPoint, 0, 256)
|
||||
for rows.Next() {
|
||||
var bucket time.Time
|
||||
var requests int64
|
||||
var tokens sql.NullInt64
|
||||
if err := rows.Scan(&bucket, &requests, &tokens); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
tokenConsumed := int64(0)
|
||||
if tokens.Valid {
|
||||
tokenConsumed = tokens.Int64
|
||||
}
|
||||
|
||||
denom := float64(bucketSeconds)
|
||||
if denom <= 0 {
|
||||
denom = 60
|
||||
}
|
||||
qps := roundTo1DP(float64(requests) / denom)
|
||||
tps := roundTo1DP(float64(tokenConsumed) / denom)
|
||||
|
||||
points = append(points, &service.OpsThroughputTrendPoint{
|
||||
BucketStart: bucket.UTC(),
|
||||
RequestCount: requests,
|
||||
TokenConsumed: tokenConsumed,
|
||||
QPS: qps,
|
||||
TPS: tps,
|
||||
})
|
||||
}
|
||||
if err := rows.Err(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Fill missing buckets with zeros so charts render continuous timelines.
|
||||
points = fillOpsThroughputBuckets(start, end, bucketSeconds, points)
|
||||
|
||||
var byPlatform []*service.OpsThroughputPlatformBreakdownItem
|
||||
var topGroups []*service.OpsThroughputGroupBreakdownItem
|
||||
|
||||
platform := ""
|
||||
if filter != nil {
|
||||
platform = strings.TrimSpace(strings.ToLower(filter.Platform))
|
||||
}
|
||||
groupID := (*int64)(nil)
|
||||
if filter != nil {
|
||||
groupID = filter.GroupID
|
||||
}
|
||||
|
||||
// Drilldown helpers:
|
||||
// - No platform/group: totals by platform
|
||||
// - Platform selected but no group: top groups in that platform
|
||||
if platform == "" && (groupID == nil || *groupID <= 0) {
|
||||
items, err := r.getThroughputBreakdownByPlatform(ctx, start, end)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
byPlatform = items
|
||||
} else if platform != "" && (groupID == nil || *groupID <= 0) {
|
||||
items, err := r.getThroughputTopGroupsByPlatform(ctx, start, end, platform, 10)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
topGroups = items
|
||||
}
|
||||
|
||||
return &service.OpsThroughputTrendResponse{
|
||||
Bucket: opsBucketLabel(bucketSeconds),
|
||||
Points: points,
|
||||
|
||||
ByPlatform: byPlatform,
|
||||
TopGroups: topGroups,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (r *opsRepository) getThroughputBreakdownByPlatform(ctx context.Context, start, end time.Time) ([]*service.OpsThroughputPlatformBreakdownItem, error) {
|
||||
q := `
|
||||
WITH usage_totals AS (
|
||||
SELECT COALESCE(NULLIF(g.platform,''), a.platform) AS platform,
|
||||
COUNT(*) AS success_count,
|
||||
COALESCE(SUM(input_tokens + output_tokens + cache_creation_tokens + cache_read_tokens), 0) AS token_consumed
|
||||
FROM usage_logs ul
|
||||
LEFT JOIN groups g ON g.id = ul.group_id
|
||||
LEFT JOIN accounts a ON a.id = ul.account_id
|
||||
WHERE ul.created_at >= $1 AND ul.created_at < $2
|
||||
GROUP BY 1
|
||||
),
|
||||
error_totals AS (
|
||||
SELECT platform,
|
||||
COUNT(*) AS error_count
|
||||
FROM ops_error_logs
|
||||
WHERE created_at >= $1 AND created_at < $2
|
||||
AND COALESCE(status_code, 0) >= 400
|
||||
AND is_count_tokens = FALSE -- 排除 count_tokens 请求的错误
|
||||
GROUP BY 1
|
||||
),
|
||||
combined AS (
|
||||
SELECT COALESCE(u.platform, e.platform) AS platform,
|
||||
COALESCE(u.success_count, 0) AS success_count,
|
||||
COALESCE(e.error_count, 0) AS error_count,
|
||||
COALESCE(u.token_consumed, 0) AS token_consumed
|
||||
FROM usage_totals u
|
||||
FULL OUTER JOIN error_totals e ON u.platform = e.platform
|
||||
)
|
||||
SELECT platform, (success_count + error_count) AS request_count, token_consumed
|
||||
FROM combined
|
||||
WHERE platform IS NOT NULL AND platform <> ''
|
||||
ORDER BY request_count DESC`
|
||||
|
||||
rows, err := r.db.QueryContext(ctx, q, start, end)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer func() { _ = rows.Close() }()
|
||||
|
||||
items := make([]*service.OpsThroughputPlatformBreakdownItem, 0, 8)
|
||||
for rows.Next() {
|
||||
var platform string
|
||||
var requests int64
|
||||
var tokens sql.NullInt64
|
||||
if err := rows.Scan(&platform, &requests, &tokens); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
tokenConsumed := int64(0)
|
||||
if tokens.Valid {
|
||||
tokenConsumed = tokens.Int64
|
||||
}
|
||||
items = append(items, &service.OpsThroughputPlatformBreakdownItem{
|
||||
Platform: platform,
|
||||
RequestCount: requests,
|
||||
TokenConsumed: tokenConsumed,
|
||||
})
|
||||
}
|
||||
if err := rows.Err(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return items, nil
|
||||
}
|
||||
|
||||
func (r *opsRepository) getThroughputTopGroupsByPlatform(ctx context.Context, start, end time.Time, platform string, limit int) ([]*service.OpsThroughputGroupBreakdownItem, error) {
|
||||
if strings.TrimSpace(platform) == "" {
|
||||
return nil, nil
|
||||
}
|
||||
if limit <= 0 || limit > 100 {
|
||||
limit = 10
|
||||
}
|
||||
|
||||
q := `
|
||||
WITH usage_totals AS (
|
||||
SELECT ul.group_id AS group_id,
|
||||
g.name AS group_name,
|
||||
COUNT(*) AS success_count,
|
||||
COALESCE(SUM(input_tokens + output_tokens + cache_creation_tokens + cache_read_tokens), 0) AS token_consumed
|
||||
FROM usage_logs ul
|
||||
JOIN groups g ON g.id = ul.group_id
|
||||
WHERE ul.created_at >= $1 AND ul.created_at < $2
|
||||
AND g.platform = $3
|
||||
GROUP BY 1, 2
|
||||
),
|
||||
error_totals AS (
|
||||
SELECT group_id,
|
||||
COUNT(*) AS error_count
|
||||
FROM ops_error_logs
|
||||
WHERE created_at >= $1 AND created_at < $2
|
||||
AND platform = $3
|
||||
AND group_id IS NOT NULL
|
||||
AND COALESCE(status_code, 0) >= 400
|
||||
AND is_count_tokens = FALSE -- 排除 count_tokens 请求的错误
|
||||
GROUP BY 1
|
||||
),
|
||||
combined AS (
|
||||
SELECT COALESCE(u.group_id, e.group_id) AS group_id,
|
||||
COALESCE(u.group_name, g2.name, '') AS group_name,
|
||||
COALESCE(u.success_count, 0) AS success_count,
|
||||
COALESCE(e.error_count, 0) AS error_count,
|
||||
COALESCE(u.token_consumed, 0) AS token_consumed
|
||||
FROM usage_totals u
|
||||
FULL OUTER JOIN error_totals e ON u.group_id = e.group_id
|
||||
LEFT JOIN groups g2 ON g2.id = COALESCE(u.group_id, e.group_id)
|
||||
)
|
||||
SELECT group_id, group_name, (success_count + error_count) AS request_count, token_consumed
|
||||
FROM combined
|
||||
WHERE group_id IS NOT NULL
|
||||
ORDER BY request_count DESC
|
||||
LIMIT $4`
|
||||
|
||||
rows, err := r.db.QueryContext(ctx, q, start, end, platform, limit)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer func() { _ = rows.Close() }()
|
||||
|
||||
items := make([]*service.OpsThroughputGroupBreakdownItem, 0, limit)
|
||||
for rows.Next() {
|
||||
var groupID int64
|
||||
var groupName sql.NullString
|
||||
var requests int64
|
||||
var tokens sql.NullInt64
|
||||
if err := rows.Scan(&groupID, &groupName, &requests, &tokens); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
tokenConsumed := int64(0)
|
||||
if tokens.Valid {
|
||||
tokenConsumed = tokens.Int64
|
||||
}
|
||||
name := ""
|
||||
if groupName.Valid {
|
||||
name = groupName.String
|
||||
}
|
||||
items = append(items, &service.OpsThroughputGroupBreakdownItem{
|
||||
GroupID: groupID,
|
||||
GroupName: name,
|
||||
RequestCount: requests,
|
||||
TokenConsumed: tokenConsumed,
|
||||
})
|
||||
}
|
||||
if err := rows.Err(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return items, nil
|
||||
}
|
||||
|
||||
func opsBucketExprForUsage(bucketSeconds int) string {
|
||||
switch bucketSeconds {
|
||||
case 3600:
|
||||
return "date_trunc('hour', ul.created_at)"
|
||||
case 300:
|
||||
// 5-minute buckets in UTC.
|
||||
return "to_timestamp(floor(extract(epoch from ul.created_at) / 300) * 300)"
|
||||
default:
|
||||
return "date_trunc('minute', ul.created_at)"
|
||||
}
|
||||
}
|
||||
|
||||
func opsBucketExprForError(bucketSeconds int) string {
|
||||
switch bucketSeconds {
|
||||
case 3600:
|
||||
return "date_trunc('hour', created_at)"
|
||||
case 300:
|
||||
return "to_timestamp(floor(extract(epoch from created_at) / 300) * 300)"
|
||||
default:
|
||||
return "date_trunc('minute', created_at)"
|
||||
}
|
||||
}
|
||||
|
||||
func opsBucketLabel(bucketSeconds int) string {
|
||||
if bucketSeconds <= 0 {
|
||||
return "1m"
|
||||
}
|
||||
if bucketSeconds%3600 == 0 {
|
||||
h := bucketSeconds / 3600
|
||||
if h <= 0 {
|
||||
h = 1
|
||||
}
|
||||
return fmt.Sprintf("%dh", h)
|
||||
}
|
||||
m := bucketSeconds / 60
|
||||
if m <= 0 {
|
||||
m = 1
|
||||
}
|
||||
return fmt.Sprintf("%dm", m)
|
||||
}
|
||||
|
||||
func opsFloorToBucketStart(t time.Time, bucketSeconds int) time.Time {
|
||||
t = t.UTC()
|
||||
if bucketSeconds <= 0 {
|
||||
bucketSeconds = 60
|
||||
}
|
||||
secs := t.Unix()
|
||||
floored := secs - (secs % int64(bucketSeconds))
|
||||
return time.Unix(floored, 0).UTC()
|
||||
}
|
||||
|
||||
func fillOpsThroughputBuckets(start, end time.Time, bucketSeconds int, points []*service.OpsThroughputTrendPoint) []*service.OpsThroughputTrendPoint {
|
||||
if bucketSeconds <= 0 {
|
||||
bucketSeconds = 60
|
||||
}
|
||||
if !start.Before(end) {
|
||||
return points
|
||||
}
|
||||
|
||||
endMinus := end.Add(-time.Nanosecond)
|
||||
if endMinus.Before(start) {
|
||||
return points
|
||||
}
|
||||
|
||||
first := opsFloorToBucketStart(start, bucketSeconds)
|
||||
last := opsFloorToBucketStart(endMinus, bucketSeconds)
|
||||
step := time.Duration(bucketSeconds) * time.Second
|
||||
|
||||
existing := make(map[int64]*service.OpsThroughputTrendPoint, len(points))
|
||||
for _, p := range points {
|
||||
if p == nil {
|
||||
continue
|
||||
}
|
||||
existing[p.BucketStart.UTC().Unix()] = p
|
||||
}
|
||||
|
||||
out := make([]*service.OpsThroughputTrendPoint, 0, int(last.Sub(first)/step)+1)
|
||||
for cursor := first; !cursor.After(last); cursor = cursor.Add(step) {
|
||||
if p, ok := existing[cursor.Unix()]; ok && p != nil {
|
||||
out = append(out, p)
|
||||
continue
|
||||
}
|
||||
out = append(out, &service.OpsThroughputTrendPoint{
|
||||
BucketStart: cursor,
|
||||
RequestCount: 0,
|
||||
TokenConsumed: 0,
|
||||
QPS: 0,
|
||||
TPS: 0,
|
||||
})
|
||||
}
|
||||
return out
|
||||
}
|
||||
|
||||
func (r *opsRepository) GetErrorTrend(ctx context.Context, filter *service.OpsDashboardFilter, bucketSeconds int) (*service.OpsErrorTrendResponse, error) {
|
||||
if r == nil || r.db == nil {
|
||||
return nil, fmt.Errorf("nil ops repository")
|
||||
}
|
||||
if filter == nil {
|
||||
return nil, fmt.Errorf("nil filter")
|
||||
}
|
||||
if filter.StartTime.IsZero() || filter.EndTime.IsZero() {
|
||||
return nil, fmt.Errorf("start_time/end_time required")
|
||||
}
|
||||
|
||||
if bucketSeconds <= 0 {
|
||||
bucketSeconds = 60
|
||||
}
|
||||
if bucketSeconds != 60 && bucketSeconds != 300 && bucketSeconds != 3600 {
|
||||
bucketSeconds = 60
|
||||
}
|
||||
|
||||
start := filter.StartTime.UTC()
|
||||
end := filter.EndTime.UTC()
|
||||
where, args, _ := buildErrorWhere(filter, start, end, 1)
|
||||
bucketExpr := opsBucketExprForError(bucketSeconds)
|
||||
|
||||
q := `
|
||||
SELECT
|
||||
` + bucketExpr + ` AS bucket,
|
||||
COUNT(*) FILTER (WHERE COALESCE(status_code, 0) >= 400) AS error_total,
|
||||
COUNT(*) FILTER (WHERE COALESCE(status_code, 0) >= 400 AND is_business_limited) AS business_limited,
|
||||
COUNT(*) FILTER (WHERE COALESCE(status_code, 0) >= 400 AND NOT is_business_limited) AS error_sla,
|
||||
COUNT(*) FILTER (WHERE error_owner = 'provider' AND NOT is_business_limited AND COALESCE(upstream_status_code, status_code, 0) NOT IN (429, 529)) AS upstream_excl,
|
||||
COUNT(*) FILTER (WHERE error_owner = 'provider' AND NOT is_business_limited AND COALESCE(upstream_status_code, status_code, 0) = 429) AS upstream_429,
|
||||
COUNT(*) FILTER (WHERE error_owner = 'provider' AND NOT is_business_limited AND COALESCE(upstream_status_code, status_code, 0) = 529) AS upstream_529
|
||||
FROM ops_error_logs
|
||||
` + where + `
|
||||
GROUP BY 1
|
||||
ORDER BY 1 ASC`
|
||||
|
||||
rows, err := r.db.QueryContext(ctx, q, args...)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer func() { _ = rows.Close() }()
|
||||
|
||||
points := make([]*service.OpsErrorTrendPoint, 0, 256)
|
||||
for rows.Next() {
|
||||
var bucket time.Time
|
||||
var total, businessLimited, sla, upstreamExcl, upstream429, upstream529 int64
|
||||
if err := rows.Scan(&bucket, &total, &businessLimited, &sla, &upstreamExcl, &upstream429, &upstream529); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
points = append(points, &service.OpsErrorTrendPoint{
|
||||
BucketStart: bucket.UTC(),
|
||||
|
||||
ErrorCountTotal: total,
|
||||
BusinessLimitedCount: businessLimited,
|
||||
ErrorCountSLA: sla,
|
||||
|
||||
UpstreamErrorCountExcl429529: upstreamExcl,
|
||||
Upstream429Count: upstream429,
|
||||
Upstream529Count: upstream529,
|
||||
})
|
||||
}
|
||||
if err := rows.Err(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
points = fillOpsErrorTrendBuckets(start, end, bucketSeconds, points)
|
||||
|
||||
return &service.OpsErrorTrendResponse{
|
||||
Bucket: opsBucketLabel(bucketSeconds),
|
||||
Points: points,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func fillOpsErrorTrendBuckets(start, end time.Time, bucketSeconds int, points []*service.OpsErrorTrendPoint) []*service.OpsErrorTrendPoint {
|
||||
if bucketSeconds <= 0 {
|
||||
bucketSeconds = 60
|
||||
}
|
||||
if !start.Before(end) {
|
||||
return points
|
||||
}
|
||||
|
||||
endMinus := end.Add(-time.Nanosecond)
|
||||
if endMinus.Before(start) {
|
||||
return points
|
||||
}
|
||||
|
||||
first := opsFloorToBucketStart(start, bucketSeconds)
|
||||
last := opsFloorToBucketStart(endMinus, bucketSeconds)
|
||||
step := time.Duration(bucketSeconds) * time.Second
|
||||
|
||||
existing := make(map[int64]*service.OpsErrorTrendPoint, len(points))
|
||||
for _, p := range points {
|
||||
if p == nil {
|
||||
continue
|
||||
}
|
||||
existing[p.BucketStart.UTC().Unix()] = p
|
||||
}
|
||||
|
||||
out := make([]*service.OpsErrorTrendPoint, 0, int(last.Sub(first)/step)+1)
|
||||
for cursor := first; !cursor.After(last); cursor = cursor.Add(step) {
|
||||
if p, ok := existing[cursor.Unix()]; ok && p != nil {
|
||||
out = append(out, p)
|
||||
continue
|
||||
}
|
||||
out = append(out, &service.OpsErrorTrendPoint{
|
||||
BucketStart: cursor,
|
||||
|
||||
ErrorCountTotal: 0,
|
||||
BusinessLimitedCount: 0,
|
||||
ErrorCountSLA: 0,
|
||||
|
||||
UpstreamErrorCountExcl429529: 0,
|
||||
Upstream429Count: 0,
|
||||
Upstream529Count: 0,
|
||||
})
|
||||
}
|
||||
return out
|
||||
}
|
||||
|
||||
func (r *opsRepository) GetErrorDistribution(ctx context.Context, filter *service.OpsDashboardFilter) (*service.OpsErrorDistributionResponse, error) {
|
||||
if r == nil || r.db == nil {
|
||||
return nil, fmt.Errorf("nil ops repository")
|
||||
}
|
||||
if filter == nil {
|
||||
return nil, fmt.Errorf("nil filter")
|
||||
}
|
||||
if filter.StartTime.IsZero() || filter.EndTime.IsZero() {
|
||||
return nil, fmt.Errorf("start_time/end_time required")
|
||||
}
|
||||
|
||||
start := filter.StartTime.UTC()
|
||||
end := filter.EndTime.UTC()
|
||||
where, args, _ := buildErrorWhere(filter, start, end, 1)
|
||||
|
||||
q := `
|
||||
SELECT
|
||||
COALESCE(upstream_status_code, status_code, 0) AS status_code,
|
||||
COUNT(*) AS total,
|
||||
COUNT(*) FILTER (WHERE NOT is_business_limited) AS sla,
|
||||
COUNT(*) FILTER (WHERE is_business_limited) AS business_limited
|
||||
FROM ops_error_logs
|
||||
` + where + `
|
||||
AND COALESCE(status_code, 0) >= 400
|
||||
GROUP BY 1
|
||||
ORDER BY total DESC
|
||||
LIMIT 20`
|
||||
|
||||
rows, err := r.db.QueryContext(ctx, q, args...)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer func() { _ = rows.Close() }()
|
||||
|
||||
items := make([]*service.OpsErrorDistributionItem, 0, 16)
|
||||
var total int64
|
||||
for rows.Next() {
|
||||
var statusCode int
|
||||
var cntTotal, cntSLA, cntBiz int64
|
||||
if err := rows.Scan(&statusCode, &cntTotal, &cntSLA, &cntBiz); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
total += cntTotal
|
||||
items = append(items, &service.OpsErrorDistributionItem{
|
||||
StatusCode: statusCode,
|
||||
Total: cntTotal,
|
||||
SLA: cntSLA,
|
||||
BusinessLimited: cntBiz,
|
||||
})
|
||||
}
|
||||
if err := rows.Err(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return &service.OpsErrorDistributionResponse{
|
||||
Total: total,
|
||||
Items: items,
|
||||
}, nil
|
||||
}
|
||||
50
backend/internal/repository/ops_repo_window_stats.go
Normal file
50
backend/internal/repository/ops_repo_window_stats.go
Normal file
@@ -0,0 +1,50 @@
|
||||
package repository
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"time"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/service"
|
||||
)
|
||||
|
||||
func (r *opsRepository) GetWindowStats(ctx context.Context, filter *service.OpsDashboardFilter) (*service.OpsWindowStats, error) {
|
||||
if r == nil || r.db == nil {
|
||||
return nil, fmt.Errorf("nil ops repository")
|
||||
}
|
||||
if filter == nil {
|
||||
return nil, fmt.Errorf("nil filter")
|
||||
}
|
||||
if filter.StartTime.IsZero() || filter.EndTime.IsZero() {
|
||||
return nil, fmt.Errorf("start_time/end_time required")
|
||||
}
|
||||
|
||||
start := filter.StartTime.UTC()
|
||||
end := filter.EndTime.UTC()
|
||||
if start.After(end) {
|
||||
return nil, fmt.Errorf("start_time must be <= end_time")
|
||||
}
|
||||
// Bound excessively large windows to prevent accidental heavy queries.
|
||||
if end.Sub(start) > 24*time.Hour {
|
||||
return nil, fmt.Errorf("window too large")
|
||||
}
|
||||
|
||||
successCount, tokenConsumed, err := r.queryUsageCounts(ctx, filter, start, end)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
errorTotal, _, _, _, _, _, err := r.queryErrorCounts(ctx, filter, start, end)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return &service.OpsWindowStats{
|
||||
StartTime: start,
|
||||
EndTime: end,
|
||||
|
||||
SuccessCount: successCount,
|
||||
ErrorCountTotal: errorTotal,
|
||||
TokenConsumed: tokenConsumed,
|
||||
}, nil
|
||||
}
|
||||
273
backend/internal/repository/promo_code_repo.go
Normal file
273
backend/internal/repository/promo_code_repo.go
Normal file
@@ -0,0 +1,273 @@
|
||||
package repository
|
||||
|
||||
import (
|
||||
"context"
|
||||
|
||||
dbent "github.com/Wei-Shaw/sub2api/ent"
|
||||
"github.com/Wei-Shaw/sub2api/ent/promocode"
|
||||
"github.com/Wei-Shaw/sub2api/ent/promocodeusage"
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/pagination"
|
||||
"github.com/Wei-Shaw/sub2api/internal/service"
|
||||
)
|
||||
|
||||
type promoCodeRepository struct {
|
||||
client *dbent.Client
|
||||
}
|
||||
|
||||
func NewPromoCodeRepository(client *dbent.Client) service.PromoCodeRepository {
|
||||
return &promoCodeRepository{client: client}
|
||||
}
|
||||
|
||||
func (r *promoCodeRepository) Create(ctx context.Context, code *service.PromoCode) error {
|
||||
client := clientFromContext(ctx, r.client)
|
||||
builder := client.PromoCode.Create().
|
||||
SetCode(code.Code).
|
||||
SetBonusAmount(code.BonusAmount).
|
||||
SetMaxUses(code.MaxUses).
|
||||
SetUsedCount(code.UsedCount).
|
||||
SetStatus(code.Status).
|
||||
SetNotes(code.Notes)
|
||||
|
||||
if code.ExpiresAt != nil {
|
||||
builder.SetExpiresAt(*code.ExpiresAt)
|
||||
}
|
||||
|
||||
created, err := builder.Save(ctx)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
code.ID = created.ID
|
||||
code.CreatedAt = created.CreatedAt
|
||||
code.UpdatedAt = created.UpdatedAt
|
||||
return nil
|
||||
}
|
||||
|
||||
func (r *promoCodeRepository) GetByID(ctx context.Context, id int64) (*service.PromoCode, error) {
|
||||
m, err := r.client.PromoCode.Query().
|
||||
Where(promocode.IDEQ(id)).
|
||||
Only(ctx)
|
||||
if err != nil {
|
||||
if dbent.IsNotFound(err) {
|
||||
return nil, service.ErrPromoCodeNotFound
|
||||
}
|
||||
return nil, err
|
||||
}
|
||||
return promoCodeEntityToService(m), nil
|
||||
}
|
||||
|
||||
func (r *promoCodeRepository) GetByCode(ctx context.Context, code string) (*service.PromoCode, error) {
|
||||
m, err := r.client.PromoCode.Query().
|
||||
Where(promocode.CodeEqualFold(code)).
|
||||
Only(ctx)
|
||||
if err != nil {
|
||||
if dbent.IsNotFound(err) {
|
||||
return nil, service.ErrPromoCodeNotFound
|
||||
}
|
||||
return nil, err
|
||||
}
|
||||
return promoCodeEntityToService(m), nil
|
||||
}
|
||||
|
||||
func (r *promoCodeRepository) GetByCodeForUpdate(ctx context.Context, code string) (*service.PromoCode, error) {
|
||||
client := clientFromContext(ctx, r.client)
|
||||
m, err := client.PromoCode.Query().
|
||||
Where(promocode.CodeEqualFold(code)).
|
||||
ForUpdate().
|
||||
Only(ctx)
|
||||
if err != nil {
|
||||
if dbent.IsNotFound(err) {
|
||||
return nil, service.ErrPromoCodeNotFound
|
||||
}
|
||||
return nil, err
|
||||
}
|
||||
return promoCodeEntityToService(m), nil
|
||||
}
|
||||
|
||||
func (r *promoCodeRepository) Update(ctx context.Context, code *service.PromoCode) error {
|
||||
client := clientFromContext(ctx, r.client)
|
||||
builder := client.PromoCode.UpdateOneID(code.ID).
|
||||
SetCode(code.Code).
|
||||
SetBonusAmount(code.BonusAmount).
|
||||
SetMaxUses(code.MaxUses).
|
||||
SetUsedCount(code.UsedCount).
|
||||
SetStatus(code.Status).
|
||||
SetNotes(code.Notes)
|
||||
|
||||
if code.ExpiresAt != nil {
|
||||
builder.SetExpiresAt(*code.ExpiresAt)
|
||||
} else {
|
||||
builder.ClearExpiresAt()
|
||||
}
|
||||
|
||||
updated, err := builder.Save(ctx)
|
||||
if err != nil {
|
||||
if dbent.IsNotFound(err) {
|
||||
return service.ErrPromoCodeNotFound
|
||||
}
|
||||
return err
|
||||
}
|
||||
|
||||
code.UpdatedAt = updated.UpdatedAt
|
||||
return nil
|
||||
}
|
||||
|
||||
func (r *promoCodeRepository) Delete(ctx context.Context, id int64) error {
|
||||
client := clientFromContext(ctx, r.client)
|
||||
_, err := client.PromoCode.Delete().Where(promocode.IDEQ(id)).Exec(ctx)
|
||||
return err
|
||||
}
|
||||
|
||||
func (r *promoCodeRepository) List(ctx context.Context, params pagination.PaginationParams) ([]service.PromoCode, *pagination.PaginationResult, error) {
|
||||
return r.ListWithFilters(ctx, params, "", "")
|
||||
}
|
||||
|
||||
func (r *promoCodeRepository) ListWithFilters(ctx context.Context, params pagination.PaginationParams, status, search string) ([]service.PromoCode, *pagination.PaginationResult, error) {
|
||||
q := r.client.PromoCode.Query()
|
||||
|
||||
if status != "" {
|
||||
q = q.Where(promocode.StatusEQ(status))
|
||||
}
|
||||
if search != "" {
|
||||
q = q.Where(promocode.CodeContainsFold(search))
|
||||
}
|
||||
|
||||
total, err := q.Count(ctx)
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
|
||||
codes, err := q.
|
||||
Offset(params.Offset()).
|
||||
Limit(params.Limit()).
|
||||
Order(dbent.Desc(promocode.FieldID)).
|
||||
All(ctx)
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
|
||||
outCodes := promoCodeEntitiesToService(codes)
|
||||
|
||||
return outCodes, paginationResultFromTotal(int64(total), params), nil
|
||||
}
|
||||
|
||||
func (r *promoCodeRepository) CreateUsage(ctx context.Context, usage *service.PromoCodeUsage) error {
|
||||
client := clientFromContext(ctx, r.client)
|
||||
created, err := client.PromoCodeUsage.Create().
|
||||
SetPromoCodeID(usage.PromoCodeID).
|
||||
SetUserID(usage.UserID).
|
||||
SetBonusAmount(usage.BonusAmount).
|
||||
SetUsedAt(usage.UsedAt).
|
||||
Save(ctx)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
usage.ID = created.ID
|
||||
return nil
|
||||
}
|
||||
|
||||
func (r *promoCodeRepository) GetUsageByPromoCodeAndUser(ctx context.Context, promoCodeID, userID int64) (*service.PromoCodeUsage, error) {
|
||||
m, err := r.client.PromoCodeUsage.Query().
|
||||
Where(
|
||||
promocodeusage.PromoCodeIDEQ(promoCodeID),
|
||||
promocodeusage.UserIDEQ(userID),
|
||||
).
|
||||
Only(ctx)
|
||||
if err != nil {
|
||||
if dbent.IsNotFound(err) {
|
||||
return nil, nil
|
||||
}
|
||||
return nil, err
|
||||
}
|
||||
return promoCodeUsageEntityToService(m), nil
|
||||
}
|
||||
|
||||
func (r *promoCodeRepository) ListUsagesByPromoCode(ctx context.Context, promoCodeID int64, params pagination.PaginationParams) ([]service.PromoCodeUsage, *pagination.PaginationResult, error) {
|
||||
q := r.client.PromoCodeUsage.Query().
|
||||
Where(promocodeusage.PromoCodeIDEQ(promoCodeID))
|
||||
|
||||
total, err := q.Count(ctx)
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
|
||||
usages, err := q.
|
||||
WithUser().
|
||||
Offset(params.Offset()).
|
||||
Limit(params.Limit()).
|
||||
Order(dbent.Desc(promocodeusage.FieldID)).
|
||||
All(ctx)
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
|
||||
outUsages := promoCodeUsageEntitiesToService(usages)
|
||||
|
||||
return outUsages, paginationResultFromTotal(int64(total), params), nil
|
||||
}
|
||||
|
||||
func (r *promoCodeRepository) IncrementUsedCount(ctx context.Context, id int64) error {
|
||||
client := clientFromContext(ctx, r.client)
|
||||
_, err := client.PromoCode.UpdateOneID(id).
|
||||
AddUsedCount(1).
|
||||
Save(ctx)
|
||||
return err
|
||||
}
|
||||
|
||||
// Entity to Service conversions
|
||||
|
||||
func promoCodeEntityToService(m *dbent.PromoCode) *service.PromoCode {
|
||||
if m == nil {
|
||||
return nil
|
||||
}
|
||||
return &service.PromoCode{
|
||||
ID: m.ID,
|
||||
Code: m.Code,
|
||||
BonusAmount: m.BonusAmount,
|
||||
MaxUses: m.MaxUses,
|
||||
UsedCount: m.UsedCount,
|
||||
Status: m.Status,
|
||||
ExpiresAt: m.ExpiresAt,
|
||||
Notes: derefString(m.Notes),
|
||||
CreatedAt: m.CreatedAt,
|
||||
UpdatedAt: m.UpdatedAt,
|
||||
}
|
||||
}
|
||||
|
||||
func promoCodeEntitiesToService(models []*dbent.PromoCode) []service.PromoCode {
|
||||
out := make([]service.PromoCode, 0, len(models))
|
||||
for i := range models {
|
||||
if s := promoCodeEntityToService(models[i]); s != nil {
|
||||
out = append(out, *s)
|
||||
}
|
||||
}
|
||||
return out
|
||||
}
|
||||
|
||||
func promoCodeUsageEntityToService(m *dbent.PromoCodeUsage) *service.PromoCodeUsage {
|
||||
if m == nil {
|
||||
return nil
|
||||
}
|
||||
out := &service.PromoCodeUsage{
|
||||
ID: m.ID,
|
||||
PromoCodeID: m.PromoCodeID,
|
||||
UserID: m.UserID,
|
||||
BonusAmount: m.BonusAmount,
|
||||
UsedAt: m.UsedAt,
|
||||
}
|
||||
if m.Edges.User != nil {
|
||||
out.User = userEntityToService(m.Edges.User)
|
||||
}
|
||||
return out
|
||||
}
|
||||
|
||||
func promoCodeUsageEntitiesToService(models []*dbent.PromoCodeUsage) []service.PromoCodeUsage {
|
||||
out := make([]service.PromoCodeUsage, 0, len(models))
|
||||
for i := range models {
|
||||
if s := promoCodeUsageEntityToService(models[i]); s != nil {
|
||||
out = append(out, *s)
|
||||
}
|
||||
}
|
||||
return out
|
||||
}
|
||||
74
backend/internal/repository/proxy_latency_cache.go
Normal file
74
backend/internal/repository/proxy_latency_cache.go
Normal file
@@ -0,0 +1,74 @@
|
||||
package repository
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/service"
|
||||
"github.com/redis/go-redis/v9"
|
||||
)
|
||||
|
||||
const proxyLatencyKeyPrefix = "proxy:latency:"
|
||||
|
||||
func proxyLatencyKey(proxyID int64) string {
|
||||
return fmt.Sprintf("%s%d", proxyLatencyKeyPrefix, proxyID)
|
||||
}
|
||||
|
||||
type proxyLatencyCache struct {
|
||||
rdb *redis.Client
|
||||
}
|
||||
|
||||
func NewProxyLatencyCache(rdb *redis.Client) service.ProxyLatencyCache {
|
||||
return &proxyLatencyCache{rdb: rdb}
|
||||
}
|
||||
|
||||
func (c *proxyLatencyCache) GetProxyLatencies(ctx context.Context, proxyIDs []int64) (map[int64]*service.ProxyLatencyInfo, error) {
|
||||
results := make(map[int64]*service.ProxyLatencyInfo)
|
||||
if len(proxyIDs) == 0 {
|
||||
return results, nil
|
||||
}
|
||||
|
||||
keys := make([]string, 0, len(proxyIDs))
|
||||
for _, id := range proxyIDs {
|
||||
keys = append(keys, proxyLatencyKey(id))
|
||||
}
|
||||
|
||||
values, err := c.rdb.MGet(ctx, keys...).Result()
|
||||
if err != nil {
|
||||
return results, err
|
||||
}
|
||||
|
||||
for i, raw := range values {
|
||||
if raw == nil {
|
||||
continue
|
||||
}
|
||||
var payload []byte
|
||||
switch v := raw.(type) {
|
||||
case string:
|
||||
payload = []byte(v)
|
||||
case []byte:
|
||||
payload = v
|
||||
default:
|
||||
continue
|
||||
}
|
||||
var info service.ProxyLatencyInfo
|
||||
if err := json.Unmarshal(payload, &info); err != nil {
|
||||
continue
|
||||
}
|
||||
results[proxyIDs[i]] = &info
|
||||
}
|
||||
|
||||
return results, nil
|
||||
}
|
||||
|
||||
func (c *proxyLatencyCache) SetProxyLatency(ctx context.Context, proxyID int64, info *service.ProxyLatencyInfo) error {
|
||||
if info == nil {
|
||||
return nil
|
||||
}
|
||||
payload, err := json.Marshal(info)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
return c.rdb.Set(ctx, proxyLatencyKey(proxyID), payload, 0).Err()
|
||||
}
|
||||
@@ -7,6 +7,7 @@ import (
|
||||
"io"
|
||||
"log"
|
||||
"net/http"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/config"
|
||||
@@ -34,7 +35,10 @@ func NewProxyExitInfoProber(cfg *config.Config) service.ProxyExitInfoProber {
|
||||
}
|
||||
}
|
||||
|
||||
const defaultIPInfoURL = "https://ipinfo.io/json"
|
||||
const (
|
||||
defaultIPInfoURL = "http://ip-api.com/json/?lang=zh-CN"
|
||||
defaultProxyProbeTimeout = 30 * time.Second
|
||||
)
|
||||
|
||||
type proxyProbeService struct {
|
||||
ipInfoURL string
|
||||
@@ -46,7 +50,7 @@ type proxyProbeService struct {
|
||||
func (s *proxyProbeService) ProbeProxy(ctx context.Context, proxyURL string) (*service.ProxyExitInfo, int64, error) {
|
||||
client, err := httpclient.GetClient(httpclient.Options{
|
||||
ProxyURL: proxyURL,
|
||||
Timeout: 15 * time.Second,
|
||||
Timeout: defaultProxyProbeTimeout,
|
||||
InsecureSkipVerify: s.insecureSkipVerify,
|
||||
ProxyStrict: true,
|
||||
ValidateResolvedIP: s.validateResolvedIP,
|
||||
@@ -75,10 +79,14 @@ func (s *proxyProbeService) ProbeProxy(ctx context.Context, proxyURL string) (*s
|
||||
}
|
||||
|
||||
var ipInfo struct {
|
||||
IP string `json:"ip"`
|
||||
City string `json:"city"`
|
||||
Region string `json:"region"`
|
||||
Country string `json:"country"`
|
||||
Status string `json:"status"`
|
||||
Message string `json:"message"`
|
||||
Query string `json:"query"`
|
||||
City string `json:"city"`
|
||||
Region string `json:"region"`
|
||||
RegionName string `json:"regionName"`
|
||||
Country string `json:"country"`
|
||||
CountryCode string `json:"countryCode"`
|
||||
}
|
||||
|
||||
body, err := io.ReadAll(resp.Body)
|
||||
@@ -89,11 +97,22 @@ func (s *proxyProbeService) ProbeProxy(ctx context.Context, proxyURL string) (*s
|
||||
if err := json.Unmarshal(body, &ipInfo); err != nil {
|
||||
return nil, latencyMs, fmt.Errorf("failed to parse response: %w", err)
|
||||
}
|
||||
if strings.ToLower(ipInfo.Status) != "success" {
|
||||
if ipInfo.Message == "" {
|
||||
ipInfo.Message = "ip-api request failed"
|
||||
}
|
||||
return nil, latencyMs, fmt.Errorf("ip-api request failed: %s", ipInfo.Message)
|
||||
}
|
||||
|
||||
region := ipInfo.RegionName
|
||||
if region == "" {
|
||||
region = ipInfo.Region
|
||||
}
|
||||
return &service.ProxyExitInfo{
|
||||
IP: ipInfo.IP,
|
||||
City: ipInfo.City,
|
||||
Region: ipInfo.Region,
|
||||
Country: ipInfo.Country,
|
||||
IP: ipInfo.Query,
|
||||
City: ipInfo.City,
|
||||
Region: region,
|
||||
Country: ipInfo.Country,
|
||||
CountryCode: ipInfo.CountryCode,
|
||||
}, latencyMs, nil
|
||||
}
|
||||
|
||||
@@ -21,7 +21,7 @@ type ProxyProbeServiceSuite struct {
|
||||
func (s *ProxyProbeServiceSuite) SetupTest() {
|
||||
s.ctx = context.Background()
|
||||
s.prober = &proxyProbeService{
|
||||
ipInfoURL: "http://ipinfo.test/json",
|
||||
ipInfoURL: "http://ip-api.test/json/?lang=zh-CN",
|
||||
allowPrivateHosts: true,
|
||||
}
|
||||
}
|
||||
@@ -54,7 +54,7 @@ func (s *ProxyProbeServiceSuite) TestProbeProxy_Success() {
|
||||
s.setupProxyServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
seen <- r.RequestURI
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
_, _ = io.WriteString(w, `{"ip":"1.2.3.4","city":"c","region":"r","country":"cc"}`)
|
||||
_, _ = io.WriteString(w, `{"status":"success","query":"1.2.3.4","city":"c","regionName":"r","country":"cc","countryCode":"CC"}`)
|
||||
}))
|
||||
|
||||
info, latencyMs, err := s.prober.ProbeProxy(s.ctx, s.proxySrv.URL)
|
||||
@@ -64,11 +64,12 @@ func (s *ProxyProbeServiceSuite) TestProbeProxy_Success() {
|
||||
require.Equal(s.T(), "c", info.City)
|
||||
require.Equal(s.T(), "r", info.Region)
|
||||
require.Equal(s.T(), "cc", info.Country)
|
||||
require.Equal(s.T(), "CC", info.CountryCode)
|
||||
|
||||
// Verify proxy received the request
|
||||
select {
|
||||
case uri := <-seen:
|
||||
require.Contains(s.T(), uri, "ipinfo.test", "expected request to go through proxy")
|
||||
require.Contains(s.T(), uri, "ip-api.test", "expected request to go through proxy")
|
||||
default:
|
||||
require.Fail(s.T(), "expected proxy to receive request")
|
||||
}
|
||||
|
||||
@@ -219,12 +219,54 @@ func (r *proxyRepository) ExistsByHostPortAuth(ctx context.Context, host string,
|
||||
// CountAccountsByProxyID returns the number of accounts using a specific proxy
|
||||
func (r *proxyRepository) CountAccountsByProxyID(ctx context.Context, proxyID int64) (int64, error) {
|
||||
var count int64
|
||||
if err := scanSingleRow(ctx, r.sql, "SELECT COUNT(*) FROM accounts WHERE proxy_id = $1", []any{proxyID}, &count); err != nil {
|
||||
if err := scanSingleRow(ctx, r.sql, "SELECT COUNT(*) FROM accounts WHERE proxy_id = $1 AND deleted_at IS NULL", []any{proxyID}, &count); err != nil {
|
||||
return 0, err
|
||||
}
|
||||
return count, nil
|
||||
}
|
||||
|
||||
func (r *proxyRepository) ListAccountSummariesByProxyID(ctx context.Context, proxyID int64) ([]service.ProxyAccountSummary, error) {
|
||||
rows, err := r.sql.QueryContext(ctx, `
|
||||
SELECT id, name, platform, type, notes
|
||||
FROM accounts
|
||||
WHERE proxy_id = $1 AND deleted_at IS NULL
|
||||
ORDER BY id DESC
|
||||
`, proxyID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer func() { _ = rows.Close() }()
|
||||
|
||||
out := make([]service.ProxyAccountSummary, 0)
|
||||
for rows.Next() {
|
||||
var (
|
||||
id int64
|
||||
name string
|
||||
platform string
|
||||
accType string
|
||||
notes sql.NullString
|
||||
)
|
||||
if err := rows.Scan(&id, &name, &platform, &accType, ¬es); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
var notesPtr *string
|
||||
if notes.Valid {
|
||||
notesPtr = ¬es.String
|
||||
}
|
||||
out = append(out, service.ProxyAccountSummary{
|
||||
ID: id,
|
||||
Name: name,
|
||||
Platform: platform,
|
||||
Type: accType,
|
||||
Notes: notesPtr,
|
||||
})
|
||||
}
|
||||
if err := rows.Err(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return out, nil
|
||||
}
|
||||
|
||||
// GetAccountCountsForProxies returns a map of proxy ID to account count for all proxies
|
||||
func (r *proxyRepository) GetAccountCountsForProxies(ctx context.Context) (counts map[int64]int64, err error) {
|
||||
rows, err := r.sql.QueryContext(ctx, "SELECT proxy_id, COUNT(*) AS count FROM accounts WHERE proxy_id IS NOT NULL AND deleted_at IS NULL GROUP BY proxy_id")
|
||||
|
||||
276
backend/internal/repository/scheduler_cache.go
Normal file
276
backend/internal/repository/scheduler_cache.go
Normal file
@@ -0,0 +1,276 @@
|
||||
package repository
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"strconv"
|
||||
"time"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/service"
|
||||
"github.com/redis/go-redis/v9"
|
||||
)
|
||||
|
||||
const (
|
||||
schedulerBucketSetKey = "sched:buckets"
|
||||
schedulerOutboxWatermarkKey = "sched:outbox:watermark"
|
||||
schedulerAccountPrefix = "sched:acc:"
|
||||
schedulerActivePrefix = "sched:active:"
|
||||
schedulerReadyPrefix = "sched:ready:"
|
||||
schedulerVersionPrefix = "sched:ver:"
|
||||
schedulerSnapshotPrefix = "sched:"
|
||||
schedulerLockPrefix = "sched:lock:"
|
||||
)
|
||||
|
||||
type schedulerCache struct {
|
||||
rdb *redis.Client
|
||||
}
|
||||
|
||||
func NewSchedulerCache(rdb *redis.Client) service.SchedulerCache {
|
||||
return &schedulerCache{rdb: rdb}
|
||||
}
|
||||
|
||||
func (c *schedulerCache) GetSnapshot(ctx context.Context, bucket service.SchedulerBucket) ([]*service.Account, bool, error) {
|
||||
readyKey := schedulerBucketKey(schedulerReadyPrefix, bucket)
|
||||
readyVal, err := c.rdb.Get(ctx, readyKey).Result()
|
||||
if err == redis.Nil {
|
||||
return nil, false, nil
|
||||
}
|
||||
if err != nil {
|
||||
return nil, false, err
|
||||
}
|
||||
if readyVal != "1" {
|
||||
return nil, false, nil
|
||||
}
|
||||
|
||||
activeKey := schedulerBucketKey(schedulerActivePrefix, bucket)
|
||||
activeVal, err := c.rdb.Get(ctx, activeKey).Result()
|
||||
if err == redis.Nil {
|
||||
return nil, false, nil
|
||||
}
|
||||
if err != nil {
|
||||
return nil, false, err
|
||||
}
|
||||
|
||||
snapshotKey := schedulerSnapshotKey(bucket, activeVal)
|
||||
ids, err := c.rdb.ZRange(ctx, snapshotKey, 0, -1).Result()
|
||||
if err != nil {
|
||||
return nil, false, err
|
||||
}
|
||||
if len(ids) == 0 {
|
||||
return []*service.Account{}, true, nil
|
||||
}
|
||||
|
||||
keys := make([]string, 0, len(ids))
|
||||
for _, id := range ids {
|
||||
keys = append(keys, schedulerAccountKey(id))
|
||||
}
|
||||
values, err := c.rdb.MGet(ctx, keys...).Result()
|
||||
if err != nil {
|
||||
return nil, false, err
|
||||
}
|
||||
|
||||
accounts := make([]*service.Account, 0, len(values))
|
||||
for _, val := range values {
|
||||
if val == nil {
|
||||
return nil, false, nil
|
||||
}
|
||||
account, err := decodeCachedAccount(val)
|
||||
if err != nil {
|
||||
return nil, false, err
|
||||
}
|
||||
accounts = append(accounts, account)
|
||||
}
|
||||
|
||||
return accounts, true, nil
|
||||
}
|
||||
|
||||
func (c *schedulerCache) SetSnapshot(ctx context.Context, bucket service.SchedulerBucket, accounts []service.Account) error {
|
||||
activeKey := schedulerBucketKey(schedulerActivePrefix, bucket)
|
||||
oldActive, _ := c.rdb.Get(ctx, activeKey).Result()
|
||||
|
||||
versionKey := schedulerBucketKey(schedulerVersionPrefix, bucket)
|
||||
version, err := c.rdb.Incr(ctx, versionKey).Result()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
versionStr := strconv.FormatInt(version, 10)
|
||||
snapshotKey := schedulerSnapshotKey(bucket, versionStr)
|
||||
|
||||
pipe := c.rdb.Pipeline()
|
||||
for _, account := range accounts {
|
||||
payload, err := json.Marshal(account)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
pipe.Set(ctx, schedulerAccountKey(strconv.FormatInt(account.ID, 10)), payload, 0)
|
||||
}
|
||||
if len(accounts) > 0 {
|
||||
// 使用序号作为 score,保持数据库返回的排序语义。
|
||||
members := make([]redis.Z, 0, len(accounts))
|
||||
for idx, account := range accounts {
|
||||
members = append(members, redis.Z{
|
||||
Score: float64(idx),
|
||||
Member: strconv.FormatInt(account.ID, 10),
|
||||
})
|
||||
}
|
||||
pipe.ZAdd(ctx, snapshotKey, members...)
|
||||
} else {
|
||||
pipe.Del(ctx, snapshotKey)
|
||||
}
|
||||
pipe.Set(ctx, activeKey, versionStr, 0)
|
||||
pipe.Set(ctx, schedulerBucketKey(schedulerReadyPrefix, bucket), "1", 0)
|
||||
pipe.SAdd(ctx, schedulerBucketSetKey, bucket.String())
|
||||
if _, err := pipe.Exec(ctx); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if oldActive != "" && oldActive != versionStr {
|
||||
_ = c.rdb.Del(ctx, schedulerSnapshotKey(bucket, oldActive)).Err()
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *schedulerCache) GetAccount(ctx context.Context, accountID int64) (*service.Account, error) {
|
||||
key := schedulerAccountKey(strconv.FormatInt(accountID, 10))
|
||||
val, err := c.rdb.Get(ctx, key).Result()
|
||||
if err == redis.Nil {
|
||||
return nil, nil
|
||||
}
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return decodeCachedAccount(val)
|
||||
}
|
||||
|
||||
func (c *schedulerCache) SetAccount(ctx context.Context, account *service.Account) error {
|
||||
if account == nil || account.ID <= 0 {
|
||||
return nil
|
||||
}
|
||||
payload, err := json.Marshal(account)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
key := schedulerAccountKey(strconv.FormatInt(account.ID, 10))
|
||||
return c.rdb.Set(ctx, key, payload, 0).Err()
|
||||
}
|
||||
|
||||
func (c *schedulerCache) DeleteAccount(ctx context.Context, accountID int64) error {
|
||||
if accountID <= 0 {
|
||||
return nil
|
||||
}
|
||||
key := schedulerAccountKey(strconv.FormatInt(accountID, 10))
|
||||
return c.rdb.Del(ctx, key).Err()
|
||||
}
|
||||
|
||||
func (c *schedulerCache) UpdateLastUsed(ctx context.Context, updates map[int64]time.Time) error {
|
||||
if len(updates) == 0 {
|
||||
return nil
|
||||
}
|
||||
|
||||
keys := make([]string, 0, len(updates))
|
||||
ids := make([]int64, 0, len(updates))
|
||||
for id := range updates {
|
||||
keys = append(keys, schedulerAccountKey(strconv.FormatInt(id, 10)))
|
||||
ids = append(ids, id)
|
||||
}
|
||||
|
||||
values, err := c.rdb.MGet(ctx, keys...).Result()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
pipe := c.rdb.Pipeline()
|
||||
for i, val := range values {
|
||||
if val == nil {
|
||||
continue
|
||||
}
|
||||
account, err := decodeCachedAccount(val)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
account.LastUsedAt = ptrTime(updates[ids[i]])
|
||||
updated, err := json.Marshal(account)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
pipe.Set(ctx, keys[i], updated, 0)
|
||||
}
|
||||
_, err = pipe.Exec(ctx)
|
||||
return err
|
||||
}
|
||||
|
||||
func (c *schedulerCache) TryLockBucket(ctx context.Context, bucket service.SchedulerBucket, ttl time.Duration) (bool, error) {
|
||||
key := schedulerBucketKey(schedulerLockPrefix, bucket)
|
||||
return c.rdb.SetNX(ctx, key, time.Now().UnixNano(), ttl).Result()
|
||||
}
|
||||
|
||||
func (c *schedulerCache) ListBuckets(ctx context.Context) ([]service.SchedulerBucket, error) {
|
||||
raw, err := c.rdb.SMembers(ctx, schedulerBucketSetKey).Result()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
out := make([]service.SchedulerBucket, 0, len(raw))
|
||||
for _, entry := range raw {
|
||||
bucket, ok := service.ParseSchedulerBucket(entry)
|
||||
if !ok {
|
||||
continue
|
||||
}
|
||||
out = append(out, bucket)
|
||||
}
|
||||
return out, nil
|
||||
}
|
||||
|
||||
func (c *schedulerCache) GetOutboxWatermark(ctx context.Context) (int64, error) {
|
||||
val, err := c.rdb.Get(ctx, schedulerOutboxWatermarkKey).Result()
|
||||
if err == redis.Nil {
|
||||
return 0, nil
|
||||
}
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
id, err := strconv.ParseInt(val, 10, 64)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
return id, nil
|
||||
}
|
||||
|
||||
func (c *schedulerCache) SetOutboxWatermark(ctx context.Context, id int64) error {
|
||||
return c.rdb.Set(ctx, schedulerOutboxWatermarkKey, strconv.FormatInt(id, 10), 0).Err()
|
||||
}
|
||||
|
||||
func schedulerBucketKey(prefix string, bucket service.SchedulerBucket) string {
|
||||
return fmt.Sprintf("%s%d:%s:%s", prefix, bucket.GroupID, bucket.Platform, bucket.Mode)
|
||||
}
|
||||
|
||||
func schedulerSnapshotKey(bucket service.SchedulerBucket, version string) string {
|
||||
return fmt.Sprintf("%s%d:%s:%s:v%s", schedulerSnapshotPrefix, bucket.GroupID, bucket.Platform, bucket.Mode, version)
|
||||
}
|
||||
|
||||
func schedulerAccountKey(id string) string {
|
||||
return schedulerAccountPrefix + id
|
||||
}
|
||||
|
||||
func ptrTime(t time.Time) *time.Time {
|
||||
return &t
|
||||
}
|
||||
|
||||
func decodeCachedAccount(val any) (*service.Account, error) {
|
||||
var payload []byte
|
||||
switch raw := val.(type) {
|
||||
case string:
|
||||
payload = []byte(raw)
|
||||
case []byte:
|
||||
payload = raw
|
||||
default:
|
||||
return nil, fmt.Errorf("unexpected account cache type: %T", val)
|
||||
}
|
||||
var account service.Account
|
||||
if err := json.Unmarshal(payload, &account); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &account, nil
|
||||
}
|
||||
96
backend/internal/repository/scheduler_outbox_repo.go
Normal file
96
backend/internal/repository/scheduler_outbox_repo.go
Normal file
@@ -0,0 +1,96 @@
|
||||
package repository
|
||||
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"encoding/json"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/service"
|
||||
)
|
||||
|
||||
type schedulerOutboxRepository struct {
|
||||
db *sql.DB
|
||||
}
|
||||
|
||||
func NewSchedulerOutboxRepository(db *sql.DB) service.SchedulerOutboxRepository {
|
||||
return &schedulerOutboxRepository{db: db}
|
||||
}
|
||||
|
||||
func (r *schedulerOutboxRepository) ListAfter(ctx context.Context, afterID int64, limit int) ([]service.SchedulerOutboxEvent, error) {
|
||||
if limit <= 0 {
|
||||
limit = 100
|
||||
}
|
||||
rows, err := r.db.QueryContext(ctx, `
|
||||
SELECT id, event_type, account_id, group_id, payload, created_at
|
||||
FROM scheduler_outbox
|
||||
WHERE id > $1
|
||||
ORDER BY id ASC
|
||||
LIMIT $2
|
||||
`, afterID, limit)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer func() {
|
||||
_ = rows.Close()
|
||||
}()
|
||||
|
||||
events := make([]service.SchedulerOutboxEvent, 0, limit)
|
||||
for rows.Next() {
|
||||
var (
|
||||
payloadRaw []byte
|
||||
accountID sql.NullInt64
|
||||
groupID sql.NullInt64
|
||||
event service.SchedulerOutboxEvent
|
||||
)
|
||||
if err := rows.Scan(&event.ID, &event.EventType, &accountID, &groupID, &payloadRaw, &event.CreatedAt); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if accountID.Valid {
|
||||
v := accountID.Int64
|
||||
event.AccountID = &v
|
||||
}
|
||||
if groupID.Valid {
|
||||
v := groupID.Int64
|
||||
event.GroupID = &v
|
||||
}
|
||||
if len(payloadRaw) > 0 {
|
||||
var payload map[string]any
|
||||
if err := json.Unmarshal(payloadRaw, &payload); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
event.Payload = payload
|
||||
}
|
||||
events = append(events, event)
|
||||
}
|
||||
if err := rows.Err(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return events, nil
|
||||
}
|
||||
|
||||
func (r *schedulerOutboxRepository) MaxID(ctx context.Context) (int64, error) {
|
||||
var maxID int64
|
||||
if err := r.db.QueryRowContext(ctx, "SELECT COALESCE(MAX(id), 0) FROM scheduler_outbox").Scan(&maxID); err != nil {
|
||||
return 0, err
|
||||
}
|
||||
return maxID, nil
|
||||
}
|
||||
|
||||
func enqueueSchedulerOutbox(ctx context.Context, exec sqlExecutor, eventType string, accountID *int64, groupID *int64, payload any) error {
|
||||
if exec == nil {
|
||||
return nil
|
||||
}
|
||||
var payloadArg any
|
||||
if payload != nil {
|
||||
encoded, err := json.Marshal(payload)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
payloadArg = encoded
|
||||
}
|
||||
_, err := exec.ExecContext(ctx, `
|
||||
INSERT INTO scheduler_outbox (event_type, account_id, group_id, payload)
|
||||
VALUES ($1, $2, $3, $4)
|
||||
`, eventType, accountID, groupID, payloadArg)
|
||||
return err
|
||||
}
|
||||
@@ -0,0 +1,68 @@
|
||||
//go:build integration
|
||||
|
||||
package repository
|
||||
|
||||
import (
|
||||
"context"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/config"
|
||||
"github.com/Wei-Shaw/sub2api/internal/service"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestSchedulerSnapshotOutboxReplay(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
rdb := testRedis(t)
|
||||
client := testEntClient(t)
|
||||
|
||||
_, _ = integrationDB.ExecContext(ctx, "TRUNCATE scheduler_outbox")
|
||||
|
||||
accountRepo := newAccountRepositoryWithSQL(client, integrationDB)
|
||||
outboxRepo := NewSchedulerOutboxRepository(integrationDB)
|
||||
cache := NewSchedulerCache(rdb)
|
||||
|
||||
cfg := &config.Config{
|
||||
RunMode: config.RunModeStandard,
|
||||
Gateway: config.GatewayConfig{
|
||||
Scheduling: config.GatewaySchedulingConfig{
|
||||
OutboxPollIntervalSeconds: 1,
|
||||
FullRebuildIntervalSeconds: 0,
|
||||
DbFallbackEnabled: true,
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
account := &service.Account{
|
||||
Name: "outbox-replay-" + time.Now().Format("150405.000000"),
|
||||
Platform: service.PlatformOpenAI,
|
||||
Type: service.AccountTypeAPIKey,
|
||||
Status: service.StatusActive,
|
||||
Schedulable: true,
|
||||
Concurrency: 3,
|
||||
Priority: 1,
|
||||
Credentials: map[string]any{},
|
||||
Extra: map[string]any{},
|
||||
}
|
||||
require.NoError(t, accountRepo.Create(ctx, account))
|
||||
require.NoError(t, cache.SetAccount(ctx, account))
|
||||
|
||||
svc := service.NewSchedulerSnapshotService(cache, outboxRepo, accountRepo, nil, cfg)
|
||||
svc.Start()
|
||||
t.Cleanup(svc.Stop)
|
||||
|
||||
require.NoError(t, accountRepo.UpdateLastUsed(ctx, account.ID))
|
||||
updated, err := accountRepo.GetByID(ctx, account.ID)
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, updated.LastUsedAt)
|
||||
expectedUnix := updated.LastUsedAt.Unix()
|
||||
|
||||
require.Eventually(t, func() bool {
|
||||
cached, err := cache.GetAccount(ctx, account.ID)
|
||||
if err != nil || cached == nil || cached.LastUsedAt == nil {
|
||||
return false
|
||||
}
|
||||
return cached.LastUsedAt.Unix() == expectedUnix
|
||||
}, 5*time.Second, 100*time.Millisecond)
|
||||
}
|
||||
321
backend/internal/repository/session_limit_cache.go
Normal file
321
backend/internal/repository/session_limit_cache.go
Normal file
@@ -0,0 +1,321 @@
|
||||
package repository
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"strconv"
|
||||
"time"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/service"
|
||||
"github.com/redis/go-redis/v9"
|
||||
)
|
||||
|
||||
// 会话限制缓存常量定义
|
||||
//
|
||||
// 设计说明:
|
||||
// 使用 Redis 有序集合(Sorted Set)跟踪每个账号的活跃会话:
|
||||
// - Key: session_limit:account:{accountID}
|
||||
// - Member: sessionUUID(从 metadata.user_id 中提取)
|
||||
// - Score: Unix 时间戳(会话最后活跃时间)
|
||||
//
|
||||
// 通过 ZREMRANGEBYSCORE 自动清理过期会话,无需手动管理 TTL
|
||||
const (
|
||||
// 会话限制键前缀
|
||||
// 格式: session_limit:account:{accountID}
|
||||
sessionLimitKeyPrefix = "session_limit:account:"
|
||||
|
||||
// 窗口费用缓存键前缀
|
||||
// 格式: window_cost:account:{accountID}
|
||||
windowCostKeyPrefix = "window_cost:account:"
|
||||
|
||||
// 窗口费用缓存 TTL(30秒)
|
||||
windowCostCacheTTL = 30 * time.Second
|
||||
)
|
||||
|
||||
var (
|
||||
// registerSessionScript 注册会话活动
|
||||
// 使用 Redis TIME 命令获取服务器时间,避免多实例时钟不同步
|
||||
// KEYS[1] = session_limit:account:{accountID}
|
||||
// ARGV[1] = maxSessions
|
||||
// ARGV[2] = idleTimeout(秒)
|
||||
// ARGV[3] = sessionUUID
|
||||
// 返回: 1 = 允许, 0 = 拒绝
|
||||
registerSessionScript = redis.NewScript(`
|
||||
local key = KEYS[1]
|
||||
local maxSessions = tonumber(ARGV[1])
|
||||
local idleTimeout = tonumber(ARGV[2])
|
||||
local sessionUUID = ARGV[3]
|
||||
|
||||
-- 使用 Redis 服务器时间,确保多实例时钟一致
|
||||
local timeResult = redis.call('TIME')
|
||||
local now = tonumber(timeResult[1])
|
||||
local expireBefore = now - idleTimeout
|
||||
|
||||
-- 清理过期会话
|
||||
redis.call('ZREMRANGEBYSCORE', key, '-inf', expireBefore)
|
||||
|
||||
-- 检查会话是否已存在(支持刷新时间戳)
|
||||
local exists = redis.call('ZSCORE', key, sessionUUID)
|
||||
if exists ~= false then
|
||||
-- 会话已存在,刷新时间戳
|
||||
redis.call('ZADD', key, now, sessionUUID)
|
||||
redis.call('EXPIRE', key, idleTimeout + 60)
|
||||
return 1
|
||||
end
|
||||
|
||||
-- 检查是否达到会话数量上限
|
||||
local count = redis.call('ZCARD', key)
|
||||
if count < maxSessions then
|
||||
-- 未达上限,添加新会话
|
||||
redis.call('ZADD', key, now, sessionUUID)
|
||||
redis.call('EXPIRE', key, idleTimeout + 60)
|
||||
return 1
|
||||
end
|
||||
|
||||
-- 达到上限,拒绝新会话
|
||||
return 0
|
||||
`)
|
||||
|
||||
// refreshSessionScript 刷新会话时间戳
|
||||
// KEYS[1] = session_limit:account:{accountID}
|
||||
// ARGV[1] = idleTimeout(秒)
|
||||
// ARGV[2] = sessionUUID
|
||||
refreshSessionScript = redis.NewScript(`
|
||||
local key = KEYS[1]
|
||||
local idleTimeout = tonumber(ARGV[1])
|
||||
local sessionUUID = ARGV[2]
|
||||
|
||||
local timeResult = redis.call('TIME')
|
||||
local now = tonumber(timeResult[1])
|
||||
|
||||
-- 检查会话是否存在
|
||||
local exists = redis.call('ZSCORE', key, sessionUUID)
|
||||
if exists ~= false then
|
||||
redis.call('ZADD', key, now, sessionUUID)
|
||||
redis.call('EXPIRE', key, idleTimeout + 60)
|
||||
end
|
||||
return 1
|
||||
`)
|
||||
|
||||
// getActiveSessionCountScript 获取活跃会话数
|
||||
// KEYS[1] = session_limit:account:{accountID}
|
||||
// ARGV[1] = idleTimeout(秒)
|
||||
getActiveSessionCountScript = redis.NewScript(`
|
||||
local key = KEYS[1]
|
||||
local idleTimeout = tonumber(ARGV[1])
|
||||
|
||||
local timeResult = redis.call('TIME')
|
||||
local now = tonumber(timeResult[1])
|
||||
local expireBefore = now - idleTimeout
|
||||
|
||||
-- 清理过期会话
|
||||
redis.call('ZREMRANGEBYSCORE', key, '-inf', expireBefore)
|
||||
|
||||
return redis.call('ZCARD', key)
|
||||
`)
|
||||
|
||||
// isSessionActiveScript 检查会话是否活跃
|
||||
// KEYS[1] = session_limit:account:{accountID}
|
||||
// ARGV[1] = idleTimeout(秒)
|
||||
// ARGV[2] = sessionUUID
|
||||
isSessionActiveScript = redis.NewScript(`
|
||||
local key = KEYS[1]
|
||||
local idleTimeout = tonumber(ARGV[1])
|
||||
local sessionUUID = ARGV[2]
|
||||
|
||||
local timeResult = redis.call('TIME')
|
||||
local now = tonumber(timeResult[1])
|
||||
local expireBefore = now - idleTimeout
|
||||
|
||||
-- 获取会话的时间戳
|
||||
local score = redis.call('ZSCORE', key, sessionUUID)
|
||||
if score == false then
|
||||
return 0
|
||||
end
|
||||
|
||||
-- 检查是否过期
|
||||
if tonumber(score) <= expireBefore then
|
||||
return 0
|
||||
end
|
||||
|
||||
return 1
|
||||
`)
|
||||
)
|
||||
|
||||
type sessionLimitCache struct {
|
||||
rdb *redis.Client
|
||||
defaultIdleTimeout time.Duration // 默认空闲超时(用于 GetActiveSessionCount)
|
||||
}
|
||||
|
||||
// NewSessionLimitCache 创建会话限制缓存
|
||||
// defaultIdleTimeoutMinutes: 默认空闲超时时间(分钟),用于无参数查询
|
||||
func NewSessionLimitCache(rdb *redis.Client, defaultIdleTimeoutMinutes int) service.SessionLimitCache {
|
||||
if defaultIdleTimeoutMinutes <= 0 {
|
||||
defaultIdleTimeoutMinutes = 5 // 默认 5 分钟
|
||||
}
|
||||
return &sessionLimitCache{
|
||||
rdb: rdb,
|
||||
defaultIdleTimeout: time.Duration(defaultIdleTimeoutMinutes) * time.Minute,
|
||||
}
|
||||
}
|
||||
|
||||
// sessionLimitKey 生成会话限制的 Redis 键
|
||||
func sessionLimitKey(accountID int64) string {
|
||||
return fmt.Sprintf("%s%d", sessionLimitKeyPrefix, accountID)
|
||||
}
|
||||
|
||||
// windowCostKey 生成窗口费用缓存的 Redis 键
|
||||
func windowCostKey(accountID int64) string {
|
||||
return fmt.Sprintf("%s%d", windowCostKeyPrefix, accountID)
|
||||
}
|
||||
|
||||
// RegisterSession 注册会话活动
|
||||
func (c *sessionLimitCache) RegisterSession(ctx context.Context, accountID int64, sessionUUID string, maxSessions int, idleTimeout time.Duration) (bool, error) {
|
||||
if sessionUUID == "" || maxSessions <= 0 {
|
||||
return true, nil // 无效参数,默认允许
|
||||
}
|
||||
|
||||
key := sessionLimitKey(accountID)
|
||||
idleTimeoutSeconds := int(idleTimeout.Seconds())
|
||||
if idleTimeoutSeconds <= 0 {
|
||||
idleTimeoutSeconds = int(c.defaultIdleTimeout.Seconds())
|
||||
}
|
||||
|
||||
result, err := registerSessionScript.Run(ctx, c.rdb, []string{key}, maxSessions, idleTimeoutSeconds, sessionUUID).Int()
|
||||
if err != nil {
|
||||
return true, err // 失败开放:缓存错误时允许请求通过
|
||||
}
|
||||
return result == 1, nil
|
||||
}
|
||||
|
||||
// RefreshSession 刷新会话时间戳
|
||||
func (c *sessionLimitCache) RefreshSession(ctx context.Context, accountID int64, sessionUUID string, idleTimeout time.Duration) error {
|
||||
if sessionUUID == "" {
|
||||
return nil
|
||||
}
|
||||
|
||||
key := sessionLimitKey(accountID)
|
||||
idleTimeoutSeconds := int(idleTimeout.Seconds())
|
||||
if idleTimeoutSeconds <= 0 {
|
||||
idleTimeoutSeconds = int(c.defaultIdleTimeout.Seconds())
|
||||
}
|
||||
|
||||
_, err := refreshSessionScript.Run(ctx, c.rdb, []string{key}, idleTimeoutSeconds, sessionUUID).Result()
|
||||
return err
|
||||
}
|
||||
|
||||
// GetActiveSessionCount 获取活跃会话数
|
||||
func (c *sessionLimitCache) GetActiveSessionCount(ctx context.Context, accountID int64) (int, error) {
|
||||
key := sessionLimitKey(accountID)
|
||||
idleTimeoutSeconds := int(c.defaultIdleTimeout.Seconds())
|
||||
|
||||
result, err := getActiveSessionCountScript.Run(ctx, c.rdb, []string{key}, idleTimeoutSeconds).Int()
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
return result, nil
|
||||
}
|
||||
|
||||
// GetActiveSessionCountBatch 批量获取多个账号的活跃会话数
|
||||
func (c *sessionLimitCache) GetActiveSessionCountBatch(ctx context.Context, accountIDs []int64) (map[int64]int, error) {
|
||||
if len(accountIDs) == 0 {
|
||||
return make(map[int64]int), nil
|
||||
}
|
||||
|
||||
results := make(map[int64]int, len(accountIDs))
|
||||
|
||||
// 使用 pipeline 批量执行
|
||||
pipe := c.rdb.Pipeline()
|
||||
idleTimeoutSeconds := int(c.defaultIdleTimeout.Seconds())
|
||||
|
||||
cmds := make(map[int64]*redis.Cmd, len(accountIDs))
|
||||
for _, accountID := range accountIDs {
|
||||
key := sessionLimitKey(accountID)
|
||||
cmds[accountID] = getActiveSessionCountScript.Run(ctx, pipe, []string{key}, idleTimeoutSeconds)
|
||||
}
|
||||
|
||||
// 执行 pipeline,即使部分失败也尝试获取成功的结果
|
||||
_, _ = pipe.Exec(ctx)
|
||||
|
||||
for accountID, cmd := range cmds {
|
||||
if result, err := cmd.Int(); err == nil {
|
||||
results[accountID] = result
|
||||
}
|
||||
}
|
||||
|
||||
return results, nil
|
||||
}
|
||||
|
||||
// IsSessionActive 检查会话是否活跃
|
||||
func (c *sessionLimitCache) IsSessionActive(ctx context.Context, accountID int64, sessionUUID string) (bool, error) {
|
||||
if sessionUUID == "" {
|
||||
return false, nil
|
||||
}
|
||||
|
||||
key := sessionLimitKey(accountID)
|
||||
idleTimeoutSeconds := int(c.defaultIdleTimeout.Seconds())
|
||||
|
||||
result, err := isSessionActiveScript.Run(ctx, c.rdb, []string{key}, idleTimeoutSeconds, sessionUUID).Int()
|
||||
if err != nil {
|
||||
return false, err
|
||||
}
|
||||
return result == 1, nil
|
||||
}
|
||||
|
||||
// ========== 5h窗口费用缓存实现 ==========
|
||||
|
||||
// GetWindowCost 获取缓存的窗口费用
|
||||
func (c *sessionLimitCache) GetWindowCost(ctx context.Context, accountID int64) (float64, bool, error) {
|
||||
key := windowCostKey(accountID)
|
||||
val, err := c.rdb.Get(ctx, key).Float64()
|
||||
if err == redis.Nil {
|
||||
return 0, false, nil // 缓存未命中
|
||||
}
|
||||
if err != nil {
|
||||
return 0, false, err
|
||||
}
|
||||
return val, true, nil
|
||||
}
|
||||
|
||||
// SetWindowCost 设置窗口费用缓存
|
||||
func (c *sessionLimitCache) SetWindowCost(ctx context.Context, accountID int64, cost float64) error {
|
||||
key := windowCostKey(accountID)
|
||||
return c.rdb.Set(ctx, key, cost, windowCostCacheTTL).Err()
|
||||
}
|
||||
|
||||
// GetWindowCostBatch 批量获取窗口费用缓存
|
||||
func (c *sessionLimitCache) GetWindowCostBatch(ctx context.Context, accountIDs []int64) (map[int64]float64, error) {
|
||||
if len(accountIDs) == 0 {
|
||||
return make(map[int64]float64), nil
|
||||
}
|
||||
|
||||
// 构建批量查询的 keys
|
||||
keys := make([]string, len(accountIDs))
|
||||
for i, accountID := range accountIDs {
|
||||
keys[i] = windowCostKey(accountID)
|
||||
}
|
||||
|
||||
// 使用 MGET 批量获取
|
||||
vals, err := c.rdb.MGet(ctx, keys...).Result()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
results := make(map[int64]float64, len(accountIDs))
|
||||
for i, val := range vals {
|
||||
if val == nil {
|
||||
continue // 缓存未命中
|
||||
}
|
||||
// 尝试解析为 float64
|
||||
switch v := val.(type) {
|
||||
case string:
|
||||
if cost, err := strconv.ParseFloat(v, 64); err == nil {
|
||||
results[accountIDs[i]] = cost
|
||||
}
|
||||
case float64:
|
||||
results[accountIDs[i]] = v
|
||||
}
|
||||
}
|
||||
|
||||
return results, nil
|
||||
}
|
||||
80
backend/internal/repository/timeout_counter_cache.go
Normal file
80
backend/internal/repository/timeout_counter_cache.go
Normal file
@@ -0,0 +1,80 @@
|
||||
package repository
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"time"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/service"
|
||||
"github.com/redis/go-redis/v9"
|
||||
)
|
||||
|
||||
const timeoutCounterPrefix = "timeout_count:account:"
|
||||
|
||||
// timeoutCounterIncrScript 使用 Lua 脚本原子性地增加计数并返回当前值
|
||||
// 如果 key 不存在,则创建并设置过期时间
|
||||
var timeoutCounterIncrScript = redis.NewScript(`
|
||||
local key = KEYS[1]
|
||||
local ttl = tonumber(ARGV[1])
|
||||
|
||||
local count = redis.call('INCR', key)
|
||||
if count == 1 then
|
||||
redis.call('EXPIRE', key, ttl)
|
||||
end
|
||||
|
||||
return count
|
||||
`)
|
||||
|
||||
type timeoutCounterCache struct {
|
||||
rdb *redis.Client
|
||||
}
|
||||
|
||||
// NewTimeoutCounterCache 创建超时计数器缓存实例
|
||||
func NewTimeoutCounterCache(rdb *redis.Client) service.TimeoutCounterCache {
|
||||
return &timeoutCounterCache{rdb: rdb}
|
||||
}
|
||||
|
||||
// IncrementTimeoutCount 增加账户的超时计数,返回当前计数值
|
||||
// windowMinutes 是计数窗口时间(分钟),超过此时间计数器会自动重置
|
||||
func (c *timeoutCounterCache) IncrementTimeoutCount(ctx context.Context, accountID int64, windowMinutes int) (int64, error) {
|
||||
key := fmt.Sprintf("%s%d", timeoutCounterPrefix, accountID)
|
||||
|
||||
ttlSeconds := windowMinutes * 60
|
||||
if ttlSeconds < 60 {
|
||||
ttlSeconds = 60 // 最小1分钟
|
||||
}
|
||||
|
||||
result, err := timeoutCounterIncrScript.Run(ctx, c.rdb, []string{key}, ttlSeconds).Int64()
|
||||
if err != nil {
|
||||
return 0, fmt.Errorf("increment timeout count: %w", err)
|
||||
}
|
||||
|
||||
return result, nil
|
||||
}
|
||||
|
||||
// GetTimeoutCount 获取账户当前的超时计数
|
||||
func (c *timeoutCounterCache) GetTimeoutCount(ctx context.Context, accountID int64) (int64, error) {
|
||||
key := fmt.Sprintf("%s%d", timeoutCounterPrefix, accountID)
|
||||
|
||||
val, err := c.rdb.Get(ctx, key).Int64()
|
||||
if err == redis.Nil {
|
||||
return 0, nil
|
||||
}
|
||||
if err != nil {
|
||||
return 0, fmt.Errorf("get timeout count: %w", err)
|
||||
}
|
||||
|
||||
return val, nil
|
||||
}
|
||||
|
||||
// ResetTimeoutCount 重置账户的超时计数
|
||||
func (c *timeoutCounterCache) ResetTimeoutCount(ctx context.Context, accountID int64) error {
|
||||
key := fmt.Sprintf("%s%d", timeoutCounterPrefix, accountID)
|
||||
return c.rdb.Del(ctx, key).Err()
|
||||
}
|
||||
|
||||
// GetTimeoutCountTTL 获取计数器剩余过期时间
|
||||
func (c *timeoutCounterCache) GetTimeoutCountTTL(ctx context.Context, accountID int64) (time.Duration, error) {
|
||||
key := fmt.Sprintf("%s%d", timeoutCounterPrefix, accountID)
|
||||
return c.rdb.TTL(ctx, key).Result()
|
||||
}
|
||||
@@ -22,7 +22,7 @@ import (
|
||||
"github.com/lib/pq"
|
||||
)
|
||||
|
||||
const usageLogSelectColumns = "id, user_id, api_key_id, account_id, request_id, model, group_id, subscription_id, input_tokens, output_tokens, cache_creation_tokens, cache_read_tokens, cache_creation_5m_tokens, cache_creation_1h_tokens, input_cost, output_cost, cache_creation_cost, cache_read_cost, total_cost, actual_cost, rate_multiplier, billing_type, stream, duration_ms, first_token_ms, user_agent, image_count, image_size, created_at"
|
||||
const usageLogSelectColumns = "id, user_id, api_key_id, account_id, request_id, model, group_id, subscription_id, input_tokens, output_tokens, cache_creation_tokens, cache_read_tokens, cache_creation_5m_tokens, cache_creation_1h_tokens, input_cost, output_cost, cache_creation_cost, cache_read_cost, total_cost, actual_cost, rate_multiplier, account_rate_multiplier, billing_type, stream, duration_ms, first_token_ms, user_agent, ip_address, image_count, image_size, created_at"
|
||||
|
||||
type usageLogRepository struct {
|
||||
client *dbent.Client
|
||||
@@ -105,11 +105,13 @@ func (r *usageLogRepository) Create(ctx context.Context, log *service.UsageLog)
|
||||
total_cost,
|
||||
actual_cost,
|
||||
rate_multiplier,
|
||||
account_rate_multiplier,
|
||||
billing_type,
|
||||
stream,
|
||||
duration_ms,
|
||||
first_token_ms,
|
||||
user_agent,
|
||||
ip_address,
|
||||
image_count,
|
||||
image_size,
|
||||
created_at
|
||||
@@ -119,7 +121,7 @@ func (r *usageLogRepository) Create(ctx context.Context, log *service.UsageLog)
|
||||
$8, $9, $10, $11,
|
||||
$12, $13,
|
||||
$14, $15, $16, $17, $18, $19,
|
||||
$20, $21, $22, $23, $24, $25, $26, $27, $28
|
||||
$20, $21, $22, $23, $24, $25, $26, $27, $28, $29, $30
|
||||
)
|
||||
ON CONFLICT (request_id, api_key_id) DO NOTHING
|
||||
RETURNING id, created_at
|
||||
@@ -130,6 +132,7 @@ func (r *usageLogRepository) Create(ctx context.Context, log *service.UsageLog)
|
||||
duration := nullInt(log.DurationMs)
|
||||
firstToken := nullInt(log.FirstTokenMs)
|
||||
userAgent := nullString(log.UserAgent)
|
||||
ipAddress := nullString(log.IPAddress)
|
||||
imageSize := nullString(log.ImageSize)
|
||||
|
||||
var requestIDArg any
|
||||
@@ -158,11 +161,13 @@ func (r *usageLogRepository) Create(ctx context.Context, log *service.UsageLog)
|
||||
log.TotalCost,
|
||||
log.ActualCost,
|
||||
rateMultiplier,
|
||||
log.AccountRateMultiplier,
|
||||
log.BillingType,
|
||||
log.Stream,
|
||||
duration,
|
||||
firstToken,
|
||||
userAgent,
|
||||
ipAddress,
|
||||
log.ImageCount,
|
||||
imageSize,
|
||||
createdAt,
|
||||
@@ -266,16 +271,60 @@ func (r *usageLogRepository) GetUserStats(ctx context.Context, userID int64, sta
|
||||
type DashboardStats = usagestats.DashboardStats
|
||||
|
||||
func (r *usageLogRepository) GetDashboardStats(ctx context.Context) (*DashboardStats, error) {
|
||||
var stats DashboardStats
|
||||
today := timezone.Today()
|
||||
now := time.Now()
|
||||
stats := &DashboardStats{}
|
||||
now := timezone.Now()
|
||||
todayStart := timezone.Today()
|
||||
|
||||
// 合并用户统计查询
|
||||
if err := r.fillDashboardEntityStats(ctx, stats, todayStart, now); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if err := r.fillDashboardUsageStatsAggregated(ctx, stats, todayStart, now); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
rpm, tpm, err := r.getPerformanceStats(ctx, 0)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
stats.Rpm = rpm
|
||||
stats.Tpm = tpm
|
||||
|
||||
return stats, nil
|
||||
}
|
||||
|
||||
func (r *usageLogRepository) GetDashboardStatsWithRange(ctx context.Context, start, end time.Time) (*DashboardStats, error) {
|
||||
startUTC := start.UTC()
|
||||
endUTC := end.UTC()
|
||||
if !endUTC.After(startUTC) {
|
||||
return nil, errors.New("统计时间范围无效")
|
||||
}
|
||||
|
||||
stats := &DashboardStats{}
|
||||
now := timezone.Now()
|
||||
todayStart := timezone.Today()
|
||||
|
||||
if err := r.fillDashboardEntityStats(ctx, stats, todayStart, now); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if err := r.fillDashboardUsageStatsFromUsageLogs(ctx, stats, startUTC, endUTC, todayStart, now); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
rpm, tpm, err := r.getPerformanceStats(ctx, 0)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
stats.Rpm = rpm
|
||||
stats.Tpm = tpm
|
||||
|
||||
return stats, nil
|
||||
}
|
||||
|
||||
func (r *usageLogRepository) fillDashboardEntityStats(ctx context.Context, stats *DashboardStats, todayUTC, now time.Time) error {
|
||||
userStatsQuery := `
|
||||
SELECT
|
||||
COUNT(*) as total_users,
|
||||
COUNT(CASE WHEN created_at >= $1 THEN 1 END) as today_new_users,
|
||||
(SELECT COUNT(DISTINCT user_id) FROM usage_logs WHERE created_at >= $2) as active_users
|
||||
COUNT(CASE WHEN created_at >= $1 THEN 1 END) as today_new_users
|
||||
FROM users
|
||||
WHERE deleted_at IS NULL
|
||||
`
|
||||
@@ -283,15 +332,13 @@ func (r *usageLogRepository) GetDashboardStats(ctx context.Context) (*DashboardS
|
||||
ctx,
|
||||
r.sql,
|
||||
userStatsQuery,
|
||||
[]any{today, today},
|
||||
[]any{todayUTC},
|
||||
&stats.TotalUsers,
|
||||
&stats.TodayNewUsers,
|
||||
&stats.ActiveUsers,
|
||||
); err != nil {
|
||||
return nil, err
|
||||
return err
|
||||
}
|
||||
|
||||
// 合并API Key统计查询
|
||||
apiKeyStatsQuery := `
|
||||
SELECT
|
||||
COUNT(*) as total_api_keys,
|
||||
@@ -307,10 +354,9 @@ func (r *usageLogRepository) GetDashboardStats(ctx context.Context) (*DashboardS
|
||||
&stats.TotalAPIKeys,
|
||||
&stats.ActiveAPIKeys,
|
||||
); err != nil {
|
||||
return nil, err
|
||||
return err
|
||||
}
|
||||
|
||||
// 合并账户统计查询
|
||||
accountStatsQuery := `
|
||||
SELECT
|
||||
COUNT(*) as total_accounts,
|
||||
@@ -332,22 +378,26 @@ func (r *usageLogRepository) GetDashboardStats(ctx context.Context) (*DashboardS
|
||||
&stats.RateLimitAccounts,
|
||||
&stats.OverloadAccounts,
|
||||
); err != nil {
|
||||
return nil, err
|
||||
return err
|
||||
}
|
||||
|
||||
// 累计 Token 统计
|
||||
return nil
|
||||
}
|
||||
|
||||
func (r *usageLogRepository) fillDashboardUsageStatsAggregated(ctx context.Context, stats *DashboardStats, todayUTC, now time.Time) error {
|
||||
totalStatsQuery := `
|
||||
SELECT
|
||||
COUNT(*) as total_requests,
|
||||
COALESCE(SUM(total_requests), 0) as total_requests,
|
||||
COALESCE(SUM(input_tokens), 0) as total_input_tokens,
|
||||
COALESCE(SUM(output_tokens), 0) as total_output_tokens,
|
||||
COALESCE(SUM(cache_creation_tokens), 0) as total_cache_creation_tokens,
|
||||
COALESCE(SUM(cache_read_tokens), 0) as total_cache_read_tokens,
|
||||
COALESCE(SUM(total_cost), 0) as total_cost,
|
||||
COALESCE(SUM(actual_cost), 0) as total_actual_cost,
|
||||
COALESCE(AVG(duration_ms), 0) as avg_duration_ms
|
||||
FROM usage_logs
|
||||
COALESCE(SUM(total_duration_ms), 0) as total_duration_ms
|
||||
FROM usage_dashboard_daily
|
||||
`
|
||||
var totalDurationMs int64
|
||||
if err := scanSingleRow(
|
||||
ctx,
|
||||
r.sql,
|
||||
@@ -360,13 +410,100 @@ func (r *usageLogRepository) GetDashboardStats(ctx context.Context) (*DashboardS
|
||||
&stats.TotalCacheReadTokens,
|
||||
&stats.TotalCost,
|
||||
&stats.TotalActualCost,
|
||||
&stats.AverageDurationMs,
|
||||
&totalDurationMs,
|
||||
); err != nil {
|
||||
return nil, err
|
||||
return err
|
||||
}
|
||||
stats.TotalTokens = stats.TotalInputTokens + stats.TotalOutputTokens + stats.TotalCacheCreationTokens + stats.TotalCacheReadTokens
|
||||
if stats.TotalRequests > 0 {
|
||||
stats.AverageDurationMs = float64(totalDurationMs) / float64(stats.TotalRequests)
|
||||
}
|
||||
|
||||
// 今日 Token 统计
|
||||
todayStatsQuery := `
|
||||
SELECT
|
||||
total_requests as today_requests,
|
||||
input_tokens as today_input_tokens,
|
||||
output_tokens as today_output_tokens,
|
||||
cache_creation_tokens as today_cache_creation_tokens,
|
||||
cache_read_tokens as today_cache_read_tokens,
|
||||
total_cost as today_cost,
|
||||
actual_cost as today_actual_cost,
|
||||
active_users as active_users
|
||||
FROM usage_dashboard_daily
|
||||
WHERE bucket_date = $1::date
|
||||
`
|
||||
if err := scanSingleRow(
|
||||
ctx,
|
||||
r.sql,
|
||||
todayStatsQuery,
|
||||
[]any{todayUTC},
|
||||
&stats.TodayRequests,
|
||||
&stats.TodayInputTokens,
|
||||
&stats.TodayOutputTokens,
|
||||
&stats.TodayCacheCreationTokens,
|
||||
&stats.TodayCacheReadTokens,
|
||||
&stats.TodayCost,
|
||||
&stats.TodayActualCost,
|
||||
&stats.ActiveUsers,
|
||||
); err != nil {
|
||||
if err != sql.ErrNoRows {
|
||||
return err
|
||||
}
|
||||
}
|
||||
stats.TodayTokens = stats.TodayInputTokens + stats.TodayOutputTokens + stats.TodayCacheCreationTokens + stats.TodayCacheReadTokens
|
||||
|
||||
hourlyActiveQuery := `
|
||||
SELECT active_users
|
||||
FROM usage_dashboard_hourly
|
||||
WHERE bucket_start = $1
|
||||
`
|
||||
hourStart := now.In(timezone.Location()).Truncate(time.Hour)
|
||||
if err := scanSingleRow(ctx, r.sql, hourlyActiveQuery, []any{hourStart}, &stats.HourlyActiveUsers); err != nil {
|
||||
if err != sql.ErrNoRows {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (r *usageLogRepository) fillDashboardUsageStatsFromUsageLogs(ctx context.Context, stats *DashboardStats, startUTC, endUTC, todayUTC, now time.Time) error {
|
||||
totalStatsQuery := `
|
||||
SELECT
|
||||
COUNT(*) as total_requests,
|
||||
COALESCE(SUM(input_tokens), 0) as total_input_tokens,
|
||||
COALESCE(SUM(output_tokens), 0) as total_output_tokens,
|
||||
COALESCE(SUM(cache_creation_tokens), 0) as total_cache_creation_tokens,
|
||||
COALESCE(SUM(cache_read_tokens), 0) as total_cache_read_tokens,
|
||||
COALESCE(SUM(total_cost), 0) as total_cost,
|
||||
COALESCE(SUM(actual_cost), 0) as total_actual_cost,
|
||||
COALESCE(SUM(COALESCE(duration_ms, 0)), 0) as total_duration_ms
|
||||
FROM usage_logs
|
||||
WHERE created_at >= $1 AND created_at < $2
|
||||
`
|
||||
var totalDurationMs int64
|
||||
if err := scanSingleRow(
|
||||
ctx,
|
||||
r.sql,
|
||||
totalStatsQuery,
|
||||
[]any{startUTC, endUTC},
|
||||
&stats.TotalRequests,
|
||||
&stats.TotalInputTokens,
|
||||
&stats.TotalOutputTokens,
|
||||
&stats.TotalCacheCreationTokens,
|
||||
&stats.TotalCacheReadTokens,
|
||||
&stats.TotalCost,
|
||||
&stats.TotalActualCost,
|
||||
&totalDurationMs,
|
||||
); err != nil {
|
||||
return err
|
||||
}
|
||||
stats.TotalTokens = stats.TotalInputTokens + stats.TotalOutputTokens + stats.TotalCacheCreationTokens + stats.TotalCacheReadTokens
|
||||
if stats.TotalRequests > 0 {
|
||||
stats.AverageDurationMs = float64(totalDurationMs) / float64(stats.TotalRequests)
|
||||
}
|
||||
|
||||
todayEnd := todayUTC.Add(24 * time.Hour)
|
||||
todayStatsQuery := `
|
||||
SELECT
|
||||
COUNT(*) as today_requests,
|
||||
@@ -377,13 +514,13 @@ func (r *usageLogRepository) GetDashboardStats(ctx context.Context) (*DashboardS
|
||||
COALESCE(SUM(total_cost), 0) as today_cost,
|
||||
COALESCE(SUM(actual_cost), 0) as today_actual_cost
|
||||
FROM usage_logs
|
||||
WHERE created_at >= $1
|
||||
WHERE created_at >= $1 AND created_at < $2
|
||||
`
|
||||
if err := scanSingleRow(
|
||||
ctx,
|
||||
r.sql,
|
||||
todayStatsQuery,
|
||||
[]any{today},
|
||||
[]any{todayUTC, todayEnd},
|
||||
&stats.TodayRequests,
|
||||
&stats.TodayInputTokens,
|
||||
&stats.TodayOutputTokens,
|
||||
@@ -392,19 +529,31 @@ func (r *usageLogRepository) GetDashboardStats(ctx context.Context) (*DashboardS
|
||||
&stats.TodayCost,
|
||||
&stats.TodayActualCost,
|
||||
); err != nil {
|
||||
return nil, err
|
||||
return err
|
||||
}
|
||||
stats.TodayTokens = stats.TodayInputTokens + stats.TodayOutputTokens + stats.TodayCacheCreationTokens + stats.TodayCacheReadTokens
|
||||
|
||||
// 性能指标:RPM 和 TPM(最近1分钟,全局)
|
||||
rpm, tpm, err := r.getPerformanceStats(ctx, 0)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
activeUsersQuery := `
|
||||
SELECT COUNT(DISTINCT user_id) as active_users
|
||||
FROM usage_logs
|
||||
WHERE created_at >= $1 AND created_at < $2
|
||||
`
|
||||
if err := scanSingleRow(ctx, r.sql, activeUsersQuery, []any{todayUTC, todayEnd}, &stats.ActiveUsers); err != nil {
|
||||
return err
|
||||
}
|
||||
stats.Rpm = rpm
|
||||
stats.Tpm = tpm
|
||||
|
||||
return &stats, nil
|
||||
hourStart := now.UTC().Truncate(time.Hour)
|
||||
hourEnd := hourStart.Add(time.Hour)
|
||||
hourlyActiveQuery := `
|
||||
SELECT COUNT(DISTINCT user_id) as active_users
|
||||
FROM usage_logs
|
||||
WHERE created_at >= $1 AND created_at < $2
|
||||
`
|
||||
if err := scanSingleRow(ctx, r.sql, hourlyActiveQuery, []any{hourStart, hourEnd}, &stats.HourlyActiveUsers); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (r *usageLogRepository) ListByAccount(ctx context.Context, accountID int64, params pagination.PaginationParams) ([]service.UsageLog, *pagination.PaginationResult, error) {
|
||||
@@ -688,7 +837,9 @@ func (r *usageLogRepository) GetAccountTodayStats(ctx context.Context, accountID
|
||||
SELECT
|
||||
COUNT(*) as requests,
|
||||
COALESCE(SUM(input_tokens + output_tokens + cache_creation_tokens + cache_read_tokens), 0) as tokens,
|
||||
COALESCE(SUM(actual_cost), 0) as cost
|
||||
COALESCE(SUM(total_cost * COALESCE(account_rate_multiplier, 1)), 0) as cost,
|
||||
COALESCE(SUM(total_cost), 0) as standard_cost,
|
||||
COALESCE(SUM(actual_cost), 0) as user_cost
|
||||
FROM usage_logs
|
||||
WHERE account_id = $1 AND created_at >= $2
|
||||
`
|
||||
@@ -702,6 +853,8 @@ func (r *usageLogRepository) GetAccountTodayStats(ctx context.Context, accountID
|
||||
&stats.Requests,
|
||||
&stats.Tokens,
|
||||
&stats.Cost,
|
||||
&stats.StandardCost,
|
||||
&stats.UserCost,
|
||||
); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@@ -714,7 +867,9 @@ func (r *usageLogRepository) GetAccountWindowStats(ctx context.Context, accountI
|
||||
SELECT
|
||||
COUNT(*) as requests,
|
||||
COALESCE(SUM(input_tokens + output_tokens + cache_creation_tokens + cache_read_tokens), 0) as tokens,
|
||||
COALESCE(SUM(actual_cost), 0) as cost
|
||||
COALESCE(SUM(total_cost * COALESCE(account_rate_multiplier, 1)), 0) as cost,
|
||||
COALESCE(SUM(total_cost), 0) as standard_cost,
|
||||
COALESCE(SUM(actual_cost), 0) as user_cost
|
||||
FROM usage_logs
|
||||
WHERE account_id = $1 AND created_at >= $2
|
||||
`
|
||||
@@ -728,6 +883,8 @@ func (r *usageLogRepository) GetAccountWindowStats(ctx context.Context, accountI
|
||||
&stats.Requests,
|
||||
&stats.Tokens,
|
||||
&stats.Cost,
|
||||
&stats.StandardCost,
|
||||
&stats.UserCost,
|
||||
); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@@ -1253,8 +1410,8 @@ func (r *usageLogRepository) GetBatchAPIKeyUsageStats(ctx context.Context, apiKe
|
||||
return result, nil
|
||||
}
|
||||
|
||||
// GetUsageTrendWithFilters returns usage trend data with optional user/api_key filters
|
||||
func (r *usageLogRepository) GetUsageTrendWithFilters(ctx context.Context, startTime, endTime time.Time, granularity string, userID, apiKeyID int64) (results []TrendDataPoint, err error) {
|
||||
// GetUsageTrendWithFilters returns usage trend data with optional filters
|
||||
func (r *usageLogRepository) GetUsageTrendWithFilters(ctx context.Context, startTime, endTime time.Time, granularity string, userID, apiKeyID, accountID, groupID int64, model string, stream *bool) (results []TrendDataPoint, err error) {
|
||||
dateFormat := "YYYY-MM-DD"
|
||||
if granularity == "hour" {
|
||||
dateFormat = "YYYY-MM-DD HH24:00"
|
||||
@@ -1283,6 +1440,22 @@ func (r *usageLogRepository) GetUsageTrendWithFilters(ctx context.Context, start
|
||||
query += fmt.Sprintf(" AND api_key_id = $%d", len(args)+1)
|
||||
args = append(args, apiKeyID)
|
||||
}
|
||||
if accountID > 0 {
|
||||
query += fmt.Sprintf(" AND account_id = $%d", len(args)+1)
|
||||
args = append(args, accountID)
|
||||
}
|
||||
if groupID > 0 {
|
||||
query += fmt.Sprintf(" AND group_id = $%d", len(args)+1)
|
||||
args = append(args, groupID)
|
||||
}
|
||||
if model != "" {
|
||||
query += fmt.Sprintf(" AND model = $%d", len(args)+1)
|
||||
args = append(args, model)
|
||||
}
|
||||
if stream != nil {
|
||||
query += fmt.Sprintf(" AND stream = $%d", len(args)+1)
|
||||
args = append(args, *stream)
|
||||
}
|
||||
query += " GROUP BY date ORDER BY date ASC"
|
||||
|
||||
rows, err := r.sql.QueryContext(ctx, query, args...)
|
||||
@@ -1305,9 +1478,15 @@ func (r *usageLogRepository) GetUsageTrendWithFilters(ctx context.Context, start
|
||||
return results, nil
|
||||
}
|
||||
|
||||
// GetModelStatsWithFilters returns model statistics with optional user/api_key filters
|
||||
func (r *usageLogRepository) GetModelStatsWithFilters(ctx context.Context, startTime, endTime time.Time, userID, apiKeyID, accountID int64) (results []ModelStat, err error) {
|
||||
query := `
|
||||
// GetModelStatsWithFilters returns model statistics with optional filters
|
||||
func (r *usageLogRepository) GetModelStatsWithFilters(ctx context.Context, startTime, endTime time.Time, userID, apiKeyID, accountID, groupID int64, stream *bool) (results []ModelStat, err error) {
|
||||
actualCostExpr := "COALESCE(SUM(actual_cost), 0) as actual_cost"
|
||||
// 当仅按 account_id 聚合时,实际费用使用账号倍率(total_cost * account_rate_multiplier)。
|
||||
if accountID > 0 && userID == 0 && apiKeyID == 0 {
|
||||
actualCostExpr = "COALESCE(SUM(total_cost * COALESCE(account_rate_multiplier, 1)), 0) as actual_cost"
|
||||
}
|
||||
|
||||
query := fmt.Sprintf(`
|
||||
SELECT
|
||||
model,
|
||||
COUNT(*) as requests,
|
||||
@@ -1315,10 +1494,10 @@ func (r *usageLogRepository) GetModelStatsWithFilters(ctx context.Context, start
|
||||
COALESCE(SUM(output_tokens), 0) as output_tokens,
|
||||
COALESCE(SUM(input_tokens + output_tokens + cache_creation_tokens + cache_read_tokens), 0) as total_tokens,
|
||||
COALESCE(SUM(total_cost), 0) as cost,
|
||||
COALESCE(SUM(actual_cost), 0) as actual_cost
|
||||
%s
|
||||
FROM usage_logs
|
||||
WHERE created_at >= $1 AND created_at < $2
|
||||
`
|
||||
`, actualCostExpr)
|
||||
|
||||
args := []any{startTime, endTime}
|
||||
if userID > 0 {
|
||||
@@ -1333,6 +1512,14 @@ func (r *usageLogRepository) GetModelStatsWithFilters(ctx context.Context, start
|
||||
query += fmt.Sprintf(" AND account_id = $%d", len(args)+1)
|
||||
args = append(args, accountID)
|
||||
}
|
||||
if groupID > 0 {
|
||||
query += fmt.Sprintf(" AND group_id = $%d", len(args)+1)
|
||||
args = append(args, groupID)
|
||||
}
|
||||
if stream != nil {
|
||||
query += fmt.Sprintf(" AND stream = $%d", len(args)+1)
|
||||
args = append(args, *stream)
|
||||
}
|
||||
query += " GROUP BY model ORDER BY total_tokens DESC"
|
||||
|
||||
rows, err := r.sql.QueryContext(ctx, query, args...)
|
||||
@@ -1440,12 +1627,14 @@ func (r *usageLogRepository) GetStatsWithFilters(ctx context.Context, filters Us
|
||||
COALESCE(SUM(cache_creation_tokens + cache_read_tokens), 0) as total_cache_tokens,
|
||||
COALESCE(SUM(total_cost), 0) as total_cost,
|
||||
COALESCE(SUM(actual_cost), 0) as total_actual_cost,
|
||||
COALESCE(SUM(total_cost * COALESCE(account_rate_multiplier, 1)), 0) as total_account_cost,
|
||||
COALESCE(AVG(duration_ms), 0) as avg_duration_ms
|
||||
FROM usage_logs
|
||||
%s
|
||||
`, buildWhere(conditions))
|
||||
|
||||
stats := &UsageStats{}
|
||||
var totalAccountCost float64
|
||||
if err := scanSingleRow(
|
||||
ctx,
|
||||
r.sql,
|
||||
@@ -1457,10 +1646,14 @@ func (r *usageLogRepository) GetStatsWithFilters(ctx context.Context, filters Us
|
||||
&stats.TotalCacheTokens,
|
||||
&stats.TotalCost,
|
||||
&stats.TotalActualCost,
|
||||
&totalAccountCost,
|
||||
&stats.AverageDurationMs,
|
||||
); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if filters.AccountID > 0 {
|
||||
stats.TotalAccountCost = &totalAccountCost
|
||||
}
|
||||
stats.TotalTokens = stats.TotalInputTokens + stats.TotalOutputTokens + stats.TotalCacheTokens
|
||||
return stats, nil
|
||||
}
|
||||
@@ -1487,7 +1680,8 @@ func (r *usageLogRepository) GetAccountUsageStats(ctx context.Context, accountID
|
||||
COUNT(*) as requests,
|
||||
COALESCE(SUM(input_tokens + output_tokens + cache_creation_tokens + cache_read_tokens), 0) as tokens,
|
||||
COALESCE(SUM(total_cost), 0) as cost,
|
||||
COALESCE(SUM(actual_cost), 0) as actual_cost
|
||||
COALESCE(SUM(total_cost * COALESCE(account_rate_multiplier, 1)), 0) as actual_cost,
|
||||
COALESCE(SUM(actual_cost), 0) as user_cost
|
||||
FROM usage_logs
|
||||
WHERE account_id = $1 AND created_at >= $2 AND created_at < $3
|
||||
GROUP BY date
|
||||
@@ -1514,7 +1708,8 @@ func (r *usageLogRepository) GetAccountUsageStats(ctx context.Context, accountID
|
||||
var tokens int64
|
||||
var cost float64
|
||||
var actualCost float64
|
||||
if err = rows.Scan(&date, &requests, &tokens, &cost, &actualCost); err != nil {
|
||||
var userCost float64
|
||||
if err = rows.Scan(&date, &requests, &tokens, &cost, &actualCost, &userCost); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
t, _ := time.Parse("2006-01-02", date)
|
||||
@@ -1525,19 +1720,21 @@ func (r *usageLogRepository) GetAccountUsageStats(ctx context.Context, accountID
|
||||
Tokens: tokens,
|
||||
Cost: cost,
|
||||
ActualCost: actualCost,
|
||||
UserCost: userCost,
|
||||
})
|
||||
}
|
||||
if err = rows.Err(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
var totalActualCost, totalStandardCost float64
|
||||
var totalAccountCost, totalUserCost, totalStandardCost float64
|
||||
var totalRequests, totalTokens int64
|
||||
var highestCostDay, highestRequestDay *AccountUsageHistory
|
||||
|
||||
for i := range history {
|
||||
h := &history[i]
|
||||
totalActualCost += h.ActualCost
|
||||
totalAccountCost += h.ActualCost
|
||||
totalUserCost += h.UserCost
|
||||
totalStandardCost += h.Cost
|
||||
totalRequests += h.Requests
|
||||
totalTokens += h.Tokens
|
||||
@@ -1564,11 +1761,13 @@ func (r *usageLogRepository) GetAccountUsageStats(ctx context.Context, accountID
|
||||
summary := AccountUsageSummary{
|
||||
Days: daysCount,
|
||||
ActualDaysUsed: actualDaysUsed,
|
||||
TotalCost: totalActualCost,
|
||||
TotalCost: totalAccountCost,
|
||||
TotalUserCost: totalUserCost,
|
||||
TotalStandardCost: totalStandardCost,
|
||||
TotalRequests: totalRequests,
|
||||
TotalTokens: totalTokens,
|
||||
AvgDailyCost: totalActualCost / float64(actualDaysUsed),
|
||||
AvgDailyCost: totalAccountCost / float64(actualDaysUsed),
|
||||
AvgDailyUserCost: totalUserCost / float64(actualDaysUsed),
|
||||
AvgDailyRequests: float64(totalRequests) / float64(actualDaysUsed),
|
||||
AvgDailyTokens: float64(totalTokens) / float64(actualDaysUsed),
|
||||
AvgDurationMs: avgDuration,
|
||||
@@ -1580,11 +1779,13 @@ func (r *usageLogRepository) GetAccountUsageStats(ctx context.Context, accountID
|
||||
summary.Today = &struct {
|
||||
Date string `json:"date"`
|
||||
Cost float64 `json:"cost"`
|
||||
UserCost float64 `json:"user_cost"`
|
||||
Requests int64 `json:"requests"`
|
||||
Tokens int64 `json:"tokens"`
|
||||
}{
|
||||
Date: history[i].Date,
|
||||
Cost: history[i].ActualCost,
|
||||
UserCost: history[i].UserCost,
|
||||
Requests: history[i].Requests,
|
||||
Tokens: history[i].Tokens,
|
||||
}
|
||||
@@ -1597,11 +1798,13 @@ func (r *usageLogRepository) GetAccountUsageStats(ctx context.Context, accountID
|
||||
Date string `json:"date"`
|
||||
Label string `json:"label"`
|
||||
Cost float64 `json:"cost"`
|
||||
UserCost float64 `json:"user_cost"`
|
||||
Requests int64 `json:"requests"`
|
||||
}{
|
||||
Date: highestCostDay.Date,
|
||||
Label: highestCostDay.Label,
|
||||
Cost: highestCostDay.ActualCost,
|
||||
UserCost: highestCostDay.UserCost,
|
||||
Requests: highestCostDay.Requests,
|
||||
}
|
||||
}
|
||||
@@ -1612,15 +1815,17 @@ func (r *usageLogRepository) GetAccountUsageStats(ctx context.Context, accountID
|
||||
Label string `json:"label"`
|
||||
Requests int64 `json:"requests"`
|
||||
Cost float64 `json:"cost"`
|
||||
UserCost float64 `json:"user_cost"`
|
||||
}{
|
||||
Date: highestRequestDay.Date,
|
||||
Label: highestRequestDay.Label,
|
||||
Requests: highestRequestDay.Requests,
|
||||
Cost: highestRequestDay.ActualCost,
|
||||
UserCost: highestRequestDay.UserCost,
|
||||
}
|
||||
}
|
||||
|
||||
models, err := r.GetModelStatsWithFilters(ctx, startTime, endTime, 0, 0, accountID)
|
||||
models, err := r.GetModelStatsWithFilters(ctx, startTime, endTime, 0, 0, accountID, 0, nil)
|
||||
if err != nil {
|
||||
models = []ModelStat{}
|
||||
}
|
||||
@@ -1847,35 +2052,37 @@ func (r *usageLogRepository) loadSubscriptions(ctx context.Context, ids []int64)
|
||||
|
||||
func scanUsageLog(scanner interface{ Scan(...any) error }) (*service.UsageLog, error) {
|
||||
var (
|
||||
id int64
|
||||
userID int64
|
||||
apiKeyID int64
|
||||
accountID int64
|
||||
requestID sql.NullString
|
||||
model string
|
||||
groupID sql.NullInt64
|
||||
subscriptionID sql.NullInt64
|
||||
inputTokens int
|
||||
outputTokens int
|
||||
cacheCreationTokens int
|
||||
cacheReadTokens int
|
||||
cacheCreation5m int
|
||||
cacheCreation1h int
|
||||
inputCost float64
|
||||
outputCost float64
|
||||
cacheCreationCost float64
|
||||
cacheReadCost float64
|
||||
totalCost float64
|
||||
actualCost float64
|
||||
rateMultiplier float64
|
||||
billingType int16
|
||||
stream bool
|
||||
durationMs sql.NullInt64
|
||||
firstTokenMs sql.NullInt64
|
||||
userAgent sql.NullString
|
||||
imageCount int
|
||||
imageSize sql.NullString
|
||||
createdAt time.Time
|
||||
id int64
|
||||
userID int64
|
||||
apiKeyID int64
|
||||
accountID int64
|
||||
requestID sql.NullString
|
||||
model string
|
||||
groupID sql.NullInt64
|
||||
subscriptionID sql.NullInt64
|
||||
inputTokens int
|
||||
outputTokens int
|
||||
cacheCreationTokens int
|
||||
cacheReadTokens int
|
||||
cacheCreation5m int
|
||||
cacheCreation1h int
|
||||
inputCost float64
|
||||
outputCost float64
|
||||
cacheCreationCost float64
|
||||
cacheReadCost float64
|
||||
totalCost float64
|
||||
actualCost float64
|
||||
rateMultiplier float64
|
||||
accountRateMultiplier sql.NullFloat64
|
||||
billingType int16
|
||||
stream bool
|
||||
durationMs sql.NullInt64
|
||||
firstTokenMs sql.NullInt64
|
||||
userAgent sql.NullString
|
||||
ipAddress sql.NullString
|
||||
imageCount int
|
||||
imageSize sql.NullString
|
||||
createdAt time.Time
|
||||
)
|
||||
|
||||
if err := scanner.Scan(
|
||||
@@ -1900,11 +2107,13 @@ func scanUsageLog(scanner interface{ Scan(...any) error }) (*service.UsageLog, e
|
||||
&totalCost,
|
||||
&actualCost,
|
||||
&rateMultiplier,
|
||||
&accountRateMultiplier,
|
||||
&billingType,
|
||||
&stream,
|
||||
&durationMs,
|
||||
&firstTokenMs,
|
||||
&userAgent,
|
||||
&ipAddress,
|
||||
&imageCount,
|
||||
&imageSize,
|
||||
&createdAt,
|
||||
@@ -1931,6 +2140,7 @@ func scanUsageLog(scanner interface{ Scan(...any) error }) (*service.UsageLog, e
|
||||
TotalCost: totalCost,
|
||||
ActualCost: actualCost,
|
||||
RateMultiplier: rateMultiplier,
|
||||
AccountRateMultiplier: nullFloat64Ptr(accountRateMultiplier),
|
||||
BillingType: int8(billingType),
|
||||
Stream: stream,
|
||||
ImageCount: imageCount,
|
||||
@@ -1959,6 +2169,9 @@ func scanUsageLog(scanner interface{ Scan(...any) error }) (*service.UsageLog, e
|
||||
if userAgent.Valid {
|
||||
log.UserAgent = &userAgent.String
|
||||
}
|
||||
if ipAddress.Valid {
|
||||
log.IPAddress = &ipAddress.String
|
||||
}
|
||||
if imageSize.Valid {
|
||||
log.ImageSize = &imageSize.String
|
||||
}
|
||||
@@ -2034,6 +2247,14 @@ func nullInt(v *int) sql.NullInt64 {
|
||||
return sql.NullInt64{Int64: int64(*v), Valid: true}
|
||||
}
|
||||
|
||||
func nullFloat64Ptr(v sql.NullFloat64) *float64 {
|
||||
if !v.Valid {
|
||||
return nil
|
||||
}
|
||||
out := v.Float64
|
||||
return &out
|
||||
}
|
||||
|
||||
func nullString(v *string) sql.NullString {
|
||||
if v == nil || *v == "" {
|
||||
return sql.NullString{}
|
||||
|
||||
@@ -37,6 +37,12 @@ func TestUsageLogRepoSuite(t *testing.T) {
|
||||
suite.Run(t, new(UsageLogRepoSuite))
|
||||
}
|
||||
|
||||
// truncateToDayUTC 截断到 UTC 日期边界(测试辅助函数)
|
||||
func truncateToDayUTC(t time.Time) time.Time {
|
||||
t = t.UTC()
|
||||
return time.Date(t.Year(), t.Month(), t.Day(), 0, 0, 0, 0, time.UTC)
|
||||
}
|
||||
|
||||
func (s *UsageLogRepoSuite) createUsageLog(user *service.User, apiKey *service.APIKey, account *service.Account, inputTokens, outputTokens int, cost float64, createdAt time.Time) *service.UsageLog {
|
||||
log := &service.UsageLog{
|
||||
UserID: user.ID,
|
||||
@@ -96,6 +102,34 @@ func (s *UsageLogRepoSuite) TestGetByID_NotFound() {
|
||||
s.Require().Error(err, "expected error for non-existent ID")
|
||||
}
|
||||
|
||||
func (s *UsageLogRepoSuite) TestGetByID_ReturnsAccountRateMultiplier() {
|
||||
user := mustCreateUser(s.T(), s.client, &service.User{Email: "getbyid-mult@test.com"})
|
||||
apiKey := mustCreateApiKey(s.T(), s.client, &service.APIKey{UserID: user.ID, Key: "sk-getbyid-mult", Name: "k"})
|
||||
account := mustCreateAccount(s.T(), s.client, &service.Account{Name: "acc-getbyid-mult"})
|
||||
|
||||
m := 0.5
|
||||
log := &service.UsageLog{
|
||||
UserID: user.ID,
|
||||
APIKeyID: apiKey.ID,
|
||||
AccountID: account.ID,
|
||||
RequestID: uuid.New().String(),
|
||||
Model: "claude-3",
|
||||
InputTokens: 10,
|
||||
OutputTokens: 20,
|
||||
TotalCost: 1.0,
|
||||
ActualCost: 2.0,
|
||||
AccountRateMultiplier: &m,
|
||||
CreatedAt: timezone.Today().Add(2 * time.Hour),
|
||||
}
|
||||
_, err := s.repo.Create(s.ctx, log)
|
||||
s.Require().NoError(err)
|
||||
|
||||
got, err := s.repo.GetByID(s.ctx, log.ID)
|
||||
s.Require().NoError(err)
|
||||
s.Require().NotNil(got.AccountRateMultiplier)
|
||||
s.Require().InEpsilon(0.5, *got.AccountRateMultiplier, 0.0001)
|
||||
}
|
||||
|
||||
// --- Delete ---
|
||||
|
||||
func (s *UsageLogRepoSuite) TestDelete() {
|
||||
@@ -198,14 +232,14 @@ func (s *UsageLogRepoSuite) TestListWithFilters() {
|
||||
// --- GetDashboardStats ---
|
||||
|
||||
func (s *UsageLogRepoSuite) TestDashboardStats_TodayTotalsAndPerformance() {
|
||||
now := time.Now()
|
||||
todayStart := timezone.Today()
|
||||
now := time.Now().UTC()
|
||||
todayStart := truncateToDayUTC(now)
|
||||
baseStats, err := s.repo.GetDashboardStats(s.ctx)
|
||||
s.Require().NoError(err, "GetDashboardStats base")
|
||||
|
||||
userToday := mustCreateUser(s.T(), s.client, &service.User{
|
||||
Email: "today@example.com",
|
||||
CreatedAt: maxTime(todayStart.Add(10*time.Second), now.Add(-10*time.Second)),
|
||||
CreatedAt: testMaxTime(todayStart.Add(10*time.Second), now.Add(-10*time.Second)),
|
||||
UpdatedAt: now,
|
||||
})
|
||||
userOld := mustCreateUser(s.T(), s.client, &service.User{
|
||||
@@ -238,7 +272,7 @@ func (s *UsageLogRepoSuite) TestDashboardStats_TodayTotalsAndPerformance() {
|
||||
TotalCost: 1.5,
|
||||
ActualCost: 1.2,
|
||||
DurationMs: &d1,
|
||||
CreatedAt: maxTime(todayStart.Add(2*time.Minute), now.Add(-2*time.Minute)),
|
||||
CreatedAt: testMaxTime(todayStart.Add(2*time.Minute), now.Add(-2*time.Minute)),
|
||||
}
|
||||
_, err = s.repo.Create(s.ctx, logToday)
|
||||
s.Require().NoError(err, "Create logToday")
|
||||
@@ -273,6 +307,11 @@ func (s *UsageLogRepoSuite) TestDashboardStats_TodayTotalsAndPerformance() {
|
||||
_, err = s.repo.Create(s.ctx, logPerf)
|
||||
s.Require().NoError(err, "Create logPerf")
|
||||
|
||||
aggRepo := newDashboardAggregationRepositoryWithSQL(s.tx)
|
||||
aggStart := todayStart.Add(-2 * time.Hour)
|
||||
aggEnd := now.Add(2 * time.Minute)
|
||||
s.Require().NoError(aggRepo.AggregateRange(s.ctx, aggStart, aggEnd), "AggregateRange")
|
||||
|
||||
stats, err := s.repo.GetDashboardStats(s.ctx)
|
||||
s.Require().NoError(err, "GetDashboardStats")
|
||||
|
||||
@@ -303,6 +342,80 @@ func (s *UsageLogRepoSuite) TestDashboardStats_TodayTotalsAndPerformance() {
|
||||
s.Require().Equal(wantTpm, stats.Tpm, "Tpm mismatch")
|
||||
}
|
||||
|
||||
func (s *UsageLogRepoSuite) TestDashboardStatsWithRange_Fallback() {
|
||||
now := time.Now().UTC()
|
||||
todayStart := truncateToDayUTC(now)
|
||||
rangeStart := todayStart.Add(-24 * time.Hour)
|
||||
rangeEnd := now.Add(1 * time.Second)
|
||||
|
||||
user1 := mustCreateUser(s.T(), s.client, &service.User{Email: "range-u1@test.com"})
|
||||
user2 := mustCreateUser(s.T(), s.client, &service.User{Email: "range-u2@test.com"})
|
||||
apiKey1 := mustCreateApiKey(s.T(), s.client, &service.APIKey{UserID: user1.ID, Key: "sk-range-1", Name: "k1"})
|
||||
apiKey2 := mustCreateApiKey(s.T(), s.client, &service.APIKey{UserID: user2.ID, Key: "sk-range-2", Name: "k2"})
|
||||
account := mustCreateAccount(s.T(), s.client, &service.Account{Name: "acc-range"})
|
||||
|
||||
d1, d2, d3 := 100, 200, 300
|
||||
logOutside := &service.UsageLog{
|
||||
UserID: user1.ID,
|
||||
APIKeyID: apiKey1.ID,
|
||||
AccountID: account.ID,
|
||||
Model: "claude-3",
|
||||
InputTokens: 7,
|
||||
OutputTokens: 8,
|
||||
TotalCost: 0.8,
|
||||
ActualCost: 0.7,
|
||||
DurationMs: &d3,
|
||||
CreatedAt: rangeStart.Add(-1 * time.Hour),
|
||||
}
|
||||
_, err := s.repo.Create(s.ctx, logOutside)
|
||||
s.Require().NoError(err)
|
||||
|
||||
logRange := &service.UsageLog{
|
||||
UserID: user1.ID,
|
||||
APIKeyID: apiKey1.ID,
|
||||
AccountID: account.ID,
|
||||
Model: "claude-3",
|
||||
InputTokens: 10,
|
||||
OutputTokens: 20,
|
||||
CacheCreationTokens: 1,
|
||||
CacheReadTokens: 2,
|
||||
TotalCost: 1.0,
|
||||
ActualCost: 0.9,
|
||||
DurationMs: &d1,
|
||||
CreatedAt: rangeStart.Add(2 * time.Hour),
|
||||
}
|
||||
_, err = s.repo.Create(s.ctx, logRange)
|
||||
s.Require().NoError(err)
|
||||
|
||||
logToday := &service.UsageLog{
|
||||
UserID: user2.ID,
|
||||
APIKeyID: apiKey2.ID,
|
||||
AccountID: account.ID,
|
||||
Model: "claude-3",
|
||||
InputTokens: 5,
|
||||
OutputTokens: 6,
|
||||
CacheReadTokens: 1,
|
||||
TotalCost: 0.5,
|
||||
ActualCost: 0.5,
|
||||
DurationMs: &d2,
|
||||
CreatedAt: now,
|
||||
}
|
||||
_, err = s.repo.Create(s.ctx, logToday)
|
||||
s.Require().NoError(err)
|
||||
|
||||
stats, err := s.repo.GetDashboardStatsWithRange(s.ctx, rangeStart, rangeEnd)
|
||||
s.Require().NoError(err)
|
||||
s.Require().Equal(int64(2), stats.TotalRequests)
|
||||
s.Require().Equal(int64(15), stats.TotalInputTokens)
|
||||
s.Require().Equal(int64(26), stats.TotalOutputTokens)
|
||||
s.Require().Equal(int64(1), stats.TotalCacheCreationTokens)
|
||||
s.Require().Equal(int64(3), stats.TotalCacheReadTokens)
|
||||
s.Require().Equal(int64(45), stats.TotalTokens)
|
||||
s.Require().Equal(1.5, stats.TotalCost)
|
||||
s.Require().Equal(1.4, stats.TotalActualCost)
|
||||
s.Require().InEpsilon(150.0, stats.AverageDurationMs, 0.0001)
|
||||
}
|
||||
|
||||
// --- GetUserDashboardStats ---
|
||||
|
||||
func (s *UsageLogRepoSuite) TestGetUserDashboardStats() {
|
||||
@@ -325,12 +438,202 @@ func (s *UsageLogRepoSuite) TestGetAccountTodayStats() {
|
||||
apiKey := mustCreateApiKey(s.T(), s.client, &service.APIKey{UserID: user.ID, Key: "sk-acctoday", Name: "k"})
|
||||
account := mustCreateAccount(s.T(), s.client, &service.Account{Name: "acc-today"})
|
||||
|
||||
s.createUsageLog(user, apiKey, account, 10, 20, 0.5, time.Now())
|
||||
createdAt := timezone.Today().Add(1 * time.Hour)
|
||||
|
||||
m1 := 1.5
|
||||
m2 := 0.0
|
||||
_, err := s.repo.Create(s.ctx, &service.UsageLog{
|
||||
UserID: user.ID,
|
||||
APIKeyID: apiKey.ID,
|
||||
AccountID: account.ID,
|
||||
RequestID: uuid.New().String(),
|
||||
Model: "claude-3",
|
||||
InputTokens: 10,
|
||||
OutputTokens: 20,
|
||||
TotalCost: 1.0,
|
||||
ActualCost: 2.0,
|
||||
AccountRateMultiplier: &m1,
|
||||
CreatedAt: createdAt,
|
||||
})
|
||||
s.Require().NoError(err)
|
||||
_, err = s.repo.Create(s.ctx, &service.UsageLog{
|
||||
UserID: user.ID,
|
||||
APIKeyID: apiKey.ID,
|
||||
AccountID: account.ID,
|
||||
RequestID: uuid.New().String(),
|
||||
Model: "claude-3",
|
||||
InputTokens: 5,
|
||||
OutputTokens: 5,
|
||||
TotalCost: 0.5,
|
||||
ActualCost: 1.0,
|
||||
AccountRateMultiplier: &m2,
|
||||
CreatedAt: createdAt,
|
||||
})
|
||||
s.Require().NoError(err)
|
||||
|
||||
stats, err := s.repo.GetAccountTodayStats(s.ctx, account.ID)
|
||||
s.Require().NoError(err, "GetAccountTodayStats")
|
||||
s.Require().Equal(int64(1), stats.Requests)
|
||||
s.Require().Equal(int64(30), stats.Tokens)
|
||||
s.Require().Equal(int64(2), stats.Requests)
|
||||
s.Require().Equal(int64(40), stats.Tokens)
|
||||
// account cost = SUM(total_cost * account_rate_multiplier)
|
||||
s.Require().InEpsilon(1.5, stats.Cost, 0.0001)
|
||||
// standard cost = SUM(total_cost)
|
||||
s.Require().InEpsilon(1.5, stats.StandardCost, 0.0001)
|
||||
// user cost = SUM(actual_cost)
|
||||
s.Require().InEpsilon(3.0, stats.UserCost, 0.0001)
|
||||
}
|
||||
|
||||
func (s *UsageLogRepoSuite) TestDashboardAggregationConsistency() {
|
||||
now := time.Now().UTC().Truncate(time.Second)
|
||||
// 使用固定的时间偏移确保 hour1 和 hour2 在同一天且都在过去
|
||||
// 选择当天 02:00 和 03:00 作为测试时间点(基于 now 的日期)
|
||||
dayStart := truncateToDayUTC(now)
|
||||
hour1 := dayStart.Add(2 * time.Hour) // 当天 02:00
|
||||
hour2 := dayStart.Add(3 * time.Hour) // 当天 03:00
|
||||
// 如果当前时间早于 hour2,则使用昨天的时间
|
||||
if now.Before(hour2.Add(time.Hour)) {
|
||||
dayStart = dayStart.Add(-24 * time.Hour)
|
||||
hour1 = dayStart.Add(2 * time.Hour)
|
||||
hour2 = dayStart.Add(3 * time.Hour)
|
||||
}
|
||||
|
||||
user1 := mustCreateUser(s.T(), s.client, &service.User{Email: "agg-u1@test.com"})
|
||||
user2 := mustCreateUser(s.T(), s.client, &service.User{Email: "agg-u2@test.com"})
|
||||
apiKey1 := mustCreateApiKey(s.T(), s.client, &service.APIKey{UserID: user1.ID, Key: "sk-agg-1", Name: "k1"})
|
||||
apiKey2 := mustCreateApiKey(s.T(), s.client, &service.APIKey{UserID: user2.ID, Key: "sk-agg-2", Name: "k2"})
|
||||
account := mustCreateAccount(s.T(), s.client, &service.Account{Name: "acc-agg"})
|
||||
|
||||
d1, d2, d3 := 100, 200, 150
|
||||
log1 := &service.UsageLog{
|
||||
UserID: user1.ID,
|
||||
APIKeyID: apiKey1.ID,
|
||||
AccountID: account.ID,
|
||||
Model: "claude-3",
|
||||
InputTokens: 10,
|
||||
OutputTokens: 20,
|
||||
CacheCreationTokens: 2,
|
||||
CacheReadTokens: 1,
|
||||
TotalCost: 1.0,
|
||||
ActualCost: 0.9,
|
||||
DurationMs: &d1,
|
||||
CreatedAt: hour1.Add(5 * time.Minute),
|
||||
}
|
||||
_, err := s.repo.Create(s.ctx, log1)
|
||||
s.Require().NoError(err)
|
||||
|
||||
log2 := &service.UsageLog{
|
||||
UserID: user1.ID,
|
||||
APIKeyID: apiKey1.ID,
|
||||
AccountID: account.ID,
|
||||
Model: "claude-3",
|
||||
InputTokens: 5,
|
||||
OutputTokens: 5,
|
||||
TotalCost: 0.5,
|
||||
ActualCost: 0.5,
|
||||
DurationMs: &d2,
|
||||
CreatedAt: hour1.Add(20 * time.Minute),
|
||||
}
|
||||
_, err = s.repo.Create(s.ctx, log2)
|
||||
s.Require().NoError(err)
|
||||
|
||||
log3 := &service.UsageLog{
|
||||
UserID: user2.ID,
|
||||
APIKeyID: apiKey2.ID,
|
||||
AccountID: account.ID,
|
||||
Model: "claude-3",
|
||||
InputTokens: 7,
|
||||
OutputTokens: 8,
|
||||
TotalCost: 0.7,
|
||||
ActualCost: 0.7,
|
||||
DurationMs: &d3,
|
||||
CreatedAt: hour2.Add(10 * time.Minute),
|
||||
}
|
||||
_, err = s.repo.Create(s.ctx, log3)
|
||||
s.Require().NoError(err)
|
||||
|
||||
aggRepo := newDashboardAggregationRepositoryWithSQL(s.tx)
|
||||
aggStart := hour1.Add(-5 * time.Minute)
|
||||
aggEnd := hour2.Add(time.Hour) // 确保覆盖 hour2 的所有数据
|
||||
s.Require().NoError(aggRepo.AggregateRange(s.ctx, aggStart, aggEnd))
|
||||
|
||||
type hourlyRow struct {
|
||||
totalRequests int64
|
||||
inputTokens int64
|
||||
outputTokens int64
|
||||
cacheCreationTokens int64
|
||||
cacheReadTokens int64
|
||||
totalCost float64
|
||||
actualCost float64
|
||||
totalDurationMs int64
|
||||
activeUsers int64
|
||||
}
|
||||
fetchHourly := func(bucketStart time.Time) hourlyRow {
|
||||
var row hourlyRow
|
||||
err := scanSingleRow(s.ctx, s.tx, `
|
||||
SELECT total_requests, input_tokens, output_tokens, cache_creation_tokens, cache_read_tokens,
|
||||
total_cost, actual_cost, total_duration_ms, active_users
|
||||
FROM usage_dashboard_hourly
|
||||
WHERE bucket_start = $1
|
||||
`, []any{bucketStart}, &row.totalRequests, &row.inputTokens, &row.outputTokens,
|
||||
&row.cacheCreationTokens, &row.cacheReadTokens, &row.totalCost, &row.actualCost,
|
||||
&row.totalDurationMs, &row.activeUsers,
|
||||
)
|
||||
s.Require().NoError(err)
|
||||
return row
|
||||
}
|
||||
|
||||
hour1Row := fetchHourly(hour1)
|
||||
s.Require().Equal(int64(2), hour1Row.totalRequests)
|
||||
s.Require().Equal(int64(15), hour1Row.inputTokens)
|
||||
s.Require().Equal(int64(25), hour1Row.outputTokens)
|
||||
s.Require().Equal(int64(2), hour1Row.cacheCreationTokens)
|
||||
s.Require().Equal(int64(1), hour1Row.cacheReadTokens)
|
||||
s.Require().Equal(1.5, hour1Row.totalCost)
|
||||
s.Require().Equal(1.4, hour1Row.actualCost)
|
||||
s.Require().Equal(int64(300), hour1Row.totalDurationMs)
|
||||
s.Require().Equal(int64(1), hour1Row.activeUsers)
|
||||
|
||||
hour2Row := fetchHourly(hour2)
|
||||
s.Require().Equal(int64(1), hour2Row.totalRequests)
|
||||
s.Require().Equal(int64(7), hour2Row.inputTokens)
|
||||
s.Require().Equal(int64(8), hour2Row.outputTokens)
|
||||
s.Require().Equal(int64(0), hour2Row.cacheCreationTokens)
|
||||
s.Require().Equal(int64(0), hour2Row.cacheReadTokens)
|
||||
s.Require().Equal(0.7, hour2Row.totalCost)
|
||||
s.Require().Equal(0.7, hour2Row.actualCost)
|
||||
s.Require().Equal(int64(150), hour2Row.totalDurationMs)
|
||||
s.Require().Equal(int64(1), hour2Row.activeUsers)
|
||||
|
||||
var daily struct {
|
||||
totalRequests int64
|
||||
inputTokens int64
|
||||
outputTokens int64
|
||||
cacheCreationTokens int64
|
||||
cacheReadTokens int64
|
||||
totalCost float64
|
||||
actualCost float64
|
||||
totalDurationMs int64
|
||||
activeUsers int64
|
||||
}
|
||||
err = scanSingleRow(s.ctx, s.tx, `
|
||||
SELECT total_requests, input_tokens, output_tokens, cache_creation_tokens, cache_read_tokens,
|
||||
total_cost, actual_cost, total_duration_ms, active_users
|
||||
FROM usage_dashboard_daily
|
||||
WHERE bucket_date = $1::date
|
||||
`, []any{dayStart}, &daily.totalRequests, &daily.inputTokens, &daily.outputTokens,
|
||||
&daily.cacheCreationTokens, &daily.cacheReadTokens, &daily.totalCost, &daily.actualCost,
|
||||
&daily.totalDurationMs, &daily.activeUsers,
|
||||
)
|
||||
s.Require().NoError(err)
|
||||
s.Require().Equal(int64(3), daily.totalRequests)
|
||||
s.Require().Equal(int64(22), daily.inputTokens)
|
||||
s.Require().Equal(int64(33), daily.outputTokens)
|
||||
s.Require().Equal(int64(2), daily.cacheCreationTokens)
|
||||
s.Require().Equal(int64(1), daily.cacheReadTokens)
|
||||
s.Require().Equal(2.2, daily.totalCost)
|
||||
s.Require().Equal(2.1, daily.actualCost)
|
||||
s.Require().Equal(int64(450), daily.totalDurationMs)
|
||||
s.Require().Equal(int64(2), daily.activeUsers)
|
||||
}
|
||||
|
||||
// --- GetBatchUserUsageStats ---
|
||||
@@ -398,7 +701,7 @@ func (s *UsageLogRepoSuite) TestGetGlobalStats() {
|
||||
s.Require().Equal(int64(45), stats.TotalOutputTokens)
|
||||
}
|
||||
|
||||
func maxTime(a, b time.Time) time.Time {
|
||||
func testMaxTime(a, b time.Time) time.Time {
|
||||
if a.After(b) {
|
||||
return a
|
||||
}
|
||||
@@ -641,17 +944,17 @@ func (s *UsageLogRepoSuite) TestGetUsageTrendWithFilters() {
|
||||
endTime := base.Add(48 * time.Hour)
|
||||
|
||||
// Test with user filter
|
||||
trend, err := s.repo.GetUsageTrendWithFilters(s.ctx, startTime, endTime, "day", user.ID, 0)
|
||||
trend, err := s.repo.GetUsageTrendWithFilters(s.ctx, startTime, endTime, "day", user.ID, 0, 0, 0, "", nil)
|
||||
s.Require().NoError(err, "GetUsageTrendWithFilters user filter")
|
||||
s.Require().Len(trend, 2)
|
||||
|
||||
// Test with apiKey filter
|
||||
trend, err = s.repo.GetUsageTrendWithFilters(s.ctx, startTime, endTime, "day", 0, apiKey.ID)
|
||||
trend, err = s.repo.GetUsageTrendWithFilters(s.ctx, startTime, endTime, "day", 0, apiKey.ID, 0, 0, "", nil)
|
||||
s.Require().NoError(err, "GetUsageTrendWithFilters apiKey filter")
|
||||
s.Require().Len(trend, 2)
|
||||
|
||||
// Test with both filters
|
||||
trend, err = s.repo.GetUsageTrendWithFilters(s.ctx, startTime, endTime, "day", user.ID, apiKey.ID)
|
||||
trend, err = s.repo.GetUsageTrendWithFilters(s.ctx, startTime, endTime, "day", user.ID, apiKey.ID, 0, 0, "", nil)
|
||||
s.Require().NoError(err, "GetUsageTrendWithFilters both filters")
|
||||
s.Require().Len(trend, 2)
|
||||
}
|
||||
@@ -668,7 +971,7 @@ func (s *UsageLogRepoSuite) TestGetUsageTrendWithFilters_HourlyGranularity() {
|
||||
startTime := base.Add(-1 * time.Hour)
|
||||
endTime := base.Add(3 * time.Hour)
|
||||
|
||||
trend, err := s.repo.GetUsageTrendWithFilters(s.ctx, startTime, endTime, "hour", user.ID, 0)
|
||||
trend, err := s.repo.GetUsageTrendWithFilters(s.ctx, startTime, endTime, "hour", user.ID, 0, 0, 0, "", nil)
|
||||
s.Require().NoError(err, "GetUsageTrendWithFilters hourly")
|
||||
s.Require().Len(trend, 2)
|
||||
}
|
||||
@@ -714,17 +1017,17 @@ func (s *UsageLogRepoSuite) TestGetModelStatsWithFilters() {
|
||||
endTime := base.Add(2 * time.Hour)
|
||||
|
||||
// Test with user filter
|
||||
stats, err := s.repo.GetModelStatsWithFilters(s.ctx, startTime, endTime, user.ID, 0, 0)
|
||||
stats, err := s.repo.GetModelStatsWithFilters(s.ctx, startTime, endTime, user.ID, 0, 0, 0, nil)
|
||||
s.Require().NoError(err, "GetModelStatsWithFilters user filter")
|
||||
s.Require().Len(stats, 2)
|
||||
|
||||
// Test with apiKey filter
|
||||
stats, err = s.repo.GetModelStatsWithFilters(s.ctx, startTime, endTime, 0, apiKey.ID, 0)
|
||||
stats, err = s.repo.GetModelStatsWithFilters(s.ctx, startTime, endTime, 0, apiKey.ID, 0, 0, nil)
|
||||
s.Require().NoError(err, "GetModelStatsWithFilters apiKey filter")
|
||||
s.Require().Len(stats, 2)
|
||||
|
||||
// Test with account filter
|
||||
stats, err = s.repo.GetModelStatsWithFilters(s.ctx, startTime, endTime, 0, 0, account.ID)
|
||||
stats, err = s.repo.GetModelStatsWithFilters(s.ctx, startTime, endTime, 0, 0, account.ID, 0, nil)
|
||||
s.Require().NoError(err, "GetModelStatsWithFilters account filter")
|
||||
s.Require().Len(stats, 2)
|
||||
}
|
||||
|
||||
@@ -37,6 +37,16 @@ func ProvidePricingRemoteClient(cfg *config.Config) service.PricingRemoteClient
|
||||
return NewPricingRemoteClient(cfg.Update.ProxyURL)
|
||||
}
|
||||
|
||||
// ProvideSessionLimitCache 创建会话限制缓存
|
||||
// 用于 Anthropic OAuth/SetupToken 账号的并发会话数量控制
|
||||
func ProvideSessionLimitCache(rdb *redis.Client, cfg *config.Config) service.SessionLimitCache {
|
||||
defaultIdleTimeoutMinutes := 5 // 默认 5 分钟空闲超时
|
||||
if cfg != nil && cfg.Gateway.SessionIdleTimeoutMinutes > 0 {
|
||||
defaultIdleTimeoutMinutes = cfg.Gateway.SessionIdleTimeoutMinutes
|
||||
}
|
||||
return NewSessionLimitCache(rdb, defaultIdleTimeoutMinutes)
|
||||
}
|
||||
|
||||
// ProviderSet is the Wire provider set for all repositories
|
||||
var ProviderSet = wire.NewSet(
|
||||
NewUserRepository,
|
||||
@@ -45,8 +55,11 @@ var ProviderSet = wire.NewSet(
|
||||
NewAccountRepository,
|
||||
NewProxyRepository,
|
||||
NewRedeemCodeRepository,
|
||||
NewPromoCodeRepository,
|
||||
NewUsageLogRepository,
|
||||
NewDashboardAggregationRepository,
|
||||
NewSettingRepository,
|
||||
NewOpsRepository,
|
||||
NewUserSubscriptionRepository,
|
||||
NewUserAttributeDefinitionRepository,
|
||||
NewUserAttributeValueRepository,
|
||||
@@ -56,12 +69,18 @@ var ProviderSet = wire.NewSet(
|
||||
NewBillingCache,
|
||||
NewAPIKeyCache,
|
||||
NewTempUnschedCache,
|
||||
NewTimeoutCounterCache,
|
||||
ProvideConcurrencyCache,
|
||||
ProvideSessionLimitCache,
|
||||
NewDashboardCache,
|
||||
NewEmailCache,
|
||||
NewIdentityCache,
|
||||
NewRedeemCache,
|
||||
NewUpdateCache,
|
||||
NewGeminiTokenCache,
|
||||
NewSchedulerCache,
|
||||
NewSchedulerOutboxRepository,
|
||||
NewProxyLatencyCache,
|
||||
|
||||
// HTTP service ports (DI Strategy A: return interface directly)
|
||||
NewTurnstileVerifier,
|
||||
|
||||
Reference in New Issue
Block a user