Merge branch 'main' of github.com:Wei-Shaw/sub2api
This commit is contained in:
@@ -592,6 +592,44 @@ func (a *Account) IsAnthropicOAuthOrSetupToken() bool {
|
||||
return a.Platform == PlatformAnthropic && (a.Type == AccountTypeOAuth || a.Type == AccountTypeSetupToken)
|
||||
}
|
||||
|
||||
// IsTLSFingerprintEnabled 检查是否启用 TLS 指纹伪装
|
||||
// 仅适用于 Anthropic OAuth/SetupToken 类型账号
|
||||
// 启用后将模拟 Claude Code (Node.js) 客户端的 TLS 握手特征
|
||||
func (a *Account) IsTLSFingerprintEnabled() bool {
|
||||
// 仅支持 Anthropic OAuth/SetupToken 账号
|
||||
if !a.IsAnthropicOAuthOrSetupToken() {
|
||||
return false
|
||||
}
|
||||
if a.Extra == nil {
|
||||
return false
|
||||
}
|
||||
if v, ok := a.Extra["enable_tls_fingerprint"]; ok {
|
||||
if enabled, ok := v.(bool); ok {
|
||||
return enabled
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
// IsSessionIDMaskingEnabled 检查是否启用会话ID伪装
|
||||
// 仅适用于 Anthropic OAuth/SetupToken 类型账号
|
||||
// 启用后将在一段时间内(15分钟)固定 metadata.user_id 中的 session ID,
|
||||
// 使上游认为请求来自同一个会话
|
||||
func (a *Account) IsSessionIDMaskingEnabled() bool {
|
||||
if !a.IsAnthropicOAuthOrSetupToken() {
|
||||
return false
|
||||
}
|
||||
if a.Extra == nil {
|
||||
return false
|
||||
}
|
||||
if v, ok := a.Extra["session_id_masking_enabled"]; ok {
|
||||
if enabled, ok := v.(bool); ok {
|
||||
return enabled
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
// GetWindowCostLimit 获取 5h 窗口费用阈值(美元)
|
||||
// 返回 0 表示未启用
|
||||
func (a *Account) GetWindowCostLimit() float64 {
|
||||
@@ -668,6 +706,23 @@ func (a *Account) CheckWindowCostSchedulability(currentWindowCost float64) Windo
|
||||
return WindowCostNotSchedulable
|
||||
}
|
||||
|
||||
// GetCurrentWindowStartTime 获取当前有效的窗口开始时间
|
||||
// 逻辑:
|
||||
// 1. 如果窗口未过期(SessionWindowEnd 存在且在当前时间之后),使用记录的 SessionWindowStart
|
||||
// 2. 否则(窗口过期或未设置),使用新的预测窗口开始时间(从当前整点开始)
|
||||
func (a *Account) GetCurrentWindowStartTime() time.Time {
|
||||
now := time.Now()
|
||||
|
||||
// 窗口未过期,使用记录的窗口开始时间
|
||||
if a.SessionWindowStart != nil && a.SessionWindowEnd != nil && now.Before(*a.SessionWindowEnd) {
|
||||
return *a.SessionWindowStart
|
||||
}
|
||||
|
||||
// 窗口已过期或未设置,预测新的窗口开始时间(从当前整点开始)
|
||||
// 与 ratelimit_service.go 中 UpdateSessionWindow 的预测逻辑保持一致
|
||||
return time.Date(now.Year(), now.Month(), now.Day(), now.Hour(), 0, 0, 0, now.Location())
|
||||
}
|
||||
|
||||
// parseExtraFloat64 从 extra 字段解析 float64 值
|
||||
func parseExtraFloat64(value any) float64 {
|
||||
switch v := value.(type) {
|
||||
|
||||
@@ -37,6 +37,7 @@ type AccountRepository interface {
|
||||
UpdateLastUsed(ctx context.Context, id int64) error
|
||||
BatchUpdateLastUsed(ctx context.Context, updates map[int64]time.Time) error
|
||||
SetError(ctx context.Context, id int64, errorMsg string) error
|
||||
ClearError(ctx context.Context, id int64) error
|
||||
SetSchedulable(ctx context.Context, id int64, schedulable bool) error
|
||||
AutoPauseExpiredAccounts(ctx context.Context, now time.Time) (int64, error)
|
||||
BindGroups(ctx context.Context, accountID int64, groupIDs []int64) error
|
||||
|
||||
@@ -99,6 +99,10 @@ func (s *accountRepoStub) SetError(ctx context.Context, id int64, errorMsg strin
|
||||
panic("unexpected SetError call")
|
||||
}
|
||||
|
||||
func (s *accountRepoStub) ClearError(ctx context.Context, id int64) error {
|
||||
panic("unexpected ClearError call")
|
||||
}
|
||||
|
||||
func (s *accountRepoStub) SetSchedulable(ctx context.Context, id int64, schedulable bool) error {
|
||||
panic("unexpected SetSchedulable call")
|
||||
}
|
||||
|
||||
@@ -265,7 +265,7 @@ func (s *AccountTestService) testClaudeAccountConnection(c *gin.Context, account
|
||||
proxyURL = account.Proxy.URL()
|
||||
}
|
||||
|
||||
resp, err := s.httpUpstream.Do(req, proxyURL, account.ID, account.Concurrency)
|
||||
resp, err := s.httpUpstream.DoWithTLS(req, proxyURL, account.ID, account.Concurrency, account.IsTLSFingerprintEnabled())
|
||||
if err != nil {
|
||||
return s.sendErrorAndEnd(c, fmt.Sprintf("Request failed: %s", err.Error()))
|
||||
}
|
||||
@@ -375,7 +375,7 @@ func (s *AccountTestService) testOpenAIAccountConnection(c *gin.Context, account
|
||||
proxyURL = account.Proxy.URL()
|
||||
}
|
||||
|
||||
resp, err := s.httpUpstream.Do(req, proxyURL, account.ID, account.Concurrency)
|
||||
resp, err := s.httpUpstream.DoWithTLS(req, proxyURL, account.ID, account.Concurrency, account.IsTLSFingerprintEnabled())
|
||||
if err != nil {
|
||||
return s.sendErrorAndEnd(c, fmt.Sprintf("Request failed: %s", err.Error()))
|
||||
}
|
||||
@@ -446,7 +446,7 @@ func (s *AccountTestService) testGeminiAccountConnection(c *gin.Context, account
|
||||
proxyURL = account.Proxy.URL()
|
||||
}
|
||||
|
||||
resp, err := s.httpUpstream.Do(req, proxyURL, account.ID, account.Concurrency)
|
||||
resp, err := s.httpUpstream.DoWithTLS(req, proxyURL, account.ID, account.Concurrency, account.IsTLSFingerprintEnabled())
|
||||
if err != nil {
|
||||
return s.sendErrorAndEnd(c, fmt.Sprintf("Request failed: %s", err.Error()))
|
||||
}
|
||||
|
||||
@@ -32,8 +32,8 @@ type UsageLogRepository interface {
|
||||
|
||||
// Admin dashboard stats
|
||||
GetDashboardStats(ctx context.Context) (*usagestats.DashboardStats, error)
|
||||
GetUsageTrendWithFilters(ctx context.Context, startTime, endTime time.Time, granularity string, userID, apiKeyID, accountID, groupID int64, model string, stream *bool) ([]usagestats.TrendDataPoint, error)
|
||||
GetModelStatsWithFilters(ctx context.Context, startTime, endTime time.Time, userID, apiKeyID, accountID, groupID int64, stream *bool) ([]usagestats.ModelStat, error)
|
||||
GetUsageTrendWithFilters(ctx context.Context, startTime, endTime time.Time, granularity string, userID, apiKeyID, accountID, groupID int64, model string, stream *bool, billingType *int8) ([]usagestats.TrendDataPoint, error)
|
||||
GetModelStatsWithFilters(ctx context.Context, startTime, endTime time.Time, userID, apiKeyID, accountID, groupID int64, stream *bool, billingType *int8) ([]usagestats.ModelStat, error)
|
||||
GetAPIKeyUsageTrend(ctx context.Context, startTime, endTime time.Time, granularity string, limit int) ([]usagestats.APIKeyUsageTrendPoint, error)
|
||||
GetUserUsageTrend(ctx context.Context, startTime, endTime time.Time, granularity string, limit int) ([]usagestats.UserUsageTrendPoint, error)
|
||||
GetBatchUserUsageStats(ctx context.Context, userIDs []int64) (map[int64]*usagestats.BatchUserUsageStats, error)
|
||||
@@ -157,9 +157,20 @@ type ClaudeUsageResponse struct {
|
||||
} `json:"seven_day_sonnet"`
|
||||
}
|
||||
|
||||
// ClaudeUsageFetchOptions 包含获取 Claude 用量数据所需的所有选项
|
||||
type ClaudeUsageFetchOptions struct {
|
||||
AccessToken string // OAuth access token
|
||||
ProxyURL string // 代理 URL(可选)
|
||||
AccountID int64 // 账号 ID(用于 TLS 指纹选择)
|
||||
EnableTLSFingerprint bool // 是否启用 TLS 指纹伪装
|
||||
Fingerprint *Fingerprint // 缓存的指纹信息(User-Agent 等)
|
||||
}
|
||||
|
||||
// ClaudeUsageFetcher fetches usage data from Anthropic OAuth API
|
||||
type ClaudeUsageFetcher interface {
|
||||
FetchUsage(ctx context.Context, accessToken, proxyURL string) (*ClaudeUsageResponse, error)
|
||||
// FetchUsageWithOptions 使用完整选项获取用量数据,支持 TLS 指纹和自定义 User-Agent
|
||||
FetchUsageWithOptions(ctx context.Context, opts *ClaudeUsageFetchOptions) (*ClaudeUsageResponse, error)
|
||||
}
|
||||
|
||||
// AccountUsageService 账号使用量查询服务
|
||||
@@ -170,6 +181,7 @@ type AccountUsageService struct {
|
||||
geminiQuotaService *GeminiQuotaService
|
||||
antigravityQuotaFetcher *AntigravityQuotaFetcher
|
||||
cache *UsageCache
|
||||
identityCache IdentityCache
|
||||
}
|
||||
|
||||
// NewAccountUsageService 创建AccountUsageService实例
|
||||
@@ -180,6 +192,7 @@ func NewAccountUsageService(
|
||||
geminiQuotaService *GeminiQuotaService,
|
||||
antigravityQuotaFetcher *AntigravityQuotaFetcher,
|
||||
cache *UsageCache,
|
||||
identityCache IdentityCache,
|
||||
) *AccountUsageService {
|
||||
return &AccountUsageService{
|
||||
accountRepo: accountRepo,
|
||||
@@ -188,6 +201,7 @@ func NewAccountUsageService(
|
||||
geminiQuotaService: geminiQuotaService,
|
||||
antigravityQuotaFetcher: antigravityQuotaFetcher,
|
||||
cache: cache,
|
||||
identityCache: identityCache,
|
||||
}
|
||||
}
|
||||
|
||||
@@ -272,7 +286,7 @@ func (s *AccountUsageService) getGeminiUsage(ctx context.Context, account *Accou
|
||||
}
|
||||
|
||||
dayStart := geminiDailyWindowStart(now)
|
||||
stats, err := s.usageLogRepo.GetModelStatsWithFilters(ctx, dayStart, now, 0, 0, account.ID, 0, nil)
|
||||
stats, err := s.usageLogRepo.GetModelStatsWithFilters(ctx, dayStart, now, 0, 0, account.ID, 0, nil, nil)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("get gemini usage stats failed: %w", err)
|
||||
}
|
||||
@@ -294,7 +308,7 @@ func (s *AccountUsageService) getGeminiUsage(ctx context.Context, account *Accou
|
||||
// Minute window (RPM) - fixed-window approximation: current minute [truncate(now), truncate(now)+1m)
|
||||
minuteStart := now.Truncate(time.Minute)
|
||||
minuteResetAt := minuteStart.Add(time.Minute)
|
||||
minuteStats, err := s.usageLogRepo.GetModelStatsWithFilters(ctx, minuteStart, now, 0, 0, account.ID, 0, nil)
|
||||
minuteStats, err := s.usageLogRepo.GetModelStatsWithFilters(ctx, minuteStart, now, 0, 0, account.ID, 0, nil, nil)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("get gemini minute usage stats failed: %w", err)
|
||||
}
|
||||
@@ -369,12 +383,8 @@ func (s *AccountUsageService) addWindowStats(ctx context.Context, account *Accou
|
||||
|
||||
// 如果没有缓存,从数据库查询
|
||||
if windowStats == nil {
|
||||
var startTime time.Time
|
||||
if account.SessionWindowStart != nil {
|
||||
startTime = *account.SessionWindowStart
|
||||
} else {
|
||||
startTime = time.Now().Add(-5 * time.Hour)
|
||||
}
|
||||
// 使用统一的窗口开始时间计算逻辑(考虑窗口过期情况)
|
||||
startTime := account.GetCurrentWindowStartTime()
|
||||
|
||||
stats, err := s.usageLogRepo.GetAccountWindowStats(ctx, account.ID, startTime)
|
||||
if err != nil {
|
||||
@@ -428,6 +438,8 @@ func (s *AccountUsageService) GetAccountUsageStats(ctx context.Context, accountI
|
||||
}
|
||||
|
||||
// fetchOAuthUsageRaw 从 Anthropic API 获取原始响应(不构建 UsageInfo)
|
||||
// 如果账号开启了 TLS 指纹,则使用 TLS 指纹伪装
|
||||
// 如果有缓存的 Fingerprint,则使用缓存的 User-Agent 等信息
|
||||
func (s *AccountUsageService) fetchOAuthUsageRaw(ctx context.Context, account *Account) (*ClaudeUsageResponse, error) {
|
||||
accessToken := account.GetCredential("access_token")
|
||||
if accessToken == "" {
|
||||
@@ -439,7 +451,22 @@ func (s *AccountUsageService) fetchOAuthUsageRaw(ctx context.Context, account *A
|
||||
proxyURL = account.Proxy.URL()
|
||||
}
|
||||
|
||||
return s.usageFetcher.FetchUsage(ctx, accessToken, proxyURL)
|
||||
// 构建完整的选项
|
||||
opts := &ClaudeUsageFetchOptions{
|
||||
AccessToken: accessToken,
|
||||
ProxyURL: proxyURL,
|
||||
AccountID: account.ID,
|
||||
EnableTLSFingerprint: account.IsTLSFingerprintEnabled(),
|
||||
}
|
||||
|
||||
// 尝试获取缓存的 Fingerprint(包含 User-Agent 等信息)
|
||||
if s.identityCache != nil {
|
||||
if fp, err := s.identityCache.GetFingerprint(ctx, account.ID); err == nil && fp != nil {
|
||||
opts.Fingerprint = fp
|
||||
}
|
||||
}
|
||||
|
||||
return s.usageFetcher.FetchUsageWithOptions(ctx, opts)
|
||||
}
|
||||
|
||||
// parseTime 尝试多种格式解析时间
|
||||
|
||||
@@ -42,6 +42,7 @@ type AdminService interface {
|
||||
DeleteAccount(ctx context.Context, id int64) error
|
||||
RefreshAccountCredentials(ctx context.Context, id int64) (*Account, error)
|
||||
ClearAccountError(ctx context.Context, id int64) (*Account, error)
|
||||
SetAccountError(ctx context.Context, id int64, errorMsg string) error
|
||||
SetAccountSchedulable(ctx context.Context, id int64, schedulable bool) (*Account, error)
|
||||
BulkUpdateAccounts(ctx context.Context, input *BulkUpdateAccountsInput) (*BulkUpdateAccountsResult, error)
|
||||
|
||||
@@ -1101,6 +1102,10 @@ func (s *adminServiceImpl) ClearAccountError(ctx context.Context, id int64) (*Ac
|
||||
return account, nil
|
||||
}
|
||||
|
||||
func (s *adminServiceImpl) SetAccountError(ctx context.Context, id int64, errorMsg string) error {
|
||||
return s.accountRepo.SetError(ctx, id, errorMsg)
|
||||
}
|
||||
|
||||
func (s *adminServiceImpl) SetAccountSchedulable(ctx context.Context, id int64, schedulable bool) (*Account, error) {
|
||||
if err := s.accountRepo.SetSchedulable(ctx, id, schedulable); err != nil {
|
||||
return nil, err
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
@@ -30,7 +30,7 @@ func TestIsAntigravityModelSupported(t *testing.T) {
|
||||
{"可映射 - claude-3-haiku-20240307", "claude-3-haiku-20240307", true},
|
||||
|
||||
// Gemini 前缀透传
|
||||
{"Gemini前缀 - gemini-1.5-pro", "gemini-1.5-pro", true},
|
||||
{"Gemini前缀 - gemini-2.5-pro", "gemini-2.5-pro", true},
|
||||
{"Gemini前缀 - gemini-unknown-model", "gemini-unknown-model", true},
|
||||
{"Gemini前缀 - gemini-future-version", "gemini-future-version", true},
|
||||
|
||||
@@ -142,10 +142,10 @@ func TestAntigravityGatewayService_GetMappedModel(t *testing.T) {
|
||||
expected: "gemini-2.5-flash",
|
||||
},
|
||||
{
|
||||
name: "Gemini透传 - gemini-1.5-pro",
|
||||
requestedModel: "gemini-1.5-pro",
|
||||
name: "Gemini透传 - gemini-2.5-pro",
|
||||
requestedModel: "gemini-2.5-pro",
|
||||
accountMapping: nil,
|
||||
expected: "gemini-1.5-pro",
|
||||
expected: "gemini-2.5-pro",
|
||||
},
|
||||
{
|
||||
name: "Gemini透传 - gemini-future-model",
|
||||
|
||||
@@ -82,13 +82,14 @@ type AntigravityExchangeCodeInput struct {
|
||||
|
||||
// AntigravityTokenInfo token 信息
|
||||
type AntigravityTokenInfo struct {
|
||||
AccessToken string `json:"access_token"`
|
||||
RefreshToken string `json:"refresh_token"`
|
||||
ExpiresIn int64 `json:"expires_in"`
|
||||
ExpiresAt int64 `json:"expires_at"`
|
||||
TokenType string `json:"token_type"`
|
||||
Email string `json:"email,omitempty"`
|
||||
ProjectID string `json:"project_id,omitempty"`
|
||||
AccessToken string `json:"access_token"`
|
||||
RefreshToken string `json:"refresh_token"`
|
||||
ExpiresIn int64 `json:"expires_in"`
|
||||
ExpiresAt int64 `json:"expires_at"`
|
||||
TokenType string `json:"token_type"`
|
||||
Email string `json:"email,omitempty"`
|
||||
ProjectID string `json:"project_id,omitempty"`
|
||||
ProjectIDMissing bool `json:"-"` // LoadCodeAssist 未返回 project_id
|
||||
}
|
||||
|
||||
// ExchangeCode 用 authorization code 交换 token
|
||||
@@ -149,12 +150,6 @@ func (s *AntigravityOAuthService) ExchangeCode(ctx context.Context, input *Antig
|
||||
result.ProjectID = loadResp.CloudAICompanionProject
|
||||
}
|
||||
|
||||
// 兜底:随机生成 project_id
|
||||
if result.ProjectID == "" {
|
||||
result.ProjectID = antigravity.GenerateMockProjectID()
|
||||
fmt.Printf("[AntigravityOAuth] 使用随机生成的 project_id: %s\n", result.ProjectID)
|
||||
}
|
||||
|
||||
return result, nil
|
||||
}
|
||||
|
||||
@@ -236,16 +231,24 @@ func (s *AntigravityOAuthService) RefreshAccountToken(ctx context.Context, accou
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// 保留原有的 project_id 和 email
|
||||
existingProjectID := strings.TrimSpace(account.GetCredential("project_id"))
|
||||
if existingProjectID != "" {
|
||||
tokenInfo.ProjectID = existingProjectID
|
||||
}
|
||||
// 保留原有的 email
|
||||
existingEmail := strings.TrimSpace(account.GetCredential("email"))
|
||||
if existingEmail != "" {
|
||||
tokenInfo.Email = existingEmail
|
||||
}
|
||||
|
||||
// 每次刷新都调用 LoadCodeAssist 获取 project_id
|
||||
client := antigravity.NewClient(proxyURL)
|
||||
loadResp, _, err := client.LoadCodeAssist(ctx, tokenInfo.AccessToken)
|
||||
if err != nil || loadResp == nil || loadResp.CloudAICompanionProject == "" {
|
||||
// LoadCodeAssist 失败或返回空,保留原有 project_id,标记缺失
|
||||
existingProjectID := strings.TrimSpace(account.GetCredential("project_id"))
|
||||
tokenInfo.ProjectID = existingProjectID
|
||||
tokenInfo.ProjectIDMissing = true
|
||||
} else {
|
||||
tokenInfo.ProjectID = loadResp.CloudAICompanionProject
|
||||
}
|
||||
|
||||
return tokenInfo, nil
|
||||
}
|
||||
|
||||
|
||||
@@ -31,11 +31,6 @@ func (f *AntigravityQuotaFetcher) FetchQuota(ctx context.Context, account *Accou
|
||||
accessToken := account.GetCredential("access_token")
|
||||
projectID := account.GetCredential("project_id")
|
||||
|
||||
// 如果没有 project_id,生成一个随机的
|
||||
if projectID == "" {
|
||||
projectID = antigravity.GenerateMockProjectID()
|
||||
}
|
||||
|
||||
client := antigravity.NewClient(proxyURL)
|
||||
|
||||
// 调用 API 获取配额
|
||||
|
||||
190
backend/internal/service/antigravity_rate_limit_test.go
Normal file
190
backend/internal/service/antigravity_rate_limit_test.go
Normal file
@@ -0,0 +1,190 @@
|
||||
//go:build unit
|
||||
|
||||
package service
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
"strings"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/antigravity"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
type stubAntigravityUpstream struct {
|
||||
firstBase string
|
||||
secondBase string
|
||||
calls []string
|
||||
}
|
||||
|
||||
func (s *stubAntigravityUpstream) Do(req *http.Request, proxyURL string, accountID int64, accountConcurrency int) (*http.Response, error) {
|
||||
url := req.URL.String()
|
||||
s.calls = append(s.calls, url)
|
||||
if strings.HasPrefix(url, s.firstBase) {
|
||||
return &http.Response{
|
||||
StatusCode: http.StatusTooManyRequests,
|
||||
Header: http.Header{},
|
||||
Body: io.NopCloser(strings.NewReader(`{"error":{"message":"Resource has been exhausted"}}`)),
|
||||
}, nil
|
||||
}
|
||||
return &http.Response{
|
||||
StatusCode: http.StatusOK,
|
||||
Header: http.Header{},
|
||||
Body: io.NopCloser(strings.NewReader("ok")),
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (s *stubAntigravityUpstream) DoWithTLS(req *http.Request, proxyURL string, accountID int64, accountConcurrency int, enableTLSFingerprint bool) (*http.Response, error) {
|
||||
return s.Do(req, proxyURL, accountID, accountConcurrency)
|
||||
}
|
||||
|
||||
type scopeLimitCall struct {
|
||||
accountID int64
|
||||
scope AntigravityQuotaScope
|
||||
resetAt time.Time
|
||||
}
|
||||
|
||||
type rateLimitCall struct {
|
||||
accountID int64
|
||||
resetAt time.Time
|
||||
}
|
||||
|
||||
type stubAntigravityAccountRepo struct {
|
||||
AccountRepository
|
||||
scopeCalls []scopeLimitCall
|
||||
rateCalls []rateLimitCall
|
||||
}
|
||||
|
||||
func (s *stubAntigravityAccountRepo) SetAntigravityQuotaScopeLimit(ctx context.Context, id int64, scope AntigravityQuotaScope, resetAt time.Time) error {
|
||||
s.scopeCalls = append(s.scopeCalls, scopeLimitCall{accountID: id, scope: scope, resetAt: resetAt})
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *stubAntigravityAccountRepo) SetRateLimited(ctx context.Context, id int64, resetAt time.Time) error {
|
||||
s.rateCalls = append(s.rateCalls, rateLimitCall{accountID: id, resetAt: resetAt})
|
||||
return nil
|
||||
}
|
||||
|
||||
func TestAntigravityRetryLoop_URLFallback_UsesLatestSuccess(t *testing.T) {
|
||||
oldBaseURLs := append([]string(nil), antigravity.BaseURLs...)
|
||||
oldAvailability := antigravity.DefaultURLAvailability
|
||||
defer func() {
|
||||
antigravity.BaseURLs = oldBaseURLs
|
||||
antigravity.DefaultURLAvailability = oldAvailability
|
||||
}()
|
||||
|
||||
base1 := "https://ag-1.test"
|
||||
base2 := "https://ag-2.test"
|
||||
antigravity.BaseURLs = []string{base1, base2}
|
||||
antigravity.DefaultURLAvailability = antigravity.NewURLAvailability(time.Minute)
|
||||
|
||||
upstream := &stubAntigravityUpstream{firstBase: base1, secondBase: base2}
|
||||
account := &Account{
|
||||
ID: 1,
|
||||
Name: "acc-1",
|
||||
Platform: PlatformAntigravity,
|
||||
Schedulable: true,
|
||||
Status: StatusActive,
|
||||
Concurrency: 1,
|
||||
}
|
||||
|
||||
var handleErrorCalled bool
|
||||
result, err := antigravityRetryLoop(antigravityRetryLoopParams{
|
||||
prefix: "[test]",
|
||||
ctx: context.Background(),
|
||||
account: account,
|
||||
proxyURL: "",
|
||||
accessToken: "token",
|
||||
action: "generateContent",
|
||||
body: []byte(`{"input":"test"}`),
|
||||
quotaScope: AntigravityQuotaScopeClaude,
|
||||
httpUpstream: upstream,
|
||||
handleError: func(ctx context.Context, prefix string, account *Account, statusCode int, headers http.Header, body []byte, quotaScope AntigravityQuotaScope) {
|
||||
handleErrorCalled = true
|
||||
},
|
||||
})
|
||||
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, result)
|
||||
require.NotNil(t, result.resp)
|
||||
defer func() { _ = result.resp.Body.Close() }()
|
||||
require.Equal(t, http.StatusOK, result.resp.StatusCode)
|
||||
require.False(t, handleErrorCalled)
|
||||
require.Len(t, upstream.calls, 2)
|
||||
require.True(t, strings.HasPrefix(upstream.calls[0], base1))
|
||||
require.True(t, strings.HasPrefix(upstream.calls[1], base2))
|
||||
|
||||
available := antigravity.DefaultURLAvailability.GetAvailableURLs()
|
||||
require.NotEmpty(t, available)
|
||||
require.Equal(t, base2, available[0])
|
||||
}
|
||||
|
||||
func TestAntigravityHandleUpstreamError_UsesScopeLimitWhenEnabled(t *testing.T) {
|
||||
t.Setenv(antigravityScopeRateLimitEnv, "true")
|
||||
repo := &stubAntigravityAccountRepo{}
|
||||
svc := &AntigravityGatewayService{accountRepo: repo}
|
||||
account := &Account{ID: 9, Name: "acc-9", Platform: PlatformAntigravity}
|
||||
|
||||
body := buildGeminiRateLimitBody("3s")
|
||||
svc.handleUpstreamError(context.Background(), "[test]", account, http.StatusTooManyRequests, http.Header{}, body, AntigravityQuotaScopeClaude)
|
||||
|
||||
require.Len(t, repo.scopeCalls, 1)
|
||||
require.Empty(t, repo.rateCalls)
|
||||
call := repo.scopeCalls[0]
|
||||
require.Equal(t, account.ID, call.accountID)
|
||||
require.Equal(t, AntigravityQuotaScopeClaude, call.scope)
|
||||
require.WithinDuration(t, time.Now().Add(3*time.Second), call.resetAt, 2*time.Second)
|
||||
}
|
||||
|
||||
func TestAntigravityHandleUpstreamError_UsesAccountLimitWhenScopeDisabled(t *testing.T) {
|
||||
t.Setenv(antigravityScopeRateLimitEnv, "false")
|
||||
repo := &stubAntigravityAccountRepo{}
|
||||
svc := &AntigravityGatewayService{accountRepo: repo}
|
||||
account := &Account{ID: 10, Name: "acc-10", Platform: PlatformAntigravity}
|
||||
|
||||
body := buildGeminiRateLimitBody("2s")
|
||||
svc.handleUpstreamError(context.Background(), "[test]", account, http.StatusTooManyRequests, http.Header{}, body, AntigravityQuotaScopeClaude)
|
||||
|
||||
require.Len(t, repo.rateCalls, 1)
|
||||
require.Empty(t, repo.scopeCalls)
|
||||
call := repo.rateCalls[0]
|
||||
require.Equal(t, account.ID, call.accountID)
|
||||
require.WithinDuration(t, time.Now().Add(2*time.Second), call.resetAt, 2*time.Second)
|
||||
}
|
||||
|
||||
func TestAccountIsSchedulableForModel_AntigravityRateLimits(t *testing.T) {
|
||||
now := time.Now()
|
||||
future := now.Add(10 * time.Minute)
|
||||
|
||||
account := &Account{
|
||||
ID: 1,
|
||||
Name: "acc",
|
||||
Platform: PlatformAntigravity,
|
||||
Status: StatusActive,
|
||||
Schedulable: true,
|
||||
}
|
||||
|
||||
account.RateLimitResetAt = &future
|
||||
require.False(t, account.IsSchedulableForModel("claude-sonnet-4-5"))
|
||||
require.False(t, account.IsSchedulableForModel("gemini-3-flash"))
|
||||
|
||||
account.RateLimitResetAt = nil
|
||||
account.Extra = map[string]any{
|
||||
antigravityQuotaScopesKey: map[string]any{
|
||||
"claude": map[string]any{
|
||||
"rate_limit_reset_at": future.Format(time.RFC3339),
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
require.False(t, account.IsSchedulableForModel("claude-sonnet-4-5"))
|
||||
require.True(t, account.IsSchedulableForModel("gemini-3-flash"))
|
||||
}
|
||||
|
||||
func buildGeminiRateLimitBody(delay string) []byte {
|
||||
return []byte(fmt.Sprintf(`{"error":{"message":"too many requests","details":[{"metadata":{"quotaResetDelay":%q}}]}}`, delay))
|
||||
}
|
||||
@@ -61,5 +61,10 @@ func (r *AntigravityTokenRefresher) Refresh(ctx context.Context, account *Accoun
|
||||
}
|
||||
}
|
||||
|
||||
// 如果 project_id 获取失败,返回 credentials 但同时返回错误让账户被标记
|
||||
if tokenInfo.ProjectIDMissing {
|
||||
return newCredentials, fmt.Errorf("missing_project_id: 账户缺少project id,可能无法使用Antigravity")
|
||||
}
|
||||
|
||||
return newCredentials, nil
|
||||
}
|
||||
|
||||
@@ -94,6 +94,20 @@ func (s *APIKeyService) initAuthCache(cfg *config.Config) {
|
||||
s.authCacheL1 = cache
|
||||
}
|
||||
|
||||
// StartAuthCacheInvalidationSubscriber starts the Pub/Sub subscriber for L1 cache invalidation.
|
||||
// This should be called after the service is fully initialized.
|
||||
func (s *APIKeyService) StartAuthCacheInvalidationSubscriber(ctx context.Context) {
|
||||
if s.cache == nil || s.authCacheL1 == nil {
|
||||
return
|
||||
}
|
||||
if err := s.cache.SubscribeAuthCacheInvalidation(ctx, func(cacheKey string) {
|
||||
s.authCacheL1.Del(cacheKey)
|
||||
}); err != nil {
|
||||
// Log but don't fail - L1 cache will still work, just without cross-instance invalidation
|
||||
println("[Service] Warning: failed to start auth cache invalidation subscriber:", err.Error())
|
||||
}
|
||||
}
|
||||
|
||||
func (s *APIKeyService) authCacheKey(key string) string {
|
||||
sum := sha256.Sum256([]byte(key))
|
||||
return hex.EncodeToString(sum[:])
|
||||
@@ -149,6 +163,8 @@ func (s *APIKeyService) deleteAuthCache(ctx context.Context, cacheKey string) {
|
||||
return
|
||||
}
|
||||
_ = s.cache.DeleteAuthCache(ctx, cacheKey)
|
||||
// Publish invalidation message to other instances
|
||||
_ = s.cache.PublishAuthCacheInvalidation(ctx, cacheKey)
|
||||
}
|
||||
|
||||
func (s *APIKeyService) loadAuthCacheEntry(ctx context.Context, key, cacheKey string) (*APIKeyAuthCacheEntry, error) {
|
||||
|
||||
@@ -65,6 +65,10 @@ type APIKeyCache interface {
|
||||
GetAuthCache(ctx context.Context, key string) (*APIKeyAuthCacheEntry, error)
|
||||
SetAuthCache(ctx context.Context, key string, entry *APIKeyAuthCacheEntry, ttl time.Duration) error
|
||||
DeleteAuthCache(ctx context.Context, key string) error
|
||||
|
||||
// Pub/Sub for L1 cache invalidation across instances
|
||||
PublishAuthCacheInvalidation(ctx context.Context, cacheKey string) error
|
||||
SubscribeAuthCacheInvalidation(ctx context.Context, handler func(cacheKey string)) error
|
||||
}
|
||||
|
||||
// APIKeyAuthCacheInvalidator 提供认证缓存失效能力
|
||||
|
||||
@@ -142,6 +142,14 @@ func (s *authCacheStub) DeleteAuthCache(ctx context.Context, key string) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *authCacheStub) PublishAuthCacheInvalidation(ctx context.Context, cacheKey string) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *authCacheStub) SubscribeAuthCacheInvalidation(ctx context.Context, handler func(cacheKey string)) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func TestAPIKeyService_GetByKey_UsesL2Cache(t *testing.T) {
|
||||
cache := &authCacheStub{}
|
||||
repo := &authRepoStub{
|
||||
|
||||
@@ -168,6 +168,14 @@ func (s *apiKeyCacheStub) DeleteAuthCache(ctx context.Context, key string) error
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *apiKeyCacheStub) PublishAuthCacheInvalidation(ctx context.Context, cacheKey string) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *apiKeyCacheStub) SubscribeAuthCacheInvalidation(ctx context.Context, handler func(cacheKey string)) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
// TestApiKeyService_Delete_OwnerMismatch 测试非所有者尝试删除时返回权限错误。
|
||||
// 预期行为:
|
||||
// - GetKeyAndOwnerID 返回所有者 ID 为 1
|
||||
|
||||
@@ -20,12 +20,16 @@ var (
|
||||
// ErrDashboardBackfillDisabled 当配置禁用回填时返回。
|
||||
ErrDashboardBackfillDisabled = errors.New("仪表盘聚合回填已禁用")
|
||||
// ErrDashboardBackfillTooLarge 当回填跨度超过限制时返回。
|
||||
ErrDashboardBackfillTooLarge = errors.New("回填时间跨度过大")
|
||||
ErrDashboardBackfillTooLarge = errors.New("回填时间跨度过大")
|
||||
errDashboardAggregationRunning = errors.New("聚合作业正在运行")
|
||||
)
|
||||
|
||||
// DashboardAggregationRepository 定义仪表盘预聚合仓储接口。
|
||||
type DashboardAggregationRepository interface {
|
||||
AggregateRange(ctx context.Context, start, end time.Time) error
|
||||
// RecomputeRange 重新计算指定时间范围内的聚合数据(包含活跃用户等派生表)。
|
||||
// 设计目的:当 usage_logs 被批量删除/回滚后,确保聚合表可恢复一致性。
|
||||
RecomputeRange(ctx context.Context, start, end time.Time) error
|
||||
GetAggregationWatermark(ctx context.Context) (time.Time, error)
|
||||
UpdateAggregationWatermark(ctx context.Context, aggregatedAt time.Time) error
|
||||
CleanupAggregates(ctx context.Context, hourlyCutoff, dailyCutoff time.Time) error
|
||||
@@ -112,6 +116,41 @@ func (s *DashboardAggregationService) TriggerBackfill(start, end time.Time) erro
|
||||
return nil
|
||||
}
|
||||
|
||||
// TriggerRecomputeRange 触发指定范围的重新计算(异步)。
|
||||
// 与 TriggerBackfill 不同:
|
||||
// - 不依赖 backfill_enabled(这是内部一致性修复)
|
||||
// - 不更新 watermark(避免影响正常增量聚合游标)
|
||||
func (s *DashboardAggregationService) TriggerRecomputeRange(start, end time.Time) error {
|
||||
if s == nil || s.repo == nil {
|
||||
return errors.New("聚合服务未初始化")
|
||||
}
|
||||
if !s.cfg.Enabled {
|
||||
return errors.New("聚合服务已禁用")
|
||||
}
|
||||
if !end.After(start) {
|
||||
return errors.New("重新计算时间范围无效")
|
||||
}
|
||||
|
||||
go func() {
|
||||
const maxRetries = 3
|
||||
for i := 0; i < maxRetries; i++ {
|
||||
ctx, cancel := context.WithTimeout(context.Background(), defaultDashboardAggregationBackfillTimeout)
|
||||
err := s.recomputeRange(ctx, start, end)
|
||||
cancel()
|
||||
if err == nil {
|
||||
return
|
||||
}
|
||||
if !errors.Is(err, errDashboardAggregationRunning) {
|
||||
log.Printf("[DashboardAggregation] 重新计算失败: %v", err)
|
||||
return
|
||||
}
|
||||
time.Sleep(5 * time.Second)
|
||||
}
|
||||
log.Printf("[DashboardAggregation] 重新计算放弃: 聚合作业持续占用")
|
||||
}()
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *DashboardAggregationService) recomputeRecentDays() {
|
||||
days := s.cfg.RecomputeDays
|
||||
if days <= 0 {
|
||||
@@ -128,6 +167,24 @@ func (s *DashboardAggregationService) recomputeRecentDays() {
|
||||
}
|
||||
}
|
||||
|
||||
func (s *DashboardAggregationService) recomputeRange(ctx context.Context, start, end time.Time) error {
|
||||
if !atomic.CompareAndSwapInt32(&s.running, 0, 1) {
|
||||
return errDashboardAggregationRunning
|
||||
}
|
||||
defer atomic.StoreInt32(&s.running, 0)
|
||||
|
||||
jobStart := time.Now().UTC()
|
||||
if err := s.repo.RecomputeRange(ctx, start, end); err != nil {
|
||||
return err
|
||||
}
|
||||
log.Printf("[DashboardAggregation] 重新计算完成 (start=%s end=%s duration=%s)",
|
||||
start.UTC().Format(time.RFC3339),
|
||||
end.UTC().Format(time.RFC3339),
|
||||
time.Since(jobStart).String(),
|
||||
)
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *DashboardAggregationService) runScheduledAggregation() {
|
||||
if !atomic.CompareAndSwapInt32(&s.running, 0, 1) {
|
||||
return
|
||||
@@ -179,7 +236,7 @@ func (s *DashboardAggregationService) runScheduledAggregation() {
|
||||
|
||||
func (s *DashboardAggregationService) backfillRange(ctx context.Context, start, end time.Time) error {
|
||||
if !atomic.CompareAndSwapInt32(&s.running, 0, 1) {
|
||||
return errors.New("聚合作业正在运行")
|
||||
return errDashboardAggregationRunning
|
||||
}
|
||||
defer atomic.StoreInt32(&s.running, 0)
|
||||
|
||||
|
||||
@@ -27,6 +27,10 @@ func (s *dashboardAggregationRepoTestStub) AggregateRange(ctx context.Context, s
|
||||
return s.aggregateErr
|
||||
}
|
||||
|
||||
func (s *dashboardAggregationRepoTestStub) RecomputeRange(ctx context.Context, start, end time.Time) error {
|
||||
return s.AggregateRange(ctx, start, end)
|
||||
}
|
||||
|
||||
func (s *dashboardAggregationRepoTestStub) GetAggregationWatermark(ctx context.Context) (time.Time, error) {
|
||||
return s.watermark, nil
|
||||
}
|
||||
|
||||
@@ -124,16 +124,16 @@ func (s *DashboardService) GetDashboardStats(ctx context.Context) (*usagestats.D
|
||||
return stats, nil
|
||||
}
|
||||
|
||||
func (s *DashboardService) GetUsageTrendWithFilters(ctx context.Context, startTime, endTime time.Time, granularity string, userID, apiKeyID, accountID, groupID int64, model string, stream *bool) ([]usagestats.TrendDataPoint, error) {
|
||||
trend, err := s.usageRepo.GetUsageTrendWithFilters(ctx, startTime, endTime, granularity, userID, apiKeyID, accountID, groupID, model, stream)
|
||||
func (s *DashboardService) GetUsageTrendWithFilters(ctx context.Context, startTime, endTime time.Time, granularity string, userID, apiKeyID, accountID, groupID int64, model string, stream *bool, billingType *int8) ([]usagestats.TrendDataPoint, error) {
|
||||
trend, err := s.usageRepo.GetUsageTrendWithFilters(ctx, startTime, endTime, granularity, userID, apiKeyID, accountID, groupID, model, stream, billingType)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("get usage trend with filters: %w", err)
|
||||
}
|
||||
return trend, nil
|
||||
}
|
||||
|
||||
func (s *DashboardService) GetModelStatsWithFilters(ctx context.Context, startTime, endTime time.Time, userID, apiKeyID, accountID, groupID int64, stream *bool) ([]usagestats.ModelStat, error) {
|
||||
stats, err := s.usageRepo.GetModelStatsWithFilters(ctx, startTime, endTime, userID, apiKeyID, accountID, groupID, stream)
|
||||
func (s *DashboardService) GetModelStatsWithFilters(ctx context.Context, startTime, endTime time.Time, userID, apiKeyID, accountID, groupID int64, stream *bool, billingType *int8) ([]usagestats.ModelStat, error) {
|
||||
stats, err := s.usageRepo.GetModelStatsWithFilters(ctx, startTime, endTime, userID, apiKeyID, accountID, groupID, stream, billingType)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("get model stats with filters: %w", err)
|
||||
}
|
||||
|
||||
@@ -101,6 +101,10 @@ func (s *dashboardAggregationRepoStub) AggregateRange(ctx context.Context, start
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *dashboardAggregationRepoStub) RecomputeRange(ctx context.Context, start, end time.Time) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *dashboardAggregationRepoStub) GetAggregationWatermark(ctx context.Context) (time.Time, error) {
|
||||
if s.err != nil {
|
||||
return time.Time{}, s.err
|
||||
|
||||
@@ -93,13 +93,14 @@ const (
|
||||
SettingKeyLinuxDoConnectRedirectURL = "linuxdo_connect_redirect_url"
|
||||
|
||||
// OEM设置
|
||||
SettingKeySiteName = "site_name" // 网站名称
|
||||
SettingKeySiteLogo = "site_logo" // 网站Logo (base64)
|
||||
SettingKeySiteSubtitle = "site_subtitle" // 网站副标题
|
||||
SettingKeyAPIBaseURL = "api_base_url" // API端点地址(用于客户端配置和导入)
|
||||
SettingKeyContactInfo = "contact_info" // 客服联系方式
|
||||
SettingKeyDocURL = "doc_url" // 文档链接
|
||||
SettingKeyHomeContent = "home_content" // 首页内容(支持 Markdown/HTML,或 URL 作为 iframe src)
|
||||
SettingKeySiteName = "site_name" // 网站名称
|
||||
SettingKeySiteLogo = "site_logo" // 网站Logo (base64)
|
||||
SettingKeySiteSubtitle = "site_subtitle" // 网站副标题
|
||||
SettingKeyAPIBaseURL = "api_base_url" // API端点地址(用于客户端配置和导入)
|
||||
SettingKeyContactInfo = "contact_info" // 客服联系方式
|
||||
SettingKeyDocURL = "doc_url" // 文档链接
|
||||
SettingKeyHomeContent = "home_content" // 首页内容(支持 Markdown/HTML,或 URL 作为 iframe src)
|
||||
SettingKeyHideCcsImportButton = "hide_ccs_import_button" // 是否隐藏 API Keys 页面的导入 CCS 按钮
|
||||
|
||||
// 默认配置
|
||||
SettingKeyDefaultConcurrency = "default_concurrency" // 新用户默认并发量
|
||||
|
||||
@@ -105,6 +105,9 @@ func (m *mockAccountRepoForPlatform) BatchUpdateLastUsed(ctx context.Context, up
|
||||
func (m *mockAccountRepoForPlatform) SetError(ctx context.Context, id int64, errorMsg string) error {
|
||||
return nil
|
||||
}
|
||||
func (m *mockAccountRepoForPlatform) ClearError(ctx context.Context, id int64) error {
|
||||
return nil
|
||||
}
|
||||
func (m *mockAccountRepoForPlatform) SetSchedulable(ctx context.Context, id int64, schedulable bool) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -11,6 +11,8 @@ import (
|
||||
"fmt"
|
||||
"io"
|
||||
"log"
|
||||
"log/slog"
|
||||
mathrand "math/rand"
|
||||
"net/http"
|
||||
"os"
|
||||
"regexp"
|
||||
@@ -819,11 +821,20 @@ func (s *GatewayService) SelectAccountForModelWithExclusions(ctx context.Context
|
||||
}
|
||||
|
||||
// SelectAccountWithLoadAwareness selects account with load-awareness and wait plan.
|
||||
// metadataUserID: 原始 metadata.user_id 字段(用于提取会话 UUID 进行会话数量限制)
|
||||
// metadataUserID: 已废弃参数,会话限制现在统一使用 sessionHash
|
||||
func (s *GatewayService) SelectAccountWithLoadAwareness(ctx context.Context, groupID *int64, sessionHash string, requestedModel string, excludedIDs map[int64]struct{}, metadataUserID string) (*AccountSelectionResult, error) {
|
||||
// 调试日志:记录调度入口参数
|
||||
excludedIDsList := make([]int64, 0, len(excludedIDs))
|
||||
for id := range excludedIDs {
|
||||
excludedIDsList = append(excludedIDsList, id)
|
||||
}
|
||||
slog.Debug("account_scheduling_starting",
|
||||
"group_id", derefGroupID(groupID),
|
||||
"model", requestedModel,
|
||||
"session", shortSessionHash(sessionHash),
|
||||
"excluded_ids", excludedIDsList)
|
||||
|
||||
cfg := s.schedulingConfig()
|
||||
// 提取会话 UUID(用于会话数量限制)
|
||||
sessionUUID := extractSessionUUID(metadataUserID)
|
||||
|
||||
var stickyAccountID int64
|
||||
if sessionHash != "" && s.cache != nil {
|
||||
@@ -849,41 +860,63 @@ func (s *GatewayService) SelectAccountWithLoadAwareness(ctx context.Context, gro
|
||||
}
|
||||
|
||||
if s.concurrencyService == nil || !cfg.LoadBatchEnabled {
|
||||
account, err := s.SelectAccountForModelWithExclusions(ctx, groupID, sessionHash, requestedModel, excludedIDs)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
// 复制排除列表,用于会话限制拒绝时的重试
|
||||
localExcluded := make(map[int64]struct{})
|
||||
for k, v := range excludedIDs {
|
||||
localExcluded[k] = v
|
||||
}
|
||||
result, err := s.tryAcquireAccountSlot(ctx, account.ID, account.Concurrency)
|
||||
if err == nil && result.Acquired {
|
||||
return &AccountSelectionResult{
|
||||
Account: account,
|
||||
Acquired: true,
|
||||
ReleaseFunc: result.ReleaseFunc,
|
||||
}, nil
|
||||
}
|
||||
if stickyAccountID > 0 && stickyAccountID == account.ID && s.concurrencyService != nil {
|
||||
waitingCount, _ := s.concurrencyService.GetAccountWaitingCount(ctx, account.ID)
|
||||
if waitingCount < cfg.StickySessionMaxWaiting {
|
||||
|
||||
for {
|
||||
account, err := s.SelectAccountForModelWithExclusions(ctx, groupID, sessionHash, requestedModel, localExcluded)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
result, err := s.tryAcquireAccountSlot(ctx, account.ID, account.Concurrency)
|
||||
if err == nil && result.Acquired {
|
||||
// 获取槽位后检查会话限制(使用 sessionHash 作为会话标识符)
|
||||
if !s.checkAndRegisterSession(ctx, account, sessionHash) {
|
||||
result.ReleaseFunc() // 释放槽位
|
||||
localExcluded[account.ID] = struct{}{} // 排除此账号
|
||||
continue // 重新选择
|
||||
}
|
||||
return &AccountSelectionResult{
|
||||
Account: account,
|
||||
WaitPlan: &AccountWaitPlan{
|
||||
AccountID: account.ID,
|
||||
MaxConcurrency: account.Concurrency,
|
||||
Timeout: cfg.StickySessionWaitTimeout,
|
||||
MaxWaiting: cfg.StickySessionMaxWaiting,
|
||||
},
|
||||
Account: account,
|
||||
Acquired: true,
|
||||
ReleaseFunc: result.ReleaseFunc,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// 对于等待计划的情况,也需要先检查会话限制
|
||||
if !s.checkAndRegisterSession(ctx, account, sessionHash) {
|
||||
localExcluded[account.ID] = struct{}{}
|
||||
continue
|
||||
}
|
||||
|
||||
if stickyAccountID > 0 && stickyAccountID == account.ID && s.concurrencyService != nil {
|
||||
waitingCount, _ := s.concurrencyService.GetAccountWaitingCount(ctx, account.ID)
|
||||
if waitingCount < cfg.StickySessionMaxWaiting {
|
||||
return &AccountSelectionResult{
|
||||
Account: account,
|
||||
WaitPlan: &AccountWaitPlan{
|
||||
AccountID: account.ID,
|
||||
MaxConcurrency: account.Concurrency,
|
||||
Timeout: cfg.StickySessionWaitTimeout,
|
||||
MaxWaiting: cfg.StickySessionMaxWaiting,
|
||||
},
|
||||
}, nil
|
||||
}
|
||||
}
|
||||
return &AccountSelectionResult{
|
||||
Account: account,
|
||||
WaitPlan: &AccountWaitPlan{
|
||||
AccountID: account.ID,
|
||||
MaxConcurrency: account.Concurrency,
|
||||
Timeout: cfg.FallbackWaitTimeout,
|
||||
MaxWaiting: cfg.FallbackMaxWaiting,
|
||||
},
|
||||
}, nil
|
||||
}
|
||||
return &AccountSelectionResult{
|
||||
Account: account,
|
||||
WaitPlan: &AccountWaitPlan{
|
||||
AccountID: account.ID,
|
||||
MaxConcurrency: account.Concurrency,
|
||||
Timeout: cfg.FallbackWaitTimeout,
|
||||
MaxWaiting: cfg.FallbackMaxWaiting,
|
||||
},
|
||||
}, nil
|
||||
}
|
||||
|
||||
platform, hasForcePlatform, err := s.resolvePlatform(ctx, groupID, group)
|
||||
@@ -999,7 +1032,7 @@ func (s *GatewayService) SelectAccountWithLoadAwareness(ctx context.Context, gro
|
||||
result, err := s.tryAcquireAccountSlot(ctx, stickyAccountID, stickyAccount.Concurrency)
|
||||
if err == nil && result.Acquired {
|
||||
// 会话数量限制检查
|
||||
if !s.checkAndRegisterSession(ctx, stickyAccount, sessionUUID) {
|
||||
if !s.checkAndRegisterSession(ctx, stickyAccount, sessionHash) {
|
||||
result.ReleaseFunc() // 释放槽位
|
||||
// 继续到负载感知选择
|
||||
} else {
|
||||
@@ -1017,15 +1050,20 @@ func (s *GatewayService) SelectAccountWithLoadAwareness(ctx context.Context, gro
|
||||
|
||||
waitingCount, _ := s.concurrencyService.GetAccountWaitingCount(ctx, stickyAccountID)
|
||||
if waitingCount < cfg.StickySessionMaxWaiting {
|
||||
return &AccountSelectionResult{
|
||||
Account: stickyAccount,
|
||||
WaitPlan: &AccountWaitPlan{
|
||||
AccountID: stickyAccountID,
|
||||
MaxConcurrency: stickyAccount.Concurrency,
|
||||
Timeout: cfg.StickySessionWaitTimeout,
|
||||
MaxWaiting: cfg.StickySessionMaxWaiting,
|
||||
},
|
||||
}, nil
|
||||
// 会话数量限制检查(等待计划也需要占用会话配额)
|
||||
if !s.checkAndRegisterSession(ctx, stickyAccount, sessionHash) {
|
||||
// 会话限制已满,继续到负载感知选择
|
||||
} else {
|
||||
return &AccountSelectionResult{
|
||||
Account: stickyAccount,
|
||||
WaitPlan: &AccountWaitPlan{
|
||||
AccountID: stickyAccountID,
|
||||
MaxConcurrency: stickyAccount.Concurrency,
|
||||
Timeout: cfg.StickySessionWaitTimeout,
|
||||
MaxWaiting: cfg.StickySessionMaxWaiting,
|
||||
},
|
||||
}, nil
|
||||
}
|
||||
}
|
||||
// 粘性账号槽位满且等待队列已满,继续使用负载感知选择
|
||||
}
|
||||
@@ -1086,7 +1124,7 @@ func (s *GatewayService) SelectAccountWithLoadAwareness(ctx context.Context, gro
|
||||
result, err := s.tryAcquireAccountSlot(ctx, item.account.ID, item.account.Concurrency)
|
||||
if err == nil && result.Acquired {
|
||||
// 会话数量限制检查
|
||||
if !s.checkAndRegisterSession(ctx, item.account, sessionUUID) {
|
||||
if !s.checkAndRegisterSession(ctx, item.account, sessionHash) {
|
||||
result.ReleaseFunc() // 释放槽位,继续尝试下一个账号
|
||||
continue
|
||||
}
|
||||
@@ -1104,20 +1142,26 @@ func (s *GatewayService) SelectAccountWithLoadAwareness(ctx context.Context, gro
|
||||
}
|
||||
}
|
||||
|
||||
// 5. 所有路由账号槽位满,返回等待计划(选择负载最低的)
|
||||
acc := routingAvailable[0].account
|
||||
if s.debugModelRoutingEnabled() {
|
||||
log.Printf("[ModelRoutingDebug] routed wait: group_id=%v model=%s session=%s account=%d", derefGroupID(groupID), requestedModel, shortSessionHash(sessionHash), acc.ID)
|
||||
// 5. 所有路由账号槽位满,尝试返回等待计划(选择负载最低的)
|
||||
// 遍历找到第一个满足会话限制的账号
|
||||
for _, item := range routingAvailable {
|
||||
if !s.checkAndRegisterSession(ctx, item.account, sessionHash) {
|
||||
continue // 会话限制已满,尝试下一个
|
||||
}
|
||||
if s.debugModelRoutingEnabled() {
|
||||
log.Printf("[ModelRoutingDebug] routed wait: group_id=%v model=%s session=%s account=%d", derefGroupID(groupID), requestedModel, shortSessionHash(sessionHash), item.account.ID)
|
||||
}
|
||||
return &AccountSelectionResult{
|
||||
Account: item.account,
|
||||
WaitPlan: &AccountWaitPlan{
|
||||
AccountID: item.account.ID,
|
||||
MaxConcurrency: item.account.Concurrency,
|
||||
Timeout: cfg.StickySessionWaitTimeout,
|
||||
MaxWaiting: cfg.StickySessionMaxWaiting,
|
||||
},
|
||||
}, nil
|
||||
}
|
||||
return &AccountSelectionResult{
|
||||
Account: acc,
|
||||
WaitPlan: &AccountWaitPlan{
|
||||
AccountID: acc.ID,
|
||||
MaxConcurrency: acc.Concurrency,
|
||||
Timeout: cfg.StickySessionWaitTimeout,
|
||||
MaxWaiting: cfg.StickySessionMaxWaiting,
|
||||
},
|
||||
}, nil
|
||||
// 所有路由账号会话限制都已满,继续到 Layer 2 回退
|
||||
}
|
||||
// 路由列表中的账号都不可用(负载率 >= 100),继续到 Layer 2 回退
|
||||
log.Printf("[ModelRouting] All routed accounts unavailable for model=%s, falling back to normal selection", requestedModel)
|
||||
@@ -1137,7 +1181,7 @@ func (s *GatewayService) SelectAccountWithLoadAwareness(ctx context.Context, gro
|
||||
result, err := s.tryAcquireAccountSlot(ctx, accountID, account.Concurrency)
|
||||
if err == nil && result.Acquired {
|
||||
// 会话数量限制检查
|
||||
if !s.checkAndRegisterSession(ctx, account, sessionUUID) {
|
||||
if !s.checkAndRegisterSession(ctx, account, sessionHash) {
|
||||
result.ReleaseFunc() // 释放槽位,继续到 Layer 2
|
||||
} else {
|
||||
_ = s.cache.RefreshSessionTTL(ctx, derefGroupID(groupID), sessionHash, stickySessionTTL)
|
||||
@@ -1151,15 +1195,20 @@ func (s *GatewayService) SelectAccountWithLoadAwareness(ctx context.Context, gro
|
||||
|
||||
waitingCount, _ := s.concurrencyService.GetAccountWaitingCount(ctx, accountID)
|
||||
if waitingCount < cfg.StickySessionMaxWaiting {
|
||||
return &AccountSelectionResult{
|
||||
Account: account,
|
||||
WaitPlan: &AccountWaitPlan{
|
||||
AccountID: accountID,
|
||||
MaxConcurrency: account.Concurrency,
|
||||
Timeout: cfg.StickySessionWaitTimeout,
|
||||
MaxWaiting: cfg.StickySessionMaxWaiting,
|
||||
},
|
||||
}, nil
|
||||
// 会话数量限制检查(等待计划也需要占用会话配额)
|
||||
if !s.checkAndRegisterSession(ctx, account, sessionHash) {
|
||||
// 会话限制已满,继续到 Layer 2
|
||||
} else {
|
||||
return &AccountSelectionResult{
|
||||
Account: account,
|
||||
WaitPlan: &AccountWaitPlan{
|
||||
AccountID: accountID,
|
||||
MaxConcurrency: account.Concurrency,
|
||||
Timeout: cfg.StickySessionWaitTimeout,
|
||||
MaxWaiting: cfg.StickySessionMaxWaiting,
|
||||
},
|
||||
}, nil
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -1208,7 +1257,7 @@ func (s *GatewayService) SelectAccountWithLoadAwareness(ctx context.Context, gro
|
||||
|
||||
loadMap, err := s.concurrencyService.GetAccountsLoadBatch(ctx, accountLoads)
|
||||
if err != nil {
|
||||
if result, ok := s.tryAcquireByLegacyOrder(ctx, candidates, groupID, sessionHash, preferOAuth, sessionUUID); ok {
|
||||
if result, ok := s.tryAcquireByLegacyOrder(ctx, candidates, groupID, sessionHash, preferOAuth); ok {
|
||||
return result, nil
|
||||
}
|
||||
} else {
|
||||
@@ -1258,7 +1307,7 @@ func (s *GatewayService) SelectAccountWithLoadAwareness(ctx context.Context, gro
|
||||
result, err := s.tryAcquireAccountSlot(ctx, item.account.ID, item.account.Concurrency)
|
||||
if err == nil && result.Acquired {
|
||||
// 会话数量限制检查
|
||||
if !s.checkAndRegisterSession(ctx, item.account, sessionUUID) {
|
||||
if !s.checkAndRegisterSession(ctx, item.account, sessionHash) {
|
||||
result.ReleaseFunc() // 释放槽位,继续尝试下一个账号
|
||||
continue
|
||||
}
|
||||
@@ -1276,8 +1325,12 @@ func (s *GatewayService) SelectAccountWithLoadAwareness(ctx context.Context, gro
|
||||
}
|
||||
|
||||
// ============ Layer 3: 兜底排队 ============
|
||||
sortAccountsByPriorityAndLastUsed(candidates, preferOAuth)
|
||||
s.sortCandidatesForFallback(candidates, preferOAuth, cfg.FallbackSelectionMode)
|
||||
for _, acc := range candidates {
|
||||
// 会话数量限制检查(等待计划也需要占用会话配额)
|
||||
if !s.checkAndRegisterSession(ctx, acc, sessionHash) {
|
||||
continue // 会话限制已满,尝试下一个账号
|
||||
}
|
||||
return &AccountSelectionResult{
|
||||
Account: acc,
|
||||
WaitPlan: &AccountWaitPlan{
|
||||
@@ -1291,7 +1344,7 @@ func (s *GatewayService) SelectAccountWithLoadAwareness(ctx context.Context, gro
|
||||
return nil, errors.New("no available accounts")
|
||||
}
|
||||
|
||||
func (s *GatewayService) tryAcquireByLegacyOrder(ctx context.Context, candidates []*Account, groupID *int64, sessionHash string, preferOAuth bool, sessionUUID string) (*AccountSelectionResult, bool) {
|
||||
func (s *GatewayService) tryAcquireByLegacyOrder(ctx context.Context, candidates []*Account, groupID *int64, sessionHash string, preferOAuth bool) (*AccountSelectionResult, bool) {
|
||||
ordered := append([]*Account(nil), candidates...)
|
||||
sortAccountsByPriorityAndLastUsed(ordered, preferOAuth)
|
||||
|
||||
@@ -1299,7 +1352,7 @@ func (s *GatewayService) tryAcquireByLegacyOrder(ctx context.Context, candidates
|
||||
result, err := s.tryAcquireAccountSlot(ctx, acc.ID, acc.Concurrency)
|
||||
if err == nil && result.Acquired {
|
||||
// 会话数量限制检查
|
||||
if !s.checkAndRegisterSession(ctx, acc, sessionUUID) {
|
||||
if !s.checkAndRegisterSession(ctx, acc, sessionHash) {
|
||||
result.ReleaseFunc() // 释放槽位,继续尝试下一个账号
|
||||
continue
|
||||
}
|
||||
@@ -1456,7 +1509,24 @@ func (s *GatewayService) resolvePlatform(ctx context.Context, groupID *int64, gr
|
||||
|
||||
func (s *GatewayService) listSchedulableAccounts(ctx context.Context, groupID *int64, platform string, hasForcePlatform bool) ([]Account, bool, error) {
|
||||
if s.schedulerSnapshot != nil {
|
||||
return s.schedulerSnapshot.ListSchedulableAccounts(ctx, groupID, platform, hasForcePlatform)
|
||||
accounts, useMixed, err := s.schedulerSnapshot.ListSchedulableAccounts(ctx, groupID, platform, hasForcePlatform)
|
||||
if err == nil {
|
||||
slog.Debug("account_scheduling_list_snapshot",
|
||||
"group_id", derefGroupID(groupID),
|
||||
"platform", platform,
|
||||
"use_mixed", useMixed,
|
||||
"count", len(accounts))
|
||||
for _, acc := range accounts {
|
||||
slog.Debug("account_scheduling_account_detail",
|
||||
"account_id", acc.ID,
|
||||
"name", acc.Name,
|
||||
"platform", acc.Platform,
|
||||
"type", acc.Type,
|
||||
"status", acc.Status,
|
||||
"tls_fingerprint", acc.IsTLSFingerprintEnabled())
|
||||
}
|
||||
}
|
||||
return accounts, useMixed, err
|
||||
}
|
||||
useMixed := (platform == PlatformAnthropic || platform == PlatformGemini) && !hasForcePlatform
|
||||
if useMixed {
|
||||
@@ -1469,6 +1539,10 @@ func (s *GatewayService) listSchedulableAccounts(ctx context.Context, groupID *i
|
||||
accounts, err = s.accountRepo.ListSchedulableByPlatforms(ctx, platforms)
|
||||
}
|
||||
if err != nil {
|
||||
slog.Debug("account_scheduling_list_failed",
|
||||
"group_id", derefGroupID(groupID),
|
||||
"platform", platform,
|
||||
"error", err)
|
||||
return nil, useMixed, err
|
||||
}
|
||||
filtered := make([]Account, 0, len(accounts))
|
||||
@@ -1478,6 +1552,20 @@ func (s *GatewayService) listSchedulableAccounts(ctx context.Context, groupID *i
|
||||
}
|
||||
filtered = append(filtered, acc)
|
||||
}
|
||||
slog.Debug("account_scheduling_list_mixed",
|
||||
"group_id", derefGroupID(groupID),
|
||||
"platform", platform,
|
||||
"raw_count", len(accounts),
|
||||
"filtered_count", len(filtered))
|
||||
for _, acc := range filtered {
|
||||
slog.Debug("account_scheduling_account_detail",
|
||||
"account_id", acc.ID,
|
||||
"name", acc.Name,
|
||||
"platform", acc.Platform,
|
||||
"type", acc.Type,
|
||||
"status", acc.Status,
|
||||
"tls_fingerprint", acc.IsTLSFingerprintEnabled())
|
||||
}
|
||||
return filtered, useMixed, nil
|
||||
}
|
||||
|
||||
@@ -1492,8 +1580,25 @@ func (s *GatewayService) listSchedulableAccounts(ctx context.Context, groupID *i
|
||||
accounts, err = s.accountRepo.ListSchedulableByPlatform(ctx, platform)
|
||||
}
|
||||
if err != nil {
|
||||
slog.Debug("account_scheduling_list_failed",
|
||||
"group_id", derefGroupID(groupID),
|
||||
"platform", platform,
|
||||
"error", err)
|
||||
return nil, useMixed, err
|
||||
}
|
||||
slog.Debug("account_scheduling_list_single",
|
||||
"group_id", derefGroupID(groupID),
|
||||
"platform", platform,
|
||||
"count", len(accounts))
|
||||
for _, acc := range accounts {
|
||||
slog.Debug("account_scheduling_account_detail",
|
||||
"account_id", acc.ID,
|
||||
"name", acc.Name,
|
||||
"platform", acc.Platform,
|
||||
"type", acc.Type,
|
||||
"status", acc.Status,
|
||||
"tls_fingerprint", acc.IsTLSFingerprintEnabled())
|
||||
}
|
||||
return accounts, useMixed, nil
|
||||
}
|
||||
|
||||
@@ -1559,12 +1664,8 @@ func (s *GatewayService) isAccountSchedulableForWindowCost(ctx context.Context,
|
||||
|
||||
// 缓存未命中,从数据库查询
|
||||
{
|
||||
var startTime time.Time
|
||||
if account.SessionWindowStart != nil {
|
||||
startTime = *account.SessionWindowStart
|
||||
} else {
|
||||
startTime = time.Now().Add(-5 * time.Hour)
|
||||
}
|
||||
// 使用统一的窗口开始时间计算逻辑(考虑窗口过期情况)
|
||||
startTime := account.GetCurrentWindowStartTime()
|
||||
|
||||
stats, err := s.usageLogRepo.GetAccountWindowStats(ctx, account.ID, startTime)
|
||||
if err != nil {
|
||||
@@ -1597,15 +1698,16 @@ checkSchedulability:
|
||||
|
||||
// checkAndRegisterSession 检查并注册会话,用于会话数量限制
|
||||
// 仅适用于 Anthropic OAuth/SetupToken 账号
|
||||
// sessionID: 会话标识符(使用粘性会话的 hash)
|
||||
// 返回 true 表示允许(在限制内或会话已存在),false 表示拒绝(超出限制且是新会话)
|
||||
func (s *GatewayService) checkAndRegisterSession(ctx context.Context, account *Account, sessionUUID string) bool {
|
||||
func (s *GatewayService) checkAndRegisterSession(ctx context.Context, account *Account, sessionID string) bool {
|
||||
// 只检查 Anthropic OAuth/SetupToken 账号
|
||||
if !account.IsAnthropicOAuthOrSetupToken() {
|
||||
return true
|
||||
}
|
||||
|
||||
maxSessions := account.GetMaxSessions()
|
||||
if maxSessions <= 0 || sessionUUID == "" {
|
||||
if maxSessions <= 0 || sessionID == "" {
|
||||
return true // 未启用会话限制或无会话ID
|
||||
}
|
||||
|
||||
@@ -1615,7 +1717,7 @@ func (s *GatewayService) checkAndRegisterSession(ctx context.Context, account *A
|
||||
|
||||
idleTimeout := time.Duration(account.GetSessionIdleTimeoutMinutes()) * time.Minute
|
||||
|
||||
allowed, err := s.sessionLimitCache.RegisterSession(ctx, account.ID, sessionUUID, maxSessions, idleTimeout)
|
||||
allowed, err := s.sessionLimitCache.RegisterSession(ctx, account.ID, sessionID, maxSessions, idleTimeout)
|
||||
if err != nil {
|
||||
// 失败开放:缓存错误时允许通过
|
||||
return true
|
||||
@@ -1623,18 +1725,6 @@ func (s *GatewayService) checkAndRegisterSession(ctx context.Context, account *A
|
||||
return allowed
|
||||
}
|
||||
|
||||
// extractSessionUUID 从 metadata.user_id 中提取会话 UUID
|
||||
// 格式: user_{64位hex}_account__session_{uuid}
|
||||
func extractSessionUUID(metadataUserID string) string {
|
||||
if metadataUserID == "" {
|
||||
return ""
|
||||
}
|
||||
if match := sessionIDRegex.FindStringSubmatch(metadataUserID); len(match) > 1 {
|
||||
return match[1]
|
||||
}
|
||||
return ""
|
||||
}
|
||||
|
||||
func (s *GatewayService) getSchedulableAccount(ctx context.Context, accountID int64) (*Account, error) {
|
||||
if s.schedulerSnapshot != nil {
|
||||
return s.schedulerSnapshot.GetAccount(ctx, accountID)
|
||||
@@ -1664,6 +1754,56 @@ func sortAccountsByPriorityAndLastUsed(accounts []*Account, preferOAuth bool) {
|
||||
})
|
||||
}
|
||||
|
||||
// sortCandidatesForFallback 根据配置选择排序策略
|
||||
// mode: "last_used"(按最后使用时间) 或 "random"(随机)
|
||||
func (s *GatewayService) sortCandidatesForFallback(accounts []*Account, preferOAuth bool, mode string) {
|
||||
if mode == "random" {
|
||||
// 先按优先级排序,然后在同优先级内随机打乱
|
||||
sortAccountsByPriorityOnly(accounts, preferOAuth)
|
||||
shuffleWithinPriority(accounts)
|
||||
} else {
|
||||
// 默认按最后使用时间排序
|
||||
sortAccountsByPriorityAndLastUsed(accounts, preferOAuth)
|
||||
}
|
||||
}
|
||||
|
||||
// sortAccountsByPriorityOnly 仅按优先级排序
|
||||
func sortAccountsByPriorityOnly(accounts []*Account, preferOAuth bool) {
|
||||
sort.SliceStable(accounts, func(i, j int) bool {
|
||||
a, b := accounts[i], accounts[j]
|
||||
if a.Priority != b.Priority {
|
||||
return a.Priority < b.Priority
|
||||
}
|
||||
if preferOAuth && a.Type != b.Type {
|
||||
return a.Type == AccountTypeOAuth
|
||||
}
|
||||
return false
|
||||
})
|
||||
}
|
||||
|
||||
// shuffleWithinPriority 在同优先级内随机打乱顺序
|
||||
func shuffleWithinPriority(accounts []*Account) {
|
||||
if len(accounts) <= 1 {
|
||||
return
|
||||
}
|
||||
r := mathrand.New(mathrand.NewSource(time.Now().UnixNano()))
|
||||
start := 0
|
||||
for start < len(accounts) {
|
||||
priority := accounts[start].Priority
|
||||
end := start + 1
|
||||
for end < len(accounts) && accounts[end].Priority == priority {
|
||||
end++
|
||||
}
|
||||
// 对 [start, end) 范围内的账户随机打乱
|
||||
if end-start > 1 {
|
||||
r.Shuffle(end-start, func(i, j int) {
|
||||
accounts[start+i], accounts[start+j] = accounts[start+j], accounts[start+i]
|
||||
})
|
||||
}
|
||||
start = end
|
||||
}
|
||||
}
|
||||
|
||||
// selectAccountForModelWithPlatform 选择单平台账户(完全隔离)
|
||||
func (s *GatewayService) selectAccountForModelWithPlatform(ctx context.Context, groupID *int64, sessionHash string, requestedModel string, excludedIDs map[int64]struct{}, platform string) (*Account, error) {
|
||||
preferOAuth := platform == PlatformGemini
|
||||
@@ -2524,6 +2664,10 @@ func (s *GatewayService) Forward(ctx context.Context, c *gin.Context, account *A
|
||||
proxyURL = account.Proxy.URL()
|
||||
}
|
||||
|
||||
// 调试日志:记录即将转发的账号信息
|
||||
log.Printf("[Forward] Using account: ID=%d Name=%s Platform=%s Type=%s TLSFingerprint=%v Proxy=%s",
|
||||
account.ID, account.Name, account.Platform, account.Type, account.IsTLSFingerprintEnabled(), proxyURL)
|
||||
|
||||
// 重试循环
|
||||
var resp *http.Response
|
||||
retryStart := time.Now()
|
||||
@@ -2537,7 +2681,7 @@ func (s *GatewayService) Forward(ctx context.Context, c *gin.Context, account *A
|
||||
}
|
||||
|
||||
// 发送请求
|
||||
resp, err = s.httpUpstream.Do(upstreamReq, proxyURL, account.ID, account.Concurrency)
|
||||
resp, err = s.httpUpstream.DoWithTLS(upstreamReq, proxyURL, account.ID, account.Concurrency, account.IsTLSFingerprintEnabled())
|
||||
if err != nil {
|
||||
if resp != nil && resp.Body != nil {
|
||||
_ = resp.Body.Close()
|
||||
@@ -2611,7 +2755,7 @@ func (s *GatewayService) Forward(ctx context.Context, c *gin.Context, account *A
|
||||
filteredBody := FilterThinkingBlocksForRetry(body)
|
||||
retryReq, buildErr := s.buildUpstreamRequest(ctx, c, account, filteredBody, token, tokenType, reqModel, reqStream, shouldMimicClaudeCode)
|
||||
if buildErr == nil {
|
||||
retryResp, retryErr := s.httpUpstream.Do(retryReq, proxyURL, account.ID, account.Concurrency)
|
||||
retryResp, retryErr := s.httpUpstream.DoWithTLS(retryReq, proxyURL, account.ID, account.Concurrency, account.IsTLSFingerprintEnabled())
|
||||
if retryErr == nil {
|
||||
if retryResp.StatusCode < 400 {
|
||||
log.Printf("Account %d: signature error retry succeeded (thinking downgraded)", account.ID)
|
||||
@@ -2643,7 +2787,7 @@ func (s *GatewayService) Forward(ctx context.Context, c *gin.Context, account *A
|
||||
filteredBody2 := FilterSignatureSensitiveBlocksForRetry(body)
|
||||
retryReq2, buildErr2 := s.buildUpstreamRequest(ctx, c, account, filteredBody2, token, tokenType, reqModel, reqStream, shouldMimicClaudeCode)
|
||||
if buildErr2 == nil {
|
||||
retryResp2, retryErr2 := s.httpUpstream.Do(retryReq2, proxyURL, account.ID, account.Concurrency)
|
||||
retryResp2, retryErr2 := s.httpUpstream.DoWithTLS(retryReq2, proxyURL, account.ID, account.Concurrency, account.IsTLSFingerprintEnabled())
|
||||
if retryErr2 == nil {
|
||||
resp = retryResp2
|
||||
break
|
||||
@@ -2758,6 +2902,10 @@ func (s *GatewayService) Forward(ctx context.Context, c *gin.Context, account *A
|
||||
_ = resp.Body.Close()
|
||||
resp.Body = io.NopCloser(bytes.NewReader(respBody))
|
||||
|
||||
// 调试日志:打印重试耗尽后的错误响应
|
||||
log.Printf("[Forward] Upstream error (retry exhausted, failover): Account=%d(%s) Status=%d RequestID=%s Body=%s",
|
||||
account.ID, account.Name, resp.StatusCode, resp.Header.Get("x-request-id"), truncateString(string(respBody), 1000))
|
||||
|
||||
s.handleRetryExhaustedSideEffects(ctx, resp, account)
|
||||
appendOpsUpstreamError(c, OpsUpstreamErrorEvent{
|
||||
Platform: account.Platform,
|
||||
@@ -2785,6 +2933,10 @@ func (s *GatewayService) Forward(ctx context.Context, c *gin.Context, account *A
|
||||
_ = resp.Body.Close()
|
||||
resp.Body = io.NopCloser(bytes.NewReader(respBody))
|
||||
|
||||
// 调试日志:打印上游错误响应
|
||||
log.Printf("[Forward] Upstream error (failover): Account=%d(%s) Status=%d RequestID=%s Body=%s",
|
||||
account.ID, account.Name, resp.StatusCode, resp.Header.Get("x-request-id"), truncateString(string(respBody), 1000))
|
||||
|
||||
s.handleFailoverSideEffects(ctx, resp, account)
|
||||
appendOpsUpstreamError(c, OpsUpstreamErrorEvent{
|
||||
Platform: account.Platform,
|
||||
@@ -2914,9 +3066,10 @@ func (s *GatewayService) buildUpstreamRequest(ctx context.Context, c *gin.Contex
|
||||
fingerprint = fp
|
||||
|
||||
// 2. 重写metadata.user_id(需要指纹中的ClientID和账号的account_uuid)
|
||||
// 如果启用了会话ID伪装,会在重写后替换 session 部分为固定值
|
||||
accountUUID := account.GetExtraString("account_uuid")
|
||||
if accountUUID != "" && fp.ClientID != "" {
|
||||
if newBody, err := s.identityService.RewriteUserID(body, account.ID, accountUUID, fp.ClientID); err == nil && len(newBody) > 0 {
|
||||
if newBody, err := s.identityService.RewriteUserIDWithMasking(ctx, body, account, accountUUID, fp.ClientID); err == nil && len(newBody) > 0 {
|
||||
body = newBody
|
||||
}
|
||||
}
|
||||
@@ -3183,6 +3336,10 @@ func extractUpstreamErrorMessage(body []byte) string {
|
||||
func (s *GatewayService) handleErrorResponse(ctx context.Context, resp *http.Response, c *gin.Context, account *Account) (*ForwardResult, error) {
|
||||
body, _ := io.ReadAll(io.LimitReader(resp.Body, 2<<20))
|
||||
|
||||
// 调试日志:打印上游错误响应
|
||||
log.Printf("[Forward] Upstream error (non-retryable): Account=%d(%s) Status=%d RequestID=%s Body=%s",
|
||||
account.ID, account.Name, resp.StatusCode, resp.Header.Get("x-request-id"), truncateString(string(body), 1000))
|
||||
|
||||
upstreamMsg := strings.TrimSpace(extractUpstreamErrorMessage(body))
|
||||
upstreamMsg = sanitizeUpstreamErrorMessage(upstreamMsg)
|
||||
|
||||
@@ -4171,7 +4328,7 @@ func (s *GatewayService) ForwardCountTokens(ctx context.Context, c *gin.Context,
|
||||
}
|
||||
|
||||
// 发送请求
|
||||
resp, err := s.httpUpstream.Do(upstreamReq, proxyURL, account.ID, account.Concurrency)
|
||||
resp, err := s.httpUpstream.DoWithTLS(upstreamReq, proxyURL, account.ID, account.Concurrency, account.IsTLSFingerprintEnabled())
|
||||
if err != nil {
|
||||
setOpsUpstreamError(c, 0, sanitizeUpstreamErrorMessage(err.Error()), "")
|
||||
s.countTokensError(c, http.StatusBadGateway, "upstream_error", "Request failed")
|
||||
@@ -4193,7 +4350,7 @@ func (s *GatewayService) ForwardCountTokens(ctx context.Context, c *gin.Context,
|
||||
filteredBody := FilterThinkingBlocksForRetry(body)
|
||||
retryReq, buildErr := s.buildCountTokensRequest(ctx, c, account, filteredBody, token, tokenType, reqModel, shouldMimicClaudeCode)
|
||||
if buildErr == nil {
|
||||
retryResp, retryErr := s.httpUpstream.Do(retryReq, proxyURL, account.ID, account.Concurrency)
|
||||
retryResp, retryErr := s.httpUpstream.DoWithTLS(retryReq, proxyURL, account.ID, account.Concurrency, account.IsTLSFingerprintEnabled())
|
||||
if retryErr == nil {
|
||||
resp = retryResp
|
||||
respBody, err = io.ReadAll(resp.Body)
|
||||
@@ -4271,12 +4428,13 @@ func (s *GatewayService) buildCountTokensRequest(ctx context.Context, c *gin.Con
|
||||
}
|
||||
|
||||
// OAuth 账号:应用统一指纹和重写 userID
|
||||
// 如果启用了会话ID伪装,会在重写后替换 session 部分为固定值
|
||||
if account.IsOAuth() && s.identityService != nil {
|
||||
fp, err := s.identityService.GetOrCreateFingerprint(ctx, account.ID, c.Request.Header)
|
||||
if err == nil {
|
||||
accountUUID := account.GetExtraString("account_uuid")
|
||||
if accountUUID != "" && fp.ClientID != "" {
|
||||
if newBody, err := s.identityService.RewriteUserID(body, account.ID, accountUUID, fp.ClientID); err == nil && len(newBody) > 0 {
|
||||
if newBody, err := s.identityService.RewriteUserIDWithMasking(ctx, body, account, accountUUID, fp.ClientID); err == nil && len(newBody) > 0 {
|
||||
body = newBody
|
||||
}
|
||||
}
|
||||
|
||||
@@ -88,6 +88,9 @@ func (m *mockAccountRepoForGemini) BatchUpdateLastUsed(ctx context.Context, upda
|
||||
func (m *mockAccountRepoForGemini) SetError(ctx context.Context, id int64, errorMsg string) error {
|
||||
return nil
|
||||
}
|
||||
func (m *mockAccountRepoForGemini) ClearError(ctx context.Context, id int64) error {
|
||||
return nil
|
||||
}
|
||||
func (m *mockAccountRepoForGemini) SetSchedulable(ctx context.Context, id int64, schedulable bool) error {
|
||||
return nil
|
||||
}
|
||||
@@ -599,7 +602,7 @@ func TestGeminiMessagesCompatService_isModelSupportedByAccount(t *testing.T) {
|
||||
name: "Gemini平台-有映射配置-只支持配置的模型",
|
||||
account: &Account{
|
||||
Platform: PlatformGemini,
|
||||
Credentials: map[string]any{"model_mapping": map[string]any{"gemini-1.5-pro": "x"}},
|
||||
Credentials: map[string]any{"model_mapping": map[string]any{"gemini-2.5-pro": "x"}},
|
||||
},
|
||||
model: "gemini-2.5-flash",
|
||||
expected: false,
|
||||
|
||||
@@ -10,6 +10,7 @@ import "net/http"
|
||||
// - 支持可选代理配置
|
||||
// - 支持账户级连接池隔离
|
||||
// - 实现类负责连接池管理和复用
|
||||
// - 支持可选的 TLS 指纹伪装
|
||||
type HTTPUpstream interface {
|
||||
// Do 执行 HTTP 请求
|
||||
//
|
||||
@@ -27,4 +28,28 @@ type HTTPUpstream interface {
|
||||
// - 调用方必须关闭 resp.Body,否则会导致连接泄漏
|
||||
// - 响应体可能已被包装以跟踪请求生命周期
|
||||
Do(req *http.Request, proxyURL string, accountID int64, accountConcurrency int) (*http.Response, error)
|
||||
|
||||
// DoWithTLS 执行带 TLS 指纹伪装的 HTTP 请求
|
||||
//
|
||||
// 参数:
|
||||
// - req: HTTP 请求对象,由调用方构建
|
||||
// - proxyURL: 代理服务器地址,空字符串表示直连
|
||||
// - accountID: 账户 ID,用于连接池隔离和 TLS 指纹模板选择
|
||||
// - accountConcurrency: 账户并发限制,用于动态调整连接池大小
|
||||
// - enableTLSFingerprint: 是否启用 TLS 指纹伪装
|
||||
//
|
||||
// 返回:
|
||||
// - *http.Response: HTTP 响应,调用方必须关闭 Body
|
||||
// - error: 请求错误(网络错误、超时等)
|
||||
//
|
||||
// TLS 指纹说明:
|
||||
// - 当 enableTLSFingerprint=true 时,使用 utls 库模拟 Claude CLI 的 TLS 指纹
|
||||
// - TLS 指纹模板根据 accountID % len(profiles) 自动选择
|
||||
// - 支持直连、HTTP/HTTPS 代理、SOCKS5 代理三种场景
|
||||
// - 如果 enableTLSFingerprint=false,行为与 Do 方法相同
|
||||
//
|
||||
// 注意:
|
||||
// - 调用方必须关闭 resp.Body,否则会导致连接泄漏
|
||||
// - TLS 指纹客户端与普通客户端使用不同的缓存键,互不影响
|
||||
DoWithTLS(req *http.Request, proxyURL string, accountID int64, accountConcurrency int, enableTLSFingerprint bool) (*http.Response, error)
|
||||
}
|
||||
|
||||
@@ -8,9 +8,11 @@ import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"log"
|
||||
"log/slog"
|
||||
"net/http"
|
||||
"regexp"
|
||||
"strconv"
|
||||
"strings"
|
||||
"time"
|
||||
)
|
||||
|
||||
@@ -49,6 +51,13 @@ type Fingerprint struct {
|
||||
type IdentityCache interface {
|
||||
GetFingerprint(ctx context.Context, accountID int64) (*Fingerprint, error)
|
||||
SetFingerprint(ctx context.Context, accountID int64, fp *Fingerprint) error
|
||||
// GetMaskedSessionID 获取固定的会话ID(用于会话ID伪装功能)
|
||||
// 返回的 sessionID 是一个 UUID 格式的字符串
|
||||
// 如果不存在或已过期(15分钟无请求),返回空字符串
|
||||
GetMaskedSessionID(ctx context.Context, accountID int64) (string, error)
|
||||
// SetMaskedSessionID 设置固定的会话ID,TTL 为 15 分钟
|
||||
// 每次调用都会刷新 TTL
|
||||
SetMaskedSessionID(ctx context.Context, accountID int64, sessionID string) error
|
||||
}
|
||||
|
||||
// IdentityService 管理OAuth账号的请求身份指纹
|
||||
@@ -203,6 +212,94 @@ func (s *IdentityService) RewriteUserID(body []byte, accountID int64, accountUUI
|
||||
return json.Marshal(reqMap)
|
||||
}
|
||||
|
||||
// RewriteUserIDWithMasking 重写body中的metadata.user_id,支持会话ID伪装
|
||||
// 如果账号启用了会话ID伪装(session_id_masking_enabled),
|
||||
// 则在完成常规重写后,将 session 部分替换为固定的伪装ID(15分钟内保持不变)
|
||||
func (s *IdentityService) RewriteUserIDWithMasking(ctx context.Context, body []byte, account *Account, accountUUID, cachedClientID string) ([]byte, error) {
|
||||
// 先执行常规的 RewriteUserID 逻辑
|
||||
newBody, err := s.RewriteUserID(body, account.ID, accountUUID, cachedClientID)
|
||||
if err != nil {
|
||||
return newBody, err
|
||||
}
|
||||
|
||||
// 检查是否启用会话ID伪装
|
||||
if !account.IsSessionIDMaskingEnabled() {
|
||||
return newBody, nil
|
||||
}
|
||||
|
||||
// 解析重写后的 body,提取 user_id
|
||||
var reqMap map[string]any
|
||||
if err := json.Unmarshal(newBody, &reqMap); err != nil {
|
||||
return newBody, nil
|
||||
}
|
||||
|
||||
metadata, ok := reqMap["metadata"].(map[string]any)
|
||||
if !ok {
|
||||
return newBody, nil
|
||||
}
|
||||
|
||||
userID, ok := metadata["user_id"].(string)
|
||||
if !ok || userID == "" {
|
||||
return newBody, nil
|
||||
}
|
||||
|
||||
// 查找 _session_ 的位置,替换其后的内容
|
||||
const sessionMarker = "_session_"
|
||||
idx := strings.LastIndex(userID, sessionMarker)
|
||||
if idx == -1 {
|
||||
return newBody, nil
|
||||
}
|
||||
|
||||
// 获取或生成固定的伪装 session ID
|
||||
maskedSessionID, err := s.cache.GetMaskedSessionID(ctx, account.ID)
|
||||
if err != nil {
|
||||
log.Printf("Warning: failed to get masked session ID for account %d: %v", account.ID, err)
|
||||
return newBody, nil
|
||||
}
|
||||
|
||||
if maskedSessionID == "" {
|
||||
// 首次或已过期,生成新的伪装 session ID
|
||||
maskedSessionID = generateRandomUUID()
|
||||
log.Printf("Generated new masked session ID for account %d: %s", account.ID, maskedSessionID)
|
||||
}
|
||||
|
||||
// 刷新 TTL(每次请求都刷新,保持 15 分钟有效期)
|
||||
if err := s.cache.SetMaskedSessionID(ctx, account.ID, maskedSessionID); err != nil {
|
||||
log.Printf("Warning: failed to set masked session ID for account %d: %v", account.ID, err)
|
||||
}
|
||||
|
||||
// 替换 session 部分:保留 _session_ 之前的内容,替换之后的内容
|
||||
newUserID := userID[:idx+len(sessionMarker)] + maskedSessionID
|
||||
|
||||
slog.Debug("session_id_masking_applied",
|
||||
"account_id", account.ID,
|
||||
"before", userID,
|
||||
"after", newUserID,
|
||||
)
|
||||
|
||||
metadata["user_id"] = newUserID
|
||||
reqMap["metadata"] = metadata
|
||||
|
||||
return json.Marshal(reqMap)
|
||||
}
|
||||
|
||||
// generateRandomUUID 生成随机 UUID v4 格式字符串
|
||||
func generateRandomUUID() string {
|
||||
b := make([]byte, 16)
|
||||
if _, err := rand.Read(b); err != nil {
|
||||
// fallback: 使用时间戳生成
|
||||
h := sha256.Sum256([]byte(fmt.Sprintf("%d", time.Now().UnixNano())))
|
||||
b = h[:16]
|
||||
}
|
||||
|
||||
// 设置 UUID v4 版本和变体位
|
||||
b[6] = (b[6] & 0x0f) | 0x40
|
||||
b[8] = (b[8] & 0x3f) | 0x80
|
||||
|
||||
return fmt.Sprintf("%x-%x-%x-%x-%x",
|
||||
b[0:4], b[4:6], b[6:8], b[8:10], b[10:16])
|
||||
}
|
||||
|
||||
// generateClientID 生成64位十六进制客户端ID(32字节随机数)
|
||||
func generateClientID() string {
|
||||
b := make([]byte, 32)
|
||||
|
||||
@@ -48,8 +48,7 @@ type GenerateAuthURLResult struct {
|
||||
|
||||
// GenerateAuthURL generates an OAuth authorization URL with full scope
|
||||
func (s *OAuthService) GenerateAuthURL(ctx context.Context, proxyID *int64) (*GenerateAuthURLResult, error) {
|
||||
scope := fmt.Sprintf("%s %s", oauth.ScopeProfile, oauth.ScopeInference)
|
||||
return s.generateAuthURLWithScope(ctx, scope, proxyID)
|
||||
return s.generateAuthURLWithScope(ctx, oauth.ScopeOAuth, proxyID)
|
||||
}
|
||||
|
||||
// GenerateSetupTokenURL generates an OAuth authorization URL for setup token (inference only)
|
||||
@@ -176,7 +175,8 @@ func (s *OAuthService) CookieAuth(ctx context.Context, input *CookieAuthInput) (
|
||||
}
|
||||
|
||||
// Determine scope and if this is a setup token
|
||||
scope := fmt.Sprintf("%s %s", oauth.ScopeProfile, oauth.ScopeInference)
|
||||
// Internal API call uses ScopeAPI (org:create_api_key not supported)
|
||||
scope := oauth.ScopeAPI
|
||||
isSetupToken := false
|
||||
if input.Scope == "inference" {
|
||||
scope = oauth.ScopeInference
|
||||
|
||||
@@ -394,19 +394,35 @@ func normalizeCodexTools(reqBody map[string]any) bool {
|
||||
}
|
||||
|
||||
modified := false
|
||||
for idx, tool := range tools {
|
||||
validTools := make([]any, 0, len(tools))
|
||||
|
||||
for _, tool := range tools {
|
||||
toolMap, ok := tool.(map[string]any)
|
||||
if !ok {
|
||||
// Keep unknown structure as-is to avoid breaking upstream behavior.
|
||||
validTools = append(validTools, tool)
|
||||
continue
|
||||
}
|
||||
|
||||
toolType, _ := toolMap["type"].(string)
|
||||
if strings.TrimSpace(toolType) != "function" {
|
||||
toolType = strings.TrimSpace(toolType)
|
||||
if toolType != "function" {
|
||||
validTools = append(validTools, toolMap)
|
||||
continue
|
||||
}
|
||||
|
||||
function, ok := toolMap["function"].(map[string]any)
|
||||
if !ok {
|
||||
// OpenAI Responses-style tools use top-level name/parameters.
|
||||
if name, ok := toolMap["name"].(string); ok && strings.TrimSpace(name) != "" {
|
||||
validTools = append(validTools, toolMap)
|
||||
continue
|
||||
}
|
||||
|
||||
// ChatCompletions-style tools use {type:"function", function:{...}}.
|
||||
functionValue, hasFunction := toolMap["function"]
|
||||
function, ok := functionValue.(map[string]any)
|
||||
if !hasFunction || functionValue == nil || !ok || function == nil {
|
||||
// Drop invalid function tools.
|
||||
modified = true
|
||||
continue
|
||||
}
|
||||
|
||||
@@ -435,11 +451,11 @@ func normalizeCodexTools(reqBody map[string]any) bool {
|
||||
}
|
||||
}
|
||||
|
||||
tools[idx] = toolMap
|
||||
validTools = append(validTools, toolMap)
|
||||
}
|
||||
|
||||
if modified {
|
||||
reqBody["tools"] = tools
|
||||
reqBody["tools"] = validTools
|
||||
}
|
||||
|
||||
return modified
|
||||
|
||||
@@ -129,6 +129,37 @@ func TestFilterCodexInput_RemovesItemReferenceWhenNotPreserved(t *testing.T) {
|
||||
require.False(t, hasID)
|
||||
}
|
||||
|
||||
func TestApplyCodexOAuthTransform_NormalizeCodexTools_PreservesResponsesFunctionTools(t *testing.T) {
|
||||
setupCodexCache(t)
|
||||
|
||||
reqBody := map[string]any{
|
||||
"model": "gpt-5.1",
|
||||
"tools": []any{
|
||||
map[string]any{
|
||||
"type": "function",
|
||||
"name": "bash",
|
||||
"description": "desc",
|
||||
"parameters": map[string]any{"type": "object"},
|
||||
},
|
||||
map[string]any{
|
||||
"type": "function",
|
||||
"function": nil,
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
applyCodexOAuthTransform(reqBody)
|
||||
|
||||
tools, ok := reqBody["tools"].([]any)
|
||||
require.True(t, ok)
|
||||
require.Len(t, tools, 1)
|
||||
|
||||
first, ok := tools[0].(map[string]any)
|
||||
require.True(t, ok)
|
||||
require.Equal(t, "function", first["type"])
|
||||
require.Equal(t, "bash", first["name"])
|
||||
}
|
||||
|
||||
func TestApplyCodexOAuthTransform_EmptyInput(t *testing.T) {
|
||||
// 空 input 应保持为空且不触发异常。
|
||||
setupCodexCache(t)
|
||||
|
||||
@@ -133,12 +133,30 @@ func NewOpenAIGatewayService(
|
||||
}
|
||||
}
|
||||
|
||||
// GenerateSessionHash generates session hash from header (OpenAI uses session_id header)
|
||||
func (s *OpenAIGatewayService) GenerateSessionHash(c *gin.Context) string {
|
||||
sessionID := c.GetHeader("session_id")
|
||||
// GenerateSessionHash generates a sticky-session hash for OpenAI requests.
|
||||
//
|
||||
// Priority:
|
||||
// 1. Header: session_id
|
||||
// 2. Header: conversation_id
|
||||
// 3. Body: prompt_cache_key (opencode)
|
||||
func (s *OpenAIGatewayService) GenerateSessionHash(c *gin.Context, reqBody map[string]any) string {
|
||||
if c == nil {
|
||||
return ""
|
||||
}
|
||||
|
||||
sessionID := strings.TrimSpace(c.GetHeader("session_id"))
|
||||
if sessionID == "" {
|
||||
sessionID = strings.TrimSpace(c.GetHeader("conversation_id"))
|
||||
}
|
||||
if sessionID == "" && reqBody != nil {
|
||||
if v, ok := reqBody["prompt_cache_key"].(string); ok {
|
||||
sessionID = strings.TrimSpace(v)
|
||||
}
|
||||
}
|
||||
if sessionID == "" {
|
||||
return ""
|
||||
}
|
||||
|
||||
hash := sha256.Sum256([]byte(sessionID))
|
||||
return hex.EncodeToString(hash[:])
|
||||
}
|
||||
|
||||
@@ -68,6 +68,49 @@ func (c stubConcurrencyCache) GetAccountsLoadBatch(ctx context.Context, accounts
|
||||
return out, nil
|
||||
}
|
||||
|
||||
func TestOpenAIGatewayService_GenerateSessionHash_Priority(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
rec := httptest.NewRecorder()
|
||||
c, _ := gin.CreateTestContext(rec)
|
||||
c.Request = httptest.NewRequest(http.MethodPost, "/openai/v1/responses", nil)
|
||||
|
||||
svc := &OpenAIGatewayService{}
|
||||
|
||||
// 1) session_id header wins
|
||||
c.Request.Header.Set("session_id", "sess-123")
|
||||
c.Request.Header.Set("conversation_id", "conv-456")
|
||||
h1 := svc.GenerateSessionHash(c, map[string]any{"prompt_cache_key": "ses_aaa"})
|
||||
if h1 == "" {
|
||||
t.Fatalf("expected non-empty hash")
|
||||
}
|
||||
|
||||
// 2) conversation_id used when session_id absent
|
||||
c.Request.Header.Del("session_id")
|
||||
h2 := svc.GenerateSessionHash(c, map[string]any{"prompt_cache_key": "ses_aaa"})
|
||||
if h2 == "" {
|
||||
t.Fatalf("expected non-empty hash")
|
||||
}
|
||||
if h1 == h2 {
|
||||
t.Fatalf("expected different hashes for different keys")
|
||||
}
|
||||
|
||||
// 3) prompt_cache_key used when both headers absent
|
||||
c.Request.Header.Del("conversation_id")
|
||||
h3 := svc.GenerateSessionHash(c, map[string]any{"prompt_cache_key": "ses_aaa"})
|
||||
if h3 == "" {
|
||||
t.Fatalf("expected non-empty hash")
|
||||
}
|
||||
if h2 == h3 {
|
||||
t.Fatalf("expected different hashes for different keys")
|
||||
}
|
||||
|
||||
// 4) empty when no signals
|
||||
h4 := svc.GenerateSessionHash(c, map[string]any{})
|
||||
if h4 != "" {
|
||||
t.Fatalf("expected empty hash when no signals")
|
||||
}
|
||||
}
|
||||
|
||||
func TestOpenAISelectAccountWithLoadAwareness_FiltersUnschedulable(t *testing.T) {
|
||||
now := time.Now()
|
||||
resetAt := now.Add(10 * time.Minute)
|
||||
|
||||
@@ -27,6 +27,11 @@ var codexToolNameMapping = map[string]string{
|
||||
"executeBash": "bash",
|
||||
"exec_bash": "bash",
|
||||
"execBash": "bash",
|
||||
|
||||
// Some clients output generic fetch names.
|
||||
"fetch": "webfetch",
|
||||
"web_fetch": "webfetch",
|
||||
"webFetch": "webfetch",
|
||||
}
|
||||
|
||||
// ToolCorrectionStats 记录工具修正的统计信息(导出用于 JSON 序列化)
|
||||
@@ -208,27 +213,67 @@ func (c *CodexToolCorrector) correctToolParameters(toolName string, functionCall
|
||||
// 根据工具名称应用特定的参数修正规则
|
||||
switch toolName {
|
||||
case "bash":
|
||||
// 移除 workdir 参数(OpenCode 不支持)
|
||||
if _, exists := argsMap["workdir"]; exists {
|
||||
delete(argsMap, "workdir")
|
||||
corrected = true
|
||||
log.Printf("[CodexToolCorrector] Removed 'workdir' parameter from bash tool")
|
||||
}
|
||||
if _, exists := argsMap["work_dir"]; exists {
|
||||
delete(argsMap, "work_dir")
|
||||
corrected = true
|
||||
log.Printf("[CodexToolCorrector] Removed 'work_dir' parameter from bash tool")
|
||||
// OpenCode bash 支持 workdir;有些来源会输出 work_dir。
|
||||
if _, hasWorkdir := argsMap["workdir"]; !hasWorkdir {
|
||||
if workDir, exists := argsMap["work_dir"]; exists {
|
||||
argsMap["workdir"] = workDir
|
||||
delete(argsMap, "work_dir")
|
||||
corrected = true
|
||||
log.Printf("[CodexToolCorrector] Renamed 'work_dir' to 'workdir' in bash tool")
|
||||
}
|
||||
} else {
|
||||
if _, exists := argsMap["work_dir"]; exists {
|
||||
delete(argsMap, "work_dir")
|
||||
corrected = true
|
||||
log.Printf("[CodexToolCorrector] Removed duplicate 'work_dir' parameter from bash tool")
|
||||
}
|
||||
}
|
||||
|
||||
case "edit":
|
||||
// OpenCode edit 使用 old_string/new_string,Codex 可能使用其他名称
|
||||
// 这里可以添加参数名称的映射逻辑
|
||||
if _, exists := argsMap["file_path"]; !exists {
|
||||
if path, exists := argsMap["path"]; exists {
|
||||
argsMap["file_path"] = path
|
||||
// OpenCode edit 参数为 filePath/oldString/newString(camelCase)。
|
||||
if _, exists := argsMap["filePath"]; !exists {
|
||||
if filePath, exists := argsMap["file_path"]; exists {
|
||||
argsMap["filePath"] = filePath
|
||||
delete(argsMap, "file_path")
|
||||
corrected = true
|
||||
log.Printf("[CodexToolCorrector] Renamed 'file_path' to 'filePath' in edit tool")
|
||||
} else if filePath, exists := argsMap["path"]; exists {
|
||||
argsMap["filePath"] = filePath
|
||||
delete(argsMap, "path")
|
||||
corrected = true
|
||||
log.Printf("[CodexToolCorrector] Renamed 'path' to 'file_path' in edit tool")
|
||||
log.Printf("[CodexToolCorrector] Renamed 'path' to 'filePath' in edit tool")
|
||||
} else if filePath, exists := argsMap["file"]; exists {
|
||||
argsMap["filePath"] = filePath
|
||||
delete(argsMap, "file")
|
||||
corrected = true
|
||||
log.Printf("[CodexToolCorrector] Renamed 'file' to 'filePath' in edit tool")
|
||||
}
|
||||
}
|
||||
|
||||
if _, exists := argsMap["oldString"]; !exists {
|
||||
if oldString, exists := argsMap["old_string"]; exists {
|
||||
argsMap["oldString"] = oldString
|
||||
delete(argsMap, "old_string")
|
||||
corrected = true
|
||||
log.Printf("[CodexToolCorrector] Renamed 'old_string' to 'oldString' in edit tool")
|
||||
}
|
||||
}
|
||||
|
||||
if _, exists := argsMap["newString"]; !exists {
|
||||
if newString, exists := argsMap["new_string"]; exists {
|
||||
argsMap["newString"] = newString
|
||||
delete(argsMap, "new_string")
|
||||
corrected = true
|
||||
log.Printf("[CodexToolCorrector] Renamed 'new_string' to 'newString' in edit tool")
|
||||
}
|
||||
}
|
||||
|
||||
if _, exists := argsMap["replaceAll"]; !exists {
|
||||
if replaceAll, exists := argsMap["replace_all"]; exists {
|
||||
argsMap["replaceAll"] = replaceAll
|
||||
delete(argsMap, "replace_all")
|
||||
corrected = true
|
||||
log.Printf("[CodexToolCorrector] Renamed 'replace_all' to 'replaceAll' in edit tool")
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -416,22 +416,23 @@ func TestCorrectToolParameters(t *testing.T) {
|
||||
expected map[string]bool // key: 期待存在的参数, value: true表示应该存在
|
||||
}{
|
||||
{
|
||||
name: "remove workdir from bash tool",
|
||||
name: "rename work_dir to workdir in bash tool",
|
||||
input: `{
|
||||
"tool_calls": [{
|
||||
"function": {
|
||||
"name": "bash",
|
||||
"arguments": "{\"command\":\"ls\",\"workdir\":\"/tmp\"}"
|
||||
"arguments": "{\"command\":\"ls\",\"work_dir\":\"/tmp\"}"
|
||||
}
|
||||
}]
|
||||
}`,
|
||||
expected: map[string]bool{
|
||||
"command": true,
|
||||
"workdir": false,
|
||||
"command": true,
|
||||
"workdir": true,
|
||||
"work_dir": false,
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "rename path to file_path in edit tool",
|
||||
name: "rename snake_case edit params to camelCase",
|
||||
input: `{
|
||||
"tool_calls": [{
|
||||
"function": {
|
||||
@@ -441,10 +442,12 @@ func TestCorrectToolParameters(t *testing.T) {
|
||||
}]
|
||||
}`,
|
||||
expected: map[string]bool{
|
||||
"file_path": true,
|
||||
"filePath": true,
|
||||
"path": false,
|
||||
"old_string": true,
|
||||
"new_string": true,
|
||||
"oldString": true,
|
||||
"old_string": false,
|
||||
"newString": true,
|
||||
"new_string": false,
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
@@ -531,8 +531,8 @@ func (s *PricingService) buildModelLookupCandidates(modelLower string) []string
|
||||
func normalizeModelNameForPricing(model string) string {
|
||||
// Common Gemini/VertexAI forms:
|
||||
// - models/gemini-2.0-flash-exp
|
||||
// - publishers/google/models/gemini-1.5-pro
|
||||
// - projects/.../locations/.../publishers/google/models/gemini-1.5-pro
|
||||
// - publishers/google/models/gemini-2.5-pro
|
||||
// - projects/.../locations/.../publishers/google/models/gemini-2.5-pro
|
||||
model = strings.TrimSpace(model)
|
||||
model = strings.TrimLeft(model, "/")
|
||||
model = strings.TrimPrefix(model, "models/")
|
||||
|
||||
@@ -73,10 +73,14 @@ func (s *RateLimitService) HandleUpstreamError(ctx context.Context, account *Acc
|
||||
return false
|
||||
}
|
||||
|
||||
tempMatched := false
|
||||
// 先尝试临时不可调度规则(401除外)
|
||||
// 如果匹配成功,直接返回,不执行后续禁用逻辑
|
||||
if statusCode != 401 {
|
||||
tempMatched = s.tryTempUnschedulable(ctx, account, statusCode, responseBody)
|
||||
if s.tryTempUnschedulable(ctx, account, statusCode, responseBody) {
|
||||
return true
|
||||
}
|
||||
}
|
||||
|
||||
upstreamMsg := strings.TrimSpace(extractUpstreamErrorMessage(responseBody))
|
||||
upstreamMsg = sanitizeUpstreamErrorMessage(upstreamMsg)
|
||||
if upstreamMsg != "" {
|
||||
@@ -84,6 +88,14 @@ func (s *RateLimitService) HandleUpstreamError(ctx context.Context, account *Acc
|
||||
}
|
||||
|
||||
switch statusCode {
|
||||
case 400:
|
||||
// 只有当错误信息包含 "organization has been disabled" 时才禁用
|
||||
if strings.Contains(strings.ToLower(upstreamMsg), "organization has been disabled") {
|
||||
msg := "Organization disabled (400): " + upstreamMsg
|
||||
s.handleAuthError(ctx, account, msg)
|
||||
shouldDisable = true
|
||||
}
|
||||
// 其他 400 错误(如参数问题)不处理,不禁用账号
|
||||
case 401:
|
||||
// 对所有 OAuth 账号在 401 错误时调用缓存失效并强制下次刷新
|
||||
if account.Type == AccountTypeOAuth {
|
||||
@@ -148,9 +160,6 @@ func (s *RateLimitService) HandleUpstreamError(ctx context.Context, account *Acc
|
||||
}
|
||||
}
|
||||
|
||||
if tempMatched {
|
||||
return true
|
||||
}
|
||||
return shouldDisable
|
||||
}
|
||||
|
||||
@@ -190,7 +199,7 @@ func (s *RateLimitService) PreCheckUsage(ctx context.Context, account *Account,
|
||||
start := geminiDailyWindowStart(now)
|
||||
totals, ok := s.getGeminiUsageTotals(account.ID, start, now)
|
||||
if !ok {
|
||||
stats, err := s.usageRepo.GetModelStatsWithFilters(ctx, start, now, 0, 0, account.ID, 0, nil)
|
||||
stats, err := s.usageRepo.GetModelStatsWithFilters(ctx, start, now, 0, 0, account.ID, 0, nil, nil)
|
||||
if err != nil {
|
||||
return true, err
|
||||
}
|
||||
@@ -237,7 +246,7 @@ func (s *RateLimitService) PreCheckUsage(ctx context.Context, account *Account,
|
||||
|
||||
if limit > 0 {
|
||||
start := now.Truncate(time.Minute)
|
||||
stats, err := s.usageRepo.GetModelStatsWithFilters(ctx, start, now, 0, 0, account.ID, 0, nil)
|
||||
stats, err := s.usageRepo.GetModelStatsWithFilters(ctx, start, now, 0, 0, account.ID, 0, nil, nil)
|
||||
if err != nil {
|
||||
return true, err
|
||||
}
|
||||
|
||||
@@ -38,8 +38,9 @@ type SessionLimitCache interface {
|
||||
GetActiveSessionCount(ctx context.Context, accountID int64) (int, error)
|
||||
|
||||
// GetActiveSessionCountBatch 批量获取多个账号的活跃会话数
|
||||
// idleTimeouts: 每个账号的空闲超时时间配置,key 为 accountID;若为 nil 或某账号不在其中,则使用默认超时
|
||||
// 返回 map[accountID]count,查询失败的账号不在 map 中
|
||||
GetActiveSessionCountBatch(ctx context.Context, accountIDs []int64) (map[int64]int, error)
|
||||
GetActiveSessionCountBatch(ctx context.Context, accountIDs []int64, idleTimeouts map[int64]time.Duration) (map[int64]int, error)
|
||||
|
||||
// IsSessionActive 检查特定会话是否活跃(未过期)
|
||||
IsSessionActive(ctx context.Context, accountID int64, sessionUUID string) (bool, error)
|
||||
|
||||
@@ -69,6 +69,7 @@ func (s *SettingService) GetPublicSettings(ctx context.Context) (*PublicSettings
|
||||
SettingKeyContactInfo,
|
||||
SettingKeyDocURL,
|
||||
SettingKeyHomeContent,
|
||||
SettingKeyHideCcsImportButton,
|
||||
SettingKeyLinuxDoConnectEnabled,
|
||||
}
|
||||
|
||||
@@ -96,6 +97,7 @@ func (s *SettingService) GetPublicSettings(ctx context.Context) (*PublicSettings
|
||||
ContactInfo: settings[SettingKeyContactInfo],
|
||||
DocURL: settings[SettingKeyDocURL],
|
||||
HomeContent: settings[SettingKeyHomeContent],
|
||||
HideCcsImportButton: settings[SettingKeyHideCcsImportButton] == "true",
|
||||
LinuxDoOAuthEnabled: linuxDoEnabled,
|
||||
}, nil
|
||||
}
|
||||
@@ -132,6 +134,7 @@ func (s *SettingService) GetPublicSettingsForInjection(ctx context.Context) (any
|
||||
ContactInfo string `json:"contact_info,omitempty"`
|
||||
DocURL string `json:"doc_url,omitempty"`
|
||||
HomeContent string `json:"home_content,omitempty"`
|
||||
HideCcsImportButton bool `json:"hide_ccs_import_button"`
|
||||
LinuxDoOAuthEnabled bool `json:"linuxdo_oauth_enabled"`
|
||||
Version string `json:"version,omitempty"`
|
||||
}{
|
||||
@@ -146,6 +149,7 @@ func (s *SettingService) GetPublicSettingsForInjection(ctx context.Context) (any
|
||||
ContactInfo: settings.ContactInfo,
|
||||
DocURL: settings.DocURL,
|
||||
HomeContent: settings.HomeContent,
|
||||
HideCcsImportButton: settings.HideCcsImportButton,
|
||||
LinuxDoOAuthEnabled: settings.LinuxDoOAuthEnabled,
|
||||
Version: s.version,
|
||||
}, nil
|
||||
@@ -193,6 +197,7 @@ func (s *SettingService) UpdateSettings(ctx context.Context, settings *SystemSet
|
||||
updates[SettingKeyContactInfo] = settings.ContactInfo
|
||||
updates[SettingKeyDocURL] = settings.DocURL
|
||||
updates[SettingKeyHomeContent] = settings.HomeContent
|
||||
updates[SettingKeyHideCcsImportButton] = strconv.FormatBool(settings.HideCcsImportButton)
|
||||
|
||||
// 默认配置
|
||||
updates[SettingKeyDefaultConcurrency] = strconv.Itoa(settings.DefaultConcurrency)
|
||||
@@ -339,6 +344,7 @@ func (s *SettingService) parseSettings(settings map[string]string) *SystemSettin
|
||||
ContactInfo: settings[SettingKeyContactInfo],
|
||||
DocURL: settings[SettingKeyDocURL],
|
||||
HomeContent: settings[SettingKeyHomeContent],
|
||||
HideCcsImportButton: settings[SettingKeyHideCcsImportButton] == "true",
|
||||
}
|
||||
|
||||
// 解析整数类型
|
||||
|
||||
@@ -25,13 +25,14 @@ type SystemSettings struct {
|
||||
LinuxDoConnectClientSecretConfigured bool
|
||||
LinuxDoConnectRedirectURL string
|
||||
|
||||
SiteName string
|
||||
SiteLogo string
|
||||
SiteSubtitle string
|
||||
APIBaseURL string
|
||||
ContactInfo string
|
||||
DocURL string
|
||||
HomeContent string
|
||||
SiteName string
|
||||
SiteLogo string
|
||||
SiteSubtitle string
|
||||
APIBaseURL string
|
||||
ContactInfo string
|
||||
DocURL string
|
||||
HomeContent string
|
||||
HideCcsImportButton bool
|
||||
|
||||
DefaultConcurrency int
|
||||
DefaultBalance float64
|
||||
@@ -66,6 +67,7 @@ type PublicSettings struct {
|
||||
ContactInfo string
|
||||
DocURL string
|
||||
HomeContent string
|
||||
HideCcsImportButton bool
|
||||
LinuxDoOAuthEnabled bool
|
||||
Version string
|
||||
}
|
||||
|
||||
@@ -27,6 +27,7 @@ var (
|
||||
ErrWeeklyLimitExceeded = infraerrors.TooManyRequests("WEEKLY_LIMIT_EXCEEDED", "weekly usage limit exceeded")
|
||||
ErrMonthlyLimitExceeded = infraerrors.TooManyRequests("MONTHLY_LIMIT_EXCEEDED", "monthly usage limit exceeded")
|
||||
ErrSubscriptionNilInput = infraerrors.BadRequest("SUBSCRIPTION_NIL_INPUT", "subscription input cannot be nil")
|
||||
ErrAdjustWouldExpire = infraerrors.BadRequest("ADJUST_WOULD_EXPIRE", "adjustment would result in expired subscription (remaining days must be > 0)")
|
||||
)
|
||||
|
||||
// SubscriptionService 订阅服务
|
||||
@@ -308,17 +309,20 @@ func (s *SubscriptionService) RevokeSubscription(ctx context.Context, subscripti
|
||||
return nil
|
||||
}
|
||||
|
||||
// ExtendSubscription 延长订阅
|
||||
// ExtendSubscription 调整订阅时长(正数延长,负数缩短)
|
||||
func (s *SubscriptionService) ExtendSubscription(ctx context.Context, subscriptionID int64, days int) (*UserSubscription, error) {
|
||||
sub, err := s.userSubRepo.GetByID(ctx, subscriptionID)
|
||||
if err != nil {
|
||||
return nil, ErrSubscriptionNotFound
|
||||
}
|
||||
|
||||
// 限制延长天数
|
||||
// 限制调整天数范围
|
||||
if days > MaxValidityDays {
|
||||
days = MaxValidityDays
|
||||
}
|
||||
if days < -MaxValidityDays {
|
||||
days = -MaxValidityDays
|
||||
}
|
||||
|
||||
// 计算新的过期时间
|
||||
newExpiresAt := sub.ExpiresAt.AddDate(0, 0, days)
|
||||
@@ -326,6 +330,14 @@ func (s *SubscriptionService) ExtendSubscription(ctx context.Context, subscripti
|
||||
newExpiresAt = MaxExpiresAt
|
||||
}
|
||||
|
||||
// 如果是缩短(负数),检查新的过期时间必须大于当前时间
|
||||
if days < 0 {
|
||||
now := time.Now()
|
||||
if !newExpiresAt.After(now) {
|
||||
return nil, ErrAdjustWouldExpire
|
||||
}
|
||||
}
|
||||
|
||||
if err := s.userSubRepo.ExtendExpiry(ctx, subscriptionID, newExpiresAt); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
@@ -166,11 +166,25 @@ func (s *TokenRefreshService) refreshWithRetry(ctx context.Context, account *Acc
|
||||
|
||||
for attempt := 1; attempt <= s.cfg.MaxRetries; attempt++ {
|
||||
newCredentials, err := refresher.Refresh(ctx, account)
|
||||
if err == nil {
|
||||
// 刷新成功,更新账号credentials
|
||||
|
||||
// 如果有新凭证,先更新(即使有错误也要保存 token)
|
||||
if newCredentials != nil {
|
||||
account.Credentials = newCredentials
|
||||
if err := s.accountRepo.Update(ctx, account); err != nil {
|
||||
return fmt.Errorf("failed to save credentials: %w", err)
|
||||
if saveErr := s.accountRepo.Update(ctx, account); saveErr != nil {
|
||||
return fmt.Errorf("failed to save credentials: %w", saveErr)
|
||||
}
|
||||
}
|
||||
|
||||
if err == nil {
|
||||
// Antigravity 账户:如果之前是因为缺少 project_id 而标记为 error,现在成功获取到了,清除错误状态
|
||||
if account.Platform == PlatformAntigravity &&
|
||||
account.Status == StatusError &&
|
||||
strings.Contains(account.ErrorMessage, "missing_project_id:") {
|
||||
if clearErr := s.accountRepo.ClearError(ctx, account.ID); clearErr != nil {
|
||||
log.Printf("[TokenRefresh] Failed to clear error status for account %d: %v", account.ID, clearErr)
|
||||
} else {
|
||||
log.Printf("[TokenRefresh] Account %d: cleared missing_project_id error", account.ID)
|
||||
}
|
||||
}
|
||||
// 对所有 OAuth 账号调用缓存失效(InvalidateToken 内部根据平台判断是否需要处理)
|
||||
if s.cacheInvalidator != nil && account.Type == AccountTypeOAuth {
|
||||
@@ -230,6 +244,7 @@ func isNonRetryableRefreshError(err error) bool {
|
||||
"invalid_client", // 客户端配置错误
|
||||
"unauthorized_client", // 客户端未授权
|
||||
"access_denied", // 访问被拒绝
|
||||
"missing_project_id", // 缺少 project_id
|
||||
}
|
||||
for _, needle := range nonRetryable {
|
||||
if strings.Contains(msg, needle) {
|
||||
|
||||
74
backend/internal/service/usage_cleanup.go
Normal file
74
backend/internal/service/usage_cleanup.go
Normal file
@@ -0,0 +1,74 @@
|
||||
package service
|
||||
|
||||
import (
|
||||
"context"
|
||||
"time"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/pagination"
|
||||
)
|
||||
|
||||
const (
|
||||
UsageCleanupStatusPending = "pending"
|
||||
UsageCleanupStatusRunning = "running"
|
||||
UsageCleanupStatusSucceeded = "succeeded"
|
||||
UsageCleanupStatusFailed = "failed"
|
||||
UsageCleanupStatusCanceled = "canceled"
|
||||
)
|
||||
|
||||
// UsageCleanupFilters 定义清理任务过滤条件
|
||||
// 时间范围为必填,其他字段可选
|
||||
// JSON 序列化用于存储任务参数
|
||||
//
|
||||
// start_time/end_time 使用 RFC3339 时间格式
|
||||
// 以 UTC 或用户时区解析后的时间为准
|
||||
//
|
||||
// 说明:
|
||||
// - nil 表示未设置该过滤条件
|
||||
// - 过滤条件均为精确匹配
|
||||
type UsageCleanupFilters struct {
|
||||
StartTime time.Time `json:"start_time"`
|
||||
EndTime time.Time `json:"end_time"`
|
||||
UserID *int64 `json:"user_id,omitempty"`
|
||||
APIKeyID *int64 `json:"api_key_id,omitempty"`
|
||||
AccountID *int64 `json:"account_id,omitempty"`
|
||||
GroupID *int64 `json:"group_id,omitempty"`
|
||||
Model *string `json:"model,omitempty"`
|
||||
Stream *bool `json:"stream,omitempty"`
|
||||
BillingType *int8 `json:"billing_type,omitempty"`
|
||||
}
|
||||
|
||||
// UsageCleanupTask 表示使用记录清理任务
|
||||
// 状态包含 pending/running/succeeded/failed/canceled
|
||||
type UsageCleanupTask struct {
|
||||
ID int64
|
||||
Status string
|
||||
Filters UsageCleanupFilters
|
||||
CreatedBy int64
|
||||
DeletedRows int64
|
||||
ErrorMsg *string
|
||||
CanceledBy *int64
|
||||
CanceledAt *time.Time
|
||||
StartedAt *time.Time
|
||||
FinishedAt *time.Time
|
||||
CreatedAt time.Time
|
||||
UpdatedAt time.Time
|
||||
}
|
||||
|
||||
// UsageCleanupRepository 定义清理任务持久层接口
|
||||
type UsageCleanupRepository interface {
|
||||
CreateTask(ctx context.Context, task *UsageCleanupTask) error
|
||||
ListTasks(ctx context.Context, params pagination.PaginationParams) ([]UsageCleanupTask, *pagination.PaginationResult, error)
|
||||
// ClaimNextPendingTask 抢占下一条可执行任务:
|
||||
// - 优先 pending
|
||||
// - 若 running 超过 staleRunningAfterSeconds(可能由于进程退出/崩溃/超时),允许重新抢占继续执行
|
||||
ClaimNextPendingTask(ctx context.Context, staleRunningAfterSeconds int64) (*UsageCleanupTask, error)
|
||||
// GetTaskStatus 查询任务状态;若不存在返回 sql.ErrNoRows
|
||||
GetTaskStatus(ctx context.Context, taskID int64) (string, error)
|
||||
// UpdateTaskProgress 更新任务进度(deleted_rows)用于断点续跑/展示
|
||||
UpdateTaskProgress(ctx context.Context, taskID int64, deletedRows int64) error
|
||||
// CancelTask 将任务标记为 canceled(仅允许 pending/running)
|
||||
CancelTask(ctx context.Context, taskID int64, canceledBy int64) (bool, error)
|
||||
MarkTaskSucceeded(ctx context.Context, taskID int64, deletedRows int64) error
|
||||
MarkTaskFailed(ctx context.Context, taskID int64, deletedRows int64, errorMsg string) error
|
||||
DeleteUsageLogsBatch(ctx context.Context, filters UsageCleanupFilters, limit int) (int64, error)
|
||||
}
|
||||
404
backend/internal/service/usage_cleanup_service.go
Normal file
404
backend/internal/service/usage_cleanup_service.go
Normal file
@@ -0,0 +1,404 @@
|
||||
package service
|
||||
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"errors"
|
||||
"fmt"
|
||||
"log"
|
||||
"net/http"
|
||||
"strings"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
"time"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/config"
|
||||
infraerrors "github.com/Wei-Shaw/sub2api/internal/pkg/errors"
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/pagination"
|
||||
)
|
||||
|
||||
const (
|
||||
usageCleanupWorkerName = "usage_cleanup_worker"
|
||||
)
|
||||
|
||||
// UsageCleanupService 负责创建与执行使用记录清理任务
|
||||
type UsageCleanupService struct {
|
||||
repo UsageCleanupRepository
|
||||
timingWheel *TimingWheelService
|
||||
dashboard *DashboardAggregationService
|
||||
cfg *config.Config
|
||||
|
||||
running int32
|
||||
startOnce sync.Once
|
||||
stopOnce sync.Once
|
||||
|
||||
workerCtx context.Context
|
||||
workerCancel context.CancelFunc
|
||||
}
|
||||
|
||||
func NewUsageCleanupService(repo UsageCleanupRepository, timingWheel *TimingWheelService, dashboard *DashboardAggregationService, cfg *config.Config) *UsageCleanupService {
|
||||
workerCtx, workerCancel := context.WithCancel(context.Background())
|
||||
return &UsageCleanupService{
|
||||
repo: repo,
|
||||
timingWheel: timingWheel,
|
||||
dashboard: dashboard,
|
||||
cfg: cfg,
|
||||
workerCtx: workerCtx,
|
||||
workerCancel: workerCancel,
|
||||
}
|
||||
}
|
||||
|
||||
func describeUsageCleanupFilters(filters UsageCleanupFilters) string {
|
||||
var parts []string
|
||||
parts = append(parts, "start="+filters.StartTime.UTC().Format(time.RFC3339))
|
||||
parts = append(parts, "end="+filters.EndTime.UTC().Format(time.RFC3339))
|
||||
if filters.UserID != nil {
|
||||
parts = append(parts, fmt.Sprintf("user_id=%d", *filters.UserID))
|
||||
}
|
||||
if filters.APIKeyID != nil {
|
||||
parts = append(parts, fmt.Sprintf("api_key_id=%d", *filters.APIKeyID))
|
||||
}
|
||||
if filters.AccountID != nil {
|
||||
parts = append(parts, fmt.Sprintf("account_id=%d", *filters.AccountID))
|
||||
}
|
||||
if filters.GroupID != nil {
|
||||
parts = append(parts, fmt.Sprintf("group_id=%d", *filters.GroupID))
|
||||
}
|
||||
if filters.Model != nil {
|
||||
parts = append(parts, "model="+strings.TrimSpace(*filters.Model))
|
||||
}
|
||||
if filters.Stream != nil {
|
||||
parts = append(parts, fmt.Sprintf("stream=%t", *filters.Stream))
|
||||
}
|
||||
if filters.BillingType != nil {
|
||||
parts = append(parts, fmt.Sprintf("billing_type=%d", *filters.BillingType))
|
||||
}
|
||||
return strings.Join(parts, " ")
|
||||
}
|
||||
|
||||
func (s *UsageCleanupService) Start() {
|
||||
if s == nil {
|
||||
return
|
||||
}
|
||||
if s.cfg != nil && !s.cfg.UsageCleanup.Enabled {
|
||||
log.Printf("[UsageCleanup] not started (disabled)")
|
||||
return
|
||||
}
|
||||
if s.repo == nil || s.timingWheel == nil {
|
||||
log.Printf("[UsageCleanup] not started (missing deps)")
|
||||
return
|
||||
}
|
||||
|
||||
interval := s.workerInterval()
|
||||
s.startOnce.Do(func() {
|
||||
s.timingWheel.ScheduleRecurring(usageCleanupWorkerName, interval, s.runOnce)
|
||||
log.Printf("[UsageCleanup] started (interval=%s max_range_days=%d batch_size=%d task_timeout=%s)", interval, s.maxRangeDays(), s.batchSize(), s.taskTimeout())
|
||||
})
|
||||
}
|
||||
|
||||
func (s *UsageCleanupService) Stop() {
|
||||
if s == nil {
|
||||
return
|
||||
}
|
||||
s.stopOnce.Do(func() {
|
||||
if s.workerCancel != nil {
|
||||
s.workerCancel()
|
||||
}
|
||||
if s.timingWheel != nil {
|
||||
s.timingWheel.Cancel(usageCleanupWorkerName)
|
||||
}
|
||||
log.Printf("[UsageCleanup] stopped")
|
||||
})
|
||||
}
|
||||
|
||||
func (s *UsageCleanupService) ListTasks(ctx context.Context, params pagination.PaginationParams) ([]UsageCleanupTask, *pagination.PaginationResult, error) {
|
||||
if s == nil || s.repo == nil {
|
||||
return nil, nil, fmt.Errorf("cleanup service not ready")
|
||||
}
|
||||
return s.repo.ListTasks(ctx, params)
|
||||
}
|
||||
|
||||
func (s *UsageCleanupService) CreateTask(ctx context.Context, filters UsageCleanupFilters, createdBy int64) (*UsageCleanupTask, error) {
|
||||
if s == nil || s.repo == nil {
|
||||
return nil, fmt.Errorf("cleanup service not ready")
|
||||
}
|
||||
if s.cfg != nil && !s.cfg.UsageCleanup.Enabled {
|
||||
return nil, infraerrors.New(http.StatusServiceUnavailable, "USAGE_CLEANUP_DISABLED", "usage cleanup is disabled")
|
||||
}
|
||||
if createdBy <= 0 {
|
||||
return nil, infraerrors.BadRequest("USAGE_CLEANUP_INVALID_CREATOR", "invalid creator")
|
||||
}
|
||||
|
||||
log.Printf("[UsageCleanup] create_task requested: operator=%d %s", createdBy, describeUsageCleanupFilters(filters))
|
||||
sanitizeUsageCleanupFilters(&filters)
|
||||
if err := s.validateFilters(filters); err != nil {
|
||||
log.Printf("[UsageCleanup] create_task rejected: operator=%d err=%v %s", createdBy, err, describeUsageCleanupFilters(filters))
|
||||
return nil, err
|
||||
}
|
||||
|
||||
task := &UsageCleanupTask{
|
||||
Status: UsageCleanupStatusPending,
|
||||
Filters: filters,
|
||||
CreatedBy: createdBy,
|
||||
}
|
||||
if err := s.repo.CreateTask(ctx, task); err != nil {
|
||||
log.Printf("[UsageCleanup] create_task persist failed: operator=%d err=%v %s", createdBy, err, describeUsageCleanupFilters(filters))
|
||||
return nil, fmt.Errorf("create cleanup task: %w", err)
|
||||
}
|
||||
log.Printf("[UsageCleanup] create_task persisted: task=%d operator=%d status=%s deleted_rows=%d %s", task.ID, createdBy, task.Status, task.DeletedRows, describeUsageCleanupFilters(filters))
|
||||
go s.runOnce()
|
||||
return task, nil
|
||||
}
|
||||
|
||||
func (s *UsageCleanupService) runOnce() {
|
||||
svc := s
|
||||
if svc == nil {
|
||||
return
|
||||
}
|
||||
if !atomic.CompareAndSwapInt32(&svc.running, 0, 1) {
|
||||
log.Printf("[UsageCleanup] run_once skipped: already_running=true")
|
||||
return
|
||||
}
|
||||
defer atomic.StoreInt32(&svc.running, 0)
|
||||
|
||||
parent := context.Background()
|
||||
if svc.workerCtx != nil {
|
||||
parent = svc.workerCtx
|
||||
}
|
||||
ctx, cancel := context.WithTimeout(parent, svc.taskTimeout())
|
||||
defer cancel()
|
||||
|
||||
task, err := svc.repo.ClaimNextPendingTask(ctx, int64(svc.taskTimeout().Seconds()))
|
||||
if err != nil {
|
||||
log.Printf("[UsageCleanup] claim pending task failed: %v", err)
|
||||
return
|
||||
}
|
||||
if task == nil {
|
||||
log.Printf("[UsageCleanup] run_once done: no_task=true")
|
||||
return
|
||||
}
|
||||
|
||||
log.Printf("[UsageCleanup] task claimed: task=%d status=%s created_by=%d deleted_rows=%d %s", task.ID, task.Status, task.CreatedBy, task.DeletedRows, describeUsageCleanupFilters(task.Filters))
|
||||
svc.executeTask(ctx, task)
|
||||
}
|
||||
|
||||
func (s *UsageCleanupService) executeTask(ctx context.Context, task *UsageCleanupTask) {
|
||||
if task == nil {
|
||||
return
|
||||
}
|
||||
|
||||
batchSize := s.batchSize()
|
||||
deletedTotal := task.DeletedRows
|
||||
start := time.Now()
|
||||
log.Printf("[UsageCleanup] task started: task=%d batch_size=%d deleted_rows=%d %s", task.ID, batchSize, deletedTotal, describeUsageCleanupFilters(task.Filters))
|
||||
var batchNum int
|
||||
|
||||
for {
|
||||
if ctx != nil && ctx.Err() != nil {
|
||||
log.Printf("[UsageCleanup] task interrupted: task=%d err=%v", task.ID, ctx.Err())
|
||||
return
|
||||
}
|
||||
canceled, err := s.isTaskCanceled(ctx, task.ID)
|
||||
if err != nil {
|
||||
s.markTaskFailed(task.ID, deletedTotal, err)
|
||||
return
|
||||
}
|
||||
if canceled {
|
||||
log.Printf("[UsageCleanup] task canceled: task=%d deleted_rows=%d duration=%s", task.ID, deletedTotal, time.Since(start))
|
||||
return
|
||||
}
|
||||
|
||||
batchNum++
|
||||
deleted, err := s.repo.DeleteUsageLogsBatch(ctx, task.Filters, batchSize)
|
||||
if err != nil {
|
||||
if errors.Is(err, context.Canceled) || errors.Is(err, context.DeadlineExceeded) {
|
||||
// 任务被中断(例如服务停止/超时),保持 running 状态,后续通过 stale reclaim 续跑。
|
||||
log.Printf("[UsageCleanup] task interrupted: task=%d err=%v", task.ID, err)
|
||||
return
|
||||
}
|
||||
s.markTaskFailed(task.ID, deletedTotal, err)
|
||||
return
|
||||
}
|
||||
deletedTotal += deleted
|
||||
if deleted > 0 {
|
||||
updateCtx, cancel := context.WithTimeout(context.Background(), 3*time.Second)
|
||||
if err := s.repo.UpdateTaskProgress(updateCtx, task.ID, deletedTotal); err != nil {
|
||||
log.Printf("[UsageCleanup] task progress update failed: task=%d deleted_rows=%d err=%v", task.ID, deletedTotal, err)
|
||||
}
|
||||
cancel()
|
||||
}
|
||||
if batchNum <= 3 || batchNum%20 == 0 || deleted < int64(batchSize) {
|
||||
log.Printf("[UsageCleanup] task batch done: task=%d batch=%d deleted=%d deleted_total=%d", task.ID, batchNum, deleted, deletedTotal)
|
||||
}
|
||||
if deleted == 0 || deleted < int64(batchSize) {
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
updateCtx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
|
||||
defer cancel()
|
||||
if err := s.repo.MarkTaskSucceeded(updateCtx, task.ID, deletedTotal); err != nil {
|
||||
log.Printf("[UsageCleanup] update task succeeded failed: task=%d err=%v", task.ID, err)
|
||||
} else {
|
||||
log.Printf("[UsageCleanup] task succeeded: task=%d deleted_rows=%d duration=%s", task.ID, deletedTotal, time.Since(start))
|
||||
}
|
||||
|
||||
if s.dashboard != nil {
|
||||
if err := s.dashboard.TriggerRecomputeRange(task.Filters.StartTime, task.Filters.EndTime); err != nil {
|
||||
log.Printf("[UsageCleanup] trigger dashboard recompute failed: task=%d err=%v", task.ID, err)
|
||||
} else {
|
||||
log.Printf("[UsageCleanup] trigger dashboard recompute: task=%d start=%s end=%s", task.ID, task.Filters.StartTime.UTC().Format(time.RFC3339), task.Filters.EndTime.UTC().Format(time.RFC3339))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (s *UsageCleanupService) markTaskFailed(taskID int64, deletedRows int64, err error) {
|
||||
msg := strings.TrimSpace(err.Error())
|
||||
if len(msg) > 500 {
|
||||
msg = msg[:500]
|
||||
}
|
||||
log.Printf("[UsageCleanup] task failed: task=%d deleted_rows=%d err=%s", taskID, deletedRows, msg)
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
|
||||
defer cancel()
|
||||
if updateErr := s.repo.MarkTaskFailed(ctx, taskID, deletedRows, msg); updateErr != nil {
|
||||
log.Printf("[UsageCleanup] update task failed failed: task=%d err=%v", taskID, updateErr)
|
||||
}
|
||||
}
|
||||
|
||||
func (s *UsageCleanupService) isTaskCanceled(ctx context.Context, taskID int64) (bool, error) {
|
||||
if s == nil || s.repo == nil {
|
||||
return false, fmt.Errorf("cleanup service not ready")
|
||||
}
|
||||
checkCtx, cancel := context.WithTimeout(context.Background(), 2*time.Second)
|
||||
defer cancel()
|
||||
status, err := s.repo.GetTaskStatus(checkCtx, taskID)
|
||||
if err != nil {
|
||||
if errors.Is(err, sql.ErrNoRows) {
|
||||
return false, nil
|
||||
}
|
||||
return false, err
|
||||
}
|
||||
if status == UsageCleanupStatusCanceled {
|
||||
log.Printf("[UsageCleanup] task cancel detected: task=%d", taskID)
|
||||
}
|
||||
return status == UsageCleanupStatusCanceled, nil
|
||||
}
|
||||
|
||||
func (s *UsageCleanupService) validateFilters(filters UsageCleanupFilters) error {
|
||||
if filters.StartTime.IsZero() || filters.EndTime.IsZero() {
|
||||
return infraerrors.BadRequest("USAGE_CLEANUP_MISSING_RANGE", "start_date and end_date are required")
|
||||
}
|
||||
if filters.EndTime.Before(filters.StartTime) {
|
||||
return infraerrors.BadRequest("USAGE_CLEANUP_INVALID_RANGE", "end_date must be after start_date")
|
||||
}
|
||||
maxDays := s.maxRangeDays()
|
||||
if maxDays > 0 {
|
||||
delta := filters.EndTime.Sub(filters.StartTime)
|
||||
if delta > time.Duration(maxDays)*24*time.Hour {
|
||||
return infraerrors.BadRequest("USAGE_CLEANUP_RANGE_TOO_LARGE", fmt.Sprintf("date range exceeds %d days", maxDays))
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *UsageCleanupService) CancelTask(ctx context.Context, taskID int64, canceledBy int64) error {
|
||||
if s == nil || s.repo == nil {
|
||||
return fmt.Errorf("cleanup service not ready")
|
||||
}
|
||||
if s.cfg != nil && !s.cfg.UsageCleanup.Enabled {
|
||||
return infraerrors.New(http.StatusServiceUnavailable, "USAGE_CLEANUP_DISABLED", "usage cleanup is disabled")
|
||||
}
|
||||
if canceledBy <= 0 {
|
||||
return infraerrors.BadRequest("USAGE_CLEANUP_INVALID_CANCELLER", "invalid canceller")
|
||||
}
|
||||
status, err := s.repo.GetTaskStatus(ctx, taskID)
|
||||
if err != nil {
|
||||
if errors.Is(err, sql.ErrNoRows) {
|
||||
return infraerrors.New(http.StatusNotFound, "USAGE_CLEANUP_TASK_NOT_FOUND", "cleanup task not found")
|
||||
}
|
||||
return err
|
||||
}
|
||||
log.Printf("[UsageCleanup] cancel_task requested: task=%d operator=%d status=%s", taskID, canceledBy, status)
|
||||
if status != UsageCleanupStatusPending && status != UsageCleanupStatusRunning {
|
||||
return infraerrors.New(http.StatusConflict, "USAGE_CLEANUP_CANCEL_CONFLICT", "cleanup task cannot be canceled in current status")
|
||||
}
|
||||
ok, err := s.repo.CancelTask(ctx, taskID, canceledBy)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if !ok {
|
||||
// 状态可能并发改变
|
||||
return infraerrors.New(http.StatusConflict, "USAGE_CLEANUP_CANCEL_CONFLICT", "cleanup task cannot be canceled in current status")
|
||||
}
|
||||
log.Printf("[UsageCleanup] cancel_task done: task=%d operator=%d", taskID, canceledBy)
|
||||
return nil
|
||||
}
|
||||
|
||||
func sanitizeUsageCleanupFilters(filters *UsageCleanupFilters) {
|
||||
if filters == nil {
|
||||
return
|
||||
}
|
||||
if filters.UserID != nil && *filters.UserID <= 0 {
|
||||
filters.UserID = nil
|
||||
}
|
||||
if filters.APIKeyID != nil && *filters.APIKeyID <= 0 {
|
||||
filters.APIKeyID = nil
|
||||
}
|
||||
if filters.AccountID != nil && *filters.AccountID <= 0 {
|
||||
filters.AccountID = nil
|
||||
}
|
||||
if filters.GroupID != nil && *filters.GroupID <= 0 {
|
||||
filters.GroupID = nil
|
||||
}
|
||||
if filters.Model != nil {
|
||||
model := strings.TrimSpace(*filters.Model)
|
||||
if model == "" {
|
||||
filters.Model = nil
|
||||
} else {
|
||||
filters.Model = &model
|
||||
}
|
||||
}
|
||||
if filters.BillingType != nil && *filters.BillingType < 0 {
|
||||
filters.BillingType = nil
|
||||
}
|
||||
}
|
||||
|
||||
func (s *UsageCleanupService) maxRangeDays() int {
|
||||
if s == nil || s.cfg == nil {
|
||||
return 31
|
||||
}
|
||||
if s.cfg.UsageCleanup.MaxRangeDays > 0 {
|
||||
return s.cfg.UsageCleanup.MaxRangeDays
|
||||
}
|
||||
return 31
|
||||
}
|
||||
|
||||
func (s *UsageCleanupService) batchSize() int {
|
||||
if s == nil || s.cfg == nil {
|
||||
return 5000
|
||||
}
|
||||
if s.cfg.UsageCleanup.BatchSize > 0 {
|
||||
return s.cfg.UsageCleanup.BatchSize
|
||||
}
|
||||
return 5000
|
||||
}
|
||||
|
||||
func (s *UsageCleanupService) workerInterval() time.Duration {
|
||||
if s == nil || s.cfg == nil {
|
||||
return 10 * time.Second
|
||||
}
|
||||
if s.cfg.UsageCleanup.WorkerIntervalSeconds > 0 {
|
||||
return time.Duration(s.cfg.UsageCleanup.WorkerIntervalSeconds) * time.Second
|
||||
}
|
||||
return 10 * time.Second
|
||||
}
|
||||
|
||||
func (s *UsageCleanupService) taskTimeout() time.Duration {
|
||||
if s == nil || s.cfg == nil {
|
||||
return 30 * time.Minute
|
||||
}
|
||||
if s.cfg.UsageCleanup.TaskTimeoutSeconds > 0 {
|
||||
return time.Duration(s.cfg.UsageCleanup.TaskTimeoutSeconds) * time.Second
|
||||
}
|
||||
return 30 * time.Minute
|
||||
}
|
||||
815
backend/internal/service/usage_cleanup_service_test.go
Normal file
815
backend/internal/service/usage_cleanup_service_test.go
Normal file
@@ -0,0 +1,815 @@
|
||||
package service
|
||||
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"errors"
|
||||
"net/http"
|
||||
"strings"
|
||||
"sync"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/config"
|
||||
infraerrors "github.com/Wei-Shaw/sub2api/internal/pkg/errors"
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/pagination"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
type cleanupDeleteResponse struct {
|
||||
deleted int64
|
||||
err error
|
||||
}
|
||||
|
||||
type cleanupDeleteCall struct {
|
||||
filters UsageCleanupFilters
|
||||
limit int
|
||||
}
|
||||
|
||||
type cleanupMarkCall struct {
|
||||
taskID int64
|
||||
deletedRows int64
|
||||
errMsg string
|
||||
}
|
||||
|
||||
type cleanupRepoStub struct {
|
||||
mu sync.Mutex
|
||||
created []*UsageCleanupTask
|
||||
createErr error
|
||||
listTasks []UsageCleanupTask
|
||||
listResult *pagination.PaginationResult
|
||||
listErr error
|
||||
claimQueue []*UsageCleanupTask
|
||||
claimErr error
|
||||
deleteQueue []cleanupDeleteResponse
|
||||
deleteCalls []cleanupDeleteCall
|
||||
markSucceeded []cleanupMarkCall
|
||||
markFailed []cleanupMarkCall
|
||||
statusByID map[int64]string
|
||||
statusErr error
|
||||
progressCalls []cleanupMarkCall
|
||||
updateErr error
|
||||
cancelCalls []int64
|
||||
cancelErr error
|
||||
cancelResult *bool
|
||||
markFailedErr error
|
||||
}
|
||||
|
||||
type dashboardRepoStub struct {
|
||||
recomputeErr error
|
||||
}
|
||||
|
||||
func (s *dashboardRepoStub) AggregateRange(ctx context.Context, start, end time.Time) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *dashboardRepoStub) RecomputeRange(ctx context.Context, start, end time.Time) error {
|
||||
return s.recomputeErr
|
||||
}
|
||||
|
||||
func (s *dashboardRepoStub) GetAggregationWatermark(ctx context.Context) (time.Time, error) {
|
||||
return time.Time{}, nil
|
||||
}
|
||||
|
||||
func (s *dashboardRepoStub) UpdateAggregationWatermark(ctx context.Context, aggregatedAt time.Time) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *dashboardRepoStub) CleanupAggregates(ctx context.Context, hourlyCutoff, dailyCutoff time.Time) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *dashboardRepoStub) CleanupUsageLogs(ctx context.Context, cutoff time.Time) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *dashboardRepoStub) EnsureUsageLogsPartitions(ctx context.Context, now time.Time) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *cleanupRepoStub) CreateTask(ctx context.Context, task *UsageCleanupTask) error {
|
||||
if task == nil {
|
||||
return nil
|
||||
}
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
if s.createErr != nil {
|
||||
return s.createErr
|
||||
}
|
||||
if task.ID == 0 {
|
||||
task.ID = int64(len(s.created) + 1)
|
||||
}
|
||||
if task.CreatedAt.IsZero() {
|
||||
task.CreatedAt = time.Now().UTC()
|
||||
}
|
||||
if task.UpdatedAt.IsZero() {
|
||||
task.UpdatedAt = task.CreatedAt
|
||||
}
|
||||
clone := *task
|
||||
s.created = append(s.created, &clone)
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *cleanupRepoStub) ListTasks(ctx context.Context, params pagination.PaginationParams) ([]UsageCleanupTask, *pagination.PaginationResult, error) {
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
return s.listTasks, s.listResult, s.listErr
|
||||
}
|
||||
|
||||
func (s *cleanupRepoStub) ClaimNextPendingTask(ctx context.Context, staleRunningAfterSeconds int64) (*UsageCleanupTask, error) {
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
if s.claimErr != nil {
|
||||
return nil, s.claimErr
|
||||
}
|
||||
if len(s.claimQueue) == 0 {
|
||||
return nil, nil
|
||||
}
|
||||
task := s.claimQueue[0]
|
||||
s.claimQueue = s.claimQueue[1:]
|
||||
if s.statusByID == nil {
|
||||
s.statusByID = map[int64]string{}
|
||||
}
|
||||
s.statusByID[task.ID] = UsageCleanupStatusRunning
|
||||
return task, nil
|
||||
}
|
||||
|
||||
func (s *cleanupRepoStub) GetTaskStatus(ctx context.Context, taskID int64) (string, error) {
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
if s.statusErr != nil {
|
||||
return "", s.statusErr
|
||||
}
|
||||
if s.statusByID == nil {
|
||||
return "", sql.ErrNoRows
|
||||
}
|
||||
status, ok := s.statusByID[taskID]
|
||||
if !ok {
|
||||
return "", sql.ErrNoRows
|
||||
}
|
||||
return status, nil
|
||||
}
|
||||
|
||||
func (s *cleanupRepoStub) UpdateTaskProgress(ctx context.Context, taskID int64, deletedRows int64) error {
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
s.progressCalls = append(s.progressCalls, cleanupMarkCall{taskID: taskID, deletedRows: deletedRows})
|
||||
if s.updateErr != nil {
|
||||
return s.updateErr
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *cleanupRepoStub) CancelTask(ctx context.Context, taskID int64, canceledBy int64) (bool, error) {
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
s.cancelCalls = append(s.cancelCalls, taskID)
|
||||
if s.cancelErr != nil {
|
||||
return false, s.cancelErr
|
||||
}
|
||||
if s.cancelResult != nil {
|
||||
ok := *s.cancelResult
|
||||
if ok {
|
||||
if s.statusByID == nil {
|
||||
s.statusByID = map[int64]string{}
|
||||
}
|
||||
s.statusByID[taskID] = UsageCleanupStatusCanceled
|
||||
}
|
||||
return ok, nil
|
||||
}
|
||||
if s.statusByID == nil {
|
||||
s.statusByID = map[int64]string{}
|
||||
}
|
||||
status := s.statusByID[taskID]
|
||||
if status != UsageCleanupStatusPending && status != UsageCleanupStatusRunning {
|
||||
return false, nil
|
||||
}
|
||||
s.statusByID[taskID] = UsageCleanupStatusCanceled
|
||||
return true, nil
|
||||
}
|
||||
|
||||
func (s *cleanupRepoStub) MarkTaskSucceeded(ctx context.Context, taskID int64, deletedRows int64) error {
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
s.markSucceeded = append(s.markSucceeded, cleanupMarkCall{taskID: taskID, deletedRows: deletedRows})
|
||||
if s.statusByID == nil {
|
||||
s.statusByID = map[int64]string{}
|
||||
}
|
||||
s.statusByID[taskID] = UsageCleanupStatusSucceeded
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *cleanupRepoStub) MarkTaskFailed(ctx context.Context, taskID int64, deletedRows int64, errorMsg string) error {
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
s.markFailed = append(s.markFailed, cleanupMarkCall{taskID: taskID, deletedRows: deletedRows, errMsg: errorMsg})
|
||||
if s.statusByID == nil {
|
||||
s.statusByID = map[int64]string{}
|
||||
}
|
||||
s.statusByID[taskID] = UsageCleanupStatusFailed
|
||||
if s.markFailedErr != nil {
|
||||
return s.markFailedErr
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *cleanupRepoStub) DeleteUsageLogsBatch(ctx context.Context, filters UsageCleanupFilters, limit int) (int64, error) {
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
s.deleteCalls = append(s.deleteCalls, cleanupDeleteCall{filters: filters, limit: limit})
|
||||
if len(s.deleteQueue) == 0 {
|
||||
return 0, nil
|
||||
}
|
||||
resp := s.deleteQueue[0]
|
||||
s.deleteQueue = s.deleteQueue[1:]
|
||||
return resp.deleted, resp.err
|
||||
}
|
||||
|
||||
func TestUsageCleanupServiceCreateTaskSanitizeFilters(t *testing.T) {
|
||||
repo := &cleanupRepoStub{}
|
||||
cfg := &config.Config{UsageCleanup: config.UsageCleanupConfig{Enabled: true, MaxRangeDays: 31}}
|
||||
svc := NewUsageCleanupService(repo, nil, nil, cfg)
|
||||
|
||||
start := time.Date(2024, 1, 1, 0, 0, 0, 0, time.UTC)
|
||||
end := start.Add(24 * time.Hour)
|
||||
userID := int64(-1)
|
||||
apiKeyID := int64(10)
|
||||
model := " gpt-4 "
|
||||
billingType := int8(-2)
|
||||
filters := UsageCleanupFilters{
|
||||
StartTime: start,
|
||||
EndTime: end,
|
||||
UserID: &userID,
|
||||
APIKeyID: &apiKeyID,
|
||||
Model: &model,
|
||||
BillingType: &billingType,
|
||||
}
|
||||
|
||||
task, err := svc.CreateTask(context.Background(), filters, 9)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, UsageCleanupStatusPending, task.Status)
|
||||
require.Nil(t, task.Filters.UserID)
|
||||
require.NotNil(t, task.Filters.APIKeyID)
|
||||
require.Equal(t, apiKeyID, *task.Filters.APIKeyID)
|
||||
require.NotNil(t, task.Filters.Model)
|
||||
require.Equal(t, "gpt-4", *task.Filters.Model)
|
||||
require.Nil(t, task.Filters.BillingType)
|
||||
require.Equal(t, int64(9), task.CreatedBy)
|
||||
}
|
||||
|
||||
func TestUsageCleanupServiceCreateTaskInvalidCreator(t *testing.T) {
|
||||
repo := &cleanupRepoStub{}
|
||||
cfg := &config.Config{UsageCleanup: config.UsageCleanupConfig{Enabled: true}}
|
||||
svc := NewUsageCleanupService(repo, nil, nil, cfg)
|
||||
|
||||
filters := UsageCleanupFilters{
|
||||
StartTime: time.Now(),
|
||||
EndTime: time.Now().Add(24 * time.Hour),
|
||||
}
|
||||
_, err := svc.CreateTask(context.Background(), filters, 0)
|
||||
require.Error(t, err)
|
||||
require.Equal(t, "USAGE_CLEANUP_INVALID_CREATOR", infraerrors.Reason(err))
|
||||
}
|
||||
|
||||
func TestUsageCleanupServiceCreateTaskDisabled(t *testing.T) {
|
||||
repo := &cleanupRepoStub{}
|
||||
cfg := &config.Config{UsageCleanup: config.UsageCleanupConfig{Enabled: false}}
|
||||
svc := NewUsageCleanupService(repo, nil, nil, cfg)
|
||||
|
||||
filters := UsageCleanupFilters{
|
||||
StartTime: time.Now(),
|
||||
EndTime: time.Now().Add(24 * time.Hour),
|
||||
}
|
||||
_, err := svc.CreateTask(context.Background(), filters, 1)
|
||||
require.Error(t, err)
|
||||
require.Equal(t, http.StatusServiceUnavailable, infraerrors.Code(err))
|
||||
require.Equal(t, "USAGE_CLEANUP_DISABLED", infraerrors.Reason(err))
|
||||
}
|
||||
|
||||
func TestUsageCleanupServiceCreateTaskRangeTooLarge(t *testing.T) {
|
||||
repo := &cleanupRepoStub{}
|
||||
cfg := &config.Config{UsageCleanup: config.UsageCleanupConfig{Enabled: true, MaxRangeDays: 1}}
|
||||
svc := NewUsageCleanupService(repo, nil, nil, cfg)
|
||||
|
||||
start := time.Date(2024, 1, 1, 0, 0, 0, 0, time.UTC)
|
||||
end := start.Add(48 * time.Hour)
|
||||
filters := UsageCleanupFilters{StartTime: start, EndTime: end}
|
||||
|
||||
_, err := svc.CreateTask(context.Background(), filters, 1)
|
||||
require.Error(t, err)
|
||||
require.Equal(t, "USAGE_CLEANUP_RANGE_TOO_LARGE", infraerrors.Reason(err))
|
||||
}
|
||||
|
||||
func TestUsageCleanupServiceCreateTaskMissingRange(t *testing.T) {
|
||||
repo := &cleanupRepoStub{}
|
||||
cfg := &config.Config{UsageCleanup: config.UsageCleanupConfig{Enabled: true}}
|
||||
svc := NewUsageCleanupService(repo, nil, nil, cfg)
|
||||
|
||||
_, err := svc.CreateTask(context.Background(), UsageCleanupFilters{}, 1)
|
||||
require.Error(t, err)
|
||||
require.Equal(t, "USAGE_CLEANUP_MISSING_RANGE", infraerrors.Reason(err))
|
||||
}
|
||||
|
||||
func TestUsageCleanupServiceCreateTaskRepoError(t *testing.T) {
|
||||
repo := &cleanupRepoStub{createErr: errors.New("db down")}
|
||||
cfg := &config.Config{UsageCleanup: config.UsageCleanupConfig{Enabled: true}}
|
||||
svc := NewUsageCleanupService(repo, nil, nil, cfg)
|
||||
|
||||
filters := UsageCleanupFilters{
|
||||
StartTime: time.Now(),
|
||||
EndTime: time.Now().Add(24 * time.Hour),
|
||||
}
|
||||
_, err := svc.CreateTask(context.Background(), filters, 1)
|
||||
require.Error(t, err)
|
||||
require.Contains(t, err.Error(), "create cleanup task")
|
||||
}
|
||||
|
||||
func TestUsageCleanupServiceRunOnceSuccess(t *testing.T) {
|
||||
start := time.Date(2024, 1, 1, 0, 0, 0, 0, time.UTC)
|
||||
end := start.Add(2 * time.Hour)
|
||||
repo := &cleanupRepoStub{
|
||||
claimQueue: []*UsageCleanupTask{
|
||||
{ID: 5, Filters: UsageCleanupFilters{StartTime: start, EndTime: end}},
|
||||
},
|
||||
deleteQueue: []cleanupDeleteResponse{
|
||||
{deleted: 2},
|
||||
{deleted: 2},
|
||||
{deleted: 1},
|
||||
},
|
||||
}
|
||||
cfg := &config.Config{UsageCleanup: config.UsageCleanupConfig{Enabled: true, BatchSize: 2, TaskTimeoutSeconds: 30}}
|
||||
svc := NewUsageCleanupService(repo, nil, nil, cfg)
|
||||
|
||||
svc.runOnce()
|
||||
|
||||
repo.mu.Lock()
|
||||
defer repo.mu.Unlock()
|
||||
require.Len(t, repo.deleteCalls, 3)
|
||||
require.Len(t, repo.markSucceeded, 1)
|
||||
require.Empty(t, repo.markFailed)
|
||||
require.Equal(t, int64(5), repo.markSucceeded[0].taskID)
|
||||
require.Equal(t, int64(5), repo.markSucceeded[0].deletedRows)
|
||||
require.Equal(t, 2, repo.deleteCalls[0].limit)
|
||||
require.Equal(t, start, repo.deleteCalls[0].filters.StartTime)
|
||||
require.Equal(t, end, repo.deleteCalls[0].filters.EndTime)
|
||||
}
|
||||
|
||||
func TestUsageCleanupServiceRunOnceClaimError(t *testing.T) {
|
||||
repo := &cleanupRepoStub{claimErr: errors.New("claim failed")}
|
||||
cfg := &config.Config{UsageCleanup: config.UsageCleanupConfig{Enabled: true}}
|
||||
svc := NewUsageCleanupService(repo, nil, nil, cfg)
|
||||
svc.runOnce()
|
||||
|
||||
repo.mu.Lock()
|
||||
defer repo.mu.Unlock()
|
||||
require.Empty(t, repo.markSucceeded)
|
||||
require.Empty(t, repo.markFailed)
|
||||
}
|
||||
|
||||
func TestUsageCleanupServiceRunOnceAlreadyRunning(t *testing.T) {
|
||||
repo := &cleanupRepoStub{}
|
||||
cfg := &config.Config{UsageCleanup: config.UsageCleanupConfig{Enabled: true}}
|
||||
svc := NewUsageCleanupService(repo, nil, nil, cfg)
|
||||
svc.running = 1
|
||||
svc.runOnce()
|
||||
}
|
||||
|
||||
func TestUsageCleanupServiceExecuteTaskFailed(t *testing.T) {
|
||||
longMsg := strings.Repeat("x", 600)
|
||||
repo := &cleanupRepoStub{
|
||||
deleteQueue: []cleanupDeleteResponse{
|
||||
{err: errors.New(longMsg)},
|
||||
},
|
||||
}
|
||||
cfg := &config.Config{UsageCleanup: config.UsageCleanupConfig{Enabled: true, BatchSize: 3}}
|
||||
svc := NewUsageCleanupService(repo, nil, nil, cfg)
|
||||
task := &UsageCleanupTask{
|
||||
ID: 11,
|
||||
Filters: UsageCleanupFilters{
|
||||
StartTime: time.Now(),
|
||||
EndTime: time.Now().Add(24 * time.Hour),
|
||||
},
|
||||
}
|
||||
|
||||
svc.executeTask(context.Background(), task)
|
||||
|
||||
repo.mu.Lock()
|
||||
defer repo.mu.Unlock()
|
||||
require.Len(t, repo.markFailed, 1)
|
||||
require.Equal(t, int64(11), repo.markFailed[0].taskID)
|
||||
require.Equal(t, 500, len(repo.markFailed[0].errMsg))
|
||||
}
|
||||
|
||||
func TestUsageCleanupServiceExecuteTaskProgressError(t *testing.T) {
|
||||
repo := &cleanupRepoStub{
|
||||
deleteQueue: []cleanupDeleteResponse{
|
||||
{deleted: 2},
|
||||
{deleted: 0},
|
||||
},
|
||||
updateErr: errors.New("update failed"),
|
||||
}
|
||||
cfg := &config.Config{UsageCleanup: config.UsageCleanupConfig{Enabled: true, BatchSize: 2}}
|
||||
svc := NewUsageCleanupService(repo, nil, nil, cfg)
|
||||
task := &UsageCleanupTask{
|
||||
ID: 8,
|
||||
Filters: UsageCleanupFilters{
|
||||
StartTime: time.Now().UTC(),
|
||||
EndTime: time.Now().UTC().Add(time.Hour),
|
||||
},
|
||||
}
|
||||
|
||||
svc.executeTask(context.Background(), task)
|
||||
|
||||
repo.mu.Lock()
|
||||
defer repo.mu.Unlock()
|
||||
require.Len(t, repo.markSucceeded, 1)
|
||||
require.Empty(t, repo.markFailed)
|
||||
require.Len(t, repo.progressCalls, 1)
|
||||
}
|
||||
|
||||
func TestUsageCleanupServiceExecuteTaskDeleteCanceled(t *testing.T) {
|
||||
repo := &cleanupRepoStub{
|
||||
deleteQueue: []cleanupDeleteResponse{
|
||||
{err: context.Canceled},
|
||||
},
|
||||
}
|
||||
cfg := &config.Config{UsageCleanup: config.UsageCleanupConfig{Enabled: true, BatchSize: 2}}
|
||||
svc := NewUsageCleanupService(repo, nil, nil, cfg)
|
||||
task := &UsageCleanupTask{
|
||||
ID: 12,
|
||||
Filters: UsageCleanupFilters{
|
||||
StartTime: time.Now().UTC(),
|
||||
EndTime: time.Now().UTC().Add(time.Hour),
|
||||
},
|
||||
}
|
||||
|
||||
svc.executeTask(context.Background(), task)
|
||||
|
||||
repo.mu.Lock()
|
||||
defer repo.mu.Unlock()
|
||||
require.Empty(t, repo.markSucceeded)
|
||||
require.Empty(t, repo.markFailed)
|
||||
}
|
||||
|
||||
func TestUsageCleanupServiceExecuteTaskContextCanceled(t *testing.T) {
|
||||
repo := &cleanupRepoStub{}
|
||||
cfg := &config.Config{UsageCleanup: config.UsageCleanupConfig{Enabled: true, BatchSize: 2}}
|
||||
svc := NewUsageCleanupService(repo, nil, nil, cfg)
|
||||
task := &UsageCleanupTask{
|
||||
ID: 9,
|
||||
Filters: UsageCleanupFilters{
|
||||
StartTime: time.Now().UTC(),
|
||||
EndTime: time.Now().UTC().Add(time.Hour),
|
||||
},
|
||||
}
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
cancel()
|
||||
|
||||
svc.executeTask(ctx, task)
|
||||
|
||||
repo.mu.Lock()
|
||||
defer repo.mu.Unlock()
|
||||
require.Empty(t, repo.markSucceeded)
|
||||
require.Empty(t, repo.markFailed)
|
||||
require.Empty(t, repo.deleteCalls)
|
||||
}
|
||||
|
||||
func TestUsageCleanupServiceExecuteTaskMarkFailedUpdateError(t *testing.T) {
|
||||
repo := &cleanupRepoStub{
|
||||
deleteQueue: []cleanupDeleteResponse{
|
||||
{err: errors.New("boom")},
|
||||
},
|
||||
markFailedErr: errors.New("update failed"),
|
||||
}
|
||||
cfg := &config.Config{UsageCleanup: config.UsageCleanupConfig{Enabled: true, BatchSize: 2}}
|
||||
svc := NewUsageCleanupService(repo, nil, nil, cfg)
|
||||
task := &UsageCleanupTask{
|
||||
ID: 13,
|
||||
Filters: UsageCleanupFilters{
|
||||
StartTime: time.Now().UTC(),
|
||||
EndTime: time.Now().UTC().Add(time.Hour),
|
||||
},
|
||||
}
|
||||
|
||||
svc.executeTask(context.Background(), task)
|
||||
|
||||
repo.mu.Lock()
|
||||
defer repo.mu.Unlock()
|
||||
require.Len(t, repo.markFailed, 1)
|
||||
require.Equal(t, int64(13), repo.markFailed[0].taskID)
|
||||
}
|
||||
|
||||
func TestUsageCleanupServiceExecuteTaskDashboardRecomputeError(t *testing.T) {
|
||||
repo := &cleanupRepoStub{
|
||||
deleteQueue: []cleanupDeleteResponse{
|
||||
{deleted: 0},
|
||||
},
|
||||
}
|
||||
dashboard := NewDashboardAggregationService(&dashboardRepoStub{}, nil, &config.Config{
|
||||
DashboardAgg: config.DashboardAggregationConfig{Enabled: false},
|
||||
})
|
||||
cfg := &config.Config{UsageCleanup: config.UsageCleanupConfig{Enabled: true, BatchSize: 2}}
|
||||
svc := NewUsageCleanupService(repo, nil, dashboard, cfg)
|
||||
task := &UsageCleanupTask{
|
||||
ID: 14,
|
||||
Filters: UsageCleanupFilters{
|
||||
StartTime: time.Now().UTC(),
|
||||
EndTime: time.Now().UTC().Add(time.Hour),
|
||||
},
|
||||
}
|
||||
|
||||
svc.executeTask(context.Background(), task)
|
||||
|
||||
repo.mu.Lock()
|
||||
defer repo.mu.Unlock()
|
||||
require.Len(t, repo.markSucceeded, 1)
|
||||
}
|
||||
|
||||
func TestUsageCleanupServiceExecuteTaskDashboardRecomputeSuccess(t *testing.T) {
|
||||
repo := &cleanupRepoStub{
|
||||
deleteQueue: []cleanupDeleteResponse{
|
||||
{deleted: 0},
|
||||
},
|
||||
}
|
||||
dashboard := NewDashboardAggregationService(&dashboardRepoStub{}, nil, &config.Config{
|
||||
DashboardAgg: config.DashboardAggregationConfig{Enabled: true},
|
||||
})
|
||||
cfg := &config.Config{UsageCleanup: config.UsageCleanupConfig{Enabled: true, BatchSize: 2}}
|
||||
svc := NewUsageCleanupService(repo, nil, dashboard, cfg)
|
||||
task := &UsageCleanupTask{
|
||||
ID: 15,
|
||||
Filters: UsageCleanupFilters{
|
||||
StartTime: time.Now().UTC(),
|
||||
EndTime: time.Now().UTC().Add(time.Hour),
|
||||
},
|
||||
}
|
||||
|
||||
svc.executeTask(context.Background(), task)
|
||||
|
||||
repo.mu.Lock()
|
||||
defer repo.mu.Unlock()
|
||||
require.Len(t, repo.markSucceeded, 1)
|
||||
}
|
||||
|
||||
func TestUsageCleanupServiceExecuteTaskCanceled(t *testing.T) {
|
||||
repo := &cleanupRepoStub{
|
||||
statusByID: map[int64]string{
|
||||
3: UsageCleanupStatusCanceled,
|
||||
},
|
||||
}
|
||||
cfg := &config.Config{UsageCleanup: config.UsageCleanupConfig{Enabled: true, BatchSize: 2}}
|
||||
svc := NewUsageCleanupService(repo, nil, nil, cfg)
|
||||
task := &UsageCleanupTask{
|
||||
ID: 3,
|
||||
Filters: UsageCleanupFilters{
|
||||
StartTime: time.Now().UTC(),
|
||||
EndTime: time.Now().UTC().Add(time.Hour),
|
||||
},
|
||||
}
|
||||
|
||||
svc.executeTask(context.Background(), task)
|
||||
|
||||
repo.mu.Lock()
|
||||
defer repo.mu.Unlock()
|
||||
require.Empty(t, repo.deleteCalls)
|
||||
require.Empty(t, repo.markSucceeded)
|
||||
require.Empty(t, repo.markFailed)
|
||||
}
|
||||
|
||||
func TestUsageCleanupServiceCancelTaskSuccess(t *testing.T) {
|
||||
repo := &cleanupRepoStub{
|
||||
statusByID: map[int64]string{
|
||||
5: UsageCleanupStatusPending,
|
||||
},
|
||||
}
|
||||
cfg := &config.Config{UsageCleanup: config.UsageCleanupConfig{Enabled: true}}
|
||||
svc := NewUsageCleanupService(repo, nil, nil, cfg)
|
||||
|
||||
err := svc.CancelTask(context.Background(), 5, 9)
|
||||
require.NoError(t, err)
|
||||
|
||||
repo.mu.Lock()
|
||||
defer repo.mu.Unlock()
|
||||
require.Equal(t, UsageCleanupStatusCanceled, repo.statusByID[5])
|
||||
require.Len(t, repo.cancelCalls, 1)
|
||||
}
|
||||
|
||||
func TestUsageCleanupServiceCancelTaskDisabled(t *testing.T) {
|
||||
repo := &cleanupRepoStub{}
|
||||
cfg := &config.Config{UsageCleanup: config.UsageCleanupConfig{Enabled: false}}
|
||||
svc := NewUsageCleanupService(repo, nil, nil, cfg)
|
||||
|
||||
err := svc.CancelTask(context.Background(), 1, 2)
|
||||
require.Error(t, err)
|
||||
require.Equal(t, http.StatusServiceUnavailable, infraerrors.Code(err))
|
||||
require.Equal(t, "USAGE_CLEANUP_DISABLED", infraerrors.Reason(err))
|
||||
}
|
||||
|
||||
func TestUsageCleanupServiceCancelTaskNotFound(t *testing.T) {
|
||||
repo := &cleanupRepoStub{}
|
||||
cfg := &config.Config{UsageCleanup: config.UsageCleanupConfig{Enabled: true}}
|
||||
svc := NewUsageCleanupService(repo, nil, nil, cfg)
|
||||
|
||||
err := svc.CancelTask(context.Background(), 999, 1)
|
||||
require.Error(t, err)
|
||||
require.Equal(t, http.StatusNotFound, infraerrors.Code(err))
|
||||
require.Equal(t, "USAGE_CLEANUP_TASK_NOT_FOUND", infraerrors.Reason(err))
|
||||
}
|
||||
|
||||
func TestUsageCleanupServiceCancelTaskStatusError(t *testing.T) {
|
||||
repo := &cleanupRepoStub{statusErr: errors.New("status broken")}
|
||||
cfg := &config.Config{UsageCleanup: config.UsageCleanupConfig{Enabled: true}}
|
||||
svc := NewUsageCleanupService(repo, nil, nil, cfg)
|
||||
|
||||
err := svc.CancelTask(context.Background(), 7, 1)
|
||||
require.Error(t, err)
|
||||
require.Contains(t, err.Error(), "status broken")
|
||||
}
|
||||
|
||||
func TestUsageCleanupServiceCancelTaskConflict(t *testing.T) {
|
||||
repo := &cleanupRepoStub{
|
||||
statusByID: map[int64]string{
|
||||
7: UsageCleanupStatusSucceeded,
|
||||
},
|
||||
}
|
||||
cfg := &config.Config{UsageCleanup: config.UsageCleanupConfig{Enabled: true}}
|
||||
svc := NewUsageCleanupService(repo, nil, nil, cfg)
|
||||
|
||||
err := svc.CancelTask(context.Background(), 7, 1)
|
||||
require.Error(t, err)
|
||||
require.Equal(t, http.StatusConflict, infraerrors.Code(err))
|
||||
require.Equal(t, "USAGE_CLEANUP_CANCEL_CONFLICT", infraerrors.Reason(err))
|
||||
}
|
||||
|
||||
func TestUsageCleanupServiceCancelTaskRepoConflict(t *testing.T) {
|
||||
shouldCancel := false
|
||||
repo := &cleanupRepoStub{
|
||||
statusByID: map[int64]string{
|
||||
7: UsageCleanupStatusPending,
|
||||
},
|
||||
cancelResult: &shouldCancel,
|
||||
}
|
||||
cfg := &config.Config{UsageCleanup: config.UsageCleanupConfig{Enabled: true}}
|
||||
svc := NewUsageCleanupService(repo, nil, nil, cfg)
|
||||
|
||||
err := svc.CancelTask(context.Background(), 7, 1)
|
||||
require.Error(t, err)
|
||||
require.Equal(t, http.StatusConflict, infraerrors.Code(err))
|
||||
require.Equal(t, "USAGE_CLEANUP_CANCEL_CONFLICT", infraerrors.Reason(err))
|
||||
}
|
||||
|
||||
func TestUsageCleanupServiceCancelTaskRepoError(t *testing.T) {
|
||||
repo := &cleanupRepoStub{
|
||||
statusByID: map[int64]string{
|
||||
7: UsageCleanupStatusPending,
|
||||
},
|
||||
cancelErr: errors.New("cancel failed"),
|
||||
}
|
||||
cfg := &config.Config{UsageCleanup: config.UsageCleanupConfig{Enabled: true}}
|
||||
svc := NewUsageCleanupService(repo, nil, nil, cfg)
|
||||
|
||||
err := svc.CancelTask(context.Background(), 7, 1)
|
||||
require.Error(t, err)
|
||||
require.Contains(t, err.Error(), "cancel failed")
|
||||
}
|
||||
|
||||
func TestUsageCleanupServiceCancelTaskInvalidCanceller(t *testing.T) {
|
||||
repo := &cleanupRepoStub{
|
||||
statusByID: map[int64]string{
|
||||
7: UsageCleanupStatusRunning,
|
||||
},
|
||||
}
|
||||
cfg := &config.Config{UsageCleanup: config.UsageCleanupConfig{Enabled: true}}
|
||||
svc := NewUsageCleanupService(repo, nil, nil, cfg)
|
||||
|
||||
err := svc.CancelTask(context.Background(), 7, 0)
|
||||
require.Error(t, err)
|
||||
require.Equal(t, "USAGE_CLEANUP_INVALID_CANCELLER", infraerrors.Reason(err))
|
||||
}
|
||||
|
||||
func TestUsageCleanupServiceListTasks(t *testing.T) {
|
||||
repo := &cleanupRepoStub{
|
||||
listTasks: []UsageCleanupTask{{ID: 1}, {ID: 2}},
|
||||
listResult: &pagination.PaginationResult{
|
||||
Total: 2,
|
||||
Page: 1,
|
||||
PageSize: 20,
|
||||
Pages: 1,
|
||||
},
|
||||
}
|
||||
svc := NewUsageCleanupService(repo, nil, nil, &config.Config{UsageCleanup: config.UsageCleanupConfig{Enabled: true}})
|
||||
|
||||
tasks, result, err := svc.ListTasks(context.Background(), pagination.PaginationParams{Page: 1, PageSize: 20})
|
||||
require.NoError(t, err)
|
||||
require.Len(t, tasks, 2)
|
||||
require.Equal(t, int64(2), result.Total)
|
||||
}
|
||||
|
||||
func TestUsageCleanupServiceListTasksNotReady(t *testing.T) {
|
||||
var nilSvc *UsageCleanupService
|
||||
_, _, err := nilSvc.ListTasks(context.Background(), pagination.PaginationParams{Page: 1, PageSize: 20})
|
||||
require.Error(t, err)
|
||||
|
||||
svc := NewUsageCleanupService(nil, nil, nil, &config.Config{UsageCleanup: config.UsageCleanupConfig{Enabled: true}})
|
||||
_, _, err = svc.ListTasks(context.Background(), pagination.PaginationParams{Page: 1, PageSize: 20})
|
||||
require.Error(t, err)
|
||||
}
|
||||
|
||||
func TestUsageCleanupServiceDefaultsAndLifecycle(t *testing.T) {
|
||||
var nilSvc *UsageCleanupService
|
||||
require.Equal(t, 31, nilSvc.maxRangeDays())
|
||||
require.Equal(t, 5000, nilSvc.batchSize())
|
||||
require.Equal(t, 10*time.Second, nilSvc.workerInterval())
|
||||
require.Equal(t, 30*time.Minute, nilSvc.taskTimeout())
|
||||
nilSvc.Start()
|
||||
nilSvc.Stop()
|
||||
|
||||
repo := &cleanupRepoStub{}
|
||||
cfgDisabled := &config.Config{UsageCleanup: config.UsageCleanupConfig{Enabled: false}}
|
||||
svcDisabled := NewUsageCleanupService(repo, nil, nil, cfgDisabled)
|
||||
svcDisabled.Start()
|
||||
svcDisabled.Stop()
|
||||
|
||||
timingWheel, err := NewTimingWheelService()
|
||||
require.NoError(t, err)
|
||||
|
||||
cfg := &config.Config{UsageCleanup: config.UsageCleanupConfig{Enabled: true, WorkerIntervalSeconds: 5}}
|
||||
svc := NewUsageCleanupService(repo, timingWheel, nil, cfg)
|
||||
require.Equal(t, 5*time.Second, svc.workerInterval())
|
||||
svc.Start()
|
||||
svc.Stop()
|
||||
|
||||
cfgFallback := &config.Config{UsageCleanup: config.UsageCleanupConfig{Enabled: true}}
|
||||
svcFallback := NewUsageCleanupService(repo, timingWheel, nil, cfgFallback)
|
||||
require.Equal(t, 31, svcFallback.maxRangeDays())
|
||||
require.Equal(t, 5000, svcFallback.batchSize())
|
||||
require.Equal(t, 10*time.Second, svcFallback.workerInterval())
|
||||
|
||||
svcMissingDeps := NewUsageCleanupService(nil, nil, nil, cfgFallback)
|
||||
svcMissingDeps.Start()
|
||||
}
|
||||
|
||||
func TestSanitizeUsageCleanupFiltersModelEmpty(t *testing.T) {
|
||||
model := " "
|
||||
apiKeyID := int64(-5)
|
||||
accountID := int64(-1)
|
||||
groupID := int64(-2)
|
||||
filters := UsageCleanupFilters{
|
||||
UserID: &apiKeyID,
|
||||
APIKeyID: &apiKeyID,
|
||||
AccountID: &accountID,
|
||||
GroupID: &groupID,
|
||||
Model: &model,
|
||||
}
|
||||
|
||||
sanitizeUsageCleanupFilters(&filters)
|
||||
require.Nil(t, filters.UserID)
|
||||
require.Nil(t, filters.APIKeyID)
|
||||
require.Nil(t, filters.AccountID)
|
||||
require.Nil(t, filters.GroupID)
|
||||
require.Nil(t, filters.Model)
|
||||
}
|
||||
|
||||
func TestDescribeUsageCleanupFiltersAllFields(t *testing.T) {
|
||||
start := time.Date(2024, 2, 1, 10, 0, 0, 0, time.UTC)
|
||||
end := start.Add(2 * time.Hour)
|
||||
userID := int64(1)
|
||||
apiKeyID := int64(2)
|
||||
accountID := int64(3)
|
||||
groupID := int64(4)
|
||||
model := " gpt-4 "
|
||||
stream := true
|
||||
billingType := int8(2)
|
||||
filters := UsageCleanupFilters{
|
||||
StartTime: start,
|
||||
EndTime: end,
|
||||
UserID: &userID,
|
||||
APIKeyID: &apiKeyID,
|
||||
AccountID: &accountID,
|
||||
GroupID: &groupID,
|
||||
Model: &model,
|
||||
Stream: &stream,
|
||||
BillingType: &billingType,
|
||||
}
|
||||
|
||||
desc := describeUsageCleanupFilters(filters)
|
||||
require.Equal(t, "start=2024-02-01T10:00:00Z end=2024-02-01T12:00:00Z user_id=1 api_key_id=2 account_id=3 group_id=4 model=gpt-4 stream=true billing_type=2", desc)
|
||||
}
|
||||
|
||||
func TestUsageCleanupServiceIsTaskCanceledNotFound(t *testing.T) {
|
||||
repo := &cleanupRepoStub{}
|
||||
svc := NewUsageCleanupService(repo, nil, nil, &config.Config{UsageCleanup: config.UsageCleanupConfig{Enabled: true}})
|
||||
|
||||
canceled, err := svc.isTaskCanceled(context.Background(), 9)
|
||||
require.NoError(t, err)
|
||||
require.False(t, canceled)
|
||||
}
|
||||
|
||||
func TestUsageCleanupServiceIsTaskCanceledError(t *testing.T) {
|
||||
repo := &cleanupRepoStub{statusErr: errors.New("status err")}
|
||||
svc := NewUsageCleanupService(repo, nil, nil, &config.Config{UsageCleanup: config.UsageCleanupConfig{Enabled: true}})
|
||||
|
||||
_, err := svc.isTaskCanceled(context.Background(), 9)
|
||||
require.Error(t, err)
|
||||
require.Contains(t, err.Error(), "status err")
|
||||
}
|
||||
@@ -1,6 +1,7 @@
|
||||
package service
|
||||
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"time"
|
||||
|
||||
@@ -57,6 +58,13 @@ func ProvideDashboardAggregationService(repo DashboardAggregationRepository, tim
|
||||
return svc
|
||||
}
|
||||
|
||||
// ProvideUsageCleanupService 创建并启动使用记录清理任务服务
|
||||
func ProvideUsageCleanupService(repo UsageCleanupRepository, timingWheel *TimingWheelService, dashboardAgg *DashboardAggregationService, cfg *config.Config) *UsageCleanupService {
|
||||
svc := NewUsageCleanupService(repo, timingWheel, dashboardAgg, cfg)
|
||||
svc.Start()
|
||||
return svc
|
||||
}
|
||||
|
||||
// ProvideAccountExpiryService creates and starts AccountExpiryService.
|
||||
func ProvideAccountExpiryService(accountRepo AccountRepository) *AccountExpiryService {
|
||||
svc := NewAccountExpiryService(accountRepo, time.Minute)
|
||||
@@ -189,6 +197,8 @@ func ProvideOpsScheduledReportService(
|
||||
|
||||
// ProvideAPIKeyAuthCacheInvalidator 提供 API Key 认证缓存失效能力
|
||||
func ProvideAPIKeyAuthCacheInvalidator(apiKeyService *APIKeyService) APIKeyAuthCacheInvalidator {
|
||||
// Start Pub/Sub subscriber for L1 cache invalidation across instances
|
||||
apiKeyService.StartAuthCacheInvalidationSubscriber(context.Background())
|
||||
return apiKeyService
|
||||
}
|
||||
|
||||
@@ -248,6 +258,7 @@ var ProviderSet = wire.NewSet(
|
||||
ProvideAccountExpiryService,
|
||||
ProvideTimingWheelService,
|
||||
ProvideDashboardAggregationService,
|
||||
ProvideUsageCleanupService,
|
||||
ProvideDeferredService,
|
||||
NewAntigravityQuotaFetcher,
|
||||
NewUserAttributeService,
|
||||
|
||||
Reference in New Issue
Block a user