diff --git a/Linux DO Connect.md b/Linux DO Connect.md new file mode 100644 index 00000000..7ca1260f --- /dev/null +++ b/Linux DO Connect.md @@ -0,0 +1,368 @@ +# Linux DO Connect + +OAuth(Open Authorization)是一个开放的网络授权标准,目前最新版本为 OAuth 2.0。我们日常使用的第三方登录(如 Google 账号登录)就采用了该标准。OAuth 允许用户授权第三方应用访问存储在其他服务提供商(如 Google)上的信息,无需在不同平台上重复填写注册信息。用户授权后,平台可以直接访问用户的账户信息进行身份验证,而用户无需向第三方应用提供密码。 + +目前系统已实现完整的 OAuth2 授权码(code)方式鉴权,但界面等配套功能还在持续完善中。让我们一起打造一个更完善的共享方案。 + +## 基本介绍 + +这是一套标准的 OAuth2 鉴权系统,可以让开发者共享论坛的用户基本信息。 + +- 可获取字段: + +| 参数 | 说明 | +| ----------------- | ------------------------------- | +| `id` | 用户唯一标识(不可变) | +| `username` | 论坛用户名 | +| `name` | 论坛用户昵称(可变) | +| `avatar_template` | 用户头像模板URL(支持多种尺寸) | +| `active` | 账号活跃状态 | +| `trust_level` | 信任等级(0-4) | +| `silenced` | 禁言状态 | +| `external_ids` | 外部ID关联信息 | +| `api_key` | API访问密钥 | + +通过这些信息,公益网站/接口可以实现: + +1. 基于 `id` 的服务频率限制 +2. 基于 `trust_level` 的服务额度分配 +3. 基于用户信息的滥用举报机制 + +## 相关端点 + +- Authorize 端点: `https://connect.linux.do/oauth2/authorize` +- Token 端点:`https://connect.linux.do/oauth2/token` +- 用户信息 端点:`https://connect.linux.do/api/user` + +## 申请使用 + +- 访问 [Connect.Linux.Do](https://connect.linux.do/) 申请接入你的应用。 + + + +- 点击 **`我的应用接入`** - **`申请新接入`**,填写相关信息。其中 **`回调地址`** 是你的应用接收用户信息的地址。 + + + +- 申请成功后,你将获得 **`Client Id`** 和 **`Client Secret`**,这是你应用的唯一身份凭证。 + + + +## 接入 Linux Do + +JavaScript +```JavaScript +// 安装第三方请求库(或使用原生的 Fetch API),本例中使用 axios +// npm install axios + +// 通过 OAuth2 获取 Linux Do 用户信息的参考流程 +const axios = require('axios'); +const readline = require('readline'); + +// 配置信息(建议通过环境变量配置,避免使用硬编码) +const CLIENT_ID = '你的 Client ID'; +const CLIENT_SECRET = '你的 Client Secret'; +const REDIRECT_URI = '你的回调地址'; +const AUTH_URL = 'https://connect.linux.do/oauth2/authorize'; +const TOKEN_URL = 'https://connect.linux.do/oauth2/token'; +const USER_INFO_URL = 'https://connect.linux.do/api/user'; + +// 第一步:生成授权 URL +function getAuthUrl() { + const params = new URLSearchParams({ + client_id: CLIENT_ID, + redirect_uri: REDIRECT_URI, + response_type: 'code', + scope: 'user' + }); + + return `${AUTH_URL}?${params.toString()}`; +} + +// 第二步:获取 code 参数 +function getCode() { + return new Promise((resolve) => { + // 本例中使用终端输入来模拟流程,仅供本地测试 + // 请在实际应用中替换为真实的处理逻辑 + const rl = readline.createInterface({ input: process.stdin, output: process.stdout }); + rl.question('从回调 URL 中提取出 code,粘贴到此处并按回车:', (answer) => { + rl.close(); + resolve(answer.trim()); + }); + }); +} + +// 第三步:使用 code 参数获取访问令牌 +async function getAccessToken(code) { + try { + const form = new URLSearchParams({ + client_id: CLIENT_ID, + client_secret: CLIENT_SECRET, + code: code, + redirect_uri: REDIRECT_URI, + grant_type: 'authorization_code' + }).toString(); + + const response = await axios.post(TOKEN_URL, form, { + // 提醒:需正确配置请求头,否则无法正常获取访问令牌 + headers: { + 'Content-Type': 'application/x-www-form-urlencoded', + 'Accept': 'application/json' + } + }); + + return response.data; + } catch (error) { + console.error(`获取访问令牌失败:${error.response ? JSON.stringify(error.response.data) : error.message}`); + throw error; + } +} + +// 第四步:使用访问令牌获取用户信息 +async function getUserInfo(accessToken) { + try { + const response = await axios.get(USER_INFO_URL, { + headers: { + Authorization: `Bearer ${accessToken}` + } + }); + + return response.data; + } catch (error) { + console.error(`获取用户信息失败:${error.response ? JSON.stringify(error.response.data) : error.message}`); + throw error; + } +} + +// 主流程 +async function main() { + // 1. 生成授权 URL,前端引导用户访问授权页 + const authUrl = getAuthUrl(); + console.log(`请访问此 URL 授权:${authUrl} +`); + + // 2. 用户授权后,从回调 URL 获取 code 参数 + const code = await getCode(); + + try { + // 3. 使用 code 参数获取访问令牌 + const tokenData = await getAccessToken(code); + const accessToken = tokenData.access_token; + + // 4. 使用访问令牌获取用户信息 + if (accessToken) { + const userInfo = await getUserInfo(accessToken); + console.log(` +获取用户信息成功:${JSON.stringify(userInfo, null, 2)}`); + } else { + console.log(` +获取访问令牌失败:${JSON.stringify(tokenData)}`); + } + } catch (error) { + console.error('发生错误:', error); + } +} +``` +Python +```python +# 安装第三方请求库,本例中使用 requests +# pip install requests + +# 通过 OAuth2 获取 Linux Do 用户信息的参考流程 +import requests +import json + +# 配置信息(建议通过环境变量配置,避免使用硬编码) +CLIENT_ID = '你的 Client ID' +CLIENT_SECRET = '你的 Client Secret' +REDIRECT_URI = '你的回调地址' +AUTH_URL = 'https://connect.linux.do/oauth2/authorize' +TOKEN_URL = 'https://connect.linux.do/oauth2/token' +USER_INFO_URL = 'https://connect.linux.do/api/user' + +# 第一步:生成授权 URL +def get_auth_url(): + params = { + 'client_id': CLIENT_ID, + 'redirect_uri': REDIRECT_URI, + 'response_type': 'code', + 'scope': 'user' + } + auth_url = f"{AUTH_URL}?{'&'.join(f'{k}={v}' for k, v in params.items())}" + return auth_url + +# 第二步:获取 code 参数 +def get_code(): + # 本例中使用终端输入来模拟流程,仅供本地测试 + # 请在实际应用中替换为真实的处理逻辑 + return input('从回调 URL 中提取出 code,粘贴到此处并按回车:').strip() + +# 第三步:使用 code 参数获取访问令牌 +def get_access_token(code): + try: + data = { + 'client_id': CLIENT_ID, + 'client_secret': CLIENT_SECRET, + 'code': code, + 'redirect_uri': REDIRECT_URI, + 'grant_type': 'authorization_code' + } + # 提醒:需正确配置请求头,否则无法正常获取访问令牌 + headers = { + 'Content-Type': 'application/x-www-form-urlencoded', + 'Accept': 'application/json' + } + response = requests.post(TOKEN_URL, data=data, headers=headers) + response.raise_for_status() + return response.json() + except requests.exceptions.RequestException as e: + print(f"获取访问令牌失败:{e}") + return None + +# 第四步:使用访问令牌获取用户信息 +def get_user_info(access_token): + try: + headers = { + 'Authorization': f'Bearer {access_token}' + } + response = requests.get(USER_INFO_URL, headers=headers) + response.raise_for_status() + return response.json() + except requests.exceptions.RequestException as e: + print(f"获取用户信息失败:{e}") + return None + +# 主流程 +if __name__ == '__main__': + # 1. 生成授权 URL,前端引导用户访问授权页 + auth_url = get_auth_url() + print(f'请访问此 URL 授权:{auth_url} +') + + # 2. 用户授权后,从回调 URL 获取 code 参数 + code = get_code() + + # 3. 使用 code 参数获取访问令牌 + token_data = get_access_token(code) + if token_data: + access_token = token_data.get('access_token') + + # 4. 使用访问令牌获取用户信息 + if access_token: + user_info = get_user_info(access_token) + if user_info: + print(f" +获取用户信息成功:{json.dumps(user_info, indent=2)}") + else: + print(" +获取用户信息失败") + else: + print(f" +获取访问令牌失败:{json.dumps(token_data, indent=2)}") + else: + print(" +获取访问令牌失败") +``` +PHP +```php +// 通过 OAuth2 获取 Linux Do 用户信息的参考流程 + +// 配置信息 +$CLIENT_ID = '你的 Client ID'; +$CLIENT_SECRET = '你的 Client Secret'; +$REDIRECT_URI = '你的回调地址'; +$AUTH_URL = 'https://connect.linux.do/oauth2/authorize'; +$TOKEN_URL = 'https://connect.linux.do/oauth2/token'; +$USER_INFO_URL = 'https://connect.linux.do/api/user'; + +// 生成授权 URL +function getAuthUrl($clientId, $redirectUri) { + global $AUTH_URL; + return $AUTH_URL . '?' . http_build_query([ + 'client_id' => $clientId, + 'redirect_uri' => $redirectUri, + 'response_type' => 'code', + 'scope' => 'user' + ]); +} + +// 使用 code 参数获取用户信息(合并获取令牌和获取用户信息的步骤) +function getUserInfoWithCode($code, $clientId, $clientSecret, $redirectUri) { + global $TOKEN_URL, $USER_INFO_URL; + + // 1. 获取访问令牌 + $ch = curl_init($TOKEN_URL); + curl_setopt($ch, CURLOPT_RETURNTRANSFER, true); + curl_setopt($ch, CURLOPT_POST, true); + curl_setopt($ch, CURLOPT_POSTFIELDS, http_build_query([ + 'client_id' => $clientId, + 'client_secret' => $clientSecret, + 'code' => $code, + 'redirect_uri' => $redirectUri, + 'grant_type' => 'authorization_code' + ])); + curl_setopt($ch, CURLOPT_HTTPHEADER, [ + 'Content-Type: application/x-www-form-urlencoded', + 'Accept: application/json' + ]); + + $tokenResponse = curl_exec($ch); + curl_close($ch); + + $tokenData = json_decode($tokenResponse, true); + if (!isset($tokenData['access_token'])) { + return ['error' => '获取访问令牌失败', 'details' => $tokenData]; + } + + // 2. 获取用户信息 + $ch = curl_init($USER_INFO_URL); + curl_setopt($ch, CURLOPT_RETURNTRANSFER, true); + curl_setopt($ch, CURLOPT_HTTPHEADER, [ + 'Authorization: Bearer ' . $tokenData['access_token'] + ]); + + $userResponse = curl_exec($ch); + curl_close($ch); + + return json_decode($userResponse, true); +} + +// 主流程 +// 1. 生成授权 URL +$authUrl = getAuthUrl($CLIENT_ID, $REDIRECT_URI); +echo "使用 Linux Do 登录"; + +// 2. 处理回调并获取用户信息 +if (isset($_GET['code'])) { + $userInfo = getUserInfoWithCode( + $_GET['code'], + $CLIENT_ID, + $CLIENT_SECRET, + $REDIRECT_URI + ); + + if (isset($userInfo['error'])) { + echo '错误: ' . $userInfo['error']; + } else { + echo '欢迎, ' . $userInfo['name'] . '!'; + // 处理用户登录逻辑... + } +} +``` + +## 使用说明 + +### 授权流程 + +1. 用户点击应用中的’使用 Linux Do 登录’按钮 +2. 系统将用户重定向至 Linux Do 的授权页面 +3. 用户完成授权后,系统自动重定向回应用并携带授权码 +4. 应用使用授权码获取访问令牌 +5. 使用访问令牌获取用户信息 + +### 安全建议 + +- 切勿在前端代码中暴露 Client Secret +- 对所有用户输入数据进行严格验证 +- 确保使用 HTTPS 协议传输数据 +- 定期更新并妥善保管 Client Secret \ No newline at end of file diff --git a/backend/cmd/server/VERSION b/backend/cmd/server/VERSION index 17e51c38..79e0dd8a 100644 --- a/backend/cmd/server/VERSION +++ b/backend/cmd/server/VERSION @@ -1 +1 @@ -0.1.1 +0.1.46 diff --git a/backend/cmd/server/wire_gen.go b/backend/cmd/server/wire_gen.go index 31dc3682..85bed3f3 100644 --- a/backend/cmd/server/wire_gen.go +++ b/backend/cmd/server/wire_gen.go @@ -53,7 +53,7 @@ func initializeApplication(buildInfo handler.BuildInfo) (*Application, error) { emailQueueService := service.ProvideEmailQueueService(emailService) authService := service.NewAuthService(userRepository, configConfig, settingService, emailService, turnstileService, emailQueueService) userService := service.NewUserService(userRepository) - authHandler := handler.NewAuthHandler(configConfig, authService, userService) + authHandler := handler.NewAuthHandler(configConfig, authService, userService, settingService) userHandler := handler.NewUserHandler(userService) apiKeyRepository := repository.NewAPIKeyRepository(client) groupRepository := repository.NewGroupRepository(client, db) diff --git a/backend/internal/config/config.go b/backend/internal/config/config.go index c1e15290..84a14ca2 100644 --- a/backend/internal/config/config.go +++ b/backend/internal/config/config.go @@ -6,6 +6,7 @@ import ( "encoding/hex" "fmt" "log" + "net/url" "os" "strings" "time" @@ -35,24 +36,25 @@ const ( ) type Config struct { - Server ServerConfig `mapstructure:"server"` - CORS CORSConfig `mapstructure:"cors"` - Security SecurityConfig `mapstructure:"security"` - Billing BillingConfig `mapstructure:"billing"` - Turnstile TurnstileConfig `mapstructure:"turnstile"` - Database DatabaseConfig `mapstructure:"database"` - Redis RedisConfig `mapstructure:"redis"` - JWT JWTConfig `mapstructure:"jwt"` - Default DefaultConfig `mapstructure:"default"` - RateLimit RateLimitConfig `mapstructure:"rate_limit"` - Pricing PricingConfig `mapstructure:"pricing"` - Gateway GatewayConfig `mapstructure:"gateway"` - Concurrency ConcurrencyConfig `mapstructure:"concurrency"` - TokenRefresh TokenRefreshConfig `mapstructure:"token_refresh"` - RunMode string `mapstructure:"run_mode" yaml:"run_mode"` - Timezone string `mapstructure:"timezone"` // e.g. "Asia/Shanghai", "UTC" - Gemini GeminiConfig `mapstructure:"gemini"` - Update UpdateConfig `mapstructure:"update"` + Server ServerConfig `mapstructure:"server"` + CORS CORSConfig `mapstructure:"cors"` + Security SecurityConfig `mapstructure:"security"` + Billing BillingConfig `mapstructure:"billing"` + Turnstile TurnstileConfig `mapstructure:"turnstile"` + Database DatabaseConfig `mapstructure:"database"` + Redis RedisConfig `mapstructure:"redis"` + JWT JWTConfig `mapstructure:"jwt"` + LinuxDo LinuxDoConnectConfig `mapstructure:"linuxdo_connect"` + Default DefaultConfig `mapstructure:"default"` + RateLimit RateLimitConfig `mapstructure:"rate_limit"` + Pricing PricingConfig `mapstructure:"pricing"` + Gateway GatewayConfig `mapstructure:"gateway"` + Concurrency ConcurrencyConfig `mapstructure:"concurrency"` + TokenRefresh TokenRefreshConfig `mapstructure:"token_refresh"` + RunMode string `mapstructure:"run_mode" yaml:"run_mode"` + Timezone string `mapstructure:"timezone"` // e.g. "Asia/Shanghai", "UTC" + Gemini GeminiConfig `mapstructure:"gemini"` + Update UpdateConfig `mapstructure:"update"` } // UpdateConfig 在线更新相关配置 @@ -322,6 +324,30 @@ type TurnstileConfig struct { Required bool `mapstructure:"required"` } +// LinuxDoConnectConfig 用于 LinuxDo Connect OAuth 登录(终端用户 SSO)。 +// +// 注意:这与上游账号的 OAuth(例如 OpenAI/Gemini 账号接入)不是一回事。 +// 这里是用于登录 Sub2API 本身的用户体系。 +type LinuxDoConnectConfig struct { + Enabled bool `mapstructure:"enabled"` + ClientID string `mapstructure:"client_id"` + ClientSecret string `mapstructure:"client_secret"` + AuthorizeURL string `mapstructure:"authorize_url"` + TokenURL string `mapstructure:"token_url"` + UserInfoURL string `mapstructure:"userinfo_url"` + Scopes string `mapstructure:"scopes"` + RedirectURL string `mapstructure:"redirect_url"` // 后端回调地址(需在提供方后台登记) + FrontendRedirectURL string `mapstructure:"frontend_redirect_url"` // 前端接收 token 的路由(默认:/auth/linuxdo/callback) + TokenAuthMethod string `mapstructure:"token_auth_method"` // client_secret_post / client_secret_basic / none + UsePKCE bool `mapstructure:"use_pkce"` + + // 可选:用于从 userinfo JSON 中提取字段的 gjson 路径。 + // 为空时,服务端会尝试一组常见字段名。 + UserInfoEmailPath string `mapstructure:"userinfo_email_path"` + UserInfoIDPath string `mapstructure:"userinfo_id_path"` + UserInfoUsernamePath string `mapstructure:"userinfo_username_path"` +} + type DefaultConfig struct { AdminEmail string `mapstructure:"admin_email"` AdminPassword string `mapstructure:"admin_password"` @@ -388,6 +414,18 @@ func Load() (*Config, error) { cfg.Server.Mode = "debug" } cfg.JWT.Secret = strings.TrimSpace(cfg.JWT.Secret) + cfg.LinuxDo.ClientID = strings.TrimSpace(cfg.LinuxDo.ClientID) + cfg.LinuxDo.ClientSecret = strings.TrimSpace(cfg.LinuxDo.ClientSecret) + cfg.LinuxDo.AuthorizeURL = strings.TrimSpace(cfg.LinuxDo.AuthorizeURL) + cfg.LinuxDo.TokenURL = strings.TrimSpace(cfg.LinuxDo.TokenURL) + cfg.LinuxDo.UserInfoURL = strings.TrimSpace(cfg.LinuxDo.UserInfoURL) + cfg.LinuxDo.Scopes = strings.TrimSpace(cfg.LinuxDo.Scopes) + cfg.LinuxDo.RedirectURL = strings.TrimSpace(cfg.LinuxDo.RedirectURL) + cfg.LinuxDo.FrontendRedirectURL = strings.TrimSpace(cfg.LinuxDo.FrontendRedirectURL) + cfg.LinuxDo.TokenAuthMethod = strings.ToLower(strings.TrimSpace(cfg.LinuxDo.TokenAuthMethod)) + cfg.LinuxDo.UserInfoEmailPath = strings.TrimSpace(cfg.LinuxDo.UserInfoEmailPath) + cfg.LinuxDo.UserInfoIDPath = strings.TrimSpace(cfg.LinuxDo.UserInfoIDPath) + cfg.LinuxDo.UserInfoUsernamePath = strings.TrimSpace(cfg.LinuxDo.UserInfoUsernamePath) cfg.CORS.AllowedOrigins = normalizeStringSlice(cfg.CORS.AllowedOrigins) cfg.Security.ResponseHeaders.AdditionalAllowed = normalizeStringSlice(cfg.Security.ResponseHeaders.AdditionalAllowed) cfg.Security.ResponseHeaders.ForceRemove = normalizeStringSlice(cfg.Security.ResponseHeaders.ForceRemove) @@ -426,6 +464,81 @@ func Load() (*Config, error) { return &cfg, nil } +// ValidateAbsoluteHTTPURL 校验一个绝对 http(s) URL(禁止 fragment)。 +func ValidateAbsoluteHTTPURL(raw string) error { + raw = strings.TrimSpace(raw) + if raw == "" { + return fmt.Errorf("empty url") + } + u, err := url.Parse(raw) + if err != nil { + return err + } + if !u.IsAbs() { + return fmt.Errorf("must be absolute") + } + if !isHTTPScheme(u.Scheme) { + return fmt.Errorf("unsupported scheme: %s", u.Scheme) + } + if strings.TrimSpace(u.Host) == "" { + return fmt.Errorf("missing host") + } + if u.Fragment != "" { + return fmt.Errorf("must not include fragment") + } + return nil +} + +// ValidateFrontendRedirectURL 校验前端回调地址: +// - 允许同源相对路径(以 / 开头) +// - 或绝对 http(s) URL(禁止 fragment) +func ValidateFrontendRedirectURL(raw string) error { + raw = strings.TrimSpace(raw) + if raw == "" { + return fmt.Errorf("empty url") + } + if strings.ContainsAny(raw, "\r\n") { + return fmt.Errorf("contains invalid characters") + } + if strings.HasPrefix(raw, "/") { + if strings.HasPrefix(raw, "//") { + return fmt.Errorf("must not start with //") + } + return nil + } + u, err := url.Parse(raw) + if err != nil { + return err + } + if !u.IsAbs() { + return fmt.Errorf("must be absolute http(s) url or relative path") + } + if !isHTTPScheme(u.Scheme) { + return fmt.Errorf("unsupported scheme: %s", u.Scheme) + } + if strings.TrimSpace(u.Host) == "" { + return fmt.Errorf("missing host") + } + if u.Fragment != "" { + return fmt.Errorf("must not include fragment") + } + return nil +} + +func isHTTPScheme(scheme string) bool { + return strings.EqualFold(scheme, "http") || strings.EqualFold(scheme, "https") +} + +func warnIfInsecureURL(field, raw string) { + u, err := url.Parse(strings.TrimSpace(raw)) + if err != nil { + return + } + if strings.EqualFold(u.Scheme, "http") { + log.Printf("Warning: %s uses http scheme; use https in production to avoid token leakage.", field) + } +} + func setDefaults() { viper.SetDefault("run_mode", RunModeStandard) @@ -475,6 +588,22 @@ func setDefaults() { // Turnstile viper.SetDefault("turnstile.required", false) + // LinuxDo Connect OAuth 登录(终端用户 SSO) + viper.SetDefault("linuxdo_connect.enabled", false) + viper.SetDefault("linuxdo_connect.client_id", "") + viper.SetDefault("linuxdo_connect.client_secret", "") + viper.SetDefault("linuxdo_connect.authorize_url", "https://connect.linux.do/oauth2/authorize") + viper.SetDefault("linuxdo_connect.token_url", "https://connect.linux.do/oauth2/token") + viper.SetDefault("linuxdo_connect.userinfo_url", "https://connect.linux.do/api/user") + viper.SetDefault("linuxdo_connect.scopes", "user") + viper.SetDefault("linuxdo_connect.redirect_url", "") + viper.SetDefault("linuxdo_connect.frontend_redirect_url", "/auth/linuxdo/callback") + viper.SetDefault("linuxdo_connect.token_auth_method", "client_secret_post") + viper.SetDefault("linuxdo_connect.use_pkce", false) + viper.SetDefault("linuxdo_connect.userinfo_email_path", "") + viper.SetDefault("linuxdo_connect.userinfo_id_path", "") + viper.SetDefault("linuxdo_connect.userinfo_username_path", "") + // Database viper.SetDefault("database.host", "localhost") viper.SetDefault("database.port", 5432) @@ -586,6 +715,60 @@ func (c *Config) Validate() error { if c.Security.CSP.Enabled && strings.TrimSpace(c.Security.CSP.Policy) == "" { return fmt.Errorf("security.csp.policy is required when CSP is enabled") } + if c.LinuxDo.Enabled { + if strings.TrimSpace(c.LinuxDo.ClientID) == "" { + return fmt.Errorf("linuxdo_connect.client_id is required when linuxdo_connect.enabled=true") + } + if strings.TrimSpace(c.LinuxDo.AuthorizeURL) == "" { + return fmt.Errorf("linuxdo_connect.authorize_url is required when linuxdo_connect.enabled=true") + } + if strings.TrimSpace(c.LinuxDo.TokenURL) == "" { + return fmt.Errorf("linuxdo_connect.token_url is required when linuxdo_connect.enabled=true") + } + if strings.TrimSpace(c.LinuxDo.UserInfoURL) == "" { + return fmt.Errorf("linuxdo_connect.userinfo_url is required when linuxdo_connect.enabled=true") + } + if strings.TrimSpace(c.LinuxDo.RedirectURL) == "" { + return fmt.Errorf("linuxdo_connect.redirect_url is required when linuxdo_connect.enabled=true") + } + method := strings.ToLower(strings.TrimSpace(c.LinuxDo.TokenAuthMethod)) + switch method { + case "", "client_secret_post", "client_secret_basic", "none": + default: + return fmt.Errorf("linuxdo_connect.token_auth_method must be one of: client_secret_post/client_secret_basic/none") + } + if method == "none" && !c.LinuxDo.UsePKCE { + return fmt.Errorf("linuxdo_connect.use_pkce must be true when linuxdo_connect.token_auth_method=none") + } + if (method == "" || method == "client_secret_post" || method == "client_secret_basic") && strings.TrimSpace(c.LinuxDo.ClientSecret) == "" { + return fmt.Errorf("linuxdo_connect.client_secret is required when linuxdo_connect.enabled=true and token_auth_method is client_secret_post/client_secret_basic") + } + if strings.TrimSpace(c.LinuxDo.FrontendRedirectURL) == "" { + return fmt.Errorf("linuxdo_connect.frontend_redirect_url is required when linuxdo_connect.enabled=true") + } + + if err := ValidateAbsoluteHTTPURL(c.LinuxDo.AuthorizeURL); err != nil { + return fmt.Errorf("linuxdo_connect.authorize_url invalid: %w", err) + } + if err := ValidateAbsoluteHTTPURL(c.LinuxDo.TokenURL); err != nil { + return fmt.Errorf("linuxdo_connect.token_url invalid: %w", err) + } + if err := ValidateAbsoluteHTTPURL(c.LinuxDo.UserInfoURL); err != nil { + return fmt.Errorf("linuxdo_connect.userinfo_url invalid: %w", err) + } + if err := ValidateAbsoluteHTTPURL(c.LinuxDo.RedirectURL); err != nil { + return fmt.Errorf("linuxdo_connect.redirect_url invalid: %w", err) + } + if err := ValidateFrontendRedirectURL(c.LinuxDo.FrontendRedirectURL); err != nil { + return fmt.Errorf("linuxdo_connect.frontend_redirect_url invalid: %w", err) + } + + warnIfInsecureURL("linuxdo_connect.authorize_url", c.LinuxDo.AuthorizeURL) + warnIfInsecureURL("linuxdo_connect.token_url", c.LinuxDo.TokenURL) + warnIfInsecureURL("linuxdo_connect.userinfo_url", c.LinuxDo.UserInfoURL) + warnIfInsecureURL("linuxdo_connect.redirect_url", c.LinuxDo.RedirectURL) + warnIfInsecureURL("linuxdo_connect.frontend_redirect_url", c.LinuxDo.FrontendRedirectURL) + } if c.Billing.CircuitBreaker.Enabled { if c.Billing.CircuitBreaker.FailureThreshold <= 0 { return fmt.Errorf("billing.circuit_breaker.failure_threshold must be positive") diff --git a/backend/internal/config/config_test.go b/backend/internal/config/config_test.go index f28680c6..a39d41f9 100644 --- a/backend/internal/config/config_test.go +++ b/backend/internal/config/config_test.go @@ -1,6 +1,7 @@ package config import ( + "strings" "testing" "time" @@ -90,3 +91,53 @@ func TestLoadDefaultSecurityToggles(t *testing.T) { t.Fatalf("ResponseHeaders.Enabled = true, want false") } } + +func TestValidateLinuxDoFrontendRedirectURL(t *testing.T) { + viper.Reset() + + cfg, err := Load() + if err != nil { + t.Fatalf("Load() error: %v", err) + } + + cfg.LinuxDo.Enabled = true + cfg.LinuxDo.ClientID = "test-client" + cfg.LinuxDo.ClientSecret = "test-secret" + cfg.LinuxDo.RedirectURL = "https://example.com/api/v1/auth/oauth/linuxdo/callback" + cfg.LinuxDo.TokenAuthMethod = "client_secret_post" + cfg.LinuxDo.UsePKCE = false + + cfg.LinuxDo.FrontendRedirectURL = "javascript:alert(1)" + err = cfg.Validate() + if err == nil { + t.Fatalf("Validate() expected error for javascript scheme, got nil") + } + if !strings.Contains(err.Error(), "linuxdo_connect.frontend_redirect_url") { + t.Fatalf("Validate() expected frontend_redirect_url error, got: %v", err) + } +} + +func TestValidateLinuxDoPKCERequiredForPublicClient(t *testing.T) { + viper.Reset() + + cfg, err := Load() + if err != nil { + t.Fatalf("Load() error: %v", err) + } + + cfg.LinuxDo.Enabled = true + cfg.LinuxDo.ClientID = "test-client" + cfg.LinuxDo.ClientSecret = "" + cfg.LinuxDo.RedirectURL = "https://example.com/api/v1/auth/oauth/linuxdo/callback" + cfg.LinuxDo.FrontendRedirectURL = "/auth/linuxdo/callback" + cfg.LinuxDo.TokenAuthMethod = "none" + cfg.LinuxDo.UsePKCE = false + + err = cfg.Validate() + if err == nil { + t.Fatalf("Validate() expected error when token_auth_method=none and use_pkce=false, got nil") + } + if !strings.Contains(err.Error(), "linuxdo_connect.use_pkce") { + t.Fatalf("Validate() expected use_pkce error, got: %v", err) + } +} diff --git a/backend/internal/handler/admin/account_handler.go b/backend/internal/handler/admin/account_handler.go index da9f6990..8a7270e5 100644 --- a/backend/internal/handler/admin/account_handler.go +++ b/backend/internal/handler/admin/account_handler.go @@ -116,6 +116,7 @@ type BulkUpdateAccountsRequest struct { Concurrency *int `json:"concurrency"` Priority *int `json:"priority"` Status string `json:"status" binding:"omitempty,oneof=active inactive error"` + Schedulable *bool `json:"schedulable"` GroupIDs *[]int64 `json:"group_ids"` Credentials map[string]any `json:"credentials"` Extra map[string]any `json:"extra"` @@ -136,6 +137,11 @@ func (h *AccountHandler) List(c *gin.Context) { accountType := c.Query("type") status := c.Query("status") search := c.Query("search") + // 标准化和验证 search 参数 + search = strings.TrimSpace(search) + if len(search) > 100 { + search = search[:100] + } accounts, total, err := h.adminService.ListAccounts(c.Request.Context(), page, pageSize, platform, accountType, status, search) if err != nil { @@ -655,6 +661,7 @@ func (h *AccountHandler) BulkUpdate(c *gin.Context) { req.Concurrency != nil || req.Priority != nil || req.Status != "" || + req.Schedulable != nil || req.GroupIDs != nil || len(req.Credentials) > 0 || len(req.Extra) > 0 @@ -671,6 +678,7 @@ func (h *AccountHandler) BulkUpdate(c *gin.Context) { Concurrency: req.Concurrency, Priority: req.Priority, Status: req.Status, + Schedulable: req.Schedulable, GroupIDs: req.GroupIDs, Credentials: req.Credentials, Extra: req.Extra, diff --git a/backend/internal/handler/admin/group_handler.go b/backend/internal/handler/admin/group_handler.go index acb9462c..a8bae35e 100644 --- a/backend/internal/handler/admin/group_handler.go +++ b/backend/internal/handler/admin/group_handler.go @@ -2,6 +2,7 @@ package admin import ( "strconv" + "strings" "github.com/Wei-Shaw/sub2api/internal/handler/dto" "github.com/Wei-Shaw/sub2api/internal/pkg/response" @@ -67,6 +68,12 @@ func (h *GroupHandler) List(c *gin.Context) { page, pageSize := response.ParsePagination(c) platform := c.Query("platform") status := c.Query("status") + search := c.Query("search") + // 标准化和验证 search 参数 + search = strings.TrimSpace(search) + if len(search) > 100 { + search = search[:100] + } isExclusiveStr := c.Query("is_exclusive") var isExclusive *bool @@ -75,7 +82,7 @@ func (h *GroupHandler) List(c *gin.Context) { isExclusive = &val } - groups, total, err := h.adminService.ListGroups(c.Request.Context(), page, pageSize, platform, status, isExclusive) + groups, total, err := h.adminService.ListGroups(c.Request.Context(), page, pageSize, platform, status, search, isExclusive) if err != nil { response.ErrorFrom(c, err) return diff --git a/backend/internal/handler/admin/proxy_handler.go b/backend/internal/handler/admin/proxy_handler.go index 4fabd8ec..437e9300 100644 --- a/backend/internal/handler/admin/proxy_handler.go +++ b/backend/internal/handler/admin/proxy_handler.go @@ -51,6 +51,11 @@ func (h *ProxyHandler) List(c *gin.Context) { protocol := c.Query("protocol") status := c.Query("status") search := c.Query("search") + // 标准化和验证 search 参数 + search = strings.TrimSpace(search) + if len(search) > 100 { + search = search[:100] + } proxies, total, err := h.adminService.ListProxiesWithAccountCount(c.Request.Context(), page, pageSize, protocol, status, search) if err != nil { diff --git a/backend/internal/handler/admin/redeem_handler.go b/backend/internal/handler/admin/redeem_handler.go index 45fae43a..5b3229b6 100644 --- a/backend/internal/handler/admin/redeem_handler.go +++ b/backend/internal/handler/admin/redeem_handler.go @@ -5,6 +5,7 @@ import ( "encoding/csv" "fmt" "strconv" + "strings" "github.com/Wei-Shaw/sub2api/internal/handler/dto" "github.com/Wei-Shaw/sub2api/internal/pkg/response" @@ -41,6 +42,11 @@ func (h *RedeemHandler) List(c *gin.Context) { codeType := c.Query("type") status := c.Query("status") search := c.Query("search") + // 标准化和验证 search 参数 + search = strings.TrimSpace(search) + if len(search) > 100 { + search = search[:100] + } codes, total, err := h.adminService.ListRedeemCodes(c.Request.Context(), page, pageSize, codeType, status, search) if err != nil { diff --git a/backend/internal/handler/admin/setting_handler.go b/backend/internal/handler/admin/setting_handler.go index 743c4268..d95a8980 100644 --- a/backend/internal/handler/admin/setting_handler.go +++ b/backend/internal/handler/admin/setting_handler.go @@ -2,8 +2,10 @@ package admin import ( "log" + "strings" "time" + "github.com/Wei-Shaw/sub2api/internal/config" "github.com/Wei-Shaw/sub2api/internal/handler/dto" "github.com/Wei-Shaw/sub2api/internal/pkg/response" "github.com/Wei-Shaw/sub2api/internal/server/middleware" @@ -38,33 +40,37 @@ func (h *SettingHandler) GetSettings(c *gin.Context) { } response.Success(c, dto.SystemSettings{ - RegistrationEnabled: settings.RegistrationEnabled, - EmailVerifyEnabled: settings.EmailVerifyEnabled, - SMTPHost: settings.SMTPHost, - SMTPPort: settings.SMTPPort, - SMTPUsername: settings.SMTPUsername, - SMTPPasswordConfigured: settings.SMTPPasswordConfigured, - SMTPFrom: settings.SMTPFrom, - SMTPFromName: settings.SMTPFromName, - SMTPUseTLS: settings.SMTPUseTLS, - TurnstileEnabled: settings.TurnstileEnabled, - TurnstileSiteKey: settings.TurnstileSiteKey, - TurnstileSecretKeyConfigured: settings.TurnstileSecretKeyConfigured, - SiteName: settings.SiteName, - SiteLogo: settings.SiteLogo, - SiteSubtitle: settings.SiteSubtitle, - APIBaseURL: settings.APIBaseURL, - ContactInfo: settings.ContactInfo, - DocURL: settings.DocURL, - DefaultConcurrency: settings.DefaultConcurrency, - DefaultBalance: settings.DefaultBalance, - EnableModelFallback: settings.EnableModelFallback, - FallbackModelAnthropic: settings.FallbackModelAnthropic, - FallbackModelOpenAI: settings.FallbackModelOpenAI, - FallbackModelGemini: settings.FallbackModelGemini, - FallbackModelAntigravity: settings.FallbackModelAntigravity, - EnableIdentityPatch: settings.EnableIdentityPatch, - IdentityPatchPrompt: settings.IdentityPatchPrompt, + RegistrationEnabled: settings.RegistrationEnabled, + EmailVerifyEnabled: settings.EmailVerifyEnabled, + SMTPHost: settings.SMTPHost, + SMTPPort: settings.SMTPPort, + SMTPUsername: settings.SMTPUsername, + SMTPPasswordConfigured: settings.SMTPPasswordConfigured, + SMTPFrom: settings.SMTPFrom, + SMTPFromName: settings.SMTPFromName, + SMTPUseTLS: settings.SMTPUseTLS, + TurnstileEnabled: settings.TurnstileEnabled, + TurnstileSiteKey: settings.TurnstileSiteKey, + TurnstileSecretKeyConfigured: settings.TurnstileSecretKeyConfigured, + LinuxDoConnectEnabled: settings.LinuxDoConnectEnabled, + LinuxDoConnectClientID: settings.LinuxDoConnectClientID, + LinuxDoConnectClientSecretConfigured: settings.LinuxDoConnectClientSecretConfigured, + LinuxDoConnectRedirectURL: settings.LinuxDoConnectRedirectURL, + SiteName: settings.SiteName, + SiteLogo: settings.SiteLogo, + SiteSubtitle: settings.SiteSubtitle, + APIBaseURL: settings.APIBaseURL, + ContactInfo: settings.ContactInfo, + DocURL: settings.DocURL, + DefaultConcurrency: settings.DefaultConcurrency, + DefaultBalance: settings.DefaultBalance, + EnableModelFallback: settings.EnableModelFallback, + FallbackModelAnthropic: settings.FallbackModelAnthropic, + FallbackModelOpenAI: settings.FallbackModelOpenAI, + FallbackModelGemini: settings.FallbackModelGemini, + FallbackModelAntigravity: settings.FallbackModelAntigravity, + EnableIdentityPatch: settings.EnableIdentityPatch, + IdentityPatchPrompt: settings.IdentityPatchPrompt, }) } @@ -88,6 +94,12 @@ type UpdateSettingsRequest struct { TurnstileSiteKey string `json:"turnstile_site_key"` TurnstileSecretKey string `json:"turnstile_secret_key"` + // LinuxDo Connect OAuth 登录(终端用户 SSO) + LinuxDoConnectEnabled bool `json:"linuxdo_connect_enabled"` + LinuxDoConnectClientID string `json:"linuxdo_connect_client_id"` + LinuxDoConnectClientSecret string `json:"linuxdo_connect_client_secret"` + LinuxDoConnectRedirectURL string `json:"linuxdo_connect_redirect_url"` + // OEM设置 SiteName string `json:"site_name"` SiteLogo string `json:"site_logo"` @@ -165,34 +177,67 @@ func (h *SettingHandler) UpdateSettings(c *gin.Context) { } } + // LinuxDo Connect 参数验证 + if req.LinuxDoConnectEnabled { + req.LinuxDoConnectClientID = strings.TrimSpace(req.LinuxDoConnectClientID) + req.LinuxDoConnectClientSecret = strings.TrimSpace(req.LinuxDoConnectClientSecret) + req.LinuxDoConnectRedirectURL = strings.TrimSpace(req.LinuxDoConnectRedirectURL) + + if req.LinuxDoConnectClientID == "" { + response.BadRequest(c, "LinuxDo Client ID is required when enabled") + return + } + if req.LinuxDoConnectRedirectURL == "" { + response.BadRequest(c, "LinuxDo Redirect URL is required when enabled") + return + } + if err := config.ValidateAbsoluteHTTPURL(req.LinuxDoConnectRedirectURL); err != nil { + response.BadRequest(c, "LinuxDo Redirect URL must be an absolute http(s) URL") + return + } + + // 如果未提供 client_secret,则保留现有值(如有)。 + if req.LinuxDoConnectClientSecret == "" { + if previousSettings.LinuxDoConnectClientSecret == "" { + response.BadRequest(c, "LinuxDo Client Secret is required when enabled") + return + } + req.LinuxDoConnectClientSecret = previousSettings.LinuxDoConnectClientSecret + } + } + settings := &service.SystemSettings{ - RegistrationEnabled: req.RegistrationEnabled, - EmailVerifyEnabled: req.EmailVerifyEnabled, - SMTPHost: req.SMTPHost, - SMTPPort: req.SMTPPort, - SMTPUsername: req.SMTPUsername, - SMTPPassword: req.SMTPPassword, - SMTPFrom: req.SMTPFrom, - SMTPFromName: req.SMTPFromName, - SMTPUseTLS: req.SMTPUseTLS, - TurnstileEnabled: req.TurnstileEnabled, - TurnstileSiteKey: req.TurnstileSiteKey, - TurnstileSecretKey: req.TurnstileSecretKey, - SiteName: req.SiteName, - SiteLogo: req.SiteLogo, - SiteSubtitle: req.SiteSubtitle, - APIBaseURL: req.APIBaseURL, - ContactInfo: req.ContactInfo, - DocURL: req.DocURL, - DefaultConcurrency: req.DefaultConcurrency, - DefaultBalance: req.DefaultBalance, - EnableModelFallback: req.EnableModelFallback, - FallbackModelAnthropic: req.FallbackModelAnthropic, - FallbackModelOpenAI: req.FallbackModelOpenAI, - FallbackModelGemini: req.FallbackModelGemini, - FallbackModelAntigravity: req.FallbackModelAntigravity, - EnableIdentityPatch: req.EnableIdentityPatch, - IdentityPatchPrompt: req.IdentityPatchPrompt, + RegistrationEnabled: req.RegistrationEnabled, + EmailVerifyEnabled: req.EmailVerifyEnabled, + SMTPHost: req.SMTPHost, + SMTPPort: req.SMTPPort, + SMTPUsername: req.SMTPUsername, + SMTPPassword: req.SMTPPassword, + SMTPFrom: req.SMTPFrom, + SMTPFromName: req.SMTPFromName, + SMTPUseTLS: req.SMTPUseTLS, + TurnstileEnabled: req.TurnstileEnabled, + TurnstileSiteKey: req.TurnstileSiteKey, + TurnstileSecretKey: req.TurnstileSecretKey, + LinuxDoConnectEnabled: req.LinuxDoConnectEnabled, + LinuxDoConnectClientID: req.LinuxDoConnectClientID, + LinuxDoConnectClientSecret: req.LinuxDoConnectClientSecret, + LinuxDoConnectRedirectURL: req.LinuxDoConnectRedirectURL, + SiteName: req.SiteName, + SiteLogo: req.SiteLogo, + SiteSubtitle: req.SiteSubtitle, + APIBaseURL: req.APIBaseURL, + ContactInfo: req.ContactInfo, + DocURL: req.DocURL, + DefaultConcurrency: req.DefaultConcurrency, + DefaultBalance: req.DefaultBalance, + EnableModelFallback: req.EnableModelFallback, + FallbackModelAnthropic: req.FallbackModelAnthropic, + FallbackModelOpenAI: req.FallbackModelOpenAI, + FallbackModelGemini: req.FallbackModelGemini, + FallbackModelAntigravity: req.FallbackModelAntigravity, + EnableIdentityPatch: req.EnableIdentityPatch, + IdentityPatchPrompt: req.IdentityPatchPrompt, } if err := h.settingService.UpdateSettings(c.Request.Context(), settings); err != nil { @@ -210,33 +255,37 @@ func (h *SettingHandler) UpdateSettings(c *gin.Context) { } response.Success(c, dto.SystemSettings{ - RegistrationEnabled: updatedSettings.RegistrationEnabled, - EmailVerifyEnabled: updatedSettings.EmailVerifyEnabled, - SMTPHost: updatedSettings.SMTPHost, - SMTPPort: updatedSettings.SMTPPort, - SMTPUsername: updatedSettings.SMTPUsername, - SMTPPasswordConfigured: updatedSettings.SMTPPasswordConfigured, - SMTPFrom: updatedSettings.SMTPFrom, - SMTPFromName: updatedSettings.SMTPFromName, - SMTPUseTLS: updatedSettings.SMTPUseTLS, - TurnstileEnabled: updatedSettings.TurnstileEnabled, - TurnstileSiteKey: updatedSettings.TurnstileSiteKey, - TurnstileSecretKeyConfigured: updatedSettings.TurnstileSecretKeyConfigured, - SiteName: updatedSettings.SiteName, - SiteLogo: updatedSettings.SiteLogo, - SiteSubtitle: updatedSettings.SiteSubtitle, - APIBaseURL: updatedSettings.APIBaseURL, - ContactInfo: updatedSettings.ContactInfo, - DocURL: updatedSettings.DocURL, - DefaultConcurrency: updatedSettings.DefaultConcurrency, - DefaultBalance: updatedSettings.DefaultBalance, - EnableModelFallback: updatedSettings.EnableModelFallback, - FallbackModelAnthropic: updatedSettings.FallbackModelAnthropic, - FallbackModelOpenAI: updatedSettings.FallbackModelOpenAI, - FallbackModelGemini: updatedSettings.FallbackModelGemini, - FallbackModelAntigravity: updatedSettings.FallbackModelAntigravity, - EnableIdentityPatch: updatedSettings.EnableIdentityPatch, - IdentityPatchPrompt: updatedSettings.IdentityPatchPrompt, + RegistrationEnabled: updatedSettings.RegistrationEnabled, + EmailVerifyEnabled: updatedSettings.EmailVerifyEnabled, + SMTPHost: updatedSettings.SMTPHost, + SMTPPort: updatedSettings.SMTPPort, + SMTPUsername: updatedSettings.SMTPUsername, + SMTPPasswordConfigured: updatedSettings.SMTPPasswordConfigured, + SMTPFrom: updatedSettings.SMTPFrom, + SMTPFromName: updatedSettings.SMTPFromName, + SMTPUseTLS: updatedSettings.SMTPUseTLS, + TurnstileEnabled: updatedSettings.TurnstileEnabled, + TurnstileSiteKey: updatedSettings.TurnstileSiteKey, + TurnstileSecretKeyConfigured: updatedSettings.TurnstileSecretKeyConfigured, + LinuxDoConnectEnabled: updatedSettings.LinuxDoConnectEnabled, + LinuxDoConnectClientID: updatedSettings.LinuxDoConnectClientID, + LinuxDoConnectClientSecretConfigured: updatedSettings.LinuxDoConnectClientSecretConfigured, + LinuxDoConnectRedirectURL: updatedSettings.LinuxDoConnectRedirectURL, + SiteName: updatedSettings.SiteName, + SiteLogo: updatedSettings.SiteLogo, + SiteSubtitle: updatedSettings.SiteSubtitle, + APIBaseURL: updatedSettings.APIBaseURL, + ContactInfo: updatedSettings.ContactInfo, + DocURL: updatedSettings.DocURL, + DefaultConcurrency: updatedSettings.DefaultConcurrency, + DefaultBalance: updatedSettings.DefaultBalance, + EnableModelFallback: updatedSettings.EnableModelFallback, + FallbackModelAnthropic: updatedSettings.FallbackModelAnthropic, + FallbackModelOpenAI: updatedSettings.FallbackModelOpenAI, + FallbackModelGemini: updatedSettings.FallbackModelGemini, + FallbackModelAntigravity: updatedSettings.FallbackModelAntigravity, + EnableIdentityPatch: updatedSettings.EnableIdentityPatch, + IdentityPatchPrompt: updatedSettings.IdentityPatchPrompt, }) } @@ -298,6 +347,18 @@ func diffSettings(before *service.SystemSettings, after *service.SystemSettings, if req.TurnstileSecretKey != "" { changed = append(changed, "turnstile_secret_key") } + if before.LinuxDoConnectEnabled != after.LinuxDoConnectEnabled { + changed = append(changed, "linuxdo_connect_enabled") + } + if before.LinuxDoConnectClientID != after.LinuxDoConnectClientID { + changed = append(changed, "linuxdo_connect_client_id") + } + if req.LinuxDoConnectClientSecret != "" { + changed = append(changed, "linuxdo_connect_client_secret") + } + if before.LinuxDoConnectRedirectURL != after.LinuxDoConnectRedirectURL { + changed = append(changed, "linuxdo_connect_redirect_url") + } if before.SiteName != after.SiteName { changed = append(changed, "site_name") } @@ -337,6 +398,12 @@ func diffSettings(before *service.SystemSettings, after *service.SystemSettings, if before.FallbackModelAntigravity != after.FallbackModelAntigravity { changed = append(changed, "fallback_model_antigravity") } + if before.EnableIdentityPatch != after.EnableIdentityPatch { + changed = append(changed, "enable_identity_patch") + } + if before.IdentityPatchPrompt != after.IdentityPatchPrompt { + changed = append(changed, "identity_patch_prompt") + } return changed } diff --git a/backend/internal/handler/admin/user_handler.go b/backend/internal/handler/admin/user_handler.go index f8cd1d5a..38cc8acd 100644 --- a/backend/internal/handler/admin/user_handler.go +++ b/backend/internal/handler/admin/user_handler.go @@ -2,6 +2,7 @@ package admin import ( "strconv" + "strings" "github.com/Wei-Shaw/sub2api/internal/handler/dto" "github.com/Wei-Shaw/sub2api/internal/pkg/response" @@ -63,10 +64,17 @@ type UpdateBalanceRequest struct { func (h *UserHandler) List(c *gin.Context) { page, pageSize := response.ParsePagination(c) + search := c.Query("search") + // 标准化和验证 search 参数 + search = strings.TrimSpace(search) + if len(search) > 100 { + search = search[:100] + } + filters := service.UserListFilters{ Status: c.Query("status"), Role: c.Query("role"), - Search: c.Query("search"), + Search: search, Attributes: parseAttributeFilters(c), } diff --git a/backend/internal/handler/auth_handler.go b/backend/internal/handler/auth_handler.go index 8466f131..8463367e 100644 --- a/backend/internal/handler/auth_handler.go +++ b/backend/internal/handler/auth_handler.go @@ -15,14 +15,16 @@ type AuthHandler struct { cfg *config.Config authService *service.AuthService userService *service.UserService + settingSvc *service.SettingService } // NewAuthHandler creates a new AuthHandler -func NewAuthHandler(cfg *config.Config, authService *service.AuthService, userService *service.UserService) *AuthHandler { +func NewAuthHandler(cfg *config.Config, authService *service.AuthService, userService *service.UserService, settingService *service.SettingService) *AuthHandler { return &AuthHandler{ cfg: cfg, authService: authService, userService: userService, + settingSvc: settingService, } } diff --git a/backend/internal/handler/auth_linuxdo_oauth.go b/backend/internal/handler/auth_linuxdo_oauth.go new file mode 100644 index 00000000..a16c4cc7 --- /dev/null +++ b/backend/internal/handler/auth_linuxdo_oauth.go @@ -0,0 +1,679 @@ +package handler + +import ( + "context" + "encoding/base64" + "errors" + "fmt" + "log" + "net/http" + "net/url" + "strconv" + "strings" + "time" + "unicode/utf8" + + "github.com/Wei-Shaw/sub2api/internal/config" + infraerrors "github.com/Wei-Shaw/sub2api/internal/pkg/errors" + "github.com/Wei-Shaw/sub2api/internal/pkg/oauth" + "github.com/Wei-Shaw/sub2api/internal/pkg/response" + "github.com/Wei-Shaw/sub2api/internal/service" + + "github.com/gin-gonic/gin" + "github.com/imroc/req/v3" + "github.com/tidwall/gjson" +) + +const ( + linuxDoOAuthCookiePath = "/api/v1/auth/oauth/linuxdo" + linuxDoOAuthStateCookieName = "linuxdo_oauth_state" + linuxDoOAuthVerifierCookie = "linuxdo_oauth_verifier" + linuxDoOAuthRedirectCookie = "linuxdo_oauth_redirect" + linuxDoOAuthCookieMaxAgeSec = 10 * 60 // 10 minutes + linuxDoOAuthDefaultRedirectTo = "/dashboard" + linuxDoOAuthDefaultFrontendCB = "/auth/linuxdo/callback" + + linuxDoOAuthMaxRedirectLen = 2048 + linuxDoOAuthMaxFragmentValueLen = 512 + linuxDoOAuthMaxSubjectLen = 64 - len("linuxdo-") +) + +type linuxDoTokenResponse struct { + AccessToken string `json:"access_token"` + TokenType string `json:"token_type"` + ExpiresIn int64 `json:"expires_in"` + RefreshToken string `json:"refresh_token,omitempty"` + Scope string `json:"scope,omitempty"` +} + +type linuxDoTokenExchangeError struct { + StatusCode int + ProviderError string + ProviderDescription string + Body string +} + +func (e *linuxDoTokenExchangeError) Error() string { + if e == nil { + return "" + } + parts := []string{fmt.Sprintf("token exchange status=%d", e.StatusCode)} + if strings.TrimSpace(e.ProviderError) != "" { + parts = append(parts, "error="+strings.TrimSpace(e.ProviderError)) + } + if strings.TrimSpace(e.ProviderDescription) != "" { + parts = append(parts, "error_description="+strings.TrimSpace(e.ProviderDescription)) + } + return strings.Join(parts, " ") +} + +// LinuxDoOAuthStart 启动 LinuxDo Connect OAuth 登录流程。 +// GET /api/v1/auth/oauth/linuxdo/start?redirect=/dashboard +func (h *AuthHandler) LinuxDoOAuthStart(c *gin.Context) { + cfg, err := h.getLinuxDoOAuthConfig(c.Request.Context()) + if err != nil { + response.ErrorFrom(c, err) + return + } + + state, err := oauth.GenerateState() + if err != nil { + response.ErrorFrom(c, infraerrors.InternalServer("OAUTH_STATE_GEN_FAILED", "failed to generate oauth state").WithCause(err)) + return + } + + redirectTo := sanitizeFrontendRedirectPath(c.Query("redirect")) + if redirectTo == "" { + redirectTo = linuxDoOAuthDefaultRedirectTo + } + + secureCookie := isRequestHTTPS(c) + setCookie(c, linuxDoOAuthStateCookieName, encodeCookieValue(state), linuxDoOAuthCookieMaxAgeSec, secureCookie) + setCookie(c, linuxDoOAuthRedirectCookie, encodeCookieValue(redirectTo), linuxDoOAuthCookieMaxAgeSec, secureCookie) + + codeChallenge := "" + if cfg.UsePKCE { + verifier, err := oauth.GenerateCodeVerifier() + if err != nil { + response.ErrorFrom(c, infraerrors.InternalServer("OAUTH_PKCE_GEN_FAILED", "failed to generate pkce verifier").WithCause(err)) + return + } + codeChallenge = oauth.GenerateCodeChallenge(verifier) + setCookie(c, linuxDoOAuthVerifierCookie, encodeCookieValue(verifier), linuxDoOAuthCookieMaxAgeSec, secureCookie) + } + + redirectURI := strings.TrimSpace(cfg.RedirectURL) + if redirectURI == "" { + response.ErrorFrom(c, infraerrors.InternalServer("OAUTH_CONFIG_INVALID", "oauth redirect url not configured")) + return + } + + authURL, err := buildLinuxDoAuthorizeURL(cfg, state, codeChallenge, redirectURI) + if err != nil { + response.ErrorFrom(c, infraerrors.InternalServer("OAUTH_BUILD_URL_FAILED", "failed to build oauth authorization url").WithCause(err)) + return + } + + c.Redirect(http.StatusFound, authURL) +} + +// LinuxDoOAuthCallback 处理 OAuth 回调:创建/登录用户,然后重定向到前端。 +// GET /api/v1/auth/oauth/linuxdo/callback?code=...&state=... +func (h *AuthHandler) LinuxDoOAuthCallback(c *gin.Context) { + cfg, cfgErr := h.getLinuxDoOAuthConfig(c.Request.Context()) + if cfgErr != nil { + response.ErrorFrom(c, cfgErr) + return + } + + frontendCallback := strings.TrimSpace(cfg.FrontendRedirectURL) + if frontendCallback == "" { + frontendCallback = linuxDoOAuthDefaultFrontendCB + } + + if providerErr := strings.TrimSpace(c.Query("error")); providerErr != "" { + redirectOAuthError(c, frontendCallback, "provider_error", providerErr, c.Query("error_description")) + return + } + + code := strings.TrimSpace(c.Query("code")) + state := strings.TrimSpace(c.Query("state")) + if code == "" || state == "" { + redirectOAuthError(c, frontendCallback, "missing_params", "missing code/state", "") + return + } + + secureCookie := isRequestHTTPS(c) + defer func() { + clearCookie(c, linuxDoOAuthStateCookieName, secureCookie) + clearCookie(c, linuxDoOAuthVerifierCookie, secureCookie) + clearCookie(c, linuxDoOAuthRedirectCookie, secureCookie) + }() + + expectedState, err := readCookieDecoded(c, linuxDoOAuthStateCookieName) + if err != nil || expectedState == "" || state != expectedState { + redirectOAuthError(c, frontendCallback, "invalid_state", "invalid oauth state", "") + return + } + + redirectTo, _ := readCookieDecoded(c, linuxDoOAuthRedirectCookie) + redirectTo = sanitizeFrontendRedirectPath(redirectTo) + if redirectTo == "" { + redirectTo = linuxDoOAuthDefaultRedirectTo + } + + codeVerifier := "" + if cfg.UsePKCE { + codeVerifier, _ = readCookieDecoded(c, linuxDoOAuthVerifierCookie) + if codeVerifier == "" { + redirectOAuthError(c, frontendCallback, "missing_verifier", "missing pkce verifier", "") + return + } + } + + redirectURI := strings.TrimSpace(cfg.RedirectURL) + if redirectURI == "" { + redirectOAuthError(c, frontendCallback, "config_error", "oauth redirect url not configured", "") + return + } + + tokenResp, err := linuxDoExchangeCode(c.Request.Context(), cfg, code, redirectURI, codeVerifier) + if err != nil { + description := "" + var exchangeErr *linuxDoTokenExchangeError + if errors.As(err, &exchangeErr) && exchangeErr != nil { + log.Printf( + "[LinuxDo OAuth] token exchange failed: status=%d provider_error=%q provider_description=%q body=%s", + exchangeErr.StatusCode, + exchangeErr.ProviderError, + exchangeErr.ProviderDescription, + truncateLogValue(exchangeErr.Body, 2048), + ) + description = exchangeErr.Error() + } else { + log.Printf("[LinuxDo OAuth] token exchange failed: %v", err) + description = err.Error() + } + redirectOAuthError(c, frontendCallback, "token_exchange_failed", "failed to exchange oauth code", singleLine(description)) + return + } + + email, username, subject, err := linuxDoFetchUserInfo(c.Request.Context(), cfg, tokenResp) + if err != nil { + log.Printf("[LinuxDo OAuth] userinfo fetch failed: %v", err) + redirectOAuthError(c, frontendCallback, "userinfo_failed", "failed to fetch user info", "") + return + } + + // 安全考虑:不要把第三方返回的 email 直接映射到本地账号(可能与本地邮箱用户冲突导致账号被接管)。 + // 统一使用基于 subject 的稳定合成邮箱来做账号绑定。 + if subject != "" { + email = linuxDoSyntheticEmail(subject) + } + + jwtToken, _, err := h.authService.LoginOrRegisterOAuth(c.Request.Context(), email, username) + if err != nil { + // 避免把内部细节泄露给客户端;给前端保留结构化原因与提示信息即可。 + redirectOAuthError(c, frontendCallback, "login_failed", infraerrors.Reason(err), infraerrors.Message(err)) + return + } + + fragment := url.Values{} + fragment.Set("access_token", jwtToken) + fragment.Set("token_type", "Bearer") + fragment.Set("redirect", redirectTo) + redirectWithFragment(c, frontendCallback, fragment) +} + +func (h *AuthHandler) getLinuxDoOAuthConfig(ctx context.Context) (config.LinuxDoConnectConfig, error) { + if h != nil && h.settingSvc != nil { + return h.settingSvc.GetLinuxDoConnectOAuthConfig(ctx) + } + if h == nil || h.cfg == nil { + return config.LinuxDoConnectConfig{}, infraerrors.ServiceUnavailable("CONFIG_NOT_READY", "config not loaded") + } + if !h.cfg.LinuxDo.Enabled { + return config.LinuxDoConnectConfig{}, infraerrors.NotFound("OAUTH_DISABLED", "oauth login is disabled") + } + return h.cfg.LinuxDo, nil +} + +func linuxDoExchangeCode( + ctx context.Context, + cfg config.LinuxDoConnectConfig, + code string, + redirectURI string, + codeVerifier string, +) (*linuxDoTokenResponse, error) { + client := req.C().SetTimeout(30 * time.Second) + + form := url.Values{} + form.Set("grant_type", "authorization_code") + form.Set("client_id", cfg.ClientID) + form.Set("code", code) + form.Set("redirect_uri", redirectURI) + if cfg.UsePKCE { + form.Set("code_verifier", codeVerifier) + } + + r := client.R(). + SetContext(ctx). + SetHeader("Accept", "application/json") + + switch strings.ToLower(strings.TrimSpace(cfg.TokenAuthMethod)) { + case "", "client_secret_post": + form.Set("client_secret", cfg.ClientSecret) + case "client_secret_basic": + r.SetBasicAuth(cfg.ClientID, cfg.ClientSecret) + case "none": + default: + return nil, fmt.Errorf("unsupported token_auth_method: %s", cfg.TokenAuthMethod) + } + + resp, err := r.SetFormDataFromValues(form).Post(cfg.TokenURL) + if err != nil { + return nil, fmt.Errorf("request token: %w", err) + } + body := strings.TrimSpace(resp.String()) + if !resp.IsSuccessState() { + providerErr, providerDesc := parseOAuthProviderError(body) + return nil, &linuxDoTokenExchangeError{ + StatusCode: resp.StatusCode, + ProviderError: providerErr, + ProviderDescription: providerDesc, + Body: body, + } + } + + tokenResp, ok := parseLinuxDoTokenResponse(body) + if !ok || strings.TrimSpace(tokenResp.AccessToken) == "" { + return nil, &linuxDoTokenExchangeError{ + StatusCode: resp.StatusCode, + Body: body, + } + } + if strings.TrimSpace(tokenResp.TokenType) == "" { + tokenResp.TokenType = "Bearer" + } + return tokenResp, nil +} + +func linuxDoFetchUserInfo( + ctx context.Context, + cfg config.LinuxDoConnectConfig, + token *linuxDoTokenResponse, +) (email string, username string, subject string, err error) { + client := req.C().SetTimeout(30 * time.Second) + authorization, err := buildBearerAuthorization(token.TokenType, token.AccessToken) + if err != nil { + return "", "", "", fmt.Errorf("invalid token for userinfo request: %w", err) + } + + resp, err := client.R(). + SetContext(ctx). + SetHeader("Accept", "application/json"). + SetHeader("Authorization", authorization). + Get(cfg.UserInfoURL) + if err != nil { + return "", "", "", fmt.Errorf("request userinfo: %w", err) + } + if !resp.IsSuccessState() { + return "", "", "", fmt.Errorf("userinfo status=%d", resp.StatusCode) + } + + return linuxDoParseUserInfo(resp.String(), cfg) +} + +func linuxDoParseUserInfo(body string, cfg config.LinuxDoConnectConfig) (email string, username string, subject string, err error) { + email = firstNonEmpty( + getGJSON(body, cfg.UserInfoEmailPath), + getGJSON(body, "email"), + getGJSON(body, "user.email"), + getGJSON(body, "data.email"), + getGJSON(body, "attributes.email"), + ) + username = firstNonEmpty( + getGJSON(body, cfg.UserInfoUsernamePath), + getGJSON(body, "username"), + getGJSON(body, "preferred_username"), + getGJSON(body, "name"), + getGJSON(body, "user.username"), + getGJSON(body, "user.name"), + ) + subject = firstNonEmpty( + getGJSON(body, cfg.UserInfoIDPath), + getGJSON(body, "sub"), + getGJSON(body, "id"), + getGJSON(body, "user_id"), + getGJSON(body, "uid"), + getGJSON(body, "user.id"), + ) + + subject = strings.TrimSpace(subject) + if subject == "" { + return "", "", "", errors.New("userinfo missing id field") + } + if !isSafeLinuxDoSubject(subject) { + return "", "", "", errors.New("userinfo returned invalid id field") + } + + email = strings.TrimSpace(email) + if email == "" { + // LinuxDo Connect 的 userinfo 可能不提供 email。为兼容现有用户模型(email 必填且唯一),使用稳定的合成邮箱。 + email = linuxDoSyntheticEmail(subject) + } + + username = strings.TrimSpace(username) + if username == "" { + username = "linuxdo_" + subject + } + + return email, username, subject, nil +} + +func buildLinuxDoAuthorizeURL(cfg config.LinuxDoConnectConfig, state string, codeChallenge string, redirectURI string) (string, error) { + u, err := url.Parse(cfg.AuthorizeURL) + if err != nil { + return "", fmt.Errorf("parse authorize_url: %w", err) + } + + q := u.Query() + q.Set("response_type", "code") + q.Set("client_id", cfg.ClientID) + q.Set("redirect_uri", redirectURI) + if strings.TrimSpace(cfg.Scopes) != "" { + q.Set("scope", cfg.Scopes) + } + q.Set("state", state) + if cfg.UsePKCE { + q.Set("code_challenge", codeChallenge) + q.Set("code_challenge_method", "S256") + } + + u.RawQuery = q.Encode() + return u.String(), nil +} + +func redirectOAuthError(c *gin.Context, frontendCallback string, code string, message string, description string) { + fragment := url.Values{} + fragment.Set("error", truncateFragmentValue(code)) + if strings.TrimSpace(message) != "" { + fragment.Set("error_message", truncateFragmentValue(message)) + } + if strings.TrimSpace(description) != "" { + fragment.Set("error_description", truncateFragmentValue(description)) + } + redirectWithFragment(c, frontendCallback, fragment) +} + +func redirectWithFragment(c *gin.Context, frontendCallback string, fragment url.Values) { + u, err := url.Parse(frontendCallback) + if err != nil { + // 兜底:尽力跳转到默认页面,避免卡死在回调页。 + c.Redirect(http.StatusFound, linuxDoOAuthDefaultRedirectTo) + return + } + if u.Scheme != "" && !strings.EqualFold(u.Scheme, "http") && !strings.EqualFold(u.Scheme, "https") { + c.Redirect(http.StatusFound, linuxDoOAuthDefaultRedirectTo) + return + } + u.Fragment = fragment.Encode() + c.Header("Cache-Control", "no-store") + c.Header("Pragma", "no-cache") + c.Redirect(http.StatusFound, u.String()) +} + +func firstNonEmpty(values ...string) string { + for _, v := range values { + v = strings.TrimSpace(v) + if v != "" { + return v + } + } + return "" +} + +func parseOAuthProviderError(body string) (providerErr string, providerDesc string) { + body = strings.TrimSpace(body) + if body == "" { + return "", "" + } + + providerErr = firstNonEmpty( + getGJSON(body, "error"), + getGJSON(body, "code"), + getGJSON(body, "error.code"), + ) + providerDesc = firstNonEmpty( + getGJSON(body, "error_description"), + getGJSON(body, "error.message"), + getGJSON(body, "message"), + getGJSON(body, "detail"), + ) + + if providerErr != "" || providerDesc != "" { + return providerErr, providerDesc + } + + values, err := url.ParseQuery(body) + if err != nil { + return "", "" + } + providerErr = firstNonEmpty(values.Get("error"), values.Get("code")) + providerDesc = firstNonEmpty(values.Get("error_description"), values.Get("error_message"), values.Get("message")) + return providerErr, providerDesc +} + +func parseLinuxDoTokenResponse(body string) (*linuxDoTokenResponse, bool) { + body = strings.TrimSpace(body) + if body == "" { + return nil, false + } + + accessToken := strings.TrimSpace(getGJSON(body, "access_token")) + if accessToken != "" { + tokenType := strings.TrimSpace(getGJSON(body, "token_type")) + refreshToken := strings.TrimSpace(getGJSON(body, "refresh_token")) + scope := strings.TrimSpace(getGJSON(body, "scope")) + expiresIn := gjson.Get(body, "expires_in").Int() + return &linuxDoTokenResponse{ + AccessToken: accessToken, + TokenType: tokenType, + ExpiresIn: expiresIn, + RefreshToken: refreshToken, + Scope: scope, + }, true + } + + values, err := url.ParseQuery(body) + if err != nil { + return nil, false + } + accessToken = strings.TrimSpace(values.Get("access_token")) + if accessToken == "" { + return nil, false + } + expiresIn := int64(0) + if raw := strings.TrimSpace(values.Get("expires_in")); raw != "" { + if v, err := strconv.ParseInt(raw, 10, 64); err == nil { + expiresIn = v + } + } + return &linuxDoTokenResponse{ + AccessToken: accessToken, + TokenType: strings.TrimSpace(values.Get("token_type")), + ExpiresIn: expiresIn, + RefreshToken: strings.TrimSpace(values.Get("refresh_token")), + Scope: strings.TrimSpace(values.Get("scope")), + }, true +} + +func getGJSON(body string, path string) string { + path = strings.TrimSpace(path) + if path == "" { + return "" + } + res := gjson.Get(body, path) + if !res.Exists() { + return "" + } + return res.String() +} + +func truncateLogValue(value string, maxLen int) string { + value = strings.TrimSpace(value) + if value == "" || maxLen <= 0 { + return "" + } + if len(value) <= maxLen { + return value + } + value = value[:maxLen] + for !utf8.ValidString(value) { + value = value[:len(value)-1] + } + return value +} + +func singleLine(value string) string { + value = strings.TrimSpace(value) + if value == "" { + return "" + } + return strings.Join(strings.Fields(value), " ") +} + +func sanitizeFrontendRedirectPath(path string) string { + path = strings.TrimSpace(path) + if path == "" { + return "" + } + if len(path) > linuxDoOAuthMaxRedirectLen { + return "" + } + // 只允许同源相对路径(避免开放重定向)。 + if !strings.HasPrefix(path, "/") { + return "" + } + if strings.HasPrefix(path, "//") { + return "" + } + if strings.Contains(path, "://") { + return "" + } + if strings.ContainsAny(path, "\r\n") { + return "" + } + return path +} + +func isRequestHTTPS(c *gin.Context) bool { + if c.Request.TLS != nil { + return true + } + proto := strings.ToLower(strings.TrimSpace(c.GetHeader("X-Forwarded-Proto"))) + return proto == "https" +} + +func encodeCookieValue(value string) string { + return base64.RawURLEncoding.EncodeToString([]byte(value)) +} + +func decodeCookieValue(value string) (string, error) { + raw, err := base64.RawURLEncoding.DecodeString(value) + if err != nil { + return "", err + } + return string(raw), nil +} + +func readCookieDecoded(c *gin.Context, name string) (string, error) { + ck, err := c.Request.Cookie(name) + if err != nil { + return "", err + } + return decodeCookieValue(ck.Value) +} + +func setCookie(c *gin.Context, name string, value string, maxAgeSec int, secure bool) { + http.SetCookie(c.Writer, &http.Cookie{ + Name: name, + Value: value, + Path: linuxDoOAuthCookiePath, + MaxAge: maxAgeSec, + HttpOnly: true, + Secure: secure, + SameSite: http.SameSiteLaxMode, + }) +} + +func clearCookie(c *gin.Context, name string, secure bool) { + http.SetCookie(c.Writer, &http.Cookie{ + Name: name, + Value: "", + Path: linuxDoOAuthCookiePath, + MaxAge: -1, + HttpOnly: true, + Secure: secure, + SameSite: http.SameSiteLaxMode, + }) +} + +func truncateFragmentValue(value string) string { + value = strings.TrimSpace(value) + if value == "" { + return "" + } + if len(value) > linuxDoOAuthMaxFragmentValueLen { + value = value[:linuxDoOAuthMaxFragmentValueLen] + for !utf8.ValidString(value) { + value = value[:len(value)-1] + } + } + return value +} + +func buildBearerAuthorization(tokenType, accessToken string) (string, error) { + tokenType = strings.TrimSpace(tokenType) + if tokenType == "" { + tokenType = "Bearer" + } + if !strings.EqualFold(tokenType, "Bearer") { + return "", fmt.Errorf("unsupported token_type: %s", tokenType) + } + + accessToken = strings.TrimSpace(accessToken) + if accessToken == "" { + return "", errors.New("missing access_token") + } + if strings.ContainsAny(accessToken, " \t\r\n") { + return "", errors.New("access_token contains whitespace") + } + return "Bearer " + accessToken, nil +} + +func isSafeLinuxDoSubject(subject string) bool { + subject = strings.TrimSpace(subject) + if subject == "" || len(subject) > linuxDoOAuthMaxSubjectLen { + return false + } + for _, r := range subject { + switch { + case r >= '0' && r <= '9': + case r >= 'a' && r <= 'z': + case r >= 'A' && r <= 'Z': + case r == '_' || r == '-': + default: + return false + } + } + return true +} + +func linuxDoSyntheticEmail(subject string) string { + subject = strings.TrimSpace(subject) + if subject == "" { + return "" + } + return "linuxdo-" + subject + service.LinuxDoConnectSyntheticEmailDomain +} diff --git a/backend/internal/handler/auth_linuxdo_oauth_test.go b/backend/internal/handler/auth_linuxdo_oauth_test.go new file mode 100644 index 00000000..ff169c52 --- /dev/null +++ b/backend/internal/handler/auth_linuxdo_oauth_test.go @@ -0,0 +1,108 @@ +package handler + +import ( + "strings" + "testing" + + "github.com/Wei-Shaw/sub2api/internal/config" + "github.com/stretchr/testify/require" +) + +func TestSanitizeFrontendRedirectPath(t *testing.T) { + require.Equal(t, "/dashboard", sanitizeFrontendRedirectPath("/dashboard")) + require.Equal(t, "/dashboard", sanitizeFrontendRedirectPath(" /dashboard ")) + require.Equal(t, "", sanitizeFrontendRedirectPath("dashboard")) + require.Equal(t, "", sanitizeFrontendRedirectPath("//evil.com")) + require.Equal(t, "", sanitizeFrontendRedirectPath("https://evil.com")) + require.Equal(t, "", sanitizeFrontendRedirectPath("/\nfoo")) + + long := "/" + strings.Repeat("a", linuxDoOAuthMaxRedirectLen) + require.Equal(t, "", sanitizeFrontendRedirectPath(long)) +} + +func TestBuildBearerAuthorization(t *testing.T) { + auth, err := buildBearerAuthorization("", "token123") + require.NoError(t, err) + require.Equal(t, "Bearer token123", auth) + + auth, err = buildBearerAuthorization("bearer", "token123") + require.NoError(t, err) + require.Equal(t, "Bearer token123", auth) + + _, err = buildBearerAuthorization("MAC", "token123") + require.Error(t, err) + + _, err = buildBearerAuthorization("Bearer", "token 123") + require.Error(t, err) +} + +func TestLinuxDoParseUserInfoParsesIDAndUsername(t *testing.T) { + cfg := config.LinuxDoConnectConfig{ + UserInfoURL: "https://connect.linux.do/api/user", + } + + email, username, subject, err := linuxDoParseUserInfo(`{"id":123,"username":"alice"}`, cfg) + require.NoError(t, err) + require.Equal(t, "123", subject) + require.Equal(t, "alice", username) + require.Equal(t, "linuxdo-123@linuxdo-connect.invalid", email) +} + +func TestLinuxDoParseUserInfoDefaultsUsername(t *testing.T) { + cfg := config.LinuxDoConnectConfig{ + UserInfoURL: "https://connect.linux.do/api/user", + } + + email, username, subject, err := linuxDoParseUserInfo(`{"id":"123"}`, cfg) + require.NoError(t, err) + require.Equal(t, "123", subject) + require.Equal(t, "linuxdo_123", username) + require.Equal(t, "linuxdo-123@linuxdo-connect.invalid", email) +} + +func TestLinuxDoParseUserInfoRejectsUnsafeSubject(t *testing.T) { + cfg := config.LinuxDoConnectConfig{ + UserInfoURL: "https://connect.linux.do/api/user", + } + + _, _, _, err := linuxDoParseUserInfo(`{"id":"123@456"}`, cfg) + require.Error(t, err) + + tooLong := strings.Repeat("a", linuxDoOAuthMaxSubjectLen+1) + _, _, _, err = linuxDoParseUserInfo(`{"id":"`+tooLong+`"}`, cfg) + require.Error(t, err) +} + +func TestParseOAuthProviderErrorJSON(t *testing.T) { + code, desc := parseOAuthProviderError(`{"error":"invalid_client","error_description":"bad secret"}`) + require.Equal(t, "invalid_client", code) + require.Equal(t, "bad secret", desc) +} + +func TestParseOAuthProviderErrorForm(t *testing.T) { + code, desc := parseOAuthProviderError("error=invalid_request&error_description=Missing+code_verifier") + require.Equal(t, "invalid_request", code) + require.Equal(t, "Missing code_verifier", desc) +} + +func TestParseLinuxDoTokenResponseJSON(t *testing.T) { + token, ok := parseLinuxDoTokenResponse(`{"access_token":"t1","token_type":"Bearer","expires_in":3600,"scope":"user"}`) + require.True(t, ok) + require.Equal(t, "t1", token.AccessToken) + require.Equal(t, "Bearer", token.TokenType) + require.Equal(t, int64(3600), token.ExpiresIn) + require.Equal(t, "user", token.Scope) +} + +func TestParseLinuxDoTokenResponseForm(t *testing.T) { + token, ok := parseLinuxDoTokenResponse("access_token=t2&token_type=bearer&expires_in=60") + require.True(t, ok) + require.Equal(t, "t2", token.AccessToken) + require.Equal(t, "bearer", token.TokenType) + require.Equal(t, int64(60), token.ExpiresIn) +} + +func TestSingleLineStripsWhitespace(t *testing.T) { + require.Equal(t, "hello world", singleLine("hello\r\nworld")) + require.Equal(t, "", singleLine("\n\t\r")) +} diff --git a/backend/internal/handler/dto/settings.go b/backend/internal/handler/dto/settings.go index 4c50cedf..dab5eb75 100644 --- a/backend/internal/handler/dto/settings.go +++ b/backend/internal/handler/dto/settings.go @@ -17,6 +17,11 @@ type SystemSettings struct { TurnstileSiteKey string `json:"turnstile_site_key"` TurnstileSecretKeyConfigured bool `json:"turnstile_secret_key_configured"` + LinuxDoConnectEnabled bool `json:"linuxdo_connect_enabled"` + LinuxDoConnectClientID string `json:"linuxdo_connect_client_id"` + LinuxDoConnectClientSecretConfigured bool `json:"linuxdo_connect_client_secret_configured"` + LinuxDoConnectRedirectURL string `json:"linuxdo_connect_redirect_url"` + SiteName string `json:"site_name"` SiteLogo string `json:"site_logo"` SiteSubtitle string `json:"site_subtitle"` @@ -50,5 +55,6 @@ type PublicSettings struct { APIBaseURL string `json:"api_base_url"` ContactInfo string `json:"contact_info"` DocURL string `json:"doc_url"` + LinuxDoOAuthEnabled bool `json:"linuxdo_oauth_enabled"` Version string `json:"version"` } diff --git a/backend/internal/handler/setting_handler.go b/backend/internal/handler/setting_handler.go index 3cae7a7f..e1b20c8c 100644 --- a/backend/internal/handler/setting_handler.go +++ b/backend/internal/handler/setting_handler.go @@ -42,6 +42,7 @@ func (h *SettingHandler) GetPublicSettings(c *gin.Context) { APIBaseURL: settings.APIBaseURL, ContactInfo: settings.ContactInfo, DocURL: settings.DocURL, + LinuxDoOAuthEnabled: settings.LinuxDoOAuthEnabled, Version: h.version, }) } diff --git a/backend/internal/pkg/antigravity/client.go b/backend/internal/pkg/antigravity/client.go index 8ff75f57..1248be95 100644 --- a/backend/internal/pkg/antigravity/client.go +++ b/backend/internal/pkg/antigravity/client.go @@ -5,8 +5,11 @@ import ( "bytes" "context" "encoding/json" + "errors" "fmt" "io" + "log" + "net" "net/http" "net/url" "strings" @@ -22,10 +25,10 @@ func resolveHost(urlStr string) string { return parsed.Host } -// NewAPIRequest 创建 Antigravity API 请求(v1internal 端点) -func NewAPIRequest(ctx context.Context, action, accessToken string, body []byte) (*http.Request, error) { +// NewAPIRequestWithURL 使用指定的 base URL 创建 Antigravity API 请求(v1internal 端点) +func NewAPIRequestWithURL(ctx context.Context, baseURL, action, accessToken string, body []byte) (*http.Request, error) { // 构建 URL,流式请求添加 ?alt=sse 参数 - apiURL := fmt.Sprintf("%s/v1internal:%s", BaseURL, action) + apiURL := fmt.Sprintf("%s/v1internal:%s", baseURL, action) isStream := action == "streamGenerateContent" if isStream { apiURL += "?alt=sse" @@ -53,11 +56,15 @@ func NewAPIRequest(ctx context.Context, action, accessToken string, body []byte) req.Host = host } - // 注意:requestType 已在 JSON body 的 V1InternalRequest 中设置,不需要 HTTP Header - return req, nil } +// NewAPIRequest 使用默认 URL 创建 Antigravity API 请求(v1internal 端点) +// 向后兼容:仅使用默认 BaseURL +func NewAPIRequest(ctx context.Context, action, accessToken string, body []byte) (*http.Request, error) { + return NewAPIRequestWithURL(ctx, BaseURL, action, accessToken, body) +} + // TokenResponse Google OAuth token 响应 type TokenResponse struct { AccessToken string `json:"access_token"` @@ -164,6 +171,38 @@ func NewClient(proxyURL string) *Client { } } +// isConnectionError 判断是否为连接错误(网络超时、DNS 失败、连接拒绝) +func isConnectionError(err error) bool { + if err == nil { + return false + } + + // 检查超时错误 + var netErr net.Error + if errors.As(err, &netErr) && netErr.Timeout() { + return true + } + + // 检查连接错误(DNS 失败、连接拒绝) + var opErr *net.OpError + if errors.As(err, &opErr) { + return true + } + + // 检查 URL 错误 + var urlErr *url.Error + return errors.As(err, &urlErr) +} + +// shouldFallbackToNextURL 判断是否应切换到下一个 URL +// 仅连接错误和 HTTP 429 触发 URL 降级 +func shouldFallbackToNextURL(err error, statusCode int) bool { + if isConnectionError(err) { + return true + } + return statusCode == http.StatusTooManyRequests +} + // ExchangeCode 用 authorization code 交换 token func (c *Client) ExchangeCode(ctx context.Context, code, codeVerifier string) (*TokenResponse, error) { params := url.Values{} @@ -272,6 +311,7 @@ func (c *Client) GetUserInfo(ctx context.Context, accessToken string) (*UserInfo } // LoadCodeAssist 获取账户信息,返回解析后的结构体和原始 JSON +// 支持 URL fallback:sandbox → daily → prod func (c *Client) LoadCodeAssist(ctx context.Context, accessToken string) (*LoadCodeAssistResponse, map[string]any, error) { reqBody := LoadCodeAssistRequest{} reqBody.Metadata.IDEType = "ANTIGRAVITY" @@ -281,40 +321,65 @@ func (c *Client) LoadCodeAssist(ctx context.Context, accessToken string) (*LoadC return nil, nil, fmt.Errorf("序列化请求失败: %w", err) } - url := BaseURL + "/v1internal:loadCodeAssist" - req, err := http.NewRequestWithContext(ctx, http.MethodPost, url, strings.NewReader(string(bodyBytes))) - if err != nil { - return nil, nil, fmt.Errorf("创建请求失败: %w", err) - } - req.Header.Set("Authorization", "Bearer "+accessToken) - req.Header.Set("Content-Type", "application/json") - req.Header.Set("User-Agent", UserAgent) - - resp, err := c.httpClient.Do(req) - if err != nil { - return nil, nil, fmt.Errorf("loadCodeAssist 请求失败: %w", err) - } - defer func() { _ = resp.Body.Close() }() - - respBodyBytes, err := io.ReadAll(resp.Body) - if err != nil { - return nil, nil, fmt.Errorf("读取响应失败: %w", err) + // 获取可用的 URL 列表 + availableURLs := DefaultURLAvailability.GetAvailableURLs() + if len(availableURLs) == 0 { + availableURLs = BaseURLs // 所有 URL 都不可用时,重试所有 } - if resp.StatusCode != http.StatusOK { - return nil, nil, fmt.Errorf("loadCodeAssist 失败 (HTTP %d): %s", resp.StatusCode, string(respBodyBytes)) + var lastErr error + for urlIdx, baseURL := range availableURLs { + apiURL := baseURL + "/v1internal:loadCodeAssist" + req, err := http.NewRequestWithContext(ctx, http.MethodPost, apiURL, strings.NewReader(string(bodyBytes))) + if err != nil { + lastErr = fmt.Errorf("创建请求失败: %w", err) + continue + } + req.Header.Set("Authorization", "Bearer "+accessToken) + req.Header.Set("Content-Type", "application/json") + req.Header.Set("User-Agent", UserAgent) + + resp, err := c.httpClient.Do(req) + if err != nil { + lastErr = fmt.Errorf("loadCodeAssist 请求失败: %w", err) + if shouldFallbackToNextURL(err, 0) && urlIdx < len(availableURLs)-1 { + DefaultURLAvailability.MarkUnavailable(baseURL) + log.Printf("[antigravity] loadCodeAssist URL fallback: %s -> %s", baseURL, availableURLs[urlIdx+1]) + continue + } + return nil, nil, lastErr + } + + respBodyBytes, err := io.ReadAll(resp.Body) + _ = resp.Body.Close() // 立即关闭,避免循环内 defer 导致的资源泄漏 + if err != nil { + return nil, nil, fmt.Errorf("读取响应失败: %w", err) + } + + // 检查是否需要 URL 降级 + if shouldFallbackToNextURL(nil, resp.StatusCode) && urlIdx < len(availableURLs)-1 { + DefaultURLAvailability.MarkUnavailable(baseURL) + log.Printf("[antigravity] loadCodeAssist URL fallback (HTTP %d): %s -> %s", resp.StatusCode, baseURL, availableURLs[urlIdx+1]) + continue + } + + if resp.StatusCode != http.StatusOK { + return nil, nil, fmt.Errorf("loadCodeAssist 失败 (HTTP %d): %s", resp.StatusCode, string(respBodyBytes)) + } + + var loadResp LoadCodeAssistResponse + if err := json.Unmarshal(respBodyBytes, &loadResp); err != nil { + return nil, nil, fmt.Errorf("响应解析失败: %w", err) + } + + // 解析原始 JSON 为 map + var rawResp map[string]any + _ = json.Unmarshal(respBodyBytes, &rawResp) + + return &loadResp, rawResp, nil } - var loadResp LoadCodeAssistResponse - if err := json.Unmarshal(respBodyBytes, &loadResp); err != nil { - return nil, nil, fmt.Errorf("响应解析失败: %w", err) - } - - // 解析原始 JSON 为 map - var rawResp map[string]any - _ = json.Unmarshal(respBodyBytes, &rawResp) - - return &loadResp, rawResp, nil + return nil, nil, lastErr } // ModelQuotaInfo 模型配额信息 @@ -339,6 +404,7 @@ type FetchAvailableModelsResponse struct { } // FetchAvailableModels 获取可用模型和配额信息,返回解析后的结构体和原始 JSON +// 支持 URL fallback:sandbox → daily → prod func (c *Client) FetchAvailableModels(ctx context.Context, accessToken, projectID string) (*FetchAvailableModelsResponse, map[string]any, error) { reqBody := FetchAvailableModelsRequest{Project: projectID} bodyBytes, err := json.Marshal(reqBody) @@ -346,38 +412,63 @@ func (c *Client) FetchAvailableModels(ctx context.Context, accessToken, projectI return nil, nil, fmt.Errorf("序列化请求失败: %w", err) } - apiURL := BaseURL + "/v1internal:fetchAvailableModels" - req, err := http.NewRequestWithContext(ctx, http.MethodPost, apiURL, strings.NewReader(string(bodyBytes))) - if err != nil { - return nil, nil, fmt.Errorf("创建请求失败: %w", err) - } - req.Header.Set("Authorization", "Bearer "+accessToken) - req.Header.Set("Content-Type", "application/json") - req.Header.Set("User-Agent", UserAgent) - - resp, err := c.httpClient.Do(req) - if err != nil { - return nil, nil, fmt.Errorf("fetchAvailableModels 请求失败: %w", err) - } - defer func() { _ = resp.Body.Close() }() - - respBodyBytes, err := io.ReadAll(resp.Body) - if err != nil { - return nil, nil, fmt.Errorf("读取响应失败: %w", err) + // 获取可用的 URL 列表 + availableURLs := DefaultURLAvailability.GetAvailableURLs() + if len(availableURLs) == 0 { + availableURLs = BaseURLs // 所有 URL 都不可用时,重试所有 } - if resp.StatusCode != http.StatusOK { - return nil, nil, fmt.Errorf("fetchAvailableModels 失败 (HTTP %d): %s", resp.StatusCode, string(respBodyBytes)) + var lastErr error + for urlIdx, baseURL := range availableURLs { + apiURL := baseURL + "/v1internal:fetchAvailableModels" + req, err := http.NewRequestWithContext(ctx, http.MethodPost, apiURL, strings.NewReader(string(bodyBytes))) + if err != nil { + lastErr = fmt.Errorf("创建请求失败: %w", err) + continue + } + req.Header.Set("Authorization", "Bearer "+accessToken) + req.Header.Set("Content-Type", "application/json") + req.Header.Set("User-Agent", UserAgent) + + resp, err := c.httpClient.Do(req) + if err != nil { + lastErr = fmt.Errorf("fetchAvailableModels 请求失败: %w", err) + if shouldFallbackToNextURL(err, 0) && urlIdx < len(availableURLs)-1 { + DefaultURLAvailability.MarkUnavailable(baseURL) + log.Printf("[antigravity] fetchAvailableModels URL fallback: %s -> %s", baseURL, availableURLs[urlIdx+1]) + continue + } + return nil, nil, lastErr + } + + respBodyBytes, err := io.ReadAll(resp.Body) + _ = resp.Body.Close() // 立即关闭,避免循环内 defer 导致的资源泄漏 + if err != nil { + return nil, nil, fmt.Errorf("读取响应失败: %w", err) + } + + // 检查是否需要 URL 降级 + if shouldFallbackToNextURL(nil, resp.StatusCode) && urlIdx < len(availableURLs)-1 { + DefaultURLAvailability.MarkUnavailable(baseURL) + log.Printf("[antigravity] fetchAvailableModels URL fallback (HTTP %d): %s -> %s", resp.StatusCode, baseURL, availableURLs[urlIdx+1]) + continue + } + + if resp.StatusCode != http.StatusOK { + return nil, nil, fmt.Errorf("fetchAvailableModels 失败 (HTTP %d): %s", resp.StatusCode, string(respBodyBytes)) + } + + var modelsResp FetchAvailableModelsResponse + if err := json.Unmarshal(respBodyBytes, &modelsResp); err != nil { + return nil, nil, fmt.Errorf("响应解析失败: %w", err) + } + + // 解析原始 JSON 为 map + var rawResp map[string]any + _ = json.Unmarshal(respBodyBytes, &rawResp) + + return &modelsResp, rawResp, nil } - var modelsResp FetchAvailableModelsResponse - if err := json.Unmarshal(respBodyBytes, &modelsResp); err != nil { - return nil, nil, fmt.Errorf("响应解析失败: %w", err) - } - - // 解析原始 JSON 为 map - var rawResp map[string]any - _ = json.Unmarshal(respBodyBytes, &rawResp) - - return &modelsResp, rawResp, nil + return nil, nil, lastErr } diff --git a/backend/internal/pkg/antigravity/oauth.go b/backend/internal/pkg/antigravity/oauth.go index e88c203b..736c45df 100644 --- a/backend/internal/pkg/antigravity/oauth.go +++ b/backend/internal/pkg/antigravity/oauth.go @@ -32,17 +32,79 @@ const ( "https://www.googleapis.com/auth/cclog " + "https://www.googleapis.com/auth/experimentsandconfigs" - // API 端点 - // 优先使用 sandbox daily URL,配额更宽松 - BaseURL = "https://daily-cloudcode-pa.sandbox.googleapis.com" - // User-Agent(模拟官方客户端) UserAgent = "antigravity/1.104.0 darwin/arm64" // Session 过期时间 SessionTTL = 30 * time.Minute + + // URL 可用性 TTL(不可用 URL 的恢复时间) + URLAvailabilityTTL = 5 * time.Minute ) +// BaseURLs 定义 Antigravity API 端点,按优先级排序 +// fallback 顺序: sandbox → daily → prod +var BaseURLs = []string{ + "https://daily-cloudcode-pa.sandbox.googleapis.com", // sandbox + "https://daily-cloudcode-pa.googleapis.com", // daily + "https://cloudcode-pa.googleapis.com", // prod +} + +// BaseURL 默认 URL(保持向后兼容) +var BaseURL = BaseURLs[0] + +// URLAvailability 管理 URL 可用性状态(带 TTL 自动恢复) +type URLAvailability struct { + mu sync.RWMutex + unavailable map[string]time.Time // URL -> 恢复时间 + ttl time.Duration +} + +// DefaultURLAvailability 全局 URL 可用性管理器 +var DefaultURLAvailability = NewURLAvailability(URLAvailabilityTTL) + +// NewURLAvailability 创建 URL 可用性管理器 +func NewURLAvailability(ttl time.Duration) *URLAvailability { + return &URLAvailability{ + unavailable: make(map[string]time.Time), + ttl: ttl, + } +} + +// MarkUnavailable 标记 URL 临时不可用 +func (u *URLAvailability) MarkUnavailable(url string) { + u.mu.Lock() + defer u.mu.Unlock() + u.unavailable[url] = time.Now().Add(u.ttl) +} + +// IsAvailable 检查 URL 是否可用 +func (u *URLAvailability) IsAvailable(url string) bool { + u.mu.RLock() + defer u.mu.RUnlock() + expiry, exists := u.unavailable[url] + if !exists { + return true + } + return time.Now().After(expiry) +} + +// GetAvailableURLs 返回可用的 URL 列表(保持优先级顺序) +func (u *URLAvailability) GetAvailableURLs() []string { + u.mu.RLock() + defer u.mu.RUnlock() + + now := time.Now() + result := make([]string, 0, len(BaseURLs)) + for _, url := range BaseURLs { + expiry, exists := u.unavailable[url] + if !exists || now.After(expiry) { + result = append(result, url) + } + } + return result +} + // OAuthSession 保存 OAuth 授权流程的临时状态 type OAuthSession struct { State string `json:"state"` diff --git a/backend/internal/pkg/geminicli/constants.go b/backend/internal/pkg/geminicli/constants.go index 6d7e5a5d..d4d52116 100644 --- a/backend/internal/pkg/geminicli/constants.go +++ b/backend/internal/pkg/geminicli/constants.go @@ -27,10 +27,9 @@ const ( // https://www.googleapis.com/auth/generative-language.retriever (often with cloud-platform). DefaultAIStudioScopes = "https://www.googleapis.com/auth/cloud-platform https://www.googleapis.com/auth/generative-language.retriever" - // DefaultScopes for Google One (personal Google accounts with Gemini access) - // Only used when a custom OAuth client is configured. When using the built-in Gemini CLI client, - // Google One uses DefaultCodeAssistScopes (same as code_assist) because the built-in client - // cannot request restricted scopes like generative-language.retriever or drive.readonly. + // DefaultGoogleOneScopes (DEPRECATED, no longer used) + // Google One now always uses the built-in Gemini CLI client with DefaultCodeAssistScopes. + // This constant is kept for backward compatibility but is not actively used. DefaultGoogleOneScopes = "https://www.googleapis.com/auth/cloud-platform https://www.googleapis.com/auth/generative-language.retriever https://www.googleapis.com/auth/drive.readonly https://www.googleapis.com/auth/userinfo.email https://www.googleapis.com/auth/userinfo.profile" // GeminiCLIRedirectURI is the redirect URI used by Gemini CLI for Code Assist OAuth. diff --git a/backend/internal/pkg/geminicli/oauth.go b/backend/internal/pkg/geminicli/oauth.go index 473017a2..c71e8aad 100644 --- a/backend/internal/pkg/geminicli/oauth.go +++ b/backend/internal/pkg/geminicli/oauth.go @@ -185,13 +185,9 @@ func EffectiveOAuthConfig(cfg OAuthConfig, oauthType string) (OAuthConfig, error effective.Scopes = DefaultAIStudioScopes } case "google_one": - // Google One uses built-in Gemini CLI client (same as code_assist) - // Built-in client can't request restricted scopes like generative-language.retriever - if isBuiltinClient { - effective.Scopes = DefaultCodeAssistScopes - } else { - effective.Scopes = DefaultGoogleOneScopes - } + // Google One always uses built-in Gemini CLI client (same as code_assist) + // Built-in client can't request restricted scopes like generative-language.retriever or drive.readonly + effective.Scopes = DefaultCodeAssistScopes default: // Default to Code Assist scopes effective.Scopes = DefaultCodeAssistScopes diff --git a/backend/internal/pkg/geminicli/oauth_test.go b/backend/internal/pkg/geminicli/oauth_test.go index 0520f0f2..0770730a 100644 --- a/backend/internal/pkg/geminicli/oauth_test.go +++ b/backend/internal/pkg/geminicli/oauth_test.go @@ -23,14 +23,14 @@ func TestEffectiveOAuthConfig_GoogleOne(t *testing.T) { wantErr: false, }, { - name: "Google One with custom client", + name: "Google One always uses built-in client (even if custom credentials passed)", input: OAuthConfig{ ClientID: "custom-client-id", ClientSecret: "custom-client-secret", }, oauthType: "google_one", wantClientID: "custom-client-id", - wantScopes: DefaultGoogleOneScopes, + wantScopes: DefaultCodeAssistScopes, // Uses code assist scopes even with custom client wantErr: false, }, { diff --git a/backend/internal/repository/account_repo.go b/backend/internal/repository/account_repo.go index 83f02608..02f7b9a5 100644 --- a/backend/internal/repository/account_repo.go +++ b/backend/internal/repository/account_repo.go @@ -831,6 +831,11 @@ func (r *accountRepository) BulkUpdate(ctx context.Context, ids []int64, updates args = append(args, *updates.Status) idx++ } + if updates.Schedulable != nil { + setClauses = append(setClauses, "schedulable = $"+itoa(idx)) + args = append(args, *updates.Schedulable) + idx++ + } // JSONB 需要合并而非覆盖,使用 raw SQL 保持旧行为。 if len(updates.Credentials) > 0 { payload, err := json.Marshal(updates.Credentials) diff --git a/backend/internal/repository/gemini_oauth_client.go b/backend/internal/repository/gemini_oauth_client.go index 14ecfc89..8b7fe625 100644 --- a/backend/internal/repository/gemini_oauth_client.go +++ b/backend/internal/repository/gemini_oauth_client.go @@ -30,14 +30,15 @@ func (c *geminiOAuthClient) ExchangeCode(ctx context.Context, oauthType, code, c // Use different OAuth clients based on oauthType: // - code_assist: always use built-in Gemini CLI OAuth client (public) - // - google_one: uses configured OAuth client when provided; otherwise falls back to built-in client + // - google_one: always use built-in Gemini CLI OAuth client (public) // - ai_studio: requires a user-provided OAuth client oauthCfgInput := geminicli.OAuthConfig{ ClientID: c.cfg.Gemini.OAuth.ClientID, ClientSecret: c.cfg.Gemini.OAuth.ClientSecret, Scopes: c.cfg.Gemini.OAuth.Scopes, } - if oauthType == "code_assist" { + if oauthType == "code_assist" || oauthType == "google_one" { + // Force use of built-in Gemini CLI OAuth client oauthCfgInput.ClientID = "" oauthCfgInput.ClientSecret = "" } @@ -78,7 +79,8 @@ func (c *geminiOAuthClient) RefreshToken(ctx context.Context, oauthType, refresh ClientSecret: c.cfg.Gemini.OAuth.ClientSecret, Scopes: c.cfg.Gemini.OAuth.Scopes, } - if oauthType == "code_assist" { + if oauthType == "code_assist" || oauthType == "google_one" { + // Force use of built-in Gemini CLI OAuth client oauthCfgInput.ClientID = "" oauthCfgInput.ClientSecret = "" } diff --git a/backend/internal/repository/group_repo.go b/backend/internal/repository/group_repo.go index 1fb4ae90..a54f3116 100644 --- a/backend/internal/repository/group_repo.go +++ b/backend/internal/repository/group_repo.go @@ -112,10 +112,10 @@ func (r *groupRepository) Delete(ctx context.Context, id int64) error { } func (r *groupRepository) List(ctx context.Context, params pagination.PaginationParams) ([]service.Group, *pagination.PaginationResult, error) { - return r.ListWithFilters(ctx, params, "", "", nil) + return r.ListWithFilters(ctx, params, "", "", "", nil) } -func (r *groupRepository) ListWithFilters(ctx context.Context, params pagination.PaginationParams, platform, status string, isExclusive *bool) ([]service.Group, *pagination.PaginationResult, error) { +func (r *groupRepository) ListWithFilters(ctx context.Context, params pagination.PaginationParams, platform, status, search string, isExclusive *bool) ([]service.Group, *pagination.PaginationResult, error) { q := r.client.Group.Query() if platform != "" { @@ -124,6 +124,12 @@ func (r *groupRepository) ListWithFilters(ctx context.Context, params pagination if status != "" { q = q.Where(group.StatusEQ(status)) } + if search != "" { + q = q.Where(group.Or( + group.NameContainsFold(search), + group.DescriptionContainsFold(search), + )) + } if isExclusive != nil { q = q.Where(group.IsExclusiveEQ(*isExclusive)) } diff --git a/backend/internal/repository/group_repo_integration_test.go b/backend/internal/repository/group_repo_integration_test.go index b9079d7a..660618a6 100644 --- a/backend/internal/repository/group_repo_integration_test.go +++ b/backend/internal/repository/group_repo_integration_test.go @@ -131,6 +131,7 @@ func (s *GroupRepoSuite) TestListWithFilters_Platform() { pagination.PaginationParams{Page: 1, PageSize: 10}, service.PlatformOpenAI, "", + "", nil, ) s.Require().NoError(err, "ListWithFilters base") @@ -152,7 +153,7 @@ func (s *GroupRepoSuite) TestListWithFilters_Platform() { SubscriptionType: service.SubscriptionTypeStandard, })) - groups, _, err := s.repo.ListWithFilters(s.ctx, pagination.PaginationParams{Page: 1, PageSize: 10}, service.PlatformOpenAI, "", nil) + groups, _, err := s.repo.ListWithFilters(s.ctx, pagination.PaginationParams{Page: 1, PageSize: 10}, service.PlatformOpenAI, "", "", nil) s.Require().NoError(err) s.Require().Len(groups, len(baseGroups)+1) // Verify all groups are OpenAI platform @@ -179,7 +180,7 @@ func (s *GroupRepoSuite) TestListWithFilters_Status() { SubscriptionType: service.SubscriptionTypeStandard, })) - groups, _, err := s.repo.ListWithFilters(s.ctx, pagination.PaginationParams{Page: 1, PageSize: 10}, "", service.StatusDisabled, nil) + groups, _, err := s.repo.ListWithFilters(s.ctx, pagination.PaginationParams{Page: 1, PageSize: 10}, "", service.StatusDisabled, "", nil) s.Require().NoError(err) s.Require().Len(groups, 1) s.Require().Equal(service.StatusDisabled, groups[0].Status) @@ -204,12 +205,117 @@ func (s *GroupRepoSuite) TestListWithFilters_IsExclusive() { })) isExclusive := true - groups, _, err := s.repo.ListWithFilters(s.ctx, pagination.PaginationParams{Page: 1, PageSize: 10}, "", "", &isExclusive) + groups, _, err := s.repo.ListWithFilters(s.ctx, pagination.PaginationParams{Page: 1, PageSize: 10}, "", "", "", &isExclusive) s.Require().NoError(err) s.Require().Len(groups, 1) s.Require().True(groups[0].IsExclusive) } +func (s *GroupRepoSuite) TestListWithFilters_Search() { + newRepo := func() (*groupRepository, context.Context) { + tx := testEntTx(s.T()) + return newGroupRepositoryWithSQL(tx.Client(), tx), context.Background() + } + + containsID := func(groups []service.Group, id int64) bool { + for i := range groups { + if groups[i].ID == id { + return true + } + } + return false + } + + mustCreate := func(repo *groupRepository, ctx context.Context, g *service.Group) *service.Group { + s.Require().NoError(repo.Create(ctx, g)) + s.Require().NotZero(g.ID) + return g + } + + newGroup := func(name string) *service.Group { + return &service.Group{ + Name: name, + Platform: service.PlatformAnthropic, + RateMultiplier: 1.0, + IsExclusive: false, + Status: service.StatusActive, + SubscriptionType: service.SubscriptionTypeStandard, + } + } + + s.Run("search_name_should_match", func() { + repo, ctx := newRepo() + + target := mustCreate(repo, ctx, newGroup("it-group-search-name-target")) + other := mustCreate(repo, ctx, newGroup("it-group-search-name-other")) + + groups, _, err := repo.ListWithFilters(ctx, pagination.PaginationParams{Page: 1, PageSize: 50}, "", "", "name-target", nil) + s.Require().NoError(err) + s.Require().True(containsID(groups, target.ID), "expected target group to match by name") + s.Require().False(containsID(groups, other.ID), "expected other group to be filtered out") + }) + + s.Run("search_description_should_match", func() { + repo, ctx := newRepo() + + target := newGroup("it-group-search-desc-target") + target.Description = "something about desc-needle in here" + target = mustCreate(repo, ctx, target) + + other := newGroup("it-group-search-desc-other") + other.Description = "nothing to see here" + other = mustCreate(repo, ctx, other) + + groups, _, err := repo.ListWithFilters(ctx, pagination.PaginationParams{Page: 1, PageSize: 50}, "", "", "desc-needle", nil) + s.Require().NoError(err) + s.Require().True(containsID(groups, target.ID), "expected target group to match by description") + s.Require().False(containsID(groups, other.ID), "expected other group to be filtered out") + }) + + s.Run("search_nonexistent_should_return_empty", func() { + repo, ctx := newRepo() + + _ = mustCreate(repo, ctx, newGroup("it-group-search-nonexistent-baseline")) + + search := s.T().Name() + "__no_such_group__" + groups, _, err := repo.ListWithFilters(ctx, pagination.PaginationParams{Page: 1, PageSize: 50}, "", "", search, nil) + s.Require().NoError(err) + s.Require().Empty(groups) + }) + + s.Run("search_should_be_case_insensitive", func() { + repo, ctx := newRepo() + + target := mustCreate(repo, ctx, newGroup("MiXeDCaSe-Needle")) + other := mustCreate(repo, ctx, newGroup("it-group-search-case-other")) + + groups, _, err := repo.ListWithFilters(ctx, pagination.PaginationParams{Page: 1, PageSize: 50}, "", "", "mixedcase-needle", nil) + s.Require().NoError(err) + s.Require().True(containsID(groups, target.ID), "expected case-insensitive match") + s.Require().False(containsID(groups, other.ID), "expected other group to be filtered out") + }) + + s.Run("search_should_escape_like_wildcards", func() { + repo, ctx := newRepo() + + percentTarget := mustCreate(repo, ctx, newGroup("it-group-search-100%-target")) + percentOther := mustCreate(repo, ctx, newGroup("it-group-search-100X-other")) + + groups, _, err := repo.ListWithFilters(ctx, pagination.PaginationParams{Page: 1, PageSize: 50}, "", "", "100%", nil) + s.Require().NoError(err) + s.Require().True(containsID(groups, percentTarget.ID), "expected literal %% match") + s.Require().False(containsID(groups, percentOther.ID), "expected %% not to act as wildcard") + + underscoreTarget := mustCreate(repo, ctx, newGroup("it-group-search-ab_cd-target")) + underscoreOther := mustCreate(repo, ctx, newGroup("it-group-search-abXcd-other")) + + groups, _, err = repo.ListWithFilters(ctx, pagination.PaginationParams{Page: 1, PageSize: 50}, "", "", "ab_cd", nil) + s.Require().NoError(err) + s.Require().True(containsID(groups, underscoreTarget.ID), "expected literal _ match") + s.Require().False(containsID(groups, underscoreOther.ID), "expected _ not to act as wildcard") + }) +} + func (s *GroupRepoSuite) TestListWithFilters_AccountCount() { g1 := &service.Group{ Name: "g1", @@ -244,7 +350,7 @@ func (s *GroupRepoSuite) TestListWithFilters_AccountCount() { s.Require().NoError(err) isExclusive := true - groups, page, err := s.repo.ListWithFilters(s.ctx, pagination.PaginationParams{Page: 1, PageSize: 10}, service.PlatformAnthropic, service.StatusActive, &isExclusive) + groups, page, err := s.repo.ListWithFilters(s.ctx, pagination.PaginationParams{Page: 1, PageSize: 10}, service.PlatformAnthropic, service.StatusActive, "", &isExclusive) s.Require().NoError(err, "ListWithFilters") s.Require().Equal(int64(1), page.Total) s.Require().Len(groups, 1) diff --git a/backend/internal/server/api_contract_test.go b/backend/internal/server/api_contract_test.go index 502d74b3..20e82be8 100644 --- a/backend/internal/server/api_contract_test.go +++ b/backend/internal/server/api_contract_test.go @@ -304,6 +304,10 @@ func TestAPIContracts(t *testing.T) { "turnstile_enabled": true, "turnstile_site_key": "site-key", "turnstile_secret_key_configured": true, + "linuxdo_connect_enabled": false, + "linuxdo_connect_client_id": "", + "linuxdo_connect_client_secret_configured": false, + "linuxdo_connect_redirect_url": "", "site_name": "Sub2API", "site_logo": "", "site_subtitle": "Subtitle", @@ -390,7 +394,7 @@ func newContractDeps(t *testing.T) *contractDeps { settingRepo := newStubSettingRepo() settingService := service.NewSettingService(settingRepo, cfg) - authHandler := handler.NewAuthHandler(cfg, nil, userService) + authHandler := handler.NewAuthHandler(cfg, nil, userService, settingService) apiKeyHandler := handler.NewAPIKeyHandler(apiKeyService) usageHandler := handler.NewUsageHandler(usageService, apiKeyService) adminSettingHandler := adminhandler.NewSettingHandler(settingService, nil, nil) @@ -583,7 +587,7 @@ func (stubGroupRepo) List(ctx context.Context, params pagination.PaginationParam return nil, nil, errors.New("not implemented") } -func (stubGroupRepo) ListWithFilters(ctx context.Context, params pagination.PaginationParams, platform, status string, isExclusive *bool) ([]service.Group, *pagination.PaginationResult, error) { +func (stubGroupRepo) ListWithFilters(ctx context.Context, params pagination.PaginationParams, platform, status, search string, isExclusive *bool) ([]service.Group, *pagination.PaginationResult, error) { return nil, nil, errors.New("not implemented") } diff --git a/backend/internal/server/routes/auth.go b/backend/internal/server/routes/auth.go index 196d8bdb..e61d3939 100644 --- a/backend/internal/server/routes/auth.go +++ b/backend/internal/server/routes/auth.go @@ -19,6 +19,8 @@ func RegisterAuthRoutes( auth.POST("/register", h.Auth.Register) auth.POST("/login", h.Auth.Login) auth.POST("/send-verify-code", h.Auth.SendVerifyCode) + auth.GET("/oauth/linuxdo/start", h.Auth.LinuxDoOAuthStart) + auth.GET("/oauth/linuxdo/callback", h.Auth.LinuxDoOAuthCallback) } // 公开设置(无需认证) diff --git a/backend/internal/service/account_service.go b/backend/internal/service/account_service.go index e1b93fcb..3407f33a 100644 --- a/backend/internal/service/account_service.go +++ b/backend/internal/service/account_service.go @@ -66,6 +66,7 @@ type AccountBulkUpdate struct { Concurrency *int Priority *int Status *string + Schedulable *bool Credentials map[string]any Extra map[string]any } diff --git a/backend/internal/service/account_test_service.go b/backend/internal/service/account_test_service.go index 7121a13d..8419c2b4 100644 --- a/backend/internal/service/account_test_service.go +++ b/backend/internal/service/account_test_service.go @@ -661,13 +661,7 @@ func (s *AccountTestService) processGeminiStream(c *gin.Context, body io.Reader) } if candidates, ok := data["candidates"].([]any); ok && len(candidates) > 0 { if candidate, ok := candidates[0].(map[string]any); ok { - // Check for completion - if finishReason, ok := candidate["finishReason"].(string); ok && finishReason != "" { - s.sendEvent(c, TestEvent{Type: "test_complete", Success: true}) - return nil - } - - // Extract content + // Extract content first (before checking completion) if content, ok := candidate["content"].(map[string]any); ok { if parts, ok := content["parts"].([]any); ok { for _, part := range parts { @@ -679,6 +673,12 @@ func (s *AccountTestService) processGeminiStream(c *gin.Context, body io.Reader) } } } + + // Check for completion after extracting content + if finishReason, ok := candidate["finishReason"].(string); ok && finishReason != "" { + s.sendEvent(c, TestEvent{Type: "test_complete", Success: true}) + return nil + } } } diff --git a/backend/internal/service/admin_service.go b/backend/internal/service/admin_service.go index e29bbdb4..4288381c 100644 --- a/backend/internal/service/admin_service.go +++ b/backend/internal/service/admin_service.go @@ -24,7 +24,7 @@ type AdminService interface { GetUserUsageStats(ctx context.Context, userID int64, period string) (any, error) // Group management - ListGroups(ctx context.Context, page, pageSize int, platform, status string, isExclusive *bool) ([]Group, int64, error) + ListGroups(ctx context.Context, page, pageSize int, platform, status, search string, isExclusive *bool) ([]Group, int64, error) GetAllGroups(ctx context.Context) ([]Group, error) GetAllGroupsByPlatform(ctx context.Context, platform string) ([]Group, error) GetGroup(ctx context.Context, id int64) (*Group, error) @@ -168,6 +168,7 @@ type BulkUpdateAccountsInput struct { Concurrency *int Priority *int Status string + Schedulable *bool GroupIDs *[]int64 Credentials map[string]any Extra map[string]any @@ -478,9 +479,9 @@ func (s *adminServiceImpl) GetUserUsageStats(ctx context.Context, userID int64, } // Group management implementations -func (s *adminServiceImpl) ListGroups(ctx context.Context, page, pageSize int, platform, status string, isExclusive *bool) ([]Group, int64, error) { +func (s *adminServiceImpl) ListGroups(ctx context.Context, page, pageSize int, platform, status, search string, isExclusive *bool) ([]Group, int64, error) { params := pagination.PaginationParams{Page: page, PageSize: pageSize} - groups, result, err := s.groupRepo.ListWithFilters(ctx, params, platform, status, isExclusive) + groups, result, err := s.groupRepo.ListWithFilters(ctx, params, platform, status, search, isExclusive) if err != nil { return nil, 0, err } @@ -910,6 +911,9 @@ func (s *adminServiceImpl) BulkUpdateAccounts(ctx context.Context, input *BulkUp if input.Status != "" { repoUpdates.Status = &input.Status } + if input.Schedulable != nil { + repoUpdates.Schedulable = input.Schedulable + } // Run bulk update for column/jsonb fields first. if _, err := s.accountRepo.BulkUpdate(ctx, input.AccountIDs, repoUpdates); err != nil { diff --git a/backend/internal/service/admin_service_delete_test.go b/backend/internal/service/admin_service_delete_test.go index c1d2e4c9..351f64e8 100644 --- a/backend/internal/service/admin_service_delete_test.go +++ b/backend/internal/service/admin_service_delete_test.go @@ -124,7 +124,7 @@ func (s *groupRepoStub) List(ctx context.Context, params pagination.PaginationPa panic("unexpected List call") } -func (s *groupRepoStub) ListWithFilters(ctx context.Context, params pagination.PaginationParams, platform, status string, isExclusive *bool) ([]Group, *pagination.PaginationResult, error) { +func (s *groupRepoStub) ListWithFilters(ctx context.Context, params pagination.PaginationParams, platform, status, search string, isExclusive *bool) ([]Group, *pagination.PaginationResult, error) { panic("unexpected ListWithFilters call") } diff --git a/backend/internal/service/admin_service_group_test.go b/backend/internal/service/admin_service_group_test.go index 3171de11..26d6eedf 100644 --- a/backend/internal/service/admin_service_group_test.go +++ b/backend/internal/service/admin_service_group_test.go @@ -16,6 +16,16 @@ type groupRepoStubForAdmin struct { updated *Group // 记录 Update 调用的参数 getByID *Group // GetByID 返回值 getErr error // GetByID 返回的错误 + + listWithFiltersCalls int + listWithFiltersParams pagination.PaginationParams + listWithFiltersPlatform string + listWithFiltersStatus string + listWithFiltersSearch string + listWithFiltersIsExclusive *bool + listWithFiltersGroups []Group + listWithFiltersResult *pagination.PaginationResult + listWithFiltersErr error } func (s *groupRepoStubForAdmin) Create(_ context.Context, g *Group) error { @@ -47,8 +57,28 @@ func (s *groupRepoStubForAdmin) List(_ context.Context, _ pagination.PaginationP panic("unexpected List call") } -func (s *groupRepoStubForAdmin) ListWithFilters(_ context.Context, _ pagination.PaginationParams, _, _ string, _ *bool) ([]Group, *pagination.PaginationResult, error) { - panic("unexpected ListWithFilters call") +func (s *groupRepoStubForAdmin) ListWithFilters(_ context.Context, params pagination.PaginationParams, platform, status, search string, isExclusive *bool) ([]Group, *pagination.PaginationResult, error) { + s.listWithFiltersCalls++ + s.listWithFiltersParams = params + s.listWithFiltersPlatform = platform + s.listWithFiltersStatus = status + s.listWithFiltersSearch = search + s.listWithFiltersIsExclusive = isExclusive + + if s.listWithFiltersErr != nil { + return nil, nil, s.listWithFiltersErr + } + + result := s.listWithFiltersResult + if result == nil { + result = &pagination.PaginationResult{ + Total: int64(len(s.listWithFiltersGroups)), + Page: params.Page, + PageSize: params.PageSize, + } + } + + return s.listWithFiltersGroups, result, nil } func (s *groupRepoStubForAdmin) ListActive(_ context.Context) ([]Group, error) { @@ -195,3 +225,68 @@ func TestAdminService_UpdateGroup_PartialImagePricing(t *testing.T) { require.InDelta(t, 0.15, *repo.updated.ImagePrice2K, 0.0001) // 原值保持 require.Nil(t, repo.updated.ImagePrice4K) } + +func TestAdminService_ListGroups_WithSearch(t *testing.T) { + // 测试: + // 1. search 参数正常传递到 repository 层 + // 2. search 为空字符串时的行为 + // 3. search 与其他过滤条件组合使用 + + t.Run("search 参数正常传递到 repository 层", func(t *testing.T) { + repo := &groupRepoStubForAdmin{ + listWithFiltersGroups: []Group{{ID: 1, Name: "alpha"}}, + listWithFiltersResult: &pagination.PaginationResult{Total: 1}, + } + svc := &adminServiceImpl{groupRepo: repo} + + groups, total, err := svc.ListGroups(context.Background(), 1, 20, "", "", "alpha", nil) + require.NoError(t, err) + require.Equal(t, int64(1), total) + require.Equal(t, []Group{{ID: 1, Name: "alpha"}}, groups) + + require.Equal(t, 1, repo.listWithFiltersCalls) + require.Equal(t, pagination.PaginationParams{Page: 1, PageSize: 20}, repo.listWithFiltersParams) + require.Equal(t, "alpha", repo.listWithFiltersSearch) + require.Nil(t, repo.listWithFiltersIsExclusive) + }) + + t.Run("search 为空字符串时传递空字符串", func(t *testing.T) { + repo := &groupRepoStubForAdmin{ + listWithFiltersGroups: []Group{}, + listWithFiltersResult: &pagination.PaginationResult{Total: 0}, + } + svc := &adminServiceImpl{groupRepo: repo} + + groups, total, err := svc.ListGroups(context.Background(), 2, 10, "", "", "", nil) + require.NoError(t, err) + require.Empty(t, groups) + require.Equal(t, int64(0), total) + + require.Equal(t, 1, repo.listWithFiltersCalls) + require.Equal(t, pagination.PaginationParams{Page: 2, PageSize: 10}, repo.listWithFiltersParams) + require.Equal(t, "", repo.listWithFiltersSearch) + require.Nil(t, repo.listWithFiltersIsExclusive) + }) + + t.Run("search 与其他过滤条件组合使用", func(t *testing.T) { + isExclusive := true + repo := &groupRepoStubForAdmin{ + listWithFiltersGroups: []Group{{ID: 2, Name: "beta"}}, + listWithFiltersResult: &pagination.PaginationResult{Total: 42}, + } + svc := &adminServiceImpl{groupRepo: repo} + + groups, total, err := svc.ListGroups(context.Background(), 3, 50, PlatformAntigravity, StatusActive, "beta", &isExclusive) + require.NoError(t, err) + require.Equal(t, int64(42), total) + require.Equal(t, []Group{{ID: 2, Name: "beta"}}, groups) + + require.Equal(t, 1, repo.listWithFiltersCalls) + require.Equal(t, pagination.PaginationParams{Page: 3, PageSize: 50}, repo.listWithFiltersParams) + require.Equal(t, PlatformAntigravity, repo.listWithFiltersPlatform) + require.Equal(t, StatusActive, repo.listWithFiltersStatus) + require.Equal(t, "beta", repo.listWithFiltersSearch) + require.NotNil(t, repo.listWithFiltersIsExclusive) + require.True(t, *repo.listWithFiltersIsExclusive) + }) +} diff --git a/backend/internal/service/admin_service_search_test.go b/backend/internal/service/admin_service_search_test.go new file mode 100644 index 00000000..7506c6db --- /dev/null +++ b/backend/internal/service/admin_service_search_test.go @@ -0,0 +1,238 @@ +//go:build unit + +package service + +import ( + "context" + "testing" + + "github.com/Wei-Shaw/sub2api/internal/pkg/pagination" + "github.com/stretchr/testify/require" +) + +type accountRepoStubForAdminList struct { + accountRepoStub + + listWithFiltersCalls int + listWithFiltersParams pagination.PaginationParams + listWithFiltersPlatform string + listWithFiltersType string + listWithFiltersStatus string + listWithFiltersSearch string + listWithFiltersAccounts []Account + listWithFiltersResult *pagination.PaginationResult + listWithFiltersErr error +} + +func (s *accountRepoStubForAdminList) ListWithFilters(_ context.Context, params pagination.PaginationParams, platform, accountType, status, search string) ([]Account, *pagination.PaginationResult, error) { + s.listWithFiltersCalls++ + s.listWithFiltersParams = params + s.listWithFiltersPlatform = platform + s.listWithFiltersType = accountType + s.listWithFiltersStatus = status + s.listWithFiltersSearch = search + + if s.listWithFiltersErr != nil { + return nil, nil, s.listWithFiltersErr + } + + result := s.listWithFiltersResult + if result == nil { + result = &pagination.PaginationResult{ + Total: int64(len(s.listWithFiltersAccounts)), + Page: params.Page, + PageSize: params.PageSize, + } + } + + return s.listWithFiltersAccounts, result, nil +} + +type proxyRepoStubForAdminList struct { + proxyRepoStub + + listWithFiltersCalls int + listWithFiltersParams pagination.PaginationParams + listWithFiltersProtocol string + listWithFiltersStatus string + listWithFiltersSearch string + listWithFiltersProxies []Proxy + listWithFiltersResult *pagination.PaginationResult + listWithFiltersErr error + + listWithFiltersAndAccountCountCalls int + listWithFiltersAndAccountCountParams pagination.PaginationParams + listWithFiltersAndAccountCountProtocol string + listWithFiltersAndAccountCountStatus string + listWithFiltersAndAccountCountSearch string + listWithFiltersAndAccountCountProxies []ProxyWithAccountCount + listWithFiltersAndAccountCountResult *pagination.PaginationResult + listWithFiltersAndAccountCountErr error +} + +func (s *proxyRepoStubForAdminList) ListWithFilters(_ context.Context, params pagination.PaginationParams, protocol, status, search string) ([]Proxy, *pagination.PaginationResult, error) { + s.listWithFiltersCalls++ + s.listWithFiltersParams = params + s.listWithFiltersProtocol = protocol + s.listWithFiltersStatus = status + s.listWithFiltersSearch = search + + if s.listWithFiltersErr != nil { + return nil, nil, s.listWithFiltersErr + } + + result := s.listWithFiltersResult + if result == nil { + result = &pagination.PaginationResult{ + Total: int64(len(s.listWithFiltersProxies)), + Page: params.Page, + PageSize: params.PageSize, + } + } + + return s.listWithFiltersProxies, result, nil +} + +func (s *proxyRepoStubForAdminList) ListWithFiltersAndAccountCount(_ context.Context, params pagination.PaginationParams, protocol, status, search string) ([]ProxyWithAccountCount, *pagination.PaginationResult, error) { + s.listWithFiltersAndAccountCountCalls++ + s.listWithFiltersAndAccountCountParams = params + s.listWithFiltersAndAccountCountProtocol = protocol + s.listWithFiltersAndAccountCountStatus = status + s.listWithFiltersAndAccountCountSearch = search + + if s.listWithFiltersAndAccountCountErr != nil { + return nil, nil, s.listWithFiltersAndAccountCountErr + } + + result := s.listWithFiltersAndAccountCountResult + if result == nil { + result = &pagination.PaginationResult{ + Total: int64(len(s.listWithFiltersAndAccountCountProxies)), + Page: params.Page, + PageSize: params.PageSize, + } + } + + return s.listWithFiltersAndAccountCountProxies, result, nil +} + +type redeemRepoStubForAdminList struct { + redeemRepoStub + + listWithFiltersCalls int + listWithFiltersParams pagination.PaginationParams + listWithFiltersType string + listWithFiltersStatus string + listWithFiltersSearch string + listWithFiltersCodes []RedeemCode + listWithFiltersResult *pagination.PaginationResult + listWithFiltersErr error +} + +func (s *redeemRepoStubForAdminList) ListWithFilters(_ context.Context, params pagination.PaginationParams, codeType, status, search string) ([]RedeemCode, *pagination.PaginationResult, error) { + s.listWithFiltersCalls++ + s.listWithFiltersParams = params + s.listWithFiltersType = codeType + s.listWithFiltersStatus = status + s.listWithFiltersSearch = search + + if s.listWithFiltersErr != nil { + return nil, nil, s.listWithFiltersErr + } + + result := s.listWithFiltersResult + if result == nil { + result = &pagination.PaginationResult{ + Total: int64(len(s.listWithFiltersCodes)), + Page: params.Page, + PageSize: params.PageSize, + } + } + + return s.listWithFiltersCodes, result, nil +} + +func TestAdminService_ListAccounts_WithSearch(t *testing.T) { + t.Run("search 参数正常传递到 repository 层", func(t *testing.T) { + repo := &accountRepoStubForAdminList{ + listWithFiltersAccounts: []Account{{ID: 1, Name: "acc"}}, + listWithFiltersResult: &pagination.PaginationResult{Total: 10}, + } + svc := &adminServiceImpl{accountRepo: repo} + + accounts, total, err := svc.ListAccounts(context.Background(), 1, 20, PlatformGemini, AccountTypeOAuth, StatusActive, "acc") + require.NoError(t, err) + require.Equal(t, int64(10), total) + require.Equal(t, []Account{{ID: 1, Name: "acc"}}, accounts) + + require.Equal(t, 1, repo.listWithFiltersCalls) + require.Equal(t, pagination.PaginationParams{Page: 1, PageSize: 20}, repo.listWithFiltersParams) + require.Equal(t, PlatformGemini, repo.listWithFiltersPlatform) + require.Equal(t, AccountTypeOAuth, repo.listWithFiltersType) + require.Equal(t, StatusActive, repo.listWithFiltersStatus) + require.Equal(t, "acc", repo.listWithFiltersSearch) + }) +} + +func TestAdminService_ListProxies_WithSearch(t *testing.T) { + t.Run("search 参数正常传递到 repository 层", func(t *testing.T) { + repo := &proxyRepoStubForAdminList{ + listWithFiltersProxies: []Proxy{{ID: 2, Name: "p1"}}, + listWithFiltersResult: &pagination.PaginationResult{Total: 7}, + } + svc := &adminServiceImpl{proxyRepo: repo} + + proxies, total, err := svc.ListProxies(context.Background(), 3, 50, "http", StatusActive, "p1") + require.NoError(t, err) + require.Equal(t, int64(7), total) + require.Equal(t, []Proxy{{ID: 2, Name: "p1"}}, proxies) + + require.Equal(t, 1, repo.listWithFiltersCalls) + require.Equal(t, pagination.PaginationParams{Page: 3, PageSize: 50}, repo.listWithFiltersParams) + require.Equal(t, "http", repo.listWithFiltersProtocol) + require.Equal(t, StatusActive, repo.listWithFiltersStatus) + require.Equal(t, "p1", repo.listWithFiltersSearch) + }) +} + +func TestAdminService_ListProxiesWithAccountCount_WithSearch(t *testing.T) { + t.Run("search 参数正常传递到 repository 层", func(t *testing.T) { + repo := &proxyRepoStubForAdminList{ + listWithFiltersAndAccountCountProxies: []ProxyWithAccountCount{{Proxy: Proxy{ID: 3, Name: "p2"}, AccountCount: 5}}, + listWithFiltersAndAccountCountResult: &pagination.PaginationResult{Total: 9}, + } + svc := &adminServiceImpl{proxyRepo: repo} + + proxies, total, err := svc.ListProxiesWithAccountCount(context.Background(), 2, 10, "socks5", StatusDisabled, "p2") + require.NoError(t, err) + require.Equal(t, int64(9), total) + require.Equal(t, []ProxyWithAccountCount{{Proxy: Proxy{ID: 3, Name: "p2"}, AccountCount: 5}}, proxies) + + require.Equal(t, 1, repo.listWithFiltersAndAccountCountCalls) + require.Equal(t, pagination.PaginationParams{Page: 2, PageSize: 10}, repo.listWithFiltersAndAccountCountParams) + require.Equal(t, "socks5", repo.listWithFiltersAndAccountCountProtocol) + require.Equal(t, StatusDisabled, repo.listWithFiltersAndAccountCountStatus) + require.Equal(t, "p2", repo.listWithFiltersAndAccountCountSearch) + }) +} + +func TestAdminService_ListRedeemCodes_WithSearch(t *testing.T) { + t.Run("search 参数正常传递到 repository 层", func(t *testing.T) { + repo := &redeemRepoStubForAdminList{ + listWithFiltersCodes: []RedeemCode{{ID: 4, Code: "ABC"}}, + listWithFiltersResult: &pagination.PaginationResult{Total: 3}, + } + svc := &adminServiceImpl{redeemCodeRepo: repo} + + codes, total, err := svc.ListRedeemCodes(context.Background(), 1, 20, RedeemTypeBalance, StatusUnused, "ABC") + require.NoError(t, err) + require.Equal(t, int64(3), total) + require.Equal(t, []RedeemCode{{ID: 4, Code: "ABC"}}, codes) + + require.Equal(t, 1, repo.listWithFiltersCalls) + require.Equal(t, pagination.PaginationParams{Page: 1, PageSize: 20}, repo.listWithFiltersParams) + require.Equal(t, RedeemTypeBalance, repo.listWithFiltersType) + require.Equal(t, StatusUnused, repo.listWithFiltersStatus) + require.Equal(t, "ABC", repo.listWithFiltersSearch) + }) +} diff --git a/backend/internal/service/antigravity_gateway_service.go b/backend/internal/service/antigravity_gateway_service.go index 2fe77b2d..573017cd 100644 --- a/backend/internal/service/antigravity_gateway_service.go +++ b/backend/internal/service/antigravity_gateway_service.go @@ -10,6 +10,7 @@ import ( "io" "log" mathrand "math/rand" + "net" "net/http" "strings" "sync/atomic" @@ -27,6 +28,32 @@ const ( antigravityRetryMaxDelay = 16 * time.Second ) +// isAntigravityConnectionError 判断是否为连接错误(网络超时、DNS 失败、连接拒绝) +func isAntigravityConnectionError(err error) bool { + if err == nil { + return false + } + + // 检查超时错误 + var netErr net.Error + if errors.As(err, &netErr) && netErr.Timeout() { + return true + } + + // 检查连接错误(DNS 失败、连接拒绝) + var opErr *net.OpError + return errors.As(err, &opErr) +} + +// shouldAntigravityFallbackToNextURL 判断是否应切换到下一个 URL +// 仅连接错误和 HTTP 429 触发 URL 降级 +func shouldAntigravityFallbackToNextURL(err error, statusCode int) bool { + if isAntigravityConnectionError(err) { + return true + } + return statusCode == http.StatusTooManyRequests +} + // getSessionID 从 gin.Context 获取 session_id(用于日志追踪) func getSessionID(c *gin.Context) string { if c == nil { @@ -181,45 +208,70 @@ func (s *AntigravityGatewayService) TestConnection(ctx context.Context, account return nil, fmt.Errorf("构建请求失败: %w", err) } - // 构建 HTTP 请求(总是使用流式 endpoint,与官方客户端一致) - req, err := antigravity.NewAPIRequest(ctx, "streamGenerateContent", accessToken, requestBody) - if err != nil { - return nil, err - } - - // 调试日志:Test 请求信息 - log.Printf("[antigravity-Test] account=%s request_size=%d url=%s", account.Name, len(requestBody), req.URL.String()) - // 代理 URL proxyURL := "" if account.ProxyID != nil && account.Proxy != nil { proxyURL = account.Proxy.URL() } - // 发送请求 - resp, err := s.httpUpstream.Do(req, proxyURL, account.ID, account.Concurrency) - if err != nil { - return nil, fmt.Errorf("请求失败: %w", err) - } - defer func() { _ = resp.Body.Close() }() - - // 读取响应 - respBody, err := io.ReadAll(io.LimitReader(resp.Body, 2<<20)) - if err != nil { - return nil, fmt.Errorf("读取响应失败: %w", err) + // URL fallback 循环 + availableURLs := antigravity.DefaultURLAvailability.GetAvailableURLs() + if len(availableURLs) == 0 { + availableURLs = antigravity.BaseURLs // 所有 URL 都不可用时,重试所有 } - if resp.StatusCode >= 400 { - return nil, fmt.Errorf("API 返回 %d: %s", resp.StatusCode, string(respBody)) + var lastErr error + for urlIdx, baseURL := range availableURLs { + // 构建 HTTP 请求(总是使用流式 endpoint,与官方客户端一致) + req, err := antigravity.NewAPIRequestWithURL(ctx, baseURL, "streamGenerateContent", accessToken, requestBody) + if err != nil { + lastErr = err + continue + } + + // 调试日志:Test 请求信息 + log.Printf("[antigravity-Test] account=%s request_size=%d url=%s", account.Name, len(requestBody), req.URL.String()) + + // 发送请求 + resp, err := s.httpUpstream.Do(req, proxyURL, account.ID, account.Concurrency) + if err != nil { + lastErr = fmt.Errorf("请求失败: %w", err) + if shouldAntigravityFallbackToNextURL(err, 0) && urlIdx < len(availableURLs)-1 { + antigravity.DefaultURLAvailability.MarkUnavailable(baseURL) + log.Printf("[antigravity-Test] URL fallback: %s -> %s", baseURL, availableURLs[urlIdx+1]) + continue + } + return nil, lastErr + } + + // 读取响应 + respBody, err := io.ReadAll(io.LimitReader(resp.Body, 2<<20)) + _ = resp.Body.Close() // 立即关闭,避免循环内 defer 导致的资源泄漏 + if err != nil { + return nil, fmt.Errorf("读取响应失败: %w", err) + } + + // 检查是否需要 URL 降级 + if shouldAntigravityFallbackToNextURL(nil, resp.StatusCode) && urlIdx < len(availableURLs)-1 { + antigravity.DefaultURLAvailability.MarkUnavailable(baseURL) + log.Printf("[antigravity-Test] URL fallback (HTTP %d): %s -> %s", resp.StatusCode, baseURL, availableURLs[urlIdx+1]) + continue + } + + if resp.StatusCode >= 400 { + return nil, fmt.Errorf("API 返回 %d: %s", resp.StatusCode, string(respBody)) + } + + // 解析流式响应,提取文本 + text := extractTextFromSSEResponse(respBody) + + return &TestConnectionResult{ + Text: text, + MappedModel: mappedModel, + }, nil } - // 解析流式响应,提取文本 - text := extractTextFromSSEResponse(respBody) - - return &TestConnectionResult{ - Text: text, - MappedModel: mappedModel, - }, nil + return nil, lastErr } // buildGeminiTestRequest 构建 Gemini 格式测试请求 @@ -484,62 +536,86 @@ func (s *AntigravityGatewayService) Forward(ctx context.Context, c *gin.Context, // 如果客户端请求非流式,在响应处理阶段会收集完整流式响应后转换返回 action := "streamGenerateContent" + // URL fallback 循环 + availableURLs := antigravity.DefaultURLAvailability.GetAvailableURLs() + if len(availableURLs) == 0 { + availableURLs = antigravity.BaseURLs // 所有 URL 都不可用时,重试所有 + } + // 重试循环 var resp *http.Response - for attempt := 1; attempt <= antigravityMaxRetries; attempt++ { - // 检查 context 是否已取消(客户端断开连接) - select { - case <-ctx.Done(): - log.Printf("%s status=context_canceled error=%v", prefix, ctx.Err()) - return nil, ctx.Err() - default: - } +urlFallbackLoop: + for urlIdx, baseURL := range availableURLs { + for attempt := 1; attempt <= antigravityMaxRetries; attempt++ { + // 检查 context 是否已取消(客户端断开连接) + select { + case <-ctx.Done(): + log.Printf("%s status=context_canceled error=%v", prefix, ctx.Err()) + return nil, ctx.Err() + default: + } - upstreamReq, err := antigravity.NewAPIRequest(ctx, action, accessToken, geminiBody) - if err != nil { - return nil, err - } + upstreamReq, err := antigravity.NewAPIRequestWithURL(ctx, baseURL, action, accessToken, geminiBody) + if err != nil { + return nil, err + } - resp, err = s.httpUpstream.Do(upstreamReq, proxyURL, account.ID, account.Concurrency) - if err != nil { - if attempt < antigravityMaxRetries { - log.Printf("%s status=request_failed retry=%d/%d error=%v", prefix, attempt, antigravityMaxRetries, err) - if !sleepAntigravityBackoffWithContext(ctx, attempt) { - log.Printf("%s status=context_canceled_during_backoff", prefix) - return nil, ctx.Err() + resp, err = s.httpUpstream.Do(upstreamReq, proxyURL, account.ID, account.Concurrency) + if err != nil { + // 检查是否应触发 URL 降级 + if shouldAntigravityFallbackToNextURL(err, 0) && urlIdx < len(availableURLs)-1 { + antigravity.DefaultURLAvailability.MarkUnavailable(baseURL) + log.Printf("%s URL fallback (connection error): %s -> %s", prefix, baseURL, availableURLs[urlIdx+1]) + continue urlFallbackLoop } - continue - } - log.Printf("%s status=request_failed retries_exhausted error=%v", prefix, err) - return nil, s.writeClaudeError(c, http.StatusBadGateway, "upstream_error", "Upstream request failed after retries") - } - - if resp.StatusCode >= 400 && s.shouldRetryUpstreamError(resp.StatusCode) { - respBody, _ := io.ReadAll(io.LimitReader(resp.Body, 2<<20)) - _ = resp.Body.Close() - - if attempt < antigravityMaxRetries { - log.Printf("%s status=%d retry=%d/%d body=%s", prefix, resp.StatusCode, attempt, antigravityMaxRetries, truncateForLog(respBody, 500)) - if !sleepAntigravityBackoffWithContext(ctx, attempt) { - log.Printf("%s status=context_canceled_during_backoff", prefix) - return nil, ctx.Err() + if attempt < antigravityMaxRetries { + log.Printf("%s status=request_failed retry=%d/%d error=%v", prefix, attempt, antigravityMaxRetries, err) + if !sleepAntigravityBackoffWithContext(ctx, attempt) { + log.Printf("%s status=context_canceled_during_backoff", prefix) + return nil, ctx.Err() + } + continue } - continue + log.Printf("%s status=request_failed retries_exhausted error=%v", prefix, err) + return nil, s.writeClaudeError(c, http.StatusBadGateway, "upstream_error", "Upstream request failed after retries") } - // 所有重试都失败,标记限流状态 - if resp.StatusCode == 429 { - s.handleUpstreamError(ctx, prefix, account, resp.StatusCode, resp.Header, respBody) - } - // 最后一次尝试也失败 - resp = &http.Response{ - StatusCode: resp.StatusCode, - Header: resp.Header.Clone(), - Body: io.NopCloser(bytes.NewReader(respBody)), - } - break - } - break + // 检查是否应触发 URL 降级(仅 429) + if resp.StatusCode == http.StatusTooManyRequests && urlIdx < len(availableURLs)-1 { + respBody, _ := io.ReadAll(io.LimitReader(resp.Body, 2<<20)) + _ = resp.Body.Close() + antigravity.DefaultURLAvailability.MarkUnavailable(baseURL) + log.Printf("%s URL fallback (HTTP 429): %s -> %s body=%s", prefix, baseURL, availableURLs[urlIdx+1], truncateForLog(respBody, 200)) + continue urlFallbackLoop + } + + if resp.StatusCode >= 400 && s.shouldRetryUpstreamError(resp.StatusCode) { + respBody, _ := io.ReadAll(io.LimitReader(resp.Body, 2<<20)) + _ = resp.Body.Close() + + if attempt < antigravityMaxRetries { + log.Printf("%s status=%d retry=%d/%d body=%s", prefix, resp.StatusCode, attempt, antigravityMaxRetries, truncateForLog(respBody, 500)) + if !sleepAntigravityBackoffWithContext(ctx, attempt) { + log.Printf("%s status=context_canceled_during_backoff", prefix) + return nil, ctx.Err() + } + continue + } + // 所有重试都失败,标记限流状态 + if resp.StatusCode == 429 { + s.handleUpstreamError(ctx, prefix, account, resp.StatusCode, resp.Header, respBody) + } + // 最后一次尝试也失败 + resp = &http.Response{ + StatusCode: resp.StatusCode, + Header: resp.Header.Clone(), + Body: io.NopCloser(bytes.NewReader(respBody)), + } + break urlFallbackLoop + } + + break urlFallbackLoop + } } defer func() { _ = resp.Body.Close() }() @@ -1003,61 +1079,85 @@ func (s *AntigravityGatewayService) ForwardGemini(ctx context.Context, c *gin.Co // 如果客户端请求非流式,在响应处理阶段会收集完整流式响应后返回 upstreamAction := "streamGenerateContent" + // URL fallback 循环 + availableURLs := antigravity.DefaultURLAvailability.GetAvailableURLs() + if len(availableURLs) == 0 { + availableURLs = antigravity.BaseURLs // 所有 URL 都不可用时,重试所有 + } + // 重试循环 var resp *http.Response - for attempt := 1; attempt <= antigravityMaxRetries; attempt++ { - // 检查 context 是否已取消(客户端断开连接) - select { - case <-ctx.Done(): - log.Printf("%s status=context_canceled error=%v", prefix, ctx.Err()) - return nil, ctx.Err() - default: - } +urlFallbackLoop: + for urlIdx, baseURL := range availableURLs { + for attempt := 1; attempt <= antigravityMaxRetries; attempt++ { + // 检查 context 是否已取消(客户端断开连接) + select { + case <-ctx.Done(): + log.Printf("%s status=context_canceled error=%v", prefix, ctx.Err()) + return nil, ctx.Err() + default: + } - upstreamReq, err := antigravity.NewAPIRequest(ctx, upstreamAction, accessToken, wrappedBody) - if err != nil { - return nil, err - } + upstreamReq, err := antigravity.NewAPIRequestWithURL(ctx, baseURL, upstreamAction, accessToken, wrappedBody) + if err != nil { + return nil, err + } - resp, err = s.httpUpstream.Do(upstreamReq, proxyURL, account.ID, account.Concurrency) - if err != nil { - if attempt < antigravityMaxRetries { - log.Printf("%s status=request_failed retry=%d/%d error=%v", prefix, attempt, antigravityMaxRetries, err) - if !sleepAntigravityBackoffWithContext(ctx, attempt) { - log.Printf("%s status=context_canceled_during_backoff", prefix) - return nil, ctx.Err() + resp, err = s.httpUpstream.Do(upstreamReq, proxyURL, account.ID, account.Concurrency) + if err != nil { + // 检查是否应触发 URL 降级 + if shouldAntigravityFallbackToNextURL(err, 0) && urlIdx < len(availableURLs)-1 { + antigravity.DefaultURLAvailability.MarkUnavailable(baseURL) + log.Printf("%s URL fallback (connection error): %s -> %s", prefix, baseURL, availableURLs[urlIdx+1]) + continue urlFallbackLoop } - continue - } - log.Printf("%s status=request_failed retries_exhausted error=%v", prefix, err) - return nil, s.writeGoogleError(c, http.StatusBadGateway, "Upstream request failed after retries") - } - - if resp.StatusCode >= 400 && s.shouldRetryUpstreamError(resp.StatusCode) { - respBody, _ := io.ReadAll(io.LimitReader(resp.Body, 2<<20)) - _ = resp.Body.Close() - - if attempt < antigravityMaxRetries { - log.Printf("%s status=%d retry=%d/%d", prefix, resp.StatusCode, attempt, antigravityMaxRetries) - if !sleepAntigravityBackoffWithContext(ctx, attempt) { - log.Printf("%s status=context_canceled_during_backoff", prefix) - return nil, ctx.Err() + if attempt < antigravityMaxRetries { + log.Printf("%s status=request_failed retry=%d/%d error=%v", prefix, attempt, antigravityMaxRetries, err) + if !sleepAntigravityBackoffWithContext(ctx, attempt) { + log.Printf("%s status=context_canceled_during_backoff", prefix) + return nil, ctx.Err() + } + continue } - continue + log.Printf("%s status=request_failed retries_exhausted error=%v", prefix, err) + return nil, s.writeGoogleError(c, http.StatusBadGateway, "Upstream request failed after retries") } - // 所有重试都失败,标记限流状态 - if resp.StatusCode == 429 { - s.handleUpstreamError(ctx, prefix, account, resp.StatusCode, resp.Header, respBody) - } - resp = &http.Response{ - StatusCode: resp.StatusCode, - Header: resp.Header.Clone(), - Body: io.NopCloser(bytes.NewReader(respBody)), - } - break - } - break + // 检查是否应触发 URL 降级(仅 429) + if resp.StatusCode == http.StatusTooManyRequests && urlIdx < len(availableURLs)-1 { + respBody, _ := io.ReadAll(io.LimitReader(resp.Body, 2<<20)) + _ = resp.Body.Close() + antigravity.DefaultURLAvailability.MarkUnavailable(baseURL) + log.Printf("%s URL fallback (HTTP 429): %s -> %s body=%s", prefix, baseURL, availableURLs[urlIdx+1], truncateForLog(respBody, 200)) + continue urlFallbackLoop + } + + if resp.StatusCode >= 400 && s.shouldRetryUpstreamError(resp.StatusCode) { + respBody, _ := io.ReadAll(io.LimitReader(resp.Body, 2<<20)) + _ = resp.Body.Close() + + if attempt < antigravityMaxRetries { + log.Printf("%s status=%d retry=%d/%d", prefix, resp.StatusCode, attempt, antigravityMaxRetries) + if !sleepAntigravityBackoffWithContext(ctx, attempt) { + log.Printf("%s status=context_canceled_during_backoff", prefix) + return nil, ctx.Err() + } + continue + } + // 所有重试都失败,标记限流状态 + if resp.StatusCode == 429 { + s.handleUpstreamError(ctx, prefix, account, resp.StatusCode, resp.Header, respBody) + } + resp = &http.Response{ + StatusCode: resp.StatusCode, + Header: resp.Header.Clone(), + Body: io.NopCloser(bytes.NewReader(respBody)), + } + break urlFallbackLoop + } + + break urlFallbackLoop + } } defer func() { if resp != nil && resp.Body != nil { diff --git a/backend/internal/service/auth_service.go b/backend/internal/service/auth_service.go index 85772e75..e232deb3 100644 --- a/backend/internal/service/auth_service.go +++ b/backend/internal/service/auth_service.go @@ -2,9 +2,13 @@ package service import ( "context" + "crypto/rand" + "encoding/hex" "errors" "fmt" "log" + "net/mail" + "strings" "time" "github.com/Wei-Shaw/sub2api/internal/config" @@ -18,6 +22,7 @@ var ( ErrInvalidCredentials = infraerrors.Unauthorized("INVALID_CREDENTIALS", "invalid email or password") ErrUserNotActive = infraerrors.Forbidden("USER_NOT_ACTIVE", "user is not active") ErrEmailExists = infraerrors.Conflict("EMAIL_EXISTS", "email already exists") + ErrEmailReserved = infraerrors.BadRequest("EMAIL_RESERVED", "email is reserved") ErrInvalidToken = infraerrors.Unauthorized("INVALID_TOKEN", "invalid token") ErrTokenExpired = infraerrors.Unauthorized("TOKEN_EXPIRED", "token has expired") ErrTokenTooLarge = infraerrors.BadRequest("TOKEN_TOO_LARGE", "token too large") @@ -75,21 +80,30 @@ func (s *AuthService) Register(ctx context.Context, email, password string) (str // RegisterWithVerification 用户注册(支持邮件验证),返回token和用户 func (s *AuthService) RegisterWithVerification(ctx context.Context, email, password, verifyCode string) (string, *User, error) { - // 检查是否开放注册 - if s.settingService != nil && !s.settingService.IsRegistrationEnabled(ctx) { + // 检查是否开放注册(默认关闭:settingService 未配置时不允许注册) + if s.settingService == nil || !s.settingService.IsRegistrationEnabled(ctx) { return "", nil, ErrRegDisabled } + // 防止用户注册 LinuxDo OAuth 合成邮箱,避免第三方登录与本地账号发生碰撞。 + if isReservedEmail(email) { + return "", nil, ErrEmailReserved + } + // 检查是否需要邮件验证 if s.settingService != nil && s.settingService.IsEmailVerifyEnabled(ctx) { + // 如果邮件验证已开启但邮件服务未配置,拒绝注册 + // 这是一个配置错误,不应该允许绕过验证 + if s.emailService == nil { + log.Println("[Auth] Email verification enabled but email service not configured, rejecting registration") + return "", nil, ErrServiceUnavailable + } if verifyCode == "" { return "", nil, ErrEmailVerifyRequired } // 验证邮箱验证码 - if s.emailService != nil { - if err := s.emailService.VerifyCode(ctx, email, verifyCode); err != nil { - return "", nil, fmt.Errorf("verify code: %w", err) - } + if err := s.emailService.VerifyCode(ctx, email, verifyCode); err != nil { + return "", nil, fmt.Errorf("verify code: %w", err) } } @@ -128,6 +142,10 @@ func (s *AuthService) RegisterWithVerification(ctx context.Context, email, passw } if err := s.userRepo.Create(ctx, user); err != nil { + // 优先检查邮箱冲突错误(竞态条件下可能发生) + if errors.Is(err, ErrEmailExists) { + return "", nil, ErrEmailExists + } log.Printf("[Auth] Database error creating user: %v", err) return "", nil, ErrServiceUnavailable } @@ -148,11 +166,15 @@ type SendVerifyCodeResult struct { // SendVerifyCode 发送邮箱验证码(同步方式) func (s *AuthService) SendVerifyCode(ctx context.Context, email string) error { - // 检查是否开放注册 - if s.settingService != nil && !s.settingService.IsRegistrationEnabled(ctx) { + // 检查是否开放注册(默认关闭) + if s.settingService == nil || !s.settingService.IsRegistrationEnabled(ctx) { return ErrRegDisabled } + if isReservedEmail(email) { + return ErrEmailReserved + } + // 检查邮箱是否已存在 existsEmail, err := s.userRepo.ExistsByEmail(ctx, email) if err != nil { @@ -181,12 +203,16 @@ func (s *AuthService) SendVerifyCode(ctx context.Context, email string) error { func (s *AuthService) SendVerifyCodeAsync(ctx context.Context, email string) (*SendVerifyCodeResult, error) { log.Printf("[Auth] SendVerifyCodeAsync called for email: %s", email) - // 检查是否开放注册 - if s.settingService != nil && !s.settingService.IsRegistrationEnabled(ctx) { + // 检查是否开放注册(默认关闭) + if s.settingService == nil || !s.settingService.IsRegistrationEnabled(ctx) { log.Println("[Auth] Registration is disabled") return nil, ErrRegDisabled } + if isReservedEmail(email) { + return nil, ErrEmailReserved + } + // 检查邮箱是否已存在 existsEmail, err := s.userRepo.ExistsByEmail(ctx, email) if err != nil { @@ -266,7 +292,7 @@ func (s *AuthService) IsTurnstileEnabled(ctx context.Context) bool { // IsRegistrationEnabled 检查是否开放注册 func (s *AuthService) IsRegistrationEnabled(ctx context.Context) bool { if s.settingService == nil { - return true + return false // 安全默认:settingService 未配置时关闭注册 } return s.settingService.IsRegistrationEnabled(ctx) } @@ -311,6 +337,102 @@ func (s *AuthService) Login(ctx context.Context, email, password string) (string return token, user, nil } +// LoginOrRegisterOAuth 用于第三方 OAuth/SSO 登录: +// - 如果邮箱已存在:直接登录(不需要本地密码) +// - 如果邮箱不存在:创建新用户并登录 +// +// 注意:该函数用于“终端用户登录 Sub2API 本身”的场景(不同于上游账号的 OAuth,例如 OpenAI/Gemini)。 +// 为了满足现有数据库约束(需要密码哈希),新用户会生成随机密码并进行哈希保存。 +func (s *AuthService) LoginOrRegisterOAuth(ctx context.Context, email, username string) (string, *User, error) { + email = strings.TrimSpace(email) + if email == "" || len(email) > 255 { + return "", nil, infraerrors.BadRequest("INVALID_EMAIL", "invalid email") + } + if _, err := mail.ParseAddress(email); err != nil { + return "", nil, infraerrors.BadRequest("INVALID_EMAIL", "invalid email") + } + + username = strings.TrimSpace(username) + if len([]rune(username)) > 100 { + username = string([]rune(username)[:100]) + } + + user, err := s.userRepo.GetByEmail(ctx, email) + if err != nil { + if errors.Is(err, ErrUserNotFound) { + // OAuth 首次登录视为注册。 + if s.settingService != nil && !s.settingService.IsRegistrationEnabled(ctx) { + return "", nil, ErrRegDisabled + } + + randomPassword, err := randomHexString(32) + if err != nil { + log.Printf("[Auth] Failed to generate random password for oauth signup: %v", err) + return "", nil, ErrServiceUnavailable + } + hashedPassword, err := s.HashPassword(randomPassword) + if err != nil { + return "", nil, fmt.Errorf("hash password: %w", err) + } + + // 新用户默认值。 + defaultBalance := s.cfg.Default.UserBalance + defaultConcurrency := s.cfg.Default.UserConcurrency + if s.settingService != nil { + defaultBalance = s.settingService.GetDefaultBalance(ctx) + defaultConcurrency = s.settingService.GetDefaultConcurrency(ctx) + } + + newUser := &User{ + Email: email, + Username: username, + PasswordHash: hashedPassword, + Role: RoleUser, + Balance: defaultBalance, + Concurrency: defaultConcurrency, + Status: StatusActive, + } + + if err := s.userRepo.Create(ctx, newUser); err != nil { + if errors.Is(err, ErrEmailExists) { + // 并发场景:GetByEmail 与 Create 之间用户被创建。 + user, err = s.userRepo.GetByEmail(ctx, email) + if err != nil { + log.Printf("[Auth] Database error getting user after conflict: %v", err) + return "", nil, ErrServiceUnavailable + } + } else { + log.Printf("[Auth] Database error creating oauth user: %v", err) + return "", nil, ErrServiceUnavailable + } + } else { + user = newUser + } + } else { + log.Printf("[Auth] Database error during oauth login: %v", err) + return "", nil, ErrServiceUnavailable + } + } + + if !user.IsActive() { + return "", nil, ErrUserNotActive + } + + // 尽力补全:当用户名为空时,使用第三方返回的用户名回填。 + if user.Username == "" && username != "" { + user.Username = username + if err := s.userRepo.Update(ctx, user); err != nil { + log.Printf("[Auth] Failed to update username after oauth login: %v", err) + } + } + + token, err := s.GenerateToken(user) + if err != nil { + return "", nil, fmt.Errorf("generate token: %w", err) + } + return token, user, nil +} + // ValidateToken 验证JWT token并返回用户声明 func (s *AuthService) ValidateToken(tokenString string) (*JWTClaims, error) { // 先做长度校验,尽早拒绝异常超长 token,降低 DoS 风险。 @@ -336,6 +458,11 @@ func (s *AuthService) ValidateToken(tokenString string) (*JWTClaims, error) { if err != nil { if errors.Is(err, jwt.ErrTokenExpired) { + // token 过期但仍返回 claims(用于 RefreshToken 等场景) + // jwt-go 在解析时即使遇到过期错误,token.Claims 仍会被填充 + if claims, ok := token.Claims.(*JWTClaims); ok { + return claims, ErrTokenExpired + } return nil, ErrTokenExpired } return nil, ErrInvalidToken @@ -348,6 +475,22 @@ func (s *AuthService) ValidateToken(tokenString string) (*JWTClaims, error) { return nil, ErrInvalidToken } +func randomHexString(byteLength int) (string, error) { + if byteLength <= 0 { + byteLength = 16 + } + buf := make([]byte, byteLength) + if _, err := rand.Read(buf); err != nil { + return "", err + } + return hex.EncodeToString(buf), nil +} + +func isReservedEmail(email string) bool { + normalized := strings.ToLower(strings.TrimSpace(email)) + return strings.HasSuffix(normalized, LinuxDoConnectSyntheticEmailDomain) +} + // GenerateToken 生成JWT token func (s *AuthService) GenerateToken(user *User) (string, error) { now := time.Now() diff --git a/backend/internal/service/auth_service_register_test.go b/backend/internal/service/auth_service_register_test.go index cd6e2808..ab1f20a0 100644 --- a/backend/internal/service/auth_service_register_test.go +++ b/backend/internal/service/auth_service_register_test.go @@ -113,13 +113,36 @@ func TestAuthService_Register_Disabled(t *testing.T) { require.ErrorIs(t, err, ErrRegDisabled) } -func TestAuthService_Register_EmailVerifyRequired(t *testing.T) { +func TestAuthService_Register_DisabledByDefault(t *testing.T) { + // 当 settings 为 nil(设置项不存在)时,注册应该默认关闭 repo := &userRepoStub{} + service := newAuthService(repo, nil, nil) + + _, _, err := service.Register(context.Background(), "user@test.com", "password") + require.ErrorIs(t, err, ErrRegDisabled) +} + +func TestAuthService_Register_EmailVerifyEnabledButServiceNotConfigured(t *testing.T) { + repo := &userRepoStub{} + // 邮件验证开启但 emailCache 为 nil(emailService 未配置) service := newAuthService(repo, map[string]string{ SettingKeyRegistrationEnabled: "true", SettingKeyEmailVerifyEnabled: "true", }, nil) + // 应返回服务不可用错误,而不是允许绕过验证 + _, _, err := service.RegisterWithVerification(context.Background(), "user@test.com", "password", "any-code") + require.ErrorIs(t, err, ErrServiceUnavailable) +} + +func TestAuthService_Register_EmailVerifyRequired(t *testing.T) { + repo := &userRepoStub{} + cache := &emailCacheStub{} // 配置 emailService + service := newAuthService(repo, map[string]string{ + SettingKeyRegistrationEnabled: "true", + SettingKeyEmailVerifyEnabled: "true", + }, cache) + _, _, err := service.RegisterWithVerification(context.Background(), "user@test.com", "password", "") require.ErrorIs(t, err, ErrEmailVerifyRequired) } @@ -141,7 +164,9 @@ func TestAuthService_Register_EmailVerifyInvalid(t *testing.T) { func TestAuthService_Register_EmailExists(t *testing.T) { repo := &userRepoStub{exists: true} - service := newAuthService(repo, nil, nil) + service := newAuthService(repo, map[string]string{ + SettingKeyRegistrationEnabled: "true", + }, nil) _, _, err := service.Register(context.Background(), "user@test.com", "password") require.ErrorIs(t, err, ErrEmailExists) @@ -149,23 +174,50 @@ func TestAuthService_Register_EmailExists(t *testing.T) { func TestAuthService_Register_CheckEmailError(t *testing.T) { repo := &userRepoStub{existsErr: errors.New("db down")} - service := newAuthService(repo, nil, nil) + service := newAuthService(repo, map[string]string{ + SettingKeyRegistrationEnabled: "true", + }, nil) _, _, err := service.Register(context.Background(), "user@test.com", "password") require.ErrorIs(t, err, ErrServiceUnavailable) } +func TestAuthService_Register_ReservedEmail(t *testing.T) { + repo := &userRepoStub{} + service := newAuthService(repo, map[string]string{ + SettingKeyRegistrationEnabled: "true", + }, nil) + + _, _, err := service.Register(context.Background(), "linuxdo-123@linuxdo-connect.invalid", "password") + require.ErrorIs(t, err, ErrEmailReserved) +} + func TestAuthService_Register_CreateError(t *testing.T) { repo := &userRepoStub{createErr: errors.New("create failed")} - service := newAuthService(repo, nil, nil) + service := newAuthService(repo, map[string]string{ + SettingKeyRegistrationEnabled: "true", + }, nil) _, _, err := service.Register(context.Background(), "user@test.com", "password") require.ErrorIs(t, err, ErrServiceUnavailable) } +func TestAuthService_Register_CreateEmailExistsRace(t *testing.T) { + // 模拟竞态条件:ExistsByEmail 返回 false,但 Create 时因唯一约束失败 + repo := &userRepoStub{createErr: ErrEmailExists} + service := newAuthService(repo, map[string]string{ + SettingKeyRegistrationEnabled: "true", + }, nil) + + _, _, err := service.Register(context.Background(), "user@test.com", "password") + require.ErrorIs(t, err, ErrEmailExists) +} + func TestAuthService_Register_Success(t *testing.T) { repo := &userRepoStub{nextID: 5} - service := newAuthService(repo, nil, nil) + service := newAuthService(repo, map[string]string{ + SettingKeyRegistrationEnabled: "true", + }, nil) token, user, err := service.Register(context.Background(), "user@test.com", "password") require.NoError(t, err) @@ -180,3 +232,63 @@ func TestAuthService_Register_Success(t *testing.T) { require.Len(t, repo.created, 1) require.True(t, user.CheckPassword("password")) } + +func TestAuthService_ValidateToken_ExpiredReturnsClaimsWithError(t *testing.T) { + repo := &userRepoStub{} + service := newAuthService(repo, nil, nil) + + // 创建用户并生成 token + user := &User{ + ID: 1, + Email: "test@test.com", + Role: RoleUser, + Status: StatusActive, + TokenVersion: 1, + } + token, err := service.GenerateToken(user) + require.NoError(t, err) + + // 验证有效 token + claims, err := service.ValidateToken(token) + require.NoError(t, err) + require.NotNil(t, claims) + require.Equal(t, int64(1), claims.UserID) + + // 模拟过期 token(通过创建一个过期很久的 token) + service.cfg.JWT.ExpireHour = -1 // 设置为负数使 token 立即过期 + expiredToken, err := service.GenerateToken(user) + require.NoError(t, err) + service.cfg.JWT.ExpireHour = 1 // 恢复 + + // 验证过期 token 应返回 claims 和 ErrTokenExpired + claims, err = service.ValidateToken(expiredToken) + require.ErrorIs(t, err, ErrTokenExpired) + require.NotNil(t, claims, "claims should not be nil when token is expired") + require.Equal(t, int64(1), claims.UserID) + require.Equal(t, "test@test.com", claims.Email) +} + +func TestAuthService_RefreshToken_ExpiredTokenNoPanic(t *testing.T) { + user := &User{ + ID: 1, + Email: "test@test.com", + Role: RoleUser, + Status: StatusActive, + TokenVersion: 1, + } + repo := &userRepoStub{user: user} + service := newAuthService(repo, nil, nil) + + // 创建过期 token + service.cfg.JWT.ExpireHour = -1 + expiredToken, err := service.GenerateToken(user) + require.NoError(t, err) + service.cfg.JWT.ExpireHour = 1 + + // RefreshToken 使用过期 token 不应 panic + require.NotPanics(t, func() { + newToken, err := service.RefreshToken(context.Background(), expiredToken) + require.NoError(t, err) + require.NotEmpty(t, newToken) + }) +} diff --git a/backend/internal/service/domain_constants.go b/backend/internal/service/domain_constants.go index 9c61ea2e..df34e167 100644 --- a/backend/internal/service/domain_constants.go +++ b/backend/internal/service/domain_constants.go @@ -105,7 +105,17 @@ const ( // Request identity patch (Claude -> Gemini systemInstruction injection) SettingKeyEnableIdentityPatch = "enable_identity_patch" SettingKeyIdentityPatchPrompt = "identity_patch_prompt" + + // LinuxDo Connect OAuth 登录(终端用户 SSO) + SettingKeyLinuxDoConnectEnabled = "linuxdo_connect_enabled" + SettingKeyLinuxDoConnectClientID = "linuxdo_connect_client_id" + SettingKeyLinuxDoConnectClientSecret = "linuxdo_connect_client_secret" + SettingKeyLinuxDoConnectRedirectURL = "linuxdo_connect_redirect_url" ) +// LinuxDoConnectSyntheticEmailDomain 是 LinuxDo Connect 用户的合成邮箱后缀(RFC 保留域名)。 +// 目的:避免第三方登录返回的用户标识与本地真实邮箱发生碰撞,进而造成账号被接管的风险。 +const LinuxDoConnectSyntheticEmailDomain = "@linuxdo-connect.invalid" + // AdminAPIKeyPrefix is the prefix for admin API keys (distinct from user "sk-" keys). const AdminAPIKeyPrefix = "admin-" diff --git a/backend/internal/service/email_service.go b/backend/internal/service/email_service.go index afd8907c..55e137d6 100644 --- a/backend/internal/service/email_service.go +++ b/backend/internal/service/email_service.go @@ -5,6 +5,7 @@ import ( "crypto/rand" "crypto/tls" "fmt" + "log" "math/big" "net/smtp" "strconv" @@ -256,7 +257,9 @@ func (s *EmailService) VerifyCode(ctx context.Context, email, code string) error // 验证码不匹配 if data.Code != code { data.Attempts++ - _ = s.cache.SetVerificationCode(ctx, email, data, verifyCodeTTL) + if err := s.cache.SetVerificationCode(ctx, email, data, verifyCodeTTL); err != nil { + log.Printf("[Email] Failed to update verification attempt count: %v", err) + } if data.Attempts >= maxVerifyCodeAttempts { return ErrVerifyCodeMaxAttempts } @@ -264,7 +267,9 @@ func (s *EmailService) VerifyCode(ctx context.Context, email, code string) error } // 验证成功,删除验证码 - _ = s.cache.DeleteVerificationCode(ctx, email) + if err := s.cache.DeleteVerificationCode(ctx, email); err != nil { + log.Printf("[Email] Failed to delete verification code after success: %v", err) + } return nil } diff --git a/backend/internal/service/gemini_multiplatform_test.go b/backend/internal/service/gemini_multiplatform_test.go index 6007bce8..1a8b005f 100644 --- a/backend/internal/service/gemini_multiplatform_test.go +++ b/backend/internal/service/gemini_multiplatform_test.go @@ -166,7 +166,7 @@ func (m *mockGroupRepoForGemini) DeleteCascade(ctx context.Context, id int64) ([ func (m *mockGroupRepoForGemini) List(ctx context.Context, params pagination.PaginationParams) ([]Group, *pagination.PaginationResult, error) { return nil, nil, nil } -func (m *mockGroupRepoForGemini) ListWithFilters(ctx context.Context, params pagination.PaginationParams, platform, status string, isExclusive *bool) ([]Group, *pagination.PaginationResult, error) { +func (m *mockGroupRepoForGemini) ListWithFilters(ctx context.Context, params pagination.PaginationParams, platform, status, search string, isExclusive *bool) ([]Group, *pagination.PaginationResult, error) { return nil, nil, nil } func (m *mockGroupRepoForGemini) ListActive(ctx context.Context) ([]Group, error) { return nil, nil } diff --git a/backend/internal/service/gemini_oauth_service.go b/backend/internal/service/gemini_oauth_service.go index 48d31da9..bc84baeb 100644 --- a/backend/internal/service/gemini_oauth_service.go +++ b/backend/internal/service/gemini_oauth_service.go @@ -120,15 +120,16 @@ func (s *GeminiOAuthService) GenerateAuthURL(ctx context.Context, proxyID *int64 } // OAuth client selection: - // - code_assist: always use built-in Gemini CLI OAuth client (public), regardless of configured client_id/secret. - // - google_one: uses configured OAuth client when provided; otherwise falls back to built-in client. - // - ai_studio: requires a user-provided OAuth client. + // - code_assist: always use built-in Gemini CLI OAuth client (public) + // - google_one: always use built-in Gemini CLI OAuth client (public) + // - ai_studio: requires a user-provided OAuth client oauthCfg := geminicli.OAuthConfig{ ClientID: s.cfg.Gemini.OAuth.ClientID, ClientSecret: s.cfg.Gemini.OAuth.ClientSecret, Scopes: s.cfg.Gemini.OAuth.Scopes, } - if oauthType == "code_assist" { + if oauthType == "code_assist" || oauthType == "google_one" { + // Force use of built-in Gemini CLI OAuth client oauthCfg.ClientID = "" oauthCfg.ClientSecret = "" } @@ -576,6 +577,20 @@ func (s *GeminiOAuthService) ExchangeCode(ctx context.Context, input *GeminiExch case "google_one": log.Printf("[GeminiOAuth] Processing google_one OAuth type") + + // Google One accounts use cloudaicompanion API, which requires a project_id. + // For personal accounts, Google auto-assigns a project_id via the LoadCodeAssist API. + if projectID == "" { + log.Printf("[GeminiOAuth] No project_id provided, attempting to fetch from LoadCodeAssist API...") + var err error + projectID, _, err = s.fetchProjectID(ctx, tokenResp.AccessToken, proxyURL) + if err != nil { + log.Printf("[GeminiOAuth] ERROR: Failed to fetch project_id: %v", err) + return nil, fmt.Errorf("google One accounts require a project_id, failed to auto-detect: %w", err) + } + log.Printf("[GeminiOAuth] Successfully fetched project_id: %s", projectID) + } + log.Printf("[GeminiOAuth] Attempting to fetch Google One tier from Drive API...") // Attempt to fetch Drive storage tier var storageInfo *geminicli.DriveStorageInfo diff --git a/backend/internal/service/gemini_oauth_service_test.go b/backend/internal/service/gemini_oauth_service_test.go index eb3d86e6..5591eb39 100644 --- a/backend/internal/service/gemini_oauth_service_test.go +++ b/backend/internal/service/gemini_oauth_service_test.go @@ -40,7 +40,7 @@ func TestGeminiOAuthService_GenerateAuthURL_RedirectURIStrategy(t *testing.T) { wantProjectID: "", }, { - name: "google_one uses custom client when configured and redirects to localhost", + name: "google_one always forces built-in client even when custom client configured", cfg: &config.Config{ Gemini: config.GeminiConfig{ OAuth: config.GeminiOAuthConfig{ @@ -50,9 +50,9 @@ func TestGeminiOAuthService_GenerateAuthURL_RedirectURIStrategy(t *testing.T) { }, }, oauthType: "google_one", - wantClientID: "custom-client-id", - wantRedirect: geminicli.AIStudioOAuthRedirectURI, - wantScope: geminicli.DefaultGoogleOneScopes, + wantClientID: geminicli.GeminiCLIOAuthClientID, + wantRedirect: geminicli.GeminiCLIRedirectURI, + wantScope: geminicli.DefaultCodeAssistScopes, wantProjectID: "", }, { diff --git a/backend/internal/service/group_service.go b/backend/internal/service/group_service.go index 403636e8..a444556f 100644 --- a/backend/internal/service/group_service.go +++ b/backend/internal/service/group_service.go @@ -21,7 +21,7 @@ type GroupRepository interface { DeleteCascade(ctx context.Context, id int64) ([]int64, error) List(ctx context.Context, params pagination.PaginationParams) ([]Group, *pagination.PaginationResult, error) - ListWithFilters(ctx context.Context, params pagination.PaginationParams, platform, status string, isExclusive *bool) ([]Group, *pagination.PaginationResult, error) + ListWithFilters(ctx context.Context, params pagination.PaginationParams, platform, status, search string, isExclusive *bool) ([]Group, *pagination.PaginationResult, error) ListActive(ctx context.Context) ([]Group, error) ListActiveByPlatform(ctx context.Context, platform string) ([]Group, error) diff --git a/backend/internal/service/openai_gateway_service.go b/backend/internal/service/openai_gateway_service.go index 32d308f7..5bb7574a 100644 --- a/backend/internal/service/openai_gateway_service.go +++ b/backend/internal/service/openai_gateway_service.go @@ -540,10 +540,19 @@ func (s *OpenAIGatewayService) Forward(ctx context.Context, c *gin.Context, acco bodyModified = true } - // For OAuth accounts using ChatGPT internal API, add store: false + // For OAuth accounts using ChatGPT internal API: + // 1. Add store: false + // 2. Normalize input format for Codex API compatibility if account.Type == AccountTypeOAuth { reqBody["store"] = false bodyModified = true + + // Normalize input format: convert AI SDK multi-part content format to simplified format + // AI SDK sends: {"content": [{"type": "input_text", "text": "..."}]} + // Codex API expects: {"content": "..."} + if normalizeInputForCodexAPI(reqBody) { + bodyModified = true + } } // Re-serialize body only if modified @@ -1085,6 +1094,101 @@ func (s *OpenAIGatewayService) replaceModelInResponseBody(body []byte, fromModel return newBody } +// normalizeInputForCodexAPI converts AI SDK multi-part content format to simplified format +// that the ChatGPT internal Codex API expects. +// +// AI SDK sends content as an array of typed objects: +// +// {"content": [{"type": "input_text", "text": "hello"}]} +// +// ChatGPT Codex API expects content as a simple string: +// +// {"content": "hello"} +// +// This function modifies reqBody in-place and returns true if any modification was made. +func normalizeInputForCodexAPI(reqBody map[string]any) bool { + input, ok := reqBody["input"] + if !ok { + return false + } + + // Handle case where input is a simple string (already compatible) + if _, isString := input.(string); isString { + return false + } + + // Handle case where input is an array of messages + inputArray, ok := input.([]any) + if !ok { + return false + } + + modified := false + for _, item := range inputArray { + message, ok := item.(map[string]any) + if !ok { + continue + } + + content, ok := message["content"] + if !ok { + continue + } + + // If content is already a string, no conversion needed + if _, isString := content.(string); isString { + continue + } + + // If content is an array (AI SDK format), convert to string + contentArray, ok := content.([]any) + if !ok { + continue + } + + // Extract text from content array + var textParts []string + for _, part := range contentArray { + partMap, ok := part.(map[string]any) + if !ok { + continue + } + + // Handle different content types + partType, _ := partMap["type"].(string) + switch partType { + case "input_text", "text": + // Extract text from input_text or text type + if text, ok := partMap["text"].(string); ok { + textParts = append(textParts, text) + } + case "input_image", "image": + // For images, we need to preserve the original format + // as ChatGPT Codex API may support images in a different way + // For now, skip image parts (they will be lost in conversion) + // TODO: Consider preserving image data or handling it separately + continue + case "input_file", "file": + // Similar to images, file inputs may need special handling + continue + default: + // For unknown types, try to extract text if available + if text, ok := partMap["text"].(string); ok { + textParts = append(textParts, text) + } + } + } + + // Convert content array to string + if len(textParts) > 0 { + message["content"] = strings.Join(textParts, "\n") + modified = true + } + } + + return modified +} + // OpenAIRecordUsageInput input for recording usage type OpenAIRecordUsageInput struct { Result *OpenAIForwardResult diff --git a/backend/internal/service/setting_service.go b/backend/internal/service/setting_service.go index 6ce8ba2b..d25698de 100644 --- a/backend/internal/service/setting_service.go +++ b/backend/internal/service/setting_service.go @@ -7,6 +7,7 @@ import ( "errors" "fmt" "strconv" + "strings" "github.com/Wei-Shaw/sub2api/internal/config" infraerrors "github.com/Wei-Shaw/sub2api/internal/pkg/errors" @@ -64,6 +65,7 @@ func (s *SettingService) GetPublicSettings(ctx context.Context) (*PublicSettings SettingKeyAPIBaseURL, SettingKeyContactInfo, SettingKeyDocURL, + SettingKeyLinuxDoConnectEnabled, } settings, err := s.settingRepo.GetMultiple(ctx, keys) @@ -71,6 +73,13 @@ func (s *SettingService) GetPublicSettings(ctx context.Context) (*PublicSettings return nil, fmt.Errorf("get public settings: %w", err) } + linuxDoEnabled := false + if raw, ok := settings[SettingKeyLinuxDoConnectEnabled]; ok { + linuxDoEnabled = raw == "true" + } else { + linuxDoEnabled = s.cfg != nil && s.cfg.LinuxDo.Enabled + } + return &PublicSettings{ RegistrationEnabled: settings[SettingKeyRegistrationEnabled] == "true", EmailVerifyEnabled: settings[SettingKeyEmailVerifyEnabled] == "true", @@ -82,6 +91,7 @@ func (s *SettingService) GetPublicSettings(ctx context.Context) (*PublicSettings APIBaseURL: settings[SettingKeyAPIBaseURL], ContactInfo: settings[SettingKeyContactInfo], DocURL: settings[SettingKeyDocURL], + LinuxDoOAuthEnabled: linuxDoEnabled, }, nil } @@ -111,6 +121,14 @@ func (s *SettingService) UpdateSettings(ctx context.Context, settings *SystemSet updates[SettingKeyTurnstileSecretKey] = settings.TurnstileSecretKey } + // LinuxDo Connect OAuth 登录(终端用户 SSO) + updates[SettingKeyLinuxDoConnectEnabled] = strconv.FormatBool(settings.LinuxDoConnectEnabled) + updates[SettingKeyLinuxDoConnectClientID] = settings.LinuxDoConnectClientID + updates[SettingKeyLinuxDoConnectRedirectURL] = settings.LinuxDoConnectRedirectURL + if settings.LinuxDoConnectClientSecret != "" { + updates[SettingKeyLinuxDoConnectClientSecret] = settings.LinuxDoConnectClientSecret + } + // OEM设置 updates[SettingKeySiteName] = settings.SiteName updates[SettingKeySiteLogo] = settings.SiteLogo @@ -141,8 +159,8 @@ func (s *SettingService) UpdateSettings(ctx context.Context, settings *SystemSet func (s *SettingService) IsRegistrationEnabled(ctx context.Context) bool { value, err := s.settingRepo.GetValue(ctx, SettingKeyRegistrationEnabled) if err != nil { - // 默认开放注册 - return true + // 安全默认:如果设置不存在或查询出错,默认关闭注册 + return false } return value == "true" } @@ -271,6 +289,38 @@ func (s *SettingService) parseSettings(settings map[string]string) *SystemSettin result.SMTPPassword = settings[SettingKeySMTPPassword] result.TurnstileSecretKey = settings[SettingKeyTurnstileSecretKey] + // LinuxDo Connect 设置: + // - 兼容 config.yaml/env(避免老部署因为未迁移到数据库设置而被意外关闭) + // - 支持在后台“系统设置”中覆盖并持久化(存储于 DB) + linuxDoBase := config.LinuxDoConnectConfig{} + if s.cfg != nil { + linuxDoBase = s.cfg.LinuxDo + } + + if raw, ok := settings[SettingKeyLinuxDoConnectEnabled]; ok { + result.LinuxDoConnectEnabled = raw == "true" + } else { + result.LinuxDoConnectEnabled = linuxDoBase.Enabled + } + + if v, ok := settings[SettingKeyLinuxDoConnectClientID]; ok && strings.TrimSpace(v) != "" { + result.LinuxDoConnectClientID = strings.TrimSpace(v) + } else { + result.LinuxDoConnectClientID = linuxDoBase.ClientID + } + + if v, ok := settings[SettingKeyLinuxDoConnectRedirectURL]; ok && strings.TrimSpace(v) != "" { + result.LinuxDoConnectRedirectURL = strings.TrimSpace(v) + } else { + result.LinuxDoConnectRedirectURL = linuxDoBase.RedirectURL + } + + result.LinuxDoConnectClientSecret = strings.TrimSpace(settings[SettingKeyLinuxDoConnectClientSecret]) + if result.LinuxDoConnectClientSecret == "" { + result.LinuxDoConnectClientSecret = strings.TrimSpace(linuxDoBase.ClientSecret) + } + result.LinuxDoConnectClientSecretConfigured = result.LinuxDoConnectClientSecret != "" + // Model fallback settings result.EnableModelFallback = settings[SettingKeyEnableModelFallback] == "true" result.FallbackModelAnthropic = s.getStringOrDefault(settings, SettingKeyFallbackModelAnthropic, "claude-3-5-sonnet-20241022") @@ -289,6 +339,99 @@ func (s *SettingService) parseSettings(settings map[string]string) *SystemSettin return result } +// GetLinuxDoConnectOAuthConfig 返回用于登录的“最终生效” LinuxDo Connect 配置。 +// +// 优先级: +// - 若对应系统设置键存在,则覆盖 config.yaml/env 的值 +// - 否则回退到 config.yaml/env 的值 +func (s *SettingService) GetLinuxDoConnectOAuthConfig(ctx context.Context) (config.LinuxDoConnectConfig, error) { + if s == nil || s.cfg == nil { + return config.LinuxDoConnectConfig{}, infraerrors.ServiceUnavailable("CONFIG_NOT_READY", "config not loaded") + } + + effective := s.cfg.LinuxDo + + keys := []string{ + SettingKeyLinuxDoConnectEnabled, + SettingKeyLinuxDoConnectClientID, + SettingKeyLinuxDoConnectClientSecret, + SettingKeyLinuxDoConnectRedirectURL, + } + settings, err := s.settingRepo.GetMultiple(ctx, keys) + if err != nil { + return config.LinuxDoConnectConfig{}, fmt.Errorf("get linuxdo connect settings: %w", err) + } + + if raw, ok := settings[SettingKeyLinuxDoConnectEnabled]; ok { + effective.Enabled = raw == "true" + } + if v, ok := settings[SettingKeyLinuxDoConnectClientID]; ok && strings.TrimSpace(v) != "" { + effective.ClientID = strings.TrimSpace(v) + } + if v, ok := settings[SettingKeyLinuxDoConnectClientSecret]; ok && strings.TrimSpace(v) != "" { + effective.ClientSecret = strings.TrimSpace(v) + } + if v, ok := settings[SettingKeyLinuxDoConnectRedirectURL]; ok && strings.TrimSpace(v) != "" { + effective.RedirectURL = strings.TrimSpace(v) + } + + if !effective.Enabled { + return config.LinuxDoConnectConfig{}, infraerrors.NotFound("OAUTH_DISABLED", "oauth login is disabled") + } + + // 基础健壮性校验(避免把用户重定向到一个必然失败或不安全的 OAuth 流程里)。 + if strings.TrimSpace(effective.ClientID) == "" { + return config.LinuxDoConnectConfig{}, infraerrors.InternalServer("OAUTH_CONFIG_INVALID", "oauth client id not configured") + } + if strings.TrimSpace(effective.AuthorizeURL) == "" { + return config.LinuxDoConnectConfig{}, infraerrors.InternalServer("OAUTH_CONFIG_INVALID", "oauth authorize url not configured") + } + if strings.TrimSpace(effective.TokenURL) == "" { + return config.LinuxDoConnectConfig{}, infraerrors.InternalServer("OAUTH_CONFIG_INVALID", "oauth token url not configured") + } + if strings.TrimSpace(effective.UserInfoURL) == "" { + return config.LinuxDoConnectConfig{}, infraerrors.InternalServer("OAUTH_CONFIG_INVALID", "oauth userinfo url not configured") + } + if strings.TrimSpace(effective.RedirectURL) == "" { + return config.LinuxDoConnectConfig{}, infraerrors.InternalServer("OAUTH_CONFIG_INVALID", "oauth redirect url not configured") + } + if strings.TrimSpace(effective.FrontendRedirectURL) == "" { + return config.LinuxDoConnectConfig{}, infraerrors.InternalServer("OAUTH_CONFIG_INVALID", "oauth frontend redirect url not configured") + } + + if err := config.ValidateAbsoluteHTTPURL(effective.AuthorizeURL); err != nil { + return config.LinuxDoConnectConfig{}, infraerrors.InternalServer("OAUTH_CONFIG_INVALID", "oauth authorize url invalid") + } + if err := config.ValidateAbsoluteHTTPURL(effective.TokenURL); err != nil { + return config.LinuxDoConnectConfig{}, infraerrors.InternalServer("OAUTH_CONFIG_INVALID", "oauth token url invalid") + } + if err := config.ValidateAbsoluteHTTPURL(effective.UserInfoURL); err != nil { + return config.LinuxDoConnectConfig{}, infraerrors.InternalServer("OAUTH_CONFIG_INVALID", "oauth userinfo url invalid") + } + if err := config.ValidateAbsoluteHTTPURL(effective.RedirectURL); err != nil { + return config.LinuxDoConnectConfig{}, infraerrors.InternalServer("OAUTH_CONFIG_INVALID", "oauth redirect url invalid") + } + if err := config.ValidateFrontendRedirectURL(effective.FrontendRedirectURL); err != nil { + return config.LinuxDoConnectConfig{}, infraerrors.InternalServer("OAUTH_CONFIG_INVALID", "oauth frontend redirect url invalid") + } + + method := strings.ToLower(strings.TrimSpace(effective.TokenAuthMethod)) + switch method { + case "", "client_secret_post", "client_secret_basic": + if strings.TrimSpace(effective.ClientSecret) == "" { + return config.LinuxDoConnectConfig{}, infraerrors.InternalServer("OAUTH_CONFIG_INVALID", "oauth client secret not configured") + } + case "none": + if !effective.UsePKCE { + return config.LinuxDoConnectConfig{}, infraerrors.InternalServer("OAUTH_CONFIG_INVALID", "oauth pkce must be enabled when token_auth_method=none") + } + default: + return config.LinuxDoConnectConfig{}, infraerrors.InternalServer("OAUTH_CONFIG_INVALID", "oauth token_auth_method invalid") + } + + return effective, nil +} + // getStringOrDefault 获取字符串值或默认值 func (s *SettingService) getStringOrDefault(settings map[string]string, key, defaultValue string) string { if value, ok := settings[key]; ok && value != "" { diff --git a/backend/internal/service/settings_view.go b/backend/internal/service/settings_view.go index de0331f7..26051418 100644 --- a/backend/internal/service/settings_view.go +++ b/backend/internal/service/settings_view.go @@ -18,6 +18,13 @@ type SystemSettings struct { TurnstileSecretKey string TurnstileSecretKeyConfigured bool + // LinuxDo Connect OAuth 登录(终端用户 SSO) + LinuxDoConnectEnabled bool + LinuxDoConnectClientID string + LinuxDoConnectClientSecret string + LinuxDoConnectClientSecretConfigured bool + LinuxDoConnectRedirectURL string + SiteName string SiteLogo string SiteSubtitle string @@ -51,5 +58,6 @@ type PublicSettings struct { APIBaseURL string ContactInfo string DocURL string + LinuxDoOAuthEnabled bool Version string } diff --git a/deploy/config.example.yaml b/deploy/config.example.yaml index 49bf0afa..936f0ea4 100644 --- a/deploy/config.example.yaml +++ b/deploy/config.example.yaml @@ -234,6 +234,31 @@ jwt: # 令牌过期时间(小时,最大 24) expire_hour: 24 +# ============================================================================= +# LinuxDo Connect OAuth Login (SSO) +# LinuxDo Connect OAuth 登录(用于 Sub2API 用户登录) +# ============================================================================= +linuxdo_connect: + enabled: false + client_id: "" + client_secret: "" + authorize_url: "https://connect.linux.do/oauth2/authorize" + token_url: "https://connect.linux.do/oauth2/token" + userinfo_url: "https://connect.linux.do/api/user" + scopes: "user" + # 示例: "https://your-domain.com/api/v1/auth/oauth/linuxdo/callback" + redirect_url: "" + # 安全提示: + # - 建议使用同源相对路径(以 / 开头),避免把 token 重定向到意外的第三方域名 + # - 该地址不应包含 #fragment(本实现使用 URL fragment 传递 access_token) + frontend_redirect_url: "/auth/linuxdo/callback" + token_auth_method: "client_secret_post" # client_secret_post | client_secret_basic | none + # 注意:当 token_auth_method=none(public client)时,必须启用 PKCE + use_pkce: false + userinfo_email_path: "" + userinfo_id_path: "" + userinfo_username_path: "" + # ============================================================================= # Default Settings # 默认设置 diff --git a/deploy/docker-compose.yml b/deploy/docker-compose.yml index 6a370e9a..484df3a8 100644 --- a/deploy/docker-compose.yml +++ b/deploy/docker-compose.yml @@ -173,11 +173,12 @@ services: volumes: - redis_data:/data command: > - redis-server - --save 60 1 - --appendonly yes - --appendfsync everysec - ${REDIS_PASSWORD:+--requirepass ${REDIS_PASSWORD}} + sh -c ' + redis-server + --save 60 1 + --appendonly yes + --appendfsync everysec + ${REDIS_PASSWORD:+--requirepass "$REDIS_PASSWORD"}' environment: - TZ=${TZ:-Asia/Shanghai} # REDISCLI_AUTH is used by redis-cli for authentication (safer than -a flag) diff --git a/frontend/src/api/admin/groups.ts b/frontend/src/api/admin/groups.ts index 23db9104..44eebc99 100644 --- a/frontend/src/api/admin/groups.ts +++ b/frontend/src/api/admin/groups.ts @@ -16,7 +16,7 @@ import type { * List all groups with pagination * @param page - Page number (default: 1) * @param pageSize - Items per page (default: 20) - * @param filters - Optional filters (platform, status, is_exclusive) + * @param filters - Optional filters (platform, status, is_exclusive, search) * @returns Paginated list of groups */ export async function list( @@ -26,6 +26,7 @@ export async function list( platform?: GroupPlatform status?: 'active' | 'inactive' is_exclusive?: boolean + search?: string }, options?: { signal?: AbortSignal diff --git a/frontend/src/api/admin/settings.ts b/frontend/src/api/admin/settings.ts index 6b46de7d..2f6991e7 100644 --- a/frontend/src/api/admin/settings.ts +++ b/frontend/src/api/admin/settings.ts @@ -34,6 +34,11 @@ export interface SystemSettings { turnstile_enabled: boolean turnstile_site_key: string turnstile_secret_key_configured: boolean + // LinuxDo Connect OAuth 登录(终端用户 SSO) + linuxdo_connect_enabled: boolean + linuxdo_connect_client_id: string + linuxdo_connect_client_secret_configured: boolean + linuxdo_connect_redirect_url: string // Identity patch configuration (Claude -> Gemini) enable_identity_patch: boolean identity_patch_prompt: string @@ -60,6 +65,10 @@ export interface UpdateSettingsRequest { turnstile_enabled?: boolean turnstile_site_key?: string turnstile_secret_key?: string + linuxdo_connect_enabled?: boolean + linuxdo_connect_client_id?: string + linuxdo_connect_client_secret?: string + linuxdo_connect_redirect_url?: string enable_identity_patch?: boolean identity_patch_prompt?: string } diff --git a/frontend/src/components/account/CreateAccountModal.vue b/frontend/src/components/account/CreateAccountModal.vue index e90bec6c..5833632b 100644 --- a/frontend/src/components/account/CreateAccountModal.vue +++ b/frontend/src/components/account/CreateAccountModal.vue @@ -166,7 +166,7 @@ >
+ {{ t('admin.settings.linuxdo.description') }} +
++ {{ t('admin.settings.linuxdo.enableHint') }} +
++ {{ t('admin.settings.linuxdo.clientIdHint') }} +
++ {{ + form.linuxdo_connect_client_secret_configured + ? t('admin.settings.linuxdo.clientSecretConfiguredHint') + : t('admin.settings.linuxdo.clientSecretHint') + }} +
+
+ {{ linuxdoRedirectUrlSuggestion }}
+
+ + {{ t('admin.settings.linuxdo.redirectUrlHint') }} +
++ {{ isProcessing ? t('auth.linuxdo.callbackProcessing') : t('auth.linuxdo.callbackHint') }} +
++ {{ errorMessage }} +
+