From 5b8d4fb0479eeb14356066e17576de928c466771 Mon Sep 17 00:00:00 2001 From: cyhhao Date: Fri, 9 Jan 2026 00:34:49 +0800 Subject: [PATCH 1/4] feat(openai): add AI SDK content format compatibility for OAuth accounts - Add normalizeInputForCodexAPI function to convert AI SDK multi-part content format to simplified format expected by ChatGPT Codex API - AI SDK sends: {"content": [{"type": "input_text", "text": "..."}]} - Codex API expects: {"content": "..."} - Only applies to OAuth accounts (ChatGPT internal API) - API Key accounts remain unchanged (OpenAI Platform API supports both) --- .../service/openai_gateway_service.go | 106 +++++++++++++++++- 1 file changed, 105 insertions(+), 1 deletion(-) diff --git a/backend/internal/service/openai_gateway_service.go b/backend/internal/service/openai_gateway_service.go index d744bfab..f26404c8 100644 --- a/backend/internal/service/openai_gateway_service.go +++ b/backend/internal/service/openai_gateway_service.go @@ -540,10 +540,19 @@ func (s *OpenAIGatewayService) Forward(ctx context.Context, c *gin.Context, acco bodyModified = true } - // For OAuth accounts using ChatGPT internal API, add store: false + // For OAuth accounts using ChatGPT internal API: + // 1. Add store: false + // 2. Normalize input format for Codex API compatibility if account.Type == AccountTypeOAuth { reqBody["store"] = false bodyModified = true + + // Normalize input format: convert AI SDK multi-part content format to simplified format + // AI SDK sends: {"content": [{"type": "input_text", "text": "..."}]} + // Codex API expects: {"content": "..."} + if normalizeInputForCodexAPI(reqBody) { + bodyModified = true + } } // Re-serialize body only if modified @@ -1085,6 +1094,101 @@ func (s *OpenAIGatewayService) replaceModelInResponseBody(body []byte, fromModel return newBody } +// normalizeInputForCodexAPI converts AI SDK multi-part content format to simplified format +// that the ChatGPT internal Codex API expects. +// +// AI SDK sends content as an array of typed objects: +// +// {"content": [{"type": "input_text", "text": "hello"}]} +// +// ChatGPT Codex API expects content as a simple string: +// +// {"content": "hello"} +// +// This function modifies reqBody in-place and returns true if any modification was made. +func normalizeInputForCodexAPI(reqBody map[string]any) bool { + input, ok := reqBody["input"] + if !ok { + return false + } + + // Handle case where input is a simple string (already compatible) + if _, isString := input.(string); isString { + return false + } + + // Handle case where input is an array of messages + inputArray, ok := input.([]any) + if !ok { + return false + } + + modified := false + for _, item := range inputArray { + message, ok := item.(map[string]any) + if !ok { + continue + } + + content, ok := message["content"] + if !ok { + continue + } + + // If content is already a string, no conversion needed + if _, isString := content.(string); isString { + continue + } + + // If content is an array (AI SDK format), convert to string + contentArray, ok := content.([]any) + if !ok { + continue + } + + // Extract text from content array + var textParts []string + for _, part := range contentArray { + partMap, ok := part.(map[string]any) + if !ok { + continue + } + + // Handle different content types + partType, _ := partMap["type"].(string) + switch partType { + case "input_text", "text": + // Extract text from input_text or text type + if text, ok := partMap["text"].(string); ok { + textParts = append(textParts, text) + } + case "input_image", "image": + // For images, we need to preserve the original format + // as ChatGPT Codex API may support images in a different way + // For now, skip image parts (they will be lost in conversion) + // TODO: Consider preserving image data or handling it separately + continue + case "input_file", "file": + // Similar to images, file inputs may need special handling + continue + default: + // For unknown types, try to extract text if available + if text, ok := partMap["text"].(string); ok { + textParts = append(textParts, text) + } + } + } + + // Convert content array to string + if len(textParts) > 0 { + message["content"] = strings.Join(textParts, "\n") + modified = true + } + } + + return modified +} + // OpenAIRecordUsageInput input for recording usage type OpenAIRecordUsageInput struct { Result *OpenAIForwardResult From 2d83941aaa5dd397d1119352c72a01c326c8c460 Mon Sep 17 00:00:00 2001 From: shaw Date: Fri, 9 Jan 2026 10:36:56 +0800 Subject: [PATCH 2/4] =?UTF-8?q?feat(antigravity):=20=E6=B7=BB=E5=8A=A0=20U?= =?UTF-8?q?RL=20fallback=20=E6=9C=BA=E5=88=B6=20(sandbox=20=E2=86=92=20dai?= =?UTF-8?q?ly=20=E2=86=92=20prod)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- backend/internal/pkg/antigravity/client.go | 221 +++++++---- backend/internal/pkg/antigravity/oauth.go | 70 +++- .../service/antigravity_gateway_service.go | 348 +++++++++++------- 3 files changed, 446 insertions(+), 193 deletions(-) diff --git a/backend/internal/pkg/antigravity/client.go b/backend/internal/pkg/antigravity/client.go index 8ff75f57..1248be95 100644 --- a/backend/internal/pkg/antigravity/client.go +++ b/backend/internal/pkg/antigravity/client.go @@ -5,8 +5,11 @@ import ( "bytes" "context" "encoding/json" + "errors" "fmt" "io" + "log" + "net" "net/http" "net/url" "strings" @@ -22,10 +25,10 @@ func resolveHost(urlStr string) string { return parsed.Host } -// NewAPIRequest 创建 Antigravity API 请求(v1internal 端点) -func NewAPIRequest(ctx context.Context, action, accessToken string, body []byte) (*http.Request, error) { +// NewAPIRequestWithURL 使用指定的 base URL 创建 Antigravity API 请求(v1internal 端点) +func NewAPIRequestWithURL(ctx context.Context, baseURL, action, accessToken string, body []byte) (*http.Request, error) { // 构建 URL,流式请求添加 ?alt=sse 参数 - apiURL := fmt.Sprintf("%s/v1internal:%s", BaseURL, action) + apiURL := fmt.Sprintf("%s/v1internal:%s", baseURL, action) isStream := action == "streamGenerateContent" if isStream { apiURL += "?alt=sse" @@ -53,11 +56,15 @@ func NewAPIRequest(ctx context.Context, action, accessToken string, body []byte) req.Host = host } - // 注意:requestType 已在 JSON body 的 V1InternalRequest 中设置,不需要 HTTP Header - return req, nil } +// NewAPIRequest 使用默认 URL 创建 Antigravity API 请求(v1internal 端点) +// 向后兼容:仅使用默认 BaseURL +func NewAPIRequest(ctx context.Context, action, accessToken string, body []byte) (*http.Request, error) { + return NewAPIRequestWithURL(ctx, BaseURL, action, accessToken, body) +} + // TokenResponse Google OAuth token 响应 type TokenResponse struct { AccessToken string `json:"access_token"` @@ -164,6 +171,38 @@ func NewClient(proxyURL string) *Client { } } +// isConnectionError 判断是否为连接错误(网络超时、DNS 失败、连接拒绝) +func isConnectionError(err error) bool { + if err == nil { + return false + } + + // 检查超时错误 + var netErr net.Error + if errors.As(err, &netErr) && netErr.Timeout() { + return true + } + + // 检查连接错误(DNS 失败、连接拒绝) + var opErr *net.OpError + if errors.As(err, &opErr) { + return true + } + + // 检查 URL 错误 + var urlErr *url.Error + return errors.As(err, &urlErr) +} + +// shouldFallbackToNextURL 判断是否应切换到下一个 URL +// 仅连接错误和 HTTP 429 触发 URL 降级 +func shouldFallbackToNextURL(err error, statusCode int) bool { + if isConnectionError(err) { + return true + } + return statusCode == http.StatusTooManyRequests +} + // ExchangeCode 用 authorization code 交换 token func (c *Client) ExchangeCode(ctx context.Context, code, codeVerifier string) (*TokenResponse, error) { params := url.Values{} @@ -272,6 +311,7 @@ func (c *Client) GetUserInfo(ctx context.Context, accessToken string) (*UserInfo } // LoadCodeAssist 获取账户信息,返回解析后的结构体和原始 JSON +// 支持 URL fallback:sandbox → daily → prod func (c *Client) LoadCodeAssist(ctx context.Context, accessToken string) (*LoadCodeAssistResponse, map[string]any, error) { reqBody := LoadCodeAssistRequest{} reqBody.Metadata.IDEType = "ANTIGRAVITY" @@ -281,40 +321,65 @@ func (c *Client) LoadCodeAssist(ctx context.Context, accessToken string) (*LoadC return nil, nil, fmt.Errorf("序列化请求失败: %w", err) } - url := BaseURL + "/v1internal:loadCodeAssist" - req, err := http.NewRequestWithContext(ctx, http.MethodPost, url, strings.NewReader(string(bodyBytes))) - if err != nil { - return nil, nil, fmt.Errorf("创建请求失败: %w", err) - } - req.Header.Set("Authorization", "Bearer "+accessToken) - req.Header.Set("Content-Type", "application/json") - req.Header.Set("User-Agent", UserAgent) - - resp, err := c.httpClient.Do(req) - if err != nil { - return nil, nil, fmt.Errorf("loadCodeAssist 请求失败: %w", err) - } - defer func() { _ = resp.Body.Close() }() - - respBodyBytes, err := io.ReadAll(resp.Body) - if err != nil { - return nil, nil, fmt.Errorf("读取响应失败: %w", err) + // 获取可用的 URL 列表 + availableURLs := DefaultURLAvailability.GetAvailableURLs() + if len(availableURLs) == 0 { + availableURLs = BaseURLs // 所有 URL 都不可用时,重试所有 } - if resp.StatusCode != http.StatusOK { - return nil, nil, fmt.Errorf("loadCodeAssist 失败 (HTTP %d): %s", resp.StatusCode, string(respBodyBytes)) + var lastErr error + for urlIdx, baseURL := range availableURLs { + apiURL := baseURL + "/v1internal:loadCodeAssist" + req, err := http.NewRequestWithContext(ctx, http.MethodPost, apiURL, strings.NewReader(string(bodyBytes))) + if err != nil { + lastErr = fmt.Errorf("创建请求失败: %w", err) + continue + } + req.Header.Set("Authorization", "Bearer "+accessToken) + req.Header.Set("Content-Type", "application/json") + req.Header.Set("User-Agent", UserAgent) + + resp, err := c.httpClient.Do(req) + if err != nil { + lastErr = fmt.Errorf("loadCodeAssist 请求失败: %w", err) + if shouldFallbackToNextURL(err, 0) && urlIdx < len(availableURLs)-1 { + DefaultURLAvailability.MarkUnavailable(baseURL) + log.Printf("[antigravity] loadCodeAssist URL fallback: %s -> %s", baseURL, availableURLs[urlIdx+1]) + continue + } + return nil, nil, lastErr + } + + respBodyBytes, err := io.ReadAll(resp.Body) + _ = resp.Body.Close() // 立即关闭,避免循环内 defer 导致的资源泄漏 + if err != nil { + return nil, nil, fmt.Errorf("读取响应失败: %w", err) + } + + // 检查是否需要 URL 降级 + if shouldFallbackToNextURL(nil, resp.StatusCode) && urlIdx < len(availableURLs)-1 { + DefaultURLAvailability.MarkUnavailable(baseURL) + log.Printf("[antigravity] loadCodeAssist URL fallback (HTTP %d): %s -> %s", resp.StatusCode, baseURL, availableURLs[urlIdx+1]) + continue + } + + if resp.StatusCode != http.StatusOK { + return nil, nil, fmt.Errorf("loadCodeAssist 失败 (HTTP %d): %s", resp.StatusCode, string(respBodyBytes)) + } + + var loadResp LoadCodeAssistResponse + if err := json.Unmarshal(respBodyBytes, &loadResp); err != nil { + return nil, nil, fmt.Errorf("响应解析失败: %w", err) + } + + // 解析原始 JSON 为 map + var rawResp map[string]any + _ = json.Unmarshal(respBodyBytes, &rawResp) + + return &loadResp, rawResp, nil } - var loadResp LoadCodeAssistResponse - if err := json.Unmarshal(respBodyBytes, &loadResp); err != nil { - return nil, nil, fmt.Errorf("响应解析失败: %w", err) - } - - // 解析原始 JSON 为 map - var rawResp map[string]any - _ = json.Unmarshal(respBodyBytes, &rawResp) - - return &loadResp, rawResp, nil + return nil, nil, lastErr } // ModelQuotaInfo 模型配额信息 @@ -339,6 +404,7 @@ type FetchAvailableModelsResponse struct { } // FetchAvailableModels 获取可用模型和配额信息,返回解析后的结构体和原始 JSON +// 支持 URL fallback:sandbox → daily → prod func (c *Client) FetchAvailableModels(ctx context.Context, accessToken, projectID string) (*FetchAvailableModelsResponse, map[string]any, error) { reqBody := FetchAvailableModelsRequest{Project: projectID} bodyBytes, err := json.Marshal(reqBody) @@ -346,38 +412,63 @@ func (c *Client) FetchAvailableModels(ctx context.Context, accessToken, projectI return nil, nil, fmt.Errorf("序列化请求失败: %w", err) } - apiURL := BaseURL + "/v1internal:fetchAvailableModels" - req, err := http.NewRequestWithContext(ctx, http.MethodPost, apiURL, strings.NewReader(string(bodyBytes))) - if err != nil { - return nil, nil, fmt.Errorf("创建请求失败: %w", err) - } - req.Header.Set("Authorization", "Bearer "+accessToken) - req.Header.Set("Content-Type", "application/json") - req.Header.Set("User-Agent", UserAgent) - - resp, err := c.httpClient.Do(req) - if err != nil { - return nil, nil, fmt.Errorf("fetchAvailableModels 请求失败: %w", err) - } - defer func() { _ = resp.Body.Close() }() - - respBodyBytes, err := io.ReadAll(resp.Body) - if err != nil { - return nil, nil, fmt.Errorf("读取响应失败: %w", err) + // 获取可用的 URL 列表 + availableURLs := DefaultURLAvailability.GetAvailableURLs() + if len(availableURLs) == 0 { + availableURLs = BaseURLs // 所有 URL 都不可用时,重试所有 } - if resp.StatusCode != http.StatusOK { - return nil, nil, fmt.Errorf("fetchAvailableModels 失败 (HTTP %d): %s", resp.StatusCode, string(respBodyBytes)) + var lastErr error + for urlIdx, baseURL := range availableURLs { + apiURL := baseURL + "/v1internal:fetchAvailableModels" + req, err := http.NewRequestWithContext(ctx, http.MethodPost, apiURL, strings.NewReader(string(bodyBytes))) + if err != nil { + lastErr = fmt.Errorf("创建请求失败: %w", err) + continue + } + req.Header.Set("Authorization", "Bearer "+accessToken) + req.Header.Set("Content-Type", "application/json") + req.Header.Set("User-Agent", UserAgent) + + resp, err := c.httpClient.Do(req) + if err != nil { + lastErr = fmt.Errorf("fetchAvailableModels 请求失败: %w", err) + if shouldFallbackToNextURL(err, 0) && urlIdx < len(availableURLs)-1 { + DefaultURLAvailability.MarkUnavailable(baseURL) + log.Printf("[antigravity] fetchAvailableModels URL fallback: %s -> %s", baseURL, availableURLs[urlIdx+1]) + continue + } + return nil, nil, lastErr + } + + respBodyBytes, err := io.ReadAll(resp.Body) + _ = resp.Body.Close() // 立即关闭,避免循环内 defer 导致的资源泄漏 + if err != nil { + return nil, nil, fmt.Errorf("读取响应失败: %w", err) + } + + // 检查是否需要 URL 降级 + if shouldFallbackToNextURL(nil, resp.StatusCode) && urlIdx < len(availableURLs)-1 { + DefaultURLAvailability.MarkUnavailable(baseURL) + log.Printf("[antigravity] fetchAvailableModels URL fallback (HTTP %d): %s -> %s", resp.StatusCode, baseURL, availableURLs[urlIdx+1]) + continue + } + + if resp.StatusCode != http.StatusOK { + return nil, nil, fmt.Errorf("fetchAvailableModels 失败 (HTTP %d): %s", resp.StatusCode, string(respBodyBytes)) + } + + var modelsResp FetchAvailableModelsResponse + if err := json.Unmarshal(respBodyBytes, &modelsResp); err != nil { + return nil, nil, fmt.Errorf("响应解析失败: %w", err) + } + + // 解析原始 JSON 为 map + var rawResp map[string]any + _ = json.Unmarshal(respBodyBytes, &rawResp) + + return &modelsResp, rawResp, nil } - var modelsResp FetchAvailableModelsResponse - if err := json.Unmarshal(respBodyBytes, &modelsResp); err != nil { - return nil, nil, fmt.Errorf("响应解析失败: %w", err) - } - - // 解析原始 JSON 为 map - var rawResp map[string]any - _ = json.Unmarshal(respBodyBytes, &rawResp) - - return &modelsResp, rawResp, nil + return nil, nil, lastErr } diff --git a/backend/internal/pkg/antigravity/oauth.go b/backend/internal/pkg/antigravity/oauth.go index e88c203b..736c45df 100644 --- a/backend/internal/pkg/antigravity/oauth.go +++ b/backend/internal/pkg/antigravity/oauth.go @@ -32,17 +32,79 @@ const ( "https://www.googleapis.com/auth/cclog " + "https://www.googleapis.com/auth/experimentsandconfigs" - // API 端点 - // 优先使用 sandbox daily URL,配额更宽松 - BaseURL = "https://daily-cloudcode-pa.sandbox.googleapis.com" - // User-Agent(模拟官方客户端) UserAgent = "antigravity/1.104.0 darwin/arm64" // Session 过期时间 SessionTTL = 30 * time.Minute + + // URL 可用性 TTL(不可用 URL 的恢复时间) + URLAvailabilityTTL = 5 * time.Minute ) +// BaseURLs 定义 Antigravity API 端点,按优先级排序 +// fallback 顺序: sandbox → daily → prod +var BaseURLs = []string{ + "https://daily-cloudcode-pa.sandbox.googleapis.com", // sandbox + "https://daily-cloudcode-pa.googleapis.com", // daily + "https://cloudcode-pa.googleapis.com", // prod +} + +// BaseURL 默认 URL(保持向后兼容) +var BaseURL = BaseURLs[0] + +// URLAvailability 管理 URL 可用性状态(带 TTL 自动恢复) +type URLAvailability struct { + mu sync.RWMutex + unavailable map[string]time.Time // URL -> 恢复时间 + ttl time.Duration +} + +// DefaultURLAvailability 全局 URL 可用性管理器 +var DefaultURLAvailability = NewURLAvailability(URLAvailabilityTTL) + +// NewURLAvailability 创建 URL 可用性管理器 +func NewURLAvailability(ttl time.Duration) *URLAvailability { + return &URLAvailability{ + unavailable: make(map[string]time.Time), + ttl: ttl, + } +} + +// MarkUnavailable 标记 URL 临时不可用 +func (u *URLAvailability) MarkUnavailable(url string) { + u.mu.Lock() + defer u.mu.Unlock() + u.unavailable[url] = time.Now().Add(u.ttl) +} + +// IsAvailable 检查 URL 是否可用 +func (u *URLAvailability) IsAvailable(url string) bool { + u.mu.RLock() + defer u.mu.RUnlock() + expiry, exists := u.unavailable[url] + if !exists { + return true + } + return time.Now().After(expiry) +} + +// GetAvailableURLs 返回可用的 URL 列表(保持优先级顺序) +func (u *URLAvailability) GetAvailableURLs() []string { + u.mu.RLock() + defer u.mu.RUnlock() + + now := time.Now() + result := make([]string, 0, len(BaseURLs)) + for _, url := range BaseURLs { + expiry, exists := u.unavailable[url] + if !exists || now.After(expiry) { + result = append(result, url) + } + } + return result +} + // OAuthSession 保存 OAuth 授权流程的临时状态 type OAuthSession struct { State string `json:"state"` diff --git a/backend/internal/service/antigravity_gateway_service.go b/backend/internal/service/antigravity_gateway_service.go index 2fe77b2d..573017cd 100644 --- a/backend/internal/service/antigravity_gateway_service.go +++ b/backend/internal/service/antigravity_gateway_service.go @@ -10,6 +10,7 @@ import ( "io" "log" mathrand "math/rand" + "net" "net/http" "strings" "sync/atomic" @@ -27,6 +28,32 @@ const ( antigravityRetryMaxDelay = 16 * time.Second ) +// isAntigravityConnectionError 判断是否为连接错误(网络超时、DNS 失败、连接拒绝) +func isAntigravityConnectionError(err error) bool { + if err == nil { + return false + } + + // 检查超时错误 + var netErr net.Error + if errors.As(err, &netErr) && netErr.Timeout() { + return true + } + + // 检查连接错误(DNS 失败、连接拒绝) + var opErr *net.OpError + return errors.As(err, &opErr) +} + +// shouldAntigravityFallbackToNextURL 判断是否应切换到下一个 URL +// 仅连接错误和 HTTP 429 触发 URL 降级 +func shouldAntigravityFallbackToNextURL(err error, statusCode int) bool { + if isAntigravityConnectionError(err) { + return true + } + return statusCode == http.StatusTooManyRequests +} + // getSessionID 从 gin.Context 获取 session_id(用于日志追踪) func getSessionID(c *gin.Context) string { if c == nil { @@ -181,45 +208,70 @@ func (s *AntigravityGatewayService) TestConnection(ctx context.Context, account return nil, fmt.Errorf("构建请求失败: %w", err) } - // 构建 HTTP 请求(总是使用流式 endpoint,与官方客户端一致) - req, err := antigravity.NewAPIRequest(ctx, "streamGenerateContent", accessToken, requestBody) - if err != nil { - return nil, err - } - - // 调试日志:Test 请求信息 - log.Printf("[antigravity-Test] account=%s request_size=%d url=%s", account.Name, len(requestBody), req.URL.String()) - // 代理 URL proxyURL := "" if account.ProxyID != nil && account.Proxy != nil { proxyURL = account.Proxy.URL() } - // 发送请求 - resp, err := s.httpUpstream.Do(req, proxyURL, account.ID, account.Concurrency) - if err != nil { - return nil, fmt.Errorf("请求失败: %w", err) - } - defer func() { _ = resp.Body.Close() }() - - // 读取响应 - respBody, err := io.ReadAll(io.LimitReader(resp.Body, 2<<20)) - if err != nil { - return nil, fmt.Errorf("读取响应失败: %w", err) + // URL fallback 循环 + availableURLs := antigravity.DefaultURLAvailability.GetAvailableURLs() + if len(availableURLs) == 0 { + availableURLs = antigravity.BaseURLs // 所有 URL 都不可用时,重试所有 } - if resp.StatusCode >= 400 { - return nil, fmt.Errorf("API 返回 %d: %s", resp.StatusCode, string(respBody)) + var lastErr error + for urlIdx, baseURL := range availableURLs { + // 构建 HTTP 请求(总是使用流式 endpoint,与官方客户端一致) + req, err := antigravity.NewAPIRequestWithURL(ctx, baseURL, "streamGenerateContent", accessToken, requestBody) + if err != nil { + lastErr = err + continue + } + + // 调试日志:Test 请求信息 + log.Printf("[antigravity-Test] account=%s request_size=%d url=%s", account.Name, len(requestBody), req.URL.String()) + + // 发送请求 + resp, err := s.httpUpstream.Do(req, proxyURL, account.ID, account.Concurrency) + if err != nil { + lastErr = fmt.Errorf("请求失败: %w", err) + if shouldAntigravityFallbackToNextURL(err, 0) && urlIdx < len(availableURLs)-1 { + antigravity.DefaultURLAvailability.MarkUnavailable(baseURL) + log.Printf("[antigravity-Test] URL fallback: %s -> %s", baseURL, availableURLs[urlIdx+1]) + continue + } + return nil, lastErr + } + + // 读取响应 + respBody, err := io.ReadAll(io.LimitReader(resp.Body, 2<<20)) + _ = resp.Body.Close() // 立即关闭,避免循环内 defer 导致的资源泄漏 + if err != nil { + return nil, fmt.Errorf("读取响应失败: %w", err) + } + + // 检查是否需要 URL 降级 + if shouldAntigravityFallbackToNextURL(nil, resp.StatusCode) && urlIdx < len(availableURLs)-1 { + antigravity.DefaultURLAvailability.MarkUnavailable(baseURL) + log.Printf("[antigravity-Test] URL fallback (HTTP %d): %s -> %s", resp.StatusCode, baseURL, availableURLs[urlIdx+1]) + continue + } + + if resp.StatusCode >= 400 { + return nil, fmt.Errorf("API 返回 %d: %s", resp.StatusCode, string(respBody)) + } + + // 解析流式响应,提取文本 + text := extractTextFromSSEResponse(respBody) + + return &TestConnectionResult{ + Text: text, + MappedModel: mappedModel, + }, nil } - // 解析流式响应,提取文本 - text := extractTextFromSSEResponse(respBody) - - return &TestConnectionResult{ - Text: text, - MappedModel: mappedModel, - }, nil + return nil, lastErr } // buildGeminiTestRequest 构建 Gemini 格式测试请求 @@ -484,62 +536,86 @@ func (s *AntigravityGatewayService) Forward(ctx context.Context, c *gin.Context, // 如果客户端请求非流式,在响应处理阶段会收集完整流式响应后转换返回 action := "streamGenerateContent" + // URL fallback 循环 + availableURLs := antigravity.DefaultURLAvailability.GetAvailableURLs() + if len(availableURLs) == 0 { + availableURLs = antigravity.BaseURLs // 所有 URL 都不可用时,重试所有 + } + // 重试循环 var resp *http.Response - for attempt := 1; attempt <= antigravityMaxRetries; attempt++ { - // 检查 context 是否已取消(客户端断开连接) - select { - case <-ctx.Done(): - log.Printf("%s status=context_canceled error=%v", prefix, ctx.Err()) - return nil, ctx.Err() - default: - } +urlFallbackLoop: + for urlIdx, baseURL := range availableURLs { + for attempt := 1; attempt <= antigravityMaxRetries; attempt++ { + // 检查 context 是否已取消(客户端断开连接) + select { + case <-ctx.Done(): + log.Printf("%s status=context_canceled error=%v", prefix, ctx.Err()) + return nil, ctx.Err() + default: + } - upstreamReq, err := antigravity.NewAPIRequest(ctx, action, accessToken, geminiBody) - if err != nil { - return nil, err - } + upstreamReq, err := antigravity.NewAPIRequestWithURL(ctx, baseURL, action, accessToken, geminiBody) + if err != nil { + return nil, err + } - resp, err = s.httpUpstream.Do(upstreamReq, proxyURL, account.ID, account.Concurrency) - if err != nil { - if attempt < antigravityMaxRetries { - log.Printf("%s status=request_failed retry=%d/%d error=%v", prefix, attempt, antigravityMaxRetries, err) - if !sleepAntigravityBackoffWithContext(ctx, attempt) { - log.Printf("%s status=context_canceled_during_backoff", prefix) - return nil, ctx.Err() + resp, err = s.httpUpstream.Do(upstreamReq, proxyURL, account.ID, account.Concurrency) + if err != nil { + // 检查是否应触发 URL 降级 + if shouldAntigravityFallbackToNextURL(err, 0) && urlIdx < len(availableURLs)-1 { + antigravity.DefaultURLAvailability.MarkUnavailable(baseURL) + log.Printf("%s URL fallback (connection error): %s -> %s", prefix, baseURL, availableURLs[urlIdx+1]) + continue urlFallbackLoop } - continue - } - log.Printf("%s status=request_failed retries_exhausted error=%v", prefix, err) - return nil, s.writeClaudeError(c, http.StatusBadGateway, "upstream_error", "Upstream request failed after retries") - } - - if resp.StatusCode >= 400 && s.shouldRetryUpstreamError(resp.StatusCode) { - respBody, _ := io.ReadAll(io.LimitReader(resp.Body, 2<<20)) - _ = resp.Body.Close() - - if attempt < antigravityMaxRetries { - log.Printf("%s status=%d retry=%d/%d body=%s", prefix, resp.StatusCode, attempt, antigravityMaxRetries, truncateForLog(respBody, 500)) - if !sleepAntigravityBackoffWithContext(ctx, attempt) { - log.Printf("%s status=context_canceled_during_backoff", prefix) - return nil, ctx.Err() + if attempt < antigravityMaxRetries { + log.Printf("%s status=request_failed retry=%d/%d error=%v", prefix, attempt, antigravityMaxRetries, err) + if !sleepAntigravityBackoffWithContext(ctx, attempt) { + log.Printf("%s status=context_canceled_during_backoff", prefix) + return nil, ctx.Err() + } + continue } - continue + log.Printf("%s status=request_failed retries_exhausted error=%v", prefix, err) + return nil, s.writeClaudeError(c, http.StatusBadGateway, "upstream_error", "Upstream request failed after retries") } - // 所有重试都失败,标记限流状态 - if resp.StatusCode == 429 { - s.handleUpstreamError(ctx, prefix, account, resp.StatusCode, resp.Header, respBody) - } - // 最后一次尝试也失败 - resp = &http.Response{ - StatusCode: resp.StatusCode, - Header: resp.Header.Clone(), - Body: io.NopCloser(bytes.NewReader(respBody)), - } - break - } - break + // 检查是否应触发 URL 降级(仅 429) + if resp.StatusCode == http.StatusTooManyRequests && urlIdx < len(availableURLs)-1 { + respBody, _ := io.ReadAll(io.LimitReader(resp.Body, 2<<20)) + _ = resp.Body.Close() + antigravity.DefaultURLAvailability.MarkUnavailable(baseURL) + log.Printf("%s URL fallback (HTTP 429): %s -> %s body=%s", prefix, baseURL, availableURLs[urlIdx+1], truncateForLog(respBody, 200)) + continue urlFallbackLoop + } + + if resp.StatusCode >= 400 && s.shouldRetryUpstreamError(resp.StatusCode) { + respBody, _ := io.ReadAll(io.LimitReader(resp.Body, 2<<20)) + _ = resp.Body.Close() + + if attempt < antigravityMaxRetries { + log.Printf("%s status=%d retry=%d/%d body=%s", prefix, resp.StatusCode, attempt, antigravityMaxRetries, truncateForLog(respBody, 500)) + if !sleepAntigravityBackoffWithContext(ctx, attempt) { + log.Printf("%s status=context_canceled_during_backoff", prefix) + return nil, ctx.Err() + } + continue + } + // 所有重试都失败,标记限流状态 + if resp.StatusCode == 429 { + s.handleUpstreamError(ctx, prefix, account, resp.StatusCode, resp.Header, respBody) + } + // 最后一次尝试也失败 + resp = &http.Response{ + StatusCode: resp.StatusCode, + Header: resp.Header.Clone(), + Body: io.NopCloser(bytes.NewReader(respBody)), + } + break urlFallbackLoop + } + + break urlFallbackLoop + } } defer func() { _ = resp.Body.Close() }() @@ -1003,61 +1079,85 @@ func (s *AntigravityGatewayService) ForwardGemini(ctx context.Context, c *gin.Co // 如果客户端请求非流式,在响应处理阶段会收集完整流式响应后返回 upstreamAction := "streamGenerateContent" + // URL fallback 循环 + availableURLs := antigravity.DefaultURLAvailability.GetAvailableURLs() + if len(availableURLs) == 0 { + availableURLs = antigravity.BaseURLs // 所有 URL 都不可用时,重试所有 + } + // 重试循环 var resp *http.Response - for attempt := 1; attempt <= antigravityMaxRetries; attempt++ { - // 检查 context 是否已取消(客户端断开连接) - select { - case <-ctx.Done(): - log.Printf("%s status=context_canceled error=%v", prefix, ctx.Err()) - return nil, ctx.Err() - default: - } +urlFallbackLoop: + for urlIdx, baseURL := range availableURLs { + for attempt := 1; attempt <= antigravityMaxRetries; attempt++ { + // 检查 context 是否已取消(客户端断开连接) + select { + case <-ctx.Done(): + log.Printf("%s status=context_canceled error=%v", prefix, ctx.Err()) + return nil, ctx.Err() + default: + } - upstreamReq, err := antigravity.NewAPIRequest(ctx, upstreamAction, accessToken, wrappedBody) - if err != nil { - return nil, err - } + upstreamReq, err := antigravity.NewAPIRequestWithURL(ctx, baseURL, upstreamAction, accessToken, wrappedBody) + if err != nil { + return nil, err + } - resp, err = s.httpUpstream.Do(upstreamReq, proxyURL, account.ID, account.Concurrency) - if err != nil { - if attempt < antigravityMaxRetries { - log.Printf("%s status=request_failed retry=%d/%d error=%v", prefix, attempt, antigravityMaxRetries, err) - if !sleepAntigravityBackoffWithContext(ctx, attempt) { - log.Printf("%s status=context_canceled_during_backoff", prefix) - return nil, ctx.Err() + resp, err = s.httpUpstream.Do(upstreamReq, proxyURL, account.ID, account.Concurrency) + if err != nil { + // 检查是否应触发 URL 降级 + if shouldAntigravityFallbackToNextURL(err, 0) && urlIdx < len(availableURLs)-1 { + antigravity.DefaultURLAvailability.MarkUnavailable(baseURL) + log.Printf("%s URL fallback (connection error): %s -> %s", prefix, baseURL, availableURLs[urlIdx+1]) + continue urlFallbackLoop } - continue - } - log.Printf("%s status=request_failed retries_exhausted error=%v", prefix, err) - return nil, s.writeGoogleError(c, http.StatusBadGateway, "Upstream request failed after retries") - } - - if resp.StatusCode >= 400 && s.shouldRetryUpstreamError(resp.StatusCode) { - respBody, _ := io.ReadAll(io.LimitReader(resp.Body, 2<<20)) - _ = resp.Body.Close() - - if attempt < antigravityMaxRetries { - log.Printf("%s status=%d retry=%d/%d", prefix, resp.StatusCode, attempt, antigravityMaxRetries) - if !sleepAntigravityBackoffWithContext(ctx, attempt) { - log.Printf("%s status=context_canceled_during_backoff", prefix) - return nil, ctx.Err() + if attempt < antigravityMaxRetries { + log.Printf("%s status=request_failed retry=%d/%d error=%v", prefix, attempt, antigravityMaxRetries, err) + if !sleepAntigravityBackoffWithContext(ctx, attempt) { + log.Printf("%s status=context_canceled_during_backoff", prefix) + return nil, ctx.Err() + } + continue } - continue + log.Printf("%s status=request_failed retries_exhausted error=%v", prefix, err) + return nil, s.writeGoogleError(c, http.StatusBadGateway, "Upstream request failed after retries") } - // 所有重试都失败,标记限流状态 - if resp.StatusCode == 429 { - s.handleUpstreamError(ctx, prefix, account, resp.StatusCode, resp.Header, respBody) - } - resp = &http.Response{ - StatusCode: resp.StatusCode, - Header: resp.Header.Clone(), - Body: io.NopCloser(bytes.NewReader(respBody)), - } - break - } - break + // 检查是否应触发 URL 降级(仅 429) + if resp.StatusCode == http.StatusTooManyRequests && urlIdx < len(availableURLs)-1 { + respBody, _ := io.ReadAll(io.LimitReader(resp.Body, 2<<20)) + _ = resp.Body.Close() + antigravity.DefaultURLAvailability.MarkUnavailable(baseURL) + log.Printf("%s URL fallback (HTTP 429): %s -> %s body=%s", prefix, baseURL, availableURLs[urlIdx+1], truncateForLog(respBody, 200)) + continue urlFallbackLoop + } + + if resp.StatusCode >= 400 && s.shouldRetryUpstreamError(resp.StatusCode) { + respBody, _ := io.ReadAll(io.LimitReader(resp.Body, 2<<20)) + _ = resp.Body.Close() + + if attempt < antigravityMaxRetries { + log.Printf("%s status=%d retry=%d/%d", prefix, resp.StatusCode, attempt, antigravityMaxRetries) + if !sleepAntigravityBackoffWithContext(ctx, attempt) { + log.Printf("%s status=context_canceled_during_backoff", prefix) + return nil, ctx.Err() + } + continue + } + // 所有重试都失败,标记限流状态 + if resp.StatusCode == 429 { + s.handleUpstreamError(ctx, prefix, account, resp.StatusCode, resp.Header, respBody) + } + resp = &http.Response{ + StatusCode: resp.StatusCode, + Header: resp.Header.Clone(), + Body: io.NopCloser(bytes.NewReader(respBody)), + } + break urlFallbackLoop + } + + break urlFallbackLoop + } } defer func() { if resp != nil && resp.Body != nil { From 799b01063119db53aec5e0a283df176ff093cf69 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E7=A8=8B=E5=BA=8F=E7=8C=BFMT?= <32916545+mt21625457@users.noreply.github.com> Date: Fri, 9 Jan 2026 10:37:15 +0800 Subject: [PATCH 3/4] =?UTF-8?q?fix(auth):=20=E4=BF=AE=E5=A4=8D=20RefreshTo?= =?UTF-8?q?ken=20=E4=BD=BF=E7=94=A8=E8=BF=87=E6=9C=9F=20token=20=E6=97=B6?= =?UTF-8?q?=E7=9A=84=20nil=20pointer=20panic=20(#214)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * fix(auth): 修复 RefreshToken 使用过期 token 时的 nil pointer panic 问题分析: - RefreshToken 允许过期 token 继续流程(用于无感刷新) - 但 ValidateToken 在 token 过期时返回 nil claims - 导致后续访问 claims.UserID 时触发 panic 修复方案: - 修改 ValidateToken,在检测到 ErrTokenExpired 时仍然返回 claims - jwt-go 在解析时即使遇到过期错误,token.Claims 仍会被填充 - 这样 RefreshToken 可以正常获取用户信息并生成新 token 新增测试: - TestAuthService_ValidateToken_ExpiredReturnsClaimsWithError - TestAuthService_RefreshToken_ExpiredTokenNoPanic Co-Authored-By: Claude Opus 4.5 * fix(auth): 修复邮件验证服务未配置时可绕过验证的安全漏洞 当邮件验证开启但 emailService 未配置时,原逻辑允许用户绕过验证直接注册。 现在会返回 ErrServiceUnavailable 拒绝注册,确保配置错误不会导致安全问题。 - 在验证码检查前先检查 emailService 是否配置 - 添加日志记录帮助发现配置问题 - 新增单元测试覆盖该场景 Co-Authored-By: Claude Opus 4.5 --------- Co-authored-by: yangjianbo Co-authored-by: Claude Opus 4.5 --- backend/internal/service/auth_service.go | 17 ++++- .../service/auth_service_register_test.go | 76 ++++++++++++++++++- 2 files changed, 88 insertions(+), 5 deletions(-) diff --git a/backend/internal/service/auth_service.go b/backend/internal/service/auth_service.go index 85772e75..5a5ca03d 100644 --- a/backend/internal/service/auth_service.go +++ b/backend/internal/service/auth_service.go @@ -82,14 +82,18 @@ func (s *AuthService) RegisterWithVerification(ctx context.Context, email, passw // 检查是否需要邮件验证 if s.settingService != nil && s.settingService.IsEmailVerifyEnabled(ctx) { + // 如果邮件验证已开启但邮件服务未配置,拒绝注册 + // 这是一个配置错误,不应该允许绕过验证 + if s.emailService == nil { + log.Println("[Auth] Email verification enabled but email service not configured, rejecting registration") + return "", nil, ErrServiceUnavailable + } if verifyCode == "" { return "", nil, ErrEmailVerifyRequired } // 验证邮箱验证码 - if s.emailService != nil { - if err := s.emailService.VerifyCode(ctx, email, verifyCode); err != nil { - return "", nil, fmt.Errorf("verify code: %w", err) - } + if err := s.emailService.VerifyCode(ctx, email, verifyCode); err != nil { + return "", nil, fmt.Errorf("verify code: %w", err) } } @@ -336,6 +340,11 @@ func (s *AuthService) ValidateToken(tokenString string) (*JWTClaims, error) { if err != nil { if errors.Is(err, jwt.ErrTokenExpired) { + // token 过期但仍返回 claims(用于 RefreshToken 等场景) + // jwt-go 在解析时即使遇到过期错误,token.Claims 仍会被填充 + if claims, ok := token.Claims.(*JWTClaims); ok { + return claims, ErrTokenExpired + } return nil, ErrTokenExpired } return nil, ErrInvalidToken diff --git a/backend/internal/service/auth_service_register_test.go b/backend/internal/service/auth_service_register_test.go index cd6e2808..a31267ab 100644 --- a/backend/internal/service/auth_service_register_test.go +++ b/backend/internal/service/auth_service_register_test.go @@ -113,13 +113,27 @@ func TestAuthService_Register_Disabled(t *testing.T) { require.ErrorIs(t, err, ErrRegDisabled) } -func TestAuthService_Register_EmailVerifyRequired(t *testing.T) { +func TestAuthService_Register_EmailVerifyEnabledButServiceNotConfigured(t *testing.T) { repo := &userRepoStub{} + // 邮件验证开启但 emailCache 为 nil(emailService 未配置) service := newAuthService(repo, map[string]string{ SettingKeyRegistrationEnabled: "true", SettingKeyEmailVerifyEnabled: "true", }, nil) + // 应返回服务不可用错误,而不是允许绕过验证 + _, _, err := service.RegisterWithVerification(context.Background(), "user@test.com", "password", "any-code") + require.ErrorIs(t, err, ErrServiceUnavailable) +} + +func TestAuthService_Register_EmailVerifyRequired(t *testing.T) { + repo := &userRepoStub{} + cache := &emailCacheStub{} // 配置 emailService + service := newAuthService(repo, map[string]string{ + SettingKeyRegistrationEnabled: "true", + SettingKeyEmailVerifyEnabled: "true", + }, cache) + _, _, err := service.RegisterWithVerification(context.Background(), "user@test.com", "password", "") require.ErrorIs(t, err, ErrEmailVerifyRequired) } @@ -180,3 +194,63 @@ func TestAuthService_Register_Success(t *testing.T) { require.Len(t, repo.created, 1) require.True(t, user.CheckPassword("password")) } + +func TestAuthService_ValidateToken_ExpiredReturnsClaimsWithError(t *testing.T) { + repo := &userRepoStub{} + service := newAuthService(repo, nil, nil) + + // 创建用户并生成 token + user := &User{ + ID: 1, + Email: "test@test.com", + Role: RoleUser, + Status: StatusActive, + TokenVersion: 1, + } + token, err := service.GenerateToken(user) + require.NoError(t, err) + + // 验证有效 token + claims, err := service.ValidateToken(token) + require.NoError(t, err) + require.NotNil(t, claims) + require.Equal(t, int64(1), claims.UserID) + + // 模拟过期 token(通过创建一个过期很久的 token) + service.cfg.JWT.ExpireHour = -1 // 设置为负数使 token 立即过期 + expiredToken, err := service.GenerateToken(user) + require.NoError(t, err) + service.cfg.JWT.ExpireHour = 1 // 恢复 + + // 验证过期 token 应返回 claims 和 ErrTokenExpired + claims, err = service.ValidateToken(expiredToken) + require.ErrorIs(t, err, ErrTokenExpired) + require.NotNil(t, claims, "claims should not be nil when token is expired") + require.Equal(t, int64(1), claims.UserID) + require.Equal(t, "test@test.com", claims.Email) +} + +func TestAuthService_RefreshToken_ExpiredTokenNoPanic(t *testing.T) { + user := &User{ + ID: 1, + Email: "test@test.com", + Role: RoleUser, + Status: StatusActive, + TokenVersion: 1, + } + repo := &userRepoStub{user: user} + service := newAuthService(repo, nil, nil) + + // 创建过期 token + service.cfg.JWT.ExpireHour = -1 + expiredToken, err := service.GenerateToken(user) + require.NoError(t, err) + service.cfg.JWT.ExpireHour = 1 + + // RefreshToken 使用过期 token 不应 panic + require.NotPanics(t, func() { + newToken, err := service.RefreshToken(context.Background(), expiredToken) + require.NoError(t, err) + require.NotEmpty(t, newToken) + }) +} From 43f104bdf728e23a7babbb48642f20d63c95afc1 Mon Sep 17 00:00:00 2001 From: shaw Date: Fri, 9 Jan 2026 14:49:20 +0800 Subject: [PATCH 4/4] =?UTF-8?q?fix(auth):=20=E6=B3=A8=E5=86=8C=E6=8E=A5?= =?UTF-8?q?=E5=8F=A3=E5=AE=89=E5=85=A8=E5=8A=A0=E5=9B=BA=20-=20=E9=BB=98?= =?UTF-8?q?=E8=AE=A4=E5=85=B3=E9=97=AD=E6=B3=A8=E5=86=8C?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- backend/internal/service/auth_service.go | 18 ++++++---- .../service/auth_service_register_test.go | 36 ++++++++++++++++--- backend/internal/service/email_service.go | 9 +++-- backend/internal/service/setting_service.go | 4 +-- 4 files changed, 52 insertions(+), 15 deletions(-) diff --git a/backend/internal/service/auth_service.go b/backend/internal/service/auth_service.go index 5a5ca03d..6e685869 100644 --- a/backend/internal/service/auth_service.go +++ b/backend/internal/service/auth_service.go @@ -75,8 +75,8 @@ func (s *AuthService) Register(ctx context.Context, email, password string) (str // RegisterWithVerification 用户注册(支持邮件验证),返回token和用户 func (s *AuthService) RegisterWithVerification(ctx context.Context, email, password, verifyCode string) (string, *User, error) { - // 检查是否开放注册 - if s.settingService != nil && !s.settingService.IsRegistrationEnabled(ctx) { + // 检查是否开放注册(默认关闭:settingService 未配置时不允许注册) + if s.settingService == nil || !s.settingService.IsRegistrationEnabled(ctx) { return "", nil, ErrRegDisabled } @@ -132,6 +132,10 @@ func (s *AuthService) RegisterWithVerification(ctx context.Context, email, passw } if err := s.userRepo.Create(ctx, user); err != nil { + // 优先检查邮箱冲突错误(竞态条件下可能发生) + if errors.Is(err, ErrEmailExists) { + return "", nil, ErrEmailExists + } log.Printf("[Auth] Database error creating user: %v", err) return "", nil, ErrServiceUnavailable } @@ -152,8 +156,8 @@ type SendVerifyCodeResult struct { // SendVerifyCode 发送邮箱验证码(同步方式) func (s *AuthService) SendVerifyCode(ctx context.Context, email string) error { - // 检查是否开放注册 - if s.settingService != nil && !s.settingService.IsRegistrationEnabled(ctx) { + // 检查是否开放注册(默认关闭) + if s.settingService == nil || !s.settingService.IsRegistrationEnabled(ctx) { return ErrRegDisabled } @@ -185,8 +189,8 @@ func (s *AuthService) SendVerifyCode(ctx context.Context, email string) error { func (s *AuthService) SendVerifyCodeAsync(ctx context.Context, email string) (*SendVerifyCodeResult, error) { log.Printf("[Auth] SendVerifyCodeAsync called for email: %s", email) - // 检查是否开放注册 - if s.settingService != nil && !s.settingService.IsRegistrationEnabled(ctx) { + // 检查是否开放注册(默认关闭) + if s.settingService == nil || !s.settingService.IsRegistrationEnabled(ctx) { log.Println("[Auth] Registration is disabled") return nil, ErrRegDisabled } @@ -270,7 +274,7 @@ func (s *AuthService) IsTurnstileEnabled(ctx context.Context) bool { // IsRegistrationEnabled 检查是否开放注册 func (s *AuthService) IsRegistrationEnabled(ctx context.Context) bool { if s.settingService == nil { - return true + return false // 安全默认:settingService 未配置时关闭注册 } return s.settingService.IsRegistrationEnabled(ctx) } diff --git a/backend/internal/service/auth_service_register_test.go b/backend/internal/service/auth_service_register_test.go index a31267ab..bfd504a3 100644 --- a/backend/internal/service/auth_service_register_test.go +++ b/backend/internal/service/auth_service_register_test.go @@ -113,6 +113,15 @@ func TestAuthService_Register_Disabled(t *testing.T) { require.ErrorIs(t, err, ErrRegDisabled) } +func TestAuthService_Register_DisabledByDefault(t *testing.T) { + // 当 settings 为 nil(设置项不存在)时,注册应该默认关闭 + repo := &userRepoStub{} + service := newAuthService(repo, nil, nil) + + _, _, err := service.Register(context.Background(), "user@test.com", "password") + require.ErrorIs(t, err, ErrRegDisabled) +} + func TestAuthService_Register_EmailVerifyEnabledButServiceNotConfigured(t *testing.T) { repo := &userRepoStub{} // 邮件验证开启但 emailCache 为 nil(emailService 未配置) @@ -155,7 +164,9 @@ func TestAuthService_Register_EmailVerifyInvalid(t *testing.T) { func TestAuthService_Register_EmailExists(t *testing.T) { repo := &userRepoStub{exists: true} - service := newAuthService(repo, nil, nil) + service := newAuthService(repo, map[string]string{ + SettingKeyRegistrationEnabled: "true", + }, nil) _, _, err := service.Register(context.Background(), "user@test.com", "password") require.ErrorIs(t, err, ErrEmailExists) @@ -163,7 +174,9 @@ func TestAuthService_Register_EmailExists(t *testing.T) { func TestAuthService_Register_CheckEmailError(t *testing.T) { repo := &userRepoStub{existsErr: errors.New("db down")} - service := newAuthService(repo, nil, nil) + service := newAuthService(repo, map[string]string{ + SettingKeyRegistrationEnabled: "true", + }, nil) _, _, err := service.Register(context.Background(), "user@test.com", "password") require.ErrorIs(t, err, ErrServiceUnavailable) @@ -171,15 +184,30 @@ func TestAuthService_Register_CheckEmailError(t *testing.T) { func TestAuthService_Register_CreateError(t *testing.T) { repo := &userRepoStub{createErr: errors.New("create failed")} - service := newAuthService(repo, nil, nil) + service := newAuthService(repo, map[string]string{ + SettingKeyRegistrationEnabled: "true", + }, nil) _, _, err := service.Register(context.Background(), "user@test.com", "password") require.ErrorIs(t, err, ErrServiceUnavailable) } +func TestAuthService_Register_CreateEmailExistsRace(t *testing.T) { + // 模拟竞态条件:ExistsByEmail 返回 false,但 Create 时因唯一约束失败 + repo := &userRepoStub{createErr: ErrEmailExists} + service := newAuthService(repo, map[string]string{ + SettingKeyRegistrationEnabled: "true", + }, nil) + + _, _, err := service.Register(context.Background(), "user@test.com", "password") + require.ErrorIs(t, err, ErrEmailExists) +} + func TestAuthService_Register_Success(t *testing.T) { repo := &userRepoStub{nextID: 5} - service := newAuthService(repo, nil, nil) + service := newAuthService(repo, map[string]string{ + SettingKeyRegistrationEnabled: "true", + }, nil) token, user, err := service.Register(context.Background(), "user@test.com", "password") require.NoError(t, err) diff --git a/backend/internal/service/email_service.go b/backend/internal/service/email_service.go index afd8907c..55e137d6 100644 --- a/backend/internal/service/email_service.go +++ b/backend/internal/service/email_service.go @@ -5,6 +5,7 @@ import ( "crypto/rand" "crypto/tls" "fmt" + "log" "math/big" "net/smtp" "strconv" @@ -256,7 +257,9 @@ func (s *EmailService) VerifyCode(ctx context.Context, email, code string) error // 验证码不匹配 if data.Code != code { data.Attempts++ - _ = s.cache.SetVerificationCode(ctx, email, data, verifyCodeTTL) + if err := s.cache.SetVerificationCode(ctx, email, data, verifyCodeTTL); err != nil { + log.Printf("[Email] Failed to update verification attempt count: %v", err) + } if data.Attempts >= maxVerifyCodeAttempts { return ErrVerifyCodeMaxAttempts } @@ -264,7 +267,9 @@ func (s *EmailService) VerifyCode(ctx context.Context, email, code string) error } // 验证成功,删除验证码 - _ = s.cache.DeleteVerificationCode(ctx, email) + if err := s.cache.DeleteVerificationCode(ctx, email); err != nil { + log.Printf("[Email] Failed to delete verification code after success: %v", err) + } return nil } diff --git a/backend/internal/service/setting_service.go b/backend/internal/service/setting_service.go index 6ce8ba2b..965253cf 100644 --- a/backend/internal/service/setting_service.go +++ b/backend/internal/service/setting_service.go @@ -141,8 +141,8 @@ func (s *SettingService) UpdateSettings(ctx context.Context, settings *SystemSet func (s *SettingService) IsRegistrationEnabled(ctx context.Context) bool { value, err := s.settingRepo.GetValue(ctx, SettingKeyRegistrationEnabled) if err != nil { - // 默认开放注册 - return true + // 安全默认:如果设置不存在或查询出错,默认关闭注册 + return false } return value == "true" }