Merge remote-tracking branch 'upstream/main'
This commit is contained in:
@@ -595,6 +595,25 @@ func (a *Account) IsTLSFingerprintEnabled() bool {
|
||||
return false
|
||||
}
|
||||
|
||||
// IsSessionIDMaskingEnabled 检查是否启用会话ID伪装
|
||||
// 仅适用于 Anthropic OAuth/SetupToken 类型账号
|
||||
// 启用后将在一段时间内(15分钟)固定 metadata.user_id 中的 session ID,
|
||||
// 使上游认为请求来自同一个会话
|
||||
func (a *Account) IsSessionIDMaskingEnabled() bool {
|
||||
if !a.IsAnthropicOAuthOrSetupToken() {
|
||||
return false
|
||||
}
|
||||
if a.Extra == nil {
|
||||
return false
|
||||
}
|
||||
if v, ok := a.Extra["session_id_masking_enabled"]; ok {
|
||||
if enabled, ok := v.(bool); ok {
|
||||
return enabled
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
// GetWindowCostLimit 获取 5h 窗口费用阈值(美元)
|
||||
// 返回 0 表示未启用
|
||||
func (a *Account) GetWindowCostLimit() float64 {
|
||||
@@ -671,6 +690,23 @@ func (a *Account) CheckWindowCostSchedulability(currentWindowCost float64) Windo
|
||||
return WindowCostNotSchedulable
|
||||
}
|
||||
|
||||
// GetCurrentWindowStartTime 获取当前有效的窗口开始时间
|
||||
// 逻辑:
|
||||
// 1. 如果窗口未过期(SessionWindowEnd 存在且在当前时间之后),使用记录的 SessionWindowStart
|
||||
// 2. 否则(窗口过期或未设置),使用新的预测窗口开始时间(从当前整点开始)
|
||||
func (a *Account) GetCurrentWindowStartTime() time.Time {
|
||||
now := time.Now()
|
||||
|
||||
// 窗口未过期,使用记录的窗口开始时间
|
||||
if a.SessionWindowStart != nil && a.SessionWindowEnd != nil && now.Before(*a.SessionWindowEnd) {
|
||||
return *a.SessionWindowStart
|
||||
}
|
||||
|
||||
// 窗口已过期或未设置,预测新的窗口开始时间(从当前整点开始)
|
||||
// 与 ratelimit_service.go 中 UpdateSessionWindow 的预测逻辑保持一致
|
||||
return time.Date(now.Year(), now.Month(), now.Day(), now.Hour(), 0, 0, 0, now.Location())
|
||||
}
|
||||
|
||||
// parseExtraFloat64 从 extra 字段解析 float64 值
|
||||
func parseExtraFloat64(value any) float64 {
|
||||
switch v := value.(type) {
|
||||
|
||||
@@ -37,6 +37,7 @@ type AccountRepository interface {
|
||||
UpdateLastUsed(ctx context.Context, id int64) error
|
||||
BatchUpdateLastUsed(ctx context.Context, updates map[int64]time.Time) error
|
||||
SetError(ctx context.Context, id int64, errorMsg string) error
|
||||
ClearError(ctx context.Context, id int64) error
|
||||
SetSchedulable(ctx context.Context, id int64, schedulable bool) error
|
||||
AutoPauseExpiredAccounts(ctx context.Context, now time.Time) (int64, error)
|
||||
BindGroups(ctx context.Context, accountID int64, groupIDs []int64) error
|
||||
|
||||
@@ -99,6 +99,10 @@ func (s *accountRepoStub) SetError(ctx context.Context, id int64, errorMsg strin
|
||||
panic("unexpected SetError call")
|
||||
}
|
||||
|
||||
func (s *accountRepoStub) ClearError(ctx context.Context, id int64) error {
|
||||
panic("unexpected ClearError call")
|
||||
}
|
||||
|
||||
func (s *accountRepoStub) SetSchedulable(ctx context.Context, id int64, schedulable bool) error {
|
||||
panic("unexpected SetSchedulable call")
|
||||
}
|
||||
|
||||
@@ -32,8 +32,8 @@ type UsageLogRepository interface {
|
||||
|
||||
// Admin dashboard stats
|
||||
GetDashboardStats(ctx context.Context) (*usagestats.DashboardStats, error)
|
||||
GetUsageTrendWithFilters(ctx context.Context, startTime, endTime time.Time, granularity string, userID, apiKeyID, accountID, groupID int64, model string, stream *bool) ([]usagestats.TrendDataPoint, error)
|
||||
GetModelStatsWithFilters(ctx context.Context, startTime, endTime time.Time, userID, apiKeyID, accountID, groupID int64, stream *bool) ([]usagestats.ModelStat, error)
|
||||
GetUsageTrendWithFilters(ctx context.Context, startTime, endTime time.Time, granularity string, userID, apiKeyID, accountID, groupID int64, model string, stream *bool, billingType *int8) ([]usagestats.TrendDataPoint, error)
|
||||
GetModelStatsWithFilters(ctx context.Context, startTime, endTime time.Time, userID, apiKeyID, accountID, groupID int64, stream *bool, billingType *int8) ([]usagestats.ModelStat, error)
|
||||
GetAPIKeyUsageTrend(ctx context.Context, startTime, endTime time.Time, granularity string, limit int) ([]usagestats.APIKeyUsageTrendPoint, error)
|
||||
GetUserUsageTrend(ctx context.Context, startTime, endTime time.Time, granularity string, limit int) ([]usagestats.UserUsageTrendPoint, error)
|
||||
GetBatchUserUsageStats(ctx context.Context, userIDs []int64) (map[int64]*usagestats.BatchUserUsageStats, error)
|
||||
@@ -157,9 +157,20 @@ type ClaudeUsageResponse struct {
|
||||
} `json:"seven_day_sonnet"`
|
||||
}
|
||||
|
||||
// ClaudeUsageFetchOptions 包含获取 Claude 用量数据所需的所有选项
|
||||
type ClaudeUsageFetchOptions struct {
|
||||
AccessToken string // OAuth access token
|
||||
ProxyURL string // 代理 URL(可选)
|
||||
AccountID int64 // 账号 ID(用于 TLS 指纹选择)
|
||||
EnableTLSFingerprint bool // 是否启用 TLS 指纹伪装
|
||||
Fingerprint *Fingerprint // 缓存的指纹信息(User-Agent 等)
|
||||
}
|
||||
|
||||
// ClaudeUsageFetcher fetches usage data from Anthropic OAuth API
|
||||
type ClaudeUsageFetcher interface {
|
||||
FetchUsage(ctx context.Context, accessToken, proxyURL string) (*ClaudeUsageResponse, error)
|
||||
// FetchUsageWithOptions 使用完整选项获取用量数据,支持 TLS 指纹和自定义 User-Agent
|
||||
FetchUsageWithOptions(ctx context.Context, opts *ClaudeUsageFetchOptions) (*ClaudeUsageResponse, error)
|
||||
}
|
||||
|
||||
// AccountUsageService 账号使用量查询服务
|
||||
@@ -170,6 +181,7 @@ type AccountUsageService struct {
|
||||
geminiQuotaService *GeminiQuotaService
|
||||
antigravityQuotaFetcher *AntigravityQuotaFetcher
|
||||
cache *UsageCache
|
||||
identityCache IdentityCache
|
||||
}
|
||||
|
||||
// NewAccountUsageService 创建AccountUsageService实例
|
||||
@@ -180,6 +192,7 @@ func NewAccountUsageService(
|
||||
geminiQuotaService *GeminiQuotaService,
|
||||
antigravityQuotaFetcher *AntigravityQuotaFetcher,
|
||||
cache *UsageCache,
|
||||
identityCache IdentityCache,
|
||||
) *AccountUsageService {
|
||||
return &AccountUsageService{
|
||||
accountRepo: accountRepo,
|
||||
@@ -188,6 +201,7 @@ func NewAccountUsageService(
|
||||
geminiQuotaService: geminiQuotaService,
|
||||
antigravityQuotaFetcher: antigravityQuotaFetcher,
|
||||
cache: cache,
|
||||
identityCache: identityCache,
|
||||
}
|
||||
}
|
||||
|
||||
@@ -272,7 +286,7 @@ func (s *AccountUsageService) getGeminiUsage(ctx context.Context, account *Accou
|
||||
}
|
||||
|
||||
dayStart := geminiDailyWindowStart(now)
|
||||
stats, err := s.usageLogRepo.GetModelStatsWithFilters(ctx, dayStart, now, 0, 0, account.ID, 0, nil)
|
||||
stats, err := s.usageLogRepo.GetModelStatsWithFilters(ctx, dayStart, now, 0, 0, account.ID, 0, nil, nil)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("get gemini usage stats failed: %w", err)
|
||||
}
|
||||
@@ -294,7 +308,7 @@ func (s *AccountUsageService) getGeminiUsage(ctx context.Context, account *Accou
|
||||
// Minute window (RPM) - fixed-window approximation: current minute [truncate(now), truncate(now)+1m)
|
||||
minuteStart := now.Truncate(time.Minute)
|
||||
minuteResetAt := minuteStart.Add(time.Minute)
|
||||
minuteStats, err := s.usageLogRepo.GetModelStatsWithFilters(ctx, minuteStart, now, 0, 0, account.ID, 0, nil)
|
||||
minuteStats, err := s.usageLogRepo.GetModelStatsWithFilters(ctx, minuteStart, now, 0, 0, account.ID, 0, nil, nil)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("get gemini minute usage stats failed: %w", err)
|
||||
}
|
||||
@@ -369,12 +383,8 @@ func (s *AccountUsageService) addWindowStats(ctx context.Context, account *Accou
|
||||
|
||||
// 如果没有缓存,从数据库查询
|
||||
if windowStats == nil {
|
||||
var startTime time.Time
|
||||
if account.SessionWindowStart != nil {
|
||||
startTime = *account.SessionWindowStart
|
||||
} else {
|
||||
startTime = time.Now().Add(-5 * time.Hour)
|
||||
}
|
||||
// 使用统一的窗口开始时间计算逻辑(考虑窗口过期情况)
|
||||
startTime := account.GetCurrentWindowStartTime()
|
||||
|
||||
stats, err := s.usageLogRepo.GetAccountWindowStats(ctx, account.ID, startTime)
|
||||
if err != nil {
|
||||
@@ -428,6 +438,8 @@ func (s *AccountUsageService) GetAccountUsageStats(ctx context.Context, accountI
|
||||
}
|
||||
|
||||
// fetchOAuthUsageRaw 从 Anthropic API 获取原始响应(不构建 UsageInfo)
|
||||
// 如果账号开启了 TLS 指纹,则使用 TLS 指纹伪装
|
||||
// 如果有缓存的 Fingerprint,则使用缓存的 User-Agent 等信息
|
||||
func (s *AccountUsageService) fetchOAuthUsageRaw(ctx context.Context, account *Account) (*ClaudeUsageResponse, error) {
|
||||
accessToken := account.GetCredential("access_token")
|
||||
if accessToken == "" {
|
||||
@@ -439,7 +451,22 @@ func (s *AccountUsageService) fetchOAuthUsageRaw(ctx context.Context, account *A
|
||||
proxyURL = account.Proxy.URL()
|
||||
}
|
||||
|
||||
return s.usageFetcher.FetchUsage(ctx, accessToken, proxyURL)
|
||||
// 构建完整的选项
|
||||
opts := &ClaudeUsageFetchOptions{
|
||||
AccessToken: accessToken,
|
||||
ProxyURL: proxyURL,
|
||||
AccountID: account.ID,
|
||||
EnableTLSFingerprint: account.IsTLSFingerprintEnabled(),
|
||||
}
|
||||
|
||||
// 尝试获取缓存的 Fingerprint(包含 User-Agent 等信息)
|
||||
if s.identityCache != nil {
|
||||
if fp, err := s.identityCache.GetFingerprint(ctx, account.ID); err == nil && fp != nil {
|
||||
opts.Fingerprint = fp
|
||||
}
|
||||
}
|
||||
|
||||
return s.usageFetcher.FetchUsageWithOptions(ctx, opts)
|
||||
}
|
||||
|
||||
// parseTime 尝试多种格式解析时间
|
||||
|
||||
@@ -42,6 +42,7 @@ type AdminService interface {
|
||||
DeleteAccount(ctx context.Context, id int64) error
|
||||
RefreshAccountCredentials(ctx context.Context, id int64) (*Account, error)
|
||||
ClearAccountError(ctx context.Context, id int64) (*Account, error)
|
||||
SetAccountError(ctx context.Context, id int64, errorMsg string) error
|
||||
SetAccountSchedulable(ctx context.Context, id int64, schedulable bool) (*Account, error)
|
||||
BulkUpdateAccounts(ctx context.Context, input *BulkUpdateAccountsInput) (*BulkUpdateAccountsResult, error)
|
||||
|
||||
@@ -1101,6 +1102,10 @@ func (s *adminServiceImpl) ClearAccountError(ctx context.Context, id int64) (*Ac
|
||||
return account, nil
|
||||
}
|
||||
|
||||
func (s *adminServiceImpl) SetAccountError(ctx context.Context, id int64, errorMsg string) error {
|
||||
return s.accountRepo.SetError(ctx, id, errorMsg)
|
||||
}
|
||||
|
||||
func (s *adminServiceImpl) SetAccountSchedulable(ctx context.Context, id int64, schedulable bool) (*Account, error) {
|
||||
if err := s.accountRepo.SetSchedulable(ctx, id, schedulable); err != nil {
|
||||
return nil, err
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
@@ -82,13 +82,14 @@ type AntigravityExchangeCodeInput struct {
|
||||
|
||||
// AntigravityTokenInfo token 信息
|
||||
type AntigravityTokenInfo struct {
|
||||
AccessToken string `json:"access_token"`
|
||||
RefreshToken string `json:"refresh_token"`
|
||||
ExpiresIn int64 `json:"expires_in"`
|
||||
ExpiresAt int64 `json:"expires_at"`
|
||||
TokenType string `json:"token_type"`
|
||||
Email string `json:"email,omitempty"`
|
||||
ProjectID string `json:"project_id,omitempty"`
|
||||
AccessToken string `json:"access_token"`
|
||||
RefreshToken string `json:"refresh_token"`
|
||||
ExpiresIn int64 `json:"expires_in"`
|
||||
ExpiresAt int64 `json:"expires_at"`
|
||||
TokenType string `json:"token_type"`
|
||||
Email string `json:"email,omitempty"`
|
||||
ProjectID string `json:"project_id,omitempty"`
|
||||
ProjectIDMissing bool `json:"-"` // LoadCodeAssist 未返回 project_id
|
||||
}
|
||||
|
||||
// ExchangeCode 用 authorization code 交换 token
|
||||
@@ -149,12 +150,6 @@ func (s *AntigravityOAuthService) ExchangeCode(ctx context.Context, input *Antig
|
||||
result.ProjectID = loadResp.CloudAICompanionProject
|
||||
}
|
||||
|
||||
// 兜底:随机生成 project_id
|
||||
if result.ProjectID == "" {
|
||||
result.ProjectID = antigravity.GenerateMockProjectID()
|
||||
fmt.Printf("[AntigravityOAuth] 使用随机生成的 project_id: %s\n", result.ProjectID)
|
||||
}
|
||||
|
||||
return result, nil
|
||||
}
|
||||
|
||||
@@ -236,16 +231,24 @@ func (s *AntigravityOAuthService) RefreshAccountToken(ctx context.Context, accou
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// 保留原有的 project_id 和 email
|
||||
existingProjectID := strings.TrimSpace(account.GetCredential("project_id"))
|
||||
if existingProjectID != "" {
|
||||
tokenInfo.ProjectID = existingProjectID
|
||||
}
|
||||
// 保留原有的 email
|
||||
existingEmail := strings.TrimSpace(account.GetCredential("email"))
|
||||
if existingEmail != "" {
|
||||
tokenInfo.Email = existingEmail
|
||||
}
|
||||
|
||||
// 每次刷新都调用 LoadCodeAssist 获取 project_id
|
||||
client := antigravity.NewClient(proxyURL)
|
||||
loadResp, _, err := client.LoadCodeAssist(ctx, tokenInfo.AccessToken)
|
||||
if err != nil || loadResp == nil || loadResp.CloudAICompanionProject == "" {
|
||||
// LoadCodeAssist 失败或返回空,保留原有 project_id,标记缺失
|
||||
existingProjectID := strings.TrimSpace(account.GetCredential("project_id"))
|
||||
tokenInfo.ProjectID = existingProjectID
|
||||
tokenInfo.ProjectIDMissing = true
|
||||
} else {
|
||||
tokenInfo.ProjectID = loadResp.CloudAICompanionProject
|
||||
}
|
||||
|
||||
return tokenInfo, nil
|
||||
}
|
||||
|
||||
|
||||
@@ -31,11 +31,6 @@ func (f *AntigravityQuotaFetcher) FetchQuota(ctx context.Context, account *Accou
|
||||
accessToken := account.GetCredential("access_token")
|
||||
projectID := account.GetCredential("project_id")
|
||||
|
||||
// 如果没有 project_id,生成一个随机的
|
||||
if projectID == "" {
|
||||
projectID = antigravity.GenerateMockProjectID()
|
||||
}
|
||||
|
||||
client := antigravity.NewClient(proxyURL)
|
||||
|
||||
// 调用 API 获取配额
|
||||
|
||||
190
backend/internal/service/antigravity_rate_limit_test.go
Normal file
190
backend/internal/service/antigravity_rate_limit_test.go
Normal file
@@ -0,0 +1,190 @@
|
||||
//go:build unit
|
||||
|
||||
package service
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
"strings"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/antigravity"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
type stubAntigravityUpstream struct {
|
||||
firstBase string
|
||||
secondBase string
|
||||
calls []string
|
||||
}
|
||||
|
||||
func (s *stubAntigravityUpstream) Do(req *http.Request, proxyURL string, accountID int64, accountConcurrency int) (*http.Response, error) {
|
||||
url := req.URL.String()
|
||||
s.calls = append(s.calls, url)
|
||||
if strings.HasPrefix(url, s.firstBase) {
|
||||
return &http.Response{
|
||||
StatusCode: http.StatusTooManyRequests,
|
||||
Header: http.Header{},
|
||||
Body: io.NopCloser(strings.NewReader(`{"error":{"message":"Resource has been exhausted"}}`)),
|
||||
}, nil
|
||||
}
|
||||
return &http.Response{
|
||||
StatusCode: http.StatusOK,
|
||||
Header: http.Header{},
|
||||
Body: io.NopCloser(strings.NewReader("ok")),
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (s *stubAntigravityUpstream) DoWithTLS(req *http.Request, proxyURL string, accountID int64, accountConcurrency int, enableTLSFingerprint bool) (*http.Response, error) {
|
||||
return s.Do(req, proxyURL, accountID, accountConcurrency)
|
||||
}
|
||||
|
||||
type scopeLimitCall struct {
|
||||
accountID int64
|
||||
scope AntigravityQuotaScope
|
||||
resetAt time.Time
|
||||
}
|
||||
|
||||
type rateLimitCall struct {
|
||||
accountID int64
|
||||
resetAt time.Time
|
||||
}
|
||||
|
||||
type stubAntigravityAccountRepo struct {
|
||||
AccountRepository
|
||||
scopeCalls []scopeLimitCall
|
||||
rateCalls []rateLimitCall
|
||||
}
|
||||
|
||||
func (s *stubAntigravityAccountRepo) SetAntigravityQuotaScopeLimit(ctx context.Context, id int64, scope AntigravityQuotaScope, resetAt time.Time) error {
|
||||
s.scopeCalls = append(s.scopeCalls, scopeLimitCall{accountID: id, scope: scope, resetAt: resetAt})
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *stubAntigravityAccountRepo) SetRateLimited(ctx context.Context, id int64, resetAt time.Time) error {
|
||||
s.rateCalls = append(s.rateCalls, rateLimitCall{accountID: id, resetAt: resetAt})
|
||||
return nil
|
||||
}
|
||||
|
||||
func TestAntigravityRetryLoop_URLFallback_UsesLatestSuccess(t *testing.T) {
|
||||
oldBaseURLs := append([]string(nil), antigravity.BaseURLs...)
|
||||
oldAvailability := antigravity.DefaultURLAvailability
|
||||
defer func() {
|
||||
antigravity.BaseURLs = oldBaseURLs
|
||||
antigravity.DefaultURLAvailability = oldAvailability
|
||||
}()
|
||||
|
||||
base1 := "https://ag-1.test"
|
||||
base2 := "https://ag-2.test"
|
||||
antigravity.BaseURLs = []string{base1, base2}
|
||||
antigravity.DefaultURLAvailability = antigravity.NewURLAvailability(time.Minute)
|
||||
|
||||
upstream := &stubAntigravityUpstream{firstBase: base1, secondBase: base2}
|
||||
account := &Account{
|
||||
ID: 1,
|
||||
Name: "acc-1",
|
||||
Platform: PlatformAntigravity,
|
||||
Schedulable: true,
|
||||
Status: StatusActive,
|
||||
Concurrency: 1,
|
||||
}
|
||||
|
||||
var handleErrorCalled bool
|
||||
result, err := antigravityRetryLoop(antigravityRetryLoopParams{
|
||||
prefix: "[test]",
|
||||
ctx: context.Background(),
|
||||
account: account,
|
||||
proxyURL: "",
|
||||
accessToken: "token",
|
||||
action: "generateContent",
|
||||
body: []byte(`{"input":"test"}`),
|
||||
quotaScope: AntigravityQuotaScopeClaude,
|
||||
httpUpstream: upstream,
|
||||
handleError: func(ctx context.Context, prefix string, account *Account, statusCode int, headers http.Header, body []byte, quotaScope AntigravityQuotaScope) {
|
||||
handleErrorCalled = true
|
||||
},
|
||||
})
|
||||
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, result)
|
||||
require.NotNil(t, result.resp)
|
||||
defer func() { _ = result.resp.Body.Close() }()
|
||||
require.Equal(t, http.StatusOK, result.resp.StatusCode)
|
||||
require.False(t, handleErrorCalled)
|
||||
require.Len(t, upstream.calls, 2)
|
||||
require.True(t, strings.HasPrefix(upstream.calls[0], base1))
|
||||
require.True(t, strings.HasPrefix(upstream.calls[1], base2))
|
||||
|
||||
available := antigravity.DefaultURLAvailability.GetAvailableURLs()
|
||||
require.NotEmpty(t, available)
|
||||
require.Equal(t, base2, available[0])
|
||||
}
|
||||
|
||||
func TestAntigravityHandleUpstreamError_UsesScopeLimitWhenEnabled(t *testing.T) {
|
||||
t.Setenv(antigravityScopeRateLimitEnv, "true")
|
||||
repo := &stubAntigravityAccountRepo{}
|
||||
svc := &AntigravityGatewayService{accountRepo: repo}
|
||||
account := &Account{ID: 9, Name: "acc-9", Platform: PlatformAntigravity}
|
||||
|
||||
body := buildGeminiRateLimitBody("3s")
|
||||
svc.handleUpstreamError(context.Background(), "[test]", account, http.StatusTooManyRequests, http.Header{}, body, AntigravityQuotaScopeClaude)
|
||||
|
||||
require.Len(t, repo.scopeCalls, 1)
|
||||
require.Empty(t, repo.rateCalls)
|
||||
call := repo.scopeCalls[0]
|
||||
require.Equal(t, account.ID, call.accountID)
|
||||
require.Equal(t, AntigravityQuotaScopeClaude, call.scope)
|
||||
require.WithinDuration(t, time.Now().Add(3*time.Second), call.resetAt, 2*time.Second)
|
||||
}
|
||||
|
||||
func TestAntigravityHandleUpstreamError_UsesAccountLimitWhenScopeDisabled(t *testing.T) {
|
||||
t.Setenv(antigravityScopeRateLimitEnv, "false")
|
||||
repo := &stubAntigravityAccountRepo{}
|
||||
svc := &AntigravityGatewayService{accountRepo: repo}
|
||||
account := &Account{ID: 10, Name: "acc-10", Platform: PlatformAntigravity}
|
||||
|
||||
body := buildGeminiRateLimitBody("2s")
|
||||
svc.handleUpstreamError(context.Background(), "[test]", account, http.StatusTooManyRequests, http.Header{}, body, AntigravityQuotaScopeClaude)
|
||||
|
||||
require.Len(t, repo.rateCalls, 1)
|
||||
require.Empty(t, repo.scopeCalls)
|
||||
call := repo.rateCalls[0]
|
||||
require.Equal(t, account.ID, call.accountID)
|
||||
require.WithinDuration(t, time.Now().Add(2*time.Second), call.resetAt, 2*time.Second)
|
||||
}
|
||||
|
||||
func TestAccountIsSchedulableForModel_AntigravityRateLimits(t *testing.T) {
|
||||
now := time.Now()
|
||||
future := now.Add(10 * time.Minute)
|
||||
|
||||
account := &Account{
|
||||
ID: 1,
|
||||
Name: "acc",
|
||||
Platform: PlatformAntigravity,
|
||||
Status: StatusActive,
|
||||
Schedulable: true,
|
||||
}
|
||||
|
||||
account.RateLimitResetAt = &future
|
||||
require.False(t, account.IsSchedulableForModel("claude-sonnet-4-5"))
|
||||
require.False(t, account.IsSchedulableForModel("gemini-3-flash"))
|
||||
|
||||
account.RateLimitResetAt = nil
|
||||
account.Extra = map[string]any{
|
||||
antigravityQuotaScopesKey: map[string]any{
|
||||
"claude": map[string]any{
|
||||
"rate_limit_reset_at": future.Format(time.RFC3339),
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
require.False(t, account.IsSchedulableForModel("claude-sonnet-4-5"))
|
||||
require.True(t, account.IsSchedulableForModel("gemini-3-flash"))
|
||||
}
|
||||
|
||||
func buildGeminiRateLimitBody(delay string) []byte {
|
||||
return []byte(fmt.Sprintf(`{"error":{"message":"too many requests","details":[{"metadata":{"quotaResetDelay":%q}}]}}`, delay))
|
||||
}
|
||||
@@ -61,5 +61,10 @@ func (r *AntigravityTokenRefresher) Refresh(ctx context.Context, account *Accoun
|
||||
}
|
||||
}
|
||||
|
||||
// 如果 project_id 获取失败,返回 credentials 但同时返回错误让账户被标记
|
||||
if tokenInfo.ProjectIDMissing {
|
||||
return newCredentials, fmt.Errorf("missing_project_id: 账户缺少project id,可能无法使用Antigravity")
|
||||
}
|
||||
|
||||
return newCredentials, nil
|
||||
}
|
||||
|
||||
@@ -20,12 +20,16 @@ var (
|
||||
// ErrDashboardBackfillDisabled 当配置禁用回填时返回。
|
||||
ErrDashboardBackfillDisabled = errors.New("仪表盘聚合回填已禁用")
|
||||
// ErrDashboardBackfillTooLarge 当回填跨度超过限制时返回。
|
||||
ErrDashboardBackfillTooLarge = errors.New("回填时间跨度过大")
|
||||
ErrDashboardBackfillTooLarge = errors.New("回填时间跨度过大")
|
||||
errDashboardAggregationRunning = errors.New("聚合作业正在运行")
|
||||
)
|
||||
|
||||
// DashboardAggregationRepository 定义仪表盘预聚合仓储接口。
|
||||
type DashboardAggregationRepository interface {
|
||||
AggregateRange(ctx context.Context, start, end time.Time) error
|
||||
// RecomputeRange 重新计算指定时间范围内的聚合数据(包含活跃用户等派生表)。
|
||||
// 设计目的:当 usage_logs 被批量删除/回滚后,确保聚合表可恢复一致性。
|
||||
RecomputeRange(ctx context.Context, start, end time.Time) error
|
||||
GetAggregationWatermark(ctx context.Context) (time.Time, error)
|
||||
UpdateAggregationWatermark(ctx context.Context, aggregatedAt time.Time) error
|
||||
CleanupAggregates(ctx context.Context, hourlyCutoff, dailyCutoff time.Time) error
|
||||
@@ -112,6 +116,41 @@ func (s *DashboardAggregationService) TriggerBackfill(start, end time.Time) erro
|
||||
return nil
|
||||
}
|
||||
|
||||
// TriggerRecomputeRange 触发指定范围的重新计算(异步)。
|
||||
// 与 TriggerBackfill 不同:
|
||||
// - 不依赖 backfill_enabled(这是内部一致性修复)
|
||||
// - 不更新 watermark(避免影响正常增量聚合游标)
|
||||
func (s *DashboardAggregationService) TriggerRecomputeRange(start, end time.Time) error {
|
||||
if s == nil || s.repo == nil {
|
||||
return errors.New("聚合服务未初始化")
|
||||
}
|
||||
if !s.cfg.Enabled {
|
||||
return errors.New("聚合服务已禁用")
|
||||
}
|
||||
if !end.After(start) {
|
||||
return errors.New("重新计算时间范围无效")
|
||||
}
|
||||
|
||||
go func() {
|
||||
const maxRetries = 3
|
||||
for i := 0; i < maxRetries; i++ {
|
||||
ctx, cancel := context.WithTimeout(context.Background(), defaultDashboardAggregationBackfillTimeout)
|
||||
err := s.recomputeRange(ctx, start, end)
|
||||
cancel()
|
||||
if err == nil {
|
||||
return
|
||||
}
|
||||
if !errors.Is(err, errDashboardAggregationRunning) {
|
||||
log.Printf("[DashboardAggregation] 重新计算失败: %v", err)
|
||||
return
|
||||
}
|
||||
time.Sleep(5 * time.Second)
|
||||
}
|
||||
log.Printf("[DashboardAggregation] 重新计算放弃: 聚合作业持续占用")
|
||||
}()
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *DashboardAggregationService) recomputeRecentDays() {
|
||||
days := s.cfg.RecomputeDays
|
||||
if days <= 0 {
|
||||
@@ -128,6 +167,24 @@ func (s *DashboardAggregationService) recomputeRecentDays() {
|
||||
}
|
||||
}
|
||||
|
||||
func (s *DashboardAggregationService) recomputeRange(ctx context.Context, start, end time.Time) error {
|
||||
if !atomic.CompareAndSwapInt32(&s.running, 0, 1) {
|
||||
return errDashboardAggregationRunning
|
||||
}
|
||||
defer atomic.StoreInt32(&s.running, 0)
|
||||
|
||||
jobStart := time.Now().UTC()
|
||||
if err := s.repo.RecomputeRange(ctx, start, end); err != nil {
|
||||
return err
|
||||
}
|
||||
log.Printf("[DashboardAggregation] 重新计算完成 (start=%s end=%s duration=%s)",
|
||||
start.UTC().Format(time.RFC3339),
|
||||
end.UTC().Format(time.RFC3339),
|
||||
time.Since(jobStart).String(),
|
||||
)
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *DashboardAggregationService) runScheduledAggregation() {
|
||||
if !atomic.CompareAndSwapInt32(&s.running, 0, 1) {
|
||||
return
|
||||
@@ -179,7 +236,7 @@ func (s *DashboardAggregationService) runScheduledAggregation() {
|
||||
|
||||
func (s *DashboardAggregationService) backfillRange(ctx context.Context, start, end time.Time) error {
|
||||
if !atomic.CompareAndSwapInt32(&s.running, 0, 1) {
|
||||
return errors.New("聚合作业正在运行")
|
||||
return errDashboardAggregationRunning
|
||||
}
|
||||
defer atomic.StoreInt32(&s.running, 0)
|
||||
|
||||
|
||||
@@ -27,6 +27,10 @@ func (s *dashboardAggregationRepoTestStub) AggregateRange(ctx context.Context, s
|
||||
return s.aggregateErr
|
||||
}
|
||||
|
||||
func (s *dashboardAggregationRepoTestStub) RecomputeRange(ctx context.Context, start, end time.Time) error {
|
||||
return s.AggregateRange(ctx, start, end)
|
||||
}
|
||||
|
||||
func (s *dashboardAggregationRepoTestStub) GetAggregationWatermark(ctx context.Context) (time.Time, error) {
|
||||
return s.watermark, nil
|
||||
}
|
||||
|
||||
@@ -124,16 +124,16 @@ func (s *DashboardService) GetDashboardStats(ctx context.Context) (*usagestats.D
|
||||
return stats, nil
|
||||
}
|
||||
|
||||
func (s *DashboardService) GetUsageTrendWithFilters(ctx context.Context, startTime, endTime time.Time, granularity string, userID, apiKeyID, accountID, groupID int64, model string, stream *bool) ([]usagestats.TrendDataPoint, error) {
|
||||
trend, err := s.usageRepo.GetUsageTrendWithFilters(ctx, startTime, endTime, granularity, userID, apiKeyID, accountID, groupID, model, stream)
|
||||
func (s *DashboardService) GetUsageTrendWithFilters(ctx context.Context, startTime, endTime time.Time, granularity string, userID, apiKeyID, accountID, groupID int64, model string, stream *bool, billingType *int8) ([]usagestats.TrendDataPoint, error) {
|
||||
trend, err := s.usageRepo.GetUsageTrendWithFilters(ctx, startTime, endTime, granularity, userID, apiKeyID, accountID, groupID, model, stream, billingType)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("get usage trend with filters: %w", err)
|
||||
}
|
||||
return trend, nil
|
||||
}
|
||||
|
||||
func (s *DashboardService) GetModelStatsWithFilters(ctx context.Context, startTime, endTime time.Time, userID, apiKeyID, accountID, groupID int64, stream *bool) ([]usagestats.ModelStat, error) {
|
||||
stats, err := s.usageRepo.GetModelStatsWithFilters(ctx, startTime, endTime, userID, apiKeyID, accountID, groupID, stream)
|
||||
func (s *DashboardService) GetModelStatsWithFilters(ctx context.Context, startTime, endTime time.Time, userID, apiKeyID, accountID, groupID int64, stream *bool, billingType *int8) ([]usagestats.ModelStat, error) {
|
||||
stats, err := s.usageRepo.GetModelStatsWithFilters(ctx, startTime, endTime, userID, apiKeyID, accountID, groupID, stream, billingType)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("get model stats with filters: %w", err)
|
||||
}
|
||||
|
||||
@@ -101,6 +101,10 @@ func (s *dashboardAggregationRepoStub) AggregateRange(ctx context.Context, start
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *dashboardAggregationRepoStub) RecomputeRange(ctx context.Context, start, end time.Time) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *dashboardAggregationRepoStub) GetAggregationWatermark(ctx context.Context) (time.Time, error) {
|
||||
if s.err != nil {
|
||||
return time.Time{}, s.err
|
||||
|
||||
@@ -93,13 +93,14 @@ const (
|
||||
SettingKeyLinuxDoConnectRedirectURL = "linuxdo_connect_redirect_url"
|
||||
|
||||
// OEM设置
|
||||
SettingKeySiteName = "site_name" // 网站名称
|
||||
SettingKeySiteLogo = "site_logo" // 网站Logo (base64)
|
||||
SettingKeySiteSubtitle = "site_subtitle" // 网站副标题
|
||||
SettingKeyAPIBaseURL = "api_base_url" // API端点地址(用于客户端配置和导入)
|
||||
SettingKeyContactInfo = "contact_info" // 客服联系方式
|
||||
SettingKeyDocURL = "doc_url" // 文档链接
|
||||
SettingKeyHomeContent = "home_content" // 首页内容(支持 Markdown/HTML,或 URL 作为 iframe src)
|
||||
SettingKeySiteName = "site_name" // 网站名称
|
||||
SettingKeySiteLogo = "site_logo" // 网站Logo (base64)
|
||||
SettingKeySiteSubtitle = "site_subtitle" // 网站副标题
|
||||
SettingKeyAPIBaseURL = "api_base_url" // API端点地址(用于客户端配置和导入)
|
||||
SettingKeyContactInfo = "contact_info" // 客服联系方式
|
||||
SettingKeyDocURL = "doc_url" // 文档链接
|
||||
SettingKeyHomeContent = "home_content" // 首页内容(支持 Markdown/HTML,或 URL 作为 iframe src)
|
||||
SettingKeyHideCcsImportButton = "hide_ccs_import_button" // 是否隐藏 API Keys 页面的导入 CCS 按钮
|
||||
|
||||
// 默认配置
|
||||
SettingKeyDefaultConcurrency = "default_concurrency" // 新用户默认并发量
|
||||
|
||||
@@ -105,6 +105,9 @@ func (m *mockAccountRepoForPlatform) BatchUpdateLastUsed(ctx context.Context, up
|
||||
func (m *mockAccountRepoForPlatform) SetError(ctx context.Context, id int64, errorMsg string) error {
|
||||
return nil
|
||||
}
|
||||
func (m *mockAccountRepoForPlatform) ClearError(ctx context.Context, id int64) error {
|
||||
return nil
|
||||
}
|
||||
func (m *mockAccountRepoForPlatform) SetSchedulable(ctx context.Context, id int64, schedulable bool) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -11,6 +11,8 @@ import (
|
||||
"fmt"
|
||||
"io"
|
||||
"log"
|
||||
"log/slog"
|
||||
mathrand "math/rand"
|
||||
"net/http"
|
||||
"os"
|
||||
"regexp"
|
||||
@@ -44,13 +46,6 @@ func (s *GatewayService) debugModelRoutingEnabled() bool {
|
||||
return v == "1" || v == "true" || v == "yes" || v == "on"
|
||||
}
|
||||
|
||||
// debugLog prints log only in non-release mode.
|
||||
func debugLog(format string, v ...any) {
|
||||
if gin.Mode() != gin.ReleaseMode {
|
||||
log.Printf(format, v...)
|
||||
}
|
||||
}
|
||||
|
||||
func shortSessionHash(sessionHash string) string {
|
||||
if sessionHash == "" {
|
||||
return ""
|
||||
@@ -424,8 +419,11 @@ func (s *GatewayService) SelectAccountWithLoadAwareness(ctx context.Context, gro
|
||||
for id := range excludedIDs {
|
||||
excludedIDsList = append(excludedIDsList, id)
|
||||
}
|
||||
debugLog("[AccountScheduling] Starting account selection: groupID=%v model=%s session=%s excludedIDs=%v",
|
||||
derefGroupID(groupID), requestedModel, shortSessionHash(sessionHash), excludedIDsList)
|
||||
slog.Debug("account_scheduling_starting",
|
||||
"group_id", derefGroupID(groupID),
|
||||
"model", requestedModel,
|
||||
"session", shortSessionHash(sessionHash),
|
||||
"excluded_ids", excludedIDsList)
|
||||
|
||||
cfg := s.schedulingConfig()
|
||||
|
||||
@@ -918,7 +916,7 @@ func (s *GatewayService) SelectAccountWithLoadAwareness(ctx context.Context, gro
|
||||
}
|
||||
|
||||
// ============ Layer 3: 兜底排队 ============
|
||||
sortAccountsByPriorityAndLastUsed(candidates, preferOAuth)
|
||||
s.sortCandidatesForFallback(candidates, preferOAuth, cfg.FallbackSelectionMode)
|
||||
for _, acc := range candidates {
|
||||
// 会话数量限制检查(等待计划也需要占用会话配额)
|
||||
if !s.checkAndRegisterSession(ctx, acc, sessionHash) {
|
||||
@@ -1104,11 +1102,19 @@ func (s *GatewayService) listSchedulableAccounts(ctx context.Context, groupID *i
|
||||
if s.schedulerSnapshot != nil {
|
||||
accounts, useMixed, err := s.schedulerSnapshot.ListSchedulableAccounts(ctx, groupID, platform, hasForcePlatform)
|
||||
if err == nil {
|
||||
debugLog("[AccountScheduling] listSchedulableAccounts (snapshot): groupID=%v platform=%s useMixed=%v count=%d",
|
||||
derefGroupID(groupID), platform, useMixed, len(accounts))
|
||||
slog.Debug("account_scheduling_list_snapshot",
|
||||
"group_id", derefGroupID(groupID),
|
||||
"platform", platform,
|
||||
"use_mixed", useMixed,
|
||||
"count", len(accounts))
|
||||
for _, acc := range accounts {
|
||||
debugLog("[AccountScheduling] - Account ID=%d Name=%s Platform=%s Type=%s Status=%s TLSFingerprint=%v",
|
||||
acc.ID, acc.Name, acc.Platform, acc.Type, acc.Status, acc.IsTLSFingerprintEnabled())
|
||||
slog.Debug("account_scheduling_account_detail",
|
||||
"account_id", acc.ID,
|
||||
"name", acc.Name,
|
||||
"platform", acc.Platform,
|
||||
"type", acc.Type,
|
||||
"status", acc.Status,
|
||||
"tls_fingerprint", acc.IsTLSFingerprintEnabled())
|
||||
}
|
||||
}
|
||||
return accounts, useMixed, err
|
||||
@@ -1124,7 +1130,10 @@ func (s *GatewayService) listSchedulableAccounts(ctx context.Context, groupID *i
|
||||
accounts, err = s.accountRepo.ListSchedulableByPlatforms(ctx, platforms)
|
||||
}
|
||||
if err != nil {
|
||||
debugLog("[AccountScheduling] listSchedulableAccounts FAILED: groupID=%v platform=%s err=%v", derefGroupID(groupID), platform, err)
|
||||
slog.Debug("account_scheduling_list_failed",
|
||||
"group_id", derefGroupID(groupID),
|
||||
"platform", platform,
|
||||
"error", err)
|
||||
return nil, useMixed, err
|
||||
}
|
||||
filtered := make([]Account, 0, len(accounts))
|
||||
@@ -1134,11 +1143,19 @@ func (s *GatewayService) listSchedulableAccounts(ctx context.Context, groupID *i
|
||||
}
|
||||
filtered = append(filtered, acc)
|
||||
}
|
||||
debugLog("[AccountScheduling] listSchedulableAccounts (mixed): groupID=%v platform=%s rawCount=%d filteredCount=%d",
|
||||
derefGroupID(groupID), platform, len(accounts), len(filtered))
|
||||
slog.Debug("account_scheduling_list_mixed",
|
||||
"group_id", derefGroupID(groupID),
|
||||
"platform", platform,
|
||||
"raw_count", len(accounts),
|
||||
"filtered_count", len(filtered))
|
||||
for _, acc := range filtered {
|
||||
debugLog("[AccountScheduling] - Account ID=%d Name=%s Platform=%s Type=%s Status=%s TLSFingerprint=%v",
|
||||
acc.ID, acc.Name, acc.Platform, acc.Type, acc.Status, acc.IsTLSFingerprintEnabled())
|
||||
slog.Debug("account_scheduling_account_detail",
|
||||
"account_id", acc.ID,
|
||||
"name", acc.Name,
|
||||
"platform", acc.Platform,
|
||||
"type", acc.Type,
|
||||
"status", acc.Status,
|
||||
"tls_fingerprint", acc.IsTLSFingerprintEnabled())
|
||||
}
|
||||
return filtered, useMixed, nil
|
||||
}
|
||||
@@ -1154,14 +1171,24 @@ func (s *GatewayService) listSchedulableAccounts(ctx context.Context, groupID *i
|
||||
accounts, err = s.accountRepo.ListSchedulableByPlatform(ctx, platform)
|
||||
}
|
||||
if err != nil {
|
||||
debugLog("[AccountScheduling] listSchedulableAccounts FAILED: groupID=%v platform=%s err=%v", derefGroupID(groupID), platform, err)
|
||||
slog.Debug("account_scheduling_list_failed",
|
||||
"group_id", derefGroupID(groupID),
|
||||
"platform", platform,
|
||||
"error", err)
|
||||
return nil, useMixed, err
|
||||
}
|
||||
debugLog("[AccountScheduling] listSchedulableAccounts (single): groupID=%v platform=%s count=%d",
|
||||
derefGroupID(groupID), platform, len(accounts))
|
||||
slog.Debug("account_scheduling_list_single",
|
||||
"group_id", derefGroupID(groupID),
|
||||
"platform", platform,
|
||||
"count", len(accounts))
|
||||
for _, acc := range accounts {
|
||||
debugLog("[AccountScheduling] - Account ID=%d Name=%s Platform=%s Type=%s Status=%s TLSFingerprint=%v",
|
||||
acc.ID, acc.Name, acc.Platform, acc.Type, acc.Status, acc.IsTLSFingerprintEnabled())
|
||||
slog.Debug("account_scheduling_account_detail",
|
||||
"account_id", acc.ID,
|
||||
"name", acc.Name,
|
||||
"platform", acc.Platform,
|
||||
"type", acc.Type,
|
||||
"status", acc.Status,
|
||||
"tls_fingerprint", acc.IsTLSFingerprintEnabled())
|
||||
}
|
||||
return accounts, useMixed, nil
|
||||
}
|
||||
@@ -1228,12 +1255,8 @@ func (s *GatewayService) isAccountSchedulableForWindowCost(ctx context.Context,
|
||||
|
||||
// 缓存未命中,从数据库查询
|
||||
{
|
||||
var startTime time.Time
|
||||
if account.SessionWindowStart != nil {
|
||||
startTime = *account.SessionWindowStart
|
||||
} else {
|
||||
startTime = time.Now().Add(-5 * time.Hour)
|
||||
}
|
||||
// 使用统一的窗口开始时间计算逻辑(考虑窗口过期情况)
|
||||
startTime := account.GetCurrentWindowStartTime()
|
||||
|
||||
stats, err := s.usageLogRepo.GetAccountWindowStats(ctx, account.ID, startTime)
|
||||
if err != nil {
|
||||
@@ -1322,6 +1345,56 @@ func sortAccountsByPriorityAndLastUsed(accounts []*Account, preferOAuth bool) {
|
||||
})
|
||||
}
|
||||
|
||||
// sortCandidatesForFallback 根据配置选择排序策略
|
||||
// mode: "last_used"(按最后使用时间) 或 "random"(随机)
|
||||
func (s *GatewayService) sortCandidatesForFallback(accounts []*Account, preferOAuth bool, mode string) {
|
||||
if mode == "random" {
|
||||
// 先按优先级排序,然后在同优先级内随机打乱
|
||||
sortAccountsByPriorityOnly(accounts, preferOAuth)
|
||||
shuffleWithinPriority(accounts)
|
||||
} else {
|
||||
// 默认按最后使用时间排序
|
||||
sortAccountsByPriorityAndLastUsed(accounts, preferOAuth)
|
||||
}
|
||||
}
|
||||
|
||||
// sortAccountsByPriorityOnly 仅按优先级排序
|
||||
func sortAccountsByPriorityOnly(accounts []*Account, preferOAuth bool) {
|
||||
sort.SliceStable(accounts, func(i, j int) bool {
|
||||
a, b := accounts[i], accounts[j]
|
||||
if a.Priority != b.Priority {
|
||||
return a.Priority < b.Priority
|
||||
}
|
||||
if preferOAuth && a.Type != b.Type {
|
||||
return a.Type == AccountTypeOAuth
|
||||
}
|
||||
return false
|
||||
})
|
||||
}
|
||||
|
||||
// shuffleWithinPriority 在同优先级内随机打乱顺序
|
||||
func shuffleWithinPriority(accounts []*Account) {
|
||||
if len(accounts) <= 1 {
|
||||
return
|
||||
}
|
||||
r := mathrand.New(mathrand.NewSource(time.Now().UnixNano()))
|
||||
start := 0
|
||||
for start < len(accounts) {
|
||||
priority := accounts[start].Priority
|
||||
end := start + 1
|
||||
for end < len(accounts) && accounts[end].Priority == priority {
|
||||
end++
|
||||
}
|
||||
// 对 [start, end) 范围内的账户随机打乱
|
||||
if end-start > 1 {
|
||||
r.Shuffle(end-start, func(i, j int) {
|
||||
accounts[start+i], accounts[start+j] = accounts[start+j], accounts[start+i]
|
||||
})
|
||||
}
|
||||
start = end
|
||||
}
|
||||
}
|
||||
|
||||
// selectAccountForModelWithPlatform 选择单平台账户(完全隔离)
|
||||
func (s *GatewayService) selectAccountForModelWithPlatform(ctx context.Context, groupID *int64, sessionHash string, requestedModel string, excludedIDs map[int64]struct{}, platform string) (*Account, error) {
|
||||
preferOAuth := platform == PlatformGemini
|
||||
@@ -2558,9 +2631,10 @@ func (s *GatewayService) buildUpstreamRequest(ctx context.Context, c *gin.Contex
|
||||
fingerprint = fp
|
||||
|
||||
// 2. 重写metadata.user_id(需要指纹中的ClientID和账号的account_uuid)
|
||||
// 如果启用了会话ID伪装,会在重写后替换 session 部分为固定值
|
||||
accountUUID := account.GetExtraString("account_uuid")
|
||||
if accountUUID != "" && fp.ClientID != "" {
|
||||
if newBody, err := s.identityService.RewriteUserID(body, account.ID, accountUUID, fp.ClientID); err == nil && len(newBody) > 0 {
|
||||
if newBody, err := s.identityService.RewriteUserIDWithMasking(ctx, body, account, accountUUID, fp.ClientID); err == nil && len(newBody) > 0 {
|
||||
body = newBody
|
||||
}
|
||||
}
|
||||
@@ -3591,12 +3665,13 @@ func (s *GatewayService) buildCountTokensRequest(ctx context.Context, c *gin.Con
|
||||
}
|
||||
|
||||
// OAuth 账号:应用统一指纹和重写 userID
|
||||
// 如果启用了会话ID伪装,会在重写后替换 session 部分为固定值
|
||||
if account.IsOAuth() && s.identityService != nil {
|
||||
fp, err := s.identityService.GetOrCreateFingerprint(ctx, account.ID, c.Request.Header)
|
||||
if err == nil {
|
||||
accountUUID := account.GetExtraString("account_uuid")
|
||||
if accountUUID != "" && fp.ClientID != "" {
|
||||
if newBody, err := s.identityService.RewriteUserID(body, account.ID, accountUUID, fp.ClientID); err == nil && len(newBody) > 0 {
|
||||
if newBody, err := s.identityService.RewriteUserIDWithMasking(ctx, body, account, accountUUID, fp.ClientID); err == nil && len(newBody) > 0 {
|
||||
body = newBody
|
||||
}
|
||||
}
|
||||
|
||||
@@ -88,6 +88,9 @@ func (m *mockAccountRepoForGemini) BatchUpdateLastUsed(ctx context.Context, upda
|
||||
func (m *mockAccountRepoForGemini) SetError(ctx context.Context, id int64, errorMsg string) error {
|
||||
return nil
|
||||
}
|
||||
func (m *mockAccountRepoForGemini) ClearError(ctx context.Context, id int64) error {
|
||||
return nil
|
||||
}
|
||||
func (m *mockAccountRepoForGemini) SetSchedulable(ctx context.Context, id int64, schedulable bool) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -8,9 +8,11 @@ import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"log"
|
||||
"log/slog"
|
||||
"net/http"
|
||||
"regexp"
|
||||
"strconv"
|
||||
"strings"
|
||||
"time"
|
||||
)
|
||||
|
||||
@@ -49,6 +51,13 @@ type Fingerprint struct {
|
||||
type IdentityCache interface {
|
||||
GetFingerprint(ctx context.Context, accountID int64) (*Fingerprint, error)
|
||||
SetFingerprint(ctx context.Context, accountID int64, fp *Fingerprint) error
|
||||
// GetMaskedSessionID 获取固定的会话ID(用于会话ID伪装功能)
|
||||
// 返回的 sessionID 是一个 UUID 格式的字符串
|
||||
// 如果不存在或已过期(15分钟无请求),返回空字符串
|
||||
GetMaskedSessionID(ctx context.Context, accountID int64) (string, error)
|
||||
// SetMaskedSessionID 设置固定的会话ID,TTL 为 15 分钟
|
||||
// 每次调用都会刷新 TTL
|
||||
SetMaskedSessionID(ctx context.Context, accountID int64, sessionID string) error
|
||||
}
|
||||
|
||||
// IdentityService 管理OAuth账号的请求身份指纹
|
||||
@@ -203,6 +212,94 @@ func (s *IdentityService) RewriteUserID(body []byte, accountID int64, accountUUI
|
||||
return json.Marshal(reqMap)
|
||||
}
|
||||
|
||||
// RewriteUserIDWithMasking 重写body中的metadata.user_id,支持会话ID伪装
|
||||
// 如果账号启用了会话ID伪装(session_id_masking_enabled),
|
||||
// 则在完成常规重写后,将 session 部分替换为固定的伪装ID(15分钟内保持不变)
|
||||
func (s *IdentityService) RewriteUserIDWithMasking(ctx context.Context, body []byte, account *Account, accountUUID, cachedClientID string) ([]byte, error) {
|
||||
// 先执行常规的 RewriteUserID 逻辑
|
||||
newBody, err := s.RewriteUserID(body, account.ID, accountUUID, cachedClientID)
|
||||
if err != nil {
|
||||
return newBody, err
|
||||
}
|
||||
|
||||
// 检查是否启用会话ID伪装
|
||||
if !account.IsSessionIDMaskingEnabled() {
|
||||
return newBody, nil
|
||||
}
|
||||
|
||||
// 解析重写后的 body,提取 user_id
|
||||
var reqMap map[string]any
|
||||
if err := json.Unmarshal(newBody, &reqMap); err != nil {
|
||||
return newBody, nil
|
||||
}
|
||||
|
||||
metadata, ok := reqMap["metadata"].(map[string]any)
|
||||
if !ok {
|
||||
return newBody, nil
|
||||
}
|
||||
|
||||
userID, ok := metadata["user_id"].(string)
|
||||
if !ok || userID == "" {
|
||||
return newBody, nil
|
||||
}
|
||||
|
||||
// 查找 _session_ 的位置,替换其后的内容
|
||||
const sessionMarker = "_session_"
|
||||
idx := strings.LastIndex(userID, sessionMarker)
|
||||
if idx == -1 {
|
||||
return newBody, nil
|
||||
}
|
||||
|
||||
// 获取或生成固定的伪装 session ID
|
||||
maskedSessionID, err := s.cache.GetMaskedSessionID(ctx, account.ID)
|
||||
if err != nil {
|
||||
log.Printf("Warning: failed to get masked session ID for account %d: %v", account.ID, err)
|
||||
return newBody, nil
|
||||
}
|
||||
|
||||
if maskedSessionID == "" {
|
||||
// 首次或已过期,生成新的伪装 session ID
|
||||
maskedSessionID = generateRandomUUID()
|
||||
log.Printf("Generated new masked session ID for account %d: %s", account.ID, maskedSessionID)
|
||||
}
|
||||
|
||||
// 刷新 TTL(每次请求都刷新,保持 15 分钟有效期)
|
||||
if err := s.cache.SetMaskedSessionID(ctx, account.ID, maskedSessionID); err != nil {
|
||||
log.Printf("Warning: failed to set masked session ID for account %d: %v", account.ID, err)
|
||||
}
|
||||
|
||||
// 替换 session 部分:保留 _session_ 之前的内容,替换之后的内容
|
||||
newUserID := userID[:idx+len(sessionMarker)] + maskedSessionID
|
||||
|
||||
slog.Debug("session_id_masking_applied",
|
||||
"account_id", account.ID,
|
||||
"before", userID,
|
||||
"after", newUserID,
|
||||
)
|
||||
|
||||
metadata["user_id"] = newUserID
|
||||
reqMap["metadata"] = metadata
|
||||
|
||||
return json.Marshal(reqMap)
|
||||
}
|
||||
|
||||
// generateRandomUUID 生成随机 UUID v4 格式字符串
|
||||
func generateRandomUUID() string {
|
||||
b := make([]byte, 16)
|
||||
if _, err := rand.Read(b); err != nil {
|
||||
// fallback: 使用时间戳生成
|
||||
h := sha256.Sum256([]byte(fmt.Sprintf("%d", time.Now().UnixNano())))
|
||||
b = h[:16]
|
||||
}
|
||||
|
||||
// 设置 UUID v4 版本和变体位
|
||||
b[6] = (b[6] & 0x0f) | 0x40
|
||||
b[8] = (b[8] & 0x3f) | 0x80
|
||||
|
||||
return fmt.Sprintf("%x-%x-%x-%x-%x",
|
||||
b[0:4], b[4:6], b[6:8], b[8:10], b[10:16])
|
||||
}
|
||||
|
||||
// generateClientID 生成64位十六进制客户端ID(32字节随机数)
|
||||
func generateClientID() string {
|
||||
b := make([]byte, 32)
|
||||
|
||||
@@ -48,8 +48,7 @@ type GenerateAuthURLResult struct {
|
||||
|
||||
// GenerateAuthURL generates an OAuth authorization URL with full scope
|
||||
func (s *OAuthService) GenerateAuthURL(ctx context.Context, proxyID *int64) (*GenerateAuthURLResult, error) {
|
||||
scope := fmt.Sprintf("%s %s", oauth.ScopeProfile, oauth.ScopeInference)
|
||||
return s.generateAuthURLWithScope(ctx, scope, proxyID)
|
||||
return s.generateAuthURLWithScope(ctx, oauth.ScopeOAuth, proxyID)
|
||||
}
|
||||
|
||||
// GenerateSetupTokenURL generates an OAuth authorization URL for setup token (inference only)
|
||||
@@ -176,7 +175,8 @@ func (s *OAuthService) CookieAuth(ctx context.Context, input *CookieAuthInput) (
|
||||
}
|
||||
|
||||
// Determine scope and if this is a setup token
|
||||
scope := fmt.Sprintf("%s %s", oauth.ScopeProfile, oauth.ScopeInference)
|
||||
// Internal API call uses ScopeAPI (org:create_api_key not supported)
|
||||
scope := oauth.ScopeAPI
|
||||
isSetupToken := false
|
||||
if input.Scope == "inference" {
|
||||
scope = oauth.ScopeInference
|
||||
|
||||
@@ -73,10 +73,14 @@ func (s *RateLimitService) HandleUpstreamError(ctx context.Context, account *Acc
|
||||
return false
|
||||
}
|
||||
|
||||
tempMatched := false
|
||||
// 先尝试临时不可调度规则(401除外)
|
||||
// 如果匹配成功,直接返回,不执行后续禁用逻辑
|
||||
if statusCode != 401 {
|
||||
tempMatched = s.tryTempUnschedulable(ctx, account, statusCode, responseBody)
|
||||
if s.tryTempUnschedulable(ctx, account, statusCode, responseBody) {
|
||||
return true
|
||||
}
|
||||
}
|
||||
|
||||
upstreamMsg := strings.TrimSpace(extractUpstreamErrorMessage(responseBody))
|
||||
upstreamMsg = sanitizeUpstreamErrorMessage(upstreamMsg)
|
||||
if upstreamMsg != "" {
|
||||
@@ -84,6 +88,14 @@ func (s *RateLimitService) HandleUpstreamError(ctx context.Context, account *Acc
|
||||
}
|
||||
|
||||
switch statusCode {
|
||||
case 400:
|
||||
// 只有当错误信息包含 "organization has been disabled" 时才禁用
|
||||
if strings.Contains(strings.ToLower(upstreamMsg), "organization has been disabled") {
|
||||
msg := "Organization disabled (400): " + upstreamMsg
|
||||
s.handleAuthError(ctx, account, msg)
|
||||
shouldDisable = true
|
||||
}
|
||||
// 其他 400 错误(如参数问题)不处理,不禁用账号
|
||||
case 401:
|
||||
// 对所有 OAuth 账号在 401 错误时调用缓存失效并强制下次刷新
|
||||
if account.Type == AccountTypeOAuth {
|
||||
@@ -148,9 +160,6 @@ func (s *RateLimitService) HandleUpstreamError(ctx context.Context, account *Acc
|
||||
}
|
||||
}
|
||||
|
||||
if tempMatched {
|
||||
return true
|
||||
}
|
||||
return shouldDisable
|
||||
}
|
||||
|
||||
@@ -190,7 +199,7 @@ func (s *RateLimitService) PreCheckUsage(ctx context.Context, account *Account,
|
||||
start := geminiDailyWindowStart(now)
|
||||
totals, ok := s.getGeminiUsageTotals(account.ID, start, now)
|
||||
if !ok {
|
||||
stats, err := s.usageRepo.GetModelStatsWithFilters(ctx, start, now, 0, 0, account.ID, 0, nil)
|
||||
stats, err := s.usageRepo.GetModelStatsWithFilters(ctx, start, now, 0, 0, account.ID, 0, nil, nil)
|
||||
if err != nil {
|
||||
return true, err
|
||||
}
|
||||
@@ -237,7 +246,7 @@ func (s *RateLimitService) PreCheckUsage(ctx context.Context, account *Account,
|
||||
|
||||
if limit > 0 {
|
||||
start := now.Truncate(time.Minute)
|
||||
stats, err := s.usageRepo.GetModelStatsWithFilters(ctx, start, now, 0, 0, account.ID, 0, nil)
|
||||
stats, err := s.usageRepo.GetModelStatsWithFilters(ctx, start, now, 0, 0, account.ID, 0, nil, nil)
|
||||
if err != nil {
|
||||
return true, err
|
||||
}
|
||||
|
||||
@@ -38,8 +38,9 @@ type SessionLimitCache interface {
|
||||
GetActiveSessionCount(ctx context.Context, accountID int64) (int, error)
|
||||
|
||||
// GetActiveSessionCountBatch 批量获取多个账号的活跃会话数
|
||||
// idleTimeouts: 每个账号的空闲超时时间配置,key 为 accountID;若为 nil 或某账号不在其中,则使用默认超时
|
||||
// 返回 map[accountID]count,查询失败的账号不在 map 中
|
||||
GetActiveSessionCountBatch(ctx context.Context, accountIDs []int64) (map[int64]int, error)
|
||||
GetActiveSessionCountBatch(ctx context.Context, accountIDs []int64, idleTimeouts map[int64]time.Duration) (map[int64]int, error)
|
||||
|
||||
// IsSessionActive 检查特定会话是否活跃(未过期)
|
||||
IsSessionActive(ctx context.Context, accountID int64, sessionUUID string) (bool, error)
|
||||
|
||||
@@ -69,6 +69,7 @@ func (s *SettingService) GetPublicSettings(ctx context.Context) (*PublicSettings
|
||||
SettingKeyContactInfo,
|
||||
SettingKeyDocURL,
|
||||
SettingKeyHomeContent,
|
||||
SettingKeyHideCcsImportButton,
|
||||
SettingKeyLinuxDoConnectEnabled,
|
||||
}
|
||||
|
||||
@@ -96,6 +97,7 @@ func (s *SettingService) GetPublicSettings(ctx context.Context) (*PublicSettings
|
||||
ContactInfo: settings[SettingKeyContactInfo],
|
||||
DocURL: settings[SettingKeyDocURL],
|
||||
HomeContent: settings[SettingKeyHomeContent],
|
||||
HideCcsImportButton: settings[SettingKeyHideCcsImportButton] == "true",
|
||||
LinuxDoOAuthEnabled: linuxDoEnabled,
|
||||
}, nil
|
||||
}
|
||||
@@ -132,6 +134,7 @@ func (s *SettingService) GetPublicSettingsForInjection(ctx context.Context) (any
|
||||
ContactInfo string `json:"contact_info,omitempty"`
|
||||
DocURL string `json:"doc_url,omitempty"`
|
||||
HomeContent string `json:"home_content,omitempty"`
|
||||
HideCcsImportButton bool `json:"hide_ccs_import_button"`
|
||||
LinuxDoOAuthEnabled bool `json:"linuxdo_oauth_enabled"`
|
||||
Version string `json:"version,omitempty"`
|
||||
}{
|
||||
@@ -146,6 +149,7 @@ func (s *SettingService) GetPublicSettingsForInjection(ctx context.Context) (any
|
||||
ContactInfo: settings.ContactInfo,
|
||||
DocURL: settings.DocURL,
|
||||
HomeContent: settings.HomeContent,
|
||||
HideCcsImportButton: settings.HideCcsImportButton,
|
||||
LinuxDoOAuthEnabled: settings.LinuxDoOAuthEnabled,
|
||||
Version: s.version,
|
||||
}, nil
|
||||
@@ -193,6 +197,7 @@ func (s *SettingService) UpdateSettings(ctx context.Context, settings *SystemSet
|
||||
updates[SettingKeyContactInfo] = settings.ContactInfo
|
||||
updates[SettingKeyDocURL] = settings.DocURL
|
||||
updates[SettingKeyHomeContent] = settings.HomeContent
|
||||
updates[SettingKeyHideCcsImportButton] = strconv.FormatBool(settings.HideCcsImportButton)
|
||||
|
||||
// 默认配置
|
||||
updates[SettingKeyDefaultConcurrency] = strconv.Itoa(settings.DefaultConcurrency)
|
||||
@@ -339,6 +344,7 @@ func (s *SettingService) parseSettings(settings map[string]string) *SystemSettin
|
||||
ContactInfo: settings[SettingKeyContactInfo],
|
||||
DocURL: settings[SettingKeyDocURL],
|
||||
HomeContent: settings[SettingKeyHomeContent],
|
||||
HideCcsImportButton: settings[SettingKeyHideCcsImportButton] == "true",
|
||||
}
|
||||
|
||||
// 解析整数类型
|
||||
|
||||
@@ -25,13 +25,14 @@ type SystemSettings struct {
|
||||
LinuxDoConnectClientSecretConfigured bool
|
||||
LinuxDoConnectRedirectURL string
|
||||
|
||||
SiteName string
|
||||
SiteLogo string
|
||||
SiteSubtitle string
|
||||
APIBaseURL string
|
||||
ContactInfo string
|
||||
DocURL string
|
||||
HomeContent string
|
||||
SiteName string
|
||||
SiteLogo string
|
||||
SiteSubtitle string
|
||||
APIBaseURL string
|
||||
ContactInfo string
|
||||
DocURL string
|
||||
HomeContent string
|
||||
HideCcsImportButton bool
|
||||
|
||||
DefaultConcurrency int
|
||||
DefaultBalance float64
|
||||
@@ -66,6 +67,7 @@ type PublicSettings struct {
|
||||
ContactInfo string
|
||||
DocURL string
|
||||
HomeContent string
|
||||
HideCcsImportButton bool
|
||||
LinuxDoOAuthEnabled bool
|
||||
Version string
|
||||
}
|
||||
|
||||
@@ -27,6 +27,7 @@ var (
|
||||
ErrWeeklyLimitExceeded = infraerrors.TooManyRequests("WEEKLY_LIMIT_EXCEEDED", "weekly usage limit exceeded")
|
||||
ErrMonthlyLimitExceeded = infraerrors.TooManyRequests("MONTHLY_LIMIT_EXCEEDED", "monthly usage limit exceeded")
|
||||
ErrSubscriptionNilInput = infraerrors.BadRequest("SUBSCRIPTION_NIL_INPUT", "subscription input cannot be nil")
|
||||
ErrAdjustWouldExpire = infraerrors.BadRequest("ADJUST_WOULD_EXPIRE", "adjustment would result in expired subscription (remaining days must be > 0)")
|
||||
)
|
||||
|
||||
// SubscriptionService 订阅服务
|
||||
@@ -308,17 +309,20 @@ func (s *SubscriptionService) RevokeSubscription(ctx context.Context, subscripti
|
||||
return nil
|
||||
}
|
||||
|
||||
// ExtendSubscription 延长订阅
|
||||
// ExtendSubscription 调整订阅时长(正数延长,负数缩短)
|
||||
func (s *SubscriptionService) ExtendSubscription(ctx context.Context, subscriptionID int64, days int) (*UserSubscription, error) {
|
||||
sub, err := s.userSubRepo.GetByID(ctx, subscriptionID)
|
||||
if err != nil {
|
||||
return nil, ErrSubscriptionNotFound
|
||||
}
|
||||
|
||||
// 限制延长天数
|
||||
// 限制调整天数范围
|
||||
if days > MaxValidityDays {
|
||||
days = MaxValidityDays
|
||||
}
|
||||
if days < -MaxValidityDays {
|
||||
days = -MaxValidityDays
|
||||
}
|
||||
|
||||
// 计算新的过期时间
|
||||
newExpiresAt := sub.ExpiresAt.AddDate(0, 0, days)
|
||||
@@ -326,6 +330,14 @@ func (s *SubscriptionService) ExtendSubscription(ctx context.Context, subscripti
|
||||
newExpiresAt = MaxExpiresAt
|
||||
}
|
||||
|
||||
// 如果是缩短(负数),检查新的过期时间必须大于当前时间
|
||||
if days < 0 {
|
||||
now := time.Now()
|
||||
if !newExpiresAt.After(now) {
|
||||
return nil, ErrAdjustWouldExpire
|
||||
}
|
||||
}
|
||||
|
||||
if err := s.userSubRepo.ExtendExpiry(ctx, subscriptionID, newExpiresAt); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
@@ -166,11 +166,25 @@ func (s *TokenRefreshService) refreshWithRetry(ctx context.Context, account *Acc
|
||||
|
||||
for attempt := 1; attempt <= s.cfg.MaxRetries; attempt++ {
|
||||
newCredentials, err := refresher.Refresh(ctx, account)
|
||||
if err == nil {
|
||||
// 刷新成功,更新账号credentials
|
||||
|
||||
// 如果有新凭证,先更新(即使有错误也要保存 token)
|
||||
if newCredentials != nil {
|
||||
account.Credentials = newCredentials
|
||||
if err := s.accountRepo.Update(ctx, account); err != nil {
|
||||
return fmt.Errorf("failed to save credentials: %w", err)
|
||||
if saveErr := s.accountRepo.Update(ctx, account); saveErr != nil {
|
||||
return fmt.Errorf("failed to save credentials: %w", saveErr)
|
||||
}
|
||||
}
|
||||
|
||||
if err == nil {
|
||||
// Antigravity 账户:如果之前是因为缺少 project_id 而标记为 error,现在成功获取到了,清除错误状态
|
||||
if account.Platform == PlatformAntigravity &&
|
||||
account.Status == StatusError &&
|
||||
strings.Contains(account.ErrorMessage, "missing_project_id:") {
|
||||
if clearErr := s.accountRepo.ClearError(ctx, account.ID); clearErr != nil {
|
||||
log.Printf("[TokenRefresh] Failed to clear error status for account %d: %v", account.ID, clearErr)
|
||||
} else {
|
||||
log.Printf("[TokenRefresh] Account %d: cleared missing_project_id error", account.ID)
|
||||
}
|
||||
}
|
||||
// 对所有 OAuth 账号调用缓存失效(InvalidateToken 内部根据平台判断是否需要处理)
|
||||
if s.cacheInvalidator != nil && account.Type == AccountTypeOAuth {
|
||||
@@ -230,6 +244,7 @@ func isNonRetryableRefreshError(err error) bool {
|
||||
"invalid_client", // 客户端配置错误
|
||||
"unauthorized_client", // 客户端未授权
|
||||
"access_denied", // 访问被拒绝
|
||||
"missing_project_id", // 缺少 project_id
|
||||
}
|
||||
for _, needle := range nonRetryable {
|
||||
if strings.Contains(msg, needle) {
|
||||
|
||||
74
backend/internal/service/usage_cleanup.go
Normal file
74
backend/internal/service/usage_cleanup.go
Normal file
@@ -0,0 +1,74 @@
|
||||
package service
|
||||
|
||||
import (
|
||||
"context"
|
||||
"time"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/pagination"
|
||||
)
|
||||
|
||||
const (
|
||||
UsageCleanupStatusPending = "pending"
|
||||
UsageCleanupStatusRunning = "running"
|
||||
UsageCleanupStatusSucceeded = "succeeded"
|
||||
UsageCleanupStatusFailed = "failed"
|
||||
UsageCleanupStatusCanceled = "canceled"
|
||||
)
|
||||
|
||||
// UsageCleanupFilters 定义清理任务过滤条件
|
||||
// 时间范围为必填,其他字段可选
|
||||
// JSON 序列化用于存储任务参数
|
||||
//
|
||||
// start_time/end_time 使用 RFC3339 时间格式
|
||||
// 以 UTC 或用户时区解析后的时间为准
|
||||
//
|
||||
// 说明:
|
||||
// - nil 表示未设置该过滤条件
|
||||
// - 过滤条件均为精确匹配
|
||||
type UsageCleanupFilters struct {
|
||||
StartTime time.Time `json:"start_time"`
|
||||
EndTime time.Time `json:"end_time"`
|
||||
UserID *int64 `json:"user_id,omitempty"`
|
||||
APIKeyID *int64 `json:"api_key_id,omitempty"`
|
||||
AccountID *int64 `json:"account_id,omitempty"`
|
||||
GroupID *int64 `json:"group_id,omitempty"`
|
||||
Model *string `json:"model,omitempty"`
|
||||
Stream *bool `json:"stream,omitempty"`
|
||||
BillingType *int8 `json:"billing_type,omitempty"`
|
||||
}
|
||||
|
||||
// UsageCleanupTask 表示使用记录清理任务
|
||||
// 状态包含 pending/running/succeeded/failed/canceled
|
||||
type UsageCleanupTask struct {
|
||||
ID int64
|
||||
Status string
|
||||
Filters UsageCleanupFilters
|
||||
CreatedBy int64
|
||||
DeletedRows int64
|
||||
ErrorMsg *string
|
||||
CanceledBy *int64
|
||||
CanceledAt *time.Time
|
||||
StartedAt *time.Time
|
||||
FinishedAt *time.Time
|
||||
CreatedAt time.Time
|
||||
UpdatedAt time.Time
|
||||
}
|
||||
|
||||
// UsageCleanupRepository 定义清理任务持久层接口
|
||||
type UsageCleanupRepository interface {
|
||||
CreateTask(ctx context.Context, task *UsageCleanupTask) error
|
||||
ListTasks(ctx context.Context, params pagination.PaginationParams) ([]UsageCleanupTask, *pagination.PaginationResult, error)
|
||||
// ClaimNextPendingTask 抢占下一条可执行任务:
|
||||
// - 优先 pending
|
||||
// - 若 running 超过 staleRunningAfterSeconds(可能由于进程退出/崩溃/超时),允许重新抢占继续执行
|
||||
ClaimNextPendingTask(ctx context.Context, staleRunningAfterSeconds int64) (*UsageCleanupTask, error)
|
||||
// GetTaskStatus 查询任务状态;若不存在返回 sql.ErrNoRows
|
||||
GetTaskStatus(ctx context.Context, taskID int64) (string, error)
|
||||
// UpdateTaskProgress 更新任务进度(deleted_rows)用于断点续跑/展示
|
||||
UpdateTaskProgress(ctx context.Context, taskID int64, deletedRows int64) error
|
||||
// CancelTask 将任务标记为 canceled(仅允许 pending/running)
|
||||
CancelTask(ctx context.Context, taskID int64, canceledBy int64) (bool, error)
|
||||
MarkTaskSucceeded(ctx context.Context, taskID int64, deletedRows int64) error
|
||||
MarkTaskFailed(ctx context.Context, taskID int64, deletedRows int64, errorMsg string) error
|
||||
DeleteUsageLogsBatch(ctx context.Context, filters UsageCleanupFilters, limit int) (int64, error)
|
||||
}
|
||||
404
backend/internal/service/usage_cleanup_service.go
Normal file
404
backend/internal/service/usage_cleanup_service.go
Normal file
@@ -0,0 +1,404 @@
|
||||
package service
|
||||
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"errors"
|
||||
"fmt"
|
||||
"log"
|
||||
"net/http"
|
||||
"strings"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
"time"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/config"
|
||||
infraerrors "github.com/Wei-Shaw/sub2api/internal/pkg/errors"
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/pagination"
|
||||
)
|
||||
|
||||
const (
|
||||
usageCleanupWorkerName = "usage_cleanup_worker"
|
||||
)
|
||||
|
||||
// UsageCleanupService 负责创建与执行使用记录清理任务
|
||||
type UsageCleanupService struct {
|
||||
repo UsageCleanupRepository
|
||||
timingWheel *TimingWheelService
|
||||
dashboard *DashboardAggregationService
|
||||
cfg *config.Config
|
||||
|
||||
running int32
|
||||
startOnce sync.Once
|
||||
stopOnce sync.Once
|
||||
|
||||
workerCtx context.Context
|
||||
workerCancel context.CancelFunc
|
||||
}
|
||||
|
||||
func NewUsageCleanupService(repo UsageCleanupRepository, timingWheel *TimingWheelService, dashboard *DashboardAggregationService, cfg *config.Config) *UsageCleanupService {
|
||||
workerCtx, workerCancel := context.WithCancel(context.Background())
|
||||
return &UsageCleanupService{
|
||||
repo: repo,
|
||||
timingWheel: timingWheel,
|
||||
dashboard: dashboard,
|
||||
cfg: cfg,
|
||||
workerCtx: workerCtx,
|
||||
workerCancel: workerCancel,
|
||||
}
|
||||
}
|
||||
|
||||
func describeUsageCleanupFilters(filters UsageCleanupFilters) string {
|
||||
var parts []string
|
||||
parts = append(parts, "start="+filters.StartTime.UTC().Format(time.RFC3339))
|
||||
parts = append(parts, "end="+filters.EndTime.UTC().Format(time.RFC3339))
|
||||
if filters.UserID != nil {
|
||||
parts = append(parts, fmt.Sprintf("user_id=%d", *filters.UserID))
|
||||
}
|
||||
if filters.APIKeyID != nil {
|
||||
parts = append(parts, fmt.Sprintf("api_key_id=%d", *filters.APIKeyID))
|
||||
}
|
||||
if filters.AccountID != nil {
|
||||
parts = append(parts, fmt.Sprintf("account_id=%d", *filters.AccountID))
|
||||
}
|
||||
if filters.GroupID != nil {
|
||||
parts = append(parts, fmt.Sprintf("group_id=%d", *filters.GroupID))
|
||||
}
|
||||
if filters.Model != nil {
|
||||
parts = append(parts, "model="+strings.TrimSpace(*filters.Model))
|
||||
}
|
||||
if filters.Stream != nil {
|
||||
parts = append(parts, fmt.Sprintf("stream=%t", *filters.Stream))
|
||||
}
|
||||
if filters.BillingType != nil {
|
||||
parts = append(parts, fmt.Sprintf("billing_type=%d", *filters.BillingType))
|
||||
}
|
||||
return strings.Join(parts, " ")
|
||||
}
|
||||
|
||||
func (s *UsageCleanupService) Start() {
|
||||
if s == nil {
|
||||
return
|
||||
}
|
||||
if s.cfg != nil && !s.cfg.UsageCleanup.Enabled {
|
||||
log.Printf("[UsageCleanup] not started (disabled)")
|
||||
return
|
||||
}
|
||||
if s.repo == nil || s.timingWheel == nil {
|
||||
log.Printf("[UsageCleanup] not started (missing deps)")
|
||||
return
|
||||
}
|
||||
|
||||
interval := s.workerInterval()
|
||||
s.startOnce.Do(func() {
|
||||
s.timingWheel.ScheduleRecurring(usageCleanupWorkerName, interval, s.runOnce)
|
||||
log.Printf("[UsageCleanup] started (interval=%s max_range_days=%d batch_size=%d task_timeout=%s)", interval, s.maxRangeDays(), s.batchSize(), s.taskTimeout())
|
||||
})
|
||||
}
|
||||
|
||||
func (s *UsageCleanupService) Stop() {
|
||||
if s == nil {
|
||||
return
|
||||
}
|
||||
s.stopOnce.Do(func() {
|
||||
if s.workerCancel != nil {
|
||||
s.workerCancel()
|
||||
}
|
||||
if s.timingWheel != nil {
|
||||
s.timingWheel.Cancel(usageCleanupWorkerName)
|
||||
}
|
||||
log.Printf("[UsageCleanup] stopped")
|
||||
})
|
||||
}
|
||||
|
||||
func (s *UsageCleanupService) ListTasks(ctx context.Context, params pagination.PaginationParams) ([]UsageCleanupTask, *pagination.PaginationResult, error) {
|
||||
if s == nil || s.repo == nil {
|
||||
return nil, nil, fmt.Errorf("cleanup service not ready")
|
||||
}
|
||||
return s.repo.ListTasks(ctx, params)
|
||||
}
|
||||
|
||||
func (s *UsageCleanupService) CreateTask(ctx context.Context, filters UsageCleanupFilters, createdBy int64) (*UsageCleanupTask, error) {
|
||||
if s == nil || s.repo == nil {
|
||||
return nil, fmt.Errorf("cleanup service not ready")
|
||||
}
|
||||
if s.cfg != nil && !s.cfg.UsageCleanup.Enabled {
|
||||
return nil, infraerrors.New(http.StatusServiceUnavailable, "USAGE_CLEANUP_DISABLED", "usage cleanup is disabled")
|
||||
}
|
||||
if createdBy <= 0 {
|
||||
return nil, infraerrors.BadRequest("USAGE_CLEANUP_INVALID_CREATOR", "invalid creator")
|
||||
}
|
||||
|
||||
log.Printf("[UsageCleanup] create_task requested: operator=%d %s", createdBy, describeUsageCleanupFilters(filters))
|
||||
sanitizeUsageCleanupFilters(&filters)
|
||||
if err := s.validateFilters(filters); err != nil {
|
||||
log.Printf("[UsageCleanup] create_task rejected: operator=%d err=%v %s", createdBy, err, describeUsageCleanupFilters(filters))
|
||||
return nil, err
|
||||
}
|
||||
|
||||
task := &UsageCleanupTask{
|
||||
Status: UsageCleanupStatusPending,
|
||||
Filters: filters,
|
||||
CreatedBy: createdBy,
|
||||
}
|
||||
if err := s.repo.CreateTask(ctx, task); err != nil {
|
||||
log.Printf("[UsageCleanup] create_task persist failed: operator=%d err=%v %s", createdBy, err, describeUsageCleanupFilters(filters))
|
||||
return nil, fmt.Errorf("create cleanup task: %w", err)
|
||||
}
|
||||
log.Printf("[UsageCleanup] create_task persisted: task=%d operator=%d status=%s deleted_rows=%d %s", task.ID, createdBy, task.Status, task.DeletedRows, describeUsageCleanupFilters(filters))
|
||||
go s.runOnce()
|
||||
return task, nil
|
||||
}
|
||||
|
||||
func (s *UsageCleanupService) runOnce() {
|
||||
svc := s
|
||||
if svc == nil {
|
||||
return
|
||||
}
|
||||
if !atomic.CompareAndSwapInt32(&svc.running, 0, 1) {
|
||||
log.Printf("[UsageCleanup] run_once skipped: already_running=true")
|
||||
return
|
||||
}
|
||||
defer atomic.StoreInt32(&svc.running, 0)
|
||||
|
||||
parent := context.Background()
|
||||
if svc.workerCtx != nil {
|
||||
parent = svc.workerCtx
|
||||
}
|
||||
ctx, cancel := context.WithTimeout(parent, svc.taskTimeout())
|
||||
defer cancel()
|
||||
|
||||
task, err := svc.repo.ClaimNextPendingTask(ctx, int64(svc.taskTimeout().Seconds()))
|
||||
if err != nil {
|
||||
log.Printf("[UsageCleanup] claim pending task failed: %v", err)
|
||||
return
|
||||
}
|
||||
if task == nil {
|
||||
log.Printf("[UsageCleanup] run_once done: no_task=true")
|
||||
return
|
||||
}
|
||||
|
||||
log.Printf("[UsageCleanup] task claimed: task=%d status=%s created_by=%d deleted_rows=%d %s", task.ID, task.Status, task.CreatedBy, task.DeletedRows, describeUsageCleanupFilters(task.Filters))
|
||||
svc.executeTask(ctx, task)
|
||||
}
|
||||
|
||||
func (s *UsageCleanupService) executeTask(ctx context.Context, task *UsageCleanupTask) {
|
||||
if task == nil {
|
||||
return
|
||||
}
|
||||
|
||||
batchSize := s.batchSize()
|
||||
deletedTotal := task.DeletedRows
|
||||
start := time.Now()
|
||||
log.Printf("[UsageCleanup] task started: task=%d batch_size=%d deleted_rows=%d %s", task.ID, batchSize, deletedTotal, describeUsageCleanupFilters(task.Filters))
|
||||
var batchNum int
|
||||
|
||||
for {
|
||||
if ctx != nil && ctx.Err() != nil {
|
||||
log.Printf("[UsageCleanup] task interrupted: task=%d err=%v", task.ID, ctx.Err())
|
||||
return
|
||||
}
|
||||
canceled, err := s.isTaskCanceled(ctx, task.ID)
|
||||
if err != nil {
|
||||
s.markTaskFailed(task.ID, deletedTotal, err)
|
||||
return
|
||||
}
|
||||
if canceled {
|
||||
log.Printf("[UsageCleanup] task canceled: task=%d deleted_rows=%d duration=%s", task.ID, deletedTotal, time.Since(start))
|
||||
return
|
||||
}
|
||||
|
||||
batchNum++
|
||||
deleted, err := s.repo.DeleteUsageLogsBatch(ctx, task.Filters, batchSize)
|
||||
if err != nil {
|
||||
if errors.Is(err, context.Canceled) || errors.Is(err, context.DeadlineExceeded) {
|
||||
// 任务被中断(例如服务停止/超时),保持 running 状态,后续通过 stale reclaim 续跑。
|
||||
log.Printf("[UsageCleanup] task interrupted: task=%d err=%v", task.ID, err)
|
||||
return
|
||||
}
|
||||
s.markTaskFailed(task.ID, deletedTotal, err)
|
||||
return
|
||||
}
|
||||
deletedTotal += deleted
|
||||
if deleted > 0 {
|
||||
updateCtx, cancel := context.WithTimeout(context.Background(), 3*time.Second)
|
||||
if err := s.repo.UpdateTaskProgress(updateCtx, task.ID, deletedTotal); err != nil {
|
||||
log.Printf("[UsageCleanup] task progress update failed: task=%d deleted_rows=%d err=%v", task.ID, deletedTotal, err)
|
||||
}
|
||||
cancel()
|
||||
}
|
||||
if batchNum <= 3 || batchNum%20 == 0 || deleted < int64(batchSize) {
|
||||
log.Printf("[UsageCleanup] task batch done: task=%d batch=%d deleted=%d deleted_total=%d", task.ID, batchNum, deleted, deletedTotal)
|
||||
}
|
||||
if deleted == 0 || deleted < int64(batchSize) {
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
updateCtx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
|
||||
defer cancel()
|
||||
if err := s.repo.MarkTaskSucceeded(updateCtx, task.ID, deletedTotal); err != nil {
|
||||
log.Printf("[UsageCleanup] update task succeeded failed: task=%d err=%v", task.ID, err)
|
||||
} else {
|
||||
log.Printf("[UsageCleanup] task succeeded: task=%d deleted_rows=%d duration=%s", task.ID, deletedTotal, time.Since(start))
|
||||
}
|
||||
|
||||
if s.dashboard != nil {
|
||||
if err := s.dashboard.TriggerRecomputeRange(task.Filters.StartTime, task.Filters.EndTime); err != nil {
|
||||
log.Printf("[UsageCleanup] trigger dashboard recompute failed: task=%d err=%v", task.ID, err)
|
||||
} else {
|
||||
log.Printf("[UsageCleanup] trigger dashboard recompute: task=%d start=%s end=%s", task.ID, task.Filters.StartTime.UTC().Format(time.RFC3339), task.Filters.EndTime.UTC().Format(time.RFC3339))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (s *UsageCleanupService) markTaskFailed(taskID int64, deletedRows int64, err error) {
|
||||
msg := strings.TrimSpace(err.Error())
|
||||
if len(msg) > 500 {
|
||||
msg = msg[:500]
|
||||
}
|
||||
log.Printf("[UsageCleanup] task failed: task=%d deleted_rows=%d err=%s", taskID, deletedRows, msg)
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
|
||||
defer cancel()
|
||||
if updateErr := s.repo.MarkTaskFailed(ctx, taskID, deletedRows, msg); updateErr != nil {
|
||||
log.Printf("[UsageCleanup] update task failed failed: task=%d err=%v", taskID, updateErr)
|
||||
}
|
||||
}
|
||||
|
||||
func (s *UsageCleanupService) isTaskCanceled(ctx context.Context, taskID int64) (bool, error) {
|
||||
if s == nil || s.repo == nil {
|
||||
return false, fmt.Errorf("cleanup service not ready")
|
||||
}
|
||||
checkCtx, cancel := context.WithTimeout(context.Background(), 2*time.Second)
|
||||
defer cancel()
|
||||
status, err := s.repo.GetTaskStatus(checkCtx, taskID)
|
||||
if err != nil {
|
||||
if errors.Is(err, sql.ErrNoRows) {
|
||||
return false, nil
|
||||
}
|
||||
return false, err
|
||||
}
|
||||
if status == UsageCleanupStatusCanceled {
|
||||
log.Printf("[UsageCleanup] task cancel detected: task=%d", taskID)
|
||||
}
|
||||
return status == UsageCleanupStatusCanceled, nil
|
||||
}
|
||||
|
||||
func (s *UsageCleanupService) validateFilters(filters UsageCleanupFilters) error {
|
||||
if filters.StartTime.IsZero() || filters.EndTime.IsZero() {
|
||||
return infraerrors.BadRequest("USAGE_CLEANUP_MISSING_RANGE", "start_date and end_date are required")
|
||||
}
|
||||
if filters.EndTime.Before(filters.StartTime) {
|
||||
return infraerrors.BadRequest("USAGE_CLEANUP_INVALID_RANGE", "end_date must be after start_date")
|
||||
}
|
||||
maxDays := s.maxRangeDays()
|
||||
if maxDays > 0 {
|
||||
delta := filters.EndTime.Sub(filters.StartTime)
|
||||
if delta > time.Duration(maxDays)*24*time.Hour {
|
||||
return infraerrors.BadRequest("USAGE_CLEANUP_RANGE_TOO_LARGE", fmt.Sprintf("date range exceeds %d days", maxDays))
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *UsageCleanupService) CancelTask(ctx context.Context, taskID int64, canceledBy int64) error {
|
||||
if s == nil || s.repo == nil {
|
||||
return fmt.Errorf("cleanup service not ready")
|
||||
}
|
||||
if s.cfg != nil && !s.cfg.UsageCleanup.Enabled {
|
||||
return infraerrors.New(http.StatusServiceUnavailable, "USAGE_CLEANUP_DISABLED", "usage cleanup is disabled")
|
||||
}
|
||||
if canceledBy <= 0 {
|
||||
return infraerrors.BadRequest("USAGE_CLEANUP_INVALID_CANCELLER", "invalid canceller")
|
||||
}
|
||||
status, err := s.repo.GetTaskStatus(ctx, taskID)
|
||||
if err != nil {
|
||||
if errors.Is(err, sql.ErrNoRows) {
|
||||
return infraerrors.New(http.StatusNotFound, "USAGE_CLEANUP_TASK_NOT_FOUND", "cleanup task not found")
|
||||
}
|
||||
return err
|
||||
}
|
||||
log.Printf("[UsageCleanup] cancel_task requested: task=%d operator=%d status=%s", taskID, canceledBy, status)
|
||||
if status != UsageCleanupStatusPending && status != UsageCleanupStatusRunning {
|
||||
return infraerrors.New(http.StatusConflict, "USAGE_CLEANUP_CANCEL_CONFLICT", "cleanup task cannot be canceled in current status")
|
||||
}
|
||||
ok, err := s.repo.CancelTask(ctx, taskID, canceledBy)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if !ok {
|
||||
// 状态可能并发改变
|
||||
return infraerrors.New(http.StatusConflict, "USAGE_CLEANUP_CANCEL_CONFLICT", "cleanup task cannot be canceled in current status")
|
||||
}
|
||||
log.Printf("[UsageCleanup] cancel_task done: task=%d operator=%d", taskID, canceledBy)
|
||||
return nil
|
||||
}
|
||||
|
||||
func sanitizeUsageCleanupFilters(filters *UsageCleanupFilters) {
|
||||
if filters == nil {
|
||||
return
|
||||
}
|
||||
if filters.UserID != nil && *filters.UserID <= 0 {
|
||||
filters.UserID = nil
|
||||
}
|
||||
if filters.APIKeyID != nil && *filters.APIKeyID <= 0 {
|
||||
filters.APIKeyID = nil
|
||||
}
|
||||
if filters.AccountID != nil && *filters.AccountID <= 0 {
|
||||
filters.AccountID = nil
|
||||
}
|
||||
if filters.GroupID != nil && *filters.GroupID <= 0 {
|
||||
filters.GroupID = nil
|
||||
}
|
||||
if filters.Model != nil {
|
||||
model := strings.TrimSpace(*filters.Model)
|
||||
if model == "" {
|
||||
filters.Model = nil
|
||||
} else {
|
||||
filters.Model = &model
|
||||
}
|
||||
}
|
||||
if filters.BillingType != nil && *filters.BillingType < 0 {
|
||||
filters.BillingType = nil
|
||||
}
|
||||
}
|
||||
|
||||
func (s *UsageCleanupService) maxRangeDays() int {
|
||||
if s == nil || s.cfg == nil {
|
||||
return 31
|
||||
}
|
||||
if s.cfg.UsageCleanup.MaxRangeDays > 0 {
|
||||
return s.cfg.UsageCleanup.MaxRangeDays
|
||||
}
|
||||
return 31
|
||||
}
|
||||
|
||||
func (s *UsageCleanupService) batchSize() int {
|
||||
if s == nil || s.cfg == nil {
|
||||
return 5000
|
||||
}
|
||||
if s.cfg.UsageCleanup.BatchSize > 0 {
|
||||
return s.cfg.UsageCleanup.BatchSize
|
||||
}
|
||||
return 5000
|
||||
}
|
||||
|
||||
func (s *UsageCleanupService) workerInterval() time.Duration {
|
||||
if s == nil || s.cfg == nil {
|
||||
return 10 * time.Second
|
||||
}
|
||||
if s.cfg.UsageCleanup.WorkerIntervalSeconds > 0 {
|
||||
return time.Duration(s.cfg.UsageCleanup.WorkerIntervalSeconds) * time.Second
|
||||
}
|
||||
return 10 * time.Second
|
||||
}
|
||||
|
||||
func (s *UsageCleanupService) taskTimeout() time.Duration {
|
||||
if s == nil || s.cfg == nil {
|
||||
return 30 * time.Minute
|
||||
}
|
||||
if s.cfg.UsageCleanup.TaskTimeoutSeconds > 0 {
|
||||
return time.Duration(s.cfg.UsageCleanup.TaskTimeoutSeconds) * time.Second
|
||||
}
|
||||
return 30 * time.Minute
|
||||
}
|
||||
815
backend/internal/service/usage_cleanup_service_test.go
Normal file
815
backend/internal/service/usage_cleanup_service_test.go
Normal file
@@ -0,0 +1,815 @@
|
||||
package service
|
||||
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"errors"
|
||||
"net/http"
|
||||
"strings"
|
||||
"sync"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/config"
|
||||
infraerrors "github.com/Wei-Shaw/sub2api/internal/pkg/errors"
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/pagination"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
type cleanupDeleteResponse struct {
|
||||
deleted int64
|
||||
err error
|
||||
}
|
||||
|
||||
type cleanupDeleteCall struct {
|
||||
filters UsageCleanupFilters
|
||||
limit int
|
||||
}
|
||||
|
||||
type cleanupMarkCall struct {
|
||||
taskID int64
|
||||
deletedRows int64
|
||||
errMsg string
|
||||
}
|
||||
|
||||
type cleanupRepoStub struct {
|
||||
mu sync.Mutex
|
||||
created []*UsageCleanupTask
|
||||
createErr error
|
||||
listTasks []UsageCleanupTask
|
||||
listResult *pagination.PaginationResult
|
||||
listErr error
|
||||
claimQueue []*UsageCleanupTask
|
||||
claimErr error
|
||||
deleteQueue []cleanupDeleteResponse
|
||||
deleteCalls []cleanupDeleteCall
|
||||
markSucceeded []cleanupMarkCall
|
||||
markFailed []cleanupMarkCall
|
||||
statusByID map[int64]string
|
||||
statusErr error
|
||||
progressCalls []cleanupMarkCall
|
||||
updateErr error
|
||||
cancelCalls []int64
|
||||
cancelErr error
|
||||
cancelResult *bool
|
||||
markFailedErr error
|
||||
}
|
||||
|
||||
type dashboardRepoStub struct {
|
||||
recomputeErr error
|
||||
}
|
||||
|
||||
func (s *dashboardRepoStub) AggregateRange(ctx context.Context, start, end time.Time) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *dashboardRepoStub) RecomputeRange(ctx context.Context, start, end time.Time) error {
|
||||
return s.recomputeErr
|
||||
}
|
||||
|
||||
func (s *dashboardRepoStub) GetAggregationWatermark(ctx context.Context) (time.Time, error) {
|
||||
return time.Time{}, nil
|
||||
}
|
||||
|
||||
func (s *dashboardRepoStub) UpdateAggregationWatermark(ctx context.Context, aggregatedAt time.Time) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *dashboardRepoStub) CleanupAggregates(ctx context.Context, hourlyCutoff, dailyCutoff time.Time) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *dashboardRepoStub) CleanupUsageLogs(ctx context.Context, cutoff time.Time) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *dashboardRepoStub) EnsureUsageLogsPartitions(ctx context.Context, now time.Time) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *cleanupRepoStub) CreateTask(ctx context.Context, task *UsageCleanupTask) error {
|
||||
if task == nil {
|
||||
return nil
|
||||
}
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
if s.createErr != nil {
|
||||
return s.createErr
|
||||
}
|
||||
if task.ID == 0 {
|
||||
task.ID = int64(len(s.created) + 1)
|
||||
}
|
||||
if task.CreatedAt.IsZero() {
|
||||
task.CreatedAt = time.Now().UTC()
|
||||
}
|
||||
if task.UpdatedAt.IsZero() {
|
||||
task.UpdatedAt = task.CreatedAt
|
||||
}
|
||||
clone := *task
|
||||
s.created = append(s.created, &clone)
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *cleanupRepoStub) ListTasks(ctx context.Context, params pagination.PaginationParams) ([]UsageCleanupTask, *pagination.PaginationResult, error) {
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
return s.listTasks, s.listResult, s.listErr
|
||||
}
|
||||
|
||||
func (s *cleanupRepoStub) ClaimNextPendingTask(ctx context.Context, staleRunningAfterSeconds int64) (*UsageCleanupTask, error) {
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
if s.claimErr != nil {
|
||||
return nil, s.claimErr
|
||||
}
|
||||
if len(s.claimQueue) == 0 {
|
||||
return nil, nil
|
||||
}
|
||||
task := s.claimQueue[0]
|
||||
s.claimQueue = s.claimQueue[1:]
|
||||
if s.statusByID == nil {
|
||||
s.statusByID = map[int64]string{}
|
||||
}
|
||||
s.statusByID[task.ID] = UsageCleanupStatusRunning
|
||||
return task, nil
|
||||
}
|
||||
|
||||
func (s *cleanupRepoStub) GetTaskStatus(ctx context.Context, taskID int64) (string, error) {
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
if s.statusErr != nil {
|
||||
return "", s.statusErr
|
||||
}
|
||||
if s.statusByID == nil {
|
||||
return "", sql.ErrNoRows
|
||||
}
|
||||
status, ok := s.statusByID[taskID]
|
||||
if !ok {
|
||||
return "", sql.ErrNoRows
|
||||
}
|
||||
return status, nil
|
||||
}
|
||||
|
||||
func (s *cleanupRepoStub) UpdateTaskProgress(ctx context.Context, taskID int64, deletedRows int64) error {
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
s.progressCalls = append(s.progressCalls, cleanupMarkCall{taskID: taskID, deletedRows: deletedRows})
|
||||
if s.updateErr != nil {
|
||||
return s.updateErr
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *cleanupRepoStub) CancelTask(ctx context.Context, taskID int64, canceledBy int64) (bool, error) {
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
s.cancelCalls = append(s.cancelCalls, taskID)
|
||||
if s.cancelErr != nil {
|
||||
return false, s.cancelErr
|
||||
}
|
||||
if s.cancelResult != nil {
|
||||
ok := *s.cancelResult
|
||||
if ok {
|
||||
if s.statusByID == nil {
|
||||
s.statusByID = map[int64]string{}
|
||||
}
|
||||
s.statusByID[taskID] = UsageCleanupStatusCanceled
|
||||
}
|
||||
return ok, nil
|
||||
}
|
||||
if s.statusByID == nil {
|
||||
s.statusByID = map[int64]string{}
|
||||
}
|
||||
status := s.statusByID[taskID]
|
||||
if status != UsageCleanupStatusPending && status != UsageCleanupStatusRunning {
|
||||
return false, nil
|
||||
}
|
||||
s.statusByID[taskID] = UsageCleanupStatusCanceled
|
||||
return true, nil
|
||||
}
|
||||
|
||||
func (s *cleanupRepoStub) MarkTaskSucceeded(ctx context.Context, taskID int64, deletedRows int64) error {
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
s.markSucceeded = append(s.markSucceeded, cleanupMarkCall{taskID: taskID, deletedRows: deletedRows})
|
||||
if s.statusByID == nil {
|
||||
s.statusByID = map[int64]string{}
|
||||
}
|
||||
s.statusByID[taskID] = UsageCleanupStatusSucceeded
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *cleanupRepoStub) MarkTaskFailed(ctx context.Context, taskID int64, deletedRows int64, errorMsg string) error {
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
s.markFailed = append(s.markFailed, cleanupMarkCall{taskID: taskID, deletedRows: deletedRows, errMsg: errorMsg})
|
||||
if s.statusByID == nil {
|
||||
s.statusByID = map[int64]string{}
|
||||
}
|
||||
s.statusByID[taskID] = UsageCleanupStatusFailed
|
||||
if s.markFailedErr != nil {
|
||||
return s.markFailedErr
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *cleanupRepoStub) DeleteUsageLogsBatch(ctx context.Context, filters UsageCleanupFilters, limit int) (int64, error) {
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
s.deleteCalls = append(s.deleteCalls, cleanupDeleteCall{filters: filters, limit: limit})
|
||||
if len(s.deleteQueue) == 0 {
|
||||
return 0, nil
|
||||
}
|
||||
resp := s.deleteQueue[0]
|
||||
s.deleteQueue = s.deleteQueue[1:]
|
||||
return resp.deleted, resp.err
|
||||
}
|
||||
|
||||
func TestUsageCleanupServiceCreateTaskSanitizeFilters(t *testing.T) {
|
||||
repo := &cleanupRepoStub{}
|
||||
cfg := &config.Config{UsageCleanup: config.UsageCleanupConfig{Enabled: true, MaxRangeDays: 31}}
|
||||
svc := NewUsageCleanupService(repo, nil, nil, cfg)
|
||||
|
||||
start := time.Date(2024, 1, 1, 0, 0, 0, 0, time.UTC)
|
||||
end := start.Add(24 * time.Hour)
|
||||
userID := int64(-1)
|
||||
apiKeyID := int64(10)
|
||||
model := " gpt-4 "
|
||||
billingType := int8(-2)
|
||||
filters := UsageCleanupFilters{
|
||||
StartTime: start,
|
||||
EndTime: end,
|
||||
UserID: &userID,
|
||||
APIKeyID: &apiKeyID,
|
||||
Model: &model,
|
||||
BillingType: &billingType,
|
||||
}
|
||||
|
||||
task, err := svc.CreateTask(context.Background(), filters, 9)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, UsageCleanupStatusPending, task.Status)
|
||||
require.Nil(t, task.Filters.UserID)
|
||||
require.NotNil(t, task.Filters.APIKeyID)
|
||||
require.Equal(t, apiKeyID, *task.Filters.APIKeyID)
|
||||
require.NotNil(t, task.Filters.Model)
|
||||
require.Equal(t, "gpt-4", *task.Filters.Model)
|
||||
require.Nil(t, task.Filters.BillingType)
|
||||
require.Equal(t, int64(9), task.CreatedBy)
|
||||
}
|
||||
|
||||
func TestUsageCleanupServiceCreateTaskInvalidCreator(t *testing.T) {
|
||||
repo := &cleanupRepoStub{}
|
||||
cfg := &config.Config{UsageCleanup: config.UsageCleanupConfig{Enabled: true}}
|
||||
svc := NewUsageCleanupService(repo, nil, nil, cfg)
|
||||
|
||||
filters := UsageCleanupFilters{
|
||||
StartTime: time.Now(),
|
||||
EndTime: time.Now().Add(24 * time.Hour),
|
||||
}
|
||||
_, err := svc.CreateTask(context.Background(), filters, 0)
|
||||
require.Error(t, err)
|
||||
require.Equal(t, "USAGE_CLEANUP_INVALID_CREATOR", infraerrors.Reason(err))
|
||||
}
|
||||
|
||||
func TestUsageCleanupServiceCreateTaskDisabled(t *testing.T) {
|
||||
repo := &cleanupRepoStub{}
|
||||
cfg := &config.Config{UsageCleanup: config.UsageCleanupConfig{Enabled: false}}
|
||||
svc := NewUsageCleanupService(repo, nil, nil, cfg)
|
||||
|
||||
filters := UsageCleanupFilters{
|
||||
StartTime: time.Now(),
|
||||
EndTime: time.Now().Add(24 * time.Hour),
|
||||
}
|
||||
_, err := svc.CreateTask(context.Background(), filters, 1)
|
||||
require.Error(t, err)
|
||||
require.Equal(t, http.StatusServiceUnavailable, infraerrors.Code(err))
|
||||
require.Equal(t, "USAGE_CLEANUP_DISABLED", infraerrors.Reason(err))
|
||||
}
|
||||
|
||||
func TestUsageCleanupServiceCreateTaskRangeTooLarge(t *testing.T) {
|
||||
repo := &cleanupRepoStub{}
|
||||
cfg := &config.Config{UsageCleanup: config.UsageCleanupConfig{Enabled: true, MaxRangeDays: 1}}
|
||||
svc := NewUsageCleanupService(repo, nil, nil, cfg)
|
||||
|
||||
start := time.Date(2024, 1, 1, 0, 0, 0, 0, time.UTC)
|
||||
end := start.Add(48 * time.Hour)
|
||||
filters := UsageCleanupFilters{StartTime: start, EndTime: end}
|
||||
|
||||
_, err := svc.CreateTask(context.Background(), filters, 1)
|
||||
require.Error(t, err)
|
||||
require.Equal(t, "USAGE_CLEANUP_RANGE_TOO_LARGE", infraerrors.Reason(err))
|
||||
}
|
||||
|
||||
func TestUsageCleanupServiceCreateTaskMissingRange(t *testing.T) {
|
||||
repo := &cleanupRepoStub{}
|
||||
cfg := &config.Config{UsageCleanup: config.UsageCleanupConfig{Enabled: true}}
|
||||
svc := NewUsageCleanupService(repo, nil, nil, cfg)
|
||||
|
||||
_, err := svc.CreateTask(context.Background(), UsageCleanupFilters{}, 1)
|
||||
require.Error(t, err)
|
||||
require.Equal(t, "USAGE_CLEANUP_MISSING_RANGE", infraerrors.Reason(err))
|
||||
}
|
||||
|
||||
func TestUsageCleanupServiceCreateTaskRepoError(t *testing.T) {
|
||||
repo := &cleanupRepoStub{createErr: errors.New("db down")}
|
||||
cfg := &config.Config{UsageCleanup: config.UsageCleanupConfig{Enabled: true}}
|
||||
svc := NewUsageCleanupService(repo, nil, nil, cfg)
|
||||
|
||||
filters := UsageCleanupFilters{
|
||||
StartTime: time.Now(),
|
||||
EndTime: time.Now().Add(24 * time.Hour),
|
||||
}
|
||||
_, err := svc.CreateTask(context.Background(), filters, 1)
|
||||
require.Error(t, err)
|
||||
require.Contains(t, err.Error(), "create cleanup task")
|
||||
}
|
||||
|
||||
func TestUsageCleanupServiceRunOnceSuccess(t *testing.T) {
|
||||
start := time.Date(2024, 1, 1, 0, 0, 0, 0, time.UTC)
|
||||
end := start.Add(2 * time.Hour)
|
||||
repo := &cleanupRepoStub{
|
||||
claimQueue: []*UsageCleanupTask{
|
||||
{ID: 5, Filters: UsageCleanupFilters{StartTime: start, EndTime: end}},
|
||||
},
|
||||
deleteQueue: []cleanupDeleteResponse{
|
||||
{deleted: 2},
|
||||
{deleted: 2},
|
||||
{deleted: 1},
|
||||
},
|
||||
}
|
||||
cfg := &config.Config{UsageCleanup: config.UsageCleanupConfig{Enabled: true, BatchSize: 2, TaskTimeoutSeconds: 30}}
|
||||
svc := NewUsageCleanupService(repo, nil, nil, cfg)
|
||||
|
||||
svc.runOnce()
|
||||
|
||||
repo.mu.Lock()
|
||||
defer repo.mu.Unlock()
|
||||
require.Len(t, repo.deleteCalls, 3)
|
||||
require.Len(t, repo.markSucceeded, 1)
|
||||
require.Empty(t, repo.markFailed)
|
||||
require.Equal(t, int64(5), repo.markSucceeded[0].taskID)
|
||||
require.Equal(t, int64(5), repo.markSucceeded[0].deletedRows)
|
||||
require.Equal(t, 2, repo.deleteCalls[0].limit)
|
||||
require.Equal(t, start, repo.deleteCalls[0].filters.StartTime)
|
||||
require.Equal(t, end, repo.deleteCalls[0].filters.EndTime)
|
||||
}
|
||||
|
||||
func TestUsageCleanupServiceRunOnceClaimError(t *testing.T) {
|
||||
repo := &cleanupRepoStub{claimErr: errors.New("claim failed")}
|
||||
cfg := &config.Config{UsageCleanup: config.UsageCleanupConfig{Enabled: true}}
|
||||
svc := NewUsageCleanupService(repo, nil, nil, cfg)
|
||||
svc.runOnce()
|
||||
|
||||
repo.mu.Lock()
|
||||
defer repo.mu.Unlock()
|
||||
require.Empty(t, repo.markSucceeded)
|
||||
require.Empty(t, repo.markFailed)
|
||||
}
|
||||
|
||||
func TestUsageCleanupServiceRunOnceAlreadyRunning(t *testing.T) {
|
||||
repo := &cleanupRepoStub{}
|
||||
cfg := &config.Config{UsageCleanup: config.UsageCleanupConfig{Enabled: true}}
|
||||
svc := NewUsageCleanupService(repo, nil, nil, cfg)
|
||||
svc.running = 1
|
||||
svc.runOnce()
|
||||
}
|
||||
|
||||
func TestUsageCleanupServiceExecuteTaskFailed(t *testing.T) {
|
||||
longMsg := strings.Repeat("x", 600)
|
||||
repo := &cleanupRepoStub{
|
||||
deleteQueue: []cleanupDeleteResponse{
|
||||
{err: errors.New(longMsg)},
|
||||
},
|
||||
}
|
||||
cfg := &config.Config{UsageCleanup: config.UsageCleanupConfig{Enabled: true, BatchSize: 3}}
|
||||
svc := NewUsageCleanupService(repo, nil, nil, cfg)
|
||||
task := &UsageCleanupTask{
|
||||
ID: 11,
|
||||
Filters: UsageCleanupFilters{
|
||||
StartTime: time.Now(),
|
||||
EndTime: time.Now().Add(24 * time.Hour),
|
||||
},
|
||||
}
|
||||
|
||||
svc.executeTask(context.Background(), task)
|
||||
|
||||
repo.mu.Lock()
|
||||
defer repo.mu.Unlock()
|
||||
require.Len(t, repo.markFailed, 1)
|
||||
require.Equal(t, int64(11), repo.markFailed[0].taskID)
|
||||
require.Equal(t, 500, len(repo.markFailed[0].errMsg))
|
||||
}
|
||||
|
||||
func TestUsageCleanupServiceExecuteTaskProgressError(t *testing.T) {
|
||||
repo := &cleanupRepoStub{
|
||||
deleteQueue: []cleanupDeleteResponse{
|
||||
{deleted: 2},
|
||||
{deleted: 0},
|
||||
},
|
||||
updateErr: errors.New("update failed"),
|
||||
}
|
||||
cfg := &config.Config{UsageCleanup: config.UsageCleanupConfig{Enabled: true, BatchSize: 2}}
|
||||
svc := NewUsageCleanupService(repo, nil, nil, cfg)
|
||||
task := &UsageCleanupTask{
|
||||
ID: 8,
|
||||
Filters: UsageCleanupFilters{
|
||||
StartTime: time.Now().UTC(),
|
||||
EndTime: time.Now().UTC().Add(time.Hour),
|
||||
},
|
||||
}
|
||||
|
||||
svc.executeTask(context.Background(), task)
|
||||
|
||||
repo.mu.Lock()
|
||||
defer repo.mu.Unlock()
|
||||
require.Len(t, repo.markSucceeded, 1)
|
||||
require.Empty(t, repo.markFailed)
|
||||
require.Len(t, repo.progressCalls, 1)
|
||||
}
|
||||
|
||||
func TestUsageCleanupServiceExecuteTaskDeleteCanceled(t *testing.T) {
|
||||
repo := &cleanupRepoStub{
|
||||
deleteQueue: []cleanupDeleteResponse{
|
||||
{err: context.Canceled},
|
||||
},
|
||||
}
|
||||
cfg := &config.Config{UsageCleanup: config.UsageCleanupConfig{Enabled: true, BatchSize: 2}}
|
||||
svc := NewUsageCleanupService(repo, nil, nil, cfg)
|
||||
task := &UsageCleanupTask{
|
||||
ID: 12,
|
||||
Filters: UsageCleanupFilters{
|
||||
StartTime: time.Now().UTC(),
|
||||
EndTime: time.Now().UTC().Add(time.Hour),
|
||||
},
|
||||
}
|
||||
|
||||
svc.executeTask(context.Background(), task)
|
||||
|
||||
repo.mu.Lock()
|
||||
defer repo.mu.Unlock()
|
||||
require.Empty(t, repo.markSucceeded)
|
||||
require.Empty(t, repo.markFailed)
|
||||
}
|
||||
|
||||
func TestUsageCleanupServiceExecuteTaskContextCanceled(t *testing.T) {
|
||||
repo := &cleanupRepoStub{}
|
||||
cfg := &config.Config{UsageCleanup: config.UsageCleanupConfig{Enabled: true, BatchSize: 2}}
|
||||
svc := NewUsageCleanupService(repo, nil, nil, cfg)
|
||||
task := &UsageCleanupTask{
|
||||
ID: 9,
|
||||
Filters: UsageCleanupFilters{
|
||||
StartTime: time.Now().UTC(),
|
||||
EndTime: time.Now().UTC().Add(time.Hour),
|
||||
},
|
||||
}
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
cancel()
|
||||
|
||||
svc.executeTask(ctx, task)
|
||||
|
||||
repo.mu.Lock()
|
||||
defer repo.mu.Unlock()
|
||||
require.Empty(t, repo.markSucceeded)
|
||||
require.Empty(t, repo.markFailed)
|
||||
require.Empty(t, repo.deleteCalls)
|
||||
}
|
||||
|
||||
func TestUsageCleanupServiceExecuteTaskMarkFailedUpdateError(t *testing.T) {
|
||||
repo := &cleanupRepoStub{
|
||||
deleteQueue: []cleanupDeleteResponse{
|
||||
{err: errors.New("boom")},
|
||||
},
|
||||
markFailedErr: errors.New("update failed"),
|
||||
}
|
||||
cfg := &config.Config{UsageCleanup: config.UsageCleanupConfig{Enabled: true, BatchSize: 2}}
|
||||
svc := NewUsageCleanupService(repo, nil, nil, cfg)
|
||||
task := &UsageCleanupTask{
|
||||
ID: 13,
|
||||
Filters: UsageCleanupFilters{
|
||||
StartTime: time.Now().UTC(),
|
||||
EndTime: time.Now().UTC().Add(time.Hour),
|
||||
},
|
||||
}
|
||||
|
||||
svc.executeTask(context.Background(), task)
|
||||
|
||||
repo.mu.Lock()
|
||||
defer repo.mu.Unlock()
|
||||
require.Len(t, repo.markFailed, 1)
|
||||
require.Equal(t, int64(13), repo.markFailed[0].taskID)
|
||||
}
|
||||
|
||||
func TestUsageCleanupServiceExecuteTaskDashboardRecomputeError(t *testing.T) {
|
||||
repo := &cleanupRepoStub{
|
||||
deleteQueue: []cleanupDeleteResponse{
|
||||
{deleted: 0},
|
||||
},
|
||||
}
|
||||
dashboard := NewDashboardAggregationService(&dashboardRepoStub{}, nil, &config.Config{
|
||||
DashboardAgg: config.DashboardAggregationConfig{Enabled: false},
|
||||
})
|
||||
cfg := &config.Config{UsageCleanup: config.UsageCleanupConfig{Enabled: true, BatchSize: 2}}
|
||||
svc := NewUsageCleanupService(repo, nil, dashboard, cfg)
|
||||
task := &UsageCleanupTask{
|
||||
ID: 14,
|
||||
Filters: UsageCleanupFilters{
|
||||
StartTime: time.Now().UTC(),
|
||||
EndTime: time.Now().UTC().Add(time.Hour),
|
||||
},
|
||||
}
|
||||
|
||||
svc.executeTask(context.Background(), task)
|
||||
|
||||
repo.mu.Lock()
|
||||
defer repo.mu.Unlock()
|
||||
require.Len(t, repo.markSucceeded, 1)
|
||||
}
|
||||
|
||||
func TestUsageCleanupServiceExecuteTaskDashboardRecomputeSuccess(t *testing.T) {
|
||||
repo := &cleanupRepoStub{
|
||||
deleteQueue: []cleanupDeleteResponse{
|
||||
{deleted: 0},
|
||||
},
|
||||
}
|
||||
dashboard := NewDashboardAggregationService(&dashboardRepoStub{}, nil, &config.Config{
|
||||
DashboardAgg: config.DashboardAggregationConfig{Enabled: true},
|
||||
})
|
||||
cfg := &config.Config{UsageCleanup: config.UsageCleanupConfig{Enabled: true, BatchSize: 2}}
|
||||
svc := NewUsageCleanupService(repo, nil, dashboard, cfg)
|
||||
task := &UsageCleanupTask{
|
||||
ID: 15,
|
||||
Filters: UsageCleanupFilters{
|
||||
StartTime: time.Now().UTC(),
|
||||
EndTime: time.Now().UTC().Add(time.Hour),
|
||||
},
|
||||
}
|
||||
|
||||
svc.executeTask(context.Background(), task)
|
||||
|
||||
repo.mu.Lock()
|
||||
defer repo.mu.Unlock()
|
||||
require.Len(t, repo.markSucceeded, 1)
|
||||
}
|
||||
|
||||
func TestUsageCleanupServiceExecuteTaskCanceled(t *testing.T) {
|
||||
repo := &cleanupRepoStub{
|
||||
statusByID: map[int64]string{
|
||||
3: UsageCleanupStatusCanceled,
|
||||
},
|
||||
}
|
||||
cfg := &config.Config{UsageCleanup: config.UsageCleanupConfig{Enabled: true, BatchSize: 2}}
|
||||
svc := NewUsageCleanupService(repo, nil, nil, cfg)
|
||||
task := &UsageCleanupTask{
|
||||
ID: 3,
|
||||
Filters: UsageCleanupFilters{
|
||||
StartTime: time.Now().UTC(),
|
||||
EndTime: time.Now().UTC().Add(time.Hour),
|
||||
},
|
||||
}
|
||||
|
||||
svc.executeTask(context.Background(), task)
|
||||
|
||||
repo.mu.Lock()
|
||||
defer repo.mu.Unlock()
|
||||
require.Empty(t, repo.deleteCalls)
|
||||
require.Empty(t, repo.markSucceeded)
|
||||
require.Empty(t, repo.markFailed)
|
||||
}
|
||||
|
||||
func TestUsageCleanupServiceCancelTaskSuccess(t *testing.T) {
|
||||
repo := &cleanupRepoStub{
|
||||
statusByID: map[int64]string{
|
||||
5: UsageCleanupStatusPending,
|
||||
},
|
||||
}
|
||||
cfg := &config.Config{UsageCleanup: config.UsageCleanupConfig{Enabled: true}}
|
||||
svc := NewUsageCleanupService(repo, nil, nil, cfg)
|
||||
|
||||
err := svc.CancelTask(context.Background(), 5, 9)
|
||||
require.NoError(t, err)
|
||||
|
||||
repo.mu.Lock()
|
||||
defer repo.mu.Unlock()
|
||||
require.Equal(t, UsageCleanupStatusCanceled, repo.statusByID[5])
|
||||
require.Len(t, repo.cancelCalls, 1)
|
||||
}
|
||||
|
||||
func TestUsageCleanupServiceCancelTaskDisabled(t *testing.T) {
|
||||
repo := &cleanupRepoStub{}
|
||||
cfg := &config.Config{UsageCleanup: config.UsageCleanupConfig{Enabled: false}}
|
||||
svc := NewUsageCleanupService(repo, nil, nil, cfg)
|
||||
|
||||
err := svc.CancelTask(context.Background(), 1, 2)
|
||||
require.Error(t, err)
|
||||
require.Equal(t, http.StatusServiceUnavailable, infraerrors.Code(err))
|
||||
require.Equal(t, "USAGE_CLEANUP_DISABLED", infraerrors.Reason(err))
|
||||
}
|
||||
|
||||
func TestUsageCleanupServiceCancelTaskNotFound(t *testing.T) {
|
||||
repo := &cleanupRepoStub{}
|
||||
cfg := &config.Config{UsageCleanup: config.UsageCleanupConfig{Enabled: true}}
|
||||
svc := NewUsageCleanupService(repo, nil, nil, cfg)
|
||||
|
||||
err := svc.CancelTask(context.Background(), 999, 1)
|
||||
require.Error(t, err)
|
||||
require.Equal(t, http.StatusNotFound, infraerrors.Code(err))
|
||||
require.Equal(t, "USAGE_CLEANUP_TASK_NOT_FOUND", infraerrors.Reason(err))
|
||||
}
|
||||
|
||||
func TestUsageCleanupServiceCancelTaskStatusError(t *testing.T) {
|
||||
repo := &cleanupRepoStub{statusErr: errors.New("status broken")}
|
||||
cfg := &config.Config{UsageCleanup: config.UsageCleanupConfig{Enabled: true}}
|
||||
svc := NewUsageCleanupService(repo, nil, nil, cfg)
|
||||
|
||||
err := svc.CancelTask(context.Background(), 7, 1)
|
||||
require.Error(t, err)
|
||||
require.Contains(t, err.Error(), "status broken")
|
||||
}
|
||||
|
||||
func TestUsageCleanupServiceCancelTaskConflict(t *testing.T) {
|
||||
repo := &cleanupRepoStub{
|
||||
statusByID: map[int64]string{
|
||||
7: UsageCleanupStatusSucceeded,
|
||||
},
|
||||
}
|
||||
cfg := &config.Config{UsageCleanup: config.UsageCleanupConfig{Enabled: true}}
|
||||
svc := NewUsageCleanupService(repo, nil, nil, cfg)
|
||||
|
||||
err := svc.CancelTask(context.Background(), 7, 1)
|
||||
require.Error(t, err)
|
||||
require.Equal(t, http.StatusConflict, infraerrors.Code(err))
|
||||
require.Equal(t, "USAGE_CLEANUP_CANCEL_CONFLICT", infraerrors.Reason(err))
|
||||
}
|
||||
|
||||
func TestUsageCleanupServiceCancelTaskRepoConflict(t *testing.T) {
|
||||
shouldCancel := false
|
||||
repo := &cleanupRepoStub{
|
||||
statusByID: map[int64]string{
|
||||
7: UsageCleanupStatusPending,
|
||||
},
|
||||
cancelResult: &shouldCancel,
|
||||
}
|
||||
cfg := &config.Config{UsageCleanup: config.UsageCleanupConfig{Enabled: true}}
|
||||
svc := NewUsageCleanupService(repo, nil, nil, cfg)
|
||||
|
||||
err := svc.CancelTask(context.Background(), 7, 1)
|
||||
require.Error(t, err)
|
||||
require.Equal(t, http.StatusConflict, infraerrors.Code(err))
|
||||
require.Equal(t, "USAGE_CLEANUP_CANCEL_CONFLICT", infraerrors.Reason(err))
|
||||
}
|
||||
|
||||
func TestUsageCleanupServiceCancelTaskRepoError(t *testing.T) {
|
||||
repo := &cleanupRepoStub{
|
||||
statusByID: map[int64]string{
|
||||
7: UsageCleanupStatusPending,
|
||||
},
|
||||
cancelErr: errors.New("cancel failed"),
|
||||
}
|
||||
cfg := &config.Config{UsageCleanup: config.UsageCleanupConfig{Enabled: true}}
|
||||
svc := NewUsageCleanupService(repo, nil, nil, cfg)
|
||||
|
||||
err := svc.CancelTask(context.Background(), 7, 1)
|
||||
require.Error(t, err)
|
||||
require.Contains(t, err.Error(), "cancel failed")
|
||||
}
|
||||
|
||||
func TestUsageCleanupServiceCancelTaskInvalidCanceller(t *testing.T) {
|
||||
repo := &cleanupRepoStub{
|
||||
statusByID: map[int64]string{
|
||||
7: UsageCleanupStatusRunning,
|
||||
},
|
||||
}
|
||||
cfg := &config.Config{UsageCleanup: config.UsageCleanupConfig{Enabled: true}}
|
||||
svc := NewUsageCleanupService(repo, nil, nil, cfg)
|
||||
|
||||
err := svc.CancelTask(context.Background(), 7, 0)
|
||||
require.Error(t, err)
|
||||
require.Equal(t, "USAGE_CLEANUP_INVALID_CANCELLER", infraerrors.Reason(err))
|
||||
}
|
||||
|
||||
func TestUsageCleanupServiceListTasks(t *testing.T) {
|
||||
repo := &cleanupRepoStub{
|
||||
listTasks: []UsageCleanupTask{{ID: 1}, {ID: 2}},
|
||||
listResult: &pagination.PaginationResult{
|
||||
Total: 2,
|
||||
Page: 1,
|
||||
PageSize: 20,
|
||||
Pages: 1,
|
||||
},
|
||||
}
|
||||
svc := NewUsageCleanupService(repo, nil, nil, &config.Config{UsageCleanup: config.UsageCleanupConfig{Enabled: true}})
|
||||
|
||||
tasks, result, err := svc.ListTasks(context.Background(), pagination.PaginationParams{Page: 1, PageSize: 20})
|
||||
require.NoError(t, err)
|
||||
require.Len(t, tasks, 2)
|
||||
require.Equal(t, int64(2), result.Total)
|
||||
}
|
||||
|
||||
func TestUsageCleanupServiceListTasksNotReady(t *testing.T) {
|
||||
var nilSvc *UsageCleanupService
|
||||
_, _, err := nilSvc.ListTasks(context.Background(), pagination.PaginationParams{Page: 1, PageSize: 20})
|
||||
require.Error(t, err)
|
||||
|
||||
svc := NewUsageCleanupService(nil, nil, nil, &config.Config{UsageCleanup: config.UsageCleanupConfig{Enabled: true}})
|
||||
_, _, err = svc.ListTasks(context.Background(), pagination.PaginationParams{Page: 1, PageSize: 20})
|
||||
require.Error(t, err)
|
||||
}
|
||||
|
||||
func TestUsageCleanupServiceDefaultsAndLifecycle(t *testing.T) {
|
||||
var nilSvc *UsageCleanupService
|
||||
require.Equal(t, 31, nilSvc.maxRangeDays())
|
||||
require.Equal(t, 5000, nilSvc.batchSize())
|
||||
require.Equal(t, 10*time.Second, nilSvc.workerInterval())
|
||||
require.Equal(t, 30*time.Minute, nilSvc.taskTimeout())
|
||||
nilSvc.Start()
|
||||
nilSvc.Stop()
|
||||
|
||||
repo := &cleanupRepoStub{}
|
||||
cfgDisabled := &config.Config{UsageCleanup: config.UsageCleanupConfig{Enabled: false}}
|
||||
svcDisabled := NewUsageCleanupService(repo, nil, nil, cfgDisabled)
|
||||
svcDisabled.Start()
|
||||
svcDisabled.Stop()
|
||||
|
||||
timingWheel, err := NewTimingWheelService()
|
||||
require.NoError(t, err)
|
||||
|
||||
cfg := &config.Config{UsageCleanup: config.UsageCleanupConfig{Enabled: true, WorkerIntervalSeconds: 5}}
|
||||
svc := NewUsageCleanupService(repo, timingWheel, nil, cfg)
|
||||
require.Equal(t, 5*time.Second, svc.workerInterval())
|
||||
svc.Start()
|
||||
svc.Stop()
|
||||
|
||||
cfgFallback := &config.Config{UsageCleanup: config.UsageCleanupConfig{Enabled: true}}
|
||||
svcFallback := NewUsageCleanupService(repo, timingWheel, nil, cfgFallback)
|
||||
require.Equal(t, 31, svcFallback.maxRangeDays())
|
||||
require.Equal(t, 5000, svcFallback.batchSize())
|
||||
require.Equal(t, 10*time.Second, svcFallback.workerInterval())
|
||||
|
||||
svcMissingDeps := NewUsageCleanupService(nil, nil, nil, cfgFallback)
|
||||
svcMissingDeps.Start()
|
||||
}
|
||||
|
||||
func TestSanitizeUsageCleanupFiltersModelEmpty(t *testing.T) {
|
||||
model := " "
|
||||
apiKeyID := int64(-5)
|
||||
accountID := int64(-1)
|
||||
groupID := int64(-2)
|
||||
filters := UsageCleanupFilters{
|
||||
UserID: &apiKeyID,
|
||||
APIKeyID: &apiKeyID,
|
||||
AccountID: &accountID,
|
||||
GroupID: &groupID,
|
||||
Model: &model,
|
||||
}
|
||||
|
||||
sanitizeUsageCleanupFilters(&filters)
|
||||
require.Nil(t, filters.UserID)
|
||||
require.Nil(t, filters.APIKeyID)
|
||||
require.Nil(t, filters.AccountID)
|
||||
require.Nil(t, filters.GroupID)
|
||||
require.Nil(t, filters.Model)
|
||||
}
|
||||
|
||||
func TestDescribeUsageCleanupFiltersAllFields(t *testing.T) {
|
||||
start := time.Date(2024, 2, 1, 10, 0, 0, 0, time.UTC)
|
||||
end := start.Add(2 * time.Hour)
|
||||
userID := int64(1)
|
||||
apiKeyID := int64(2)
|
||||
accountID := int64(3)
|
||||
groupID := int64(4)
|
||||
model := " gpt-4 "
|
||||
stream := true
|
||||
billingType := int8(2)
|
||||
filters := UsageCleanupFilters{
|
||||
StartTime: start,
|
||||
EndTime: end,
|
||||
UserID: &userID,
|
||||
APIKeyID: &apiKeyID,
|
||||
AccountID: &accountID,
|
||||
GroupID: &groupID,
|
||||
Model: &model,
|
||||
Stream: &stream,
|
||||
BillingType: &billingType,
|
||||
}
|
||||
|
||||
desc := describeUsageCleanupFilters(filters)
|
||||
require.Equal(t, "start=2024-02-01T10:00:00Z end=2024-02-01T12:00:00Z user_id=1 api_key_id=2 account_id=3 group_id=4 model=gpt-4 stream=true billing_type=2", desc)
|
||||
}
|
||||
|
||||
func TestUsageCleanupServiceIsTaskCanceledNotFound(t *testing.T) {
|
||||
repo := &cleanupRepoStub{}
|
||||
svc := NewUsageCleanupService(repo, nil, nil, &config.Config{UsageCleanup: config.UsageCleanupConfig{Enabled: true}})
|
||||
|
||||
canceled, err := svc.isTaskCanceled(context.Background(), 9)
|
||||
require.NoError(t, err)
|
||||
require.False(t, canceled)
|
||||
}
|
||||
|
||||
func TestUsageCleanupServiceIsTaskCanceledError(t *testing.T) {
|
||||
repo := &cleanupRepoStub{statusErr: errors.New("status err")}
|
||||
svc := NewUsageCleanupService(repo, nil, nil, &config.Config{UsageCleanup: config.UsageCleanupConfig{Enabled: true}})
|
||||
|
||||
_, err := svc.isTaskCanceled(context.Background(), 9)
|
||||
require.Error(t, err)
|
||||
require.Contains(t, err.Error(), "status err")
|
||||
}
|
||||
@@ -58,6 +58,13 @@ func ProvideDashboardAggregationService(repo DashboardAggregationRepository, tim
|
||||
return svc
|
||||
}
|
||||
|
||||
// ProvideUsageCleanupService 创建并启动使用记录清理任务服务
|
||||
func ProvideUsageCleanupService(repo UsageCleanupRepository, timingWheel *TimingWheelService, dashboardAgg *DashboardAggregationService, cfg *config.Config) *UsageCleanupService {
|
||||
svc := NewUsageCleanupService(repo, timingWheel, dashboardAgg, cfg)
|
||||
svc.Start()
|
||||
return svc
|
||||
}
|
||||
|
||||
// ProvideAccountExpiryService creates and starts AccountExpiryService.
|
||||
func ProvideAccountExpiryService(accountRepo AccountRepository) *AccountExpiryService {
|
||||
svc := NewAccountExpiryService(accountRepo, time.Minute)
|
||||
@@ -251,6 +258,7 @@ var ProviderSet = wire.NewSet(
|
||||
ProvideAccountExpiryService,
|
||||
ProvideTimingWheelService,
|
||||
ProvideDashboardAggregationService,
|
||||
ProvideUsageCleanupService,
|
||||
ProvideDeferredService,
|
||||
NewAntigravityQuotaFetcher,
|
||||
NewUserAttributeService,
|
||||
|
||||
Reference in New Issue
Block a user