diff --git a/Linux DO Connect.md b/Linux DO Connect.md new file mode 100644 index 00000000..7ca1260f --- /dev/null +++ b/Linux DO Connect.md @@ -0,0 +1,368 @@ +# Linux DO Connect + +OAuth(Open Authorization)是一个开放的网络授权标准,目前最新版本为 OAuth 2.0。我们日常使用的第三方登录(如 Google 账号登录)就采用了该标准。OAuth 允许用户授权第三方应用访问存储在其他服务提供商(如 Google)上的信息,无需在不同平台上重复填写注册信息。用户授权后,平台可以直接访问用户的账户信息进行身份验证,而用户无需向第三方应用提供密码。 + +目前系统已实现完整的 OAuth2 授权码(code)方式鉴权,但界面等配套功能还在持续完善中。让我们一起打造一个更完善的共享方案。 + +## 基本介绍 + +这是一套标准的 OAuth2 鉴权系统,可以让开发者共享论坛的用户基本信息。 + +- 可获取字段: + +| 参数 | 说明 | +| ----------------- | ------------------------------- | +| `id` | 用户唯一标识(不可变) | +| `username` | 论坛用户名 | +| `name` | 论坛用户昵称(可变) | +| `avatar_template` | 用户头像模板URL(支持多种尺寸) | +| `active` | 账号活跃状态 | +| `trust_level` | 信任等级(0-4) | +| `silenced` | 禁言状态 | +| `external_ids` | 外部ID关联信息 | +| `api_key` | API访问密钥 | + +通过这些信息,公益网站/接口可以实现: + +1. 基于 `id` 的服务频率限制 +2. 基于 `trust_level` 的服务额度分配 +3. 基于用户信息的滥用举报机制 + +## 相关端点 + +- Authorize 端点: `https://connect.linux.do/oauth2/authorize` +- Token 端点:`https://connect.linux.do/oauth2/token` +- 用户信息 端点:`https://connect.linux.do/api/user` + +## 申请使用 + +- 访问 [Connect.Linux.Do](https://connect.linux.do/) 申请接入你的应用。 + + + +- 点击 **`我的应用接入`** - **`申请新接入`**,填写相关信息。其中 **`回调地址`** 是你的应用接收用户信息的地址。 + + + +- 申请成功后,你将获得 **`Client Id`** 和 **`Client Secret`**,这是你应用的唯一身份凭证。 + + + +## 接入 Linux Do + +JavaScript +```JavaScript +// 安装第三方请求库(或使用原生的 Fetch API),本例中使用 axios +// npm install axios + +// 通过 OAuth2 获取 Linux Do 用户信息的参考流程 +const axios = require('axios'); +const readline = require('readline'); + +// 配置信息(建议通过环境变量配置,避免使用硬编码) +const CLIENT_ID = '你的 Client ID'; +const CLIENT_SECRET = '你的 Client Secret'; +const REDIRECT_URI = '你的回调地址'; +const AUTH_URL = 'https://connect.linux.do/oauth2/authorize'; +const TOKEN_URL = 'https://connect.linux.do/oauth2/token'; +const USER_INFO_URL = 'https://connect.linux.do/api/user'; + +// 第一步:生成授权 URL +function getAuthUrl() { + const params = new URLSearchParams({ + client_id: CLIENT_ID, + redirect_uri: REDIRECT_URI, + response_type: 'code', + scope: 'user' + }); + + return `${AUTH_URL}?${params.toString()}`; +} + +// 第二步:获取 code 参数 +function getCode() { + return new Promise((resolve) => { + // 本例中使用终端输入来模拟流程,仅供本地测试 + // 请在实际应用中替换为真实的处理逻辑 + const rl = readline.createInterface({ input: process.stdin, output: process.stdout }); + rl.question('从回调 URL 中提取出 code,粘贴到此处并按回车:', (answer) => { + rl.close(); + resolve(answer.trim()); + }); + }); +} + +// 第三步:使用 code 参数获取访问令牌 +async function getAccessToken(code) { + try { + const form = new URLSearchParams({ + client_id: CLIENT_ID, + client_secret: CLIENT_SECRET, + code: code, + redirect_uri: REDIRECT_URI, + grant_type: 'authorization_code' + }).toString(); + + const response = await axios.post(TOKEN_URL, form, { + // 提醒:需正确配置请求头,否则无法正常获取访问令牌 + headers: { + 'Content-Type': 'application/x-www-form-urlencoded', + 'Accept': 'application/json' + } + }); + + return response.data; + } catch (error) { + console.error(`获取访问令牌失败:${error.response ? JSON.stringify(error.response.data) : error.message}`); + throw error; + } +} + +// 第四步:使用访问令牌获取用户信息 +async function getUserInfo(accessToken) { + try { + const response = await axios.get(USER_INFO_URL, { + headers: { + Authorization: `Bearer ${accessToken}` + } + }); + + return response.data; + } catch (error) { + console.error(`获取用户信息失败:${error.response ? JSON.stringify(error.response.data) : error.message}`); + throw error; + } +} + +// 主流程 +async function main() { + // 1. 生成授权 URL,前端引导用户访问授权页 + const authUrl = getAuthUrl(); + console.log(`请访问此 URL 授权:${authUrl} +`); + + // 2. 用户授权后,从回调 URL 获取 code 参数 + const code = await getCode(); + + try { + // 3. 使用 code 参数获取访问令牌 + const tokenData = await getAccessToken(code); + const accessToken = tokenData.access_token; + + // 4. 使用访问令牌获取用户信息 + if (accessToken) { + const userInfo = await getUserInfo(accessToken); + console.log(` +获取用户信息成功:${JSON.stringify(userInfo, null, 2)}`); + } else { + console.log(` +获取访问令牌失败:${JSON.stringify(tokenData)}`); + } + } catch (error) { + console.error('发生错误:', error); + } +} +``` +Python +```python +# 安装第三方请求库,本例中使用 requests +# pip install requests + +# 通过 OAuth2 获取 Linux Do 用户信息的参考流程 +import requests +import json + +# 配置信息(建议通过环境变量配置,避免使用硬编码) +CLIENT_ID = '你的 Client ID' +CLIENT_SECRET = '你的 Client Secret' +REDIRECT_URI = '你的回调地址' +AUTH_URL = 'https://connect.linux.do/oauth2/authorize' +TOKEN_URL = 'https://connect.linux.do/oauth2/token' +USER_INFO_URL = 'https://connect.linux.do/api/user' + +# 第一步:生成授权 URL +def get_auth_url(): + params = { + 'client_id': CLIENT_ID, + 'redirect_uri': REDIRECT_URI, + 'response_type': 'code', + 'scope': 'user' + } + auth_url = f"{AUTH_URL}?{'&'.join(f'{k}={v}' for k, v in params.items())}" + return auth_url + +# 第二步:获取 code 参数 +def get_code(): + # 本例中使用终端输入来模拟流程,仅供本地测试 + # 请在实际应用中替换为真实的处理逻辑 + return input('从回调 URL 中提取出 code,粘贴到此处并按回车:').strip() + +# 第三步:使用 code 参数获取访问令牌 +def get_access_token(code): + try: + data = { + 'client_id': CLIENT_ID, + 'client_secret': CLIENT_SECRET, + 'code': code, + 'redirect_uri': REDIRECT_URI, + 'grant_type': 'authorization_code' + } + # 提醒:需正确配置请求头,否则无法正常获取访问令牌 + headers = { + 'Content-Type': 'application/x-www-form-urlencoded', + 'Accept': 'application/json' + } + response = requests.post(TOKEN_URL, data=data, headers=headers) + response.raise_for_status() + return response.json() + except requests.exceptions.RequestException as e: + print(f"获取访问令牌失败:{e}") + return None + +# 第四步:使用访问令牌获取用户信息 +def get_user_info(access_token): + try: + headers = { + 'Authorization': f'Bearer {access_token}' + } + response = requests.get(USER_INFO_URL, headers=headers) + response.raise_for_status() + return response.json() + except requests.exceptions.RequestException as e: + print(f"获取用户信息失败:{e}") + return None + +# 主流程 +if __name__ == '__main__': + # 1. 生成授权 URL,前端引导用户访问授权页 + auth_url = get_auth_url() + print(f'请访问此 URL 授权:{auth_url} +') + + # 2. 用户授权后,从回调 URL 获取 code 参数 + code = get_code() + + # 3. 使用 code 参数获取访问令牌 + token_data = get_access_token(code) + if token_data: + access_token = token_data.get('access_token') + + # 4. 使用访问令牌获取用户信息 + if access_token: + user_info = get_user_info(access_token) + if user_info: + print(f" +获取用户信息成功:{json.dumps(user_info, indent=2)}") + else: + print(" +获取用户信息失败") + else: + print(f" +获取访问令牌失败:{json.dumps(token_data, indent=2)}") + else: + print(" +获取访问令牌失败") +``` +PHP +```php +// 通过 OAuth2 获取 Linux Do 用户信息的参考流程 + +// 配置信息 +$CLIENT_ID = '你的 Client ID'; +$CLIENT_SECRET = '你的 Client Secret'; +$REDIRECT_URI = '你的回调地址'; +$AUTH_URL = 'https://connect.linux.do/oauth2/authorize'; +$TOKEN_URL = 'https://connect.linux.do/oauth2/token'; +$USER_INFO_URL = 'https://connect.linux.do/api/user'; + +// 生成授权 URL +function getAuthUrl($clientId, $redirectUri) { + global $AUTH_URL; + return $AUTH_URL . '?' . http_build_query([ + 'client_id' => $clientId, + 'redirect_uri' => $redirectUri, + 'response_type' => 'code', + 'scope' => 'user' + ]); +} + +// 使用 code 参数获取用户信息(合并获取令牌和获取用户信息的步骤) +function getUserInfoWithCode($code, $clientId, $clientSecret, $redirectUri) { + global $TOKEN_URL, $USER_INFO_URL; + + // 1. 获取访问令牌 + $ch = curl_init($TOKEN_URL); + curl_setopt($ch, CURLOPT_RETURNTRANSFER, true); + curl_setopt($ch, CURLOPT_POST, true); + curl_setopt($ch, CURLOPT_POSTFIELDS, http_build_query([ + 'client_id' => $clientId, + 'client_secret' => $clientSecret, + 'code' => $code, + 'redirect_uri' => $redirectUri, + 'grant_type' => 'authorization_code' + ])); + curl_setopt($ch, CURLOPT_HTTPHEADER, [ + 'Content-Type: application/x-www-form-urlencoded', + 'Accept: application/json' + ]); + + $tokenResponse = curl_exec($ch); + curl_close($ch); + + $tokenData = json_decode($tokenResponse, true); + if (!isset($tokenData['access_token'])) { + return ['error' => '获取访问令牌失败', 'details' => $tokenData]; + } + + // 2. 获取用户信息 + $ch = curl_init($USER_INFO_URL); + curl_setopt($ch, CURLOPT_RETURNTRANSFER, true); + curl_setopt($ch, CURLOPT_HTTPHEADER, [ + 'Authorization: Bearer ' . $tokenData['access_token'] + ]); + + $userResponse = curl_exec($ch); + curl_close($ch); + + return json_decode($userResponse, true); +} + +// 主流程 +// 1. 生成授权 URL +$authUrl = getAuthUrl($CLIENT_ID, $REDIRECT_URI); +echo "使用 Linux Do 登录"; + +// 2. 处理回调并获取用户信息 +if (isset($_GET['code'])) { + $userInfo = getUserInfoWithCode( + $_GET['code'], + $CLIENT_ID, + $CLIENT_SECRET, + $REDIRECT_URI + ); + + if (isset($userInfo['error'])) { + echo '错误: ' . $userInfo['error']; + } else { + echo '欢迎, ' . $userInfo['name'] . '!'; + // 处理用户登录逻辑... + } +} +``` + +## 使用说明 + +### 授权流程 + +1. 用户点击应用中的’使用 Linux Do 登录’按钮 +2. 系统将用户重定向至 Linux Do 的授权页面 +3. 用户完成授权后,系统自动重定向回应用并携带授权码 +4. 应用使用授权码获取访问令牌 +5. 使用访问令牌获取用户信息 + +### 安全建议 + +- 切勿在前端代码中暴露 Client Secret +- 对所有用户输入数据进行严格验证 +- 确保使用 HTTPS 协议传输数据 +- 定期更新并妥善保管 Client Secret \ No newline at end of file diff --git a/backend/cmd/server/VERSION b/backend/cmd/server/VERSION index 17e51c38..79e0dd8a 100644 --- a/backend/cmd/server/VERSION +++ b/backend/cmd/server/VERSION @@ -1 +1 @@ -0.1.1 +0.1.46 diff --git a/backend/cmd/server/wire_gen.go b/backend/cmd/server/wire_gen.go index 31dc3682..85bed3f3 100644 --- a/backend/cmd/server/wire_gen.go +++ b/backend/cmd/server/wire_gen.go @@ -53,7 +53,7 @@ func initializeApplication(buildInfo handler.BuildInfo) (*Application, error) { emailQueueService := service.ProvideEmailQueueService(emailService) authService := service.NewAuthService(userRepository, configConfig, settingService, emailService, turnstileService, emailQueueService) userService := service.NewUserService(userRepository) - authHandler := handler.NewAuthHandler(configConfig, authService, userService) + authHandler := handler.NewAuthHandler(configConfig, authService, userService, settingService) userHandler := handler.NewUserHandler(userService) apiKeyRepository := repository.NewAPIKeyRepository(client) groupRepository := repository.NewGroupRepository(client, db) diff --git a/backend/ent/group.go b/backend/ent/group.go index dca64cec..4a31442a 100644 --- a/backend/ent/group.go +++ b/backend/ent/group.go @@ -51,6 +51,10 @@ type Group struct { ImagePrice2k *float64 `json:"image_price_2k,omitempty"` // ImagePrice4k holds the value of the "image_price_4k" field. ImagePrice4k *float64 `json:"image_price_4k,omitempty"` + // 是否仅允许 Claude Code 客户端 + ClaudeCodeOnly bool `json:"claude_code_only,omitempty"` + // 非 Claude Code 请求降级使用的分组 ID + FallbackGroupID *int64 `json:"fallback_group_id,omitempty"` // Edges holds the relations/edges for other nodes in the graph. // The values are being populated by the GroupQuery when eager-loading is set. Edges GroupEdges `json:"edges"` @@ -157,11 +161,11 @@ func (*Group) scanValues(columns []string) ([]any, error) { values := make([]any, len(columns)) for i := range columns { switch columns[i] { - case group.FieldIsExclusive: + case group.FieldIsExclusive, group.FieldClaudeCodeOnly: values[i] = new(sql.NullBool) case group.FieldRateMultiplier, group.FieldDailyLimitUsd, group.FieldWeeklyLimitUsd, group.FieldMonthlyLimitUsd, group.FieldImagePrice1k, group.FieldImagePrice2k, group.FieldImagePrice4k: values[i] = new(sql.NullFloat64) - case group.FieldID, group.FieldDefaultValidityDays: + case group.FieldID, group.FieldDefaultValidityDays, group.FieldFallbackGroupID: values[i] = new(sql.NullInt64) case group.FieldName, group.FieldDescription, group.FieldStatus, group.FieldPlatform, group.FieldSubscriptionType: values[i] = new(sql.NullString) @@ -298,6 +302,19 @@ func (_m *Group) assignValues(columns []string, values []any) error { _m.ImagePrice4k = new(float64) *_m.ImagePrice4k = value.Float64 } + case group.FieldClaudeCodeOnly: + if value, ok := values[i].(*sql.NullBool); !ok { + return fmt.Errorf("unexpected type %T for field claude_code_only", values[i]) + } else if value.Valid { + _m.ClaudeCodeOnly = value.Bool + } + case group.FieldFallbackGroupID: + if value, ok := values[i].(*sql.NullInt64); !ok { + return fmt.Errorf("unexpected type %T for field fallback_group_id", values[i]) + } else if value.Valid { + _m.FallbackGroupID = new(int64) + *_m.FallbackGroupID = value.Int64 + } default: _m.selectValues.Set(columns[i], values[i]) } @@ -440,6 +457,14 @@ func (_m *Group) String() string { builder.WriteString("image_price_4k=") builder.WriteString(fmt.Sprintf("%v", *v)) } + builder.WriteString(", ") + builder.WriteString("claude_code_only=") + builder.WriteString(fmt.Sprintf("%v", _m.ClaudeCodeOnly)) + builder.WriteString(", ") + if v := _m.FallbackGroupID; v != nil { + builder.WriteString("fallback_group_id=") + builder.WriteString(fmt.Sprintf("%v", *v)) + } builder.WriteByte(')') return builder.String() } diff --git a/backend/ent/group/group.go b/backend/ent/group/group.go index 1c5ed343..c4317f00 100644 --- a/backend/ent/group/group.go +++ b/backend/ent/group/group.go @@ -49,6 +49,10 @@ const ( FieldImagePrice2k = "image_price_2k" // FieldImagePrice4k holds the string denoting the image_price_4k field in the database. FieldImagePrice4k = "image_price_4k" + // FieldClaudeCodeOnly holds the string denoting the claude_code_only field in the database. + FieldClaudeCodeOnly = "claude_code_only" + // FieldFallbackGroupID holds the string denoting the fallback_group_id field in the database. + FieldFallbackGroupID = "fallback_group_id" // EdgeAPIKeys holds the string denoting the api_keys edge name in mutations. EdgeAPIKeys = "api_keys" // EdgeRedeemCodes holds the string denoting the redeem_codes edge name in mutations. @@ -141,6 +145,8 @@ var Columns = []string{ FieldImagePrice1k, FieldImagePrice2k, FieldImagePrice4k, + FieldClaudeCodeOnly, + FieldFallbackGroupID, } var ( @@ -196,6 +202,8 @@ var ( SubscriptionTypeValidator func(string) error // DefaultDefaultValidityDays holds the default value on creation for the "default_validity_days" field. DefaultDefaultValidityDays int + // DefaultClaudeCodeOnly holds the default value on creation for the "claude_code_only" field. + DefaultClaudeCodeOnly bool ) // OrderOption defines the ordering options for the Group queries. @@ -291,6 +299,16 @@ func ByImagePrice4k(opts ...sql.OrderTermOption) OrderOption { return sql.OrderByField(FieldImagePrice4k, opts...).ToFunc() } +// ByClaudeCodeOnly orders the results by the claude_code_only field. +func ByClaudeCodeOnly(opts ...sql.OrderTermOption) OrderOption { + return sql.OrderByField(FieldClaudeCodeOnly, opts...).ToFunc() +} + +// ByFallbackGroupID orders the results by the fallback_group_id field. +func ByFallbackGroupID(opts ...sql.OrderTermOption) OrderOption { + return sql.OrderByField(FieldFallbackGroupID, opts...).ToFunc() +} + // ByAPIKeysCount orders the results by api_keys count. func ByAPIKeysCount(opts ...sql.OrderTermOption) OrderOption { return func(s *sql.Selector) { diff --git a/backend/ent/group/where.go b/backend/ent/group/where.go index 7bce1fe6..fb2f942f 100644 --- a/backend/ent/group/where.go +++ b/backend/ent/group/where.go @@ -140,6 +140,16 @@ func ImagePrice4k(v float64) predicate.Group { return predicate.Group(sql.FieldEQ(FieldImagePrice4k, v)) } +// ClaudeCodeOnly applies equality check predicate on the "claude_code_only" field. It's identical to ClaudeCodeOnlyEQ. +func ClaudeCodeOnly(v bool) predicate.Group { + return predicate.Group(sql.FieldEQ(FieldClaudeCodeOnly, v)) +} + +// FallbackGroupID applies equality check predicate on the "fallback_group_id" field. It's identical to FallbackGroupIDEQ. +func FallbackGroupID(v int64) predicate.Group { + return predicate.Group(sql.FieldEQ(FieldFallbackGroupID, v)) +} + // CreatedAtEQ applies the EQ predicate on the "created_at" field. func CreatedAtEQ(v time.Time) predicate.Group { return predicate.Group(sql.FieldEQ(FieldCreatedAt, v)) @@ -995,6 +1005,66 @@ func ImagePrice4kNotNil() predicate.Group { return predicate.Group(sql.FieldNotNull(FieldImagePrice4k)) } +// ClaudeCodeOnlyEQ applies the EQ predicate on the "claude_code_only" field. +func ClaudeCodeOnlyEQ(v bool) predicate.Group { + return predicate.Group(sql.FieldEQ(FieldClaudeCodeOnly, v)) +} + +// ClaudeCodeOnlyNEQ applies the NEQ predicate on the "claude_code_only" field. +func ClaudeCodeOnlyNEQ(v bool) predicate.Group { + return predicate.Group(sql.FieldNEQ(FieldClaudeCodeOnly, v)) +} + +// FallbackGroupIDEQ applies the EQ predicate on the "fallback_group_id" field. +func FallbackGroupIDEQ(v int64) predicate.Group { + return predicate.Group(sql.FieldEQ(FieldFallbackGroupID, v)) +} + +// FallbackGroupIDNEQ applies the NEQ predicate on the "fallback_group_id" field. +func FallbackGroupIDNEQ(v int64) predicate.Group { + return predicate.Group(sql.FieldNEQ(FieldFallbackGroupID, v)) +} + +// FallbackGroupIDIn applies the In predicate on the "fallback_group_id" field. +func FallbackGroupIDIn(vs ...int64) predicate.Group { + return predicate.Group(sql.FieldIn(FieldFallbackGroupID, vs...)) +} + +// FallbackGroupIDNotIn applies the NotIn predicate on the "fallback_group_id" field. +func FallbackGroupIDNotIn(vs ...int64) predicate.Group { + return predicate.Group(sql.FieldNotIn(FieldFallbackGroupID, vs...)) +} + +// FallbackGroupIDGT applies the GT predicate on the "fallback_group_id" field. +func FallbackGroupIDGT(v int64) predicate.Group { + return predicate.Group(sql.FieldGT(FieldFallbackGroupID, v)) +} + +// FallbackGroupIDGTE applies the GTE predicate on the "fallback_group_id" field. +func FallbackGroupIDGTE(v int64) predicate.Group { + return predicate.Group(sql.FieldGTE(FieldFallbackGroupID, v)) +} + +// FallbackGroupIDLT applies the LT predicate on the "fallback_group_id" field. +func FallbackGroupIDLT(v int64) predicate.Group { + return predicate.Group(sql.FieldLT(FieldFallbackGroupID, v)) +} + +// FallbackGroupIDLTE applies the LTE predicate on the "fallback_group_id" field. +func FallbackGroupIDLTE(v int64) predicate.Group { + return predicate.Group(sql.FieldLTE(FieldFallbackGroupID, v)) +} + +// FallbackGroupIDIsNil applies the IsNil predicate on the "fallback_group_id" field. +func FallbackGroupIDIsNil() predicate.Group { + return predicate.Group(sql.FieldIsNull(FieldFallbackGroupID)) +} + +// FallbackGroupIDNotNil applies the NotNil predicate on the "fallback_group_id" field. +func FallbackGroupIDNotNil() predicate.Group { + return predicate.Group(sql.FieldNotNull(FieldFallbackGroupID)) +} + // HasAPIKeys applies the HasEdge predicate on the "api_keys" edge. func HasAPIKeys() predicate.Group { return predicate.Group(func(s *sql.Selector) { diff --git a/backend/ent/group_create.go b/backend/ent/group_create.go index 6a928af6..59229402 100644 --- a/backend/ent/group_create.go +++ b/backend/ent/group_create.go @@ -258,6 +258,34 @@ func (_c *GroupCreate) SetNillableImagePrice4k(v *float64) *GroupCreate { return _c } +// SetClaudeCodeOnly sets the "claude_code_only" field. +func (_c *GroupCreate) SetClaudeCodeOnly(v bool) *GroupCreate { + _c.mutation.SetClaudeCodeOnly(v) + return _c +} + +// SetNillableClaudeCodeOnly sets the "claude_code_only" field if the given value is not nil. +func (_c *GroupCreate) SetNillableClaudeCodeOnly(v *bool) *GroupCreate { + if v != nil { + _c.SetClaudeCodeOnly(*v) + } + return _c +} + +// SetFallbackGroupID sets the "fallback_group_id" field. +func (_c *GroupCreate) SetFallbackGroupID(v int64) *GroupCreate { + _c.mutation.SetFallbackGroupID(v) + return _c +} + +// SetNillableFallbackGroupID sets the "fallback_group_id" field if the given value is not nil. +func (_c *GroupCreate) SetNillableFallbackGroupID(v *int64) *GroupCreate { + if v != nil { + _c.SetFallbackGroupID(*v) + } + return _c +} + // AddAPIKeyIDs adds the "api_keys" edge to the APIKey entity by IDs. func (_c *GroupCreate) AddAPIKeyIDs(ids ...int64) *GroupCreate { _c.mutation.AddAPIKeyIDs(ids...) @@ -423,6 +451,10 @@ func (_c *GroupCreate) defaults() error { v := group.DefaultDefaultValidityDays _c.mutation.SetDefaultValidityDays(v) } + if _, ok := _c.mutation.ClaudeCodeOnly(); !ok { + v := group.DefaultClaudeCodeOnly + _c.mutation.SetClaudeCodeOnly(v) + } return nil } @@ -475,6 +507,9 @@ func (_c *GroupCreate) check() error { if _, ok := _c.mutation.DefaultValidityDays(); !ok { return &ValidationError{Name: "default_validity_days", err: errors.New(`ent: missing required field "Group.default_validity_days"`)} } + if _, ok := _c.mutation.ClaudeCodeOnly(); !ok { + return &ValidationError{Name: "claude_code_only", err: errors.New(`ent: missing required field "Group.claude_code_only"`)} + } return nil } @@ -570,6 +605,14 @@ func (_c *GroupCreate) createSpec() (*Group, *sqlgraph.CreateSpec) { _spec.SetField(group.FieldImagePrice4k, field.TypeFloat64, value) _node.ImagePrice4k = &value } + if value, ok := _c.mutation.ClaudeCodeOnly(); ok { + _spec.SetField(group.FieldClaudeCodeOnly, field.TypeBool, value) + _node.ClaudeCodeOnly = value + } + if value, ok := _c.mutation.FallbackGroupID(); ok { + _spec.SetField(group.FieldFallbackGroupID, field.TypeInt64, value) + _node.FallbackGroupID = &value + } if nodes := _c.mutation.APIKeysIDs(); len(nodes) > 0 { edge := &sqlgraph.EdgeSpec{ Rel: sqlgraph.O2M, @@ -1014,6 +1057,42 @@ func (u *GroupUpsert) ClearImagePrice4k() *GroupUpsert { return u } +// SetClaudeCodeOnly sets the "claude_code_only" field. +func (u *GroupUpsert) SetClaudeCodeOnly(v bool) *GroupUpsert { + u.Set(group.FieldClaudeCodeOnly, v) + return u +} + +// UpdateClaudeCodeOnly sets the "claude_code_only" field to the value that was provided on create. +func (u *GroupUpsert) UpdateClaudeCodeOnly() *GroupUpsert { + u.SetExcluded(group.FieldClaudeCodeOnly) + return u +} + +// SetFallbackGroupID sets the "fallback_group_id" field. +func (u *GroupUpsert) SetFallbackGroupID(v int64) *GroupUpsert { + u.Set(group.FieldFallbackGroupID, v) + return u +} + +// UpdateFallbackGroupID sets the "fallback_group_id" field to the value that was provided on create. +func (u *GroupUpsert) UpdateFallbackGroupID() *GroupUpsert { + u.SetExcluded(group.FieldFallbackGroupID) + return u +} + +// AddFallbackGroupID adds v to the "fallback_group_id" field. +func (u *GroupUpsert) AddFallbackGroupID(v int64) *GroupUpsert { + u.Add(group.FieldFallbackGroupID, v) + return u +} + +// ClearFallbackGroupID clears the value of the "fallback_group_id" field. +func (u *GroupUpsert) ClearFallbackGroupID() *GroupUpsert { + u.SetNull(group.FieldFallbackGroupID) + return u +} + // UpdateNewValues updates the mutable fields using the new values that were set on create. // Using this option is equivalent to using: // @@ -1395,6 +1474,48 @@ func (u *GroupUpsertOne) ClearImagePrice4k() *GroupUpsertOne { }) } +// SetClaudeCodeOnly sets the "claude_code_only" field. +func (u *GroupUpsertOne) SetClaudeCodeOnly(v bool) *GroupUpsertOne { + return u.Update(func(s *GroupUpsert) { + s.SetClaudeCodeOnly(v) + }) +} + +// UpdateClaudeCodeOnly sets the "claude_code_only" field to the value that was provided on create. +func (u *GroupUpsertOne) UpdateClaudeCodeOnly() *GroupUpsertOne { + return u.Update(func(s *GroupUpsert) { + s.UpdateClaudeCodeOnly() + }) +} + +// SetFallbackGroupID sets the "fallback_group_id" field. +func (u *GroupUpsertOne) SetFallbackGroupID(v int64) *GroupUpsertOne { + return u.Update(func(s *GroupUpsert) { + s.SetFallbackGroupID(v) + }) +} + +// AddFallbackGroupID adds v to the "fallback_group_id" field. +func (u *GroupUpsertOne) AddFallbackGroupID(v int64) *GroupUpsertOne { + return u.Update(func(s *GroupUpsert) { + s.AddFallbackGroupID(v) + }) +} + +// UpdateFallbackGroupID sets the "fallback_group_id" field to the value that was provided on create. +func (u *GroupUpsertOne) UpdateFallbackGroupID() *GroupUpsertOne { + return u.Update(func(s *GroupUpsert) { + s.UpdateFallbackGroupID() + }) +} + +// ClearFallbackGroupID clears the value of the "fallback_group_id" field. +func (u *GroupUpsertOne) ClearFallbackGroupID() *GroupUpsertOne { + return u.Update(func(s *GroupUpsert) { + s.ClearFallbackGroupID() + }) +} + // Exec executes the query. func (u *GroupUpsertOne) Exec(ctx context.Context) error { if len(u.create.conflict) == 0 { @@ -1942,6 +2063,48 @@ func (u *GroupUpsertBulk) ClearImagePrice4k() *GroupUpsertBulk { }) } +// SetClaudeCodeOnly sets the "claude_code_only" field. +func (u *GroupUpsertBulk) SetClaudeCodeOnly(v bool) *GroupUpsertBulk { + return u.Update(func(s *GroupUpsert) { + s.SetClaudeCodeOnly(v) + }) +} + +// UpdateClaudeCodeOnly sets the "claude_code_only" field to the value that was provided on create. +func (u *GroupUpsertBulk) UpdateClaudeCodeOnly() *GroupUpsertBulk { + return u.Update(func(s *GroupUpsert) { + s.UpdateClaudeCodeOnly() + }) +} + +// SetFallbackGroupID sets the "fallback_group_id" field. +func (u *GroupUpsertBulk) SetFallbackGroupID(v int64) *GroupUpsertBulk { + return u.Update(func(s *GroupUpsert) { + s.SetFallbackGroupID(v) + }) +} + +// AddFallbackGroupID adds v to the "fallback_group_id" field. +func (u *GroupUpsertBulk) AddFallbackGroupID(v int64) *GroupUpsertBulk { + return u.Update(func(s *GroupUpsert) { + s.AddFallbackGroupID(v) + }) +} + +// UpdateFallbackGroupID sets the "fallback_group_id" field to the value that was provided on create. +func (u *GroupUpsertBulk) UpdateFallbackGroupID() *GroupUpsertBulk { + return u.Update(func(s *GroupUpsert) { + s.UpdateFallbackGroupID() + }) +} + +// ClearFallbackGroupID clears the value of the "fallback_group_id" field. +func (u *GroupUpsertBulk) ClearFallbackGroupID() *GroupUpsertBulk { + return u.Update(func(s *GroupUpsert) { + s.ClearFallbackGroupID() + }) +} + // Exec executes the query. func (u *GroupUpsertBulk) Exec(ctx context.Context) error { if u.create.err != nil { diff --git a/backend/ent/group_update.go b/backend/ent/group_update.go index 43555ce2..1a6f15ec 100644 --- a/backend/ent/group_update.go +++ b/backend/ent/group_update.go @@ -354,6 +354,47 @@ func (_u *GroupUpdate) ClearImagePrice4k() *GroupUpdate { return _u } +// SetClaudeCodeOnly sets the "claude_code_only" field. +func (_u *GroupUpdate) SetClaudeCodeOnly(v bool) *GroupUpdate { + _u.mutation.SetClaudeCodeOnly(v) + return _u +} + +// SetNillableClaudeCodeOnly sets the "claude_code_only" field if the given value is not nil. +func (_u *GroupUpdate) SetNillableClaudeCodeOnly(v *bool) *GroupUpdate { + if v != nil { + _u.SetClaudeCodeOnly(*v) + } + return _u +} + +// SetFallbackGroupID sets the "fallback_group_id" field. +func (_u *GroupUpdate) SetFallbackGroupID(v int64) *GroupUpdate { + _u.mutation.ResetFallbackGroupID() + _u.mutation.SetFallbackGroupID(v) + return _u +} + +// SetNillableFallbackGroupID sets the "fallback_group_id" field if the given value is not nil. +func (_u *GroupUpdate) SetNillableFallbackGroupID(v *int64) *GroupUpdate { + if v != nil { + _u.SetFallbackGroupID(*v) + } + return _u +} + +// AddFallbackGroupID adds value to the "fallback_group_id" field. +func (_u *GroupUpdate) AddFallbackGroupID(v int64) *GroupUpdate { + _u.mutation.AddFallbackGroupID(v) + return _u +} + +// ClearFallbackGroupID clears the value of the "fallback_group_id" field. +func (_u *GroupUpdate) ClearFallbackGroupID() *GroupUpdate { + _u.mutation.ClearFallbackGroupID() + return _u +} + // AddAPIKeyIDs adds the "api_keys" edge to the APIKey entity by IDs. func (_u *GroupUpdate) AddAPIKeyIDs(ids ...int64) *GroupUpdate { _u.mutation.AddAPIKeyIDs(ids...) @@ -750,6 +791,18 @@ func (_u *GroupUpdate) sqlSave(ctx context.Context) (_node int, err error) { if _u.mutation.ImagePrice4kCleared() { _spec.ClearField(group.FieldImagePrice4k, field.TypeFloat64) } + if value, ok := _u.mutation.ClaudeCodeOnly(); ok { + _spec.SetField(group.FieldClaudeCodeOnly, field.TypeBool, value) + } + if value, ok := _u.mutation.FallbackGroupID(); ok { + _spec.SetField(group.FieldFallbackGroupID, field.TypeInt64, value) + } + if value, ok := _u.mutation.AddedFallbackGroupID(); ok { + _spec.AddField(group.FieldFallbackGroupID, field.TypeInt64, value) + } + if _u.mutation.FallbackGroupIDCleared() { + _spec.ClearField(group.FieldFallbackGroupID, field.TypeInt64) + } if _u.mutation.APIKeysCleared() { edge := &sqlgraph.EdgeSpec{ Rel: sqlgraph.O2M, @@ -1384,6 +1437,47 @@ func (_u *GroupUpdateOne) ClearImagePrice4k() *GroupUpdateOne { return _u } +// SetClaudeCodeOnly sets the "claude_code_only" field. +func (_u *GroupUpdateOne) SetClaudeCodeOnly(v bool) *GroupUpdateOne { + _u.mutation.SetClaudeCodeOnly(v) + return _u +} + +// SetNillableClaudeCodeOnly sets the "claude_code_only" field if the given value is not nil. +func (_u *GroupUpdateOne) SetNillableClaudeCodeOnly(v *bool) *GroupUpdateOne { + if v != nil { + _u.SetClaudeCodeOnly(*v) + } + return _u +} + +// SetFallbackGroupID sets the "fallback_group_id" field. +func (_u *GroupUpdateOne) SetFallbackGroupID(v int64) *GroupUpdateOne { + _u.mutation.ResetFallbackGroupID() + _u.mutation.SetFallbackGroupID(v) + return _u +} + +// SetNillableFallbackGroupID sets the "fallback_group_id" field if the given value is not nil. +func (_u *GroupUpdateOne) SetNillableFallbackGroupID(v *int64) *GroupUpdateOne { + if v != nil { + _u.SetFallbackGroupID(*v) + } + return _u +} + +// AddFallbackGroupID adds value to the "fallback_group_id" field. +func (_u *GroupUpdateOne) AddFallbackGroupID(v int64) *GroupUpdateOne { + _u.mutation.AddFallbackGroupID(v) + return _u +} + +// ClearFallbackGroupID clears the value of the "fallback_group_id" field. +func (_u *GroupUpdateOne) ClearFallbackGroupID() *GroupUpdateOne { + _u.mutation.ClearFallbackGroupID() + return _u +} + // AddAPIKeyIDs adds the "api_keys" edge to the APIKey entity by IDs. func (_u *GroupUpdateOne) AddAPIKeyIDs(ids ...int64) *GroupUpdateOne { _u.mutation.AddAPIKeyIDs(ids...) @@ -1810,6 +1904,18 @@ func (_u *GroupUpdateOne) sqlSave(ctx context.Context) (_node *Group, err error) if _u.mutation.ImagePrice4kCleared() { _spec.ClearField(group.FieldImagePrice4k, field.TypeFloat64) } + if value, ok := _u.mutation.ClaudeCodeOnly(); ok { + _spec.SetField(group.FieldClaudeCodeOnly, field.TypeBool, value) + } + if value, ok := _u.mutation.FallbackGroupID(); ok { + _spec.SetField(group.FieldFallbackGroupID, field.TypeInt64, value) + } + if value, ok := _u.mutation.AddedFallbackGroupID(); ok { + _spec.AddField(group.FieldFallbackGroupID, field.TypeInt64, value) + } + if _u.mutation.FallbackGroupIDCleared() { + _spec.ClearField(group.FieldFallbackGroupID, field.TypeInt64) + } if _u.mutation.APIKeysCleared() { edge := &sqlgraph.EdgeSpec{ Rel: sqlgraph.O2M, diff --git a/backend/ent/migrate/schema.go b/backend/ent/migrate/schema.go index e48201f3..13081e31 100644 --- a/backend/ent/migrate/schema.go +++ b/backend/ent/migrate/schema.go @@ -221,6 +221,8 @@ var ( {Name: "image_price_1k", Type: field.TypeFloat64, Nullable: true, SchemaType: map[string]string{"postgres": "decimal(20,8)"}}, {Name: "image_price_2k", Type: field.TypeFloat64, Nullable: true, SchemaType: map[string]string{"postgres": "decimal(20,8)"}}, {Name: "image_price_4k", Type: field.TypeFloat64, Nullable: true, SchemaType: map[string]string{"postgres": "decimal(20,8)"}}, + {Name: "claude_code_only", Type: field.TypeBool, Default: false}, + {Name: "fallback_group_id", Type: field.TypeInt64, Nullable: true}, } // GroupsTable holds the schema information for the "groups" table. GroupsTable = &schema.Table{ diff --git a/backend/ent/mutation.go b/backend/ent/mutation.go index a809e858..4e01e12b 100644 --- a/backend/ent/mutation.go +++ b/backend/ent/mutation.go @@ -3590,6 +3590,9 @@ type GroupMutation struct { addimage_price_2k *float64 image_price_4k *float64 addimage_price_4k *float64 + claude_code_only *bool + fallback_group_id *int64 + addfallback_group_id *int64 clearedFields map[string]struct{} api_keys map[int64]struct{} removedapi_keys map[int64]struct{} @@ -4594,6 +4597,112 @@ func (m *GroupMutation) ResetImagePrice4k() { delete(m.clearedFields, group.FieldImagePrice4k) } +// SetClaudeCodeOnly sets the "claude_code_only" field. +func (m *GroupMutation) SetClaudeCodeOnly(b bool) { + m.claude_code_only = &b +} + +// ClaudeCodeOnly returns the value of the "claude_code_only" field in the mutation. +func (m *GroupMutation) ClaudeCodeOnly() (r bool, exists bool) { + v := m.claude_code_only + if v == nil { + return + } + return *v, true +} + +// OldClaudeCodeOnly returns the old "claude_code_only" field's value of the Group entity. +// If the Group object wasn't provided to the builder, the object is fetched from the database. +// An error is returned if the mutation operation is not UpdateOne, or the database query fails. +func (m *GroupMutation) OldClaudeCodeOnly(ctx context.Context) (v bool, err error) { + if !m.op.Is(OpUpdateOne) { + return v, errors.New("OldClaudeCodeOnly is only allowed on UpdateOne operations") + } + if m.id == nil || m.oldValue == nil { + return v, errors.New("OldClaudeCodeOnly requires an ID field in the mutation") + } + oldValue, err := m.oldValue(ctx) + if err != nil { + return v, fmt.Errorf("querying old value for OldClaudeCodeOnly: %w", err) + } + return oldValue.ClaudeCodeOnly, nil +} + +// ResetClaudeCodeOnly resets all changes to the "claude_code_only" field. +func (m *GroupMutation) ResetClaudeCodeOnly() { + m.claude_code_only = nil +} + +// SetFallbackGroupID sets the "fallback_group_id" field. +func (m *GroupMutation) SetFallbackGroupID(i int64) { + m.fallback_group_id = &i + m.addfallback_group_id = nil +} + +// FallbackGroupID returns the value of the "fallback_group_id" field in the mutation. +func (m *GroupMutation) FallbackGroupID() (r int64, exists bool) { + v := m.fallback_group_id + if v == nil { + return + } + return *v, true +} + +// OldFallbackGroupID returns the old "fallback_group_id" field's value of the Group entity. +// If the Group object wasn't provided to the builder, the object is fetched from the database. +// An error is returned if the mutation operation is not UpdateOne, or the database query fails. +func (m *GroupMutation) OldFallbackGroupID(ctx context.Context) (v *int64, err error) { + if !m.op.Is(OpUpdateOne) { + return v, errors.New("OldFallbackGroupID is only allowed on UpdateOne operations") + } + if m.id == nil || m.oldValue == nil { + return v, errors.New("OldFallbackGroupID requires an ID field in the mutation") + } + oldValue, err := m.oldValue(ctx) + if err != nil { + return v, fmt.Errorf("querying old value for OldFallbackGroupID: %w", err) + } + return oldValue.FallbackGroupID, nil +} + +// AddFallbackGroupID adds i to the "fallback_group_id" field. +func (m *GroupMutation) AddFallbackGroupID(i int64) { + if m.addfallback_group_id != nil { + *m.addfallback_group_id += i + } else { + m.addfallback_group_id = &i + } +} + +// AddedFallbackGroupID returns the value that was added to the "fallback_group_id" field in this mutation. +func (m *GroupMutation) AddedFallbackGroupID() (r int64, exists bool) { + v := m.addfallback_group_id + if v == nil { + return + } + return *v, true +} + +// ClearFallbackGroupID clears the value of the "fallback_group_id" field. +func (m *GroupMutation) ClearFallbackGroupID() { + m.fallback_group_id = nil + m.addfallback_group_id = nil + m.clearedFields[group.FieldFallbackGroupID] = struct{}{} +} + +// FallbackGroupIDCleared returns if the "fallback_group_id" field was cleared in this mutation. +func (m *GroupMutation) FallbackGroupIDCleared() bool { + _, ok := m.clearedFields[group.FieldFallbackGroupID] + return ok +} + +// ResetFallbackGroupID resets all changes to the "fallback_group_id" field. +func (m *GroupMutation) ResetFallbackGroupID() { + m.fallback_group_id = nil + m.addfallback_group_id = nil + delete(m.clearedFields, group.FieldFallbackGroupID) +} + // AddAPIKeyIDs adds the "api_keys" edge to the APIKey entity by ids. func (m *GroupMutation) AddAPIKeyIDs(ids ...int64) { if m.api_keys == nil { @@ -4952,7 +5061,7 @@ func (m *GroupMutation) Type() string { // order to get all numeric fields that were incremented/decremented, call // AddedFields(). func (m *GroupMutation) Fields() []string { - fields := make([]string, 0, 17) + fields := make([]string, 0, 19) if m.created_at != nil { fields = append(fields, group.FieldCreatedAt) } @@ -5004,6 +5113,12 @@ func (m *GroupMutation) Fields() []string { if m.image_price_4k != nil { fields = append(fields, group.FieldImagePrice4k) } + if m.claude_code_only != nil { + fields = append(fields, group.FieldClaudeCodeOnly) + } + if m.fallback_group_id != nil { + fields = append(fields, group.FieldFallbackGroupID) + } return fields } @@ -5046,6 +5161,10 @@ func (m *GroupMutation) Field(name string) (ent.Value, bool) { return m.ImagePrice2k() case group.FieldImagePrice4k: return m.ImagePrice4k() + case group.FieldClaudeCodeOnly: + return m.ClaudeCodeOnly() + case group.FieldFallbackGroupID: + return m.FallbackGroupID() } return nil, false } @@ -5089,6 +5208,10 @@ func (m *GroupMutation) OldField(ctx context.Context, name string) (ent.Value, e return m.OldImagePrice2k(ctx) case group.FieldImagePrice4k: return m.OldImagePrice4k(ctx) + case group.FieldClaudeCodeOnly: + return m.OldClaudeCodeOnly(ctx) + case group.FieldFallbackGroupID: + return m.OldFallbackGroupID(ctx) } return nil, fmt.Errorf("unknown Group field %s", name) } @@ -5217,6 +5340,20 @@ func (m *GroupMutation) SetField(name string, value ent.Value) error { } m.SetImagePrice4k(v) return nil + case group.FieldClaudeCodeOnly: + v, ok := value.(bool) + if !ok { + return fmt.Errorf("unexpected type %T for field %s", value, name) + } + m.SetClaudeCodeOnly(v) + return nil + case group.FieldFallbackGroupID: + v, ok := value.(int64) + if !ok { + return fmt.Errorf("unexpected type %T for field %s", value, name) + } + m.SetFallbackGroupID(v) + return nil } return fmt.Errorf("unknown Group field %s", name) } @@ -5249,6 +5386,9 @@ func (m *GroupMutation) AddedFields() []string { if m.addimage_price_4k != nil { fields = append(fields, group.FieldImagePrice4k) } + if m.addfallback_group_id != nil { + fields = append(fields, group.FieldFallbackGroupID) + } return fields } @@ -5273,6 +5413,8 @@ func (m *GroupMutation) AddedField(name string) (ent.Value, bool) { return m.AddedImagePrice2k() case group.FieldImagePrice4k: return m.AddedImagePrice4k() + case group.FieldFallbackGroupID: + return m.AddedFallbackGroupID() } return nil, false } @@ -5338,6 +5480,13 @@ func (m *GroupMutation) AddField(name string, value ent.Value) error { } m.AddImagePrice4k(v) return nil + case group.FieldFallbackGroupID: + v, ok := value.(int64) + if !ok { + return fmt.Errorf("unexpected type %T for field %s", value, name) + } + m.AddFallbackGroupID(v) + return nil } return fmt.Errorf("unknown Group numeric field %s", name) } @@ -5370,6 +5519,9 @@ func (m *GroupMutation) ClearedFields() []string { if m.FieldCleared(group.FieldImagePrice4k) { fields = append(fields, group.FieldImagePrice4k) } + if m.FieldCleared(group.FieldFallbackGroupID) { + fields = append(fields, group.FieldFallbackGroupID) + } return fields } @@ -5408,6 +5560,9 @@ func (m *GroupMutation) ClearField(name string) error { case group.FieldImagePrice4k: m.ClearImagePrice4k() return nil + case group.FieldFallbackGroupID: + m.ClearFallbackGroupID() + return nil } return fmt.Errorf("unknown Group nullable field %s", name) } @@ -5467,6 +5622,12 @@ func (m *GroupMutation) ResetField(name string) error { case group.FieldImagePrice4k: m.ResetImagePrice4k() return nil + case group.FieldClaudeCodeOnly: + m.ResetClaudeCodeOnly() + return nil + case group.FieldFallbackGroupID: + m.ResetFallbackGroupID() + return nil } return fmt.Errorf("unknown Group field %s", name) } diff --git a/backend/ent/runtime/runtime.go b/backend/ent/runtime/runtime.go index 6ccfc6d2..fb1c948c 100644 --- a/backend/ent/runtime/runtime.go +++ b/backend/ent/runtime/runtime.go @@ -270,6 +270,10 @@ func init() { groupDescDefaultValidityDays := groupFields[10].Descriptor() // group.DefaultDefaultValidityDays holds the default value on creation for the default_validity_days field. group.DefaultDefaultValidityDays = groupDescDefaultValidityDays.Default.(int) + // groupDescClaudeCodeOnly is the schema descriptor for claude_code_only field. + groupDescClaudeCodeOnly := groupFields[14].Descriptor() + // group.DefaultClaudeCodeOnly holds the default value on creation for the claude_code_only field. + group.DefaultClaudeCodeOnly = groupDescClaudeCodeOnly.Default.(bool) proxyMixin := schema.Proxy{}.Mixin() proxyMixinHooks1 := proxyMixin[1].Hooks() proxy.Hooks[0] = proxyMixinHooks1[0] diff --git a/backend/ent/schema/group.go b/backend/ent/schema/group.go index 7b5f77b1..d38925b1 100644 --- a/backend/ent/schema/group.go +++ b/backend/ent/schema/group.go @@ -86,6 +86,15 @@ func (Group) Fields() []ent.Field { Optional(). Nillable(). SchemaType(map[string]string{dialect.Postgres: "decimal(20,8)"}), + + // Claude Code 客户端限制 (added by migration 029) + field.Bool("claude_code_only"). + Default(false). + Comment("是否仅允许 Claude Code 客户端"), + field.Int64("fallback_group_id"). + Optional(). + Nillable(). + Comment("非 Claude Code 请求降级使用的分组 ID"), } } @@ -101,6 +110,8 @@ func (Group) Edges() []ent.Edge { edge.From("allowed_users", User.Type). Ref("allowed_groups"). Through("user_allowed_groups", UserAllowedGroup.Type), + // 注意:fallback_group_id 直接作为字段使用,不定义 edge + // 这样允许多个分组指向同一个降级分组(M2O 关系) } } diff --git a/backend/internal/config/config.go b/backend/internal/config/config.go index d13a460a..2cc11967 100644 --- a/backend/internal/config/config.go +++ b/backend/internal/config/config.go @@ -6,6 +6,7 @@ import ( "encoding/hex" "fmt" "log" + "net/url" "os" "strings" "time" @@ -35,24 +36,25 @@ const ( ) type Config struct { - Server ServerConfig `mapstructure:"server"` - CORS CORSConfig `mapstructure:"cors"` - Security SecurityConfig `mapstructure:"security"` - Billing BillingConfig `mapstructure:"billing"` - Turnstile TurnstileConfig `mapstructure:"turnstile"` - Database DatabaseConfig `mapstructure:"database"` - Redis RedisConfig `mapstructure:"redis"` - JWT JWTConfig `mapstructure:"jwt"` - Default DefaultConfig `mapstructure:"default"` - RateLimit RateLimitConfig `mapstructure:"rate_limit"` - Pricing PricingConfig `mapstructure:"pricing"` - Gateway GatewayConfig `mapstructure:"gateway"` - Concurrency ConcurrencyConfig `mapstructure:"concurrency"` - TokenRefresh TokenRefreshConfig `mapstructure:"token_refresh"` - RunMode string `mapstructure:"run_mode" yaml:"run_mode"` - Timezone string `mapstructure:"timezone"` // e.g. "Asia/Shanghai", "UTC" - Gemini GeminiConfig `mapstructure:"gemini"` - Update UpdateConfig `mapstructure:"update"` + Server ServerConfig `mapstructure:"server"` + CORS CORSConfig `mapstructure:"cors"` + Security SecurityConfig `mapstructure:"security"` + Billing BillingConfig `mapstructure:"billing"` + Turnstile TurnstileConfig `mapstructure:"turnstile"` + Database DatabaseConfig `mapstructure:"database"` + Redis RedisConfig `mapstructure:"redis"` + JWT JWTConfig `mapstructure:"jwt"` + LinuxDo LinuxDoConnectConfig `mapstructure:"linuxdo_connect"` + Default DefaultConfig `mapstructure:"default"` + RateLimit RateLimitConfig `mapstructure:"rate_limit"` + Pricing PricingConfig `mapstructure:"pricing"` + Gateway GatewayConfig `mapstructure:"gateway"` + Concurrency ConcurrencyConfig `mapstructure:"concurrency"` + TokenRefresh TokenRefreshConfig `mapstructure:"token_refresh"` + RunMode string `mapstructure:"run_mode" yaml:"run_mode"` + Timezone string `mapstructure:"timezone"` // e.g. "Asia/Shanghai", "UTC" + Gemini GeminiConfig `mapstructure:"gemini"` + Update UpdateConfig `mapstructure:"update"` } // UpdateConfig 在线更新相关配置 @@ -322,6 +324,30 @@ type TurnstileConfig struct { Required bool `mapstructure:"required"` } +// LinuxDoConnectConfig 用于 LinuxDo Connect OAuth 登录(终端用户 SSO)。 +// +// 注意:这与上游账号的 OAuth(例如 OpenAI/Gemini 账号接入)不是一回事。 +// 这里是用于登录 Sub2API 本身的用户体系。 +type LinuxDoConnectConfig struct { + Enabled bool `mapstructure:"enabled"` + ClientID string `mapstructure:"client_id"` + ClientSecret string `mapstructure:"client_secret"` + AuthorizeURL string `mapstructure:"authorize_url"` + TokenURL string `mapstructure:"token_url"` + UserInfoURL string `mapstructure:"userinfo_url"` + Scopes string `mapstructure:"scopes"` + RedirectURL string `mapstructure:"redirect_url"` // 后端回调地址(需在提供方后台登记) + FrontendRedirectURL string `mapstructure:"frontend_redirect_url"` // 前端接收 token 的路由(默认:/auth/linuxdo/callback) + TokenAuthMethod string `mapstructure:"token_auth_method"` // client_secret_post / client_secret_basic / none + UsePKCE bool `mapstructure:"use_pkce"` + + // 可选:用于从 userinfo JSON 中提取字段的 gjson 路径。 + // 为空时,服务端会尝试一组常见字段名。 + UserInfoEmailPath string `mapstructure:"userinfo_email_path"` + UserInfoIDPath string `mapstructure:"userinfo_id_path"` + UserInfoUsernamePath string `mapstructure:"userinfo_username_path"` +} + type DefaultConfig struct { AdminEmail string `mapstructure:"admin_email"` AdminPassword string `mapstructure:"admin_password"` @@ -388,6 +414,18 @@ func Load() (*Config, error) { cfg.Server.Mode = "debug" } cfg.JWT.Secret = strings.TrimSpace(cfg.JWT.Secret) + cfg.LinuxDo.ClientID = strings.TrimSpace(cfg.LinuxDo.ClientID) + cfg.LinuxDo.ClientSecret = strings.TrimSpace(cfg.LinuxDo.ClientSecret) + cfg.LinuxDo.AuthorizeURL = strings.TrimSpace(cfg.LinuxDo.AuthorizeURL) + cfg.LinuxDo.TokenURL = strings.TrimSpace(cfg.LinuxDo.TokenURL) + cfg.LinuxDo.UserInfoURL = strings.TrimSpace(cfg.LinuxDo.UserInfoURL) + cfg.LinuxDo.Scopes = strings.TrimSpace(cfg.LinuxDo.Scopes) + cfg.LinuxDo.RedirectURL = strings.TrimSpace(cfg.LinuxDo.RedirectURL) + cfg.LinuxDo.FrontendRedirectURL = strings.TrimSpace(cfg.LinuxDo.FrontendRedirectURL) + cfg.LinuxDo.TokenAuthMethod = strings.ToLower(strings.TrimSpace(cfg.LinuxDo.TokenAuthMethod)) + cfg.LinuxDo.UserInfoEmailPath = strings.TrimSpace(cfg.LinuxDo.UserInfoEmailPath) + cfg.LinuxDo.UserInfoIDPath = strings.TrimSpace(cfg.LinuxDo.UserInfoIDPath) + cfg.LinuxDo.UserInfoUsernamePath = strings.TrimSpace(cfg.LinuxDo.UserInfoUsernamePath) cfg.CORS.AllowedOrigins = normalizeStringSlice(cfg.CORS.AllowedOrigins) cfg.Security.ResponseHeaders.AdditionalAllowed = normalizeStringSlice(cfg.Security.ResponseHeaders.AdditionalAllowed) cfg.Security.ResponseHeaders.ForceRemove = normalizeStringSlice(cfg.Security.ResponseHeaders.ForceRemove) @@ -426,6 +464,81 @@ func Load() (*Config, error) { return &cfg, nil } +// ValidateAbsoluteHTTPURL 校验一个绝对 http(s) URL(禁止 fragment)。 +func ValidateAbsoluteHTTPURL(raw string) error { + raw = strings.TrimSpace(raw) + if raw == "" { + return fmt.Errorf("empty url") + } + u, err := url.Parse(raw) + if err != nil { + return err + } + if !u.IsAbs() { + return fmt.Errorf("must be absolute") + } + if !isHTTPScheme(u.Scheme) { + return fmt.Errorf("unsupported scheme: %s", u.Scheme) + } + if strings.TrimSpace(u.Host) == "" { + return fmt.Errorf("missing host") + } + if u.Fragment != "" { + return fmt.Errorf("must not include fragment") + } + return nil +} + +// ValidateFrontendRedirectURL 校验前端回调地址: +// - 允许同源相对路径(以 / 开头) +// - 或绝对 http(s) URL(禁止 fragment) +func ValidateFrontendRedirectURL(raw string) error { + raw = strings.TrimSpace(raw) + if raw == "" { + return fmt.Errorf("empty url") + } + if strings.ContainsAny(raw, "\r\n") { + return fmt.Errorf("contains invalid characters") + } + if strings.HasPrefix(raw, "/") { + if strings.HasPrefix(raw, "//") { + return fmt.Errorf("must not start with //") + } + return nil + } + u, err := url.Parse(raw) + if err != nil { + return err + } + if !u.IsAbs() { + return fmt.Errorf("must be absolute http(s) url or relative path") + } + if !isHTTPScheme(u.Scheme) { + return fmt.Errorf("unsupported scheme: %s", u.Scheme) + } + if strings.TrimSpace(u.Host) == "" { + return fmt.Errorf("missing host") + } + if u.Fragment != "" { + return fmt.Errorf("must not include fragment") + } + return nil +} + +func isHTTPScheme(scheme string) bool { + return strings.EqualFold(scheme, "http") || strings.EqualFold(scheme, "https") +} + +func warnIfInsecureURL(field, raw string) { + u, err := url.Parse(strings.TrimSpace(raw)) + if err != nil { + return + } + if strings.EqualFold(u.Scheme, "http") { + log.Printf("Warning: %s uses http scheme; use https in production to avoid token leakage.", field) + } +} + func setDefaults() { viper.SetDefault("run_mode", RunModeStandard) @@ -475,6 +588,22 @@ func setDefaults() { // Turnstile viper.SetDefault("turnstile.required", false) + // LinuxDo Connect OAuth 登录(终端用户 SSO) + viper.SetDefault("linuxdo_connect.enabled", false) + viper.SetDefault("linuxdo_connect.client_id", "") + viper.SetDefault("linuxdo_connect.client_secret", "") + viper.SetDefault("linuxdo_connect.authorize_url", "https://connect.linux.do/oauth2/authorize") + viper.SetDefault("linuxdo_connect.token_url", "https://connect.linux.do/oauth2/token") + viper.SetDefault("linuxdo_connect.userinfo_url", "https://connect.linux.do/api/user") + viper.SetDefault("linuxdo_connect.scopes", "user") + viper.SetDefault("linuxdo_connect.redirect_url", "") + viper.SetDefault("linuxdo_connect.frontend_redirect_url", "/auth/linuxdo/callback") + viper.SetDefault("linuxdo_connect.token_auth_method", "client_secret_post") + viper.SetDefault("linuxdo_connect.use_pkce", false) + viper.SetDefault("linuxdo_connect.userinfo_email_path", "") + viper.SetDefault("linuxdo_connect.userinfo_id_path", "") + viper.SetDefault("linuxdo_connect.userinfo_username_path", "") + // Database viper.SetDefault("database.host", "localhost") viper.SetDefault("database.port", 5432) @@ -586,6 +715,60 @@ func (c *Config) Validate() error { if c.Security.CSP.Enabled && strings.TrimSpace(c.Security.CSP.Policy) == "" { return fmt.Errorf("security.csp.policy is required when CSP is enabled") } + if c.LinuxDo.Enabled { + if strings.TrimSpace(c.LinuxDo.ClientID) == "" { + return fmt.Errorf("linuxdo_connect.client_id is required when linuxdo_connect.enabled=true") + } + if strings.TrimSpace(c.LinuxDo.AuthorizeURL) == "" { + return fmt.Errorf("linuxdo_connect.authorize_url is required when linuxdo_connect.enabled=true") + } + if strings.TrimSpace(c.LinuxDo.TokenURL) == "" { + return fmt.Errorf("linuxdo_connect.token_url is required when linuxdo_connect.enabled=true") + } + if strings.TrimSpace(c.LinuxDo.UserInfoURL) == "" { + return fmt.Errorf("linuxdo_connect.userinfo_url is required when linuxdo_connect.enabled=true") + } + if strings.TrimSpace(c.LinuxDo.RedirectURL) == "" { + return fmt.Errorf("linuxdo_connect.redirect_url is required when linuxdo_connect.enabled=true") + } + method := strings.ToLower(strings.TrimSpace(c.LinuxDo.TokenAuthMethod)) + switch method { + case "", "client_secret_post", "client_secret_basic", "none": + default: + return fmt.Errorf("linuxdo_connect.token_auth_method must be one of: client_secret_post/client_secret_basic/none") + } + if method == "none" && !c.LinuxDo.UsePKCE { + return fmt.Errorf("linuxdo_connect.use_pkce must be true when linuxdo_connect.token_auth_method=none") + } + if (method == "" || method == "client_secret_post" || method == "client_secret_basic") && strings.TrimSpace(c.LinuxDo.ClientSecret) == "" { + return fmt.Errorf("linuxdo_connect.client_secret is required when linuxdo_connect.enabled=true and token_auth_method is client_secret_post/client_secret_basic") + } + if strings.TrimSpace(c.LinuxDo.FrontendRedirectURL) == "" { + return fmt.Errorf("linuxdo_connect.frontend_redirect_url is required when linuxdo_connect.enabled=true") + } + + if err := ValidateAbsoluteHTTPURL(c.LinuxDo.AuthorizeURL); err != nil { + return fmt.Errorf("linuxdo_connect.authorize_url invalid: %w", err) + } + if err := ValidateAbsoluteHTTPURL(c.LinuxDo.TokenURL); err != nil { + return fmt.Errorf("linuxdo_connect.token_url invalid: %w", err) + } + if err := ValidateAbsoluteHTTPURL(c.LinuxDo.UserInfoURL); err != nil { + return fmt.Errorf("linuxdo_connect.userinfo_url invalid: %w", err) + } + if err := ValidateAbsoluteHTTPURL(c.LinuxDo.RedirectURL); err != nil { + return fmt.Errorf("linuxdo_connect.redirect_url invalid: %w", err) + } + if err := ValidateFrontendRedirectURL(c.LinuxDo.FrontendRedirectURL); err != nil { + return fmt.Errorf("linuxdo_connect.frontend_redirect_url invalid: %w", err) + } + + warnIfInsecureURL("linuxdo_connect.authorize_url", c.LinuxDo.AuthorizeURL) + warnIfInsecureURL("linuxdo_connect.token_url", c.LinuxDo.TokenURL) + warnIfInsecureURL("linuxdo_connect.userinfo_url", c.LinuxDo.UserInfoURL) + warnIfInsecureURL("linuxdo_connect.redirect_url", c.LinuxDo.RedirectURL) + warnIfInsecureURL("linuxdo_connect.frontend_redirect_url", c.LinuxDo.FrontendRedirectURL) + } if c.Billing.CircuitBreaker.Enabled { if c.Billing.CircuitBreaker.FailureThreshold <= 0 { return fmt.Errorf("billing.circuit_breaker.failure_threshold must be positive") diff --git a/backend/internal/config/config_test.go b/backend/internal/config/config_test.go index f28680c6..a39d41f9 100644 --- a/backend/internal/config/config_test.go +++ b/backend/internal/config/config_test.go @@ -1,6 +1,7 @@ package config import ( + "strings" "testing" "time" @@ -90,3 +91,53 @@ func TestLoadDefaultSecurityToggles(t *testing.T) { t.Fatalf("ResponseHeaders.Enabled = true, want false") } } + +func TestValidateLinuxDoFrontendRedirectURL(t *testing.T) { + viper.Reset() + + cfg, err := Load() + if err != nil { + t.Fatalf("Load() error: %v", err) + } + + cfg.LinuxDo.Enabled = true + cfg.LinuxDo.ClientID = "test-client" + cfg.LinuxDo.ClientSecret = "test-secret" + cfg.LinuxDo.RedirectURL = "https://example.com/api/v1/auth/oauth/linuxdo/callback" + cfg.LinuxDo.TokenAuthMethod = "client_secret_post" + cfg.LinuxDo.UsePKCE = false + + cfg.LinuxDo.FrontendRedirectURL = "javascript:alert(1)" + err = cfg.Validate() + if err == nil { + t.Fatalf("Validate() expected error for javascript scheme, got nil") + } + if !strings.Contains(err.Error(), "linuxdo_connect.frontend_redirect_url") { + t.Fatalf("Validate() expected frontend_redirect_url error, got: %v", err) + } +} + +func TestValidateLinuxDoPKCERequiredForPublicClient(t *testing.T) { + viper.Reset() + + cfg, err := Load() + if err != nil { + t.Fatalf("Load() error: %v", err) + } + + cfg.LinuxDo.Enabled = true + cfg.LinuxDo.ClientID = "test-client" + cfg.LinuxDo.ClientSecret = "" + cfg.LinuxDo.RedirectURL = "https://example.com/api/v1/auth/oauth/linuxdo/callback" + cfg.LinuxDo.FrontendRedirectURL = "/auth/linuxdo/callback" + cfg.LinuxDo.TokenAuthMethod = "none" + cfg.LinuxDo.UsePKCE = false + + err = cfg.Validate() + if err == nil { + t.Fatalf("Validate() expected error when token_auth_method=none and use_pkce=false, got nil") + } + if !strings.Contains(err.Error(), "linuxdo_connect.use_pkce") { + t.Fatalf("Validate() expected use_pkce error, got: %v", err) + } +} diff --git a/backend/internal/handler/admin/account_handler.go b/backend/internal/handler/admin/account_handler.go index da9f6990..8a7270e5 100644 --- a/backend/internal/handler/admin/account_handler.go +++ b/backend/internal/handler/admin/account_handler.go @@ -116,6 +116,7 @@ type BulkUpdateAccountsRequest struct { Concurrency *int `json:"concurrency"` Priority *int `json:"priority"` Status string `json:"status" binding:"omitempty,oneof=active inactive error"` + Schedulable *bool `json:"schedulable"` GroupIDs *[]int64 `json:"group_ids"` Credentials map[string]any `json:"credentials"` Extra map[string]any `json:"extra"` @@ -136,6 +137,11 @@ func (h *AccountHandler) List(c *gin.Context) { accountType := c.Query("type") status := c.Query("status") search := c.Query("search") + // 标准化和验证 search 参数 + search = strings.TrimSpace(search) + if len(search) > 100 { + search = search[:100] + } accounts, total, err := h.adminService.ListAccounts(c.Request.Context(), page, pageSize, platform, accountType, status, search) if err != nil { @@ -655,6 +661,7 @@ func (h *AccountHandler) BulkUpdate(c *gin.Context) { req.Concurrency != nil || req.Priority != nil || req.Status != "" || + req.Schedulable != nil || req.GroupIDs != nil || len(req.Credentials) > 0 || len(req.Extra) > 0 @@ -671,6 +678,7 @@ func (h *AccountHandler) BulkUpdate(c *gin.Context) { Concurrency: req.Concurrency, Priority: req.Priority, Status: req.Status, + Schedulable: req.Schedulable, GroupIDs: req.GroupIDs, Credentials: req.Credentials, Extra: req.Extra, diff --git a/backend/internal/handler/admin/group_handler.go b/backend/internal/handler/admin/group_handler.go index 182d26d0..a8bae35e 100644 --- a/backend/internal/handler/admin/group_handler.go +++ b/backend/internal/handler/admin/group_handler.go @@ -2,6 +2,7 @@ package admin import ( "strconv" + "strings" "github.com/Wei-Shaw/sub2api/internal/handler/dto" "github.com/Wei-Shaw/sub2api/internal/pkg/response" @@ -34,9 +35,11 @@ type CreateGroupRequest struct { WeeklyLimitUSD *float64 `json:"weekly_limit_usd"` MonthlyLimitUSD *float64 `json:"monthly_limit_usd"` // 图片生成计费配置(antigravity 和 gemini 平台使用,负数表示清除配置) - ImagePrice1K *float64 `json:"image_price_1k"` - ImagePrice2K *float64 `json:"image_price_2k"` - ImagePrice4K *float64 `json:"image_price_4k"` + ImagePrice1K *float64 `json:"image_price_1k"` + ImagePrice2K *float64 `json:"image_price_2k"` + ImagePrice4K *float64 `json:"image_price_4k"` + ClaudeCodeOnly bool `json:"claude_code_only"` + FallbackGroupID *int64 `json:"fallback_group_id"` } // UpdateGroupRequest represents update group request @@ -52,9 +55,11 @@ type UpdateGroupRequest struct { WeeklyLimitUSD *float64 `json:"weekly_limit_usd"` MonthlyLimitUSD *float64 `json:"monthly_limit_usd"` // 图片生成计费配置(antigravity 和 gemini 平台使用,负数表示清除配置) - ImagePrice1K *float64 `json:"image_price_1k"` - ImagePrice2K *float64 `json:"image_price_2k"` - ImagePrice4K *float64 `json:"image_price_4k"` + ImagePrice1K *float64 `json:"image_price_1k"` + ImagePrice2K *float64 `json:"image_price_2k"` + ImagePrice4K *float64 `json:"image_price_4k"` + ClaudeCodeOnly *bool `json:"claude_code_only"` + FallbackGroupID *int64 `json:"fallback_group_id"` } // List handles listing all groups with pagination @@ -63,6 +68,12 @@ func (h *GroupHandler) List(c *gin.Context) { page, pageSize := response.ParsePagination(c) platform := c.Query("platform") status := c.Query("status") + search := c.Query("search") + // 标准化和验证 search 参数 + search = strings.TrimSpace(search) + if len(search) > 100 { + search = search[:100] + } isExclusiveStr := c.Query("is_exclusive") var isExclusive *bool @@ -71,7 +82,7 @@ func (h *GroupHandler) List(c *gin.Context) { isExclusive = &val } - groups, total, err := h.adminService.ListGroups(c.Request.Context(), page, pageSize, platform, status, isExclusive) + groups, total, err := h.adminService.ListGroups(c.Request.Context(), page, pageSize, platform, status, search, isExclusive) if err != nil { response.ErrorFrom(c, err) return @@ -150,6 +161,8 @@ func (h *GroupHandler) Create(c *gin.Context) { ImagePrice1K: req.ImagePrice1K, ImagePrice2K: req.ImagePrice2K, ImagePrice4K: req.ImagePrice4K, + ClaudeCodeOnly: req.ClaudeCodeOnly, + FallbackGroupID: req.FallbackGroupID, }) if err != nil { response.ErrorFrom(c, err) @@ -188,6 +201,8 @@ func (h *GroupHandler) Update(c *gin.Context) { ImagePrice1K: req.ImagePrice1K, ImagePrice2K: req.ImagePrice2K, ImagePrice4K: req.ImagePrice4K, + ClaudeCodeOnly: req.ClaudeCodeOnly, + FallbackGroupID: req.FallbackGroupID, }) if err != nil { response.ErrorFrom(c, err) diff --git a/backend/internal/handler/admin/proxy_handler.go b/backend/internal/handler/admin/proxy_handler.go index 99557f9a..437e9300 100644 --- a/backend/internal/handler/admin/proxy_handler.go +++ b/backend/internal/handler/admin/proxy_handler.go @@ -51,16 +51,21 @@ func (h *ProxyHandler) List(c *gin.Context) { protocol := c.Query("protocol") status := c.Query("status") search := c.Query("search") + // 标准化和验证 search 参数 + search = strings.TrimSpace(search) + if len(search) > 100 { + search = search[:100] + } - proxies, total, err := h.adminService.ListProxies(c.Request.Context(), page, pageSize, protocol, status, search) + proxies, total, err := h.adminService.ListProxiesWithAccountCount(c.Request.Context(), page, pageSize, protocol, status, search) if err != nil { response.ErrorFrom(c, err) return } - out := make([]dto.Proxy, 0, len(proxies)) + out := make([]dto.ProxyWithAccountCount, 0, len(proxies)) for i := range proxies { - out = append(out, *dto.ProxyFromService(&proxies[i])) + out = append(out, *dto.ProxyWithAccountCountFromService(&proxies[i])) } response.Paginated(c, out, total, page, pageSize) } diff --git a/backend/internal/handler/admin/redeem_handler.go b/backend/internal/handler/admin/redeem_handler.go index 45fae43a..5b3229b6 100644 --- a/backend/internal/handler/admin/redeem_handler.go +++ b/backend/internal/handler/admin/redeem_handler.go @@ -5,6 +5,7 @@ import ( "encoding/csv" "fmt" "strconv" + "strings" "github.com/Wei-Shaw/sub2api/internal/handler/dto" "github.com/Wei-Shaw/sub2api/internal/pkg/response" @@ -41,6 +42,11 @@ func (h *RedeemHandler) List(c *gin.Context) { codeType := c.Query("type") status := c.Query("status") search := c.Query("search") + // 标准化和验证 search 参数 + search = strings.TrimSpace(search) + if len(search) > 100 { + search = search[:100] + } codes, total, err := h.adminService.ListRedeemCodes(c.Request.Context(), page, pageSize, codeType, status, search) if err != nil { diff --git a/backend/internal/handler/admin/setting_handler.go b/backend/internal/handler/admin/setting_handler.go index 743c4268..d95a8980 100644 --- a/backend/internal/handler/admin/setting_handler.go +++ b/backend/internal/handler/admin/setting_handler.go @@ -2,8 +2,10 @@ package admin import ( "log" + "strings" "time" + "github.com/Wei-Shaw/sub2api/internal/config" "github.com/Wei-Shaw/sub2api/internal/handler/dto" "github.com/Wei-Shaw/sub2api/internal/pkg/response" "github.com/Wei-Shaw/sub2api/internal/server/middleware" @@ -38,33 +40,37 @@ func (h *SettingHandler) GetSettings(c *gin.Context) { } response.Success(c, dto.SystemSettings{ - RegistrationEnabled: settings.RegistrationEnabled, - EmailVerifyEnabled: settings.EmailVerifyEnabled, - SMTPHost: settings.SMTPHost, - SMTPPort: settings.SMTPPort, - SMTPUsername: settings.SMTPUsername, - SMTPPasswordConfigured: settings.SMTPPasswordConfigured, - SMTPFrom: settings.SMTPFrom, - SMTPFromName: settings.SMTPFromName, - SMTPUseTLS: settings.SMTPUseTLS, - TurnstileEnabled: settings.TurnstileEnabled, - TurnstileSiteKey: settings.TurnstileSiteKey, - TurnstileSecretKeyConfigured: settings.TurnstileSecretKeyConfigured, - SiteName: settings.SiteName, - SiteLogo: settings.SiteLogo, - SiteSubtitle: settings.SiteSubtitle, - APIBaseURL: settings.APIBaseURL, - ContactInfo: settings.ContactInfo, - DocURL: settings.DocURL, - DefaultConcurrency: settings.DefaultConcurrency, - DefaultBalance: settings.DefaultBalance, - EnableModelFallback: settings.EnableModelFallback, - FallbackModelAnthropic: settings.FallbackModelAnthropic, - FallbackModelOpenAI: settings.FallbackModelOpenAI, - FallbackModelGemini: settings.FallbackModelGemini, - FallbackModelAntigravity: settings.FallbackModelAntigravity, - EnableIdentityPatch: settings.EnableIdentityPatch, - IdentityPatchPrompt: settings.IdentityPatchPrompt, + RegistrationEnabled: settings.RegistrationEnabled, + EmailVerifyEnabled: settings.EmailVerifyEnabled, + SMTPHost: settings.SMTPHost, + SMTPPort: settings.SMTPPort, + SMTPUsername: settings.SMTPUsername, + SMTPPasswordConfigured: settings.SMTPPasswordConfigured, + SMTPFrom: settings.SMTPFrom, + SMTPFromName: settings.SMTPFromName, + SMTPUseTLS: settings.SMTPUseTLS, + TurnstileEnabled: settings.TurnstileEnabled, + TurnstileSiteKey: settings.TurnstileSiteKey, + TurnstileSecretKeyConfigured: settings.TurnstileSecretKeyConfigured, + LinuxDoConnectEnabled: settings.LinuxDoConnectEnabled, + LinuxDoConnectClientID: settings.LinuxDoConnectClientID, + LinuxDoConnectClientSecretConfigured: settings.LinuxDoConnectClientSecretConfigured, + LinuxDoConnectRedirectURL: settings.LinuxDoConnectRedirectURL, + SiteName: settings.SiteName, + SiteLogo: settings.SiteLogo, + SiteSubtitle: settings.SiteSubtitle, + APIBaseURL: settings.APIBaseURL, + ContactInfo: settings.ContactInfo, + DocURL: settings.DocURL, + DefaultConcurrency: settings.DefaultConcurrency, + DefaultBalance: settings.DefaultBalance, + EnableModelFallback: settings.EnableModelFallback, + FallbackModelAnthropic: settings.FallbackModelAnthropic, + FallbackModelOpenAI: settings.FallbackModelOpenAI, + FallbackModelGemini: settings.FallbackModelGemini, + FallbackModelAntigravity: settings.FallbackModelAntigravity, + EnableIdentityPatch: settings.EnableIdentityPatch, + IdentityPatchPrompt: settings.IdentityPatchPrompt, }) } @@ -88,6 +94,12 @@ type UpdateSettingsRequest struct { TurnstileSiteKey string `json:"turnstile_site_key"` TurnstileSecretKey string `json:"turnstile_secret_key"` + // LinuxDo Connect OAuth 登录(终端用户 SSO) + LinuxDoConnectEnabled bool `json:"linuxdo_connect_enabled"` + LinuxDoConnectClientID string `json:"linuxdo_connect_client_id"` + LinuxDoConnectClientSecret string `json:"linuxdo_connect_client_secret"` + LinuxDoConnectRedirectURL string `json:"linuxdo_connect_redirect_url"` + // OEM设置 SiteName string `json:"site_name"` SiteLogo string `json:"site_logo"` @@ -165,34 +177,67 @@ func (h *SettingHandler) UpdateSettings(c *gin.Context) { } } + // LinuxDo Connect 参数验证 + if req.LinuxDoConnectEnabled { + req.LinuxDoConnectClientID = strings.TrimSpace(req.LinuxDoConnectClientID) + req.LinuxDoConnectClientSecret = strings.TrimSpace(req.LinuxDoConnectClientSecret) + req.LinuxDoConnectRedirectURL = strings.TrimSpace(req.LinuxDoConnectRedirectURL) + + if req.LinuxDoConnectClientID == "" { + response.BadRequest(c, "LinuxDo Client ID is required when enabled") + return + } + if req.LinuxDoConnectRedirectURL == "" { + response.BadRequest(c, "LinuxDo Redirect URL is required when enabled") + return + } + if err := config.ValidateAbsoluteHTTPURL(req.LinuxDoConnectRedirectURL); err != nil { + response.BadRequest(c, "LinuxDo Redirect URL must be an absolute http(s) URL") + return + } + + // 如果未提供 client_secret,则保留现有值(如有)。 + if req.LinuxDoConnectClientSecret == "" { + if previousSettings.LinuxDoConnectClientSecret == "" { + response.BadRequest(c, "LinuxDo Client Secret is required when enabled") + return + } + req.LinuxDoConnectClientSecret = previousSettings.LinuxDoConnectClientSecret + } + } + settings := &service.SystemSettings{ - RegistrationEnabled: req.RegistrationEnabled, - EmailVerifyEnabled: req.EmailVerifyEnabled, - SMTPHost: req.SMTPHost, - SMTPPort: req.SMTPPort, - SMTPUsername: req.SMTPUsername, - SMTPPassword: req.SMTPPassword, - SMTPFrom: req.SMTPFrom, - SMTPFromName: req.SMTPFromName, - SMTPUseTLS: req.SMTPUseTLS, - TurnstileEnabled: req.TurnstileEnabled, - TurnstileSiteKey: req.TurnstileSiteKey, - TurnstileSecretKey: req.TurnstileSecretKey, - SiteName: req.SiteName, - SiteLogo: req.SiteLogo, - SiteSubtitle: req.SiteSubtitle, - APIBaseURL: req.APIBaseURL, - ContactInfo: req.ContactInfo, - DocURL: req.DocURL, - DefaultConcurrency: req.DefaultConcurrency, - DefaultBalance: req.DefaultBalance, - EnableModelFallback: req.EnableModelFallback, - FallbackModelAnthropic: req.FallbackModelAnthropic, - FallbackModelOpenAI: req.FallbackModelOpenAI, - FallbackModelGemini: req.FallbackModelGemini, - FallbackModelAntigravity: req.FallbackModelAntigravity, - EnableIdentityPatch: req.EnableIdentityPatch, - IdentityPatchPrompt: req.IdentityPatchPrompt, + RegistrationEnabled: req.RegistrationEnabled, + EmailVerifyEnabled: req.EmailVerifyEnabled, + SMTPHost: req.SMTPHost, + SMTPPort: req.SMTPPort, + SMTPUsername: req.SMTPUsername, + SMTPPassword: req.SMTPPassword, + SMTPFrom: req.SMTPFrom, + SMTPFromName: req.SMTPFromName, + SMTPUseTLS: req.SMTPUseTLS, + TurnstileEnabled: req.TurnstileEnabled, + TurnstileSiteKey: req.TurnstileSiteKey, + TurnstileSecretKey: req.TurnstileSecretKey, + LinuxDoConnectEnabled: req.LinuxDoConnectEnabled, + LinuxDoConnectClientID: req.LinuxDoConnectClientID, + LinuxDoConnectClientSecret: req.LinuxDoConnectClientSecret, + LinuxDoConnectRedirectURL: req.LinuxDoConnectRedirectURL, + SiteName: req.SiteName, + SiteLogo: req.SiteLogo, + SiteSubtitle: req.SiteSubtitle, + APIBaseURL: req.APIBaseURL, + ContactInfo: req.ContactInfo, + DocURL: req.DocURL, + DefaultConcurrency: req.DefaultConcurrency, + DefaultBalance: req.DefaultBalance, + EnableModelFallback: req.EnableModelFallback, + FallbackModelAnthropic: req.FallbackModelAnthropic, + FallbackModelOpenAI: req.FallbackModelOpenAI, + FallbackModelGemini: req.FallbackModelGemini, + FallbackModelAntigravity: req.FallbackModelAntigravity, + EnableIdentityPatch: req.EnableIdentityPatch, + IdentityPatchPrompt: req.IdentityPatchPrompt, } if err := h.settingService.UpdateSettings(c.Request.Context(), settings); err != nil { @@ -210,33 +255,37 @@ func (h *SettingHandler) UpdateSettings(c *gin.Context) { } response.Success(c, dto.SystemSettings{ - RegistrationEnabled: updatedSettings.RegistrationEnabled, - EmailVerifyEnabled: updatedSettings.EmailVerifyEnabled, - SMTPHost: updatedSettings.SMTPHost, - SMTPPort: updatedSettings.SMTPPort, - SMTPUsername: updatedSettings.SMTPUsername, - SMTPPasswordConfigured: updatedSettings.SMTPPasswordConfigured, - SMTPFrom: updatedSettings.SMTPFrom, - SMTPFromName: updatedSettings.SMTPFromName, - SMTPUseTLS: updatedSettings.SMTPUseTLS, - TurnstileEnabled: updatedSettings.TurnstileEnabled, - TurnstileSiteKey: updatedSettings.TurnstileSiteKey, - TurnstileSecretKeyConfigured: updatedSettings.TurnstileSecretKeyConfigured, - SiteName: updatedSettings.SiteName, - SiteLogo: updatedSettings.SiteLogo, - SiteSubtitle: updatedSettings.SiteSubtitle, - APIBaseURL: updatedSettings.APIBaseURL, - ContactInfo: updatedSettings.ContactInfo, - DocURL: updatedSettings.DocURL, - DefaultConcurrency: updatedSettings.DefaultConcurrency, - DefaultBalance: updatedSettings.DefaultBalance, - EnableModelFallback: updatedSettings.EnableModelFallback, - FallbackModelAnthropic: updatedSettings.FallbackModelAnthropic, - FallbackModelOpenAI: updatedSettings.FallbackModelOpenAI, - FallbackModelGemini: updatedSettings.FallbackModelGemini, - FallbackModelAntigravity: updatedSettings.FallbackModelAntigravity, - EnableIdentityPatch: updatedSettings.EnableIdentityPatch, - IdentityPatchPrompt: updatedSettings.IdentityPatchPrompt, + RegistrationEnabled: updatedSettings.RegistrationEnabled, + EmailVerifyEnabled: updatedSettings.EmailVerifyEnabled, + SMTPHost: updatedSettings.SMTPHost, + SMTPPort: updatedSettings.SMTPPort, + SMTPUsername: updatedSettings.SMTPUsername, + SMTPPasswordConfigured: updatedSettings.SMTPPasswordConfigured, + SMTPFrom: updatedSettings.SMTPFrom, + SMTPFromName: updatedSettings.SMTPFromName, + SMTPUseTLS: updatedSettings.SMTPUseTLS, + TurnstileEnabled: updatedSettings.TurnstileEnabled, + TurnstileSiteKey: updatedSettings.TurnstileSiteKey, + TurnstileSecretKeyConfigured: updatedSettings.TurnstileSecretKeyConfigured, + LinuxDoConnectEnabled: updatedSettings.LinuxDoConnectEnabled, + LinuxDoConnectClientID: updatedSettings.LinuxDoConnectClientID, + LinuxDoConnectClientSecretConfigured: updatedSettings.LinuxDoConnectClientSecretConfigured, + LinuxDoConnectRedirectURL: updatedSettings.LinuxDoConnectRedirectURL, + SiteName: updatedSettings.SiteName, + SiteLogo: updatedSettings.SiteLogo, + SiteSubtitle: updatedSettings.SiteSubtitle, + APIBaseURL: updatedSettings.APIBaseURL, + ContactInfo: updatedSettings.ContactInfo, + DocURL: updatedSettings.DocURL, + DefaultConcurrency: updatedSettings.DefaultConcurrency, + DefaultBalance: updatedSettings.DefaultBalance, + EnableModelFallback: updatedSettings.EnableModelFallback, + FallbackModelAnthropic: updatedSettings.FallbackModelAnthropic, + FallbackModelOpenAI: updatedSettings.FallbackModelOpenAI, + FallbackModelGemini: updatedSettings.FallbackModelGemini, + FallbackModelAntigravity: updatedSettings.FallbackModelAntigravity, + EnableIdentityPatch: updatedSettings.EnableIdentityPatch, + IdentityPatchPrompt: updatedSettings.IdentityPatchPrompt, }) } @@ -298,6 +347,18 @@ func diffSettings(before *service.SystemSettings, after *service.SystemSettings, if req.TurnstileSecretKey != "" { changed = append(changed, "turnstile_secret_key") } + if before.LinuxDoConnectEnabled != after.LinuxDoConnectEnabled { + changed = append(changed, "linuxdo_connect_enabled") + } + if before.LinuxDoConnectClientID != after.LinuxDoConnectClientID { + changed = append(changed, "linuxdo_connect_client_id") + } + if req.LinuxDoConnectClientSecret != "" { + changed = append(changed, "linuxdo_connect_client_secret") + } + if before.LinuxDoConnectRedirectURL != after.LinuxDoConnectRedirectURL { + changed = append(changed, "linuxdo_connect_redirect_url") + } if before.SiteName != after.SiteName { changed = append(changed, "site_name") } @@ -337,6 +398,12 @@ func diffSettings(before *service.SystemSettings, after *service.SystemSettings, if before.FallbackModelAntigravity != after.FallbackModelAntigravity { changed = append(changed, "fallback_model_antigravity") } + if before.EnableIdentityPatch != after.EnableIdentityPatch { + changed = append(changed, "enable_identity_patch") + } + if before.IdentityPatchPrompt != after.IdentityPatchPrompt { + changed = append(changed, "identity_patch_prompt") + } return changed } diff --git a/backend/internal/handler/admin/user_handler.go b/backend/internal/handler/admin/user_handler.go index f8cd1d5a..38cc8acd 100644 --- a/backend/internal/handler/admin/user_handler.go +++ b/backend/internal/handler/admin/user_handler.go @@ -2,6 +2,7 @@ package admin import ( "strconv" + "strings" "github.com/Wei-Shaw/sub2api/internal/handler/dto" "github.com/Wei-Shaw/sub2api/internal/pkg/response" @@ -63,10 +64,17 @@ type UpdateBalanceRequest struct { func (h *UserHandler) List(c *gin.Context) { page, pageSize := response.ParsePagination(c) + search := c.Query("search") + // 标准化和验证 search 参数 + search = strings.TrimSpace(search) + if len(search) > 100 { + search = search[:100] + } + filters := service.UserListFilters{ Status: c.Query("status"), Role: c.Query("role"), - Search: c.Query("search"), + Search: search, Attributes: parseAttributeFilters(c), } diff --git a/backend/internal/handler/auth_handler.go b/backend/internal/handler/auth_handler.go index 8466f131..8463367e 100644 --- a/backend/internal/handler/auth_handler.go +++ b/backend/internal/handler/auth_handler.go @@ -15,14 +15,16 @@ type AuthHandler struct { cfg *config.Config authService *service.AuthService userService *service.UserService + settingSvc *service.SettingService } // NewAuthHandler creates a new AuthHandler -func NewAuthHandler(cfg *config.Config, authService *service.AuthService, userService *service.UserService) *AuthHandler { +func NewAuthHandler(cfg *config.Config, authService *service.AuthService, userService *service.UserService, settingService *service.SettingService) *AuthHandler { return &AuthHandler{ cfg: cfg, authService: authService, userService: userService, + settingSvc: settingService, } } diff --git a/backend/internal/handler/auth_linuxdo_oauth.go b/backend/internal/handler/auth_linuxdo_oauth.go new file mode 100644 index 00000000..a16c4cc7 --- /dev/null +++ b/backend/internal/handler/auth_linuxdo_oauth.go @@ -0,0 +1,679 @@ +package handler + +import ( + "context" + "encoding/base64" + "errors" + "fmt" + "log" + "net/http" + "net/url" + "strconv" + "strings" + "time" + "unicode/utf8" + + "github.com/Wei-Shaw/sub2api/internal/config" + infraerrors "github.com/Wei-Shaw/sub2api/internal/pkg/errors" + "github.com/Wei-Shaw/sub2api/internal/pkg/oauth" + "github.com/Wei-Shaw/sub2api/internal/pkg/response" + "github.com/Wei-Shaw/sub2api/internal/service" + + "github.com/gin-gonic/gin" + "github.com/imroc/req/v3" + "github.com/tidwall/gjson" +) + +const ( + linuxDoOAuthCookiePath = "/api/v1/auth/oauth/linuxdo" + linuxDoOAuthStateCookieName = "linuxdo_oauth_state" + linuxDoOAuthVerifierCookie = "linuxdo_oauth_verifier" + linuxDoOAuthRedirectCookie = "linuxdo_oauth_redirect" + linuxDoOAuthCookieMaxAgeSec = 10 * 60 // 10 minutes + linuxDoOAuthDefaultRedirectTo = "/dashboard" + linuxDoOAuthDefaultFrontendCB = "/auth/linuxdo/callback" + + linuxDoOAuthMaxRedirectLen = 2048 + linuxDoOAuthMaxFragmentValueLen = 512 + linuxDoOAuthMaxSubjectLen = 64 - len("linuxdo-") +) + +type linuxDoTokenResponse struct { + AccessToken string `json:"access_token"` + TokenType string `json:"token_type"` + ExpiresIn int64 `json:"expires_in"` + RefreshToken string `json:"refresh_token,omitempty"` + Scope string `json:"scope,omitempty"` +} + +type linuxDoTokenExchangeError struct { + StatusCode int + ProviderError string + ProviderDescription string + Body string +} + +func (e *linuxDoTokenExchangeError) Error() string { + if e == nil { + return "" + } + parts := []string{fmt.Sprintf("token exchange status=%d", e.StatusCode)} + if strings.TrimSpace(e.ProviderError) != "" { + parts = append(parts, "error="+strings.TrimSpace(e.ProviderError)) + } + if strings.TrimSpace(e.ProviderDescription) != "" { + parts = append(parts, "error_description="+strings.TrimSpace(e.ProviderDescription)) + } + return strings.Join(parts, " ") +} + +// LinuxDoOAuthStart 启动 LinuxDo Connect OAuth 登录流程。 +// GET /api/v1/auth/oauth/linuxdo/start?redirect=/dashboard +func (h *AuthHandler) LinuxDoOAuthStart(c *gin.Context) { + cfg, err := h.getLinuxDoOAuthConfig(c.Request.Context()) + if err != nil { + response.ErrorFrom(c, err) + return + } + + state, err := oauth.GenerateState() + if err != nil { + response.ErrorFrom(c, infraerrors.InternalServer("OAUTH_STATE_GEN_FAILED", "failed to generate oauth state").WithCause(err)) + return + } + + redirectTo := sanitizeFrontendRedirectPath(c.Query("redirect")) + if redirectTo == "" { + redirectTo = linuxDoOAuthDefaultRedirectTo + } + + secureCookie := isRequestHTTPS(c) + setCookie(c, linuxDoOAuthStateCookieName, encodeCookieValue(state), linuxDoOAuthCookieMaxAgeSec, secureCookie) + setCookie(c, linuxDoOAuthRedirectCookie, encodeCookieValue(redirectTo), linuxDoOAuthCookieMaxAgeSec, secureCookie) + + codeChallenge := "" + if cfg.UsePKCE { + verifier, err := oauth.GenerateCodeVerifier() + if err != nil { + response.ErrorFrom(c, infraerrors.InternalServer("OAUTH_PKCE_GEN_FAILED", "failed to generate pkce verifier").WithCause(err)) + return + } + codeChallenge = oauth.GenerateCodeChallenge(verifier) + setCookie(c, linuxDoOAuthVerifierCookie, encodeCookieValue(verifier), linuxDoOAuthCookieMaxAgeSec, secureCookie) + } + + redirectURI := strings.TrimSpace(cfg.RedirectURL) + if redirectURI == "" { + response.ErrorFrom(c, infraerrors.InternalServer("OAUTH_CONFIG_INVALID", "oauth redirect url not configured")) + return + } + + authURL, err := buildLinuxDoAuthorizeURL(cfg, state, codeChallenge, redirectURI) + if err != nil { + response.ErrorFrom(c, infraerrors.InternalServer("OAUTH_BUILD_URL_FAILED", "failed to build oauth authorization url").WithCause(err)) + return + } + + c.Redirect(http.StatusFound, authURL) +} + +// LinuxDoOAuthCallback 处理 OAuth 回调:创建/登录用户,然后重定向到前端。 +// GET /api/v1/auth/oauth/linuxdo/callback?code=...&state=... +func (h *AuthHandler) LinuxDoOAuthCallback(c *gin.Context) { + cfg, cfgErr := h.getLinuxDoOAuthConfig(c.Request.Context()) + if cfgErr != nil { + response.ErrorFrom(c, cfgErr) + return + } + + frontendCallback := strings.TrimSpace(cfg.FrontendRedirectURL) + if frontendCallback == "" { + frontendCallback = linuxDoOAuthDefaultFrontendCB + } + + if providerErr := strings.TrimSpace(c.Query("error")); providerErr != "" { + redirectOAuthError(c, frontendCallback, "provider_error", providerErr, c.Query("error_description")) + return + } + + code := strings.TrimSpace(c.Query("code")) + state := strings.TrimSpace(c.Query("state")) + if code == "" || state == "" { + redirectOAuthError(c, frontendCallback, "missing_params", "missing code/state", "") + return + } + + secureCookie := isRequestHTTPS(c) + defer func() { + clearCookie(c, linuxDoOAuthStateCookieName, secureCookie) + clearCookie(c, linuxDoOAuthVerifierCookie, secureCookie) + clearCookie(c, linuxDoOAuthRedirectCookie, secureCookie) + }() + + expectedState, err := readCookieDecoded(c, linuxDoOAuthStateCookieName) + if err != nil || expectedState == "" || state != expectedState { + redirectOAuthError(c, frontendCallback, "invalid_state", "invalid oauth state", "") + return + } + + redirectTo, _ := readCookieDecoded(c, linuxDoOAuthRedirectCookie) + redirectTo = sanitizeFrontendRedirectPath(redirectTo) + if redirectTo == "" { + redirectTo = linuxDoOAuthDefaultRedirectTo + } + + codeVerifier := "" + if cfg.UsePKCE { + codeVerifier, _ = readCookieDecoded(c, linuxDoOAuthVerifierCookie) + if codeVerifier == "" { + redirectOAuthError(c, frontendCallback, "missing_verifier", "missing pkce verifier", "") + return + } + } + + redirectURI := strings.TrimSpace(cfg.RedirectURL) + if redirectURI == "" { + redirectOAuthError(c, frontendCallback, "config_error", "oauth redirect url not configured", "") + return + } + + tokenResp, err := linuxDoExchangeCode(c.Request.Context(), cfg, code, redirectURI, codeVerifier) + if err != nil { + description := "" + var exchangeErr *linuxDoTokenExchangeError + if errors.As(err, &exchangeErr) && exchangeErr != nil { + log.Printf( + "[LinuxDo OAuth] token exchange failed: status=%d provider_error=%q provider_description=%q body=%s", + exchangeErr.StatusCode, + exchangeErr.ProviderError, + exchangeErr.ProviderDescription, + truncateLogValue(exchangeErr.Body, 2048), + ) + description = exchangeErr.Error() + } else { + log.Printf("[LinuxDo OAuth] token exchange failed: %v", err) + description = err.Error() + } + redirectOAuthError(c, frontendCallback, "token_exchange_failed", "failed to exchange oauth code", singleLine(description)) + return + } + + email, username, subject, err := linuxDoFetchUserInfo(c.Request.Context(), cfg, tokenResp) + if err != nil { + log.Printf("[LinuxDo OAuth] userinfo fetch failed: %v", err) + redirectOAuthError(c, frontendCallback, "userinfo_failed", "failed to fetch user info", "") + return + } + + // 安全考虑:不要把第三方返回的 email 直接映射到本地账号(可能与本地邮箱用户冲突导致账号被接管)。 + // 统一使用基于 subject 的稳定合成邮箱来做账号绑定。 + if subject != "" { + email = linuxDoSyntheticEmail(subject) + } + + jwtToken, _, err := h.authService.LoginOrRegisterOAuth(c.Request.Context(), email, username) + if err != nil { + // 避免把内部细节泄露给客户端;给前端保留结构化原因与提示信息即可。 + redirectOAuthError(c, frontendCallback, "login_failed", infraerrors.Reason(err), infraerrors.Message(err)) + return + } + + fragment := url.Values{} + fragment.Set("access_token", jwtToken) + fragment.Set("token_type", "Bearer") + fragment.Set("redirect", redirectTo) + redirectWithFragment(c, frontendCallback, fragment) +} + +func (h *AuthHandler) getLinuxDoOAuthConfig(ctx context.Context) (config.LinuxDoConnectConfig, error) { + if h != nil && h.settingSvc != nil { + return h.settingSvc.GetLinuxDoConnectOAuthConfig(ctx) + } + if h == nil || h.cfg == nil { + return config.LinuxDoConnectConfig{}, infraerrors.ServiceUnavailable("CONFIG_NOT_READY", "config not loaded") + } + if !h.cfg.LinuxDo.Enabled { + return config.LinuxDoConnectConfig{}, infraerrors.NotFound("OAUTH_DISABLED", "oauth login is disabled") + } + return h.cfg.LinuxDo, nil +} + +func linuxDoExchangeCode( + ctx context.Context, + cfg config.LinuxDoConnectConfig, + code string, + redirectURI string, + codeVerifier string, +) (*linuxDoTokenResponse, error) { + client := req.C().SetTimeout(30 * time.Second) + + form := url.Values{} + form.Set("grant_type", "authorization_code") + form.Set("client_id", cfg.ClientID) + form.Set("code", code) + form.Set("redirect_uri", redirectURI) + if cfg.UsePKCE { + form.Set("code_verifier", codeVerifier) + } + + r := client.R(). + SetContext(ctx). + SetHeader("Accept", "application/json") + + switch strings.ToLower(strings.TrimSpace(cfg.TokenAuthMethod)) { + case "", "client_secret_post": + form.Set("client_secret", cfg.ClientSecret) + case "client_secret_basic": + r.SetBasicAuth(cfg.ClientID, cfg.ClientSecret) + case "none": + default: + return nil, fmt.Errorf("unsupported token_auth_method: %s", cfg.TokenAuthMethod) + } + + resp, err := r.SetFormDataFromValues(form).Post(cfg.TokenURL) + if err != nil { + return nil, fmt.Errorf("request token: %w", err) + } + body := strings.TrimSpace(resp.String()) + if !resp.IsSuccessState() { + providerErr, providerDesc := parseOAuthProviderError(body) + return nil, &linuxDoTokenExchangeError{ + StatusCode: resp.StatusCode, + ProviderError: providerErr, + ProviderDescription: providerDesc, + Body: body, + } + } + + tokenResp, ok := parseLinuxDoTokenResponse(body) + if !ok || strings.TrimSpace(tokenResp.AccessToken) == "" { + return nil, &linuxDoTokenExchangeError{ + StatusCode: resp.StatusCode, + Body: body, + } + } + if strings.TrimSpace(tokenResp.TokenType) == "" { + tokenResp.TokenType = "Bearer" + } + return tokenResp, nil +} + +func linuxDoFetchUserInfo( + ctx context.Context, + cfg config.LinuxDoConnectConfig, + token *linuxDoTokenResponse, +) (email string, username string, subject string, err error) { + client := req.C().SetTimeout(30 * time.Second) + authorization, err := buildBearerAuthorization(token.TokenType, token.AccessToken) + if err != nil { + return "", "", "", fmt.Errorf("invalid token for userinfo request: %w", err) + } + + resp, err := client.R(). + SetContext(ctx). + SetHeader("Accept", "application/json"). + SetHeader("Authorization", authorization). + Get(cfg.UserInfoURL) + if err != nil { + return "", "", "", fmt.Errorf("request userinfo: %w", err) + } + if !resp.IsSuccessState() { + return "", "", "", fmt.Errorf("userinfo status=%d", resp.StatusCode) + } + + return linuxDoParseUserInfo(resp.String(), cfg) +} + +func linuxDoParseUserInfo(body string, cfg config.LinuxDoConnectConfig) (email string, username string, subject string, err error) { + email = firstNonEmpty( + getGJSON(body, cfg.UserInfoEmailPath), + getGJSON(body, "email"), + getGJSON(body, "user.email"), + getGJSON(body, "data.email"), + getGJSON(body, "attributes.email"), + ) + username = firstNonEmpty( + getGJSON(body, cfg.UserInfoUsernamePath), + getGJSON(body, "username"), + getGJSON(body, "preferred_username"), + getGJSON(body, "name"), + getGJSON(body, "user.username"), + getGJSON(body, "user.name"), + ) + subject = firstNonEmpty( + getGJSON(body, cfg.UserInfoIDPath), + getGJSON(body, "sub"), + getGJSON(body, "id"), + getGJSON(body, "user_id"), + getGJSON(body, "uid"), + getGJSON(body, "user.id"), + ) + + subject = strings.TrimSpace(subject) + if subject == "" { + return "", "", "", errors.New("userinfo missing id field") + } + if !isSafeLinuxDoSubject(subject) { + return "", "", "", errors.New("userinfo returned invalid id field") + } + + email = strings.TrimSpace(email) + if email == "" { + // LinuxDo Connect 的 userinfo 可能不提供 email。为兼容现有用户模型(email 必填且唯一),使用稳定的合成邮箱。 + email = linuxDoSyntheticEmail(subject) + } + + username = strings.TrimSpace(username) + if username == "" { + username = "linuxdo_" + subject + } + + return email, username, subject, nil +} + +func buildLinuxDoAuthorizeURL(cfg config.LinuxDoConnectConfig, state string, codeChallenge string, redirectURI string) (string, error) { + u, err := url.Parse(cfg.AuthorizeURL) + if err != nil { + return "", fmt.Errorf("parse authorize_url: %w", err) + } + + q := u.Query() + q.Set("response_type", "code") + q.Set("client_id", cfg.ClientID) + q.Set("redirect_uri", redirectURI) + if strings.TrimSpace(cfg.Scopes) != "" { + q.Set("scope", cfg.Scopes) + } + q.Set("state", state) + if cfg.UsePKCE { + q.Set("code_challenge", codeChallenge) + q.Set("code_challenge_method", "S256") + } + + u.RawQuery = q.Encode() + return u.String(), nil +} + +func redirectOAuthError(c *gin.Context, frontendCallback string, code string, message string, description string) { + fragment := url.Values{} + fragment.Set("error", truncateFragmentValue(code)) + if strings.TrimSpace(message) != "" { + fragment.Set("error_message", truncateFragmentValue(message)) + } + if strings.TrimSpace(description) != "" { + fragment.Set("error_description", truncateFragmentValue(description)) + } + redirectWithFragment(c, frontendCallback, fragment) +} + +func redirectWithFragment(c *gin.Context, frontendCallback string, fragment url.Values) { + u, err := url.Parse(frontendCallback) + if err != nil { + // 兜底:尽力跳转到默认页面,避免卡死在回调页。 + c.Redirect(http.StatusFound, linuxDoOAuthDefaultRedirectTo) + return + } + if u.Scheme != "" && !strings.EqualFold(u.Scheme, "http") && !strings.EqualFold(u.Scheme, "https") { + c.Redirect(http.StatusFound, linuxDoOAuthDefaultRedirectTo) + return + } + u.Fragment = fragment.Encode() + c.Header("Cache-Control", "no-store") + c.Header("Pragma", "no-cache") + c.Redirect(http.StatusFound, u.String()) +} + +func firstNonEmpty(values ...string) string { + for _, v := range values { + v = strings.TrimSpace(v) + if v != "" { + return v + } + } + return "" +} + +func parseOAuthProviderError(body string) (providerErr string, providerDesc string) { + body = strings.TrimSpace(body) + if body == "" { + return "", "" + } + + providerErr = firstNonEmpty( + getGJSON(body, "error"), + getGJSON(body, "code"), + getGJSON(body, "error.code"), + ) + providerDesc = firstNonEmpty( + getGJSON(body, "error_description"), + getGJSON(body, "error.message"), + getGJSON(body, "message"), + getGJSON(body, "detail"), + ) + + if providerErr != "" || providerDesc != "" { + return providerErr, providerDesc + } + + values, err := url.ParseQuery(body) + if err != nil { + return "", "" + } + providerErr = firstNonEmpty(values.Get("error"), values.Get("code")) + providerDesc = firstNonEmpty(values.Get("error_description"), values.Get("error_message"), values.Get("message")) + return providerErr, providerDesc +} + +func parseLinuxDoTokenResponse(body string) (*linuxDoTokenResponse, bool) { + body = strings.TrimSpace(body) + if body == "" { + return nil, false + } + + accessToken := strings.TrimSpace(getGJSON(body, "access_token")) + if accessToken != "" { + tokenType := strings.TrimSpace(getGJSON(body, "token_type")) + refreshToken := strings.TrimSpace(getGJSON(body, "refresh_token")) + scope := strings.TrimSpace(getGJSON(body, "scope")) + expiresIn := gjson.Get(body, "expires_in").Int() + return &linuxDoTokenResponse{ + AccessToken: accessToken, + TokenType: tokenType, + ExpiresIn: expiresIn, + RefreshToken: refreshToken, + Scope: scope, + }, true + } + + values, err := url.ParseQuery(body) + if err != nil { + return nil, false + } + accessToken = strings.TrimSpace(values.Get("access_token")) + if accessToken == "" { + return nil, false + } + expiresIn := int64(0) + if raw := strings.TrimSpace(values.Get("expires_in")); raw != "" { + if v, err := strconv.ParseInt(raw, 10, 64); err == nil { + expiresIn = v + } + } + return &linuxDoTokenResponse{ + AccessToken: accessToken, + TokenType: strings.TrimSpace(values.Get("token_type")), + ExpiresIn: expiresIn, + RefreshToken: strings.TrimSpace(values.Get("refresh_token")), + Scope: strings.TrimSpace(values.Get("scope")), + }, true +} + +func getGJSON(body string, path string) string { + path = strings.TrimSpace(path) + if path == "" { + return "" + } + res := gjson.Get(body, path) + if !res.Exists() { + return "" + } + return res.String() +} + +func truncateLogValue(value string, maxLen int) string { + value = strings.TrimSpace(value) + if value == "" || maxLen <= 0 { + return "" + } + if len(value) <= maxLen { + return value + } + value = value[:maxLen] + for !utf8.ValidString(value) { + value = value[:len(value)-1] + } + return value +} + +func singleLine(value string) string { + value = strings.TrimSpace(value) + if value == "" { + return "" + } + return strings.Join(strings.Fields(value), " ") +} + +func sanitizeFrontendRedirectPath(path string) string { + path = strings.TrimSpace(path) + if path == "" { + return "" + } + if len(path) > linuxDoOAuthMaxRedirectLen { + return "" + } + // 只允许同源相对路径(避免开放重定向)。 + if !strings.HasPrefix(path, "/") { + return "" + } + if strings.HasPrefix(path, "//") { + return "" + } + if strings.Contains(path, "://") { + return "" + } + if strings.ContainsAny(path, "\r\n") { + return "" + } + return path +} + +func isRequestHTTPS(c *gin.Context) bool { + if c.Request.TLS != nil { + return true + } + proto := strings.ToLower(strings.TrimSpace(c.GetHeader("X-Forwarded-Proto"))) + return proto == "https" +} + +func encodeCookieValue(value string) string { + return base64.RawURLEncoding.EncodeToString([]byte(value)) +} + +func decodeCookieValue(value string) (string, error) { + raw, err := base64.RawURLEncoding.DecodeString(value) + if err != nil { + return "", err + } + return string(raw), nil +} + +func readCookieDecoded(c *gin.Context, name string) (string, error) { + ck, err := c.Request.Cookie(name) + if err != nil { + return "", err + } + return decodeCookieValue(ck.Value) +} + +func setCookie(c *gin.Context, name string, value string, maxAgeSec int, secure bool) { + http.SetCookie(c.Writer, &http.Cookie{ + Name: name, + Value: value, + Path: linuxDoOAuthCookiePath, + MaxAge: maxAgeSec, + HttpOnly: true, + Secure: secure, + SameSite: http.SameSiteLaxMode, + }) +} + +func clearCookie(c *gin.Context, name string, secure bool) { + http.SetCookie(c.Writer, &http.Cookie{ + Name: name, + Value: "", + Path: linuxDoOAuthCookiePath, + MaxAge: -1, + HttpOnly: true, + Secure: secure, + SameSite: http.SameSiteLaxMode, + }) +} + +func truncateFragmentValue(value string) string { + value = strings.TrimSpace(value) + if value == "" { + return "" + } + if len(value) > linuxDoOAuthMaxFragmentValueLen { + value = value[:linuxDoOAuthMaxFragmentValueLen] + for !utf8.ValidString(value) { + value = value[:len(value)-1] + } + } + return value +} + +func buildBearerAuthorization(tokenType, accessToken string) (string, error) { + tokenType = strings.TrimSpace(tokenType) + if tokenType == "" { + tokenType = "Bearer" + } + if !strings.EqualFold(tokenType, "Bearer") { + return "", fmt.Errorf("unsupported token_type: %s", tokenType) + } + + accessToken = strings.TrimSpace(accessToken) + if accessToken == "" { + return "", errors.New("missing access_token") + } + if strings.ContainsAny(accessToken, " \t\r\n") { + return "", errors.New("access_token contains whitespace") + } + return "Bearer " + accessToken, nil +} + +func isSafeLinuxDoSubject(subject string) bool { + subject = strings.TrimSpace(subject) + if subject == "" || len(subject) > linuxDoOAuthMaxSubjectLen { + return false + } + for _, r := range subject { + switch { + case r >= '0' && r <= '9': + case r >= 'a' && r <= 'z': + case r >= 'A' && r <= 'Z': + case r == '_' || r == '-': + default: + return false + } + } + return true +} + +func linuxDoSyntheticEmail(subject string) string { + subject = strings.TrimSpace(subject) + if subject == "" { + return "" + } + return "linuxdo-" + subject + service.LinuxDoConnectSyntheticEmailDomain +} diff --git a/backend/internal/handler/auth_linuxdo_oauth_test.go b/backend/internal/handler/auth_linuxdo_oauth_test.go new file mode 100644 index 00000000..ff169c52 --- /dev/null +++ b/backend/internal/handler/auth_linuxdo_oauth_test.go @@ -0,0 +1,108 @@ +package handler + +import ( + "strings" + "testing" + + "github.com/Wei-Shaw/sub2api/internal/config" + "github.com/stretchr/testify/require" +) + +func TestSanitizeFrontendRedirectPath(t *testing.T) { + require.Equal(t, "/dashboard", sanitizeFrontendRedirectPath("/dashboard")) + require.Equal(t, "/dashboard", sanitizeFrontendRedirectPath(" /dashboard ")) + require.Equal(t, "", sanitizeFrontendRedirectPath("dashboard")) + require.Equal(t, "", sanitizeFrontendRedirectPath("//evil.com")) + require.Equal(t, "", sanitizeFrontendRedirectPath("https://evil.com")) + require.Equal(t, "", sanitizeFrontendRedirectPath("/\nfoo")) + + long := "/" + strings.Repeat("a", linuxDoOAuthMaxRedirectLen) + require.Equal(t, "", sanitizeFrontendRedirectPath(long)) +} + +func TestBuildBearerAuthorization(t *testing.T) { + auth, err := buildBearerAuthorization("", "token123") + require.NoError(t, err) + require.Equal(t, "Bearer token123", auth) + + auth, err = buildBearerAuthorization("bearer", "token123") + require.NoError(t, err) + require.Equal(t, "Bearer token123", auth) + + _, err = buildBearerAuthorization("MAC", "token123") + require.Error(t, err) + + _, err = buildBearerAuthorization("Bearer", "token 123") + require.Error(t, err) +} + +func TestLinuxDoParseUserInfoParsesIDAndUsername(t *testing.T) { + cfg := config.LinuxDoConnectConfig{ + UserInfoURL: "https://connect.linux.do/api/user", + } + + email, username, subject, err := linuxDoParseUserInfo(`{"id":123,"username":"alice"}`, cfg) + require.NoError(t, err) + require.Equal(t, "123", subject) + require.Equal(t, "alice", username) + require.Equal(t, "linuxdo-123@linuxdo-connect.invalid", email) +} + +func TestLinuxDoParseUserInfoDefaultsUsername(t *testing.T) { + cfg := config.LinuxDoConnectConfig{ + UserInfoURL: "https://connect.linux.do/api/user", + } + + email, username, subject, err := linuxDoParseUserInfo(`{"id":"123"}`, cfg) + require.NoError(t, err) + require.Equal(t, "123", subject) + require.Equal(t, "linuxdo_123", username) + require.Equal(t, "linuxdo-123@linuxdo-connect.invalid", email) +} + +func TestLinuxDoParseUserInfoRejectsUnsafeSubject(t *testing.T) { + cfg := config.LinuxDoConnectConfig{ + UserInfoURL: "https://connect.linux.do/api/user", + } + + _, _, _, err := linuxDoParseUserInfo(`{"id":"123@456"}`, cfg) + require.Error(t, err) + + tooLong := strings.Repeat("a", linuxDoOAuthMaxSubjectLen+1) + _, _, _, err = linuxDoParseUserInfo(`{"id":"`+tooLong+`"}`, cfg) + require.Error(t, err) +} + +func TestParseOAuthProviderErrorJSON(t *testing.T) { + code, desc := parseOAuthProviderError(`{"error":"invalid_client","error_description":"bad secret"}`) + require.Equal(t, "invalid_client", code) + require.Equal(t, "bad secret", desc) +} + +func TestParseOAuthProviderErrorForm(t *testing.T) { + code, desc := parseOAuthProviderError("error=invalid_request&error_description=Missing+code_verifier") + require.Equal(t, "invalid_request", code) + require.Equal(t, "Missing code_verifier", desc) +} + +func TestParseLinuxDoTokenResponseJSON(t *testing.T) { + token, ok := parseLinuxDoTokenResponse(`{"access_token":"t1","token_type":"Bearer","expires_in":3600,"scope":"user"}`) + require.True(t, ok) + require.Equal(t, "t1", token.AccessToken) + require.Equal(t, "Bearer", token.TokenType) + require.Equal(t, int64(3600), token.ExpiresIn) + require.Equal(t, "user", token.Scope) +} + +func TestParseLinuxDoTokenResponseForm(t *testing.T) { + token, ok := parseLinuxDoTokenResponse("access_token=t2&token_type=bearer&expires_in=60") + require.True(t, ok) + require.Equal(t, "t2", token.AccessToken) + require.Equal(t, "bearer", token.TokenType) + require.Equal(t, int64(60), token.ExpiresIn) +} + +func TestSingleLineStripsWhitespace(t *testing.T) { + require.Equal(t, "hello world", singleLine("hello\r\nworld")) + require.Equal(t, "", singleLine("\n\t\r")) +} diff --git a/backend/internal/handler/dto/mappers.go b/backend/internal/handler/dto/mappers.go index bf15e9dc..9a672064 100644 --- a/backend/internal/handler/dto/mappers.go +++ b/backend/internal/handler/dto/mappers.go @@ -85,6 +85,8 @@ func GroupFromServiceShallow(g *service.Group) *Group { ImagePrice1K: g.ImagePrice1K, ImagePrice2K: g.ImagePrice2K, ImagePrice4K: g.ImagePrice4K, + ClaudeCodeOnly: g.ClaudeCodeOnly, + FallbackGroupID: g.FallbackGroupID, CreatedAt: g.CreatedAt, UpdatedAt: g.UpdatedAt, AccountCount: g.AccountCount, @@ -280,6 +282,7 @@ func usageLogFromServiceBase(l *service.UsageLog, account *AccountSummary) *Usag FirstTokenMs: l.FirstTokenMs, ImageCount: l.ImageCount, ImageSize: l.ImageSize, + UserAgent: l.UserAgent, CreatedAt: l.CreatedAt, User: UserFromServiceShallow(l.User), APIKey: APIKeyFromService(l.APIKey), diff --git a/backend/internal/handler/dto/settings.go b/backend/internal/handler/dto/settings.go index 4c50cedf..dab5eb75 100644 --- a/backend/internal/handler/dto/settings.go +++ b/backend/internal/handler/dto/settings.go @@ -17,6 +17,11 @@ type SystemSettings struct { TurnstileSiteKey string `json:"turnstile_site_key"` TurnstileSecretKeyConfigured bool `json:"turnstile_secret_key_configured"` + LinuxDoConnectEnabled bool `json:"linuxdo_connect_enabled"` + LinuxDoConnectClientID string `json:"linuxdo_connect_client_id"` + LinuxDoConnectClientSecretConfigured bool `json:"linuxdo_connect_client_secret_configured"` + LinuxDoConnectRedirectURL string `json:"linuxdo_connect_redirect_url"` + SiteName string `json:"site_name"` SiteLogo string `json:"site_logo"` SiteSubtitle string `json:"site_subtitle"` @@ -50,5 +55,6 @@ type PublicSettings struct { APIBaseURL string `json:"api_base_url"` ContactInfo string `json:"contact_info"` DocURL string `json:"doc_url"` + LinuxDoOAuthEnabled bool `json:"linuxdo_oauth_enabled"` Version string `json:"version"` } diff --git a/backend/internal/handler/dto/types.go b/backend/internal/handler/dto/types.go index bb953fae..03f7080b 100644 --- a/backend/internal/handler/dto/types.go +++ b/backend/internal/handler/dto/types.go @@ -52,6 +52,10 @@ type Group struct { ImagePrice2K *float64 `json:"image_price_2k"` ImagePrice4K *float64 `json:"image_price_4k"` + // Claude Code 客户端限制 + ClaudeCodeOnly bool `json:"claude_code_only"` + FallbackGroupID *int64 `json:"fallback_group_id"` + CreatedAt time.Time `json:"created_at"` UpdatedAt time.Time `json:"updated_at"` @@ -180,6 +184,9 @@ type UsageLog struct { ImageCount int `json:"image_count"` ImageSize *string `json:"image_size"` + // User-Agent + UserAgent *string `json:"user_agent"` + CreatedAt time.Time `json:"created_at"` User *User `json:"user,omitempty"` diff --git a/backend/internal/handler/gateway_handler.go b/backend/internal/handler/gateway_handler.go index 2d8ff957..48a827f3 100644 --- a/backend/internal/handler/gateway_handler.go +++ b/backend/internal/handler/gateway_handler.go @@ -96,6 +96,9 @@ func (h *GatewayHandler) Messages(c *gin.Context) { reqModel := parsedReq.Model reqStream := parsedReq.Stream + // 设置 Claude Code 客户端标识到 context(用于分组限制检查) + SetClaudeCodeClientContext(c, body) + // 验证 model 必填 if reqModel == "" { h.errorResponse(c, http.StatusBadRequest, "invalid_request_error", "model is required") @@ -229,7 +232,7 @@ func (h *GatewayHandler) Messages(c *gin.Context) { h.handleConcurrencyError(c, err, "account", streamStarted) return } - if err := h.gatewayService.BindStickySession(c.Request.Context(), sessionKey, account.ID); err != nil { + if err := h.gatewayService.BindStickySession(c.Request.Context(), apiKey.GroupID, sessionKey, account.ID); err != nil { log.Printf("Bind sticky session failed: %v", err) } } @@ -357,7 +360,7 @@ func (h *GatewayHandler) Messages(c *gin.Context) { h.handleConcurrencyError(c, err, "account", streamStarted) return } - if err := h.gatewayService.BindStickySession(c.Request.Context(), sessionKey, account.ID); err != nil { + if err := h.gatewayService.BindStickySession(c.Request.Context(), apiKey.GroupID, sessionKey, account.ID); err != nil { log.Printf("Bind sticky session failed: %v", err) } } @@ -683,6 +686,9 @@ func (h *GatewayHandler) CountTokens(c *gin.Context) { return } + // 设置 Claude Code 客户端标识到 context(用于分组限制检查) + SetClaudeCodeClientContext(c, body) + // 验证 model 必填 if parsedReq.Model == "" { h.errorResponse(c, http.StatusBadRequest, "invalid_request_error", "model is required") diff --git a/backend/internal/handler/gateway_helper.go b/backend/internal/handler/gateway_helper.go index 5de519c7..0393f954 100644 --- a/backend/internal/handler/gateway_helper.go +++ b/backend/internal/handler/gateway_helper.go @@ -2,6 +2,7 @@ package handler import ( "context" + "encoding/json" "fmt" "math/rand" "net/http" @@ -13,6 +14,26 @@ import ( "github.com/gin-gonic/gin" ) +// claudeCodeValidator is a singleton validator for Claude Code client detection +var claudeCodeValidator = service.NewClaudeCodeValidator() + +// SetClaudeCodeClientContext 检查请求是否来自 Claude Code 客户端,并设置到 context 中 +// 返回更新后的 context +func SetClaudeCodeClientContext(c *gin.Context, body []byte) { + // 解析请求体为 map + var bodyMap map[string]any + if len(body) > 0 { + _ = json.Unmarshal(body, &bodyMap) + } + + // 验证是否为 Claude Code 客户端 + isClaudeCode := claudeCodeValidator.Validate(c.Request, bodyMap) + + // 更新 request context + ctx := service.SetClaudeCodeClient(c.Request.Context(), isClaudeCode) + c.Request = c.Request.WithContext(ctx) +} + // 并发槽位等待相关常量 // // 性能优化说明: diff --git a/backend/internal/handler/gemini_v1beta_handler.go b/backend/internal/handler/gemini_v1beta_handler.go index fc8c7cd6..0cbe44f2 100644 --- a/backend/internal/handler/gemini_v1beta_handler.go +++ b/backend/internal/handler/gemini_v1beta_handler.go @@ -203,6 +203,10 @@ func (h *GatewayHandler) GeminiV1BetaModels(c *gin.Context) { // 3) select account (sticky session based on request body) parsedReq, _ := service.ParseGatewayRequest(body) + + // 设置 Claude Code 客户端标识到 context(用于分组限制检查) + SetClaudeCodeClientContext(c, body) + sessionHash := h.gatewayService.GenerateSessionHash(parsedReq) sessionKey := sessionHash if sessionHash != "" { @@ -262,7 +266,7 @@ func (h *GatewayHandler) GeminiV1BetaModels(c *gin.Context) { googleError(c, http.StatusTooManyRequests, err.Error()) return } - if err := h.gatewayService.BindStickySession(c.Request.Context(), sessionKey, account.ID); err != nil { + if err := h.gatewayService.BindStickySession(c.Request.Context(), apiKey.GroupID, sessionKey, account.ID); err != nil { log.Printf("Bind sticky session failed: %v", err) } } diff --git a/backend/internal/handler/openai_gateway_handler.go b/backend/internal/handler/openai_gateway_handler.go index f76a9851..70131417 100644 --- a/backend/internal/handler/openai_gateway_handler.go +++ b/backend/internal/handler/openai_gateway_handler.go @@ -206,7 +206,7 @@ func (h *OpenAIGatewayHandler) Responses(c *gin.Context) { h.handleConcurrencyError(c, err, "account", streamStarted) return } - if err := h.gatewayService.BindStickySession(c.Request.Context(), sessionHash, account.ID); err != nil { + if err := h.gatewayService.BindStickySession(c.Request.Context(), apiKey.GroupID, sessionHash, account.ID); err != nil { log.Printf("Bind sticky session failed: %v", err) } } diff --git a/backend/internal/handler/setting_handler.go b/backend/internal/handler/setting_handler.go index 3cae7a7f..e1b20c8c 100644 --- a/backend/internal/handler/setting_handler.go +++ b/backend/internal/handler/setting_handler.go @@ -42,6 +42,7 @@ func (h *SettingHandler) GetPublicSettings(c *gin.Context) { APIBaseURL: settings.APIBaseURL, ContactInfo: settings.ContactInfo, DocURL: settings.DocURL, + LinuxDoOAuthEnabled: settings.LinuxDoOAuthEnabled, Version: h.version, }) } diff --git a/backend/internal/pkg/antigravity/client.go b/backend/internal/pkg/antigravity/client.go index 8ff75f57..1248be95 100644 --- a/backend/internal/pkg/antigravity/client.go +++ b/backend/internal/pkg/antigravity/client.go @@ -5,8 +5,11 @@ import ( "bytes" "context" "encoding/json" + "errors" "fmt" "io" + "log" + "net" "net/http" "net/url" "strings" @@ -22,10 +25,10 @@ func resolveHost(urlStr string) string { return parsed.Host } -// NewAPIRequest 创建 Antigravity API 请求(v1internal 端点) -func NewAPIRequest(ctx context.Context, action, accessToken string, body []byte) (*http.Request, error) { +// NewAPIRequestWithURL 使用指定的 base URL 创建 Antigravity API 请求(v1internal 端点) +func NewAPIRequestWithURL(ctx context.Context, baseURL, action, accessToken string, body []byte) (*http.Request, error) { // 构建 URL,流式请求添加 ?alt=sse 参数 - apiURL := fmt.Sprintf("%s/v1internal:%s", BaseURL, action) + apiURL := fmt.Sprintf("%s/v1internal:%s", baseURL, action) isStream := action == "streamGenerateContent" if isStream { apiURL += "?alt=sse" @@ -53,11 +56,15 @@ func NewAPIRequest(ctx context.Context, action, accessToken string, body []byte) req.Host = host } - // 注意:requestType 已在 JSON body 的 V1InternalRequest 中设置,不需要 HTTP Header - return req, nil } +// NewAPIRequest 使用默认 URL 创建 Antigravity API 请求(v1internal 端点) +// 向后兼容:仅使用默认 BaseURL +func NewAPIRequest(ctx context.Context, action, accessToken string, body []byte) (*http.Request, error) { + return NewAPIRequestWithURL(ctx, BaseURL, action, accessToken, body) +} + // TokenResponse Google OAuth token 响应 type TokenResponse struct { AccessToken string `json:"access_token"` @@ -164,6 +171,38 @@ func NewClient(proxyURL string) *Client { } } +// isConnectionError 判断是否为连接错误(网络超时、DNS 失败、连接拒绝) +func isConnectionError(err error) bool { + if err == nil { + return false + } + + // 检查超时错误 + var netErr net.Error + if errors.As(err, &netErr) && netErr.Timeout() { + return true + } + + // 检查连接错误(DNS 失败、连接拒绝) + var opErr *net.OpError + if errors.As(err, &opErr) { + return true + } + + // 检查 URL 错误 + var urlErr *url.Error + return errors.As(err, &urlErr) +} + +// shouldFallbackToNextURL 判断是否应切换到下一个 URL +// 仅连接错误和 HTTP 429 触发 URL 降级 +func shouldFallbackToNextURL(err error, statusCode int) bool { + if isConnectionError(err) { + return true + } + return statusCode == http.StatusTooManyRequests +} + // ExchangeCode 用 authorization code 交换 token func (c *Client) ExchangeCode(ctx context.Context, code, codeVerifier string) (*TokenResponse, error) { params := url.Values{} @@ -272,6 +311,7 @@ func (c *Client) GetUserInfo(ctx context.Context, accessToken string) (*UserInfo } // LoadCodeAssist 获取账户信息,返回解析后的结构体和原始 JSON +// 支持 URL fallback:sandbox → daily → prod func (c *Client) LoadCodeAssist(ctx context.Context, accessToken string) (*LoadCodeAssistResponse, map[string]any, error) { reqBody := LoadCodeAssistRequest{} reqBody.Metadata.IDEType = "ANTIGRAVITY" @@ -281,40 +321,65 @@ func (c *Client) LoadCodeAssist(ctx context.Context, accessToken string) (*LoadC return nil, nil, fmt.Errorf("序列化请求失败: %w", err) } - url := BaseURL + "/v1internal:loadCodeAssist" - req, err := http.NewRequestWithContext(ctx, http.MethodPost, url, strings.NewReader(string(bodyBytes))) - if err != nil { - return nil, nil, fmt.Errorf("创建请求失败: %w", err) - } - req.Header.Set("Authorization", "Bearer "+accessToken) - req.Header.Set("Content-Type", "application/json") - req.Header.Set("User-Agent", UserAgent) - - resp, err := c.httpClient.Do(req) - if err != nil { - return nil, nil, fmt.Errorf("loadCodeAssist 请求失败: %w", err) - } - defer func() { _ = resp.Body.Close() }() - - respBodyBytes, err := io.ReadAll(resp.Body) - if err != nil { - return nil, nil, fmt.Errorf("读取响应失败: %w", err) + // 获取可用的 URL 列表 + availableURLs := DefaultURLAvailability.GetAvailableURLs() + if len(availableURLs) == 0 { + availableURLs = BaseURLs // 所有 URL 都不可用时,重试所有 } - if resp.StatusCode != http.StatusOK { - return nil, nil, fmt.Errorf("loadCodeAssist 失败 (HTTP %d): %s", resp.StatusCode, string(respBodyBytes)) + var lastErr error + for urlIdx, baseURL := range availableURLs { + apiURL := baseURL + "/v1internal:loadCodeAssist" + req, err := http.NewRequestWithContext(ctx, http.MethodPost, apiURL, strings.NewReader(string(bodyBytes))) + if err != nil { + lastErr = fmt.Errorf("创建请求失败: %w", err) + continue + } + req.Header.Set("Authorization", "Bearer "+accessToken) + req.Header.Set("Content-Type", "application/json") + req.Header.Set("User-Agent", UserAgent) + + resp, err := c.httpClient.Do(req) + if err != nil { + lastErr = fmt.Errorf("loadCodeAssist 请求失败: %w", err) + if shouldFallbackToNextURL(err, 0) && urlIdx < len(availableURLs)-1 { + DefaultURLAvailability.MarkUnavailable(baseURL) + log.Printf("[antigravity] loadCodeAssist URL fallback: %s -> %s", baseURL, availableURLs[urlIdx+1]) + continue + } + return nil, nil, lastErr + } + + respBodyBytes, err := io.ReadAll(resp.Body) + _ = resp.Body.Close() // 立即关闭,避免循环内 defer 导致的资源泄漏 + if err != nil { + return nil, nil, fmt.Errorf("读取响应失败: %w", err) + } + + // 检查是否需要 URL 降级 + if shouldFallbackToNextURL(nil, resp.StatusCode) && urlIdx < len(availableURLs)-1 { + DefaultURLAvailability.MarkUnavailable(baseURL) + log.Printf("[antigravity] loadCodeAssist URL fallback (HTTP %d): %s -> %s", resp.StatusCode, baseURL, availableURLs[urlIdx+1]) + continue + } + + if resp.StatusCode != http.StatusOK { + return nil, nil, fmt.Errorf("loadCodeAssist 失败 (HTTP %d): %s", resp.StatusCode, string(respBodyBytes)) + } + + var loadResp LoadCodeAssistResponse + if err := json.Unmarshal(respBodyBytes, &loadResp); err != nil { + return nil, nil, fmt.Errorf("响应解析失败: %w", err) + } + + // 解析原始 JSON 为 map + var rawResp map[string]any + _ = json.Unmarshal(respBodyBytes, &rawResp) + + return &loadResp, rawResp, nil } - var loadResp LoadCodeAssistResponse - if err := json.Unmarshal(respBodyBytes, &loadResp); err != nil { - return nil, nil, fmt.Errorf("响应解析失败: %w", err) - } - - // 解析原始 JSON 为 map - var rawResp map[string]any - _ = json.Unmarshal(respBodyBytes, &rawResp) - - return &loadResp, rawResp, nil + return nil, nil, lastErr } // ModelQuotaInfo 模型配额信息 @@ -339,6 +404,7 @@ type FetchAvailableModelsResponse struct { } // FetchAvailableModels 获取可用模型和配额信息,返回解析后的结构体和原始 JSON +// 支持 URL fallback:sandbox → daily → prod func (c *Client) FetchAvailableModels(ctx context.Context, accessToken, projectID string) (*FetchAvailableModelsResponse, map[string]any, error) { reqBody := FetchAvailableModelsRequest{Project: projectID} bodyBytes, err := json.Marshal(reqBody) @@ -346,38 +412,63 @@ func (c *Client) FetchAvailableModels(ctx context.Context, accessToken, projectI return nil, nil, fmt.Errorf("序列化请求失败: %w", err) } - apiURL := BaseURL + "/v1internal:fetchAvailableModels" - req, err := http.NewRequestWithContext(ctx, http.MethodPost, apiURL, strings.NewReader(string(bodyBytes))) - if err != nil { - return nil, nil, fmt.Errorf("创建请求失败: %w", err) - } - req.Header.Set("Authorization", "Bearer "+accessToken) - req.Header.Set("Content-Type", "application/json") - req.Header.Set("User-Agent", UserAgent) - - resp, err := c.httpClient.Do(req) - if err != nil { - return nil, nil, fmt.Errorf("fetchAvailableModels 请求失败: %w", err) - } - defer func() { _ = resp.Body.Close() }() - - respBodyBytes, err := io.ReadAll(resp.Body) - if err != nil { - return nil, nil, fmt.Errorf("读取响应失败: %w", err) + // 获取可用的 URL 列表 + availableURLs := DefaultURLAvailability.GetAvailableURLs() + if len(availableURLs) == 0 { + availableURLs = BaseURLs // 所有 URL 都不可用时,重试所有 } - if resp.StatusCode != http.StatusOK { - return nil, nil, fmt.Errorf("fetchAvailableModels 失败 (HTTP %d): %s", resp.StatusCode, string(respBodyBytes)) + var lastErr error + for urlIdx, baseURL := range availableURLs { + apiURL := baseURL + "/v1internal:fetchAvailableModels" + req, err := http.NewRequestWithContext(ctx, http.MethodPost, apiURL, strings.NewReader(string(bodyBytes))) + if err != nil { + lastErr = fmt.Errorf("创建请求失败: %w", err) + continue + } + req.Header.Set("Authorization", "Bearer "+accessToken) + req.Header.Set("Content-Type", "application/json") + req.Header.Set("User-Agent", UserAgent) + + resp, err := c.httpClient.Do(req) + if err != nil { + lastErr = fmt.Errorf("fetchAvailableModels 请求失败: %w", err) + if shouldFallbackToNextURL(err, 0) && urlIdx < len(availableURLs)-1 { + DefaultURLAvailability.MarkUnavailable(baseURL) + log.Printf("[antigravity] fetchAvailableModels URL fallback: %s -> %s", baseURL, availableURLs[urlIdx+1]) + continue + } + return nil, nil, lastErr + } + + respBodyBytes, err := io.ReadAll(resp.Body) + _ = resp.Body.Close() // 立即关闭,避免循环内 defer 导致的资源泄漏 + if err != nil { + return nil, nil, fmt.Errorf("读取响应失败: %w", err) + } + + // 检查是否需要 URL 降级 + if shouldFallbackToNextURL(nil, resp.StatusCode) && urlIdx < len(availableURLs)-1 { + DefaultURLAvailability.MarkUnavailable(baseURL) + log.Printf("[antigravity] fetchAvailableModels URL fallback (HTTP %d): %s -> %s", resp.StatusCode, baseURL, availableURLs[urlIdx+1]) + continue + } + + if resp.StatusCode != http.StatusOK { + return nil, nil, fmt.Errorf("fetchAvailableModels 失败 (HTTP %d): %s", resp.StatusCode, string(respBodyBytes)) + } + + var modelsResp FetchAvailableModelsResponse + if err := json.Unmarshal(respBodyBytes, &modelsResp); err != nil { + return nil, nil, fmt.Errorf("响应解析失败: %w", err) + } + + // 解析原始 JSON 为 map + var rawResp map[string]any + _ = json.Unmarshal(respBodyBytes, &rawResp) + + return &modelsResp, rawResp, nil } - var modelsResp FetchAvailableModelsResponse - if err := json.Unmarshal(respBodyBytes, &modelsResp); err != nil { - return nil, nil, fmt.Errorf("响应解析失败: %w", err) - } - - // 解析原始 JSON 为 map - var rawResp map[string]any - _ = json.Unmarshal(respBodyBytes, &rawResp) - - return &modelsResp, rawResp, nil + return nil, nil, lastErr } diff --git a/backend/internal/pkg/antigravity/oauth.go b/backend/internal/pkg/antigravity/oauth.go index e88c203b..736c45df 100644 --- a/backend/internal/pkg/antigravity/oauth.go +++ b/backend/internal/pkg/antigravity/oauth.go @@ -32,17 +32,79 @@ const ( "https://www.googleapis.com/auth/cclog " + "https://www.googleapis.com/auth/experimentsandconfigs" - // API 端点 - // 优先使用 sandbox daily URL,配额更宽松 - BaseURL = "https://daily-cloudcode-pa.sandbox.googleapis.com" - // User-Agent(模拟官方客户端) UserAgent = "antigravity/1.104.0 darwin/arm64" // Session 过期时间 SessionTTL = 30 * time.Minute + + // URL 可用性 TTL(不可用 URL 的恢复时间) + URLAvailabilityTTL = 5 * time.Minute ) +// BaseURLs 定义 Antigravity API 端点,按优先级排序 +// fallback 顺序: sandbox → daily → prod +var BaseURLs = []string{ + "https://daily-cloudcode-pa.sandbox.googleapis.com", // sandbox + "https://daily-cloudcode-pa.googleapis.com", // daily + "https://cloudcode-pa.googleapis.com", // prod +} + +// BaseURL 默认 URL(保持向后兼容) +var BaseURL = BaseURLs[0] + +// URLAvailability 管理 URL 可用性状态(带 TTL 自动恢复) +type URLAvailability struct { + mu sync.RWMutex + unavailable map[string]time.Time // URL -> 恢复时间 + ttl time.Duration +} + +// DefaultURLAvailability 全局 URL 可用性管理器 +var DefaultURLAvailability = NewURLAvailability(URLAvailabilityTTL) + +// NewURLAvailability 创建 URL 可用性管理器 +func NewURLAvailability(ttl time.Duration) *URLAvailability { + return &URLAvailability{ + unavailable: make(map[string]time.Time), + ttl: ttl, + } +} + +// MarkUnavailable 标记 URL 临时不可用 +func (u *URLAvailability) MarkUnavailable(url string) { + u.mu.Lock() + defer u.mu.Unlock() + u.unavailable[url] = time.Now().Add(u.ttl) +} + +// IsAvailable 检查 URL 是否可用 +func (u *URLAvailability) IsAvailable(url string) bool { + u.mu.RLock() + defer u.mu.RUnlock() + expiry, exists := u.unavailable[url] + if !exists { + return true + } + return time.Now().After(expiry) +} + +// GetAvailableURLs 返回可用的 URL 列表(保持优先级顺序) +func (u *URLAvailability) GetAvailableURLs() []string { + u.mu.RLock() + defer u.mu.RUnlock() + + now := time.Now() + result := make([]string, 0, len(BaseURLs)) + for _, url := range BaseURLs { + expiry, exists := u.unavailable[url] + if !exists || now.After(expiry) { + result = append(result, url) + } + } + return result +} + // OAuthSession 保存 OAuth 授权流程的临时状态 type OAuthSession struct { State string `json:"state"` diff --git a/backend/internal/pkg/ctxkey/ctxkey.go b/backend/internal/pkg/ctxkey/ctxkey.go index 8920ea69..3add78de 100644 --- a/backend/internal/pkg/ctxkey/ctxkey.go +++ b/backend/internal/pkg/ctxkey/ctxkey.go @@ -7,4 +7,6 @@ type Key string const ( // ForcePlatform 强制平台(用于 /antigravity 路由),由 middleware.ForcePlatform 设置 ForcePlatform Key = "ctx_force_platform" + // IsClaudeCodeClient 是否为 Claude Code 客户端,由中间件设置 + IsClaudeCodeClient Key = "ctx_is_claude_code_client" ) diff --git a/backend/internal/pkg/geminicli/constants.go b/backend/internal/pkg/geminicli/constants.go index 6d7e5a5d..d4d52116 100644 --- a/backend/internal/pkg/geminicli/constants.go +++ b/backend/internal/pkg/geminicli/constants.go @@ -27,10 +27,9 @@ const ( // https://www.googleapis.com/auth/generative-language.retriever (often with cloud-platform). DefaultAIStudioScopes = "https://www.googleapis.com/auth/cloud-platform https://www.googleapis.com/auth/generative-language.retriever" - // DefaultScopes for Google One (personal Google accounts with Gemini access) - // Only used when a custom OAuth client is configured. When using the built-in Gemini CLI client, - // Google One uses DefaultCodeAssistScopes (same as code_assist) because the built-in client - // cannot request restricted scopes like generative-language.retriever or drive.readonly. + // DefaultGoogleOneScopes (DEPRECATED, no longer used) + // Google One now always uses the built-in Gemini CLI client with DefaultCodeAssistScopes. + // This constant is kept for backward compatibility but is not actively used. DefaultGoogleOneScopes = "https://www.googleapis.com/auth/cloud-platform https://www.googleapis.com/auth/generative-language.retriever https://www.googleapis.com/auth/drive.readonly https://www.googleapis.com/auth/userinfo.email https://www.googleapis.com/auth/userinfo.profile" // GeminiCLIRedirectURI is the redirect URI used by Gemini CLI for Code Assist OAuth. diff --git a/backend/internal/pkg/geminicli/oauth.go b/backend/internal/pkg/geminicli/oauth.go index 473017a2..c71e8aad 100644 --- a/backend/internal/pkg/geminicli/oauth.go +++ b/backend/internal/pkg/geminicli/oauth.go @@ -185,13 +185,9 @@ func EffectiveOAuthConfig(cfg OAuthConfig, oauthType string) (OAuthConfig, error effective.Scopes = DefaultAIStudioScopes } case "google_one": - // Google One uses built-in Gemini CLI client (same as code_assist) - // Built-in client can't request restricted scopes like generative-language.retriever - if isBuiltinClient { - effective.Scopes = DefaultCodeAssistScopes - } else { - effective.Scopes = DefaultGoogleOneScopes - } + // Google One always uses built-in Gemini CLI client (same as code_assist) + // Built-in client can't request restricted scopes like generative-language.retriever or drive.readonly + effective.Scopes = DefaultCodeAssistScopes default: // Default to Code Assist scopes effective.Scopes = DefaultCodeAssistScopes diff --git a/backend/internal/pkg/geminicli/oauth_test.go b/backend/internal/pkg/geminicli/oauth_test.go index 0520f0f2..0770730a 100644 --- a/backend/internal/pkg/geminicli/oauth_test.go +++ b/backend/internal/pkg/geminicli/oauth_test.go @@ -23,14 +23,14 @@ func TestEffectiveOAuthConfig_GoogleOne(t *testing.T) { wantErr: false, }, { - name: "Google One with custom client", + name: "Google One always uses built-in client (even if custom credentials passed)", input: OAuthConfig{ ClientID: "custom-client-id", ClientSecret: "custom-client-secret", }, oauthType: "google_one", wantClientID: "custom-client-id", - wantScopes: DefaultGoogleOneScopes, + wantScopes: DefaultCodeAssistScopes, // Uses code assist scopes even with custom client wantErr: false, }, { diff --git a/backend/internal/repository/account_repo.go b/backend/internal/repository/account_repo.go index 30a783bc..04ca7052 100644 --- a/backend/internal/repository/account_repo.go +++ b/backend/internal/repository/account_repo.go @@ -886,6 +886,11 @@ func (r *accountRepository) BulkUpdate(ctx context.Context, ids []int64, updates args = append(args, *updates.Status) idx++ } + if updates.Schedulable != nil { + setClauses = append(setClauses, "schedulable = $"+itoa(idx)) + args = append(args, *updates.Schedulable) + idx++ + } // JSONB 需要合并而非覆盖,使用 raw SQL 保持旧行为。 if len(updates.Credentials) > 0 { payload, err := json.Marshal(updates.Credentials) diff --git a/backend/internal/repository/api_key_repo.go b/backend/internal/repository/api_key_repo.go index 4384bff5..f3b07616 100644 --- a/backend/internal/repository/api_key_repo.go +++ b/backend/internal/repository/api_key_repo.go @@ -325,6 +325,8 @@ func groupEntityToService(g *dbent.Group) *service.Group { ImagePrice2K: g.ImagePrice2k, ImagePrice4K: g.ImagePrice4k, DefaultValidityDays: g.DefaultValidityDays, + ClaudeCodeOnly: g.ClaudeCodeOnly, + FallbackGroupID: g.FallbackGroupID, CreatedAt: g.CreatedAt, UpdatedAt: g.UpdatedAt, } diff --git a/backend/internal/repository/gateway_cache.go b/backend/internal/repository/gateway_cache.go index 4ed47e9b..40a9ad05 100644 --- a/backend/internal/repository/gateway_cache.go +++ b/backend/internal/repository/gateway_cache.go @@ -2,6 +2,7 @@ package repository import ( "context" + "fmt" "time" "github.com/Wei-Shaw/sub2api/internal/service" @@ -18,17 +19,23 @@ func NewGatewayCache(rdb *redis.Client) service.GatewayCache { return &gatewayCache{rdb: rdb} } -func (c *gatewayCache) GetSessionAccountID(ctx context.Context, sessionHash string) (int64, error) { - key := stickySessionPrefix + sessionHash +// buildSessionKey 构建 session key,包含 groupID 实现分组隔离 +// 格式: sticky_session:{groupID}:{sessionHash} +func buildSessionKey(groupID int64, sessionHash string) string { + return fmt.Sprintf("%s%d:%s", stickySessionPrefix, groupID, sessionHash) +} + +func (c *gatewayCache) GetSessionAccountID(ctx context.Context, groupID int64, sessionHash string) (int64, error) { + key := buildSessionKey(groupID, sessionHash) return c.rdb.Get(ctx, key).Int64() } -func (c *gatewayCache) SetSessionAccountID(ctx context.Context, sessionHash string, accountID int64, ttl time.Duration) error { - key := stickySessionPrefix + sessionHash +func (c *gatewayCache) SetSessionAccountID(ctx context.Context, groupID int64, sessionHash string, accountID int64, ttl time.Duration) error { + key := buildSessionKey(groupID, sessionHash) return c.rdb.Set(ctx, key, accountID, ttl).Err() } -func (c *gatewayCache) RefreshSessionTTL(ctx context.Context, sessionHash string, ttl time.Duration) error { - key := stickySessionPrefix + sessionHash +func (c *gatewayCache) RefreshSessionTTL(ctx context.Context, groupID int64, sessionHash string, ttl time.Duration) error { + key := buildSessionKey(groupID, sessionHash) return c.rdb.Expire(ctx, key, ttl).Err() } diff --git a/backend/internal/repository/gateway_cache_integration_test.go b/backend/internal/repository/gateway_cache_integration_test.go index 170f4074..d8885bca 100644 --- a/backend/internal/repository/gateway_cache_integration_test.go +++ b/backend/internal/repository/gateway_cache_integration_test.go @@ -24,18 +24,19 @@ func (s *GatewayCacheSuite) SetupTest() { } func (s *GatewayCacheSuite) TestGetSessionAccountID_Missing() { - _, err := s.cache.GetSessionAccountID(s.ctx, "nonexistent") + _, err := s.cache.GetSessionAccountID(s.ctx, 1, "nonexistent") require.True(s.T(), errors.Is(err, redis.Nil), "expected redis.Nil for missing session") } func (s *GatewayCacheSuite) TestSetAndGetSessionAccountID() { sessionID := "s1" accountID := int64(99) + groupID := int64(1) sessionTTL := 1 * time.Minute - require.NoError(s.T(), s.cache.SetSessionAccountID(s.ctx, sessionID, accountID, sessionTTL), "SetSessionAccountID") + require.NoError(s.T(), s.cache.SetSessionAccountID(s.ctx, groupID, sessionID, accountID, sessionTTL), "SetSessionAccountID") - sid, err := s.cache.GetSessionAccountID(s.ctx, sessionID) + sid, err := s.cache.GetSessionAccountID(s.ctx, groupID, sessionID) require.NoError(s.T(), err, "GetSessionAccountID") require.Equal(s.T(), accountID, sid, "session id mismatch") } @@ -43,11 +44,12 @@ func (s *GatewayCacheSuite) TestSetAndGetSessionAccountID() { func (s *GatewayCacheSuite) TestSessionAccountID_TTL() { sessionID := "s2" accountID := int64(100) + groupID := int64(1) sessionTTL := 1 * time.Minute - require.NoError(s.T(), s.cache.SetSessionAccountID(s.ctx, sessionID, accountID, sessionTTL), "SetSessionAccountID") + require.NoError(s.T(), s.cache.SetSessionAccountID(s.ctx, groupID, sessionID, accountID, sessionTTL), "SetSessionAccountID") - sessionKey := stickySessionPrefix + sessionID + sessionKey := buildSessionKey(groupID, sessionID) ttl, err := s.rdb.TTL(s.ctx, sessionKey).Result() require.NoError(s.T(), err, "TTL sessionKey after Set") s.AssertTTLWithin(ttl, 1*time.Second, sessionTTL) @@ -56,14 +58,15 @@ func (s *GatewayCacheSuite) TestSessionAccountID_TTL() { func (s *GatewayCacheSuite) TestRefreshSessionTTL() { sessionID := "s3" accountID := int64(101) + groupID := int64(1) initialTTL := 1 * time.Minute refreshTTL := 3 * time.Minute - require.NoError(s.T(), s.cache.SetSessionAccountID(s.ctx, sessionID, accountID, initialTTL), "SetSessionAccountID") + require.NoError(s.T(), s.cache.SetSessionAccountID(s.ctx, groupID, sessionID, accountID, initialTTL), "SetSessionAccountID") - require.NoError(s.T(), s.cache.RefreshSessionTTL(s.ctx, sessionID, refreshTTL), "RefreshSessionTTL") + require.NoError(s.T(), s.cache.RefreshSessionTTL(s.ctx, groupID, sessionID, refreshTTL), "RefreshSessionTTL") - sessionKey := stickySessionPrefix + sessionID + sessionKey := buildSessionKey(groupID, sessionID) ttl, err := s.rdb.TTL(s.ctx, sessionKey).Result() require.NoError(s.T(), err, "TTL after Refresh") s.AssertTTLWithin(ttl, 1*time.Second, refreshTTL) @@ -71,18 +74,19 @@ func (s *GatewayCacheSuite) TestRefreshSessionTTL() { func (s *GatewayCacheSuite) TestRefreshSessionTTL_MissingKey() { // RefreshSessionTTL on a missing key should not error (no-op) - err := s.cache.RefreshSessionTTL(s.ctx, "missing-session", 1*time.Minute) + err := s.cache.RefreshSessionTTL(s.ctx, 1, "missing-session", 1*time.Minute) require.NoError(s.T(), err, "RefreshSessionTTL on missing key should not error") } func (s *GatewayCacheSuite) TestGetSessionAccountID_CorruptedValue() { sessionID := "corrupted" - sessionKey := stickySessionPrefix + sessionID + groupID := int64(1) + sessionKey := buildSessionKey(groupID, sessionID) // Set a non-integer value require.NoError(s.T(), s.rdb.Set(s.ctx, sessionKey, "not-a-number", 1*time.Minute).Err(), "Set invalid value") - _, err := s.cache.GetSessionAccountID(s.ctx, sessionID) + _, err := s.cache.GetSessionAccountID(s.ctx, groupID, sessionID) require.Error(s.T(), err, "expected error for corrupted value") require.False(s.T(), errors.Is(err, redis.Nil), "expected parsing error, not redis.Nil") } diff --git a/backend/internal/repository/gemini_oauth_client.go b/backend/internal/repository/gemini_oauth_client.go index 14ecfc89..8b7fe625 100644 --- a/backend/internal/repository/gemini_oauth_client.go +++ b/backend/internal/repository/gemini_oauth_client.go @@ -30,14 +30,15 @@ func (c *geminiOAuthClient) ExchangeCode(ctx context.Context, oauthType, code, c // Use different OAuth clients based on oauthType: // - code_assist: always use built-in Gemini CLI OAuth client (public) - // - google_one: uses configured OAuth client when provided; otherwise falls back to built-in client + // - google_one: always use built-in Gemini CLI OAuth client (public) // - ai_studio: requires a user-provided OAuth client oauthCfgInput := geminicli.OAuthConfig{ ClientID: c.cfg.Gemini.OAuth.ClientID, ClientSecret: c.cfg.Gemini.OAuth.ClientSecret, Scopes: c.cfg.Gemini.OAuth.Scopes, } - if oauthType == "code_assist" { + if oauthType == "code_assist" || oauthType == "google_one" { + // Force use of built-in Gemini CLI OAuth client oauthCfgInput.ClientID = "" oauthCfgInput.ClientSecret = "" } @@ -78,7 +79,8 @@ func (c *geminiOAuthClient) RefreshToken(ctx context.Context, oauthType, refresh ClientSecret: c.cfg.Gemini.OAuth.ClientSecret, Scopes: c.cfg.Gemini.OAuth.Scopes, } - if oauthType == "code_assist" { + if oauthType == "code_assist" || oauthType == "google_one" { + // Force use of built-in Gemini CLI OAuth client oauthCfgInput.ClientID = "" oauthCfgInput.ClientSecret = "" } diff --git a/backend/internal/repository/group_repo.go b/backend/internal/repository/group_repo.go index 729c1404..a54f3116 100644 --- a/backend/internal/repository/group_repo.go +++ b/backend/internal/repository/group_repo.go @@ -46,7 +46,9 @@ func (r *groupRepository) Create(ctx context.Context, groupIn *service.Group) er SetNillableImagePrice1k(groupIn.ImagePrice1K). SetNillableImagePrice2k(groupIn.ImagePrice2K). SetNillableImagePrice4k(groupIn.ImagePrice4K). - SetDefaultValidityDays(groupIn.DefaultValidityDays) + SetDefaultValidityDays(groupIn.DefaultValidityDays). + SetClaudeCodeOnly(groupIn.ClaudeCodeOnly). + SetNillableFallbackGroupID(groupIn.FallbackGroupID) created, err := builder.Save(ctx) if err == nil { @@ -72,7 +74,7 @@ func (r *groupRepository) GetByID(ctx context.Context, id int64) (*service.Group } func (r *groupRepository) Update(ctx context.Context, groupIn *service.Group) error { - updated, err := r.client.Group.UpdateOneID(groupIn.ID). + builder := r.client.Group.UpdateOneID(groupIn.ID). SetName(groupIn.Name). SetDescription(groupIn.Description). SetPlatform(groupIn.Platform). @@ -87,7 +89,16 @@ func (r *groupRepository) Update(ctx context.Context, groupIn *service.Group) er SetNillableImagePrice2k(groupIn.ImagePrice2K). SetNillableImagePrice4k(groupIn.ImagePrice4K). SetDefaultValidityDays(groupIn.DefaultValidityDays). - Save(ctx) + SetClaudeCodeOnly(groupIn.ClaudeCodeOnly) + + // 处理 FallbackGroupID:nil 时清除,否则设置 + if groupIn.FallbackGroupID != nil { + builder = builder.SetFallbackGroupID(*groupIn.FallbackGroupID) + } else { + builder = builder.ClearFallbackGroupID() + } + + updated, err := builder.Save(ctx) if err != nil { return translatePersistenceError(err, service.ErrGroupNotFound, service.ErrGroupExists) } @@ -101,10 +112,10 @@ func (r *groupRepository) Delete(ctx context.Context, id int64) error { } func (r *groupRepository) List(ctx context.Context, params pagination.PaginationParams) ([]service.Group, *pagination.PaginationResult, error) { - return r.ListWithFilters(ctx, params, "", "", nil) + return r.ListWithFilters(ctx, params, "", "", "", nil) } -func (r *groupRepository) ListWithFilters(ctx context.Context, params pagination.PaginationParams, platform, status string, isExclusive *bool) ([]service.Group, *pagination.PaginationResult, error) { +func (r *groupRepository) ListWithFilters(ctx context.Context, params pagination.PaginationParams, platform, status, search string, isExclusive *bool) ([]service.Group, *pagination.PaginationResult, error) { q := r.client.Group.Query() if platform != "" { @@ -113,6 +124,12 @@ func (r *groupRepository) ListWithFilters(ctx context.Context, params pagination if status != "" { q = q.Where(group.StatusEQ(status)) } + if search != "" { + q = q.Where(group.Or( + group.NameContainsFold(search), + group.DescriptionContainsFold(search), + )) + } if isExclusive != nil { q = q.Where(group.IsExclusiveEQ(*isExclusive)) } diff --git a/backend/internal/repository/group_repo_integration_test.go b/backend/internal/repository/group_repo_integration_test.go index b9079d7a..660618a6 100644 --- a/backend/internal/repository/group_repo_integration_test.go +++ b/backend/internal/repository/group_repo_integration_test.go @@ -131,6 +131,7 @@ func (s *GroupRepoSuite) TestListWithFilters_Platform() { pagination.PaginationParams{Page: 1, PageSize: 10}, service.PlatformOpenAI, "", + "", nil, ) s.Require().NoError(err, "ListWithFilters base") @@ -152,7 +153,7 @@ func (s *GroupRepoSuite) TestListWithFilters_Platform() { SubscriptionType: service.SubscriptionTypeStandard, })) - groups, _, err := s.repo.ListWithFilters(s.ctx, pagination.PaginationParams{Page: 1, PageSize: 10}, service.PlatformOpenAI, "", nil) + groups, _, err := s.repo.ListWithFilters(s.ctx, pagination.PaginationParams{Page: 1, PageSize: 10}, service.PlatformOpenAI, "", "", nil) s.Require().NoError(err) s.Require().Len(groups, len(baseGroups)+1) // Verify all groups are OpenAI platform @@ -179,7 +180,7 @@ func (s *GroupRepoSuite) TestListWithFilters_Status() { SubscriptionType: service.SubscriptionTypeStandard, })) - groups, _, err := s.repo.ListWithFilters(s.ctx, pagination.PaginationParams{Page: 1, PageSize: 10}, "", service.StatusDisabled, nil) + groups, _, err := s.repo.ListWithFilters(s.ctx, pagination.PaginationParams{Page: 1, PageSize: 10}, "", service.StatusDisabled, "", nil) s.Require().NoError(err) s.Require().Len(groups, 1) s.Require().Equal(service.StatusDisabled, groups[0].Status) @@ -204,12 +205,117 @@ func (s *GroupRepoSuite) TestListWithFilters_IsExclusive() { })) isExclusive := true - groups, _, err := s.repo.ListWithFilters(s.ctx, pagination.PaginationParams{Page: 1, PageSize: 10}, "", "", &isExclusive) + groups, _, err := s.repo.ListWithFilters(s.ctx, pagination.PaginationParams{Page: 1, PageSize: 10}, "", "", "", &isExclusive) s.Require().NoError(err) s.Require().Len(groups, 1) s.Require().True(groups[0].IsExclusive) } +func (s *GroupRepoSuite) TestListWithFilters_Search() { + newRepo := func() (*groupRepository, context.Context) { + tx := testEntTx(s.T()) + return newGroupRepositoryWithSQL(tx.Client(), tx), context.Background() + } + + containsID := func(groups []service.Group, id int64) bool { + for i := range groups { + if groups[i].ID == id { + return true + } + } + return false + } + + mustCreate := func(repo *groupRepository, ctx context.Context, g *service.Group) *service.Group { + s.Require().NoError(repo.Create(ctx, g)) + s.Require().NotZero(g.ID) + return g + } + + newGroup := func(name string) *service.Group { + return &service.Group{ + Name: name, + Platform: service.PlatformAnthropic, + RateMultiplier: 1.0, + IsExclusive: false, + Status: service.StatusActive, + SubscriptionType: service.SubscriptionTypeStandard, + } + } + + s.Run("search_name_should_match", func() { + repo, ctx := newRepo() + + target := mustCreate(repo, ctx, newGroup("it-group-search-name-target")) + other := mustCreate(repo, ctx, newGroup("it-group-search-name-other")) + + groups, _, err := repo.ListWithFilters(ctx, pagination.PaginationParams{Page: 1, PageSize: 50}, "", "", "name-target", nil) + s.Require().NoError(err) + s.Require().True(containsID(groups, target.ID), "expected target group to match by name") + s.Require().False(containsID(groups, other.ID), "expected other group to be filtered out") + }) + + s.Run("search_description_should_match", func() { + repo, ctx := newRepo() + + target := newGroup("it-group-search-desc-target") + target.Description = "something about desc-needle in here" + target = mustCreate(repo, ctx, target) + + other := newGroup("it-group-search-desc-other") + other.Description = "nothing to see here" + other = mustCreate(repo, ctx, other) + + groups, _, err := repo.ListWithFilters(ctx, pagination.PaginationParams{Page: 1, PageSize: 50}, "", "", "desc-needle", nil) + s.Require().NoError(err) + s.Require().True(containsID(groups, target.ID), "expected target group to match by description") + s.Require().False(containsID(groups, other.ID), "expected other group to be filtered out") + }) + + s.Run("search_nonexistent_should_return_empty", func() { + repo, ctx := newRepo() + + _ = mustCreate(repo, ctx, newGroup("it-group-search-nonexistent-baseline")) + + search := s.T().Name() + "__no_such_group__" + groups, _, err := repo.ListWithFilters(ctx, pagination.PaginationParams{Page: 1, PageSize: 50}, "", "", search, nil) + s.Require().NoError(err) + s.Require().Empty(groups) + }) + + s.Run("search_should_be_case_insensitive", func() { + repo, ctx := newRepo() + + target := mustCreate(repo, ctx, newGroup("MiXeDCaSe-Needle")) + other := mustCreate(repo, ctx, newGroup("it-group-search-case-other")) + + groups, _, err := repo.ListWithFilters(ctx, pagination.PaginationParams{Page: 1, PageSize: 50}, "", "", "mixedcase-needle", nil) + s.Require().NoError(err) + s.Require().True(containsID(groups, target.ID), "expected case-insensitive match") + s.Require().False(containsID(groups, other.ID), "expected other group to be filtered out") + }) + + s.Run("search_should_escape_like_wildcards", func() { + repo, ctx := newRepo() + + percentTarget := mustCreate(repo, ctx, newGroup("it-group-search-100%-target")) + percentOther := mustCreate(repo, ctx, newGroup("it-group-search-100X-other")) + + groups, _, err := repo.ListWithFilters(ctx, pagination.PaginationParams{Page: 1, PageSize: 50}, "", "", "100%", nil) + s.Require().NoError(err) + s.Require().True(containsID(groups, percentTarget.ID), "expected literal %% match") + s.Require().False(containsID(groups, percentOther.ID), "expected %% not to act as wildcard") + + underscoreTarget := mustCreate(repo, ctx, newGroup("it-group-search-ab_cd-target")) + underscoreOther := mustCreate(repo, ctx, newGroup("it-group-search-abXcd-other")) + + groups, _, err = repo.ListWithFilters(ctx, pagination.PaginationParams{Page: 1, PageSize: 50}, "", "", "ab_cd", nil) + s.Require().NoError(err) + s.Require().True(containsID(groups, underscoreTarget.ID), "expected literal _ match") + s.Require().False(containsID(groups, underscoreOther.ID), "expected _ not to act as wildcard") + }) +} + func (s *GroupRepoSuite) TestListWithFilters_AccountCount() { g1 := &service.Group{ Name: "g1", @@ -244,7 +350,7 @@ func (s *GroupRepoSuite) TestListWithFilters_AccountCount() { s.Require().NoError(err) isExclusive := true - groups, page, err := s.repo.ListWithFilters(s.ctx, pagination.PaginationParams{Page: 1, PageSize: 10}, service.PlatformAnthropic, service.StatusActive, &isExclusive) + groups, page, err := s.repo.ListWithFilters(s.ctx, pagination.PaginationParams{Page: 1, PageSize: 10}, service.PlatformAnthropic, service.StatusActive, "", &isExclusive) s.Require().NoError(err, "ListWithFilters") s.Require().Equal(int64(1), page.Total) s.Require().Len(groups, 1) diff --git a/backend/internal/repository/proxy_repo.go b/backend/internal/repository/proxy_repo.go index c24b2e2c..622b0aeb 100644 --- a/backend/internal/repository/proxy_repo.go +++ b/backend/internal/repository/proxy_repo.go @@ -133,6 +133,55 @@ func (r *proxyRepository) ListWithFilters(ctx context.Context, params pagination return outProxies, paginationResultFromTotal(int64(total), params), nil } +// ListWithFiltersAndAccountCount lists proxies with filters and includes account count per proxy +func (r *proxyRepository) ListWithFiltersAndAccountCount(ctx context.Context, params pagination.PaginationParams, protocol, status, search string) ([]service.ProxyWithAccountCount, *pagination.PaginationResult, error) { + q := r.client.Proxy.Query() + if protocol != "" { + q = q.Where(proxy.ProtocolEQ(protocol)) + } + if status != "" { + q = q.Where(proxy.StatusEQ(status)) + } + if search != "" { + q = q.Where(proxy.NameContainsFold(search)) + } + + total, err := q.Count(ctx) + if err != nil { + return nil, nil, err + } + + proxies, err := q. + Offset(params.Offset()). + Limit(params.Limit()). + Order(dbent.Desc(proxy.FieldID)). + All(ctx) + if err != nil { + return nil, nil, err + } + + // Get account counts + counts, err := r.GetAccountCountsForProxies(ctx) + if err != nil { + return nil, nil, err + } + + // Build result with account counts + result := make([]service.ProxyWithAccountCount, 0, len(proxies)) + for i := range proxies { + proxyOut := proxyEntityToService(proxies[i]) + if proxyOut == nil { + continue + } + result = append(result, service.ProxyWithAccountCount{ + Proxy: *proxyOut, + AccountCount: counts[proxyOut.ID], + }) + } + + return result, paginationResultFromTotal(int64(total), params), nil +} + func (r *proxyRepository) ListActive(ctx context.Context) ([]service.Proxy, error) { proxies, err := r.client.Proxy.Query(). Where(proxy.StatusEQ(service.StatusActive)). diff --git a/backend/internal/server/api_contract_test.go b/backend/internal/server/api_contract_test.go index bd3278c8..20e82be8 100644 --- a/backend/internal/server/api_contract_test.go +++ b/backend/internal/server/api_contract_test.go @@ -243,7 +243,8 @@ func TestAPIContracts(t *testing.T) { "first_token_ms": 50, "image_count": 0, "image_size": null, - "created_at": "2025-01-02T03:04:05Z" + "created_at": "2025-01-02T03:04:05Z", + "user_agent": null } ], "total": 1, @@ -303,6 +304,10 @@ func TestAPIContracts(t *testing.T) { "turnstile_enabled": true, "turnstile_site_key": "site-key", "turnstile_secret_key_configured": true, + "linuxdo_connect_enabled": false, + "linuxdo_connect_client_id": "", + "linuxdo_connect_client_secret_configured": false, + "linuxdo_connect_redirect_url": "", "site_name": "Sub2API", "site_logo": "", "site_subtitle": "Subtitle", @@ -389,7 +394,7 @@ func newContractDeps(t *testing.T) *contractDeps { settingRepo := newStubSettingRepo() settingService := service.NewSettingService(settingRepo, cfg) - authHandler := handler.NewAuthHandler(cfg, nil, userService) + authHandler := handler.NewAuthHandler(cfg, nil, userService, settingService) apiKeyHandler := handler.NewAPIKeyHandler(apiKeyService) usageHandler := handler.NewUsageHandler(usageService, apiKeyService) adminSettingHandler := adminhandler.NewSettingHandler(settingService, nil, nil) @@ -582,7 +587,7 @@ func (stubGroupRepo) List(ctx context.Context, params pagination.PaginationParam return nil, nil, errors.New("not implemented") } -func (stubGroupRepo) ListWithFilters(ctx context.Context, params pagination.PaginationParams, platform, status string, isExclusive *bool) ([]service.Group, *pagination.PaginationResult, error) { +func (stubGroupRepo) ListWithFilters(ctx context.Context, params pagination.PaginationParams, platform, status, search string, isExclusive *bool) ([]service.Group, *pagination.PaginationResult, error) { return nil, nil, errors.New("not implemented") } diff --git a/backend/internal/server/routes/auth.go b/backend/internal/server/routes/auth.go index 196d8bdb..e61d3939 100644 --- a/backend/internal/server/routes/auth.go +++ b/backend/internal/server/routes/auth.go @@ -19,6 +19,8 @@ func RegisterAuthRoutes( auth.POST("/register", h.Auth.Register) auth.POST("/login", h.Auth.Login) auth.POST("/send-verify-code", h.Auth.SendVerifyCode) + auth.GET("/oauth/linuxdo/start", h.Auth.LinuxDoOAuthStart) + auth.GET("/oauth/linuxdo/callback", h.Auth.LinuxDoOAuthCallback) } // 公开设置(无需认证) diff --git a/backend/internal/service/account_service.go b/backend/internal/service/account_service.go index de32cfeb..2f138b81 100644 --- a/backend/internal/service/account_service.go +++ b/backend/internal/service/account_service.go @@ -68,6 +68,7 @@ type AccountBulkUpdate struct { Concurrency *int Priority *int Status *string + Schedulable *bool Credentials map[string]any Extra map[string]any } diff --git a/backend/internal/service/account_test_service.go b/backend/internal/service/account_test_service.go index 7121a13d..8419c2b4 100644 --- a/backend/internal/service/account_test_service.go +++ b/backend/internal/service/account_test_service.go @@ -661,13 +661,7 @@ func (s *AccountTestService) processGeminiStream(c *gin.Context, body io.Reader) } if candidates, ok := data["candidates"].([]any); ok && len(candidates) > 0 { if candidate, ok := candidates[0].(map[string]any); ok { - // Check for completion - if finishReason, ok := candidate["finishReason"].(string); ok && finishReason != "" { - s.sendEvent(c, TestEvent{Type: "test_complete", Success: true}) - return nil - } - - // Extract content + // Extract content first (before checking completion) if content, ok := candidate["content"].(map[string]any); ok { if parts, ok := content["parts"].([]any); ok { for _, part := range parts { @@ -679,6 +673,12 @@ func (s *AccountTestService) processGeminiStream(c *gin.Context, body io.Reader) } } } + + // Check for completion after extracting content + if finishReason, ok := candidate["finishReason"].(string); ok && finishReason != "" { + s.sendEvent(c, TestEvent{Type: "test_complete", Success: true}) + return nil + } } } diff --git a/backend/internal/service/admin_service.go b/backend/internal/service/admin_service.go index 80acd440..4288381c 100644 --- a/backend/internal/service/admin_service.go +++ b/backend/internal/service/admin_service.go @@ -24,7 +24,7 @@ type AdminService interface { GetUserUsageStats(ctx context.Context, userID int64, period string) (any, error) // Group management - ListGroups(ctx context.Context, page, pageSize int, platform, status string, isExclusive *bool) ([]Group, int64, error) + ListGroups(ctx context.Context, page, pageSize int, platform, status, search string, isExclusive *bool) ([]Group, int64, error) GetAllGroups(ctx context.Context) ([]Group, error) GetAllGroupsByPlatform(ctx context.Context, platform string) ([]Group, error) GetGroup(ctx context.Context, id int64) (*Group, error) @@ -47,6 +47,7 @@ type AdminService interface { // Proxy management ListProxies(ctx context.Context, page, pageSize int, protocol, status, search string) ([]Proxy, int64, error) + ListProxiesWithAccountCount(ctx context.Context, page, pageSize int, protocol, status, search string) ([]ProxyWithAccountCount, int64, error) GetAllProxies(ctx context.Context) ([]Proxy, error) GetAllProxiesWithAccountCount(ctx context.Context) ([]ProxyWithAccountCount, error) GetProxy(ctx context.Context, id int64) (*Proxy, error) @@ -99,9 +100,11 @@ type CreateGroupInput struct { WeeklyLimitUSD *float64 // 周限额 (USD) MonthlyLimitUSD *float64 // 月限额 (USD) // 图片生成计费配置(仅 antigravity 平台使用) - ImagePrice1K *float64 - ImagePrice2K *float64 - ImagePrice4K *float64 + ImagePrice1K *float64 + ImagePrice2K *float64 + ImagePrice4K *float64 + ClaudeCodeOnly bool // 仅允许 Claude Code 客户端 + FallbackGroupID *int64 // 降级分组 ID } type UpdateGroupInput struct { @@ -116,9 +119,11 @@ type UpdateGroupInput struct { WeeklyLimitUSD *float64 // 周限额 (USD) MonthlyLimitUSD *float64 // 月限额 (USD) // 图片生成计费配置(仅 antigravity 平台使用) - ImagePrice1K *float64 - ImagePrice2K *float64 - ImagePrice4K *float64 + ImagePrice1K *float64 + ImagePrice2K *float64 + ImagePrice4K *float64 + ClaudeCodeOnly *bool // 仅允许 Claude Code 客户端 + FallbackGroupID *int64 // 降级分组 ID } type CreateAccountInput struct { @@ -163,6 +168,7 @@ type BulkUpdateAccountsInput struct { Concurrency *int Priority *int Status string + Schedulable *bool GroupIDs *[]int64 Credentials map[string]any Extra map[string]any @@ -473,9 +479,9 @@ func (s *adminServiceImpl) GetUserUsageStats(ctx context.Context, userID int64, } // Group management implementations -func (s *adminServiceImpl) ListGroups(ctx context.Context, page, pageSize int, platform, status string, isExclusive *bool) ([]Group, int64, error) { +func (s *adminServiceImpl) ListGroups(ctx context.Context, page, pageSize int, platform, status, search string, isExclusive *bool) ([]Group, int64, error) { params := pagination.PaginationParams{Page: page, PageSize: pageSize} - groups, result, err := s.groupRepo.ListWithFilters(ctx, params, platform, status, isExclusive) + groups, result, err := s.groupRepo.ListWithFilters(ctx, params, platform, status, search, isExclusive) if err != nil { return nil, 0, err } @@ -515,6 +521,13 @@ func (s *adminServiceImpl) CreateGroup(ctx context.Context, input *CreateGroupIn imagePrice2K := normalizePrice(input.ImagePrice2K) imagePrice4K := normalizePrice(input.ImagePrice4K) + // 校验降级分组 + if input.FallbackGroupID != nil { + if err := s.validateFallbackGroup(ctx, 0, *input.FallbackGroupID); err != nil { + return nil, err + } + } + group := &Group{ Name: input.Name, Description: input.Description, @@ -529,6 +542,8 @@ func (s *adminServiceImpl) CreateGroup(ctx context.Context, input *CreateGroupIn ImagePrice1K: imagePrice1K, ImagePrice2K: imagePrice2K, ImagePrice4K: imagePrice4K, + ClaudeCodeOnly: input.ClaudeCodeOnly, + FallbackGroupID: input.FallbackGroupID, } if err := s.groupRepo.Create(ctx, group); err != nil { return nil, err @@ -552,6 +567,29 @@ func normalizePrice(price *float64) *float64 { return price } +// validateFallbackGroup 校验降级分组的有效性 +// currentGroupID: 当前分组 ID(新建时为 0) +// fallbackGroupID: 降级分组 ID +func (s *adminServiceImpl) validateFallbackGroup(ctx context.Context, currentGroupID, fallbackGroupID int64) error { + // 不能将自己设置为降级分组 + if currentGroupID > 0 && currentGroupID == fallbackGroupID { + return fmt.Errorf("cannot set self as fallback group") + } + + // 检查降级分组是否存在 + fallbackGroup, err := s.groupRepo.GetByID(ctx, fallbackGroupID) + if err != nil { + return fmt.Errorf("fallback group not found: %w", err) + } + + // 降级分组不能启用 claude_code_only,否则会造成死循环 + if fallbackGroup.ClaudeCodeOnly { + return fmt.Errorf("fallback group cannot have claude_code_only enabled") + } + + return nil +} + func (s *adminServiceImpl) UpdateGroup(ctx context.Context, id int64, input *UpdateGroupInput) (*Group, error) { group, err := s.groupRepo.GetByID(ctx, id) if err != nil { @@ -602,6 +640,23 @@ func (s *adminServiceImpl) UpdateGroup(ctx context.Context, id int64, input *Upd group.ImagePrice4K = normalizePrice(input.ImagePrice4K) } + // Claude Code 客户端限制 + if input.ClaudeCodeOnly != nil { + group.ClaudeCodeOnly = *input.ClaudeCodeOnly + } + if input.FallbackGroupID != nil { + // 校验降级分组 + if *input.FallbackGroupID > 0 { + if err := s.validateFallbackGroup(ctx, id, *input.FallbackGroupID); err != nil { + return nil, err + } + group.FallbackGroupID = input.FallbackGroupID + } else { + // 传入 0 或负数表示清除降级分组 + group.FallbackGroupID = nil + } + } + if err := s.groupRepo.Update(ctx, group); err != nil { return nil, err } @@ -856,6 +911,9 @@ func (s *adminServiceImpl) BulkUpdateAccounts(ctx context.Context, input *BulkUp if input.Status != "" { repoUpdates.Status = &input.Status } + if input.Schedulable != nil { + repoUpdates.Schedulable = input.Schedulable + } // Run bulk update for column/jsonb fields first. if _, err := s.accountRepo.BulkUpdate(ctx, input.AccountIDs, repoUpdates); err != nil { @@ -950,6 +1008,15 @@ func (s *adminServiceImpl) ListProxies(ctx context.Context, page, pageSize int, return proxies, result.Total, nil } +func (s *adminServiceImpl) ListProxiesWithAccountCount(ctx context.Context, page, pageSize int, protocol, status, search string) ([]ProxyWithAccountCount, int64, error) { + params := pagination.PaginationParams{Page: page, PageSize: pageSize} + proxies, result, err := s.proxyRepo.ListWithFiltersAndAccountCount(ctx, params, protocol, status, search) + if err != nil { + return nil, 0, err + } + return proxies, result.Total, nil +} + func (s *adminServiceImpl) GetAllProxies(ctx context.Context) ([]Proxy, error) { return s.proxyRepo.ListActive(ctx) } diff --git a/backend/internal/service/admin_service_delete_test.go b/backend/internal/service/admin_service_delete_test.go index 8aeaab43..351f64e8 100644 --- a/backend/internal/service/admin_service_delete_test.go +++ b/backend/internal/service/admin_service_delete_test.go @@ -124,7 +124,7 @@ func (s *groupRepoStub) List(ctx context.Context, params pagination.PaginationPa panic("unexpected List call") } -func (s *groupRepoStub) ListWithFilters(ctx context.Context, params pagination.PaginationParams, platform, status string, isExclusive *bool) ([]Group, *pagination.PaginationResult, error) { +func (s *groupRepoStub) ListWithFilters(ctx context.Context, params pagination.PaginationParams, platform, status, search string, isExclusive *bool) ([]Group, *pagination.PaginationResult, error) { panic("unexpected ListWithFilters call") } @@ -186,6 +186,10 @@ func (s *proxyRepoStub) ListActiveWithAccountCount(ctx context.Context) ([]Proxy panic("unexpected ListActiveWithAccountCount call") } +func (s *proxyRepoStub) ListWithFiltersAndAccountCount(ctx context.Context, params pagination.PaginationParams, protocol, status, search string) ([]ProxyWithAccountCount, *pagination.PaginationResult, error) { + panic("unexpected ListWithFiltersAndAccountCount call") +} + func (s *proxyRepoStub) ExistsByHostPortAuth(ctx context.Context, host string, port int, username, password string) (bool, error) { panic("unexpected ExistsByHostPortAuth call") } diff --git a/backend/internal/service/admin_service_group_test.go b/backend/internal/service/admin_service_group_test.go index 3171de11..26d6eedf 100644 --- a/backend/internal/service/admin_service_group_test.go +++ b/backend/internal/service/admin_service_group_test.go @@ -16,6 +16,16 @@ type groupRepoStubForAdmin struct { updated *Group // 记录 Update 调用的参数 getByID *Group // GetByID 返回值 getErr error // GetByID 返回的错误 + + listWithFiltersCalls int + listWithFiltersParams pagination.PaginationParams + listWithFiltersPlatform string + listWithFiltersStatus string + listWithFiltersSearch string + listWithFiltersIsExclusive *bool + listWithFiltersGroups []Group + listWithFiltersResult *pagination.PaginationResult + listWithFiltersErr error } func (s *groupRepoStubForAdmin) Create(_ context.Context, g *Group) error { @@ -47,8 +57,28 @@ func (s *groupRepoStubForAdmin) List(_ context.Context, _ pagination.PaginationP panic("unexpected List call") } -func (s *groupRepoStubForAdmin) ListWithFilters(_ context.Context, _ pagination.PaginationParams, _, _ string, _ *bool) ([]Group, *pagination.PaginationResult, error) { - panic("unexpected ListWithFilters call") +func (s *groupRepoStubForAdmin) ListWithFilters(_ context.Context, params pagination.PaginationParams, platform, status, search string, isExclusive *bool) ([]Group, *pagination.PaginationResult, error) { + s.listWithFiltersCalls++ + s.listWithFiltersParams = params + s.listWithFiltersPlatform = platform + s.listWithFiltersStatus = status + s.listWithFiltersSearch = search + s.listWithFiltersIsExclusive = isExclusive + + if s.listWithFiltersErr != nil { + return nil, nil, s.listWithFiltersErr + } + + result := s.listWithFiltersResult + if result == nil { + result = &pagination.PaginationResult{ + Total: int64(len(s.listWithFiltersGroups)), + Page: params.Page, + PageSize: params.PageSize, + } + } + + return s.listWithFiltersGroups, result, nil } func (s *groupRepoStubForAdmin) ListActive(_ context.Context) ([]Group, error) { @@ -195,3 +225,68 @@ func TestAdminService_UpdateGroup_PartialImagePricing(t *testing.T) { require.InDelta(t, 0.15, *repo.updated.ImagePrice2K, 0.0001) // 原值保持 require.Nil(t, repo.updated.ImagePrice4K) } + +func TestAdminService_ListGroups_WithSearch(t *testing.T) { + // 测试: + // 1. search 参数正常传递到 repository 层 + // 2. search 为空字符串时的行为 + // 3. search 与其他过滤条件组合使用 + + t.Run("search 参数正常传递到 repository 层", func(t *testing.T) { + repo := &groupRepoStubForAdmin{ + listWithFiltersGroups: []Group{{ID: 1, Name: "alpha"}}, + listWithFiltersResult: &pagination.PaginationResult{Total: 1}, + } + svc := &adminServiceImpl{groupRepo: repo} + + groups, total, err := svc.ListGroups(context.Background(), 1, 20, "", "", "alpha", nil) + require.NoError(t, err) + require.Equal(t, int64(1), total) + require.Equal(t, []Group{{ID: 1, Name: "alpha"}}, groups) + + require.Equal(t, 1, repo.listWithFiltersCalls) + require.Equal(t, pagination.PaginationParams{Page: 1, PageSize: 20}, repo.listWithFiltersParams) + require.Equal(t, "alpha", repo.listWithFiltersSearch) + require.Nil(t, repo.listWithFiltersIsExclusive) + }) + + t.Run("search 为空字符串时传递空字符串", func(t *testing.T) { + repo := &groupRepoStubForAdmin{ + listWithFiltersGroups: []Group{}, + listWithFiltersResult: &pagination.PaginationResult{Total: 0}, + } + svc := &adminServiceImpl{groupRepo: repo} + + groups, total, err := svc.ListGroups(context.Background(), 2, 10, "", "", "", nil) + require.NoError(t, err) + require.Empty(t, groups) + require.Equal(t, int64(0), total) + + require.Equal(t, 1, repo.listWithFiltersCalls) + require.Equal(t, pagination.PaginationParams{Page: 2, PageSize: 10}, repo.listWithFiltersParams) + require.Equal(t, "", repo.listWithFiltersSearch) + require.Nil(t, repo.listWithFiltersIsExclusive) + }) + + t.Run("search 与其他过滤条件组合使用", func(t *testing.T) { + isExclusive := true + repo := &groupRepoStubForAdmin{ + listWithFiltersGroups: []Group{{ID: 2, Name: "beta"}}, + listWithFiltersResult: &pagination.PaginationResult{Total: 42}, + } + svc := &adminServiceImpl{groupRepo: repo} + + groups, total, err := svc.ListGroups(context.Background(), 3, 50, PlatformAntigravity, StatusActive, "beta", &isExclusive) + require.NoError(t, err) + require.Equal(t, int64(42), total) + require.Equal(t, []Group{{ID: 2, Name: "beta"}}, groups) + + require.Equal(t, 1, repo.listWithFiltersCalls) + require.Equal(t, pagination.PaginationParams{Page: 3, PageSize: 50}, repo.listWithFiltersParams) + require.Equal(t, PlatformAntigravity, repo.listWithFiltersPlatform) + require.Equal(t, StatusActive, repo.listWithFiltersStatus) + require.Equal(t, "beta", repo.listWithFiltersSearch) + require.NotNil(t, repo.listWithFiltersIsExclusive) + require.True(t, *repo.listWithFiltersIsExclusive) + }) +} diff --git a/backend/internal/service/admin_service_search_test.go b/backend/internal/service/admin_service_search_test.go new file mode 100644 index 00000000..7506c6db --- /dev/null +++ b/backend/internal/service/admin_service_search_test.go @@ -0,0 +1,238 @@ +//go:build unit + +package service + +import ( + "context" + "testing" + + "github.com/Wei-Shaw/sub2api/internal/pkg/pagination" + "github.com/stretchr/testify/require" +) + +type accountRepoStubForAdminList struct { + accountRepoStub + + listWithFiltersCalls int + listWithFiltersParams pagination.PaginationParams + listWithFiltersPlatform string + listWithFiltersType string + listWithFiltersStatus string + listWithFiltersSearch string + listWithFiltersAccounts []Account + listWithFiltersResult *pagination.PaginationResult + listWithFiltersErr error +} + +func (s *accountRepoStubForAdminList) ListWithFilters(_ context.Context, params pagination.PaginationParams, platform, accountType, status, search string) ([]Account, *pagination.PaginationResult, error) { + s.listWithFiltersCalls++ + s.listWithFiltersParams = params + s.listWithFiltersPlatform = platform + s.listWithFiltersType = accountType + s.listWithFiltersStatus = status + s.listWithFiltersSearch = search + + if s.listWithFiltersErr != nil { + return nil, nil, s.listWithFiltersErr + } + + result := s.listWithFiltersResult + if result == nil { + result = &pagination.PaginationResult{ + Total: int64(len(s.listWithFiltersAccounts)), + Page: params.Page, + PageSize: params.PageSize, + } + } + + return s.listWithFiltersAccounts, result, nil +} + +type proxyRepoStubForAdminList struct { + proxyRepoStub + + listWithFiltersCalls int + listWithFiltersParams pagination.PaginationParams + listWithFiltersProtocol string + listWithFiltersStatus string + listWithFiltersSearch string + listWithFiltersProxies []Proxy + listWithFiltersResult *pagination.PaginationResult + listWithFiltersErr error + + listWithFiltersAndAccountCountCalls int + listWithFiltersAndAccountCountParams pagination.PaginationParams + listWithFiltersAndAccountCountProtocol string + listWithFiltersAndAccountCountStatus string + listWithFiltersAndAccountCountSearch string + listWithFiltersAndAccountCountProxies []ProxyWithAccountCount + listWithFiltersAndAccountCountResult *pagination.PaginationResult + listWithFiltersAndAccountCountErr error +} + +func (s *proxyRepoStubForAdminList) ListWithFilters(_ context.Context, params pagination.PaginationParams, protocol, status, search string) ([]Proxy, *pagination.PaginationResult, error) { + s.listWithFiltersCalls++ + s.listWithFiltersParams = params + s.listWithFiltersProtocol = protocol + s.listWithFiltersStatus = status + s.listWithFiltersSearch = search + + if s.listWithFiltersErr != nil { + return nil, nil, s.listWithFiltersErr + } + + result := s.listWithFiltersResult + if result == nil { + result = &pagination.PaginationResult{ + Total: int64(len(s.listWithFiltersProxies)), + Page: params.Page, + PageSize: params.PageSize, + } + } + + return s.listWithFiltersProxies, result, nil +} + +func (s *proxyRepoStubForAdminList) ListWithFiltersAndAccountCount(_ context.Context, params pagination.PaginationParams, protocol, status, search string) ([]ProxyWithAccountCount, *pagination.PaginationResult, error) { + s.listWithFiltersAndAccountCountCalls++ + s.listWithFiltersAndAccountCountParams = params + s.listWithFiltersAndAccountCountProtocol = protocol + s.listWithFiltersAndAccountCountStatus = status + s.listWithFiltersAndAccountCountSearch = search + + if s.listWithFiltersAndAccountCountErr != nil { + return nil, nil, s.listWithFiltersAndAccountCountErr + } + + result := s.listWithFiltersAndAccountCountResult + if result == nil { + result = &pagination.PaginationResult{ + Total: int64(len(s.listWithFiltersAndAccountCountProxies)), + Page: params.Page, + PageSize: params.PageSize, + } + } + + return s.listWithFiltersAndAccountCountProxies, result, nil +} + +type redeemRepoStubForAdminList struct { + redeemRepoStub + + listWithFiltersCalls int + listWithFiltersParams pagination.PaginationParams + listWithFiltersType string + listWithFiltersStatus string + listWithFiltersSearch string + listWithFiltersCodes []RedeemCode + listWithFiltersResult *pagination.PaginationResult + listWithFiltersErr error +} + +func (s *redeemRepoStubForAdminList) ListWithFilters(_ context.Context, params pagination.PaginationParams, codeType, status, search string) ([]RedeemCode, *pagination.PaginationResult, error) { + s.listWithFiltersCalls++ + s.listWithFiltersParams = params + s.listWithFiltersType = codeType + s.listWithFiltersStatus = status + s.listWithFiltersSearch = search + + if s.listWithFiltersErr != nil { + return nil, nil, s.listWithFiltersErr + } + + result := s.listWithFiltersResult + if result == nil { + result = &pagination.PaginationResult{ + Total: int64(len(s.listWithFiltersCodes)), + Page: params.Page, + PageSize: params.PageSize, + } + } + + return s.listWithFiltersCodes, result, nil +} + +func TestAdminService_ListAccounts_WithSearch(t *testing.T) { + t.Run("search 参数正常传递到 repository 层", func(t *testing.T) { + repo := &accountRepoStubForAdminList{ + listWithFiltersAccounts: []Account{{ID: 1, Name: "acc"}}, + listWithFiltersResult: &pagination.PaginationResult{Total: 10}, + } + svc := &adminServiceImpl{accountRepo: repo} + + accounts, total, err := svc.ListAccounts(context.Background(), 1, 20, PlatformGemini, AccountTypeOAuth, StatusActive, "acc") + require.NoError(t, err) + require.Equal(t, int64(10), total) + require.Equal(t, []Account{{ID: 1, Name: "acc"}}, accounts) + + require.Equal(t, 1, repo.listWithFiltersCalls) + require.Equal(t, pagination.PaginationParams{Page: 1, PageSize: 20}, repo.listWithFiltersParams) + require.Equal(t, PlatformGemini, repo.listWithFiltersPlatform) + require.Equal(t, AccountTypeOAuth, repo.listWithFiltersType) + require.Equal(t, StatusActive, repo.listWithFiltersStatus) + require.Equal(t, "acc", repo.listWithFiltersSearch) + }) +} + +func TestAdminService_ListProxies_WithSearch(t *testing.T) { + t.Run("search 参数正常传递到 repository 层", func(t *testing.T) { + repo := &proxyRepoStubForAdminList{ + listWithFiltersProxies: []Proxy{{ID: 2, Name: "p1"}}, + listWithFiltersResult: &pagination.PaginationResult{Total: 7}, + } + svc := &adminServiceImpl{proxyRepo: repo} + + proxies, total, err := svc.ListProxies(context.Background(), 3, 50, "http", StatusActive, "p1") + require.NoError(t, err) + require.Equal(t, int64(7), total) + require.Equal(t, []Proxy{{ID: 2, Name: "p1"}}, proxies) + + require.Equal(t, 1, repo.listWithFiltersCalls) + require.Equal(t, pagination.PaginationParams{Page: 3, PageSize: 50}, repo.listWithFiltersParams) + require.Equal(t, "http", repo.listWithFiltersProtocol) + require.Equal(t, StatusActive, repo.listWithFiltersStatus) + require.Equal(t, "p1", repo.listWithFiltersSearch) + }) +} + +func TestAdminService_ListProxiesWithAccountCount_WithSearch(t *testing.T) { + t.Run("search 参数正常传递到 repository 层", func(t *testing.T) { + repo := &proxyRepoStubForAdminList{ + listWithFiltersAndAccountCountProxies: []ProxyWithAccountCount{{Proxy: Proxy{ID: 3, Name: "p2"}, AccountCount: 5}}, + listWithFiltersAndAccountCountResult: &pagination.PaginationResult{Total: 9}, + } + svc := &adminServiceImpl{proxyRepo: repo} + + proxies, total, err := svc.ListProxiesWithAccountCount(context.Background(), 2, 10, "socks5", StatusDisabled, "p2") + require.NoError(t, err) + require.Equal(t, int64(9), total) + require.Equal(t, []ProxyWithAccountCount{{Proxy: Proxy{ID: 3, Name: "p2"}, AccountCount: 5}}, proxies) + + require.Equal(t, 1, repo.listWithFiltersAndAccountCountCalls) + require.Equal(t, pagination.PaginationParams{Page: 2, PageSize: 10}, repo.listWithFiltersAndAccountCountParams) + require.Equal(t, "socks5", repo.listWithFiltersAndAccountCountProtocol) + require.Equal(t, StatusDisabled, repo.listWithFiltersAndAccountCountStatus) + require.Equal(t, "p2", repo.listWithFiltersAndAccountCountSearch) + }) +} + +func TestAdminService_ListRedeemCodes_WithSearch(t *testing.T) { + t.Run("search 参数正常传递到 repository 层", func(t *testing.T) { + repo := &redeemRepoStubForAdminList{ + listWithFiltersCodes: []RedeemCode{{ID: 4, Code: "ABC"}}, + listWithFiltersResult: &pagination.PaginationResult{Total: 3}, + } + svc := &adminServiceImpl{redeemCodeRepo: repo} + + codes, total, err := svc.ListRedeemCodes(context.Background(), 1, 20, RedeemTypeBalance, StatusUnused, "ABC") + require.NoError(t, err) + require.Equal(t, int64(3), total) + require.Equal(t, []RedeemCode{{ID: 4, Code: "ABC"}}, codes) + + require.Equal(t, 1, repo.listWithFiltersCalls) + require.Equal(t, pagination.PaginationParams{Page: 1, PageSize: 20}, repo.listWithFiltersParams) + require.Equal(t, RedeemTypeBalance, repo.listWithFiltersType) + require.Equal(t, StatusUnused, repo.listWithFiltersStatus) + require.Equal(t, "ABC", repo.listWithFiltersSearch) + }) +} diff --git a/backend/internal/service/antigravity_gateway_service.go b/backend/internal/service/antigravity_gateway_service.go index fe4eb621..4fd55757 100644 --- a/backend/internal/service/antigravity_gateway_service.go +++ b/backend/internal/service/antigravity_gateway_service.go @@ -10,6 +10,7 @@ import ( "io" "log" mathrand "math/rand" + "net" "net/http" "strings" "sync/atomic" @@ -27,6 +28,32 @@ const ( antigravityRetryMaxDelay = 16 * time.Second ) +// isAntigravityConnectionError 判断是否为连接错误(网络超时、DNS 失败、连接拒绝) +func isAntigravityConnectionError(err error) bool { + if err == nil { + return false + } + + // 检查超时错误 + var netErr net.Error + if errors.As(err, &netErr) && netErr.Timeout() { + return true + } + + // 检查连接错误(DNS 失败、连接拒绝) + var opErr *net.OpError + return errors.As(err, &opErr) +} + +// shouldAntigravityFallbackToNextURL 判断是否应切换到下一个 URL +// 仅连接错误和 HTTP 429 触发 URL 降级 +func shouldAntigravityFallbackToNextURL(err error, statusCode int) bool { + if isAntigravityConnectionError(err) { + return true + } + return statusCode == http.StatusTooManyRequests +} + // getSessionID 从 gin.Context 获取 session_id(用于日志追踪) func getSessionID(c *gin.Context) string { if c == nil { @@ -182,45 +209,70 @@ func (s *AntigravityGatewayService) TestConnection(ctx context.Context, account return nil, fmt.Errorf("构建请求失败: %w", err) } - // 构建 HTTP 请求(总是使用流式 endpoint,与官方客户端一致) - req, err := antigravity.NewAPIRequest(ctx, "streamGenerateContent", accessToken, requestBody) - if err != nil { - return nil, err - } - - // 调试日志:Test 请求信息 - log.Printf("[antigravity-Test] account=%s request_size=%d url=%s", account.Name, len(requestBody), req.URL.String()) - // 代理 URL proxyURL := "" if account.ProxyID != nil && account.Proxy != nil { proxyURL = account.Proxy.URL() } - // 发送请求 - resp, err := s.httpUpstream.Do(req, proxyURL, account.ID, account.Concurrency) - if err != nil { - return nil, fmt.Errorf("请求失败: %w", err) - } - defer func() { _ = resp.Body.Close() }() - - // 读取响应 - respBody, err := io.ReadAll(io.LimitReader(resp.Body, 2<<20)) - if err != nil { - return nil, fmt.Errorf("读取响应失败: %w", err) + // URL fallback 循环 + availableURLs := antigravity.DefaultURLAvailability.GetAvailableURLs() + if len(availableURLs) == 0 { + availableURLs = antigravity.BaseURLs // 所有 URL 都不可用时,重试所有 } - if resp.StatusCode >= 400 { - return nil, fmt.Errorf("API 返回 %d: %s", resp.StatusCode, string(respBody)) + var lastErr error + for urlIdx, baseURL := range availableURLs { + // 构建 HTTP 请求(总是使用流式 endpoint,与官方客户端一致) + req, err := antigravity.NewAPIRequestWithURL(ctx, baseURL, "streamGenerateContent", accessToken, requestBody) + if err != nil { + lastErr = err + continue + } + + // 调试日志:Test 请求信息 + log.Printf("[antigravity-Test] account=%s request_size=%d url=%s", account.Name, len(requestBody), req.URL.String()) + + // 发送请求 + resp, err := s.httpUpstream.Do(req, proxyURL, account.ID, account.Concurrency) + if err != nil { + lastErr = fmt.Errorf("请求失败: %w", err) + if shouldAntigravityFallbackToNextURL(err, 0) && urlIdx < len(availableURLs)-1 { + antigravity.DefaultURLAvailability.MarkUnavailable(baseURL) + log.Printf("[antigravity-Test] URL fallback: %s -> %s", baseURL, availableURLs[urlIdx+1]) + continue + } + return nil, lastErr + } + + // 读取响应 + respBody, err := io.ReadAll(io.LimitReader(resp.Body, 2<<20)) + _ = resp.Body.Close() // 立即关闭,避免循环内 defer 导致的资源泄漏 + if err != nil { + return nil, fmt.Errorf("读取响应失败: %w", err) + } + + // 检查是否需要 URL 降级 + if shouldAntigravityFallbackToNextURL(nil, resp.StatusCode) && urlIdx < len(availableURLs)-1 { + antigravity.DefaultURLAvailability.MarkUnavailable(baseURL) + log.Printf("[antigravity-Test] URL fallback (HTTP %d): %s -> %s", resp.StatusCode, baseURL, availableURLs[urlIdx+1]) + continue + } + + if resp.StatusCode >= 400 { + return nil, fmt.Errorf("API 返回 %d: %s", resp.StatusCode, string(respBody)) + } + + // 解析流式响应,提取文本 + text := extractTextFromSSEResponse(respBody) + + return &TestConnectionResult{ + Text: text, + MappedModel: mappedModel, + }, nil } - // 解析流式响应,提取文本 - text := extractTextFromSSEResponse(respBody) - - return &TestConnectionResult{ - Text: text, - MappedModel: mappedModel, - }, nil + return nil, lastErr } // buildGeminiTestRequest 构建 Gemini 格式测试请求 @@ -486,62 +538,86 @@ func (s *AntigravityGatewayService) Forward(ctx context.Context, c *gin.Context, // 如果客户端请求非流式,在响应处理阶段会收集完整流式响应后转换返回 action := "streamGenerateContent" + // URL fallback 循环 + availableURLs := antigravity.DefaultURLAvailability.GetAvailableURLs() + if len(availableURLs) == 0 { + availableURLs = antigravity.BaseURLs // 所有 URL 都不可用时,重试所有 + } + // 重试循环 var resp *http.Response - for attempt := 1; attempt <= antigravityMaxRetries; attempt++ { - // 检查 context 是否已取消(客户端断开连接) - select { - case <-ctx.Done(): - log.Printf("%s status=context_canceled error=%v", prefix, ctx.Err()) - return nil, ctx.Err() - default: - } +urlFallbackLoop: + for urlIdx, baseURL := range availableURLs { + for attempt := 1; attempt <= antigravityMaxRetries; attempt++ { + // 检查 context 是否已取消(客户端断开连接) + select { + case <-ctx.Done(): + log.Printf("%s status=context_canceled error=%v", prefix, ctx.Err()) + return nil, ctx.Err() + default: + } - upstreamReq, err := antigravity.NewAPIRequest(ctx, action, accessToken, geminiBody) - if err != nil { - return nil, err - } + upstreamReq, err := antigravity.NewAPIRequestWithURL(ctx, baseURL, action, accessToken, geminiBody) + if err != nil { + return nil, err + } - resp, err = s.httpUpstream.Do(upstreamReq, proxyURL, account.ID, account.Concurrency) - if err != nil { - if attempt < antigravityMaxRetries { - log.Printf("%s status=request_failed retry=%d/%d error=%v", prefix, attempt, antigravityMaxRetries, err) - if !sleepAntigravityBackoffWithContext(ctx, attempt) { - log.Printf("%s status=context_canceled_during_backoff", prefix) - return nil, ctx.Err() + resp, err = s.httpUpstream.Do(upstreamReq, proxyURL, account.ID, account.Concurrency) + if err != nil { + // 检查是否应触发 URL 降级 + if shouldAntigravityFallbackToNextURL(err, 0) && urlIdx < len(availableURLs)-1 { + antigravity.DefaultURLAvailability.MarkUnavailable(baseURL) + log.Printf("%s URL fallback (connection error): %s -> %s", prefix, baseURL, availableURLs[urlIdx+1]) + continue urlFallbackLoop } - continue - } - log.Printf("%s status=request_failed retries_exhausted error=%v", prefix, err) - return nil, s.writeClaudeError(c, http.StatusBadGateway, "upstream_error", "Upstream request failed after retries") - } - - if resp.StatusCode >= 400 && s.shouldRetryUpstreamError(resp.StatusCode) { - respBody, _ := io.ReadAll(io.LimitReader(resp.Body, 2<<20)) - _ = resp.Body.Close() - - if attempt < antigravityMaxRetries { - log.Printf("%s status=%d retry=%d/%d body=%s", prefix, resp.StatusCode, attempt, antigravityMaxRetries, truncateForLog(respBody, 500)) - if !sleepAntigravityBackoffWithContext(ctx, attempt) { - log.Printf("%s status=context_canceled_during_backoff", prefix) - return nil, ctx.Err() + if attempt < antigravityMaxRetries { + log.Printf("%s status=request_failed retry=%d/%d error=%v", prefix, attempt, antigravityMaxRetries, err) + if !sleepAntigravityBackoffWithContext(ctx, attempt) { + log.Printf("%s status=context_canceled_during_backoff", prefix) + return nil, ctx.Err() + } + continue } - continue + log.Printf("%s status=request_failed retries_exhausted error=%v", prefix, err) + return nil, s.writeClaudeError(c, http.StatusBadGateway, "upstream_error", "Upstream request failed after retries") } - // 所有重试都失败,标记限流状态 - if resp.StatusCode == 429 { - s.handleUpstreamError(ctx, prefix, account, resp.StatusCode, resp.Header, respBody, quotaScope) - } - // 最后一次尝试也失败 - resp = &http.Response{ - StatusCode: resp.StatusCode, - Header: resp.Header.Clone(), - Body: io.NopCloser(bytes.NewReader(respBody)), - } - break - } - break + // 检查是否应触发 URL 降级(仅 429) + if resp.StatusCode == http.StatusTooManyRequests && urlIdx < len(availableURLs)-1 { + respBody, _ := io.ReadAll(io.LimitReader(resp.Body, 2<<20)) + _ = resp.Body.Close() + antigravity.DefaultURLAvailability.MarkUnavailable(baseURL) + log.Printf("%s URL fallback (HTTP 429): %s -> %s body=%s", prefix, baseURL, availableURLs[urlIdx+1], truncateForLog(respBody, 200)) + continue urlFallbackLoop + } + + if resp.StatusCode >= 400 && s.shouldRetryUpstreamError(resp.StatusCode) { + respBody, _ := io.ReadAll(io.LimitReader(resp.Body, 2<<20)) + _ = resp.Body.Close() + + if attempt < antigravityMaxRetries { + log.Printf("%s status=%d retry=%d/%d body=%s", prefix, resp.StatusCode, attempt, antigravityMaxRetries, truncateForLog(respBody, 500)) + if !sleepAntigravityBackoffWithContext(ctx, attempt) { + log.Printf("%s status=context_canceled_during_backoff", prefix) + return nil, ctx.Err() + } + continue + } + // 所有重试都失败,标记限流状态 + if resp.StatusCode == 429 { + s.handleUpstreamError(ctx, prefix, account, resp.StatusCode, resp.Header, respBody, quotaScope) + } + // 最后一次尝试也失败 + resp = &http.Response{ + StatusCode: resp.StatusCode, + Header: resp.Header.Clone(), + Body: io.NopCloser(bytes.NewReader(respBody)), + } + break urlFallbackLoop + } + + break urlFallbackLoop + } } defer func() { _ = resp.Body.Close() }() @@ -1006,61 +1082,85 @@ func (s *AntigravityGatewayService) ForwardGemini(ctx context.Context, c *gin.Co // 如果客户端请求非流式,在响应处理阶段会收集完整流式响应后返回 upstreamAction := "streamGenerateContent" + // URL fallback 循环 + availableURLs := antigravity.DefaultURLAvailability.GetAvailableURLs() + if len(availableURLs) == 0 { + availableURLs = antigravity.BaseURLs // 所有 URL 都不可用时,重试所有 + } + // 重试循环 var resp *http.Response - for attempt := 1; attempt <= antigravityMaxRetries; attempt++ { - // 检查 context 是否已取消(客户端断开连接) - select { - case <-ctx.Done(): - log.Printf("%s status=context_canceled error=%v", prefix, ctx.Err()) - return nil, ctx.Err() - default: - } +urlFallbackLoop: + for urlIdx, baseURL := range availableURLs { + for attempt := 1; attempt <= antigravityMaxRetries; attempt++ { + // 检查 context 是否已取消(客户端断开连接) + select { + case <-ctx.Done(): + log.Printf("%s status=context_canceled error=%v", prefix, ctx.Err()) + return nil, ctx.Err() + default: + } - upstreamReq, err := antigravity.NewAPIRequest(ctx, upstreamAction, accessToken, wrappedBody) - if err != nil { - return nil, err - } + upstreamReq, err := antigravity.NewAPIRequestWithURL(ctx, baseURL, upstreamAction, accessToken, wrappedBody) + if err != nil { + return nil, err + } - resp, err = s.httpUpstream.Do(upstreamReq, proxyURL, account.ID, account.Concurrency) - if err != nil { - if attempt < antigravityMaxRetries { - log.Printf("%s status=request_failed retry=%d/%d error=%v", prefix, attempt, antigravityMaxRetries, err) - if !sleepAntigravityBackoffWithContext(ctx, attempt) { - log.Printf("%s status=context_canceled_during_backoff", prefix) - return nil, ctx.Err() + resp, err = s.httpUpstream.Do(upstreamReq, proxyURL, account.ID, account.Concurrency) + if err != nil { + // 检查是否应触发 URL 降级 + if shouldAntigravityFallbackToNextURL(err, 0) && urlIdx < len(availableURLs)-1 { + antigravity.DefaultURLAvailability.MarkUnavailable(baseURL) + log.Printf("%s URL fallback (connection error): %s -> %s", prefix, baseURL, availableURLs[urlIdx+1]) + continue urlFallbackLoop } - continue - } - log.Printf("%s status=request_failed retries_exhausted error=%v", prefix, err) - return nil, s.writeGoogleError(c, http.StatusBadGateway, "Upstream request failed after retries") - } - - if resp.StatusCode >= 400 && s.shouldRetryUpstreamError(resp.StatusCode) { - respBody, _ := io.ReadAll(io.LimitReader(resp.Body, 2<<20)) - _ = resp.Body.Close() - - if attempt < antigravityMaxRetries { - log.Printf("%s status=%d retry=%d/%d", prefix, resp.StatusCode, attempt, antigravityMaxRetries) - if !sleepAntigravityBackoffWithContext(ctx, attempt) { - log.Printf("%s status=context_canceled_during_backoff", prefix) - return nil, ctx.Err() + if attempt < antigravityMaxRetries { + log.Printf("%s status=request_failed retry=%d/%d error=%v", prefix, attempt, antigravityMaxRetries, err) + if !sleepAntigravityBackoffWithContext(ctx, attempt) { + log.Printf("%s status=context_canceled_during_backoff", prefix) + return nil, ctx.Err() + } + continue } - continue + log.Printf("%s status=request_failed retries_exhausted error=%v", prefix, err) + return nil, s.writeGoogleError(c, http.StatusBadGateway, "Upstream request failed after retries") } - // 所有重试都失败,标记限流状态 - if resp.StatusCode == 429 { - s.handleUpstreamError(ctx, prefix, account, resp.StatusCode, resp.Header, respBody, quotaScope) - } - resp = &http.Response{ - StatusCode: resp.StatusCode, - Header: resp.Header.Clone(), - Body: io.NopCloser(bytes.NewReader(respBody)), - } - break - } - break + // 检查是否应触发 URL 降级(仅 429) + if resp.StatusCode == http.StatusTooManyRequests && urlIdx < len(availableURLs)-1 { + respBody, _ := io.ReadAll(io.LimitReader(resp.Body, 2<<20)) + _ = resp.Body.Close() + antigravity.DefaultURLAvailability.MarkUnavailable(baseURL) + log.Printf("%s URL fallback (HTTP 429): %s -> %s body=%s", prefix, baseURL, availableURLs[urlIdx+1], truncateForLog(respBody, 200)) + continue urlFallbackLoop + } + + if resp.StatusCode >= 400 && s.shouldRetryUpstreamError(resp.StatusCode) { + respBody, _ := io.ReadAll(io.LimitReader(resp.Body, 2<<20)) + _ = resp.Body.Close() + + if attempt < antigravityMaxRetries { + log.Printf("%s status=%d retry=%d/%d", prefix, resp.StatusCode, attempt, antigravityMaxRetries) + if !sleepAntigravityBackoffWithContext(ctx, attempt) { + log.Printf("%s status=context_canceled_during_backoff", prefix) + return nil, ctx.Err() + } + continue + } + // 所有重试都失败,标记限流状态 + if resp.StatusCode == 429 { + s.handleUpstreamError(ctx, prefix, account, resp.StatusCode, resp.Header, respBody, quotaScope) + } + resp = &http.Response{ + StatusCode: resp.StatusCode, + Header: resp.Header.Clone(), + Body: io.NopCloser(bytes.NewReader(respBody)), + } + break urlFallbackLoop + } + + break urlFallbackLoop + } } defer func() { if resp != nil && resp.Body != nil { diff --git a/backend/internal/service/auth_service.go b/backend/internal/service/auth_service.go index 85772e75..e232deb3 100644 --- a/backend/internal/service/auth_service.go +++ b/backend/internal/service/auth_service.go @@ -2,9 +2,13 @@ package service import ( "context" + "crypto/rand" + "encoding/hex" "errors" "fmt" "log" + "net/mail" + "strings" "time" "github.com/Wei-Shaw/sub2api/internal/config" @@ -18,6 +22,7 @@ var ( ErrInvalidCredentials = infraerrors.Unauthorized("INVALID_CREDENTIALS", "invalid email or password") ErrUserNotActive = infraerrors.Forbidden("USER_NOT_ACTIVE", "user is not active") ErrEmailExists = infraerrors.Conflict("EMAIL_EXISTS", "email already exists") + ErrEmailReserved = infraerrors.BadRequest("EMAIL_RESERVED", "email is reserved") ErrInvalidToken = infraerrors.Unauthorized("INVALID_TOKEN", "invalid token") ErrTokenExpired = infraerrors.Unauthorized("TOKEN_EXPIRED", "token has expired") ErrTokenTooLarge = infraerrors.BadRequest("TOKEN_TOO_LARGE", "token too large") @@ -75,21 +80,30 @@ func (s *AuthService) Register(ctx context.Context, email, password string) (str // RegisterWithVerification 用户注册(支持邮件验证),返回token和用户 func (s *AuthService) RegisterWithVerification(ctx context.Context, email, password, verifyCode string) (string, *User, error) { - // 检查是否开放注册 - if s.settingService != nil && !s.settingService.IsRegistrationEnabled(ctx) { + // 检查是否开放注册(默认关闭:settingService 未配置时不允许注册) + if s.settingService == nil || !s.settingService.IsRegistrationEnabled(ctx) { return "", nil, ErrRegDisabled } + // 防止用户注册 LinuxDo OAuth 合成邮箱,避免第三方登录与本地账号发生碰撞。 + if isReservedEmail(email) { + return "", nil, ErrEmailReserved + } + // 检查是否需要邮件验证 if s.settingService != nil && s.settingService.IsEmailVerifyEnabled(ctx) { + // 如果邮件验证已开启但邮件服务未配置,拒绝注册 + // 这是一个配置错误,不应该允许绕过验证 + if s.emailService == nil { + log.Println("[Auth] Email verification enabled but email service not configured, rejecting registration") + return "", nil, ErrServiceUnavailable + } if verifyCode == "" { return "", nil, ErrEmailVerifyRequired } // 验证邮箱验证码 - if s.emailService != nil { - if err := s.emailService.VerifyCode(ctx, email, verifyCode); err != nil { - return "", nil, fmt.Errorf("verify code: %w", err) - } + if err := s.emailService.VerifyCode(ctx, email, verifyCode); err != nil { + return "", nil, fmt.Errorf("verify code: %w", err) } } @@ -128,6 +142,10 @@ func (s *AuthService) RegisterWithVerification(ctx context.Context, email, passw } if err := s.userRepo.Create(ctx, user); err != nil { + // 优先检查邮箱冲突错误(竞态条件下可能发生) + if errors.Is(err, ErrEmailExists) { + return "", nil, ErrEmailExists + } log.Printf("[Auth] Database error creating user: %v", err) return "", nil, ErrServiceUnavailable } @@ -148,11 +166,15 @@ type SendVerifyCodeResult struct { // SendVerifyCode 发送邮箱验证码(同步方式) func (s *AuthService) SendVerifyCode(ctx context.Context, email string) error { - // 检查是否开放注册 - if s.settingService != nil && !s.settingService.IsRegistrationEnabled(ctx) { + // 检查是否开放注册(默认关闭) + if s.settingService == nil || !s.settingService.IsRegistrationEnabled(ctx) { return ErrRegDisabled } + if isReservedEmail(email) { + return ErrEmailReserved + } + // 检查邮箱是否已存在 existsEmail, err := s.userRepo.ExistsByEmail(ctx, email) if err != nil { @@ -181,12 +203,16 @@ func (s *AuthService) SendVerifyCode(ctx context.Context, email string) error { func (s *AuthService) SendVerifyCodeAsync(ctx context.Context, email string) (*SendVerifyCodeResult, error) { log.Printf("[Auth] SendVerifyCodeAsync called for email: %s", email) - // 检查是否开放注册 - if s.settingService != nil && !s.settingService.IsRegistrationEnabled(ctx) { + // 检查是否开放注册(默认关闭) + if s.settingService == nil || !s.settingService.IsRegistrationEnabled(ctx) { log.Println("[Auth] Registration is disabled") return nil, ErrRegDisabled } + if isReservedEmail(email) { + return nil, ErrEmailReserved + } + // 检查邮箱是否已存在 existsEmail, err := s.userRepo.ExistsByEmail(ctx, email) if err != nil { @@ -266,7 +292,7 @@ func (s *AuthService) IsTurnstileEnabled(ctx context.Context) bool { // IsRegistrationEnabled 检查是否开放注册 func (s *AuthService) IsRegistrationEnabled(ctx context.Context) bool { if s.settingService == nil { - return true + return false // 安全默认:settingService 未配置时关闭注册 } return s.settingService.IsRegistrationEnabled(ctx) } @@ -311,6 +337,102 @@ func (s *AuthService) Login(ctx context.Context, email, password string) (string return token, user, nil } +// LoginOrRegisterOAuth 用于第三方 OAuth/SSO 登录: +// - 如果邮箱已存在:直接登录(不需要本地密码) +// - 如果邮箱不存在:创建新用户并登录 +// +// 注意:该函数用于“终端用户登录 Sub2API 本身”的场景(不同于上游账号的 OAuth,例如 OpenAI/Gemini)。 +// 为了满足现有数据库约束(需要密码哈希),新用户会生成随机密码并进行哈希保存。 +func (s *AuthService) LoginOrRegisterOAuth(ctx context.Context, email, username string) (string, *User, error) { + email = strings.TrimSpace(email) + if email == "" || len(email) > 255 { + return "", nil, infraerrors.BadRequest("INVALID_EMAIL", "invalid email") + } + if _, err := mail.ParseAddress(email); err != nil { + return "", nil, infraerrors.BadRequest("INVALID_EMAIL", "invalid email") + } + + username = strings.TrimSpace(username) + if len([]rune(username)) > 100 { + username = string([]rune(username)[:100]) + } + + user, err := s.userRepo.GetByEmail(ctx, email) + if err != nil { + if errors.Is(err, ErrUserNotFound) { + // OAuth 首次登录视为注册。 + if s.settingService != nil && !s.settingService.IsRegistrationEnabled(ctx) { + return "", nil, ErrRegDisabled + } + + randomPassword, err := randomHexString(32) + if err != nil { + log.Printf("[Auth] Failed to generate random password for oauth signup: %v", err) + return "", nil, ErrServiceUnavailable + } + hashedPassword, err := s.HashPassword(randomPassword) + if err != nil { + return "", nil, fmt.Errorf("hash password: %w", err) + } + + // 新用户默认值。 + defaultBalance := s.cfg.Default.UserBalance + defaultConcurrency := s.cfg.Default.UserConcurrency + if s.settingService != nil { + defaultBalance = s.settingService.GetDefaultBalance(ctx) + defaultConcurrency = s.settingService.GetDefaultConcurrency(ctx) + } + + newUser := &User{ + Email: email, + Username: username, + PasswordHash: hashedPassword, + Role: RoleUser, + Balance: defaultBalance, + Concurrency: defaultConcurrency, + Status: StatusActive, + } + + if err := s.userRepo.Create(ctx, newUser); err != nil { + if errors.Is(err, ErrEmailExists) { + // 并发场景:GetByEmail 与 Create 之间用户被创建。 + user, err = s.userRepo.GetByEmail(ctx, email) + if err != nil { + log.Printf("[Auth] Database error getting user after conflict: %v", err) + return "", nil, ErrServiceUnavailable + } + } else { + log.Printf("[Auth] Database error creating oauth user: %v", err) + return "", nil, ErrServiceUnavailable + } + } else { + user = newUser + } + } else { + log.Printf("[Auth] Database error during oauth login: %v", err) + return "", nil, ErrServiceUnavailable + } + } + + if !user.IsActive() { + return "", nil, ErrUserNotActive + } + + // 尽力补全:当用户名为空时,使用第三方返回的用户名回填。 + if user.Username == "" && username != "" { + user.Username = username + if err := s.userRepo.Update(ctx, user); err != nil { + log.Printf("[Auth] Failed to update username after oauth login: %v", err) + } + } + + token, err := s.GenerateToken(user) + if err != nil { + return "", nil, fmt.Errorf("generate token: %w", err) + } + return token, user, nil +} + // ValidateToken 验证JWT token并返回用户声明 func (s *AuthService) ValidateToken(tokenString string) (*JWTClaims, error) { // 先做长度校验,尽早拒绝异常超长 token,降低 DoS 风险。 @@ -336,6 +458,11 @@ func (s *AuthService) ValidateToken(tokenString string) (*JWTClaims, error) { if err != nil { if errors.Is(err, jwt.ErrTokenExpired) { + // token 过期但仍返回 claims(用于 RefreshToken 等场景) + // jwt-go 在解析时即使遇到过期错误,token.Claims 仍会被填充 + if claims, ok := token.Claims.(*JWTClaims); ok { + return claims, ErrTokenExpired + } return nil, ErrTokenExpired } return nil, ErrInvalidToken @@ -348,6 +475,22 @@ func (s *AuthService) ValidateToken(tokenString string) (*JWTClaims, error) { return nil, ErrInvalidToken } +func randomHexString(byteLength int) (string, error) { + if byteLength <= 0 { + byteLength = 16 + } + buf := make([]byte, byteLength) + if _, err := rand.Read(buf); err != nil { + return "", err + } + return hex.EncodeToString(buf), nil +} + +func isReservedEmail(email string) bool { + normalized := strings.ToLower(strings.TrimSpace(email)) + return strings.HasSuffix(normalized, LinuxDoConnectSyntheticEmailDomain) +} + // GenerateToken 生成JWT token func (s *AuthService) GenerateToken(user *User) (string, error) { now := time.Now() diff --git a/backend/internal/service/auth_service_register_test.go b/backend/internal/service/auth_service_register_test.go index cd6e2808..ab1f20a0 100644 --- a/backend/internal/service/auth_service_register_test.go +++ b/backend/internal/service/auth_service_register_test.go @@ -113,13 +113,36 @@ func TestAuthService_Register_Disabled(t *testing.T) { require.ErrorIs(t, err, ErrRegDisabled) } -func TestAuthService_Register_EmailVerifyRequired(t *testing.T) { +func TestAuthService_Register_DisabledByDefault(t *testing.T) { + // 当 settings 为 nil(设置项不存在)时,注册应该默认关闭 repo := &userRepoStub{} + service := newAuthService(repo, nil, nil) + + _, _, err := service.Register(context.Background(), "user@test.com", "password") + require.ErrorIs(t, err, ErrRegDisabled) +} + +func TestAuthService_Register_EmailVerifyEnabledButServiceNotConfigured(t *testing.T) { + repo := &userRepoStub{} + // 邮件验证开启但 emailCache 为 nil(emailService 未配置) service := newAuthService(repo, map[string]string{ SettingKeyRegistrationEnabled: "true", SettingKeyEmailVerifyEnabled: "true", }, nil) + // 应返回服务不可用错误,而不是允许绕过验证 + _, _, err := service.RegisterWithVerification(context.Background(), "user@test.com", "password", "any-code") + require.ErrorIs(t, err, ErrServiceUnavailable) +} + +func TestAuthService_Register_EmailVerifyRequired(t *testing.T) { + repo := &userRepoStub{} + cache := &emailCacheStub{} // 配置 emailService + service := newAuthService(repo, map[string]string{ + SettingKeyRegistrationEnabled: "true", + SettingKeyEmailVerifyEnabled: "true", + }, cache) + _, _, err := service.RegisterWithVerification(context.Background(), "user@test.com", "password", "") require.ErrorIs(t, err, ErrEmailVerifyRequired) } @@ -141,7 +164,9 @@ func TestAuthService_Register_EmailVerifyInvalid(t *testing.T) { func TestAuthService_Register_EmailExists(t *testing.T) { repo := &userRepoStub{exists: true} - service := newAuthService(repo, nil, nil) + service := newAuthService(repo, map[string]string{ + SettingKeyRegistrationEnabled: "true", + }, nil) _, _, err := service.Register(context.Background(), "user@test.com", "password") require.ErrorIs(t, err, ErrEmailExists) @@ -149,23 +174,50 @@ func TestAuthService_Register_EmailExists(t *testing.T) { func TestAuthService_Register_CheckEmailError(t *testing.T) { repo := &userRepoStub{existsErr: errors.New("db down")} - service := newAuthService(repo, nil, nil) + service := newAuthService(repo, map[string]string{ + SettingKeyRegistrationEnabled: "true", + }, nil) _, _, err := service.Register(context.Background(), "user@test.com", "password") require.ErrorIs(t, err, ErrServiceUnavailable) } +func TestAuthService_Register_ReservedEmail(t *testing.T) { + repo := &userRepoStub{} + service := newAuthService(repo, map[string]string{ + SettingKeyRegistrationEnabled: "true", + }, nil) + + _, _, err := service.Register(context.Background(), "linuxdo-123@linuxdo-connect.invalid", "password") + require.ErrorIs(t, err, ErrEmailReserved) +} + func TestAuthService_Register_CreateError(t *testing.T) { repo := &userRepoStub{createErr: errors.New("create failed")} - service := newAuthService(repo, nil, nil) + service := newAuthService(repo, map[string]string{ + SettingKeyRegistrationEnabled: "true", + }, nil) _, _, err := service.Register(context.Background(), "user@test.com", "password") require.ErrorIs(t, err, ErrServiceUnavailable) } +func TestAuthService_Register_CreateEmailExistsRace(t *testing.T) { + // 模拟竞态条件:ExistsByEmail 返回 false,但 Create 时因唯一约束失败 + repo := &userRepoStub{createErr: ErrEmailExists} + service := newAuthService(repo, map[string]string{ + SettingKeyRegistrationEnabled: "true", + }, nil) + + _, _, err := service.Register(context.Background(), "user@test.com", "password") + require.ErrorIs(t, err, ErrEmailExists) +} + func TestAuthService_Register_Success(t *testing.T) { repo := &userRepoStub{nextID: 5} - service := newAuthService(repo, nil, nil) + service := newAuthService(repo, map[string]string{ + SettingKeyRegistrationEnabled: "true", + }, nil) token, user, err := service.Register(context.Background(), "user@test.com", "password") require.NoError(t, err) @@ -180,3 +232,63 @@ func TestAuthService_Register_Success(t *testing.T) { require.Len(t, repo.created, 1) require.True(t, user.CheckPassword("password")) } + +func TestAuthService_ValidateToken_ExpiredReturnsClaimsWithError(t *testing.T) { + repo := &userRepoStub{} + service := newAuthService(repo, nil, nil) + + // 创建用户并生成 token + user := &User{ + ID: 1, + Email: "test@test.com", + Role: RoleUser, + Status: StatusActive, + TokenVersion: 1, + } + token, err := service.GenerateToken(user) + require.NoError(t, err) + + // 验证有效 token + claims, err := service.ValidateToken(token) + require.NoError(t, err) + require.NotNil(t, claims) + require.Equal(t, int64(1), claims.UserID) + + // 模拟过期 token(通过创建一个过期很久的 token) + service.cfg.JWT.ExpireHour = -1 // 设置为负数使 token 立即过期 + expiredToken, err := service.GenerateToken(user) + require.NoError(t, err) + service.cfg.JWT.ExpireHour = 1 // 恢复 + + // 验证过期 token 应返回 claims 和 ErrTokenExpired + claims, err = service.ValidateToken(expiredToken) + require.ErrorIs(t, err, ErrTokenExpired) + require.NotNil(t, claims, "claims should not be nil when token is expired") + require.Equal(t, int64(1), claims.UserID) + require.Equal(t, "test@test.com", claims.Email) +} + +func TestAuthService_RefreshToken_ExpiredTokenNoPanic(t *testing.T) { + user := &User{ + ID: 1, + Email: "test@test.com", + Role: RoleUser, + Status: StatusActive, + TokenVersion: 1, + } + repo := &userRepoStub{user: user} + service := newAuthService(repo, nil, nil) + + // 创建过期 token + service.cfg.JWT.ExpireHour = -1 + expiredToken, err := service.GenerateToken(user) + require.NoError(t, err) + service.cfg.JWT.ExpireHour = 1 + + // RefreshToken 使用过期 token 不应 panic + require.NotPanics(t, func() { + newToken, err := service.RefreshToken(context.Background(), expiredToken) + require.NoError(t, err) + require.NotEmpty(t, newToken) + }) +} diff --git a/backend/internal/service/claude_code_validator.go b/backend/internal/service/claude_code_validator.go new file mode 100644 index 00000000..ab86f1e8 --- /dev/null +++ b/backend/internal/service/claude_code_validator.go @@ -0,0 +1,265 @@ +package service + +import ( + "context" + "net/http" + "regexp" + "strings" + + "github.com/Wei-Shaw/sub2api/internal/pkg/ctxkey" +) + +// ClaudeCodeValidator 验证请求是否来自 Claude Code 客户端 +// 完全学习自 claude-relay-service 项目的验证逻辑 +type ClaudeCodeValidator struct{} + +var ( + // User-Agent 匹配: claude-cli/x.x.x (仅支持官方 CLI,大小写不敏感) + claudeCodeUAPattern = regexp.MustCompile(`(?i)^claude-cli/\d+\.\d+\.\d+`) + + // metadata.user_id 格式: user_{64位hex}_account__session_{uuid} + userIDPattern = regexp.MustCompile(`^user_[a-fA-F0-9]{64}_account__session_[\w-]+$`) + + // System prompt 相似度阈值(默认 0.5,和 claude-relay-service 一致) + systemPromptThreshold = 0.5 +) + +// Claude Code 官方 System Prompt 模板 +// 从 claude-relay-service/src/utils/contents.js 提取 +var claudeCodeSystemPrompts = []string{ + // claudeOtherSystemPrompt1 - Primary + "You are Claude Code, Anthropic's official CLI for Claude.", + + // claudeOtherSystemPrompt3 - Agent SDK + "You are a Claude agent, built on Anthropic's Claude Agent SDK.", + + // claudeOtherSystemPrompt4 - Compact Agent SDK + "You are Claude Code, Anthropic's official CLI for Claude, running within the Claude Agent SDK.", + + // exploreAgentSystemPrompt + "You are a file search specialist for Claude Code, Anthropic's official CLI for Claude.", + + // claudeOtherSystemPromptCompact - Compact (用于对话摘要) + "You are a helpful AI assistant tasked with summarizing conversations.", + + // claudeOtherSystemPrompt2 - Secondary (长提示词的关键部分) + "You are an interactive CLI tool that helps users", +} + +// NewClaudeCodeValidator 创建验证器实例 +func NewClaudeCodeValidator() *ClaudeCodeValidator { + return &ClaudeCodeValidator{} +} + +// Validate 验证请求是否来自 Claude Code CLI +// 采用与 claude-relay-service 完全一致的验证策略: +// +// Step 1: User-Agent 检查 (必需) - 必须是 claude-cli/x.x.x +// Step 2: 对于非 messages 路径,只要 UA 匹配就通过 +// Step 3: 对于 messages 路径,进行严格验证: +// - System prompt 相似度检查 +// - X-App header 检查 +// - anthropic-beta header 检查 +// - anthropic-version header 检查 +// - metadata.user_id 格式验证 +func (v *ClaudeCodeValidator) Validate(r *http.Request, body map[string]any) bool { + // Step 1: User-Agent 检查 + ua := r.Header.Get("User-Agent") + if !claudeCodeUAPattern.MatchString(ua) { + return false + } + + // Step 2: 非 messages 路径,只要 UA 匹配就通过 + path := r.URL.Path + if !strings.Contains(path, "messages") { + return true + } + + // Step 3: messages 路径,进行严格验证 + + // 3.1 检查 system prompt 相似度 + if !v.hasClaudeCodeSystemPrompt(body) { + return false + } + + // 3.2 检查必需的 headers(值不为空即可) + xApp := r.Header.Get("X-App") + if xApp == "" { + return false + } + + anthropicBeta := r.Header.Get("anthropic-beta") + if anthropicBeta == "" { + return false + } + + anthropicVersion := r.Header.Get("anthropic-version") + if anthropicVersion == "" { + return false + } + + // 3.3 验证 metadata.user_id + if body == nil { + return false + } + + metadata, ok := body["metadata"].(map[string]any) + if !ok { + return false + } + + userID, ok := metadata["user_id"].(string) + if !ok || userID == "" { + return false + } + + if !userIDPattern.MatchString(userID) { + return false + } + + return true +} + +// hasClaudeCodeSystemPrompt 检查请求是否包含 Claude Code 系统提示词 +// 使用字符串相似度匹配(Dice coefficient) +func (v *ClaudeCodeValidator) hasClaudeCodeSystemPrompt(body map[string]any) bool { + if body == nil { + return false + } + + // 检查 model 字段 + if _, ok := body["model"].(string); !ok { + return false + } + + // 获取 system 字段 + systemEntries, ok := body["system"].([]any) + if !ok { + return false + } + + // 检查每个 system entry + for _, entry := range systemEntries { + entryMap, ok := entry.(map[string]any) + if !ok { + continue + } + + text, ok := entryMap["text"].(string) + if !ok || text == "" { + continue + } + + // 计算与所有模板的最佳相似度 + bestScore := v.bestSimilarityScore(text) + if bestScore >= systemPromptThreshold { + return true + } + } + + return false +} + +// bestSimilarityScore 计算文本与所有 Claude Code 模板的最佳相似度 +func (v *ClaudeCodeValidator) bestSimilarityScore(text string) float64 { + normalizedText := normalizePrompt(text) + bestScore := 0.0 + + for _, template := range claudeCodeSystemPrompts { + normalizedTemplate := normalizePrompt(template) + score := diceCoefficient(normalizedText, normalizedTemplate) + if score > bestScore { + bestScore = score + } + } + + return bestScore +} + +// normalizePrompt 标准化提示词文本(去除多余空白) +func normalizePrompt(text string) string { + // 将所有空白字符替换为单个空格,并去除首尾空白 + return strings.Join(strings.Fields(text), " ") +} + +// diceCoefficient 计算两个字符串的 Dice 系数(Sørensen–Dice coefficient) +// 这是 string-similarity 库使用的算法 +// 公式: 2 * |intersection| / (|bigrams(a)| + |bigrams(b)|) +func diceCoefficient(a, b string) float64 { + if a == b { + return 1.0 + } + + if len(a) < 2 || len(b) < 2 { + return 0.0 + } + + // 生成 bigrams + bigramsA := getBigrams(a) + bigramsB := getBigrams(b) + + if len(bigramsA) == 0 || len(bigramsB) == 0 { + return 0.0 + } + + // 计算交集大小 + intersection := 0 + for bigram, countA := range bigramsA { + if countB, exists := bigramsB[bigram]; exists { + if countA < countB { + intersection += countA + } else { + intersection += countB + } + } + } + + // 计算总 bigram 数量 + totalA := 0 + for _, count := range bigramsA { + totalA += count + } + totalB := 0 + for _, count := range bigramsB { + totalB += count + } + + return float64(2*intersection) / float64(totalA+totalB) +} + +// getBigrams 获取字符串的所有 bigrams(相邻字符对) +func getBigrams(s string) map[string]int { + bigrams := make(map[string]int) + runes := []rune(strings.ToLower(s)) + + for i := 0; i < len(runes)-1; i++ { + bigram := string(runes[i : i+2]) + bigrams[bigram]++ + } + + return bigrams +} + +// ValidateUserAgent 仅验证 User-Agent(用于不需要解析请求体的场景) +func (v *ClaudeCodeValidator) ValidateUserAgent(ua string) bool { + return claudeCodeUAPattern.MatchString(ua) +} + +// IncludesClaudeCodeSystemPrompt 检查请求体是否包含 Claude Code 系统提示词 +// 只要存在匹配的系统提示词就返回 true(用于宽松检测) +func (v *ClaudeCodeValidator) IncludesClaudeCodeSystemPrompt(body map[string]any) bool { + return v.hasClaudeCodeSystemPrompt(body) +} + +// IsClaudeCodeClient 从 context 中获取 Claude Code 客户端标识 +func IsClaudeCodeClient(ctx context.Context) bool { + if v, ok := ctx.Value(ctxkey.IsClaudeCodeClient).(bool); ok { + return v + } + return false +} + +// SetClaudeCodeClient 将 Claude Code 客户端标识设置到 context 中 +func SetClaudeCodeClient(ctx context.Context, isClaudeCode bool) context.Context { + return context.WithValue(ctx, ctxkey.IsClaudeCodeClient, isClaudeCode) +} diff --git a/backend/internal/service/domain_constants.go b/backend/internal/service/domain_constants.go index 9c61ea2e..df34e167 100644 --- a/backend/internal/service/domain_constants.go +++ b/backend/internal/service/domain_constants.go @@ -105,7 +105,17 @@ const ( // Request identity patch (Claude -> Gemini systemInstruction injection) SettingKeyEnableIdentityPatch = "enable_identity_patch" SettingKeyIdentityPatchPrompt = "identity_patch_prompt" + + // LinuxDo Connect OAuth 登录(终端用户 SSO) + SettingKeyLinuxDoConnectEnabled = "linuxdo_connect_enabled" + SettingKeyLinuxDoConnectClientID = "linuxdo_connect_client_id" + SettingKeyLinuxDoConnectClientSecret = "linuxdo_connect_client_secret" + SettingKeyLinuxDoConnectRedirectURL = "linuxdo_connect_redirect_url" ) +// LinuxDoConnectSyntheticEmailDomain 是 LinuxDo Connect 用户的合成邮箱后缀(RFC 保留域名)。 +// 目的:避免第三方登录返回的用户标识与本地真实邮箱发生碰撞,进而造成账号被接管的风险。 +const LinuxDoConnectSyntheticEmailDomain = "@linuxdo-connect.invalid" + // AdminAPIKeyPrefix is the prefix for admin API keys (distinct from user "sk-" keys). const AdminAPIKeyPrefix = "admin-" diff --git a/backend/internal/service/email_service.go b/backend/internal/service/email_service.go index afd8907c..55e137d6 100644 --- a/backend/internal/service/email_service.go +++ b/backend/internal/service/email_service.go @@ -5,6 +5,7 @@ import ( "crypto/rand" "crypto/tls" "fmt" + "log" "math/big" "net/smtp" "strconv" @@ -256,7 +257,9 @@ func (s *EmailService) VerifyCode(ctx context.Context, email, code string) error // 验证码不匹配 if data.Code != code { data.Attempts++ - _ = s.cache.SetVerificationCode(ctx, email, data, verifyCodeTTL) + if err := s.cache.SetVerificationCode(ctx, email, data, verifyCodeTTL); err != nil { + log.Printf("[Email] Failed to update verification attempt count: %v", err) + } if data.Attempts >= maxVerifyCodeAttempts { return ErrVerifyCodeMaxAttempts } @@ -264,7 +267,9 @@ func (s *EmailService) VerifyCode(ctx context.Context, email, code string) error } // 验证成功,删除验证码 - _ = s.cache.DeleteVerificationCode(ctx, email) + if err := s.cache.DeleteVerificationCode(ctx, email); err != nil { + log.Printf("[Email] Failed to delete verification code after success: %v", err) + } return nil } diff --git a/backend/internal/service/gateway_multiplatform_test.go b/backend/internal/service/gateway_multiplatform_test.go index 8f29e07c..da7c311c 100644 --- a/backend/internal/service/gateway_multiplatform_test.go +++ b/backend/internal/service/gateway_multiplatform_test.go @@ -172,14 +172,14 @@ type mockGatewayCacheForPlatform struct { sessionBindings map[string]int64 } -func (m *mockGatewayCacheForPlatform) GetSessionAccountID(ctx context.Context, sessionHash string) (int64, error) { +func (m *mockGatewayCacheForPlatform) GetSessionAccountID(ctx context.Context, groupID int64, sessionHash string) (int64, error) { if id, ok := m.sessionBindings[sessionHash]; ok { return id, nil } return 0, errors.New("not found") } -func (m *mockGatewayCacheForPlatform) SetSessionAccountID(ctx context.Context, sessionHash string, accountID int64, ttl time.Duration) error { +func (m *mockGatewayCacheForPlatform) SetSessionAccountID(ctx context.Context, groupID int64, sessionHash string, accountID int64, ttl time.Duration) error { if m.sessionBindings == nil { m.sessionBindings = make(map[string]int64) } @@ -187,7 +187,7 @@ func (m *mockGatewayCacheForPlatform) SetSessionAccountID(ctx context.Context, s return nil } -func (m *mockGatewayCacheForPlatform) RefreshSessionTTL(ctx context.Context, sessionHash string, ttl time.Duration) error { +func (m *mockGatewayCacheForPlatform) RefreshSessionTTL(ctx context.Context, groupID int64, sessionHash string, ttl time.Duration) error { return nil } diff --git a/backend/internal/service/gateway_service.go b/backend/internal/service/gateway_service.go index 6da9b565..5871fddb 100644 --- a/backend/internal/service/gateway_service.go +++ b/backend/internal/service/gateway_service.go @@ -56,6 +56,9 @@ var ( } ) +// ErrClaudeCodeOnly 表示分组仅允许 Claude Code 客户端访问 +var ErrClaudeCodeOnly = errors.New("this group only allows Claude Code clients") + // allowedHeaders 白名单headers(参考CRS项目) var allowedHeaders = map[string]bool{ "accept": true, @@ -80,9 +83,17 @@ var allowedHeaders = map[string]bool{ // GatewayCache defines cache operations for gateway service type GatewayCache interface { - GetSessionAccountID(ctx context.Context, sessionHash string) (int64, error) - SetSessionAccountID(ctx context.Context, sessionHash string, accountID int64, ttl time.Duration) error - RefreshSessionTTL(ctx context.Context, sessionHash string, ttl time.Duration) error + GetSessionAccountID(ctx context.Context, groupID int64, sessionHash string) (int64, error) + SetSessionAccountID(ctx context.Context, groupID int64, sessionHash string, accountID int64, ttl time.Duration) error + RefreshSessionTTL(ctx context.Context, groupID int64, sessionHash string, ttl time.Duration) error +} + +// derefGroupID safely dereferences *int64 to int64, returning 0 if nil +func derefGroupID(groupID *int64) int64 { + if groupID == nil { + return 0 + } + return *groupID } type AccountWaitPlan struct { @@ -225,11 +236,11 @@ func (s *GatewayService) GenerateSessionHash(parsed *ParsedRequest) string { } // BindStickySession sets session -> account binding with standard TTL. -func (s *GatewayService) BindStickySession(ctx context.Context, sessionHash string, accountID int64) error { +func (s *GatewayService) BindStickySession(ctx context.Context, groupID *int64, sessionHash string, accountID int64) error { if sessionHash == "" || accountID <= 0 || s.cache == nil { return nil } - return s.cache.SetSessionAccountID(ctx, sessionHash, accountID, stickySessionTTL) + return s.cache.SetSessionAccountID(ctx, derefGroupID(groupID), sessionHash, accountID, stickySessionTTL) } func (s *GatewayService) extractCacheableContent(parsed *ParsedRequest) string { @@ -356,6 +367,21 @@ func (s *GatewayService) SelectAccountForModelWithExclusions(ctx context.Context return nil, fmt.Errorf("get group failed: %w", err) } platform = group.Platform + + // 检查 Claude Code 客户端限制 + if group.ClaudeCodeOnly { + isClaudeCode := IsClaudeCodeClient(ctx) + if !isClaudeCode { + // 非 Claude Code 客户端,检查是否有降级分组 + if group.FallbackGroupID != nil { + // 使用降级分组重新调度 + fallbackGroupID := *group.FallbackGroupID + return s.SelectAccountForModelWithExclusions(ctx, &fallbackGroupID, sessionHash, requestedModel, excludedIDs) + } + // 无降级分组,拒绝访问 + return nil, ErrClaudeCodeOnly + } + } } else { // 无分组时只使用原生 anthropic 平台 platform = PlatformAnthropic @@ -377,10 +403,17 @@ func (s *GatewayService) SelectAccountWithLoadAwareness(ctx context.Context, gro cfg := s.schedulingConfig() var stickyAccountID int64 if sessionHash != "" && s.cache != nil { - if accountID, err := s.cache.GetSessionAccountID(ctx, sessionHash); err == nil { + if accountID, err := s.cache.GetSessionAccountID(ctx, derefGroupID(groupID), sessionHash); err == nil { stickyAccountID = accountID } } + + // 检查 Claude Code 客户端限制(可能会替换 groupID 为降级分组) + groupID, err := s.checkClaudeCodeRestriction(ctx, groupID) + if err != nil { + return nil, err + } + if s.concurrencyService == nil || !cfg.LoadBatchEnabled { account, err := s.SelectAccountForModelWithExclusions(ctx, groupID, sessionHash, requestedModel, excludedIDs) if err != nil { @@ -443,7 +476,7 @@ func (s *GatewayService) SelectAccountWithLoadAwareness(ctx context.Context, gro // ============ Layer 1: 粘性会话优先 ============ if sessionHash != "" && s.cache != nil { - accountID, err := s.cache.GetSessionAccountID(ctx, sessionHash) + accountID, err := s.cache.GetSessionAccountID(ctx, derefGroupID(groupID), sessionHash) if err == nil && accountID > 0 && !isExcluded(accountID) { account, err := s.accountRepo.GetByID(ctx, accountID) if err == nil && s.isAccountInGroup(account, groupID) && @@ -452,7 +485,7 @@ func (s *GatewayService) SelectAccountWithLoadAwareness(ctx context.Context, gro (requestedModel == "" || s.isModelSupportedByAccount(account, requestedModel)) { result, err := s.tryAcquireAccountSlot(ctx, accountID, account.Concurrency) if err == nil && result.Acquired { - _ = s.cache.RefreshSessionTTL(ctx, sessionHash, stickySessionTTL) + _ = s.cache.RefreshSessionTTL(ctx, derefGroupID(groupID), sessionHash, stickySessionTTL) return &AccountSelectionResult{ Account: account, Acquired: true, @@ -509,7 +542,7 @@ func (s *GatewayService) SelectAccountWithLoadAwareness(ctx context.Context, gro loadMap, err := s.concurrencyService.GetAccountsLoadBatch(ctx, accountLoads) if err != nil { - if result, ok := s.tryAcquireByLegacyOrder(ctx, candidates, sessionHash, preferOAuth); ok { + if result, ok := s.tryAcquireByLegacyOrder(ctx, candidates, groupID, sessionHash, preferOAuth); ok { return result, nil } } else { @@ -559,7 +592,7 @@ func (s *GatewayService) SelectAccountWithLoadAwareness(ctx context.Context, gro result, err := s.tryAcquireAccountSlot(ctx, item.account.ID, item.account.Concurrency) if err == nil && result.Acquired { if sessionHash != "" && s.cache != nil { - _ = s.cache.SetSessionAccountID(ctx, sessionHash, item.account.ID, stickySessionTTL) + _ = s.cache.SetSessionAccountID(ctx, derefGroupID(groupID), sessionHash, item.account.ID, stickySessionTTL) } return &AccountSelectionResult{ Account: item.account, @@ -587,7 +620,7 @@ func (s *GatewayService) SelectAccountWithLoadAwareness(ctx context.Context, gro return nil, errors.New("no available accounts") } -func (s *GatewayService) tryAcquireByLegacyOrder(ctx context.Context, candidates []*Account, sessionHash string, preferOAuth bool) (*AccountSelectionResult, bool) { +func (s *GatewayService) tryAcquireByLegacyOrder(ctx context.Context, candidates []*Account, groupID *int64, sessionHash string, preferOAuth bool) (*AccountSelectionResult, bool) { ordered := append([]*Account(nil), candidates...) sortAccountsByPriorityAndLastUsed(ordered, preferOAuth) @@ -595,7 +628,7 @@ func (s *GatewayService) tryAcquireByLegacyOrder(ctx context.Context, candidates result, err := s.tryAcquireAccountSlot(ctx, acc.ID, acc.Concurrency) if err == nil && result.Acquired { if sessionHash != "" && s.cache != nil { - _ = s.cache.SetSessionAccountID(ctx, sessionHash, acc.ID, stickySessionTTL) + _ = s.cache.SetSessionAccountID(ctx, derefGroupID(groupID), sessionHash, acc.ID, stickySessionTTL) } return &AccountSelectionResult{ Account: acc, @@ -622,6 +655,42 @@ func (s *GatewayService) schedulingConfig() config.GatewaySchedulingConfig { } } +// checkClaudeCodeRestriction 检查分组的 Claude Code 客户端限制 +// 如果分组启用了 claude_code_only 且请求不是来自 Claude Code 客户端: +// - 有降级分组:返回降级分组的 ID +// - 无降级分组:返回 ErrClaudeCodeOnly 错误 +func (s *GatewayService) checkClaudeCodeRestriction(ctx context.Context, groupID *int64) (*int64, error) { + if groupID == nil { + return groupID, nil + } + + // 强制平台模式不检查 Claude Code 限制 + if _, hasForcePlatform := ctx.Value(ctxkey.ForcePlatform).(string); hasForcePlatform { + return groupID, nil + } + + group, err := s.groupRepo.GetByID(ctx, *groupID) + if err != nil { + return nil, fmt.Errorf("get group failed: %w", err) + } + + if !group.ClaudeCodeOnly { + return groupID, nil + } + + // 分组启用了 Claude Code 限制 + if IsClaudeCodeClient(ctx) { + return groupID, nil + } + + // 非 Claude Code 客户端,检查降级分组 + if group.FallbackGroupID != nil { + return group.FallbackGroupID, nil + } + + return nil, ErrClaudeCodeOnly +} + func (s *GatewayService) resolvePlatform(ctx context.Context, groupID *int64) (string, bool, error) { forcePlatform, hasForcePlatform := ctx.Value(ctxkey.ForcePlatform).(string) if hasForcePlatform && forcePlatform != "" { @@ -741,13 +810,13 @@ func (s *GatewayService) selectAccountForModelWithPlatform(ctx context.Context, preferOAuth := platform == PlatformGemini // 1. 查询粘性会话 if sessionHash != "" && s.cache != nil { - accountID, err := s.cache.GetSessionAccountID(ctx, sessionHash) + accountID, err := s.cache.GetSessionAccountID(ctx, derefGroupID(groupID), sessionHash) if err == nil && accountID > 0 { if _, excluded := excludedIDs[accountID]; !excluded { account, err := s.accountRepo.GetByID(ctx, accountID) // 检查账号分组归属和平台匹配(确保粘性会话不会跨分组或跨平台) if err == nil && s.isAccountInGroup(account, groupID) && account.Platform == platform && account.IsSchedulableForModel(requestedModel) && (requestedModel == "" || s.isModelSupportedByAccount(account, requestedModel)) { - if err := s.cache.RefreshSessionTTL(ctx, sessionHash, stickySessionTTL); err != nil { + if err := s.cache.RefreshSessionTTL(ctx, derefGroupID(groupID), sessionHash, stickySessionTTL); err != nil { log.Printf("refresh session ttl failed: session=%s err=%v", sessionHash, err) } return account, nil @@ -817,7 +886,7 @@ func (s *GatewayService) selectAccountForModelWithPlatform(ctx context.Context, // 4. 建立粘性绑定 if sessionHash != "" && s.cache != nil { - if err := s.cache.SetSessionAccountID(ctx, sessionHash, selected.ID, stickySessionTTL); err != nil { + if err := s.cache.SetSessionAccountID(ctx, derefGroupID(groupID), sessionHash, selected.ID, stickySessionTTL); err != nil { log.Printf("set session account failed: session=%s account_id=%d err=%v", sessionHash, selected.ID, err) } } @@ -833,14 +902,14 @@ func (s *GatewayService) selectAccountWithMixedScheduling(ctx context.Context, g // 1. 查询粘性会话 if sessionHash != "" && s.cache != nil { - accountID, err := s.cache.GetSessionAccountID(ctx, sessionHash) + accountID, err := s.cache.GetSessionAccountID(ctx, derefGroupID(groupID), sessionHash) if err == nil && accountID > 0 { if _, excluded := excludedIDs[accountID]; !excluded { account, err := s.accountRepo.GetByID(ctx, accountID) // 检查账号分组归属和有效性:原生平台直接匹配,antigravity 需要启用混合调度 if err == nil && s.isAccountInGroup(account, groupID) && account.IsSchedulableForModel(requestedModel) && (requestedModel == "" || s.isModelSupportedByAccount(account, requestedModel)) { if account.Platform == nativePlatform || (account.Platform == PlatformAntigravity && account.IsMixedSchedulingEnabled()) { - if err := s.cache.RefreshSessionTTL(ctx, sessionHash, stickySessionTTL); err != nil { + if err := s.cache.RefreshSessionTTL(ctx, derefGroupID(groupID), sessionHash, stickySessionTTL); err != nil { log.Printf("refresh session ttl failed: session=%s err=%v", sessionHash, err) } return account, nil @@ -912,7 +981,7 @@ func (s *GatewayService) selectAccountWithMixedScheduling(ctx context.Context, g // 4. 建立粘性绑定 if sessionHash != "" && s.cache != nil { - if err := s.cache.SetSessionAccountID(ctx, sessionHash, selected.ID, stickySessionTTL); err != nil { + if err := s.cache.SetSessionAccountID(ctx, derefGroupID(groupID), sessionHash, selected.ID, stickySessionTTL); err != nil { log.Printf("set session account failed: session=%s account_id=%d err=%v", sessionHash, selected.ID, err) } } diff --git a/backend/internal/service/gemini_messages_compat_service.go b/backend/internal/service/gemini_messages_compat_service.go index 13f644c8..2b500072 100644 --- a/backend/internal/service/gemini_messages_compat_service.go +++ b/backend/internal/service/gemini_messages_compat_service.go @@ -109,7 +109,7 @@ func (s *GeminiMessagesCompatService) SelectAccountForModelWithExclusions(ctx co cacheKey := "gemini:" + sessionHash if sessionHash != "" { - accountID, err := s.cache.GetSessionAccountID(ctx, cacheKey) + accountID, err := s.cache.GetSessionAccountID(ctx, derefGroupID(groupID), cacheKey) if err == nil && accountID > 0 { if _, excluded := excludedIDs[accountID]; !excluded { account, err := s.accountRepo.GetByID(ctx, accountID) @@ -133,7 +133,7 @@ func (s *GeminiMessagesCompatService) SelectAccountForModelWithExclusions(ctx co } } if usable { - _ = s.cache.RefreshSessionTTL(ctx, cacheKey, geminiStickySessionTTL) + _ = s.cache.RefreshSessionTTL(ctx, derefGroupID(groupID), cacheKey, geminiStickySessionTTL) return account, nil } } @@ -220,7 +220,7 @@ func (s *GeminiMessagesCompatService) SelectAccountForModelWithExclusions(ctx co } if sessionHash != "" { - _ = s.cache.SetSessionAccountID(ctx, cacheKey, selected.ID, geminiStickySessionTTL) + _ = s.cache.SetSessionAccountID(ctx, derefGroupID(groupID), cacheKey, selected.ID, geminiStickySessionTTL) } return selected, nil diff --git a/backend/internal/service/gemini_multiplatform_test.go b/backend/internal/service/gemini_multiplatform_test.go index 794e56a7..d9df5f4c 100644 --- a/backend/internal/service/gemini_multiplatform_test.go +++ b/backend/internal/service/gemini_multiplatform_test.go @@ -172,7 +172,7 @@ func (m *mockGroupRepoForGemini) DeleteCascade(ctx context.Context, id int64) ([ func (m *mockGroupRepoForGemini) List(ctx context.Context, params pagination.PaginationParams) ([]Group, *pagination.PaginationResult, error) { return nil, nil, nil } -func (m *mockGroupRepoForGemini) ListWithFilters(ctx context.Context, params pagination.PaginationParams, platform, status string, isExclusive *bool) ([]Group, *pagination.PaginationResult, error) { +func (m *mockGroupRepoForGemini) ListWithFilters(ctx context.Context, params pagination.PaginationParams, platform, status, search string, isExclusive *bool) ([]Group, *pagination.PaginationResult, error) { return nil, nil, nil } func (m *mockGroupRepoForGemini) ListActive(ctx context.Context) ([]Group, error) { return nil, nil } @@ -196,14 +196,14 @@ type mockGatewayCacheForGemini struct { sessionBindings map[string]int64 } -func (m *mockGatewayCacheForGemini) GetSessionAccountID(ctx context.Context, sessionHash string) (int64, error) { +func (m *mockGatewayCacheForGemini) GetSessionAccountID(ctx context.Context, groupID int64, sessionHash string) (int64, error) { if id, ok := m.sessionBindings[sessionHash]; ok { return id, nil } return 0, errors.New("not found") } -func (m *mockGatewayCacheForGemini) SetSessionAccountID(ctx context.Context, sessionHash string, accountID int64, ttl time.Duration) error { +func (m *mockGatewayCacheForGemini) SetSessionAccountID(ctx context.Context, groupID int64, sessionHash string, accountID int64, ttl time.Duration) error { if m.sessionBindings == nil { m.sessionBindings = make(map[string]int64) } @@ -211,7 +211,7 @@ func (m *mockGatewayCacheForGemini) SetSessionAccountID(ctx context.Context, ses return nil } -func (m *mockGatewayCacheForGemini) RefreshSessionTTL(ctx context.Context, sessionHash string, ttl time.Duration) error { +func (m *mockGatewayCacheForGemini) RefreshSessionTTL(ctx context.Context, groupID int64, sessionHash string, ttl time.Duration) error { return nil } diff --git a/backend/internal/service/gemini_oauth_service.go b/backend/internal/service/gemini_oauth_service.go index 48d31da9..bc84baeb 100644 --- a/backend/internal/service/gemini_oauth_service.go +++ b/backend/internal/service/gemini_oauth_service.go @@ -120,15 +120,16 @@ func (s *GeminiOAuthService) GenerateAuthURL(ctx context.Context, proxyID *int64 } // OAuth client selection: - // - code_assist: always use built-in Gemini CLI OAuth client (public), regardless of configured client_id/secret. - // - google_one: uses configured OAuth client when provided; otherwise falls back to built-in client. - // - ai_studio: requires a user-provided OAuth client. + // - code_assist: always use built-in Gemini CLI OAuth client (public) + // - google_one: always use built-in Gemini CLI OAuth client (public) + // - ai_studio: requires a user-provided OAuth client oauthCfg := geminicli.OAuthConfig{ ClientID: s.cfg.Gemini.OAuth.ClientID, ClientSecret: s.cfg.Gemini.OAuth.ClientSecret, Scopes: s.cfg.Gemini.OAuth.Scopes, } - if oauthType == "code_assist" { + if oauthType == "code_assist" || oauthType == "google_one" { + // Force use of built-in Gemini CLI OAuth client oauthCfg.ClientID = "" oauthCfg.ClientSecret = "" } @@ -576,6 +577,20 @@ func (s *GeminiOAuthService) ExchangeCode(ctx context.Context, input *GeminiExch case "google_one": log.Printf("[GeminiOAuth] Processing google_one OAuth type") + + // Google One accounts use cloudaicompanion API, which requires a project_id. + // For personal accounts, Google auto-assigns a project_id via the LoadCodeAssist API. + if projectID == "" { + log.Printf("[GeminiOAuth] No project_id provided, attempting to fetch from LoadCodeAssist API...") + var err error + projectID, _, err = s.fetchProjectID(ctx, tokenResp.AccessToken, proxyURL) + if err != nil { + log.Printf("[GeminiOAuth] ERROR: Failed to fetch project_id: %v", err) + return nil, fmt.Errorf("google One accounts require a project_id, failed to auto-detect: %w", err) + } + log.Printf("[GeminiOAuth] Successfully fetched project_id: %s", projectID) + } + log.Printf("[GeminiOAuth] Attempting to fetch Google One tier from Drive API...") // Attempt to fetch Drive storage tier var storageInfo *geminicli.DriveStorageInfo diff --git a/backend/internal/service/gemini_oauth_service_test.go b/backend/internal/service/gemini_oauth_service_test.go index eb3d86e6..5591eb39 100644 --- a/backend/internal/service/gemini_oauth_service_test.go +++ b/backend/internal/service/gemini_oauth_service_test.go @@ -40,7 +40,7 @@ func TestGeminiOAuthService_GenerateAuthURL_RedirectURIStrategy(t *testing.T) { wantProjectID: "", }, { - name: "google_one uses custom client when configured and redirects to localhost", + name: "google_one always forces built-in client even when custom client configured", cfg: &config.Config{ Gemini: config.GeminiConfig{ OAuth: config.GeminiOAuthConfig{ @@ -50,9 +50,9 @@ func TestGeminiOAuthService_GenerateAuthURL_RedirectURIStrategy(t *testing.T) { }, }, oauthType: "google_one", - wantClientID: "custom-client-id", - wantRedirect: geminicli.AIStudioOAuthRedirectURI, - wantScope: geminicli.DefaultGoogleOneScopes, + wantClientID: geminicli.GeminiCLIOAuthClientID, + wantRedirect: geminicli.GeminiCLIRedirectURI, + wantScope: geminicli.DefaultCodeAssistScopes, wantProjectID: "", }, { diff --git a/backend/internal/service/group.go b/backend/internal/service/group.go index 01b6b513..80d89074 100644 --- a/backend/internal/service/group.go +++ b/backend/internal/service/group.go @@ -22,6 +22,10 @@ type Group struct { ImagePrice2K *float64 ImagePrice4K *float64 + // Claude Code 客户端限制 + ClaudeCodeOnly bool + FallbackGroupID *int64 + CreatedAt time.Time UpdatedAt time.Time diff --git a/backend/internal/service/group_service.go b/backend/internal/service/group_service.go index 403636e8..a444556f 100644 --- a/backend/internal/service/group_service.go +++ b/backend/internal/service/group_service.go @@ -21,7 +21,7 @@ type GroupRepository interface { DeleteCascade(ctx context.Context, id int64) ([]int64, error) List(ctx context.Context, params pagination.PaginationParams) ([]Group, *pagination.PaginationResult, error) - ListWithFilters(ctx context.Context, params pagination.PaginationParams, platform, status string, isExclusive *bool) ([]Group, *pagination.PaginationResult, error) + ListWithFilters(ctx context.Context, params pagination.PaginationParams, platform, status, search string, isExclusive *bool) ([]Group, *pagination.PaginationResult, error) ListActive(ctx context.Context) ([]Group, error) ListActiveByPlatform(ctx context.Context, platform string) ([]Group, error) diff --git a/backend/internal/service/openai_gateway_service.go b/backend/internal/service/openai_gateway_service.go index d744bfab..42e98585 100644 --- a/backend/internal/service/openai_gateway_service.go +++ b/backend/internal/service/openai_gateway_service.go @@ -134,11 +134,11 @@ func (s *OpenAIGatewayService) GenerateSessionHash(c *gin.Context) string { } // BindStickySession sets session -> account binding with standard TTL. -func (s *OpenAIGatewayService) BindStickySession(ctx context.Context, sessionHash string, accountID int64) error { +func (s *OpenAIGatewayService) BindStickySession(ctx context.Context, groupID *int64, sessionHash string, accountID int64) error { if sessionHash == "" || accountID <= 0 { return nil } - return s.cache.SetSessionAccountID(ctx, "openai:"+sessionHash, accountID, openaiStickySessionTTL) + return s.cache.SetSessionAccountID(ctx, derefGroupID(groupID), "openai:"+sessionHash, accountID, openaiStickySessionTTL) } // SelectAccount selects an OpenAI account with sticky session support @@ -155,13 +155,13 @@ func (s *OpenAIGatewayService) SelectAccountForModel(ctx context.Context, groupI func (s *OpenAIGatewayService) SelectAccountForModelWithExclusions(ctx context.Context, groupID *int64, sessionHash string, requestedModel string, excludedIDs map[int64]struct{}) (*Account, error) { // 1. Check sticky session if sessionHash != "" { - accountID, err := s.cache.GetSessionAccountID(ctx, "openai:"+sessionHash) + accountID, err := s.cache.GetSessionAccountID(ctx, derefGroupID(groupID), "openai:"+sessionHash) if err == nil && accountID > 0 { if _, excluded := excludedIDs[accountID]; !excluded { account, err := s.accountRepo.GetByID(ctx, accountID) if err == nil && account.IsSchedulable() && account.IsOpenAI() && (requestedModel == "" || account.IsModelSupported(requestedModel)) { // Refresh sticky session TTL - _ = s.cache.RefreshSessionTTL(ctx, "openai:"+sessionHash, openaiStickySessionTTL) + _ = s.cache.RefreshSessionTTL(ctx, derefGroupID(groupID), "openai:"+sessionHash, openaiStickySessionTTL) return account, nil } } @@ -227,7 +227,7 @@ func (s *OpenAIGatewayService) SelectAccountForModelWithExclusions(ctx context.C // 4. Set sticky session if sessionHash != "" { - _ = s.cache.SetSessionAccountID(ctx, "openai:"+sessionHash, selected.ID, openaiStickySessionTTL) + _ = s.cache.SetSessionAccountID(ctx, derefGroupID(groupID), "openai:"+sessionHash, selected.ID, openaiStickySessionTTL) } return selected, nil @@ -238,7 +238,7 @@ func (s *OpenAIGatewayService) SelectAccountWithLoadAwareness(ctx context.Contex cfg := s.schedulingConfig() var stickyAccountID int64 if sessionHash != "" && s.cache != nil { - if accountID, err := s.cache.GetSessionAccountID(ctx, "openai:"+sessionHash); err == nil { + if accountID, err := s.cache.GetSessionAccountID(ctx, derefGroupID(groupID), "openai:"+sessionHash); err == nil { stickyAccountID = accountID } } @@ -298,14 +298,14 @@ func (s *OpenAIGatewayService) SelectAccountWithLoadAwareness(ctx context.Contex // ============ Layer 1: Sticky session ============ if sessionHash != "" { - accountID, err := s.cache.GetSessionAccountID(ctx, "openai:"+sessionHash) + accountID, err := s.cache.GetSessionAccountID(ctx, derefGroupID(groupID), "openai:"+sessionHash) if err == nil && accountID > 0 && !isExcluded(accountID) { account, err := s.accountRepo.GetByID(ctx, accountID) if err == nil && account.IsSchedulable() && account.IsOpenAI() && (requestedModel == "" || account.IsModelSupported(requestedModel)) { result, err := s.tryAcquireAccountSlot(ctx, accountID, account.Concurrency) if err == nil && result.Acquired { - _ = s.cache.RefreshSessionTTL(ctx, "openai:"+sessionHash, openaiStickySessionTTL) + _ = s.cache.RefreshSessionTTL(ctx, derefGroupID(groupID), "openai:"+sessionHash, openaiStickySessionTTL) return &AccountSelectionResult{ Account: account, Acquired: true, @@ -362,7 +362,7 @@ func (s *OpenAIGatewayService) SelectAccountWithLoadAwareness(ctx context.Contex result, err := s.tryAcquireAccountSlot(ctx, acc.ID, acc.Concurrency) if err == nil && result.Acquired { if sessionHash != "" { - _ = s.cache.SetSessionAccountID(ctx, "openai:"+sessionHash, acc.ID, openaiStickySessionTTL) + _ = s.cache.SetSessionAccountID(ctx, derefGroupID(groupID), "openai:"+sessionHash, acc.ID, openaiStickySessionTTL) } return &AccountSelectionResult{ Account: acc, @@ -415,7 +415,7 @@ func (s *OpenAIGatewayService) SelectAccountWithLoadAwareness(ctx context.Contex result, err := s.tryAcquireAccountSlot(ctx, item.account.ID, item.account.Concurrency) if err == nil && result.Acquired { if sessionHash != "" { - _ = s.cache.SetSessionAccountID(ctx, "openai:"+sessionHash, item.account.ID, openaiStickySessionTTL) + _ = s.cache.SetSessionAccountID(ctx, derefGroupID(groupID), "openai:"+sessionHash, item.account.ID, openaiStickySessionTTL) } return &AccountSelectionResult{ Account: item.account, @@ -540,10 +540,19 @@ func (s *OpenAIGatewayService) Forward(ctx context.Context, c *gin.Context, acco bodyModified = true } - // For OAuth accounts using ChatGPT internal API, add store: false + // For OAuth accounts using ChatGPT internal API: + // 1. Add store: false + // 2. Normalize input format for Codex API compatibility if account.Type == AccountTypeOAuth { reqBody["store"] = false bodyModified = true + + // Normalize input format: convert AI SDK multi-part content format to simplified format + // AI SDK sends: {"content": [{"type": "input_text", "text": "..."}]} + // Codex API expects: {"content": "..."} + if normalizeInputForCodexAPI(reqBody) { + bodyModified = true + } } // Re-serialize body only if modified @@ -1085,6 +1094,101 @@ func (s *OpenAIGatewayService) replaceModelInResponseBody(body []byte, fromModel return newBody } +// normalizeInputForCodexAPI converts AI SDK multi-part content format to simplified format +// that the ChatGPT internal Codex API expects. +// +// AI SDK sends content as an array of typed objects: +// +// {"content": [{"type": "input_text", "text": "hello"}]} +// +// ChatGPT Codex API expects content as a simple string: +// +// {"content": "hello"} +// +// This function modifies reqBody in-place and returns true if any modification was made. +func normalizeInputForCodexAPI(reqBody map[string]any) bool { + input, ok := reqBody["input"] + if !ok { + return false + } + + // Handle case where input is a simple string (already compatible) + if _, isString := input.(string); isString { + return false + } + + // Handle case where input is an array of messages + inputArray, ok := input.([]any) + if !ok { + return false + } + + modified := false + for _, item := range inputArray { + message, ok := item.(map[string]any) + if !ok { + continue + } + + content, ok := message["content"] + if !ok { + continue + } + + // If content is already a string, no conversion needed + if _, isString := content.(string); isString { + continue + } + + // If content is an array (AI SDK format), convert to string + contentArray, ok := content.([]any) + if !ok { + continue + } + + // Extract text from content array + var textParts []string + for _, part := range contentArray { + partMap, ok := part.(map[string]any) + if !ok { + continue + } + + // Handle different content types + partType, _ := partMap["type"].(string) + switch partType { + case "input_text", "text": + // Extract text from input_text or text type + if text, ok := partMap["text"].(string); ok { + textParts = append(textParts, text) + } + case "input_image", "image": + // For images, we need to preserve the original format + // as ChatGPT Codex API may support images in a different way + // For now, skip image parts (they will be lost in conversion) + // TODO: Consider preserving image data or handling it separately + continue + case "input_file", "file": + // Similar to images, file inputs may need special handling + continue + default: + // For unknown types, try to extract text if available + if text, ok := partMap["text"].(string); ok { + textParts = append(textParts, text) + } + } + } + + // Convert content array to string + if len(textParts) > 0 { + message["content"] = strings.Join(textParts, "\n") + modified = true + } + } + + return modified +} + // OpenAIRecordUsageInput input for recording usage type OpenAIRecordUsageInput struct { Result *OpenAIForwardResult diff --git a/backend/internal/service/proxy_service.go b/backend/internal/service/proxy_service.go index 044f9ffc..58408d04 100644 --- a/backend/internal/service/proxy_service.go +++ b/backend/internal/service/proxy_service.go @@ -20,6 +20,7 @@ type ProxyRepository interface { List(ctx context.Context, params pagination.PaginationParams) ([]Proxy, *pagination.PaginationResult, error) ListWithFilters(ctx context.Context, params pagination.PaginationParams, protocol, status, search string) ([]Proxy, *pagination.PaginationResult, error) + ListWithFiltersAndAccountCount(ctx context.Context, params pagination.PaginationParams, protocol, status, search string) ([]ProxyWithAccountCount, *pagination.PaginationResult, error) ListActive(ctx context.Context) ([]Proxy, error) ListActiveWithAccountCount(ctx context.Context) ([]ProxyWithAccountCount, error) diff --git a/backend/internal/service/setting_service.go b/backend/internal/service/setting_service.go index 6ce8ba2b..d25698de 100644 --- a/backend/internal/service/setting_service.go +++ b/backend/internal/service/setting_service.go @@ -7,6 +7,7 @@ import ( "errors" "fmt" "strconv" + "strings" "github.com/Wei-Shaw/sub2api/internal/config" infraerrors "github.com/Wei-Shaw/sub2api/internal/pkg/errors" @@ -64,6 +65,7 @@ func (s *SettingService) GetPublicSettings(ctx context.Context) (*PublicSettings SettingKeyAPIBaseURL, SettingKeyContactInfo, SettingKeyDocURL, + SettingKeyLinuxDoConnectEnabled, } settings, err := s.settingRepo.GetMultiple(ctx, keys) @@ -71,6 +73,13 @@ func (s *SettingService) GetPublicSettings(ctx context.Context) (*PublicSettings return nil, fmt.Errorf("get public settings: %w", err) } + linuxDoEnabled := false + if raw, ok := settings[SettingKeyLinuxDoConnectEnabled]; ok { + linuxDoEnabled = raw == "true" + } else { + linuxDoEnabled = s.cfg != nil && s.cfg.LinuxDo.Enabled + } + return &PublicSettings{ RegistrationEnabled: settings[SettingKeyRegistrationEnabled] == "true", EmailVerifyEnabled: settings[SettingKeyEmailVerifyEnabled] == "true", @@ -82,6 +91,7 @@ func (s *SettingService) GetPublicSettings(ctx context.Context) (*PublicSettings APIBaseURL: settings[SettingKeyAPIBaseURL], ContactInfo: settings[SettingKeyContactInfo], DocURL: settings[SettingKeyDocURL], + LinuxDoOAuthEnabled: linuxDoEnabled, }, nil } @@ -111,6 +121,14 @@ func (s *SettingService) UpdateSettings(ctx context.Context, settings *SystemSet updates[SettingKeyTurnstileSecretKey] = settings.TurnstileSecretKey } + // LinuxDo Connect OAuth 登录(终端用户 SSO) + updates[SettingKeyLinuxDoConnectEnabled] = strconv.FormatBool(settings.LinuxDoConnectEnabled) + updates[SettingKeyLinuxDoConnectClientID] = settings.LinuxDoConnectClientID + updates[SettingKeyLinuxDoConnectRedirectURL] = settings.LinuxDoConnectRedirectURL + if settings.LinuxDoConnectClientSecret != "" { + updates[SettingKeyLinuxDoConnectClientSecret] = settings.LinuxDoConnectClientSecret + } + // OEM设置 updates[SettingKeySiteName] = settings.SiteName updates[SettingKeySiteLogo] = settings.SiteLogo @@ -141,8 +159,8 @@ func (s *SettingService) UpdateSettings(ctx context.Context, settings *SystemSet func (s *SettingService) IsRegistrationEnabled(ctx context.Context) bool { value, err := s.settingRepo.GetValue(ctx, SettingKeyRegistrationEnabled) if err != nil { - // 默认开放注册 - return true + // 安全默认:如果设置不存在或查询出错,默认关闭注册 + return false } return value == "true" } @@ -271,6 +289,38 @@ func (s *SettingService) parseSettings(settings map[string]string) *SystemSettin result.SMTPPassword = settings[SettingKeySMTPPassword] result.TurnstileSecretKey = settings[SettingKeyTurnstileSecretKey] + // LinuxDo Connect 设置: + // - 兼容 config.yaml/env(避免老部署因为未迁移到数据库设置而被意外关闭) + // - 支持在后台“系统设置”中覆盖并持久化(存储于 DB) + linuxDoBase := config.LinuxDoConnectConfig{} + if s.cfg != nil { + linuxDoBase = s.cfg.LinuxDo + } + + if raw, ok := settings[SettingKeyLinuxDoConnectEnabled]; ok { + result.LinuxDoConnectEnabled = raw == "true" + } else { + result.LinuxDoConnectEnabled = linuxDoBase.Enabled + } + + if v, ok := settings[SettingKeyLinuxDoConnectClientID]; ok && strings.TrimSpace(v) != "" { + result.LinuxDoConnectClientID = strings.TrimSpace(v) + } else { + result.LinuxDoConnectClientID = linuxDoBase.ClientID + } + + if v, ok := settings[SettingKeyLinuxDoConnectRedirectURL]; ok && strings.TrimSpace(v) != "" { + result.LinuxDoConnectRedirectURL = strings.TrimSpace(v) + } else { + result.LinuxDoConnectRedirectURL = linuxDoBase.RedirectURL + } + + result.LinuxDoConnectClientSecret = strings.TrimSpace(settings[SettingKeyLinuxDoConnectClientSecret]) + if result.LinuxDoConnectClientSecret == "" { + result.LinuxDoConnectClientSecret = strings.TrimSpace(linuxDoBase.ClientSecret) + } + result.LinuxDoConnectClientSecretConfigured = result.LinuxDoConnectClientSecret != "" + // Model fallback settings result.EnableModelFallback = settings[SettingKeyEnableModelFallback] == "true" result.FallbackModelAnthropic = s.getStringOrDefault(settings, SettingKeyFallbackModelAnthropic, "claude-3-5-sonnet-20241022") @@ -289,6 +339,99 @@ func (s *SettingService) parseSettings(settings map[string]string) *SystemSettin return result } +// GetLinuxDoConnectOAuthConfig 返回用于登录的“最终生效” LinuxDo Connect 配置。 +// +// 优先级: +// - 若对应系统设置键存在,则覆盖 config.yaml/env 的值 +// - 否则回退到 config.yaml/env 的值 +func (s *SettingService) GetLinuxDoConnectOAuthConfig(ctx context.Context) (config.LinuxDoConnectConfig, error) { + if s == nil || s.cfg == nil { + return config.LinuxDoConnectConfig{}, infraerrors.ServiceUnavailable("CONFIG_NOT_READY", "config not loaded") + } + + effective := s.cfg.LinuxDo + + keys := []string{ + SettingKeyLinuxDoConnectEnabled, + SettingKeyLinuxDoConnectClientID, + SettingKeyLinuxDoConnectClientSecret, + SettingKeyLinuxDoConnectRedirectURL, + } + settings, err := s.settingRepo.GetMultiple(ctx, keys) + if err != nil { + return config.LinuxDoConnectConfig{}, fmt.Errorf("get linuxdo connect settings: %w", err) + } + + if raw, ok := settings[SettingKeyLinuxDoConnectEnabled]; ok { + effective.Enabled = raw == "true" + } + if v, ok := settings[SettingKeyLinuxDoConnectClientID]; ok && strings.TrimSpace(v) != "" { + effective.ClientID = strings.TrimSpace(v) + } + if v, ok := settings[SettingKeyLinuxDoConnectClientSecret]; ok && strings.TrimSpace(v) != "" { + effective.ClientSecret = strings.TrimSpace(v) + } + if v, ok := settings[SettingKeyLinuxDoConnectRedirectURL]; ok && strings.TrimSpace(v) != "" { + effective.RedirectURL = strings.TrimSpace(v) + } + + if !effective.Enabled { + return config.LinuxDoConnectConfig{}, infraerrors.NotFound("OAUTH_DISABLED", "oauth login is disabled") + } + + // 基础健壮性校验(避免把用户重定向到一个必然失败或不安全的 OAuth 流程里)。 + if strings.TrimSpace(effective.ClientID) == "" { + return config.LinuxDoConnectConfig{}, infraerrors.InternalServer("OAUTH_CONFIG_INVALID", "oauth client id not configured") + } + if strings.TrimSpace(effective.AuthorizeURL) == "" { + return config.LinuxDoConnectConfig{}, infraerrors.InternalServer("OAUTH_CONFIG_INVALID", "oauth authorize url not configured") + } + if strings.TrimSpace(effective.TokenURL) == "" { + return config.LinuxDoConnectConfig{}, infraerrors.InternalServer("OAUTH_CONFIG_INVALID", "oauth token url not configured") + } + if strings.TrimSpace(effective.UserInfoURL) == "" { + return config.LinuxDoConnectConfig{}, infraerrors.InternalServer("OAUTH_CONFIG_INVALID", "oauth userinfo url not configured") + } + if strings.TrimSpace(effective.RedirectURL) == "" { + return config.LinuxDoConnectConfig{}, infraerrors.InternalServer("OAUTH_CONFIG_INVALID", "oauth redirect url not configured") + } + if strings.TrimSpace(effective.FrontendRedirectURL) == "" { + return config.LinuxDoConnectConfig{}, infraerrors.InternalServer("OAUTH_CONFIG_INVALID", "oauth frontend redirect url not configured") + } + + if err := config.ValidateAbsoluteHTTPURL(effective.AuthorizeURL); err != nil { + return config.LinuxDoConnectConfig{}, infraerrors.InternalServer("OAUTH_CONFIG_INVALID", "oauth authorize url invalid") + } + if err := config.ValidateAbsoluteHTTPURL(effective.TokenURL); err != nil { + return config.LinuxDoConnectConfig{}, infraerrors.InternalServer("OAUTH_CONFIG_INVALID", "oauth token url invalid") + } + if err := config.ValidateAbsoluteHTTPURL(effective.UserInfoURL); err != nil { + return config.LinuxDoConnectConfig{}, infraerrors.InternalServer("OAUTH_CONFIG_INVALID", "oauth userinfo url invalid") + } + if err := config.ValidateAbsoluteHTTPURL(effective.RedirectURL); err != nil { + return config.LinuxDoConnectConfig{}, infraerrors.InternalServer("OAUTH_CONFIG_INVALID", "oauth redirect url invalid") + } + if err := config.ValidateFrontendRedirectURL(effective.FrontendRedirectURL); err != nil { + return config.LinuxDoConnectConfig{}, infraerrors.InternalServer("OAUTH_CONFIG_INVALID", "oauth frontend redirect url invalid") + } + + method := strings.ToLower(strings.TrimSpace(effective.TokenAuthMethod)) + switch method { + case "", "client_secret_post", "client_secret_basic": + if strings.TrimSpace(effective.ClientSecret) == "" { + return config.LinuxDoConnectConfig{}, infraerrors.InternalServer("OAUTH_CONFIG_INVALID", "oauth client secret not configured") + } + case "none": + if !effective.UsePKCE { + return config.LinuxDoConnectConfig{}, infraerrors.InternalServer("OAUTH_CONFIG_INVALID", "oauth pkce must be enabled when token_auth_method=none") + } + default: + return config.LinuxDoConnectConfig{}, infraerrors.InternalServer("OAUTH_CONFIG_INVALID", "oauth token_auth_method invalid") + } + + return effective, nil +} + // getStringOrDefault 获取字符串值或默认值 func (s *SettingService) getStringOrDefault(settings map[string]string, key, defaultValue string) string { if value, ok := settings[key]; ok && value != "" { diff --git a/backend/internal/service/settings_view.go b/backend/internal/service/settings_view.go index de0331f7..26051418 100644 --- a/backend/internal/service/settings_view.go +++ b/backend/internal/service/settings_view.go @@ -18,6 +18,13 @@ type SystemSettings struct { TurnstileSecretKey string TurnstileSecretKeyConfigured bool + // LinuxDo Connect OAuth 登录(终端用户 SSO) + LinuxDoConnectEnabled bool + LinuxDoConnectClientID string + LinuxDoConnectClientSecret string + LinuxDoConnectClientSecretConfigured bool + LinuxDoConnectRedirectURL string + SiteName string SiteLogo string SiteSubtitle string @@ -51,5 +58,6 @@ type PublicSettings struct { APIBaseURL string ContactInfo string DocURL string + LinuxDoOAuthEnabled bool Version string } diff --git a/backend/migrations/029_add_group_claude_code_restriction.sql b/backend/migrations/029_add_group_claude_code_restriction.sql new file mode 100644 index 00000000..6185704d --- /dev/null +++ b/backend/migrations/029_add_group_claude_code_restriction.sql @@ -0,0 +1,21 @@ +-- 029_add_group_claude_code_restriction.sql +-- 添加分组级别的 Claude Code 客户端限制功能 + +-- 添加 claude_code_only 字段:是否仅允许 Claude Code 客户端 +ALTER TABLE groups +ADD COLUMN IF NOT EXISTS claude_code_only BOOLEAN NOT NULL DEFAULT FALSE; + +-- 添加 fallback_group_id 字段:非 Claude Code 请求降级到的分组 +ALTER TABLE groups +ADD COLUMN IF NOT EXISTS fallback_group_id BIGINT REFERENCES groups(id) ON DELETE SET NULL; + +-- 添加索引优化查询 +CREATE INDEX IF NOT EXISTS idx_groups_claude_code_only +ON groups(claude_code_only) WHERE deleted_at IS NULL; + +CREATE INDEX IF NOT EXISTS idx_groups_fallback_group_id +ON groups(fallback_group_id) WHERE deleted_at IS NULL AND fallback_group_id IS NOT NULL; + +-- 添加字段注释 +COMMENT ON COLUMN groups.claude_code_only IS '是否仅允许 Claude Code 客户端访问此分组'; +COMMENT ON COLUMN groups.fallback_group_id IS '非 Claude Code 请求降级使用的分组 ID'; diff --git a/deploy/config.example.yaml b/deploy/config.example.yaml index 60d79377..87ff3148 100644 --- a/deploy/config.example.yaml +++ b/deploy/config.example.yaml @@ -234,6 +234,31 @@ jwt: # 令牌过期时间(小时,最大 24) expire_hour: 24 +# ============================================================================= +# LinuxDo Connect OAuth Login (SSO) +# LinuxDo Connect OAuth 登录(用于 Sub2API 用户登录) +# ============================================================================= +linuxdo_connect: + enabled: false + client_id: "" + client_secret: "" + authorize_url: "https://connect.linux.do/oauth2/authorize" + token_url: "https://connect.linux.do/oauth2/token" + userinfo_url: "https://connect.linux.do/api/user" + scopes: "user" + # 示例: "https://your-domain.com/api/v1/auth/oauth/linuxdo/callback" + redirect_url: "" + # 安全提示: + # - 建议使用同源相对路径(以 / 开头),避免把 token 重定向到意外的第三方域名 + # - 该地址不应包含 #fragment(本实现使用 URL fragment 传递 access_token) + frontend_redirect_url: "/auth/linuxdo/callback" + token_auth_method: "client_secret_post" # client_secret_post | client_secret_basic | none + # 注意:当 token_auth_method=none(public client)时,必须启用 PKCE + use_pkce: false + userinfo_email_path: "" + userinfo_id_path: "" + userinfo_username_path: "" + # ============================================================================= # Default Settings # 默认设置 diff --git a/deploy/docker-compose.yml b/deploy/docker-compose.yml index 6a370e9a..484df3a8 100644 --- a/deploy/docker-compose.yml +++ b/deploy/docker-compose.yml @@ -173,11 +173,12 @@ services: volumes: - redis_data:/data command: > - redis-server - --save 60 1 - --appendonly yes - --appendfsync everysec - ${REDIS_PASSWORD:+--requirepass ${REDIS_PASSWORD}} + sh -c ' + redis-server + --save 60 1 + --appendonly yes + --appendfsync everysec + ${REDIS_PASSWORD:+--requirepass "$REDIS_PASSWORD"}' environment: - TZ=${TZ:-Asia/Shanghai} # REDISCLI_AUTH is used by redis-cli for authentication (safer than -a flag) diff --git a/frontend/src/api/admin/groups.ts b/frontend/src/api/admin/groups.ts index 23db9104..44eebc99 100644 --- a/frontend/src/api/admin/groups.ts +++ b/frontend/src/api/admin/groups.ts @@ -16,7 +16,7 @@ import type { * List all groups with pagination * @param page - Page number (default: 1) * @param pageSize - Items per page (default: 20) - * @param filters - Optional filters (platform, status, is_exclusive) + * @param filters - Optional filters (platform, status, is_exclusive, search) * @returns Paginated list of groups */ export async function list( @@ -26,6 +26,7 @@ export async function list( platform?: GroupPlatform status?: 'active' | 'inactive' is_exclusive?: boolean + search?: string }, options?: { signal?: AbortSignal diff --git a/frontend/src/api/admin/settings.ts b/frontend/src/api/admin/settings.ts index 6b46de7d..2f6991e7 100644 --- a/frontend/src/api/admin/settings.ts +++ b/frontend/src/api/admin/settings.ts @@ -34,6 +34,11 @@ export interface SystemSettings { turnstile_enabled: boolean turnstile_site_key: string turnstile_secret_key_configured: boolean + // LinuxDo Connect OAuth 登录(终端用户 SSO) + linuxdo_connect_enabled: boolean + linuxdo_connect_client_id: string + linuxdo_connect_client_secret_configured: boolean + linuxdo_connect_redirect_url: string // Identity patch configuration (Claude -> Gemini) enable_identity_patch: boolean identity_patch_prompt: string @@ -60,6 +65,10 @@ export interface UpdateSettingsRequest { turnstile_enabled?: boolean turnstile_site_key?: string turnstile_secret_key?: string + linuxdo_connect_enabled?: boolean + linuxdo_connect_client_id?: string + linuxdo_connect_client_secret?: string + linuxdo_connect_redirect_url?: string enable_identity_patch?: boolean identity_patch_prompt?: string } diff --git a/frontend/src/components/account/CreateAccountModal.vue b/frontend/src/components/account/CreateAccountModal.vue index e90bec6c..5833632b 100644 --- a/frontend/src/components/account/CreateAccountModal.vue +++ b/frontend/src/components/account/CreateAccountModal.vue @@ -166,7 +166,7 @@ >
+ {{ t('admin.groups.claudeCode.tooltip') }} +
+ +{{ t('admin.groups.claudeCode.fallbackHint') }}
++ {{ t('admin.settings.linuxdo.description') }} +
++ {{ t('admin.settings.linuxdo.enableHint') }} +
++ {{ t('admin.settings.linuxdo.clientIdHint') }} +
++ {{ + form.linuxdo_connect_client_secret_configured + ? t('admin.settings.linuxdo.clientSecretConfiguredHint') + : t('admin.settings.linuxdo.clientSecretHint') + }} +
+
+ {{ linuxdoRedirectUrlSuggestion }}
+
+ + {{ t('admin.settings.linuxdo.redirectUrlHint') }} +
++ {{ isProcessing ? t('auth.linuxdo.callbackProcessing') : t('auth.linuxdo.callbackHint') }} +
++ {{ errorMessage }} +
+