From afcfbb458d2e2ce9585289ecb36b9790cd882afb Mon Sep 17 00:00:00 2001 From: IanShaw <131567472+IanShaw027@users.noreply.github.com> Date: Thu, 8 Jan 2026 23:47:29 +0800 Subject: [PATCH 01/17] =?UTF-8?q?fix(gemini):=20Google=20One=20=E5=BC=BA?= =?UTF-8?q?=E5=88=B6=E4=BD=BF=E7=94=A8=E5=86=85=E7=BD=AE=20OAuth=20client?= =?UTF-8?q?=20+=20=E8=87=AA=E5=8A=A8=E8=8E=B7=E5=8F=96=20project=5Fid=20+?= =?UTF-8?q?=20UI=20=E4=BC=98=E5=8C=96=20(#212)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * fix(gemini): Google One 强制使用内置 OAuth client + 自动获取 project_id + UI 优化 ## 后端改动 ### 1. Google One 强制使用内置 Gemini CLI OAuth Client **问题**: - Google One 之前允许使用自定义 OAuth client,导致认证流程不稳定 - 与 Code Assist 的行为不一致 **解决方案**: - 修改 `gemini_oauth_service.go`: Google One 现在与 Code Assist 一样强制使用内置 client (L122-135) - 更新 `gemini_oauth_client.go`: ExchangeCode 和 RefreshToken 方法支持强制内置 client (L31-44, L77-86) - 简化 `geminicli/oauth.go`: Google One scope 选择逻辑 (L187-190) - 标记 `geminicli/constants.go`: DefaultGoogleOneScopes 为 DEPRECATED (L30-33) - 更新测试用例以反映新行为 **OAuth 类型对比**: | OAuth类型 | Client来源 | Scopes | Redirect URI | |-----------|-----------|--------|-----------------| | code_assist | 内置 Gemini CLI | DefaultCodeAssistScopes | https://codeassist.google.com/authcode | | google_one | 内置 Gemini CLI (新) | DefaultCodeAssistScopes | https://codeassist.google.com/authcode | | ai_studio | 必须自定义 | DefaultAIStudioScopes | http://localhost:1455/auth/callback | ### 2. Google One 自动获取 project_id **问题**: - Google One 个人账号测试模型时返回 403/404 错误 - 原因:cloudaicompanion API 需要 project_id,但个人账号无需手动创建 GCP 项目 **解决方案**: - 修改 `gemini_oauth_service.go`: OAuth 流程中自动调用 fetchProjectID - Google 通过 LoadCodeAssist API 自动分配 project_id - 与 Gemini CLI 行为保持一致 - 后端根据 project_id 自动选择正确的 API 端点 **影响**: - Google One 账号现在可以正常使用(需要重新授权) - Code Assist 和 AI Studio 账号不受影响 ### 3. 修复 Gemini 测试账号无内容输出问题 **问题**: - 测试 Gemini 账号时只显示"测试成功",没有显示 AI 响应内容 - 原因:processGeminiStream 在检查到 finishReason 时立即返回,跳过了内容提取 **解决方案**: - 修改 `account_test_service.go`: 调整逻辑顺序,先提取内容再检查是否完成 - 确保最后一个 chunk 的内容也能被正确显示 **影响**: - 所有 Gemini 账号类型(API Key、OAuth)的测试现在都会显示完整响应内容 - 用户可以看到流式输出效果 ## 前端改动 ### 1. 修复图标宽度压缩问题 **问题**: - 账户类型选择按钮中的图标在某些情况下会被压缩变形 **解决方案**: - 修改 `CreateAccountModal.vue`: 为所有平台图标容器添加 `shrink-0` 类 - 确保 Anthropic、OpenAI、Gemini、Antigravity 图标保持固定 8×8 尺寸 (32px × 32px) ### 2. 优化重新授权界面 **问题**: - 重新授权时显示三个可点击的授权类型选择按钮,可能导致用户误切换到不兼容的授权方式 **解决方案**: - 修改 `ReAuthAccountModal.vue` (admin 和普通用户版本): - 将可点击的授权类型选择按钮改为只读信息展示框 - 根据账号的 `credentials.oauth_type` 动态显示对应图标和文本 - 删除 `geminiAIStudioOAuthEnabled` 状态和 `handleSelectGeminiOAuthType` 方法 - 防止用户误操作 ## 测试验证 - ✅ 所有后端单元测试通过 - ✅ OAuth client 选择逻辑正确 - ✅ Google One 和 Code Assist 行为一致 - ✅ 测试账号显示完整响应内容 - ✅ UI 图标显示正常 - ✅ 重新授权界面只读展示正确 * fix(lint): 修复 golangci-lint 错误信息格式问题 - 将错误信息改为小写开头以符合 Go 代码规范 - 修复 ST1005: error strings should not be capitalized --- backend/internal/pkg/geminicli/constants.go | 7 +- backend/internal/pkg/geminicli/oauth.go | 10 +- backend/internal/pkg/geminicli/oauth_test.go | 4 +- .../repository/gemini_oauth_client.go | 8 +- .../internal/service/account_test_service.go | 14 +- .../internal/service/gemini_oauth_service.go | 23 ++- .../service/gemini_oauth_service_test.go | 8 +- .../components/account/CreateAccountModal.vue | 20 +-- .../components/account/ReAuthAccountModal.vue | 165 +++++------------- .../admin/account/ReAuthAccountModal.vue | 163 +++++------------ 10 files changed, 135 insertions(+), 287 deletions(-) diff --git a/backend/internal/pkg/geminicli/constants.go b/backend/internal/pkg/geminicli/constants.go index 6d7e5a5d..d4d52116 100644 --- a/backend/internal/pkg/geminicli/constants.go +++ b/backend/internal/pkg/geminicli/constants.go @@ -27,10 +27,9 @@ const ( // https://www.googleapis.com/auth/generative-language.retriever (often with cloud-platform). DefaultAIStudioScopes = "https://www.googleapis.com/auth/cloud-platform https://www.googleapis.com/auth/generative-language.retriever" - // DefaultScopes for Google One (personal Google accounts with Gemini access) - // Only used when a custom OAuth client is configured. When using the built-in Gemini CLI client, - // Google One uses DefaultCodeAssistScopes (same as code_assist) because the built-in client - // cannot request restricted scopes like generative-language.retriever or drive.readonly. + // DefaultGoogleOneScopes (DEPRECATED, no longer used) + // Google One now always uses the built-in Gemini CLI client with DefaultCodeAssistScopes. + // This constant is kept for backward compatibility but is not actively used. DefaultGoogleOneScopes = "https://www.googleapis.com/auth/cloud-platform https://www.googleapis.com/auth/generative-language.retriever https://www.googleapis.com/auth/drive.readonly https://www.googleapis.com/auth/userinfo.email https://www.googleapis.com/auth/userinfo.profile" // GeminiCLIRedirectURI is the redirect URI used by Gemini CLI for Code Assist OAuth. diff --git a/backend/internal/pkg/geminicli/oauth.go b/backend/internal/pkg/geminicli/oauth.go index 473017a2..c71e8aad 100644 --- a/backend/internal/pkg/geminicli/oauth.go +++ b/backend/internal/pkg/geminicli/oauth.go @@ -185,13 +185,9 @@ func EffectiveOAuthConfig(cfg OAuthConfig, oauthType string) (OAuthConfig, error effective.Scopes = DefaultAIStudioScopes } case "google_one": - // Google One uses built-in Gemini CLI client (same as code_assist) - // Built-in client can't request restricted scopes like generative-language.retriever - if isBuiltinClient { - effective.Scopes = DefaultCodeAssistScopes - } else { - effective.Scopes = DefaultGoogleOneScopes - } + // Google One always uses built-in Gemini CLI client (same as code_assist) + // Built-in client can't request restricted scopes like generative-language.retriever or drive.readonly + effective.Scopes = DefaultCodeAssistScopes default: // Default to Code Assist scopes effective.Scopes = DefaultCodeAssistScopes diff --git a/backend/internal/pkg/geminicli/oauth_test.go b/backend/internal/pkg/geminicli/oauth_test.go index 0520f0f2..0770730a 100644 --- a/backend/internal/pkg/geminicli/oauth_test.go +++ b/backend/internal/pkg/geminicli/oauth_test.go @@ -23,14 +23,14 @@ func TestEffectiveOAuthConfig_GoogleOne(t *testing.T) { wantErr: false, }, { - name: "Google One with custom client", + name: "Google One always uses built-in client (even if custom credentials passed)", input: OAuthConfig{ ClientID: "custom-client-id", ClientSecret: "custom-client-secret", }, oauthType: "google_one", wantClientID: "custom-client-id", - wantScopes: DefaultGoogleOneScopes, + wantScopes: DefaultCodeAssistScopes, // Uses code assist scopes even with custom client wantErr: false, }, { diff --git a/backend/internal/repository/gemini_oauth_client.go b/backend/internal/repository/gemini_oauth_client.go index 14ecfc89..8b7fe625 100644 --- a/backend/internal/repository/gemini_oauth_client.go +++ b/backend/internal/repository/gemini_oauth_client.go @@ -30,14 +30,15 @@ func (c *geminiOAuthClient) ExchangeCode(ctx context.Context, oauthType, code, c // Use different OAuth clients based on oauthType: // - code_assist: always use built-in Gemini CLI OAuth client (public) - // - google_one: uses configured OAuth client when provided; otherwise falls back to built-in client + // - google_one: always use built-in Gemini CLI OAuth client (public) // - ai_studio: requires a user-provided OAuth client oauthCfgInput := geminicli.OAuthConfig{ ClientID: c.cfg.Gemini.OAuth.ClientID, ClientSecret: c.cfg.Gemini.OAuth.ClientSecret, Scopes: c.cfg.Gemini.OAuth.Scopes, } - if oauthType == "code_assist" { + if oauthType == "code_assist" || oauthType == "google_one" { + // Force use of built-in Gemini CLI OAuth client oauthCfgInput.ClientID = "" oauthCfgInput.ClientSecret = "" } @@ -78,7 +79,8 @@ func (c *geminiOAuthClient) RefreshToken(ctx context.Context, oauthType, refresh ClientSecret: c.cfg.Gemini.OAuth.ClientSecret, Scopes: c.cfg.Gemini.OAuth.Scopes, } - if oauthType == "code_assist" { + if oauthType == "code_assist" || oauthType == "google_one" { + // Force use of built-in Gemini CLI OAuth client oauthCfgInput.ClientID = "" oauthCfgInput.ClientSecret = "" } diff --git a/backend/internal/service/account_test_service.go b/backend/internal/service/account_test_service.go index 7121a13d..8419c2b4 100644 --- a/backend/internal/service/account_test_service.go +++ b/backend/internal/service/account_test_service.go @@ -661,13 +661,7 @@ func (s *AccountTestService) processGeminiStream(c *gin.Context, body io.Reader) } if candidates, ok := data["candidates"].([]any); ok && len(candidates) > 0 { if candidate, ok := candidates[0].(map[string]any); ok { - // Check for completion - if finishReason, ok := candidate["finishReason"].(string); ok && finishReason != "" { - s.sendEvent(c, TestEvent{Type: "test_complete", Success: true}) - return nil - } - - // Extract content + // Extract content first (before checking completion) if content, ok := candidate["content"].(map[string]any); ok { if parts, ok := content["parts"].([]any); ok { for _, part := range parts { @@ -679,6 +673,12 @@ func (s *AccountTestService) processGeminiStream(c *gin.Context, body io.Reader) } } } + + // Check for completion after extracting content + if finishReason, ok := candidate["finishReason"].(string); ok && finishReason != "" { + s.sendEvent(c, TestEvent{Type: "test_complete", Success: true}) + return nil + } } } diff --git a/backend/internal/service/gemini_oauth_service.go b/backend/internal/service/gemini_oauth_service.go index 48d31da9..bc84baeb 100644 --- a/backend/internal/service/gemini_oauth_service.go +++ b/backend/internal/service/gemini_oauth_service.go @@ -120,15 +120,16 @@ func (s *GeminiOAuthService) GenerateAuthURL(ctx context.Context, proxyID *int64 } // OAuth client selection: - // - code_assist: always use built-in Gemini CLI OAuth client (public), regardless of configured client_id/secret. - // - google_one: uses configured OAuth client when provided; otherwise falls back to built-in client. - // - ai_studio: requires a user-provided OAuth client. + // - code_assist: always use built-in Gemini CLI OAuth client (public) + // - google_one: always use built-in Gemini CLI OAuth client (public) + // - ai_studio: requires a user-provided OAuth client oauthCfg := geminicli.OAuthConfig{ ClientID: s.cfg.Gemini.OAuth.ClientID, ClientSecret: s.cfg.Gemini.OAuth.ClientSecret, Scopes: s.cfg.Gemini.OAuth.Scopes, } - if oauthType == "code_assist" { + if oauthType == "code_assist" || oauthType == "google_one" { + // Force use of built-in Gemini CLI OAuth client oauthCfg.ClientID = "" oauthCfg.ClientSecret = "" } @@ -576,6 +577,20 @@ func (s *GeminiOAuthService) ExchangeCode(ctx context.Context, input *GeminiExch case "google_one": log.Printf("[GeminiOAuth] Processing google_one OAuth type") + + // Google One accounts use cloudaicompanion API, which requires a project_id. + // For personal accounts, Google auto-assigns a project_id via the LoadCodeAssist API. + if projectID == "" { + log.Printf("[GeminiOAuth] No project_id provided, attempting to fetch from LoadCodeAssist API...") + var err error + projectID, _, err = s.fetchProjectID(ctx, tokenResp.AccessToken, proxyURL) + if err != nil { + log.Printf("[GeminiOAuth] ERROR: Failed to fetch project_id: %v", err) + return nil, fmt.Errorf("google One accounts require a project_id, failed to auto-detect: %w", err) + } + log.Printf("[GeminiOAuth] Successfully fetched project_id: %s", projectID) + } + log.Printf("[GeminiOAuth] Attempting to fetch Google One tier from Drive API...") // Attempt to fetch Drive storage tier var storageInfo *geminicli.DriveStorageInfo diff --git a/backend/internal/service/gemini_oauth_service_test.go b/backend/internal/service/gemini_oauth_service_test.go index eb3d86e6..5591eb39 100644 --- a/backend/internal/service/gemini_oauth_service_test.go +++ b/backend/internal/service/gemini_oauth_service_test.go @@ -40,7 +40,7 @@ func TestGeminiOAuthService_GenerateAuthURL_RedirectURIStrategy(t *testing.T) { wantProjectID: "", }, { - name: "google_one uses custom client when configured and redirects to localhost", + name: "google_one always forces built-in client even when custom client configured", cfg: &config.Config{ Gemini: config.GeminiConfig{ OAuth: config.GeminiOAuthConfig{ @@ -50,9 +50,9 @@ func TestGeminiOAuthService_GenerateAuthURL_RedirectURIStrategy(t *testing.T) { }, }, oauthType: "google_one", - wantClientID: "custom-client-id", - wantRedirect: geminicli.AIStudioOAuthRedirectURI, - wantScope: geminicli.DefaultGoogleOneScopes, + wantClientID: geminicli.GeminiCLIOAuthClientID, + wantRedirect: geminicli.GeminiCLIRedirectURI, + wantScope: geminicli.DefaultCodeAssistScopes, wantProjectID: "", }, { diff --git a/frontend/src/components/account/CreateAccountModal.vue b/frontend/src/components/account/CreateAccountModal.vue index e90bec6c..5833632b 100644 --- a/frontend/src/components/account/CreateAccountModal.vue +++ b/frontend/src/components/account/CreateAccountModal.vue @@ -166,7 +166,7 @@ >
-
+
diff --git a/frontend/src/components/account/ReAuthAccountModal.vue b/frontend/src/components/account/ReAuthAccountModal.vue index 43d1198f..b2734b4f 100644 --- a/frontend/src/components/account/ReAuthAccountModal.vue +++ b/frontend/src/components/account/ReAuthAccountModal.vue @@ -73,113 +73,48 @@
- -
- {{ t('admin.accounts.oauth.gemini.oauthTypeLabel') }} -
- - - - - + +
+
+ {{ t('admin.accounts.oauth.gemini.oauthTypeLabel') }}
-
+
+
+ + + +
+
+ + {{ + geminiOAuthType === 'google_one' + ? 'Google One' + : geminiOAuthType === 'code_assist' + ? t('admin.accounts.gemini.oauthType.builtInTitle') + : t('admin.accounts.gemini.oauthType.customTitle') + }} + + + {{ + geminiOAuthType === 'google_one' + ? '个人账号' + : geminiOAuthType === 'code_assist' + ? t('admin.accounts.gemini.oauthType.builtInDesc') + : t('admin.accounts.gemini.oauthType.customDesc') + }} + +
+
+
(null) // State const addMethod = ref('oauth') const geminiOAuthType = ref<'code_assist' | 'google_one' | 'ai_studio'>('code_assist') -const geminiAIStudioOAuthEnabled = ref(false) // Computed - check platform const isOpenAI = computed(() => props.account?.platform === 'openai') @@ -367,14 +301,6 @@ watch( ? 'ai_studio' : 'code_assist' } - if (isGemini.value) { - geminiOAuth.getCapabilities().then((caps) => { - geminiAIStudioOAuthEnabled.value = !!caps?.ai_studio_oauth_enabled - if (!geminiAIStudioOAuthEnabled.value && geminiOAuthType.value === 'ai_studio') { - geminiOAuthType.value = 'code_assist' - } - }) - } } else { resetState() } @@ -385,7 +311,6 @@ watch( const resetState = () => { addMethod.value = 'oauth' geminiOAuthType.value = 'code_assist' - geminiAIStudioOAuthEnabled.value = false claudeOAuth.resetState() openaiOAuth.resetState() geminiOAuth.resetState() @@ -393,14 +318,6 @@ const resetState = () => { oauthFlowRef.value?.reset() } -const handleSelectGeminiOAuthType = (oauthType: 'code_assist' | 'google_one' | 'ai_studio') => { - if (oauthType === 'ai_studio' && !geminiAIStudioOAuthEnabled.value) { - appStore.showError(t('admin.accounts.oauth.gemini.aiStudioNotConfigured')) - return - } - geminiOAuthType.value = oauthType -} - const handleClose = () => { emit('close') } diff --git a/frontend/src/components/admin/account/ReAuthAccountModal.vue b/frontend/src/components/admin/account/ReAuthAccountModal.vue index d9838a2e..8133e029 100644 --- a/frontend/src/components/admin/account/ReAuthAccountModal.vue +++ b/frontend/src/components/admin/account/ReAuthAccountModal.vue @@ -73,111 +73,48 @@
- -
- {{ t('admin.accounts.oauth.gemini.oauthTypeLabel') }} -
- - - - - + +
+
+ {{ t('admin.accounts.oauth.gemini.oauthTypeLabel') }}
-
+
+
+ + + +
+
+ + {{ + geminiOAuthType === 'google_one' + ? 'Google One' + : geminiOAuthType === 'code_assist' + ? t('admin.accounts.gemini.oauthType.builtInTitle') + : t('admin.accounts.gemini.oauthType.customTitle') + }} + + + {{ + geminiOAuthType === 'google_one' + ? '个人账号' + : geminiOAuthType === 'code_assist' + ? t('admin.accounts.gemini.oauthType.builtInDesc') + : t('admin.accounts.gemini.oauthType.customDesc') + }} + +
+
+
(null) // State const addMethod = ref('oauth') const geminiOAuthType = ref<'code_assist' | 'google_one' | 'ai_studio'>('code_assist') -const geminiAIStudioOAuthEnabled = ref(false) // Computed - check platform const isOpenAI = computed(() => props.account?.platform === 'openai') @@ -365,14 +301,6 @@ watch( ? 'ai_studio' : 'code_assist' } - if (isGemini.value) { - geminiOAuth.getCapabilities().then((caps) => { - geminiAIStudioOAuthEnabled.value = !!caps?.ai_studio_oauth_enabled - if (!geminiAIStudioOAuthEnabled.value && geminiOAuthType.value === 'ai_studio') { - geminiOAuthType.value = 'code_assist' - } - }) - } } else { resetState() } @@ -383,7 +311,6 @@ watch( const resetState = () => { addMethod.value = 'oauth' geminiOAuthType.value = 'code_assist' - geminiAIStudioOAuthEnabled.value = false claudeOAuth.resetState() openaiOAuth.resetState() geminiOAuth.resetState() @@ -391,14 +318,6 @@ const resetState = () => { oauthFlowRef.value?.reset() } -const handleSelectGeminiOAuthType = (oauthType: 'code_assist' | 'google_one' | 'ai_studio') => { - if (oauthType === 'ai_studio' && !geminiAIStudioOAuthEnabled.value) { - appStore.showError(t('admin.accounts.oauth.gemini.aiStudioNotConfigured')) - return - } - geminiOAuthType.value = oauthType -} - const handleClose = () => { emit('close') } From 5b8d4fb0479eeb14356066e17576de928c466771 Mon Sep 17 00:00:00 2001 From: cyhhao Date: Fri, 9 Jan 2026 00:34:49 +0800 Subject: [PATCH 02/17] feat(openai): add AI SDK content format compatibility for OAuth accounts - Add normalizeInputForCodexAPI function to convert AI SDK multi-part content format to simplified format expected by ChatGPT Codex API - AI SDK sends: {"content": [{"type": "input_text", "text": "..."}]} - Codex API expects: {"content": "..."} - Only applies to OAuth accounts (ChatGPT internal API) - API Key accounts remain unchanged (OpenAI Platform API supports both) --- .../service/openai_gateway_service.go | 106 +++++++++++++++++- 1 file changed, 105 insertions(+), 1 deletion(-) diff --git a/backend/internal/service/openai_gateway_service.go b/backend/internal/service/openai_gateway_service.go index d744bfab..f26404c8 100644 --- a/backend/internal/service/openai_gateway_service.go +++ b/backend/internal/service/openai_gateway_service.go @@ -540,10 +540,19 @@ func (s *OpenAIGatewayService) Forward(ctx context.Context, c *gin.Context, acco bodyModified = true } - // For OAuth accounts using ChatGPT internal API, add store: false + // For OAuth accounts using ChatGPT internal API: + // 1. Add store: false + // 2. Normalize input format for Codex API compatibility if account.Type == AccountTypeOAuth { reqBody["store"] = false bodyModified = true + + // Normalize input format: convert AI SDK multi-part content format to simplified format + // AI SDK sends: {"content": [{"type": "input_text", "text": "..."}]} + // Codex API expects: {"content": "..."} + if normalizeInputForCodexAPI(reqBody) { + bodyModified = true + } } // Re-serialize body only if modified @@ -1085,6 +1094,101 @@ func (s *OpenAIGatewayService) replaceModelInResponseBody(body []byte, fromModel return newBody } +// normalizeInputForCodexAPI converts AI SDK multi-part content format to simplified format +// that the ChatGPT internal Codex API expects. +// +// AI SDK sends content as an array of typed objects: +// +// {"content": [{"type": "input_text", "text": "hello"}]} +// +// ChatGPT Codex API expects content as a simple string: +// +// {"content": "hello"} +// +// This function modifies reqBody in-place and returns true if any modification was made. +func normalizeInputForCodexAPI(reqBody map[string]any) bool { + input, ok := reqBody["input"] + if !ok { + return false + } + + // Handle case where input is a simple string (already compatible) + if _, isString := input.(string); isString { + return false + } + + // Handle case where input is an array of messages + inputArray, ok := input.([]any) + if !ok { + return false + } + + modified := false + for _, item := range inputArray { + message, ok := item.(map[string]any) + if !ok { + continue + } + + content, ok := message["content"] + if !ok { + continue + } + + // If content is already a string, no conversion needed + if _, isString := content.(string); isString { + continue + } + + // If content is an array (AI SDK format), convert to string + contentArray, ok := content.([]any) + if !ok { + continue + } + + // Extract text from content array + var textParts []string + for _, part := range contentArray { + partMap, ok := part.(map[string]any) + if !ok { + continue + } + + // Handle different content types + partType, _ := partMap["type"].(string) + switch partType { + case "input_text", "text": + // Extract text from input_text or text type + if text, ok := partMap["text"].(string); ok { + textParts = append(textParts, text) + } + case "input_image", "image": + // For images, we need to preserve the original format + // as ChatGPT Codex API may support images in a different way + // For now, skip image parts (they will be lost in conversion) + // TODO: Consider preserving image data or handling it separately + continue + case "input_file", "file": + // Similar to images, file inputs may need special handling + continue + default: + // For unknown types, try to extract text if available + if text, ok := partMap["text"].(string); ok { + textParts = append(textParts, text) + } + } + } + + // Convert content array to string + if len(textParts) > 0 { + message["content"] = strings.Join(textParts, "\n") + modified = true + } + } + + return modified +} + // OpenAIRecordUsageInput input for recording usage type OpenAIRecordUsageInput struct { Result *OpenAIForwardResult From 2d83941aaa5dd397d1119352c72a01c326c8c460 Mon Sep 17 00:00:00 2001 From: shaw Date: Fri, 9 Jan 2026 10:36:56 +0800 Subject: [PATCH 03/17] =?UTF-8?q?feat(antigravity):=20=E6=B7=BB=E5=8A=A0?= =?UTF-8?q?=20URL=20fallback=20=E6=9C=BA=E5=88=B6=20(sandbox=20=E2=86=92?= =?UTF-8?q?=20daily=20=E2=86=92=20prod)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- backend/internal/pkg/antigravity/client.go | 221 +++++++---- backend/internal/pkg/antigravity/oauth.go | 70 +++- .../service/antigravity_gateway_service.go | 348 +++++++++++------- 3 files changed, 446 insertions(+), 193 deletions(-) diff --git a/backend/internal/pkg/antigravity/client.go b/backend/internal/pkg/antigravity/client.go index 8ff75f57..1248be95 100644 --- a/backend/internal/pkg/antigravity/client.go +++ b/backend/internal/pkg/antigravity/client.go @@ -5,8 +5,11 @@ import ( "bytes" "context" "encoding/json" + "errors" "fmt" "io" + "log" + "net" "net/http" "net/url" "strings" @@ -22,10 +25,10 @@ func resolveHost(urlStr string) string { return parsed.Host } -// NewAPIRequest 创建 Antigravity API 请求(v1internal 端点) -func NewAPIRequest(ctx context.Context, action, accessToken string, body []byte) (*http.Request, error) { +// NewAPIRequestWithURL 使用指定的 base URL 创建 Antigravity API 请求(v1internal 端点) +func NewAPIRequestWithURL(ctx context.Context, baseURL, action, accessToken string, body []byte) (*http.Request, error) { // 构建 URL,流式请求添加 ?alt=sse 参数 - apiURL := fmt.Sprintf("%s/v1internal:%s", BaseURL, action) + apiURL := fmt.Sprintf("%s/v1internal:%s", baseURL, action) isStream := action == "streamGenerateContent" if isStream { apiURL += "?alt=sse" @@ -53,11 +56,15 @@ func NewAPIRequest(ctx context.Context, action, accessToken string, body []byte) req.Host = host } - // 注意:requestType 已在 JSON body 的 V1InternalRequest 中设置,不需要 HTTP Header - return req, nil } +// NewAPIRequest 使用默认 URL 创建 Antigravity API 请求(v1internal 端点) +// 向后兼容:仅使用默认 BaseURL +func NewAPIRequest(ctx context.Context, action, accessToken string, body []byte) (*http.Request, error) { + return NewAPIRequestWithURL(ctx, BaseURL, action, accessToken, body) +} + // TokenResponse Google OAuth token 响应 type TokenResponse struct { AccessToken string `json:"access_token"` @@ -164,6 +171,38 @@ func NewClient(proxyURL string) *Client { } } +// isConnectionError 判断是否为连接错误(网络超时、DNS 失败、连接拒绝) +func isConnectionError(err error) bool { + if err == nil { + return false + } + + // 检查超时错误 + var netErr net.Error + if errors.As(err, &netErr) && netErr.Timeout() { + return true + } + + // 检查连接错误(DNS 失败、连接拒绝) + var opErr *net.OpError + if errors.As(err, &opErr) { + return true + } + + // 检查 URL 错误 + var urlErr *url.Error + return errors.As(err, &urlErr) +} + +// shouldFallbackToNextURL 判断是否应切换到下一个 URL +// 仅连接错误和 HTTP 429 触发 URL 降级 +func shouldFallbackToNextURL(err error, statusCode int) bool { + if isConnectionError(err) { + return true + } + return statusCode == http.StatusTooManyRequests +} + // ExchangeCode 用 authorization code 交换 token func (c *Client) ExchangeCode(ctx context.Context, code, codeVerifier string) (*TokenResponse, error) { params := url.Values{} @@ -272,6 +311,7 @@ func (c *Client) GetUserInfo(ctx context.Context, accessToken string) (*UserInfo } // LoadCodeAssist 获取账户信息,返回解析后的结构体和原始 JSON +// 支持 URL fallback:sandbox → daily → prod func (c *Client) LoadCodeAssist(ctx context.Context, accessToken string) (*LoadCodeAssistResponse, map[string]any, error) { reqBody := LoadCodeAssistRequest{} reqBody.Metadata.IDEType = "ANTIGRAVITY" @@ -281,40 +321,65 @@ func (c *Client) LoadCodeAssist(ctx context.Context, accessToken string) (*LoadC return nil, nil, fmt.Errorf("序列化请求失败: %w", err) } - url := BaseURL + "/v1internal:loadCodeAssist" - req, err := http.NewRequestWithContext(ctx, http.MethodPost, url, strings.NewReader(string(bodyBytes))) - if err != nil { - return nil, nil, fmt.Errorf("创建请求失败: %w", err) - } - req.Header.Set("Authorization", "Bearer "+accessToken) - req.Header.Set("Content-Type", "application/json") - req.Header.Set("User-Agent", UserAgent) - - resp, err := c.httpClient.Do(req) - if err != nil { - return nil, nil, fmt.Errorf("loadCodeAssist 请求失败: %w", err) - } - defer func() { _ = resp.Body.Close() }() - - respBodyBytes, err := io.ReadAll(resp.Body) - if err != nil { - return nil, nil, fmt.Errorf("读取响应失败: %w", err) + // 获取可用的 URL 列表 + availableURLs := DefaultURLAvailability.GetAvailableURLs() + if len(availableURLs) == 0 { + availableURLs = BaseURLs // 所有 URL 都不可用时,重试所有 } - if resp.StatusCode != http.StatusOK { - return nil, nil, fmt.Errorf("loadCodeAssist 失败 (HTTP %d): %s", resp.StatusCode, string(respBodyBytes)) + var lastErr error + for urlIdx, baseURL := range availableURLs { + apiURL := baseURL + "/v1internal:loadCodeAssist" + req, err := http.NewRequestWithContext(ctx, http.MethodPost, apiURL, strings.NewReader(string(bodyBytes))) + if err != nil { + lastErr = fmt.Errorf("创建请求失败: %w", err) + continue + } + req.Header.Set("Authorization", "Bearer "+accessToken) + req.Header.Set("Content-Type", "application/json") + req.Header.Set("User-Agent", UserAgent) + + resp, err := c.httpClient.Do(req) + if err != nil { + lastErr = fmt.Errorf("loadCodeAssist 请求失败: %w", err) + if shouldFallbackToNextURL(err, 0) && urlIdx < len(availableURLs)-1 { + DefaultURLAvailability.MarkUnavailable(baseURL) + log.Printf("[antigravity] loadCodeAssist URL fallback: %s -> %s", baseURL, availableURLs[urlIdx+1]) + continue + } + return nil, nil, lastErr + } + + respBodyBytes, err := io.ReadAll(resp.Body) + _ = resp.Body.Close() // 立即关闭,避免循环内 defer 导致的资源泄漏 + if err != nil { + return nil, nil, fmt.Errorf("读取响应失败: %w", err) + } + + // 检查是否需要 URL 降级 + if shouldFallbackToNextURL(nil, resp.StatusCode) && urlIdx < len(availableURLs)-1 { + DefaultURLAvailability.MarkUnavailable(baseURL) + log.Printf("[antigravity] loadCodeAssist URL fallback (HTTP %d): %s -> %s", resp.StatusCode, baseURL, availableURLs[urlIdx+1]) + continue + } + + if resp.StatusCode != http.StatusOK { + return nil, nil, fmt.Errorf("loadCodeAssist 失败 (HTTP %d): %s", resp.StatusCode, string(respBodyBytes)) + } + + var loadResp LoadCodeAssistResponse + if err := json.Unmarshal(respBodyBytes, &loadResp); err != nil { + return nil, nil, fmt.Errorf("响应解析失败: %w", err) + } + + // 解析原始 JSON 为 map + var rawResp map[string]any + _ = json.Unmarshal(respBodyBytes, &rawResp) + + return &loadResp, rawResp, nil } - var loadResp LoadCodeAssistResponse - if err := json.Unmarshal(respBodyBytes, &loadResp); err != nil { - return nil, nil, fmt.Errorf("响应解析失败: %w", err) - } - - // 解析原始 JSON 为 map - var rawResp map[string]any - _ = json.Unmarshal(respBodyBytes, &rawResp) - - return &loadResp, rawResp, nil + return nil, nil, lastErr } // ModelQuotaInfo 模型配额信息 @@ -339,6 +404,7 @@ type FetchAvailableModelsResponse struct { } // FetchAvailableModels 获取可用模型和配额信息,返回解析后的结构体和原始 JSON +// 支持 URL fallback:sandbox → daily → prod func (c *Client) FetchAvailableModels(ctx context.Context, accessToken, projectID string) (*FetchAvailableModelsResponse, map[string]any, error) { reqBody := FetchAvailableModelsRequest{Project: projectID} bodyBytes, err := json.Marshal(reqBody) @@ -346,38 +412,63 @@ func (c *Client) FetchAvailableModels(ctx context.Context, accessToken, projectI return nil, nil, fmt.Errorf("序列化请求失败: %w", err) } - apiURL := BaseURL + "/v1internal:fetchAvailableModels" - req, err := http.NewRequestWithContext(ctx, http.MethodPost, apiURL, strings.NewReader(string(bodyBytes))) - if err != nil { - return nil, nil, fmt.Errorf("创建请求失败: %w", err) - } - req.Header.Set("Authorization", "Bearer "+accessToken) - req.Header.Set("Content-Type", "application/json") - req.Header.Set("User-Agent", UserAgent) - - resp, err := c.httpClient.Do(req) - if err != nil { - return nil, nil, fmt.Errorf("fetchAvailableModels 请求失败: %w", err) - } - defer func() { _ = resp.Body.Close() }() - - respBodyBytes, err := io.ReadAll(resp.Body) - if err != nil { - return nil, nil, fmt.Errorf("读取响应失败: %w", err) + // 获取可用的 URL 列表 + availableURLs := DefaultURLAvailability.GetAvailableURLs() + if len(availableURLs) == 0 { + availableURLs = BaseURLs // 所有 URL 都不可用时,重试所有 } - if resp.StatusCode != http.StatusOK { - return nil, nil, fmt.Errorf("fetchAvailableModels 失败 (HTTP %d): %s", resp.StatusCode, string(respBodyBytes)) + var lastErr error + for urlIdx, baseURL := range availableURLs { + apiURL := baseURL + "/v1internal:fetchAvailableModels" + req, err := http.NewRequestWithContext(ctx, http.MethodPost, apiURL, strings.NewReader(string(bodyBytes))) + if err != nil { + lastErr = fmt.Errorf("创建请求失败: %w", err) + continue + } + req.Header.Set("Authorization", "Bearer "+accessToken) + req.Header.Set("Content-Type", "application/json") + req.Header.Set("User-Agent", UserAgent) + + resp, err := c.httpClient.Do(req) + if err != nil { + lastErr = fmt.Errorf("fetchAvailableModels 请求失败: %w", err) + if shouldFallbackToNextURL(err, 0) && urlIdx < len(availableURLs)-1 { + DefaultURLAvailability.MarkUnavailable(baseURL) + log.Printf("[antigravity] fetchAvailableModels URL fallback: %s -> %s", baseURL, availableURLs[urlIdx+1]) + continue + } + return nil, nil, lastErr + } + + respBodyBytes, err := io.ReadAll(resp.Body) + _ = resp.Body.Close() // 立即关闭,避免循环内 defer 导致的资源泄漏 + if err != nil { + return nil, nil, fmt.Errorf("读取响应失败: %w", err) + } + + // 检查是否需要 URL 降级 + if shouldFallbackToNextURL(nil, resp.StatusCode) && urlIdx < len(availableURLs)-1 { + DefaultURLAvailability.MarkUnavailable(baseURL) + log.Printf("[antigravity] fetchAvailableModels URL fallback (HTTP %d): %s -> %s", resp.StatusCode, baseURL, availableURLs[urlIdx+1]) + continue + } + + if resp.StatusCode != http.StatusOK { + return nil, nil, fmt.Errorf("fetchAvailableModels 失败 (HTTP %d): %s", resp.StatusCode, string(respBodyBytes)) + } + + var modelsResp FetchAvailableModelsResponse + if err := json.Unmarshal(respBodyBytes, &modelsResp); err != nil { + return nil, nil, fmt.Errorf("响应解析失败: %w", err) + } + + // 解析原始 JSON 为 map + var rawResp map[string]any + _ = json.Unmarshal(respBodyBytes, &rawResp) + + return &modelsResp, rawResp, nil } - var modelsResp FetchAvailableModelsResponse - if err := json.Unmarshal(respBodyBytes, &modelsResp); err != nil { - return nil, nil, fmt.Errorf("响应解析失败: %w", err) - } - - // 解析原始 JSON 为 map - var rawResp map[string]any - _ = json.Unmarshal(respBodyBytes, &rawResp) - - return &modelsResp, rawResp, nil + return nil, nil, lastErr } diff --git a/backend/internal/pkg/antigravity/oauth.go b/backend/internal/pkg/antigravity/oauth.go index e88c203b..736c45df 100644 --- a/backend/internal/pkg/antigravity/oauth.go +++ b/backend/internal/pkg/antigravity/oauth.go @@ -32,17 +32,79 @@ const ( "https://www.googleapis.com/auth/cclog " + "https://www.googleapis.com/auth/experimentsandconfigs" - // API 端点 - // 优先使用 sandbox daily URL,配额更宽松 - BaseURL = "https://daily-cloudcode-pa.sandbox.googleapis.com" - // User-Agent(模拟官方客户端) UserAgent = "antigravity/1.104.0 darwin/arm64" // Session 过期时间 SessionTTL = 30 * time.Minute + + // URL 可用性 TTL(不可用 URL 的恢复时间) + URLAvailabilityTTL = 5 * time.Minute ) +// BaseURLs 定义 Antigravity API 端点,按优先级排序 +// fallback 顺序: sandbox → daily → prod +var BaseURLs = []string{ + "https://daily-cloudcode-pa.sandbox.googleapis.com", // sandbox + "https://daily-cloudcode-pa.googleapis.com", // daily + "https://cloudcode-pa.googleapis.com", // prod +} + +// BaseURL 默认 URL(保持向后兼容) +var BaseURL = BaseURLs[0] + +// URLAvailability 管理 URL 可用性状态(带 TTL 自动恢复) +type URLAvailability struct { + mu sync.RWMutex + unavailable map[string]time.Time // URL -> 恢复时间 + ttl time.Duration +} + +// DefaultURLAvailability 全局 URL 可用性管理器 +var DefaultURLAvailability = NewURLAvailability(URLAvailabilityTTL) + +// NewURLAvailability 创建 URL 可用性管理器 +func NewURLAvailability(ttl time.Duration) *URLAvailability { + return &URLAvailability{ + unavailable: make(map[string]time.Time), + ttl: ttl, + } +} + +// MarkUnavailable 标记 URL 临时不可用 +func (u *URLAvailability) MarkUnavailable(url string) { + u.mu.Lock() + defer u.mu.Unlock() + u.unavailable[url] = time.Now().Add(u.ttl) +} + +// IsAvailable 检查 URL 是否可用 +func (u *URLAvailability) IsAvailable(url string) bool { + u.mu.RLock() + defer u.mu.RUnlock() + expiry, exists := u.unavailable[url] + if !exists { + return true + } + return time.Now().After(expiry) +} + +// GetAvailableURLs 返回可用的 URL 列表(保持优先级顺序) +func (u *URLAvailability) GetAvailableURLs() []string { + u.mu.RLock() + defer u.mu.RUnlock() + + now := time.Now() + result := make([]string, 0, len(BaseURLs)) + for _, url := range BaseURLs { + expiry, exists := u.unavailable[url] + if !exists || now.After(expiry) { + result = append(result, url) + } + } + return result +} + // OAuthSession 保存 OAuth 授权流程的临时状态 type OAuthSession struct { State string `json:"state"` diff --git a/backend/internal/service/antigravity_gateway_service.go b/backend/internal/service/antigravity_gateway_service.go index 2fe77b2d..573017cd 100644 --- a/backend/internal/service/antigravity_gateway_service.go +++ b/backend/internal/service/antigravity_gateway_service.go @@ -10,6 +10,7 @@ import ( "io" "log" mathrand "math/rand" + "net" "net/http" "strings" "sync/atomic" @@ -27,6 +28,32 @@ const ( antigravityRetryMaxDelay = 16 * time.Second ) +// isAntigravityConnectionError 判断是否为连接错误(网络超时、DNS 失败、连接拒绝) +func isAntigravityConnectionError(err error) bool { + if err == nil { + return false + } + + // 检查超时错误 + var netErr net.Error + if errors.As(err, &netErr) && netErr.Timeout() { + return true + } + + // 检查连接错误(DNS 失败、连接拒绝) + var opErr *net.OpError + return errors.As(err, &opErr) +} + +// shouldAntigravityFallbackToNextURL 判断是否应切换到下一个 URL +// 仅连接错误和 HTTP 429 触发 URL 降级 +func shouldAntigravityFallbackToNextURL(err error, statusCode int) bool { + if isAntigravityConnectionError(err) { + return true + } + return statusCode == http.StatusTooManyRequests +} + // getSessionID 从 gin.Context 获取 session_id(用于日志追踪) func getSessionID(c *gin.Context) string { if c == nil { @@ -181,45 +208,70 @@ func (s *AntigravityGatewayService) TestConnection(ctx context.Context, account return nil, fmt.Errorf("构建请求失败: %w", err) } - // 构建 HTTP 请求(总是使用流式 endpoint,与官方客户端一致) - req, err := antigravity.NewAPIRequest(ctx, "streamGenerateContent", accessToken, requestBody) - if err != nil { - return nil, err - } - - // 调试日志:Test 请求信息 - log.Printf("[antigravity-Test] account=%s request_size=%d url=%s", account.Name, len(requestBody), req.URL.String()) - // 代理 URL proxyURL := "" if account.ProxyID != nil && account.Proxy != nil { proxyURL = account.Proxy.URL() } - // 发送请求 - resp, err := s.httpUpstream.Do(req, proxyURL, account.ID, account.Concurrency) - if err != nil { - return nil, fmt.Errorf("请求失败: %w", err) - } - defer func() { _ = resp.Body.Close() }() - - // 读取响应 - respBody, err := io.ReadAll(io.LimitReader(resp.Body, 2<<20)) - if err != nil { - return nil, fmt.Errorf("读取响应失败: %w", err) + // URL fallback 循环 + availableURLs := antigravity.DefaultURLAvailability.GetAvailableURLs() + if len(availableURLs) == 0 { + availableURLs = antigravity.BaseURLs // 所有 URL 都不可用时,重试所有 } - if resp.StatusCode >= 400 { - return nil, fmt.Errorf("API 返回 %d: %s", resp.StatusCode, string(respBody)) + var lastErr error + for urlIdx, baseURL := range availableURLs { + // 构建 HTTP 请求(总是使用流式 endpoint,与官方客户端一致) + req, err := antigravity.NewAPIRequestWithURL(ctx, baseURL, "streamGenerateContent", accessToken, requestBody) + if err != nil { + lastErr = err + continue + } + + // 调试日志:Test 请求信息 + log.Printf("[antigravity-Test] account=%s request_size=%d url=%s", account.Name, len(requestBody), req.URL.String()) + + // 发送请求 + resp, err := s.httpUpstream.Do(req, proxyURL, account.ID, account.Concurrency) + if err != nil { + lastErr = fmt.Errorf("请求失败: %w", err) + if shouldAntigravityFallbackToNextURL(err, 0) && urlIdx < len(availableURLs)-1 { + antigravity.DefaultURLAvailability.MarkUnavailable(baseURL) + log.Printf("[antigravity-Test] URL fallback: %s -> %s", baseURL, availableURLs[urlIdx+1]) + continue + } + return nil, lastErr + } + + // 读取响应 + respBody, err := io.ReadAll(io.LimitReader(resp.Body, 2<<20)) + _ = resp.Body.Close() // 立即关闭,避免循环内 defer 导致的资源泄漏 + if err != nil { + return nil, fmt.Errorf("读取响应失败: %w", err) + } + + // 检查是否需要 URL 降级 + if shouldAntigravityFallbackToNextURL(nil, resp.StatusCode) && urlIdx < len(availableURLs)-1 { + antigravity.DefaultURLAvailability.MarkUnavailable(baseURL) + log.Printf("[antigravity-Test] URL fallback (HTTP %d): %s -> %s", resp.StatusCode, baseURL, availableURLs[urlIdx+1]) + continue + } + + if resp.StatusCode >= 400 { + return nil, fmt.Errorf("API 返回 %d: %s", resp.StatusCode, string(respBody)) + } + + // 解析流式响应,提取文本 + text := extractTextFromSSEResponse(respBody) + + return &TestConnectionResult{ + Text: text, + MappedModel: mappedModel, + }, nil } - // 解析流式响应,提取文本 - text := extractTextFromSSEResponse(respBody) - - return &TestConnectionResult{ - Text: text, - MappedModel: mappedModel, - }, nil + return nil, lastErr } // buildGeminiTestRequest 构建 Gemini 格式测试请求 @@ -484,62 +536,86 @@ func (s *AntigravityGatewayService) Forward(ctx context.Context, c *gin.Context, // 如果客户端请求非流式,在响应处理阶段会收集完整流式响应后转换返回 action := "streamGenerateContent" + // URL fallback 循环 + availableURLs := antigravity.DefaultURLAvailability.GetAvailableURLs() + if len(availableURLs) == 0 { + availableURLs = antigravity.BaseURLs // 所有 URL 都不可用时,重试所有 + } + // 重试循环 var resp *http.Response - for attempt := 1; attempt <= antigravityMaxRetries; attempt++ { - // 检查 context 是否已取消(客户端断开连接) - select { - case <-ctx.Done(): - log.Printf("%s status=context_canceled error=%v", prefix, ctx.Err()) - return nil, ctx.Err() - default: - } +urlFallbackLoop: + for urlIdx, baseURL := range availableURLs { + for attempt := 1; attempt <= antigravityMaxRetries; attempt++ { + // 检查 context 是否已取消(客户端断开连接) + select { + case <-ctx.Done(): + log.Printf("%s status=context_canceled error=%v", prefix, ctx.Err()) + return nil, ctx.Err() + default: + } - upstreamReq, err := antigravity.NewAPIRequest(ctx, action, accessToken, geminiBody) - if err != nil { - return nil, err - } + upstreamReq, err := antigravity.NewAPIRequestWithURL(ctx, baseURL, action, accessToken, geminiBody) + if err != nil { + return nil, err + } - resp, err = s.httpUpstream.Do(upstreamReq, proxyURL, account.ID, account.Concurrency) - if err != nil { - if attempt < antigravityMaxRetries { - log.Printf("%s status=request_failed retry=%d/%d error=%v", prefix, attempt, antigravityMaxRetries, err) - if !sleepAntigravityBackoffWithContext(ctx, attempt) { - log.Printf("%s status=context_canceled_during_backoff", prefix) - return nil, ctx.Err() + resp, err = s.httpUpstream.Do(upstreamReq, proxyURL, account.ID, account.Concurrency) + if err != nil { + // 检查是否应触发 URL 降级 + if shouldAntigravityFallbackToNextURL(err, 0) && urlIdx < len(availableURLs)-1 { + antigravity.DefaultURLAvailability.MarkUnavailable(baseURL) + log.Printf("%s URL fallback (connection error): %s -> %s", prefix, baseURL, availableURLs[urlIdx+1]) + continue urlFallbackLoop } - continue - } - log.Printf("%s status=request_failed retries_exhausted error=%v", prefix, err) - return nil, s.writeClaudeError(c, http.StatusBadGateway, "upstream_error", "Upstream request failed after retries") - } - - if resp.StatusCode >= 400 && s.shouldRetryUpstreamError(resp.StatusCode) { - respBody, _ := io.ReadAll(io.LimitReader(resp.Body, 2<<20)) - _ = resp.Body.Close() - - if attempt < antigravityMaxRetries { - log.Printf("%s status=%d retry=%d/%d body=%s", prefix, resp.StatusCode, attempt, antigravityMaxRetries, truncateForLog(respBody, 500)) - if !sleepAntigravityBackoffWithContext(ctx, attempt) { - log.Printf("%s status=context_canceled_during_backoff", prefix) - return nil, ctx.Err() + if attempt < antigravityMaxRetries { + log.Printf("%s status=request_failed retry=%d/%d error=%v", prefix, attempt, antigravityMaxRetries, err) + if !sleepAntigravityBackoffWithContext(ctx, attempt) { + log.Printf("%s status=context_canceled_during_backoff", prefix) + return nil, ctx.Err() + } + continue } - continue + log.Printf("%s status=request_failed retries_exhausted error=%v", prefix, err) + return nil, s.writeClaudeError(c, http.StatusBadGateway, "upstream_error", "Upstream request failed after retries") } - // 所有重试都失败,标记限流状态 - if resp.StatusCode == 429 { - s.handleUpstreamError(ctx, prefix, account, resp.StatusCode, resp.Header, respBody) - } - // 最后一次尝试也失败 - resp = &http.Response{ - StatusCode: resp.StatusCode, - Header: resp.Header.Clone(), - Body: io.NopCloser(bytes.NewReader(respBody)), - } - break - } - break + // 检查是否应触发 URL 降级(仅 429) + if resp.StatusCode == http.StatusTooManyRequests && urlIdx < len(availableURLs)-1 { + respBody, _ := io.ReadAll(io.LimitReader(resp.Body, 2<<20)) + _ = resp.Body.Close() + antigravity.DefaultURLAvailability.MarkUnavailable(baseURL) + log.Printf("%s URL fallback (HTTP 429): %s -> %s body=%s", prefix, baseURL, availableURLs[urlIdx+1], truncateForLog(respBody, 200)) + continue urlFallbackLoop + } + + if resp.StatusCode >= 400 && s.shouldRetryUpstreamError(resp.StatusCode) { + respBody, _ := io.ReadAll(io.LimitReader(resp.Body, 2<<20)) + _ = resp.Body.Close() + + if attempt < antigravityMaxRetries { + log.Printf("%s status=%d retry=%d/%d body=%s", prefix, resp.StatusCode, attempt, antigravityMaxRetries, truncateForLog(respBody, 500)) + if !sleepAntigravityBackoffWithContext(ctx, attempt) { + log.Printf("%s status=context_canceled_during_backoff", prefix) + return nil, ctx.Err() + } + continue + } + // 所有重试都失败,标记限流状态 + if resp.StatusCode == 429 { + s.handleUpstreamError(ctx, prefix, account, resp.StatusCode, resp.Header, respBody) + } + // 最后一次尝试也失败 + resp = &http.Response{ + StatusCode: resp.StatusCode, + Header: resp.Header.Clone(), + Body: io.NopCloser(bytes.NewReader(respBody)), + } + break urlFallbackLoop + } + + break urlFallbackLoop + } } defer func() { _ = resp.Body.Close() }() @@ -1003,61 +1079,85 @@ func (s *AntigravityGatewayService) ForwardGemini(ctx context.Context, c *gin.Co // 如果客户端请求非流式,在响应处理阶段会收集完整流式响应后返回 upstreamAction := "streamGenerateContent" + // URL fallback 循环 + availableURLs := antigravity.DefaultURLAvailability.GetAvailableURLs() + if len(availableURLs) == 0 { + availableURLs = antigravity.BaseURLs // 所有 URL 都不可用时,重试所有 + } + // 重试循环 var resp *http.Response - for attempt := 1; attempt <= antigravityMaxRetries; attempt++ { - // 检查 context 是否已取消(客户端断开连接) - select { - case <-ctx.Done(): - log.Printf("%s status=context_canceled error=%v", prefix, ctx.Err()) - return nil, ctx.Err() - default: - } +urlFallbackLoop: + for urlIdx, baseURL := range availableURLs { + for attempt := 1; attempt <= antigravityMaxRetries; attempt++ { + // 检查 context 是否已取消(客户端断开连接) + select { + case <-ctx.Done(): + log.Printf("%s status=context_canceled error=%v", prefix, ctx.Err()) + return nil, ctx.Err() + default: + } - upstreamReq, err := antigravity.NewAPIRequest(ctx, upstreamAction, accessToken, wrappedBody) - if err != nil { - return nil, err - } + upstreamReq, err := antigravity.NewAPIRequestWithURL(ctx, baseURL, upstreamAction, accessToken, wrappedBody) + if err != nil { + return nil, err + } - resp, err = s.httpUpstream.Do(upstreamReq, proxyURL, account.ID, account.Concurrency) - if err != nil { - if attempt < antigravityMaxRetries { - log.Printf("%s status=request_failed retry=%d/%d error=%v", prefix, attempt, antigravityMaxRetries, err) - if !sleepAntigravityBackoffWithContext(ctx, attempt) { - log.Printf("%s status=context_canceled_during_backoff", prefix) - return nil, ctx.Err() + resp, err = s.httpUpstream.Do(upstreamReq, proxyURL, account.ID, account.Concurrency) + if err != nil { + // 检查是否应触发 URL 降级 + if shouldAntigravityFallbackToNextURL(err, 0) && urlIdx < len(availableURLs)-1 { + antigravity.DefaultURLAvailability.MarkUnavailable(baseURL) + log.Printf("%s URL fallback (connection error): %s -> %s", prefix, baseURL, availableURLs[urlIdx+1]) + continue urlFallbackLoop } - continue - } - log.Printf("%s status=request_failed retries_exhausted error=%v", prefix, err) - return nil, s.writeGoogleError(c, http.StatusBadGateway, "Upstream request failed after retries") - } - - if resp.StatusCode >= 400 && s.shouldRetryUpstreamError(resp.StatusCode) { - respBody, _ := io.ReadAll(io.LimitReader(resp.Body, 2<<20)) - _ = resp.Body.Close() - - if attempt < antigravityMaxRetries { - log.Printf("%s status=%d retry=%d/%d", prefix, resp.StatusCode, attempt, antigravityMaxRetries) - if !sleepAntigravityBackoffWithContext(ctx, attempt) { - log.Printf("%s status=context_canceled_during_backoff", prefix) - return nil, ctx.Err() + if attempt < antigravityMaxRetries { + log.Printf("%s status=request_failed retry=%d/%d error=%v", prefix, attempt, antigravityMaxRetries, err) + if !sleepAntigravityBackoffWithContext(ctx, attempt) { + log.Printf("%s status=context_canceled_during_backoff", prefix) + return nil, ctx.Err() + } + continue } - continue + log.Printf("%s status=request_failed retries_exhausted error=%v", prefix, err) + return nil, s.writeGoogleError(c, http.StatusBadGateway, "Upstream request failed after retries") } - // 所有重试都失败,标记限流状态 - if resp.StatusCode == 429 { - s.handleUpstreamError(ctx, prefix, account, resp.StatusCode, resp.Header, respBody) - } - resp = &http.Response{ - StatusCode: resp.StatusCode, - Header: resp.Header.Clone(), - Body: io.NopCloser(bytes.NewReader(respBody)), - } - break - } - break + // 检查是否应触发 URL 降级(仅 429) + if resp.StatusCode == http.StatusTooManyRequests && urlIdx < len(availableURLs)-1 { + respBody, _ := io.ReadAll(io.LimitReader(resp.Body, 2<<20)) + _ = resp.Body.Close() + antigravity.DefaultURLAvailability.MarkUnavailable(baseURL) + log.Printf("%s URL fallback (HTTP 429): %s -> %s body=%s", prefix, baseURL, availableURLs[urlIdx+1], truncateForLog(respBody, 200)) + continue urlFallbackLoop + } + + if resp.StatusCode >= 400 && s.shouldRetryUpstreamError(resp.StatusCode) { + respBody, _ := io.ReadAll(io.LimitReader(resp.Body, 2<<20)) + _ = resp.Body.Close() + + if attempt < antigravityMaxRetries { + log.Printf("%s status=%d retry=%d/%d", prefix, resp.StatusCode, attempt, antigravityMaxRetries) + if !sleepAntigravityBackoffWithContext(ctx, attempt) { + log.Printf("%s status=context_canceled_during_backoff", prefix) + return nil, ctx.Err() + } + continue + } + // 所有重试都失败,标记限流状态 + if resp.StatusCode == 429 { + s.handleUpstreamError(ctx, prefix, account, resp.StatusCode, resp.Header, respBody) + } + resp = &http.Response{ + StatusCode: resp.StatusCode, + Header: resp.Header.Clone(), + Body: io.NopCloser(bytes.NewReader(respBody)), + } + break urlFallbackLoop + } + + break urlFallbackLoop + } } defer func() { if resp != nil && resp.Body != nil { From 799b01063119db53aec5e0a283df176ff093cf69 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E7=A8=8B=E5=BA=8F=E7=8C=BFMT?= <32916545+mt21625457@users.noreply.github.com> Date: Fri, 9 Jan 2026 10:37:15 +0800 Subject: [PATCH 04/17] =?UTF-8?q?fix(auth):=20=E4=BF=AE=E5=A4=8D=20Refresh?= =?UTF-8?q?Token=20=E4=BD=BF=E7=94=A8=E8=BF=87=E6=9C=9F=20token=20?= =?UTF-8?q?=E6=97=B6=E7=9A=84=20nil=20pointer=20panic=20(#214)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * fix(auth): 修复 RefreshToken 使用过期 token 时的 nil pointer panic 问题分析: - RefreshToken 允许过期 token 继续流程(用于无感刷新) - 但 ValidateToken 在 token 过期时返回 nil claims - 导致后续访问 claims.UserID 时触发 panic 修复方案: - 修改 ValidateToken,在检测到 ErrTokenExpired 时仍然返回 claims - jwt-go 在解析时即使遇到过期错误,token.Claims 仍会被填充 - 这样 RefreshToken 可以正常获取用户信息并生成新 token 新增测试: - TestAuthService_ValidateToken_ExpiredReturnsClaimsWithError - TestAuthService_RefreshToken_ExpiredTokenNoPanic Co-Authored-By: Claude Opus 4.5 * fix(auth): 修复邮件验证服务未配置时可绕过验证的安全漏洞 当邮件验证开启但 emailService 未配置时,原逻辑允许用户绕过验证直接注册。 现在会返回 ErrServiceUnavailable 拒绝注册,确保配置错误不会导致安全问题。 - 在验证码检查前先检查 emailService 是否配置 - 添加日志记录帮助发现配置问题 - 新增单元测试覆盖该场景 Co-Authored-By: Claude Opus 4.5 --------- Co-authored-by: yangjianbo Co-authored-by: Claude Opus 4.5 --- backend/internal/service/auth_service.go | 17 ++++- .../service/auth_service_register_test.go | 76 ++++++++++++++++++- 2 files changed, 88 insertions(+), 5 deletions(-) diff --git a/backend/internal/service/auth_service.go b/backend/internal/service/auth_service.go index 85772e75..5a5ca03d 100644 --- a/backend/internal/service/auth_service.go +++ b/backend/internal/service/auth_service.go @@ -82,14 +82,18 @@ func (s *AuthService) RegisterWithVerification(ctx context.Context, email, passw // 检查是否需要邮件验证 if s.settingService != nil && s.settingService.IsEmailVerifyEnabled(ctx) { + // 如果邮件验证已开启但邮件服务未配置,拒绝注册 + // 这是一个配置错误,不应该允许绕过验证 + if s.emailService == nil { + log.Println("[Auth] Email verification enabled but email service not configured, rejecting registration") + return "", nil, ErrServiceUnavailable + } if verifyCode == "" { return "", nil, ErrEmailVerifyRequired } // 验证邮箱验证码 - if s.emailService != nil { - if err := s.emailService.VerifyCode(ctx, email, verifyCode); err != nil { - return "", nil, fmt.Errorf("verify code: %w", err) - } + if err := s.emailService.VerifyCode(ctx, email, verifyCode); err != nil { + return "", nil, fmt.Errorf("verify code: %w", err) } } @@ -336,6 +340,11 @@ func (s *AuthService) ValidateToken(tokenString string) (*JWTClaims, error) { if err != nil { if errors.Is(err, jwt.ErrTokenExpired) { + // token 过期但仍返回 claims(用于 RefreshToken 等场景) + // jwt-go 在解析时即使遇到过期错误,token.Claims 仍会被填充 + if claims, ok := token.Claims.(*JWTClaims); ok { + return claims, ErrTokenExpired + } return nil, ErrTokenExpired } return nil, ErrInvalidToken diff --git a/backend/internal/service/auth_service_register_test.go b/backend/internal/service/auth_service_register_test.go index cd6e2808..a31267ab 100644 --- a/backend/internal/service/auth_service_register_test.go +++ b/backend/internal/service/auth_service_register_test.go @@ -113,13 +113,27 @@ func TestAuthService_Register_Disabled(t *testing.T) { require.ErrorIs(t, err, ErrRegDisabled) } -func TestAuthService_Register_EmailVerifyRequired(t *testing.T) { +func TestAuthService_Register_EmailVerifyEnabledButServiceNotConfigured(t *testing.T) { repo := &userRepoStub{} + // 邮件验证开启但 emailCache 为 nil(emailService 未配置) service := newAuthService(repo, map[string]string{ SettingKeyRegistrationEnabled: "true", SettingKeyEmailVerifyEnabled: "true", }, nil) + // 应返回服务不可用错误,而不是允许绕过验证 + _, _, err := service.RegisterWithVerification(context.Background(), "user@test.com", "password", "any-code") + require.ErrorIs(t, err, ErrServiceUnavailable) +} + +func TestAuthService_Register_EmailVerifyRequired(t *testing.T) { + repo := &userRepoStub{} + cache := &emailCacheStub{} // 配置 emailService + service := newAuthService(repo, map[string]string{ + SettingKeyRegistrationEnabled: "true", + SettingKeyEmailVerifyEnabled: "true", + }, cache) + _, _, err := service.RegisterWithVerification(context.Background(), "user@test.com", "password", "") require.ErrorIs(t, err, ErrEmailVerifyRequired) } @@ -180,3 +194,63 @@ func TestAuthService_Register_Success(t *testing.T) { require.Len(t, repo.created, 1) require.True(t, user.CheckPassword("password")) } + +func TestAuthService_ValidateToken_ExpiredReturnsClaimsWithError(t *testing.T) { + repo := &userRepoStub{} + service := newAuthService(repo, nil, nil) + + // 创建用户并生成 token + user := &User{ + ID: 1, + Email: "test@test.com", + Role: RoleUser, + Status: StatusActive, + TokenVersion: 1, + } + token, err := service.GenerateToken(user) + require.NoError(t, err) + + // 验证有效 token + claims, err := service.ValidateToken(token) + require.NoError(t, err) + require.NotNil(t, claims) + require.Equal(t, int64(1), claims.UserID) + + // 模拟过期 token(通过创建一个过期很久的 token) + service.cfg.JWT.ExpireHour = -1 // 设置为负数使 token 立即过期 + expiredToken, err := service.GenerateToken(user) + require.NoError(t, err) + service.cfg.JWT.ExpireHour = 1 // 恢复 + + // 验证过期 token 应返回 claims 和 ErrTokenExpired + claims, err = service.ValidateToken(expiredToken) + require.ErrorIs(t, err, ErrTokenExpired) + require.NotNil(t, claims, "claims should not be nil when token is expired") + require.Equal(t, int64(1), claims.UserID) + require.Equal(t, "test@test.com", claims.Email) +} + +func TestAuthService_RefreshToken_ExpiredTokenNoPanic(t *testing.T) { + user := &User{ + ID: 1, + Email: "test@test.com", + Role: RoleUser, + Status: StatusActive, + TokenVersion: 1, + } + repo := &userRepoStub{user: user} + service := newAuthService(repo, nil, nil) + + // 创建过期 token + service.cfg.JWT.ExpireHour = -1 + expiredToken, err := service.GenerateToken(user) + require.NoError(t, err) + service.cfg.JWT.ExpireHour = 1 + + // RefreshToken 使用过期 token 不应 panic + require.NotPanics(t, func() { + newToken, err := service.RefreshToken(context.Background(), expiredToken) + require.NoError(t, err) + require.NotEmpty(t, newToken) + }) +} From 43f104bdf728e23a7babbb48642f20d63c95afc1 Mon Sep 17 00:00:00 2001 From: shaw Date: Fri, 9 Jan 2026 14:49:20 +0800 Subject: [PATCH 05/17] =?UTF-8?q?fix(auth):=20=E6=B3=A8=E5=86=8C=E6=8E=A5?= =?UTF-8?q?=E5=8F=A3=E5=AE=89=E5=85=A8=E5=8A=A0=E5=9B=BA=20-=20=E9=BB=98?= =?UTF-8?q?=E8=AE=A4=E5=85=B3=E9=97=AD=E6=B3=A8=E5=86=8C?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- backend/internal/service/auth_service.go | 18 ++++++---- .../service/auth_service_register_test.go | 36 ++++++++++++++++--- backend/internal/service/email_service.go | 9 +++-- backend/internal/service/setting_service.go | 4 +-- 4 files changed, 52 insertions(+), 15 deletions(-) diff --git a/backend/internal/service/auth_service.go b/backend/internal/service/auth_service.go index 5a5ca03d..6e685869 100644 --- a/backend/internal/service/auth_service.go +++ b/backend/internal/service/auth_service.go @@ -75,8 +75,8 @@ func (s *AuthService) Register(ctx context.Context, email, password string) (str // RegisterWithVerification 用户注册(支持邮件验证),返回token和用户 func (s *AuthService) RegisterWithVerification(ctx context.Context, email, password, verifyCode string) (string, *User, error) { - // 检查是否开放注册 - if s.settingService != nil && !s.settingService.IsRegistrationEnabled(ctx) { + // 检查是否开放注册(默认关闭:settingService 未配置时不允许注册) + if s.settingService == nil || !s.settingService.IsRegistrationEnabled(ctx) { return "", nil, ErrRegDisabled } @@ -132,6 +132,10 @@ func (s *AuthService) RegisterWithVerification(ctx context.Context, email, passw } if err := s.userRepo.Create(ctx, user); err != nil { + // 优先检查邮箱冲突错误(竞态条件下可能发生) + if errors.Is(err, ErrEmailExists) { + return "", nil, ErrEmailExists + } log.Printf("[Auth] Database error creating user: %v", err) return "", nil, ErrServiceUnavailable } @@ -152,8 +156,8 @@ type SendVerifyCodeResult struct { // SendVerifyCode 发送邮箱验证码(同步方式) func (s *AuthService) SendVerifyCode(ctx context.Context, email string) error { - // 检查是否开放注册 - if s.settingService != nil && !s.settingService.IsRegistrationEnabled(ctx) { + // 检查是否开放注册(默认关闭) + if s.settingService == nil || !s.settingService.IsRegistrationEnabled(ctx) { return ErrRegDisabled } @@ -185,8 +189,8 @@ func (s *AuthService) SendVerifyCode(ctx context.Context, email string) error { func (s *AuthService) SendVerifyCodeAsync(ctx context.Context, email string) (*SendVerifyCodeResult, error) { log.Printf("[Auth] SendVerifyCodeAsync called for email: %s", email) - // 检查是否开放注册 - if s.settingService != nil && !s.settingService.IsRegistrationEnabled(ctx) { + // 检查是否开放注册(默认关闭) + if s.settingService == nil || !s.settingService.IsRegistrationEnabled(ctx) { log.Println("[Auth] Registration is disabled") return nil, ErrRegDisabled } @@ -270,7 +274,7 @@ func (s *AuthService) IsTurnstileEnabled(ctx context.Context) bool { // IsRegistrationEnabled 检查是否开放注册 func (s *AuthService) IsRegistrationEnabled(ctx context.Context) bool { if s.settingService == nil { - return true + return false // 安全默认:settingService 未配置时关闭注册 } return s.settingService.IsRegistrationEnabled(ctx) } diff --git a/backend/internal/service/auth_service_register_test.go b/backend/internal/service/auth_service_register_test.go index a31267ab..bfd504a3 100644 --- a/backend/internal/service/auth_service_register_test.go +++ b/backend/internal/service/auth_service_register_test.go @@ -113,6 +113,15 @@ func TestAuthService_Register_Disabled(t *testing.T) { require.ErrorIs(t, err, ErrRegDisabled) } +func TestAuthService_Register_DisabledByDefault(t *testing.T) { + // 当 settings 为 nil(设置项不存在)时,注册应该默认关闭 + repo := &userRepoStub{} + service := newAuthService(repo, nil, nil) + + _, _, err := service.Register(context.Background(), "user@test.com", "password") + require.ErrorIs(t, err, ErrRegDisabled) +} + func TestAuthService_Register_EmailVerifyEnabledButServiceNotConfigured(t *testing.T) { repo := &userRepoStub{} // 邮件验证开启但 emailCache 为 nil(emailService 未配置) @@ -155,7 +164,9 @@ func TestAuthService_Register_EmailVerifyInvalid(t *testing.T) { func TestAuthService_Register_EmailExists(t *testing.T) { repo := &userRepoStub{exists: true} - service := newAuthService(repo, nil, nil) + service := newAuthService(repo, map[string]string{ + SettingKeyRegistrationEnabled: "true", + }, nil) _, _, err := service.Register(context.Background(), "user@test.com", "password") require.ErrorIs(t, err, ErrEmailExists) @@ -163,7 +174,9 @@ func TestAuthService_Register_EmailExists(t *testing.T) { func TestAuthService_Register_CheckEmailError(t *testing.T) { repo := &userRepoStub{existsErr: errors.New("db down")} - service := newAuthService(repo, nil, nil) + service := newAuthService(repo, map[string]string{ + SettingKeyRegistrationEnabled: "true", + }, nil) _, _, err := service.Register(context.Background(), "user@test.com", "password") require.ErrorIs(t, err, ErrServiceUnavailable) @@ -171,15 +184,30 @@ func TestAuthService_Register_CheckEmailError(t *testing.T) { func TestAuthService_Register_CreateError(t *testing.T) { repo := &userRepoStub{createErr: errors.New("create failed")} - service := newAuthService(repo, nil, nil) + service := newAuthService(repo, map[string]string{ + SettingKeyRegistrationEnabled: "true", + }, nil) _, _, err := service.Register(context.Background(), "user@test.com", "password") require.ErrorIs(t, err, ErrServiceUnavailable) } +func TestAuthService_Register_CreateEmailExistsRace(t *testing.T) { + // 模拟竞态条件:ExistsByEmail 返回 false,但 Create 时因唯一约束失败 + repo := &userRepoStub{createErr: ErrEmailExists} + service := newAuthService(repo, map[string]string{ + SettingKeyRegistrationEnabled: "true", + }, nil) + + _, _, err := service.Register(context.Background(), "user@test.com", "password") + require.ErrorIs(t, err, ErrEmailExists) +} + func TestAuthService_Register_Success(t *testing.T) { repo := &userRepoStub{nextID: 5} - service := newAuthService(repo, nil, nil) + service := newAuthService(repo, map[string]string{ + SettingKeyRegistrationEnabled: "true", + }, nil) token, user, err := service.Register(context.Background(), "user@test.com", "password") require.NoError(t, err) diff --git a/backend/internal/service/email_service.go b/backend/internal/service/email_service.go index afd8907c..55e137d6 100644 --- a/backend/internal/service/email_service.go +++ b/backend/internal/service/email_service.go @@ -5,6 +5,7 @@ import ( "crypto/rand" "crypto/tls" "fmt" + "log" "math/big" "net/smtp" "strconv" @@ -256,7 +257,9 @@ func (s *EmailService) VerifyCode(ctx context.Context, email, code string) error // 验证码不匹配 if data.Code != code { data.Attempts++ - _ = s.cache.SetVerificationCode(ctx, email, data, verifyCodeTTL) + if err := s.cache.SetVerificationCode(ctx, email, data, verifyCodeTTL); err != nil { + log.Printf("[Email] Failed to update verification attempt count: %v", err) + } if data.Attempts >= maxVerifyCodeAttempts { return ErrVerifyCodeMaxAttempts } @@ -264,7 +267,9 @@ func (s *EmailService) VerifyCode(ctx context.Context, email, code string) error } // 验证成功,删除验证码 - _ = s.cache.DeleteVerificationCode(ctx, email) + if err := s.cache.DeleteVerificationCode(ctx, email); err != nil { + log.Printf("[Email] Failed to delete verification code after success: %v", err) + } return nil } diff --git a/backend/internal/service/setting_service.go b/backend/internal/service/setting_service.go index 6ce8ba2b..965253cf 100644 --- a/backend/internal/service/setting_service.go +++ b/backend/internal/service/setting_service.go @@ -141,8 +141,8 @@ func (s *SettingService) UpdateSettings(ctx context.Context, settings *SystemSet func (s *SettingService) IsRegistrationEnabled(ctx context.Context) bool { value, err := s.settingRepo.GetValue(ctx, SettingKeyRegistrationEnabled) if err != nil { - // 默认开放注册 - return true + // 安全默认:如果设置不存在或查询出错,默认关闭注册 + return false } return value == "true" } From 5d1badfe67908443314f66608152ce42fac5a2d6 Mon Sep 17 00:00:00 2001 From: Xu Kang <7836246@qq.com> Date: Fri, 9 Jan 2026 15:28:55 +0800 Subject: [PATCH 06/17] fix: add missing i18n key admin.accounts.outputCopied (#218) --- frontend/src/i18n/locales/en.ts | 1 + frontend/src/i18n/locales/zh.ts | 1 + 2 files changed, 2 insertions(+) diff --git a/frontend/src/i18n/locales/en.ts b/frontend/src/i18n/locales/en.ts index c4cf6cc6..2732d84d 100644 --- a/frontend/src/i18n/locales/en.ts +++ b/frontend/src/i18n/locales/en.ts @@ -1486,6 +1486,7 @@ export default { testing: 'Testing...', retry: 'Retry', copyOutput: 'Copy output', + outputCopied: 'Output copied', startingTestForAccount: 'Starting test for account: {name}', testAccountTypeLabel: 'Account type: {type}', selectTestModel: 'Select Test Model', diff --git a/frontend/src/i18n/locales/zh.ts b/frontend/src/i18n/locales/zh.ts index 79ddf6cc..40aa39ab 100644 --- a/frontend/src/i18n/locales/zh.ts +++ b/frontend/src/i18n/locales/zh.ts @@ -1601,6 +1601,7 @@ export default { startTest: '开始测试', retry: '重试', copyOutput: '复制输出', + outputCopied: '输出已复制', startingTestForAccount: '开始测试账号:{name}', testAccountTypeLabel: '账号类型:{type}', selectTestModel: '选择测试模型', From 27291f2e5f69a7d3775c8a890f5f4ca7ce2c2272 Mon Sep 17 00:00:00 2001 From: shaw Date: Fri, 9 Jan 2026 17:28:31 +0800 Subject: [PATCH 07/17] =?UTF-8?q?fix(docker):=20=E4=BF=AE=E6=94=B9=20Redis?= =?UTF-8?q?=20=E9=85=8D=E7=BD=AE=E4=BB=A5=E6=94=AF=E6=8C=81=E5=8F=AF?= =?UTF-8?q?=E9=80=89=E7=9A=84=E5=AF=86=E7=A0=81=E8=AE=BE=E7=BD=AE?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- deploy/docker-compose.yml | 11 ++++++----- 1 file changed, 6 insertions(+), 5 deletions(-) diff --git a/deploy/docker-compose.yml b/deploy/docker-compose.yml index 6a370e9a..484df3a8 100644 --- a/deploy/docker-compose.yml +++ b/deploy/docker-compose.yml @@ -173,11 +173,12 @@ services: volumes: - redis_data:/data command: > - redis-server - --save 60 1 - --appendonly yes - --appendfsync everysec - ${REDIS_PASSWORD:+--requirepass ${REDIS_PASSWORD}} + sh -c ' + redis-server + --save 60 1 + --appendonly yes + --appendfsync everysec + ${REDIS_PASSWORD:+--requirepass "$REDIS_PASSWORD"}' environment: - TZ=${TZ:-Asia/Shanghai} # REDISCLI_AUTH is used by redis-cli for authentication (safer than -a flag) From ee9b9b397136dd27c11f289bfa55ac71e80ec3a4 Mon Sep 17 00:00:00 2001 From: IanShaw027 <131567472+IanShaw027@users.noreply.github.com> Date: Fri, 9 Jan 2026 17:38:45 +0800 Subject: [PATCH 08/17] =?UTF-8?q?fix(fe):=20=E4=BF=AE=E5=A4=8D=E8=A1=A8?= =?UTF-8?q?=E6=A0=BC=E5=88=86=E9=A1=B5=E5=92=8C=E5=9F=BA=E7=A1=80=E5=8A=9F?= =?UTF-8?q?=E8=83=BD=E9=97=AE=E9=A2=98?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit 修复的主要问题: 1. **分页切换失效**(AccountsView.vue) - 修复 useTableLoader 未解构 handlePageSizeChange 函数 - 添加 @update:pageSize 事件绑定到 Pagination 组件 2. **内存泄漏修复**(多个文件) - UsersView.vue: 添加 searchTimeout 清理和 abortController.abort() - ProxiesView.vue: 添加 onUnmounted 钩子清理定时器 - RedeemView.vue: 添加 onUnmounted 钩子清理定时器 3. **分页重置问题**(UsersView.vue) - toggleBuiltInFilter: 切换筛选器时重置 pagination.page = 1 - toggleAttributeFilter: 切换属性筛选时重置 pagination.page = 1 影响范围: - frontend/src/views/admin/AccountsView.vue - frontend/src/views/admin/ProxiesView.vue - frontend/src/views/admin/RedeemView.vue - frontend/src/views/admin/UsersView.vue 测试:✓ 前端构建测试通过 --- frontend/src/views/admin/AccountsView.vue | 4 ++-- frontend/src/views/admin/ProxiesView.vue | 7 ++++++- frontend/src/views/admin/RedeemView.vue | 7 ++++++- frontend/src/views/admin/UsersView.vue | 4 ++++ 4 files changed, 18 insertions(+), 4 deletions(-) diff --git a/frontend/src/views/admin/AccountsView.vue b/frontend/src/views/admin/AccountsView.vue index 0ca22a76..560ae12b 100644 --- a/frontend/src/views/admin/AccountsView.vue +++ b/frontend/src/views/admin/AccountsView.vue @@ -107,7 +107,7 @@ - + @@ -175,7 +175,7 @@ const statsAcc = ref(null) const togglingSchedulable = ref(null) const menu = reactive<{show:boolean, acc:Account|null, pos:{top:number, left:number}|null}>({ show: false, acc: null, pos: null }) -const { items: accounts, loading, params, pagination, load, reload, debouncedReload, handlePageChange } = useTableLoader({ +const { items: accounts, loading, params, pagination, load, reload, debouncedReload, handlePageChange, handlePageSizeChange } = useTableLoader({ fetchFn: adminAPI.accounts.list, initialParams: { platform: '', type: '', status: '', search: '' } }) diff --git a/frontend/src/views/admin/ProxiesView.vue b/frontend/src/views/admin/ProxiesView.vue index 00b43ba2..22d778d9 100644 --- a/frontend/src/views/admin/ProxiesView.vue +++ b/frontend/src/views/admin/ProxiesView.vue @@ -519,7 +519,7 @@ diff --git a/frontend/src/views/admin/RedeemView.vue b/frontend/src/views/admin/RedeemView.vue index 3503cd16..50c55ba3 100644 --- a/frontend/src/views/admin/RedeemView.vue +++ b/frontend/src/views/admin/RedeemView.vue @@ -364,7 +364,7 @@ diff --git a/frontend/src/views/admin/UsersView.vue b/frontend/src/views/admin/UsersView.vue index d4bf555c..f1fc9e1f 100644 --- a/frontend/src/views/admin/UsersView.vue +++ b/frontend/src/views/admin/UsersView.vue @@ -943,6 +943,7 @@ const toggleBuiltInFilter = (key: string) => { visibleFilters.add(key) } saveFiltersToStorage() + pagination.page = 1 loadUsers() } @@ -957,6 +958,7 @@ const toggleAttributeFilter = (attr: UserAttributeDefinition) => { activeAttributeFilters[attr.id] = '' } saveFiltersToStorage() + pagination.page = 1 loadUsers() } @@ -1059,5 +1061,7 @@ onMounted(async () => { onUnmounted(() => { document.removeEventListener('click', handleClickOutside) + clearTimeout(searchTimeout) + abortController?.abort() }) From 514f5802b54b18aa0211d5d3ba713bd5a3febb6e Mon Sep 17 00:00:00 2001 From: IanShaw027 <131567472+IanShaw027@users.noreply.github.com> Date: Fri, 9 Jan 2026 17:58:21 +0800 Subject: [PATCH 09/17] =?UTF-8?q?fix(fe):=20=E4=BF=AE=E5=A4=8D=E4=B8=AD?= =?UTF-8?q?=E4=BC=98=E5=85=88=E7=BA=A7=E8=A1=A8=E6=A0=BC=E5=8A=9F=E8=83=BD?= =?UTF-8?q?=E9=97=AE=E9=A2=98?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit 修复的问题: 1. **搜索和筛选防抖不同步**(AccountsView.vue) - 问题:筛选器使用 reload(立即),搜索使用 debouncedReload(300ms延迟) - 修复:统一使用 debouncedReload,避免多余的API调用 2. **useTableLoader 竞态条件**(useTableLoader.ts) - 问题:finally 块检查 signal.aborted 而不是 controller 实例 - 修复:检查 abortController === currentController 3. **改进错误处理**(UsersView.vue) - 添加详细错误消息:error.response?.data?.detail || error.message - 用户可以看到具体的错误原因而不是通用消息 4. **分页边界检查**(useTableLoader.ts, UsersView.vue) - 添加页码有效性检查:Math.max(1, Math.min(page, pagination.pages || 1)) - 防止分页越界导致显示空表 影响范围: - frontend/src/composables/useTableLoader.ts - frontend/src/views/admin/AccountsView.vue - frontend/src/views/admin/UsersView.vue 测试:✓ 前端构建测试通过 --- frontend/src/composables/useTableLoader.ts | 13 ++++++++----- frontend/src/views/admin/AccountsView.vue | 2 +- frontend/src/views/admin/UsersView.vue | 9 ++++++--- 3 files changed, 15 insertions(+), 9 deletions(-) diff --git a/frontend/src/composables/useTableLoader.ts b/frontend/src/composables/useTableLoader.ts index 01703ee1..5fb6c5e0 100644 --- a/frontend/src/composables/useTableLoader.ts +++ b/frontend/src/composables/useTableLoader.ts @@ -43,7 +43,8 @@ export function useTableLoader>(options: TableL if (abortController) { abortController.abort() } - abortController = new AbortController() + const currentController = new AbortController() + abortController = currentController loading.value = true try { @@ -51,9 +52,9 @@ export function useTableLoader>(options: TableL pagination.page, pagination.page_size, toRaw(params) as P, - { signal: abortController.signal } + { signal: currentController.signal } ) - + items.value = response.items || [] pagination.total = response.total || 0 pagination.pages = response.pages || 0 @@ -63,7 +64,7 @@ export function useTableLoader>(options: TableL throw error } } finally { - if (abortController && !abortController.signal.aborted) { + if (abortController === currentController) { loading.value = false } } @@ -77,7 +78,9 @@ export function useTableLoader>(options: TableL const debouncedReload = useDebounceFn(reload, debounceMs) const handlePageChange = (page: number) => { - pagination.page = page + // 确保页码在有效范围内 + const validPage = Math.max(1, Math.min(page, pagination.pages || 1)) + pagination.page = validPage load() } diff --git a/frontend/src/views/admin/AccountsView.vue b/frontend/src/views/admin/AccountsView.vue index 560ae12b..27cd9c19 100644 --- a/frontend/src/views/admin/AccountsView.vue +++ b/frontend/src/views/admin/AccountsView.vue @@ -7,7 +7,7 @@ v-model:searchQuery="params.search" :filters="params" @update:filters="(newFilters) => Object.assign(params, newFilters)" - @change="reload" + @change="debouncedReload" @update:searchQuery="debouncedReload" /> { } } } - } catch (error) { + } catch (error: any) { const errorInfo = error as { name?: string; code?: string } if (errorInfo?.name === 'AbortError' || errorInfo?.name === 'CanceledError' || errorInfo?.code === 'ERR_CANCELED') { return } - appStore.showError(t('admin.users.failedToLoad')) + const message = error.response?.data?.detail || error.message || t('admin.users.failedToLoad') + appStore.showError(message) console.error('Error loading users:', error) } finally { if (abortController === currentAbortController) { @@ -917,7 +918,9 @@ const handleSearch = () => { } const handlePageChange = (page: number) => { - pagination.page = page + // 确保页码在有效范围内 + const validPage = Math.max(1, Math.min(page, pagination.pages || 1)) + pagination.page = validPage loadUsers() } From 152d0cdec65d81cf9fcf00e060e77b468044de9d Mon Sep 17 00:00:00 2001 From: admin Date: Fri, 9 Jan 2026 12:05:25 +0800 Subject: [PATCH 10/17] =?UTF-8?q?feat(auth):=20=E6=B7=BB=E5=8A=A0=20Linux?= =?UTF-8?q?=20DO=20Connect=20OAuth=20=E7=99=BB=E5=BD=95=E6=94=AF=E6=8C=81?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit - 新增 Linux DO OAuth 配置项和环境变量支持 - 实现 OAuth 授权流程和回调处理 - 前端添加 Linux DO 登录按钮和回调页面 - 支持通过 Linux DO 账号注册/登录 - 添加相关国际化文本 🤖 Generated with [Claude Code](https://claude.com/claude-code) Co-Authored-By: Claude Opus 4.5 --- Linux DO Connect.md | 368 +++++++++++++ backend/cmd/server/VERSION | 2 +- backend/internal/config/config.go | 215 +++++++- backend/internal/config/config_test.go | 51 ++ .../internal/handler/auth_linuxdo_oauth.go | 517 ++++++++++++++++++ .../handler/auth_linuxdo_oauth_test.go | 74 +++ backend/internal/handler/dto/settings.go | 1 + backend/internal/handler/setting_handler.go | 1 + backend/internal/server/routes/auth.go | 2 + backend/internal/service/auth_service.go | 131 +++++ .../service/auth_service_register_test.go | 8 + backend/internal/service/setting_service.go | 1 + backend/internal/service/settings_view.go | 1 + deploy/config.example.yaml | 25 + frontend/src/i18n/locales/en.ts | 9 + frontend/src/i18n/locales/zh.ts | 9 + frontend/src/router/index.ts | 9 + frontend/src/stores/app.ts | 35 +- frontend/src/stores/auth.ts | 22 + frontend/src/types/index.ts | 1 + .../src/views/auth/LinuxDoCallbackView.vue | 119 ++++ frontend/src/views/auth/LoginView.vue | 55 ++ frontend/src/views/auth/RegisterView.vue | 55 ++ 23 files changed, 1675 insertions(+), 36 deletions(-) create mode 100644 Linux DO Connect.md create mode 100644 backend/internal/handler/auth_linuxdo_oauth.go create mode 100644 backend/internal/handler/auth_linuxdo_oauth_test.go create mode 100644 frontend/src/views/auth/LinuxDoCallbackView.vue diff --git a/Linux DO Connect.md b/Linux DO Connect.md new file mode 100644 index 00000000..7ca1260f --- /dev/null +++ b/Linux DO Connect.md @@ -0,0 +1,368 @@ +# Linux DO Connect + +OAuth(Open Authorization)是一个开放的网络授权标准,目前最新版本为 OAuth 2.0。我们日常使用的第三方登录(如 Google 账号登录)就采用了该标准。OAuth 允许用户授权第三方应用访问存储在其他服务提供商(如 Google)上的信息,无需在不同平台上重复填写注册信息。用户授权后,平台可以直接访问用户的账户信息进行身份验证,而用户无需向第三方应用提供密码。 + +目前系统已实现完整的 OAuth2 授权码(code)方式鉴权,但界面等配套功能还在持续完善中。让我们一起打造一个更完善的共享方案。 + +## 基本介绍 + +这是一套标准的 OAuth2 鉴权系统,可以让开发者共享论坛的用户基本信息。 + +- 可获取字段: + +| 参数 | 说明 | +| ----------------- | ------------------------------- | +| `id` | 用户唯一标识(不可变) | +| `username` | 论坛用户名 | +| `name` | 论坛用户昵称(可变) | +| `avatar_template` | 用户头像模板URL(支持多种尺寸) | +| `active` | 账号活跃状态 | +| `trust_level` | 信任等级(0-4) | +| `silenced` | 禁言状态 | +| `external_ids` | 外部ID关联信息 | +| `api_key` | API访问密钥 | + +通过这些信息,公益网站/接口可以实现: + +1. 基于 `id` 的服务频率限制 +2. 基于 `trust_level` 的服务额度分配 +3. 基于用户信息的滥用举报机制 + +## 相关端点 + +- Authorize 端点: `https://connect.linux.do/oauth2/authorize` +- Token 端点:`https://connect.linux.do/oauth2/token` +- 用户信息 端点:`https://connect.linux.do/api/user` + +## 申请使用 + +- 访问 [Connect.Linux.Do](https://connect.linux.do/) 申请接入你的应用。 + +![linuxdoconnect_1](https://wiki.linux.do/_next/image?url=%2Flinuxdoconnect_1.png&w=1080&q=75) + +- 点击 **`我的应用接入`** - **`申请新接入`**,填写相关信息。其中 **`回调地址`** 是你的应用接收用户信息的地址。 + +![linuxdoconnect_2](https://wiki.linux.do/_next/image?url=%2Flinuxdoconnect_2.png&w=1080&q=75) + +- 申请成功后,你将获得 **`Client Id`** 和 **`Client Secret`**,这是你应用的唯一身份凭证。 + +![linuxdoconnect_3](https://wiki.linux.do/_next/image?url=%2Flinuxdoconnect_3.png&w=1080&q=75) + +## 接入 Linux Do + +JavaScript +```JavaScript +// 安装第三方请求库(或使用原生的 Fetch API),本例中使用 axios +// npm install axios + +// 通过 OAuth2 获取 Linux Do 用户信息的参考流程 +const axios = require('axios'); +const readline = require('readline'); + +// 配置信息(建议通过环境变量配置,避免使用硬编码) +const CLIENT_ID = '你的 Client ID'; +const CLIENT_SECRET = '你的 Client Secret'; +const REDIRECT_URI = '你的回调地址'; +const AUTH_URL = 'https://connect.linux.do/oauth2/authorize'; +const TOKEN_URL = 'https://connect.linux.do/oauth2/token'; +const USER_INFO_URL = 'https://connect.linux.do/api/user'; + +// 第一步:生成授权 URL +function getAuthUrl() { + const params = new URLSearchParams({ + client_id: CLIENT_ID, + redirect_uri: REDIRECT_URI, + response_type: 'code', + scope: 'user' + }); + + return `${AUTH_URL}?${params.toString()}`; +} + +// 第二步:获取 code 参数 +function getCode() { + return new Promise((resolve) => { + // 本例中使用终端输入来模拟流程,仅供本地测试 + // 请在实际应用中替换为真实的处理逻辑 + const rl = readline.createInterface({ input: process.stdin, output: process.stdout }); + rl.question('从回调 URL 中提取出 code,粘贴到此处并按回车:', (answer) => { + rl.close(); + resolve(answer.trim()); + }); + }); +} + +// 第三步:使用 code 参数获取访问令牌 +async function getAccessToken(code) { + try { + const form = new URLSearchParams({ + client_id: CLIENT_ID, + client_secret: CLIENT_SECRET, + code: code, + redirect_uri: REDIRECT_URI, + grant_type: 'authorization_code' + }).toString(); + + const response = await axios.post(TOKEN_URL, form, { + // 提醒:需正确配置请求头,否则无法正常获取访问令牌 + headers: { + 'Content-Type': 'application/x-www-form-urlencoded', + 'Accept': 'application/json' + } + }); + + return response.data; + } catch (error) { + console.error(`获取访问令牌失败:${error.response ? JSON.stringify(error.response.data) : error.message}`); + throw error; + } +} + +// 第四步:使用访问令牌获取用户信息 +async function getUserInfo(accessToken) { + try { + const response = await axios.get(USER_INFO_URL, { + headers: { + Authorization: `Bearer ${accessToken}` + } + }); + + return response.data; + } catch (error) { + console.error(`获取用户信息失败:${error.response ? JSON.stringify(error.response.data) : error.message}`); + throw error; + } +} + +// 主流程 +async function main() { + // 1. 生成授权 URL,前端引导用户访问授权页 + const authUrl = getAuthUrl(); + console.log(`请访问此 URL 授权:${authUrl} +`); + + // 2. 用户授权后,从回调 URL 获取 code 参数 + const code = await getCode(); + + try { + // 3. 使用 code 参数获取访问令牌 + const tokenData = await getAccessToken(code); + const accessToken = tokenData.access_token; + + // 4. 使用访问令牌获取用户信息 + if (accessToken) { + const userInfo = await getUserInfo(accessToken); + console.log(` +获取用户信息成功:${JSON.stringify(userInfo, null, 2)}`); + } else { + console.log(` +获取访问令牌失败:${JSON.stringify(tokenData)}`); + } + } catch (error) { + console.error('发生错误:', error); + } +} +``` +Python +```python +# 安装第三方请求库,本例中使用 requests +# pip install requests + +# 通过 OAuth2 获取 Linux Do 用户信息的参考流程 +import requests +import json + +# 配置信息(建议通过环境变量配置,避免使用硬编码) +CLIENT_ID = '你的 Client ID' +CLIENT_SECRET = '你的 Client Secret' +REDIRECT_URI = '你的回调地址' +AUTH_URL = 'https://connect.linux.do/oauth2/authorize' +TOKEN_URL = 'https://connect.linux.do/oauth2/token' +USER_INFO_URL = 'https://connect.linux.do/api/user' + +# 第一步:生成授权 URL +def get_auth_url(): + params = { + 'client_id': CLIENT_ID, + 'redirect_uri': REDIRECT_URI, + 'response_type': 'code', + 'scope': 'user' + } + auth_url = f"{AUTH_URL}?{'&'.join(f'{k}={v}' for k, v in params.items())}" + return auth_url + +# 第二步:获取 code 参数 +def get_code(): + # 本例中使用终端输入来模拟流程,仅供本地测试 + # 请在实际应用中替换为真实的处理逻辑 + return input('从回调 URL 中提取出 code,粘贴到此处并按回车:').strip() + +# 第三步:使用 code 参数获取访问令牌 +def get_access_token(code): + try: + data = { + 'client_id': CLIENT_ID, + 'client_secret': CLIENT_SECRET, + 'code': code, + 'redirect_uri': REDIRECT_URI, + 'grant_type': 'authorization_code' + } + # 提醒:需正确配置请求头,否则无法正常获取访问令牌 + headers = { + 'Content-Type': 'application/x-www-form-urlencoded', + 'Accept': 'application/json' + } + response = requests.post(TOKEN_URL, data=data, headers=headers) + response.raise_for_status() + return response.json() + except requests.exceptions.RequestException as e: + print(f"获取访问令牌失败:{e}") + return None + +# 第四步:使用访问令牌获取用户信息 +def get_user_info(access_token): + try: + headers = { + 'Authorization': f'Bearer {access_token}' + } + response = requests.get(USER_INFO_URL, headers=headers) + response.raise_for_status() + return response.json() + except requests.exceptions.RequestException as e: + print(f"获取用户信息失败:{e}") + return None + +# 主流程 +if __name__ == '__main__': + # 1. 生成授权 URL,前端引导用户访问授权页 + auth_url = get_auth_url() + print(f'请访问此 URL 授权:{auth_url} +') + + # 2. 用户授权后,从回调 URL 获取 code 参数 + code = get_code() + + # 3. 使用 code 参数获取访问令牌 + token_data = get_access_token(code) + if token_data: + access_token = token_data.get('access_token') + + # 4. 使用访问令牌获取用户信息 + if access_token: + user_info = get_user_info(access_token) + if user_info: + print(f" +获取用户信息成功:{json.dumps(user_info, indent=2)}") + else: + print(" +获取用户信息失败") + else: + print(f" +获取访问令牌失败:{json.dumps(token_data, indent=2)}") + else: + print(" +获取访问令牌失败") +``` +PHP +```php +// 通过 OAuth2 获取 Linux Do 用户信息的参考流程 + +// 配置信息 +$CLIENT_ID = '你的 Client ID'; +$CLIENT_SECRET = '你的 Client Secret'; +$REDIRECT_URI = '你的回调地址'; +$AUTH_URL = 'https://connect.linux.do/oauth2/authorize'; +$TOKEN_URL = 'https://connect.linux.do/oauth2/token'; +$USER_INFO_URL = 'https://connect.linux.do/api/user'; + +// 生成授权 URL +function getAuthUrl($clientId, $redirectUri) { + global $AUTH_URL; + return $AUTH_URL . '?' . http_build_query([ + 'client_id' => $clientId, + 'redirect_uri' => $redirectUri, + 'response_type' => 'code', + 'scope' => 'user' + ]); +} + +// 使用 code 参数获取用户信息(合并获取令牌和获取用户信息的步骤) +function getUserInfoWithCode($code, $clientId, $clientSecret, $redirectUri) { + global $TOKEN_URL, $USER_INFO_URL; + + // 1. 获取访问令牌 + $ch = curl_init($TOKEN_URL); + curl_setopt($ch, CURLOPT_RETURNTRANSFER, true); + curl_setopt($ch, CURLOPT_POST, true); + curl_setopt($ch, CURLOPT_POSTFIELDS, http_build_query([ + 'client_id' => $clientId, + 'client_secret' => $clientSecret, + 'code' => $code, + 'redirect_uri' => $redirectUri, + 'grant_type' => 'authorization_code' + ])); + curl_setopt($ch, CURLOPT_HTTPHEADER, [ + 'Content-Type: application/x-www-form-urlencoded', + 'Accept: application/json' + ]); + + $tokenResponse = curl_exec($ch); + curl_close($ch); + + $tokenData = json_decode($tokenResponse, true); + if (!isset($tokenData['access_token'])) { + return ['error' => '获取访问令牌失败', 'details' => $tokenData]; + } + + // 2. 获取用户信息 + $ch = curl_init($USER_INFO_URL); + curl_setopt($ch, CURLOPT_RETURNTRANSFER, true); + curl_setopt($ch, CURLOPT_HTTPHEADER, [ + 'Authorization: Bearer ' . $tokenData['access_token'] + ]); + + $userResponse = curl_exec($ch); + curl_close($ch); + + return json_decode($userResponse, true); +} + +// 主流程 +// 1. 生成授权 URL +$authUrl = getAuthUrl($CLIENT_ID, $REDIRECT_URI); +echo "使用 Linux Do 登录"; + +// 2. 处理回调并获取用户信息 +if (isset($_GET['code'])) { + $userInfo = getUserInfoWithCode( + $_GET['code'], + $CLIENT_ID, + $CLIENT_SECRET, + $REDIRECT_URI + ); + + if (isset($userInfo['error'])) { + echo '错误: ' . $userInfo['error']; + } else { + echo '欢迎, ' . $userInfo['name'] . '!'; + // 处理用户登录逻辑... + } +} +``` + +## 使用说明 + +### 授权流程 + +1. 用户点击应用中的’使用 Linux Do 登录’按钮 +2. 系统将用户重定向至 Linux Do 的授权页面 +3. 用户完成授权后,系统自动重定向回应用并携带授权码 +4. 应用使用授权码获取访问令牌 +5. 使用访问令牌获取用户信息 + +### 安全建议 + +- 切勿在前端代码中暴露 Client Secret +- 对所有用户输入数据进行严格验证 +- 确保使用 HTTPS 协议传输数据 +- 定期更新并妥善保管 Client Secret \ No newline at end of file diff --git a/backend/cmd/server/VERSION b/backend/cmd/server/VERSION index 17e51c38..79e0dd8a 100644 --- a/backend/cmd/server/VERSION +++ b/backend/cmd/server/VERSION @@ -1 +1 @@ -0.1.1 +0.1.46 diff --git a/backend/internal/config/config.go b/backend/internal/config/config.go index c1e15290..af51c8ed 100644 --- a/backend/internal/config/config.go +++ b/backend/internal/config/config.go @@ -6,6 +6,7 @@ import ( "encoding/hex" "fmt" "log" + "net/url" "os" "strings" "time" @@ -35,24 +36,25 @@ const ( ) type Config struct { - Server ServerConfig `mapstructure:"server"` - CORS CORSConfig `mapstructure:"cors"` - Security SecurityConfig `mapstructure:"security"` - Billing BillingConfig `mapstructure:"billing"` - Turnstile TurnstileConfig `mapstructure:"turnstile"` - Database DatabaseConfig `mapstructure:"database"` - Redis RedisConfig `mapstructure:"redis"` - JWT JWTConfig `mapstructure:"jwt"` - Default DefaultConfig `mapstructure:"default"` - RateLimit RateLimitConfig `mapstructure:"rate_limit"` - Pricing PricingConfig `mapstructure:"pricing"` - Gateway GatewayConfig `mapstructure:"gateway"` - Concurrency ConcurrencyConfig `mapstructure:"concurrency"` - TokenRefresh TokenRefreshConfig `mapstructure:"token_refresh"` - RunMode string `mapstructure:"run_mode" yaml:"run_mode"` - Timezone string `mapstructure:"timezone"` // e.g. "Asia/Shanghai", "UTC" - Gemini GeminiConfig `mapstructure:"gemini"` - Update UpdateConfig `mapstructure:"update"` + Server ServerConfig `mapstructure:"server"` + CORS CORSConfig `mapstructure:"cors"` + Security SecurityConfig `mapstructure:"security"` + Billing BillingConfig `mapstructure:"billing"` + Turnstile TurnstileConfig `mapstructure:"turnstile"` + Database DatabaseConfig `mapstructure:"database"` + Redis RedisConfig `mapstructure:"redis"` + JWT JWTConfig `mapstructure:"jwt"` + LinuxDo LinuxDoConnectConfig `mapstructure:"linuxdo_connect"` + Default DefaultConfig `mapstructure:"default"` + RateLimit RateLimitConfig `mapstructure:"rate_limit"` + Pricing PricingConfig `mapstructure:"pricing"` + Gateway GatewayConfig `mapstructure:"gateway"` + Concurrency ConcurrencyConfig `mapstructure:"concurrency"` + TokenRefresh TokenRefreshConfig `mapstructure:"token_refresh"` + RunMode string `mapstructure:"run_mode" yaml:"run_mode"` + Timezone string `mapstructure:"timezone"` // e.g. "Asia/Shanghai", "UTC" + Gemini GeminiConfig `mapstructure:"gemini"` + Update UpdateConfig `mapstructure:"update"` } // UpdateConfig 在线更新相关配置 @@ -322,6 +324,30 @@ type TurnstileConfig struct { Required bool `mapstructure:"required"` } +// LinuxDoConnectConfig controls LinuxDo Connect OAuth login (end-user SSO). +// +// Note: This is NOT the same as upstream account OAuth (e.g. OpenAI/Gemini). +// It is used for logging in to Sub2API itself. +type LinuxDoConnectConfig struct { + Enabled bool `mapstructure:"enabled"` + ClientID string `mapstructure:"client_id"` + ClientSecret string `mapstructure:"client_secret"` + AuthorizeURL string `mapstructure:"authorize_url"` + TokenURL string `mapstructure:"token_url"` + UserInfoURL string `mapstructure:"userinfo_url"` + Scopes string `mapstructure:"scopes"` + RedirectURL string `mapstructure:"redirect_url"` // backend callback URL registered at the provider + FrontendRedirectURL string `mapstructure:"frontend_redirect_url"` // frontend route to receive token (default: /auth/linuxdo/callback) + TokenAuthMethod string `mapstructure:"token_auth_method"` // client_secret_post / client_secret_basic / none + UsePKCE bool `mapstructure:"use_pkce"` + + // Optional: gjson paths to extract fields from userinfo JSON. + // When empty, the server tries a set of common keys. + UserInfoEmailPath string `mapstructure:"userinfo_email_path"` + UserInfoIDPath string `mapstructure:"userinfo_id_path"` + UserInfoUsernamePath string `mapstructure:"userinfo_username_path"` +} + type DefaultConfig struct { AdminEmail string `mapstructure:"admin_email"` AdminPassword string `mapstructure:"admin_password"` @@ -388,6 +414,18 @@ func Load() (*Config, error) { cfg.Server.Mode = "debug" } cfg.JWT.Secret = strings.TrimSpace(cfg.JWT.Secret) + cfg.LinuxDo.ClientID = strings.TrimSpace(cfg.LinuxDo.ClientID) + cfg.LinuxDo.ClientSecret = strings.TrimSpace(cfg.LinuxDo.ClientSecret) + cfg.LinuxDo.AuthorizeURL = strings.TrimSpace(cfg.LinuxDo.AuthorizeURL) + cfg.LinuxDo.TokenURL = strings.TrimSpace(cfg.LinuxDo.TokenURL) + cfg.LinuxDo.UserInfoURL = strings.TrimSpace(cfg.LinuxDo.UserInfoURL) + cfg.LinuxDo.Scopes = strings.TrimSpace(cfg.LinuxDo.Scopes) + cfg.LinuxDo.RedirectURL = strings.TrimSpace(cfg.LinuxDo.RedirectURL) + cfg.LinuxDo.FrontendRedirectURL = strings.TrimSpace(cfg.LinuxDo.FrontendRedirectURL) + cfg.LinuxDo.TokenAuthMethod = strings.ToLower(strings.TrimSpace(cfg.LinuxDo.TokenAuthMethod)) + cfg.LinuxDo.UserInfoEmailPath = strings.TrimSpace(cfg.LinuxDo.UserInfoEmailPath) + cfg.LinuxDo.UserInfoIDPath = strings.TrimSpace(cfg.LinuxDo.UserInfoIDPath) + cfg.LinuxDo.UserInfoUsernamePath = strings.TrimSpace(cfg.LinuxDo.UserInfoUsernamePath) cfg.CORS.AllowedOrigins = normalizeStringSlice(cfg.CORS.AllowedOrigins) cfg.Security.ResponseHeaders.AdditionalAllowed = normalizeStringSlice(cfg.Security.ResponseHeaders.AdditionalAllowed) cfg.Security.ResponseHeaders.ForceRemove = normalizeStringSlice(cfg.Security.ResponseHeaders.ForceRemove) @@ -426,6 +464,77 @@ func Load() (*Config, error) { return &cfg, nil } +func validateAbsoluteHTTPURL(raw string) error { + raw = strings.TrimSpace(raw) + if raw == "" { + return fmt.Errorf("empty url") + } + u, err := url.Parse(raw) + if err != nil { + return err + } + if !u.IsAbs() { + return fmt.Errorf("must be absolute") + } + if !isHTTPScheme(u.Scheme) { + return fmt.Errorf("unsupported scheme: %s", u.Scheme) + } + if strings.TrimSpace(u.Host) == "" { + return fmt.Errorf("missing host") + } + if u.Fragment != "" { + return fmt.Errorf("must not include fragment") + } + return nil +} + +func validateFrontendRedirectURL(raw string) error { + raw = strings.TrimSpace(raw) + if raw == "" { + return fmt.Errorf("empty url") + } + if strings.ContainsAny(raw, "\r\n") { + return fmt.Errorf("contains invalid characters") + } + if strings.HasPrefix(raw, "/") { + if strings.HasPrefix(raw, "//") { + return fmt.Errorf("must not start with //") + } + return nil + } + u, err := url.Parse(raw) + if err != nil { + return err + } + if !u.IsAbs() { + return fmt.Errorf("must be absolute http(s) url or relative path") + } + if !isHTTPScheme(u.Scheme) { + return fmt.Errorf("unsupported scheme: %s", u.Scheme) + } + if strings.TrimSpace(u.Host) == "" { + return fmt.Errorf("missing host") + } + if u.Fragment != "" { + return fmt.Errorf("must not include fragment") + } + return nil +} + +func isHTTPScheme(scheme string) bool { + return strings.EqualFold(scheme, "http") || strings.EqualFold(scheme, "https") +} + +func warnIfInsecureURL(field, raw string) { + u, err := url.Parse(strings.TrimSpace(raw)) + if err != nil { + return + } + if strings.EqualFold(u.Scheme, "http") { + log.Printf("Warning: %s uses http scheme; use https in production to avoid token leakage.", field) + } +} + func setDefaults() { viper.SetDefault("run_mode", RunModeStandard) @@ -475,6 +584,22 @@ func setDefaults() { // Turnstile viper.SetDefault("turnstile.required", false) + // LinuxDo Connect OAuth login (end-user SSO) + viper.SetDefault("linuxdo_connect.enabled", false) + viper.SetDefault("linuxdo_connect.client_id", "") + viper.SetDefault("linuxdo_connect.client_secret", "") + viper.SetDefault("linuxdo_connect.authorize_url", "https://connect.linux.do/oauth2/authorize") + viper.SetDefault("linuxdo_connect.token_url", "https://connect.linux.do/oauth2/token") + viper.SetDefault("linuxdo_connect.userinfo_url", "https://connect.linux.do/api/user") + viper.SetDefault("linuxdo_connect.scopes", "user") + viper.SetDefault("linuxdo_connect.redirect_url", "") + viper.SetDefault("linuxdo_connect.frontend_redirect_url", "/auth/linuxdo/callback") + viper.SetDefault("linuxdo_connect.token_auth_method", "client_secret_post") + viper.SetDefault("linuxdo_connect.use_pkce", false) + viper.SetDefault("linuxdo_connect.userinfo_email_path", "") + viper.SetDefault("linuxdo_connect.userinfo_id_path", "") + viper.SetDefault("linuxdo_connect.userinfo_username_path", "") + // Database viper.SetDefault("database.host", "localhost") viper.SetDefault("database.port", 5432) @@ -586,6 +711,60 @@ func (c *Config) Validate() error { if c.Security.CSP.Enabled && strings.TrimSpace(c.Security.CSP.Policy) == "" { return fmt.Errorf("security.csp.policy is required when CSP is enabled") } + if c.LinuxDo.Enabled { + if strings.TrimSpace(c.LinuxDo.ClientID) == "" { + return fmt.Errorf("linuxdo_connect.client_id is required when linuxdo_connect.enabled=true") + } + if strings.TrimSpace(c.LinuxDo.AuthorizeURL) == "" { + return fmt.Errorf("linuxdo_connect.authorize_url is required when linuxdo_connect.enabled=true") + } + if strings.TrimSpace(c.LinuxDo.TokenURL) == "" { + return fmt.Errorf("linuxdo_connect.token_url is required when linuxdo_connect.enabled=true") + } + if strings.TrimSpace(c.LinuxDo.UserInfoURL) == "" { + return fmt.Errorf("linuxdo_connect.userinfo_url is required when linuxdo_connect.enabled=true") + } + if strings.TrimSpace(c.LinuxDo.RedirectURL) == "" { + return fmt.Errorf("linuxdo_connect.redirect_url is required when linuxdo_connect.enabled=true") + } + method := strings.ToLower(strings.TrimSpace(c.LinuxDo.TokenAuthMethod)) + switch method { + case "", "client_secret_post", "client_secret_basic", "none": + default: + return fmt.Errorf("linuxdo_connect.token_auth_method must be one of: client_secret_post/client_secret_basic/none") + } + if method == "none" && !c.LinuxDo.UsePKCE { + return fmt.Errorf("linuxdo_connect.use_pkce must be true when linuxdo_connect.token_auth_method=none") + } + if (method == "" || method == "client_secret_post" || method == "client_secret_basic") && strings.TrimSpace(c.LinuxDo.ClientSecret) == "" { + return fmt.Errorf("linuxdo_connect.client_secret is required when linuxdo_connect.enabled=true and token_auth_method is client_secret_post/client_secret_basic") + } + if strings.TrimSpace(c.LinuxDo.FrontendRedirectURL) == "" { + return fmt.Errorf("linuxdo_connect.frontend_redirect_url is required when linuxdo_connect.enabled=true") + } + + if err := validateAbsoluteHTTPURL(c.LinuxDo.AuthorizeURL); err != nil { + return fmt.Errorf("linuxdo_connect.authorize_url invalid: %w", err) + } + if err := validateAbsoluteHTTPURL(c.LinuxDo.TokenURL); err != nil { + return fmt.Errorf("linuxdo_connect.token_url invalid: %w", err) + } + if err := validateAbsoluteHTTPURL(c.LinuxDo.UserInfoURL); err != nil { + return fmt.Errorf("linuxdo_connect.userinfo_url invalid: %w", err) + } + if err := validateAbsoluteHTTPURL(c.LinuxDo.RedirectURL); err != nil { + return fmt.Errorf("linuxdo_connect.redirect_url invalid: %w", err) + } + if err := validateFrontendRedirectURL(c.LinuxDo.FrontendRedirectURL); err != nil { + return fmt.Errorf("linuxdo_connect.frontend_redirect_url invalid: %w", err) + } + + warnIfInsecureURL("linuxdo_connect.authorize_url", c.LinuxDo.AuthorizeURL) + warnIfInsecureURL("linuxdo_connect.token_url", c.LinuxDo.TokenURL) + warnIfInsecureURL("linuxdo_connect.userinfo_url", c.LinuxDo.UserInfoURL) + warnIfInsecureURL("linuxdo_connect.redirect_url", c.LinuxDo.RedirectURL) + warnIfInsecureURL("linuxdo_connect.frontend_redirect_url", c.LinuxDo.FrontendRedirectURL) + } if c.Billing.CircuitBreaker.Enabled { if c.Billing.CircuitBreaker.FailureThreshold <= 0 { return fmt.Errorf("billing.circuit_breaker.failure_threshold must be positive") diff --git a/backend/internal/config/config_test.go b/backend/internal/config/config_test.go index f28680c6..a39d41f9 100644 --- a/backend/internal/config/config_test.go +++ b/backend/internal/config/config_test.go @@ -1,6 +1,7 @@ package config import ( + "strings" "testing" "time" @@ -90,3 +91,53 @@ func TestLoadDefaultSecurityToggles(t *testing.T) { t.Fatalf("ResponseHeaders.Enabled = true, want false") } } + +func TestValidateLinuxDoFrontendRedirectURL(t *testing.T) { + viper.Reset() + + cfg, err := Load() + if err != nil { + t.Fatalf("Load() error: %v", err) + } + + cfg.LinuxDo.Enabled = true + cfg.LinuxDo.ClientID = "test-client" + cfg.LinuxDo.ClientSecret = "test-secret" + cfg.LinuxDo.RedirectURL = "https://example.com/api/v1/auth/oauth/linuxdo/callback" + cfg.LinuxDo.TokenAuthMethod = "client_secret_post" + cfg.LinuxDo.UsePKCE = false + + cfg.LinuxDo.FrontendRedirectURL = "javascript:alert(1)" + err = cfg.Validate() + if err == nil { + t.Fatalf("Validate() expected error for javascript scheme, got nil") + } + if !strings.Contains(err.Error(), "linuxdo_connect.frontend_redirect_url") { + t.Fatalf("Validate() expected frontend_redirect_url error, got: %v", err) + } +} + +func TestValidateLinuxDoPKCERequiredForPublicClient(t *testing.T) { + viper.Reset() + + cfg, err := Load() + if err != nil { + t.Fatalf("Load() error: %v", err) + } + + cfg.LinuxDo.Enabled = true + cfg.LinuxDo.ClientID = "test-client" + cfg.LinuxDo.ClientSecret = "" + cfg.LinuxDo.RedirectURL = "https://example.com/api/v1/auth/oauth/linuxdo/callback" + cfg.LinuxDo.FrontendRedirectURL = "/auth/linuxdo/callback" + cfg.LinuxDo.TokenAuthMethod = "none" + cfg.LinuxDo.UsePKCE = false + + err = cfg.Validate() + if err == nil { + t.Fatalf("Validate() expected error when token_auth_method=none and use_pkce=false, got nil") + } + if !strings.Contains(err.Error(), "linuxdo_connect.use_pkce") { + t.Fatalf("Validate() expected use_pkce error, got: %v", err) + } +} diff --git a/backend/internal/handler/auth_linuxdo_oauth.go b/backend/internal/handler/auth_linuxdo_oauth.go new file mode 100644 index 00000000..07310213 --- /dev/null +++ b/backend/internal/handler/auth_linuxdo_oauth.go @@ -0,0 +1,517 @@ +package handler + +import ( + "context" + "encoding/base64" + "errors" + "fmt" + "log" + "net/http" + "net/url" + "strings" + "time" + "unicode/utf8" + + "github.com/Wei-Shaw/sub2api/internal/config" + infraerrors "github.com/Wei-Shaw/sub2api/internal/pkg/errors" + "github.com/Wei-Shaw/sub2api/internal/pkg/oauth" + "github.com/Wei-Shaw/sub2api/internal/pkg/response" + + "github.com/gin-gonic/gin" + "github.com/imroc/req/v3" + "github.com/tidwall/gjson" +) + +const ( + linuxDoOAuthCookiePath = "/api/v1/auth/oauth/linuxdo" + linuxDoOAuthStateCookieName = "linuxdo_oauth_state" + linuxDoOAuthVerifierCookie = "linuxdo_oauth_verifier" + linuxDoOAuthRedirectCookie = "linuxdo_oauth_redirect" + linuxDoOAuthCookieMaxAgeSec = 10 * 60 // 10 minutes + linuxDoOAuthDefaultRedirectTo = "/dashboard" + linuxDoOAuthDefaultFrontendCB = "/auth/linuxdo/callback" + + linuxDoOAuthMaxRedirectLen = 2048 + linuxDoOAuthMaxFragmentValueLen = 512 + linuxDoOAuthMaxSubjectLen = 64 - len("linuxdo-") +) + +type linuxDoTokenResponse struct { + AccessToken string `json:"access_token"` + TokenType string `json:"token_type"` + ExpiresIn int64 `json:"expires_in"` + RefreshToken string `json:"refresh_token,omitempty"` + Scope string `json:"scope,omitempty"` +} + +// LinuxDoOAuthStart starts the LinuxDo Connect OAuth login flow. +// GET /api/v1/auth/oauth/linuxdo/start?redirect=/dashboard +func (h *AuthHandler) LinuxDoOAuthStart(c *gin.Context) { + cfg, err := linuxDoOAuthConfig(h.cfg) + if err != nil { + response.ErrorFrom(c, err) + return + } + + state, err := oauth.GenerateState() + if err != nil { + response.ErrorFrom(c, infraerrors.InternalServer("OAUTH_STATE_GEN_FAILED", "failed to generate oauth state").WithCause(err)) + return + } + + redirectTo := sanitizeFrontendRedirectPath(c.Query("redirect")) + if redirectTo == "" { + redirectTo = linuxDoOAuthDefaultRedirectTo + } + + secureCookie := isRequestHTTPS(c) + setCookie(c, linuxDoOAuthStateCookieName, encodeCookieValue(state), linuxDoOAuthCookieMaxAgeSec, secureCookie) + setCookie(c, linuxDoOAuthRedirectCookie, encodeCookieValue(redirectTo), linuxDoOAuthCookieMaxAgeSec, secureCookie) + + codeChallenge := "" + if cfg.UsePKCE { + verifier, err := oauth.GenerateCodeVerifier() + if err != nil { + response.ErrorFrom(c, infraerrors.InternalServer("OAUTH_PKCE_GEN_FAILED", "failed to generate pkce verifier").WithCause(err)) + return + } + codeChallenge = oauth.GenerateCodeChallenge(verifier) + setCookie(c, linuxDoOAuthVerifierCookie, encodeCookieValue(verifier), linuxDoOAuthCookieMaxAgeSec, secureCookie) + } + + redirectURI := strings.TrimSpace(cfg.RedirectURL) + if redirectURI == "" { + response.ErrorFrom(c, infraerrors.InternalServer("OAUTH_CONFIG_INVALID", "oauth redirect url not configured")) + return + } + + authURL, err := buildLinuxDoAuthorizeURL(cfg, state, codeChallenge, redirectURI) + if err != nil { + response.ErrorFrom(c, infraerrors.InternalServer("OAUTH_BUILD_URL_FAILED", "failed to build oauth authorization url").WithCause(err)) + return + } + + c.Redirect(http.StatusFound, authURL) +} + +// LinuxDoOAuthCallback handles the OAuth callback, creates/logins the user, then redirects to frontend. +// GET /api/v1/auth/oauth/linuxdo/callback?code=...&state=... +func (h *AuthHandler) LinuxDoOAuthCallback(c *gin.Context) { + cfg, cfgErr := linuxDoOAuthConfig(h.cfg) + if cfgErr != nil { + response.ErrorFrom(c, cfgErr) + return + } + + frontendCallback := strings.TrimSpace(cfg.FrontendRedirectURL) + if frontendCallback == "" { + frontendCallback = linuxDoOAuthDefaultFrontendCB + } + + if providerErr := strings.TrimSpace(c.Query("error")); providerErr != "" { + redirectOAuthError(c, frontendCallback, "provider_error", providerErr, c.Query("error_description")) + return + } + + code := strings.TrimSpace(c.Query("code")) + state := strings.TrimSpace(c.Query("state")) + if code == "" || state == "" { + redirectOAuthError(c, frontendCallback, "missing_params", "missing code/state", "") + return + } + + secureCookie := isRequestHTTPS(c) + defer func() { + clearCookie(c, linuxDoOAuthStateCookieName, secureCookie) + clearCookie(c, linuxDoOAuthVerifierCookie, secureCookie) + clearCookie(c, linuxDoOAuthRedirectCookie, secureCookie) + }() + + expectedState, err := readCookieDecoded(c, linuxDoOAuthStateCookieName) + if err != nil || expectedState == "" || state != expectedState { + redirectOAuthError(c, frontendCallback, "invalid_state", "invalid oauth state", "") + return + } + + redirectTo, _ := readCookieDecoded(c, linuxDoOAuthRedirectCookie) + redirectTo = sanitizeFrontendRedirectPath(redirectTo) + if redirectTo == "" { + redirectTo = linuxDoOAuthDefaultRedirectTo + } + + codeVerifier := "" + if cfg.UsePKCE { + codeVerifier, _ = readCookieDecoded(c, linuxDoOAuthVerifierCookie) + if codeVerifier == "" { + redirectOAuthError(c, frontendCallback, "missing_verifier", "missing pkce verifier", "") + return + } + } + + redirectURI := strings.TrimSpace(cfg.RedirectURL) + if redirectURI == "" { + redirectOAuthError(c, frontendCallback, "config_error", "oauth redirect url not configured", "") + return + } + + tokenResp, err := linuxDoExchangeCode(c.Request.Context(), cfg, code, redirectURI, codeVerifier) + if err != nil { + log.Printf("[LinuxDo OAuth] token exchange failed: %v", err) + redirectOAuthError(c, frontendCallback, "token_exchange_failed", "failed to exchange oauth code", "") + return + } + + email, username, _, err := linuxDoFetchUserInfo(c.Request.Context(), cfg, tokenResp) + if err != nil { + log.Printf("[LinuxDo OAuth] userinfo fetch failed: %v", err) + redirectOAuthError(c, frontendCallback, "userinfo_failed", "failed to fetch user info", "") + return + } + + jwtToken, _, err := h.authService.LoginOrRegisterOAuth(c.Request.Context(), email, username) + if err != nil { + // Avoid leaking internal details to the client; keep structured reason for frontend. + redirectOAuthError(c, frontendCallback, "login_failed", infraerrors.Reason(err), infraerrors.Message(err)) + return + } + + fragment := url.Values{} + fragment.Set("access_token", jwtToken) + fragment.Set("token_type", "Bearer") + fragment.Set("redirect", redirectTo) + redirectWithFragment(c, frontendCallback, fragment) +} + +func linuxDoOAuthConfig(cfg *config.Config) (config.LinuxDoConnectConfig, error) { + if cfg == nil { + return config.LinuxDoConnectConfig{}, infraerrors.ServiceUnavailable("CONFIG_NOT_READY", "config not loaded") + } + if !cfg.LinuxDo.Enabled { + return config.LinuxDoConnectConfig{}, infraerrors.NotFound("OAUTH_DISABLED", "oauth login is disabled") + } + return cfg.LinuxDo, nil +} + +func linuxDoExchangeCode( + ctx context.Context, + cfg config.LinuxDoConnectConfig, + code string, + redirectURI string, + codeVerifier string, +) (*linuxDoTokenResponse, error) { + client := req.C().SetTimeout(30 * time.Second) + + form := url.Values{} + form.Set("grant_type", "authorization_code") + form.Set("client_id", cfg.ClientID) + form.Set("code", code) + form.Set("redirect_uri", redirectURI) + if cfg.UsePKCE { + form.Set("code_verifier", codeVerifier) + } + + r := client.R(). + SetContext(ctx). + SetHeader("Accept", "application/json") + + switch strings.ToLower(strings.TrimSpace(cfg.TokenAuthMethod)) { + case "", "client_secret_post": + form.Set("client_secret", cfg.ClientSecret) + case "client_secret_basic": + r.SetBasicAuth(cfg.ClientID, cfg.ClientSecret) + case "none": + default: + return nil, fmt.Errorf("unsupported token_auth_method: %s", cfg.TokenAuthMethod) + } + + var tokenResp linuxDoTokenResponse + resp, err := r.SetFormDataFromValues(form).SetSuccessResult(&tokenResp).Post(cfg.TokenURL) + if err != nil { + return nil, fmt.Errorf("request token: %w", err) + } + if !resp.IsSuccessState() { + return nil, fmt.Errorf("token exchange status=%d", resp.StatusCode) + } + if strings.TrimSpace(tokenResp.AccessToken) == "" { + return nil, errors.New("token response missing access_token") + } + if strings.TrimSpace(tokenResp.TokenType) == "" { + tokenResp.TokenType = "Bearer" + } + return &tokenResp, nil +} + +func linuxDoFetchUserInfo( + ctx context.Context, + cfg config.LinuxDoConnectConfig, + token *linuxDoTokenResponse, +) (email string, username string, subject string, err error) { + client := req.C().SetTimeout(30 * time.Second) + authorization, err := buildBearerAuthorization(token.TokenType, token.AccessToken) + if err != nil { + return "", "", "", fmt.Errorf("invalid token for userinfo request: %w", err) + } + + resp, err := client.R(). + SetContext(ctx). + SetHeader("Accept", "application/json"). + SetHeader("Authorization", authorization). + Get(cfg.UserInfoURL) + if err != nil { + return "", "", "", fmt.Errorf("request userinfo: %w", err) + } + if !resp.IsSuccessState() { + return "", "", "", fmt.Errorf("userinfo status=%d", resp.StatusCode) + } + + return linuxDoParseUserInfo(resp.String(), cfg) +} + +func linuxDoParseUserInfo(body string, cfg config.LinuxDoConnectConfig) (email string, username string, subject string, err error) { + email = firstNonEmpty( + getGJSON(body, cfg.UserInfoEmailPath), + getGJSON(body, "email"), + getGJSON(body, "user.email"), + getGJSON(body, "data.email"), + getGJSON(body, "attributes.email"), + ) + username = firstNonEmpty( + getGJSON(body, cfg.UserInfoUsernamePath), + getGJSON(body, "username"), + getGJSON(body, "preferred_username"), + getGJSON(body, "name"), + getGJSON(body, "user.username"), + getGJSON(body, "user.name"), + ) + subject = firstNonEmpty( + getGJSON(body, cfg.UserInfoIDPath), + getGJSON(body, "sub"), + getGJSON(body, "id"), + getGJSON(body, "user_id"), + getGJSON(body, "uid"), + getGJSON(body, "user.id"), + ) + + subject = strings.TrimSpace(subject) + if subject == "" { + return "", "", "", errors.New("userinfo missing id field") + } + if !isSafeLinuxDoSubject(subject) { + return "", "", "", errors.New("userinfo returned invalid id field") + } + + email = strings.TrimSpace(email) + if email == "" { + // LinuxDo Connect userinfo does not necessarily provide email. To keep compatibility with the + // existing user schema (email is required/unique), use a stable synthetic email. + email = fmt.Sprintf("linuxdo-%s@linuxdo-connect.invalid", subject) + } + + username = strings.TrimSpace(username) + if username == "" { + username = "linuxdo_" + subject + } + + return email, username, subject, nil +} + +func buildLinuxDoAuthorizeURL(cfg config.LinuxDoConnectConfig, state string, codeChallenge string, redirectURI string) (string, error) { + u, err := url.Parse(cfg.AuthorizeURL) + if err != nil { + return "", fmt.Errorf("parse authorize_url: %w", err) + } + + q := u.Query() + q.Set("response_type", "code") + q.Set("client_id", cfg.ClientID) + q.Set("redirect_uri", redirectURI) + if strings.TrimSpace(cfg.Scopes) != "" { + q.Set("scope", cfg.Scopes) + } + q.Set("state", state) + if cfg.UsePKCE { + q.Set("code_challenge", codeChallenge) + q.Set("code_challenge_method", "S256") + } + + u.RawQuery = q.Encode() + return u.String(), nil +} + +func redirectOAuthError(c *gin.Context, frontendCallback string, code string, message string, description string) { + fragment := url.Values{} + fragment.Set("error", truncateFragmentValue(code)) + if strings.TrimSpace(message) != "" { + fragment.Set("error_message", truncateFragmentValue(message)) + } + if strings.TrimSpace(description) != "" { + fragment.Set("error_description", truncateFragmentValue(description)) + } + redirectWithFragment(c, frontendCallback, fragment) +} + +func redirectWithFragment(c *gin.Context, frontendCallback string, fragment url.Values) { + u, err := url.Parse(frontendCallback) + if err != nil { + // Fallback: best-effort redirect. + c.Redirect(http.StatusFound, linuxDoOAuthDefaultRedirectTo) + return + } + if u.Scheme != "" && !strings.EqualFold(u.Scheme, "http") && !strings.EqualFold(u.Scheme, "https") { + c.Redirect(http.StatusFound, linuxDoOAuthDefaultRedirectTo) + return + } + u.Fragment = fragment.Encode() + c.Header("Cache-Control", "no-store") + c.Header("Pragma", "no-cache") + c.Redirect(http.StatusFound, u.String()) +} + +func firstNonEmpty(values ...string) string { + for _, v := range values { + v = strings.TrimSpace(v) + if v != "" { + return v + } + } + return "" +} + +func getGJSON(body string, path string) string { + path = strings.TrimSpace(path) + if path == "" { + return "" + } + res := gjson.Get(body, path) + if !res.Exists() { + return "" + } + return res.String() +} + +func sanitizeFrontendRedirectPath(path string) string { + path = strings.TrimSpace(path) + if path == "" { + return "" + } + if len(path) > linuxDoOAuthMaxRedirectLen { + return "" + } + // Only allow same-origin relative paths (avoid open redirect). + if !strings.HasPrefix(path, "/") { + return "" + } + if strings.HasPrefix(path, "//") { + return "" + } + if strings.Contains(path, "://") { + return "" + } + if strings.ContainsAny(path, "\r\n") { + return "" + } + return path +} + +func isRequestHTTPS(c *gin.Context) bool { + if c.Request.TLS != nil { + return true + } + proto := strings.ToLower(strings.TrimSpace(c.GetHeader("X-Forwarded-Proto"))) + return proto == "https" +} + +func encodeCookieValue(value string) string { + return base64.RawURLEncoding.EncodeToString([]byte(value)) +} + +func decodeCookieValue(value string) (string, error) { + raw, err := base64.RawURLEncoding.DecodeString(value) + if err != nil { + return "", err + } + return string(raw), nil +} + +func readCookieDecoded(c *gin.Context, name string) (string, error) { + ck, err := c.Request.Cookie(name) + if err != nil { + return "", err + } + return decodeCookieValue(ck.Value) +} + +func setCookie(c *gin.Context, name string, value string, maxAgeSec int, secure bool) { + http.SetCookie(c.Writer, &http.Cookie{ + Name: name, + Value: value, + Path: linuxDoOAuthCookiePath, + MaxAge: maxAgeSec, + HttpOnly: true, + Secure: secure, + SameSite: http.SameSiteLaxMode, + }) +} + +func clearCookie(c *gin.Context, name string, secure bool) { + http.SetCookie(c.Writer, &http.Cookie{ + Name: name, + Value: "", + Path: linuxDoOAuthCookiePath, + MaxAge: -1, + HttpOnly: true, + Secure: secure, + SameSite: http.SameSiteLaxMode, + }) +} + +func truncateFragmentValue(value string) string { + value = strings.TrimSpace(value) + if value == "" { + return "" + } + if len(value) > linuxDoOAuthMaxFragmentValueLen { + value = value[:linuxDoOAuthMaxFragmentValueLen] + for !utf8.ValidString(value) { + value = value[:len(value)-1] + } + } + return value +} + +func buildBearerAuthorization(tokenType, accessToken string) (string, error) { + tokenType = strings.TrimSpace(tokenType) + if tokenType == "" { + tokenType = "Bearer" + } + if !strings.EqualFold(tokenType, "Bearer") { + return "", fmt.Errorf("unsupported token_type: %s", tokenType) + } + + accessToken = strings.TrimSpace(accessToken) + if accessToken == "" { + return "", errors.New("missing access_token") + } + if strings.ContainsAny(accessToken, " \t\r\n") { + return "", errors.New("access_token contains whitespace") + } + return "Bearer " + accessToken, nil +} + +func isSafeLinuxDoSubject(subject string) bool { + subject = strings.TrimSpace(subject) + if subject == "" || len(subject) > linuxDoOAuthMaxSubjectLen { + return false + } + for _, r := range subject { + switch { + case r >= '0' && r <= '9': + case r >= 'a' && r <= 'z': + case r >= 'A' && r <= 'Z': + case r == '_' || r == '-': + default: + return false + } + } + return true +} diff --git a/backend/internal/handler/auth_linuxdo_oauth_test.go b/backend/internal/handler/auth_linuxdo_oauth_test.go new file mode 100644 index 00000000..03db69a8 --- /dev/null +++ b/backend/internal/handler/auth_linuxdo_oauth_test.go @@ -0,0 +1,74 @@ +package handler + +import ( + "strings" + "testing" + + "github.com/Wei-Shaw/sub2api/internal/config" + "github.com/stretchr/testify/require" +) + +func TestSanitizeFrontendRedirectPath(t *testing.T) { + require.Equal(t, "/dashboard", sanitizeFrontendRedirectPath("/dashboard")) + require.Equal(t, "/dashboard", sanitizeFrontendRedirectPath(" /dashboard ")) + require.Equal(t, "", sanitizeFrontendRedirectPath("dashboard")) + require.Equal(t, "", sanitizeFrontendRedirectPath("//evil.com")) + require.Equal(t, "", sanitizeFrontendRedirectPath("https://evil.com")) + require.Equal(t, "", sanitizeFrontendRedirectPath("/\nfoo")) + + long := "/" + strings.Repeat("a", linuxDoOAuthMaxRedirectLen) + require.Equal(t, "", sanitizeFrontendRedirectPath(long)) +} + +func TestBuildBearerAuthorization(t *testing.T) { + auth, err := buildBearerAuthorization("", "token123") + require.NoError(t, err) + require.Equal(t, "Bearer token123", auth) + + auth, err = buildBearerAuthorization("bearer", "token123") + require.NoError(t, err) + require.Equal(t, "Bearer token123", auth) + + _, err = buildBearerAuthorization("MAC", "token123") + require.Error(t, err) + + _, err = buildBearerAuthorization("Bearer", "token 123") + require.Error(t, err) +} + +func TestLinuxDoParseUserInfoParsesIDAndUsername(t *testing.T) { + cfg := config.LinuxDoConnectConfig{ + UserInfoURL: "https://connect.linux.do/api/user", + } + + email, username, subject, err := linuxDoParseUserInfo(`{"id":123,"username":"alice"}`, cfg) + require.NoError(t, err) + require.Equal(t, "123", subject) + require.Equal(t, "alice", username) + require.Equal(t, "linuxdo-123@linuxdo-connect.invalid", email) +} + +func TestLinuxDoParseUserInfoDefaultsUsername(t *testing.T) { + cfg := config.LinuxDoConnectConfig{ + UserInfoURL: "https://connect.linux.do/api/user", + } + + email, username, subject, err := linuxDoParseUserInfo(`{"id":"123"}`, cfg) + require.NoError(t, err) + require.Equal(t, "123", subject) + require.Equal(t, "linuxdo_123", username) + require.Equal(t, "linuxdo-123@linuxdo-connect.invalid", email) +} + +func TestLinuxDoParseUserInfoRejectsUnsafeSubject(t *testing.T) { + cfg := config.LinuxDoConnectConfig{ + UserInfoURL: "https://connect.linux.do/api/user", + } + + _, _, _, err := linuxDoParseUserInfo(`{"id":"123@456"}`, cfg) + require.Error(t, err) + + tooLong := strings.Repeat("a", linuxDoOAuthMaxSubjectLen+1) + _, _, _, err = linuxDoParseUserInfo(`{"id":"`+tooLong+`"}`, cfg) + require.Error(t, err) +} diff --git a/backend/internal/handler/dto/settings.go b/backend/internal/handler/dto/settings.go index 4c50cedf..7382a577 100644 --- a/backend/internal/handler/dto/settings.go +++ b/backend/internal/handler/dto/settings.go @@ -50,5 +50,6 @@ type PublicSettings struct { APIBaseURL string `json:"api_base_url"` ContactInfo string `json:"contact_info"` DocURL string `json:"doc_url"` + LinuxDoOAuthEnabled bool `json:"linuxdo_oauth_enabled"` Version string `json:"version"` } diff --git a/backend/internal/handler/setting_handler.go b/backend/internal/handler/setting_handler.go index 3cae7a7f..e1b20c8c 100644 --- a/backend/internal/handler/setting_handler.go +++ b/backend/internal/handler/setting_handler.go @@ -42,6 +42,7 @@ func (h *SettingHandler) GetPublicSettings(c *gin.Context) { APIBaseURL: settings.APIBaseURL, ContactInfo: settings.ContactInfo, DocURL: settings.DocURL, + LinuxDoOAuthEnabled: settings.LinuxDoOAuthEnabled, Version: h.version, }) } diff --git a/backend/internal/server/routes/auth.go b/backend/internal/server/routes/auth.go index 196d8bdb..e61d3939 100644 --- a/backend/internal/server/routes/auth.go +++ b/backend/internal/server/routes/auth.go @@ -19,6 +19,8 @@ func RegisterAuthRoutes( auth.POST("/register", h.Auth.Register) auth.POST("/login", h.Auth.Login) auth.POST("/send-verify-code", h.Auth.SendVerifyCode) + auth.GET("/oauth/linuxdo/start", h.Auth.LinuxDoOAuthStart) + auth.GET("/oauth/linuxdo/callback", h.Auth.LinuxDoOAuthCallback) } // 公开设置(无需认证) diff --git a/backend/internal/service/auth_service.go b/backend/internal/service/auth_service.go index 6e685869..e3532b25 100644 --- a/backend/internal/service/auth_service.go +++ b/backend/internal/service/auth_service.go @@ -2,9 +2,13 @@ package service import ( "context" + "crypto/rand" + "encoding/hex" "errors" "fmt" "log" + "net/mail" + "strings" "time" "github.com/Wei-Shaw/sub2api/internal/config" @@ -18,6 +22,7 @@ var ( ErrInvalidCredentials = infraerrors.Unauthorized("INVALID_CREDENTIALS", "invalid email or password") ErrUserNotActive = infraerrors.Forbidden("USER_NOT_ACTIVE", "user is not active") ErrEmailExists = infraerrors.Conflict("EMAIL_EXISTS", "email already exists") + ErrEmailReserved = infraerrors.BadRequest("EMAIL_RESERVED", "email is reserved") ErrInvalidToken = infraerrors.Unauthorized("INVALID_TOKEN", "invalid token") ErrTokenExpired = infraerrors.Unauthorized("TOKEN_EXPIRED", "token has expired") ErrTokenTooLarge = infraerrors.BadRequest("TOKEN_TOO_LARGE", "token too large") @@ -27,6 +32,8 @@ var ( ErrServiceUnavailable = infraerrors.ServiceUnavailable("SERVICE_UNAVAILABLE", "service temporarily unavailable") ) +const linuxDoSyntheticEmailDomain = "@linuxdo-connect.invalid" + // maxTokenLength 限制 token 大小,避免超长 header 触发解析时的异常内存分配。 const maxTokenLength = 8192 @@ -80,6 +87,11 @@ func (s *AuthService) RegisterWithVerification(ctx context.Context, email, passw return "", nil, ErrRegDisabled } + // Prevent users from registering emails reserved for synthetic OAuth accounts. + if isReservedEmail(email) { + return "", nil, ErrEmailReserved + } + // 检查是否需要邮件验证 if s.settingService != nil && s.settingService.IsEmailVerifyEnabled(ctx) { // 如果邮件验证已开启但邮件服务未配置,拒绝注册 @@ -161,6 +173,10 @@ func (s *AuthService) SendVerifyCode(ctx context.Context, email string) error { return ErrRegDisabled } + if isReservedEmail(email) { + return ErrEmailReserved + } + // 检查邮箱是否已存在 existsEmail, err := s.userRepo.ExistsByEmail(ctx, email) if err != nil { @@ -195,6 +211,10 @@ func (s *AuthService) SendVerifyCodeAsync(ctx context.Context, email string) (*S return nil, ErrRegDisabled } + if isReservedEmail(email) { + return nil, ErrEmailReserved + } + // 检查邮箱是否已存在 existsEmail, err := s.userRepo.ExistsByEmail(ctx, email) if err != nil { @@ -319,6 +339,101 @@ func (s *AuthService) Login(ctx context.Context, email, password string) (string return token, user, nil } +// LoginOrRegisterOAuth logs a user in by email (trusted from an OAuth provider) or creates a new user. +// +// This is used by end-user OAuth/SSO login flows (e.g. LinuxDo Connect), and intentionally does +// NOT require the local password. A random password hash is generated for new users to satisfy +// the existing database constraint. +func (s *AuthService) LoginOrRegisterOAuth(ctx context.Context, email, username string) (string, *User, error) { + email = strings.TrimSpace(email) + if email == "" || len(email) > 255 { + return "", nil, infraerrors.BadRequest("INVALID_EMAIL", "invalid email") + } + if _, err := mail.ParseAddress(email); err != nil { + return "", nil, infraerrors.BadRequest("INVALID_EMAIL", "invalid email") + } + + username = strings.TrimSpace(username) + if len([]rune(username)) > 100 { + username = string([]rune(username)[:100]) + } + + user, err := s.userRepo.GetByEmail(ctx, email) + if err != nil { + if errors.Is(err, ErrUserNotFound) { + // Treat OAuth-first login as registration. + if s.settingService != nil && !s.settingService.IsRegistrationEnabled(ctx) { + return "", nil, ErrRegDisabled + } + + randomPassword, err := randomHexString(32) + if err != nil { + log.Printf("[Auth] Failed to generate random password for oauth signup: %v", err) + return "", nil, ErrServiceUnavailable + } + hashedPassword, err := s.HashPassword(randomPassword) + if err != nil { + return "", nil, fmt.Errorf("hash password: %w", err) + } + + // Defaults for new users. + defaultBalance := s.cfg.Default.UserBalance + defaultConcurrency := s.cfg.Default.UserConcurrency + if s.settingService != nil { + defaultBalance = s.settingService.GetDefaultBalance(ctx) + defaultConcurrency = s.settingService.GetDefaultConcurrency(ctx) + } + + newUser := &User{ + Email: email, + Username: username, + PasswordHash: hashedPassword, + Role: RoleUser, + Balance: defaultBalance, + Concurrency: defaultConcurrency, + Status: StatusActive, + } + + if err := s.userRepo.Create(ctx, newUser); err != nil { + if errors.Is(err, ErrEmailExists) { + // Race: user created between GetByEmail and Create. + user, err = s.userRepo.GetByEmail(ctx, email) + if err != nil { + log.Printf("[Auth] Database error getting user after conflict: %v", err) + return "", nil, ErrServiceUnavailable + } + } else { + log.Printf("[Auth] Database error creating oauth user: %v", err) + return "", nil, ErrServiceUnavailable + } + } else { + user = newUser + } + } else { + log.Printf("[Auth] Database error during oauth login: %v", err) + return "", nil, ErrServiceUnavailable + } + } + + if !user.IsActive() { + return "", nil, ErrUserNotActive + } + + // Best-effort: fill username when empty. + if user.Username == "" && username != "" { + user.Username = username + if err := s.userRepo.Update(ctx, user); err != nil { + log.Printf("[Auth] Failed to update username after oauth login: %v", err) + } + } + + token, err := s.GenerateToken(user) + if err != nil { + return "", nil, fmt.Errorf("generate token: %w", err) + } + return token, user, nil +} + // ValidateToken 验证JWT token并返回用户声明 func (s *AuthService) ValidateToken(tokenString string) (*JWTClaims, error) { // 先做长度校验,尽早拒绝异常超长 token,降低 DoS 风险。 @@ -361,6 +476,22 @@ func (s *AuthService) ValidateToken(tokenString string) (*JWTClaims, error) { return nil, ErrInvalidToken } +func randomHexString(byteLength int) (string, error) { + if byteLength <= 0 { + byteLength = 16 + } + buf := make([]byte, byteLength) + if _, err := rand.Read(buf); err != nil { + return "", err + } + return hex.EncodeToString(buf), nil +} + +func isReservedEmail(email string) bool { + normalized := strings.ToLower(strings.TrimSpace(email)) + return strings.HasSuffix(normalized, linuxDoSyntheticEmailDomain) +} + // GenerateToken 生成JWT token func (s *AuthService) GenerateToken(user *User) (string, error) { now := time.Now() diff --git a/backend/internal/service/auth_service_register_test.go b/backend/internal/service/auth_service_register_test.go index bfd504a3..8e99ea29 100644 --- a/backend/internal/service/auth_service_register_test.go +++ b/backend/internal/service/auth_service_register_test.go @@ -182,6 +182,14 @@ func TestAuthService_Register_CheckEmailError(t *testing.T) { require.ErrorIs(t, err, ErrServiceUnavailable) } +func TestAuthService_Register_ReservedEmail(t *testing.T) { + repo := &userRepoStub{} + service := newAuthService(repo, nil, nil) + + _, _, err := service.Register(context.Background(), "linuxdo-123@linuxdo-connect.invalid", "password") + require.ErrorIs(t, err, ErrEmailReserved) +} + func TestAuthService_Register_CreateError(t *testing.T) { repo := &userRepoStub{createErr: errors.New("create failed")} service := newAuthService(repo, map[string]string{ diff --git a/backend/internal/service/setting_service.go b/backend/internal/service/setting_service.go index 965253cf..b3a3bf21 100644 --- a/backend/internal/service/setting_service.go +++ b/backend/internal/service/setting_service.go @@ -82,6 +82,7 @@ func (s *SettingService) GetPublicSettings(ctx context.Context) (*PublicSettings APIBaseURL: settings[SettingKeyAPIBaseURL], ContactInfo: settings[SettingKeyContactInfo], DocURL: settings[SettingKeyDocURL], + LinuxDoOAuthEnabled: s.cfg != nil && s.cfg.LinuxDo.Enabled, }, nil } diff --git a/backend/internal/service/settings_view.go b/backend/internal/service/settings_view.go index de0331f7..a06723f8 100644 --- a/backend/internal/service/settings_view.go +++ b/backend/internal/service/settings_view.go @@ -51,5 +51,6 @@ type PublicSettings struct { APIBaseURL string ContactInfo string DocURL string + LinuxDoOAuthEnabled bool Version string } diff --git a/deploy/config.example.yaml b/deploy/config.example.yaml index 49bf0afa..936f0ea4 100644 --- a/deploy/config.example.yaml +++ b/deploy/config.example.yaml @@ -234,6 +234,31 @@ jwt: # 令牌过期时间(小时,最大 24) expire_hour: 24 +# ============================================================================= +# LinuxDo Connect OAuth Login (SSO) +# LinuxDo Connect OAuth 登录(用于 Sub2API 用户登录) +# ============================================================================= +linuxdo_connect: + enabled: false + client_id: "" + client_secret: "" + authorize_url: "https://connect.linux.do/oauth2/authorize" + token_url: "https://connect.linux.do/oauth2/token" + userinfo_url: "https://connect.linux.do/api/user" + scopes: "user" + # 示例: "https://your-domain.com/api/v1/auth/oauth/linuxdo/callback" + redirect_url: "" + # 安全提示: + # - 建议使用同源相对路径(以 / 开头),避免把 token 重定向到意外的第三方域名 + # - 该地址不应包含 #fragment(本实现使用 URL fragment 传递 access_token) + frontend_redirect_url: "/auth/linuxdo/callback" + token_auth_method: "client_secret_post" # client_secret_post | client_secret_basic | none + # 注意:当 token_auth_method=none(public client)时,必须启用 PKCE + use_pkce: false + userinfo_email_path: "" + userinfo_id_path: "" + userinfo_username_path: "" + # ============================================================================= # Default Settings # 默认设置 diff --git a/frontend/src/i18n/locales/en.ts b/frontend/src/i18n/locales/en.ts index 2732d84d..745445bf 100644 --- a/frontend/src/i18n/locales/en.ts +++ b/frontend/src/i18n/locales/en.ts @@ -229,6 +229,15 @@ export default { sendingCode: 'Sending...', clickToResend: 'Click to resend code', resendCode: 'Resend verification code', + linuxdo: { + signIn: 'Continue with Linux.do', + orContinue: 'or continue with email', + callbackTitle: 'Signing you in', + callbackProcessing: 'Completing login, please wait...', + callbackHint: 'If you are not redirected automatically, go back to the login page and try again.', + callbackMissingToken: 'Missing login token, please try again.', + backToLogin: 'Back to Login' + }, oauth: { code: 'Code', state: 'State', diff --git a/frontend/src/i18n/locales/zh.ts b/frontend/src/i18n/locales/zh.ts index 40aa39ab..83df3ddc 100644 --- a/frontend/src/i18n/locales/zh.ts +++ b/frontend/src/i18n/locales/zh.ts @@ -227,6 +227,15 @@ export default { sendingCode: '发送中...', clickToResend: '点击重新发送验证码', resendCode: '重新发送验证码', + linuxdo: { + signIn: '使用 Linux.do 登录', + orContinue: '或使用邮箱密码继续', + callbackTitle: '正在完成登录', + callbackProcessing: '正在验证登录信息,请稍候...', + callbackHint: '如果页面未自动跳转,请返回登录页重试。', + callbackMissingToken: '登录信息缺失,请返回重试。', + backToLogin: '返回登录' + }, oauth: { code: '授权码', state: '状态', diff --git a/frontend/src/router/index.ts b/frontend/src/router/index.ts index 48a6f0fd..238982ef 100644 --- a/frontend/src/router/index.ts +++ b/frontend/src/router/index.ts @@ -67,6 +67,15 @@ const routes: RouteRecordRaw[] = [ title: 'OAuth Callback' } }, + { + path: '/auth/linuxdo/callback', + name: 'LinuxDoOAuthCallback', + component: () => import('@/views/auth/LinuxDoCallbackView.vue'), + meta: { + requiresAuth: false, + title: 'LinuxDo OAuth Callback' + } + }, // ==================== User Routes ==================== { diff --git a/frontend/src/stores/app.ts b/frontend/src/stores/app.ts index cfc9d677..d91a9b7e 100644 --- a/frontend/src/stores/app.ts +++ b/frontend/src/stores/app.ts @@ -282,23 +282,24 @@ export const useAppStore = defineStore('app', () => { * Fetch public settings (uses cache unless force=true) * @param force - Force refresh from API */ - async function fetchPublicSettings(force = false): Promise { - // Return cached data if available and not forcing refresh - if (publicSettingsLoaded.value && !force) { - return { - registration_enabled: false, - email_verify_enabled: false, - turnstile_enabled: false, - turnstile_site_key: '', - site_name: siteName.value, - site_logo: siteLogo.value, - site_subtitle: '', - api_base_url: apiBaseUrl.value, - contact_info: contactInfo.value, - doc_url: docUrl.value, - version: siteVersion.value - } - } + async function fetchPublicSettings(force = false): Promise { + // Return cached data if available and not forcing refresh + if (publicSettingsLoaded.value && !force) { + return { + registration_enabled: false, + email_verify_enabled: false, + turnstile_enabled: false, + turnstile_site_key: '', + site_name: siteName.value, + site_logo: siteLogo.value, + site_subtitle: '', + api_base_url: apiBaseUrl.value, + contact_info: contactInfo.value, + doc_url: docUrl.value, + linuxdo_oauth_enabled: false, + version: siteVersion.value + } + } // Prevent duplicate requests if (publicSettingsLoading.value) { diff --git a/frontend/src/stores/auth.ts b/frontend/src/stores/auth.ts index 27faaf4b..3f22a9d3 100644 --- a/frontend/src/stores/auth.ts +++ b/frontend/src/stores/auth.ts @@ -159,6 +159,27 @@ export const useAuthStore = defineStore('auth', () => { } } + /** + * Set token directly (OAuth/SSO callback) and load current user profile. + * @param newToken - JWT access token issued by backend + */ + async function setToken(newToken: string): Promise { + // Clear any previous state first (avoid mixing sessions) + clearAuth() + + token.value = newToken + localStorage.setItem(AUTH_TOKEN_KEY, newToken) + + try { + const userData = await refreshUser() + startAutoRefresh() + return userData + } catch (error) { + clearAuth() + throw error + } + } + /** * User logout * Clears all authentication state and persisted data @@ -233,6 +254,7 @@ export const useAuthStore = defineStore('auth', () => { // Actions login, register, + setToken, logout, checkAuth, refreshUser diff --git a/frontend/src/types/index.ts b/frontend/src/types/index.ts index eaea24be..4b8fff09 100644 --- a/frontend/src/types/index.ts +++ b/frontend/src/types/index.ts @@ -73,6 +73,7 @@ export interface PublicSettings { api_base_url: string contact_info: string doc_url: string + linuxdo_oauth_enabled: boolean version: string } diff --git a/frontend/src/views/auth/LinuxDoCallbackView.vue b/frontend/src/views/auth/LinuxDoCallbackView.vue new file mode 100644 index 00000000..c6f93e6b --- /dev/null +++ b/frontend/src/views/auth/LinuxDoCallbackView.vue @@ -0,0 +1,119 @@ + + + + + + diff --git a/frontend/src/views/auth/LoginView.vue b/frontend/src/views/auth/LoginView.vue index 903db100..a6b5d2b2 100644 --- a/frontend/src/views/auth/LoginView.vue +++ b/frontend/src/views/auth/LoginView.vue @@ -11,6 +11,51 @@

+ +
+ + +
+
+ + {{ t('auth.linuxdo.orContinue') }} + +
+
+
+
@@ -179,6 +224,7 @@ const showPassword = ref(false) // Public settings const turnstileEnabled = ref(false) const turnstileSiteKey = ref('') +const linuxdoOAuthEnabled = ref(false) // Turnstile const turnstileRef = ref | null>(null) @@ -210,6 +256,7 @@ onMounted(async () => { const settings = await getPublicSettings() turnstileEnabled.value = settings.turnstile_enabled turnstileSiteKey.value = settings.turnstile_site_key || '' + linuxdoOAuthEnabled.value = settings.linuxdo_oauth_enabled } catch (error) { console.error('Failed to load public settings:', error) } @@ -320,6 +367,14 @@ async function handleLogin(): Promise { isLoading.value = false } } + +function handleLinuxDoLogin(): void { + const redirectTo = (router.currentRoute.value.query.redirect as string) || '/dashboard' + const apiBase = (import.meta.env.VITE_API_BASE_URL as string | undefined) || '/api/v1' + const normalized = apiBase.replace(/\/$/, '') + const startURL = `${normalized}/auth/oauth/linuxdo/start?redirect=${encodeURIComponent(redirectTo)}` + window.location.href = startURL +}