diff --git a/.gitignore b/.gitignore index bfa6bb1b..2062600f 100644 --- a/.gitignore +++ b/.gitignore @@ -130,5 +130,4 @@ deploy/docker-compose.override.yml vite.config.js docs/* .serena/ - -frontend/coverage \ No newline at end of file +frontend/coverage/ diff --git a/DEV_GUIDE.md b/DEV_GUIDE.md new file mode 100644 index 00000000..541bf1fa --- /dev/null +++ b/DEV_GUIDE.md @@ -0,0 +1,323 @@ +# sub2api 项目开发指南 + +> 本文档记录项目环境配置、常见坑点和注意事项,供 Claude Code 和团队成员参考。 + +## 一、项目基本信息 + +| 项目 | 说明 | +|------|------| +| **上游仓库** | Wei-Shaw/sub2api | +| **Fork 仓库** | bayma888/sub2api-bmai | +| **技术栈** | Go 后端 (Ent ORM + Gin) + Vue3 前端 (pnpm) | +| **数据库** | PostgreSQL 16 + Redis | +| **包管理** | 后端: go modules, 前端: **pnpm**(不是 npm) | + +## 二、本地环境配置 + +### PostgreSQL 16 (Windows 服务) + +| 配置项 | 值 | +|--------|-----| +| 端口 | 5432 | +| psql 路径 | `C:\Program Files\PostgreSQL\16\bin\psql.exe` | +| pg_hba.conf | `C:\Program Files\PostgreSQL\16\data\pg_hba.conf` | +| 数据库凭据 | user=`sub2api`, password=`sub2api`, dbname=`sub2api` | +| 超级用户 | user=`postgres`, password=`postgres` | + +### Redis + +| 配置项 | 值 | +|--------|-----| +| 端口 | 6379 | +| 密码 | 无 | + +### 开发工具 + +```bash +# golangci-lint v2.7 +go install github.com/golangci/golangci-lint/v2/cmd/golangci-lint@v2.7 + +# pnpm (前端包管理) +npm install -g pnpm +``` + +## 三、CI/CD 流水线 + +### GitHub Actions Workflows + +| Workflow | 触发条件 | 检查内容 | +|----------|----------|----------| +| **backend-ci.yml** | push, pull_request | 单元测试 + 集成测试 + golangci-lint v2.7 | +| **security-scan.yml** | push, pull_request, 每周一 | govulncheck + gosec + pnpm audit | +| **release.yml** | tag `v*` | 构建发布(PR 不触发) | + +### CI 要求 + +- Go 版本必须是 **1.25.7** +- 前端使用 `pnpm install --frozen-lockfile`,必须提交 `pnpm-lock.yaml` + +### 本地测试命令 + +```bash +# 后端单元测试 +cd backend && go test -tags=unit ./... + +# 后端集成测试 +cd backend && go test -tags=integration ./... + +# 代码质量检查 +cd backend && golangci-lint run ./... + +# 前端依赖安装(必须用 pnpm) +cd frontend && pnpm install +``` + +## 四、常见坑点 & 解决方案 + +### 坑 1:pnpm-lock.yaml 必须同步提交 + +**问题**:`package.json` 新增依赖后,CI 的 `pnpm install --frozen-lockfile` 失败。 + +**原因**:上游 CI 使用 pnpm,lock 文件不同步会报错。 + +**解决**: +```bash +cd frontend +pnpm install # 更新 pnpm-lock.yaml +git add pnpm-lock.yaml +git commit -m "chore: update pnpm-lock.yaml" +``` + +--- + +### 坑 2:npm 和 pnpm 的 node_modules 冲突 + +**问题**:之前用 npm 装过 `node_modules`,pnpm install 报 `EPERM` 错误。 + +**解决**: +```bash +cd frontend +rm -rf node_modules # 或 PowerShell: Remove-Item -Recurse -Force node_modules +pnpm install +``` + +--- + +### 坑 3:PowerShell 中 bcrypt hash 的 `$` 被转义 + +**问题**:bcrypt hash 格式如 `$2a$10$xxx...`,PowerShell 把 `$2a` 当变量解析,导致数据丢失。 + +**解决**:将 SQL 写入文件,用 `psql -f` 执行: +```bash +# 错误示范(PowerShell 会吃掉 $) +psql -c "INSERT INTO users ... VALUES ('$2a$10$...')" + +# 正确做法 +echo "INSERT INTO users ... VALUES ('\$2a\$10\$...')" > temp.sql +psql -U sub2api -h 127.0.0.1 -d sub2api -f temp.sql +``` + +--- + +### 坑 4:psql 不支持中文路径 + +**问题**:`psql -f "D:\中文路径\file.sql"` 报错找不到文件。 + +**解决**:复制到纯英文路径再执行: +```bash +cp "D:\中文路径\file.sql" "C:\temp.sql" +psql -f "C:\temp.sql" +``` + +--- + +### 坑 5:PostgreSQL 密码重置流程 + +**场景**:忘记 PostgreSQL 密码。 + +**步骤**: +1. 修改 `C:\Program Files\PostgreSQL\16\data\pg_hba.conf` + ``` + # 将 scram-sha-256 改为 trust + host all all 127.0.0.1/32 trust + ``` +2. 重启 PostgreSQL 服务 + ```powershell + Restart-Service postgresql-x64-16 + ``` +3. 无密码登录并重置 + ```bash + psql -U postgres -h 127.0.0.1 + ALTER USER sub2api WITH PASSWORD 'sub2api'; + ALTER USER postgres WITH PASSWORD 'postgres'; + ``` +4. 改回 `scram-sha-256` 并重启 + +--- + +### 坑 6:Go interface 新增方法后 test stub 必须补全 + +**问题**:给 interface 新增方法后,编译报错 `does not implement interface (missing method XXX)`。 + +**原因**:所有测试文件中实现该 interface 的 stub/mock 都必须补上新方法。 + +**解决**: +```bash +# 搜索所有实现该 interface 的 struct +cd backend +grep -r "type.*Stub.*struct" internal/ +grep -r "type.*Mock.*struct" internal/ + +# 逐一补全新方法 +``` + +--- + +### 坑 7:Windows 上 psql 连 localhost 的 IPv6 问题 + +**问题**:psql 连 `localhost` 先尝试 IPv6 (::1),可能报错后再回退 IPv4。 + +**建议**:直接用 `127.0.0.1` 代替 `localhost`。 + +--- + +### 坑 8:Windows 没有 make 命令 + +**问题**:CI 里用 `make test-unit`,本地 Windows 没有 make。 + +**解决**:直接用 Makefile 里的原始命令: +```bash +# 代替 make test-unit +go test -tags=unit ./... + +# 代替 make test-integration +go test -tags=integration ./... +``` + +--- + +### 坑 9:Ent Schema 修改后必须重新生成 + +**问题**:修改 `ent/schema/*.go` 后,代码不生效。 + +**解决**: +```bash +cd backend +go generate ./ent # 重新生成 ent 代码 +git add ent/ # 生成的文件也要提交 +``` + +--- + +### 坑 10:PR 提交前检查清单 + +提交 PR 前务必本地验证: + +- [ ] `go test -tags=unit ./...` 通过 +- [ ] `go test -tags=integration ./...` 通过 +- [ ] `golangci-lint run ./...` 无新增问题 +- [ ] `pnpm-lock.yaml` 已同步(如果改了 package.json) +- [ ] 所有 test stub 补全新接口方法(如果改了 interface) +- [ ] Ent 生成的代码已提交(如果改了 schema) + +## 五、常用命令速查 + +### 数据库操作 + +```bash +# 连接数据库 +psql -U sub2api -h 127.0.0.1 -d sub2api + +# 查看所有用户 +psql -U postgres -h 127.0.0.1 -c "\du" + +# 查看所有数据库 +psql -U postgres -h 127.0.0.1 -c "\l" + +# 执行 SQL 文件 +psql -U sub2api -h 127.0.0.1 -d sub2api -f migration.sql +``` + +### Git 操作 + +```bash +# 同步上游 +git fetch upstream +git checkout main +git merge upstream/main +git push origin main + +# 创建功能分支 +git checkout -b feature/xxx + +# Rebase 到最新 main +git fetch upstream +git rebase upstream/main +``` + +### 前端操作 + +```bash +# 安装依赖(必须用 pnpm) +cd frontend +pnpm install + +# 开发服务器 +pnpm dev + +# 构建 +pnpm build +``` + +### 后端操作 + +```bash +# 运行服务器 +cd backend +go run ./cmd/server/ + +# 生成 Ent 代码 +go generate ./ent + +# 运行测试 +go test -tags=unit ./... +go test -tags=integration ./... + +# Lint 检查 +golangci-lint run ./... +``` + +## 六、项目结构速览 + +``` +sub2api-bmai/ +├── backend/ +│ ├── cmd/server/ # 主程序入口 +│ ├── ent/ # Ent ORM 生成代码 +│ │ └── schema/ # 数据库 Schema 定义 +│ ├── internal/ +│ │ ├── handler/ # HTTP 处理器 +│ │ ├── service/ # 业务逻辑 +│ │ ├── repository/ # 数据访问层 +│ │ └── server/ # 服务器配置 +│ ├── migrations/ # 数据库迁移脚本 +│ └── config.yaml # 配置文件 +├── frontend/ +│ ├── src/ +│ │ ├── api/ # API 调用 +│ │ ├── components/ # Vue 组件 +│ │ ├── views/ # 页面视图 +│ │ ├── types/ # TypeScript 类型 +│ │ └── i18n/ # 国际化 +│ ├── package.json # 依赖配置 +│ └── pnpm-lock.yaml # pnpm 锁文件(必须提交) +└── .claude/ + └── CLAUDE.md # 本文档 +``` + +## 七、参考资源 + +- [上游仓库](https://github.com/Wei-Shaw/sub2api) +- [Ent 文档](https://entgo.io/docs/getting-started) +- [Vue3 文档](https://vuejs.org/) +- [pnpm 文档](https://pnpm.io/) diff --git a/backend/cmd/server/VERSION b/backend/cmd/server/VERSION index 8d293fd1..a69656a3 100644 --- a/backend/cmd/server/VERSION +++ b/backend/cmd/server/VERSION @@ -1 +1 @@ -0.1.74.1 +0.1.74.2 diff --git a/backend/cmd/server/wire_gen.go b/backend/cmd/server/wire_gen.go index 341da381..8fb34a63 100644 --- a/backend/cmd/server/wire_gen.go +++ b/backend/cmd/server/wire_gen.go @@ -103,7 +103,9 @@ func initializeApplication(buildInfo handler.BuildInfo) (*Application, error) { proxyExitInfoProber := repository.NewProxyExitInfoProber(configConfig) proxyLatencyCache := repository.NewProxyLatencyCache(redisClient) adminService := service.NewAdminService(userRepository, groupRepository, accountRepository, soraAccountRepository, proxyRepository, apiKeyRepository, redeemCodeRepository, userGroupRateRepository, billingCacheService, proxyExitInfoProber, proxyLatencyCache, apiKeyAuthCacheInvalidator) - adminUserHandler := admin.NewUserHandler(adminService) + concurrencyCache := repository.ProvideConcurrencyCache(redisClient, configConfig) + concurrencyService := service.ProvideConcurrencyService(concurrencyCache, accountRepository, configConfig) + adminUserHandler := admin.NewUserHandler(adminService, concurrencyService) groupHandler := admin.NewGroupHandler(adminService) claudeOAuthClient := repository.NewClaudeOAuthClient() oAuthService := service.NewOAuthService(proxyRepository, claudeOAuthClient) @@ -127,13 +129,11 @@ func initializeApplication(buildInfo handler.BuildInfo) (*Application, error) { accountUsageService := service.NewAccountUsageService(accountRepository, usageLogRepository, claudeUsageFetcher, geminiQuotaService, antigravityQuotaFetcher, usageCache, identityCache) geminiTokenProvider := service.NewGeminiTokenProvider(accountRepository, geminiTokenCache, geminiOAuthService) gatewayCache := repository.NewGatewayCache(redisClient) - antigravityTokenProvider := service.NewAntigravityTokenProvider(accountRepository, geminiTokenCache, antigravityOAuthService) schedulerOutboxRepository := repository.NewSchedulerOutboxRepository(db) schedulerSnapshotService := service.ProvideSchedulerSnapshotService(schedulerCache, schedulerOutboxRepository, accountRepository, groupRepository, configConfig) + antigravityTokenProvider := service.NewAntigravityTokenProvider(accountRepository, geminiTokenCache, antigravityOAuthService) antigravityGatewayService := service.NewAntigravityGatewayService(accountRepository, gatewayCache, schedulerSnapshotService, antigravityTokenProvider, rateLimitService, httpUpstream, settingService) accountTestService := service.NewAccountTestService(accountRepository, geminiTokenProvider, antigravityGatewayService, httpUpstream, configConfig) - concurrencyCache := repository.ProvideConcurrencyCache(redisClient, configConfig) - concurrencyService := service.ProvideConcurrencyService(concurrencyCache, accountRepository, configConfig) crsSyncService := service.NewCRSSyncService(accountRepository, proxyRepository, oAuthService, openAIOAuthService, geminiOAuthService, configConfig) sessionLimitCache := repository.ProvideSessionLimitCache(redisClient, configConfig) accountHandler := admin.NewAccountHandler(adminService, oAuthService, openAIOAuthService, geminiOAuthService, antigravityOAuthService, rateLimitService, accountUsageService, accountTestService, concurrencyService, crsSyncService, sessionLimitCache, compositeTokenCacheInvalidator) @@ -155,7 +155,8 @@ func initializeApplication(buildInfo handler.BuildInfo) (*Application, error) { identityService := service.NewIdentityService(identityCache) deferredService := service.ProvideDeferredService(accountRepository, timingWheelService) claudeTokenProvider := service.NewClaudeTokenProvider(accountRepository, geminiTokenCache, oAuthService) - gatewayService := service.NewGatewayService(accountRepository, groupRepository, usageLogRepository, userRepository, userSubscriptionRepository, userGroupRateRepository, gatewayCache, configConfig, schedulerSnapshotService, concurrencyService, billingService, rateLimitService, billingCacheService, identityService, httpUpstream, deferredService, claudeTokenProvider, sessionLimitCache) + digestSessionStore := service.NewDigestSessionStore() + gatewayService := service.NewGatewayService(accountRepository, groupRepository, usageLogRepository, userRepository, userSubscriptionRepository, userGroupRateRepository, gatewayCache, configConfig, schedulerSnapshotService, concurrencyService, billingService, rateLimitService, billingCacheService, identityService, httpUpstream, deferredService, claudeTokenProvider, sessionLimitCache, digestSessionStore) openAITokenProvider := service.NewOpenAITokenProvider(accountRepository, geminiTokenCache, openAIOAuthService) openAIGatewayService := service.NewOpenAIGatewayService(accountRepository, usageLogRepository, userRepository, userSubscriptionRepository, gatewayCache, configConfig, schedulerSnapshotService, concurrencyService, billingService, rateLimitService, billingCacheService, httpUpstream, deferredService, openAITokenProvider) geminiMessagesCompatService := service.NewGeminiMessagesCompatService(accountRepository, groupRepository, gatewayCache, schedulerSnapshotService, geminiTokenProvider, rateLimitService, httpUpstream, antigravityGatewayService, configConfig) diff --git a/backend/ent/group.go b/backend/ent/group.go index 8bfdca42..79ec5bf5 100644 --- a/backend/ent/group.go +++ b/backend/ent/group.go @@ -74,6 +74,8 @@ type Group struct { McpXMLInject bool `json:"mcp_xml_inject,omitempty"` // 支持的模型系列:claude, gemini_text, gemini_image SupportedModelScopes []string `json:"supported_model_scopes,omitempty"` + // 分组显示排序,数值越小越靠前 + SortOrder int `json:"sort_order,omitempty"` // Edges holds the relations/edges for other nodes in the graph. // The values are being populated by the GroupQuery when eager-loading is set. Edges GroupEdges `json:"edges"` @@ -186,7 +188,7 @@ func (*Group) scanValues(columns []string) ([]any, error) { values[i] = new(sql.NullBool) case group.FieldRateMultiplier, group.FieldDailyLimitUsd, group.FieldWeeklyLimitUsd, group.FieldMonthlyLimitUsd, group.FieldImagePrice1k, group.FieldImagePrice2k, group.FieldImagePrice4k, group.FieldSoraImagePrice360, group.FieldSoraImagePrice540, group.FieldSoraVideoPricePerRequest, group.FieldSoraVideoPricePerRequestHd: values[i] = new(sql.NullFloat64) - case group.FieldID, group.FieldDefaultValidityDays, group.FieldFallbackGroupID, group.FieldFallbackGroupIDOnInvalidRequest: + case group.FieldID, group.FieldDefaultValidityDays, group.FieldFallbackGroupID, group.FieldFallbackGroupIDOnInvalidRequest, group.FieldSortOrder: values[i] = new(sql.NullInt64) case group.FieldName, group.FieldDescription, group.FieldStatus, group.FieldPlatform, group.FieldSubscriptionType: values[i] = new(sql.NullString) @@ -399,6 +401,12 @@ func (_m *Group) assignValues(columns []string, values []any) error { return fmt.Errorf("unmarshal field supported_model_scopes: %w", err) } } + case group.FieldSortOrder: + if value, ok := values[i].(*sql.NullInt64); !ok { + return fmt.Errorf("unexpected type %T for field sort_order", values[i]) + } else if value.Valid { + _m.SortOrder = int(value.Int64) + } default: _m.selectValues.Set(columns[i], values[i]) } @@ -586,6 +594,9 @@ func (_m *Group) String() string { builder.WriteString(", ") builder.WriteString("supported_model_scopes=") builder.WriteString(fmt.Sprintf("%v", _m.SupportedModelScopes)) + builder.WriteString(", ") + builder.WriteString("sort_order=") + builder.WriteString(fmt.Sprintf("%v", _m.SortOrder)) builder.WriteByte(')') return builder.String() } diff --git a/backend/ent/group/group.go b/backend/ent/group/group.go index 7bafc615..133123a1 100644 --- a/backend/ent/group/group.go +++ b/backend/ent/group/group.go @@ -71,6 +71,8 @@ const ( FieldMcpXMLInject = "mcp_xml_inject" // FieldSupportedModelScopes holds the string denoting the supported_model_scopes field in the database. FieldSupportedModelScopes = "supported_model_scopes" + // FieldSortOrder holds the string denoting the sort_order field in the database. + FieldSortOrder = "sort_order" // EdgeAPIKeys holds the string denoting the api_keys edge name in mutations. EdgeAPIKeys = "api_keys" // EdgeRedeemCodes holds the string denoting the redeem_codes edge name in mutations. @@ -174,6 +176,7 @@ var Columns = []string{ FieldModelRoutingEnabled, FieldMcpXMLInject, FieldSupportedModelScopes, + FieldSortOrder, } var ( @@ -237,6 +240,8 @@ var ( DefaultMcpXMLInject bool // DefaultSupportedModelScopes holds the default value on creation for the "supported_model_scopes" field. DefaultSupportedModelScopes []string + // DefaultSortOrder holds the default value on creation for the "sort_order" field. + DefaultSortOrder int ) // OrderOption defines the ordering options for the Group queries. @@ -377,6 +382,11 @@ func ByMcpXMLInject(opts ...sql.OrderTermOption) OrderOption { return sql.OrderByField(FieldMcpXMLInject, opts...).ToFunc() } +// BySortOrder orders the results by the sort_order field. +func BySortOrder(opts ...sql.OrderTermOption) OrderOption { + return sql.OrderByField(FieldSortOrder, opts...).ToFunc() +} + // ByAPIKeysCount orders the results by api_keys count. func ByAPIKeysCount(opts ...sql.OrderTermOption) OrderOption { return func(s *sql.Selector) { diff --git a/backend/ent/group/where.go b/backend/ent/group/where.go index fb30fe86..127d4ae9 100644 --- a/backend/ent/group/where.go +++ b/backend/ent/group/where.go @@ -185,6 +185,11 @@ func McpXMLInject(v bool) predicate.Group { return predicate.Group(sql.FieldEQ(FieldMcpXMLInject, v)) } +// SortOrder applies equality check predicate on the "sort_order" field. It's identical to SortOrderEQ. +func SortOrder(v int) predicate.Group { + return predicate.Group(sql.FieldEQ(FieldSortOrder, v)) +} + // CreatedAtEQ applies the EQ predicate on the "created_at" field. func CreatedAtEQ(v time.Time) predicate.Group { return predicate.Group(sql.FieldEQ(FieldCreatedAt, v)) @@ -1380,6 +1385,46 @@ func McpXMLInjectNEQ(v bool) predicate.Group { return predicate.Group(sql.FieldNEQ(FieldMcpXMLInject, v)) } +// SortOrderEQ applies the EQ predicate on the "sort_order" field. +func SortOrderEQ(v int) predicate.Group { + return predicate.Group(sql.FieldEQ(FieldSortOrder, v)) +} + +// SortOrderNEQ applies the NEQ predicate on the "sort_order" field. +func SortOrderNEQ(v int) predicate.Group { + return predicate.Group(sql.FieldNEQ(FieldSortOrder, v)) +} + +// SortOrderIn applies the In predicate on the "sort_order" field. +func SortOrderIn(vs ...int) predicate.Group { + return predicate.Group(sql.FieldIn(FieldSortOrder, vs...)) +} + +// SortOrderNotIn applies the NotIn predicate on the "sort_order" field. +func SortOrderNotIn(vs ...int) predicate.Group { + return predicate.Group(sql.FieldNotIn(FieldSortOrder, vs...)) +} + +// SortOrderGT applies the GT predicate on the "sort_order" field. +func SortOrderGT(v int) predicate.Group { + return predicate.Group(sql.FieldGT(FieldSortOrder, v)) +} + +// SortOrderGTE applies the GTE predicate on the "sort_order" field. +func SortOrderGTE(v int) predicate.Group { + return predicate.Group(sql.FieldGTE(FieldSortOrder, v)) +} + +// SortOrderLT applies the LT predicate on the "sort_order" field. +func SortOrderLT(v int) predicate.Group { + return predicate.Group(sql.FieldLT(FieldSortOrder, v)) +} + +// SortOrderLTE applies the LTE predicate on the "sort_order" field. +func SortOrderLTE(v int) predicate.Group { + return predicate.Group(sql.FieldLTE(FieldSortOrder, v)) +} + // HasAPIKeys applies the HasEdge predicate on the "api_keys" edge. func HasAPIKeys() predicate.Group { return predicate.Group(func(s *sql.Selector) { diff --git a/backend/ent/group_create.go b/backend/ent/group_create.go index 2ce0f730..4416516b 100644 --- a/backend/ent/group_create.go +++ b/backend/ent/group_create.go @@ -396,6 +396,20 @@ func (_c *GroupCreate) SetSupportedModelScopes(v []string) *GroupCreate { return _c } +// SetSortOrder sets the "sort_order" field. +func (_c *GroupCreate) SetSortOrder(v int) *GroupCreate { + _c.mutation.SetSortOrder(v) + return _c +} + +// SetNillableSortOrder sets the "sort_order" field if the given value is not nil. +func (_c *GroupCreate) SetNillableSortOrder(v *int) *GroupCreate { + if v != nil { + _c.SetSortOrder(*v) + } + return _c +} + // AddAPIKeyIDs adds the "api_keys" edge to the APIKey entity by IDs. func (_c *GroupCreate) AddAPIKeyIDs(ids ...int64) *GroupCreate { _c.mutation.AddAPIKeyIDs(ids...) @@ -577,6 +591,10 @@ func (_c *GroupCreate) defaults() error { v := group.DefaultSupportedModelScopes _c.mutation.SetSupportedModelScopes(v) } + if _, ok := _c.mutation.SortOrder(); !ok { + v := group.DefaultSortOrder + _c.mutation.SetSortOrder(v) + } return nil } @@ -641,6 +659,9 @@ func (_c *GroupCreate) check() error { if _, ok := _c.mutation.SupportedModelScopes(); !ok { return &ValidationError{Name: "supported_model_scopes", err: errors.New(`ent: missing required field "Group.supported_model_scopes"`)} } + if _, ok := _c.mutation.SortOrder(); !ok { + return &ValidationError{Name: "sort_order", err: errors.New(`ent: missing required field "Group.sort_order"`)} + } return nil } @@ -780,6 +801,10 @@ func (_c *GroupCreate) createSpec() (*Group, *sqlgraph.CreateSpec) { _spec.SetField(group.FieldSupportedModelScopes, field.TypeJSON, value) _node.SupportedModelScopes = value } + if value, ok := _c.mutation.SortOrder(); ok { + _spec.SetField(group.FieldSortOrder, field.TypeInt, value) + _node.SortOrder = value + } if nodes := _c.mutation.APIKeysIDs(); len(nodes) > 0 { edge := &sqlgraph.EdgeSpec{ Rel: sqlgraph.O2M, @@ -1434,6 +1459,24 @@ func (u *GroupUpsert) UpdateSupportedModelScopes() *GroupUpsert { return u } +// SetSortOrder sets the "sort_order" field. +func (u *GroupUpsert) SetSortOrder(v int) *GroupUpsert { + u.Set(group.FieldSortOrder, v) + return u +} + +// UpdateSortOrder sets the "sort_order" field to the value that was provided on create. +func (u *GroupUpsert) UpdateSortOrder() *GroupUpsert { + u.SetExcluded(group.FieldSortOrder) + return u +} + +// AddSortOrder adds v to the "sort_order" field. +func (u *GroupUpsert) AddSortOrder(v int) *GroupUpsert { + u.Add(group.FieldSortOrder, v) + return u +} + // UpdateNewValues updates the mutable fields using the new values that were set on create. // Using this option is equivalent to using: // @@ -2060,6 +2103,27 @@ func (u *GroupUpsertOne) UpdateSupportedModelScopes() *GroupUpsertOne { }) } +// SetSortOrder sets the "sort_order" field. +func (u *GroupUpsertOne) SetSortOrder(v int) *GroupUpsertOne { + return u.Update(func(s *GroupUpsert) { + s.SetSortOrder(v) + }) +} + +// AddSortOrder adds v to the "sort_order" field. +func (u *GroupUpsertOne) AddSortOrder(v int) *GroupUpsertOne { + return u.Update(func(s *GroupUpsert) { + s.AddSortOrder(v) + }) +} + +// UpdateSortOrder sets the "sort_order" field to the value that was provided on create. +func (u *GroupUpsertOne) UpdateSortOrder() *GroupUpsertOne { + return u.Update(func(s *GroupUpsert) { + s.UpdateSortOrder() + }) +} + // Exec executes the query. func (u *GroupUpsertOne) Exec(ctx context.Context) error { if len(u.create.conflict) == 0 { @@ -2852,6 +2916,27 @@ func (u *GroupUpsertBulk) UpdateSupportedModelScopes() *GroupUpsertBulk { }) } +// SetSortOrder sets the "sort_order" field. +func (u *GroupUpsertBulk) SetSortOrder(v int) *GroupUpsertBulk { + return u.Update(func(s *GroupUpsert) { + s.SetSortOrder(v) + }) +} + +// AddSortOrder adds v to the "sort_order" field. +func (u *GroupUpsertBulk) AddSortOrder(v int) *GroupUpsertBulk { + return u.Update(func(s *GroupUpsert) { + s.AddSortOrder(v) + }) +} + +// UpdateSortOrder sets the "sort_order" field to the value that was provided on create. +func (u *GroupUpsertBulk) UpdateSortOrder() *GroupUpsertBulk { + return u.Update(func(s *GroupUpsert) { + s.UpdateSortOrder() + }) +} + // Exec executes the query. func (u *GroupUpsertBulk) Exec(ctx context.Context) error { if u.create.err != nil { diff --git a/backend/ent/group_update.go b/backend/ent/group_update.go index f2142ce4..db510e05 100644 --- a/backend/ent/group_update.go +++ b/backend/ent/group_update.go @@ -583,6 +583,27 @@ func (_u *GroupUpdate) AppendSupportedModelScopes(v []string) *GroupUpdate { return _u } +// SetSortOrder sets the "sort_order" field. +func (_u *GroupUpdate) SetSortOrder(v int) *GroupUpdate { + _u.mutation.ResetSortOrder() + _u.mutation.SetSortOrder(v) + return _u +} + +// SetNillableSortOrder sets the "sort_order" field if the given value is not nil. +func (_u *GroupUpdate) SetNillableSortOrder(v *int) *GroupUpdate { + if v != nil { + _u.SetSortOrder(*v) + } + return _u +} + +// AddSortOrder adds value to the "sort_order" field. +func (_u *GroupUpdate) AddSortOrder(v int) *GroupUpdate { + _u.mutation.AddSortOrder(v) + return _u +} + // AddAPIKeyIDs adds the "api_keys" edge to the APIKey entity by IDs. func (_u *GroupUpdate) AddAPIKeyIDs(ids ...int64) *GroupUpdate { _u.mutation.AddAPIKeyIDs(ids...) @@ -1056,6 +1077,12 @@ func (_u *GroupUpdate) sqlSave(ctx context.Context) (_node int, err error) { sqljson.Append(u, group.FieldSupportedModelScopes, value) }) } + if value, ok := _u.mutation.SortOrder(); ok { + _spec.SetField(group.FieldSortOrder, field.TypeInt, value) + } + if value, ok := _u.mutation.AddedSortOrder(); ok { + _spec.AddField(group.FieldSortOrder, field.TypeInt, value) + } if _u.mutation.APIKeysCleared() { edge := &sqlgraph.EdgeSpec{ Rel: sqlgraph.O2M, @@ -1918,6 +1945,27 @@ func (_u *GroupUpdateOne) AppendSupportedModelScopes(v []string) *GroupUpdateOne return _u } +// SetSortOrder sets the "sort_order" field. +func (_u *GroupUpdateOne) SetSortOrder(v int) *GroupUpdateOne { + _u.mutation.ResetSortOrder() + _u.mutation.SetSortOrder(v) + return _u +} + +// SetNillableSortOrder sets the "sort_order" field if the given value is not nil. +func (_u *GroupUpdateOne) SetNillableSortOrder(v *int) *GroupUpdateOne { + if v != nil { + _u.SetSortOrder(*v) + } + return _u +} + +// AddSortOrder adds value to the "sort_order" field. +func (_u *GroupUpdateOne) AddSortOrder(v int) *GroupUpdateOne { + _u.mutation.AddSortOrder(v) + return _u +} + // AddAPIKeyIDs adds the "api_keys" edge to the APIKey entity by IDs. func (_u *GroupUpdateOne) AddAPIKeyIDs(ids ...int64) *GroupUpdateOne { _u.mutation.AddAPIKeyIDs(ids...) @@ -2421,6 +2469,12 @@ func (_u *GroupUpdateOne) sqlSave(ctx context.Context) (_node *Group, err error) sqljson.Append(u, group.FieldSupportedModelScopes, value) }) } + if value, ok := _u.mutation.SortOrder(); ok { + _spec.SetField(group.FieldSortOrder, field.TypeInt, value) + } + if value, ok := _u.mutation.AddedSortOrder(); ok { + _spec.AddField(group.FieldSortOrder, field.TypeInt, value) + } if _u.mutation.APIKeysCleared() { edge := &sqlgraph.EdgeSpec{ Rel: sqlgraph.O2M, diff --git a/backend/ent/migrate/schema.go b/backend/ent/migrate/schema.go index d4c27870..f24db53e 100644 --- a/backend/ent/migrate/schema.go +++ b/backend/ent/migrate/schema.go @@ -376,6 +376,7 @@ var ( {Name: "model_routing_enabled", Type: field.TypeBool, Default: false}, {Name: "mcp_xml_inject", Type: field.TypeBool, Default: true}, {Name: "supported_model_scopes", Type: field.TypeJSON, SchemaType: map[string]string{"postgres": "jsonb"}}, + {Name: "sort_order", Type: field.TypeInt, Default: 0}, } // GroupsTable holds the schema information for the "groups" table. GroupsTable = &schema.Table{ @@ -408,6 +409,11 @@ var ( Unique: false, Columns: []*schema.Column{GroupsColumns[3]}, }, + { + Name: "group_sort_order", + Unique: false, + Columns: []*schema.Column{GroupsColumns[29]}, + }, }, } // PromoCodesColumns holds the columns for the "promo_codes" table. diff --git a/backend/ent/mutation.go b/backend/ent/mutation.go index 296db78d..6721866a 100644 --- a/backend/ent/mutation.go +++ b/backend/ent/mutation.go @@ -7067,6 +7067,8 @@ type GroupMutation struct { mcp_xml_inject *bool supported_model_scopes *[]string appendsupported_model_scopes []string + sort_order *int + addsort_order *int clearedFields map[string]struct{} api_keys map[int64]struct{} removedapi_keys map[int64]struct{} @@ -8699,6 +8701,62 @@ func (m *GroupMutation) ResetSupportedModelScopes() { m.appendsupported_model_scopes = nil } +// SetSortOrder sets the "sort_order" field. +func (m *GroupMutation) SetSortOrder(i int) { + m.sort_order = &i + m.addsort_order = nil +} + +// SortOrder returns the value of the "sort_order" field in the mutation. +func (m *GroupMutation) SortOrder() (r int, exists bool) { + v := m.sort_order + if v == nil { + return + } + return *v, true +} + +// OldSortOrder returns the old "sort_order" field's value of the Group entity. +// If the Group object wasn't provided to the builder, the object is fetched from the database. +// An error is returned if the mutation operation is not UpdateOne, or the database query fails. +func (m *GroupMutation) OldSortOrder(ctx context.Context) (v int, err error) { + if !m.op.Is(OpUpdateOne) { + return v, errors.New("OldSortOrder is only allowed on UpdateOne operations") + } + if m.id == nil || m.oldValue == nil { + return v, errors.New("OldSortOrder requires an ID field in the mutation") + } + oldValue, err := m.oldValue(ctx) + if err != nil { + return v, fmt.Errorf("querying old value for OldSortOrder: %w", err) + } + return oldValue.SortOrder, nil +} + +// AddSortOrder adds i to the "sort_order" field. +func (m *GroupMutation) AddSortOrder(i int) { + if m.addsort_order != nil { + *m.addsort_order += i + } else { + m.addsort_order = &i + } +} + +// AddedSortOrder returns the value that was added to the "sort_order" field in this mutation. +func (m *GroupMutation) AddedSortOrder() (r int, exists bool) { + v := m.addsort_order + if v == nil { + return + } + return *v, true +} + +// ResetSortOrder resets all changes to the "sort_order" field. +func (m *GroupMutation) ResetSortOrder() { + m.sort_order = nil + m.addsort_order = nil +} + // AddAPIKeyIDs adds the "api_keys" edge to the APIKey entity by ids. func (m *GroupMutation) AddAPIKeyIDs(ids ...int64) { if m.api_keys == nil { @@ -9057,7 +9115,7 @@ func (m *GroupMutation) Type() string { // order to get all numeric fields that were incremented/decremented, call // AddedFields(). func (m *GroupMutation) Fields() []string { - fields := make([]string, 0, 28) + fields := make([]string, 0, 29) if m.created_at != nil { fields = append(fields, group.FieldCreatedAt) } @@ -9142,6 +9200,9 @@ func (m *GroupMutation) Fields() []string { if m.supported_model_scopes != nil { fields = append(fields, group.FieldSupportedModelScopes) } + if m.sort_order != nil { + fields = append(fields, group.FieldSortOrder) + } return fields } @@ -9206,6 +9267,8 @@ func (m *GroupMutation) Field(name string) (ent.Value, bool) { return m.McpXMLInject() case group.FieldSupportedModelScopes: return m.SupportedModelScopes() + case group.FieldSortOrder: + return m.SortOrder() } return nil, false } @@ -9271,6 +9334,8 @@ func (m *GroupMutation) OldField(ctx context.Context, name string) (ent.Value, e return m.OldMcpXMLInject(ctx) case group.FieldSupportedModelScopes: return m.OldSupportedModelScopes(ctx) + case group.FieldSortOrder: + return m.OldSortOrder(ctx) } return nil, fmt.Errorf("unknown Group field %s", name) } @@ -9476,6 +9541,13 @@ func (m *GroupMutation) SetField(name string, value ent.Value) error { } m.SetSupportedModelScopes(v) return nil + case group.FieldSortOrder: + v, ok := value.(int) + if !ok { + return fmt.Errorf("unexpected type %T for field %s", value, name) + } + m.SetSortOrder(v) + return nil } return fmt.Errorf("unknown Group field %s", name) } @@ -9526,6 +9598,9 @@ func (m *GroupMutation) AddedFields() []string { if m.addfallback_group_id_on_invalid_request != nil { fields = append(fields, group.FieldFallbackGroupIDOnInvalidRequest) } + if m.addsort_order != nil { + fields = append(fields, group.FieldSortOrder) + } return fields } @@ -9562,6 +9637,8 @@ func (m *GroupMutation) AddedField(name string) (ent.Value, bool) { return m.AddedFallbackGroupID() case group.FieldFallbackGroupIDOnInvalidRequest: return m.AddedFallbackGroupIDOnInvalidRequest() + case group.FieldSortOrder: + return m.AddedSortOrder() } return nil, false } @@ -9669,6 +9746,13 @@ func (m *GroupMutation) AddField(name string, value ent.Value) error { } m.AddFallbackGroupIDOnInvalidRequest(v) return nil + case group.FieldSortOrder: + v, ok := value.(int) + if !ok { + return fmt.Errorf("unexpected type %T for field %s", value, name) + } + m.AddSortOrder(v) + return nil } return fmt.Errorf("unknown Group numeric field %s", name) } @@ -9873,6 +9957,9 @@ func (m *GroupMutation) ResetField(name string) error { case group.FieldSupportedModelScopes: m.ResetSupportedModelScopes() return nil + case group.FieldSortOrder: + m.ResetSortOrder() + return nil } return fmt.Errorf("unknown Group field %s", name) } diff --git a/backend/ent/runtime/runtime.go b/backend/ent/runtime/runtime.go index 81c8c800..6d32fc26 100644 --- a/backend/ent/runtime/runtime.go +++ b/backend/ent/runtime/runtime.go @@ -409,6 +409,10 @@ func init() { groupDescSupportedModelScopes := groupFields[24].Descriptor() // group.DefaultSupportedModelScopes holds the default value on creation for the supported_model_scopes field. group.DefaultSupportedModelScopes = groupDescSupportedModelScopes.Default.([]string) + // groupDescSortOrder is the schema descriptor for sort_order field. + groupDescSortOrder := groupFields[25].Descriptor() + // group.DefaultSortOrder holds the default value on creation for the sort_order field. + group.DefaultSortOrder = groupDescSortOrder.Default.(int) promocodeFields := schema.PromoCode{}.Fields() _ = promocodeFields // promocodeDescCode is the schema descriptor for code field. diff --git a/backend/ent/schema/group.go b/backend/ent/schema/group.go index cb1e5eec..fddf23ce 100644 --- a/backend/ent/schema/group.go +++ b/backend/ent/schema/group.go @@ -139,6 +139,11 @@ func (Group) Fields() []ent.Field { Default([]string{"claude", "gemini_text", "gemini_image"}). SchemaType(map[string]string{dialect.Postgres: "jsonb"}). Comment("支持的模型系列:claude, gemini_text, gemini_image"), + + // 分组排序 (added by migration 052) + field.Int("sort_order"). + Default(0). + Comment("分组显示排序,数值越小越靠前"), } } @@ -167,5 +172,6 @@ func (Group) Indexes() []ent.Index { index.Fields("subscription_type"), index.Fields("is_exclusive"), index.Fields("deleted_at"), + index.Fields("sort_order"), } } diff --git a/backend/go.mod b/backend/go.mod index 6916057f..30a0041c 100644 --- a/backend/go.mod +++ b/backend/go.mod @@ -75,6 +75,7 @@ require ( github.com/goccy/go-json v0.10.2 // indirect github.com/google/go-cmp v0.7.0 // indirect github.com/google/go-querystring v1.1.0 // indirect + github.com/google/subcommands v1.2.0 // indirect github.com/grpc-ecosystem/grpc-gateway/v2 v2.27.3 // indirect github.com/hashicorp/hcl v1.0.0 // indirect github.com/hashicorp/hcl/v2 v2.18.1 // indirect @@ -103,6 +104,7 @@ require ( github.com/ncruces/go-strftime v1.0.0 // indirect github.com/opencontainers/go-digest v1.0.0 // indirect github.com/opencontainers/image-spec v1.1.1 // indirect + github.com/patrickmn/go-cache v2.1.0+incompatible // indirect github.com/pelletier/go-toml/v2 v2.2.2 // indirect github.com/pkg/errors v0.9.1 // indirect github.com/pmezard/go-difflib v1.0.1-0.20181226105442-5d4384ee4fb2 // indirect @@ -143,6 +145,7 @@ require ( golang.org/x/mod v0.31.0 // indirect golang.org/x/sys v0.40.0 // indirect golang.org/x/text v0.33.0 // indirect + golang.org/x/tools v0.40.0 // indirect google.golang.org/grpc v1.75.1 // indirect google.golang.org/protobuf v1.36.10 // indirect gopkg.in/ini.v1 v1.67.0 // indirect diff --git a/backend/go.sum b/backend/go.sum index 171995c7..f6fdb851 100644 --- a/backend/go.sum +++ b/backend/go.sum @@ -116,6 +116,8 @@ github.com/google/go-querystring v1.1.0/go.mod h1:Kcdr2DB4koayq7X8pmAG4sNG59So17 github.com/google/gofuzz v1.0.0/go.mod h1:dBl0BpW6vV/+mYPU4Po3pmUjxk6FQPldtuIdl/M65Eg= github.com/google/pprof v0.0.0-20250317173921-a4b03ec1a45e h1:ijClszYn+mADRFY17kjQEVQ1XRhq2/JR1M3sGqeJoxs= github.com/google/pprof v0.0.0-20250317173921-a4b03ec1a45e/go.mod h1:boTsfXsheKC2y+lKOCMpSfarhxDeIzfZG1jqGcPl3cA= +github.com/google/subcommands v1.2.0 h1:vWQspBTo2nEqTUFita5/KeEWlUL8kQObDFbub/EN9oE= +github.com/google/subcommands v1.2.0/go.mod h1:ZjhPrFU+Olkh9WazFPsl27BQ4UPiG37m3yTrtFlrHVk= github.com/google/uuid v1.6.0 h1:NIvaJDMOsjHA8n1jAhLSgzrAzy1Hgr+hNrb57e+94F0= github.com/google/uuid v1.6.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo= github.com/google/wire v0.7.0 h1:JxUKI6+CVBgCO2WToKy/nQk0sS+amI9z9EjVmdaocj4= @@ -135,6 +137,8 @@ github.com/icholy/digest v1.1.0 h1:HfGg9Irj7i+IX1o1QAmPfIBNu/Q5A5Tu3n/MED9k9H4= github.com/icholy/digest v1.1.0/go.mod h1:QNrsSGQ5v7v9cReDI0+eyjsXGUoRSUZQHeQ5C4XLa0Y= github.com/imroc/req/v3 v3.57.0 h1:LMTUjNRUybUkTPn8oJDq8Kg3JRBOBTcnDhKu7mzupKI= github.com/imroc/req/v3 v3.57.0/go.mod h1:JL62ey1nvSLq81HORNcosvlf7SxZStONNqOprg0Pz00= +github.com/inconshreveable/mousetrap v1.1.0 h1:wN+x4NVGpMsO7ErUn/mUI3vEoE6Jt13X2s0bqwp9tc8= +github.com/inconshreveable/mousetrap v1.1.0/go.mod h1:vpF70FUmC8bwa3OWnCshd2FqLfsEA9PFc4w1p2J65bw= github.com/jackc/pgpassfile v1.0.0 h1:/6Hmqy13Ss2zCq62VdNG8tM1wchn8zjSGOBJ6icpsIM= github.com/jackc/pgpassfile v1.0.0/go.mod h1:CEx0iS5ambNFdcRtxPj5JhEz+xB6uRky5eyVu/W2HEg= github.com/jackc/pgservicefile v0.0.0-20240606120523-5a60cdf6a761 h1:iCEnooe7UlwOQYpKFhBabPMi4aNAfoODPEFNiAnClxo= @@ -170,6 +174,8 @@ github.com/mattn/go-colorable v0.1.13/go.mod h1:7S9/ev0klgBDR4GtXTXX8a3vIGJpMovk github.com/mattn/go-isatty v0.0.16/go.mod h1:kYGgaQfpe5nmfYZH+SKPsOc2e4SrIfOl2e/yFXSvRLM= github.com/mattn/go-isatty v0.0.20 h1:xfD0iDuEKnDkl03q4limB+vH+GxLEtL/jb4xVJSWWEY= github.com/mattn/go-isatty v0.0.20/go.mod h1:W+V8PltTTMOvKvAeJH7IuucS94S2C6jfK/D7dTCTo3Y= +github.com/mattn/go-runewidth v0.0.15 h1:UNAjwbU9l54TA3KzvqLGxwWjHmMgBUVhBiTjelZgg3U= +github.com/mattn/go-runewidth v0.0.15/go.mod h1:Jdepj2loyihRzMpdS35Xk/zdY8IAYHsh153qUoGf23w= github.com/mattn/go-sqlite3 v1.14.17 h1:mCRHCLDUBXgpKAqIKsaAaAsrAlbkeomtRFKXh2L6YIM= github.com/mattn/go-sqlite3 v1.14.17/go.mod h1:2eHXhiwb8IkHr+BDWZGa96P6+rkvnG63S2DGjv9HUNg= github.com/mdelapenya/tlscert v0.2.0 h1:7H81W6Z/4weDvZBNOfQte5GpIMo0lGYEeWbkGp5LJHI= @@ -203,10 +209,14 @@ github.com/morikuni/aec v1.0.0 h1:nP9CBfwrvYnBRgY6qfDQkygYDmYwOilePFkwzv4dU8A= github.com/morikuni/aec v1.0.0/go.mod h1:BbKIizmSmc5MMPqRYbxO4ZU0S0+P200+tUnFx7PXmsc= github.com/ncruces/go-strftime v1.0.0 h1:HMFp8mLCTPp341M/ZnA4qaf7ZlsbTc+miZjCLOFAw7w= github.com/ncruces/go-strftime v1.0.0/go.mod h1:Fwc5htZGVVkseilnfgOVb9mKy6w1naJmn9CehxcKcls= +github.com/olekukonko/tablewriter v0.0.5 h1:P2Ga83D34wi1o9J6Wh1mRuqd4mF/x/lgBS7N7AbDhec= +github.com/olekukonko/tablewriter v0.0.5/go.mod h1:hPp6KlRPjbx+hW8ykQs1w3UBbZlj6HuIJcUGPhkA7kY= github.com/opencontainers/go-digest v1.0.0 h1:apOUWs51W5PlhuyGyz9FCeeBIOUDA/6nW8Oi/yOhh5U= github.com/opencontainers/go-digest v1.0.0/go.mod h1:0JzlMkj0TRzQZfJkVvzbP0HBR3IKzErnv2BNG4W4MAM= github.com/opencontainers/image-spec v1.1.1 h1:y0fUlFfIZhPF1W537XOLg0/fcx6zcHCJwooC2xJA040= github.com/opencontainers/image-spec v1.1.1/go.mod h1:qpqAh3Dmcf36wStyyWU+kCeDgrGnAve2nCC8+7h8Q0M= +github.com/patrickmn/go-cache v2.1.0+incompatible h1:HRMgzkcYKYpi3C8ajMPV8OFXaaRUnok+kx1WdO15EQc= +github.com/patrickmn/go-cache v2.1.0+incompatible/go.mod h1:3Qf8kWWT7OJRJbdiICTKqZju1ZixQ/KpMGzzAfe6+WQ= github.com/pelletier/go-toml/v2 v2.2.2 h1:aYUidT7k73Pcl9nb2gScu7NSrKCSHIDE89b3+6Wq+LM= github.com/pelletier/go-toml/v2 v2.2.2/go.mod h1:1t835xjRzz80PqgE6HHgN2JOsmgYu/h4qDAS4n929Rs= github.com/pkg/errors v0.9.1 h1:FEBLx1zS214owpjy7qsBeixbURkuhQAwrK5UwLGTwt4= @@ -230,6 +240,8 @@ github.com/refraction-networking/utls v1.8.1 h1:yNY1kapmQU8JeM1sSw2H2asfTIwWxIkr github.com/refraction-networking/utls v1.8.1/go.mod h1:jkSOEkLqn+S/jtpEHPOsVv/4V4EVnelwbMQl4vCWXAM= github.com/remyoudompheng/bigfft v0.0.0-20230129092748-24d4a6f8daec h1:W09IVJc94icq4NjY3clb7Lk8O1qJ8BdBEF8z0ibU0rE= github.com/remyoudompheng/bigfft v0.0.0-20230129092748-24d4a6f8daec/go.mod h1:qqbHyh8v60DhA7CoWK5oRCqLrMHRGoxYCSS9EjAz6Eo= +github.com/rivo/uniseg v0.2.0 h1:S1pD9weZBuJdFmowNwbpi7BJ8TNftyUImj/0WQi72jY= +github.com/rivo/uniseg v0.2.0/go.mod h1:J6wj4VEh+S6ZtnVlnTBMWIodfgj8LQOQFoIToxlJtxc= github.com/robfig/cron/v3 v3.0.1 h1:WdRxkvbJztn8LMz/QEvLN5sBU+xKpSqwwUO1Pjr4qDs= github.com/robfig/cron/v3 v3.0.1/go.mod h1:eQICP3HwyT7UooqI/z+Ov+PtYAWygg1TEWWzGIFLtro= github.com/rogpeppe/go-internal v1.13.1 h1:KvO1DLK/DRN07sQ1LQKScxyZJuNnedQ5/wKSR38lUII= @@ -252,6 +264,8 @@ github.com/spf13/afero v1.11.0 h1:WJQKhtpdm3v2IzqG8VMqrr6Rf3UYpEF239Jy9wNepM8= github.com/spf13/afero v1.11.0/go.mod h1:GH9Y3pIexgf1MTIWtNGyogA5MwRIDXGUr+hbWNoBjkY= github.com/spf13/cast v1.6.0 h1:GEiTHELF+vaR5dhz3VqZfFSzZjYbgeKDpBxQVS4GYJ0= github.com/spf13/cast v1.6.0/go.mod h1:ancEpBxwJDODSW/UG4rDrAqiKolqNNh2DX3mk86cAdo= +github.com/spf13/cobra v1.7.0 h1:hyqWnYt1ZQShIddO5kBpj3vu05/++x6tJ6dg8EC572I= +github.com/spf13/cobra v1.7.0/go.mod h1:uLxZILRyS/50WlhOIKD7W6V5bgeIt+4sICxh6uRMrb0= github.com/spf13/pflag v1.0.5 h1:iy+VFUOCP1a+8yFto/drg2CJ5u0yRoB7fZw3DKv/JXA= github.com/spf13/pflag v1.0.5/go.mod h1:McXfInJRrz4CZXVZOBLb0bTZqETkiAhM9Iw0y3An2Bg= github.com/spf13/viper v1.18.2 h1:LUXCnvUvSM6FXAsj6nnfc8Q2tp1dIgUfY9Kc8GsSOiQ= diff --git a/backend/internal/handler/admin/account_handler.go b/backend/internal/handler/admin/account_handler.go index 55536ba9..c031d6d6 100644 --- a/backend/internal/handler/admin/account_handler.go +++ b/backend/internal/handler/admin/account_handler.go @@ -425,10 +425,17 @@ type TestAccountRequest struct { } type SyncFromCRSRequest struct { - BaseURL string `json:"base_url" binding:"required"` - Username string `json:"username" binding:"required"` - Password string `json:"password" binding:"required"` - SyncProxies *bool `json:"sync_proxies"` + BaseURL string `json:"base_url" binding:"required"` + Username string `json:"username" binding:"required"` + Password string `json:"password" binding:"required"` + SyncProxies *bool `json:"sync_proxies"` + SelectedAccountIDs []string `json:"selected_account_ids"` +} + +type PreviewFromCRSRequest struct { + BaseURL string `json:"base_url" binding:"required"` + Username string `json:"username" binding:"required"` + Password string `json:"password" binding:"required"` } // Test handles testing account connectivity with SSE streaming @@ -467,10 +474,11 @@ func (h *AccountHandler) SyncFromCRS(c *gin.Context) { } result, err := h.crsSyncService.SyncFromCRS(c.Request.Context(), service.SyncFromCRSInput{ - BaseURL: req.BaseURL, - Username: req.Username, - Password: req.Password, - SyncProxies: syncProxies, + BaseURL: req.BaseURL, + Username: req.Username, + Password: req.Password, + SyncProxies: syncProxies, + SelectedAccountIDs: req.SelectedAccountIDs, }) if err != nil { // Provide detailed error message for CRS sync failures @@ -481,6 +489,28 @@ func (h *AccountHandler) SyncFromCRS(c *gin.Context) { response.Success(c, result) } +// PreviewFromCRS handles previewing accounts from CRS before sync +// POST /api/v1/admin/accounts/sync/crs/preview +func (h *AccountHandler) PreviewFromCRS(c *gin.Context) { + var req PreviewFromCRSRequest + if err := c.ShouldBindJSON(&req); err != nil { + response.BadRequest(c, "Invalid request: "+err.Error()) + return + } + + result, err := h.crsSyncService.PreviewFromCRS(c.Request.Context(), service.SyncFromCRSInput{ + BaseURL: req.BaseURL, + Username: req.Username, + Password: req.Password, + }) + if err != nil { + response.InternalError(c, "CRS preview failed: "+err.Error()) + return + } + + response.Success(c, result) +} + // Refresh handles refreshing account credentials // POST /api/v1/admin/accounts/:id/refresh func (h *AccountHandler) Refresh(c *gin.Context) { diff --git a/backend/internal/handler/admin/admin_basic_handlers_test.go b/backend/internal/handler/admin/admin_basic_handlers_test.go index e0f731e1..20a25222 100644 --- a/backend/internal/handler/admin/admin_basic_handlers_test.go +++ b/backend/internal/handler/admin/admin_basic_handlers_test.go @@ -16,7 +16,7 @@ func setupAdminRouter() (*gin.Engine, *stubAdminService) { router := gin.New() adminSvc := newStubAdminService() - userHandler := NewUserHandler(adminSvc) + userHandler := NewUserHandler(adminSvc, nil) groupHandler := NewGroupHandler(adminSvc) proxyHandler := NewProxyHandler(adminSvc) redeemHandler := NewRedeemHandler(adminSvc) diff --git a/backend/internal/handler/admin/admin_service_stub_test.go b/backend/internal/handler/admin/admin_service_stub_test.go index 77d288f9..cbbfe942 100644 --- a/backend/internal/handler/admin/admin_service_stub_test.go +++ b/backend/internal/handler/admin/admin_service_stub_test.go @@ -357,5 +357,9 @@ func (s *stubAdminService) GetUserBalanceHistory(ctx context.Context, userID int return s.redeems, int64(len(s.redeems)), 100.0, nil } +func (s *stubAdminService) UpdateGroupSortOrders(ctx context.Context, updates []service.GroupSortOrderUpdate) error { + return nil +} + // Ensure stub implements interface. var _ service.AdminService = (*stubAdminService)(nil) diff --git a/backend/internal/handler/admin/group_handler.go b/backend/internal/handler/admin/group_handler.go index 20a20767..25ff3c96 100644 --- a/backend/internal/handler/admin/group_handler.go +++ b/backend/internal/handler/admin/group_handler.go @@ -318,3 +318,36 @@ func (h *GroupHandler) GetGroupAPIKeys(c *gin.Context) { } response.Paginated(c, outKeys, total, page, pageSize) } + +// UpdateSortOrderRequest represents the request to update group sort orders +type UpdateSortOrderRequest struct { + Updates []struct { + ID int64 `json:"id" binding:"required"` + SortOrder int `json:"sort_order"` + } `json:"updates" binding:"required,min=1"` +} + +// UpdateSortOrder handles updating group sort orders +// PUT /api/v1/admin/groups/sort-order +func (h *GroupHandler) UpdateSortOrder(c *gin.Context) { + var req UpdateSortOrderRequest + if err := c.ShouldBindJSON(&req); err != nil { + response.BadRequest(c, "Invalid request: "+err.Error()) + return + } + + updates := make([]service.GroupSortOrderUpdate, 0, len(req.Updates)) + for _, u := range req.Updates { + updates = append(updates, service.GroupSortOrderUpdate{ + ID: u.ID, + SortOrder: u.SortOrder, + }) + } + + if err := h.adminService.UpdateGroupSortOrders(c.Request.Context(), updates); err != nil { + response.ErrorFrom(c, err) + return + } + + response.Success(c, gin.H{"message": "Sort order updated successfully"}) +} diff --git a/backend/internal/handler/admin/user_handler.go b/backend/internal/handler/admin/user_handler.go index 0427e77e..248caa4b 100644 --- a/backend/internal/handler/admin/user_handler.go +++ b/backend/internal/handler/admin/user_handler.go @@ -11,15 +11,23 @@ import ( "github.com/gin-gonic/gin" ) +// UserWithConcurrency wraps AdminUser with current concurrency info +type UserWithConcurrency struct { + dto.AdminUser + CurrentConcurrency int `json:"current_concurrency"` +} + // UserHandler handles admin user management type UserHandler struct { - adminService service.AdminService + adminService service.AdminService + concurrencyService *service.ConcurrencyService } // NewUserHandler creates a new admin user handler -func NewUserHandler(adminService service.AdminService) *UserHandler { +func NewUserHandler(adminService service.AdminService, concurrencyService *service.ConcurrencyService) *UserHandler { return &UserHandler{ - adminService: adminService, + adminService: adminService, + concurrencyService: concurrencyService, } } @@ -87,10 +95,30 @@ func (h *UserHandler) List(c *gin.Context) { return } - out := make([]dto.AdminUser, 0, len(users)) - for i := range users { - out = append(out, *dto.UserFromServiceAdmin(&users[i])) + // Batch get current concurrency (nil map if unavailable) + var loadInfo map[int64]*service.UserLoadInfo + if len(users) > 0 && h.concurrencyService != nil { + usersConcurrency := make([]service.UserWithConcurrency, len(users)) + for i := range users { + usersConcurrency[i] = service.UserWithConcurrency{ + ID: users[i].ID, + MaxConcurrency: users[i].Concurrency, + } + } + loadInfo, _ = h.concurrencyService.GetUsersLoadBatch(c.Request.Context(), usersConcurrency) } + + // Build response with concurrency info + out := make([]UserWithConcurrency, len(users)) + for i := range users { + out[i] = UserWithConcurrency{ + AdminUser: *dto.UserFromServiceAdmin(&users[i]), + } + if info := loadInfo[users[i].ID]; info != nil { + out[i].CurrentConcurrency = info.CurrentConcurrency + } + } + response.Paginated(c, out, total, page, pageSize) } diff --git a/backend/internal/handler/dto/mappers.go b/backend/internal/handler/dto/mappers.go index b72ab6ff..3c216d65 100644 --- a/backend/internal/handler/dto/mappers.go +++ b/backend/internal/handler/dto/mappers.go @@ -115,6 +115,7 @@ func GroupFromServiceAdmin(g *service.Group) *AdminGroup { MCPXMLInject: g.MCPXMLInject, SupportedModelScopes: g.SupportedModelScopes, AccountCount: g.AccountCount, + SortOrder: g.SortOrder, } if len(g.AccountGroups) > 0 { out.AccountGroups = make([]AccountGroup, 0, len(g.AccountGroups)) diff --git a/backend/internal/handler/dto/types.go b/backend/internal/handler/dto/types.go index 97f3f81a..daac42bd 100644 --- a/backend/internal/handler/dto/types.go +++ b/backend/internal/handler/dto/types.go @@ -2,11 +2,6 @@ package dto import "time" -type ScopeRateLimitInfo struct { - ResetAt time.Time `json:"reset_at"` - RemainingSec int64 `json:"remaining_sec"` -} - type User struct { ID int64 `json:"id"` Email string `json:"email"` @@ -104,6 +99,9 @@ type AdminGroup struct { SupportedModelScopes []string `json:"supported_model_scopes"` AccountGroups []AccountGroup `json:"account_groups,omitempty"` AccountCount int64 `json:"account_count,omitempty"` + + // 分组排序 + SortOrder int `json:"sort_order"` } type Account struct { @@ -132,9 +130,6 @@ type Account struct { RateLimitResetAt *time.Time `json:"rate_limit_reset_at"` OverloadUntil *time.Time `json:"overload_until"` - // Antigravity scope 级限流状态(从 extra 提取) - ScopeRateLimits map[string]ScopeRateLimitInfo `json:"scope_rate_limits,omitempty"` - TempUnschedulableUntil *time.Time `json:"temp_unschedulable_until"` TempUnschedulableReason string `json:"temp_unschedulable_reason"` diff --git a/backend/internal/handler/gateway_handler.go b/backend/internal/handler/gateway_handler.go index e3b0a9b5..af20e318 100644 --- a/backend/internal/handler/gateway_handler.go +++ b/backend/internal/handler/gateway_handler.go @@ -13,6 +13,7 @@ import ( "time" "github.com/Wei-Shaw/sub2api/internal/config" + "github.com/Wei-Shaw/sub2api/internal/domain" "github.com/Wei-Shaw/sub2api/internal/pkg/antigravity" "github.com/Wei-Shaw/sub2api/internal/pkg/claude" "github.com/Wei-Shaw/sub2api/internal/pkg/ctxkey" @@ -116,7 +117,7 @@ func (h *GatewayHandler) Messages(c *gin.Context) { setOpsRequestContext(c, "", false, body) - parsedReq, err := service.ParseGatewayRequest(body) + parsedReq, err := service.ParseGatewayRequest(body, domain.PlatformAnthropic) if err != nil { h.errorResponse(c, http.StatusBadRequest, "invalid_request_error", "Failed to parse request body") return @@ -205,6 +206,11 @@ func (h *GatewayHandler) Messages(c *gin.Context) { } // 计算粘性会话hash + parsedReq.SessionContext = &service.SessionContext{ + ClientIP: ip.GetClientIP(c), + UserAgent: c.GetHeader("User-Agent"), + APIKeyID: apiKey.ID, + } sessionHash := h.gatewayService.GenerateSessionHash(parsedReq) // 获取平台:优先使用强制平台(/antigravity 路由,中间件已设置 request.Context),否则使用分组平台 @@ -336,7 +342,7 @@ func (h *GatewayHandler) Messages(c *gin.Context) { if errors.As(err, &failoverErr) { failedAccountIDs[account.ID] = struct{}{} lastFailoverErr = failoverErr - if failoverErr.ForceCacheBilling { + if needForceCacheBilling(hasBoundSession, failoverErr) { forceCacheBilling = true } if switchCount >= maxAccountSwitches { @@ -345,6 +351,11 @@ func (h *GatewayHandler) Messages(c *gin.Context) { } switchCount++ log.Printf("Account %d: upstream error %d, switching account %d/%d", account.ID, failoverErr.StatusCode, switchCount, maxAccountSwitches) + if account.Platform == service.PlatformAntigravity { + if !sleepFailoverDelay(c.Request.Context(), switchCount) { + return + } + } continue } // 错误响应已在Forward中处理,这里只记录日志 @@ -484,7 +495,7 @@ func (h *GatewayHandler) Messages(c *gin.Context) { if switchCount > 0 { requestCtx = context.WithValue(requestCtx, ctxkey.AccountSwitchCount, switchCount) } - if account.Platform == service.PlatformAntigravity { + if account.Platform == service.PlatformAntigravity && account.Type != service.AccountTypeAPIKey { result, err = h.antigravityGatewayService.Forward(requestCtx, c, account, body, hasBoundSession) } else { result, err = h.gatewayService.Forward(requestCtx, c, account, parsedReq) @@ -532,7 +543,7 @@ func (h *GatewayHandler) Messages(c *gin.Context) { if errors.As(err, &failoverErr) { failedAccountIDs[account.ID] = struct{}{} lastFailoverErr = failoverErr - if failoverErr.ForceCacheBilling { + if needForceCacheBilling(hasBoundSession, failoverErr) { forceCacheBilling = true } if switchCount >= maxAccountSwitches { @@ -541,6 +552,11 @@ func (h *GatewayHandler) Messages(c *gin.Context) { } switchCount++ log.Printf("Account %d: upstream error %d, switching account %d/%d", account.ID, failoverErr.StatusCode, switchCount, maxAccountSwitches) + if account.Platform == service.PlatformAntigravity { + if !sleepFailoverDelay(c.Request.Context(), switchCount) { + return + } + } continue } // 错误响应已在Forward中处理,这里只记录日志 @@ -814,6 +830,27 @@ func (h *GatewayHandler) handleConcurrencyError(c *gin.Context, err error, slotT fmt.Sprintf("Concurrency limit exceeded for %s, please retry later", slotType), streamStarted) } +// needForceCacheBilling 判断 failover 时是否需要强制缓存计费 +// 粘性会话切换账号、或上游明确标记时,将 input_tokens 转为 cache_read 计费 +func needForceCacheBilling(hasBoundSession bool, failoverErr *service.UpstreamFailoverError) bool { + return hasBoundSession || (failoverErr != nil && failoverErr.ForceCacheBilling) +} + +// sleepFailoverDelay 账号切换线性递增延时:第1次0s、第2次1s、第3次2s… +// 返回 false 表示 context 已取消。 +func sleepFailoverDelay(ctx context.Context, switchCount int) bool { + delay := time.Duration(switchCount-1) * time.Second + if delay <= 0 { + return true + } + select { + case <-ctx.Done(): + return false + case <-time.After(delay): + return true + } +} + func (h *GatewayHandler) handleFailoverExhausted(c *gin.Context, failoverErr *service.UpstreamFailoverError, platform string, streamStarted bool) { statusCode := failoverErr.StatusCode responseBody := failoverErr.ResponseBody @@ -947,7 +984,7 @@ func (h *GatewayHandler) CountTokens(c *gin.Context) { setOpsRequestContext(c, "", false, body) - parsedReq, err := service.ParseGatewayRequest(body) + parsedReq, err := service.ParseGatewayRequest(body, domain.PlatformAnthropic) if err != nil { h.errorResponse(c, http.StatusBadRequest, "invalid_request_error", "Failed to parse request body") return @@ -975,6 +1012,11 @@ func (h *GatewayHandler) CountTokens(c *gin.Context) { } // 计算粘性会话 hash + parsedReq.SessionContext = &service.SessionContext{ + ClientIP: ip.GetClientIP(c), + UserAgent: c.GetHeader("User-Agent"), + APIKeyID: apiKey.ID, + } sessionHash := h.gatewayService.GenerateSessionHash(parsedReq) // 选择支持该模型的账号 diff --git a/backend/internal/handler/gemini_v1beta_handler.go b/backend/internal/handler/gemini_v1beta_handler.go index b1477ac6..d5149f22 100644 --- a/backend/internal/handler/gemini_v1beta_handler.go +++ b/backend/internal/handler/gemini_v1beta_handler.go @@ -14,6 +14,7 @@ import ( "strings" "time" + "github.com/Wei-Shaw/sub2api/internal/domain" "github.com/Wei-Shaw/sub2api/internal/pkg/antigravity" "github.com/Wei-Shaw/sub2api/internal/pkg/ctxkey" "github.com/Wei-Shaw/sub2api/internal/pkg/gemini" @@ -30,13 +31,6 @@ import ( // 匹配格式: /Users/xxx/.gemini/tmp/[64位十六进制哈希] var geminiCLITmpDirRegex = regexp.MustCompile(`/\.gemini/tmp/([A-Fa-f0-9]{64})`) -func isGeminiCLIRequest(c *gin.Context, body []byte) bool { - if strings.TrimSpace(c.GetHeader("x-gemini-api-privileged-user-id")) != "" { - return true - } - return geminiCLITmpDirRegex.Match(body) -} - // GeminiV1BetaListModels proxies: // GET /v1beta/models func (h *GatewayHandler) GeminiV1BetaListModels(c *gin.Context) { @@ -239,7 +233,14 @@ func (h *GatewayHandler) GeminiV1BetaModels(c *gin.Context) { sessionHash := extractGeminiCLISessionHash(c, body) if sessionHash == "" { // Fallback: 使用通用的会话哈希生成逻辑(适用于其他客户端) - parsedReq, _ := service.ParseGatewayRequest(body) + parsedReq, _ := service.ParseGatewayRequest(body, domain.PlatformGemini) + if parsedReq != nil { + parsedReq.SessionContext = &service.SessionContext{ + ClientIP: ip.GetClientIP(c), + UserAgent: c.GetHeader("User-Agent"), + APIKeyID: apiKey.ID, + } + } sessionHash = h.gatewayService.GenerateSessionHash(parsedReq) } sessionKey := sessionHash @@ -258,6 +259,7 @@ func (h *GatewayHandler) GeminiV1BetaModels(c *gin.Context) { var geminiDigestChain string var geminiPrefixHash string var geminiSessionUUID string + var matchedDigestChain string useDigestFallback := sessionBoundAccountID == 0 if useDigestFallback { @@ -284,13 +286,14 @@ func (h *GatewayHandler) GeminiV1BetaModels(c *gin.Context) { ) // 查找会话 - foundUUID, foundAccountID, found := h.gatewayService.FindGeminiSession( + foundUUID, foundAccountID, foundMatchedChain, found := h.gatewayService.FindGeminiSession( c.Request.Context(), derefGroupID(apiKey.GroupID), geminiPrefixHash, geminiDigestChain, ) if found { + matchedDigestChain = foundMatchedChain sessionBoundAccountID = foundAccountID geminiSessionUUID = foundUUID log.Printf("[Gemini] Digest fallback matched: uuid=%s, accountID=%d, chain=%s", @@ -316,7 +319,6 @@ func (h *GatewayHandler) GeminiV1BetaModels(c *gin.Context) { // 判断是否真的绑定了粘性会话:有 sessionKey 且已经绑定到某个账号 hasBoundSession := sessionKey != "" && sessionBoundAccountID > 0 - isCLI := isGeminiCLIRequest(c, body) cleanedForUnknownBinding := false maxAccountSwitches := h.maxAccountSwitchesGemini @@ -344,10 +346,10 @@ func (h *GatewayHandler) GeminiV1BetaModels(c *gin.Context) { log.Printf("[Gemini] Sticky session account switched: %d -> %d, cleaning thoughtSignature", sessionBoundAccountID, account.ID) body = service.CleanGeminiNativeThoughtSignatures(body) sessionBoundAccountID = account.ID - } else if sessionKey != "" && sessionBoundAccountID == 0 && isCLI && !cleanedForUnknownBinding && bytes.Contains(body, []byte(`"thoughtSignature"`)) { - // 无缓存绑定但请求里已有 thoughtSignature:常见于缓存丢失/TTL 过期后,CLI 继续携带旧签名。 + } else if sessionKey != "" && sessionBoundAccountID == 0 && !cleanedForUnknownBinding && bytes.Contains(body, []byte(`"thoughtSignature"`)) { + // 无缓存绑定但请求里已有 thoughtSignature:常见于缓存丢失/TTL 过期后,客户端继续携带旧签名。 // 为避免第一次转发就 400,这里做一次确定性清理,让新账号重新生成签名链路。 - log.Printf("[Gemini] Sticky session binding missing for CLI request, cleaning thoughtSignature proactively") + log.Printf("[Gemini] Sticky session binding missing, cleaning thoughtSignature proactively") body = service.CleanGeminiNativeThoughtSignatures(body) cleanedForUnknownBinding = true sessionBoundAccountID = account.ID @@ -410,7 +412,7 @@ func (h *GatewayHandler) GeminiV1BetaModels(c *gin.Context) { if switchCount > 0 { requestCtx = context.WithValue(requestCtx, ctxkey.AccountSwitchCount, switchCount) } - if account.Platform == service.PlatformAntigravity { + if account.Platform == service.PlatformAntigravity && account.Type != service.AccountTypeAPIKey { result, err = h.antigravityGatewayService.ForwardGemini(requestCtx, c, account, modelName, action, stream, body, hasBoundSession) } else { result, err = h.geminiCompatService.ForwardNative(requestCtx, c, account, modelName, action, stream, body) @@ -422,7 +424,7 @@ func (h *GatewayHandler) GeminiV1BetaModels(c *gin.Context) { var failoverErr *service.UpstreamFailoverError if errors.As(err, &failoverErr) { failedAccountIDs[account.ID] = struct{}{} - if failoverErr.ForceCacheBilling { + if needForceCacheBilling(hasBoundSession, failoverErr) { forceCacheBilling = true } if switchCount >= maxAccountSwitches { @@ -433,6 +435,11 @@ func (h *GatewayHandler) GeminiV1BetaModels(c *gin.Context) { lastFailoverErr = failoverErr switchCount++ log.Printf("Gemini account %d: upstream error %d, switching account %d/%d", account.ID, failoverErr.StatusCode, switchCount, maxAccountSwitches) + if account.Platform == service.PlatformAntigravity { + if !sleepFailoverDelay(c.Request.Context(), switchCount) { + return + } + } continue } // ForwardNative already wrote the response @@ -453,6 +460,7 @@ func (h *GatewayHandler) GeminiV1BetaModels(c *gin.Context) { geminiDigestChain, geminiSessionUUID, account.ID, + matchedDigestChain, ); err != nil { log.Printf("[Gemini] Failed to save digest session: %v", err) } diff --git a/backend/internal/repository/account_repo.go b/backend/internal/repository/account_repo.go index 562d2feb..7e6fddfb 100644 --- a/backend/internal/repository/account_repo.go +++ b/backend/internal/repository/account_repo.go @@ -282,6 +282,34 @@ func (r *accountRepository) GetByCRSAccountID(ctx context.Context, crsAccountID return &accounts[0], nil } +func (r *accountRepository) ListCRSAccountIDs(ctx context.Context) (map[string]int64, error) { + rows, err := r.sql.QueryContext(ctx, ` + SELECT id, extra->>'crs_account_id' + FROM accounts + WHERE deleted_at IS NULL + AND extra->>'crs_account_id' IS NOT NULL + AND extra->>'crs_account_id' != '' + `) + if err != nil { + return nil, err + } + defer func() { _ = rows.Close() }() + + result := make(map[string]int64) + for rows.Next() { + var id int64 + var crsID string + if err := rows.Scan(&id, &crsID); err != nil { + return nil, err + } + result[crsID] = id + } + if err := rows.Err(); err != nil { + return nil, err + } + return result, nil +} + func (r *accountRepository) Update(ctx context.Context, account *service.Account) error { if account == nil { return nil @@ -798,53 +826,6 @@ func (r *accountRepository) SetRateLimited(ctx context.Context, id int64, resetA return nil } -func (r *accountRepository) SetAntigravityQuotaScopeLimit(ctx context.Context, id int64, scope service.AntigravityQuotaScope, resetAt time.Time) error { - now := time.Now().UTC() - payload := map[string]string{ - "rate_limited_at": now.Format(time.RFC3339), - "rate_limit_reset_at": resetAt.UTC().Format(time.RFC3339), - } - raw, err := json.Marshal(payload) - if err != nil { - return err - } - - scopeKey := string(scope) - client := clientFromContext(ctx, r.client) - result, err := client.ExecContext( - ctx, - `UPDATE accounts SET - extra = jsonb_set( - jsonb_set(COALESCE(extra, '{}'::jsonb), '{antigravity_quota_scopes}'::text[], COALESCE(extra->'antigravity_quota_scopes', '{}'::jsonb), true), - ARRAY['antigravity_quota_scopes', $1]::text[], - $2::jsonb, - true - ), - updated_at = NOW(), - last_used_at = NOW() - WHERE id = $3 AND deleted_at IS NULL`, - scopeKey, - raw, - id, - ) - if err != nil { - return err - } - - affected, err := result.RowsAffected() - if err != nil { - return err - } - if affected == 0 { - return service.ErrAccountNotFound - } - - if err := enqueueSchedulerOutbox(ctx, r.sql, service.SchedulerOutboxEventAccountChanged, &id, nil, nil); err != nil { - log.Printf("[SchedulerOutbox] enqueue quota scope failed: account=%d err=%v", id, err) - } - return nil -} - func (r *accountRepository) SetModelRateLimit(ctx context.Context, id int64, scope string, resetAt time.Time) error { if scope == "" { return nil diff --git a/backend/internal/repository/api_key_repo.go b/backend/internal/repository/api_key_repo.go index 220e63d2..9dcf0fe6 100644 --- a/backend/internal/repository/api_key_repo.go +++ b/backend/internal/repository/api_key_repo.go @@ -476,6 +476,7 @@ func groupEntityToService(g *dbent.Group) *service.Group { ModelRoutingEnabled: g.ModelRoutingEnabled, MCPXMLInject: g.McpXMLInject, SupportedModelScopes: g.SupportedModelScopes, + SortOrder: g.SortOrder, CreatedAt: g.CreatedAt, UpdatedAt: g.UpdatedAt, } diff --git a/backend/internal/repository/gateway_cache.go b/backend/internal/repository/gateway_cache.go index 9365252a..58291b66 100644 --- a/backend/internal/repository/gateway_cache.go +++ b/backend/internal/repository/gateway_cache.go @@ -11,63 +11,6 @@ import ( const stickySessionPrefix = "sticky_session:" -// Gemini Trie Lua 脚本 -const ( - // geminiTrieFindScript 查找最长前缀匹配的 Lua 脚本 - // KEYS[1] = trie key - // ARGV[1] = digestChain (如 "u:a-m:b-u:c-m:d") - // ARGV[2] = TTL seconds (用于刷新) - // 返回: 最长匹配的 value (uuid:accountID) 或 nil - // 查找成功时自动刷新 TTL,防止活跃会话意外过期 - geminiTrieFindScript = ` -local chain = ARGV[1] -local ttl = tonumber(ARGV[2]) -local lastMatch = nil -local path = "" - -for part in string.gmatch(chain, "[^-]+") do - path = path == "" and part or path .. "-" .. part - local val = redis.call('HGET', KEYS[1], path) - if val and val ~= "" then - lastMatch = val - end -end - -if lastMatch then - redis.call('EXPIRE', KEYS[1], ttl) -end - -return lastMatch -` - - // geminiTrieSaveScript 保存会话到 Trie 的 Lua 脚本 - // KEYS[1] = trie key - // ARGV[1] = digestChain - // ARGV[2] = value (uuid:accountID) - // ARGV[3] = TTL seconds - geminiTrieSaveScript = ` -local chain = ARGV[1] -local value = ARGV[2] -local ttl = tonumber(ARGV[3]) -local path = "" - -for part in string.gmatch(chain, "[^-]+") do - path = path == "" and part or path .. "-" .. part -end -redis.call('HSET', KEYS[1], path, value) -redis.call('EXPIRE', KEYS[1], ttl) -return "OK" -` -) - -// 模型负载统计相关常量 -const ( - modelLoadKeyPrefix = "ag:model_load:" // 模型调用次数 key 前缀 - modelLastUsedKeyPrefix = "ag:model_last_used:" // 模型最后调度时间 key 前缀 - modelLoadTTL = 24 * time.Hour // 调用次数 TTL(24 小时无调用后清零) - modelLastUsedTTL = 24 * time.Hour // 最后调度时间 TTL -) - type gatewayCache struct { rdb *redis.Client } @@ -108,133 +51,3 @@ func (c *gatewayCache) DeleteSessionAccountID(ctx context.Context, groupID int64 key := buildSessionKey(groupID, sessionHash) return c.rdb.Del(ctx, key).Err() } - -// ============ Antigravity 模型负载统计方法 ============ - -// modelLoadKey 构建模型调用次数 key -// 格式: ag:model_load:{accountID}:{model} -func modelLoadKey(accountID int64, model string) string { - return fmt.Sprintf("%s%d:%s", modelLoadKeyPrefix, accountID, model) -} - -// modelLastUsedKey 构建模型最后调度时间 key -// 格式: ag:model_last_used:{accountID}:{model} -func modelLastUsedKey(accountID int64, model string) string { - return fmt.Sprintf("%s%d:%s", modelLastUsedKeyPrefix, accountID, model) -} - -// IncrModelCallCount 增加模型调用次数并更新最后调度时间 -// 返回更新后的调用次数 -func (c *gatewayCache) IncrModelCallCount(ctx context.Context, accountID int64, model string) (int64, error) { - loadKey := modelLoadKey(accountID, model) - lastUsedKey := modelLastUsedKey(accountID, model) - - pipe := c.rdb.Pipeline() - incrCmd := pipe.Incr(ctx, loadKey) - pipe.Expire(ctx, loadKey, modelLoadTTL) // 每次调用刷新 TTL - pipe.Set(ctx, lastUsedKey, time.Now().Unix(), modelLastUsedTTL) - if _, err := pipe.Exec(ctx); err != nil { - return 0, err - } - return incrCmd.Val(), nil -} - -// GetModelLoadBatch 批量获取账号的模型负载信息 -func (c *gatewayCache) GetModelLoadBatch(ctx context.Context, accountIDs []int64, model string) (map[int64]*service.ModelLoadInfo, error) { - if len(accountIDs) == 0 { - return make(map[int64]*service.ModelLoadInfo), nil - } - - loadCmds, lastUsedCmds := c.pipelineModelLoadGet(ctx, accountIDs, model) - return c.parseModelLoadResults(accountIDs, loadCmds, lastUsedCmds), nil -} - -// pipelineModelLoadGet 批量获取模型负载的 Pipeline 操作 -func (c *gatewayCache) pipelineModelLoadGet( - ctx context.Context, - accountIDs []int64, - model string, -) (map[int64]*redis.StringCmd, map[int64]*redis.StringCmd) { - pipe := c.rdb.Pipeline() - loadCmds := make(map[int64]*redis.StringCmd, len(accountIDs)) - lastUsedCmds := make(map[int64]*redis.StringCmd, len(accountIDs)) - - for _, id := range accountIDs { - loadCmds[id] = pipe.Get(ctx, modelLoadKey(id, model)) - lastUsedCmds[id] = pipe.Get(ctx, modelLastUsedKey(id, model)) - } - _, _ = pipe.Exec(ctx) // 忽略错误,key 不存在是正常的 - return loadCmds, lastUsedCmds -} - -// parseModelLoadResults 解析 Pipeline 结果 -func (c *gatewayCache) parseModelLoadResults( - accountIDs []int64, - loadCmds map[int64]*redis.StringCmd, - lastUsedCmds map[int64]*redis.StringCmd, -) map[int64]*service.ModelLoadInfo { - result := make(map[int64]*service.ModelLoadInfo, len(accountIDs)) - for _, id := range accountIDs { - result[id] = &service.ModelLoadInfo{ - CallCount: getInt64OrZero(loadCmds[id]), - LastUsedAt: getTimeOrZero(lastUsedCmds[id]), - } - } - return result -} - -// getInt64OrZero 从 StringCmd 获取 int64 值,失败返回 0 -func getInt64OrZero(cmd *redis.StringCmd) int64 { - val, _ := cmd.Int64() - return val -} - -// getTimeOrZero 从 StringCmd 获取 time.Time,失败返回零值 -func getTimeOrZero(cmd *redis.StringCmd) time.Time { - val, err := cmd.Int64() - if err != nil { - return time.Time{} - } - return time.Unix(val, 0) -} - -// ============ Gemini 会话 Fallback 方法 (Trie 实现) ============ - -// FindGeminiSession 查找 Gemini 会话(使用 Trie + Lua 脚本实现 O(L) 查询) -// 返回最长匹配的会话信息,匹配成功时自动刷新 TTL -func (c *gatewayCache) FindGeminiSession(ctx context.Context, groupID int64, prefixHash, digestChain string) (uuid string, accountID int64, found bool) { - if digestChain == "" { - return "", 0, false - } - - trieKey := service.BuildGeminiTrieKey(groupID, prefixHash) - ttlSeconds := int(service.GeminiSessionTTL().Seconds()) - - // 使用 Lua 脚本在 Redis 端执行 Trie 查找,O(L) 次 HGET,1 次网络往返 - // 查找成功时自动刷新 TTL,防止活跃会话意外过期 - result, err := c.rdb.Eval(ctx, geminiTrieFindScript, []string{trieKey}, digestChain, ttlSeconds).Result() - if err != nil || result == nil { - return "", 0, false - } - - value, ok := result.(string) - if !ok || value == "" { - return "", 0, false - } - - uuid, accountID, ok = service.ParseGeminiSessionValue(value) - return uuid, accountID, ok -} - -// SaveGeminiSession 保存 Gemini 会话(使用 Trie + Lua 脚本) -func (c *gatewayCache) SaveGeminiSession(ctx context.Context, groupID int64, prefixHash, digestChain, uuid string, accountID int64) error { - if digestChain == "" { - return nil - } - - trieKey := service.BuildGeminiTrieKey(groupID, prefixHash) - value := service.FormatGeminiSessionValue(uuid, accountID) - ttlSeconds := int(service.GeminiSessionTTL().Seconds()) - - return c.rdb.Eval(ctx, geminiTrieSaveScript, []string{trieKey}, digestChain, value, ttlSeconds).Err() -} diff --git a/backend/internal/repository/gateway_cache_integration_test.go b/backend/internal/repository/gateway_cache_integration_test.go index fc8e7372..2fdaa3d1 100644 --- a/backend/internal/repository/gateway_cache_integration_test.go +++ b/backend/internal/repository/gateway_cache_integration_test.go @@ -104,157 +104,6 @@ func (s *GatewayCacheSuite) TestGetSessionAccountID_CorruptedValue() { require.False(s.T(), errors.Is(err, redis.Nil), "expected parsing error, not redis.Nil") } -// ============ Gemini Trie 会话测试 ============ - -func (s *GatewayCacheSuite) TestGeminiSessionTrie_SaveAndFind() { - groupID := int64(1) - prefixHash := "testprefix" - digestChain := "u:hash1-m:hash2-u:hash3" - uuid := "test-uuid-123" - accountID := int64(42) - - // 保存会话 - err := s.cache.SaveGeminiSession(s.ctx, groupID, prefixHash, digestChain, uuid, accountID) - require.NoError(s.T(), err, "SaveGeminiSession") - - // 精确匹配查找 - foundUUID, foundAccountID, found := s.cache.FindGeminiSession(s.ctx, groupID, prefixHash, digestChain) - require.True(s.T(), found, "should find exact match") - require.Equal(s.T(), uuid, foundUUID) - require.Equal(s.T(), accountID, foundAccountID) -} - -func (s *GatewayCacheSuite) TestGeminiSessionTrie_PrefixMatch() { - groupID := int64(1) - prefixHash := "prefixmatch" - shortChain := "u:a-m:b" - longChain := "u:a-m:b-u:c-m:d" - uuid := "uuid-prefix" - accountID := int64(100) - - // 保存短链 - err := s.cache.SaveGeminiSession(s.ctx, groupID, prefixHash, shortChain, uuid, accountID) - require.NoError(s.T(), err) - - // 用长链查找,应该匹配到短链(前缀匹配) - foundUUID, foundAccountID, found := s.cache.FindGeminiSession(s.ctx, groupID, prefixHash, longChain) - require.True(s.T(), found, "should find prefix match") - require.Equal(s.T(), uuid, foundUUID) - require.Equal(s.T(), accountID, foundAccountID) -} - -func (s *GatewayCacheSuite) TestGeminiSessionTrie_LongestPrefixMatch() { - groupID := int64(1) - prefixHash := "longestmatch" - - // 保存多个不同长度的链 - err := s.cache.SaveGeminiSession(s.ctx, groupID, prefixHash, "u:a", "uuid-short", 1) - require.NoError(s.T(), err) - err = s.cache.SaveGeminiSession(s.ctx, groupID, prefixHash, "u:a-m:b", "uuid-medium", 2) - require.NoError(s.T(), err) - err = s.cache.SaveGeminiSession(s.ctx, groupID, prefixHash, "u:a-m:b-u:c", "uuid-long", 3) - require.NoError(s.T(), err) - - // 查找更长的链,应该匹配到最长的前缀 - foundUUID, foundAccountID, found := s.cache.FindGeminiSession(s.ctx, groupID, prefixHash, "u:a-m:b-u:c-m:d-u:e") - require.True(s.T(), found, "should find longest prefix match") - require.Equal(s.T(), "uuid-long", foundUUID) - require.Equal(s.T(), int64(3), foundAccountID) - - // 查找中等长度的链 - foundUUID, foundAccountID, found = s.cache.FindGeminiSession(s.ctx, groupID, prefixHash, "u:a-m:b-u:x") - require.True(s.T(), found) - require.Equal(s.T(), "uuid-medium", foundUUID) - require.Equal(s.T(), int64(2), foundAccountID) -} - -func (s *GatewayCacheSuite) TestGeminiSessionTrie_NoMatch() { - groupID := int64(1) - prefixHash := "nomatch" - digestChain := "u:a-m:b" - - // 保存一个会话 - err := s.cache.SaveGeminiSession(s.ctx, groupID, prefixHash, digestChain, "uuid", 1) - require.NoError(s.T(), err) - - // 用不同的链查找,应该找不到 - _, _, found := s.cache.FindGeminiSession(s.ctx, groupID, prefixHash, "u:x-m:y") - require.False(s.T(), found, "should not find non-matching chain") -} - -func (s *GatewayCacheSuite) TestGeminiSessionTrie_DifferentPrefixHash() { - groupID := int64(1) - digestChain := "u:a-m:b" - - // 保存到 prefixHash1 - err := s.cache.SaveGeminiSession(s.ctx, groupID, "prefix1", digestChain, "uuid1", 1) - require.NoError(s.T(), err) - - // 用 prefixHash2 查找,应该找不到(不同用户/客户端隔离) - _, _, found := s.cache.FindGeminiSession(s.ctx, groupID, "prefix2", digestChain) - require.False(s.T(), found, "different prefixHash should be isolated") -} - -func (s *GatewayCacheSuite) TestGeminiSessionTrie_DifferentGroupID() { - prefixHash := "sameprefix" - digestChain := "u:a-m:b" - - // 保存到 groupID 1 - err := s.cache.SaveGeminiSession(s.ctx, 1, prefixHash, digestChain, "uuid1", 1) - require.NoError(s.T(), err) - - // 用 groupID 2 查找,应该找不到(分组隔离) - _, _, found := s.cache.FindGeminiSession(s.ctx, 2, prefixHash, digestChain) - require.False(s.T(), found, "different groupID should be isolated") -} - -func (s *GatewayCacheSuite) TestGeminiSessionTrie_EmptyDigestChain() { - groupID := int64(1) - prefixHash := "emptytest" - - // 空链不应该保存 - err := s.cache.SaveGeminiSession(s.ctx, groupID, prefixHash, "", "uuid", 1) - require.NoError(s.T(), err, "empty chain should not error") - - // 空链查找应该返回 false - _, _, found := s.cache.FindGeminiSession(s.ctx, groupID, prefixHash, "") - require.False(s.T(), found, "empty chain should not match") -} - -func (s *GatewayCacheSuite) TestGeminiSessionTrie_MultipleSessions() { - groupID := int64(1) - prefixHash := "multisession" - - // 保存多个不同会话(模拟 1000 个并发会话的场景) - sessions := []struct { - chain string - uuid string - accountID int64 - }{ - {"u:session1", "uuid-1", 1}, - {"u:session2-m:reply2", "uuid-2", 2}, - {"u:session3-m:reply3-u:msg3", "uuid-3", 3}, - } - - for _, sess := range sessions { - err := s.cache.SaveGeminiSession(s.ctx, groupID, prefixHash, sess.chain, sess.uuid, sess.accountID) - require.NoError(s.T(), err) - } - - // 验证每个会话都能正确查找 - for _, sess := range sessions { - foundUUID, foundAccountID, found := s.cache.FindGeminiSession(s.ctx, groupID, prefixHash, sess.chain) - require.True(s.T(), found, "should find session: %s", sess.chain) - require.Equal(s.T(), sess.uuid, foundUUID) - require.Equal(s.T(), sess.accountID, foundAccountID) - } - - // 验证继续对话的场景 - foundUUID, foundAccountID, found := s.cache.FindGeminiSession(s.ctx, groupID, prefixHash, "u:session2-m:reply2-u:newmsg") - require.True(s.T(), found) - require.Equal(s.T(), "uuid-2", foundUUID) - require.Equal(s.T(), int64(2), foundAccountID) -} func TestGatewayCacheSuite(t *testing.T) { suite.Run(t, new(GatewayCacheSuite)) diff --git a/backend/internal/repository/gateway_cache_model_load_integration_test.go b/backend/internal/repository/gateway_cache_model_load_integration_test.go deleted file mode 100644 index de6fa5ae..00000000 --- a/backend/internal/repository/gateway_cache_model_load_integration_test.go +++ /dev/null @@ -1,234 +0,0 @@ -//go:build integration - -package repository - -import ( - "context" - "testing" - "time" - - "github.com/stretchr/testify/require" - "github.com/stretchr/testify/suite" -) - -// ============ Gateway Cache 模型负载统计集成测试 ============ - -type GatewayCacheModelLoadSuite struct { - suite.Suite -} - -func TestGatewayCacheModelLoadSuite(t *testing.T) { - suite.Run(t, new(GatewayCacheModelLoadSuite)) -} - -func (s *GatewayCacheModelLoadSuite) TestIncrModelCallCount_Basic() { - t := s.T() - rdb := testRedis(t) - cache := &gatewayCache{rdb: rdb} - ctx := context.Background() - - accountID := int64(123) - model := "claude-sonnet-4-20250514" - - // 首次调用应返回 1 - count1, err := cache.IncrModelCallCount(ctx, accountID, model) - require.NoError(t, err) - require.Equal(t, int64(1), count1) - - // 第二次调用应返回 2 - count2, err := cache.IncrModelCallCount(ctx, accountID, model) - require.NoError(t, err) - require.Equal(t, int64(2), count2) - - // 第三次调用应返回 3 - count3, err := cache.IncrModelCallCount(ctx, accountID, model) - require.NoError(t, err) - require.Equal(t, int64(3), count3) -} - -func (s *GatewayCacheModelLoadSuite) TestIncrModelCallCount_DifferentModels() { - t := s.T() - rdb := testRedis(t) - cache := &gatewayCache{rdb: rdb} - ctx := context.Background() - - accountID := int64(456) - model1 := "claude-sonnet-4-20250514" - model2 := "claude-opus-4-5-20251101" - - // 不同模型应该独立计数 - count1, err := cache.IncrModelCallCount(ctx, accountID, model1) - require.NoError(t, err) - require.Equal(t, int64(1), count1) - - count2, err := cache.IncrModelCallCount(ctx, accountID, model2) - require.NoError(t, err) - require.Equal(t, int64(1), count2) - - count1Again, err := cache.IncrModelCallCount(ctx, accountID, model1) - require.NoError(t, err) - require.Equal(t, int64(2), count1Again) -} - -func (s *GatewayCacheModelLoadSuite) TestIncrModelCallCount_DifferentAccounts() { - t := s.T() - rdb := testRedis(t) - cache := &gatewayCache{rdb: rdb} - ctx := context.Background() - - account1 := int64(111) - account2 := int64(222) - model := "gemini-2.5-pro" - - // 不同账号应该独立计数 - count1, err := cache.IncrModelCallCount(ctx, account1, model) - require.NoError(t, err) - require.Equal(t, int64(1), count1) - - count2, err := cache.IncrModelCallCount(ctx, account2, model) - require.NoError(t, err) - require.Equal(t, int64(1), count2) -} - -func (s *GatewayCacheModelLoadSuite) TestGetModelLoadBatch_Empty() { - t := s.T() - rdb := testRedis(t) - cache := &gatewayCache{rdb: rdb} - ctx := context.Background() - - result, err := cache.GetModelLoadBatch(ctx, []int64{}, "any-model") - require.NoError(t, err) - require.NotNil(t, result) - require.Empty(t, result) -} - -func (s *GatewayCacheModelLoadSuite) TestGetModelLoadBatch_NonExistent() { - t := s.T() - rdb := testRedis(t) - cache := &gatewayCache{rdb: rdb} - ctx := context.Background() - - // 查询不存在的账号应返回零值 - result, err := cache.GetModelLoadBatch(ctx, []int64{9999, 9998}, "claude-sonnet-4-20250514") - require.NoError(t, err) - require.Len(t, result, 2) - - require.Equal(t, int64(0), result[9999].CallCount) - require.True(t, result[9999].LastUsedAt.IsZero()) - require.Equal(t, int64(0), result[9998].CallCount) - require.True(t, result[9998].LastUsedAt.IsZero()) -} - -func (s *GatewayCacheModelLoadSuite) TestGetModelLoadBatch_AfterIncrement() { - t := s.T() - rdb := testRedis(t) - cache := &gatewayCache{rdb: rdb} - ctx := context.Background() - - accountID := int64(789) - model := "claude-sonnet-4-20250514" - - // 先增加调用次数 - beforeIncr := time.Now() - _, err := cache.IncrModelCallCount(ctx, accountID, model) - require.NoError(t, err) - _, err = cache.IncrModelCallCount(ctx, accountID, model) - require.NoError(t, err) - _, err = cache.IncrModelCallCount(ctx, accountID, model) - require.NoError(t, err) - afterIncr := time.Now() - - // 获取负载信息 - result, err := cache.GetModelLoadBatch(ctx, []int64{accountID}, model) - require.NoError(t, err) - require.Len(t, result, 1) - - loadInfo := result[accountID] - require.NotNil(t, loadInfo) - require.Equal(t, int64(3), loadInfo.CallCount) - require.False(t, loadInfo.LastUsedAt.IsZero()) - // LastUsedAt 应该在 beforeIncr 和 afterIncr 之间 - require.True(t, loadInfo.LastUsedAt.After(beforeIncr.Add(-time.Second)) || loadInfo.LastUsedAt.Equal(beforeIncr)) - require.True(t, loadInfo.LastUsedAt.Before(afterIncr.Add(time.Second)) || loadInfo.LastUsedAt.Equal(afterIncr)) -} - -func (s *GatewayCacheModelLoadSuite) TestGetModelLoadBatch_MultipleAccounts() { - t := s.T() - rdb := testRedis(t) - cache := &gatewayCache{rdb: rdb} - ctx := context.Background() - - model := "claude-opus-4-5-20251101" - account1 := int64(1001) - account2 := int64(1002) - account3 := int64(1003) // 不调用 - - // account1 调用 2 次 - _, err := cache.IncrModelCallCount(ctx, account1, model) - require.NoError(t, err) - _, err = cache.IncrModelCallCount(ctx, account1, model) - require.NoError(t, err) - - // account2 调用 5 次 - for i := 0; i < 5; i++ { - _, err = cache.IncrModelCallCount(ctx, account2, model) - require.NoError(t, err) - } - - // 批量获取 - result, err := cache.GetModelLoadBatch(ctx, []int64{account1, account2, account3}, model) - require.NoError(t, err) - require.Len(t, result, 3) - - require.Equal(t, int64(2), result[account1].CallCount) - require.False(t, result[account1].LastUsedAt.IsZero()) - - require.Equal(t, int64(5), result[account2].CallCount) - require.False(t, result[account2].LastUsedAt.IsZero()) - - require.Equal(t, int64(0), result[account3].CallCount) - require.True(t, result[account3].LastUsedAt.IsZero()) -} - -func (s *GatewayCacheModelLoadSuite) TestGetModelLoadBatch_ModelIsolation() { - t := s.T() - rdb := testRedis(t) - cache := &gatewayCache{rdb: rdb} - ctx := context.Background() - - accountID := int64(2001) - model1 := "claude-sonnet-4-20250514" - model2 := "gemini-2.5-pro" - - // 对 model1 调用 3 次 - for i := 0; i < 3; i++ { - _, err := cache.IncrModelCallCount(ctx, accountID, model1) - require.NoError(t, err) - } - - // 获取 model1 的负载 - result1, err := cache.GetModelLoadBatch(ctx, []int64{accountID}, model1) - require.NoError(t, err) - require.Equal(t, int64(3), result1[accountID].CallCount) - - // 获取 model2 的负载(应该为 0) - result2, err := cache.GetModelLoadBatch(ctx, []int64{accountID}, model2) - require.NoError(t, err) - require.Equal(t, int64(0), result2[accountID].CallCount) -} - -// ============ 辅助函数测试 ============ - -func (s *GatewayCacheModelLoadSuite) TestModelLoadKey_Format() { - t := s.T() - - key := modelLoadKey(123, "claude-sonnet-4") - require.Equal(t, "ag:model_load:123:claude-sonnet-4", key) -} - -func (s *GatewayCacheModelLoadSuite) TestModelLastUsedKey_Format() { - t := s.T() - - key := modelLastUsedKey(456, "gemini-2.5-pro") - require.Equal(t, "ag:model_last_used:456:gemini-2.5-pro", key) -} diff --git a/backend/internal/repository/group_repo.go b/backend/internal/repository/group_repo.go index 234a4526..6c414efa 100644 --- a/backend/internal/repository/group_repo.go +++ b/backend/internal/repository/group_repo.go @@ -199,7 +199,7 @@ func (r *groupRepository) ListWithFilters(ctx context.Context, params pagination groups, err := q. Offset(params.Offset()). Limit(params.Limit()). - Order(dbent.Asc(group.FieldID)). + Order(dbent.Asc(group.FieldSortOrder), dbent.Asc(group.FieldID)). All(ctx) if err != nil { return nil, nil, err @@ -226,7 +226,7 @@ func (r *groupRepository) ListWithFilters(ctx context.Context, params pagination func (r *groupRepository) ListActive(ctx context.Context) ([]service.Group, error) { groups, err := r.client.Group.Query(). Where(group.StatusEQ(service.StatusActive)). - Order(dbent.Asc(group.FieldID)). + Order(dbent.Asc(group.FieldSortOrder), dbent.Asc(group.FieldID)). All(ctx) if err != nil { return nil, err @@ -253,7 +253,7 @@ func (r *groupRepository) ListActive(ctx context.Context) ([]service.Group, erro func (r *groupRepository) ListActiveByPlatform(ctx context.Context, platform string) ([]service.Group, error) { groups, err := r.client.Group.Query(). Where(group.StatusEQ(service.StatusActive), group.PlatformEQ(platform)). - Order(dbent.Asc(group.FieldID)). + Order(dbent.Asc(group.FieldSortOrder), dbent.Asc(group.FieldID)). All(ctx) if err != nil { return nil, err @@ -505,3 +505,29 @@ func (r *groupRepository) BindAccountsToGroup(ctx context.Context, groupID int64 return nil } + +// UpdateSortOrders 批量更新分组排序 +func (r *groupRepository) UpdateSortOrders(ctx context.Context, updates []service.GroupSortOrderUpdate) error { + if len(updates) == 0 { + return nil + } + + // 使用事务批量更新 + tx, err := r.client.Tx(ctx) + if err != nil { + return err + } + defer func() { _ = tx.Rollback() }() + + for _, u := range updates { + if _, err := tx.Group.UpdateOneID(u.ID).SetSortOrder(u.SortOrder).Save(ctx); err != nil { + return translatePersistenceError(err, service.ErrGroupNotFound, nil) + } + } + + if err := tx.Commit(); err != nil { + return err + } + + return nil +} diff --git a/backend/internal/server/api_contract_test.go b/backend/internal/server/api_contract_test.go index 6851e71a..c574219b 100644 --- a/backend/internal/server/api_contract_test.go +++ b/backend/internal/server/api_contract_test.go @@ -901,6 +901,10 @@ func (stubGroupRepo) GetAccountIDsByGroupIDs(ctx context.Context, groupIDs []int return nil, errors.New("not implemented") } +func (stubGroupRepo) UpdateSortOrders(ctx context.Context, updates []service.GroupSortOrderUpdate) error { + return nil +} + type stubAccountRepo struct { bulkUpdateIDs []int64 } @@ -1013,10 +1017,6 @@ func (s *stubAccountRepo) SetRateLimited(ctx context.Context, id int64, resetAt return errors.New("not implemented") } -func (s *stubAccountRepo) SetAntigravityQuotaScopeLimit(ctx context.Context, id int64, scope service.AntigravityQuotaScope, resetAt time.Time) error { - return errors.New("not implemented") -} - func (s *stubAccountRepo) SetModelRateLimit(ctx context.Context, id int64, scope string, resetAt time.Time) error { return errors.New("not implemented") } @@ -1058,6 +1058,10 @@ func (s *stubAccountRepo) BulkUpdate(ctx context.Context, ids []int64, updates s return int64(len(ids)), nil } +func (s *stubAccountRepo) ListCRSAccountIDs(ctx context.Context) (map[string]int64, error) { + return nil, errors.New("not implemented") +} + type stubProxyRepo struct{} func (stubProxyRepo) Create(ctx context.Context, proxy *service.Proxy) error { diff --git a/backend/internal/server/routes/admin.go b/backend/internal/server/routes/admin.go index 14815262..39c5d2fc 100644 --- a/backend/internal/server/routes/admin.go +++ b/backend/internal/server/routes/admin.go @@ -192,6 +192,7 @@ func registerGroupRoutes(admin *gin.RouterGroup, h *handler.Handlers) { { groups.GET("", h.Admin.Group.List) groups.GET("/all", h.Admin.Group.GetAll) + groups.PUT("/sort-order", h.Admin.Group.UpdateSortOrder) groups.GET("/:id", h.Admin.Group.GetByID) groups.POST("", h.Admin.Group.Create) groups.PUT("/:id", h.Admin.Group.Update) @@ -208,6 +209,7 @@ func registerAccountRoutes(admin *gin.RouterGroup, h *handler.Handlers) { accounts.GET("/:id", h.Admin.Account.GetByID) accounts.POST("", h.Admin.Account.Create) accounts.POST("/sync/crs", h.Admin.Account.SyncFromCRS) + accounts.POST("/sync/crs/preview", h.Admin.Account.PreviewFromCRS) accounts.PUT("/:id", h.Admin.Account.Update) accounts.DELETE("/:id", h.Admin.Account.Delete) accounts.POST("/:id/test", h.Admin.Account.Test) diff --git a/backend/internal/service/account.go b/backend/internal/service/account.go index a6ae8a68..138d5bcb 100644 --- a/backend/internal/service/account.go +++ b/backend/internal/service/account.go @@ -425,6 +425,22 @@ func (a *Account) GetBaseURL() string { if baseURL == "" { return "https://api.anthropic.com" } + if a.Platform == PlatformAntigravity { + return strings.TrimRight(baseURL, "/") + "/antigravity" + } + return baseURL +} + +// GetGeminiBaseURL 返回 Gemini 兼容端点的 base URL。 +// Antigravity 平台的 APIKey 账号自动拼接 /antigravity。 +func (a *Account) GetGeminiBaseURL(defaultBaseURL string) string { + baseURL := strings.TrimSpace(a.GetCredential("base_url")) + if baseURL == "" { + return defaultBaseURL + } + if a.Platform == PlatformAntigravity && a.Type == AccountTypeAPIKey { + return strings.TrimRight(baseURL, "/") + "/antigravity" + } return baseURL } diff --git a/backend/internal/service/account_base_url_test.go b/backend/internal/service/account_base_url_test.go new file mode 100644 index 00000000..a1322193 --- /dev/null +++ b/backend/internal/service/account_base_url_test.go @@ -0,0 +1,160 @@ +//go:build unit + +package service + +import ( + "testing" +) + +func TestGetBaseURL(t *testing.T) { + tests := []struct { + name string + account Account + expected string + }{ + { + name: "non-apikey type returns empty", + account: Account{ + Type: AccountTypeOAuth, + Platform: PlatformAnthropic, + }, + expected: "", + }, + { + name: "apikey without base_url returns default anthropic", + account: Account{ + Type: AccountTypeAPIKey, + Platform: PlatformAnthropic, + Credentials: map[string]any{}, + }, + expected: "https://api.anthropic.com", + }, + { + name: "apikey with custom base_url", + account: Account{ + Type: AccountTypeAPIKey, + Platform: PlatformAnthropic, + Credentials: map[string]any{"base_url": "https://custom.example.com"}, + }, + expected: "https://custom.example.com", + }, + { + name: "antigravity apikey auto-appends /antigravity", + account: Account{ + Type: AccountTypeAPIKey, + Platform: PlatformAntigravity, + Credentials: map[string]any{"base_url": "https://upstream.example.com"}, + }, + expected: "https://upstream.example.com/antigravity", + }, + { + name: "antigravity apikey trims trailing slash before appending", + account: Account{ + Type: AccountTypeAPIKey, + Platform: PlatformAntigravity, + Credentials: map[string]any{"base_url": "https://upstream.example.com/"}, + }, + expected: "https://upstream.example.com/antigravity", + }, + { + name: "antigravity non-apikey returns empty", + account: Account{ + Type: AccountTypeOAuth, + Platform: PlatformAntigravity, + Credentials: map[string]any{"base_url": "https://upstream.example.com"}, + }, + expected: "", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + result := tt.account.GetBaseURL() + if result != tt.expected { + t.Errorf("GetBaseURL() = %q, want %q", result, tt.expected) + } + }) + } +} + +func TestGetGeminiBaseURL(t *testing.T) { + const defaultGeminiURL = "https://generativelanguage.googleapis.com" + + tests := []struct { + name string + account Account + expected string + }{ + { + name: "apikey without base_url returns default", + account: Account{ + Type: AccountTypeAPIKey, + Platform: PlatformGemini, + Credentials: map[string]any{}, + }, + expected: defaultGeminiURL, + }, + { + name: "apikey with custom base_url", + account: Account{ + Type: AccountTypeAPIKey, + Platform: PlatformGemini, + Credentials: map[string]any{"base_url": "https://custom-gemini.example.com"}, + }, + expected: "https://custom-gemini.example.com", + }, + { + name: "antigravity apikey auto-appends /antigravity", + account: Account{ + Type: AccountTypeAPIKey, + Platform: PlatformAntigravity, + Credentials: map[string]any{"base_url": "https://upstream.example.com"}, + }, + expected: "https://upstream.example.com/antigravity", + }, + { + name: "antigravity apikey trims trailing slash", + account: Account{ + Type: AccountTypeAPIKey, + Platform: PlatformAntigravity, + Credentials: map[string]any{"base_url": "https://upstream.example.com/"}, + }, + expected: "https://upstream.example.com/antigravity", + }, + { + name: "antigravity oauth does NOT append /antigravity", + account: Account{ + Type: AccountTypeOAuth, + Platform: PlatformAntigravity, + Credentials: map[string]any{"base_url": "https://upstream.example.com"}, + }, + expected: "https://upstream.example.com", + }, + { + name: "oauth without base_url returns default", + account: Account{ + Type: AccountTypeOAuth, + Platform: PlatformAntigravity, + Credentials: map[string]any{}, + }, + expected: defaultGeminiURL, + }, + { + name: "nil credentials returns default", + account: Account{ + Type: AccountTypeAPIKey, + Platform: PlatformGemini, + }, + expected: defaultGeminiURL, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + result := tt.account.GetGeminiBaseURL(defaultGeminiURL) + if result != tt.expected { + t.Errorf("GetGeminiBaseURL() = %q, want %q", result, tt.expected) + } + }) + } +} diff --git a/backend/internal/service/account_service.go b/backend/internal/service/account_service.go index a261fb21..3cddd2c7 100644 --- a/backend/internal/service/account_service.go +++ b/backend/internal/service/account_service.go @@ -28,6 +28,9 @@ type AccountRepository interface { // FindByExtraField 根据 extra 字段中的键值对查找账号(限定 platform='sora') // 用于查找通过 linked_openai_account_id 关联的 Sora 账号 FindByExtraField(ctx context.Context, key string, value any) ([]Account, error) + // ListCRSAccountIDs returns a map of crs_account_id -> local account ID + // for all accounts that have been synced from CRS. + ListCRSAccountIDs(ctx context.Context) (map[string]int64, error) Update(ctx context.Context, account *Account) error Delete(ctx context.Context, id int64) error @@ -53,7 +56,6 @@ type AccountRepository interface { ListSchedulableByGroupIDAndPlatforms(ctx context.Context, groupID int64, platforms []string) ([]Account, error) SetRateLimited(ctx context.Context, id int64, resetAt time.Time) error - SetAntigravityQuotaScopeLimit(ctx context.Context, id int64, scope AntigravityQuotaScope, resetAt time.Time) error SetModelRateLimit(ctx context.Context, id int64, scope string, resetAt time.Time) error SetOverloaded(ctx context.Context, id int64, until time.Time) error SetTempUnschedulable(ctx context.Context, id int64, until time.Time, reason string) error diff --git a/backend/internal/service/account_service_delete_test.go b/backend/internal/service/account_service_delete_test.go index f4e03e8e..414b3678 100644 --- a/backend/internal/service/account_service_delete_test.go +++ b/backend/internal/service/account_service_delete_test.go @@ -54,10 +54,14 @@ func (s *accountRepoStub) GetByCRSAccountID(ctx context.Context, crsAccountID st panic("unexpected GetByCRSAccountID call") } -func (s *accountRepoStub) FindByExtraField(ctx context.Context, key string, value interface{}) ([]Account, error) { +func (s *accountRepoStub) FindByExtraField(ctx context.Context, key string, value any) ([]Account, error) { panic("unexpected FindByExtraField call") } +func (s *accountRepoStub) ListCRSAccountIDs(ctx context.Context) (map[string]int64, error) { + panic("unexpected ListCRSAccountIDs call") +} + func (s *accountRepoStub) Update(ctx context.Context, account *Account) error { panic("unexpected Update call") } @@ -147,10 +151,6 @@ func (s *accountRepoStub) SetRateLimited(ctx context.Context, id int64, resetAt panic("unexpected SetRateLimited call") } -func (s *accountRepoStub) SetAntigravityQuotaScopeLimit(ctx context.Context, id int64, scope AntigravityQuotaScope, resetAt time.Time) error { - panic("unexpected SetAntigravityQuotaScopeLimit call") -} - func (s *accountRepoStub) SetModelRateLimit(ctx context.Context, id int64, scope string, resetAt time.Time) error { panic("unexpected SetModelRateLimit call") } diff --git a/backend/internal/service/account_test_service.go b/backend/internal/service/account_test_service.go index acb6eb69..093f7d4d 100644 --- a/backend/internal/service/account_test_service.go +++ b/backend/internal/service/account_test_service.go @@ -250,7 +250,6 @@ func (s *AccountTestService) testClaudeAccountConnection(c *gin.Context, account // Set common headers req.Header.Set("Content-Type", "application/json") req.Header.Set("anthropic-version", "2023-06-01") - req.Header.Set("anthropic-beta", claude.DefaultBetaHeader) // Apply Claude Code client headers for key, value := range claude.DefaultHeaders { @@ -259,8 +258,10 @@ func (s *AccountTestService) testClaudeAccountConnection(c *gin.Context, account // Set authentication header if useBearer { + req.Header.Set("anthropic-beta", claude.DefaultBetaHeader) req.Header.Set("Authorization", "Bearer "+authToken) } else { + req.Header.Set("anthropic-beta", claude.APIKeyBetaHeader) req.Header.Set("x-api-key", authToken) } diff --git a/backend/internal/service/admin_service.go b/backend/internal/service/admin_service.go index 2b69aff3..6be73fda 100644 --- a/backend/internal/service/admin_service.go +++ b/backend/internal/service/admin_service.go @@ -36,6 +36,7 @@ type AdminService interface { UpdateGroup(ctx context.Context, id int64, input *UpdateGroupInput) (*Group, error) DeleteGroup(ctx context.Context, id int64) error GetGroupAPIKeys(ctx context.Context, groupID int64, page, pageSize int) ([]APIKey, int64, error) + UpdateGroupSortOrders(ctx context.Context, updates []GroupSortOrderUpdate) error // Account management ListAccounts(ctx context.Context, page, pageSize int, platform, accountType, status, search string) ([]Account, int64, error) @@ -1048,6 +1049,10 @@ func (s *adminServiceImpl) GetGroupAPIKeys(ctx context.Context, groupID int64, p return keys, result.Total, nil } +func (s *adminServiceImpl) UpdateGroupSortOrders(ctx context.Context, updates []GroupSortOrderUpdate) error { + return s.groupRepo.UpdateSortOrders(ctx, updates) +} + // Account management implementations func (s *adminServiceImpl) ListAccounts(ctx context.Context, page, pageSize int, platform, accountType, status, search string) ([]Account, int64, error) { params := pagination.PaginationParams{Page: page, PageSize: pageSize} diff --git a/backend/internal/service/admin_service_delete_test.go b/backend/internal/service/admin_service_delete_test.go index c775749d..60fa3d77 100644 --- a/backend/internal/service/admin_service_delete_test.go +++ b/backend/internal/service/admin_service_delete_test.go @@ -172,6 +172,10 @@ func (s *groupRepoStub) GetAccountIDsByGroupIDs(ctx context.Context, groupIDs [] panic("unexpected GetAccountIDsByGroupIDs call") } +func (s *groupRepoStub) UpdateSortOrders(ctx context.Context, updates []GroupSortOrderUpdate) error { + return nil +} + type proxyRepoStub struct { deleteErr error countErr error diff --git a/backend/internal/service/admin_service_group_test.go b/backend/internal/service/admin_service_group_test.go index d921a086..ef77a980 100644 --- a/backend/internal/service/admin_service_group_test.go +++ b/backend/internal/service/admin_service_group_test.go @@ -116,6 +116,10 @@ func (s *groupRepoStubForAdmin) GetAccountIDsByGroupIDs(_ context.Context, _ []i panic("unexpected GetAccountIDsByGroupIDs call") } +func (s *groupRepoStubForAdmin) UpdateSortOrders(_ context.Context, _ []GroupSortOrderUpdate) error { + return nil +} + // TestAdminService_CreateGroup_WithImagePricing 测试创建分组时 ImagePrice 字段正确传递 func TestAdminService_CreateGroup_WithImagePricing(t *testing.T) { repo := &groupRepoStubForAdmin{} @@ -395,6 +399,10 @@ func (s *groupRepoStubForFallbackCycle) GetAccountIDsByGroupIDs(_ context.Contex panic("unexpected GetAccountIDsByGroupIDs call") } +func (s *groupRepoStubForFallbackCycle) UpdateSortOrders(_ context.Context, _ []GroupSortOrderUpdate) error { + return nil +} + type groupRepoStubForInvalidRequestFallback struct { groups map[int64]*Group created *Group @@ -466,6 +474,10 @@ func (s *groupRepoStubForInvalidRequestFallback) BindAccountsToGroup(_ context.C panic("unexpected BindAccountsToGroup call") } +func (s *groupRepoStubForInvalidRequestFallback) UpdateSortOrders(_ context.Context, _ []GroupSortOrderUpdate) error { + return nil +} + func TestAdminService_CreateGroup_InvalidRequestFallbackRejectsUnsupportedPlatform(t *testing.T) { fallbackID := int64(10) repo := &groupRepoStubForInvalidRequestFallback{ diff --git a/backend/internal/service/anthropic_session.go b/backend/internal/service/anthropic_session.go new file mode 100644 index 00000000..26544c68 --- /dev/null +++ b/backend/internal/service/anthropic_session.go @@ -0,0 +1,79 @@ +package service + +import ( + "encoding/json" + "strings" + "time" +) + +// Anthropic 会话 Fallback 相关常量 +const ( + // anthropicSessionTTLSeconds Anthropic 会话缓存 TTL(5 分钟) + anthropicSessionTTLSeconds = 300 + + // anthropicDigestSessionKeyPrefix Anthropic 摘要 fallback 会话 key 前缀 + anthropicDigestSessionKeyPrefix = "anthropic:digest:" +) + +// AnthropicSessionTTL 返回 Anthropic 会话缓存 TTL +func AnthropicSessionTTL() time.Duration { + return anthropicSessionTTLSeconds * time.Second +} + +// BuildAnthropicDigestChain 根据 Anthropic 请求生成摘要链 +// 格式: s:-u:-a:-u:-... +// s = system, u = user, a = assistant +func BuildAnthropicDigestChain(parsed *ParsedRequest) string { + if parsed == nil { + return "" + } + + var parts []string + + // 1. system prompt + if parsed.System != nil { + systemData, _ := json.Marshal(parsed.System) + if len(systemData) > 0 && string(systemData) != "null" { + parts = append(parts, "s:"+shortHash(systemData)) + } + } + + // 2. messages + for _, msg := range parsed.Messages { + msgMap, ok := msg.(map[string]any) + if !ok { + continue + } + role, _ := msgMap["role"].(string) + prefix := rolePrefix(role) + content := msgMap["content"] + contentData, _ := json.Marshal(content) + parts = append(parts, prefix+":"+shortHash(contentData)) + } + + return strings.Join(parts, "-") +} + +// rolePrefix 将 Anthropic 的 role 映射为单字符前缀 +func rolePrefix(role string) string { + switch role { + case "assistant": + return "a" + default: + return "u" + } +} + +// GenerateAnthropicDigestSessionKey 生成 Anthropic 摘要 fallback 的 sessionKey +// 组合 prefixHash 前 8 位 + uuid 前 8 位,确保不同会话产生不同的 sessionKey +func GenerateAnthropicDigestSessionKey(prefixHash, uuid string) string { + prefix := prefixHash + if len(prefixHash) >= 8 { + prefix = prefixHash[:8] + } + uuidPart := uuid + if len(uuid) >= 8 { + uuidPart = uuid[:8] + } + return anthropicDigestSessionKeyPrefix + prefix + ":" + uuidPart +} diff --git a/backend/internal/service/anthropic_session_test.go b/backend/internal/service/anthropic_session_test.go new file mode 100644 index 00000000..10406643 --- /dev/null +++ b/backend/internal/service/anthropic_session_test.go @@ -0,0 +1,320 @@ +package service + +import ( + "strings" + "testing" +) + +func TestBuildAnthropicDigestChain_NilRequest(t *testing.T) { + result := BuildAnthropicDigestChain(nil) + if result != "" { + t.Errorf("expected empty string for nil request, got: %s", result) + } +} + +func TestBuildAnthropicDigestChain_EmptyMessages(t *testing.T) { + parsed := &ParsedRequest{ + Messages: []any{}, + } + result := BuildAnthropicDigestChain(parsed) + if result != "" { + t.Errorf("expected empty string for empty messages, got: %s", result) + } +} + +func TestBuildAnthropicDigestChain_SingleUserMessage(t *testing.T) { + parsed := &ParsedRequest{ + Messages: []any{ + map[string]any{"role": "user", "content": "hello"}, + }, + } + result := BuildAnthropicDigestChain(parsed) + parts := splitChain(result) + if len(parts) != 1 { + t.Fatalf("expected 1 part, got %d: %s", len(parts), result) + } + if !strings.HasPrefix(parts[0], "u:") { + t.Errorf("expected prefix 'u:', got: %s", parts[0]) + } +} + +func TestBuildAnthropicDigestChain_UserAndAssistant(t *testing.T) { + parsed := &ParsedRequest{ + Messages: []any{ + map[string]any{"role": "user", "content": "hello"}, + map[string]any{"role": "assistant", "content": "hi there"}, + }, + } + result := BuildAnthropicDigestChain(parsed) + parts := splitChain(result) + if len(parts) != 2 { + t.Fatalf("expected 2 parts, got %d: %s", len(parts), result) + } + if !strings.HasPrefix(parts[0], "u:") { + t.Errorf("part[0] expected prefix 'u:', got: %s", parts[0]) + } + if !strings.HasPrefix(parts[1], "a:") { + t.Errorf("part[1] expected prefix 'a:', got: %s", parts[1]) + } +} + +func TestBuildAnthropicDigestChain_WithSystemString(t *testing.T) { + parsed := &ParsedRequest{ + System: "You are a helpful assistant", + Messages: []any{ + map[string]any{"role": "user", "content": "hello"}, + }, + } + result := BuildAnthropicDigestChain(parsed) + parts := splitChain(result) + if len(parts) != 2 { + t.Fatalf("expected 2 parts (s + u), got %d: %s", len(parts), result) + } + if !strings.HasPrefix(parts[0], "s:") { + t.Errorf("part[0] expected prefix 's:', got: %s", parts[0]) + } + if !strings.HasPrefix(parts[1], "u:") { + t.Errorf("part[1] expected prefix 'u:', got: %s", parts[1]) + } +} + +func TestBuildAnthropicDigestChain_WithSystemContentBlocks(t *testing.T) { + parsed := &ParsedRequest{ + System: []any{ + map[string]any{"type": "text", "text": "You are a helpful assistant"}, + }, + Messages: []any{ + map[string]any{"role": "user", "content": "hello"}, + }, + } + result := BuildAnthropicDigestChain(parsed) + parts := splitChain(result) + if len(parts) != 2 { + t.Fatalf("expected 2 parts (s + u), got %d: %s", len(parts), result) + } + if !strings.HasPrefix(parts[0], "s:") { + t.Errorf("part[0] expected prefix 's:', got: %s", parts[0]) + } +} + +func TestBuildAnthropicDigestChain_ConversationPrefixRelationship(t *testing.T) { + // 核心测试:验证对话增长时链的前缀关系 + // 上一轮的完整链一定是下一轮链的前缀 + system := "You are a helpful assistant" + + // 第 1 轮: system + user + round1 := &ParsedRequest{ + System: system, + Messages: []any{ + map[string]any{"role": "user", "content": "hello"}, + }, + } + chain1 := BuildAnthropicDigestChain(round1) + + // 第 2 轮: system + user + assistant + user + round2 := &ParsedRequest{ + System: system, + Messages: []any{ + map[string]any{"role": "user", "content": "hello"}, + map[string]any{"role": "assistant", "content": "hi there"}, + map[string]any{"role": "user", "content": "how are you?"}, + }, + } + chain2 := BuildAnthropicDigestChain(round2) + + // 第 3 轮: system + user + assistant + user + assistant + user + round3 := &ParsedRequest{ + System: system, + Messages: []any{ + map[string]any{"role": "user", "content": "hello"}, + map[string]any{"role": "assistant", "content": "hi there"}, + map[string]any{"role": "user", "content": "how are you?"}, + map[string]any{"role": "assistant", "content": "I'm doing well"}, + map[string]any{"role": "user", "content": "great"}, + }, + } + chain3 := BuildAnthropicDigestChain(round3) + + t.Logf("Chain1: %s", chain1) + t.Logf("Chain2: %s", chain2) + t.Logf("Chain3: %s", chain3) + + // chain1 是 chain2 的前缀 + if !strings.HasPrefix(chain2, chain1) { + t.Errorf("chain1 should be prefix of chain2:\n chain1: %s\n chain2: %s", chain1, chain2) + } + + // chain2 是 chain3 的前缀 + if !strings.HasPrefix(chain3, chain2) { + t.Errorf("chain2 should be prefix of chain3:\n chain2: %s\n chain3: %s", chain2, chain3) + } + + // chain1 也是 chain3 的前缀(传递性) + if !strings.HasPrefix(chain3, chain1) { + t.Errorf("chain1 should be prefix of chain3:\n chain1: %s\n chain3: %s", chain1, chain3) + } +} + +func TestBuildAnthropicDigestChain_DifferentSystemProducesDifferentChain(t *testing.T) { + parsed1 := &ParsedRequest{ + System: "System A", + Messages: []any{ + map[string]any{"role": "user", "content": "hello"}, + }, + } + parsed2 := &ParsedRequest{ + System: "System B", + Messages: []any{ + map[string]any{"role": "user", "content": "hello"}, + }, + } + + chain1 := BuildAnthropicDigestChain(parsed1) + chain2 := BuildAnthropicDigestChain(parsed2) + + if chain1 == chain2 { + t.Error("Different system prompts should produce different chains") + } + + // 但 user 部分的 hash 应该相同 + parts1 := splitChain(chain1) + parts2 := splitChain(chain2) + if parts1[1] != parts2[1] { + t.Error("Same user message should produce same hash regardless of system") + } +} + +func TestBuildAnthropicDigestChain_DifferentContentProducesDifferentChain(t *testing.T) { + parsed1 := &ParsedRequest{ + Messages: []any{ + map[string]any{"role": "user", "content": "hello"}, + map[string]any{"role": "assistant", "content": "ORIGINAL reply"}, + map[string]any{"role": "user", "content": "next"}, + }, + } + parsed2 := &ParsedRequest{ + Messages: []any{ + map[string]any{"role": "user", "content": "hello"}, + map[string]any{"role": "assistant", "content": "TAMPERED reply"}, + map[string]any{"role": "user", "content": "next"}, + }, + } + + chain1 := BuildAnthropicDigestChain(parsed1) + chain2 := BuildAnthropicDigestChain(parsed2) + + if chain1 == chain2 { + t.Error("Different content should produce different chains") + } + + parts1 := splitChain(chain1) + parts2 := splitChain(chain2) + // 第一个 user message hash 应该相同 + if parts1[0] != parts2[0] { + t.Error("First user message hash should be the same") + } + // assistant reply hash 应该不同 + if parts1[1] == parts2[1] { + t.Error("Assistant reply hash should differ") + } +} + +func TestBuildAnthropicDigestChain_Deterministic(t *testing.T) { + parsed := &ParsedRequest{ + System: "test system", + Messages: []any{ + map[string]any{"role": "user", "content": "hello"}, + map[string]any{"role": "assistant", "content": "hi"}, + }, + } + + chain1 := BuildAnthropicDigestChain(parsed) + chain2 := BuildAnthropicDigestChain(parsed) + + if chain1 != chain2 { + t.Errorf("BuildAnthropicDigestChain not deterministic: %s vs %s", chain1, chain2) + } +} + +func TestGenerateAnthropicDigestSessionKey(t *testing.T) { + tests := []struct { + name string + prefixHash string + uuid string + want string + }{ + { + name: "normal 16 char hash with uuid", + prefixHash: "abcdefgh12345678", + uuid: "550e8400-e29b-41d4-a716-446655440000", + want: "anthropic:digest:abcdefgh:550e8400", + }, + { + name: "exactly 8 chars", + prefixHash: "12345678", + uuid: "abcdefgh", + want: "anthropic:digest:12345678:abcdefgh", + }, + { + name: "short values", + prefixHash: "abc", + uuid: "xyz", + want: "anthropic:digest:abc:xyz", + }, + { + name: "empty values", + prefixHash: "", + uuid: "", + want: "anthropic:digest::", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got := GenerateAnthropicDigestSessionKey(tt.prefixHash, tt.uuid) + if got != tt.want { + t.Errorf("GenerateAnthropicDigestSessionKey(%q, %q) = %q, want %q", tt.prefixHash, tt.uuid, got, tt.want) + } + }) + } + + // 验证不同 uuid 产生不同 sessionKey + t.Run("different uuid different key", func(t *testing.T) { + hash := "sameprefix123456" + result1 := GenerateAnthropicDigestSessionKey(hash, "uuid0001-session-a") + result2 := GenerateAnthropicDigestSessionKey(hash, "uuid0002-session-b") + if result1 == result2 { + t.Errorf("Different UUIDs should produce different session keys: %s vs %s", result1, result2) + } + }) +} + +func TestAnthropicSessionTTL(t *testing.T) { + ttl := AnthropicSessionTTL() + if ttl.Seconds() != 300 { + t.Errorf("expected 300 seconds, got: %v", ttl.Seconds()) + } +} + +func TestBuildAnthropicDigestChain_ContentBlocks(t *testing.T) { + // 测试 content 为 content blocks 数组的情况 + parsed := &ParsedRequest{ + Messages: []any{ + map[string]any{ + "role": "user", + "content": []any{ + map[string]any{"type": "text", "text": "describe this image"}, + map[string]any{"type": "image", "source": map[string]any{"type": "base64"}}, + }, + }, + }, + } + result := BuildAnthropicDigestChain(parsed) + parts := splitChain(result) + if len(parts) != 1 { + t.Fatalf("expected 1 part, got %d: %s", len(parts), result) + } + if !strings.HasPrefix(parts[0], "u:") { + t.Errorf("expected prefix 'u:', got: %s", parts[0]) + } +} diff --git a/backend/internal/service/antigravity_gateway_service.go b/backend/internal/service/antigravity_gateway_service.go index b49315ef..ea866b21 100644 --- a/backend/internal/service/antigravity_gateway_service.go +++ b/backend/internal/service/antigravity_gateway_service.go @@ -9,6 +9,7 @@ import ( "fmt" "io" "log" + "log/slog" mathrand "math/rand" "net" "net/http" @@ -35,7 +36,7 @@ const ( // - 预检查:剩余限流时间 < 此阈值时等待,>= 此阈值时切换账号 antigravityRateLimitThreshold = 7 * time.Second antigravitySmartRetryMinWait = 1 * time.Second // 智能重试最小等待时间 - antigravitySmartRetryMaxAttempts = 3 // 智能重试最大次数 + antigravitySmartRetryMaxAttempts = 1 // 智能重试最大次数(仅重试 1 次,防止重复限流/长期等待) antigravityDefaultRateLimitDuration = 30 * time.Second // 默认限流时间(无 retryDelay 时使用) // Google RPC 状态和类型常量 @@ -100,12 +101,11 @@ type antigravityRetryLoopParams struct { accessToken string action string body []byte - quotaScope AntigravityQuotaScope c *gin.Context httpUpstream HTTPUpstream settingService *SettingService accountRepo AccountRepository // 用于智能重试的模型级别限流 - handleError func(ctx context.Context, prefix string, account *Account, statusCode int, headers http.Header, body []byte, quotaScope AntigravityQuotaScope, groupID int64, sessionHash string, isStickySession bool) *handleModelRateLimitResult + handleError func(ctx context.Context, prefix string, account *Account, statusCode int, headers http.Header, body []byte, requestedModel string, groupID int64, sessionHash string, isStickySession bool) *handleModelRateLimitResult requestedModel string // 用于限流检查的原始请求模型 isStickySession bool // 是否为粘性会话(用于账号切换时的缓存计费判断) groupID int64 // 用于模型级限流时清除粘性会话 @@ -148,13 +148,17 @@ func (s *AntigravityGatewayService) handleSmartRetry(p antigravityRetryLoopParam // 情况1: retryDelay >= 阈值,限流模型并切换账号 if shouldRateLimitModel { - log.Printf("%s status=%d oauth_long_delay model=%s account=%d (model rate limit, switch account)", - p.prefix, resp.StatusCode, modelName, p.account.ID) + rateLimitDuration := waitDuration + if rateLimitDuration <= 0 { + rateLimitDuration = antigravityDefaultRateLimitDuration + } + log.Printf("%s status=%d oauth_long_delay model=%s account=%d upstream_retry_delay=%v body=%s (model rate limit, switch account)", + p.prefix, resp.StatusCode, modelName, p.account.ID, rateLimitDuration, truncateForLog(respBody, 200)) - resetAt := time.Now().Add(antigravityDefaultRateLimitDuration) + resetAt := time.Now().Add(rateLimitDuration) if !setModelRateLimitByModelName(p.ctx, p.accountRepo, p.account.ID, modelName, p.prefix, resp.StatusCode, resetAt, false) { - p.handleError(p.ctx, p.prefix, p.account, resp.StatusCode, resp.Header, respBody, p.quotaScope, p.groupID, p.sessionHash, p.isStickySession) - log.Printf("%s status=%d rate_limited account=%d (no scope mapping)", p.prefix, resp.StatusCode, p.account.ID) + p.handleError(p.ctx, p.prefix, p.account, resp.StatusCode, resp.Header, respBody, p.requestedModel, p.groupID, p.sessionHash, p.isStickySession) + log.Printf("%s status=%d rate_limited account=%d (no model mapping)", p.prefix, resp.StatusCode, p.account.ID) } else { s.updateAccountModelRateLimitInCache(p.ctx, p.account, modelName, resetAt) } @@ -190,7 +194,7 @@ func (s *AntigravityGatewayService) handleSmartRetry(p antigravityRetryLoopParam retryReq, err := antigravity.NewAPIRequestWithURL(p.ctx, baseURL, p.action, p.accessToken, p.body) if err != nil { log.Printf("%s status=smart_retry_request_build_failed error=%v", p.prefix, err) - p.handleError(p.ctx, p.prefix, p.account, resp.StatusCode, resp.Header, respBody, p.quotaScope, p.groupID, p.sessionHash, p.isStickySession) + p.handleError(p.ctx, p.prefix, p.account, resp.StatusCode, resp.Header, respBody, p.requestedModel, p.groupID, p.sessionHash, p.isStickySession) return &smartRetryResult{ action: smartRetryActionBreakWithResp, resp: &http.Response{ @@ -233,20 +237,33 @@ func (s *AntigravityGatewayService) handleSmartRetry(p antigravityRetryLoopParam } // 所有重试都失败,限流当前模型并切换账号 - log.Printf("%s status=%d smart_retry_exhausted attempts=%d model=%s account=%d (switch account)", - p.prefix, resp.StatusCode, antigravitySmartRetryMaxAttempts, modelName, p.account.ID) + rateLimitDuration := waitDuration + if rateLimitDuration <= 0 { + rateLimitDuration = antigravityDefaultRateLimitDuration + } + retryBody := lastRetryBody + if retryBody == nil { + retryBody = respBody + } + log.Printf("%s status=%d smart_retry_exhausted attempts=%d model=%s account=%d upstream_retry_delay=%v body=%s (switch account)", + p.prefix, resp.StatusCode, antigravitySmartRetryMaxAttempts, modelName, p.account.ID, rateLimitDuration, truncateForLog(retryBody, 200)) - resetAt := time.Now().Add(antigravityDefaultRateLimitDuration) + resetAt := time.Now().Add(rateLimitDuration) if p.accountRepo != nil && modelName != "" { if err := p.accountRepo.SetModelRateLimit(p.ctx, p.account.ID, modelName, resetAt); err != nil { log.Printf("%s status=%d model_rate_limit_failed model=%s error=%v", p.prefix, resp.StatusCode, modelName, err) } else { log.Printf("%s status=%d model_rate_limited_after_smart_retry model=%s account=%d reset_in=%v", - p.prefix, resp.StatusCode, modelName, p.account.ID, antigravityDefaultRateLimitDuration) + p.prefix, resp.StatusCode, modelName, p.account.ID, rateLimitDuration) s.updateAccountModelRateLimitInCache(p.ctx, p.account, modelName, resetAt) } } + // 清除粘性会话绑定,避免下次请求仍命中限流账号 + if s.cache != nil && p.sessionHash != "" { + _ = s.cache.DeleteSessionAccountID(p.ctx, p.groupID, p.sessionHash) + } + // 返回账号切换信号,让上层切换账号重试 return &smartRetryResult{ action: smartRetryActionBreakWithResp, @@ -264,27 +281,15 @@ func (s *AntigravityGatewayService) handleSmartRetry(p antigravityRetryLoopParam // antigravityRetryLoop 执行带 URL fallback 的重试循环 func (s *AntigravityGatewayService) antigravityRetryLoop(p antigravityRetryLoopParams) (*antigravityRetryLoopResult, error) { - // 预检查:如果账号已限流,根据剩余时间决定等待或切换 + // 预检查:如果账号已限流,直接返回切换信号 if p.requestedModel != "" { if remaining := p.account.GetRateLimitRemainingTimeWithContext(p.ctx, p.requestedModel); remaining > 0 { - if remaining < antigravityRateLimitThreshold { - // 限流剩余时间较短,等待后继续 - log.Printf("%s pre_check: rate_limit_wait remaining=%v model=%s account=%d", - p.prefix, remaining.Truncate(time.Millisecond), p.requestedModel, p.account.ID) - select { - case <-p.ctx.Done(): - return nil, p.ctx.Err() - case <-time.After(remaining): - } - } else { - // 限流剩余时间较长,返回账号切换信号 - log.Printf("%s pre_check: rate_limit_switch remaining=%v model=%s account=%d", - p.prefix, remaining.Truncate(time.Second), p.requestedModel, p.account.ID) - return nil, &AntigravityAccountSwitchError{ - OriginalAccountID: p.account.ID, - RateLimitedModel: p.requestedModel, - IsStickySession: p.isStickySession, - } + log.Printf("%s pre_check: rate_limit_switch remaining=%v model=%s account=%d", + p.prefix, remaining.Truncate(time.Millisecond), p.requestedModel, p.account.ID) + return nil, &AntigravityAccountSwitchError{ + OriginalAccountID: p.account.ID, + RateLimitedModel: p.requestedModel, + IsStickySession: p.isStickySession, } } } @@ -360,87 +365,102 @@ urlFallbackLoop: return nil, fmt.Errorf("upstream request failed after retries: %w", err) } - // 429/503 限流处理:区分 URL 级别限流、智能重试和账户配额限流 - if resp.StatusCode == http.StatusTooManyRequests || resp.StatusCode == http.StatusServiceUnavailable { + // 统一处理错误响应 + if resp.StatusCode >= 400 { respBody, _ := io.ReadAll(io.LimitReader(resp.Body, 2<<20)) _ = resp.Body.Close() - // 尝试智能重试处理(OAuth 账号专用) - smartResult := s.handleSmartRetry(p, resp, respBody, baseURL, urlIdx, availableURLs) - switch smartResult.action { - case smartRetryActionContinueURL: - continue urlFallbackLoop - case smartRetryActionBreakWithResp: - if smartResult.err != nil { - return nil, smartResult.err + // ★ 统一入口:自定义错误码 + 临时不可调度 + if handled, policyErr := s.applyErrorPolicy(p, resp.StatusCode, resp.Header, respBody); handled { + if policyErr != nil { + return nil, policyErr } - // 模型限流时返回切换账号信号 - if smartResult.switchError != nil { - return nil, smartResult.switchError + resp = &http.Response{ + StatusCode: resp.StatusCode, + Header: resp.Header.Clone(), + Body: io.NopCloser(bytes.NewReader(respBody)), } - resp = smartResult.resp break urlFallbackLoop } - // smartRetryActionContinue: 继续默认重试逻辑 - // 账户/模型配额限流,重试 3 次(指数退避)- 默认逻辑(非 OAuth 账号或解析失败) - if attempt < antigravityMaxRetries { - upstreamMsg := strings.TrimSpace(extractAntigravityErrorMessage(respBody)) - upstreamMsg = sanitizeUpstreamErrorMessage(upstreamMsg) - appendOpsUpstreamError(p.c, OpsUpstreamErrorEvent{ - Platform: p.account.Platform, - AccountID: p.account.ID, - AccountName: p.account.Name, - UpstreamStatusCode: resp.StatusCode, - UpstreamRequestID: resp.Header.Get("x-request-id"), - Kind: "retry", - Message: upstreamMsg, - Detail: getUpstreamDetail(respBody), - }) - log.Printf("%s status=%d retry=%d/%d body=%s", p.prefix, resp.StatusCode, attempt, antigravityMaxRetries, truncateForLog(respBody, 200)) - if !sleepAntigravityBackoffWithContext(p.ctx, attempt) { - log.Printf("%s status=context_canceled_during_backoff", p.prefix) - return nil, p.ctx.Err() + // 429/503 限流处理:区分 URL 级别限流、智能重试和账户配额限流 + if resp.StatusCode == http.StatusTooManyRequests || resp.StatusCode == http.StatusServiceUnavailable { + // 尝试智能重试处理(OAuth 账号专用) + smartResult := s.handleSmartRetry(p, resp, respBody, baseURL, urlIdx, availableURLs) + switch smartResult.action { + case smartRetryActionContinueURL: + continue urlFallbackLoop + case smartRetryActionBreakWithResp: + if smartResult.err != nil { + return nil, smartResult.err + } + // 模型限流时返回切换账号信号 + if smartResult.switchError != nil { + return nil, smartResult.switchError + } + resp = smartResult.resp + break urlFallbackLoop } - continue + // smartRetryActionContinue: 继续默认重试逻辑 + + // 账户/模型配额限流,重试 3 次(指数退避)- 默认逻辑(非 OAuth 账号或解析失败) + if attempt < antigravityMaxRetries { + upstreamMsg := strings.TrimSpace(extractAntigravityErrorMessage(respBody)) + upstreamMsg = sanitizeUpstreamErrorMessage(upstreamMsg) + appendOpsUpstreamError(p.c, OpsUpstreamErrorEvent{ + Platform: p.account.Platform, + AccountID: p.account.ID, + AccountName: p.account.Name, + UpstreamStatusCode: resp.StatusCode, + UpstreamRequestID: resp.Header.Get("x-request-id"), + Kind: "retry", + Message: upstreamMsg, + Detail: getUpstreamDetail(respBody), + }) + log.Printf("%s status=%d retry=%d/%d body=%s", p.prefix, resp.StatusCode, attempt, antigravityMaxRetries, truncateForLog(respBody, 200)) + if !sleepAntigravityBackoffWithContext(p.ctx, attempt) { + log.Printf("%s status=context_canceled_during_backoff", p.prefix) + return nil, p.ctx.Err() + } + continue + } + + // 重试用尽,标记账户限流 + p.handleError(p.ctx, p.prefix, p.account, resp.StatusCode, resp.Header, respBody, p.requestedModel, p.groupID, p.sessionHash, p.isStickySession) + log.Printf("%s status=%d rate_limited base_url=%s body=%s", p.prefix, resp.StatusCode, baseURL, truncateForLog(respBody, 200)) + resp = &http.Response{ + StatusCode: resp.StatusCode, + Header: resp.Header.Clone(), + Body: io.NopCloser(bytes.NewReader(respBody)), + } + break urlFallbackLoop } - // 重试用尽,标记账户限流 - p.handleError(p.ctx, p.prefix, p.account, resp.StatusCode, resp.Header, respBody, p.quotaScope, p.groupID, p.sessionHash, p.isStickySession) - log.Printf("%s status=%d rate_limited base_url=%s body=%s", p.prefix, resp.StatusCode, baseURL, truncateForLog(respBody, 200)) - resp = &http.Response{ - StatusCode: resp.StatusCode, - Header: resp.Header.Clone(), - Body: io.NopCloser(bytes.NewReader(respBody)), - } - break urlFallbackLoop - } - - // 其他可重试错误(不包括 429 和 503,因为上面已处理) - if resp.StatusCode >= 400 && shouldRetryAntigravityError(resp.StatusCode) { - respBody, _ := io.ReadAll(io.LimitReader(resp.Body, 2<<20)) - _ = resp.Body.Close() - - if attempt < antigravityMaxRetries { - upstreamMsg := strings.TrimSpace(extractAntigravityErrorMessage(respBody)) - upstreamMsg = sanitizeUpstreamErrorMessage(upstreamMsg) - appendOpsUpstreamError(p.c, OpsUpstreamErrorEvent{ - Platform: p.account.Platform, - AccountID: p.account.ID, - AccountName: p.account.Name, - UpstreamStatusCode: resp.StatusCode, - UpstreamRequestID: resp.Header.Get("x-request-id"), - Kind: "retry", - Message: upstreamMsg, - Detail: getUpstreamDetail(respBody), - }) - log.Printf("%s status=%d retry=%d/%d body=%s", p.prefix, resp.StatusCode, attempt, antigravityMaxRetries, truncateForLog(respBody, 500)) - if !sleepAntigravityBackoffWithContext(p.ctx, attempt) { - log.Printf("%s status=context_canceled_during_backoff", p.prefix) - return nil, p.ctx.Err() - } - continue - } + // 其他可重试错误(500/502/504/529,不包括 429 和 503) + if shouldRetryAntigravityError(resp.StatusCode) { + if attempt < antigravityMaxRetries { + upstreamMsg := strings.TrimSpace(extractAntigravityErrorMessage(respBody)) + upstreamMsg = sanitizeUpstreamErrorMessage(upstreamMsg) + appendOpsUpstreamError(p.c, OpsUpstreamErrorEvent{ + Platform: p.account.Platform, + AccountID: p.account.ID, + AccountName: p.account.Name, + UpstreamStatusCode: resp.StatusCode, + UpstreamRequestID: resp.Header.Get("x-request-id"), + Kind: "retry", + Message: upstreamMsg, + Detail: getUpstreamDetail(respBody), + }) + log.Printf("%s status=%d retry=%d/%d body=%s", p.prefix, resp.StatusCode, attempt, antigravityMaxRetries, truncateForLog(respBody, 500)) + if !sleepAntigravityBackoffWithContext(p.ctx, attempt) { + log.Printf("%s status=context_canceled_during_backoff", p.prefix) + return nil, p.ctx.Err() + } + continue + } + } + + // 其他 4xx 错误或重试用尽,直接返回 resp = &http.Response{ StatusCode: resp.StatusCode, Header: resp.Header.Clone(), @@ -449,6 +469,7 @@ urlFallbackLoop: break urlFallbackLoop } + // 成功响应(< 400) break urlFallbackLoop } } @@ -581,6 +602,31 @@ func (s *AntigravityGatewayService) getUpstreamErrorDetail(body []byte) string { return truncateString(string(body), maxBytes) } +// checkErrorPolicy nil 安全的包装 +func (s *AntigravityGatewayService) checkErrorPolicy(ctx context.Context, account *Account, statusCode int, body []byte) ErrorPolicyResult { + if s.rateLimitService == nil { + return ErrorPolicyNone + } + return s.rateLimitService.CheckErrorPolicy(ctx, account, statusCode, body) +} + +// applyErrorPolicy 应用错误策略结果,返回是否应终止当前循环 +func (s *AntigravityGatewayService) applyErrorPolicy(p antigravityRetryLoopParams, statusCode int, headers http.Header, respBody []byte) (handled bool, retErr error) { + switch s.checkErrorPolicy(p.ctx, p.account, statusCode, respBody) { + case ErrorPolicySkipped: + return true, nil + case ErrorPolicyMatched: + _ = p.handleError(p.ctx, p.prefix, p.account, statusCode, headers, respBody, + p.requestedModel, p.groupID, p.sessionHash, p.isStickySession) + return true, nil + case ErrorPolicyTempUnscheduled: + slog.Info("temp_unschedulable_matched", + "prefix", p.prefix, "status_code", statusCode, "account_id", p.account.ID) + return true, &AntigravityAccountSwitchError{OriginalAccountID: p.account.ID, IsStickySession: p.isStickySession} + } + return false, nil +} + // mapAntigravityModel 获取映射后的模型名 // 完全依赖映射配置:账户映射(通配符)→ 默认映射兜底(DefaultAntigravityModelMapping) // 注意:返回空字符串表示模型不被支持,调度时会过滤掉该账号 @@ -650,6 +696,7 @@ type TestConnectionResult struct { // TestConnection 测试 Antigravity 账号连接(非流式,无重试、无计费) // 支持 Claude 和 Gemini 两种协议,根据 modelID 前缀自动选择 func (s *AntigravityGatewayService) TestConnection(ctx context.Context, account *Account, modelID string) (*TestConnectionResult, error) { + // 获取 token if s.tokenProvider == nil { return nil, errors.New("antigravity token provider not configured") @@ -964,8 +1011,24 @@ func isModelNotFoundError(statusCode int, body []byte) bool { } // Forward 转发 Claude 协议请求(Claude → Gemini 转换) +// +// 限流处理流程: +// +// 请求 → antigravityRetryLoop → 预检查(remaining>0? → 切换账号) → 发送上游 +// ├─ 成功 → 正常返回 +// └─ 429/503 → handleSmartRetry +// ├─ retryDelay >= 7s → 设置模型限流 + 清除粘性绑定 → 切换账号 +// └─ retryDelay < 7s → 等待后重试 1 次 +// ├─ 成功 → 正常返回 +// └─ 失败 → 设置模型限流 + 清除粘性绑定 → 切换账号 func (s *AntigravityGatewayService) Forward(ctx context.Context, c *gin.Context, account *Account, body []byte, isStickySession bool) (*ForwardResult, error) { + // 上游透传账号直接转发,不走 OAuth token 刷新 + if account.Type == AccountTypeUpstream { + return s.ForwardUpstream(ctx, c, account, body) + } + startTime := time.Now() + sessionID := getSessionID(c) prefix := logPrefix(sessionID, account.Name) @@ -983,11 +1046,9 @@ func (s *AntigravityGatewayService) Forward(ctx context.Context, c *gin.Context, if mappedModel == "" { return nil, s.writeClaudeError(c, http.StatusForbidden, "permission_error", fmt.Sprintf("model %s not in whitelist", claudeReq.Model)) } - loadModel := mappedModel // 应用 thinking 模式自动后缀:如果 thinking 开启且目标是 claude-sonnet-4-5,自动改为 thinking 版本 thinkingEnabled := claudeReq.Thinking != nil && claudeReq.Thinking.Type == "enabled" mappedModel = applyThinkingModelSuffix(mappedModel, thinkingEnabled) - quotaScope, _ := resolveAntigravityQuotaScope(originalModel) // 获取 access_token if s.tokenProvider == nil { @@ -1022,11 +1083,6 @@ func (s *AntigravityGatewayService) Forward(ctx context.Context, c *gin.Context, // 如果客户端请求非流式,在响应处理阶段会收集完整流式响应后转换返回 action := "streamGenerateContent" - // 统计模型调用次数(包括粘性会话,用于负载均衡调度) - if s.cache != nil { - _, _ = s.cache.IncrModelCallCount(ctx, account.ID, loadModel) - } - // 执行带重试的请求 result, err := s.antigravityRetryLoop(antigravityRetryLoopParams{ ctx: ctx, @@ -1036,7 +1092,6 @@ func (s *AntigravityGatewayService) Forward(ctx context.Context, c *gin.Context, accessToken: accessToken, action: action, body: geminiBody, - quotaScope: quotaScope, c: c, httpUpstream: s.httpUpstream, settingService: s.settingService, @@ -1117,7 +1172,6 @@ func (s *AntigravityGatewayService) Forward(ctx context.Context, c *gin.Context, accessToken: accessToken, action: action, body: retryGeminiBody, - quotaScope: quotaScope, c: c, httpUpstream: s.httpUpstream, settingService: s.settingService, @@ -1228,7 +1282,7 @@ func (s *AntigravityGatewayService) Forward(ctx context.Context, c *gin.Context, } } - s.handleUpstreamError(ctx, prefix, account, resp.StatusCode, resp.Header, respBody, quotaScope, 0, "", isStickySession) + s.handleUpstreamError(ctx, prefix, account, resp.StatusCode, resp.Header, respBody, originalModel, 0, "", isStickySession) if s.shouldFailoverUpstreamError(resp.StatusCode) { upstreamMsg := strings.TrimSpace(extractAntigravityErrorMessage(respBody)) @@ -1258,6 +1312,7 @@ func (s *AntigravityGatewayService) Forward(ctx context.Context, c *gin.Context, var usage *ClaudeUsage var firstTokenMs *int + var clientDisconnect bool if claudeReq.Stream { // 客户端要求流式,直接透传转换 streamRes, err := s.handleClaudeStreamingResponse(c, resp, startTime, originalModel) @@ -1267,6 +1322,7 @@ func (s *AntigravityGatewayService) Forward(ctx context.Context, c *gin.Context, } usage = streamRes.usage firstTokenMs = streamRes.firstTokenMs + clientDisconnect = streamRes.clientDisconnect } else { // 客户端要求非流式,收集流式响应后转换返回 streamRes, err := s.handleClaudeStreamToNonStreaming(c, resp, startTime, originalModel) @@ -1279,12 +1335,13 @@ func (s *AntigravityGatewayService) Forward(ctx context.Context, c *gin.Context, } return &ForwardResult{ - RequestID: requestID, - Usage: *usage, - Model: originalModel, // 使用原始模型用于计费和日志 - Stream: claudeReq.Stream, - Duration: time.Since(startTime), - FirstTokenMs: firstTokenMs, + RequestID: requestID, + Usage: *usage, + Model: originalModel, // 使用原始模型用于计费和日志 + Stream: claudeReq.Stream, + Duration: time.Since(startTime), + FirstTokenMs: firstTokenMs, + ClientDisconnect: clientDisconnect, }, nil } @@ -1582,211 +1639,20 @@ func stripSignatureSensitiveBlocksFromClaudeRequest(req *antigravity.ClaudeReque return changed, nil } -// ForwardUpstream 透传请求到上游 Antigravity 服务 -// 用于 upstream 类型账号,直接使用 base_url + api_key 转发,不走 OAuth token -func (s *AntigravityGatewayService) ForwardUpstream(ctx context.Context, c *gin.Context, account *Account, body []byte) (*ForwardResult, error) { - startTime := time.Now() - sessionID := getSessionID(c) - prefix := logPrefix(sessionID, account.Name) - - // 获取上游配置 - baseURL := strings.TrimSpace(account.GetCredential("base_url")) - apiKey := strings.TrimSpace(account.GetCredential("api_key")) - if baseURL == "" || apiKey == "" { - return nil, fmt.Errorf("upstream account missing base_url or api_key") - } - baseURL = strings.TrimSuffix(baseURL, "/") - - // 解析请求获取模型信息 - var claudeReq antigravity.ClaudeRequest - if err := json.Unmarshal(body, &claudeReq); err != nil { - return nil, fmt.Errorf("parse claude request: %w", err) - } - if strings.TrimSpace(claudeReq.Model) == "" { - return nil, fmt.Errorf("missing model") - } - originalModel := claudeReq.Model - billingModel := originalModel - - // 构建上游请求 URL - upstreamURL := baseURL + "/v1/messages" - - // 创建请求 - req, err := http.NewRequestWithContext(ctx, http.MethodPost, upstreamURL, bytes.NewReader(body)) - if err != nil { - return nil, fmt.Errorf("create upstream request: %w", err) - } - - // 设置请求头 - req.Header.Set("Content-Type", "application/json") - req.Header.Set("Authorization", "Bearer "+apiKey) - req.Header.Set("x-api-key", apiKey) // Claude API 兼容 - - // 透传 Claude 相关 headers - if v := c.GetHeader("anthropic-version"); v != "" { - req.Header.Set("anthropic-version", v) - } - if v := c.GetHeader("anthropic-beta"); v != "" { - req.Header.Set("anthropic-beta", v) - } - - // 代理 URL - proxyURL := "" - if account.ProxyID != nil && account.Proxy != nil { - proxyURL = account.Proxy.URL() - } - - // 发送请求 - resp, err := s.httpUpstream.Do(req, proxyURL, account.ID, account.Concurrency) - if err != nil { - log.Printf("%s upstream request failed: %v", prefix, err) - return nil, fmt.Errorf("upstream request failed: %w", err) - } - defer func() { _ = resp.Body.Close() }() - - // 处理错误响应 - if resp.StatusCode >= 400 { - respBody, _ := io.ReadAll(io.LimitReader(resp.Body, 2<<20)) - - // 429 错误时标记账号限流 - if resp.StatusCode == http.StatusTooManyRequests { - s.handleUpstreamError(ctx, prefix, account, resp.StatusCode, resp.Header, respBody, AntigravityQuotaScopeClaude, 0, "", false) - } - - // 透传上游错误 - c.Header("Content-Type", resp.Header.Get("Content-Type")) - c.Status(resp.StatusCode) - _, _ = c.Writer.Write(respBody) - - return &ForwardResult{ - Model: billingModel, - }, nil - } - - // 处理成功响应(流式/非流式) - var usage *ClaudeUsage - var firstTokenMs *int - - if claudeReq.Stream { - // 流式响应:透传 - c.Header("Content-Type", "text/event-stream") - c.Header("Cache-Control", "no-cache") - c.Header("Connection", "keep-alive") - c.Header("X-Accel-Buffering", "no") - c.Status(http.StatusOK) - - usage, firstTokenMs = s.streamUpstreamResponse(c, resp, startTime) - } else { - // 非流式响应:直接透传 - respBody, err := io.ReadAll(resp.Body) - if err != nil { - return nil, fmt.Errorf("read upstream response: %w", err) - } - - // 提取 usage - usage = s.extractClaudeUsage(respBody) - - c.Header("Content-Type", resp.Header.Get("Content-Type")) - c.Status(http.StatusOK) - _, _ = c.Writer.Write(respBody) - } - - // 构建计费结果 - duration := time.Since(startTime) - log.Printf("%s status=success duration_ms=%d", prefix, duration.Milliseconds()) - - return &ForwardResult{ - Model: billingModel, - Stream: claudeReq.Stream, - Duration: duration, - FirstTokenMs: firstTokenMs, - Usage: ClaudeUsage{ - InputTokens: usage.InputTokens, - OutputTokens: usage.OutputTokens, - CacheReadInputTokens: usage.CacheReadInputTokens, - CacheCreationInputTokens: usage.CacheCreationInputTokens, - }, - }, nil -} - -// streamUpstreamResponse 透传上游流式响应并提取 usage -func (s *AntigravityGatewayService) streamUpstreamResponse(c *gin.Context, resp *http.Response, startTime time.Time) (*ClaudeUsage, *int) { - usage := &ClaudeUsage{} - var firstTokenMs *int - var firstTokenRecorded bool - - scanner := bufio.NewScanner(resp.Body) - buf := make([]byte, 0, 64*1024) - scanner.Buffer(buf, 1024*1024) - - for scanner.Scan() { - line := scanner.Bytes() - - // 记录首 token 时间 - if !firstTokenRecorded && len(line) > 0 { - ms := int(time.Since(startTime).Milliseconds()) - firstTokenMs = &ms - firstTokenRecorded = true - } - - // 尝试从 message_delta 或 message_stop 事件提取 usage - if bytes.HasPrefix(line, []byte("data: ")) { - dataStr := bytes.TrimPrefix(line, []byte("data: ")) - var event map[string]any - if json.Unmarshal(dataStr, &event) == nil { - if u, ok := event["usage"].(map[string]any); ok { - if v, ok := u["input_tokens"].(float64); ok && int(v) > 0 { - usage.InputTokens = int(v) - } - if v, ok := u["output_tokens"].(float64); ok && int(v) > 0 { - usage.OutputTokens = int(v) - } - if v, ok := u["cache_read_input_tokens"].(float64); ok && int(v) > 0 { - usage.CacheReadInputTokens = int(v) - } - if v, ok := u["cache_creation_input_tokens"].(float64); ok && int(v) > 0 { - usage.CacheCreationInputTokens = int(v) - } - } - } - } - - // 透传行 - _, _ = c.Writer.Write(line) - _, _ = c.Writer.Write([]byte("\n")) - c.Writer.Flush() - } - - return usage, firstTokenMs -} - -// extractClaudeUsage 从非流式 Claude 响应提取 usage -func (s *AntigravityGatewayService) extractClaudeUsage(body []byte) *ClaudeUsage { - usage := &ClaudeUsage{} - var resp map[string]any - if json.Unmarshal(body, &resp) != nil { - return usage - } - if u, ok := resp["usage"].(map[string]any); ok { - if v, ok := u["input_tokens"].(float64); ok { - usage.InputTokens = int(v) - } - if v, ok := u["output_tokens"].(float64); ok { - usage.OutputTokens = int(v) - } - if v, ok := u["cache_read_input_tokens"].(float64); ok { - usage.CacheReadInputTokens = int(v) - } - if v, ok := u["cache_creation_input_tokens"].(float64); ok { - usage.CacheCreationInputTokens = int(v) - } - } - return usage -} - // ForwardGemini 转发 Gemini 协议请求 +// +// 限流处理流程: +// +// 请求 → antigravityRetryLoop → 预检查(remaining>0? → 切换账号) → 发送上游 +// ├─ 成功 → 正常返回 +// └─ 429/503 → handleSmartRetry +// ├─ retryDelay >= 7s → 设置模型限流 + 清除粘性绑定 → 切换账号 +// └─ retryDelay < 7s → 等待后重试 1 次 +// ├─ 成功 → 正常返回 +// └─ 失败 → 设置模型限流 + 清除粘性绑定 → 切换账号 func (s *AntigravityGatewayService) ForwardGemini(ctx context.Context, c *gin.Context, account *Account, originalModel string, action string, stream bool, body []byte, isStickySession bool) (*ForwardResult, error) { startTime := time.Now() + sessionID := getSessionID(c) prefix := logPrefix(sessionID, account.Name) @@ -1799,7 +1665,6 @@ func (s *AntigravityGatewayService) ForwardGemini(ctx context.Context, c *gin.Co if len(body) == 0 { return nil, s.writeGoogleError(c, http.StatusBadRequest, "Request body is empty") } - quotaScope, _ := resolveAntigravityQuotaScope(originalModel) // 解析请求以获取 image_size(用于图片计费) imageSize := s.extractImageSize(body) @@ -1869,11 +1734,6 @@ func (s *AntigravityGatewayService) ForwardGemini(ctx context.Context, c *gin.Co // 如果客户端请求非流式,在响应处理阶段会收集完整流式响应后返回 upstreamAction := "streamGenerateContent" - // 统计模型调用次数(包括粘性会话,用于负载均衡调度) - if s.cache != nil { - _, _ = s.cache.IncrModelCallCount(ctx, account.ID, mappedModel) - } - // 执行带重试的请求 result, err := s.antigravityRetryLoop(antigravityRetryLoopParams{ ctx: ctx, @@ -1883,7 +1743,6 @@ func (s *AntigravityGatewayService) ForwardGemini(ctx context.Context, c *gin.Co accessToken: accessToken, action: upstreamAction, body: wrappedBody, - quotaScope: quotaScope, c: c, httpUpstream: s.httpUpstream, settingService: s.settingService, @@ -1957,7 +1816,7 @@ func (s *AntigravityGatewayService) ForwardGemini(ctx context.Context, c *gin.Co if unwrapErr != nil || len(unwrappedForOps) == 0 { unwrappedForOps = respBody } - s.handleUpstreamError(ctx, prefix, account, resp.StatusCode, resp.Header, respBody, quotaScope, 0, "", isStickySession) + s.handleUpstreamError(ctx, prefix, account, resp.StatusCode, resp.Header, respBody, originalModel, 0, "", isStickySession) upstreamMsg := strings.TrimSpace(extractAntigravityErrorMessage(unwrappedForOps)) upstreamMsg = sanitizeUpstreamErrorMessage(upstreamMsg) upstreamDetail := s.getUpstreamErrorDetail(unwrappedForOps) @@ -2004,6 +1863,7 @@ handleSuccess: var usage *ClaudeUsage var firstTokenMs *int + var clientDisconnect bool if stream { // 客户端要求流式,直接透传 @@ -2014,6 +1874,7 @@ handleSuccess: } usage = streamRes.usage firstTokenMs = streamRes.firstTokenMs + clientDisconnect = streamRes.clientDisconnect } else { // 客户端要求非流式,收集流式响应后返回 streamRes, err := s.handleGeminiStreamToNonStreaming(c, resp, startTime) @@ -2037,14 +1898,15 @@ handleSuccess: } return &ForwardResult{ - RequestID: requestID, - Usage: *usage, - Model: originalModel, - Stream: stream, - Duration: time.Since(startTime), - FirstTokenMs: firstTokenMs, - ImageCount: imageCount, - ImageSize: imageSize, + RequestID: requestID, + Usage: *usage, + Model: originalModel, + Stream: stream, + Duration: time.Since(startTime), + FirstTokenMs: firstTokenMs, + ClientDisconnect: clientDisconnect, + ImageCount: imageCount, + ImageSize: imageSize, }, nil } @@ -2253,9 +2115,9 @@ func shouldTriggerAntigravitySmartRetry(account *Account, respBody []byte) (shou } // retryDelay >= 阈值:直接限流模型,不重试 - // 注意:如果上游未提供 retryDelay,parseAntigravitySmartRetryInfo 已设置为默认 5 分钟 + // 注意:如果上游未提供 retryDelay,parseAntigravitySmartRetryInfo 已设置为默认 30s if info.RetryDelay >= antigravityRateLimitThreshold { - return false, true, 0, info.ModelName + return false, true, info.RetryDelay, info.ModelName } // retryDelay < 阈值:智能重试 @@ -2377,10 +2239,10 @@ func (s *AntigravityGatewayService) updateAccountModelRateLimitInCache(ctx conte func (s *AntigravityGatewayService) handleUpstreamError( ctx context.Context, prefix string, account *Account, statusCode int, headers http.Header, body []byte, - quotaScope AntigravityQuotaScope, + requestedModel string, groupID int64, sessionHash string, isStickySession bool, ) *handleModelRateLimitResult { - // ✨ 模型级限流处理(在原有逻辑之前) + // 模型级限流处理(优先) result := s.handleModelRateLimit(&handleModelRateLimitParams{ ctx: ctx, prefix: prefix, @@ -2402,52 +2264,35 @@ func (s *AntigravityGatewayService) handleUpstreamError( return nil } - // ========== 原有逻辑,保持不变 ========== - // 429 使用 Gemini 格式解析(从 body 解析重置时间) + // 429:尝试解析模型级限流,解析失败时兜底为账号级限流 if statusCode == 429 { - // 调试日志遵循统一日志开关与长度限制,避免无条件记录完整上游响应体。 if logBody, maxBytes := s.getLogConfig(); logBody { log.Printf("[Antigravity-Debug] 429 response body: %s", truncateString(string(body), maxBytes)) } - useScopeLimit := quotaScope != "" resetAt := ParseGeminiRateLimitResetTime(body) - if resetAt == nil { - // 解析失败:使用默认限流时间(与临时限流保持一致) - // 可通过配置或环境变量覆盖 - defaultDur := antigravityDefaultRateLimitDuration - if s.settingService != nil && s.settingService.cfg != nil && s.settingService.cfg.Gateway.AntigravityFallbackCooldownMinutes > 0 { - defaultDur = time.Duration(s.settingService.cfg.Gateway.AntigravityFallbackCooldownMinutes) * time.Minute - } - // 秒级环境变量优先级最高 - if override, ok := antigravityFallbackCooldownSeconds(); ok { - defaultDur = override - } - ra := time.Now().Add(defaultDur) - if useScopeLimit { - log.Printf("%s status=429 rate_limited scope=%s reset_in=%v (fallback)", prefix, quotaScope, defaultDur) - if err := s.accountRepo.SetAntigravityQuotaScopeLimit(ctx, account.ID, quotaScope, ra); err != nil { - log.Printf("%s status=429 rate_limit_set_failed scope=%s error=%v", prefix, quotaScope, err) - } + defaultDur := s.getDefaultRateLimitDuration() + + // 尝试解析模型 key 并设置模型级限流 + modelKey := resolveAntigravityModelKey(requestedModel) + if modelKey != "" { + ra := s.resolveResetTime(resetAt, defaultDur) + if err := s.accountRepo.SetModelRateLimit(ctx, account.ID, modelKey, ra); err != nil { + log.Printf("%s status=429 model_rate_limit_set_failed model=%s error=%v", prefix, modelKey, err) } else { - log.Printf("%s status=429 rate_limited account=%d reset_in=%v (fallback)", prefix, account.ID, defaultDur) - if err := s.accountRepo.SetRateLimited(ctx, account.ID, ra); err != nil { - log.Printf("%s status=429 rate_limit_set_failed account=%d error=%v", prefix, account.ID, err) - } + log.Printf("%s status=429 model_rate_limited model=%s account=%d reset_at=%v reset_in=%v", + prefix, modelKey, account.ID, ra.Format("15:04:05"), time.Until(ra).Truncate(time.Second)) + s.updateAccountModelRateLimitInCache(ctx, account, modelKey, ra) } return nil } - resetTime := time.Unix(*resetAt, 0) - if useScopeLimit { - log.Printf("%s status=429 rate_limited scope=%s reset_at=%v reset_in=%v", prefix, quotaScope, resetTime.Format("15:04:05"), time.Until(resetTime).Truncate(time.Second)) - if err := s.accountRepo.SetAntigravityQuotaScopeLimit(ctx, account.ID, quotaScope, resetTime); err != nil { - log.Printf("%s status=429 rate_limit_set_failed scope=%s error=%v", prefix, quotaScope, err) - } - } else { - log.Printf("%s status=429 rate_limited account=%d reset_at=%v reset_in=%v", prefix, account.ID, resetTime.Format("15:04:05"), time.Until(resetTime).Truncate(time.Second)) - if err := s.accountRepo.SetRateLimited(ctx, account.ID, resetTime); err != nil { - log.Printf("%s status=429 rate_limit_set_failed account=%d error=%v", prefix, account.ID, err) - } + + // 无法解析模型 key,兜底为账号级限流 + ra := s.resolveResetTime(resetAt, defaultDur) + log.Printf("%s status=429 rate_limited account=%d reset_at=%v reset_in=%v (fallback)", + prefix, account.ID, ra.Format("15:04:05"), time.Until(ra).Truncate(time.Second)) + if err := s.accountRepo.SetRateLimited(ctx, account.ID, ra); err != nil { + log.Printf("%s status=429 rate_limit_set_failed account=%d error=%v", prefix, account.ID, err) } return nil } @@ -2462,9 +2307,90 @@ func (s *AntigravityGatewayService) handleUpstreamError( return nil } +// getDefaultRateLimitDuration 获取默认限流时间 +func (s *AntigravityGatewayService) getDefaultRateLimitDuration() time.Duration { + defaultDur := antigravityDefaultRateLimitDuration + if s.settingService != nil && s.settingService.cfg != nil && s.settingService.cfg.Gateway.AntigravityFallbackCooldownMinutes > 0 { + defaultDur = time.Duration(s.settingService.cfg.Gateway.AntigravityFallbackCooldownMinutes) * time.Minute + } + if override, ok := antigravityFallbackCooldownSeconds(); ok { + defaultDur = override + } + return defaultDur +} + +// resolveResetTime 根据解析的重置时间或默认时长计算重置时间点 +func (s *AntigravityGatewayService) resolveResetTime(resetAt *int64, defaultDur time.Duration) time.Time { + if resetAt != nil { + return time.Unix(*resetAt, 0) + } + return time.Now().Add(defaultDur) +} + type antigravityStreamResult struct { - usage *ClaudeUsage - firstTokenMs *int + usage *ClaudeUsage + firstTokenMs *int + clientDisconnect bool // 客户端是否在流式传输过程中断开 +} + +// antigravityClientWriter 封装流式响应的客户端写入,自动检测断开并标记。 +// 断开后所有写入操作变为 no-op,调用方通过 Disconnected() 判断是否继续 drain 上游。 +type antigravityClientWriter struct { + w gin.ResponseWriter + flusher http.Flusher + disconnected bool + prefix string // 日志前缀,标识来源方法 +} + +func newAntigravityClientWriter(w gin.ResponseWriter, flusher http.Flusher, prefix string) *antigravityClientWriter { + return &antigravityClientWriter{w: w, flusher: flusher, prefix: prefix} +} + +// Write 写入数据到客户端,写入失败时标记断开并返回 false +func (cw *antigravityClientWriter) Write(p []byte) bool { + if cw.disconnected { + return false + } + if _, err := cw.w.Write(p); err != nil { + cw.markDisconnected() + return false + } + cw.flusher.Flush() + return true +} + +// Fprintf 格式化写入数据到客户端,写入失败时标记断开并返回 false +func (cw *antigravityClientWriter) Fprintf(format string, args ...any) bool { + if cw.disconnected { + return false + } + if _, err := fmt.Fprintf(cw.w, format, args...); err != nil { + cw.markDisconnected() + return false + } + cw.flusher.Flush() + return true +} + +func (cw *antigravityClientWriter) Disconnected() bool { return cw.disconnected } + +func (cw *antigravityClientWriter) markDisconnected() { + cw.disconnected = true + log.Printf("Client disconnected during streaming (%s), continuing to drain upstream for billing", cw.prefix) +} + +// handleStreamReadError 处理上游读取错误的通用逻辑。 +// 返回 (clientDisconnect, handled):handled=true 表示错误已处理,调用方应返回已收集的 usage。 +func handleStreamReadError(err error, clientDisconnected bool, prefix string) (disconnect bool, handled bool) { + if errors.Is(err, context.Canceled) || errors.Is(err, context.DeadlineExceeded) { + log.Printf("Context canceled during streaming (%s), returning collected usage", prefix) + return true, true + } + if clientDisconnected { + log.Printf("Upstream read error after client disconnect (%s): %v, returning collected usage", prefix, err) + return true, true + } + return false, false } func (s *AntigravityGatewayService) handleGeminiStreamingResponse(c *gin.Context, resp *http.Response, startTime time.Time) (*antigravityStreamResult, error) { @@ -2542,10 +2468,12 @@ func (s *AntigravityGatewayService) handleGeminiStreamingResponse(c *gin.Context intervalCh = intervalTicker.C } + cw := newAntigravityClientWriter(c.Writer, flusher, "antigravity gemini") + // 仅发送一次错误事件,避免多次写入导致协议混乱 errorEventSent := false sendErrorEvent := func(reason string) { - if errorEventSent { + if errorEventSent || cw.Disconnected() { return } errorEventSent = true @@ -2557,9 +2485,12 @@ func (s *AntigravityGatewayService) handleGeminiStreamingResponse(c *gin.Context select { case ev, ok := <-events: if !ok { - return &antigravityStreamResult{usage: usage, firstTokenMs: firstTokenMs}, nil + return &antigravityStreamResult{usage: usage, firstTokenMs: firstTokenMs, clientDisconnect: cw.Disconnected()}, nil } if ev.err != nil { + if disconnect, handled := handleStreamReadError(ev.err, cw.Disconnected(), "antigravity gemini"); handled { + return &antigravityStreamResult{usage: usage, firstTokenMs: firstTokenMs, clientDisconnect: disconnect}, nil + } if errors.Is(ev.err, bufio.ErrTooLong) { log.Printf("SSE line too long (antigravity): max_size=%d error=%v", maxLineSize, ev.err) sendErrorEvent("response_too_large") @@ -2574,11 +2505,7 @@ func (s *AntigravityGatewayService) handleGeminiStreamingResponse(c *gin.Context if strings.HasPrefix(trimmed, "data:") { payload := strings.TrimSpace(strings.TrimPrefix(trimmed, "data:")) if payload == "" || payload == "[DONE]" { - if _, err := fmt.Fprintf(c.Writer, "%s\n", line); err != nil { - sendErrorEvent("write_failed") - return &antigravityStreamResult{usage: usage, firstTokenMs: firstTokenMs}, err - } - flusher.Flush() + cw.Fprintf("%s\n", line) continue } @@ -2614,27 +2541,22 @@ func (s *AntigravityGatewayService) handleGeminiStreamingResponse(c *gin.Context firstTokenMs = &ms } - if _, err := fmt.Fprintf(c.Writer, "data: %s\n\n", payload); err != nil { - sendErrorEvent("write_failed") - return &antigravityStreamResult{usage: usage, firstTokenMs: firstTokenMs}, err - } - flusher.Flush() + cw.Fprintf("data: %s\n\n", payload) continue } - if _, err := fmt.Fprintf(c.Writer, "%s\n", line); err != nil { - sendErrorEvent("write_failed") - return &antigravityStreamResult{usage: usage, firstTokenMs: firstTokenMs}, err - } - flusher.Flush() + cw.Fprintf("%s\n", line) case <-intervalCh: lastRead := time.Unix(0, atomic.LoadInt64(&lastReadAt)) if time.Since(lastRead) < streamInterval { continue } + if cw.Disconnected() { + log.Printf("Upstream timeout after client disconnect (antigravity gemini), returning collected usage") + return &antigravityStreamResult{usage: usage, firstTokenMs: firstTokenMs, clientDisconnect: true}, nil + } log.Printf("Stream data interval timeout (antigravity)") - // 注意:此函数没有 account 上下文,无法调用 HandleStreamTimeout sendErrorEvent("stream_timeout") return &antigravityStreamResult{usage: usage, firstTokenMs: firstTokenMs}, fmt.Errorf("stream data interval timeout") } @@ -3338,10 +3260,12 @@ func (s *AntigravityGatewayService) handleClaudeStreamingResponse(c *gin.Context intervalCh = intervalTicker.C } + cw := newAntigravityClientWriter(c.Writer, flusher, "antigravity claude") + // 仅发送一次错误事件,避免多次写入导致协议混乱 errorEventSent := false sendErrorEvent := func(reason string) { - if errorEventSent { + if errorEventSent || cw.Disconnected() { return } errorEventSent = true @@ -3349,19 +3273,27 @@ func (s *AntigravityGatewayService) handleClaudeStreamingResponse(c *gin.Context flusher.Flush() } + // finishUsage 是获取 processor 最终 usage 的辅助函数 + finishUsage := func() *ClaudeUsage { + _, agUsage := processor.Finish() + return convertUsage(agUsage) + } + for { select { case ev, ok := <-events: if !ok { - // 发送结束事件 + // 上游完成,发送结束事件 finalEvents, agUsage := processor.Finish() if len(finalEvents) > 0 { - _, _ = c.Writer.Write(finalEvents) - flusher.Flush() + cw.Write(finalEvents) } - return &antigravityStreamResult{usage: convertUsage(agUsage), firstTokenMs: firstTokenMs}, nil + return &antigravityStreamResult{usage: convertUsage(agUsage), firstTokenMs: firstTokenMs, clientDisconnect: cw.Disconnected()}, nil } if ev.err != nil { + if disconnect, handled := handleStreamReadError(ev.err, cw.Disconnected(), "antigravity claude"); handled { + return &antigravityStreamResult{usage: finishUsage(), firstTokenMs: firstTokenMs, clientDisconnect: disconnect}, nil + } if errors.Is(ev.err, bufio.ErrTooLong) { log.Printf("SSE line too long (antigravity): max_size=%d error=%v", maxLineSize, ev.err) sendErrorEvent("response_too_large") @@ -3371,25 +3303,14 @@ func (s *AntigravityGatewayService) handleClaudeStreamingResponse(c *gin.Context return nil, fmt.Errorf("stream read error: %w", ev.err) } - line := ev.line // 处理 SSE 行,转换为 Claude 格式 - claudeEvents := processor.ProcessLine(strings.TrimRight(line, "\r\n")) - + claudeEvents := processor.ProcessLine(strings.TrimRight(ev.line, "\r\n")) if len(claudeEvents) > 0 { if firstTokenMs == nil { ms := int(time.Since(startTime).Milliseconds()) firstTokenMs = &ms } - - if _, writeErr := c.Writer.Write(claudeEvents); writeErr != nil { - finalEvents, agUsage := processor.Finish() - if len(finalEvents) > 0 { - _, _ = c.Writer.Write(finalEvents) - } - sendErrorEvent("write_failed") - return &antigravityStreamResult{usage: convertUsage(agUsage), firstTokenMs: firstTokenMs}, writeErr - } - flusher.Flush() + cw.Write(claudeEvents) } case <-intervalCh: @@ -3397,13 +3318,15 @@ func (s *AntigravityGatewayService) handleClaudeStreamingResponse(c *gin.Context if time.Since(lastRead) < streamInterval { continue } + if cw.Disconnected() { + log.Printf("Upstream timeout after client disconnect (antigravity claude), returning collected usage") + return &antigravityStreamResult{usage: finishUsage(), firstTokenMs: firstTokenMs, clientDisconnect: true}, nil + } log.Printf("Stream data interval timeout (antigravity)") - // 注意:此函数没有 account 上下文,无法调用 HandleStreamTimeout sendErrorEvent("stream_timeout") return &antigravityStreamResult{usage: convertUsage(nil), firstTokenMs: firstTokenMs}, fmt.Errorf("stream data interval timeout") } } - } // extractImageSize 从 Gemini 请求中提取 image_size 参数 @@ -3542,3 +3465,288 @@ func filterEmptyPartsFromGeminiRequest(body []byte) ([]byte, error) { payload["contents"] = filtered return json.Marshal(payload) } + +// ForwardUpstream 使用 base_url + /v1/messages + 双 header 认证透传上游 Claude 请求 +func (s *AntigravityGatewayService) ForwardUpstream(ctx context.Context, c *gin.Context, account *Account, body []byte) (*ForwardResult, error) { + startTime := time.Now() + sessionID := getSessionID(c) + prefix := logPrefix(sessionID, account.Name) + + // 获取上游配置 + baseURL := strings.TrimSpace(account.GetCredential("base_url")) + apiKey := strings.TrimSpace(account.GetCredential("api_key")) + if baseURL == "" || apiKey == "" { + return nil, fmt.Errorf("upstream account missing base_url or api_key") + } + baseURL = strings.TrimSuffix(baseURL, "/") + + // 解析请求获取模型信息 + var claudeReq antigravity.ClaudeRequest + if err := json.Unmarshal(body, &claudeReq); err != nil { + return nil, fmt.Errorf("parse claude request: %w", err) + } + if strings.TrimSpace(claudeReq.Model) == "" { + return nil, fmt.Errorf("missing model") + } + originalModel := claudeReq.Model + billingModel := originalModel + + // 构建上游请求 URL + upstreamURL := baseURL + "/v1/messages" + + // 创建请求 + req, err := http.NewRequestWithContext(ctx, http.MethodPost, upstreamURL, bytes.NewReader(body)) + if err != nil { + return nil, fmt.Errorf("create upstream request: %w", err) + } + + // 设置请求头 + req.Header.Set("Content-Type", "application/json") + req.Header.Set("Authorization", "Bearer "+apiKey) + req.Header.Set("x-api-key", apiKey) // Claude API 兼容 + + // 透传 Claude 相关 headers + if v := c.GetHeader("anthropic-version"); v != "" { + req.Header.Set("anthropic-version", v) + } + if v := c.GetHeader("anthropic-beta"); v != "" { + req.Header.Set("anthropic-beta", v) + } + + // 代理 URL + proxyURL := "" + if account.ProxyID != nil && account.Proxy != nil { + proxyURL = account.Proxy.URL() + } + + // 发送请求 + resp, err := s.httpUpstream.Do(req, proxyURL, account.ID, account.Concurrency) + if err != nil { + log.Printf("%s upstream request failed: %v", prefix, err) + return nil, fmt.Errorf("upstream request failed: %w", err) + } + defer func() { _ = resp.Body.Close() }() + + // 处理错误响应 + if resp.StatusCode >= 400 { + respBody, _ := io.ReadAll(io.LimitReader(resp.Body, 2<<20)) + + // 429 错误时标记账号限流 + if resp.StatusCode == http.StatusTooManyRequests { + s.handleUpstreamError(ctx, prefix, account, resp.StatusCode, resp.Header, respBody, originalModel, 0, "", false) + } + + // 透传上游错误 + c.Header("Content-Type", resp.Header.Get("Content-Type")) + c.Status(resp.StatusCode) + _, _ = c.Writer.Write(respBody) + + return &ForwardResult{ + Model: billingModel, + }, nil + } + + // 处理成功响应(流式/非流式) + var usage *ClaudeUsage + var firstTokenMs *int + var clientDisconnect bool + + if claudeReq.Stream { + // 流式响应:透传 + c.Header("Content-Type", "text/event-stream") + c.Header("Cache-Control", "no-cache") + c.Header("Connection", "keep-alive") + c.Header("X-Accel-Buffering", "no") + c.Status(http.StatusOK) + + streamRes := s.streamUpstreamResponse(c, resp, startTime) + usage = streamRes.usage + firstTokenMs = streamRes.firstTokenMs + clientDisconnect = streamRes.clientDisconnect + } else { + // 非流式响应:直接透传 + respBody, err := io.ReadAll(resp.Body) + if err != nil { + return nil, fmt.Errorf("read upstream response: %w", err) + } + + // 提取 usage + usage = s.extractClaudeUsage(respBody) + + c.Header("Content-Type", resp.Header.Get("Content-Type")) + c.Status(http.StatusOK) + _, _ = c.Writer.Write(respBody) + } + + // 构建计费结果 + duration := time.Since(startTime) + log.Printf("%s status=success duration_ms=%d", prefix, duration.Milliseconds()) + + return &ForwardResult{ + Model: billingModel, + Stream: claudeReq.Stream, + Duration: duration, + FirstTokenMs: firstTokenMs, + ClientDisconnect: clientDisconnect, + Usage: ClaudeUsage{ + InputTokens: usage.InputTokens, + OutputTokens: usage.OutputTokens, + CacheReadInputTokens: usage.CacheReadInputTokens, + CacheCreationInputTokens: usage.CacheCreationInputTokens, + }, + }, nil +} + +// streamUpstreamResponse 透传上游 SSE 流并提取 Claude usage +func (s *AntigravityGatewayService) streamUpstreamResponse(c *gin.Context, resp *http.Response, startTime time.Time) *antigravityStreamResult { + usage := &ClaudeUsage{} + var firstTokenMs *int + + scanner := bufio.NewScanner(resp.Body) + maxLineSize := defaultMaxLineSize + if s.settingService.cfg != nil && s.settingService.cfg.Gateway.MaxLineSize > 0 { + maxLineSize = s.settingService.cfg.Gateway.MaxLineSize + } + scanner.Buffer(make([]byte, 64*1024), maxLineSize) + + type scanEvent struct { + line string + err error + } + events := make(chan scanEvent, 16) + done := make(chan struct{}) + sendEvent := func(ev scanEvent) bool { + select { + case events <- ev: + return true + case <-done: + return false + } + } + var lastReadAt int64 + atomic.StoreInt64(&lastReadAt, time.Now().UnixNano()) + go func() { + defer close(events) + for scanner.Scan() { + atomic.StoreInt64(&lastReadAt, time.Now().UnixNano()) + if !sendEvent(scanEvent{line: scanner.Text()}) { + return + } + } + if err := scanner.Err(); err != nil { + _ = sendEvent(scanEvent{err: err}) + } + }() + defer close(done) + + streamInterval := time.Duration(0) + if s.settingService.cfg != nil && s.settingService.cfg.Gateway.StreamDataIntervalTimeout > 0 { + streamInterval = time.Duration(s.settingService.cfg.Gateway.StreamDataIntervalTimeout) * time.Second + } + var intervalTicker *time.Ticker + if streamInterval > 0 { + intervalTicker = time.NewTicker(streamInterval) + defer intervalTicker.Stop() + } + var intervalCh <-chan time.Time + if intervalTicker != nil { + intervalCh = intervalTicker.C + } + + flusher, _ := c.Writer.(http.Flusher) + cw := newAntigravityClientWriter(c.Writer, flusher, "antigravity upstream") + + for { + select { + case ev, ok := <-events: + if !ok { + return &antigravityStreamResult{usage: usage, firstTokenMs: firstTokenMs, clientDisconnect: cw.Disconnected()} + } + if ev.err != nil { + if disconnect, handled := handleStreamReadError(ev.err, cw.Disconnected(), "antigravity upstream"); handled { + return &antigravityStreamResult{usage: usage, firstTokenMs: firstTokenMs, clientDisconnect: disconnect} + } + log.Printf("Stream read error (antigravity upstream): %v", ev.err) + return &antigravityStreamResult{usage: usage, firstTokenMs: firstTokenMs} + } + + line := ev.line + + // 记录首 token 时间 + if firstTokenMs == nil && len(line) > 0 { + ms := int(time.Since(startTime).Milliseconds()) + firstTokenMs = &ms + } + + // 尝试从 message_delta 或 message_stop 事件提取 usage + s.extractSSEUsage(line, usage) + + // 透传行 + cw.Fprintf("%s\n", line) + + case <-intervalCh: + lastRead := time.Unix(0, atomic.LoadInt64(&lastReadAt)) + if time.Since(lastRead) < streamInterval { + continue + } + if cw.Disconnected() { + log.Printf("Upstream timeout after client disconnect (antigravity upstream), returning collected usage") + return &antigravityStreamResult{usage: usage, firstTokenMs: firstTokenMs, clientDisconnect: true} + } + log.Printf("Stream data interval timeout (antigravity upstream)") + return &antigravityStreamResult{usage: usage, firstTokenMs: firstTokenMs} + } + } +} + +// extractSSEUsage 从 SSE data 行中提取 Claude usage(用于流式透传场景) +func (s *AntigravityGatewayService) extractSSEUsage(line string, usage *ClaudeUsage) { + if !strings.HasPrefix(line, "data: ") { + return + } + dataStr := strings.TrimPrefix(line, "data: ") + var event map[string]any + if json.Unmarshal([]byte(dataStr), &event) != nil { + return + } + u, ok := event["usage"].(map[string]any) + if !ok { + return + } + if v, ok := u["input_tokens"].(float64); ok && int(v) > 0 { + usage.InputTokens = int(v) + } + if v, ok := u["output_tokens"].(float64); ok && int(v) > 0 { + usage.OutputTokens = int(v) + } + if v, ok := u["cache_read_input_tokens"].(float64); ok && int(v) > 0 { + usage.CacheReadInputTokens = int(v) + } + if v, ok := u["cache_creation_input_tokens"].(float64); ok && int(v) > 0 { + usage.CacheCreationInputTokens = int(v) + } +} + +// extractClaudeUsage 从非流式 Claude 响应提取 usage +func (s *AntigravityGatewayService) extractClaudeUsage(body []byte) *ClaudeUsage { + usage := &ClaudeUsage{} + var resp map[string]any + if json.Unmarshal(body, &resp) != nil { + return usage + } + if u, ok := resp["usage"].(map[string]any); ok { + if v, ok := u["input_tokens"].(float64); ok { + usage.InputTokens = int(v) + } + if v, ok := u["output_tokens"].(float64); ok { + usage.OutputTokens = int(v) + } + if v, ok := u["cache_read_input_tokens"].(float64); ok { + usage.CacheReadInputTokens = int(v) + } + if v, ok := u["cache_creation_input_tokens"].(float64); ok { + usage.CacheCreationInputTokens = int(v) + } + } + return usage +} diff --git a/backend/internal/service/antigravity_gateway_service_test.go b/backend/internal/service/antigravity_gateway_service_test.go index ecad4171..12f35add 100644 --- a/backend/internal/service/antigravity_gateway_service_test.go +++ b/backend/internal/service/antigravity_gateway_service_test.go @@ -4,18 +4,42 @@ import ( "bytes" "context" "encoding/json" + "errors" + "fmt" "io" "net/http" "net/http/httptest" - "strings" "testing" "time" + "github.com/Wei-Shaw/sub2api/internal/config" "github.com/Wei-Shaw/sub2api/internal/pkg/antigravity" "github.com/gin-gonic/gin" "github.com/stretchr/testify/require" ) +// antigravityFailingWriter 模拟客户端断开连接的 gin.ResponseWriter +type antigravityFailingWriter struct { + gin.ResponseWriter + failAfter int // 允许成功写入的次数,之后所有写入返回错误 + writes int +} + +func (w *antigravityFailingWriter) Write(p []byte) (int, error) { + if w.writes >= w.failAfter { + return 0, errors.New("write failed: client disconnected") + } + w.writes++ + return w.ResponseWriter.Write(p) +} + +// newAntigravityTestService 创建用于流式测试的 AntigravityGatewayService +func newAntigravityTestService(cfg *config.Config) *AntigravityGatewayService { + return &AntigravityGatewayService{ + settingService: &SettingService{cfg: cfg}, + } +} + func TestStripSignatureSensitiveBlocksFromClaudeRequest(t *testing.T) { req := &antigravity.ClaudeRequest{ Model: "claude-sonnet-4-5", @@ -338,8 +362,8 @@ func TestAntigravityGatewayService_Forward_StickySessionForceCacheBilling(t *tes require.True(t, failoverErr.ForceCacheBilling, "ForceCacheBilling should be true for sticky session switch") } -// TestAntigravityGatewayService_ForwardGemini_StickySessionForceCacheBilling -// 验证:ForwardGemini 粘性会话切换时,UpstreamFailoverError.ForceCacheBilling 应为 true +// TestAntigravityGatewayService_ForwardGemini_StickySessionForceCacheBilling verifies +// that ForwardGemini sets ForceCacheBilling=true for sticky session switch. func TestAntigravityGatewayService_ForwardGemini_StickySessionForceCacheBilling(t *testing.T) { gin.SetMode(gin.TestMode) writer := httptest.NewRecorder() @@ -393,10 +417,16 @@ func TestAntigravityGatewayService_ForwardGemini_StickySessionForceCacheBilling( require.True(t, failoverErr.ForceCacheBilling, "ForceCacheBilling should be true for sticky session switch") } -func TestAntigravityStreamUpstreamResponse_UsageAndFirstToken(t *testing.T) { +// TestStreamUpstreamResponse_UsageAndFirstToken +// 验证:usage 字段可被累积/覆盖更新,并且能记录首 token 时间 +func TestStreamUpstreamResponse_UsageAndFirstToken(t *testing.T) { gin.SetMode(gin.TestMode) - writer := httptest.NewRecorder() - c, _ := gin.CreateTestContext(writer) + svc := newAntigravityTestService(&config.Config{ + Gateway: config.GatewayConfig{MaxLineSize: defaultMaxLineSize}, + }) + + rec := httptest.NewRecorder() + c, _ := gin.CreateTestContext(rec) c.Request = httptest.NewRequest(http.MethodPost, "/v1/messages", nil) pr, pw := io.Pipe() @@ -404,25 +434,458 @@ func TestAntigravityStreamUpstreamResponse_UsageAndFirstToken(t *testing.T) { go func() { defer func() { _ = pw.Close() }() - _, _ = pw.Write([]byte("data: {\"usage\":{\"input_tokens\":1,\"output_tokens\":2,\"cache_read_input_tokens\":3,\"cache_creation_input_tokens\":4}}\n")) - _, _ = pw.Write([]byte("data: {\"usage\":{\"output_tokens\":5}}\n")) + fmt.Fprintln(pw, `data: {"usage":{"input_tokens":1,"output_tokens":2,"cache_read_input_tokens":3,"cache_creation_input_tokens":4}}`) + fmt.Fprintln(pw, `data: {"usage":{"output_tokens":5}}`) }() - svc := &AntigravityGatewayService{} start := time.Now().Add(-10 * time.Millisecond) - usage, firstTokenMs := svc.streamUpstreamResponse(c, resp, start) + result := svc.streamUpstreamResponse(c, resp, start) _ = pr.Close() - require.NotNil(t, usage) - require.Equal(t, 1, usage.InputTokens) + require.NotNil(t, result) + require.NotNil(t, result.usage) + require.Equal(t, 1, result.usage.InputTokens) // 第二次事件覆盖 output_tokens - require.Equal(t, 5, usage.OutputTokens) - require.Equal(t, 3, usage.CacheReadInputTokens) - require.Equal(t, 4, usage.CacheCreationInputTokens) + require.Equal(t, 5, result.usage.OutputTokens) + require.Equal(t, 3, result.usage.CacheReadInputTokens) + require.Equal(t, 4, result.usage.CacheCreationInputTokens) + require.NotNil(t, result.firstTokenMs) - if firstTokenMs == nil { - t.Fatalf("expected firstTokenMs to be set") - } // 确保有透传输出 - require.True(t, strings.Contains(writer.Body.String(), "data:")) + require.Contains(t, rec.Body.String(), "data:") +} + +// --- 流式 happy path 测试 --- + +// TestStreamUpstreamResponse_NormalComplete +// 验证:正常流式转发完成时,数据正确透传、usage 正确收集、clientDisconnect=false +func TestStreamUpstreamResponse_NormalComplete(t *testing.T) { + gin.SetMode(gin.TestMode) + svc := newAntigravityTestService(&config.Config{ + Gateway: config.GatewayConfig{MaxLineSize: defaultMaxLineSize}, + }) + + rec := httptest.NewRecorder() + c, _ := gin.CreateTestContext(rec) + c.Request = httptest.NewRequest(http.MethodPost, "/", nil) + + pr, pw := io.Pipe() + resp := &http.Response{StatusCode: http.StatusOK, Body: pr, Header: http.Header{}} + + go func() { + defer func() { _ = pw.Close() }() + fmt.Fprintln(pw, `event: message_start`) + fmt.Fprintln(pw, `data: {"type":"message_start","message":{"usage":{"input_tokens":10}}}`) + fmt.Fprintln(pw, "") + fmt.Fprintln(pw, `event: content_block_delta`) + fmt.Fprintln(pw, `data: {"type":"content_block_delta","delta":{"text":"hello"}}`) + fmt.Fprintln(pw, "") + fmt.Fprintln(pw, `event: message_delta`) + fmt.Fprintln(pw, `data: {"type":"message_delta","usage":{"output_tokens":5}}`) + fmt.Fprintln(pw, "") + }() + + result := svc.streamUpstreamResponse(c, resp, time.Now()) + _ = pr.Close() + + require.NotNil(t, result) + require.False(t, result.clientDisconnect, "normal completion should not set clientDisconnect") + require.NotNil(t, result.usage) + require.Equal(t, 5, result.usage.OutputTokens, "should collect output_tokens from message_delta") + require.NotNil(t, result.firstTokenMs, "should record first token time") + + // 验证数据被透传到客户端 + body := rec.Body.String() + require.Contains(t, body, "event: message_start") + require.Contains(t, body, "content_block_delta") + require.Contains(t, body, "message_delta") +} + +// TestHandleGeminiStreamingResponse_NormalComplete +// 验证:正常 Gemini 流式转发,数据正确透传、usage 正确收集 +func TestHandleGeminiStreamingResponse_NormalComplete(t *testing.T) { + gin.SetMode(gin.TestMode) + svc := newAntigravityTestService(&config.Config{ + Gateway: config.GatewayConfig{MaxLineSize: defaultMaxLineSize}, + }) + + rec := httptest.NewRecorder() + c, _ := gin.CreateTestContext(rec) + c.Request = httptest.NewRequest(http.MethodPost, "/", nil) + + pr, pw := io.Pipe() + resp := &http.Response{StatusCode: http.StatusOK, Body: pr, Header: http.Header{}} + + go func() { + defer func() { _ = pw.Close() }() + // 第一个 chunk(部分内容) + fmt.Fprintln(pw, `data: {"candidates":[{"content":{"parts":[{"text":"Hello"}]}}],"usageMetadata":{"promptTokenCount":10,"candidatesTokenCount":3}}`) + fmt.Fprintln(pw, "") + // 第二个 chunk(最终内容+完整 usage) + fmt.Fprintln(pw, `data: {"candidates":[{"content":{"parts":[{"text":" world"}]},"finishReason":"STOP"}],"usageMetadata":{"promptTokenCount":10,"candidatesTokenCount":8,"cachedContentTokenCount":2}}`) + fmt.Fprintln(pw, "") + }() + + result, err := svc.handleGeminiStreamingResponse(c, resp, time.Now()) + _ = pr.Close() + + require.NoError(t, err) + require.NotNil(t, result) + require.False(t, result.clientDisconnect, "normal completion should not set clientDisconnect") + require.NotNil(t, result.usage) + // Gemini usage: promptTokenCount=10, candidatesTokenCount=8, cachedContentTokenCount=2 + // → InputTokens=10-2=8, OutputTokens=8, CacheReadInputTokens=2 + require.Equal(t, 8, result.usage.InputTokens) + require.Equal(t, 8, result.usage.OutputTokens) + require.Equal(t, 2, result.usage.CacheReadInputTokens) + require.NotNil(t, result.firstTokenMs, "should record first token time") + + // 验证数据被透传到客户端 + body := rec.Body.String() + require.Contains(t, body, "Hello") + require.Contains(t, body, "world") + // 不应包含错误事件 + require.NotContains(t, body, "event: error") +} + +// TestHandleClaudeStreamingResponse_NormalComplete +// 验证:正常 Claude 流式转发(Gemini→Claude 转换),数据正确转换并输出 +func TestHandleClaudeStreamingResponse_NormalComplete(t *testing.T) { + gin.SetMode(gin.TestMode) + svc := newAntigravityTestService(&config.Config{ + Gateway: config.GatewayConfig{MaxLineSize: defaultMaxLineSize}, + }) + + rec := httptest.NewRecorder() + c, _ := gin.CreateTestContext(rec) + c.Request = httptest.NewRequest(http.MethodPost, "/", nil) + + pr, pw := io.Pipe() + resp := &http.Response{StatusCode: http.StatusOK, Body: pr, Header: http.Header{}} + + go func() { + defer func() { _ = pw.Close() }() + // v1internal 包装格式:Gemini 数据嵌套在 "response" 字段下 + // ProcessLine 先尝试反序列化为 V1InternalResponse,裸格式会导致 Response.UsageMetadata 为空 + fmt.Fprintln(pw, `data: {"response":{"candidates":[{"content":{"parts":[{"text":"Hi there"}]},"finishReason":"STOP"}],"usageMetadata":{"promptTokenCount":5,"candidatesTokenCount":3}}}`) + fmt.Fprintln(pw, "") + }() + + result, err := svc.handleClaudeStreamingResponse(c, resp, time.Now(), "claude-sonnet-4-5") + _ = pr.Close() + + require.NoError(t, err) + require.NotNil(t, result) + require.False(t, result.clientDisconnect, "normal completion should not set clientDisconnect") + require.NotNil(t, result.usage) + // Gemini→Claude 转换的 usage:promptTokenCount=5→InputTokens=5, candidatesTokenCount=3→OutputTokens=3 + require.Equal(t, 5, result.usage.InputTokens) + require.Equal(t, 3, result.usage.OutputTokens) + require.NotNil(t, result.firstTokenMs, "should record first token time") + + // 验证输出是 Claude SSE 格式(processor 会转换) + body := rec.Body.String() + require.Contains(t, body, "event: message_start", "should contain Claude message_start event") + require.Contains(t, body, "event: message_stop", "should contain Claude message_stop event") + // 不应包含错误事件 + require.NotContains(t, body, "event: error") +} + +// --- 流式客户端断开检测测试 --- + +// TestStreamUpstreamResponse_ClientDisconnectDrainsUsage +// 验证:客户端写入失败后,streamUpstreamResponse 继续读取上游以收集 usage +func TestStreamUpstreamResponse_ClientDisconnectDrainsUsage(t *testing.T) { + gin.SetMode(gin.TestMode) + svc := newAntigravityTestService(&config.Config{ + Gateway: config.GatewayConfig{MaxLineSize: defaultMaxLineSize}, + }) + + rec := httptest.NewRecorder() + c, _ := gin.CreateTestContext(rec) + c.Request = httptest.NewRequest(http.MethodPost, "/", nil) + c.Writer = &antigravityFailingWriter{ResponseWriter: c.Writer, failAfter: 0} + + pr, pw := io.Pipe() + resp := &http.Response{StatusCode: http.StatusOK, Body: pr, Header: http.Header{}} + + go func() { + defer func() { _ = pw.Close() }() + fmt.Fprintln(pw, `event: message_start`) + fmt.Fprintln(pw, `data: {"type":"message_start","message":{"usage":{"input_tokens":10}}}`) + fmt.Fprintln(pw, "") + fmt.Fprintln(pw, `event: message_delta`) + fmt.Fprintln(pw, `data: {"type":"message_delta","usage":{"output_tokens":20}}`) + fmt.Fprintln(pw, "") + }() + + result := svc.streamUpstreamResponse(c, resp, time.Now()) + _ = pr.Close() + + require.NotNil(t, result) + require.True(t, result.clientDisconnect) + require.NotNil(t, result.usage) + require.Equal(t, 20, result.usage.OutputTokens) +} + +// TestStreamUpstreamResponse_ContextCanceled +// 验证:context 取消时返回 usage 且标记 clientDisconnect +func TestStreamUpstreamResponse_ContextCanceled(t *testing.T) { + gin.SetMode(gin.TestMode) + svc := newAntigravityTestService(&config.Config{ + Gateway: config.GatewayConfig{MaxLineSize: defaultMaxLineSize}, + }) + + rec := httptest.NewRecorder() + c, _ := gin.CreateTestContext(rec) + ctx, cancel := context.WithCancel(context.Background()) + cancel() + c.Request = httptest.NewRequest(http.MethodPost, "/", nil).WithContext(ctx) + + resp := &http.Response{StatusCode: http.StatusOK, Body: cancelReadCloser{}, Header: http.Header{}} + + result := svc.streamUpstreamResponse(c, resp, time.Now()) + + require.NotNil(t, result) + require.True(t, result.clientDisconnect) + require.NotContains(t, rec.Body.String(), "event: error") +} + +// TestStreamUpstreamResponse_Timeout +// 验证:上游超时时返回已收集的 usage +func TestStreamUpstreamResponse_Timeout(t *testing.T) { + gin.SetMode(gin.TestMode) + svc := newAntigravityTestService(&config.Config{ + Gateway: config.GatewayConfig{StreamDataIntervalTimeout: 1, MaxLineSize: defaultMaxLineSize}, + }) + + rec := httptest.NewRecorder() + c, _ := gin.CreateTestContext(rec) + c.Request = httptest.NewRequest(http.MethodPost, "/", nil) + + pr, pw := io.Pipe() + resp := &http.Response{StatusCode: http.StatusOK, Body: pr, Header: http.Header{}} + + result := svc.streamUpstreamResponse(c, resp, time.Now()) + _ = pw.Close() + _ = pr.Close() + + require.NotNil(t, result) + require.False(t, result.clientDisconnect) +} + +// TestStreamUpstreamResponse_TimeoutAfterClientDisconnect +// 验证:客户端断开后上游超时,返回 usage 并标记 clientDisconnect +func TestStreamUpstreamResponse_TimeoutAfterClientDisconnect(t *testing.T) { + gin.SetMode(gin.TestMode) + svc := newAntigravityTestService(&config.Config{ + Gateway: config.GatewayConfig{StreamDataIntervalTimeout: 1, MaxLineSize: defaultMaxLineSize}, + }) + + rec := httptest.NewRecorder() + c, _ := gin.CreateTestContext(rec) + c.Request = httptest.NewRequest(http.MethodPost, "/", nil) + c.Writer = &antigravityFailingWriter{ResponseWriter: c.Writer, failAfter: 0} + + pr, pw := io.Pipe() + resp := &http.Response{StatusCode: http.StatusOK, Body: pr, Header: http.Header{}} + + go func() { + fmt.Fprintln(pw, `data: {"type":"message_start","message":{"usage":{"input_tokens":5}}}`) + fmt.Fprintln(pw, "") + // 不关闭 pw → 等待超时 + }() + + result := svc.streamUpstreamResponse(c, resp, time.Now()) + _ = pw.Close() + _ = pr.Close() + + require.NotNil(t, result) + require.True(t, result.clientDisconnect) +} + +// TestHandleGeminiStreamingResponse_ClientDisconnect +// 验证:Gemini 流式转发中客户端断开后继续 drain 上游 +func TestHandleGeminiStreamingResponse_ClientDisconnect(t *testing.T) { + gin.SetMode(gin.TestMode) + svc := newAntigravityTestService(&config.Config{ + Gateway: config.GatewayConfig{MaxLineSize: defaultMaxLineSize}, + }) + + rec := httptest.NewRecorder() + c, _ := gin.CreateTestContext(rec) + c.Request = httptest.NewRequest(http.MethodPost, "/", nil) + c.Writer = &antigravityFailingWriter{ResponseWriter: c.Writer, failAfter: 0} + + pr, pw := io.Pipe() + resp := &http.Response{StatusCode: http.StatusOK, Body: pr, Header: http.Header{}} + + go func() { + defer func() { _ = pw.Close() }() + fmt.Fprintln(pw, `data: {"candidates":[{"content":{"parts":[{"text":"hi"}]}}],"usageMetadata":{"promptTokenCount":5,"candidatesTokenCount":10}}`) + fmt.Fprintln(pw, "") + }() + + result, err := svc.handleGeminiStreamingResponse(c, resp, time.Now()) + _ = pr.Close() + + require.NoError(t, err) + require.NotNil(t, result) + require.True(t, result.clientDisconnect) + require.NotContains(t, rec.Body.String(), "write_failed") +} + +// TestHandleGeminiStreamingResponse_ContextCanceled +// 验证:context 取消时不注入错误事件 +func TestHandleGeminiStreamingResponse_ContextCanceled(t *testing.T) { + gin.SetMode(gin.TestMode) + svc := newAntigravityTestService(&config.Config{ + Gateway: config.GatewayConfig{MaxLineSize: defaultMaxLineSize}, + }) + + rec := httptest.NewRecorder() + c, _ := gin.CreateTestContext(rec) + ctx, cancel := context.WithCancel(context.Background()) + cancel() + c.Request = httptest.NewRequest(http.MethodPost, "/", nil).WithContext(ctx) + + resp := &http.Response{StatusCode: http.StatusOK, Body: cancelReadCloser{}, Header: http.Header{}} + + result, err := svc.handleGeminiStreamingResponse(c, resp, time.Now()) + + require.NoError(t, err) + require.NotNil(t, result) + require.True(t, result.clientDisconnect) + require.NotContains(t, rec.Body.String(), "event: error") +} + +// TestHandleClaudeStreamingResponse_ClientDisconnect +// 验证:Claude 流式转发中客户端断开后继续 drain 上游 +func TestHandleClaudeStreamingResponse_ClientDisconnect(t *testing.T) { + gin.SetMode(gin.TestMode) + svc := newAntigravityTestService(&config.Config{ + Gateway: config.GatewayConfig{MaxLineSize: defaultMaxLineSize}, + }) + + rec := httptest.NewRecorder() + c, _ := gin.CreateTestContext(rec) + c.Request = httptest.NewRequest(http.MethodPost, "/", nil) + c.Writer = &antigravityFailingWriter{ResponseWriter: c.Writer, failAfter: 0} + + pr, pw := io.Pipe() + resp := &http.Response{StatusCode: http.StatusOK, Body: pr, Header: http.Header{}} + + go func() { + defer func() { _ = pw.Close() }() + // v1internal 包装格式 + fmt.Fprintln(pw, `data: {"response":{"candidates":[{"content":{"parts":[{"text":"hello"}]},"finishReason":"STOP"}],"usageMetadata":{"promptTokenCount":8,"candidatesTokenCount":15}}}`) + fmt.Fprintln(pw, "") + }() + + result, err := svc.handleClaudeStreamingResponse(c, resp, time.Now(), "claude-sonnet-4-5") + _ = pr.Close() + + require.NoError(t, err) + require.NotNil(t, result) + require.True(t, result.clientDisconnect) +} + +// TestHandleClaudeStreamingResponse_ContextCanceled +// 验证:context 取消时不注入错误事件 +func TestHandleClaudeStreamingResponse_ContextCanceled(t *testing.T) { + gin.SetMode(gin.TestMode) + svc := newAntigravityTestService(&config.Config{ + Gateway: config.GatewayConfig{MaxLineSize: defaultMaxLineSize}, + }) + + rec := httptest.NewRecorder() + c, _ := gin.CreateTestContext(rec) + ctx, cancel := context.WithCancel(context.Background()) + cancel() + c.Request = httptest.NewRequest(http.MethodPost, "/", nil).WithContext(ctx) + + resp := &http.Response{StatusCode: http.StatusOK, Body: cancelReadCloser{}, Header: http.Header{}} + + result, err := svc.handleClaudeStreamingResponse(c, resp, time.Now(), "claude-sonnet-4-5") + + require.NoError(t, err) + require.NotNil(t, result) + require.True(t, result.clientDisconnect) + require.NotContains(t, rec.Body.String(), "event: error") +} + +// TestExtractSSEUsage 验证 extractSSEUsage 从 SSE data 行正确提取 usage +func TestExtractSSEUsage(t *testing.T) { + svc := &AntigravityGatewayService{} + tests := []struct { + name string + line string + expected ClaudeUsage + }{ + { + name: "message_delta with output_tokens", + line: `data: {"type":"message_delta","usage":{"output_tokens":42}}`, + expected: ClaudeUsage{OutputTokens: 42}, + }, + { + name: "non-data line ignored", + line: `event: message_start`, + expected: ClaudeUsage{}, + }, + { + name: "top-level usage with all fields", + line: `data: {"usage":{"input_tokens":10,"output_tokens":20,"cache_read_input_tokens":5,"cache_creation_input_tokens":3}}`, + expected: ClaudeUsage{InputTokens: 10, OutputTokens: 20, CacheReadInputTokens: 5, CacheCreationInputTokens: 3}, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + usage := &ClaudeUsage{} + svc.extractSSEUsage(tt.line, usage) + require.Equal(t, tt.expected, *usage) + }) + } +} + +// TestAntigravityClientWriter 验证 antigravityClientWriter 的断开检测 +func TestAntigravityClientWriter(t *testing.T) { + t.Run("normal write succeeds", func(t *testing.T) { + gin.SetMode(gin.TestMode) + rec := httptest.NewRecorder() + c, _ := gin.CreateTestContext(rec) + flusher, _ := c.Writer.(http.Flusher) + cw := newAntigravityClientWriter(c.Writer, flusher, "test") + + ok := cw.Write([]byte("hello")) + require.True(t, ok) + require.False(t, cw.Disconnected()) + require.Contains(t, rec.Body.String(), "hello") + }) + + t.Run("write failure marks disconnected", func(t *testing.T) { + gin.SetMode(gin.TestMode) + rec := httptest.NewRecorder() + c, _ := gin.CreateTestContext(rec) + fw := &antigravityFailingWriter{ResponseWriter: c.Writer, failAfter: 0} + flusher, _ := c.Writer.(http.Flusher) + cw := newAntigravityClientWriter(fw, flusher, "test") + + ok := cw.Write([]byte("hello")) + require.False(t, ok) + require.True(t, cw.Disconnected()) + }) + + t.Run("subsequent writes are no-op", func(t *testing.T) { + gin.SetMode(gin.TestMode) + rec := httptest.NewRecorder() + c, _ := gin.CreateTestContext(rec) + fw := &antigravityFailingWriter{ResponseWriter: c.Writer, failAfter: 0} + flusher, _ := c.Writer.(http.Flusher) + cw := newAntigravityClientWriter(fw, flusher, "test") + + cw.Write([]byte("first")) + ok := cw.Fprintf("second %d", 2) + require.False(t, ok) + require.True(t, cw.Disconnected()) + }) } diff --git a/backend/internal/service/antigravity_quota_scope.go b/backend/internal/service/antigravity_quota_scope.go index 43ac6c2f..e181e7f8 100644 --- a/backend/internal/service/antigravity_quota_scope.go +++ b/backend/internal/service/antigravity_quota_scope.go @@ -2,63 +2,23 @@ package service import ( "context" - "slices" "strings" "time" ) -const antigravityQuotaScopesKey = "antigravity_quota_scopes" - -// AntigravityQuotaScope 表示 Antigravity 的配额域 -type AntigravityQuotaScope string - -const ( - AntigravityQuotaScopeClaude AntigravityQuotaScope = "claude" - AntigravityQuotaScopeGeminiText AntigravityQuotaScope = "gemini_text" - AntigravityQuotaScopeGeminiImage AntigravityQuotaScope = "gemini_image" -) - -// IsScopeSupported 检查给定的 scope 是否在分组支持的 scope 列表中 -func IsScopeSupported(supportedScopes []string, scope AntigravityQuotaScope) bool { - if len(supportedScopes) == 0 { - // 未配置时默认全部支持 - return true - } - supported := slices.Contains(supportedScopes, string(scope)) - return supported -} - -// ResolveAntigravityQuotaScope 根据模型名称解析配额域(导出版本) -func ResolveAntigravityQuotaScope(requestedModel string) (AntigravityQuotaScope, bool) { - return resolveAntigravityQuotaScope(requestedModel) -} - -// resolveAntigravityQuotaScope 根据模型名称解析配额域 -func resolveAntigravityQuotaScope(requestedModel string) (AntigravityQuotaScope, bool) { - model := normalizeAntigravityModelName(requestedModel) - if model == "" { - return "", false - } - switch { - case strings.HasPrefix(model, "claude-"): - return AntigravityQuotaScopeClaude, true - case strings.HasPrefix(model, "gemini-"): - if isImageGenerationModel(model) { - return AntigravityQuotaScopeGeminiImage, true - } - return AntigravityQuotaScopeGeminiText, true - default: - return "", false - } -} - func normalizeAntigravityModelName(model string) string { normalized := strings.ToLower(strings.TrimSpace(model)) normalized = strings.TrimPrefix(normalized, "models/") return normalized } -// IsSchedulableForModel 结合 Antigravity 配额域限流判断是否可调度。 +// resolveAntigravityModelKey 根据请求的模型名解析限流 key +// 返回空字符串表示无法解析 +func resolveAntigravityModelKey(requestedModel string) string { + return normalizeAntigravityModelName(requestedModel) +} + +// IsSchedulableForModel 结合模型级限流判断是否可调度。 // 保持旧签名以兼容既有调用方;默认使用 context.Background()。 func (a *Account) IsSchedulableForModel(requestedModel string) bool { return a.IsSchedulableForModelWithContext(context.Background(), requestedModel) @@ -74,107 +34,20 @@ func (a *Account) IsSchedulableForModelWithContext(ctx context.Context, requeste if a.isModelRateLimitedWithContext(ctx, requestedModel) { return false } - if a.Platform != PlatformAntigravity { - return true - } - scope, ok := resolveAntigravityQuotaScope(requestedModel) - if !ok { - return true - } - resetAt := a.antigravityQuotaScopeResetAt(scope) - if resetAt == nil { - return true - } - now := time.Now() - return !now.Before(*resetAt) + return true } -func (a *Account) antigravityQuotaScopeResetAt(scope AntigravityQuotaScope) *time.Time { - if a == nil || a.Extra == nil || scope == "" { - return nil - } - rawScopes, ok := a.Extra[antigravityQuotaScopesKey].(map[string]any) - if !ok { - return nil - } - rawScope, ok := rawScopes[string(scope)].(map[string]any) - if !ok { - return nil - } - resetAtRaw, ok := rawScope["rate_limit_reset_at"].(string) - if !ok || strings.TrimSpace(resetAtRaw) == "" { - return nil - } - resetAt, err := time.Parse(time.RFC3339, resetAtRaw) - if err != nil { - return nil - } - return &resetAt -} - -var antigravityAllScopes = []AntigravityQuotaScope{ - AntigravityQuotaScopeClaude, - AntigravityQuotaScopeGeminiText, - AntigravityQuotaScopeGeminiImage, -} - -func (a *Account) GetAntigravityScopeRateLimits() map[string]int64 { - if a == nil || a.Platform != PlatformAntigravity { - return nil - } - now := time.Now() - result := make(map[string]int64) - for _, scope := range antigravityAllScopes { - resetAt := a.antigravityQuotaScopeResetAt(scope) - if resetAt != nil && now.Before(*resetAt) { - remainingSec := int64(time.Until(*resetAt).Seconds()) - if remainingSec > 0 { - result[string(scope)] = remainingSec - } - } - } - if len(result) == 0 { - return nil - } - return result -} - -// GetQuotaScopeRateLimitRemainingTime 获取模型域限流剩余时间 -// 返回 0 表示未限流或已过期 -func (a *Account) GetQuotaScopeRateLimitRemainingTime(requestedModel string) time.Duration { - if a == nil || a.Platform != PlatformAntigravity { - return 0 - } - scope, ok := resolveAntigravityQuotaScope(requestedModel) - if !ok { - return 0 - } - resetAt := a.antigravityQuotaScopeResetAt(scope) - if resetAt == nil { - return 0 - } - if remaining := time.Until(*resetAt); remaining > 0 { - return remaining - } - return 0 -} - -// GetRateLimitRemainingTime 获取限流剩余时间(模型限流和模型域限流取最大值) +// GetRateLimitRemainingTime 获取限流剩余时间(模型级限流) // 返回 0 表示未限流或已过期 func (a *Account) GetRateLimitRemainingTime(requestedModel string) time.Duration { return a.GetRateLimitRemainingTimeWithContext(context.Background(), requestedModel) } -// GetRateLimitRemainingTimeWithContext 获取限流剩余时间(模型限流和模型域限流取最大值) +// GetRateLimitRemainingTimeWithContext 获取限流剩余时间(模型级限流) // 返回 0 表示未限流或已过期 func (a *Account) GetRateLimitRemainingTimeWithContext(ctx context.Context, requestedModel string) time.Duration { if a == nil { return 0 } - modelRemaining := a.GetModelRateLimitRemainingTimeWithContext(ctx, requestedModel) - scopeRemaining := a.GetQuotaScopeRateLimitRemainingTime(requestedModel) - if modelRemaining > scopeRemaining { - return modelRemaining - } - return scopeRemaining + return a.GetModelRateLimitRemainingTimeWithContext(ctx, requestedModel) } diff --git a/backend/internal/service/antigravity_rate_limit_test.go b/backend/internal/service/antigravity_rate_limit_test.go index 2b4a5504..243bf90b 100644 --- a/backend/internal/service/antigravity_rate_limit_test.go +++ b/backend/internal/service/antigravity_rate_limit_test.go @@ -65,12 +65,6 @@ func (s *stubAntigravityUpstream) DoWithTLS(req *http.Request, proxyURL string, return s.Do(req, proxyURL, accountID, accountConcurrency) } -type scopeLimitCall struct { - accountID int64 - scope AntigravityQuotaScope - resetAt time.Time -} - type rateLimitCall struct { accountID int64 resetAt time.Time @@ -84,16 +78,10 @@ type modelRateLimitCall struct { type stubAntigravityAccountRepo struct { AccountRepository - scopeCalls []scopeLimitCall rateCalls []rateLimitCall modelRateLimitCalls []modelRateLimitCall } -func (s *stubAntigravityAccountRepo) SetAntigravityQuotaScopeLimit(ctx context.Context, id int64, scope AntigravityQuotaScope, resetAt time.Time) error { - s.scopeCalls = append(s.scopeCalls, scopeLimitCall{accountID: id, scope: scope, resetAt: resetAt}) - return nil -} - func (s *stubAntigravityAccountRepo) SetRateLimited(ctx context.Context, id int64, resetAt time.Time) error { s.rateCalls = append(s.rateCalls, rateLimitCall{accountID: id, resetAt: resetAt}) return nil @@ -137,10 +125,9 @@ func TestAntigravityRetryLoop_URLFallback_UsesLatestSuccess(t *testing.T) { accessToken: "token", action: "generateContent", body: []byte(`{"input":"test"}`), - quotaScope: AntigravityQuotaScopeClaude, httpUpstream: upstream, requestedModel: "claude-sonnet-4-5", - handleError: func(ctx context.Context, prefix string, account *Account, statusCode int, headers http.Header, body []byte, quotaScope AntigravityQuotaScope, groupID int64, sessionHash string, isStickySession bool) *handleModelRateLimitResult { + handleError: func(ctx context.Context, prefix string, account *Account, statusCode int, headers http.Header, body []byte, requestedModel string, groupID int64, sessionHash string, isStickySession bool) *handleModelRateLimitResult { handleErrorCalled = true return nil }, @@ -161,23 +148,6 @@ func TestAntigravityRetryLoop_URLFallback_UsesLatestSuccess(t *testing.T) { require.Equal(t, base2, available[0]) } -func TestAntigravityHandleUpstreamError_UsesScopeLimit(t *testing.T) { - // 分区限流始终开启,不再支持通过环境变量关闭 - repo := &stubAntigravityAccountRepo{} - svc := &AntigravityGatewayService{accountRepo: repo} - account := &Account{ID: 9, Name: "acc-9", Platform: PlatformAntigravity} - - body := buildGeminiRateLimitBody("3s") - svc.handleUpstreamError(context.Background(), "[test]", account, http.StatusTooManyRequests, http.Header{}, body, AntigravityQuotaScopeClaude, 0, "", false) - - require.Len(t, repo.scopeCalls, 1) - require.Empty(t, repo.rateCalls) - call := repo.scopeCalls[0] - require.Equal(t, account.ID, call.accountID) - require.Equal(t, AntigravityQuotaScopeClaude, call.scope) - require.WithinDuration(t, time.Now().Add(3*time.Second), call.resetAt, 2*time.Second) -} - // TestHandleUpstreamError_429_ModelRateLimit 测试 429 模型限流场景 func TestHandleUpstreamError_429_ModelRateLimit(t *testing.T) { repo := &stubAntigravityAccountRepo{} @@ -195,7 +165,7 @@ func TestHandleUpstreamError_429_ModelRateLimit(t *testing.T) { } }`) - result := svc.handleUpstreamError(context.Background(), "[test]", account, http.StatusTooManyRequests, http.Header{}, body, AntigravityQuotaScopeClaude, 0, "", false) + result := svc.handleUpstreamError(context.Background(), "[test]", account, http.StatusTooManyRequests, http.Header{}, body, "claude-sonnet-4-5", 0, "", false) // 应该触发模型限流 require.NotNil(t, result) @@ -206,22 +176,22 @@ func TestHandleUpstreamError_429_ModelRateLimit(t *testing.T) { require.Equal(t, "claude-sonnet-4-5", repo.modelRateLimitCalls[0].modelKey) } -// TestHandleUpstreamError_429_NonModelRateLimit 测试 429 非模型限流场景(走 scope 限流) +// TestHandleUpstreamError_429_NonModelRateLimit 测试 429 非模型限流场景(走模型级限流兜底) func TestHandleUpstreamError_429_NonModelRateLimit(t *testing.T) { repo := &stubAntigravityAccountRepo{} svc := &AntigravityGatewayService{accountRepo: repo} account := &Account{ID: 2, Name: "acc-2", Platform: PlatformAntigravity} - // 429 + 普通限流响应(无 RATE_LIMIT_EXCEEDED reason)→ scope 限流 + // 429 + 普通限流响应(无 RATE_LIMIT_EXCEEDED reason)→ 走模型级限流兜底 body := buildGeminiRateLimitBody("5s") - result := svc.handleUpstreamError(context.Background(), "[test]", account, http.StatusTooManyRequests, http.Header{}, body, AntigravityQuotaScopeClaude, 0, "", false) + result := svc.handleUpstreamError(context.Background(), "[test]", account, http.StatusTooManyRequests, http.Header{}, body, "claude-sonnet-4-5", 0, "", false) - // 不应该触发模型限流,应该走 scope 限流 + // handleModelRateLimit 不会处理(因为没有 RATE_LIMIT_EXCEEDED), + // 但 429 兜底逻辑会使用 requestedModel 设置模型级限流 require.Nil(t, result) - require.Empty(t, repo.modelRateLimitCalls) - require.Len(t, repo.scopeCalls, 1) - require.Equal(t, AntigravityQuotaScopeClaude, repo.scopeCalls[0].scope) + require.Len(t, repo.modelRateLimitCalls, 1) + require.Equal(t, "claude-sonnet-4-5", repo.modelRateLimitCalls[0].modelKey) } // TestHandleUpstreamError_503_ModelRateLimit 测试 503 模型限流场景 @@ -241,7 +211,7 @@ func TestHandleUpstreamError_503_ModelRateLimit(t *testing.T) { } }`) - result := svc.handleUpstreamError(context.Background(), "[test]", account, http.StatusServiceUnavailable, http.Header{}, body, AntigravityQuotaScopeGeminiText, 0, "", false) + result := svc.handleUpstreamError(context.Background(), "[test]", account, http.StatusServiceUnavailable, http.Header{}, body, "gemini-3-pro-high", 0, "", false) // 应该触发模型限流 require.NotNil(t, result) @@ -269,12 +239,11 @@ func TestHandleUpstreamError_503_NonModelRateLimit(t *testing.T) { } }`) - result := svc.handleUpstreamError(context.Background(), "[test]", account, http.StatusServiceUnavailable, http.Header{}, body, AntigravityQuotaScopeGeminiText, 0, "", false) + result := svc.handleUpstreamError(context.Background(), "[test]", account, http.StatusServiceUnavailable, http.Header{}, body, "gemini-3-pro-high", 0, "", false) // 503 非模型限流不应该做任何处理 require.Nil(t, result) require.Empty(t, repo.modelRateLimitCalls, "503 non-model rate limit should not trigger model rate limit") - require.Empty(t, repo.scopeCalls, "503 non-model rate limit should not trigger scope rate limit") require.Empty(t, repo.rateCalls, "503 non-model rate limit should not trigger account rate limit") } @@ -287,12 +256,11 @@ func TestHandleUpstreamError_503_EmptyBody(t *testing.T) { // 503 + 空响应体 → 不做任何处理 body := []byte(`{}`) - result := svc.handleUpstreamError(context.Background(), "[test]", account, http.StatusServiceUnavailable, http.Header{}, body, AntigravityQuotaScopeGeminiText, 0, "", false) + result := svc.handleUpstreamError(context.Background(), "[test]", account, http.StatusServiceUnavailable, http.Header{}, body, "gemini-3-pro-high", 0, "", false) // 503 空响应不应该做任何处理 require.Nil(t, result) require.Empty(t, repo.modelRateLimitCalls) - require.Empty(t, repo.scopeCalls) require.Empty(t, repo.rateCalls) } @@ -313,15 +281,7 @@ func TestAccountIsSchedulableForModel_AntigravityRateLimits(t *testing.T) { require.False(t, account.IsSchedulableForModel("gemini-3-flash")) account.RateLimitResetAt = nil - account.Extra = map[string]any{ - antigravityQuotaScopesKey: map[string]any{ - "claude": map[string]any{ - "rate_limit_reset_at": future.Format(time.RFC3339), - }, - }, - } - - require.False(t, account.IsSchedulableForModel("claude-sonnet-4-5")) + require.True(t, account.IsSchedulableForModel("claude-sonnet-4-5")) require.True(t, account.IsSchedulableForModel("gemini-3-flash")) } @@ -641,6 +601,7 @@ func TestShouldTriggerAntigravitySmartRetry(t *testing.T) { }`, expectedShouldRetry: false, expectedShouldRateLimit: true, + minWait: 7 * time.Second, modelName: "gemini-pro", }, { @@ -658,6 +619,7 @@ func TestShouldTriggerAntigravitySmartRetry(t *testing.T) { }`, expectedShouldRetry: false, expectedShouldRateLimit: true, + minWait: 39 * time.Second, modelName: "gemini-3-pro-high", }, { @@ -675,6 +637,7 @@ func TestShouldTriggerAntigravitySmartRetry(t *testing.T) { }`, expectedShouldRetry: false, expectedShouldRateLimit: true, + minWait: 30 * time.Second, modelName: "gemini-2.5-flash", }, { @@ -692,6 +655,7 @@ func TestShouldTriggerAntigravitySmartRetry(t *testing.T) { }`, expectedShouldRetry: false, expectedShouldRateLimit: true, + minWait: 30 * time.Second, modelName: "claude-sonnet-4-5", }, } @@ -710,6 +674,11 @@ func TestShouldTriggerAntigravitySmartRetry(t *testing.T) { t.Errorf("wait = %v, want >= %v", wait, tt.minWait) } } + if shouldRateLimit && tt.minWait > 0 { + if wait < tt.minWait { + t.Errorf("rate limit wait = %v, want >= %v", wait, tt.minWait) + } + } if (shouldRetry || shouldRateLimit) && model != tt.modelName { t.Errorf("modelName = %q, want %q", model, tt.modelName) } @@ -809,7 +778,7 @@ func TestSetModelRateLimitByModelName_NotConvertToScope(t *testing.T) { require.NotEqual(t, "claude_sonnet", call.modelKey, "should NOT be scope") } -func TestAntigravityRetryLoop_PreCheck_WaitsWhenRemainingBelowThreshold(t *testing.T) { +func TestAntigravityRetryLoop_PreCheck_SwitchesWhenRateLimited(t *testing.T) { upstream := &recordingOKUpstream{} account := &Account{ ID: 1, @@ -821,19 +790,15 @@ func TestAntigravityRetryLoop_PreCheck_WaitsWhenRemainingBelowThreshold(t *testi Extra: map[string]any{ modelRateLimitsKey: map[string]any{ "claude-sonnet-4-5": map[string]any{ - // RFC3339 here is second-precision; keep it safely in the future. "rate_limit_reset_at": time.Now().Add(2 * time.Second).Format(time.RFC3339), }, }, }, } - ctx, cancel := context.WithTimeout(context.Background(), 30*time.Millisecond) - defer cancel() - svc := &AntigravityGatewayService{} result, err := svc.antigravityRetryLoop(antigravityRetryLoopParams{ - ctx: ctx, + ctx: context.Background(), prefix: "[test]", account: account, accessToken: "token", @@ -842,17 +807,21 @@ func TestAntigravityRetryLoop_PreCheck_WaitsWhenRemainingBelowThreshold(t *testi requestedModel: "claude-sonnet-4-5", httpUpstream: upstream, isStickySession: true, - handleError: func(ctx context.Context, prefix string, account *Account, statusCode int, headers http.Header, body []byte, quotaScope AntigravityQuotaScope, groupID int64, sessionHash string, isStickySession bool) *handleModelRateLimitResult { + handleError: func(ctx context.Context, prefix string, account *Account, statusCode int, headers http.Header, body []byte, requestedModel string, groupID int64, sessionHash string, isStickySession bool) *handleModelRateLimitResult { return nil }, }) - require.ErrorIs(t, err, context.DeadlineExceeded) require.Nil(t, result) - require.Equal(t, 0, upstream.calls, "should not call upstream while waiting on pre-check") + var switchErr *AntigravityAccountSwitchError + require.ErrorAs(t, err, &switchErr) + require.Equal(t, account.ID, switchErr.OriginalAccountID) + require.Equal(t, "claude-sonnet-4-5", switchErr.RateLimitedModel) + require.True(t, switchErr.IsStickySession) + require.Equal(t, 0, upstream.calls, "should not call upstream when switching on pre-check") } -func TestAntigravityRetryLoop_PreCheck_SwitchesWhenRemainingAtOrAboveThreshold(t *testing.T) { +func TestAntigravityRetryLoop_PreCheck_SwitchesWhenRemainingLong(t *testing.T) { upstream := &recordingOKUpstream{} account := &Account{ ID: 2, @@ -881,7 +850,7 @@ func TestAntigravityRetryLoop_PreCheck_SwitchesWhenRemainingAtOrAboveThreshold(t requestedModel: "claude-sonnet-4-5", httpUpstream: upstream, isStickySession: true, - handleError: func(ctx context.Context, prefix string, account *Account, statusCode int, headers http.Header, body []byte, quotaScope AntigravityQuotaScope, groupID int64, sessionHash string, isStickySession bool) *handleModelRateLimitResult { + handleError: func(ctx context.Context, prefix string, account *Account, statusCode int, headers http.Header, body []byte, requestedModel string, groupID int64, sessionHash string, isStickySession bool) *handleModelRateLimitResult { return nil }, }) diff --git a/backend/internal/service/antigravity_smart_retry_test.go b/backend/internal/service/antigravity_smart_retry_test.go index 623dfec5..a7e0d296 100644 --- a/backend/internal/service/antigravity_smart_retry_test.go +++ b/backend/internal/service/antigravity_smart_retry_test.go @@ -13,6 +13,23 @@ import ( "github.com/stretchr/testify/require" ) +// stubSmartRetryCache 用于 handleSmartRetry 测试的 GatewayCache mock +// 仅关注 DeleteSessionAccountID 的调用记录 +type stubSmartRetryCache struct { + GatewayCache // 嵌入接口,未实现的方法 panic(确保只调用预期方法) + deleteCalls []deleteSessionCall +} + +type deleteSessionCall struct { + groupID int64 + sessionHash string +} + +func (c *stubSmartRetryCache) DeleteSessionAccountID(_ context.Context, groupID int64, sessionHash string) error { + c.deleteCalls = append(c.deleteCalls, deleteSessionCall{groupID: groupID, sessionHash: sessionHash}) + return nil +} + // mockSmartRetryUpstream 用于 handleSmartRetry 测试的 mock upstream type mockSmartRetryUpstream struct { responses []*http.Response @@ -58,7 +75,7 @@ func TestHandleSmartRetry_URLLevelRateLimit(t *testing.T) { accessToken: "token", action: "generateContent", body: []byte(`{"input":"test"}`), - handleError: func(ctx context.Context, prefix string, account *Account, statusCode int, headers http.Header, body []byte, quotaScope AntigravityQuotaScope, groupID int64, sessionHash string, isStickySession bool) *handleModelRateLimitResult { + handleError: func(ctx context.Context, prefix string, account *Account, statusCode int, headers http.Header, body []byte, requestedModel string, groupID int64, sessionHash string, isStickySession bool) *handleModelRateLimitResult { return nil }, } @@ -110,7 +127,7 @@ func TestHandleSmartRetry_LongDelay_ReturnsSwitchError(t *testing.T) { body: []byte(`{"input":"test"}`), accountRepo: repo, isStickySession: true, - handleError: func(ctx context.Context, prefix string, account *Account, statusCode int, headers http.Header, body []byte, quotaScope AntigravityQuotaScope, groupID int64, sessionHash string, isStickySession bool) *handleModelRateLimitResult { + handleError: func(ctx context.Context, prefix string, account *Account, statusCode int, headers http.Header, body []byte, requestedModel string, groupID int64, sessionHash string, isStickySession bool) *handleModelRateLimitResult { return nil }, } @@ -177,7 +194,7 @@ func TestHandleSmartRetry_ShortDelay_SmartRetrySuccess(t *testing.T) { action: "generateContent", body: []byte(`{"input":"test"}`), httpUpstream: upstream, - handleError: func(ctx context.Context, prefix string, account *Account, statusCode int, headers http.Header, body []byte, quotaScope AntigravityQuotaScope, groupID int64, sessionHash string, isStickySession bool) *handleModelRateLimitResult { + handleError: func(ctx context.Context, prefix string, account *Account, statusCode int, headers http.Header, body []byte, requestedModel string, groupID int64, sessionHash string, isStickySession bool) *handleModelRateLimitResult { return nil }, } @@ -198,7 +215,7 @@ func TestHandleSmartRetry_ShortDelay_SmartRetrySuccess(t *testing.T) { // TestHandleSmartRetry_ShortDelay_SmartRetryFailed_ReturnsSwitchError 测试智能重试失败后返回 switchError func TestHandleSmartRetry_ShortDelay_SmartRetryFailed_ReturnsSwitchError(t *testing.T) { - // 智能重试后仍然返回 429(需要提供 3 个响应,因为智能重试最多 3 次) + // 智能重试后仍然返回 429(需要提供 1 个响应,因为智能重试最多 1 次) failRespBody := `{ "error": { "status": "RESOURCE_EXHAUSTED", @@ -213,19 +230,9 @@ func TestHandleSmartRetry_ShortDelay_SmartRetryFailed_ReturnsSwitchError(t *test Header: http.Header{}, Body: io.NopCloser(strings.NewReader(failRespBody)), } - failResp2 := &http.Response{ - StatusCode: http.StatusTooManyRequests, - Header: http.Header{}, - Body: io.NopCloser(strings.NewReader(failRespBody)), - } - failResp3 := &http.Response{ - StatusCode: http.StatusTooManyRequests, - Header: http.Header{}, - Body: io.NopCloser(strings.NewReader(failRespBody)), - } upstream := &mockSmartRetryUpstream{ - responses: []*http.Response{failResp1, failResp2, failResp3}, - errors: []error{nil, nil, nil}, + responses: []*http.Response{failResp1}, + errors: []error{nil}, } repo := &stubAntigravityAccountRepo{} @@ -236,7 +243,7 @@ func TestHandleSmartRetry_ShortDelay_SmartRetryFailed_ReturnsSwitchError(t *test Platform: PlatformAntigravity, } - // 3s < 7s 阈值,应该触发智能重试(最多 3 次) + // 3s < 7s 阈值,应该触发智能重试(最多 1 次) respBody := []byte(`{ "error": { "status": "RESOURCE_EXHAUSTED", @@ -262,7 +269,7 @@ func TestHandleSmartRetry_ShortDelay_SmartRetryFailed_ReturnsSwitchError(t *test httpUpstream: upstream, accountRepo: repo, isStickySession: false, - handleError: func(ctx context.Context, prefix string, account *Account, statusCode int, headers http.Header, body []byte, quotaScope AntigravityQuotaScope, groupID int64, sessionHash string, isStickySession bool) *handleModelRateLimitResult { + handleError: func(ctx context.Context, prefix string, account *Account, statusCode int, headers http.Header, body []byte, requestedModel string, groupID int64, sessionHash string, isStickySession bool) *handleModelRateLimitResult { return nil }, } @@ -284,7 +291,7 @@ func TestHandleSmartRetry_ShortDelay_SmartRetryFailed_ReturnsSwitchError(t *test // 验证模型限流已设置 require.Len(t, repo.modelRateLimitCalls, 1) require.Equal(t, "gemini-3-flash", repo.modelRateLimitCalls[0].modelKey) - require.Len(t, upstream.calls, 3, "should have made three retry calls (max attempts)") + require.Len(t, upstream.calls, 1, "should have made one retry call (max attempts)") } // TestHandleSmartRetry_503_ModelCapacityExhausted_ReturnsSwitchError 测试 503 MODEL_CAPACITY_EXHAUSTED 返回 switchError @@ -324,7 +331,7 @@ func TestHandleSmartRetry_503_ModelCapacityExhausted_ReturnsSwitchError(t *testi body: []byte(`{"input":"test"}`), accountRepo: repo, isStickySession: true, - handleError: func(ctx context.Context, prefix string, account *Account, statusCode int, headers http.Header, body []byte, quotaScope AntigravityQuotaScope, groupID int64, sessionHash string, isStickySession bool) *handleModelRateLimitResult { + handleError: func(ctx context.Context, prefix string, account *Account, statusCode int, headers http.Header, body []byte, requestedModel string, groupID int64, sessionHash string, isStickySession bool) *handleModelRateLimitResult { return nil }, } @@ -380,7 +387,7 @@ func TestHandleSmartRetry_NonAntigravityAccount_ContinuesDefaultLogic(t *testing accessToken: "token", action: "generateContent", body: []byte(`{"input":"test"}`), - handleError: func(ctx context.Context, prefix string, account *Account, statusCode int, headers http.Header, body []byte, quotaScope AntigravityQuotaScope, groupID int64, sessionHash string, isStickySession bool) *handleModelRateLimitResult { + handleError: func(ctx context.Context, prefix string, account *Account, statusCode int, headers http.Header, body []byte, requestedModel string, groupID int64, sessionHash string, isStickySession bool) *handleModelRateLimitResult { return nil }, } @@ -429,7 +436,7 @@ func TestHandleSmartRetry_NonModelRateLimit_ContinuesDefaultLogic(t *testing.T) accessToken: "token", action: "generateContent", body: []byte(`{"input":"test"}`), - handleError: func(ctx context.Context, prefix string, account *Account, statusCode int, headers http.Header, body []byte, quotaScope AntigravityQuotaScope, groupID int64, sessionHash string, isStickySession bool) *handleModelRateLimitResult { + handleError: func(ctx context.Context, prefix string, account *Account, statusCode int, headers http.Header, body []byte, requestedModel string, groupID int64, sessionHash string, isStickySession bool) *handleModelRateLimitResult { return nil }, } @@ -480,7 +487,7 @@ func TestHandleSmartRetry_ExactlyAtThreshold_ReturnsSwitchError(t *testing.T) { action: "generateContent", body: []byte(`{"input":"test"}`), accountRepo: repo, - handleError: func(ctx context.Context, prefix string, account *Account, statusCode int, headers http.Header, body []byte, quotaScope AntigravityQuotaScope, groupID int64, sessionHash string, isStickySession bool) *handleModelRateLimitResult { + handleError: func(ctx context.Context, prefix string, account *Account, statusCode int, headers http.Header, body []byte, requestedModel string, groupID int64, sessionHash string, isStickySession bool) *handleModelRateLimitResult { return nil }, } @@ -541,7 +548,7 @@ func TestAntigravityRetryLoop_HandleSmartRetry_SwitchError_Propagates(t *testing httpUpstream: upstream, accountRepo: repo, isStickySession: true, - handleError: func(ctx context.Context, prefix string, account *Account, statusCode int, headers http.Header, body []byte, quotaScope AntigravityQuotaScope, groupID int64, sessionHash string, isStickySession bool) *handleModelRateLimitResult { + handleError: func(ctx context.Context, prefix string, account *Account, statusCode int, headers http.Header, body []byte, requestedModel string, groupID int64, sessionHash string, isStickySession bool) *handleModelRateLimitResult { return nil }, }) @@ -556,19 +563,15 @@ func TestAntigravityRetryLoop_HandleSmartRetry_SwitchError_Propagates(t *testing require.True(t, switchErr.IsStickySession) } -// TestHandleSmartRetry_NetworkError_ContinuesRetry 测试网络错误时继续重试 -func TestHandleSmartRetry_NetworkError_ContinuesRetry(t *testing.T) { - // 第一次网络错误,第二次成功 - successResp := &http.Response{ - StatusCode: http.StatusOK, - Header: http.Header{}, - Body: io.NopCloser(strings.NewReader(`{"result":"ok"}`)), - } +// TestHandleSmartRetry_NetworkError_ExhaustsRetry 测试网络错误时(maxAttempts=1)直接耗尽重试并切换账号 +func TestHandleSmartRetry_NetworkError_ExhaustsRetry(t *testing.T) { + // 唯一一次重试遇到网络错误(nil response) upstream := &mockSmartRetryUpstream{ - responses: []*http.Response{nil, successResp}, // 第一次返回 nil(模拟网络错误) - errors: []error{nil, nil}, // mock 不返回 error,靠 nil response 触发 + responses: []*http.Response{nil}, // 返回 nil(模拟网络错误) + errors: []error{nil}, // mock 不返回 error,靠 nil response 触发 } + repo := &stubAntigravityAccountRepo{} account := &Account{ ID: 8, Name: "acc-8", @@ -600,7 +603,8 @@ func TestHandleSmartRetry_NetworkError_ContinuesRetry(t *testing.T) { action: "generateContent", body: []byte(`{"input":"test"}`), httpUpstream: upstream, - handleError: func(ctx context.Context, prefix string, account *Account, statusCode int, headers http.Header, body []byte, quotaScope AntigravityQuotaScope, groupID int64, sessionHash string, isStickySession bool) *handleModelRateLimitResult { + accountRepo: repo, + handleError: func(ctx context.Context, prefix string, account *Account, statusCode int, headers http.Header, body []byte, requestedModel string, groupID int64, sessionHash string, isStickySession bool) *handleModelRateLimitResult { return nil }, } @@ -612,10 +616,15 @@ func TestHandleSmartRetry_NetworkError_ContinuesRetry(t *testing.T) { require.NotNil(t, result) require.Equal(t, smartRetryActionBreakWithResp, result.action) - require.NotNil(t, result.resp, "should return successful response after network error recovery") - require.Equal(t, http.StatusOK, result.resp.StatusCode) - require.Nil(t, result.switchError, "should not return switchError on success") - require.Len(t, upstream.calls, 2, "should have made two retry calls") + require.Nil(t, result.resp, "should not return resp when switchError is set") + require.NotNil(t, result.switchError, "should return switchError after network error exhausted retry") + require.Equal(t, account.ID, result.switchError.OriginalAccountID) + require.Equal(t, "claude-sonnet-4-5", result.switchError.RateLimitedModel) + require.Len(t, upstream.calls, 1, "should have made one retry call") + + // 验证模型限流已设置 + require.Len(t, repo.modelRateLimitCalls, 1) + require.Equal(t, "claude-sonnet-4-5", repo.modelRateLimitCalls[0].modelKey) } // TestHandleSmartRetry_NoRetryDelay_UsesDefaultRateLimit 测试无 retryDelay 时使用默认 1 分钟限流 @@ -653,7 +662,7 @@ func TestHandleSmartRetry_NoRetryDelay_UsesDefaultRateLimit(t *testing.T) { body: []byte(`{"input":"test"}`), accountRepo: repo, isStickySession: true, - handleError: func(ctx context.Context, prefix string, account *Account, statusCode int, headers http.Header, body []byte, quotaScope AntigravityQuotaScope, groupID int64, sessionHash string, isStickySession bool) *handleModelRateLimitResult { + handleError: func(ctx context.Context, prefix string, account *Account, statusCode int, headers http.Header, body []byte, requestedModel string, groupID int64, sessionHash string, isStickySession bool) *handleModelRateLimitResult { return nil }, } @@ -674,3 +683,617 @@ func TestHandleSmartRetry_NoRetryDelay_UsesDefaultRateLimit(t *testing.T) { require.Len(t, repo.modelRateLimitCalls, 1) require.Equal(t, "claude-sonnet-4-5", repo.modelRateLimitCalls[0].modelKey) } + +// --------------------------------------------------------------------------- +// 以下测试覆盖本次改动: +// 1. antigravitySmartRetryMaxAttempts = 1(仅重试 1 次) +// 2. 智能重试失败后清除粘性会话绑定(DeleteSessionAccountID) +// --------------------------------------------------------------------------- + +// TestSmartRetryMaxAttempts_VerifyConstant 验证常量值为 1 +func TestSmartRetryMaxAttempts_VerifyConstant(t *testing.T) { + require.Equal(t, 1, antigravitySmartRetryMaxAttempts, + "antigravitySmartRetryMaxAttempts should be 1 to prevent repeated rate limiting") +} + +// TestHandleSmartRetry_ShortDelay_StickySession_FailedRetry_ClearsSession +// 核心场景:粘性会话 + 短延迟重试失败 → 必须清除粘性绑定 +func TestHandleSmartRetry_ShortDelay_StickySession_FailedRetry_ClearsSession(t *testing.T) { + failRespBody := `{ + "error": { + "status": "RESOURCE_EXHAUSTED", + "details": [ + {"@type": "type.googleapis.com/google.rpc.ErrorInfo", "metadata": {"model": "claude-sonnet-4-5"}, "reason": "RATE_LIMIT_EXCEEDED"}, + {"@type": "type.googleapis.com/google.rpc.RetryInfo", "retryDelay": "0.1s"} + ] + } + }` + failResp := &http.Response{ + StatusCode: http.StatusTooManyRequests, + Header: http.Header{}, + Body: io.NopCloser(strings.NewReader(failRespBody)), + } + upstream := &mockSmartRetryUpstream{ + responses: []*http.Response{failResp}, + errors: []error{nil}, + } + + repo := &stubAntigravityAccountRepo{} + cache := &stubSmartRetryCache{} + account := &Account{ + ID: 10, + Name: "acc-10", + Type: AccountTypeOAuth, + Platform: PlatformAntigravity, + } + + respBody := []byte(`{ + "error": { + "status": "RESOURCE_EXHAUSTED", + "details": [ + {"@type": "type.googleapis.com/google.rpc.ErrorInfo", "metadata": {"model": "claude-sonnet-4-5"}, "reason": "RATE_LIMIT_EXCEEDED"}, + {"@type": "type.googleapis.com/google.rpc.RetryInfo", "retryDelay": "0.1s"} + ] + } + }`) + resp := &http.Response{ + StatusCode: http.StatusTooManyRequests, + Header: http.Header{}, + Body: io.NopCloser(bytes.NewReader(respBody)), + } + + params := antigravityRetryLoopParams{ + ctx: context.Background(), + prefix: "[test]", + account: account, + accessToken: "token", + action: "generateContent", + body: []byte(`{"input":"test"}`), + httpUpstream: upstream, + accountRepo: repo, + isStickySession: true, + groupID: 42, + sessionHash: "sticky-hash-abc", + handleError: func(ctx context.Context, prefix string, account *Account, statusCode int, headers http.Header, body []byte, requestedModel string, groupID int64, sessionHash string, isStickySession bool) *handleModelRateLimitResult { + return nil + }, + } + + availableURLs := []string{"https://ag-1.test"} + + svc := &AntigravityGatewayService{cache: cache} + result := svc.handleSmartRetry(params, resp, respBody, "https://ag-1.test", 0, availableURLs) + + // 验证返回 switchError + require.NotNil(t, result) + require.Equal(t, smartRetryActionBreakWithResp, result.action) + require.NotNil(t, result.switchError) + require.True(t, result.switchError.IsStickySession, "switchError should carry IsStickySession=true") + require.Equal(t, account.ID, result.switchError.OriginalAccountID) + + // 核心断言:DeleteSessionAccountID 被调用,且参数正确 + require.Len(t, cache.deleteCalls, 1, "should call DeleteSessionAccountID exactly once") + require.Equal(t, int64(42), cache.deleteCalls[0].groupID) + require.Equal(t, "sticky-hash-abc", cache.deleteCalls[0].sessionHash) + + // 验证仅重试 1 次 + require.Len(t, upstream.calls, 1, "should make exactly 1 retry call (maxAttempts=1)") + + // 验证模型限流已设置 + require.Len(t, repo.modelRateLimitCalls, 1) + require.Equal(t, "claude-sonnet-4-5", repo.modelRateLimitCalls[0].modelKey) +} + +// TestHandleSmartRetry_ShortDelay_NonStickySession_FailedRetry_NoDeleteSession +// 非粘性会话 + 短延迟重试失败 → 不应调用 DeleteSessionAccountID(sessionHash 为空) +func TestHandleSmartRetry_ShortDelay_NonStickySession_FailedRetry_NoDeleteSession(t *testing.T) { + failRespBody := `{ + "error": { + "status": "RESOURCE_EXHAUSTED", + "details": [ + {"@type": "type.googleapis.com/google.rpc.ErrorInfo", "metadata": {"model": "gemini-3-flash"}, "reason": "RATE_LIMIT_EXCEEDED"}, + {"@type": "type.googleapis.com/google.rpc.RetryInfo", "retryDelay": "0.1s"} + ] + } + }` + failResp := &http.Response{ + StatusCode: http.StatusTooManyRequests, + Header: http.Header{}, + Body: io.NopCloser(strings.NewReader(failRespBody)), + } + upstream := &mockSmartRetryUpstream{ + responses: []*http.Response{failResp}, + errors: []error{nil}, + } + + repo := &stubAntigravityAccountRepo{} + cache := &stubSmartRetryCache{} + account := &Account{ + ID: 11, + Name: "acc-11", + Type: AccountTypeOAuth, + Platform: PlatformAntigravity, + } + + respBody := []byte(`{ + "error": { + "status": "RESOURCE_EXHAUSTED", + "details": [ + {"@type": "type.googleapis.com/google.rpc.ErrorInfo", "metadata": {"model": "gemini-3-flash"}, "reason": "RATE_LIMIT_EXCEEDED"}, + {"@type": "type.googleapis.com/google.rpc.RetryInfo", "retryDelay": "0.1s"} + ] + } + }`) + resp := &http.Response{ + StatusCode: http.StatusTooManyRequests, + Header: http.Header{}, + Body: io.NopCloser(bytes.NewReader(respBody)), + } + + params := antigravityRetryLoopParams{ + ctx: context.Background(), + prefix: "[test]", + account: account, + accessToken: "token", + action: "generateContent", + body: []byte(`{"input":"test"}`), + httpUpstream: upstream, + accountRepo: repo, + isStickySession: false, + groupID: 42, + sessionHash: "", // 非粘性会话,sessionHash 为空 + handleError: func(ctx context.Context, prefix string, account *Account, statusCode int, headers http.Header, body []byte, requestedModel string, groupID int64, sessionHash string, isStickySession bool) *handleModelRateLimitResult { + return nil + }, + } + + availableURLs := []string{"https://ag-1.test"} + + svc := &AntigravityGatewayService{cache: cache} + result := svc.handleSmartRetry(params, resp, respBody, "https://ag-1.test", 0, availableURLs) + + require.NotNil(t, result) + require.Equal(t, smartRetryActionBreakWithResp, result.action) + require.NotNil(t, result.switchError) + require.False(t, result.switchError.IsStickySession) + + // 核心断言:sessionHash 为空时不应调用 DeleteSessionAccountID + require.Len(t, cache.deleteCalls, 0, "should NOT call DeleteSessionAccountID when sessionHash is empty") +} + +// TestHandleSmartRetry_ShortDelay_StickySession_FailedRetry_NilCache_NoPanic +// 边界:cache 为 nil 时不应 panic +func TestHandleSmartRetry_ShortDelay_StickySession_FailedRetry_NilCache_NoPanic(t *testing.T) { + failRespBody := `{ + "error": { + "status": "RESOURCE_EXHAUSTED", + "details": [ + {"@type": "type.googleapis.com/google.rpc.ErrorInfo", "metadata": {"model": "claude-sonnet-4-5"}, "reason": "RATE_LIMIT_EXCEEDED"}, + {"@type": "type.googleapis.com/google.rpc.RetryInfo", "retryDelay": "0.1s"} + ] + } + }` + failResp := &http.Response{ + StatusCode: http.StatusTooManyRequests, + Header: http.Header{}, + Body: io.NopCloser(strings.NewReader(failRespBody)), + } + upstream := &mockSmartRetryUpstream{ + responses: []*http.Response{failResp}, + errors: []error{nil}, + } + + repo := &stubAntigravityAccountRepo{} + account := &Account{ + ID: 12, + Name: "acc-12", + Type: AccountTypeOAuth, + Platform: PlatformAntigravity, + } + + respBody := []byte(`{ + "error": { + "status": "RESOURCE_EXHAUSTED", + "details": [ + {"@type": "type.googleapis.com/google.rpc.ErrorInfo", "metadata": {"model": "claude-sonnet-4-5"}, "reason": "RATE_LIMIT_EXCEEDED"}, + {"@type": "type.googleapis.com/google.rpc.RetryInfo", "retryDelay": "0.1s"} + ] + } + }`) + resp := &http.Response{ + StatusCode: http.StatusTooManyRequests, + Header: http.Header{}, + Body: io.NopCloser(bytes.NewReader(respBody)), + } + + params := antigravityRetryLoopParams{ + ctx: context.Background(), + prefix: "[test]", + account: account, + accessToken: "token", + action: "generateContent", + body: []byte(`{"input":"test"}`), + httpUpstream: upstream, + accountRepo: repo, + isStickySession: true, + groupID: 42, + sessionHash: "sticky-hash-nil-cache", + handleError: func(ctx context.Context, prefix string, account *Account, statusCode int, headers http.Header, body []byte, requestedModel string, groupID int64, sessionHash string, isStickySession bool) *handleModelRateLimitResult { + return nil + }, + } + + availableURLs := []string{"https://ag-1.test"} + + // cache 为 nil,不应 panic + svc := &AntigravityGatewayService{cache: nil} + require.NotPanics(t, func() { + result := svc.handleSmartRetry(params, resp, respBody, "https://ag-1.test", 0, availableURLs) + require.NotNil(t, result) + require.Equal(t, smartRetryActionBreakWithResp, result.action) + require.NotNil(t, result.switchError) + require.True(t, result.switchError.IsStickySession) + }) +} + +// TestHandleSmartRetry_ShortDelay_StickySession_SuccessRetry_NoDeleteSession +// 重试成功时不应清除粘性会话(只有失败才清除) +func TestHandleSmartRetry_ShortDelay_StickySession_SuccessRetry_NoDeleteSession(t *testing.T) { + successResp := &http.Response{ + StatusCode: http.StatusOK, + Header: http.Header{}, + Body: io.NopCloser(strings.NewReader(`{"result":"ok"}`)), + } + upstream := &mockSmartRetryUpstream{ + responses: []*http.Response{successResp}, + errors: []error{nil}, + } + + cache := &stubSmartRetryCache{} + account := &Account{ + ID: 13, + Name: "acc-13", + Type: AccountTypeOAuth, + Platform: PlatformAntigravity, + } + + respBody := []byte(`{ + "error": { + "status": "RESOURCE_EXHAUSTED", + "details": [ + {"@type": "type.googleapis.com/google.rpc.ErrorInfo", "metadata": {"model": "claude-opus-4"}, "reason": "RATE_LIMIT_EXCEEDED"}, + {"@type": "type.googleapis.com/google.rpc.RetryInfo", "retryDelay": "0.5s"} + ] + } + }`) + resp := &http.Response{ + StatusCode: http.StatusTooManyRequests, + Header: http.Header{}, + Body: io.NopCloser(bytes.NewReader(respBody)), + } + + params := antigravityRetryLoopParams{ + ctx: context.Background(), + prefix: "[test]", + account: account, + accessToken: "token", + action: "generateContent", + body: []byte(`{"input":"test"}`), + httpUpstream: upstream, + isStickySession: true, + groupID: 42, + sessionHash: "sticky-hash-success", + handleError: func(ctx context.Context, prefix string, account *Account, statusCode int, headers http.Header, body []byte, requestedModel string, groupID int64, sessionHash string, isStickySession bool) *handleModelRateLimitResult { + return nil + }, + } + + availableURLs := []string{"https://ag-1.test"} + + svc := &AntigravityGatewayService{cache: cache} + result := svc.handleSmartRetry(params, resp, respBody, "https://ag-1.test", 0, availableURLs) + + require.NotNil(t, result) + require.Equal(t, smartRetryActionBreakWithResp, result.action) + require.NotNil(t, result.resp, "should return successful response") + require.Equal(t, http.StatusOK, result.resp.StatusCode) + require.Nil(t, result.switchError, "should not return switchError on success") + + // 核心断言:重试成功时不应清除粘性会话 + require.Len(t, cache.deleteCalls, 0, "should NOT call DeleteSessionAccountID on successful retry") +} + +// TestHandleSmartRetry_LongDelay_StickySession_NoDeleteInHandleSmartRetry +// 长延迟路径(情况1)在 handleSmartRetry 中不直接调用 DeleteSessionAccountID +// (清除由 handler 层的 shouldClearStickySession 在下次请求时处理) +func TestHandleSmartRetry_LongDelay_StickySession_NoDeleteInHandleSmartRetry(t *testing.T) { + repo := &stubAntigravityAccountRepo{} + cache := &stubSmartRetryCache{} + account := &Account{ + ID: 14, + Name: "acc-14", + Type: AccountTypeOAuth, + Platform: PlatformAntigravity, + } + + // 15s >= 7s 阈值 → 走长延迟路径 + respBody := []byte(`{ + "error": { + "status": "RESOURCE_EXHAUSTED", + "details": [ + {"@type": "type.googleapis.com/google.rpc.ErrorInfo", "metadata": {"model": "claude-sonnet-4-5"}, "reason": "RATE_LIMIT_EXCEEDED"}, + {"@type": "type.googleapis.com/google.rpc.RetryInfo", "retryDelay": "15s"} + ] + } + }`) + resp := &http.Response{ + StatusCode: http.StatusTooManyRequests, + Header: http.Header{}, + Body: io.NopCloser(bytes.NewReader(respBody)), + } + + params := antigravityRetryLoopParams{ + ctx: context.Background(), + prefix: "[test]", + account: account, + accessToken: "token", + action: "generateContent", + body: []byte(`{"input":"test"}`), + accountRepo: repo, + isStickySession: true, + groupID: 42, + sessionHash: "sticky-hash-long-delay", + handleError: func(ctx context.Context, prefix string, account *Account, statusCode int, headers http.Header, body []byte, requestedModel string, groupID int64, sessionHash string, isStickySession bool) *handleModelRateLimitResult { + return nil + }, + } + + availableURLs := []string{"https://ag-1.test"} + + svc := &AntigravityGatewayService{cache: cache} + result := svc.handleSmartRetry(params, resp, respBody, "https://ag-1.test", 0, availableURLs) + + require.NotNil(t, result) + require.Equal(t, smartRetryActionBreakWithResp, result.action) + require.NotNil(t, result.switchError) + require.True(t, result.switchError.IsStickySession) + + // 长延迟路径不在 handleSmartRetry 中调用 DeleteSessionAccountID + // (由上游 handler 的 shouldClearStickySession 处理) + require.Len(t, cache.deleteCalls, 0, + "long delay path should NOT call DeleteSessionAccountID in handleSmartRetry (handled by handler layer)") +} + +// TestHandleSmartRetry_ShortDelay_NetworkError_StickySession_ClearsSession +// 网络错误耗尽重试 + 粘性会话 → 也应清除粘性绑定 +func TestHandleSmartRetry_ShortDelay_NetworkError_StickySession_ClearsSession(t *testing.T) { + upstream := &mockSmartRetryUpstream{ + responses: []*http.Response{nil}, // 网络错误 + errors: []error{nil}, + } + + repo := &stubAntigravityAccountRepo{} + cache := &stubSmartRetryCache{} + account := &Account{ + ID: 15, + Name: "acc-15", + Type: AccountTypeOAuth, + Platform: PlatformAntigravity, + } + + respBody := []byte(`{ + "error": { + "status": "RESOURCE_EXHAUSTED", + "details": [ + {"@type": "type.googleapis.com/google.rpc.ErrorInfo", "metadata": {"model": "gemini-3-flash"}, "reason": "RATE_LIMIT_EXCEEDED"}, + {"@type": "type.googleapis.com/google.rpc.RetryInfo", "retryDelay": "0.1s"} + ] + } + }`) + resp := &http.Response{ + StatusCode: http.StatusTooManyRequests, + Header: http.Header{}, + Body: io.NopCloser(bytes.NewReader(respBody)), + } + + params := antigravityRetryLoopParams{ + ctx: context.Background(), + prefix: "[test]", + account: account, + accessToken: "token", + action: "generateContent", + body: []byte(`{"input":"test"}`), + httpUpstream: upstream, + accountRepo: repo, + isStickySession: true, + groupID: 99, + sessionHash: "sticky-net-error", + handleError: func(ctx context.Context, prefix string, account *Account, statusCode int, headers http.Header, body []byte, requestedModel string, groupID int64, sessionHash string, isStickySession bool) *handleModelRateLimitResult { + return nil + }, + } + + availableURLs := []string{"https://ag-1.test"} + + svc := &AntigravityGatewayService{cache: cache} + result := svc.handleSmartRetry(params, resp, respBody, "https://ag-1.test", 0, availableURLs) + + require.NotNil(t, result) + require.NotNil(t, result.switchError) + require.True(t, result.switchError.IsStickySession) + + // 核心断言:网络错误耗尽重试后也应清除粘性绑定 + require.Len(t, cache.deleteCalls, 1, "should call DeleteSessionAccountID after network error exhausts retry") + require.Equal(t, int64(99), cache.deleteCalls[0].groupID) + require.Equal(t, "sticky-net-error", cache.deleteCalls[0].sessionHash) +} + +// TestHandleSmartRetry_ShortDelay_503_StickySession_FailedRetry_ClearsSession +// 503 + 短延迟 + 粘性会话 + 重试失败 → 清除粘性绑定 +func TestHandleSmartRetry_ShortDelay_503_StickySession_FailedRetry_ClearsSession(t *testing.T) { + failRespBody := `{ + "error": { + "code": 503, + "status": "UNAVAILABLE", + "details": [ + {"@type": "type.googleapis.com/google.rpc.ErrorInfo", "metadata": {"model": "gemini-3-pro"}, "reason": "MODEL_CAPACITY_EXHAUSTED"}, + {"@type": "type.googleapis.com/google.rpc.RetryInfo", "retryDelay": "0.5s"} + ] + } + }` + failResp := &http.Response{ + StatusCode: http.StatusServiceUnavailable, + Header: http.Header{}, + Body: io.NopCloser(strings.NewReader(failRespBody)), + } + upstream := &mockSmartRetryUpstream{ + responses: []*http.Response{failResp}, + errors: []error{nil}, + } + + repo := &stubAntigravityAccountRepo{} + cache := &stubSmartRetryCache{} + account := &Account{ + ID: 16, + Name: "acc-16", + Type: AccountTypeOAuth, + Platform: PlatformAntigravity, + } + + respBody := []byte(`{ + "error": { + "code": 503, + "status": "UNAVAILABLE", + "details": [ + {"@type": "type.googleapis.com/google.rpc.ErrorInfo", "metadata": {"model": "gemini-3-pro"}, "reason": "MODEL_CAPACITY_EXHAUSTED"}, + {"@type": "type.googleapis.com/google.rpc.RetryInfo", "retryDelay": "0.5s"} + ] + } + }`) + resp := &http.Response{ + StatusCode: http.StatusServiceUnavailable, + Header: http.Header{}, + Body: io.NopCloser(bytes.NewReader(respBody)), + } + + params := antigravityRetryLoopParams{ + ctx: context.Background(), + prefix: "[test]", + account: account, + accessToken: "token", + action: "generateContent", + body: []byte(`{"input":"test"}`), + httpUpstream: upstream, + accountRepo: repo, + isStickySession: true, + groupID: 77, + sessionHash: "sticky-503-short", + handleError: func(ctx context.Context, prefix string, account *Account, statusCode int, headers http.Header, body []byte, requestedModel string, groupID int64, sessionHash string, isStickySession bool) *handleModelRateLimitResult { + return nil + }, + } + + availableURLs := []string{"https://ag-1.test"} + + svc := &AntigravityGatewayService{cache: cache} + result := svc.handleSmartRetry(params, resp, respBody, "https://ag-1.test", 0, availableURLs) + + require.NotNil(t, result) + require.NotNil(t, result.switchError) + require.True(t, result.switchError.IsStickySession) + + // 验证粘性绑定被清除 + require.Len(t, cache.deleteCalls, 1) + require.Equal(t, int64(77), cache.deleteCalls[0].groupID) + require.Equal(t, "sticky-503-short", cache.deleteCalls[0].sessionHash) + + // 验证模型限流已设置 + require.Len(t, repo.modelRateLimitCalls, 1) + require.Equal(t, "gemini-3-pro", repo.modelRateLimitCalls[0].modelKey) +} + +// TestAntigravityRetryLoop_SmartRetryFailed_StickySession_SwitchErrorPropagates +// 集成测试:antigravityRetryLoop → handleSmartRetry → switchError 传播 +// 验证 IsStickySession 正确传递到上层,且粘性绑定被清除 +func TestAntigravityRetryLoop_SmartRetryFailed_StickySession_SwitchErrorPropagates(t *testing.T) { + // 初始 429 响应 + initialRespBody := []byte(`{ + "error": { + "status": "RESOURCE_EXHAUSTED", + "details": [ + {"@type": "type.googleapis.com/google.rpc.ErrorInfo", "metadata": {"model": "claude-opus-4-6"}, "reason": "RATE_LIMIT_EXCEEDED"}, + {"@type": "type.googleapis.com/google.rpc.RetryInfo", "retryDelay": "0.1s"} + ] + } + }`) + initialResp := &http.Response{ + StatusCode: http.StatusTooManyRequests, + Header: http.Header{}, + Body: io.NopCloser(bytes.NewReader(initialRespBody)), + } + + // 智能重试也返回 429 + retryRespBody := `{ + "error": { + "status": "RESOURCE_EXHAUSTED", + "details": [ + {"@type": "type.googleapis.com/google.rpc.ErrorInfo", "metadata": {"model": "claude-opus-4-6"}, "reason": "RATE_LIMIT_EXCEEDED"}, + {"@type": "type.googleapis.com/google.rpc.RetryInfo", "retryDelay": "0.1s"} + ] + } + }` + retryResp := &http.Response{ + StatusCode: http.StatusTooManyRequests, + Header: http.Header{}, + Body: io.NopCloser(strings.NewReader(retryRespBody)), + } + + upstream := &mockSmartRetryUpstream{ + responses: []*http.Response{initialResp, retryResp}, + errors: []error{nil, nil}, + } + + repo := &stubAntigravityAccountRepo{} + cache := &stubSmartRetryCache{} + account := &Account{ + ID: 17, + Name: "acc-17", + Type: AccountTypeOAuth, + Platform: PlatformAntigravity, + Schedulable: true, + Status: StatusActive, + Concurrency: 1, + } + + svc := &AntigravityGatewayService{cache: cache} + result, err := svc.antigravityRetryLoop(antigravityRetryLoopParams{ + ctx: context.Background(), + prefix: "[test]", + account: account, + accessToken: "token", + action: "generateContent", + body: []byte(`{"input":"test"}`), + httpUpstream: upstream, + accountRepo: repo, + isStickySession: true, + groupID: 55, + sessionHash: "sticky-loop-test", + handleError: func(ctx context.Context, prefix string, account *Account, statusCode int, headers http.Header, body []byte, requestedModel string, groupID int64, sessionHash string, isStickySession bool) *handleModelRateLimitResult { + return nil + }, + }) + + require.Nil(t, result, "should not return result when switchError") + require.NotNil(t, err, "should return error") + + var switchErr *AntigravityAccountSwitchError + require.ErrorAs(t, err, &switchErr, "error should be AntigravityAccountSwitchError") + require.Equal(t, account.ID, switchErr.OriginalAccountID) + require.Equal(t, "claude-opus-4-6", switchErr.RateLimitedModel) + require.True(t, switchErr.IsStickySession, "IsStickySession must propagate through retryLoop") + + // 验证粘性绑定被清除 + require.Len(t, cache.deleteCalls, 1, "should clear sticky session in handleSmartRetry") + require.Equal(t, int64(55), cache.deleteCalls[0].groupID) + require.Equal(t, "sticky-loop-test", cache.deleteCalls[0].sessionHash) +} diff --git a/backend/internal/service/crs_sync_helpers_test.go b/backend/internal/service/crs_sync_helpers_test.go new file mode 100644 index 00000000..0dc05335 --- /dev/null +++ b/backend/internal/service/crs_sync_helpers_test.go @@ -0,0 +1,112 @@ +package service + +import ( + "testing" +) + +func TestBuildSelectedSet(t *testing.T) { + tests := []struct { + name string + ids []string + wantNil bool + wantSize int + }{ + { + name: "nil input returns nil (backward compatible: create all)", + ids: nil, + wantNil: true, + }, + { + name: "empty slice returns empty map (create none)", + ids: []string{}, + wantNil: false, + wantSize: 0, + }, + { + name: "single ID", + ids: []string{"abc-123"}, + wantNil: false, + wantSize: 1, + }, + { + name: "multiple IDs", + ids: []string{"a", "b", "c"}, + wantNil: false, + wantSize: 3, + }, + { + name: "duplicate IDs are deduplicated", + ids: []string{"a", "a", "b"}, + wantNil: false, + wantSize: 2, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got := buildSelectedSet(tt.ids) + if tt.wantNil { + if got != nil { + t.Errorf("buildSelectedSet(%v) = %v, want nil", tt.ids, got) + } + return + } + if got == nil { + t.Fatalf("buildSelectedSet(%v) = nil, want non-nil map", tt.ids) + } + if len(got) != tt.wantSize { + t.Errorf("buildSelectedSet(%v) has %d entries, want %d", tt.ids, len(got), tt.wantSize) + } + // Verify all unique IDs are present + for _, id := range tt.ids { + if _, ok := got[id]; !ok { + t.Errorf("buildSelectedSet(%v) missing key %q", tt.ids, id) + } + } + }) + } +} + +func TestShouldCreateAccount(t *testing.T) { + tests := []struct { + name string + crsID string + selectedSet map[string]struct{} + want bool + }{ + { + name: "nil set allows all (backward compatible)", + crsID: "any-id", + selectedSet: nil, + want: true, + }, + { + name: "empty set blocks all", + crsID: "any-id", + selectedSet: map[string]struct{}{}, + want: false, + }, + { + name: "ID in set is allowed", + crsID: "abc-123", + selectedSet: map[string]struct{}{"abc-123": {}, "def-456": {}}, + want: true, + }, + { + name: "ID not in set is blocked", + crsID: "xyz-789", + selectedSet: map[string]struct{}{"abc-123": {}, "def-456": {}}, + want: false, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got := shouldCreateAccount(tt.crsID, tt.selectedSet) + if got != tt.want { + t.Errorf("shouldCreateAccount(%q, %v) = %v, want %v", + tt.crsID, tt.selectedSet, got, tt.want) + } + }) + } +} diff --git a/backend/internal/service/crs_sync_service.go b/backend/internal/service/crs_sync_service.go index a6ccb967..040b2357 100644 --- a/backend/internal/service/crs_sync_service.go +++ b/backend/internal/service/crs_sync_service.go @@ -45,10 +45,11 @@ func NewCRSSyncService( } type SyncFromCRSInput struct { - BaseURL string - Username string - Password string - SyncProxies bool + BaseURL string + Username string + Password string + SyncProxies bool + SelectedAccountIDs []string // if non-empty, only create new accounts with these CRS IDs } type SyncFromCRSItemResult struct { @@ -190,25 +191,27 @@ type crsGeminiAPIKeyAccount struct { Extra map[string]any `json:"extra"` } -func (s *CRSSyncService) SyncFromCRS(ctx context.Context, input SyncFromCRSInput) (*SyncFromCRSResult, error) { +// fetchCRSExport validates the connection parameters, authenticates with CRS, +// and returns the exported accounts. Shared by SyncFromCRS and PreviewFromCRS. +func (s *CRSSyncService) fetchCRSExport(ctx context.Context, baseURL, username, password string) (*crsExportResponse, error) { if s.cfg == nil { return nil, errors.New("config is not available") } - baseURL := strings.TrimSpace(input.BaseURL) + normalizedURL := strings.TrimSpace(baseURL) if s.cfg.Security.URLAllowlist.Enabled { - normalized, err := normalizeBaseURL(baseURL, s.cfg.Security.URLAllowlist.CRSHosts, s.cfg.Security.URLAllowlist.AllowPrivateHosts) + normalized, err := normalizeBaseURL(normalizedURL, s.cfg.Security.URLAllowlist.CRSHosts, s.cfg.Security.URLAllowlist.AllowPrivateHosts) if err != nil { return nil, err } - baseURL = normalized + normalizedURL = normalized } else { - normalized, err := urlvalidator.ValidateURLFormat(baseURL, s.cfg.Security.URLAllowlist.AllowInsecureHTTP) + normalized, err := urlvalidator.ValidateURLFormat(normalizedURL, s.cfg.Security.URLAllowlist.AllowInsecureHTTP) if err != nil { return nil, fmt.Errorf("invalid base_url: %w", err) } - baseURL = normalized + normalizedURL = normalized } - if strings.TrimSpace(input.Username) == "" || strings.TrimSpace(input.Password) == "" { + if strings.TrimSpace(username) == "" || strings.TrimSpace(password) == "" { return nil, errors.New("username and password are required") } @@ -221,12 +224,16 @@ func (s *CRSSyncService) SyncFromCRS(ctx context.Context, input SyncFromCRSInput client = &http.Client{Timeout: 20 * time.Second} } - adminToken, err := crsLogin(ctx, client, baseURL, input.Username, input.Password) + adminToken, err := crsLogin(ctx, client, normalizedURL, username, password) if err != nil { return nil, err } - exported, err := crsExportAccounts(ctx, client, baseURL, adminToken) + return crsExportAccounts(ctx, client, normalizedURL, adminToken) +} + +func (s *CRSSyncService) SyncFromCRS(ctx context.Context, input SyncFromCRSInput) (*SyncFromCRSResult, error) { + exported, err := s.fetchCRSExport(ctx, input.BaseURL, input.Username, input.Password) if err != nil { return nil, err } @@ -241,6 +248,8 @@ func (s *CRSSyncService) SyncFromCRS(ctx context.Context, input SyncFromCRSInput ), } + selectedSet := buildSelectedSet(input.SelectedAccountIDs) + var proxies []Proxy if input.SyncProxies { proxies, _ = s.proxyRepo.ListActive(ctx) @@ -329,6 +338,13 @@ func (s *CRSSyncService) SyncFromCRS(ctx context.Context, input SyncFromCRSInput } if existing == nil { + if !shouldCreateAccount(src.ID, selectedSet) { + item.Action = "skipped" + item.Error = "not selected" + result.Skipped++ + result.Items = append(result.Items, item) + continue + } account := &Account{ Name: defaultName(src.Name, src.ID), Platform: PlatformAnthropic, @@ -446,6 +462,13 @@ func (s *CRSSyncService) SyncFromCRS(ctx context.Context, input SyncFromCRSInput } if existing == nil { + if !shouldCreateAccount(src.ID, selectedSet) { + item.Action = "skipped" + item.Error = "not selected" + result.Skipped++ + result.Items = append(result.Items, item) + continue + } account := &Account{ Name: defaultName(src.Name, src.ID), Platform: PlatformAnthropic, @@ -569,6 +592,13 @@ func (s *CRSSyncService) SyncFromCRS(ctx context.Context, input SyncFromCRSInput } if existing == nil { + if !shouldCreateAccount(src.ID, selectedSet) { + item.Action = "skipped" + item.Error = "not selected" + result.Skipped++ + result.Items = append(result.Items, item) + continue + } account := &Account{ Name: defaultName(src.Name, src.ID), Platform: PlatformOpenAI, @@ -690,6 +720,13 @@ func (s *CRSSyncService) SyncFromCRS(ctx context.Context, input SyncFromCRSInput } if existing == nil { + if !shouldCreateAccount(src.ID, selectedSet) { + item.Action = "skipped" + item.Error = "not selected" + result.Skipped++ + result.Items = append(result.Items, item) + continue + } account := &Account{ Name: defaultName(src.Name, src.ID), Platform: PlatformOpenAI, @@ -798,6 +835,13 @@ func (s *CRSSyncService) SyncFromCRS(ctx context.Context, input SyncFromCRSInput } if existing == nil { + if !shouldCreateAccount(src.ID, selectedSet) { + item.Action = "skipped" + item.Error = "not selected" + result.Skipped++ + result.Items = append(result.Items, item) + continue + } account := &Account{ Name: defaultName(src.Name, src.ID), Platform: PlatformGemini, @@ -909,6 +953,13 @@ func (s *CRSSyncService) SyncFromCRS(ctx context.Context, input SyncFromCRSInput } if existing == nil { + if !shouldCreateAccount(src.ID, selectedSet) { + item.Action = "skipped" + item.Error = "not selected" + result.Skipped++ + result.Items = append(result.Items, item) + continue + } account := &Account{ Name: defaultName(src.Name, src.ID), Platform: PlatformGemini, @@ -1253,3 +1304,102 @@ func (s *CRSSyncService) refreshOAuthToken(ctx context.Context, account *Account return newCredentials } + +// buildSelectedSet converts a slice of selected CRS account IDs to a set for O(1) lookup. +// Returns nil if ids is nil (field not sent → backward compatible: create all). +// Returns an empty map if ids is non-nil but empty (user selected none → create none). +func buildSelectedSet(ids []string) map[string]struct{} { + if ids == nil { + return nil + } + set := make(map[string]struct{}, len(ids)) + for _, id := range ids { + set[id] = struct{}{} + } + return set +} + +// shouldCreateAccount checks if a new CRS account should be created based on user selection. +// Returns true if selectedSet is nil (backward compatible: create all) or if crsID is in the set. +func shouldCreateAccount(crsID string, selectedSet map[string]struct{}) bool { + if selectedSet == nil { + return true + } + _, ok := selectedSet[crsID] + return ok +} + +// PreviewFromCRSResult contains the preview of accounts from CRS before sync. +type PreviewFromCRSResult struct { + NewAccounts []CRSPreviewAccount `json:"new_accounts"` + ExistingAccounts []CRSPreviewAccount `json:"existing_accounts"` +} + +// CRSPreviewAccount represents a single account in the preview result. +type CRSPreviewAccount struct { + CRSAccountID string `json:"crs_account_id"` + Kind string `json:"kind"` + Name string `json:"name"` + Platform string `json:"platform"` + Type string `json:"type"` +} + +// PreviewFromCRS connects to CRS, fetches all accounts, and classifies them +// as new or existing by batch-querying local crs_account_id mappings. +func (s *CRSSyncService) PreviewFromCRS(ctx context.Context, input SyncFromCRSInput) (*PreviewFromCRSResult, error) { + exported, err := s.fetchCRSExport(ctx, input.BaseURL, input.Username, input.Password) + if err != nil { + return nil, err + } + + // Batch query all existing CRS account IDs + existingCRSIDs, err := s.accountRepo.ListCRSAccountIDs(ctx) + if err != nil { + return nil, fmt.Errorf("failed to list existing CRS accounts: %w", err) + } + + result := &PreviewFromCRSResult{ + NewAccounts: make([]CRSPreviewAccount, 0), + ExistingAccounts: make([]CRSPreviewAccount, 0), + } + + classify := func(crsID, kind, name, platform, accountType string) { + preview := CRSPreviewAccount{ + CRSAccountID: crsID, + Kind: kind, + Name: defaultName(name, crsID), + Platform: platform, + Type: accountType, + } + if _, exists := existingCRSIDs[crsID]; exists { + result.ExistingAccounts = append(result.ExistingAccounts, preview) + } else { + result.NewAccounts = append(result.NewAccounts, preview) + } + } + + for _, src := range exported.Data.ClaudeAccounts { + authType := strings.TrimSpace(src.AuthType) + if authType == "" { + authType = AccountTypeOAuth + } + classify(src.ID, src.Kind, src.Name, PlatformAnthropic, authType) + } + for _, src := range exported.Data.ClaudeConsoleAccounts { + classify(src.ID, src.Kind, src.Name, PlatformAnthropic, AccountTypeAPIKey) + } + for _, src := range exported.Data.OpenAIOAuthAccounts { + classify(src.ID, src.Kind, src.Name, PlatformOpenAI, AccountTypeOAuth) + } + for _, src := range exported.Data.OpenAIResponsesAccounts { + classify(src.ID, src.Kind, src.Name, PlatformOpenAI, AccountTypeAPIKey) + } + for _, src := range exported.Data.GeminiOAuthAccounts { + classify(src.ID, src.Kind, src.Name, PlatformGemini, AccountTypeOAuth) + } + for _, src := range exported.Data.GeminiAPIKeyAccounts { + classify(src.ID, src.Kind, src.Name, PlatformGemini, AccountTypeAPIKey) + } + + return result, nil +} diff --git a/backend/internal/service/digest_session_store.go b/backend/internal/service/digest_session_store.go new file mode 100644 index 00000000..3ac08936 --- /dev/null +++ b/backend/internal/service/digest_session_store.go @@ -0,0 +1,69 @@ +package service + +import ( + "strconv" + "strings" + "time" + + gocache "github.com/patrickmn/go-cache" +) + +// digestSessionTTL 摘要会话默认 TTL +const digestSessionTTL = 5 * time.Minute + +// sessionEntry flat cache 条目 +type sessionEntry struct { + uuid string + accountID int64 +} + +// DigestSessionStore 内存摘要会话存储(flat cache 实现) +// key: "{groupID}:{prefixHash}|{digestChain}" → *sessionEntry +type DigestSessionStore struct { + cache *gocache.Cache +} + +// NewDigestSessionStore 创建内存摘要会话存储 +func NewDigestSessionStore() *DigestSessionStore { + return &DigestSessionStore{ + cache: gocache.New(digestSessionTTL, time.Minute), + } +} + +// Save 保存摘要会话。oldDigestChain 为 Find 返回的 matchedChain,用于删旧 key。 +func (s *DigestSessionStore) Save(groupID int64, prefixHash, digestChain, uuid string, accountID int64, oldDigestChain string) { + if digestChain == "" { + return + } + ns := buildNS(groupID, prefixHash) + s.cache.Set(ns+digestChain, &sessionEntry{uuid: uuid, accountID: accountID}, gocache.DefaultExpiration) + if oldDigestChain != "" && oldDigestChain != digestChain { + s.cache.Delete(ns + oldDigestChain) + } +} + +// Find 查找摘要会话,从完整 chain 逐段截断,返回最长匹配及对应 matchedChain。 +func (s *DigestSessionStore) Find(groupID int64, prefixHash, digestChain string) (uuid string, accountID int64, matchedChain string, found bool) { + if digestChain == "" { + return "", 0, "", false + } + ns := buildNS(groupID, prefixHash) + chain := digestChain + for { + if val, ok := s.cache.Get(ns + chain); ok { + if e, ok := val.(*sessionEntry); ok { + return e.uuid, e.accountID, chain, true + } + } + i := strings.LastIndex(chain, "-") + if i < 0 { + return "", 0, "", false + } + chain = chain[:i] + } +} + +// buildNS 构建 namespace 前缀 +func buildNS(groupID int64, prefixHash string) string { + return strconv.FormatInt(groupID, 10) + ":" + prefixHash + "|" +} diff --git a/backend/internal/service/digest_session_store_test.go b/backend/internal/service/digest_session_store_test.go new file mode 100644 index 00000000..e505bf30 --- /dev/null +++ b/backend/internal/service/digest_session_store_test.go @@ -0,0 +1,312 @@ +//go:build unit + +package service + +import ( + "fmt" + "sync" + "testing" + "time" + + gocache "github.com/patrickmn/go-cache" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestDigestSessionStore_SaveAndFind(t *testing.T) { + store := NewDigestSessionStore() + + store.Save(1, "prefix", "s:a1-u:b2-m:c3", "uuid-1", 100, "") + + uuid, accountID, _, found := store.Find(1, "prefix", "s:a1-u:b2-m:c3") + require.True(t, found) + assert.Equal(t, "uuid-1", uuid) + assert.Equal(t, int64(100), accountID) +} + +func TestDigestSessionStore_PrefixMatch(t *testing.T) { + store := NewDigestSessionStore() + + // 保存短链 + store.Save(1, "prefix", "u:a-m:b", "uuid-short", 10, "") + + // 用长链查找,应前缀匹配到短链 + uuid, accountID, matchedChain, found := store.Find(1, "prefix", "u:a-m:b-u:c-m:d") + require.True(t, found) + assert.Equal(t, "uuid-short", uuid) + assert.Equal(t, int64(10), accountID) + assert.Equal(t, "u:a-m:b", matchedChain) +} + +func TestDigestSessionStore_LongestPrefixMatch(t *testing.T) { + store := NewDigestSessionStore() + + store.Save(1, "prefix", "u:a", "uuid-1", 1, "") + store.Save(1, "prefix", "u:a-m:b", "uuid-2", 2, "") + store.Save(1, "prefix", "u:a-m:b-u:c", "uuid-3", 3, "") + + // 应匹配最深的 "u:a-m:b-u:c"(从完整 chain 逐段截断,先命中最长的) + uuid, accountID, _, found := store.Find(1, "prefix", "u:a-m:b-u:c-m:d-u:e") + require.True(t, found) + assert.Equal(t, "uuid-3", uuid) + assert.Equal(t, int64(3), accountID) + + // 查找中等长度,应匹配到 "u:a-m:b" + uuid, accountID, _, found = store.Find(1, "prefix", "u:a-m:b-u:x") + require.True(t, found) + assert.Equal(t, "uuid-2", uuid) + assert.Equal(t, int64(2), accountID) +} + +func TestDigestSessionStore_SaveDeletesOldChain(t *testing.T) { + store := NewDigestSessionStore() + + // 第一轮:保存 "u:a-m:b" + store.Save(1, "prefix", "u:a-m:b", "uuid-1", 100, "") + + // 第二轮:同一 uuid 保存更长的链,传入旧 chain + store.Save(1, "prefix", "u:a-m:b-u:c-m:d", "uuid-1", 100, "u:a-m:b") + + // 旧链 "u:a-m:b" 应已被删除 + _, _, _, found := store.Find(1, "prefix", "u:a-m:b") + assert.False(t, found, "old chain should be deleted") + + // 新链应能找到 + uuid, accountID, _, found := store.Find(1, "prefix", "u:a-m:b-u:c-m:d") + require.True(t, found) + assert.Equal(t, "uuid-1", uuid) + assert.Equal(t, int64(100), accountID) +} + +func TestDigestSessionStore_DifferentSessionsNoInterference(t *testing.T) { + store := NewDigestSessionStore() + + // 相同系统提示词,不同用户提示词 + store.Save(1, "prefix", "s:sys-u:user1", "uuid-1", 100, "") + store.Save(1, "prefix", "s:sys-u:user2", "uuid-2", 200, "") + + uuid, accountID, _, found := store.Find(1, "prefix", "s:sys-u:user1-m:reply1") + require.True(t, found) + assert.Equal(t, "uuid-1", uuid) + assert.Equal(t, int64(100), accountID) + + uuid, accountID, _, found = store.Find(1, "prefix", "s:sys-u:user2-m:reply2") + require.True(t, found) + assert.Equal(t, "uuid-2", uuid) + assert.Equal(t, int64(200), accountID) +} + +func TestDigestSessionStore_NoMatch(t *testing.T) { + store := NewDigestSessionStore() + + store.Save(1, "prefix", "u:a-m:b", "uuid-1", 100, "") + + // 完全不同的 chain + _, _, _, found := store.Find(1, "prefix", "u:x-m:y") + assert.False(t, found) +} + +func TestDigestSessionStore_DifferentPrefixHash(t *testing.T) { + store := NewDigestSessionStore() + + store.Save(1, "prefix1", "u:a-m:b", "uuid-1", 100, "") + + // 不同 prefixHash 应隔离 + _, _, _, found := store.Find(1, "prefix2", "u:a-m:b") + assert.False(t, found) +} + +func TestDigestSessionStore_DifferentGroupID(t *testing.T) { + store := NewDigestSessionStore() + + store.Save(1, "prefix", "u:a-m:b", "uuid-1", 100, "") + + // 不同 groupID 应隔离 + _, _, _, found := store.Find(2, "prefix", "u:a-m:b") + assert.False(t, found) +} + +func TestDigestSessionStore_EmptyDigestChain(t *testing.T) { + store := NewDigestSessionStore() + + // 空链不应保存 + store.Save(1, "prefix", "", "uuid-1", 100, "") + _, _, _, found := store.Find(1, "prefix", "") + assert.False(t, found) +} + +func TestDigestSessionStore_TTLExpiration(t *testing.T) { + store := &DigestSessionStore{ + cache: gocache.New(100*time.Millisecond, 50*time.Millisecond), + } + + store.Save(1, "prefix", "u:a-m:b", "uuid-1", 100, "") + + // 立即应该能找到 + _, _, _, found := store.Find(1, "prefix", "u:a-m:b") + require.True(t, found) + + // 等待过期 + 清理周期 + time.Sleep(300 * time.Millisecond) + + // 过期后应找不到 + _, _, _, found = store.Find(1, "prefix", "u:a-m:b") + assert.False(t, found) +} + +func TestDigestSessionStore_ConcurrentSafety(t *testing.T) { + store := NewDigestSessionStore() + + var wg sync.WaitGroup + const goroutines = 50 + const operations = 100 + + wg.Add(goroutines) + for g := 0; g < goroutines; g++ { + go func(id int) { + defer wg.Done() + prefix := fmt.Sprintf("prefix-%d", id%5) + for i := 0; i < operations; i++ { + chain := fmt.Sprintf("u:%d-m:%d", id, i) + uuid := fmt.Sprintf("uuid-%d-%d", id, i) + store.Save(1, prefix, chain, uuid, int64(id), "") + store.Find(1, prefix, chain) + } + }(g) + } + wg.Wait() +} + +func TestDigestSessionStore_MultipleSessions(t *testing.T) { + store := NewDigestSessionStore() + + sessions := []struct { + chain string + uuid string + accountID int64 + }{ + {"u:session1", "uuid-1", 1}, + {"u:session2-m:reply2", "uuid-2", 2}, + {"u:session3-m:reply3-u:msg3", "uuid-3", 3}, + } + + for _, sess := range sessions { + store.Save(1, "prefix", sess.chain, sess.uuid, sess.accountID, "") + } + + // 验证每个会话都能正确查找 + for _, sess := range sessions { + uuid, accountID, _, found := store.Find(1, "prefix", sess.chain) + require.True(t, found, "should find session: %s", sess.chain) + assert.Equal(t, sess.uuid, uuid) + assert.Equal(t, sess.accountID, accountID) + } + + // 验证继续对话的场景 + uuid, accountID, _, found := store.Find(1, "prefix", "u:session2-m:reply2-u:newmsg") + require.True(t, found) + assert.Equal(t, "uuid-2", uuid) + assert.Equal(t, int64(2), accountID) +} + +func TestDigestSessionStore_Performance1000Sessions(t *testing.T) { + store := NewDigestSessionStore() + + // 插入 1000 个会话 + for i := 0; i < 1000; i++ { + chain := fmt.Sprintf("s:sys-u:user%d-m:reply%d", i, i) + store.Save(1, "prefix", chain, fmt.Sprintf("uuid-%d", i), int64(i), "") + } + + // 查找性能测试 + start := time.Now() + const lookups = 10000 + for i := 0; i < lookups; i++ { + idx := i % 1000 + chain := fmt.Sprintf("s:sys-u:user%d-m:reply%d-u:newmsg", idx, idx) + _, _, _, found := store.Find(1, "prefix", chain) + assert.True(t, found) + } + elapsed := time.Since(start) + t.Logf("%d lookups in %v (%.0f ns/op)", lookups, elapsed, float64(elapsed.Nanoseconds())/lookups) +} + +func TestDigestSessionStore_FindReturnsMatchedChain(t *testing.T) { + store := NewDigestSessionStore() + + store.Save(1, "prefix", "u:a-m:b-u:c", "uuid-1", 100, "") + + // 精确匹配 + _, _, matchedChain, found := store.Find(1, "prefix", "u:a-m:b-u:c") + require.True(t, found) + assert.Equal(t, "u:a-m:b-u:c", matchedChain) + + // 前缀匹配(截断后命中) + _, _, matchedChain, found = store.Find(1, "prefix", "u:a-m:b-u:c-m:d-u:e") + require.True(t, found) + assert.Equal(t, "u:a-m:b-u:c", matchedChain) +} + +func TestDigestSessionStore_CacheItemCountStable(t *testing.T) { + store := NewDigestSessionStore() + + // 模拟 100 个独立会话,每个进行 10 轮对话 + // 正确传递 oldDigestChain 时,每个会话始终只保留 1 个 key + for conv := 0; conv < 100; conv++ { + var prevMatchedChain string + for round := 0; round < 10; round++ { + chain := fmt.Sprintf("s:sys-u:user%d", conv) + for r := 0; r < round; r++ { + chain += fmt.Sprintf("-m:a%d-u:q%d", r, r+1) + } + uuid := fmt.Sprintf("uuid-conv%d", conv) + + _, _, matched, _ := store.Find(1, "prefix", chain) + store.Save(1, "prefix", chain, uuid, int64(conv), matched) + prevMatchedChain = matched + _ = prevMatchedChain + } + } + + // 100 个会话 × 1 key/会话 = 应该 ≤ 100 个 key + // 允许少量并发残留,但绝不能接近 100×10=1000 + itemCount := store.cache.ItemCount() + assert.LessOrEqual(t, itemCount, 100, "cache should have at most 100 items (1 per conversation), got %d", itemCount) + t.Logf("Cache item count after 100 conversations × 10 rounds: %d", itemCount) +} + +func TestDigestSessionStore_TTLPreventsUnboundedGrowth(t *testing.T) { + // 使用极短 TTL 验证大量写入后 cache 能被清理 + store := &DigestSessionStore{ + cache: gocache.New(100*time.Millisecond, 50*time.Millisecond), + } + + // 插入 500 个不同的 key(无 oldDigestChain,模拟最坏场景:全是新会话首轮) + for i := 0; i < 500; i++ { + chain := fmt.Sprintf("u:user%d", i) + store.Save(1, "prefix", chain, fmt.Sprintf("uuid-%d", i), int64(i), "") + } + + assert.Equal(t, 500, store.cache.ItemCount()) + + // 等待 TTL + 清理周期 + time.Sleep(300 * time.Millisecond) + + assert.Equal(t, 0, store.cache.ItemCount(), "all items should be expired and cleaned up") +} + +func TestDigestSessionStore_SaveSameChainNoDelete(t *testing.T) { + store := NewDigestSessionStore() + + // 保存 chain + store.Save(1, "prefix", "u:a-m:b", "uuid-1", 100, "") + + // 用户重发相同消息:oldDigestChain == digestChain,不应删掉刚设置的 key + store.Save(1, "prefix", "u:a-m:b", "uuid-1", 100, "u:a-m:b") + + // 仍然能找到 + uuid, accountID, _, found := store.Find(1, "prefix", "u:a-m:b") + require.True(t, found) + assert.Equal(t, "uuid-1", uuid) + assert.Equal(t, int64(100), accountID) +} diff --git a/backend/internal/service/error_policy_integration_test.go b/backend/internal/service/error_policy_integration_test.go new file mode 100644 index 00000000..9f8ad938 --- /dev/null +++ b/backend/internal/service/error_policy_integration_test.go @@ -0,0 +1,366 @@ +//go:build unit + +package service + +import ( + "context" + "io" + "net/http" + "strings" + "testing" + "time" + + "github.com/Wei-Shaw/sub2api/internal/config" + "github.com/Wei-Shaw/sub2api/internal/pkg/antigravity" + "github.com/stretchr/testify/require" +) + +// --------------------------------------------------------------------------- +// Mocks (scoped to this file by naming convention) +// --------------------------------------------------------------------------- + +// epFixedUpstream returns a fixed response for every request. +type epFixedUpstream struct { + statusCode int + body string + calls int +} + +func (u *epFixedUpstream) Do(req *http.Request, proxyURL string, accountID int64, accountConcurrency int) (*http.Response, error) { + u.calls++ + return &http.Response{ + StatusCode: u.statusCode, + Header: http.Header{}, + Body: io.NopCloser(strings.NewReader(u.body)), + }, nil +} + +func (u *epFixedUpstream) DoWithTLS(req *http.Request, proxyURL string, accountID int64, accountConcurrency int, enableTLSFingerprint bool) (*http.Response, error) { + return u.Do(req, proxyURL, accountID, accountConcurrency) +} + +// epAccountRepo records SetTempUnschedulable / SetError calls. +type epAccountRepo struct { + mockAccountRepoForGemini + tempCalls int + setErrCalls int +} + +func (r *epAccountRepo) SetTempUnschedulable(_ context.Context, _ int64, _ time.Time, _ string) error { + r.tempCalls++ + return nil +} + +func (r *epAccountRepo) SetError(_ context.Context, _ int64, _ string) error { + r.setErrCalls++ + return nil +} + +// --------------------------------------------------------------------------- +// Helpers +// --------------------------------------------------------------------------- + +func saveAndSetBaseURLs(t *testing.T) { + t.Helper() + oldBaseURLs := append([]string(nil), antigravity.BaseURLs...) + oldAvail := antigravity.DefaultURLAvailability + antigravity.BaseURLs = []string{"https://ep-test.example"} + antigravity.DefaultURLAvailability = antigravity.NewURLAvailability(time.Minute) + t.Cleanup(func() { + antigravity.BaseURLs = oldBaseURLs + antigravity.DefaultURLAvailability = oldAvail + }) +} + +func newRetryParams(account *Account, upstream HTTPUpstream, handleError func(context.Context, string, *Account, int, http.Header, []byte, string, int64, string, bool) *handleModelRateLimitResult) antigravityRetryLoopParams { + return antigravityRetryLoopParams{ + ctx: context.Background(), + prefix: "[ep-test]", + account: account, + accessToken: "token", + action: "generateContent", + body: []byte(`{"input":"test"}`), + httpUpstream: upstream, + requestedModel: "claude-sonnet-4-5", + handleError: handleError, + } +} + +// --------------------------------------------------------------------------- +// TestRetryLoop_ErrorPolicy_CustomErrorCodes +// --------------------------------------------------------------------------- + +func TestRetryLoop_ErrorPolicy_CustomErrorCodes(t *testing.T) { + tests := []struct { + name string + upstreamStatus int + upstreamBody string + customCodes []any + expectHandleError int + expectUpstream int + expectStatusCode int + }{ + { + name: "429_in_custom_codes_matched", + upstreamStatus: 429, + upstreamBody: `{"error":"rate limited"}`, + customCodes: []any{float64(429)}, + expectHandleError: 1, + expectUpstream: 1, + expectStatusCode: 429, + }, + { + name: "429_not_in_custom_codes_skipped", + upstreamStatus: 429, + upstreamBody: `{"error":"rate limited"}`, + customCodes: []any{float64(500)}, + expectHandleError: 0, + expectUpstream: 1, + expectStatusCode: 429, + }, + { + name: "500_in_custom_codes_matched", + upstreamStatus: 500, + upstreamBody: `{"error":"internal"}`, + customCodes: []any{float64(500)}, + expectHandleError: 1, + expectUpstream: 1, + expectStatusCode: 500, + }, + { + name: "500_not_in_custom_codes_skipped", + upstreamStatus: 500, + upstreamBody: `{"error":"internal"}`, + customCodes: []any{float64(429)}, + expectHandleError: 0, + expectUpstream: 1, + expectStatusCode: 500, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + saveAndSetBaseURLs(t) + + upstream := &epFixedUpstream{statusCode: tt.upstreamStatus, body: tt.upstreamBody} + repo := &epAccountRepo{} + rlSvc := NewRateLimitService(repo, nil, &config.Config{}, nil, nil) + + account := &Account{ + ID: 100, + Type: AccountTypeAPIKey, + Platform: PlatformAntigravity, + Schedulable: true, + Status: StatusActive, + Concurrency: 1, + Credentials: map[string]any{ + "custom_error_codes_enabled": true, + "custom_error_codes": tt.customCodes, + }, + } + + svc := &AntigravityGatewayService{rateLimitService: rlSvc} + + var handleErrorCount int + p := newRetryParams(account, upstream, func(_ context.Context, _ string, _ *Account, _ int, _ http.Header, _ []byte, _ string, _ int64, _ string, _ bool) *handleModelRateLimitResult { + handleErrorCount++ + return nil + }) + + result, err := svc.antigravityRetryLoop(p) + + require.NoError(t, err) + require.NotNil(t, result) + require.NotNil(t, result.resp) + defer func() { _ = result.resp.Body.Close() }() + + require.Equal(t, tt.expectStatusCode, result.resp.StatusCode) + require.Equal(t, tt.expectHandleError, handleErrorCount, "handleError call count") + require.Equal(t, tt.expectUpstream, upstream.calls, "upstream call count") + }) + } +} + +// --------------------------------------------------------------------------- +// TestRetryLoop_ErrorPolicy_TempUnschedulable +// --------------------------------------------------------------------------- + +func TestRetryLoop_ErrorPolicy_TempUnschedulable(t *testing.T) { + tempRulesAccount := func(rules []any) *Account { + return &Account{ + ID: 200, + Type: AccountTypeOAuth, + Platform: PlatformAntigravity, + Schedulable: true, + Status: StatusActive, + Concurrency: 1, + Credentials: map[string]any{ + "temp_unschedulable_enabled": true, + "temp_unschedulable_rules": rules, + }, + } + } + + overloadedRule := map[string]any{ + "error_code": float64(503), + "keywords": []any{"overloaded"}, + "duration_minutes": float64(10), + } + + rateLimitRule := map[string]any{ + "error_code": float64(429), + "keywords": []any{"rate limited keyword"}, + "duration_minutes": float64(5), + } + + t.Run("503_overloaded_matches_rule", func(t *testing.T) { + saveAndSetBaseURLs(t) + + upstream := &epFixedUpstream{statusCode: 503, body: `overloaded`} + repo := &epAccountRepo{} + rlSvc := NewRateLimitService(repo, nil, &config.Config{}, nil, nil) + svc := &AntigravityGatewayService{rateLimitService: rlSvc} + + account := tempRulesAccount([]any{overloadedRule}) + p := newRetryParams(account, upstream, func(_ context.Context, _ string, _ *Account, _ int, _ http.Header, _ []byte, _ string, _ int64, _ string, _ bool) *handleModelRateLimitResult { + t.Error("handleError should not be called for temp unschedulable") + return nil + }) + + result, err := svc.antigravityRetryLoop(p) + + require.Nil(t, result) + var switchErr *AntigravityAccountSwitchError + require.ErrorAs(t, err, &switchErr) + require.Equal(t, account.ID, switchErr.OriginalAccountID) + require.Equal(t, 1, upstream.calls, "should not retry") + }) + + t.Run("429_rate_limited_keyword_matches_rule", func(t *testing.T) { + saveAndSetBaseURLs(t) + + upstream := &epFixedUpstream{statusCode: 429, body: `rate limited keyword`} + repo := &epAccountRepo{} + rlSvc := NewRateLimitService(repo, nil, &config.Config{}, nil, nil) + svc := &AntigravityGatewayService{rateLimitService: rlSvc} + + account := tempRulesAccount([]any{rateLimitRule}) + p := newRetryParams(account, upstream, func(_ context.Context, _ string, _ *Account, _ int, _ http.Header, _ []byte, _ string, _ int64, _ string, _ bool) *handleModelRateLimitResult { + t.Error("handleError should not be called for temp unschedulable") + return nil + }) + + result, err := svc.antigravityRetryLoop(p) + + require.Nil(t, result) + var switchErr *AntigravityAccountSwitchError + require.ErrorAs(t, err, &switchErr) + require.Equal(t, account.ID, switchErr.OriginalAccountID) + require.Equal(t, 1, upstream.calls, "should not retry") + }) + + t.Run("503_body_no_match_continues_default_retry", func(t *testing.T) { + saveAndSetBaseURLs(t) + + upstream := &epFixedUpstream{statusCode: 503, body: `random`} + repo := &epAccountRepo{} + rlSvc := NewRateLimitService(repo, nil, &config.Config{}, nil, nil) + svc := &AntigravityGatewayService{rateLimitService: rlSvc} + + account := tempRulesAccount([]any{overloadedRule}) + + // Use a short-lived context: the backoff sleep (~1s) will be + // interrupted, proving the code entered the default retry path + // instead of breaking early via error policy. + ctx, cancel := context.WithTimeout(context.Background(), 100*time.Millisecond) + defer cancel() + + p := newRetryParams(account, upstream, func(_ context.Context, _ string, _ *Account, _ int, _ http.Header, _ []byte, _ string, _ int64, _ string, _ bool) *handleModelRateLimitResult { + return nil + }) + p.ctx = ctx + + result, err := svc.antigravityRetryLoop(p) + + // Context cancellation during backoff proves default retry was entered + require.Nil(t, result) + require.ErrorIs(t, err, context.DeadlineExceeded) + require.GreaterOrEqual(t, upstream.calls, 1, "should have called upstream at least once") + }) +} + +// --------------------------------------------------------------------------- +// TestRetryLoop_ErrorPolicy_NilRateLimitService +// --------------------------------------------------------------------------- + +func TestRetryLoop_ErrorPolicy_NilRateLimitService(t *testing.T) { + saveAndSetBaseURLs(t) + + upstream := &epFixedUpstream{statusCode: 429, body: `{"error":"rate limited"}`} + // rateLimitService is nil — must not panic + svc := &AntigravityGatewayService{rateLimitService: nil} + + account := &Account{ + ID: 300, + Type: AccountTypeOAuth, + Platform: PlatformAntigravity, + Schedulable: true, + Status: StatusActive, + Concurrency: 1, + } + + ctx, cancel := context.WithTimeout(context.Background(), 100*time.Millisecond) + defer cancel() + + p := newRetryParams(account, upstream, func(_ context.Context, _ string, _ *Account, _ int, _ http.Header, _ []byte, _ string, _ int64, _ string, _ bool) *handleModelRateLimitResult { + return nil + }) + p.ctx = ctx + + // Should not panic; enters the default retry path (eventually times out) + result, err := svc.antigravityRetryLoop(p) + + require.Nil(t, result) + require.ErrorIs(t, err, context.DeadlineExceeded) + require.GreaterOrEqual(t, upstream.calls, 1) +} + +// --------------------------------------------------------------------------- +// TestRetryLoop_ErrorPolicy_NoPolicy_OriginalBehavior +// --------------------------------------------------------------------------- + +func TestRetryLoop_ErrorPolicy_NoPolicy_OriginalBehavior(t *testing.T) { + saveAndSetBaseURLs(t) + + upstream := &epFixedUpstream{statusCode: 429, body: `{"error":"rate limited"}`} + repo := &epAccountRepo{} + rlSvc := NewRateLimitService(repo, nil, &config.Config{}, nil, nil) + svc := &AntigravityGatewayService{rateLimitService: rlSvc} + + // Plain OAuth account with no error policy configured + account := &Account{ + ID: 400, + Type: AccountTypeOAuth, + Platform: PlatformAntigravity, + Schedulable: true, + Status: StatusActive, + Concurrency: 1, + } + + var handleErrorCount int + p := newRetryParams(account, upstream, func(_ context.Context, _ string, _ *Account, _ int, _ http.Header, _ []byte, _ string, _ int64, _ string, _ bool) *handleModelRateLimitResult { + handleErrorCount++ + return nil + }) + + result, err := svc.antigravityRetryLoop(p) + + require.NoError(t, err) + require.NotNil(t, result) + require.NotNil(t, result.resp) + defer func() { _ = result.resp.Body.Close() }() + + require.Equal(t, http.StatusTooManyRequests, result.resp.StatusCode) + require.Equal(t, antigravityMaxRetries, upstream.calls, "should exhaust all retries") + require.Equal(t, 1, handleErrorCount, "handleError should be called once after retries exhausted") +} diff --git a/backend/internal/service/error_policy_test.go b/backend/internal/service/error_policy_test.go new file mode 100644 index 00000000..a8b69c22 --- /dev/null +++ b/backend/internal/service/error_policy_test.go @@ -0,0 +1,289 @@ +//go:build unit + +package service + +import ( + "context" + "net/http" + "testing" + "time" + + "github.com/Wei-Shaw/sub2api/internal/config" + "github.com/stretchr/testify/require" +) + +// --------------------------------------------------------------------------- +// TestCheckErrorPolicy — 6 table-driven cases for the pure logic function +// --------------------------------------------------------------------------- + +func TestCheckErrorPolicy(t *testing.T) { + tests := []struct { + name string + account *Account + statusCode int + body []byte + expected ErrorPolicyResult + }{ + { + name: "no_policy_oauth_returns_none", + account: &Account{ + ID: 1, + Type: AccountTypeOAuth, + Platform: PlatformAntigravity, + // no custom error codes, no temp rules + }, + statusCode: 500, + body: []byte(`"error"`), + expected: ErrorPolicyNone, + }, + { + name: "custom_error_codes_hit_returns_matched", + account: &Account{ + ID: 2, + Type: AccountTypeAPIKey, + Platform: PlatformAntigravity, + Credentials: map[string]any{ + "custom_error_codes_enabled": true, + "custom_error_codes": []any{float64(429), float64(500)}, + }, + }, + statusCode: 500, + body: []byte(`"error"`), + expected: ErrorPolicyMatched, + }, + { + name: "custom_error_codes_miss_returns_skipped", + account: &Account{ + ID: 3, + Type: AccountTypeAPIKey, + Platform: PlatformAntigravity, + Credentials: map[string]any{ + "custom_error_codes_enabled": true, + "custom_error_codes": []any{float64(429), float64(500)}, + }, + }, + statusCode: 503, + body: []byte(`"error"`), + expected: ErrorPolicySkipped, + }, + { + name: "temp_unschedulable_hit_returns_temp_unscheduled", + account: &Account{ + ID: 4, + Type: AccountTypeOAuth, + Platform: PlatformAntigravity, + Credentials: map[string]any{ + "temp_unschedulable_enabled": true, + "temp_unschedulable_rules": []any{ + map[string]any{ + "error_code": float64(503), + "keywords": []any{"overloaded"}, + "duration_minutes": float64(10), + "description": "overloaded rule", + }, + }, + }, + }, + statusCode: 503, + body: []byte(`overloaded service`), + expected: ErrorPolicyTempUnscheduled, + }, + { + name: "temp_unschedulable_body_miss_returns_none", + account: &Account{ + ID: 5, + Type: AccountTypeOAuth, + Platform: PlatformAntigravity, + Credentials: map[string]any{ + "temp_unschedulable_enabled": true, + "temp_unschedulable_rules": []any{ + map[string]any{ + "error_code": float64(503), + "keywords": []any{"overloaded"}, + "duration_minutes": float64(10), + "description": "overloaded rule", + }, + }, + }, + }, + statusCode: 503, + body: []byte(`random msg`), + expected: ErrorPolicyNone, + }, + { + name: "custom_error_codes_override_temp_unschedulable", + account: &Account{ + ID: 6, + Type: AccountTypeAPIKey, + Platform: PlatformAntigravity, + Credentials: map[string]any{ + "custom_error_codes_enabled": true, + "custom_error_codes": []any{float64(503)}, + "temp_unschedulable_enabled": true, + "temp_unschedulable_rules": []any{ + map[string]any{ + "error_code": float64(503), + "keywords": []any{"overloaded"}, + "duration_minutes": float64(10), + "description": "overloaded rule", + }, + }, + }, + }, + statusCode: 503, + body: []byte(`overloaded`), + expected: ErrorPolicyMatched, // custom codes take precedence + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + repo := &errorPolicyRepoStub{} + svc := NewRateLimitService(repo, nil, &config.Config{}, nil, nil) + + result := svc.CheckErrorPolicy(context.Background(), tt.account, tt.statusCode, tt.body) + require.Equal(t, tt.expected, result, "unexpected ErrorPolicyResult") + }) + } +} + +// --------------------------------------------------------------------------- +// TestApplyErrorPolicy — 4 table-driven cases for the wrapper method +// --------------------------------------------------------------------------- + +func TestApplyErrorPolicy(t *testing.T) { + tests := []struct { + name string + account *Account + statusCode int + body []byte + expectedHandled bool + expectedSwitchErr bool // expect *AntigravityAccountSwitchError + handleErrorCalls int + }{ + { + name: "none_not_handled", + account: &Account{ + ID: 10, + Type: AccountTypeOAuth, + Platform: PlatformAntigravity, + }, + statusCode: 500, + body: []byte(`"error"`), + expectedHandled: false, + handleErrorCalls: 0, + }, + { + name: "skipped_handled_no_handleError", + account: &Account{ + ID: 11, + Type: AccountTypeAPIKey, + Platform: PlatformAntigravity, + Credentials: map[string]any{ + "custom_error_codes_enabled": true, + "custom_error_codes": []any{float64(429)}, + }, + }, + statusCode: 500, // not in custom codes + body: []byte(`"error"`), + expectedHandled: true, + handleErrorCalls: 0, + }, + { + name: "matched_handled_calls_handleError", + account: &Account{ + ID: 12, + Type: AccountTypeAPIKey, + Platform: PlatformAntigravity, + Credentials: map[string]any{ + "custom_error_codes_enabled": true, + "custom_error_codes": []any{float64(500)}, + }, + }, + statusCode: 500, + body: []byte(`"error"`), + expectedHandled: true, + handleErrorCalls: 1, + }, + { + name: "temp_unscheduled_returns_switch_error", + account: &Account{ + ID: 13, + Type: AccountTypeOAuth, + Platform: PlatformAntigravity, + Credentials: map[string]any{ + "temp_unschedulable_enabled": true, + "temp_unschedulable_rules": []any{ + map[string]any{ + "error_code": float64(503), + "keywords": []any{"overloaded"}, + "duration_minutes": float64(10), + }, + }, + }, + }, + statusCode: 503, + body: []byte(`overloaded`), + expectedHandled: true, + expectedSwitchErr: true, + handleErrorCalls: 0, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + repo := &errorPolicyRepoStub{} + rlSvc := NewRateLimitService(repo, nil, &config.Config{}, nil, nil) + svc := &AntigravityGatewayService{ + rateLimitService: rlSvc, + } + + var handleErrorCount int + p := antigravityRetryLoopParams{ + ctx: context.Background(), + prefix: "[test]", + account: tt.account, + handleError: func(ctx context.Context, prefix string, account *Account, statusCode int, headers http.Header, body []byte, requestedModel string, groupID int64, sessionHash string, isStickySession bool) *handleModelRateLimitResult { + handleErrorCount++ + return nil + }, + isStickySession: true, + } + + handled, retErr := svc.applyErrorPolicy(p, tt.statusCode, http.Header{}, tt.body) + + require.Equal(t, tt.expectedHandled, handled, "handled mismatch") + require.Equal(t, tt.handleErrorCalls, handleErrorCount, "handleError call count mismatch") + + if tt.expectedSwitchErr { + var switchErr *AntigravityAccountSwitchError + require.ErrorAs(t, retErr, &switchErr) + require.Equal(t, tt.account.ID, switchErr.OriginalAccountID) + } else { + require.NoError(t, retErr) + } + }) + } +} + +// --------------------------------------------------------------------------- +// errorPolicyRepoStub — minimal AccountRepository stub for error policy tests +// --------------------------------------------------------------------------- + +type errorPolicyRepoStub struct { + mockAccountRepoForGemini + tempCalls int + setErrCalls int + lastErrorMsg string +} + +func (r *errorPolicyRepoStub) SetTempUnschedulable(ctx context.Context, id int64, until time.Time, reason string) error { + r.tempCalls++ + return nil +} + +func (r *errorPolicyRepoStub) SetError(ctx context.Context, id int64, errorMsg string) error { + r.setErrCalls++ + r.lastErrorMsg = errorMsg + return nil +} diff --git a/backend/internal/service/gateway_multiplatform_test.go b/backend/internal/service/gateway_multiplatform_test.go index 8551e7d2..c7104fde 100644 --- a/backend/internal/service/gateway_multiplatform_test.go +++ b/backend/internal/service/gateway_multiplatform_test.go @@ -77,7 +77,12 @@ func (m *mockAccountRepoForPlatform) Create(ctx context.Context, account *Accoun func (m *mockAccountRepoForPlatform) GetByCRSAccountID(ctx context.Context, crsAccountID string) (*Account, error) { return nil, nil } -func (m *mockAccountRepoForPlatform) FindByExtraField(ctx context.Context, key string, value interface{}) ([]Account, error) { + +func (m *mockAccountRepoForPlatform) FindByExtraField(ctx context.Context, key string, value any) ([]Account, error) { + return nil, nil +} + +func (m *mockAccountRepoForPlatform) ListCRSAccountIDs(ctx context.Context) (map[string]int64, error) { return nil, nil } func (m *mockAccountRepoForPlatform) Update(ctx context.Context, account *Account) error { @@ -145,9 +150,6 @@ func (m *mockAccountRepoForPlatform) ListSchedulableByGroupIDAndPlatforms(ctx co func (m *mockAccountRepoForPlatform) SetRateLimited(ctx context.Context, id int64, resetAt time.Time) error { return nil } -func (m *mockAccountRepoForPlatform) SetAntigravityQuotaScopeLimit(ctx context.Context, id int64, scope AntigravityQuotaScope, resetAt time.Time) error { - return nil -} func (m *mockAccountRepoForPlatform) SetModelRateLimit(ctx context.Context, id int64, scope string, resetAt time.Time) error { return nil } @@ -219,22 +221,6 @@ func (m *mockGatewayCacheForPlatform) DeleteSessionAccountID(ctx context.Context return nil } -func (m *mockGatewayCacheForPlatform) IncrModelCallCount(ctx context.Context, accountID int64, model string) (int64, error) { - return 0, nil -} - -func (m *mockGatewayCacheForPlatform) GetModelLoadBatch(ctx context.Context, accountIDs []int64, model string) (map[int64]*ModelLoadInfo, error) { - return nil, nil -} - -func (m *mockGatewayCacheForPlatform) FindGeminiSession(ctx context.Context, groupID int64, prefixHash, digestChain string) (uuid string, accountID int64, found bool) { - return "", 0, false -} - -func (m *mockGatewayCacheForPlatform) SaveGeminiSession(ctx context.Context, groupID int64, prefixHash, digestChain, uuid string, accountID int64) error { - return nil -} - type mockGroupRepoForGateway struct { groups map[int64]*Group getByIDCalls int @@ -293,6 +279,10 @@ func (m *mockGroupRepoForGateway) GetAccountIDsByGroupIDs(ctx context.Context, g return nil, nil } +func (m *mockGroupRepoForGateway) UpdateSortOrders(ctx context.Context, updates []GroupSortOrderUpdate) error { + return nil +} + func ptr[T any](v T) *T { return &v } diff --git a/backend/internal/service/gateway_request.go b/backend/internal/service/gateway_request.go index 0ecd18aa..c039f030 100644 --- a/backend/internal/service/gateway_request.go +++ b/backend/internal/service/gateway_request.go @@ -6,9 +6,19 @@ import ( "fmt" "math" + "github.com/Wei-Shaw/sub2api/internal/domain" "github.com/Wei-Shaw/sub2api/internal/pkg/antigravity" ) +// SessionContext 粘性会话上下文,用于区分不同来源的请求。 +// 仅在 GenerateSessionHash 第 3 级 fallback(消息内容 hash)时混入, +// 避免不同用户发送相同消息产生相同 hash 导致账号集中。 +type SessionContext struct { + ClientIP string + UserAgent string + APIKeyID int64 +} + // ParsedRequest 保存网关请求的预解析结果 // // 性能优化说明: @@ -22,20 +32,22 @@ import ( // 2. 将解析结果 ParsedRequest 传递给 Service 层 // 3. 避免重复 json.Unmarshal,减少 CPU 和内存开销 type ParsedRequest struct { - Body []byte // 原始请求体(保留用于转发) - Model string // 请求的模型名称 - Stream bool // 是否为流式请求 - MetadataUserID string // metadata.user_id(用于会话亲和) - System any // system 字段内容 - Messages []any // messages 数组 - HasSystem bool // 是否包含 system 字段(包含 null 也视为显式传入) - ThinkingEnabled bool // 是否开启 thinking(部分平台会影响最终模型名) - MaxTokens int // max_tokens 值(用于探测请求拦截) + Body []byte // 原始请求体(保留用于转发) + Model string // 请求的模型名称 + Stream bool // 是否为流式请求 + MetadataUserID string // metadata.user_id(用于会话亲和) + System any // system 字段内容 + Messages []any // messages 数组 + HasSystem bool // 是否包含 system 字段(包含 null 也视为显式传入) + ThinkingEnabled bool // 是否开启 thinking(部分平台会影响最终模型名) + MaxTokens int // max_tokens 值(用于探测请求拦截) + SessionContext *SessionContext // 可选:请求上下文区分因子(nil 时行为不变) } -// ParseGatewayRequest 解析网关请求体并返回结构化结果 -// 性能优化:一次解析提取所有需要的字段,避免重复 Unmarshal -func ParseGatewayRequest(body []byte) (*ParsedRequest, error) { +// ParseGatewayRequest 解析网关请求体并返回结构化结果。 +// protocol 指定请求协议格式(domain.PlatformAnthropic / domain.PlatformGemini), +// 不同协议使用不同的 system/messages 字段名。 +func ParseGatewayRequest(body []byte, protocol string) (*ParsedRequest, error) { var req map[string]any if err := json.Unmarshal(body, &req); err != nil { return nil, err @@ -64,14 +76,29 @@ func ParseGatewayRequest(body []byte) (*ParsedRequest, error) { parsed.MetadataUserID = userID } } - // system 字段只要存在就视为显式提供(即使为 null), - // 以避免客户端传 null 时被默认 system 误注入。 - if system, ok := req["system"]; ok { - parsed.HasSystem = true - parsed.System = system - } - if messages, ok := req["messages"].([]any); ok { - parsed.Messages = messages + + switch protocol { + case domain.PlatformGemini: + // Gemini 原生格式: systemInstruction.parts / contents + if sysInst, ok := req["systemInstruction"].(map[string]any); ok { + if parts, ok := sysInst["parts"].([]any); ok { + parsed.System = parts + } + } + if contents, ok := req["contents"].([]any); ok { + parsed.Messages = contents + } + default: + // Anthropic / OpenAI 格式: system / messages + // system 字段只要存在就视为显式提供(即使为 null), + // 以避免客户端传 null 时被默认 system 误注入。 + if system, ok := req["system"]; ok { + parsed.HasSystem = true + parsed.System = system + } + if messages, ok := req["messages"].([]any); ok { + parsed.Messages = messages + } } // thinking: {type: "enabled"} diff --git a/backend/internal/service/gateway_request_test.go b/backend/internal/service/gateway_request_test.go index 4e390b0a..cef41c91 100644 --- a/backend/internal/service/gateway_request_test.go +++ b/backend/internal/service/gateway_request_test.go @@ -4,12 +4,13 @@ import ( "encoding/json" "testing" + "github.com/Wei-Shaw/sub2api/internal/domain" "github.com/stretchr/testify/require" ) func TestParseGatewayRequest(t *testing.T) { body := []byte(`{"model":"claude-3-7-sonnet","stream":true,"metadata":{"user_id":"session_123e4567-e89b-12d3-a456-426614174000"},"system":[{"type":"text","text":"hello","cache_control":{"type":"ephemeral"}}],"messages":[{"content":"hi"}]}`) - parsed, err := ParseGatewayRequest(body) + parsed, err := ParseGatewayRequest(body, "") require.NoError(t, err) require.Equal(t, "claude-3-7-sonnet", parsed.Model) require.True(t, parsed.Stream) @@ -22,7 +23,7 @@ func TestParseGatewayRequest(t *testing.T) { func TestParseGatewayRequest_ThinkingEnabled(t *testing.T) { body := []byte(`{"model":"claude-sonnet-4-5","thinking":{"type":"enabled"},"messages":[{"content":"hi"}]}`) - parsed, err := ParseGatewayRequest(body) + parsed, err := ParseGatewayRequest(body, "") require.NoError(t, err) require.Equal(t, "claude-sonnet-4-5", parsed.Model) require.True(t, parsed.ThinkingEnabled) @@ -30,21 +31,21 @@ func TestParseGatewayRequest_ThinkingEnabled(t *testing.T) { func TestParseGatewayRequest_MaxTokens(t *testing.T) { body := []byte(`{"model":"claude-haiku-4-5","max_tokens":1}`) - parsed, err := ParseGatewayRequest(body) + parsed, err := ParseGatewayRequest(body, "") require.NoError(t, err) require.Equal(t, 1, parsed.MaxTokens) } func TestParseGatewayRequest_MaxTokensNonIntegralIgnored(t *testing.T) { body := []byte(`{"model":"claude-haiku-4-5","max_tokens":1.5}`) - parsed, err := ParseGatewayRequest(body) + parsed, err := ParseGatewayRequest(body, "") require.NoError(t, err) require.Equal(t, 0, parsed.MaxTokens) } func TestParseGatewayRequest_SystemNull(t *testing.T) { body := []byte(`{"model":"claude-3","system":null}`) - parsed, err := ParseGatewayRequest(body) + parsed, err := ParseGatewayRequest(body, "") require.NoError(t, err) // 显式传入 system:null 也应视为“字段已存在”,避免默认 system 被注入。 require.True(t, parsed.HasSystem) @@ -53,16 +54,112 @@ func TestParseGatewayRequest_SystemNull(t *testing.T) { func TestParseGatewayRequest_InvalidModelType(t *testing.T) { body := []byte(`{"model":123}`) - _, err := ParseGatewayRequest(body) + _, err := ParseGatewayRequest(body, "") require.Error(t, err) } func TestParseGatewayRequest_InvalidStreamType(t *testing.T) { body := []byte(`{"stream":"true"}`) - _, err := ParseGatewayRequest(body) + _, err := ParseGatewayRequest(body, "") require.Error(t, err) } +// ============ Gemini 原生格式解析测试 ============ + +func TestParseGatewayRequest_GeminiContents(t *testing.T) { + body := []byte(`{ + "contents": [ + {"role": "user", "parts": [{"text": "Hello"}]}, + {"role": "model", "parts": [{"text": "Hi there"}]}, + {"role": "user", "parts": [{"text": "How are you?"}]} + ] + }`) + parsed, err := ParseGatewayRequest(body, domain.PlatformGemini) + require.NoError(t, err) + require.Len(t, parsed.Messages, 3, "should parse contents as Messages") + require.False(t, parsed.HasSystem, "Gemini format should not set HasSystem") + require.Nil(t, parsed.System, "no systemInstruction means nil System") +} + +func TestParseGatewayRequest_GeminiSystemInstruction(t *testing.T) { + body := []byte(`{ + "systemInstruction": { + "parts": [{"text": "You are a helpful assistant."}] + }, + "contents": [ + {"role": "user", "parts": [{"text": "Hello"}]} + ] + }`) + parsed, err := ParseGatewayRequest(body, domain.PlatformGemini) + require.NoError(t, err) + require.NotNil(t, parsed.System, "should parse systemInstruction.parts as System") + parts, ok := parsed.System.([]any) + require.True(t, ok) + require.Len(t, parts, 1) + partMap, ok := parts[0].(map[string]any) + require.True(t, ok) + require.Equal(t, "You are a helpful assistant.", partMap["text"]) + require.Len(t, parsed.Messages, 1) +} + +func TestParseGatewayRequest_GeminiWithModel(t *testing.T) { + body := []byte(`{ + "model": "gemini-2.5-pro", + "contents": [{"role": "user", "parts": [{"text": "test"}]}] + }`) + parsed, err := ParseGatewayRequest(body, domain.PlatformGemini) + require.NoError(t, err) + require.Equal(t, "gemini-2.5-pro", parsed.Model) + require.Len(t, parsed.Messages, 1) +} + +func TestParseGatewayRequest_GeminiIgnoresAnthropicFields(t *testing.T) { + // Gemini 格式下 system/messages 字段应被忽略 + body := []byte(`{ + "system": "should be ignored", + "messages": [{"role": "user", "content": "ignored"}], + "contents": [{"role": "user", "parts": [{"text": "real content"}]}] + }`) + parsed, err := ParseGatewayRequest(body, domain.PlatformGemini) + require.NoError(t, err) + require.False(t, parsed.HasSystem, "Gemini protocol should not parse Anthropic system field") + require.Nil(t, parsed.System, "no systemInstruction = nil System") + require.Len(t, parsed.Messages, 1, "should use contents, not messages") +} + +func TestParseGatewayRequest_GeminiEmptyContents(t *testing.T) { + body := []byte(`{"contents": []}`) + parsed, err := ParseGatewayRequest(body, domain.PlatformGemini) + require.NoError(t, err) + require.Empty(t, parsed.Messages) +} + +func TestParseGatewayRequest_GeminiNoContents(t *testing.T) { + body := []byte(`{"model": "gemini-2.5-flash"}`) + parsed, err := ParseGatewayRequest(body, domain.PlatformGemini) + require.NoError(t, err) + require.Nil(t, parsed.Messages) + require.Equal(t, "gemini-2.5-flash", parsed.Model) +} + +func TestParseGatewayRequest_AnthropicIgnoresGeminiFields(t *testing.T) { + // Anthropic 格式下 contents/systemInstruction 字段应被忽略 + body := []byte(`{ + "system": "real system", + "messages": [{"role": "user", "content": "real content"}], + "contents": [{"role": "user", "parts": [{"text": "ignored"}]}], + "systemInstruction": {"parts": [{"text": "ignored"}]} + }`) + parsed, err := ParseGatewayRequest(body, domain.PlatformAnthropic) + require.NoError(t, err) + require.True(t, parsed.HasSystem) + require.Equal(t, "real system", parsed.System) + require.Len(t, parsed.Messages, 1) + msg, ok := parsed.Messages[0].(map[string]any) + require.True(t, ok) + require.Equal(t, "real content", msg["content"]) +} + func TestFilterThinkingBlocks(t *testing.T) { containsThinkingBlock := func(body []byte) bool { var req map[string]any diff --git a/backend/internal/service/gateway_service.go b/backend/internal/service/gateway_service.go index 5df5ecba..040745a8 100644 --- a/backend/internal/service/gateway_service.go +++ b/backend/internal/service/gateway_service.go @@ -5,7 +5,6 @@ import ( "bytes" "context" "crypto/sha256" - "encoding/hex" "encoding/json" "errors" "fmt" @@ -17,6 +16,7 @@ import ( "os" "regexp" "sort" + "strconv" "strings" "sync/atomic" "time" @@ -26,6 +26,7 @@ import ( "github.com/Wei-Shaw/sub2api/internal/pkg/ctxkey" "github.com/Wei-Shaw/sub2api/internal/util/responseheaders" "github.com/Wei-Shaw/sub2api/internal/util/urlvalidator" + "github.com/cespare/xxhash/v2" "github.com/google/uuid" "github.com/tidwall/gjson" "github.com/tidwall/sjson" @@ -245,9 +246,6 @@ var ( // ErrClaudeCodeOnly 表示分组仅允许 Claude Code 客户端访问 var ErrClaudeCodeOnly = errors.New("this group only allows Claude Code clients") -// ErrModelScopeNotSupported 表示请求的模型系列不在分组支持的范围内 -var ErrModelScopeNotSupported = errors.New("model scope not supported by this group") - // allowedHeaders 白名单headers(参考CRS项目) var allowedHeaders = map[string]bool{ "accept": true, @@ -273,13 +271,6 @@ var allowedHeaders = map[string]bool{ // GatewayCache 定义网关服务的缓存操作接口。 // 提供粘性会话(Sticky Session)的存储、查询、刷新和删除功能。 // -// ModelLoadInfo 模型负载信息(用于 Antigravity 调度) -// Model load info for Antigravity scheduling -type ModelLoadInfo struct { - CallCount int64 // 当前分钟调用次数 / Call count in current minute - LastUsedAt time.Time // 最后调度时间(零值表示未调度过)/ Last scheduling time (zero means never scheduled) -} - // GatewayCache defines cache operations for gateway service. // Provides sticky session storage, retrieval, refresh and deletion capabilities. type GatewayCache interface { @@ -295,24 +286,6 @@ type GatewayCache interface { // DeleteSessionAccountID 删除粘性会话绑定,用于账号不可用时主动清理 // Delete sticky session binding, used to proactively clean up when account becomes unavailable DeleteSessionAccountID(ctx context.Context, groupID int64, sessionHash string) error - - // IncrModelCallCount 增加模型调用次数并更新最后调度时间(Antigravity 专用) - // Increment model call count and update last scheduling time (Antigravity only) - // 返回更新后的调用次数 - IncrModelCallCount(ctx context.Context, accountID int64, model string) (int64, error) - - // GetModelLoadBatch 批量获取账号的模型负载信息(Antigravity 专用) - // Batch get model load info for accounts (Antigravity only) - GetModelLoadBatch(ctx context.Context, accountIDs []int64, model string) (map[int64]*ModelLoadInfo, error) - - // FindGeminiSession 查找 Gemini 会话(MGET 倒序匹配) - // Find Gemini session using MGET reverse order matching - // 返回最长匹配的会话信息(uuid, accountID) - FindGeminiSession(ctx context.Context, groupID int64, prefixHash, digestChain string) (uuid string, accountID int64, found bool) - - // SaveGeminiSession 保存 Gemini 会话 - // Save Gemini session binding - SaveGeminiSession(ctx context.Context, groupID int64, prefixHash, digestChain, uuid string, accountID int64) error } // derefGroupID safely dereferences *int64 to int64, returning 0 if nil @@ -323,21 +296,15 @@ func derefGroupID(groupID *int64) int64 { return *groupID } -// stickySessionRateLimitThreshold 定义清除粘性会话的限流时间阈值。 -// 当账号限流剩余时间超过此阈值时,清除粘性会话以便切换到其他账号。 -// 低于此阈值时保持粘性会话,等待短暂限流结束。 -const stickySessionRateLimitThreshold = 10 * time.Second - // shouldClearStickySession 检查账号是否处于不可调度状态,需要清理粘性会话绑定。 // 当账号状态为错误、禁用、不可调度、处于临时不可调度期间, -// 或模型限流剩余时间超过 stickySessionRateLimitThreshold 时,返回 true。 +// 或请求的模型处于限流状态时,返回 true。 // 这确保后续请求不会继续使用不可用的账号。 // // shouldClearStickySession checks if an account is in an unschedulable state // and the sticky session binding should be cleared. // Returns true when account status is error/disabled, schedulable is false, -// within temporary unschedulable period, or model rate limit remaining time -// exceeds stickySessionRateLimitThreshold. +// within temporary unschedulable period, or the requested model is rate-limited. // This ensures subsequent requests won't continue using unavailable accounts. func shouldClearStickySession(account *Account, requestedModel string) bool { if account == nil { @@ -349,8 +316,8 @@ func shouldClearStickySession(account *Account, requestedModel string) bool { if account.TempUnschedulableUntil != nil && time.Now().Before(*account.TempUnschedulableUntil) { return true } - // 检查模型限流和 scope 限流,只在超过阈值时清除粘性会话 - if remaining := account.GetRateLimitRemainingTimeWithContext(context.Background(), requestedModel); remaining > stickySessionRateLimitThreshold { + // 检查模型限流和 scope 限流,有限流即清除粘性会话 + if remaining := account.GetRateLimitRemainingTimeWithContext(context.Background(), requestedModel); remaining > 0 { return true } return false @@ -417,6 +384,7 @@ type GatewayService struct { userSubRepo UserSubscriptionRepository userGroupRateRepo UserGroupRateRepository cache GatewayCache + digestStore *DigestSessionStore cfg *config.Config schedulerSnapshot *SchedulerSnapshotService billingService *BillingService @@ -450,6 +418,7 @@ func NewGatewayService( deferredService *DeferredService, claudeTokenProvider *ClaudeTokenProvider, sessionLimitCache SessionLimitCache, + digestStore *DigestSessionStore, ) *GatewayService { return &GatewayService{ accountRepo: accountRepo, @@ -459,6 +428,7 @@ func NewGatewayService( userSubRepo: userSubRepo, userGroupRateRepo: userGroupRateRepo, cache: cache, + digestStore: digestStore, cfg: cfg, schedulerSnapshot: schedulerSnapshot, concurrencyService: concurrencyService, @@ -492,23 +462,45 @@ func (s *GatewayService) GenerateSessionHash(parsed *ParsedRequest) string { return s.hashContent(cacheableContent) } - // 3. Fallback: 使用 system 内容 + // 3. 最后 fallback: 使用 session上下文 + system + 所有消息的完整摘要串 + var combined strings.Builder + // 混入请求上下文区分因子,避免不同用户相同消息产生相同 hash + if parsed.SessionContext != nil { + _, _ = combined.WriteString(parsed.SessionContext.ClientIP) + _, _ = combined.WriteString(":") + _, _ = combined.WriteString(parsed.SessionContext.UserAgent) + _, _ = combined.WriteString(":") + _, _ = combined.WriteString(strconv.FormatInt(parsed.SessionContext.APIKeyID, 10)) + _, _ = combined.WriteString("|") + } if parsed.System != nil { systemText := s.extractTextFromSystem(parsed.System) if systemText != "" { - return s.hashContent(systemText) + _, _ = combined.WriteString(systemText) } } - - // 4. 最后 fallback: 使用第一条消息 - if len(parsed.Messages) > 0 { - if firstMsg, ok := parsed.Messages[0].(map[string]any); ok { - msgText := s.extractTextFromContent(firstMsg["content"]) - if msgText != "" { - return s.hashContent(msgText) + for _, msg := range parsed.Messages { + if m, ok := msg.(map[string]any); ok { + if content, exists := m["content"]; exists { + // Anthropic: messages[].content + if msgText := s.extractTextFromContent(content); msgText != "" { + _, _ = combined.WriteString(msgText) + } + } else if parts, ok := m["parts"].([]any); ok { + // Gemini: contents[].parts[].text + for _, part := range parts { + if partMap, ok := part.(map[string]any); ok { + if text, ok := partMap["text"].(string); ok { + _, _ = combined.WriteString(text) + } + } + } } } } + if combined.Len() > 0 { + return s.hashContent(combined.String()) + } return "" } @@ -536,19 +528,37 @@ func (s *GatewayService) GetCachedSessionAccountID(ctx context.Context, groupID // FindGeminiSession 查找 Gemini 会话(基于内容摘要链的 Fallback 匹配) // 返回最长匹配的会话信息(uuid, accountID) -func (s *GatewayService) FindGeminiSession(ctx context.Context, groupID int64, prefixHash, digestChain string) (uuid string, accountID int64, found bool) { - if digestChain == "" || s.cache == nil { - return "", 0, false +func (s *GatewayService) FindGeminiSession(_ context.Context, groupID int64, prefixHash, digestChain string) (uuid string, accountID int64, matchedChain string, found bool) { + if digestChain == "" || s.digestStore == nil { + return "", 0, "", false } - return s.cache.FindGeminiSession(ctx, groupID, prefixHash, digestChain) + return s.digestStore.Find(groupID, prefixHash, digestChain) } -// SaveGeminiSession 保存 Gemini 会话 -func (s *GatewayService) SaveGeminiSession(ctx context.Context, groupID int64, prefixHash, digestChain, uuid string, accountID int64) error { - if digestChain == "" || s.cache == nil { +// SaveGeminiSession 保存 Gemini 会话。oldDigestChain 为 Find 返回的 matchedChain,用于删旧 key。 +func (s *GatewayService) SaveGeminiSession(_ context.Context, groupID int64, prefixHash, digestChain, uuid string, accountID int64, oldDigestChain string) error { + if digestChain == "" || s.digestStore == nil { return nil } - return s.cache.SaveGeminiSession(ctx, groupID, prefixHash, digestChain, uuid, accountID) + s.digestStore.Save(groupID, prefixHash, digestChain, uuid, accountID, oldDigestChain) + return nil +} + +// FindAnthropicSession 查找 Anthropic 会话(基于内容摘要链的 Fallback 匹配) +func (s *GatewayService) FindAnthropicSession(_ context.Context, groupID int64, prefixHash, digestChain string) (uuid string, accountID int64, matchedChain string, found bool) { + if digestChain == "" || s.digestStore == nil { + return "", 0, "", false + } + return s.digestStore.Find(groupID, prefixHash, digestChain) +} + +// SaveAnthropicSession 保存 Anthropic 会话 +func (s *GatewayService) SaveAnthropicSession(_ context.Context, groupID int64, prefixHash, digestChain, uuid string, accountID int64, oldDigestChain string) error { + if digestChain == "" || s.digestStore == nil { + return nil + } + s.digestStore.Save(groupID, prefixHash, digestChain, uuid, accountID, oldDigestChain) + return nil } func (s *GatewayService) extractCacheableContent(parsed *ParsedRequest) string { @@ -633,8 +643,8 @@ func (s *GatewayService) extractTextFromContent(content any) string { } func (s *GatewayService) hashContent(content string) string { - hash := sha256.Sum256([]byte(content)) - return hex.EncodeToString(hash[:16]) // 32字符 + h := xxhash.Sum64String(content) + return strconv.FormatUint(h, 36) } // replaceModelInBody 替换请求体中的model字段 @@ -993,13 +1003,6 @@ func (s *GatewayService) SelectAccountWithLoadAwareness(ctx context.Context, gro log.Printf("[ModelRoutingDebug] load-aware enabled: group_id=%v model=%s session=%s platform=%s", derefGroupID(groupID), requestedModel, shortSessionHash(sessionHash), platform) } - // Antigravity 模型系列检查(在账号选择前检查,确保所有代码路径都经过此检查) - if platform == PlatformAntigravity && groupID != nil && requestedModel != "" { - if err := s.checkAntigravityModelScope(ctx, *groupID, requestedModel); err != nil { - return nil, err - } - } - accounts, useMixed, err := s.listSchedulableAccounts(ctx, groupID, platform, hasForcePlatform) if err != nil { return nil, err @@ -1114,7 +1117,6 @@ func (s *GatewayService) SelectAccountWithLoadAwareness(ctx context.Context, gro result.ReleaseFunc() // 释放槽位 // 继续到负载感知选择 } else { - _ = s.cache.RefreshSessionTTL(ctx, derefGroupID(groupID), sessionHash, stickySessionTTL) if s.debugModelRoutingEnabled() { log.Printf("[ModelRoutingDebug] routed sticky hit: group_id=%v model=%s session=%s account=%d", derefGroupID(groupID), requestedModel, shortSessionHash(sessionHash), stickyAccountID) } @@ -1194,6 +1196,7 @@ func (s *GatewayService) SelectAccountWithLoadAwareness(ctx context.Context, gro return a.account.LastUsedAt.Before(*b.account.LastUsedAt) } }) + shuffleWithinSortGroups(routingAvailable) // 4. 尝试获取槽位 for _, item := range routingAvailable { @@ -1268,7 +1271,6 @@ func (s *GatewayService) SelectAccountWithLoadAwareness(ctx context.Context, gro if !s.checkAndRegisterSession(ctx, account, sessionHash) { result.ReleaseFunc() // 释放槽位,继续到 Layer 2 } else { - _ = s.cache.RefreshSessionTTL(ctx, derefGroupID(groupID), sessionHash, stickySessionTTL) return &AccountSelectionResult{ Account: account, Acquired: true, @@ -1348,10 +1350,6 @@ func (s *GatewayService) SelectAccountWithLoadAwareness(ctx context.Context, gro return result, nil } } else { - // Antigravity 平台:获取模型负载信息 - var modelLoadMap map[int64]*ModelLoadInfo - isAntigravity := platform == PlatformAntigravity - var available []accountWithLoad for _, acc := range candidates { loadInfo := loadMap[acc.ID] @@ -1366,109 +1364,44 @@ func (s *GatewayService) SelectAccountWithLoadAwareness(ctx context.Context, gro } } - // Antigravity 平台:按账号实际映射后的模型名获取模型负载(与 Forward 的统计保持一致) - if isAntigravity && requestedModel != "" && s.cache != nil && len(available) > 0 { - modelLoadMap = make(map[int64]*ModelLoadInfo, len(available)) - modelToAccountIDs := make(map[string][]int64) - for _, item := range available { - mappedModel := mapAntigravityModel(item.account, requestedModel) - if mappedModel == "" { - continue - } - modelToAccountIDs[mappedModel] = append(modelToAccountIDs[mappedModel], item.account.ID) + // 分层过滤选择:优先级 → 负载率 → LRU + for len(available) > 0 { + // 1. 取优先级最小的集合 + candidates := filterByMinPriority(available) + // 2. 取负载率最低的集合 + candidates = filterByMinLoadRate(candidates) + // 3. LRU 选择最久未用的账号 + selected := selectByLRU(candidates, preferOAuth) + if selected == nil { + break } - for model, ids := range modelToAccountIDs { - batch, err := s.cache.GetModelLoadBatch(ctx, ids, model) - if err != nil { - continue - } - for id, info := range batch { - modelLoadMap[id] = info - } - } - if len(modelLoadMap) == 0 { - modelLoadMap = nil - } - } - // Antigravity 平台:优先级硬过滤 →(同优先级内)按调用次数选择(最少优先,新账号用平均值) - // 其他平台:分层过滤选择:优先级 → 负载率 → LRU - if isAntigravity { - for len(available) > 0 { - // 1. 取优先级最小的集合(硬过滤) - candidates := filterByMinPriority(available) - // 2. 同优先级内按调用次数选择(调用次数最少优先,新账号使用平均值) - selected := selectByCallCount(candidates, modelLoadMap, preferOAuth) - if selected == nil { - break - } - - result, err := s.tryAcquireAccountSlot(ctx, selected.account.ID, selected.account.Concurrency) - if err == nil && result.Acquired { - // 会话数量限制检查 - if !s.checkAndRegisterSession(ctx, selected.account, sessionHash) { - result.ReleaseFunc() // 释放槽位,继续尝试下一个账号 - } else { - if sessionHash != "" && s.cache != nil { - _ = s.cache.SetSessionAccountID(ctx, derefGroupID(groupID), sessionHash, selected.account.ID, stickySessionTTL) - } - return &AccountSelectionResult{ - Account: selected.account, - Acquired: true, - ReleaseFunc: result.ReleaseFunc, - }, nil + result, err := s.tryAcquireAccountSlot(ctx, selected.account.ID, selected.account.Concurrency) + if err == nil && result.Acquired { + // 会话数量限制检查 + if !s.checkAndRegisterSession(ctx, selected.account, sessionHash) { + result.ReleaseFunc() // 释放槽位,继续尝试下一个账号 + } else { + if sessionHash != "" && s.cache != nil { + _ = s.cache.SetSessionAccountID(ctx, derefGroupID(groupID), sessionHash, selected.account.ID, stickySessionTTL) } + return &AccountSelectionResult{ + Account: selected.account, + Acquired: true, + ReleaseFunc: result.ReleaseFunc, + }, nil } - - // 移除已尝试的账号,重新选择 - selectedID := selected.account.ID - newAvailable := make([]accountWithLoad, 0, len(available)-1) - for _, acc := range available { - if acc.account.ID != selectedID { - newAvailable = append(newAvailable, acc) - } - } - available = newAvailable } - } else { - for len(available) > 0 { - // 1. 取优先级最小的集合 - candidates := filterByMinPriority(available) - // 2. 取负载率最低的集合 - candidates = filterByMinLoadRate(candidates) - // 3. LRU 选择最久未用的账号 - selected := selectByLRU(candidates, preferOAuth) - if selected == nil { - break - } - result, err := s.tryAcquireAccountSlot(ctx, selected.account.ID, selected.account.Concurrency) - if err == nil && result.Acquired { - // 会话数量限制检查 - if !s.checkAndRegisterSession(ctx, selected.account, sessionHash) { - result.ReleaseFunc() // 释放槽位,继续尝试下一个账号 - } else { - if sessionHash != "" && s.cache != nil { - _ = s.cache.SetSessionAccountID(ctx, derefGroupID(groupID), sessionHash, selected.account.ID, stickySessionTTL) - } - return &AccountSelectionResult{ - Account: selected.account, - Acquired: true, - ReleaseFunc: result.ReleaseFunc, - }, nil - } + // 移除已尝试的账号,重新进行分层过滤 + selectedID := selected.account.ID + newAvailable := make([]accountWithLoad, 0, len(available)-1) + for _, acc := range available { + if acc.account.ID != selectedID { + newAvailable = append(newAvailable, acc) } - - // 移除已尝试的账号,重新进行分层过滤 - selectedID := selected.account.ID - newAvailable := make([]accountWithLoad, 0, len(available)-1) - for _, acc := range available { - if acc.account.ID != selectedID { - newAvailable = append(newAvailable, acc) - } - } - available = newAvailable } + available = newAvailable } } @@ -2004,87 +1937,79 @@ func sortAccountsByPriorityAndLastUsed(accounts []*Account, preferOAuth bool) { return a.LastUsedAt.Before(*b.LastUsedAt) } }) + shuffleWithinPriorityAndLastUsed(accounts) } -// selectByCallCount 从候选账号中选择调用次数最少的账号(Antigravity 专用) -// 新账号(CallCount=0)使用平均调用次数作为虚拟值,避免冷启动被猛调 -// 如果有多个账号具有相同的最小调用次数,则随机选择一个 -func selectByCallCount(accounts []accountWithLoad, modelLoadMap map[int64]*ModelLoadInfo, preferOAuth bool) *accountWithLoad { - if len(accounts) == 0 { - return nil +// shuffleWithinSortGroups 对排序后的 accountWithLoad 切片,按 (Priority, LoadRate, LastUsedAt) 分组后组内随机打乱。 +// 防止并发请求读取同一快照时,确定性排序导致所有请求命中相同账号。 +func shuffleWithinSortGroups(accounts []accountWithLoad) { + if len(accounts) <= 1 { + return } - if len(accounts) == 1 { - return &accounts[0] - } - - // 如果没有负载信息,回退到 LRU - if modelLoadMap == nil { - return selectByLRU(accounts, preferOAuth) - } - - // 1. 计算平均调用次数(用于新账号冷启动) - var totalCallCount int64 - var countWithCalls int - for _, acc := range accounts { - if info := modelLoadMap[acc.account.ID]; info != nil && info.CallCount > 0 { - totalCallCount += info.CallCount - countWithCalls++ + i := 0 + for i < len(accounts) { + j := i + 1 + for j < len(accounts) && sameAccountWithLoadGroup(accounts[i], accounts[j]) { + j++ } - } - - var avgCallCount int64 - if countWithCalls > 0 { - avgCallCount = totalCallCount / int64(countWithCalls) - } - - // 2. 获取每个账号的有效调用次数 - getEffectiveCallCount := func(acc accountWithLoad) int64 { - if acc.account == nil { - return 0 + if j-i > 1 { + mathrand.Shuffle(j-i, func(a, b int) { + accounts[i+a], accounts[i+b] = accounts[i+b], accounts[i+a] + }) } - info := modelLoadMap[acc.account.ID] - if info == nil || info.CallCount == 0 { - return avgCallCount // 新账号使用平均值 - } - return info.CallCount + i = j } +} - // 3. 找到最小调用次数 - minCount := getEffectiveCallCount(accounts[0]) - for _, acc := range accounts[1:] { - if c := getEffectiveCallCount(acc); c < minCount { - minCount = c - } +// sameAccountWithLoadGroup 判断两个 accountWithLoad 是否属于同一排序组 +func sameAccountWithLoadGroup(a, b accountWithLoad) bool { + if a.account.Priority != b.account.Priority { + return false } - - // 4. 收集所有具有最小调用次数的账号 - var candidateIdxs []int - for i, acc := range accounts { - if getEffectiveCallCount(acc) == minCount { - candidateIdxs = append(candidateIdxs, i) - } + if a.loadInfo.LoadRate != b.loadInfo.LoadRate { + return false } + return sameLastUsedAt(a.account.LastUsedAt, b.account.LastUsedAt) +} - // 5. 如果只有一个候选,直接返回 - if len(candidateIdxs) == 1 { - return &accounts[candidateIdxs[0]] +// shuffleWithinPriorityAndLastUsed 对排序后的 []*Account 切片,按 (Priority, LastUsedAt) 分组后组内随机打乱。 +func shuffleWithinPriorityAndLastUsed(accounts []*Account) { + if len(accounts) <= 1 { + return } - - // 6. preferOAuth 处理 - if preferOAuth { - var oauthIdxs []int - for _, idx := range candidateIdxs { - if accounts[idx].account.Type == AccountTypeOAuth { - oauthIdxs = append(oauthIdxs, idx) - } + i := 0 + for i < len(accounts) { + j := i + 1 + for j < len(accounts) && sameAccountGroup(accounts[i], accounts[j]) { + j++ } - if len(oauthIdxs) > 0 { - candidateIdxs = oauthIdxs + if j-i > 1 { + mathrand.Shuffle(j-i, func(a, b int) { + accounts[i+a], accounts[i+b] = accounts[i+b], accounts[i+a] + }) } + i = j } +} - // 7. 随机选择 - return &accounts[candidateIdxs[mathrand.Intn(len(candidateIdxs))]] +// sameAccountGroup 判断两个 Account 是否属于同一排序组(Priority + LastUsedAt) +func sameAccountGroup(a, b *Account) bool { + if a.Priority != b.Priority { + return false + } + return sameLastUsedAt(a.LastUsedAt, b.LastUsedAt) +} + +// sameLastUsedAt 判断两个 LastUsedAt 是否相同(精度到秒) +func sameLastUsedAt(a, b *time.Time) bool { + switch { + case a == nil && b == nil: + return true + case a == nil || b == nil: + return false + default: + return a.Unix() == b.Unix() + } } // sortCandidatesForFallback 根据配置选择排序策略 @@ -2139,13 +2064,6 @@ func shuffleWithinPriority(accounts []*Account) { // selectAccountForModelWithPlatform 选择单平台账户(完全隔离) func (s *GatewayService) selectAccountForModelWithPlatform(ctx context.Context, groupID *int64, sessionHash string, requestedModel string, excludedIDs map[int64]struct{}, platform string) (*Account, error) { - // 对 Antigravity 平台,检查请求的模型系列是否在分组支持范围内 - if platform == PlatformAntigravity && groupID != nil && requestedModel != "" { - if err := s.checkAntigravityModelScope(ctx, *groupID, requestedModel); err != nil { - return nil, err - } - } - preferOAuth := platform == PlatformGemini routingAccountIDs := s.routingAccountIDsForRequest(ctx, groupID, requestedModel, platform) @@ -2173,9 +2091,6 @@ func (s *GatewayService) selectAccountForModelWithPlatform(ctx context.Context, _ = s.cache.DeleteSessionAccountID(ctx, derefGroupID(groupID), sessionHash) } if !clearSticky && s.isAccountInGroup(account, groupID) && account.Platform == platform && (requestedModel == "" || s.isModelSupportedByAccountWithContext(ctx, account, requestedModel)) && account.IsSchedulableForModelWithContext(ctx, requestedModel) { - if err := s.cache.RefreshSessionTTL(ctx, derefGroupID(groupID), sessionHash, stickySessionTTL); err != nil { - log.Printf("refresh session ttl failed: session=%s err=%v", sessionHash, err) - } if s.debugModelRoutingEnabled() { log.Printf("[ModelRoutingDebug] legacy routed sticky hit: group_id=%v model=%s session=%s account=%d", derefGroupID(groupID), requestedModel, shortSessionHash(sessionHash), accountID) } @@ -2276,9 +2191,6 @@ func (s *GatewayService) selectAccountForModelWithPlatform(ctx context.Context, _ = s.cache.DeleteSessionAccountID(ctx, derefGroupID(groupID), sessionHash) } if !clearSticky && s.isAccountInGroup(account, groupID) && account.Platform == platform && (requestedModel == "" || s.isModelSupportedByAccountWithContext(ctx, account, requestedModel)) && account.IsSchedulableForModelWithContext(ctx, requestedModel) { - if err := s.cache.RefreshSessionTTL(ctx, derefGroupID(groupID), sessionHash, stickySessionTTL); err != nil { - log.Printf("refresh session ttl failed: session=%s err=%v", sessionHash, err) - } return account, nil } } @@ -2387,9 +2299,6 @@ func (s *GatewayService) selectAccountWithMixedScheduling(ctx context.Context, g } if !clearSticky && s.isAccountInGroup(account, groupID) && (requestedModel == "" || s.isModelSupportedByAccountWithContext(ctx, account, requestedModel)) && account.IsSchedulableForModelWithContext(ctx, requestedModel) { if account.Platform == nativePlatform || (account.Platform == PlatformAntigravity && account.IsMixedSchedulingEnabled()) { - if err := s.cache.RefreshSessionTTL(ctx, derefGroupID(groupID), sessionHash, stickySessionTTL); err != nil { - log.Printf("refresh session ttl failed: session=%s err=%v", sessionHash, err) - } if s.debugModelRoutingEnabled() { log.Printf("[ModelRoutingDebug] legacy mixed routed sticky hit: group_id=%v model=%s session=%s account=%d", derefGroupID(groupID), requestedModel, shortSessionHash(sessionHash), accountID) } @@ -2492,9 +2401,6 @@ func (s *GatewayService) selectAccountWithMixedScheduling(ctx context.Context, g } if !clearSticky && s.isAccountInGroup(account, groupID) && (requestedModel == "" || s.isModelSupportedByAccountWithContext(ctx, account, requestedModel)) && account.IsSchedulableForModelWithContext(ctx, requestedModel) { if account.Platform == nativePlatform || (account.Platform == PlatformAntigravity && account.IsMixedSchedulingEnabled()) { - if err := s.cache.RefreshSessionTTL(ctx, derefGroupID(groupID), sessionHash, stickySessionTTL); err != nil { - log.Printf("refresh session ttl failed: session=%s err=%v", sessionHash, err) - } return account, nil } } @@ -5185,27 +5091,6 @@ func (s *GatewayService) validateUpstreamBaseURL(raw string) (string, error) { return normalized, nil } -// checkAntigravityModelScope 检查 Antigravity 平台的模型系列是否在分组支持范围内 -func (s *GatewayService) checkAntigravityModelScope(ctx context.Context, groupID int64, requestedModel string) error { - scope, ok := ResolveAntigravityQuotaScope(requestedModel) - if !ok { - return nil // 无法解析 scope,跳过检查 - } - - group, err := s.resolveGroupByID(ctx, groupID) - if err != nil { - return nil // 查询失败时放行 - } - if group == nil { - return nil // 分组不存在时放行 - } - - if !IsScopeSupported(group.SupportedModelScopes, scope) { - return ErrModelScopeNotSupported - } - return nil -} - // GetAvailableModels returns the list of models available for a group // It aggregates model_mapping keys from all schedulable accounts in the group func (s *GatewayService) GetAvailableModels(ctx context.Context, groupID *int64, platform string) []string { diff --git a/backend/internal/service/gateway_service_benchmark_test.go b/backend/internal/service/gateway_service_benchmark_test.go index f15a85d6..c9c4d3dd 100644 --- a/backend/internal/service/gateway_service_benchmark_test.go +++ b/backend/internal/service/gateway_service_benchmark_test.go @@ -14,7 +14,7 @@ func BenchmarkGenerateSessionHash_Metadata(b *testing.B) { b.ReportAllocs() for i := 0; i < b.N; i++ { - parsed, err := ParseGatewayRequest(body) + parsed, err := ParseGatewayRequest(body, "") if err != nil { b.Fatalf("解析请求失败: %v", err) } diff --git a/backend/internal/service/gemini_error_policy_test.go b/backend/internal/service/gemini_error_policy_test.go new file mode 100644 index 00000000..2ce8793a --- /dev/null +++ b/backend/internal/service/gemini_error_policy_test.go @@ -0,0 +1,384 @@ +//go:build unit + +package service + +import ( + "context" + "net/http" + "net/http/httptest" + "testing" + "time" + + "github.com/Wei-Shaw/sub2api/internal/config" + "github.com/gin-gonic/gin" + "github.com/stretchr/testify/require" +) + +// --------------------------------------------------------------------------- +// TestShouldFailoverGeminiUpstreamError — verifies the failover decision +// for the ErrorPolicyNone path (original logic preserved). +// --------------------------------------------------------------------------- + +func TestShouldFailoverGeminiUpstreamError(t *testing.T) { + svc := &GeminiMessagesCompatService{} + + tests := []struct { + name string + statusCode int + expected bool + }{ + {"401_failover", 401, true}, + {"403_failover", 403, true}, + {"429_failover", 429, true}, + {"529_failover", 529, true}, + {"500_failover", 500, true}, + {"502_failover", 502, true}, + {"503_failover", 503, true}, + {"400_no_failover", 400, false}, + {"404_no_failover", 404, false}, + {"422_no_failover", 422, false}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got := svc.shouldFailoverGeminiUpstreamError(tt.statusCode) + require.Equal(t, tt.expected, got) + }) + } +} + +// --------------------------------------------------------------------------- +// TestCheckErrorPolicy_GeminiAccounts — verifies CheckErrorPolicy works +// correctly for Gemini platform accounts (API Key type). +// --------------------------------------------------------------------------- + +func TestCheckErrorPolicy_GeminiAccounts(t *testing.T) { + tests := []struct { + name string + account *Account + statusCode int + body []byte + expected ErrorPolicyResult + }{ + { + name: "gemini_apikey_custom_codes_hit", + account: &Account{ + ID: 100, + Type: AccountTypeAPIKey, + Platform: PlatformGemini, + Credentials: map[string]any{ + "custom_error_codes_enabled": true, + "custom_error_codes": []any{float64(429), float64(500)}, + }, + }, + statusCode: 429, + body: []byte(`{"error":"rate limited"}`), + expected: ErrorPolicyMatched, + }, + { + name: "gemini_apikey_custom_codes_miss", + account: &Account{ + ID: 101, + Type: AccountTypeAPIKey, + Platform: PlatformGemini, + Credentials: map[string]any{ + "custom_error_codes_enabled": true, + "custom_error_codes": []any{float64(429)}, + }, + }, + statusCode: 500, + body: []byte(`{"error":"internal"}`), + expected: ErrorPolicySkipped, + }, + { + name: "gemini_apikey_no_custom_codes_returns_none", + account: &Account{ + ID: 102, + Type: AccountTypeAPIKey, + Platform: PlatformGemini, + }, + statusCode: 500, + body: []byte(`{"error":"internal"}`), + expected: ErrorPolicyNone, + }, + { + name: "gemini_apikey_temp_unschedulable_hit", + account: &Account{ + ID: 103, + Type: AccountTypeAPIKey, + Platform: PlatformGemini, + Credentials: map[string]any{ + "temp_unschedulable_enabled": true, + "temp_unschedulable_rules": []any{ + map[string]any{ + "error_code": float64(503), + "keywords": []any{"overloaded"}, + "duration_minutes": float64(10), + }, + }, + }, + }, + statusCode: 503, + body: []byte(`overloaded service`), + expected: ErrorPolicyTempUnscheduled, + }, + { + name: "gemini_custom_codes_override_temp_unschedulable", + account: &Account{ + ID: 104, + Type: AccountTypeAPIKey, + Platform: PlatformGemini, + Credentials: map[string]any{ + "custom_error_codes_enabled": true, + "custom_error_codes": []any{float64(503)}, + "temp_unschedulable_enabled": true, + "temp_unschedulable_rules": []any{ + map[string]any{ + "error_code": float64(503), + "keywords": []any{"overloaded"}, + "duration_minutes": float64(10), + }, + }, + }, + }, + statusCode: 503, + body: []byte(`overloaded`), + expected: ErrorPolicyMatched, // custom codes take precedence + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + repo := &errorPolicyRepoStub{} + svc := NewRateLimitService(repo, nil, &config.Config{}, nil, nil) + + result := svc.CheckErrorPolicy(context.Background(), tt.account, tt.statusCode, tt.body) + require.Equal(t, tt.expected, result) + }) + } +} + +// --------------------------------------------------------------------------- +// TestGeminiErrorPolicyIntegration — verifies the Gemini error handling +// paths produce the correct behavior for each ErrorPolicyResult. +// +// These tests simulate the inline error policy switch in handleClaudeCompat +// and forwardNativeGemini by calling the same methods in the same order. +// --------------------------------------------------------------------------- + +func TestGeminiErrorPolicyIntegration(t *testing.T) { + gin.SetMode(gin.TestMode) + + tests := []struct { + name string + account *Account + statusCode int + respBody []byte + expectFailover bool // expect UpstreamFailoverError + expectHandleError bool // expect handleGeminiUpstreamError to be called + expectShouldFailover bool // for None path, whether shouldFailover triggers + }{ + { + name: "custom_codes_matched_429_failover", + account: &Account{ + ID: 200, + Type: AccountTypeAPIKey, + Platform: PlatformGemini, + Credentials: map[string]any{ + "custom_error_codes_enabled": true, + "custom_error_codes": []any{float64(429)}, + }, + }, + statusCode: 429, + respBody: []byte(`{"error":"rate limited"}`), + expectFailover: true, + expectHandleError: true, + }, + { + name: "custom_codes_skipped_500_no_failover", + account: &Account{ + ID: 201, + Type: AccountTypeAPIKey, + Platform: PlatformGemini, + Credentials: map[string]any{ + "custom_error_codes_enabled": true, + "custom_error_codes": []any{float64(429)}, + }, + }, + statusCode: 500, + respBody: []byte(`{"error":"internal"}`), + expectFailover: false, + expectHandleError: false, + }, + { + name: "temp_unschedulable_matched_failover", + account: &Account{ + ID: 202, + Type: AccountTypeAPIKey, + Platform: PlatformGemini, + Credentials: map[string]any{ + "temp_unschedulable_enabled": true, + "temp_unschedulable_rules": []any{ + map[string]any{ + "error_code": float64(503), + "keywords": []any{"overloaded"}, + "duration_minutes": float64(10), + }, + }, + }, + }, + statusCode: 503, + respBody: []byte(`overloaded`), + expectFailover: true, + expectHandleError: true, + }, + { + name: "no_policy_429_failover_via_shouldFailover", + account: &Account{ + ID: 203, + Type: AccountTypeAPIKey, + Platform: PlatformGemini, + }, + statusCode: 429, + respBody: []byte(`{"error":"rate limited"}`), + expectFailover: true, + expectHandleError: true, + expectShouldFailover: true, + }, + { + name: "no_policy_400_no_failover", + account: &Account{ + ID: 204, + Type: AccountTypeAPIKey, + Platform: PlatformGemini, + }, + statusCode: 400, + respBody: []byte(`{"error":"bad request"}`), + expectFailover: false, + expectHandleError: true, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + repo := &geminiErrorPolicyRepo{} + rlSvc := NewRateLimitService(repo, nil, &config.Config{}, nil, nil) + svc := &GeminiMessagesCompatService{ + accountRepo: repo, + rateLimitService: rlSvc, + } + + writer := httptest.NewRecorder() + c, _ := gin.CreateTestContext(writer) + c.Request = httptest.NewRequest(http.MethodPost, "/v1/messages", nil) + + // Simulate the Claude compat error handling path (same logic as native). + // This mirrors the inline switch in handleClaudeCompat. + var handleErrorCalled bool + var gotFailover bool + + ctx := context.Background() + statusCode := tt.statusCode + respBody := tt.respBody + account := tt.account + headers := http.Header{} + + if svc.rateLimitService != nil { + switch svc.rateLimitService.CheckErrorPolicy(ctx, account, statusCode, respBody) { + case ErrorPolicySkipped: + // Skipped → return error directly (no handleGeminiUpstreamError, no failover) + gotFailover = false + handleErrorCalled = false + goto verify + case ErrorPolicyMatched, ErrorPolicyTempUnscheduled: + svc.handleGeminiUpstreamError(ctx, account, statusCode, headers, respBody) + handleErrorCalled = true + gotFailover = true + goto verify + } + } + + // ErrorPolicyNone → original logic + svc.handleGeminiUpstreamError(ctx, account, statusCode, headers, respBody) + handleErrorCalled = true + if svc.shouldFailoverGeminiUpstreamError(statusCode) { + gotFailover = true + } + + verify: + require.Equal(t, tt.expectFailover, gotFailover, "failover mismatch") + require.Equal(t, tt.expectHandleError, handleErrorCalled, "handleGeminiUpstreamError call mismatch") + + if tt.expectShouldFailover { + require.True(t, svc.shouldFailoverGeminiUpstreamError(statusCode), + "shouldFailoverGeminiUpstreamError should return true for status %d", statusCode) + } + }) + } +} + +// --------------------------------------------------------------------------- +// TestGeminiErrorPolicy_NilRateLimitService — verifies nil safety +// --------------------------------------------------------------------------- + +func TestGeminiErrorPolicy_NilRateLimitService(t *testing.T) { + svc := &GeminiMessagesCompatService{ + rateLimitService: nil, + } + + // When rateLimitService is nil, error policy is skipped → falls through to + // shouldFailoverGeminiUpstreamError (original logic). + // Verify this doesn't panic and follows expected behavior. + + ctx := context.Background() + account := &Account{ + ID: 300, + Type: AccountTypeAPIKey, + Platform: PlatformGemini, + Credentials: map[string]any{ + "custom_error_codes_enabled": true, + "custom_error_codes": []any{float64(429)}, + }, + } + + // The nil check should prevent CheckErrorPolicy from being called + if svc.rateLimitService != nil { + t.Fatal("rateLimitService should be nil for this test") + } + + // shouldFailoverGeminiUpstreamError still works + require.True(t, svc.shouldFailoverGeminiUpstreamError(429)) + require.False(t, svc.shouldFailoverGeminiUpstreamError(400)) + + // handleGeminiUpstreamError should not panic with nil rateLimitService + require.NotPanics(t, func() { + svc.handleGeminiUpstreamError(ctx, account, 500, http.Header{}, []byte(`error`)) + }) +} + +// --------------------------------------------------------------------------- +// geminiErrorPolicyRepo — minimal AccountRepository stub for Gemini error +// policy tests. Embeds mockAccountRepoForGemini and adds tracking. +// --------------------------------------------------------------------------- + +type geminiErrorPolicyRepo struct { + mockAccountRepoForGemini + setErrorCalls int + setRateLimitedCalls int + setTempCalls int +} + +func (r *geminiErrorPolicyRepo) SetError(_ context.Context, _ int64, _ string) error { + r.setErrorCalls++ + return nil +} + +func (r *geminiErrorPolicyRepo) SetRateLimited(_ context.Context, _ int64, _ time.Time) error { + r.setRateLimitedCalls++ + return nil +} + +func (r *geminiErrorPolicyRepo) SetTempUnschedulable(_ context.Context, _ int64, _ time.Time, _ string) error { + r.setTempCalls++ + return nil +} diff --git a/backend/internal/service/gemini_messages_compat_service.go b/backend/internal/service/gemini_messages_compat_service.go index 0f156c2e..d77f6f92 100644 --- a/backend/internal/service/gemini_messages_compat_service.go +++ b/backend/internal/service/gemini_messages_compat_service.go @@ -560,10 +560,7 @@ func (s *GeminiMessagesCompatService) Forward(ctx context.Context, c *gin.Contex return nil, "", errors.New("gemini api_key not configured") } - baseURL := strings.TrimSpace(account.GetCredential("base_url")) - if baseURL == "" { - baseURL = geminicli.AIStudioBaseURL - } + baseURL := account.GetGeminiBaseURL(geminicli.AIStudioBaseURL) normalizedBaseURL, err := s.validateUpstreamBaseURL(baseURL) if err != nil { return nil, "", err @@ -640,10 +637,7 @@ func (s *GeminiMessagesCompatService) Forward(ctx context.Context, c *gin.Contex return upstreamReq, "x-request-id", nil } else { // Mode 2: AI Studio API with OAuth (like API key mode, but using Bearer token) - baseURL := strings.TrimSpace(account.GetCredential("base_url")) - if baseURL == "" { - baseURL = geminicli.AIStudioBaseURL - } + baseURL := account.GetGeminiBaseURL(geminicli.AIStudioBaseURL) normalizedBaseURL, err := s.validateUpstreamBaseURL(baseURL) if err != nil { return nil, "", err @@ -837,38 +831,47 @@ func (s *GeminiMessagesCompatService) Forward(ctx context.Context, c *gin.Contex if resp.StatusCode >= 400 { respBody, _ := io.ReadAll(io.LimitReader(resp.Body, 2<<20)) - tempMatched := false + // 统一错误策略:自定义错误码 + 临时不可调度 if s.rateLimitService != nil { - tempMatched = s.rateLimitService.HandleTempUnschedulable(ctx, account, resp.StatusCode, respBody) - } - s.handleGeminiUpstreamError(ctx, account, resp.StatusCode, resp.Header, respBody) - if tempMatched { - upstreamReqID := resp.Header.Get(requestIDHeader) - if upstreamReqID == "" { - upstreamReqID = resp.Header.Get("x-goog-request-id") - } - upstreamMsg := strings.TrimSpace(extractUpstreamErrorMessage(respBody)) - upstreamMsg = sanitizeUpstreamErrorMessage(upstreamMsg) - upstreamDetail := "" - if s.cfg != nil && s.cfg.Gateway.LogUpstreamErrorBody { - maxBytes := s.cfg.Gateway.LogUpstreamErrorBodyMaxBytes - if maxBytes <= 0 { - maxBytes = 2048 + switch s.rateLimitService.CheckErrorPolicy(ctx, account, resp.StatusCode, respBody) { + case ErrorPolicySkipped: + upstreamReqID := resp.Header.Get(requestIDHeader) + if upstreamReqID == "" { + upstreamReqID = resp.Header.Get("x-goog-request-id") } - upstreamDetail = truncateString(string(respBody), maxBytes) + return nil, s.writeGeminiMappedError(c, account, resp.StatusCode, upstreamReqID, respBody) + case ErrorPolicyMatched, ErrorPolicyTempUnscheduled: + s.handleGeminiUpstreamError(ctx, account, resp.StatusCode, resp.Header, respBody) + upstreamReqID := resp.Header.Get(requestIDHeader) + if upstreamReqID == "" { + upstreamReqID = resp.Header.Get("x-goog-request-id") + } + upstreamMsg := strings.TrimSpace(extractUpstreamErrorMessage(respBody)) + upstreamMsg = sanitizeUpstreamErrorMessage(upstreamMsg) + upstreamDetail := "" + if s.cfg != nil && s.cfg.Gateway.LogUpstreamErrorBody { + maxBytes := s.cfg.Gateway.LogUpstreamErrorBodyMaxBytes + if maxBytes <= 0 { + maxBytes = 2048 + } + upstreamDetail = truncateString(string(respBody), maxBytes) + } + appendOpsUpstreamError(c, OpsUpstreamErrorEvent{ + Platform: account.Platform, + AccountID: account.ID, + AccountName: account.Name, + UpstreamStatusCode: resp.StatusCode, + UpstreamRequestID: upstreamReqID, + Kind: "failover", + Message: upstreamMsg, + Detail: upstreamDetail, + }) + return nil, &UpstreamFailoverError{StatusCode: resp.StatusCode, ResponseBody: respBody} } - appendOpsUpstreamError(c, OpsUpstreamErrorEvent{ - Platform: account.Platform, - AccountID: account.ID, - AccountName: account.Name, - UpstreamStatusCode: resp.StatusCode, - UpstreamRequestID: upstreamReqID, - Kind: "failover", - Message: upstreamMsg, - Detail: upstreamDetail, - }) - return nil, &UpstreamFailoverError{StatusCode: resp.StatusCode, ResponseBody: respBody} } + + // ErrorPolicyNone → 原有逻辑 + s.handleGeminiUpstreamError(ctx, account, resp.StatusCode, resp.Header, respBody) if s.shouldFailoverGeminiUpstreamError(resp.StatusCode) { upstreamReqID := resp.Header.Get(requestIDHeader) if upstreamReqID == "" { @@ -1026,10 +1029,7 @@ func (s *GeminiMessagesCompatService) ForwardNative(ctx context.Context, c *gin. return nil, "", errors.New("gemini api_key not configured") } - baseURL := strings.TrimSpace(account.GetCredential("base_url")) - if baseURL == "" { - baseURL = geminicli.AIStudioBaseURL - } + baseURL := account.GetGeminiBaseURL(geminicli.AIStudioBaseURL) normalizedBaseURL, err := s.validateUpstreamBaseURL(baseURL) if err != nil { return nil, "", err @@ -1097,10 +1097,7 @@ func (s *GeminiMessagesCompatService) ForwardNative(ctx context.Context, c *gin. return upstreamReq, "x-request-id", nil } else { // Mode 2: AI Studio API with OAuth (like API key mode, but using Bearer token) - baseURL := strings.TrimSpace(account.GetCredential("base_url")) - if baseURL == "" { - baseURL = geminicli.AIStudioBaseURL - } + baseURL := account.GetGeminiBaseURL(geminicli.AIStudioBaseURL) normalizedBaseURL, err := s.validateUpstreamBaseURL(baseURL) if err != nil { return nil, "", err @@ -1261,14 +1258,9 @@ func (s *GeminiMessagesCompatService) ForwardNative(ctx context.Context, c *gin. if resp.StatusCode >= 400 { respBody, _ := io.ReadAll(io.LimitReader(resp.Body, 2<<20)) - tempMatched := false - if s.rateLimitService != nil { - tempMatched = s.rateLimitService.HandleTempUnschedulable(ctx, account, resp.StatusCode, respBody) - } - s.handleGeminiUpstreamError(ctx, account, resp.StatusCode, resp.Header, respBody) - // Best-effort fallback for OAuth tokens missing AI Studio scopes when calling countTokens. // This avoids Gemini SDKs failing hard during preflight token counting. + // Checked before error policy so it always works regardless of custom error codes. if action == "countTokens" && isOAuth && isGeminiInsufficientScope(resp.Header, respBody) { estimated := estimateGeminiCountTokens(body) c.JSON(http.StatusOK, map[string]any{"totalTokens": estimated}) @@ -1282,30 +1274,46 @@ func (s *GeminiMessagesCompatService) ForwardNative(ctx context.Context, c *gin. }, nil } - if tempMatched { - evBody := unwrapIfNeeded(isOAuth, respBody) - upstreamMsg := strings.TrimSpace(extractUpstreamErrorMessage(evBody)) - upstreamMsg = sanitizeUpstreamErrorMessage(upstreamMsg) - upstreamDetail := "" - if s.cfg != nil && s.cfg.Gateway.LogUpstreamErrorBody { - maxBytes := s.cfg.Gateway.LogUpstreamErrorBodyMaxBytes - if maxBytes <= 0 { - maxBytes = 2048 + // 统一错误策略:自定义错误码 + 临时不可调度 + if s.rateLimitService != nil { + switch s.rateLimitService.CheckErrorPolicy(ctx, account, resp.StatusCode, respBody) { + case ErrorPolicySkipped: + respBody = unwrapIfNeeded(isOAuth, respBody) + contentType := resp.Header.Get("Content-Type") + if contentType == "" { + contentType = "application/json" } - upstreamDetail = truncateString(string(evBody), maxBytes) + c.Data(resp.StatusCode, contentType, respBody) + return nil, fmt.Errorf("gemini upstream error: %d (skipped by error policy)", resp.StatusCode) + case ErrorPolicyMatched, ErrorPolicyTempUnscheduled: + s.handleGeminiUpstreamError(ctx, account, resp.StatusCode, resp.Header, respBody) + evBody := unwrapIfNeeded(isOAuth, respBody) + upstreamMsg := strings.TrimSpace(extractUpstreamErrorMessage(evBody)) + upstreamMsg = sanitizeUpstreamErrorMessage(upstreamMsg) + upstreamDetail := "" + if s.cfg != nil && s.cfg.Gateway.LogUpstreamErrorBody { + maxBytes := s.cfg.Gateway.LogUpstreamErrorBodyMaxBytes + if maxBytes <= 0 { + maxBytes = 2048 + } + upstreamDetail = truncateString(string(evBody), maxBytes) + } + appendOpsUpstreamError(c, OpsUpstreamErrorEvent{ + Platform: account.Platform, + AccountID: account.ID, + AccountName: account.Name, + UpstreamStatusCode: resp.StatusCode, + UpstreamRequestID: requestID, + Kind: "failover", + Message: upstreamMsg, + Detail: upstreamDetail, + }) + return nil, &UpstreamFailoverError{StatusCode: resp.StatusCode, ResponseBody: respBody} } - appendOpsUpstreamError(c, OpsUpstreamErrorEvent{ - Platform: account.Platform, - AccountID: account.ID, - AccountName: account.Name, - UpstreamStatusCode: resp.StatusCode, - UpstreamRequestID: requestID, - Kind: "failover", - Message: upstreamMsg, - Detail: upstreamDetail, - }) - return nil, &UpstreamFailoverError{StatusCode: resp.StatusCode, ResponseBody: respBody} } + + // ErrorPolicyNone → 原有逻辑 + s.handleGeminiUpstreamError(ctx, account, resp.StatusCode, resp.Header, respBody) if s.shouldFailoverGeminiUpstreamError(resp.StatusCode) { evBody := unwrapIfNeeded(isOAuth, respBody) upstreamMsg := strings.TrimSpace(extractUpstreamErrorMessage(evBody)) @@ -2420,10 +2428,7 @@ func (s *GeminiMessagesCompatService) ForwardAIStudioGET(ctx context.Context, ac return nil, errors.New("invalid path") } - baseURL := strings.TrimSpace(account.GetCredential("base_url")) - if baseURL == "" { - baseURL = geminicli.AIStudioBaseURL - } + baseURL := account.GetGeminiBaseURL(geminicli.AIStudioBaseURL) normalizedBaseURL, err := s.validateUpstreamBaseURL(baseURL) if err != nil { return nil, err diff --git a/backend/internal/service/gemini_multiplatform_test.go b/backend/internal/service/gemini_multiplatform_test.go index 9acf08f6..2d596f33 100644 --- a/backend/internal/service/gemini_multiplatform_test.go +++ b/backend/internal/service/gemini_multiplatform_test.go @@ -66,7 +66,12 @@ func (m *mockAccountRepoForGemini) Create(ctx context.Context, account *Account) func (m *mockAccountRepoForGemini) GetByCRSAccountID(ctx context.Context, crsAccountID string) (*Account, error) { return nil, nil } -func (m *mockAccountRepoForGemini) FindByExtraField(ctx context.Context, key string, value interface{}) ([]Account, error) { + +func (m *mockAccountRepoForGemini) FindByExtraField(ctx context.Context, key string, value any) ([]Account, error) { + return nil, nil +} + +func (m *mockAccountRepoForGemini) ListCRSAccountIDs(ctx context.Context) (map[string]int64, error) { return nil, nil } func (m *mockAccountRepoForGemini) Update(ctx context.Context, account *Account) error { return nil } @@ -136,9 +141,6 @@ func (m *mockAccountRepoForGemini) ListSchedulableByGroupIDAndPlatforms(ctx cont func (m *mockAccountRepoForGemini) SetRateLimited(ctx context.Context, id int64, resetAt time.Time) error { return nil } -func (m *mockAccountRepoForGemini) SetAntigravityQuotaScopeLimit(ctx context.Context, id int64, scope AntigravityQuotaScope, resetAt time.Time) error { - return nil -} func (m *mockAccountRepoForGemini) SetModelRateLimit(ctx context.Context, id int64, scope string, resetAt time.Time) error { return nil } @@ -229,6 +231,10 @@ func (m *mockGroupRepoForGemini) GetAccountIDsByGroupIDs(ctx context.Context, gr return nil, nil } +func (m *mockGroupRepoForGemini) UpdateSortOrders(ctx context.Context, updates []GroupSortOrderUpdate) error { + return nil +} + var _ GroupRepository = (*mockGroupRepoForGemini)(nil) // mockGatewayCacheForGemini Gemini 测试用的 cache mock @@ -268,22 +274,6 @@ func (m *mockGatewayCacheForGemini) DeleteSessionAccountID(ctx context.Context, return nil } -func (m *mockGatewayCacheForGemini) IncrModelCallCount(ctx context.Context, accountID int64, model string) (int64, error) { - return 0, nil -} - -func (m *mockGatewayCacheForGemini) GetModelLoadBatch(ctx context.Context, accountIDs []int64, model string) (map[int64]*ModelLoadInfo, error) { - return nil, nil -} - -func (m *mockGatewayCacheForGemini) FindGeminiSession(ctx context.Context, groupID int64, prefixHash, digestChain string) (uuid string, accountID int64, found bool) { - return "", 0, false -} - -func (m *mockGatewayCacheForGemini) SaveGeminiSession(ctx context.Context, groupID int64, prefixHash, digestChain, uuid string, accountID int64) error { - return nil -} - // TestGeminiMessagesCompatService_SelectAccountForModelWithExclusions_GeminiPlatform 测试 Gemini 单平台选择 func TestGeminiMessagesCompatService_SelectAccountForModelWithExclusions_GeminiPlatform(t *testing.T) { ctx := context.Background() diff --git a/backend/internal/service/gemini_session.go b/backend/internal/service/gemini_session.go index 859ae9f3..1780d1da 100644 --- a/backend/internal/service/gemini_session.go +++ b/backend/internal/service/gemini_session.go @@ -6,26 +6,11 @@ import ( "encoding/json" "strconv" "strings" - "time" "github.com/Wei-Shaw/sub2api/internal/pkg/antigravity" "github.com/cespare/xxhash/v2" ) -// Gemini 会话 ID Fallback 相关常量 -const ( - // geminiSessionTTLSeconds Gemini 会话缓存 TTL(5 分钟) - geminiSessionTTLSeconds = 300 - - // geminiSessionKeyPrefix Gemini 会话 Redis key 前缀 - geminiSessionKeyPrefix = "gemini:sess:" -) - -// GeminiSessionTTL 返回 Gemini 会话缓存 TTL -func GeminiSessionTTL() time.Duration { - return geminiSessionTTLSeconds * time.Second -} - // shortHash 使用 XXHash64 + Base36 生成短 hash(16 字符) // XXHash64 比 SHA256 快约 10 倍,Base36 比 Hex 短约 20% func shortHash(data []byte) string { @@ -79,35 +64,6 @@ func GenerateGeminiPrefixHash(userID, apiKeyID int64, ip, userAgent, platform, m return base64.RawURLEncoding.EncodeToString(hash[:12]) } -// BuildGeminiSessionKey 构建 Gemini 会话 Redis key -// 格式: gemini:sess:{groupID}:{prefixHash}:{digestChain} -func BuildGeminiSessionKey(groupID int64, prefixHash, digestChain string) string { - return geminiSessionKeyPrefix + strconv.FormatInt(groupID, 10) + ":" + prefixHash + ":" + digestChain -} - -// GenerateDigestChainPrefixes 生成摘要链的所有前缀(从长到短) -// 用于 MGET 批量查询最长匹配 -func GenerateDigestChainPrefixes(chain string) []string { - if chain == "" { - return nil - } - - var prefixes []string - c := chain - - for c != "" { - prefixes = append(prefixes, c) - // 找到最后一个 "-" 的位置 - if i := strings.LastIndex(c, "-"); i > 0 { - c = c[:i] - } else { - break - } - } - - return prefixes -} - // ParseGeminiSessionValue 解析 Gemini 会话缓存值 // 格式: {uuid}:{accountID} func ParseGeminiSessionValue(value string) (uuid string, accountID int64, ok bool) { @@ -139,15 +95,6 @@ func FormatGeminiSessionValue(uuid string, accountID int64) string { // geminiDigestSessionKeyPrefix Gemini 摘要 fallback 会话 key 前缀 const geminiDigestSessionKeyPrefix = "gemini:digest:" -// geminiTrieKeyPrefix Gemini Trie 会话 key 前缀 -const geminiTrieKeyPrefix = "gemini:trie:" - -// BuildGeminiTrieKey 构建 Gemini Trie Redis key -// 格式: gemini:trie:{groupID}:{prefixHash} -func BuildGeminiTrieKey(groupID int64, prefixHash string) string { - return geminiTrieKeyPrefix + strconv.FormatInt(groupID, 10) + ":" + prefixHash -} - // GenerateGeminiDigestSessionKey 生成 Gemini 摘要 fallback 的 sessionKey // 组合 prefixHash 前 8 位 + uuid 前 8 位,确保不同会话产生不同的 sessionKey // 用于在 SelectAccountWithLoadAwareness 中保持粘性会话 diff --git a/backend/internal/service/gemini_session_integration_test.go b/backend/internal/service/gemini_session_integration_test.go index 928c62cf..95b5f594 100644 --- a/backend/internal/service/gemini_session_integration_test.go +++ b/backend/internal/service/gemini_session_integration_test.go @@ -1,41 +1,14 @@ package service import ( - "context" "testing" "github.com/Wei-Shaw/sub2api/internal/pkg/antigravity" ) -// mockGeminiSessionCache 模拟 Redis 缓存 -type mockGeminiSessionCache struct { - sessions map[string]string // key -> value -} - -func newMockGeminiSessionCache() *mockGeminiSessionCache { - return &mockGeminiSessionCache{sessions: make(map[string]string)} -} - -func (m *mockGeminiSessionCache) Save(groupID int64, prefixHash, digestChain, uuid string, accountID int64) { - key := BuildGeminiSessionKey(groupID, prefixHash, digestChain) - value := FormatGeminiSessionValue(uuid, accountID) - m.sessions[key] = value -} - -func (m *mockGeminiSessionCache) Find(groupID int64, prefixHash, digestChain string) (uuid string, accountID int64, found bool) { - prefixes := GenerateDigestChainPrefixes(digestChain) - for _, p := range prefixes { - key := BuildGeminiSessionKey(groupID, prefixHash, p) - if val, ok := m.sessions[key]; ok { - return ParseGeminiSessionValue(val) - } - } - return "", 0, false -} - // TestGeminiSessionContinuousConversation 测试连续会话的摘要链匹配 func TestGeminiSessionContinuousConversation(t *testing.T) { - cache := newMockGeminiSessionCache() + store := NewDigestSessionStore() groupID := int64(1) prefixHash := "test_prefix_hash" sessionUUID := "session-uuid-12345" @@ -54,13 +27,13 @@ func TestGeminiSessionContinuousConversation(t *testing.T) { t.Logf("Round 1 chain: %s", chain1) // 第一轮:没有找到会话,创建新会话 - _, _, found := cache.Find(groupID, prefixHash, chain1) + _, _, _, found := store.Find(groupID, prefixHash, chain1) if found { t.Error("Round 1: should not find existing session") } - // 保存第一轮会话 - cache.Save(groupID, prefixHash, chain1, sessionUUID, accountID) + // 保存第一轮会话(首轮无旧 chain) + store.Save(groupID, prefixHash, chain1, sessionUUID, accountID, "") // 模拟第二轮对话(用户继续对话) req2 := &antigravity.GeminiRequest{ @@ -77,7 +50,7 @@ func TestGeminiSessionContinuousConversation(t *testing.T) { t.Logf("Round 2 chain: %s", chain2) // 第二轮:应该能找到会话(通过前缀匹配) - foundUUID, foundAccID, found := cache.Find(groupID, prefixHash, chain2) + foundUUID, foundAccID, matchedChain, found := store.Find(groupID, prefixHash, chain2) if !found { t.Error("Round 2: should find session via prefix matching") } @@ -88,8 +61,8 @@ func TestGeminiSessionContinuousConversation(t *testing.T) { t.Errorf("Round 2: expected accountID %d, got %d", accountID, foundAccID) } - // 保存第二轮会话 - cache.Save(groupID, prefixHash, chain2, sessionUUID, accountID) + // 保存第二轮会话,传入 Find 返回的 matchedChain 以删旧 key + store.Save(groupID, prefixHash, chain2, sessionUUID, accountID, matchedChain) // 模拟第三轮对话 req3 := &antigravity.GeminiRequest{ @@ -108,7 +81,7 @@ func TestGeminiSessionContinuousConversation(t *testing.T) { t.Logf("Round 3 chain: %s", chain3) // 第三轮:应该能找到会话(通过第二轮的前缀匹配) - foundUUID, foundAccID, found = cache.Find(groupID, prefixHash, chain3) + foundUUID, foundAccID, _, found = store.Find(groupID, prefixHash, chain3) if !found { t.Error("Round 3: should find session via prefix matching") } @@ -118,13 +91,11 @@ func TestGeminiSessionContinuousConversation(t *testing.T) { if foundAccID != accountID { t.Errorf("Round 3: expected accountID %d, got %d", accountID, foundAccID) } - - t.Log("✓ Continuous conversation session matching works correctly!") } // TestGeminiSessionDifferentConversations 测试不同会话不会错误匹配 func TestGeminiSessionDifferentConversations(t *testing.T) { - cache := newMockGeminiSessionCache() + store := NewDigestSessionStore() groupID := int64(1) prefixHash := "test_prefix_hash" @@ -135,7 +106,7 @@ func TestGeminiSessionDifferentConversations(t *testing.T) { }, } chain1 := BuildGeminiDigestChain(req1) - cache.Save(groupID, prefixHash, chain1, "session-1", 100) + store.Save(groupID, prefixHash, chain1, "session-1", 100, "") // 第二个完全不同的会话 req2 := &antigravity.GeminiRequest{ @@ -146,61 +117,29 @@ func TestGeminiSessionDifferentConversations(t *testing.T) { chain2 := BuildGeminiDigestChain(req2) // 不同会话不应该匹配 - _, _, found := cache.Find(groupID, prefixHash, chain2) + _, _, _, found := store.Find(groupID, prefixHash, chain2) if found { t.Error("Different conversations should not match") } - - t.Log("✓ Different conversations are correctly isolated!") } // TestGeminiSessionPrefixMatchingOrder 测试前缀匹配的优先级(最长匹配优先) func TestGeminiSessionPrefixMatchingOrder(t *testing.T) { - cache := newMockGeminiSessionCache() + store := NewDigestSessionStore() groupID := int64(1) prefixHash := "test_prefix_hash" - // 创建一个三轮对话 - req := &antigravity.GeminiRequest{ - SystemInstruction: &antigravity.GeminiContent{ - Parts: []antigravity.GeminiPart{{Text: "System prompt"}}, - }, - Contents: []antigravity.GeminiContent{ - {Role: "user", Parts: []antigravity.GeminiPart{{Text: "Q1"}}}, - {Role: "model", Parts: []antigravity.GeminiPart{{Text: "A1"}}}, - {Role: "user", Parts: []antigravity.GeminiPart{{Text: "Q2"}}}, - }, - } - fullChain := BuildGeminiDigestChain(req) - prefixes := GenerateDigestChainPrefixes(fullChain) - - t.Logf("Full chain: %s", fullChain) - t.Logf("Prefixes (longest first): %v", prefixes) - - // 验证前缀生成顺序(从长到短) - if len(prefixes) != 4 { - t.Errorf("Expected 4 prefixes, got %d", len(prefixes)) - } - // 保存不同轮次的会话到不同账号 - // 第一轮(最短前缀)-> 账号 1 - cache.Save(groupID, prefixHash, prefixes[3], "session-round1", 1) - // 第二轮 -> 账号 2 - cache.Save(groupID, prefixHash, prefixes[2], "session-round2", 2) - // 第三轮(最长前缀,完整链)-> 账号 3 - cache.Save(groupID, prefixHash, prefixes[0], "session-round3", 3) + store.Save(groupID, prefixHash, "s:sys-u:q1", "session-round1", 1, "") + store.Save(groupID, prefixHash, "s:sys-u:q1-m:a1", "session-round2", 2, "") + store.Save(groupID, prefixHash, "s:sys-u:q1-m:a1-u:q2", "session-round3", 3, "") - // 查找应该返回最长匹配(账号 3) - _, accID, found := cache.Find(groupID, prefixHash, fullChain) + // 查找更长的链,应该返回最长匹配(账号 3) + _, accID, _, found := store.Find(groupID, prefixHash, "s:sys-u:q1-m:a1-u:q2-m:a2") if !found { t.Error("Should find session") } if accID != 3 { t.Errorf("Should match longest prefix (account 3), got account %d", accID) } - - t.Log("✓ Longest prefix matching works correctly!") } - -// 确保 context 包被使用(避免未使用的导入警告) -var _ = context.Background diff --git a/backend/internal/service/gemini_session_test.go b/backend/internal/service/gemini_session_test.go index 8c1908f7..a034cddd 100644 --- a/backend/internal/service/gemini_session_test.go +++ b/backend/internal/service/gemini_session_test.go @@ -152,61 +152,6 @@ func TestGenerateGeminiPrefixHash(t *testing.T) { } } -func TestGenerateDigestChainPrefixes(t *testing.T) { - tests := []struct { - name string - chain string - want []string - wantLen int - }{ - { - name: "empty", - chain: "", - wantLen: 0, - }, - { - name: "single part", - chain: "u:abc123", - want: []string{"u:abc123"}, - wantLen: 1, - }, - { - name: "two parts", - chain: "s:xyz-u:abc", - want: []string{"s:xyz-u:abc", "s:xyz"}, - wantLen: 2, - }, - { - name: "four parts", - chain: "s:a-u:b-m:c-u:d", - want: []string{"s:a-u:b-m:c-u:d", "s:a-u:b-m:c", "s:a-u:b", "s:a"}, - wantLen: 4, - }, - } - - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - result := GenerateDigestChainPrefixes(tt.chain) - - if len(result) != tt.wantLen { - t.Errorf("expected %d prefixes, got %d: %v", tt.wantLen, len(result), result) - } - - if tt.want != nil { - for i, want := range tt.want { - if i >= len(result) { - t.Errorf("missing prefix at index %d", i) - continue - } - if result[i] != want { - t.Errorf("prefix[%d]: expected %s, got %s", i, want, result[i]) - } - } - } - }) - } -} - func TestParseGeminiSessionValue(t *testing.T) { tests := []struct { name string @@ -442,40 +387,3 @@ func TestGenerateGeminiDigestSessionKey(t *testing.T) { } }) } - -func TestBuildGeminiTrieKey(t *testing.T) { - tests := []struct { - name string - groupID int64 - prefixHash string - want string - }{ - { - name: "normal", - groupID: 123, - prefixHash: "abcdef12", - want: "gemini:trie:123:abcdef12", - }, - { - name: "zero group", - groupID: 0, - prefixHash: "xyz", - want: "gemini:trie:0:xyz", - }, - { - name: "empty prefix", - groupID: 1, - prefixHash: "", - want: "gemini:trie:1:", - }, - } - - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - got := BuildGeminiTrieKey(tt.groupID, tt.prefixHash) - if got != tt.want { - t.Errorf("BuildGeminiTrieKey(%d, %q) = %q, want %q", tt.groupID, tt.prefixHash, got, tt.want) - } - }) - } -} diff --git a/backend/internal/service/generate_session_hash_test.go b/backend/internal/service/generate_session_hash_test.go new file mode 100644 index 00000000..8aa358a5 --- /dev/null +++ b/backend/internal/service/generate_session_hash_test.go @@ -0,0 +1,1213 @@ +//go:build unit + +package service + +import ( + "testing" + + "github.com/stretchr/testify/require" +) + +// ============ 基础优先级测试 ============ + +func TestGenerateSessionHash_NilParsedRequest(t *testing.T) { + svc := &GatewayService{} + require.Empty(t, svc.GenerateSessionHash(nil)) +} + +func TestGenerateSessionHash_EmptyRequest(t *testing.T) { + svc := &GatewayService{} + require.Empty(t, svc.GenerateSessionHash(&ParsedRequest{})) +} + +func TestGenerateSessionHash_MetadataHasHighestPriority(t *testing.T) { + svc := &GatewayService{} + + parsed := &ParsedRequest{ + MetadataUserID: "session_123e4567-e89b-12d3-a456-426614174000", + System: "You are a helpful assistant.", + HasSystem: true, + Messages: []any{ + map[string]any{"role": "user", "content": "hello"}, + }, + } + + hash := svc.GenerateSessionHash(parsed) + require.Equal(t, "123e4567-e89b-12d3-a456-426614174000", hash, "metadata session_id should have highest priority") +} + +// ============ System + Messages 基础测试 ============ + +func TestGenerateSessionHash_SystemPlusMessages(t *testing.T) { + svc := &GatewayService{} + + withSystem := &ParsedRequest{ + System: "You are a helpful assistant.", + HasSystem: true, + Messages: []any{ + map[string]any{"role": "user", "content": "hello"}, + }, + } + withoutSystem := &ParsedRequest{ + Messages: []any{ + map[string]any{"role": "user", "content": "hello"}, + }, + } + + h1 := svc.GenerateSessionHash(withSystem) + h2 := svc.GenerateSessionHash(withoutSystem) + require.NotEmpty(t, h1) + require.NotEmpty(t, h2) + require.NotEqual(t, h1, h2, "system prompt should be part of digest, producing different hash") +} + +func TestGenerateSessionHash_SystemOnlyProducesHash(t *testing.T) { + svc := &GatewayService{} + + parsed := &ParsedRequest{ + System: "You are a helpful assistant.", + HasSystem: true, + } + hash := svc.GenerateSessionHash(parsed) + require.NotEmpty(t, hash, "system prompt alone should produce a hash as part of full digest") +} + +func TestGenerateSessionHash_DifferentSystemsSameMessages(t *testing.T) { + svc := &GatewayService{} + + parsed1 := &ParsedRequest{ + System: "You are assistant A.", + HasSystem: true, + Messages: []any{ + map[string]any{"role": "user", "content": "hello"}, + }, + } + parsed2 := &ParsedRequest{ + System: "You are assistant B.", + HasSystem: true, + Messages: []any{ + map[string]any{"role": "user", "content": "hello"}, + }, + } + + h1 := svc.GenerateSessionHash(parsed1) + h2 := svc.GenerateSessionHash(parsed2) + require.NotEqual(t, h1, h2, "different system prompts with same messages should produce different hashes") +} + +func TestGenerateSessionHash_SameSystemSameMessages(t *testing.T) { + svc := &GatewayService{} + + mk := func() *ParsedRequest { + return &ParsedRequest{ + System: "You are a helpful assistant.", + HasSystem: true, + Messages: []any{ + map[string]any{"role": "user", "content": "hello"}, + map[string]any{"role": "assistant", "content": "hi"}, + }, + } + } + + h1 := svc.GenerateSessionHash(mk()) + h2 := svc.GenerateSessionHash(mk()) + require.Equal(t, h1, h2, "same system + same messages should produce identical hash") +} + +func TestGenerateSessionHash_DifferentMessagesProduceDifferentHash(t *testing.T) { + svc := &GatewayService{} + + parsed1 := &ParsedRequest{ + System: "You are a helpful assistant.", + HasSystem: true, + Messages: []any{ + map[string]any{"role": "user", "content": "help me with Go"}, + }, + } + parsed2 := &ParsedRequest{ + System: "You are a helpful assistant.", + HasSystem: true, + Messages: []any{ + map[string]any{"role": "user", "content": "help me with Python"}, + }, + } + + h1 := svc.GenerateSessionHash(parsed1) + h2 := svc.GenerateSessionHash(parsed2) + require.NotEqual(t, h1, h2, "same system but different messages should produce different hashes") +} + +// ============ SessionContext 核心测试 ============ + +func TestGenerateSessionHash_DifferentSessionContextProducesDifferentHash(t *testing.T) { + svc := &GatewayService{} + + // 相同消息 + 不同 SessionContext → 不同 hash(解决碰撞问题的核心场景) + parsed1 := &ParsedRequest{ + Messages: []any{ + map[string]any{"role": "user", "content": "hello"}, + }, + SessionContext: &SessionContext{ + ClientIP: "192.168.1.1", + UserAgent: "Mozilla/5.0", + APIKeyID: 100, + }, + } + parsed2 := &ParsedRequest{ + Messages: []any{ + map[string]any{"role": "user", "content": "hello"}, + }, + SessionContext: &SessionContext{ + ClientIP: "10.0.0.1", + UserAgent: "curl/7.0", + APIKeyID: 200, + }, + } + + h1 := svc.GenerateSessionHash(parsed1) + h2 := svc.GenerateSessionHash(parsed2) + require.NotEmpty(t, h1) + require.NotEmpty(t, h2) + require.NotEqual(t, h1, h2, "same messages but different SessionContext should produce different hashes") +} + +func TestGenerateSessionHash_SameSessionContextProducesSameHash(t *testing.T) { + svc := &GatewayService{} + + mk := func() *ParsedRequest { + return &ParsedRequest{ + Messages: []any{ + map[string]any{"role": "user", "content": "hello"}, + }, + SessionContext: &SessionContext{ + ClientIP: "192.168.1.1", + UserAgent: "Mozilla/5.0", + APIKeyID: 100, + }, + } + } + + h1 := svc.GenerateSessionHash(mk()) + h2 := svc.GenerateSessionHash(mk()) + require.Equal(t, h1, h2, "same messages + same SessionContext should produce identical hash") +} + +func TestGenerateSessionHash_MetadataOverridesSessionContext(t *testing.T) { + svc := &GatewayService{} + + parsed := &ParsedRequest{ + MetadataUserID: "session_123e4567-e89b-12d3-a456-426614174000", + Messages: []any{ + map[string]any{"role": "user", "content": "hello"}, + }, + SessionContext: &SessionContext{ + ClientIP: "192.168.1.1", + UserAgent: "Mozilla/5.0", + APIKeyID: 100, + }, + } + + hash := svc.GenerateSessionHash(parsed) + require.Equal(t, "123e4567-e89b-12d3-a456-426614174000", hash, + "metadata session_id should take priority over SessionContext") +} + +func TestGenerateSessionHash_NilSessionContextBackwardCompatible(t *testing.T) { + svc := &GatewayService{} + + withCtx := &ParsedRequest{ + Messages: []any{ + map[string]any{"role": "user", "content": "hello"}, + }, + SessionContext: nil, + } + withoutCtx := &ParsedRequest{ + Messages: []any{ + map[string]any{"role": "user", "content": "hello"}, + }, + } + + h1 := svc.GenerateSessionHash(withCtx) + h2 := svc.GenerateSessionHash(withoutCtx) + require.Equal(t, h1, h2, "nil SessionContext should produce same hash as no SessionContext") +} + +// ============ 多轮连续会话测试 ============ + +func TestGenerateSessionHash_ContinuousConversation_HashChangesWithMessages(t *testing.T) { + svc := &GatewayService{} + + ctx := &SessionContext{ClientIP: "1.2.3.4", UserAgent: "test", APIKeyID: 1} + + // 模拟连续会话:每增加一轮对话,hash 应该不同(内容累积变化) + round1 := &ParsedRequest{ + System: "You are a helpful assistant.", + HasSystem: true, + Messages: []any{ + map[string]any{"role": "user", "content": "hello"}, + }, + SessionContext: ctx, + } + + round2 := &ParsedRequest{ + System: "You are a helpful assistant.", + HasSystem: true, + Messages: []any{ + map[string]any{"role": "user", "content": "hello"}, + map[string]any{"role": "assistant", "content": "Hi there!"}, + map[string]any{"role": "user", "content": "How are you?"}, + }, + SessionContext: ctx, + } + + round3 := &ParsedRequest{ + System: "You are a helpful assistant.", + HasSystem: true, + Messages: []any{ + map[string]any{"role": "user", "content": "hello"}, + map[string]any{"role": "assistant", "content": "Hi there!"}, + map[string]any{"role": "user", "content": "How are you?"}, + map[string]any{"role": "assistant", "content": "I'm doing well!"}, + map[string]any{"role": "user", "content": "Tell me a joke"}, + }, + SessionContext: ctx, + } + + h1 := svc.GenerateSessionHash(round1) + h2 := svc.GenerateSessionHash(round2) + h3 := svc.GenerateSessionHash(round3) + + require.NotEmpty(t, h1) + require.NotEmpty(t, h2) + require.NotEmpty(t, h3) + require.NotEqual(t, h1, h2, "different conversation rounds should produce different hashes") + require.NotEqual(t, h2, h3, "each new round should produce a different hash") + require.NotEqual(t, h1, h3, "round 1 and round 3 should differ") +} + +func TestGenerateSessionHash_ContinuousConversation_SameRoundSameHash(t *testing.T) { + svc := &GatewayService{} + + ctx := &SessionContext{ClientIP: "1.2.3.4", UserAgent: "test", APIKeyID: 1} + + // 同一轮对话重复请求(如重试)应产生相同 hash + mk := func() *ParsedRequest { + return &ParsedRequest{ + System: "You are a helpful assistant.", + HasSystem: true, + Messages: []any{ + map[string]any{"role": "user", "content": "hello"}, + map[string]any{"role": "assistant", "content": "Hi there!"}, + map[string]any{"role": "user", "content": "How are you?"}, + }, + SessionContext: ctx, + } + } + + h1 := svc.GenerateSessionHash(mk()) + h2 := svc.GenerateSessionHash(mk()) + require.Equal(t, h1, h2, "same conversation state should produce identical hash on retry") +} + +// ============ 消息回退测试 ============ + +func TestGenerateSessionHash_MessageRollback(t *testing.T) { + svc := &GatewayService{} + + ctx := &SessionContext{ClientIP: "1.2.3.4", UserAgent: "test", APIKeyID: 1} + + // 模拟消息回退:用户删掉最后一轮再重发 + original := &ParsedRequest{ + System: "System prompt", + HasSystem: true, + Messages: []any{ + map[string]any{"role": "user", "content": "msg1"}, + map[string]any{"role": "assistant", "content": "reply1"}, + map[string]any{"role": "user", "content": "msg2"}, + map[string]any{"role": "assistant", "content": "reply2"}, + map[string]any{"role": "user", "content": "msg3"}, + }, + SessionContext: ctx, + } + + // 回退到 msg2 后,用新的 msg3 替代 + rollback := &ParsedRequest{ + System: "System prompt", + HasSystem: true, + Messages: []any{ + map[string]any{"role": "user", "content": "msg1"}, + map[string]any{"role": "assistant", "content": "reply1"}, + map[string]any{"role": "user", "content": "msg2"}, + map[string]any{"role": "assistant", "content": "reply2"}, + map[string]any{"role": "user", "content": "different msg3"}, + }, + SessionContext: ctx, + } + + hOrig := svc.GenerateSessionHash(original) + hRollback := svc.GenerateSessionHash(rollback) + require.NotEqual(t, hOrig, hRollback, "rollback with different last message should produce different hash") +} + +func TestGenerateSessionHash_MessageRollbackSameContent(t *testing.T) { + svc := &GatewayService{} + + ctx := &SessionContext{ClientIP: "1.2.3.4", UserAgent: "test", APIKeyID: 1} + + // 回退后重新发送相同内容 → 相同 hash(合理的粘性恢复) + mk := func() *ParsedRequest { + return &ParsedRequest{ + System: "System prompt", + HasSystem: true, + Messages: []any{ + map[string]any{"role": "user", "content": "msg1"}, + map[string]any{"role": "assistant", "content": "reply1"}, + map[string]any{"role": "user", "content": "msg2"}, + }, + SessionContext: ctx, + } + } + + h1 := svc.GenerateSessionHash(mk()) + h2 := svc.GenerateSessionHash(mk()) + require.Equal(t, h1, h2, "rollback and resend same content should produce same hash") +} + +// ============ 相同 System、不同用户消息 ============ + +func TestGenerateSessionHash_SameSystemDifferentUsers(t *testing.T) { + svc := &GatewayService{} + + // 两个不同用户使用相同 system prompt 但发送不同消息 + user1 := &ParsedRequest{ + System: "You are a code reviewer.", + HasSystem: true, + Messages: []any{ + map[string]any{"role": "user", "content": "Review this Go code"}, + }, + SessionContext: &SessionContext{ + ClientIP: "1.1.1.1", + UserAgent: "vscode", + APIKeyID: 1, + }, + } + user2 := &ParsedRequest{ + System: "You are a code reviewer.", + HasSystem: true, + Messages: []any{ + map[string]any{"role": "user", "content": "Review this Python code"}, + }, + SessionContext: &SessionContext{ + ClientIP: "2.2.2.2", + UserAgent: "vscode", + APIKeyID: 2, + }, + } + + h1 := svc.GenerateSessionHash(user1) + h2 := svc.GenerateSessionHash(user2) + require.NotEqual(t, h1, h2, "different users with different messages should get different hashes") +} + +func TestGenerateSessionHash_SameSystemSameMessageDifferentContext(t *testing.T) { + svc := &GatewayService{} + + // 这是修复的核心场景:两个不同用户发送完全相同的 system + messages(如 "hello") + // 有了 SessionContext 后应该产生不同 hash + user1 := &ParsedRequest{ + System: "You are a helpful assistant.", + HasSystem: true, + Messages: []any{ + map[string]any{"role": "user", "content": "hello"}, + }, + SessionContext: &SessionContext{ + ClientIP: "1.1.1.1", + UserAgent: "Mozilla/5.0", + APIKeyID: 10, + }, + } + user2 := &ParsedRequest{ + System: "You are a helpful assistant.", + HasSystem: true, + Messages: []any{ + map[string]any{"role": "user", "content": "hello"}, + }, + SessionContext: &SessionContext{ + ClientIP: "2.2.2.2", + UserAgent: "Mozilla/5.0", + APIKeyID: 20, + }, + } + + h1 := svc.GenerateSessionHash(user1) + h2 := svc.GenerateSessionHash(user2) + require.NotEqual(t, h1, h2, "CRITICAL: same system+messages but different users should get different hashes") +} + +// ============ SessionContext 各字段独立影响测试 ============ + +func TestGenerateSessionHash_SessionContext_IPDifference(t *testing.T) { + svc := &GatewayService{} + + base := func(ip string) *ParsedRequest { + return &ParsedRequest{ + Messages: []any{ + map[string]any{"role": "user", "content": "test"}, + }, + SessionContext: &SessionContext{ + ClientIP: ip, + UserAgent: "same-ua", + APIKeyID: 1, + }, + } + } + + h1 := svc.GenerateSessionHash(base("1.1.1.1")) + h2 := svc.GenerateSessionHash(base("2.2.2.2")) + require.NotEqual(t, h1, h2, "different IP should produce different hash") +} + +func TestGenerateSessionHash_SessionContext_UADifference(t *testing.T) { + svc := &GatewayService{} + + base := func(ua string) *ParsedRequest { + return &ParsedRequest{ + Messages: []any{ + map[string]any{"role": "user", "content": "test"}, + }, + SessionContext: &SessionContext{ + ClientIP: "1.1.1.1", + UserAgent: ua, + APIKeyID: 1, + }, + } + } + + h1 := svc.GenerateSessionHash(base("Mozilla/5.0")) + h2 := svc.GenerateSessionHash(base("curl/7.0")) + require.NotEqual(t, h1, h2, "different User-Agent should produce different hash") +} + +func TestGenerateSessionHash_SessionContext_APIKeyIDDifference(t *testing.T) { + svc := &GatewayService{} + + base := func(keyID int64) *ParsedRequest { + return &ParsedRequest{ + Messages: []any{ + map[string]any{"role": "user", "content": "test"}, + }, + SessionContext: &SessionContext{ + ClientIP: "1.1.1.1", + UserAgent: "same-ua", + APIKeyID: keyID, + }, + } + } + + h1 := svc.GenerateSessionHash(base(1)) + h2 := svc.GenerateSessionHash(base(2)) + require.NotEqual(t, h1, h2, "different APIKeyID should produce different hash") +} + +// ============ 多用户并发相同消息场景 ============ + +func TestGenerateSessionHash_MultipleUsersSameFirstMessage(t *testing.T) { + svc := &GatewayService{} + + // 模拟 5 个不同用户同时发送 "hello" → 应该产生 5 个不同的 hash + hashes := make(map[string]bool) + for i := 0; i < 5; i++ { + parsed := &ParsedRequest{ + Messages: []any{ + map[string]any{"role": "user", "content": "hello"}, + }, + SessionContext: &SessionContext{ + ClientIP: "192.168.1." + string(rune('1'+i)), + UserAgent: "client-" + string(rune('A'+i)), + APIKeyID: int64(i + 1), + }, + } + h := svc.GenerateSessionHash(parsed) + require.NotEmpty(t, h) + require.False(t, hashes[h], "hash collision detected for user %d", i) + hashes[h] = true + } + require.Len(t, hashes, 5, "5 different users should produce 5 unique hashes") +} + +// ============ 连续会话粘性:多轮对话同一用户 ============ + +func TestGenerateSessionHash_SameUserGrowingConversation(t *testing.T) { + svc := &GatewayService{} + + ctx := &SessionContext{ClientIP: "1.2.3.4", UserAgent: "browser", APIKeyID: 42} + + // 模拟同一用户的连续会话,每轮 hash 不同但同用户重试保持一致 + messages := []map[string]any{ + {"role": "user", "content": "msg1"}, + {"role": "assistant", "content": "reply1"}, + {"role": "user", "content": "msg2"}, + {"role": "assistant", "content": "reply2"}, + {"role": "user", "content": "msg3"}, + {"role": "assistant", "content": "reply3"}, + {"role": "user", "content": "msg4"}, + } + + prevHash := "" + for round := 1; round <= len(messages); round += 2 { + // 构建前 round 条消息 + msgs := make([]any, round) + for j := 0; j < round; j++ { + msgs[j] = messages[j] + } + parsed := &ParsedRequest{ + System: "System", + HasSystem: true, + Messages: msgs, + SessionContext: ctx, + } + h := svc.GenerateSessionHash(parsed) + require.NotEmpty(t, h, "round %d hash should not be empty", round) + + if prevHash != "" { + require.NotEqual(t, prevHash, h, "round %d hash should differ from previous round", round) + } + prevHash = h + + // 同一轮重试应该相同 + h2 := svc.GenerateSessionHash(parsed) + require.Equal(t, h, h2, "retry of round %d should produce same hash", round) + } +} + +// ============ 多轮消息内容结构化测试 ============ + +func TestGenerateSessionHash_MultipleUserMessages(t *testing.T) { + svc := &GatewayService{} + + ctx := &SessionContext{ClientIP: "1.2.3.4", UserAgent: "test", APIKeyID: 1} + + // 5 条用户消息(无 assistant 回复) + parsed := &ParsedRequest{ + Messages: []any{ + map[string]any{"role": "user", "content": "first"}, + map[string]any{"role": "user", "content": "second"}, + map[string]any{"role": "user", "content": "third"}, + map[string]any{"role": "user", "content": "fourth"}, + map[string]any{"role": "user", "content": "fifth"}, + }, + SessionContext: ctx, + } + + h := svc.GenerateSessionHash(parsed) + require.NotEmpty(t, h) + + // 修改中间一条消息应该改变 hash + parsed2 := &ParsedRequest{ + Messages: []any{ + map[string]any{"role": "user", "content": "first"}, + map[string]any{"role": "user", "content": "CHANGED"}, + map[string]any{"role": "user", "content": "third"}, + map[string]any{"role": "user", "content": "fourth"}, + map[string]any{"role": "user", "content": "fifth"}, + }, + SessionContext: ctx, + } + + h2 := svc.GenerateSessionHash(parsed2) + require.NotEqual(t, h, h2, "changing any message should change the hash") +} + +func TestGenerateSessionHash_MessageOrderMatters(t *testing.T) { + svc := &GatewayService{} + + ctx := &SessionContext{ClientIP: "1.2.3.4", UserAgent: "test", APIKeyID: 1} + + parsed1 := &ParsedRequest{ + Messages: []any{ + map[string]any{"role": "user", "content": "alpha"}, + map[string]any{"role": "user", "content": "beta"}, + }, + SessionContext: ctx, + } + parsed2 := &ParsedRequest{ + Messages: []any{ + map[string]any{"role": "user", "content": "beta"}, + map[string]any{"role": "user", "content": "alpha"}, + }, + SessionContext: ctx, + } + + h1 := svc.GenerateSessionHash(parsed1) + h2 := svc.GenerateSessionHash(parsed2) + require.NotEqual(t, h1, h2, "message order should affect the hash") +} + +// ============ 复杂内容格式测试 ============ + +func TestGenerateSessionHash_StructuredContent(t *testing.T) { + svc := &GatewayService{} + + ctx := &SessionContext{ClientIP: "1.2.3.4", UserAgent: "test", APIKeyID: 1} + + // 结构化 content(数组形式) + parsed := &ParsedRequest{ + Messages: []any{ + map[string]any{ + "role": "user", + "content": []any{ + map[string]any{"type": "text", "text": "Look at this"}, + map[string]any{"type": "text", "text": "And this too"}, + }, + }, + }, + SessionContext: ctx, + } + + h := svc.GenerateSessionHash(parsed) + require.NotEmpty(t, h, "structured content should produce a hash") +} + +func TestGenerateSessionHash_ArraySystemPrompt(t *testing.T) { + svc := &GatewayService{} + + ctx := &SessionContext{ClientIP: "1.2.3.4", UserAgent: "test", APIKeyID: 1} + + // 数组格式的 system prompt + parsed := &ParsedRequest{ + System: []any{ + map[string]any{"type": "text", "text": "You are a helpful assistant."}, + map[string]any{"type": "text", "text": "Be concise."}, + }, + HasSystem: true, + Messages: []any{ + map[string]any{"role": "user", "content": "hello"}, + }, + SessionContext: ctx, + } + + h := svc.GenerateSessionHash(parsed) + require.NotEmpty(t, h, "array system prompt should produce a hash") +} + +// ============ SessionContext 与 cache_control 优先级 ============ + +func TestGenerateSessionHash_CacheControlOverridesSessionContext(t *testing.T) { + svc := &GatewayService{} + + // 当有 cache_control: ephemeral 时,使用第 2 级优先级 + // SessionContext 不应影响结果 + parsed1 := &ParsedRequest{ + System: []any{ + map[string]any{ + "type": "text", + "text": "You are a tool-specific assistant.", + "cache_control": map[string]any{"type": "ephemeral"}, + }, + }, + HasSystem: true, + Messages: []any{ + map[string]any{"role": "user", "content": "hello"}, + }, + SessionContext: &SessionContext{ + ClientIP: "1.1.1.1", + UserAgent: "ua1", + APIKeyID: 100, + }, + } + parsed2 := &ParsedRequest{ + System: []any{ + map[string]any{ + "type": "text", + "text": "You are a tool-specific assistant.", + "cache_control": map[string]any{"type": "ephemeral"}, + }, + }, + HasSystem: true, + Messages: []any{ + map[string]any{"role": "user", "content": "hello"}, + }, + SessionContext: &SessionContext{ + ClientIP: "2.2.2.2", + UserAgent: "ua2", + APIKeyID: 200, + }, + } + + h1 := svc.GenerateSessionHash(parsed1) + h2 := svc.GenerateSessionHash(parsed2) + require.Equal(t, h1, h2, "cache_control ephemeral has higher priority, SessionContext should not affect result") +} + +// ============ 边界情况 ============ + +func TestGenerateSessionHash_EmptyMessages(t *testing.T) { + svc := &GatewayService{} + + parsed := &ParsedRequest{ + Messages: []any{}, + SessionContext: &SessionContext{ + ClientIP: "1.1.1.1", + UserAgent: "test", + APIKeyID: 1, + }, + } + + // 空 messages + 只有 SessionContext 时,combined.Len() > 0 因为有 context 写入 + h := svc.GenerateSessionHash(parsed) + require.NotEmpty(t, h, "empty messages with SessionContext should still produce a hash from context") +} + +func TestGenerateSessionHash_EmptyMessagesNoContext(t *testing.T) { + svc := &GatewayService{} + + parsed := &ParsedRequest{ + Messages: []any{}, + } + + h := svc.GenerateSessionHash(parsed) + require.Empty(t, h, "empty messages without SessionContext should produce empty hash") +} + +func TestGenerateSessionHash_SessionContextWithEmptyFields(t *testing.T) { + svc := &GatewayService{} + + // SessionContext 字段为空字符串和零值时仍应影响 hash + withEmptyCtx := &ParsedRequest{ + Messages: []any{ + map[string]any{"role": "user", "content": "test"}, + }, + SessionContext: &SessionContext{ + ClientIP: "", + UserAgent: "", + APIKeyID: 0, + }, + } + withoutCtx := &ParsedRequest{ + Messages: []any{ + map[string]any{"role": "user", "content": "test"}, + }, + } + + h1 := svc.GenerateSessionHash(withEmptyCtx) + h2 := svc.GenerateSessionHash(withoutCtx) + // 有 SessionContext(即使字段为空)仍然会写入分隔符 "::" 等 + require.NotEqual(t, h1, h2, "empty-field SessionContext should still differ from nil SessionContext") +} + +// ============ 长对话历史测试 ============ + +func TestGenerateSessionHash_LongConversation(t *testing.T) { + svc := &GatewayService{} + + ctx := &SessionContext{ClientIP: "1.2.3.4", UserAgent: "test", APIKeyID: 1} + + // 构建 20 轮对话 + messages := make([]any, 0, 40) + for i := 0; i < 20; i++ { + messages = append(messages, map[string]any{ + "role": "user", + "content": "user message " + string(rune('A'+i)), + }) + messages = append(messages, map[string]any{ + "role": "assistant", + "content": "assistant reply " + string(rune('A'+i)), + }) + } + + parsed := &ParsedRequest{ + System: "System prompt", + HasSystem: true, + Messages: messages, + SessionContext: ctx, + } + + h := svc.GenerateSessionHash(parsed) + require.NotEmpty(t, h) + + // 再加一轮应该不同 + moreMessages := make([]any, len(messages)+2) + copy(moreMessages, messages) + moreMessages[len(messages)] = map[string]any{"role": "user", "content": "one more"} + moreMessages[len(messages)+1] = map[string]any{"role": "assistant", "content": "ok"} + + parsed2 := &ParsedRequest{ + System: "System prompt", + HasSystem: true, + Messages: moreMessages, + SessionContext: ctx, + } + + h2 := svc.GenerateSessionHash(parsed2) + require.NotEqual(t, h, h2, "adding more messages to long conversation should change hash") +} + +// ============ Gemini 原生格式 session hash 测试 ============ + +func TestGenerateSessionHash_GeminiContentsProducesHash(t *testing.T) { + svc := &GatewayService{} + + // Gemini 格式: contents[].parts[].text + parsed := &ParsedRequest{ + Messages: []any{ + map[string]any{ + "role": "user", + "parts": []any{ + map[string]any{"text": "Hello from Gemini"}, + }, + }, + }, + SessionContext: &SessionContext{ + ClientIP: "1.2.3.4", + UserAgent: "gemini-cli", + APIKeyID: 1, + }, + } + + h := svc.GenerateSessionHash(parsed) + require.NotEmpty(t, h, "Gemini contents with parts should produce a non-empty hash") +} + +func TestGenerateSessionHash_GeminiDifferentContentsDifferentHash(t *testing.T) { + svc := &GatewayService{} + + ctx := &SessionContext{ClientIP: "1.2.3.4", UserAgent: "gemini-cli", APIKeyID: 1} + + parsed1 := &ParsedRequest{ + Messages: []any{ + map[string]any{ + "role": "user", + "parts": []any{ + map[string]any{"text": "Hello"}, + }, + }, + }, + SessionContext: ctx, + } + parsed2 := &ParsedRequest{ + Messages: []any{ + map[string]any{ + "role": "user", + "parts": []any{ + map[string]any{"text": "Goodbye"}, + }, + }, + }, + SessionContext: ctx, + } + + h1 := svc.GenerateSessionHash(parsed1) + h2 := svc.GenerateSessionHash(parsed2) + require.NotEqual(t, h1, h2, "different Gemini contents should produce different hashes") +} + +func TestGenerateSessionHash_GeminiSameContentsSameHash(t *testing.T) { + svc := &GatewayService{} + + ctx := &SessionContext{ClientIP: "1.2.3.4", UserAgent: "gemini-cli", APIKeyID: 1} + + mk := func() *ParsedRequest { + return &ParsedRequest{ + Messages: []any{ + map[string]any{ + "role": "user", + "parts": []any{ + map[string]any{"text": "Hello"}, + }, + }, + map[string]any{ + "role": "model", + "parts": []any{ + map[string]any{"text": "Hi there!"}, + }, + }, + }, + SessionContext: ctx, + } + } + + h1 := svc.GenerateSessionHash(mk()) + h2 := svc.GenerateSessionHash(mk()) + require.Equal(t, h1, h2, "same Gemini contents should produce identical hash") +} + +func TestGenerateSessionHash_GeminiMultiTurnHashChanges(t *testing.T) { + svc := &GatewayService{} + + ctx := &SessionContext{ClientIP: "1.2.3.4", UserAgent: "gemini-cli", APIKeyID: 1} + + round1 := &ParsedRequest{ + Messages: []any{ + map[string]any{ + "role": "user", + "parts": []any{map[string]any{"text": "hello"}}, + }, + }, + SessionContext: ctx, + } + + round2 := &ParsedRequest{ + Messages: []any{ + map[string]any{ + "role": "user", + "parts": []any{map[string]any{"text": "hello"}}, + }, + map[string]any{ + "role": "model", + "parts": []any{map[string]any{"text": "Hi!"}}, + }, + map[string]any{ + "role": "user", + "parts": []any{map[string]any{"text": "How are you?"}}, + }, + }, + SessionContext: ctx, + } + + h1 := svc.GenerateSessionHash(round1) + h2 := svc.GenerateSessionHash(round2) + require.NotEmpty(t, h1) + require.NotEmpty(t, h2) + require.NotEqual(t, h1, h2, "Gemini multi-turn should produce different hashes per round") +} + +func TestGenerateSessionHash_GeminiDifferentUsersSameContentDifferentHash(t *testing.T) { + svc := &GatewayService{} + + // 核心场景:两个不同用户发送相同 Gemini 格式消息应得到不同 hash + user1 := &ParsedRequest{ + Messages: []any{ + map[string]any{ + "role": "user", + "parts": []any{map[string]any{"text": "hello"}}, + }, + }, + SessionContext: &SessionContext{ + ClientIP: "1.1.1.1", + UserAgent: "gemini-cli", + APIKeyID: 10, + }, + } + user2 := &ParsedRequest{ + Messages: []any{ + map[string]any{ + "role": "user", + "parts": []any{map[string]any{"text": "hello"}}, + }, + }, + SessionContext: &SessionContext{ + ClientIP: "2.2.2.2", + UserAgent: "gemini-cli", + APIKeyID: 20, + }, + } + + h1 := svc.GenerateSessionHash(user1) + h2 := svc.GenerateSessionHash(user2) + require.NotEqual(t, h1, h2, "CRITICAL: different Gemini users with same content must get different hashes") +} + +func TestGenerateSessionHash_GeminiSystemInstructionAffectsHash(t *testing.T) { + svc := &GatewayService{} + + ctx := &SessionContext{ClientIP: "1.2.3.4", UserAgent: "gemini-cli", APIKeyID: 1} + + // systemInstruction 经 ParseGatewayRequest 解析后存入 parsed.System + withSys := &ParsedRequest{ + System: []any{ + map[string]any{"text": "You are a coding assistant."}, + }, + Messages: []any{ + map[string]any{ + "role": "user", + "parts": []any{map[string]any{"text": "hello"}}, + }, + }, + SessionContext: ctx, + } + withoutSys := &ParsedRequest{ + Messages: []any{ + map[string]any{ + "role": "user", + "parts": []any{map[string]any{"text": "hello"}}, + }, + }, + SessionContext: ctx, + } + + h1 := svc.GenerateSessionHash(withSys) + h2 := svc.GenerateSessionHash(withoutSys) + require.NotEqual(t, h1, h2, "systemInstruction should affect the hash") +} + +func TestGenerateSessionHash_GeminiMultiPartMessage(t *testing.T) { + svc := &GatewayService{} + + ctx := &SessionContext{ClientIP: "1.2.3.4", UserAgent: "gemini-cli", APIKeyID: 1} + + // 多 parts 的消息 + parsed := &ParsedRequest{ + Messages: []any{ + map[string]any{ + "role": "user", + "parts": []any{ + map[string]any{"text": "Part 1"}, + map[string]any{"text": "Part 2"}, + map[string]any{"text": "Part 3"}, + }, + }, + }, + SessionContext: ctx, + } + + h := svc.GenerateSessionHash(parsed) + require.NotEmpty(t, h, "multi-part Gemini message should produce a hash") + + // 不同内容的多 parts + parsed2 := &ParsedRequest{ + Messages: []any{ + map[string]any{ + "role": "user", + "parts": []any{ + map[string]any{"text": "Part 1"}, + map[string]any{"text": "CHANGED"}, + map[string]any{"text": "Part 3"}, + }, + }, + }, + SessionContext: ctx, + } + + h2 := svc.GenerateSessionHash(parsed2) + require.NotEqual(t, h, h2, "changing a part should change the hash") +} + +func TestGenerateSessionHash_GeminiNonTextPartsIgnored(t *testing.T) { + svc := &GatewayService{} + + ctx := &SessionContext{ClientIP: "1.2.3.4", UserAgent: "gemini-cli", APIKeyID: 1} + + // 含非 text 类型 parts(如 inline_data),应被跳过但不报错 + parsed := &ParsedRequest{ + Messages: []any{ + map[string]any{ + "role": "user", + "parts": []any{ + map[string]any{"text": "Describe this image"}, + map[string]any{"inline_data": map[string]any{"mime_type": "image/png", "data": "base64..."}}, + }, + }, + }, + SessionContext: ctx, + } + + h := svc.GenerateSessionHash(parsed) + require.NotEmpty(t, h, "Gemini message with mixed parts should still produce a hash from text parts") +} + +func TestGenerateSessionHash_GeminiMultiTurnHashNotSticky(t *testing.T) { + svc := &GatewayService{} + + ctx := &SessionContext{ClientIP: "10.0.0.1", UserAgent: "gemini-cli", APIKeyID: 42} + + // 模拟同一 Gemini 会话的三轮请求,每轮 contents 累积增长。 + // 验证预期行为:每轮 hash 都不同,即 GenerateSessionHash 不具备跨轮粘性。 + // 这是 by-design 的——Gemini 的跨轮粘性由 Digest Fallback(BuildGeminiDigestChain)负责。 + round1Body := []byte(`{ + "systemInstruction": {"parts": [{"text": "You are a coding assistant."}]}, + "contents": [ + {"role": "user", "parts": [{"text": "Write a Go function"}]} + ] + }`) + round2Body := []byte(`{ + "systemInstruction": {"parts": [{"text": "You are a coding assistant."}]}, + "contents": [ + {"role": "user", "parts": [{"text": "Write a Go function"}]}, + {"role": "model", "parts": [{"text": "func hello() {}"}]}, + {"role": "user", "parts": [{"text": "Add error handling"}]} + ] + }`) + round3Body := []byte(`{ + "systemInstruction": {"parts": [{"text": "You are a coding assistant."}]}, + "contents": [ + {"role": "user", "parts": [{"text": "Write a Go function"}]}, + {"role": "model", "parts": [{"text": "func hello() {}"}]}, + {"role": "user", "parts": [{"text": "Add error handling"}]}, + {"role": "model", "parts": [{"text": "func hello() error { return nil }"}]}, + {"role": "user", "parts": [{"text": "Now add tests"}]} + ] + }`) + + hashes := make([]string, 3) + for i, body := range [][]byte{round1Body, round2Body, round3Body} { + parsed, err := ParseGatewayRequest(body, "gemini") + require.NoError(t, err) + parsed.SessionContext = ctx + hashes[i] = svc.GenerateSessionHash(parsed) + require.NotEmpty(t, hashes[i], "round %d hash should not be empty", i+1) + } + + // 每轮 hash 都不同——这是预期行为 + require.NotEqual(t, hashes[0], hashes[1], "round 1 vs 2 hash should differ (contents grow)") + require.NotEqual(t, hashes[1], hashes[2], "round 2 vs 3 hash should differ (contents grow)") + require.NotEqual(t, hashes[0], hashes[2], "round 1 vs 3 hash should differ") + + // 同一轮重试应产生相同 hash + parsed1Again, err := ParseGatewayRequest(round2Body, "gemini") + require.NoError(t, err) + parsed1Again.SessionContext = ctx + h2Again := svc.GenerateSessionHash(parsed1Again) + require.Equal(t, hashes[1], h2Again, "retry of same round should produce same hash") +} + +func TestGenerateSessionHash_GeminiEndToEnd(t *testing.T) { + svc := &GatewayService{} + + // 端到端测试:模拟 ParseGatewayRequest + GenerateSessionHash 完整流程 + body := []byte(`{ + "model": "gemini-2.5-pro", + "systemInstruction": { + "parts": [{"text": "You are a coding assistant."}] + }, + "contents": [ + {"role": "user", "parts": [{"text": "Write a Go function"}]}, + {"role": "model", "parts": [{"text": "Here is a function..."}]}, + {"role": "user", "parts": [{"text": "Now add error handling"}]} + ] + }`) + + parsed, err := ParseGatewayRequest(body, "gemini") + require.NoError(t, err) + parsed.SessionContext = &SessionContext{ + ClientIP: "10.0.0.1", + UserAgent: "gemini-cli/1.0", + APIKeyID: 42, + } + + h := svc.GenerateSessionHash(parsed) + require.NotEmpty(t, h, "end-to-end Gemini flow should produce a hash") + + // 同一请求再次解析应产生相同 hash + parsed2, err := ParseGatewayRequest(body, "gemini") + require.NoError(t, err) + parsed2.SessionContext = &SessionContext{ + ClientIP: "10.0.0.1", + UserAgent: "gemini-cli/1.0", + APIKeyID: 42, + } + + h2 := svc.GenerateSessionHash(parsed2) + require.Equal(t, h, h2, "same request should produce same hash") + + // 不同用户发送相同请求应产生不同 hash + parsed3, err := ParseGatewayRequest(body, "gemini") + require.NoError(t, err) + parsed3.SessionContext = &SessionContext{ + ClientIP: "10.0.0.2", + UserAgent: "gemini-cli/1.0", + APIKeyID: 99, + } + + h3 := svc.GenerateSessionHash(parsed3) + require.NotEqual(t, h, h3, "different user with same Gemini request should get different hash") +} diff --git a/backend/internal/service/group.go b/backend/internal/service/group.go index 23880b0b..86ece03f 100644 --- a/backend/internal/service/group.go +++ b/backend/internal/service/group.go @@ -51,6 +51,9 @@ type Group struct { // 可选值: claude, gemini_text, gemini_image SupportedModelScopes []string + // 分组排序 + SortOrder int + CreatedAt time.Time UpdatedAt time.Time diff --git a/backend/internal/service/group_service.go b/backend/internal/service/group_service.go index a2bf2073..22a67eda 100644 --- a/backend/internal/service/group_service.go +++ b/backend/internal/service/group_service.go @@ -33,6 +33,14 @@ type GroupRepository interface { GetAccountIDsByGroupIDs(ctx context.Context, groupIDs []int64) ([]int64, error) // BindAccountsToGroup 将多个账号绑定到指定分组 BindAccountsToGroup(ctx context.Context, groupID int64, accountIDs []int64) error + // UpdateSortOrders 批量更新分组排序 + UpdateSortOrders(ctx context.Context, updates []GroupSortOrderUpdate) error +} + +// GroupSortOrderUpdate 分组排序更新 +type GroupSortOrderUpdate struct { + ID int64 `json:"id"` + SortOrder int `json:"sort_order"` } // CreateGroupRequest 创建分组请求 diff --git a/backend/internal/service/model_rate_limit_test.go b/backend/internal/service/model_rate_limit_test.go index a51e6909..b79b9688 100644 --- a/backend/internal/service/model_rate_limit_test.go +++ b/backend/internal/service/model_rate_limit_test.go @@ -318,110 +318,6 @@ func TestGetModelRateLimitRemainingTime(t *testing.T) { } } -func TestGetQuotaScopeRateLimitRemainingTime(t *testing.T) { - now := time.Now() - future10m := now.Add(10 * time.Minute).Format(time.RFC3339) - past := now.Add(-10 * time.Minute).Format(time.RFC3339) - - tests := []struct { - name string - account *Account - requestedModel string - minExpected time.Duration - maxExpected time.Duration - }{ - { - name: "nil account", - account: nil, - requestedModel: "claude-sonnet-4-5", - minExpected: 0, - maxExpected: 0, - }, - { - name: "non-antigravity platform", - account: &Account{ - Platform: PlatformAnthropic, - Extra: map[string]any{ - antigravityQuotaScopesKey: map[string]any{ - "claude": map[string]any{ - "rate_limit_reset_at": future10m, - }, - }, - }, - }, - requestedModel: "claude-sonnet-4-5", - minExpected: 0, - maxExpected: 0, - }, - { - name: "claude scope rate limited", - account: &Account{ - Platform: PlatformAntigravity, - Extra: map[string]any{ - antigravityQuotaScopesKey: map[string]any{ - "claude": map[string]any{ - "rate_limit_reset_at": future10m, - }, - }, - }, - }, - requestedModel: "claude-sonnet-4-5", - minExpected: 9 * time.Minute, - maxExpected: 11 * time.Minute, - }, - { - name: "gemini_text scope rate limited", - account: &Account{ - Platform: PlatformAntigravity, - Extra: map[string]any{ - antigravityQuotaScopesKey: map[string]any{ - "gemini_text": map[string]any{ - "rate_limit_reset_at": future10m, - }, - }, - }, - }, - requestedModel: "gemini-3-flash", - minExpected: 9 * time.Minute, - maxExpected: 11 * time.Minute, - }, - { - name: "expired scope rate limit", - account: &Account{ - Platform: PlatformAntigravity, - Extra: map[string]any{ - antigravityQuotaScopesKey: map[string]any{ - "claude": map[string]any{ - "rate_limit_reset_at": past, - }, - }, - }, - }, - requestedModel: "claude-sonnet-4-5", - minExpected: 0, - maxExpected: 0, - }, - { - name: "unsupported model", - account: &Account{ - Platform: PlatformAntigravity, - }, - requestedModel: "gpt-4", - minExpected: 0, - maxExpected: 0, - }, - } - - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - result := tt.account.GetQuotaScopeRateLimitRemainingTime(tt.requestedModel) - if result < tt.minExpected || result > tt.maxExpected { - t.Errorf("GetQuotaScopeRateLimitRemainingTime() = %v, want between %v and %v", result, tt.minExpected, tt.maxExpected) - } - }) - } -} - func TestGetRateLimitRemainingTime(t *testing.T) { now := time.Now() future15m := now.Add(15 * time.Minute).Format(time.RFC3339) @@ -442,45 +338,19 @@ func TestGetRateLimitRemainingTime(t *testing.T) { maxExpected: 0, }, { - name: "model remaining > scope remaining - returns model", + name: "model rate limited - 15 minutes", account: &Account{ Platform: PlatformAntigravity, Extra: map[string]any{ modelRateLimitsKey: map[string]any{ "claude-sonnet-4-5": map[string]any{ - "rate_limit_reset_at": future15m, // 15 分钟 - }, - }, - antigravityQuotaScopesKey: map[string]any{ - "claude": map[string]any{ - "rate_limit_reset_at": future5m, // 5 分钟 + "rate_limit_reset_at": future15m, }, }, }, }, requestedModel: "claude-sonnet-4-5", - minExpected: 14 * time.Minute, // 应返回较大的 15 分钟 - maxExpected: 16 * time.Minute, - }, - { - name: "scope remaining > model remaining - returns scope", - account: &Account{ - Platform: PlatformAntigravity, - Extra: map[string]any{ - modelRateLimitsKey: map[string]any{ - "claude-sonnet-4-5": map[string]any{ - "rate_limit_reset_at": future5m, // 5 分钟 - }, - }, - antigravityQuotaScopesKey: map[string]any{ - "claude": map[string]any{ - "rate_limit_reset_at": future15m, // 15 分钟 - }, - }, - }, - }, - requestedModel: "claude-sonnet-4-5", - minExpected: 14 * time.Minute, // 应返回较大的 15 分钟 + minExpected: 14 * time.Minute, maxExpected: 16 * time.Minute, }, { @@ -499,22 +369,6 @@ func TestGetRateLimitRemainingTime(t *testing.T) { minExpected: 4 * time.Minute, maxExpected: 6 * time.Minute, }, - { - name: "only scope rate limited", - account: &Account{ - Platform: PlatformAntigravity, - Extra: map[string]any{ - antigravityQuotaScopesKey: map[string]any{ - "claude": map[string]any{ - "rate_limit_reset_at": future5m, - }, - }, - }, - }, - requestedModel: "claude-sonnet-4-5", - minExpected: 4 * time.Minute, - maxExpected: 6 * time.Minute, - }, { name: "neither rate limited", account: &Account{ diff --git a/backend/internal/service/openai_gateway_service.go b/backend/internal/service/openai_gateway_service.go index 450075fb..bc618046 100644 --- a/backend/internal/service/openai_gateway_service.go +++ b/backend/internal/service/openai_gateway_service.go @@ -582,10 +582,6 @@ func (s *OpenAIGatewayService) SelectAccountWithLoadAwareness(ctx context.Contex } } } else { - type accountWithLoad struct { - account *Account - loadInfo *AccountLoadInfo - } var available []accountWithLoad for _, acc := range candidates { loadInfo := loadMap[acc.ID] @@ -620,6 +616,7 @@ func (s *OpenAIGatewayService) SelectAccountWithLoadAwareness(ctx context.Contex return a.account.LastUsedAt.Before(*b.account.LastUsedAt) } }) + shuffleWithinSortGroups(available) for _, item := range available { result, err := s.tryAcquireAccountSlot(ctx, item.account.ID, item.account.Concurrency) diff --git a/backend/internal/service/openai_gateway_service_test.go b/backend/internal/service/openai_gateway_service_test.go index a6eeb3eb..006820ed 100644 --- a/backend/internal/service/openai_gateway_service_test.go +++ b/backend/internal/service/openai_gateway_service_test.go @@ -209,22 +209,6 @@ func (c *stubGatewayCache) DeleteSessionAccountID(ctx context.Context, groupID i return nil } -func (c *stubGatewayCache) IncrModelCallCount(ctx context.Context, accountID int64, model string) (int64, error) { - return 0, nil -} - -func (c *stubGatewayCache) GetModelLoadBatch(ctx context.Context, accountIDs []int64, model string) (map[int64]*ModelLoadInfo, error) { - return nil, nil -} - -func (c *stubGatewayCache) FindGeminiSession(ctx context.Context, groupID int64, prefixHash, digestChain string) (uuid string, accountID int64, found bool) { - return "", 0, false -} - -func (c *stubGatewayCache) SaveGeminiSession(ctx context.Context, groupID int64, prefixHash, digestChain, uuid string, accountID int64) error { - return nil -} - func TestOpenAISelectAccountWithLoadAwareness_FiltersUnschedulable(t *testing.T) { now := time.Now() resetAt := now.Add(10 * time.Minute) diff --git a/backend/internal/service/ops_account_availability.go b/backend/internal/service/ops_account_availability.go index a649e7b5..da66ec4d 100644 --- a/backend/internal/service/ops_account_availability.go +++ b/backend/internal/service/ops_account_availability.go @@ -66,7 +66,6 @@ func (s *OpsService) GetAccountAvailabilityStats(ctx context.Context, platformFi } isAvailable := acc.Status == StatusActive && acc.Schedulable && !isRateLimited && !isOverloaded && !isTempUnsched - scopeRateLimits := acc.GetAntigravityScopeRateLimits() if acc.Platform != "" { if _, ok := platform[acc.Platform]; !ok { @@ -85,14 +84,6 @@ func (s *OpsService) GetAccountAvailabilityStats(ctx context.Context, platformFi if hasError { p.ErrorCount++ } - if len(scopeRateLimits) > 0 { - if p.ScopeRateLimitCount == nil { - p.ScopeRateLimitCount = make(map[string]int64) - } - for scope := range scopeRateLimits { - p.ScopeRateLimitCount[scope]++ - } - } } for _, grp := range acc.Groups { @@ -117,14 +108,6 @@ func (s *OpsService) GetAccountAvailabilityStats(ctx context.Context, platformFi if hasError { g.ErrorCount++ } - if len(scopeRateLimits) > 0 { - if g.ScopeRateLimitCount == nil { - g.ScopeRateLimitCount = make(map[string]int64) - } - for scope := range scopeRateLimits { - g.ScopeRateLimitCount[scope]++ - } - } } displayGroupID := int64(0) @@ -157,9 +140,6 @@ func (s *OpsService) GetAccountAvailabilityStats(ctx context.Context, platformFi item.RateLimitRemainingSec = &remainingSec } } - if len(scopeRateLimits) > 0 { - item.ScopeRateLimits = scopeRateLimits - } if isOverloaded && acc.OverloadUntil != nil { item.OverloadUntil = acc.OverloadUntil remainingSec := int64(time.Until(*acc.OverloadUntil).Seconds()) diff --git a/backend/internal/service/ops_realtime_models.go b/backend/internal/service/ops_realtime_models.go index 33029f59..a19ab355 100644 --- a/backend/internal/service/ops_realtime_models.go +++ b/backend/internal/service/ops_realtime_models.go @@ -50,24 +50,22 @@ type UserConcurrencyInfo struct { // PlatformAvailability aggregates account availability by platform. type PlatformAvailability struct { - Platform string `json:"platform"` - TotalAccounts int64 `json:"total_accounts"` - AvailableCount int64 `json:"available_count"` - RateLimitCount int64 `json:"rate_limit_count"` - ScopeRateLimitCount map[string]int64 `json:"scope_rate_limit_count,omitempty"` - ErrorCount int64 `json:"error_count"` + Platform string `json:"platform"` + TotalAccounts int64 `json:"total_accounts"` + AvailableCount int64 `json:"available_count"` + RateLimitCount int64 `json:"rate_limit_count"` + ErrorCount int64 `json:"error_count"` } // GroupAvailability aggregates account availability by group. type GroupAvailability struct { - GroupID int64 `json:"group_id"` - GroupName string `json:"group_name"` - Platform string `json:"platform"` - TotalAccounts int64 `json:"total_accounts"` - AvailableCount int64 `json:"available_count"` - RateLimitCount int64 `json:"rate_limit_count"` - ScopeRateLimitCount map[string]int64 `json:"scope_rate_limit_count,omitempty"` - ErrorCount int64 `json:"error_count"` + GroupID int64 `json:"group_id"` + GroupName string `json:"group_name"` + Platform string `json:"platform"` + TotalAccounts int64 `json:"total_accounts"` + AvailableCount int64 `json:"available_count"` + RateLimitCount int64 `json:"rate_limit_count"` + ErrorCount int64 `json:"error_count"` } // AccountAvailability represents current availability for a single account. @@ -85,11 +83,10 @@ type AccountAvailability struct { IsOverloaded bool `json:"is_overloaded"` HasError bool `json:"has_error"` - RateLimitResetAt *time.Time `json:"rate_limit_reset_at"` - RateLimitRemainingSec *int64 `json:"rate_limit_remaining_sec"` - ScopeRateLimits map[string]int64 `json:"scope_rate_limits,omitempty"` - OverloadUntil *time.Time `json:"overload_until"` - OverloadRemainingSec *int64 `json:"overload_remaining_sec"` - ErrorMessage string `json:"error_message"` - TempUnschedulableUntil *time.Time `json:"temp_unschedulable_until,omitempty"` + RateLimitResetAt *time.Time `json:"rate_limit_reset_at"` + RateLimitRemainingSec *int64 `json:"rate_limit_remaining_sec"` + OverloadUntil *time.Time `json:"overload_until"` + OverloadRemainingSec *int64 `json:"overload_remaining_sec"` + ErrorMessage string `json:"error_message"` + TempUnschedulableUntil *time.Time `json:"temp_unschedulable_until,omitempty"` } diff --git a/backend/internal/service/ops_retry.go b/backend/internal/service/ops_retry.go index fbc800f2..23a524ad 100644 --- a/backend/internal/service/ops_retry.go +++ b/backend/internal/service/ops_retry.go @@ -12,6 +12,7 @@ import ( "strings" "time" + "github.com/Wei-Shaw/sub2api/internal/domain" "github.com/Wei-Shaw/sub2api/internal/pkg/ctxkey" infraerrors "github.com/Wei-Shaw/sub2api/internal/pkg/errors" "github.com/gin-gonic/gin" @@ -528,7 +529,7 @@ func (s *OpsService) selectAccountForRetry(ctx context.Context, reqType opsRetry func extractRetryModelAndStream(reqType opsRetryRequestType, errorLog *OpsErrorLogDetail, body []byte) (model string, stream bool, err error) { switch reqType { case opsRetryTypeMessages: - parsed, parseErr := ParseGatewayRequest(body) + parsed, parseErr := ParseGatewayRequest(body, domain.PlatformAnthropic) if parseErr != nil { return "", false, fmt.Errorf("failed to parse messages request body: %w", parseErr) } @@ -596,7 +597,7 @@ func (s *OpsService) executeWithAccount(ctx context.Context, reqType opsRetryReq if s.gatewayService == nil { return &opsRetryExecution{status: opsRetryStatusFailed, errorMessage: "gateway service not available"} } - parsedReq, parseErr := ParseGatewayRequest(body) + parsedReq, parseErr := ParseGatewayRequest(body, domain.PlatformAnthropic) if parseErr != nil { return &opsRetryExecution{status: opsRetryStatusFailed, errorMessage: "failed to parse request body"} } diff --git a/backend/internal/service/ratelimit_service.go b/backend/internal/service/ratelimit_service.go index 47286deb..63732dee 100644 --- a/backend/internal/service/ratelimit_service.go +++ b/backend/internal/service/ratelimit_service.go @@ -62,6 +62,32 @@ func (s *RateLimitService) SetTokenCacheInvalidator(invalidator TokenCacheInvali s.tokenCacheInvalidator = invalidator } +// ErrorPolicyResult 表示错误策略检查的结果 +type ErrorPolicyResult int + +const ( + ErrorPolicyNone ErrorPolicyResult = iota // 未命中任何策略,继续默认逻辑 + ErrorPolicySkipped // 自定义错误码开启但未命中,跳过处理 + ErrorPolicyMatched // 自定义错误码命中,应停止调度 + ErrorPolicyTempUnscheduled // 临时不可调度规则命中 +) + +// CheckErrorPolicy 检查自定义错误码和临时不可调度规则。 +// 自定义错误码开启时覆盖后续所有逻辑(包括临时不可调度)。 +func (s *RateLimitService) CheckErrorPolicy(ctx context.Context, account *Account, statusCode int, responseBody []byte) ErrorPolicyResult { + if account.IsCustomErrorCodesEnabled() { + if account.ShouldHandleErrorCode(statusCode) { + return ErrorPolicyMatched + } + slog.Info("account_error_code_skipped", "account_id", account.ID, "status_code", statusCode) + return ErrorPolicySkipped + } + if s.tryTempUnschedulable(ctx, account, statusCode, responseBody) { + return ErrorPolicyTempUnscheduled + } + return ErrorPolicyNone +} + // HandleUpstreamError 处理上游错误响应,标记账号状态 // 返回是否应该停止该账号的调度 func (s *RateLimitService) HandleUpstreamError(ctx context.Context, account *Account, statusCode int, headers http.Header, responseBody []byte) (shouldDisable bool) { diff --git a/backend/internal/service/scheduler_shuffle_test.go b/backend/internal/service/scheduler_shuffle_test.go new file mode 100644 index 00000000..78ac5f57 --- /dev/null +++ b/backend/internal/service/scheduler_shuffle_test.go @@ -0,0 +1,318 @@ +//go:build unit + +package service + +import ( + "testing" + "time" + + "github.com/stretchr/testify/require" +) + +// ============ shuffleWithinSortGroups 测试 ============ + +func TestShuffleWithinSortGroups_Empty(t *testing.T) { + shuffleWithinSortGroups(nil) + shuffleWithinSortGroups([]accountWithLoad{}) +} + +func TestShuffleWithinSortGroups_SingleElement(t *testing.T) { + accounts := []accountWithLoad{ + {account: &Account{ID: 1, Priority: 1}, loadInfo: &AccountLoadInfo{LoadRate: 10}}, + } + shuffleWithinSortGroups(accounts) + require.Equal(t, int64(1), accounts[0].account.ID) +} + +func TestShuffleWithinSortGroups_DifferentGroups_OrderPreserved(t *testing.T) { + now := time.Now() + earlier := now.Add(-1 * time.Hour) + + accounts := []accountWithLoad{ + {account: &Account{ID: 1, Priority: 1, LastUsedAt: &earlier}, loadInfo: &AccountLoadInfo{LoadRate: 10}}, + {account: &Account{ID: 2, Priority: 1, LastUsedAt: &now}, loadInfo: &AccountLoadInfo{LoadRate: 20}}, + {account: &Account{ID: 3, Priority: 2, LastUsedAt: &earlier}, loadInfo: &AccountLoadInfo{LoadRate: 10}}, + } + + // 每个元素都属于不同组(Priority 或 LoadRate 或 LastUsedAt 不同),顺序不变 + for i := 0; i < 20; i++ { + cpy := make([]accountWithLoad, len(accounts)) + copy(cpy, accounts) + shuffleWithinSortGroups(cpy) + require.Equal(t, int64(1), cpy[0].account.ID) + require.Equal(t, int64(2), cpy[1].account.ID) + require.Equal(t, int64(3), cpy[2].account.ID) + } +} + +func TestShuffleWithinSortGroups_SameGroup_Shuffled(t *testing.T) { + now := time.Now() + // 同一秒的时间戳视为同一组 + sameSecond := time.Unix(now.Unix(), 0) + sameSecond2 := time.Unix(now.Unix(), 500_000_000) // 同一秒但不同纳秒 + + accounts := []accountWithLoad{ + {account: &Account{ID: 1, Priority: 1, LastUsedAt: &sameSecond}, loadInfo: &AccountLoadInfo{LoadRate: 10}}, + {account: &Account{ID: 2, Priority: 1, LastUsedAt: &sameSecond2}, loadInfo: &AccountLoadInfo{LoadRate: 10}}, + {account: &Account{ID: 3, Priority: 1, LastUsedAt: &sameSecond}, loadInfo: &AccountLoadInfo{LoadRate: 10}}, + } + + // 多次执行,验证所有 ID 都出现在第一个位置(说明确实被打乱了) + seen := map[int64]bool{} + for i := 0; i < 100; i++ { + cpy := make([]accountWithLoad, len(accounts)) + copy(cpy, accounts) + shuffleWithinSortGroups(cpy) + seen[cpy[0].account.ID] = true + // 无论怎么打乱,所有 ID 都应在候选中 + ids := map[int64]bool{} + for _, a := range cpy { + ids[a.account.ID] = true + } + require.True(t, ids[1] && ids[2] && ids[3]) + } + // 至少 2 个不同的 ID 出现在首位(随机性验证) + require.GreaterOrEqual(t, len(seen), 2, "shuffle should produce different orderings") +} + +func TestShuffleWithinSortGroups_NilLastUsedAt_SameGroup(t *testing.T) { + accounts := []accountWithLoad{ + {account: &Account{ID: 1, Priority: 1, LastUsedAt: nil}, loadInfo: &AccountLoadInfo{LoadRate: 0}}, + {account: &Account{ID: 2, Priority: 1, LastUsedAt: nil}, loadInfo: &AccountLoadInfo{LoadRate: 0}}, + {account: &Account{ID: 3, Priority: 1, LastUsedAt: nil}, loadInfo: &AccountLoadInfo{LoadRate: 0}}, + } + + seen := map[int64]bool{} + for i := 0; i < 100; i++ { + cpy := make([]accountWithLoad, len(accounts)) + copy(cpy, accounts) + shuffleWithinSortGroups(cpy) + seen[cpy[0].account.ID] = true + } + require.GreaterOrEqual(t, len(seen), 2, "nil LastUsedAt accounts should be shuffled") +} + +func TestShuffleWithinSortGroups_MixedGroups(t *testing.T) { + now := time.Now() + earlier := now.Add(-1 * time.Hour) + sameAsNow := time.Unix(now.Unix(), 0) + + // 组1: Priority=1, LoadRate=10, LastUsedAt=earlier (ID 1) — 单元素组 + // 组2: Priority=1, LoadRate=20, LastUsedAt=now (ID 2, 3) — 双元素组 + // 组3: Priority=2, LoadRate=10, LastUsedAt=earlier (ID 4) — 单元素组 + accounts := []accountWithLoad{ + {account: &Account{ID: 1, Priority: 1, LastUsedAt: &earlier}, loadInfo: &AccountLoadInfo{LoadRate: 10}}, + {account: &Account{ID: 2, Priority: 1, LastUsedAt: &now}, loadInfo: &AccountLoadInfo{LoadRate: 20}}, + {account: &Account{ID: 3, Priority: 1, LastUsedAt: &sameAsNow}, loadInfo: &AccountLoadInfo{LoadRate: 20}}, + {account: &Account{ID: 4, Priority: 2, LastUsedAt: &earlier}, loadInfo: &AccountLoadInfo{LoadRate: 10}}, + } + + for i := 0; i < 20; i++ { + cpy := make([]accountWithLoad, len(accounts)) + copy(cpy, accounts) + shuffleWithinSortGroups(cpy) + + // 组间顺序不变 + require.Equal(t, int64(1), cpy[0].account.ID, "group 1 position fixed") + require.Equal(t, int64(4), cpy[3].account.ID, "group 3 position fixed") + + // 组2 内部可以打乱,但仍在位置 1 和 2 + mid := map[int64]bool{cpy[1].account.ID: true, cpy[2].account.ID: true} + require.True(t, mid[2] && mid[3], "group 2 elements should stay in positions 1-2") + } +} + +// ============ shuffleWithinPriorityAndLastUsed 测试 ============ + +func TestShuffleWithinPriorityAndLastUsed_Empty(t *testing.T) { + shuffleWithinPriorityAndLastUsed(nil) + shuffleWithinPriorityAndLastUsed([]*Account{}) +} + +func TestShuffleWithinPriorityAndLastUsed_SingleElement(t *testing.T) { + accounts := []*Account{{ID: 1, Priority: 1}} + shuffleWithinPriorityAndLastUsed(accounts) + require.Equal(t, int64(1), accounts[0].ID) +} + +func TestShuffleWithinPriorityAndLastUsed_SameGroup_Shuffled(t *testing.T) { + accounts := []*Account{ + {ID: 1, Priority: 1, LastUsedAt: nil}, + {ID: 2, Priority: 1, LastUsedAt: nil}, + {ID: 3, Priority: 1, LastUsedAt: nil}, + } + + seen := map[int64]bool{} + for i := 0; i < 100; i++ { + cpy := make([]*Account, len(accounts)) + copy(cpy, accounts) + shuffleWithinPriorityAndLastUsed(cpy) + seen[cpy[0].ID] = true + } + require.GreaterOrEqual(t, len(seen), 2, "same group should be shuffled") +} + +func TestShuffleWithinPriorityAndLastUsed_DifferentPriority_OrderPreserved(t *testing.T) { + accounts := []*Account{ + {ID: 1, Priority: 1, LastUsedAt: nil}, + {ID: 2, Priority: 2, LastUsedAt: nil}, + {ID: 3, Priority: 3, LastUsedAt: nil}, + } + + for i := 0; i < 20; i++ { + cpy := make([]*Account, len(accounts)) + copy(cpy, accounts) + shuffleWithinPriorityAndLastUsed(cpy) + require.Equal(t, int64(1), cpy[0].ID) + require.Equal(t, int64(2), cpy[1].ID) + require.Equal(t, int64(3), cpy[2].ID) + } +} + +func TestShuffleWithinPriorityAndLastUsed_DifferentLastUsedAt_OrderPreserved(t *testing.T) { + now := time.Now() + earlier := now.Add(-1 * time.Hour) + + accounts := []*Account{ + {ID: 1, Priority: 1, LastUsedAt: nil}, + {ID: 2, Priority: 1, LastUsedAt: &earlier}, + {ID: 3, Priority: 1, LastUsedAt: &now}, + } + + for i := 0; i < 20; i++ { + cpy := make([]*Account, len(accounts)) + copy(cpy, accounts) + shuffleWithinPriorityAndLastUsed(cpy) + require.Equal(t, int64(1), cpy[0].ID) + require.Equal(t, int64(2), cpy[1].ID) + require.Equal(t, int64(3), cpy[2].ID) + } +} + +// ============ sameLastUsedAt 测试 ============ + +func TestSameLastUsedAt(t *testing.T) { + now := time.Now() + sameSecond := time.Unix(now.Unix(), 0) + sameSecondDiffNano := time.Unix(now.Unix(), 999_999_999) + differentSecond := now.Add(1 * time.Second) + + t.Run("both nil", func(t *testing.T) { + require.True(t, sameLastUsedAt(nil, nil)) + }) + + t.Run("one nil one not", func(t *testing.T) { + require.False(t, sameLastUsedAt(nil, &now)) + require.False(t, sameLastUsedAt(&now, nil)) + }) + + t.Run("same second different nanoseconds", func(t *testing.T) { + require.True(t, sameLastUsedAt(&sameSecond, &sameSecondDiffNano)) + }) + + t.Run("different seconds", func(t *testing.T) { + require.False(t, sameLastUsedAt(&now, &differentSecond)) + }) + + t.Run("exact same time", func(t *testing.T) { + require.True(t, sameLastUsedAt(&now, &now)) + }) +} + +// ============ sameAccountWithLoadGroup 测试 ============ + +func TestSameAccountWithLoadGroup(t *testing.T) { + now := time.Now() + sameSecond := time.Unix(now.Unix(), 0) + + t.Run("same group", func(t *testing.T) { + a := accountWithLoad{account: &Account{Priority: 1, LastUsedAt: &now}, loadInfo: &AccountLoadInfo{LoadRate: 10}} + b := accountWithLoad{account: &Account{Priority: 1, LastUsedAt: &sameSecond}, loadInfo: &AccountLoadInfo{LoadRate: 10}} + require.True(t, sameAccountWithLoadGroup(a, b)) + }) + + t.Run("different priority", func(t *testing.T) { + a := accountWithLoad{account: &Account{Priority: 1, LastUsedAt: &now}, loadInfo: &AccountLoadInfo{LoadRate: 10}} + b := accountWithLoad{account: &Account{Priority: 2, LastUsedAt: &now}, loadInfo: &AccountLoadInfo{LoadRate: 10}} + require.False(t, sameAccountWithLoadGroup(a, b)) + }) + + t.Run("different load rate", func(t *testing.T) { + a := accountWithLoad{account: &Account{Priority: 1, LastUsedAt: &now}, loadInfo: &AccountLoadInfo{LoadRate: 10}} + b := accountWithLoad{account: &Account{Priority: 1, LastUsedAt: &now}, loadInfo: &AccountLoadInfo{LoadRate: 20}} + require.False(t, sameAccountWithLoadGroup(a, b)) + }) + + t.Run("different last used at", func(t *testing.T) { + later := now.Add(1 * time.Second) + a := accountWithLoad{account: &Account{Priority: 1, LastUsedAt: &now}, loadInfo: &AccountLoadInfo{LoadRate: 10}} + b := accountWithLoad{account: &Account{Priority: 1, LastUsedAt: &later}, loadInfo: &AccountLoadInfo{LoadRate: 10}} + require.False(t, sameAccountWithLoadGroup(a, b)) + }) + + t.Run("both nil LastUsedAt", func(t *testing.T) { + a := accountWithLoad{account: &Account{Priority: 1, LastUsedAt: nil}, loadInfo: &AccountLoadInfo{LoadRate: 0}} + b := accountWithLoad{account: &Account{Priority: 1, LastUsedAt: nil}, loadInfo: &AccountLoadInfo{LoadRate: 0}} + require.True(t, sameAccountWithLoadGroup(a, b)) + }) +} + +// ============ sameAccountGroup 测试 ============ + +func TestSameAccountGroup(t *testing.T) { + now := time.Now() + + t.Run("same group", func(t *testing.T) { + a := &Account{Priority: 1, LastUsedAt: nil} + b := &Account{Priority: 1, LastUsedAt: nil} + require.True(t, sameAccountGroup(a, b)) + }) + + t.Run("different priority", func(t *testing.T) { + a := &Account{Priority: 1, LastUsedAt: nil} + b := &Account{Priority: 2, LastUsedAt: nil} + require.False(t, sameAccountGroup(a, b)) + }) + + t.Run("different LastUsedAt", func(t *testing.T) { + later := now.Add(1 * time.Second) + a := &Account{Priority: 1, LastUsedAt: &now} + b := &Account{Priority: 1, LastUsedAt: &later} + require.False(t, sameAccountGroup(a, b)) + }) +} + +// ============ sortAccountsByPriorityAndLastUsed 集成随机化测试 ============ + +func TestSortAccountsByPriorityAndLastUsed_WithShuffle(t *testing.T) { + t.Run("same priority and nil LastUsedAt are shuffled", func(t *testing.T) { + accounts := []*Account{ + {ID: 1, Priority: 1, LastUsedAt: nil}, + {ID: 2, Priority: 1, LastUsedAt: nil}, + {ID: 3, Priority: 1, LastUsedAt: nil}, + } + + seen := map[int64]bool{} + for i := 0; i < 100; i++ { + cpy := make([]*Account, len(accounts)) + copy(cpy, accounts) + sortAccountsByPriorityAndLastUsed(cpy, false) + seen[cpy[0].ID] = true + } + require.GreaterOrEqual(t, len(seen), 2, "identical sort keys should produce different orderings after shuffle") + }) + + t.Run("different priorities still sorted correctly", func(t *testing.T) { + now := time.Now() + accounts := []*Account{ + {ID: 3, Priority: 3, LastUsedAt: &now}, + {ID: 1, Priority: 1, LastUsedAt: &now}, + {ID: 2, Priority: 2, LastUsedAt: &now}, + } + + sortAccountsByPriorityAndLastUsed(accounts, false) + require.Equal(t, int64(1), accounts[0].ID) + require.Equal(t, int64(2), accounts[1].ID) + require.Equal(t, int64(3), accounts[2].ID) + }) +} diff --git a/backend/internal/service/sticky_session_test.go b/backend/internal/service/sticky_session_test.go index c70f12fe..e7ef8982 100644 --- a/backend/internal/service/sticky_session_test.go +++ b/backend/internal/service/sticky_session_test.go @@ -23,8 +23,7 @@ import ( // - 临时不可调度且未过期:清理 // - 临时不可调度已过期:不清理 // - 正常可调度状态:不清理 -// - 模型限流超过阈值:清理 -// - 模型限流未超过阈值:不清理 +// - 模型限流(任意时长):清理 // // TestShouldClearStickySession tests the sticky session clearing logic. // Verifies correct behavior for various account states including: @@ -35,9 +34,9 @@ func TestShouldClearStickySession(t *testing.T) { future := now.Add(1 * time.Hour) past := now.Add(-1 * time.Hour) - // 短限流时间(低于阈值,不应清除粘性会话) + // 短限流时间(有限流即清除粘性会话) shortRateLimitReset := now.Add(5 * time.Second).Format(time.RFC3339) - // 长限流时间(超过阈值,应清除粘性会话) + // 长限流时间(有限流即清除粘性会话) longRateLimitReset := now.Add(30 * time.Second).Format(time.RFC3339) tests := []struct { @@ -53,7 +52,7 @@ func TestShouldClearStickySession(t *testing.T) { {name: "temp unschedulable", account: &Account{Status: StatusActive, Schedulable: true, TempUnschedulableUntil: &future}, requestedModel: "", want: true}, {name: "temp unschedulable expired", account: &Account{Status: StatusActive, Schedulable: true, TempUnschedulableUntil: &past}, requestedModel: "", want: false}, {name: "active schedulable", account: &Account{Status: StatusActive, Schedulable: true}, requestedModel: "", want: false}, - // 模型限流测试 + // 模型限流测试:有限流即清除 { name: "model rate limited short duration", account: &Account{ @@ -68,7 +67,7 @@ func TestShouldClearStickySession(t *testing.T) { }, }, requestedModel: "claude-sonnet-4", - want: false, // 低于阈值,不清除 + want: true, // 有限流即清除 }, { name: "model rate limited long duration", @@ -84,7 +83,7 @@ func TestShouldClearStickySession(t *testing.T) { }, }, requestedModel: "claude-sonnet-4", - want: true, // 超过阈值,清除 + want: true, // 有限流即清除 }, { name: "model rate limited different model", diff --git a/backend/internal/service/wire.go b/backend/internal/service/wire.go index baceaaad..310fac1e 100644 --- a/backend/internal/service/wire.go +++ b/backend/internal/service/wire.go @@ -295,4 +295,5 @@ var ProviderSet = wire.NewSet( NewUsageCache, NewTotpService, NewErrorPassthroughService, + NewDigestSessionStore, ) diff --git a/backend/migrations/052_add_group_sort_order.sql b/backend/migrations/052_add_group_sort_order.sql new file mode 100644 index 00000000..ee687608 --- /dev/null +++ b/backend/migrations/052_add_group_sort_order.sql @@ -0,0 +1,8 @@ +-- Add sort_order field to groups table for custom ordering +ALTER TABLE groups ADD COLUMN IF NOT EXISTS sort_order INT NOT NULL DEFAULT 0; + +-- Initialize existing groups with sort_order based on their ID +UPDATE groups SET sort_order = id WHERE sort_order = 0; + +-- Create index for efficient sorting +CREATE INDEX IF NOT EXISTS idx_groups_sort_order ON groups(sort_order); diff --git a/backend/migrations/052_migrate_upstream_to_apikey.sql b/backend/migrations/052_migrate_upstream_to_apikey.sql new file mode 100644 index 00000000..974f3f3c --- /dev/null +++ b/backend/migrations/052_migrate_upstream_to_apikey.sql @@ -0,0 +1,11 @@ +-- Migrate upstream accounts to apikey type +-- Background: upstream type is no longer needed. Antigravity platform APIKey accounts +-- with base_url pointing to an upstream sub2api instance can reuse the standard +-- APIKey forwarding path. GetBaseURL()/GetGeminiBaseURL() automatically appends +-- /antigravity for Antigravity platform APIKey accounts. + +UPDATE accounts +SET type = 'apikey' +WHERE type = 'upstream' + AND platform = 'antigravity' + AND deleted_at IS NULL; diff --git a/frontend/package.json b/frontend/package.json index 38b92708..325eba60 100644 --- a/frontend/package.json +++ b/frontend/package.json @@ -27,6 +27,7 @@ "qrcode": "^1.5.4", "vue": "^3.4.0", "vue-chartjs": "^5.3.0", + "vue-draggable-plus": "^0.6.1", "vue-i18n": "^9.14.5", "vue-router": "^4.2.5", "xlsx": "^0.18.5" diff --git a/frontend/pnpm-lock.yaml b/frontend/pnpm-lock.yaml index 7dc73325..9af2d7af 100644 --- a/frontend/pnpm-lock.yaml +++ b/frontend/pnpm-lock.yaml @@ -44,6 +44,9 @@ importers: vue-chartjs: specifier: ^5.3.0 version: 5.3.3(chart.js@4.5.1)(vue@3.5.26(typescript@5.6.3)) + vue-draggable-plus: + specifier: ^0.6.1 + version: 0.6.1(@types/sortablejs@1.15.9) vue-i18n: specifier: ^9.14.5 version: 9.14.5(vue@3.5.26(typescript@5.6.3)) @@ -1254,67 +1257,56 @@ packages: resolution: {integrity: sha512-EHMUcDwhtdRGlXZsGSIuXSYwD5kOT9NVnx9sqzYiwAc91wfYOE1g1djOEDseZJKKqtHAHGwnGPQu3kytmfaXLQ==} cpu: [arm] os: [linux] - libc: [glibc] '@rollup/rollup-linux-arm-musleabihf@4.54.0': resolution: {integrity: sha512-+pBrqEjaakN2ySv5RVrj/qLytYhPKEUwk+e3SFU5jTLHIcAtqh2rLrd/OkbNuHJpsBgxsD8ccJt5ga/SeG0JmA==} cpu: [arm] os: [linux] - libc: [musl] '@rollup/rollup-linux-arm64-gnu@4.54.0': resolution: {integrity: sha512-NSqc7rE9wuUaRBsBp5ckQ5CVz5aIRKCwsoa6WMF7G01sX3/qHUw/z4pv+D+ahL1EIKy6Enpcnz1RY8pf7bjwng==} cpu: [arm64] os: [linux] - libc: [glibc] '@rollup/rollup-linux-arm64-musl@4.54.0': resolution: {integrity: sha512-gr5vDbg3Bakga5kbdpqx81m2n9IX8M6gIMlQQIXiLTNeQW6CucvuInJ91EuCJ/JYvc+rcLLsDFcfAD1K7fMofg==} cpu: [arm64] os: [linux] - libc: [musl] '@rollup/rollup-linux-loong64-gnu@4.54.0': resolution: {integrity: sha512-gsrtB1NA3ZYj2vq0Rzkylo9ylCtW/PhpLEivlgWe0bpgtX5+9j9EZa0wtZiCjgu6zmSeZWyI/e2YRX1URozpIw==} cpu: [loong64] os: [linux] - libc: [glibc] '@rollup/rollup-linux-ppc64-gnu@4.54.0': resolution: {integrity: sha512-y3qNOfTBStmFNq+t4s7Tmc9hW2ENtPg8FeUD/VShI7rKxNW7O4fFeaYbMsd3tpFlIg1Q8IapFgy7Q9i2BqeBvA==} cpu: [ppc64] os: [linux] - libc: [glibc] '@rollup/rollup-linux-riscv64-gnu@4.54.0': resolution: {integrity: sha512-89sepv7h2lIVPsFma8iwmccN7Yjjtgz0Rj/Ou6fEqg3HDhpCa+Et+YSufy27i6b0Wav69Qv4WBNl3Rs6pwhebQ==} cpu: [riscv64] os: [linux] - libc: [glibc] '@rollup/rollup-linux-riscv64-musl@4.54.0': resolution: {integrity: sha512-ZcU77ieh0M2Q8Ur7D5X7KvK+UxbXeDHwiOt/CPSBTI1fBmeDMivW0dPkdqkT4rOgDjrDDBUed9x4EgraIKoR2A==} cpu: [riscv64] os: [linux] - libc: [musl] '@rollup/rollup-linux-s390x-gnu@4.54.0': resolution: {integrity: sha512-2AdWy5RdDF5+4YfG/YesGDDtbyJlC9LHmL6rZw6FurBJ5n4vFGupsOBGfwMRjBYH7qRQowT8D/U4LoSvVwOhSQ==} cpu: [s390x] os: [linux] - libc: [glibc] '@rollup/rollup-linux-x64-gnu@4.54.0': resolution: {integrity: sha512-WGt5J8Ij/rvyqpFexxk3ffKqqbLf9AqrTBbWDk7ApGUzaIs6V+s2s84kAxklFwmMF/vBNGrVdYgbblCOFFezMQ==} cpu: [x64] os: [linux] - libc: [glibc] '@rollup/rollup-linux-x64-musl@4.54.0': resolution: {integrity: sha512-JzQmb38ATzHjxlPHuTH6tE7ojnMKM2kYNzt44LO/jJi8BpceEC8QuXYA908n8r3CNuG/B3BV8VR3Hi1rYtmPiw==} cpu: [x64] os: [linux] - libc: [musl] '@rollup/rollup-openharmony-arm64@4.54.0': resolution: {integrity: sha512-huT3fd0iC7jigGh7n3q/+lfPcXxBi+om/Rs3yiFxjvSxbSB6aohDFXbWvlspaqjeOh+hx7DDHS+5Es5qRkWkZg==} @@ -1515,6 +1507,9 @@ packages: '@types/react@19.2.7': resolution: {integrity: sha512-MWtvHrGZLFttgeEj28VXHxpmwYbor/ATPYbBfSFZEIRK0ecCFLl2Qo55z52Hss+UV9CRN7trSeq1zbgx7YDWWg==} + '@types/sortablejs@1.15.9': + resolution: {integrity: sha512-7HP+rZGE2p886PKV9c9OJzLBI6BBJu1O7lJGYnPyG3fS4/duUCcngkNCjsLwIMV+WMqANe3tt4irrXHSIe68OQ==} + '@types/trusted-types@2.0.7': resolution: {integrity: sha512-ScaPdn1dQczgbl0QFTeTOmVHFULt394XJgOQNoyVhZ6r2vLnMLJfBPd53SB52T/3G36VI1/g2MZaX0cwDuXsfw==} @@ -4298,6 +4293,15 @@ packages: '@vue/composition-api': optional: true + vue-draggable-plus@0.6.1: + resolution: {integrity: sha512-FbtQ/fuoixiOfTZzG3yoPl4JAo9HJXRHmBQZFB9x2NYCh6pq0TomHf7g5MUmpaDYv+LU2n6BPq2YN9sBO+FbIg==} + peerDependencies: + '@types/sortablejs': ^1.15.0 + '@vue/composition-api': '*' + peerDependenciesMeta: + '@vue/composition-api': + optional: true + vue-eslint-parser@9.4.3: resolution: {integrity: sha512-2rYRLWlIpaiN8xbPiDyXZXRgLGOtWxERV7ND5fFAv5qo1D2N9Fu9MNajBNc6o13lZ+24DAWCkQCvj4klgmcITg==} engines: {node: ^14.17.0 || >=16.0.0} @@ -5958,6 +5962,8 @@ snapshots: dependencies: csstype: 3.2.3 + '@types/sortablejs@1.15.9': {} + '@types/trusted-types@2.0.7': {} '@types/unist@2.0.11': {} @@ -9401,6 +9407,10 @@ snapshots: dependencies: vue: 3.5.26(typescript@5.6.3) + vue-draggable-plus@0.6.1(@types/sortablejs@1.15.9): + dependencies: + '@types/sortablejs': 1.15.9 + vue-eslint-parser@9.4.3(eslint@8.57.1): dependencies: debug: 4.4.3 diff --git a/frontend/src/api/admin/accounts.ts b/frontend/src/api/admin/accounts.ts index 6df93498..4cb1a6f2 100644 --- a/frontend/src/api/admin/accounts.ts +++ b/frontend/src/api/admin/accounts.ts @@ -327,11 +327,34 @@ export async function getAvailableModels(id: number): Promise { return data } +export interface CRSPreviewAccount { + crs_account_id: string + kind: string + name: string + platform: string + type: string +} + +export interface PreviewFromCRSResult { + new_accounts: CRSPreviewAccount[] + existing_accounts: CRSPreviewAccount[] +} + +export async function previewFromCrs(params: { + base_url: string + username: string + password: string +}): Promise { + const { data } = await apiClient.post('/admin/accounts/sync/crs/preview', params) + return data +} + export async function syncFromCrs(params: { base_url: string username: string password: string sync_proxies?: boolean + selected_account_ids?: string[] }): Promise<{ created: number updated: number @@ -345,7 +368,19 @@ export async function syncFromCrs(params: { error?: string }> }> { - const { data } = await apiClient.post('/admin/accounts/sync/crs', params) + const { data } = await apiClient.post<{ + created: number + updated: number + skipped: number + failed: number + items: Array<{ + crs_account_id: string + kind: string + name: string + action: string + error?: string + }> + }>('/admin/accounts/sync/crs', params) return data } @@ -398,6 +433,26 @@ export async function getAntigravityDefaultModelMapping(): Promise> { + const payload: { refresh_token: string; proxy_id?: number } = { + refresh_token: refreshToken + } + if (proxyId) { + payload.proxy_id = proxyId + } + const { data } = await apiClient.post>('/admin/openai/refresh-token', payload) + return data +} + export const accountsAPI = { list, getById, @@ -418,9 +473,11 @@ export const accountsAPI = { getAvailableModels, generateAuthUrl, exchangeCode, + refreshOpenAIToken, batchCreate, batchUpdateCredentials, bulkUpdate, + previewFromCrs, syncFromCrs, exportData, importData, diff --git a/frontend/src/api/admin/groups.ts b/frontend/src/api/admin/groups.ts index 4d2b10ef..3d18ba87 100644 --- a/frontend/src/api/admin/groups.ts +++ b/frontend/src/api/admin/groups.ts @@ -153,6 +153,20 @@ export async function getGroupApiKeys( return data } +/** + * Update group sort orders + * @param updates - Array of { id, sort_order } objects + * @returns Success confirmation + */ +export async function updateSortOrder( + updates: Array<{ id: number; sort_order: number }> +): Promise<{ message: string }> { + const { data } = await apiClient.put<{ message: string }>('/admin/groups/sort-order', { + updates + }) + return data +} + export const groupsAPI = { list, getAll, @@ -163,7 +177,8 @@ export const groupsAPI = { delete: deleteGroup, toggleStatus, getStats, - getGroupApiKeys + getGroupApiKeys, + updateSortOrder } export default groupsAPI diff --git a/frontend/src/api/admin/ops.ts b/frontend/src/api/admin/ops.ts index 5b96feda..9f980a12 100644 --- a/frontend/src/api/admin/ops.ts +++ b/frontend/src/api/admin/ops.ts @@ -376,7 +376,6 @@ export interface PlatformAvailability { total_accounts: number available_count: number rate_limit_count: number - scope_rate_limit_count?: Record error_count: number } @@ -387,7 +386,6 @@ export interface GroupAvailability { total_accounts: number available_count: number rate_limit_count: number - scope_rate_limit_count?: Record error_count: number } @@ -402,7 +400,6 @@ export interface AccountAvailability { is_rate_limited: boolean rate_limit_reset_at?: string rate_limit_remaining_sec?: number - scope_rate_limits?: Record is_overloaded: boolean overload_until?: string overload_remaining_sec?: number diff --git a/frontend/src/components/account/AccountStatusIndicator.vue b/frontend/src/components/account/AccountStatusIndicator.vue index 3474da44..5fe96a1d 100644 --- a/frontend/src/components/account/AccountStatusIndicator.vue +++ b/frontend/src/components/account/AccountStatusIndicator.vue @@ -76,26 +76,6 @@ - - - + + diff --git a/frontend/src/composables/useAccountOAuth.ts b/frontend/src/composables/useAccountOAuth.ts index bdc6f0f1..ca200cb3 100644 --- a/frontend/src/composables/useAccountOAuth.ts +++ b/frontend/src/composables/useAccountOAuth.ts @@ -3,7 +3,7 @@ import { useAppStore } from '@/stores/app' import { adminAPI } from '@/api/admin' export type AddMethod = 'oauth' | 'setup-token' -export type AuthInputMethod = 'manual' | 'cookie' +export type AuthInputMethod = 'manual' | 'cookie' | 'refresh_token' export interface OAuthState { authUrl: string diff --git a/frontend/src/composables/useOpenAIOAuth.ts b/frontend/src/composables/useOpenAIOAuth.ts index 4b5ffe31..82a77031 100644 --- a/frontend/src/composables/useOpenAIOAuth.ts +++ b/frontend/src/composables/useOpenAIOAuth.ts @@ -105,6 +105,32 @@ export function useOpenAIOAuth() { } } + // Validate refresh token and get full token info + const validateRefreshToken = async ( + refreshToken: string, + proxyId?: number | null + ): Promise => { + if (!refreshToken.trim()) { + error.value = 'Missing refresh token' + return null + } + + loading.value = true + error.value = '' + + try { + // Use dedicated refresh-token endpoint + const tokenInfo = await adminAPI.accounts.refreshOpenAIToken(refreshToken.trim(), proxyId) + return tokenInfo as OpenAITokenInfo + } catch (err: any) { + error.value = err.response?.data?.detail || 'Failed to validate refresh token' + appStore.showError(error.value) + return null + } finally { + loading.value = false + } + } + // Build credentials for OpenAI OAuth account const buildCredentials = (tokenInfo: OpenAITokenInfo): Record => { const creds: Record = { @@ -152,6 +178,7 @@ export function useOpenAIOAuth() { resetState, generateAuthUrl, exchangeAuthCode, + validateRefreshToken, buildCredentials, buildExtraInfo } diff --git a/frontend/src/i18n/locales/en.ts b/frontend/src/i18n/locales/en.ts index 93e467f2..a04cbdf1 100644 --- a/frontend/src/i18n/locales/en.ts +++ b/frontend/src/i18n/locales/en.ts @@ -1042,6 +1042,10 @@ export default { createGroup: 'Create Group', editGroup: 'Edit Group', deleteGroup: 'Delete Group', + sortOrder: 'Sort', + sortOrderHint: 'Drag groups to adjust display order, groups at the top will be displayed first', + sortOrderUpdated: 'Sort order updated', + failedToUpdateSortOrder: 'Failed to update sort order', allPlatforms: 'All Platforms', allStatus: 'All Status', allGroups: 'All Groups', @@ -1314,10 +1318,23 @@ export default { syncResult: 'Sync Result', syncResultSummary: 'Created {created}, updated {updated}, skipped {skipped}, failed {failed}', syncErrors: 'Errors / Skipped Details', - syncCompleted: 'Sync completed: created {created}, updated {updated}', + syncCompleted: 'Sync completed: created {created}, updated {updated}, skipped {skipped}', syncCompletedWithErrors: - 'Sync completed with errors: failed {failed} (created {created}, updated {updated})', + 'Sync completed with errors: failed {failed} (created {created}, updated {updated}, skipped {skipped})', syncFailed: 'Sync failed', + crsPreview: 'Preview', + crsPreviewing: 'Previewing...', + crsPreviewFailed: 'Preview failed', + crsExistingAccounts: 'Existing accounts (will be updated)', + crsNewAccounts: 'New accounts (select to sync)', + crsSelectAll: 'Select all', + crsSelectNone: 'Select none', + crsNoNewAccounts: 'All CRS accounts are already synced.', + crsWillUpdate: 'Will update {count} existing accounts.', + crsSelectedCount: '{count} new accounts selected', + crsUpdateBehaviorNote: + 'Existing accounts only sync fields returned by CRS; missing fields keep their current values. Credentials are merged by key — keys not returned by CRS are preserved. Proxies are kept when "Sync proxies" is unchecked.', + crsBack: 'Back', editAccount: 'Edit Account', deleteAccount: 'Delete Account', searchAccounts: 'Search accounts...', @@ -1366,7 +1383,6 @@ export default { overloaded: 'Overloaded', tempUnschedulable: 'Temp Unschedulable', rateLimitedUntil: 'Rate limited until {time}', - scopeRateLimitedUntil: '{scope} rate limited until {time}', modelRateLimitedUntil: '{model} rate limited until {time}', overloadedUntil: 'Overloaded until {time}', viewTempUnschedDetails: 'View temp unschedulable details' @@ -1679,6 +1695,9 @@ export default { cookieAuthFailed: 'Cookie authorization failed', keyAuthFailed: 'Key {index}: {error}', successCreated: 'Successfully created {count} account(s)', + batchSuccess: 'Successfully created {count} account(s)', + batchPartialSuccess: 'Partial success: {success} succeeded, {failed} failed', + batchFailed: 'Batch creation failed', // OpenAI specific openai: { title: 'OpenAI Account Authorization', @@ -1697,7 +1716,14 @@ export default { authCodePlaceholder: 'Option 1: Copy the complete URL\n(http://localhost:xxx/auth/callback?code=...)\nOption 2: Copy only the code parameter value', authCodeHint: - 'You can copy the entire URL or just the code parameter value, the system will auto-detect' + 'You can copy the entire URL or just the code parameter value, the system will auto-detect', + // Refresh Token auth + refreshTokenAuth: 'Manual RT Input', + refreshTokenDesc: 'Enter your existing OpenAI Refresh Token(s). Supports batch input (one per line). The system will automatically validate and create accounts.', + refreshTokenPlaceholder: 'Paste your OpenAI Refresh Token...\nSupports multiple, one per line', + validating: 'Validating...', + validateAndCreate: 'Validate & Create Account', + pleaseEnterRefreshToken: 'Please enter Refresh Token' }, // Gemini specific gemini: { @@ -3066,7 +3092,6 @@ export default { empty: 'No data', queued: 'Queue {count}', rateLimited: 'Rate-limited {count}', - scopeRateLimitedTooltip: '{scope} rate-limited ({count} accounts)', errorAccounts: 'Errors {count}', loadFailed: 'Failed to load concurrency data' }, diff --git a/frontend/src/i18n/locales/zh.ts b/frontend/src/i18n/locales/zh.ts index c280ed1c..22d90ee2 100644 --- a/frontend/src/i18n/locales/zh.ts +++ b/frontend/src/i18n/locales/zh.ts @@ -1099,6 +1099,10 @@ export default { createGroup: '创建分组', editGroup: '编辑分组', deleteGroup: '删除分组', + sortOrder: '排序', + sortOrderHint: '拖拽分组调整显示顺序,排在前面的分组会优先显示', + sortOrderUpdated: '排序已更新', + failedToUpdateSortOrder: '更新排序失败', deleteConfirm: "确定要删除分组 '{name}' 吗?所有关联的 API 密钥将不再属于任何分组。", deleteConfirmSubscription: "确定要删除订阅分组 '{name}' 吗?此操作会让所有绑定此订阅的用户的 API Key 失效,并删除所有相关的订阅记录。此操作无法撤销。", @@ -1402,9 +1406,22 @@ export default { syncResult: '同步结果', syncResultSummary: '创建 {created},更新 {updated},跳过 {skipped},失败 {failed}', syncErrors: '错误/跳过详情', - syncCompleted: '同步完成:创建 {created},更新 {updated}', - syncCompletedWithErrors: '同步完成但有错误:失败 {failed}(创建 {created},更新 {updated})', + syncCompleted: '同步完成:创建 {created},更新 {updated},跳过 {skipped}', + syncCompletedWithErrors: '同步完成但有错误:失败 {failed}(创建 {created},更新 {updated},跳过 {skipped})', syncFailed: '同步失败', + crsPreview: '预览', + crsPreviewing: '预览中...', + crsPreviewFailed: '预览失败', + crsExistingAccounts: '将自动更新的已有账号', + crsNewAccounts: '新账号(可选择)', + crsSelectAll: '全选', + crsSelectNone: '全不选', + crsNoNewAccounts: '所有 CRS 账号均已同步。', + crsWillUpdate: '将更新 {count} 个已有账号。', + crsSelectedCount: '已选择 {count} 个新账号', + crsUpdateBehaviorNote: + '已有账号仅同步 CRS 返回的字段,缺失字段保持原值;凭据按键合并,不会清空未下发的键;未勾选"同步代理"时保留原有代理。', + crsBack: '返回', editAccount: '编辑账号', deleteAccount: '删除账号', deleteConfirmMessage: "确定要删除账号 '{name}' 吗?", @@ -1502,7 +1519,6 @@ export default { overloaded: '过载中', tempUnschedulable: '临时不可调度', rateLimitedUntil: '限流中,重置时间:{time}', - scopeRateLimitedUntil: '{scope} 限流中,重置时间:{time}', modelRateLimitedUntil: '{model} 限流至 {time}', overloadedUntil: '负载过重,重置时间:{time}', viewTempUnschedDetails: '查看临时不可调度详情' @@ -1821,6 +1837,9 @@ export default { cookieAuthFailed: 'Cookie 授权失败', keyAuthFailed: '密钥 {index}: {error}', successCreated: '成功创建 {count} 个账号', + batchSuccess: '成功创建 {count} 个账号', + batchPartialSuccess: '部分成功:{success} 个成功,{failed} 个失败', + batchFailed: '批量创建失败', // OpenAI specific openai: { title: 'OpenAI 账户授权', @@ -1837,7 +1856,14 @@ export default { authCode: '授权链接或 Code', authCodePlaceholder: '方式1:复制完整的链接\n(http://localhost:xxx/auth/callback?code=...)\n方式2:仅复制 code 参数的值', - authCodeHint: '您可以直接复制整个链接或仅复制 code 参数值,系统会自动识别' + authCodeHint: '您可以直接复制整个链接或仅复制 code 参数值,系统会自动识别', + // Refresh Token auth + refreshTokenAuth: '手动输入 RT', + refreshTokenDesc: '输入您已有的 OpenAI Refresh Token,支持批量输入(每行一个),系统将自动验证并创建账号。', + refreshTokenPlaceholder: '粘贴您的 OpenAI Refresh Token...\n支持多个,每行一个', + validating: '验证中...', + validateAndCreate: '验证并创建账号', + pleaseEnterRefreshToken: '请输入 Refresh Token' }, // Gemini specific gemini: { @@ -3239,7 +3265,6 @@ export default { empty: '暂无数据', queued: '队列 {count}', rateLimited: '限流 {count}', - scopeRateLimitedTooltip: '{scope} 限流中 ({count} 个账号)', errorAccounts: '异常 {count}', loadFailed: '加载并发数据失败' }, diff --git a/frontend/src/types/index.ts b/frontend/src/types/index.ts index 1472dd2c..e5f71520 100644 --- a/frontend/src/types/index.ts +++ b/frontend/src/types/index.ts @@ -43,6 +43,8 @@ export interface AdminUser extends User { notes: string // 用户专属分组倍率配置 (group_id -> rate_multiplier) group_rates?: Record + // 当前并发数(仅管理员列表接口返回) + current_concurrency?: number } export interface LoginRequest { @@ -382,6 +384,9 @@ export interface AdminGroup extends Group { // 分组下账号数量(仅管理员可见) account_count?: number + + // 分组排序 + sort_order: number } export interface ApiKey { @@ -602,9 +607,6 @@ export interface Account { temp_unschedulable_until: string | null temp_unschedulable_reason: string | null - // Antigravity scope 级限流状态 - scope_rate_limits?: Record - // Session window fields (5-hour window) session_window_start: string | null session_window_end: string | null diff --git a/frontend/src/views/admin/AnnouncementsView.vue b/frontend/src/views/admin/AnnouncementsView.vue index 38574454..08d7b871 100644 --- a/frontend/src/views/admin/AnnouncementsView.vue +++ b/frontend/src/views/admin/AnnouncementsView.vue @@ -1,26 +1,10 @@ @@ -1592,6 +1686,7 @@ import EmptyState from '@/components/common/EmptyState.vue' import Select from '@/components/common/Select.vue' import PlatformIcon from '@/components/common/PlatformIcon.vue' import Icon from '@/components/icons/Icon.vue' +import { VueDraggable } from 'vue-draggable-plus' const { t } = useI18n() const appStore = useAppStore() @@ -1758,9 +1853,12 @@ let abortController: AbortController | null = null const showCreateModal = ref(false) const showEditModal = ref(false) const showDeleteDialog = ref(false) +const showSortModal = ref(false) const submitting = ref(false) +const sortSubmitting = ref(false) const editingGroup = ref(null) const deletingGroup = ref(null) +const sortableGroups = ref([]) const createForm = reactive({ name: '', @@ -2237,6 +2335,46 @@ const handleClickOutside = (event: MouseEvent) => { } } +// 打开排序弹窗 +const openSortModal = async () => { + try { + // 获取所有分组(不分页) + const allGroups = await adminAPI.groups.getAll() + // 按 sort_order 排序 + sortableGroups.value = [...allGroups].sort((a, b) => a.sort_order - b.sort_order) + showSortModal.value = true + } catch (error) { + appStore.showError(t('admin.groups.failedToLoad')) + console.error('Error loading groups for sorting:', error) + } +} + +// 关闭排序弹窗 +const closeSortModal = () => { + showSortModal.value = false + sortableGroups.value = [] +} + +// 保存排序 +const saveSortOrder = async () => { + sortSubmitting.value = true + try { + const updates = sortableGroups.value.map((g, index) => ({ + id: g.id, + sort_order: index * 10 + })) + await adminAPI.groups.updateSortOrder(updates) + appStore.showSuccess(t('admin.groups.sortOrderUpdated')) + closeSortModal() + loadGroups() + } catch (error: any) { + appStore.showError(error.response?.data?.detail || t('admin.groups.failedToUpdateSortOrder')) + console.error('Error updating sort order:', error) + } finally { + sortSubmitting.value = false + } +} + onMounted(() => { loadGroups() document.addEventListener('click', handleClickOutside) diff --git a/frontend/src/views/admin/PromoCodesView.vue b/frontend/src/views/admin/PromoCodesView.vue index 968728b2..73499f80 100644 --- a/frontend/src/views/admin/PromoCodesView.vue +++ b/frontend/src/views/admin/PromoCodesView.vue @@ -1,26 +1,10 @@