From 46e5ac9672f2d898d44f7d89349a5faf54a300b8 Mon Sep 17 00:00:00 2001
From: cyhhao
Date: Thu, 15 Jan 2026 18:54:42 +0800
Subject: [PATCH 001/155] =?UTF-8?q?fix(=E7=BD=91=E5=85=B3):=20=E5=AF=B9?=
=?UTF-8?q?=E9=BD=90=20Claude=20OAuth=20=E8=AF=B7=E6=B1=82=E9=80=82?=
=?UTF-8?q?=E9=85=8D?=
MIME-Version: 1.0
Content-Type: text/plain; charset=UTF-8
Content-Transfer-Encoding: 8bit
---
backend/internal/pkg/claude/constants.go | 44 +-
backend/internal/service/gateway_service.go | 454 ++++++++++++++++++-
backend/internal/service/identity_service.go | 8 +-
3 files changed, 481 insertions(+), 25 deletions(-)
diff --git a/backend/internal/pkg/claude/constants.go b/backend/internal/pkg/claude/constants.go
index d1a56a84..15144881 100644
--- a/backend/internal/pkg/claude/constants.go
+++ b/backend/internal/pkg/claude/constants.go
@@ -25,15 +25,15 @@ const APIKeyHaikuBetaHeader = BetaInterleavedThinking
// DefaultHeaders 是 Claude Code 客户端默认请求头。
var DefaultHeaders = map[string]string{
- "User-Agent": "claude-cli/2.0.62 (external, cli)",
+ "User-Agent": "claude-cli/2.1.2 (external, cli)",
"X-Stainless-Lang": "js",
- "X-Stainless-Package-Version": "0.52.0",
+ "X-Stainless-Package-Version": "0.70.0",
"X-Stainless-OS": "Linux",
"X-Stainless-Arch": "x64",
"X-Stainless-Runtime": "node",
- "X-Stainless-Runtime-Version": "v22.14.0",
+ "X-Stainless-Runtime-Version": "v24.3.0",
"X-Stainless-Retry-Count": "0",
- "X-Stainless-Timeout": "60",
+ "X-Stainless-Timeout": "600",
"X-App": "cli",
"Anthropic-Dangerous-Direct-Browser-Access": "true",
}
@@ -79,3 +79,39 @@ func DefaultModelIDs() []string {
// DefaultTestModel 测试时使用的默认模型
const DefaultTestModel = "claude-sonnet-4-5-20250929"
+
+// ModelIDOverrides Claude OAuth 请求需要的模型 ID 映射
+var ModelIDOverrides = map[string]string{
+ "claude-sonnet-4-5": "claude-sonnet-4-5-20250929",
+ "claude-opus-4-5": "claude-opus-4-5-20251101",
+ "claude-haiku-4-5": "claude-haiku-4-5-20251001",
+}
+
+// ModelIDReverseOverrides 用于将上游模型 ID 还原为短名
+var ModelIDReverseOverrides = map[string]string{
+ "claude-sonnet-4-5-20250929": "claude-sonnet-4-5",
+ "claude-opus-4-5-20251101": "claude-opus-4-5",
+ "claude-haiku-4-5-20251001": "claude-haiku-4-5",
+}
+
+// NormalizeModelID 根据 Claude OAuth 规则映射模型
+func NormalizeModelID(id string) string {
+ if id == "" {
+ return id
+ }
+ if mapped, ok := ModelIDOverrides[id]; ok {
+ return mapped
+ }
+ return id
+}
+
+// DenormalizeModelID 将上游模型 ID 转换为短名
+func DenormalizeModelID(id string) string {
+ if id == "" {
+ return id
+ }
+ if mapped, ok := ModelIDReverseOverrides[id]; ok {
+ return mapped
+ }
+ return id
+}
diff --git a/backend/internal/service/gateway_service.go b/backend/internal/service/gateway_service.go
index d5eb0e52..1d29b3fd 100644
--- a/backend/internal/service/gateway_service.go
+++ b/backend/internal/service/gateway_service.go
@@ -17,12 +17,14 @@ import (
"strings"
"sync/atomic"
"time"
+ "unicode"
"github.com/Wei-Shaw/sub2api/internal/config"
"github.com/Wei-Shaw/sub2api/internal/pkg/claude"
"github.com/Wei-Shaw/sub2api/internal/pkg/ctxkey"
"github.com/Wei-Shaw/sub2api/internal/util/responseheaders"
"github.com/Wei-Shaw/sub2api/internal/util/urlvalidator"
+ "github.com/google/uuid"
"github.com/tidwall/gjson"
"github.com/tidwall/sjson"
@@ -44,6 +46,36 @@ var (
sseDataRe = regexp.MustCompile(`^data:\s*`)
sessionIDRegex = regexp.MustCompile(`session_([a-f0-9-]{36})`)
claudeCliUserAgentRe = regexp.MustCompile(`^claude-cli/\d+\.\d+\.\d+`)
+ toolPrefixRe = regexp.MustCompile(`(?i)^(?:oc_|mcp_)`)
+ toolNameBoundaryRe = regexp.MustCompile(`[^a-zA-Z0-9]+`)
+ toolNameCamelRe = regexp.MustCompile(`([a-z0-9])([A-Z])`)
+
+ claudeToolNameOverrides = map[string]string{
+ "bash": "Bash",
+ "read": "Read",
+ "edit": "Edit",
+ "write": "Write",
+ "task": "Task",
+ "glob": "Glob",
+ "grep": "Grep",
+ "webfetch": "WebFetch",
+ "websearch": "WebSearch",
+ "todowrite": "TodoWrite",
+ "question": "AskUserQuestion",
+ }
+ openCodeToolOverrides = map[string]string{
+ "Bash": "bash",
+ "Read": "read",
+ "Edit": "edit",
+ "Write": "write",
+ "Task": "task",
+ "Glob": "glob",
+ "Grep": "grep",
+ "WebFetch": "webfetch",
+ "WebSearch": "websearch",
+ "TodoWrite": "todowrite",
+ "AskUserQuestion": "question",
+ }
// claudeCodePromptPrefixes 用于检测 Claude Code 系统提示词的前缀列表
// 支持多种变体:标准版、Agent SDK 版、Explore Agent 版、Compact 版等
@@ -346,6 +378,268 @@ func (s *GatewayService) replaceModelInBody(body []byte, newModel string) []byte
return newBody
}
+type claudeOAuthNormalizeOptions struct {
+ injectMetadata bool
+ metadataUserID string
+ stripSystemCacheControl bool
+}
+
+func stripToolPrefix(value string) string {
+ if value == "" {
+ return value
+ }
+ return toolPrefixRe.ReplaceAllString(value, "")
+}
+
+func toPascalCase(value string) string {
+ if value == "" {
+ return value
+ }
+ normalized := toolNameBoundaryRe.ReplaceAllString(value, " ")
+ tokens := make([]string, 0)
+ for _, token := range strings.Fields(normalized) {
+ expanded := toolNameCamelRe.ReplaceAllString(token, "$1 $2")
+ parts := strings.Fields(expanded)
+ if len(parts) > 0 {
+ tokens = append(tokens, parts...)
+ }
+ }
+ if len(tokens) == 0 {
+ return value
+ }
+ var builder strings.Builder
+ for _, token := range tokens {
+ lower := strings.ToLower(token)
+ if lower == "" {
+ continue
+ }
+ runes := []rune(lower)
+ runes[0] = unicode.ToUpper(runes[0])
+ builder.WriteString(string(runes))
+ }
+ return builder.String()
+}
+
+func toSnakeCase(value string) string {
+ if value == "" {
+ return value
+ }
+ output := toolNameCamelRe.ReplaceAllString(value, "$1_$2")
+ output = toolNameBoundaryRe.ReplaceAllString(output, "_")
+ output = strings.Trim(output, "_")
+ return strings.ToLower(output)
+}
+
+func normalizeToolNameForClaude(name string, cache map[string]string) string {
+ if name == "" {
+ return name
+ }
+ stripped := stripToolPrefix(name)
+ mapped, ok := claudeToolNameOverrides[strings.ToLower(stripped)]
+ if !ok {
+ mapped = toPascalCase(stripped)
+ }
+ if mapped != "" && cache != nil && mapped != stripped {
+ cache[mapped] = stripped
+ }
+ if mapped == "" {
+ return stripped
+ }
+ return mapped
+}
+
+func normalizeToolNameForOpenCode(name string, cache map[string]string) string {
+ if name == "" {
+ return name
+ }
+ if cache != nil {
+ if mapped, ok := cache[name]; ok {
+ return mapped
+ }
+ }
+ if mapped, ok := openCodeToolOverrides[name]; ok {
+ return mapped
+ }
+ return toSnakeCase(name)
+}
+
+func stripCacheControlFromSystemBlocks(system any) bool {
+ blocks, ok := system.([]any)
+ if !ok {
+ return false
+ }
+ changed := false
+ for _, item := range blocks {
+ block, ok := item.(map[string]any)
+ if !ok {
+ continue
+ }
+ if _, exists := block["cache_control"]; !exists {
+ continue
+ }
+ if text, ok := block["text"].(string); ok && text == claudeCodeSystemPrompt {
+ continue
+ }
+ delete(block, "cache_control")
+ changed = true
+ }
+ return changed
+}
+
+func normalizeClaudeOAuthRequestBody(body []byte, modelID string, opts claudeOAuthNormalizeOptions) ([]byte, string, map[string]string) {
+ if len(body) == 0 {
+ return body, modelID, nil
+ }
+ var req map[string]any
+ if err := json.Unmarshal(body, &req); err != nil {
+ return body, modelID, nil
+ }
+
+ toolNameMap := make(map[string]string)
+
+ if rawModel, ok := req["model"].(string); ok {
+ normalized := claude.NormalizeModelID(rawModel)
+ if normalized != rawModel {
+ req["model"] = normalized
+ modelID = normalized
+ }
+ }
+
+ if rawTools, exists := req["tools"]; exists {
+ switch tools := rawTools.(type) {
+ case []any:
+ for idx, tool := range tools {
+ toolMap, ok := tool.(map[string]any)
+ if !ok {
+ continue
+ }
+ if name, ok := toolMap["name"].(string); ok {
+ normalized := normalizeToolNameForClaude(name, toolNameMap)
+ if normalized != "" && normalized != name {
+ toolMap["name"] = normalized
+ }
+ }
+ tools[idx] = toolMap
+ }
+ req["tools"] = tools
+ case map[string]any:
+ normalizedTools := make(map[string]any, len(tools))
+ for name, value := range tools {
+ normalized := normalizeToolNameForClaude(name, toolNameMap)
+ if normalized == "" {
+ normalized = name
+ }
+ if toolMap, ok := value.(map[string]any); ok {
+ if toolName, ok := toolMap["name"].(string); ok {
+ mappedName := normalizeToolNameForClaude(toolName, toolNameMap)
+ if mappedName != "" && mappedName != toolName {
+ toolMap["name"] = mappedName
+ }
+ } else if normalized != name {
+ toolMap["name"] = normalized
+ }
+ normalizedTools[normalized] = toolMap
+ continue
+ }
+ normalizedTools[normalized] = value
+ }
+ req["tools"] = normalizedTools
+ }
+ } else {
+ req["tools"] = []any{}
+ }
+
+ if messages, ok := req["messages"].([]any); ok {
+ for _, msg := range messages {
+ msgMap, ok := msg.(map[string]any)
+ if !ok {
+ continue
+ }
+ content, ok := msgMap["content"].([]any)
+ if !ok {
+ continue
+ }
+ for _, block := range content {
+ blockMap, ok := block.(map[string]any)
+ if !ok {
+ continue
+ }
+ if blockType, _ := blockMap["type"].(string); blockType != "tool_use" {
+ continue
+ }
+ if name, ok := blockMap["name"].(string); ok {
+ normalized := normalizeToolNameForClaude(name, toolNameMap)
+ if normalized != "" && normalized != name {
+ blockMap["name"] = normalized
+ }
+ }
+ }
+ }
+ }
+
+ if opts.stripSystemCacheControl {
+ if system, ok := req["system"]; ok {
+ _ = stripCacheControlFromSystemBlocks(system)
+ }
+ }
+
+ if opts.injectMetadata && opts.metadataUserID != "" {
+ metadata, ok := req["metadata"].(map[string]any)
+ if !ok {
+ metadata = map[string]any{}
+ req["metadata"] = metadata
+ }
+ if existing, ok := metadata["user_id"].(string); !ok || existing == "" {
+ metadata["user_id"] = opts.metadataUserID
+ }
+ }
+
+ if _, ok := req["temperature"]; ok {
+ delete(req, "temperature")
+ }
+ if _, ok := req["tool_choice"]; ok {
+ delete(req, "tool_choice")
+ }
+
+ newBody, err := json.Marshal(req)
+ if err != nil {
+ return body, modelID, toolNameMap
+ }
+ return newBody, modelID, toolNameMap
+}
+
+func (s *GatewayService) buildOAuthMetadataUserID(parsed *ParsedRequest, account *Account, fp *Fingerprint) string {
+ if parsed == nil || fp == nil || fp.ClientID == "" {
+ return ""
+ }
+ if parsed.MetadataUserID != "" {
+ return ""
+ }
+ accountUUID := account.GetExtraString("account_uuid")
+ if accountUUID == "" {
+ return ""
+ }
+ sessionHash := s.GenerateSessionHash(parsed)
+ sessionID := uuid.NewString()
+ if sessionHash != "" {
+ seed := fmt.Sprintf("%d::%s", account.ID, sessionHash)
+ sessionID = generateSessionUUID(seed)
+ }
+ return fmt.Sprintf("user_%s_account_%s_session_%s", fp.ClientID, accountUUID, sessionID)
+}
+
+func generateSessionUUID(seed string) string {
+ if seed == "" {
+ return uuid.NewString()
+ }
+ hash := sha256.Sum256([]byte(seed))
+ bytes := hash[:16]
+ bytes[6] = (bytes[6] & 0x0f) | 0x40
+ bytes[8] = (bytes[8] & 0x3f) | 0x80
+ return fmt.Sprintf("%x-%x-%x-%x-%x",
+ bytes[0:4], bytes[4:6], bytes[6:8], bytes[8:10], bytes[10:16])
+}
+
// SelectAccount 选择账号(粘性会话+优先级)
func (s *GatewayService) SelectAccount(ctx context.Context, groupID *int64, sessionHash string) (*Account, error) {
return s.SelectAccountForModel(ctx, groupID, sessionHash, "")
@@ -1423,21 +1717,36 @@ func (s *GatewayService) Forward(ctx context.Context, c *gin.Context, account *A
body := parsed.Body
reqModel := parsed.Model
reqStream := parsed.Stream
+ originalModel := reqModel
+ var toolNameMap map[string]string
- // 智能注入 Claude Code 系统提示词(仅 OAuth/SetupToken 账号需要)
- // 条件:1) OAuth/SetupToken 账号 2) 不是 Claude Code 客户端 3) 不是 Haiku 模型 4) system 中还没有 Claude Code 提示词
- if account.IsOAuth() &&
- !isClaudeCodeClient(c.GetHeader("User-Agent"), parsed.MetadataUserID) &&
- !strings.Contains(strings.ToLower(reqModel), "haiku") &&
- !systemIncludesClaudeCodePrompt(parsed.System) {
- body = injectClaudeCodePrompt(body, parsed.System)
+ if account.IsOAuth() {
+ // 智能注入 Claude Code 系统提示词(仅 OAuth/SetupToken 账号需要)
+ // 条件:1) OAuth/SetupToken 账号 2) 不是 Claude Code 客户端 3) 不是 Haiku 模型 4) system 中还没有 Claude Code 提示词
+ if !isClaudeCodeClient(c.GetHeader("User-Agent"), parsed.MetadataUserID) &&
+ !strings.Contains(strings.ToLower(reqModel), "haiku") &&
+ !systemIncludesClaudeCodePrompt(parsed.System) {
+ body = injectClaudeCodePrompt(body, parsed.System)
+ }
+
+ normalizeOpts := claudeOAuthNormalizeOptions{stripSystemCacheControl: true}
+ if s.identityService != nil {
+ fp, err := s.identityService.GetOrCreateFingerprint(ctx, account.ID, c.Request.Header)
+ if err == nil && fp != nil {
+ if metadataUserID := s.buildOAuthMetadataUserID(parsed, account, fp); metadataUserID != "" {
+ normalizeOpts.injectMetadata = true
+ normalizeOpts.metadataUserID = metadataUserID
+ }
+ }
+ }
+
+ body, reqModel, toolNameMap = normalizeClaudeOAuthRequestBody(body, reqModel, normalizeOpts)
}
// 强制执行 cache_control 块数量限制(最多 4 个)
body = enforceCacheControlLimit(body)
// 应用模型映射(仅对apikey类型账号)
- originalModel := reqModel
if account.Type == AccountTypeAPIKey {
mappedModel := account.GetMappedModel(reqModel)
if mappedModel != reqModel {
@@ -1465,7 +1774,7 @@ func (s *GatewayService) Forward(ctx context.Context, c *gin.Context, account *A
retryStart := time.Now()
for attempt := 1; attempt <= maxRetryAttempts; attempt++ {
// 构建上游请求(每次重试需要重新构建,因为请求体需要重新读取)
- upstreamReq, err := s.buildUpstreamRequest(ctx, c, account, body, token, tokenType, reqModel)
+ upstreamReq, err := s.buildUpstreamRequest(ctx, c, account, body, token, tokenType, reqModel, reqStream)
if err != nil {
return nil, err
}
@@ -1541,7 +1850,7 @@ func (s *GatewayService) Forward(ctx context.Context, c *gin.Context, account *A
// also downgrade tool_use/tool_result blocks to text.
filteredBody := FilterThinkingBlocksForRetry(body)
- retryReq, buildErr := s.buildUpstreamRequest(ctx, c, account, filteredBody, token, tokenType, reqModel)
+ retryReq, buildErr := s.buildUpstreamRequest(ctx, c, account, filteredBody, token, tokenType, reqModel, reqStream)
if buildErr == nil {
retryResp, retryErr := s.httpUpstream.Do(retryReq, proxyURL, account.ID, account.Concurrency)
if retryErr == nil {
@@ -1572,7 +1881,7 @@ func (s *GatewayService) Forward(ctx context.Context, c *gin.Context, account *A
if looksLikeToolSignatureError(msg2) && time.Since(retryStart) < maxRetryElapsed {
log.Printf("Account %d: signature retry still failing and looks tool-related, retrying with tool blocks downgraded", account.ID)
filteredBody2 := FilterSignatureSensitiveBlocksForRetry(body)
- retryReq2, buildErr2 := s.buildUpstreamRequest(ctx, c, account, filteredBody2, token, tokenType, reqModel)
+ retryReq2, buildErr2 := s.buildUpstreamRequest(ctx, c, account, filteredBody2, token, tokenType, reqModel, reqStream)
if buildErr2 == nil {
retryResp2, retryErr2 := s.httpUpstream.Do(retryReq2, proxyURL, account.ID, account.Concurrency)
if retryErr2 == nil {
@@ -1785,7 +2094,7 @@ func (s *GatewayService) Forward(ctx context.Context, c *gin.Context, account *A
var firstTokenMs *int
var clientDisconnect bool
if reqStream {
- streamResult, err := s.handleStreamingResponse(ctx, resp, c, account, startTime, originalModel, reqModel)
+ streamResult, err := s.handleStreamingResponse(ctx, resp, c, account, startTime, originalModel, reqModel, toolNameMap)
if err != nil {
if err.Error() == "have error in stream" {
return nil, &UpstreamFailoverError{
@@ -1798,7 +2107,7 @@ func (s *GatewayService) Forward(ctx context.Context, c *gin.Context, account *A
firstTokenMs = streamResult.firstTokenMs
clientDisconnect = streamResult.clientDisconnect
} else {
- usage, err = s.handleNonStreamingResponse(ctx, resp, c, account, originalModel, reqModel)
+ usage, err = s.handleNonStreamingResponse(ctx, resp, c, account, originalModel, reqModel, toolNameMap)
if err != nil {
return nil, err
}
@@ -1815,7 +2124,7 @@ func (s *GatewayService) Forward(ctx context.Context, c *gin.Context, account *A
}, nil
}
-func (s *GatewayService) buildUpstreamRequest(ctx context.Context, c *gin.Context, account *Account, body []byte, token, tokenType, modelID string) (*http.Request, error) {
+func (s *GatewayService) buildUpstreamRequest(ctx context.Context, c *gin.Context, account *Account, body []byte, token, tokenType, modelID string, reqStream bool) (*http.Request, error) {
// 确定目标URL
targetURL := claudeAPIURL
if account.Type == AccountTypeAPIKey {
@@ -1884,6 +2193,9 @@ func (s *GatewayService) buildUpstreamRequest(ctx context.Context, c *gin.Contex
if req.Header.Get("anthropic-version") == "" {
req.Header.Set("anthropic-version", "2023-06-01")
}
+ if tokenType == "oauth" {
+ applyClaudeOAuthHeaderDefaults(req, reqStream)
+ }
// 处理anthropic-beta header(OAuth账号需要特殊处理)
if tokenType == "oauth" {
@@ -1966,6 +2278,26 @@ func defaultAPIKeyBetaHeader(body []byte) string {
return claude.APIKeyBetaHeader
}
+func applyClaudeOAuthHeaderDefaults(req *http.Request, isStream bool) {
+ if req == nil {
+ return
+ }
+ if req.Header.Get("accept") == "" {
+ req.Header.Set("accept", "application/json")
+ }
+ for key, value := range claude.DefaultHeaders {
+ if value == "" {
+ continue
+ }
+ if req.Header.Get(key) == "" {
+ req.Header.Set(key, value)
+ }
+ }
+ if isStream && req.Header.Get("x-stainless-helper-method") == "" {
+ req.Header.Set("x-stainless-helper-method", "stream")
+ }
+}
+
func truncateForLog(b []byte, maxBytes int) string {
if maxBytes <= 0 {
maxBytes = 2048
@@ -2246,7 +2578,7 @@ type streamingResult struct {
clientDisconnect bool // 客户端是否在流式传输过程中断开
}
-func (s *GatewayService) handleStreamingResponse(ctx context.Context, resp *http.Response, c *gin.Context, account *Account, startTime time.Time, originalModel, mappedModel string) (*streamingResult, error) {
+func (s *GatewayService) handleStreamingResponse(ctx context.Context, resp *http.Response, c *gin.Context, account *Account, startTime time.Time, originalModel, mappedModel string, toolNameMap map[string]string) (*streamingResult, error) {
// 更新5h窗口状态
s.rateLimitService.UpdateSessionWindow(ctx, account, resp.Header)
@@ -2339,6 +2671,7 @@ func (s *GatewayService) handleStreamingResponse(ctx context.Context, resp *http
}
needModelReplace := originalModel != mappedModel
+ rewriteTools := account.IsOAuth()
clientDisconnected := false // 客户端断开标志,断开后继续读取上游以获取完整usage
for {
@@ -2380,11 +2713,14 @@ func (s *GatewayService) handleStreamingResponse(ctx context.Context, resp *http
// Extract data from SSE line (supports both "data: " and "data:" formats)
var data string
if sseDataRe.MatchString(line) {
- data = sseDataRe.ReplaceAllString(line, "")
// 如果有模型映射,替换响应中的model字段
if needModelReplace {
line = s.replaceModelInSSELine(line, mappedModel, originalModel)
}
+ if rewriteTools {
+ line = s.replaceToolNamesInSSELine(line, toolNameMap)
+ }
+ data = sseDataRe.ReplaceAllString(line, "")
}
// 写入客户端(统一处理 data 行和非 data 行)
@@ -2467,6 +2803,61 @@ func (s *GatewayService) replaceModelInSSELine(line, fromModel, toModel string)
return "data: " + string(newData)
}
+func rewriteToolNamesInValue(value any, toolNameMap map[string]string) bool {
+ switch v := value.(type) {
+ case map[string]any:
+ changed := false
+ if blockType, _ := v["type"].(string); blockType == "tool_use" {
+ if name, ok := v["name"].(string); ok {
+ mapped := normalizeToolNameForOpenCode(name, toolNameMap)
+ if mapped != name {
+ v["name"] = mapped
+ changed = true
+ }
+ }
+ }
+ for _, item := range v {
+ if rewriteToolNamesInValue(item, toolNameMap) {
+ changed = true
+ }
+ }
+ return changed
+ case []any:
+ changed := false
+ for _, item := range v {
+ if rewriteToolNamesInValue(item, toolNameMap) {
+ changed = true
+ }
+ }
+ return changed
+ default:
+ return false
+ }
+}
+
+func (s *GatewayService) replaceToolNamesInSSELine(line string, toolNameMap map[string]string) string {
+ if !sseDataRe.MatchString(line) {
+ return line
+ }
+ data := sseDataRe.ReplaceAllString(line, "")
+ if data == "" || data == "[DONE]" {
+ return line
+ }
+
+ var event map[string]any
+ if err := json.Unmarshal([]byte(data), &event); err != nil {
+ return line
+ }
+ if !rewriteToolNamesInValue(event, toolNameMap) {
+ return line
+ }
+ newData, err := json.Marshal(event)
+ if err != nil {
+ return line
+ }
+ return "data: " + string(newData)
+}
+
func (s *GatewayService) parseSSEUsage(data string, usage *ClaudeUsage) {
// 解析message_start获取input tokens(标准Claude API格式)
var msgStart struct {
@@ -2508,7 +2899,7 @@ func (s *GatewayService) parseSSEUsage(data string, usage *ClaudeUsage) {
}
}
-func (s *GatewayService) handleNonStreamingResponse(ctx context.Context, resp *http.Response, c *gin.Context, account *Account, originalModel, mappedModel string) (*ClaudeUsage, error) {
+func (s *GatewayService) handleNonStreamingResponse(ctx context.Context, resp *http.Response, c *gin.Context, account *Account, originalModel, mappedModel string, toolNameMap map[string]string) (*ClaudeUsage, error) {
// 更新5h窗口状态
s.rateLimitService.UpdateSessionWindow(ctx, account, resp.Header)
@@ -2529,6 +2920,9 @@ func (s *GatewayService) handleNonStreamingResponse(ctx context.Context, resp *h
if originalModel != mappedModel {
body = s.replaceModelInResponseBody(body, mappedModel, originalModel)
}
+ if account.IsOAuth() {
+ body = s.replaceToolNamesInResponseBody(body, toolNameMap)
+ }
responseheaders.WriteFilteredHeaders(c.Writer.Header(), resp.Header, s.cfg.Security.ResponseHeaders)
@@ -2566,6 +2960,24 @@ func (s *GatewayService) replaceModelInResponseBody(body []byte, fromModel, toMo
return newBody
}
+func (s *GatewayService) replaceToolNamesInResponseBody(body []byte, toolNameMap map[string]string) []byte {
+ if len(body) == 0 {
+ return body
+ }
+ var resp map[string]any
+ if err := json.Unmarshal(body, &resp); err != nil {
+ return body
+ }
+ if !rewriteToolNamesInValue(resp, toolNameMap) {
+ return body
+ }
+ newBody, err := json.Marshal(resp)
+ if err != nil {
+ return body
+ }
+ return newBody
+}
+
// RecordUsageInput 记录使用量的输入参数
type RecordUsageInput struct {
Result *ForwardResult
@@ -2729,6 +3141,11 @@ func (s *GatewayService) ForwardCountTokens(ctx context.Context, c *gin.Context,
body := parsed.Body
reqModel := parsed.Model
+ if account.IsOAuth() {
+ normalizeOpts := claudeOAuthNormalizeOptions{stripSystemCacheControl: true}
+ body, reqModel, _ = normalizeClaudeOAuthRequestBody(body, reqModel, normalizeOpts)
+ }
+
// Antigravity 账户不支持 count_tokens 转发,直接返回空值
if account.Platform == PlatformAntigravity {
c.JSON(http.StatusOK, gin.H{"input_tokens": 0})
@@ -2917,6 +3334,9 @@ func (s *GatewayService) buildCountTokensRequest(ctx context.Context, c *gin.Con
if req.Header.Get("anthropic-version") == "" {
req.Header.Set("anthropic-version", "2023-06-01")
}
+ if tokenType == "oauth" {
+ applyClaudeOAuthHeaderDefaults(req, false)
+ }
// OAuth 账号:处理 anthropic-beta header
if tokenType == "oauth" {
diff --git a/backend/internal/service/identity_service.go b/backend/internal/service/identity_service.go
index 1ffa8057..4ab1ab96 100644
--- a/backend/internal/service/identity_service.go
+++ b/backend/internal/service/identity_service.go
@@ -24,13 +24,13 @@ var (
// 默认指纹值(当客户端未提供时使用)
var defaultFingerprint = Fingerprint{
- UserAgent: "claude-cli/2.0.62 (external, cli)",
+ UserAgent: "claude-cli/2.1.2 (external, cli)",
StainlessLang: "js",
- StainlessPackageVersion: "0.52.0",
+ StainlessPackageVersion: "0.70.0",
StainlessOS: "Linux",
StainlessArch: "x64",
StainlessRuntime: "node",
- StainlessRuntimeVersion: "v22.14.0",
+ StainlessRuntimeVersion: "v24.3.0",
}
// Fingerprint represents account fingerprint data
@@ -230,7 +230,7 @@ func generateUUIDFromSeed(seed string) string {
}
// parseUserAgentVersion 解析user-agent版本号
-// 例如:claude-cli/2.0.62 -> (2, 0, 62)
+// 例如:claude-cli/2.1.2 -> (2, 1, 2)
func parseUserAgentVersion(ua string) (major, minor, patch int, ok bool) {
// 匹配 xxx/x.y.z 格式
matches := userAgentVersionRegex.FindStringSubmatch(ua)
From c579439c1ea42636ed7e7447e133a98bedfa7091 Mon Sep 17 00:00:00 2001
From: cyhhao
Date: Thu, 15 Jan 2026 19:17:07 +0800
Subject: [PATCH 002/155] =?UTF-8?q?fix(=E7=BD=91=E5=85=B3):=20=E5=8C=BA?=
=?UTF-8?q?=E5=88=86=20Claude=20Code=20OAuth=20=E9=80=82=E9=85=8D?=
MIME-Version: 1.0
Content-Type: text/plain; charset=UTF-8
Content-Transfer-Encoding: 8bit
---
backend/internal/handler/gateway_handler.go | 3 +
backend/internal/pkg/claude/constants.go | 4 +
backend/internal/service/gateway_service.go | 110 +++++++++++++++-----
3 files changed, 90 insertions(+), 27 deletions(-)
diff --git a/backend/internal/handler/gateway_handler.go b/backend/internal/handler/gateway_handler.go
index b60618a8..91d590bf 100644
--- a/backend/internal/handler/gateway_handler.go
+++ b/backend/internal/handler/gateway_handler.go
@@ -707,6 +707,9 @@ func (h *GatewayHandler) CountTokens(c *gin.Context) {
return
}
+ // 检查是否为 Claude Code 客户端,设置到 context 中
+ SetClaudeCodeClientContext(c, body)
+
setOpsRequestContext(c, "", false, body)
parsedReq, err := service.ParseGatewayRequest(body)
diff --git a/backend/internal/pkg/claude/constants.go b/backend/internal/pkg/claude/constants.go
index 15144881..f60412c2 100644
--- a/backend/internal/pkg/claude/constants.go
+++ b/backend/internal/pkg/claude/constants.go
@@ -9,11 +9,15 @@ const (
BetaClaudeCode = "claude-code-20250219"
BetaInterleavedThinking = "interleaved-thinking-2025-05-14"
BetaFineGrainedToolStreaming = "fine-grained-tool-streaming-2025-05-14"
+ BetaTokenCounting = "token-counting-2024-11-01"
)
// DefaultBetaHeader Claude Code 客户端默认的 anthropic-beta header
const DefaultBetaHeader = BetaClaudeCode + "," + BetaOAuth + "," + BetaInterleavedThinking + "," + BetaFineGrainedToolStreaming
+// CountTokensBetaHeader count_tokens 请求使用的 anthropic-beta header
+const CountTokensBetaHeader = BetaClaudeCode + "," + BetaOAuth + "," + BetaInterleavedThinking + "," + BetaTokenCounting
+
// HaikuBetaHeader Haiku 模型使用的 anthropic-beta header(不需要 claude-code beta)
const HaikuBetaHeader = BetaOAuth + "," + BetaInterleavedThinking
diff --git a/backend/internal/service/gateway_service.go b/backend/internal/service/gateway_service.go
index 1d29b3fd..904b5acd 100644
--- a/backend/internal/service/gateway_service.go
+++ b/backend/internal/service/gateway_service.go
@@ -49,6 +49,8 @@ var (
toolPrefixRe = regexp.MustCompile(`(?i)^(?:oc_|mcp_)`)
toolNameBoundaryRe = regexp.MustCompile(`[^a-zA-Z0-9]+`)
toolNameCamelRe = regexp.MustCompile(`([a-z0-9])([A-Z])`)
+ toolNameFieldRe = regexp.MustCompile(`"name"\s*:\s*"([^"]+)"`)
+ modelFieldRe = regexp.MustCompile(`"model"\s*:\s*"([^"]+)"`)
claudeToolNameOverrides = map[string]string{
"bash": "Bash",
@@ -1458,6 +1460,16 @@ func isClaudeCodeClient(userAgent string, metadataUserID string) bool {
return claudeCliUserAgentRe.MatchString(userAgent)
}
+func isClaudeCodeRequest(ctx context.Context, c *gin.Context, parsed *ParsedRequest) bool {
+ if IsClaudeCodeClient(ctx) {
+ return true
+ }
+ if parsed == nil || c == nil {
+ return false
+ }
+ return isClaudeCodeClient(c.GetHeader("User-Agent"), parsed.MetadataUserID)
+}
+
// systemIncludesClaudeCodePrompt 检查 system 中是否已包含 Claude Code 提示词
// 使用前缀匹配支持多种变体(标准版、Agent SDK 版等)
func systemIncludesClaudeCodePrompt(system any) bool {
@@ -1720,11 +1732,13 @@ func (s *GatewayService) Forward(ctx context.Context, c *gin.Context, account *A
originalModel := reqModel
var toolNameMap map[string]string
- if account.IsOAuth() {
+ isClaudeCode := isClaudeCodeRequest(ctx, c, parsed)
+ shouldMimicClaudeCode := account.IsOAuth() && !isClaudeCode
+
+ if shouldMimicClaudeCode {
// 智能注入 Claude Code 系统提示词(仅 OAuth/SetupToken 账号需要)
// 条件:1) OAuth/SetupToken 账号 2) 不是 Claude Code 客户端 3) 不是 Haiku 模型 4) system 中还没有 Claude Code 提示词
- if !isClaudeCodeClient(c.GetHeader("User-Agent"), parsed.MetadataUserID) &&
- !strings.Contains(strings.ToLower(reqModel), "haiku") &&
+ if !strings.Contains(strings.ToLower(reqModel), "haiku") &&
!systemIncludesClaudeCodePrompt(parsed.System) {
body = injectClaudeCodePrompt(body, parsed.System)
}
@@ -1774,7 +1788,7 @@ func (s *GatewayService) Forward(ctx context.Context, c *gin.Context, account *A
retryStart := time.Now()
for attempt := 1; attempt <= maxRetryAttempts; attempt++ {
// 构建上游请求(每次重试需要重新构建,因为请求体需要重新读取)
- upstreamReq, err := s.buildUpstreamRequest(ctx, c, account, body, token, tokenType, reqModel, reqStream)
+ upstreamReq, err := s.buildUpstreamRequest(ctx, c, account, body, token, tokenType, reqModel, reqStream, shouldMimicClaudeCode)
if err != nil {
return nil, err
}
@@ -1850,7 +1864,7 @@ func (s *GatewayService) Forward(ctx context.Context, c *gin.Context, account *A
// also downgrade tool_use/tool_result blocks to text.
filteredBody := FilterThinkingBlocksForRetry(body)
- retryReq, buildErr := s.buildUpstreamRequest(ctx, c, account, filteredBody, token, tokenType, reqModel, reqStream)
+ retryReq, buildErr := s.buildUpstreamRequest(ctx, c, account, filteredBody, token, tokenType, reqModel, reqStream, shouldMimicClaudeCode)
if buildErr == nil {
retryResp, retryErr := s.httpUpstream.Do(retryReq, proxyURL, account.ID, account.Concurrency)
if retryErr == nil {
@@ -1881,7 +1895,7 @@ func (s *GatewayService) Forward(ctx context.Context, c *gin.Context, account *A
if looksLikeToolSignatureError(msg2) && time.Since(retryStart) < maxRetryElapsed {
log.Printf("Account %d: signature retry still failing and looks tool-related, retrying with tool blocks downgraded", account.ID)
filteredBody2 := FilterSignatureSensitiveBlocksForRetry(body)
- retryReq2, buildErr2 := s.buildUpstreamRequest(ctx, c, account, filteredBody2, token, tokenType, reqModel, reqStream)
+ retryReq2, buildErr2 := s.buildUpstreamRequest(ctx, c, account, filteredBody2, token, tokenType, reqModel, reqStream, shouldMimicClaudeCode)
if buildErr2 == nil {
retryResp2, retryErr2 := s.httpUpstream.Do(retryReq2, proxyURL, account.ID, account.Concurrency)
if retryErr2 == nil {
@@ -2094,7 +2108,7 @@ func (s *GatewayService) Forward(ctx context.Context, c *gin.Context, account *A
var firstTokenMs *int
var clientDisconnect bool
if reqStream {
- streamResult, err := s.handleStreamingResponse(ctx, resp, c, account, startTime, originalModel, reqModel, toolNameMap)
+ streamResult, err := s.handleStreamingResponse(ctx, resp, c, account, startTime, originalModel, reqModel, toolNameMap, shouldMimicClaudeCode)
if err != nil {
if err.Error() == "have error in stream" {
return nil, &UpstreamFailoverError{
@@ -2107,7 +2121,7 @@ func (s *GatewayService) Forward(ctx context.Context, c *gin.Context, account *A
firstTokenMs = streamResult.firstTokenMs
clientDisconnect = streamResult.clientDisconnect
} else {
- usage, err = s.handleNonStreamingResponse(ctx, resp, c, account, originalModel, reqModel, toolNameMap)
+ usage, err = s.handleNonStreamingResponse(ctx, resp, c, account, originalModel, reqModel, toolNameMap, shouldMimicClaudeCode)
if err != nil {
return nil, err
}
@@ -2124,7 +2138,7 @@ func (s *GatewayService) Forward(ctx context.Context, c *gin.Context, account *A
}, nil
}
-func (s *GatewayService) buildUpstreamRequest(ctx context.Context, c *gin.Context, account *Account, body []byte, token, tokenType, modelID string, reqStream bool) (*http.Request, error) {
+func (s *GatewayService) buildUpstreamRequest(ctx context.Context, c *gin.Context, account *Account, body []byte, token, tokenType, modelID string, reqStream bool, mimicClaudeCode bool) (*http.Request, error) {
// 确定目标URL
targetURL := claudeAPIURL
if account.Type == AccountTypeAPIKey {
@@ -2140,7 +2154,7 @@ func (s *GatewayService) buildUpstreamRequest(ctx context.Context, c *gin.Contex
// OAuth账号:应用统一指纹
var fingerprint *Fingerprint
- if account.IsOAuth() && s.identityService != nil {
+ if account.IsOAuth() && mimicClaudeCode && s.identityService != nil {
// 1. 获取或创建指纹(包含随机生成的ClientID)
fp, err := s.identityService.GetOrCreateFingerprint(ctx, account.ID, c.Request.Header)
if err != nil {
@@ -2193,12 +2207,12 @@ func (s *GatewayService) buildUpstreamRequest(ctx context.Context, c *gin.Contex
if req.Header.Get("anthropic-version") == "" {
req.Header.Set("anthropic-version", "2023-06-01")
}
- if tokenType == "oauth" {
+ if tokenType == "oauth" && mimicClaudeCode {
applyClaudeOAuthHeaderDefaults(req, reqStream)
}
// 处理anthropic-beta header(OAuth账号需要特殊处理)
- if tokenType == "oauth" {
+ if tokenType == "oauth" && mimicClaudeCode {
req.Header.Set("anthropic-beta", s.getBetaHeader(modelID, c.GetHeader("anthropic-beta")))
} else if s.cfg != nil && s.cfg.Gateway.InjectBetaForAPIKey && req.Header.Get("anthropic-beta") == "" {
// API-key:仅在请求显式使用 beta 特性且客户端未提供时,按需补齐(默认关闭)
@@ -2578,7 +2592,7 @@ type streamingResult struct {
clientDisconnect bool // 客户端是否在流式传输过程中断开
}
-func (s *GatewayService) handleStreamingResponse(ctx context.Context, resp *http.Response, c *gin.Context, account *Account, startTime time.Time, originalModel, mappedModel string, toolNameMap map[string]string) (*streamingResult, error) {
+func (s *GatewayService) handleStreamingResponse(ctx context.Context, resp *http.Response, c *gin.Context, account *Account, startTime time.Time, originalModel, mappedModel string, toolNameMap map[string]string, mimicClaudeCode bool) (*streamingResult, error) {
// 更新5h窗口状态
s.rateLimitService.UpdateSessionWindow(ctx, account, resp.Header)
@@ -2671,7 +2685,7 @@ func (s *GatewayService) handleStreamingResponse(ctx context.Context, resp *http
}
needModelReplace := originalModel != mappedModel
- rewriteTools := account.IsOAuth()
+ rewriteTools := mimicClaudeCode
clientDisconnected := false // 客户端断开标志,断开后继续读取上游以获取完整usage
for {
@@ -2835,6 +2849,37 @@ func rewriteToolNamesInValue(value any, toolNameMap map[string]string) bool {
}
}
+func replaceToolNamesInText(text string, toolNameMap map[string]string) string {
+ if text == "" {
+ return text
+ }
+ output := toolNameFieldRe.ReplaceAllStringFunc(text, func(match string) string {
+ submatches := toolNameFieldRe.FindStringSubmatch(match)
+ if len(submatches) < 2 {
+ return match
+ }
+ name := submatches[1]
+ mapped := normalizeToolNameForOpenCode(name, toolNameMap)
+ if mapped == name {
+ return match
+ }
+ return strings.Replace(match, name, mapped, 1)
+ })
+ output = modelFieldRe.ReplaceAllStringFunc(output, func(match string) string {
+ submatches := modelFieldRe.FindStringSubmatch(match)
+ if len(submatches) < 2 {
+ return match
+ }
+ model := submatches[1]
+ mapped := claude.DenormalizeModelID(model)
+ if mapped == model {
+ return match
+ }
+ return strings.Replace(match, model, mapped, 1)
+ })
+ return output
+}
+
func (s *GatewayService) replaceToolNamesInSSELine(line string, toolNameMap map[string]string) string {
if !sseDataRe.MatchString(line) {
return line
@@ -2846,7 +2891,11 @@ func (s *GatewayService) replaceToolNamesInSSELine(line string, toolNameMap map[
var event map[string]any
if err := json.Unmarshal([]byte(data), &event); err != nil {
- return line
+ replaced := replaceToolNamesInText(data, toolNameMap)
+ if replaced == data {
+ return line
+ }
+ return "data: " + replaced
}
if !rewriteToolNamesInValue(event, toolNameMap) {
return line
@@ -2899,7 +2948,7 @@ func (s *GatewayService) parseSSEUsage(data string, usage *ClaudeUsage) {
}
}
-func (s *GatewayService) handleNonStreamingResponse(ctx context.Context, resp *http.Response, c *gin.Context, account *Account, originalModel, mappedModel string, toolNameMap map[string]string) (*ClaudeUsage, error) {
+func (s *GatewayService) handleNonStreamingResponse(ctx context.Context, resp *http.Response, c *gin.Context, account *Account, originalModel, mappedModel string, toolNameMap map[string]string, mimicClaudeCode bool) (*ClaudeUsage, error) {
// 更新5h窗口状态
s.rateLimitService.UpdateSessionWindow(ctx, account, resp.Header)
@@ -2920,7 +2969,7 @@ func (s *GatewayService) handleNonStreamingResponse(ctx context.Context, resp *h
if originalModel != mappedModel {
body = s.replaceModelInResponseBody(body, mappedModel, originalModel)
}
- if account.IsOAuth() {
+ if mimicClaudeCode {
body = s.replaceToolNamesInResponseBody(body, toolNameMap)
}
@@ -2966,7 +3015,11 @@ func (s *GatewayService) replaceToolNamesInResponseBody(body []byte, toolNameMap
}
var resp map[string]any
if err := json.Unmarshal(body, &resp); err != nil {
- return body
+ replaced := replaceToolNamesInText(string(body), toolNameMap)
+ if replaced == string(body) {
+ return body
+ }
+ return []byte(replaced)
}
if !rewriteToolNamesInValue(resp, toolNameMap) {
return body
@@ -3141,7 +3194,10 @@ func (s *GatewayService) ForwardCountTokens(ctx context.Context, c *gin.Context,
body := parsed.Body
reqModel := parsed.Model
- if account.IsOAuth() {
+ isClaudeCode := isClaudeCodeRequest(ctx, c, parsed)
+ shouldMimicClaudeCode := account.IsOAuth() && !isClaudeCode
+
+ if shouldMimicClaudeCode {
normalizeOpts := claudeOAuthNormalizeOptions{stripSystemCacheControl: true}
body, reqModel, _ = normalizeClaudeOAuthRequestBody(body, reqModel, normalizeOpts)
}
@@ -3172,7 +3228,7 @@ func (s *GatewayService) ForwardCountTokens(ctx context.Context, c *gin.Context,
}
// 构建上游请求
- upstreamReq, err := s.buildCountTokensRequest(ctx, c, account, body, token, tokenType, reqModel)
+ upstreamReq, err := s.buildCountTokensRequest(ctx, c, account, body, token, tokenType, reqModel, shouldMimicClaudeCode)
if err != nil {
s.countTokensError(c, http.StatusInternalServerError, "api_error", "Failed to build request")
return err
@@ -3205,7 +3261,7 @@ func (s *GatewayService) ForwardCountTokens(ctx context.Context, c *gin.Context,
log.Printf("Account %d: detected thinking block signature error on count_tokens, retrying with filtered thinking blocks", account.ID)
filteredBody := FilterThinkingBlocksForRetry(body)
- retryReq, buildErr := s.buildCountTokensRequest(ctx, c, account, filteredBody, token, tokenType, reqModel)
+ retryReq, buildErr := s.buildCountTokensRequest(ctx, c, account, filteredBody, token, tokenType, reqModel, shouldMimicClaudeCode)
if buildErr == nil {
retryResp, retryErr := s.httpUpstream.Do(retryReq, proxyURL, account.ID, account.Concurrency)
if retryErr == nil {
@@ -3270,7 +3326,7 @@ func (s *GatewayService) ForwardCountTokens(ctx context.Context, c *gin.Context,
}
// buildCountTokensRequest 构建 count_tokens 上游请求
-func (s *GatewayService) buildCountTokensRequest(ctx context.Context, c *gin.Context, account *Account, body []byte, token, tokenType, modelID string) (*http.Request, error) {
+func (s *GatewayService) buildCountTokensRequest(ctx context.Context, c *gin.Context, account *Account, body []byte, token, tokenType, modelID string, mimicClaudeCode bool) (*http.Request, error) {
// 确定目标 URL
targetURL := claudeAPICountTokensURL
if account.Type == AccountTypeAPIKey {
@@ -3285,7 +3341,7 @@ func (s *GatewayService) buildCountTokensRequest(ctx context.Context, c *gin.Con
}
// OAuth 账号:应用统一指纹和重写 userID
- if account.IsOAuth() && s.identityService != nil {
+ if account.IsOAuth() && mimicClaudeCode && s.identityService != nil {
fp, err := s.identityService.GetOrCreateFingerprint(ctx, account.ID, c.Request.Header)
if err == nil {
accountUUID := account.GetExtraString("account_uuid")
@@ -3320,7 +3376,7 @@ func (s *GatewayService) buildCountTokensRequest(ctx context.Context, c *gin.Con
}
// OAuth 账号:应用指纹到请求头
- if account.IsOAuth() && s.identityService != nil {
+ if account.IsOAuth() && mimicClaudeCode && s.identityService != nil {
fp, _ := s.identityService.GetOrCreateFingerprint(ctx, account.ID, c.Request.Header)
if fp != nil {
s.identityService.ApplyFingerprint(req, fp)
@@ -3334,13 +3390,13 @@ func (s *GatewayService) buildCountTokensRequest(ctx context.Context, c *gin.Con
if req.Header.Get("anthropic-version") == "" {
req.Header.Set("anthropic-version", "2023-06-01")
}
- if tokenType == "oauth" {
+ if tokenType == "oauth" && mimicClaudeCode {
applyClaudeOAuthHeaderDefaults(req, false)
}
// OAuth 账号:处理 anthropic-beta header
- if tokenType == "oauth" {
- req.Header.Set("anthropic-beta", s.getBetaHeader(modelID, c.GetHeader("anthropic-beta")))
+ if tokenType == "oauth" && mimicClaudeCode {
+ req.Header.Set("anthropic-beta", claude.CountTokensBetaHeader)
} else if s.cfg != nil && s.cfg.Gateway.InjectBetaForAPIKey && req.Header.Get("anthropic-beta") == "" {
// API-key:与 messages 同步的按需 beta 注入(默认关闭)
if requestNeedsBetaFeatures(body) {
From 98b65e67f21189f441f92dec88ed40b3ba7e8561 Mon Sep 17 00:00:00 2001
From: cyhhao
Date: Thu, 15 Jan 2026 21:42:13 +0800
Subject: [PATCH 003/155] fix(gateway): avoid injecting invalid SSE on client
cancel
---
.../service/openai_gateway_service.go | 6 +++
.../service/openai_gateway_service_test.go | 37 +++++++++++++++++++
2 files changed, 43 insertions(+)
diff --git a/backend/internal/service/openai_gateway_service.go b/backend/internal/service/openai_gateway_service.go
index 04a90fdd..d49be282 100644
--- a/backend/internal/service/openai_gateway_service.go
+++ b/backend/internal/service/openai_gateway_service.go
@@ -1064,6 +1064,12 @@ func (s *OpenAIGatewayService) handleStreamingResponse(ctx context.Context, resp
return &openaiStreamingResult{usage: usage, firstTokenMs: firstTokenMs}, nil
}
if ev.err != nil {
+ // 客户端断开/取消请求时,上游读取往往会返回 context canceled。
+ // /v1/responses 的 SSE 事件必须符合 OpenAI 协议;这里不注入自定义 error event,避免下游 SDK 解析失败。
+ if errors.Is(ev.err, context.Canceled) || errors.Is(ev.err, context.DeadlineExceeded) {
+ log.Printf("Context canceled during streaming, returning collected usage")
+ return &openaiStreamingResult{usage: usage, firstTokenMs: firstTokenMs}, nil
+ }
if errors.Is(ev.err, bufio.ErrTooLong) {
log.Printf("SSE line too long: account=%d max_size=%d error=%v", account.ID, maxLineSize, ev.err)
sendErrorEvent("response_too_large")
diff --git a/backend/internal/service/openai_gateway_service_test.go b/backend/internal/service/openai_gateway_service_test.go
index 42b88b7d..ead6e143 100644
--- a/backend/internal/service/openai_gateway_service_test.go
+++ b/backend/internal/service/openai_gateway_service_test.go
@@ -33,6 +33,11 @@ type stubConcurrencyCache struct {
ConcurrencyCache
}
+type cancelReadCloser struct{}
+
+func (c cancelReadCloser) Read(p []byte) (int, error) { return 0, context.Canceled }
+func (c cancelReadCloser) Close() error { return nil }
+
func (c stubConcurrencyCache) AcquireAccountSlot(ctx context.Context, accountID int64, maxConcurrency int, requestID string) (bool, error) {
return true, nil
}
@@ -174,6 +179,38 @@ func TestOpenAIStreamingTimeout(t *testing.T) {
}
}
+func TestOpenAIStreamingContextCanceledDoesNotInjectErrorEvent(t *testing.T) {
+ gin.SetMode(gin.TestMode)
+ cfg := &config.Config{
+ Gateway: config.GatewayConfig{
+ StreamDataIntervalTimeout: 0,
+ StreamKeepaliveInterval: 0,
+ MaxLineSize: defaultMaxLineSize,
+ },
+ }
+ svc := &OpenAIGatewayService{cfg: cfg}
+
+ rec := httptest.NewRecorder()
+ c, _ := gin.CreateTestContext(rec)
+ ctx, cancel := context.WithCancel(context.Background())
+ cancel()
+ c.Request = httptest.NewRequest(http.MethodPost, "/", nil).WithContext(ctx)
+
+ resp := &http.Response{
+ StatusCode: http.StatusOK,
+ Body: cancelReadCloser{},
+ Header: http.Header{},
+ }
+
+ _, err := svc.handleStreamingResponse(c.Request.Context(), resp, c, &Account{ID: 1}, time.Now(), "model", "model")
+ if err != nil {
+ t.Fatalf("expected nil error, got %v", err)
+ }
+ if strings.Contains(rec.Body.String(), "event: error") || strings.Contains(rec.Body.String(), "stream_read_error") {
+ t.Fatalf("expected no injected SSE error event, got %q", rec.Body.String())
+ }
+}
+
func TestOpenAIStreamingTooLong(t *testing.T) {
gin.SetMode(gin.TestMode)
cfg := &config.Config{
From c11f14f3a030c30846183704ccd6193785899bd4 Mon Sep 17 00:00:00 2001
From: cyhhao
Date: Thu, 15 Jan 2026 21:51:14 +0800
Subject: [PATCH 004/155] fix(gateway): drain upstream after client disconnect
---
.../service/openai_gateway_service.go | 43 ++++++++++----
.../service/openai_gateway_service_test.go | 59 +++++++++++++++++++
2 files changed, 91 insertions(+), 11 deletions(-)
diff --git a/backend/internal/service/openai_gateway_service.go b/backend/internal/service/openai_gateway_service.go
index d49be282..fb811e9e 100644
--- a/backend/internal/service/openai_gateway_service.go
+++ b/backend/internal/service/openai_gateway_service.go
@@ -1046,8 +1046,9 @@ func (s *OpenAIGatewayService) handleStreamingResponse(ctx context.Context, resp
// 仅发送一次错误事件,避免多次写入导致协议混乱(写失败时尽力通知客户端)
errorEventSent := false
+ clientDisconnected := false // 客户端断开后继续 drain 上游以收集 usage
sendErrorEvent := func(reason string) {
- if errorEventSent {
+ if errorEventSent || clientDisconnected {
return
}
errorEventSent = true
@@ -1070,6 +1071,11 @@ func (s *OpenAIGatewayService) handleStreamingResponse(ctx context.Context, resp
log.Printf("Context canceled during streaming, returning collected usage")
return &openaiStreamingResult{usage: usage, firstTokenMs: firstTokenMs}, nil
}
+ // 客户端已断开时,上游出错仅影响体验,不影响计费;返回已收集 usage
+ if clientDisconnected {
+ log.Printf("Upstream read error after client disconnect: %v, returning collected usage", ev.err)
+ return &openaiStreamingResult{usage: usage, firstTokenMs: firstTokenMs}, nil
+ }
if errors.Is(ev.err, bufio.ErrTooLong) {
log.Printf("SSE line too long: account=%d max_size=%d error=%v", account.ID, maxLineSize, ev.err)
sendErrorEvent("response_too_large")
@@ -1091,12 +1097,15 @@ func (s *OpenAIGatewayService) handleStreamingResponse(ctx context.Context, resp
line = s.replaceModelInSSELine(line, mappedModel, originalModel)
}
- // Forward line
- if _, err := fmt.Fprintf(w, "%s\n", line); err != nil {
- sendErrorEvent("write_failed")
- return &openaiStreamingResult{usage: usage, firstTokenMs: firstTokenMs}, err
+ // 写入客户端(客户端断开后继续 drain 上游)
+ if !clientDisconnected {
+ if _, err := fmt.Fprintf(w, "%s\n", line); err != nil {
+ clientDisconnected = true
+ log.Printf("Client disconnected during streaming, continuing to drain upstream for billing")
+ } else {
+ flusher.Flush()
+ }
}
- flusher.Flush()
// Record first token time
if firstTokenMs == nil && data != "" && data != "[DONE]" {
@@ -1106,11 +1115,14 @@ func (s *OpenAIGatewayService) handleStreamingResponse(ctx context.Context, resp
s.parseSSEUsage(data, usage)
} else {
// Forward non-data lines as-is
- if _, err := fmt.Fprintf(w, "%s\n", line); err != nil {
- sendErrorEvent("write_failed")
- return &openaiStreamingResult{usage: usage, firstTokenMs: firstTokenMs}, err
+ if !clientDisconnected {
+ if _, err := fmt.Fprintf(w, "%s\n", line); err != nil {
+ clientDisconnected = true
+ log.Printf("Client disconnected during streaming, continuing to drain upstream for billing")
+ } else {
+ flusher.Flush()
+ }
}
- flusher.Flush()
}
case <-intervalCh:
@@ -1118,6 +1130,10 @@ func (s *OpenAIGatewayService) handleStreamingResponse(ctx context.Context, resp
if time.Since(lastRead) < streamInterval {
continue
}
+ if clientDisconnected {
+ log.Printf("Upstream timeout after client disconnect, returning collected usage")
+ return &openaiStreamingResult{usage: usage, firstTokenMs: firstTokenMs}, nil
+ }
log.Printf("Stream data interval timeout: account=%d model=%s interval=%s", account.ID, originalModel, streamInterval)
// 处理流超时,可能标记账户为临时不可调度或错误状态
if s.rateLimitService != nil {
@@ -1127,11 +1143,16 @@ func (s *OpenAIGatewayService) handleStreamingResponse(ctx context.Context, resp
return &openaiStreamingResult{usage: usage, firstTokenMs: firstTokenMs}, fmt.Errorf("stream data interval timeout")
case <-keepaliveCh:
+ if clientDisconnected {
+ continue
+ }
if time.Since(lastDataAt) < keepaliveInterval {
continue
}
if _, err := fmt.Fprint(w, ":\n\n"); err != nil {
- return &openaiStreamingResult{usage: usage, firstTokenMs: firstTokenMs}, err
+ clientDisconnected = true
+ log.Printf("Client disconnected during streaming, continuing to drain upstream for billing")
+ continue
}
flusher.Flush()
}
diff --git a/backend/internal/service/openai_gateway_service_test.go b/backend/internal/service/openai_gateway_service_test.go
index ead6e143..3ec37544 100644
--- a/backend/internal/service/openai_gateway_service_test.go
+++ b/backend/internal/service/openai_gateway_service_test.go
@@ -38,6 +38,20 @@ type cancelReadCloser struct{}
func (c cancelReadCloser) Read(p []byte) (int, error) { return 0, context.Canceled }
func (c cancelReadCloser) Close() error { return nil }
+type failingGinWriter struct {
+ gin.ResponseWriter
+ failAfter int
+ writes int
+}
+
+func (w *failingGinWriter) Write(p []byte) (int, error) {
+ if w.writes >= w.failAfter {
+ return 0, errors.New("write failed")
+ }
+ w.writes++
+ return w.ResponseWriter.Write(p)
+}
+
func (c stubConcurrencyCache) AcquireAccountSlot(ctx context.Context, accountID int64, maxConcurrency int, requestID string) (bool, error) {
return true, nil
}
@@ -211,6 +225,51 @@ func TestOpenAIStreamingContextCanceledDoesNotInjectErrorEvent(t *testing.T) {
}
}
+func TestOpenAIStreamingClientDisconnectDrainsUpstreamUsage(t *testing.T) {
+ gin.SetMode(gin.TestMode)
+ cfg := &config.Config{
+ Gateway: config.GatewayConfig{
+ StreamDataIntervalTimeout: 0,
+ StreamKeepaliveInterval: 0,
+ MaxLineSize: defaultMaxLineSize,
+ },
+ }
+ svc := &OpenAIGatewayService{cfg: cfg}
+
+ rec := httptest.NewRecorder()
+ c, _ := gin.CreateTestContext(rec)
+ c.Request = httptest.NewRequest(http.MethodPost, "/", nil)
+ c.Writer = &failingGinWriter{ResponseWriter: c.Writer, failAfter: 0}
+
+ pr, pw := io.Pipe()
+ resp := &http.Response{
+ StatusCode: http.StatusOK,
+ Body: pr,
+ Header: http.Header{},
+ }
+
+ go func() {
+ defer func() { _ = pw.Close() }()
+ _, _ = pw.Write([]byte("data: {\"type\":\"response.in_progress\",\"response\":{}}\n\n"))
+ _, _ = pw.Write([]byte("data: {\"type\":\"response.completed\",\"response\":{\"usage\":{\"input_tokens\":3,\"output_tokens\":5,\"input_tokens_details\":{\"cached_tokens\":1}}}}\n\n"))
+ }()
+
+ result, err := svc.handleStreamingResponse(c.Request.Context(), resp, c, &Account{ID: 1}, time.Now(), "model", "model")
+ _ = pr.Close()
+ if err != nil {
+ t.Fatalf("expected nil error, got %v", err)
+ }
+ if result == nil || result.usage == nil {
+ t.Fatalf("expected usage result")
+ }
+ if result.usage.InputTokens != 3 || result.usage.OutputTokens != 5 || result.usage.CacheReadInputTokens != 1 {
+ t.Fatalf("unexpected usage: %+v", *result.usage)
+ }
+ if strings.Contains(rec.Body.String(), "event: error") || strings.Contains(rec.Body.String(), "write_failed") {
+ t.Fatalf("expected no injected SSE error event, got %q", rec.Body.String())
+ }
+}
+
func TestOpenAIStreamingTooLong(t *testing.T) {
gin.SetMode(gin.TestMode)
cfg := &config.Config{
From 65fd0d15ae0f5b1b454d27a02e7df3e8b5670b2d Mon Sep 17 00:00:00 2001
From: cyhhao
Date: Fri, 16 Jan 2026 00:41:29 +0800
Subject: [PATCH 005/155] =?UTF-8?q?fix(=E7=BD=91=E5=85=B3):=20=E8=A1=A5?=
=?UTF-8?q?=E9=BD=90=E9=9D=9E=20Claude=20Code=20OAuth=20=E5=85=BC=E5=AE=B9?=
MIME-Version: 1.0
Content-Type: text/plain; charset=UTF-8
Content-Transfer-Encoding: 8bit
---
backend/internal/pkg/claude/constants.go | 6 +
backend/internal/service/account.go | 16 ++
backend/internal/service/gateway_service.go | 239 +++++++++++++++++---
3 files changed, 232 insertions(+), 29 deletions(-)
diff --git a/backend/internal/pkg/claude/constants.go b/backend/internal/pkg/claude/constants.go
index f60412c2..0c6e9b4c 100644
--- a/backend/internal/pkg/claude/constants.go
+++ b/backend/internal/pkg/claude/constants.go
@@ -15,6 +15,12 @@ const (
// DefaultBetaHeader Claude Code 客户端默认的 anthropic-beta header
const DefaultBetaHeader = BetaClaudeCode + "," + BetaOAuth + "," + BetaInterleavedThinking + "," + BetaFineGrainedToolStreaming
+// MessageBetaHeaderNoTools /v1/messages 在无工具时的 beta header
+const MessageBetaHeaderNoTools = BetaOAuth + "," + BetaInterleavedThinking
+
+// MessageBetaHeaderWithTools /v1/messages 在有工具时的 beta header
+const MessageBetaHeaderWithTools = BetaClaudeCode + "," + BetaOAuth + "," + BetaInterleavedThinking
+
// CountTokensBetaHeader count_tokens 请求使用的 anthropic-beta header
const CountTokensBetaHeader = BetaClaudeCode + "," + BetaOAuth + "," + BetaInterleavedThinking + "," + BetaTokenCounting
diff --git a/backend/internal/service/account.go b/backend/internal/service/account.go
index cfce9bfa..435eecd9 100644
--- a/backend/internal/service/account.go
+++ b/backend/internal/service/account.go
@@ -364,6 +364,22 @@ func (a *Account) GetExtraString(key string) string {
return ""
}
+func (a *Account) GetClaudeUserID() string {
+ if v := strings.TrimSpace(a.GetExtraString("claude_user_id")); v != "" {
+ return v
+ }
+ if v := strings.TrimSpace(a.GetExtraString("anthropic_user_id")); v != "" {
+ return v
+ }
+ if v := strings.TrimSpace(a.GetCredential("claude_user_id")); v != "" {
+ return v
+ }
+ if v := strings.TrimSpace(a.GetCredential("anthropic_user_id")); v != "" {
+ return v
+ }
+ return ""
+}
+
func (a *Account) IsCustomErrorCodesEnabled() bool {
if a.Type != AccountTypeAPIKey || a.Credentials == nil {
return false
diff --git a/backend/internal/service/gateway_service.go b/backend/internal/service/gateway_service.go
index 904b5acd..790d9fa2 100644
--- a/backend/internal/service/gateway_service.go
+++ b/backend/internal/service/gateway_service.go
@@ -51,6 +51,9 @@ var (
toolNameCamelRe = regexp.MustCompile(`([a-z0-9])([A-Z])`)
toolNameFieldRe = regexp.MustCompile(`"name"\s*:\s*"([^"]+)"`)
modelFieldRe = regexp.MustCompile(`"model"\s*:\s*"([^"]+)"`)
+ toolDescAbsPathRe = regexp.MustCompile(`/\/?(?:home|Users|tmp|var|opt|usr|etc)\/[^\s,\)"'\]]+`)
+ toolDescWinPathRe = regexp.MustCompile(`(?i)[A-Z]:\\[^\s,\)"'\]]+`)
+ opencodeTextRe = regexp.MustCompile(`(?i)opencode`)
claudeToolNameOverrides = map[string]string{
"bash": "Bash",
@@ -451,6 +454,22 @@ func normalizeToolNameForClaude(name string, cache map[string]string) string {
}
func normalizeToolNameForOpenCode(name string, cache map[string]string) string {
+ if name == "" {
+ return name
+ }
+ stripped := stripToolPrefix(name)
+ if cache != nil {
+ if mapped, ok := cache[stripped]; ok {
+ return mapped
+ }
+ }
+ if mapped, ok := openCodeToolOverrides[stripped]; ok {
+ return mapped
+ }
+ return toSnakeCase(stripped)
+}
+
+func normalizeParamNameForOpenCode(name string, cache map[string]string) string {
if name == "" {
return name
}
@@ -459,10 +478,63 @@ func normalizeToolNameForOpenCode(name string, cache map[string]string) string {
return mapped
}
}
- if mapped, ok := openCodeToolOverrides[name]; ok {
- return mapped
+ return name
+}
+
+func sanitizeOpenCodeText(text string) string {
+ if text == "" {
+ return text
+ }
+ text = strings.ReplaceAll(text, "OpenCode", "Claude Code")
+ text = opencodeTextRe.ReplaceAllString(text, "Claude")
+ return text
+}
+
+func sanitizeToolDescription(description string) string {
+ if description == "" {
+ return description
+ }
+ description = toolDescAbsPathRe.ReplaceAllString(description, "[path]")
+ description = toolDescWinPathRe.ReplaceAllString(description, "[path]")
+ return sanitizeOpenCodeText(description)
+}
+
+func normalizeToolInputSchema(inputSchema any, cache map[string]string) {
+ schema, ok := inputSchema.(map[string]any)
+ if !ok {
+ return
+ }
+ properties, ok := schema["properties"].(map[string]any)
+ if !ok {
+ return
+ }
+
+ newProperties := make(map[string]any, len(properties))
+ for key, value := range properties {
+ snakeKey := toSnakeCase(key)
+ newProperties[snakeKey] = value
+ if snakeKey != key && cache != nil {
+ cache[snakeKey] = key
+ }
+ }
+ schema["properties"] = newProperties
+
+ if required, ok := schema["required"].([]any); ok {
+ newRequired := make([]any, 0, len(required))
+ for _, item := range required {
+ name, ok := item.(string)
+ if !ok {
+ newRequired = append(newRequired, item)
+ continue
+ }
+ snakeName := toSnakeCase(name)
+ newRequired = append(newRequired, snakeName)
+ if snakeName != name && cache != nil {
+ cache[snakeName] = name
+ }
+ }
+ schema["required"] = newRequired
}
- return toSnakeCase(name)
}
func stripCacheControlFromSystemBlocks(system any) bool {
@@ -479,9 +551,6 @@ func stripCacheControlFromSystemBlocks(system any) bool {
if _, exists := block["cache_control"]; !exists {
continue
}
- if text, ok := block["text"].(string); ok && text == claudeCodeSystemPrompt {
- continue
- }
delete(block, "cache_control")
changed = true
}
@@ -499,6 +568,34 @@ func normalizeClaudeOAuthRequestBody(body []byte, modelID string, opts claudeOAu
toolNameMap := make(map[string]string)
+ if system, ok := req["system"]; ok {
+ switch v := system.(type) {
+ case string:
+ sanitized := sanitizeOpenCodeText(v)
+ if sanitized != v {
+ req["system"] = sanitized
+ }
+ case []any:
+ for _, item := range v {
+ block, ok := item.(map[string]any)
+ if !ok {
+ continue
+ }
+ if blockType, _ := block["type"].(string); blockType != "text" {
+ continue
+ }
+ text, ok := block["text"].(string)
+ if !ok || text == "" {
+ continue
+ }
+ sanitized := sanitizeOpenCodeText(text)
+ if sanitized != text {
+ block["text"] = sanitized
+ }
+ }
+ }
+ }
+
if rawModel, ok := req["model"].(string); ok {
normalized := claude.NormalizeModelID(rawModel)
if normalized != rawModel {
@@ -521,6 +618,15 @@ func normalizeClaudeOAuthRequestBody(body []byte, modelID string, opts claudeOAu
toolMap["name"] = normalized
}
}
+ if desc, ok := toolMap["description"].(string); ok {
+ sanitized := sanitizeToolDescription(desc)
+ if sanitized != desc {
+ toolMap["description"] = sanitized
+ }
+ }
+ if schema, ok := toolMap["input_schema"]; ok {
+ normalizeToolInputSchema(schema, toolNameMap)
+ }
tools[idx] = toolMap
}
req["tools"] = tools
@@ -532,13 +638,15 @@ func normalizeClaudeOAuthRequestBody(body []byte, modelID string, opts claudeOAu
normalized = name
}
if toolMap, ok := value.(map[string]any); ok {
- if toolName, ok := toolMap["name"].(string); ok {
- mappedName := normalizeToolNameForClaude(toolName, toolNameMap)
- if mappedName != "" && mappedName != toolName {
- toolMap["name"] = mappedName
+ toolMap["name"] = normalized
+ if desc, ok := toolMap["description"].(string); ok {
+ sanitized := sanitizeToolDescription(desc)
+ if sanitized != desc {
+ toolMap["description"] = sanitized
}
- } else if normalized != name {
- toolMap["name"] = normalized
+ }
+ if schema, ok := toolMap["input_schema"]; ok {
+ normalizeToolInputSchema(schema, toolNameMap)
}
normalizedTools[normalized] = toolMap
continue
@@ -611,7 +719,7 @@ func normalizeClaudeOAuthRequestBody(body []byte, modelID string, opts claudeOAu
}
func (s *GatewayService) buildOAuthMetadataUserID(parsed *ParsedRequest, account *Account, fp *Fingerprint) string {
- if parsed == nil || fp == nil || fp.ClientID == "" {
+ if parsed == nil || account == nil {
return ""
}
if parsed.MetadataUserID != "" {
@@ -621,13 +729,22 @@ func (s *GatewayService) buildOAuthMetadataUserID(parsed *ParsedRequest, account
if accountUUID == "" {
return ""
}
+
+ userID := strings.TrimSpace(account.GetClaudeUserID())
+ if userID == "" && fp != nil {
+ userID = fp.ClientID
+ }
+ if userID == "" {
+ return ""
+ }
+
sessionHash := s.GenerateSessionHash(parsed)
sessionID := uuid.NewString()
if sessionHash != "" {
seed := fmt.Sprintf("%d::%s", account.ID, sessionHash)
sessionID = generateSessionUUID(seed)
}
- return fmt.Sprintf("user_%s_account_%s_session_%s", fp.ClientID, accountUUID, sessionID)
+ return fmt.Sprintf("user_%s_account_%s_session_%s", userID, accountUUID, sessionID)
}
func generateSessionUUID(seed string) string {
@@ -2213,7 +2330,11 @@ func (s *GatewayService) buildUpstreamRequest(ctx context.Context, c *gin.Contex
// 处理anthropic-beta header(OAuth账号需要特殊处理)
if tokenType == "oauth" && mimicClaudeCode {
- req.Header.Set("anthropic-beta", s.getBetaHeader(modelID, c.GetHeader("anthropic-beta")))
+ if requestHasTools(body) {
+ req.Header.Set("anthropic-beta", claude.MessageBetaHeaderWithTools)
+ } else {
+ req.Header.Set("anthropic-beta", claude.MessageBetaHeaderNoTools)
+ }
} else if s.cfg != nil && s.cfg.Gateway.InjectBetaForAPIKey && req.Header.Get("anthropic-beta") == "" {
// API-key:仅在请求显式使用 beta 特性且客户端未提供时,按需补齐(默认关闭)
if requestNeedsBetaFeatures(body) {
@@ -2284,6 +2405,20 @@ func requestNeedsBetaFeatures(body []byte) bool {
return false
}
+func requestHasTools(body []byte) bool {
+ tools := gjson.GetBytes(body, "tools")
+ if !tools.Exists() {
+ return false
+ }
+ if tools.IsArray() {
+ return len(tools.Array()) > 0
+ }
+ if tools.IsObject() {
+ return len(tools.Map()) > 0
+ }
+ return false
+}
+
func defaultAPIKeyBetaHeader(body []byte) string {
modelID := gjson.GetBytes(body, "model").String()
if strings.Contains(strings.ToLower(modelID), "haiku") {
@@ -2817,6 +2952,45 @@ func (s *GatewayService) replaceModelInSSELine(line, fromModel, toModel string)
return "data: " + string(newData)
}
+func rewriteParamKeysInValue(value any, cache map[string]string) (any, bool) {
+ switch v := value.(type) {
+ case map[string]any:
+ changed := false
+ rewritten := make(map[string]any, len(v))
+ for key, item := range v {
+ newKey := normalizeParamNameForOpenCode(key, cache)
+ newItem, childChanged := rewriteParamKeysInValue(item, cache)
+ if childChanged {
+ changed = true
+ }
+ if newKey != key {
+ changed = true
+ }
+ rewritten[newKey] = newItem
+ }
+ if !changed {
+ return value, false
+ }
+ return rewritten, true
+ case []any:
+ changed := false
+ rewritten := make([]any, len(v))
+ for idx, item := range v {
+ newItem, childChanged := rewriteParamKeysInValue(item, cache)
+ if childChanged {
+ changed = true
+ }
+ rewritten[idx] = newItem
+ }
+ if !changed {
+ return value, false
+ }
+ return rewritten, true
+ default:
+ return value, false
+ }
+}
+
func rewriteToolNamesInValue(value any, toolNameMap map[string]string) bool {
switch v := value.(type) {
case map[string]any:
@@ -2829,6 +3003,15 @@ func rewriteToolNamesInValue(value any, toolNameMap map[string]string) bool {
changed = true
}
}
+ if input, ok := v["input"].(map[string]any); ok {
+ rewrittenInput, inputChanged := rewriteParamKeysInValue(input, toolNameMap)
+ if inputChanged {
+ if m, ok := rewrittenInput.(map[string]any); ok {
+ v["input"] = m
+ changed = true
+ }
+ }
+ }
}
for _, item := range v {
if rewriteToolNamesInValue(item, toolNameMap) {
@@ -2877,6 +3060,15 @@ func replaceToolNamesInText(text string, toolNameMap map[string]string) string {
}
return strings.Replace(match, model, mapped, 1)
})
+
+ for mapped, original := range toolNameMap {
+ if mapped == "" || original == "" || mapped == original {
+ continue
+ }
+ output = strings.ReplaceAll(output, "\""+mapped+"\":", "\""+original+"\":")
+ output = strings.ReplaceAll(output, "\\\""+mapped+"\\\":", "\\\""+original+"\\\":")
+ }
+
return output
}
@@ -2889,22 +3081,11 @@ func (s *GatewayService) replaceToolNamesInSSELine(line string, toolNameMap map[
return line
}
- var event map[string]any
- if err := json.Unmarshal([]byte(data), &event); err != nil {
- replaced := replaceToolNamesInText(data, toolNameMap)
- if replaced == data {
- return line
- }
- return "data: " + replaced
- }
- if !rewriteToolNamesInValue(event, toolNameMap) {
+ replaced := replaceToolNamesInText(data, toolNameMap)
+ if replaced == data {
return line
}
- newData, err := json.Marshal(event)
- if err != nil {
- return line
- }
- return "data: " + string(newData)
+ return "data: " + replaced
}
func (s *GatewayService) parseSSEUsage(data string, usage *ClaudeUsage) {
From bd854e1750e568c4a02b3a276e68bcd6336f5368 Mon Sep 17 00:00:00 2001
From: cyhhao
Date: Fri, 16 Jan 2026 23:15:52 +0800
Subject: [PATCH 006/155] =?UTF-8?q?fix(=E7=BD=91=E5=85=B3):=20Claude=20Cod?=
=?UTF-8?q?e=20OAuth=20=E8=A1=A5=E9=BD=90=20oauth=20beta?=
MIME-Version: 1.0
Content-Type: text/plain; charset=UTF-8
Content-Transfer-Encoding: 8bit
---
backend/internal/service/gateway_service.go | 34 ++++++++++++++++-----
1 file changed, 27 insertions(+), 7 deletions(-)
diff --git a/backend/internal/service/gateway_service.go b/backend/internal/service/gateway_service.go
index 790d9fa2..aa811bf5 100644
--- a/backend/internal/service/gateway_service.go
+++ b/backend/internal/service/gateway_service.go
@@ -2328,12 +2328,19 @@ func (s *GatewayService) buildUpstreamRequest(ctx context.Context, c *gin.Contex
applyClaudeOAuthHeaderDefaults(req, reqStream)
}
- // 处理anthropic-beta header(OAuth账号需要特殊处理)
- if tokenType == "oauth" && mimicClaudeCode {
- if requestHasTools(body) {
- req.Header.Set("anthropic-beta", claude.MessageBetaHeaderWithTools)
+ // 处理 anthropic-beta header(OAuth 账号需要包含 oauth beta)
+ if tokenType == "oauth" {
+ if mimicClaudeCode {
+ // 非 Claude Code 客户端:按 Claude Code 规则生成 beta header
+ if requestHasTools(body) {
+ req.Header.Set("anthropic-beta", claude.MessageBetaHeaderWithTools)
+ } else {
+ req.Header.Set("anthropic-beta", claude.MessageBetaHeaderNoTools)
+ }
} else {
- req.Header.Set("anthropic-beta", claude.MessageBetaHeaderNoTools)
+ // Claude Code 客户端:尽量透传原始 header,仅补齐 oauth beta
+ clientBetaHeader := req.Header.Get("anthropic-beta")
+ req.Header.Set("anthropic-beta", s.getBetaHeader(modelID, clientBetaHeader))
}
} else if s.cfg != nil && s.cfg.Gateway.InjectBetaForAPIKey && req.Header.Get("anthropic-beta") == "" {
// API-key:仅在请求显式使用 beta 特性且客户端未提供时,按需补齐(默认关闭)
@@ -3576,8 +3583,21 @@ func (s *GatewayService) buildCountTokensRequest(ctx context.Context, c *gin.Con
}
// OAuth 账号:处理 anthropic-beta header
- if tokenType == "oauth" && mimicClaudeCode {
- req.Header.Set("anthropic-beta", claude.CountTokensBetaHeader)
+ if tokenType == "oauth" {
+ if mimicClaudeCode {
+ req.Header.Set("anthropic-beta", claude.CountTokensBetaHeader)
+ } else {
+ clientBetaHeader := req.Header.Get("anthropic-beta")
+ if clientBetaHeader == "" {
+ req.Header.Set("anthropic-beta", claude.CountTokensBetaHeader)
+ } else {
+ beta := s.getBetaHeader(modelID, clientBetaHeader)
+ if !strings.Contains(beta, claude.BetaTokenCounting) {
+ beta = beta + "," + claude.BetaTokenCounting
+ }
+ req.Header.Set("anthropic-beta", beta)
+ }
+ }
} else if s.cfg != nil && s.cfg.Gateway.InjectBetaForAPIKey && req.Header.Get("anthropic-beta") == "" {
// API-key:与 messages 同步的按需 beta 注入(默认关闭)
if requestNeedsBetaFeatures(body) {
From 2a7d04fec4f452bc20b73ab0fa04da9ef6fd7870 Mon Sep 17 00:00:00 2001
From: cyhhao
Date: Thu, 15 Jan 2026 18:54:42 +0800
Subject: [PATCH 007/155] =?UTF-8?q?fix(=E7=BD=91=E5=85=B3):=20=E5=AF=B9?=
=?UTF-8?q?=E9=BD=90=20Claude=20OAuth=20=E8=AF=B7=E6=B1=82=E9=80=82?=
=?UTF-8?q?=E9=85=8D?=
MIME-Version: 1.0
Content-Type: text/plain; charset=UTF-8
Content-Transfer-Encoding: 8bit
---
backend/internal/pkg/claude/constants.go | 44 +-
backend/internal/service/gateway_service.go | 455 ++++++++++++++++++-
backend/internal/service/identity_service.go | 8 +-
3 files changed, 481 insertions(+), 26 deletions(-)
diff --git a/backend/internal/pkg/claude/constants.go b/backend/internal/pkg/claude/constants.go
index d1a56a84..15144881 100644
--- a/backend/internal/pkg/claude/constants.go
+++ b/backend/internal/pkg/claude/constants.go
@@ -25,15 +25,15 @@ const APIKeyHaikuBetaHeader = BetaInterleavedThinking
// DefaultHeaders 是 Claude Code 客户端默认请求头。
var DefaultHeaders = map[string]string{
- "User-Agent": "claude-cli/2.0.62 (external, cli)",
+ "User-Agent": "claude-cli/2.1.2 (external, cli)",
"X-Stainless-Lang": "js",
- "X-Stainless-Package-Version": "0.52.0",
+ "X-Stainless-Package-Version": "0.70.0",
"X-Stainless-OS": "Linux",
"X-Stainless-Arch": "x64",
"X-Stainless-Runtime": "node",
- "X-Stainless-Runtime-Version": "v22.14.0",
+ "X-Stainless-Runtime-Version": "v24.3.0",
"X-Stainless-Retry-Count": "0",
- "X-Stainless-Timeout": "60",
+ "X-Stainless-Timeout": "600",
"X-App": "cli",
"Anthropic-Dangerous-Direct-Browser-Access": "true",
}
@@ -79,3 +79,39 @@ func DefaultModelIDs() []string {
// DefaultTestModel 测试时使用的默认模型
const DefaultTestModel = "claude-sonnet-4-5-20250929"
+
+// ModelIDOverrides Claude OAuth 请求需要的模型 ID 映射
+var ModelIDOverrides = map[string]string{
+ "claude-sonnet-4-5": "claude-sonnet-4-5-20250929",
+ "claude-opus-4-5": "claude-opus-4-5-20251101",
+ "claude-haiku-4-5": "claude-haiku-4-5-20251001",
+}
+
+// ModelIDReverseOverrides 用于将上游模型 ID 还原为短名
+var ModelIDReverseOverrides = map[string]string{
+ "claude-sonnet-4-5-20250929": "claude-sonnet-4-5",
+ "claude-opus-4-5-20251101": "claude-opus-4-5",
+ "claude-haiku-4-5-20251001": "claude-haiku-4-5",
+}
+
+// NormalizeModelID 根据 Claude OAuth 规则映射模型
+func NormalizeModelID(id string) string {
+ if id == "" {
+ return id
+ }
+ if mapped, ok := ModelIDOverrides[id]; ok {
+ return mapped
+ }
+ return id
+}
+
+// DenormalizeModelID 将上游模型 ID 转换为短名
+func DenormalizeModelID(id string) string {
+ if id == "" {
+ return id
+ }
+ if mapped, ok := ModelIDReverseOverrides[id]; ok {
+ return mapped
+ }
+ return id
+}
diff --git a/backend/internal/service/gateway_service.go b/backend/internal/service/gateway_service.go
index 1e3221d3..899a0fc5 100644
--- a/backend/internal/service/gateway_service.go
+++ b/backend/internal/service/gateway_service.go
@@ -18,12 +18,14 @@ import (
"strings"
"sync/atomic"
"time"
+ "unicode"
"github.com/Wei-Shaw/sub2api/internal/config"
"github.com/Wei-Shaw/sub2api/internal/pkg/claude"
"github.com/Wei-Shaw/sub2api/internal/pkg/ctxkey"
"github.com/Wei-Shaw/sub2api/internal/util/responseheaders"
"github.com/Wei-Shaw/sub2api/internal/util/urlvalidator"
+ "github.com/google/uuid"
"github.com/tidwall/gjson"
"github.com/tidwall/sjson"
@@ -60,6 +62,36 @@ var (
sseDataRe = regexp.MustCompile(`^data:\s*`)
sessionIDRegex = regexp.MustCompile(`session_([a-f0-9-]{36})`)
claudeCliUserAgentRe = regexp.MustCompile(`^claude-cli/\d+\.\d+\.\d+`)
+ toolPrefixRe = regexp.MustCompile(`(?i)^(?:oc_|mcp_)`)
+ toolNameBoundaryRe = regexp.MustCompile(`[^a-zA-Z0-9]+`)
+ toolNameCamelRe = regexp.MustCompile(`([a-z0-9])([A-Z])`)
+
+ claudeToolNameOverrides = map[string]string{
+ "bash": "Bash",
+ "read": "Read",
+ "edit": "Edit",
+ "write": "Write",
+ "task": "Task",
+ "glob": "Glob",
+ "grep": "Grep",
+ "webfetch": "WebFetch",
+ "websearch": "WebSearch",
+ "todowrite": "TodoWrite",
+ "question": "AskUserQuestion",
+ }
+ openCodeToolOverrides = map[string]string{
+ "Bash": "bash",
+ "Read": "read",
+ "Edit": "edit",
+ "Write": "write",
+ "Task": "task",
+ "Glob": "glob",
+ "Grep": "grep",
+ "WebFetch": "webfetch",
+ "WebSearch": "websearch",
+ "TodoWrite": "todowrite",
+ "AskUserQuestion": "question",
+ }
// claudeCodePromptPrefixes 用于检测 Claude Code 系统提示词的前缀列表
// 支持多种变体:标准版、Agent SDK 版、Explore Agent 版、Compact 版等
@@ -365,6 +397,268 @@ func (s *GatewayService) replaceModelInBody(body []byte, newModel string) []byte
return newBody
}
+type claudeOAuthNormalizeOptions struct {
+ injectMetadata bool
+ metadataUserID string
+ stripSystemCacheControl bool
+}
+
+func stripToolPrefix(value string) string {
+ if value == "" {
+ return value
+ }
+ return toolPrefixRe.ReplaceAllString(value, "")
+}
+
+func toPascalCase(value string) string {
+ if value == "" {
+ return value
+ }
+ normalized := toolNameBoundaryRe.ReplaceAllString(value, " ")
+ tokens := make([]string, 0)
+ for _, token := range strings.Fields(normalized) {
+ expanded := toolNameCamelRe.ReplaceAllString(token, "$1 $2")
+ parts := strings.Fields(expanded)
+ if len(parts) > 0 {
+ tokens = append(tokens, parts...)
+ }
+ }
+ if len(tokens) == 0 {
+ return value
+ }
+ var builder strings.Builder
+ for _, token := range tokens {
+ lower := strings.ToLower(token)
+ if lower == "" {
+ continue
+ }
+ runes := []rune(lower)
+ runes[0] = unicode.ToUpper(runes[0])
+ builder.WriteString(string(runes))
+ }
+ return builder.String()
+}
+
+func toSnakeCase(value string) string {
+ if value == "" {
+ return value
+ }
+ output := toolNameCamelRe.ReplaceAllString(value, "$1_$2")
+ output = toolNameBoundaryRe.ReplaceAllString(output, "_")
+ output = strings.Trim(output, "_")
+ return strings.ToLower(output)
+}
+
+func normalizeToolNameForClaude(name string, cache map[string]string) string {
+ if name == "" {
+ return name
+ }
+ stripped := stripToolPrefix(name)
+ mapped, ok := claudeToolNameOverrides[strings.ToLower(stripped)]
+ if !ok {
+ mapped = toPascalCase(stripped)
+ }
+ if mapped != "" && cache != nil && mapped != stripped {
+ cache[mapped] = stripped
+ }
+ if mapped == "" {
+ return stripped
+ }
+ return mapped
+}
+
+func normalizeToolNameForOpenCode(name string, cache map[string]string) string {
+ if name == "" {
+ return name
+ }
+ if cache != nil {
+ if mapped, ok := cache[name]; ok {
+ return mapped
+ }
+ }
+ if mapped, ok := openCodeToolOverrides[name]; ok {
+ return mapped
+ }
+ return toSnakeCase(name)
+}
+
+func stripCacheControlFromSystemBlocks(system any) bool {
+ blocks, ok := system.([]any)
+ if !ok {
+ return false
+ }
+ changed := false
+ for _, item := range blocks {
+ block, ok := item.(map[string]any)
+ if !ok {
+ continue
+ }
+ if _, exists := block["cache_control"]; !exists {
+ continue
+ }
+ if text, ok := block["text"].(string); ok && text == claudeCodeSystemPrompt {
+ continue
+ }
+ delete(block, "cache_control")
+ changed = true
+ }
+ return changed
+}
+
+func normalizeClaudeOAuthRequestBody(body []byte, modelID string, opts claudeOAuthNormalizeOptions) ([]byte, string, map[string]string) {
+ if len(body) == 0 {
+ return body, modelID, nil
+ }
+ var req map[string]any
+ if err := json.Unmarshal(body, &req); err != nil {
+ return body, modelID, nil
+ }
+
+ toolNameMap := make(map[string]string)
+
+ if rawModel, ok := req["model"].(string); ok {
+ normalized := claude.NormalizeModelID(rawModel)
+ if normalized != rawModel {
+ req["model"] = normalized
+ modelID = normalized
+ }
+ }
+
+ if rawTools, exists := req["tools"]; exists {
+ switch tools := rawTools.(type) {
+ case []any:
+ for idx, tool := range tools {
+ toolMap, ok := tool.(map[string]any)
+ if !ok {
+ continue
+ }
+ if name, ok := toolMap["name"].(string); ok {
+ normalized := normalizeToolNameForClaude(name, toolNameMap)
+ if normalized != "" && normalized != name {
+ toolMap["name"] = normalized
+ }
+ }
+ tools[idx] = toolMap
+ }
+ req["tools"] = tools
+ case map[string]any:
+ normalizedTools := make(map[string]any, len(tools))
+ for name, value := range tools {
+ normalized := normalizeToolNameForClaude(name, toolNameMap)
+ if normalized == "" {
+ normalized = name
+ }
+ if toolMap, ok := value.(map[string]any); ok {
+ if toolName, ok := toolMap["name"].(string); ok {
+ mappedName := normalizeToolNameForClaude(toolName, toolNameMap)
+ if mappedName != "" && mappedName != toolName {
+ toolMap["name"] = mappedName
+ }
+ } else if normalized != name {
+ toolMap["name"] = normalized
+ }
+ normalizedTools[normalized] = toolMap
+ continue
+ }
+ normalizedTools[normalized] = value
+ }
+ req["tools"] = normalizedTools
+ }
+ } else {
+ req["tools"] = []any{}
+ }
+
+ if messages, ok := req["messages"].([]any); ok {
+ for _, msg := range messages {
+ msgMap, ok := msg.(map[string]any)
+ if !ok {
+ continue
+ }
+ content, ok := msgMap["content"].([]any)
+ if !ok {
+ continue
+ }
+ for _, block := range content {
+ blockMap, ok := block.(map[string]any)
+ if !ok {
+ continue
+ }
+ if blockType, _ := blockMap["type"].(string); blockType != "tool_use" {
+ continue
+ }
+ if name, ok := blockMap["name"].(string); ok {
+ normalized := normalizeToolNameForClaude(name, toolNameMap)
+ if normalized != "" && normalized != name {
+ blockMap["name"] = normalized
+ }
+ }
+ }
+ }
+ }
+
+ if opts.stripSystemCacheControl {
+ if system, ok := req["system"]; ok {
+ _ = stripCacheControlFromSystemBlocks(system)
+ }
+ }
+
+ if opts.injectMetadata && opts.metadataUserID != "" {
+ metadata, ok := req["metadata"].(map[string]any)
+ if !ok {
+ metadata = map[string]any{}
+ req["metadata"] = metadata
+ }
+ if existing, ok := metadata["user_id"].(string); !ok || existing == "" {
+ metadata["user_id"] = opts.metadataUserID
+ }
+ }
+
+ if _, ok := req["temperature"]; ok {
+ delete(req, "temperature")
+ }
+ if _, ok := req["tool_choice"]; ok {
+ delete(req, "tool_choice")
+ }
+
+ newBody, err := json.Marshal(req)
+ if err != nil {
+ return body, modelID, toolNameMap
+ }
+ return newBody, modelID, toolNameMap
+}
+
+func (s *GatewayService) buildOAuthMetadataUserID(parsed *ParsedRequest, account *Account, fp *Fingerprint) string {
+ if parsed == nil || fp == nil || fp.ClientID == "" {
+ return ""
+ }
+ if parsed.MetadataUserID != "" {
+ return ""
+ }
+ accountUUID := account.GetExtraString("account_uuid")
+ if accountUUID == "" {
+ return ""
+ }
+ sessionHash := s.GenerateSessionHash(parsed)
+ sessionID := uuid.NewString()
+ if sessionHash != "" {
+ seed := fmt.Sprintf("%d::%s", account.ID, sessionHash)
+ sessionID = generateSessionUUID(seed)
+ }
+ return fmt.Sprintf("user_%s_account_%s_session_%s", fp.ClientID, accountUUID, sessionID)
+}
+
+func generateSessionUUID(seed string) string {
+ if seed == "" {
+ return uuid.NewString()
+ }
+ hash := sha256.Sum256([]byte(seed))
+ bytes := hash[:16]
+ bytes[6] = (bytes[6] & 0x0f) | 0x40
+ bytes[8] = (bytes[8] & 0x3f) | 0x80
+ return fmt.Sprintf("%x-%x-%x-%x-%x",
+ bytes[0:4], bytes[4:6], bytes[6:8], bytes[8:10], bytes[10:16])
+}
+
// SelectAccount 选择账号(粘性会话+优先级)
func (s *GatewayService) SelectAccount(ctx context.Context, groupID *int64, sessionHash string) (*Account, error) {
return s.SelectAccountForModel(ctx, groupID, sessionHash, "")
@@ -1906,21 +2200,36 @@ func (s *GatewayService) Forward(ctx context.Context, c *gin.Context, account *A
body := parsed.Body
reqModel := parsed.Model
reqStream := parsed.Stream
+ originalModel := reqModel
+ var toolNameMap map[string]string
- // 智能注入 Claude Code 系统提示词(仅 OAuth/SetupToken 账号需要)
- // 条件:1) OAuth/SetupToken 账号 2) 不是 Claude Code 客户端 3) 不是 Haiku 模型 4) system 中还没有 Claude Code 提示词
- if account.IsOAuth() &&
- !isClaudeCodeClient(c.GetHeader("User-Agent"), parsed.MetadataUserID) &&
- !strings.Contains(strings.ToLower(reqModel), "haiku") &&
- !systemIncludesClaudeCodePrompt(parsed.System) {
- body = injectClaudeCodePrompt(body, parsed.System)
+ if account.IsOAuth() {
+ // 智能注入 Claude Code 系统提示词(仅 OAuth/SetupToken 账号需要)
+ // 条件:1) OAuth/SetupToken 账号 2) 不是 Claude Code 客户端 3) 不是 Haiku 模型 4) system 中还没有 Claude Code 提示词
+ if !isClaudeCodeClient(c.GetHeader("User-Agent"), parsed.MetadataUserID) &&
+ !strings.Contains(strings.ToLower(reqModel), "haiku") &&
+ !systemIncludesClaudeCodePrompt(parsed.System) {
+ body = injectClaudeCodePrompt(body, parsed.System)
+ }
+
+ normalizeOpts := claudeOAuthNormalizeOptions{stripSystemCacheControl: true}
+ if s.identityService != nil {
+ fp, err := s.identityService.GetOrCreateFingerprint(ctx, account.ID, c.Request.Header)
+ if err == nil && fp != nil {
+ if metadataUserID := s.buildOAuthMetadataUserID(parsed, account, fp); metadataUserID != "" {
+ normalizeOpts.injectMetadata = true
+ normalizeOpts.metadataUserID = metadataUserID
+ }
+ }
+ }
+
+ body, reqModel, toolNameMap = normalizeClaudeOAuthRequestBody(body, reqModel, normalizeOpts)
}
// 强制执行 cache_control 块数量限制(最多 4 个)
body = enforceCacheControlLimit(body)
// 应用模型映射(仅对apikey类型账号)
- originalModel := reqModel
if account.Type == AccountTypeAPIKey {
mappedModel := account.GetMappedModel(reqModel)
if mappedModel != reqModel {
@@ -1948,10 +2257,9 @@ func (s *GatewayService) Forward(ctx context.Context, c *gin.Context, account *A
retryStart := time.Now()
for attempt := 1; attempt <= maxRetryAttempts; attempt++ {
// 构建上游请求(每次重试需要重新构建,因为请求体需要重新读取)
- upstreamReq, err := s.buildUpstreamRequest(ctx, c, account, body, token, tokenType, reqModel)
+ upstreamReq, err := s.buildUpstreamRequest(ctx, c, account, body, token, tokenType, reqModel, reqStream)
// Capture upstream request body for ops retry of this attempt.
c.Set(OpsUpstreamRequestBodyKey, string(body))
-
if err != nil {
return nil, err
}
@@ -2029,7 +2337,7 @@ func (s *GatewayService) Forward(ctx context.Context, c *gin.Context, account *A
// also downgrade tool_use/tool_result blocks to text.
filteredBody := FilterThinkingBlocksForRetry(body)
- retryReq, buildErr := s.buildUpstreamRequest(ctx, c, account, filteredBody, token, tokenType, reqModel)
+ retryReq, buildErr := s.buildUpstreamRequest(ctx, c, account, filteredBody, token, tokenType, reqModel, reqStream)
if buildErr == nil {
retryResp, retryErr := s.httpUpstream.Do(retryReq, proxyURL, account.ID, account.Concurrency)
if retryErr == nil {
@@ -2061,7 +2369,7 @@ func (s *GatewayService) Forward(ctx context.Context, c *gin.Context, account *A
if looksLikeToolSignatureError(msg2) && time.Since(retryStart) < maxRetryElapsed {
log.Printf("Account %d: signature retry still failing and looks tool-related, retrying with tool blocks downgraded", account.ID)
filteredBody2 := FilterSignatureSensitiveBlocksForRetry(body)
- retryReq2, buildErr2 := s.buildUpstreamRequest(ctx, c, account, filteredBody2, token, tokenType, reqModel)
+ retryReq2, buildErr2 := s.buildUpstreamRequest(ctx, c, account, filteredBody2, token, tokenType, reqModel, reqStream)
if buildErr2 == nil {
retryResp2, retryErr2 := s.httpUpstream.Do(retryReq2, proxyURL, account.ID, account.Concurrency)
if retryErr2 == nil {
@@ -2278,7 +2586,7 @@ func (s *GatewayService) Forward(ctx context.Context, c *gin.Context, account *A
var firstTokenMs *int
var clientDisconnect bool
if reqStream {
- streamResult, err := s.handleStreamingResponse(ctx, resp, c, account, startTime, originalModel, reqModel)
+ streamResult, err := s.handleStreamingResponse(ctx, resp, c, account, startTime, originalModel, reqModel, toolNameMap)
if err != nil {
if err.Error() == "have error in stream" {
return nil, &UpstreamFailoverError{
@@ -2291,7 +2599,7 @@ func (s *GatewayService) Forward(ctx context.Context, c *gin.Context, account *A
firstTokenMs = streamResult.firstTokenMs
clientDisconnect = streamResult.clientDisconnect
} else {
- usage, err = s.handleNonStreamingResponse(ctx, resp, c, account, originalModel, reqModel)
+ usage, err = s.handleNonStreamingResponse(ctx, resp, c, account, originalModel, reqModel, toolNameMap)
if err != nil {
return nil, err
}
@@ -2308,7 +2616,7 @@ func (s *GatewayService) Forward(ctx context.Context, c *gin.Context, account *A
}, nil
}
-func (s *GatewayService) buildUpstreamRequest(ctx context.Context, c *gin.Context, account *Account, body []byte, token, tokenType, modelID string) (*http.Request, error) {
+func (s *GatewayService) buildUpstreamRequest(ctx context.Context, c *gin.Context, account *Account, body []byte, token, tokenType, modelID string, reqStream bool) (*http.Request, error) {
// 确定目标URL
targetURL := claudeAPIURL
if account.Type == AccountTypeAPIKey {
@@ -2377,6 +2685,9 @@ func (s *GatewayService) buildUpstreamRequest(ctx context.Context, c *gin.Contex
if req.Header.Get("anthropic-version") == "" {
req.Header.Set("anthropic-version", "2023-06-01")
}
+ if tokenType == "oauth" {
+ applyClaudeOAuthHeaderDefaults(req, reqStream)
+ }
// 处理anthropic-beta header(OAuth账号需要特殊处理)
if tokenType == "oauth" {
@@ -2459,6 +2770,26 @@ func defaultAPIKeyBetaHeader(body []byte) string {
return claude.APIKeyBetaHeader
}
+func applyClaudeOAuthHeaderDefaults(req *http.Request, isStream bool) {
+ if req == nil {
+ return
+ }
+ if req.Header.Get("accept") == "" {
+ req.Header.Set("accept", "application/json")
+ }
+ for key, value := range claude.DefaultHeaders {
+ if value == "" {
+ continue
+ }
+ if req.Header.Get(key) == "" {
+ req.Header.Set(key, value)
+ }
+ }
+ if isStream && req.Header.Get("x-stainless-helper-method") == "" {
+ req.Header.Set("x-stainless-helper-method", "stream")
+ }
+}
+
func truncateForLog(b []byte, maxBytes int) string {
if maxBytes <= 0 {
maxBytes = 2048
@@ -2739,7 +3070,7 @@ type streamingResult struct {
clientDisconnect bool // 客户端是否在流式传输过程中断开
}
-func (s *GatewayService) handleStreamingResponse(ctx context.Context, resp *http.Response, c *gin.Context, account *Account, startTime time.Time, originalModel, mappedModel string) (*streamingResult, error) {
+func (s *GatewayService) handleStreamingResponse(ctx context.Context, resp *http.Response, c *gin.Context, account *Account, startTime time.Time, originalModel, mappedModel string, toolNameMap map[string]string) (*streamingResult, error) {
// 更新5h窗口状态
s.rateLimitService.UpdateSessionWindow(ctx, account, resp.Header)
@@ -2832,6 +3163,7 @@ func (s *GatewayService) handleStreamingResponse(ctx context.Context, resp *http
}
needModelReplace := originalModel != mappedModel
+ rewriteTools := account.IsOAuth()
clientDisconnected := false // 客户端断开标志,断开后继续读取上游以获取完整usage
for {
@@ -2873,11 +3205,14 @@ func (s *GatewayService) handleStreamingResponse(ctx context.Context, resp *http
// Extract data from SSE line (supports both "data: " and "data:" formats)
var data string
if sseDataRe.MatchString(line) {
- data = sseDataRe.ReplaceAllString(line, "")
// 如果有模型映射,替换响应中的model字段
if needModelReplace {
line = s.replaceModelInSSELine(line, mappedModel, originalModel)
}
+ if rewriteTools {
+ line = s.replaceToolNamesInSSELine(line, toolNameMap)
+ }
+ data = sseDataRe.ReplaceAllString(line, "")
}
// 写入客户端(统一处理 data 行和非 data 行)
@@ -2960,6 +3295,61 @@ func (s *GatewayService) replaceModelInSSELine(line, fromModel, toModel string)
return "data: " + string(newData)
}
+func rewriteToolNamesInValue(value any, toolNameMap map[string]string) bool {
+ switch v := value.(type) {
+ case map[string]any:
+ changed := false
+ if blockType, _ := v["type"].(string); blockType == "tool_use" {
+ if name, ok := v["name"].(string); ok {
+ mapped := normalizeToolNameForOpenCode(name, toolNameMap)
+ if mapped != name {
+ v["name"] = mapped
+ changed = true
+ }
+ }
+ }
+ for _, item := range v {
+ if rewriteToolNamesInValue(item, toolNameMap) {
+ changed = true
+ }
+ }
+ return changed
+ case []any:
+ changed := false
+ for _, item := range v {
+ if rewriteToolNamesInValue(item, toolNameMap) {
+ changed = true
+ }
+ }
+ return changed
+ default:
+ return false
+ }
+}
+
+func (s *GatewayService) replaceToolNamesInSSELine(line string, toolNameMap map[string]string) string {
+ if !sseDataRe.MatchString(line) {
+ return line
+ }
+ data := sseDataRe.ReplaceAllString(line, "")
+ if data == "" || data == "[DONE]" {
+ return line
+ }
+
+ var event map[string]any
+ if err := json.Unmarshal([]byte(data), &event); err != nil {
+ return line
+ }
+ if !rewriteToolNamesInValue(event, toolNameMap) {
+ return line
+ }
+ newData, err := json.Marshal(event)
+ if err != nil {
+ return line
+ }
+ return "data: " + string(newData)
+}
+
func (s *GatewayService) parseSSEUsage(data string, usage *ClaudeUsage) {
// 解析message_start获取input tokens(标准Claude API格式)
var msgStart struct {
@@ -3001,7 +3391,7 @@ func (s *GatewayService) parseSSEUsage(data string, usage *ClaudeUsage) {
}
}
-func (s *GatewayService) handleNonStreamingResponse(ctx context.Context, resp *http.Response, c *gin.Context, account *Account, originalModel, mappedModel string) (*ClaudeUsage, error) {
+func (s *GatewayService) handleNonStreamingResponse(ctx context.Context, resp *http.Response, c *gin.Context, account *Account, originalModel, mappedModel string, toolNameMap map[string]string) (*ClaudeUsage, error) {
// 更新5h窗口状态
s.rateLimitService.UpdateSessionWindow(ctx, account, resp.Header)
@@ -3022,6 +3412,9 @@ func (s *GatewayService) handleNonStreamingResponse(ctx context.Context, resp *h
if originalModel != mappedModel {
body = s.replaceModelInResponseBody(body, mappedModel, originalModel)
}
+ if account.IsOAuth() {
+ body = s.replaceToolNamesInResponseBody(body, toolNameMap)
+ }
responseheaders.WriteFilteredHeaders(c.Writer.Header(), resp.Header, s.cfg.Security.ResponseHeaders)
@@ -3059,6 +3452,24 @@ func (s *GatewayService) replaceModelInResponseBody(body []byte, fromModel, toMo
return newBody
}
+func (s *GatewayService) replaceToolNamesInResponseBody(body []byte, toolNameMap map[string]string) []byte {
+ if len(body) == 0 {
+ return body
+ }
+ var resp map[string]any
+ if err := json.Unmarshal(body, &resp); err != nil {
+ return body
+ }
+ if !rewriteToolNamesInValue(resp, toolNameMap) {
+ return body
+ }
+ newBody, err := json.Marshal(resp)
+ if err != nil {
+ return body
+ }
+ return newBody
+}
+
// RecordUsageInput 记录使用量的输入参数
type RecordUsageInput struct {
Result *ForwardResult
@@ -3224,6 +3635,11 @@ func (s *GatewayService) ForwardCountTokens(ctx context.Context, c *gin.Context,
body := parsed.Body
reqModel := parsed.Model
+ if account.IsOAuth() {
+ normalizeOpts := claudeOAuthNormalizeOptions{stripSystemCacheControl: true}
+ body, reqModel, _ = normalizeClaudeOAuthRequestBody(body, reqModel, normalizeOpts)
+ }
+
// Antigravity 账户不支持 count_tokens 转发,直接返回空值
if account.Platform == PlatformAntigravity {
c.JSON(http.StatusOK, gin.H{"input_tokens": 0})
@@ -3412,6 +3828,9 @@ func (s *GatewayService) buildCountTokensRequest(ctx context.Context, c *gin.Con
if req.Header.Get("anthropic-version") == "" {
req.Header.Set("anthropic-version", "2023-06-01")
}
+ if tokenType == "oauth" {
+ applyClaudeOAuthHeaderDefaults(req, false)
+ }
// OAuth 账号:处理 anthropic-beta header
if tokenType == "oauth" {
diff --git a/backend/internal/service/identity_service.go b/backend/internal/service/identity_service.go
index 1ffa8057..4ab1ab96 100644
--- a/backend/internal/service/identity_service.go
+++ b/backend/internal/service/identity_service.go
@@ -24,13 +24,13 @@ var (
// 默认指纹值(当客户端未提供时使用)
var defaultFingerprint = Fingerprint{
- UserAgent: "claude-cli/2.0.62 (external, cli)",
+ UserAgent: "claude-cli/2.1.2 (external, cli)",
StainlessLang: "js",
- StainlessPackageVersion: "0.52.0",
+ StainlessPackageVersion: "0.70.0",
StainlessOS: "Linux",
StainlessArch: "x64",
StainlessRuntime: "node",
- StainlessRuntimeVersion: "v22.14.0",
+ StainlessRuntimeVersion: "v24.3.0",
}
// Fingerprint represents account fingerprint data
@@ -230,7 +230,7 @@ func generateUUIDFromSeed(seed string) string {
}
// parseUserAgentVersion 解析user-agent版本号
-// 例如:claude-cli/2.0.62 -> (2, 0, 62)
+// 例如:claude-cli/2.1.2 -> (2, 1, 2)
func parseUserAgentVersion(ua string) (major, minor, patch int, ok bool) {
// 匹配 xxx/x.y.z 格式
matches := userAgentVersionRegex.FindStringSubmatch(ua)
From b8c48fb4775785e4bb607585d2f77fde03444fcc Mon Sep 17 00:00:00 2001
From: cyhhao
Date: Thu, 15 Jan 2026 19:17:07 +0800
Subject: [PATCH 008/155] =?UTF-8?q?fix(=E7=BD=91=E5=85=B3):=20=E5=8C=BA?=
=?UTF-8?q?=E5=88=86=20Claude=20Code=20OAuth=20=E9=80=82=E9=85=8D?=
MIME-Version: 1.0
Content-Type: text/plain; charset=UTF-8
Content-Transfer-Encoding: 8bit
---
backend/internal/handler/gateway_handler.go | 3 +
backend/internal/pkg/claude/constants.go | 4 +
backend/internal/service/gateway_service.go | 110 +++++++++++++++-----
3 files changed, 90 insertions(+), 27 deletions(-)
diff --git a/backend/internal/handler/gateway_handler.go b/backend/internal/handler/gateway_handler.go
index b60618a8..91d590bf 100644
--- a/backend/internal/handler/gateway_handler.go
+++ b/backend/internal/handler/gateway_handler.go
@@ -707,6 +707,9 @@ func (h *GatewayHandler) CountTokens(c *gin.Context) {
return
}
+ // 检查是否为 Claude Code 客户端,设置到 context 中
+ SetClaudeCodeClientContext(c, body)
+
setOpsRequestContext(c, "", false, body)
parsedReq, err := service.ParseGatewayRequest(body)
diff --git a/backend/internal/pkg/claude/constants.go b/backend/internal/pkg/claude/constants.go
index 15144881..f60412c2 100644
--- a/backend/internal/pkg/claude/constants.go
+++ b/backend/internal/pkg/claude/constants.go
@@ -9,11 +9,15 @@ const (
BetaClaudeCode = "claude-code-20250219"
BetaInterleavedThinking = "interleaved-thinking-2025-05-14"
BetaFineGrainedToolStreaming = "fine-grained-tool-streaming-2025-05-14"
+ BetaTokenCounting = "token-counting-2024-11-01"
)
// DefaultBetaHeader Claude Code 客户端默认的 anthropic-beta header
const DefaultBetaHeader = BetaClaudeCode + "," + BetaOAuth + "," + BetaInterleavedThinking + "," + BetaFineGrainedToolStreaming
+// CountTokensBetaHeader count_tokens 请求使用的 anthropic-beta header
+const CountTokensBetaHeader = BetaClaudeCode + "," + BetaOAuth + "," + BetaInterleavedThinking + "," + BetaTokenCounting
+
// HaikuBetaHeader Haiku 模型使用的 anthropic-beta header(不需要 claude-code beta)
const HaikuBetaHeader = BetaOAuth + "," + BetaInterleavedThinking
diff --git a/backend/internal/service/gateway_service.go b/backend/internal/service/gateway_service.go
index 899a0fc5..93dc59dc 100644
--- a/backend/internal/service/gateway_service.go
+++ b/backend/internal/service/gateway_service.go
@@ -65,6 +65,8 @@ var (
toolPrefixRe = regexp.MustCompile(`(?i)^(?:oc_|mcp_)`)
toolNameBoundaryRe = regexp.MustCompile(`[^a-zA-Z0-9]+`)
toolNameCamelRe = regexp.MustCompile(`([a-z0-9])([A-Z])`)
+ toolNameFieldRe = regexp.MustCompile(`"name"\s*:\s*"([^"]+)"`)
+ modelFieldRe = regexp.MustCompile(`"model"\s*:\s*"([^"]+)"`)
claudeToolNameOverrides = map[string]string{
"bash": "Bash",
@@ -1941,6 +1943,16 @@ func isClaudeCodeClient(userAgent string, metadataUserID string) bool {
return claudeCliUserAgentRe.MatchString(userAgent)
}
+func isClaudeCodeRequest(ctx context.Context, c *gin.Context, parsed *ParsedRequest) bool {
+ if IsClaudeCodeClient(ctx) {
+ return true
+ }
+ if parsed == nil || c == nil {
+ return false
+ }
+ return isClaudeCodeClient(c.GetHeader("User-Agent"), parsed.MetadataUserID)
+}
+
// systemIncludesClaudeCodePrompt 检查 system 中是否已包含 Claude Code 提示词
// 使用前缀匹配支持多种变体(标准版、Agent SDK 版等)
func systemIncludesClaudeCodePrompt(system any) bool {
@@ -2203,11 +2215,13 @@ func (s *GatewayService) Forward(ctx context.Context, c *gin.Context, account *A
originalModel := reqModel
var toolNameMap map[string]string
- if account.IsOAuth() {
+ isClaudeCode := isClaudeCodeRequest(ctx, c, parsed)
+ shouldMimicClaudeCode := account.IsOAuth() && !isClaudeCode
+
+ if shouldMimicClaudeCode {
// 智能注入 Claude Code 系统提示词(仅 OAuth/SetupToken 账号需要)
// 条件:1) OAuth/SetupToken 账号 2) 不是 Claude Code 客户端 3) 不是 Haiku 模型 4) system 中还没有 Claude Code 提示词
- if !isClaudeCodeClient(c.GetHeader("User-Agent"), parsed.MetadataUserID) &&
- !strings.Contains(strings.ToLower(reqModel), "haiku") &&
+ if !strings.Contains(strings.ToLower(reqModel), "haiku") &&
!systemIncludesClaudeCodePrompt(parsed.System) {
body = injectClaudeCodePrompt(body, parsed.System)
}
@@ -2257,7 +2271,7 @@ func (s *GatewayService) Forward(ctx context.Context, c *gin.Context, account *A
retryStart := time.Now()
for attempt := 1; attempt <= maxRetryAttempts; attempt++ {
// 构建上游请求(每次重试需要重新构建,因为请求体需要重新读取)
- upstreamReq, err := s.buildUpstreamRequest(ctx, c, account, body, token, tokenType, reqModel, reqStream)
+ upstreamReq, err := s.buildUpstreamRequest(ctx, c, account, body, token, tokenType, reqModel, reqStream, shouldMimicClaudeCode)
// Capture upstream request body for ops retry of this attempt.
c.Set(OpsUpstreamRequestBodyKey, string(body))
if err != nil {
@@ -2337,7 +2351,7 @@ func (s *GatewayService) Forward(ctx context.Context, c *gin.Context, account *A
// also downgrade tool_use/tool_result blocks to text.
filteredBody := FilterThinkingBlocksForRetry(body)
- retryReq, buildErr := s.buildUpstreamRequest(ctx, c, account, filteredBody, token, tokenType, reqModel, reqStream)
+ retryReq, buildErr := s.buildUpstreamRequest(ctx, c, account, filteredBody, token, tokenType, reqModel, reqStream, shouldMimicClaudeCode)
if buildErr == nil {
retryResp, retryErr := s.httpUpstream.Do(retryReq, proxyURL, account.ID, account.Concurrency)
if retryErr == nil {
@@ -2369,7 +2383,7 @@ func (s *GatewayService) Forward(ctx context.Context, c *gin.Context, account *A
if looksLikeToolSignatureError(msg2) && time.Since(retryStart) < maxRetryElapsed {
log.Printf("Account %d: signature retry still failing and looks tool-related, retrying with tool blocks downgraded", account.ID)
filteredBody2 := FilterSignatureSensitiveBlocksForRetry(body)
- retryReq2, buildErr2 := s.buildUpstreamRequest(ctx, c, account, filteredBody2, token, tokenType, reqModel, reqStream)
+ retryReq2, buildErr2 := s.buildUpstreamRequest(ctx, c, account, filteredBody2, token, tokenType, reqModel, reqStream, shouldMimicClaudeCode)
if buildErr2 == nil {
retryResp2, retryErr2 := s.httpUpstream.Do(retryReq2, proxyURL, account.ID, account.Concurrency)
if retryErr2 == nil {
@@ -2586,7 +2600,7 @@ func (s *GatewayService) Forward(ctx context.Context, c *gin.Context, account *A
var firstTokenMs *int
var clientDisconnect bool
if reqStream {
- streamResult, err := s.handleStreamingResponse(ctx, resp, c, account, startTime, originalModel, reqModel, toolNameMap)
+ streamResult, err := s.handleStreamingResponse(ctx, resp, c, account, startTime, originalModel, reqModel, toolNameMap, shouldMimicClaudeCode)
if err != nil {
if err.Error() == "have error in stream" {
return nil, &UpstreamFailoverError{
@@ -2599,7 +2613,7 @@ func (s *GatewayService) Forward(ctx context.Context, c *gin.Context, account *A
firstTokenMs = streamResult.firstTokenMs
clientDisconnect = streamResult.clientDisconnect
} else {
- usage, err = s.handleNonStreamingResponse(ctx, resp, c, account, originalModel, reqModel, toolNameMap)
+ usage, err = s.handleNonStreamingResponse(ctx, resp, c, account, originalModel, reqModel, toolNameMap, shouldMimicClaudeCode)
if err != nil {
return nil, err
}
@@ -2616,7 +2630,7 @@ func (s *GatewayService) Forward(ctx context.Context, c *gin.Context, account *A
}, nil
}
-func (s *GatewayService) buildUpstreamRequest(ctx context.Context, c *gin.Context, account *Account, body []byte, token, tokenType, modelID string, reqStream bool) (*http.Request, error) {
+func (s *GatewayService) buildUpstreamRequest(ctx context.Context, c *gin.Context, account *Account, body []byte, token, tokenType, modelID string, reqStream bool, mimicClaudeCode bool) (*http.Request, error) {
// 确定目标URL
targetURL := claudeAPIURL
if account.Type == AccountTypeAPIKey {
@@ -2632,7 +2646,7 @@ func (s *GatewayService) buildUpstreamRequest(ctx context.Context, c *gin.Contex
// OAuth账号:应用统一指纹
var fingerprint *Fingerprint
- if account.IsOAuth() && s.identityService != nil {
+ if account.IsOAuth() && mimicClaudeCode && s.identityService != nil {
// 1. 获取或创建指纹(包含随机生成的ClientID)
fp, err := s.identityService.GetOrCreateFingerprint(ctx, account.ID, c.Request.Header)
if err != nil {
@@ -2685,12 +2699,12 @@ func (s *GatewayService) buildUpstreamRequest(ctx context.Context, c *gin.Contex
if req.Header.Get("anthropic-version") == "" {
req.Header.Set("anthropic-version", "2023-06-01")
}
- if tokenType == "oauth" {
+ if tokenType == "oauth" && mimicClaudeCode {
applyClaudeOAuthHeaderDefaults(req, reqStream)
}
// 处理anthropic-beta header(OAuth账号需要特殊处理)
- if tokenType == "oauth" {
+ if tokenType == "oauth" && mimicClaudeCode {
req.Header.Set("anthropic-beta", s.getBetaHeader(modelID, c.GetHeader("anthropic-beta")))
} else if s.cfg != nil && s.cfg.Gateway.InjectBetaForAPIKey && req.Header.Get("anthropic-beta") == "" {
// API-key:仅在请求显式使用 beta 特性且客户端未提供时,按需补齐(默认关闭)
@@ -3070,7 +3084,7 @@ type streamingResult struct {
clientDisconnect bool // 客户端是否在流式传输过程中断开
}
-func (s *GatewayService) handleStreamingResponse(ctx context.Context, resp *http.Response, c *gin.Context, account *Account, startTime time.Time, originalModel, mappedModel string, toolNameMap map[string]string) (*streamingResult, error) {
+func (s *GatewayService) handleStreamingResponse(ctx context.Context, resp *http.Response, c *gin.Context, account *Account, startTime time.Time, originalModel, mappedModel string, toolNameMap map[string]string, mimicClaudeCode bool) (*streamingResult, error) {
// 更新5h窗口状态
s.rateLimitService.UpdateSessionWindow(ctx, account, resp.Header)
@@ -3163,7 +3177,7 @@ func (s *GatewayService) handleStreamingResponse(ctx context.Context, resp *http
}
needModelReplace := originalModel != mappedModel
- rewriteTools := account.IsOAuth()
+ rewriteTools := mimicClaudeCode
clientDisconnected := false // 客户端断开标志,断开后继续读取上游以获取完整usage
for {
@@ -3327,6 +3341,37 @@ func rewriteToolNamesInValue(value any, toolNameMap map[string]string) bool {
}
}
+func replaceToolNamesInText(text string, toolNameMap map[string]string) string {
+ if text == "" {
+ return text
+ }
+ output := toolNameFieldRe.ReplaceAllStringFunc(text, func(match string) string {
+ submatches := toolNameFieldRe.FindStringSubmatch(match)
+ if len(submatches) < 2 {
+ return match
+ }
+ name := submatches[1]
+ mapped := normalizeToolNameForOpenCode(name, toolNameMap)
+ if mapped == name {
+ return match
+ }
+ return strings.Replace(match, name, mapped, 1)
+ })
+ output = modelFieldRe.ReplaceAllStringFunc(output, func(match string) string {
+ submatches := modelFieldRe.FindStringSubmatch(match)
+ if len(submatches) < 2 {
+ return match
+ }
+ model := submatches[1]
+ mapped := claude.DenormalizeModelID(model)
+ if mapped == model {
+ return match
+ }
+ return strings.Replace(match, model, mapped, 1)
+ })
+ return output
+}
+
func (s *GatewayService) replaceToolNamesInSSELine(line string, toolNameMap map[string]string) string {
if !sseDataRe.MatchString(line) {
return line
@@ -3338,7 +3383,11 @@ func (s *GatewayService) replaceToolNamesInSSELine(line string, toolNameMap map[
var event map[string]any
if err := json.Unmarshal([]byte(data), &event); err != nil {
- return line
+ replaced := replaceToolNamesInText(data, toolNameMap)
+ if replaced == data {
+ return line
+ }
+ return "data: " + replaced
}
if !rewriteToolNamesInValue(event, toolNameMap) {
return line
@@ -3391,7 +3440,7 @@ func (s *GatewayService) parseSSEUsage(data string, usage *ClaudeUsage) {
}
}
-func (s *GatewayService) handleNonStreamingResponse(ctx context.Context, resp *http.Response, c *gin.Context, account *Account, originalModel, mappedModel string, toolNameMap map[string]string) (*ClaudeUsage, error) {
+func (s *GatewayService) handleNonStreamingResponse(ctx context.Context, resp *http.Response, c *gin.Context, account *Account, originalModel, mappedModel string, toolNameMap map[string]string, mimicClaudeCode bool) (*ClaudeUsage, error) {
// 更新5h窗口状态
s.rateLimitService.UpdateSessionWindow(ctx, account, resp.Header)
@@ -3412,7 +3461,7 @@ func (s *GatewayService) handleNonStreamingResponse(ctx context.Context, resp *h
if originalModel != mappedModel {
body = s.replaceModelInResponseBody(body, mappedModel, originalModel)
}
- if account.IsOAuth() {
+ if mimicClaudeCode {
body = s.replaceToolNamesInResponseBody(body, toolNameMap)
}
@@ -3458,7 +3507,11 @@ func (s *GatewayService) replaceToolNamesInResponseBody(body []byte, toolNameMap
}
var resp map[string]any
if err := json.Unmarshal(body, &resp); err != nil {
- return body
+ replaced := replaceToolNamesInText(string(body), toolNameMap)
+ if replaced == string(body) {
+ return body
+ }
+ return []byte(replaced)
}
if !rewriteToolNamesInValue(resp, toolNameMap) {
return body
@@ -3635,7 +3688,10 @@ func (s *GatewayService) ForwardCountTokens(ctx context.Context, c *gin.Context,
body := parsed.Body
reqModel := parsed.Model
- if account.IsOAuth() {
+ isClaudeCode := isClaudeCodeRequest(ctx, c, parsed)
+ shouldMimicClaudeCode := account.IsOAuth() && !isClaudeCode
+
+ if shouldMimicClaudeCode {
normalizeOpts := claudeOAuthNormalizeOptions{stripSystemCacheControl: true}
body, reqModel, _ = normalizeClaudeOAuthRequestBody(body, reqModel, normalizeOpts)
}
@@ -3666,7 +3722,7 @@ func (s *GatewayService) ForwardCountTokens(ctx context.Context, c *gin.Context,
}
// 构建上游请求
- upstreamReq, err := s.buildCountTokensRequest(ctx, c, account, body, token, tokenType, reqModel)
+ upstreamReq, err := s.buildCountTokensRequest(ctx, c, account, body, token, tokenType, reqModel, shouldMimicClaudeCode)
if err != nil {
s.countTokensError(c, http.StatusInternalServerError, "api_error", "Failed to build request")
return err
@@ -3699,7 +3755,7 @@ func (s *GatewayService) ForwardCountTokens(ctx context.Context, c *gin.Context,
log.Printf("Account %d: detected thinking block signature error on count_tokens, retrying with filtered thinking blocks", account.ID)
filteredBody := FilterThinkingBlocksForRetry(body)
- retryReq, buildErr := s.buildCountTokensRequest(ctx, c, account, filteredBody, token, tokenType, reqModel)
+ retryReq, buildErr := s.buildCountTokensRequest(ctx, c, account, filteredBody, token, tokenType, reqModel, shouldMimicClaudeCode)
if buildErr == nil {
retryResp, retryErr := s.httpUpstream.Do(retryReq, proxyURL, account.ID, account.Concurrency)
if retryErr == nil {
@@ -3764,7 +3820,7 @@ func (s *GatewayService) ForwardCountTokens(ctx context.Context, c *gin.Context,
}
// buildCountTokensRequest 构建 count_tokens 上游请求
-func (s *GatewayService) buildCountTokensRequest(ctx context.Context, c *gin.Context, account *Account, body []byte, token, tokenType, modelID string) (*http.Request, error) {
+func (s *GatewayService) buildCountTokensRequest(ctx context.Context, c *gin.Context, account *Account, body []byte, token, tokenType, modelID string, mimicClaudeCode bool) (*http.Request, error) {
// 确定目标 URL
targetURL := claudeAPICountTokensURL
if account.Type == AccountTypeAPIKey {
@@ -3779,7 +3835,7 @@ func (s *GatewayService) buildCountTokensRequest(ctx context.Context, c *gin.Con
}
// OAuth 账号:应用统一指纹和重写 userID
- if account.IsOAuth() && s.identityService != nil {
+ if account.IsOAuth() && mimicClaudeCode && s.identityService != nil {
fp, err := s.identityService.GetOrCreateFingerprint(ctx, account.ID, c.Request.Header)
if err == nil {
accountUUID := account.GetExtraString("account_uuid")
@@ -3814,7 +3870,7 @@ func (s *GatewayService) buildCountTokensRequest(ctx context.Context, c *gin.Con
}
// OAuth 账号:应用指纹到请求头
- if account.IsOAuth() && s.identityService != nil {
+ if account.IsOAuth() && mimicClaudeCode && s.identityService != nil {
fp, _ := s.identityService.GetOrCreateFingerprint(ctx, account.ID, c.Request.Header)
if fp != nil {
s.identityService.ApplyFingerprint(req, fp)
@@ -3828,13 +3884,13 @@ func (s *GatewayService) buildCountTokensRequest(ctx context.Context, c *gin.Con
if req.Header.Get("anthropic-version") == "" {
req.Header.Set("anthropic-version", "2023-06-01")
}
- if tokenType == "oauth" {
+ if tokenType == "oauth" && mimicClaudeCode {
applyClaudeOAuthHeaderDefaults(req, false)
}
// OAuth 账号:处理 anthropic-beta header
- if tokenType == "oauth" {
- req.Header.Set("anthropic-beta", s.getBetaHeader(modelID, c.GetHeader("anthropic-beta")))
+ if tokenType == "oauth" && mimicClaudeCode {
+ req.Header.Set("anthropic-beta", claude.CountTokensBetaHeader)
} else if s.cfg != nil && s.cfg.Gateway.InjectBetaForAPIKey && req.Header.Get("anthropic-beta") == "" {
// API-key:与 messages 同步的按需 beta 注入(默认关闭)
if requestNeedsBetaFeatures(body) {
From 0962ba43c0fcc517225d716b056cc3dd3d71125f Mon Sep 17 00:00:00 2001
From: cyhhao
Date: Fri, 16 Jan 2026 00:41:29 +0800
Subject: [PATCH 009/155] =?UTF-8?q?fix(=E7=BD=91=E5=85=B3):=20=E8=A1=A5?=
=?UTF-8?q?=E9=BD=90=E9=9D=9E=20Claude=20Code=20OAuth=20=E5=85=BC=E5=AE=B9?=
MIME-Version: 1.0
Content-Type: text/plain; charset=UTF-8
Content-Transfer-Encoding: 8bit
---
backend/internal/pkg/claude/constants.go | 6 +
backend/internal/service/account.go | 16 ++
backend/internal/service/gateway_service.go | 239 +++++++++++++++++---
3 files changed, 232 insertions(+), 29 deletions(-)
diff --git a/backend/internal/pkg/claude/constants.go b/backend/internal/pkg/claude/constants.go
index f60412c2..0c6e9b4c 100644
--- a/backend/internal/pkg/claude/constants.go
+++ b/backend/internal/pkg/claude/constants.go
@@ -15,6 +15,12 @@ const (
// DefaultBetaHeader Claude Code 客户端默认的 anthropic-beta header
const DefaultBetaHeader = BetaClaudeCode + "," + BetaOAuth + "," + BetaInterleavedThinking + "," + BetaFineGrainedToolStreaming
+// MessageBetaHeaderNoTools /v1/messages 在无工具时的 beta header
+const MessageBetaHeaderNoTools = BetaOAuth + "," + BetaInterleavedThinking
+
+// MessageBetaHeaderWithTools /v1/messages 在有工具时的 beta header
+const MessageBetaHeaderWithTools = BetaClaudeCode + "," + BetaOAuth + "," + BetaInterleavedThinking
+
// CountTokensBetaHeader count_tokens 请求使用的 anthropic-beta header
const CountTokensBetaHeader = BetaClaudeCode + "," + BetaOAuth + "," + BetaInterleavedThinking + "," + BetaTokenCounting
diff --git a/backend/internal/service/account.go b/backend/internal/service/account.go
index 0d7a9cf9..9f965682 100644
--- a/backend/internal/service/account.go
+++ b/backend/internal/service/account.go
@@ -381,6 +381,22 @@ func (a *Account) GetExtraString(key string) string {
return ""
}
+func (a *Account) GetClaudeUserID() string {
+ if v := strings.TrimSpace(a.GetExtraString("claude_user_id")); v != "" {
+ return v
+ }
+ if v := strings.TrimSpace(a.GetExtraString("anthropic_user_id")); v != "" {
+ return v
+ }
+ if v := strings.TrimSpace(a.GetCredential("claude_user_id")); v != "" {
+ return v
+ }
+ if v := strings.TrimSpace(a.GetCredential("anthropic_user_id")); v != "" {
+ return v
+ }
+ return ""
+}
+
func (a *Account) IsCustomErrorCodesEnabled() bool {
if a.Type != AccountTypeAPIKey || a.Credentials == nil {
return false
diff --git a/backend/internal/service/gateway_service.go b/backend/internal/service/gateway_service.go
index 93dc59dc..71ad0d00 100644
--- a/backend/internal/service/gateway_service.go
+++ b/backend/internal/service/gateway_service.go
@@ -67,6 +67,9 @@ var (
toolNameCamelRe = regexp.MustCompile(`([a-z0-9])([A-Z])`)
toolNameFieldRe = regexp.MustCompile(`"name"\s*:\s*"([^"]+)"`)
modelFieldRe = regexp.MustCompile(`"model"\s*:\s*"([^"]+)"`)
+ toolDescAbsPathRe = regexp.MustCompile(`/\/?(?:home|Users|tmp|var|opt|usr|etc)\/[^\s,\)"'\]]+`)
+ toolDescWinPathRe = regexp.MustCompile(`(?i)[A-Z]:\\[^\s,\)"'\]]+`)
+ opencodeTextRe = regexp.MustCompile(`(?i)opencode`)
claudeToolNameOverrides = map[string]string{
"bash": "Bash",
@@ -470,6 +473,22 @@ func normalizeToolNameForClaude(name string, cache map[string]string) string {
}
func normalizeToolNameForOpenCode(name string, cache map[string]string) string {
+ if name == "" {
+ return name
+ }
+ stripped := stripToolPrefix(name)
+ if cache != nil {
+ if mapped, ok := cache[stripped]; ok {
+ return mapped
+ }
+ }
+ if mapped, ok := openCodeToolOverrides[stripped]; ok {
+ return mapped
+ }
+ return toSnakeCase(stripped)
+}
+
+func normalizeParamNameForOpenCode(name string, cache map[string]string) string {
if name == "" {
return name
}
@@ -478,10 +497,63 @@ func normalizeToolNameForOpenCode(name string, cache map[string]string) string {
return mapped
}
}
- if mapped, ok := openCodeToolOverrides[name]; ok {
- return mapped
+ return name
+}
+
+func sanitizeOpenCodeText(text string) string {
+ if text == "" {
+ return text
+ }
+ text = strings.ReplaceAll(text, "OpenCode", "Claude Code")
+ text = opencodeTextRe.ReplaceAllString(text, "Claude")
+ return text
+}
+
+func sanitizeToolDescription(description string) string {
+ if description == "" {
+ return description
+ }
+ description = toolDescAbsPathRe.ReplaceAllString(description, "[path]")
+ description = toolDescWinPathRe.ReplaceAllString(description, "[path]")
+ return sanitizeOpenCodeText(description)
+}
+
+func normalizeToolInputSchema(inputSchema any, cache map[string]string) {
+ schema, ok := inputSchema.(map[string]any)
+ if !ok {
+ return
+ }
+ properties, ok := schema["properties"].(map[string]any)
+ if !ok {
+ return
+ }
+
+ newProperties := make(map[string]any, len(properties))
+ for key, value := range properties {
+ snakeKey := toSnakeCase(key)
+ newProperties[snakeKey] = value
+ if snakeKey != key && cache != nil {
+ cache[snakeKey] = key
+ }
+ }
+ schema["properties"] = newProperties
+
+ if required, ok := schema["required"].([]any); ok {
+ newRequired := make([]any, 0, len(required))
+ for _, item := range required {
+ name, ok := item.(string)
+ if !ok {
+ newRequired = append(newRequired, item)
+ continue
+ }
+ snakeName := toSnakeCase(name)
+ newRequired = append(newRequired, snakeName)
+ if snakeName != name && cache != nil {
+ cache[snakeName] = name
+ }
+ }
+ schema["required"] = newRequired
}
- return toSnakeCase(name)
}
func stripCacheControlFromSystemBlocks(system any) bool {
@@ -498,9 +570,6 @@ func stripCacheControlFromSystemBlocks(system any) bool {
if _, exists := block["cache_control"]; !exists {
continue
}
- if text, ok := block["text"].(string); ok && text == claudeCodeSystemPrompt {
- continue
- }
delete(block, "cache_control")
changed = true
}
@@ -518,6 +587,34 @@ func normalizeClaudeOAuthRequestBody(body []byte, modelID string, opts claudeOAu
toolNameMap := make(map[string]string)
+ if system, ok := req["system"]; ok {
+ switch v := system.(type) {
+ case string:
+ sanitized := sanitizeOpenCodeText(v)
+ if sanitized != v {
+ req["system"] = sanitized
+ }
+ case []any:
+ for _, item := range v {
+ block, ok := item.(map[string]any)
+ if !ok {
+ continue
+ }
+ if blockType, _ := block["type"].(string); blockType != "text" {
+ continue
+ }
+ text, ok := block["text"].(string)
+ if !ok || text == "" {
+ continue
+ }
+ sanitized := sanitizeOpenCodeText(text)
+ if sanitized != text {
+ block["text"] = sanitized
+ }
+ }
+ }
+ }
+
if rawModel, ok := req["model"].(string); ok {
normalized := claude.NormalizeModelID(rawModel)
if normalized != rawModel {
@@ -540,6 +637,15 @@ func normalizeClaudeOAuthRequestBody(body []byte, modelID string, opts claudeOAu
toolMap["name"] = normalized
}
}
+ if desc, ok := toolMap["description"].(string); ok {
+ sanitized := sanitizeToolDescription(desc)
+ if sanitized != desc {
+ toolMap["description"] = sanitized
+ }
+ }
+ if schema, ok := toolMap["input_schema"]; ok {
+ normalizeToolInputSchema(schema, toolNameMap)
+ }
tools[idx] = toolMap
}
req["tools"] = tools
@@ -551,13 +657,15 @@ func normalizeClaudeOAuthRequestBody(body []byte, modelID string, opts claudeOAu
normalized = name
}
if toolMap, ok := value.(map[string]any); ok {
- if toolName, ok := toolMap["name"].(string); ok {
- mappedName := normalizeToolNameForClaude(toolName, toolNameMap)
- if mappedName != "" && mappedName != toolName {
- toolMap["name"] = mappedName
+ toolMap["name"] = normalized
+ if desc, ok := toolMap["description"].(string); ok {
+ sanitized := sanitizeToolDescription(desc)
+ if sanitized != desc {
+ toolMap["description"] = sanitized
}
- } else if normalized != name {
- toolMap["name"] = normalized
+ }
+ if schema, ok := toolMap["input_schema"]; ok {
+ normalizeToolInputSchema(schema, toolNameMap)
}
normalizedTools[normalized] = toolMap
continue
@@ -630,7 +738,7 @@ func normalizeClaudeOAuthRequestBody(body []byte, modelID string, opts claudeOAu
}
func (s *GatewayService) buildOAuthMetadataUserID(parsed *ParsedRequest, account *Account, fp *Fingerprint) string {
- if parsed == nil || fp == nil || fp.ClientID == "" {
+ if parsed == nil || account == nil {
return ""
}
if parsed.MetadataUserID != "" {
@@ -640,13 +748,22 @@ func (s *GatewayService) buildOAuthMetadataUserID(parsed *ParsedRequest, account
if accountUUID == "" {
return ""
}
+
+ userID := strings.TrimSpace(account.GetClaudeUserID())
+ if userID == "" && fp != nil {
+ userID = fp.ClientID
+ }
+ if userID == "" {
+ return ""
+ }
+
sessionHash := s.GenerateSessionHash(parsed)
sessionID := uuid.NewString()
if sessionHash != "" {
seed := fmt.Sprintf("%d::%s", account.ID, sessionHash)
sessionID = generateSessionUUID(seed)
}
- return fmt.Sprintf("user_%s_account_%s_session_%s", fp.ClientID, accountUUID, sessionID)
+ return fmt.Sprintf("user_%s_account_%s_session_%s", userID, accountUUID, sessionID)
}
func generateSessionUUID(seed string) string {
@@ -2705,7 +2822,11 @@ func (s *GatewayService) buildUpstreamRequest(ctx context.Context, c *gin.Contex
// 处理anthropic-beta header(OAuth账号需要特殊处理)
if tokenType == "oauth" && mimicClaudeCode {
- req.Header.Set("anthropic-beta", s.getBetaHeader(modelID, c.GetHeader("anthropic-beta")))
+ if requestHasTools(body) {
+ req.Header.Set("anthropic-beta", claude.MessageBetaHeaderWithTools)
+ } else {
+ req.Header.Set("anthropic-beta", claude.MessageBetaHeaderNoTools)
+ }
} else if s.cfg != nil && s.cfg.Gateway.InjectBetaForAPIKey && req.Header.Get("anthropic-beta") == "" {
// API-key:仅在请求显式使用 beta 特性且客户端未提供时,按需补齐(默认关闭)
if requestNeedsBetaFeatures(body) {
@@ -2776,6 +2897,20 @@ func requestNeedsBetaFeatures(body []byte) bool {
return false
}
+func requestHasTools(body []byte) bool {
+ tools := gjson.GetBytes(body, "tools")
+ if !tools.Exists() {
+ return false
+ }
+ if tools.IsArray() {
+ return len(tools.Array()) > 0
+ }
+ if tools.IsObject() {
+ return len(tools.Map()) > 0
+ }
+ return false
+}
+
func defaultAPIKeyBetaHeader(body []byte) string {
modelID := gjson.GetBytes(body, "model").String()
if strings.Contains(strings.ToLower(modelID), "haiku") {
@@ -3309,6 +3444,45 @@ func (s *GatewayService) replaceModelInSSELine(line, fromModel, toModel string)
return "data: " + string(newData)
}
+func rewriteParamKeysInValue(value any, cache map[string]string) (any, bool) {
+ switch v := value.(type) {
+ case map[string]any:
+ changed := false
+ rewritten := make(map[string]any, len(v))
+ for key, item := range v {
+ newKey := normalizeParamNameForOpenCode(key, cache)
+ newItem, childChanged := rewriteParamKeysInValue(item, cache)
+ if childChanged {
+ changed = true
+ }
+ if newKey != key {
+ changed = true
+ }
+ rewritten[newKey] = newItem
+ }
+ if !changed {
+ return value, false
+ }
+ return rewritten, true
+ case []any:
+ changed := false
+ rewritten := make([]any, len(v))
+ for idx, item := range v {
+ newItem, childChanged := rewriteParamKeysInValue(item, cache)
+ if childChanged {
+ changed = true
+ }
+ rewritten[idx] = newItem
+ }
+ if !changed {
+ return value, false
+ }
+ return rewritten, true
+ default:
+ return value, false
+ }
+}
+
func rewriteToolNamesInValue(value any, toolNameMap map[string]string) bool {
switch v := value.(type) {
case map[string]any:
@@ -3321,6 +3495,15 @@ func rewriteToolNamesInValue(value any, toolNameMap map[string]string) bool {
changed = true
}
}
+ if input, ok := v["input"].(map[string]any); ok {
+ rewrittenInput, inputChanged := rewriteParamKeysInValue(input, toolNameMap)
+ if inputChanged {
+ if m, ok := rewrittenInput.(map[string]any); ok {
+ v["input"] = m
+ changed = true
+ }
+ }
+ }
}
for _, item := range v {
if rewriteToolNamesInValue(item, toolNameMap) {
@@ -3369,6 +3552,15 @@ func replaceToolNamesInText(text string, toolNameMap map[string]string) string {
}
return strings.Replace(match, model, mapped, 1)
})
+
+ for mapped, original := range toolNameMap {
+ if mapped == "" || original == "" || mapped == original {
+ continue
+ }
+ output = strings.ReplaceAll(output, "\""+mapped+"\":", "\""+original+"\":")
+ output = strings.ReplaceAll(output, "\\\""+mapped+"\\\":", "\\\""+original+"\\\":")
+ }
+
return output
}
@@ -3381,22 +3573,11 @@ func (s *GatewayService) replaceToolNamesInSSELine(line string, toolNameMap map[
return line
}
- var event map[string]any
- if err := json.Unmarshal([]byte(data), &event); err != nil {
- replaced := replaceToolNamesInText(data, toolNameMap)
- if replaced == data {
- return line
- }
- return "data: " + replaced
- }
- if !rewriteToolNamesInValue(event, toolNameMap) {
+ replaced := replaceToolNamesInText(data, toolNameMap)
+ if replaced == data {
return line
}
- newData, err := json.Marshal(event)
- if err != nil {
- return line
- }
- return "data: " + string(newData)
+ return "data: " + replaced
}
func (s *GatewayService) parseSSEUsage(data string, usage *ClaudeUsage) {
From 0c011b889b980ba4626703af4d54e1879cfd3f9c Mon Sep 17 00:00:00 2001
From: cyhhao
Date: Fri, 16 Jan 2026 23:15:52 +0800
Subject: [PATCH 010/155] =?UTF-8?q?fix(=E7=BD=91=E5=85=B3):=20Claude=20Cod?=
=?UTF-8?q?e=20OAuth=20=E8=A1=A5=E9=BD=90=20oauth=20beta?=
MIME-Version: 1.0
Content-Type: text/plain; charset=UTF-8
Content-Transfer-Encoding: 8bit
---
backend/internal/service/gateway_service.go | 34 ++++++++++++++++-----
1 file changed, 27 insertions(+), 7 deletions(-)
diff --git a/backend/internal/service/gateway_service.go b/backend/internal/service/gateway_service.go
index 71ad0d00..8b4871c9 100644
--- a/backend/internal/service/gateway_service.go
+++ b/backend/internal/service/gateway_service.go
@@ -2820,12 +2820,19 @@ func (s *GatewayService) buildUpstreamRequest(ctx context.Context, c *gin.Contex
applyClaudeOAuthHeaderDefaults(req, reqStream)
}
- // 处理anthropic-beta header(OAuth账号需要特殊处理)
- if tokenType == "oauth" && mimicClaudeCode {
- if requestHasTools(body) {
- req.Header.Set("anthropic-beta", claude.MessageBetaHeaderWithTools)
+ // 处理 anthropic-beta header(OAuth 账号需要包含 oauth beta)
+ if tokenType == "oauth" {
+ if mimicClaudeCode {
+ // 非 Claude Code 客户端:按 Claude Code 规则生成 beta header
+ if requestHasTools(body) {
+ req.Header.Set("anthropic-beta", claude.MessageBetaHeaderWithTools)
+ } else {
+ req.Header.Set("anthropic-beta", claude.MessageBetaHeaderNoTools)
+ }
} else {
- req.Header.Set("anthropic-beta", claude.MessageBetaHeaderNoTools)
+ // Claude Code 客户端:尽量透传原始 header,仅补齐 oauth beta
+ clientBetaHeader := req.Header.Get("anthropic-beta")
+ req.Header.Set("anthropic-beta", s.getBetaHeader(modelID, clientBetaHeader))
}
} else if s.cfg != nil && s.cfg.Gateway.InjectBetaForAPIKey && req.Header.Get("anthropic-beta") == "" {
// API-key:仅在请求显式使用 beta 特性且客户端未提供时,按需补齐(默认关闭)
@@ -4070,8 +4077,21 @@ func (s *GatewayService) buildCountTokensRequest(ctx context.Context, c *gin.Con
}
// OAuth 账号:处理 anthropic-beta header
- if tokenType == "oauth" && mimicClaudeCode {
- req.Header.Set("anthropic-beta", claude.CountTokensBetaHeader)
+ if tokenType == "oauth" {
+ if mimicClaudeCode {
+ req.Header.Set("anthropic-beta", claude.CountTokensBetaHeader)
+ } else {
+ clientBetaHeader := req.Header.Get("anthropic-beta")
+ if clientBetaHeader == "" {
+ req.Header.Set("anthropic-beta", claude.CountTokensBetaHeader)
+ } else {
+ beta := s.getBetaHeader(modelID, clientBetaHeader)
+ if !strings.Contains(beta, claude.BetaTokenCounting) {
+ beta = beta + "," + claude.BetaTokenCounting
+ }
+ req.Header.Set("anthropic-beta", beta)
+ }
+ }
} else if s.cfg != nil && s.cfg.Gateway.InjectBetaForAPIKey && req.Header.Get("anthropic-beta") == "" {
// API-key:与 messages 同步的按需 beta 注入(默认关闭)
if requestNeedsBetaFeatures(body) {
From 28e46e0e7cd9337a89ce221d3afd8e27baf95168 Mon Sep 17 00:00:00 2001
From: IanShaw027 <131567472+IanShaw027@users.noreply.github.com>
Date: Fri, 16 Jan 2026 23:47:42 +0800
Subject: [PATCH 011/155] =?UTF-8?q?fix(gemini):=20=E6=9B=B4=E6=96=B0=20Gem?=
=?UTF-8?q?ini=20=E6=A8=A1=E5=9E=8B=E5=88=97=E8=A1=A8=E9=85=8D=E7=BD=AE?=
MIME-Version: 1.0
Content-Type: text/plain; charset=UTF-8
Content-Transfer-Encoding: 8bit
- 移除已弃用的 1.5 系列模型
- 调整模型优先级顺序(2.0 Flash > 2.5 Flash > 2.5 Pro > 3.0 Preview)
- 同步前后端模型配置
- 更新相关测试用例和默认模型选择逻辑
---
backend/internal/pkg/gemini/models.go | 11 +++-----
backend/internal/pkg/geminicli/models.go | 4 +--
.../service/antigravity_model_mapping_test.go | 8 +++---
.../service/gemini_multiplatform_test.go | 2 +-
backend/internal/service/pricing_service.go | 4 +--
.../components/account/AccountTestModal.vue | 5 +++-
.../admin/account/AccountTestModal.vue | 5 +++-
frontend/src/components/keys/UseKeyModal.vue | 26 ++++++++++++-------
frontend/src/composables/useModelWhitelist.ts | 19 +++++++-------
9 files changed, 47 insertions(+), 37 deletions(-)
diff --git a/backend/internal/pkg/gemini/models.go b/backend/internal/pkg/gemini/models.go
index e251c8d8..424e8ddb 100644
--- a/backend/internal/pkg/gemini/models.go
+++ b/backend/internal/pkg/gemini/models.go
@@ -16,14 +16,11 @@ type ModelsListResponse struct {
func DefaultModels() []Model {
methods := []string{"generateContent", "streamGenerateContent"}
return []Model{
- {Name: "models/gemini-3-pro-preview", SupportedGenerationMethods: methods},
- {Name: "models/gemini-3-flash-preview", SupportedGenerationMethods: methods},
- {Name: "models/gemini-2.5-pro", SupportedGenerationMethods: methods},
- {Name: "models/gemini-2.5-flash", SupportedGenerationMethods: methods},
{Name: "models/gemini-2.0-flash", SupportedGenerationMethods: methods},
- {Name: "models/gemini-1.5-pro", SupportedGenerationMethods: methods},
- {Name: "models/gemini-1.5-flash", SupportedGenerationMethods: methods},
- {Name: "models/gemini-1.5-flash-8b", SupportedGenerationMethods: methods},
+ {Name: "models/gemini-2.5-flash", SupportedGenerationMethods: methods},
+ {Name: "models/gemini-2.5-pro", SupportedGenerationMethods: methods},
+ {Name: "models/gemini-3-flash-preview", SupportedGenerationMethods: methods},
+ {Name: "models/gemini-3-pro-preview", SupportedGenerationMethods: methods},
}
}
diff --git a/backend/internal/pkg/geminicli/models.go b/backend/internal/pkg/geminicli/models.go
index 922988c7..08e69886 100644
--- a/backend/internal/pkg/geminicli/models.go
+++ b/backend/internal/pkg/geminicli/models.go
@@ -12,10 +12,10 @@ type Model struct {
// DefaultModels is the curated Gemini model list used by the admin UI "test account" flow.
var DefaultModels = []Model{
{ID: "gemini-2.0-flash", Type: "model", DisplayName: "Gemini 2.0 Flash", CreatedAt: ""},
- {ID: "gemini-2.5-pro", Type: "model", DisplayName: "Gemini 2.5 Pro", CreatedAt: ""},
{ID: "gemini-2.5-flash", Type: "model", DisplayName: "Gemini 2.5 Flash", CreatedAt: ""},
- {ID: "gemini-3-pro-preview", Type: "model", DisplayName: "Gemini 3 Pro Preview", CreatedAt: ""},
+ {ID: "gemini-2.5-pro", Type: "model", DisplayName: "Gemini 2.5 Pro", CreatedAt: ""},
{ID: "gemini-3-flash-preview", Type: "model", DisplayName: "Gemini 3 Flash Preview", CreatedAt: ""},
+ {ID: "gemini-3-pro-preview", Type: "model", DisplayName: "Gemini 3 Pro Preview", CreatedAt: ""},
}
// DefaultTestModel is the default model to preselect in test flows.
diff --git a/backend/internal/service/antigravity_model_mapping_test.go b/backend/internal/service/antigravity_model_mapping_test.go
index 39000e4f..179a3520 100644
--- a/backend/internal/service/antigravity_model_mapping_test.go
+++ b/backend/internal/service/antigravity_model_mapping_test.go
@@ -30,7 +30,7 @@ func TestIsAntigravityModelSupported(t *testing.T) {
{"可映射 - claude-3-haiku-20240307", "claude-3-haiku-20240307", true},
// Gemini 前缀透传
- {"Gemini前缀 - gemini-1.5-pro", "gemini-1.5-pro", true},
+ {"Gemini前缀 - gemini-2.5-pro", "gemini-2.5-pro", true},
{"Gemini前缀 - gemini-unknown-model", "gemini-unknown-model", true},
{"Gemini前缀 - gemini-future-version", "gemini-future-version", true},
@@ -142,10 +142,10 @@ func TestAntigravityGatewayService_GetMappedModel(t *testing.T) {
expected: "gemini-2.5-flash",
},
{
- name: "Gemini透传 - gemini-1.5-pro",
- requestedModel: "gemini-1.5-pro",
+ name: "Gemini透传 - gemini-2.5-pro",
+ requestedModel: "gemini-2.5-pro",
accountMapping: nil,
- expected: "gemini-1.5-pro",
+ expected: "gemini-2.5-pro",
},
{
name: "Gemini透传 - gemini-future-model",
diff --git a/backend/internal/service/gemini_multiplatform_test.go b/backend/internal/service/gemini_multiplatform_test.go
index 03f5d757..f2ea5859 100644
--- a/backend/internal/service/gemini_multiplatform_test.go
+++ b/backend/internal/service/gemini_multiplatform_test.go
@@ -599,7 +599,7 @@ func TestGeminiMessagesCompatService_isModelSupportedByAccount(t *testing.T) {
name: "Gemini平台-有映射配置-只支持配置的模型",
account: &Account{
Platform: PlatformGemini,
- Credentials: map[string]any{"model_mapping": map[string]any{"gemini-1.5-pro": "x"}},
+ Credentials: map[string]any{"model_mapping": map[string]any{"gemini-2.5-pro": "x"}},
},
model: "gemini-2.5-flash",
expected: false,
diff --git a/backend/internal/service/pricing_service.go b/backend/internal/service/pricing_service.go
index 392fb65c..0ade72cd 100644
--- a/backend/internal/service/pricing_service.go
+++ b/backend/internal/service/pricing_service.go
@@ -531,8 +531,8 @@ func (s *PricingService) buildModelLookupCandidates(modelLower string) []string
func normalizeModelNameForPricing(model string) string {
// Common Gemini/VertexAI forms:
// - models/gemini-2.0-flash-exp
- // - publishers/google/models/gemini-1.5-pro
- // - projects/.../locations/.../publishers/google/models/gemini-1.5-pro
+ // - publishers/google/models/gemini-2.5-pro
+ // - projects/.../locations/.../publishers/google/models/gemini-2.5-pro
model = strings.TrimSpace(model)
model = strings.TrimLeft(model, "/")
model = strings.TrimPrefix(model, "models/")
diff --git a/frontend/src/components/account/AccountTestModal.vue b/frontend/src/components/account/AccountTestModal.vue
index 42f3c1b9..dfa1503e 100644
--- a/frontend/src/components/account/AccountTestModal.vue
+++ b/frontend/src/components/account/AccountTestModal.vue
@@ -292,8 +292,11 @@ const loadAvailableModels = async () => {
if (availableModels.value.length > 0) {
if (props.account.platform === 'gemini') {
const preferred =
+ availableModels.value.find((m) => m.id === 'gemini-2.0-flash') ||
+ availableModels.value.find((m) => m.id === 'gemini-2.5-flash') ||
availableModels.value.find((m) => m.id === 'gemini-2.5-pro') ||
- availableModels.value.find((m) => m.id === 'gemini-3-pro')
+ availableModels.value.find((m) => m.id === 'gemini-3-flash-preview') ||
+ availableModels.value.find((m) => m.id === 'gemini-3-pro-preview')
selectedModelId.value = preferred?.id || availableModels.value[0].id
} else {
// Try to select Sonnet as default, otherwise use first model
diff --git a/frontend/src/components/admin/account/AccountTestModal.vue b/frontend/src/components/admin/account/AccountTestModal.vue
index 2cb1c5a5..feb09654 100644
--- a/frontend/src/components/admin/account/AccountTestModal.vue
+++ b/frontend/src/components/admin/account/AccountTestModal.vue
@@ -232,8 +232,11 @@ const loadAvailableModels = async () => {
if (availableModels.value.length > 0) {
if (props.account.platform === 'gemini') {
const preferred =
+ availableModels.value.find((m) => m.id === 'gemini-2.0-flash') ||
+ availableModels.value.find((m) => m.id === 'gemini-2.5-flash') ||
availableModels.value.find((m) => m.id === 'gemini-2.5-pro') ||
- availableModels.value.find((m) => m.id === 'gemini-3-pro')
+ availableModels.value.find((m) => m.id === 'gemini-3-flash-preview') ||
+ availableModels.value.find((m) => m.id === 'gemini-3-pro-preview')
selectedModelId.value = preferred?.id || availableModels.value[0].id
} else {
// Try to select Sonnet as default, otherwise use first model
diff --git a/frontend/src/components/keys/UseKeyModal.vue b/frontend/src/components/keys/UseKeyModal.vue
index 8075ba70..7f9bd1ed 100644
--- a/frontend/src/components/keys/UseKeyModal.vue
+++ b/frontend/src/components/keys/UseKeyModal.vue
@@ -443,7 +443,7 @@ $env:ANTHROPIC_AUTH_TOKEN="${apiKey}"`
}
function generateGeminiCliContent(baseUrl: string, apiKey: string): FileConfig {
- const model = 'gemini-2.5-pro'
+ const model = 'gemini-2.0-flash'
const modelComment = t('keys.useKeyModal.gemini.modelComment')
let path: string
let content: string
@@ -548,14 +548,22 @@ function generateOpenCodeConfig(platform: string, baseUrl: string, apiKey: strin
}
}
const geminiModels = {
- 'gemini-3-pro-high': { name: 'Gemini 3 Pro High' },
- 'gemini-3-pro-low': { name: 'Gemini 3 Pro Low' },
- 'gemini-3-pro-preview': { name: 'Gemini 3 Pro Preview' },
- 'gemini-3-pro-image': { name: 'Gemini 3 Pro Image' },
- 'gemini-3-flash': { name: 'Gemini 3 Flash' },
- 'gemini-2.5-flash-thinking': { name: 'Gemini 2.5 Flash Thinking' },
+ 'gemini-2.0-flash': { name: 'Gemini 2.0 Flash' },
'gemini-2.5-flash': { name: 'Gemini 2.5 Flash' },
- 'gemini-2.5-flash-lite': { name: 'Gemini 2.5 Flash Lite' }
+ 'gemini-2.5-pro': { name: 'Gemini 2.5 Pro' },
+ 'gemini-3-flash-preview': { name: 'Gemini 3 Flash Preview' },
+ 'gemini-3-pro-preview': { name: 'Gemini 3 Pro Preview' }
+ }
+
+ const antigravityGeminiModels = {
+ 'gemini-2.5-flash': { name: 'Gemini 2.5 Flash' },
+ 'gemini-2.5-flash-lite': { name: 'Gemini 2.5 Flash Lite' },
+ 'gemini-2.5-flash-thinking': { name: 'Gemini 2.5 Flash Thinking' },
+ 'gemini-3-flash': { name: 'Gemini 3 Flash' },
+ 'gemini-3-pro-low': { name: 'Gemini 3 Pro Low' },
+ 'gemini-3-pro-high': { name: 'Gemini 3 Pro High' },
+ 'gemini-3-pro-preview': { name: 'Gemini 3 Pro Preview' },
+ 'gemini-3-pro-image': { name: 'Gemini 3 Pro Image' }
}
const claudeModels = {
'claude-opus-4-5-thinking': { name: 'Claude Opus 4.5 Thinking' },
@@ -575,7 +583,7 @@ function generateOpenCodeConfig(platform: string, baseUrl: string, apiKey: strin
} else if (platform === 'antigravity-gemini') {
provider[platform].npm = '@ai-sdk/google'
provider[platform].name = 'Antigravity (Gemini)'
- provider[platform].models = geminiModels
+ provider[platform].models = antigravityGeminiModels
} else if (platform === 'openai') {
provider[platform].models = openaiModels
}
diff --git a/frontend/src/composables/useModelWhitelist.ts b/frontend/src/composables/useModelWhitelist.ts
index 79900c6e..d4fa2993 100644
--- a/frontend/src/composables/useModelWhitelist.ts
+++ b/frontend/src/composables/useModelWhitelist.ts
@@ -43,13 +43,13 @@ export const claudeModels = [
// Google Gemini
const geminiModels = [
- 'gemini-2.0-flash', 'gemini-2.0-flash-lite-preview', 'gemini-2.0-flash-exp',
- 'gemini-2.0-pro-exp', 'gemini-2.0-flash-thinking-exp',
- 'gemini-2.5-pro-exp-03-25', 'gemini-2.5-pro-preview-03-25',
- 'gemini-3-pro-preview',
- 'gemini-1.5-pro', 'gemini-1.5-pro-latest',
- 'gemini-1.5-flash', 'gemini-1.5-flash-latest', 'gemini-1.5-flash-8b',
- 'gemini-exp-1206'
+ // Keep in sync with backend curated Gemini lists.
+ // This list is intentionally conservative (models commonly available across OAuth/API key).
+ 'gemini-2.0-flash',
+ 'gemini-2.5-flash',
+ 'gemini-2.5-pro',
+ 'gemini-3-flash-preview',
+ 'gemini-3-pro-preview'
]
// 智谱 GLM
@@ -229,9 +229,8 @@ const openaiPresetMappings = [
const geminiPresetMappings = [
{ label: 'Flash 2.0', from: 'gemini-2.0-flash', to: 'gemini-2.0-flash', color: 'bg-blue-100 text-blue-700 hover:bg-blue-200 dark:bg-blue-900/30 dark:text-blue-400' },
- { label: 'Flash Lite', from: 'gemini-2.0-flash-lite-preview', to: 'gemini-2.0-flash-lite-preview', color: 'bg-indigo-100 text-indigo-700 hover:bg-indigo-200 dark:bg-indigo-900/30 dark:text-indigo-400' },
- { label: '1.5 Pro', from: 'gemini-1.5-pro', to: 'gemini-1.5-pro', color: 'bg-purple-100 text-purple-700 hover:bg-purple-200 dark:bg-purple-900/30 dark:text-purple-400' },
- { label: '1.5 Flash', from: 'gemini-1.5-flash', to: 'gemini-1.5-flash', color: 'bg-emerald-100 text-emerald-700 hover:bg-emerald-200 dark:bg-emerald-900/30 dark:text-emerald-400' }
+ { label: '2.5 Flash', from: 'gemini-2.5-flash', to: 'gemini-2.5-flash', color: 'bg-indigo-100 text-indigo-700 hover:bg-indigo-200 dark:bg-indigo-900/30 dark:text-indigo-400' },
+ { label: '2.5 Pro', from: 'gemini-2.5-pro', to: 'gemini-2.5-pro', color: 'bg-purple-100 text-purple-700 hover:bg-purple-200 dark:bg-purple-900/30 dark:text-purple-400' }
]
// =====================
From 8917a3ea8fa4ffa8943e32513b3cee5528ef516d Mon Sep 17 00:00:00 2001
From: cyhhao
Date: Sat, 17 Jan 2026 00:27:36 +0800
Subject: [PATCH 012/155] =?UTF-8?q?fix(=E7=BD=91=E5=85=B3):=20=E4=BF=AE?=
=?UTF-8?q?=E5=A4=8D=20golangci-lint?=
MIME-Version: 1.0
Content-Type: text/plain; charset=UTF-8
Content-Transfer-Encoding: 8bit
---
backend/internal/service/gateway_service.go | 10 +++-------
1 file changed, 3 insertions(+), 7 deletions(-)
diff --git a/backend/internal/service/gateway_service.go b/backend/internal/service/gateway_service.go
index 8b4871c9..fb2d40a3 100644
--- a/backend/internal/service/gateway_service.go
+++ b/backend/internal/service/gateway_service.go
@@ -439,7 +439,7 @@ func toPascalCase(value string) string {
}
runes := []rune(lower)
runes[0] = unicode.ToUpper(runes[0])
- builder.WriteString(string(runes))
+ _, _ = builder.WriteString(string(runes))
}
return builder.String()
}
@@ -723,12 +723,8 @@ func normalizeClaudeOAuthRequestBody(body []byte, modelID string, opts claudeOAu
}
}
- if _, ok := req["temperature"]; ok {
- delete(req, "temperature")
- }
- if _, ok := req["tool_choice"]; ok {
- delete(req, "tool_choice")
- }
+ delete(req, "temperature")
+ delete(req, "tool_choice")
newBody, err := json.Marshal(req)
if err != nil {
From a7165b0f73b86c750c8037d4d5ec41dfa1f461d8 Mon Sep 17 00:00:00 2001
From: IanShaw027 <131567472+IanShaw027@users.noreply.github.com>
Date: Sat, 17 Jan 2026 01:53:51 +0800
Subject: [PATCH 013/155] =?UTF-8?q?fix(group):=20SIMPLE=20=E6=A8=A1?=
=?UTF-8?q?=E5=BC=8F=E5=90=AF=E5=8A=A8=E8=A1=A5=E9=BD=90=E9=BB=98=E8=AE=A4?=
=?UTF-8?q?=E5=88=86=E7=BB=84?=
MIME-Version: 1.0
Content-Type: text/plain; charset=UTF-8
Content-Transfer-Encoding: 8bit
---
backend/internal/repository/ent.go | 13 +++
.../repository/simple_mode_default_groups.go | 82 ++++++++++++++++++
...le_mode_default_groups_integration_test.go | 84 +++++++++++++++++++
3 files changed, 179 insertions(+)
create mode 100644 backend/internal/repository/simple_mode_default_groups.go
create mode 100644 backend/internal/repository/simple_mode_default_groups_integration_test.go
diff --git a/backend/internal/repository/ent.go b/backend/internal/repository/ent.go
index 8005f114..d7d574e8 100644
--- a/backend/internal/repository/ent.go
+++ b/backend/internal/repository/ent.go
@@ -65,5 +65,18 @@ func InitEnt(cfg *config.Config) (*ent.Client, *sql.DB, error) {
// 创建 Ent 客户端,绑定到已配置的数据库驱动。
client := ent.NewClient(ent.Driver(drv))
+
+ // SIMPLE 模式:启动时补齐各平台默认分组。
+ // - anthropic/openai/gemini: 确保存在 -default
+ // - antigravity: 仅要求存在 >=2 个未软删除分组(用于 claude/gemini 混合调度场景)
+ if cfg.RunMode == config.RunModeSimple {
+ seedCtx, seedCancel := context.WithTimeout(context.Background(), 30*time.Second)
+ defer seedCancel()
+ if err := ensureSimpleModeDefaultGroups(seedCtx, client); err != nil {
+ _ = client.Close()
+ return nil, nil, err
+ }
+ }
+
return client, drv.DB(), nil
}
diff --git a/backend/internal/repository/simple_mode_default_groups.go b/backend/internal/repository/simple_mode_default_groups.go
new file mode 100644
index 00000000..56309184
--- /dev/null
+++ b/backend/internal/repository/simple_mode_default_groups.go
@@ -0,0 +1,82 @@
+package repository
+
+import (
+ "context"
+ "fmt"
+
+ dbent "github.com/Wei-Shaw/sub2api/ent"
+ "github.com/Wei-Shaw/sub2api/ent/group"
+ "github.com/Wei-Shaw/sub2api/internal/service"
+)
+
+func ensureSimpleModeDefaultGroups(ctx context.Context, client *dbent.Client) error {
+ if client == nil {
+ return fmt.Errorf("nil ent client")
+ }
+
+ requiredByPlatform := map[string]int{
+ service.PlatformAnthropic: 1,
+ service.PlatformOpenAI: 1,
+ service.PlatformGemini: 1,
+ service.PlatformAntigravity: 2,
+ }
+
+ for platform, minCount := range requiredByPlatform {
+ count, err := client.Group.Query().
+ Where(group.PlatformEQ(platform), group.DeletedAtIsNil()).
+ Count(ctx)
+ if err != nil {
+ return fmt.Errorf("count groups for platform %s: %w", platform, err)
+ }
+
+ if platform == service.PlatformAntigravity {
+ if count < minCount {
+ for i := count; i < minCount; i++ {
+ name := fmt.Sprintf("%s-default-%d", platform, i+1)
+ if err := createGroupIfNotExists(ctx, client, name, platform); err != nil {
+ return err
+ }
+ }
+ }
+ continue
+ }
+
+ // Non-antigravity platforms: ensure -default exists.
+ name := platform + "-default"
+ if err := createGroupIfNotExists(ctx, client, name, platform); err != nil {
+ return err
+ }
+ }
+
+ return nil
+}
+
+func createGroupIfNotExists(ctx context.Context, client *dbent.Client, name, platform string) error {
+ exists, err := client.Group.Query().
+ Where(group.NameEQ(name), group.DeletedAtIsNil()).
+ Exist(ctx)
+ if err != nil {
+ return fmt.Errorf("check group exists %s: %w", name, err)
+ }
+ if exists {
+ return nil
+ }
+
+ _, err = client.Group.Create().
+ SetName(name).
+ SetDescription("Auto-created default group").
+ SetPlatform(platform).
+ SetStatus(service.StatusActive).
+ SetSubscriptionType(service.SubscriptionTypeStandard).
+ SetRateMultiplier(1.0).
+ SetIsExclusive(false).
+ Save(ctx)
+ if err != nil {
+ if dbent.IsConstraintError(err) {
+ // Concurrent server startups may race on creation; treat as success.
+ return nil
+ }
+ return fmt.Errorf("create default group %s: %w", name, err)
+ }
+ return nil
+}
diff --git a/backend/internal/repository/simple_mode_default_groups_integration_test.go b/backend/internal/repository/simple_mode_default_groups_integration_test.go
new file mode 100644
index 00000000..3327257b
--- /dev/null
+++ b/backend/internal/repository/simple_mode_default_groups_integration_test.go
@@ -0,0 +1,84 @@
+//go:build integration
+
+package repository
+
+import (
+ "context"
+ "testing"
+ "time"
+
+ "github.com/Wei-Shaw/sub2api/ent/group"
+ "github.com/Wei-Shaw/sub2api/internal/service"
+ "github.com/stretchr/testify/require"
+)
+
+func TestEnsureSimpleModeDefaultGroups_CreatesMissingDefaults(t *testing.T) {
+ ctx := context.Background()
+ tx := testEntTx(t)
+ client := tx.Client()
+
+ seedCtx, cancel := context.WithTimeout(ctx, 10*time.Second)
+ defer cancel()
+
+ require.NoError(t, ensureSimpleModeDefaultGroups(seedCtx, client))
+
+ assertGroupExists := func(name string) {
+ exists, err := client.Group.Query().Where(group.NameEQ(name), group.DeletedAtIsNil()).Exist(seedCtx)
+ require.NoError(t, err)
+ require.True(t, exists, "expected group %s to exist", name)
+ }
+
+ assertGroupExists(service.PlatformAnthropic + "-default")
+ assertGroupExists(service.PlatformOpenAI + "-default")
+ assertGroupExists(service.PlatformGemini + "-default")
+ assertGroupExists(service.PlatformAntigravity + "-default-1")
+ assertGroupExists(service.PlatformAntigravity + "-default-2")
+}
+
+func TestEnsureSimpleModeDefaultGroups_IgnoresSoftDeletedGroups(t *testing.T) {
+ ctx := context.Background()
+ tx := testEntTx(t)
+ client := tx.Client()
+
+ seedCtx, cancel := context.WithTimeout(ctx, 10*time.Second)
+ defer cancel()
+
+ // Create and then soft-delete an anthropic default group.
+ g, err := client.Group.Create().
+ SetName(service.PlatformAnthropic + "-default").
+ SetPlatform(service.PlatformAnthropic).
+ SetStatus(service.StatusActive).
+ SetSubscriptionType(service.SubscriptionTypeStandard).
+ SetRateMultiplier(1.0).
+ SetIsExclusive(false).
+ Save(seedCtx)
+ require.NoError(t, err)
+
+ _, err = client.Group.Delete().Where(group.IDEQ(g.ID)).Exec(seedCtx)
+ require.NoError(t, err)
+
+ require.NoError(t, ensureSimpleModeDefaultGroups(seedCtx, client))
+
+ // New active one should exist.
+ count, err := client.Group.Query().Where(group.NameEQ(service.PlatformAnthropic+"-default"), group.DeletedAtIsNil()).Count(seedCtx)
+ require.NoError(t, err)
+ require.Equal(t, 1, count)
+}
+
+func TestEnsureSimpleModeDefaultGroups_AntigravityNeedsTwoGroupsOnlyByCount(t *testing.T) {
+ ctx := context.Background()
+ tx := testEntTx(t)
+ client := tx.Client()
+
+ seedCtx, cancel := context.WithTimeout(ctx, 10*time.Second)
+ defer cancel()
+
+ mustCreateGroup(t, client, &service.Group{Name: "ag-custom-1-" + time.Now().Format(time.RFC3339Nano), Platform: service.PlatformAntigravity})
+ mustCreateGroup(t, client, &service.Group{Name: "ag-custom-2-" + time.Now().Format(time.RFC3339Nano), Platform: service.PlatformAntigravity})
+
+ require.NoError(t, ensureSimpleModeDefaultGroups(seedCtx, client))
+
+ count, err := client.Group.Query().Where(group.PlatformEQ(service.PlatformAntigravity), group.DeletedAtIsNil()).Count(seedCtx)
+ require.NoError(t, err)
+ require.GreaterOrEqual(t, count, 2)
+}
From ae21db77ecaa9f3fa05e3efe8b3a6b0c2dc47566 Mon Sep 17 00:00:00 2001
From: IanShaw027 <131567472+IanShaw027@users.noreply.github.com>
Date: Sat, 17 Jan 2026 02:31:16 +0800
Subject: [PATCH 014/155] =?UTF-8?q?fix(openai):=20=E4=BD=BF=E7=94=A8=20pro?=
=?UTF-8?q?mpt=5Fcache=5Fkey=20=E5=85=9C=E5=BA=95=E7=B2=98=E6=80=A7?=
=?UTF-8?q?=E4=BC=9A=E8=AF=9D?=
MIME-Version: 1.0
Content-Type: text/plain; charset=UTF-8
Content-Transfer-Encoding: 8bit
opencode 请求不带 session_id/conversation_id,导致粘性会话失效。现在按 header 优先、prompt_cache_key 兜底生成 session hash,并补充单测验证优先级。
---
.../handler/openai_gateway_handler.go | 4 +-
.../service/openai_gateway_service.go | 24 +++++++++--
.../service/openai_gateway_service_test.go | 43 +++++++++++++++++++
3 files changed, 66 insertions(+), 5 deletions(-)
diff --git a/backend/internal/handler/openai_gateway_handler.go b/backend/internal/handler/openai_gateway_handler.go
index c4cfabc3..68e67656 100644
--- a/backend/internal/handler/openai_gateway_handler.go
+++ b/backend/internal/handler/openai_gateway_handler.go
@@ -186,8 +186,8 @@ func (h *OpenAIGatewayHandler) Responses(c *gin.Context) {
return
}
- // Generate session hash (from header for OpenAI)
- sessionHash := h.gatewayService.GenerateSessionHash(c)
+ // Generate session hash (header first; fallback to prompt_cache_key)
+ sessionHash := h.gatewayService.GenerateSessionHash(c, reqBody)
const maxAccountSwitches = 3
switchCount := 0
diff --git a/backend/internal/service/openai_gateway_service.go b/backend/internal/service/openai_gateway_service.go
index c7d94882..a3c4a239 100644
--- a/backend/internal/service/openai_gateway_service.go
+++ b/backend/internal/service/openai_gateway_service.go
@@ -133,12 +133,30 @@ func NewOpenAIGatewayService(
}
}
-// GenerateSessionHash generates session hash from header (OpenAI uses session_id header)
-func (s *OpenAIGatewayService) GenerateSessionHash(c *gin.Context) string {
- sessionID := c.GetHeader("session_id")
+// GenerateSessionHash generates a sticky-session hash for OpenAI requests.
+//
+// Priority:
+// 1. Header: session_id
+// 2. Header: conversation_id
+// 3. Body: prompt_cache_key (opencode)
+func (s *OpenAIGatewayService) GenerateSessionHash(c *gin.Context, reqBody map[string]any) string {
+ if c == nil {
+ return ""
+ }
+
+ sessionID := strings.TrimSpace(c.GetHeader("session_id"))
+ if sessionID == "" {
+ sessionID = strings.TrimSpace(c.GetHeader("conversation_id"))
+ }
+ if sessionID == "" && reqBody != nil {
+ if v, ok := reqBody["prompt_cache_key"].(string); ok {
+ sessionID = strings.TrimSpace(v)
+ }
+ }
if sessionID == "" {
return ""
}
+
hash := sha256.Sum256([]byte(sessionID))
return hex.EncodeToString(hash[:])
}
diff --git a/backend/internal/service/openai_gateway_service_test.go b/backend/internal/service/openai_gateway_service_test.go
index 42b88b7d..a34b8045 100644
--- a/backend/internal/service/openai_gateway_service_test.go
+++ b/backend/internal/service/openai_gateway_service_test.go
@@ -49,6 +49,49 @@ func (c stubConcurrencyCache) GetAccountsLoadBatch(ctx context.Context, accounts
return out, nil
}
+func TestOpenAIGatewayService_GenerateSessionHash_Priority(t *testing.T) {
+ gin.SetMode(gin.TestMode)
+ rec := httptest.NewRecorder()
+ c, _ := gin.CreateTestContext(rec)
+ c.Request = httptest.NewRequest(http.MethodPost, "/openai/v1/responses", nil)
+
+ svc := &OpenAIGatewayService{}
+
+ // 1) session_id header wins
+ c.Request.Header.Set("session_id", "sess-123")
+ c.Request.Header.Set("conversation_id", "conv-456")
+ h1 := svc.GenerateSessionHash(c, map[string]any{"prompt_cache_key": "ses_aaa"})
+ if h1 == "" {
+ t.Fatalf("expected non-empty hash")
+ }
+
+ // 2) conversation_id used when session_id absent
+ c.Request.Header.Del("session_id")
+ h2 := svc.GenerateSessionHash(c, map[string]any{"prompt_cache_key": "ses_aaa"})
+ if h2 == "" {
+ t.Fatalf("expected non-empty hash")
+ }
+ if h1 == h2 {
+ t.Fatalf("expected different hashes for different keys")
+ }
+
+ // 3) prompt_cache_key used when both headers absent
+ c.Request.Header.Del("conversation_id")
+ h3 := svc.GenerateSessionHash(c, map[string]any{"prompt_cache_key": "ses_aaa"})
+ if h3 == "" {
+ t.Fatalf("expected non-empty hash")
+ }
+ if h2 == h3 {
+ t.Fatalf("expected different hashes for different keys")
+ }
+
+ // 4) empty when no signals
+ h4 := svc.GenerateSessionHash(c, map[string]any{})
+ if h4 != "" {
+ t.Fatalf("expected empty hash when no signals")
+ }
+}
+
func TestOpenAISelectAccountWithLoadAwareness_FiltersUnschedulable(t *testing.T) {
now := time.Now()
resetAt := now.Add(10 * time.Minute)
From a61cc2cb249e32e72fd5d0b41d1e3294bda79d75 Mon Sep 17 00:00:00 2001
From: IanShaw027 <131567472+IanShaw027@users.noreply.github.com>
Date: Sat, 17 Jan 2026 11:00:07 +0800
Subject: [PATCH 015/155] =?UTF-8?q?fix(openai):=20=E5=A2=9E=E5=BC=BA=20Cod?=
=?UTF-8?q?ex=20=E5=B7=A5=E5=85=B7=E8=BF=87=E6=BB=A4=E5=92=8C=E5=8F=82?=
=?UTF-8?q?=E6=95=B0=E6=A0=87=E5=87=86=E5=8C=96?=
MIME-Version: 1.0
Content-Type: text/plain; charset=UTF-8
Content-Transfer-Encoding: 8bit
- codex_transform: 过滤无效工具,支持 Responses-style 和 ChatCompletions-style 格式
- tool_corrector: 添加 fetch 工具映射,修正 bash/edit 参数命名规范
---
.../service/openai_codex_transform.go | 28 +++++--
.../service/openai_codex_transform_test.go | 31 ++++++++
.../internal/service/openai_tool_corrector.go | 77 +++++++++++++++----
.../service/openai_tool_corrector_test.go | 19 +++--
4 files changed, 125 insertions(+), 30 deletions(-)
diff --git a/backend/internal/service/openai_codex_transform.go b/backend/internal/service/openai_codex_transform.go
index 264bdf95..48c72593 100644
--- a/backend/internal/service/openai_codex_transform.go
+++ b/backend/internal/service/openai_codex_transform.go
@@ -394,19 +394,35 @@ func normalizeCodexTools(reqBody map[string]any) bool {
}
modified := false
- for idx, tool := range tools {
+ validTools := make([]any, 0, len(tools))
+
+ for _, tool := range tools {
toolMap, ok := tool.(map[string]any)
if !ok {
+ // Keep unknown structure as-is to avoid breaking upstream behavior.
+ validTools = append(validTools, tool)
continue
}
toolType, _ := toolMap["type"].(string)
- if strings.TrimSpace(toolType) != "function" {
+ toolType = strings.TrimSpace(toolType)
+ if toolType != "function" {
+ validTools = append(validTools, toolMap)
continue
}
- function, ok := toolMap["function"].(map[string]any)
- if !ok {
+ // OpenAI Responses-style tools use top-level name/parameters.
+ if name, ok := toolMap["name"].(string); ok && strings.TrimSpace(name) != "" {
+ validTools = append(validTools, toolMap)
+ continue
+ }
+
+ // ChatCompletions-style tools use {type:"function", function:{...}}.
+ functionValue, hasFunction := toolMap["function"]
+ function, ok := functionValue.(map[string]any)
+ if !hasFunction || functionValue == nil || !ok || function == nil {
+ // Drop invalid function tools.
+ modified = true
continue
}
@@ -435,11 +451,11 @@ func normalizeCodexTools(reqBody map[string]any) bool {
}
}
- tools[idx] = toolMap
+ validTools = append(validTools, toolMap)
}
if modified {
- reqBody["tools"] = tools
+ reqBody["tools"] = validTools
}
return modified
diff --git a/backend/internal/service/openai_codex_transform_test.go b/backend/internal/service/openai_codex_transform_test.go
index 0ff9485a..4cd72ab6 100644
--- a/backend/internal/service/openai_codex_transform_test.go
+++ b/backend/internal/service/openai_codex_transform_test.go
@@ -129,6 +129,37 @@ func TestFilterCodexInput_RemovesItemReferenceWhenNotPreserved(t *testing.T) {
require.False(t, hasID)
}
+func TestApplyCodexOAuthTransform_NormalizeCodexTools_PreservesResponsesFunctionTools(t *testing.T) {
+ setupCodexCache(t)
+
+ reqBody := map[string]any{
+ "model": "gpt-5.1",
+ "tools": []any{
+ map[string]any{
+ "type": "function",
+ "name": "bash",
+ "description": "desc",
+ "parameters": map[string]any{"type": "object"},
+ },
+ map[string]any{
+ "type": "function",
+ "function": nil,
+ },
+ },
+ }
+
+ applyCodexOAuthTransform(reqBody)
+
+ tools, ok := reqBody["tools"].([]any)
+ require.True(t, ok)
+ require.Len(t, tools, 1)
+
+ first, ok := tools[0].(map[string]any)
+ require.True(t, ok)
+ require.Equal(t, "function", first["type"])
+ require.Equal(t, "bash", first["name"])
+}
+
func TestApplyCodexOAuthTransform_EmptyInput(t *testing.T) {
// 空 input 应保持为空且不触发异常。
setupCodexCache(t)
diff --git a/backend/internal/service/openai_tool_corrector.go b/backend/internal/service/openai_tool_corrector.go
index 9c9eab84..f4719275 100644
--- a/backend/internal/service/openai_tool_corrector.go
+++ b/backend/internal/service/openai_tool_corrector.go
@@ -27,6 +27,11 @@ var codexToolNameMapping = map[string]string{
"executeBash": "bash",
"exec_bash": "bash",
"execBash": "bash",
+
+ // Some clients output generic fetch names.
+ "fetch": "webfetch",
+ "web_fetch": "webfetch",
+ "webFetch": "webfetch",
}
// ToolCorrectionStats 记录工具修正的统计信息(导出用于 JSON 序列化)
@@ -208,27 +213,67 @@ func (c *CodexToolCorrector) correctToolParameters(toolName string, functionCall
// 根据工具名称应用特定的参数修正规则
switch toolName {
case "bash":
- // 移除 workdir 参数(OpenCode 不支持)
- if _, exists := argsMap["workdir"]; exists {
- delete(argsMap, "workdir")
- corrected = true
- log.Printf("[CodexToolCorrector] Removed 'workdir' parameter from bash tool")
- }
- if _, exists := argsMap["work_dir"]; exists {
- delete(argsMap, "work_dir")
- corrected = true
- log.Printf("[CodexToolCorrector] Removed 'work_dir' parameter from bash tool")
+ // OpenCode bash 支持 workdir;有些来源会输出 work_dir。
+ if _, hasWorkdir := argsMap["workdir"]; !hasWorkdir {
+ if workDir, exists := argsMap["work_dir"]; exists {
+ argsMap["workdir"] = workDir
+ delete(argsMap, "work_dir")
+ corrected = true
+ log.Printf("[CodexToolCorrector] Renamed 'work_dir' to 'workdir' in bash tool")
+ }
+ } else {
+ if _, exists := argsMap["work_dir"]; exists {
+ delete(argsMap, "work_dir")
+ corrected = true
+ log.Printf("[CodexToolCorrector] Removed duplicate 'work_dir' parameter from bash tool")
+ }
}
case "edit":
- // OpenCode edit 使用 old_string/new_string,Codex 可能使用其他名称
- // 这里可以添加参数名称的映射逻辑
- if _, exists := argsMap["file_path"]; !exists {
- if path, exists := argsMap["path"]; exists {
- argsMap["file_path"] = path
+ // OpenCode edit 参数为 filePath/oldString/newString(camelCase)。
+ if _, exists := argsMap["filePath"]; !exists {
+ if filePath, exists := argsMap["file_path"]; exists {
+ argsMap["filePath"] = filePath
+ delete(argsMap, "file_path")
+ corrected = true
+ log.Printf("[CodexToolCorrector] Renamed 'file_path' to 'filePath' in edit tool")
+ } else if filePath, exists := argsMap["path"]; exists {
+ argsMap["filePath"] = filePath
delete(argsMap, "path")
corrected = true
- log.Printf("[CodexToolCorrector] Renamed 'path' to 'file_path' in edit tool")
+ log.Printf("[CodexToolCorrector] Renamed 'path' to 'filePath' in edit tool")
+ } else if filePath, exists := argsMap["file"]; exists {
+ argsMap["filePath"] = filePath
+ delete(argsMap, "file")
+ corrected = true
+ log.Printf("[CodexToolCorrector] Renamed 'file' to 'filePath' in edit tool")
+ }
+ }
+
+ if _, exists := argsMap["oldString"]; !exists {
+ if oldString, exists := argsMap["old_string"]; exists {
+ argsMap["oldString"] = oldString
+ delete(argsMap, "old_string")
+ corrected = true
+ log.Printf("[CodexToolCorrector] Renamed 'old_string' to 'oldString' in edit tool")
+ }
+ }
+
+ if _, exists := argsMap["newString"]; !exists {
+ if newString, exists := argsMap["new_string"]; exists {
+ argsMap["newString"] = newString
+ delete(argsMap, "new_string")
+ corrected = true
+ log.Printf("[CodexToolCorrector] Renamed 'new_string' to 'newString' in edit tool")
+ }
+ }
+
+ if _, exists := argsMap["replaceAll"]; !exists {
+ if replaceAll, exists := argsMap["replace_all"]; exists {
+ argsMap["replaceAll"] = replaceAll
+ delete(argsMap, "replace_all")
+ corrected = true
+ log.Printf("[CodexToolCorrector] Renamed 'replace_all' to 'replaceAll' in edit tool")
}
}
}
diff --git a/backend/internal/service/openai_tool_corrector_test.go b/backend/internal/service/openai_tool_corrector_test.go
index 3e885b4b..ff518ea6 100644
--- a/backend/internal/service/openai_tool_corrector_test.go
+++ b/backend/internal/service/openai_tool_corrector_test.go
@@ -416,22 +416,23 @@ func TestCorrectToolParameters(t *testing.T) {
expected map[string]bool // key: 期待存在的参数, value: true表示应该存在
}{
{
- name: "remove workdir from bash tool",
+ name: "rename work_dir to workdir in bash tool",
input: `{
"tool_calls": [{
"function": {
"name": "bash",
- "arguments": "{\"command\":\"ls\",\"workdir\":\"/tmp\"}"
+ "arguments": "{\"command\":\"ls\",\"work_dir\":\"/tmp\"}"
}
}]
}`,
expected: map[string]bool{
- "command": true,
- "workdir": false,
+ "command": true,
+ "workdir": true,
+ "work_dir": false,
},
},
{
- name: "rename path to file_path in edit tool",
+ name: "rename snake_case edit params to camelCase",
input: `{
"tool_calls": [{
"function": {
@@ -441,10 +442,12 @@ func TestCorrectToolParameters(t *testing.T) {
}]
}`,
expected: map[string]bool{
- "file_path": true,
+ "filePath": true,
"path": false,
- "old_string": true,
- "new_string": true,
+ "oldString": true,
+ "old_string": false,
+ "newString": true,
+ "new_string": false,
},
},
}
From bc1d7edc58a09b0ef0abb5297a679eadeb6d74a4 Mon Sep 17 00:00:00 2001
From: ianshaw
Date: Sat, 17 Jan 2026 17:54:33 +0800
Subject: [PATCH 016/155] =?UTF-8?q?fix(ops):=20=E7=BB=9F=E4=B8=80=20reques?=
=?UTF-8?q?t-errors=20=E5=92=8C=20SLA=20=E7=9A=84=E9=94=99=E8=AF=AF?=
=?UTF-8?q?=E5=88=86=E7=B1=BB=E9=80=BB=E8=BE=91?=
MIME-Version: 1.0
Content-Type: text/plain; charset=UTF-8
Content-Transfer-Encoding: 8bit
修复 request-errors 接口与 Dashboard Overview SLA 计算不一致的问题:
- errors 视图现在只排除业务限制错误(余额不足、并发限制等)
- 上游 429/529 错误现在包含在 errors 视图中,与 SLA 计算保持一致
- excluded 视图现在只显示业务限制错误
这确保了 request-errors 接口和 Dashboard 的 error_count_sla 使用相同的过滤逻辑。
---
backend/internal/repository/ops_repo.go | 7 +++----
1 file changed, 3 insertions(+), 4 deletions(-)
diff --git a/backend/internal/repository/ops_repo.go b/backend/internal/repository/ops_repo.go
index 613c5bd5..b04154b7 100644
--- a/backend/internal/repository/ops_repo.go
+++ b/backend/internal/repository/ops_repo.go
@@ -992,7 +992,8 @@ func buildOpsErrorLogsWhere(filter *service.OpsErrorLogFilter) (string, []any) {
}
// View filter: errors vs excluded vs all.
- // Excluded = upstream 429/529 and business-limited (quota/concurrency/billing) errors.
+ // Excluded = business-limited errors (quota/concurrency/billing).
+ // Upstream 429/529 are included in errors view to match SLA calculation.
view := ""
if filter != nil {
view = strings.ToLower(strings.TrimSpace(filter.View))
@@ -1000,15 +1001,13 @@ func buildOpsErrorLogsWhere(filter *service.OpsErrorLogFilter) (string, []any) {
switch view {
case "", "errors":
clauses = append(clauses, "COALESCE(is_business_limited,false) = false")
- clauses = append(clauses, "COALESCE(upstream_status_code, status_code, 0) NOT IN (429, 529)")
case "excluded":
- clauses = append(clauses, "(COALESCE(is_business_limited,false) = true OR COALESCE(upstream_status_code, status_code, 0) IN (429, 529))")
+ clauses = append(clauses, "COALESCE(is_business_limited,false) = true")
case "all":
// no-op
default:
// treat unknown as default 'errors'
clauses = append(clauses, "COALESCE(is_business_limited,false) = false")
- clauses = append(clauses, "COALESCE(upstream_status_code, status_code, 0) NOT IN (429, 529)")
}
if len(filter.StatusCodes) > 0 {
args = append(args, pq.Array(filter.StatusCodes))
From 32c47b1509287dbec5d4289bf0d23c5ea3a85f49 Mon Sep 17 00:00:00 2001
From: cyhhao
Date: Sat, 17 Jan 2026 18:16:34 +0800
Subject: [PATCH 017/155] fix(gateway): satisfy golangci-lint checks
---
backend/internal/service/gateway_service.go | 10 +++-------
1 file changed, 3 insertions(+), 7 deletions(-)
diff --git a/backend/internal/service/gateway_service.go b/backend/internal/service/gateway_service.go
index aa811bf5..ff143eee 100644
--- a/backend/internal/service/gateway_service.go
+++ b/backend/internal/service/gateway_service.go
@@ -420,7 +420,7 @@ func toPascalCase(value string) string {
}
runes := []rune(lower)
runes[0] = unicode.ToUpper(runes[0])
- builder.WriteString(string(runes))
+ _, _ = builder.WriteString(string(runes))
}
return builder.String()
}
@@ -704,12 +704,8 @@ func normalizeClaudeOAuthRequestBody(body []byte, modelID string, opts claudeOAu
}
}
- if _, ok := req["temperature"]; ok {
- delete(req, "temperature")
- }
- if _, ok := req["tool_choice"]; ok {
- delete(req, "tool_choice")
- }
+ delete(req, "temperature")
+ delete(req, "tool_choice")
newBody, err := json.Marshal(req)
if err != nil {
From ef5a41057fa7127aba012f5bbdd044ea11dc0b05 Mon Sep 17 00:00:00 2001
From: yangjianbo
Date: Sun, 18 Jan 2026 10:52:18 +0800
Subject: [PATCH 018/155] =?UTF-8?q?feat(usage):=20=E6=B7=BB=E5=8A=A0?=
=?UTF-8?q?=E6=B8=85=E7=90=86=E4=BB=BB=E5=8A=A1=E4=B8=8E=E7=BB=9F=E8=AE=A1?=
=?UTF-8?q?=E8=BF=87=E6=BB=A4?=
MIME-Version: 1.0
Content-Type: text/plain; charset=UTF-8
Content-Transfer-Encoding: 8bit
---
backend/cmd/server/wire.go | 7 +
backend/cmd/server/wire_gen.go | 13 +-
backend/go.mod | 1 +
backend/go.sum | 1 +
backend/internal/config/config.go | 49 ++
backend/internal/config/config_test.go | 570 ++++++++++++++++++
.../admin/admin_basic_handlers_test.go | 262 ++++++++
.../handler/admin/admin_helpers_test.go | 134 ++++
.../handler/admin/admin_service_stub_test.go | 290 +++++++++
.../handler/admin/dashboard_handler.go | 28 +-
.../admin/usage_cleanup_handler_test.go | 377 ++++++++++++
.../internal/handler/admin/usage_handler.go | 192 +++++-
backend/internal/handler/dto/mappers.go | 30 +
backend/internal/handler/dto/types.go | 27 +
.../repository/dashboard_aggregation_repo.go | 69 +++
.../internal/repository/usage_cleanup_repo.go | 363 +++++++++++
.../repository/usage_cleanup_repo_test.go | 440 ++++++++++++++
backend/internal/repository/usage_log_repo.go | 14 +-
.../usage_log_repo_integration_test.go | 14 +-
backend/internal/repository/wire.go | 1 +
backend/internal/server/api_contract_test.go | 4 +-
backend/internal/server/routes/admin.go | 3 +
.../internal/service/account_usage_service.go | 8 +-
.../service/dashboard_aggregation_service.go | 59 +-
.../dashboard_aggregation_service_test.go | 4 +
backend/internal/service/dashboard_service.go | 8 +-
.../service/dashboard_service_test.go | 4 +
backend/internal/service/ratelimit_service.go | 4 +-
backend/internal/service/usage_cleanup.go | 74 +++
.../internal/service/usage_cleanup_service.go | 400 ++++++++++++
.../service/usage_cleanup_service_test.go | 420 +++++++++++++
backend/internal/service/wire.go | 8 +
.../042_add_usage_cleanup_tasks.sql | 21 +
.../043_add_usage_cleanup_cancel_audit.sql | 10 +
config.yaml | 21 +
deploy/config.example.yaml | 21 +
frontend/src/api/admin/dashboard.ts | 2 +
frontend/src/api/admin/usage.ts | 82 ++-
.../admin/usage/UsageCleanupDialog.vue | 339 +++++++++++
.../components/admin/usage/UsageFilters.vue | 25 +-
frontend/src/i18n/locales/en.ts | 38 +-
frontend/src/i18n/locales/zh.ts | 38 +-
frontend/src/types/index.ts | 29 +
frontend/src/views/admin/UsageView.vue | 20 +-
44 files changed, 4478 insertions(+), 46 deletions(-)
create mode 100644 backend/internal/handler/admin/admin_basic_handlers_test.go
create mode 100644 backend/internal/handler/admin/admin_helpers_test.go
create mode 100644 backend/internal/handler/admin/admin_service_stub_test.go
create mode 100644 backend/internal/handler/admin/usage_cleanup_handler_test.go
create mode 100644 backend/internal/repository/usage_cleanup_repo.go
create mode 100644 backend/internal/repository/usage_cleanup_repo_test.go
create mode 100644 backend/internal/service/usage_cleanup.go
create mode 100644 backend/internal/service/usage_cleanup_service.go
create mode 100644 backend/internal/service/usage_cleanup_service_test.go
create mode 100644 backend/migrations/042_add_usage_cleanup_tasks.sql
create mode 100644 backend/migrations/043_add_usage_cleanup_cancel_audit.sql
create mode 100644 frontend/src/components/admin/usage/UsageCleanupDialog.vue
diff --git a/backend/cmd/server/wire.go b/backend/cmd/server/wire.go
index 0a5f9744..5ef04a66 100644
--- a/backend/cmd/server/wire.go
+++ b/backend/cmd/server/wire.go
@@ -70,6 +70,7 @@ func provideCleanup(
schedulerSnapshot *service.SchedulerSnapshotService,
tokenRefresh *service.TokenRefreshService,
accountExpiry *service.AccountExpiryService,
+ usageCleanup *service.UsageCleanupService,
pricing *service.PricingService,
emailQueue *service.EmailQueueService,
billingCache *service.BillingCacheService,
@@ -123,6 +124,12 @@ func provideCleanup(
}
return nil
}},
+ {"UsageCleanupService", func() error {
+ if usageCleanup != nil {
+ usageCleanup.Stop()
+ }
+ return nil
+ }},
{"TokenRefreshService", func() error {
tokenRefresh.Stop()
return nil
diff --git a/backend/cmd/server/wire_gen.go b/backend/cmd/server/wire_gen.go
index 27404b02..509cf13a 100644
--- a/backend/cmd/server/wire_gen.go
+++ b/backend/cmd/server/wire_gen.go
@@ -153,7 +153,9 @@ func initializeApplication(buildInfo handler.BuildInfo) (*Application, error) {
updateService := service.ProvideUpdateService(updateCache, gitHubReleaseClient, serviceBuildInfo)
systemHandler := handler.ProvideSystemHandler(updateService)
adminSubscriptionHandler := admin.NewSubscriptionHandler(subscriptionService)
- adminUsageHandler := admin.NewUsageHandler(usageService, apiKeyService, adminService)
+ usageCleanupRepository := repository.NewUsageCleanupRepository(db)
+ usageCleanupService := service.ProvideUsageCleanupService(usageCleanupRepository, timingWheelService, dashboardAggregationService, configConfig)
+ adminUsageHandler := admin.NewUsageHandler(usageService, apiKeyService, adminService, usageCleanupService)
userAttributeDefinitionRepository := repository.NewUserAttributeDefinitionRepository(client)
userAttributeValueRepository := repository.NewUserAttributeValueRepository(client)
userAttributeService := service.NewUserAttributeService(userAttributeDefinitionRepository, userAttributeValueRepository)
@@ -175,7 +177,7 @@ func initializeApplication(buildInfo handler.BuildInfo) (*Application, error) {
opsScheduledReportService := service.ProvideOpsScheduledReportService(opsService, userService, emailService, redisClient, configConfig)
tokenRefreshService := service.ProvideTokenRefreshService(accountRepository, oAuthService, openAIOAuthService, geminiOAuthService, antigravityOAuthService, compositeTokenCacheInvalidator, configConfig)
accountExpiryService := service.ProvideAccountExpiryService(accountRepository)
- v := provideCleanup(client, redisClient, opsMetricsCollector, opsAggregationService, opsAlertEvaluatorService, opsCleanupService, opsScheduledReportService, schedulerSnapshotService, tokenRefreshService, accountExpiryService, pricingService, emailQueueService, billingCacheService, oAuthService, openAIOAuthService, geminiOAuthService, antigravityOAuthService)
+ v := provideCleanup(client, redisClient, opsMetricsCollector, opsAggregationService, opsAlertEvaluatorService, opsCleanupService, opsScheduledReportService, schedulerSnapshotService, tokenRefreshService, accountExpiryService, usageCleanupService, pricingService, emailQueueService, billingCacheService, oAuthService, openAIOAuthService, geminiOAuthService, antigravityOAuthService)
application := &Application{
Server: httpServer,
Cleanup: v,
@@ -208,6 +210,7 @@ func provideCleanup(
schedulerSnapshot *service.SchedulerSnapshotService,
tokenRefresh *service.TokenRefreshService,
accountExpiry *service.AccountExpiryService,
+ usageCleanup *service.UsageCleanupService,
pricing *service.PricingService,
emailQueue *service.EmailQueueService,
billingCache *service.BillingCacheService,
@@ -260,6 +263,12 @@ func provideCleanup(
}
return nil
}},
+ {"UsageCleanupService", func() error {
+ if usageCleanup != nil {
+ usageCleanup.Stop()
+ }
+ return nil
+ }},
{"TokenRefreshService", func() error {
tokenRefresh.Stop()
return nil
diff --git a/backend/go.mod b/backend/go.mod
index 4ac6ba14..9ebae69e 100644
--- a/backend/go.mod
+++ b/backend/go.mod
@@ -31,6 +31,7 @@ require (
ariga.io/atlas v0.32.1-0.20250325101103-175b25e1c1b9 // indirect
dario.cat/mergo v1.0.2 // indirect
github.com/Azure/go-ansiterm v0.0.0-20210617225240-d185dfc1b5a1 // indirect
+ github.com/DATA-DOG/go-sqlmock v1.5.2 // indirect
github.com/Microsoft/go-winio v0.6.2 // indirect
github.com/agext/levenshtein v1.2.3 // indirect
github.com/andybalholm/brotli v1.2.0 // indirect
diff --git a/backend/go.sum b/backend/go.sum
index 415e73a7..4496603d 100644
--- a/backend/go.sum
+++ b/backend/go.sum
@@ -141,6 +141,7 @@ github.com/jackc/puddle/v2 v2.2.2 h1:PR8nw+E/1w0GLuRFSmiioY6UooMp6KJv0/61nB7icHo
github.com/jackc/puddle/v2 v2.2.2/go.mod h1:vriiEXHvEE654aYKXXjOvZM39qJ0q+azkZFrfEOc3H4=
github.com/json-iterator/go v1.1.12 h1:PV8peI4a0ysnczrg+LtxykD8LfKY9ML6u2jnxaEnrnM=
github.com/json-iterator/go v1.1.12/go.mod h1:e30LSqwooZae/UwlEbR2852Gd8hjQvJoHmT4TnhNGBo=
+github.com/kisielk/sqlstruct v0.0.0-20201105191214-5f3e10d3ab46/go.mod h1:yyMNCyc/Ib3bDTKd379tNMpB/7/H5TjM2Y9QJ5THLbE=
github.com/klauspost/compress v1.18.2 h1:iiPHWW0YrcFgpBYhsA6D1+fqHssJscY/Tm/y2Uqnapk=
github.com/klauspost/compress v1.18.2/go.mod h1:R0h/fSBs8DE4ENlcrlib3PsXS61voFxhIs2DeRhCvJ4=
github.com/klauspost/cpuid/v2 v2.0.9/go.mod h1:FInQzS24/EEf25PyTYn52gqo7WaD8xa0213Md/qVLRg=
diff --git a/backend/internal/config/config.go b/backend/internal/config/config.go
index 5dc6ad19..d616e44b 100644
--- a/backend/internal/config/config.go
+++ b/backend/internal/config/config.go
@@ -55,6 +55,7 @@ type Config struct {
APIKeyAuth APIKeyAuthCacheConfig `mapstructure:"api_key_auth_cache"`
Dashboard DashboardCacheConfig `mapstructure:"dashboard_cache"`
DashboardAgg DashboardAggregationConfig `mapstructure:"dashboard_aggregation"`
+ UsageCleanup UsageCleanupConfig `mapstructure:"usage_cleanup"`
Concurrency ConcurrencyConfig `mapstructure:"concurrency"`
TokenRefresh TokenRefreshConfig `mapstructure:"token_refresh"`
RunMode string `mapstructure:"run_mode" yaml:"run_mode"`
@@ -489,6 +490,20 @@ type DashboardAggregationRetentionConfig struct {
DailyDays int `mapstructure:"daily_days"`
}
+// UsageCleanupConfig 使用记录清理任务配置
+type UsageCleanupConfig struct {
+ // Enabled: 是否启用清理任务执行器
+ Enabled bool `mapstructure:"enabled"`
+ // MaxRangeDays: 单次任务允许的最大时间跨度(天)
+ MaxRangeDays int `mapstructure:"max_range_days"`
+ // BatchSize: 单批删除数量
+ BatchSize int `mapstructure:"batch_size"`
+ // WorkerIntervalSeconds: 后台任务轮询间隔(秒)
+ WorkerIntervalSeconds int `mapstructure:"worker_interval_seconds"`
+ // TaskTimeoutSeconds: 单次任务最大执行时长(秒)
+ TaskTimeoutSeconds int `mapstructure:"task_timeout_seconds"`
+}
+
func NormalizeRunMode(value string) string {
normalized := strings.ToLower(strings.TrimSpace(value))
switch normalized {
@@ -749,6 +764,13 @@ func setDefaults() {
viper.SetDefault("dashboard_aggregation.retention.daily_days", 730)
viper.SetDefault("dashboard_aggregation.recompute_days", 2)
+ // Usage cleanup task
+ viper.SetDefault("usage_cleanup.enabled", true)
+ viper.SetDefault("usage_cleanup.max_range_days", 31)
+ viper.SetDefault("usage_cleanup.batch_size", 5000)
+ viper.SetDefault("usage_cleanup.worker_interval_seconds", 10)
+ viper.SetDefault("usage_cleanup.task_timeout_seconds", 1800)
+
// Gateway
viper.SetDefault("gateway.response_header_timeout", 600) // 600秒(10分钟)等待上游响应头,LLM高负载时可能排队较久
viper.SetDefault("gateway.log_upstream_error_body", true)
@@ -985,6 +1007,33 @@ func (c *Config) Validate() error {
return fmt.Errorf("dashboard_aggregation.recompute_days must be non-negative")
}
}
+ if c.UsageCleanup.Enabled {
+ if c.UsageCleanup.MaxRangeDays <= 0 {
+ return fmt.Errorf("usage_cleanup.max_range_days must be positive")
+ }
+ if c.UsageCleanup.BatchSize <= 0 {
+ return fmt.Errorf("usage_cleanup.batch_size must be positive")
+ }
+ if c.UsageCleanup.WorkerIntervalSeconds <= 0 {
+ return fmt.Errorf("usage_cleanup.worker_interval_seconds must be positive")
+ }
+ if c.UsageCleanup.TaskTimeoutSeconds <= 0 {
+ return fmt.Errorf("usage_cleanup.task_timeout_seconds must be positive")
+ }
+ } else {
+ if c.UsageCleanup.MaxRangeDays < 0 {
+ return fmt.Errorf("usage_cleanup.max_range_days must be non-negative")
+ }
+ if c.UsageCleanup.BatchSize < 0 {
+ return fmt.Errorf("usage_cleanup.batch_size must be non-negative")
+ }
+ if c.UsageCleanup.WorkerIntervalSeconds < 0 {
+ return fmt.Errorf("usage_cleanup.worker_interval_seconds must be non-negative")
+ }
+ if c.UsageCleanup.TaskTimeoutSeconds < 0 {
+ return fmt.Errorf("usage_cleanup.task_timeout_seconds must be non-negative")
+ }
+ }
if c.Gateway.MaxBodySize <= 0 {
return fmt.Errorf("gateway.max_body_size must be positive")
}
diff --git a/backend/internal/config/config_test.go b/backend/internal/config/config_test.go
index 4637989e..f734619f 100644
--- a/backend/internal/config/config_test.go
+++ b/backend/internal/config/config_test.go
@@ -280,3 +280,573 @@ func TestValidateDashboardAggregationBackfillMaxDays(t *testing.T) {
t.Fatalf("Validate() expected backfill_max_days error, got: %v", err)
}
}
+
+func TestLoadDefaultUsageCleanupConfig(t *testing.T) {
+ viper.Reset()
+
+ cfg, err := Load()
+ if err != nil {
+ t.Fatalf("Load() error: %v", err)
+ }
+
+ if !cfg.UsageCleanup.Enabled {
+ t.Fatalf("UsageCleanup.Enabled = false, want true")
+ }
+ if cfg.UsageCleanup.MaxRangeDays != 31 {
+ t.Fatalf("UsageCleanup.MaxRangeDays = %d, want 31", cfg.UsageCleanup.MaxRangeDays)
+ }
+ if cfg.UsageCleanup.BatchSize != 5000 {
+ t.Fatalf("UsageCleanup.BatchSize = %d, want 5000", cfg.UsageCleanup.BatchSize)
+ }
+ if cfg.UsageCleanup.WorkerIntervalSeconds != 10 {
+ t.Fatalf("UsageCleanup.WorkerIntervalSeconds = %d, want 10", cfg.UsageCleanup.WorkerIntervalSeconds)
+ }
+ if cfg.UsageCleanup.TaskTimeoutSeconds != 1800 {
+ t.Fatalf("UsageCleanup.TaskTimeoutSeconds = %d, want 1800", cfg.UsageCleanup.TaskTimeoutSeconds)
+ }
+}
+
+func TestValidateUsageCleanupConfigEnabled(t *testing.T) {
+ viper.Reset()
+
+ cfg, err := Load()
+ if err != nil {
+ t.Fatalf("Load() error: %v", err)
+ }
+
+ cfg.UsageCleanup.Enabled = true
+ cfg.UsageCleanup.MaxRangeDays = 0
+ err = cfg.Validate()
+ if err == nil {
+ t.Fatalf("Validate() expected error for usage_cleanup.max_range_days, got nil")
+ }
+ if !strings.Contains(err.Error(), "usage_cleanup.max_range_days") {
+ t.Fatalf("Validate() expected max_range_days error, got: %v", err)
+ }
+}
+
+func TestValidateUsageCleanupConfigDisabled(t *testing.T) {
+ viper.Reset()
+
+ cfg, err := Load()
+ if err != nil {
+ t.Fatalf("Load() error: %v", err)
+ }
+
+ cfg.UsageCleanup.Enabled = false
+ cfg.UsageCleanup.BatchSize = -1
+ err = cfg.Validate()
+ if err == nil {
+ t.Fatalf("Validate() expected error for usage_cleanup.batch_size, got nil")
+ }
+ if !strings.Contains(err.Error(), "usage_cleanup.batch_size") {
+ t.Fatalf("Validate() expected batch_size error, got: %v", err)
+ }
+}
+
+func TestConfigAddressHelpers(t *testing.T) {
+ server := ServerConfig{Host: "127.0.0.1", Port: 9000}
+ if server.Address() != "127.0.0.1:9000" {
+ t.Fatalf("ServerConfig.Address() = %q", server.Address())
+ }
+
+ dbCfg := DatabaseConfig{
+ Host: "localhost",
+ Port: 5432,
+ User: "postgres",
+ Password: "",
+ DBName: "sub2api",
+ SSLMode: "disable",
+ }
+ if !strings.Contains(dbCfg.DSN(), "password=") {
+ } else {
+ t.Fatalf("DatabaseConfig.DSN() should not include password when empty")
+ }
+
+ dbCfg.Password = "secret"
+ if !strings.Contains(dbCfg.DSN(), "password=secret") {
+ t.Fatalf("DatabaseConfig.DSN() missing password")
+ }
+
+ dbCfg.Password = ""
+ if strings.Contains(dbCfg.DSNWithTimezone("UTC"), "password=") {
+ t.Fatalf("DatabaseConfig.DSNWithTimezone() should omit password when empty")
+ }
+
+ if !strings.Contains(dbCfg.DSNWithTimezone(""), "TimeZone=Asia/Shanghai") {
+ t.Fatalf("DatabaseConfig.DSNWithTimezone() should use default timezone")
+ }
+ if !strings.Contains(dbCfg.DSNWithTimezone("UTC"), "TimeZone=UTC") {
+ t.Fatalf("DatabaseConfig.DSNWithTimezone() should use provided timezone")
+ }
+
+ redis := RedisConfig{Host: "redis", Port: 6379}
+ if redis.Address() != "redis:6379" {
+ t.Fatalf("RedisConfig.Address() = %q", redis.Address())
+ }
+}
+
+func TestNormalizeStringSlice(t *testing.T) {
+ values := normalizeStringSlice([]string{" a ", "", "b", " ", "c"})
+ if len(values) != 3 || values[0] != "a" || values[1] != "b" || values[2] != "c" {
+ t.Fatalf("normalizeStringSlice() unexpected result: %#v", values)
+ }
+ if normalizeStringSlice(nil) != nil {
+ t.Fatalf("normalizeStringSlice(nil) expected nil slice")
+ }
+}
+
+func TestGetServerAddressFromEnv(t *testing.T) {
+ t.Setenv("SERVER_HOST", "127.0.0.1")
+ t.Setenv("SERVER_PORT", "9090")
+
+ address := GetServerAddress()
+ if address != "127.0.0.1:9090" {
+ t.Fatalf("GetServerAddress() = %q", address)
+ }
+}
+
+func TestValidateAbsoluteHTTPURL(t *testing.T) {
+ if err := ValidateAbsoluteHTTPURL("https://example.com/path"); err != nil {
+ t.Fatalf("ValidateAbsoluteHTTPURL valid url error: %v", err)
+ }
+ if err := ValidateAbsoluteHTTPURL(""); err == nil {
+ t.Fatalf("ValidateAbsoluteHTTPURL should reject empty url")
+ }
+ if err := ValidateAbsoluteHTTPURL("/relative"); err == nil {
+ t.Fatalf("ValidateAbsoluteHTTPURL should reject relative url")
+ }
+ if err := ValidateAbsoluteHTTPURL("ftp://example.com"); err == nil {
+ t.Fatalf("ValidateAbsoluteHTTPURL should reject ftp scheme")
+ }
+ if err := ValidateAbsoluteHTTPURL("https://example.com/#frag"); err == nil {
+ t.Fatalf("ValidateAbsoluteHTTPURL should reject fragment")
+ }
+}
+
+func TestValidateFrontendRedirectURL(t *testing.T) {
+ if err := ValidateFrontendRedirectURL("/auth/callback"); err != nil {
+ t.Fatalf("ValidateFrontendRedirectURL relative error: %v", err)
+ }
+ if err := ValidateFrontendRedirectURL("https://example.com/auth"); err != nil {
+ t.Fatalf("ValidateFrontendRedirectURL absolute error: %v", err)
+ }
+ if err := ValidateFrontendRedirectURL("example.com/path"); err == nil {
+ t.Fatalf("ValidateFrontendRedirectURL should reject non-absolute url")
+ }
+ if err := ValidateFrontendRedirectURL("//evil.com"); err == nil {
+ t.Fatalf("ValidateFrontendRedirectURL should reject // prefix")
+ }
+ if err := ValidateFrontendRedirectURL("javascript:alert(1)"); err == nil {
+ t.Fatalf("ValidateFrontendRedirectURL should reject javascript scheme")
+ }
+}
+
+func TestWarnIfInsecureURL(t *testing.T) {
+ warnIfInsecureURL("test", "http://example.com")
+ warnIfInsecureURL("test", "bad://url")
+}
+
+func TestGenerateJWTSecretDefaultLength(t *testing.T) {
+ secret, err := generateJWTSecret(0)
+ if err != nil {
+ t.Fatalf("generateJWTSecret error: %v", err)
+ }
+ if len(secret) == 0 {
+ t.Fatalf("generateJWTSecret returned empty string")
+ }
+}
+
+func TestValidateOpsCleanupScheduleRequired(t *testing.T) {
+ viper.Reset()
+
+ cfg, err := Load()
+ if err != nil {
+ t.Fatalf("Load() error: %v", err)
+ }
+ cfg.Ops.Cleanup.Enabled = true
+ cfg.Ops.Cleanup.Schedule = ""
+ err = cfg.Validate()
+ if err == nil {
+ t.Fatalf("Validate() expected error for ops.cleanup.schedule")
+ }
+ if !strings.Contains(err.Error(), "ops.cleanup.schedule") {
+ t.Fatalf("Validate() expected ops.cleanup.schedule error, got: %v", err)
+ }
+}
+
+func TestValidateConcurrencyPingInterval(t *testing.T) {
+ viper.Reset()
+
+ cfg, err := Load()
+ if err != nil {
+ t.Fatalf("Load() error: %v", err)
+ }
+ cfg.Concurrency.PingInterval = 3
+ err = cfg.Validate()
+ if err == nil {
+ t.Fatalf("Validate() expected error for concurrency.ping_interval")
+ }
+ if !strings.Contains(err.Error(), "concurrency.ping_interval") {
+ t.Fatalf("Validate() expected concurrency.ping_interval error, got: %v", err)
+ }
+}
+
+func TestProvideConfig(t *testing.T) {
+ viper.Reset()
+ if _, err := ProvideConfig(); err != nil {
+ t.Fatalf("ProvideConfig() error: %v", err)
+ }
+}
+
+func TestValidateConfigWithLinuxDoEnabled(t *testing.T) {
+ viper.Reset()
+
+ cfg, err := Load()
+ if err != nil {
+ t.Fatalf("Load() error: %v", err)
+ }
+
+ cfg.Security.CSP.Enabled = true
+ cfg.Security.CSP.Policy = "default-src 'self'"
+
+ cfg.LinuxDo.Enabled = true
+ cfg.LinuxDo.ClientID = "client"
+ cfg.LinuxDo.ClientSecret = "secret"
+ cfg.LinuxDo.AuthorizeURL = "https://example.com/oauth2/authorize"
+ cfg.LinuxDo.TokenURL = "https://example.com/oauth2/token"
+ cfg.LinuxDo.UserInfoURL = "https://example.com/oauth2/userinfo"
+ cfg.LinuxDo.RedirectURL = "https://example.com/api/v1/auth/oauth/linuxdo/callback"
+ cfg.LinuxDo.FrontendRedirectURL = "/auth/linuxdo/callback"
+ cfg.LinuxDo.TokenAuthMethod = "client_secret_post"
+
+ if err := cfg.Validate(); err != nil {
+ t.Fatalf("Validate() unexpected error: %v", err)
+ }
+}
+
+func TestValidateJWTSecretStrength(t *testing.T) {
+ if !isWeakJWTSecret("change-me-in-production") {
+ t.Fatalf("isWeakJWTSecret should detect weak secret")
+ }
+ if isWeakJWTSecret("StrongSecretValue") {
+ t.Fatalf("isWeakJWTSecret should accept strong secret")
+ }
+}
+
+func TestGenerateJWTSecretWithLength(t *testing.T) {
+ secret, err := generateJWTSecret(16)
+ if err != nil {
+ t.Fatalf("generateJWTSecret error: %v", err)
+ }
+ if len(secret) == 0 {
+ t.Fatalf("generateJWTSecret returned empty string")
+ }
+}
+
+func TestValidateAbsoluteHTTPURLMissingHost(t *testing.T) {
+ if err := ValidateAbsoluteHTTPURL("https://"); err == nil {
+ t.Fatalf("ValidateAbsoluteHTTPURL should reject missing host")
+ }
+}
+
+func TestValidateFrontendRedirectURLInvalidChars(t *testing.T) {
+ if err := ValidateFrontendRedirectURL("/auth/\ncallback"); err == nil {
+ t.Fatalf("ValidateFrontendRedirectURL should reject invalid chars")
+ }
+ if err := ValidateFrontendRedirectURL("http://"); err == nil {
+ t.Fatalf("ValidateFrontendRedirectURL should reject missing host")
+ }
+ if err := ValidateFrontendRedirectURL("mailto:user@example.com"); err == nil {
+ t.Fatalf("ValidateFrontendRedirectURL should reject mailto")
+ }
+}
+
+func TestWarnIfInsecureURLHTTPS(t *testing.T) {
+ warnIfInsecureURL("secure", "https://example.com")
+}
+
+func TestValidateConfigErrors(t *testing.T) {
+ buildValid := func(t *testing.T) *Config {
+ t.Helper()
+ viper.Reset()
+ cfg, err := Load()
+ if err != nil {
+ t.Fatalf("Load() error: %v", err)
+ }
+ return cfg
+ }
+
+ cases := []struct {
+ name string
+ mutate func(*Config)
+ wantErr string
+ }{
+ {
+ name: "jwt expire hour positive",
+ mutate: func(c *Config) { c.JWT.ExpireHour = 0 },
+ wantErr: "jwt.expire_hour must be positive",
+ },
+ {
+ name: "jwt expire hour max",
+ mutate: func(c *Config) { c.JWT.ExpireHour = 200 },
+ wantErr: "jwt.expire_hour must be <= 168",
+ },
+ {
+ name: "csp policy required",
+ mutate: func(c *Config) { c.Security.CSP.Enabled = true; c.Security.CSP.Policy = "" },
+ wantErr: "security.csp.policy",
+ },
+ {
+ name: "linuxdo client id required",
+ mutate: func(c *Config) {
+ c.LinuxDo.Enabled = true
+ c.LinuxDo.ClientID = ""
+ },
+ wantErr: "linuxdo_connect.client_id",
+ },
+ {
+ name: "linuxdo token auth method",
+ mutate: func(c *Config) {
+ c.LinuxDo.Enabled = true
+ c.LinuxDo.ClientID = "client"
+ c.LinuxDo.ClientSecret = "secret"
+ c.LinuxDo.AuthorizeURL = "https://example.com/authorize"
+ c.LinuxDo.TokenURL = "https://example.com/token"
+ c.LinuxDo.UserInfoURL = "https://example.com/userinfo"
+ c.LinuxDo.RedirectURL = "https://example.com/callback"
+ c.LinuxDo.FrontendRedirectURL = "/auth/callback"
+ c.LinuxDo.TokenAuthMethod = "invalid"
+ },
+ wantErr: "linuxdo_connect.token_auth_method",
+ },
+ {
+ name: "billing circuit breaker threshold",
+ mutate: func(c *Config) { c.Billing.CircuitBreaker.FailureThreshold = 0 },
+ wantErr: "billing.circuit_breaker.failure_threshold",
+ },
+ {
+ name: "billing circuit breaker reset",
+ mutate: func(c *Config) { c.Billing.CircuitBreaker.ResetTimeoutSeconds = 0 },
+ wantErr: "billing.circuit_breaker.reset_timeout_seconds",
+ },
+ {
+ name: "billing circuit breaker half open",
+ mutate: func(c *Config) { c.Billing.CircuitBreaker.HalfOpenRequests = 0 },
+ wantErr: "billing.circuit_breaker.half_open_requests",
+ },
+ {
+ name: "database max open conns",
+ mutate: func(c *Config) { c.Database.MaxOpenConns = 0 },
+ wantErr: "database.max_open_conns",
+ },
+ {
+ name: "database max lifetime",
+ mutate: func(c *Config) { c.Database.ConnMaxLifetimeMinutes = -1 },
+ wantErr: "database.conn_max_lifetime_minutes",
+ },
+ {
+ name: "database idle exceeds open",
+ mutate: func(c *Config) { c.Database.MaxIdleConns = c.Database.MaxOpenConns + 1 },
+ wantErr: "database.max_idle_conns cannot exceed",
+ },
+ {
+ name: "redis dial timeout",
+ mutate: func(c *Config) { c.Redis.DialTimeoutSeconds = 0 },
+ wantErr: "redis.dial_timeout_seconds",
+ },
+ {
+ name: "redis read timeout",
+ mutate: func(c *Config) { c.Redis.ReadTimeoutSeconds = 0 },
+ wantErr: "redis.read_timeout_seconds",
+ },
+ {
+ name: "redis write timeout",
+ mutate: func(c *Config) { c.Redis.WriteTimeoutSeconds = 0 },
+ wantErr: "redis.write_timeout_seconds",
+ },
+ {
+ name: "redis pool size",
+ mutate: func(c *Config) { c.Redis.PoolSize = 0 },
+ wantErr: "redis.pool_size",
+ },
+ {
+ name: "redis idle exceeds pool",
+ mutate: func(c *Config) { c.Redis.MinIdleConns = c.Redis.PoolSize + 1 },
+ wantErr: "redis.min_idle_conns cannot exceed",
+ },
+ {
+ name: "dashboard cache disabled negative",
+ mutate: func(c *Config) { c.Dashboard.Enabled = false; c.Dashboard.StatsTTLSeconds = -1 },
+ wantErr: "dashboard_cache.stats_ttl_seconds",
+ },
+ {
+ name: "dashboard cache fresh ttl positive",
+ mutate: func(c *Config) { c.Dashboard.Enabled = true; c.Dashboard.StatsFreshTTLSeconds = 0 },
+ wantErr: "dashboard_cache.stats_fresh_ttl_seconds",
+ },
+ {
+ name: "dashboard aggregation enabled interval",
+ mutate: func(c *Config) { c.DashboardAgg.Enabled = true; c.DashboardAgg.IntervalSeconds = 0 },
+ wantErr: "dashboard_aggregation.interval_seconds",
+ },
+ {
+ name: "dashboard aggregation backfill positive",
+ mutate: func(c *Config) {
+ c.DashboardAgg.Enabled = true
+ c.DashboardAgg.BackfillEnabled = true
+ c.DashboardAgg.BackfillMaxDays = 0
+ },
+ wantErr: "dashboard_aggregation.backfill_max_days",
+ },
+ {
+ name: "dashboard aggregation retention",
+ mutate: func(c *Config) { c.DashboardAgg.Enabled = true; c.DashboardAgg.Retention.UsageLogsDays = 0 },
+ wantErr: "dashboard_aggregation.retention.usage_logs_days",
+ },
+ {
+ name: "dashboard aggregation disabled interval",
+ mutate: func(c *Config) { c.DashboardAgg.Enabled = false; c.DashboardAgg.IntervalSeconds = -1 },
+ wantErr: "dashboard_aggregation.interval_seconds",
+ },
+ {
+ name: "usage cleanup max range",
+ mutate: func(c *Config) { c.UsageCleanup.Enabled = true; c.UsageCleanup.MaxRangeDays = 0 },
+ wantErr: "usage_cleanup.max_range_days",
+ },
+ {
+ name: "usage cleanup worker interval",
+ mutate: func(c *Config) { c.UsageCleanup.Enabled = true; c.UsageCleanup.WorkerIntervalSeconds = 0 },
+ wantErr: "usage_cleanup.worker_interval_seconds",
+ },
+ {
+ name: "usage cleanup batch size",
+ mutate: func(c *Config) { c.UsageCleanup.Enabled = true; c.UsageCleanup.BatchSize = 0 },
+ wantErr: "usage_cleanup.batch_size",
+ },
+ {
+ name: "usage cleanup disabled negative",
+ mutate: func(c *Config) { c.UsageCleanup.Enabled = false; c.UsageCleanup.BatchSize = -1 },
+ wantErr: "usage_cleanup.batch_size",
+ },
+ {
+ name: "gateway max body size",
+ mutate: func(c *Config) { c.Gateway.MaxBodySize = 0 },
+ wantErr: "gateway.max_body_size",
+ },
+ {
+ name: "gateway max idle conns",
+ mutate: func(c *Config) { c.Gateway.MaxIdleConns = 0 },
+ wantErr: "gateway.max_idle_conns",
+ },
+ {
+ name: "gateway max idle conns per host",
+ mutate: func(c *Config) { c.Gateway.MaxIdleConnsPerHost = 0 },
+ wantErr: "gateway.max_idle_conns_per_host",
+ },
+ {
+ name: "gateway idle timeout",
+ mutate: func(c *Config) { c.Gateway.IdleConnTimeoutSeconds = 0 },
+ wantErr: "gateway.idle_conn_timeout_seconds",
+ },
+ {
+ name: "gateway max upstream clients",
+ mutate: func(c *Config) { c.Gateway.MaxUpstreamClients = 0 },
+ wantErr: "gateway.max_upstream_clients",
+ },
+ {
+ name: "gateway client idle ttl",
+ mutate: func(c *Config) { c.Gateway.ClientIdleTTLSeconds = 0 },
+ wantErr: "gateway.client_idle_ttl_seconds",
+ },
+ {
+ name: "gateway concurrency slot ttl",
+ mutate: func(c *Config) { c.Gateway.ConcurrencySlotTTLMinutes = 0 },
+ wantErr: "gateway.concurrency_slot_ttl_minutes",
+ },
+ {
+ name: "gateway max conns per host",
+ mutate: func(c *Config) { c.Gateway.MaxConnsPerHost = -1 },
+ wantErr: "gateway.max_conns_per_host",
+ },
+ {
+ name: "gateway connection isolation",
+ mutate: func(c *Config) { c.Gateway.ConnectionPoolIsolation = "invalid" },
+ wantErr: "gateway.connection_pool_isolation",
+ },
+ {
+ name: "gateway stream keepalive range",
+ mutate: func(c *Config) { c.Gateway.StreamKeepaliveInterval = 4 },
+ wantErr: "gateway.stream_keepalive_interval",
+ },
+ {
+ name: "gateway stream data interval range",
+ mutate: func(c *Config) { c.Gateway.StreamDataIntervalTimeout = 5 },
+ wantErr: "gateway.stream_data_interval_timeout",
+ },
+ {
+ name: "gateway stream data interval negative",
+ mutate: func(c *Config) { c.Gateway.StreamDataIntervalTimeout = -1 },
+ wantErr: "gateway.stream_data_interval_timeout must be non-negative",
+ },
+ {
+ name: "gateway max line size",
+ mutate: func(c *Config) { c.Gateway.MaxLineSize = 1024 },
+ wantErr: "gateway.max_line_size must be at least",
+ },
+ {
+ name: "gateway max line size negative",
+ mutate: func(c *Config) { c.Gateway.MaxLineSize = -1 },
+ wantErr: "gateway.max_line_size must be non-negative",
+ },
+ {
+ name: "gateway scheduling sticky waiting",
+ mutate: func(c *Config) { c.Gateway.Scheduling.StickySessionMaxWaiting = 0 },
+ wantErr: "gateway.scheduling.sticky_session_max_waiting",
+ },
+ {
+ name: "gateway scheduling outbox poll",
+ mutate: func(c *Config) { c.Gateway.Scheduling.OutboxPollIntervalSeconds = 0 },
+ wantErr: "gateway.scheduling.outbox_poll_interval_seconds",
+ },
+ {
+ name: "gateway scheduling outbox failures",
+ mutate: func(c *Config) { c.Gateway.Scheduling.OutboxLagRebuildFailures = 0 },
+ wantErr: "gateway.scheduling.outbox_lag_rebuild_failures",
+ },
+ {
+ name: "gateway outbox lag rebuild",
+ mutate: func(c *Config) {
+ c.Gateway.Scheduling.OutboxLagWarnSeconds = 10
+ c.Gateway.Scheduling.OutboxLagRebuildSeconds = 5
+ },
+ wantErr: "gateway.scheduling.outbox_lag_rebuild_seconds",
+ },
+ {
+ name: "ops metrics collector ttl",
+ mutate: func(c *Config) { c.Ops.MetricsCollectorCache.TTL = -1 },
+ wantErr: "ops.metrics_collector_cache.ttl",
+ },
+ {
+ name: "ops cleanup retention",
+ mutate: func(c *Config) { c.Ops.Cleanup.ErrorLogRetentionDays = -1 },
+ wantErr: "ops.cleanup.error_log_retention_days",
+ },
+ {
+ name: "ops cleanup minute retention",
+ mutate: func(c *Config) { c.Ops.Cleanup.MinuteMetricsRetentionDays = -1 },
+ wantErr: "ops.cleanup.minute_metrics_retention_days",
+ },
+ }
+
+ for _, tt := range cases {
+ t.Run(tt.name, func(t *testing.T) {
+ cfg := buildValid(t)
+ tt.mutate(cfg)
+ err := cfg.Validate()
+ if err == nil || !strings.Contains(err.Error(), tt.wantErr) {
+ t.Fatalf("Validate() error = %v, want %q", err, tt.wantErr)
+ }
+ })
+ }
+}
diff --git a/backend/internal/handler/admin/admin_basic_handlers_test.go b/backend/internal/handler/admin/admin_basic_handlers_test.go
new file mode 100644
index 00000000..e0f731e1
--- /dev/null
+++ b/backend/internal/handler/admin/admin_basic_handlers_test.go
@@ -0,0 +1,262 @@
+package admin
+
+import (
+ "bytes"
+ "encoding/json"
+ "net/http"
+ "net/http/httptest"
+ "testing"
+
+ "github.com/gin-gonic/gin"
+ "github.com/stretchr/testify/require"
+)
+
+func setupAdminRouter() (*gin.Engine, *stubAdminService) {
+ gin.SetMode(gin.TestMode)
+ router := gin.New()
+ adminSvc := newStubAdminService()
+
+ userHandler := NewUserHandler(adminSvc)
+ groupHandler := NewGroupHandler(adminSvc)
+ proxyHandler := NewProxyHandler(adminSvc)
+ redeemHandler := NewRedeemHandler(adminSvc)
+
+ router.GET("/api/v1/admin/users", userHandler.List)
+ router.GET("/api/v1/admin/users/:id", userHandler.GetByID)
+ router.POST("/api/v1/admin/users", userHandler.Create)
+ router.PUT("/api/v1/admin/users/:id", userHandler.Update)
+ router.DELETE("/api/v1/admin/users/:id", userHandler.Delete)
+ router.POST("/api/v1/admin/users/:id/balance", userHandler.UpdateBalance)
+ router.GET("/api/v1/admin/users/:id/api-keys", userHandler.GetUserAPIKeys)
+ router.GET("/api/v1/admin/users/:id/usage", userHandler.GetUserUsage)
+
+ router.GET("/api/v1/admin/groups", groupHandler.List)
+ router.GET("/api/v1/admin/groups/all", groupHandler.GetAll)
+ router.GET("/api/v1/admin/groups/:id", groupHandler.GetByID)
+ router.POST("/api/v1/admin/groups", groupHandler.Create)
+ router.PUT("/api/v1/admin/groups/:id", groupHandler.Update)
+ router.DELETE("/api/v1/admin/groups/:id", groupHandler.Delete)
+ router.GET("/api/v1/admin/groups/:id/stats", groupHandler.GetStats)
+ router.GET("/api/v1/admin/groups/:id/api-keys", groupHandler.GetGroupAPIKeys)
+
+ router.GET("/api/v1/admin/proxies", proxyHandler.List)
+ router.GET("/api/v1/admin/proxies/all", proxyHandler.GetAll)
+ router.GET("/api/v1/admin/proxies/:id", proxyHandler.GetByID)
+ router.POST("/api/v1/admin/proxies", proxyHandler.Create)
+ router.PUT("/api/v1/admin/proxies/:id", proxyHandler.Update)
+ router.DELETE("/api/v1/admin/proxies/:id", proxyHandler.Delete)
+ router.POST("/api/v1/admin/proxies/batch-delete", proxyHandler.BatchDelete)
+ router.POST("/api/v1/admin/proxies/:id/test", proxyHandler.Test)
+ router.GET("/api/v1/admin/proxies/:id/stats", proxyHandler.GetStats)
+ router.GET("/api/v1/admin/proxies/:id/accounts", proxyHandler.GetProxyAccounts)
+
+ router.GET("/api/v1/admin/redeem-codes", redeemHandler.List)
+ router.GET("/api/v1/admin/redeem-codes/:id", redeemHandler.GetByID)
+ router.POST("/api/v1/admin/redeem-codes", redeemHandler.Generate)
+ router.DELETE("/api/v1/admin/redeem-codes/:id", redeemHandler.Delete)
+ router.POST("/api/v1/admin/redeem-codes/batch-delete", redeemHandler.BatchDelete)
+ router.POST("/api/v1/admin/redeem-codes/:id/expire", redeemHandler.Expire)
+ router.GET("/api/v1/admin/redeem-codes/:id/stats", redeemHandler.GetStats)
+
+ return router, adminSvc
+}
+
+func TestUserHandlerEndpoints(t *testing.T) {
+ router, _ := setupAdminRouter()
+
+ rec := httptest.NewRecorder()
+ req := httptest.NewRequest(http.MethodGet, "/api/v1/admin/users?page=1&page_size=20", nil)
+ router.ServeHTTP(rec, req)
+ require.Equal(t, http.StatusOK, rec.Code)
+
+ rec = httptest.NewRecorder()
+ req = httptest.NewRequest(http.MethodGet, "/api/v1/admin/users/1", nil)
+ router.ServeHTTP(rec, req)
+ require.Equal(t, http.StatusOK, rec.Code)
+
+ createBody := map[string]any{"email": "new@example.com", "password": "pass123", "balance": 1, "concurrency": 2}
+ body, _ := json.Marshal(createBody)
+ rec = httptest.NewRecorder()
+ req = httptest.NewRequest(http.MethodPost, "/api/v1/admin/users", bytes.NewReader(body))
+ req.Header.Set("Content-Type", "application/json")
+ router.ServeHTTP(rec, req)
+ require.Equal(t, http.StatusOK, rec.Code)
+
+ updateBody := map[string]any{"email": "updated@example.com"}
+ body, _ = json.Marshal(updateBody)
+ rec = httptest.NewRecorder()
+ req = httptest.NewRequest(http.MethodPut, "/api/v1/admin/users/1", bytes.NewReader(body))
+ req.Header.Set("Content-Type", "application/json")
+ router.ServeHTTP(rec, req)
+ require.Equal(t, http.StatusOK, rec.Code)
+
+ rec = httptest.NewRecorder()
+ req = httptest.NewRequest(http.MethodDelete, "/api/v1/admin/users/1", nil)
+ router.ServeHTTP(rec, req)
+ require.Equal(t, http.StatusOK, rec.Code)
+
+ rec = httptest.NewRecorder()
+ req = httptest.NewRequest(http.MethodPost, "/api/v1/admin/users/1/balance", bytes.NewBufferString(`{"balance":1,"operation":"add"}`))
+ req.Header.Set("Content-Type", "application/json")
+ router.ServeHTTP(rec, req)
+ require.Equal(t, http.StatusOK, rec.Code)
+
+ rec = httptest.NewRecorder()
+ req = httptest.NewRequest(http.MethodGet, "/api/v1/admin/users/1/api-keys", nil)
+ router.ServeHTTP(rec, req)
+ require.Equal(t, http.StatusOK, rec.Code)
+
+ rec = httptest.NewRecorder()
+ req = httptest.NewRequest(http.MethodGet, "/api/v1/admin/users/1/usage?period=today", nil)
+ router.ServeHTTP(rec, req)
+ require.Equal(t, http.StatusOK, rec.Code)
+}
+
+func TestGroupHandlerEndpoints(t *testing.T) {
+ router, _ := setupAdminRouter()
+
+ rec := httptest.NewRecorder()
+ req := httptest.NewRequest(http.MethodGet, "/api/v1/admin/groups", nil)
+ router.ServeHTTP(rec, req)
+ require.Equal(t, http.StatusOK, rec.Code)
+
+ rec = httptest.NewRecorder()
+ req = httptest.NewRequest(http.MethodGet, "/api/v1/admin/groups/all", nil)
+ router.ServeHTTP(rec, req)
+ require.Equal(t, http.StatusOK, rec.Code)
+
+ rec = httptest.NewRecorder()
+ req = httptest.NewRequest(http.MethodGet, "/api/v1/admin/groups/2", nil)
+ router.ServeHTTP(rec, req)
+ require.Equal(t, http.StatusOK, rec.Code)
+
+ body, _ := json.Marshal(map[string]any{"name": "new", "platform": "anthropic", "subscription_type": "standard"})
+ rec = httptest.NewRecorder()
+ req = httptest.NewRequest(http.MethodPost, "/api/v1/admin/groups", bytes.NewReader(body))
+ req.Header.Set("Content-Type", "application/json")
+ router.ServeHTTP(rec, req)
+ require.Equal(t, http.StatusOK, rec.Code)
+
+ body, _ = json.Marshal(map[string]any{"name": "update"})
+ rec = httptest.NewRecorder()
+ req = httptest.NewRequest(http.MethodPut, "/api/v1/admin/groups/2", bytes.NewReader(body))
+ req.Header.Set("Content-Type", "application/json")
+ router.ServeHTTP(rec, req)
+ require.Equal(t, http.StatusOK, rec.Code)
+
+ rec = httptest.NewRecorder()
+ req = httptest.NewRequest(http.MethodDelete, "/api/v1/admin/groups/2", nil)
+ router.ServeHTTP(rec, req)
+ require.Equal(t, http.StatusOK, rec.Code)
+
+ rec = httptest.NewRecorder()
+ req = httptest.NewRequest(http.MethodGet, "/api/v1/admin/groups/2/stats", nil)
+ router.ServeHTTP(rec, req)
+ require.Equal(t, http.StatusOK, rec.Code)
+
+ rec = httptest.NewRecorder()
+ req = httptest.NewRequest(http.MethodGet, "/api/v1/admin/groups/2/api-keys", nil)
+ router.ServeHTTP(rec, req)
+ require.Equal(t, http.StatusOK, rec.Code)
+}
+
+func TestProxyHandlerEndpoints(t *testing.T) {
+ router, _ := setupAdminRouter()
+
+ rec := httptest.NewRecorder()
+ req := httptest.NewRequest(http.MethodGet, "/api/v1/admin/proxies", nil)
+ router.ServeHTTP(rec, req)
+ require.Equal(t, http.StatusOK, rec.Code)
+
+ rec = httptest.NewRecorder()
+ req = httptest.NewRequest(http.MethodGet, "/api/v1/admin/proxies/all", nil)
+ router.ServeHTTP(rec, req)
+ require.Equal(t, http.StatusOK, rec.Code)
+
+ rec = httptest.NewRecorder()
+ req = httptest.NewRequest(http.MethodGet, "/api/v1/admin/proxies/4", nil)
+ router.ServeHTTP(rec, req)
+ require.Equal(t, http.StatusOK, rec.Code)
+
+ body, _ := json.Marshal(map[string]any{"name": "proxy", "protocol": "http", "host": "localhost", "port": 8080})
+ rec = httptest.NewRecorder()
+ req = httptest.NewRequest(http.MethodPost, "/api/v1/admin/proxies", bytes.NewReader(body))
+ req.Header.Set("Content-Type", "application/json")
+ router.ServeHTTP(rec, req)
+ require.Equal(t, http.StatusOK, rec.Code)
+
+ body, _ = json.Marshal(map[string]any{"name": "proxy2"})
+ rec = httptest.NewRecorder()
+ req = httptest.NewRequest(http.MethodPut, "/api/v1/admin/proxies/4", bytes.NewReader(body))
+ req.Header.Set("Content-Type", "application/json")
+ router.ServeHTTP(rec, req)
+ require.Equal(t, http.StatusOK, rec.Code)
+
+ rec = httptest.NewRecorder()
+ req = httptest.NewRequest(http.MethodDelete, "/api/v1/admin/proxies/4", nil)
+ router.ServeHTTP(rec, req)
+ require.Equal(t, http.StatusOK, rec.Code)
+
+ rec = httptest.NewRecorder()
+ req = httptest.NewRequest(http.MethodPost, "/api/v1/admin/proxies/batch-delete", bytes.NewBufferString(`{"ids":[1,2]}`))
+ req.Header.Set("Content-Type", "application/json")
+ router.ServeHTTP(rec, req)
+ require.Equal(t, http.StatusOK, rec.Code)
+
+ rec = httptest.NewRecorder()
+ req = httptest.NewRequest(http.MethodPost, "/api/v1/admin/proxies/4/test", nil)
+ router.ServeHTTP(rec, req)
+ require.Equal(t, http.StatusOK, rec.Code)
+
+ rec = httptest.NewRecorder()
+ req = httptest.NewRequest(http.MethodGet, "/api/v1/admin/proxies/4/stats", nil)
+ router.ServeHTTP(rec, req)
+ require.Equal(t, http.StatusOK, rec.Code)
+
+ rec = httptest.NewRecorder()
+ req = httptest.NewRequest(http.MethodGet, "/api/v1/admin/proxies/4/accounts", nil)
+ router.ServeHTTP(rec, req)
+ require.Equal(t, http.StatusOK, rec.Code)
+}
+
+func TestRedeemHandlerEndpoints(t *testing.T) {
+ router, _ := setupAdminRouter()
+
+ rec := httptest.NewRecorder()
+ req := httptest.NewRequest(http.MethodGet, "/api/v1/admin/redeem-codes", nil)
+ router.ServeHTTP(rec, req)
+ require.Equal(t, http.StatusOK, rec.Code)
+
+ rec = httptest.NewRecorder()
+ req = httptest.NewRequest(http.MethodGet, "/api/v1/admin/redeem-codes/5", nil)
+ router.ServeHTTP(rec, req)
+ require.Equal(t, http.StatusOK, rec.Code)
+
+ body, _ := json.Marshal(map[string]any{"count": 1, "type": "balance", "value": 10})
+ rec = httptest.NewRecorder()
+ req = httptest.NewRequest(http.MethodPost, "/api/v1/admin/redeem-codes", bytes.NewReader(body))
+ req.Header.Set("Content-Type", "application/json")
+ router.ServeHTTP(rec, req)
+ require.Equal(t, http.StatusOK, rec.Code)
+
+ rec = httptest.NewRecorder()
+ req = httptest.NewRequest(http.MethodDelete, "/api/v1/admin/redeem-codes/5", nil)
+ router.ServeHTTP(rec, req)
+ require.Equal(t, http.StatusOK, rec.Code)
+
+ rec = httptest.NewRecorder()
+ req = httptest.NewRequest(http.MethodPost, "/api/v1/admin/redeem-codes/batch-delete", bytes.NewBufferString(`{"ids":[1,2]}`))
+ req.Header.Set("Content-Type", "application/json")
+ router.ServeHTTP(rec, req)
+ require.Equal(t, http.StatusOK, rec.Code)
+
+ rec = httptest.NewRecorder()
+ req = httptest.NewRequest(http.MethodPost, "/api/v1/admin/redeem-codes/5/expire", nil)
+ router.ServeHTTP(rec, req)
+ require.Equal(t, http.StatusOK, rec.Code)
+
+ rec = httptest.NewRecorder()
+ req = httptest.NewRequest(http.MethodGet, "/api/v1/admin/redeem-codes/5/stats", nil)
+ router.ServeHTTP(rec, req)
+ require.Equal(t, http.StatusOK, rec.Code)
+}
diff --git a/backend/internal/handler/admin/admin_helpers_test.go b/backend/internal/handler/admin/admin_helpers_test.go
new file mode 100644
index 00000000..863c755c
--- /dev/null
+++ b/backend/internal/handler/admin/admin_helpers_test.go
@@ -0,0 +1,134 @@
+package admin
+
+import (
+ "encoding/json"
+ "net/http"
+ "net/http/httptest"
+ "net/netip"
+ "testing"
+ "time"
+
+ "github.com/Wei-Shaw/sub2api/internal/service"
+ "github.com/gin-gonic/gin"
+ "github.com/stretchr/testify/require"
+)
+
+func TestParseTimeRange(t *testing.T) {
+ gin.SetMode(gin.TestMode)
+ w := httptest.NewRecorder()
+ c, _ := gin.CreateTestContext(w)
+ req := httptest.NewRequest(http.MethodGet, "/?start_date=2024-01-01&end_date=2024-01-02&timezone=UTC", nil)
+ c.Request = req
+
+ start, end := parseTimeRange(c)
+ require.Equal(t, time.Date(2024, 1, 1, 0, 0, 0, 0, time.UTC), start)
+ require.Equal(t, time.Date(2024, 1, 3, 0, 0, 0, 0, time.UTC), end)
+
+ req = httptest.NewRequest(http.MethodGet, "/?start_date=bad&timezone=UTC", nil)
+ c.Request = req
+ start, end = parseTimeRange(c)
+ require.False(t, start.IsZero())
+ require.False(t, end.IsZero())
+}
+
+func TestParseOpsViewParam(t *testing.T) {
+ gin.SetMode(gin.TestMode)
+ w := httptest.NewRecorder()
+ c, _ := gin.CreateTestContext(w)
+ c.Request = httptest.NewRequest(http.MethodGet, "/?view=excluded", nil)
+ require.Equal(t, opsListViewExcluded, parseOpsViewParam(c))
+
+ c2, _ := gin.CreateTestContext(w)
+ c2.Request = httptest.NewRequest(http.MethodGet, "/?view=all", nil)
+ require.Equal(t, opsListViewAll, parseOpsViewParam(c2))
+
+ c3, _ := gin.CreateTestContext(w)
+ c3.Request = httptest.NewRequest(http.MethodGet, "/?view=unknown", nil)
+ require.Equal(t, opsListViewErrors, parseOpsViewParam(c3))
+
+ require.Equal(t, "", parseOpsViewParam(nil))
+}
+
+func TestParseOpsDuration(t *testing.T) {
+ dur, ok := parseOpsDuration("1h")
+ require.True(t, ok)
+ require.Equal(t, time.Hour, dur)
+
+ _, ok = parseOpsDuration("invalid")
+ require.False(t, ok)
+}
+
+func TestParseOpsTimeRange(t *testing.T) {
+ gin.SetMode(gin.TestMode)
+ w := httptest.NewRecorder()
+ c, _ := gin.CreateTestContext(w)
+ now := time.Now().UTC()
+ startStr := now.Add(-time.Hour).Format(time.RFC3339)
+ endStr := now.Format(time.RFC3339)
+ c.Request = httptest.NewRequest(http.MethodGet, "/?start_time="+startStr+"&end_time="+endStr, nil)
+ start, end, err := parseOpsTimeRange(c, "1h")
+ require.NoError(t, err)
+ require.True(t, start.Before(end))
+
+ c2, _ := gin.CreateTestContext(w)
+ c2.Request = httptest.NewRequest(http.MethodGet, "/?start_time=bad", nil)
+ _, _, err = parseOpsTimeRange(c2, "1h")
+ require.Error(t, err)
+}
+
+func TestParseOpsRealtimeWindow(t *testing.T) {
+ dur, label, ok := parseOpsRealtimeWindow("5m")
+ require.True(t, ok)
+ require.Equal(t, 5*time.Minute, dur)
+ require.Equal(t, "5min", label)
+
+ _, _, ok = parseOpsRealtimeWindow("invalid")
+ require.False(t, ok)
+}
+
+func TestPickThroughputBucketSeconds(t *testing.T) {
+ require.Equal(t, 60, pickThroughputBucketSeconds(30*time.Minute))
+ require.Equal(t, 300, pickThroughputBucketSeconds(6*time.Hour))
+ require.Equal(t, 3600, pickThroughputBucketSeconds(48*time.Hour))
+}
+
+func TestParseOpsQueryMode(t *testing.T) {
+ gin.SetMode(gin.TestMode)
+ w := httptest.NewRecorder()
+ c, _ := gin.CreateTestContext(w)
+ c.Request = httptest.NewRequest(http.MethodGet, "/?mode=raw", nil)
+ require.Equal(t, service.ParseOpsQueryMode("raw"), parseOpsQueryMode(c))
+ require.Equal(t, service.OpsQueryMode(""), parseOpsQueryMode(nil))
+}
+
+func TestOpsAlertRuleValidation(t *testing.T) {
+ raw := map[string]json.RawMessage{
+ "name": json.RawMessage(`"High error rate"`),
+ "metric_type": json.RawMessage(`"error_rate"`),
+ "operator": json.RawMessage(`">"`),
+ "threshold": json.RawMessage(`90`),
+ }
+
+ validated, err := validateOpsAlertRulePayload(raw)
+ require.NoError(t, err)
+ require.Equal(t, "High error rate", validated.Name)
+
+ _, err = validateOpsAlertRulePayload(map[string]json.RawMessage{})
+ require.Error(t, err)
+
+ require.True(t, isPercentOrRateMetric("error_rate"))
+ require.False(t, isPercentOrRateMetric("concurrency_queue_depth"))
+}
+
+func TestOpsWSHelpers(t *testing.T) {
+ prefixes, invalid := parseTrustedProxyList("10.0.0.0/8,invalid")
+ require.Len(t, prefixes, 1)
+ require.Len(t, invalid, 1)
+
+ host := hostWithoutPort("example.com:443")
+ require.Equal(t, "example.com", host)
+
+ addr := netip.MustParseAddr("10.0.0.1")
+ require.True(t, isAddrInTrustedProxies(addr, prefixes))
+ require.False(t, isAddrInTrustedProxies(netip.MustParseAddr("192.168.0.1"), prefixes))
+}
diff --git a/backend/internal/handler/admin/admin_service_stub_test.go b/backend/internal/handler/admin/admin_service_stub_test.go
new file mode 100644
index 00000000..457d52fc
--- /dev/null
+++ b/backend/internal/handler/admin/admin_service_stub_test.go
@@ -0,0 +1,290 @@
+package admin
+
+import (
+ "context"
+ "time"
+
+ "github.com/Wei-Shaw/sub2api/internal/service"
+)
+
+type stubAdminService struct {
+ users []service.User
+ apiKeys []service.APIKey
+ groups []service.Group
+ accounts []service.Account
+ proxies []service.Proxy
+ proxyCounts []service.ProxyWithAccountCount
+ redeems []service.RedeemCode
+}
+
+func newStubAdminService() *stubAdminService {
+ now := time.Now().UTC()
+ user := service.User{
+ ID: 1,
+ Email: "user@example.com",
+ Role: service.RoleUser,
+ Status: service.StatusActive,
+ CreatedAt: now,
+ UpdatedAt: now,
+ }
+ apiKey := service.APIKey{
+ ID: 10,
+ UserID: user.ID,
+ Key: "sk-test",
+ Name: "test",
+ Status: service.StatusActive,
+ CreatedAt: now,
+ UpdatedAt: now,
+ }
+ group := service.Group{
+ ID: 2,
+ Name: "group",
+ Platform: service.PlatformAnthropic,
+ Status: service.StatusActive,
+ CreatedAt: now,
+ UpdatedAt: now,
+ }
+ account := service.Account{
+ ID: 3,
+ Name: "account",
+ Platform: service.PlatformAnthropic,
+ Type: service.AccountTypeOAuth,
+ Status: service.StatusActive,
+ CreatedAt: now,
+ UpdatedAt: now,
+ }
+ proxy := service.Proxy{
+ ID: 4,
+ Name: "proxy",
+ Protocol: "http",
+ Host: "127.0.0.1",
+ Port: 8080,
+ Status: service.StatusActive,
+ CreatedAt: now,
+ UpdatedAt: now,
+ }
+ redeem := service.RedeemCode{
+ ID: 5,
+ Code: "R-TEST",
+ Type: service.RedeemTypeBalance,
+ Value: 10,
+ Status: service.StatusUnused,
+ CreatedAt: now,
+ }
+ return &stubAdminService{
+ users: []service.User{user},
+ apiKeys: []service.APIKey{apiKey},
+ groups: []service.Group{group},
+ accounts: []service.Account{account},
+ proxies: []service.Proxy{proxy},
+ proxyCounts: []service.ProxyWithAccountCount{{Proxy: proxy, AccountCount: 1}},
+ redeems: []service.RedeemCode{redeem},
+ }
+}
+
+func (s *stubAdminService) ListUsers(ctx context.Context, page, pageSize int, filters service.UserListFilters) ([]service.User, int64, error) {
+ return s.users, int64(len(s.users)), nil
+}
+
+func (s *stubAdminService) GetUser(ctx context.Context, id int64) (*service.User, error) {
+ for i := range s.users {
+ if s.users[i].ID == id {
+ return &s.users[i], nil
+ }
+ }
+ user := service.User{ID: id, Email: "user@example.com", Status: service.StatusActive}
+ return &user, nil
+}
+
+func (s *stubAdminService) CreateUser(ctx context.Context, input *service.CreateUserInput) (*service.User, error) {
+ user := service.User{ID: 100, Email: input.Email, Status: service.StatusActive}
+ return &user, nil
+}
+
+func (s *stubAdminService) UpdateUser(ctx context.Context, id int64, input *service.UpdateUserInput) (*service.User, error) {
+ user := service.User{ID: id, Email: "updated@example.com", Status: service.StatusActive}
+ return &user, nil
+}
+
+func (s *stubAdminService) DeleteUser(ctx context.Context, id int64) error {
+ return nil
+}
+
+func (s *stubAdminService) UpdateUserBalance(ctx context.Context, userID int64, balance float64, operation string, notes string) (*service.User, error) {
+ user := service.User{ID: userID, Balance: balance, Status: service.StatusActive}
+ return &user, nil
+}
+
+func (s *stubAdminService) GetUserAPIKeys(ctx context.Context, userID int64, page, pageSize int) ([]service.APIKey, int64, error) {
+ return s.apiKeys, int64(len(s.apiKeys)), nil
+}
+
+func (s *stubAdminService) GetUserUsageStats(ctx context.Context, userID int64, period string) (any, error) {
+ return map[string]any{"user_id": userID}, nil
+}
+
+func (s *stubAdminService) ListGroups(ctx context.Context, page, pageSize int, platform, status, search string, isExclusive *bool) ([]service.Group, int64, error) {
+ return s.groups, int64(len(s.groups)), nil
+}
+
+func (s *stubAdminService) GetAllGroups(ctx context.Context) ([]service.Group, error) {
+ return s.groups, nil
+}
+
+func (s *stubAdminService) GetAllGroupsByPlatform(ctx context.Context, platform string) ([]service.Group, error) {
+ return s.groups, nil
+}
+
+func (s *stubAdminService) GetGroup(ctx context.Context, id int64) (*service.Group, error) {
+ group := service.Group{ID: id, Name: "group", Status: service.StatusActive}
+ return &group, nil
+}
+
+func (s *stubAdminService) CreateGroup(ctx context.Context, input *service.CreateGroupInput) (*service.Group, error) {
+ group := service.Group{ID: 200, Name: input.Name, Status: service.StatusActive}
+ return &group, nil
+}
+
+func (s *stubAdminService) UpdateGroup(ctx context.Context, id int64, input *service.UpdateGroupInput) (*service.Group, error) {
+ group := service.Group{ID: id, Name: input.Name, Status: service.StatusActive}
+ return &group, nil
+}
+
+func (s *stubAdminService) DeleteGroup(ctx context.Context, id int64) error {
+ return nil
+}
+
+func (s *stubAdminService) GetGroupAPIKeys(ctx context.Context, groupID int64, page, pageSize int) ([]service.APIKey, int64, error) {
+ return s.apiKeys, int64(len(s.apiKeys)), nil
+}
+
+func (s *stubAdminService) ListAccounts(ctx context.Context, page, pageSize int, platform, accountType, status, search string) ([]service.Account, int64, error) {
+ return s.accounts, int64(len(s.accounts)), nil
+}
+
+func (s *stubAdminService) GetAccount(ctx context.Context, id int64) (*service.Account, error) {
+ account := service.Account{ID: id, Name: "account", Status: service.StatusActive}
+ return &account, nil
+}
+
+func (s *stubAdminService) GetAccountsByIDs(ctx context.Context, ids []int64) ([]*service.Account, error) {
+ out := make([]*service.Account, 0, len(ids))
+ for _, id := range ids {
+ account := service.Account{ID: id, Name: "account", Status: service.StatusActive}
+ out = append(out, &account)
+ }
+ return out, nil
+}
+
+func (s *stubAdminService) CreateAccount(ctx context.Context, input *service.CreateAccountInput) (*service.Account, error) {
+ account := service.Account{ID: 300, Name: input.Name, Status: service.StatusActive}
+ return &account, nil
+}
+
+func (s *stubAdminService) UpdateAccount(ctx context.Context, id int64, input *service.UpdateAccountInput) (*service.Account, error) {
+ account := service.Account{ID: id, Name: input.Name, Status: service.StatusActive}
+ return &account, nil
+}
+
+func (s *stubAdminService) DeleteAccount(ctx context.Context, id int64) error {
+ return nil
+}
+
+func (s *stubAdminService) RefreshAccountCredentials(ctx context.Context, id int64) (*service.Account, error) {
+ account := service.Account{ID: id, Name: "account", Status: service.StatusActive}
+ return &account, nil
+}
+
+func (s *stubAdminService) ClearAccountError(ctx context.Context, id int64) (*service.Account, error) {
+ account := service.Account{ID: id, Name: "account", Status: service.StatusActive}
+ return &account, nil
+}
+
+func (s *stubAdminService) SetAccountSchedulable(ctx context.Context, id int64, schedulable bool) (*service.Account, error) {
+ account := service.Account{ID: id, Name: "account", Status: service.StatusActive, Schedulable: schedulable}
+ return &account, nil
+}
+
+func (s *stubAdminService) BulkUpdateAccounts(ctx context.Context, input *service.BulkUpdateAccountsInput) (*service.BulkUpdateAccountsResult, error) {
+ return &service.BulkUpdateAccountsResult{Success: 1, Failed: 0, SuccessIDs: []int64{1}}, nil
+}
+
+func (s *stubAdminService) ListProxies(ctx context.Context, page, pageSize int, protocol, status, search string) ([]service.Proxy, int64, error) {
+ return s.proxies, int64(len(s.proxies)), nil
+}
+
+func (s *stubAdminService) ListProxiesWithAccountCount(ctx context.Context, page, pageSize int, protocol, status, search string) ([]service.ProxyWithAccountCount, int64, error) {
+ return s.proxyCounts, int64(len(s.proxyCounts)), nil
+}
+
+func (s *stubAdminService) GetAllProxies(ctx context.Context) ([]service.Proxy, error) {
+ return s.proxies, nil
+}
+
+func (s *stubAdminService) GetAllProxiesWithAccountCount(ctx context.Context) ([]service.ProxyWithAccountCount, error) {
+ return s.proxyCounts, nil
+}
+
+func (s *stubAdminService) GetProxy(ctx context.Context, id int64) (*service.Proxy, error) {
+ proxy := service.Proxy{ID: id, Name: "proxy", Status: service.StatusActive}
+ return &proxy, nil
+}
+
+func (s *stubAdminService) CreateProxy(ctx context.Context, input *service.CreateProxyInput) (*service.Proxy, error) {
+ proxy := service.Proxy{ID: 400, Name: input.Name, Status: service.StatusActive}
+ return &proxy, nil
+}
+
+func (s *stubAdminService) UpdateProxy(ctx context.Context, id int64, input *service.UpdateProxyInput) (*service.Proxy, error) {
+ proxy := service.Proxy{ID: id, Name: input.Name, Status: service.StatusActive}
+ return &proxy, nil
+}
+
+func (s *stubAdminService) DeleteProxy(ctx context.Context, id int64) error {
+ return nil
+}
+
+func (s *stubAdminService) BatchDeleteProxies(ctx context.Context, ids []int64) (*service.ProxyBatchDeleteResult, error) {
+ return &service.ProxyBatchDeleteResult{DeletedIDs: ids}, nil
+}
+
+func (s *stubAdminService) GetProxyAccounts(ctx context.Context, proxyID int64) ([]service.ProxyAccountSummary, error) {
+ return []service.ProxyAccountSummary{{ID: 1, Name: "account"}}, nil
+}
+
+func (s *stubAdminService) CheckProxyExists(ctx context.Context, host string, port int, username, password string) (bool, error) {
+ return false, nil
+}
+
+func (s *stubAdminService) TestProxy(ctx context.Context, id int64) (*service.ProxyTestResult, error) {
+ return &service.ProxyTestResult{Success: true, Message: "ok"}, nil
+}
+
+func (s *stubAdminService) ListRedeemCodes(ctx context.Context, page, pageSize int, codeType, status, search string) ([]service.RedeemCode, int64, error) {
+ return s.redeems, int64(len(s.redeems)), nil
+}
+
+func (s *stubAdminService) GetRedeemCode(ctx context.Context, id int64) (*service.RedeemCode, error) {
+ code := service.RedeemCode{ID: id, Code: "R-TEST", Status: service.StatusUnused}
+ return &code, nil
+}
+
+func (s *stubAdminService) GenerateRedeemCodes(ctx context.Context, input *service.GenerateRedeemCodesInput) ([]service.RedeemCode, error) {
+ return s.redeems, nil
+}
+
+func (s *stubAdminService) DeleteRedeemCode(ctx context.Context, id int64) error {
+ return nil
+}
+
+func (s *stubAdminService) BatchDeleteRedeemCodes(ctx context.Context, ids []int64) (int64, error) {
+ return int64(len(ids)), nil
+}
+
+func (s *stubAdminService) ExpireRedeemCode(ctx context.Context, id int64) (*service.RedeemCode, error) {
+ code := service.RedeemCode{ID: id, Code: "R-TEST", Status: service.StatusUsed}
+ return &code, nil
+}
+
+// Ensure stub implements interface.
+var _ service.AdminService = (*stubAdminService)(nil)
diff --git a/backend/internal/handler/admin/dashboard_handler.go b/backend/internal/handler/admin/dashboard_handler.go
index 3f07403d..18365186 100644
--- a/backend/internal/handler/admin/dashboard_handler.go
+++ b/backend/internal/handler/admin/dashboard_handler.go
@@ -186,7 +186,7 @@ func (h *DashboardHandler) GetRealtimeMetrics(c *gin.Context) {
// GetUsageTrend handles getting usage trend data
// GET /api/v1/admin/dashboard/trend
-// Query params: start_date, end_date (YYYY-MM-DD), granularity (day/hour), user_id, api_key_id, model, account_id, group_id, stream
+// Query params: start_date, end_date (YYYY-MM-DD), granularity (day/hour), user_id, api_key_id, model, account_id, group_id, stream, billing_type
func (h *DashboardHandler) GetUsageTrend(c *gin.Context) {
startTime, endTime := parseTimeRange(c)
granularity := c.DefaultQuery("granularity", "day")
@@ -195,6 +195,7 @@ func (h *DashboardHandler) GetUsageTrend(c *gin.Context) {
var userID, apiKeyID, accountID, groupID int64
var model string
var stream *bool
+ var billingType *int8
if userIDStr := c.Query("user_id"); userIDStr != "" {
if id, err := strconv.ParseInt(userIDStr, 10, 64); err == nil {
@@ -224,8 +225,17 @@ func (h *DashboardHandler) GetUsageTrend(c *gin.Context) {
stream = &streamVal
}
}
+ if billingTypeStr := c.Query("billing_type"); billingTypeStr != "" {
+ if v, err := strconv.ParseInt(billingTypeStr, 10, 8); err == nil {
+ bt := int8(v)
+ billingType = &bt
+ } else {
+ response.BadRequest(c, "Invalid billing_type")
+ return
+ }
+ }
- trend, err := h.dashboardService.GetUsageTrendWithFilters(c.Request.Context(), startTime, endTime, granularity, userID, apiKeyID, accountID, groupID, model, stream)
+ trend, err := h.dashboardService.GetUsageTrendWithFilters(c.Request.Context(), startTime, endTime, granularity, userID, apiKeyID, accountID, groupID, model, stream, billingType)
if err != nil {
response.Error(c, 500, "Failed to get usage trend")
return
@@ -241,13 +251,14 @@ func (h *DashboardHandler) GetUsageTrend(c *gin.Context) {
// GetModelStats handles getting model usage statistics
// GET /api/v1/admin/dashboard/models
-// Query params: start_date, end_date (YYYY-MM-DD), user_id, api_key_id, account_id, group_id, stream
+// Query params: start_date, end_date (YYYY-MM-DD), user_id, api_key_id, account_id, group_id, stream, billing_type
func (h *DashboardHandler) GetModelStats(c *gin.Context) {
startTime, endTime := parseTimeRange(c)
// Parse optional filter params
var userID, apiKeyID, accountID, groupID int64
var stream *bool
+ var billingType *int8
if userIDStr := c.Query("user_id"); userIDStr != "" {
if id, err := strconv.ParseInt(userIDStr, 10, 64); err == nil {
@@ -274,8 +285,17 @@ func (h *DashboardHandler) GetModelStats(c *gin.Context) {
stream = &streamVal
}
}
+ if billingTypeStr := c.Query("billing_type"); billingTypeStr != "" {
+ if v, err := strconv.ParseInt(billingTypeStr, 10, 8); err == nil {
+ bt := int8(v)
+ billingType = &bt
+ } else {
+ response.BadRequest(c, "Invalid billing_type")
+ return
+ }
+ }
- stats, err := h.dashboardService.GetModelStatsWithFilters(c.Request.Context(), startTime, endTime, userID, apiKeyID, accountID, groupID, stream)
+ stats, err := h.dashboardService.GetModelStatsWithFilters(c.Request.Context(), startTime, endTime, userID, apiKeyID, accountID, groupID, stream, billingType)
if err != nil {
response.Error(c, 500, "Failed to get model statistics")
return
diff --git a/backend/internal/handler/admin/usage_cleanup_handler_test.go b/backend/internal/handler/admin/usage_cleanup_handler_test.go
new file mode 100644
index 00000000..d8684c39
--- /dev/null
+++ b/backend/internal/handler/admin/usage_cleanup_handler_test.go
@@ -0,0 +1,377 @@
+package admin
+
+import (
+ "bytes"
+ "context"
+ "encoding/json"
+ "database/sql"
+ "errors"
+ "net/http"
+ "net/http/httptest"
+ "sync"
+ "testing"
+ "time"
+
+ "github.com/Wei-Shaw/sub2api/internal/config"
+ "github.com/Wei-Shaw/sub2api/internal/handler/dto"
+ "github.com/Wei-Shaw/sub2api/internal/pkg/pagination"
+ "github.com/Wei-Shaw/sub2api/internal/pkg/response"
+ "github.com/Wei-Shaw/sub2api/internal/server/middleware"
+ "github.com/Wei-Shaw/sub2api/internal/service"
+ "github.com/gin-gonic/gin"
+ "github.com/stretchr/testify/require"
+)
+
+type cleanupRepoStub struct {
+ mu sync.Mutex
+ created []*service.UsageCleanupTask
+ listTasks []service.UsageCleanupTask
+ listResult *pagination.PaginationResult
+ listErr error
+ statusByID map[int64]string
+}
+
+func (s *cleanupRepoStub) CreateTask(ctx context.Context, task *service.UsageCleanupTask) error {
+ if task == nil {
+ return nil
+ }
+ s.mu.Lock()
+ defer s.mu.Unlock()
+ if task.ID == 0 {
+ task.ID = int64(len(s.created) + 1)
+ }
+ if task.CreatedAt.IsZero() {
+ task.CreatedAt = time.Now().UTC()
+ }
+ task.UpdatedAt = task.CreatedAt
+ clone := *task
+ s.created = append(s.created, &clone)
+ return nil
+}
+
+func (s *cleanupRepoStub) ListTasks(ctx context.Context, params pagination.PaginationParams) ([]service.UsageCleanupTask, *pagination.PaginationResult, error) {
+ s.mu.Lock()
+ defer s.mu.Unlock()
+ return s.listTasks, s.listResult, s.listErr
+}
+
+func (s *cleanupRepoStub) ClaimNextPendingTask(ctx context.Context, staleRunningAfterSeconds int64) (*service.UsageCleanupTask, error) {
+ return nil, nil
+}
+
+func (s *cleanupRepoStub) GetTaskStatus(ctx context.Context, taskID int64) (string, error) {
+ s.mu.Lock()
+ defer s.mu.Unlock()
+ if s.statusByID == nil {
+ return "", sql.ErrNoRows
+ }
+ status, ok := s.statusByID[taskID]
+ if !ok {
+ return "", sql.ErrNoRows
+ }
+ return status, nil
+}
+
+func (s *cleanupRepoStub) UpdateTaskProgress(ctx context.Context, taskID int64, deletedRows int64) error {
+ return nil
+}
+
+func (s *cleanupRepoStub) CancelTask(ctx context.Context, taskID int64, canceledBy int64) (bool, error) {
+ s.mu.Lock()
+ defer s.mu.Unlock()
+ if s.statusByID == nil {
+ s.statusByID = map[int64]string{}
+ }
+ status := s.statusByID[taskID]
+ if status != service.UsageCleanupStatusPending && status != service.UsageCleanupStatusRunning {
+ return false, nil
+ }
+ s.statusByID[taskID] = service.UsageCleanupStatusCanceled
+ return true, nil
+}
+
+func (s *cleanupRepoStub) MarkTaskSucceeded(ctx context.Context, taskID int64, deletedRows int64) error {
+ return nil
+}
+
+func (s *cleanupRepoStub) MarkTaskFailed(ctx context.Context, taskID int64, deletedRows int64, errorMsg string) error {
+ return nil
+}
+
+func (s *cleanupRepoStub) DeleteUsageLogsBatch(ctx context.Context, filters service.UsageCleanupFilters, limit int) (int64, error) {
+ return 0, nil
+}
+
+var _ service.UsageCleanupRepository = (*cleanupRepoStub)(nil)
+
+func setupCleanupRouter(cleanupService *service.UsageCleanupService, userID int64) *gin.Engine {
+ gin.SetMode(gin.TestMode)
+ router := gin.New()
+ if userID > 0 {
+ router.Use(func(c *gin.Context) {
+ c.Set(string(middleware.ContextKeyUser), middleware.AuthSubject{UserID: userID})
+ c.Next()
+ })
+ }
+
+ handler := NewUsageHandler(nil, nil, nil, cleanupService)
+ router.POST("/api/v1/admin/usage/cleanup-tasks", handler.CreateCleanupTask)
+ router.GET("/api/v1/admin/usage/cleanup-tasks", handler.ListCleanupTasks)
+ router.POST("/api/v1/admin/usage/cleanup-tasks/:id/cancel", handler.CancelCleanupTask)
+ return router
+}
+
+func TestUsageHandlerCreateCleanupTaskUnauthorized(t *testing.T) {
+ repo := &cleanupRepoStub{}
+ cfg := &config.Config{UsageCleanup: config.UsageCleanupConfig{Enabled: true, MaxRangeDays: 31}}
+ cleanupService := service.NewUsageCleanupService(repo, nil, nil, cfg)
+ router := setupCleanupRouter(cleanupService, 0)
+
+ req := httptest.NewRequest(http.MethodPost, "/api/v1/admin/usage/cleanup-tasks", bytes.NewBufferString(`{}`))
+ req.Header.Set("Content-Type", "application/json")
+ recorder := httptest.NewRecorder()
+ router.ServeHTTP(recorder, req)
+
+ require.Equal(t, http.StatusUnauthorized, recorder.Code)
+}
+
+func TestUsageHandlerCreateCleanupTaskUnavailable(t *testing.T) {
+ router := setupCleanupRouter(nil, 1)
+
+ req := httptest.NewRequest(http.MethodPost, "/api/v1/admin/usage/cleanup-tasks", bytes.NewBufferString(`{}`))
+ req.Header.Set("Content-Type", "application/json")
+ recorder := httptest.NewRecorder()
+ router.ServeHTTP(recorder, req)
+
+ require.Equal(t, http.StatusServiceUnavailable, recorder.Code)
+}
+
+func TestUsageHandlerCreateCleanupTaskBindError(t *testing.T) {
+ repo := &cleanupRepoStub{}
+ cfg := &config.Config{UsageCleanup: config.UsageCleanupConfig{Enabled: true, MaxRangeDays: 31}}
+ cleanupService := service.NewUsageCleanupService(repo, nil, nil, cfg)
+ router := setupCleanupRouter(cleanupService, 88)
+
+ req := httptest.NewRequest(http.MethodPost, "/api/v1/admin/usage/cleanup-tasks", bytes.NewBufferString("{bad-json"))
+ req.Header.Set("Content-Type", "application/json")
+ recorder := httptest.NewRecorder()
+ router.ServeHTTP(recorder, req)
+
+ require.Equal(t, http.StatusBadRequest, recorder.Code)
+}
+
+func TestUsageHandlerCreateCleanupTaskMissingRange(t *testing.T) {
+ repo := &cleanupRepoStub{}
+ cfg := &config.Config{UsageCleanup: config.UsageCleanupConfig{Enabled: true, MaxRangeDays: 31}}
+ cleanupService := service.NewUsageCleanupService(repo, nil, nil, cfg)
+ router := setupCleanupRouter(cleanupService, 88)
+
+ payload := map[string]any{
+ "start_date": "2024-01-01",
+ "timezone": "UTC",
+ }
+ body, err := json.Marshal(payload)
+ require.NoError(t, err)
+
+ req := httptest.NewRequest(http.MethodPost, "/api/v1/admin/usage/cleanup-tasks", bytes.NewReader(body))
+ req.Header.Set("Content-Type", "application/json")
+ recorder := httptest.NewRecorder()
+ router.ServeHTTP(recorder, req)
+
+ require.Equal(t, http.StatusBadRequest, recorder.Code)
+}
+
+func TestUsageHandlerCreateCleanupTaskInvalidDate(t *testing.T) {
+ repo := &cleanupRepoStub{}
+ cfg := &config.Config{UsageCleanup: config.UsageCleanupConfig{Enabled: true, MaxRangeDays: 31}}
+ cleanupService := service.NewUsageCleanupService(repo, nil, nil, cfg)
+ router := setupCleanupRouter(cleanupService, 88)
+
+ payload := map[string]any{
+ "start_date": "2024-13-01",
+ "end_date": "2024-01-02",
+ "timezone": "UTC",
+ }
+ body, err := json.Marshal(payload)
+ require.NoError(t, err)
+
+ req := httptest.NewRequest(http.MethodPost, "/api/v1/admin/usage/cleanup-tasks", bytes.NewReader(body))
+ req.Header.Set("Content-Type", "application/json")
+ recorder := httptest.NewRecorder()
+ router.ServeHTTP(recorder, req)
+
+ require.Equal(t, http.StatusBadRequest, recorder.Code)
+}
+
+func TestUsageHandlerCreateCleanupTaskInvalidEndDate(t *testing.T) {
+ repo := &cleanupRepoStub{}
+ cfg := &config.Config{UsageCleanup: config.UsageCleanupConfig{Enabled: true, MaxRangeDays: 31}}
+ cleanupService := service.NewUsageCleanupService(repo, nil, nil, cfg)
+ router := setupCleanupRouter(cleanupService, 88)
+
+ payload := map[string]any{
+ "start_date": "2024-01-01",
+ "end_date": "2024-02-40",
+ "timezone": "UTC",
+ }
+ body, err := json.Marshal(payload)
+ require.NoError(t, err)
+
+ req := httptest.NewRequest(http.MethodPost, "/api/v1/admin/usage/cleanup-tasks", bytes.NewReader(body))
+ req.Header.Set("Content-Type", "application/json")
+ recorder := httptest.NewRecorder()
+ router.ServeHTTP(recorder, req)
+
+ require.Equal(t, http.StatusBadRequest, recorder.Code)
+}
+
+func TestUsageHandlerCreateCleanupTaskSuccess(t *testing.T) {
+ repo := &cleanupRepoStub{}
+ cfg := &config.Config{UsageCleanup: config.UsageCleanupConfig{Enabled: true, MaxRangeDays: 31}}
+ cleanupService := service.NewUsageCleanupService(repo, nil, nil, cfg)
+ router := setupCleanupRouter(cleanupService, 99)
+
+ payload := map[string]any{
+ "start_date": " 2024-01-01 ",
+ "end_date": "2024-01-02",
+ "timezone": "UTC",
+ "model": "gpt-4",
+ }
+ body, err := json.Marshal(payload)
+ require.NoError(t, err)
+
+ req := httptest.NewRequest(http.MethodPost, "/api/v1/admin/usage/cleanup-tasks", bytes.NewReader(body))
+ req.Header.Set("Content-Type", "application/json")
+ recorder := httptest.NewRecorder()
+ router.ServeHTTP(recorder, req)
+
+ require.Equal(t, http.StatusOK, recorder.Code)
+
+ var resp response.Response
+ require.NoError(t, json.Unmarshal(recorder.Body.Bytes(), &resp))
+ require.Equal(t, 0, resp.Code)
+
+ repo.mu.Lock()
+ defer repo.mu.Unlock()
+ require.Len(t, repo.created, 1)
+ created := repo.created[0]
+ require.Equal(t, int64(99), created.CreatedBy)
+ require.NotNil(t, created.Filters.Model)
+ require.Equal(t, "gpt-4", *created.Filters.Model)
+
+ start := time.Date(2024, 1, 1, 0, 0, 0, 0, time.UTC)
+ end := time.Date(2024, 1, 2, 0, 0, 0, 0, time.UTC).Add(24*time.Hour - time.Nanosecond)
+ require.True(t, created.Filters.StartTime.Equal(start))
+ require.True(t, created.Filters.EndTime.Equal(end))
+}
+
+func TestUsageHandlerListCleanupTasksUnavailable(t *testing.T) {
+ router := setupCleanupRouter(nil, 0)
+
+ req := httptest.NewRequest(http.MethodGet, "/api/v1/admin/usage/cleanup-tasks", nil)
+ recorder := httptest.NewRecorder()
+ router.ServeHTTP(recorder, req)
+
+ require.Equal(t, http.StatusServiceUnavailable, recorder.Code)
+}
+
+func TestUsageHandlerListCleanupTasksSuccess(t *testing.T) {
+ repo := &cleanupRepoStub{}
+ repo.listTasks = []service.UsageCleanupTask{
+ {
+ ID: 7,
+ Status: service.UsageCleanupStatusSucceeded,
+ CreatedBy: 4,
+ },
+ }
+ repo.listResult = &pagination.PaginationResult{Total: 1, Page: 1, PageSize: 20, Pages: 1}
+ cfg := &config.Config{UsageCleanup: config.UsageCleanupConfig{Enabled: true, MaxRangeDays: 31}}
+ cleanupService := service.NewUsageCleanupService(repo, nil, nil, cfg)
+ router := setupCleanupRouter(cleanupService, 1)
+
+ req := httptest.NewRequest(http.MethodGet, "/api/v1/admin/usage/cleanup-tasks", nil)
+ recorder := httptest.NewRecorder()
+ router.ServeHTTP(recorder, req)
+
+ require.Equal(t, http.StatusOK, recorder.Code)
+
+ var resp struct {
+ Code int `json:"code"`
+ Data struct {
+ Items []dto.UsageCleanupTask `json:"items"`
+ Total int64 `json:"total"`
+ Page int `json:"page"`
+ } `json:"data"`
+ }
+ require.NoError(t, json.Unmarshal(recorder.Body.Bytes(), &resp))
+ require.Equal(t, 0, resp.Code)
+ require.Len(t, resp.Data.Items, 1)
+ require.Equal(t, int64(7), resp.Data.Items[0].ID)
+ require.Equal(t, int64(1), resp.Data.Total)
+ require.Equal(t, 1, resp.Data.Page)
+}
+
+func TestUsageHandlerListCleanupTasksError(t *testing.T) {
+ repo := &cleanupRepoStub{listErr: errors.New("boom")}
+ cfg := &config.Config{UsageCleanup: config.UsageCleanupConfig{Enabled: true, MaxRangeDays: 31}}
+ cleanupService := service.NewUsageCleanupService(repo, nil, nil, cfg)
+ router := setupCleanupRouter(cleanupService, 1)
+
+ req := httptest.NewRequest(http.MethodGet, "/api/v1/admin/usage/cleanup-tasks", nil)
+ recorder := httptest.NewRecorder()
+ router.ServeHTTP(recorder, req)
+
+ require.Equal(t, http.StatusInternalServerError, recorder.Code)
+}
+
+func TestUsageHandlerCancelCleanupTaskUnauthorized(t *testing.T) {
+ repo := &cleanupRepoStub{}
+ cfg := &config.Config{UsageCleanup: config.UsageCleanupConfig{Enabled: true}}
+ cleanupService := service.NewUsageCleanupService(repo, nil, nil, cfg)
+ router := setupCleanupRouter(cleanupService, 0)
+
+ req := httptest.NewRequest(http.MethodPost, "/api/v1/admin/usage/cleanup-tasks/1/cancel", nil)
+ rec := httptest.NewRecorder()
+ router.ServeHTTP(rec, req)
+
+ require.Equal(t, http.StatusUnauthorized, rec.Code)
+}
+
+func TestUsageHandlerCancelCleanupTaskNotFound(t *testing.T) {
+ repo := &cleanupRepoStub{}
+ cfg := &config.Config{UsageCleanup: config.UsageCleanupConfig{Enabled: true}}
+ cleanupService := service.NewUsageCleanupService(repo, nil, nil, cfg)
+ router := setupCleanupRouter(cleanupService, 1)
+
+ req := httptest.NewRequest(http.MethodPost, "/api/v1/admin/usage/cleanup-tasks/999/cancel", nil)
+ rec := httptest.NewRecorder()
+ router.ServeHTTP(rec, req)
+
+ require.Equal(t, http.StatusNotFound, rec.Code)
+}
+
+func TestUsageHandlerCancelCleanupTaskConflict(t *testing.T) {
+ repo := &cleanupRepoStub{statusByID: map[int64]string{2: service.UsageCleanupStatusSucceeded}}
+ cfg := &config.Config{UsageCleanup: config.UsageCleanupConfig{Enabled: true}}
+ cleanupService := service.NewUsageCleanupService(repo, nil, nil, cfg)
+ router := setupCleanupRouter(cleanupService, 1)
+
+ req := httptest.NewRequest(http.MethodPost, "/api/v1/admin/usage/cleanup-tasks/2/cancel", nil)
+ rec := httptest.NewRecorder()
+ router.ServeHTTP(rec, req)
+
+ require.Equal(t, http.StatusConflict, rec.Code)
+}
+
+func TestUsageHandlerCancelCleanupTaskSuccess(t *testing.T) {
+ repo := &cleanupRepoStub{statusByID: map[int64]string{3: service.UsageCleanupStatusPending}}
+ cfg := &config.Config{UsageCleanup: config.UsageCleanupConfig{Enabled: true}}
+ cleanupService := service.NewUsageCleanupService(repo, nil, nil, cfg)
+ router := setupCleanupRouter(cleanupService, 1)
+
+ req := httptest.NewRequest(http.MethodPost, "/api/v1/admin/usage/cleanup-tasks/3/cancel", nil)
+ rec := httptest.NewRecorder()
+ router.ServeHTTP(rec, req)
+
+ require.Equal(t, http.StatusOK, rec.Code)
+}
diff --git a/backend/internal/handler/admin/usage_handler.go b/backend/internal/handler/admin/usage_handler.go
index c7b983f1..81aa78e1 100644
--- a/backend/internal/handler/admin/usage_handler.go
+++ b/backend/internal/handler/admin/usage_handler.go
@@ -1,7 +1,10 @@
package admin
import (
+ "log"
+ "net/http"
"strconv"
+ "strings"
"time"
"github.com/Wei-Shaw/sub2api/internal/handler/dto"
@@ -9,6 +12,7 @@ import (
"github.com/Wei-Shaw/sub2api/internal/pkg/response"
"github.com/Wei-Shaw/sub2api/internal/pkg/timezone"
"github.com/Wei-Shaw/sub2api/internal/pkg/usagestats"
+ "github.com/Wei-Shaw/sub2api/internal/server/middleware"
"github.com/Wei-Shaw/sub2api/internal/service"
"github.com/gin-gonic/gin"
@@ -16,9 +20,10 @@ import (
// UsageHandler handles admin usage-related requests
type UsageHandler struct {
- usageService *service.UsageService
- apiKeyService *service.APIKeyService
- adminService service.AdminService
+ usageService *service.UsageService
+ apiKeyService *service.APIKeyService
+ adminService service.AdminService
+ cleanupService *service.UsageCleanupService
}
// NewUsageHandler creates a new admin usage handler
@@ -26,14 +31,30 @@ func NewUsageHandler(
usageService *service.UsageService,
apiKeyService *service.APIKeyService,
adminService service.AdminService,
+ cleanupService *service.UsageCleanupService,
) *UsageHandler {
return &UsageHandler{
- usageService: usageService,
- apiKeyService: apiKeyService,
- adminService: adminService,
+ usageService: usageService,
+ apiKeyService: apiKeyService,
+ adminService: adminService,
+ cleanupService: cleanupService,
}
}
+// CreateUsageCleanupTaskRequest represents cleanup task creation request
+type CreateUsageCleanupTaskRequest struct {
+ StartDate string `json:"start_date"`
+ EndDate string `json:"end_date"`
+ UserID *int64 `json:"user_id"`
+ APIKeyID *int64 `json:"api_key_id"`
+ AccountID *int64 `json:"account_id"`
+ GroupID *int64 `json:"group_id"`
+ Model *string `json:"model"`
+ Stream *bool `json:"stream"`
+ BillingType *int8 `json:"billing_type"`
+ Timezone string `json:"timezone"`
+}
+
// List handles listing all usage records with filters
// GET /api/v1/admin/usage
func (h *UsageHandler) List(c *gin.Context) {
@@ -344,3 +365,162 @@ func (h *UsageHandler) SearchAPIKeys(c *gin.Context) {
response.Success(c, result)
}
+
+// ListCleanupTasks handles listing usage cleanup tasks
+// GET /api/v1/admin/usage/cleanup-tasks
+func (h *UsageHandler) ListCleanupTasks(c *gin.Context) {
+ if h.cleanupService == nil {
+ response.Error(c, http.StatusServiceUnavailable, "Usage cleanup service unavailable")
+ return
+ }
+ operator := int64(0)
+ if subject, ok := middleware.GetAuthSubjectFromContext(c); ok {
+ operator = subject.UserID
+ }
+ page, pageSize := response.ParsePagination(c)
+ log.Printf("[UsageCleanup] 请求清理任务列表: operator=%d page=%d page_size=%d", operator, page, pageSize)
+ params := pagination.PaginationParams{Page: page, PageSize: pageSize}
+ tasks, result, err := h.cleanupService.ListTasks(c.Request.Context(), params)
+ if err != nil {
+ log.Printf("[UsageCleanup] 查询清理任务列表失败: operator=%d page=%d page_size=%d err=%v", operator, page, pageSize, err)
+ response.ErrorFrom(c, err)
+ return
+ }
+ out := make([]dto.UsageCleanupTask, 0, len(tasks))
+ for i := range tasks {
+ out = append(out, *dto.UsageCleanupTaskFromService(&tasks[i]))
+ }
+ log.Printf("[UsageCleanup] 返回清理任务列表: operator=%d total=%d items=%d page=%d page_size=%d", operator, result.Total, len(out), page, pageSize)
+ response.Paginated(c, out, result.Total, page, pageSize)
+}
+
+// CreateCleanupTask handles creating a usage cleanup task
+// POST /api/v1/admin/usage/cleanup-tasks
+func (h *UsageHandler) CreateCleanupTask(c *gin.Context) {
+ if h.cleanupService == nil {
+ response.Error(c, http.StatusServiceUnavailable, "Usage cleanup service unavailable")
+ return
+ }
+ subject, ok := middleware.GetAuthSubjectFromContext(c)
+ if !ok || subject.UserID <= 0 {
+ response.Unauthorized(c, "Unauthorized")
+ return
+ }
+
+ var req CreateUsageCleanupTaskRequest
+ if err := c.ShouldBindJSON(&req); err != nil {
+ response.BadRequest(c, "Invalid request: "+err.Error())
+ return
+ }
+ req.StartDate = strings.TrimSpace(req.StartDate)
+ req.EndDate = strings.TrimSpace(req.EndDate)
+ if req.StartDate == "" || req.EndDate == "" {
+ response.BadRequest(c, "start_date and end_date are required")
+ return
+ }
+
+ startTime, err := timezone.ParseInUserLocation("2006-01-02", req.StartDate, req.Timezone)
+ if err != nil {
+ response.BadRequest(c, "Invalid start_date format, use YYYY-MM-DD")
+ return
+ }
+ endTime, err := timezone.ParseInUserLocation("2006-01-02", req.EndDate, req.Timezone)
+ if err != nil {
+ response.BadRequest(c, "Invalid end_date format, use YYYY-MM-DD")
+ return
+ }
+ endTime = endTime.Add(24*time.Hour - time.Nanosecond)
+
+ filters := service.UsageCleanupFilters{
+ StartTime: startTime,
+ EndTime: endTime,
+ UserID: req.UserID,
+ APIKeyID: req.APIKeyID,
+ AccountID: req.AccountID,
+ GroupID: req.GroupID,
+ Model: req.Model,
+ Stream: req.Stream,
+ BillingType: req.BillingType,
+ }
+
+ var userID any
+ if filters.UserID != nil {
+ userID = *filters.UserID
+ }
+ var apiKeyID any
+ if filters.APIKeyID != nil {
+ apiKeyID = *filters.APIKeyID
+ }
+ var accountID any
+ if filters.AccountID != nil {
+ accountID = *filters.AccountID
+ }
+ var groupID any
+ if filters.GroupID != nil {
+ groupID = *filters.GroupID
+ }
+ var model any
+ if filters.Model != nil {
+ model = *filters.Model
+ }
+ var stream any
+ if filters.Stream != nil {
+ stream = *filters.Stream
+ }
+ var billingType any
+ if filters.BillingType != nil {
+ billingType = *filters.BillingType
+ }
+
+ log.Printf("[UsageCleanup] 请求创建清理任务: operator=%d start=%s end=%s user_id=%v api_key_id=%v account_id=%v group_id=%v model=%v stream=%v billing_type=%v tz=%q",
+ subject.UserID,
+ filters.StartTime.Format(time.RFC3339),
+ filters.EndTime.Format(time.RFC3339),
+ userID,
+ apiKeyID,
+ accountID,
+ groupID,
+ model,
+ stream,
+ billingType,
+ req.Timezone,
+ )
+
+ task, err := h.cleanupService.CreateTask(c.Request.Context(), filters, subject.UserID)
+ if err != nil {
+ log.Printf("[UsageCleanup] 创建清理任务失败: operator=%d err=%v", subject.UserID, err)
+ response.ErrorFrom(c, err)
+ return
+ }
+
+ log.Printf("[UsageCleanup] 清理任务已创建: task=%d operator=%d status=%s", task.ID, subject.UserID, task.Status)
+ response.Success(c, dto.UsageCleanupTaskFromService(task))
+}
+
+// CancelCleanupTask handles canceling a usage cleanup task
+// POST /api/v1/admin/usage/cleanup-tasks/:id/cancel
+func (h *UsageHandler) CancelCleanupTask(c *gin.Context) {
+ if h.cleanupService == nil {
+ response.Error(c, http.StatusServiceUnavailable, "Usage cleanup service unavailable")
+ return
+ }
+ subject, ok := middleware.GetAuthSubjectFromContext(c)
+ if !ok || subject.UserID <= 0 {
+ response.Unauthorized(c, "Unauthorized")
+ return
+ }
+ idStr := strings.TrimSpace(c.Param("id"))
+ taskID, err := strconv.ParseInt(idStr, 10, 64)
+ if err != nil || taskID <= 0 {
+ response.BadRequest(c, "Invalid task id")
+ return
+ }
+ log.Printf("[UsageCleanup] 请求取消清理任务: task=%d operator=%d", taskID, subject.UserID)
+ if err := h.cleanupService.CancelTask(c.Request.Context(), taskID, subject.UserID); err != nil {
+ log.Printf("[UsageCleanup] 取消清理任务失败: task=%d operator=%d err=%v", taskID, subject.UserID, err)
+ response.ErrorFrom(c, err)
+ return
+ }
+ log.Printf("[UsageCleanup] 清理任务已取消: task=%d operator=%d", taskID, subject.UserID)
+ response.Success(c, gin.H{"id": taskID, "status": service.UsageCleanupStatusCanceled})
+}
diff --git a/backend/internal/handler/dto/mappers.go b/backend/internal/handler/dto/mappers.go
index 4d59ddff..f43fac27 100644
--- a/backend/internal/handler/dto/mappers.go
+++ b/backend/internal/handler/dto/mappers.go
@@ -340,6 +340,36 @@ func UsageLogFromServiceAdmin(l *service.UsageLog) *UsageLog {
return usageLogFromServiceBase(l, AccountSummaryFromService(l.Account), true)
}
+func UsageCleanupTaskFromService(task *service.UsageCleanupTask) *UsageCleanupTask {
+ if task == nil {
+ return nil
+ }
+ return &UsageCleanupTask{
+ ID: task.ID,
+ Status: task.Status,
+ Filters: UsageCleanupFilters{
+ StartTime: task.Filters.StartTime,
+ EndTime: task.Filters.EndTime,
+ UserID: task.Filters.UserID,
+ APIKeyID: task.Filters.APIKeyID,
+ AccountID: task.Filters.AccountID,
+ GroupID: task.Filters.GroupID,
+ Model: task.Filters.Model,
+ Stream: task.Filters.Stream,
+ BillingType: task.Filters.BillingType,
+ },
+ CreatedBy: task.CreatedBy,
+ DeletedRows: task.DeletedRows,
+ ErrorMessage: task.ErrorMsg,
+ CanceledBy: task.CanceledBy,
+ CanceledAt: task.CanceledAt,
+ StartedAt: task.StartedAt,
+ FinishedAt: task.FinishedAt,
+ CreatedAt: task.CreatedAt,
+ UpdatedAt: task.UpdatedAt,
+ }
+}
+
func SettingFromService(s *service.Setting) *Setting {
if s == nil {
return nil
diff --git a/backend/internal/handler/dto/types.go b/backend/internal/handler/dto/types.go
index 914f2b23..5fa5a3fd 100644
--- a/backend/internal/handler/dto/types.go
+++ b/backend/internal/handler/dto/types.go
@@ -223,6 +223,33 @@ type UsageLog struct {
Subscription *UserSubscription `json:"subscription,omitempty"`
}
+type UsageCleanupFilters struct {
+ StartTime time.Time `json:"start_time"`
+ EndTime time.Time `json:"end_time"`
+ UserID *int64 `json:"user_id,omitempty"`
+ APIKeyID *int64 `json:"api_key_id,omitempty"`
+ AccountID *int64 `json:"account_id,omitempty"`
+ GroupID *int64 `json:"group_id,omitempty"`
+ Model *string `json:"model,omitempty"`
+ Stream *bool `json:"stream,omitempty"`
+ BillingType *int8 `json:"billing_type,omitempty"`
+}
+
+type UsageCleanupTask struct {
+ ID int64 `json:"id"`
+ Status string `json:"status"`
+ Filters UsageCleanupFilters `json:"filters"`
+ CreatedBy int64 `json:"created_by"`
+ DeletedRows int64 `json:"deleted_rows"`
+ ErrorMessage *string `json:"error_message,omitempty"`
+ CanceledBy *int64 `json:"canceled_by,omitempty"`
+ CanceledAt *time.Time `json:"canceled_at,omitempty"`
+ StartedAt *time.Time `json:"started_at,omitempty"`
+ FinishedAt *time.Time `json:"finished_at,omitempty"`
+ CreatedAt time.Time `json:"created_at"`
+ UpdatedAt time.Time `json:"updated_at"`
+}
+
// AccountSummary is a minimal account info for usage log display.
// It intentionally excludes sensitive fields like Credentials, Proxy, etc.
type AccountSummary struct {
diff --git a/backend/internal/repository/dashboard_aggregation_repo.go b/backend/internal/repository/dashboard_aggregation_repo.go
index 3543e061..59bbd6a3 100644
--- a/backend/internal/repository/dashboard_aggregation_repo.go
+++ b/backend/internal/repository/dashboard_aggregation_repo.go
@@ -77,6 +77,75 @@ func (r *dashboardAggregationRepository) AggregateRange(ctx context.Context, sta
return nil
}
+func (r *dashboardAggregationRepository) RecomputeRange(ctx context.Context, start, end time.Time) error {
+ if r == nil || r.sql == nil {
+ return nil
+ }
+ loc := timezone.Location()
+ startLocal := start.In(loc)
+ endLocal := end.In(loc)
+ if !endLocal.After(startLocal) {
+ return nil
+ }
+
+ hourStart := startLocal.Truncate(time.Hour)
+ hourEnd := endLocal.Truncate(time.Hour)
+ if endLocal.After(hourEnd) {
+ hourEnd = hourEnd.Add(time.Hour)
+ }
+
+ dayStart := truncateToDay(startLocal)
+ dayEnd := truncateToDay(endLocal)
+ if endLocal.After(dayEnd) {
+ dayEnd = dayEnd.Add(24 * time.Hour)
+ }
+
+ // 尽量使用事务保证范围内的一致性(允许在非 *sql.DB 的情况下退化为非事务执行)。
+ if db, ok := r.sql.(*sql.DB); ok {
+ tx, err := db.BeginTx(ctx, nil)
+ if err != nil {
+ return err
+ }
+ txRepo := newDashboardAggregationRepositoryWithSQL(tx)
+ if err := txRepo.recomputeRangeInTx(ctx, hourStart, hourEnd, dayStart, dayEnd); err != nil {
+ _ = tx.Rollback()
+ return err
+ }
+ return tx.Commit()
+ }
+ return r.recomputeRangeInTx(ctx, hourStart, hourEnd, dayStart, dayEnd)
+}
+
+func (r *dashboardAggregationRepository) recomputeRangeInTx(ctx context.Context, hourStart, hourEnd, dayStart, dayEnd time.Time) error {
+ // 先清空范围内桶,再重建(避免仅增量插入导致活跃用户等指标无法回退)。
+ if _, err := r.sql.ExecContext(ctx, "DELETE FROM usage_dashboard_hourly WHERE bucket_start >= $1 AND bucket_start < $2", hourStart, hourEnd); err != nil {
+ return err
+ }
+ if _, err := r.sql.ExecContext(ctx, "DELETE FROM usage_dashboard_hourly_users WHERE bucket_start >= $1 AND bucket_start < $2", hourStart, hourEnd); err != nil {
+ return err
+ }
+ if _, err := r.sql.ExecContext(ctx, "DELETE FROM usage_dashboard_daily WHERE bucket_date >= $1::date AND bucket_date < $2::date", dayStart, dayEnd); err != nil {
+ return err
+ }
+ if _, err := r.sql.ExecContext(ctx, "DELETE FROM usage_dashboard_daily_users WHERE bucket_date >= $1::date AND bucket_date < $2::date", dayStart, dayEnd); err != nil {
+ return err
+ }
+
+ if err := r.insertHourlyActiveUsers(ctx, hourStart, hourEnd); err != nil {
+ return err
+ }
+ if err := r.insertDailyActiveUsers(ctx, hourStart, hourEnd); err != nil {
+ return err
+ }
+ if err := r.upsertHourlyAggregates(ctx, hourStart, hourEnd); err != nil {
+ return err
+ }
+ if err := r.upsertDailyAggregates(ctx, dayStart, dayEnd); err != nil {
+ return err
+ }
+ return nil
+}
+
func (r *dashboardAggregationRepository) GetAggregationWatermark(ctx context.Context) (time.Time, error) {
var ts time.Time
query := "SELECT last_aggregated_at FROM usage_dashboard_aggregation_watermark WHERE id = 1"
diff --git a/backend/internal/repository/usage_cleanup_repo.go b/backend/internal/repository/usage_cleanup_repo.go
new file mode 100644
index 00000000..b703cc9f
--- /dev/null
+++ b/backend/internal/repository/usage_cleanup_repo.go
@@ -0,0 +1,363 @@
+package repository
+
+import (
+ "context"
+ "database/sql"
+ "encoding/json"
+ "errors"
+ "fmt"
+ "strings"
+
+ "github.com/Wei-Shaw/sub2api/internal/pkg/pagination"
+ "github.com/Wei-Shaw/sub2api/internal/service"
+)
+
+type usageCleanupRepository struct {
+ sql sqlExecutor
+}
+
+func NewUsageCleanupRepository(sqlDB *sql.DB) service.UsageCleanupRepository {
+ return &usageCleanupRepository{sql: sqlDB}
+}
+
+func (r *usageCleanupRepository) CreateTask(ctx context.Context, task *service.UsageCleanupTask) error {
+ if task == nil {
+ return nil
+ }
+ filtersJSON, err := json.Marshal(task.Filters)
+ if err != nil {
+ return fmt.Errorf("marshal cleanup filters: %w", err)
+ }
+ query := `
+ INSERT INTO usage_cleanup_tasks (
+ status,
+ filters,
+ created_by,
+ deleted_rows
+ ) VALUES ($1, $2, $3, $4)
+ RETURNING id, created_at, updated_at
+ `
+ if err := scanSingleRow(ctx, r.sql, query, []any{task.Status, filtersJSON, task.CreatedBy, task.DeletedRows}, &task.ID, &task.CreatedAt, &task.UpdatedAt); err != nil {
+ return err
+ }
+ return nil
+}
+
+func (r *usageCleanupRepository) ListTasks(ctx context.Context, params pagination.PaginationParams) ([]service.UsageCleanupTask, *pagination.PaginationResult, error) {
+ var total int64
+ if err := scanSingleRow(ctx, r.sql, "SELECT COUNT(*) FROM usage_cleanup_tasks", nil, &total); err != nil {
+ return nil, nil, err
+ }
+ if total == 0 {
+ return []service.UsageCleanupTask{}, paginationResultFromTotal(0, params), nil
+ }
+
+ query := `
+ SELECT id, status, filters, created_by, deleted_rows, error_message,
+ canceled_by, canceled_at,
+ started_at, finished_at, created_at, updated_at
+ FROM usage_cleanup_tasks
+ ORDER BY created_at DESC
+ LIMIT $1 OFFSET $2
+ `
+ rows, err := r.sql.QueryContext(ctx, query, params.Limit(), params.Offset())
+ if err != nil {
+ return nil, nil, err
+ }
+ defer rows.Close()
+
+ tasks := make([]service.UsageCleanupTask, 0)
+ for rows.Next() {
+ var task service.UsageCleanupTask
+ var filtersJSON []byte
+ var errMsg sql.NullString
+ var canceledBy sql.NullInt64
+ var canceledAt sql.NullTime
+ var startedAt sql.NullTime
+ var finishedAt sql.NullTime
+ if err := rows.Scan(
+ &task.ID,
+ &task.Status,
+ &filtersJSON,
+ &task.CreatedBy,
+ &task.DeletedRows,
+ &errMsg,
+ &canceledBy,
+ &canceledAt,
+ &startedAt,
+ &finishedAt,
+ &task.CreatedAt,
+ &task.UpdatedAt,
+ ); err != nil {
+ return nil, nil, err
+ }
+ if err := json.Unmarshal(filtersJSON, &task.Filters); err != nil {
+ return nil, nil, fmt.Errorf("parse cleanup filters: %w", err)
+ }
+ if errMsg.Valid {
+ task.ErrorMsg = &errMsg.String
+ }
+ if canceledBy.Valid {
+ v := canceledBy.Int64
+ task.CanceledBy = &v
+ }
+ if canceledAt.Valid {
+ task.CanceledAt = &canceledAt.Time
+ }
+ if startedAt.Valid {
+ task.StartedAt = &startedAt.Time
+ }
+ if finishedAt.Valid {
+ task.FinishedAt = &finishedAt.Time
+ }
+ tasks = append(tasks, task)
+ }
+ if err := rows.Err(); err != nil {
+ return nil, nil, err
+ }
+ return tasks, paginationResultFromTotal(total, params), nil
+}
+
+func (r *usageCleanupRepository) ClaimNextPendingTask(ctx context.Context, staleRunningAfterSeconds int64) (*service.UsageCleanupTask, error) {
+ if staleRunningAfterSeconds <= 0 {
+ staleRunningAfterSeconds = 1800
+ }
+ query := `
+ WITH next AS (
+ SELECT id
+ FROM usage_cleanup_tasks
+ WHERE status = $1
+ OR (
+ status = $2
+ AND started_at IS NOT NULL
+ AND started_at < NOW() - ($3 * interval '1 second')
+ )
+ ORDER BY created_at ASC
+ LIMIT 1
+ FOR UPDATE SKIP LOCKED
+ )
+ UPDATE usage_cleanup_tasks
+ SET status = $4,
+ started_at = NOW(),
+ finished_at = NULL,
+ error_message = NULL,
+ updated_at = NOW()
+ FROM next
+ WHERE usage_cleanup_tasks.id = next.id
+ RETURNING id, status, filters, created_by, deleted_rows, error_message,
+ started_at, finished_at, created_at, updated_at
+ `
+ var task service.UsageCleanupTask
+ var filtersJSON []byte
+ var errMsg sql.NullString
+ var startedAt sql.NullTime
+ var finishedAt sql.NullTime
+ if err := scanSingleRow(
+ ctx,
+ r.sql,
+ query,
+ []any{
+ service.UsageCleanupStatusPending,
+ service.UsageCleanupStatusRunning,
+ staleRunningAfterSeconds,
+ service.UsageCleanupStatusRunning,
+ },
+ &task.ID,
+ &task.Status,
+ &filtersJSON,
+ &task.CreatedBy,
+ &task.DeletedRows,
+ &errMsg,
+ &startedAt,
+ &finishedAt,
+ &task.CreatedAt,
+ &task.UpdatedAt,
+ ); err != nil {
+ if errors.Is(err, sql.ErrNoRows) {
+ return nil, nil
+ }
+ return nil, err
+ }
+ if err := json.Unmarshal(filtersJSON, &task.Filters); err != nil {
+ return nil, fmt.Errorf("parse cleanup filters: %w", err)
+ }
+ if errMsg.Valid {
+ task.ErrorMsg = &errMsg.String
+ }
+ if startedAt.Valid {
+ task.StartedAt = &startedAt.Time
+ }
+ if finishedAt.Valid {
+ task.FinishedAt = &finishedAt.Time
+ }
+ return &task, nil
+}
+
+func (r *usageCleanupRepository) GetTaskStatus(ctx context.Context, taskID int64) (string, error) {
+ var status string
+ if err := scanSingleRow(ctx, r.sql, "SELECT status FROM usage_cleanup_tasks WHERE id = $1", []any{taskID}, &status); err != nil {
+ return "", err
+ }
+ return status, nil
+}
+
+func (r *usageCleanupRepository) UpdateTaskProgress(ctx context.Context, taskID int64, deletedRows int64) error {
+ query := `
+ UPDATE usage_cleanup_tasks
+ SET deleted_rows = $1,
+ updated_at = NOW()
+ WHERE id = $2
+ `
+ _, err := r.sql.ExecContext(ctx, query, deletedRows, taskID)
+ return err
+}
+
+func (r *usageCleanupRepository) CancelTask(ctx context.Context, taskID int64, canceledBy int64) (bool, error) {
+ query := `
+ UPDATE usage_cleanup_tasks
+ SET status = $1,
+ canceled_by = $3,
+ canceled_at = NOW(),
+ finished_at = NOW(),
+ error_message = NULL,
+ updated_at = NOW()
+ WHERE id = $2
+ AND status IN ($4, $5)
+ RETURNING id
+ `
+ var id int64
+ err := scanSingleRow(ctx, r.sql, query, []any{
+ service.UsageCleanupStatusCanceled,
+ taskID,
+ canceledBy,
+ service.UsageCleanupStatusPending,
+ service.UsageCleanupStatusRunning,
+ }, &id)
+ if errors.Is(err, sql.ErrNoRows) {
+ return false, nil
+ }
+ if err != nil {
+ return false, err
+ }
+ return true, nil
+}
+
+func (r *usageCleanupRepository) MarkTaskSucceeded(ctx context.Context, taskID int64, deletedRows int64) error {
+ query := `
+ UPDATE usage_cleanup_tasks
+ SET status = $1,
+ deleted_rows = $2,
+ finished_at = NOW(),
+ updated_at = NOW()
+ WHERE id = $3
+ `
+ _, err := r.sql.ExecContext(ctx, query, service.UsageCleanupStatusSucceeded, deletedRows, taskID)
+ return err
+}
+
+func (r *usageCleanupRepository) MarkTaskFailed(ctx context.Context, taskID int64, deletedRows int64, errorMsg string) error {
+ query := `
+ UPDATE usage_cleanup_tasks
+ SET status = $1,
+ deleted_rows = $2,
+ error_message = $3,
+ finished_at = NOW(),
+ updated_at = NOW()
+ WHERE id = $4
+ `
+ _, err := r.sql.ExecContext(ctx, query, service.UsageCleanupStatusFailed, deletedRows, errorMsg, taskID)
+ return err
+}
+
+func (r *usageCleanupRepository) DeleteUsageLogsBatch(ctx context.Context, filters service.UsageCleanupFilters, limit int) (int64, error) {
+ if filters.StartTime.IsZero() || filters.EndTime.IsZero() {
+ return 0, fmt.Errorf("cleanup filters missing time range")
+ }
+ whereClause, args := buildUsageCleanupWhere(filters)
+ if whereClause == "" {
+ return 0, fmt.Errorf("cleanup filters missing time range")
+ }
+ args = append(args, limit)
+ query := fmt.Sprintf(`
+ WITH target AS (
+ SELECT id
+ FROM usage_logs
+ WHERE %s
+ ORDER BY created_at ASC, id ASC
+ LIMIT $%d
+ )
+ DELETE FROM usage_logs
+ WHERE id IN (SELECT id FROM target)
+ RETURNING id
+ `, whereClause, len(args))
+
+ rows, err := r.sql.QueryContext(ctx, query, args...)
+ if err != nil {
+ return 0, err
+ }
+ defer rows.Close()
+
+ var deleted int64
+ for rows.Next() {
+ deleted++
+ }
+ if err := rows.Err(); err != nil {
+ return 0, err
+ }
+ return deleted, nil
+}
+
+func buildUsageCleanupWhere(filters service.UsageCleanupFilters) (string, []any) {
+ conditions := make([]string, 0, 8)
+ args := make([]any, 0, 8)
+ idx := 1
+ if !filters.StartTime.IsZero() {
+ conditions = append(conditions, fmt.Sprintf("created_at >= $%d", idx))
+ args = append(args, filters.StartTime)
+ idx++
+ }
+ if !filters.EndTime.IsZero() {
+ conditions = append(conditions, fmt.Sprintf("created_at <= $%d", idx))
+ args = append(args, filters.EndTime)
+ idx++
+ }
+ if filters.UserID != nil {
+ conditions = append(conditions, fmt.Sprintf("user_id = $%d", idx))
+ args = append(args, *filters.UserID)
+ idx++
+ }
+ if filters.APIKeyID != nil {
+ conditions = append(conditions, fmt.Sprintf("api_key_id = $%d", idx))
+ args = append(args, *filters.APIKeyID)
+ idx++
+ }
+ if filters.AccountID != nil {
+ conditions = append(conditions, fmt.Sprintf("account_id = $%d", idx))
+ args = append(args, *filters.AccountID)
+ idx++
+ }
+ if filters.GroupID != nil {
+ conditions = append(conditions, fmt.Sprintf("group_id = $%d", idx))
+ args = append(args, *filters.GroupID)
+ idx++
+ }
+ if filters.Model != nil {
+ model := strings.TrimSpace(*filters.Model)
+ if model != "" {
+ conditions = append(conditions, fmt.Sprintf("model = $%d", idx))
+ args = append(args, model)
+ idx++
+ }
+ }
+ if filters.Stream != nil {
+ conditions = append(conditions, fmt.Sprintf("stream = $%d", idx))
+ args = append(args, *filters.Stream)
+ idx++
+ }
+ if filters.BillingType != nil {
+ conditions = append(conditions, fmt.Sprintf("billing_type = $%d", idx))
+ args = append(args, *filters.BillingType)
+ idx++
+ }
+ return strings.Join(conditions, " AND "), args
+}
diff --git a/backend/internal/repository/usage_cleanup_repo_test.go b/backend/internal/repository/usage_cleanup_repo_test.go
new file mode 100644
index 00000000..e5582709
--- /dev/null
+++ b/backend/internal/repository/usage_cleanup_repo_test.go
@@ -0,0 +1,440 @@
+package repository
+
+import (
+ "context"
+ "database/sql"
+ "encoding/json"
+ "testing"
+ "time"
+
+ "github.com/DATA-DOG/go-sqlmock"
+ "github.com/Wei-Shaw/sub2api/internal/pkg/pagination"
+ "github.com/Wei-Shaw/sub2api/internal/service"
+ "github.com/stretchr/testify/require"
+)
+
+func newSQLMock(t *testing.T) (*sql.DB, sqlmock.Sqlmock) {
+ t.Helper()
+ db, mock, err := sqlmock.New(sqlmock.QueryMatcherOption(sqlmock.QueryMatcherRegexp))
+ require.NoError(t, err)
+ t.Cleanup(func() { _ = db.Close() })
+ return db, mock
+}
+
+func TestNewUsageCleanupRepository(t *testing.T) {
+ db, _ := newSQLMock(t)
+ repo := NewUsageCleanupRepository(db)
+ require.NotNil(t, repo)
+}
+
+func TestUsageCleanupRepositoryCreateTask(t *testing.T) {
+ db, mock := newSQLMock(t)
+ repo := &usageCleanupRepository{sql: db}
+
+ start := time.Date(2024, 1, 1, 0, 0, 0, 0, time.UTC)
+ end := start.Add(24 * time.Hour)
+ task := &service.UsageCleanupTask{
+ Status: service.UsageCleanupStatusPending,
+ Filters: service.UsageCleanupFilters{StartTime: start, EndTime: end},
+ CreatedBy: 12,
+ }
+ now := time.Date(2024, 1, 2, 0, 0, 0, 0, time.UTC)
+
+ mock.ExpectQuery("INSERT INTO usage_cleanup_tasks").
+ WithArgs(task.Status, sqlmock.AnyArg(), task.CreatedBy, task.DeletedRows).
+ WillReturnRows(sqlmock.NewRows([]string{"id", "created_at", "updated_at"}).AddRow(int64(1), now, now))
+
+ err := repo.CreateTask(context.Background(), task)
+ require.NoError(t, err)
+ require.Equal(t, int64(1), task.ID)
+ require.Equal(t, now, task.CreatedAt)
+ require.Equal(t, now, task.UpdatedAt)
+ require.NoError(t, mock.ExpectationsWereMet())
+}
+
+func TestUsageCleanupRepositoryCreateTaskNil(t *testing.T) {
+ db, mock := newSQLMock(t)
+ repo := &usageCleanupRepository{sql: db}
+
+ err := repo.CreateTask(context.Background(), nil)
+ require.NoError(t, err)
+ require.NoError(t, mock.ExpectationsWereMet())
+}
+
+func TestUsageCleanupRepositoryCreateTaskQueryError(t *testing.T) {
+ db, mock := newSQLMock(t)
+ repo := &usageCleanupRepository{sql: db}
+
+ task := &service.UsageCleanupTask{
+ Status: service.UsageCleanupStatusPending,
+ Filters: service.UsageCleanupFilters{StartTime: time.Now(), EndTime: time.Now().Add(time.Hour)},
+ CreatedBy: 1,
+ }
+
+ mock.ExpectQuery("INSERT INTO usage_cleanup_tasks").
+ WithArgs(task.Status, sqlmock.AnyArg(), task.CreatedBy, task.DeletedRows).
+ WillReturnError(sql.ErrConnDone)
+
+ err := repo.CreateTask(context.Background(), task)
+ require.Error(t, err)
+ require.NoError(t, mock.ExpectationsWereMet())
+}
+
+func TestUsageCleanupRepositoryListTasksEmpty(t *testing.T) {
+ db, mock := newSQLMock(t)
+ repo := &usageCleanupRepository{sql: db}
+
+ mock.ExpectQuery("SELECT COUNT\\(\\*\\) FROM usage_cleanup_tasks").
+ WillReturnRows(sqlmock.NewRows([]string{"count"}).AddRow(int64(0)))
+
+ tasks, result, err := repo.ListTasks(context.Background(), pagination.PaginationParams{Page: 1, PageSize: 20})
+ require.NoError(t, err)
+ require.Empty(t, tasks)
+ require.Equal(t, int64(0), result.Total)
+ require.NoError(t, mock.ExpectationsWereMet())
+}
+
+func TestUsageCleanupRepositoryListTasks(t *testing.T) {
+ db, mock := newSQLMock(t)
+ repo := &usageCleanupRepository{sql: db}
+
+ start := time.Date(2024, 1, 1, 0, 0, 0, 0, time.UTC)
+ end := start.Add(2 * time.Hour)
+ filters := service.UsageCleanupFilters{StartTime: start, EndTime: end}
+ filtersJSON, err := json.Marshal(filters)
+ require.NoError(t, err)
+
+ createdAt := time.Date(2024, 1, 2, 12, 0, 0, 0, time.UTC)
+ updatedAt := createdAt.Add(time.Minute)
+ rows := sqlmock.NewRows([]string{
+ "id", "status", "filters", "created_by", "deleted_rows", "error_message",
+ "canceled_by", "canceled_at",
+ "started_at", "finished_at", "created_at", "updated_at",
+ }).AddRow(
+ int64(1),
+ service.UsageCleanupStatusSucceeded,
+ filtersJSON,
+ int64(2),
+ int64(9),
+ "error",
+ nil,
+ nil,
+ start,
+ end,
+ createdAt,
+ updatedAt,
+ )
+
+ mock.ExpectQuery("SELECT COUNT\\(\\*\\) FROM usage_cleanup_tasks").
+ WillReturnRows(sqlmock.NewRows([]string{"count"}).AddRow(int64(1)))
+ mock.ExpectQuery("SELECT id, status, filters, created_by, deleted_rows, error_message").
+ WithArgs(20, 0).
+ WillReturnRows(rows)
+
+ tasks, result, err := repo.ListTasks(context.Background(), pagination.PaginationParams{Page: 1, PageSize: 20})
+ require.NoError(t, err)
+ require.Len(t, tasks, 1)
+ require.Equal(t, int64(1), tasks[0].ID)
+ require.Equal(t, service.UsageCleanupStatusSucceeded, tasks[0].Status)
+ require.Equal(t, int64(2), tasks[0].CreatedBy)
+ require.Equal(t, int64(9), tasks[0].DeletedRows)
+ require.NotNil(t, tasks[0].ErrorMsg)
+ require.Equal(t, "error", *tasks[0].ErrorMsg)
+ require.NotNil(t, tasks[0].StartedAt)
+ require.NotNil(t, tasks[0].FinishedAt)
+ require.Equal(t, int64(1), result.Total)
+ require.NoError(t, mock.ExpectationsWereMet())
+}
+
+func TestUsageCleanupRepositoryListTasksInvalidFilters(t *testing.T) {
+ db, mock := newSQLMock(t)
+ repo := &usageCleanupRepository{sql: db}
+
+ rows := sqlmock.NewRows([]string{
+ "id", "status", "filters", "created_by", "deleted_rows", "error_message",
+ "canceled_by", "canceled_at",
+ "started_at", "finished_at", "created_at", "updated_at",
+ }).AddRow(
+ int64(1),
+ service.UsageCleanupStatusSucceeded,
+ []byte("not-json"),
+ int64(2),
+ int64(9),
+ nil,
+ nil,
+ nil,
+ nil,
+ nil,
+ time.Now().UTC(),
+ time.Now().UTC(),
+ )
+
+ mock.ExpectQuery("SELECT COUNT\\(\\*\\) FROM usage_cleanup_tasks").
+ WillReturnRows(sqlmock.NewRows([]string{"count"}).AddRow(int64(1)))
+ mock.ExpectQuery("SELECT id, status, filters, created_by, deleted_rows, error_message").
+ WithArgs(20, 0).
+ WillReturnRows(rows)
+
+ _, _, err := repo.ListTasks(context.Background(), pagination.PaginationParams{Page: 1, PageSize: 20})
+ require.Error(t, err)
+ require.NoError(t, mock.ExpectationsWereMet())
+}
+
+func TestUsageCleanupRepositoryClaimNextPendingTaskNone(t *testing.T) {
+ db, mock := newSQLMock(t)
+ repo := &usageCleanupRepository{sql: db}
+
+ mock.ExpectQuery("UPDATE usage_cleanup_tasks").
+ WithArgs(service.UsageCleanupStatusPending, service.UsageCleanupStatusRunning, int64(1800), service.UsageCleanupStatusRunning).
+ WillReturnRows(sqlmock.NewRows([]string{
+ "id", "status", "filters", "created_by", "deleted_rows", "error_message",
+ "started_at", "finished_at", "created_at", "updated_at",
+ }))
+
+ task, err := repo.ClaimNextPendingTask(context.Background(), 1800)
+ require.NoError(t, err)
+ require.Nil(t, task)
+ require.NoError(t, mock.ExpectationsWereMet())
+}
+
+func TestUsageCleanupRepositoryClaimNextPendingTask(t *testing.T) {
+ db, mock := newSQLMock(t)
+ repo := &usageCleanupRepository{sql: db}
+
+ start := time.Date(2024, 1, 1, 0, 0, 0, 0, time.UTC)
+ end := start.Add(24 * time.Hour)
+ filters := service.UsageCleanupFilters{StartTime: start, EndTime: end}
+ filtersJSON, err := json.Marshal(filters)
+ require.NoError(t, err)
+
+ rows := sqlmock.NewRows([]string{
+ "id", "status", "filters", "created_by", "deleted_rows", "error_message",
+ "started_at", "finished_at", "created_at", "updated_at",
+ }).AddRow(
+ int64(4),
+ service.UsageCleanupStatusRunning,
+ filtersJSON,
+ int64(7),
+ int64(0),
+ nil,
+ start,
+ nil,
+ start,
+ start,
+ )
+
+ mock.ExpectQuery("UPDATE usage_cleanup_tasks").
+ WithArgs(service.UsageCleanupStatusPending, service.UsageCleanupStatusRunning, int64(1800), service.UsageCleanupStatusRunning).
+ WillReturnRows(rows)
+
+ task, err := repo.ClaimNextPendingTask(context.Background(), 1800)
+ require.NoError(t, err)
+ require.NotNil(t, task)
+ require.Equal(t, int64(4), task.ID)
+ require.Equal(t, service.UsageCleanupStatusRunning, task.Status)
+ require.Equal(t, int64(7), task.CreatedBy)
+ require.NotNil(t, task.StartedAt)
+ require.Nil(t, task.ErrorMsg)
+ require.NoError(t, mock.ExpectationsWereMet())
+}
+
+func TestUsageCleanupRepositoryClaimNextPendingTaskError(t *testing.T) {
+ db, mock := newSQLMock(t)
+ repo := &usageCleanupRepository{sql: db}
+
+ mock.ExpectQuery("UPDATE usage_cleanup_tasks").
+ WithArgs(service.UsageCleanupStatusPending, service.UsageCleanupStatusRunning, int64(1800), service.UsageCleanupStatusRunning).
+ WillReturnError(sql.ErrConnDone)
+
+ _, err := repo.ClaimNextPendingTask(context.Background(), 1800)
+ require.Error(t, err)
+ require.NoError(t, mock.ExpectationsWereMet())
+}
+
+func TestUsageCleanupRepositoryClaimNextPendingTaskInvalidFilters(t *testing.T) {
+ db, mock := newSQLMock(t)
+ repo := &usageCleanupRepository{sql: db}
+
+ rows := sqlmock.NewRows([]string{
+ "id", "status", "filters", "created_by", "deleted_rows", "error_message",
+ "started_at", "finished_at", "created_at", "updated_at",
+ }).AddRow(
+ int64(4),
+ service.UsageCleanupStatusRunning,
+ []byte("invalid"),
+ int64(7),
+ int64(0),
+ nil,
+ nil,
+ nil,
+ time.Now().UTC(),
+ time.Now().UTC(),
+ )
+
+ mock.ExpectQuery("UPDATE usage_cleanup_tasks").
+ WithArgs(service.UsageCleanupStatusPending, service.UsageCleanupStatusRunning, int64(1800), service.UsageCleanupStatusRunning).
+ WillReturnRows(rows)
+
+ _, err := repo.ClaimNextPendingTask(context.Background(), 1800)
+ require.Error(t, err)
+ require.NoError(t, mock.ExpectationsWereMet())
+}
+
+func TestUsageCleanupRepositoryMarkTaskSucceeded(t *testing.T) {
+ db, mock := newSQLMock(t)
+ repo := &usageCleanupRepository{sql: db}
+
+ mock.ExpectExec("UPDATE usage_cleanup_tasks").
+ WithArgs(service.UsageCleanupStatusSucceeded, int64(12), int64(9)).
+ WillReturnResult(sqlmock.NewResult(0, 1))
+
+ err := repo.MarkTaskSucceeded(context.Background(), 9, 12)
+ require.NoError(t, err)
+ require.NoError(t, mock.ExpectationsWereMet())
+}
+
+func TestUsageCleanupRepositoryMarkTaskFailed(t *testing.T) {
+ db, mock := newSQLMock(t)
+ repo := &usageCleanupRepository{sql: db}
+
+ mock.ExpectExec("UPDATE usage_cleanup_tasks").
+ WithArgs(service.UsageCleanupStatusFailed, int64(4), "boom", int64(2)).
+ WillReturnResult(sqlmock.NewResult(0, 1))
+
+ err := repo.MarkTaskFailed(context.Background(), 2, 4, "boom")
+ require.NoError(t, err)
+ require.NoError(t, mock.ExpectationsWereMet())
+}
+
+func TestUsageCleanupRepositoryGetTaskStatus(t *testing.T) {
+ db, mock := newSQLMock(t)
+ repo := &usageCleanupRepository{sql: db}
+
+ mock.ExpectQuery("SELECT status FROM usage_cleanup_tasks").
+ WithArgs(int64(9)).
+ WillReturnRows(sqlmock.NewRows([]string{"status"}).AddRow(service.UsageCleanupStatusPending))
+
+ status, err := repo.GetTaskStatus(context.Background(), 9)
+ require.NoError(t, err)
+ require.Equal(t, service.UsageCleanupStatusPending, status)
+ require.NoError(t, mock.ExpectationsWereMet())
+}
+
+func TestUsageCleanupRepositoryUpdateTaskProgress(t *testing.T) {
+ db, mock := newSQLMock(t)
+ repo := &usageCleanupRepository{sql: db}
+
+ mock.ExpectExec("UPDATE usage_cleanup_tasks").
+ WithArgs(int64(123), int64(8)).
+ WillReturnResult(sqlmock.NewResult(0, 1))
+
+ err := repo.UpdateTaskProgress(context.Background(), 8, 123)
+ require.NoError(t, err)
+ require.NoError(t, mock.ExpectationsWereMet())
+}
+
+func TestUsageCleanupRepositoryCancelTask(t *testing.T) {
+ db, mock := newSQLMock(t)
+ repo := &usageCleanupRepository{sql: db}
+
+ mock.ExpectQuery("UPDATE usage_cleanup_tasks").
+ WithArgs(service.UsageCleanupStatusCanceled, int64(6), int64(9), service.UsageCleanupStatusPending, service.UsageCleanupStatusRunning).
+ WillReturnRows(sqlmock.NewRows([]string{"id"}).AddRow(int64(6)))
+
+ ok, err := repo.CancelTask(context.Background(), 6, 9)
+ require.NoError(t, err)
+ require.True(t, ok)
+ require.NoError(t, mock.ExpectationsWereMet())
+}
+
+func TestUsageCleanupRepositoryDeleteUsageLogsBatchMissingRange(t *testing.T) {
+ db, _ := newSQLMock(t)
+ repo := &usageCleanupRepository{sql: db}
+
+ _, err := repo.DeleteUsageLogsBatch(context.Background(), service.UsageCleanupFilters{}, 10)
+ require.Error(t, err)
+}
+
+func TestUsageCleanupRepositoryDeleteUsageLogsBatch(t *testing.T) {
+ db, mock := newSQLMock(t)
+ repo := &usageCleanupRepository{sql: db}
+
+ start := time.Date(2024, 1, 1, 0, 0, 0, 0, time.UTC)
+ end := start.Add(24 * time.Hour)
+ userID := int64(3)
+ model := " gpt-4 "
+ filters := service.UsageCleanupFilters{
+ StartTime: start,
+ EndTime: end,
+ UserID: &userID,
+ Model: &model,
+ }
+
+ mock.ExpectQuery("DELETE FROM usage_logs").
+ WithArgs(start, end, userID, "gpt-4", 2).
+ WillReturnRows(sqlmock.NewRows([]string{"id"}).AddRow(int64(1)).AddRow(int64(2)))
+
+ deleted, err := repo.DeleteUsageLogsBatch(context.Background(), filters, 2)
+ require.NoError(t, err)
+ require.Equal(t, int64(2), deleted)
+ require.NoError(t, mock.ExpectationsWereMet())
+}
+
+func TestUsageCleanupRepositoryDeleteUsageLogsBatchQueryError(t *testing.T) {
+ db, mock := newSQLMock(t)
+ repo := &usageCleanupRepository{sql: db}
+
+ start := time.Date(2024, 1, 1, 0, 0, 0, 0, time.UTC)
+ end := start.Add(24 * time.Hour)
+ filters := service.UsageCleanupFilters{StartTime: start, EndTime: end}
+
+ mock.ExpectQuery("DELETE FROM usage_logs").
+ WithArgs(start, end, 5).
+ WillReturnError(sql.ErrConnDone)
+
+ _, err := repo.DeleteUsageLogsBatch(context.Background(), filters, 5)
+ require.Error(t, err)
+ require.NoError(t, mock.ExpectationsWereMet())
+}
+
+func TestBuildUsageCleanupWhere(t *testing.T) {
+ start := time.Date(2024, 1, 1, 0, 0, 0, 0, time.UTC)
+ end := start.Add(24 * time.Hour)
+ userID := int64(1)
+ apiKeyID := int64(2)
+ accountID := int64(3)
+ groupID := int64(4)
+ model := " gpt-4 "
+ stream := true
+ billingType := int8(2)
+
+ where, args := buildUsageCleanupWhere(service.UsageCleanupFilters{
+ StartTime: start,
+ EndTime: end,
+ UserID: &userID,
+ APIKeyID: &apiKeyID,
+ AccountID: &accountID,
+ GroupID: &groupID,
+ Model: &model,
+ Stream: &stream,
+ BillingType: &billingType,
+ })
+
+ require.Equal(t, "created_at >= $1 AND created_at <= $2 AND user_id = $3 AND api_key_id = $4 AND account_id = $5 AND group_id = $6 AND model = $7 AND stream = $8 AND billing_type = $9", where)
+ require.Equal(t, []any{start, end, userID, apiKeyID, accountID, groupID, "gpt-4", stream, billingType}, args)
+}
+
+func TestBuildUsageCleanupWhereModelEmpty(t *testing.T) {
+ start := time.Date(2024, 1, 1, 0, 0, 0, 0, time.UTC)
+ end := start.Add(24 * time.Hour)
+ model := " "
+
+ where, args := buildUsageCleanupWhere(service.UsageCleanupFilters{
+ StartTime: start,
+ EndTime: end,
+ Model: &model,
+ })
+
+ require.Equal(t, "created_at >= $1 AND created_at <= $2", where)
+ require.Equal(t, []any{start, end}, args)
+}
diff --git a/backend/internal/repository/usage_log_repo.go b/backend/internal/repository/usage_log_repo.go
index 4a2aaade..963db7ba 100644
--- a/backend/internal/repository/usage_log_repo.go
+++ b/backend/internal/repository/usage_log_repo.go
@@ -1411,7 +1411,7 @@ func (r *usageLogRepository) GetBatchAPIKeyUsageStats(ctx context.Context, apiKe
}
// GetUsageTrendWithFilters returns usage trend data with optional filters
-func (r *usageLogRepository) GetUsageTrendWithFilters(ctx context.Context, startTime, endTime time.Time, granularity string, userID, apiKeyID, accountID, groupID int64, model string, stream *bool) (results []TrendDataPoint, err error) {
+func (r *usageLogRepository) GetUsageTrendWithFilters(ctx context.Context, startTime, endTime time.Time, granularity string, userID, apiKeyID, accountID, groupID int64, model string, stream *bool, billingType *int8) (results []TrendDataPoint, err error) {
dateFormat := "YYYY-MM-DD"
if granularity == "hour" {
dateFormat = "YYYY-MM-DD HH24:00"
@@ -1456,6 +1456,10 @@ func (r *usageLogRepository) GetUsageTrendWithFilters(ctx context.Context, start
query += fmt.Sprintf(" AND stream = $%d", len(args)+1)
args = append(args, *stream)
}
+ if billingType != nil {
+ query += fmt.Sprintf(" AND billing_type = $%d", len(args)+1)
+ args = append(args, int16(*billingType))
+ }
query += " GROUP BY date ORDER BY date ASC"
rows, err := r.sql.QueryContext(ctx, query, args...)
@@ -1479,7 +1483,7 @@ func (r *usageLogRepository) GetUsageTrendWithFilters(ctx context.Context, start
}
// GetModelStatsWithFilters returns model statistics with optional filters
-func (r *usageLogRepository) GetModelStatsWithFilters(ctx context.Context, startTime, endTime time.Time, userID, apiKeyID, accountID, groupID int64, stream *bool) (results []ModelStat, err error) {
+func (r *usageLogRepository) GetModelStatsWithFilters(ctx context.Context, startTime, endTime time.Time, userID, apiKeyID, accountID, groupID int64, stream *bool, billingType *int8) (results []ModelStat, err error) {
actualCostExpr := "COALESCE(SUM(actual_cost), 0) as actual_cost"
// 当仅按 account_id 聚合时,实际费用使用账号倍率(total_cost * account_rate_multiplier)。
if accountID > 0 && userID == 0 && apiKeyID == 0 {
@@ -1520,6 +1524,10 @@ func (r *usageLogRepository) GetModelStatsWithFilters(ctx context.Context, start
query += fmt.Sprintf(" AND stream = $%d", len(args)+1)
args = append(args, *stream)
}
+ if billingType != nil {
+ query += fmt.Sprintf(" AND billing_type = $%d", len(args)+1)
+ args = append(args, int16(*billingType))
+ }
query += " GROUP BY model ORDER BY total_tokens DESC"
rows, err := r.sql.QueryContext(ctx, query, args...)
@@ -1825,7 +1833,7 @@ func (r *usageLogRepository) GetAccountUsageStats(ctx context.Context, accountID
}
}
- models, err := r.GetModelStatsWithFilters(ctx, startTime, endTime, 0, 0, accountID, 0, nil)
+ models, err := r.GetModelStatsWithFilters(ctx, startTime, endTime, 0, 0, accountID, 0, nil, nil)
if err != nil {
models = []ModelStat{}
}
diff --git a/backend/internal/repository/usage_log_repo_integration_test.go b/backend/internal/repository/usage_log_repo_integration_test.go
index 7174be18..eb220f22 100644
--- a/backend/internal/repository/usage_log_repo_integration_test.go
+++ b/backend/internal/repository/usage_log_repo_integration_test.go
@@ -944,17 +944,17 @@ func (s *UsageLogRepoSuite) TestGetUsageTrendWithFilters() {
endTime := base.Add(48 * time.Hour)
// Test with user filter
- trend, err := s.repo.GetUsageTrendWithFilters(s.ctx, startTime, endTime, "day", user.ID, 0, 0, 0, "", nil)
+ trend, err := s.repo.GetUsageTrendWithFilters(s.ctx, startTime, endTime, "day", user.ID, 0, 0, 0, "", nil, nil)
s.Require().NoError(err, "GetUsageTrendWithFilters user filter")
s.Require().Len(trend, 2)
// Test with apiKey filter
- trend, err = s.repo.GetUsageTrendWithFilters(s.ctx, startTime, endTime, "day", 0, apiKey.ID, 0, 0, "", nil)
+ trend, err = s.repo.GetUsageTrendWithFilters(s.ctx, startTime, endTime, "day", 0, apiKey.ID, 0, 0, "", nil, nil)
s.Require().NoError(err, "GetUsageTrendWithFilters apiKey filter")
s.Require().Len(trend, 2)
// Test with both filters
- trend, err = s.repo.GetUsageTrendWithFilters(s.ctx, startTime, endTime, "day", user.ID, apiKey.ID, 0, 0, "", nil)
+ trend, err = s.repo.GetUsageTrendWithFilters(s.ctx, startTime, endTime, "day", user.ID, apiKey.ID, 0, 0, "", nil, nil)
s.Require().NoError(err, "GetUsageTrendWithFilters both filters")
s.Require().Len(trend, 2)
}
@@ -971,7 +971,7 @@ func (s *UsageLogRepoSuite) TestGetUsageTrendWithFilters_HourlyGranularity() {
startTime := base.Add(-1 * time.Hour)
endTime := base.Add(3 * time.Hour)
- trend, err := s.repo.GetUsageTrendWithFilters(s.ctx, startTime, endTime, "hour", user.ID, 0, 0, 0, "", nil)
+ trend, err := s.repo.GetUsageTrendWithFilters(s.ctx, startTime, endTime, "hour", user.ID, 0, 0, 0, "", nil, nil)
s.Require().NoError(err, "GetUsageTrendWithFilters hourly")
s.Require().Len(trend, 2)
}
@@ -1017,17 +1017,17 @@ func (s *UsageLogRepoSuite) TestGetModelStatsWithFilters() {
endTime := base.Add(2 * time.Hour)
// Test with user filter
- stats, err := s.repo.GetModelStatsWithFilters(s.ctx, startTime, endTime, user.ID, 0, 0, 0, nil)
+ stats, err := s.repo.GetModelStatsWithFilters(s.ctx, startTime, endTime, user.ID, 0, 0, 0, nil, nil)
s.Require().NoError(err, "GetModelStatsWithFilters user filter")
s.Require().Len(stats, 2)
// Test with apiKey filter
- stats, err = s.repo.GetModelStatsWithFilters(s.ctx, startTime, endTime, 0, apiKey.ID, 0, 0, nil)
+ stats, err = s.repo.GetModelStatsWithFilters(s.ctx, startTime, endTime, 0, apiKey.ID, 0, 0, nil, nil)
s.Require().NoError(err, "GetModelStatsWithFilters apiKey filter")
s.Require().Len(stats, 2)
// Test with account filter
- stats, err = s.repo.GetModelStatsWithFilters(s.ctx, startTime, endTime, 0, 0, account.ID, 0, nil)
+ stats, err = s.repo.GetModelStatsWithFilters(s.ctx, startTime, endTime, 0, 0, account.ID, 0, nil, nil)
s.Require().NoError(err, "GetModelStatsWithFilters account filter")
s.Require().Len(stats, 2)
}
diff --git a/backend/internal/repository/wire.go b/backend/internal/repository/wire.go
index 91ef9413..9dc91eca 100644
--- a/backend/internal/repository/wire.go
+++ b/backend/internal/repository/wire.go
@@ -47,6 +47,7 @@ var ProviderSet = wire.NewSet(
NewRedeemCodeRepository,
NewPromoCodeRepository,
NewUsageLogRepository,
+ NewUsageCleanupRepository,
NewDashboardAggregationRepository,
NewSettingRepository,
NewOpsRepository,
diff --git a/backend/internal/server/api_contract_test.go b/backend/internal/server/api_contract_test.go
index 7971c65f..7076f8c5 100644
--- a/backend/internal/server/api_contract_test.go
+++ b/backend/internal/server/api_contract_test.go
@@ -1242,11 +1242,11 @@ func (r *stubUsageLogRepo) GetDashboardStats(ctx context.Context) (*usagestats.D
return nil, errors.New("not implemented")
}
-func (r *stubUsageLogRepo) GetUsageTrendWithFilters(ctx context.Context, startTime, endTime time.Time, granularity string, userID, apiKeyID, accountID, groupID int64, model string, stream *bool) ([]usagestats.TrendDataPoint, error) {
+func (r *stubUsageLogRepo) GetUsageTrendWithFilters(ctx context.Context, startTime, endTime time.Time, granularity string, userID, apiKeyID, accountID, groupID int64, model string, stream *bool, billingType *int8) ([]usagestats.TrendDataPoint, error) {
return nil, errors.New("not implemented")
}
-func (r *stubUsageLogRepo) GetModelStatsWithFilters(ctx context.Context, startTime, endTime time.Time, userID, apiKeyID, accountID, groupID int64, stream *bool) ([]usagestats.ModelStat, error) {
+func (r *stubUsageLogRepo) GetModelStatsWithFilters(ctx context.Context, startTime, endTime time.Time, userID, apiKeyID, accountID, groupID int64, stream *bool, billingType *int8) ([]usagestats.ModelStat, error) {
return nil, errors.New("not implemented")
}
diff --git a/backend/internal/server/routes/admin.go b/backend/internal/server/routes/admin.go
index ff05b32a..050e724d 100644
--- a/backend/internal/server/routes/admin.go
+++ b/backend/internal/server/routes/admin.go
@@ -354,6 +354,9 @@ func registerUsageRoutes(admin *gin.RouterGroup, h *handler.Handlers) {
usage.GET("/stats", h.Admin.Usage.Stats)
usage.GET("/search-users", h.Admin.Usage.SearchUsers)
usage.GET("/search-api-keys", h.Admin.Usage.SearchAPIKeys)
+ usage.GET("/cleanup-tasks", h.Admin.Usage.ListCleanupTasks)
+ usage.POST("/cleanup-tasks", h.Admin.Usage.CreateCleanupTask)
+ usage.POST("/cleanup-tasks/:id/cancel", h.Admin.Usage.CancelCleanupTask)
}
}
diff --git a/backend/internal/service/account_usage_service.go b/backend/internal/service/account_usage_service.go
index d9ed5609..f1c07d5e 100644
--- a/backend/internal/service/account_usage_service.go
+++ b/backend/internal/service/account_usage_service.go
@@ -32,8 +32,8 @@ type UsageLogRepository interface {
// Admin dashboard stats
GetDashboardStats(ctx context.Context) (*usagestats.DashboardStats, error)
- GetUsageTrendWithFilters(ctx context.Context, startTime, endTime time.Time, granularity string, userID, apiKeyID, accountID, groupID int64, model string, stream *bool) ([]usagestats.TrendDataPoint, error)
- GetModelStatsWithFilters(ctx context.Context, startTime, endTime time.Time, userID, apiKeyID, accountID, groupID int64, stream *bool) ([]usagestats.ModelStat, error)
+ GetUsageTrendWithFilters(ctx context.Context, startTime, endTime time.Time, granularity string, userID, apiKeyID, accountID, groupID int64, model string, stream *bool, billingType *int8) ([]usagestats.TrendDataPoint, error)
+ GetModelStatsWithFilters(ctx context.Context, startTime, endTime time.Time, userID, apiKeyID, accountID, groupID int64, stream *bool, billingType *int8) ([]usagestats.ModelStat, error)
GetAPIKeyUsageTrend(ctx context.Context, startTime, endTime time.Time, granularity string, limit int) ([]usagestats.APIKeyUsageTrendPoint, error)
GetUserUsageTrend(ctx context.Context, startTime, endTime time.Time, granularity string, limit int) ([]usagestats.UserUsageTrendPoint, error)
GetBatchUserUsageStats(ctx context.Context, userIDs []int64) (map[int64]*usagestats.BatchUserUsageStats, error)
@@ -272,7 +272,7 @@ func (s *AccountUsageService) getGeminiUsage(ctx context.Context, account *Accou
}
dayStart := geminiDailyWindowStart(now)
- stats, err := s.usageLogRepo.GetModelStatsWithFilters(ctx, dayStart, now, 0, 0, account.ID, 0, nil)
+ stats, err := s.usageLogRepo.GetModelStatsWithFilters(ctx, dayStart, now, 0, 0, account.ID, 0, nil, nil)
if err != nil {
return nil, fmt.Errorf("get gemini usage stats failed: %w", err)
}
@@ -294,7 +294,7 @@ func (s *AccountUsageService) getGeminiUsage(ctx context.Context, account *Accou
// Minute window (RPM) - fixed-window approximation: current minute [truncate(now), truncate(now)+1m)
minuteStart := now.Truncate(time.Minute)
minuteResetAt := minuteStart.Add(time.Minute)
- minuteStats, err := s.usageLogRepo.GetModelStatsWithFilters(ctx, minuteStart, now, 0, 0, account.ID, 0, nil)
+ minuteStats, err := s.usageLogRepo.GetModelStatsWithFilters(ctx, minuteStart, now, 0, 0, account.ID, 0, nil, nil)
if err != nil {
return nil, fmt.Errorf("get gemini minute usage stats failed: %w", err)
}
diff --git a/backend/internal/service/dashboard_aggregation_service.go b/backend/internal/service/dashboard_aggregation_service.go
index da5c0e7d..8f7e8144 100644
--- a/backend/internal/service/dashboard_aggregation_service.go
+++ b/backend/internal/service/dashboard_aggregation_service.go
@@ -21,11 +21,15 @@ var (
ErrDashboardBackfillDisabled = errors.New("仪表盘聚合回填已禁用")
// ErrDashboardBackfillTooLarge 当回填跨度超过限制时返回。
ErrDashboardBackfillTooLarge = errors.New("回填时间跨度过大")
+ errDashboardAggregationRunning = errors.New("聚合作业正在运行")
)
// DashboardAggregationRepository 定义仪表盘预聚合仓储接口。
type DashboardAggregationRepository interface {
AggregateRange(ctx context.Context, start, end time.Time) error
+ // RecomputeRange 重新计算指定时间范围内的聚合数据(包含活跃用户等派生表)。
+ // 设计目的:当 usage_logs 被批量删除/回滚后,确保聚合表可恢复一致性。
+ RecomputeRange(ctx context.Context, start, end time.Time) error
GetAggregationWatermark(ctx context.Context) (time.Time, error)
UpdateAggregationWatermark(ctx context.Context, aggregatedAt time.Time) error
CleanupAggregates(ctx context.Context, hourlyCutoff, dailyCutoff time.Time) error
@@ -112,6 +116,41 @@ func (s *DashboardAggregationService) TriggerBackfill(start, end time.Time) erro
return nil
}
+// TriggerRecomputeRange 触发指定范围的重新计算(异步)。
+// 与 TriggerBackfill 不同:
+// - 不依赖 backfill_enabled(这是内部一致性修复)
+// - 不更新 watermark(避免影响正常增量聚合游标)
+func (s *DashboardAggregationService) TriggerRecomputeRange(start, end time.Time) error {
+ if s == nil || s.repo == nil {
+ return errors.New("聚合服务未初始化")
+ }
+ if !s.cfg.Enabled {
+ return errors.New("聚合服务已禁用")
+ }
+ if !end.After(start) {
+ return errors.New("重新计算时间范围无效")
+ }
+
+ go func() {
+ const maxRetries = 3
+ for i := 0; i < maxRetries; i++ {
+ ctx, cancel := context.WithTimeout(context.Background(), defaultDashboardAggregationBackfillTimeout)
+ err := s.recomputeRange(ctx, start, end)
+ cancel()
+ if err == nil {
+ return
+ }
+ if !errors.Is(err, errDashboardAggregationRunning) {
+ log.Printf("[DashboardAggregation] 重新计算失败: %v", err)
+ return
+ }
+ time.Sleep(5 * time.Second)
+ }
+ log.Printf("[DashboardAggregation] 重新计算放弃: 聚合作业持续占用")
+ }()
+ return nil
+}
+
func (s *DashboardAggregationService) recomputeRecentDays() {
days := s.cfg.RecomputeDays
if days <= 0 {
@@ -128,6 +167,24 @@ func (s *DashboardAggregationService) recomputeRecentDays() {
}
}
+func (s *DashboardAggregationService) recomputeRange(ctx context.Context, start, end time.Time) error {
+ if !atomic.CompareAndSwapInt32(&s.running, 0, 1) {
+ return errDashboardAggregationRunning
+ }
+ defer atomic.StoreInt32(&s.running, 0)
+
+ jobStart := time.Now().UTC()
+ if err := s.repo.RecomputeRange(ctx, start, end); err != nil {
+ return err
+ }
+ log.Printf("[DashboardAggregation] 重新计算完成 (start=%s end=%s duration=%s)",
+ start.UTC().Format(time.RFC3339),
+ end.UTC().Format(time.RFC3339),
+ time.Since(jobStart).String(),
+ )
+ return nil
+}
+
func (s *DashboardAggregationService) runScheduledAggregation() {
if !atomic.CompareAndSwapInt32(&s.running, 0, 1) {
return
@@ -179,7 +236,7 @@ func (s *DashboardAggregationService) runScheduledAggregation() {
func (s *DashboardAggregationService) backfillRange(ctx context.Context, start, end time.Time) error {
if !atomic.CompareAndSwapInt32(&s.running, 0, 1) {
- return errors.New("聚合作业正在运行")
+ return errDashboardAggregationRunning
}
defer atomic.StoreInt32(&s.running, 0)
diff --git a/backend/internal/service/dashboard_aggregation_service_test.go b/backend/internal/service/dashboard_aggregation_service_test.go
index 2fc22105..a7058985 100644
--- a/backend/internal/service/dashboard_aggregation_service_test.go
+++ b/backend/internal/service/dashboard_aggregation_service_test.go
@@ -27,6 +27,10 @@ func (s *dashboardAggregationRepoTestStub) AggregateRange(ctx context.Context, s
return s.aggregateErr
}
+func (s *dashboardAggregationRepoTestStub) RecomputeRange(ctx context.Context, start, end time.Time) error {
+ return s.AggregateRange(ctx, start, end)
+}
+
func (s *dashboardAggregationRepoTestStub) GetAggregationWatermark(ctx context.Context) (time.Time, error) {
return s.watermark, nil
}
diff --git a/backend/internal/service/dashboard_service.go b/backend/internal/service/dashboard_service.go
index a9811919..cd11923e 100644
--- a/backend/internal/service/dashboard_service.go
+++ b/backend/internal/service/dashboard_service.go
@@ -124,16 +124,16 @@ func (s *DashboardService) GetDashboardStats(ctx context.Context) (*usagestats.D
return stats, nil
}
-func (s *DashboardService) GetUsageTrendWithFilters(ctx context.Context, startTime, endTime time.Time, granularity string, userID, apiKeyID, accountID, groupID int64, model string, stream *bool) ([]usagestats.TrendDataPoint, error) {
- trend, err := s.usageRepo.GetUsageTrendWithFilters(ctx, startTime, endTime, granularity, userID, apiKeyID, accountID, groupID, model, stream)
+func (s *DashboardService) GetUsageTrendWithFilters(ctx context.Context, startTime, endTime time.Time, granularity string, userID, apiKeyID, accountID, groupID int64, model string, stream *bool, billingType *int8) ([]usagestats.TrendDataPoint, error) {
+ trend, err := s.usageRepo.GetUsageTrendWithFilters(ctx, startTime, endTime, granularity, userID, apiKeyID, accountID, groupID, model, stream, billingType)
if err != nil {
return nil, fmt.Errorf("get usage trend with filters: %w", err)
}
return trend, nil
}
-func (s *DashboardService) GetModelStatsWithFilters(ctx context.Context, startTime, endTime time.Time, userID, apiKeyID, accountID, groupID int64, stream *bool) ([]usagestats.ModelStat, error) {
- stats, err := s.usageRepo.GetModelStatsWithFilters(ctx, startTime, endTime, userID, apiKeyID, accountID, groupID, stream)
+func (s *DashboardService) GetModelStatsWithFilters(ctx context.Context, startTime, endTime time.Time, userID, apiKeyID, accountID, groupID int64, stream *bool, billingType *int8) ([]usagestats.ModelStat, error) {
+ stats, err := s.usageRepo.GetModelStatsWithFilters(ctx, startTime, endTime, userID, apiKeyID, accountID, groupID, stream, billingType)
if err != nil {
return nil, fmt.Errorf("get model stats with filters: %w", err)
}
diff --git a/backend/internal/service/dashboard_service_test.go b/backend/internal/service/dashboard_service_test.go
index db3c78c3..59b83e66 100644
--- a/backend/internal/service/dashboard_service_test.go
+++ b/backend/internal/service/dashboard_service_test.go
@@ -101,6 +101,10 @@ func (s *dashboardAggregationRepoStub) AggregateRange(ctx context.Context, start
return nil
}
+func (s *dashboardAggregationRepoStub) RecomputeRange(ctx context.Context, start, end time.Time) error {
+ return nil
+}
+
func (s *dashboardAggregationRepoStub) GetAggregationWatermark(ctx context.Context) (time.Time, error) {
if s.err != nil {
return time.Time{}, s.err
diff --git a/backend/internal/service/ratelimit_service.go b/backend/internal/service/ratelimit_service.go
index 47a04cf5..2d75dd5a 100644
--- a/backend/internal/service/ratelimit_service.go
+++ b/backend/internal/service/ratelimit_service.go
@@ -190,7 +190,7 @@ func (s *RateLimitService) PreCheckUsage(ctx context.Context, account *Account,
start := geminiDailyWindowStart(now)
totals, ok := s.getGeminiUsageTotals(account.ID, start, now)
if !ok {
- stats, err := s.usageRepo.GetModelStatsWithFilters(ctx, start, now, 0, 0, account.ID, 0, nil)
+ stats, err := s.usageRepo.GetModelStatsWithFilters(ctx, start, now, 0, 0, account.ID, 0, nil, nil)
if err != nil {
return true, err
}
@@ -237,7 +237,7 @@ func (s *RateLimitService) PreCheckUsage(ctx context.Context, account *Account,
if limit > 0 {
start := now.Truncate(time.Minute)
- stats, err := s.usageRepo.GetModelStatsWithFilters(ctx, start, now, 0, 0, account.ID, 0, nil)
+ stats, err := s.usageRepo.GetModelStatsWithFilters(ctx, start, now, 0, 0, account.ID, 0, nil, nil)
if err != nil {
return true, err
}
diff --git a/backend/internal/service/usage_cleanup.go b/backend/internal/service/usage_cleanup.go
new file mode 100644
index 00000000..7e3ffbb9
--- /dev/null
+++ b/backend/internal/service/usage_cleanup.go
@@ -0,0 +1,74 @@
+package service
+
+import (
+ "context"
+ "time"
+
+ "github.com/Wei-Shaw/sub2api/internal/pkg/pagination"
+)
+
+const (
+ UsageCleanupStatusPending = "pending"
+ UsageCleanupStatusRunning = "running"
+ UsageCleanupStatusSucceeded = "succeeded"
+ UsageCleanupStatusFailed = "failed"
+ UsageCleanupStatusCanceled = "canceled"
+)
+
+// UsageCleanupFilters 定义清理任务过滤条件
+// 时间范围为必填,其他字段可选
+// JSON 序列化用于存储任务参数
+//
+// start_time/end_time 使用 RFC3339 时间格式
+// 以 UTC 或用户时区解析后的时间为准
+//
+// 说明:
+// - nil 表示未设置该过滤条件
+// - 过滤条件均为精确匹配
+type UsageCleanupFilters struct {
+ StartTime time.Time `json:"start_time"`
+ EndTime time.Time `json:"end_time"`
+ UserID *int64 `json:"user_id,omitempty"`
+ APIKeyID *int64 `json:"api_key_id,omitempty"`
+ AccountID *int64 `json:"account_id,omitempty"`
+ GroupID *int64 `json:"group_id,omitempty"`
+ Model *string `json:"model,omitempty"`
+ Stream *bool `json:"stream,omitempty"`
+ BillingType *int8 `json:"billing_type,omitempty"`
+}
+
+// UsageCleanupTask 表示使用记录清理任务
+// 状态包含 pending/running/succeeded/failed/canceled
+type UsageCleanupTask struct {
+ ID int64
+ Status string
+ Filters UsageCleanupFilters
+ CreatedBy int64
+ DeletedRows int64
+ ErrorMsg *string
+ CanceledBy *int64
+ CanceledAt *time.Time
+ StartedAt *time.Time
+ FinishedAt *time.Time
+ CreatedAt time.Time
+ UpdatedAt time.Time
+}
+
+// UsageCleanupRepository 定义清理任务持久层接口
+type UsageCleanupRepository interface {
+ CreateTask(ctx context.Context, task *UsageCleanupTask) error
+ ListTasks(ctx context.Context, params pagination.PaginationParams) ([]UsageCleanupTask, *pagination.PaginationResult, error)
+ // ClaimNextPendingTask 抢占下一条可执行任务:
+ // - 优先 pending
+ // - 若 running 超过 staleRunningAfterSeconds(可能由于进程退出/崩溃/超时),允许重新抢占继续执行
+ ClaimNextPendingTask(ctx context.Context, staleRunningAfterSeconds int64) (*UsageCleanupTask, error)
+ // GetTaskStatus 查询任务状态;若不存在返回 sql.ErrNoRows
+ GetTaskStatus(ctx context.Context, taskID int64) (string, error)
+ // UpdateTaskProgress 更新任务进度(deleted_rows)用于断点续跑/展示
+ UpdateTaskProgress(ctx context.Context, taskID int64, deletedRows int64) error
+ // CancelTask 将任务标记为 canceled(仅允许 pending/running)
+ CancelTask(ctx context.Context, taskID int64, canceledBy int64) (bool, error)
+ MarkTaskSucceeded(ctx context.Context, taskID int64, deletedRows int64) error
+ MarkTaskFailed(ctx context.Context, taskID int64, deletedRows int64, errorMsg string) error
+ DeleteUsageLogsBatch(ctx context.Context, filters UsageCleanupFilters, limit int) (int64, error)
+}
diff --git a/backend/internal/service/usage_cleanup_service.go b/backend/internal/service/usage_cleanup_service.go
new file mode 100644
index 00000000..8ca02cfc
--- /dev/null
+++ b/backend/internal/service/usage_cleanup_service.go
@@ -0,0 +1,400 @@
+package service
+
+import (
+ "context"
+ "database/sql"
+ "errors"
+ "fmt"
+ "log"
+ "net/http"
+ "strings"
+ "sync"
+ "sync/atomic"
+ "time"
+
+ "github.com/Wei-Shaw/sub2api/internal/config"
+ infraerrors "github.com/Wei-Shaw/sub2api/internal/pkg/errors"
+ "github.com/Wei-Shaw/sub2api/internal/pkg/pagination"
+)
+
+const (
+ usageCleanupWorkerName = "usage_cleanup_worker"
+)
+
+// UsageCleanupService 负责创建与执行使用记录清理任务
+type UsageCleanupService struct {
+ repo UsageCleanupRepository
+ timingWheel *TimingWheelService
+ dashboard *DashboardAggregationService
+ cfg *config.Config
+
+ running int32
+ startOnce sync.Once
+ stopOnce sync.Once
+
+ workerCtx context.Context
+ workerCancel context.CancelFunc
+}
+
+func NewUsageCleanupService(repo UsageCleanupRepository, timingWheel *TimingWheelService, dashboard *DashboardAggregationService, cfg *config.Config) *UsageCleanupService {
+ workerCtx, workerCancel := context.WithCancel(context.Background())
+ return &UsageCleanupService{
+ repo: repo,
+ timingWheel: timingWheel,
+ dashboard: dashboard,
+ cfg: cfg,
+ workerCtx: workerCtx,
+ workerCancel: workerCancel,
+ }
+}
+
+func describeUsageCleanupFilters(filters UsageCleanupFilters) string {
+ var parts []string
+ parts = append(parts, "start="+filters.StartTime.UTC().Format(time.RFC3339))
+ parts = append(parts, "end="+filters.EndTime.UTC().Format(time.RFC3339))
+ if filters.UserID != nil {
+ parts = append(parts, fmt.Sprintf("user_id=%d", *filters.UserID))
+ }
+ if filters.APIKeyID != nil {
+ parts = append(parts, fmt.Sprintf("api_key_id=%d", *filters.APIKeyID))
+ }
+ if filters.AccountID != nil {
+ parts = append(parts, fmt.Sprintf("account_id=%d", *filters.AccountID))
+ }
+ if filters.GroupID != nil {
+ parts = append(parts, fmt.Sprintf("group_id=%d", *filters.GroupID))
+ }
+ if filters.Model != nil {
+ parts = append(parts, "model="+strings.TrimSpace(*filters.Model))
+ }
+ if filters.Stream != nil {
+ parts = append(parts, fmt.Sprintf("stream=%t", *filters.Stream))
+ }
+ if filters.BillingType != nil {
+ parts = append(parts, fmt.Sprintf("billing_type=%d", *filters.BillingType))
+ }
+ return strings.Join(parts, " ")
+}
+
+func (s *UsageCleanupService) Start() {
+ if s == nil {
+ return
+ }
+ if s.cfg != nil && !s.cfg.UsageCleanup.Enabled {
+ log.Printf("[UsageCleanup] not started (disabled)")
+ return
+ }
+ if s.repo == nil || s.timingWheel == nil {
+ log.Printf("[UsageCleanup] not started (missing deps)")
+ return
+ }
+
+ interval := s.workerInterval()
+ s.startOnce.Do(func() {
+ s.timingWheel.ScheduleRecurring(usageCleanupWorkerName, interval, s.runOnce)
+ log.Printf("[UsageCleanup] started (interval=%s max_range_days=%d batch_size=%d task_timeout=%s)", interval, s.maxRangeDays(), s.batchSize(), s.taskTimeout())
+ })
+}
+
+func (s *UsageCleanupService) Stop() {
+ if s == nil {
+ return
+ }
+ s.stopOnce.Do(func() {
+ if s.workerCancel != nil {
+ s.workerCancel()
+ }
+ if s.timingWheel != nil {
+ s.timingWheel.Cancel(usageCleanupWorkerName)
+ }
+ log.Printf("[UsageCleanup] stopped")
+ })
+}
+
+func (s *UsageCleanupService) ListTasks(ctx context.Context, params pagination.PaginationParams) ([]UsageCleanupTask, *pagination.PaginationResult, error) {
+ if s == nil || s.repo == nil {
+ return nil, nil, fmt.Errorf("cleanup service not ready")
+ }
+ return s.repo.ListTasks(ctx, params)
+}
+
+func (s *UsageCleanupService) CreateTask(ctx context.Context, filters UsageCleanupFilters, createdBy int64) (*UsageCleanupTask, error) {
+ if s == nil || s.repo == nil {
+ return nil, fmt.Errorf("cleanup service not ready")
+ }
+ if s.cfg != nil && !s.cfg.UsageCleanup.Enabled {
+ return nil, infraerrors.New(http.StatusServiceUnavailable, "USAGE_CLEANUP_DISABLED", "usage cleanup is disabled")
+ }
+ if createdBy <= 0 {
+ return nil, infraerrors.BadRequest("USAGE_CLEANUP_INVALID_CREATOR", "invalid creator")
+ }
+
+ log.Printf("[UsageCleanup] create_task requested: operator=%d %s", createdBy, describeUsageCleanupFilters(filters))
+ sanitizeUsageCleanupFilters(&filters)
+ if err := s.validateFilters(filters); err != nil {
+ log.Printf("[UsageCleanup] create_task rejected: operator=%d err=%v %s", createdBy, err, describeUsageCleanupFilters(filters))
+ return nil, err
+ }
+
+ task := &UsageCleanupTask{
+ Status: UsageCleanupStatusPending,
+ Filters: filters,
+ CreatedBy: createdBy,
+ }
+ if err := s.repo.CreateTask(ctx, task); err != nil {
+ log.Printf("[UsageCleanup] create_task persist failed: operator=%d err=%v %s", createdBy, err, describeUsageCleanupFilters(filters))
+ return nil, fmt.Errorf("create cleanup task: %w", err)
+ }
+ log.Printf("[UsageCleanup] create_task persisted: task=%d operator=%d status=%s deleted_rows=%d %s", task.ID, createdBy, task.Status, task.DeletedRows, describeUsageCleanupFilters(filters))
+ go s.runOnce()
+ return task, nil
+}
+
+func (s *UsageCleanupService) runOnce() {
+ if !atomic.CompareAndSwapInt32(&s.running, 0, 1) {
+ log.Printf("[UsageCleanup] run_once skipped: already_running=true")
+ return
+ }
+ defer atomic.StoreInt32(&s.running, 0)
+
+ parent := context.Background()
+ if s != nil && s.workerCtx != nil {
+ parent = s.workerCtx
+ }
+ ctx, cancel := context.WithTimeout(parent, s.taskTimeout())
+ defer cancel()
+
+ task, err := s.repo.ClaimNextPendingTask(ctx, int64(s.taskTimeout().Seconds()))
+ if err != nil {
+ log.Printf("[UsageCleanup] claim pending task failed: %v", err)
+ return
+ }
+ if task == nil {
+ log.Printf("[UsageCleanup] run_once done: no_task=true")
+ return
+ }
+
+ log.Printf("[UsageCleanup] task claimed: task=%d status=%s created_by=%d deleted_rows=%d %s", task.ID, task.Status, task.CreatedBy, task.DeletedRows, describeUsageCleanupFilters(task.Filters))
+ s.executeTask(ctx, task)
+}
+
+func (s *UsageCleanupService) executeTask(ctx context.Context, task *UsageCleanupTask) {
+ if task == nil {
+ return
+ }
+
+ batchSize := s.batchSize()
+ deletedTotal := task.DeletedRows
+ start := time.Now()
+ log.Printf("[UsageCleanup] task started: task=%d batch_size=%d deleted_rows=%d %s", task.ID, batchSize, deletedTotal, describeUsageCleanupFilters(task.Filters))
+ var batchNum int
+
+ for {
+ if ctx != nil && ctx.Err() != nil {
+ log.Printf("[UsageCleanup] task interrupted: task=%d err=%v", task.ID, ctx.Err())
+ return
+ }
+ canceled, err := s.isTaskCanceled(ctx, task.ID)
+ if err != nil {
+ s.markTaskFailed(task.ID, deletedTotal, err)
+ return
+ }
+ if canceled {
+ log.Printf("[UsageCleanup] task canceled: task=%d deleted_rows=%d duration=%s", task.ID, deletedTotal, time.Since(start))
+ return
+ }
+
+ batchNum++
+ deleted, err := s.repo.DeleteUsageLogsBatch(ctx, task.Filters, batchSize)
+ if err != nil {
+ if errors.Is(err, context.Canceled) || errors.Is(err, context.DeadlineExceeded) {
+ // 任务被中断(例如服务停止/超时),保持 running 状态,后续通过 stale reclaim 续跑。
+ log.Printf("[UsageCleanup] task interrupted: task=%d err=%v", task.ID, err)
+ return
+ }
+ s.markTaskFailed(task.ID, deletedTotal, err)
+ return
+ }
+ deletedTotal += deleted
+ if deleted > 0 {
+ updateCtx, cancel := context.WithTimeout(context.Background(), 3*time.Second)
+ if err := s.repo.UpdateTaskProgress(updateCtx, task.ID, deletedTotal); err != nil {
+ log.Printf("[UsageCleanup] task progress update failed: task=%d deleted_rows=%d err=%v", task.ID, deletedTotal, err)
+ }
+ cancel()
+ }
+ if batchNum <= 3 || batchNum%20 == 0 || deleted < int64(batchSize) {
+ log.Printf("[UsageCleanup] task batch done: task=%d batch=%d deleted=%d deleted_total=%d", task.ID, batchNum, deleted, deletedTotal)
+ }
+ if deleted == 0 || deleted < int64(batchSize) {
+ break
+ }
+ }
+
+ updateCtx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
+ defer cancel()
+ if err := s.repo.MarkTaskSucceeded(updateCtx, task.ID, deletedTotal); err != nil {
+ log.Printf("[UsageCleanup] update task succeeded failed: task=%d err=%v", task.ID, err)
+ } else {
+ log.Printf("[UsageCleanup] task succeeded: task=%d deleted_rows=%d duration=%s", task.ID, deletedTotal, time.Since(start))
+ }
+
+ if s.dashboard != nil {
+ if err := s.dashboard.TriggerRecomputeRange(task.Filters.StartTime, task.Filters.EndTime); err != nil {
+ log.Printf("[UsageCleanup] trigger dashboard recompute failed: task=%d err=%v", task.ID, err)
+ } else {
+ log.Printf("[UsageCleanup] trigger dashboard recompute: task=%d start=%s end=%s", task.ID, task.Filters.StartTime.UTC().Format(time.RFC3339), task.Filters.EndTime.UTC().Format(time.RFC3339))
+ }
+ }
+}
+
+func (s *UsageCleanupService) markTaskFailed(taskID int64, deletedRows int64, err error) {
+ msg := strings.TrimSpace(err.Error())
+ if len(msg) > 500 {
+ msg = msg[:500]
+ }
+ log.Printf("[UsageCleanup] task failed: task=%d deleted_rows=%d err=%s", taskID, deletedRows, msg)
+ ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
+ defer cancel()
+ if updateErr := s.repo.MarkTaskFailed(ctx, taskID, deletedRows, msg); updateErr != nil {
+ log.Printf("[UsageCleanup] update task failed failed: task=%d err=%v", taskID, updateErr)
+ }
+}
+
+func (s *UsageCleanupService) isTaskCanceled(ctx context.Context, taskID int64) (bool, error) {
+ if s == nil || s.repo == nil {
+ return false, fmt.Errorf("cleanup service not ready")
+ }
+ checkCtx, cancel := context.WithTimeout(context.Background(), 2*time.Second)
+ defer cancel()
+ status, err := s.repo.GetTaskStatus(checkCtx, taskID)
+ if err != nil {
+ if errors.Is(err, sql.ErrNoRows) {
+ return false, nil
+ }
+ return false, err
+ }
+ if status == UsageCleanupStatusCanceled {
+ log.Printf("[UsageCleanup] task cancel detected: task=%d", taskID)
+ }
+ return status == UsageCleanupStatusCanceled, nil
+}
+
+func (s *UsageCleanupService) validateFilters(filters UsageCleanupFilters) error {
+ if filters.StartTime.IsZero() || filters.EndTime.IsZero() {
+ return infraerrors.BadRequest("USAGE_CLEANUP_MISSING_RANGE", "start_date and end_date are required")
+ }
+ if filters.EndTime.Before(filters.StartTime) {
+ return infraerrors.BadRequest("USAGE_CLEANUP_INVALID_RANGE", "end_date must be after start_date")
+ }
+ maxDays := s.maxRangeDays()
+ if maxDays > 0 {
+ delta := filters.EndTime.Sub(filters.StartTime)
+ if delta > time.Duration(maxDays)*24*time.Hour {
+ return infraerrors.BadRequest("USAGE_CLEANUP_RANGE_TOO_LARGE", fmt.Sprintf("date range exceeds %d days", maxDays))
+ }
+ }
+ return nil
+}
+
+func (s *UsageCleanupService) CancelTask(ctx context.Context, taskID int64, canceledBy int64) error {
+ if s == nil || s.repo == nil {
+ return fmt.Errorf("cleanup service not ready")
+ }
+ if s.cfg != nil && !s.cfg.UsageCleanup.Enabled {
+ return infraerrors.New(http.StatusServiceUnavailable, "USAGE_CLEANUP_DISABLED", "usage cleanup is disabled")
+ }
+ if canceledBy <= 0 {
+ return infraerrors.BadRequest("USAGE_CLEANUP_INVALID_CANCELLER", "invalid canceller")
+ }
+ status, err := s.repo.GetTaskStatus(ctx, taskID)
+ if err != nil {
+ if errors.Is(err, sql.ErrNoRows) {
+ return infraerrors.New(http.StatusNotFound, "USAGE_CLEANUP_TASK_NOT_FOUND", "cleanup task not found")
+ }
+ return err
+ }
+ log.Printf("[UsageCleanup] cancel_task requested: task=%d operator=%d status=%s", taskID, canceledBy, status)
+ if status != UsageCleanupStatusPending && status != UsageCleanupStatusRunning {
+ return infraerrors.New(http.StatusConflict, "USAGE_CLEANUP_CANCEL_CONFLICT", "cleanup task cannot be canceled in current status")
+ }
+ ok, err := s.repo.CancelTask(ctx, taskID, canceledBy)
+ if err != nil {
+ return err
+ }
+ if !ok {
+ // 状态可能并发改变
+ return infraerrors.New(http.StatusConflict, "USAGE_CLEANUP_CANCEL_CONFLICT", "cleanup task cannot be canceled in current status")
+ }
+ log.Printf("[UsageCleanup] cancel_task done: task=%d operator=%d", taskID, canceledBy)
+ return nil
+}
+
+func sanitizeUsageCleanupFilters(filters *UsageCleanupFilters) {
+ if filters == nil {
+ return
+ }
+ if filters.UserID != nil && *filters.UserID <= 0 {
+ filters.UserID = nil
+ }
+ if filters.APIKeyID != nil && *filters.APIKeyID <= 0 {
+ filters.APIKeyID = nil
+ }
+ if filters.AccountID != nil && *filters.AccountID <= 0 {
+ filters.AccountID = nil
+ }
+ if filters.GroupID != nil && *filters.GroupID <= 0 {
+ filters.GroupID = nil
+ }
+ if filters.Model != nil {
+ model := strings.TrimSpace(*filters.Model)
+ if model == "" {
+ filters.Model = nil
+ } else {
+ filters.Model = &model
+ }
+ }
+ if filters.BillingType != nil && *filters.BillingType < 0 {
+ filters.BillingType = nil
+ }
+}
+
+func (s *UsageCleanupService) maxRangeDays() int {
+ if s == nil || s.cfg == nil {
+ return 31
+ }
+ if s.cfg.UsageCleanup.MaxRangeDays > 0 {
+ return s.cfg.UsageCleanup.MaxRangeDays
+ }
+ return 31
+}
+
+func (s *UsageCleanupService) batchSize() int {
+ if s == nil || s.cfg == nil {
+ return 5000
+ }
+ if s.cfg.UsageCleanup.BatchSize > 0 {
+ return s.cfg.UsageCleanup.BatchSize
+ }
+ return 5000
+}
+
+func (s *UsageCleanupService) workerInterval() time.Duration {
+ if s == nil || s.cfg == nil {
+ return 10 * time.Second
+ }
+ if s.cfg.UsageCleanup.WorkerIntervalSeconds > 0 {
+ return time.Duration(s.cfg.UsageCleanup.WorkerIntervalSeconds) * time.Second
+ }
+ return 10 * time.Second
+}
+
+func (s *UsageCleanupService) taskTimeout() time.Duration {
+ if s == nil || s.cfg == nil {
+ return 30 * time.Minute
+ }
+ if s.cfg.UsageCleanup.TaskTimeoutSeconds > 0 {
+ return time.Duration(s.cfg.UsageCleanup.TaskTimeoutSeconds) * time.Second
+ }
+ return 30 * time.Minute
+}
diff --git a/backend/internal/service/usage_cleanup_service_test.go b/backend/internal/service/usage_cleanup_service_test.go
new file mode 100644
index 00000000..37d3eb19
--- /dev/null
+++ b/backend/internal/service/usage_cleanup_service_test.go
@@ -0,0 +1,420 @@
+package service
+
+import (
+ "context"
+ "database/sql"
+ "errors"
+ "net/http"
+ "strings"
+ "sync"
+ "testing"
+ "time"
+
+ "github.com/Wei-Shaw/sub2api/internal/config"
+ infraerrors "github.com/Wei-Shaw/sub2api/internal/pkg/errors"
+ "github.com/Wei-Shaw/sub2api/internal/pkg/pagination"
+ "github.com/stretchr/testify/require"
+)
+
+type cleanupDeleteResponse struct {
+ deleted int64
+ err error
+}
+
+type cleanupDeleteCall struct {
+ filters UsageCleanupFilters
+ limit int
+}
+
+type cleanupMarkCall struct {
+ taskID int64
+ deletedRows int64
+ errMsg string
+}
+
+type cleanupRepoStub struct {
+ mu sync.Mutex
+ created []*UsageCleanupTask
+ createErr error
+ listTasks []UsageCleanupTask
+ listResult *pagination.PaginationResult
+ listErr error
+ claimQueue []*UsageCleanupTask
+ claimErr error
+ deleteQueue []cleanupDeleteResponse
+ deleteCalls []cleanupDeleteCall
+ markSucceeded []cleanupMarkCall
+ markFailed []cleanupMarkCall
+ statusByID map[int64]string
+ progressCalls []cleanupMarkCall
+ cancelCalls []int64
+}
+
+func (s *cleanupRepoStub) CreateTask(ctx context.Context, task *UsageCleanupTask) error {
+ if task == nil {
+ return nil
+ }
+ s.mu.Lock()
+ defer s.mu.Unlock()
+ if s.createErr != nil {
+ return s.createErr
+ }
+ if task.ID == 0 {
+ task.ID = int64(len(s.created) + 1)
+ }
+ if task.CreatedAt.IsZero() {
+ task.CreatedAt = time.Now().UTC()
+ }
+ if task.UpdatedAt.IsZero() {
+ task.UpdatedAt = task.CreatedAt
+ }
+ clone := *task
+ s.created = append(s.created, &clone)
+ return nil
+}
+
+func (s *cleanupRepoStub) ListTasks(ctx context.Context, params pagination.PaginationParams) ([]UsageCleanupTask, *pagination.PaginationResult, error) {
+ s.mu.Lock()
+ defer s.mu.Unlock()
+ return s.listTasks, s.listResult, s.listErr
+}
+
+func (s *cleanupRepoStub) ClaimNextPendingTask(ctx context.Context, staleRunningAfterSeconds int64) (*UsageCleanupTask, error) {
+ s.mu.Lock()
+ defer s.mu.Unlock()
+ if s.claimErr != nil {
+ return nil, s.claimErr
+ }
+ if len(s.claimQueue) == 0 {
+ return nil, nil
+ }
+ task := s.claimQueue[0]
+ s.claimQueue = s.claimQueue[1:]
+ if s.statusByID == nil {
+ s.statusByID = map[int64]string{}
+ }
+ s.statusByID[task.ID] = UsageCleanupStatusRunning
+ return task, nil
+}
+
+func (s *cleanupRepoStub) GetTaskStatus(ctx context.Context, taskID int64) (string, error) {
+ s.mu.Lock()
+ defer s.mu.Unlock()
+ if s.statusByID == nil {
+ return "", sql.ErrNoRows
+ }
+ status, ok := s.statusByID[taskID]
+ if !ok {
+ return "", sql.ErrNoRows
+ }
+ return status, nil
+}
+
+func (s *cleanupRepoStub) UpdateTaskProgress(ctx context.Context, taskID int64, deletedRows int64) error {
+ s.mu.Lock()
+ defer s.mu.Unlock()
+ s.progressCalls = append(s.progressCalls, cleanupMarkCall{taskID: taskID, deletedRows: deletedRows})
+ return nil
+}
+
+func (s *cleanupRepoStub) CancelTask(ctx context.Context, taskID int64, canceledBy int64) (bool, error) {
+ s.mu.Lock()
+ defer s.mu.Unlock()
+ s.cancelCalls = append(s.cancelCalls, taskID)
+ if s.statusByID == nil {
+ s.statusByID = map[int64]string{}
+ }
+ status := s.statusByID[taskID]
+ if status != UsageCleanupStatusPending && status != UsageCleanupStatusRunning {
+ return false, nil
+ }
+ s.statusByID[taskID] = UsageCleanupStatusCanceled
+ return true, nil
+}
+
+func (s *cleanupRepoStub) MarkTaskSucceeded(ctx context.Context, taskID int64, deletedRows int64) error {
+ s.mu.Lock()
+ defer s.mu.Unlock()
+ s.markSucceeded = append(s.markSucceeded, cleanupMarkCall{taskID: taskID, deletedRows: deletedRows})
+ if s.statusByID == nil {
+ s.statusByID = map[int64]string{}
+ }
+ s.statusByID[taskID] = UsageCleanupStatusSucceeded
+ return nil
+}
+
+func (s *cleanupRepoStub) MarkTaskFailed(ctx context.Context, taskID int64, deletedRows int64, errorMsg string) error {
+ s.mu.Lock()
+ defer s.mu.Unlock()
+ s.markFailed = append(s.markFailed, cleanupMarkCall{taskID: taskID, deletedRows: deletedRows, errMsg: errorMsg})
+ if s.statusByID == nil {
+ s.statusByID = map[int64]string{}
+ }
+ s.statusByID[taskID] = UsageCleanupStatusFailed
+ return nil
+}
+
+func (s *cleanupRepoStub) DeleteUsageLogsBatch(ctx context.Context, filters UsageCleanupFilters, limit int) (int64, error) {
+ s.mu.Lock()
+ defer s.mu.Unlock()
+ s.deleteCalls = append(s.deleteCalls, cleanupDeleteCall{filters: filters, limit: limit})
+ if len(s.deleteQueue) == 0 {
+ return 0, nil
+ }
+ resp := s.deleteQueue[0]
+ s.deleteQueue = s.deleteQueue[1:]
+ return resp.deleted, resp.err
+}
+
+func TestUsageCleanupServiceCreateTaskSanitizeFilters(t *testing.T) {
+ repo := &cleanupRepoStub{}
+ cfg := &config.Config{UsageCleanup: config.UsageCleanupConfig{Enabled: true, MaxRangeDays: 31}}
+ svc := NewUsageCleanupService(repo, nil, nil, cfg)
+
+ start := time.Date(2024, 1, 1, 0, 0, 0, 0, time.UTC)
+ end := start.Add(24 * time.Hour)
+ userID := int64(-1)
+ apiKeyID := int64(10)
+ model := " gpt-4 "
+ billingType := int8(-2)
+ filters := UsageCleanupFilters{
+ StartTime: start,
+ EndTime: end,
+ UserID: &userID,
+ APIKeyID: &apiKeyID,
+ Model: &model,
+ BillingType: &billingType,
+ }
+
+ task, err := svc.CreateTask(context.Background(), filters, 9)
+ require.NoError(t, err)
+ require.Equal(t, UsageCleanupStatusPending, task.Status)
+ require.Nil(t, task.Filters.UserID)
+ require.NotNil(t, task.Filters.APIKeyID)
+ require.Equal(t, apiKeyID, *task.Filters.APIKeyID)
+ require.NotNil(t, task.Filters.Model)
+ require.Equal(t, "gpt-4", *task.Filters.Model)
+ require.Nil(t, task.Filters.BillingType)
+ require.Equal(t, int64(9), task.CreatedBy)
+}
+
+func TestUsageCleanupServiceCreateTaskInvalidCreator(t *testing.T) {
+ repo := &cleanupRepoStub{}
+ cfg := &config.Config{UsageCleanup: config.UsageCleanupConfig{Enabled: true}}
+ svc := NewUsageCleanupService(repo, nil, nil, cfg)
+
+ filters := UsageCleanupFilters{
+ StartTime: time.Now(),
+ EndTime: time.Now().Add(24 * time.Hour),
+ }
+ _, err := svc.CreateTask(context.Background(), filters, 0)
+ require.Error(t, err)
+ require.Equal(t, "USAGE_CLEANUP_INVALID_CREATOR", infraerrors.Reason(err))
+}
+
+func TestUsageCleanupServiceCreateTaskDisabled(t *testing.T) {
+ repo := &cleanupRepoStub{}
+ cfg := &config.Config{UsageCleanup: config.UsageCleanupConfig{Enabled: false}}
+ svc := NewUsageCleanupService(repo, nil, nil, cfg)
+
+ filters := UsageCleanupFilters{
+ StartTime: time.Now(),
+ EndTime: time.Now().Add(24 * time.Hour),
+ }
+ _, err := svc.CreateTask(context.Background(), filters, 1)
+ require.Error(t, err)
+ require.Equal(t, http.StatusServiceUnavailable, infraerrors.Code(err))
+ require.Equal(t, "USAGE_CLEANUP_DISABLED", infraerrors.Reason(err))
+}
+
+func TestUsageCleanupServiceCreateTaskRangeTooLarge(t *testing.T) {
+ repo := &cleanupRepoStub{}
+ cfg := &config.Config{UsageCleanup: config.UsageCleanupConfig{Enabled: true, MaxRangeDays: 1}}
+ svc := NewUsageCleanupService(repo, nil, nil, cfg)
+
+ start := time.Date(2024, 1, 1, 0, 0, 0, 0, time.UTC)
+ end := start.Add(48 * time.Hour)
+ filters := UsageCleanupFilters{StartTime: start, EndTime: end}
+
+ _, err := svc.CreateTask(context.Background(), filters, 1)
+ require.Error(t, err)
+ require.Equal(t, "USAGE_CLEANUP_RANGE_TOO_LARGE", infraerrors.Reason(err))
+}
+
+func TestUsageCleanupServiceCreateTaskMissingRange(t *testing.T) {
+ repo := &cleanupRepoStub{}
+ cfg := &config.Config{UsageCleanup: config.UsageCleanupConfig{Enabled: true}}
+ svc := NewUsageCleanupService(repo, nil, nil, cfg)
+
+ _, err := svc.CreateTask(context.Background(), UsageCleanupFilters{}, 1)
+ require.Error(t, err)
+ require.Equal(t, "USAGE_CLEANUP_MISSING_RANGE", infraerrors.Reason(err))
+}
+
+func TestUsageCleanupServiceCreateTaskRepoError(t *testing.T) {
+ repo := &cleanupRepoStub{createErr: errors.New("db down")}
+ cfg := &config.Config{UsageCleanup: config.UsageCleanupConfig{Enabled: true}}
+ svc := NewUsageCleanupService(repo, nil, nil, cfg)
+
+ filters := UsageCleanupFilters{
+ StartTime: time.Now(),
+ EndTime: time.Now().Add(24 * time.Hour),
+ }
+ _, err := svc.CreateTask(context.Background(), filters, 1)
+ require.Error(t, err)
+ require.Contains(t, err.Error(), "create cleanup task")
+}
+
+func TestUsageCleanupServiceRunOnceSuccess(t *testing.T) {
+ repo := &cleanupRepoStub{
+ claimQueue: []*UsageCleanupTask{
+ {ID: 5, Filters: UsageCleanupFilters{StartTime: time.Now(), EndTime: time.Now().Add(2 * time.Hour)}},
+ },
+ deleteQueue: []cleanupDeleteResponse{
+ {deleted: 2},
+ {deleted: 2},
+ {deleted: 1},
+ },
+ }
+ cfg := &config.Config{UsageCleanup: config.UsageCleanupConfig{Enabled: true, BatchSize: 2, TaskTimeoutSeconds: 30}}
+ svc := NewUsageCleanupService(repo, nil, nil, cfg)
+
+ svc.runOnce()
+
+ repo.mu.Lock()
+ defer repo.mu.Unlock()
+ require.Len(t, repo.deleteCalls, 3)
+ require.Len(t, repo.markSucceeded, 1)
+ require.Empty(t, repo.markFailed)
+ require.Equal(t, int64(5), repo.markSucceeded[0].taskID)
+ require.Equal(t, int64(5), repo.markSucceeded[0].deletedRows)
+}
+
+func TestUsageCleanupServiceRunOnceClaimError(t *testing.T) {
+ repo := &cleanupRepoStub{claimErr: errors.New("claim failed")}
+ cfg := &config.Config{UsageCleanup: config.UsageCleanupConfig{Enabled: true}}
+ svc := NewUsageCleanupService(repo, nil, nil, cfg)
+ svc.runOnce()
+
+ repo.mu.Lock()
+ defer repo.mu.Unlock()
+ require.Empty(t, repo.markSucceeded)
+ require.Empty(t, repo.markFailed)
+}
+
+func TestUsageCleanupServiceRunOnceAlreadyRunning(t *testing.T) {
+ repo := &cleanupRepoStub{}
+ cfg := &config.Config{UsageCleanup: config.UsageCleanupConfig{Enabled: true}}
+ svc := NewUsageCleanupService(repo, nil, nil, cfg)
+ svc.running = 1
+ svc.runOnce()
+}
+
+func TestUsageCleanupServiceExecuteTaskFailed(t *testing.T) {
+ longMsg := strings.Repeat("x", 600)
+ repo := &cleanupRepoStub{
+ deleteQueue: []cleanupDeleteResponse{
+ {err: errors.New(longMsg)},
+ },
+ }
+ cfg := &config.Config{UsageCleanup: config.UsageCleanupConfig{Enabled: true, BatchSize: 3}}
+ svc := NewUsageCleanupService(repo, nil, nil, cfg)
+ task := &UsageCleanupTask{
+ ID: 11,
+ Filters: UsageCleanupFilters{
+ StartTime: time.Now(),
+ EndTime: time.Now().Add(24 * time.Hour),
+ },
+ }
+
+ svc.executeTask(context.Background(), task)
+
+ repo.mu.Lock()
+ defer repo.mu.Unlock()
+ require.Len(t, repo.markFailed, 1)
+ require.Equal(t, int64(11), repo.markFailed[0].taskID)
+ require.Equal(t, 500, len(repo.markFailed[0].errMsg))
+}
+
+func TestUsageCleanupServiceListTasks(t *testing.T) {
+ repo := &cleanupRepoStub{
+ listTasks: []UsageCleanupTask{{ID: 1}, {ID: 2}},
+ listResult: &pagination.PaginationResult{
+ Total: 2,
+ Page: 1,
+ PageSize: 20,
+ Pages: 1,
+ },
+ }
+ svc := NewUsageCleanupService(repo, nil, nil, &config.Config{UsageCleanup: config.UsageCleanupConfig{Enabled: true}})
+
+ tasks, result, err := svc.ListTasks(context.Background(), pagination.PaginationParams{Page: 1, PageSize: 20})
+ require.NoError(t, err)
+ require.Len(t, tasks, 2)
+ require.Equal(t, int64(2), result.Total)
+}
+
+func TestUsageCleanupServiceListTasksNotReady(t *testing.T) {
+ var nilSvc *UsageCleanupService
+ _, _, err := nilSvc.ListTasks(context.Background(), pagination.PaginationParams{Page: 1, PageSize: 20})
+ require.Error(t, err)
+
+ svc := NewUsageCleanupService(nil, nil, nil, &config.Config{UsageCleanup: config.UsageCleanupConfig{Enabled: true}})
+ _, _, err = svc.ListTasks(context.Background(), pagination.PaginationParams{Page: 1, PageSize: 20})
+ require.Error(t, err)
+}
+
+func TestUsageCleanupServiceDefaultsAndLifecycle(t *testing.T) {
+ var nilSvc *UsageCleanupService
+ require.Equal(t, 31, nilSvc.maxRangeDays())
+ require.Equal(t, 5000, nilSvc.batchSize())
+ require.Equal(t, 10*time.Second, nilSvc.workerInterval())
+ require.Equal(t, 30*time.Minute, nilSvc.taskTimeout())
+ nilSvc.Start()
+ nilSvc.Stop()
+
+ repo := &cleanupRepoStub{}
+ cfgDisabled := &config.Config{UsageCleanup: config.UsageCleanupConfig{Enabled: false}}
+ svcDisabled := NewUsageCleanupService(repo, nil, nil, cfgDisabled)
+ svcDisabled.Start()
+ svcDisabled.Stop()
+
+ timingWheel, err := NewTimingWheelService()
+ require.NoError(t, err)
+
+ cfg := &config.Config{UsageCleanup: config.UsageCleanupConfig{Enabled: true, WorkerIntervalSeconds: 5}}
+ svc := NewUsageCleanupService(repo, timingWheel, nil, cfg)
+ require.Equal(t, 5*time.Second, svc.workerInterval())
+ svc.Start()
+ svc.Stop()
+
+ cfgFallback := &config.Config{UsageCleanup: config.UsageCleanupConfig{Enabled: true}}
+ svcFallback := NewUsageCleanupService(repo, timingWheel, nil, cfgFallback)
+ require.Equal(t, 31, svcFallback.maxRangeDays())
+ require.Equal(t, 5000, svcFallback.batchSize())
+ require.Equal(t, 10*time.Second, svcFallback.workerInterval())
+
+ svcMissingDeps := NewUsageCleanupService(nil, nil, nil, cfgFallback)
+ svcMissingDeps.Start()
+}
+
+func TestSanitizeUsageCleanupFiltersModelEmpty(t *testing.T) {
+ model := " "
+ apiKeyID := int64(-5)
+ accountID := int64(-1)
+ groupID := int64(-2)
+ filters := UsageCleanupFilters{
+ UserID: &apiKeyID,
+ APIKeyID: &apiKeyID,
+ AccountID: &accountID,
+ GroupID: &groupID,
+ Model: &model,
+ }
+
+ sanitizeUsageCleanupFilters(&filters)
+ require.Nil(t, filters.UserID)
+ require.Nil(t, filters.APIKeyID)
+ require.Nil(t, filters.AccountID)
+ require.Nil(t, filters.GroupID)
+ require.Nil(t, filters.Model)
+}
diff --git a/backend/internal/service/wire.go b/backend/internal/service/wire.go
index acc0a5fb..0b9bc20c 100644
--- a/backend/internal/service/wire.go
+++ b/backend/internal/service/wire.go
@@ -57,6 +57,13 @@ func ProvideDashboardAggregationService(repo DashboardAggregationRepository, tim
return svc
}
+// ProvideUsageCleanupService 创建并启动使用记录清理任务服务
+func ProvideUsageCleanupService(repo UsageCleanupRepository, timingWheel *TimingWheelService, dashboardAgg *DashboardAggregationService, cfg *config.Config) *UsageCleanupService {
+ svc := NewUsageCleanupService(repo, timingWheel, dashboardAgg, cfg)
+ svc.Start()
+ return svc
+}
+
// ProvideAccountExpiryService creates and starts AccountExpiryService.
func ProvideAccountExpiryService(accountRepo AccountRepository) *AccountExpiryService {
svc := NewAccountExpiryService(accountRepo, time.Minute)
@@ -248,6 +255,7 @@ var ProviderSet = wire.NewSet(
ProvideAccountExpiryService,
ProvideTimingWheelService,
ProvideDashboardAggregationService,
+ ProvideUsageCleanupService,
ProvideDeferredService,
NewAntigravityQuotaFetcher,
NewUserAttributeService,
diff --git a/backend/migrations/042_add_usage_cleanup_tasks.sql b/backend/migrations/042_add_usage_cleanup_tasks.sql
new file mode 100644
index 00000000..ce4be91f
--- /dev/null
+++ b/backend/migrations/042_add_usage_cleanup_tasks.sql
@@ -0,0 +1,21 @@
+-- 042_add_usage_cleanup_tasks.sql
+-- 使用记录清理任务表
+
+CREATE TABLE IF NOT EXISTS usage_cleanup_tasks (
+ id BIGSERIAL PRIMARY KEY,
+ status VARCHAR(20) NOT NULL,
+ filters JSONB NOT NULL,
+ created_by BIGINT NOT NULL REFERENCES users(id) ON DELETE RESTRICT,
+ deleted_rows BIGINT NOT NULL DEFAULT 0,
+ error_message TEXT,
+ started_at TIMESTAMPTZ,
+ finished_at TIMESTAMPTZ,
+ created_at TIMESTAMPTZ NOT NULL DEFAULT NOW(),
+ updated_at TIMESTAMPTZ NOT NULL DEFAULT NOW()
+);
+
+CREATE INDEX IF NOT EXISTS idx_usage_cleanup_tasks_status_created_at
+ ON usage_cleanup_tasks(status, created_at DESC);
+
+CREATE INDEX IF NOT EXISTS idx_usage_cleanup_tasks_created_at
+ ON usage_cleanup_tasks(created_at DESC);
diff --git a/backend/migrations/043_add_usage_cleanup_cancel_audit.sql b/backend/migrations/043_add_usage_cleanup_cancel_audit.sql
new file mode 100644
index 00000000..42ca6696
--- /dev/null
+++ b/backend/migrations/043_add_usage_cleanup_cancel_audit.sql
@@ -0,0 +1,10 @@
+-- 043_add_usage_cleanup_cancel_audit.sql
+-- usage_cleanup_tasks 取消任务审计字段
+
+ALTER TABLE usage_cleanup_tasks
+ ADD COLUMN IF NOT EXISTS canceled_by BIGINT REFERENCES users(id) ON DELETE SET NULL,
+ ADD COLUMN IF NOT EXISTS canceled_at TIMESTAMPTZ;
+
+CREATE INDEX IF NOT EXISTS idx_usage_cleanup_tasks_canceled_at
+ ON usage_cleanup_tasks(canceled_at DESC);
+
diff --git a/config.yaml b/config.yaml
index 424ce9eb..5e7513fb 100644
--- a/config.yaml
+++ b/config.yaml
@@ -251,6 +251,27 @@ dashboard_aggregation:
# 日聚合保留天数
daily_days: 730
+# =============================================================================
+# Usage Cleanup Task Configuration
+# 使用记录清理任务配置(重启生效)
+# =============================================================================
+usage_cleanup:
+ # Enable cleanup task worker
+ # 启用清理任务执行器
+ enabled: true
+ # Max date range (days) per task
+ # 单次任务最大时间跨度(天)
+ max_range_days: 31
+ # Batch delete size
+ # 单批删除数量
+ batch_size: 5000
+ # Worker interval (seconds)
+ # 执行器轮询间隔(秒)
+ worker_interval_seconds: 10
+ # Task execution timeout (seconds)
+ # 单次任务最大执行时长(秒)
+ task_timeout_seconds: 1800
+
# =============================================================================
# Concurrency Wait Configuration
# 并发等待配置
diff --git a/deploy/config.example.yaml b/deploy/config.example.yaml
index 9e85d1ff..1f4aa266 100644
--- a/deploy/config.example.yaml
+++ b/deploy/config.example.yaml
@@ -292,6 +292,27 @@ dashboard_aggregation:
# 日聚合保留天数
daily_days: 730
+# =============================================================================
+# Usage Cleanup Task Configuration
+# 使用记录清理任务配置(重启生效)
+# =============================================================================
+usage_cleanup:
+ # Enable cleanup task worker
+ # 启用清理任务执行器
+ enabled: true
+ # Max date range (days) per task
+ # 单次任务最大时间跨度(天)
+ max_range_days: 31
+ # Batch delete size
+ # 单批删除数量
+ batch_size: 5000
+ # Worker interval (seconds)
+ # 执行器轮询间隔(秒)
+ worker_interval_seconds: 10
+ # Task execution timeout (seconds)
+ # 单次任务最大执行时长(秒)
+ task_timeout_seconds: 1800
+
# =============================================================================
# Concurrency Wait Configuration
# 并发等待配置
diff --git a/frontend/src/api/admin/dashboard.ts b/frontend/src/api/admin/dashboard.ts
index 9b338788..ae48bec2 100644
--- a/frontend/src/api/admin/dashboard.ts
+++ b/frontend/src/api/admin/dashboard.ts
@@ -50,6 +50,7 @@ export interface TrendParams {
account_id?: number
group_id?: number
stream?: boolean
+ billing_type?: number | null
}
export interface TrendResponse {
@@ -78,6 +79,7 @@ export interface ModelStatsParams {
account_id?: number
group_id?: number
stream?: boolean
+ billing_type?: number | null
}
export interface ModelStatsResponse {
diff --git a/frontend/src/api/admin/usage.ts b/frontend/src/api/admin/usage.ts
index dd85fc24..c271a2d0 100644
--- a/frontend/src/api/admin/usage.ts
+++ b/frontend/src/api/admin/usage.ts
@@ -31,6 +31,46 @@ export interface SimpleApiKey {
user_id: number
}
+export interface UsageCleanupFilters {
+ start_time: string
+ end_time: string
+ user_id?: number
+ api_key_id?: number
+ account_id?: number
+ group_id?: number
+ model?: string | null
+ stream?: boolean | null
+ billing_type?: number | null
+}
+
+export interface UsageCleanupTask {
+ id: number
+ status: string
+ filters: UsageCleanupFilters
+ created_by: number
+ deleted_rows: number
+ error_message?: string | null
+ canceled_by?: number | null
+ canceled_at?: string | null
+ started_at?: string | null
+ finished_at?: string | null
+ created_at: string
+ updated_at: string
+}
+
+export interface CreateUsageCleanupTaskRequest {
+ start_date: string
+ end_date: string
+ user_id?: number
+ api_key_id?: number
+ account_id?: number
+ group_id?: number
+ model?: string | null
+ stream?: boolean | null
+ billing_type?: number | null
+ timezone?: string
+}
+
export interface AdminUsageQueryParams extends UsageQueryParams {
user_id?: number
}
@@ -108,11 +148,51 @@ export async function searchApiKeys(userId?: number, keyword?: string): Promise<
return data
}
+/**
+ * List usage cleanup tasks (admin only)
+ * @param params - Query parameters for pagination
+ * @returns Paginated list of cleanup tasks
+ */
+export async function listCleanupTasks(
+ params: { page?: number; page_size?: number },
+ options?: { signal?: AbortSignal }
+): Promise> {
+ const { data } = await apiClient.get>('/admin/usage/cleanup-tasks', {
+ params,
+ signal: options?.signal
+ })
+ return data
+}
+
+/**
+ * Create a usage cleanup task (admin only)
+ * @param payload - Cleanup task parameters
+ * @returns Created cleanup task
+ */
+export async function createCleanupTask(payload: CreateUsageCleanupTaskRequest): Promise {
+ const { data } = await apiClient.post('/admin/usage/cleanup-tasks', payload)
+ return data
+}
+
+/**
+ * Cancel a usage cleanup task (admin only)
+ * @param taskId - Task ID to cancel
+ */
+export async function cancelCleanupTask(taskId: number): Promise<{ id: number; status: string }> {
+ const { data } = await apiClient.post<{ id: number; status: string }>(
+ `/admin/usage/cleanup-tasks/${taskId}/cancel`
+ )
+ return data
+}
+
export const adminUsageAPI = {
list,
getStats,
searchUsers,
- searchApiKeys
+ searchApiKeys,
+ listCleanupTasks,
+ createCleanupTask,
+ cancelCleanupTask
}
export default adminUsageAPI
diff --git a/frontend/src/components/admin/usage/UsageCleanupDialog.vue b/frontend/src/components/admin/usage/UsageCleanupDialog.vue
new file mode 100644
index 00000000..4cd562e8
--- /dev/null
+++ b/frontend/src/components/admin/usage/UsageCleanupDialog.vue
@@ -0,0 +1,339 @@
+
+
+
+
+
+
+ {{ t('admin.usage.cleanup.warning') }}
+
+
+
+
+
+ {{ t('admin.usage.cleanup.recentTasks') }}
+
+
+
+
+
+
+ {{ t('admin.usage.cleanup.loadingTasks') }}
+
+
+ {{ t('admin.usage.cleanup.noTasks') }}
+
+
+
+
+
+
+ {{ statusLabel(task.status) }}
+
+ #{{ task.id }}
+
+
+
+ {{ formatDateTime(task.created_at) }}
+
+
+
+ {{ t('admin.usage.cleanup.range') }}: {{ formatRange(task) }}
+ {{ t('admin.usage.cleanup.deletedRows') }}: {{ task.deleted_rows.toLocaleString() }}
+
+
+ {{ task.error_message }}
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
diff --git a/frontend/src/components/admin/usage/UsageFilters.vue b/frontend/src/components/admin/usage/UsageFilters.vue
index 0926d83c..b17e0fdc 100644
--- a/frontend/src/components/admin/usage/UsageFilters.vue
+++ b/frontend/src/components/admin/usage/UsageFilters.vue
@@ -127,6 +127,12 @@
+
+
+
+
+
+
@@ -147,10 +153,13 @@
-
+
+
@@ -174,16 +183,20 @@ interface Props {
exporting: boolean
startDate: string
endDate: string
+ showActions?: boolean
}
-const props = defineProps
()
+const props = withDefaults(defineProps(), {
+ showActions: true
+})
const emit = defineEmits([
'update:modelValue',
'update:startDate',
'update:endDate',
'change',
'reset',
- 'export'
+ 'export',
+ 'cleanup'
])
const { t } = useI18n()
@@ -221,6 +234,12 @@ const streamTypeOptions = ref([
{ value: false, label: t('usage.sync') }
])
+const billingTypeOptions = ref([
+ { value: null, label: t('admin.usage.allBillingTypes') },
+ { value: 0, label: t('admin.usage.billingTypeBalance') },
+ { value: 1, label: t('admin.usage.billingTypeSubscription') }
+])
+
const emitChange = () => emit('change')
const updateStartDate = (value: string) => {
diff --git a/frontend/src/i18n/locales/en.ts b/frontend/src/i18n/locales/en.ts
index e4fe1bd1..2a000d0b 100644
--- a/frontend/src/i18n/locales/en.ts
+++ b/frontend/src/i18n/locales/en.ts
@@ -1893,7 +1893,43 @@ export default {
cacheCreationTokens: 'Cache Creation Tokens',
cacheReadTokens: 'Cache Read Tokens',
failedToLoad: 'Failed to load usage records',
- ipAddress: 'IP'
+ billingType: 'Billing Type',
+ allBillingTypes: 'All Billing Types',
+ billingTypeBalance: 'Balance',
+ billingTypeSubscription: 'Subscription',
+ ipAddress: 'IP',
+ cleanup: {
+ button: 'Cleanup',
+ title: 'Cleanup Usage Records',
+ warning: 'Cleanup is irreversible and will affect historical stats.',
+ submit: 'Submit Cleanup',
+ submitting: 'Submitting...',
+ confirmTitle: 'Confirm Cleanup',
+ confirmMessage: 'Are you sure you want to submit this cleanup task? This action cannot be undone.',
+ confirmSubmit: 'Confirm Cleanup',
+ cancel: 'Cancel',
+ cancelConfirmTitle: 'Confirm Cancel',
+ cancelConfirmMessage: 'Are you sure you want to cancel this cleanup task?',
+ cancelConfirm: 'Confirm Cancel',
+ cancelSuccess: 'Cleanup task canceled',
+ cancelFailed: 'Failed to cancel cleanup task',
+ recentTasks: 'Recent Cleanup Tasks',
+ loadingTasks: 'Loading tasks...',
+ noTasks: 'No cleanup tasks yet',
+ range: 'Range',
+ deletedRows: 'Deleted',
+ missingRange: 'Please select a date range',
+ submitSuccess: 'Cleanup task created',
+ submitFailed: 'Failed to create cleanup task',
+ loadFailed: 'Failed to load cleanup tasks',
+ status: {
+ pending: 'Pending',
+ running: 'Running',
+ succeeded: 'Succeeded',
+ failed: 'Failed',
+ canceled: 'Canceled'
+ }
+ }
},
// Ops Monitoring
diff --git a/frontend/src/i18n/locales/zh.ts b/frontend/src/i18n/locales/zh.ts
index 35242c69..0c27f7a3 100644
--- a/frontend/src/i18n/locales/zh.ts
+++ b/frontend/src/i18n/locales/zh.ts
@@ -2041,7 +2041,43 @@ export default {
cacheCreationTokens: '缓存创建 Token',
cacheReadTokens: '缓存读取 Token',
failedToLoad: '加载使用记录失败',
- ipAddress: 'IP'
+ billingType: '计费类型',
+ allBillingTypes: '全部计费类型',
+ billingTypeBalance: '钱包余额',
+ billingTypeSubscription: '订阅套餐',
+ ipAddress: 'IP',
+ cleanup: {
+ button: '清理',
+ title: '清理使用记录',
+ warning: '清理不可恢复,且会影响历史统计回看。',
+ submit: '提交清理',
+ submitting: '提交中...',
+ confirmTitle: '确认清理',
+ confirmMessage: '确定要提交清理任务吗?清理不可恢复。',
+ confirmSubmit: '确认清理',
+ cancel: '取消任务',
+ cancelConfirmTitle: '确认取消',
+ cancelConfirmMessage: '确定要取消该清理任务吗?',
+ cancelConfirm: '确认取消',
+ cancelSuccess: '清理任务已取消',
+ cancelFailed: '取消清理任务失败',
+ recentTasks: '最近清理任务',
+ loadingTasks: '正在加载任务...',
+ noTasks: '暂无清理任务',
+ range: '时间范围',
+ deletedRows: '删除数量',
+ missingRange: '请选择时间范围',
+ submitSuccess: '清理任务已创建',
+ submitFailed: '创建清理任务失败',
+ loadFailed: '加载清理任务失败',
+ status: {
+ pending: '待执行',
+ running: '执行中',
+ succeeded: '已完成',
+ failed: '失败',
+ canceled: '已取消'
+ }
+ }
},
// Ops Monitoring
diff --git a/frontend/src/types/index.ts b/frontend/src/types/index.ts
index 523033c2..1bb6e5d6 100644
--- a/frontend/src/types/index.ts
+++ b/frontend/src/types/index.ts
@@ -618,6 +618,7 @@ export interface UsageLog {
actual_cost: number
rate_multiplier: number
account_rate_multiplier?: number | null
+ billing_type: number
stream: boolean
duration_ms: number
@@ -642,6 +643,33 @@ export interface UsageLog {
subscription?: UserSubscription
}
+export interface UsageCleanupFilters {
+ start_time: string
+ end_time: string
+ user_id?: number
+ api_key_id?: number
+ account_id?: number
+ group_id?: number
+ model?: string | null
+ stream?: boolean | null
+ billing_type?: number | null
+}
+
+export interface UsageCleanupTask {
+ id: number
+ status: string
+ filters: UsageCleanupFilters
+ created_by: number
+ deleted_rows: number
+ error_message?: string | null
+ canceled_by?: number | null
+ canceled_at?: string | null
+ started_at?: string | null
+ finished_at?: string | null
+ created_at: string
+ updated_at: string
+}
+
export interface RedeemCode {
id: number
code: string
@@ -865,6 +893,7 @@ export interface UsageQueryParams {
group_id?: number
model?: string
stream?: boolean
+ billing_type?: number | null
start_date?: string
end_date?: string
}
diff --git a/frontend/src/views/admin/UsageView.vue b/frontend/src/views/admin/UsageView.vue
index 6f62f59e..40b63ec3 100644
--- a/frontend/src/views/admin/UsageView.vue
+++ b/frontend/src/views/admin/UsageView.vue
@@ -17,12 +17,19 @@
-
+
+
diff --git a/frontend/src/views/admin/SubscriptionsView.vue b/frontend/src/views/admin/SubscriptionsView.vue
index 7b38b455..d5a47788 100644
--- a/frontend/src/views/admin/SubscriptionsView.vue
+++ b/frontend/src/views/admin/SubscriptionsView.vue
@@ -85,6 +85,57 @@
+
+
+
+
+
+
+
+
+
+ {{ t('admin.subscriptions.columns.user') }}
+
+
+
+
+
+
+
+
+
- {{
- row.user?.email || t('admin.redeem.userPrefix', { id: row.user_id })
- }}
+
+ {{ userColumnMode === 'email'
+ ? (row.user?.email || t('admin.redeem.userPrefix', { id: row.user_id }))
+ : (row.user?.username || '-')
+ }}
+
@@ -545,8 +602,43 @@ import Icon from '@/components/icons/Icon.vue'
const { t } = useI18n()
const appStore = useAppStore()
-const columns = computed(() => [
- { key: 'user', label: t('admin.subscriptions.columns.user'), sortable: true },
+// User column display mode: 'email' or 'username'
+const userColumnMode = ref<'email' | 'username'>('email')
+const USER_COLUMN_MODE_KEY = 'subscription-user-column-mode'
+
+const loadUserColumnMode = () => {
+ try {
+ const saved = localStorage.getItem(USER_COLUMN_MODE_KEY)
+ if (saved === 'email' || saved === 'username') {
+ userColumnMode.value = saved
+ }
+ } catch (e) {
+ console.error('Failed to load user column mode:', e)
+ }
+}
+
+const saveUserColumnMode = () => {
+ try {
+ localStorage.setItem(USER_COLUMN_MODE_KEY, userColumnMode.value)
+ } catch (e) {
+ console.error('Failed to save user column mode:', e)
+ }
+}
+
+const setUserColumnMode = (mode: 'email' | 'username') => {
+ userColumnMode.value = mode
+ saveUserColumnMode()
+}
+
+// All available columns
+const allColumns = computed(() => [
+ {
+ key: 'user',
+ label: userColumnMode.value === 'email'
+ ? t('admin.subscriptions.columns.user')
+ : t('admin.users.columns.username'),
+ sortable: true
+ },
{ key: 'group', label: t('admin.subscriptions.columns.group'), sortable: true },
{ key: 'usage', label: t('admin.subscriptions.columns.usage'), sortable: false },
{ key: 'expires_at', label: t('admin.subscriptions.columns.expires'), sortable: true },
@@ -554,6 +646,69 @@ const columns = computed(() => [
{ key: 'actions', label: t('admin.subscriptions.columns.actions'), sortable: false }
])
+// Columns that can be toggled (exclude user and actions which are always visible)
+const toggleableColumns = computed(() =>
+ allColumns.value.filter(col => col.key !== 'user' && col.key !== 'actions')
+)
+
+// Hidden columns set
+const hiddenColumns = reactive>(new Set())
+
+// Default hidden columns
+const DEFAULT_HIDDEN_COLUMNS: string[] = []
+
+// localStorage key
+const HIDDEN_COLUMNS_KEY = 'subscription-hidden-columns'
+
+// Load saved column settings
+const loadSavedColumns = () => {
+ try {
+ const saved = localStorage.getItem(HIDDEN_COLUMNS_KEY)
+ if (saved) {
+ const parsed = JSON.parse(saved) as string[]
+ parsed.forEach(key => hiddenColumns.add(key))
+ } else {
+ DEFAULT_HIDDEN_COLUMNS.forEach(key => hiddenColumns.add(key))
+ }
+ } catch (e) {
+ console.error('Failed to load saved columns:', e)
+ DEFAULT_HIDDEN_COLUMNS.forEach(key => hiddenColumns.add(key))
+ }
+}
+
+// Save column settings to localStorage
+const saveColumnsToStorage = () => {
+ try {
+ localStorage.setItem(HIDDEN_COLUMNS_KEY, JSON.stringify([...hiddenColumns]))
+ } catch (e) {
+ console.error('Failed to save columns:', e)
+ }
+}
+
+// Toggle column visibility
+const toggleColumn = (key: string) => {
+ if (hiddenColumns.has(key)) {
+ hiddenColumns.delete(key)
+ } else {
+ hiddenColumns.add(key)
+ }
+ saveColumnsToStorage()
+}
+
+// Check if column is visible
+const isColumnVisible = (key: string) => !hiddenColumns.has(key)
+
+// Filtered columns for display
+const columns = computed(() =>
+ allColumns.value.filter(col =>
+ col.key === 'user' || col.key === 'actions' || !hiddenColumns.has(col.key)
+ )
+)
+
+// Column dropdown state
+const showColumnDropdown = ref(false)
+const columnDropdownRef = ref(null)
+
// Filter options
const statusOptions = computed(() => [
{ value: '', label: t('admin.subscriptions.allStatus') },
@@ -949,14 +1104,19 @@ const formatResetTime = (windowStart: string, period: 'daily' | 'weekly' | 'mont
}
}
-// Handle click outside to close user dropdown
+// Handle click outside to close dropdowns
const handleClickOutside = (event: MouseEvent) => {
const target = event.target as HTMLElement
if (!target.closest('[data-assign-user-search]')) showUserDropdown.value = false
if (!target.closest('[data-filter-user-search]')) showFilterUserDropdown.value = false
+ if (columnDropdownRef.value && !columnDropdownRef.value.contains(target)) {
+ showColumnDropdown.value = false
+ }
}
onMounted(() => {
+ loadUserColumnMode()
+ loadSavedColumns()
loadSubscriptions()
loadGroups()
document.addEventListener('click', handleClickOutside)
From bf7b79f2f037a20b930fe5d5e9760f190ef0ce6b Mon Sep 17 00:00:00 2001
From: yangjianbo
Date: Sun, 18 Jan 2026 11:58:53 +0800
Subject: [PATCH 022/155] =?UTF-8?q?fix(=E6=95=B0=E6=8D=AE=E5=BA=93):=20?=
=?UTF-8?q?=E4=BC=98=E5=8C=96=E4=BB=BB=E5=8A=A1=E7=8A=B6=E6=80=81=E6=9B=B4?=
=?UTF-8?q?=E6=96=B0=E6=9F=A5=E8=AF=A2=EF=BC=8C=E4=BD=BF=E7=94=A8=E5=88=AB?=
=?UTF-8?q?=E5=90=8D=E6=8F=90=E9=AB=98=E5=8F=AF=E8=AF=BB=E6=80=A7?=
MIME-Version: 1.0
Content-Type: text/plain; charset=UTF-8
Content-Transfer-Encoding: 8bit
---
backend/internal/repository/usage_cleanup_repo.go | 8 ++++----
1 file changed, 4 insertions(+), 4 deletions(-)
diff --git a/backend/internal/repository/usage_cleanup_repo.go b/backend/internal/repository/usage_cleanup_repo.go
index b703cc9f..b6dfa42a 100644
--- a/backend/internal/repository/usage_cleanup_repo.go
+++ b/backend/internal/repository/usage_cleanup_repo.go
@@ -136,16 +136,16 @@ func (r *usageCleanupRepository) ClaimNextPendingTask(ctx context.Context, stale
LIMIT 1
FOR UPDATE SKIP LOCKED
)
- UPDATE usage_cleanup_tasks
+ UPDATE usage_cleanup_tasks AS tasks
SET status = $4,
started_at = NOW(),
finished_at = NULL,
error_message = NULL,
updated_at = NOW()
FROM next
- WHERE usage_cleanup_tasks.id = next.id
- RETURNING id, status, filters, created_by, deleted_rows, error_message,
- started_at, finished_at, created_at, updated_at
+ WHERE tasks.id = next.id
+ RETURNING tasks.id, tasks.status, tasks.filters, tasks.created_by, tasks.deleted_rows, tasks.error_message,
+ tasks.started_at, tasks.finished_at, tasks.created_at, tasks.updated_at
`
var task service.UsageCleanupTask
var filtersJSON []byte
From bd18f4b8ef2d1bbb713c362a5efbe20d4bc4fbc8 Mon Sep 17 00:00:00 2001
From: yangjianbo
Date: Sun, 18 Jan 2026 14:18:28 +0800
Subject: [PATCH 023/155] =?UTF-8?q?feat(=E6=B8=85=E7=90=86=E4=BB=BB?=
=?UTF-8?q?=E5=8A=A1):=20=E5=BC=95=E5=85=A5Ent=E5=AD=98=E5=82=A8=E5=B9=B6?=
=?UTF-8?q?=E8=A1=A5=E5=85=85=E6=97=A5=E5=BF=97=E4=B8=8E=E6=B5=8B=E8=AF=95?=
MIME-Version: 1.0
Content-Type: text/plain; charset=UTF-8
Content-Transfer-Encoding: 8bit
新增 usage_cleanup_task Ent schema 与仓储实现,支持清理任务排序分页
补充清理任务全链路日志、仪表盘重算触发及 UI 过滤调整
完善 repository/service 单测并引入 sqlite 测试依赖
---
backend/cmd/server/wire_gen.go | 2 +-
backend/ent/client.go | 159 ++-
backend/ent/ent.go | 2 +
backend/ent/hook/hook.go | 12 +
backend/ent/intercept/intercept.go | 30 +
backend/ent/migrate/schema.go | 42 +
backend/ent/mutation.go | 1086 +++++++++++++++
backend/ent/predicate/predicate.go | 3 +
backend/ent/runtime/runtime.go | 38 +
backend/ent/schema/mixins/soft_delete.go | 5 +-
backend/ent/schema/usage_cleanup_task.go | 75 ++
backend/ent/tx.go | 3 +
backend/ent/usagecleanuptask.go | 236 ++++
.../ent/usagecleanuptask/usagecleanuptask.go | 137 ++
backend/ent/usagecleanuptask/where.go | 620 +++++++++
backend/ent/usagecleanuptask_create.go | 1190 +++++++++++++++++
backend/ent/usagecleanuptask_delete.go | 88 ++
backend/ent/usagecleanuptask_query.go | 564 ++++++++
backend/ent/usagecleanuptask_update.go | 702 ++++++++++
backend/go.mod | 8 +-
backend/go.sum | 15 +
.../admin/usage_cleanup_handler_test.go | 2 +-
.../internal/repository/usage_cleanup_repo.go | 234 +++-
.../repository/usage_cleanup_repo_ent_test.go | 251 ++++
.../repository/usage_cleanup_repo_test.go | 44 +-
.../service/dashboard_aggregation_service.go | 2 +-
.../internal/service/usage_cleanup_service.go | 18 +-
.../service/usage_cleanup_service_test.go | 397 +++++-
.../admin/usage/UsageCleanupDialog.vue | 2 +-
29 files changed, 5920 insertions(+), 47 deletions(-)
create mode 100644 backend/ent/schema/usage_cleanup_task.go
create mode 100644 backend/ent/usagecleanuptask.go
create mode 100644 backend/ent/usagecleanuptask/usagecleanuptask.go
create mode 100644 backend/ent/usagecleanuptask/where.go
create mode 100644 backend/ent/usagecleanuptask_create.go
create mode 100644 backend/ent/usagecleanuptask_delete.go
create mode 100644 backend/ent/usagecleanuptask_query.go
create mode 100644 backend/ent/usagecleanuptask_update.go
create mode 100644 backend/internal/repository/usage_cleanup_repo_ent_test.go
diff --git a/backend/cmd/server/wire_gen.go b/backend/cmd/server/wire_gen.go
index 509cf13a..e5bfa515 100644
--- a/backend/cmd/server/wire_gen.go
+++ b/backend/cmd/server/wire_gen.go
@@ -153,7 +153,7 @@ func initializeApplication(buildInfo handler.BuildInfo) (*Application, error) {
updateService := service.ProvideUpdateService(updateCache, gitHubReleaseClient, serviceBuildInfo)
systemHandler := handler.ProvideSystemHandler(updateService)
adminSubscriptionHandler := admin.NewSubscriptionHandler(subscriptionService)
- usageCleanupRepository := repository.NewUsageCleanupRepository(db)
+ usageCleanupRepository := repository.NewUsageCleanupRepository(client, db)
usageCleanupService := service.ProvideUsageCleanupService(usageCleanupRepository, timingWheelService, dashboardAggregationService, configConfig)
adminUsageHandler := admin.NewUsageHandler(usageService, apiKeyService, adminService, usageCleanupService)
userAttributeDefinitionRepository := repository.NewUserAttributeDefinitionRepository(client)
diff --git a/backend/ent/client.go b/backend/ent/client.go
index 35cf644f..f6c13e84 100644
--- a/backend/ent/client.go
+++ b/backend/ent/client.go
@@ -24,6 +24,7 @@ import (
"github.com/Wei-Shaw/sub2api/ent/proxy"
"github.com/Wei-Shaw/sub2api/ent/redeemcode"
"github.com/Wei-Shaw/sub2api/ent/setting"
+ "github.com/Wei-Shaw/sub2api/ent/usagecleanuptask"
"github.com/Wei-Shaw/sub2api/ent/usagelog"
"github.com/Wei-Shaw/sub2api/ent/user"
"github.com/Wei-Shaw/sub2api/ent/userallowedgroup"
@@ -57,6 +58,8 @@ type Client struct {
RedeemCode *RedeemCodeClient
// Setting is the client for interacting with the Setting builders.
Setting *SettingClient
+ // UsageCleanupTask is the client for interacting with the UsageCleanupTask builders.
+ UsageCleanupTask *UsageCleanupTaskClient
// UsageLog is the client for interacting with the UsageLog builders.
UsageLog *UsageLogClient
// User is the client for interacting with the User builders.
@@ -89,6 +92,7 @@ func (c *Client) init() {
c.Proxy = NewProxyClient(c.config)
c.RedeemCode = NewRedeemCodeClient(c.config)
c.Setting = NewSettingClient(c.config)
+ c.UsageCleanupTask = NewUsageCleanupTaskClient(c.config)
c.UsageLog = NewUsageLogClient(c.config)
c.User = NewUserClient(c.config)
c.UserAllowedGroup = NewUserAllowedGroupClient(c.config)
@@ -196,6 +200,7 @@ func (c *Client) Tx(ctx context.Context) (*Tx, error) {
Proxy: NewProxyClient(cfg),
RedeemCode: NewRedeemCodeClient(cfg),
Setting: NewSettingClient(cfg),
+ UsageCleanupTask: NewUsageCleanupTaskClient(cfg),
UsageLog: NewUsageLogClient(cfg),
User: NewUserClient(cfg),
UserAllowedGroup: NewUserAllowedGroupClient(cfg),
@@ -230,6 +235,7 @@ func (c *Client) BeginTx(ctx context.Context, opts *sql.TxOptions) (*Tx, error)
Proxy: NewProxyClient(cfg),
RedeemCode: NewRedeemCodeClient(cfg),
Setting: NewSettingClient(cfg),
+ UsageCleanupTask: NewUsageCleanupTaskClient(cfg),
UsageLog: NewUsageLogClient(cfg),
User: NewUserClient(cfg),
UserAllowedGroup: NewUserAllowedGroupClient(cfg),
@@ -266,8 +272,9 @@ func (c *Client) Close() error {
func (c *Client) Use(hooks ...Hook) {
for _, n := range []interface{ Use(...Hook) }{
c.APIKey, c.Account, c.AccountGroup, c.Group, c.PromoCode, c.PromoCodeUsage,
- c.Proxy, c.RedeemCode, c.Setting, c.UsageLog, c.User, c.UserAllowedGroup,
- c.UserAttributeDefinition, c.UserAttributeValue, c.UserSubscription,
+ c.Proxy, c.RedeemCode, c.Setting, c.UsageCleanupTask, c.UsageLog, c.User,
+ c.UserAllowedGroup, c.UserAttributeDefinition, c.UserAttributeValue,
+ c.UserSubscription,
} {
n.Use(hooks...)
}
@@ -278,8 +285,9 @@ func (c *Client) Use(hooks ...Hook) {
func (c *Client) Intercept(interceptors ...Interceptor) {
for _, n := range []interface{ Intercept(...Interceptor) }{
c.APIKey, c.Account, c.AccountGroup, c.Group, c.PromoCode, c.PromoCodeUsage,
- c.Proxy, c.RedeemCode, c.Setting, c.UsageLog, c.User, c.UserAllowedGroup,
- c.UserAttributeDefinition, c.UserAttributeValue, c.UserSubscription,
+ c.Proxy, c.RedeemCode, c.Setting, c.UsageCleanupTask, c.UsageLog, c.User,
+ c.UserAllowedGroup, c.UserAttributeDefinition, c.UserAttributeValue,
+ c.UserSubscription,
} {
n.Intercept(interceptors...)
}
@@ -306,6 +314,8 @@ func (c *Client) Mutate(ctx context.Context, m Mutation) (Value, error) {
return c.RedeemCode.mutate(ctx, m)
case *SettingMutation:
return c.Setting.mutate(ctx, m)
+ case *UsageCleanupTaskMutation:
+ return c.UsageCleanupTask.mutate(ctx, m)
case *UsageLogMutation:
return c.UsageLog.mutate(ctx, m)
case *UserMutation:
@@ -1847,6 +1857,139 @@ func (c *SettingClient) mutate(ctx context.Context, m *SettingMutation) (Value,
}
}
+// UsageCleanupTaskClient is a client for the UsageCleanupTask schema.
+type UsageCleanupTaskClient struct {
+ config
+}
+
+// NewUsageCleanupTaskClient returns a client for the UsageCleanupTask from the given config.
+func NewUsageCleanupTaskClient(c config) *UsageCleanupTaskClient {
+ return &UsageCleanupTaskClient{config: c}
+}
+
+// Use adds a list of mutation hooks to the hooks stack.
+// A call to `Use(f, g, h)` equals to `usagecleanuptask.Hooks(f(g(h())))`.
+func (c *UsageCleanupTaskClient) Use(hooks ...Hook) {
+ c.hooks.UsageCleanupTask = append(c.hooks.UsageCleanupTask, hooks...)
+}
+
+// Intercept adds a list of query interceptors to the interceptors stack.
+// A call to `Intercept(f, g, h)` equals to `usagecleanuptask.Intercept(f(g(h())))`.
+func (c *UsageCleanupTaskClient) Intercept(interceptors ...Interceptor) {
+ c.inters.UsageCleanupTask = append(c.inters.UsageCleanupTask, interceptors...)
+}
+
+// Create returns a builder for creating a UsageCleanupTask entity.
+func (c *UsageCleanupTaskClient) Create() *UsageCleanupTaskCreate {
+ mutation := newUsageCleanupTaskMutation(c.config, OpCreate)
+ return &UsageCleanupTaskCreate{config: c.config, hooks: c.Hooks(), mutation: mutation}
+}
+
+// CreateBulk returns a builder for creating a bulk of UsageCleanupTask entities.
+func (c *UsageCleanupTaskClient) CreateBulk(builders ...*UsageCleanupTaskCreate) *UsageCleanupTaskCreateBulk {
+ return &UsageCleanupTaskCreateBulk{config: c.config, builders: builders}
+}
+
+// MapCreateBulk creates a bulk creation builder from the given slice. For each item in the slice, the function creates
+// a builder and applies setFunc on it.
+func (c *UsageCleanupTaskClient) MapCreateBulk(slice any, setFunc func(*UsageCleanupTaskCreate, int)) *UsageCleanupTaskCreateBulk {
+ rv := reflect.ValueOf(slice)
+ if rv.Kind() != reflect.Slice {
+ return &UsageCleanupTaskCreateBulk{err: fmt.Errorf("calling to UsageCleanupTaskClient.MapCreateBulk with wrong type %T, need slice", slice)}
+ }
+ builders := make([]*UsageCleanupTaskCreate, rv.Len())
+ for i := 0; i < rv.Len(); i++ {
+ builders[i] = c.Create()
+ setFunc(builders[i], i)
+ }
+ return &UsageCleanupTaskCreateBulk{config: c.config, builders: builders}
+}
+
+// Update returns an update builder for UsageCleanupTask.
+func (c *UsageCleanupTaskClient) Update() *UsageCleanupTaskUpdate {
+ mutation := newUsageCleanupTaskMutation(c.config, OpUpdate)
+ return &UsageCleanupTaskUpdate{config: c.config, hooks: c.Hooks(), mutation: mutation}
+}
+
+// UpdateOne returns an update builder for the given entity.
+func (c *UsageCleanupTaskClient) UpdateOne(_m *UsageCleanupTask) *UsageCleanupTaskUpdateOne {
+ mutation := newUsageCleanupTaskMutation(c.config, OpUpdateOne, withUsageCleanupTask(_m))
+ return &UsageCleanupTaskUpdateOne{config: c.config, hooks: c.Hooks(), mutation: mutation}
+}
+
+// UpdateOneID returns an update builder for the given id.
+func (c *UsageCleanupTaskClient) UpdateOneID(id int64) *UsageCleanupTaskUpdateOne {
+ mutation := newUsageCleanupTaskMutation(c.config, OpUpdateOne, withUsageCleanupTaskID(id))
+ return &UsageCleanupTaskUpdateOne{config: c.config, hooks: c.Hooks(), mutation: mutation}
+}
+
+// Delete returns a delete builder for UsageCleanupTask.
+func (c *UsageCleanupTaskClient) Delete() *UsageCleanupTaskDelete {
+ mutation := newUsageCleanupTaskMutation(c.config, OpDelete)
+ return &UsageCleanupTaskDelete{config: c.config, hooks: c.Hooks(), mutation: mutation}
+}
+
+// DeleteOne returns a builder for deleting the given entity.
+func (c *UsageCleanupTaskClient) DeleteOne(_m *UsageCleanupTask) *UsageCleanupTaskDeleteOne {
+ return c.DeleteOneID(_m.ID)
+}
+
+// DeleteOneID returns a builder for deleting the given entity by its id.
+func (c *UsageCleanupTaskClient) DeleteOneID(id int64) *UsageCleanupTaskDeleteOne {
+ builder := c.Delete().Where(usagecleanuptask.ID(id))
+ builder.mutation.id = &id
+ builder.mutation.op = OpDeleteOne
+ return &UsageCleanupTaskDeleteOne{builder}
+}
+
+// Query returns a query builder for UsageCleanupTask.
+func (c *UsageCleanupTaskClient) Query() *UsageCleanupTaskQuery {
+ return &UsageCleanupTaskQuery{
+ config: c.config,
+ ctx: &QueryContext{Type: TypeUsageCleanupTask},
+ inters: c.Interceptors(),
+ }
+}
+
+// Get returns a UsageCleanupTask entity by its id.
+func (c *UsageCleanupTaskClient) Get(ctx context.Context, id int64) (*UsageCleanupTask, error) {
+ return c.Query().Where(usagecleanuptask.ID(id)).Only(ctx)
+}
+
+// GetX is like Get, but panics if an error occurs.
+func (c *UsageCleanupTaskClient) GetX(ctx context.Context, id int64) *UsageCleanupTask {
+ obj, err := c.Get(ctx, id)
+ if err != nil {
+ panic(err)
+ }
+ return obj
+}
+
+// Hooks returns the client hooks.
+func (c *UsageCleanupTaskClient) Hooks() []Hook {
+ return c.hooks.UsageCleanupTask
+}
+
+// Interceptors returns the client interceptors.
+func (c *UsageCleanupTaskClient) Interceptors() []Interceptor {
+ return c.inters.UsageCleanupTask
+}
+
+func (c *UsageCleanupTaskClient) mutate(ctx context.Context, m *UsageCleanupTaskMutation) (Value, error) {
+ switch m.Op() {
+ case OpCreate:
+ return (&UsageCleanupTaskCreate{config: c.config, hooks: c.Hooks(), mutation: m}).Save(ctx)
+ case OpUpdate:
+ return (&UsageCleanupTaskUpdate{config: c.config, hooks: c.Hooks(), mutation: m}).Save(ctx)
+ case OpUpdateOne:
+ return (&UsageCleanupTaskUpdateOne{config: c.config, hooks: c.Hooks(), mutation: m}).Save(ctx)
+ case OpDelete, OpDeleteOne:
+ return (&UsageCleanupTaskDelete{config: c.config, hooks: c.Hooks(), mutation: m}).Exec(ctx)
+ default:
+ return nil, fmt.Errorf("ent: unknown UsageCleanupTask mutation op: %q", m.Op())
+ }
+}
+
// UsageLogClient is a client for the UsageLog schema.
type UsageLogClient struct {
config
@@ -2974,13 +3117,13 @@ func (c *UserSubscriptionClient) mutate(ctx context.Context, m *UserSubscription
type (
hooks struct {
APIKey, Account, AccountGroup, Group, PromoCode, PromoCodeUsage, Proxy,
- RedeemCode, Setting, UsageLog, User, UserAllowedGroup, UserAttributeDefinition,
- UserAttributeValue, UserSubscription []ent.Hook
+ RedeemCode, Setting, UsageCleanupTask, UsageLog, User, UserAllowedGroup,
+ UserAttributeDefinition, UserAttributeValue, UserSubscription []ent.Hook
}
inters struct {
APIKey, Account, AccountGroup, Group, PromoCode, PromoCodeUsage, Proxy,
- RedeemCode, Setting, UsageLog, User, UserAllowedGroup, UserAttributeDefinition,
- UserAttributeValue, UserSubscription []ent.Interceptor
+ RedeemCode, Setting, UsageCleanupTask, UsageLog, User, UserAllowedGroup,
+ UserAttributeDefinition, UserAttributeValue, UserSubscription []ent.Interceptor
}
)
diff --git a/backend/ent/ent.go b/backend/ent/ent.go
index 410375a7..4bcc2642 100644
--- a/backend/ent/ent.go
+++ b/backend/ent/ent.go
@@ -21,6 +21,7 @@ import (
"github.com/Wei-Shaw/sub2api/ent/proxy"
"github.com/Wei-Shaw/sub2api/ent/redeemcode"
"github.com/Wei-Shaw/sub2api/ent/setting"
+ "github.com/Wei-Shaw/sub2api/ent/usagecleanuptask"
"github.com/Wei-Shaw/sub2api/ent/usagelog"
"github.com/Wei-Shaw/sub2api/ent/user"
"github.com/Wei-Shaw/sub2api/ent/userallowedgroup"
@@ -96,6 +97,7 @@ func checkColumn(t, c string) error {
proxy.Table: proxy.ValidColumn,
redeemcode.Table: redeemcode.ValidColumn,
setting.Table: setting.ValidColumn,
+ usagecleanuptask.Table: usagecleanuptask.ValidColumn,
usagelog.Table: usagelog.ValidColumn,
user.Table: user.ValidColumn,
userallowedgroup.Table: userallowedgroup.ValidColumn,
diff --git a/backend/ent/hook/hook.go b/backend/ent/hook/hook.go
index 532b0d2c..edd84f5e 100644
--- a/backend/ent/hook/hook.go
+++ b/backend/ent/hook/hook.go
@@ -117,6 +117,18 @@ func (f SettingFunc) Mutate(ctx context.Context, m ent.Mutation) (ent.Value, err
return nil, fmt.Errorf("unexpected mutation type %T. expect *ent.SettingMutation", m)
}
+// The UsageCleanupTaskFunc type is an adapter to allow the use of ordinary
+// function as UsageCleanupTask mutator.
+type UsageCleanupTaskFunc func(context.Context, *ent.UsageCleanupTaskMutation) (ent.Value, error)
+
+// Mutate calls f(ctx, m).
+func (f UsageCleanupTaskFunc) Mutate(ctx context.Context, m ent.Mutation) (ent.Value, error) {
+ if mv, ok := m.(*ent.UsageCleanupTaskMutation); ok {
+ return f(ctx, mv)
+ }
+ return nil, fmt.Errorf("unexpected mutation type %T. expect *ent.UsageCleanupTaskMutation", m)
+}
+
// The UsageLogFunc type is an adapter to allow the use of ordinary
// function as UsageLog mutator.
type UsageLogFunc func(context.Context, *ent.UsageLogMutation) (ent.Value, error)
diff --git a/backend/ent/intercept/intercept.go b/backend/ent/intercept/intercept.go
index 765d39b4..f18c0624 100644
--- a/backend/ent/intercept/intercept.go
+++ b/backend/ent/intercept/intercept.go
@@ -18,6 +18,7 @@ import (
"github.com/Wei-Shaw/sub2api/ent/proxy"
"github.com/Wei-Shaw/sub2api/ent/redeemcode"
"github.com/Wei-Shaw/sub2api/ent/setting"
+ "github.com/Wei-Shaw/sub2api/ent/usagecleanuptask"
"github.com/Wei-Shaw/sub2api/ent/usagelog"
"github.com/Wei-Shaw/sub2api/ent/user"
"github.com/Wei-Shaw/sub2api/ent/userallowedgroup"
@@ -325,6 +326,33 @@ func (f TraverseSetting) Traverse(ctx context.Context, q ent.Query) error {
return fmt.Errorf("unexpected query type %T. expect *ent.SettingQuery", q)
}
+// The UsageCleanupTaskFunc type is an adapter to allow the use of ordinary function as a Querier.
+type UsageCleanupTaskFunc func(context.Context, *ent.UsageCleanupTaskQuery) (ent.Value, error)
+
+// Query calls f(ctx, q).
+func (f UsageCleanupTaskFunc) Query(ctx context.Context, q ent.Query) (ent.Value, error) {
+ if q, ok := q.(*ent.UsageCleanupTaskQuery); ok {
+ return f(ctx, q)
+ }
+ return nil, fmt.Errorf("unexpected query type %T. expect *ent.UsageCleanupTaskQuery", q)
+}
+
+// The TraverseUsageCleanupTask type is an adapter to allow the use of ordinary function as Traverser.
+type TraverseUsageCleanupTask func(context.Context, *ent.UsageCleanupTaskQuery) error
+
+// Intercept is a dummy implementation of Intercept that returns the next Querier in the pipeline.
+func (f TraverseUsageCleanupTask) Intercept(next ent.Querier) ent.Querier {
+ return next
+}
+
+// Traverse calls f(ctx, q).
+func (f TraverseUsageCleanupTask) Traverse(ctx context.Context, q ent.Query) error {
+ if q, ok := q.(*ent.UsageCleanupTaskQuery); ok {
+ return f(ctx, q)
+ }
+ return fmt.Errorf("unexpected query type %T. expect *ent.UsageCleanupTaskQuery", q)
+}
+
// The UsageLogFunc type is an adapter to allow the use of ordinary function as a Querier.
type UsageLogFunc func(context.Context, *ent.UsageLogQuery) (ent.Value, error)
@@ -508,6 +536,8 @@ func NewQuery(q ent.Query) (Query, error) {
return &query[*ent.RedeemCodeQuery, predicate.RedeemCode, redeemcode.OrderOption]{typ: ent.TypeRedeemCode, tq: q}, nil
case *ent.SettingQuery:
return &query[*ent.SettingQuery, predicate.Setting, setting.OrderOption]{typ: ent.TypeSetting, tq: q}, nil
+ case *ent.UsageCleanupTaskQuery:
+ return &query[*ent.UsageCleanupTaskQuery, predicate.UsageCleanupTask, usagecleanuptask.OrderOption]{typ: ent.TypeUsageCleanupTask, tq: q}, nil
case *ent.UsageLogQuery:
return &query[*ent.UsageLogQuery, predicate.UsageLog, usagelog.OrderOption]{typ: ent.TypeUsageLog, tq: q}, nil
case *ent.UserQuery:
diff --git a/backend/ent/migrate/schema.go b/backend/ent/migrate/schema.go
index b377804f..d1f05186 100644
--- a/backend/ent/migrate/schema.go
+++ b/backend/ent/migrate/schema.go
@@ -434,6 +434,44 @@ var (
Columns: SettingsColumns,
PrimaryKey: []*schema.Column{SettingsColumns[0]},
}
+ // UsageCleanupTasksColumns holds the columns for the "usage_cleanup_tasks" table.
+ UsageCleanupTasksColumns = []*schema.Column{
+ {Name: "id", Type: field.TypeInt64, Increment: true},
+ {Name: "created_at", Type: field.TypeTime, SchemaType: map[string]string{"postgres": "timestamptz"}},
+ {Name: "updated_at", Type: field.TypeTime, SchemaType: map[string]string{"postgres": "timestamptz"}},
+ {Name: "status", Type: field.TypeString, Size: 20},
+ {Name: "filters", Type: field.TypeJSON},
+ {Name: "created_by", Type: field.TypeInt64},
+ {Name: "deleted_rows", Type: field.TypeInt64, Default: 0},
+ {Name: "error_message", Type: field.TypeString, Nullable: true},
+ {Name: "canceled_by", Type: field.TypeInt64, Nullable: true},
+ {Name: "canceled_at", Type: field.TypeTime, Nullable: true},
+ {Name: "started_at", Type: field.TypeTime, Nullable: true},
+ {Name: "finished_at", Type: field.TypeTime, Nullable: true},
+ }
+ // UsageCleanupTasksTable holds the schema information for the "usage_cleanup_tasks" table.
+ UsageCleanupTasksTable = &schema.Table{
+ Name: "usage_cleanup_tasks",
+ Columns: UsageCleanupTasksColumns,
+ PrimaryKey: []*schema.Column{UsageCleanupTasksColumns[0]},
+ Indexes: []*schema.Index{
+ {
+ Name: "usagecleanuptask_status_created_at",
+ Unique: false,
+ Columns: []*schema.Column{UsageCleanupTasksColumns[3], UsageCleanupTasksColumns[1]},
+ },
+ {
+ Name: "usagecleanuptask_created_at",
+ Unique: false,
+ Columns: []*schema.Column{UsageCleanupTasksColumns[1]},
+ },
+ {
+ Name: "usagecleanuptask_canceled_at",
+ Unique: false,
+ Columns: []*schema.Column{UsageCleanupTasksColumns[9]},
+ },
+ },
+ }
// UsageLogsColumns holds the columns for the "usage_logs" table.
UsageLogsColumns = []*schema.Column{
{Name: "id", Type: field.TypeInt64, Increment: true},
@@ -805,6 +843,7 @@ var (
ProxiesTable,
RedeemCodesTable,
SettingsTable,
+ UsageCleanupTasksTable,
UsageLogsTable,
UsersTable,
UserAllowedGroupsTable,
@@ -851,6 +890,9 @@ func init() {
SettingsTable.Annotation = &entsql.Annotation{
Table: "settings",
}
+ UsageCleanupTasksTable.Annotation = &entsql.Annotation{
+ Table: "usage_cleanup_tasks",
+ }
UsageLogsTable.ForeignKeys[0].RefTable = APIKeysTable
UsageLogsTable.ForeignKeys[1].RefTable = AccountsTable
UsageLogsTable.ForeignKeys[2].RefTable = GroupsTable
diff --git a/backend/ent/mutation.go b/backend/ent/mutation.go
index cd2fe8e0..9b330616 100644
--- a/backend/ent/mutation.go
+++ b/backend/ent/mutation.go
@@ -4,6 +4,7 @@ package ent
import (
"context"
+ "encoding/json"
"errors"
"fmt"
"sync"
@@ -21,6 +22,7 @@ import (
"github.com/Wei-Shaw/sub2api/ent/proxy"
"github.com/Wei-Shaw/sub2api/ent/redeemcode"
"github.com/Wei-Shaw/sub2api/ent/setting"
+ "github.com/Wei-Shaw/sub2api/ent/usagecleanuptask"
"github.com/Wei-Shaw/sub2api/ent/usagelog"
"github.com/Wei-Shaw/sub2api/ent/user"
"github.com/Wei-Shaw/sub2api/ent/userallowedgroup"
@@ -47,6 +49,7 @@ const (
TypeProxy = "Proxy"
TypeRedeemCode = "RedeemCode"
TypeSetting = "Setting"
+ TypeUsageCleanupTask = "UsageCleanupTask"
TypeUsageLog = "UsageLog"
TypeUser = "User"
TypeUserAllowedGroup = "UserAllowedGroup"
@@ -10370,6 +10373,1089 @@ func (m *SettingMutation) ResetEdge(name string) error {
return fmt.Errorf("unknown Setting edge %s", name)
}
+// UsageCleanupTaskMutation represents an operation that mutates the UsageCleanupTask nodes in the graph.
+type UsageCleanupTaskMutation struct {
+ config
+ op Op
+ typ string
+ id *int64
+ created_at *time.Time
+ updated_at *time.Time
+ status *string
+ filters *json.RawMessage
+ appendfilters json.RawMessage
+ created_by *int64
+ addcreated_by *int64
+ deleted_rows *int64
+ adddeleted_rows *int64
+ error_message *string
+ canceled_by *int64
+ addcanceled_by *int64
+ canceled_at *time.Time
+ started_at *time.Time
+ finished_at *time.Time
+ clearedFields map[string]struct{}
+ done bool
+ oldValue func(context.Context) (*UsageCleanupTask, error)
+ predicates []predicate.UsageCleanupTask
+}
+
+var _ ent.Mutation = (*UsageCleanupTaskMutation)(nil)
+
+// usagecleanuptaskOption allows management of the mutation configuration using functional options.
+type usagecleanuptaskOption func(*UsageCleanupTaskMutation)
+
+// newUsageCleanupTaskMutation creates new mutation for the UsageCleanupTask entity.
+func newUsageCleanupTaskMutation(c config, op Op, opts ...usagecleanuptaskOption) *UsageCleanupTaskMutation {
+ m := &UsageCleanupTaskMutation{
+ config: c,
+ op: op,
+ typ: TypeUsageCleanupTask,
+ clearedFields: make(map[string]struct{}),
+ }
+ for _, opt := range opts {
+ opt(m)
+ }
+ return m
+}
+
+// withUsageCleanupTaskID sets the ID field of the mutation.
+func withUsageCleanupTaskID(id int64) usagecleanuptaskOption {
+ return func(m *UsageCleanupTaskMutation) {
+ var (
+ err error
+ once sync.Once
+ value *UsageCleanupTask
+ )
+ m.oldValue = func(ctx context.Context) (*UsageCleanupTask, error) {
+ once.Do(func() {
+ if m.done {
+ err = errors.New("querying old values post mutation is not allowed")
+ } else {
+ value, err = m.Client().UsageCleanupTask.Get(ctx, id)
+ }
+ })
+ return value, err
+ }
+ m.id = &id
+ }
+}
+
+// withUsageCleanupTask sets the old UsageCleanupTask of the mutation.
+func withUsageCleanupTask(node *UsageCleanupTask) usagecleanuptaskOption {
+ return func(m *UsageCleanupTaskMutation) {
+ m.oldValue = func(context.Context) (*UsageCleanupTask, error) {
+ return node, nil
+ }
+ m.id = &node.ID
+ }
+}
+
+// Client returns a new `ent.Client` from the mutation. If the mutation was
+// executed in a transaction (ent.Tx), a transactional client is returned.
+func (m UsageCleanupTaskMutation) Client() *Client {
+ client := &Client{config: m.config}
+ client.init()
+ return client
+}
+
+// Tx returns an `ent.Tx` for mutations that were executed in transactions;
+// it returns an error otherwise.
+func (m UsageCleanupTaskMutation) Tx() (*Tx, error) {
+ if _, ok := m.driver.(*txDriver); !ok {
+ return nil, errors.New("ent: mutation is not running in a transaction")
+ }
+ tx := &Tx{config: m.config}
+ tx.init()
+ return tx, nil
+}
+
+// ID returns the ID value in the mutation. Note that the ID is only available
+// if it was provided to the builder or after it was returned from the database.
+func (m *UsageCleanupTaskMutation) ID() (id int64, exists bool) {
+ if m.id == nil {
+ return
+ }
+ return *m.id, true
+}
+
+// IDs queries the database and returns the entity ids that match the mutation's predicate.
+// That means, if the mutation is applied within a transaction with an isolation level such
+// as sql.LevelSerializable, the returned ids match the ids of the rows that will be updated
+// or updated by the mutation.
+func (m *UsageCleanupTaskMutation) IDs(ctx context.Context) ([]int64, error) {
+ switch {
+ case m.op.Is(OpUpdateOne | OpDeleteOne):
+ id, exists := m.ID()
+ if exists {
+ return []int64{id}, nil
+ }
+ fallthrough
+ case m.op.Is(OpUpdate | OpDelete):
+ return m.Client().UsageCleanupTask.Query().Where(m.predicates...).IDs(ctx)
+ default:
+ return nil, fmt.Errorf("IDs is not allowed on %s operations", m.op)
+ }
+}
+
+// SetCreatedAt sets the "created_at" field.
+func (m *UsageCleanupTaskMutation) SetCreatedAt(t time.Time) {
+ m.created_at = &t
+}
+
+// CreatedAt returns the value of the "created_at" field in the mutation.
+func (m *UsageCleanupTaskMutation) CreatedAt() (r time.Time, exists bool) {
+ v := m.created_at
+ if v == nil {
+ return
+ }
+ return *v, true
+}
+
+// OldCreatedAt returns the old "created_at" field's value of the UsageCleanupTask entity.
+// If the UsageCleanupTask object wasn't provided to the builder, the object is fetched from the database.
+// An error is returned if the mutation operation is not UpdateOne, or the database query fails.
+func (m *UsageCleanupTaskMutation) OldCreatedAt(ctx context.Context) (v time.Time, err error) {
+ if !m.op.Is(OpUpdateOne) {
+ return v, errors.New("OldCreatedAt is only allowed on UpdateOne operations")
+ }
+ if m.id == nil || m.oldValue == nil {
+ return v, errors.New("OldCreatedAt requires an ID field in the mutation")
+ }
+ oldValue, err := m.oldValue(ctx)
+ if err != nil {
+ return v, fmt.Errorf("querying old value for OldCreatedAt: %w", err)
+ }
+ return oldValue.CreatedAt, nil
+}
+
+// ResetCreatedAt resets all changes to the "created_at" field.
+func (m *UsageCleanupTaskMutation) ResetCreatedAt() {
+ m.created_at = nil
+}
+
+// SetUpdatedAt sets the "updated_at" field.
+func (m *UsageCleanupTaskMutation) SetUpdatedAt(t time.Time) {
+ m.updated_at = &t
+}
+
+// UpdatedAt returns the value of the "updated_at" field in the mutation.
+func (m *UsageCleanupTaskMutation) UpdatedAt() (r time.Time, exists bool) {
+ v := m.updated_at
+ if v == nil {
+ return
+ }
+ return *v, true
+}
+
+// OldUpdatedAt returns the old "updated_at" field's value of the UsageCleanupTask entity.
+// If the UsageCleanupTask object wasn't provided to the builder, the object is fetched from the database.
+// An error is returned if the mutation operation is not UpdateOne, or the database query fails.
+func (m *UsageCleanupTaskMutation) OldUpdatedAt(ctx context.Context) (v time.Time, err error) {
+ if !m.op.Is(OpUpdateOne) {
+ return v, errors.New("OldUpdatedAt is only allowed on UpdateOne operations")
+ }
+ if m.id == nil || m.oldValue == nil {
+ return v, errors.New("OldUpdatedAt requires an ID field in the mutation")
+ }
+ oldValue, err := m.oldValue(ctx)
+ if err != nil {
+ return v, fmt.Errorf("querying old value for OldUpdatedAt: %w", err)
+ }
+ return oldValue.UpdatedAt, nil
+}
+
+// ResetUpdatedAt resets all changes to the "updated_at" field.
+func (m *UsageCleanupTaskMutation) ResetUpdatedAt() {
+ m.updated_at = nil
+}
+
+// SetStatus sets the "status" field.
+func (m *UsageCleanupTaskMutation) SetStatus(s string) {
+ m.status = &s
+}
+
+// Status returns the value of the "status" field in the mutation.
+func (m *UsageCleanupTaskMutation) Status() (r string, exists bool) {
+ v := m.status
+ if v == nil {
+ return
+ }
+ return *v, true
+}
+
+// OldStatus returns the old "status" field's value of the UsageCleanupTask entity.
+// If the UsageCleanupTask object wasn't provided to the builder, the object is fetched from the database.
+// An error is returned if the mutation operation is not UpdateOne, or the database query fails.
+func (m *UsageCleanupTaskMutation) OldStatus(ctx context.Context) (v string, err error) {
+ if !m.op.Is(OpUpdateOne) {
+ return v, errors.New("OldStatus is only allowed on UpdateOne operations")
+ }
+ if m.id == nil || m.oldValue == nil {
+ return v, errors.New("OldStatus requires an ID field in the mutation")
+ }
+ oldValue, err := m.oldValue(ctx)
+ if err != nil {
+ return v, fmt.Errorf("querying old value for OldStatus: %w", err)
+ }
+ return oldValue.Status, nil
+}
+
+// ResetStatus resets all changes to the "status" field.
+func (m *UsageCleanupTaskMutation) ResetStatus() {
+ m.status = nil
+}
+
+// SetFilters sets the "filters" field.
+func (m *UsageCleanupTaskMutation) SetFilters(jm json.RawMessage) {
+ m.filters = &jm
+ m.appendfilters = nil
+}
+
+// Filters returns the value of the "filters" field in the mutation.
+func (m *UsageCleanupTaskMutation) Filters() (r json.RawMessage, exists bool) {
+ v := m.filters
+ if v == nil {
+ return
+ }
+ return *v, true
+}
+
+// OldFilters returns the old "filters" field's value of the UsageCleanupTask entity.
+// If the UsageCleanupTask object wasn't provided to the builder, the object is fetched from the database.
+// An error is returned if the mutation operation is not UpdateOne, or the database query fails.
+func (m *UsageCleanupTaskMutation) OldFilters(ctx context.Context) (v json.RawMessage, err error) {
+ if !m.op.Is(OpUpdateOne) {
+ return v, errors.New("OldFilters is only allowed on UpdateOne operations")
+ }
+ if m.id == nil || m.oldValue == nil {
+ return v, errors.New("OldFilters requires an ID field in the mutation")
+ }
+ oldValue, err := m.oldValue(ctx)
+ if err != nil {
+ return v, fmt.Errorf("querying old value for OldFilters: %w", err)
+ }
+ return oldValue.Filters, nil
+}
+
+// AppendFilters adds jm to the "filters" field.
+func (m *UsageCleanupTaskMutation) AppendFilters(jm json.RawMessage) {
+ m.appendfilters = append(m.appendfilters, jm...)
+}
+
+// AppendedFilters returns the list of values that were appended to the "filters" field in this mutation.
+func (m *UsageCleanupTaskMutation) AppendedFilters() (json.RawMessage, bool) {
+ if len(m.appendfilters) == 0 {
+ return nil, false
+ }
+ return m.appendfilters, true
+}
+
+// ResetFilters resets all changes to the "filters" field.
+func (m *UsageCleanupTaskMutation) ResetFilters() {
+ m.filters = nil
+ m.appendfilters = nil
+}
+
+// SetCreatedBy sets the "created_by" field.
+func (m *UsageCleanupTaskMutation) SetCreatedBy(i int64) {
+ m.created_by = &i
+ m.addcreated_by = nil
+}
+
+// CreatedBy returns the value of the "created_by" field in the mutation.
+func (m *UsageCleanupTaskMutation) CreatedBy() (r int64, exists bool) {
+ v := m.created_by
+ if v == nil {
+ return
+ }
+ return *v, true
+}
+
+// OldCreatedBy returns the old "created_by" field's value of the UsageCleanupTask entity.
+// If the UsageCleanupTask object wasn't provided to the builder, the object is fetched from the database.
+// An error is returned if the mutation operation is not UpdateOne, or the database query fails.
+func (m *UsageCleanupTaskMutation) OldCreatedBy(ctx context.Context) (v int64, err error) {
+ if !m.op.Is(OpUpdateOne) {
+ return v, errors.New("OldCreatedBy is only allowed on UpdateOne operations")
+ }
+ if m.id == nil || m.oldValue == nil {
+ return v, errors.New("OldCreatedBy requires an ID field in the mutation")
+ }
+ oldValue, err := m.oldValue(ctx)
+ if err != nil {
+ return v, fmt.Errorf("querying old value for OldCreatedBy: %w", err)
+ }
+ return oldValue.CreatedBy, nil
+}
+
+// AddCreatedBy adds i to the "created_by" field.
+func (m *UsageCleanupTaskMutation) AddCreatedBy(i int64) {
+ if m.addcreated_by != nil {
+ *m.addcreated_by += i
+ } else {
+ m.addcreated_by = &i
+ }
+}
+
+// AddedCreatedBy returns the value that was added to the "created_by" field in this mutation.
+func (m *UsageCleanupTaskMutation) AddedCreatedBy() (r int64, exists bool) {
+ v := m.addcreated_by
+ if v == nil {
+ return
+ }
+ return *v, true
+}
+
+// ResetCreatedBy resets all changes to the "created_by" field.
+func (m *UsageCleanupTaskMutation) ResetCreatedBy() {
+ m.created_by = nil
+ m.addcreated_by = nil
+}
+
+// SetDeletedRows sets the "deleted_rows" field.
+func (m *UsageCleanupTaskMutation) SetDeletedRows(i int64) {
+ m.deleted_rows = &i
+ m.adddeleted_rows = nil
+}
+
+// DeletedRows returns the value of the "deleted_rows" field in the mutation.
+func (m *UsageCleanupTaskMutation) DeletedRows() (r int64, exists bool) {
+ v := m.deleted_rows
+ if v == nil {
+ return
+ }
+ return *v, true
+}
+
+// OldDeletedRows returns the old "deleted_rows" field's value of the UsageCleanupTask entity.
+// If the UsageCleanupTask object wasn't provided to the builder, the object is fetched from the database.
+// An error is returned if the mutation operation is not UpdateOne, or the database query fails.
+func (m *UsageCleanupTaskMutation) OldDeletedRows(ctx context.Context) (v int64, err error) {
+ if !m.op.Is(OpUpdateOne) {
+ return v, errors.New("OldDeletedRows is only allowed on UpdateOne operations")
+ }
+ if m.id == nil || m.oldValue == nil {
+ return v, errors.New("OldDeletedRows requires an ID field in the mutation")
+ }
+ oldValue, err := m.oldValue(ctx)
+ if err != nil {
+ return v, fmt.Errorf("querying old value for OldDeletedRows: %w", err)
+ }
+ return oldValue.DeletedRows, nil
+}
+
+// AddDeletedRows adds i to the "deleted_rows" field.
+func (m *UsageCleanupTaskMutation) AddDeletedRows(i int64) {
+ if m.adddeleted_rows != nil {
+ *m.adddeleted_rows += i
+ } else {
+ m.adddeleted_rows = &i
+ }
+}
+
+// AddedDeletedRows returns the value that was added to the "deleted_rows" field in this mutation.
+func (m *UsageCleanupTaskMutation) AddedDeletedRows() (r int64, exists bool) {
+ v := m.adddeleted_rows
+ if v == nil {
+ return
+ }
+ return *v, true
+}
+
+// ResetDeletedRows resets all changes to the "deleted_rows" field.
+func (m *UsageCleanupTaskMutation) ResetDeletedRows() {
+ m.deleted_rows = nil
+ m.adddeleted_rows = nil
+}
+
+// SetErrorMessage sets the "error_message" field.
+func (m *UsageCleanupTaskMutation) SetErrorMessage(s string) {
+ m.error_message = &s
+}
+
+// ErrorMessage returns the value of the "error_message" field in the mutation.
+func (m *UsageCleanupTaskMutation) ErrorMessage() (r string, exists bool) {
+ v := m.error_message
+ if v == nil {
+ return
+ }
+ return *v, true
+}
+
+// OldErrorMessage returns the old "error_message" field's value of the UsageCleanupTask entity.
+// If the UsageCleanupTask object wasn't provided to the builder, the object is fetched from the database.
+// An error is returned if the mutation operation is not UpdateOne, or the database query fails.
+func (m *UsageCleanupTaskMutation) OldErrorMessage(ctx context.Context) (v *string, err error) {
+ if !m.op.Is(OpUpdateOne) {
+ return v, errors.New("OldErrorMessage is only allowed on UpdateOne operations")
+ }
+ if m.id == nil || m.oldValue == nil {
+ return v, errors.New("OldErrorMessage requires an ID field in the mutation")
+ }
+ oldValue, err := m.oldValue(ctx)
+ if err != nil {
+ return v, fmt.Errorf("querying old value for OldErrorMessage: %w", err)
+ }
+ return oldValue.ErrorMessage, nil
+}
+
+// ClearErrorMessage clears the value of the "error_message" field.
+func (m *UsageCleanupTaskMutation) ClearErrorMessage() {
+ m.error_message = nil
+ m.clearedFields[usagecleanuptask.FieldErrorMessage] = struct{}{}
+}
+
+// ErrorMessageCleared returns if the "error_message" field was cleared in this mutation.
+func (m *UsageCleanupTaskMutation) ErrorMessageCleared() bool {
+ _, ok := m.clearedFields[usagecleanuptask.FieldErrorMessage]
+ return ok
+}
+
+// ResetErrorMessage resets all changes to the "error_message" field.
+func (m *UsageCleanupTaskMutation) ResetErrorMessage() {
+ m.error_message = nil
+ delete(m.clearedFields, usagecleanuptask.FieldErrorMessage)
+}
+
+// SetCanceledBy sets the "canceled_by" field.
+func (m *UsageCleanupTaskMutation) SetCanceledBy(i int64) {
+ m.canceled_by = &i
+ m.addcanceled_by = nil
+}
+
+// CanceledBy returns the value of the "canceled_by" field in the mutation.
+func (m *UsageCleanupTaskMutation) CanceledBy() (r int64, exists bool) {
+ v := m.canceled_by
+ if v == nil {
+ return
+ }
+ return *v, true
+}
+
+// OldCanceledBy returns the old "canceled_by" field's value of the UsageCleanupTask entity.
+// If the UsageCleanupTask object wasn't provided to the builder, the object is fetched from the database.
+// An error is returned if the mutation operation is not UpdateOne, or the database query fails.
+func (m *UsageCleanupTaskMutation) OldCanceledBy(ctx context.Context) (v *int64, err error) {
+ if !m.op.Is(OpUpdateOne) {
+ return v, errors.New("OldCanceledBy is only allowed on UpdateOne operations")
+ }
+ if m.id == nil || m.oldValue == nil {
+ return v, errors.New("OldCanceledBy requires an ID field in the mutation")
+ }
+ oldValue, err := m.oldValue(ctx)
+ if err != nil {
+ return v, fmt.Errorf("querying old value for OldCanceledBy: %w", err)
+ }
+ return oldValue.CanceledBy, nil
+}
+
+// AddCanceledBy adds i to the "canceled_by" field.
+func (m *UsageCleanupTaskMutation) AddCanceledBy(i int64) {
+ if m.addcanceled_by != nil {
+ *m.addcanceled_by += i
+ } else {
+ m.addcanceled_by = &i
+ }
+}
+
+// AddedCanceledBy returns the value that was added to the "canceled_by" field in this mutation.
+func (m *UsageCleanupTaskMutation) AddedCanceledBy() (r int64, exists bool) {
+ v := m.addcanceled_by
+ if v == nil {
+ return
+ }
+ return *v, true
+}
+
+// ClearCanceledBy clears the value of the "canceled_by" field.
+func (m *UsageCleanupTaskMutation) ClearCanceledBy() {
+ m.canceled_by = nil
+ m.addcanceled_by = nil
+ m.clearedFields[usagecleanuptask.FieldCanceledBy] = struct{}{}
+}
+
+// CanceledByCleared returns if the "canceled_by" field was cleared in this mutation.
+func (m *UsageCleanupTaskMutation) CanceledByCleared() bool {
+ _, ok := m.clearedFields[usagecleanuptask.FieldCanceledBy]
+ return ok
+}
+
+// ResetCanceledBy resets all changes to the "canceled_by" field.
+func (m *UsageCleanupTaskMutation) ResetCanceledBy() {
+ m.canceled_by = nil
+ m.addcanceled_by = nil
+ delete(m.clearedFields, usagecleanuptask.FieldCanceledBy)
+}
+
+// SetCanceledAt sets the "canceled_at" field.
+func (m *UsageCleanupTaskMutation) SetCanceledAt(t time.Time) {
+ m.canceled_at = &t
+}
+
+// CanceledAt returns the value of the "canceled_at" field in the mutation.
+func (m *UsageCleanupTaskMutation) CanceledAt() (r time.Time, exists bool) {
+ v := m.canceled_at
+ if v == nil {
+ return
+ }
+ return *v, true
+}
+
+// OldCanceledAt returns the old "canceled_at" field's value of the UsageCleanupTask entity.
+// If the UsageCleanupTask object wasn't provided to the builder, the object is fetched from the database.
+// An error is returned if the mutation operation is not UpdateOne, or the database query fails.
+func (m *UsageCleanupTaskMutation) OldCanceledAt(ctx context.Context) (v *time.Time, err error) {
+ if !m.op.Is(OpUpdateOne) {
+ return v, errors.New("OldCanceledAt is only allowed on UpdateOne operations")
+ }
+ if m.id == nil || m.oldValue == nil {
+ return v, errors.New("OldCanceledAt requires an ID field in the mutation")
+ }
+ oldValue, err := m.oldValue(ctx)
+ if err != nil {
+ return v, fmt.Errorf("querying old value for OldCanceledAt: %w", err)
+ }
+ return oldValue.CanceledAt, nil
+}
+
+// ClearCanceledAt clears the value of the "canceled_at" field.
+func (m *UsageCleanupTaskMutation) ClearCanceledAt() {
+ m.canceled_at = nil
+ m.clearedFields[usagecleanuptask.FieldCanceledAt] = struct{}{}
+}
+
+// CanceledAtCleared returns if the "canceled_at" field was cleared in this mutation.
+func (m *UsageCleanupTaskMutation) CanceledAtCleared() bool {
+ _, ok := m.clearedFields[usagecleanuptask.FieldCanceledAt]
+ return ok
+}
+
+// ResetCanceledAt resets all changes to the "canceled_at" field.
+func (m *UsageCleanupTaskMutation) ResetCanceledAt() {
+ m.canceled_at = nil
+ delete(m.clearedFields, usagecleanuptask.FieldCanceledAt)
+}
+
+// SetStartedAt sets the "started_at" field.
+func (m *UsageCleanupTaskMutation) SetStartedAt(t time.Time) {
+ m.started_at = &t
+}
+
+// StartedAt returns the value of the "started_at" field in the mutation.
+func (m *UsageCleanupTaskMutation) StartedAt() (r time.Time, exists bool) {
+ v := m.started_at
+ if v == nil {
+ return
+ }
+ return *v, true
+}
+
+// OldStartedAt returns the old "started_at" field's value of the UsageCleanupTask entity.
+// If the UsageCleanupTask object wasn't provided to the builder, the object is fetched from the database.
+// An error is returned if the mutation operation is not UpdateOne, or the database query fails.
+func (m *UsageCleanupTaskMutation) OldStartedAt(ctx context.Context) (v *time.Time, err error) {
+ if !m.op.Is(OpUpdateOne) {
+ return v, errors.New("OldStartedAt is only allowed on UpdateOne operations")
+ }
+ if m.id == nil || m.oldValue == nil {
+ return v, errors.New("OldStartedAt requires an ID field in the mutation")
+ }
+ oldValue, err := m.oldValue(ctx)
+ if err != nil {
+ return v, fmt.Errorf("querying old value for OldStartedAt: %w", err)
+ }
+ return oldValue.StartedAt, nil
+}
+
+// ClearStartedAt clears the value of the "started_at" field.
+func (m *UsageCleanupTaskMutation) ClearStartedAt() {
+ m.started_at = nil
+ m.clearedFields[usagecleanuptask.FieldStartedAt] = struct{}{}
+}
+
+// StartedAtCleared returns if the "started_at" field was cleared in this mutation.
+func (m *UsageCleanupTaskMutation) StartedAtCleared() bool {
+ _, ok := m.clearedFields[usagecleanuptask.FieldStartedAt]
+ return ok
+}
+
+// ResetStartedAt resets all changes to the "started_at" field.
+func (m *UsageCleanupTaskMutation) ResetStartedAt() {
+ m.started_at = nil
+ delete(m.clearedFields, usagecleanuptask.FieldStartedAt)
+}
+
+// SetFinishedAt sets the "finished_at" field.
+func (m *UsageCleanupTaskMutation) SetFinishedAt(t time.Time) {
+ m.finished_at = &t
+}
+
+// FinishedAt returns the value of the "finished_at" field in the mutation.
+func (m *UsageCleanupTaskMutation) FinishedAt() (r time.Time, exists bool) {
+ v := m.finished_at
+ if v == nil {
+ return
+ }
+ return *v, true
+}
+
+// OldFinishedAt returns the old "finished_at" field's value of the UsageCleanupTask entity.
+// If the UsageCleanupTask object wasn't provided to the builder, the object is fetched from the database.
+// An error is returned if the mutation operation is not UpdateOne, or the database query fails.
+func (m *UsageCleanupTaskMutation) OldFinishedAt(ctx context.Context) (v *time.Time, err error) {
+ if !m.op.Is(OpUpdateOne) {
+ return v, errors.New("OldFinishedAt is only allowed on UpdateOne operations")
+ }
+ if m.id == nil || m.oldValue == nil {
+ return v, errors.New("OldFinishedAt requires an ID field in the mutation")
+ }
+ oldValue, err := m.oldValue(ctx)
+ if err != nil {
+ return v, fmt.Errorf("querying old value for OldFinishedAt: %w", err)
+ }
+ return oldValue.FinishedAt, nil
+}
+
+// ClearFinishedAt clears the value of the "finished_at" field.
+func (m *UsageCleanupTaskMutation) ClearFinishedAt() {
+ m.finished_at = nil
+ m.clearedFields[usagecleanuptask.FieldFinishedAt] = struct{}{}
+}
+
+// FinishedAtCleared returns if the "finished_at" field was cleared in this mutation.
+func (m *UsageCleanupTaskMutation) FinishedAtCleared() bool {
+ _, ok := m.clearedFields[usagecleanuptask.FieldFinishedAt]
+ return ok
+}
+
+// ResetFinishedAt resets all changes to the "finished_at" field.
+func (m *UsageCleanupTaskMutation) ResetFinishedAt() {
+ m.finished_at = nil
+ delete(m.clearedFields, usagecleanuptask.FieldFinishedAt)
+}
+
+// Where appends a list predicates to the UsageCleanupTaskMutation builder.
+func (m *UsageCleanupTaskMutation) Where(ps ...predicate.UsageCleanupTask) {
+ m.predicates = append(m.predicates, ps...)
+}
+
+// WhereP appends storage-level predicates to the UsageCleanupTaskMutation builder. Using this method,
+// users can use type-assertion to append predicates that do not depend on any generated package.
+func (m *UsageCleanupTaskMutation) WhereP(ps ...func(*sql.Selector)) {
+ p := make([]predicate.UsageCleanupTask, len(ps))
+ for i := range ps {
+ p[i] = ps[i]
+ }
+ m.Where(p...)
+}
+
+// Op returns the operation name.
+func (m *UsageCleanupTaskMutation) Op() Op {
+ return m.op
+}
+
+// SetOp allows setting the mutation operation.
+func (m *UsageCleanupTaskMutation) SetOp(op Op) {
+ m.op = op
+}
+
+// Type returns the node type of this mutation (UsageCleanupTask).
+func (m *UsageCleanupTaskMutation) Type() string {
+ return m.typ
+}
+
+// Fields returns all fields that were changed during this mutation. Note that in
+// order to get all numeric fields that were incremented/decremented, call
+// AddedFields().
+func (m *UsageCleanupTaskMutation) Fields() []string {
+ fields := make([]string, 0, 11)
+ if m.created_at != nil {
+ fields = append(fields, usagecleanuptask.FieldCreatedAt)
+ }
+ if m.updated_at != nil {
+ fields = append(fields, usagecleanuptask.FieldUpdatedAt)
+ }
+ if m.status != nil {
+ fields = append(fields, usagecleanuptask.FieldStatus)
+ }
+ if m.filters != nil {
+ fields = append(fields, usagecleanuptask.FieldFilters)
+ }
+ if m.created_by != nil {
+ fields = append(fields, usagecleanuptask.FieldCreatedBy)
+ }
+ if m.deleted_rows != nil {
+ fields = append(fields, usagecleanuptask.FieldDeletedRows)
+ }
+ if m.error_message != nil {
+ fields = append(fields, usagecleanuptask.FieldErrorMessage)
+ }
+ if m.canceled_by != nil {
+ fields = append(fields, usagecleanuptask.FieldCanceledBy)
+ }
+ if m.canceled_at != nil {
+ fields = append(fields, usagecleanuptask.FieldCanceledAt)
+ }
+ if m.started_at != nil {
+ fields = append(fields, usagecleanuptask.FieldStartedAt)
+ }
+ if m.finished_at != nil {
+ fields = append(fields, usagecleanuptask.FieldFinishedAt)
+ }
+ return fields
+}
+
+// Field returns the value of a field with the given name. The second boolean
+// return value indicates that this field was not set, or was not defined in the
+// schema.
+func (m *UsageCleanupTaskMutation) Field(name string) (ent.Value, bool) {
+ switch name {
+ case usagecleanuptask.FieldCreatedAt:
+ return m.CreatedAt()
+ case usagecleanuptask.FieldUpdatedAt:
+ return m.UpdatedAt()
+ case usagecleanuptask.FieldStatus:
+ return m.Status()
+ case usagecleanuptask.FieldFilters:
+ return m.Filters()
+ case usagecleanuptask.FieldCreatedBy:
+ return m.CreatedBy()
+ case usagecleanuptask.FieldDeletedRows:
+ return m.DeletedRows()
+ case usagecleanuptask.FieldErrorMessage:
+ return m.ErrorMessage()
+ case usagecleanuptask.FieldCanceledBy:
+ return m.CanceledBy()
+ case usagecleanuptask.FieldCanceledAt:
+ return m.CanceledAt()
+ case usagecleanuptask.FieldStartedAt:
+ return m.StartedAt()
+ case usagecleanuptask.FieldFinishedAt:
+ return m.FinishedAt()
+ }
+ return nil, false
+}
+
+// OldField returns the old value of the field from the database. An error is
+// returned if the mutation operation is not UpdateOne, or the query to the
+// database failed.
+func (m *UsageCleanupTaskMutation) OldField(ctx context.Context, name string) (ent.Value, error) {
+ switch name {
+ case usagecleanuptask.FieldCreatedAt:
+ return m.OldCreatedAt(ctx)
+ case usagecleanuptask.FieldUpdatedAt:
+ return m.OldUpdatedAt(ctx)
+ case usagecleanuptask.FieldStatus:
+ return m.OldStatus(ctx)
+ case usagecleanuptask.FieldFilters:
+ return m.OldFilters(ctx)
+ case usagecleanuptask.FieldCreatedBy:
+ return m.OldCreatedBy(ctx)
+ case usagecleanuptask.FieldDeletedRows:
+ return m.OldDeletedRows(ctx)
+ case usagecleanuptask.FieldErrorMessage:
+ return m.OldErrorMessage(ctx)
+ case usagecleanuptask.FieldCanceledBy:
+ return m.OldCanceledBy(ctx)
+ case usagecleanuptask.FieldCanceledAt:
+ return m.OldCanceledAt(ctx)
+ case usagecleanuptask.FieldStartedAt:
+ return m.OldStartedAt(ctx)
+ case usagecleanuptask.FieldFinishedAt:
+ return m.OldFinishedAt(ctx)
+ }
+ return nil, fmt.Errorf("unknown UsageCleanupTask field %s", name)
+}
+
+// SetField sets the value of a field with the given name. It returns an error if
+// the field is not defined in the schema, or if the type mismatched the field
+// type.
+func (m *UsageCleanupTaskMutation) SetField(name string, value ent.Value) error {
+ switch name {
+ case usagecleanuptask.FieldCreatedAt:
+ v, ok := value.(time.Time)
+ if !ok {
+ return fmt.Errorf("unexpected type %T for field %s", value, name)
+ }
+ m.SetCreatedAt(v)
+ return nil
+ case usagecleanuptask.FieldUpdatedAt:
+ v, ok := value.(time.Time)
+ if !ok {
+ return fmt.Errorf("unexpected type %T for field %s", value, name)
+ }
+ m.SetUpdatedAt(v)
+ return nil
+ case usagecleanuptask.FieldStatus:
+ v, ok := value.(string)
+ if !ok {
+ return fmt.Errorf("unexpected type %T for field %s", value, name)
+ }
+ m.SetStatus(v)
+ return nil
+ case usagecleanuptask.FieldFilters:
+ v, ok := value.(json.RawMessage)
+ if !ok {
+ return fmt.Errorf("unexpected type %T for field %s", value, name)
+ }
+ m.SetFilters(v)
+ return nil
+ case usagecleanuptask.FieldCreatedBy:
+ v, ok := value.(int64)
+ if !ok {
+ return fmt.Errorf("unexpected type %T for field %s", value, name)
+ }
+ m.SetCreatedBy(v)
+ return nil
+ case usagecleanuptask.FieldDeletedRows:
+ v, ok := value.(int64)
+ if !ok {
+ return fmt.Errorf("unexpected type %T for field %s", value, name)
+ }
+ m.SetDeletedRows(v)
+ return nil
+ case usagecleanuptask.FieldErrorMessage:
+ v, ok := value.(string)
+ if !ok {
+ return fmt.Errorf("unexpected type %T for field %s", value, name)
+ }
+ m.SetErrorMessage(v)
+ return nil
+ case usagecleanuptask.FieldCanceledBy:
+ v, ok := value.(int64)
+ if !ok {
+ return fmt.Errorf("unexpected type %T for field %s", value, name)
+ }
+ m.SetCanceledBy(v)
+ return nil
+ case usagecleanuptask.FieldCanceledAt:
+ v, ok := value.(time.Time)
+ if !ok {
+ return fmt.Errorf("unexpected type %T for field %s", value, name)
+ }
+ m.SetCanceledAt(v)
+ return nil
+ case usagecleanuptask.FieldStartedAt:
+ v, ok := value.(time.Time)
+ if !ok {
+ return fmt.Errorf("unexpected type %T for field %s", value, name)
+ }
+ m.SetStartedAt(v)
+ return nil
+ case usagecleanuptask.FieldFinishedAt:
+ v, ok := value.(time.Time)
+ if !ok {
+ return fmt.Errorf("unexpected type %T for field %s", value, name)
+ }
+ m.SetFinishedAt(v)
+ return nil
+ }
+ return fmt.Errorf("unknown UsageCleanupTask field %s", name)
+}
+
+// AddedFields returns all numeric fields that were incremented/decremented during
+// this mutation.
+func (m *UsageCleanupTaskMutation) AddedFields() []string {
+ var fields []string
+ if m.addcreated_by != nil {
+ fields = append(fields, usagecleanuptask.FieldCreatedBy)
+ }
+ if m.adddeleted_rows != nil {
+ fields = append(fields, usagecleanuptask.FieldDeletedRows)
+ }
+ if m.addcanceled_by != nil {
+ fields = append(fields, usagecleanuptask.FieldCanceledBy)
+ }
+ return fields
+}
+
+// AddedField returns the numeric value that was incremented/decremented on a field
+// with the given name. The second boolean return value indicates that this field
+// was not set, or was not defined in the schema.
+func (m *UsageCleanupTaskMutation) AddedField(name string) (ent.Value, bool) {
+ switch name {
+ case usagecleanuptask.FieldCreatedBy:
+ return m.AddedCreatedBy()
+ case usagecleanuptask.FieldDeletedRows:
+ return m.AddedDeletedRows()
+ case usagecleanuptask.FieldCanceledBy:
+ return m.AddedCanceledBy()
+ }
+ return nil, false
+}
+
+// AddField adds the value to the field with the given name. It returns an error if
+// the field is not defined in the schema, or if the type mismatched the field
+// type.
+func (m *UsageCleanupTaskMutation) AddField(name string, value ent.Value) error {
+ switch name {
+ case usagecleanuptask.FieldCreatedBy:
+ v, ok := value.(int64)
+ if !ok {
+ return fmt.Errorf("unexpected type %T for field %s", value, name)
+ }
+ m.AddCreatedBy(v)
+ return nil
+ case usagecleanuptask.FieldDeletedRows:
+ v, ok := value.(int64)
+ if !ok {
+ return fmt.Errorf("unexpected type %T for field %s", value, name)
+ }
+ m.AddDeletedRows(v)
+ return nil
+ case usagecleanuptask.FieldCanceledBy:
+ v, ok := value.(int64)
+ if !ok {
+ return fmt.Errorf("unexpected type %T for field %s", value, name)
+ }
+ m.AddCanceledBy(v)
+ return nil
+ }
+ return fmt.Errorf("unknown UsageCleanupTask numeric field %s", name)
+}
+
+// ClearedFields returns all nullable fields that were cleared during this
+// mutation.
+func (m *UsageCleanupTaskMutation) ClearedFields() []string {
+ var fields []string
+ if m.FieldCleared(usagecleanuptask.FieldErrorMessage) {
+ fields = append(fields, usagecleanuptask.FieldErrorMessage)
+ }
+ if m.FieldCleared(usagecleanuptask.FieldCanceledBy) {
+ fields = append(fields, usagecleanuptask.FieldCanceledBy)
+ }
+ if m.FieldCleared(usagecleanuptask.FieldCanceledAt) {
+ fields = append(fields, usagecleanuptask.FieldCanceledAt)
+ }
+ if m.FieldCleared(usagecleanuptask.FieldStartedAt) {
+ fields = append(fields, usagecleanuptask.FieldStartedAt)
+ }
+ if m.FieldCleared(usagecleanuptask.FieldFinishedAt) {
+ fields = append(fields, usagecleanuptask.FieldFinishedAt)
+ }
+ return fields
+}
+
+// FieldCleared returns a boolean indicating if a field with the given name was
+// cleared in this mutation.
+func (m *UsageCleanupTaskMutation) FieldCleared(name string) bool {
+ _, ok := m.clearedFields[name]
+ return ok
+}
+
+// ClearField clears the value of the field with the given name. It returns an
+// error if the field is not defined in the schema.
+func (m *UsageCleanupTaskMutation) ClearField(name string) error {
+ switch name {
+ case usagecleanuptask.FieldErrorMessage:
+ m.ClearErrorMessage()
+ return nil
+ case usagecleanuptask.FieldCanceledBy:
+ m.ClearCanceledBy()
+ return nil
+ case usagecleanuptask.FieldCanceledAt:
+ m.ClearCanceledAt()
+ return nil
+ case usagecleanuptask.FieldStartedAt:
+ m.ClearStartedAt()
+ return nil
+ case usagecleanuptask.FieldFinishedAt:
+ m.ClearFinishedAt()
+ return nil
+ }
+ return fmt.Errorf("unknown UsageCleanupTask nullable field %s", name)
+}
+
+// ResetField resets all changes in the mutation for the field with the given name.
+// It returns an error if the field is not defined in the schema.
+func (m *UsageCleanupTaskMutation) ResetField(name string) error {
+ switch name {
+ case usagecleanuptask.FieldCreatedAt:
+ m.ResetCreatedAt()
+ return nil
+ case usagecleanuptask.FieldUpdatedAt:
+ m.ResetUpdatedAt()
+ return nil
+ case usagecleanuptask.FieldStatus:
+ m.ResetStatus()
+ return nil
+ case usagecleanuptask.FieldFilters:
+ m.ResetFilters()
+ return nil
+ case usagecleanuptask.FieldCreatedBy:
+ m.ResetCreatedBy()
+ return nil
+ case usagecleanuptask.FieldDeletedRows:
+ m.ResetDeletedRows()
+ return nil
+ case usagecleanuptask.FieldErrorMessage:
+ m.ResetErrorMessage()
+ return nil
+ case usagecleanuptask.FieldCanceledBy:
+ m.ResetCanceledBy()
+ return nil
+ case usagecleanuptask.FieldCanceledAt:
+ m.ResetCanceledAt()
+ return nil
+ case usagecleanuptask.FieldStartedAt:
+ m.ResetStartedAt()
+ return nil
+ case usagecleanuptask.FieldFinishedAt:
+ m.ResetFinishedAt()
+ return nil
+ }
+ return fmt.Errorf("unknown UsageCleanupTask field %s", name)
+}
+
+// AddedEdges returns all edge names that were set/added in this mutation.
+func (m *UsageCleanupTaskMutation) AddedEdges() []string {
+ edges := make([]string, 0, 0)
+ return edges
+}
+
+// AddedIDs returns all IDs (to other nodes) that were added for the given edge
+// name in this mutation.
+func (m *UsageCleanupTaskMutation) AddedIDs(name string) []ent.Value {
+ return nil
+}
+
+// RemovedEdges returns all edge names that were removed in this mutation.
+func (m *UsageCleanupTaskMutation) RemovedEdges() []string {
+ edges := make([]string, 0, 0)
+ return edges
+}
+
+// RemovedIDs returns all IDs (to other nodes) that were removed for the edge with
+// the given name in this mutation.
+func (m *UsageCleanupTaskMutation) RemovedIDs(name string) []ent.Value {
+ return nil
+}
+
+// ClearedEdges returns all edge names that were cleared in this mutation.
+func (m *UsageCleanupTaskMutation) ClearedEdges() []string {
+ edges := make([]string, 0, 0)
+ return edges
+}
+
+// EdgeCleared returns a boolean which indicates if the edge with the given name
+// was cleared in this mutation.
+func (m *UsageCleanupTaskMutation) EdgeCleared(name string) bool {
+ return false
+}
+
+// ClearEdge clears the value of the edge with the given name. It returns an error
+// if that edge is not defined in the schema.
+func (m *UsageCleanupTaskMutation) ClearEdge(name string) error {
+ return fmt.Errorf("unknown UsageCleanupTask unique edge %s", name)
+}
+
+// ResetEdge resets all changes to the edge with the given name in this mutation.
+// It returns an error if the edge is not defined in the schema.
+func (m *UsageCleanupTaskMutation) ResetEdge(name string) error {
+ return fmt.Errorf("unknown UsageCleanupTask edge %s", name)
+}
+
// UsageLogMutation represents an operation that mutates the UsageLog nodes in the graph.
type UsageLogMutation struct {
config
diff --git a/backend/ent/predicate/predicate.go b/backend/ent/predicate/predicate.go
index 7a443c5d..785cb4e6 100644
--- a/backend/ent/predicate/predicate.go
+++ b/backend/ent/predicate/predicate.go
@@ -33,6 +33,9 @@ type RedeemCode func(*sql.Selector)
// Setting is the predicate function for setting builders.
type Setting func(*sql.Selector)
+// UsageCleanupTask is the predicate function for usagecleanuptask builders.
+type UsageCleanupTask func(*sql.Selector)
+
// UsageLog is the predicate function for usagelog builders.
type UsageLog func(*sql.Selector)
diff --git a/backend/ent/runtime/runtime.go b/backend/ent/runtime/runtime.go
index 0cb10775..1e3f4cbe 100644
--- a/backend/ent/runtime/runtime.go
+++ b/backend/ent/runtime/runtime.go
@@ -15,6 +15,7 @@ import (
"github.com/Wei-Shaw/sub2api/ent/redeemcode"
"github.com/Wei-Shaw/sub2api/ent/schema"
"github.com/Wei-Shaw/sub2api/ent/setting"
+ "github.com/Wei-Shaw/sub2api/ent/usagecleanuptask"
"github.com/Wei-Shaw/sub2api/ent/usagelog"
"github.com/Wei-Shaw/sub2api/ent/user"
"github.com/Wei-Shaw/sub2api/ent/userallowedgroup"
@@ -495,6 +496,43 @@ func init() {
setting.DefaultUpdatedAt = settingDescUpdatedAt.Default.(func() time.Time)
// setting.UpdateDefaultUpdatedAt holds the default value on update for the updated_at field.
setting.UpdateDefaultUpdatedAt = settingDescUpdatedAt.UpdateDefault.(func() time.Time)
+ usagecleanuptaskMixin := schema.UsageCleanupTask{}.Mixin()
+ usagecleanuptaskMixinFields0 := usagecleanuptaskMixin[0].Fields()
+ _ = usagecleanuptaskMixinFields0
+ usagecleanuptaskFields := schema.UsageCleanupTask{}.Fields()
+ _ = usagecleanuptaskFields
+ // usagecleanuptaskDescCreatedAt is the schema descriptor for created_at field.
+ usagecleanuptaskDescCreatedAt := usagecleanuptaskMixinFields0[0].Descriptor()
+ // usagecleanuptask.DefaultCreatedAt holds the default value on creation for the created_at field.
+ usagecleanuptask.DefaultCreatedAt = usagecleanuptaskDescCreatedAt.Default.(func() time.Time)
+ // usagecleanuptaskDescUpdatedAt is the schema descriptor for updated_at field.
+ usagecleanuptaskDescUpdatedAt := usagecleanuptaskMixinFields0[1].Descriptor()
+ // usagecleanuptask.DefaultUpdatedAt holds the default value on creation for the updated_at field.
+ usagecleanuptask.DefaultUpdatedAt = usagecleanuptaskDescUpdatedAt.Default.(func() time.Time)
+ // usagecleanuptask.UpdateDefaultUpdatedAt holds the default value on update for the updated_at field.
+ usagecleanuptask.UpdateDefaultUpdatedAt = usagecleanuptaskDescUpdatedAt.UpdateDefault.(func() time.Time)
+ // usagecleanuptaskDescStatus is the schema descriptor for status field.
+ usagecleanuptaskDescStatus := usagecleanuptaskFields[0].Descriptor()
+ // usagecleanuptask.StatusValidator is a validator for the "status" field. It is called by the builders before save.
+ usagecleanuptask.StatusValidator = func() func(string) error {
+ validators := usagecleanuptaskDescStatus.Validators
+ fns := [...]func(string) error{
+ validators[0].(func(string) error),
+ validators[1].(func(string) error),
+ }
+ return func(status string) error {
+ for _, fn := range fns {
+ if err := fn(status); err != nil {
+ return err
+ }
+ }
+ return nil
+ }
+ }()
+ // usagecleanuptaskDescDeletedRows is the schema descriptor for deleted_rows field.
+ usagecleanuptaskDescDeletedRows := usagecleanuptaskFields[3].Descriptor()
+ // usagecleanuptask.DefaultDeletedRows holds the default value on creation for the deleted_rows field.
+ usagecleanuptask.DefaultDeletedRows = usagecleanuptaskDescDeletedRows.Default.(int64)
usagelogFields := schema.UsageLog{}.Fields()
_ = usagelogFields
// usagelogDescRequestID is the schema descriptor for request_id field.
diff --git a/backend/ent/schema/mixins/soft_delete.go b/backend/ent/schema/mixins/soft_delete.go
index 9571bc9c..461c7348 100644
--- a/backend/ent/schema/mixins/soft_delete.go
+++ b/backend/ent/schema/mixins/soft_delete.go
@@ -12,7 +12,6 @@ import (
"entgo.io/ent/dialect/sql"
"entgo.io/ent/schema/field"
"entgo.io/ent/schema/mixin"
- dbent "github.com/Wei-Shaw/sub2api/ent"
"github.com/Wei-Shaw/sub2api/ent/intercept"
)
@@ -113,7 +112,9 @@ func (d SoftDeleteMixin) Hooks() []ent.Hook {
SetOp(ent.Op)
SetDeletedAt(time.Time)
WhereP(...func(*sql.Selector))
- Client() *dbent.Client
+ Client() interface {
+ Mutate(context.Context, ent.Mutation) (ent.Value, error)
+ }
})
if !ok {
return nil, fmt.Errorf("unexpected mutation type %T", m)
diff --git a/backend/ent/schema/usage_cleanup_task.go b/backend/ent/schema/usage_cleanup_task.go
new file mode 100644
index 00000000..753e6410
--- /dev/null
+++ b/backend/ent/schema/usage_cleanup_task.go
@@ -0,0 +1,75 @@
+package schema
+
+import (
+ "encoding/json"
+ "fmt"
+
+ "github.com/Wei-Shaw/sub2api/ent/schema/mixins"
+
+ "entgo.io/ent"
+ "entgo.io/ent/dialect/entsql"
+ "entgo.io/ent/schema"
+ "entgo.io/ent/schema/field"
+ "entgo.io/ent/schema/index"
+)
+
+// UsageCleanupTask 定义使用记录清理任务的 schema。
+type UsageCleanupTask struct {
+ ent.Schema
+}
+
+func (UsageCleanupTask) Annotations() []schema.Annotation {
+ return []schema.Annotation{
+ entsql.Annotation{Table: "usage_cleanup_tasks"},
+ }
+}
+
+func (UsageCleanupTask) Mixin() []ent.Mixin {
+ return []ent.Mixin{
+ mixins.TimeMixin{},
+ }
+}
+
+func (UsageCleanupTask) Fields() []ent.Field {
+ return []ent.Field{
+ field.String("status").
+ MaxLen(20).
+ Validate(validateUsageCleanupStatus),
+ field.JSON("filters", json.RawMessage{}),
+ field.Int64("created_by"),
+ field.Int64("deleted_rows").
+ Default(0),
+ field.String("error_message").
+ Optional().
+ Nillable(),
+ field.Int64("canceled_by").
+ Optional().
+ Nillable(),
+ field.Time("canceled_at").
+ Optional().
+ Nillable(),
+ field.Time("started_at").
+ Optional().
+ Nillable(),
+ field.Time("finished_at").
+ Optional().
+ Nillable(),
+ }
+}
+
+func (UsageCleanupTask) Indexes() []ent.Index {
+ return []ent.Index{
+ index.Fields("status", "created_at"),
+ index.Fields("created_at"),
+ index.Fields("canceled_at"),
+ }
+}
+
+func validateUsageCleanupStatus(status string) error {
+ switch status {
+ case "pending", "running", "succeeded", "failed", "canceled":
+ return nil
+ default:
+ return fmt.Errorf("invalid usage cleanup status: %s", status)
+ }
+}
diff --git a/backend/ent/tx.go b/backend/ent/tx.go
index 56df121a..7ff16ec8 100644
--- a/backend/ent/tx.go
+++ b/backend/ent/tx.go
@@ -32,6 +32,8 @@ type Tx struct {
RedeemCode *RedeemCodeClient
// Setting is the client for interacting with the Setting builders.
Setting *SettingClient
+ // UsageCleanupTask is the client for interacting with the UsageCleanupTask builders.
+ UsageCleanupTask *UsageCleanupTaskClient
// UsageLog is the client for interacting with the UsageLog builders.
UsageLog *UsageLogClient
// User is the client for interacting with the User builders.
@@ -184,6 +186,7 @@ func (tx *Tx) init() {
tx.Proxy = NewProxyClient(tx.config)
tx.RedeemCode = NewRedeemCodeClient(tx.config)
tx.Setting = NewSettingClient(tx.config)
+ tx.UsageCleanupTask = NewUsageCleanupTaskClient(tx.config)
tx.UsageLog = NewUsageLogClient(tx.config)
tx.User = NewUserClient(tx.config)
tx.UserAllowedGroup = NewUserAllowedGroupClient(tx.config)
diff --git a/backend/ent/usagecleanuptask.go b/backend/ent/usagecleanuptask.go
new file mode 100644
index 00000000..e3a17b5a
--- /dev/null
+++ b/backend/ent/usagecleanuptask.go
@@ -0,0 +1,236 @@
+// Code generated by ent, DO NOT EDIT.
+
+package ent
+
+import (
+ "encoding/json"
+ "fmt"
+ "strings"
+ "time"
+
+ "entgo.io/ent"
+ "entgo.io/ent/dialect/sql"
+ "github.com/Wei-Shaw/sub2api/ent/usagecleanuptask"
+)
+
+// UsageCleanupTask is the model entity for the UsageCleanupTask schema.
+type UsageCleanupTask struct {
+ config `json:"-"`
+ // ID of the ent.
+ ID int64 `json:"id,omitempty"`
+ // CreatedAt holds the value of the "created_at" field.
+ CreatedAt time.Time `json:"created_at,omitempty"`
+ // UpdatedAt holds the value of the "updated_at" field.
+ UpdatedAt time.Time `json:"updated_at,omitempty"`
+ // Status holds the value of the "status" field.
+ Status string `json:"status,omitempty"`
+ // Filters holds the value of the "filters" field.
+ Filters json.RawMessage `json:"filters,omitempty"`
+ // CreatedBy holds the value of the "created_by" field.
+ CreatedBy int64 `json:"created_by,omitempty"`
+ // DeletedRows holds the value of the "deleted_rows" field.
+ DeletedRows int64 `json:"deleted_rows,omitempty"`
+ // ErrorMessage holds the value of the "error_message" field.
+ ErrorMessage *string `json:"error_message,omitempty"`
+ // CanceledBy holds the value of the "canceled_by" field.
+ CanceledBy *int64 `json:"canceled_by,omitempty"`
+ // CanceledAt holds the value of the "canceled_at" field.
+ CanceledAt *time.Time `json:"canceled_at,omitempty"`
+ // StartedAt holds the value of the "started_at" field.
+ StartedAt *time.Time `json:"started_at,omitempty"`
+ // FinishedAt holds the value of the "finished_at" field.
+ FinishedAt *time.Time `json:"finished_at,omitempty"`
+ selectValues sql.SelectValues
+}
+
+// scanValues returns the types for scanning values from sql.Rows.
+func (*UsageCleanupTask) scanValues(columns []string) ([]any, error) {
+ values := make([]any, len(columns))
+ for i := range columns {
+ switch columns[i] {
+ case usagecleanuptask.FieldFilters:
+ values[i] = new([]byte)
+ case usagecleanuptask.FieldID, usagecleanuptask.FieldCreatedBy, usagecleanuptask.FieldDeletedRows, usagecleanuptask.FieldCanceledBy:
+ values[i] = new(sql.NullInt64)
+ case usagecleanuptask.FieldStatus, usagecleanuptask.FieldErrorMessage:
+ values[i] = new(sql.NullString)
+ case usagecleanuptask.FieldCreatedAt, usagecleanuptask.FieldUpdatedAt, usagecleanuptask.FieldCanceledAt, usagecleanuptask.FieldStartedAt, usagecleanuptask.FieldFinishedAt:
+ values[i] = new(sql.NullTime)
+ default:
+ values[i] = new(sql.UnknownType)
+ }
+ }
+ return values, nil
+}
+
+// assignValues assigns the values that were returned from sql.Rows (after scanning)
+// to the UsageCleanupTask fields.
+func (_m *UsageCleanupTask) assignValues(columns []string, values []any) error {
+ if m, n := len(values), len(columns); m < n {
+ return fmt.Errorf("mismatch number of scan values: %d != %d", m, n)
+ }
+ for i := range columns {
+ switch columns[i] {
+ case usagecleanuptask.FieldID:
+ value, ok := values[i].(*sql.NullInt64)
+ if !ok {
+ return fmt.Errorf("unexpected type %T for field id", value)
+ }
+ _m.ID = int64(value.Int64)
+ case usagecleanuptask.FieldCreatedAt:
+ if value, ok := values[i].(*sql.NullTime); !ok {
+ return fmt.Errorf("unexpected type %T for field created_at", values[i])
+ } else if value.Valid {
+ _m.CreatedAt = value.Time
+ }
+ case usagecleanuptask.FieldUpdatedAt:
+ if value, ok := values[i].(*sql.NullTime); !ok {
+ return fmt.Errorf("unexpected type %T for field updated_at", values[i])
+ } else if value.Valid {
+ _m.UpdatedAt = value.Time
+ }
+ case usagecleanuptask.FieldStatus:
+ if value, ok := values[i].(*sql.NullString); !ok {
+ return fmt.Errorf("unexpected type %T for field status", values[i])
+ } else if value.Valid {
+ _m.Status = value.String
+ }
+ case usagecleanuptask.FieldFilters:
+ if value, ok := values[i].(*[]byte); !ok {
+ return fmt.Errorf("unexpected type %T for field filters", values[i])
+ } else if value != nil && len(*value) > 0 {
+ if err := json.Unmarshal(*value, &_m.Filters); err != nil {
+ return fmt.Errorf("unmarshal field filters: %w", err)
+ }
+ }
+ case usagecleanuptask.FieldCreatedBy:
+ if value, ok := values[i].(*sql.NullInt64); !ok {
+ return fmt.Errorf("unexpected type %T for field created_by", values[i])
+ } else if value.Valid {
+ _m.CreatedBy = value.Int64
+ }
+ case usagecleanuptask.FieldDeletedRows:
+ if value, ok := values[i].(*sql.NullInt64); !ok {
+ return fmt.Errorf("unexpected type %T for field deleted_rows", values[i])
+ } else if value.Valid {
+ _m.DeletedRows = value.Int64
+ }
+ case usagecleanuptask.FieldErrorMessage:
+ if value, ok := values[i].(*sql.NullString); !ok {
+ return fmt.Errorf("unexpected type %T for field error_message", values[i])
+ } else if value.Valid {
+ _m.ErrorMessage = new(string)
+ *_m.ErrorMessage = value.String
+ }
+ case usagecleanuptask.FieldCanceledBy:
+ if value, ok := values[i].(*sql.NullInt64); !ok {
+ return fmt.Errorf("unexpected type %T for field canceled_by", values[i])
+ } else if value.Valid {
+ _m.CanceledBy = new(int64)
+ *_m.CanceledBy = value.Int64
+ }
+ case usagecleanuptask.FieldCanceledAt:
+ if value, ok := values[i].(*sql.NullTime); !ok {
+ return fmt.Errorf("unexpected type %T for field canceled_at", values[i])
+ } else if value.Valid {
+ _m.CanceledAt = new(time.Time)
+ *_m.CanceledAt = value.Time
+ }
+ case usagecleanuptask.FieldStartedAt:
+ if value, ok := values[i].(*sql.NullTime); !ok {
+ return fmt.Errorf("unexpected type %T for field started_at", values[i])
+ } else if value.Valid {
+ _m.StartedAt = new(time.Time)
+ *_m.StartedAt = value.Time
+ }
+ case usagecleanuptask.FieldFinishedAt:
+ if value, ok := values[i].(*sql.NullTime); !ok {
+ return fmt.Errorf("unexpected type %T for field finished_at", values[i])
+ } else if value.Valid {
+ _m.FinishedAt = new(time.Time)
+ *_m.FinishedAt = value.Time
+ }
+ default:
+ _m.selectValues.Set(columns[i], values[i])
+ }
+ }
+ return nil
+}
+
+// Value returns the ent.Value that was dynamically selected and assigned to the UsageCleanupTask.
+// This includes values selected through modifiers, order, etc.
+func (_m *UsageCleanupTask) Value(name string) (ent.Value, error) {
+ return _m.selectValues.Get(name)
+}
+
+// Update returns a builder for updating this UsageCleanupTask.
+// Note that you need to call UsageCleanupTask.Unwrap() before calling this method if this UsageCleanupTask
+// was returned from a transaction, and the transaction was committed or rolled back.
+func (_m *UsageCleanupTask) Update() *UsageCleanupTaskUpdateOne {
+ return NewUsageCleanupTaskClient(_m.config).UpdateOne(_m)
+}
+
+// Unwrap unwraps the UsageCleanupTask entity that was returned from a transaction after it was closed,
+// so that all future queries will be executed through the driver which created the transaction.
+func (_m *UsageCleanupTask) Unwrap() *UsageCleanupTask {
+ _tx, ok := _m.config.driver.(*txDriver)
+ if !ok {
+ panic("ent: UsageCleanupTask is not a transactional entity")
+ }
+ _m.config.driver = _tx.drv
+ return _m
+}
+
+// String implements the fmt.Stringer.
+func (_m *UsageCleanupTask) String() string {
+ var builder strings.Builder
+ builder.WriteString("UsageCleanupTask(")
+ builder.WriteString(fmt.Sprintf("id=%v, ", _m.ID))
+ builder.WriteString("created_at=")
+ builder.WriteString(_m.CreatedAt.Format(time.ANSIC))
+ builder.WriteString(", ")
+ builder.WriteString("updated_at=")
+ builder.WriteString(_m.UpdatedAt.Format(time.ANSIC))
+ builder.WriteString(", ")
+ builder.WriteString("status=")
+ builder.WriteString(_m.Status)
+ builder.WriteString(", ")
+ builder.WriteString("filters=")
+ builder.WriteString(fmt.Sprintf("%v", _m.Filters))
+ builder.WriteString(", ")
+ builder.WriteString("created_by=")
+ builder.WriteString(fmt.Sprintf("%v", _m.CreatedBy))
+ builder.WriteString(", ")
+ builder.WriteString("deleted_rows=")
+ builder.WriteString(fmt.Sprintf("%v", _m.DeletedRows))
+ builder.WriteString(", ")
+ if v := _m.ErrorMessage; v != nil {
+ builder.WriteString("error_message=")
+ builder.WriteString(*v)
+ }
+ builder.WriteString(", ")
+ if v := _m.CanceledBy; v != nil {
+ builder.WriteString("canceled_by=")
+ builder.WriteString(fmt.Sprintf("%v", *v))
+ }
+ builder.WriteString(", ")
+ if v := _m.CanceledAt; v != nil {
+ builder.WriteString("canceled_at=")
+ builder.WriteString(v.Format(time.ANSIC))
+ }
+ builder.WriteString(", ")
+ if v := _m.StartedAt; v != nil {
+ builder.WriteString("started_at=")
+ builder.WriteString(v.Format(time.ANSIC))
+ }
+ builder.WriteString(", ")
+ if v := _m.FinishedAt; v != nil {
+ builder.WriteString("finished_at=")
+ builder.WriteString(v.Format(time.ANSIC))
+ }
+ builder.WriteByte(')')
+ return builder.String()
+}
+
+// UsageCleanupTasks is a parsable slice of UsageCleanupTask.
+type UsageCleanupTasks []*UsageCleanupTask
diff --git a/backend/ent/usagecleanuptask/usagecleanuptask.go b/backend/ent/usagecleanuptask/usagecleanuptask.go
new file mode 100644
index 00000000..a8ddd9a0
--- /dev/null
+++ b/backend/ent/usagecleanuptask/usagecleanuptask.go
@@ -0,0 +1,137 @@
+// Code generated by ent, DO NOT EDIT.
+
+package usagecleanuptask
+
+import (
+ "time"
+
+ "entgo.io/ent/dialect/sql"
+)
+
+const (
+ // Label holds the string label denoting the usagecleanuptask type in the database.
+ Label = "usage_cleanup_task"
+ // FieldID holds the string denoting the id field in the database.
+ FieldID = "id"
+ // FieldCreatedAt holds the string denoting the created_at field in the database.
+ FieldCreatedAt = "created_at"
+ // FieldUpdatedAt holds the string denoting the updated_at field in the database.
+ FieldUpdatedAt = "updated_at"
+ // FieldStatus holds the string denoting the status field in the database.
+ FieldStatus = "status"
+ // FieldFilters holds the string denoting the filters field in the database.
+ FieldFilters = "filters"
+ // FieldCreatedBy holds the string denoting the created_by field in the database.
+ FieldCreatedBy = "created_by"
+ // FieldDeletedRows holds the string denoting the deleted_rows field in the database.
+ FieldDeletedRows = "deleted_rows"
+ // FieldErrorMessage holds the string denoting the error_message field in the database.
+ FieldErrorMessage = "error_message"
+ // FieldCanceledBy holds the string denoting the canceled_by field in the database.
+ FieldCanceledBy = "canceled_by"
+ // FieldCanceledAt holds the string denoting the canceled_at field in the database.
+ FieldCanceledAt = "canceled_at"
+ // FieldStartedAt holds the string denoting the started_at field in the database.
+ FieldStartedAt = "started_at"
+ // FieldFinishedAt holds the string denoting the finished_at field in the database.
+ FieldFinishedAt = "finished_at"
+ // Table holds the table name of the usagecleanuptask in the database.
+ Table = "usage_cleanup_tasks"
+)
+
+// Columns holds all SQL columns for usagecleanuptask fields.
+var Columns = []string{
+ FieldID,
+ FieldCreatedAt,
+ FieldUpdatedAt,
+ FieldStatus,
+ FieldFilters,
+ FieldCreatedBy,
+ FieldDeletedRows,
+ FieldErrorMessage,
+ FieldCanceledBy,
+ FieldCanceledAt,
+ FieldStartedAt,
+ FieldFinishedAt,
+}
+
+// ValidColumn reports if the column name is valid (part of the table columns).
+func ValidColumn(column string) bool {
+ for i := range Columns {
+ if column == Columns[i] {
+ return true
+ }
+ }
+ return false
+}
+
+var (
+ // DefaultCreatedAt holds the default value on creation for the "created_at" field.
+ DefaultCreatedAt func() time.Time
+ // DefaultUpdatedAt holds the default value on creation for the "updated_at" field.
+ DefaultUpdatedAt func() time.Time
+ // UpdateDefaultUpdatedAt holds the default value on update for the "updated_at" field.
+ UpdateDefaultUpdatedAt func() time.Time
+ // StatusValidator is a validator for the "status" field. It is called by the builders before save.
+ StatusValidator func(string) error
+ // DefaultDeletedRows holds the default value on creation for the "deleted_rows" field.
+ DefaultDeletedRows int64
+)
+
+// OrderOption defines the ordering options for the UsageCleanupTask queries.
+type OrderOption func(*sql.Selector)
+
+// ByID orders the results by the id field.
+func ByID(opts ...sql.OrderTermOption) OrderOption {
+ return sql.OrderByField(FieldID, opts...).ToFunc()
+}
+
+// ByCreatedAt orders the results by the created_at field.
+func ByCreatedAt(opts ...sql.OrderTermOption) OrderOption {
+ return sql.OrderByField(FieldCreatedAt, opts...).ToFunc()
+}
+
+// ByUpdatedAt orders the results by the updated_at field.
+func ByUpdatedAt(opts ...sql.OrderTermOption) OrderOption {
+ return sql.OrderByField(FieldUpdatedAt, opts...).ToFunc()
+}
+
+// ByStatus orders the results by the status field.
+func ByStatus(opts ...sql.OrderTermOption) OrderOption {
+ return sql.OrderByField(FieldStatus, opts...).ToFunc()
+}
+
+// ByCreatedBy orders the results by the created_by field.
+func ByCreatedBy(opts ...sql.OrderTermOption) OrderOption {
+ return sql.OrderByField(FieldCreatedBy, opts...).ToFunc()
+}
+
+// ByDeletedRows orders the results by the deleted_rows field.
+func ByDeletedRows(opts ...sql.OrderTermOption) OrderOption {
+ return sql.OrderByField(FieldDeletedRows, opts...).ToFunc()
+}
+
+// ByErrorMessage orders the results by the error_message field.
+func ByErrorMessage(opts ...sql.OrderTermOption) OrderOption {
+ return sql.OrderByField(FieldErrorMessage, opts...).ToFunc()
+}
+
+// ByCanceledBy orders the results by the canceled_by field.
+func ByCanceledBy(opts ...sql.OrderTermOption) OrderOption {
+ return sql.OrderByField(FieldCanceledBy, opts...).ToFunc()
+}
+
+// ByCanceledAt orders the results by the canceled_at field.
+func ByCanceledAt(opts ...sql.OrderTermOption) OrderOption {
+ return sql.OrderByField(FieldCanceledAt, opts...).ToFunc()
+}
+
+// ByStartedAt orders the results by the started_at field.
+func ByStartedAt(opts ...sql.OrderTermOption) OrderOption {
+ return sql.OrderByField(FieldStartedAt, opts...).ToFunc()
+}
+
+// ByFinishedAt orders the results by the finished_at field.
+func ByFinishedAt(opts ...sql.OrderTermOption) OrderOption {
+ return sql.OrderByField(FieldFinishedAt, opts...).ToFunc()
+}
diff --git a/backend/ent/usagecleanuptask/where.go b/backend/ent/usagecleanuptask/where.go
new file mode 100644
index 00000000..99e790ca
--- /dev/null
+++ b/backend/ent/usagecleanuptask/where.go
@@ -0,0 +1,620 @@
+// Code generated by ent, DO NOT EDIT.
+
+package usagecleanuptask
+
+import (
+ "time"
+
+ "entgo.io/ent/dialect/sql"
+ "github.com/Wei-Shaw/sub2api/ent/predicate"
+)
+
+// ID filters vertices based on their ID field.
+func ID(id int64) predicate.UsageCleanupTask {
+ return predicate.UsageCleanupTask(sql.FieldEQ(FieldID, id))
+}
+
+// IDEQ applies the EQ predicate on the ID field.
+func IDEQ(id int64) predicate.UsageCleanupTask {
+ return predicate.UsageCleanupTask(sql.FieldEQ(FieldID, id))
+}
+
+// IDNEQ applies the NEQ predicate on the ID field.
+func IDNEQ(id int64) predicate.UsageCleanupTask {
+ return predicate.UsageCleanupTask(sql.FieldNEQ(FieldID, id))
+}
+
+// IDIn applies the In predicate on the ID field.
+func IDIn(ids ...int64) predicate.UsageCleanupTask {
+ return predicate.UsageCleanupTask(sql.FieldIn(FieldID, ids...))
+}
+
+// IDNotIn applies the NotIn predicate on the ID field.
+func IDNotIn(ids ...int64) predicate.UsageCleanupTask {
+ return predicate.UsageCleanupTask(sql.FieldNotIn(FieldID, ids...))
+}
+
+// IDGT applies the GT predicate on the ID field.
+func IDGT(id int64) predicate.UsageCleanupTask {
+ return predicate.UsageCleanupTask(sql.FieldGT(FieldID, id))
+}
+
+// IDGTE applies the GTE predicate on the ID field.
+func IDGTE(id int64) predicate.UsageCleanupTask {
+ return predicate.UsageCleanupTask(sql.FieldGTE(FieldID, id))
+}
+
+// IDLT applies the LT predicate on the ID field.
+func IDLT(id int64) predicate.UsageCleanupTask {
+ return predicate.UsageCleanupTask(sql.FieldLT(FieldID, id))
+}
+
+// IDLTE applies the LTE predicate on the ID field.
+func IDLTE(id int64) predicate.UsageCleanupTask {
+ return predicate.UsageCleanupTask(sql.FieldLTE(FieldID, id))
+}
+
+// CreatedAt applies equality check predicate on the "created_at" field. It's identical to CreatedAtEQ.
+func CreatedAt(v time.Time) predicate.UsageCleanupTask {
+ return predicate.UsageCleanupTask(sql.FieldEQ(FieldCreatedAt, v))
+}
+
+// UpdatedAt applies equality check predicate on the "updated_at" field. It's identical to UpdatedAtEQ.
+func UpdatedAt(v time.Time) predicate.UsageCleanupTask {
+ return predicate.UsageCleanupTask(sql.FieldEQ(FieldUpdatedAt, v))
+}
+
+// Status applies equality check predicate on the "status" field. It's identical to StatusEQ.
+func Status(v string) predicate.UsageCleanupTask {
+ return predicate.UsageCleanupTask(sql.FieldEQ(FieldStatus, v))
+}
+
+// CreatedBy applies equality check predicate on the "created_by" field. It's identical to CreatedByEQ.
+func CreatedBy(v int64) predicate.UsageCleanupTask {
+ return predicate.UsageCleanupTask(sql.FieldEQ(FieldCreatedBy, v))
+}
+
+// DeletedRows applies equality check predicate on the "deleted_rows" field. It's identical to DeletedRowsEQ.
+func DeletedRows(v int64) predicate.UsageCleanupTask {
+ return predicate.UsageCleanupTask(sql.FieldEQ(FieldDeletedRows, v))
+}
+
+// ErrorMessage applies equality check predicate on the "error_message" field. It's identical to ErrorMessageEQ.
+func ErrorMessage(v string) predicate.UsageCleanupTask {
+ return predicate.UsageCleanupTask(sql.FieldEQ(FieldErrorMessage, v))
+}
+
+// CanceledBy applies equality check predicate on the "canceled_by" field. It's identical to CanceledByEQ.
+func CanceledBy(v int64) predicate.UsageCleanupTask {
+ return predicate.UsageCleanupTask(sql.FieldEQ(FieldCanceledBy, v))
+}
+
+// CanceledAt applies equality check predicate on the "canceled_at" field. It's identical to CanceledAtEQ.
+func CanceledAt(v time.Time) predicate.UsageCleanupTask {
+ return predicate.UsageCleanupTask(sql.FieldEQ(FieldCanceledAt, v))
+}
+
+// StartedAt applies equality check predicate on the "started_at" field. It's identical to StartedAtEQ.
+func StartedAt(v time.Time) predicate.UsageCleanupTask {
+ return predicate.UsageCleanupTask(sql.FieldEQ(FieldStartedAt, v))
+}
+
+// FinishedAt applies equality check predicate on the "finished_at" field. It's identical to FinishedAtEQ.
+func FinishedAt(v time.Time) predicate.UsageCleanupTask {
+ return predicate.UsageCleanupTask(sql.FieldEQ(FieldFinishedAt, v))
+}
+
+// CreatedAtEQ applies the EQ predicate on the "created_at" field.
+func CreatedAtEQ(v time.Time) predicate.UsageCleanupTask {
+ return predicate.UsageCleanupTask(sql.FieldEQ(FieldCreatedAt, v))
+}
+
+// CreatedAtNEQ applies the NEQ predicate on the "created_at" field.
+func CreatedAtNEQ(v time.Time) predicate.UsageCleanupTask {
+ return predicate.UsageCleanupTask(sql.FieldNEQ(FieldCreatedAt, v))
+}
+
+// CreatedAtIn applies the In predicate on the "created_at" field.
+func CreatedAtIn(vs ...time.Time) predicate.UsageCleanupTask {
+ return predicate.UsageCleanupTask(sql.FieldIn(FieldCreatedAt, vs...))
+}
+
+// CreatedAtNotIn applies the NotIn predicate on the "created_at" field.
+func CreatedAtNotIn(vs ...time.Time) predicate.UsageCleanupTask {
+ return predicate.UsageCleanupTask(sql.FieldNotIn(FieldCreatedAt, vs...))
+}
+
+// CreatedAtGT applies the GT predicate on the "created_at" field.
+func CreatedAtGT(v time.Time) predicate.UsageCleanupTask {
+ return predicate.UsageCleanupTask(sql.FieldGT(FieldCreatedAt, v))
+}
+
+// CreatedAtGTE applies the GTE predicate on the "created_at" field.
+func CreatedAtGTE(v time.Time) predicate.UsageCleanupTask {
+ return predicate.UsageCleanupTask(sql.FieldGTE(FieldCreatedAt, v))
+}
+
+// CreatedAtLT applies the LT predicate on the "created_at" field.
+func CreatedAtLT(v time.Time) predicate.UsageCleanupTask {
+ return predicate.UsageCleanupTask(sql.FieldLT(FieldCreatedAt, v))
+}
+
+// CreatedAtLTE applies the LTE predicate on the "created_at" field.
+func CreatedAtLTE(v time.Time) predicate.UsageCleanupTask {
+ return predicate.UsageCleanupTask(sql.FieldLTE(FieldCreatedAt, v))
+}
+
+// UpdatedAtEQ applies the EQ predicate on the "updated_at" field.
+func UpdatedAtEQ(v time.Time) predicate.UsageCleanupTask {
+ return predicate.UsageCleanupTask(sql.FieldEQ(FieldUpdatedAt, v))
+}
+
+// UpdatedAtNEQ applies the NEQ predicate on the "updated_at" field.
+func UpdatedAtNEQ(v time.Time) predicate.UsageCleanupTask {
+ return predicate.UsageCleanupTask(sql.FieldNEQ(FieldUpdatedAt, v))
+}
+
+// UpdatedAtIn applies the In predicate on the "updated_at" field.
+func UpdatedAtIn(vs ...time.Time) predicate.UsageCleanupTask {
+ return predicate.UsageCleanupTask(sql.FieldIn(FieldUpdatedAt, vs...))
+}
+
+// UpdatedAtNotIn applies the NotIn predicate on the "updated_at" field.
+func UpdatedAtNotIn(vs ...time.Time) predicate.UsageCleanupTask {
+ return predicate.UsageCleanupTask(sql.FieldNotIn(FieldUpdatedAt, vs...))
+}
+
+// UpdatedAtGT applies the GT predicate on the "updated_at" field.
+func UpdatedAtGT(v time.Time) predicate.UsageCleanupTask {
+ return predicate.UsageCleanupTask(sql.FieldGT(FieldUpdatedAt, v))
+}
+
+// UpdatedAtGTE applies the GTE predicate on the "updated_at" field.
+func UpdatedAtGTE(v time.Time) predicate.UsageCleanupTask {
+ return predicate.UsageCleanupTask(sql.FieldGTE(FieldUpdatedAt, v))
+}
+
+// UpdatedAtLT applies the LT predicate on the "updated_at" field.
+func UpdatedAtLT(v time.Time) predicate.UsageCleanupTask {
+ return predicate.UsageCleanupTask(sql.FieldLT(FieldUpdatedAt, v))
+}
+
+// UpdatedAtLTE applies the LTE predicate on the "updated_at" field.
+func UpdatedAtLTE(v time.Time) predicate.UsageCleanupTask {
+ return predicate.UsageCleanupTask(sql.FieldLTE(FieldUpdatedAt, v))
+}
+
+// StatusEQ applies the EQ predicate on the "status" field.
+func StatusEQ(v string) predicate.UsageCleanupTask {
+ return predicate.UsageCleanupTask(sql.FieldEQ(FieldStatus, v))
+}
+
+// StatusNEQ applies the NEQ predicate on the "status" field.
+func StatusNEQ(v string) predicate.UsageCleanupTask {
+ return predicate.UsageCleanupTask(sql.FieldNEQ(FieldStatus, v))
+}
+
+// StatusIn applies the In predicate on the "status" field.
+func StatusIn(vs ...string) predicate.UsageCleanupTask {
+ return predicate.UsageCleanupTask(sql.FieldIn(FieldStatus, vs...))
+}
+
+// StatusNotIn applies the NotIn predicate on the "status" field.
+func StatusNotIn(vs ...string) predicate.UsageCleanupTask {
+ return predicate.UsageCleanupTask(sql.FieldNotIn(FieldStatus, vs...))
+}
+
+// StatusGT applies the GT predicate on the "status" field.
+func StatusGT(v string) predicate.UsageCleanupTask {
+ return predicate.UsageCleanupTask(sql.FieldGT(FieldStatus, v))
+}
+
+// StatusGTE applies the GTE predicate on the "status" field.
+func StatusGTE(v string) predicate.UsageCleanupTask {
+ return predicate.UsageCleanupTask(sql.FieldGTE(FieldStatus, v))
+}
+
+// StatusLT applies the LT predicate on the "status" field.
+func StatusLT(v string) predicate.UsageCleanupTask {
+ return predicate.UsageCleanupTask(sql.FieldLT(FieldStatus, v))
+}
+
+// StatusLTE applies the LTE predicate on the "status" field.
+func StatusLTE(v string) predicate.UsageCleanupTask {
+ return predicate.UsageCleanupTask(sql.FieldLTE(FieldStatus, v))
+}
+
+// StatusContains applies the Contains predicate on the "status" field.
+func StatusContains(v string) predicate.UsageCleanupTask {
+ return predicate.UsageCleanupTask(sql.FieldContains(FieldStatus, v))
+}
+
+// StatusHasPrefix applies the HasPrefix predicate on the "status" field.
+func StatusHasPrefix(v string) predicate.UsageCleanupTask {
+ return predicate.UsageCleanupTask(sql.FieldHasPrefix(FieldStatus, v))
+}
+
+// StatusHasSuffix applies the HasSuffix predicate on the "status" field.
+func StatusHasSuffix(v string) predicate.UsageCleanupTask {
+ return predicate.UsageCleanupTask(sql.FieldHasSuffix(FieldStatus, v))
+}
+
+// StatusEqualFold applies the EqualFold predicate on the "status" field.
+func StatusEqualFold(v string) predicate.UsageCleanupTask {
+ return predicate.UsageCleanupTask(sql.FieldEqualFold(FieldStatus, v))
+}
+
+// StatusContainsFold applies the ContainsFold predicate on the "status" field.
+func StatusContainsFold(v string) predicate.UsageCleanupTask {
+ return predicate.UsageCleanupTask(sql.FieldContainsFold(FieldStatus, v))
+}
+
+// CreatedByEQ applies the EQ predicate on the "created_by" field.
+func CreatedByEQ(v int64) predicate.UsageCleanupTask {
+ return predicate.UsageCleanupTask(sql.FieldEQ(FieldCreatedBy, v))
+}
+
+// CreatedByNEQ applies the NEQ predicate on the "created_by" field.
+func CreatedByNEQ(v int64) predicate.UsageCleanupTask {
+ return predicate.UsageCleanupTask(sql.FieldNEQ(FieldCreatedBy, v))
+}
+
+// CreatedByIn applies the In predicate on the "created_by" field.
+func CreatedByIn(vs ...int64) predicate.UsageCleanupTask {
+ return predicate.UsageCleanupTask(sql.FieldIn(FieldCreatedBy, vs...))
+}
+
+// CreatedByNotIn applies the NotIn predicate on the "created_by" field.
+func CreatedByNotIn(vs ...int64) predicate.UsageCleanupTask {
+ return predicate.UsageCleanupTask(sql.FieldNotIn(FieldCreatedBy, vs...))
+}
+
+// CreatedByGT applies the GT predicate on the "created_by" field.
+func CreatedByGT(v int64) predicate.UsageCleanupTask {
+ return predicate.UsageCleanupTask(sql.FieldGT(FieldCreatedBy, v))
+}
+
+// CreatedByGTE applies the GTE predicate on the "created_by" field.
+func CreatedByGTE(v int64) predicate.UsageCleanupTask {
+ return predicate.UsageCleanupTask(sql.FieldGTE(FieldCreatedBy, v))
+}
+
+// CreatedByLT applies the LT predicate on the "created_by" field.
+func CreatedByLT(v int64) predicate.UsageCleanupTask {
+ return predicate.UsageCleanupTask(sql.FieldLT(FieldCreatedBy, v))
+}
+
+// CreatedByLTE applies the LTE predicate on the "created_by" field.
+func CreatedByLTE(v int64) predicate.UsageCleanupTask {
+ return predicate.UsageCleanupTask(sql.FieldLTE(FieldCreatedBy, v))
+}
+
+// DeletedRowsEQ applies the EQ predicate on the "deleted_rows" field.
+func DeletedRowsEQ(v int64) predicate.UsageCleanupTask {
+ return predicate.UsageCleanupTask(sql.FieldEQ(FieldDeletedRows, v))
+}
+
+// DeletedRowsNEQ applies the NEQ predicate on the "deleted_rows" field.
+func DeletedRowsNEQ(v int64) predicate.UsageCleanupTask {
+ return predicate.UsageCleanupTask(sql.FieldNEQ(FieldDeletedRows, v))
+}
+
+// DeletedRowsIn applies the In predicate on the "deleted_rows" field.
+func DeletedRowsIn(vs ...int64) predicate.UsageCleanupTask {
+ return predicate.UsageCleanupTask(sql.FieldIn(FieldDeletedRows, vs...))
+}
+
+// DeletedRowsNotIn applies the NotIn predicate on the "deleted_rows" field.
+func DeletedRowsNotIn(vs ...int64) predicate.UsageCleanupTask {
+ return predicate.UsageCleanupTask(sql.FieldNotIn(FieldDeletedRows, vs...))
+}
+
+// DeletedRowsGT applies the GT predicate on the "deleted_rows" field.
+func DeletedRowsGT(v int64) predicate.UsageCleanupTask {
+ return predicate.UsageCleanupTask(sql.FieldGT(FieldDeletedRows, v))
+}
+
+// DeletedRowsGTE applies the GTE predicate on the "deleted_rows" field.
+func DeletedRowsGTE(v int64) predicate.UsageCleanupTask {
+ return predicate.UsageCleanupTask(sql.FieldGTE(FieldDeletedRows, v))
+}
+
+// DeletedRowsLT applies the LT predicate on the "deleted_rows" field.
+func DeletedRowsLT(v int64) predicate.UsageCleanupTask {
+ return predicate.UsageCleanupTask(sql.FieldLT(FieldDeletedRows, v))
+}
+
+// DeletedRowsLTE applies the LTE predicate on the "deleted_rows" field.
+func DeletedRowsLTE(v int64) predicate.UsageCleanupTask {
+ return predicate.UsageCleanupTask(sql.FieldLTE(FieldDeletedRows, v))
+}
+
+// ErrorMessageEQ applies the EQ predicate on the "error_message" field.
+func ErrorMessageEQ(v string) predicate.UsageCleanupTask {
+ return predicate.UsageCleanupTask(sql.FieldEQ(FieldErrorMessage, v))
+}
+
+// ErrorMessageNEQ applies the NEQ predicate on the "error_message" field.
+func ErrorMessageNEQ(v string) predicate.UsageCleanupTask {
+ return predicate.UsageCleanupTask(sql.FieldNEQ(FieldErrorMessage, v))
+}
+
+// ErrorMessageIn applies the In predicate on the "error_message" field.
+func ErrorMessageIn(vs ...string) predicate.UsageCleanupTask {
+ return predicate.UsageCleanupTask(sql.FieldIn(FieldErrorMessage, vs...))
+}
+
+// ErrorMessageNotIn applies the NotIn predicate on the "error_message" field.
+func ErrorMessageNotIn(vs ...string) predicate.UsageCleanupTask {
+ return predicate.UsageCleanupTask(sql.FieldNotIn(FieldErrorMessage, vs...))
+}
+
+// ErrorMessageGT applies the GT predicate on the "error_message" field.
+func ErrorMessageGT(v string) predicate.UsageCleanupTask {
+ return predicate.UsageCleanupTask(sql.FieldGT(FieldErrorMessage, v))
+}
+
+// ErrorMessageGTE applies the GTE predicate on the "error_message" field.
+func ErrorMessageGTE(v string) predicate.UsageCleanupTask {
+ return predicate.UsageCleanupTask(sql.FieldGTE(FieldErrorMessage, v))
+}
+
+// ErrorMessageLT applies the LT predicate on the "error_message" field.
+func ErrorMessageLT(v string) predicate.UsageCleanupTask {
+ return predicate.UsageCleanupTask(sql.FieldLT(FieldErrorMessage, v))
+}
+
+// ErrorMessageLTE applies the LTE predicate on the "error_message" field.
+func ErrorMessageLTE(v string) predicate.UsageCleanupTask {
+ return predicate.UsageCleanupTask(sql.FieldLTE(FieldErrorMessage, v))
+}
+
+// ErrorMessageContains applies the Contains predicate on the "error_message" field.
+func ErrorMessageContains(v string) predicate.UsageCleanupTask {
+ return predicate.UsageCleanupTask(sql.FieldContains(FieldErrorMessage, v))
+}
+
+// ErrorMessageHasPrefix applies the HasPrefix predicate on the "error_message" field.
+func ErrorMessageHasPrefix(v string) predicate.UsageCleanupTask {
+ return predicate.UsageCleanupTask(sql.FieldHasPrefix(FieldErrorMessage, v))
+}
+
+// ErrorMessageHasSuffix applies the HasSuffix predicate on the "error_message" field.
+func ErrorMessageHasSuffix(v string) predicate.UsageCleanupTask {
+ return predicate.UsageCleanupTask(sql.FieldHasSuffix(FieldErrorMessage, v))
+}
+
+// ErrorMessageIsNil applies the IsNil predicate on the "error_message" field.
+func ErrorMessageIsNil() predicate.UsageCleanupTask {
+ return predicate.UsageCleanupTask(sql.FieldIsNull(FieldErrorMessage))
+}
+
+// ErrorMessageNotNil applies the NotNil predicate on the "error_message" field.
+func ErrorMessageNotNil() predicate.UsageCleanupTask {
+ return predicate.UsageCleanupTask(sql.FieldNotNull(FieldErrorMessage))
+}
+
+// ErrorMessageEqualFold applies the EqualFold predicate on the "error_message" field.
+func ErrorMessageEqualFold(v string) predicate.UsageCleanupTask {
+ return predicate.UsageCleanupTask(sql.FieldEqualFold(FieldErrorMessage, v))
+}
+
+// ErrorMessageContainsFold applies the ContainsFold predicate on the "error_message" field.
+func ErrorMessageContainsFold(v string) predicate.UsageCleanupTask {
+ return predicate.UsageCleanupTask(sql.FieldContainsFold(FieldErrorMessage, v))
+}
+
+// CanceledByEQ applies the EQ predicate on the "canceled_by" field.
+func CanceledByEQ(v int64) predicate.UsageCleanupTask {
+ return predicate.UsageCleanupTask(sql.FieldEQ(FieldCanceledBy, v))
+}
+
+// CanceledByNEQ applies the NEQ predicate on the "canceled_by" field.
+func CanceledByNEQ(v int64) predicate.UsageCleanupTask {
+ return predicate.UsageCleanupTask(sql.FieldNEQ(FieldCanceledBy, v))
+}
+
+// CanceledByIn applies the In predicate on the "canceled_by" field.
+func CanceledByIn(vs ...int64) predicate.UsageCleanupTask {
+ return predicate.UsageCleanupTask(sql.FieldIn(FieldCanceledBy, vs...))
+}
+
+// CanceledByNotIn applies the NotIn predicate on the "canceled_by" field.
+func CanceledByNotIn(vs ...int64) predicate.UsageCleanupTask {
+ return predicate.UsageCleanupTask(sql.FieldNotIn(FieldCanceledBy, vs...))
+}
+
+// CanceledByGT applies the GT predicate on the "canceled_by" field.
+func CanceledByGT(v int64) predicate.UsageCleanupTask {
+ return predicate.UsageCleanupTask(sql.FieldGT(FieldCanceledBy, v))
+}
+
+// CanceledByGTE applies the GTE predicate on the "canceled_by" field.
+func CanceledByGTE(v int64) predicate.UsageCleanupTask {
+ return predicate.UsageCleanupTask(sql.FieldGTE(FieldCanceledBy, v))
+}
+
+// CanceledByLT applies the LT predicate on the "canceled_by" field.
+func CanceledByLT(v int64) predicate.UsageCleanupTask {
+ return predicate.UsageCleanupTask(sql.FieldLT(FieldCanceledBy, v))
+}
+
+// CanceledByLTE applies the LTE predicate on the "canceled_by" field.
+func CanceledByLTE(v int64) predicate.UsageCleanupTask {
+ return predicate.UsageCleanupTask(sql.FieldLTE(FieldCanceledBy, v))
+}
+
+// CanceledByIsNil applies the IsNil predicate on the "canceled_by" field.
+func CanceledByIsNil() predicate.UsageCleanupTask {
+ return predicate.UsageCleanupTask(sql.FieldIsNull(FieldCanceledBy))
+}
+
+// CanceledByNotNil applies the NotNil predicate on the "canceled_by" field.
+func CanceledByNotNil() predicate.UsageCleanupTask {
+ return predicate.UsageCleanupTask(sql.FieldNotNull(FieldCanceledBy))
+}
+
+// CanceledAtEQ applies the EQ predicate on the "canceled_at" field.
+func CanceledAtEQ(v time.Time) predicate.UsageCleanupTask {
+ return predicate.UsageCleanupTask(sql.FieldEQ(FieldCanceledAt, v))
+}
+
+// CanceledAtNEQ applies the NEQ predicate on the "canceled_at" field.
+func CanceledAtNEQ(v time.Time) predicate.UsageCleanupTask {
+ return predicate.UsageCleanupTask(sql.FieldNEQ(FieldCanceledAt, v))
+}
+
+// CanceledAtIn applies the In predicate on the "canceled_at" field.
+func CanceledAtIn(vs ...time.Time) predicate.UsageCleanupTask {
+ return predicate.UsageCleanupTask(sql.FieldIn(FieldCanceledAt, vs...))
+}
+
+// CanceledAtNotIn applies the NotIn predicate on the "canceled_at" field.
+func CanceledAtNotIn(vs ...time.Time) predicate.UsageCleanupTask {
+ return predicate.UsageCleanupTask(sql.FieldNotIn(FieldCanceledAt, vs...))
+}
+
+// CanceledAtGT applies the GT predicate on the "canceled_at" field.
+func CanceledAtGT(v time.Time) predicate.UsageCleanupTask {
+ return predicate.UsageCleanupTask(sql.FieldGT(FieldCanceledAt, v))
+}
+
+// CanceledAtGTE applies the GTE predicate on the "canceled_at" field.
+func CanceledAtGTE(v time.Time) predicate.UsageCleanupTask {
+ return predicate.UsageCleanupTask(sql.FieldGTE(FieldCanceledAt, v))
+}
+
+// CanceledAtLT applies the LT predicate on the "canceled_at" field.
+func CanceledAtLT(v time.Time) predicate.UsageCleanupTask {
+ return predicate.UsageCleanupTask(sql.FieldLT(FieldCanceledAt, v))
+}
+
+// CanceledAtLTE applies the LTE predicate on the "canceled_at" field.
+func CanceledAtLTE(v time.Time) predicate.UsageCleanupTask {
+ return predicate.UsageCleanupTask(sql.FieldLTE(FieldCanceledAt, v))
+}
+
+// CanceledAtIsNil applies the IsNil predicate on the "canceled_at" field.
+func CanceledAtIsNil() predicate.UsageCleanupTask {
+ return predicate.UsageCleanupTask(sql.FieldIsNull(FieldCanceledAt))
+}
+
+// CanceledAtNotNil applies the NotNil predicate on the "canceled_at" field.
+func CanceledAtNotNil() predicate.UsageCleanupTask {
+ return predicate.UsageCleanupTask(sql.FieldNotNull(FieldCanceledAt))
+}
+
+// StartedAtEQ applies the EQ predicate on the "started_at" field.
+func StartedAtEQ(v time.Time) predicate.UsageCleanupTask {
+ return predicate.UsageCleanupTask(sql.FieldEQ(FieldStartedAt, v))
+}
+
+// StartedAtNEQ applies the NEQ predicate on the "started_at" field.
+func StartedAtNEQ(v time.Time) predicate.UsageCleanupTask {
+ return predicate.UsageCleanupTask(sql.FieldNEQ(FieldStartedAt, v))
+}
+
+// StartedAtIn applies the In predicate on the "started_at" field.
+func StartedAtIn(vs ...time.Time) predicate.UsageCleanupTask {
+ return predicate.UsageCleanupTask(sql.FieldIn(FieldStartedAt, vs...))
+}
+
+// StartedAtNotIn applies the NotIn predicate on the "started_at" field.
+func StartedAtNotIn(vs ...time.Time) predicate.UsageCleanupTask {
+ return predicate.UsageCleanupTask(sql.FieldNotIn(FieldStartedAt, vs...))
+}
+
+// StartedAtGT applies the GT predicate on the "started_at" field.
+func StartedAtGT(v time.Time) predicate.UsageCleanupTask {
+ return predicate.UsageCleanupTask(sql.FieldGT(FieldStartedAt, v))
+}
+
+// StartedAtGTE applies the GTE predicate on the "started_at" field.
+func StartedAtGTE(v time.Time) predicate.UsageCleanupTask {
+ return predicate.UsageCleanupTask(sql.FieldGTE(FieldStartedAt, v))
+}
+
+// StartedAtLT applies the LT predicate on the "started_at" field.
+func StartedAtLT(v time.Time) predicate.UsageCleanupTask {
+ return predicate.UsageCleanupTask(sql.FieldLT(FieldStartedAt, v))
+}
+
+// StartedAtLTE applies the LTE predicate on the "started_at" field.
+func StartedAtLTE(v time.Time) predicate.UsageCleanupTask {
+ return predicate.UsageCleanupTask(sql.FieldLTE(FieldStartedAt, v))
+}
+
+// StartedAtIsNil applies the IsNil predicate on the "started_at" field.
+func StartedAtIsNil() predicate.UsageCleanupTask {
+ return predicate.UsageCleanupTask(sql.FieldIsNull(FieldStartedAt))
+}
+
+// StartedAtNotNil applies the NotNil predicate on the "started_at" field.
+func StartedAtNotNil() predicate.UsageCleanupTask {
+ return predicate.UsageCleanupTask(sql.FieldNotNull(FieldStartedAt))
+}
+
+// FinishedAtEQ applies the EQ predicate on the "finished_at" field.
+func FinishedAtEQ(v time.Time) predicate.UsageCleanupTask {
+ return predicate.UsageCleanupTask(sql.FieldEQ(FieldFinishedAt, v))
+}
+
+// FinishedAtNEQ applies the NEQ predicate on the "finished_at" field.
+func FinishedAtNEQ(v time.Time) predicate.UsageCleanupTask {
+ return predicate.UsageCleanupTask(sql.FieldNEQ(FieldFinishedAt, v))
+}
+
+// FinishedAtIn applies the In predicate on the "finished_at" field.
+func FinishedAtIn(vs ...time.Time) predicate.UsageCleanupTask {
+ return predicate.UsageCleanupTask(sql.FieldIn(FieldFinishedAt, vs...))
+}
+
+// FinishedAtNotIn applies the NotIn predicate on the "finished_at" field.
+func FinishedAtNotIn(vs ...time.Time) predicate.UsageCleanupTask {
+ return predicate.UsageCleanupTask(sql.FieldNotIn(FieldFinishedAt, vs...))
+}
+
+// FinishedAtGT applies the GT predicate on the "finished_at" field.
+func FinishedAtGT(v time.Time) predicate.UsageCleanupTask {
+ return predicate.UsageCleanupTask(sql.FieldGT(FieldFinishedAt, v))
+}
+
+// FinishedAtGTE applies the GTE predicate on the "finished_at" field.
+func FinishedAtGTE(v time.Time) predicate.UsageCleanupTask {
+ return predicate.UsageCleanupTask(sql.FieldGTE(FieldFinishedAt, v))
+}
+
+// FinishedAtLT applies the LT predicate on the "finished_at" field.
+func FinishedAtLT(v time.Time) predicate.UsageCleanupTask {
+ return predicate.UsageCleanupTask(sql.FieldLT(FieldFinishedAt, v))
+}
+
+// FinishedAtLTE applies the LTE predicate on the "finished_at" field.
+func FinishedAtLTE(v time.Time) predicate.UsageCleanupTask {
+ return predicate.UsageCleanupTask(sql.FieldLTE(FieldFinishedAt, v))
+}
+
+// FinishedAtIsNil applies the IsNil predicate on the "finished_at" field.
+func FinishedAtIsNil() predicate.UsageCleanupTask {
+ return predicate.UsageCleanupTask(sql.FieldIsNull(FieldFinishedAt))
+}
+
+// FinishedAtNotNil applies the NotNil predicate on the "finished_at" field.
+func FinishedAtNotNil() predicate.UsageCleanupTask {
+ return predicate.UsageCleanupTask(sql.FieldNotNull(FieldFinishedAt))
+}
+
+// And groups predicates with the AND operator between them.
+func And(predicates ...predicate.UsageCleanupTask) predicate.UsageCleanupTask {
+ return predicate.UsageCleanupTask(sql.AndPredicates(predicates...))
+}
+
+// Or groups predicates with the OR operator between them.
+func Or(predicates ...predicate.UsageCleanupTask) predicate.UsageCleanupTask {
+ return predicate.UsageCleanupTask(sql.OrPredicates(predicates...))
+}
+
+// Not applies the not operator on the given predicate.
+func Not(p predicate.UsageCleanupTask) predicate.UsageCleanupTask {
+ return predicate.UsageCleanupTask(sql.NotPredicates(p))
+}
diff --git a/backend/ent/usagecleanuptask_create.go b/backend/ent/usagecleanuptask_create.go
new file mode 100644
index 00000000..0b1dcff5
--- /dev/null
+++ b/backend/ent/usagecleanuptask_create.go
@@ -0,0 +1,1190 @@
+// Code generated by ent, DO NOT EDIT.
+
+package ent
+
+import (
+ "context"
+ "encoding/json"
+ "errors"
+ "fmt"
+ "time"
+
+ "entgo.io/ent/dialect/sql"
+ "entgo.io/ent/dialect/sql/sqlgraph"
+ "entgo.io/ent/schema/field"
+ "github.com/Wei-Shaw/sub2api/ent/usagecleanuptask"
+)
+
+// UsageCleanupTaskCreate is the builder for creating a UsageCleanupTask entity.
+type UsageCleanupTaskCreate struct {
+ config
+ mutation *UsageCleanupTaskMutation
+ hooks []Hook
+ conflict []sql.ConflictOption
+}
+
+// SetCreatedAt sets the "created_at" field.
+func (_c *UsageCleanupTaskCreate) SetCreatedAt(v time.Time) *UsageCleanupTaskCreate {
+ _c.mutation.SetCreatedAt(v)
+ return _c
+}
+
+// SetNillableCreatedAt sets the "created_at" field if the given value is not nil.
+func (_c *UsageCleanupTaskCreate) SetNillableCreatedAt(v *time.Time) *UsageCleanupTaskCreate {
+ if v != nil {
+ _c.SetCreatedAt(*v)
+ }
+ return _c
+}
+
+// SetUpdatedAt sets the "updated_at" field.
+func (_c *UsageCleanupTaskCreate) SetUpdatedAt(v time.Time) *UsageCleanupTaskCreate {
+ _c.mutation.SetUpdatedAt(v)
+ return _c
+}
+
+// SetNillableUpdatedAt sets the "updated_at" field if the given value is not nil.
+func (_c *UsageCleanupTaskCreate) SetNillableUpdatedAt(v *time.Time) *UsageCleanupTaskCreate {
+ if v != nil {
+ _c.SetUpdatedAt(*v)
+ }
+ return _c
+}
+
+// SetStatus sets the "status" field.
+func (_c *UsageCleanupTaskCreate) SetStatus(v string) *UsageCleanupTaskCreate {
+ _c.mutation.SetStatus(v)
+ return _c
+}
+
+// SetFilters sets the "filters" field.
+func (_c *UsageCleanupTaskCreate) SetFilters(v json.RawMessage) *UsageCleanupTaskCreate {
+ _c.mutation.SetFilters(v)
+ return _c
+}
+
+// SetCreatedBy sets the "created_by" field.
+func (_c *UsageCleanupTaskCreate) SetCreatedBy(v int64) *UsageCleanupTaskCreate {
+ _c.mutation.SetCreatedBy(v)
+ return _c
+}
+
+// SetDeletedRows sets the "deleted_rows" field.
+func (_c *UsageCleanupTaskCreate) SetDeletedRows(v int64) *UsageCleanupTaskCreate {
+ _c.mutation.SetDeletedRows(v)
+ return _c
+}
+
+// SetNillableDeletedRows sets the "deleted_rows" field if the given value is not nil.
+func (_c *UsageCleanupTaskCreate) SetNillableDeletedRows(v *int64) *UsageCleanupTaskCreate {
+ if v != nil {
+ _c.SetDeletedRows(*v)
+ }
+ return _c
+}
+
+// SetErrorMessage sets the "error_message" field.
+func (_c *UsageCleanupTaskCreate) SetErrorMessage(v string) *UsageCleanupTaskCreate {
+ _c.mutation.SetErrorMessage(v)
+ return _c
+}
+
+// SetNillableErrorMessage sets the "error_message" field if the given value is not nil.
+func (_c *UsageCleanupTaskCreate) SetNillableErrorMessage(v *string) *UsageCleanupTaskCreate {
+ if v != nil {
+ _c.SetErrorMessage(*v)
+ }
+ return _c
+}
+
+// SetCanceledBy sets the "canceled_by" field.
+func (_c *UsageCleanupTaskCreate) SetCanceledBy(v int64) *UsageCleanupTaskCreate {
+ _c.mutation.SetCanceledBy(v)
+ return _c
+}
+
+// SetNillableCanceledBy sets the "canceled_by" field if the given value is not nil.
+func (_c *UsageCleanupTaskCreate) SetNillableCanceledBy(v *int64) *UsageCleanupTaskCreate {
+ if v != nil {
+ _c.SetCanceledBy(*v)
+ }
+ return _c
+}
+
+// SetCanceledAt sets the "canceled_at" field.
+func (_c *UsageCleanupTaskCreate) SetCanceledAt(v time.Time) *UsageCleanupTaskCreate {
+ _c.mutation.SetCanceledAt(v)
+ return _c
+}
+
+// SetNillableCanceledAt sets the "canceled_at" field if the given value is not nil.
+func (_c *UsageCleanupTaskCreate) SetNillableCanceledAt(v *time.Time) *UsageCleanupTaskCreate {
+ if v != nil {
+ _c.SetCanceledAt(*v)
+ }
+ return _c
+}
+
+// SetStartedAt sets the "started_at" field.
+func (_c *UsageCleanupTaskCreate) SetStartedAt(v time.Time) *UsageCleanupTaskCreate {
+ _c.mutation.SetStartedAt(v)
+ return _c
+}
+
+// SetNillableStartedAt sets the "started_at" field if the given value is not nil.
+func (_c *UsageCleanupTaskCreate) SetNillableStartedAt(v *time.Time) *UsageCleanupTaskCreate {
+ if v != nil {
+ _c.SetStartedAt(*v)
+ }
+ return _c
+}
+
+// SetFinishedAt sets the "finished_at" field.
+func (_c *UsageCleanupTaskCreate) SetFinishedAt(v time.Time) *UsageCleanupTaskCreate {
+ _c.mutation.SetFinishedAt(v)
+ return _c
+}
+
+// SetNillableFinishedAt sets the "finished_at" field if the given value is not nil.
+func (_c *UsageCleanupTaskCreate) SetNillableFinishedAt(v *time.Time) *UsageCleanupTaskCreate {
+ if v != nil {
+ _c.SetFinishedAt(*v)
+ }
+ return _c
+}
+
+// Mutation returns the UsageCleanupTaskMutation object of the builder.
+func (_c *UsageCleanupTaskCreate) Mutation() *UsageCleanupTaskMutation {
+ return _c.mutation
+}
+
+// Save creates the UsageCleanupTask in the database.
+func (_c *UsageCleanupTaskCreate) Save(ctx context.Context) (*UsageCleanupTask, error) {
+ _c.defaults()
+ return withHooks(ctx, _c.sqlSave, _c.mutation, _c.hooks)
+}
+
+// SaveX calls Save and panics if Save returns an error.
+func (_c *UsageCleanupTaskCreate) SaveX(ctx context.Context) *UsageCleanupTask {
+ v, err := _c.Save(ctx)
+ if err != nil {
+ panic(err)
+ }
+ return v
+}
+
+// Exec executes the query.
+func (_c *UsageCleanupTaskCreate) Exec(ctx context.Context) error {
+ _, err := _c.Save(ctx)
+ return err
+}
+
+// ExecX is like Exec, but panics if an error occurs.
+func (_c *UsageCleanupTaskCreate) ExecX(ctx context.Context) {
+ if err := _c.Exec(ctx); err != nil {
+ panic(err)
+ }
+}
+
+// defaults sets the default values of the builder before save.
+func (_c *UsageCleanupTaskCreate) defaults() {
+ if _, ok := _c.mutation.CreatedAt(); !ok {
+ v := usagecleanuptask.DefaultCreatedAt()
+ _c.mutation.SetCreatedAt(v)
+ }
+ if _, ok := _c.mutation.UpdatedAt(); !ok {
+ v := usagecleanuptask.DefaultUpdatedAt()
+ _c.mutation.SetUpdatedAt(v)
+ }
+ if _, ok := _c.mutation.DeletedRows(); !ok {
+ v := usagecleanuptask.DefaultDeletedRows
+ _c.mutation.SetDeletedRows(v)
+ }
+}
+
+// check runs all checks and user-defined validators on the builder.
+func (_c *UsageCleanupTaskCreate) check() error {
+ if _, ok := _c.mutation.CreatedAt(); !ok {
+ return &ValidationError{Name: "created_at", err: errors.New(`ent: missing required field "UsageCleanupTask.created_at"`)}
+ }
+ if _, ok := _c.mutation.UpdatedAt(); !ok {
+ return &ValidationError{Name: "updated_at", err: errors.New(`ent: missing required field "UsageCleanupTask.updated_at"`)}
+ }
+ if _, ok := _c.mutation.Status(); !ok {
+ return &ValidationError{Name: "status", err: errors.New(`ent: missing required field "UsageCleanupTask.status"`)}
+ }
+ if v, ok := _c.mutation.Status(); ok {
+ if err := usagecleanuptask.StatusValidator(v); err != nil {
+ return &ValidationError{Name: "status", err: fmt.Errorf(`ent: validator failed for field "UsageCleanupTask.status": %w`, err)}
+ }
+ }
+ if _, ok := _c.mutation.Filters(); !ok {
+ return &ValidationError{Name: "filters", err: errors.New(`ent: missing required field "UsageCleanupTask.filters"`)}
+ }
+ if _, ok := _c.mutation.CreatedBy(); !ok {
+ return &ValidationError{Name: "created_by", err: errors.New(`ent: missing required field "UsageCleanupTask.created_by"`)}
+ }
+ if _, ok := _c.mutation.DeletedRows(); !ok {
+ return &ValidationError{Name: "deleted_rows", err: errors.New(`ent: missing required field "UsageCleanupTask.deleted_rows"`)}
+ }
+ return nil
+}
+
+func (_c *UsageCleanupTaskCreate) sqlSave(ctx context.Context) (*UsageCleanupTask, error) {
+ if err := _c.check(); err != nil {
+ return nil, err
+ }
+ _node, _spec := _c.createSpec()
+ if err := sqlgraph.CreateNode(ctx, _c.driver, _spec); err != nil {
+ if sqlgraph.IsConstraintError(err) {
+ err = &ConstraintError{msg: err.Error(), wrap: err}
+ }
+ return nil, err
+ }
+ id := _spec.ID.Value.(int64)
+ _node.ID = int64(id)
+ _c.mutation.id = &_node.ID
+ _c.mutation.done = true
+ return _node, nil
+}
+
+func (_c *UsageCleanupTaskCreate) createSpec() (*UsageCleanupTask, *sqlgraph.CreateSpec) {
+ var (
+ _node = &UsageCleanupTask{config: _c.config}
+ _spec = sqlgraph.NewCreateSpec(usagecleanuptask.Table, sqlgraph.NewFieldSpec(usagecleanuptask.FieldID, field.TypeInt64))
+ )
+ _spec.OnConflict = _c.conflict
+ if value, ok := _c.mutation.CreatedAt(); ok {
+ _spec.SetField(usagecleanuptask.FieldCreatedAt, field.TypeTime, value)
+ _node.CreatedAt = value
+ }
+ if value, ok := _c.mutation.UpdatedAt(); ok {
+ _spec.SetField(usagecleanuptask.FieldUpdatedAt, field.TypeTime, value)
+ _node.UpdatedAt = value
+ }
+ if value, ok := _c.mutation.Status(); ok {
+ _spec.SetField(usagecleanuptask.FieldStatus, field.TypeString, value)
+ _node.Status = value
+ }
+ if value, ok := _c.mutation.Filters(); ok {
+ _spec.SetField(usagecleanuptask.FieldFilters, field.TypeJSON, value)
+ _node.Filters = value
+ }
+ if value, ok := _c.mutation.CreatedBy(); ok {
+ _spec.SetField(usagecleanuptask.FieldCreatedBy, field.TypeInt64, value)
+ _node.CreatedBy = value
+ }
+ if value, ok := _c.mutation.DeletedRows(); ok {
+ _spec.SetField(usagecleanuptask.FieldDeletedRows, field.TypeInt64, value)
+ _node.DeletedRows = value
+ }
+ if value, ok := _c.mutation.ErrorMessage(); ok {
+ _spec.SetField(usagecleanuptask.FieldErrorMessage, field.TypeString, value)
+ _node.ErrorMessage = &value
+ }
+ if value, ok := _c.mutation.CanceledBy(); ok {
+ _spec.SetField(usagecleanuptask.FieldCanceledBy, field.TypeInt64, value)
+ _node.CanceledBy = &value
+ }
+ if value, ok := _c.mutation.CanceledAt(); ok {
+ _spec.SetField(usagecleanuptask.FieldCanceledAt, field.TypeTime, value)
+ _node.CanceledAt = &value
+ }
+ if value, ok := _c.mutation.StartedAt(); ok {
+ _spec.SetField(usagecleanuptask.FieldStartedAt, field.TypeTime, value)
+ _node.StartedAt = &value
+ }
+ if value, ok := _c.mutation.FinishedAt(); ok {
+ _spec.SetField(usagecleanuptask.FieldFinishedAt, field.TypeTime, value)
+ _node.FinishedAt = &value
+ }
+ return _node, _spec
+}
+
+// OnConflict allows configuring the `ON CONFLICT` / `ON DUPLICATE KEY` clause
+// of the `INSERT` statement. For example:
+//
+// client.UsageCleanupTask.Create().
+// SetCreatedAt(v).
+// OnConflict(
+// // Update the row with the new values
+// // the was proposed for insertion.
+// sql.ResolveWithNewValues(),
+// ).
+// // Override some of the fields with custom
+// // update values.
+// Update(func(u *ent.UsageCleanupTaskUpsert) {
+// SetCreatedAt(v+v).
+// }).
+// Exec(ctx)
+func (_c *UsageCleanupTaskCreate) OnConflict(opts ...sql.ConflictOption) *UsageCleanupTaskUpsertOne {
+ _c.conflict = opts
+ return &UsageCleanupTaskUpsertOne{
+ create: _c,
+ }
+}
+
+// OnConflictColumns calls `OnConflict` and configures the columns
+// as conflict target. Using this option is equivalent to using:
+//
+// client.UsageCleanupTask.Create().
+// OnConflict(sql.ConflictColumns(columns...)).
+// Exec(ctx)
+func (_c *UsageCleanupTaskCreate) OnConflictColumns(columns ...string) *UsageCleanupTaskUpsertOne {
+ _c.conflict = append(_c.conflict, sql.ConflictColumns(columns...))
+ return &UsageCleanupTaskUpsertOne{
+ create: _c,
+ }
+}
+
+type (
+ // UsageCleanupTaskUpsertOne is the builder for "upsert"-ing
+ // one UsageCleanupTask node.
+ UsageCleanupTaskUpsertOne struct {
+ create *UsageCleanupTaskCreate
+ }
+
+ // UsageCleanupTaskUpsert is the "OnConflict" setter.
+ UsageCleanupTaskUpsert struct {
+ *sql.UpdateSet
+ }
+)
+
+// SetUpdatedAt sets the "updated_at" field.
+func (u *UsageCleanupTaskUpsert) SetUpdatedAt(v time.Time) *UsageCleanupTaskUpsert {
+ u.Set(usagecleanuptask.FieldUpdatedAt, v)
+ return u
+}
+
+// UpdateUpdatedAt sets the "updated_at" field to the value that was provided on create.
+func (u *UsageCleanupTaskUpsert) UpdateUpdatedAt() *UsageCleanupTaskUpsert {
+ u.SetExcluded(usagecleanuptask.FieldUpdatedAt)
+ return u
+}
+
+// SetStatus sets the "status" field.
+func (u *UsageCleanupTaskUpsert) SetStatus(v string) *UsageCleanupTaskUpsert {
+ u.Set(usagecleanuptask.FieldStatus, v)
+ return u
+}
+
+// UpdateStatus sets the "status" field to the value that was provided on create.
+func (u *UsageCleanupTaskUpsert) UpdateStatus() *UsageCleanupTaskUpsert {
+ u.SetExcluded(usagecleanuptask.FieldStatus)
+ return u
+}
+
+// SetFilters sets the "filters" field.
+func (u *UsageCleanupTaskUpsert) SetFilters(v json.RawMessage) *UsageCleanupTaskUpsert {
+ u.Set(usagecleanuptask.FieldFilters, v)
+ return u
+}
+
+// UpdateFilters sets the "filters" field to the value that was provided on create.
+func (u *UsageCleanupTaskUpsert) UpdateFilters() *UsageCleanupTaskUpsert {
+ u.SetExcluded(usagecleanuptask.FieldFilters)
+ return u
+}
+
+// SetCreatedBy sets the "created_by" field.
+func (u *UsageCleanupTaskUpsert) SetCreatedBy(v int64) *UsageCleanupTaskUpsert {
+ u.Set(usagecleanuptask.FieldCreatedBy, v)
+ return u
+}
+
+// UpdateCreatedBy sets the "created_by" field to the value that was provided on create.
+func (u *UsageCleanupTaskUpsert) UpdateCreatedBy() *UsageCleanupTaskUpsert {
+ u.SetExcluded(usagecleanuptask.FieldCreatedBy)
+ return u
+}
+
+// AddCreatedBy adds v to the "created_by" field.
+func (u *UsageCleanupTaskUpsert) AddCreatedBy(v int64) *UsageCleanupTaskUpsert {
+ u.Add(usagecleanuptask.FieldCreatedBy, v)
+ return u
+}
+
+// SetDeletedRows sets the "deleted_rows" field.
+func (u *UsageCleanupTaskUpsert) SetDeletedRows(v int64) *UsageCleanupTaskUpsert {
+ u.Set(usagecleanuptask.FieldDeletedRows, v)
+ return u
+}
+
+// UpdateDeletedRows sets the "deleted_rows" field to the value that was provided on create.
+func (u *UsageCleanupTaskUpsert) UpdateDeletedRows() *UsageCleanupTaskUpsert {
+ u.SetExcluded(usagecleanuptask.FieldDeletedRows)
+ return u
+}
+
+// AddDeletedRows adds v to the "deleted_rows" field.
+func (u *UsageCleanupTaskUpsert) AddDeletedRows(v int64) *UsageCleanupTaskUpsert {
+ u.Add(usagecleanuptask.FieldDeletedRows, v)
+ return u
+}
+
+// SetErrorMessage sets the "error_message" field.
+func (u *UsageCleanupTaskUpsert) SetErrorMessage(v string) *UsageCleanupTaskUpsert {
+ u.Set(usagecleanuptask.FieldErrorMessage, v)
+ return u
+}
+
+// UpdateErrorMessage sets the "error_message" field to the value that was provided on create.
+func (u *UsageCleanupTaskUpsert) UpdateErrorMessage() *UsageCleanupTaskUpsert {
+ u.SetExcluded(usagecleanuptask.FieldErrorMessage)
+ return u
+}
+
+// ClearErrorMessage clears the value of the "error_message" field.
+func (u *UsageCleanupTaskUpsert) ClearErrorMessage() *UsageCleanupTaskUpsert {
+ u.SetNull(usagecleanuptask.FieldErrorMessage)
+ return u
+}
+
+// SetCanceledBy sets the "canceled_by" field.
+func (u *UsageCleanupTaskUpsert) SetCanceledBy(v int64) *UsageCleanupTaskUpsert {
+ u.Set(usagecleanuptask.FieldCanceledBy, v)
+ return u
+}
+
+// UpdateCanceledBy sets the "canceled_by" field to the value that was provided on create.
+func (u *UsageCleanupTaskUpsert) UpdateCanceledBy() *UsageCleanupTaskUpsert {
+ u.SetExcluded(usagecleanuptask.FieldCanceledBy)
+ return u
+}
+
+// AddCanceledBy adds v to the "canceled_by" field.
+func (u *UsageCleanupTaskUpsert) AddCanceledBy(v int64) *UsageCleanupTaskUpsert {
+ u.Add(usagecleanuptask.FieldCanceledBy, v)
+ return u
+}
+
+// ClearCanceledBy clears the value of the "canceled_by" field.
+func (u *UsageCleanupTaskUpsert) ClearCanceledBy() *UsageCleanupTaskUpsert {
+ u.SetNull(usagecleanuptask.FieldCanceledBy)
+ return u
+}
+
+// SetCanceledAt sets the "canceled_at" field.
+func (u *UsageCleanupTaskUpsert) SetCanceledAt(v time.Time) *UsageCleanupTaskUpsert {
+ u.Set(usagecleanuptask.FieldCanceledAt, v)
+ return u
+}
+
+// UpdateCanceledAt sets the "canceled_at" field to the value that was provided on create.
+func (u *UsageCleanupTaskUpsert) UpdateCanceledAt() *UsageCleanupTaskUpsert {
+ u.SetExcluded(usagecleanuptask.FieldCanceledAt)
+ return u
+}
+
+// ClearCanceledAt clears the value of the "canceled_at" field.
+func (u *UsageCleanupTaskUpsert) ClearCanceledAt() *UsageCleanupTaskUpsert {
+ u.SetNull(usagecleanuptask.FieldCanceledAt)
+ return u
+}
+
+// SetStartedAt sets the "started_at" field.
+func (u *UsageCleanupTaskUpsert) SetStartedAt(v time.Time) *UsageCleanupTaskUpsert {
+ u.Set(usagecleanuptask.FieldStartedAt, v)
+ return u
+}
+
+// UpdateStartedAt sets the "started_at" field to the value that was provided on create.
+func (u *UsageCleanupTaskUpsert) UpdateStartedAt() *UsageCleanupTaskUpsert {
+ u.SetExcluded(usagecleanuptask.FieldStartedAt)
+ return u
+}
+
+// ClearStartedAt clears the value of the "started_at" field.
+func (u *UsageCleanupTaskUpsert) ClearStartedAt() *UsageCleanupTaskUpsert {
+ u.SetNull(usagecleanuptask.FieldStartedAt)
+ return u
+}
+
+// SetFinishedAt sets the "finished_at" field.
+func (u *UsageCleanupTaskUpsert) SetFinishedAt(v time.Time) *UsageCleanupTaskUpsert {
+ u.Set(usagecleanuptask.FieldFinishedAt, v)
+ return u
+}
+
+// UpdateFinishedAt sets the "finished_at" field to the value that was provided on create.
+func (u *UsageCleanupTaskUpsert) UpdateFinishedAt() *UsageCleanupTaskUpsert {
+ u.SetExcluded(usagecleanuptask.FieldFinishedAt)
+ return u
+}
+
+// ClearFinishedAt clears the value of the "finished_at" field.
+func (u *UsageCleanupTaskUpsert) ClearFinishedAt() *UsageCleanupTaskUpsert {
+ u.SetNull(usagecleanuptask.FieldFinishedAt)
+ return u
+}
+
+// UpdateNewValues updates the mutable fields using the new values that were set on create.
+// Using this option is equivalent to using:
+//
+// client.UsageCleanupTask.Create().
+// OnConflict(
+// sql.ResolveWithNewValues(),
+// ).
+// Exec(ctx)
+func (u *UsageCleanupTaskUpsertOne) UpdateNewValues() *UsageCleanupTaskUpsertOne {
+ u.create.conflict = append(u.create.conflict, sql.ResolveWithNewValues())
+ u.create.conflict = append(u.create.conflict, sql.ResolveWith(func(s *sql.UpdateSet) {
+ if _, exists := u.create.mutation.CreatedAt(); exists {
+ s.SetIgnore(usagecleanuptask.FieldCreatedAt)
+ }
+ }))
+ return u
+}
+
+// Ignore sets each column to itself in case of conflict.
+// Using this option is equivalent to using:
+//
+// client.UsageCleanupTask.Create().
+// OnConflict(sql.ResolveWithIgnore()).
+// Exec(ctx)
+func (u *UsageCleanupTaskUpsertOne) Ignore() *UsageCleanupTaskUpsertOne {
+ u.create.conflict = append(u.create.conflict, sql.ResolveWithIgnore())
+ return u
+}
+
+// DoNothing configures the conflict_action to `DO NOTHING`.
+// Supported only by SQLite and PostgreSQL.
+func (u *UsageCleanupTaskUpsertOne) DoNothing() *UsageCleanupTaskUpsertOne {
+ u.create.conflict = append(u.create.conflict, sql.DoNothing())
+ return u
+}
+
+// Update allows overriding fields `UPDATE` values. See the UsageCleanupTaskCreate.OnConflict
+// documentation for more info.
+func (u *UsageCleanupTaskUpsertOne) Update(set func(*UsageCleanupTaskUpsert)) *UsageCleanupTaskUpsertOne {
+ u.create.conflict = append(u.create.conflict, sql.ResolveWith(func(update *sql.UpdateSet) {
+ set(&UsageCleanupTaskUpsert{UpdateSet: update})
+ }))
+ return u
+}
+
+// SetUpdatedAt sets the "updated_at" field.
+func (u *UsageCleanupTaskUpsertOne) SetUpdatedAt(v time.Time) *UsageCleanupTaskUpsertOne {
+ return u.Update(func(s *UsageCleanupTaskUpsert) {
+ s.SetUpdatedAt(v)
+ })
+}
+
+// UpdateUpdatedAt sets the "updated_at" field to the value that was provided on create.
+func (u *UsageCleanupTaskUpsertOne) UpdateUpdatedAt() *UsageCleanupTaskUpsertOne {
+ return u.Update(func(s *UsageCleanupTaskUpsert) {
+ s.UpdateUpdatedAt()
+ })
+}
+
+// SetStatus sets the "status" field.
+func (u *UsageCleanupTaskUpsertOne) SetStatus(v string) *UsageCleanupTaskUpsertOne {
+ return u.Update(func(s *UsageCleanupTaskUpsert) {
+ s.SetStatus(v)
+ })
+}
+
+// UpdateStatus sets the "status" field to the value that was provided on create.
+func (u *UsageCleanupTaskUpsertOne) UpdateStatus() *UsageCleanupTaskUpsertOne {
+ return u.Update(func(s *UsageCleanupTaskUpsert) {
+ s.UpdateStatus()
+ })
+}
+
+// SetFilters sets the "filters" field.
+func (u *UsageCleanupTaskUpsertOne) SetFilters(v json.RawMessage) *UsageCleanupTaskUpsertOne {
+ return u.Update(func(s *UsageCleanupTaskUpsert) {
+ s.SetFilters(v)
+ })
+}
+
+// UpdateFilters sets the "filters" field to the value that was provided on create.
+func (u *UsageCleanupTaskUpsertOne) UpdateFilters() *UsageCleanupTaskUpsertOne {
+ return u.Update(func(s *UsageCleanupTaskUpsert) {
+ s.UpdateFilters()
+ })
+}
+
+// SetCreatedBy sets the "created_by" field.
+func (u *UsageCleanupTaskUpsertOne) SetCreatedBy(v int64) *UsageCleanupTaskUpsertOne {
+ return u.Update(func(s *UsageCleanupTaskUpsert) {
+ s.SetCreatedBy(v)
+ })
+}
+
+// AddCreatedBy adds v to the "created_by" field.
+func (u *UsageCleanupTaskUpsertOne) AddCreatedBy(v int64) *UsageCleanupTaskUpsertOne {
+ return u.Update(func(s *UsageCleanupTaskUpsert) {
+ s.AddCreatedBy(v)
+ })
+}
+
+// UpdateCreatedBy sets the "created_by" field to the value that was provided on create.
+func (u *UsageCleanupTaskUpsertOne) UpdateCreatedBy() *UsageCleanupTaskUpsertOne {
+ return u.Update(func(s *UsageCleanupTaskUpsert) {
+ s.UpdateCreatedBy()
+ })
+}
+
+// SetDeletedRows sets the "deleted_rows" field.
+func (u *UsageCleanupTaskUpsertOne) SetDeletedRows(v int64) *UsageCleanupTaskUpsertOne {
+ return u.Update(func(s *UsageCleanupTaskUpsert) {
+ s.SetDeletedRows(v)
+ })
+}
+
+// AddDeletedRows adds v to the "deleted_rows" field.
+func (u *UsageCleanupTaskUpsertOne) AddDeletedRows(v int64) *UsageCleanupTaskUpsertOne {
+ return u.Update(func(s *UsageCleanupTaskUpsert) {
+ s.AddDeletedRows(v)
+ })
+}
+
+// UpdateDeletedRows sets the "deleted_rows" field to the value that was provided on create.
+func (u *UsageCleanupTaskUpsertOne) UpdateDeletedRows() *UsageCleanupTaskUpsertOne {
+ return u.Update(func(s *UsageCleanupTaskUpsert) {
+ s.UpdateDeletedRows()
+ })
+}
+
+// SetErrorMessage sets the "error_message" field.
+func (u *UsageCleanupTaskUpsertOne) SetErrorMessage(v string) *UsageCleanupTaskUpsertOne {
+ return u.Update(func(s *UsageCleanupTaskUpsert) {
+ s.SetErrorMessage(v)
+ })
+}
+
+// UpdateErrorMessage sets the "error_message" field to the value that was provided on create.
+func (u *UsageCleanupTaskUpsertOne) UpdateErrorMessage() *UsageCleanupTaskUpsertOne {
+ return u.Update(func(s *UsageCleanupTaskUpsert) {
+ s.UpdateErrorMessage()
+ })
+}
+
+// ClearErrorMessage clears the value of the "error_message" field.
+func (u *UsageCleanupTaskUpsertOne) ClearErrorMessage() *UsageCleanupTaskUpsertOne {
+ return u.Update(func(s *UsageCleanupTaskUpsert) {
+ s.ClearErrorMessage()
+ })
+}
+
+// SetCanceledBy sets the "canceled_by" field.
+func (u *UsageCleanupTaskUpsertOne) SetCanceledBy(v int64) *UsageCleanupTaskUpsertOne {
+ return u.Update(func(s *UsageCleanupTaskUpsert) {
+ s.SetCanceledBy(v)
+ })
+}
+
+// AddCanceledBy adds v to the "canceled_by" field.
+func (u *UsageCleanupTaskUpsertOne) AddCanceledBy(v int64) *UsageCleanupTaskUpsertOne {
+ return u.Update(func(s *UsageCleanupTaskUpsert) {
+ s.AddCanceledBy(v)
+ })
+}
+
+// UpdateCanceledBy sets the "canceled_by" field to the value that was provided on create.
+func (u *UsageCleanupTaskUpsertOne) UpdateCanceledBy() *UsageCleanupTaskUpsertOne {
+ return u.Update(func(s *UsageCleanupTaskUpsert) {
+ s.UpdateCanceledBy()
+ })
+}
+
+// ClearCanceledBy clears the value of the "canceled_by" field.
+func (u *UsageCleanupTaskUpsertOne) ClearCanceledBy() *UsageCleanupTaskUpsertOne {
+ return u.Update(func(s *UsageCleanupTaskUpsert) {
+ s.ClearCanceledBy()
+ })
+}
+
+// SetCanceledAt sets the "canceled_at" field.
+func (u *UsageCleanupTaskUpsertOne) SetCanceledAt(v time.Time) *UsageCleanupTaskUpsertOne {
+ return u.Update(func(s *UsageCleanupTaskUpsert) {
+ s.SetCanceledAt(v)
+ })
+}
+
+// UpdateCanceledAt sets the "canceled_at" field to the value that was provided on create.
+func (u *UsageCleanupTaskUpsertOne) UpdateCanceledAt() *UsageCleanupTaskUpsertOne {
+ return u.Update(func(s *UsageCleanupTaskUpsert) {
+ s.UpdateCanceledAt()
+ })
+}
+
+// ClearCanceledAt clears the value of the "canceled_at" field.
+func (u *UsageCleanupTaskUpsertOne) ClearCanceledAt() *UsageCleanupTaskUpsertOne {
+ return u.Update(func(s *UsageCleanupTaskUpsert) {
+ s.ClearCanceledAt()
+ })
+}
+
+// SetStartedAt sets the "started_at" field.
+func (u *UsageCleanupTaskUpsertOne) SetStartedAt(v time.Time) *UsageCleanupTaskUpsertOne {
+ return u.Update(func(s *UsageCleanupTaskUpsert) {
+ s.SetStartedAt(v)
+ })
+}
+
+// UpdateStartedAt sets the "started_at" field to the value that was provided on create.
+func (u *UsageCleanupTaskUpsertOne) UpdateStartedAt() *UsageCleanupTaskUpsertOne {
+ return u.Update(func(s *UsageCleanupTaskUpsert) {
+ s.UpdateStartedAt()
+ })
+}
+
+// ClearStartedAt clears the value of the "started_at" field.
+func (u *UsageCleanupTaskUpsertOne) ClearStartedAt() *UsageCleanupTaskUpsertOne {
+ return u.Update(func(s *UsageCleanupTaskUpsert) {
+ s.ClearStartedAt()
+ })
+}
+
+// SetFinishedAt sets the "finished_at" field.
+func (u *UsageCleanupTaskUpsertOne) SetFinishedAt(v time.Time) *UsageCleanupTaskUpsertOne {
+ return u.Update(func(s *UsageCleanupTaskUpsert) {
+ s.SetFinishedAt(v)
+ })
+}
+
+// UpdateFinishedAt sets the "finished_at" field to the value that was provided on create.
+func (u *UsageCleanupTaskUpsertOne) UpdateFinishedAt() *UsageCleanupTaskUpsertOne {
+ return u.Update(func(s *UsageCleanupTaskUpsert) {
+ s.UpdateFinishedAt()
+ })
+}
+
+// ClearFinishedAt clears the value of the "finished_at" field.
+func (u *UsageCleanupTaskUpsertOne) ClearFinishedAt() *UsageCleanupTaskUpsertOne {
+ return u.Update(func(s *UsageCleanupTaskUpsert) {
+ s.ClearFinishedAt()
+ })
+}
+
+// Exec executes the query.
+func (u *UsageCleanupTaskUpsertOne) Exec(ctx context.Context) error {
+ if len(u.create.conflict) == 0 {
+ return errors.New("ent: missing options for UsageCleanupTaskCreate.OnConflict")
+ }
+ return u.create.Exec(ctx)
+}
+
+// ExecX is like Exec, but panics if an error occurs.
+func (u *UsageCleanupTaskUpsertOne) ExecX(ctx context.Context) {
+ if err := u.create.Exec(ctx); err != nil {
+ panic(err)
+ }
+}
+
+// Exec executes the UPSERT query and returns the inserted/updated ID.
+func (u *UsageCleanupTaskUpsertOne) ID(ctx context.Context) (id int64, err error) {
+ node, err := u.create.Save(ctx)
+ if err != nil {
+ return id, err
+ }
+ return node.ID, nil
+}
+
+// IDX is like ID, but panics if an error occurs.
+func (u *UsageCleanupTaskUpsertOne) IDX(ctx context.Context) int64 {
+ id, err := u.ID(ctx)
+ if err != nil {
+ panic(err)
+ }
+ return id
+}
+
+// UsageCleanupTaskCreateBulk is the builder for creating many UsageCleanupTask entities in bulk.
+type UsageCleanupTaskCreateBulk struct {
+ config
+ err error
+ builders []*UsageCleanupTaskCreate
+ conflict []sql.ConflictOption
+}
+
+// Save creates the UsageCleanupTask entities in the database.
+func (_c *UsageCleanupTaskCreateBulk) Save(ctx context.Context) ([]*UsageCleanupTask, error) {
+ if _c.err != nil {
+ return nil, _c.err
+ }
+ specs := make([]*sqlgraph.CreateSpec, len(_c.builders))
+ nodes := make([]*UsageCleanupTask, len(_c.builders))
+ mutators := make([]Mutator, len(_c.builders))
+ for i := range _c.builders {
+ func(i int, root context.Context) {
+ builder := _c.builders[i]
+ builder.defaults()
+ var mut Mutator = MutateFunc(func(ctx context.Context, m Mutation) (Value, error) {
+ mutation, ok := m.(*UsageCleanupTaskMutation)
+ if !ok {
+ return nil, fmt.Errorf("unexpected mutation type %T", m)
+ }
+ if err := builder.check(); err != nil {
+ return nil, err
+ }
+ builder.mutation = mutation
+ var err error
+ nodes[i], specs[i] = builder.createSpec()
+ if i < len(mutators)-1 {
+ _, err = mutators[i+1].Mutate(root, _c.builders[i+1].mutation)
+ } else {
+ spec := &sqlgraph.BatchCreateSpec{Nodes: specs}
+ spec.OnConflict = _c.conflict
+ // Invoke the actual operation on the latest mutation in the chain.
+ if err = sqlgraph.BatchCreate(ctx, _c.driver, spec); err != nil {
+ if sqlgraph.IsConstraintError(err) {
+ err = &ConstraintError{msg: err.Error(), wrap: err}
+ }
+ }
+ }
+ if err != nil {
+ return nil, err
+ }
+ mutation.id = &nodes[i].ID
+ if specs[i].ID.Value != nil {
+ id := specs[i].ID.Value.(int64)
+ nodes[i].ID = int64(id)
+ }
+ mutation.done = true
+ return nodes[i], nil
+ })
+ for i := len(builder.hooks) - 1; i >= 0; i-- {
+ mut = builder.hooks[i](mut)
+ }
+ mutators[i] = mut
+ }(i, ctx)
+ }
+ if len(mutators) > 0 {
+ if _, err := mutators[0].Mutate(ctx, _c.builders[0].mutation); err != nil {
+ return nil, err
+ }
+ }
+ return nodes, nil
+}
+
+// SaveX is like Save, but panics if an error occurs.
+func (_c *UsageCleanupTaskCreateBulk) SaveX(ctx context.Context) []*UsageCleanupTask {
+ v, err := _c.Save(ctx)
+ if err != nil {
+ panic(err)
+ }
+ return v
+}
+
+// Exec executes the query.
+func (_c *UsageCleanupTaskCreateBulk) Exec(ctx context.Context) error {
+ _, err := _c.Save(ctx)
+ return err
+}
+
+// ExecX is like Exec, but panics if an error occurs.
+func (_c *UsageCleanupTaskCreateBulk) ExecX(ctx context.Context) {
+ if err := _c.Exec(ctx); err != nil {
+ panic(err)
+ }
+}
+
+// OnConflict allows configuring the `ON CONFLICT` / `ON DUPLICATE KEY` clause
+// of the `INSERT` statement. For example:
+//
+// client.UsageCleanupTask.CreateBulk(builders...).
+// OnConflict(
+// // Update the row with the new values
+// // the was proposed for insertion.
+// sql.ResolveWithNewValues(),
+// ).
+// // Override some of the fields with custom
+// // update values.
+// Update(func(u *ent.UsageCleanupTaskUpsert) {
+// SetCreatedAt(v+v).
+// }).
+// Exec(ctx)
+func (_c *UsageCleanupTaskCreateBulk) OnConflict(opts ...sql.ConflictOption) *UsageCleanupTaskUpsertBulk {
+ _c.conflict = opts
+ return &UsageCleanupTaskUpsertBulk{
+ create: _c,
+ }
+}
+
+// OnConflictColumns calls `OnConflict` and configures the columns
+// as conflict target. Using this option is equivalent to using:
+//
+// client.UsageCleanupTask.Create().
+// OnConflict(sql.ConflictColumns(columns...)).
+// Exec(ctx)
+func (_c *UsageCleanupTaskCreateBulk) OnConflictColumns(columns ...string) *UsageCleanupTaskUpsertBulk {
+ _c.conflict = append(_c.conflict, sql.ConflictColumns(columns...))
+ return &UsageCleanupTaskUpsertBulk{
+ create: _c,
+ }
+}
+
+// UsageCleanupTaskUpsertBulk is the builder for "upsert"-ing
+// a bulk of UsageCleanupTask nodes.
+type UsageCleanupTaskUpsertBulk struct {
+ create *UsageCleanupTaskCreateBulk
+}
+
+// UpdateNewValues updates the mutable fields using the new values that
+// were set on create. Using this option is equivalent to using:
+//
+// client.UsageCleanupTask.Create().
+// OnConflict(
+// sql.ResolveWithNewValues(),
+// ).
+// Exec(ctx)
+func (u *UsageCleanupTaskUpsertBulk) UpdateNewValues() *UsageCleanupTaskUpsertBulk {
+ u.create.conflict = append(u.create.conflict, sql.ResolveWithNewValues())
+ u.create.conflict = append(u.create.conflict, sql.ResolveWith(func(s *sql.UpdateSet) {
+ for _, b := range u.create.builders {
+ if _, exists := b.mutation.CreatedAt(); exists {
+ s.SetIgnore(usagecleanuptask.FieldCreatedAt)
+ }
+ }
+ }))
+ return u
+}
+
+// Ignore sets each column to itself in case of conflict.
+// Using this option is equivalent to using:
+//
+// client.UsageCleanupTask.Create().
+// OnConflict(sql.ResolveWithIgnore()).
+// Exec(ctx)
+func (u *UsageCleanupTaskUpsertBulk) Ignore() *UsageCleanupTaskUpsertBulk {
+ u.create.conflict = append(u.create.conflict, sql.ResolveWithIgnore())
+ return u
+}
+
+// DoNothing configures the conflict_action to `DO NOTHING`.
+// Supported only by SQLite and PostgreSQL.
+func (u *UsageCleanupTaskUpsertBulk) DoNothing() *UsageCleanupTaskUpsertBulk {
+ u.create.conflict = append(u.create.conflict, sql.DoNothing())
+ return u
+}
+
+// Update allows overriding fields `UPDATE` values. See the UsageCleanupTaskCreateBulk.OnConflict
+// documentation for more info.
+func (u *UsageCleanupTaskUpsertBulk) Update(set func(*UsageCleanupTaskUpsert)) *UsageCleanupTaskUpsertBulk {
+ u.create.conflict = append(u.create.conflict, sql.ResolveWith(func(update *sql.UpdateSet) {
+ set(&UsageCleanupTaskUpsert{UpdateSet: update})
+ }))
+ return u
+}
+
+// SetUpdatedAt sets the "updated_at" field.
+func (u *UsageCleanupTaskUpsertBulk) SetUpdatedAt(v time.Time) *UsageCleanupTaskUpsertBulk {
+ return u.Update(func(s *UsageCleanupTaskUpsert) {
+ s.SetUpdatedAt(v)
+ })
+}
+
+// UpdateUpdatedAt sets the "updated_at" field to the value that was provided on create.
+func (u *UsageCleanupTaskUpsertBulk) UpdateUpdatedAt() *UsageCleanupTaskUpsertBulk {
+ return u.Update(func(s *UsageCleanupTaskUpsert) {
+ s.UpdateUpdatedAt()
+ })
+}
+
+// SetStatus sets the "status" field.
+func (u *UsageCleanupTaskUpsertBulk) SetStatus(v string) *UsageCleanupTaskUpsertBulk {
+ return u.Update(func(s *UsageCleanupTaskUpsert) {
+ s.SetStatus(v)
+ })
+}
+
+// UpdateStatus sets the "status" field to the value that was provided on create.
+func (u *UsageCleanupTaskUpsertBulk) UpdateStatus() *UsageCleanupTaskUpsertBulk {
+ return u.Update(func(s *UsageCleanupTaskUpsert) {
+ s.UpdateStatus()
+ })
+}
+
+// SetFilters sets the "filters" field.
+func (u *UsageCleanupTaskUpsertBulk) SetFilters(v json.RawMessage) *UsageCleanupTaskUpsertBulk {
+ return u.Update(func(s *UsageCleanupTaskUpsert) {
+ s.SetFilters(v)
+ })
+}
+
+// UpdateFilters sets the "filters" field to the value that was provided on create.
+func (u *UsageCleanupTaskUpsertBulk) UpdateFilters() *UsageCleanupTaskUpsertBulk {
+ return u.Update(func(s *UsageCleanupTaskUpsert) {
+ s.UpdateFilters()
+ })
+}
+
+// SetCreatedBy sets the "created_by" field.
+func (u *UsageCleanupTaskUpsertBulk) SetCreatedBy(v int64) *UsageCleanupTaskUpsertBulk {
+ return u.Update(func(s *UsageCleanupTaskUpsert) {
+ s.SetCreatedBy(v)
+ })
+}
+
+// AddCreatedBy adds v to the "created_by" field.
+func (u *UsageCleanupTaskUpsertBulk) AddCreatedBy(v int64) *UsageCleanupTaskUpsertBulk {
+ return u.Update(func(s *UsageCleanupTaskUpsert) {
+ s.AddCreatedBy(v)
+ })
+}
+
+// UpdateCreatedBy sets the "created_by" field to the value that was provided on create.
+func (u *UsageCleanupTaskUpsertBulk) UpdateCreatedBy() *UsageCleanupTaskUpsertBulk {
+ return u.Update(func(s *UsageCleanupTaskUpsert) {
+ s.UpdateCreatedBy()
+ })
+}
+
+// SetDeletedRows sets the "deleted_rows" field.
+func (u *UsageCleanupTaskUpsertBulk) SetDeletedRows(v int64) *UsageCleanupTaskUpsertBulk {
+ return u.Update(func(s *UsageCleanupTaskUpsert) {
+ s.SetDeletedRows(v)
+ })
+}
+
+// AddDeletedRows adds v to the "deleted_rows" field.
+func (u *UsageCleanupTaskUpsertBulk) AddDeletedRows(v int64) *UsageCleanupTaskUpsertBulk {
+ return u.Update(func(s *UsageCleanupTaskUpsert) {
+ s.AddDeletedRows(v)
+ })
+}
+
+// UpdateDeletedRows sets the "deleted_rows" field to the value that was provided on create.
+func (u *UsageCleanupTaskUpsertBulk) UpdateDeletedRows() *UsageCleanupTaskUpsertBulk {
+ return u.Update(func(s *UsageCleanupTaskUpsert) {
+ s.UpdateDeletedRows()
+ })
+}
+
+// SetErrorMessage sets the "error_message" field.
+func (u *UsageCleanupTaskUpsertBulk) SetErrorMessage(v string) *UsageCleanupTaskUpsertBulk {
+ return u.Update(func(s *UsageCleanupTaskUpsert) {
+ s.SetErrorMessage(v)
+ })
+}
+
+// UpdateErrorMessage sets the "error_message" field to the value that was provided on create.
+func (u *UsageCleanupTaskUpsertBulk) UpdateErrorMessage() *UsageCleanupTaskUpsertBulk {
+ return u.Update(func(s *UsageCleanupTaskUpsert) {
+ s.UpdateErrorMessage()
+ })
+}
+
+// ClearErrorMessage clears the value of the "error_message" field.
+func (u *UsageCleanupTaskUpsertBulk) ClearErrorMessage() *UsageCleanupTaskUpsertBulk {
+ return u.Update(func(s *UsageCleanupTaskUpsert) {
+ s.ClearErrorMessage()
+ })
+}
+
+// SetCanceledBy sets the "canceled_by" field.
+func (u *UsageCleanupTaskUpsertBulk) SetCanceledBy(v int64) *UsageCleanupTaskUpsertBulk {
+ return u.Update(func(s *UsageCleanupTaskUpsert) {
+ s.SetCanceledBy(v)
+ })
+}
+
+// AddCanceledBy adds v to the "canceled_by" field.
+func (u *UsageCleanupTaskUpsertBulk) AddCanceledBy(v int64) *UsageCleanupTaskUpsertBulk {
+ return u.Update(func(s *UsageCleanupTaskUpsert) {
+ s.AddCanceledBy(v)
+ })
+}
+
+// UpdateCanceledBy sets the "canceled_by" field to the value that was provided on create.
+func (u *UsageCleanupTaskUpsertBulk) UpdateCanceledBy() *UsageCleanupTaskUpsertBulk {
+ return u.Update(func(s *UsageCleanupTaskUpsert) {
+ s.UpdateCanceledBy()
+ })
+}
+
+// ClearCanceledBy clears the value of the "canceled_by" field.
+func (u *UsageCleanupTaskUpsertBulk) ClearCanceledBy() *UsageCleanupTaskUpsertBulk {
+ return u.Update(func(s *UsageCleanupTaskUpsert) {
+ s.ClearCanceledBy()
+ })
+}
+
+// SetCanceledAt sets the "canceled_at" field.
+func (u *UsageCleanupTaskUpsertBulk) SetCanceledAt(v time.Time) *UsageCleanupTaskUpsertBulk {
+ return u.Update(func(s *UsageCleanupTaskUpsert) {
+ s.SetCanceledAt(v)
+ })
+}
+
+// UpdateCanceledAt sets the "canceled_at" field to the value that was provided on create.
+func (u *UsageCleanupTaskUpsertBulk) UpdateCanceledAt() *UsageCleanupTaskUpsertBulk {
+ return u.Update(func(s *UsageCleanupTaskUpsert) {
+ s.UpdateCanceledAt()
+ })
+}
+
+// ClearCanceledAt clears the value of the "canceled_at" field.
+func (u *UsageCleanupTaskUpsertBulk) ClearCanceledAt() *UsageCleanupTaskUpsertBulk {
+ return u.Update(func(s *UsageCleanupTaskUpsert) {
+ s.ClearCanceledAt()
+ })
+}
+
+// SetStartedAt sets the "started_at" field.
+func (u *UsageCleanupTaskUpsertBulk) SetStartedAt(v time.Time) *UsageCleanupTaskUpsertBulk {
+ return u.Update(func(s *UsageCleanupTaskUpsert) {
+ s.SetStartedAt(v)
+ })
+}
+
+// UpdateStartedAt sets the "started_at" field to the value that was provided on create.
+func (u *UsageCleanupTaskUpsertBulk) UpdateStartedAt() *UsageCleanupTaskUpsertBulk {
+ return u.Update(func(s *UsageCleanupTaskUpsert) {
+ s.UpdateStartedAt()
+ })
+}
+
+// ClearStartedAt clears the value of the "started_at" field.
+func (u *UsageCleanupTaskUpsertBulk) ClearStartedAt() *UsageCleanupTaskUpsertBulk {
+ return u.Update(func(s *UsageCleanupTaskUpsert) {
+ s.ClearStartedAt()
+ })
+}
+
+// SetFinishedAt sets the "finished_at" field.
+func (u *UsageCleanupTaskUpsertBulk) SetFinishedAt(v time.Time) *UsageCleanupTaskUpsertBulk {
+ return u.Update(func(s *UsageCleanupTaskUpsert) {
+ s.SetFinishedAt(v)
+ })
+}
+
+// UpdateFinishedAt sets the "finished_at" field to the value that was provided on create.
+func (u *UsageCleanupTaskUpsertBulk) UpdateFinishedAt() *UsageCleanupTaskUpsertBulk {
+ return u.Update(func(s *UsageCleanupTaskUpsert) {
+ s.UpdateFinishedAt()
+ })
+}
+
+// ClearFinishedAt clears the value of the "finished_at" field.
+func (u *UsageCleanupTaskUpsertBulk) ClearFinishedAt() *UsageCleanupTaskUpsertBulk {
+ return u.Update(func(s *UsageCleanupTaskUpsert) {
+ s.ClearFinishedAt()
+ })
+}
+
+// Exec executes the query.
+func (u *UsageCleanupTaskUpsertBulk) Exec(ctx context.Context) error {
+ if u.create.err != nil {
+ return u.create.err
+ }
+ for i, b := range u.create.builders {
+ if len(b.conflict) != 0 {
+ return fmt.Errorf("ent: OnConflict was set for builder %d. Set it on the UsageCleanupTaskCreateBulk instead", i)
+ }
+ }
+ if len(u.create.conflict) == 0 {
+ return errors.New("ent: missing options for UsageCleanupTaskCreateBulk.OnConflict")
+ }
+ return u.create.Exec(ctx)
+}
+
+// ExecX is like Exec, but panics if an error occurs.
+func (u *UsageCleanupTaskUpsertBulk) ExecX(ctx context.Context) {
+ if err := u.create.Exec(ctx); err != nil {
+ panic(err)
+ }
+}
diff --git a/backend/ent/usagecleanuptask_delete.go b/backend/ent/usagecleanuptask_delete.go
new file mode 100644
index 00000000..158555f7
--- /dev/null
+++ b/backend/ent/usagecleanuptask_delete.go
@@ -0,0 +1,88 @@
+// Code generated by ent, DO NOT EDIT.
+
+package ent
+
+import (
+ "context"
+
+ "entgo.io/ent/dialect/sql"
+ "entgo.io/ent/dialect/sql/sqlgraph"
+ "entgo.io/ent/schema/field"
+ "github.com/Wei-Shaw/sub2api/ent/predicate"
+ "github.com/Wei-Shaw/sub2api/ent/usagecleanuptask"
+)
+
+// UsageCleanupTaskDelete is the builder for deleting a UsageCleanupTask entity.
+type UsageCleanupTaskDelete struct {
+ config
+ hooks []Hook
+ mutation *UsageCleanupTaskMutation
+}
+
+// Where appends a list predicates to the UsageCleanupTaskDelete builder.
+func (_d *UsageCleanupTaskDelete) Where(ps ...predicate.UsageCleanupTask) *UsageCleanupTaskDelete {
+ _d.mutation.Where(ps...)
+ return _d
+}
+
+// Exec executes the deletion query and returns how many vertices were deleted.
+func (_d *UsageCleanupTaskDelete) Exec(ctx context.Context) (int, error) {
+ return withHooks(ctx, _d.sqlExec, _d.mutation, _d.hooks)
+}
+
+// ExecX is like Exec, but panics if an error occurs.
+func (_d *UsageCleanupTaskDelete) ExecX(ctx context.Context) int {
+ n, err := _d.Exec(ctx)
+ if err != nil {
+ panic(err)
+ }
+ return n
+}
+
+func (_d *UsageCleanupTaskDelete) sqlExec(ctx context.Context) (int, error) {
+ _spec := sqlgraph.NewDeleteSpec(usagecleanuptask.Table, sqlgraph.NewFieldSpec(usagecleanuptask.FieldID, field.TypeInt64))
+ if ps := _d.mutation.predicates; len(ps) > 0 {
+ _spec.Predicate = func(selector *sql.Selector) {
+ for i := range ps {
+ ps[i](selector)
+ }
+ }
+ }
+ affected, err := sqlgraph.DeleteNodes(ctx, _d.driver, _spec)
+ if err != nil && sqlgraph.IsConstraintError(err) {
+ err = &ConstraintError{msg: err.Error(), wrap: err}
+ }
+ _d.mutation.done = true
+ return affected, err
+}
+
+// UsageCleanupTaskDeleteOne is the builder for deleting a single UsageCleanupTask entity.
+type UsageCleanupTaskDeleteOne struct {
+ _d *UsageCleanupTaskDelete
+}
+
+// Where appends a list predicates to the UsageCleanupTaskDelete builder.
+func (_d *UsageCleanupTaskDeleteOne) Where(ps ...predicate.UsageCleanupTask) *UsageCleanupTaskDeleteOne {
+ _d._d.mutation.Where(ps...)
+ return _d
+}
+
+// Exec executes the deletion query.
+func (_d *UsageCleanupTaskDeleteOne) Exec(ctx context.Context) error {
+ n, err := _d._d.Exec(ctx)
+ switch {
+ case err != nil:
+ return err
+ case n == 0:
+ return &NotFoundError{usagecleanuptask.Label}
+ default:
+ return nil
+ }
+}
+
+// ExecX is like Exec, but panics if an error occurs.
+func (_d *UsageCleanupTaskDeleteOne) ExecX(ctx context.Context) {
+ if err := _d.Exec(ctx); err != nil {
+ panic(err)
+ }
+}
diff --git a/backend/ent/usagecleanuptask_query.go b/backend/ent/usagecleanuptask_query.go
new file mode 100644
index 00000000..9d8d5410
--- /dev/null
+++ b/backend/ent/usagecleanuptask_query.go
@@ -0,0 +1,564 @@
+// Code generated by ent, DO NOT EDIT.
+
+package ent
+
+import (
+ "context"
+ "fmt"
+ "math"
+
+ "entgo.io/ent"
+ "entgo.io/ent/dialect"
+ "entgo.io/ent/dialect/sql"
+ "entgo.io/ent/dialect/sql/sqlgraph"
+ "entgo.io/ent/schema/field"
+ "github.com/Wei-Shaw/sub2api/ent/predicate"
+ "github.com/Wei-Shaw/sub2api/ent/usagecleanuptask"
+)
+
+// UsageCleanupTaskQuery is the builder for querying UsageCleanupTask entities.
+type UsageCleanupTaskQuery struct {
+ config
+ ctx *QueryContext
+ order []usagecleanuptask.OrderOption
+ inters []Interceptor
+ predicates []predicate.UsageCleanupTask
+ modifiers []func(*sql.Selector)
+ // intermediate query (i.e. traversal path).
+ sql *sql.Selector
+ path func(context.Context) (*sql.Selector, error)
+}
+
+// Where adds a new predicate for the UsageCleanupTaskQuery builder.
+func (_q *UsageCleanupTaskQuery) Where(ps ...predicate.UsageCleanupTask) *UsageCleanupTaskQuery {
+ _q.predicates = append(_q.predicates, ps...)
+ return _q
+}
+
+// Limit the number of records to be returned by this query.
+func (_q *UsageCleanupTaskQuery) Limit(limit int) *UsageCleanupTaskQuery {
+ _q.ctx.Limit = &limit
+ return _q
+}
+
+// Offset to start from.
+func (_q *UsageCleanupTaskQuery) Offset(offset int) *UsageCleanupTaskQuery {
+ _q.ctx.Offset = &offset
+ return _q
+}
+
+// Unique configures the query builder to filter duplicate records on query.
+// By default, unique is set to true, and can be disabled using this method.
+func (_q *UsageCleanupTaskQuery) Unique(unique bool) *UsageCleanupTaskQuery {
+ _q.ctx.Unique = &unique
+ return _q
+}
+
+// Order specifies how the records should be ordered.
+func (_q *UsageCleanupTaskQuery) Order(o ...usagecleanuptask.OrderOption) *UsageCleanupTaskQuery {
+ _q.order = append(_q.order, o...)
+ return _q
+}
+
+// First returns the first UsageCleanupTask entity from the query.
+// Returns a *NotFoundError when no UsageCleanupTask was found.
+func (_q *UsageCleanupTaskQuery) First(ctx context.Context) (*UsageCleanupTask, error) {
+ nodes, err := _q.Limit(1).All(setContextOp(ctx, _q.ctx, ent.OpQueryFirst))
+ if err != nil {
+ return nil, err
+ }
+ if len(nodes) == 0 {
+ return nil, &NotFoundError{usagecleanuptask.Label}
+ }
+ return nodes[0], nil
+}
+
+// FirstX is like First, but panics if an error occurs.
+func (_q *UsageCleanupTaskQuery) FirstX(ctx context.Context) *UsageCleanupTask {
+ node, err := _q.First(ctx)
+ if err != nil && !IsNotFound(err) {
+ panic(err)
+ }
+ return node
+}
+
+// FirstID returns the first UsageCleanupTask ID from the query.
+// Returns a *NotFoundError when no UsageCleanupTask ID was found.
+func (_q *UsageCleanupTaskQuery) FirstID(ctx context.Context) (id int64, err error) {
+ var ids []int64
+ if ids, err = _q.Limit(1).IDs(setContextOp(ctx, _q.ctx, ent.OpQueryFirstID)); err != nil {
+ return
+ }
+ if len(ids) == 0 {
+ err = &NotFoundError{usagecleanuptask.Label}
+ return
+ }
+ return ids[0], nil
+}
+
+// FirstIDX is like FirstID, but panics if an error occurs.
+func (_q *UsageCleanupTaskQuery) FirstIDX(ctx context.Context) int64 {
+ id, err := _q.FirstID(ctx)
+ if err != nil && !IsNotFound(err) {
+ panic(err)
+ }
+ return id
+}
+
+// Only returns a single UsageCleanupTask entity found by the query, ensuring it only returns one.
+// Returns a *NotSingularError when more than one UsageCleanupTask entity is found.
+// Returns a *NotFoundError when no UsageCleanupTask entities are found.
+func (_q *UsageCleanupTaskQuery) Only(ctx context.Context) (*UsageCleanupTask, error) {
+ nodes, err := _q.Limit(2).All(setContextOp(ctx, _q.ctx, ent.OpQueryOnly))
+ if err != nil {
+ return nil, err
+ }
+ switch len(nodes) {
+ case 1:
+ return nodes[0], nil
+ case 0:
+ return nil, &NotFoundError{usagecleanuptask.Label}
+ default:
+ return nil, &NotSingularError{usagecleanuptask.Label}
+ }
+}
+
+// OnlyX is like Only, but panics if an error occurs.
+func (_q *UsageCleanupTaskQuery) OnlyX(ctx context.Context) *UsageCleanupTask {
+ node, err := _q.Only(ctx)
+ if err != nil {
+ panic(err)
+ }
+ return node
+}
+
+// OnlyID is like Only, but returns the only UsageCleanupTask ID in the query.
+// Returns a *NotSingularError when more than one UsageCleanupTask ID is found.
+// Returns a *NotFoundError when no entities are found.
+func (_q *UsageCleanupTaskQuery) OnlyID(ctx context.Context) (id int64, err error) {
+ var ids []int64
+ if ids, err = _q.Limit(2).IDs(setContextOp(ctx, _q.ctx, ent.OpQueryOnlyID)); err != nil {
+ return
+ }
+ switch len(ids) {
+ case 1:
+ id = ids[0]
+ case 0:
+ err = &NotFoundError{usagecleanuptask.Label}
+ default:
+ err = &NotSingularError{usagecleanuptask.Label}
+ }
+ return
+}
+
+// OnlyIDX is like OnlyID, but panics if an error occurs.
+func (_q *UsageCleanupTaskQuery) OnlyIDX(ctx context.Context) int64 {
+ id, err := _q.OnlyID(ctx)
+ if err != nil {
+ panic(err)
+ }
+ return id
+}
+
+// All executes the query and returns a list of UsageCleanupTasks.
+func (_q *UsageCleanupTaskQuery) All(ctx context.Context) ([]*UsageCleanupTask, error) {
+ ctx = setContextOp(ctx, _q.ctx, ent.OpQueryAll)
+ if err := _q.prepareQuery(ctx); err != nil {
+ return nil, err
+ }
+ qr := querierAll[[]*UsageCleanupTask, *UsageCleanupTaskQuery]()
+ return withInterceptors[[]*UsageCleanupTask](ctx, _q, qr, _q.inters)
+}
+
+// AllX is like All, but panics if an error occurs.
+func (_q *UsageCleanupTaskQuery) AllX(ctx context.Context) []*UsageCleanupTask {
+ nodes, err := _q.All(ctx)
+ if err != nil {
+ panic(err)
+ }
+ return nodes
+}
+
+// IDs executes the query and returns a list of UsageCleanupTask IDs.
+func (_q *UsageCleanupTaskQuery) IDs(ctx context.Context) (ids []int64, err error) {
+ if _q.ctx.Unique == nil && _q.path != nil {
+ _q.Unique(true)
+ }
+ ctx = setContextOp(ctx, _q.ctx, ent.OpQueryIDs)
+ if err = _q.Select(usagecleanuptask.FieldID).Scan(ctx, &ids); err != nil {
+ return nil, err
+ }
+ return ids, nil
+}
+
+// IDsX is like IDs, but panics if an error occurs.
+func (_q *UsageCleanupTaskQuery) IDsX(ctx context.Context) []int64 {
+ ids, err := _q.IDs(ctx)
+ if err != nil {
+ panic(err)
+ }
+ return ids
+}
+
+// Count returns the count of the given query.
+func (_q *UsageCleanupTaskQuery) Count(ctx context.Context) (int, error) {
+ ctx = setContextOp(ctx, _q.ctx, ent.OpQueryCount)
+ if err := _q.prepareQuery(ctx); err != nil {
+ return 0, err
+ }
+ return withInterceptors[int](ctx, _q, querierCount[*UsageCleanupTaskQuery](), _q.inters)
+}
+
+// CountX is like Count, but panics if an error occurs.
+func (_q *UsageCleanupTaskQuery) CountX(ctx context.Context) int {
+ count, err := _q.Count(ctx)
+ if err != nil {
+ panic(err)
+ }
+ return count
+}
+
+// Exist returns true if the query has elements in the graph.
+func (_q *UsageCleanupTaskQuery) Exist(ctx context.Context) (bool, error) {
+ ctx = setContextOp(ctx, _q.ctx, ent.OpQueryExist)
+ switch _, err := _q.FirstID(ctx); {
+ case IsNotFound(err):
+ return false, nil
+ case err != nil:
+ return false, fmt.Errorf("ent: check existence: %w", err)
+ default:
+ return true, nil
+ }
+}
+
+// ExistX is like Exist, but panics if an error occurs.
+func (_q *UsageCleanupTaskQuery) ExistX(ctx context.Context) bool {
+ exist, err := _q.Exist(ctx)
+ if err != nil {
+ panic(err)
+ }
+ return exist
+}
+
+// Clone returns a duplicate of the UsageCleanupTaskQuery builder, including all associated steps. It can be
+// used to prepare common query builders and use them differently after the clone is made.
+func (_q *UsageCleanupTaskQuery) Clone() *UsageCleanupTaskQuery {
+ if _q == nil {
+ return nil
+ }
+ return &UsageCleanupTaskQuery{
+ config: _q.config,
+ ctx: _q.ctx.Clone(),
+ order: append([]usagecleanuptask.OrderOption{}, _q.order...),
+ inters: append([]Interceptor{}, _q.inters...),
+ predicates: append([]predicate.UsageCleanupTask{}, _q.predicates...),
+ // clone intermediate query.
+ sql: _q.sql.Clone(),
+ path: _q.path,
+ }
+}
+
+// GroupBy is used to group vertices by one or more fields/columns.
+// It is often used with aggregate functions, like: count, max, mean, min, sum.
+//
+// Example:
+//
+// var v []struct {
+// CreatedAt time.Time `json:"created_at,omitempty"`
+// Count int `json:"count,omitempty"`
+// }
+//
+// client.UsageCleanupTask.Query().
+// GroupBy(usagecleanuptask.FieldCreatedAt).
+// Aggregate(ent.Count()).
+// Scan(ctx, &v)
+func (_q *UsageCleanupTaskQuery) GroupBy(field string, fields ...string) *UsageCleanupTaskGroupBy {
+ _q.ctx.Fields = append([]string{field}, fields...)
+ grbuild := &UsageCleanupTaskGroupBy{build: _q}
+ grbuild.flds = &_q.ctx.Fields
+ grbuild.label = usagecleanuptask.Label
+ grbuild.scan = grbuild.Scan
+ return grbuild
+}
+
+// Select allows the selection one or more fields/columns for the given query,
+// instead of selecting all fields in the entity.
+//
+// Example:
+//
+// var v []struct {
+// CreatedAt time.Time `json:"created_at,omitempty"`
+// }
+//
+// client.UsageCleanupTask.Query().
+// Select(usagecleanuptask.FieldCreatedAt).
+// Scan(ctx, &v)
+func (_q *UsageCleanupTaskQuery) Select(fields ...string) *UsageCleanupTaskSelect {
+ _q.ctx.Fields = append(_q.ctx.Fields, fields...)
+ sbuild := &UsageCleanupTaskSelect{UsageCleanupTaskQuery: _q}
+ sbuild.label = usagecleanuptask.Label
+ sbuild.flds, sbuild.scan = &_q.ctx.Fields, sbuild.Scan
+ return sbuild
+}
+
+// Aggregate returns a UsageCleanupTaskSelect configured with the given aggregations.
+func (_q *UsageCleanupTaskQuery) Aggregate(fns ...AggregateFunc) *UsageCleanupTaskSelect {
+ return _q.Select().Aggregate(fns...)
+}
+
+func (_q *UsageCleanupTaskQuery) prepareQuery(ctx context.Context) error {
+ for _, inter := range _q.inters {
+ if inter == nil {
+ return fmt.Errorf("ent: uninitialized interceptor (forgotten import ent/runtime?)")
+ }
+ if trv, ok := inter.(Traverser); ok {
+ if err := trv.Traverse(ctx, _q); err != nil {
+ return err
+ }
+ }
+ }
+ for _, f := range _q.ctx.Fields {
+ if !usagecleanuptask.ValidColumn(f) {
+ return &ValidationError{Name: f, err: fmt.Errorf("ent: invalid field %q for query", f)}
+ }
+ }
+ if _q.path != nil {
+ prev, err := _q.path(ctx)
+ if err != nil {
+ return err
+ }
+ _q.sql = prev
+ }
+ return nil
+}
+
+func (_q *UsageCleanupTaskQuery) sqlAll(ctx context.Context, hooks ...queryHook) ([]*UsageCleanupTask, error) {
+ var (
+ nodes = []*UsageCleanupTask{}
+ _spec = _q.querySpec()
+ )
+ _spec.ScanValues = func(columns []string) ([]any, error) {
+ return (*UsageCleanupTask).scanValues(nil, columns)
+ }
+ _spec.Assign = func(columns []string, values []any) error {
+ node := &UsageCleanupTask{config: _q.config}
+ nodes = append(nodes, node)
+ return node.assignValues(columns, values)
+ }
+ if len(_q.modifiers) > 0 {
+ _spec.Modifiers = _q.modifiers
+ }
+ for i := range hooks {
+ hooks[i](ctx, _spec)
+ }
+ if err := sqlgraph.QueryNodes(ctx, _q.driver, _spec); err != nil {
+ return nil, err
+ }
+ if len(nodes) == 0 {
+ return nodes, nil
+ }
+ return nodes, nil
+}
+
+func (_q *UsageCleanupTaskQuery) sqlCount(ctx context.Context) (int, error) {
+ _spec := _q.querySpec()
+ if len(_q.modifiers) > 0 {
+ _spec.Modifiers = _q.modifiers
+ }
+ _spec.Node.Columns = _q.ctx.Fields
+ if len(_q.ctx.Fields) > 0 {
+ _spec.Unique = _q.ctx.Unique != nil && *_q.ctx.Unique
+ }
+ return sqlgraph.CountNodes(ctx, _q.driver, _spec)
+}
+
+func (_q *UsageCleanupTaskQuery) querySpec() *sqlgraph.QuerySpec {
+ _spec := sqlgraph.NewQuerySpec(usagecleanuptask.Table, usagecleanuptask.Columns, sqlgraph.NewFieldSpec(usagecleanuptask.FieldID, field.TypeInt64))
+ _spec.From = _q.sql
+ if unique := _q.ctx.Unique; unique != nil {
+ _spec.Unique = *unique
+ } else if _q.path != nil {
+ _spec.Unique = true
+ }
+ if fields := _q.ctx.Fields; len(fields) > 0 {
+ _spec.Node.Columns = make([]string, 0, len(fields))
+ _spec.Node.Columns = append(_spec.Node.Columns, usagecleanuptask.FieldID)
+ for i := range fields {
+ if fields[i] != usagecleanuptask.FieldID {
+ _spec.Node.Columns = append(_spec.Node.Columns, fields[i])
+ }
+ }
+ }
+ if ps := _q.predicates; len(ps) > 0 {
+ _spec.Predicate = func(selector *sql.Selector) {
+ for i := range ps {
+ ps[i](selector)
+ }
+ }
+ }
+ if limit := _q.ctx.Limit; limit != nil {
+ _spec.Limit = *limit
+ }
+ if offset := _q.ctx.Offset; offset != nil {
+ _spec.Offset = *offset
+ }
+ if ps := _q.order; len(ps) > 0 {
+ _spec.Order = func(selector *sql.Selector) {
+ for i := range ps {
+ ps[i](selector)
+ }
+ }
+ }
+ return _spec
+}
+
+func (_q *UsageCleanupTaskQuery) sqlQuery(ctx context.Context) *sql.Selector {
+ builder := sql.Dialect(_q.driver.Dialect())
+ t1 := builder.Table(usagecleanuptask.Table)
+ columns := _q.ctx.Fields
+ if len(columns) == 0 {
+ columns = usagecleanuptask.Columns
+ }
+ selector := builder.Select(t1.Columns(columns...)...).From(t1)
+ if _q.sql != nil {
+ selector = _q.sql
+ selector.Select(selector.Columns(columns...)...)
+ }
+ if _q.ctx.Unique != nil && *_q.ctx.Unique {
+ selector.Distinct()
+ }
+ for _, m := range _q.modifiers {
+ m(selector)
+ }
+ for _, p := range _q.predicates {
+ p(selector)
+ }
+ for _, p := range _q.order {
+ p(selector)
+ }
+ if offset := _q.ctx.Offset; offset != nil {
+ // limit is mandatory for offset clause. We start
+ // with default value, and override it below if needed.
+ selector.Offset(*offset).Limit(math.MaxInt32)
+ }
+ if limit := _q.ctx.Limit; limit != nil {
+ selector.Limit(*limit)
+ }
+ return selector
+}
+
+// ForUpdate locks the selected rows against concurrent updates, and prevent them from being
+// updated, deleted or "selected ... for update" by other sessions, until the transaction is
+// either committed or rolled-back.
+func (_q *UsageCleanupTaskQuery) ForUpdate(opts ...sql.LockOption) *UsageCleanupTaskQuery {
+ if _q.driver.Dialect() == dialect.Postgres {
+ _q.Unique(false)
+ }
+ _q.modifiers = append(_q.modifiers, func(s *sql.Selector) {
+ s.ForUpdate(opts...)
+ })
+ return _q
+}
+
+// ForShare behaves similarly to ForUpdate, except that it acquires a shared mode lock
+// on any rows that are read. Other sessions can read the rows, but cannot modify them
+// until your transaction commits.
+func (_q *UsageCleanupTaskQuery) ForShare(opts ...sql.LockOption) *UsageCleanupTaskQuery {
+ if _q.driver.Dialect() == dialect.Postgres {
+ _q.Unique(false)
+ }
+ _q.modifiers = append(_q.modifiers, func(s *sql.Selector) {
+ s.ForShare(opts...)
+ })
+ return _q
+}
+
+// UsageCleanupTaskGroupBy is the group-by builder for UsageCleanupTask entities.
+type UsageCleanupTaskGroupBy struct {
+ selector
+ build *UsageCleanupTaskQuery
+}
+
+// Aggregate adds the given aggregation functions to the group-by query.
+func (_g *UsageCleanupTaskGroupBy) Aggregate(fns ...AggregateFunc) *UsageCleanupTaskGroupBy {
+ _g.fns = append(_g.fns, fns...)
+ return _g
+}
+
+// Scan applies the selector query and scans the result into the given value.
+func (_g *UsageCleanupTaskGroupBy) Scan(ctx context.Context, v any) error {
+ ctx = setContextOp(ctx, _g.build.ctx, ent.OpQueryGroupBy)
+ if err := _g.build.prepareQuery(ctx); err != nil {
+ return err
+ }
+ return scanWithInterceptors[*UsageCleanupTaskQuery, *UsageCleanupTaskGroupBy](ctx, _g.build, _g, _g.build.inters, v)
+}
+
+func (_g *UsageCleanupTaskGroupBy) sqlScan(ctx context.Context, root *UsageCleanupTaskQuery, v any) error {
+ selector := root.sqlQuery(ctx).Select()
+ aggregation := make([]string, 0, len(_g.fns))
+ for _, fn := range _g.fns {
+ aggregation = append(aggregation, fn(selector))
+ }
+ if len(selector.SelectedColumns()) == 0 {
+ columns := make([]string, 0, len(*_g.flds)+len(_g.fns))
+ for _, f := range *_g.flds {
+ columns = append(columns, selector.C(f))
+ }
+ columns = append(columns, aggregation...)
+ selector.Select(columns...)
+ }
+ selector.GroupBy(selector.Columns(*_g.flds...)...)
+ if err := selector.Err(); err != nil {
+ return err
+ }
+ rows := &sql.Rows{}
+ query, args := selector.Query()
+ if err := _g.build.driver.Query(ctx, query, args, rows); err != nil {
+ return err
+ }
+ defer rows.Close()
+ return sql.ScanSlice(rows, v)
+}
+
+// UsageCleanupTaskSelect is the builder for selecting fields of UsageCleanupTask entities.
+type UsageCleanupTaskSelect struct {
+ *UsageCleanupTaskQuery
+ selector
+}
+
+// Aggregate adds the given aggregation functions to the selector query.
+func (_s *UsageCleanupTaskSelect) Aggregate(fns ...AggregateFunc) *UsageCleanupTaskSelect {
+ _s.fns = append(_s.fns, fns...)
+ return _s
+}
+
+// Scan applies the selector query and scans the result into the given value.
+func (_s *UsageCleanupTaskSelect) Scan(ctx context.Context, v any) error {
+ ctx = setContextOp(ctx, _s.ctx, ent.OpQuerySelect)
+ if err := _s.prepareQuery(ctx); err != nil {
+ return err
+ }
+ return scanWithInterceptors[*UsageCleanupTaskQuery, *UsageCleanupTaskSelect](ctx, _s.UsageCleanupTaskQuery, _s, _s.inters, v)
+}
+
+func (_s *UsageCleanupTaskSelect) sqlScan(ctx context.Context, root *UsageCleanupTaskQuery, v any) error {
+ selector := root.sqlQuery(ctx)
+ aggregation := make([]string, 0, len(_s.fns))
+ for _, fn := range _s.fns {
+ aggregation = append(aggregation, fn(selector))
+ }
+ switch n := len(*_s.selector.flds); {
+ case n == 0 && len(aggregation) > 0:
+ selector.Select(aggregation...)
+ case n != 0 && len(aggregation) > 0:
+ selector.AppendSelect(aggregation...)
+ }
+ rows := &sql.Rows{}
+ query, args := selector.Query()
+ if err := _s.driver.Query(ctx, query, args, rows); err != nil {
+ return err
+ }
+ defer rows.Close()
+ return sql.ScanSlice(rows, v)
+}
diff --git a/backend/ent/usagecleanuptask_update.go b/backend/ent/usagecleanuptask_update.go
new file mode 100644
index 00000000..604202c6
--- /dev/null
+++ b/backend/ent/usagecleanuptask_update.go
@@ -0,0 +1,702 @@
+// Code generated by ent, DO NOT EDIT.
+
+package ent
+
+import (
+ "context"
+ "encoding/json"
+ "errors"
+ "fmt"
+ "time"
+
+ "entgo.io/ent/dialect/sql"
+ "entgo.io/ent/dialect/sql/sqlgraph"
+ "entgo.io/ent/dialect/sql/sqljson"
+ "entgo.io/ent/schema/field"
+ "github.com/Wei-Shaw/sub2api/ent/predicate"
+ "github.com/Wei-Shaw/sub2api/ent/usagecleanuptask"
+)
+
+// UsageCleanupTaskUpdate is the builder for updating UsageCleanupTask entities.
+type UsageCleanupTaskUpdate struct {
+ config
+ hooks []Hook
+ mutation *UsageCleanupTaskMutation
+}
+
+// Where appends a list predicates to the UsageCleanupTaskUpdate builder.
+func (_u *UsageCleanupTaskUpdate) Where(ps ...predicate.UsageCleanupTask) *UsageCleanupTaskUpdate {
+ _u.mutation.Where(ps...)
+ return _u
+}
+
+// SetUpdatedAt sets the "updated_at" field.
+func (_u *UsageCleanupTaskUpdate) SetUpdatedAt(v time.Time) *UsageCleanupTaskUpdate {
+ _u.mutation.SetUpdatedAt(v)
+ return _u
+}
+
+// SetStatus sets the "status" field.
+func (_u *UsageCleanupTaskUpdate) SetStatus(v string) *UsageCleanupTaskUpdate {
+ _u.mutation.SetStatus(v)
+ return _u
+}
+
+// SetNillableStatus sets the "status" field if the given value is not nil.
+func (_u *UsageCleanupTaskUpdate) SetNillableStatus(v *string) *UsageCleanupTaskUpdate {
+ if v != nil {
+ _u.SetStatus(*v)
+ }
+ return _u
+}
+
+// SetFilters sets the "filters" field.
+func (_u *UsageCleanupTaskUpdate) SetFilters(v json.RawMessage) *UsageCleanupTaskUpdate {
+ _u.mutation.SetFilters(v)
+ return _u
+}
+
+// AppendFilters appends value to the "filters" field.
+func (_u *UsageCleanupTaskUpdate) AppendFilters(v json.RawMessage) *UsageCleanupTaskUpdate {
+ _u.mutation.AppendFilters(v)
+ return _u
+}
+
+// SetCreatedBy sets the "created_by" field.
+func (_u *UsageCleanupTaskUpdate) SetCreatedBy(v int64) *UsageCleanupTaskUpdate {
+ _u.mutation.ResetCreatedBy()
+ _u.mutation.SetCreatedBy(v)
+ return _u
+}
+
+// SetNillableCreatedBy sets the "created_by" field if the given value is not nil.
+func (_u *UsageCleanupTaskUpdate) SetNillableCreatedBy(v *int64) *UsageCleanupTaskUpdate {
+ if v != nil {
+ _u.SetCreatedBy(*v)
+ }
+ return _u
+}
+
+// AddCreatedBy adds value to the "created_by" field.
+func (_u *UsageCleanupTaskUpdate) AddCreatedBy(v int64) *UsageCleanupTaskUpdate {
+ _u.mutation.AddCreatedBy(v)
+ return _u
+}
+
+// SetDeletedRows sets the "deleted_rows" field.
+func (_u *UsageCleanupTaskUpdate) SetDeletedRows(v int64) *UsageCleanupTaskUpdate {
+ _u.mutation.ResetDeletedRows()
+ _u.mutation.SetDeletedRows(v)
+ return _u
+}
+
+// SetNillableDeletedRows sets the "deleted_rows" field if the given value is not nil.
+func (_u *UsageCleanupTaskUpdate) SetNillableDeletedRows(v *int64) *UsageCleanupTaskUpdate {
+ if v != nil {
+ _u.SetDeletedRows(*v)
+ }
+ return _u
+}
+
+// AddDeletedRows adds value to the "deleted_rows" field.
+func (_u *UsageCleanupTaskUpdate) AddDeletedRows(v int64) *UsageCleanupTaskUpdate {
+ _u.mutation.AddDeletedRows(v)
+ return _u
+}
+
+// SetErrorMessage sets the "error_message" field.
+func (_u *UsageCleanupTaskUpdate) SetErrorMessage(v string) *UsageCleanupTaskUpdate {
+ _u.mutation.SetErrorMessage(v)
+ return _u
+}
+
+// SetNillableErrorMessage sets the "error_message" field if the given value is not nil.
+func (_u *UsageCleanupTaskUpdate) SetNillableErrorMessage(v *string) *UsageCleanupTaskUpdate {
+ if v != nil {
+ _u.SetErrorMessage(*v)
+ }
+ return _u
+}
+
+// ClearErrorMessage clears the value of the "error_message" field.
+func (_u *UsageCleanupTaskUpdate) ClearErrorMessage() *UsageCleanupTaskUpdate {
+ _u.mutation.ClearErrorMessage()
+ return _u
+}
+
+// SetCanceledBy sets the "canceled_by" field.
+func (_u *UsageCleanupTaskUpdate) SetCanceledBy(v int64) *UsageCleanupTaskUpdate {
+ _u.mutation.ResetCanceledBy()
+ _u.mutation.SetCanceledBy(v)
+ return _u
+}
+
+// SetNillableCanceledBy sets the "canceled_by" field if the given value is not nil.
+func (_u *UsageCleanupTaskUpdate) SetNillableCanceledBy(v *int64) *UsageCleanupTaskUpdate {
+ if v != nil {
+ _u.SetCanceledBy(*v)
+ }
+ return _u
+}
+
+// AddCanceledBy adds value to the "canceled_by" field.
+func (_u *UsageCleanupTaskUpdate) AddCanceledBy(v int64) *UsageCleanupTaskUpdate {
+ _u.mutation.AddCanceledBy(v)
+ return _u
+}
+
+// ClearCanceledBy clears the value of the "canceled_by" field.
+func (_u *UsageCleanupTaskUpdate) ClearCanceledBy() *UsageCleanupTaskUpdate {
+ _u.mutation.ClearCanceledBy()
+ return _u
+}
+
+// SetCanceledAt sets the "canceled_at" field.
+func (_u *UsageCleanupTaskUpdate) SetCanceledAt(v time.Time) *UsageCleanupTaskUpdate {
+ _u.mutation.SetCanceledAt(v)
+ return _u
+}
+
+// SetNillableCanceledAt sets the "canceled_at" field if the given value is not nil.
+func (_u *UsageCleanupTaskUpdate) SetNillableCanceledAt(v *time.Time) *UsageCleanupTaskUpdate {
+ if v != nil {
+ _u.SetCanceledAt(*v)
+ }
+ return _u
+}
+
+// ClearCanceledAt clears the value of the "canceled_at" field.
+func (_u *UsageCleanupTaskUpdate) ClearCanceledAt() *UsageCleanupTaskUpdate {
+ _u.mutation.ClearCanceledAt()
+ return _u
+}
+
+// SetStartedAt sets the "started_at" field.
+func (_u *UsageCleanupTaskUpdate) SetStartedAt(v time.Time) *UsageCleanupTaskUpdate {
+ _u.mutation.SetStartedAt(v)
+ return _u
+}
+
+// SetNillableStartedAt sets the "started_at" field if the given value is not nil.
+func (_u *UsageCleanupTaskUpdate) SetNillableStartedAt(v *time.Time) *UsageCleanupTaskUpdate {
+ if v != nil {
+ _u.SetStartedAt(*v)
+ }
+ return _u
+}
+
+// ClearStartedAt clears the value of the "started_at" field.
+func (_u *UsageCleanupTaskUpdate) ClearStartedAt() *UsageCleanupTaskUpdate {
+ _u.mutation.ClearStartedAt()
+ return _u
+}
+
+// SetFinishedAt sets the "finished_at" field.
+func (_u *UsageCleanupTaskUpdate) SetFinishedAt(v time.Time) *UsageCleanupTaskUpdate {
+ _u.mutation.SetFinishedAt(v)
+ return _u
+}
+
+// SetNillableFinishedAt sets the "finished_at" field if the given value is not nil.
+func (_u *UsageCleanupTaskUpdate) SetNillableFinishedAt(v *time.Time) *UsageCleanupTaskUpdate {
+ if v != nil {
+ _u.SetFinishedAt(*v)
+ }
+ return _u
+}
+
+// ClearFinishedAt clears the value of the "finished_at" field.
+func (_u *UsageCleanupTaskUpdate) ClearFinishedAt() *UsageCleanupTaskUpdate {
+ _u.mutation.ClearFinishedAt()
+ return _u
+}
+
+// Mutation returns the UsageCleanupTaskMutation object of the builder.
+func (_u *UsageCleanupTaskUpdate) Mutation() *UsageCleanupTaskMutation {
+ return _u.mutation
+}
+
+// Save executes the query and returns the number of nodes affected by the update operation.
+func (_u *UsageCleanupTaskUpdate) Save(ctx context.Context) (int, error) {
+ _u.defaults()
+ return withHooks(ctx, _u.sqlSave, _u.mutation, _u.hooks)
+}
+
+// SaveX is like Save, but panics if an error occurs.
+func (_u *UsageCleanupTaskUpdate) SaveX(ctx context.Context) int {
+ affected, err := _u.Save(ctx)
+ if err != nil {
+ panic(err)
+ }
+ return affected
+}
+
+// Exec executes the query.
+func (_u *UsageCleanupTaskUpdate) Exec(ctx context.Context) error {
+ _, err := _u.Save(ctx)
+ return err
+}
+
+// ExecX is like Exec, but panics if an error occurs.
+func (_u *UsageCleanupTaskUpdate) ExecX(ctx context.Context) {
+ if err := _u.Exec(ctx); err != nil {
+ panic(err)
+ }
+}
+
+// defaults sets the default values of the builder before save.
+func (_u *UsageCleanupTaskUpdate) defaults() {
+ if _, ok := _u.mutation.UpdatedAt(); !ok {
+ v := usagecleanuptask.UpdateDefaultUpdatedAt()
+ _u.mutation.SetUpdatedAt(v)
+ }
+}
+
+// check runs all checks and user-defined validators on the builder.
+func (_u *UsageCleanupTaskUpdate) check() error {
+ if v, ok := _u.mutation.Status(); ok {
+ if err := usagecleanuptask.StatusValidator(v); err != nil {
+ return &ValidationError{Name: "status", err: fmt.Errorf(`ent: validator failed for field "UsageCleanupTask.status": %w`, err)}
+ }
+ }
+ return nil
+}
+
+func (_u *UsageCleanupTaskUpdate) sqlSave(ctx context.Context) (_node int, err error) {
+ if err := _u.check(); err != nil {
+ return _node, err
+ }
+ _spec := sqlgraph.NewUpdateSpec(usagecleanuptask.Table, usagecleanuptask.Columns, sqlgraph.NewFieldSpec(usagecleanuptask.FieldID, field.TypeInt64))
+ if ps := _u.mutation.predicates; len(ps) > 0 {
+ _spec.Predicate = func(selector *sql.Selector) {
+ for i := range ps {
+ ps[i](selector)
+ }
+ }
+ }
+ if value, ok := _u.mutation.UpdatedAt(); ok {
+ _spec.SetField(usagecleanuptask.FieldUpdatedAt, field.TypeTime, value)
+ }
+ if value, ok := _u.mutation.Status(); ok {
+ _spec.SetField(usagecleanuptask.FieldStatus, field.TypeString, value)
+ }
+ if value, ok := _u.mutation.Filters(); ok {
+ _spec.SetField(usagecleanuptask.FieldFilters, field.TypeJSON, value)
+ }
+ if value, ok := _u.mutation.AppendedFilters(); ok {
+ _spec.AddModifier(func(u *sql.UpdateBuilder) {
+ sqljson.Append(u, usagecleanuptask.FieldFilters, value)
+ })
+ }
+ if value, ok := _u.mutation.CreatedBy(); ok {
+ _spec.SetField(usagecleanuptask.FieldCreatedBy, field.TypeInt64, value)
+ }
+ if value, ok := _u.mutation.AddedCreatedBy(); ok {
+ _spec.AddField(usagecleanuptask.FieldCreatedBy, field.TypeInt64, value)
+ }
+ if value, ok := _u.mutation.DeletedRows(); ok {
+ _spec.SetField(usagecleanuptask.FieldDeletedRows, field.TypeInt64, value)
+ }
+ if value, ok := _u.mutation.AddedDeletedRows(); ok {
+ _spec.AddField(usagecleanuptask.FieldDeletedRows, field.TypeInt64, value)
+ }
+ if value, ok := _u.mutation.ErrorMessage(); ok {
+ _spec.SetField(usagecleanuptask.FieldErrorMessage, field.TypeString, value)
+ }
+ if _u.mutation.ErrorMessageCleared() {
+ _spec.ClearField(usagecleanuptask.FieldErrorMessage, field.TypeString)
+ }
+ if value, ok := _u.mutation.CanceledBy(); ok {
+ _spec.SetField(usagecleanuptask.FieldCanceledBy, field.TypeInt64, value)
+ }
+ if value, ok := _u.mutation.AddedCanceledBy(); ok {
+ _spec.AddField(usagecleanuptask.FieldCanceledBy, field.TypeInt64, value)
+ }
+ if _u.mutation.CanceledByCleared() {
+ _spec.ClearField(usagecleanuptask.FieldCanceledBy, field.TypeInt64)
+ }
+ if value, ok := _u.mutation.CanceledAt(); ok {
+ _spec.SetField(usagecleanuptask.FieldCanceledAt, field.TypeTime, value)
+ }
+ if _u.mutation.CanceledAtCleared() {
+ _spec.ClearField(usagecleanuptask.FieldCanceledAt, field.TypeTime)
+ }
+ if value, ok := _u.mutation.StartedAt(); ok {
+ _spec.SetField(usagecleanuptask.FieldStartedAt, field.TypeTime, value)
+ }
+ if _u.mutation.StartedAtCleared() {
+ _spec.ClearField(usagecleanuptask.FieldStartedAt, field.TypeTime)
+ }
+ if value, ok := _u.mutation.FinishedAt(); ok {
+ _spec.SetField(usagecleanuptask.FieldFinishedAt, field.TypeTime, value)
+ }
+ if _u.mutation.FinishedAtCleared() {
+ _spec.ClearField(usagecleanuptask.FieldFinishedAt, field.TypeTime)
+ }
+ if _node, err = sqlgraph.UpdateNodes(ctx, _u.driver, _spec); err != nil {
+ if _, ok := err.(*sqlgraph.NotFoundError); ok {
+ err = &NotFoundError{usagecleanuptask.Label}
+ } else if sqlgraph.IsConstraintError(err) {
+ err = &ConstraintError{msg: err.Error(), wrap: err}
+ }
+ return 0, err
+ }
+ _u.mutation.done = true
+ return _node, nil
+}
+
+// UsageCleanupTaskUpdateOne is the builder for updating a single UsageCleanupTask entity.
+type UsageCleanupTaskUpdateOne struct {
+ config
+ fields []string
+ hooks []Hook
+ mutation *UsageCleanupTaskMutation
+}
+
+// SetUpdatedAt sets the "updated_at" field.
+func (_u *UsageCleanupTaskUpdateOne) SetUpdatedAt(v time.Time) *UsageCleanupTaskUpdateOne {
+ _u.mutation.SetUpdatedAt(v)
+ return _u
+}
+
+// SetStatus sets the "status" field.
+func (_u *UsageCleanupTaskUpdateOne) SetStatus(v string) *UsageCleanupTaskUpdateOne {
+ _u.mutation.SetStatus(v)
+ return _u
+}
+
+// SetNillableStatus sets the "status" field if the given value is not nil.
+func (_u *UsageCleanupTaskUpdateOne) SetNillableStatus(v *string) *UsageCleanupTaskUpdateOne {
+ if v != nil {
+ _u.SetStatus(*v)
+ }
+ return _u
+}
+
+// SetFilters sets the "filters" field.
+func (_u *UsageCleanupTaskUpdateOne) SetFilters(v json.RawMessage) *UsageCleanupTaskUpdateOne {
+ _u.mutation.SetFilters(v)
+ return _u
+}
+
+// AppendFilters appends value to the "filters" field.
+func (_u *UsageCleanupTaskUpdateOne) AppendFilters(v json.RawMessage) *UsageCleanupTaskUpdateOne {
+ _u.mutation.AppendFilters(v)
+ return _u
+}
+
+// SetCreatedBy sets the "created_by" field.
+func (_u *UsageCleanupTaskUpdateOne) SetCreatedBy(v int64) *UsageCleanupTaskUpdateOne {
+ _u.mutation.ResetCreatedBy()
+ _u.mutation.SetCreatedBy(v)
+ return _u
+}
+
+// SetNillableCreatedBy sets the "created_by" field if the given value is not nil.
+func (_u *UsageCleanupTaskUpdateOne) SetNillableCreatedBy(v *int64) *UsageCleanupTaskUpdateOne {
+ if v != nil {
+ _u.SetCreatedBy(*v)
+ }
+ return _u
+}
+
+// AddCreatedBy adds value to the "created_by" field.
+func (_u *UsageCleanupTaskUpdateOne) AddCreatedBy(v int64) *UsageCleanupTaskUpdateOne {
+ _u.mutation.AddCreatedBy(v)
+ return _u
+}
+
+// SetDeletedRows sets the "deleted_rows" field.
+func (_u *UsageCleanupTaskUpdateOne) SetDeletedRows(v int64) *UsageCleanupTaskUpdateOne {
+ _u.mutation.ResetDeletedRows()
+ _u.mutation.SetDeletedRows(v)
+ return _u
+}
+
+// SetNillableDeletedRows sets the "deleted_rows" field if the given value is not nil.
+func (_u *UsageCleanupTaskUpdateOne) SetNillableDeletedRows(v *int64) *UsageCleanupTaskUpdateOne {
+ if v != nil {
+ _u.SetDeletedRows(*v)
+ }
+ return _u
+}
+
+// AddDeletedRows adds value to the "deleted_rows" field.
+func (_u *UsageCleanupTaskUpdateOne) AddDeletedRows(v int64) *UsageCleanupTaskUpdateOne {
+ _u.mutation.AddDeletedRows(v)
+ return _u
+}
+
+// SetErrorMessage sets the "error_message" field.
+func (_u *UsageCleanupTaskUpdateOne) SetErrorMessage(v string) *UsageCleanupTaskUpdateOne {
+ _u.mutation.SetErrorMessage(v)
+ return _u
+}
+
+// SetNillableErrorMessage sets the "error_message" field if the given value is not nil.
+func (_u *UsageCleanupTaskUpdateOne) SetNillableErrorMessage(v *string) *UsageCleanupTaskUpdateOne {
+ if v != nil {
+ _u.SetErrorMessage(*v)
+ }
+ return _u
+}
+
+// ClearErrorMessage clears the value of the "error_message" field.
+func (_u *UsageCleanupTaskUpdateOne) ClearErrorMessage() *UsageCleanupTaskUpdateOne {
+ _u.mutation.ClearErrorMessage()
+ return _u
+}
+
+// SetCanceledBy sets the "canceled_by" field.
+func (_u *UsageCleanupTaskUpdateOne) SetCanceledBy(v int64) *UsageCleanupTaskUpdateOne {
+ _u.mutation.ResetCanceledBy()
+ _u.mutation.SetCanceledBy(v)
+ return _u
+}
+
+// SetNillableCanceledBy sets the "canceled_by" field if the given value is not nil.
+func (_u *UsageCleanupTaskUpdateOne) SetNillableCanceledBy(v *int64) *UsageCleanupTaskUpdateOne {
+ if v != nil {
+ _u.SetCanceledBy(*v)
+ }
+ return _u
+}
+
+// AddCanceledBy adds value to the "canceled_by" field.
+func (_u *UsageCleanupTaskUpdateOne) AddCanceledBy(v int64) *UsageCleanupTaskUpdateOne {
+ _u.mutation.AddCanceledBy(v)
+ return _u
+}
+
+// ClearCanceledBy clears the value of the "canceled_by" field.
+func (_u *UsageCleanupTaskUpdateOne) ClearCanceledBy() *UsageCleanupTaskUpdateOne {
+ _u.mutation.ClearCanceledBy()
+ return _u
+}
+
+// SetCanceledAt sets the "canceled_at" field.
+func (_u *UsageCleanupTaskUpdateOne) SetCanceledAt(v time.Time) *UsageCleanupTaskUpdateOne {
+ _u.mutation.SetCanceledAt(v)
+ return _u
+}
+
+// SetNillableCanceledAt sets the "canceled_at" field if the given value is not nil.
+func (_u *UsageCleanupTaskUpdateOne) SetNillableCanceledAt(v *time.Time) *UsageCleanupTaskUpdateOne {
+ if v != nil {
+ _u.SetCanceledAt(*v)
+ }
+ return _u
+}
+
+// ClearCanceledAt clears the value of the "canceled_at" field.
+func (_u *UsageCleanupTaskUpdateOne) ClearCanceledAt() *UsageCleanupTaskUpdateOne {
+ _u.mutation.ClearCanceledAt()
+ return _u
+}
+
+// SetStartedAt sets the "started_at" field.
+func (_u *UsageCleanupTaskUpdateOne) SetStartedAt(v time.Time) *UsageCleanupTaskUpdateOne {
+ _u.mutation.SetStartedAt(v)
+ return _u
+}
+
+// SetNillableStartedAt sets the "started_at" field if the given value is not nil.
+func (_u *UsageCleanupTaskUpdateOne) SetNillableStartedAt(v *time.Time) *UsageCleanupTaskUpdateOne {
+ if v != nil {
+ _u.SetStartedAt(*v)
+ }
+ return _u
+}
+
+// ClearStartedAt clears the value of the "started_at" field.
+func (_u *UsageCleanupTaskUpdateOne) ClearStartedAt() *UsageCleanupTaskUpdateOne {
+ _u.mutation.ClearStartedAt()
+ return _u
+}
+
+// SetFinishedAt sets the "finished_at" field.
+func (_u *UsageCleanupTaskUpdateOne) SetFinishedAt(v time.Time) *UsageCleanupTaskUpdateOne {
+ _u.mutation.SetFinishedAt(v)
+ return _u
+}
+
+// SetNillableFinishedAt sets the "finished_at" field if the given value is not nil.
+func (_u *UsageCleanupTaskUpdateOne) SetNillableFinishedAt(v *time.Time) *UsageCleanupTaskUpdateOne {
+ if v != nil {
+ _u.SetFinishedAt(*v)
+ }
+ return _u
+}
+
+// ClearFinishedAt clears the value of the "finished_at" field.
+func (_u *UsageCleanupTaskUpdateOne) ClearFinishedAt() *UsageCleanupTaskUpdateOne {
+ _u.mutation.ClearFinishedAt()
+ return _u
+}
+
+// Mutation returns the UsageCleanupTaskMutation object of the builder.
+func (_u *UsageCleanupTaskUpdateOne) Mutation() *UsageCleanupTaskMutation {
+ return _u.mutation
+}
+
+// Where appends a list predicates to the UsageCleanupTaskUpdate builder.
+func (_u *UsageCleanupTaskUpdateOne) Where(ps ...predicate.UsageCleanupTask) *UsageCleanupTaskUpdateOne {
+ _u.mutation.Where(ps...)
+ return _u
+}
+
+// Select allows selecting one or more fields (columns) of the returned entity.
+// The default is selecting all fields defined in the entity schema.
+func (_u *UsageCleanupTaskUpdateOne) Select(field string, fields ...string) *UsageCleanupTaskUpdateOne {
+ _u.fields = append([]string{field}, fields...)
+ return _u
+}
+
+// Save executes the query and returns the updated UsageCleanupTask entity.
+func (_u *UsageCleanupTaskUpdateOne) Save(ctx context.Context) (*UsageCleanupTask, error) {
+ _u.defaults()
+ return withHooks(ctx, _u.sqlSave, _u.mutation, _u.hooks)
+}
+
+// SaveX is like Save, but panics if an error occurs.
+func (_u *UsageCleanupTaskUpdateOne) SaveX(ctx context.Context) *UsageCleanupTask {
+ node, err := _u.Save(ctx)
+ if err != nil {
+ panic(err)
+ }
+ return node
+}
+
+// Exec executes the query on the entity.
+func (_u *UsageCleanupTaskUpdateOne) Exec(ctx context.Context) error {
+ _, err := _u.Save(ctx)
+ return err
+}
+
+// ExecX is like Exec, but panics if an error occurs.
+func (_u *UsageCleanupTaskUpdateOne) ExecX(ctx context.Context) {
+ if err := _u.Exec(ctx); err != nil {
+ panic(err)
+ }
+}
+
+// defaults sets the default values of the builder before save.
+func (_u *UsageCleanupTaskUpdateOne) defaults() {
+ if _, ok := _u.mutation.UpdatedAt(); !ok {
+ v := usagecleanuptask.UpdateDefaultUpdatedAt()
+ _u.mutation.SetUpdatedAt(v)
+ }
+}
+
+// check runs all checks and user-defined validators on the builder.
+func (_u *UsageCleanupTaskUpdateOne) check() error {
+ if v, ok := _u.mutation.Status(); ok {
+ if err := usagecleanuptask.StatusValidator(v); err != nil {
+ return &ValidationError{Name: "status", err: fmt.Errorf(`ent: validator failed for field "UsageCleanupTask.status": %w`, err)}
+ }
+ }
+ return nil
+}
+
+func (_u *UsageCleanupTaskUpdateOne) sqlSave(ctx context.Context) (_node *UsageCleanupTask, err error) {
+ if err := _u.check(); err != nil {
+ return _node, err
+ }
+ _spec := sqlgraph.NewUpdateSpec(usagecleanuptask.Table, usagecleanuptask.Columns, sqlgraph.NewFieldSpec(usagecleanuptask.FieldID, field.TypeInt64))
+ id, ok := _u.mutation.ID()
+ if !ok {
+ return nil, &ValidationError{Name: "id", err: errors.New(`ent: missing "UsageCleanupTask.id" for update`)}
+ }
+ _spec.Node.ID.Value = id
+ if fields := _u.fields; len(fields) > 0 {
+ _spec.Node.Columns = make([]string, 0, len(fields))
+ _spec.Node.Columns = append(_spec.Node.Columns, usagecleanuptask.FieldID)
+ for _, f := range fields {
+ if !usagecleanuptask.ValidColumn(f) {
+ return nil, &ValidationError{Name: f, err: fmt.Errorf("ent: invalid field %q for query", f)}
+ }
+ if f != usagecleanuptask.FieldID {
+ _spec.Node.Columns = append(_spec.Node.Columns, f)
+ }
+ }
+ }
+ if ps := _u.mutation.predicates; len(ps) > 0 {
+ _spec.Predicate = func(selector *sql.Selector) {
+ for i := range ps {
+ ps[i](selector)
+ }
+ }
+ }
+ if value, ok := _u.mutation.UpdatedAt(); ok {
+ _spec.SetField(usagecleanuptask.FieldUpdatedAt, field.TypeTime, value)
+ }
+ if value, ok := _u.mutation.Status(); ok {
+ _spec.SetField(usagecleanuptask.FieldStatus, field.TypeString, value)
+ }
+ if value, ok := _u.mutation.Filters(); ok {
+ _spec.SetField(usagecleanuptask.FieldFilters, field.TypeJSON, value)
+ }
+ if value, ok := _u.mutation.AppendedFilters(); ok {
+ _spec.AddModifier(func(u *sql.UpdateBuilder) {
+ sqljson.Append(u, usagecleanuptask.FieldFilters, value)
+ })
+ }
+ if value, ok := _u.mutation.CreatedBy(); ok {
+ _spec.SetField(usagecleanuptask.FieldCreatedBy, field.TypeInt64, value)
+ }
+ if value, ok := _u.mutation.AddedCreatedBy(); ok {
+ _spec.AddField(usagecleanuptask.FieldCreatedBy, field.TypeInt64, value)
+ }
+ if value, ok := _u.mutation.DeletedRows(); ok {
+ _spec.SetField(usagecleanuptask.FieldDeletedRows, field.TypeInt64, value)
+ }
+ if value, ok := _u.mutation.AddedDeletedRows(); ok {
+ _spec.AddField(usagecleanuptask.FieldDeletedRows, field.TypeInt64, value)
+ }
+ if value, ok := _u.mutation.ErrorMessage(); ok {
+ _spec.SetField(usagecleanuptask.FieldErrorMessage, field.TypeString, value)
+ }
+ if _u.mutation.ErrorMessageCleared() {
+ _spec.ClearField(usagecleanuptask.FieldErrorMessage, field.TypeString)
+ }
+ if value, ok := _u.mutation.CanceledBy(); ok {
+ _spec.SetField(usagecleanuptask.FieldCanceledBy, field.TypeInt64, value)
+ }
+ if value, ok := _u.mutation.AddedCanceledBy(); ok {
+ _spec.AddField(usagecleanuptask.FieldCanceledBy, field.TypeInt64, value)
+ }
+ if _u.mutation.CanceledByCleared() {
+ _spec.ClearField(usagecleanuptask.FieldCanceledBy, field.TypeInt64)
+ }
+ if value, ok := _u.mutation.CanceledAt(); ok {
+ _spec.SetField(usagecleanuptask.FieldCanceledAt, field.TypeTime, value)
+ }
+ if _u.mutation.CanceledAtCleared() {
+ _spec.ClearField(usagecleanuptask.FieldCanceledAt, field.TypeTime)
+ }
+ if value, ok := _u.mutation.StartedAt(); ok {
+ _spec.SetField(usagecleanuptask.FieldStartedAt, field.TypeTime, value)
+ }
+ if _u.mutation.StartedAtCleared() {
+ _spec.ClearField(usagecleanuptask.FieldStartedAt, field.TypeTime)
+ }
+ if value, ok := _u.mutation.FinishedAt(); ok {
+ _spec.SetField(usagecleanuptask.FieldFinishedAt, field.TypeTime, value)
+ }
+ if _u.mutation.FinishedAtCleared() {
+ _spec.ClearField(usagecleanuptask.FieldFinishedAt, field.TypeTime)
+ }
+ _node = &UsageCleanupTask{config: _u.config}
+ _spec.Assign = _node.assignValues
+ _spec.ScanValues = _node.scanValues
+ if err = sqlgraph.UpdateNode(ctx, _u.driver, _spec); err != nil {
+ if _, ok := err.(*sqlgraph.NotFoundError); ok {
+ err = &NotFoundError{usagecleanuptask.Label}
+ } else if sqlgraph.IsConstraintError(err) {
+ err = &ConstraintError{msg: err.Error(), wrap: err}
+ }
+ return nil, err
+ }
+ _u.mutation.done = true
+ return _node, nil
+}
diff --git a/backend/go.mod b/backend/go.mod
index 9ebae69e..fd429b07 100644
--- a/backend/go.mod
+++ b/backend/go.mod
@@ -98,6 +98,7 @@ require (
github.com/modern-go/concurrent v0.0.0-20180306012644-bacd9c7ef1dd // indirect
github.com/modern-go/reflect2 v1.0.2 // indirect
github.com/morikuni/aec v1.0.0 // indirect
+ github.com/ncruces/go-strftime v1.0.0 // indirect
github.com/olekukonko/tablewriter v0.0.5 // indirect
github.com/opencontainers/go-digest v1.0.0 // indirect
github.com/opencontainers/image-spec v1.1.1 // indirect
@@ -108,6 +109,7 @@ require (
github.com/quic-go/qpack v0.6.0 // indirect
github.com/quic-go/quic-go v0.57.1 // indirect
github.com/refraction-networking/utls v1.8.1 // indirect
+ github.com/remyoudompheng/bigfft v0.0.0-20230129092748-24d4a6f8daec // indirect
github.com/rivo/uniseg v0.2.0 // indirect
github.com/robfig/cron/v3 v3.0.1 // indirect
github.com/sagikazarmark/locafero v0.4.0 // indirect
@@ -140,7 +142,7 @@ require (
go.uber.org/automaxprocs v1.6.0 // indirect
go.uber.org/multierr v1.9.0 // indirect
golang.org/x/arch v0.3.0 // indirect
- golang.org/x/exp v0.0.0-20230905200255-921286631fa9 // indirect
+ golang.org/x/exp v0.0.0-20251023183803-a4bb9ffd2546 // indirect
golang.org/x/mod v0.30.0 // indirect
golang.org/x/sys v0.39.0 // indirect
golang.org/x/text v0.32.0 // indirect
@@ -149,4 +151,8 @@ require (
google.golang.org/grpc v1.75.1 // indirect
google.golang.org/protobuf v1.36.10 // indirect
gopkg.in/ini.v1 v1.67.0 // indirect
+ modernc.org/libc v1.67.6 // indirect
+ modernc.org/mathutil v1.7.1 // indirect
+ modernc.org/memory v1.11.0 // indirect
+ modernc.org/sqlite v1.44.1 // indirect
)
diff --git a/backend/go.sum b/backend/go.sum
index 4496603d..aa10718c 100644
--- a/backend/go.sum
+++ b/backend/go.sum
@@ -200,6 +200,8 @@ github.com/modern-go/reflect2 v1.0.2 h1:xBagoLtFs94CBntxluKeaWgTMpvLxC4ur3nMaC9G
github.com/modern-go/reflect2 v1.0.2/go.mod h1:yWuevngMOJpCy52FWWMvUC8ws7m/LJsjYzDa0/r8luk=
github.com/morikuni/aec v1.0.0 h1:nP9CBfwrvYnBRgY6qfDQkygYDmYwOilePFkwzv4dU8A=
github.com/morikuni/aec v1.0.0/go.mod h1:BbKIizmSmc5MMPqRYbxO4ZU0S0+P200+tUnFx7PXmsc=
+github.com/ncruces/go-strftime v1.0.0 h1:HMFp8mLCTPp341M/ZnA4qaf7ZlsbTc+miZjCLOFAw7w=
+github.com/ncruces/go-strftime v1.0.0/go.mod h1:Fwc5htZGVVkseilnfgOVb9mKy6w1naJmn9CehxcKcls=
github.com/olekukonko/tablewriter v0.0.5 h1:P2Ga83D34wi1o9J6Wh1mRuqd4mF/x/lgBS7N7AbDhec=
github.com/olekukonko/tablewriter v0.0.5/go.mod h1:hPp6KlRPjbx+hW8ykQs1w3UBbZlj6HuIJcUGPhkA7kY=
github.com/opencontainers/go-digest v1.0.0 h1:apOUWs51W5PlhuyGyz9FCeeBIOUDA/6nW8Oi/yOhh5U=
@@ -225,6 +227,8 @@ github.com/redis/go-redis/v9 v9.17.2 h1:P2EGsA4qVIM3Pp+aPocCJ7DguDHhqrXNhVcEp4Vi
github.com/redis/go-redis/v9 v9.17.2/go.mod h1:u410H11HMLoB+TP67dz8rL9s6QW2j76l0//kSOd3370=
github.com/refraction-networking/utls v1.8.1 h1:yNY1kapmQU8JeM1sSw2H2asfTIwWxIkrMJI0pRUOCAo=
github.com/refraction-networking/utls v1.8.1/go.mod h1:jkSOEkLqn+S/jtpEHPOsVv/4V4EVnelwbMQl4vCWXAM=
+github.com/remyoudompheng/bigfft v0.0.0-20230129092748-24d4a6f8daec h1:W09IVJc94icq4NjY3clb7Lk8O1qJ8BdBEF8z0ibU0rE=
+github.com/remyoudompheng/bigfft v0.0.0-20230129092748-24d4a6f8daec/go.mod h1:qqbHyh8v60DhA7CoWK5oRCqLrMHRGoxYCSS9EjAz6Eo=
github.com/rivo/uniseg v0.2.0 h1:S1pD9weZBuJdFmowNwbpi7BJ8TNftyUImj/0WQi72jY=
github.com/rivo/uniseg v0.2.0/go.mod h1:J6wj4VEh+S6ZtnVlnTBMWIodfgj8LQOQFoIToxlJtxc=
github.com/robfig/cron/v3 v3.0.1 h1:WdRxkvbJztn8LMz/QEvLN5sBU+xKpSqwwUO1Pjr4qDs=
@@ -339,6 +343,8 @@ golang.org/x/crypto v0.46.0 h1:cKRW/pmt1pKAfetfu+RCEvjvZkA9RimPbh7bhFjGVBU=
golang.org/x/crypto v0.46.0/go.mod h1:Evb/oLKmMraqjZ2iQTwDwvCtJkczlDuTmdJXoZVzqU0=
golang.org/x/exp v0.0.0-20230905200255-921286631fa9 h1:GoHiUyI/Tp2nVkLI2mCxVkOjsbSXD66ic0XW0js0R9g=
golang.org/x/exp v0.0.0-20230905200255-921286631fa9/go.mod h1:S2oDrQGGwySpoQPVqRShND87VCbxmc6bL1Yd2oYrm6k=
+golang.org/x/exp v0.0.0-20251023183803-a4bb9ffd2546 h1:mgKeJMpvi0yx/sU5GsxQ7p6s2wtOnGAHZWCHUM4KGzY=
+golang.org/x/exp v0.0.0-20251023183803-a4bb9ffd2546/go.mod h1:j/pmGrbnkbPtQfxEe5D0VQhZC6qKbfKifgD0oM7sR70=
golang.org/x/mod v0.30.0 h1:fDEXFVZ/fmCKProc/yAXXUijritrDzahmwwefnjoPFk=
golang.org/x/mod v0.30.0/go.mod h1:lAsf5O2EvJeSFMiBxXDki7sCgAxEUcZHXoXMKT4GJKc=
golang.org/x/net v0.48.0 h1:zyQRTTrjc33Lhh0fBgT/H3oZq9WuvRR5gPC70xpDiQU=
@@ -366,6 +372,7 @@ golang.org/x/tools v0.39.0 h1:ik4ho21kwuQln40uelmciQPp9SipgNDdrafrYA4TmQQ=
golang.org/x/tools v0.39.0/go.mod h1:JnefbkDPyD8UU2kI5fuf8ZX4/yUeh9W877ZeBONxUqQ=
golang.org/x/tools/go/expect v0.1.0-deprecated h1:jY2C5HGYR5lqex3gEniOQL0r7Dq5+VGVgY1nudX5lXY=
golang.org/x/tools/go/expect v0.1.0-deprecated/go.mod h1:eihoPOH+FgIqa3FpoTwguz/bVUSGBlGQU67vpBeOrBY=
+golang.org/x/tools/go/expect v0.1.1-deprecated h1:jpBZDwmgPhXsKZC6WhL20P4b/wmnpsEAGHaNy0n/rJM=
golang.org/x/tools/go/packages/packagestest v0.1.1-deprecated h1:1h2MnaIAIXISqTFKdENegdpAgUXz6NrPEsbIeWaBRvM=
golang.org/x/tools/go/packages/packagestest v0.1.1-deprecated/go.mod h1:RVAQXBGNv1ib0J382/DPCRS/BPnsGebyM1Gj5VSDpG8=
golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0=
@@ -388,4 +395,12 @@ gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA=
gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM=
gotest.tools/v3 v3.5.2 h1:7koQfIKdy+I8UTetycgUqXWSDwpgv193Ka+qRsmBY8Q=
gotest.tools/v3 v3.5.2/go.mod h1:LtdLGcnqToBH83WByAAi/wiwSFCArdFIUV/xxN4pcjA=
+modernc.org/libc v1.67.6 h1:eVOQvpModVLKOdT+LvBPjdQqfrZq+pC39BygcT+E7OI=
+modernc.org/libc v1.67.6/go.mod h1:JAhxUVlolfYDErnwiqaLvUqc8nfb2r6S6slAgZOnaiE=
+modernc.org/mathutil v1.7.1 h1:GCZVGXdaN8gTqB1Mf/usp1Y/hSqgI2vAGGP4jZMCxOU=
+modernc.org/mathutil v1.7.1/go.mod h1:4p5IwJITfppl0G4sUEDtCr4DthTaT47/N3aT6MhfgJg=
+modernc.org/memory v1.11.0 h1:o4QC8aMQzmcwCK3t3Ux/ZHmwFPzE6hf2Y5LbkRs+hbI=
+modernc.org/memory v1.11.0/go.mod h1:/JP4VbVC+K5sU2wZi9bHoq2MAkCnrt2r98UGeSK7Mjw=
+modernc.org/sqlite v1.44.1 h1:qybx/rNpfQipX/t47OxbHmkkJuv2JWifCMH8SVUiDas=
+modernc.org/sqlite v1.44.1/go.mod h1:CzbrU2lSB1DKUusvwGz7rqEKIq+NUd8GWuBBZDs9/nA=
rsc.io/pdf v0.1.1/go.mod h1:n8OzWcQ6Sp37PL01nO98y4iUCRdTGarVfzxY20ICaU4=
diff --git a/backend/internal/handler/admin/usage_cleanup_handler_test.go b/backend/internal/handler/admin/usage_cleanup_handler_test.go
index d8684c39..ed1c7cc2 100644
--- a/backend/internal/handler/admin/usage_cleanup_handler_test.go
+++ b/backend/internal/handler/admin/usage_cleanup_handler_test.go
@@ -3,8 +3,8 @@ package admin
import (
"bytes"
"context"
- "encoding/json"
"database/sql"
+ "encoding/json"
"errors"
"net/http"
"net/http/httptest"
diff --git a/backend/internal/repository/usage_cleanup_repo.go b/backend/internal/repository/usage_cleanup_repo.go
index b6dfa42a..9c021357 100644
--- a/backend/internal/repository/usage_cleanup_repo.go
+++ b/backend/internal/repository/usage_cleanup_repo.go
@@ -7,43 +7,41 @@ import (
"errors"
"fmt"
"strings"
+ "time"
+ dbent "github.com/Wei-Shaw/sub2api/ent"
+ dbusagecleanuptask "github.com/Wei-Shaw/sub2api/ent/usagecleanuptask"
"github.com/Wei-Shaw/sub2api/internal/pkg/pagination"
"github.com/Wei-Shaw/sub2api/internal/service"
)
type usageCleanupRepository struct {
- sql sqlExecutor
+ client *dbent.Client
+ sql sqlExecutor
}
-func NewUsageCleanupRepository(sqlDB *sql.DB) service.UsageCleanupRepository {
- return &usageCleanupRepository{sql: sqlDB}
+func NewUsageCleanupRepository(client *dbent.Client, sqlDB *sql.DB) service.UsageCleanupRepository {
+ return newUsageCleanupRepositoryWithSQL(client, sqlDB)
+}
+
+func newUsageCleanupRepositoryWithSQL(client *dbent.Client, sqlq sqlExecutor) *usageCleanupRepository {
+ return &usageCleanupRepository{client: client, sql: sqlq}
}
func (r *usageCleanupRepository) CreateTask(ctx context.Context, task *service.UsageCleanupTask) error {
if task == nil {
return nil
}
- filtersJSON, err := json.Marshal(task.Filters)
- if err != nil {
- return fmt.Errorf("marshal cleanup filters: %w", err)
+ if r.client != nil {
+ return r.createTaskWithEnt(ctx, task)
}
- query := `
- INSERT INTO usage_cleanup_tasks (
- status,
- filters,
- created_by,
- deleted_rows
- ) VALUES ($1, $2, $3, $4)
- RETURNING id, created_at, updated_at
- `
- if err := scanSingleRow(ctx, r.sql, query, []any{task.Status, filtersJSON, task.CreatedBy, task.DeletedRows}, &task.ID, &task.CreatedAt, &task.UpdatedAt); err != nil {
- return err
- }
- return nil
+ return r.createTaskWithSQL(ctx, task)
}
func (r *usageCleanupRepository) ListTasks(ctx context.Context, params pagination.PaginationParams) ([]service.UsageCleanupTask, *pagination.PaginationResult, error) {
+ if r.client != nil {
+ return r.listTasksWithEnt(ctx, params)
+ }
var total int64
if err := scanSingleRow(ctx, r.sql, "SELECT COUNT(*) FROM usage_cleanup_tasks", nil, &total); err != nil {
return nil, nil, err
@@ -57,14 +55,14 @@ func (r *usageCleanupRepository) ListTasks(ctx context.Context, params paginatio
canceled_by, canceled_at,
started_at, finished_at, created_at, updated_at
FROM usage_cleanup_tasks
- ORDER BY created_at DESC
+ ORDER BY created_at DESC, id DESC
LIMIT $1 OFFSET $2
`
rows, err := r.sql.QueryContext(ctx, query, params.Limit(), params.Offset())
if err != nil {
return nil, nil, err
}
- defer rows.Close()
+ defer func() { _ = rows.Close() }()
tasks := make([]service.UsageCleanupTask, 0)
for rows.Next() {
@@ -194,6 +192,9 @@ func (r *usageCleanupRepository) ClaimNextPendingTask(ctx context.Context, stale
}
func (r *usageCleanupRepository) GetTaskStatus(ctx context.Context, taskID int64) (string, error) {
+ if r.client != nil {
+ return r.getTaskStatusWithEnt(ctx, taskID)
+ }
var status string
if err := scanSingleRow(ctx, r.sql, "SELECT status FROM usage_cleanup_tasks WHERE id = $1", []any{taskID}, &status); err != nil {
return "", err
@@ -202,6 +203,9 @@ func (r *usageCleanupRepository) GetTaskStatus(ctx context.Context, taskID int64
}
func (r *usageCleanupRepository) UpdateTaskProgress(ctx context.Context, taskID int64, deletedRows int64) error {
+ if r.client != nil {
+ return r.updateTaskProgressWithEnt(ctx, taskID, deletedRows)
+ }
query := `
UPDATE usage_cleanup_tasks
SET deleted_rows = $1,
@@ -213,6 +217,9 @@ func (r *usageCleanupRepository) UpdateTaskProgress(ctx context.Context, taskID
}
func (r *usageCleanupRepository) CancelTask(ctx context.Context, taskID int64, canceledBy int64) (bool, error) {
+ if r.client != nil {
+ return r.cancelTaskWithEnt(ctx, taskID, canceledBy)
+ }
query := `
UPDATE usage_cleanup_tasks
SET status = $1,
@@ -243,6 +250,9 @@ func (r *usageCleanupRepository) CancelTask(ctx context.Context, taskID int64, c
}
func (r *usageCleanupRepository) MarkTaskSucceeded(ctx context.Context, taskID int64, deletedRows int64) error {
+ if r.client != nil {
+ return r.markTaskSucceededWithEnt(ctx, taskID, deletedRows)
+ }
query := `
UPDATE usage_cleanup_tasks
SET status = $1,
@@ -256,6 +266,9 @@ func (r *usageCleanupRepository) MarkTaskSucceeded(ctx context.Context, taskID i
}
func (r *usageCleanupRepository) MarkTaskFailed(ctx context.Context, taskID int64, deletedRows int64, errorMsg string) error {
+ if r.client != nil {
+ return r.markTaskFailedWithEnt(ctx, taskID, deletedRows, errorMsg)
+ }
query := `
UPDATE usage_cleanup_tasks
SET status = $1,
@@ -295,7 +308,7 @@ func (r *usageCleanupRepository) DeleteUsageLogsBatch(ctx context.Context, filte
if err != nil {
return 0, err
}
- defer rows.Close()
+ defer func() { _ = rows.Close() }()
var deleted int64
for rows.Next() {
@@ -357,7 +370,182 @@ func buildUsageCleanupWhere(filters service.UsageCleanupFilters) (string, []any)
if filters.BillingType != nil {
conditions = append(conditions, fmt.Sprintf("billing_type = $%d", idx))
args = append(args, *filters.BillingType)
- idx++
}
return strings.Join(conditions, " AND "), args
}
+
+func (r *usageCleanupRepository) createTaskWithEnt(ctx context.Context, task *service.UsageCleanupTask) error {
+ client := clientFromContext(ctx, r.client)
+ filtersJSON, err := json.Marshal(task.Filters)
+ if err != nil {
+ return fmt.Errorf("marshal cleanup filters: %w", err)
+ }
+ created, err := client.UsageCleanupTask.
+ Create().
+ SetStatus(task.Status).
+ SetFilters(json.RawMessage(filtersJSON)).
+ SetCreatedBy(task.CreatedBy).
+ SetDeletedRows(task.DeletedRows).
+ Save(ctx)
+ if err != nil {
+ return err
+ }
+ task.ID = created.ID
+ task.CreatedAt = created.CreatedAt
+ task.UpdatedAt = created.UpdatedAt
+ return nil
+}
+
+func (r *usageCleanupRepository) createTaskWithSQL(ctx context.Context, task *service.UsageCleanupTask) error {
+ filtersJSON, err := json.Marshal(task.Filters)
+ if err != nil {
+ return fmt.Errorf("marshal cleanup filters: %w", err)
+ }
+ query := `
+ INSERT INTO usage_cleanup_tasks (
+ status,
+ filters,
+ created_by,
+ deleted_rows
+ ) VALUES ($1, $2, $3, $4)
+ RETURNING id, created_at, updated_at
+ `
+ if err := scanSingleRow(ctx, r.sql, query, []any{task.Status, filtersJSON, task.CreatedBy, task.DeletedRows}, &task.ID, &task.CreatedAt, &task.UpdatedAt); err != nil {
+ return err
+ }
+ return nil
+}
+
+func (r *usageCleanupRepository) listTasksWithEnt(ctx context.Context, params pagination.PaginationParams) ([]service.UsageCleanupTask, *pagination.PaginationResult, error) {
+ client := clientFromContext(ctx, r.client)
+ query := client.UsageCleanupTask.Query()
+ total, err := query.Clone().Count(ctx)
+ if err != nil {
+ return nil, nil, err
+ }
+ if total == 0 {
+ return []service.UsageCleanupTask{}, paginationResultFromTotal(0, params), nil
+ }
+ rows, err := query.
+ Order(dbent.Desc(dbusagecleanuptask.FieldCreatedAt), dbent.Desc(dbusagecleanuptask.FieldID)).
+ Offset(params.Offset()).
+ Limit(params.Limit()).
+ All(ctx)
+ if err != nil {
+ return nil, nil, err
+ }
+ tasks := make([]service.UsageCleanupTask, 0, len(rows))
+ for _, row := range rows {
+ task, err := usageCleanupTaskFromEnt(row)
+ if err != nil {
+ return nil, nil, err
+ }
+ tasks = append(tasks, task)
+ }
+ return tasks, paginationResultFromTotal(int64(total), params), nil
+}
+
+func (r *usageCleanupRepository) getTaskStatusWithEnt(ctx context.Context, taskID int64) (string, error) {
+ client := clientFromContext(ctx, r.client)
+ task, err := client.UsageCleanupTask.Query().
+ Where(dbusagecleanuptask.IDEQ(taskID)).
+ Only(ctx)
+ if err != nil {
+ if dbent.IsNotFound(err) {
+ return "", sql.ErrNoRows
+ }
+ return "", err
+ }
+ return task.Status, nil
+}
+
+func (r *usageCleanupRepository) updateTaskProgressWithEnt(ctx context.Context, taskID int64, deletedRows int64) error {
+ client := clientFromContext(ctx, r.client)
+ now := time.Now()
+ _, err := client.UsageCleanupTask.Update().
+ Where(dbusagecleanuptask.IDEQ(taskID)).
+ SetDeletedRows(deletedRows).
+ SetUpdatedAt(now).
+ Save(ctx)
+ return err
+}
+
+func (r *usageCleanupRepository) cancelTaskWithEnt(ctx context.Context, taskID int64, canceledBy int64) (bool, error) {
+ client := clientFromContext(ctx, r.client)
+ now := time.Now()
+ affected, err := client.UsageCleanupTask.Update().
+ Where(
+ dbusagecleanuptask.IDEQ(taskID),
+ dbusagecleanuptask.StatusIn(service.UsageCleanupStatusPending, service.UsageCleanupStatusRunning),
+ ).
+ SetStatus(service.UsageCleanupStatusCanceled).
+ SetCanceledBy(canceledBy).
+ SetCanceledAt(now).
+ SetFinishedAt(now).
+ ClearErrorMessage().
+ SetUpdatedAt(now).
+ Save(ctx)
+ if err != nil {
+ return false, err
+ }
+ return affected > 0, nil
+}
+
+func (r *usageCleanupRepository) markTaskSucceededWithEnt(ctx context.Context, taskID int64, deletedRows int64) error {
+ client := clientFromContext(ctx, r.client)
+ now := time.Now()
+ _, err := client.UsageCleanupTask.Update().
+ Where(dbusagecleanuptask.IDEQ(taskID)).
+ SetStatus(service.UsageCleanupStatusSucceeded).
+ SetDeletedRows(deletedRows).
+ SetFinishedAt(now).
+ SetUpdatedAt(now).
+ Save(ctx)
+ return err
+}
+
+func (r *usageCleanupRepository) markTaskFailedWithEnt(ctx context.Context, taskID int64, deletedRows int64, errorMsg string) error {
+ client := clientFromContext(ctx, r.client)
+ now := time.Now()
+ _, err := client.UsageCleanupTask.Update().
+ Where(dbusagecleanuptask.IDEQ(taskID)).
+ SetStatus(service.UsageCleanupStatusFailed).
+ SetDeletedRows(deletedRows).
+ SetErrorMessage(errorMsg).
+ SetFinishedAt(now).
+ SetUpdatedAt(now).
+ Save(ctx)
+ return err
+}
+
+func usageCleanupTaskFromEnt(row *dbent.UsageCleanupTask) (service.UsageCleanupTask, error) {
+ task := service.UsageCleanupTask{
+ ID: row.ID,
+ Status: row.Status,
+ CreatedBy: row.CreatedBy,
+ DeletedRows: row.DeletedRows,
+ CreatedAt: row.CreatedAt,
+ UpdatedAt: row.UpdatedAt,
+ }
+ if len(row.Filters) > 0 {
+ if err := json.Unmarshal(row.Filters, &task.Filters); err != nil {
+ return service.UsageCleanupTask{}, fmt.Errorf("parse cleanup filters: %w", err)
+ }
+ }
+ if row.ErrorMessage != nil {
+ task.ErrorMsg = row.ErrorMessage
+ }
+ if row.CanceledBy != nil {
+ task.CanceledBy = row.CanceledBy
+ }
+ if row.CanceledAt != nil {
+ task.CanceledAt = row.CanceledAt
+ }
+ if row.StartedAt != nil {
+ task.StartedAt = row.StartedAt
+ }
+ if row.FinishedAt != nil {
+ task.FinishedAt = row.FinishedAt
+ }
+ return task, nil
+}
diff --git a/backend/internal/repository/usage_cleanup_repo_ent_test.go b/backend/internal/repository/usage_cleanup_repo_ent_test.go
new file mode 100644
index 00000000..6c20b2b9
--- /dev/null
+++ b/backend/internal/repository/usage_cleanup_repo_ent_test.go
@@ -0,0 +1,251 @@
+package repository
+
+import (
+ "context"
+ "database/sql"
+ "encoding/json"
+ "testing"
+ "time"
+
+ dbent "github.com/Wei-Shaw/sub2api/ent"
+ "github.com/Wei-Shaw/sub2api/ent/enttest"
+ dbusagecleanuptask "github.com/Wei-Shaw/sub2api/ent/usagecleanuptask"
+ "github.com/Wei-Shaw/sub2api/internal/pkg/pagination"
+ "github.com/Wei-Shaw/sub2api/internal/service"
+ "github.com/stretchr/testify/require"
+
+ "entgo.io/ent/dialect"
+ entsql "entgo.io/ent/dialect/sql"
+ _ "modernc.org/sqlite"
+)
+
+func newUsageCleanupEntRepo(t *testing.T) (*usageCleanupRepository, *dbent.Client) {
+ t.Helper()
+ db, err := sql.Open("sqlite", "file:usage_cleanup?mode=memory&cache=shared")
+ require.NoError(t, err)
+ t.Cleanup(func() { _ = db.Close() })
+ _, err = db.Exec("PRAGMA foreign_keys = ON")
+ require.NoError(t, err)
+
+ drv := entsql.OpenDB(dialect.SQLite, db)
+ client := enttest.NewClient(t, enttest.WithOptions(dbent.Driver(drv)))
+ t.Cleanup(func() { _ = client.Close() })
+
+ repo := &usageCleanupRepository{client: client, sql: db}
+ return repo, client
+}
+
+func TestUsageCleanupRepositoryEntCreateAndList(t *testing.T) {
+ repo, _ := newUsageCleanupEntRepo(t)
+
+ start := time.Date(2024, 1, 2, 0, 0, 0, 0, time.UTC)
+ end := start.Add(24 * time.Hour)
+ task := &service.UsageCleanupTask{
+ Status: service.UsageCleanupStatusPending,
+ Filters: service.UsageCleanupFilters{StartTime: start, EndTime: end},
+ CreatedBy: 9,
+ }
+ require.NoError(t, repo.CreateTask(context.Background(), task))
+ require.NotZero(t, task.ID)
+
+ task2 := &service.UsageCleanupTask{
+ Status: service.UsageCleanupStatusRunning,
+ Filters: service.UsageCleanupFilters{StartTime: start.Add(-24 * time.Hour), EndTime: end.Add(-24 * time.Hour)},
+ CreatedBy: 10,
+ }
+ require.NoError(t, repo.CreateTask(context.Background(), task2))
+
+ tasks, result, err := repo.ListTasks(context.Background(), pagination.PaginationParams{Page: 1, PageSize: 10})
+ require.NoError(t, err)
+ require.Len(t, tasks, 2)
+ require.Equal(t, int64(2), result.Total)
+ require.Greater(t, tasks[0].ID, tasks[1].ID)
+ require.Equal(t, start, tasks[1].Filters.StartTime)
+ require.Equal(t, end, tasks[1].Filters.EndTime)
+}
+
+func TestUsageCleanupRepositoryEntListEmpty(t *testing.T) {
+ repo, _ := newUsageCleanupEntRepo(t)
+
+ tasks, result, err := repo.ListTasks(context.Background(), pagination.PaginationParams{Page: 1, PageSize: 10})
+ require.NoError(t, err)
+ require.Empty(t, tasks)
+ require.Equal(t, int64(0), result.Total)
+}
+
+func TestUsageCleanupRepositoryEntGetStatusAndProgress(t *testing.T) {
+ repo, client := newUsageCleanupEntRepo(t)
+
+ task := &service.UsageCleanupTask{
+ Status: service.UsageCleanupStatusPending,
+ Filters: service.UsageCleanupFilters{StartTime: time.Now().UTC(), EndTime: time.Now().UTC().Add(time.Hour)},
+ CreatedBy: 3,
+ }
+ require.NoError(t, repo.CreateTask(context.Background(), task))
+
+ status, err := repo.GetTaskStatus(context.Background(), task.ID)
+ require.NoError(t, err)
+ require.Equal(t, service.UsageCleanupStatusPending, status)
+
+ _, err = repo.GetTaskStatus(context.Background(), task.ID+99)
+ require.ErrorIs(t, err, sql.ErrNoRows)
+
+ require.NoError(t, repo.UpdateTaskProgress(context.Background(), task.ID, 42))
+ loaded, err := client.UsageCleanupTask.Get(context.Background(), task.ID)
+ require.NoError(t, err)
+ require.Equal(t, int64(42), loaded.DeletedRows)
+}
+
+func TestUsageCleanupRepositoryEntCancelAndFinish(t *testing.T) {
+ repo, client := newUsageCleanupEntRepo(t)
+
+ task := &service.UsageCleanupTask{
+ Status: service.UsageCleanupStatusPending,
+ Filters: service.UsageCleanupFilters{StartTime: time.Now().UTC(), EndTime: time.Now().UTC().Add(time.Hour)},
+ CreatedBy: 5,
+ }
+ require.NoError(t, repo.CreateTask(context.Background(), task))
+
+ ok, err := repo.CancelTask(context.Background(), task.ID, 7)
+ require.NoError(t, err)
+ require.True(t, ok)
+
+ loaded, err := client.UsageCleanupTask.Get(context.Background(), task.ID)
+ require.NoError(t, err)
+ require.Equal(t, service.UsageCleanupStatusCanceled, loaded.Status)
+ require.NotNil(t, loaded.CanceledBy)
+ require.NotNil(t, loaded.CanceledAt)
+ require.NotNil(t, loaded.FinishedAt)
+
+ loaded.Status = service.UsageCleanupStatusSucceeded
+ _, err = client.UsageCleanupTask.Update().Where(dbusagecleanuptask.IDEQ(task.ID)).SetStatus(loaded.Status).Save(context.Background())
+ require.NoError(t, err)
+
+ ok, err = repo.CancelTask(context.Background(), task.ID, 7)
+ require.NoError(t, err)
+ require.False(t, ok)
+}
+
+func TestUsageCleanupRepositoryEntCancelError(t *testing.T) {
+ repo, client := newUsageCleanupEntRepo(t)
+
+ task := &service.UsageCleanupTask{
+ Status: service.UsageCleanupStatusPending,
+ Filters: service.UsageCleanupFilters{StartTime: time.Now().UTC(), EndTime: time.Now().UTC().Add(time.Hour)},
+ CreatedBy: 5,
+ }
+ require.NoError(t, repo.CreateTask(context.Background(), task))
+
+ require.NoError(t, client.Close())
+ _, err := repo.CancelTask(context.Background(), task.ID, 7)
+ require.Error(t, err)
+}
+
+func TestUsageCleanupRepositoryEntMarkResults(t *testing.T) {
+ repo, client := newUsageCleanupEntRepo(t)
+
+ task := &service.UsageCleanupTask{
+ Status: service.UsageCleanupStatusRunning,
+ Filters: service.UsageCleanupFilters{StartTime: time.Now().UTC(), EndTime: time.Now().UTC().Add(time.Hour)},
+ CreatedBy: 12,
+ }
+ require.NoError(t, repo.CreateTask(context.Background(), task))
+
+ require.NoError(t, repo.MarkTaskSucceeded(context.Background(), task.ID, 6))
+ loaded, err := client.UsageCleanupTask.Get(context.Background(), task.ID)
+ require.NoError(t, err)
+ require.Equal(t, service.UsageCleanupStatusSucceeded, loaded.Status)
+ require.Equal(t, int64(6), loaded.DeletedRows)
+ require.NotNil(t, loaded.FinishedAt)
+
+ task2 := &service.UsageCleanupTask{
+ Status: service.UsageCleanupStatusRunning,
+ Filters: service.UsageCleanupFilters{StartTime: time.Now().UTC(), EndTime: time.Now().UTC().Add(time.Hour)},
+ CreatedBy: 12,
+ }
+ require.NoError(t, repo.CreateTask(context.Background(), task2))
+
+ require.NoError(t, repo.MarkTaskFailed(context.Background(), task2.ID, 4, "boom"))
+ loaded2, err := client.UsageCleanupTask.Get(context.Background(), task2.ID)
+ require.NoError(t, err)
+ require.Equal(t, service.UsageCleanupStatusFailed, loaded2.Status)
+ require.Equal(t, "boom", *loaded2.ErrorMessage)
+}
+
+func TestUsageCleanupRepositoryEntInvalidStatus(t *testing.T) {
+ repo, _ := newUsageCleanupEntRepo(t)
+
+ task := &service.UsageCleanupTask{
+ Status: "invalid",
+ Filters: service.UsageCleanupFilters{StartTime: time.Now().UTC(), EndTime: time.Now().UTC().Add(time.Hour)},
+ CreatedBy: 1,
+ }
+ require.Error(t, repo.CreateTask(context.Background(), task))
+}
+
+func TestUsageCleanupRepositoryEntListInvalidFilters(t *testing.T) {
+ repo, client := newUsageCleanupEntRepo(t)
+
+ now := time.Now().UTC()
+ driver, ok := client.Driver().(*entsql.Driver)
+ require.True(t, ok)
+ _, err := driver.DB().ExecContext(
+ context.Background(),
+ `INSERT INTO usage_cleanup_tasks (status, filters, created_by, deleted_rows, created_at, updated_at)
+ VALUES (?, ?, ?, ?, ?, ?)`,
+ service.UsageCleanupStatusPending,
+ []byte("invalid-json"),
+ int64(1),
+ int64(0),
+ now,
+ now,
+ )
+ require.NoError(t, err)
+
+ _, _, err = repo.ListTasks(context.Background(), pagination.PaginationParams{Page: 1, PageSize: 10})
+ require.Error(t, err)
+}
+
+func TestUsageCleanupTaskFromEntFull(t *testing.T) {
+ start := time.Date(2024, 1, 2, 0, 0, 0, 0, time.UTC)
+ end := start.Add(24 * time.Hour)
+ errMsg := "failed"
+ canceledBy := int64(2)
+ canceledAt := start.Add(time.Minute)
+ startedAt := start.Add(2 * time.Minute)
+ finishedAt := start.Add(3 * time.Minute)
+ filters := service.UsageCleanupFilters{StartTime: start, EndTime: end}
+ filtersJSON, err := json.Marshal(filters)
+ require.NoError(t, err)
+
+ task, err := usageCleanupTaskFromEnt(&dbent.UsageCleanupTask{
+ ID: 10,
+ Status: service.UsageCleanupStatusFailed,
+ Filters: filtersJSON,
+ CreatedBy: 11,
+ DeletedRows: 7,
+ ErrorMessage: &errMsg,
+ CanceledBy: &canceledBy,
+ CanceledAt: &canceledAt,
+ StartedAt: &startedAt,
+ FinishedAt: &finishedAt,
+ CreatedAt: start,
+ UpdatedAt: end,
+ })
+ require.NoError(t, err)
+ require.Equal(t, int64(10), task.ID)
+ require.Equal(t, service.UsageCleanupStatusFailed, task.Status)
+ require.NotNil(t, task.ErrorMsg)
+ require.NotNil(t, task.CanceledBy)
+ require.NotNil(t, task.CanceledAt)
+ require.NotNil(t, task.StartedAt)
+ require.NotNil(t, task.FinishedAt)
+}
+
+func TestUsageCleanupTaskFromEntInvalidFilters(t *testing.T) {
+ task, err := usageCleanupTaskFromEnt(&dbent.UsageCleanupTask{
+ Filters: json.RawMessage("invalid-json"),
+ })
+ require.Error(t, err)
+ require.Empty(t, task)
+}
diff --git a/backend/internal/repository/usage_cleanup_repo_test.go b/backend/internal/repository/usage_cleanup_repo_test.go
index e5582709..0ca30ec7 100644
--- a/backend/internal/repository/usage_cleanup_repo_test.go
+++ b/backend/internal/repository/usage_cleanup_repo_test.go
@@ -23,7 +23,7 @@ func newSQLMock(t *testing.T) (*sql.DB, sqlmock.Sqlmock) {
func TestNewUsageCleanupRepository(t *testing.T) {
db, _ := newSQLMock(t)
- repo := NewUsageCleanupRepository(db)
+ repo := NewUsageCleanupRepository(nil, db)
require.NotNil(t, repo)
}
@@ -146,6 +146,21 @@ func TestUsageCleanupRepositoryListTasks(t *testing.T) {
require.NoError(t, mock.ExpectationsWereMet())
}
+func TestUsageCleanupRepositoryListTasksQueryError(t *testing.T) {
+ db, mock := newSQLMock(t)
+ repo := &usageCleanupRepository{sql: db}
+
+ mock.ExpectQuery("SELECT COUNT\\(\\*\\) FROM usage_cleanup_tasks").
+ WillReturnRows(sqlmock.NewRows([]string{"count"}).AddRow(int64(2)))
+ mock.ExpectQuery("SELECT id, status, filters, created_by, deleted_rows, error_message").
+ WithArgs(20, 0).
+ WillReturnError(sql.ErrConnDone)
+
+ _, _, err := repo.ListTasks(context.Background(), pagination.PaginationParams{Page: 1, PageSize: 20})
+ require.Error(t, err)
+ require.NoError(t, mock.ExpectationsWereMet())
+}
+
func TestUsageCleanupRepositoryListTasksInvalidFilters(t *testing.T) {
db, mock := newSQLMock(t)
repo := &usageCleanupRepository{sql: db}
@@ -320,6 +335,19 @@ func TestUsageCleanupRepositoryGetTaskStatus(t *testing.T) {
require.NoError(t, mock.ExpectationsWereMet())
}
+func TestUsageCleanupRepositoryGetTaskStatusQueryError(t *testing.T) {
+ db, mock := newSQLMock(t)
+ repo := &usageCleanupRepository{sql: db}
+
+ mock.ExpectQuery("SELECT status FROM usage_cleanup_tasks").
+ WithArgs(int64(9)).
+ WillReturnError(sql.ErrConnDone)
+
+ _, err := repo.GetTaskStatus(context.Background(), 9)
+ require.Error(t, err)
+ require.NoError(t, mock.ExpectationsWereMet())
+}
+
func TestUsageCleanupRepositoryUpdateTaskProgress(t *testing.T) {
db, mock := newSQLMock(t)
repo := &usageCleanupRepository{sql: db}
@@ -347,6 +375,20 @@ func TestUsageCleanupRepositoryCancelTask(t *testing.T) {
require.NoError(t, mock.ExpectationsWereMet())
}
+func TestUsageCleanupRepositoryCancelTaskNoRows(t *testing.T) {
+ db, mock := newSQLMock(t)
+ repo := &usageCleanupRepository{sql: db}
+
+ mock.ExpectQuery("UPDATE usage_cleanup_tasks").
+ WithArgs(service.UsageCleanupStatusCanceled, int64(6), int64(9), service.UsageCleanupStatusPending, service.UsageCleanupStatusRunning).
+ WillReturnRows(sqlmock.NewRows([]string{"id"}))
+
+ ok, err := repo.CancelTask(context.Background(), 6, 9)
+ require.NoError(t, err)
+ require.False(t, ok)
+ require.NoError(t, mock.ExpectationsWereMet())
+}
+
func TestUsageCleanupRepositoryDeleteUsageLogsBatchMissingRange(t *testing.T) {
db, _ := newSQLMock(t)
repo := &usageCleanupRepository{sql: db}
diff --git a/backend/internal/service/dashboard_aggregation_service.go b/backend/internal/service/dashboard_aggregation_service.go
index 8f7e8144..10c68868 100644
--- a/backend/internal/service/dashboard_aggregation_service.go
+++ b/backend/internal/service/dashboard_aggregation_service.go
@@ -20,7 +20,7 @@ var (
// ErrDashboardBackfillDisabled 当配置禁用回填时返回。
ErrDashboardBackfillDisabled = errors.New("仪表盘聚合回填已禁用")
// ErrDashboardBackfillTooLarge 当回填跨度超过限制时返回。
- ErrDashboardBackfillTooLarge = errors.New("回填时间跨度过大")
+ ErrDashboardBackfillTooLarge = errors.New("回填时间跨度过大")
errDashboardAggregationRunning = errors.New("聚合作业正在运行")
)
diff --git a/backend/internal/service/usage_cleanup_service.go b/backend/internal/service/usage_cleanup_service.go
index 8ca02cfc..37f6d375 100644
--- a/backend/internal/service/usage_cleanup_service.go
+++ b/backend/internal/service/usage_cleanup_service.go
@@ -151,20 +151,24 @@ func (s *UsageCleanupService) CreateTask(ctx context.Context, filters UsageClean
}
func (s *UsageCleanupService) runOnce() {
- if !atomic.CompareAndSwapInt32(&s.running, 0, 1) {
+ svc := s
+ if svc == nil {
+ return
+ }
+ if !atomic.CompareAndSwapInt32(&svc.running, 0, 1) {
log.Printf("[UsageCleanup] run_once skipped: already_running=true")
return
}
- defer atomic.StoreInt32(&s.running, 0)
+ defer atomic.StoreInt32(&svc.running, 0)
parent := context.Background()
- if s != nil && s.workerCtx != nil {
- parent = s.workerCtx
+ if svc.workerCtx != nil {
+ parent = svc.workerCtx
}
- ctx, cancel := context.WithTimeout(parent, s.taskTimeout())
+ ctx, cancel := context.WithTimeout(parent, svc.taskTimeout())
defer cancel()
- task, err := s.repo.ClaimNextPendingTask(ctx, int64(s.taskTimeout().Seconds()))
+ task, err := svc.repo.ClaimNextPendingTask(ctx, int64(svc.taskTimeout().Seconds()))
if err != nil {
log.Printf("[UsageCleanup] claim pending task failed: %v", err)
return
@@ -175,7 +179,7 @@ func (s *UsageCleanupService) runOnce() {
}
log.Printf("[UsageCleanup] task claimed: task=%d status=%s created_by=%d deleted_rows=%d %s", task.ID, task.Status, task.CreatedBy, task.DeletedRows, describeUsageCleanupFilters(task.Filters))
- s.executeTask(ctx, task)
+ svc.executeTask(ctx, task)
}
func (s *UsageCleanupService) executeTask(ctx context.Context, task *UsageCleanupTask) {
diff --git a/backend/internal/service/usage_cleanup_service_test.go b/backend/internal/service/usage_cleanup_service_test.go
index 37d3eb19..05c423bc 100644
--- a/backend/internal/service/usage_cleanup_service_test.go
+++ b/backend/internal/service/usage_cleanup_service_test.go
@@ -46,8 +46,45 @@ type cleanupRepoStub struct {
markSucceeded []cleanupMarkCall
markFailed []cleanupMarkCall
statusByID map[int64]string
+ statusErr error
progressCalls []cleanupMarkCall
+ updateErr error
cancelCalls []int64
+ cancelErr error
+ cancelResult *bool
+ markFailedErr error
+}
+
+type dashboardRepoStub struct {
+ recomputeErr error
+}
+
+func (s *dashboardRepoStub) AggregateRange(ctx context.Context, start, end time.Time) error {
+ return nil
+}
+
+func (s *dashboardRepoStub) RecomputeRange(ctx context.Context, start, end time.Time) error {
+ return s.recomputeErr
+}
+
+func (s *dashboardRepoStub) GetAggregationWatermark(ctx context.Context) (time.Time, error) {
+ return time.Time{}, nil
+}
+
+func (s *dashboardRepoStub) UpdateAggregationWatermark(ctx context.Context, aggregatedAt time.Time) error {
+ return nil
+}
+
+func (s *dashboardRepoStub) CleanupAggregates(ctx context.Context, hourlyCutoff, dailyCutoff time.Time) error {
+ return nil
+}
+
+func (s *dashboardRepoStub) CleanupUsageLogs(ctx context.Context, cutoff time.Time) error {
+ return nil
+}
+
+func (s *dashboardRepoStub) EnsureUsageLogsPartitions(ctx context.Context, now time.Time) error {
+ return nil
}
func (s *cleanupRepoStub) CreateTask(ctx context.Context, task *UsageCleanupTask) error {
@@ -100,6 +137,9 @@ func (s *cleanupRepoStub) ClaimNextPendingTask(ctx context.Context, staleRunning
func (s *cleanupRepoStub) GetTaskStatus(ctx context.Context, taskID int64) (string, error) {
s.mu.Lock()
defer s.mu.Unlock()
+ if s.statusErr != nil {
+ return "", s.statusErr
+ }
if s.statusByID == nil {
return "", sql.ErrNoRows
}
@@ -114,6 +154,9 @@ func (s *cleanupRepoStub) UpdateTaskProgress(ctx context.Context, taskID int64,
s.mu.Lock()
defer s.mu.Unlock()
s.progressCalls = append(s.progressCalls, cleanupMarkCall{taskID: taskID, deletedRows: deletedRows})
+ if s.updateErr != nil {
+ return s.updateErr
+ }
return nil
}
@@ -121,6 +164,19 @@ func (s *cleanupRepoStub) CancelTask(ctx context.Context, taskID int64, canceled
s.mu.Lock()
defer s.mu.Unlock()
s.cancelCalls = append(s.cancelCalls, taskID)
+ if s.cancelErr != nil {
+ return false, s.cancelErr
+ }
+ if s.cancelResult != nil {
+ ok := *s.cancelResult
+ if ok {
+ if s.statusByID == nil {
+ s.statusByID = map[int64]string{}
+ }
+ s.statusByID[taskID] = UsageCleanupStatusCanceled
+ }
+ return ok, nil
+ }
if s.statusByID == nil {
s.statusByID = map[int64]string{}
}
@@ -151,6 +207,9 @@ func (s *cleanupRepoStub) MarkTaskFailed(ctx context.Context, taskID int64, dele
s.statusByID = map[int64]string{}
}
s.statusByID[taskID] = UsageCleanupStatusFailed
+ if s.markFailedErr != nil {
+ return s.markFailedErr
+ }
return nil
}
@@ -266,9 +325,11 @@ func TestUsageCleanupServiceCreateTaskRepoError(t *testing.T) {
}
func TestUsageCleanupServiceRunOnceSuccess(t *testing.T) {
+ start := time.Date(2024, 1, 1, 0, 0, 0, 0, time.UTC)
+ end := start.Add(2 * time.Hour)
repo := &cleanupRepoStub{
claimQueue: []*UsageCleanupTask{
- {ID: 5, Filters: UsageCleanupFilters{StartTime: time.Now(), EndTime: time.Now().Add(2 * time.Hour)}},
+ {ID: 5, Filters: UsageCleanupFilters{StartTime: start, EndTime: end}},
},
deleteQueue: []cleanupDeleteResponse{
{deleted: 2},
@@ -288,6 +349,9 @@ func TestUsageCleanupServiceRunOnceSuccess(t *testing.T) {
require.Empty(t, repo.markFailed)
require.Equal(t, int64(5), repo.markSucceeded[0].taskID)
require.Equal(t, int64(5), repo.markSucceeded[0].deletedRows)
+ require.Equal(t, 2, repo.deleteCalls[0].limit)
+ require.Equal(t, start, repo.deleteCalls[0].filters.StartTime)
+ require.Equal(t, end, repo.deleteCalls[0].filters.EndTime)
}
func TestUsageCleanupServiceRunOnceClaimError(t *testing.T) {
@@ -336,6 +400,293 @@ func TestUsageCleanupServiceExecuteTaskFailed(t *testing.T) {
require.Equal(t, 500, len(repo.markFailed[0].errMsg))
}
+func TestUsageCleanupServiceExecuteTaskProgressError(t *testing.T) {
+ repo := &cleanupRepoStub{
+ deleteQueue: []cleanupDeleteResponse{
+ {deleted: 2},
+ {deleted: 0},
+ },
+ updateErr: errors.New("update failed"),
+ }
+ cfg := &config.Config{UsageCleanup: config.UsageCleanupConfig{Enabled: true, BatchSize: 2}}
+ svc := NewUsageCleanupService(repo, nil, nil, cfg)
+ task := &UsageCleanupTask{
+ ID: 8,
+ Filters: UsageCleanupFilters{
+ StartTime: time.Now().UTC(),
+ EndTime: time.Now().UTC().Add(time.Hour),
+ },
+ }
+
+ svc.executeTask(context.Background(), task)
+
+ repo.mu.Lock()
+ defer repo.mu.Unlock()
+ require.Len(t, repo.markSucceeded, 1)
+ require.Empty(t, repo.markFailed)
+ require.Len(t, repo.progressCalls, 1)
+}
+
+func TestUsageCleanupServiceExecuteTaskDeleteCanceled(t *testing.T) {
+ repo := &cleanupRepoStub{
+ deleteQueue: []cleanupDeleteResponse{
+ {err: context.Canceled},
+ },
+ }
+ cfg := &config.Config{UsageCleanup: config.UsageCleanupConfig{Enabled: true, BatchSize: 2}}
+ svc := NewUsageCleanupService(repo, nil, nil, cfg)
+ task := &UsageCleanupTask{
+ ID: 12,
+ Filters: UsageCleanupFilters{
+ StartTime: time.Now().UTC(),
+ EndTime: time.Now().UTC().Add(time.Hour),
+ },
+ }
+
+ svc.executeTask(context.Background(), task)
+
+ repo.mu.Lock()
+ defer repo.mu.Unlock()
+ require.Empty(t, repo.markSucceeded)
+ require.Empty(t, repo.markFailed)
+}
+
+func TestUsageCleanupServiceExecuteTaskContextCanceled(t *testing.T) {
+ repo := &cleanupRepoStub{}
+ cfg := &config.Config{UsageCleanup: config.UsageCleanupConfig{Enabled: true, BatchSize: 2}}
+ svc := NewUsageCleanupService(repo, nil, nil, cfg)
+ task := &UsageCleanupTask{
+ ID: 9,
+ Filters: UsageCleanupFilters{
+ StartTime: time.Now().UTC(),
+ EndTime: time.Now().UTC().Add(time.Hour),
+ },
+ }
+ ctx, cancel := context.WithCancel(context.Background())
+ cancel()
+
+ svc.executeTask(ctx, task)
+
+ repo.mu.Lock()
+ defer repo.mu.Unlock()
+ require.Empty(t, repo.markSucceeded)
+ require.Empty(t, repo.markFailed)
+ require.Empty(t, repo.deleteCalls)
+}
+
+func TestUsageCleanupServiceExecuteTaskMarkFailedUpdateError(t *testing.T) {
+ repo := &cleanupRepoStub{
+ deleteQueue: []cleanupDeleteResponse{
+ {err: errors.New("boom")},
+ },
+ markFailedErr: errors.New("update failed"),
+ }
+ cfg := &config.Config{UsageCleanup: config.UsageCleanupConfig{Enabled: true, BatchSize: 2}}
+ svc := NewUsageCleanupService(repo, nil, nil, cfg)
+ task := &UsageCleanupTask{
+ ID: 13,
+ Filters: UsageCleanupFilters{
+ StartTime: time.Now().UTC(),
+ EndTime: time.Now().UTC().Add(time.Hour),
+ },
+ }
+
+ svc.executeTask(context.Background(), task)
+
+ repo.mu.Lock()
+ defer repo.mu.Unlock()
+ require.Len(t, repo.markFailed, 1)
+ require.Equal(t, int64(13), repo.markFailed[0].taskID)
+}
+
+func TestUsageCleanupServiceExecuteTaskDashboardRecomputeError(t *testing.T) {
+ repo := &cleanupRepoStub{
+ deleteQueue: []cleanupDeleteResponse{
+ {deleted: 0},
+ },
+ }
+ dashboard := NewDashboardAggregationService(&dashboardRepoStub{}, nil, &config.Config{
+ DashboardAgg: config.DashboardAggregationConfig{Enabled: false},
+ })
+ cfg := &config.Config{UsageCleanup: config.UsageCleanupConfig{Enabled: true, BatchSize: 2}}
+ svc := NewUsageCleanupService(repo, nil, dashboard, cfg)
+ task := &UsageCleanupTask{
+ ID: 14,
+ Filters: UsageCleanupFilters{
+ StartTime: time.Now().UTC(),
+ EndTime: time.Now().UTC().Add(time.Hour),
+ },
+ }
+
+ svc.executeTask(context.Background(), task)
+
+ repo.mu.Lock()
+ defer repo.mu.Unlock()
+ require.Len(t, repo.markSucceeded, 1)
+}
+
+func TestUsageCleanupServiceExecuteTaskDashboardRecomputeSuccess(t *testing.T) {
+ repo := &cleanupRepoStub{
+ deleteQueue: []cleanupDeleteResponse{
+ {deleted: 0},
+ },
+ }
+ dashboard := NewDashboardAggregationService(&dashboardRepoStub{}, nil, &config.Config{
+ DashboardAgg: config.DashboardAggregationConfig{Enabled: true},
+ })
+ cfg := &config.Config{UsageCleanup: config.UsageCleanupConfig{Enabled: true, BatchSize: 2}}
+ svc := NewUsageCleanupService(repo, nil, dashboard, cfg)
+ task := &UsageCleanupTask{
+ ID: 15,
+ Filters: UsageCleanupFilters{
+ StartTime: time.Now().UTC(),
+ EndTime: time.Now().UTC().Add(time.Hour),
+ },
+ }
+
+ svc.executeTask(context.Background(), task)
+
+ repo.mu.Lock()
+ defer repo.mu.Unlock()
+ require.Len(t, repo.markSucceeded, 1)
+}
+
+func TestUsageCleanupServiceExecuteTaskCanceled(t *testing.T) {
+ repo := &cleanupRepoStub{
+ statusByID: map[int64]string{
+ 3: UsageCleanupStatusCanceled,
+ },
+ }
+ cfg := &config.Config{UsageCleanup: config.UsageCleanupConfig{Enabled: true, BatchSize: 2}}
+ svc := NewUsageCleanupService(repo, nil, nil, cfg)
+ task := &UsageCleanupTask{
+ ID: 3,
+ Filters: UsageCleanupFilters{
+ StartTime: time.Now().UTC(),
+ EndTime: time.Now().UTC().Add(time.Hour),
+ },
+ }
+
+ svc.executeTask(context.Background(), task)
+
+ repo.mu.Lock()
+ defer repo.mu.Unlock()
+ require.Empty(t, repo.deleteCalls)
+ require.Empty(t, repo.markSucceeded)
+ require.Empty(t, repo.markFailed)
+}
+
+func TestUsageCleanupServiceCancelTaskSuccess(t *testing.T) {
+ repo := &cleanupRepoStub{
+ statusByID: map[int64]string{
+ 5: UsageCleanupStatusPending,
+ },
+ }
+ cfg := &config.Config{UsageCleanup: config.UsageCleanupConfig{Enabled: true}}
+ svc := NewUsageCleanupService(repo, nil, nil, cfg)
+
+ err := svc.CancelTask(context.Background(), 5, 9)
+ require.NoError(t, err)
+
+ repo.mu.Lock()
+ defer repo.mu.Unlock()
+ require.Equal(t, UsageCleanupStatusCanceled, repo.statusByID[5])
+ require.Len(t, repo.cancelCalls, 1)
+}
+
+func TestUsageCleanupServiceCancelTaskDisabled(t *testing.T) {
+ repo := &cleanupRepoStub{}
+ cfg := &config.Config{UsageCleanup: config.UsageCleanupConfig{Enabled: false}}
+ svc := NewUsageCleanupService(repo, nil, nil, cfg)
+
+ err := svc.CancelTask(context.Background(), 1, 2)
+ require.Error(t, err)
+ require.Equal(t, http.StatusServiceUnavailable, infraerrors.Code(err))
+ require.Equal(t, "USAGE_CLEANUP_DISABLED", infraerrors.Reason(err))
+}
+
+func TestUsageCleanupServiceCancelTaskNotFound(t *testing.T) {
+ repo := &cleanupRepoStub{}
+ cfg := &config.Config{UsageCleanup: config.UsageCleanupConfig{Enabled: true}}
+ svc := NewUsageCleanupService(repo, nil, nil, cfg)
+
+ err := svc.CancelTask(context.Background(), 999, 1)
+ require.Error(t, err)
+ require.Equal(t, http.StatusNotFound, infraerrors.Code(err))
+ require.Equal(t, "USAGE_CLEANUP_TASK_NOT_FOUND", infraerrors.Reason(err))
+}
+
+func TestUsageCleanupServiceCancelTaskStatusError(t *testing.T) {
+ repo := &cleanupRepoStub{statusErr: errors.New("status broken")}
+ cfg := &config.Config{UsageCleanup: config.UsageCleanupConfig{Enabled: true}}
+ svc := NewUsageCleanupService(repo, nil, nil, cfg)
+
+ err := svc.CancelTask(context.Background(), 7, 1)
+ require.Error(t, err)
+ require.Contains(t, err.Error(), "status broken")
+}
+
+func TestUsageCleanupServiceCancelTaskConflict(t *testing.T) {
+ repo := &cleanupRepoStub{
+ statusByID: map[int64]string{
+ 7: UsageCleanupStatusSucceeded,
+ },
+ }
+ cfg := &config.Config{UsageCleanup: config.UsageCleanupConfig{Enabled: true}}
+ svc := NewUsageCleanupService(repo, nil, nil, cfg)
+
+ err := svc.CancelTask(context.Background(), 7, 1)
+ require.Error(t, err)
+ require.Equal(t, http.StatusConflict, infraerrors.Code(err))
+ require.Equal(t, "USAGE_CLEANUP_CANCEL_CONFLICT", infraerrors.Reason(err))
+}
+
+func TestUsageCleanupServiceCancelTaskRepoConflict(t *testing.T) {
+ shouldCancel := false
+ repo := &cleanupRepoStub{
+ statusByID: map[int64]string{
+ 7: UsageCleanupStatusPending,
+ },
+ cancelResult: &shouldCancel,
+ }
+ cfg := &config.Config{UsageCleanup: config.UsageCleanupConfig{Enabled: true}}
+ svc := NewUsageCleanupService(repo, nil, nil, cfg)
+
+ err := svc.CancelTask(context.Background(), 7, 1)
+ require.Error(t, err)
+ require.Equal(t, http.StatusConflict, infraerrors.Code(err))
+ require.Equal(t, "USAGE_CLEANUP_CANCEL_CONFLICT", infraerrors.Reason(err))
+}
+
+func TestUsageCleanupServiceCancelTaskRepoError(t *testing.T) {
+ repo := &cleanupRepoStub{
+ statusByID: map[int64]string{
+ 7: UsageCleanupStatusPending,
+ },
+ cancelErr: errors.New("cancel failed"),
+ }
+ cfg := &config.Config{UsageCleanup: config.UsageCleanupConfig{Enabled: true}}
+ svc := NewUsageCleanupService(repo, nil, nil, cfg)
+
+ err := svc.CancelTask(context.Background(), 7, 1)
+ require.Error(t, err)
+ require.Contains(t, err.Error(), "cancel failed")
+}
+
+func TestUsageCleanupServiceCancelTaskInvalidCanceller(t *testing.T) {
+ repo := &cleanupRepoStub{
+ statusByID: map[int64]string{
+ 7: UsageCleanupStatusRunning,
+ },
+ }
+ cfg := &config.Config{UsageCleanup: config.UsageCleanupConfig{Enabled: true}}
+ svc := NewUsageCleanupService(repo, nil, nil, cfg)
+
+ err := svc.CancelTask(context.Background(), 7, 0)
+ require.Error(t, err)
+ require.Equal(t, "USAGE_CLEANUP_INVALID_CANCELLER", infraerrors.Reason(err))
+}
+
func TestUsageCleanupServiceListTasks(t *testing.T) {
repo := &cleanupRepoStub{
listTasks: []UsageCleanupTask{{ID: 1}, {ID: 2}},
@@ -418,3 +769,47 @@ func TestSanitizeUsageCleanupFiltersModelEmpty(t *testing.T) {
require.Nil(t, filters.GroupID)
require.Nil(t, filters.Model)
}
+
+func TestDescribeUsageCleanupFiltersAllFields(t *testing.T) {
+ start := time.Date(2024, 2, 1, 10, 0, 0, 0, time.UTC)
+ end := start.Add(2 * time.Hour)
+ userID := int64(1)
+ apiKeyID := int64(2)
+ accountID := int64(3)
+ groupID := int64(4)
+ model := " gpt-4 "
+ stream := true
+ billingType := int8(2)
+ filters := UsageCleanupFilters{
+ StartTime: start,
+ EndTime: end,
+ UserID: &userID,
+ APIKeyID: &apiKeyID,
+ AccountID: &accountID,
+ GroupID: &groupID,
+ Model: &model,
+ Stream: &stream,
+ BillingType: &billingType,
+ }
+
+ desc := describeUsageCleanupFilters(filters)
+ require.Equal(t, "start=2024-02-01T10:00:00Z end=2024-02-01T12:00:00Z user_id=1 api_key_id=2 account_id=3 group_id=4 model=gpt-4 stream=true billing_type=2", desc)
+}
+
+func TestUsageCleanupServiceIsTaskCanceledNotFound(t *testing.T) {
+ repo := &cleanupRepoStub{}
+ svc := NewUsageCleanupService(repo, nil, nil, &config.Config{UsageCleanup: config.UsageCleanupConfig{Enabled: true}})
+
+ canceled, err := svc.isTaskCanceled(context.Background(), 9)
+ require.NoError(t, err)
+ require.False(t, canceled)
+}
+
+func TestUsageCleanupServiceIsTaskCanceledError(t *testing.T) {
+ repo := &cleanupRepoStub{statusErr: errors.New("status err")}
+ svc := NewUsageCleanupService(repo, nil, nil, &config.Config{UsageCleanup: config.UsageCleanupConfig{Enabled: true}})
+
+ _, err := svc.isTaskCanceled(context.Background(), 9)
+ require.Error(t, err)
+ require.Contains(t, err.Error(), "status err")
+}
diff --git a/frontend/src/components/admin/usage/UsageCleanupDialog.vue b/frontend/src/components/admin/usage/UsageCleanupDialog.vue
index 4cd562e8..91a43ecd 100644
--- a/frontend/src/components/admin/usage/UsageCleanupDialog.vue
+++ b/frontend/src/components/admin/usage/UsageCleanupDialog.vue
@@ -219,7 +219,7 @@ const loadTasks = async () => {
if (!props.show) return
tasksLoading.value = true
try {
- const res = await adminUsageAPI.listCleanupTasks({ page: 1, page_size: 10 })
+ const res = await adminUsageAPI.listCleanupTasks({ page: 1, page_size: 5 })
tasks.value = res.items || []
} catch (error) {
console.error('Failed to load cleanup tasks:', error)
From 771baa66ee34812691b8a28047e702113aeada42 Mon Sep 17 00:00:00 2001
From: yangjianbo
Date: Sun, 18 Jan 2026 14:31:22 +0800
Subject: [PATCH 024/155] =?UTF-8?q?feat(=E7=95=8C=E9=9D=A2):=20=E4=BC=98?=
=?UTF-8?q?=E5=8C=96=E5=88=86=E9=A1=B5=E8=B7=B3=E8=BD=AC=E4=B8=8E=E9=A1=B5?=
=?UTF-8?q?=E5=A4=A7=E5=B0=8F=E6=98=BE=E7=A4=BA?=
MIME-Version: 1.0
Content-Type: text/plain; charset=UTF-8
Content-Transfer-Encoding: 8bit
分页组件支持隐藏每页条数选择器并新增跳转页输入
清理任务列表启用跳转页并固定每页 5 条
补充中英文分页文案
---
.../admin/usage/UsageCleanupDialog.vue | 43 ++++++++++++++++++-
frontend/src/components/common/Pagination.vue | 38 ++++++++++++++--
frontend/src/i18n/locales/en.ts | 5 ++-
frontend/src/i18n/locales/zh.ts | 5 ++-
4 files changed, 85 insertions(+), 6 deletions(-)
diff --git a/frontend/src/components/admin/usage/UsageCleanupDialog.vue b/frontend/src/components/admin/usage/UsageCleanupDialog.vue
index 91a43ecd..d5e81e72 100644
--- a/frontend/src/components/admin/usage/UsageCleanupDialog.vue
+++ b/frontend/src/components/admin/usage/UsageCleanupDialog.vue
@@ -66,6 +66,19 @@
+
+
@@ -108,6 +121,7 @@ import { useI18n } from 'vue-i18n'
import { useAppStore } from '@/stores/app'
import BaseDialog from '@/components/common/BaseDialog.vue'
import ConfirmDialog from '@/components/common/ConfirmDialog.vue'
+import Pagination from '@/components/common/Pagination.vue'
import UsageFilters from '@/components/admin/usage/UsageFilters.vue'
import { adminUsageAPI } from '@/api/admin/usage'
import type { AdminUsageQueryParams, UsageCleanupTask, CreateUsageCleanupTaskRequest } from '@/api/admin/usage'
@@ -131,6 +145,9 @@ const localEndDate = ref('')
const tasks = ref([])
const tasksLoading = ref(false)
+const tasksPage = ref(1)
+const tasksPageSize = ref(5)
+const tasksTotal = ref(0)
const submitting = ref(false)
const confirmVisible = ref(false)
const cancelConfirmVisible = ref(false)
@@ -146,6 +163,8 @@ const resetFilters = () => {
localEndDate.value = props.endDate
localFilters.value.start_date = localStartDate.value
localFilters.value.end_date = localEndDate.value
+ tasksPage.value = 1
+ tasksTotal.value = 0
}
const startPolling = () => {
@@ -219,8 +238,18 @@ const loadTasks = async () => {
if (!props.show) return
tasksLoading.value = true
try {
- const res = await adminUsageAPI.listCleanupTasks({ page: 1, page_size: 5 })
+ const res = await adminUsageAPI.listCleanupTasks({
+ page: tasksPage.value,
+ page_size: tasksPageSize.value
+ })
tasks.value = res.items || []
+ tasksTotal.value = res.total || 0
+ if (res.page) {
+ tasksPage.value = res.page
+ }
+ if (res.page_size) {
+ tasksPageSize.value = res.page_size
+ }
} catch (error) {
console.error('Failed to load cleanup tasks:', error)
appStore.showError(t('admin.usage.cleanup.loadFailed'))
@@ -229,6 +258,18 @@ const loadTasks = async () => {
}
}
+const handleTaskPageChange = (page: number) => {
+ tasksPage.value = page
+ loadTasks()
+}
+
+const handleTaskPageSizeChange = (size: number) => {
+ if (!Number.isFinite(size) || size <= 0) return
+ tasksPageSize.value = size
+ tasksPage.value = 1
+ loadTasks()
+}
+
const openConfirm = () => {
confirmVisible.value = true
}
diff --git a/frontend/src/components/common/Pagination.vue b/frontend/src/components/common/Pagination.vue
index 728bc0d3..3365a186 100644
--- a/frontend/src/components/common/Pagination.vue
+++ b/frontend/src/components/common/Pagination.vue
@@ -37,7 +37,7 @@
-
+
{{ t('pagination.perPage') }}:
@@ -49,6 +49,22 @@
/>
+
+
+ {{ t('pagination.jumpTo') }}
+
+
+
@@ -102,7 +118,7 @@
+
+
+
+
+
+
密码重置请求
+
您已请求重置密码。请点击下方按钮设置新密码:
+
重置密码
+
+
此链接将在 30 分钟后失效。
+
如果您没有请求重置密码,请忽略此邮件。您的密码将保持不变。
+
+
+
如果按钮无法点击,请复制以下链接到浏览器中打开:
+
%s
+
+
+
+
+
+