Merge upstream/main: v0.1.65-v0.1.75 updates
Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
This commit is contained in:
@@ -282,6 +282,34 @@ func (r *accountRepository) GetByCRSAccountID(ctx context.Context, crsAccountID
|
||||
return &accounts[0], nil
|
||||
}
|
||||
|
||||
func (r *accountRepository) ListCRSAccountIDs(ctx context.Context) (map[string]int64, error) {
|
||||
rows, err := r.sql.QueryContext(ctx, `
|
||||
SELECT id, extra->>'crs_account_id'
|
||||
FROM accounts
|
||||
WHERE deleted_at IS NULL
|
||||
AND extra->>'crs_account_id' IS NOT NULL
|
||||
AND extra->>'crs_account_id' != ''
|
||||
`)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer func() { _ = rows.Close() }()
|
||||
|
||||
result := make(map[string]int64)
|
||||
for rows.Next() {
|
||||
var id int64
|
||||
var crsID string
|
||||
if err := rows.Scan(&id, &crsID); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
result[crsID] = id
|
||||
}
|
||||
if err := rows.Err(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return result, nil
|
||||
}
|
||||
|
||||
func (r *accountRepository) Update(ctx context.Context, account *service.Account) error {
|
||||
if account == nil {
|
||||
return nil
|
||||
@@ -798,43 +826,6 @@ func (r *accountRepository) SetRateLimited(ctx context.Context, id int64, resetA
|
||||
return nil
|
||||
}
|
||||
|
||||
func (r *accountRepository) SetAntigravityQuotaScopeLimit(ctx context.Context, id int64, scope service.AntigravityQuotaScope, resetAt time.Time) error {
|
||||
now := time.Now().UTC()
|
||||
payload := map[string]string{
|
||||
"rate_limited_at": now.Format(time.RFC3339),
|
||||
"rate_limit_reset_at": resetAt.UTC().Format(time.RFC3339),
|
||||
}
|
||||
raw, err := json.Marshal(payload)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
path := "{antigravity_quota_scopes," + string(scope) + "}"
|
||||
client := clientFromContext(ctx, r.client)
|
||||
result, err := client.ExecContext(
|
||||
ctx,
|
||||
"UPDATE accounts SET extra = jsonb_set(COALESCE(extra, '{}'::jsonb), $1::text[], $2::jsonb, true), updated_at = NOW() WHERE id = $3 AND deleted_at IS NULL",
|
||||
path,
|
||||
raw,
|
||||
id,
|
||||
)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
affected, err := result.RowsAffected()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if affected == 0 {
|
||||
return service.ErrAccountNotFound
|
||||
}
|
||||
if err := enqueueSchedulerOutbox(ctx, r.sql, service.SchedulerOutboxEventAccountChanged, &id, nil, nil); err != nil {
|
||||
log.Printf("[SchedulerOutbox] enqueue quota scope failed: account=%d err=%v", id, err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (r *accountRepository) SetModelRateLimit(ctx context.Context, id int64, scope string, resetAt time.Time) error {
|
||||
if scope == "" {
|
||||
return nil
|
||||
@@ -849,12 +840,19 @@ func (r *accountRepository) SetModelRateLimit(ctx context.Context, id int64, sco
|
||||
return err
|
||||
}
|
||||
|
||||
path := "{model_rate_limits," + scope + "}"
|
||||
client := clientFromContext(ctx, r.client)
|
||||
result, err := client.ExecContext(
|
||||
ctx,
|
||||
"UPDATE accounts SET extra = jsonb_set(COALESCE(extra, '{}'::jsonb), $1::text[], $2::jsonb, true), updated_at = NOW() WHERE id = $3 AND deleted_at IS NULL",
|
||||
path,
|
||||
`UPDATE accounts SET
|
||||
extra = jsonb_set(
|
||||
jsonb_set(COALESCE(extra, '{}'::jsonb), '{model_rate_limits}'::text[], COALESCE(extra->'model_rate_limits', '{}'::jsonb), true),
|
||||
ARRAY['model_rate_limits', $1]::text[],
|
||||
$2::jsonb,
|
||||
true
|
||||
),
|
||||
updated_at = NOW()
|
||||
WHERE id = $3 AND deleted_at IS NULL`,
|
||||
scope,
|
||||
raw,
|
||||
id,
|
||||
)
|
||||
@@ -1072,8 +1070,9 @@ func (r *accountRepository) UpdateExtra(ctx context.Context, id int64, updates m
|
||||
result, err := client.ExecContext(
|
||||
ctx,
|
||||
"UPDATE accounts SET extra = COALESCE(extra, '{}'::jsonb) || $1::jsonb, updated_at = NOW() WHERE id = $2 AND deleted_at IS NULL",
|
||||
payload, id,
|
||||
string(payload), id,
|
||||
)
|
||||
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
83
backend/internal/repository/announcement_read_repo.go
Normal file
83
backend/internal/repository/announcement_read_repo.go
Normal file
@@ -0,0 +1,83 @@
|
||||
package repository
|
||||
|
||||
import (
|
||||
"context"
|
||||
"time"
|
||||
|
||||
dbent "github.com/Wei-Shaw/sub2api/ent"
|
||||
"github.com/Wei-Shaw/sub2api/ent/announcementread"
|
||||
"github.com/Wei-Shaw/sub2api/internal/service"
|
||||
)
|
||||
|
||||
type announcementReadRepository struct {
|
||||
client *dbent.Client
|
||||
}
|
||||
|
||||
func NewAnnouncementReadRepository(client *dbent.Client) service.AnnouncementReadRepository {
|
||||
return &announcementReadRepository{client: client}
|
||||
}
|
||||
|
||||
func (r *announcementReadRepository) MarkRead(ctx context.Context, announcementID, userID int64, readAt time.Time) error {
|
||||
client := clientFromContext(ctx, r.client)
|
||||
return client.AnnouncementRead.Create().
|
||||
SetAnnouncementID(announcementID).
|
||||
SetUserID(userID).
|
||||
SetReadAt(readAt).
|
||||
OnConflictColumns(announcementread.FieldAnnouncementID, announcementread.FieldUserID).
|
||||
DoNothing().
|
||||
Exec(ctx)
|
||||
}
|
||||
|
||||
func (r *announcementReadRepository) GetReadMapByUser(ctx context.Context, userID int64, announcementIDs []int64) (map[int64]time.Time, error) {
|
||||
if len(announcementIDs) == 0 {
|
||||
return map[int64]time.Time{}, nil
|
||||
}
|
||||
|
||||
rows, err := r.client.AnnouncementRead.Query().
|
||||
Where(
|
||||
announcementread.UserIDEQ(userID),
|
||||
announcementread.AnnouncementIDIn(announcementIDs...),
|
||||
).
|
||||
All(ctx)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
out := make(map[int64]time.Time, len(rows))
|
||||
for i := range rows {
|
||||
out[rows[i].AnnouncementID] = rows[i].ReadAt
|
||||
}
|
||||
return out, nil
|
||||
}
|
||||
|
||||
func (r *announcementReadRepository) GetReadMapByUsers(ctx context.Context, announcementID int64, userIDs []int64) (map[int64]time.Time, error) {
|
||||
if len(userIDs) == 0 {
|
||||
return map[int64]time.Time{}, nil
|
||||
}
|
||||
|
||||
rows, err := r.client.AnnouncementRead.Query().
|
||||
Where(
|
||||
announcementread.AnnouncementIDEQ(announcementID),
|
||||
announcementread.UserIDIn(userIDs...),
|
||||
).
|
||||
All(ctx)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
out := make(map[int64]time.Time, len(rows))
|
||||
for i := range rows {
|
||||
out[rows[i].UserID] = rows[i].ReadAt
|
||||
}
|
||||
return out, nil
|
||||
}
|
||||
|
||||
func (r *announcementReadRepository) CountByAnnouncementID(ctx context.Context, announcementID int64) (int64, error) {
|
||||
count, err := r.client.AnnouncementRead.Query().
|
||||
Where(announcementread.AnnouncementIDEQ(announcementID)).
|
||||
Count(ctx)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
return int64(count), nil
|
||||
}
|
||||
194
backend/internal/repository/announcement_repo.go
Normal file
194
backend/internal/repository/announcement_repo.go
Normal file
@@ -0,0 +1,194 @@
|
||||
package repository
|
||||
|
||||
import (
|
||||
"context"
|
||||
"time"
|
||||
|
||||
dbent "github.com/Wei-Shaw/sub2api/ent"
|
||||
"github.com/Wei-Shaw/sub2api/ent/announcement"
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/pagination"
|
||||
"github.com/Wei-Shaw/sub2api/internal/service"
|
||||
)
|
||||
|
||||
type announcementRepository struct {
|
||||
client *dbent.Client
|
||||
}
|
||||
|
||||
func NewAnnouncementRepository(client *dbent.Client) service.AnnouncementRepository {
|
||||
return &announcementRepository{client: client}
|
||||
}
|
||||
|
||||
func (r *announcementRepository) Create(ctx context.Context, a *service.Announcement) error {
|
||||
client := clientFromContext(ctx, r.client)
|
||||
builder := client.Announcement.Create().
|
||||
SetTitle(a.Title).
|
||||
SetContent(a.Content).
|
||||
SetStatus(a.Status).
|
||||
SetTargeting(a.Targeting)
|
||||
|
||||
if a.StartsAt != nil {
|
||||
builder.SetStartsAt(*a.StartsAt)
|
||||
}
|
||||
if a.EndsAt != nil {
|
||||
builder.SetEndsAt(*a.EndsAt)
|
||||
}
|
||||
if a.CreatedBy != nil {
|
||||
builder.SetCreatedBy(*a.CreatedBy)
|
||||
}
|
||||
if a.UpdatedBy != nil {
|
||||
builder.SetUpdatedBy(*a.UpdatedBy)
|
||||
}
|
||||
|
||||
created, err := builder.Save(ctx)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
applyAnnouncementEntityToService(a, created)
|
||||
return nil
|
||||
}
|
||||
|
||||
func (r *announcementRepository) GetByID(ctx context.Context, id int64) (*service.Announcement, error) {
|
||||
m, err := r.client.Announcement.Query().
|
||||
Where(announcement.IDEQ(id)).
|
||||
Only(ctx)
|
||||
if err != nil {
|
||||
return nil, translatePersistenceError(err, service.ErrAnnouncementNotFound, nil)
|
||||
}
|
||||
return announcementEntityToService(m), nil
|
||||
}
|
||||
|
||||
func (r *announcementRepository) Update(ctx context.Context, a *service.Announcement) error {
|
||||
client := clientFromContext(ctx, r.client)
|
||||
builder := client.Announcement.UpdateOneID(a.ID).
|
||||
SetTitle(a.Title).
|
||||
SetContent(a.Content).
|
||||
SetStatus(a.Status).
|
||||
SetTargeting(a.Targeting)
|
||||
|
||||
if a.StartsAt != nil {
|
||||
builder.SetStartsAt(*a.StartsAt)
|
||||
} else {
|
||||
builder.ClearStartsAt()
|
||||
}
|
||||
if a.EndsAt != nil {
|
||||
builder.SetEndsAt(*a.EndsAt)
|
||||
} else {
|
||||
builder.ClearEndsAt()
|
||||
}
|
||||
if a.CreatedBy != nil {
|
||||
builder.SetCreatedBy(*a.CreatedBy)
|
||||
} else {
|
||||
builder.ClearCreatedBy()
|
||||
}
|
||||
if a.UpdatedBy != nil {
|
||||
builder.SetUpdatedBy(*a.UpdatedBy)
|
||||
} else {
|
||||
builder.ClearUpdatedBy()
|
||||
}
|
||||
|
||||
updated, err := builder.Save(ctx)
|
||||
if err != nil {
|
||||
return translatePersistenceError(err, service.ErrAnnouncementNotFound, nil)
|
||||
}
|
||||
|
||||
a.UpdatedAt = updated.UpdatedAt
|
||||
return nil
|
||||
}
|
||||
|
||||
func (r *announcementRepository) Delete(ctx context.Context, id int64) error {
|
||||
client := clientFromContext(ctx, r.client)
|
||||
_, err := client.Announcement.Delete().Where(announcement.IDEQ(id)).Exec(ctx)
|
||||
return err
|
||||
}
|
||||
|
||||
func (r *announcementRepository) List(
|
||||
ctx context.Context,
|
||||
params pagination.PaginationParams,
|
||||
filters service.AnnouncementListFilters,
|
||||
) ([]service.Announcement, *pagination.PaginationResult, error) {
|
||||
q := r.client.Announcement.Query()
|
||||
|
||||
if filters.Status != "" {
|
||||
q = q.Where(announcement.StatusEQ(filters.Status))
|
||||
}
|
||||
if filters.Search != "" {
|
||||
q = q.Where(
|
||||
announcement.Or(
|
||||
announcement.TitleContainsFold(filters.Search),
|
||||
announcement.ContentContainsFold(filters.Search),
|
||||
),
|
||||
)
|
||||
}
|
||||
|
||||
total, err := q.Count(ctx)
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
|
||||
items, err := q.
|
||||
Offset(params.Offset()).
|
||||
Limit(params.Limit()).
|
||||
Order(dbent.Desc(announcement.FieldID)).
|
||||
All(ctx)
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
|
||||
out := announcementEntitiesToService(items)
|
||||
return out, paginationResultFromTotal(int64(total), params), nil
|
||||
}
|
||||
|
||||
func (r *announcementRepository) ListActive(ctx context.Context, now time.Time) ([]service.Announcement, error) {
|
||||
q := r.client.Announcement.Query().
|
||||
Where(
|
||||
announcement.StatusEQ(service.AnnouncementStatusActive),
|
||||
announcement.Or(announcement.StartsAtIsNil(), announcement.StartsAtLTE(now)),
|
||||
announcement.Or(announcement.EndsAtIsNil(), announcement.EndsAtGT(now)),
|
||||
).
|
||||
Order(dbent.Desc(announcement.FieldID))
|
||||
|
||||
items, err := q.All(ctx)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return announcementEntitiesToService(items), nil
|
||||
}
|
||||
|
||||
func applyAnnouncementEntityToService(dst *service.Announcement, src *dbent.Announcement) {
|
||||
if dst == nil || src == nil {
|
||||
return
|
||||
}
|
||||
dst.ID = src.ID
|
||||
dst.CreatedAt = src.CreatedAt
|
||||
dst.UpdatedAt = src.UpdatedAt
|
||||
}
|
||||
|
||||
func announcementEntityToService(m *dbent.Announcement) *service.Announcement {
|
||||
if m == nil {
|
||||
return nil
|
||||
}
|
||||
return &service.Announcement{
|
||||
ID: m.ID,
|
||||
Title: m.Title,
|
||||
Content: m.Content,
|
||||
Status: m.Status,
|
||||
Targeting: m.Targeting,
|
||||
StartsAt: m.StartsAt,
|
||||
EndsAt: m.EndsAt,
|
||||
CreatedBy: m.CreatedBy,
|
||||
UpdatedBy: m.UpdatedBy,
|
||||
CreatedAt: m.CreatedAt,
|
||||
UpdatedAt: m.UpdatedAt,
|
||||
}
|
||||
}
|
||||
|
||||
func announcementEntitiesToService(models []*dbent.Announcement) []service.Announcement {
|
||||
out := make([]service.Announcement, 0, len(models))
|
||||
for i := range models {
|
||||
if s := announcementEntityToService(models[i]); s != nil {
|
||||
out = append(out, *s)
|
||||
}
|
||||
}
|
||||
return out
|
||||
}
|
||||
@@ -33,7 +33,10 @@ func (r *apiKeyRepository) Create(ctx context.Context, key *service.APIKey) erro
|
||||
SetKey(key.Key).
|
||||
SetName(key.Name).
|
||||
SetStatus(key.Status).
|
||||
SetNillableGroupID(key.GroupID)
|
||||
SetNillableGroupID(key.GroupID).
|
||||
SetQuota(key.Quota).
|
||||
SetQuotaUsed(key.QuotaUsed).
|
||||
SetNillableExpiresAt(key.ExpiresAt)
|
||||
|
||||
if len(key.IPWhitelist) > 0 {
|
||||
builder.SetIPWhitelist(key.IPWhitelist)
|
||||
@@ -110,6 +113,9 @@ func (r *apiKeyRepository) GetByKeyForAuth(ctx context.Context, key string) (*se
|
||||
apikey.FieldStatus,
|
||||
apikey.FieldIPWhitelist,
|
||||
apikey.FieldIPBlacklist,
|
||||
apikey.FieldQuota,
|
||||
apikey.FieldQuotaUsed,
|
||||
apikey.FieldExpiresAt,
|
||||
).
|
||||
WithUser(func(q *dbent.UserQuery) {
|
||||
q.Select(
|
||||
@@ -136,8 +142,11 @@ func (r *apiKeyRepository) GetByKeyForAuth(ctx context.Context, key string) (*se
|
||||
group.FieldImagePrice4k,
|
||||
group.FieldClaudeCodeOnly,
|
||||
group.FieldFallbackGroupID,
|
||||
group.FieldFallbackGroupIDOnInvalidRequest,
|
||||
group.FieldModelRoutingEnabled,
|
||||
group.FieldModelRouting,
|
||||
group.FieldMcpXMLInject,
|
||||
group.FieldSupportedModelScopes,
|
||||
)
|
||||
}).
|
||||
Only(ctx)
|
||||
@@ -161,6 +170,8 @@ func (r *apiKeyRepository) Update(ctx context.Context, key *service.APIKey) erro
|
||||
Where(apikey.IDEQ(key.ID), apikey.DeletedAtIsNil()).
|
||||
SetName(key.Name).
|
||||
SetStatus(key.Status).
|
||||
SetQuota(key.Quota).
|
||||
SetQuotaUsed(key.QuotaUsed).
|
||||
SetUpdatedAt(now)
|
||||
if key.GroupID != nil {
|
||||
builder.SetGroupID(*key.GroupID)
|
||||
@@ -168,6 +179,13 @@ func (r *apiKeyRepository) Update(ctx context.Context, key *service.APIKey) erro
|
||||
builder.ClearGroupID()
|
||||
}
|
||||
|
||||
// Expiration time
|
||||
if key.ExpiresAt != nil {
|
||||
builder.SetExpiresAt(*key.ExpiresAt)
|
||||
} else {
|
||||
builder.ClearExpiresAt()
|
||||
}
|
||||
|
||||
// IP 限制字段
|
||||
if len(key.IPWhitelist) > 0 {
|
||||
builder.SetIPWhitelist(key.IPWhitelist)
|
||||
@@ -357,6 +375,38 @@ func (r *apiKeyRepository) ListKeysByGroupID(ctx context.Context, groupID int64)
|
||||
return keys, nil
|
||||
}
|
||||
|
||||
// IncrementQuotaUsed atomically increments the quota_used field and returns the new value
|
||||
func (r *apiKeyRepository) IncrementQuotaUsed(ctx context.Context, id int64, amount float64) (float64, error) {
|
||||
// Use raw SQL for atomic increment to avoid race conditions
|
||||
// First get current value
|
||||
m, err := r.activeQuery().
|
||||
Where(apikey.IDEQ(id)).
|
||||
Select(apikey.FieldQuotaUsed).
|
||||
Only(ctx)
|
||||
if err != nil {
|
||||
if dbent.IsNotFound(err) {
|
||||
return 0, service.ErrAPIKeyNotFound
|
||||
}
|
||||
return 0, err
|
||||
}
|
||||
|
||||
newValue := m.QuotaUsed + amount
|
||||
|
||||
// Update with new value
|
||||
affected, err := r.client.APIKey.Update().
|
||||
Where(apikey.IDEQ(id), apikey.DeletedAtIsNil()).
|
||||
SetQuotaUsed(newValue).
|
||||
Save(ctx)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
if affected == 0 {
|
||||
return 0, service.ErrAPIKeyNotFound
|
||||
}
|
||||
|
||||
return newValue, nil
|
||||
}
|
||||
|
||||
func apiKeyEntityToService(m *dbent.APIKey) *service.APIKey {
|
||||
if m == nil {
|
||||
return nil
|
||||
@@ -372,6 +422,9 @@ func apiKeyEntityToService(m *dbent.APIKey) *service.APIKey {
|
||||
CreatedAt: m.CreatedAt,
|
||||
UpdatedAt: m.UpdatedAt,
|
||||
GroupID: m.GroupID,
|
||||
Quota: m.Quota,
|
||||
QuotaUsed: m.QuotaUsed,
|
||||
ExpiresAt: m.ExpiresAt,
|
||||
}
|
||||
if m.Edges.User != nil {
|
||||
out.User = userEntityToService(m.Edges.User)
|
||||
@@ -409,28 +462,32 @@ func groupEntityToService(g *dbent.Group) *service.Group {
|
||||
return nil
|
||||
}
|
||||
return &service.Group{
|
||||
ID: g.ID,
|
||||
Name: g.Name,
|
||||
Description: derefString(g.Description),
|
||||
Platform: g.Platform,
|
||||
RateMultiplier: g.RateMultiplier,
|
||||
IsExclusive: g.IsExclusive,
|
||||
Status: g.Status,
|
||||
Hydrated: true,
|
||||
SubscriptionType: g.SubscriptionType,
|
||||
DailyLimitUSD: g.DailyLimitUsd,
|
||||
WeeklyLimitUSD: g.WeeklyLimitUsd,
|
||||
MonthlyLimitUSD: g.MonthlyLimitUsd,
|
||||
ImagePrice1K: g.ImagePrice1k,
|
||||
ImagePrice2K: g.ImagePrice2k,
|
||||
ImagePrice4K: g.ImagePrice4k,
|
||||
DefaultValidityDays: g.DefaultValidityDays,
|
||||
ClaudeCodeOnly: g.ClaudeCodeOnly,
|
||||
FallbackGroupID: g.FallbackGroupID,
|
||||
ModelRouting: g.ModelRouting,
|
||||
ModelRoutingEnabled: g.ModelRoutingEnabled,
|
||||
CreatedAt: g.CreatedAt,
|
||||
UpdatedAt: g.UpdatedAt,
|
||||
ID: g.ID,
|
||||
Name: g.Name,
|
||||
Description: derefString(g.Description),
|
||||
Platform: g.Platform,
|
||||
RateMultiplier: g.RateMultiplier,
|
||||
IsExclusive: g.IsExclusive,
|
||||
Status: g.Status,
|
||||
Hydrated: true,
|
||||
SubscriptionType: g.SubscriptionType,
|
||||
DailyLimitUSD: g.DailyLimitUsd,
|
||||
WeeklyLimitUSD: g.WeeklyLimitUsd,
|
||||
MonthlyLimitUSD: g.MonthlyLimitUsd,
|
||||
ImagePrice1K: g.ImagePrice1k,
|
||||
ImagePrice2K: g.ImagePrice2k,
|
||||
ImagePrice4K: g.ImagePrice4k,
|
||||
DefaultValidityDays: g.DefaultValidityDays,
|
||||
ClaudeCodeOnly: g.ClaudeCodeOnly,
|
||||
FallbackGroupID: g.FallbackGroupID,
|
||||
FallbackGroupIDOnInvalidRequest: g.FallbackGroupIDOnInvalidRequest,
|
||||
ModelRouting: g.ModelRouting,
|
||||
ModelRoutingEnabled: g.ModelRoutingEnabled,
|
||||
MCPXMLInject: g.McpXMLInject,
|
||||
SupportedModelScopes: g.SupportedModelScopes,
|
||||
SortOrder: g.SortOrder,
|
||||
CreatedAt: g.CreatedAt,
|
||||
UpdatedAt: g.UpdatedAt,
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -194,6 +194,53 @@ var (
|
||||
return result
|
||||
`)
|
||||
|
||||
// getUsersLoadBatchScript - batch load query for users with expired slot cleanup
|
||||
// ARGV[1] = slot TTL (seconds)
|
||||
// ARGV[2..n] = userID1, maxConcurrency1, userID2, maxConcurrency2, ...
|
||||
getUsersLoadBatchScript = redis.NewScript(`
|
||||
local result = {}
|
||||
local slotTTL = tonumber(ARGV[1])
|
||||
|
||||
-- Get current server time
|
||||
local timeResult = redis.call('TIME')
|
||||
local nowSeconds = tonumber(timeResult[1])
|
||||
local cutoffTime = nowSeconds - slotTTL
|
||||
|
||||
local i = 2
|
||||
while i <= #ARGV do
|
||||
local userID = ARGV[i]
|
||||
local maxConcurrency = tonumber(ARGV[i + 1])
|
||||
|
||||
local slotKey = 'concurrency:user:' .. userID
|
||||
|
||||
-- Clean up expired slots before counting
|
||||
redis.call('ZREMRANGEBYSCORE', slotKey, '-inf', cutoffTime)
|
||||
local currentConcurrency = redis.call('ZCARD', slotKey)
|
||||
|
||||
local waitKey = 'concurrency:wait:' .. userID
|
||||
local waitingCount = redis.call('GET', waitKey)
|
||||
if waitingCount == false then
|
||||
waitingCount = 0
|
||||
else
|
||||
waitingCount = tonumber(waitingCount)
|
||||
end
|
||||
|
||||
local loadRate = 0
|
||||
if maxConcurrency > 0 then
|
||||
loadRate = math.floor((currentConcurrency + waitingCount) * 100 / maxConcurrency)
|
||||
end
|
||||
|
||||
table.insert(result, userID)
|
||||
table.insert(result, currentConcurrency)
|
||||
table.insert(result, waitingCount)
|
||||
table.insert(result, loadRate)
|
||||
|
||||
i = i + 2
|
||||
end
|
||||
|
||||
return result
|
||||
`)
|
||||
|
||||
// cleanupExpiredSlotsScript - remove expired slots
|
||||
// KEYS[1] = concurrency:account:{accountID}
|
||||
// ARGV[1] = TTL (seconds)
|
||||
@@ -384,6 +431,43 @@ func (c *concurrencyCache) GetAccountsLoadBatch(ctx context.Context, accounts []
|
||||
return loadMap, nil
|
||||
}
|
||||
|
||||
func (c *concurrencyCache) GetUsersLoadBatch(ctx context.Context, users []service.UserWithConcurrency) (map[int64]*service.UserLoadInfo, error) {
|
||||
if len(users) == 0 {
|
||||
return map[int64]*service.UserLoadInfo{}, nil
|
||||
}
|
||||
|
||||
args := []any{c.slotTTLSeconds}
|
||||
for _, u := range users {
|
||||
args = append(args, u.ID, u.MaxConcurrency)
|
||||
}
|
||||
|
||||
result, err := getUsersLoadBatchScript.Run(ctx, c.rdb, []string{}, args...).Slice()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
loadMap := make(map[int64]*service.UserLoadInfo)
|
||||
for i := 0; i < len(result); i += 4 {
|
||||
if i+3 >= len(result) {
|
||||
break
|
||||
}
|
||||
|
||||
userID, _ := strconv.ParseInt(fmt.Sprintf("%v", result[i]), 10, 64)
|
||||
currentConcurrency, _ := strconv.Atoi(fmt.Sprintf("%v", result[i+1]))
|
||||
waitingCount, _ := strconv.Atoi(fmt.Sprintf("%v", result[i+2]))
|
||||
loadRate, _ := strconv.Atoi(fmt.Sprintf("%v", result[i+3]))
|
||||
|
||||
loadMap[userID] = &service.UserLoadInfo{
|
||||
UserID: userID,
|
||||
CurrentConcurrency: currentConcurrency,
|
||||
WaitingCount: waitingCount,
|
||||
LoadRate: loadRate,
|
||||
}
|
||||
}
|
||||
|
||||
return loadMap, nil
|
||||
}
|
||||
|
||||
func (c *concurrencyCache) CleanupExpiredAccountSlots(ctx context.Context, accountID int64) error {
|
||||
key := accountSlotKey(accountID)
|
||||
_, err := cleanupExpiredSlotsScript.Run(ctx, c.rdb, []string{key}, c.slotTTLSeconds).Result()
|
||||
|
||||
128
backend/internal/repository/error_passthrough_cache.go
Normal file
128
backend/internal/repository/error_passthrough_cache.go
Normal file
@@ -0,0 +1,128 @@
|
||||
package repository
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"log"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/model"
|
||||
"github.com/Wei-Shaw/sub2api/internal/service"
|
||||
"github.com/redis/go-redis/v9"
|
||||
)
|
||||
|
||||
const (
|
||||
errorPassthroughCacheKey = "error_passthrough_rules"
|
||||
errorPassthroughPubSubKey = "error_passthrough_rules_updated"
|
||||
errorPassthroughCacheTTL = 24 * time.Hour
|
||||
)
|
||||
|
||||
type errorPassthroughCache struct {
|
||||
rdb *redis.Client
|
||||
localCache []*model.ErrorPassthroughRule
|
||||
localMu sync.RWMutex
|
||||
}
|
||||
|
||||
// NewErrorPassthroughCache 创建错误透传规则缓存
|
||||
func NewErrorPassthroughCache(rdb *redis.Client) service.ErrorPassthroughCache {
|
||||
return &errorPassthroughCache{
|
||||
rdb: rdb,
|
||||
}
|
||||
}
|
||||
|
||||
// Get 从缓存获取规则列表
|
||||
func (c *errorPassthroughCache) Get(ctx context.Context) ([]*model.ErrorPassthroughRule, bool) {
|
||||
// 先检查本地缓存
|
||||
c.localMu.RLock()
|
||||
if c.localCache != nil {
|
||||
rules := c.localCache
|
||||
c.localMu.RUnlock()
|
||||
return rules, true
|
||||
}
|
||||
c.localMu.RUnlock()
|
||||
|
||||
// 从 Redis 获取
|
||||
data, err := c.rdb.Get(ctx, errorPassthroughCacheKey).Bytes()
|
||||
if err != nil {
|
||||
if err != redis.Nil {
|
||||
log.Printf("[ErrorPassthroughCache] Failed to get from Redis: %v", err)
|
||||
}
|
||||
return nil, false
|
||||
}
|
||||
|
||||
var rules []*model.ErrorPassthroughRule
|
||||
if err := json.Unmarshal(data, &rules); err != nil {
|
||||
log.Printf("[ErrorPassthroughCache] Failed to unmarshal rules: %v", err)
|
||||
return nil, false
|
||||
}
|
||||
|
||||
// 更新本地缓存
|
||||
c.localMu.Lock()
|
||||
c.localCache = rules
|
||||
c.localMu.Unlock()
|
||||
|
||||
return rules, true
|
||||
}
|
||||
|
||||
// Set 设置缓存
|
||||
func (c *errorPassthroughCache) Set(ctx context.Context, rules []*model.ErrorPassthroughRule) error {
|
||||
data, err := json.Marshal(rules)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if err := c.rdb.Set(ctx, errorPassthroughCacheKey, data, errorPassthroughCacheTTL).Err(); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// 更新本地缓存
|
||||
c.localMu.Lock()
|
||||
c.localCache = rules
|
||||
c.localMu.Unlock()
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// Invalidate 使缓存失效
|
||||
func (c *errorPassthroughCache) Invalidate(ctx context.Context) error {
|
||||
// 清除本地缓存
|
||||
c.localMu.Lock()
|
||||
c.localCache = nil
|
||||
c.localMu.Unlock()
|
||||
|
||||
// 清除 Redis 缓存
|
||||
return c.rdb.Del(ctx, errorPassthroughCacheKey).Err()
|
||||
}
|
||||
|
||||
// NotifyUpdate 通知其他实例刷新缓存
|
||||
func (c *errorPassthroughCache) NotifyUpdate(ctx context.Context) error {
|
||||
return c.rdb.Publish(ctx, errorPassthroughPubSubKey, "refresh").Err()
|
||||
}
|
||||
|
||||
// SubscribeUpdates 订阅缓存更新通知
|
||||
func (c *errorPassthroughCache) SubscribeUpdates(ctx context.Context, handler func()) {
|
||||
go func() {
|
||||
sub := c.rdb.Subscribe(ctx, errorPassthroughPubSubKey)
|
||||
defer func() { _ = sub.Close() }()
|
||||
|
||||
ch := sub.Channel()
|
||||
for {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return
|
||||
case msg := <-ch:
|
||||
if msg == nil {
|
||||
return
|
||||
}
|
||||
// 清除本地缓存,下次访问时会从 Redis 或数据库重新加载
|
||||
c.localMu.Lock()
|
||||
c.localCache = nil
|
||||
c.localMu.Unlock()
|
||||
|
||||
// 调用处理函数
|
||||
handler()
|
||||
}
|
||||
}
|
||||
}()
|
||||
}
|
||||
178
backend/internal/repository/error_passthrough_repo.go
Normal file
178
backend/internal/repository/error_passthrough_repo.go
Normal file
@@ -0,0 +1,178 @@
|
||||
package repository
|
||||
|
||||
import (
|
||||
"context"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/ent"
|
||||
"github.com/Wei-Shaw/sub2api/ent/errorpassthroughrule"
|
||||
"github.com/Wei-Shaw/sub2api/internal/model"
|
||||
"github.com/Wei-Shaw/sub2api/internal/service"
|
||||
)
|
||||
|
||||
type errorPassthroughRepository struct {
|
||||
client *ent.Client
|
||||
}
|
||||
|
||||
// NewErrorPassthroughRepository 创建错误透传规则仓库
|
||||
func NewErrorPassthroughRepository(client *ent.Client) service.ErrorPassthroughRepository {
|
||||
return &errorPassthroughRepository{client: client}
|
||||
}
|
||||
|
||||
// List 获取所有规则
|
||||
func (r *errorPassthroughRepository) List(ctx context.Context) ([]*model.ErrorPassthroughRule, error) {
|
||||
rules, err := r.client.ErrorPassthroughRule.Query().
|
||||
Order(ent.Asc(errorpassthroughrule.FieldPriority)).
|
||||
All(ctx)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
result := make([]*model.ErrorPassthroughRule, len(rules))
|
||||
for i, rule := range rules {
|
||||
result[i] = r.toModel(rule)
|
||||
}
|
||||
return result, nil
|
||||
}
|
||||
|
||||
// GetByID 根据 ID 获取规则
|
||||
func (r *errorPassthroughRepository) GetByID(ctx context.Context, id int64) (*model.ErrorPassthroughRule, error) {
|
||||
rule, err := r.client.ErrorPassthroughRule.Get(ctx, id)
|
||||
if err != nil {
|
||||
if ent.IsNotFound(err) {
|
||||
return nil, nil
|
||||
}
|
||||
return nil, err
|
||||
}
|
||||
return r.toModel(rule), nil
|
||||
}
|
||||
|
||||
// Create 创建规则
|
||||
func (r *errorPassthroughRepository) Create(ctx context.Context, rule *model.ErrorPassthroughRule) (*model.ErrorPassthroughRule, error) {
|
||||
builder := r.client.ErrorPassthroughRule.Create().
|
||||
SetName(rule.Name).
|
||||
SetEnabled(rule.Enabled).
|
||||
SetPriority(rule.Priority).
|
||||
SetMatchMode(rule.MatchMode).
|
||||
SetPassthroughCode(rule.PassthroughCode).
|
||||
SetPassthroughBody(rule.PassthroughBody)
|
||||
|
||||
if len(rule.ErrorCodes) > 0 {
|
||||
builder.SetErrorCodes(rule.ErrorCodes)
|
||||
}
|
||||
if len(rule.Keywords) > 0 {
|
||||
builder.SetKeywords(rule.Keywords)
|
||||
}
|
||||
if len(rule.Platforms) > 0 {
|
||||
builder.SetPlatforms(rule.Platforms)
|
||||
}
|
||||
if rule.ResponseCode != nil {
|
||||
builder.SetResponseCode(*rule.ResponseCode)
|
||||
}
|
||||
if rule.CustomMessage != nil {
|
||||
builder.SetCustomMessage(*rule.CustomMessage)
|
||||
}
|
||||
if rule.Description != nil {
|
||||
builder.SetDescription(*rule.Description)
|
||||
}
|
||||
|
||||
created, err := builder.Save(ctx)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return r.toModel(created), nil
|
||||
}
|
||||
|
||||
// Update 更新规则
|
||||
func (r *errorPassthroughRepository) Update(ctx context.Context, rule *model.ErrorPassthroughRule) (*model.ErrorPassthroughRule, error) {
|
||||
builder := r.client.ErrorPassthroughRule.UpdateOneID(rule.ID).
|
||||
SetName(rule.Name).
|
||||
SetEnabled(rule.Enabled).
|
||||
SetPriority(rule.Priority).
|
||||
SetMatchMode(rule.MatchMode).
|
||||
SetPassthroughCode(rule.PassthroughCode).
|
||||
SetPassthroughBody(rule.PassthroughBody)
|
||||
|
||||
// 处理可选字段
|
||||
if len(rule.ErrorCodes) > 0 {
|
||||
builder.SetErrorCodes(rule.ErrorCodes)
|
||||
} else {
|
||||
builder.ClearErrorCodes()
|
||||
}
|
||||
if len(rule.Keywords) > 0 {
|
||||
builder.SetKeywords(rule.Keywords)
|
||||
} else {
|
||||
builder.ClearKeywords()
|
||||
}
|
||||
if len(rule.Platforms) > 0 {
|
||||
builder.SetPlatforms(rule.Platforms)
|
||||
} else {
|
||||
builder.ClearPlatforms()
|
||||
}
|
||||
if rule.ResponseCode != nil {
|
||||
builder.SetResponseCode(*rule.ResponseCode)
|
||||
} else {
|
||||
builder.ClearResponseCode()
|
||||
}
|
||||
if rule.CustomMessage != nil {
|
||||
builder.SetCustomMessage(*rule.CustomMessage)
|
||||
} else {
|
||||
builder.ClearCustomMessage()
|
||||
}
|
||||
if rule.Description != nil {
|
||||
builder.SetDescription(*rule.Description)
|
||||
} else {
|
||||
builder.ClearDescription()
|
||||
}
|
||||
|
||||
updated, err := builder.Save(ctx)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return r.toModel(updated), nil
|
||||
}
|
||||
|
||||
// Delete 删除规则
|
||||
func (r *errorPassthroughRepository) Delete(ctx context.Context, id int64) error {
|
||||
return r.client.ErrorPassthroughRule.DeleteOneID(id).Exec(ctx)
|
||||
}
|
||||
|
||||
// toModel 将 Ent 实体转换为服务模型
|
||||
func (r *errorPassthroughRepository) toModel(e *ent.ErrorPassthroughRule) *model.ErrorPassthroughRule {
|
||||
rule := &model.ErrorPassthroughRule{
|
||||
ID: int64(e.ID),
|
||||
Name: e.Name,
|
||||
Enabled: e.Enabled,
|
||||
Priority: e.Priority,
|
||||
ErrorCodes: e.ErrorCodes,
|
||||
Keywords: e.Keywords,
|
||||
MatchMode: e.MatchMode,
|
||||
Platforms: e.Platforms,
|
||||
PassthroughCode: e.PassthroughCode,
|
||||
PassthroughBody: e.PassthroughBody,
|
||||
CreatedAt: e.CreatedAt,
|
||||
UpdatedAt: e.UpdatedAt,
|
||||
}
|
||||
|
||||
if e.ResponseCode != nil {
|
||||
rule.ResponseCode = e.ResponseCode
|
||||
}
|
||||
if e.CustomMessage != nil {
|
||||
rule.CustomMessage = e.CustomMessage
|
||||
}
|
||||
if e.Description != nil {
|
||||
rule.Description = e.Description
|
||||
}
|
||||
|
||||
// 确保切片不为 nil
|
||||
if rule.ErrorCodes == nil {
|
||||
rule.ErrorCodes = []int{}
|
||||
}
|
||||
if rule.Keywords == nil {
|
||||
rule.Keywords = []string{}
|
||||
}
|
||||
if rule.Platforms == nil {
|
||||
rule.Platforms = []string{}
|
||||
}
|
||||
|
||||
return rule
|
||||
}
|
||||
@@ -104,6 +104,7 @@ func (s *GatewayCacheSuite) TestGetSessionAccountID_CorruptedValue() {
|
||||
require.False(s.T(), errors.Is(err, redis.Nil), "expected parsing error, not redis.Nil")
|
||||
}
|
||||
|
||||
|
||||
func TestGatewayCacheSuite(t *testing.T) {
|
||||
suite.Run(t, new(GatewayCacheSuite))
|
||||
}
|
||||
|
||||
@@ -6,6 +6,7 @@ import (
|
||||
"time"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/geminicli"
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/googleapi"
|
||||
"github.com/Wei-Shaw/sub2api/internal/service"
|
||||
|
||||
"github.com/imroc/req/v3"
|
||||
@@ -38,9 +39,20 @@ func (c *geminiCliCodeAssistClient) LoadCodeAssist(ctx context.Context, accessTo
|
||||
return nil, fmt.Errorf("request failed: %w", err)
|
||||
}
|
||||
if !resp.IsSuccessState() {
|
||||
body := geminicli.SanitizeBodyForLogs(resp.String())
|
||||
fmt.Printf("[CodeAssist] LoadCodeAssist failed: status %d, body: %s\n", resp.StatusCode, body)
|
||||
return nil, fmt.Errorf("loadCodeAssist failed: status %d, body: %s", resp.StatusCode, body)
|
||||
body := resp.String()
|
||||
sanitizedBody := geminicli.SanitizeBodyForLogs(body)
|
||||
fmt.Printf("[CodeAssist] LoadCodeAssist failed: status %d, body: %s\n", resp.StatusCode, sanitizedBody)
|
||||
|
||||
// Check if this is a SERVICE_DISABLED error and extract activation URL
|
||||
if googleapi.IsServiceDisabledError(body) {
|
||||
activationURL := googleapi.ExtractActivationURL(body)
|
||||
if activationURL != "" {
|
||||
return nil, fmt.Errorf("gemini API not enabled for this project, please enable it by visiting: %s\n\nAfter enabling the API, wait a few minutes for the changes to propagate, then try again", activationURL)
|
||||
}
|
||||
return nil, fmt.Errorf("gemini API not enabled for this project, please enable it in the Google Cloud Console at: https://console.cloud.google.com/apis/library/cloudaicompanion.googleapis.com")
|
||||
}
|
||||
|
||||
return nil, fmt.Errorf("loadCodeAssist failed: status %d, body: %s", resp.StatusCode, sanitizedBody)
|
||||
}
|
||||
fmt.Printf("[CodeAssist] LoadCodeAssist success: status %d, response: %+v\n", resp.StatusCode, out)
|
||||
return &out, nil
|
||||
@@ -67,9 +79,20 @@ func (c *geminiCliCodeAssistClient) OnboardUser(ctx context.Context, accessToken
|
||||
return nil, fmt.Errorf("request failed: %w", err)
|
||||
}
|
||||
if !resp.IsSuccessState() {
|
||||
body := geminicli.SanitizeBodyForLogs(resp.String())
|
||||
fmt.Printf("[CodeAssist] OnboardUser failed: status %d, body: %s\n", resp.StatusCode, body)
|
||||
return nil, fmt.Errorf("onboardUser failed: status %d, body: %s", resp.StatusCode, body)
|
||||
body := resp.String()
|
||||
sanitizedBody := geminicli.SanitizeBodyForLogs(body)
|
||||
fmt.Printf("[CodeAssist] OnboardUser failed: status %d, body: %s\n", resp.StatusCode, sanitizedBody)
|
||||
|
||||
// Check if this is a SERVICE_DISABLED error and extract activation URL
|
||||
if googleapi.IsServiceDisabledError(body) {
|
||||
activationURL := googleapi.ExtractActivationURL(body)
|
||||
if activationURL != "" {
|
||||
return nil, fmt.Errorf("gemini API not enabled for this project, please enable it by visiting: %s\n\nAfter enabling the API, wait a few minutes for the changes to propagate, then try again", activationURL)
|
||||
}
|
||||
return nil, fmt.Errorf("gemini API not enabled for this project, please enable it in the Google Cloud Console at: https://console.cloud.google.com/apis/library/cloudaicompanion.googleapis.com")
|
||||
}
|
||||
|
||||
return nil, fmt.Errorf("onboardUser failed: status %d, body: %s", resp.StatusCode, sanitizedBody)
|
||||
}
|
||||
fmt.Printf("[CodeAssist] OnboardUser success: status %d, response: %+v\n", resp.StatusCode, out)
|
||||
return &out, nil
|
||||
|
||||
@@ -98,12 +98,16 @@ func (c *githubReleaseClient) DownloadFile(ctx context.Context, url, dest string
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer func() { _ = out.Close() }()
|
||||
|
||||
// SECURITY: Use LimitReader to enforce max download size even if Content-Length is missing/wrong
|
||||
limited := io.LimitReader(resp.Body, maxSize+1)
|
||||
written, err := io.Copy(out, limited)
|
||||
|
||||
// Close file before attempting to remove (required on Windows)
|
||||
_ = out.Close()
|
||||
|
||||
if err != nil {
|
||||
_ = os.Remove(dest) // Clean up partial file (best-effort)
|
||||
return err
|
||||
}
|
||||
|
||||
|
||||
@@ -50,13 +50,18 @@ func (r *groupRepository) Create(ctx context.Context, groupIn *service.Group) er
|
||||
SetDefaultValidityDays(groupIn.DefaultValidityDays).
|
||||
SetClaudeCodeOnly(groupIn.ClaudeCodeOnly).
|
||||
SetNillableFallbackGroupID(groupIn.FallbackGroupID).
|
||||
SetModelRoutingEnabled(groupIn.ModelRoutingEnabled)
|
||||
SetNillableFallbackGroupIDOnInvalidRequest(groupIn.FallbackGroupIDOnInvalidRequest).
|
||||
SetModelRoutingEnabled(groupIn.ModelRoutingEnabled).
|
||||
SetMcpXMLInject(groupIn.MCPXMLInject)
|
||||
|
||||
// 设置模型路由配置
|
||||
if groupIn.ModelRouting != nil {
|
||||
builder = builder.SetModelRouting(groupIn.ModelRouting)
|
||||
}
|
||||
|
||||
// 设置支持的模型系列(始终设置,空数组表示不限制)
|
||||
builder = builder.SetSupportedModelScopes(groupIn.SupportedModelScopes)
|
||||
|
||||
created, err := builder.Save(ctx)
|
||||
if err == nil {
|
||||
groupIn.ID = created.ID
|
||||
@@ -87,7 +92,6 @@ func (r *groupRepository) GetByIDLite(ctx context.Context, id int64) (*service.G
|
||||
if err != nil {
|
||||
return nil, translatePersistenceError(err, service.ErrGroupNotFound, nil)
|
||||
}
|
||||
|
||||
return groupEntityToService(m), nil
|
||||
}
|
||||
|
||||
@@ -108,7 +112,8 @@ func (r *groupRepository) Update(ctx context.Context, groupIn *service.Group) er
|
||||
SetNillableImagePrice4k(groupIn.ImagePrice4K).
|
||||
SetDefaultValidityDays(groupIn.DefaultValidityDays).
|
||||
SetClaudeCodeOnly(groupIn.ClaudeCodeOnly).
|
||||
SetModelRoutingEnabled(groupIn.ModelRoutingEnabled)
|
||||
SetModelRoutingEnabled(groupIn.ModelRoutingEnabled).
|
||||
SetMcpXMLInject(groupIn.MCPXMLInject)
|
||||
|
||||
// 处理 FallbackGroupID:nil 时清除,否则设置
|
||||
if groupIn.FallbackGroupID != nil {
|
||||
@@ -116,6 +121,12 @@ func (r *groupRepository) Update(ctx context.Context, groupIn *service.Group) er
|
||||
} else {
|
||||
builder = builder.ClearFallbackGroupID()
|
||||
}
|
||||
// 处理 FallbackGroupIDOnInvalidRequest:nil 时清除,否则设置
|
||||
if groupIn.FallbackGroupIDOnInvalidRequest != nil {
|
||||
builder = builder.SetFallbackGroupIDOnInvalidRequest(*groupIn.FallbackGroupIDOnInvalidRequest)
|
||||
} else {
|
||||
builder = builder.ClearFallbackGroupIDOnInvalidRequest()
|
||||
}
|
||||
|
||||
// 处理 ModelRouting:nil 时清除,否则设置
|
||||
if groupIn.ModelRouting != nil {
|
||||
@@ -124,6 +135,9 @@ func (r *groupRepository) Update(ctx context.Context, groupIn *service.Group) er
|
||||
builder = builder.ClearModelRouting()
|
||||
}
|
||||
|
||||
// 处理 SupportedModelScopes(始终设置,空数组表示不限制)
|
||||
builder = builder.SetSupportedModelScopes(groupIn.SupportedModelScopes)
|
||||
|
||||
updated, err := builder.Save(ctx)
|
||||
if err != nil {
|
||||
return translatePersistenceError(err, service.ErrGroupNotFound, service.ErrGroupExists)
|
||||
@@ -177,7 +191,7 @@ func (r *groupRepository) ListWithFilters(ctx context.Context, params pagination
|
||||
groups, err := q.
|
||||
Offset(params.Offset()).
|
||||
Limit(params.Limit()).
|
||||
Order(dbent.Asc(group.FieldID)).
|
||||
Order(dbent.Asc(group.FieldSortOrder), dbent.Asc(group.FieldID)).
|
||||
All(ctx)
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
@@ -204,7 +218,7 @@ func (r *groupRepository) ListWithFilters(ctx context.Context, params pagination
|
||||
func (r *groupRepository) ListActive(ctx context.Context) ([]service.Group, error) {
|
||||
groups, err := r.client.Group.Query().
|
||||
Where(group.StatusEQ(service.StatusActive)).
|
||||
Order(dbent.Asc(group.FieldID)).
|
||||
Order(dbent.Asc(group.FieldSortOrder), dbent.Asc(group.FieldID)).
|
||||
All(ctx)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
@@ -231,7 +245,7 @@ func (r *groupRepository) ListActive(ctx context.Context) ([]service.Group, erro
|
||||
func (r *groupRepository) ListActiveByPlatform(ctx context.Context, platform string) ([]service.Group, error) {
|
||||
groups, err := r.client.Group.Query().
|
||||
Where(group.StatusEQ(service.StatusActive), group.PlatformEQ(platform)).
|
||||
Order(dbent.Asc(group.FieldID)).
|
||||
Order(dbent.Asc(group.FieldSortOrder), dbent.Asc(group.FieldID)).
|
||||
All(ctx)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
@@ -425,3 +439,87 @@ func (r *groupRepository) loadAccountCounts(ctx context.Context, groupIDs []int6
|
||||
|
||||
return counts, nil
|
||||
}
|
||||
|
||||
// GetAccountIDsByGroupIDs 获取多个分组的所有账号 ID(去重)
|
||||
func (r *groupRepository) GetAccountIDsByGroupIDs(ctx context.Context, groupIDs []int64) ([]int64, error) {
|
||||
if len(groupIDs) == 0 {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
rows, err := r.sql.QueryContext(
|
||||
ctx,
|
||||
"SELECT DISTINCT account_id FROM account_groups WHERE group_id = ANY($1) ORDER BY account_id",
|
||||
pq.Array(groupIDs),
|
||||
)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer func() { _ = rows.Close() }()
|
||||
|
||||
var accountIDs []int64
|
||||
for rows.Next() {
|
||||
var accountID int64
|
||||
if err := rows.Scan(&accountID); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
accountIDs = append(accountIDs, accountID)
|
||||
}
|
||||
if err := rows.Err(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return accountIDs, nil
|
||||
}
|
||||
|
||||
// BindAccountsToGroup 将多个账号绑定到指定分组(批量插入,忽略已存在的绑定)
|
||||
func (r *groupRepository) BindAccountsToGroup(ctx context.Context, groupID int64, accountIDs []int64) error {
|
||||
if len(accountIDs) == 0 {
|
||||
return nil
|
||||
}
|
||||
|
||||
// 使用 INSERT ... ON CONFLICT DO NOTHING 忽略已存在的绑定
|
||||
_, err := r.sql.ExecContext(
|
||||
ctx,
|
||||
`INSERT INTO account_groups (account_id, group_id, priority, created_at)
|
||||
SELECT unnest($1::bigint[]), $2, 50, NOW()
|
||||
ON CONFLICT (account_id, group_id) DO NOTHING`,
|
||||
pq.Array(accountIDs),
|
||||
groupID,
|
||||
)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// 发送调度器事件
|
||||
if err := enqueueSchedulerOutbox(ctx, r.sql, service.SchedulerOutboxEventGroupChanged, nil, &groupID, nil); err != nil {
|
||||
log.Printf("[SchedulerOutbox] enqueue bind accounts to group failed: group=%d err=%v", groupID, err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// UpdateSortOrders 批量更新分组排序
|
||||
func (r *groupRepository) UpdateSortOrders(ctx context.Context, updates []service.GroupSortOrderUpdate) error {
|
||||
if len(updates) == 0 {
|
||||
return nil
|
||||
}
|
||||
|
||||
// 使用事务批量更新
|
||||
tx, err := r.client.Tx(ctx)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer func() { _ = tx.Rollback() }()
|
||||
|
||||
for _, u := range updates {
|
||||
if _, err := tx.Group.UpdateOneID(u.ID).SetSortOrder(u.SortOrder).Save(ctx); err != nil {
|
||||
return translatePersistenceError(err, service.ErrGroupNotFound, nil)
|
||||
}
|
||||
}
|
||||
|
||||
if err := tx.Commit(); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -43,6 +43,7 @@ INSERT INTO ops_system_metrics (
|
||||
upstream_529_count,
|
||||
|
||||
token_consumed,
|
||||
account_switch_count,
|
||||
qps,
|
||||
tps,
|
||||
|
||||
@@ -81,14 +82,14 @@ INSERT INTO ops_system_metrics (
|
||||
$1,$2,$3,$4,
|
||||
$5,$6,$7,$8,
|
||||
$9,$10,$11,
|
||||
$12,$13,$14,
|
||||
$15,$16,$17,$18,$19,$20,
|
||||
$21,$22,$23,$24,$25,$26,
|
||||
$27,$28,$29,$30,
|
||||
$31,$32,
|
||||
$33,$34,
|
||||
$35,$36,$37,
|
||||
$38,$39
|
||||
$12,$13,$14,$15,
|
||||
$16,$17,$18,$19,$20,$21,
|
||||
$22,$23,$24,$25,$26,$27,
|
||||
$28,$29,$30,$31,
|
||||
$32,$33,
|
||||
$34,$35,
|
||||
$36,$37,$38,
|
||||
$39,$40
|
||||
)`
|
||||
|
||||
_, err := r.db.ExecContext(
|
||||
@@ -109,6 +110,7 @@ INSERT INTO ops_system_metrics (
|
||||
input.Upstream529Count,
|
||||
|
||||
input.TokenConsumed,
|
||||
input.AccountSwitchCount,
|
||||
opsNullFloat64(input.QPS),
|
||||
opsNullFloat64(input.TPS),
|
||||
|
||||
@@ -177,7 +179,8 @@ SELECT
|
||||
db_conn_waiting,
|
||||
|
||||
goroutine_count,
|
||||
concurrency_queue_depth
|
||||
concurrency_queue_depth,
|
||||
account_switch_count
|
||||
FROM ops_system_metrics
|
||||
WHERE window_minutes = $1
|
||||
AND platform IS NULL
|
||||
@@ -199,6 +202,7 @@ LIMIT 1`
|
||||
var dbWaiting sql.NullInt64
|
||||
var goroutines sql.NullInt64
|
||||
var queueDepth sql.NullInt64
|
||||
var accountSwitchCount sql.NullInt64
|
||||
|
||||
if err := r.db.QueryRowContext(ctx, q, windowMinutes).Scan(
|
||||
&out.ID,
|
||||
@@ -217,6 +221,7 @@ LIMIT 1`
|
||||
&dbWaiting,
|
||||
&goroutines,
|
||||
&queueDepth,
|
||||
&accountSwitchCount,
|
||||
); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@@ -273,6 +278,10 @@ LIMIT 1`
|
||||
v := int(queueDepth.Int64)
|
||||
out.ConcurrencyQueueDepth = &v
|
||||
}
|
||||
if accountSwitchCount.Valid {
|
||||
v := accountSwitchCount.Int64
|
||||
out.AccountSwitchCount = &v
|
||||
}
|
||||
|
||||
return &out, nil
|
||||
}
|
||||
|
||||
@@ -56,18 +56,44 @@ error_buckets AS (
|
||||
AND COALESCE(status_code, 0) >= 400
|
||||
GROUP BY 1
|
||||
),
|
||||
switch_buckets AS (
|
||||
SELECT ` + errorBucketExpr + ` AS bucket,
|
||||
COALESCE(SUM(CASE
|
||||
WHEN split_part(ev->>'kind', ':', 1) IN ('failover', 'retry_exhausted_failover', 'failover_on_400') THEN 1
|
||||
ELSE 0
|
||||
END), 0) AS switch_count
|
||||
FROM ops_error_logs
|
||||
CROSS JOIN LATERAL jsonb_array_elements(
|
||||
COALESCE(NULLIF(upstream_errors, 'null'::jsonb), '[]'::jsonb)
|
||||
) AS ev
|
||||
` + errorWhere + `
|
||||
AND upstream_errors IS NOT NULL
|
||||
GROUP BY 1
|
||||
),
|
||||
combined AS (
|
||||
SELECT COALESCE(u.bucket, e.bucket) AS bucket,
|
||||
COALESCE(u.success_count, 0) AS success_count,
|
||||
COALESCE(e.error_count, 0) AS error_count,
|
||||
COALESCE(u.token_consumed, 0) AS token_consumed
|
||||
FROM usage_buckets u
|
||||
FULL OUTER JOIN error_buckets e ON u.bucket = e.bucket
|
||||
SELECT
|
||||
bucket,
|
||||
SUM(success_count) AS success_count,
|
||||
SUM(error_count) AS error_count,
|
||||
SUM(token_consumed) AS token_consumed,
|
||||
SUM(switch_count) AS switch_count
|
||||
FROM (
|
||||
SELECT bucket, success_count, 0 AS error_count, token_consumed, 0 AS switch_count
|
||||
FROM usage_buckets
|
||||
UNION ALL
|
||||
SELECT bucket, 0, error_count, 0, 0
|
||||
FROM error_buckets
|
||||
UNION ALL
|
||||
SELECT bucket, 0, 0, 0, switch_count
|
||||
FROM switch_buckets
|
||||
) t
|
||||
GROUP BY bucket
|
||||
)
|
||||
SELECT
|
||||
bucket,
|
||||
(success_count + error_count) AS request_count,
|
||||
token_consumed
|
||||
token_consumed,
|
||||
switch_count
|
||||
FROM combined
|
||||
ORDER BY bucket ASC`
|
||||
|
||||
@@ -84,13 +110,18 @@ ORDER BY bucket ASC`
|
||||
var bucket time.Time
|
||||
var requests int64
|
||||
var tokens sql.NullInt64
|
||||
if err := rows.Scan(&bucket, &requests, &tokens); err != nil {
|
||||
var switches sql.NullInt64
|
||||
if err := rows.Scan(&bucket, &requests, &tokens, &switches); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
tokenConsumed := int64(0)
|
||||
if tokens.Valid {
|
||||
tokenConsumed = tokens.Int64
|
||||
}
|
||||
switchCount := int64(0)
|
||||
if switches.Valid {
|
||||
switchCount = switches.Int64
|
||||
}
|
||||
|
||||
denom := float64(bucketSeconds)
|
||||
if denom <= 0 {
|
||||
@@ -103,6 +134,7 @@ ORDER BY bucket ASC`
|
||||
BucketStart: bucket.UTC(),
|
||||
RequestCount: requests,
|
||||
TokenConsumed: tokenConsumed,
|
||||
SwitchCount: switchCount,
|
||||
QPS: qps,
|
||||
TPS: tps,
|
||||
})
|
||||
@@ -385,6 +417,7 @@ func fillOpsThroughputBuckets(start, end time.Time, bucketSeconds int, points []
|
||||
BucketStart: cursor,
|
||||
RequestCount: 0,
|
||||
TokenConsumed: 0,
|
||||
SwitchCount: 0,
|
||||
QPS: 0,
|
||||
TPS: 0,
|
||||
})
|
||||
|
||||
@@ -28,7 +28,6 @@ func NewProxyExitInfoProber(cfg *config.Config) service.ProxyExitInfoProber {
|
||||
log.Printf("[ProxyProbe] Warning: insecure_skip_verify is not allowed and will cause probe failure.")
|
||||
}
|
||||
return &proxyProbeService{
|
||||
ipInfoURL: defaultIPInfoURL,
|
||||
insecureSkipVerify: insecure,
|
||||
allowPrivateHosts: allowPrivate,
|
||||
validateResolvedIP: validateResolvedIP,
|
||||
@@ -36,12 +35,20 @@ func NewProxyExitInfoProber(cfg *config.Config) service.ProxyExitInfoProber {
|
||||
}
|
||||
|
||||
const (
|
||||
defaultIPInfoURL = "http://ip-api.com/json/?lang=zh-CN"
|
||||
defaultProxyProbeTimeout = 30 * time.Second
|
||||
)
|
||||
|
||||
// probeURLs 按优先级排列的探测 URL 列表
|
||||
// 某些 AI API 专用代理只允许访问特定域名,因此需要多个备选
|
||||
var probeURLs = []struct {
|
||||
url string
|
||||
parser string // "ip-api" or "httpbin"
|
||||
}{
|
||||
{"http://ip-api.com/json/?lang=zh-CN", "ip-api"},
|
||||
{"http://httpbin.org/ip", "httpbin"},
|
||||
}
|
||||
|
||||
type proxyProbeService struct {
|
||||
ipInfoURL string
|
||||
insecureSkipVerify bool
|
||||
allowPrivateHosts bool
|
||||
validateResolvedIP bool
|
||||
@@ -60,8 +67,21 @@ func (s *proxyProbeService) ProbeProxy(ctx context.Context, proxyURL string) (*s
|
||||
return nil, 0, fmt.Errorf("failed to create proxy client: %w", err)
|
||||
}
|
||||
|
||||
var lastErr error
|
||||
for _, probe := range probeURLs {
|
||||
exitInfo, latencyMs, err := s.probeWithURL(ctx, client, probe.url, probe.parser)
|
||||
if err == nil {
|
||||
return exitInfo, latencyMs, nil
|
||||
}
|
||||
lastErr = err
|
||||
}
|
||||
|
||||
return nil, 0, fmt.Errorf("all probe URLs failed, last error: %w", lastErr)
|
||||
}
|
||||
|
||||
func (s *proxyProbeService) probeWithURL(ctx context.Context, client *http.Client, url string, parser string) (*service.ProxyExitInfo, int64, error) {
|
||||
startTime := time.Now()
|
||||
req, err := http.NewRequestWithContext(ctx, "GET", s.ipInfoURL, nil)
|
||||
req, err := http.NewRequestWithContext(ctx, "GET", url, nil)
|
||||
if err != nil {
|
||||
return nil, 0, fmt.Errorf("failed to create request: %w", err)
|
||||
}
|
||||
@@ -78,6 +98,22 @@ func (s *proxyProbeService) ProbeProxy(ctx context.Context, proxyURL string) (*s
|
||||
return nil, latencyMs, fmt.Errorf("request failed with status: %d", resp.StatusCode)
|
||||
}
|
||||
|
||||
body, err := io.ReadAll(resp.Body)
|
||||
if err != nil {
|
||||
return nil, latencyMs, fmt.Errorf("failed to read response: %w", err)
|
||||
}
|
||||
|
||||
switch parser {
|
||||
case "ip-api":
|
||||
return s.parseIPAPI(body, latencyMs)
|
||||
case "httpbin":
|
||||
return s.parseHTTPBin(body, latencyMs)
|
||||
default:
|
||||
return nil, latencyMs, fmt.Errorf("unknown parser: %s", parser)
|
||||
}
|
||||
}
|
||||
|
||||
func (s *proxyProbeService) parseIPAPI(body []byte, latencyMs int64) (*service.ProxyExitInfo, int64, error) {
|
||||
var ipInfo struct {
|
||||
Status string `json:"status"`
|
||||
Message string `json:"message"`
|
||||
@@ -89,13 +125,12 @@ func (s *proxyProbeService) ProbeProxy(ctx context.Context, proxyURL string) (*s
|
||||
CountryCode string `json:"countryCode"`
|
||||
}
|
||||
|
||||
body, err := io.ReadAll(resp.Body)
|
||||
if err != nil {
|
||||
return nil, latencyMs, fmt.Errorf("failed to read response: %w", err)
|
||||
}
|
||||
|
||||
if err := json.Unmarshal(body, &ipInfo); err != nil {
|
||||
return nil, latencyMs, fmt.Errorf("failed to parse response: %w", err)
|
||||
preview := string(body)
|
||||
if len(preview) > 200 {
|
||||
preview = preview[:200] + "..."
|
||||
}
|
||||
return nil, latencyMs, fmt.Errorf("failed to parse response: %w (body: %s)", err, preview)
|
||||
}
|
||||
if strings.ToLower(ipInfo.Status) != "success" {
|
||||
if ipInfo.Message == "" {
|
||||
@@ -116,3 +151,19 @@ func (s *proxyProbeService) ProbeProxy(ctx context.Context, proxyURL string) (*s
|
||||
CountryCode: ipInfo.CountryCode,
|
||||
}, latencyMs, nil
|
||||
}
|
||||
|
||||
func (s *proxyProbeService) parseHTTPBin(body []byte, latencyMs int64) (*service.ProxyExitInfo, int64, error) {
|
||||
// httpbin.org/ip 返回格式: {"origin": "1.2.3.4"}
|
||||
var result struct {
|
||||
Origin string `json:"origin"`
|
||||
}
|
||||
if err := json.Unmarshal(body, &result); err != nil {
|
||||
return nil, latencyMs, fmt.Errorf("failed to parse httpbin response: %w", err)
|
||||
}
|
||||
if result.Origin == "" {
|
||||
return nil, latencyMs, fmt.Errorf("httpbin: no IP found in response")
|
||||
}
|
||||
return &service.ProxyExitInfo{
|
||||
IP: result.Origin,
|
||||
}, latencyMs, nil
|
||||
}
|
||||
|
||||
@@ -5,6 +5,7 @@ import (
|
||||
"io"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/require"
|
||||
@@ -21,7 +22,6 @@ type ProxyProbeServiceSuite struct {
|
||||
func (s *ProxyProbeServiceSuite) SetupTest() {
|
||||
s.ctx = context.Background()
|
||||
s.prober = &proxyProbeService{
|
||||
ipInfoURL: "http://ip-api.test/json/?lang=zh-CN",
|
||||
allowPrivateHosts: true,
|
||||
}
|
||||
}
|
||||
@@ -49,12 +49,16 @@ func (s *ProxyProbeServiceSuite) TestProbeProxy_UnsupportedProxyScheme() {
|
||||
require.ErrorContains(s.T(), err, "failed to create proxy client")
|
||||
}
|
||||
|
||||
func (s *ProxyProbeServiceSuite) TestProbeProxy_Success() {
|
||||
seen := make(chan string, 1)
|
||||
func (s *ProxyProbeServiceSuite) TestProbeProxy_Success_IPAPI() {
|
||||
s.setupProxyServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
seen <- r.RequestURI
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
_, _ = io.WriteString(w, `{"status":"success","query":"1.2.3.4","city":"c","regionName":"r","country":"cc","countryCode":"CC"}`)
|
||||
// 检查是否是 ip-api 请求
|
||||
if strings.Contains(r.RequestURI, "ip-api.com") {
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
_, _ = io.WriteString(w, `{"status":"success","query":"1.2.3.4","city":"c","regionName":"r","country":"cc","countryCode":"CC"}`)
|
||||
return
|
||||
}
|
||||
// 其他请求返回错误
|
||||
w.WriteHeader(http.StatusServiceUnavailable)
|
||||
}))
|
||||
|
||||
info, latencyMs, err := s.prober.ProbeProxy(s.ctx, s.proxySrv.URL)
|
||||
@@ -65,45 +69,59 @@ func (s *ProxyProbeServiceSuite) TestProbeProxy_Success() {
|
||||
require.Equal(s.T(), "r", info.Region)
|
||||
require.Equal(s.T(), "cc", info.Country)
|
||||
require.Equal(s.T(), "CC", info.CountryCode)
|
||||
|
||||
// Verify proxy received the request
|
||||
select {
|
||||
case uri := <-seen:
|
||||
require.Contains(s.T(), uri, "ip-api.test", "expected request to go through proxy")
|
||||
default:
|
||||
require.Fail(s.T(), "expected proxy to receive request")
|
||||
}
|
||||
}
|
||||
|
||||
func (s *ProxyProbeServiceSuite) TestProbeProxy_NonOKStatus() {
|
||||
func (s *ProxyProbeServiceSuite) TestProbeProxy_Success_HTTPBinFallback() {
|
||||
s.setupProxyServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
// ip-api 失败
|
||||
if strings.Contains(r.RequestURI, "ip-api.com") {
|
||||
w.WriteHeader(http.StatusServiceUnavailable)
|
||||
return
|
||||
}
|
||||
// httpbin 成功
|
||||
if strings.Contains(r.RequestURI, "httpbin.org") {
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
_, _ = io.WriteString(w, `{"origin": "5.6.7.8"}`)
|
||||
return
|
||||
}
|
||||
w.WriteHeader(http.StatusServiceUnavailable)
|
||||
}))
|
||||
|
||||
info, latencyMs, err := s.prober.ProbeProxy(s.ctx, s.proxySrv.URL)
|
||||
require.NoError(s.T(), err, "ProbeProxy should fallback to httpbin")
|
||||
require.GreaterOrEqual(s.T(), latencyMs, int64(0), "unexpected latency")
|
||||
require.Equal(s.T(), "5.6.7.8", info.IP)
|
||||
}
|
||||
|
||||
func (s *ProxyProbeServiceSuite) TestProbeProxy_AllFailed() {
|
||||
s.setupProxyServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
w.WriteHeader(http.StatusServiceUnavailable)
|
||||
}))
|
||||
|
||||
_, _, err := s.prober.ProbeProxy(s.ctx, s.proxySrv.URL)
|
||||
require.Error(s.T(), err)
|
||||
require.ErrorContains(s.T(), err, "status: 503")
|
||||
require.ErrorContains(s.T(), err, "all probe URLs failed")
|
||||
}
|
||||
|
||||
func (s *ProxyProbeServiceSuite) TestProbeProxy_InvalidJSON() {
|
||||
s.setupProxyServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
_, _ = io.WriteString(w, "not-json")
|
||||
if strings.Contains(r.RequestURI, "ip-api.com") {
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
_, _ = io.WriteString(w, "not-json")
|
||||
return
|
||||
}
|
||||
// httpbin 也返回无效响应
|
||||
if strings.Contains(r.RequestURI, "httpbin.org") {
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
_, _ = io.WriteString(w, "not-json")
|
||||
return
|
||||
}
|
||||
w.WriteHeader(http.StatusServiceUnavailable)
|
||||
}))
|
||||
|
||||
_, _, err := s.prober.ProbeProxy(s.ctx, s.proxySrv.URL)
|
||||
require.Error(s.T(), err)
|
||||
require.ErrorContains(s.T(), err, "failed to parse response")
|
||||
}
|
||||
|
||||
func (s *ProxyProbeServiceSuite) TestProbeProxy_InvalidIPInfoURL() {
|
||||
s.prober.ipInfoURL = "://invalid-url"
|
||||
s.setupProxyServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
w.WriteHeader(http.StatusOK)
|
||||
}))
|
||||
|
||||
_, _, err := s.prober.ProbeProxy(s.ctx, s.proxySrv.URL)
|
||||
require.Error(s.T(), err, "expected error for invalid ipInfoURL")
|
||||
require.ErrorContains(s.T(), err, "all probe URLs failed")
|
||||
}
|
||||
|
||||
func (s *ProxyProbeServiceSuite) TestProbeProxy_ProxyServerClosed() {
|
||||
@@ -114,6 +132,40 @@ func (s *ProxyProbeServiceSuite) TestProbeProxy_ProxyServerClosed() {
|
||||
require.Error(s.T(), err, "expected error when proxy server is closed")
|
||||
}
|
||||
|
||||
func (s *ProxyProbeServiceSuite) TestParseIPAPI_Success() {
|
||||
body := []byte(`{"status":"success","query":"1.2.3.4","city":"Beijing","regionName":"Beijing","country":"China","countryCode":"CN"}`)
|
||||
info, latencyMs, err := s.prober.parseIPAPI(body, 100)
|
||||
require.NoError(s.T(), err)
|
||||
require.Equal(s.T(), int64(100), latencyMs)
|
||||
require.Equal(s.T(), "1.2.3.4", info.IP)
|
||||
require.Equal(s.T(), "Beijing", info.City)
|
||||
require.Equal(s.T(), "Beijing", info.Region)
|
||||
require.Equal(s.T(), "China", info.Country)
|
||||
require.Equal(s.T(), "CN", info.CountryCode)
|
||||
}
|
||||
|
||||
func (s *ProxyProbeServiceSuite) TestParseIPAPI_Failure() {
|
||||
body := []byte(`{"status":"fail","message":"rate limited"}`)
|
||||
_, _, err := s.prober.parseIPAPI(body, 100)
|
||||
require.Error(s.T(), err)
|
||||
require.ErrorContains(s.T(), err, "rate limited")
|
||||
}
|
||||
|
||||
func (s *ProxyProbeServiceSuite) TestParseHTTPBin_Success() {
|
||||
body := []byte(`{"origin": "9.8.7.6"}`)
|
||||
info, latencyMs, err := s.prober.parseHTTPBin(body, 50)
|
||||
require.NoError(s.T(), err)
|
||||
require.Equal(s.T(), int64(50), latencyMs)
|
||||
require.Equal(s.T(), "9.8.7.6", info.IP)
|
||||
}
|
||||
|
||||
func (s *ProxyProbeServiceSuite) TestParseHTTPBin_NoIP() {
|
||||
body := []byte(`{"origin": ""}`)
|
||||
_, _, err := s.prober.parseHTTPBin(body, 50)
|
||||
require.Error(s.T(), err)
|
||||
require.ErrorContains(s.T(), err, "no IP found")
|
||||
}
|
||||
|
||||
func TestProxyProbeServiceSuite(t *testing.T) {
|
||||
suite.Run(t, new(ProxyProbeServiceSuite))
|
||||
}
|
||||
|
||||
@@ -60,6 +60,25 @@ func (r *proxyRepository) GetByID(ctx context.Context, id int64) (*service.Proxy
|
||||
return proxyEntityToService(m), nil
|
||||
}
|
||||
|
||||
func (r *proxyRepository) ListByIDs(ctx context.Context, ids []int64) ([]service.Proxy, error) {
|
||||
if len(ids) == 0 {
|
||||
return []service.Proxy{}, nil
|
||||
}
|
||||
|
||||
proxies, err := r.client.Proxy.Query().
|
||||
Where(proxy.IDIn(ids...)).
|
||||
All(ctx)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
out := make([]service.Proxy, 0, len(proxies))
|
||||
for i := range proxies {
|
||||
out = append(out, *proxyEntityToService(proxies[i]))
|
||||
}
|
||||
return out, nil
|
||||
}
|
||||
|
||||
func (r *proxyRepository) Update(ctx context.Context, proxyIn *service.Proxy) error {
|
||||
builder := r.client.Proxy.UpdateOneID(proxyIn.ID).
|
||||
SetName(proxyIn.Name).
|
||||
|
||||
@@ -202,6 +202,57 @@ func (r *redeemCodeRepository) ListByUser(ctx context.Context, userID int64, lim
|
||||
return redeemCodeEntitiesToService(codes), nil
|
||||
}
|
||||
|
||||
// ListByUserPaginated returns paginated balance/concurrency history for a user.
|
||||
// Supports optional type filter (e.g. "balance", "admin_balance", "concurrency", "admin_concurrency", "subscription").
|
||||
func (r *redeemCodeRepository) ListByUserPaginated(ctx context.Context, userID int64, params pagination.PaginationParams, codeType string) ([]service.RedeemCode, *pagination.PaginationResult, error) {
|
||||
q := r.client.RedeemCode.Query().
|
||||
Where(redeemcode.UsedByEQ(userID))
|
||||
|
||||
// Optional type filter
|
||||
if codeType != "" {
|
||||
q = q.Where(redeemcode.TypeEQ(codeType))
|
||||
}
|
||||
|
||||
total, err := q.Count(ctx)
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
|
||||
codes, err := q.
|
||||
WithGroup().
|
||||
Offset(params.Offset()).
|
||||
Limit(params.Limit()).
|
||||
Order(dbent.Desc(redeemcode.FieldUsedAt)).
|
||||
All(ctx)
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
|
||||
return redeemCodeEntitiesToService(codes), paginationResultFromTotal(int64(total), params), nil
|
||||
}
|
||||
|
||||
// SumPositiveBalanceByUser returns total recharged amount (sum of value > 0 where type is balance/admin_balance).
|
||||
func (r *redeemCodeRepository) SumPositiveBalanceByUser(ctx context.Context, userID int64) (float64, error) {
|
||||
var result []struct {
|
||||
Sum float64 `json:"sum"`
|
||||
}
|
||||
err := r.client.RedeemCode.Query().
|
||||
Where(
|
||||
redeemcode.UsedByEQ(userID),
|
||||
redeemcode.ValueGT(0),
|
||||
redeemcode.TypeIn("balance", "admin_balance"),
|
||||
).
|
||||
Aggregate(dbent.As(dbent.Sum(redeemcode.FieldValue), "sum")).
|
||||
Scan(ctx, &result)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
if len(result) == 0 {
|
||||
return 0, nil
|
||||
}
|
||||
return result[0].Sum, nil
|
||||
}
|
||||
|
||||
func redeemCodeEntityToService(m *dbent.RedeemCode) *service.RedeemCode {
|
||||
if m == nil {
|
||||
return nil
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
package repository
|
||||
|
||||
import (
|
||||
"crypto/tls"
|
||||
"time"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/config"
|
||||
@@ -26,7 +27,7 @@ func InitRedis(cfg *config.Config) *redis.Client {
|
||||
// buildRedisOptions 构建 Redis 连接选项
|
||||
// 从配置文件读取连接池和超时参数,支持生产环境调优
|
||||
func buildRedisOptions(cfg *config.Config) *redis.Options {
|
||||
return &redis.Options{
|
||||
opts := &redis.Options{
|
||||
Addr: cfg.Redis.Address(),
|
||||
Password: cfg.Redis.Password,
|
||||
DB: cfg.Redis.DB,
|
||||
@@ -36,4 +37,13 @@ func buildRedisOptions(cfg *config.Config) *redis.Options {
|
||||
PoolSize: cfg.Redis.PoolSize, // 连接池大小
|
||||
MinIdleConns: cfg.Redis.MinIdleConns, // 最小空闲连接
|
||||
}
|
||||
|
||||
if cfg.Redis.EnableTLS {
|
||||
opts.TLSConfig = &tls.Config{
|
||||
MinVersion: tls.VersionTLS12,
|
||||
ServerName: cfg.Redis.Host,
|
||||
}
|
||||
}
|
||||
|
||||
return opts
|
||||
}
|
||||
|
||||
@@ -32,4 +32,16 @@ func TestBuildRedisOptions(t *testing.T) {
|
||||
require.Equal(t, 4*time.Second, opts.WriteTimeout)
|
||||
require.Equal(t, 100, opts.PoolSize)
|
||||
require.Equal(t, 10, opts.MinIdleConns)
|
||||
require.Nil(t, opts.TLSConfig)
|
||||
|
||||
// Test case with TLS enabled
|
||||
cfgTLS := &config.Config{
|
||||
Redis: config.RedisConfig{
|
||||
Host: "localhost",
|
||||
EnableTLS: true,
|
||||
},
|
||||
}
|
||||
optsTLS := buildRedisOptions(cfgTLS)
|
||||
require.NotNil(t, optsTLS.TLSConfig)
|
||||
require.Equal(t, "localhost", optsTLS.TLSConfig.ServerName)
|
||||
}
|
||||
|
||||
158
backend/internal/repository/refresh_token_cache.go
Normal file
158
backend/internal/repository/refresh_token_cache.go
Normal file
@@ -0,0 +1,158 @@
|
||||
package repository
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"time"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/service"
|
||||
"github.com/redis/go-redis/v9"
|
||||
)
|
||||
|
||||
const (
|
||||
refreshTokenKeyPrefix = "refresh_token:"
|
||||
userRefreshTokensPrefix = "user_refresh_tokens:"
|
||||
tokenFamilyPrefix = "token_family:"
|
||||
)
|
||||
|
||||
// refreshTokenKey generates the Redis key for a refresh token.
|
||||
func refreshTokenKey(tokenHash string) string {
|
||||
return refreshTokenKeyPrefix + tokenHash
|
||||
}
|
||||
|
||||
// userRefreshTokensKey generates the Redis key for user's token set.
|
||||
func userRefreshTokensKey(userID int64) string {
|
||||
return fmt.Sprintf("%s%d", userRefreshTokensPrefix, userID)
|
||||
}
|
||||
|
||||
// tokenFamilyKey generates the Redis key for token family set.
|
||||
func tokenFamilyKey(familyID string) string {
|
||||
return tokenFamilyPrefix + familyID
|
||||
}
|
||||
|
||||
type refreshTokenCache struct {
|
||||
rdb *redis.Client
|
||||
}
|
||||
|
||||
// NewRefreshTokenCache creates a new RefreshTokenCache implementation.
|
||||
func NewRefreshTokenCache(rdb *redis.Client) service.RefreshTokenCache {
|
||||
return &refreshTokenCache{rdb: rdb}
|
||||
}
|
||||
|
||||
func (c *refreshTokenCache) StoreRefreshToken(ctx context.Context, tokenHash string, data *service.RefreshTokenData, ttl time.Duration) error {
|
||||
key := refreshTokenKey(tokenHash)
|
||||
val, err := json.Marshal(data)
|
||||
if err != nil {
|
||||
return fmt.Errorf("marshal refresh token data: %w", err)
|
||||
}
|
||||
return c.rdb.Set(ctx, key, val, ttl).Err()
|
||||
}
|
||||
|
||||
func (c *refreshTokenCache) GetRefreshToken(ctx context.Context, tokenHash string) (*service.RefreshTokenData, error) {
|
||||
key := refreshTokenKey(tokenHash)
|
||||
val, err := c.rdb.Get(ctx, key).Result()
|
||||
if err != nil {
|
||||
if err == redis.Nil {
|
||||
return nil, service.ErrRefreshTokenNotFound
|
||||
}
|
||||
return nil, err
|
||||
}
|
||||
var data service.RefreshTokenData
|
||||
if err := json.Unmarshal([]byte(val), &data); err != nil {
|
||||
return nil, fmt.Errorf("unmarshal refresh token data: %w", err)
|
||||
}
|
||||
return &data, nil
|
||||
}
|
||||
|
||||
func (c *refreshTokenCache) DeleteRefreshToken(ctx context.Context, tokenHash string) error {
|
||||
key := refreshTokenKey(tokenHash)
|
||||
return c.rdb.Del(ctx, key).Err()
|
||||
}
|
||||
|
||||
func (c *refreshTokenCache) DeleteUserRefreshTokens(ctx context.Context, userID int64) error {
|
||||
// Get all token hashes for this user
|
||||
tokenHashes, err := c.GetUserTokenHashes(ctx, userID)
|
||||
if err != nil && err != redis.Nil {
|
||||
return fmt.Errorf("get user token hashes: %w", err)
|
||||
}
|
||||
|
||||
if len(tokenHashes) == 0 {
|
||||
return nil
|
||||
}
|
||||
|
||||
// Build keys to delete
|
||||
keys := make([]string, 0, len(tokenHashes)+1)
|
||||
for _, hash := range tokenHashes {
|
||||
keys = append(keys, refreshTokenKey(hash))
|
||||
}
|
||||
keys = append(keys, userRefreshTokensKey(userID))
|
||||
|
||||
// Delete all keys in a pipeline
|
||||
pipe := c.rdb.Pipeline()
|
||||
for _, key := range keys {
|
||||
pipe.Del(ctx, key)
|
||||
}
|
||||
_, err = pipe.Exec(ctx)
|
||||
return err
|
||||
}
|
||||
|
||||
func (c *refreshTokenCache) DeleteTokenFamily(ctx context.Context, familyID string) error {
|
||||
// Get all token hashes in this family
|
||||
tokenHashes, err := c.GetFamilyTokenHashes(ctx, familyID)
|
||||
if err != nil && err != redis.Nil {
|
||||
return fmt.Errorf("get family token hashes: %w", err)
|
||||
}
|
||||
|
||||
if len(tokenHashes) == 0 {
|
||||
return nil
|
||||
}
|
||||
|
||||
// Build keys to delete
|
||||
keys := make([]string, 0, len(tokenHashes)+1)
|
||||
for _, hash := range tokenHashes {
|
||||
keys = append(keys, refreshTokenKey(hash))
|
||||
}
|
||||
keys = append(keys, tokenFamilyKey(familyID))
|
||||
|
||||
// Delete all keys in a pipeline
|
||||
pipe := c.rdb.Pipeline()
|
||||
for _, key := range keys {
|
||||
pipe.Del(ctx, key)
|
||||
}
|
||||
_, err = pipe.Exec(ctx)
|
||||
return err
|
||||
}
|
||||
|
||||
func (c *refreshTokenCache) AddToUserTokenSet(ctx context.Context, userID int64, tokenHash string, ttl time.Duration) error {
|
||||
key := userRefreshTokensKey(userID)
|
||||
pipe := c.rdb.Pipeline()
|
||||
pipe.SAdd(ctx, key, tokenHash)
|
||||
pipe.Expire(ctx, key, ttl)
|
||||
_, err := pipe.Exec(ctx)
|
||||
return err
|
||||
}
|
||||
|
||||
func (c *refreshTokenCache) AddToFamilyTokenSet(ctx context.Context, familyID string, tokenHash string, ttl time.Duration) error {
|
||||
key := tokenFamilyKey(familyID)
|
||||
pipe := c.rdb.Pipeline()
|
||||
pipe.SAdd(ctx, key, tokenHash)
|
||||
pipe.Expire(ctx, key, ttl)
|
||||
_, err := pipe.Exec(ctx)
|
||||
return err
|
||||
}
|
||||
|
||||
func (c *refreshTokenCache) GetUserTokenHashes(ctx context.Context, userID int64) ([]string, error) {
|
||||
key := userRefreshTokensKey(userID)
|
||||
return c.rdb.SMembers(ctx, key).Result()
|
||||
}
|
||||
|
||||
func (c *refreshTokenCache) GetFamilyTokenHashes(ctx context.Context, familyID string) ([]string, error) {
|
||||
key := tokenFamilyKey(familyID)
|
||||
return c.rdb.SMembers(ctx, key).Result()
|
||||
}
|
||||
|
||||
func (c *refreshTokenCache) IsTokenInFamily(ctx context.Context, familyID string, tokenHash string) (bool, error) {
|
||||
key := tokenFamilyKey(familyID)
|
||||
return c.rdb.SIsMember(ctx, key, tokenHash).Result()
|
||||
}
|
||||
@@ -3,6 +3,7 @@ package repository
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"log"
|
||||
"strconv"
|
||||
"time"
|
||||
|
||||
@@ -153,6 +154,21 @@ func NewSessionLimitCache(rdb *redis.Client, defaultIdleTimeoutMinutes int) serv
|
||||
if defaultIdleTimeoutMinutes <= 0 {
|
||||
defaultIdleTimeoutMinutes = 5 // 默认 5 分钟
|
||||
}
|
||||
|
||||
// 预加载 Lua 脚本到 Redis,避免 Pipeline 中出现 NOSCRIPT 错误
|
||||
ctx := context.Background()
|
||||
scripts := []*redis.Script{
|
||||
registerSessionScript,
|
||||
refreshSessionScript,
|
||||
getActiveSessionCountScript,
|
||||
isSessionActiveScript,
|
||||
}
|
||||
for _, script := range scripts {
|
||||
if err := script.Load(ctx, rdb).Err(); err != nil {
|
||||
log.Printf("[SessionLimitCache] Failed to preload Lua script: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
return &sessionLimitCache{
|
||||
rdb: rdb,
|
||||
defaultIdleTimeout: time.Duration(defaultIdleTimeoutMinutes) * time.Minute,
|
||||
|
||||
@@ -22,7 +22,7 @@ import (
|
||||
"github.com/lib/pq"
|
||||
)
|
||||
|
||||
const usageLogSelectColumns = "id, user_id, api_key_id, account_id, request_id, model, group_id, subscription_id, input_tokens, output_tokens, cache_creation_tokens, cache_read_tokens, cache_creation_5m_tokens, cache_creation_1h_tokens, input_cost, output_cost, cache_creation_cost, cache_read_cost, total_cost, actual_cost, rate_multiplier, account_rate_multiplier, billing_type, stream, duration_ms, first_token_ms, user_agent, ip_address, image_count, image_size, created_at"
|
||||
const usageLogSelectColumns = "id, user_id, api_key_id, account_id, request_id, model, group_id, subscription_id, input_tokens, output_tokens, cache_creation_tokens, cache_read_tokens, cache_creation_5m_tokens, cache_creation_1h_tokens, input_cost, output_cost, cache_creation_cost, cache_read_cost, total_cost, actual_cost, rate_multiplier, account_rate_multiplier, billing_type, stream, duration_ms, first_token_ms, user_agent, ip_address, image_count, image_size, reasoning_effort, created_at"
|
||||
|
||||
type usageLogRepository struct {
|
||||
client *dbent.Client
|
||||
@@ -111,21 +111,22 @@ func (r *usageLogRepository) Create(ctx context.Context, log *service.UsageLog)
|
||||
duration_ms,
|
||||
first_token_ms,
|
||||
user_agent,
|
||||
ip_address,
|
||||
image_count,
|
||||
image_size,
|
||||
created_at
|
||||
) VALUES (
|
||||
$1, $2, $3, $4, $5,
|
||||
$6, $7,
|
||||
$8, $9, $10, $11,
|
||||
$12, $13,
|
||||
$14, $15, $16, $17, $18, $19,
|
||||
$20, $21, $22, $23, $24, $25, $26, $27, $28, $29, $30
|
||||
)
|
||||
ON CONFLICT (request_id, api_key_id) DO NOTHING
|
||||
RETURNING id, created_at
|
||||
`
|
||||
ip_address,
|
||||
image_count,
|
||||
image_size,
|
||||
reasoning_effort,
|
||||
created_at
|
||||
) VALUES (
|
||||
$1, $2, $3, $4, $5,
|
||||
$6, $7,
|
||||
$8, $9, $10, $11,
|
||||
$12, $13,
|
||||
$14, $15, $16, $17, $18, $19,
|
||||
$20, $21, $22, $23, $24, $25, $26, $27, $28, $29, $30, $31
|
||||
)
|
||||
ON CONFLICT (request_id, api_key_id) DO NOTHING
|
||||
RETURNING id, created_at
|
||||
`
|
||||
|
||||
groupID := nullInt64(log.GroupID)
|
||||
subscriptionID := nullInt64(log.SubscriptionID)
|
||||
@@ -134,6 +135,7 @@ func (r *usageLogRepository) Create(ctx context.Context, log *service.UsageLog)
|
||||
userAgent := nullString(log.UserAgent)
|
||||
ipAddress := nullString(log.IPAddress)
|
||||
imageSize := nullString(log.ImageSize)
|
||||
reasoningEffort := nullString(log.ReasoningEffort)
|
||||
|
||||
var requestIDArg any
|
||||
if requestID != "" {
|
||||
@@ -170,6 +172,7 @@ func (r *usageLogRepository) Create(ctx context.Context, log *service.UsageLog)
|
||||
ipAddress,
|
||||
log.ImageCount,
|
||||
imageSize,
|
||||
reasoningEffort,
|
||||
createdAt,
|
||||
}
|
||||
if err := scanSingleRow(ctx, sqlq, query, args, &log.ID, &log.CreatedAt); err != nil {
|
||||
@@ -1122,6 +1125,107 @@ func (r *usageLogRepository) GetUserDashboardStats(ctx context.Context, userID i
|
||||
return stats, nil
|
||||
}
|
||||
|
||||
// getPerformanceStatsByAPIKey 获取指定 API Key 的 RPM 和 TPM(近5分钟平均值)
|
||||
func (r *usageLogRepository) getPerformanceStatsByAPIKey(ctx context.Context, apiKeyID int64) (rpm, tpm int64, err error) {
|
||||
fiveMinutesAgo := time.Now().Add(-5 * time.Minute)
|
||||
query := `
|
||||
SELECT
|
||||
COUNT(*) as request_count,
|
||||
COALESCE(SUM(input_tokens + output_tokens + cache_creation_tokens + cache_read_tokens), 0) as token_count
|
||||
FROM usage_logs
|
||||
WHERE created_at >= $1 AND api_key_id = $2`
|
||||
args := []any{fiveMinutesAgo, apiKeyID}
|
||||
|
||||
var requestCount int64
|
||||
var tokenCount int64
|
||||
if err := scanSingleRow(ctx, r.sql, query, args, &requestCount, &tokenCount); err != nil {
|
||||
return 0, 0, err
|
||||
}
|
||||
return requestCount / 5, tokenCount / 5, nil
|
||||
}
|
||||
|
||||
// GetAPIKeyDashboardStats 获取指定 API Key 的仪表盘统计(按 api_key_id 过滤)
|
||||
func (r *usageLogRepository) GetAPIKeyDashboardStats(ctx context.Context, apiKeyID int64) (*UserDashboardStats, error) {
|
||||
stats := &UserDashboardStats{}
|
||||
today := timezone.Today()
|
||||
|
||||
// API Key 维度不需要统计 key 数量,设为 1
|
||||
stats.TotalAPIKeys = 1
|
||||
stats.ActiveAPIKeys = 1
|
||||
|
||||
// 累计 Token 统计
|
||||
totalStatsQuery := `
|
||||
SELECT
|
||||
COUNT(*) as total_requests,
|
||||
COALESCE(SUM(input_tokens), 0) as total_input_tokens,
|
||||
COALESCE(SUM(output_tokens), 0) as total_output_tokens,
|
||||
COALESCE(SUM(cache_creation_tokens), 0) as total_cache_creation_tokens,
|
||||
COALESCE(SUM(cache_read_tokens), 0) as total_cache_read_tokens,
|
||||
COALESCE(SUM(total_cost), 0) as total_cost,
|
||||
COALESCE(SUM(actual_cost), 0) as total_actual_cost,
|
||||
COALESCE(AVG(duration_ms), 0) as avg_duration_ms
|
||||
FROM usage_logs
|
||||
WHERE api_key_id = $1
|
||||
`
|
||||
if err := scanSingleRow(
|
||||
ctx,
|
||||
r.sql,
|
||||
totalStatsQuery,
|
||||
[]any{apiKeyID},
|
||||
&stats.TotalRequests,
|
||||
&stats.TotalInputTokens,
|
||||
&stats.TotalOutputTokens,
|
||||
&stats.TotalCacheCreationTokens,
|
||||
&stats.TotalCacheReadTokens,
|
||||
&stats.TotalCost,
|
||||
&stats.TotalActualCost,
|
||||
&stats.AverageDurationMs,
|
||||
); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
stats.TotalTokens = stats.TotalInputTokens + stats.TotalOutputTokens + stats.TotalCacheCreationTokens + stats.TotalCacheReadTokens
|
||||
|
||||
// 今日 Token 统计
|
||||
todayStatsQuery := `
|
||||
SELECT
|
||||
COUNT(*) as today_requests,
|
||||
COALESCE(SUM(input_tokens), 0) as today_input_tokens,
|
||||
COALESCE(SUM(output_tokens), 0) as today_output_tokens,
|
||||
COALESCE(SUM(cache_creation_tokens), 0) as today_cache_creation_tokens,
|
||||
COALESCE(SUM(cache_read_tokens), 0) as today_cache_read_tokens,
|
||||
COALESCE(SUM(total_cost), 0) as today_cost,
|
||||
COALESCE(SUM(actual_cost), 0) as today_actual_cost
|
||||
FROM usage_logs
|
||||
WHERE api_key_id = $1 AND created_at >= $2
|
||||
`
|
||||
if err := scanSingleRow(
|
||||
ctx,
|
||||
r.sql,
|
||||
todayStatsQuery,
|
||||
[]any{apiKeyID, today},
|
||||
&stats.TodayRequests,
|
||||
&stats.TodayInputTokens,
|
||||
&stats.TodayOutputTokens,
|
||||
&stats.TodayCacheCreationTokens,
|
||||
&stats.TodayCacheReadTokens,
|
||||
&stats.TodayCost,
|
||||
&stats.TodayActualCost,
|
||||
); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
stats.TodayTokens = stats.TodayInputTokens + stats.TodayOutputTokens + stats.TodayCacheCreationTokens + stats.TodayCacheReadTokens
|
||||
|
||||
// 性能指标:RPM 和 TPM(最近5分钟,按 API Key 过滤)
|
||||
rpm, tpm, err := r.getPerformanceStatsByAPIKey(ctx, apiKeyID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
stats.Rpm = rpm
|
||||
stats.Tpm = tpm
|
||||
|
||||
return stats, nil
|
||||
}
|
||||
|
||||
// GetUserUsageTrendByUserID 获取指定用户的使用趋势
|
||||
func (r *usageLogRepository) GetUserUsageTrendByUserID(ctx context.Context, userID int64, startTime, endTime time.Time, granularity string) (results []TrendDataPoint, err error) {
|
||||
dateFormat := "YYYY-MM-DD"
|
||||
@@ -2090,6 +2194,7 @@ func scanUsageLog(scanner interface{ Scan(...any) error }) (*service.UsageLog, e
|
||||
ipAddress sql.NullString
|
||||
imageCount int
|
||||
imageSize sql.NullString
|
||||
reasoningEffort sql.NullString
|
||||
createdAt time.Time
|
||||
)
|
||||
|
||||
@@ -2124,6 +2229,7 @@ func scanUsageLog(scanner interface{ Scan(...any) error }) (*service.UsageLog, e
|
||||
&ipAddress,
|
||||
&imageCount,
|
||||
&imageSize,
|
||||
&reasoningEffort,
|
||||
&createdAt,
|
||||
); err != nil {
|
||||
return nil, err
|
||||
@@ -2183,6 +2289,9 @@ func scanUsageLog(scanner interface{ Scan(...any) error }) (*service.UsageLog, e
|
||||
if imageSize.Valid {
|
||||
log.ImageSize = &imageSize.String
|
||||
}
|
||||
if reasoningEffort.Valid {
|
||||
log.ReasoningEffort = &reasoningEffort.String
|
||||
}
|
||||
|
||||
return log, nil
|
||||
}
|
||||
|
||||
113
backend/internal/repository/user_group_rate_repo.go
Normal file
113
backend/internal/repository/user_group_rate_repo.go
Normal file
@@ -0,0 +1,113 @@
|
||||
package repository
|
||||
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"time"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/service"
|
||||
)
|
||||
|
||||
type userGroupRateRepository struct {
|
||||
sql sqlExecutor
|
||||
}
|
||||
|
||||
// NewUserGroupRateRepository 创建用户专属分组倍率仓储
|
||||
func NewUserGroupRateRepository(sqlDB *sql.DB) service.UserGroupRateRepository {
|
||||
return &userGroupRateRepository{sql: sqlDB}
|
||||
}
|
||||
|
||||
// GetByUserID 获取用户的所有专属分组倍率
|
||||
func (r *userGroupRateRepository) GetByUserID(ctx context.Context, userID int64) (map[int64]float64, error) {
|
||||
query := `SELECT group_id, rate_multiplier FROM user_group_rate_multipliers WHERE user_id = $1`
|
||||
rows, err := r.sql.QueryContext(ctx, query, userID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer func() { _ = rows.Close() }()
|
||||
|
||||
result := make(map[int64]float64)
|
||||
for rows.Next() {
|
||||
var groupID int64
|
||||
var rate float64
|
||||
if err := rows.Scan(&groupID, &rate); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
result[groupID] = rate
|
||||
}
|
||||
if err := rows.Err(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return result, nil
|
||||
}
|
||||
|
||||
// GetByUserAndGroup 获取用户在特定分组的专属倍率
|
||||
func (r *userGroupRateRepository) GetByUserAndGroup(ctx context.Context, userID, groupID int64) (*float64, error) {
|
||||
query := `SELECT rate_multiplier FROM user_group_rate_multipliers WHERE user_id = $1 AND group_id = $2`
|
||||
var rate float64
|
||||
err := scanSingleRow(ctx, r.sql, query, []any{userID, groupID}, &rate)
|
||||
if err == sql.ErrNoRows {
|
||||
return nil, nil
|
||||
}
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &rate, nil
|
||||
}
|
||||
|
||||
// SyncUserGroupRates 同步用户的分组专属倍率
|
||||
func (r *userGroupRateRepository) SyncUserGroupRates(ctx context.Context, userID int64, rates map[int64]*float64) error {
|
||||
if len(rates) == 0 {
|
||||
// 如果传入空 map,删除该用户的所有专属倍率
|
||||
_, err := r.sql.ExecContext(ctx, `DELETE FROM user_group_rate_multipliers WHERE user_id = $1`, userID)
|
||||
return err
|
||||
}
|
||||
|
||||
// 分离需要删除和需要 upsert 的记录
|
||||
var toDelete []int64
|
||||
toUpsert := make(map[int64]float64)
|
||||
for groupID, rate := range rates {
|
||||
if rate == nil {
|
||||
toDelete = append(toDelete, groupID)
|
||||
} else {
|
||||
toUpsert[groupID] = *rate
|
||||
}
|
||||
}
|
||||
|
||||
// 删除指定的记录
|
||||
for _, groupID := range toDelete {
|
||||
_, err := r.sql.ExecContext(ctx,
|
||||
`DELETE FROM user_group_rate_multipliers WHERE user_id = $1 AND group_id = $2`,
|
||||
userID, groupID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
// Upsert 记录
|
||||
now := time.Now()
|
||||
for groupID, rate := range toUpsert {
|
||||
_, err := r.sql.ExecContext(ctx, `
|
||||
INSERT INTO user_group_rate_multipliers (user_id, group_id, rate_multiplier, created_at, updated_at)
|
||||
VALUES ($1, $2, $3, $4, $4)
|
||||
ON CONFLICT (user_id, group_id) DO UPDATE SET rate_multiplier = $3, updated_at = $4
|
||||
`, userID, groupID, rate, now)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// DeleteByGroupID 删除指定分组的所有用户专属倍率
|
||||
func (r *userGroupRateRepository) DeleteByGroupID(ctx context.Context, groupID int64) error {
|
||||
_, err := r.sql.ExecContext(ctx, `DELETE FROM user_group_rate_multipliers WHERE group_id = $1`, groupID)
|
||||
return err
|
||||
}
|
||||
|
||||
// DeleteByUserID 删除指定用户的所有专属倍率
|
||||
func (r *userGroupRateRepository) DeleteByUserID(ctx context.Context, userID int64) error {
|
||||
_, err := r.sql.ExecContext(ctx, `DELETE FROM user_group_rate_multipliers WHERE user_id = $1`, userID)
|
||||
return err
|
||||
}
|
||||
@@ -190,6 +190,7 @@ func (r *userRepository) ListWithFilters(ctx context.Context, params pagination.
|
||||
dbuser.Or(
|
||||
dbuser.EmailContainsFold(filters.Search),
|
||||
dbuser.UsernameContainsFold(filters.Search),
|
||||
dbuser.NotesContainsFold(filters.Search),
|
||||
),
|
||||
)
|
||||
}
|
||||
|
||||
@@ -56,6 +56,8 @@ var ProviderSet = wire.NewSet(
|
||||
NewProxyRepository,
|
||||
NewRedeemCodeRepository,
|
||||
NewPromoCodeRepository,
|
||||
NewAnnouncementRepository,
|
||||
NewAnnouncementReadRepository,
|
||||
NewUsageLogRepository,
|
||||
NewUsageCleanupRepository,
|
||||
NewDashboardAggregationRepository,
|
||||
@@ -64,6 +66,8 @@ var ProviderSet = wire.NewSet(
|
||||
NewUserSubscriptionRepository,
|
||||
NewUserAttributeDefinitionRepository,
|
||||
NewUserAttributeValueRepository,
|
||||
NewUserGroupRateRepository,
|
||||
NewErrorPassthroughRepository,
|
||||
|
||||
// Cache implementations
|
||||
NewGatewayCache,
|
||||
@@ -83,6 +87,8 @@ var ProviderSet = wire.NewSet(
|
||||
NewSchedulerOutboxRepository,
|
||||
NewProxyLatencyCache,
|
||||
NewTotpCache,
|
||||
NewRefreshTokenCache,
|
||||
NewErrorPassthroughCache,
|
||||
|
||||
// Encryptors
|
||||
NewAESEncryptor,
|
||||
|
||||
Reference in New Issue
Block a user