From 7a06c4873edf63c90387e00fded8adf352899b95 Mon Sep 17 00:00:00 2001 From: cyhhao Date: Fri, 9 Jan 2026 18:35:58 +0800 Subject: [PATCH 01/23] Fix Codex OAuth tool mapping --- backend/internal/service/codex_prompts.go | 9 + .../service/openai_codex_transform.go | 1085 +++++++++++++++++ .../service/openai_gateway_service.go | 267 ++-- .../service/openai_gateway_service_test.go | 2 +- .../service/prompts/codex_opencode_bridge.txt | 122 ++ .../service/prompts/tool_remap_message.txt | 63 + 6 files changed, 1434 insertions(+), 114 deletions(-) create mode 100644 backend/internal/service/codex_prompts.go create mode 100644 backend/internal/service/openai_codex_transform.go create mode 100644 backend/internal/service/prompts/codex_opencode_bridge.txt create mode 100644 backend/internal/service/prompts/tool_remap_message.txt diff --git a/backend/internal/service/codex_prompts.go b/backend/internal/service/codex_prompts.go new file mode 100644 index 00000000..6f83eac2 --- /dev/null +++ b/backend/internal/service/codex_prompts.go @@ -0,0 +1,9 @@ +package service + +import _ "embed" + +//go:embed prompts/codex_opencode_bridge.txt +var codexOpenCodeBridge string + +//go:embed prompts/tool_remap_message.txt +var codexToolRemapMessage string diff --git a/backend/internal/service/openai_codex_transform.go b/backend/internal/service/openai_codex_transform.go new file mode 100644 index 00000000..a52c88b5 --- /dev/null +++ b/backend/internal/service/openai_codex_transform.go @@ -0,0 +1,1085 @@ +package service + +import ( + "encoding/json" + "fmt" + "io" + "net/http" + "os" + "path/filepath" + "strings" + "time" + "unicode" +) + +const ( + codexReleaseAPIURL = "https://api.github.com/repos/openai/codex/releases/latest" + codexReleaseHTMLURL = "https://github.com/openai/codex/releases/latest" + codexPromptURLFmt = "https://raw.githubusercontent.com/openai/codex/%s/codex-rs/core/%s" + opencodeCodexURL = "https://raw.githubusercontent.com/anomalyco/opencode/dev/packages/opencode/src/session/prompt/codex.txt" + codexCacheTTL = 15 * time.Minute +) + +type codexModelFamily string + +const ( + codexFamilyGpt52Codex codexModelFamily = "gpt-5.2-codex" + codexFamilyCodexMax codexModelFamily = "codex-max" + codexFamilyCodex codexModelFamily = "codex" + codexFamilyGpt52 codexModelFamily = "gpt-5.2" + codexFamilyGpt51 codexModelFamily = "gpt-5.1" +) + +var codexPromptFiles = map[codexModelFamily]string{ + codexFamilyGpt52Codex: "gpt-5.2-codex_prompt.md", + codexFamilyCodexMax: "gpt-5.1-codex-max_prompt.md", + codexFamilyCodex: "gpt_5_codex_prompt.md", + codexFamilyGpt52: "gpt_5_2_prompt.md", + codexFamilyGpt51: "gpt_5_1_prompt.md", +} + +var codexCacheFiles = map[codexModelFamily]string{ + codexFamilyGpt52Codex: "gpt-5.2-codex-instructions.md", + codexFamilyCodexMax: "codex-max-instructions.md", + codexFamilyCodex: "codex-instructions.md", + codexFamilyGpt52: "gpt-5.2-instructions.md", + codexFamilyGpt51: "gpt-5.1-instructions.md", +} + +var codexModelMap = map[string]string{ + "gpt-5.1-codex": "gpt-5.1-codex", + "gpt-5.1-codex-low": "gpt-5.1-codex", + "gpt-5.1-codex-medium": "gpt-5.1-codex", + "gpt-5.1-codex-high": "gpt-5.1-codex", + "gpt-5.1-codex-max": "gpt-5.1-codex-max", + "gpt-5.1-codex-max-low": "gpt-5.1-codex-max", + "gpt-5.1-codex-max-medium": "gpt-5.1-codex-max", + "gpt-5.1-codex-max-high": "gpt-5.1-codex-max", + "gpt-5.1-codex-max-xhigh": "gpt-5.1-codex-max", + "gpt-5.2": "gpt-5.2", + "gpt-5.2-none": "gpt-5.2", + "gpt-5.2-low": "gpt-5.2", + "gpt-5.2-medium": "gpt-5.2", + "gpt-5.2-high": "gpt-5.2", + "gpt-5.2-xhigh": "gpt-5.2", + "gpt-5.2-codex": "gpt-5.2-codex", + "gpt-5.2-codex-low": "gpt-5.2-codex", + "gpt-5.2-codex-medium": "gpt-5.2-codex", + "gpt-5.2-codex-high": "gpt-5.2-codex", + "gpt-5.2-codex-xhigh": "gpt-5.2-codex", + "gpt-5.1-codex-mini": "gpt-5.1-codex-mini", + "gpt-5.1-codex-mini-medium": "gpt-5.1-codex-mini", + "gpt-5.1-codex-mini-high": "gpt-5.1-codex-mini", + "gpt-5.1": "gpt-5.1", + "gpt-5.1-none": "gpt-5.1", + "gpt-5.1-low": "gpt-5.1", + "gpt-5.1-medium": "gpt-5.1", + "gpt-5.1-high": "gpt-5.1", + "gpt-5.1-chat-latest": "gpt-5.1", + "gpt-5-codex": "gpt-5.1-codex", + "codex-mini-latest": "gpt-5.1-codex-mini", + "gpt-5-codex-mini": "gpt-5.1-codex-mini", + "gpt-5-codex-mini-medium": "gpt-5.1-codex-mini", + "gpt-5-codex-mini-high": "gpt-5.1-codex-mini", + "gpt-5": "gpt-5.1", + "gpt-5-mini": "gpt-5.1", + "gpt-5-nano": "gpt-5.1", +} + +var opencodePromptSignatures = []string{ + "you are a coding agent running in the opencode", + "you are opencode, an agent", + "you are opencode, an interactive cli agent", + "you are opencode, an interactive cli tool", + "you are opencode, the best coding agent on the planet", +} + +var opencodeContextMarkers = []string{ + "here is some useful information about the environment you are running in:", + "", + "instructions from:", + "", +} + +type codexTransformResult struct { + Modified bool + NormalizedModel string + PromptCacheKey string +} + +type codexCacheMetadata struct { + ETag string `json:"etag"` + Tag string `json:"tag"` + LastChecked int64 `json:"lastChecked"` + URL string `json:"url"` +} + +type opencodeCacheMetadata struct { + ETag string `json:"etag"` + LastFetch string `json:"lastFetch,omitempty"` + LastChecked int64 `json:"lastChecked"` +} + +func codexModeEnabled() bool { + value := strings.TrimSpace(os.Getenv("CODEX_MODE")) + if value == "" { + return true + } + switch strings.ToLower(value) { + case "0", "false", "no", "off": + return false + case "1", "true", "yes", "on": + return true + default: + return true + } +} + +func applyCodexOAuthTransform(reqBody map[string]any, codexMode bool) codexTransformResult { + result := codexTransformResult{} + + model := "" + if v, ok := reqBody["model"].(string); ok { + model = v + } + normalizedModel := normalizeCodexModel(model) + if normalizedModel != "" { + if model != normalizedModel { + reqBody["model"] = normalizedModel + result.Modified = true + } + result.NormalizedModel = normalizedModel + } + + reqBody["store"] = false + reqBody["stream"] = true + result.Modified = true + + instructions := getCodexInstructions(normalizedModel) + if instructions != "" { + if existing, ok := reqBody["instructions"].(string); !ok || existing != instructions { + reqBody["instructions"] = instructions + result.Modified = true + } + } + + if _, ok := reqBody["max_output_tokens"]; ok { + delete(reqBody, "max_output_tokens") + result.Modified = true + } + if _, ok := reqBody["max_completion_tokens"]; ok { + delete(reqBody, "max_completion_tokens") + result.Modified = true + } + + if normalizeCodexTools(reqBody) { + result.Modified = true + } + + if v, ok := reqBody["prompt_cache_key"].(string); ok { + result.PromptCacheKey = strings.TrimSpace(v) + } + + if input, ok := reqBody["input"].([]any); ok { + input = filterCodexInput(input) + if codexMode { + cachedPrompt := getOpenCodeCodexPrompt() + input = filterOpenCodeSystemPromptsWithCachedPrompt(input, cachedPrompt) + if hasTools(reqBody) { + input = addCodexBridgeMessage(input) + } + } else if hasTools(reqBody) { + input = addToolRemapMessage(input) + } + input = normalizeOrphanedToolOutputs(input) + reqBody["input"] = input + result.Modified = true + } + + effort, summary := resolveCodexReasoning(reqBody, normalizedModel) + if effort != "" || summary != "" { + reasoning := ensureMap(reqBody["reasoning"]) + if effort != "" { + reasoning["effort"] = effort + } + if summary != "" { + reasoning["summary"] = summary + } + reqBody["reasoning"] = reasoning + result.Modified = true + } + + textVerbosity := resolveTextVerbosity(reqBody) + if textVerbosity != "" { + text := ensureMap(reqBody["text"]) + text["verbosity"] = textVerbosity + reqBody["text"] = text + result.Modified = true + } + + include := resolveInclude(reqBody) + if include != nil { + reqBody["include"] = include + result.Modified = true + } + + return result +} + +func normalizeCodexModel(model string) string { + if model == "" { + return "gpt-5.1" + } + + modelID := model + if strings.Contains(modelID, "/") { + parts := strings.Split(modelID, "/") + modelID = parts[len(parts)-1] + } + + if mapped := getNormalizedCodexModel(modelID); mapped != "" { + return mapped + } + + normalized := strings.ToLower(modelID) + + if strings.Contains(normalized, "gpt-5.2-codex") || strings.Contains(normalized, "gpt 5.2 codex") { + return "gpt-5.2-codex" + } + if strings.Contains(normalized, "gpt-5.2") || strings.Contains(normalized, "gpt 5.2") { + return "gpt-5.2" + } + if strings.Contains(normalized, "gpt-5.1-codex-max") || strings.Contains(normalized, "gpt 5.1 codex max") { + return "gpt-5.1-codex-max" + } + if strings.Contains(normalized, "gpt-5.1-codex-mini") || strings.Contains(normalized, "gpt 5.1 codex mini") { + return "gpt-5.1-codex-mini" + } + if strings.Contains(normalized, "codex-mini-latest") || + strings.Contains(normalized, "gpt-5-codex-mini") || + strings.Contains(normalized, "gpt 5 codex mini") { + return "codex-mini-latest" + } + if strings.Contains(normalized, "gpt-5.1-codex") || strings.Contains(normalized, "gpt 5.1 codex") { + return "gpt-5.1-codex" + } + if strings.Contains(normalized, "gpt-5.1") || strings.Contains(normalized, "gpt 5.1") { + return "gpt-5.1" + } + if strings.Contains(normalized, "codex") { + return "gpt-5.1-codex" + } + if strings.Contains(normalized, "gpt-5") || strings.Contains(normalized, "gpt 5") { + return "gpt-5.1" + } + + return "gpt-5.1" +} + +func getNormalizedCodexModel(modelID string) string { + if modelID == "" { + return "" + } + if mapped, ok := codexModelMap[modelID]; ok { + return mapped + } + lower := strings.ToLower(modelID) + for key, value := range codexModelMap { + if strings.ToLower(key) == lower { + return value + } + } + return "" +} + +func getCodexModelFamily(normalizedModel string) codexModelFamily { + model := strings.ToLower(normalizedModel) + if strings.Contains(model, "gpt-5.2-codex") || strings.Contains(model, "gpt 5.2 codex") { + return codexFamilyGpt52Codex + } + if strings.Contains(model, "codex-max") { + return codexFamilyCodexMax + } + if strings.Contains(model, "codex") || strings.HasPrefix(model, "codex-") { + return codexFamilyCodex + } + if strings.Contains(model, "gpt-5.2") { + return codexFamilyGpt52 + } + return codexFamilyGpt51 +} + +func getCodexInstructions(normalizedModel string) string { + if normalizedModel == "" { + normalizedModel = "gpt-5.1-codex" + } + + modelFamily := getCodexModelFamily(normalizedModel) + promptFile := codexPromptFiles[modelFamily] + cacheFile := codexCachePath(codexCacheFiles[modelFamily]) + metaFile := codexCachePath(strings.TrimSuffix(codexCacheFiles[modelFamily], ".md") + "-meta.json") + + var meta codexCacheMetadata + if loadJSON(metaFile, &meta) && meta.LastChecked > 0 { + if time.Since(time.UnixMilli(meta.LastChecked)) < codexCacheTTL { + if cached, ok := readFile(cacheFile); ok { + return cached + } + } + } + + latestTag, err := getLatestCodexReleaseTag() + if err != nil { + if cached, ok := readFile(cacheFile); ok { + return cached + } + return "" + } + + if meta.Tag != latestTag { + meta.ETag = "" + } + + promptURL := fmt.Sprintf(codexPromptURLFmt, latestTag, promptFile) + content, etag, status, err := fetchWithETag(promptURL, meta.ETag) + if err == nil && status == http.StatusNotModified { + if cached, ok := readFile(cacheFile); ok { + return cached + } + } + if err == nil && status >= 200 && status < 300 { + if content != "" { + if err := writeFile(cacheFile, content); err == nil { + meta = codexCacheMetadata{ + ETag: etag, + Tag: latestTag, + LastChecked: time.Now().UnixMilli(), + URL: promptURL, + } + _ = writeJSON(metaFile, meta) + } + return content + } + } + + if cached, ok := readFile(cacheFile); ok { + return cached + } + + return "" +} + +func getLatestCodexReleaseTag() (string, error) { + body, _, status, err := fetchWithETag(codexReleaseAPIURL, "") + if err == nil && status >= 200 && status < 300 && body != "" { + var data struct { + TagName string `json:"tag_name"` + } + if json.Unmarshal([]byte(body), &data) == nil && data.TagName != "" { + return data.TagName, nil + } + } + + resp, err := http.Get(codexReleaseHTMLURL) + if err != nil { + return "", err + } + defer resp.Body.Close() + + finalURL := "" + if resp.Request != nil && resp.Request.URL != nil { + finalURL = resp.Request.URL.String() + } + if finalURL != "" { + if tag := parseReleaseTagFromURL(finalURL); tag != "" { + return tag, nil + } + } + + html, err := io.ReadAll(resp.Body) + if err != nil { + return "", err + } + return parseReleaseTagFromHTML(string(html)) +} + +func parseReleaseTagFromURL(url string) string { + parts := strings.Split(url, "/tag/") + if len(parts) < 2 { + return "" + } + tag := parts[len(parts)-1] + if tag == "" || strings.Contains(tag, "/") { + return "" + } + return tag +} + +func parseReleaseTagFromHTML(html string) (string, error) { + const marker = "/openai/codex/releases/tag/" + idx := strings.Index(html, marker) + if idx == -1 { + return "", fmt.Errorf("release tag not found") + } + rest := html[idx+len(marker):] + for i, r := range rest { + if r == '"' || r == '\'' { + return rest[:i], nil + } + } + return "", fmt.Errorf("release tag not found") +} + +func getOpenCodeCodexPrompt() string { + cacheDir := codexCachePath("") + if cacheDir == "" { + return "" + } + cacheFile := filepath.Join(cacheDir, "opencode-codex.txt") + metaFile := filepath.Join(cacheDir, "opencode-codex-meta.json") + + var cachedContent string + if content, ok := readFile(cacheFile); ok { + cachedContent = content + } + + var meta opencodeCacheMetadata + if loadJSON(metaFile, &meta) && meta.LastChecked > 0 && cachedContent != "" { + if time.Since(time.UnixMilli(meta.LastChecked)) < codexCacheTTL { + return cachedContent + } + } + + content, etag, status, err := fetchWithETag(opencodeCodexURL, meta.ETag) + if err == nil && status == http.StatusNotModified && cachedContent != "" { + return cachedContent + } + if err == nil && status >= 200 && status < 300 && content != "" { + _ = writeFile(cacheFile, content) + meta = opencodeCacheMetadata{ + ETag: etag, + LastFetch: time.Now().UTC().Format(time.RFC3339), + LastChecked: time.Now().UnixMilli(), + } + _ = writeJSON(metaFile, meta) + return content + } + + return cachedContent +} + +func filterCodexInput(input []any) []any { + filtered := make([]any, 0, len(input)) + for _, item := range input { + m, ok := item.(map[string]any) + if !ok { + filtered = append(filtered, item) + continue + } + if typ, ok := m["type"].(string); ok && typ == "item_reference" { + continue + } + if _, ok := m["id"]; ok { + delete(m, "id") + } + filtered = append(filtered, m) + } + return filtered +} + +func filterOpenCodeSystemPromptsWithCachedPrompt(input []any, cachedPrompt string) []any { + if len(input) == 0 { + return input + } + cachedPrompt = strings.TrimSpace(cachedPrompt) + + result := make([]any, 0, len(input)) + for _, item := range input { + m, ok := item.(map[string]any) + if !ok { + result = append(result, item) + continue + } + role, _ := m["role"].(string) + if role == "user" { + result = append(result, item) + continue + } + if !isOpenCodeSystemPrompt(m, cachedPrompt) { + result = append(result, item) + continue + } + contentText := getContentText(m) + if contentText == "" { + continue + } + if preserved := extractOpenCodeContext(contentText); preserved != "" { + result = append(result, replaceContentText(m, preserved)) + } + } + return result +} + +func isOpenCodeSystemPrompt(item map[string]any, cachedPrompt string) bool { + role, _ := item["role"].(string) + if role != "developer" && role != "system" { + return false + } + + contentText := getContentText(item) + if contentText == "" { + return false + } + + if cachedPrompt != "" { + contentTrimmed := strings.TrimSpace(contentText) + cachedTrimmed := strings.TrimSpace(cachedPrompt) + if contentTrimmed == cachedTrimmed { + return true + } + if strings.HasPrefix(contentTrimmed, cachedTrimmed) { + return true + } + contentPrefix := contentTrimmed + if len(contentPrefix) > 200 { + contentPrefix = contentPrefix[:200] + } + cachedPrefix := cachedTrimmed + if len(cachedPrefix) > 200 { + cachedPrefix = cachedPrefix[:200] + } + if contentPrefix == cachedPrefix { + return true + } + } + + normalized := strings.ToLower(strings.TrimLeftFunc(contentText, unicode.IsSpace)) + for _, signature := range opencodePromptSignatures { + if strings.HasPrefix(normalized, signature) { + return true + } + } + return false +} + +func getContentText(item map[string]any) string { + content := item["content"] + if content == nil { + return "" + } + switch v := content.(type) { + case string: + return v + case []any: + var parts []string + for _, part := range v { + partMap, ok := part.(map[string]any) + if !ok { + continue + } + typ, _ := partMap["type"].(string) + if typ != "input_text" { + continue + } + if text, ok := partMap["text"].(string); ok && text != "" { + parts = append(parts, text) + } + } + return strings.Join(parts, "\n") + default: + return "" + } +} + +func replaceContentText(item map[string]any, contentText string) map[string]any { + content := item["content"] + switch content.(type) { + case string: + item["content"] = contentText + case []any: + item["content"] = []any{map[string]any{ + "type": "input_text", + "text": contentText, + }} + default: + item["content"] = contentText + } + return item +} + +func extractOpenCodeContext(contentText string) string { + lower := strings.ToLower(contentText) + earliest := -1 + for _, marker := range opencodeContextMarkers { + idx := strings.Index(lower, marker) + if idx >= 0 && (earliest == -1 || idx < earliest) { + earliest = idx + } + } + if earliest == -1 { + return "" + } + return strings.TrimLeftFunc(contentText[earliest:], unicode.IsSpace) +} + +func addCodexBridgeMessage(input []any) []any { + message := map[string]any{ + "type": "message", + "role": "developer", + "content": []any{ + map[string]any{ + "type": "input_text", + "text": codexOpenCodeBridge, + }, + }, + } + return append([]any{message}, input...) +} + +func addToolRemapMessage(input []any) []any { + message := map[string]any{ + "type": "message", + "role": "developer", + "content": []any{ + map[string]any{ + "type": "input_text", + "text": codexToolRemapMessage, + }, + }, + } + return append([]any{message}, input...) +} + +func hasTools(reqBody map[string]any) bool { + tools, ok := reqBody["tools"] + if !ok || tools == nil { + return false + } + if list, ok := tools.([]any); ok { + return len(list) > 0 + } + return true +} + +func normalizeCodexTools(reqBody map[string]any) bool { + rawTools, ok := reqBody["tools"] + if !ok || rawTools == nil { + return false + } + tools, ok := rawTools.([]any) + if !ok { + return false + } + + modified := false + for idx, tool := range tools { + toolMap, ok := tool.(map[string]any) + if !ok { + continue + } + + toolType, _ := toolMap["type"].(string) + if strings.TrimSpace(toolType) != "function" { + continue + } + + function, ok := toolMap["function"].(map[string]any) + if !ok { + continue + } + + if _, ok := toolMap["name"]; !ok { + if name, ok := function["name"].(string); ok && strings.TrimSpace(name) != "" { + toolMap["name"] = name + modified = true + } + } + if _, ok := toolMap["description"]; !ok { + if desc, ok := function["description"].(string); ok && strings.TrimSpace(desc) != "" { + toolMap["description"] = desc + modified = true + } + } + if _, ok := toolMap["parameters"]; !ok { + if params, ok := function["parameters"]; ok { + toolMap["parameters"] = params + modified = true + } + } + if _, ok := toolMap["strict"]; !ok { + if strict, ok := function["strict"]; ok { + toolMap["strict"] = strict + modified = true + } + } + + tools[idx] = toolMap + } + + if modified { + reqBody["tools"] = tools + } + + return modified +} + +func normalizeOrphanedToolOutputs(input []any) []any { + functionCallIDs := map[string]bool{} + localShellCallIDs := map[string]bool{} + customToolCallIDs := map[string]bool{} + + for _, item := range input { + m, ok := item.(map[string]any) + if !ok { + continue + } + callID := getCallID(m) + if callID == "" { + continue + } + switch m["type"] { + case "function_call": + functionCallIDs[callID] = true + case "local_shell_call": + localShellCallIDs[callID] = true + case "custom_tool_call": + customToolCallIDs[callID] = true + } + } + + output := make([]any, 0, len(input)) + for _, item := range input { + m, ok := item.(map[string]any) + if !ok { + output = append(output, item) + continue + } + switch m["type"] { + case "function_call_output": + callID := getCallID(m) + if callID == "" || !(functionCallIDs[callID] || localShellCallIDs[callID]) { + output = append(output, convertOrphanedOutputToMessage(m, callID)) + continue + } + case "custom_tool_call_output": + callID := getCallID(m) + if callID == "" || !customToolCallIDs[callID] { + output = append(output, convertOrphanedOutputToMessage(m, callID)) + continue + } + case "local_shell_call_output": + callID := getCallID(m) + if callID == "" || !localShellCallIDs[callID] { + output = append(output, convertOrphanedOutputToMessage(m, callID)) + continue + } + } + output = append(output, m) + } + return output +} + +func getCallID(item map[string]any) string { + raw, ok := item["call_id"] + if !ok { + return "" + } + callID, ok := raw.(string) + if !ok { + return "" + } + callID = strings.TrimSpace(callID) + if callID == "" { + return "" + } + return callID +} + +func convertOrphanedOutputToMessage(item map[string]any, callID string) map[string]any { + toolName := "tool" + if name, ok := item["name"].(string); ok && name != "" { + toolName = name + } + labelID := callID + if labelID == "" { + labelID = "unknown" + } + text := stringifyOutput(item["output"]) + if len(text) > 16000 { + text = text[:16000] + "\n...[truncated]" + } + return map[string]any{ + "type": "message", + "role": "assistant", + "content": fmt.Sprintf("[Previous %s result; call_id=%s]: %s", toolName, labelID, text), + } +} + +func stringifyOutput(output any) string { + switch v := output.(type) { + case string: + return v + default: + if data, err := json.Marshal(v); err == nil { + return string(data) + } + return fmt.Sprintf("%v", v) + } +} + +func resolveCodexReasoning(reqBody map[string]any, modelName string) (string, string) { + existingEffort := getReasoningValue(reqBody, "effort", "reasoningEffort") + existingSummary := getReasoningValue(reqBody, "summary", "reasoningSummary") + return getReasoningConfig(modelName, existingEffort, existingSummary) +} + +func getReasoningValue(reqBody map[string]any, field, providerField string) string { + if reasoning, ok := reqBody["reasoning"].(map[string]any); ok { + if value, ok := reasoning[field].(string); ok && value != "" { + return value + } + } + if provider := getProviderOpenAI(reqBody); provider != nil { + if value, ok := provider[providerField].(string); ok && value != "" { + return value + } + } + return "" +} + +func resolveTextVerbosity(reqBody map[string]any) string { + if text, ok := reqBody["text"].(map[string]any); ok { + if value, ok := text["verbosity"].(string); ok && value != "" { + return value + } + } + if provider := getProviderOpenAI(reqBody); provider != nil { + if value, ok := provider["textVerbosity"].(string); ok && value != "" { + return value + } + } + return "medium" +} + +func resolveInclude(reqBody map[string]any) []any { + include := toStringSlice(reqBody["include"]) + if len(include) == 0 { + if provider := getProviderOpenAI(reqBody); provider != nil { + include = toStringSlice(provider["include"]) + } + } + if len(include) == 0 { + include = []string{"reasoning.encrypted_content"} + } + + unique := make(map[string]struct{}, len(include)+1) + for _, value := range include { + if value == "" { + continue + } + unique[value] = struct{}{} + } + if _, ok := unique["reasoning.encrypted_content"]; !ok { + include = append(include, "reasoning.encrypted_content") + unique["reasoning.encrypted_content"] = struct{}{} + } + + final := make([]any, 0, len(unique)) + for _, value := range include { + if value == "" { + continue + } + if _, ok := unique[value]; ok { + final = append(final, value) + delete(unique, value) + } + } + for value := range unique { + final = append(final, value) + } + return final +} + +func getReasoningConfig(modelName, effortOverride, summaryOverride string) (string, string) { + normalized := strings.ToLower(modelName) + + isGpt52Codex := strings.Contains(normalized, "gpt-5.2-codex") || strings.Contains(normalized, "gpt 5.2 codex") + isGpt52General := (strings.Contains(normalized, "gpt-5.2") || strings.Contains(normalized, "gpt 5.2")) && !isGpt52Codex + isCodexMax := strings.Contains(normalized, "codex-max") || strings.Contains(normalized, "codex max") + isCodexMini := strings.Contains(normalized, "codex-mini") || + strings.Contains(normalized, "codex mini") || + strings.Contains(normalized, "codex_mini") || + strings.Contains(normalized, "codex-mini-latest") + isCodex := strings.Contains(normalized, "codex") && !isCodexMini + isLightweight := !isCodexMini && (strings.Contains(normalized, "nano") || strings.Contains(normalized, "mini")) + isGpt51General := (strings.Contains(normalized, "gpt-5.1") || strings.Contains(normalized, "gpt 5.1")) && + !isCodex && !isCodexMax && !isCodexMini + + supportsXhigh := isGpt52General || isGpt52Codex || isCodexMax + supportsNone := isGpt52General || isGpt51General + + defaultEffort := "medium" + if isCodexMini { + defaultEffort = "medium" + } else if supportsXhigh { + defaultEffort = "high" + } else if isLightweight { + defaultEffort = "minimal" + } + + effort := effortOverride + if effort == "" { + effort = defaultEffort + } + + if isCodexMini { + if effort == "minimal" || effort == "low" || effort == "none" { + effort = "medium" + } + if effort == "xhigh" { + effort = "high" + } + if effort != "high" && effort != "medium" { + effort = "medium" + } + } + + if !supportsXhigh && effort == "xhigh" { + effort = "high" + } + if !supportsNone && effort == "none" { + effort = "low" + } + if effort == "minimal" { + effort = "low" + } + + summary := summaryOverride + if summary == "" { + summary = "auto" + } + + return effort, summary +} + +func getProviderOpenAI(reqBody map[string]any) map[string]any { + providerOptions, ok := reqBody["providerOptions"].(map[string]any) + if !ok || providerOptions == nil { + return nil + } + openaiOptions, ok := providerOptions["openai"].(map[string]any) + if !ok || openaiOptions == nil { + return nil + } + return openaiOptions +} + +func ensureMap(value any) map[string]any { + if value == nil { + return map[string]any{} + } + if m, ok := value.(map[string]any); ok { + return m + } + return map[string]any{} +} + +func toStringSlice(value any) []string { + if value == nil { + return nil + } + switch v := value.(type) { + case []string: + return append([]string{}, v...) + case []any: + out := make([]string, 0, len(v)) + for _, item := range v { + if text, ok := item.(string); ok { + out = append(out, text) + } + } + return out + default: + return nil + } +} + +func codexCachePath(filename string) string { + home, err := os.UserHomeDir() + if err != nil { + return "" + } + cacheDir := filepath.Join(home, ".opencode", "cache") + if filename == "" { + return cacheDir + } + return filepath.Join(cacheDir, filename) +} + +func readFile(path string) (string, bool) { + if path == "" { + return "", false + } + data, err := os.ReadFile(path) + if err != nil { + return "", false + } + return string(data), true +} + +func writeFile(path, content string) error { + if path == "" { + return fmt.Errorf("empty cache path") + } + if err := os.MkdirAll(filepath.Dir(path), 0o755); err != nil { + return err + } + return os.WriteFile(path, []byte(content), 0o644) +} + +func loadJSON(path string, target any) bool { + data, err := os.ReadFile(path) + if err != nil { + return false + } + if err := json.Unmarshal(data, target); err != nil { + return false + } + return true +} + +func writeJSON(path string, value any) error { + if path == "" { + return fmt.Errorf("empty json path") + } + if err := os.MkdirAll(filepath.Dir(path), 0o755); err != nil { + return err + } + data, err := json.Marshal(value) + if err != nil { + return err + } + return os.WriteFile(path, data, 0o644) +} + +func fetchWithETag(url, etag string) (string, string, int, error) { + req, err := http.NewRequest(http.MethodGet, url, nil) + if err != nil { + return "", "", 0, err + } + req.Header.Set("User-Agent", "sub2api-codex") + if etag != "" { + req.Header.Set("If-None-Match", etag) + } + resp, err := http.DefaultClient.Do(req) + if err != nil { + return "", "", 0, err + } + defer resp.Body.Close() + + body, err := io.ReadAll(resp.Body) + if err != nil { + return "", "", resp.StatusCode, err + } + return string(body), resp.Header.Get("etag"), resp.StatusCode, nil +} diff --git a/backend/internal/service/openai_gateway_service.go b/backend/internal/service/openai_gateway_service.go index 42e98585..8f59110d 100644 --- a/backend/internal/service/openai_gateway_service.go +++ b/backend/internal/service/openai_gateway_service.go @@ -12,6 +12,7 @@ import ( "io" "log" "net/http" + "os" "regexp" "sort" "strconv" @@ -528,6 +529,7 @@ func (s *OpenAIGatewayService) Forward(ctx context.Context, c *gin.Context, acco // Extract model and stream from parsed body reqModel, _ := reqBody["model"].(string) reqStream, _ := reqBody["stream"].(bool) + promptCacheKey := "" // Track if body needs re-serialization bodyModified := false @@ -540,19 +542,17 @@ func (s *OpenAIGatewayService) Forward(ctx context.Context, c *gin.Context, acco bodyModified = true } - // For OAuth accounts using ChatGPT internal API: - // 1. Add store: false - // 2. Normalize input format for Codex API compatibility if account.Type == AccountTypeOAuth { - reqBody["store"] = false - bodyModified = true - - // Normalize input format: convert AI SDK multi-part content format to simplified format - // AI SDK sends: {"content": [{"type": "input_text", "text": "..."}]} - // Codex API expects: {"content": "..."} - if normalizeInputForCodexAPI(reqBody) { + codexResult := applyCodexOAuthTransform(reqBody, codexModeEnabled()) + if codexResult.Modified { bodyModified = true } + if codexResult.NormalizedModel != "" { + mappedModel = codexResult.NormalizedModel + } + if codexResult.PromptCacheKey != "" { + promptCacheKey = codexResult.PromptCacheKey + } } // Re-serialize body only if modified @@ -571,7 +571,7 @@ func (s *OpenAIGatewayService) Forward(ctx context.Context, c *gin.Context, acco } // Build upstream request - upstreamReq, err := s.buildUpstreamRequest(ctx, c, account, body, token, reqStream) + upstreamReq, err := s.buildUpstreamRequest(ctx, c, account, body, token, reqStream, promptCacheKey) if err != nil { return nil, err } @@ -632,7 +632,7 @@ func (s *OpenAIGatewayService) Forward(ctx context.Context, c *gin.Context, acco }, nil } -func (s *OpenAIGatewayService) buildUpstreamRequest(ctx context.Context, c *gin.Context, account *Account, body []byte, token string, isStream bool) (*http.Request, error) { +func (s *OpenAIGatewayService) buildUpstreamRequest(ctx context.Context, c *gin.Context, account *Account, body []byte, token string, isStream bool, promptCacheKey string) (*http.Request, error) { // Determine target URL based on account type var targetURL string switch account.Type { @@ -672,12 +672,6 @@ func (s *OpenAIGatewayService) buildUpstreamRequest(ctx context.Context, c *gin. if chatgptAccountID != "" { req.Header.Set("chatgpt-account-id", chatgptAccountID) } - // Set accept header based on stream mode - if isStream { - req.Header.Set("accept", "text/event-stream") - } else { - req.Header.Set("accept", "application/json") - } } // Whitelist passthrough headers @@ -689,6 +683,18 @@ func (s *OpenAIGatewayService) buildUpstreamRequest(ctx context.Context, c *gin. } } } + if account.Type == AccountTypeOAuth { + req.Header.Set("OpenAI-Beta", "responses=experimental") + req.Header.Set("originator", "codex_cli_rs") + req.Header.Set("accept", "text/event-stream") + if promptCacheKey != "" { + req.Header.Set("conversation_id", promptCacheKey) + req.Header.Set("session_id", promptCacheKey) + } else { + req.Header.Del("conversation_id") + req.Header.Del("session_id") + } + } // Apply custom User-Agent if configured customUA := account.GetOpenAIUserAgent() @@ -706,6 +712,7 @@ func (s *OpenAIGatewayService) buildUpstreamRequest(ctx context.Context, c *gin. func (s *OpenAIGatewayService) handleErrorResponse(ctx context.Context, resp *http.Response, c *gin.Context, account *Account) (*OpenAIForwardResult, error) { body, _ := io.ReadAll(resp.Body) + logUpstreamErrorBody(account.ID, resp.StatusCode, body) // Check custom error codes if !account.ShouldHandleErrorCode(resp.StatusCode) { @@ -764,6 +771,24 @@ func (s *OpenAIGatewayService) handleErrorResponse(ctx context.Context, resp *ht return nil, fmt.Errorf("upstream error: %d", resp.StatusCode) } +func logUpstreamErrorBody(accountID int64, statusCode int, body []byte) { + if strings.ToLower(strings.TrimSpace(os.Getenv("GATEWAY_LOG_UPSTREAM_ERROR_BODY"))) != "true" { + return + } + + maxBytes := 2048 + if rawMax := strings.TrimSpace(os.Getenv("GATEWAY_LOG_UPSTREAM_ERROR_BODY_MAX_BYTES")); rawMax != "" { + if parsed, err := strconv.Atoi(rawMax); err == nil && parsed > 0 { + maxBytes = parsed + } + } + if len(body) > maxBytes { + body = body[:maxBytes] + } + + log.Printf("Upstream error body: account=%d status=%d body=%q", accountID, statusCode, string(body)) +} + // openaiStreamingResult streaming response result type openaiStreamingResult struct { usage *OpenAIUsage @@ -1016,6 +1041,13 @@ func (s *OpenAIGatewayService) handleNonStreamingResponse(ctx context.Context, r return nil, err } + if account.Type == AccountTypeOAuth { + bodyLooksLikeSSE := bytes.Contains(body, []byte("data:")) || bytes.Contains(body, []byte("event:")) + if isEventStreamResponse(resp.Header) || bodyLooksLikeSSE { + return s.handleOAuthSSEToJSON(resp, c, body, originalModel, mappedModel) + } + } + // Parse usage var response struct { Usage struct { @@ -1055,6 +1087,110 @@ func (s *OpenAIGatewayService) handleNonStreamingResponse(ctx context.Context, r return usage, nil } +func isEventStreamResponse(header http.Header) bool { + contentType := strings.ToLower(header.Get("Content-Type")) + return strings.Contains(contentType, "text/event-stream") +} + +func (s *OpenAIGatewayService) handleOAuthSSEToJSON(resp *http.Response, c *gin.Context, body []byte, originalModel, mappedModel string) (*OpenAIUsage, error) { + bodyText := string(body) + finalResponse, ok := extractCodexFinalResponse(bodyText) + + usage := &OpenAIUsage{} + if ok { + var response struct { + Usage struct { + InputTokens int `json:"input_tokens"` + OutputTokens int `json:"output_tokens"` + InputTokenDetails struct { + CachedTokens int `json:"cached_tokens"` + } `json:"input_tokens_details"` + } `json:"usage"` + } + if err := json.Unmarshal(finalResponse, &response); err == nil { + usage.InputTokens = response.Usage.InputTokens + usage.OutputTokens = response.Usage.OutputTokens + usage.CacheReadInputTokens = response.Usage.InputTokenDetails.CachedTokens + } + body = finalResponse + if originalModel != mappedModel { + body = s.replaceModelInResponseBody(body, mappedModel, originalModel) + } + } else { + usage = s.parseSSEUsageFromBody(bodyText) + if originalModel != mappedModel { + bodyText = s.replaceModelInSSEBody(bodyText, mappedModel, originalModel) + } + body = []byte(bodyText) + } + + responseheaders.WriteFilteredHeaders(c.Writer.Header(), resp.Header, s.cfg.Security.ResponseHeaders) + + contentType := "application/json; charset=utf-8" + if !ok { + contentType = resp.Header.Get("Content-Type") + if contentType == "" { + contentType = "text/event-stream" + } + } + c.Data(resp.StatusCode, contentType, body) + + return usage, nil +} + +func extractCodexFinalResponse(body string) ([]byte, bool) { + lines := strings.Split(body, "\n") + for _, line := range lines { + if !openaiSSEDataRe.MatchString(line) { + continue + } + data := openaiSSEDataRe.ReplaceAllString(line, "") + if data == "" || data == "[DONE]" { + continue + } + var event struct { + Type string `json:"type"` + Response json.RawMessage `json:"response"` + } + if json.Unmarshal([]byte(data), &event) != nil { + continue + } + if event.Type == "response.done" || event.Type == "response.completed" { + if len(event.Response) > 0 { + return event.Response, true + } + } + } + return nil, false +} + +func (s *OpenAIGatewayService) parseSSEUsageFromBody(body string) *OpenAIUsage { + usage := &OpenAIUsage{} + lines := strings.Split(body, "\n") + for _, line := range lines { + if !openaiSSEDataRe.MatchString(line) { + continue + } + data := openaiSSEDataRe.ReplaceAllString(line, "") + if data == "" || data == "[DONE]" { + continue + } + s.parseSSEUsage(data, usage) + } + return usage +} + +func (s *OpenAIGatewayService) replaceModelInSSEBody(body, fromModel, toModel string) string { + lines := strings.Split(body, "\n") + for i, line := range lines { + if !openaiSSEDataRe.MatchString(line) { + continue + } + lines[i] = s.replaceModelInSSELine(line, fromModel, toModel) + } + return strings.Join(lines, "\n") +} + func (s *OpenAIGatewayService) validateUpstreamBaseURL(raw string) (string, error) { if s.cfg != nil && !s.cfg.Security.URLAllowlist.Enabled { normalized, err := urlvalidator.ValidateURLFormat(raw, s.cfg.Security.URLAllowlist.AllowInsecureHTTP) @@ -1094,101 +1230,6 @@ func (s *OpenAIGatewayService) replaceModelInResponseBody(body []byte, fromModel return newBody } -// normalizeInputForCodexAPI converts AI SDK multi-part content format to simplified format -// that the ChatGPT internal Codex API expects. -// -// AI SDK sends content as an array of typed objects: -// -// {"content": [{"type": "input_text", "text": "hello"}]} -// -// ChatGPT Codex API expects content as a simple string: -// -// {"content": "hello"} -// -// This function modifies reqBody in-place and returns true if any modification was made. -func normalizeInputForCodexAPI(reqBody map[string]any) bool { - input, ok := reqBody["input"] - if !ok { - return false - } - - // Handle case where input is a simple string (already compatible) - if _, isString := input.(string); isString { - return false - } - - // Handle case where input is an array of messages - inputArray, ok := input.([]any) - if !ok { - return false - } - - modified := false - for _, item := range inputArray { - message, ok := item.(map[string]any) - if !ok { - continue - } - - content, ok := message["content"] - if !ok { - continue - } - - // If content is already a string, no conversion needed - if _, isString := content.(string); isString { - continue - } - - // If content is an array (AI SDK format), convert to string - contentArray, ok := content.([]any) - if !ok { - continue - } - - // Extract text from content array - var textParts []string - for _, part := range contentArray { - partMap, ok := part.(map[string]any) - if !ok { - continue - } - - // Handle different content types - partType, _ := partMap["type"].(string) - switch partType { - case "input_text", "text": - // Extract text from input_text or text type - if text, ok := partMap["text"].(string); ok { - textParts = append(textParts, text) - } - case "input_image", "image": - // For images, we need to preserve the original format - // as ChatGPT Codex API may support images in a different way - // For now, skip image parts (they will be lost in conversion) - // TODO: Consider preserving image data or handling it separately - continue - case "input_file", "file": - // Similar to images, file inputs may need special handling - continue - default: - // For unknown types, try to extract text if available - if text, ok := partMap["text"].(string); ok { - textParts = append(textParts, text) - } - } - } - - // Convert content array to string - if len(textParts) > 0 { - message["content"] = strings.Join(textParts, "\n") - modified = true - } - } - - return modified -} - // OpenAIRecordUsageInput input for recording usage type OpenAIRecordUsageInput struct { Result *OpenAIForwardResult diff --git a/backend/internal/service/openai_gateway_service_test.go b/backend/internal/service/openai_gateway_service_test.go index 8562d940..c30fba7e 100644 --- a/backend/internal/service/openai_gateway_service_test.go +++ b/backend/internal/service/openai_gateway_service_test.go @@ -220,7 +220,7 @@ func TestOpenAIInvalidBaseURLWhenAllowlistDisabled(t *testing.T) { Credentials: map[string]any{"base_url": "://invalid-url"}, } - _, err := svc.buildUpstreamRequest(c.Request.Context(), c, account, []byte("{}"), "token", false) + _, err := svc.buildUpstreamRequest(c.Request.Context(), c, account, []byte("{}"), "token", false, "") if err == nil { t.Fatalf("expected error for invalid base_url when allowlist disabled") } diff --git a/backend/internal/service/prompts/codex_opencode_bridge.txt b/backend/internal/service/prompts/codex_opencode_bridge.txt new file mode 100644 index 00000000..093aa0f2 --- /dev/null +++ b/backend/internal/service/prompts/codex_opencode_bridge.txt @@ -0,0 +1,122 @@ +# Codex Running in OpenCode + +You are running Codex through OpenCode, an open-source terminal coding assistant. OpenCode provides different tools but follows Codex operating principles. + +## CRITICAL: Tool Replacements + + +❌ APPLY_PATCH DOES NOT EXIST → ✅ USE "edit" INSTEAD +- NEVER use: apply_patch, applyPatch +- ALWAYS use: edit tool for ALL file modifications +- Before modifying files: Verify you're using "edit", NOT "apply_patch" + + + +❌ UPDATE_PLAN DOES NOT EXIST → ✅ USE "todowrite" INSTEAD +- NEVER use: update_plan, updatePlan, read_plan, readPlan +- ALWAYS use: todowrite for task/plan updates, todoread to read plans +- Before plan operations: Verify you're using "todowrite", NOT "update_plan" + + +## Available OpenCode Tools + +**File Operations:** +- `write` - Create new files + - Overwriting existing files requires a prior Read in this session; default to ASCII unless the file already uses Unicode. +- `edit` - Modify existing files (REPLACES apply_patch) + - Requires a prior Read in this session; preserve exact indentation; ensure `oldString` uniquely matches or use `replaceAll`; edit fails if ambiguous or missing. +- `read` - Read file contents + +**Search/Discovery:** +- `grep` - Search file contents (tool, not bash grep); use `include` to filter patterns; set `path` only when not searching workspace root; for cross-file match counts use bash with `rg`. +- `glob` - Find files by pattern; defaults to workspace cwd unless `path` is set. +- `list` - List directories (requires absolute paths) + +**Execution:** +- `bash` - Run shell commands + - No workdir parameter; do not include it in tool calls. + - Always include a short description for the command. + - Do not use cd; use absolute paths in commands. + - Quote paths containing spaces with double quotes. + - Chain multiple commands with ';' or '&&'; avoid newlines. + - Use Grep/Glob tools for searches; only use bash with `rg` when you need counts or advanced features. + - Do not use `ls`/`cat` in bash; use `list`/`read` tools instead. + - For deletions (rm), verify by listing parent dir with `list`. + +**Network:** +- `webfetch` - Fetch web content + - Use fully-formed URLs (http/https; http auto-upgrades to https). + - Always set `format` to one of: text | markdown | html; prefer markdown unless otherwise required. + - Read-only; short cache window. + +**Task Management:** +- `todowrite` - Manage tasks/plans (REPLACES update_plan) +- `todoread` - Read current plan + +## Substitution Rules + +Base instruction says: You MUST use instead: +apply_patch → edit +update_plan → todowrite +read_plan → todoread + +**Path Usage:** Use per-tool conventions to avoid conflicts: +- Tool calls: `read`, `edit`, `write`, `list` require absolute paths. +- Searches: `grep`/`glob` default to the workspace cwd; prefer relative include patterns; set `path` only when a different root is needed. +- Presentation: In assistant messages, show workspace-relative paths; use absolute paths only inside tool calls. +- Tool schema overrides general path preferences—do not convert required absolute paths to relative. + +## Verification Checklist + +Before file/plan modifications: +1. Am I using "edit" NOT "apply_patch"? +2. Am I using "todowrite" NOT "update_plan"? +3. Is this tool in the approved list above? +4. Am I following each tool's path requirements? + +If ANY answer is NO → STOP and correct before proceeding. + +## OpenCode Working Style + +**Communication:** +- Send brief preambles (8-12 words) before tool calls, building on prior context +- Provide progress updates during longer tasks + +**Execution:** +- Keep working autonomously until query is fully resolved before yielding +- Don't return to user with partial solutions + +**Code Approach:** +- New projects: Be ambitious and creative +- Existing codebases: Surgical precision - modify only what's requested unless explicitly instructed to do otherwise + +**Testing:** +- If tests exist: Start specific to your changes, then broader validation + +## Advanced Tools + +**Task Tool (Sub-Agents):** +- Use the Task tool (functions.task) to launch sub-agents +- Check the Task tool description for current agent types and their capabilities +- Useful for complex analysis, specialized workflows, or tasks requiring isolated context +- The agent list is dynamically generated - refer to tool schema for available agents + +**Parallelization:** +- When multiple independent tool calls are needed, use multi_tool_use.parallel to run them concurrently. +- Reserve sequential calls for ordered or data-dependent steps. + +**MCP Tools:** +- Model Context Protocol servers provide additional capabilities +- MCP tools are prefixed: `mcp____` +- Check your available tools for MCP integrations +- Use when the tool's functionality matches your task needs + +## What Remains from Codex + +Sandbox policies, approval mechanisms, final answer formatting, git commit protocols, and file reference formats all follow Codex instructions. In approval policy "never", never request escalations. + +## Approvals & Safety +- Assume workspace-write filesystem, network enabled, approval on-failure unless explicitly stated otherwise. +- When a command fails due to sandboxing or permissions, retry with escalated permissions if allowed by policy, including a one-line justification. +- Treat destructive commands (e.g., `rm`, `git reset --hard`) as requiring explicit user request or approval. +- When uncertain, prefer non-destructive verification first (e.g., confirm file existence with `list`, then delete with `bash`). \ No newline at end of file diff --git a/backend/internal/service/prompts/tool_remap_message.txt b/backend/internal/service/prompts/tool_remap_message.txt new file mode 100644 index 00000000..4ff986e1 --- /dev/null +++ b/backend/internal/service/prompts/tool_remap_message.txt @@ -0,0 +1,63 @@ + + +YOU ARE IN A DIFFERENT ENVIRONMENT. These instructions override ALL previous tool references. + + + + +❌ APPLY_PATCH DOES NOT EXIST → ✅ USE "edit" INSTEAD +- NEVER use: apply_patch, applyPatch +- ALWAYS use: edit tool for ALL file modifications +- Before modifying files: Verify you're using "edit", NOT "apply_patch" + + + +❌ UPDATE_PLAN DOES NOT EXIST → ✅ USE "todowrite" INSTEAD +- NEVER use: update_plan, updatePlan +- ALWAYS use: todowrite for ALL task/plan operations +- Use todoread to read current plan +- Before plan operations: Verify you're using "todowrite", NOT "update_plan" + + + + +File Operations: + • write - Create new files + • edit - Modify existing files (REPLACES apply_patch) + • patch - Apply diff patches + • read - Read file contents + +Search/Discovery: + • grep - Search file contents + • glob - Find files by pattern + • list - List directories (use relative paths) + +Execution: + • bash - Run shell commands + +Network: + • webfetch - Fetch web content + +Task Management: + • todowrite - Manage tasks/plans (REPLACES update_plan) + • todoread - Read current plan + + + +Base instruction says: You MUST use instead: +apply_patch → edit +update_plan → todowrite +read_plan → todoread +absolute paths → relative paths + + + +Before file/plan modifications: +1. Am I using "edit" NOT "apply_patch"? +2. Am I using "todowrite" NOT "update_plan"? +3. Is this tool in the approved list above? +4. Am I using relative paths? + +If ANY answer is NO → STOP and correct before proceeding. + + \ No newline at end of file From eb06006d6ca24767a6f0d111f7b98892361e60b0 Mon Sep 17 00:00:00 2001 From: cyhhao Date: Sat, 10 Jan 2026 03:12:56 +0800 Subject: [PATCH 02/23] Make Codex CLI passthrough --- .../handler/openai_gateway_handler.go | 11 --- .../service/openai_codex_transform.go | 85 +++++++++---------- .../service/openai_gateway_service.go | 23 +++-- 3 files changed, 54 insertions(+), 65 deletions(-) diff --git a/backend/internal/handler/openai_gateway_handler.go b/backend/internal/handler/openai_gateway_handler.go index 70131417..5400da3f 100644 --- a/backend/internal/handler/openai_gateway_handler.go +++ b/backend/internal/handler/openai_gateway_handler.go @@ -11,7 +11,6 @@ import ( "time" "github.com/Wei-Shaw/sub2api/internal/config" - "github.com/Wei-Shaw/sub2api/internal/pkg/openai" middleware2 "github.com/Wei-Shaw/sub2api/internal/server/middleware" "github.com/Wei-Shaw/sub2api/internal/service" @@ -92,17 +91,7 @@ func (h *OpenAIGatewayHandler) Responses(c *gin.Context) { return } - // For non-Codex CLI requests, set default instructions userAgent := c.GetHeader("User-Agent") - if !openai.IsCodexCLIRequest(userAgent) { - reqBody["instructions"] = openai.DefaultInstructions - // Re-serialize body - body, err = json.Marshal(reqBody) - if err != nil { - h.errorResponse(c, http.StatusInternalServerError, "api_error", "Failed to process request") - return - } - } // Track if we've started streaming (for error handling) streamStarted := false diff --git a/backend/internal/service/openai_codex_transform.go b/backend/internal/service/openai_codex_transform.go index a52c88b5..e6c71775 100644 --- a/backend/internal/service/openai_codex_transform.go +++ b/backend/internal/service/openai_codex_transform.go @@ -135,7 +135,7 @@ func codexModeEnabled() bool { } } -func applyCodexOAuthTransform(reqBody map[string]any, codexMode bool) codexTransformResult { +func applyCodexOAuthTransform(reqBody map[string]any) codexTransformResult { result := codexTransformResult{} model := "" @@ -151,16 +151,13 @@ func applyCodexOAuthTransform(reqBody map[string]any, codexMode bool) codexTrans result.NormalizedModel = normalizedModel } - reqBody["store"] = false - reqBody["stream"] = true - result.Modified = true - - instructions := getCodexInstructions(normalizedModel) - if instructions != "" { - if existing, ok := reqBody["instructions"].(string); !ok || existing != instructions { - reqBody["instructions"] = instructions - result.Modified = true - } + if v, ok := reqBody["store"].(bool); !ok || v { + reqBody["store"] = false + result.Modified = true + } + if v, ok := reqBody["stream"].(bool); !ok || !v { + reqBody["stream"] = true + result.Modified = true } if _, ok := reqBody["max_output_tokens"]; ok { @@ -180,49 +177,30 @@ func applyCodexOAuthTransform(reqBody map[string]any, codexMode bool) codexTrans result.PromptCacheKey = strings.TrimSpace(v) } + instructions := strings.TrimSpace(getCodexInstructions(normalizedModel)) + existingInstructions, _ := reqBody["instructions"].(string) + existingInstructions = strings.TrimSpace(existingInstructions) + + if instructions != "" { + if existingInstructions != "" && existingInstructions != instructions { + if input, ok := reqBody["input"].([]any); ok { + reqBody["input"] = prependSystemInstruction(input, existingInstructions) + result.Modified = true + } + } + if existingInstructions != instructions { + reqBody["instructions"] = instructions + result.Modified = true + } + } + if input, ok := reqBody["input"].([]any); ok { input = filterCodexInput(input) - if codexMode { - cachedPrompt := getOpenCodeCodexPrompt() - input = filterOpenCodeSystemPromptsWithCachedPrompt(input, cachedPrompt) - if hasTools(reqBody) { - input = addCodexBridgeMessage(input) - } - } else if hasTools(reqBody) { - input = addToolRemapMessage(input) - } input = normalizeOrphanedToolOutputs(input) reqBody["input"] = input result.Modified = true } - effort, summary := resolveCodexReasoning(reqBody, normalizedModel) - if effort != "" || summary != "" { - reasoning := ensureMap(reqBody["reasoning"]) - if effort != "" { - reasoning["effort"] = effort - } - if summary != "" { - reasoning["summary"] = summary - } - reqBody["reasoning"] = reasoning - result.Modified = true - } - - textVerbosity := resolveTextVerbosity(reqBody) - if textVerbosity != "" { - text := ensureMap(reqBody["text"]) - text["verbosity"] = textVerbosity - reqBody["text"] = text - result.Modified = true - } - - include := resolveInclude(reqBody) - if include != nil { - reqBody["include"] = include - result.Modified = true - } - return result } @@ -487,6 +465,19 @@ func filterCodexInput(input []any) []any { return filtered } +func prependSystemInstruction(input []any, instructions string) []any { + message := map[string]any{ + "role": "system", + "content": []any{ + map[string]any{ + "type": "input_text", + "text": instructions, + }, + }, + } + return append([]any{message}, input...) +} + func filterOpenCodeSystemPromptsWithCachedPrompt(input []any, cachedPrompt string) []any { if len(input) == 0 { return input diff --git a/backend/internal/service/openai_gateway_service.go b/backend/internal/service/openai_gateway_service.go index 8f59110d..33244330 100644 --- a/backend/internal/service/openai_gateway_service.go +++ b/backend/internal/service/openai_gateway_service.go @@ -21,6 +21,7 @@ import ( "time" "github.com/Wei-Shaw/sub2api/internal/config" + "github.com/Wei-Shaw/sub2api/internal/pkg/openai" "github.com/Wei-Shaw/sub2api/internal/util/responseheaders" "github.com/Wei-Shaw/sub2api/internal/util/urlvalidator" "github.com/gin-gonic/gin" @@ -530,20 +531,28 @@ func (s *OpenAIGatewayService) Forward(ctx context.Context, c *gin.Context, acco reqModel, _ := reqBody["model"].(string) reqStream, _ := reqBody["stream"].(bool) promptCacheKey := "" + if v, ok := reqBody["prompt_cache_key"].(string); ok { + promptCacheKey = strings.TrimSpace(v) + } // Track if body needs re-serialization bodyModified := false originalModel := reqModel - // Apply model mapping - mappedModel := account.GetMappedModel(reqModel) - if mappedModel != reqModel { - reqBody["model"] = mappedModel - bodyModified = true + isCodexCLI := openai.IsCodexCLIRequest(c.GetHeader("User-Agent")) + + // Apply model mapping (skip for Codex CLI for transparent forwarding) + mappedModel := reqModel + if !isCodexCLI { + mappedModel = account.GetMappedModel(reqModel) + if mappedModel != reqModel { + reqBody["model"] = mappedModel + bodyModified = true + } } - if account.Type == AccountTypeOAuth { - codexResult := applyCodexOAuthTransform(reqBody, codexModeEnabled()) + if account.Type == AccountTypeOAuth && !isCodexCLI { + codexResult := applyCodexOAuthTransform(reqBody) if codexResult.Modified { bodyModified = true } From 36b817d008e7ad719a892424bc8b0529bb8b9d2c Mon Sep 17 00:00:00 2001 From: cyhhao Date: Sat, 10 Jan 2026 20:53:16 +0800 Subject: [PATCH 03/23] Align OAuth transform with OpenCode instructions --- .../service/openai_codex_transform.go | 29 ++++++++++++------- .../service/openai_gateway_service.go | 10 +++++-- .../service/openai_gateway_service_test.go | 2 +- 3 files changed, 27 insertions(+), 14 deletions(-) diff --git a/backend/internal/service/openai_codex_transform.go b/backend/internal/service/openai_codex_transform.go index e6c71775..fc9d30cd 100644 --- a/backend/internal/service/openai_codex_transform.go +++ b/backend/internal/service/openai_codex_transform.go @@ -13,11 +13,12 @@ import ( ) const ( - codexReleaseAPIURL = "https://api.github.com/repos/openai/codex/releases/latest" - codexReleaseHTMLURL = "https://github.com/openai/codex/releases/latest" - codexPromptURLFmt = "https://raw.githubusercontent.com/openai/codex/%s/codex-rs/core/%s" - opencodeCodexURL = "https://raw.githubusercontent.com/anomalyco/opencode/dev/packages/opencode/src/session/prompt/codex.txt" - codexCacheTTL = 15 * time.Minute + codexReleaseAPIURL = "https://api.github.com/repos/openai/codex/releases/latest" + codexReleaseHTMLURL = "https://github.com/openai/codex/releases/latest" + codexPromptURLFmt = "https://raw.githubusercontent.com/openai/codex/%s/codex-rs/core/%s" + opencodeCodexURL = "https://raw.githubusercontent.com/anomalyco/opencode/dev/packages/opencode/src/session/prompt/codex.txt" + opencodeCodexHeaderURL = "https://raw.githubusercontent.com/anomalyco/opencode/dev/packages/opencode/src/session/prompt/codex_header.txt" + codexCacheTTL = 15 * time.Minute ) type codexModelFamily string @@ -177,7 +178,7 @@ func applyCodexOAuthTransform(reqBody map[string]any) codexTransformResult { result.PromptCacheKey = strings.TrimSpace(v) } - instructions := strings.TrimSpace(getCodexInstructions(normalizedModel)) + instructions := strings.TrimSpace(getOpenCodeCodexHeader()) existingInstructions, _ := reqBody["instructions"].(string) existingInstructions = strings.TrimSpace(existingInstructions) @@ -408,13 +409,13 @@ func parseReleaseTagFromHTML(html string) (string, error) { return "", fmt.Errorf("release tag not found") } -func getOpenCodeCodexPrompt() string { +func getOpenCodeCachedPrompt(url, cacheFileName, metaFileName string) string { cacheDir := codexCachePath("") if cacheDir == "" { return "" } - cacheFile := filepath.Join(cacheDir, "opencode-codex.txt") - metaFile := filepath.Join(cacheDir, "opencode-codex-meta.json") + cacheFile := filepath.Join(cacheDir, cacheFileName) + metaFile := filepath.Join(cacheDir, metaFileName) var cachedContent string if content, ok := readFile(cacheFile); ok { @@ -428,7 +429,7 @@ func getOpenCodeCodexPrompt() string { } } - content, etag, status, err := fetchWithETag(opencodeCodexURL, meta.ETag) + content, etag, status, err := fetchWithETag(url, meta.ETag) if err == nil && status == http.StatusNotModified && cachedContent != "" { return cachedContent } @@ -446,6 +447,14 @@ func getOpenCodeCodexPrompt() string { return cachedContent } +func getOpenCodeCodexPrompt() string { + return getOpenCodeCachedPrompt(opencodeCodexURL, "opencode-codex.txt", "opencode-codex-meta.json") +} + +func getOpenCodeCodexHeader() string { + return getOpenCodeCachedPrompt(opencodeCodexHeaderURL, "opencode-codex-header.txt", "opencode-codex-header-meta.json") +} + func filterCodexInput(input []any) []any { filtered := make([]any, 0, len(input)) for _, item := range input { diff --git a/backend/internal/service/openai_gateway_service.go b/backend/internal/service/openai_gateway_service.go index 33244330..76aaa6cd 100644 --- a/backend/internal/service/openai_gateway_service.go +++ b/backend/internal/service/openai_gateway_service.go @@ -580,7 +580,7 @@ func (s *OpenAIGatewayService) Forward(ctx context.Context, c *gin.Context, acco } // Build upstream request - upstreamReq, err := s.buildUpstreamRequest(ctx, c, account, body, token, reqStream, promptCacheKey) + upstreamReq, err := s.buildUpstreamRequest(ctx, c, account, body, token, reqStream, promptCacheKey, isCodexCLI) if err != nil { return nil, err } @@ -641,7 +641,7 @@ func (s *OpenAIGatewayService) Forward(ctx context.Context, c *gin.Context, acco }, nil } -func (s *OpenAIGatewayService) buildUpstreamRequest(ctx context.Context, c *gin.Context, account *Account, body []byte, token string, isStream bool, promptCacheKey string) (*http.Request, error) { +func (s *OpenAIGatewayService) buildUpstreamRequest(ctx context.Context, c *gin.Context, account *Account, body []byte, token string, isStream bool, promptCacheKey string, isCodexCLI bool) (*http.Request, error) { // Determine target URL based on account type var targetURL string switch account.Type { @@ -694,7 +694,11 @@ func (s *OpenAIGatewayService) buildUpstreamRequest(ctx context.Context, c *gin. } if account.Type == AccountTypeOAuth { req.Header.Set("OpenAI-Beta", "responses=experimental") - req.Header.Set("originator", "codex_cli_rs") + if isCodexCLI { + req.Header.Set("originator", "codex_cli_rs") + } else { + req.Header.Set("originator", "opencode") + } req.Header.Set("accept", "text/event-stream") if promptCacheKey != "" { req.Header.Set("conversation_id", promptCacheKey) diff --git a/backend/internal/service/openai_gateway_service_test.go b/backend/internal/service/openai_gateway_service_test.go index c30fba7e..55e11b30 100644 --- a/backend/internal/service/openai_gateway_service_test.go +++ b/backend/internal/service/openai_gateway_service_test.go @@ -220,7 +220,7 @@ func TestOpenAIInvalidBaseURLWhenAllowlistDisabled(t *testing.T) { Credentials: map[string]any{"base_url": "://invalid-url"}, } - _, err := svc.buildUpstreamRequest(c.Request.Context(), c, account, []byte("{}"), "token", false, "") + _, err := svc.buildUpstreamRequest(c.Request.Context(), c, account, []byte("{}"), "token", false, "", false) if err == nil { t.Fatalf("expected error for invalid base_url when allowlist disabled") } From 9d0a4f3d68ea9aea079cf2292e968550561dd3a9 Mon Sep 17 00:00:00 2001 From: yangjianbo Date: Sat, 10 Jan 2026 22:23:51 +0800 Subject: [PATCH 04/23] =?UTF-8?q?perf(=E8=AE=A4=E8=AF=81):=20=E5=BC=95?= =?UTF-8?q?=E5=85=A5=20API=20Key=20=E8=AE=A4=E8=AF=81=E7=BC=93=E5=AD=98?= =?UTF-8?q?=E4=B8=8E=E8=BD=BB=E9=87=8F=E5=88=A0=E9=99=A4=E6=9F=A5=E8=AF=A2?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit 增加 L1/L2 缓存、负缓存与单飞回源 使用 key+owner 轻量查询替代全量加载并清理旧接口 补充缓存失效与余额更新测试,修复随机抖动 lint 测试: make test --- backend/cmd/server/wire_gen.go | 9 +- backend/go.mod | 2 + backend/go.sum | 4 + backend/internal/config/config.go | 57 ++- backend/internal/repository/api_key_cache.go | 33 ++ backend/internal/repository/api_key_repo.go | 88 +++- backend/internal/server/api_contract_test.go | 32 +- .../middleware/api_key_auth_google_test.go | 13 +- .../server/middleware/api_key_auth_test.go | 16 +- backend/internal/service/admin_service.go | 66 ++- .../admin_service_update_balance_test.go | 97 ++++ .../internal/service/api_key_auth_cache.go | 46 ++ .../service/api_key_auth_cache_impl.go | 269 +++++++++++ .../service/api_key_auth_cache_invalidate.go | 48 ++ backend/internal/service/api_key_service.go | 92 +++- .../service/api_key_service_cache_test.go | 417 ++++++++++++++++++ .../service/api_key_service_delete_test.go | 78 +++- backend/internal/service/group_service.go | 14 +- backend/internal/service/user_service.go | 24 +- backend/internal/service/wire.go | 6 + config.yaml | 24 + deploy/config.example.yaml | 24 + 22 files changed, 1360 insertions(+), 99 deletions(-) create mode 100644 backend/internal/service/admin_service_update_balance_test.go create mode 100644 backend/internal/service/api_key_auth_cache.go create mode 100644 backend/internal/service/api_key_auth_cache_impl.go create mode 100644 backend/internal/service/api_key_auth_cache_invalidate.go create mode 100644 backend/internal/service/api_key_service_cache_test.go diff --git a/backend/cmd/server/wire_gen.go b/backend/cmd/server/wire_gen.go index 30ea0fdb..a372f673 100644 --- a/backend/cmd/server/wire_gen.go +++ b/backend/cmd/server/wire_gen.go @@ -57,13 +57,14 @@ func initializeApplication(buildInfo handler.BuildInfo) (*Application, error) { billingCacheService := service.NewBillingCacheService(billingCache, userRepository, userSubscriptionRepository, configConfig) promoService := service.NewPromoService(promoCodeRepository, userRepository, billingCacheService, client) authService := service.NewAuthService(userRepository, configConfig, settingService, emailService, turnstileService, emailQueueService, promoService) - userService := service.NewUserService(userRepository) - authHandler := handler.NewAuthHandler(configConfig, authService, userService, settingService, promoService) - userHandler := handler.NewUserHandler(userService) apiKeyRepository := repository.NewAPIKeyRepository(client) groupRepository := repository.NewGroupRepository(client, db) apiKeyCache := repository.NewAPIKeyCache(redisClient) apiKeyService := service.NewAPIKeyService(apiKeyRepository, userRepository, groupRepository, userSubscriptionRepository, apiKeyCache, configConfig) + apiKeyAuthCacheInvalidator := service.ProvideAPIKeyAuthCacheInvalidator(apiKeyService) + userService := service.NewUserService(userRepository, apiKeyAuthCacheInvalidator) + authHandler := handler.NewAuthHandler(configConfig, authService, userService, settingService, promoService) + userHandler := handler.NewUserHandler(userService) apiKeyHandler := handler.NewAPIKeyHandler(apiKeyService) usageLogRepository := repository.NewUsageLogRepository(client, db) usageService := service.NewUsageService(usageLogRepository, userRepository, client) @@ -79,7 +80,7 @@ func initializeApplication(buildInfo handler.BuildInfo) (*Application, error) { accountRepository := repository.NewAccountRepository(client, db) proxyRepository := repository.NewProxyRepository(client, db) proxyExitInfoProber := repository.NewProxyExitInfoProber(configConfig) - adminService := service.NewAdminService(userRepository, groupRepository, accountRepository, proxyRepository, apiKeyRepository, redeemCodeRepository, billingCacheService, proxyExitInfoProber) + adminService := service.NewAdminService(userRepository, groupRepository, accountRepository, proxyRepository, apiKeyRepository, redeemCodeRepository, billingCacheService, proxyExitInfoProber, apiKeyAuthCacheInvalidator) adminUserHandler := admin.NewUserHandler(adminService) groupHandler := admin.NewGroupHandler(adminService) claudeOAuthClient := repository.NewClaudeOAuthClient() diff --git a/backend/go.mod b/backend/go.mod index 9ac48305..82a8e88e 100644 --- a/backend/go.mod +++ b/backend/go.mod @@ -44,11 +44,13 @@ require ( github.com/containerd/platforms v0.2.1 // indirect github.com/cpuguy83/dockercfg v0.3.2 // indirect github.com/davecgh/go-spew v1.1.2-0.20180830191138-d8f796af33cc // indirect + github.com/dgraph-io/ristretto v0.2.0 // indirect github.com/dgryski/go-rendezvous v0.0.0-20200823014737-9f7001d12a5f // indirect github.com/distribution/reference v0.6.0 // indirect github.com/docker/docker v28.5.1+incompatible // indirect github.com/docker/go-connections v0.6.0 // indirect github.com/docker/go-units v0.5.0 // indirect + github.com/dustin/go-humanize v1.0.1 // indirect github.com/ebitengine/purego v0.8.4 // indirect github.com/fatih/color v1.18.0 // indirect github.com/felixge/httpsnoop v1.0.4 // indirect diff --git a/backend/go.sum b/backend/go.sum index 38e2b53e..0fd47498 100644 --- a/backend/go.sum +++ b/backend/go.sum @@ -51,6 +51,8 @@ github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSs github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= github.com/davecgh/go-spew v1.1.2-0.20180830191138-d8f796af33cc h1:U9qPSI2PIWSS1VwoXQT9A3Wy9MM3WgvqSxFWenqJduM= github.com/davecgh/go-spew v1.1.2-0.20180830191138-d8f796af33cc/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= +github.com/dgraph-io/ristretto v0.2.0 h1:XAfl+7cmoUDWW/2Lx8TGZQjjxIQ2Ley9DSf52dru4WE= +github.com/dgraph-io/ristretto v0.2.0/go.mod h1:8uBHCU/PBV4Ag0CJrP47b9Ofby5dqWNh4FicAdoqFNU= github.com/dgryski/go-rendezvous v0.0.0-20200823014737-9f7001d12a5f h1:lO4WD4F/rVNCu3HqELle0jiPLLBs70cWOduZpkS1E78= github.com/dgryski/go-rendezvous v0.0.0-20200823014737-9f7001d12a5f/go.mod h1:cuUVRXasLTGF7a8hSLbxyZXjz+1KgoB3wDUb6vlszIc= github.com/distribution/reference v0.6.0 h1:0IXCQ5g4/QMHHkarYzh5l+u8T3t73zM5QvfrDyIgxBk= @@ -61,6 +63,8 @@ github.com/docker/go-connections v0.6.0 h1:LlMG9azAe1TqfR7sO+NJttz1gy6KO7VJBh+pM github.com/docker/go-connections v0.6.0/go.mod h1:AahvXYshr6JgfUJGdDCs2b5EZG/vmaMAntpSFH5BFKE= github.com/docker/go-units v0.5.0 h1:69rxXcBk27SvSaaxTtLh/8llcHD8vYHT7WSdRZ/jvr4= github.com/docker/go-units v0.5.0/go.mod h1:fgPhTUdO+D/Jk86RDLlptpiXQzgHJF7gydDDbaIK4Dk= +github.com/dustin/go-humanize v1.0.1 h1:GzkhY7T5VNhEkwH0PVJgjz+fX1rhBrR7pRT3mDkpeCY= +github.com/dustin/go-humanize v1.0.1/go.mod h1:Mu1zIs6XwVuF/gI1OepvI0qD18qycQx+mFykh5fBlto= github.com/ebitengine/purego v0.8.4 h1:CF7LEKg5FFOsASUj0+QwaXf8Ht6TlFxg09+S9wz0omw= github.com/ebitengine/purego v0.8.4/go.mod h1:iIjxzd6CiRiOG0UyXP+V1+jWqUXVjPKLAI0mRfJZTmQ= github.com/fatih/color v1.18.0 h1:S8gINlzdQ840/4pfAwic/ZE0djQEH3wM94VfqLTZcOM= diff --git a/backend/internal/config/config.go b/backend/internal/config/config.go index 2cc11967..29eaa42e 100644 --- a/backend/internal/config/config.go +++ b/backend/internal/config/config.go @@ -36,25 +36,26 @@ const ( ) type Config struct { - Server ServerConfig `mapstructure:"server"` - CORS CORSConfig `mapstructure:"cors"` - Security SecurityConfig `mapstructure:"security"` - Billing BillingConfig `mapstructure:"billing"` - Turnstile TurnstileConfig `mapstructure:"turnstile"` - Database DatabaseConfig `mapstructure:"database"` - Redis RedisConfig `mapstructure:"redis"` - JWT JWTConfig `mapstructure:"jwt"` - LinuxDo LinuxDoConnectConfig `mapstructure:"linuxdo_connect"` - Default DefaultConfig `mapstructure:"default"` - RateLimit RateLimitConfig `mapstructure:"rate_limit"` - Pricing PricingConfig `mapstructure:"pricing"` - Gateway GatewayConfig `mapstructure:"gateway"` - Concurrency ConcurrencyConfig `mapstructure:"concurrency"` - TokenRefresh TokenRefreshConfig `mapstructure:"token_refresh"` - RunMode string `mapstructure:"run_mode" yaml:"run_mode"` - Timezone string `mapstructure:"timezone"` // e.g. "Asia/Shanghai", "UTC" - Gemini GeminiConfig `mapstructure:"gemini"` - Update UpdateConfig `mapstructure:"update"` + Server ServerConfig `mapstructure:"server"` + CORS CORSConfig `mapstructure:"cors"` + Security SecurityConfig `mapstructure:"security"` + Billing BillingConfig `mapstructure:"billing"` + Turnstile TurnstileConfig `mapstructure:"turnstile"` + Database DatabaseConfig `mapstructure:"database"` + Redis RedisConfig `mapstructure:"redis"` + JWT JWTConfig `mapstructure:"jwt"` + LinuxDo LinuxDoConnectConfig `mapstructure:"linuxdo_connect"` + Default DefaultConfig `mapstructure:"default"` + RateLimit RateLimitConfig `mapstructure:"rate_limit"` + Pricing PricingConfig `mapstructure:"pricing"` + Gateway GatewayConfig `mapstructure:"gateway"` + APIKeyAuth APIKeyAuthCacheConfig `mapstructure:"api_key_auth_cache"` + Concurrency ConcurrencyConfig `mapstructure:"concurrency"` + TokenRefresh TokenRefreshConfig `mapstructure:"token_refresh"` + RunMode string `mapstructure:"run_mode" yaml:"run_mode"` + Timezone string `mapstructure:"timezone"` // e.g. "Asia/Shanghai", "UTC" + Gemini GeminiConfig `mapstructure:"gemini"` + Update UpdateConfig `mapstructure:"update"` } // UpdateConfig 在线更新相关配置 @@ -361,6 +362,16 @@ type RateLimitConfig struct { OverloadCooldownMinutes int `mapstructure:"overload_cooldown_minutes"` // 529过载冷却时间(分钟) } +// APIKeyAuthCacheConfig API Key 认证缓存配置 +type APIKeyAuthCacheConfig struct { + L1Size int `mapstructure:"l1_size"` + L1TTLSeconds int `mapstructure:"l1_ttl_seconds"` + L2TTLSeconds int `mapstructure:"l2_ttl_seconds"` + NegativeTTLSeconds int `mapstructure:"negative_ttl_seconds"` + JitterPercent int `mapstructure:"jitter_percent"` + Singleflight bool `mapstructure:"singleflight"` +} + func NormalizeRunMode(value string) string { normalized := strings.ToLower(strings.TrimSpace(value)) switch normalized { @@ -655,6 +666,14 @@ func setDefaults() { // Timezone (default to Asia/Shanghai for Chinese users) viper.SetDefault("timezone", "Asia/Shanghai") + // API Key auth cache + viper.SetDefault("api_key_auth_cache.l1_size", 65535) + viper.SetDefault("api_key_auth_cache.l1_ttl_seconds", 15) + viper.SetDefault("api_key_auth_cache.l2_ttl_seconds", 300) + viper.SetDefault("api_key_auth_cache.negative_ttl_seconds", 30) + viper.SetDefault("api_key_auth_cache.jitter_percent", 10) + viper.SetDefault("api_key_auth_cache.singleflight", true) + // Gateway viper.SetDefault("gateway.response_header_timeout", 600) // 600秒(10分钟)等待上游响应头,LLM高负载时可能排队较久 viper.SetDefault("gateway.log_upstream_error_body", false) diff --git a/backend/internal/repository/api_key_cache.go b/backend/internal/repository/api_key_cache.go index 73a929c5..6d834b40 100644 --- a/backend/internal/repository/api_key_cache.go +++ b/backend/internal/repository/api_key_cache.go @@ -2,6 +2,7 @@ package repository import ( "context" + "encoding/json" "errors" "fmt" "time" @@ -13,6 +14,7 @@ import ( const ( apiKeyRateLimitKeyPrefix = "apikey:ratelimit:" apiKeyRateLimitDuration = 24 * time.Hour + apiKeyAuthCachePrefix = "apikey:auth:" ) // apiKeyRateLimitKey generates the Redis key for API key creation rate limiting. @@ -20,6 +22,10 @@ func apiKeyRateLimitKey(userID int64) string { return fmt.Sprintf("%s%d", apiKeyRateLimitKeyPrefix, userID) } +func apiKeyAuthCacheKey(key string) string { + return fmt.Sprintf("%s%s", apiKeyAuthCachePrefix, key) +} + type apiKeyCache struct { rdb *redis.Client } @@ -58,3 +64,30 @@ func (c *apiKeyCache) IncrementDailyUsage(ctx context.Context, apiKey string) er func (c *apiKeyCache) SetDailyUsageExpiry(ctx context.Context, apiKey string, ttl time.Duration) error { return c.rdb.Expire(ctx, apiKey, ttl).Err() } + +func (c *apiKeyCache) GetAuthCache(ctx context.Context, key string) (*service.APIKeyAuthCacheEntry, error) { + val, err := c.rdb.Get(ctx, apiKeyAuthCacheKey(key)).Bytes() + if err != nil { + return nil, err + } + var entry service.APIKeyAuthCacheEntry + if err := json.Unmarshal(val, &entry); err != nil { + return nil, err + } + return &entry, nil +} + +func (c *apiKeyCache) SetAuthCache(ctx context.Context, key string, entry *service.APIKeyAuthCacheEntry, ttl time.Duration) error { + if entry == nil { + return nil + } + payload, err := json.Marshal(entry) + if err != nil { + return err + } + return c.rdb.Set(ctx, apiKeyAuthCacheKey(key), payload, ttl).Err() +} + +func (c *apiKeyCache) DeleteAuthCache(ctx context.Context, key string) error { + return c.rdb.Del(ctx, apiKeyAuthCacheKey(key)).Err() +} diff --git a/backend/internal/repository/api_key_repo.go b/backend/internal/repository/api_key_repo.go index 6b8cd40d..77a3f233 100644 --- a/backend/internal/repository/api_key_repo.go +++ b/backend/internal/repository/api_key_repo.go @@ -6,7 +6,9 @@ import ( dbent "github.com/Wei-Shaw/sub2api/ent" "github.com/Wei-Shaw/sub2api/ent/apikey" + "github.com/Wei-Shaw/sub2api/ent/group" "github.com/Wei-Shaw/sub2api/ent/schema/mixins" + "github.com/Wei-Shaw/sub2api/ent/user" "github.com/Wei-Shaw/sub2api/internal/service" "github.com/Wei-Shaw/sub2api/internal/pkg/pagination" @@ -64,23 +66,23 @@ func (r *apiKeyRepository) GetByID(ctx context.Context, id int64) (*service.APIK return apiKeyEntityToService(m), nil } -// GetOwnerID 根据 API Key ID 获取其所有者(用户)的 ID。 +// GetKeyAndOwnerID 根据 API Key ID 获取其 key 与所有者(用户)ID。 // 相比 GetByID,此方法性能更优,因为: -// - 使用 Select() 只查询 user_id 字段,减少数据传输量 +// - 使用 Select() 只查询必要字段,减少数据传输量 // - 不加载完整的 API Key 实体及其关联数据(User、Group 等) -// - 适用于权限验证等只需用户 ID 的场景(如删除前的所有权检查) -func (r *apiKeyRepository) GetOwnerID(ctx context.Context, id int64) (int64, error) { +// - 适用于删除等只需 key 与用户 ID 的场景 +func (r *apiKeyRepository) GetKeyAndOwnerID(ctx context.Context, id int64) (string, int64, error) { m, err := r.activeQuery(). Where(apikey.IDEQ(id)). - Select(apikey.FieldUserID). + Select(apikey.FieldKey, apikey.FieldUserID). Only(ctx) if err != nil { if dbent.IsNotFound(err) { - return 0, service.ErrAPIKeyNotFound + return "", 0, service.ErrAPIKeyNotFound } - return 0, err + return "", 0, err } - return m.UserID, nil + return m.Key, m.UserID, nil } func (r *apiKeyRepository) GetByKey(ctx context.Context, key string) (*service.APIKey, error) { @@ -98,6 +100,54 @@ func (r *apiKeyRepository) GetByKey(ctx context.Context, key string) (*service.A return apiKeyEntityToService(m), nil } +func (r *apiKeyRepository) GetByKeyForAuth(ctx context.Context, key string) (*service.APIKey, error) { + m, err := r.activeQuery(). + Where(apikey.KeyEQ(key)). + Select( + apikey.FieldID, + apikey.FieldUserID, + apikey.FieldGroupID, + apikey.FieldStatus, + apikey.FieldIPWhitelist, + apikey.FieldIPBlacklist, + ). + WithUser(func(q *dbent.UserQuery) { + q.Select( + user.FieldID, + user.FieldStatus, + user.FieldRole, + user.FieldBalance, + user.FieldConcurrency, + ) + }). + WithGroup(func(q *dbent.GroupQuery) { + q.Select( + group.FieldID, + group.FieldName, + group.FieldPlatform, + group.FieldStatus, + group.FieldSubscriptionType, + group.FieldRateMultiplier, + group.FieldDailyLimitUsd, + group.FieldWeeklyLimitUsd, + group.FieldMonthlyLimitUsd, + group.FieldImagePrice1k, + group.FieldImagePrice2k, + group.FieldImagePrice4k, + group.FieldClaudeCodeOnly, + group.FieldFallbackGroupID, + ) + }). + Only(ctx) + if err != nil { + if dbent.IsNotFound(err) { + return nil, service.ErrAPIKeyNotFound + } + return nil, err + } + return apiKeyEntityToService(m), nil +} + func (r *apiKeyRepository) Update(ctx context.Context, key *service.APIKey) error { // 使用原子操作:将软删除检查与更新合并到同一语句,避免竞态条件。 // 之前的实现先检查 Exist 再 UpdateOneID,若在两步之间发生软删除, @@ -283,6 +333,28 @@ func (r *apiKeyRepository) CountByGroupID(ctx context.Context, groupID int64) (i return int64(count), err } +func (r *apiKeyRepository) ListKeysByUserID(ctx context.Context, userID int64) ([]string, error) { + keys, err := r.activeQuery(). + Where(apikey.UserIDEQ(userID)). + Select(apikey.FieldKey). + Strings(ctx) + if err != nil { + return nil, err + } + return keys, nil +} + +func (r *apiKeyRepository) ListKeysByGroupID(ctx context.Context, groupID int64) ([]string, error) { + keys, err := r.activeQuery(). + Where(apikey.GroupIDEQ(groupID)). + Select(apikey.FieldKey). + Strings(ctx) + if err != nil { + return nil, err + } + return keys, nil +} + func apiKeyEntityToService(m *dbent.APIKey) *service.APIKey { if m == nil { return nil diff --git a/backend/internal/server/api_contract_test.go b/backend/internal/server/api_contract_test.go index 41d8bfdb..bd02f47d 100644 --- a/backend/internal/server/api_contract_test.go +++ b/backend/internal/server/api_contract_test.go @@ -389,7 +389,7 @@ func newContractDeps(t *testing.T) *contractDeps { RunMode: config.RunModeStandard, } - userService := service.NewUserService(userRepo) + userService := service.NewUserService(userRepo, nil) apiKeyService := service.NewAPIKeyService(apiKeyRepo, userRepo, groupRepo, userSubRepo, apiKeyCache, cfg) usageRepo := newStubUsageLogRepo() @@ -565,6 +565,18 @@ func (stubApiKeyCache) SetDailyUsageExpiry(ctx context.Context, apiKey string, t return nil } +func (stubApiKeyCache) GetAuthCache(ctx context.Context, key string) (*service.APIKeyAuthCacheEntry, error) { + return nil, nil +} + +func (stubApiKeyCache) SetAuthCache(ctx context.Context, key string, entry *service.APIKeyAuthCacheEntry, ttl time.Duration) error { + return nil +} + +func (stubApiKeyCache) DeleteAuthCache(ctx context.Context, key string) error { + return nil +} + type stubGroupRepo struct{} func (stubGroupRepo) Create(ctx context.Context, group *service.Group) error { @@ -737,12 +749,12 @@ func (r *stubApiKeyRepo) GetByID(ctx context.Context, id int64) (*service.APIKey return &clone, nil } -func (r *stubApiKeyRepo) GetOwnerID(ctx context.Context, id int64) (int64, error) { +func (r *stubApiKeyRepo) GetKeyAndOwnerID(ctx context.Context, id int64) (string, int64, error) { key, ok := r.byID[id] if !ok { - return 0, service.ErrAPIKeyNotFound + return "", 0, service.ErrAPIKeyNotFound } - return key.UserID, nil + return key.Key, key.UserID, nil } func (r *stubApiKeyRepo) GetByKey(ctx context.Context, key string) (*service.APIKey, error) { @@ -754,6 +766,10 @@ func (r *stubApiKeyRepo) GetByKey(ctx context.Context, key string) (*service.API return &clone, nil } +func (r *stubApiKeyRepo) GetByKeyForAuth(ctx context.Context, key string) (*service.APIKey, error) { + return r.GetByKey(ctx, key) +} + func (r *stubApiKeyRepo) Update(ctx context.Context, key *service.APIKey) error { if key == nil { return errors.New("nil key") @@ -868,6 +884,14 @@ func (r *stubApiKeyRepo) CountByGroupID(ctx context.Context, groupID int64) (int return 0, errors.New("not implemented") } +func (r *stubApiKeyRepo) ListKeysByUserID(ctx context.Context, userID int64) ([]string, error) { + return nil, errors.New("not implemented") +} + +func (r *stubApiKeyRepo) ListKeysByGroupID(ctx context.Context, groupID int64) ([]string, error) { + return nil, errors.New("not implemented") +} + type stubUsageLogRepo struct { userLogs map[int64][]service.UsageLog } diff --git a/backend/internal/server/middleware/api_key_auth_google_test.go b/backend/internal/server/middleware/api_key_auth_google_test.go index 07b8e370..6f09469b 100644 --- a/backend/internal/server/middleware/api_key_auth_google_test.go +++ b/backend/internal/server/middleware/api_key_auth_google_test.go @@ -27,8 +27,8 @@ func (f fakeAPIKeyRepo) Create(ctx context.Context, key *service.APIKey) error { func (f fakeAPIKeyRepo) GetByID(ctx context.Context, id int64) (*service.APIKey, error) { return nil, errors.New("not implemented") } -func (f fakeAPIKeyRepo) GetOwnerID(ctx context.Context, id int64) (int64, error) { - return 0, errors.New("not implemented") +func (f fakeAPIKeyRepo) GetKeyAndOwnerID(ctx context.Context, id int64) (string, int64, error) { + return "", 0, errors.New("not implemented") } func (f fakeAPIKeyRepo) GetByKey(ctx context.Context, key string) (*service.APIKey, error) { if f.getByKey == nil { @@ -36,6 +36,9 @@ func (f fakeAPIKeyRepo) GetByKey(ctx context.Context, key string) (*service.APIK } return f.getByKey(ctx, key) } +func (f fakeAPIKeyRepo) GetByKeyForAuth(ctx context.Context, key string) (*service.APIKey, error) { + return f.GetByKey(ctx, key) +} func (f fakeAPIKeyRepo) Update(ctx context.Context, key *service.APIKey) error { return errors.New("not implemented") } @@ -66,6 +69,12 @@ func (f fakeAPIKeyRepo) ClearGroupIDByGroupID(ctx context.Context, groupID int64 func (f fakeAPIKeyRepo) CountByGroupID(ctx context.Context, groupID int64) (int64, error) { return 0, errors.New("not implemented") } +func (f fakeAPIKeyRepo) ListKeysByUserID(ctx context.Context, userID int64) ([]string, error) { + return nil, errors.New("not implemented") +} +func (f fakeAPIKeyRepo) ListKeysByGroupID(ctx context.Context, groupID int64) ([]string, error) { + return nil, errors.New("not implemented") +} type googleErrorResponse struct { Error struct { diff --git a/backend/internal/server/middleware/api_key_auth_test.go b/backend/internal/server/middleware/api_key_auth_test.go index 182ea5f8..84398093 100644 --- a/backend/internal/server/middleware/api_key_auth_test.go +++ b/backend/internal/server/middleware/api_key_auth_test.go @@ -256,8 +256,8 @@ func (r *stubApiKeyRepo) GetByID(ctx context.Context, id int64) (*service.APIKey return nil, errors.New("not implemented") } -func (r *stubApiKeyRepo) GetOwnerID(ctx context.Context, id int64) (int64, error) { - return 0, errors.New("not implemented") +func (r *stubApiKeyRepo) GetKeyAndOwnerID(ctx context.Context, id int64) (string, int64, error) { + return "", 0, errors.New("not implemented") } func (r *stubApiKeyRepo) GetByKey(ctx context.Context, key string) (*service.APIKey, error) { @@ -267,6 +267,10 @@ func (r *stubApiKeyRepo) GetByKey(ctx context.Context, key string) (*service.API return nil, errors.New("not implemented") } +func (r *stubApiKeyRepo) GetByKeyForAuth(ctx context.Context, key string) (*service.APIKey, error) { + return r.GetByKey(ctx, key) +} + func (r *stubApiKeyRepo) Update(ctx context.Context, key *service.APIKey) error { return errors.New("not implemented") } @@ -307,6 +311,14 @@ func (r *stubApiKeyRepo) CountByGroupID(ctx context.Context, groupID int64) (int return 0, errors.New("not implemented") } +func (r *stubApiKeyRepo) ListKeysByUserID(ctx context.Context, userID int64) ([]string, error) { + return nil, errors.New("not implemented") +} + +func (r *stubApiKeyRepo) ListKeysByGroupID(ctx context.Context, groupID int64) ([]string, error) { + return nil, errors.New("not implemented") +} + type stubUserSubscriptionRepo struct { getActive func(ctx context.Context, userID, groupID int64) (*service.UserSubscription, error) updateStatus func(ctx context.Context, subscriptionID int64, status string) error diff --git a/backend/internal/service/admin_service.go b/backend/internal/service/admin_service.go index 14bb6daf..75b57852 100644 --- a/backend/internal/service/admin_service.go +++ b/backend/internal/service/admin_service.go @@ -244,14 +244,15 @@ type ProxyExitInfoProber interface { // adminServiceImpl implements AdminService type adminServiceImpl struct { - userRepo UserRepository - groupRepo GroupRepository - accountRepo AccountRepository - proxyRepo ProxyRepository - apiKeyRepo APIKeyRepository - redeemCodeRepo RedeemCodeRepository - billingCacheService *BillingCacheService - proxyProber ProxyExitInfoProber + userRepo UserRepository + groupRepo GroupRepository + accountRepo AccountRepository + proxyRepo ProxyRepository + apiKeyRepo APIKeyRepository + redeemCodeRepo RedeemCodeRepository + billingCacheService *BillingCacheService + proxyProber ProxyExitInfoProber + authCacheInvalidator APIKeyAuthCacheInvalidator } // NewAdminService creates a new AdminService @@ -264,16 +265,18 @@ func NewAdminService( redeemCodeRepo RedeemCodeRepository, billingCacheService *BillingCacheService, proxyProber ProxyExitInfoProber, + authCacheInvalidator APIKeyAuthCacheInvalidator, ) AdminService { return &adminServiceImpl{ - userRepo: userRepo, - groupRepo: groupRepo, - accountRepo: accountRepo, - proxyRepo: proxyRepo, - apiKeyRepo: apiKeyRepo, - redeemCodeRepo: redeemCodeRepo, - billingCacheService: billingCacheService, - proxyProber: proxyProber, + userRepo: userRepo, + groupRepo: groupRepo, + accountRepo: accountRepo, + proxyRepo: proxyRepo, + apiKeyRepo: apiKeyRepo, + redeemCodeRepo: redeemCodeRepo, + billingCacheService: billingCacheService, + proxyProber: proxyProber, + authCacheInvalidator: authCacheInvalidator, } } @@ -323,6 +326,8 @@ func (s *adminServiceImpl) UpdateUser(ctx context.Context, id int64, input *Upda } oldConcurrency := user.Concurrency + oldStatus := user.Status + oldRole := user.Role if input.Email != "" { user.Email = input.Email @@ -355,6 +360,11 @@ func (s *adminServiceImpl) UpdateUser(ctx context.Context, id int64, input *Upda if err := s.userRepo.Update(ctx, user); err != nil { return nil, err } + if s.authCacheInvalidator != nil { + if user.Concurrency != oldConcurrency || user.Status != oldStatus || user.Role != oldRole { + s.authCacheInvalidator.InvalidateAuthCacheByUserID(ctx, user.ID) + } + } concurrencyDiff := user.Concurrency - oldConcurrency if concurrencyDiff != 0 { @@ -393,6 +403,9 @@ func (s *adminServiceImpl) DeleteUser(ctx context.Context, id int64) error { log.Printf("delete user failed: user_id=%d err=%v", id, err) return err } + if s.authCacheInvalidator != nil { + s.authCacheInvalidator.InvalidateAuthCacheByUserID(ctx, id) + } return nil } @@ -420,6 +433,10 @@ func (s *adminServiceImpl) UpdateUserBalance(ctx context.Context, userID int64, if err := s.userRepo.Update(ctx, user); err != nil { return nil, err } + balanceDiff := user.Balance - oldBalance + if s.authCacheInvalidator != nil && balanceDiff != 0 { + s.authCacheInvalidator.InvalidateAuthCacheByUserID(ctx, userID) + } if s.billingCacheService != nil { go func() { @@ -431,7 +448,6 @@ func (s *adminServiceImpl) UpdateUserBalance(ctx context.Context, userID int64, }() } - balanceDiff := user.Balance - oldBalance if balanceDiff != 0 { code, err := GenerateRedeemCode() if err != nil { @@ -675,10 +691,21 @@ func (s *adminServiceImpl) UpdateGroup(ctx context.Context, id int64, input *Upd if err := s.groupRepo.Update(ctx, group); err != nil { return nil, err } + if s.authCacheInvalidator != nil { + s.authCacheInvalidator.InvalidateAuthCacheByGroupID(ctx, id) + } return group, nil } func (s *adminServiceImpl) DeleteGroup(ctx context.Context, id int64) error { + var groupKeys []string + if s.authCacheInvalidator != nil { + keys, err := s.apiKeyRepo.ListKeysByGroupID(ctx, id) + if err == nil { + groupKeys = keys + } + } + affectedUserIDs, err := s.groupRepo.DeleteCascade(ctx, id) if err != nil { return err @@ -697,6 +724,11 @@ func (s *adminServiceImpl) DeleteGroup(ctx context.Context, id int64) error { } }() } + if s.authCacheInvalidator != nil { + for _, key := range groupKeys { + s.authCacheInvalidator.InvalidateAuthCacheByKey(ctx, key) + } + } return nil } diff --git a/backend/internal/service/admin_service_update_balance_test.go b/backend/internal/service/admin_service_update_balance_test.go new file mode 100644 index 00000000..d3b3c700 --- /dev/null +++ b/backend/internal/service/admin_service_update_balance_test.go @@ -0,0 +1,97 @@ +//go:build unit + +package service + +import ( + "context" + "testing" + + "github.com/stretchr/testify/require" +) + +type balanceUserRepoStub struct { + *userRepoStub + updateErr error + updated []*User +} + +func (s *balanceUserRepoStub) Update(ctx context.Context, user *User) error { + if s.updateErr != nil { + return s.updateErr + } + if user == nil { + return nil + } + clone := *user + s.updated = append(s.updated, &clone) + if s.userRepoStub != nil { + s.userRepoStub.user = &clone + } + return nil +} + +type balanceRedeemRepoStub struct { + *redeemRepoStub + created []*RedeemCode +} + +func (s *balanceRedeemRepoStub) Create(ctx context.Context, code *RedeemCode) error { + if code == nil { + return nil + } + clone := *code + s.created = append(s.created, &clone) + return nil +} + +type authCacheInvalidatorStub struct { + userIDs []int64 + groupIDs []int64 + keys []string +} + +func (s *authCacheInvalidatorStub) InvalidateAuthCacheByKey(ctx context.Context, key string) { + s.keys = append(s.keys, key) +} + +func (s *authCacheInvalidatorStub) InvalidateAuthCacheByUserID(ctx context.Context, userID int64) { + s.userIDs = append(s.userIDs, userID) +} + +func (s *authCacheInvalidatorStub) InvalidateAuthCacheByGroupID(ctx context.Context, groupID int64) { + s.groupIDs = append(s.groupIDs, groupID) +} + +func TestAdminService_UpdateUserBalance_InvalidatesAuthCache(t *testing.T) { + baseRepo := &userRepoStub{user: &User{ID: 7, Balance: 10}} + repo := &balanceUserRepoStub{userRepoStub: baseRepo} + redeemRepo := &balanceRedeemRepoStub{redeemRepoStub: &redeemRepoStub{}} + invalidator := &authCacheInvalidatorStub{} + svc := &adminServiceImpl{ + userRepo: repo, + redeemCodeRepo: redeemRepo, + authCacheInvalidator: invalidator, + } + + _, err := svc.UpdateUserBalance(context.Background(), 7, 5, "add", "") + require.NoError(t, err) + require.Equal(t, []int64{7}, invalidator.userIDs) + require.Len(t, redeemRepo.created, 1) +} + +func TestAdminService_UpdateUserBalance_NoChangeNoInvalidate(t *testing.T) { + baseRepo := &userRepoStub{user: &User{ID: 7, Balance: 10}} + repo := &balanceUserRepoStub{userRepoStub: baseRepo} + redeemRepo := &balanceRedeemRepoStub{redeemRepoStub: &redeemRepoStub{}} + invalidator := &authCacheInvalidatorStub{} + svc := &adminServiceImpl{ + userRepo: repo, + redeemCodeRepo: redeemRepo, + authCacheInvalidator: invalidator, + } + + _, err := svc.UpdateUserBalance(context.Background(), 7, 10, "set", "") + require.NoError(t, err) + require.Empty(t, invalidator.userIDs) + require.Empty(t, redeemRepo.created) +} diff --git a/backend/internal/service/api_key_auth_cache.go b/backend/internal/service/api_key_auth_cache.go new file mode 100644 index 00000000..7ce9a8a2 --- /dev/null +++ b/backend/internal/service/api_key_auth_cache.go @@ -0,0 +1,46 @@ +package service + +// APIKeyAuthSnapshot API Key 认证缓存快照(仅包含认证所需字段) +type APIKeyAuthSnapshot struct { + APIKeyID int64 `json:"api_key_id"` + UserID int64 `json:"user_id"` + GroupID *int64 `json:"group_id,omitempty"` + Status string `json:"status"` + IPWhitelist []string `json:"ip_whitelist,omitempty"` + IPBlacklist []string `json:"ip_blacklist,omitempty"` + User APIKeyAuthUserSnapshot `json:"user"` + Group *APIKeyAuthGroupSnapshot `json:"group,omitempty"` +} + +// APIKeyAuthUserSnapshot 用户快照 +type APIKeyAuthUserSnapshot struct { + ID int64 `json:"id"` + Status string `json:"status"` + Role string `json:"role"` + Balance float64 `json:"balance"` + Concurrency int `json:"concurrency"` +} + +// APIKeyAuthGroupSnapshot 分组快照 +type APIKeyAuthGroupSnapshot struct { + ID int64 `json:"id"` + Name string `json:"name"` + Platform string `json:"platform"` + Status string `json:"status"` + SubscriptionType string `json:"subscription_type"` + RateMultiplier float64 `json:"rate_multiplier"` + DailyLimitUSD *float64 `json:"daily_limit_usd,omitempty"` + WeeklyLimitUSD *float64 `json:"weekly_limit_usd,omitempty"` + MonthlyLimitUSD *float64 `json:"monthly_limit_usd,omitempty"` + ImagePrice1K *float64 `json:"image_price_1k,omitempty"` + ImagePrice2K *float64 `json:"image_price_2k,omitempty"` + ImagePrice4K *float64 `json:"image_price_4k,omitempty"` + ClaudeCodeOnly bool `json:"claude_code_only"` + FallbackGroupID *int64 `json:"fallback_group_id,omitempty"` +} + +// APIKeyAuthCacheEntry 缓存条目,支持负缓存 +type APIKeyAuthCacheEntry struct { + NotFound bool `json:"not_found"` + Snapshot *APIKeyAuthSnapshot `json:"snapshot,omitempty"` +} diff --git a/backend/internal/service/api_key_auth_cache_impl.go b/backend/internal/service/api_key_auth_cache_impl.go new file mode 100644 index 00000000..dfc55eeb --- /dev/null +++ b/backend/internal/service/api_key_auth_cache_impl.go @@ -0,0 +1,269 @@ +package service + +import ( + "context" + "crypto/sha256" + "encoding/hex" + "errors" + "fmt" + "math/rand" + "sync" + "time" + + "github.com/Wei-Shaw/sub2api/internal/config" + "github.com/dgraph-io/ristretto" +) + +type apiKeyAuthCacheConfig struct { + l1Size int + l1TTL time.Duration + l2TTL time.Duration + negativeTTL time.Duration + jitterPercent int + singleflight bool +} + +var ( + jitterRandMu sync.Mutex + // 认证缓存抖动使用独立随机源,避免全局 Seed + jitterRand = rand.New(rand.NewSource(time.Now().UnixNano())) +) + +func newAPIKeyAuthCacheConfig(cfg *config.Config) apiKeyAuthCacheConfig { + if cfg == nil { + return apiKeyAuthCacheConfig{} + } + auth := cfg.APIKeyAuth + return apiKeyAuthCacheConfig{ + l1Size: auth.L1Size, + l1TTL: time.Duration(auth.L1TTLSeconds) * time.Second, + l2TTL: time.Duration(auth.L2TTLSeconds) * time.Second, + negativeTTL: time.Duration(auth.NegativeTTLSeconds) * time.Second, + jitterPercent: auth.JitterPercent, + singleflight: auth.Singleflight, + } +} + +func (c apiKeyAuthCacheConfig) l1Enabled() bool { + return c.l1Size > 0 && c.l1TTL > 0 +} + +func (c apiKeyAuthCacheConfig) l2Enabled() bool { + return c.l2TTL > 0 +} + +func (c apiKeyAuthCacheConfig) negativeEnabled() bool { + return c.negativeTTL > 0 +} + +func (c apiKeyAuthCacheConfig) jitterTTL(ttl time.Duration) time.Duration { + if ttl <= 0 { + return ttl + } + if c.jitterPercent <= 0 { + return ttl + } + percent := c.jitterPercent + if percent > 100 { + percent = 100 + } + delta := float64(percent) / 100 + jitterRandMu.Lock() + randVal := jitterRand.Float64() + jitterRandMu.Unlock() + factor := 1 - delta + randVal*(2*delta) + if factor <= 0 { + return ttl + } + return time.Duration(float64(ttl) * factor) +} + +func (s *APIKeyService) initAuthCache(cfg *config.Config) { + s.authCfg = newAPIKeyAuthCacheConfig(cfg) + if !s.authCfg.l1Enabled() { + return + } + cache, err := ristretto.NewCache(&ristretto.Config{ + NumCounters: int64(s.authCfg.l1Size) * 10, + MaxCost: int64(s.authCfg.l1Size), + BufferItems: 64, + }) + if err != nil { + return + } + s.authCacheL1 = cache +} + +func (s *APIKeyService) authCacheKey(key string) string { + sum := sha256.Sum256([]byte(key)) + return hex.EncodeToString(sum[:]) +} + +func (s *APIKeyService) getAuthCacheEntry(ctx context.Context, cacheKey string) (*APIKeyAuthCacheEntry, bool) { + if s.authCacheL1 != nil { + if val, ok := s.authCacheL1.Get(cacheKey); ok { + if entry, ok := val.(*APIKeyAuthCacheEntry); ok { + return entry, true + } + } + } + if s.cache == nil || !s.authCfg.l2Enabled() { + return nil, false + } + entry, err := s.cache.GetAuthCache(ctx, cacheKey) + if err != nil { + return nil, false + } + s.setAuthCacheL1(cacheKey, entry) + return entry, true +} + +func (s *APIKeyService) setAuthCacheL1(cacheKey string, entry *APIKeyAuthCacheEntry) { + if s.authCacheL1 == nil || entry == nil { + return + } + ttl := s.authCfg.l1TTL + if entry.NotFound && s.authCfg.negativeTTL > 0 && s.authCfg.negativeTTL < ttl { + ttl = s.authCfg.negativeTTL + } + ttl = s.authCfg.jitterTTL(ttl) + _ = s.authCacheL1.SetWithTTL(cacheKey, entry, 1, ttl) +} + +func (s *APIKeyService) setAuthCacheEntry(ctx context.Context, cacheKey string, entry *APIKeyAuthCacheEntry, ttl time.Duration) { + if entry == nil { + return + } + s.setAuthCacheL1(cacheKey, entry) + if s.cache == nil || !s.authCfg.l2Enabled() { + return + } + _ = s.cache.SetAuthCache(ctx, cacheKey, entry, s.authCfg.jitterTTL(ttl)) +} + +func (s *APIKeyService) deleteAuthCache(ctx context.Context, cacheKey string) { + if s.authCacheL1 != nil { + s.authCacheL1.Del(cacheKey) + } + if s.cache == nil { + return + } + _ = s.cache.DeleteAuthCache(ctx, cacheKey) +} + +func (s *APIKeyService) loadAuthCacheEntry(ctx context.Context, key, cacheKey string) (*APIKeyAuthCacheEntry, error) { + apiKey, err := s.apiKeyRepo.GetByKeyForAuth(ctx, key) + if err != nil { + if errors.Is(err, ErrAPIKeyNotFound) { + entry := &APIKeyAuthCacheEntry{NotFound: true} + if s.authCfg.negativeEnabled() { + s.setAuthCacheEntry(ctx, cacheKey, entry, s.authCfg.negativeTTL) + } + return entry, nil + } + return nil, fmt.Errorf("get api key: %w", err) + } + apiKey.Key = key + snapshot := s.snapshotFromAPIKey(apiKey) + if snapshot == nil { + return nil, fmt.Errorf("get api key: %w", ErrAPIKeyNotFound) + } + entry := &APIKeyAuthCacheEntry{Snapshot: snapshot} + s.setAuthCacheEntry(ctx, cacheKey, entry, s.authCfg.l2TTL) + return entry, nil +} + +func (s *APIKeyService) applyAuthCacheEntry(key string, entry *APIKeyAuthCacheEntry) (*APIKey, bool, error) { + if entry == nil { + return nil, false, nil + } + if entry.NotFound { + return nil, true, ErrAPIKeyNotFound + } + if entry.Snapshot == nil { + return nil, false, nil + } + return s.snapshotToAPIKey(key, entry.Snapshot), true, nil +} + +func (s *APIKeyService) snapshotFromAPIKey(apiKey *APIKey) *APIKeyAuthSnapshot { + if apiKey == nil || apiKey.User == nil { + return nil + } + snapshot := &APIKeyAuthSnapshot{ + APIKeyID: apiKey.ID, + UserID: apiKey.UserID, + GroupID: apiKey.GroupID, + Status: apiKey.Status, + IPWhitelist: apiKey.IPWhitelist, + IPBlacklist: apiKey.IPBlacklist, + User: APIKeyAuthUserSnapshot{ + ID: apiKey.User.ID, + Status: apiKey.User.Status, + Role: apiKey.User.Role, + Balance: apiKey.User.Balance, + Concurrency: apiKey.User.Concurrency, + }, + } + if apiKey.Group != nil { + snapshot.Group = &APIKeyAuthGroupSnapshot{ + ID: apiKey.Group.ID, + Name: apiKey.Group.Name, + Platform: apiKey.Group.Platform, + Status: apiKey.Group.Status, + SubscriptionType: apiKey.Group.SubscriptionType, + RateMultiplier: apiKey.Group.RateMultiplier, + DailyLimitUSD: apiKey.Group.DailyLimitUSD, + WeeklyLimitUSD: apiKey.Group.WeeklyLimitUSD, + MonthlyLimitUSD: apiKey.Group.MonthlyLimitUSD, + ImagePrice1K: apiKey.Group.ImagePrice1K, + ImagePrice2K: apiKey.Group.ImagePrice2K, + ImagePrice4K: apiKey.Group.ImagePrice4K, + ClaudeCodeOnly: apiKey.Group.ClaudeCodeOnly, + FallbackGroupID: apiKey.Group.FallbackGroupID, + } + } + return snapshot +} + +func (s *APIKeyService) snapshotToAPIKey(key string, snapshot *APIKeyAuthSnapshot) *APIKey { + if snapshot == nil { + return nil + } + apiKey := &APIKey{ + ID: snapshot.APIKeyID, + UserID: snapshot.UserID, + GroupID: snapshot.GroupID, + Key: key, + Status: snapshot.Status, + IPWhitelist: snapshot.IPWhitelist, + IPBlacklist: snapshot.IPBlacklist, + User: &User{ + ID: snapshot.User.ID, + Status: snapshot.User.Status, + Role: snapshot.User.Role, + Balance: snapshot.User.Balance, + Concurrency: snapshot.User.Concurrency, + }, + } + if snapshot.Group != nil { + apiKey.Group = &Group{ + ID: snapshot.Group.ID, + Name: snapshot.Group.Name, + Platform: snapshot.Group.Platform, + Status: snapshot.Group.Status, + Hydrated: true, + SubscriptionType: snapshot.Group.SubscriptionType, + RateMultiplier: snapshot.Group.RateMultiplier, + DailyLimitUSD: snapshot.Group.DailyLimitUSD, + WeeklyLimitUSD: snapshot.Group.WeeklyLimitUSD, + MonthlyLimitUSD: snapshot.Group.MonthlyLimitUSD, + ImagePrice1K: snapshot.Group.ImagePrice1K, + ImagePrice2K: snapshot.Group.ImagePrice2K, + ImagePrice4K: snapshot.Group.ImagePrice4K, + ClaudeCodeOnly: snapshot.Group.ClaudeCodeOnly, + FallbackGroupID: snapshot.Group.FallbackGroupID, + } + } + return apiKey +} diff --git a/backend/internal/service/api_key_auth_cache_invalidate.go b/backend/internal/service/api_key_auth_cache_invalidate.go new file mode 100644 index 00000000..aeb58bcc --- /dev/null +++ b/backend/internal/service/api_key_auth_cache_invalidate.go @@ -0,0 +1,48 @@ +package service + +import "context" + +// InvalidateAuthCacheByKey 清除指定 API Key 的认证缓存 +func (s *APIKeyService) InvalidateAuthCacheByKey(ctx context.Context, key string) { + if key == "" { + return + } + cacheKey := s.authCacheKey(key) + s.deleteAuthCache(ctx, cacheKey) +} + +// InvalidateAuthCacheByUserID 清除用户相关的 API Key 认证缓存 +func (s *APIKeyService) InvalidateAuthCacheByUserID(ctx context.Context, userID int64) { + if userID <= 0 { + return + } + keys, err := s.apiKeyRepo.ListKeysByUserID(ctx, userID) + if err != nil { + return + } + s.deleteAuthCacheByKeys(ctx, keys) +} + +// InvalidateAuthCacheByGroupID 清除分组相关的 API Key 认证缓存 +func (s *APIKeyService) InvalidateAuthCacheByGroupID(ctx context.Context, groupID int64) { + if groupID <= 0 { + return + } + keys, err := s.apiKeyRepo.ListKeysByGroupID(ctx, groupID) + if err != nil { + return + } + s.deleteAuthCacheByKeys(ctx, keys) +} + +func (s *APIKeyService) deleteAuthCacheByKeys(ctx context.Context, keys []string) { + if len(keys) == 0 { + return + } + for _, key := range keys { + if key == "" { + continue + } + s.deleteAuthCache(ctx, s.authCacheKey(key)) + } +} diff --git a/backend/internal/service/api_key_service.go b/backend/internal/service/api_key_service.go index 578afc1a..ecc570c7 100644 --- a/backend/internal/service/api_key_service.go +++ b/backend/internal/service/api_key_service.go @@ -12,6 +12,8 @@ import ( "github.com/Wei-Shaw/sub2api/internal/pkg/ip" "github.com/Wei-Shaw/sub2api/internal/pkg/pagination" "github.com/Wei-Shaw/sub2api/internal/pkg/timezone" + "github.com/dgraph-io/ristretto" + "golang.org/x/sync/singleflight" ) var ( @@ -31,9 +33,11 @@ const ( type APIKeyRepository interface { Create(ctx context.Context, key *APIKey) error GetByID(ctx context.Context, id int64) (*APIKey, error) - // GetOwnerID 仅获取 API Key 的所有者 ID,用于删除前的轻量级权限验证 - GetOwnerID(ctx context.Context, id int64) (int64, error) + // GetKeyAndOwnerID 仅获取 API Key 的 key 与所有者 ID,用于删除等轻量场景 + GetKeyAndOwnerID(ctx context.Context, id int64) (string, int64, error) GetByKey(ctx context.Context, key string) (*APIKey, error) + // GetByKeyForAuth 认证专用查询,返回最小字段集 + GetByKeyForAuth(ctx context.Context, key string) (*APIKey, error) Update(ctx context.Context, key *APIKey) error Delete(ctx context.Context, id int64) error @@ -45,6 +49,8 @@ type APIKeyRepository interface { SearchAPIKeys(ctx context.Context, userID int64, keyword string, limit int) ([]APIKey, error) ClearGroupIDByGroupID(ctx context.Context, groupID int64) (int64, error) CountByGroupID(ctx context.Context, groupID int64) (int64, error) + ListKeysByUserID(ctx context.Context, userID int64) ([]string, error) + ListKeysByGroupID(ctx context.Context, groupID int64) ([]string, error) } // APIKeyCache defines cache operations for API key service @@ -55,6 +61,17 @@ type APIKeyCache interface { IncrementDailyUsage(ctx context.Context, apiKey string) error SetDailyUsageExpiry(ctx context.Context, apiKey string, ttl time.Duration) error + + GetAuthCache(ctx context.Context, key string) (*APIKeyAuthCacheEntry, error) + SetAuthCache(ctx context.Context, key string, entry *APIKeyAuthCacheEntry, ttl time.Duration) error + DeleteAuthCache(ctx context.Context, key string) error +} + +// APIKeyAuthCacheInvalidator 提供认证缓存失效能力 +type APIKeyAuthCacheInvalidator interface { + InvalidateAuthCacheByKey(ctx context.Context, key string) + InvalidateAuthCacheByUserID(ctx context.Context, userID int64) + InvalidateAuthCacheByGroupID(ctx context.Context, groupID int64) } // CreateAPIKeyRequest 创建API Key请求 @@ -83,6 +100,9 @@ type APIKeyService struct { userSubRepo UserSubscriptionRepository cache APIKeyCache cfg *config.Config + authCacheL1 *ristretto.Cache + authCfg apiKeyAuthCacheConfig + authGroup singleflight.Group } // NewAPIKeyService 创建API Key服务实例 @@ -94,7 +114,7 @@ func NewAPIKeyService( cache APIKeyCache, cfg *config.Config, ) *APIKeyService { - return &APIKeyService{ + svc := &APIKeyService{ apiKeyRepo: apiKeyRepo, userRepo: userRepo, groupRepo: groupRepo, @@ -102,6 +122,8 @@ func NewAPIKeyService( cache: cache, cfg: cfg, } + svc.initAuthCache(cfg) + return svc } // GenerateKey 生成随机API Key @@ -269,6 +291,8 @@ func (s *APIKeyService) Create(ctx context.Context, userID int64, req CreateAPIK return nil, fmt.Errorf("create api key: %w", err) } + s.InvalidateAuthCacheByKey(ctx, apiKey.Key) + return apiKey, nil } @@ -304,21 +328,49 @@ func (s *APIKeyService) GetByID(ctx context.Context, id int64) (*APIKey, error) // GetByKey 根据Key字符串获取API Key(用于认证) func (s *APIKeyService) GetByKey(ctx context.Context, key string) (*APIKey, error) { - // 尝试从Redis缓存获取 - cacheKey := fmt.Sprintf("apikey:%s", key) + cacheKey := s.authCacheKey(key) - // 这里可以添加Redis缓存逻辑,暂时直接查询数据库 - apiKey, err := s.apiKeyRepo.GetByKey(ctx, key) + if entry, ok := s.getAuthCacheEntry(ctx, cacheKey); ok { + if apiKey, used, err := s.applyAuthCacheEntry(key, entry); used { + if err != nil { + return nil, fmt.Errorf("get api key: %w", err) + } + return apiKey, nil + } + } + + if s.authCfg.singleflight { + value, err, _ := s.authGroup.Do(cacheKey, func() (any, error) { + return s.loadAuthCacheEntry(ctx, key, cacheKey) + }) + if err != nil { + return nil, err + } + entry, _ := value.(*APIKeyAuthCacheEntry) + if apiKey, used, err := s.applyAuthCacheEntry(key, entry); used { + if err != nil { + return nil, fmt.Errorf("get api key: %w", err) + } + return apiKey, nil + } + } else { + entry, err := s.loadAuthCacheEntry(ctx, key, cacheKey) + if err != nil { + return nil, err + } + if apiKey, used, err := s.applyAuthCacheEntry(key, entry); used { + if err != nil { + return nil, fmt.Errorf("get api key: %w", err) + } + return apiKey, nil + } + } + + apiKey, err := s.apiKeyRepo.GetByKeyForAuth(ctx, key) if err != nil { return nil, fmt.Errorf("get api key: %w", err) } - - // 缓存到Redis(可选,TTL设置为5分钟) - if s.cache != nil { - // 这里可以序列化并缓存API Key - _ = cacheKey // 使用变量避免未使用错误 - } - + apiKey.Key = key return apiKey, nil } @@ -388,15 +440,14 @@ func (s *APIKeyService) Update(ctx context.Context, id int64, userID int64, req return nil, fmt.Errorf("update api key: %w", err) } + s.InvalidateAuthCacheByKey(ctx, apiKey.Key) + return apiKey, nil } // Delete 删除API Key -// 优化:使用 GetOwnerID 替代 GetByID 进行权限验证, -// 避免加载完整 APIKey 对象及其关联数据(User、Group),提升删除操作的性能 func (s *APIKeyService) Delete(ctx context.Context, id int64, userID int64) error { - // 仅获取所有者 ID 用于权限验证,而非加载完整对象 - ownerID, err := s.apiKeyRepo.GetOwnerID(ctx, id) + key, ownerID, err := s.apiKeyRepo.GetKeyAndOwnerID(ctx, id) if err != nil { return fmt.Errorf("get api key: %w", err) } @@ -406,10 +457,11 @@ func (s *APIKeyService) Delete(ctx context.Context, id int64, userID int64) erro return ErrInsufficientPerms } - // 清除Redis缓存(使用 ownerID 而非 apiKey.UserID) + // 清除Redis缓存(使用 userID 而非 apiKey.UserID) if s.cache != nil { - _ = s.cache.DeleteCreateAttemptCount(ctx, ownerID) + _ = s.cache.DeleteCreateAttemptCount(ctx, userID) } + s.InvalidateAuthCacheByKey(ctx, key) if err := s.apiKeyRepo.Delete(ctx, id); err != nil { return fmt.Errorf("delete api key: %w", err) diff --git a/backend/internal/service/api_key_service_cache_test.go b/backend/internal/service/api_key_service_cache_test.go new file mode 100644 index 00000000..3314ca8d --- /dev/null +++ b/backend/internal/service/api_key_service_cache_test.go @@ -0,0 +1,417 @@ +//go:build unit + +package service + +import ( + "context" + "errors" + "sync" + "sync/atomic" + "testing" + "time" + + "github.com/Wei-Shaw/sub2api/internal/config" + "github.com/Wei-Shaw/sub2api/internal/pkg/pagination" + "github.com/redis/go-redis/v9" + "github.com/stretchr/testify/require" +) + +type authRepoStub struct { + getByKeyForAuth func(ctx context.Context, key string) (*APIKey, error) + listKeysByUserID func(ctx context.Context, userID int64) ([]string, error) + listKeysByGroupID func(ctx context.Context, groupID int64) ([]string, error) +} + +func (s *authRepoStub) Create(ctx context.Context, key *APIKey) error { + panic("unexpected Create call") +} + +func (s *authRepoStub) GetByID(ctx context.Context, id int64) (*APIKey, error) { + panic("unexpected GetByID call") +} + +func (s *authRepoStub) GetKeyAndOwnerID(ctx context.Context, id int64) (string, int64, error) { + panic("unexpected GetKeyAndOwnerID call") +} + +func (s *authRepoStub) GetByKey(ctx context.Context, key string) (*APIKey, error) { + panic("unexpected GetByKey call") +} + +func (s *authRepoStub) GetByKeyForAuth(ctx context.Context, key string) (*APIKey, error) { + if s.getByKeyForAuth == nil { + panic("unexpected GetByKeyForAuth call") + } + return s.getByKeyForAuth(ctx, key) +} + +func (s *authRepoStub) Update(ctx context.Context, key *APIKey) error { + panic("unexpected Update call") +} + +func (s *authRepoStub) Delete(ctx context.Context, id int64) error { + panic("unexpected Delete call") +} + +func (s *authRepoStub) ListByUserID(ctx context.Context, userID int64, params pagination.PaginationParams) ([]APIKey, *pagination.PaginationResult, error) { + panic("unexpected ListByUserID call") +} + +func (s *authRepoStub) VerifyOwnership(ctx context.Context, userID int64, apiKeyIDs []int64) ([]int64, error) { + panic("unexpected VerifyOwnership call") +} + +func (s *authRepoStub) CountByUserID(ctx context.Context, userID int64) (int64, error) { + panic("unexpected CountByUserID call") +} + +func (s *authRepoStub) ExistsByKey(ctx context.Context, key string) (bool, error) { + panic("unexpected ExistsByKey call") +} + +func (s *authRepoStub) ListByGroupID(ctx context.Context, groupID int64, params pagination.PaginationParams) ([]APIKey, *pagination.PaginationResult, error) { + panic("unexpected ListByGroupID call") +} + +func (s *authRepoStub) SearchAPIKeys(ctx context.Context, userID int64, keyword string, limit int) ([]APIKey, error) { + panic("unexpected SearchAPIKeys call") +} + +func (s *authRepoStub) ClearGroupIDByGroupID(ctx context.Context, groupID int64) (int64, error) { + panic("unexpected ClearGroupIDByGroupID call") +} + +func (s *authRepoStub) CountByGroupID(ctx context.Context, groupID int64) (int64, error) { + panic("unexpected CountByGroupID call") +} + +func (s *authRepoStub) ListKeysByUserID(ctx context.Context, userID int64) ([]string, error) { + if s.listKeysByUserID == nil { + panic("unexpected ListKeysByUserID call") + } + return s.listKeysByUserID(ctx, userID) +} + +func (s *authRepoStub) ListKeysByGroupID(ctx context.Context, groupID int64) ([]string, error) { + if s.listKeysByGroupID == nil { + panic("unexpected ListKeysByGroupID call") + } + return s.listKeysByGroupID(ctx, groupID) +} + +type authCacheStub struct { + getAuthCache func(ctx context.Context, key string) (*APIKeyAuthCacheEntry, error) + setAuthKeys []string + deleteAuthKeys []string +} + +func (s *authCacheStub) GetCreateAttemptCount(ctx context.Context, userID int64) (int, error) { + return 0, nil +} + +func (s *authCacheStub) IncrementCreateAttemptCount(ctx context.Context, userID int64) error { + return nil +} + +func (s *authCacheStub) DeleteCreateAttemptCount(ctx context.Context, userID int64) error { + return nil +} + +func (s *authCacheStub) IncrementDailyUsage(ctx context.Context, apiKey string) error { + return nil +} + +func (s *authCacheStub) SetDailyUsageExpiry(ctx context.Context, apiKey string, ttl time.Duration) error { + return nil +} + +func (s *authCacheStub) GetAuthCache(ctx context.Context, key string) (*APIKeyAuthCacheEntry, error) { + if s.getAuthCache == nil { + return nil, redis.Nil + } + return s.getAuthCache(ctx, key) +} + +func (s *authCacheStub) SetAuthCache(ctx context.Context, key string, entry *APIKeyAuthCacheEntry, ttl time.Duration) error { + s.setAuthKeys = append(s.setAuthKeys, key) + return nil +} + +func (s *authCacheStub) DeleteAuthCache(ctx context.Context, key string) error { + s.deleteAuthKeys = append(s.deleteAuthKeys, key) + return nil +} + +func TestAPIKeyService_GetByKey_UsesL2Cache(t *testing.T) { + cache := &authCacheStub{} + repo := &authRepoStub{ + getByKeyForAuth: func(ctx context.Context, key string) (*APIKey, error) { + return nil, errors.New("unexpected repo call") + }, + } + cfg := &config.Config{ + APIKeyAuth: config.APIKeyAuthCacheConfig{ + L2TTLSeconds: 60, + NegativeTTLSeconds: 30, + }, + } + svc := NewAPIKeyService(repo, nil, nil, nil, cache, cfg) + + groupID := int64(9) + cacheEntry := &APIKeyAuthCacheEntry{ + Snapshot: &APIKeyAuthSnapshot{ + APIKeyID: 1, + UserID: 2, + GroupID: &groupID, + Status: StatusActive, + User: APIKeyAuthUserSnapshot{ + ID: 2, + Status: StatusActive, + Role: RoleUser, + Balance: 10, + Concurrency: 3, + }, + Group: &APIKeyAuthGroupSnapshot{ + ID: groupID, + Name: "g", + Platform: PlatformAnthropic, + Status: StatusActive, + SubscriptionType: SubscriptionTypeStandard, + RateMultiplier: 1, + }, + }, + } + cache.getAuthCache = func(ctx context.Context, key string) (*APIKeyAuthCacheEntry, error) { + return cacheEntry, nil + } + + apiKey, err := svc.GetByKey(context.Background(), "k1") + require.NoError(t, err) + require.Equal(t, int64(1), apiKey.ID) + require.Equal(t, int64(2), apiKey.User.ID) + require.Equal(t, groupID, apiKey.Group.ID) +} + +func TestAPIKeyService_GetByKey_NegativeCache(t *testing.T) { + cache := &authCacheStub{} + repo := &authRepoStub{ + getByKeyForAuth: func(ctx context.Context, key string) (*APIKey, error) { + return nil, errors.New("unexpected repo call") + }, + } + cfg := &config.Config{ + APIKeyAuth: config.APIKeyAuthCacheConfig{ + L2TTLSeconds: 60, + NegativeTTLSeconds: 30, + }, + } + svc := NewAPIKeyService(repo, nil, nil, nil, cache, cfg) + cache.getAuthCache = func(ctx context.Context, key string) (*APIKeyAuthCacheEntry, error) { + return &APIKeyAuthCacheEntry{NotFound: true}, nil + } + + _, err := svc.GetByKey(context.Background(), "missing") + require.ErrorIs(t, err, ErrAPIKeyNotFound) +} + +func TestAPIKeyService_GetByKey_CacheMissStoresL2(t *testing.T) { + cache := &authCacheStub{} + repo := &authRepoStub{ + getByKeyForAuth: func(ctx context.Context, key string) (*APIKey, error) { + return &APIKey{ + ID: 5, + UserID: 7, + Status: StatusActive, + User: &User{ + ID: 7, + Status: StatusActive, + Role: RoleUser, + Balance: 12, + Concurrency: 2, + }, + }, nil + }, + } + cfg := &config.Config{ + APIKeyAuth: config.APIKeyAuthCacheConfig{ + L2TTLSeconds: 60, + NegativeTTLSeconds: 30, + }, + } + svc := NewAPIKeyService(repo, nil, nil, nil, cache, cfg) + cache.getAuthCache = func(ctx context.Context, key string) (*APIKeyAuthCacheEntry, error) { + return nil, redis.Nil + } + + apiKey, err := svc.GetByKey(context.Background(), "k2") + require.NoError(t, err) + require.Equal(t, int64(5), apiKey.ID) + require.Len(t, cache.setAuthKeys, 1) +} + +func TestAPIKeyService_GetByKey_UsesL1Cache(t *testing.T) { + var calls int32 + cache := &authCacheStub{} + repo := &authRepoStub{ + getByKeyForAuth: func(ctx context.Context, key string) (*APIKey, error) { + atomic.AddInt32(&calls, 1) + return &APIKey{ + ID: 21, + UserID: 3, + Status: StatusActive, + User: &User{ + ID: 3, + Status: StatusActive, + Role: RoleUser, + Balance: 5, + Concurrency: 2, + }, + }, nil + }, + } + cfg := &config.Config{ + APIKeyAuth: config.APIKeyAuthCacheConfig{ + L1Size: 1000, + L1TTLSeconds: 60, + }, + } + svc := NewAPIKeyService(repo, nil, nil, nil, cache, cfg) + require.NotNil(t, svc.authCacheL1) + + _, err := svc.GetByKey(context.Background(), "k-l1") + require.NoError(t, err) + svc.authCacheL1.Wait() + cacheKey := svc.authCacheKey("k-l1") + _, ok := svc.authCacheL1.Get(cacheKey) + require.True(t, ok) + _, err = svc.GetByKey(context.Background(), "k-l1") + require.NoError(t, err) + require.Equal(t, int32(1), atomic.LoadInt32(&calls)) +} + +func TestAPIKeyService_InvalidateAuthCacheByUserID(t *testing.T) { + cache := &authCacheStub{} + repo := &authRepoStub{ + listKeysByUserID: func(ctx context.Context, userID int64) ([]string, error) { + return []string{"k1", "k2"}, nil + }, + } + cfg := &config.Config{ + APIKeyAuth: config.APIKeyAuthCacheConfig{ + L2TTLSeconds: 60, + NegativeTTLSeconds: 30, + }, + } + svc := NewAPIKeyService(repo, nil, nil, nil, cache, cfg) + + svc.InvalidateAuthCacheByUserID(context.Background(), 7) + require.Len(t, cache.deleteAuthKeys, 2) +} + +func TestAPIKeyService_InvalidateAuthCacheByGroupID(t *testing.T) { + cache := &authCacheStub{} + repo := &authRepoStub{ + listKeysByGroupID: func(ctx context.Context, groupID int64) ([]string, error) { + return []string{"k1", "k2"}, nil + }, + } + cfg := &config.Config{ + APIKeyAuth: config.APIKeyAuthCacheConfig{ + L2TTLSeconds: 60, + }, + } + svc := NewAPIKeyService(repo, nil, nil, nil, cache, cfg) + + svc.InvalidateAuthCacheByGroupID(context.Background(), 9) + require.Len(t, cache.deleteAuthKeys, 2) +} + +func TestAPIKeyService_InvalidateAuthCacheByKey(t *testing.T) { + cache := &authCacheStub{} + repo := &authRepoStub{ + listKeysByUserID: func(ctx context.Context, userID int64) ([]string, error) { + return nil, nil + }, + } + cfg := &config.Config{ + APIKeyAuth: config.APIKeyAuthCacheConfig{ + L2TTLSeconds: 60, + }, + } + svc := NewAPIKeyService(repo, nil, nil, nil, cache, cfg) + + svc.InvalidateAuthCacheByKey(context.Background(), "k1") + require.Len(t, cache.deleteAuthKeys, 1) +} + +func TestAPIKeyService_GetByKey_CachesNegativeOnRepoMiss(t *testing.T) { + cache := &authCacheStub{} + repo := &authRepoStub{ + getByKeyForAuth: func(ctx context.Context, key string) (*APIKey, error) { + return nil, ErrAPIKeyNotFound + }, + } + cfg := &config.Config{ + APIKeyAuth: config.APIKeyAuthCacheConfig{ + L2TTLSeconds: 60, + NegativeTTLSeconds: 30, + }, + } + svc := NewAPIKeyService(repo, nil, nil, nil, cache, cfg) + cache.getAuthCache = func(ctx context.Context, key string) (*APIKeyAuthCacheEntry, error) { + return nil, redis.Nil + } + + _, err := svc.GetByKey(context.Background(), "missing") + require.ErrorIs(t, err, ErrAPIKeyNotFound) + require.Len(t, cache.setAuthKeys, 1) +} + +func TestAPIKeyService_GetByKey_SingleflightCollapses(t *testing.T) { + var calls int32 + cache := &authCacheStub{} + repo := &authRepoStub{ + getByKeyForAuth: func(ctx context.Context, key string) (*APIKey, error) { + atomic.AddInt32(&calls, 1) + time.Sleep(50 * time.Millisecond) + return &APIKey{ + ID: 11, + UserID: 2, + Status: StatusActive, + User: &User{ + ID: 2, + Status: StatusActive, + Role: RoleUser, + Balance: 1, + Concurrency: 1, + }, + }, nil + }, + } + cfg := &config.Config{ + APIKeyAuth: config.APIKeyAuthCacheConfig{ + Singleflight: true, + }, + } + svc := NewAPIKeyService(repo, nil, nil, nil, cache, cfg) + + start := make(chan struct{}) + wg := sync.WaitGroup{} + errs := make([]error, 5) + for i := 0; i < 5; i++ { + wg.Add(1) + go func(idx int) { + defer wg.Done() + <-start + _, err := svc.GetByKey(context.Background(), "k1") + errs[idx] = err + }(i) + } + close(start) + wg.Wait() + + for _, err := range errs { + require.NoError(t, err) + } + require.Equal(t, int32(1), atomic.LoadInt32(&calls)) +} diff --git a/backend/internal/service/api_key_service_delete_test.go b/backend/internal/service/api_key_service_delete_test.go index 7d04c5ac..32ae884e 100644 --- a/backend/internal/service/api_key_service_delete_test.go +++ b/backend/internal/service/api_key_service_delete_test.go @@ -20,13 +20,12 @@ import ( // 用于隔离测试 APIKeyService.Delete 方法,避免依赖真实数据库。 // // 设计说明: -// - ownerID: 模拟 GetOwnerID 返回的所有者 ID -// - ownerErr: 模拟 GetOwnerID 返回的错误(如 ErrAPIKeyNotFound) +// - apiKey/getByIDErr: 模拟 GetKeyAndOwnerID 返回的记录与错误 // - deleteErr: 模拟 Delete 返回的错误 // - deletedIDs: 记录被调用删除的 API Key ID,用于断言验证 type apiKeyRepoStub struct { - ownerID int64 // GetOwnerID 的返回值 - ownerErr error // GetOwnerID 的错误返回值 + apiKey *APIKey // GetKeyAndOwnerID 的返回值 + getByIDErr error // GetKeyAndOwnerID 的错误返回值 deleteErr error // Delete 的错误返回值 deletedIDs []int64 // 记录已删除的 API Key ID 列表 } @@ -38,19 +37,34 @@ func (s *apiKeyRepoStub) Create(ctx context.Context, key *APIKey) error { } func (s *apiKeyRepoStub) GetByID(ctx context.Context, id int64) (*APIKey, error) { + if s.getByIDErr != nil { + return nil, s.getByIDErr + } + if s.apiKey != nil { + clone := *s.apiKey + return &clone, nil + } panic("unexpected GetByID call") } -// GetOwnerID 返回预设的所有者 ID 或错误。 -// 这是 Delete 方法调用的第一个仓储方法,用于验证调用者是否为 API Key 的所有者。 -func (s *apiKeyRepoStub) GetOwnerID(ctx context.Context, id int64) (int64, error) { - return s.ownerID, s.ownerErr +func (s *apiKeyRepoStub) GetKeyAndOwnerID(ctx context.Context, id int64) (string, int64, error) { + if s.getByIDErr != nil { + return "", 0, s.getByIDErr + } + if s.apiKey != nil { + return s.apiKey.Key, s.apiKey.UserID, nil + } + return "", 0, ErrAPIKeyNotFound } func (s *apiKeyRepoStub) GetByKey(ctx context.Context, key string) (*APIKey, error) { panic("unexpected GetByKey call") } +func (s *apiKeyRepoStub) GetByKeyForAuth(ctx context.Context, key string) (*APIKey, error) { + panic("unexpected GetByKeyForAuth call") +} + func (s *apiKeyRepoStub) Update(ctx context.Context, key *APIKey) error { panic("unexpected Update call") } @@ -96,13 +110,22 @@ func (s *apiKeyRepoStub) CountByGroupID(ctx context.Context, groupID int64) (int panic("unexpected CountByGroupID call") } +func (s *apiKeyRepoStub) ListKeysByUserID(ctx context.Context, userID int64) ([]string, error) { + panic("unexpected ListKeysByUserID call") +} + +func (s *apiKeyRepoStub) ListKeysByGroupID(ctx context.Context, groupID int64) ([]string, error) { + panic("unexpected ListKeysByGroupID call") +} + // apiKeyCacheStub 是 APIKeyCache 接口的测试桩实现。 // 用于验证删除操作时缓存清理逻辑是否被正确调用。 // // 设计说明: // - invalidated: 记录被清除缓存的用户 ID 列表 type apiKeyCacheStub struct { - invalidated []int64 // 记录调用 DeleteCreateAttemptCount 时传入的用户 ID + invalidated []int64 // 记录调用 DeleteCreateAttemptCount 时传入的用户 ID + deleteAuthKeys []string // 记录调用 DeleteAuthCache 时传入的缓存 key } // GetCreateAttemptCount 返回 0,表示用户未超过创建次数限制 @@ -132,15 +155,30 @@ func (s *apiKeyCacheStub) SetDailyUsageExpiry(ctx context.Context, apiKey string return nil } +func (s *apiKeyCacheStub) GetAuthCache(ctx context.Context, key string) (*APIKeyAuthCacheEntry, error) { + return nil, nil +} + +func (s *apiKeyCacheStub) SetAuthCache(ctx context.Context, key string, entry *APIKeyAuthCacheEntry, ttl time.Duration) error { + return nil +} + +func (s *apiKeyCacheStub) DeleteAuthCache(ctx context.Context, key string) error { + s.deleteAuthKeys = append(s.deleteAuthKeys, key) + return nil +} + // TestApiKeyService_Delete_OwnerMismatch 测试非所有者尝试删除时返回权限错误。 // 预期行为: -// - GetOwnerID 返回所有者 ID 为 1 +// - GetKeyAndOwnerID 返回所有者 ID 为 1 // - 调用者 userID 为 2(不匹配) // - 返回 ErrInsufficientPerms 错误 // - Delete 方法不被调用 // - 缓存不被清除 func TestApiKeyService_Delete_OwnerMismatch(t *testing.T) { - repo := &apiKeyRepoStub{ownerID: 1} + repo := &apiKeyRepoStub{ + apiKey: &APIKey{ID: 10, UserID: 1, Key: "k"}, + } cache := &apiKeyCacheStub{} svc := &APIKeyService{apiKeyRepo: repo, cache: cache} @@ -148,17 +186,20 @@ func TestApiKeyService_Delete_OwnerMismatch(t *testing.T) { require.ErrorIs(t, err, ErrInsufficientPerms) require.Empty(t, repo.deletedIDs) // 验证删除操作未被调用 require.Empty(t, cache.invalidated) // 验证缓存未被清除 + require.Empty(t, cache.deleteAuthKeys) } // TestApiKeyService_Delete_Success 测试所有者成功删除 API Key 的场景。 // 预期行为: -// - GetOwnerID 返回所有者 ID 为 7 +// - GetKeyAndOwnerID 返回所有者 ID 为 7 // - 调用者 userID 为 7(匹配) // - Delete 成功执行 // - 缓存被正确清除(使用 ownerID) // - 返回 nil 错误 func TestApiKeyService_Delete_Success(t *testing.T) { - repo := &apiKeyRepoStub{ownerID: 7} + repo := &apiKeyRepoStub{ + apiKey: &APIKey{ID: 42, UserID: 7, Key: "k"}, + } cache := &apiKeyCacheStub{} svc := &APIKeyService{apiKeyRepo: repo, cache: cache} @@ -166,16 +207,17 @@ func TestApiKeyService_Delete_Success(t *testing.T) { require.NoError(t, err) require.Equal(t, []int64{42}, repo.deletedIDs) // 验证正确的 API Key 被删除 require.Equal(t, []int64{7}, cache.invalidated) // 验证所有者的缓存被清除 + require.Equal(t, []string{svc.authCacheKey("k")}, cache.deleteAuthKeys) } // TestApiKeyService_Delete_NotFound 测试删除不存在的 API Key 时返回正确的错误。 // 预期行为: -// - GetOwnerID 返回 ErrAPIKeyNotFound 错误 +// - GetKeyAndOwnerID 返回 ErrAPIKeyNotFound 错误 // - 返回 ErrAPIKeyNotFound 错误(被 fmt.Errorf 包装) // - Delete 方法不被调用 // - 缓存不被清除 func TestApiKeyService_Delete_NotFound(t *testing.T) { - repo := &apiKeyRepoStub{ownerErr: ErrAPIKeyNotFound} + repo := &apiKeyRepoStub{getByIDErr: ErrAPIKeyNotFound} cache := &apiKeyCacheStub{} svc := &APIKeyService{apiKeyRepo: repo, cache: cache} @@ -183,18 +225,19 @@ func TestApiKeyService_Delete_NotFound(t *testing.T) { require.ErrorIs(t, err, ErrAPIKeyNotFound) require.Empty(t, repo.deletedIDs) require.Empty(t, cache.invalidated) + require.Empty(t, cache.deleteAuthKeys) } // TestApiKeyService_Delete_DeleteFails 测试删除操作失败时的错误处理。 // 预期行为: -// - GetOwnerID 返回正确的所有者 ID +// - GetKeyAndOwnerID 返回正确的所有者 ID // - 所有权验证通过 // - 缓存被清除(在删除之前) // - Delete 被调用但返回错误 // - 返回包含 "delete api key" 的错误信息 func TestApiKeyService_Delete_DeleteFails(t *testing.T) { repo := &apiKeyRepoStub{ - ownerID: 3, + apiKey: &APIKey{ID: 42, UserID: 3, Key: "k"}, deleteErr: errors.New("delete failed"), } cache := &apiKeyCacheStub{} @@ -205,4 +248,5 @@ func TestApiKeyService_Delete_DeleteFails(t *testing.T) { require.ErrorContains(t, err, "delete api key") require.Equal(t, []int64{3}, repo.deletedIDs) // 验证删除操作被调用 require.Equal(t, []int64{3}, cache.invalidated) // 验证缓存已被清除(即使删除失败) + require.Equal(t, []string{svc.authCacheKey("k")}, cache.deleteAuthKeys) } diff --git a/backend/internal/service/group_service.go b/backend/internal/service/group_service.go index 2f0f4975..a9214c82 100644 --- a/backend/internal/service/group_service.go +++ b/backend/internal/service/group_service.go @@ -50,13 +50,15 @@ type UpdateGroupRequest struct { // GroupService 分组管理服务 type GroupService struct { - groupRepo GroupRepository + groupRepo GroupRepository + authCacheInvalidator APIKeyAuthCacheInvalidator } // NewGroupService 创建分组服务实例 -func NewGroupService(groupRepo GroupRepository) *GroupService { +func NewGroupService(groupRepo GroupRepository, authCacheInvalidator APIKeyAuthCacheInvalidator) *GroupService { return &GroupService{ - groupRepo: groupRepo, + groupRepo: groupRepo, + authCacheInvalidator: authCacheInvalidator, } } @@ -155,6 +157,9 @@ func (s *GroupService) Update(ctx context.Context, id int64, req UpdateGroupRequ if err := s.groupRepo.Update(ctx, group); err != nil { return nil, fmt.Errorf("update group: %w", err) } + if s.authCacheInvalidator != nil { + s.authCacheInvalidator.InvalidateAuthCacheByGroupID(ctx, id) + } return group, nil } @@ -170,6 +175,9 @@ func (s *GroupService) Delete(ctx context.Context, id int64) error { if err := s.groupRepo.Delete(ctx, id); err != nil { return fmt.Errorf("delete group: %w", err) } + if s.authCacheInvalidator != nil { + s.authCacheInvalidator.InvalidateAuthCacheByGroupID(ctx, id) + } return nil } diff --git a/backend/internal/service/user_service.go b/backend/internal/service/user_service.go index 08fa40b5..a7a36760 100644 --- a/backend/internal/service/user_service.go +++ b/backend/internal/service/user_service.go @@ -55,13 +55,15 @@ type ChangePasswordRequest struct { // UserService 用户服务 type UserService struct { - userRepo UserRepository + userRepo UserRepository + authCacheInvalidator APIKeyAuthCacheInvalidator } // NewUserService 创建用户服务实例 -func NewUserService(userRepo UserRepository) *UserService { +func NewUserService(userRepo UserRepository, authCacheInvalidator APIKeyAuthCacheInvalidator) *UserService { return &UserService{ - userRepo: userRepo, + userRepo: userRepo, + authCacheInvalidator: authCacheInvalidator, } } @@ -89,6 +91,7 @@ func (s *UserService) UpdateProfile(ctx context.Context, userID int64, req Updat if err != nil { return nil, fmt.Errorf("get user: %w", err) } + oldConcurrency := user.Concurrency // 更新字段 if req.Email != nil { @@ -114,6 +117,9 @@ func (s *UserService) UpdateProfile(ctx context.Context, userID int64, req Updat if err := s.userRepo.Update(ctx, user); err != nil { return nil, fmt.Errorf("update user: %w", err) } + if s.authCacheInvalidator != nil && user.Concurrency != oldConcurrency { + s.authCacheInvalidator.InvalidateAuthCacheByUserID(ctx, userID) + } return user, nil } @@ -169,6 +175,9 @@ func (s *UserService) UpdateBalance(ctx context.Context, userID int64, amount fl if err := s.userRepo.UpdateBalance(ctx, userID, amount); err != nil { return fmt.Errorf("update balance: %w", err) } + if s.authCacheInvalidator != nil { + s.authCacheInvalidator.InvalidateAuthCacheByUserID(ctx, userID) + } return nil } @@ -177,6 +186,9 @@ func (s *UserService) UpdateConcurrency(ctx context.Context, userID int64, concu if err := s.userRepo.UpdateConcurrency(ctx, userID, concurrency); err != nil { return fmt.Errorf("update concurrency: %w", err) } + if s.authCacheInvalidator != nil { + s.authCacheInvalidator.InvalidateAuthCacheByUserID(ctx, userID) + } return nil } @@ -192,6 +204,9 @@ func (s *UserService) UpdateStatus(ctx context.Context, userID int64, status str if err := s.userRepo.Update(ctx, user); err != nil { return fmt.Errorf("update user: %w", err) } + if s.authCacheInvalidator != nil { + s.authCacheInvalidator.InvalidateAuthCacheByUserID(ctx, userID) + } return nil } @@ -201,5 +216,8 @@ func (s *UserService) Delete(ctx context.Context, userID int64) error { if err := s.userRepo.Delete(ctx, userID); err != nil { return fmt.Errorf("delete user: %w", err) } + if s.authCacheInvalidator != nil { + s.authCacheInvalidator.InvalidateAuthCacheByUserID(ctx, userID) + } return nil } diff --git a/backend/internal/service/wire.go b/backend/internal/service/wire.go index 512d2550..54c37b54 100644 --- a/backend/internal/service/wire.go +++ b/backend/internal/service/wire.go @@ -77,12 +77,18 @@ func ProvideConcurrencyService(cache ConcurrencyCache, accountRepo AccountReposi return svc } +// ProvideAPIKeyAuthCacheInvalidator 提供 API Key 认证缓存失效能力 +func ProvideAPIKeyAuthCacheInvalidator(apiKeyService *APIKeyService) APIKeyAuthCacheInvalidator { + return apiKeyService +} + // ProviderSet is the Wire provider set for all services var ProviderSet = wire.NewSet( // Core services NewAuthService, NewUserService, NewAPIKeyService, + ProvideAPIKeyAuthCacheInvalidator, NewGroupService, NewAccountService, NewProxyService, diff --git a/config.yaml b/config.yaml index 54b591f3..ecd7dfc2 100644 --- a/config.yaml +++ b/config.yaml @@ -170,6 +170,30 @@ gateway: # 允许在特定 400 错误时进行故障转移(默认:关闭) failover_on_400: false +# ============================================================================= +# API Key Auth Cache Configuration +# API Key 认证缓存配置 +# ============================================================================= +api_key_auth_cache: + # L1 cache size (entries), in-process LRU/TTL cache + # L1 缓存容量(条目数),进程内 LRU/TTL 缓存 + l1_size: 65535 + # L1 cache TTL (seconds) + # L1 缓存 TTL(秒) + l1_ttl_seconds: 15 + # L2 cache TTL (seconds), stored in Redis + # L2 缓存 TTL(秒),Redis 中存储 + l2_ttl_seconds: 300 + # Negative cache TTL (seconds) + # 负缓存 TTL(秒) + negative_ttl_seconds: 30 + # TTL jitter percent (0-100) + # TTL 抖动百分比(0-100) + jitter_percent: 10 + # Enable singleflight for cache misses + # 缓存未命中时启用 singleflight 合并回源 + singleflight: true + # ============================================================================= # Concurrency Wait Configuration # 并发等待配置 diff --git a/deploy/config.example.yaml b/deploy/config.example.yaml index 87ff3148..87abffa0 100644 --- a/deploy/config.example.yaml +++ b/deploy/config.example.yaml @@ -170,6 +170,30 @@ gateway: # 允许在特定 400 错误时进行故障转移(默认:关闭) failover_on_400: false +# ============================================================================= +# API Key Auth Cache Configuration +# API Key 认证缓存配置 +# ============================================================================= +api_key_auth_cache: + # L1 cache size (entries), in-process LRU/TTL cache + # L1 缓存容量(条目数),进程内 LRU/TTL 缓存 + l1_size: 65535 + # L1 cache TTL (seconds) + # L1 缓存 TTL(秒) + l1_ttl_seconds: 15 + # L2 cache TTL (seconds), stored in Redis + # L2 缓存 TTL(秒),Redis 中存储 + l2_ttl_seconds: 300 + # Negative cache TTL (seconds) + # 负缓存 TTL(秒) + negative_ttl_seconds: 30 + # TTL jitter percent (0-100) + # TTL 抖动百分比(0-100) + jitter_percent: 10 + # Enable singleflight for cache misses + # 缓存未命中时启用 singleflight 合并回源 + singleflight: true + # ============================================================================= # Concurrency Wait Configuration # 并发等待配置 From 80c1cdf02499c0beccb8434e40853f9c43290000 Mon Sep 17 00:00:00 2001 From: cyhhao Date: Sat, 10 Jan 2026 22:45:29 +0800 Subject: [PATCH 05/23] fix(lint): trim unused codex helpers --- backend/internal/service/codex_prompts.go | 9 - .../service/openai_codex_transform.go | 571 +----------------- 2 files changed, 5 insertions(+), 575 deletions(-) delete mode 100644 backend/internal/service/codex_prompts.go diff --git a/backend/internal/service/codex_prompts.go b/backend/internal/service/codex_prompts.go deleted file mode 100644 index 6f83eac2..00000000 --- a/backend/internal/service/codex_prompts.go +++ /dev/null @@ -1,9 +0,0 @@ -package service - -import _ "embed" - -//go:embed prompts/codex_opencode_bridge.txt -var codexOpenCodeBridge string - -//go:embed prompts/tool_remap_message.txt -var codexToolRemapMessage string diff --git a/backend/internal/service/openai_codex_transform.go b/backend/internal/service/openai_codex_transform.go index 3514dc7a..965fb770 100644 --- a/backend/internal/service/openai_codex_transform.go +++ b/backend/internal/service/openai_codex_transform.go @@ -9,44 +9,13 @@ import ( "path/filepath" "strings" "time" - "unicode" ) const ( - codexReleaseAPIURL = "https://api.github.com/repos/openai/codex/releases/latest" - codexReleaseHTMLURL = "https://github.com/openai/codex/releases/latest" - codexPromptURLFmt = "https://raw.githubusercontent.com/openai/codex/%s/codex-rs/core/%s" - opencodeCodexURL = "https://raw.githubusercontent.com/anomalyco/opencode/dev/packages/opencode/src/session/prompt/codex.txt" opencodeCodexHeaderURL = "https://raw.githubusercontent.com/anomalyco/opencode/dev/packages/opencode/src/session/prompt/codex_header.txt" codexCacheTTL = 15 * time.Minute ) -type codexModelFamily string - -const ( - codexFamilyGpt52Codex codexModelFamily = "gpt-5.2-codex" - codexFamilyCodexMax codexModelFamily = "codex-max" - codexFamilyCodex codexModelFamily = "codex" - codexFamilyGpt52 codexModelFamily = "gpt-5.2" - codexFamilyGpt51 codexModelFamily = "gpt-5.1" -) - -var codexPromptFiles = map[codexModelFamily]string{ - codexFamilyGpt52Codex: "gpt-5.2-codex_prompt.md", - codexFamilyCodexMax: "gpt-5.1-codex-max_prompt.md", - codexFamilyCodex: "gpt_5_codex_prompt.md", - codexFamilyGpt52: "gpt_5_2_prompt.md", - codexFamilyGpt51: "gpt_5_1_prompt.md", -} - -var codexCacheFiles = map[codexModelFamily]string{ - codexFamilyGpt52Codex: "gpt-5.2-codex-instructions.md", - codexFamilyCodexMax: "codex-max-instructions.md", - codexFamilyCodex: "codex-instructions.md", - codexFamilyGpt52: "gpt-5.2-instructions.md", - codexFamilyGpt51: "gpt-5.1-instructions.md", -} - var codexModelMap = map[string]string{ "gpt-5.1-codex": "gpt-5.1-codex", "gpt-5.1-codex-low": "gpt-5.1-codex", @@ -87,55 +56,18 @@ var codexModelMap = map[string]string{ "gpt-5-nano": "gpt-5.1", } -var opencodePromptSignatures = []string{ - "you are a coding agent running in the opencode", - "you are opencode, an agent", - "you are opencode, an interactive cli agent", - "you are opencode, an interactive cli tool", - "you are opencode, the best coding agent on the planet", -} - -var opencodeContextMarkers = []string{ - "here is some useful information about the environment you are running in:", - "", - "instructions from:", - "", -} - type codexTransformResult struct { Modified bool NormalizedModel string PromptCacheKey string } -type codexCacheMetadata struct { - ETag string `json:"etag"` - Tag string `json:"tag"` - LastChecked int64 `json:"lastChecked"` - URL string `json:"url"` -} - type opencodeCacheMetadata struct { ETag string `json:"etag"` LastFetch string `json:"lastFetch,omitempty"` LastChecked int64 `json:"lastChecked"` } -func codexModeEnabled() bool { - value := strings.TrimSpace(os.Getenv("CODEX_MODE")) - if value == "" { - return true - } - switch strings.ToLower(value) { - case "0", "false", "no", "off": - return false - case "1", "true", "yes", "on": - return true - default: - return true - } -} - func applyCodexOAuthTransform(reqBody map[string]any) codexTransformResult { result := codexTransformResult{} @@ -271,144 +203,6 @@ func getNormalizedCodexModel(modelID string) string { return "" } -func getCodexModelFamily(normalizedModel string) codexModelFamily { - model := strings.ToLower(normalizedModel) - if strings.Contains(model, "gpt-5.2-codex") || strings.Contains(model, "gpt 5.2 codex") { - return codexFamilyGpt52Codex - } - if strings.Contains(model, "codex-max") { - return codexFamilyCodexMax - } - if strings.Contains(model, "codex") || strings.HasPrefix(model, "codex-") { - return codexFamilyCodex - } - if strings.Contains(model, "gpt-5.2") { - return codexFamilyGpt52 - } - return codexFamilyGpt51 -} - -func getCodexInstructions(normalizedModel string) string { - if normalizedModel == "" { - normalizedModel = "gpt-5.1-codex" - } - - modelFamily := getCodexModelFamily(normalizedModel) - promptFile := codexPromptFiles[modelFamily] - cacheFile := codexCachePath(codexCacheFiles[modelFamily]) - metaFile := codexCachePath(strings.TrimSuffix(codexCacheFiles[modelFamily], ".md") + "-meta.json") - - var meta codexCacheMetadata - if loadJSON(metaFile, &meta) && meta.LastChecked > 0 { - if time.Since(time.UnixMilli(meta.LastChecked)) < codexCacheTTL { - if cached, ok := readFile(cacheFile); ok { - return cached - } - } - } - - latestTag, err := getLatestCodexReleaseTag() - if err != nil { - if cached, ok := readFile(cacheFile); ok { - return cached - } - return "" - } - - if meta.Tag != latestTag { - meta.ETag = "" - } - - promptURL := fmt.Sprintf(codexPromptURLFmt, latestTag, promptFile) - content, etag, status, err := fetchWithETag(promptURL, meta.ETag) - if err == nil && status == http.StatusNotModified { - if cached, ok := readFile(cacheFile); ok { - return cached - } - } - if err == nil && status >= 200 && status < 300 { - if content != "" { - if err := writeFile(cacheFile, content); err == nil { - meta = codexCacheMetadata{ - ETag: etag, - Tag: latestTag, - LastChecked: time.Now().UnixMilli(), - URL: promptURL, - } - _ = writeJSON(metaFile, meta) - } - return content - } - } - - if cached, ok := readFile(cacheFile); ok { - return cached - } - - return "" -} - -func getLatestCodexReleaseTag() (string, error) { - body, _, status, err := fetchWithETag(codexReleaseAPIURL, "") - if err == nil && status >= 200 && status < 300 && body != "" { - var data struct { - TagName string `json:"tag_name"` - } - if json.Unmarshal([]byte(body), &data) == nil && data.TagName != "" { - return data.TagName, nil - } - } - - resp, err := http.Get(codexReleaseHTMLURL) - if err != nil { - return "", err - } - defer resp.Body.Close() - - finalURL := "" - if resp.Request != nil && resp.Request.URL != nil { - finalURL = resp.Request.URL.String() - } - if finalURL != "" { - if tag := parseReleaseTagFromURL(finalURL); tag != "" { - return tag, nil - } - } - - html, err := io.ReadAll(resp.Body) - if err != nil { - return "", err - } - return parseReleaseTagFromHTML(string(html)) -} - -func parseReleaseTagFromURL(url string) string { - parts := strings.Split(url, "/tag/") - if len(parts) < 2 { - return "" - } - tag := parts[len(parts)-1] - if tag == "" || strings.Contains(tag, "/") { - return "" - } - return tag -} - -func parseReleaseTagFromHTML(html string) (string, error) { - const marker = "/openai/codex/releases/tag/" - idx := strings.Index(html, marker) - if idx == -1 { - return "", fmt.Errorf("release tag not found") - } - rest := html[idx+len(marker):] - for i, r := range rest { - if r == '"' || r == '\'' { - return rest[:i], nil - } - } - return "", fmt.Errorf("release tag not found") -} - func getOpenCodeCachedPrompt(url, cacheFileName, metaFileName string) string { cacheDir := codexCachePath("") if cacheDir == "" { @@ -447,10 +241,6 @@ func getOpenCodeCachedPrompt(url, cacheFileName, metaFileName string) string { return cachedContent } -func getOpenCodeCodexPrompt() string { - return getOpenCodeCachedPrompt(opencodeCodexURL, "opencode-codex.txt", "opencode-codex-meta.json") -} - func getOpenCodeCodexHeader() string { return getOpenCodeCachedPrompt(opencodeCodexHeaderURL, "opencode-codex-header.txt", "opencode-codex-header-meta.json") } @@ -470,9 +260,7 @@ func filterCodexInput(input []any) []any { if typ, ok := m["type"].(string); ok && typ == "item_reference" { continue } - if _, ok := m["id"]; ok { - delete(m, "id") - } + delete(m, "id") filtered = append(filtered, m) } return filtered @@ -491,180 +279,6 @@ func prependSystemInstruction(input []any, instructions string) []any { return append([]any{message}, input...) } -func filterOpenCodeSystemPromptsWithCachedPrompt(input []any, cachedPrompt string) []any { - if len(input) == 0 { - return input - } - cachedPrompt = strings.TrimSpace(cachedPrompt) - - result := make([]any, 0, len(input)) - for _, item := range input { - m, ok := item.(map[string]any) - if !ok { - result = append(result, item) - continue - } - role, _ := m["role"].(string) - if role == "user" { - result = append(result, item) - continue - } - if !isOpenCodeSystemPrompt(m, cachedPrompt) { - result = append(result, item) - continue - } - contentText := getContentText(m) - if contentText == "" { - continue - } - if preserved := extractOpenCodeContext(contentText); preserved != "" { - result = append(result, replaceContentText(m, preserved)) - } - } - return result -} - -func isOpenCodeSystemPrompt(item map[string]any, cachedPrompt string) bool { - role, _ := item["role"].(string) - if role != "developer" && role != "system" { - return false - } - - contentText := getContentText(item) - if contentText == "" { - return false - } - - if cachedPrompt != "" { - contentTrimmed := strings.TrimSpace(contentText) - cachedTrimmed := strings.TrimSpace(cachedPrompt) - if contentTrimmed == cachedTrimmed { - return true - } - if strings.HasPrefix(contentTrimmed, cachedTrimmed) { - return true - } - contentPrefix := contentTrimmed - if len(contentPrefix) > 200 { - contentPrefix = contentPrefix[:200] - } - cachedPrefix := cachedTrimmed - if len(cachedPrefix) > 200 { - cachedPrefix = cachedPrefix[:200] - } - if contentPrefix == cachedPrefix { - return true - } - } - - normalized := strings.ToLower(strings.TrimLeftFunc(contentText, unicode.IsSpace)) - for _, signature := range opencodePromptSignatures { - if strings.HasPrefix(normalized, signature) { - return true - } - } - return false -} - -func getContentText(item map[string]any) string { - content := item["content"] - if content == nil { - return "" - } - switch v := content.(type) { - case string: - return v - case []any: - var parts []string - for _, part := range v { - partMap, ok := part.(map[string]any) - if !ok { - continue - } - typ, _ := partMap["type"].(string) - if typ != "input_text" { - continue - } - if text, ok := partMap["text"].(string); ok && text != "" { - parts = append(parts, text) - } - } - return strings.Join(parts, "\n") - default: - return "" - } -} - -func replaceContentText(item map[string]any, contentText string) map[string]any { - content := item["content"] - switch content.(type) { - case string: - item["content"] = contentText - case []any: - item["content"] = []any{map[string]any{ - "type": "input_text", - "text": contentText, - }} - default: - item["content"] = contentText - } - return item -} - -func extractOpenCodeContext(contentText string) string { - lower := strings.ToLower(contentText) - earliest := -1 - for _, marker := range opencodeContextMarkers { - idx := strings.Index(lower, marker) - if idx >= 0 && (earliest == -1 || idx < earliest) { - earliest = idx - } - } - if earliest == -1 { - return "" - } - return strings.TrimLeftFunc(contentText[earliest:], unicode.IsSpace) -} - -func addCodexBridgeMessage(input []any) []any { - message := map[string]any{ - "type": "message", - "role": "developer", - "content": []any{ - map[string]any{ - "type": "input_text", - "text": codexOpenCodeBridge, - }, - }, - } - return append([]any{message}, input...) -} - -func addToolRemapMessage(input []any) []any { - message := map[string]any{ - "type": "message", - "role": "developer", - "content": []any{ - map[string]any{ - "type": "input_text", - "text": codexToolRemapMessage, - }, - }, - } - return append([]any{message}, input...) -} - -func hasTools(reqBody map[string]any) bool { - tools, ok := reqBody["tools"] - if !ok || tools == nil { - return false - } - if list, ok := tools.([]any); ok { - return len(list) > 0 - } - return true -} - func normalizeCodexTools(reqBody map[string]any) bool { rawTools, ok := reqBody["tools"] if !ok || rawTools == nil { @@ -761,7 +375,7 @@ func normalizeOrphanedToolOutputs(input []any) []any { switch m["type"] { case "function_call_output": callID := getCallID(m) - if callID == "" || !(functionCallIDs[callID] || localShellCallIDs[callID]) { + if callID == "" || (!functionCallIDs[callID] && !localShellCallIDs[callID]) { output = append(output, convertOrphanedOutputToMessage(m, callID)) continue } @@ -831,183 +445,6 @@ func stringifyOutput(output any) string { } } -func resolveCodexReasoning(reqBody map[string]any, modelName string) (string, string) { - existingEffort := getReasoningValue(reqBody, "effort", "reasoningEffort") - existingSummary := getReasoningValue(reqBody, "summary", "reasoningSummary") - return getReasoningConfig(modelName, existingEffort, existingSummary) -} - -func getReasoningValue(reqBody map[string]any, field, providerField string) string { - if reasoning, ok := reqBody["reasoning"].(map[string]any); ok { - if value, ok := reasoning[field].(string); ok && value != "" { - return value - } - } - if provider := getProviderOpenAI(reqBody); provider != nil { - if value, ok := provider[providerField].(string); ok && value != "" { - return value - } - } - return "" -} - -func resolveTextVerbosity(reqBody map[string]any) string { - if text, ok := reqBody["text"].(map[string]any); ok { - if value, ok := text["verbosity"].(string); ok && value != "" { - return value - } - } - if provider := getProviderOpenAI(reqBody); provider != nil { - if value, ok := provider["textVerbosity"].(string); ok && value != "" { - return value - } - } - return "medium" -} - -func resolveInclude(reqBody map[string]any) []any { - include := toStringSlice(reqBody["include"]) - if len(include) == 0 { - if provider := getProviderOpenAI(reqBody); provider != nil { - include = toStringSlice(provider["include"]) - } - } - if len(include) == 0 { - include = []string{"reasoning.encrypted_content"} - } - - unique := make(map[string]struct{}, len(include)+1) - for _, value := range include { - if value == "" { - continue - } - unique[value] = struct{}{} - } - if _, ok := unique["reasoning.encrypted_content"]; !ok { - include = append(include, "reasoning.encrypted_content") - unique["reasoning.encrypted_content"] = struct{}{} - } - - final := make([]any, 0, len(unique)) - for _, value := range include { - if value == "" { - continue - } - if _, ok := unique[value]; ok { - final = append(final, value) - delete(unique, value) - } - } - for value := range unique { - final = append(final, value) - } - return final -} - -func getReasoningConfig(modelName, effortOverride, summaryOverride string) (string, string) { - normalized := strings.ToLower(modelName) - - isGpt52Codex := strings.Contains(normalized, "gpt-5.2-codex") || strings.Contains(normalized, "gpt 5.2 codex") - isGpt52General := (strings.Contains(normalized, "gpt-5.2") || strings.Contains(normalized, "gpt 5.2")) && !isGpt52Codex - isCodexMax := strings.Contains(normalized, "codex-max") || strings.Contains(normalized, "codex max") - isCodexMini := strings.Contains(normalized, "codex-mini") || - strings.Contains(normalized, "codex mini") || - strings.Contains(normalized, "codex_mini") || - strings.Contains(normalized, "codex-mini-latest") - isCodex := strings.Contains(normalized, "codex") && !isCodexMini - isLightweight := !isCodexMini && (strings.Contains(normalized, "nano") || strings.Contains(normalized, "mini")) - isGpt51General := (strings.Contains(normalized, "gpt-5.1") || strings.Contains(normalized, "gpt 5.1")) && - !isCodex && !isCodexMax && !isCodexMini - - supportsXhigh := isGpt52General || isGpt52Codex || isCodexMax - supportsNone := isGpt52General || isGpt51General - - defaultEffort := "medium" - if isCodexMini { - defaultEffort = "medium" - } else if supportsXhigh { - defaultEffort = "high" - } else if isLightweight { - defaultEffort = "minimal" - } - - effort := effortOverride - if effort == "" { - effort = defaultEffort - } - - if isCodexMini { - if effort == "minimal" || effort == "low" || effort == "none" { - effort = "medium" - } - if effort == "xhigh" { - effort = "high" - } - if effort != "high" && effort != "medium" { - effort = "medium" - } - } - - if !supportsXhigh && effort == "xhigh" { - effort = "high" - } - if !supportsNone && effort == "none" { - effort = "low" - } - if effort == "minimal" { - effort = "low" - } - - summary := summaryOverride - if summary == "" { - summary = "auto" - } - - return effort, summary -} - -func getProviderOpenAI(reqBody map[string]any) map[string]any { - providerOptions, ok := reqBody["providerOptions"].(map[string]any) - if !ok || providerOptions == nil { - return nil - } - openaiOptions, ok := providerOptions["openai"].(map[string]any) - if !ok || openaiOptions == nil { - return nil - } - return openaiOptions -} - -func ensureMap(value any) map[string]any { - if value == nil { - return map[string]any{} - } - if m, ok := value.(map[string]any); ok { - return m - } - return map[string]any{} -} - -func toStringSlice(value any) []string { - if value == nil { - return nil - } - switch v := value.(type) { - case []string: - return append([]string{}, v...) - case []any: - out := make([]string, 0, len(v)) - for _, item := range v { - if text, ok := item.(string); ok { - out = append(out, text) - } - } - return out - default: - return nil - } -} - func codexCachePath(filename string) string { home, err := os.UserHomeDir() if err != nil { @@ -1079,7 +516,9 @@ func fetchWithETag(url, etag string) (string, string, int, error) { if err != nil { return "", "", 0, err } - defer resp.Body.Close() + defer func() { + _ = resp.Body.Close() + }() body, err := io.ReadAll(resp.Body) if err != nil { From 99e2391b2ac7c4927ed69a82025f30ca48ac0a92 Mon Sep 17 00:00:00 2001 From: yangjianbo Date: Sat, 10 Jan 2026 22:52:13 +0800 Subject: [PATCH 06/23] =?UTF-8?q?fix(=E8=AE=A4=E8=AF=81):=20=E8=A1=A5?= =?UTF-8?q?=E9=BD=90=E4=BD=99=E9=A2=9D=E4=B8=8E=E5=88=A0=E9=99=A4=E5=9C=BA?= =?UTF-8?q?=E6=99=AF=E7=BC=93=E5=AD=98=E5=A4=B1=E6=95=88?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit 为 Usage/Promo/Redeem 注入认证缓存失效逻辑 删除用户与分组前先失效认证缓存降低窗口 补充回归测试验证失效调用 测试: make test --- backend/cmd/server/wire_gen.go | 8 ++-- backend/internal/server/api_contract_test.go | 2 +- .../service/auth_cache_invalidation_test.go | 31 ++++++++++++ backend/internal/service/group_service.go | 6 +-- backend/internal/service/promo_service.go | 28 +++++++---- backend/internal/service/redeem_service.go | 47 ++++++++++++------- backend/internal/service/usage_service.go | 27 ++++++++--- backend/internal/service/user_service.go | 6 +-- 8 files changed, 113 insertions(+), 42 deletions(-) create mode 100644 backend/internal/service/auth_cache_invalidation_test.go diff --git a/backend/cmd/server/wire_gen.go b/backend/cmd/server/wire_gen.go index a372f673..95a7b30b 100644 --- a/backend/cmd/server/wire_gen.go +++ b/backend/cmd/server/wire_gen.go @@ -55,24 +55,24 @@ func initializeApplication(buildInfo handler.BuildInfo) (*Application, error) { billingCache := repository.NewBillingCache(redisClient) userSubscriptionRepository := repository.NewUserSubscriptionRepository(client) billingCacheService := service.NewBillingCacheService(billingCache, userRepository, userSubscriptionRepository, configConfig) - promoService := service.NewPromoService(promoCodeRepository, userRepository, billingCacheService, client) - authService := service.NewAuthService(userRepository, configConfig, settingService, emailService, turnstileService, emailQueueService, promoService) apiKeyRepository := repository.NewAPIKeyRepository(client) groupRepository := repository.NewGroupRepository(client, db) apiKeyCache := repository.NewAPIKeyCache(redisClient) apiKeyService := service.NewAPIKeyService(apiKeyRepository, userRepository, groupRepository, userSubscriptionRepository, apiKeyCache, configConfig) apiKeyAuthCacheInvalidator := service.ProvideAPIKeyAuthCacheInvalidator(apiKeyService) + promoService := service.NewPromoService(promoCodeRepository, userRepository, billingCacheService, client, apiKeyAuthCacheInvalidator) + authService := service.NewAuthService(userRepository, configConfig, settingService, emailService, turnstileService, emailQueueService, promoService) userService := service.NewUserService(userRepository, apiKeyAuthCacheInvalidator) authHandler := handler.NewAuthHandler(configConfig, authService, userService, settingService, promoService) userHandler := handler.NewUserHandler(userService) apiKeyHandler := handler.NewAPIKeyHandler(apiKeyService) usageLogRepository := repository.NewUsageLogRepository(client, db) - usageService := service.NewUsageService(usageLogRepository, userRepository, client) + usageService := service.NewUsageService(usageLogRepository, userRepository, client, apiKeyAuthCacheInvalidator) usageHandler := handler.NewUsageHandler(usageService, apiKeyService) redeemCodeRepository := repository.NewRedeemCodeRepository(client) subscriptionService := service.NewSubscriptionService(groupRepository, userSubscriptionRepository, billingCacheService) redeemCache := repository.NewRedeemCache(redisClient) - redeemService := service.NewRedeemService(redeemCodeRepository, userRepository, subscriptionService, redeemCache, billingCacheService, client) + redeemService := service.NewRedeemService(redeemCodeRepository, userRepository, subscriptionService, redeemCache, billingCacheService, client, apiKeyAuthCacheInvalidator) redeemHandler := handler.NewRedeemHandler(redeemService) subscriptionHandler := handler.NewSubscriptionHandler(subscriptionService) dashboardService := service.NewDashboardService(usageLogRepository) diff --git a/backend/internal/server/api_contract_test.go b/backend/internal/server/api_contract_test.go index bd02f47d..4949f14b 100644 --- a/backend/internal/server/api_contract_test.go +++ b/backend/internal/server/api_contract_test.go @@ -393,7 +393,7 @@ func newContractDeps(t *testing.T) *contractDeps { apiKeyService := service.NewAPIKeyService(apiKeyRepo, userRepo, groupRepo, userSubRepo, apiKeyCache, cfg) usageRepo := newStubUsageLogRepo() - usageService := service.NewUsageService(usageRepo, userRepo, nil) + usageService := service.NewUsageService(usageRepo, userRepo, nil, nil) settingRepo := newStubSettingRepo() settingService := service.NewSettingService(settingRepo, cfg) diff --git a/backend/internal/service/auth_cache_invalidation_test.go b/backend/internal/service/auth_cache_invalidation_test.go new file mode 100644 index 00000000..3b4217c6 --- /dev/null +++ b/backend/internal/service/auth_cache_invalidation_test.go @@ -0,0 +1,31 @@ +//go:build unit + +package service + +import ( + "context" + "testing" + + "github.com/stretchr/testify/require" +) + +func TestUsageService_InvalidateUsageCaches(t *testing.T) { + invalidator := &authCacheInvalidatorStub{} + svc := &UsageService{authCacheInvalidator: invalidator} + + svc.invalidateUsageCaches(context.Background(), 7, false) + require.Empty(t, invalidator.userIDs) + + svc.invalidateUsageCaches(context.Background(), 7, true) + require.Equal(t, []int64{7}, invalidator.userIDs) +} + +func TestRedeemService_InvalidateRedeemCaches_AuthCache(t *testing.T) { + invalidator := &authCacheInvalidatorStub{} + svc := &RedeemService{authCacheInvalidator: invalidator} + + svc.invalidateRedeemCaches(context.Background(), 11, &RedeemCode{Type: RedeemTypeBalance}) + svc.invalidateRedeemCaches(context.Background(), 11, &RedeemCode{Type: RedeemTypeConcurrency}) + + require.Equal(t, []int64{11, 11}, invalidator.userIDs) +} diff --git a/backend/internal/service/group_service.go b/backend/internal/service/group_service.go index a9214c82..324f347b 100644 --- a/backend/internal/service/group_service.go +++ b/backend/internal/service/group_service.go @@ -172,12 +172,12 @@ func (s *GroupService) Delete(ctx context.Context, id int64) error { return fmt.Errorf("get group: %w", err) } - if err := s.groupRepo.Delete(ctx, id); err != nil { - return fmt.Errorf("delete group: %w", err) - } if s.authCacheInvalidator != nil { s.authCacheInvalidator.InvalidateAuthCacheByGroupID(ctx, id) } + if err := s.groupRepo.Delete(ctx, id); err != nil { + return fmt.Errorf("delete group: %w", err) + } return nil } diff --git a/backend/internal/service/promo_service.go b/backend/internal/service/promo_service.go index 9acd5868..5ff63bdc 100644 --- a/backend/internal/service/promo_service.go +++ b/backend/internal/service/promo_service.go @@ -24,10 +24,11 @@ var ( // PromoService 优惠码服务 type PromoService struct { - promoRepo PromoCodeRepository - userRepo UserRepository - billingCacheService *BillingCacheService - entClient *dbent.Client + promoRepo PromoCodeRepository + userRepo UserRepository + billingCacheService *BillingCacheService + entClient *dbent.Client + authCacheInvalidator APIKeyAuthCacheInvalidator } // NewPromoService 创建优惠码服务实例 @@ -36,12 +37,14 @@ func NewPromoService( userRepo UserRepository, billingCacheService *BillingCacheService, entClient *dbent.Client, + authCacheInvalidator APIKeyAuthCacheInvalidator, ) *PromoService { return &PromoService{ - promoRepo: promoRepo, - userRepo: userRepo, - billingCacheService: billingCacheService, - entClient: entClient, + promoRepo: promoRepo, + userRepo: userRepo, + billingCacheService: billingCacheService, + entClient: entClient, + authCacheInvalidator: authCacheInvalidator, } } @@ -145,6 +148,8 @@ func (s *PromoService) ApplyPromoCode(ctx context.Context, userID int64, code st return fmt.Errorf("commit transaction: %w", err) } + s.invalidatePromoCaches(ctx, userID, promoCode.BonusAmount) + // 失效余额缓存 if s.billingCacheService != nil { go func() { @@ -157,6 +162,13 @@ func (s *PromoService) ApplyPromoCode(ctx context.Context, userID int64, code st return nil } +func (s *PromoService) invalidatePromoCaches(ctx context.Context, userID int64, bonusAmount float64) { + if bonusAmount == 0 || s.authCacheInvalidator == nil { + return + } + s.authCacheInvalidator.InvalidateAuthCacheByUserID(ctx, userID) +} + // GenerateRandomCode 生成随机优惠码 func (s *PromoService) GenerateRandomCode() (string, error) { bytes := make([]byte, 8) diff --git a/backend/internal/service/redeem_service.go b/backend/internal/service/redeem_service.go index b6324235..81767aa9 100644 --- a/backend/internal/service/redeem_service.go +++ b/backend/internal/service/redeem_service.go @@ -68,12 +68,13 @@ type RedeemCodeResponse struct { // RedeemService 兑换码服务 type RedeemService struct { - redeemRepo RedeemCodeRepository - userRepo UserRepository - subscriptionService *SubscriptionService - cache RedeemCache - billingCacheService *BillingCacheService - entClient *dbent.Client + redeemRepo RedeemCodeRepository + userRepo UserRepository + subscriptionService *SubscriptionService + cache RedeemCache + billingCacheService *BillingCacheService + entClient *dbent.Client + authCacheInvalidator APIKeyAuthCacheInvalidator } // NewRedeemService 创建兑换码服务实例 @@ -84,14 +85,16 @@ func NewRedeemService( cache RedeemCache, billingCacheService *BillingCacheService, entClient *dbent.Client, + authCacheInvalidator APIKeyAuthCacheInvalidator, ) *RedeemService { return &RedeemService{ - redeemRepo: redeemRepo, - userRepo: userRepo, - subscriptionService: subscriptionService, - cache: cache, - billingCacheService: billingCacheService, - entClient: entClient, + redeemRepo: redeemRepo, + userRepo: userRepo, + subscriptionService: subscriptionService, + cache: cache, + billingCacheService: billingCacheService, + entClient: entClient, + authCacheInvalidator: authCacheInvalidator, } } @@ -324,18 +327,30 @@ func (s *RedeemService) Redeem(ctx context.Context, userID int64, code string) ( // invalidateRedeemCaches 失效兑换相关的缓存 func (s *RedeemService) invalidateRedeemCaches(ctx context.Context, userID int64, redeemCode *RedeemCode) { - if s.billingCacheService == nil { - return - } - switch redeemCode.Type { case RedeemTypeBalance: + if s.authCacheInvalidator != nil { + s.authCacheInvalidator.InvalidateAuthCacheByUserID(ctx, userID) + } + if s.billingCacheService == nil { + return + } go func() { cacheCtx, cancel := context.WithTimeout(context.Background(), 5*time.Second) defer cancel() _ = s.billingCacheService.InvalidateUserBalance(cacheCtx, userID) }() + case RedeemTypeConcurrency: + if s.authCacheInvalidator != nil { + s.authCacheInvalidator.InvalidateAuthCacheByUserID(ctx, userID) + } + if s.billingCacheService == nil { + return + } case RedeemTypeSubscription: + if s.billingCacheService == nil { + return + } if redeemCode.GroupID != nil { groupID := *redeemCode.GroupID go func() { diff --git a/backend/internal/service/usage_service.go b/backend/internal/service/usage_service.go index 10a294ae..aa0a5b87 100644 --- a/backend/internal/service/usage_service.go +++ b/backend/internal/service/usage_service.go @@ -54,17 +54,19 @@ type UsageStats struct { // UsageService 使用统计服务 type UsageService struct { - usageRepo UsageLogRepository - userRepo UserRepository - entClient *dbent.Client + usageRepo UsageLogRepository + userRepo UserRepository + entClient *dbent.Client + authCacheInvalidator APIKeyAuthCacheInvalidator } // NewUsageService 创建使用统计服务实例 -func NewUsageService(usageRepo UsageLogRepository, userRepo UserRepository, entClient *dbent.Client) *UsageService { +func NewUsageService(usageRepo UsageLogRepository, userRepo UserRepository, entClient *dbent.Client, authCacheInvalidator APIKeyAuthCacheInvalidator) *UsageService { return &UsageService{ - usageRepo: usageRepo, - userRepo: userRepo, - entClient: entClient, + usageRepo: usageRepo, + userRepo: userRepo, + entClient: entClient, + authCacheInvalidator: authCacheInvalidator, } } @@ -118,10 +120,12 @@ func (s *UsageService) Create(ctx context.Context, req CreateUsageLogRequest) (* } // 扣除用户余额 + balanceUpdated := false if inserted && req.ActualCost > 0 { if err := s.userRepo.UpdateBalance(txCtx, req.UserID, -req.ActualCost); err != nil { return nil, fmt.Errorf("update user balance: %w", err) } + balanceUpdated = true } if tx != nil { @@ -130,9 +134,18 @@ func (s *UsageService) Create(ctx context.Context, req CreateUsageLogRequest) (* } } + s.invalidateUsageCaches(ctx, req.UserID, balanceUpdated) + return usageLog, nil } +func (s *UsageService) invalidateUsageCaches(ctx context.Context, userID int64, balanceUpdated bool) { + if !balanceUpdated || s.authCacheInvalidator == nil { + return + } + s.authCacheInvalidator.InvalidateAuthCacheByUserID(ctx, userID) +} + // GetByID 根据ID获取使用日志 func (s *UsageService) GetByID(ctx context.Context, id int64) (*UsageLog, error) { log, err := s.usageRepo.GetByID(ctx, id) diff --git a/backend/internal/service/user_service.go b/backend/internal/service/user_service.go index a7a36760..1734914a 100644 --- a/backend/internal/service/user_service.go +++ b/backend/internal/service/user_service.go @@ -213,11 +213,11 @@ func (s *UserService) UpdateStatus(ctx context.Context, userID int64, status str // Delete 删除用户(管理员功能) func (s *UserService) Delete(ctx context.Context, userID int64) error { - if err := s.userRepo.Delete(ctx, userID); err != nil { - return fmt.Errorf("delete user: %w", err) - } if s.authCacheInvalidator != nil { s.authCacheInvalidator.InvalidateAuthCacheByUserID(ctx, userID) } + if err := s.userRepo.Delete(ctx, userID); err != nil { + return fmt.Errorf("delete user: %w", err) + } return nil } From a16f72f52e90d3cb27d75a5679eb090eb7bc6c60 Mon Sep 17 00:00:00 2001 From: yangjianbo Date: Sat, 10 Jan 2026 23:14:20 +0800 Subject: [PATCH 07/23] =?UTF-8?q?fix(=E8=AE=A4=E8=AF=81):=20=E8=AE=A2?= =?UTF-8?q?=E9=98=85=E5=85=91=E6=8D=A2=E5=A4=B1=E6=95=88=E8=AE=A4=E8=AF=81?= =?UTF-8?q?=E7=BC=93=E5=AD=98?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit 订阅兑换后同步失效认证缓存避免授权快照滞后 补充单测覆盖订阅兑换的失效场景 测试: go test ./... -tags=unit --- backend/internal/service/auth_cache_invalidation_test.go | 4 +++- backend/internal/service/redeem_service.go | 3 +++ 2 files changed, 6 insertions(+), 1 deletion(-) diff --git a/backend/internal/service/auth_cache_invalidation_test.go b/backend/internal/service/auth_cache_invalidation_test.go index 3b4217c6..b6e56177 100644 --- a/backend/internal/service/auth_cache_invalidation_test.go +++ b/backend/internal/service/auth_cache_invalidation_test.go @@ -26,6 +26,8 @@ func TestRedeemService_InvalidateRedeemCaches_AuthCache(t *testing.T) { svc.invalidateRedeemCaches(context.Background(), 11, &RedeemCode{Type: RedeemTypeBalance}) svc.invalidateRedeemCaches(context.Background(), 11, &RedeemCode{Type: RedeemTypeConcurrency}) + groupID := int64(3) + svc.invalidateRedeemCaches(context.Background(), 11, &RedeemCode{Type: RedeemTypeSubscription, GroupID: &groupID}) - require.Equal(t, []int64{11, 11}, invalidator.userIDs) + require.Equal(t, []int64{11, 11, 11}, invalidator.userIDs) } diff --git a/backend/internal/service/redeem_service.go b/backend/internal/service/redeem_service.go index 81767aa9..ff52dc47 100644 --- a/backend/internal/service/redeem_service.go +++ b/backend/internal/service/redeem_service.go @@ -348,6 +348,9 @@ func (s *RedeemService) invalidateRedeemCaches(ctx context.Context, userID int64 return } case RedeemTypeSubscription: + if s.authCacheInvalidator != nil { + s.authCacheInvalidator.InvalidateAuthCacheByUserID(ctx, userID) + } if s.billingCacheService == nil { return } From c2c865b0cbace797a295d4ded6b0806d328b815a Mon Sep 17 00:00:00 2001 From: yangjianbo Date: Sun, 11 Jan 2026 10:07:03 +0800 Subject: [PATCH 08/23] =?UTF-8?q?perf(=E4=BB=AA=E8=A1=A8=E7=9B=98):=20?= =?UTF-8?q?=E5=A2=9E=E5=BC=BA=E7=BB=9F=E8=AE=A1=E7=BC=93=E5=AD=98=E4=B8=8E?= =?UTF-8?q?=E9=9A=94=E7=A6=BB=E9=85=8D=E7=BD=AE?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit 新增仪表盘缓存开关与 TTL 配置,支持 Redis key 前缀隔离,并补充单测与校验。 测试: make test-backend --- backend/cmd/server/wire_gen.go | 3 +- backend/internal/config/config.go | 47 +++++ backend/internal/config/config_test.go | 64 ++++++ .../internal/repository/dashboard_cache.go | 51 +++++ backend/internal/repository/wire.go | 1 + backend/internal/service/dashboard_service.go | 146 +++++++++++++- .../service/dashboard_service_test.go | 189 ++++++++++++++++++ config.yaml | 21 ++ deploy/config.example.yaml | 21 ++ 9 files changed, 536 insertions(+), 7 deletions(-) create mode 100644 backend/internal/repository/dashboard_cache.go create mode 100644 backend/internal/service/dashboard_service_test.go diff --git a/backend/cmd/server/wire_gen.go b/backend/cmd/server/wire_gen.go index 95a7b30b..4fb8351e 100644 --- a/backend/cmd/server/wire_gen.go +++ b/backend/cmd/server/wire_gen.go @@ -75,7 +75,8 @@ func initializeApplication(buildInfo handler.BuildInfo) (*Application, error) { redeemService := service.NewRedeemService(redeemCodeRepository, userRepository, subscriptionService, redeemCache, billingCacheService, client, apiKeyAuthCacheInvalidator) redeemHandler := handler.NewRedeemHandler(redeemService) subscriptionHandler := handler.NewSubscriptionHandler(subscriptionService) - dashboardService := service.NewDashboardService(usageLogRepository) + dashboardStatsCache := repository.NewDashboardCache(redisClient, configConfig) + dashboardService := service.NewDashboardService(usageLogRepository, dashboardStatsCache, configConfig) dashboardHandler := admin.NewDashboardHandler(dashboardService) accountRepository := repository.NewAccountRepository(client, db) proxyRepository := repository.NewProxyRepository(client, db) diff --git a/backend/internal/config/config.go b/backend/internal/config/config.go index 29eaa42e..677d0c6e 100644 --- a/backend/internal/config/config.go +++ b/backend/internal/config/config.go @@ -50,6 +50,7 @@ type Config struct { Pricing PricingConfig `mapstructure:"pricing"` Gateway GatewayConfig `mapstructure:"gateway"` APIKeyAuth APIKeyAuthCacheConfig `mapstructure:"api_key_auth_cache"` + Dashboard DashboardCacheConfig `mapstructure:"dashboard_cache"` Concurrency ConcurrencyConfig `mapstructure:"concurrency"` TokenRefresh TokenRefreshConfig `mapstructure:"token_refresh"` RunMode string `mapstructure:"run_mode" yaml:"run_mode"` @@ -372,6 +373,20 @@ type APIKeyAuthCacheConfig struct { Singleflight bool `mapstructure:"singleflight"` } +// DashboardCacheConfig 仪表盘统计缓存配置 +type DashboardCacheConfig struct { + // Enabled: 是否启用仪表盘缓存 + Enabled bool `mapstructure:"enabled"` + // KeyPrefix: Redis key 前缀,用于多环境隔离 + KeyPrefix string `mapstructure:"key_prefix"` + // StatsFreshTTLSeconds: 缓存命中认为“新鲜”的时间窗口(秒) + StatsFreshTTLSeconds int `mapstructure:"stats_fresh_ttl_seconds"` + // StatsTTLSeconds: Redis 缓存总 TTL(秒) + StatsTTLSeconds int `mapstructure:"stats_ttl_seconds"` + // StatsRefreshTimeoutSeconds: 异步刷新超时(秒) + StatsRefreshTimeoutSeconds int `mapstructure:"stats_refresh_timeout_seconds"` +} + func NormalizeRunMode(value string) string { normalized := strings.ToLower(strings.TrimSpace(value)) switch normalized { @@ -437,6 +452,7 @@ func Load() (*Config, error) { cfg.LinuxDo.UserInfoEmailPath = strings.TrimSpace(cfg.LinuxDo.UserInfoEmailPath) cfg.LinuxDo.UserInfoIDPath = strings.TrimSpace(cfg.LinuxDo.UserInfoIDPath) cfg.LinuxDo.UserInfoUsernamePath = strings.TrimSpace(cfg.LinuxDo.UserInfoUsernamePath) + cfg.Dashboard.KeyPrefix = strings.TrimSpace(cfg.Dashboard.KeyPrefix) cfg.CORS.AllowedOrigins = normalizeStringSlice(cfg.CORS.AllowedOrigins) cfg.Security.ResponseHeaders.AdditionalAllowed = normalizeStringSlice(cfg.Security.ResponseHeaders.AdditionalAllowed) cfg.Security.ResponseHeaders.ForceRemove = normalizeStringSlice(cfg.Security.ResponseHeaders.ForceRemove) @@ -674,6 +690,13 @@ func setDefaults() { viper.SetDefault("api_key_auth_cache.jitter_percent", 10) viper.SetDefault("api_key_auth_cache.singleflight", true) + // Dashboard cache + viper.SetDefault("dashboard_cache.enabled", true) + viper.SetDefault("dashboard_cache.key_prefix", "sub2api:") + viper.SetDefault("dashboard_cache.stats_fresh_ttl_seconds", 15) + viper.SetDefault("dashboard_cache.stats_ttl_seconds", 30) + viper.SetDefault("dashboard_cache.stats_refresh_timeout_seconds", 30) + // Gateway viper.SetDefault("gateway.response_header_timeout", 600) // 600秒(10分钟)等待上游响应头,LLM高负载时可能排队较久 viper.SetDefault("gateway.log_upstream_error_body", false) @@ -832,6 +855,30 @@ func (c *Config) Validate() error { if c.Redis.MinIdleConns > c.Redis.PoolSize { return fmt.Errorf("redis.min_idle_conns cannot exceed redis.pool_size") } + if c.Dashboard.Enabled { + if c.Dashboard.StatsFreshTTLSeconds <= 0 { + return fmt.Errorf("dashboard_cache.stats_fresh_ttl_seconds must be positive") + } + if c.Dashboard.StatsTTLSeconds <= 0 { + return fmt.Errorf("dashboard_cache.stats_ttl_seconds must be positive") + } + if c.Dashboard.StatsRefreshTimeoutSeconds <= 0 { + return fmt.Errorf("dashboard_cache.stats_refresh_timeout_seconds must be positive") + } + if c.Dashboard.StatsFreshTTLSeconds > c.Dashboard.StatsTTLSeconds { + return fmt.Errorf("dashboard_cache.stats_fresh_ttl_seconds must be <= dashboard_cache.stats_ttl_seconds") + } + } else { + if c.Dashboard.StatsFreshTTLSeconds < 0 { + return fmt.Errorf("dashboard_cache.stats_fresh_ttl_seconds must be non-negative") + } + if c.Dashboard.StatsTTLSeconds < 0 { + return fmt.Errorf("dashboard_cache.stats_ttl_seconds must be non-negative") + } + if c.Dashboard.StatsRefreshTimeoutSeconds < 0 { + return fmt.Errorf("dashboard_cache.stats_refresh_timeout_seconds must be non-negative") + } + } if c.Gateway.MaxBodySize <= 0 { return fmt.Errorf("gateway.max_body_size must be positive") } diff --git a/backend/internal/config/config_test.go b/backend/internal/config/config_test.go index a39d41f9..6cd95b1c 100644 --- a/backend/internal/config/config_test.go +++ b/backend/internal/config/config_test.go @@ -141,3 +141,67 @@ func TestValidateLinuxDoPKCERequiredForPublicClient(t *testing.T) { t.Fatalf("Validate() expected use_pkce error, got: %v", err) } } + +func TestLoadDefaultDashboardCacheConfig(t *testing.T) { + viper.Reset() + + cfg, err := Load() + if err != nil { + t.Fatalf("Load() error: %v", err) + } + + if !cfg.Dashboard.Enabled { + t.Fatalf("Dashboard.Enabled = false, want true") + } + if cfg.Dashboard.KeyPrefix != "sub2api:" { + t.Fatalf("Dashboard.KeyPrefix = %q, want %q", cfg.Dashboard.KeyPrefix, "sub2api:") + } + if cfg.Dashboard.StatsFreshTTLSeconds != 15 { + t.Fatalf("Dashboard.StatsFreshTTLSeconds = %d, want 15", cfg.Dashboard.StatsFreshTTLSeconds) + } + if cfg.Dashboard.StatsTTLSeconds != 30 { + t.Fatalf("Dashboard.StatsTTLSeconds = %d, want 30", cfg.Dashboard.StatsTTLSeconds) + } + if cfg.Dashboard.StatsRefreshTimeoutSeconds != 30 { + t.Fatalf("Dashboard.StatsRefreshTimeoutSeconds = %d, want 30", cfg.Dashboard.StatsRefreshTimeoutSeconds) + } +} + +func TestValidateDashboardCacheConfigEnabled(t *testing.T) { + viper.Reset() + + cfg, err := Load() + if err != nil { + t.Fatalf("Load() error: %v", err) + } + + cfg.Dashboard.Enabled = true + cfg.Dashboard.StatsFreshTTLSeconds = 10 + cfg.Dashboard.StatsTTLSeconds = 5 + err = cfg.Validate() + if err == nil { + t.Fatalf("Validate() expected error for stats_fresh_ttl_seconds > stats_ttl_seconds, got nil") + } + if !strings.Contains(err.Error(), "dashboard_cache.stats_fresh_ttl_seconds") { + t.Fatalf("Validate() expected stats_fresh_ttl_seconds error, got: %v", err) + } +} + +func TestValidateDashboardCacheConfigDisabled(t *testing.T) { + viper.Reset() + + cfg, err := Load() + if err != nil { + t.Fatalf("Load() error: %v", err) + } + + cfg.Dashboard.Enabled = false + cfg.Dashboard.StatsTTLSeconds = -1 + err = cfg.Validate() + if err == nil { + t.Fatalf("Validate() expected error for negative stats_ttl_seconds, got nil") + } + if !strings.Contains(err.Error(), "dashboard_cache.stats_ttl_seconds") { + t.Fatalf("Validate() expected stats_ttl_seconds error, got: %v", err) + } +} diff --git a/backend/internal/repository/dashboard_cache.go b/backend/internal/repository/dashboard_cache.go new file mode 100644 index 00000000..ec6ef25c --- /dev/null +++ b/backend/internal/repository/dashboard_cache.go @@ -0,0 +1,51 @@ +package repository + +import ( + "context" + "strings" + "time" + + "github.com/Wei-Shaw/sub2api/internal/config" + "github.com/Wei-Shaw/sub2api/internal/service" + "github.com/redis/go-redis/v9" +) + +const dashboardStatsCacheKey = "dashboard:stats:v1" + +type dashboardCache struct { + rdb *redis.Client + keyPrefix string +} + +func NewDashboardCache(rdb *redis.Client, cfg *config.Config) service.DashboardStatsCache { + prefix := "sub2api:" + if cfg != nil { + prefix = strings.TrimSpace(cfg.Dashboard.KeyPrefix) + } + return &dashboardCache{ + rdb: rdb, + keyPrefix: prefix, + } +} + +func (c *dashboardCache) GetDashboardStats(ctx context.Context) (string, error) { + val, err := c.rdb.Get(ctx, c.buildKey()).Result() + if err != nil { + if err == redis.Nil { + return "", service.ErrDashboardStatsCacheMiss + } + return "", err + } + return val, nil +} + +func (c *dashboardCache) SetDashboardStats(ctx context.Context, data string, ttl time.Duration) error { + return c.rdb.Set(ctx, c.buildKey(), data, ttl).Err() +} + +func (c *dashboardCache) buildKey() string { + if c.keyPrefix == "" { + return dashboardStatsCacheKey + } + return c.keyPrefix + dashboardStatsCacheKey +} diff --git a/backend/internal/repository/wire.go b/backend/internal/repository/wire.go index 0a6118e2..1b6a7b91 100644 --- a/backend/internal/repository/wire.go +++ b/backend/internal/repository/wire.go @@ -58,6 +58,7 @@ var ProviderSet = wire.NewSet( NewAPIKeyCache, NewTempUnschedCache, ProvideConcurrencyCache, + NewDashboardCache, NewEmailCache, NewIdentityCache, NewRedeemCache, diff --git a/backend/internal/service/dashboard_service.go b/backend/internal/service/dashboard_service.go index f0b1f2a0..f56480d3 100644 --- a/backend/internal/service/dashboard_service.go +++ b/backend/internal/service/dashboard_service.go @@ -2,25 +2,89 @@ package service import ( "context" + "encoding/json" + "errors" "fmt" + "log" + "sync/atomic" "time" + "github.com/Wei-Shaw/sub2api/internal/config" "github.com/Wei-Shaw/sub2api/internal/pkg/usagestats" ) -// DashboardService provides aggregated statistics for admin dashboard. -type DashboardService struct { - usageRepo UsageLogRepository +const ( + defaultDashboardStatsFreshTTL = 15 * time.Second + defaultDashboardStatsCacheTTL = 30 * time.Second + defaultDashboardStatsRefreshTimeout = 30 * time.Second +) + +// ErrDashboardStatsCacheMiss 标记仪表盘缓存未命中。 +var ErrDashboardStatsCacheMiss = errors.New("仪表盘缓存未命中") + +// DashboardStatsCache 定义仪表盘统计缓存接口。 +type DashboardStatsCache interface { + GetDashboardStats(ctx context.Context) (string, error) + SetDashboardStats(ctx context.Context, data string, ttl time.Duration) error } -func NewDashboardService(usageRepo UsageLogRepository) *DashboardService { +type dashboardStatsCacheEntry struct { + Stats *usagestats.DashboardStats `json:"stats"` + UpdatedAt int64 `json:"updated_at"` +} + +// DashboardService provides aggregated statistics for admin dashboard. +type DashboardService struct { + usageRepo UsageLogRepository + cache DashboardStatsCache + cacheFreshTTL time.Duration + cacheTTL time.Duration + refreshTimeout time.Duration + refreshing int32 +} + +func NewDashboardService(usageRepo UsageLogRepository, cache DashboardStatsCache, cfg *config.Config) *DashboardService { + freshTTL := defaultDashboardStatsFreshTTL + cacheTTL := defaultDashboardStatsCacheTTL + refreshTimeout := defaultDashboardStatsRefreshTimeout + if cfg != nil { + if !cfg.Dashboard.Enabled { + cache = nil + } + if cfg.Dashboard.StatsFreshTTLSeconds > 0 { + freshTTL = time.Duration(cfg.Dashboard.StatsFreshTTLSeconds) * time.Second + } + if cfg.Dashboard.StatsTTLSeconds > 0 { + cacheTTL = time.Duration(cfg.Dashboard.StatsTTLSeconds) * time.Second + } + if cfg.Dashboard.StatsRefreshTimeoutSeconds > 0 { + refreshTimeout = time.Duration(cfg.Dashboard.StatsRefreshTimeoutSeconds) * time.Second + } + } return &DashboardService{ - usageRepo: usageRepo, + usageRepo: usageRepo, + cache: cache, + cacheFreshTTL: freshTTL, + cacheTTL: cacheTTL, + refreshTimeout: refreshTimeout, } } func (s *DashboardService) GetDashboardStats(ctx context.Context) (*usagestats.DashboardStats, error) { - stats, err := s.usageRepo.GetDashboardStats(ctx) + if s.cache != nil { + cached, fresh, err := s.getCachedDashboardStats(ctx) + if err == nil && cached != nil { + if !fresh { + s.refreshDashboardStatsAsync() + } + return cached, nil + } + if err != nil && !errors.Is(err, ErrDashboardStatsCacheMiss) { + log.Printf("[Dashboard] 仪表盘缓存读取失败: %v", err) + } + } + + stats, err := s.refreshDashboardStats(ctx) if err != nil { return nil, fmt.Errorf("get dashboard stats: %w", err) } @@ -43,6 +107,76 @@ func (s *DashboardService) GetModelStatsWithFilters(ctx context.Context, startTi return stats, nil } +func (s *DashboardService) getCachedDashboardStats(ctx context.Context) (*usagestats.DashboardStats, bool, error) { + data, err := s.cache.GetDashboardStats(ctx) + if err != nil { + return nil, false, err + } + + var entry dashboardStatsCacheEntry + if err := json.Unmarshal([]byte(data), &entry); err != nil { + return nil, false, err + } + if entry.Stats == nil { + return nil, false, errors.New("仪表盘缓存缺少统计数据") + } + + age := time.Since(time.Unix(entry.UpdatedAt, 0)) + return entry.Stats, age <= s.cacheFreshTTL, nil +} + +func (s *DashboardService) refreshDashboardStats(ctx context.Context) (*usagestats.DashboardStats, error) { + stats, err := s.usageRepo.GetDashboardStats(ctx) + if err != nil { + return nil, err + } + s.saveDashboardStatsCache(ctx, stats) + return stats, nil +} + +func (s *DashboardService) refreshDashboardStatsAsync() { + if s.cache == nil { + return + } + if !atomic.CompareAndSwapInt32(&s.refreshing, 0, 1) { + return + } + + go func() { + defer atomic.StoreInt32(&s.refreshing, 0) + + ctx, cancel := context.WithTimeout(context.Background(), s.refreshTimeout) + defer cancel() + + stats, err := s.usageRepo.GetDashboardStats(ctx) + if err != nil { + log.Printf("[Dashboard] 仪表盘缓存异步刷新失败: %v", err) + return + } + s.saveDashboardStatsCache(ctx, stats) + }() +} + +func (s *DashboardService) saveDashboardStatsCache(ctx context.Context, stats *usagestats.DashboardStats) { + if s.cache == nil || stats == nil { + return + } + + entry := dashboardStatsCacheEntry{ + Stats: stats, + UpdatedAt: time.Now().Unix(), + } + data, err := json.Marshal(entry) + if err != nil { + log.Printf("[Dashboard] 仪表盘缓存序列化失败: %v", err) + return + } + + if err := s.cache.SetDashboardStats(ctx, string(data), s.cacheTTL); err != nil { + log.Printf("[Dashboard] 仪表盘缓存写入失败: %v", err) + } +} + func (s *DashboardService) GetAPIKeyUsageTrend(ctx context.Context, startTime, endTime time.Time, granularity string, limit int) ([]usagestats.APIKeyUsageTrendPoint, error) { trend, err := s.usageRepo.GetAPIKeyUsageTrend(ctx, startTime, endTime, granularity, limit) if err != nil { diff --git a/backend/internal/service/dashboard_service_test.go b/backend/internal/service/dashboard_service_test.go new file mode 100644 index 00000000..21d7b580 --- /dev/null +++ b/backend/internal/service/dashboard_service_test.go @@ -0,0 +1,189 @@ +package service + +import ( + "context" + "encoding/json" + "sync" + "sync/atomic" + "testing" + "time" + + "github.com/Wei-Shaw/sub2api/internal/config" + "github.com/Wei-Shaw/sub2api/internal/pkg/usagestats" + "github.com/stretchr/testify/require" +) + +type usageRepoStub struct { + UsageLogRepository + stats *usagestats.DashboardStats + err error + calls int32 + onCall chan struct{} +} + +func (s *usageRepoStub) GetDashboardStats(ctx context.Context) (*usagestats.DashboardStats, error) { + atomic.AddInt32(&s.calls, 1) + if s.onCall != nil { + select { + case s.onCall <- struct{}{}: + default: + } + } + if s.err != nil { + return nil, s.err + } + return s.stats, nil +} + +type dashboardCacheStub struct { + get func(ctx context.Context) (string, error) + set func(ctx context.Context, data string, ttl time.Duration) error + getCalls int32 + setCalls int32 + lastSetMu sync.Mutex + lastSet string +} + +func (c *dashboardCacheStub) GetDashboardStats(ctx context.Context) (string, error) { + atomic.AddInt32(&c.getCalls, 1) + if c.get != nil { + return c.get(ctx) + } + return "", ErrDashboardStatsCacheMiss +} + +func (c *dashboardCacheStub) SetDashboardStats(ctx context.Context, data string, ttl time.Duration) error { + atomic.AddInt32(&c.setCalls, 1) + c.lastSetMu.Lock() + c.lastSet = data + c.lastSetMu.Unlock() + if c.set != nil { + return c.set(ctx, data, ttl) + } + return nil +} + +func (c *dashboardCacheStub) readLastEntry(t *testing.T) dashboardStatsCacheEntry { + t.Helper() + c.lastSetMu.Lock() + data := c.lastSet + c.lastSetMu.Unlock() + + var entry dashboardStatsCacheEntry + err := json.Unmarshal([]byte(data), &entry) + require.NoError(t, err) + return entry +} + +func TestDashboardService_CacheHitFresh(t *testing.T) { + stats := &usagestats.DashboardStats{ + TotalUsers: 10, + } + entry := dashboardStatsCacheEntry{ + Stats: stats, + UpdatedAt: time.Now().Unix(), + } + payload, err := json.Marshal(entry) + require.NoError(t, err) + + cache := &dashboardCacheStub{ + get: func(ctx context.Context) (string, error) { + return string(payload), nil + }, + } + repo := &usageRepoStub{ + stats: &usagestats.DashboardStats{TotalUsers: 99}, + } + cfg := &config.Config{Dashboard: config.DashboardCacheConfig{Enabled: true}} + svc := NewDashboardService(repo, cache, cfg) + + got, err := svc.GetDashboardStats(context.Background()) + require.NoError(t, err) + require.Equal(t, stats, got) + require.Equal(t, int32(0), atomic.LoadInt32(&repo.calls)) + require.Equal(t, int32(1), atomic.LoadInt32(&cache.getCalls)) + require.Equal(t, int32(0), atomic.LoadInt32(&cache.setCalls)) +} + +func TestDashboardService_CacheMiss_StoresCache(t *testing.T) { + stats := &usagestats.DashboardStats{ + TotalUsers: 7, + } + cache := &dashboardCacheStub{ + get: func(ctx context.Context) (string, error) { + return "", ErrDashboardStatsCacheMiss + }, + } + repo := &usageRepoStub{stats: stats} + cfg := &config.Config{Dashboard: config.DashboardCacheConfig{Enabled: true}} + svc := NewDashboardService(repo, cache, cfg) + + got, err := svc.GetDashboardStats(context.Background()) + require.NoError(t, err) + require.Equal(t, stats, got) + require.Equal(t, int32(1), atomic.LoadInt32(&repo.calls)) + require.Equal(t, int32(1), atomic.LoadInt32(&cache.getCalls)) + require.Equal(t, int32(1), atomic.LoadInt32(&cache.setCalls)) + entry := cache.readLastEntry(t) + require.Equal(t, stats, entry.Stats) + require.WithinDuration(t, time.Now(), time.Unix(entry.UpdatedAt, 0), time.Second) +} + +func TestDashboardService_CacheDisabled_SkipsCache(t *testing.T) { + stats := &usagestats.DashboardStats{ + TotalUsers: 3, + } + cache := &dashboardCacheStub{ + get: func(ctx context.Context) (string, error) { + return "", nil + }, + } + repo := &usageRepoStub{stats: stats} + cfg := &config.Config{Dashboard: config.DashboardCacheConfig{Enabled: false}} + svc := NewDashboardService(repo, cache, cfg) + + got, err := svc.GetDashboardStats(context.Background()) + require.NoError(t, err) + require.Equal(t, stats, got) + require.Equal(t, int32(1), atomic.LoadInt32(&repo.calls)) + require.Equal(t, int32(0), atomic.LoadInt32(&cache.getCalls)) + require.Equal(t, int32(0), atomic.LoadInt32(&cache.setCalls)) +} + +func TestDashboardService_CacheHitStale_TriggersAsyncRefresh(t *testing.T) { + staleStats := &usagestats.DashboardStats{ + TotalUsers: 11, + } + entry := dashboardStatsCacheEntry{ + Stats: staleStats, + UpdatedAt: time.Now().Add(-defaultDashboardStatsFreshTTL * 2).Unix(), + } + payload, err := json.Marshal(entry) + require.NoError(t, err) + + cache := &dashboardCacheStub{ + get: func(ctx context.Context) (string, error) { + return string(payload), nil + }, + } + refreshCh := make(chan struct{}, 1) + repo := &usageRepoStub{ + stats: &usagestats.DashboardStats{TotalUsers: 22}, + onCall: refreshCh, + } + cfg := &config.Config{Dashboard: config.DashboardCacheConfig{Enabled: true}} + svc := NewDashboardService(repo, cache, cfg) + + got, err := svc.GetDashboardStats(context.Background()) + require.NoError(t, err) + require.Equal(t, staleStats, got) + + select { + case <-refreshCh: + case <-time.After(1 * time.Second): + t.Fatal("等待异步刷新超时") + } + require.Eventually(t, func() bool { + return atomic.LoadInt32(&cache.setCalls) >= 1 + }, 1*time.Second, 10*time.Millisecond) +} diff --git a/config.yaml b/config.yaml index ecd7dfc2..ffc070a0 100644 --- a/config.yaml +++ b/config.yaml @@ -194,6 +194,27 @@ api_key_auth_cache: # 缓存未命中时启用 singleflight 合并回源 singleflight: true +# ============================================================================= +# Dashboard Cache Configuration +# 仪表盘缓存配置 +# ============================================================================= +dashboard_cache: + # Enable dashboard cache + # 启用仪表盘缓存 + enabled: true + # Redis key prefix for multi-environment isolation + # Redis key 前缀,用于多环境隔离 + key_prefix: "sub2api:" + # Fresh TTL (seconds); within this window cached stats are considered fresh + # 新鲜阈值(秒);命中后处于该窗口视为新鲜数据 + stats_fresh_ttl_seconds: 15 + # Cache TTL (seconds) stored in Redis + # Redis 缓存 TTL(秒) + stats_ttl_seconds: 30 + # Async refresh timeout (seconds) + # 异步刷新超时(秒) + stats_refresh_timeout_seconds: 30 + # ============================================================================= # Concurrency Wait Configuration # 并发等待配置 diff --git a/deploy/config.example.yaml b/deploy/config.example.yaml index 87abffa0..7083f9e9 100644 --- a/deploy/config.example.yaml +++ b/deploy/config.example.yaml @@ -194,6 +194,27 @@ api_key_auth_cache: # 缓存未命中时启用 singleflight 合并回源 singleflight: true +# ============================================================================= +# Dashboard Cache Configuration +# 仪表盘缓存配置 +# ============================================================================= +dashboard_cache: + # Enable dashboard cache + # 启用仪表盘缓存 + enabled: true + # Redis key prefix for multi-environment isolation + # Redis key 前缀,用于多环境隔离 + key_prefix: "sub2api:" + # Fresh TTL (seconds); within this window cached stats are considered fresh + # 新鲜阈值(秒);命中后处于该窗口视为新鲜数据 + stats_fresh_ttl_seconds: 15 + # Cache TTL (seconds) stored in Redis + # Redis 缓存 TTL(秒) + stats_ttl_seconds: 30 + # Async refresh timeout (seconds) + # 异步刷新超时(秒) + stats_refresh_timeout_seconds: 30 + # ============================================================================= # Concurrency Wait Configuration # 并发等待配置 From 44a93c1922fedcd6979955383b4a9530735580ed Mon Sep 17 00:00:00 2001 From: yangjianbo Date: Sat, 10 Jan 2026 22:23:51 +0800 Subject: [PATCH 09/23] =?UTF-8?q?perf(=E8=AE=A4=E8=AF=81):=20=E5=BC=95?= =?UTF-8?q?=E5=85=A5=20API=20Key=20=E8=AE=A4=E8=AF=81=E7=BC=93=E5=AD=98?= =?UTF-8?q?=E4=B8=8E=E8=BD=BB=E9=87=8F=E5=88=A0=E9=99=A4=E6=9F=A5=E8=AF=A2?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit 增加 L1/L2 缓存、负缓存与单飞回源 使用 key+owner 轻量查询替代全量加载并清理旧接口 补充缓存失效与余额更新测试,修复随机抖动 lint 测试: make test --- backend/cmd/server/wire_gen.go | 9 +- backend/go.mod | 2 + backend/go.sum | 4 + backend/internal/config/config.go | 57 ++- backend/internal/repository/api_key_cache.go | 33 ++ backend/internal/repository/api_key_repo.go | 88 +++- backend/internal/server/api_contract_test.go | 32 +- .../middleware/api_key_auth_google_test.go | 13 +- .../server/middleware/api_key_auth_test.go | 16 +- backend/internal/service/admin_service.go | 66 ++- .../admin_service_update_balance_test.go | 97 ++++ .../internal/service/api_key_auth_cache.go | 46 ++ .../service/api_key_auth_cache_impl.go | 269 +++++++++++ .../service/api_key_auth_cache_invalidate.go | 48 ++ backend/internal/service/api_key_service.go | 92 +++- .../service/api_key_service_cache_test.go | 417 ++++++++++++++++++ .../service/api_key_service_delete_test.go | 78 +++- backend/internal/service/group_service.go | 14 +- backend/internal/service/user_service.go | 24 +- backend/internal/service/wire.go | 6 + config.yaml | 24 + deploy/config.example.yaml | 24 + 22 files changed, 1360 insertions(+), 99 deletions(-) create mode 100644 backend/internal/service/admin_service_update_balance_test.go create mode 100644 backend/internal/service/api_key_auth_cache.go create mode 100644 backend/internal/service/api_key_auth_cache_impl.go create mode 100644 backend/internal/service/api_key_auth_cache_invalidate.go create mode 100644 backend/internal/service/api_key_service_cache_test.go diff --git a/backend/cmd/server/wire_gen.go b/backend/cmd/server/wire_gen.go index 58f8cebf..561b0aeb 100644 --- a/backend/cmd/server/wire_gen.go +++ b/backend/cmd/server/wire_gen.go @@ -57,13 +57,14 @@ func initializeApplication(buildInfo handler.BuildInfo) (*Application, error) { billingCacheService := service.NewBillingCacheService(billingCache, userRepository, userSubscriptionRepository, configConfig) promoService := service.NewPromoService(promoCodeRepository, userRepository, billingCacheService, client) authService := service.NewAuthService(userRepository, configConfig, settingService, emailService, turnstileService, emailQueueService, promoService) - userService := service.NewUserService(userRepository) - authHandler := handler.NewAuthHandler(configConfig, authService, userService, settingService, promoService) - userHandler := handler.NewUserHandler(userService) apiKeyRepository := repository.NewAPIKeyRepository(client) groupRepository := repository.NewGroupRepository(client, db) apiKeyCache := repository.NewAPIKeyCache(redisClient) apiKeyService := service.NewAPIKeyService(apiKeyRepository, userRepository, groupRepository, userSubscriptionRepository, apiKeyCache, configConfig) + apiKeyAuthCacheInvalidator := service.ProvideAPIKeyAuthCacheInvalidator(apiKeyService) + userService := service.NewUserService(userRepository, apiKeyAuthCacheInvalidator) + authHandler := handler.NewAuthHandler(configConfig, authService, userService, settingService, promoService) + userHandler := handler.NewUserHandler(userService) apiKeyHandler := handler.NewAPIKeyHandler(apiKeyService) usageLogRepository := repository.NewUsageLogRepository(client, db) usageService := service.NewUsageService(usageLogRepository, userRepository, client) @@ -79,7 +80,7 @@ func initializeApplication(buildInfo handler.BuildInfo) (*Application, error) { accountRepository := repository.NewAccountRepository(client, db) proxyRepository := repository.NewProxyRepository(client, db) proxyExitInfoProber := repository.NewProxyExitInfoProber(configConfig) - adminService := service.NewAdminService(userRepository, groupRepository, accountRepository, proxyRepository, apiKeyRepository, redeemCodeRepository, billingCacheService, proxyExitInfoProber) + adminService := service.NewAdminService(userRepository, groupRepository, accountRepository, proxyRepository, apiKeyRepository, redeemCodeRepository, billingCacheService, proxyExitInfoProber, apiKeyAuthCacheInvalidator) adminUserHandler := admin.NewUserHandler(adminService) groupHandler := admin.NewGroupHandler(adminService) claudeOAuthClient := repository.NewClaudeOAuthClient() diff --git a/backend/go.mod b/backend/go.mod index 9ac48305..82a8e88e 100644 --- a/backend/go.mod +++ b/backend/go.mod @@ -44,11 +44,13 @@ require ( github.com/containerd/platforms v0.2.1 // indirect github.com/cpuguy83/dockercfg v0.3.2 // indirect github.com/davecgh/go-spew v1.1.2-0.20180830191138-d8f796af33cc // indirect + github.com/dgraph-io/ristretto v0.2.0 // indirect github.com/dgryski/go-rendezvous v0.0.0-20200823014737-9f7001d12a5f // indirect github.com/distribution/reference v0.6.0 // indirect github.com/docker/docker v28.5.1+incompatible // indirect github.com/docker/go-connections v0.6.0 // indirect github.com/docker/go-units v0.5.0 // indirect + github.com/dustin/go-humanize v1.0.1 // indirect github.com/ebitengine/purego v0.8.4 // indirect github.com/fatih/color v1.18.0 // indirect github.com/felixge/httpsnoop v1.0.4 // indirect diff --git a/backend/go.sum b/backend/go.sum index 38e2b53e..0fd47498 100644 --- a/backend/go.sum +++ b/backend/go.sum @@ -51,6 +51,8 @@ github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSs github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= github.com/davecgh/go-spew v1.1.2-0.20180830191138-d8f796af33cc h1:U9qPSI2PIWSS1VwoXQT9A3Wy9MM3WgvqSxFWenqJduM= github.com/davecgh/go-spew v1.1.2-0.20180830191138-d8f796af33cc/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= +github.com/dgraph-io/ristretto v0.2.0 h1:XAfl+7cmoUDWW/2Lx8TGZQjjxIQ2Ley9DSf52dru4WE= +github.com/dgraph-io/ristretto v0.2.0/go.mod h1:8uBHCU/PBV4Ag0CJrP47b9Ofby5dqWNh4FicAdoqFNU= github.com/dgryski/go-rendezvous v0.0.0-20200823014737-9f7001d12a5f h1:lO4WD4F/rVNCu3HqELle0jiPLLBs70cWOduZpkS1E78= github.com/dgryski/go-rendezvous v0.0.0-20200823014737-9f7001d12a5f/go.mod h1:cuUVRXasLTGF7a8hSLbxyZXjz+1KgoB3wDUb6vlszIc= github.com/distribution/reference v0.6.0 h1:0IXCQ5g4/QMHHkarYzh5l+u8T3t73zM5QvfrDyIgxBk= @@ -61,6 +63,8 @@ github.com/docker/go-connections v0.6.0 h1:LlMG9azAe1TqfR7sO+NJttz1gy6KO7VJBh+pM github.com/docker/go-connections v0.6.0/go.mod h1:AahvXYshr6JgfUJGdDCs2b5EZG/vmaMAntpSFH5BFKE= github.com/docker/go-units v0.5.0 h1:69rxXcBk27SvSaaxTtLh/8llcHD8vYHT7WSdRZ/jvr4= github.com/docker/go-units v0.5.0/go.mod h1:fgPhTUdO+D/Jk86RDLlptpiXQzgHJF7gydDDbaIK4Dk= +github.com/dustin/go-humanize v1.0.1 h1:GzkhY7T5VNhEkwH0PVJgjz+fX1rhBrR7pRT3mDkpeCY= +github.com/dustin/go-humanize v1.0.1/go.mod h1:Mu1zIs6XwVuF/gI1OepvI0qD18qycQx+mFykh5fBlto= github.com/ebitengine/purego v0.8.4 h1:CF7LEKg5FFOsASUj0+QwaXf8Ht6TlFxg09+S9wz0omw= github.com/ebitengine/purego v0.8.4/go.mod h1:iIjxzd6CiRiOG0UyXP+V1+jWqUXVjPKLAI0mRfJZTmQ= github.com/fatih/color v1.18.0 h1:S8gINlzdQ840/4pfAwic/ZE0djQEH3wM94VfqLTZcOM= diff --git a/backend/internal/config/config.go b/backend/internal/config/config.go index 40344cd4..ad5bd403 100644 --- a/backend/internal/config/config.go +++ b/backend/internal/config/config.go @@ -36,25 +36,26 @@ const ( ) type Config struct { - Server ServerConfig `mapstructure:"server"` - CORS CORSConfig `mapstructure:"cors"` - Security SecurityConfig `mapstructure:"security"` - Billing BillingConfig `mapstructure:"billing"` - Turnstile TurnstileConfig `mapstructure:"turnstile"` - Database DatabaseConfig `mapstructure:"database"` - Redis RedisConfig `mapstructure:"redis"` - JWT JWTConfig `mapstructure:"jwt"` - LinuxDo LinuxDoConnectConfig `mapstructure:"linuxdo_connect"` - Default DefaultConfig `mapstructure:"default"` - RateLimit RateLimitConfig `mapstructure:"rate_limit"` - Pricing PricingConfig `mapstructure:"pricing"` - Gateway GatewayConfig `mapstructure:"gateway"` - Concurrency ConcurrencyConfig `mapstructure:"concurrency"` - TokenRefresh TokenRefreshConfig `mapstructure:"token_refresh"` - RunMode string `mapstructure:"run_mode" yaml:"run_mode"` - Timezone string `mapstructure:"timezone"` // e.g. "Asia/Shanghai", "UTC" - Gemini GeminiConfig `mapstructure:"gemini"` - Update UpdateConfig `mapstructure:"update"` + Server ServerConfig `mapstructure:"server"` + CORS CORSConfig `mapstructure:"cors"` + Security SecurityConfig `mapstructure:"security"` + Billing BillingConfig `mapstructure:"billing"` + Turnstile TurnstileConfig `mapstructure:"turnstile"` + Database DatabaseConfig `mapstructure:"database"` + Redis RedisConfig `mapstructure:"redis"` + JWT JWTConfig `mapstructure:"jwt"` + LinuxDo LinuxDoConnectConfig `mapstructure:"linuxdo_connect"` + Default DefaultConfig `mapstructure:"default"` + RateLimit RateLimitConfig `mapstructure:"rate_limit"` + Pricing PricingConfig `mapstructure:"pricing"` + Gateway GatewayConfig `mapstructure:"gateway"` + APIKeyAuth APIKeyAuthCacheConfig `mapstructure:"api_key_auth_cache"` + Concurrency ConcurrencyConfig `mapstructure:"concurrency"` + TokenRefresh TokenRefreshConfig `mapstructure:"token_refresh"` + RunMode string `mapstructure:"run_mode" yaml:"run_mode"` + Timezone string `mapstructure:"timezone"` // e.g. "Asia/Shanghai", "UTC" + Gemini GeminiConfig `mapstructure:"gemini"` + Update UpdateConfig `mapstructure:"update"` } // UpdateConfig 在线更新相关配置 @@ -375,6 +376,16 @@ type RateLimitConfig struct { OverloadCooldownMinutes int `mapstructure:"overload_cooldown_minutes"` // 529过载冷却时间(分钟) } +// APIKeyAuthCacheConfig API Key 认证缓存配置 +type APIKeyAuthCacheConfig struct { + L1Size int `mapstructure:"l1_size"` + L1TTLSeconds int `mapstructure:"l1_ttl_seconds"` + L2TTLSeconds int `mapstructure:"l2_ttl_seconds"` + NegativeTTLSeconds int `mapstructure:"negative_ttl_seconds"` + JitterPercent int `mapstructure:"jitter_percent"` + Singleflight bool `mapstructure:"singleflight"` +} + func NormalizeRunMode(value string) string { normalized := strings.ToLower(strings.TrimSpace(value)) switch normalized { @@ -669,6 +680,14 @@ func setDefaults() { // Timezone (default to Asia/Shanghai for Chinese users) viper.SetDefault("timezone", "Asia/Shanghai") + // API Key auth cache + viper.SetDefault("api_key_auth_cache.l1_size", 65535) + viper.SetDefault("api_key_auth_cache.l1_ttl_seconds", 15) + viper.SetDefault("api_key_auth_cache.l2_ttl_seconds", 300) + viper.SetDefault("api_key_auth_cache.negative_ttl_seconds", 30) + viper.SetDefault("api_key_auth_cache.jitter_percent", 10) + viper.SetDefault("api_key_auth_cache.singleflight", true) + // Gateway viper.SetDefault("gateway.response_header_timeout", 600) // 600秒(10分钟)等待上游响应头,LLM高负载时可能排队较久 viper.SetDefault("gateway.log_upstream_error_body", false) diff --git a/backend/internal/repository/api_key_cache.go b/backend/internal/repository/api_key_cache.go index 73a929c5..6d834b40 100644 --- a/backend/internal/repository/api_key_cache.go +++ b/backend/internal/repository/api_key_cache.go @@ -2,6 +2,7 @@ package repository import ( "context" + "encoding/json" "errors" "fmt" "time" @@ -13,6 +14,7 @@ import ( const ( apiKeyRateLimitKeyPrefix = "apikey:ratelimit:" apiKeyRateLimitDuration = 24 * time.Hour + apiKeyAuthCachePrefix = "apikey:auth:" ) // apiKeyRateLimitKey generates the Redis key for API key creation rate limiting. @@ -20,6 +22,10 @@ func apiKeyRateLimitKey(userID int64) string { return fmt.Sprintf("%s%d", apiKeyRateLimitKeyPrefix, userID) } +func apiKeyAuthCacheKey(key string) string { + return fmt.Sprintf("%s%s", apiKeyAuthCachePrefix, key) +} + type apiKeyCache struct { rdb *redis.Client } @@ -58,3 +64,30 @@ func (c *apiKeyCache) IncrementDailyUsage(ctx context.Context, apiKey string) er func (c *apiKeyCache) SetDailyUsageExpiry(ctx context.Context, apiKey string, ttl time.Duration) error { return c.rdb.Expire(ctx, apiKey, ttl).Err() } + +func (c *apiKeyCache) GetAuthCache(ctx context.Context, key string) (*service.APIKeyAuthCacheEntry, error) { + val, err := c.rdb.Get(ctx, apiKeyAuthCacheKey(key)).Bytes() + if err != nil { + return nil, err + } + var entry service.APIKeyAuthCacheEntry + if err := json.Unmarshal(val, &entry); err != nil { + return nil, err + } + return &entry, nil +} + +func (c *apiKeyCache) SetAuthCache(ctx context.Context, key string, entry *service.APIKeyAuthCacheEntry, ttl time.Duration) error { + if entry == nil { + return nil + } + payload, err := json.Marshal(entry) + if err != nil { + return err + } + return c.rdb.Set(ctx, apiKeyAuthCacheKey(key), payload, ttl).Err() +} + +func (c *apiKeyCache) DeleteAuthCache(ctx context.Context, key string) error { + return c.rdb.Del(ctx, apiKeyAuthCacheKey(key)).Err() +} diff --git a/backend/internal/repository/api_key_repo.go b/backend/internal/repository/api_key_repo.go index 6b8cd40d..77a3f233 100644 --- a/backend/internal/repository/api_key_repo.go +++ b/backend/internal/repository/api_key_repo.go @@ -6,7 +6,9 @@ import ( dbent "github.com/Wei-Shaw/sub2api/ent" "github.com/Wei-Shaw/sub2api/ent/apikey" + "github.com/Wei-Shaw/sub2api/ent/group" "github.com/Wei-Shaw/sub2api/ent/schema/mixins" + "github.com/Wei-Shaw/sub2api/ent/user" "github.com/Wei-Shaw/sub2api/internal/service" "github.com/Wei-Shaw/sub2api/internal/pkg/pagination" @@ -64,23 +66,23 @@ func (r *apiKeyRepository) GetByID(ctx context.Context, id int64) (*service.APIK return apiKeyEntityToService(m), nil } -// GetOwnerID 根据 API Key ID 获取其所有者(用户)的 ID。 +// GetKeyAndOwnerID 根据 API Key ID 获取其 key 与所有者(用户)ID。 // 相比 GetByID,此方法性能更优,因为: -// - 使用 Select() 只查询 user_id 字段,减少数据传输量 +// - 使用 Select() 只查询必要字段,减少数据传输量 // - 不加载完整的 API Key 实体及其关联数据(User、Group 等) -// - 适用于权限验证等只需用户 ID 的场景(如删除前的所有权检查) -func (r *apiKeyRepository) GetOwnerID(ctx context.Context, id int64) (int64, error) { +// - 适用于删除等只需 key 与用户 ID 的场景 +func (r *apiKeyRepository) GetKeyAndOwnerID(ctx context.Context, id int64) (string, int64, error) { m, err := r.activeQuery(). Where(apikey.IDEQ(id)). - Select(apikey.FieldUserID). + Select(apikey.FieldKey, apikey.FieldUserID). Only(ctx) if err != nil { if dbent.IsNotFound(err) { - return 0, service.ErrAPIKeyNotFound + return "", 0, service.ErrAPIKeyNotFound } - return 0, err + return "", 0, err } - return m.UserID, nil + return m.Key, m.UserID, nil } func (r *apiKeyRepository) GetByKey(ctx context.Context, key string) (*service.APIKey, error) { @@ -98,6 +100,54 @@ func (r *apiKeyRepository) GetByKey(ctx context.Context, key string) (*service.A return apiKeyEntityToService(m), nil } +func (r *apiKeyRepository) GetByKeyForAuth(ctx context.Context, key string) (*service.APIKey, error) { + m, err := r.activeQuery(). + Where(apikey.KeyEQ(key)). + Select( + apikey.FieldID, + apikey.FieldUserID, + apikey.FieldGroupID, + apikey.FieldStatus, + apikey.FieldIPWhitelist, + apikey.FieldIPBlacklist, + ). + WithUser(func(q *dbent.UserQuery) { + q.Select( + user.FieldID, + user.FieldStatus, + user.FieldRole, + user.FieldBalance, + user.FieldConcurrency, + ) + }). + WithGroup(func(q *dbent.GroupQuery) { + q.Select( + group.FieldID, + group.FieldName, + group.FieldPlatform, + group.FieldStatus, + group.FieldSubscriptionType, + group.FieldRateMultiplier, + group.FieldDailyLimitUsd, + group.FieldWeeklyLimitUsd, + group.FieldMonthlyLimitUsd, + group.FieldImagePrice1k, + group.FieldImagePrice2k, + group.FieldImagePrice4k, + group.FieldClaudeCodeOnly, + group.FieldFallbackGroupID, + ) + }). + Only(ctx) + if err != nil { + if dbent.IsNotFound(err) { + return nil, service.ErrAPIKeyNotFound + } + return nil, err + } + return apiKeyEntityToService(m), nil +} + func (r *apiKeyRepository) Update(ctx context.Context, key *service.APIKey) error { // 使用原子操作:将软删除检查与更新合并到同一语句,避免竞态条件。 // 之前的实现先检查 Exist 再 UpdateOneID,若在两步之间发生软删除, @@ -283,6 +333,28 @@ func (r *apiKeyRepository) CountByGroupID(ctx context.Context, groupID int64) (i return int64(count), err } +func (r *apiKeyRepository) ListKeysByUserID(ctx context.Context, userID int64) ([]string, error) { + keys, err := r.activeQuery(). + Where(apikey.UserIDEQ(userID)). + Select(apikey.FieldKey). + Strings(ctx) + if err != nil { + return nil, err + } + return keys, nil +} + +func (r *apiKeyRepository) ListKeysByGroupID(ctx context.Context, groupID int64) ([]string, error) { + keys, err := r.activeQuery(). + Where(apikey.GroupIDEQ(groupID)). + Select(apikey.FieldKey). + Strings(ctx) + if err != nil { + return nil, err + } + return keys, nil +} + func apiKeyEntityToService(m *dbent.APIKey) *service.APIKey { if m == nil { return nil diff --git a/backend/internal/server/api_contract_test.go b/backend/internal/server/api_contract_test.go index aa5c6a3e..04cc5c2e 100644 --- a/backend/internal/server/api_contract_test.go +++ b/backend/internal/server/api_contract_test.go @@ -390,7 +390,7 @@ func newContractDeps(t *testing.T) *contractDeps { RunMode: config.RunModeStandard, } - userService := service.NewUserService(userRepo) + userService := service.NewUserService(userRepo, nil) apiKeyService := service.NewAPIKeyService(apiKeyRepo, userRepo, groupRepo, userSubRepo, apiKeyCache, cfg) usageRepo := newStubUsageLogRepo() @@ -566,6 +566,18 @@ func (stubApiKeyCache) SetDailyUsageExpiry(ctx context.Context, apiKey string, t return nil } +func (stubApiKeyCache) GetAuthCache(ctx context.Context, key string) (*service.APIKeyAuthCacheEntry, error) { + return nil, nil +} + +func (stubApiKeyCache) SetAuthCache(ctx context.Context, key string, entry *service.APIKeyAuthCacheEntry, ttl time.Duration) error { + return nil +} + +func (stubApiKeyCache) DeleteAuthCache(ctx context.Context, key string) error { + return nil +} + type stubGroupRepo struct{} func (stubGroupRepo) Create(ctx context.Context, group *service.Group) error { @@ -738,12 +750,12 @@ func (r *stubApiKeyRepo) GetByID(ctx context.Context, id int64) (*service.APIKey return &clone, nil } -func (r *stubApiKeyRepo) GetOwnerID(ctx context.Context, id int64) (int64, error) { +func (r *stubApiKeyRepo) GetKeyAndOwnerID(ctx context.Context, id int64) (string, int64, error) { key, ok := r.byID[id] if !ok { - return 0, service.ErrAPIKeyNotFound + return "", 0, service.ErrAPIKeyNotFound } - return key.UserID, nil + return key.Key, key.UserID, nil } func (r *stubApiKeyRepo) GetByKey(ctx context.Context, key string) (*service.APIKey, error) { @@ -755,6 +767,10 @@ func (r *stubApiKeyRepo) GetByKey(ctx context.Context, key string) (*service.API return &clone, nil } +func (r *stubApiKeyRepo) GetByKeyForAuth(ctx context.Context, key string) (*service.APIKey, error) { + return r.GetByKey(ctx, key) +} + func (r *stubApiKeyRepo) Update(ctx context.Context, key *service.APIKey) error { if key == nil { return errors.New("nil key") @@ -869,6 +885,14 @@ func (r *stubApiKeyRepo) CountByGroupID(ctx context.Context, groupID int64) (int return 0, errors.New("not implemented") } +func (r *stubApiKeyRepo) ListKeysByUserID(ctx context.Context, userID int64) ([]string, error) { + return nil, errors.New("not implemented") +} + +func (r *stubApiKeyRepo) ListKeysByGroupID(ctx context.Context, groupID int64) ([]string, error) { + return nil, errors.New("not implemented") +} + type stubUsageLogRepo struct { userLogs map[int64][]service.UsageLog } diff --git a/backend/internal/server/middleware/api_key_auth_google_test.go b/backend/internal/server/middleware/api_key_auth_google_test.go index 07b8e370..6f09469b 100644 --- a/backend/internal/server/middleware/api_key_auth_google_test.go +++ b/backend/internal/server/middleware/api_key_auth_google_test.go @@ -27,8 +27,8 @@ func (f fakeAPIKeyRepo) Create(ctx context.Context, key *service.APIKey) error { func (f fakeAPIKeyRepo) GetByID(ctx context.Context, id int64) (*service.APIKey, error) { return nil, errors.New("not implemented") } -func (f fakeAPIKeyRepo) GetOwnerID(ctx context.Context, id int64) (int64, error) { - return 0, errors.New("not implemented") +func (f fakeAPIKeyRepo) GetKeyAndOwnerID(ctx context.Context, id int64) (string, int64, error) { + return "", 0, errors.New("not implemented") } func (f fakeAPIKeyRepo) GetByKey(ctx context.Context, key string) (*service.APIKey, error) { if f.getByKey == nil { @@ -36,6 +36,9 @@ func (f fakeAPIKeyRepo) GetByKey(ctx context.Context, key string) (*service.APIK } return f.getByKey(ctx, key) } +func (f fakeAPIKeyRepo) GetByKeyForAuth(ctx context.Context, key string) (*service.APIKey, error) { + return f.GetByKey(ctx, key) +} func (f fakeAPIKeyRepo) Update(ctx context.Context, key *service.APIKey) error { return errors.New("not implemented") } @@ -66,6 +69,12 @@ func (f fakeAPIKeyRepo) ClearGroupIDByGroupID(ctx context.Context, groupID int64 func (f fakeAPIKeyRepo) CountByGroupID(ctx context.Context, groupID int64) (int64, error) { return 0, errors.New("not implemented") } +func (f fakeAPIKeyRepo) ListKeysByUserID(ctx context.Context, userID int64) ([]string, error) { + return nil, errors.New("not implemented") +} +func (f fakeAPIKeyRepo) ListKeysByGroupID(ctx context.Context, groupID int64) ([]string, error) { + return nil, errors.New("not implemented") +} type googleErrorResponse struct { Error struct { diff --git a/backend/internal/server/middleware/api_key_auth_test.go b/backend/internal/server/middleware/api_key_auth_test.go index 182ea5f8..84398093 100644 --- a/backend/internal/server/middleware/api_key_auth_test.go +++ b/backend/internal/server/middleware/api_key_auth_test.go @@ -256,8 +256,8 @@ func (r *stubApiKeyRepo) GetByID(ctx context.Context, id int64) (*service.APIKey return nil, errors.New("not implemented") } -func (r *stubApiKeyRepo) GetOwnerID(ctx context.Context, id int64) (int64, error) { - return 0, errors.New("not implemented") +func (r *stubApiKeyRepo) GetKeyAndOwnerID(ctx context.Context, id int64) (string, int64, error) { + return "", 0, errors.New("not implemented") } func (r *stubApiKeyRepo) GetByKey(ctx context.Context, key string) (*service.APIKey, error) { @@ -267,6 +267,10 @@ func (r *stubApiKeyRepo) GetByKey(ctx context.Context, key string) (*service.API return nil, errors.New("not implemented") } +func (r *stubApiKeyRepo) GetByKeyForAuth(ctx context.Context, key string) (*service.APIKey, error) { + return r.GetByKey(ctx, key) +} + func (r *stubApiKeyRepo) Update(ctx context.Context, key *service.APIKey) error { return errors.New("not implemented") } @@ -307,6 +311,14 @@ func (r *stubApiKeyRepo) CountByGroupID(ctx context.Context, groupID int64) (int return 0, errors.New("not implemented") } +func (r *stubApiKeyRepo) ListKeysByUserID(ctx context.Context, userID int64) ([]string, error) { + return nil, errors.New("not implemented") +} + +func (r *stubApiKeyRepo) ListKeysByGroupID(ctx context.Context, groupID int64) ([]string, error) { + return nil, errors.New("not implemented") +} + type stubUserSubscriptionRepo struct { getActive func(ctx context.Context, userID, groupID int64) (*service.UserSubscription, error) updateStatus func(ctx context.Context, subscriptionID int64, status string) error diff --git a/backend/internal/service/admin_service.go b/backend/internal/service/admin_service.go index 14bb6daf..75b57852 100644 --- a/backend/internal/service/admin_service.go +++ b/backend/internal/service/admin_service.go @@ -244,14 +244,15 @@ type ProxyExitInfoProber interface { // adminServiceImpl implements AdminService type adminServiceImpl struct { - userRepo UserRepository - groupRepo GroupRepository - accountRepo AccountRepository - proxyRepo ProxyRepository - apiKeyRepo APIKeyRepository - redeemCodeRepo RedeemCodeRepository - billingCacheService *BillingCacheService - proxyProber ProxyExitInfoProber + userRepo UserRepository + groupRepo GroupRepository + accountRepo AccountRepository + proxyRepo ProxyRepository + apiKeyRepo APIKeyRepository + redeemCodeRepo RedeemCodeRepository + billingCacheService *BillingCacheService + proxyProber ProxyExitInfoProber + authCacheInvalidator APIKeyAuthCacheInvalidator } // NewAdminService creates a new AdminService @@ -264,16 +265,18 @@ func NewAdminService( redeemCodeRepo RedeemCodeRepository, billingCacheService *BillingCacheService, proxyProber ProxyExitInfoProber, + authCacheInvalidator APIKeyAuthCacheInvalidator, ) AdminService { return &adminServiceImpl{ - userRepo: userRepo, - groupRepo: groupRepo, - accountRepo: accountRepo, - proxyRepo: proxyRepo, - apiKeyRepo: apiKeyRepo, - redeemCodeRepo: redeemCodeRepo, - billingCacheService: billingCacheService, - proxyProber: proxyProber, + userRepo: userRepo, + groupRepo: groupRepo, + accountRepo: accountRepo, + proxyRepo: proxyRepo, + apiKeyRepo: apiKeyRepo, + redeemCodeRepo: redeemCodeRepo, + billingCacheService: billingCacheService, + proxyProber: proxyProber, + authCacheInvalidator: authCacheInvalidator, } } @@ -323,6 +326,8 @@ func (s *adminServiceImpl) UpdateUser(ctx context.Context, id int64, input *Upda } oldConcurrency := user.Concurrency + oldStatus := user.Status + oldRole := user.Role if input.Email != "" { user.Email = input.Email @@ -355,6 +360,11 @@ func (s *adminServiceImpl) UpdateUser(ctx context.Context, id int64, input *Upda if err := s.userRepo.Update(ctx, user); err != nil { return nil, err } + if s.authCacheInvalidator != nil { + if user.Concurrency != oldConcurrency || user.Status != oldStatus || user.Role != oldRole { + s.authCacheInvalidator.InvalidateAuthCacheByUserID(ctx, user.ID) + } + } concurrencyDiff := user.Concurrency - oldConcurrency if concurrencyDiff != 0 { @@ -393,6 +403,9 @@ func (s *adminServiceImpl) DeleteUser(ctx context.Context, id int64) error { log.Printf("delete user failed: user_id=%d err=%v", id, err) return err } + if s.authCacheInvalidator != nil { + s.authCacheInvalidator.InvalidateAuthCacheByUserID(ctx, id) + } return nil } @@ -420,6 +433,10 @@ func (s *adminServiceImpl) UpdateUserBalance(ctx context.Context, userID int64, if err := s.userRepo.Update(ctx, user); err != nil { return nil, err } + balanceDiff := user.Balance - oldBalance + if s.authCacheInvalidator != nil && balanceDiff != 0 { + s.authCacheInvalidator.InvalidateAuthCacheByUserID(ctx, userID) + } if s.billingCacheService != nil { go func() { @@ -431,7 +448,6 @@ func (s *adminServiceImpl) UpdateUserBalance(ctx context.Context, userID int64, }() } - balanceDiff := user.Balance - oldBalance if balanceDiff != 0 { code, err := GenerateRedeemCode() if err != nil { @@ -675,10 +691,21 @@ func (s *adminServiceImpl) UpdateGroup(ctx context.Context, id int64, input *Upd if err := s.groupRepo.Update(ctx, group); err != nil { return nil, err } + if s.authCacheInvalidator != nil { + s.authCacheInvalidator.InvalidateAuthCacheByGroupID(ctx, id) + } return group, nil } func (s *adminServiceImpl) DeleteGroup(ctx context.Context, id int64) error { + var groupKeys []string + if s.authCacheInvalidator != nil { + keys, err := s.apiKeyRepo.ListKeysByGroupID(ctx, id) + if err == nil { + groupKeys = keys + } + } + affectedUserIDs, err := s.groupRepo.DeleteCascade(ctx, id) if err != nil { return err @@ -697,6 +724,11 @@ func (s *adminServiceImpl) DeleteGroup(ctx context.Context, id int64) error { } }() } + if s.authCacheInvalidator != nil { + for _, key := range groupKeys { + s.authCacheInvalidator.InvalidateAuthCacheByKey(ctx, key) + } + } return nil } diff --git a/backend/internal/service/admin_service_update_balance_test.go b/backend/internal/service/admin_service_update_balance_test.go new file mode 100644 index 00000000..d3b3c700 --- /dev/null +++ b/backend/internal/service/admin_service_update_balance_test.go @@ -0,0 +1,97 @@ +//go:build unit + +package service + +import ( + "context" + "testing" + + "github.com/stretchr/testify/require" +) + +type balanceUserRepoStub struct { + *userRepoStub + updateErr error + updated []*User +} + +func (s *balanceUserRepoStub) Update(ctx context.Context, user *User) error { + if s.updateErr != nil { + return s.updateErr + } + if user == nil { + return nil + } + clone := *user + s.updated = append(s.updated, &clone) + if s.userRepoStub != nil { + s.userRepoStub.user = &clone + } + return nil +} + +type balanceRedeemRepoStub struct { + *redeemRepoStub + created []*RedeemCode +} + +func (s *balanceRedeemRepoStub) Create(ctx context.Context, code *RedeemCode) error { + if code == nil { + return nil + } + clone := *code + s.created = append(s.created, &clone) + return nil +} + +type authCacheInvalidatorStub struct { + userIDs []int64 + groupIDs []int64 + keys []string +} + +func (s *authCacheInvalidatorStub) InvalidateAuthCacheByKey(ctx context.Context, key string) { + s.keys = append(s.keys, key) +} + +func (s *authCacheInvalidatorStub) InvalidateAuthCacheByUserID(ctx context.Context, userID int64) { + s.userIDs = append(s.userIDs, userID) +} + +func (s *authCacheInvalidatorStub) InvalidateAuthCacheByGroupID(ctx context.Context, groupID int64) { + s.groupIDs = append(s.groupIDs, groupID) +} + +func TestAdminService_UpdateUserBalance_InvalidatesAuthCache(t *testing.T) { + baseRepo := &userRepoStub{user: &User{ID: 7, Balance: 10}} + repo := &balanceUserRepoStub{userRepoStub: baseRepo} + redeemRepo := &balanceRedeemRepoStub{redeemRepoStub: &redeemRepoStub{}} + invalidator := &authCacheInvalidatorStub{} + svc := &adminServiceImpl{ + userRepo: repo, + redeemCodeRepo: redeemRepo, + authCacheInvalidator: invalidator, + } + + _, err := svc.UpdateUserBalance(context.Background(), 7, 5, "add", "") + require.NoError(t, err) + require.Equal(t, []int64{7}, invalidator.userIDs) + require.Len(t, redeemRepo.created, 1) +} + +func TestAdminService_UpdateUserBalance_NoChangeNoInvalidate(t *testing.T) { + baseRepo := &userRepoStub{user: &User{ID: 7, Balance: 10}} + repo := &balanceUserRepoStub{userRepoStub: baseRepo} + redeemRepo := &balanceRedeemRepoStub{redeemRepoStub: &redeemRepoStub{}} + invalidator := &authCacheInvalidatorStub{} + svc := &adminServiceImpl{ + userRepo: repo, + redeemCodeRepo: redeemRepo, + authCacheInvalidator: invalidator, + } + + _, err := svc.UpdateUserBalance(context.Background(), 7, 10, "set", "") + require.NoError(t, err) + require.Empty(t, invalidator.userIDs) + require.Empty(t, redeemRepo.created) +} diff --git a/backend/internal/service/api_key_auth_cache.go b/backend/internal/service/api_key_auth_cache.go new file mode 100644 index 00000000..7ce9a8a2 --- /dev/null +++ b/backend/internal/service/api_key_auth_cache.go @@ -0,0 +1,46 @@ +package service + +// APIKeyAuthSnapshot API Key 认证缓存快照(仅包含认证所需字段) +type APIKeyAuthSnapshot struct { + APIKeyID int64 `json:"api_key_id"` + UserID int64 `json:"user_id"` + GroupID *int64 `json:"group_id,omitempty"` + Status string `json:"status"` + IPWhitelist []string `json:"ip_whitelist,omitempty"` + IPBlacklist []string `json:"ip_blacklist,omitempty"` + User APIKeyAuthUserSnapshot `json:"user"` + Group *APIKeyAuthGroupSnapshot `json:"group,omitempty"` +} + +// APIKeyAuthUserSnapshot 用户快照 +type APIKeyAuthUserSnapshot struct { + ID int64 `json:"id"` + Status string `json:"status"` + Role string `json:"role"` + Balance float64 `json:"balance"` + Concurrency int `json:"concurrency"` +} + +// APIKeyAuthGroupSnapshot 分组快照 +type APIKeyAuthGroupSnapshot struct { + ID int64 `json:"id"` + Name string `json:"name"` + Platform string `json:"platform"` + Status string `json:"status"` + SubscriptionType string `json:"subscription_type"` + RateMultiplier float64 `json:"rate_multiplier"` + DailyLimitUSD *float64 `json:"daily_limit_usd,omitempty"` + WeeklyLimitUSD *float64 `json:"weekly_limit_usd,omitempty"` + MonthlyLimitUSD *float64 `json:"monthly_limit_usd,omitempty"` + ImagePrice1K *float64 `json:"image_price_1k,omitempty"` + ImagePrice2K *float64 `json:"image_price_2k,omitempty"` + ImagePrice4K *float64 `json:"image_price_4k,omitempty"` + ClaudeCodeOnly bool `json:"claude_code_only"` + FallbackGroupID *int64 `json:"fallback_group_id,omitempty"` +} + +// APIKeyAuthCacheEntry 缓存条目,支持负缓存 +type APIKeyAuthCacheEntry struct { + NotFound bool `json:"not_found"` + Snapshot *APIKeyAuthSnapshot `json:"snapshot,omitempty"` +} diff --git a/backend/internal/service/api_key_auth_cache_impl.go b/backend/internal/service/api_key_auth_cache_impl.go new file mode 100644 index 00000000..dfc55eeb --- /dev/null +++ b/backend/internal/service/api_key_auth_cache_impl.go @@ -0,0 +1,269 @@ +package service + +import ( + "context" + "crypto/sha256" + "encoding/hex" + "errors" + "fmt" + "math/rand" + "sync" + "time" + + "github.com/Wei-Shaw/sub2api/internal/config" + "github.com/dgraph-io/ristretto" +) + +type apiKeyAuthCacheConfig struct { + l1Size int + l1TTL time.Duration + l2TTL time.Duration + negativeTTL time.Duration + jitterPercent int + singleflight bool +} + +var ( + jitterRandMu sync.Mutex + // 认证缓存抖动使用独立随机源,避免全局 Seed + jitterRand = rand.New(rand.NewSource(time.Now().UnixNano())) +) + +func newAPIKeyAuthCacheConfig(cfg *config.Config) apiKeyAuthCacheConfig { + if cfg == nil { + return apiKeyAuthCacheConfig{} + } + auth := cfg.APIKeyAuth + return apiKeyAuthCacheConfig{ + l1Size: auth.L1Size, + l1TTL: time.Duration(auth.L1TTLSeconds) * time.Second, + l2TTL: time.Duration(auth.L2TTLSeconds) * time.Second, + negativeTTL: time.Duration(auth.NegativeTTLSeconds) * time.Second, + jitterPercent: auth.JitterPercent, + singleflight: auth.Singleflight, + } +} + +func (c apiKeyAuthCacheConfig) l1Enabled() bool { + return c.l1Size > 0 && c.l1TTL > 0 +} + +func (c apiKeyAuthCacheConfig) l2Enabled() bool { + return c.l2TTL > 0 +} + +func (c apiKeyAuthCacheConfig) negativeEnabled() bool { + return c.negativeTTL > 0 +} + +func (c apiKeyAuthCacheConfig) jitterTTL(ttl time.Duration) time.Duration { + if ttl <= 0 { + return ttl + } + if c.jitterPercent <= 0 { + return ttl + } + percent := c.jitterPercent + if percent > 100 { + percent = 100 + } + delta := float64(percent) / 100 + jitterRandMu.Lock() + randVal := jitterRand.Float64() + jitterRandMu.Unlock() + factor := 1 - delta + randVal*(2*delta) + if factor <= 0 { + return ttl + } + return time.Duration(float64(ttl) * factor) +} + +func (s *APIKeyService) initAuthCache(cfg *config.Config) { + s.authCfg = newAPIKeyAuthCacheConfig(cfg) + if !s.authCfg.l1Enabled() { + return + } + cache, err := ristretto.NewCache(&ristretto.Config{ + NumCounters: int64(s.authCfg.l1Size) * 10, + MaxCost: int64(s.authCfg.l1Size), + BufferItems: 64, + }) + if err != nil { + return + } + s.authCacheL1 = cache +} + +func (s *APIKeyService) authCacheKey(key string) string { + sum := sha256.Sum256([]byte(key)) + return hex.EncodeToString(sum[:]) +} + +func (s *APIKeyService) getAuthCacheEntry(ctx context.Context, cacheKey string) (*APIKeyAuthCacheEntry, bool) { + if s.authCacheL1 != nil { + if val, ok := s.authCacheL1.Get(cacheKey); ok { + if entry, ok := val.(*APIKeyAuthCacheEntry); ok { + return entry, true + } + } + } + if s.cache == nil || !s.authCfg.l2Enabled() { + return nil, false + } + entry, err := s.cache.GetAuthCache(ctx, cacheKey) + if err != nil { + return nil, false + } + s.setAuthCacheL1(cacheKey, entry) + return entry, true +} + +func (s *APIKeyService) setAuthCacheL1(cacheKey string, entry *APIKeyAuthCacheEntry) { + if s.authCacheL1 == nil || entry == nil { + return + } + ttl := s.authCfg.l1TTL + if entry.NotFound && s.authCfg.negativeTTL > 0 && s.authCfg.negativeTTL < ttl { + ttl = s.authCfg.negativeTTL + } + ttl = s.authCfg.jitterTTL(ttl) + _ = s.authCacheL1.SetWithTTL(cacheKey, entry, 1, ttl) +} + +func (s *APIKeyService) setAuthCacheEntry(ctx context.Context, cacheKey string, entry *APIKeyAuthCacheEntry, ttl time.Duration) { + if entry == nil { + return + } + s.setAuthCacheL1(cacheKey, entry) + if s.cache == nil || !s.authCfg.l2Enabled() { + return + } + _ = s.cache.SetAuthCache(ctx, cacheKey, entry, s.authCfg.jitterTTL(ttl)) +} + +func (s *APIKeyService) deleteAuthCache(ctx context.Context, cacheKey string) { + if s.authCacheL1 != nil { + s.authCacheL1.Del(cacheKey) + } + if s.cache == nil { + return + } + _ = s.cache.DeleteAuthCache(ctx, cacheKey) +} + +func (s *APIKeyService) loadAuthCacheEntry(ctx context.Context, key, cacheKey string) (*APIKeyAuthCacheEntry, error) { + apiKey, err := s.apiKeyRepo.GetByKeyForAuth(ctx, key) + if err != nil { + if errors.Is(err, ErrAPIKeyNotFound) { + entry := &APIKeyAuthCacheEntry{NotFound: true} + if s.authCfg.negativeEnabled() { + s.setAuthCacheEntry(ctx, cacheKey, entry, s.authCfg.negativeTTL) + } + return entry, nil + } + return nil, fmt.Errorf("get api key: %w", err) + } + apiKey.Key = key + snapshot := s.snapshotFromAPIKey(apiKey) + if snapshot == nil { + return nil, fmt.Errorf("get api key: %w", ErrAPIKeyNotFound) + } + entry := &APIKeyAuthCacheEntry{Snapshot: snapshot} + s.setAuthCacheEntry(ctx, cacheKey, entry, s.authCfg.l2TTL) + return entry, nil +} + +func (s *APIKeyService) applyAuthCacheEntry(key string, entry *APIKeyAuthCacheEntry) (*APIKey, bool, error) { + if entry == nil { + return nil, false, nil + } + if entry.NotFound { + return nil, true, ErrAPIKeyNotFound + } + if entry.Snapshot == nil { + return nil, false, nil + } + return s.snapshotToAPIKey(key, entry.Snapshot), true, nil +} + +func (s *APIKeyService) snapshotFromAPIKey(apiKey *APIKey) *APIKeyAuthSnapshot { + if apiKey == nil || apiKey.User == nil { + return nil + } + snapshot := &APIKeyAuthSnapshot{ + APIKeyID: apiKey.ID, + UserID: apiKey.UserID, + GroupID: apiKey.GroupID, + Status: apiKey.Status, + IPWhitelist: apiKey.IPWhitelist, + IPBlacklist: apiKey.IPBlacklist, + User: APIKeyAuthUserSnapshot{ + ID: apiKey.User.ID, + Status: apiKey.User.Status, + Role: apiKey.User.Role, + Balance: apiKey.User.Balance, + Concurrency: apiKey.User.Concurrency, + }, + } + if apiKey.Group != nil { + snapshot.Group = &APIKeyAuthGroupSnapshot{ + ID: apiKey.Group.ID, + Name: apiKey.Group.Name, + Platform: apiKey.Group.Platform, + Status: apiKey.Group.Status, + SubscriptionType: apiKey.Group.SubscriptionType, + RateMultiplier: apiKey.Group.RateMultiplier, + DailyLimitUSD: apiKey.Group.DailyLimitUSD, + WeeklyLimitUSD: apiKey.Group.WeeklyLimitUSD, + MonthlyLimitUSD: apiKey.Group.MonthlyLimitUSD, + ImagePrice1K: apiKey.Group.ImagePrice1K, + ImagePrice2K: apiKey.Group.ImagePrice2K, + ImagePrice4K: apiKey.Group.ImagePrice4K, + ClaudeCodeOnly: apiKey.Group.ClaudeCodeOnly, + FallbackGroupID: apiKey.Group.FallbackGroupID, + } + } + return snapshot +} + +func (s *APIKeyService) snapshotToAPIKey(key string, snapshot *APIKeyAuthSnapshot) *APIKey { + if snapshot == nil { + return nil + } + apiKey := &APIKey{ + ID: snapshot.APIKeyID, + UserID: snapshot.UserID, + GroupID: snapshot.GroupID, + Key: key, + Status: snapshot.Status, + IPWhitelist: snapshot.IPWhitelist, + IPBlacklist: snapshot.IPBlacklist, + User: &User{ + ID: snapshot.User.ID, + Status: snapshot.User.Status, + Role: snapshot.User.Role, + Balance: snapshot.User.Balance, + Concurrency: snapshot.User.Concurrency, + }, + } + if snapshot.Group != nil { + apiKey.Group = &Group{ + ID: snapshot.Group.ID, + Name: snapshot.Group.Name, + Platform: snapshot.Group.Platform, + Status: snapshot.Group.Status, + Hydrated: true, + SubscriptionType: snapshot.Group.SubscriptionType, + RateMultiplier: snapshot.Group.RateMultiplier, + DailyLimitUSD: snapshot.Group.DailyLimitUSD, + WeeklyLimitUSD: snapshot.Group.WeeklyLimitUSD, + MonthlyLimitUSD: snapshot.Group.MonthlyLimitUSD, + ImagePrice1K: snapshot.Group.ImagePrice1K, + ImagePrice2K: snapshot.Group.ImagePrice2K, + ImagePrice4K: snapshot.Group.ImagePrice4K, + ClaudeCodeOnly: snapshot.Group.ClaudeCodeOnly, + FallbackGroupID: snapshot.Group.FallbackGroupID, + } + } + return apiKey +} diff --git a/backend/internal/service/api_key_auth_cache_invalidate.go b/backend/internal/service/api_key_auth_cache_invalidate.go new file mode 100644 index 00000000..aeb58bcc --- /dev/null +++ b/backend/internal/service/api_key_auth_cache_invalidate.go @@ -0,0 +1,48 @@ +package service + +import "context" + +// InvalidateAuthCacheByKey 清除指定 API Key 的认证缓存 +func (s *APIKeyService) InvalidateAuthCacheByKey(ctx context.Context, key string) { + if key == "" { + return + } + cacheKey := s.authCacheKey(key) + s.deleteAuthCache(ctx, cacheKey) +} + +// InvalidateAuthCacheByUserID 清除用户相关的 API Key 认证缓存 +func (s *APIKeyService) InvalidateAuthCacheByUserID(ctx context.Context, userID int64) { + if userID <= 0 { + return + } + keys, err := s.apiKeyRepo.ListKeysByUserID(ctx, userID) + if err != nil { + return + } + s.deleteAuthCacheByKeys(ctx, keys) +} + +// InvalidateAuthCacheByGroupID 清除分组相关的 API Key 认证缓存 +func (s *APIKeyService) InvalidateAuthCacheByGroupID(ctx context.Context, groupID int64) { + if groupID <= 0 { + return + } + keys, err := s.apiKeyRepo.ListKeysByGroupID(ctx, groupID) + if err != nil { + return + } + s.deleteAuthCacheByKeys(ctx, keys) +} + +func (s *APIKeyService) deleteAuthCacheByKeys(ctx context.Context, keys []string) { + if len(keys) == 0 { + return + } + for _, key := range keys { + if key == "" { + continue + } + s.deleteAuthCache(ctx, s.authCacheKey(key)) + } +} diff --git a/backend/internal/service/api_key_service.go b/backend/internal/service/api_key_service.go index 578afc1a..ecc570c7 100644 --- a/backend/internal/service/api_key_service.go +++ b/backend/internal/service/api_key_service.go @@ -12,6 +12,8 @@ import ( "github.com/Wei-Shaw/sub2api/internal/pkg/ip" "github.com/Wei-Shaw/sub2api/internal/pkg/pagination" "github.com/Wei-Shaw/sub2api/internal/pkg/timezone" + "github.com/dgraph-io/ristretto" + "golang.org/x/sync/singleflight" ) var ( @@ -31,9 +33,11 @@ const ( type APIKeyRepository interface { Create(ctx context.Context, key *APIKey) error GetByID(ctx context.Context, id int64) (*APIKey, error) - // GetOwnerID 仅获取 API Key 的所有者 ID,用于删除前的轻量级权限验证 - GetOwnerID(ctx context.Context, id int64) (int64, error) + // GetKeyAndOwnerID 仅获取 API Key 的 key 与所有者 ID,用于删除等轻量场景 + GetKeyAndOwnerID(ctx context.Context, id int64) (string, int64, error) GetByKey(ctx context.Context, key string) (*APIKey, error) + // GetByKeyForAuth 认证专用查询,返回最小字段集 + GetByKeyForAuth(ctx context.Context, key string) (*APIKey, error) Update(ctx context.Context, key *APIKey) error Delete(ctx context.Context, id int64) error @@ -45,6 +49,8 @@ type APIKeyRepository interface { SearchAPIKeys(ctx context.Context, userID int64, keyword string, limit int) ([]APIKey, error) ClearGroupIDByGroupID(ctx context.Context, groupID int64) (int64, error) CountByGroupID(ctx context.Context, groupID int64) (int64, error) + ListKeysByUserID(ctx context.Context, userID int64) ([]string, error) + ListKeysByGroupID(ctx context.Context, groupID int64) ([]string, error) } // APIKeyCache defines cache operations for API key service @@ -55,6 +61,17 @@ type APIKeyCache interface { IncrementDailyUsage(ctx context.Context, apiKey string) error SetDailyUsageExpiry(ctx context.Context, apiKey string, ttl time.Duration) error + + GetAuthCache(ctx context.Context, key string) (*APIKeyAuthCacheEntry, error) + SetAuthCache(ctx context.Context, key string, entry *APIKeyAuthCacheEntry, ttl time.Duration) error + DeleteAuthCache(ctx context.Context, key string) error +} + +// APIKeyAuthCacheInvalidator 提供认证缓存失效能力 +type APIKeyAuthCacheInvalidator interface { + InvalidateAuthCacheByKey(ctx context.Context, key string) + InvalidateAuthCacheByUserID(ctx context.Context, userID int64) + InvalidateAuthCacheByGroupID(ctx context.Context, groupID int64) } // CreateAPIKeyRequest 创建API Key请求 @@ -83,6 +100,9 @@ type APIKeyService struct { userSubRepo UserSubscriptionRepository cache APIKeyCache cfg *config.Config + authCacheL1 *ristretto.Cache + authCfg apiKeyAuthCacheConfig + authGroup singleflight.Group } // NewAPIKeyService 创建API Key服务实例 @@ -94,7 +114,7 @@ func NewAPIKeyService( cache APIKeyCache, cfg *config.Config, ) *APIKeyService { - return &APIKeyService{ + svc := &APIKeyService{ apiKeyRepo: apiKeyRepo, userRepo: userRepo, groupRepo: groupRepo, @@ -102,6 +122,8 @@ func NewAPIKeyService( cache: cache, cfg: cfg, } + svc.initAuthCache(cfg) + return svc } // GenerateKey 生成随机API Key @@ -269,6 +291,8 @@ func (s *APIKeyService) Create(ctx context.Context, userID int64, req CreateAPIK return nil, fmt.Errorf("create api key: %w", err) } + s.InvalidateAuthCacheByKey(ctx, apiKey.Key) + return apiKey, nil } @@ -304,21 +328,49 @@ func (s *APIKeyService) GetByID(ctx context.Context, id int64) (*APIKey, error) // GetByKey 根据Key字符串获取API Key(用于认证) func (s *APIKeyService) GetByKey(ctx context.Context, key string) (*APIKey, error) { - // 尝试从Redis缓存获取 - cacheKey := fmt.Sprintf("apikey:%s", key) + cacheKey := s.authCacheKey(key) - // 这里可以添加Redis缓存逻辑,暂时直接查询数据库 - apiKey, err := s.apiKeyRepo.GetByKey(ctx, key) + if entry, ok := s.getAuthCacheEntry(ctx, cacheKey); ok { + if apiKey, used, err := s.applyAuthCacheEntry(key, entry); used { + if err != nil { + return nil, fmt.Errorf("get api key: %w", err) + } + return apiKey, nil + } + } + + if s.authCfg.singleflight { + value, err, _ := s.authGroup.Do(cacheKey, func() (any, error) { + return s.loadAuthCacheEntry(ctx, key, cacheKey) + }) + if err != nil { + return nil, err + } + entry, _ := value.(*APIKeyAuthCacheEntry) + if apiKey, used, err := s.applyAuthCacheEntry(key, entry); used { + if err != nil { + return nil, fmt.Errorf("get api key: %w", err) + } + return apiKey, nil + } + } else { + entry, err := s.loadAuthCacheEntry(ctx, key, cacheKey) + if err != nil { + return nil, err + } + if apiKey, used, err := s.applyAuthCacheEntry(key, entry); used { + if err != nil { + return nil, fmt.Errorf("get api key: %w", err) + } + return apiKey, nil + } + } + + apiKey, err := s.apiKeyRepo.GetByKeyForAuth(ctx, key) if err != nil { return nil, fmt.Errorf("get api key: %w", err) } - - // 缓存到Redis(可选,TTL设置为5分钟) - if s.cache != nil { - // 这里可以序列化并缓存API Key - _ = cacheKey // 使用变量避免未使用错误 - } - + apiKey.Key = key return apiKey, nil } @@ -388,15 +440,14 @@ func (s *APIKeyService) Update(ctx context.Context, id int64, userID int64, req return nil, fmt.Errorf("update api key: %w", err) } + s.InvalidateAuthCacheByKey(ctx, apiKey.Key) + return apiKey, nil } // Delete 删除API Key -// 优化:使用 GetOwnerID 替代 GetByID 进行权限验证, -// 避免加载完整 APIKey 对象及其关联数据(User、Group),提升删除操作的性能 func (s *APIKeyService) Delete(ctx context.Context, id int64, userID int64) error { - // 仅获取所有者 ID 用于权限验证,而非加载完整对象 - ownerID, err := s.apiKeyRepo.GetOwnerID(ctx, id) + key, ownerID, err := s.apiKeyRepo.GetKeyAndOwnerID(ctx, id) if err != nil { return fmt.Errorf("get api key: %w", err) } @@ -406,10 +457,11 @@ func (s *APIKeyService) Delete(ctx context.Context, id int64, userID int64) erro return ErrInsufficientPerms } - // 清除Redis缓存(使用 ownerID 而非 apiKey.UserID) + // 清除Redis缓存(使用 userID 而非 apiKey.UserID) if s.cache != nil { - _ = s.cache.DeleteCreateAttemptCount(ctx, ownerID) + _ = s.cache.DeleteCreateAttemptCount(ctx, userID) } + s.InvalidateAuthCacheByKey(ctx, key) if err := s.apiKeyRepo.Delete(ctx, id); err != nil { return fmt.Errorf("delete api key: %w", err) diff --git a/backend/internal/service/api_key_service_cache_test.go b/backend/internal/service/api_key_service_cache_test.go new file mode 100644 index 00000000..3314ca8d --- /dev/null +++ b/backend/internal/service/api_key_service_cache_test.go @@ -0,0 +1,417 @@ +//go:build unit + +package service + +import ( + "context" + "errors" + "sync" + "sync/atomic" + "testing" + "time" + + "github.com/Wei-Shaw/sub2api/internal/config" + "github.com/Wei-Shaw/sub2api/internal/pkg/pagination" + "github.com/redis/go-redis/v9" + "github.com/stretchr/testify/require" +) + +type authRepoStub struct { + getByKeyForAuth func(ctx context.Context, key string) (*APIKey, error) + listKeysByUserID func(ctx context.Context, userID int64) ([]string, error) + listKeysByGroupID func(ctx context.Context, groupID int64) ([]string, error) +} + +func (s *authRepoStub) Create(ctx context.Context, key *APIKey) error { + panic("unexpected Create call") +} + +func (s *authRepoStub) GetByID(ctx context.Context, id int64) (*APIKey, error) { + panic("unexpected GetByID call") +} + +func (s *authRepoStub) GetKeyAndOwnerID(ctx context.Context, id int64) (string, int64, error) { + panic("unexpected GetKeyAndOwnerID call") +} + +func (s *authRepoStub) GetByKey(ctx context.Context, key string) (*APIKey, error) { + panic("unexpected GetByKey call") +} + +func (s *authRepoStub) GetByKeyForAuth(ctx context.Context, key string) (*APIKey, error) { + if s.getByKeyForAuth == nil { + panic("unexpected GetByKeyForAuth call") + } + return s.getByKeyForAuth(ctx, key) +} + +func (s *authRepoStub) Update(ctx context.Context, key *APIKey) error { + panic("unexpected Update call") +} + +func (s *authRepoStub) Delete(ctx context.Context, id int64) error { + panic("unexpected Delete call") +} + +func (s *authRepoStub) ListByUserID(ctx context.Context, userID int64, params pagination.PaginationParams) ([]APIKey, *pagination.PaginationResult, error) { + panic("unexpected ListByUserID call") +} + +func (s *authRepoStub) VerifyOwnership(ctx context.Context, userID int64, apiKeyIDs []int64) ([]int64, error) { + panic("unexpected VerifyOwnership call") +} + +func (s *authRepoStub) CountByUserID(ctx context.Context, userID int64) (int64, error) { + panic("unexpected CountByUserID call") +} + +func (s *authRepoStub) ExistsByKey(ctx context.Context, key string) (bool, error) { + panic("unexpected ExistsByKey call") +} + +func (s *authRepoStub) ListByGroupID(ctx context.Context, groupID int64, params pagination.PaginationParams) ([]APIKey, *pagination.PaginationResult, error) { + panic("unexpected ListByGroupID call") +} + +func (s *authRepoStub) SearchAPIKeys(ctx context.Context, userID int64, keyword string, limit int) ([]APIKey, error) { + panic("unexpected SearchAPIKeys call") +} + +func (s *authRepoStub) ClearGroupIDByGroupID(ctx context.Context, groupID int64) (int64, error) { + panic("unexpected ClearGroupIDByGroupID call") +} + +func (s *authRepoStub) CountByGroupID(ctx context.Context, groupID int64) (int64, error) { + panic("unexpected CountByGroupID call") +} + +func (s *authRepoStub) ListKeysByUserID(ctx context.Context, userID int64) ([]string, error) { + if s.listKeysByUserID == nil { + panic("unexpected ListKeysByUserID call") + } + return s.listKeysByUserID(ctx, userID) +} + +func (s *authRepoStub) ListKeysByGroupID(ctx context.Context, groupID int64) ([]string, error) { + if s.listKeysByGroupID == nil { + panic("unexpected ListKeysByGroupID call") + } + return s.listKeysByGroupID(ctx, groupID) +} + +type authCacheStub struct { + getAuthCache func(ctx context.Context, key string) (*APIKeyAuthCacheEntry, error) + setAuthKeys []string + deleteAuthKeys []string +} + +func (s *authCacheStub) GetCreateAttemptCount(ctx context.Context, userID int64) (int, error) { + return 0, nil +} + +func (s *authCacheStub) IncrementCreateAttemptCount(ctx context.Context, userID int64) error { + return nil +} + +func (s *authCacheStub) DeleteCreateAttemptCount(ctx context.Context, userID int64) error { + return nil +} + +func (s *authCacheStub) IncrementDailyUsage(ctx context.Context, apiKey string) error { + return nil +} + +func (s *authCacheStub) SetDailyUsageExpiry(ctx context.Context, apiKey string, ttl time.Duration) error { + return nil +} + +func (s *authCacheStub) GetAuthCache(ctx context.Context, key string) (*APIKeyAuthCacheEntry, error) { + if s.getAuthCache == nil { + return nil, redis.Nil + } + return s.getAuthCache(ctx, key) +} + +func (s *authCacheStub) SetAuthCache(ctx context.Context, key string, entry *APIKeyAuthCacheEntry, ttl time.Duration) error { + s.setAuthKeys = append(s.setAuthKeys, key) + return nil +} + +func (s *authCacheStub) DeleteAuthCache(ctx context.Context, key string) error { + s.deleteAuthKeys = append(s.deleteAuthKeys, key) + return nil +} + +func TestAPIKeyService_GetByKey_UsesL2Cache(t *testing.T) { + cache := &authCacheStub{} + repo := &authRepoStub{ + getByKeyForAuth: func(ctx context.Context, key string) (*APIKey, error) { + return nil, errors.New("unexpected repo call") + }, + } + cfg := &config.Config{ + APIKeyAuth: config.APIKeyAuthCacheConfig{ + L2TTLSeconds: 60, + NegativeTTLSeconds: 30, + }, + } + svc := NewAPIKeyService(repo, nil, nil, nil, cache, cfg) + + groupID := int64(9) + cacheEntry := &APIKeyAuthCacheEntry{ + Snapshot: &APIKeyAuthSnapshot{ + APIKeyID: 1, + UserID: 2, + GroupID: &groupID, + Status: StatusActive, + User: APIKeyAuthUserSnapshot{ + ID: 2, + Status: StatusActive, + Role: RoleUser, + Balance: 10, + Concurrency: 3, + }, + Group: &APIKeyAuthGroupSnapshot{ + ID: groupID, + Name: "g", + Platform: PlatformAnthropic, + Status: StatusActive, + SubscriptionType: SubscriptionTypeStandard, + RateMultiplier: 1, + }, + }, + } + cache.getAuthCache = func(ctx context.Context, key string) (*APIKeyAuthCacheEntry, error) { + return cacheEntry, nil + } + + apiKey, err := svc.GetByKey(context.Background(), "k1") + require.NoError(t, err) + require.Equal(t, int64(1), apiKey.ID) + require.Equal(t, int64(2), apiKey.User.ID) + require.Equal(t, groupID, apiKey.Group.ID) +} + +func TestAPIKeyService_GetByKey_NegativeCache(t *testing.T) { + cache := &authCacheStub{} + repo := &authRepoStub{ + getByKeyForAuth: func(ctx context.Context, key string) (*APIKey, error) { + return nil, errors.New("unexpected repo call") + }, + } + cfg := &config.Config{ + APIKeyAuth: config.APIKeyAuthCacheConfig{ + L2TTLSeconds: 60, + NegativeTTLSeconds: 30, + }, + } + svc := NewAPIKeyService(repo, nil, nil, nil, cache, cfg) + cache.getAuthCache = func(ctx context.Context, key string) (*APIKeyAuthCacheEntry, error) { + return &APIKeyAuthCacheEntry{NotFound: true}, nil + } + + _, err := svc.GetByKey(context.Background(), "missing") + require.ErrorIs(t, err, ErrAPIKeyNotFound) +} + +func TestAPIKeyService_GetByKey_CacheMissStoresL2(t *testing.T) { + cache := &authCacheStub{} + repo := &authRepoStub{ + getByKeyForAuth: func(ctx context.Context, key string) (*APIKey, error) { + return &APIKey{ + ID: 5, + UserID: 7, + Status: StatusActive, + User: &User{ + ID: 7, + Status: StatusActive, + Role: RoleUser, + Balance: 12, + Concurrency: 2, + }, + }, nil + }, + } + cfg := &config.Config{ + APIKeyAuth: config.APIKeyAuthCacheConfig{ + L2TTLSeconds: 60, + NegativeTTLSeconds: 30, + }, + } + svc := NewAPIKeyService(repo, nil, nil, nil, cache, cfg) + cache.getAuthCache = func(ctx context.Context, key string) (*APIKeyAuthCacheEntry, error) { + return nil, redis.Nil + } + + apiKey, err := svc.GetByKey(context.Background(), "k2") + require.NoError(t, err) + require.Equal(t, int64(5), apiKey.ID) + require.Len(t, cache.setAuthKeys, 1) +} + +func TestAPIKeyService_GetByKey_UsesL1Cache(t *testing.T) { + var calls int32 + cache := &authCacheStub{} + repo := &authRepoStub{ + getByKeyForAuth: func(ctx context.Context, key string) (*APIKey, error) { + atomic.AddInt32(&calls, 1) + return &APIKey{ + ID: 21, + UserID: 3, + Status: StatusActive, + User: &User{ + ID: 3, + Status: StatusActive, + Role: RoleUser, + Balance: 5, + Concurrency: 2, + }, + }, nil + }, + } + cfg := &config.Config{ + APIKeyAuth: config.APIKeyAuthCacheConfig{ + L1Size: 1000, + L1TTLSeconds: 60, + }, + } + svc := NewAPIKeyService(repo, nil, nil, nil, cache, cfg) + require.NotNil(t, svc.authCacheL1) + + _, err := svc.GetByKey(context.Background(), "k-l1") + require.NoError(t, err) + svc.authCacheL1.Wait() + cacheKey := svc.authCacheKey("k-l1") + _, ok := svc.authCacheL1.Get(cacheKey) + require.True(t, ok) + _, err = svc.GetByKey(context.Background(), "k-l1") + require.NoError(t, err) + require.Equal(t, int32(1), atomic.LoadInt32(&calls)) +} + +func TestAPIKeyService_InvalidateAuthCacheByUserID(t *testing.T) { + cache := &authCacheStub{} + repo := &authRepoStub{ + listKeysByUserID: func(ctx context.Context, userID int64) ([]string, error) { + return []string{"k1", "k2"}, nil + }, + } + cfg := &config.Config{ + APIKeyAuth: config.APIKeyAuthCacheConfig{ + L2TTLSeconds: 60, + NegativeTTLSeconds: 30, + }, + } + svc := NewAPIKeyService(repo, nil, nil, nil, cache, cfg) + + svc.InvalidateAuthCacheByUserID(context.Background(), 7) + require.Len(t, cache.deleteAuthKeys, 2) +} + +func TestAPIKeyService_InvalidateAuthCacheByGroupID(t *testing.T) { + cache := &authCacheStub{} + repo := &authRepoStub{ + listKeysByGroupID: func(ctx context.Context, groupID int64) ([]string, error) { + return []string{"k1", "k2"}, nil + }, + } + cfg := &config.Config{ + APIKeyAuth: config.APIKeyAuthCacheConfig{ + L2TTLSeconds: 60, + }, + } + svc := NewAPIKeyService(repo, nil, nil, nil, cache, cfg) + + svc.InvalidateAuthCacheByGroupID(context.Background(), 9) + require.Len(t, cache.deleteAuthKeys, 2) +} + +func TestAPIKeyService_InvalidateAuthCacheByKey(t *testing.T) { + cache := &authCacheStub{} + repo := &authRepoStub{ + listKeysByUserID: func(ctx context.Context, userID int64) ([]string, error) { + return nil, nil + }, + } + cfg := &config.Config{ + APIKeyAuth: config.APIKeyAuthCacheConfig{ + L2TTLSeconds: 60, + }, + } + svc := NewAPIKeyService(repo, nil, nil, nil, cache, cfg) + + svc.InvalidateAuthCacheByKey(context.Background(), "k1") + require.Len(t, cache.deleteAuthKeys, 1) +} + +func TestAPIKeyService_GetByKey_CachesNegativeOnRepoMiss(t *testing.T) { + cache := &authCacheStub{} + repo := &authRepoStub{ + getByKeyForAuth: func(ctx context.Context, key string) (*APIKey, error) { + return nil, ErrAPIKeyNotFound + }, + } + cfg := &config.Config{ + APIKeyAuth: config.APIKeyAuthCacheConfig{ + L2TTLSeconds: 60, + NegativeTTLSeconds: 30, + }, + } + svc := NewAPIKeyService(repo, nil, nil, nil, cache, cfg) + cache.getAuthCache = func(ctx context.Context, key string) (*APIKeyAuthCacheEntry, error) { + return nil, redis.Nil + } + + _, err := svc.GetByKey(context.Background(), "missing") + require.ErrorIs(t, err, ErrAPIKeyNotFound) + require.Len(t, cache.setAuthKeys, 1) +} + +func TestAPIKeyService_GetByKey_SingleflightCollapses(t *testing.T) { + var calls int32 + cache := &authCacheStub{} + repo := &authRepoStub{ + getByKeyForAuth: func(ctx context.Context, key string) (*APIKey, error) { + atomic.AddInt32(&calls, 1) + time.Sleep(50 * time.Millisecond) + return &APIKey{ + ID: 11, + UserID: 2, + Status: StatusActive, + User: &User{ + ID: 2, + Status: StatusActive, + Role: RoleUser, + Balance: 1, + Concurrency: 1, + }, + }, nil + }, + } + cfg := &config.Config{ + APIKeyAuth: config.APIKeyAuthCacheConfig{ + Singleflight: true, + }, + } + svc := NewAPIKeyService(repo, nil, nil, nil, cache, cfg) + + start := make(chan struct{}) + wg := sync.WaitGroup{} + errs := make([]error, 5) + for i := 0; i < 5; i++ { + wg.Add(1) + go func(idx int) { + defer wg.Done() + <-start + _, err := svc.GetByKey(context.Background(), "k1") + errs[idx] = err + }(i) + } + close(start) + wg.Wait() + + for _, err := range errs { + require.NoError(t, err) + } + require.Equal(t, int32(1), atomic.LoadInt32(&calls)) +} diff --git a/backend/internal/service/api_key_service_delete_test.go b/backend/internal/service/api_key_service_delete_test.go index 7d04c5ac..32ae884e 100644 --- a/backend/internal/service/api_key_service_delete_test.go +++ b/backend/internal/service/api_key_service_delete_test.go @@ -20,13 +20,12 @@ import ( // 用于隔离测试 APIKeyService.Delete 方法,避免依赖真实数据库。 // // 设计说明: -// - ownerID: 模拟 GetOwnerID 返回的所有者 ID -// - ownerErr: 模拟 GetOwnerID 返回的错误(如 ErrAPIKeyNotFound) +// - apiKey/getByIDErr: 模拟 GetKeyAndOwnerID 返回的记录与错误 // - deleteErr: 模拟 Delete 返回的错误 // - deletedIDs: 记录被调用删除的 API Key ID,用于断言验证 type apiKeyRepoStub struct { - ownerID int64 // GetOwnerID 的返回值 - ownerErr error // GetOwnerID 的错误返回值 + apiKey *APIKey // GetKeyAndOwnerID 的返回值 + getByIDErr error // GetKeyAndOwnerID 的错误返回值 deleteErr error // Delete 的错误返回值 deletedIDs []int64 // 记录已删除的 API Key ID 列表 } @@ -38,19 +37,34 @@ func (s *apiKeyRepoStub) Create(ctx context.Context, key *APIKey) error { } func (s *apiKeyRepoStub) GetByID(ctx context.Context, id int64) (*APIKey, error) { + if s.getByIDErr != nil { + return nil, s.getByIDErr + } + if s.apiKey != nil { + clone := *s.apiKey + return &clone, nil + } panic("unexpected GetByID call") } -// GetOwnerID 返回预设的所有者 ID 或错误。 -// 这是 Delete 方法调用的第一个仓储方法,用于验证调用者是否为 API Key 的所有者。 -func (s *apiKeyRepoStub) GetOwnerID(ctx context.Context, id int64) (int64, error) { - return s.ownerID, s.ownerErr +func (s *apiKeyRepoStub) GetKeyAndOwnerID(ctx context.Context, id int64) (string, int64, error) { + if s.getByIDErr != nil { + return "", 0, s.getByIDErr + } + if s.apiKey != nil { + return s.apiKey.Key, s.apiKey.UserID, nil + } + return "", 0, ErrAPIKeyNotFound } func (s *apiKeyRepoStub) GetByKey(ctx context.Context, key string) (*APIKey, error) { panic("unexpected GetByKey call") } +func (s *apiKeyRepoStub) GetByKeyForAuth(ctx context.Context, key string) (*APIKey, error) { + panic("unexpected GetByKeyForAuth call") +} + func (s *apiKeyRepoStub) Update(ctx context.Context, key *APIKey) error { panic("unexpected Update call") } @@ -96,13 +110,22 @@ func (s *apiKeyRepoStub) CountByGroupID(ctx context.Context, groupID int64) (int panic("unexpected CountByGroupID call") } +func (s *apiKeyRepoStub) ListKeysByUserID(ctx context.Context, userID int64) ([]string, error) { + panic("unexpected ListKeysByUserID call") +} + +func (s *apiKeyRepoStub) ListKeysByGroupID(ctx context.Context, groupID int64) ([]string, error) { + panic("unexpected ListKeysByGroupID call") +} + // apiKeyCacheStub 是 APIKeyCache 接口的测试桩实现。 // 用于验证删除操作时缓存清理逻辑是否被正确调用。 // // 设计说明: // - invalidated: 记录被清除缓存的用户 ID 列表 type apiKeyCacheStub struct { - invalidated []int64 // 记录调用 DeleteCreateAttemptCount 时传入的用户 ID + invalidated []int64 // 记录调用 DeleteCreateAttemptCount 时传入的用户 ID + deleteAuthKeys []string // 记录调用 DeleteAuthCache 时传入的缓存 key } // GetCreateAttemptCount 返回 0,表示用户未超过创建次数限制 @@ -132,15 +155,30 @@ func (s *apiKeyCacheStub) SetDailyUsageExpiry(ctx context.Context, apiKey string return nil } +func (s *apiKeyCacheStub) GetAuthCache(ctx context.Context, key string) (*APIKeyAuthCacheEntry, error) { + return nil, nil +} + +func (s *apiKeyCacheStub) SetAuthCache(ctx context.Context, key string, entry *APIKeyAuthCacheEntry, ttl time.Duration) error { + return nil +} + +func (s *apiKeyCacheStub) DeleteAuthCache(ctx context.Context, key string) error { + s.deleteAuthKeys = append(s.deleteAuthKeys, key) + return nil +} + // TestApiKeyService_Delete_OwnerMismatch 测试非所有者尝试删除时返回权限错误。 // 预期行为: -// - GetOwnerID 返回所有者 ID 为 1 +// - GetKeyAndOwnerID 返回所有者 ID 为 1 // - 调用者 userID 为 2(不匹配) // - 返回 ErrInsufficientPerms 错误 // - Delete 方法不被调用 // - 缓存不被清除 func TestApiKeyService_Delete_OwnerMismatch(t *testing.T) { - repo := &apiKeyRepoStub{ownerID: 1} + repo := &apiKeyRepoStub{ + apiKey: &APIKey{ID: 10, UserID: 1, Key: "k"}, + } cache := &apiKeyCacheStub{} svc := &APIKeyService{apiKeyRepo: repo, cache: cache} @@ -148,17 +186,20 @@ func TestApiKeyService_Delete_OwnerMismatch(t *testing.T) { require.ErrorIs(t, err, ErrInsufficientPerms) require.Empty(t, repo.deletedIDs) // 验证删除操作未被调用 require.Empty(t, cache.invalidated) // 验证缓存未被清除 + require.Empty(t, cache.deleteAuthKeys) } // TestApiKeyService_Delete_Success 测试所有者成功删除 API Key 的场景。 // 预期行为: -// - GetOwnerID 返回所有者 ID 为 7 +// - GetKeyAndOwnerID 返回所有者 ID 为 7 // - 调用者 userID 为 7(匹配) // - Delete 成功执行 // - 缓存被正确清除(使用 ownerID) // - 返回 nil 错误 func TestApiKeyService_Delete_Success(t *testing.T) { - repo := &apiKeyRepoStub{ownerID: 7} + repo := &apiKeyRepoStub{ + apiKey: &APIKey{ID: 42, UserID: 7, Key: "k"}, + } cache := &apiKeyCacheStub{} svc := &APIKeyService{apiKeyRepo: repo, cache: cache} @@ -166,16 +207,17 @@ func TestApiKeyService_Delete_Success(t *testing.T) { require.NoError(t, err) require.Equal(t, []int64{42}, repo.deletedIDs) // 验证正确的 API Key 被删除 require.Equal(t, []int64{7}, cache.invalidated) // 验证所有者的缓存被清除 + require.Equal(t, []string{svc.authCacheKey("k")}, cache.deleteAuthKeys) } // TestApiKeyService_Delete_NotFound 测试删除不存在的 API Key 时返回正确的错误。 // 预期行为: -// - GetOwnerID 返回 ErrAPIKeyNotFound 错误 +// - GetKeyAndOwnerID 返回 ErrAPIKeyNotFound 错误 // - 返回 ErrAPIKeyNotFound 错误(被 fmt.Errorf 包装) // - Delete 方法不被调用 // - 缓存不被清除 func TestApiKeyService_Delete_NotFound(t *testing.T) { - repo := &apiKeyRepoStub{ownerErr: ErrAPIKeyNotFound} + repo := &apiKeyRepoStub{getByIDErr: ErrAPIKeyNotFound} cache := &apiKeyCacheStub{} svc := &APIKeyService{apiKeyRepo: repo, cache: cache} @@ -183,18 +225,19 @@ func TestApiKeyService_Delete_NotFound(t *testing.T) { require.ErrorIs(t, err, ErrAPIKeyNotFound) require.Empty(t, repo.deletedIDs) require.Empty(t, cache.invalidated) + require.Empty(t, cache.deleteAuthKeys) } // TestApiKeyService_Delete_DeleteFails 测试删除操作失败时的错误处理。 // 预期行为: -// - GetOwnerID 返回正确的所有者 ID +// - GetKeyAndOwnerID 返回正确的所有者 ID // - 所有权验证通过 // - 缓存被清除(在删除之前) // - Delete 被调用但返回错误 // - 返回包含 "delete api key" 的错误信息 func TestApiKeyService_Delete_DeleteFails(t *testing.T) { repo := &apiKeyRepoStub{ - ownerID: 3, + apiKey: &APIKey{ID: 42, UserID: 3, Key: "k"}, deleteErr: errors.New("delete failed"), } cache := &apiKeyCacheStub{} @@ -205,4 +248,5 @@ func TestApiKeyService_Delete_DeleteFails(t *testing.T) { require.ErrorContains(t, err, "delete api key") require.Equal(t, []int64{3}, repo.deletedIDs) // 验证删除操作被调用 require.Equal(t, []int64{3}, cache.invalidated) // 验证缓存已被清除(即使删除失败) + require.Equal(t, []string{svc.authCacheKey("k")}, cache.deleteAuthKeys) } diff --git a/backend/internal/service/group_service.go b/backend/internal/service/group_service.go index 2f0f4975..a9214c82 100644 --- a/backend/internal/service/group_service.go +++ b/backend/internal/service/group_service.go @@ -50,13 +50,15 @@ type UpdateGroupRequest struct { // GroupService 分组管理服务 type GroupService struct { - groupRepo GroupRepository + groupRepo GroupRepository + authCacheInvalidator APIKeyAuthCacheInvalidator } // NewGroupService 创建分组服务实例 -func NewGroupService(groupRepo GroupRepository) *GroupService { +func NewGroupService(groupRepo GroupRepository, authCacheInvalidator APIKeyAuthCacheInvalidator) *GroupService { return &GroupService{ - groupRepo: groupRepo, + groupRepo: groupRepo, + authCacheInvalidator: authCacheInvalidator, } } @@ -155,6 +157,9 @@ func (s *GroupService) Update(ctx context.Context, id int64, req UpdateGroupRequ if err := s.groupRepo.Update(ctx, group); err != nil { return nil, fmt.Errorf("update group: %w", err) } + if s.authCacheInvalidator != nil { + s.authCacheInvalidator.InvalidateAuthCacheByGroupID(ctx, id) + } return group, nil } @@ -170,6 +175,9 @@ func (s *GroupService) Delete(ctx context.Context, id int64) error { if err := s.groupRepo.Delete(ctx, id); err != nil { return fmt.Errorf("delete group: %w", err) } + if s.authCacheInvalidator != nil { + s.authCacheInvalidator.InvalidateAuthCacheByGroupID(ctx, id) + } return nil } diff --git a/backend/internal/service/user_service.go b/backend/internal/service/user_service.go index 08fa40b5..a7a36760 100644 --- a/backend/internal/service/user_service.go +++ b/backend/internal/service/user_service.go @@ -55,13 +55,15 @@ type ChangePasswordRequest struct { // UserService 用户服务 type UserService struct { - userRepo UserRepository + userRepo UserRepository + authCacheInvalidator APIKeyAuthCacheInvalidator } // NewUserService 创建用户服务实例 -func NewUserService(userRepo UserRepository) *UserService { +func NewUserService(userRepo UserRepository, authCacheInvalidator APIKeyAuthCacheInvalidator) *UserService { return &UserService{ - userRepo: userRepo, + userRepo: userRepo, + authCacheInvalidator: authCacheInvalidator, } } @@ -89,6 +91,7 @@ func (s *UserService) UpdateProfile(ctx context.Context, userID int64, req Updat if err != nil { return nil, fmt.Errorf("get user: %w", err) } + oldConcurrency := user.Concurrency // 更新字段 if req.Email != nil { @@ -114,6 +117,9 @@ func (s *UserService) UpdateProfile(ctx context.Context, userID int64, req Updat if err := s.userRepo.Update(ctx, user); err != nil { return nil, fmt.Errorf("update user: %w", err) } + if s.authCacheInvalidator != nil && user.Concurrency != oldConcurrency { + s.authCacheInvalidator.InvalidateAuthCacheByUserID(ctx, userID) + } return user, nil } @@ -169,6 +175,9 @@ func (s *UserService) UpdateBalance(ctx context.Context, userID int64, amount fl if err := s.userRepo.UpdateBalance(ctx, userID, amount); err != nil { return fmt.Errorf("update balance: %w", err) } + if s.authCacheInvalidator != nil { + s.authCacheInvalidator.InvalidateAuthCacheByUserID(ctx, userID) + } return nil } @@ -177,6 +186,9 @@ func (s *UserService) UpdateConcurrency(ctx context.Context, userID int64, concu if err := s.userRepo.UpdateConcurrency(ctx, userID, concurrency); err != nil { return fmt.Errorf("update concurrency: %w", err) } + if s.authCacheInvalidator != nil { + s.authCacheInvalidator.InvalidateAuthCacheByUserID(ctx, userID) + } return nil } @@ -192,6 +204,9 @@ func (s *UserService) UpdateStatus(ctx context.Context, userID int64, status str if err := s.userRepo.Update(ctx, user); err != nil { return fmt.Errorf("update user: %w", err) } + if s.authCacheInvalidator != nil { + s.authCacheInvalidator.InvalidateAuthCacheByUserID(ctx, userID) + } return nil } @@ -201,5 +216,8 @@ func (s *UserService) Delete(ctx context.Context, userID int64) error { if err := s.userRepo.Delete(ctx, userID); err != nil { return fmt.Errorf("delete user: %w", err) } + if s.authCacheInvalidator != nil { + s.authCacheInvalidator.InvalidateAuthCacheByUserID(ctx, userID) + } return nil } diff --git a/backend/internal/service/wire.go b/backend/internal/service/wire.go index 512d2550..54c37b54 100644 --- a/backend/internal/service/wire.go +++ b/backend/internal/service/wire.go @@ -77,12 +77,18 @@ func ProvideConcurrencyService(cache ConcurrencyCache, accountRepo AccountReposi return svc } +// ProvideAPIKeyAuthCacheInvalidator 提供 API Key 认证缓存失效能力 +func ProvideAPIKeyAuthCacheInvalidator(apiKeyService *APIKeyService) APIKeyAuthCacheInvalidator { + return apiKeyService +} + // ProviderSet is the Wire provider set for all services var ProviderSet = wire.NewSet( // Core services NewAuthService, NewUserService, NewAPIKeyService, + ProvideAPIKeyAuthCacheInvalidator, NewGroupService, NewAccountService, NewProxyService, diff --git a/config.yaml b/config.yaml index 54b591f3..ecd7dfc2 100644 --- a/config.yaml +++ b/config.yaml @@ -170,6 +170,30 @@ gateway: # 允许在特定 400 错误时进行故障转移(默认:关闭) failover_on_400: false +# ============================================================================= +# API Key Auth Cache Configuration +# API Key 认证缓存配置 +# ============================================================================= +api_key_auth_cache: + # L1 cache size (entries), in-process LRU/TTL cache + # L1 缓存容量(条目数),进程内 LRU/TTL 缓存 + l1_size: 65535 + # L1 cache TTL (seconds) + # L1 缓存 TTL(秒) + l1_ttl_seconds: 15 + # L2 cache TTL (seconds), stored in Redis + # L2 缓存 TTL(秒),Redis 中存储 + l2_ttl_seconds: 300 + # Negative cache TTL (seconds) + # 负缓存 TTL(秒) + negative_ttl_seconds: 30 + # TTL jitter percent (0-100) + # TTL 抖动百分比(0-100) + jitter_percent: 10 + # Enable singleflight for cache misses + # 缓存未命中时启用 singleflight 合并回源 + singleflight: true + # ============================================================================= # Concurrency Wait Configuration # 并发等待配置 diff --git a/deploy/config.example.yaml b/deploy/config.example.yaml index 87ff3148..87abffa0 100644 --- a/deploy/config.example.yaml +++ b/deploy/config.example.yaml @@ -170,6 +170,30 @@ gateway: # 允许在特定 400 错误时进行故障转移(默认:关闭) failover_on_400: false +# ============================================================================= +# API Key Auth Cache Configuration +# API Key 认证缓存配置 +# ============================================================================= +api_key_auth_cache: + # L1 cache size (entries), in-process LRU/TTL cache + # L1 缓存容量(条目数),进程内 LRU/TTL 缓存 + l1_size: 65535 + # L1 cache TTL (seconds) + # L1 缓存 TTL(秒) + l1_ttl_seconds: 15 + # L2 cache TTL (seconds), stored in Redis + # L2 缓存 TTL(秒),Redis 中存储 + l2_ttl_seconds: 300 + # Negative cache TTL (seconds) + # 负缓存 TTL(秒) + negative_ttl_seconds: 30 + # TTL jitter percent (0-100) + # TTL 抖动百分比(0-100) + jitter_percent: 10 + # Enable singleflight for cache misses + # 缓存未命中时启用 singleflight 合并回源 + singleflight: true + # ============================================================================= # Concurrency Wait Configuration # 并发等待配置 From cb3e08dda489b1e757146c1b5f64c3705fc7ed6b Mon Sep 17 00:00:00 2001 From: yangjianbo Date: Sat, 10 Jan 2026 22:52:13 +0800 Subject: [PATCH 10/23] =?UTF-8?q?fix(=E8=AE=A4=E8=AF=81):=20=E8=A1=A5?= =?UTF-8?q?=E9=BD=90=E4=BD=99=E9=A2=9D=E4=B8=8E=E5=88=A0=E9=99=A4=E5=9C=BA?= =?UTF-8?q?=E6=99=AF=E7=BC=93=E5=AD=98=E5=A4=B1=E6=95=88?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit 为 Usage/Promo/Redeem 注入认证缓存失效逻辑 删除用户与分组前先失效认证缓存降低窗口 补充回归测试验证失效调用 测试: make test --- backend/cmd/server/wire_gen.go | 8 ++-- backend/internal/server/api_contract_test.go | 2 +- .../service/auth_cache_invalidation_test.go | 31 ++++++++++++ backend/internal/service/group_service.go | 6 +-- backend/internal/service/promo_service.go | 28 +++++++---- backend/internal/service/redeem_service.go | 47 ++++++++++++------- backend/internal/service/usage_service.go | 27 ++++++++--- backend/internal/service/user_service.go | 6 +-- 8 files changed, 113 insertions(+), 42 deletions(-) create mode 100644 backend/internal/service/auth_cache_invalidation_test.go diff --git a/backend/cmd/server/wire_gen.go b/backend/cmd/server/wire_gen.go index 561b0aeb..18ec84c5 100644 --- a/backend/cmd/server/wire_gen.go +++ b/backend/cmd/server/wire_gen.go @@ -55,24 +55,24 @@ func initializeApplication(buildInfo handler.BuildInfo) (*Application, error) { billingCache := repository.NewBillingCache(redisClient) userSubscriptionRepository := repository.NewUserSubscriptionRepository(client) billingCacheService := service.NewBillingCacheService(billingCache, userRepository, userSubscriptionRepository, configConfig) - promoService := service.NewPromoService(promoCodeRepository, userRepository, billingCacheService, client) - authService := service.NewAuthService(userRepository, configConfig, settingService, emailService, turnstileService, emailQueueService, promoService) apiKeyRepository := repository.NewAPIKeyRepository(client) groupRepository := repository.NewGroupRepository(client, db) apiKeyCache := repository.NewAPIKeyCache(redisClient) apiKeyService := service.NewAPIKeyService(apiKeyRepository, userRepository, groupRepository, userSubscriptionRepository, apiKeyCache, configConfig) apiKeyAuthCacheInvalidator := service.ProvideAPIKeyAuthCacheInvalidator(apiKeyService) + promoService := service.NewPromoService(promoCodeRepository, userRepository, billingCacheService, client, apiKeyAuthCacheInvalidator) + authService := service.NewAuthService(userRepository, configConfig, settingService, emailService, turnstileService, emailQueueService, promoService) userService := service.NewUserService(userRepository, apiKeyAuthCacheInvalidator) authHandler := handler.NewAuthHandler(configConfig, authService, userService, settingService, promoService) userHandler := handler.NewUserHandler(userService) apiKeyHandler := handler.NewAPIKeyHandler(apiKeyService) usageLogRepository := repository.NewUsageLogRepository(client, db) - usageService := service.NewUsageService(usageLogRepository, userRepository, client) + usageService := service.NewUsageService(usageLogRepository, userRepository, client, apiKeyAuthCacheInvalidator) usageHandler := handler.NewUsageHandler(usageService, apiKeyService) redeemCodeRepository := repository.NewRedeemCodeRepository(client) subscriptionService := service.NewSubscriptionService(groupRepository, userSubscriptionRepository, billingCacheService) redeemCache := repository.NewRedeemCache(redisClient) - redeemService := service.NewRedeemService(redeemCodeRepository, userRepository, subscriptionService, redeemCache, billingCacheService, client) + redeemService := service.NewRedeemService(redeemCodeRepository, userRepository, subscriptionService, redeemCache, billingCacheService, client, apiKeyAuthCacheInvalidator) redeemHandler := handler.NewRedeemHandler(redeemService) subscriptionHandler := handler.NewSubscriptionHandler(subscriptionService) dashboardService := service.NewDashboardService(usageLogRepository) diff --git a/backend/internal/server/api_contract_test.go b/backend/internal/server/api_contract_test.go index 04cc5c2e..abcf0e6c 100644 --- a/backend/internal/server/api_contract_test.go +++ b/backend/internal/server/api_contract_test.go @@ -394,7 +394,7 @@ func newContractDeps(t *testing.T) *contractDeps { apiKeyService := service.NewAPIKeyService(apiKeyRepo, userRepo, groupRepo, userSubRepo, apiKeyCache, cfg) usageRepo := newStubUsageLogRepo() - usageService := service.NewUsageService(usageRepo, userRepo, nil) + usageService := service.NewUsageService(usageRepo, userRepo, nil, nil) settingRepo := newStubSettingRepo() settingService := service.NewSettingService(settingRepo, cfg) diff --git a/backend/internal/service/auth_cache_invalidation_test.go b/backend/internal/service/auth_cache_invalidation_test.go new file mode 100644 index 00000000..3b4217c6 --- /dev/null +++ b/backend/internal/service/auth_cache_invalidation_test.go @@ -0,0 +1,31 @@ +//go:build unit + +package service + +import ( + "context" + "testing" + + "github.com/stretchr/testify/require" +) + +func TestUsageService_InvalidateUsageCaches(t *testing.T) { + invalidator := &authCacheInvalidatorStub{} + svc := &UsageService{authCacheInvalidator: invalidator} + + svc.invalidateUsageCaches(context.Background(), 7, false) + require.Empty(t, invalidator.userIDs) + + svc.invalidateUsageCaches(context.Background(), 7, true) + require.Equal(t, []int64{7}, invalidator.userIDs) +} + +func TestRedeemService_InvalidateRedeemCaches_AuthCache(t *testing.T) { + invalidator := &authCacheInvalidatorStub{} + svc := &RedeemService{authCacheInvalidator: invalidator} + + svc.invalidateRedeemCaches(context.Background(), 11, &RedeemCode{Type: RedeemTypeBalance}) + svc.invalidateRedeemCaches(context.Background(), 11, &RedeemCode{Type: RedeemTypeConcurrency}) + + require.Equal(t, []int64{11, 11}, invalidator.userIDs) +} diff --git a/backend/internal/service/group_service.go b/backend/internal/service/group_service.go index a9214c82..324f347b 100644 --- a/backend/internal/service/group_service.go +++ b/backend/internal/service/group_service.go @@ -172,12 +172,12 @@ func (s *GroupService) Delete(ctx context.Context, id int64) error { return fmt.Errorf("get group: %w", err) } - if err := s.groupRepo.Delete(ctx, id); err != nil { - return fmt.Errorf("delete group: %w", err) - } if s.authCacheInvalidator != nil { s.authCacheInvalidator.InvalidateAuthCacheByGroupID(ctx, id) } + if err := s.groupRepo.Delete(ctx, id); err != nil { + return fmt.Errorf("delete group: %w", err) + } return nil } diff --git a/backend/internal/service/promo_service.go b/backend/internal/service/promo_service.go index 9acd5868..5ff63bdc 100644 --- a/backend/internal/service/promo_service.go +++ b/backend/internal/service/promo_service.go @@ -24,10 +24,11 @@ var ( // PromoService 优惠码服务 type PromoService struct { - promoRepo PromoCodeRepository - userRepo UserRepository - billingCacheService *BillingCacheService - entClient *dbent.Client + promoRepo PromoCodeRepository + userRepo UserRepository + billingCacheService *BillingCacheService + entClient *dbent.Client + authCacheInvalidator APIKeyAuthCacheInvalidator } // NewPromoService 创建优惠码服务实例 @@ -36,12 +37,14 @@ func NewPromoService( userRepo UserRepository, billingCacheService *BillingCacheService, entClient *dbent.Client, + authCacheInvalidator APIKeyAuthCacheInvalidator, ) *PromoService { return &PromoService{ - promoRepo: promoRepo, - userRepo: userRepo, - billingCacheService: billingCacheService, - entClient: entClient, + promoRepo: promoRepo, + userRepo: userRepo, + billingCacheService: billingCacheService, + entClient: entClient, + authCacheInvalidator: authCacheInvalidator, } } @@ -145,6 +148,8 @@ func (s *PromoService) ApplyPromoCode(ctx context.Context, userID int64, code st return fmt.Errorf("commit transaction: %w", err) } + s.invalidatePromoCaches(ctx, userID, promoCode.BonusAmount) + // 失效余额缓存 if s.billingCacheService != nil { go func() { @@ -157,6 +162,13 @@ func (s *PromoService) ApplyPromoCode(ctx context.Context, userID int64, code st return nil } +func (s *PromoService) invalidatePromoCaches(ctx context.Context, userID int64, bonusAmount float64) { + if bonusAmount == 0 || s.authCacheInvalidator == nil { + return + } + s.authCacheInvalidator.InvalidateAuthCacheByUserID(ctx, userID) +} + // GenerateRandomCode 生成随机优惠码 func (s *PromoService) GenerateRandomCode() (string, error) { bytes := make([]byte, 8) diff --git a/backend/internal/service/redeem_service.go b/backend/internal/service/redeem_service.go index b6324235..81767aa9 100644 --- a/backend/internal/service/redeem_service.go +++ b/backend/internal/service/redeem_service.go @@ -68,12 +68,13 @@ type RedeemCodeResponse struct { // RedeemService 兑换码服务 type RedeemService struct { - redeemRepo RedeemCodeRepository - userRepo UserRepository - subscriptionService *SubscriptionService - cache RedeemCache - billingCacheService *BillingCacheService - entClient *dbent.Client + redeemRepo RedeemCodeRepository + userRepo UserRepository + subscriptionService *SubscriptionService + cache RedeemCache + billingCacheService *BillingCacheService + entClient *dbent.Client + authCacheInvalidator APIKeyAuthCacheInvalidator } // NewRedeemService 创建兑换码服务实例 @@ -84,14 +85,16 @@ func NewRedeemService( cache RedeemCache, billingCacheService *BillingCacheService, entClient *dbent.Client, + authCacheInvalidator APIKeyAuthCacheInvalidator, ) *RedeemService { return &RedeemService{ - redeemRepo: redeemRepo, - userRepo: userRepo, - subscriptionService: subscriptionService, - cache: cache, - billingCacheService: billingCacheService, - entClient: entClient, + redeemRepo: redeemRepo, + userRepo: userRepo, + subscriptionService: subscriptionService, + cache: cache, + billingCacheService: billingCacheService, + entClient: entClient, + authCacheInvalidator: authCacheInvalidator, } } @@ -324,18 +327,30 @@ func (s *RedeemService) Redeem(ctx context.Context, userID int64, code string) ( // invalidateRedeemCaches 失效兑换相关的缓存 func (s *RedeemService) invalidateRedeemCaches(ctx context.Context, userID int64, redeemCode *RedeemCode) { - if s.billingCacheService == nil { - return - } - switch redeemCode.Type { case RedeemTypeBalance: + if s.authCacheInvalidator != nil { + s.authCacheInvalidator.InvalidateAuthCacheByUserID(ctx, userID) + } + if s.billingCacheService == nil { + return + } go func() { cacheCtx, cancel := context.WithTimeout(context.Background(), 5*time.Second) defer cancel() _ = s.billingCacheService.InvalidateUserBalance(cacheCtx, userID) }() + case RedeemTypeConcurrency: + if s.authCacheInvalidator != nil { + s.authCacheInvalidator.InvalidateAuthCacheByUserID(ctx, userID) + } + if s.billingCacheService == nil { + return + } case RedeemTypeSubscription: + if s.billingCacheService == nil { + return + } if redeemCode.GroupID != nil { groupID := *redeemCode.GroupID go func() { diff --git a/backend/internal/service/usage_service.go b/backend/internal/service/usage_service.go index 10a294ae..aa0a5b87 100644 --- a/backend/internal/service/usage_service.go +++ b/backend/internal/service/usage_service.go @@ -54,17 +54,19 @@ type UsageStats struct { // UsageService 使用统计服务 type UsageService struct { - usageRepo UsageLogRepository - userRepo UserRepository - entClient *dbent.Client + usageRepo UsageLogRepository + userRepo UserRepository + entClient *dbent.Client + authCacheInvalidator APIKeyAuthCacheInvalidator } // NewUsageService 创建使用统计服务实例 -func NewUsageService(usageRepo UsageLogRepository, userRepo UserRepository, entClient *dbent.Client) *UsageService { +func NewUsageService(usageRepo UsageLogRepository, userRepo UserRepository, entClient *dbent.Client, authCacheInvalidator APIKeyAuthCacheInvalidator) *UsageService { return &UsageService{ - usageRepo: usageRepo, - userRepo: userRepo, - entClient: entClient, + usageRepo: usageRepo, + userRepo: userRepo, + entClient: entClient, + authCacheInvalidator: authCacheInvalidator, } } @@ -118,10 +120,12 @@ func (s *UsageService) Create(ctx context.Context, req CreateUsageLogRequest) (* } // 扣除用户余额 + balanceUpdated := false if inserted && req.ActualCost > 0 { if err := s.userRepo.UpdateBalance(txCtx, req.UserID, -req.ActualCost); err != nil { return nil, fmt.Errorf("update user balance: %w", err) } + balanceUpdated = true } if tx != nil { @@ -130,9 +134,18 @@ func (s *UsageService) Create(ctx context.Context, req CreateUsageLogRequest) (* } } + s.invalidateUsageCaches(ctx, req.UserID, balanceUpdated) + return usageLog, nil } +func (s *UsageService) invalidateUsageCaches(ctx context.Context, userID int64, balanceUpdated bool) { + if !balanceUpdated || s.authCacheInvalidator == nil { + return + } + s.authCacheInvalidator.InvalidateAuthCacheByUserID(ctx, userID) +} + // GetByID 根据ID获取使用日志 func (s *UsageService) GetByID(ctx context.Context, id int64) (*UsageLog, error) { log, err := s.usageRepo.GetByID(ctx, id) diff --git a/backend/internal/service/user_service.go b/backend/internal/service/user_service.go index a7a36760..1734914a 100644 --- a/backend/internal/service/user_service.go +++ b/backend/internal/service/user_service.go @@ -213,11 +213,11 @@ func (s *UserService) UpdateStatus(ctx context.Context, userID int64, status str // Delete 删除用户(管理员功能) func (s *UserService) Delete(ctx context.Context, userID int64) error { - if err := s.userRepo.Delete(ctx, userID); err != nil { - return fmt.Errorf("delete user: %w", err) - } if s.authCacheInvalidator != nil { s.authCacheInvalidator.InvalidateAuthCacheByUserID(ctx, userID) } + if err := s.userRepo.Delete(ctx, userID); err != nil { + return fmt.Errorf("delete user: %w", err) + } return nil } From d75cd820b0bc514ada6318a001ecced4fa0bf9b5 Mon Sep 17 00:00:00 2001 From: yangjianbo Date: Sat, 10 Jan 2026 23:14:20 +0800 Subject: [PATCH 11/23] =?UTF-8?q?fix(=E8=AE=A4=E8=AF=81):=20=E8=AE=A2?= =?UTF-8?q?=E9=98=85=E5=85=91=E6=8D=A2=E5=A4=B1=E6=95=88=E8=AE=A4=E8=AF=81?= =?UTF-8?q?=E7=BC=93=E5=AD=98?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit 订阅兑换后同步失效认证缓存避免授权快照滞后 补充单测覆盖订阅兑换的失效场景 测试: go test ./... -tags=unit --- backend/internal/service/auth_cache_invalidation_test.go | 4 +++- backend/internal/service/redeem_service.go | 3 +++ 2 files changed, 6 insertions(+), 1 deletion(-) diff --git a/backend/internal/service/auth_cache_invalidation_test.go b/backend/internal/service/auth_cache_invalidation_test.go index 3b4217c6..b6e56177 100644 --- a/backend/internal/service/auth_cache_invalidation_test.go +++ b/backend/internal/service/auth_cache_invalidation_test.go @@ -26,6 +26,8 @@ func TestRedeemService_InvalidateRedeemCaches_AuthCache(t *testing.T) { svc.invalidateRedeemCaches(context.Background(), 11, &RedeemCode{Type: RedeemTypeBalance}) svc.invalidateRedeemCaches(context.Background(), 11, &RedeemCode{Type: RedeemTypeConcurrency}) + groupID := int64(3) + svc.invalidateRedeemCaches(context.Background(), 11, &RedeemCode{Type: RedeemTypeSubscription, GroupID: &groupID}) - require.Equal(t, []int64{11, 11}, invalidator.userIDs) + require.Equal(t, []int64{11, 11, 11}, invalidator.userIDs) } diff --git a/backend/internal/service/redeem_service.go b/backend/internal/service/redeem_service.go index 81767aa9..ff52dc47 100644 --- a/backend/internal/service/redeem_service.go +++ b/backend/internal/service/redeem_service.go @@ -348,6 +348,9 @@ func (s *RedeemService) invalidateRedeemCaches(ctx context.Context, userID int64 return } case RedeemTypeSubscription: + if s.authCacheInvalidator != nil { + s.authCacheInvalidator.InvalidateAuthCacheByUserID(ctx, userID) + } if s.billingCacheService == nil { return } From ab5839b461e82bb4784394a0d6be6a17243d0fab Mon Sep 17 00:00:00 2001 From: yangjianbo Date: Sun, 11 Jan 2026 15:00:16 +0800 Subject: [PATCH 12/23] =?UTF-8?q?fix(=E4=BB=AA=E8=A1=A8=E7=9B=98):=20?= =?UTF-8?q?=E4=BF=AE=E5=A4=8D=E7=BC=93=E5=AD=98=E7=A8=B3=E5=AE=9A=E6=80=A7?= =?UTF-8?q?=E5=B9=B6=E8=A1=A5=E5=85=85=E6=B5=8B=E8=AF=95?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- .../internal/repository/dashboard_cache.go | 7 +++ .../repository/dashboard_cache_test.go | 28 ++++++++++++ backend/internal/service/dashboard_service.go | 34 ++++++++++++-- .../service/dashboard_service_test.go | 44 +++++++++++++++++++ 4 files changed, 109 insertions(+), 4 deletions(-) create mode 100644 backend/internal/repository/dashboard_cache_test.go diff --git a/backend/internal/repository/dashboard_cache.go b/backend/internal/repository/dashboard_cache.go index ec6ef25c..f996cd68 100644 --- a/backend/internal/repository/dashboard_cache.go +++ b/backend/internal/repository/dashboard_cache.go @@ -22,6 +22,9 @@ func NewDashboardCache(rdb *redis.Client, cfg *config.Config) service.DashboardS if cfg != nil { prefix = strings.TrimSpace(cfg.Dashboard.KeyPrefix) } + if prefix != "" && !strings.HasSuffix(prefix, ":") { + prefix += ":" + } return &dashboardCache{ rdb: rdb, keyPrefix: prefix, @@ -49,3 +52,7 @@ func (c *dashboardCache) buildKey() string { } return c.keyPrefix + dashboardStatsCacheKey } + +func (c *dashboardCache) DeleteDashboardStats(ctx context.Context) error { + return c.rdb.Del(ctx, c.buildKey()).Err() +} diff --git a/backend/internal/repository/dashboard_cache_test.go b/backend/internal/repository/dashboard_cache_test.go new file mode 100644 index 00000000..3bb0da4f --- /dev/null +++ b/backend/internal/repository/dashboard_cache_test.go @@ -0,0 +1,28 @@ +package repository + +import ( + "testing" + + "github.com/Wei-Shaw/sub2api/internal/config" + "github.com/stretchr/testify/require" +) + +func TestNewDashboardCacheKeyPrefix(t *testing.T) { + cache := NewDashboardCache(nil, &config.Config{ + Dashboard: config.DashboardCacheConfig{ + KeyPrefix: "prod", + }, + }) + impl, ok := cache.(*dashboardCache) + require.True(t, ok) + require.Equal(t, "prod:", impl.keyPrefix) + + cache = NewDashboardCache(nil, &config.Config{ + Dashboard: config.DashboardCacheConfig{ + KeyPrefix: "staging:", + }, + }) + impl, ok = cache.(*dashboardCache) + require.True(t, ok) + require.Equal(t, "staging:", impl.keyPrefix) +} diff --git a/backend/internal/service/dashboard_service.go b/backend/internal/service/dashboard_service.go index f56480d3..468135e3 100644 --- a/backend/internal/service/dashboard_service.go +++ b/backend/internal/service/dashboard_service.go @@ -26,6 +26,7 @@ var ErrDashboardStatsCacheMiss = errors.New("仪表盘缓存未命中") type DashboardStatsCache interface { GetDashboardStats(ctx context.Context) (string, error) SetDashboardStats(ctx context.Context, data string, ttl time.Duration) error + DeleteDashboardStats(ctx context.Context) error } type dashboardStatsCacheEntry struct { @@ -115,10 +116,12 @@ func (s *DashboardService) getCachedDashboardStats(ctx context.Context) (*usages var entry dashboardStatsCacheEntry if err := json.Unmarshal([]byte(data), &entry); err != nil { - return nil, false, err + s.evictDashboardStatsCache(err) + return nil, false, ErrDashboardStatsCacheMiss } if entry.Stats == nil { - return nil, false, errors.New("仪表盘缓存缺少统计数据") + s.evictDashboardStatsCache(errors.New("仪表盘缓存缺少统计数据")) + return nil, false, ErrDashboardStatsCacheMiss } age := time.Since(time.Unix(entry.UpdatedAt, 0)) @@ -130,7 +133,9 @@ func (s *DashboardService) refreshDashboardStats(ctx context.Context) (*usagesta if err != nil { return nil, err } - s.saveDashboardStatsCache(ctx, stats) + cacheCtx, cancel := s.cacheOperationContext() + defer cancel() + s.saveDashboardStatsCache(cacheCtx, stats) return stats, nil } @@ -153,7 +158,9 @@ func (s *DashboardService) refreshDashboardStatsAsync() { log.Printf("[Dashboard] 仪表盘缓存异步刷新失败: %v", err) return } - s.saveDashboardStatsCache(ctx, stats) + cacheCtx, cancel := s.cacheOperationContext() + defer cancel() + s.saveDashboardStatsCache(cacheCtx, stats) }() } @@ -177,6 +184,25 @@ func (s *DashboardService) saveDashboardStatsCache(ctx context.Context, stats *u } } +func (s *DashboardService) evictDashboardStatsCache(reason error) { + if s.cache == nil { + return + } + cacheCtx, cancel := s.cacheOperationContext() + defer cancel() + + if err := s.cache.DeleteDashboardStats(cacheCtx); err != nil { + log.Printf("[Dashboard] 仪表盘缓存清理失败: %v", err) + } + if reason != nil { + log.Printf("[Dashboard] 仪表盘缓存异常,已清理: %v", reason) + } +} + +func (s *DashboardService) cacheOperationContext() (context.Context, context.CancelFunc) { + return context.WithTimeout(context.Background(), s.refreshTimeout) +} + func (s *DashboardService) GetAPIKeyUsageTrend(ctx context.Context, startTime, endTime time.Time, granularity string, limit int) ([]usagestats.APIKeyUsageTrendPoint, error) { trend, err := s.usageRepo.GetAPIKeyUsageTrend(ctx, startTime, endTime, granularity, limit) if err != nil { diff --git a/backend/internal/service/dashboard_service_test.go b/backend/internal/service/dashboard_service_test.go index 21d7b580..17f46ead 100644 --- a/backend/internal/service/dashboard_service_test.go +++ b/backend/internal/service/dashboard_service_test.go @@ -3,6 +3,7 @@ package service import ( "context" "encoding/json" + "errors" "sync" "sync/atomic" "testing" @@ -38,8 +39,10 @@ func (s *usageRepoStub) GetDashboardStats(ctx context.Context) (*usagestats.Dash type dashboardCacheStub struct { get func(ctx context.Context) (string, error) set func(ctx context.Context, data string, ttl time.Duration) error + del func(ctx context.Context) error getCalls int32 setCalls int32 + delCalls int32 lastSetMu sync.Mutex lastSet string } @@ -63,6 +66,14 @@ func (c *dashboardCacheStub) SetDashboardStats(ctx context.Context, data string, return nil } +func (c *dashboardCacheStub) DeleteDashboardStats(ctx context.Context) error { + atomic.AddInt32(&c.delCalls, 1) + if c.del != nil { + return c.del(ctx) + } + return nil +} + func (c *dashboardCacheStub) readLastEntry(t *testing.T) dashboardStatsCacheEntry { t.Helper() c.lastSetMu.Lock() @@ -187,3 +198,36 @@ func TestDashboardService_CacheHitStale_TriggersAsyncRefresh(t *testing.T) { return atomic.LoadInt32(&cache.setCalls) >= 1 }, 1*time.Second, 10*time.Millisecond) } + +func TestDashboardService_CacheParseError_EvictsAndRefetches(t *testing.T) { + cache := &dashboardCacheStub{ + get: func(ctx context.Context) (string, error) { + return "not-json", nil + }, + } + stats := &usagestats.DashboardStats{TotalUsers: 9} + repo := &usageRepoStub{stats: stats} + cfg := &config.Config{Dashboard: config.DashboardCacheConfig{Enabled: true}} + svc := NewDashboardService(repo, cache, cfg) + + got, err := svc.GetDashboardStats(context.Background()) + require.NoError(t, err) + require.Equal(t, stats, got) + require.Equal(t, int32(1), atomic.LoadInt32(&cache.delCalls)) + require.Equal(t, int32(1), atomic.LoadInt32(&repo.calls)) +} + +func TestDashboardService_CacheParseError_RepoFailure(t *testing.T) { + cache := &dashboardCacheStub{ + get: func(ctx context.Context) (string, error) { + return "not-json", nil + }, + } + repo := &usageRepoStub{err: errors.New("db down")} + cfg := &config.Config{Dashboard: config.DashboardCacheConfig{Enabled: true}} + svc := NewDashboardService(repo, cache, cfg) + + _, err := svc.GetDashboardStats(context.Background()) + require.Error(t, err) + require.Equal(t, int32(1), atomic.LoadInt32(&cache.delCalls)) +} From 1a869547d717aeb7f7c2bbd99a55370680ac616f Mon Sep 17 00:00:00 2001 From: yangjianbo Date: Sun, 11 Jan 2026 16:01:35 +0800 Subject: [PATCH 13/23] =?UTF-8?q?feat(=E4=BB=AA=E8=A1=A8=E7=9B=98):=20?= =?UTF-8?q?=E5=BC=95=E5=85=A5=E9=A2=84=E8=81=9A=E5=90=88=E7=BB=9F=E8=AE=A1?= =?UTF-8?q?=E4=B8=8E=E8=81=9A=E5=90=88=E4=BD=9C=E4=B8=9A?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- backend/cmd/server/wire_gen.go | 8 +- backend/internal/config/config.go | 115 +++++- backend/internal/config/config_test.go | 53 +++ .../handler/admin/dashboard_handler.go | 61 ++- .../pkg/usagestats/usage_log_types.go | 6 + .../repository/dashboard_aggregation_repo.go | 360 ++++++++++++++++++ backend/internal/repository/usage_log_repo.go | 59 ++- .../usage_log_repo_integration_test.go | 154 +++++++- backend/internal/repository/wire.go | 1 + backend/internal/server/routes/admin.go | 1 + .../service/dashboard_aggregation_service.go | 224 +++++++++++ backend/internal/service/dashboard_service.go | 78 +++- .../service/dashboard_service_test.go | 100 ++++- backend/internal/service/wire.go | 8 + ...034_usage_dashboard_aggregation_tables.sql | 77 ++++ .../035_usage_logs_partitioning.sql | 54 +++ config.yaml | 33 ++ deploy/config.example.yaml | 33 ++ frontend/src/types/index.ts | 3 + 19 files changed, 1366 insertions(+), 62 deletions(-) create mode 100644 backend/internal/repository/dashboard_aggregation_repo.go create mode 100644 backend/internal/service/dashboard_aggregation_service.go create mode 100644 backend/migrations/034_usage_dashboard_aggregation_tables.sql create mode 100644 backend/migrations/035_usage_logs_partitioning.sql diff --git a/backend/cmd/server/wire_gen.go b/backend/cmd/server/wire_gen.go index 4fb8351e..e321576e 100644 --- a/backend/cmd/server/wire_gen.go +++ b/backend/cmd/server/wire_gen.go @@ -67,6 +67,7 @@ func initializeApplication(buildInfo handler.BuildInfo) (*Application, error) { userHandler := handler.NewUserHandler(userService) apiKeyHandler := handler.NewAPIKeyHandler(apiKeyService) usageLogRepository := repository.NewUsageLogRepository(client, db) + dashboardAggregationRepository := repository.NewDashboardAggregationRepository(db) usageService := service.NewUsageService(usageLogRepository, userRepository, client, apiKeyAuthCacheInvalidator) usageHandler := handler.NewUsageHandler(usageService, apiKeyService) redeemCodeRepository := repository.NewRedeemCodeRepository(client) @@ -76,8 +77,10 @@ func initializeApplication(buildInfo handler.BuildInfo) (*Application, error) { redeemHandler := handler.NewRedeemHandler(redeemService) subscriptionHandler := handler.NewSubscriptionHandler(subscriptionService) dashboardStatsCache := repository.NewDashboardCache(redisClient, configConfig) - dashboardService := service.NewDashboardService(usageLogRepository, dashboardStatsCache, configConfig) - dashboardHandler := admin.NewDashboardHandler(dashboardService) + timingWheelService := service.ProvideTimingWheelService() + dashboardAggregationService := service.ProvideDashboardAggregationService(dashboardAggregationRepository, timingWheelService, configConfig) + dashboardService := service.NewDashboardService(usageLogRepository, dashboardAggregationRepository, dashboardStatsCache, configConfig) + dashboardHandler := admin.NewDashboardHandler(dashboardService, dashboardAggregationService) accountRepository := repository.NewAccountRepository(client, db) proxyRepository := repository.NewProxyRepository(client, db) proxyExitInfoProber := repository.NewProxyExitInfoProber(configConfig) @@ -138,7 +141,6 @@ func initializeApplication(buildInfo handler.BuildInfo) (*Application, error) { billingService := service.NewBillingService(configConfig, pricingService) identityCache := repository.NewIdentityCache(redisClient) identityService := service.NewIdentityService(identityCache) - timingWheelService := service.ProvideTimingWheelService() deferredService := service.ProvideDeferredService(accountRepository, timingWheelService) gatewayService := service.NewGatewayService(accountRepository, groupRepository, usageLogRepository, userRepository, userSubscriptionRepository, gatewayCache, configConfig, concurrencyService, billingService, rateLimitService, billingCacheService, identityService, httpUpstream, deferredService) geminiMessagesCompatService := service.NewGeminiMessagesCompatService(accountRepository, groupRepository, gatewayCache, geminiTokenProvider, rateLimitService, httpUpstream, antigravityGatewayService, configConfig) diff --git a/backend/internal/config/config.go b/backend/internal/config/config.go index 677d0c6e..b91a07c1 100644 --- a/backend/internal/config/config.go +++ b/backend/internal/config/config.go @@ -36,27 +36,28 @@ const ( ) type Config struct { - Server ServerConfig `mapstructure:"server"` - CORS CORSConfig `mapstructure:"cors"` - Security SecurityConfig `mapstructure:"security"` - Billing BillingConfig `mapstructure:"billing"` - Turnstile TurnstileConfig `mapstructure:"turnstile"` - Database DatabaseConfig `mapstructure:"database"` - Redis RedisConfig `mapstructure:"redis"` - JWT JWTConfig `mapstructure:"jwt"` - LinuxDo LinuxDoConnectConfig `mapstructure:"linuxdo_connect"` - Default DefaultConfig `mapstructure:"default"` - RateLimit RateLimitConfig `mapstructure:"rate_limit"` - Pricing PricingConfig `mapstructure:"pricing"` - Gateway GatewayConfig `mapstructure:"gateway"` - APIKeyAuth APIKeyAuthCacheConfig `mapstructure:"api_key_auth_cache"` - Dashboard DashboardCacheConfig `mapstructure:"dashboard_cache"` - Concurrency ConcurrencyConfig `mapstructure:"concurrency"` - TokenRefresh TokenRefreshConfig `mapstructure:"token_refresh"` - RunMode string `mapstructure:"run_mode" yaml:"run_mode"` - Timezone string `mapstructure:"timezone"` // e.g. "Asia/Shanghai", "UTC" - Gemini GeminiConfig `mapstructure:"gemini"` - Update UpdateConfig `mapstructure:"update"` + Server ServerConfig `mapstructure:"server"` + CORS CORSConfig `mapstructure:"cors"` + Security SecurityConfig `mapstructure:"security"` + Billing BillingConfig `mapstructure:"billing"` + Turnstile TurnstileConfig `mapstructure:"turnstile"` + Database DatabaseConfig `mapstructure:"database"` + Redis RedisConfig `mapstructure:"redis"` + JWT JWTConfig `mapstructure:"jwt"` + LinuxDo LinuxDoConnectConfig `mapstructure:"linuxdo_connect"` + Default DefaultConfig `mapstructure:"default"` + RateLimit RateLimitConfig `mapstructure:"rate_limit"` + Pricing PricingConfig `mapstructure:"pricing"` + Gateway GatewayConfig `mapstructure:"gateway"` + APIKeyAuth APIKeyAuthCacheConfig `mapstructure:"api_key_auth_cache"` + Dashboard DashboardCacheConfig `mapstructure:"dashboard_cache"` + DashboardAgg DashboardAggregationConfig `mapstructure:"dashboard_aggregation"` + Concurrency ConcurrencyConfig `mapstructure:"concurrency"` + TokenRefresh TokenRefreshConfig `mapstructure:"token_refresh"` + RunMode string `mapstructure:"run_mode" yaml:"run_mode"` + Timezone string `mapstructure:"timezone"` // e.g. "Asia/Shanghai", "UTC" + Gemini GeminiConfig `mapstructure:"gemini"` + Update UpdateConfig `mapstructure:"update"` } // UpdateConfig 在线更新相关配置 @@ -387,6 +388,29 @@ type DashboardCacheConfig struct { StatsRefreshTimeoutSeconds int `mapstructure:"stats_refresh_timeout_seconds"` } +// DashboardAggregationConfig 仪表盘预聚合配置 +type DashboardAggregationConfig struct { + // Enabled: 是否启用预聚合作业 + Enabled bool `mapstructure:"enabled"` + // IntervalSeconds: 聚合刷新间隔(秒) + IntervalSeconds int `mapstructure:"interval_seconds"` + // LookbackSeconds: 回看窗口(秒) + LookbackSeconds int `mapstructure:"lookback_seconds"` + // BackfillEnabled: 是否允许全量回填 + BackfillEnabled bool `mapstructure:"backfill_enabled"` + // Retention: 各表保留窗口(天) + Retention DashboardAggregationRetentionConfig `mapstructure:"retention"` + // RecomputeDays: 启动时重算最近 N 天 + RecomputeDays int `mapstructure:"recompute_days"` +} + +// DashboardAggregationRetentionConfig 预聚合保留窗口 +type DashboardAggregationRetentionConfig struct { + UsageLogsDays int `mapstructure:"usage_logs_days"` + HourlyDays int `mapstructure:"hourly_days"` + DailyDays int `mapstructure:"daily_days"` +} + func NormalizeRunMode(value string) string { normalized := strings.ToLower(strings.TrimSpace(value)) switch normalized { @@ -697,6 +721,16 @@ func setDefaults() { viper.SetDefault("dashboard_cache.stats_ttl_seconds", 30) viper.SetDefault("dashboard_cache.stats_refresh_timeout_seconds", 30) + // Dashboard aggregation + viper.SetDefault("dashboard_aggregation.enabled", true) + viper.SetDefault("dashboard_aggregation.interval_seconds", 60) + viper.SetDefault("dashboard_aggregation.lookback_seconds", 120) + viper.SetDefault("dashboard_aggregation.backfill_enabled", false) + viper.SetDefault("dashboard_aggregation.retention.usage_logs_days", 90) + viper.SetDefault("dashboard_aggregation.retention.hourly_days", 180) + viper.SetDefault("dashboard_aggregation.retention.daily_days", 730) + viper.SetDefault("dashboard_aggregation.recompute_days", 2) + // Gateway viper.SetDefault("gateway.response_header_timeout", 600) // 600秒(10分钟)等待上游响应头,LLM高负载时可能排队较久 viper.SetDefault("gateway.log_upstream_error_body", false) @@ -879,6 +913,45 @@ func (c *Config) Validate() error { return fmt.Errorf("dashboard_cache.stats_refresh_timeout_seconds must be non-negative") } } + if c.DashboardAgg.Enabled { + if c.DashboardAgg.IntervalSeconds <= 0 { + return fmt.Errorf("dashboard_aggregation.interval_seconds must be positive") + } + if c.DashboardAgg.LookbackSeconds < 0 { + return fmt.Errorf("dashboard_aggregation.lookback_seconds must be non-negative") + } + if c.DashboardAgg.Retention.UsageLogsDays <= 0 { + return fmt.Errorf("dashboard_aggregation.retention.usage_logs_days must be positive") + } + if c.DashboardAgg.Retention.HourlyDays <= 0 { + return fmt.Errorf("dashboard_aggregation.retention.hourly_days must be positive") + } + if c.DashboardAgg.Retention.DailyDays <= 0 { + return fmt.Errorf("dashboard_aggregation.retention.daily_days must be positive") + } + if c.DashboardAgg.RecomputeDays < 0 { + return fmt.Errorf("dashboard_aggregation.recompute_days must be non-negative") + } + } else { + if c.DashboardAgg.IntervalSeconds < 0 { + return fmt.Errorf("dashboard_aggregation.interval_seconds must be non-negative") + } + if c.DashboardAgg.LookbackSeconds < 0 { + return fmt.Errorf("dashboard_aggregation.lookback_seconds must be non-negative") + } + if c.DashboardAgg.Retention.UsageLogsDays < 0 { + return fmt.Errorf("dashboard_aggregation.retention.usage_logs_days must be non-negative") + } + if c.DashboardAgg.Retention.HourlyDays < 0 { + return fmt.Errorf("dashboard_aggregation.retention.hourly_days must be non-negative") + } + if c.DashboardAgg.Retention.DailyDays < 0 { + return fmt.Errorf("dashboard_aggregation.retention.daily_days must be non-negative") + } + if c.DashboardAgg.RecomputeDays < 0 { + return fmt.Errorf("dashboard_aggregation.recompute_days must be non-negative") + } + } if c.Gateway.MaxBodySize <= 0 { return fmt.Errorf("gateway.max_body_size must be positive") } diff --git a/backend/internal/config/config_test.go b/backend/internal/config/config_test.go index 6cd95b1c..7fc34d64 100644 --- a/backend/internal/config/config_test.go +++ b/backend/internal/config/config_test.go @@ -205,3 +205,56 @@ func TestValidateDashboardCacheConfigDisabled(t *testing.T) { t.Fatalf("Validate() expected stats_ttl_seconds error, got: %v", err) } } + +func TestLoadDefaultDashboardAggregationConfig(t *testing.T) { + viper.Reset() + + cfg, err := Load() + if err != nil { + t.Fatalf("Load() error: %v", err) + } + + if !cfg.DashboardAgg.Enabled { + t.Fatalf("DashboardAgg.Enabled = false, want true") + } + if cfg.DashboardAgg.IntervalSeconds != 60 { + t.Fatalf("DashboardAgg.IntervalSeconds = %d, want 60", cfg.DashboardAgg.IntervalSeconds) + } + if cfg.DashboardAgg.LookbackSeconds != 120 { + t.Fatalf("DashboardAgg.LookbackSeconds = %d, want 120", cfg.DashboardAgg.LookbackSeconds) + } + if cfg.DashboardAgg.BackfillEnabled { + t.Fatalf("DashboardAgg.BackfillEnabled = true, want false") + } + if cfg.DashboardAgg.Retention.UsageLogsDays != 90 { + t.Fatalf("DashboardAgg.Retention.UsageLogsDays = %d, want 90", cfg.DashboardAgg.Retention.UsageLogsDays) + } + if cfg.DashboardAgg.Retention.HourlyDays != 180 { + t.Fatalf("DashboardAgg.Retention.HourlyDays = %d, want 180", cfg.DashboardAgg.Retention.HourlyDays) + } + if cfg.DashboardAgg.Retention.DailyDays != 730 { + t.Fatalf("DashboardAgg.Retention.DailyDays = %d, want 730", cfg.DashboardAgg.Retention.DailyDays) + } + if cfg.DashboardAgg.RecomputeDays != 2 { + t.Fatalf("DashboardAgg.RecomputeDays = %d, want 2", cfg.DashboardAgg.RecomputeDays) + } +} + +func TestValidateDashboardAggregationConfigDisabled(t *testing.T) { + viper.Reset() + + cfg, err := Load() + if err != nil { + t.Fatalf("Load() error: %v", err) + } + + cfg.DashboardAgg.Enabled = false + cfg.DashboardAgg.IntervalSeconds = -1 + err = cfg.Validate() + if err == nil { + t.Fatalf("Validate() expected error for negative dashboard_aggregation.interval_seconds, got nil") + } + if !strings.Contains(err.Error(), "dashboard_aggregation.interval_seconds") { + t.Fatalf("Validate() expected interval_seconds error, got: %v", err) + } +} diff --git a/backend/internal/handler/admin/dashboard_handler.go b/backend/internal/handler/admin/dashboard_handler.go index 30cdd914..560d1075 100644 --- a/backend/internal/handler/admin/dashboard_handler.go +++ b/backend/internal/handler/admin/dashboard_handler.go @@ -1,6 +1,7 @@ package admin import ( + "errors" "strconv" "time" @@ -13,15 +14,17 @@ import ( // DashboardHandler handles admin dashboard statistics type DashboardHandler struct { - dashboardService *service.DashboardService - startTime time.Time // Server start time for uptime calculation + dashboardService *service.DashboardService + aggregationService *service.DashboardAggregationService + startTime time.Time // Server start time for uptime calculation } // NewDashboardHandler creates a new admin dashboard handler -func NewDashboardHandler(dashboardService *service.DashboardService) *DashboardHandler { +func NewDashboardHandler(dashboardService *service.DashboardService, aggregationService *service.DashboardAggregationService) *DashboardHandler { return &DashboardHandler{ - dashboardService: dashboardService, - startTime: time.Now(), + dashboardService: dashboardService, + aggregationService: aggregationService, + startTime: time.Now(), } } @@ -114,6 +117,54 @@ func (h *DashboardHandler) GetStats(c *gin.Context) { // 性能指标 "rpm": stats.Rpm, "tpm": stats.Tpm, + + // 预聚合新鲜度 + "hourly_active_users": stats.HourlyActiveUsers, + "stats_updated_at": stats.StatsUpdatedAt, + "stats_stale": stats.StatsStale, + }) +} + +type DashboardAggregationBackfillRequest struct { + Start string `json:"start"` + End string `json:"end"` +} + +// BackfillAggregation handles triggering aggregation backfill +// POST /api/v1/admin/dashboard/aggregation/backfill +func (h *DashboardHandler) BackfillAggregation(c *gin.Context) { + if h.aggregationService == nil { + response.InternalError(c, "Aggregation service not available") + return + } + + var req DashboardAggregationBackfillRequest + if err := c.ShouldBindJSON(&req); err != nil { + response.BadRequest(c, "Invalid request body") + return + } + start, err := time.Parse(time.RFC3339, req.Start) + if err != nil { + response.BadRequest(c, "Invalid start time") + return + } + end, err := time.Parse(time.RFC3339, req.End) + if err != nil { + response.BadRequest(c, "Invalid end time") + return + } + + if err := h.aggregationService.TriggerBackfill(start, end); err != nil { + if errors.Is(err, service.ErrDashboardBackfillDisabled) { + response.Forbidden(c, "Backfill is disabled") + return + } + response.InternalError(c, "Failed to trigger backfill") + return + } + + response.Success(c, gin.H{ + "status": "accepted", }) } diff --git a/backend/internal/pkg/usagestats/usage_log_types.go b/backend/internal/pkg/usagestats/usage_log_types.go index 39314602..3952785b 100644 --- a/backend/internal/pkg/usagestats/usage_log_types.go +++ b/backend/internal/pkg/usagestats/usage_log_types.go @@ -9,6 +9,12 @@ type DashboardStats struct { TotalUsers int64 `json:"total_users"` TodayNewUsers int64 `json:"today_new_users"` // 今日新增用户数 ActiveUsers int64 `json:"active_users"` // 今日有请求的用户数 + // 小时活跃用户数(UTC 当前小时) + HourlyActiveUsers int64 `json:"hourly_active_users"` + + // 预聚合新鲜度 + StatsUpdatedAt string `json:"stats_updated_at"` + StatsStale bool `json:"stats_stale"` // API Key 统计 TotalAPIKeys int64 `json:"total_api_keys"` diff --git a/backend/internal/repository/dashboard_aggregation_repo.go b/backend/internal/repository/dashboard_aggregation_repo.go new file mode 100644 index 00000000..dbba5cdb --- /dev/null +++ b/backend/internal/repository/dashboard_aggregation_repo.go @@ -0,0 +1,360 @@ +package repository + +import ( + "context" + "database/sql" + "fmt" + "strings" + "time" + + "github.com/Wei-Shaw/sub2api/internal/service" + "github.com/lib/pq" +) + +type dashboardAggregationRepository struct { + sql sqlExecutor +} + +// NewDashboardAggregationRepository 创建仪表盘预聚合仓储。 +func NewDashboardAggregationRepository(sqlDB *sql.DB) service.DashboardAggregationRepository { + return newDashboardAggregationRepositoryWithSQL(sqlDB) +} + +func newDashboardAggregationRepositoryWithSQL(sqlq sqlExecutor) *dashboardAggregationRepository { + return &dashboardAggregationRepository{sql: sqlq} +} + +func (r *dashboardAggregationRepository) AggregateRange(ctx context.Context, start, end time.Time) error { + startUTC := start.UTC() + endUTC := end.UTC() + if !endUTC.After(startUTC) { + return nil + } + + hourStart := startUTC.Truncate(time.Hour) + hourEnd := endUTC.Truncate(time.Hour) + if endUTC.After(hourEnd) { + hourEnd = hourEnd.Add(time.Hour) + } + + dayStart := truncateToDayUTC(startUTC) + dayEnd := truncateToDayUTC(endUTC) + if endUTC.After(dayEnd) { + dayEnd = dayEnd.Add(24 * time.Hour) + } + + if err := r.insertHourlyActiveUsers(ctx, startUTC, endUTC); err != nil { + return err + } + if err := r.insertDailyActiveUsers(ctx, startUTC, endUTC); err != nil { + return err + } + if err := r.upsertHourlyAggregates(ctx, hourStart, hourEnd); err != nil { + return err + } + if err := r.upsertDailyAggregates(ctx, dayStart, dayEnd); err != nil { + return err + } + return nil +} + +func (r *dashboardAggregationRepository) GetAggregationWatermark(ctx context.Context) (time.Time, error) { + var ts time.Time + query := "SELECT last_aggregated_at FROM usage_dashboard_aggregation_watermark WHERE id = 1" + if err := scanSingleRow(ctx, r.sql, query, nil, &ts); err != nil { + if err == sql.ErrNoRows { + return time.Unix(0, 0).UTC(), nil + } + return time.Time{}, err + } + return ts.UTC(), nil +} + +func (r *dashboardAggregationRepository) UpdateAggregationWatermark(ctx context.Context, aggregatedAt time.Time) error { + query := ` + INSERT INTO usage_dashboard_aggregation_watermark (id, last_aggregated_at, updated_at) + VALUES (1, $1, NOW()) + ON CONFLICT (id) + DO UPDATE SET last_aggregated_at = EXCLUDED.last_aggregated_at, updated_at = EXCLUDED.updated_at + ` + _, err := r.sql.ExecContext(ctx, query, aggregatedAt.UTC()) + return err +} + +func (r *dashboardAggregationRepository) CleanupAggregates(ctx context.Context, hourlyCutoff, dailyCutoff time.Time) error { + _, err := r.sql.ExecContext(ctx, ` + DELETE FROM usage_dashboard_hourly WHERE bucket_start < $1; + DELETE FROM usage_dashboard_hourly_users WHERE bucket_start < $1; + DELETE FROM usage_dashboard_daily WHERE bucket_date < $2::date; + DELETE FROM usage_dashboard_daily_users WHERE bucket_date < $2::date; + `, hourlyCutoff.UTC(), dailyCutoff.UTC()) + return err +} + +func (r *dashboardAggregationRepository) CleanupUsageLogs(ctx context.Context, cutoff time.Time) error { + isPartitioned, err := r.isUsageLogsPartitioned(ctx) + if err != nil { + return err + } + if isPartitioned { + return r.dropUsageLogsPartitions(ctx, cutoff) + } + _, err = r.sql.ExecContext(ctx, "DELETE FROM usage_logs WHERE created_at < $1", cutoff.UTC()) + return err +} + +func (r *dashboardAggregationRepository) EnsureUsageLogsPartitions(ctx context.Context, now time.Time) error { + isPartitioned, err := r.isUsageLogsPartitioned(ctx) + if err != nil || !isPartitioned { + return err + } + monthStart := truncateToMonthUTC(now) + prevMonth := monthStart.AddDate(0, -1, 0) + nextMonth := monthStart.AddDate(0, 1, 0) + + for _, m := range []time.Time{prevMonth, monthStart, nextMonth} { + if err := r.createUsageLogsPartition(ctx, m); err != nil { + return err + } + } + return nil +} + +func (r *dashboardAggregationRepository) insertHourlyActiveUsers(ctx context.Context, start, end time.Time) error { + query := ` + INSERT INTO usage_dashboard_hourly_users (bucket_start, user_id) + SELECT DISTINCT + date_trunc('hour', created_at AT TIME ZONE 'UTC') AT TIME ZONE 'UTC' AS bucket_start, + user_id + FROM usage_logs + WHERE created_at >= $1 AND created_at < $2 + ON CONFLICT DO NOTHING + ` + _, err := r.sql.ExecContext(ctx, query, start.UTC(), end.UTC()) + return err +} + +func (r *dashboardAggregationRepository) insertDailyActiveUsers(ctx context.Context, start, end time.Time) error { + query := ` + INSERT INTO usage_dashboard_daily_users (bucket_date, user_id) + SELECT DISTINCT + (created_at AT TIME ZONE 'UTC')::date AS bucket_date, + user_id + FROM usage_logs + WHERE created_at >= $1 AND created_at < $2 + ON CONFLICT DO NOTHING + ` + _, err := r.sql.ExecContext(ctx, query, start.UTC(), end.UTC()) + return err +} + +func (r *dashboardAggregationRepository) upsertHourlyAggregates(ctx context.Context, start, end time.Time) error { + query := ` + WITH hourly AS ( + SELECT + date_trunc('hour', created_at AT TIME ZONE 'UTC') AT TIME ZONE 'UTC' AS bucket_start, + COUNT(*) AS total_requests, + COALESCE(SUM(input_tokens), 0) AS input_tokens, + COALESCE(SUM(output_tokens), 0) AS output_tokens, + COALESCE(SUM(cache_creation_tokens), 0) AS cache_creation_tokens, + COALESCE(SUM(cache_read_tokens), 0) AS cache_read_tokens, + COALESCE(SUM(total_cost), 0) AS total_cost, + COALESCE(SUM(actual_cost), 0) AS actual_cost, + COALESCE(SUM(COALESCE(duration_ms, 0)), 0) AS total_duration_ms + FROM usage_logs + WHERE created_at >= $1 AND created_at < $2 + GROUP BY 1 + ), + user_counts AS ( + SELECT bucket_start, COUNT(*) AS active_users + FROM usage_dashboard_hourly_users + WHERE bucket_start >= $1 AND bucket_start < $2 + GROUP BY bucket_start + ) + INSERT INTO usage_dashboard_hourly ( + bucket_start, + total_requests, + input_tokens, + output_tokens, + cache_creation_tokens, + cache_read_tokens, + total_cost, + actual_cost, + total_duration_ms, + active_users, + computed_at + ) + SELECT + hourly.bucket_start, + hourly.total_requests, + hourly.input_tokens, + hourly.output_tokens, + hourly.cache_creation_tokens, + hourly.cache_read_tokens, + hourly.total_cost, + hourly.actual_cost, + hourly.total_duration_ms, + COALESCE(user_counts.active_users, 0) AS active_users, + NOW() + FROM hourly + LEFT JOIN user_counts ON user_counts.bucket_start = hourly.bucket_start + ON CONFLICT (bucket_start) + DO UPDATE SET + total_requests = EXCLUDED.total_requests, + input_tokens = EXCLUDED.input_tokens, + output_tokens = EXCLUDED.output_tokens, + cache_creation_tokens = EXCLUDED.cache_creation_tokens, + cache_read_tokens = EXCLUDED.cache_read_tokens, + total_cost = EXCLUDED.total_cost, + actual_cost = EXCLUDED.actual_cost, + total_duration_ms = EXCLUDED.total_duration_ms, + active_users = EXCLUDED.active_users, + computed_at = EXCLUDED.computed_at + ` + _, err := r.sql.ExecContext(ctx, query, start.UTC(), end.UTC()) + return err +} + +func (r *dashboardAggregationRepository) upsertDailyAggregates(ctx context.Context, start, end time.Time) error { + query := ` + WITH daily AS ( + SELECT + (bucket_start AT TIME ZONE 'UTC')::date AS bucket_date, + COALESCE(SUM(total_requests), 0) AS total_requests, + COALESCE(SUM(input_tokens), 0) AS input_tokens, + COALESCE(SUM(output_tokens), 0) AS output_tokens, + COALESCE(SUM(cache_creation_tokens), 0) AS cache_creation_tokens, + COALESCE(SUM(cache_read_tokens), 0) AS cache_read_tokens, + COALESCE(SUM(total_cost), 0) AS total_cost, + COALESCE(SUM(actual_cost), 0) AS actual_cost, + COALESCE(SUM(total_duration_ms), 0) AS total_duration_ms + FROM usage_dashboard_hourly + WHERE bucket_start >= $1 AND bucket_start < $2 + GROUP BY (bucket_start AT TIME ZONE 'UTC')::date + ), + user_counts AS ( + SELECT bucket_date, COUNT(*) AS active_users + FROM usage_dashboard_daily_users + WHERE bucket_date >= $3::date AND bucket_date < $4::date + GROUP BY bucket_date + ) + INSERT INTO usage_dashboard_daily ( + bucket_date, + total_requests, + input_tokens, + output_tokens, + cache_creation_tokens, + cache_read_tokens, + total_cost, + actual_cost, + total_duration_ms, + active_users, + computed_at + ) + SELECT + daily.bucket_date, + daily.total_requests, + daily.input_tokens, + daily.output_tokens, + daily.cache_creation_tokens, + daily.cache_read_tokens, + daily.total_cost, + daily.actual_cost, + daily.total_duration_ms, + COALESCE(user_counts.active_users, 0) AS active_users, + NOW() + FROM daily + LEFT JOIN user_counts ON user_counts.bucket_date = daily.bucket_date + ON CONFLICT (bucket_date) + DO UPDATE SET + total_requests = EXCLUDED.total_requests, + input_tokens = EXCLUDED.input_tokens, + output_tokens = EXCLUDED.output_tokens, + cache_creation_tokens = EXCLUDED.cache_creation_tokens, + cache_read_tokens = EXCLUDED.cache_read_tokens, + total_cost = EXCLUDED.total_cost, + actual_cost = EXCLUDED.actual_cost, + total_duration_ms = EXCLUDED.total_duration_ms, + active_users = EXCLUDED.active_users, + computed_at = EXCLUDED.computed_at + ` + _, err := r.sql.ExecContext(ctx, query, start.UTC(), end.UTC(), start.UTC(), end.UTC()) + return err +} + +func (r *dashboardAggregationRepository) isUsageLogsPartitioned(ctx context.Context) (bool, error) { + query := ` + SELECT EXISTS( + SELECT 1 + FROM pg_partitioned_table pt + JOIN pg_class c ON c.oid = pt.partrelid + WHERE c.relname = 'usage_logs' + ) + ` + var partitioned bool + if err := scanSingleRow(ctx, r.sql, query, nil, &partitioned); err != nil { + return false, err + } + return partitioned, nil +} + +func (r *dashboardAggregationRepository) dropUsageLogsPartitions(ctx context.Context, cutoff time.Time) error { + rows, err := r.sql.QueryContext(ctx, ` + SELECT c.relname + FROM pg_inherits + JOIN pg_class c ON c.oid = pg_inherits.inhrelid + JOIN pg_class p ON p.oid = pg_inherits.inhparent + WHERE p.relname = 'usage_logs' + `) + if err != nil { + return err + } + defer rows.Close() + + cutoffMonth := truncateToMonthUTC(cutoff) + for rows.Next() { + var name string + if err := rows.Scan(&name); err != nil { + return err + } + if !strings.HasPrefix(name, "usage_logs_") { + continue + } + suffix := strings.TrimPrefix(name, "usage_logs_") + month, err := time.Parse("200601", suffix) + if err != nil { + continue + } + month = month.UTC() + if month.Before(cutoffMonth) { + if _, err := r.sql.ExecContext(ctx, fmt.Sprintf("DROP TABLE IF EXISTS %s", pq.QuoteIdentifier(name))); err != nil { + return err + } + } + } + return rows.Err() +} + +func (r *dashboardAggregationRepository) createUsageLogsPartition(ctx context.Context, month time.Time) error { + monthStart := truncateToMonthUTC(month) + nextMonth := monthStart.AddDate(0, 1, 0) + name := fmt.Sprintf("usage_logs_%s", monthStart.Format("200601")) + query := fmt.Sprintf( + "CREATE TABLE IF NOT EXISTS %s PARTITION OF usage_logs FOR VALUES FROM (%s) TO (%s)", + pq.QuoteIdentifier(name), + pq.QuoteLiteral(monthStart.Format("2006-01-02")), + pq.QuoteLiteral(nextMonth.Format("2006-01-02")), + ) + _, err := r.sql.ExecContext(ctx, query) + return err +} + +func truncateToDayUTC(t time.Time) time.Time { + t = t.UTC() + return time.Date(t.Year(), t.Month(), t.Day(), 0, 0, 0, 0, time.UTC) +} + +func truncateToMonthUTC(t time.Time) time.Time { + t = t.UTC() + return time.Date(t.Year(), t.Month(), 1, 0, 0, 0, 0, time.UTC) +} diff --git a/backend/internal/repository/usage_log_repo.go b/backend/internal/repository/usage_log_repo.go index 6ed8910e..be2a6d18 100644 --- a/backend/internal/repository/usage_log_repo.go +++ b/backend/internal/repository/usage_log_repo.go @@ -270,15 +270,14 @@ type DashboardStats = usagestats.DashboardStats func (r *usageLogRepository) GetDashboardStats(ctx context.Context) (*DashboardStats, error) { var stats DashboardStats - today := timezone.Today() now := time.Now() + todayUTC := truncateToDayUTC(now) // 合并用户统计查询 userStatsQuery := ` SELECT COUNT(*) as total_users, - COUNT(CASE WHEN created_at >= $1 THEN 1 END) as today_new_users, - (SELECT COUNT(DISTINCT user_id) FROM usage_logs WHERE created_at >= $2) as active_users + COUNT(CASE WHEN created_at >= $1 THEN 1 END) as today_new_users FROM users WHERE deleted_at IS NULL ` @@ -286,10 +285,9 @@ func (r *usageLogRepository) GetDashboardStats(ctx context.Context) (*DashboardS ctx, r.sql, userStatsQuery, - []any{today, today}, + []any{todayUTC}, &stats.TotalUsers, &stats.TodayNewUsers, - &stats.ActiveUsers, ); err != nil { return nil, err } @@ -341,16 +339,17 @@ func (r *usageLogRepository) GetDashboardStats(ctx context.Context) (*DashboardS // 累计 Token 统计 totalStatsQuery := ` SELECT - COUNT(*) as total_requests, + COALESCE(SUM(total_requests), 0) as total_requests, COALESCE(SUM(input_tokens), 0) as total_input_tokens, COALESCE(SUM(output_tokens), 0) as total_output_tokens, COALESCE(SUM(cache_creation_tokens), 0) as total_cache_creation_tokens, COALESCE(SUM(cache_read_tokens), 0) as total_cache_read_tokens, COALESCE(SUM(total_cost), 0) as total_cost, COALESCE(SUM(actual_cost), 0) as total_actual_cost, - COALESCE(AVG(duration_ms), 0) as avg_duration_ms - FROM usage_logs + COALESCE(SUM(total_duration_ms), 0) as total_duration_ms + FROM usage_dashboard_daily ` + var totalDurationMs int64 if err := scanSingleRow( ctx, r.sql, @@ -363,30 +362,34 @@ func (r *usageLogRepository) GetDashboardStats(ctx context.Context) (*DashboardS &stats.TotalCacheReadTokens, &stats.TotalCost, &stats.TotalActualCost, - &stats.AverageDurationMs, + &totalDurationMs, ); err != nil { return nil, err } stats.TotalTokens = stats.TotalInputTokens + stats.TotalOutputTokens + stats.TotalCacheCreationTokens + stats.TotalCacheReadTokens + if stats.TotalRequests > 0 { + stats.AverageDurationMs = float64(totalDurationMs) / float64(stats.TotalRequests) + } // 今日 Token 统计 todayStatsQuery := ` SELECT - COUNT(*) as today_requests, - COALESCE(SUM(input_tokens), 0) as today_input_tokens, - COALESCE(SUM(output_tokens), 0) as today_output_tokens, - COALESCE(SUM(cache_creation_tokens), 0) as today_cache_creation_tokens, - COALESCE(SUM(cache_read_tokens), 0) as today_cache_read_tokens, - COALESCE(SUM(total_cost), 0) as today_cost, - COALESCE(SUM(actual_cost), 0) as today_actual_cost - FROM usage_logs - WHERE created_at >= $1 + total_requests as today_requests, + input_tokens as today_input_tokens, + output_tokens as today_output_tokens, + cache_creation_tokens as today_cache_creation_tokens, + cache_read_tokens as today_cache_read_tokens, + total_cost as today_cost, + actual_cost as today_actual_cost, + active_users as active_users + FROM usage_dashboard_daily + WHERE bucket_date = $1::date ` if err := scanSingleRow( ctx, r.sql, todayStatsQuery, - []any{today}, + []any{todayUTC}, &stats.TodayRequests, &stats.TodayInputTokens, &stats.TodayOutputTokens, @@ -394,11 +397,27 @@ func (r *usageLogRepository) GetDashboardStats(ctx context.Context) (*DashboardS &stats.TodayCacheReadTokens, &stats.TodayCost, &stats.TodayActualCost, + &stats.ActiveUsers, ); err != nil { - return nil, err + if err != sql.ErrNoRows { + return nil, err + } } stats.TodayTokens = stats.TodayInputTokens + stats.TodayOutputTokens + stats.TodayCacheCreationTokens + stats.TodayCacheReadTokens + // 当前小时活跃用户 + hourlyActiveQuery := ` + SELECT active_users + FROM usage_dashboard_hourly + WHERE bucket_start = $1 + ` + hourStart := now.UTC().Truncate(time.Hour) + if err := scanSingleRow(ctx, r.sql, hourlyActiveQuery, []any{hourStart}, &stats.HourlyActiveUsers); err != nil { + if err != sql.ErrNoRows { + return nil, err + } + } + // 性能指标:RPM 和 TPM(最近1分钟,全局) rpm, tpm, err := r.getPerformanceStats(ctx, 0) if err != nil { diff --git a/backend/internal/repository/usage_log_repo_integration_test.go b/backend/internal/repository/usage_log_repo_integration_test.go index 7193718f..09341db3 100644 --- a/backend/internal/repository/usage_log_repo_integration_test.go +++ b/backend/internal/repository/usage_log_repo_integration_test.go @@ -198,8 +198,8 @@ func (s *UsageLogRepoSuite) TestListWithFilters() { // --- GetDashboardStats --- func (s *UsageLogRepoSuite) TestDashboardStats_TodayTotalsAndPerformance() { - now := time.Now() - todayStart := timezone.Today() + now := time.Now().UTC() + todayStart := truncateToDayUTC(now) baseStats, err := s.repo.GetDashboardStats(s.ctx) s.Require().NoError(err, "GetDashboardStats base") @@ -273,6 +273,11 @@ func (s *UsageLogRepoSuite) TestDashboardStats_TodayTotalsAndPerformance() { _, err = s.repo.Create(s.ctx, logPerf) s.Require().NoError(err, "Create logPerf") + aggRepo := newDashboardAggregationRepositoryWithSQL(s.tx) + aggStart := todayStart.Add(-2 * time.Hour) + aggEnd := now.Add(2 * time.Minute) + s.Require().NoError(aggRepo.AggregateRange(s.ctx, aggStart, aggEnd), "AggregateRange") + stats, err := s.repo.GetDashboardStats(s.ctx) s.Require().NoError(err, "GetDashboardStats") @@ -333,6 +338,151 @@ func (s *UsageLogRepoSuite) TestGetAccountTodayStats() { s.Require().Equal(int64(30), stats.Tokens) } +func (s *UsageLogRepoSuite) TestDashboardAggregationConsistency() { + now := time.Now().UTC().Truncate(time.Second) + hour1 := now.Add(-90 * time.Minute).Truncate(time.Hour) + hour2 := now.Add(-30 * time.Minute).Truncate(time.Hour) + dayStart := truncateToDayUTC(now) + + user1 := mustCreateUser(s.T(), s.client, &service.User{Email: "agg-u1@test.com"}) + user2 := mustCreateUser(s.T(), s.client, &service.User{Email: "agg-u2@test.com"}) + apiKey1 := mustCreateApiKey(s.T(), s.client, &service.APIKey{UserID: user1.ID, Key: "sk-agg-1", Name: "k1"}) + apiKey2 := mustCreateApiKey(s.T(), s.client, &service.APIKey{UserID: user2.ID, Key: "sk-agg-2", Name: "k2"}) + account := mustCreateAccount(s.T(), s.client, &service.Account{Name: "acc-agg"}) + + d1, d2, d3 := 100, 200, 150 + log1 := &service.UsageLog{ + UserID: user1.ID, + APIKeyID: apiKey1.ID, + AccountID: account.ID, + Model: "claude-3", + InputTokens: 10, + OutputTokens: 20, + CacheCreationTokens: 2, + CacheReadTokens: 1, + TotalCost: 1.0, + ActualCost: 0.9, + DurationMs: &d1, + CreatedAt: hour1.Add(5 * time.Minute), + } + _, err := s.repo.Create(s.ctx, log1) + s.Require().NoError(err) + + log2 := &service.UsageLog{ + UserID: user1.ID, + APIKeyID: apiKey1.ID, + AccountID: account.ID, + Model: "claude-3", + InputTokens: 5, + OutputTokens: 5, + TotalCost: 0.5, + ActualCost: 0.5, + DurationMs: &d2, + CreatedAt: hour1.Add(20 * time.Minute), + } + _, err = s.repo.Create(s.ctx, log2) + s.Require().NoError(err) + + log3 := &service.UsageLog{ + UserID: user2.ID, + APIKeyID: apiKey2.ID, + AccountID: account.ID, + Model: "claude-3", + InputTokens: 7, + OutputTokens: 8, + TotalCost: 0.7, + ActualCost: 0.7, + DurationMs: &d3, + CreatedAt: hour2.Add(10 * time.Minute), + } + _, err = s.repo.Create(s.ctx, log3) + s.Require().NoError(err) + + aggRepo := newDashboardAggregationRepositoryWithSQL(s.tx) + aggStart := hour1.Add(-5 * time.Minute) + aggEnd := now.Add(5 * time.Minute) + s.Require().NoError(aggRepo.AggregateRange(s.ctx, aggStart, aggEnd)) + + type hourlyRow struct { + totalRequests int64 + inputTokens int64 + outputTokens int64 + cacheCreationTokens int64 + cacheReadTokens int64 + totalCost float64 + actualCost float64 + totalDurationMs int64 + activeUsers int64 + } + fetchHourly := func(bucketStart time.Time) hourlyRow { + var row hourlyRow + err := scanSingleRow(s.ctx, s.tx, ` + SELECT total_requests, input_tokens, output_tokens, cache_creation_tokens, cache_read_tokens, + total_cost, actual_cost, total_duration_ms, active_users + FROM usage_dashboard_hourly + WHERE bucket_start = $1 + `, []any{bucketStart}, &row.totalRequests, &row.inputTokens, &row.outputTokens, + &row.cacheCreationTokens, &row.cacheReadTokens, &row.totalCost, &row.actualCost, + &row.totalDurationMs, &row.activeUsers, + ) + s.Require().NoError(err) + return row + } + + hour1Row := fetchHourly(hour1) + s.Require().Equal(int64(2), hour1Row.totalRequests) + s.Require().Equal(int64(15), hour1Row.inputTokens) + s.Require().Equal(int64(25), hour1Row.outputTokens) + s.Require().Equal(int64(2), hour1Row.cacheCreationTokens) + s.Require().Equal(int64(1), hour1Row.cacheReadTokens) + s.Require().Equal(1.5, hour1Row.totalCost) + s.Require().Equal(1.4, hour1Row.actualCost) + s.Require().Equal(int64(300), hour1Row.totalDurationMs) + s.Require().Equal(int64(1), hour1Row.activeUsers) + + hour2Row := fetchHourly(hour2) + s.Require().Equal(int64(1), hour2Row.totalRequests) + s.Require().Equal(int64(7), hour2Row.inputTokens) + s.Require().Equal(int64(8), hour2Row.outputTokens) + s.Require().Equal(int64(0), hour2Row.cacheCreationTokens) + s.Require().Equal(int64(0), hour2Row.cacheReadTokens) + s.Require().Equal(0.7, hour2Row.totalCost) + s.Require().Equal(0.7, hour2Row.actualCost) + s.Require().Equal(int64(150), hour2Row.totalDurationMs) + s.Require().Equal(int64(1), hour2Row.activeUsers) + + var daily struct { + totalRequests int64 + inputTokens int64 + outputTokens int64 + cacheCreationTokens int64 + cacheReadTokens int64 + totalCost float64 + actualCost float64 + totalDurationMs int64 + activeUsers int64 + } + err = scanSingleRow(s.ctx, s.tx, ` + SELECT total_requests, input_tokens, output_tokens, cache_creation_tokens, cache_read_tokens, + total_cost, actual_cost, total_duration_ms, active_users + FROM usage_dashboard_daily + WHERE bucket_date = $1::date + `, []any{dayStart}, &daily.totalRequests, &daily.inputTokens, &daily.outputTokens, + &daily.cacheCreationTokens, &daily.cacheReadTokens, &daily.totalCost, &daily.actualCost, + &daily.totalDurationMs, &daily.activeUsers, + ) + s.Require().NoError(err) + s.Require().Equal(int64(3), daily.totalRequests) + s.Require().Equal(int64(22), daily.inputTokens) + s.Require().Equal(int64(33), daily.outputTokens) + s.Require().Equal(int64(2), daily.cacheCreationTokens) + s.Require().Equal(int64(1), daily.cacheReadTokens) + s.Require().Equal(2.2, daily.totalCost) + s.Require().Equal(2.1, daily.actualCost) + s.Require().Equal(int64(450), daily.totalDurationMs) + s.Require().Equal(int64(2), daily.activeUsers) +} + // --- GetBatchUserUsageStats --- func (s *UsageLogRepoSuite) TestGetBatchUserUsageStats() { diff --git a/backend/internal/repository/wire.go b/backend/internal/repository/wire.go index 1b6a7b91..8cc937bb 100644 --- a/backend/internal/repository/wire.go +++ b/backend/internal/repository/wire.go @@ -47,6 +47,7 @@ var ProviderSet = wire.NewSet( NewRedeemCodeRepository, NewPromoCodeRepository, NewUsageLogRepository, + NewDashboardAggregationRepository, NewSettingRepository, NewUserSubscriptionRepository, NewUserAttributeDefinitionRepository, diff --git a/backend/internal/server/routes/admin.go b/backend/internal/server/routes/admin.go index 6f40c491..c9c5352c 100644 --- a/backend/internal/server/routes/admin.go +++ b/backend/internal/server/routes/admin.go @@ -75,6 +75,7 @@ func registerDashboardRoutes(admin *gin.RouterGroup, h *handler.Handlers) { dashboard.GET("/users-trend", h.Admin.Dashboard.GetUserUsageTrend) dashboard.POST("/users-usage", h.Admin.Dashboard.GetBatchUsersUsage) dashboard.POST("/api-keys-usage", h.Admin.Dashboard.GetBatchAPIKeysUsage) + dashboard.POST("/aggregation/backfill", h.Admin.Dashboard.BackfillAggregation) } } diff --git a/backend/internal/service/dashboard_aggregation_service.go b/backend/internal/service/dashboard_aggregation_service.go new file mode 100644 index 00000000..343c3240 --- /dev/null +++ b/backend/internal/service/dashboard_aggregation_service.go @@ -0,0 +1,224 @@ +package service + +import ( + "context" + "errors" + "log" + "sync/atomic" + "time" + + "github.com/Wei-Shaw/sub2api/internal/config" +) + +const ( + defaultDashboardAggregationTimeout = 2 * time.Minute + defaultDashboardAggregationBackfillTimeout = 30 * time.Minute + dashboardAggregationRetentionInterval = 6 * time.Hour +) + +var ( + // ErrDashboardBackfillDisabled 当配置禁用回填时返回。 + ErrDashboardBackfillDisabled = errors.New("仪表盘聚合回填已禁用") +) + +// DashboardAggregationRepository 定义仪表盘预聚合仓储接口。 +type DashboardAggregationRepository interface { + AggregateRange(ctx context.Context, start, end time.Time) error + GetAggregationWatermark(ctx context.Context) (time.Time, error) + UpdateAggregationWatermark(ctx context.Context, aggregatedAt time.Time) error + CleanupAggregates(ctx context.Context, hourlyCutoff, dailyCutoff time.Time) error + CleanupUsageLogs(ctx context.Context, cutoff time.Time) error + EnsureUsageLogsPartitions(ctx context.Context, now time.Time) error +} + +// DashboardAggregationService 负责定时聚合与回填。 +type DashboardAggregationService struct { + repo DashboardAggregationRepository + timingWheel *TimingWheelService + cfg config.DashboardAggregationConfig + running int32 + lastRetentionCleanup atomic.Value // time.Time +} + +// NewDashboardAggregationService 创建聚合服务。 +func NewDashboardAggregationService(repo DashboardAggregationRepository, timingWheel *TimingWheelService, cfg *config.Config) *DashboardAggregationService { + var aggCfg config.DashboardAggregationConfig + if cfg != nil { + aggCfg = cfg.DashboardAgg + } + return &DashboardAggregationService{ + repo: repo, + timingWheel: timingWheel, + cfg: aggCfg, + } +} + +// Start 启动定时聚合作业(重启生效配置)。 +func (s *DashboardAggregationService) Start() { + if s == nil || s.repo == nil || s.timingWheel == nil { + return + } + if !s.cfg.Enabled { + log.Printf("[DashboardAggregation] 聚合作业已禁用") + return + } + + interval := time.Duration(s.cfg.IntervalSeconds) * time.Second + if interval <= 0 { + interval = time.Minute + } + + if s.cfg.RecomputeDays > 0 { + go s.recomputeRecentDays() + } + + s.timingWheel.ScheduleRecurring("dashboard:aggregation", interval, func() { + s.runScheduledAggregation() + }) + log.Printf("[DashboardAggregation] 聚合作业启动 (interval=%v, lookback=%ds)", interval, s.cfg.LookbackSeconds) +} + +// TriggerBackfill 触发回填(异步)。 +func (s *DashboardAggregationService) TriggerBackfill(start, end time.Time) error { + if s == nil || s.repo == nil { + return errors.New("聚合服务未初始化") + } + if !s.cfg.BackfillEnabled { + log.Printf("[DashboardAggregation] 回填被拒绝: backfill_enabled=false") + return ErrDashboardBackfillDisabled + } + if !end.After(start) { + return errors.New("回填时间范围无效") + } + + go func() { + ctx, cancel := context.WithTimeout(context.Background(), defaultDashboardAggregationBackfillTimeout) + defer cancel() + if err := s.backfillRange(ctx, start, end); err != nil { + log.Printf("[DashboardAggregation] 回填失败: %v", err) + } + }() + return nil +} + +func (s *DashboardAggregationService) recomputeRecentDays() { + days := s.cfg.RecomputeDays + if days <= 0 { + return + } + now := time.Now().UTC() + start := now.AddDate(0, 0, -days) + + ctx, cancel := context.WithTimeout(context.Background(), defaultDashboardAggregationBackfillTimeout) + defer cancel() + if err := s.backfillRange(ctx, start, now); err != nil { + log.Printf("[DashboardAggregation] 启动重算失败: %v", err) + return + } +} + +func (s *DashboardAggregationService) runScheduledAggregation() { + if !atomic.CompareAndSwapInt32(&s.running, 0, 1) { + return + } + defer atomic.StoreInt32(&s.running, 0) + + ctx, cancel := context.WithTimeout(context.Background(), defaultDashboardAggregationTimeout) + defer cancel() + + now := time.Now().UTC() + last, err := s.repo.GetAggregationWatermark(ctx) + if err != nil { + log.Printf("[DashboardAggregation] 读取水位失败: %v", err) + last = time.Unix(0, 0).UTC() + } + + lookback := time.Duration(s.cfg.LookbackSeconds) * time.Second + start := last.Add(-lookback) + epoch := time.Unix(0, 0).UTC() + if !last.After(epoch) { + start = now.Add(-lookback) + } + if start.After(now) { + start = now.Add(-lookback) + } + + if err := s.aggregateRange(ctx, start, now); err != nil { + log.Printf("[DashboardAggregation] 聚合失败: %v", err) + return + } + + if err := s.repo.UpdateAggregationWatermark(ctx, now); err != nil { + log.Printf("[DashboardAggregation] 更新水位失败: %v", err) + } + + s.maybeCleanupRetention(ctx, now) +} + +func (s *DashboardAggregationService) backfillRange(ctx context.Context, start, end time.Time) error { + if !atomic.CompareAndSwapInt32(&s.running, 0, 1) { + return errors.New("聚合作业正在运行") + } + defer atomic.StoreInt32(&s.running, 0) + + startUTC := start.UTC() + endUTC := end.UTC() + if !endUTC.After(startUTC) { + return errors.New("回填时间范围无效") + } + + cursor := truncateToDayUTC(startUTC) + for cursor.Before(endUTC) { + windowEnd := cursor.Add(24 * time.Hour) + if windowEnd.After(endUTC) { + windowEnd = endUTC + } + if err := s.aggregateRange(ctx, cursor, windowEnd); err != nil { + return err + } + cursor = windowEnd + } + + if err := s.repo.UpdateAggregationWatermark(ctx, endUTC); err != nil { + log.Printf("[DashboardAggregation] 更新水位失败: %v", err) + } + + s.maybeCleanupRetention(ctx, endUTC) + return nil +} + +func (s *DashboardAggregationService) aggregateRange(ctx context.Context, start, end time.Time) error { + if !end.After(start) { + return nil + } + if err := s.repo.EnsureUsageLogsPartitions(ctx, end); err != nil { + log.Printf("[DashboardAggregation] 分区检查失败: %v", err) + } + return s.repo.AggregateRange(ctx, start, end) +} + +func (s *DashboardAggregationService) maybeCleanupRetention(ctx context.Context, now time.Time) { + lastAny := s.lastRetentionCleanup.Load() + if lastAny != nil { + if last, ok := lastAny.(time.Time); ok && now.Sub(last) < dashboardAggregationRetentionInterval { + return + } + } + s.lastRetentionCleanup.Store(now) + + hourlyCutoff := now.AddDate(0, 0, -s.cfg.Retention.HourlyDays) + dailyCutoff := now.AddDate(0, 0, -s.cfg.Retention.DailyDays) + usageCutoff := now.AddDate(0, 0, -s.cfg.Retention.UsageLogsDays) + + if err := s.repo.CleanupAggregates(ctx, hourlyCutoff, dailyCutoff); err != nil { + log.Printf("[DashboardAggregation] 聚合保留清理失败: %v", err) + } + if err := s.repo.CleanupUsageLogs(ctx, usageCutoff); err != nil { + log.Printf("[DashboardAggregation] usage_logs 保留清理失败: %v", err) + } +} + +func truncateToDayUTC(t time.Time) time.Time { + t = t.UTC() + return time.Date(t.Year(), t.Month(), t.Day(), 0, 0, 0, 0, time.UTC) +} diff --git a/backend/internal/service/dashboard_service.go b/backend/internal/service/dashboard_service.go index 468135e3..40ab877d 100644 --- a/backend/internal/service/dashboard_service.go +++ b/backend/internal/service/dashboard_service.go @@ -37,17 +37,24 @@ type dashboardStatsCacheEntry struct { // DashboardService provides aggregated statistics for admin dashboard. type DashboardService struct { usageRepo UsageLogRepository + aggRepo DashboardAggregationRepository cache DashboardStatsCache cacheFreshTTL time.Duration cacheTTL time.Duration refreshTimeout time.Duration refreshing int32 + aggEnabled bool + aggInterval time.Duration + aggLookback time.Duration } -func NewDashboardService(usageRepo UsageLogRepository, cache DashboardStatsCache, cfg *config.Config) *DashboardService { +func NewDashboardService(usageRepo UsageLogRepository, aggRepo DashboardAggregationRepository, cache DashboardStatsCache, cfg *config.Config) *DashboardService { freshTTL := defaultDashboardStatsFreshTTL cacheTTL := defaultDashboardStatsCacheTTL refreshTimeout := defaultDashboardStatsRefreshTimeout + aggEnabled := true + aggInterval := time.Minute + aggLookback := 2 * time.Minute if cfg != nil { if !cfg.Dashboard.Enabled { cache = nil @@ -61,13 +68,24 @@ func NewDashboardService(usageRepo UsageLogRepository, cache DashboardStatsCache if cfg.Dashboard.StatsRefreshTimeoutSeconds > 0 { refreshTimeout = time.Duration(cfg.Dashboard.StatsRefreshTimeoutSeconds) * time.Second } + aggEnabled = cfg.DashboardAgg.Enabled + if cfg.DashboardAgg.IntervalSeconds > 0 { + aggInterval = time.Duration(cfg.DashboardAgg.IntervalSeconds) * time.Second + } + if cfg.DashboardAgg.LookbackSeconds > 0 { + aggLookback = time.Duration(cfg.DashboardAgg.LookbackSeconds) * time.Second + } } return &DashboardService{ usageRepo: usageRepo, + aggRepo: aggRepo, cache: cache, cacheFreshTTL: freshTTL, cacheTTL: cacheTTL, refreshTimeout: refreshTimeout, + aggEnabled: aggEnabled, + aggInterval: aggInterval, + aggLookback: aggLookback, } } @@ -75,6 +93,7 @@ func (s *DashboardService) GetDashboardStats(ctx context.Context) (*usagestats.D if s.cache != nil { cached, fresh, err := s.getCachedDashboardStats(ctx) if err == nil && cached != nil { + s.refreshAggregationStaleness(cached) if !fresh { s.refreshDashboardStatsAsync() } @@ -133,6 +152,7 @@ func (s *DashboardService) refreshDashboardStats(ctx context.Context) (*usagesta if err != nil { return nil, err } + s.applyAggregationStatus(ctx, stats) cacheCtx, cancel := s.cacheOperationContext() defer cancel() s.saveDashboardStatsCache(cacheCtx, stats) @@ -158,6 +178,7 @@ func (s *DashboardService) refreshDashboardStatsAsync() { log.Printf("[Dashboard] 仪表盘缓存异步刷新失败: %v", err) return } + s.applyAggregationStatus(ctx, stats) cacheCtx, cancel := s.cacheOperationContext() defer cancel() s.saveDashboardStatsCache(cacheCtx, stats) @@ -203,6 +224,61 @@ func (s *DashboardService) cacheOperationContext() (context.Context, context.Can return context.WithTimeout(context.Background(), s.refreshTimeout) } +func (s *DashboardService) applyAggregationStatus(ctx context.Context, stats *usagestats.DashboardStats) { + if stats == nil { + return + } + updatedAt := s.fetchAggregationUpdatedAt(ctx) + stats.StatsUpdatedAt = updatedAt.UTC().Format(time.RFC3339) + stats.StatsStale = s.isAggregationStale(updatedAt, time.Now().UTC()) +} + +func (s *DashboardService) refreshAggregationStaleness(stats *usagestats.DashboardStats) { + if stats == nil { + return + } + updatedAt := parseStatsUpdatedAt(stats.StatsUpdatedAt) + stats.StatsStale = s.isAggregationStale(updatedAt, time.Now().UTC()) +} + +func (s *DashboardService) fetchAggregationUpdatedAt(ctx context.Context) time.Time { + if s.aggRepo == nil { + return time.Unix(0, 0).UTC() + } + updatedAt, err := s.aggRepo.GetAggregationWatermark(ctx) + if err != nil { + log.Printf("[Dashboard] 读取聚合水位失败: %v", err) + return time.Unix(0, 0).UTC() + } + if updatedAt.IsZero() { + return time.Unix(0, 0).UTC() + } + return updatedAt.UTC() +} + +func (s *DashboardService) isAggregationStale(updatedAt, now time.Time) bool { + if !s.aggEnabled { + return true + } + epoch := time.Unix(0, 0).UTC() + if !updatedAt.After(epoch) { + return true + } + threshold := s.aggInterval + s.aggLookback + return now.Sub(updatedAt) > threshold +} + +func parseStatsUpdatedAt(raw string) time.Time { + if raw == "" { + return time.Unix(0, 0).UTC() + } + parsed, err := time.Parse(time.RFC3339, raw) + if err != nil { + return time.Unix(0, 0).UTC() + } + return parsed.UTC() +} + func (s *DashboardService) GetAPIKeyUsageTrend(ctx context.Context, startTime, endTime time.Time, granularity string, limit int) ([]usagestats.APIKeyUsageTrendPoint, error) { trend, err := s.usageRepo.GetAPIKeyUsageTrend(ctx, startTime, endTime, granularity, limit) if err != nil { diff --git a/backend/internal/service/dashboard_service_test.go b/backend/internal/service/dashboard_service_test.go index 17f46ead..c7b9c6af 100644 --- a/backend/internal/service/dashboard_service_test.go +++ b/backend/internal/service/dashboard_service_test.go @@ -74,6 +74,38 @@ func (c *dashboardCacheStub) DeleteDashboardStats(ctx context.Context) error { return nil } +type dashboardAggregationRepoStub struct { + watermark time.Time + err error +} + +func (s *dashboardAggregationRepoStub) AggregateRange(ctx context.Context, start, end time.Time) error { + return nil +} + +func (s *dashboardAggregationRepoStub) GetAggregationWatermark(ctx context.Context) (time.Time, error) { + if s.err != nil { + return time.Time{}, s.err + } + return s.watermark, nil +} + +func (s *dashboardAggregationRepoStub) UpdateAggregationWatermark(ctx context.Context, aggregatedAt time.Time) error { + return nil +} + +func (s *dashboardAggregationRepoStub) CleanupAggregates(ctx context.Context, hourlyCutoff, dailyCutoff time.Time) error { + return nil +} + +func (s *dashboardAggregationRepoStub) CleanupUsageLogs(ctx context.Context, cutoff time.Time) error { + return nil +} + +func (s *dashboardAggregationRepoStub) EnsureUsageLogsPartitions(ctx context.Context, now time.Time) error { + return nil +} + func (c *dashboardCacheStub) readLastEntry(t *testing.T) dashboardStatsCacheEntry { t.Helper() c.lastSetMu.Lock() @@ -88,7 +120,9 @@ func (c *dashboardCacheStub) readLastEntry(t *testing.T) dashboardStatsCacheEntr func TestDashboardService_CacheHitFresh(t *testing.T) { stats := &usagestats.DashboardStats{ - TotalUsers: 10, + TotalUsers: 10, + StatsUpdatedAt: time.Unix(0, 0).UTC().Format(time.RFC3339), + StatsStale: true, } entry := dashboardStatsCacheEntry{ Stats: stats, @@ -105,8 +139,9 @@ func TestDashboardService_CacheHitFresh(t *testing.T) { repo := &usageRepoStub{ stats: &usagestats.DashboardStats{TotalUsers: 99}, } + aggRepo := &dashboardAggregationRepoStub{watermark: time.Unix(0, 0).UTC()} cfg := &config.Config{Dashboard: config.DashboardCacheConfig{Enabled: true}} - svc := NewDashboardService(repo, cache, cfg) + svc := NewDashboardService(repo, aggRepo, cache, cfg) got, err := svc.GetDashboardStats(context.Background()) require.NoError(t, err) @@ -118,7 +153,9 @@ func TestDashboardService_CacheHitFresh(t *testing.T) { func TestDashboardService_CacheMiss_StoresCache(t *testing.T) { stats := &usagestats.DashboardStats{ - TotalUsers: 7, + TotalUsers: 7, + StatsUpdatedAt: time.Unix(0, 0).UTC().Format(time.RFC3339), + StatsStale: true, } cache := &dashboardCacheStub{ get: func(ctx context.Context) (string, error) { @@ -126,8 +163,9 @@ func TestDashboardService_CacheMiss_StoresCache(t *testing.T) { }, } repo := &usageRepoStub{stats: stats} + aggRepo := &dashboardAggregationRepoStub{watermark: time.Unix(0, 0).UTC()} cfg := &config.Config{Dashboard: config.DashboardCacheConfig{Enabled: true}} - svc := NewDashboardService(repo, cache, cfg) + svc := NewDashboardService(repo, aggRepo, cache, cfg) got, err := svc.GetDashboardStats(context.Background()) require.NoError(t, err) @@ -142,7 +180,9 @@ func TestDashboardService_CacheMiss_StoresCache(t *testing.T) { func TestDashboardService_CacheDisabled_SkipsCache(t *testing.T) { stats := &usagestats.DashboardStats{ - TotalUsers: 3, + TotalUsers: 3, + StatsUpdatedAt: time.Unix(0, 0).UTC().Format(time.RFC3339), + StatsStale: true, } cache := &dashboardCacheStub{ get: func(ctx context.Context) (string, error) { @@ -150,8 +190,9 @@ func TestDashboardService_CacheDisabled_SkipsCache(t *testing.T) { }, } repo := &usageRepoStub{stats: stats} + aggRepo := &dashboardAggregationRepoStub{watermark: time.Unix(0, 0).UTC()} cfg := &config.Config{Dashboard: config.DashboardCacheConfig{Enabled: false}} - svc := NewDashboardService(repo, cache, cfg) + svc := NewDashboardService(repo, aggRepo, cache, cfg) got, err := svc.GetDashboardStats(context.Background()) require.NoError(t, err) @@ -163,7 +204,9 @@ func TestDashboardService_CacheDisabled_SkipsCache(t *testing.T) { func TestDashboardService_CacheHitStale_TriggersAsyncRefresh(t *testing.T) { staleStats := &usagestats.DashboardStats{ - TotalUsers: 11, + TotalUsers: 11, + StatsUpdatedAt: time.Unix(0, 0).UTC().Format(time.RFC3339), + StatsStale: true, } entry := dashboardStatsCacheEntry{ Stats: staleStats, @@ -182,8 +225,9 @@ func TestDashboardService_CacheHitStale_TriggersAsyncRefresh(t *testing.T) { stats: &usagestats.DashboardStats{TotalUsers: 22}, onCall: refreshCh, } + aggRepo := &dashboardAggregationRepoStub{watermark: time.Unix(0, 0).UTC()} cfg := &config.Config{Dashboard: config.DashboardCacheConfig{Enabled: true}} - svc := NewDashboardService(repo, cache, cfg) + svc := NewDashboardService(repo, aggRepo, cache, cfg) got, err := svc.GetDashboardStats(context.Background()) require.NoError(t, err) @@ -207,8 +251,9 @@ func TestDashboardService_CacheParseError_EvictsAndRefetches(t *testing.T) { } stats := &usagestats.DashboardStats{TotalUsers: 9} repo := &usageRepoStub{stats: stats} + aggRepo := &dashboardAggregationRepoStub{watermark: time.Unix(0, 0).UTC()} cfg := &config.Config{Dashboard: config.DashboardCacheConfig{Enabled: true}} - svc := NewDashboardService(repo, cache, cfg) + svc := NewDashboardService(repo, aggRepo, cache, cfg) got, err := svc.GetDashboardStats(context.Background()) require.NoError(t, err) @@ -224,10 +269,45 @@ func TestDashboardService_CacheParseError_RepoFailure(t *testing.T) { }, } repo := &usageRepoStub{err: errors.New("db down")} + aggRepo := &dashboardAggregationRepoStub{watermark: time.Unix(0, 0).UTC()} cfg := &config.Config{Dashboard: config.DashboardCacheConfig{Enabled: true}} - svc := NewDashboardService(repo, cache, cfg) + svc := NewDashboardService(repo, aggRepo, cache, cfg) _, err := svc.GetDashboardStats(context.Background()) require.Error(t, err) require.Equal(t, int32(1), atomic.LoadInt32(&cache.delCalls)) } + +func TestDashboardService_StatsUpdatedAtEpochWhenMissing(t *testing.T) { + stats := &usagestats.DashboardStats{} + repo := &usageRepoStub{stats: stats} + aggRepo := &dashboardAggregationRepoStub{watermark: time.Unix(0, 0).UTC()} + cfg := &config.Config{Dashboard: config.DashboardCacheConfig{Enabled: false}} + svc := NewDashboardService(repo, aggRepo, nil, cfg) + + got, err := svc.GetDashboardStats(context.Background()) + require.NoError(t, err) + require.Equal(t, "1970-01-01T00:00:00Z", got.StatsUpdatedAt) + require.True(t, got.StatsStale) +} + +func TestDashboardService_StatsStaleFalseWhenFresh(t *testing.T) { + aggNow := time.Now().UTC().Truncate(time.Second) + stats := &usagestats.DashboardStats{} + repo := &usageRepoStub{stats: stats} + aggRepo := &dashboardAggregationRepoStub{watermark: aggNow} + cfg := &config.Config{ + Dashboard: config.DashboardCacheConfig{Enabled: false}, + DashboardAgg: config.DashboardAggregationConfig{ + Enabled: true, + IntervalSeconds: 60, + LookbackSeconds: 120, + }, + } + svc := NewDashboardService(repo, aggRepo, nil, cfg) + + got, err := svc.GetDashboardStats(context.Background()) + require.NoError(t, err) + require.Equal(t, aggNow.Format(time.RFC3339), got.StatsUpdatedAt) + require.False(t, got.StatsStale) +} diff --git a/backend/internal/service/wire.go b/backend/internal/service/wire.go index 54c37b54..f1074e9d 100644 --- a/backend/internal/service/wire.go +++ b/backend/internal/service/wire.go @@ -47,6 +47,13 @@ func ProvideTokenRefreshService( return svc } +// ProvideDashboardAggregationService 创建并启动仪表盘聚合服务 +func ProvideDashboardAggregationService(repo DashboardAggregationRepository, timingWheel *TimingWheelService, cfg *config.Config) *DashboardAggregationService { + svc := NewDashboardAggregationService(repo, timingWheel, cfg) + svc.Start() + return svc +} + // ProvideAccountExpiryService creates and starts AccountExpiryService. func ProvideAccountExpiryService(accountRepo AccountRepository) *AccountExpiryService { svc := NewAccountExpiryService(accountRepo, time.Minute) @@ -126,6 +133,7 @@ var ProviderSet = wire.NewSet( ProvideTokenRefreshService, ProvideAccountExpiryService, ProvideTimingWheelService, + ProvideDashboardAggregationService, ProvideDeferredService, NewAntigravityQuotaFetcher, NewUserAttributeService, diff --git a/backend/migrations/034_usage_dashboard_aggregation_tables.sql b/backend/migrations/034_usage_dashboard_aggregation_tables.sql new file mode 100644 index 00000000..64b383d4 --- /dev/null +++ b/backend/migrations/034_usage_dashboard_aggregation_tables.sql @@ -0,0 +1,77 @@ +-- Usage dashboard aggregation tables (hourly/daily) + active-user dedup + watermark. +-- These tables support Admin Dashboard statistics without full-table scans on usage_logs. + +-- Hourly aggregates (UTC buckets). +CREATE TABLE IF NOT EXISTS usage_dashboard_hourly ( + bucket_start TIMESTAMPTZ PRIMARY KEY, + total_requests BIGINT NOT NULL DEFAULT 0, + input_tokens BIGINT NOT NULL DEFAULT 0, + output_tokens BIGINT NOT NULL DEFAULT 0, + cache_creation_tokens BIGINT NOT NULL DEFAULT 0, + cache_read_tokens BIGINT NOT NULL DEFAULT 0, + total_cost DECIMAL(20, 10) NOT NULL DEFAULT 0, + actual_cost DECIMAL(20, 10) NOT NULL DEFAULT 0, + total_duration_ms BIGINT NOT NULL DEFAULT 0, + active_users BIGINT NOT NULL DEFAULT 0, + computed_at TIMESTAMPTZ NOT NULL DEFAULT NOW() +); + +CREATE INDEX IF NOT EXISTS idx_usage_dashboard_hourly_bucket_start + ON usage_dashboard_hourly (bucket_start DESC); + +COMMENT ON TABLE usage_dashboard_hourly IS 'Pre-aggregated hourly usage metrics for admin dashboard (UTC buckets).'; +COMMENT ON COLUMN usage_dashboard_hourly.bucket_start IS 'UTC start timestamp of the hour bucket.'; +COMMENT ON COLUMN usage_dashboard_hourly.computed_at IS 'When the hourly row was last computed/refreshed.'; + +-- Daily aggregates (UTC dates). +CREATE TABLE IF NOT EXISTS usage_dashboard_daily ( + bucket_date DATE PRIMARY KEY, + total_requests BIGINT NOT NULL DEFAULT 0, + input_tokens BIGINT NOT NULL DEFAULT 0, + output_tokens BIGINT NOT NULL DEFAULT 0, + cache_creation_tokens BIGINT NOT NULL DEFAULT 0, + cache_read_tokens BIGINT NOT NULL DEFAULT 0, + total_cost DECIMAL(20, 10) NOT NULL DEFAULT 0, + actual_cost DECIMAL(20, 10) NOT NULL DEFAULT 0, + total_duration_ms BIGINT NOT NULL DEFAULT 0, + active_users BIGINT NOT NULL DEFAULT 0, + computed_at TIMESTAMPTZ NOT NULL DEFAULT NOW() +); + +CREATE INDEX IF NOT EXISTS idx_usage_dashboard_daily_bucket_date + ON usage_dashboard_daily (bucket_date DESC); + +COMMENT ON TABLE usage_dashboard_daily IS 'Pre-aggregated daily usage metrics for admin dashboard (UTC dates).'; +COMMENT ON COLUMN usage_dashboard_daily.bucket_date IS 'UTC date of the day bucket.'; +COMMENT ON COLUMN usage_dashboard_daily.computed_at IS 'When the daily row was last computed/refreshed.'; + +-- Hourly active user dedup table. +CREATE TABLE IF NOT EXISTS usage_dashboard_hourly_users ( + bucket_start TIMESTAMPTZ NOT NULL, + user_id BIGINT NOT NULL, + PRIMARY KEY (bucket_start, user_id) +); + +CREATE INDEX IF NOT EXISTS idx_usage_dashboard_hourly_users_bucket_start + ON usage_dashboard_hourly_users (bucket_start); + +-- Daily active user dedup table. +CREATE TABLE IF NOT EXISTS usage_dashboard_daily_users ( + bucket_date DATE NOT NULL, + user_id BIGINT NOT NULL, + PRIMARY KEY (bucket_date, user_id) +); + +CREATE INDEX IF NOT EXISTS idx_usage_dashboard_daily_users_bucket_date + ON usage_dashboard_daily_users (bucket_date); + +-- Aggregation watermark table (single row). +CREATE TABLE IF NOT EXISTS usage_dashboard_aggregation_watermark ( + id INT PRIMARY KEY, + last_aggregated_at TIMESTAMPTZ NOT NULL DEFAULT TIMESTAMPTZ '1970-01-01 00:00:00+00', + updated_at TIMESTAMPTZ NOT NULL DEFAULT NOW() +); + +INSERT INTO usage_dashboard_aggregation_watermark (id) +VALUES (1) +ON CONFLICT (id) DO NOTHING; diff --git a/backend/migrations/035_usage_logs_partitioning.sql b/backend/migrations/035_usage_logs_partitioning.sql new file mode 100644 index 00000000..5919b5c3 --- /dev/null +++ b/backend/migrations/035_usage_logs_partitioning.sql @@ -0,0 +1,54 @@ +-- usage_logs monthly partition bootstrap. +-- Only converts to partitioned table when usage_logs is empty. +-- Existing installations with data require a manual migration plan. + +DO $$ +DECLARE + is_partitioned BOOLEAN := FALSE; + has_data BOOLEAN := FALSE; + month_start DATE; + prev_month DATE; + next_month DATE; +BEGIN + SELECT EXISTS( + SELECT 1 + FROM pg_partitioned_table pt + JOIN pg_class c ON c.oid = pt.partrelid + WHERE c.relname = 'usage_logs' + ) INTO is_partitioned; + + IF NOT is_partitioned THEN + SELECT EXISTS(SELECT 1 FROM usage_logs LIMIT 1) INTO has_data; + IF NOT has_data THEN + EXECUTE 'ALTER TABLE usage_logs PARTITION BY RANGE (created_at)'; + is_partitioned := TRUE; + END IF; + END IF; + + IF is_partitioned THEN + month_start := date_trunc('month', now() AT TIME ZONE 'UTC')::date; + prev_month := (month_start - INTERVAL '1 month')::date; + next_month := (month_start + INTERVAL '1 month')::date; + + EXECUTE format( + 'CREATE TABLE IF NOT EXISTS usage_logs_%s PARTITION OF usage_logs FOR VALUES FROM (%L) TO (%L)', + to_char(prev_month, 'YYYYMM'), + prev_month, + month_start + ); + + EXECUTE format( + 'CREATE TABLE IF NOT EXISTS usage_logs_%s PARTITION OF usage_logs FOR VALUES FROM (%L) TO (%L)', + to_char(month_start, 'YYYYMM'), + month_start, + next_month + ); + + EXECUTE format( + 'CREATE TABLE IF NOT EXISTS usage_logs_%s PARTITION OF usage_logs FOR VALUES FROM (%L) TO (%L)', + to_char(next_month, 'YYYYMM'), + next_month, + (next_month + INTERVAL '1 month')::date + ); + END IF; +END $$; diff --git a/config.yaml b/config.yaml index ffc070a0..848421d6 100644 --- a/config.yaml +++ b/config.yaml @@ -215,6 +215,39 @@ dashboard_cache: # 异步刷新超时(秒) stats_refresh_timeout_seconds: 30 +# ============================================================================= +# Dashboard Aggregation Configuration +# 仪表盘预聚合配置(重启生效) +# ============================================================================= +dashboard_aggregation: + # Enable aggregation job + # 启用聚合作业 + enabled: true + # Refresh interval (seconds) + # 刷新间隔(秒) + interval_seconds: 60 + # Lookback window (seconds) for late-arriving data + # 回看窗口(秒),处理迟到数据 + lookback_seconds: 120 + # Allow manual backfill + # 允许手动回填 + backfill_enabled: false + # Recompute recent N days on startup + # 启动时重算最近 N 天 + recompute_days: 2 + # Retention windows (days) + # 保留窗口(天) + retention: + # Raw usage_logs retention + # 原始 usage_logs 保留天数 + usage_logs_days: 90 + # Hourly aggregation retention + # 小时聚合保留天数 + hourly_days: 180 + # Daily aggregation retention + # 日聚合保留天数 + daily_days: 730 + # ============================================================================= # Concurrency Wait Configuration # 并发等待配置 diff --git a/deploy/config.example.yaml b/deploy/config.example.yaml index 7083f9e9..460606ab 100644 --- a/deploy/config.example.yaml +++ b/deploy/config.example.yaml @@ -215,6 +215,39 @@ dashboard_cache: # 异步刷新超时(秒) stats_refresh_timeout_seconds: 30 +# ============================================================================= +# Dashboard Aggregation Configuration +# 仪表盘预聚合配置(重启生效) +# ============================================================================= +dashboard_aggregation: + # Enable aggregation job + # 启用聚合作业 + enabled: true + # Refresh interval (seconds) + # 刷新间隔(秒) + interval_seconds: 60 + # Lookback window (seconds) for late-arriving data + # 回看窗口(秒),处理迟到数据 + lookback_seconds: 120 + # Allow manual backfill + # 允许手动回填 + backfill_enabled: false + # Recompute recent N days on startup + # 启动时重算最近 N 天 + recompute_days: 2 + # Retention windows (days) + # 保留窗口(天) + retention: + # Raw usage_logs retention + # 原始 usage_logs 保留天数 + usage_logs_days: 90 + # Hourly aggregation retention + # 小时聚合保留天数 + hourly_days: 180 + # Daily aggregation retention + # 日聚合保留天数 + daily_days: 730 + # ============================================================================= # Concurrency Wait Configuration # 并发等待配置 diff --git a/frontend/src/types/index.ts b/frontend/src/types/index.ts index 3d1b17f6..98718b19 100644 --- a/frontend/src/types/index.ts +++ b/frontend/src/types/index.ts @@ -651,6 +651,9 @@ export interface DashboardStats { total_users: number today_new_users: number // 今日新增用户数 active_users: number // 今日有请求的用户数 + hourly_active_users: number // 当前小时活跃用户数(UTC) + stats_updated_at: string // 统计更新时间(UTC RFC3339) + stats_stale: boolean // 统计是否过期 // API Key 统计 total_api_keys: number From d78f42d2fda854e479e94e3637db2468a7d0fe5e Mon Sep 17 00:00:00 2001 From: yangjianbo Date: Sun, 11 Jan 2026 16:02:28 +0800 Subject: [PATCH 14/23] =?UTF-8?q?chore(=E6=B3=A8=E9=87=8A):=20=E8=B0=83?= =?UTF-8?q?=E6=95=B4=E4=BB=AA=E8=A1=A8=E7=9B=98=E6=B3=A8=E9=87=8A=E4=B8=BA?= =?UTF-8?q?=E4=B8=AD=E6=96=87?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- backend/internal/service/dashboard_service.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/backend/internal/service/dashboard_service.go b/backend/internal/service/dashboard_service.go index 40ab877d..d0e6e03c 100644 --- a/backend/internal/service/dashboard_service.go +++ b/backend/internal/service/dashboard_service.go @@ -34,7 +34,7 @@ type dashboardStatsCacheEntry struct { UpdatedAt int64 `json:"updated_at"` } -// DashboardService provides aggregated statistics for admin dashboard. +// DashboardService 提供管理员仪表盘统计服务。 type DashboardService struct { usageRepo UsageLogRepository aggRepo DashboardAggregationRepository From 5364011a5bacff1f1a00c98b2d8f473e33f5baca Mon Sep 17 00:00:00 2001 From: yangjianbo Date: Sun, 11 Jan 2026 17:21:17 +0800 Subject: [PATCH 15/23] =?UTF-8?q?fix(=E4=BB=AA=E8=A1=A8=E7=9B=98):=20?= =?UTF-8?q?=E4=BF=AE=E6=AD=A3=E8=81=9A=E5=90=88=E6=97=B6=E9=97=B4=E6=A1=B6?= =?UTF-8?q?=E4=B8=8E=E6=B8=85=E7=90=86=E8=8A=82=E6=B5=81?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- .../repository/dashboard_aggregation_repo.go | 11 +-- .../service/dashboard_aggregation_service.go | 22 +++-- .../dashboard_aggregation_service_test.go | 89 +++++++++++++++++++ 3 files changed, 108 insertions(+), 14 deletions(-) create mode 100644 backend/internal/service/dashboard_aggregation_service_test.go diff --git a/backend/internal/repository/dashboard_aggregation_repo.go b/backend/internal/repository/dashboard_aggregation_repo.go index dbba5cdb..b02cde0d 100644 --- a/backend/internal/repository/dashboard_aggregation_repo.go +++ b/backend/internal/repository/dashboard_aggregation_repo.go @@ -43,10 +43,11 @@ func (r *dashboardAggregationRepository) AggregateRange(ctx context.Context, sta dayEnd = dayEnd.Add(24 * time.Hour) } - if err := r.insertHourlyActiveUsers(ctx, startUTC, endUTC); err != nil { + // 以桶边界聚合,允许覆盖 end 所在桶的剩余区间。 + if err := r.insertHourlyActiveUsers(ctx, hourStart, hourEnd); err != nil { return err } - if err := r.insertDailyActiveUsers(ctx, startUTC, endUTC); err != nil { + if err := r.insertDailyActiveUsers(ctx, hourStart, hourEnd); err != nil { return err } if err := r.upsertHourlyAggregates(ctx, hourStart, hourEnd); err != nil { @@ -138,10 +139,10 @@ func (r *dashboardAggregationRepository) insertDailyActiveUsers(ctx context.Cont query := ` INSERT INTO usage_dashboard_daily_users (bucket_date, user_id) SELECT DISTINCT - (created_at AT TIME ZONE 'UTC')::date AS bucket_date, + (bucket_start AT TIME ZONE 'UTC')::date AS bucket_date, user_id - FROM usage_logs - WHERE created_at >= $1 AND created_at < $2 + FROM usage_dashboard_hourly_users + WHERE bucket_start >= $1 AND bucket_start < $2 ON CONFLICT DO NOTHING ` _, err := r.sql.ExecContext(ctx, query, start.UTC(), end.UTC()) diff --git a/backend/internal/service/dashboard_aggregation_service.go b/backend/internal/service/dashboard_aggregation_service.go index 343c3240..133ab018 100644 --- a/backend/internal/service/dashboard_aggregation_service.go +++ b/backend/internal/service/dashboard_aggregation_service.go @@ -134,12 +134,12 @@ func (s *DashboardAggregationService) runScheduledAggregation() { } lookback := time.Duration(s.cfg.LookbackSeconds) * time.Second - start := last.Add(-lookback) epoch := time.Unix(0, 0).UTC() + start := last.Add(-lookback) if !last.After(epoch) { - start = now.Add(-lookback) - } - if start.After(now) { + // 首次聚合覆盖当天,避免只统计最后一个窗口。 + start = truncateToDayUTC(now) + } else if start.After(now) { start = now.Add(-lookback) } @@ -204,17 +204,21 @@ func (s *DashboardAggregationService) maybeCleanupRetention(ctx context.Context, return } } - s.lastRetentionCleanup.Store(now) hourlyCutoff := now.AddDate(0, 0, -s.cfg.Retention.HourlyDays) dailyCutoff := now.AddDate(0, 0, -s.cfg.Retention.DailyDays) usageCutoff := now.AddDate(0, 0, -s.cfg.Retention.UsageLogsDays) - if err := s.repo.CleanupAggregates(ctx, hourlyCutoff, dailyCutoff); err != nil { - log.Printf("[DashboardAggregation] 聚合保留清理失败: %v", err) + aggErr := s.repo.CleanupAggregates(ctx, hourlyCutoff, dailyCutoff) + if aggErr != nil { + log.Printf("[DashboardAggregation] 聚合保留清理失败: %v", aggErr) } - if err := s.repo.CleanupUsageLogs(ctx, usageCutoff); err != nil { - log.Printf("[DashboardAggregation] usage_logs 保留清理失败: %v", err) + usageErr := s.repo.CleanupUsageLogs(ctx, usageCutoff) + if usageErr != nil { + log.Printf("[DashboardAggregation] usage_logs 保留清理失败: %v", usageErr) + } + if aggErr == nil && usageErr == nil { + s.lastRetentionCleanup.Store(now) } } diff --git a/backend/internal/service/dashboard_aggregation_service_test.go b/backend/internal/service/dashboard_aggregation_service_test.go new file mode 100644 index 00000000..501b11d4 --- /dev/null +++ b/backend/internal/service/dashboard_aggregation_service_test.go @@ -0,0 +1,89 @@ +package service + +import ( + "context" + "errors" + "testing" + "time" + + "github.com/Wei-Shaw/sub2api/internal/config" + "github.com/stretchr/testify/require" +) + +type dashboardAggregationRepoTestStub struct { + aggregateCalls int + lastStart time.Time + lastEnd time.Time + watermark time.Time + aggregateErr error + cleanupAggregatesErr error + cleanupUsageErr error +} + +func (s *dashboardAggregationRepoTestStub) AggregateRange(ctx context.Context, start, end time.Time) error { + s.aggregateCalls++ + s.lastStart = start + s.lastEnd = end + return s.aggregateErr +} + +func (s *dashboardAggregationRepoTestStub) GetAggregationWatermark(ctx context.Context) (time.Time, error) { + return s.watermark, nil +} + +func (s *dashboardAggregationRepoTestStub) UpdateAggregationWatermark(ctx context.Context, aggregatedAt time.Time) error { + return nil +} + +func (s *dashboardAggregationRepoTestStub) CleanupAggregates(ctx context.Context, hourlyCutoff, dailyCutoff time.Time) error { + return s.cleanupAggregatesErr +} + +func (s *dashboardAggregationRepoTestStub) CleanupUsageLogs(ctx context.Context, cutoff time.Time) error { + return s.cleanupUsageErr +} + +func (s *dashboardAggregationRepoTestStub) EnsureUsageLogsPartitions(ctx context.Context, now time.Time) error { + return nil +} + +func TestDashboardAggregationService_RunScheduledAggregation_EpochUsesDayStart(t *testing.T) { + repo := &dashboardAggregationRepoTestStub{watermark: time.Unix(0, 0).UTC()} + svc := &DashboardAggregationService{ + repo: repo, + cfg: config.DashboardAggregationConfig{ + Enabled: true, + IntervalSeconds: 60, + LookbackSeconds: 120, + Retention: config.DashboardAggregationRetentionConfig{ + UsageLogsDays: 1, + HourlyDays: 1, + DailyDays: 1, + }, + }, + } + + svc.runScheduledAggregation() + + require.Equal(t, 1, repo.aggregateCalls) + require.False(t, repo.lastEnd.IsZero()) + require.Equal(t, truncateToDayUTC(repo.lastEnd), repo.lastStart) +} + +func TestDashboardAggregationService_CleanupRetentionFailure_DoesNotRecord(t *testing.T) { + repo := &dashboardAggregationRepoTestStub{cleanupAggregatesErr: errors.New("清理失败")} + svc := &DashboardAggregationService{ + repo: repo, + cfg: config.DashboardAggregationConfig{ + Retention: config.DashboardAggregationRetentionConfig{ + UsageLogsDays: 1, + HourlyDays: 1, + DailyDays: 1, + }, + }, + } + + svc.maybeCleanupRetention(context.Background(), time.Now().UTC()) + + require.Nil(t, svc.lastRetentionCleanup.Load()) +} From 6271a33d0898f1e3b43bac0495ff81103f437171 Mon Sep 17 00:00:00 2001 From: yangjianbo Date: Sun, 11 Jan 2026 18:20:15 +0800 Subject: [PATCH 16/23] =?UTF-8?q?fix(=E4=BB=AA=E8=A1=A8=E7=9B=98):=20?= =?UTF-8?q?=E5=85=BC=E5=AE=B9=E7=A6=81=E7=94=A8=E8=81=9A=E5=90=88=E4=B8=8E?= =?UTF-8?q?=E5=9B=9E=E5=A1=AB=E9=99=90=E5=88=B6?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- backend/internal/config/config.go | 12 ++ backend/internal/config/config_test.go | 22 +++ .../handler/admin/dashboard_handler.go | 4 + backend/internal/repository/usage_log_repo.go | 169 +++++++++++++++--- .../usage_log_repo_integration_test.go | 74 ++++++++ .../service/dashboard_aggregation_service.go | 18 +- .../dashboard_aggregation_service_test.go | 21 ++- backend/internal/service/dashboard_service.go | 25 ++- .../service/dashboard_service_test.go | 94 ++++++++-- config.yaml | 3 + deploy/.env.example | 27 +++ deploy/config.example.yaml | 3 + 12 files changed, 434 insertions(+), 38 deletions(-) diff --git a/backend/internal/config/config.go b/backend/internal/config/config.go index b91a07c1..a2fbbd1d 100644 --- a/backend/internal/config/config.go +++ b/backend/internal/config/config.go @@ -398,6 +398,8 @@ type DashboardAggregationConfig struct { LookbackSeconds int `mapstructure:"lookback_seconds"` // BackfillEnabled: 是否允许全量回填 BackfillEnabled bool `mapstructure:"backfill_enabled"` + // BackfillMaxDays: 回填最大跨度(天) + BackfillMaxDays int `mapstructure:"backfill_max_days"` // Retention: 各表保留窗口(天) Retention DashboardAggregationRetentionConfig `mapstructure:"retention"` // RecomputeDays: 启动时重算最近 N 天 @@ -726,6 +728,7 @@ func setDefaults() { viper.SetDefault("dashboard_aggregation.interval_seconds", 60) viper.SetDefault("dashboard_aggregation.lookback_seconds", 120) viper.SetDefault("dashboard_aggregation.backfill_enabled", false) + viper.SetDefault("dashboard_aggregation.backfill_max_days", 31) viper.SetDefault("dashboard_aggregation.retention.usage_logs_days", 90) viper.SetDefault("dashboard_aggregation.retention.hourly_days", 180) viper.SetDefault("dashboard_aggregation.retention.daily_days", 730) @@ -920,6 +923,12 @@ func (c *Config) Validate() error { if c.DashboardAgg.LookbackSeconds < 0 { return fmt.Errorf("dashboard_aggregation.lookback_seconds must be non-negative") } + if c.DashboardAgg.BackfillMaxDays < 0 { + return fmt.Errorf("dashboard_aggregation.backfill_max_days must be non-negative") + } + if c.DashboardAgg.BackfillEnabled && c.DashboardAgg.BackfillMaxDays == 0 { + return fmt.Errorf("dashboard_aggregation.backfill_max_days must be positive") + } if c.DashboardAgg.Retention.UsageLogsDays <= 0 { return fmt.Errorf("dashboard_aggregation.retention.usage_logs_days must be positive") } @@ -939,6 +948,9 @@ func (c *Config) Validate() error { if c.DashboardAgg.LookbackSeconds < 0 { return fmt.Errorf("dashboard_aggregation.lookback_seconds must be non-negative") } + if c.DashboardAgg.BackfillMaxDays < 0 { + return fmt.Errorf("dashboard_aggregation.backfill_max_days must be non-negative") + } if c.DashboardAgg.Retention.UsageLogsDays < 0 { return fmt.Errorf("dashboard_aggregation.retention.usage_logs_days must be non-negative") } diff --git a/backend/internal/config/config_test.go b/backend/internal/config/config_test.go index 7fc34d64..1ba6d053 100644 --- a/backend/internal/config/config_test.go +++ b/backend/internal/config/config_test.go @@ -226,6 +226,9 @@ func TestLoadDefaultDashboardAggregationConfig(t *testing.T) { if cfg.DashboardAgg.BackfillEnabled { t.Fatalf("DashboardAgg.BackfillEnabled = true, want false") } + if cfg.DashboardAgg.BackfillMaxDays != 31 { + t.Fatalf("DashboardAgg.BackfillMaxDays = %d, want 31", cfg.DashboardAgg.BackfillMaxDays) + } if cfg.DashboardAgg.Retention.UsageLogsDays != 90 { t.Fatalf("DashboardAgg.Retention.UsageLogsDays = %d, want 90", cfg.DashboardAgg.Retention.UsageLogsDays) } @@ -258,3 +261,22 @@ func TestValidateDashboardAggregationConfigDisabled(t *testing.T) { t.Fatalf("Validate() expected interval_seconds error, got: %v", err) } } + +func TestValidateDashboardAggregationBackfillMaxDays(t *testing.T) { + viper.Reset() + + cfg, err := Load() + if err != nil { + t.Fatalf("Load() error: %v", err) + } + + cfg.DashboardAgg.BackfillEnabled = true + cfg.DashboardAgg.BackfillMaxDays = 0 + err = cfg.Validate() + if err == nil { + t.Fatalf("Validate() expected error for dashboard_aggregation.backfill_max_days, got nil") + } + if !strings.Contains(err.Error(), "dashboard_aggregation.backfill_max_days") { + t.Fatalf("Validate() expected backfill_max_days error, got: %v", err) + } +} diff --git a/backend/internal/handler/admin/dashboard_handler.go b/backend/internal/handler/admin/dashboard_handler.go index 560d1075..9b675974 100644 --- a/backend/internal/handler/admin/dashboard_handler.go +++ b/backend/internal/handler/admin/dashboard_handler.go @@ -159,6 +159,10 @@ func (h *DashboardHandler) BackfillAggregation(c *gin.Context) { response.Forbidden(c, "Backfill is disabled") return } + if errors.Is(err, service.ErrDashboardBackfillTooLarge) { + response.BadRequest(c, "Backfill range too large") + return + } response.InternalError(c, "Failed to trigger backfill") return } diff --git a/backend/internal/repository/usage_log_repo.go b/backend/internal/repository/usage_log_repo.go index be2a6d18..e483f89f 100644 --- a/backend/internal/repository/usage_log_repo.go +++ b/backend/internal/repository/usage_log_repo.go @@ -269,11 +269,56 @@ func (r *usageLogRepository) GetUserStats(ctx context.Context, userID int64, sta type DashboardStats = usagestats.DashboardStats func (r *usageLogRepository) GetDashboardStats(ctx context.Context) (*DashboardStats, error) { - var stats DashboardStats - now := time.Now() + stats := &DashboardStats{} + now := time.Now().UTC() todayUTC := truncateToDayUTC(now) - // 合并用户统计查询 + if err := r.fillDashboardEntityStats(ctx, stats, todayUTC, now); err != nil { + return nil, err + } + if err := r.fillDashboardUsageStatsAggregated(ctx, stats, todayUTC, now); err != nil { + return nil, err + } + + rpm, tpm, err := r.getPerformanceStats(ctx, 0) + if err != nil { + return nil, err + } + stats.Rpm = rpm + stats.Tpm = tpm + + return stats, nil +} + +func (r *usageLogRepository) GetDashboardStatsWithRange(ctx context.Context, start, end time.Time) (*DashboardStats, error) { + startUTC := start.UTC() + endUTC := end.UTC() + if !endUTC.After(startUTC) { + return nil, errors.New("统计时间范围无效") + } + + stats := &DashboardStats{} + now := time.Now().UTC() + todayUTC := truncateToDayUTC(now) + + if err := r.fillDashboardEntityStats(ctx, stats, todayUTC, now); err != nil { + return nil, err + } + if err := r.fillDashboardUsageStatsFromUsageLogs(ctx, stats, startUTC, endUTC, todayUTC, now); err != nil { + return nil, err + } + + rpm, tpm, err := r.getPerformanceStats(ctx, 0) + if err != nil { + return nil, err + } + stats.Rpm = rpm + stats.Tpm = tpm + + return stats, nil +} + +func (r *usageLogRepository) fillDashboardEntityStats(ctx context.Context, stats *DashboardStats, todayUTC, now time.Time) error { userStatsQuery := ` SELECT COUNT(*) as total_users, @@ -289,10 +334,9 @@ func (r *usageLogRepository) GetDashboardStats(ctx context.Context) (*DashboardS &stats.TotalUsers, &stats.TodayNewUsers, ); err != nil { - return nil, err + return err } - // 合并API Key统计查询 apiKeyStatsQuery := ` SELECT COUNT(*) as total_api_keys, @@ -308,10 +352,9 @@ func (r *usageLogRepository) GetDashboardStats(ctx context.Context) (*DashboardS &stats.TotalAPIKeys, &stats.ActiveAPIKeys, ); err != nil { - return nil, err + return err } - // 合并账户统计查询 accountStatsQuery := ` SELECT COUNT(*) as total_accounts, @@ -333,10 +376,13 @@ func (r *usageLogRepository) GetDashboardStats(ctx context.Context) (*DashboardS &stats.RateLimitAccounts, &stats.OverloadAccounts, ); err != nil { - return nil, err + return err } - // 累计 Token 统计 + return nil +} + +func (r *usageLogRepository) fillDashboardUsageStatsAggregated(ctx context.Context, stats *DashboardStats, todayUTC, now time.Time) error { totalStatsQuery := ` SELECT COALESCE(SUM(total_requests), 0) as total_requests, @@ -364,14 +410,13 @@ func (r *usageLogRepository) GetDashboardStats(ctx context.Context) (*DashboardS &stats.TotalActualCost, &totalDurationMs, ); err != nil { - return nil, err + return err } stats.TotalTokens = stats.TotalInputTokens + stats.TotalOutputTokens + stats.TotalCacheCreationTokens + stats.TotalCacheReadTokens if stats.TotalRequests > 0 { stats.AverageDurationMs = float64(totalDurationMs) / float64(stats.TotalRequests) } - // 今日 Token 统计 todayStatsQuery := ` SELECT total_requests as today_requests, @@ -400,12 +445,11 @@ func (r *usageLogRepository) GetDashboardStats(ctx context.Context) (*DashboardS &stats.ActiveUsers, ); err != nil { if err != sql.ErrNoRows { - return nil, err + return err } } stats.TodayTokens = stats.TodayInputTokens + stats.TodayOutputTokens + stats.TodayCacheCreationTokens + stats.TodayCacheReadTokens - // 当前小时活跃用户 hourlyActiveQuery := ` SELECT active_users FROM usage_dashboard_hourly @@ -414,19 +458,100 @@ func (r *usageLogRepository) GetDashboardStats(ctx context.Context) (*DashboardS hourStart := now.UTC().Truncate(time.Hour) if err := scanSingleRow(ctx, r.sql, hourlyActiveQuery, []any{hourStart}, &stats.HourlyActiveUsers); err != nil { if err != sql.ErrNoRows { - return nil, err + return err } } - // 性能指标:RPM 和 TPM(最近1分钟,全局) - rpm, tpm, err := r.getPerformanceStats(ctx, 0) - if err != nil { - return nil, err - } - stats.Rpm = rpm - stats.Tpm = tpm + return nil +} - return &stats, nil +func (r *usageLogRepository) fillDashboardUsageStatsFromUsageLogs(ctx context.Context, stats *DashboardStats, startUTC, endUTC, todayUTC, now time.Time) error { + totalStatsQuery := ` + SELECT + COUNT(*) as total_requests, + COALESCE(SUM(input_tokens), 0) as total_input_tokens, + COALESCE(SUM(output_tokens), 0) as total_output_tokens, + COALESCE(SUM(cache_creation_tokens), 0) as total_cache_creation_tokens, + COALESCE(SUM(cache_read_tokens), 0) as total_cache_read_tokens, + COALESCE(SUM(total_cost), 0) as total_cost, + COALESCE(SUM(actual_cost), 0) as total_actual_cost, + COALESCE(SUM(COALESCE(duration_ms, 0)), 0) as total_duration_ms + FROM usage_logs + WHERE created_at >= $1 AND created_at < $2 + ` + var totalDurationMs int64 + if err := scanSingleRow( + ctx, + r.sql, + totalStatsQuery, + []any{startUTC, endUTC}, + &stats.TotalRequests, + &stats.TotalInputTokens, + &stats.TotalOutputTokens, + &stats.TotalCacheCreationTokens, + &stats.TotalCacheReadTokens, + &stats.TotalCost, + &stats.TotalActualCost, + &totalDurationMs, + ); err != nil { + return err + } + stats.TotalTokens = stats.TotalInputTokens + stats.TotalOutputTokens + stats.TotalCacheCreationTokens + stats.TotalCacheReadTokens + if stats.TotalRequests > 0 { + stats.AverageDurationMs = float64(totalDurationMs) / float64(stats.TotalRequests) + } + + todayEnd := todayUTC.Add(24 * time.Hour) + todayStatsQuery := ` + SELECT + COUNT(*) as today_requests, + COALESCE(SUM(input_tokens), 0) as today_input_tokens, + COALESCE(SUM(output_tokens), 0) as today_output_tokens, + COALESCE(SUM(cache_creation_tokens), 0) as today_cache_creation_tokens, + COALESCE(SUM(cache_read_tokens), 0) as today_cache_read_tokens, + COALESCE(SUM(total_cost), 0) as today_cost, + COALESCE(SUM(actual_cost), 0) as today_actual_cost + FROM usage_logs + WHERE created_at >= $1 AND created_at < $2 + ` + if err := scanSingleRow( + ctx, + r.sql, + todayStatsQuery, + []any{todayUTC, todayEnd}, + &stats.TodayRequests, + &stats.TodayInputTokens, + &stats.TodayOutputTokens, + &stats.TodayCacheCreationTokens, + &stats.TodayCacheReadTokens, + &stats.TodayCost, + &stats.TodayActualCost, + ); err != nil { + return err + } + stats.TodayTokens = stats.TodayInputTokens + stats.TodayOutputTokens + stats.TodayCacheCreationTokens + stats.TodayCacheReadTokens + + activeUsersQuery := ` + SELECT COUNT(DISTINCT user_id) as active_users + FROM usage_logs + WHERE created_at >= $1 AND created_at < $2 + ` + if err := scanSingleRow(ctx, r.sql, activeUsersQuery, []any{todayUTC, todayEnd}, &stats.ActiveUsers); err != nil { + return err + } + + hourStart := now.UTC().Truncate(time.Hour) + hourEnd := hourStart.Add(time.Hour) + hourlyActiveQuery := ` + SELECT COUNT(DISTINCT user_id) as active_users + FROM usage_logs + WHERE created_at >= $1 AND created_at < $2 + ` + if err := scanSingleRow(ctx, r.sql, hourlyActiveQuery, []any{hourStart, hourEnd}, &stats.HourlyActiveUsers); err != nil { + return err + } + + return nil } func (r *usageLogRepository) ListByAccount(ctx context.Context, accountID int64, params pagination.PaginationParams) ([]service.UsageLog, *pagination.PaginationResult, error) { diff --git a/backend/internal/repository/usage_log_repo_integration_test.go b/backend/internal/repository/usage_log_repo_integration_test.go index 09341db3..a944ed32 100644 --- a/backend/internal/repository/usage_log_repo_integration_test.go +++ b/backend/internal/repository/usage_log_repo_integration_test.go @@ -308,6 +308,80 @@ func (s *UsageLogRepoSuite) TestDashboardStats_TodayTotalsAndPerformance() { s.Require().Equal(wantTpm, stats.Tpm, "Tpm mismatch") } +func (s *UsageLogRepoSuite) TestDashboardStatsWithRange_Fallback() { + now := time.Now().UTC() + todayStart := truncateToDayUTC(now) + rangeStart := todayStart.Add(-24 * time.Hour) + rangeEnd := now + + user1 := mustCreateUser(s.T(), s.client, &service.User{Email: "range-u1@test.com"}) + user2 := mustCreateUser(s.T(), s.client, &service.User{Email: "range-u2@test.com"}) + apiKey1 := mustCreateApiKey(s.T(), s.client, &service.APIKey{UserID: user1.ID, Key: "sk-range-1", Name: "k1"}) + apiKey2 := mustCreateApiKey(s.T(), s.client, &service.APIKey{UserID: user2.ID, Key: "sk-range-2", Name: "k2"}) + account := mustCreateAccount(s.T(), s.client, &service.Account{Name: "acc-range"}) + + d1, d2, d3 := 100, 200, 300 + logOutside := &service.UsageLog{ + UserID: user1.ID, + APIKeyID: apiKey1.ID, + AccountID: account.ID, + Model: "claude-3", + InputTokens: 7, + OutputTokens: 8, + TotalCost: 0.8, + ActualCost: 0.7, + DurationMs: &d3, + CreatedAt: rangeStart.Add(-1 * time.Hour), + } + _, err := s.repo.Create(s.ctx, logOutside) + s.Require().NoError(err) + + logRange := &service.UsageLog{ + UserID: user1.ID, + APIKeyID: apiKey1.ID, + AccountID: account.ID, + Model: "claude-3", + InputTokens: 10, + OutputTokens: 20, + CacheCreationTokens: 1, + CacheReadTokens: 2, + TotalCost: 1.0, + ActualCost: 0.9, + DurationMs: &d1, + CreatedAt: rangeStart.Add(2 * time.Hour), + } + _, err = s.repo.Create(s.ctx, logRange) + s.Require().NoError(err) + + logToday := &service.UsageLog{ + UserID: user2.ID, + APIKeyID: apiKey2.ID, + AccountID: account.ID, + Model: "claude-3", + InputTokens: 5, + OutputTokens: 6, + CacheReadTokens: 1, + TotalCost: 0.5, + ActualCost: 0.5, + DurationMs: &d2, + CreatedAt: now, + } + _, err = s.repo.Create(s.ctx, logToday) + s.Require().NoError(err) + + stats, err := s.repo.GetDashboardStatsWithRange(s.ctx, rangeStart, rangeEnd) + s.Require().NoError(err) + s.Require().Equal(int64(2), stats.TotalRequests) + s.Require().Equal(int64(15), stats.TotalInputTokens) + s.Require().Equal(int64(26), stats.TotalOutputTokens) + s.Require().Equal(int64(1), stats.TotalCacheCreationTokens) + s.Require().Equal(int64(3), stats.TotalCacheReadTokens) + s.Require().Equal(int64(45), stats.TotalTokens) + s.Require().Equal(1.5, stats.TotalCost) + s.Require().Equal(1.4, stats.TotalActualCost) + s.Require().InEpsilon(150.0, stats.AverageDurationMs, 0.0001) +} + // --- GetUserDashboardStats --- func (s *UsageLogRepoSuite) TestGetUserDashboardStats() { diff --git a/backend/internal/service/dashboard_aggregation_service.go b/backend/internal/service/dashboard_aggregation_service.go index 133ab018..0d1cec57 100644 --- a/backend/internal/service/dashboard_aggregation_service.go +++ b/backend/internal/service/dashboard_aggregation_service.go @@ -19,6 +19,8 @@ const ( var ( // ErrDashboardBackfillDisabled 当配置禁用回填时返回。 ErrDashboardBackfillDisabled = errors.New("仪表盘聚合回填已禁用") + // ErrDashboardBackfillTooLarge 当回填跨度超过限制时返回。 + ErrDashboardBackfillTooLarge = errors.New("回填时间跨度过大") ) // DashboardAggregationRepository 定义仪表盘预聚合仓储接口。 @@ -76,6 +78,9 @@ func (s *DashboardAggregationService) Start() { s.runScheduledAggregation() }) log.Printf("[DashboardAggregation] 聚合作业启动 (interval=%v, lookback=%ds)", interval, s.cfg.LookbackSeconds) + if !s.cfg.BackfillEnabled { + log.Printf("[DashboardAggregation] 回填已禁用,如需补齐保留窗口以外历史数据请手动回填") + } } // TriggerBackfill 触发回填(异步)。 @@ -90,6 +95,12 @@ func (s *DashboardAggregationService) TriggerBackfill(start, end time.Time) erro if !end.After(start) { return errors.New("回填时间范围无效") } + if s.cfg.BackfillMaxDays > 0 { + maxRange := time.Duration(s.cfg.BackfillMaxDays) * 24 * time.Hour + if end.Sub(start) > maxRange { + return ErrDashboardBackfillTooLarge + } + } go func() { ctx, cancel := context.WithTimeout(context.Background(), defaultDashboardAggregationBackfillTimeout) @@ -137,8 +148,11 @@ func (s *DashboardAggregationService) runScheduledAggregation() { epoch := time.Unix(0, 0).UTC() start := last.Add(-lookback) if !last.After(epoch) { - // 首次聚合覆盖当天,避免只统计最后一个窗口。 - start = truncateToDayUTC(now) + retentionDays := s.cfg.Retention.UsageLogsDays + if retentionDays <= 0 { + retentionDays = 1 + } + start = truncateToDayUTC(now.AddDate(0, 0, -retentionDays)) } else if start.After(now) { start = now.Add(-lookback) } diff --git a/backend/internal/service/dashboard_aggregation_service_test.go b/backend/internal/service/dashboard_aggregation_service_test.go index 501b11d4..2fc22105 100644 --- a/backend/internal/service/dashboard_aggregation_service_test.go +++ b/backend/internal/service/dashboard_aggregation_service_test.go @@ -47,7 +47,7 @@ func (s *dashboardAggregationRepoTestStub) EnsureUsageLogsPartitions(ctx context return nil } -func TestDashboardAggregationService_RunScheduledAggregation_EpochUsesDayStart(t *testing.T) { +func TestDashboardAggregationService_RunScheduledAggregation_EpochUsesRetentionStart(t *testing.T) { repo := &dashboardAggregationRepoTestStub{watermark: time.Unix(0, 0).UTC()} svc := &DashboardAggregationService{ repo: repo, @@ -67,7 +67,7 @@ func TestDashboardAggregationService_RunScheduledAggregation_EpochUsesDayStart(t require.Equal(t, 1, repo.aggregateCalls) require.False(t, repo.lastEnd.IsZero()) - require.Equal(t, truncateToDayUTC(repo.lastEnd), repo.lastStart) + require.Equal(t, truncateToDayUTC(repo.lastEnd.AddDate(0, 0, -1)), repo.lastStart) } func TestDashboardAggregationService_CleanupRetentionFailure_DoesNotRecord(t *testing.T) { @@ -87,3 +87,20 @@ func TestDashboardAggregationService_CleanupRetentionFailure_DoesNotRecord(t *te require.Nil(t, svc.lastRetentionCleanup.Load()) } + +func TestDashboardAggregationService_TriggerBackfill_TooLarge(t *testing.T) { + repo := &dashboardAggregationRepoTestStub{} + svc := &DashboardAggregationService{ + repo: repo, + cfg: config.DashboardAggregationConfig{ + BackfillEnabled: true, + BackfillMaxDays: 1, + }, + } + + start := time.Now().AddDate(0, 0, -3) + end := time.Now() + err := svc.TriggerBackfill(start, end) + require.ErrorIs(t, err, ErrDashboardBackfillTooLarge) + require.Equal(t, 0, repo.aggregateCalls) +} diff --git a/backend/internal/service/dashboard_service.go b/backend/internal/service/dashboard_service.go index d0e6e03c..69d251cb 100644 --- a/backend/internal/service/dashboard_service.go +++ b/backend/internal/service/dashboard_service.go @@ -29,6 +29,10 @@ type DashboardStatsCache interface { DeleteDashboardStats(ctx context.Context) error } +type dashboardStatsRangeFetcher interface { + GetDashboardStatsWithRange(ctx context.Context, start, end time.Time) (*usagestats.DashboardStats, error) +} + type dashboardStatsCacheEntry struct { Stats *usagestats.DashboardStats `json:"stats"` UpdatedAt int64 `json:"updated_at"` @@ -46,6 +50,7 @@ type DashboardService struct { aggEnabled bool aggInterval time.Duration aggLookback time.Duration + aggUsageDays int } func NewDashboardService(usageRepo UsageLogRepository, aggRepo DashboardAggregationRepository, cache DashboardStatsCache, cfg *config.Config) *DashboardService { @@ -55,6 +60,7 @@ func NewDashboardService(usageRepo UsageLogRepository, aggRepo DashboardAggregat aggEnabled := true aggInterval := time.Minute aggLookback := 2 * time.Minute + aggUsageDays := 90 if cfg != nil { if !cfg.Dashboard.Enabled { cache = nil @@ -75,6 +81,9 @@ func NewDashboardService(usageRepo UsageLogRepository, aggRepo DashboardAggregat if cfg.DashboardAgg.LookbackSeconds > 0 { aggLookback = time.Duration(cfg.DashboardAgg.LookbackSeconds) * time.Second } + if cfg.DashboardAgg.Retention.UsageLogsDays > 0 { + aggUsageDays = cfg.DashboardAgg.Retention.UsageLogsDays + } } return &DashboardService{ usageRepo: usageRepo, @@ -86,6 +95,7 @@ func NewDashboardService(usageRepo UsageLogRepository, aggRepo DashboardAggregat aggEnabled: aggEnabled, aggInterval: aggInterval, aggLookback: aggLookback, + aggUsageDays: aggUsageDays, } } @@ -148,7 +158,7 @@ func (s *DashboardService) getCachedDashboardStats(ctx context.Context) (*usages } func (s *DashboardService) refreshDashboardStats(ctx context.Context) (*usagestats.DashboardStats, error) { - stats, err := s.usageRepo.GetDashboardStats(ctx) + stats, err := s.fetchDashboardStats(ctx) if err != nil { return nil, err } @@ -173,7 +183,7 @@ func (s *DashboardService) refreshDashboardStatsAsync() { ctx, cancel := context.WithTimeout(context.Background(), s.refreshTimeout) defer cancel() - stats, err := s.usageRepo.GetDashboardStats(ctx) + stats, err := s.fetchDashboardStats(ctx) if err != nil { log.Printf("[Dashboard] 仪表盘缓存异步刷新失败: %v", err) return @@ -185,6 +195,17 @@ func (s *DashboardService) refreshDashboardStatsAsync() { }() } +func (s *DashboardService) fetchDashboardStats(ctx context.Context) (*usagestats.DashboardStats, error) { + if !s.aggEnabled { + if fetcher, ok := s.usageRepo.(dashboardStatsRangeFetcher); ok { + now := time.Now().UTC() + start := truncateToDayUTC(now.AddDate(0, 0, -s.aggUsageDays)) + return fetcher.GetDashboardStatsWithRange(ctx, start, now) + } + } + return s.usageRepo.GetDashboardStats(ctx) +} + func (s *DashboardService) saveDashboardStatsCache(ctx context.Context, stats *usagestats.DashboardStats) { if s.cache == nil || stats == nil { return diff --git a/backend/internal/service/dashboard_service_test.go b/backend/internal/service/dashboard_service_test.go index c7b9c6af..db3c78c3 100644 --- a/backend/internal/service/dashboard_service_test.go +++ b/backend/internal/service/dashboard_service_test.go @@ -16,10 +16,15 @@ import ( type usageRepoStub struct { UsageLogRepository - stats *usagestats.DashboardStats - err error - calls int32 - onCall chan struct{} + stats *usagestats.DashboardStats + rangeStats *usagestats.DashboardStats + err error + rangeErr error + calls int32 + rangeCalls int32 + rangeStart time.Time + rangeEnd time.Time + onCall chan struct{} } func (s *usageRepoStub) GetDashboardStats(ctx context.Context) (*usagestats.DashboardStats, error) { @@ -36,6 +41,19 @@ func (s *usageRepoStub) GetDashboardStats(ctx context.Context) (*usagestats.Dash return s.stats, nil } +func (s *usageRepoStub) GetDashboardStatsWithRange(ctx context.Context, start, end time.Time) (*usagestats.DashboardStats, error) { + atomic.AddInt32(&s.rangeCalls, 1) + s.rangeStart = start + s.rangeEnd = end + if s.rangeErr != nil { + return nil, s.rangeErr + } + if s.rangeStats != nil { + return s.rangeStats, nil + } + return s.stats, nil +} + type dashboardCacheStub struct { get func(ctx context.Context) (string, error) set func(ctx context.Context, data string, ttl time.Duration) error @@ -140,7 +158,12 @@ func TestDashboardService_CacheHitFresh(t *testing.T) { stats: &usagestats.DashboardStats{TotalUsers: 99}, } aggRepo := &dashboardAggregationRepoStub{watermark: time.Unix(0, 0).UTC()} - cfg := &config.Config{Dashboard: config.DashboardCacheConfig{Enabled: true}} + cfg := &config.Config{ + Dashboard: config.DashboardCacheConfig{Enabled: true}, + DashboardAgg: config.DashboardAggregationConfig{ + Enabled: true, + }, + } svc := NewDashboardService(repo, aggRepo, cache, cfg) got, err := svc.GetDashboardStats(context.Background()) @@ -164,7 +187,12 @@ func TestDashboardService_CacheMiss_StoresCache(t *testing.T) { } repo := &usageRepoStub{stats: stats} aggRepo := &dashboardAggregationRepoStub{watermark: time.Unix(0, 0).UTC()} - cfg := &config.Config{Dashboard: config.DashboardCacheConfig{Enabled: true}} + cfg := &config.Config{ + Dashboard: config.DashboardCacheConfig{Enabled: true}, + DashboardAgg: config.DashboardAggregationConfig{ + Enabled: true, + }, + } svc := NewDashboardService(repo, aggRepo, cache, cfg) got, err := svc.GetDashboardStats(context.Background()) @@ -191,7 +219,12 @@ func TestDashboardService_CacheDisabled_SkipsCache(t *testing.T) { } repo := &usageRepoStub{stats: stats} aggRepo := &dashboardAggregationRepoStub{watermark: time.Unix(0, 0).UTC()} - cfg := &config.Config{Dashboard: config.DashboardCacheConfig{Enabled: false}} + cfg := &config.Config{ + Dashboard: config.DashboardCacheConfig{Enabled: false}, + DashboardAgg: config.DashboardAggregationConfig{ + Enabled: true, + }, + } svc := NewDashboardService(repo, aggRepo, cache, cfg) got, err := svc.GetDashboardStats(context.Background()) @@ -226,7 +259,12 @@ func TestDashboardService_CacheHitStale_TriggersAsyncRefresh(t *testing.T) { onCall: refreshCh, } aggRepo := &dashboardAggregationRepoStub{watermark: time.Unix(0, 0).UTC()} - cfg := &config.Config{Dashboard: config.DashboardCacheConfig{Enabled: true}} + cfg := &config.Config{ + Dashboard: config.DashboardCacheConfig{Enabled: true}, + DashboardAgg: config.DashboardAggregationConfig{ + Enabled: true, + }, + } svc := NewDashboardService(repo, aggRepo, cache, cfg) got, err := svc.GetDashboardStats(context.Background()) @@ -252,7 +290,12 @@ func TestDashboardService_CacheParseError_EvictsAndRefetches(t *testing.T) { stats := &usagestats.DashboardStats{TotalUsers: 9} repo := &usageRepoStub{stats: stats} aggRepo := &dashboardAggregationRepoStub{watermark: time.Unix(0, 0).UTC()} - cfg := &config.Config{Dashboard: config.DashboardCacheConfig{Enabled: true}} + cfg := &config.Config{ + Dashboard: config.DashboardCacheConfig{Enabled: true}, + DashboardAgg: config.DashboardAggregationConfig{ + Enabled: true, + }, + } svc := NewDashboardService(repo, aggRepo, cache, cfg) got, err := svc.GetDashboardStats(context.Background()) @@ -270,7 +313,12 @@ func TestDashboardService_CacheParseError_RepoFailure(t *testing.T) { } repo := &usageRepoStub{err: errors.New("db down")} aggRepo := &dashboardAggregationRepoStub{watermark: time.Unix(0, 0).UTC()} - cfg := &config.Config{Dashboard: config.DashboardCacheConfig{Enabled: true}} + cfg := &config.Config{ + Dashboard: config.DashboardCacheConfig{Enabled: true}, + DashboardAgg: config.DashboardAggregationConfig{ + Enabled: true, + }, + } svc := NewDashboardService(repo, aggRepo, cache, cfg) _, err := svc.GetDashboardStats(context.Background()) @@ -311,3 +359,29 @@ func TestDashboardService_StatsStaleFalseWhenFresh(t *testing.T) { require.Equal(t, aggNow.Format(time.RFC3339), got.StatsUpdatedAt) require.False(t, got.StatsStale) } + +func TestDashboardService_AggDisabled_UsesUsageLogsFallback(t *testing.T) { + expected := &usagestats.DashboardStats{TotalUsers: 42} + repo := &usageRepoStub{ + rangeStats: expected, + err: errors.New("should not call aggregated stats"), + } + cfg := &config.Config{ + Dashboard: config.DashboardCacheConfig{Enabled: false}, + DashboardAgg: config.DashboardAggregationConfig{ + Enabled: false, + Retention: config.DashboardAggregationRetentionConfig{ + UsageLogsDays: 7, + }, + }, + } + svc := NewDashboardService(repo, nil, nil, cfg) + + got, err := svc.GetDashboardStats(context.Background()) + require.NoError(t, err) + require.Equal(t, int64(42), got.TotalUsers) + require.Equal(t, int32(0), atomic.LoadInt32(&repo.calls)) + require.Equal(t, int32(1), atomic.LoadInt32(&repo.rangeCalls)) + require.False(t, repo.rangeEnd.IsZero()) + require.Equal(t, truncateToDayUTC(repo.rangeEnd.AddDate(0, 0, -7)), repo.rangeStart) +} diff --git a/config.yaml b/config.yaml index 848421d6..b5272aac 100644 --- a/config.yaml +++ b/config.yaml @@ -232,6 +232,9 @@ dashboard_aggregation: # Allow manual backfill # 允许手动回填 backfill_enabled: false + # Backfill max range (days) + # 回填最大跨度(天) + backfill_max_days: 31 # Recompute recent N days on startup # 启动时重算最近 N 天 recompute_days: 2 diff --git a/deploy/.env.example b/deploy/.env.example index bd8abc5c..83e58a50 100644 --- a/deploy/.env.example +++ b/deploy/.env.example @@ -69,6 +69,33 @@ JWT_EXPIRE_HOUR=24 # Leave unset to use default ./config.yaml #CONFIG_FILE=./config.yaml +# ----------------------------------------------------------------------------- +# Dashboard Aggregation (Optional) +# ----------------------------------------------------------------------------- +# Enable aggregation job +# 启用仪表盘预聚合 +DASHBOARD_AGGREGATION_ENABLED=true +# Refresh interval (seconds) +# 刷新间隔(秒) +DASHBOARD_AGGREGATION_INTERVAL_SECONDS=60 +# Lookback window (seconds) +# 回看窗口(秒) +DASHBOARD_AGGREGATION_LOOKBACK_SECONDS=120 +# Allow manual backfill +# 允许手动回填 +DASHBOARD_AGGREGATION_BACKFILL_ENABLED=false +# Backfill max range (days) +# 回填最大跨度(天) +DASHBOARD_AGGREGATION_BACKFILL_MAX_DAYS=31 +# Recompute recent N days on startup +# 启动时重算最近 N 天 +DASHBOARD_AGGREGATION_RECOMPUTE_DAYS=2 +# Retention windows (days) +# 保留窗口(天) +DASHBOARD_AGGREGATION_RETENTION_USAGE_LOGS_DAYS=90 +DASHBOARD_AGGREGATION_RETENTION_HOURLY_DAYS=180 +DASHBOARD_AGGREGATION_RETENTION_DAILY_DAYS=730 + # ----------------------------------------------------------------------------- # Security Configuration # ----------------------------------------------------------------------------- diff --git a/deploy/config.example.yaml b/deploy/config.example.yaml index 460606ab..57239f8e 100644 --- a/deploy/config.example.yaml +++ b/deploy/config.example.yaml @@ -232,6 +232,9 @@ dashboard_aggregation: # Allow manual backfill # 允许手动回填 backfill_enabled: false + # Backfill max range (days) + # 回填最大跨度(天) + backfill_max_days: 31 # Recompute recent N days on startup # 启动时重算最近 N 天 recompute_days: 2 From ccb81445576843d84019edc67df64658adce51a5 Mon Sep 17 00:00:00 2001 From: yangjianbo Date: Sun, 11 Jan 2026 18:39:29 +0800 Subject: [PATCH 17/23] =?UTF-8?q?fix(=E4=BB=AA=E8=A1=A8=E7=9B=98):=20?= =?UTF-8?q?=E4=BF=AE=E5=A4=8Drows.Close=E9=94=99=E8=AF=AF=E6=A3=80?= =?UTF-8?q?=E6=9F=A5?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- backend/internal/repository/dashboard_aggregation_repo.go | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/backend/internal/repository/dashboard_aggregation_repo.go b/backend/internal/repository/dashboard_aggregation_repo.go index b02cde0d..5241c468 100644 --- a/backend/internal/repository/dashboard_aggregation_repo.go +++ b/backend/internal/repository/dashboard_aggregation_repo.go @@ -310,7 +310,9 @@ func (r *dashboardAggregationRepository) dropUsageLogsPartitions(ctx context.Con if err != nil { return err } - defer rows.Close() + defer func() { + _ = rows.Close() + }() cutoffMonth := truncateToMonthUTC(cutoff) for rows.Next() { From abbde130abc40c1fa507ed6ddc107debd3bfe39e Mon Sep 17 00:00:00 2001 From: cyhhao Date: Sun, 11 Jan 2026 18:43:47 +0800 Subject: [PATCH 18/23] Revert Codex OAuth fallback handling --- .../service/openai_codex_transform.go | 124 ------------------ 1 file changed, 124 deletions(-) diff --git a/backend/internal/service/openai_codex_transform.go b/backend/internal/service/openai_codex_transform.go index 965fb770..94e74f22 100644 --- a/backend/internal/service/openai_codex_transform.go +++ b/backend/internal/service/openai_codex_transform.go @@ -115,12 +115,6 @@ func applyCodexOAuthTransform(reqBody map[string]any) codexTransformResult { existingInstructions = strings.TrimSpace(existingInstructions) if instructions != "" { - if existingInstructions != "" && existingInstructions != instructions { - if input, ok := reqBody["input"].([]any); ok { - reqBody["input"] = prependSystemInstruction(input, existingInstructions) - result.Modified = true - } - } if existingInstructions != instructions { reqBody["instructions"] = instructions result.Modified = true @@ -129,7 +123,6 @@ func applyCodexOAuthTransform(reqBody map[string]any) codexTransformResult { if input, ok := reqBody["input"].([]any); ok { input = filterCodexInput(input) - input = normalizeOrphanedToolOutputs(input) reqBody["input"] = input result.Modified = true } @@ -266,19 +259,6 @@ func filterCodexInput(input []any) []any { return filtered } -func prependSystemInstruction(input []any, instructions string) []any { - message := map[string]any{ - "role": "system", - "content": []any{ - map[string]any{ - "type": "input_text", - "text": instructions, - }, - }, - } - return append([]any{message}, input...) -} - func normalizeCodexTools(reqBody map[string]any) bool { rawTools, ok := reqBody["tools"] if !ok || rawTools == nil { @@ -341,110 +321,6 @@ func normalizeCodexTools(reqBody map[string]any) bool { return modified } -func normalizeOrphanedToolOutputs(input []any) []any { - functionCallIDs := map[string]bool{} - localShellCallIDs := map[string]bool{} - customToolCallIDs := map[string]bool{} - - for _, item := range input { - m, ok := item.(map[string]any) - if !ok { - continue - } - callID := getCallID(m) - if callID == "" { - continue - } - switch m["type"] { - case "function_call": - functionCallIDs[callID] = true - case "local_shell_call": - localShellCallIDs[callID] = true - case "custom_tool_call": - customToolCallIDs[callID] = true - } - } - - output := make([]any, 0, len(input)) - for _, item := range input { - m, ok := item.(map[string]any) - if !ok { - output = append(output, item) - continue - } - switch m["type"] { - case "function_call_output": - callID := getCallID(m) - if callID == "" || (!functionCallIDs[callID] && !localShellCallIDs[callID]) { - output = append(output, convertOrphanedOutputToMessage(m, callID)) - continue - } - case "custom_tool_call_output": - callID := getCallID(m) - if callID == "" || !customToolCallIDs[callID] { - output = append(output, convertOrphanedOutputToMessage(m, callID)) - continue - } - case "local_shell_call_output": - callID := getCallID(m) - if callID == "" || !localShellCallIDs[callID] { - output = append(output, convertOrphanedOutputToMessage(m, callID)) - continue - } - } - output = append(output, m) - } - return output -} - -func getCallID(item map[string]any) string { - raw, ok := item["call_id"] - if !ok { - return "" - } - callID, ok := raw.(string) - if !ok { - return "" - } - callID = strings.TrimSpace(callID) - if callID == "" { - return "" - } - return callID -} - -func convertOrphanedOutputToMessage(item map[string]any, callID string) map[string]any { - toolName := "tool" - if name, ok := item["name"].(string); ok && name != "" { - toolName = name - } - labelID := callID - if labelID == "" { - labelID = "unknown" - } - text := stringifyOutput(item["output"]) - if len(text) > 16000 { - text = text[:16000] + "\n...[truncated]" - } - return map[string]any{ - "type": "message", - "role": "assistant", - "content": fmt.Sprintf("[Previous %s result; call_id=%s]: %s", toolName, labelID, text), - } -} - -func stringifyOutput(output any) string { - switch v := output.(type) { - case string: - return v - default: - if data, err := json.Marshal(v); err == nil { - return string(data) - } - return fmt.Sprintf("%v", v) - } -} - func codexCachePath(filename string) string { home, err := os.UserHomeDir() if err != nil { From 4b66ee2f8f1356036494c53c6f843a4cbd61d53f Mon Sep 17 00:00:00 2001 From: yangjianbo Date: Sun, 11 Jan 2026 18:49:57 +0800 Subject: [PATCH 19/23] =?UTF-8?q?chore(=E6=B5=8B=E8=AF=95):=20=E6=B8=85?= =?UTF-8?q?=E7=90=86=E6=9C=AA=E4=BD=BF=E7=94=A8=E5=AF=BC=E5=85=A5?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- backend/internal/repository/usage_log_repo_integration_test.go | 1 - 1 file changed, 1 deletion(-) diff --git a/backend/internal/repository/usage_log_repo_integration_test.go b/backend/internal/repository/usage_log_repo_integration_test.go index a944ed32..e1c8085e 100644 --- a/backend/internal/repository/usage_log_repo_integration_test.go +++ b/backend/internal/repository/usage_log_repo_integration_test.go @@ -11,7 +11,6 @@ import ( dbent "github.com/Wei-Shaw/sub2api/ent" "github.com/Wei-Shaw/sub2api/internal/pkg/pagination" - "github.com/Wei-Shaw/sub2api/internal/pkg/timezone" "github.com/Wei-Shaw/sub2api/internal/pkg/usagestats" "github.com/Wei-Shaw/sub2api/internal/service" "github.com/stretchr/testify/suite" From 48613558d4a3c2a0d6a38a4185433b8069c53191 Mon Sep 17 00:00:00 2001 From: yangjianbo Date: Sun, 11 Jan 2026 19:01:15 +0800 Subject: [PATCH 20/23] =?UTF-8?q?fix(=E4=BB=AA=E8=A1=A8=E7=9B=98):=20?= =?UTF-8?q?=E4=BF=AE=E6=AD=A3=E5=88=86=E5=8C=BA=E8=BF=81=E7=A7=BB=E4=B8=8E?= =?UTF-8?q?=E8=8C=83=E5=9B=B4=E6=B5=8B=E8=AF=95?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- .../repository/usage_log_repo_integration_test.go | 2 +- backend/migrations/035_usage_logs_partitioning.sql | 8 ++++---- 2 files changed, 5 insertions(+), 5 deletions(-) diff --git a/backend/internal/repository/usage_log_repo_integration_test.go b/backend/internal/repository/usage_log_repo_integration_test.go index e1c8085e..51964782 100644 --- a/backend/internal/repository/usage_log_repo_integration_test.go +++ b/backend/internal/repository/usage_log_repo_integration_test.go @@ -311,7 +311,7 @@ func (s *UsageLogRepoSuite) TestDashboardStatsWithRange_Fallback() { now := time.Now().UTC() todayStart := truncateToDayUTC(now) rangeStart := todayStart.Add(-24 * time.Hour) - rangeEnd := now + rangeEnd := now.Add(1 * time.Second) user1 := mustCreateUser(s.T(), s.client, &service.User{Email: "range-u1@test.com"}) user2 := mustCreateUser(s.T(), s.client, &service.User{Email: "range-u2@test.com"}) diff --git a/backend/migrations/035_usage_logs_partitioning.sql b/backend/migrations/035_usage_logs_partitioning.sql index 5919b5c3..e25a105e 100644 --- a/backend/migrations/035_usage_logs_partitioning.sql +++ b/backend/migrations/035_usage_logs_partitioning.sql @@ -1,6 +1,6 @@ -- usage_logs monthly partition bootstrap. --- Only converts to partitioned table when usage_logs is empty. --- Existing installations with data require a manual migration plan. +-- Only creates partitions when usage_logs is already partitioned. +-- Converting usage_logs to a partitioned table requires a manual migration plan. DO $$ DECLARE @@ -20,8 +20,8 @@ BEGIN IF NOT is_partitioned THEN SELECT EXISTS(SELECT 1 FROM usage_logs LIMIT 1) INTO has_data; IF NOT has_data THEN - EXECUTE 'ALTER TABLE usage_logs PARTITION BY RANGE (created_at)'; - is_partitioned := TRUE; + -- Automatic conversion is intentionally skipped; see manual migration plan. + RAISE NOTICE 'usage_logs is not partitioned; skip automatic partitioning'; END IF; END IF; From 32953405b1fc9a700d0f31b558400281ab3d5b24 Mon Sep 17 00:00:00 2001 From: yangjianbo Date: Sun, 11 Jan 2026 20:22:17 +0800 Subject: [PATCH 21/23] =?UTF-8?q?fix(=E8=B4=A6=E5=8F=B7=E7=AE=A1=E7=90=86)?= =?UTF-8?q?:=20=E8=B0=83=E5=BA=A6=E6=89=B9=E9=87=8F=E7=BB=93=E6=9E=9C?= =?UTF-8?q?=E6=98=8E=E7=BB=86=E4=B8=8E=E5=88=B7=E6=96=B0=E4=BC=98=E5=8C=96?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit 补充批量调度返回 success_ids/failed_ids 并增加合约/单测 前端加入降级处理与部分失败提示,表格行使用稳定 key 测试: make test-frontend 测试: go test ./internal/service -run BulkUpdateAccounts -tags=unit 测试: go test ./internal/server -run APIContracts -tags=unit --- backend/internal/server/api_contract_test.go | 259 ++++++++++++++++++ backend/internal/service/admin_service.go | 30 +- .../service/admin_service_bulk_update_test.go | 80 ++++++ frontend/src/api/admin/accounts.ts | 6 +- frontend/src/components/common/DataTable.vue | 15 +- frontend/src/components/common/README.md | 1 + frontend/src/i18n/locales/en.ts | 2 + frontend/src/i18n/locales/zh.ts | 2 + frontend/src/views/admin/AccountsView.vue | 126 ++++++++- 9 files changed, 497 insertions(+), 24 deletions(-) create mode 100644 backend/internal/service/admin_service_bulk_update_test.go diff --git a/backend/internal/server/api_contract_test.go b/backend/internal/server/api_contract_test.go index abcf0e6c..ebb98a50 100644 --- a/backend/internal/server/api_contract_test.go +++ b/backend/internal/server/api_contract_test.go @@ -331,6 +331,30 @@ func TestAPIContracts(t *testing.T) { } }`, }, + { + name: "POST /api/v1/admin/accounts/bulk-update", + method: http.MethodPost, + path: "/api/v1/admin/accounts/bulk-update", + body: `{"account_ids":[101,102],"schedulable":false}`, + headers: map[string]string{ + "Content-Type": "application/json", + }, + wantStatus: http.StatusOK, + wantJSON: `{ + "code": 0, + "message": "success", + "data": { + "success": 2, + "failed": 0, + "success_ids": [101, 102], + "failed_ids": [], + "results": [ + {"account_id": 101, "success": true}, + {"account_id": 102, "success": true} + ] + } + }`, + }, } for _, tt := range tests { @@ -382,6 +406,9 @@ func newContractDeps(t *testing.T) *contractDeps { apiKeyCache := stubApiKeyCache{} groupRepo := stubGroupRepo{} userSubRepo := stubUserSubscriptionRepo{} + accountRepo := stubAccountRepo{} + proxyRepo := stubProxyRepo{} + redeemRepo := stubRedeemCodeRepo{} cfg := &config.Config{ Default: config.DefaultConfig{ @@ -399,10 +426,12 @@ func newContractDeps(t *testing.T) *contractDeps { settingRepo := newStubSettingRepo() settingService := service.NewSettingService(settingRepo, cfg) + adminService := service.NewAdminService(userRepo, groupRepo, &accountRepo, proxyRepo, apiKeyRepo, redeemRepo, nil, nil, nil) authHandler := handler.NewAuthHandler(cfg, nil, userService, settingService, nil) apiKeyHandler := handler.NewAPIKeyHandler(apiKeyService) usageHandler := handler.NewUsageHandler(usageService, apiKeyService) adminSettingHandler := adminhandler.NewSettingHandler(settingService, nil, nil) + adminAccountHandler := adminhandler.NewAccountHandler(adminService, nil, nil, nil, nil, nil, nil, nil, nil, nil) jwtAuth := func(c *gin.Context) { c.Set(string(middleware.ContextKeyUser), middleware.AuthSubject{ @@ -442,6 +471,7 @@ func newContractDeps(t *testing.T) *contractDeps { v1Admin := v1.Group("/admin") v1Admin.Use(adminAuth) v1Admin.GET("/settings", adminSettingHandler.GetSettings) + v1Admin.POST("/accounts/bulk-update", adminAccountHandler.BulkUpdate) return &contractDeps{ now: now, @@ -632,6 +662,235 @@ func (stubGroupRepo) DeleteAccountGroupsByGroupID(ctx context.Context, groupID i return 0, errors.New("not implemented") } +type stubAccountRepo struct { + bulkUpdateIDs []int64 +} + +func (s *stubAccountRepo) Create(ctx context.Context, account *service.Account) error { + return errors.New("not implemented") +} + +func (s *stubAccountRepo) GetByID(ctx context.Context, id int64) (*service.Account, error) { + return nil, service.ErrAccountNotFound +} + +func (s *stubAccountRepo) GetByIDs(ctx context.Context, ids []int64) ([]*service.Account, error) { + return nil, errors.New("not implemented") +} + +func (s *stubAccountRepo) ExistsByID(ctx context.Context, id int64) (bool, error) { + return false, errors.New("not implemented") +} + +func (s *stubAccountRepo) GetByCRSAccountID(ctx context.Context, crsAccountID string) (*service.Account, error) { + return nil, errors.New("not implemented") +} + +func (s *stubAccountRepo) Update(ctx context.Context, account *service.Account) error { + return errors.New("not implemented") +} + +func (s *stubAccountRepo) Delete(ctx context.Context, id int64) error { + return errors.New("not implemented") +} + +func (s *stubAccountRepo) List(ctx context.Context, params pagination.PaginationParams) ([]service.Account, *pagination.PaginationResult, error) { + return nil, nil, errors.New("not implemented") +} + +func (s *stubAccountRepo) ListWithFilters(ctx context.Context, params pagination.PaginationParams, platform, accountType, status, search string) ([]service.Account, *pagination.PaginationResult, error) { + return nil, nil, errors.New("not implemented") +} + +func (s *stubAccountRepo) ListByGroup(ctx context.Context, groupID int64) ([]service.Account, error) { + return nil, errors.New("not implemented") +} + +func (s *stubAccountRepo) ListActive(ctx context.Context) ([]service.Account, error) { + return nil, errors.New("not implemented") +} + +func (s *stubAccountRepo) ListByPlatform(ctx context.Context, platform string) ([]service.Account, error) { + return nil, errors.New("not implemented") +} + +func (s *stubAccountRepo) UpdateLastUsed(ctx context.Context, id int64) error { + return errors.New("not implemented") +} + +func (s *stubAccountRepo) BatchUpdateLastUsed(ctx context.Context, updates map[int64]time.Time) error { + return errors.New("not implemented") +} + +func (s *stubAccountRepo) SetError(ctx context.Context, id int64, errorMsg string) error { + return errors.New("not implemented") +} + +func (s *stubAccountRepo) SetSchedulable(ctx context.Context, id int64, schedulable bool) error { + return errors.New("not implemented") +} + +func (s *stubAccountRepo) AutoPauseExpiredAccounts(ctx context.Context, now time.Time) (int64, error) { + return 0, errors.New("not implemented") +} + +func (s *stubAccountRepo) BindGroups(ctx context.Context, accountID int64, groupIDs []int64) error { + return errors.New("not implemented") +} + +func (s *stubAccountRepo) ListSchedulable(ctx context.Context) ([]service.Account, error) { + return nil, errors.New("not implemented") +} + +func (s *stubAccountRepo) ListSchedulableByGroupID(ctx context.Context, groupID int64) ([]service.Account, error) { + return nil, errors.New("not implemented") +} + +func (s *stubAccountRepo) ListSchedulableByPlatform(ctx context.Context, platform string) ([]service.Account, error) { + return nil, errors.New("not implemented") +} + +func (s *stubAccountRepo) ListSchedulableByGroupIDAndPlatform(ctx context.Context, groupID int64, platform string) ([]service.Account, error) { + return nil, errors.New("not implemented") +} + +func (s *stubAccountRepo) ListSchedulableByPlatforms(ctx context.Context, platforms []string) ([]service.Account, error) { + return nil, errors.New("not implemented") +} + +func (s *stubAccountRepo) ListSchedulableByGroupIDAndPlatforms(ctx context.Context, groupID int64, platforms []string) ([]service.Account, error) { + return nil, errors.New("not implemented") +} + +func (s *stubAccountRepo) SetRateLimited(ctx context.Context, id int64, resetAt time.Time) error { + return errors.New("not implemented") +} + +func (s *stubAccountRepo) SetAntigravityQuotaScopeLimit(ctx context.Context, id int64, scope service.AntigravityQuotaScope, resetAt time.Time) error { + return errors.New("not implemented") +} + +func (s *stubAccountRepo) SetOverloaded(ctx context.Context, id int64, until time.Time) error { + return errors.New("not implemented") +} + +func (s *stubAccountRepo) SetTempUnschedulable(ctx context.Context, id int64, until time.Time, reason string) error { + return errors.New("not implemented") +} + +func (s *stubAccountRepo) ClearTempUnschedulable(ctx context.Context, id int64) error { + return errors.New("not implemented") +} + +func (s *stubAccountRepo) ClearRateLimit(ctx context.Context, id int64) error { + return errors.New("not implemented") +} + +func (s *stubAccountRepo) ClearAntigravityQuotaScopes(ctx context.Context, id int64) error { + return errors.New("not implemented") +} + +func (s *stubAccountRepo) UpdateSessionWindow(ctx context.Context, id int64, start, end *time.Time, status string) error { + return errors.New("not implemented") +} + +func (s *stubAccountRepo) UpdateExtra(ctx context.Context, id int64, updates map[string]any) error { + return errors.New("not implemented") +} + +func (s *stubAccountRepo) BulkUpdate(ctx context.Context, ids []int64, updates service.AccountBulkUpdate) (int64, error) { + s.bulkUpdateIDs = append([]int64{}, ids...) + return int64(len(ids)), nil +} + +type stubProxyRepo struct{} + +func (stubProxyRepo) Create(ctx context.Context, proxy *service.Proxy) error { + return errors.New("not implemented") +} + +func (stubProxyRepo) GetByID(ctx context.Context, id int64) (*service.Proxy, error) { + return nil, service.ErrProxyNotFound +} + +func (stubProxyRepo) Update(ctx context.Context, proxy *service.Proxy) error { + return errors.New("not implemented") +} + +func (stubProxyRepo) Delete(ctx context.Context, id int64) error { + return errors.New("not implemented") +} + +func (stubProxyRepo) List(ctx context.Context, params pagination.PaginationParams) ([]service.Proxy, *pagination.PaginationResult, error) { + return nil, nil, errors.New("not implemented") +} + +func (stubProxyRepo) ListWithFilters(ctx context.Context, params pagination.PaginationParams, protocol, status, search string) ([]service.Proxy, *pagination.PaginationResult, error) { + return nil, nil, errors.New("not implemented") +} + +func (stubProxyRepo) ListWithFiltersAndAccountCount(ctx context.Context, params pagination.PaginationParams, protocol, status, search string) ([]service.ProxyWithAccountCount, *pagination.PaginationResult, error) { + return nil, nil, errors.New("not implemented") +} + +func (stubProxyRepo) ListActive(ctx context.Context) ([]service.Proxy, error) { + return nil, errors.New("not implemented") +} + +func (stubProxyRepo) ListActiveWithAccountCount(ctx context.Context) ([]service.ProxyWithAccountCount, error) { + return nil, errors.New("not implemented") +} + +func (stubProxyRepo) ExistsByHostPortAuth(ctx context.Context, host string, port int, username, password string) (bool, error) { + return false, errors.New("not implemented") +} + +func (stubProxyRepo) CountAccountsByProxyID(ctx context.Context, proxyID int64) (int64, error) { + return 0, errors.New("not implemented") +} + +type stubRedeemCodeRepo struct{} + +func (stubRedeemCodeRepo) Create(ctx context.Context, code *service.RedeemCode) error { + return errors.New("not implemented") +} + +func (stubRedeemCodeRepo) CreateBatch(ctx context.Context, codes []service.RedeemCode) error { + return errors.New("not implemented") +} + +func (stubRedeemCodeRepo) GetByID(ctx context.Context, id int64) (*service.RedeemCode, error) { + return nil, service.ErrRedeemCodeNotFound +} + +func (stubRedeemCodeRepo) GetByCode(ctx context.Context, code string) (*service.RedeemCode, error) { + return nil, service.ErrRedeemCodeNotFound +} + +func (stubRedeemCodeRepo) Update(ctx context.Context, code *service.RedeemCode) error { + return errors.New("not implemented") +} + +func (stubRedeemCodeRepo) Delete(ctx context.Context, id int64) error { + return errors.New("not implemented") +} + +func (stubRedeemCodeRepo) Use(ctx context.Context, id, userID int64) error { + return errors.New("not implemented") +} + +func (stubRedeemCodeRepo) List(ctx context.Context, params pagination.PaginationParams) ([]service.RedeemCode, *pagination.PaginationResult, error) { + return nil, nil, errors.New("not implemented") +} + +func (stubRedeemCodeRepo) ListWithFilters(ctx context.Context, params pagination.PaginationParams, codeType, status, search string) ([]service.RedeemCode, *pagination.PaginationResult, error) { + return nil, nil, errors.New("not implemented") +} + +func (stubRedeemCodeRepo) ListByUser(ctx context.Context, userID int64, limit int) ([]service.RedeemCode, error) { + return nil, errors.New("not implemented") +} + type stubUserSubscriptionRepo struct{} func (stubUserSubscriptionRepo) Create(ctx context.Context, sub *service.UserSubscription) error { diff --git a/backend/internal/service/admin_service.go b/backend/internal/service/admin_service.go index 75b57852..944b80ff 100644 --- a/backend/internal/service/admin_service.go +++ b/backend/internal/service/admin_service.go @@ -186,9 +186,11 @@ type BulkUpdateAccountResult struct { // BulkUpdateAccountsResult is the aggregated response for bulk updates. type BulkUpdateAccountsResult struct { - Success int `json:"success"` - Failed int `json:"failed"` - Results []BulkUpdateAccountResult `json:"results"` + Success int `json:"success"` + Failed int `json:"failed"` + SuccessIDs []int64 `json:"success_ids"` + FailedIDs []int64 `json:"failed_ids"` + Results []BulkUpdateAccountResult `json:"results"` } type CreateProxyInput struct { @@ -917,7 +919,9 @@ func (s *adminServiceImpl) UpdateAccount(ctx context.Context, id int64, input *U // It merges credentials/extra keys instead of overwriting the whole object. func (s *adminServiceImpl) BulkUpdateAccounts(ctx context.Context, input *BulkUpdateAccountsInput) (*BulkUpdateAccountsResult, error) { result := &BulkUpdateAccountsResult{ - Results: make([]BulkUpdateAccountResult, 0, len(input.AccountIDs)), + SuccessIDs: make([]int64, 0, len(input.AccountIDs)), + FailedIDs: make([]int64, 0, len(input.AccountIDs)), + Results: make([]BulkUpdateAccountResult, 0, len(input.AccountIDs)), } if len(input.AccountIDs) == 0 { @@ -981,24 +985,27 @@ func (s *adminServiceImpl) BulkUpdateAccounts(ctx context.Context, input *BulkUp entry.Success = false entry.Error = err.Error() result.Failed++ + result.FailedIDs = append(result.FailedIDs, accountID) result.Results = append(result.Results, entry) continue } platform = account.Platform } - if err := s.checkMixedChannelRisk(ctx, accountID, platform, *input.GroupIDs); err != nil { - entry.Success = false - entry.Error = err.Error() - result.Failed++ - result.Results = append(result.Results, entry) - continue - } + if err := s.checkMixedChannelRisk(ctx, accountID, platform, *input.GroupIDs); err != nil { + entry.Success = false + entry.Error = err.Error() + result.Failed++ + result.FailedIDs = append(result.FailedIDs, accountID) + result.Results = append(result.Results, entry) + continue + } } if err := s.accountRepo.BindGroups(ctx, accountID, *input.GroupIDs); err != nil { entry.Success = false entry.Error = err.Error() result.Failed++ + result.FailedIDs = append(result.FailedIDs, accountID) result.Results = append(result.Results, entry) continue } @@ -1006,6 +1013,7 @@ func (s *adminServiceImpl) BulkUpdateAccounts(ctx context.Context, input *BulkUp entry.Success = true result.Success++ + result.SuccessIDs = append(result.SuccessIDs, accountID) result.Results = append(result.Results, entry) } diff --git a/backend/internal/service/admin_service_bulk_update_test.go b/backend/internal/service/admin_service_bulk_update_test.go new file mode 100644 index 00000000..ef621213 --- /dev/null +++ b/backend/internal/service/admin_service_bulk_update_test.go @@ -0,0 +1,80 @@ +//go:build unit + +package service + +import ( + "context" + "errors" + "testing" + + "github.com/stretchr/testify/require" +) + +type accountRepoStubForBulkUpdate struct { + accountRepoStub + bulkUpdateErr error + bulkUpdateIDs []int64 + bindGroupErrByID map[int64]error +} + +func (s *accountRepoStubForBulkUpdate) BulkUpdate(_ context.Context, ids []int64, _ AccountBulkUpdate) (int64, error) { + s.bulkUpdateIDs = append([]int64{}, ids...) + if s.bulkUpdateErr != nil { + return 0, s.bulkUpdateErr + } + return int64(len(ids)), nil +} + +func (s *accountRepoStubForBulkUpdate) BindGroups(_ context.Context, accountID int64, _ []int64) error { + if err, ok := s.bindGroupErrByID[accountID]; ok { + return err + } + return nil +} + +// TestAdminService_BulkUpdateAccounts_AllSuccessIDs 验证批量更新成功时返回 success_ids/failed_ids。 +func TestAdminService_BulkUpdateAccounts_AllSuccessIDs(t *testing.T) { + repo := &accountRepoStubForBulkUpdate{} + svc := &adminServiceImpl{accountRepo: repo} + + schedulable := true + input := &BulkUpdateAccountsInput{ + AccountIDs: []int64{1, 2, 3}, + Schedulable: &schedulable, + } + + result, err := svc.BulkUpdateAccounts(context.Background(), input) + require.NoError(t, err) + require.Equal(t, 3, result.Success) + require.Equal(t, 0, result.Failed) + require.ElementsMatch(t, []int64{1, 2, 3}, result.SuccessIDs) + require.Empty(t, result.FailedIDs) + require.Len(t, result.Results, 3) +} + +// TestAdminService_BulkUpdateAccounts_PartialFailureIDs 验证部分失败时 success_ids/failed_ids 正确。 +func TestAdminService_BulkUpdateAccounts_PartialFailureIDs(t *testing.T) { + repo := &accountRepoStubForBulkUpdate{ + bindGroupErrByID: map[int64]error{ + 2: errors.New("bind failed"), + }, + } + svc := &adminServiceImpl{accountRepo: repo} + + groupIDs := []int64{10} + schedulable := false + input := &BulkUpdateAccountsInput{ + AccountIDs: []int64{1, 2, 3}, + GroupIDs: &groupIDs, + Schedulable: &schedulable, + SkipMixedChannelCheck: true, + } + + result, err := svc.BulkUpdateAccounts(context.Background(), input) + require.NoError(t, err) + require.Equal(t, 2, result.Success) + require.Equal(t, 1, result.Failed) + require.ElementsMatch(t, []int64{1, 3}, result.SuccessIDs) + require.ElementsMatch(t, []int64{2}, result.FailedIDs) + require.Len(t, result.Results, 3) +} diff --git a/frontend/src/api/admin/accounts.ts b/frontend/src/api/admin/accounts.ts index 4e1f6cd3..54d0ad94 100644 --- a/frontend/src/api/admin/accounts.ts +++ b/frontend/src/api/admin/accounts.ts @@ -275,11 +275,15 @@ export async function bulkUpdate( ): Promise<{ success: number failed: number + success_ids?: number[] + failed_ids?: number[] results: Array<{ account_id: number; success: boolean; error?: string }> -}> { + }> { const { data } = await apiClient.post<{ success: number failed: number + success_ids?: number[] + failed_ids?: number[] results: Array<{ account_id: number; success: boolean; error?: string }> }>('/admin/accounts/bulk-update', { account_ids: accountIds, diff --git a/frontend/src/components/common/DataTable.vue b/frontend/src/components/common/DataTable.vue index 7ad31f7d..dc492d36 100644 --- a/frontend/src/components/common/DataTable.vue +++ b/frontend/src/components/common/DataTable.vue @@ -83,7 +83,7 @@ string | number) } const props = withDefaults(defineProps(), { @@ -222,6 +223,18 @@ const props = withDefaults(defineProps(), { const sortKey = ref('') const sortOrder = ref<'asc' | 'desc'>('asc') const actionsExpanded = ref(false) +const resolveRowKey = (row: any, index: number) => { + if (typeof props.rowKey === 'function') { + const key = props.rowKey(row) + return key ?? index + } + if (typeof props.rowKey === 'string' && props.rowKey) { + const key = row?.[props.rowKey] + return key ?? index + } + const key = row?.id + return key ?? index +} // 数据/列变化时重新检查滚动状态 // 注意:不能监听 actionsExpanded,因为 checkActionsColumnWidth 会临时修改它,会导致无限循环 diff --git a/frontend/src/components/common/README.md b/frontend/src/components/common/README.md index 640cdc0e..1733cfad 100644 --- a/frontend/src/components/common/README.md +++ b/frontend/src/components/common/README.md @@ -13,6 +13,7 @@ A generic data table component with sorting, loading states, and custom cell ren - `columns: Column[]` - Array of column definitions with key, label, sortable, and formatter - `data: any[]` - Array of data objects to display - `loading?: boolean` - Show loading skeleton +- `rowKey?: string | (row: any) => string | number` - Row key field or resolver (defaults to `row.id`, falls back to index) **Slots:** diff --git a/frontend/src/i18n/locales/en.ts b/frontend/src/i18n/locales/en.ts index babe31e7..3dfe3034 100644 --- a/frontend/src/i18n/locales/en.ts +++ b/frontend/src/i18n/locales/en.ts @@ -1105,6 +1105,8 @@ export default { rateLimitCleared: 'Rate limit cleared successfully', bulkSchedulableEnabled: 'Successfully enabled scheduling for {count} account(s)', bulkSchedulableDisabled: 'Successfully disabled scheduling for {count} account(s)', + bulkSchedulablePartial: 'Scheduling updated partially: {success} succeeded, {failed} failed', + bulkSchedulableResultUnknown: 'Bulk scheduling result incomplete. Please retry or refresh.', bulkActions: { selected: '{count} account(s) selected', selectCurrentPage: 'Select this page', diff --git a/frontend/src/i18n/locales/zh.ts b/frontend/src/i18n/locales/zh.ts index 889c2463..22d5eebe 100644 --- a/frontend/src/i18n/locales/zh.ts +++ b/frontend/src/i18n/locales/zh.ts @@ -1241,6 +1241,8 @@ export default { accountDeletedSuccess: '账号删除成功', bulkSchedulableEnabled: '成功启用 {count} 个账号的调度', bulkSchedulableDisabled: '成功停止 {count} 个账号的调度', + bulkSchedulablePartial: '部分调度更新成功:成功 {success} 个,失败 {failed} 个', + bulkSchedulableResultUnknown: '批量调度结果不完整,请稍后重试或刷新列表', bulkActions: { selected: '已选择 {count} 个账号', selectCurrentPage: '本页全选', diff --git a/frontend/src/views/admin/AccountsView.vue b/frontend/src/views/admin/AccountsView.vue index 79c6072c..0480ef39 100644 --- a/frontend/src/views/admin/AccountsView.vue +++ b/frontend/src/views/admin/AccountsView.vue @@ -20,7 +20,7 @@