diff --git a/.gitignore b/.gitignore
index fe715240..48172982 100644
--- a/.gitignore
+++ b/.gitignore
@@ -83,6 +83,8 @@ temp/
*.log
*.bak
.cache/
+.dev/
+.serena/
# ===================
# 构建产物
@@ -127,3 +129,4 @@ deploy/docker-compose.override.yml
.gocache/
vite.config.js
docs/*
+.serena/
\ No newline at end of file
diff --git a/PR_DESCRIPTION.md b/PR_DESCRIPTION.md
new file mode 100644
index 00000000..b240f45c
--- /dev/null
+++ b/PR_DESCRIPTION.md
@@ -0,0 +1,164 @@
+## 概述
+
+全面增强运维监控系统(Ops)的错误日志管理和告警静默功能,优化前端 UI 组件代码质量和用户体验。本次更新重构了核心服务层和数据访问层,提升系统可维护性和运维效率。
+
+## 主要改动
+
+### 1. 错误日志查询优化
+
+**功能特性:**
+- 新增 GetErrorLogByID 接口,支持按 ID 精确查询错误详情
+- 优化错误日志过滤逻辑,支持多维度筛选(平台、阶段、来源、所有者等)
+- 改进查询参数处理,简化代码结构
+- 增强错误分类和标准化处理
+- 支持错误解决状态追踪(resolved 字段)
+
+**技术实现:**
+- `ops_handler.go` - 新增单条错误日志查询接口
+- `ops_repo.go` - 优化数据查询和过滤条件构建
+- `ops_models.go` - 扩展错误日志数据模型
+- 前端 API 接口同步更新
+
+### 2. 告警静默功能
+
+**功能特性:**
+- 支持按规则、平台、分组、区域等维度静默告警
+- 可设置静默时长和原因说明
+- 静默记录可追溯,记录创建人和创建时间
+- 自动过期机制,避免永久静默
+
+**技术实现:**
+- `037_ops_alert_silences.sql` - 新增告警静默表
+- `ops_alerts.go` - 告警静默逻辑实现
+- `ops_alerts_handler.go` - 告警静默 API 接口
+- `OpsAlertEventsCard.vue` - 前端告警静默操作界面
+
+**数据库结构:**
+
+| 字段 | 类型 | 说明 |
+|------|------|------|
+| rule_id | BIGINT | 告警规则 ID |
+| platform | VARCHAR(64) | 平台标识 |
+| group_id | BIGINT | 分组 ID(可选) |
+| region | VARCHAR(64) | 区域(可选) |
+| until | TIMESTAMPTZ | 静默截止时间 |
+| reason | TEXT | 静默原因 |
+| created_by | BIGINT | 创建人 ID |
+
+### 3. 错误分类标准化
+
+**功能特性:**
+- 统一错误阶段分类(request|auth|routing|upstream|network|internal)
+- 规范错误归属分类(client|provider|platform)
+- 标准化错误来源分类(client_request|upstream_http|gateway)
+- 自动迁移历史数据到新分类体系
+
+**技术实现:**
+- `038_ops_errors_resolution_retry_results_and_standardize_classification.sql` - 分类标准化迁移
+- 自动映射历史遗留分类到新标准
+- 自动解决已恢复的上游错误(客户端状态码 < 400)
+
+### 4. Gateway 服务集成
+
+**功能特性:**
+- 完善各 Gateway 服务的 Ops 集成
+- 统一错误日志记录接口
+- 增强上游错误追踪能力
+
+**涉及服务:**
+- `antigravity_gateway_service.go` - Antigravity 网关集成
+- `gateway_service.go` - 通用网关集成
+- `gemini_messages_compat_service.go` - Gemini 兼容层集成
+- `openai_gateway_service.go` - OpenAI 网关集成
+
+### 5. 前端 UI 优化
+
+**代码重构:**
+- 大幅简化错误详情模态框代码(从 828 行优化到 450 行)
+- 优化错误日志表格组件,提升可读性
+- 清理未使用的 i18n 翻译,减少冗余
+- 统一组件代码风格和格式
+- 优化骨架屏组件,更好匹配实际看板布局
+
+**布局改进:**
+- 修复模态框内容溢出和滚动问题
+- 优化表格布局,使用 flex 布局确保正确显示
+- 改进看板头部布局和交互
+- 提升响应式体验
+- 骨架屏支持全屏模式适配
+
+**交互优化:**
+- 优化告警事件卡片功能和展示
+- 改进错误详情展示逻辑
+- 增强请求详情模态框
+- 完善运行时设置卡片
+- 改进加载动画效果
+
+### 6. 国际化完善
+
+**文案补充:**
+- 补充错误日志相关的英文翻译
+- 添加告警静默功能的中英文文案
+- 完善提示文本和错误信息
+- 统一术语翻译标准
+
+## 文件变更
+
+**后端(26 个文件):**
+- `backend/internal/handler/admin/ops_alerts_handler.go` - 告警接口增强
+- `backend/internal/handler/admin/ops_handler.go` - 错误日志接口优化
+- `backend/internal/handler/ops_error_logger.go` - 错误记录器增强
+- `backend/internal/repository/ops_repo.go` - 数据访问层重构
+- `backend/internal/repository/ops_repo_alerts.go` - 告警数据访问增强
+- `backend/internal/service/ops_*.go` - 核心服务层重构(10 个文件)
+- `backend/internal/service/*_gateway_service.go` - Gateway 集成(4 个文件)
+- `backend/internal/server/routes/admin.go` - 路由配置更新
+- `backend/migrations/*.sql` - 数据库迁移(2 个文件)
+- 测试文件更新(5 个文件)
+
+**前端(13 个文件):**
+- `frontend/src/views/admin/ops/OpsDashboard.vue` - 看板主页优化
+- `frontend/src/views/admin/ops/components/*.vue` - 组件重构(10 个文件)
+- `frontend/src/api/admin/ops.ts` - API 接口扩展
+- `frontend/src/i18n/locales/*.ts` - 国际化文本(2 个文件)
+
+## 代码统计
+
+- 44 个文件修改
+- 3733 行新增
+- 995 行删除
+- 净增加 2738 行
+
+## 核心改进
+
+**可维护性提升:**
+- 重构核心服务层,职责更清晰
+- 简化前端组件代码,降低复杂度
+- 统一代码风格和命名规范
+- 清理冗余代码和未使用的翻译
+- 标准化错误分类体系
+
+**功能完善:**
+- 告警静默功能,减少告警噪音
+- 错误日志查询优化,提升运维效率
+- Gateway 服务集成完善,统一监控能力
+- 错误解决状态追踪,便于问题管理
+
+**用户体验优化:**
+- 修复多个 UI 布局问题
+- 优化交互流程
+- 完善国际化支持
+- 提升响应式体验
+- 改进加载状态展示
+
+## 测试验证
+
+- ✅ 错误日志查询和过滤功能
+- ✅ 告警静默创建和自动过期
+- ✅ 错误分类标准化迁移
+- ✅ Gateway 服务错误日志记录
+- ✅ 前端组件布局和交互
+- ✅ 骨架屏全屏模式适配
+- ✅ 国际化文本完整性
+- ✅ API 接口功能正确性
+- ✅ 数据库迁移执行成功
diff --git a/backend/cmd/server/wire_gen.go b/backend/cmd/server/wire_gen.go
index 6b5ffad4..31e47332 100644
--- a/backend/cmd/server/wire_gen.go
+++ b/backend/cmd/server/wire_gen.go
@@ -67,7 +67,6 @@ func initializeApplication(buildInfo handler.BuildInfo) (*Application, error) {
userHandler := handler.NewUserHandler(userService)
apiKeyHandler := handler.NewAPIKeyHandler(apiKeyService)
usageLogRepository := repository.NewUsageLogRepository(client, db)
- dashboardAggregationRepository := repository.NewDashboardAggregationRepository(db)
usageService := service.NewUsageService(usageLogRepository, userRepository, client, apiKeyAuthCacheInvalidator)
usageHandler := handler.NewUsageHandler(usageService, apiKeyService)
redeemCodeRepository := repository.NewRedeemCodeRepository(client)
@@ -76,15 +75,20 @@ func initializeApplication(buildInfo handler.BuildInfo) (*Application, error) {
redeemService := service.NewRedeemService(redeemCodeRepository, userRepository, subscriptionService, redeemCache, billingCacheService, client, apiKeyAuthCacheInvalidator)
redeemHandler := handler.NewRedeemHandler(redeemService)
subscriptionHandler := handler.NewSubscriptionHandler(subscriptionService)
+ dashboardAggregationRepository := repository.NewDashboardAggregationRepository(db)
dashboardStatsCache := repository.NewDashboardCache(redisClient, configConfig)
- timingWheelService := service.ProvideTimingWheelService()
- dashboardAggregationService := service.ProvideDashboardAggregationService(dashboardAggregationRepository, timingWheelService, configConfig)
dashboardService := service.NewDashboardService(usageLogRepository, dashboardAggregationRepository, dashboardStatsCache, configConfig)
+ timingWheelService, err := service.ProvideTimingWheelService()
+ if err != nil {
+ return nil, err
+ }
+ dashboardAggregationService := service.ProvideDashboardAggregationService(dashboardAggregationRepository, timingWheelService, configConfig)
dashboardHandler := admin.NewDashboardHandler(dashboardService, dashboardAggregationService)
accountRepository := repository.NewAccountRepository(client, db)
proxyRepository := repository.NewProxyRepository(client, db)
proxyExitInfoProber := repository.NewProxyExitInfoProber(configConfig)
- adminService := service.NewAdminService(userRepository, groupRepository, accountRepository, proxyRepository, apiKeyRepository, redeemCodeRepository, billingCacheService, proxyExitInfoProber, apiKeyAuthCacheInvalidator)
+ proxyLatencyCache := repository.NewProxyLatencyCache(redisClient)
+ adminService := service.NewAdminService(userRepository, groupRepository, accountRepository, proxyRepository, apiKeyRepository, redeemCodeRepository, billingCacheService, proxyExitInfoProber, proxyLatencyCache, apiKeyAuthCacheInvalidator)
adminUserHandler := admin.NewUserHandler(adminService)
groupHandler := admin.NewGroupHandler(adminService)
claudeOAuthClient := repository.NewClaudeOAuthClient()
@@ -98,12 +102,13 @@ func initializeApplication(buildInfo handler.BuildInfo) (*Application, error) {
geminiQuotaService := service.NewGeminiQuotaService(configConfig, settingRepository)
tempUnschedCache := repository.NewTempUnschedCache(redisClient)
timeoutCounterCache := repository.NewTimeoutCounterCache(redisClient)
- rateLimitService := service.ProvideRateLimitService(accountRepository, usageLogRepository, configConfig, geminiQuotaService, tempUnschedCache, timeoutCounterCache, settingService)
+ geminiTokenCache := repository.NewGeminiTokenCache(redisClient)
+ compositeTokenCacheInvalidator := service.NewCompositeTokenCacheInvalidator(geminiTokenCache)
+ rateLimitService := service.ProvideRateLimitService(accountRepository, usageLogRepository, configConfig, geminiQuotaService, tempUnschedCache, timeoutCounterCache, settingService, compositeTokenCacheInvalidator)
claudeUsageFetcher := repository.NewClaudeUsageFetcher()
antigravityQuotaFetcher := service.NewAntigravityQuotaFetcher(proxyRepository)
usageCache := service.NewUsageCache()
accountUsageService := service.NewAccountUsageService(accountRepository, usageLogRepository, claudeUsageFetcher, geminiQuotaService, antigravityQuotaFetcher, usageCache)
- geminiTokenCache := repository.NewGeminiTokenCache(redisClient)
geminiTokenProvider := service.NewGeminiTokenProvider(accountRepository, geminiTokenCache, geminiOAuthService)
gatewayCache := repository.NewGatewayCache(redisClient)
antigravityTokenProvider := service.NewAntigravityTokenProvider(accountRepository, geminiTokenCache, antigravityOAuthService)
@@ -112,11 +117,9 @@ func initializeApplication(buildInfo handler.BuildInfo) (*Application, error) {
accountTestService := service.NewAccountTestService(accountRepository, geminiTokenProvider, antigravityGatewayService, httpUpstream, configConfig)
concurrencyCache := repository.ProvideConcurrencyCache(redisClient, configConfig)
concurrencyService := service.ProvideConcurrencyService(concurrencyCache, accountRepository, configConfig)
- schedulerCache := repository.NewSchedulerCache(redisClient)
- schedulerOutboxRepository := repository.NewSchedulerOutboxRepository(db)
- schedulerSnapshotService := service.ProvideSchedulerSnapshotService(schedulerCache, schedulerOutboxRepository, accountRepository, groupRepository, configConfig)
crsSyncService := service.NewCRSSyncService(accountRepository, proxyRepository, oAuthService, openAIOAuthService, geminiOAuthService, configConfig)
- accountHandler := admin.NewAccountHandler(adminService, oAuthService, openAIOAuthService, geminiOAuthService, antigravityOAuthService, rateLimitService, accountUsageService, accountTestService, concurrencyService, crsSyncService)
+ sessionLimitCache := repository.ProvideSessionLimitCache(redisClient, configConfig)
+ accountHandler := admin.NewAccountHandler(adminService, oAuthService, openAIOAuthService, geminiOAuthService, antigravityOAuthService, rateLimitService, accountUsageService, accountTestService, concurrencyService, crsSyncService, sessionLimitCache)
oAuthHandler := admin.NewOAuthHandler(oAuthService)
openAIOAuthHandler := admin.NewOpenAIOAuthHandler(openAIOAuthService, adminService)
geminiOAuthHandler := admin.NewGeminiOAuthHandler(geminiOAuthService)
@@ -125,6 +128,9 @@ func initializeApplication(buildInfo handler.BuildInfo) (*Application, error) {
adminRedeemHandler := admin.NewRedeemHandler(adminService)
promoHandler := admin.NewPromoHandler(promoService)
opsRepository := repository.NewOpsRepository(db)
+ schedulerCache := repository.NewSchedulerCache(redisClient)
+ schedulerOutboxRepository := repository.NewSchedulerOutboxRepository(db)
+ schedulerSnapshotService := service.ProvideSchedulerSnapshotService(schedulerCache, schedulerOutboxRepository, accountRepository, groupRepository, configConfig)
pricingRemoteClient := repository.ProvidePricingRemoteClient(configConfig)
pricingService, err := service.ProvidePricingService(configConfig, pricingRemoteClient)
if err != nil {
@@ -134,8 +140,10 @@ func initializeApplication(buildInfo handler.BuildInfo) (*Application, error) {
identityCache := repository.NewIdentityCache(redisClient)
identityService := service.NewIdentityService(identityCache)
deferredService := service.ProvideDeferredService(accountRepository, timingWheelService)
- gatewayService := service.NewGatewayService(accountRepository, groupRepository, usageLogRepository, userRepository, userSubscriptionRepository, gatewayCache, configConfig, schedulerSnapshotService, concurrencyService, billingService, rateLimitService, billingCacheService, identityService, httpUpstream, deferredService)
- openAIGatewayService := service.NewOpenAIGatewayService(accountRepository, usageLogRepository, userRepository, userSubscriptionRepository, gatewayCache, configConfig, schedulerSnapshotService, concurrencyService, billingService, rateLimitService, billingCacheService, httpUpstream, deferredService)
+ claudeTokenProvider := service.NewClaudeTokenProvider(accountRepository, geminiTokenCache, oAuthService)
+ gatewayService := service.NewGatewayService(accountRepository, groupRepository, usageLogRepository, userRepository, userSubscriptionRepository, gatewayCache, configConfig, schedulerSnapshotService, concurrencyService, billingService, rateLimitService, billingCacheService, identityService, httpUpstream, deferredService, claudeTokenProvider, sessionLimitCache)
+ openAITokenProvider := service.NewOpenAITokenProvider(accountRepository, geminiTokenCache, openAIOAuthService)
+ openAIGatewayService := service.NewOpenAIGatewayService(accountRepository, usageLogRepository, userRepository, userSubscriptionRepository, gatewayCache, configConfig, schedulerSnapshotService, concurrencyService, billingService, rateLimitService, billingCacheService, httpUpstream, deferredService, openAITokenProvider)
geminiMessagesCompatService := service.NewGeminiMessagesCompatService(accountRepository, groupRepository, gatewayCache, schedulerSnapshotService, geminiTokenProvider, rateLimitService, httpUpstream, antigravityGatewayService, configConfig)
opsService := service.NewOpsService(opsRepository, settingRepository, configConfig, accountRepository, concurrencyService, gatewayService, openAIGatewayService, geminiMessagesCompatService, antigravityGatewayService)
settingHandler := admin.NewSettingHandler(settingService, emailService, turnstileService, opsService)
@@ -166,7 +174,7 @@ func initializeApplication(buildInfo handler.BuildInfo) (*Application, error) {
opsAlertEvaluatorService := service.ProvideOpsAlertEvaluatorService(opsService, opsRepository, emailService, redisClient, configConfig)
opsCleanupService := service.ProvideOpsCleanupService(opsRepository, db, redisClient, configConfig)
opsScheduledReportService := service.ProvideOpsScheduledReportService(opsService, userService, emailService, redisClient, configConfig)
- tokenRefreshService := service.ProvideTokenRefreshService(accountRepository, oAuthService, openAIOAuthService, geminiOAuthService, antigravityOAuthService, configConfig)
+ tokenRefreshService := service.ProvideTokenRefreshService(accountRepository, oAuthService, openAIOAuthService, geminiOAuthService, antigravityOAuthService, compositeTokenCacheInvalidator, configConfig)
accountExpiryService := service.ProvideAccountExpiryService(accountRepository)
v := provideCleanup(client, redisClient, opsMetricsCollector, opsAggregationService, opsAlertEvaluatorService, opsCleanupService, opsScheduledReportService, schedulerSnapshotService, tokenRefreshService, accountExpiryService, pricingService, emailQueueService, billingCacheService, oAuthService, openAIOAuthService, geminiOAuthService, antigravityOAuthService)
application := &Application{
diff --git a/backend/ent/account.go b/backend/ent/account.go
index e960d324..038aa7e5 100644
--- a/backend/ent/account.go
+++ b/backend/ent/account.go
@@ -43,6 +43,8 @@ type Account struct {
Concurrency int `json:"concurrency,omitempty"`
// Priority holds the value of the "priority" field.
Priority int `json:"priority,omitempty"`
+ // RateMultiplier holds the value of the "rate_multiplier" field.
+ RateMultiplier float64 `json:"rate_multiplier,omitempty"`
// Status holds the value of the "status" field.
Status string `json:"status,omitempty"`
// ErrorMessage holds the value of the "error_message" field.
@@ -135,6 +137,8 @@ func (*Account) scanValues(columns []string) ([]any, error) {
values[i] = new([]byte)
case account.FieldAutoPauseOnExpired, account.FieldSchedulable:
values[i] = new(sql.NullBool)
+ case account.FieldRateMultiplier:
+ values[i] = new(sql.NullFloat64)
case account.FieldID, account.FieldProxyID, account.FieldConcurrency, account.FieldPriority:
values[i] = new(sql.NullInt64)
case account.FieldName, account.FieldNotes, account.FieldPlatform, account.FieldType, account.FieldStatus, account.FieldErrorMessage, account.FieldSessionWindowStatus:
@@ -241,6 +245,12 @@ func (_m *Account) assignValues(columns []string, values []any) error {
} else if value.Valid {
_m.Priority = int(value.Int64)
}
+ case account.FieldRateMultiplier:
+ if value, ok := values[i].(*sql.NullFloat64); !ok {
+ return fmt.Errorf("unexpected type %T for field rate_multiplier", values[i])
+ } else if value.Valid {
+ _m.RateMultiplier = value.Float64
+ }
case account.FieldStatus:
if value, ok := values[i].(*sql.NullString); !ok {
return fmt.Errorf("unexpected type %T for field status", values[i])
@@ -420,6 +430,9 @@ func (_m *Account) String() string {
builder.WriteString("priority=")
builder.WriteString(fmt.Sprintf("%v", _m.Priority))
builder.WriteString(", ")
+ builder.WriteString("rate_multiplier=")
+ builder.WriteString(fmt.Sprintf("%v", _m.RateMultiplier))
+ builder.WriteString(", ")
builder.WriteString("status=")
builder.WriteString(_m.Status)
builder.WriteString(", ")
diff --git a/backend/ent/account/account.go b/backend/ent/account/account.go
index 402e16ee..73c0e8c2 100644
--- a/backend/ent/account/account.go
+++ b/backend/ent/account/account.go
@@ -39,6 +39,8 @@ const (
FieldConcurrency = "concurrency"
// FieldPriority holds the string denoting the priority field in the database.
FieldPriority = "priority"
+ // FieldRateMultiplier holds the string denoting the rate_multiplier field in the database.
+ FieldRateMultiplier = "rate_multiplier"
// FieldStatus holds the string denoting the status field in the database.
FieldStatus = "status"
// FieldErrorMessage holds the string denoting the error_message field in the database.
@@ -116,6 +118,7 @@ var Columns = []string{
FieldProxyID,
FieldConcurrency,
FieldPriority,
+ FieldRateMultiplier,
FieldStatus,
FieldErrorMessage,
FieldLastUsedAt,
@@ -174,6 +177,8 @@ var (
DefaultConcurrency int
// DefaultPriority holds the default value on creation for the "priority" field.
DefaultPriority int
+ // DefaultRateMultiplier holds the default value on creation for the "rate_multiplier" field.
+ DefaultRateMultiplier float64
// DefaultStatus holds the default value on creation for the "status" field.
DefaultStatus string
// StatusValidator is a validator for the "status" field. It is called by the builders before save.
@@ -244,6 +249,11 @@ func ByPriority(opts ...sql.OrderTermOption) OrderOption {
return sql.OrderByField(FieldPriority, opts...).ToFunc()
}
+// ByRateMultiplier orders the results by the rate_multiplier field.
+func ByRateMultiplier(opts ...sql.OrderTermOption) OrderOption {
+ return sql.OrderByField(FieldRateMultiplier, opts...).ToFunc()
+}
+
// ByStatus orders the results by the status field.
func ByStatus(opts ...sql.OrderTermOption) OrderOption {
return sql.OrderByField(FieldStatus, opts...).ToFunc()
diff --git a/backend/ent/account/where.go b/backend/ent/account/where.go
index 6c639fd1..dea1127a 100644
--- a/backend/ent/account/where.go
+++ b/backend/ent/account/where.go
@@ -105,6 +105,11 @@ func Priority(v int) predicate.Account {
return predicate.Account(sql.FieldEQ(FieldPriority, v))
}
+// RateMultiplier applies equality check predicate on the "rate_multiplier" field. It's identical to RateMultiplierEQ.
+func RateMultiplier(v float64) predicate.Account {
+ return predicate.Account(sql.FieldEQ(FieldRateMultiplier, v))
+}
+
// Status applies equality check predicate on the "status" field. It's identical to StatusEQ.
func Status(v string) predicate.Account {
return predicate.Account(sql.FieldEQ(FieldStatus, v))
@@ -675,6 +680,46 @@ func PriorityLTE(v int) predicate.Account {
return predicate.Account(sql.FieldLTE(FieldPriority, v))
}
+// RateMultiplierEQ applies the EQ predicate on the "rate_multiplier" field.
+func RateMultiplierEQ(v float64) predicate.Account {
+ return predicate.Account(sql.FieldEQ(FieldRateMultiplier, v))
+}
+
+// RateMultiplierNEQ applies the NEQ predicate on the "rate_multiplier" field.
+func RateMultiplierNEQ(v float64) predicate.Account {
+ return predicate.Account(sql.FieldNEQ(FieldRateMultiplier, v))
+}
+
+// RateMultiplierIn applies the In predicate on the "rate_multiplier" field.
+func RateMultiplierIn(vs ...float64) predicate.Account {
+ return predicate.Account(sql.FieldIn(FieldRateMultiplier, vs...))
+}
+
+// RateMultiplierNotIn applies the NotIn predicate on the "rate_multiplier" field.
+func RateMultiplierNotIn(vs ...float64) predicate.Account {
+ return predicate.Account(sql.FieldNotIn(FieldRateMultiplier, vs...))
+}
+
+// RateMultiplierGT applies the GT predicate on the "rate_multiplier" field.
+func RateMultiplierGT(v float64) predicate.Account {
+ return predicate.Account(sql.FieldGT(FieldRateMultiplier, v))
+}
+
+// RateMultiplierGTE applies the GTE predicate on the "rate_multiplier" field.
+func RateMultiplierGTE(v float64) predicate.Account {
+ return predicate.Account(sql.FieldGTE(FieldRateMultiplier, v))
+}
+
+// RateMultiplierLT applies the LT predicate on the "rate_multiplier" field.
+func RateMultiplierLT(v float64) predicate.Account {
+ return predicate.Account(sql.FieldLT(FieldRateMultiplier, v))
+}
+
+// RateMultiplierLTE applies the LTE predicate on the "rate_multiplier" field.
+func RateMultiplierLTE(v float64) predicate.Account {
+ return predicate.Account(sql.FieldLTE(FieldRateMultiplier, v))
+}
+
// StatusEQ applies the EQ predicate on the "status" field.
func StatusEQ(v string) predicate.Account {
return predicate.Account(sql.FieldEQ(FieldStatus, v))
diff --git a/backend/ent/account_create.go b/backend/ent/account_create.go
index 0725d43d..42a561cf 100644
--- a/backend/ent/account_create.go
+++ b/backend/ent/account_create.go
@@ -153,6 +153,20 @@ func (_c *AccountCreate) SetNillablePriority(v *int) *AccountCreate {
return _c
}
+// SetRateMultiplier sets the "rate_multiplier" field.
+func (_c *AccountCreate) SetRateMultiplier(v float64) *AccountCreate {
+ _c.mutation.SetRateMultiplier(v)
+ return _c
+}
+
+// SetNillableRateMultiplier sets the "rate_multiplier" field if the given value is not nil.
+func (_c *AccountCreate) SetNillableRateMultiplier(v *float64) *AccountCreate {
+ if v != nil {
+ _c.SetRateMultiplier(*v)
+ }
+ return _c
+}
+
// SetStatus sets the "status" field.
func (_c *AccountCreate) SetStatus(v string) *AccountCreate {
_c.mutation.SetStatus(v)
@@ -429,6 +443,10 @@ func (_c *AccountCreate) defaults() error {
v := account.DefaultPriority
_c.mutation.SetPriority(v)
}
+ if _, ok := _c.mutation.RateMultiplier(); !ok {
+ v := account.DefaultRateMultiplier
+ _c.mutation.SetRateMultiplier(v)
+ }
if _, ok := _c.mutation.Status(); !ok {
v := account.DefaultStatus
_c.mutation.SetStatus(v)
@@ -488,6 +506,9 @@ func (_c *AccountCreate) check() error {
if _, ok := _c.mutation.Priority(); !ok {
return &ValidationError{Name: "priority", err: errors.New(`ent: missing required field "Account.priority"`)}
}
+ if _, ok := _c.mutation.RateMultiplier(); !ok {
+ return &ValidationError{Name: "rate_multiplier", err: errors.New(`ent: missing required field "Account.rate_multiplier"`)}
+ }
if _, ok := _c.mutation.Status(); !ok {
return &ValidationError{Name: "status", err: errors.New(`ent: missing required field "Account.status"`)}
}
@@ -578,6 +599,10 @@ func (_c *AccountCreate) createSpec() (*Account, *sqlgraph.CreateSpec) {
_spec.SetField(account.FieldPriority, field.TypeInt, value)
_node.Priority = value
}
+ if value, ok := _c.mutation.RateMultiplier(); ok {
+ _spec.SetField(account.FieldRateMultiplier, field.TypeFloat64, value)
+ _node.RateMultiplier = value
+ }
if value, ok := _c.mutation.Status(); ok {
_spec.SetField(account.FieldStatus, field.TypeString, value)
_node.Status = value
@@ -893,6 +918,24 @@ func (u *AccountUpsert) AddPriority(v int) *AccountUpsert {
return u
}
+// SetRateMultiplier sets the "rate_multiplier" field.
+func (u *AccountUpsert) SetRateMultiplier(v float64) *AccountUpsert {
+ u.Set(account.FieldRateMultiplier, v)
+ return u
+}
+
+// UpdateRateMultiplier sets the "rate_multiplier" field to the value that was provided on create.
+func (u *AccountUpsert) UpdateRateMultiplier() *AccountUpsert {
+ u.SetExcluded(account.FieldRateMultiplier)
+ return u
+}
+
+// AddRateMultiplier adds v to the "rate_multiplier" field.
+func (u *AccountUpsert) AddRateMultiplier(v float64) *AccountUpsert {
+ u.Add(account.FieldRateMultiplier, v)
+ return u
+}
+
// SetStatus sets the "status" field.
func (u *AccountUpsert) SetStatus(v string) *AccountUpsert {
u.Set(account.FieldStatus, v)
@@ -1325,6 +1368,27 @@ func (u *AccountUpsertOne) UpdatePriority() *AccountUpsertOne {
})
}
+// SetRateMultiplier sets the "rate_multiplier" field.
+func (u *AccountUpsertOne) SetRateMultiplier(v float64) *AccountUpsertOne {
+ return u.Update(func(s *AccountUpsert) {
+ s.SetRateMultiplier(v)
+ })
+}
+
+// AddRateMultiplier adds v to the "rate_multiplier" field.
+func (u *AccountUpsertOne) AddRateMultiplier(v float64) *AccountUpsertOne {
+ return u.Update(func(s *AccountUpsert) {
+ s.AddRateMultiplier(v)
+ })
+}
+
+// UpdateRateMultiplier sets the "rate_multiplier" field to the value that was provided on create.
+func (u *AccountUpsertOne) UpdateRateMultiplier() *AccountUpsertOne {
+ return u.Update(func(s *AccountUpsert) {
+ s.UpdateRateMultiplier()
+ })
+}
+
// SetStatus sets the "status" field.
func (u *AccountUpsertOne) SetStatus(v string) *AccountUpsertOne {
return u.Update(func(s *AccountUpsert) {
@@ -1956,6 +2020,27 @@ func (u *AccountUpsertBulk) UpdatePriority() *AccountUpsertBulk {
})
}
+// SetRateMultiplier sets the "rate_multiplier" field.
+func (u *AccountUpsertBulk) SetRateMultiplier(v float64) *AccountUpsertBulk {
+ return u.Update(func(s *AccountUpsert) {
+ s.SetRateMultiplier(v)
+ })
+}
+
+// AddRateMultiplier adds v to the "rate_multiplier" field.
+func (u *AccountUpsertBulk) AddRateMultiplier(v float64) *AccountUpsertBulk {
+ return u.Update(func(s *AccountUpsert) {
+ s.AddRateMultiplier(v)
+ })
+}
+
+// UpdateRateMultiplier sets the "rate_multiplier" field to the value that was provided on create.
+func (u *AccountUpsertBulk) UpdateRateMultiplier() *AccountUpsertBulk {
+ return u.Update(func(s *AccountUpsert) {
+ s.UpdateRateMultiplier()
+ })
+}
+
// SetStatus sets the "status" field.
func (u *AccountUpsertBulk) SetStatus(v string) *AccountUpsertBulk {
return u.Update(func(s *AccountUpsert) {
diff --git a/backend/ent/account_update.go b/backend/ent/account_update.go
index dcc3212d..63fab096 100644
--- a/backend/ent/account_update.go
+++ b/backend/ent/account_update.go
@@ -193,6 +193,27 @@ func (_u *AccountUpdate) AddPriority(v int) *AccountUpdate {
return _u
}
+// SetRateMultiplier sets the "rate_multiplier" field.
+func (_u *AccountUpdate) SetRateMultiplier(v float64) *AccountUpdate {
+ _u.mutation.ResetRateMultiplier()
+ _u.mutation.SetRateMultiplier(v)
+ return _u
+}
+
+// SetNillableRateMultiplier sets the "rate_multiplier" field if the given value is not nil.
+func (_u *AccountUpdate) SetNillableRateMultiplier(v *float64) *AccountUpdate {
+ if v != nil {
+ _u.SetRateMultiplier(*v)
+ }
+ return _u
+}
+
+// AddRateMultiplier adds value to the "rate_multiplier" field.
+func (_u *AccountUpdate) AddRateMultiplier(v float64) *AccountUpdate {
+ _u.mutation.AddRateMultiplier(v)
+ return _u
+}
+
// SetStatus sets the "status" field.
func (_u *AccountUpdate) SetStatus(v string) *AccountUpdate {
_u.mutation.SetStatus(v)
@@ -629,6 +650,12 @@ func (_u *AccountUpdate) sqlSave(ctx context.Context) (_node int, err error) {
if value, ok := _u.mutation.AddedPriority(); ok {
_spec.AddField(account.FieldPriority, field.TypeInt, value)
}
+ if value, ok := _u.mutation.RateMultiplier(); ok {
+ _spec.SetField(account.FieldRateMultiplier, field.TypeFloat64, value)
+ }
+ if value, ok := _u.mutation.AddedRateMultiplier(); ok {
+ _spec.AddField(account.FieldRateMultiplier, field.TypeFloat64, value)
+ }
if value, ok := _u.mutation.Status(); ok {
_spec.SetField(account.FieldStatus, field.TypeString, value)
}
@@ -1005,6 +1032,27 @@ func (_u *AccountUpdateOne) AddPriority(v int) *AccountUpdateOne {
return _u
}
+// SetRateMultiplier sets the "rate_multiplier" field.
+func (_u *AccountUpdateOne) SetRateMultiplier(v float64) *AccountUpdateOne {
+ _u.mutation.ResetRateMultiplier()
+ _u.mutation.SetRateMultiplier(v)
+ return _u
+}
+
+// SetNillableRateMultiplier sets the "rate_multiplier" field if the given value is not nil.
+func (_u *AccountUpdateOne) SetNillableRateMultiplier(v *float64) *AccountUpdateOne {
+ if v != nil {
+ _u.SetRateMultiplier(*v)
+ }
+ return _u
+}
+
+// AddRateMultiplier adds value to the "rate_multiplier" field.
+func (_u *AccountUpdateOne) AddRateMultiplier(v float64) *AccountUpdateOne {
+ _u.mutation.AddRateMultiplier(v)
+ return _u
+}
+
// SetStatus sets the "status" field.
func (_u *AccountUpdateOne) SetStatus(v string) *AccountUpdateOne {
_u.mutation.SetStatus(v)
@@ -1471,6 +1519,12 @@ func (_u *AccountUpdateOne) sqlSave(ctx context.Context) (_node *Account, err er
if value, ok := _u.mutation.AddedPriority(); ok {
_spec.AddField(account.FieldPriority, field.TypeInt, value)
}
+ if value, ok := _u.mutation.RateMultiplier(); ok {
+ _spec.SetField(account.FieldRateMultiplier, field.TypeFloat64, value)
+ }
+ if value, ok := _u.mutation.AddedRateMultiplier(); ok {
+ _spec.AddField(account.FieldRateMultiplier, field.TypeFloat64, value)
+ }
if value, ok := _u.mutation.Status(); ok {
_spec.SetField(account.FieldStatus, field.TypeString, value)
}
diff --git a/backend/ent/group.go b/backend/ent/group.go
index 4a31442a..0d0c0538 100644
--- a/backend/ent/group.go
+++ b/backend/ent/group.go
@@ -3,6 +3,7 @@
package ent
import (
+ "encoding/json"
"fmt"
"strings"
"time"
@@ -55,6 +56,10 @@ type Group struct {
ClaudeCodeOnly bool `json:"claude_code_only,omitempty"`
// 非 Claude Code 请求降级使用的分组 ID
FallbackGroupID *int64 `json:"fallback_group_id,omitempty"`
+ // 模型路由配置:模型模式 -> 优先账号ID列表
+ ModelRouting map[string][]int64 `json:"model_routing,omitempty"`
+ // 是否启用模型路由配置
+ ModelRoutingEnabled bool `json:"model_routing_enabled,omitempty"`
// Edges holds the relations/edges for other nodes in the graph.
// The values are being populated by the GroupQuery when eager-loading is set.
Edges GroupEdges `json:"edges"`
@@ -161,7 +166,9 @@ func (*Group) scanValues(columns []string) ([]any, error) {
values := make([]any, len(columns))
for i := range columns {
switch columns[i] {
- case group.FieldIsExclusive, group.FieldClaudeCodeOnly:
+ case group.FieldModelRouting:
+ values[i] = new([]byte)
+ case group.FieldIsExclusive, group.FieldClaudeCodeOnly, group.FieldModelRoutingEnabled:
values[i] = new(sql.NullBool)
case group.FieldRateMultiplier, group.FieldDailyLimitUsd, group.FieldWeeklyLimitUsd, group.FieldMonthlyLimitUsd, group.FieldImagePrice1k, group.FieldImagePrice2k, group.FieldImagePrice4k:
values[i] = new(sql.NullFloat64)
@@ -315,6 +322,20 @@ func (_m *Group) assignValues(columns []string, values []any) error {
_m.FallbackGroupID = new(int64)
*_m.FallbackGroupID = value.Int64
}
+ case group.FieldModelRouting:
+ if value, ok := values[i].(*[]byte); !ok {
+ return fmt.Errorf("unexpected type %T for field model_routing", values[i])
+ } else if value != nil && len(*value) > 0 {
+ if err := json.Unmarshal(*value, &_m.ModelRouting); err != nil {
+ return fmt.Errorf("unmarshal field model_routing: %w", err)
+ }
+ }
+ case group.FieldModelRoutingEnabled:
+ if value, ok := values[i].(*sql.NullBool); !ok {
+ return fmt.Errorf("unexpected type %T for field model_routing_enabled", values[i])
+ } else if value.Valid {
+ _m.ModelRoutingEnabled = value.Bool
+ }
default:
_m.selectValues.Set(columns[i], values[i])
}
@@ -465,6 +486,12 @@ func (_m *Group) String() string {
builder.WriteString("fallback_group_id=")
builder.WriteString(fmt.Sprintf("%v", *v))
}
+ builder.WriteString(", ")
+ builder.WriteString("model_routing=")
+ builder.WriteString(fmt.Sprintf("%v", _m.ModelRouting))
+ builder.WriteString(", ")
+ builder.WriteString("model_routing_enabled=")
+ builder.WriteString(fmt.Sprintf("%v", _m.ModelRoutingEnabled))
builder.WriteByte(')')
return builder.String()
}
diff --git a/backend/ent/group/group.go b/backend/ent/group/group.go
index c4317f00..d66d3edc 100644
--- a/backend/ent/group/group.go
+++ b/backend/ent/group/group.go
@@ -53,6 +53,10 @@ const (
FieldClaudeCodeOnly = "claude_code_only"
// FieldFallbackGroupID holds the string denoting the fallback_group_id field in the database.
FieldFallbackGroupID = "fallback_group_id"
+ // FieldModelRouting holds the string denoting the model_routing field in the database.
+ FieldModelRouting = "model_routing"
+ // FieldModelRoutingEnabled holds the string denoting the model_routing_enabled field in the database.
+ FieldModelRoutingEnabled = "model_routing_enabled"
// EdgeAPIKeys holds the string denoting the api_keys edge name in mutations.
EdgeAPIKeys = "api_keys"
// EdgeRedeemCodes holds the string denoting the redeem_codes edge name in mutations.
@@ -147,6 +151,8 @@ var Columns = []string{
FieldImagePrice4k,
FieldClaudeCodeOnly,
FieldFallbackGroupID,
+ FieldModelRouting,
+ FieldModelRoutingEnabled,
}
var (
@@ -204,6 +210,8 @@ var (
DefaultDefaultValidityDays int
// DefaultClaudeCodeOnly holds the default value on creation for the "claude_code_only" field.
DefaultClaudeCodeOnly bool
+ // DefaultModelRoutingEnabled holds the default value on creation for the "model_routing_enabled" field.
+ DefaultModelRoutingEnabled bool
)
// OrderOption defines the ordering options for the Group queries.
@@ -309,6 +317,11 @@ func ByFallbackGroupID(opts ...sql.OrderTermOption) OrderOption {
return sql.OrderByField(FieldFallbackGroupID, opts...).ToFunc()
}
+// ByModelRoutingEnabled orders the results by the model_routing_enabled field.
+func ByModelRoutingEnabled(opts ...sql.OrderTermOption) OrderOption {
+ return sql.OrderByField(FieldModelRoutingEnabled, opts...).ToFunc()
+}
+
// ByAPIKeysCount orders the results by api_keys count.
func ByAPIKeysCount(opts ...sql.OrderTermOption) OrderOption {
return func(s *sql.Selector) {
diff --git a/backend/ent/group/where.go b/backend/ent/group/where.go
index fb2f942f..6ce9e4c6 100644
--- a/backend/ent/group/where.go
+++ b/backend/ent/group/where.go
@@ -150,6 +150,11 @@ func FallbackGroupID(v int64) predicate.Group {
return predicate.Group(sql.FieldEQ(FieldFallbackGroupID, v))
}
+// ModelRoutingEnabled applies equality check predicate on the "model_routing_enabled" field. It's identical to ModelRoutingEnabledEQ.
+func ModelRoutingEnabled(v bool) predicate.Group {
+ return predicate.Group(sql.FieldEQ(FieldModelRoutingEnabled, v))
+}
+
// CreatedAtEQ applies the EQ predicate on the "created_at" field.
func CreatedAtEQ(v time.Time) predicate.Group {
return predicate.Group(sql.FieldEQ(FieldCreatedAt, v))
@@ -1065,6 +1070,26 @@ func FallbackGroupIDNotNil() predicate.Group {
return predicate.Group(sql.FieldNotNull(FieldFallbackGroupID))
}
+// ModelRoutingIsNil applies the IsNil predicate on the "model_routing" field.
+func ModelRoutingIsNil() predicate.Group {
+ return predicate.Group(sql.FieldIsNull(FieldModelRouting))
+}
+
+// ModelRoutingNotNil applies the NotNil predicate on the "model_routing" field.
+func ModelRoutingNotNil() predicate.Group {
+ return predicate.Group(sql.FieldNotNull(FieldModelRouting))
+}
+
+// ModelRoutingEnabledEQ applies the EQ predicate on the "model_routing_enabled" field.
+func ModelRoutingEnabledEQ(v bool) predicate.Group {
+ return predicate.Group(sql.FieldEQ(FieldModelRoutingEnabled, v))
+}
+
+// ModelRoutingEnabledNEQ applies the NEQ predicate on the "model_routing_enabled" field.
+func ModelRoutingEnabledNEQ(v bool) predicate.Group {
+ return predicate.Group(sql.FieldNEQ(FieldModelRoutingEnabled, v))
+}
+
// HasAPIKeys applies the HasEdge predicate on the "api_keys" edge.
func HasAPIKeys() predicate.Group {
return predicate.Group(func(s *sql.Selector) {
diff --git a/backend/ent/group_create.go b/backend/ent/group_create.go
index 59229402..0f251e0b 100644
--- a/backend/ent/group_create.go
+++ b/backend/ent/group_create.go
@@ -286,6 +286,26 @@ func (_c *GroupCreate) SetNillableFallbackGroupID(v *int64) *GroupCreate {
return _c
}
+// SetModelRouting sets the "model_routing" field.
+func (_c *GroupCreate) SetModelRouting(v map[string][]int64) *GroupCreate {
+ _c.mutation.SetModelRouting(v)
+ return _c
+}
+
+// SetModelRoutingEnabled sets the "model_routing_enabled" field.
+func (_c *GroupCreate) SetModelRoutingEnabled(v bool) *GroupCreate {
+ _c.mutation.SetModelRoutingEnabled(v)
+ return _c
+}
+
+// SetNillableModelRoutingEnabled sets the "model_routing_enabled" field if the given value is not nil.
+func (_c *GroupCreate) SetNillableModelRoutingEnabled(v *bool) *GroupCreate {
+ if v != nil {
+ _c.SetModelRoutingEnabled(*v)
+ }
+ return _c
+}
+
// AddAPIKeyIDs adds the "api_keys" edge to the APIKey entity by IDs.
func (_c *GroupCreate) AddAPIKeyIDs(ids ...int64) *GroupCreate {
_c.mutation.AddAPIKeyIDs(ids...)
@@ -455,6 +475,10 @@ func (_c *GroupCreate) defaults() error {
v := group.DefaultClaudeCodeOnly
_c.mutation.SetClaudeCodeOnly(v)
}
+ if _, ok := _c.mutation.ModelRoutingEnabled(); !ok {
+ v := group.DefaultModelRoutingEnabled
+ _c.mutation.SetModelRoutingEnabled(v)
+ }
return nil
}
@@ -510,6 +534,9 @@ func (_c *GroupCreate) check() error {
if _, ok := _c.mutation.ClaudeCodeOnly(); !ok {
return &ValidationError{Name: "claude_code_only", err: errors.New(`ent: missing required field "Group.claude_code_only"`)}
}
+ if _, ok := _c.mutation.ModelRoutingEnabled(); !ok {
+ return &ValidationError{Name: "model_routing_enabled", err: errors.New(`ent: missing required field "Group.model_routing_enabled"`)}
+ }
return nil
}
@@ -613,6 +640,14 @@ func (_c *GroupCreate) createSpec() (*Group, *sqlgraph.CreateSpec) {
_spec.SetField(group.FieldFallbackGroupID, field.TypeInt64, value)
_node.FallbackGroupID = &value
}
+ if value, ok := _c.mutation.ModelRouting(); ok {
+ _spec.SetField(group.FieldModelRouting, field.TypeJSON, value)
+ _node.ModelRouting = value
+ }
+ if value, ok := _c.mutation.ModelRoutingEnabled(); ok {
+ _spec.SetField(group.FieldModelRoutingEnabled, field.TypeBool, value)
+ _node.ModelRoutingEnabled = value
+ }
if nodes := _c.mutation.APIKeysIDs(); len(nodes) > 0 {
edge := &sqlgraph.EdgeSpec{
Rel: sqlgraph.O2M,
@@ -1093,6 +1128,36 @@ func (u *GroupUpsert) ClearFallbackGroupID() *GroupUpsert {
return u
}
+// SetModelRouting sets the "model_routing" field.
+func (u *GroupUpsert) SetModelRouting(v map[string][]int64) *GroupUpsert {
+ u.Set(group.FieldModelRouting, v)
+ return u
+}
+
+// UpdateModelRouting sets the "model_routing" field to the value that was provided on create.
+func (u *GroupUpsert) UpdateModelRouting() *GroupUpsert {
+ u.SetExcluded(group.FieldModelRouting)
+ return u
+}
+
+// ClearModelRouting clears the value of the "model_routing" field.
+func (u *GroupUpsert) ClearModelRouting() *GroupUpsert {
+ u.SetNull(group.FieldModelRouting)
+ return u
+}
+
+// SetModelRoutingEnabled sets the "model_routing_enabled" field.
+func (u *GroupUpsert) SetModelRoutingEnabled(v bool) *GroupUpsert {
+ u.Set(group.FieldModelRoutingEnabled, v)
+ return u
+}
+
+// UpdateModelRoutingEnabled sets the "model_routing_enabled" field to the value that was provided on create.
+func (u *GroupUpsert) UpdateModelRoutingEnabled() *GroupUpsert {
+ u.SetExcluded(group.FieldModelRoutingEnabled)
+ return u
+}
+
// UpdateNewValues updates the mutable fields using the new values that were set on create.
// Using this option is equivalent to using:
//
@@ -1516,6 +1581,41 @@ func (u *GroupUpsertOne) ClearFallbackGroupID() *GroupUpsertOne {
})
}
+// SetModelRouting sets the "model_routing" field.
+func (u *GroupUpsertOne) SetModelRouting(v map[string][]int64) *GroupUpsertOne {
+ return u.Update(func(s *GroupUpsert) {
+ s.SetModelRouting(v)
+ })
+}
+
+// UpdateModelRouting sets the "model_routing" field to the value that was provided on create.
+func (u *GroupUpsertOne) UpdateModelRouting() *GroupUpsertOne {
+ return u.Update(func(s *GroupUpsert) {
+ s.UpdateModelRouting()
+ })
+}
+
+// ClearModelRouting clears the value of the "model_routing" field.
+func (u *GroupUpsertOne) ClearModelRouting() *GroupUpsertOne {
+ return u.Update(func(s *GroupUpsert) {
+ s.ClearModelRouting()
+ })
+}
+
+// SetModelRoutingEnabled sets the "model_routing_enabled" field.
+func (u *GroupUpsertOne) SetModelRoutingEnabled(v bool) *GroupUpsertOne {
+ return u.Update(func(s *GroupUpsert) {
+ s.SetModelRoutingEnabled(v)
+ })
+}
+
+// UpdateModelRoutingEnabled sets the "model_routing_enabled" field to the value that was provided on create.
+func (u *GroupUpsertOne) UpdateModelRoutingEnabled() *GroupUpsertOne {
+ return u.Update(func(s *GroupUpsert) {
+ s.UpdateModelRoutingEnabled()
+ })
+}
+
// Exec executes the query.
func (u *GroupUpsertOne) Exec(ctx context.Context) error {
if len(u.create.conflict) == 0 {
@@ -2105,6 +2205,41 @@ func (u *GroupUpsertBulk) ClearFallbackGroupID() *GroupUpsertBulk {
})
}
+// SetModelRouting sets the "model_routing" field.
+func (u *GroupUpsertBulk) SetModelRouting(v map[string][]int64) *GroupUpsertBulk {
+ return u.Update(func(s *GroupUpsert) {
+ s.SetModelRouting(v)
+ })
+}
+
+// UpdateModelRouting sets the "model_routing" field to the value that was provided on create.
+func (u *GroupUpsertBulk) UpdateModelRouting() *GroupUpsertBulk {
+ return u.Update(func(s *GroupUpsert) {
+ s.UpdateModelRouting()
+ })
+}
+
+// ClearModelRouting clears the value of the "model_routing" field.
+func (u *GroupUpsertBulk) ClearModelRouting() *GroupUpsertBulk {
+ return u.Update(func(s *GroupUpsert) {
+ s.ClearModelRouting()
+ })
+}
+
+// SetModelRoutingEnabled sets the "model_routing_enabled" field.
+func (u *GroupUpsertBulk) SetModelRoutingEnabled(v bool) *GroupUpsertBulk {
+ return u.Update(func(s *GroupUpsert) {
+ s.SetModelRoutingEnabled(v)
+ })
+}
+
+// UpdateModelRoutingEnabled sets the "model_routing_enabled" field to the value that was provided on create.
+func (u *GroupUpsertBulk) UpdateModelRoutingEnabled() *GroupUpsertBulk {
+ return u.Update(func(s *GroupUpsert) {
+ s.UpdateModelRoutingEnabled()
+ })
+}
+
// Exec executes the query.
func (u *GroupUpsertBulk) Exec(ctx context.Context) error {
if u.create.err != nil {
diff --git a/backend/ent/group_update.go b/backend/ent/group_update.go
index 1a6f15ec..c3cc2708 100644
--- a/backend/ent/group_update.go
+++ b/backend/ent/group_update.go
@@ -395,6 +395,32 @@ func (_u *GroupUpdate) ClearFallbackGroupID() *GroupUpdate {
return _u
}
+// SetModelRouting sets the "model_routing" field.
+func (_u *GroupUpdate) SetModelRouting(v map[string][]int64) *GroupUpdate {
+ _u.mutation.SetModelRouting(v)
+ return _u
+}
+
+// ClearModelRouting clears the value of the "model_routing" field.
+func (_u *GroupUpdate) ClearModelRouting() *GroupUpdate {
+ _u.mutation.ClearModelRouting()
+ return _u
+}
+
+// SetModelRoutingEnabled sets the "model_routing_enabled" field.
+func (_u *GroupUpdate) SetModelRoutingEnabled(v bool) *GroupUpdate {
+ _u.mutation.SetModelRoutingEnabled(v)
+ return _u
+}
+
+// SetNillableModelRoutingEnabled sets the "model_routing_enabled" field if the given value is not nil.
+func (_u *GroupUpdate) SetNillableModelRoutingEnabled(v *bool) *GroupUpdate {
+ if v != nil {
+ _u.SetModelRoutingEnabled(*v)
+ }
+ return _u
+}
+
// AddAPIKeyIDs adds the "api_keys" edge to the APIKey entity by IDs.
func (_u *GroupUpdate) AddAPIKeyIDs(ids ...int64) *GroupUpdate {
_u.mutation.AddAPIKeyIDs(ids...)
@@ -803,6 +829,15 @@ func (_u *GroupUpdate) sqlSave(ctx context.Context) (_node int, err error) {
if _u.mutation.FallbackGroupIDCleared() {
_spec.ClearField(group.FieldFallbackGroupID, field.TypeInt64)
}
+ if value, ok := _u.mutation.ModelRouting(); ok {
+ _spec.SetField(group.FieldModelRouting, field.TypeJSON, value)
+ }
+ if _u.mutation.ModelRoutingCleared() {
+ _spec.ClearField(group.FieldModelRouting, field.TypeJSON)
+ }
+ if value, ok := _u.mutation.ModelRoutingEnabled(); ok {
+ _spec.SetField(group.FieldModelRoutingEnabled, field.TypeBool, value)
+ }
if _u.mutation.APIKeysCleared() {
edge := &sqlgraph.EdgeSpec{
Rel: sqlgraph.O2M,
@@ -1478,6 +1513,32 @@ func (_u *GroupUpdateOne) ClearFallbackGroupID() *GroupUpdateOne {
return _u
}
+// SetModelRouting sets the "model_routing" field.
+func (_u *GroupUpdateOne) SetModelRouting(v map[string][]int64) *GroupUpdateOne {
+ _u.mutation.SetModelRouting(v)
+ return _u
+}
+
+// ClearModelRouting clears the value of the "model_routing" field.
+func (_u *GroupUpdateOne) ClearModelRouting() *GroupUpdateOne {
+ _u.mutation.ClearModelRouting()
+ return _u
+}
+
+// SetModelRoutingEnabled sets the "model_routing_enabled" field.
+func (_u *GroupUpdateOne) SetModelRoutingEnabled(v bool) *GroupUpdateOne {
+ _u.mutation.SetModelRoutingEnabled(v)
+ return _u
+}
+
+// SetNillableModelRoutingEnabled sets the "model_routing_enabled" field if the given value is not nil.
+func (_u *GroupUpdateOne) SetNillableModelRoutingEnabled(v *bool) *GroupUpdateOne {
+ if v != nil {
+ _u.SetModelRoutingEnabled(*v)
+ }
+ return _u
+}
+
// AddAPIKeyIDs adds the "api_keys" edge to the APIKey entity by IDs.
func (_u *GroupUpdateOne) AddAPIKeyIDs(ids ...int64) *GroupUpdateOne {
_u.mutation.AddAPIKeyIDs(ids...)
@@ -1916,6 +1977,15 @@ func (_u *GroupUpdateOne) sqlSave(ctx context.Context) (_node *Group, err error)
if _u.mutation.FallbackGroupIDCleared() {
_spec.ClearField(group.FieldFallbackGroupID, field.TypeInt64)
}
+ if value, ok := _u.mutation.ModelRouting(); ok {
+ _spec.SetField(group.FieldModelRouting, field.TypeJSON, value)
+ }
+ if _u.mutation.ModelRoutingCleared() {
+ _spec.ClearField(group.FieldModelRouting, field.TypeJSON)
+ }
+ if value, ok := _u.mutation.ModelRoutingEnabled(); ok {
+ _spec.SetField(group.FieldModelRoutingEnabled, field.TypeBool, value)
+ }
if _u.mutation.APIKeysCleared() {
edge := &sqlgraph.EdgeSpec{
Rel: sqlgraph.O2M,
diff --git a/backend/ent/migrate/schema.go b/backend/ent/migrate/schema.go
index 41cd8b01..b377804f 100644
--- a/backend/ent/migrate/schema.go
+++ b/backend/ent/migrate/schema.go
@@ -79,6 +79,7 @@ var (
{Name: "extra", Type: field.TypeJSON, SchemaType: map[string]string{"postgres": "jsonb"}},
{Name: "concurrency", Type: field.TypeInt, Default: 3},
{Name: "priority", Type: field.TypeInt, Default: 50},
+ {Name: "rate_multiplier", Type: field.TypeFloat64, Default: 1, SchemaType: map[string]string{"postgres": "decimal(10,4)"}},
{Name: "status", Type: field.TypeString, Size: 20, Default: "active"},
{Name: "error_message", Type: field.TypeString, Nullable: true, SchemaType: map[string]string{"postgres": "text"}},
{Name: "last_used_at", Type: field.TypeTime, Nullable: true, SchemaType: map[string]string{"postgres": "timestamptz"}},
@@ -101,7 +102,7 @@ var (
ForeignKeys: []*schema.ForeignKey{
{
Symbol: "accounts_proxies_proxy",
- Columns: []*schema.Column{AccountsColumns[24]},
+ Columns: []*schema.Column{AccountsColumns[25]},
RefColumns: []*schema.Column{ProxiesColumns[0]},
OnDelete: schema.SetNull,
},
@@ -120,12 +121,12 @@ var (
{
Name: "account_status",
Unique: false,
- Columns: []*schema.Column{AccountsColumns[12]},
+ Columns: []*schema.Column{AccountsColumns[13]},
},
{
Name: "account_proxy_id",
Unique: false,
- Columns: []*schema.Column{AccountsColumns[24]},
+ Columns: []*schema.Column{AccountsColumns[25]},
},
{
Name: "account_priority",
@@ -135,27 +136,27 @@ var (
{
Name: "account_last_used_at",
Unique: false,
- Columns: []*schema.Column{AccountsColumns[14]},
+ Columns: []*schema.Column{AccountsColumns[15]},
},
{
Name: "account_schedulable",
Unique: false,
- Columns: []*schema.Column{AccountsColumns[17]},
+ Columns: []*schema.Column{AccountsColumns[18]},
},
{
Name: "account_rate_limited_at",
Unique: false,
- Columns: []*schema.Column{AccountsColumns[18]},
+ Columns: []*schema.Column{AccountsColumns[19]},
},
{
Name: "account_rate_limit_reset_at",
Unique: false,
- Columns: []*schema.Column{AccountsColumns[19]},
+ Columns: []*schema.Column{AccountsColumns[20]},
},
{
Name: "account_overload_until",
Unique: false,
- Columns: []*schema.Column{AccountsColumns[20]},
+ Columns: []*schema.Column{AccountsColumns[21]},
},
{
Name: "account_deleted_at",
@@ -225,6 +226,8 @@ var (
{Name: "image_price_4k", Type: field.TypeFloat64, Nullable: true, SchemaType: map[string]string{"postgres": "decimal(20,8)"}},
{Name: "claude_code_only", Type: field.TypeBool, Default: false},
{Name: "fallback_group_id", Type: field.TypeInt64, Nullable: true},
+ {Name: "model_routing", Type: field.TypeJSON, Nullable: true, SchemaType: map[string]string{"postgres": "jsonb"}},
+ {Name: "model_routing_enabled", Type: field.TypeBool, Default: false},
}
// GroupsTable holds the schema information for the "groups" table.
GroupsTable = &schema.Table{
@@ -449,6 +452,7 @@ var (
{Name: "total_cost", Type: field.TypeFloat64, Default: 0, SchemaType: map[string]string{"postgres": "decimal(20,10)"}},
{Name: "actual_cost", Type: field.TypeFloat64, Default: 0, SchemaType: map[string]string{"postgres": "decimal(20,10)"}},
{Name: "rate_multiplier", Type: field.TypeFloat64, Default: 1, SchemaType: map[string]string{"postgres": "decimal(10,4)"}},
+ {Name: "account_rate_multiplier", Type: field.TypeFloat64, Nullable: true, SchemaType: map[string]string{"postgres": "decimal(10,4)"}},
{Name: "billing_type", Type: field.TypeInt8, Default: 0},
{Name: "stream", Type: field.TypeBool, Default: false},
{Name: "duration_ms", Type: field.TypeInt, Nullable: true},
@@ -472,31 +476,31 @@ var (
ForeignKeys: []*schema.ForeignKey{
{
Symbol: "usage_logs_api_keys_usage_logs",
- Columns: []*schema.Column{UsageLogsColumns[25]},
+ Columns: []*schema.Column{UsageLogsColumns[26]},
RefColumns: []*schema.Column{APIKeysColumns[0]},
OnDelete: schema.NoAction,
},
{
Symbol: "usage_logs_accounts_usage_logs",
- Columns: []*schema.Column{UsageLogsColumns[26]},
+ Columns: []*schema.Column{UsageLogsColumns[27]},
RefColumns: []*schema.Column{AccountsColumns[0]},
OnDelete: schema.NoAction,
},
{
Symbol: "usage_logs_groups_usage_logs",
- Columns: []*schema.Column{UsageLogsColumns[27]},
+ Columns: []*schema.Column{UsageLogsColumns[28]},
RefColumns: []*schema.Column{GroupsColumns[0]},
OnDelete: schema.SetNull,
},
{
Symbol: "usage_logs_users_usage_logs",
- Columns: []*schema.Column{UsageLogsColumns[28]},
+ Columns: []*schema.Column{UsageLogsColumns[29]},
RefColumns: []*schema.Column{UsersColumns[0]},
OnDelete: schema.NoAction,
},
{
Symbol: "usage_logs_user_subscriptions_usage_logs",
- Columns: []*schema.Column{UsageLogsColumns[29]},
+ Columns: []*schema.Column{UsageLogsColumns[30]},
RefColumns: []*schema.Column{UserSubscriptionsColumns[0]},
OnDelete: schema.SetNull,
},
@@ -505,32 +509,32 @@ var (
{
Name: "usagelog_user_id",
Unique: false,
- Columns: []*schema.Column{UsageLogsColumns[28]},
+ Columns: []*schema.Column{UsageLogsColumns[29]},
},
{
Name: "usagelog_api_key_id",
Unique: false,
- Columns: []*schema.Column{UsageLogsColumns[25]},
+ Columns: []*schema.Column{UsageLogsColumns[26]},
},
{
Name: "usagelog_account_id",
Unique: false,
- Columns: []*schema.Column{UsageLogsColumns[26]},
+ Columns: []*schema.Column{UsageLogsColumns[27]},
},
{
Name: "usagelog_group_id",
Unique: false,
- Columns: []*schema.Column{UsageLogsColumns[27]},
+ Columns: []*schema.Column{UsageLogsColumns[28]},
},
{
Name: "usagelog_subscription_id",
Unique: false,
- Columns: []*schema.Column{UsageLogsColumns[29]},
+ Columns: []*schema.Column{UsageLogsColumns[30]},
},
{
Name: "usagelog_created_at",
Unique: false,
- Columns: []*schema.Column{UsageLogsColumns[24]},
+ Columns: []*schema.Column{UsageLogsColumns[25]},
},
{
Name: "usagelog_model",
@@ -545,12 +549,12 @@ var (
{
Name: "usagelog_user_id_created_at",
Unique: false,
- Columns: []*schema.Column{UsageLogsColumns[28], UsageLogsColumns[24]},
+ Columns: []*schema.Column{UsageLogsColumns[29], UsageLogsColumns[25]},
},
{
Name: "usagelog_api_key_id_created_at",
Unique: false,
- Columns: []*schema.Column{UsageLogsColumns[25], UsageLogsColumns[24]},
+ Columns: []*schema.Column{UsageLogsColumns[26], UsageLogsColumns[25]},
},
},
}
diff --git a/backend/ent/mutation.go b/backend/ent/mutation.go
index 732abd1c..cd2fe8e0 100644
--- a/backend/ent/mutation.go
+++ b/backend/ent/mutation.go
@@ -1187,6 +1187,8 @@ type AccountMutation struct {
addconcurrency *int
priority *int
addpriority *int
+ rate_multiplier *float64
+ addrate_multiplier *float64
status *string
error_message *string
last_used_at *time.Time
@@ -1822,6 +1824,62 @@ func (m *AccountMutation) ResetPriority() {
m.addpriority = nil
}
+// SetRateMultiplier sets the "rate_multiplier" field.
+func (m *AccountMutation) SetRateMultiplier(f float64) {
+ m.rate_multiplier = &f
+ m.addrate_multiplier = nil
+}
+
+// RateMultiplier returns the value of the "rate_multiplier" field in the mutation.
+func (m *AccountMutation) RateMultiplier() (r float64, exists bool) {
+ v := m.rate_multiplier
+ if v == nil {
+ return
+ }
+ return *v, true
+}
+
+// OldRateMultiplier returns the old "rate_multiplier" field's value of the Account entity.
+// If the Account object wasn't provided to the builder, the object is fetched from the database.
+// An error is returned if the mutation operation is not UpdateOne, or the database query fails.
+func (m *AccountMutation) OldRateMultiplier(ctx context.Context) (v float64, err error) {
+ if !m.op.Is(OpUpdateOne) {
+ return v, errors.New("OldRateMultiplier is only allowed on UpdateOne operations")
+ }
+ if m.id == nil || m.oldValue == nil {
+ return v, errors.New("OldRateMultiplier requires an ID field in the mutation")
+ }
+ oldValue, err := m.oldValue(ctx)
+ if err != nil {
+ return v, fmt.Errorf("querying old value for OldRateMultiplier: %w", err)
+ }
+ return oldValue.RateMultiplier, nil
+}
+
+// AddRateMultiplier adds f to the "rate_multiplier" field.
+func (m *AccountMutation) AddRateMultiplier(f float64) {
+ if m.addrate_multiplier != nil {
+ *m.addrate_multiplier += f
+ } else {
+ m.addrate_multiplier = &f
+ }
+}
+
+// AddedRateMultiplier returns the value that was added to the "rate_multiplier" field in this mutation.
+func (m *AccountMutation) AddedRateMultiplier() (r float64, exists bool) {
+ v := m.addrate_multiplier
+ if v == nil {
+ return
+ }
+ return *v, true
+}
+
+// ResetRateMultiplier resets all changes to the "rate_multiplier" field.
+func (m *AccountMutation) ResetRateMultiplier() {
+ m.rate_multiplier = nil
+ m.addrate_multiplier = nil
+}
+
// SetStatus sets the "status" field.
func (m *AccountMutation) SetStatus(s string) {
m.status = &s
@@ -2540,7 +2598,7 @@ func (m *AccountMutation) Type() string {
// order to get all numeric fields that were incremented/decremented, call
// AddedFields().
func (m *AccountMutation) Fields() []string {
- fields := make([]string, 0, 24)
+ fields := make([]string, 0, 25)
if m.created_at != nil {
fields = append(fields, account.FieldCreatedAt)
}
@@ -2577,6 +2635,9 @@ func (m *AccountMutation) Fields() []string {
if m.priority != nil {
fields = append(fields, account.FieldPriority)
}
+ if m.rate_multiplier != nil {
+ fields = append(fields, account.FieldRateMultiplier)
+ }
if m.status != nil {
fields = append(fields, account.FieldStatus)
}
@@ -2645,6 +2706,8 @@ func (m *AccountMutation) Field(name string) (ent.Value, bool) {
return m.Concurrency()
case account.FieldPriority:
return m.Priority()
+ case account.FieldRateMultiplier:
+ return m.RateMultiplier()
case account.FieldStatus:
return m.Status()
case account.FieldErrorMessage:
@@ -2702,6 +2765,8 @@ func (m *AccountMutation) OldField(ctx context.Context, name string) (ent.Value,
return m.OldConcurrency(ctx)
case account.FieldPriority:
return m.OldPriority(ctx)
+ case account.FieldRateMultiplier:
+ return m.OldRateMultiplier(ctx)
case account.FieldStatus:
return m.OldStatus(ctx)
case account.FieldErrorMessage:
@@ -2819,6 +2884,13 @@ func (m *AccountMutation) SetField(name string, value ent.Value) error {
}
m.SetPriority(v)
return nil
+ case account.FieldRateMultiplier:
+ v, ok := value.(float64)
+ if !ok {
+ return fmt.Errorf("unexpected type %T for field %s", value, name)
+ }
+ m.SetRateMultiplier(v)
+ return nil
case account.FieldStatus:
v, ok := value.(string)
if !ok {
@@ -2917,6 +2989,9 @@ func (m *AccountMutation) AddedFields() []string {
if m.addpriority != nil {
fields = append(fields, account.FieldPriority)
}
+ if m.addrate_multiplier != nil {
+ fields = append(fields, account.FieldRateMultiplier)
+ }
return fields
}
@@ -2929,6 +3004,8 @@ func (m *AccountMutation) AddedField(name string) (ent.Value, bool) {
return m.AddedConcurrency()
case account.FieldPriority:
return m.AddedPriority()
+ case account.FieldRateMultiplier:
+ return m.AddedRateMultiplier()
}
return nil, false
}
@@ -2952,6 +3029,13 @@ func (m *AccountMutation) AddField(name string, value ent.Value) error {
}
m.AddPriority(v)
return nil
+ case account.FieldRateMultiplier:
+ v, ok := value.(float64)
+ if !ok {
+ return fmt.Errorf("unexpected type %T for field %s", value, name)
+ }
+ m.AddRateMultiplier(v)
+ return nil
}
return fmt.Errorf("unknown Account numeric field %s", name)
}
@@ -3090,6 +3174,9 @@ func (m *AccountMutation) ResetField(name string) error {
case account.FieldPriority:
m.ResetPriority()
return nil
+ case account.FieldRateMultiplier:
+ m.ResetRateMultiplier()
+ return nil
case account.FieldStatus:
m.ResetStatus()
return nil
@@ -3777,6 +3864,8 @@ type GroupMutation struct {
claude_code_only *bool
fallback_group_id *int64
addfallback_group_id *int64
+ model_routing *map[string][]int64
+ model_routing_enabled *bool
clearedFields map[string]struct{}
api_keys map[int64]struct{}
removedapi_keys map[int64]struct{}
@@ -4887,6 +4976,91 @@ func (m *GroupMutation) ResetFallbackGroupID() {
delete(m.clearedFields, group.FieldFallbackGroupID)
}
+// SetModelRouting sets the "model_routing" field.
+func (m *GroupMutation) SetModelRouting(value map[string][]int64) {
+ m.model_routing = &value
+}
+
+// ModelRouting returns the value of the "model_routing" field in the mutation.
+func (m *GroupMutation) ModelRouting() (r map[string][]int64, exists bool) {
+ v := m.model_routing
+ if v == nil {
+ return
+ }
+ return *v, true
+}
+
+// OldModelRouting returns the old "model_routing" field's value of the Group entity.
+// If the Group object wasn't provided to the builder, the object is fetched from the database.
+// An error is returned if the mutation operation is not UpdateOne, or the database query fails.
+func (m *GroupMutation) OldModelRouting(ctx context.Context) (v map[string][]int64, err error) {
+ if !m.op.Is(OpUpdateOne) {
+ return v, errors.New("OldModelRouting is only allowed on UpdateOne operations")
+ }
+ if m.id == nil || m.oldValue == nil {
+ return v, errors.New("OldModelRouting requires an ID field in the mutation")
+ }
+ oldValue, err := m.oldValue(ctx)
+ if err != nil {
+ return v, fmt.Errorf("querying old value for OldModelRouting: %w", err)
+ }
+ return oldValue.ModelRouting, nil
+}
+
+// ClearModelRouting clears the value of the "model_routing" field.
+func (m *GroupMutation) ClearModelRouting() {
+ m.model_routing = nil
+ m.clearedFields[group.FieldModelRouting] = struct{}{}
+}
+
+// ModelRoutingCleared returns if the "model_routing" field was cleared in this mutation.
+func (m *GroupMutation) ModelRoutingCleared() bool {
+ _, ok := m.clearedFields[group.FieldModelRouting]
+ return ok
+}
+
+// ResetModelRouting resets all changes to the "model_routing" field.
+func (m *GroupMutation) ResetModelRouting() {
+ m.model_routing = nil
+ delete(m.clearedFields, group.FieldModelRouting)
+}
+
+// SetModelRoutingEnabled sets the "model_routing_enabled" field.
+func (m *GroupMutation) SetModelRoutingEnabled(b bool) {
+ m.model_routing_enabled = &b
+}
+
+// ModelRoutingEnabled returns the value of the "model_routing_enabled" field in the mutation.
+func (m *GroupMutation) ModelRoutingEnabled() (r bool, exists bool) {
+ v := m.model_routing_enabled
+ if v == nil {
+ return
+ }
+ return *v, true
+}
+
+// OldModelRoutingEnabled returns the old "model_routing_enabled" field's value of the Group entity.
+// If the Group object wasn't provided to the builder, the object is fetched from the database.
+// An error is returned if the mutation operation is not UpdateOne, or the database query fails.
+func (m *GroupMutation) OldModelRoutingEnabled(ctx context.Context) (v bool, err error) {
+ if !m.op.Is(OpUpdateOne) {
+ return v, errors.New("OldModelRoutingEnabled is only allowed on UpdateOne operations")
+ }
+ if m.id == nil || m.oldValue == nil {
+ return v, errors.New("OldModelRoutingEnabled requires an ID field in the mutation")
+ }
+ oldValue, err := m.oldValue(ctx)
+ if err != nil {
+ return v, fmt.Errorf("querying old value for OldModelRoutingEnabled: %w", err)
+ }
+ return oldValue.ModelRoutingEnabled, nil
+}
+
+// ResetModelRoutingEnabled resets all changes to the "model_routing_enabled" field.
+func (m *GroupMutation) ResetModelRoutingEnabled() {
+ m.model_routing_enabled = nil
+}
+
// AddAPIKeyIDs adds the "api_keys" edge to the APIKey entity by ids.
func (m *GroupMutation) AddAPIKeyIDs(ids ...int64) {
if m.api_keys == nil {
@@ -5245,7 +5419,7 @@ func (m *GroupMutation) Type() string {
// order to get all numeric fields that were incremented/decremented, call
// AddedFields().
func (m *GroupMutation) Fields() []string {
- fields := make([]string, 0, 19)
+ fields := make([]string, 0, 21)
if m.created_at != nil {
fields = append(fields, group.FieldCreatedAt)
}
@@ -5303,6 +5477,12 @@ func (m *GroupMutation) Fields() []string {
if m.fallback_group_id != nil {
fields = append(fields, group.FieldFallbackGroupID)
}
+ if m.model_routing != nil {
+ fields = append(fields, group.FieldModelRouting)
+ }
+ if m.model_routing_enabled != nil {
+ fields = append(fields, group.FieldModelRoutingEnabled)
+ }
return fields
}
@@ -5349,6 +5529,10 @@ func (m *GroupMutation) Field(name string) (ent.Value, bool) {
return m.ClaudeCodeOnly()
case group.FieldFallbackGroupID:
return m.FallbackGroupID()
+ case group.FieldModelRouting:
+ return m.ModelRouting()
+ case group.FieldModelRoutingEnabled:
+ return m.ModelRoutingEnabled()
}
return nil, false
}
@@ -5396,6 +5580,10 @@ func (m *GroupMutation) OldField(ctx context.Context, name string) (ent.Value, e
return m.OldClaudeCodeOnly(ctx)
case group.FieldFallbackGroupID:
return m.OldFallbackGroupID(ctx)
+ case group.FieldModelRouting:
+ return m.OldModelRouting(ctx)
+ case group.FieldModelRoutingEnabled:
+ return m.OldModelRoutingEnabled(ctx)
}
return nil, fmt.Errorf("unknown Group field %s", name)
}
@@ -5538,6 +5726,20 @@ func (m *GroupMutation) SetField(name string, value ent.Value) error {
}
m.SetFallbackGroupID(v)
return nil
+ case group.FieldModelRouting:
+ v, ok := value.(map[string][]int64)
+ if !ok {
+ return fmt.Errorf("unexpected type %T for field %s", value, name)
+ }
+ m.SetModelRouting(v)
+ return nil
+ case group.FieldModelRoutingEnabled:
+ v, ok := value.(bool)
+ if !ok {
+ return fmt.Errorf("unexpected type %T for field %s", value, name)
+ }
+ m.SetModelRoutingEnabled(v)
+ return nil
}
return fmt.Errorf("unknown Group field %s", name)
}
@@ -5706,6 +5908,9 @@ func (m *GroupMutation) ClearedFields() []string {
if m.FieldCleared(group.FieldFallbackGroupID) {
fields = append(fields, group.FieldFallbackGroupID)
}
+ if m.FieldCleared(group.FieldModelRouting) {
+ fields = append(fields, group.FieldModelRouting)
+ }
return fields
}
@@ -5747,6 +5952,9 @@ func (m *GroupMutation) ClearField(name string) error {
case group.FieldFallbackGroupID:
m.ClearFallbackGroupID()
return nil
+ case group.FieldModelRouting:
+ m.ClearModelRouting()
+ return nil
}
return fmt.Errorf("unknown Group nullable field %s", name)
}
@@ -5812,6 +6020,12 @@ func (m *GroupMutation) ResetField(name string) error {
case group.FieldFallbackGroupID:
m.ResetFallbackGroupID()
return nil
+ case group.FieldModelRouting:
+ m.ResetModelRouting()
+ return nil
+ case group.FieldModelRoutingEnabled:
+ m.ResetModelRoutingEnabled()
+ return nil
}
return fmt.Errorf("unknown Group field %s", name)
}
@@ -10190,6 +10404,8 @@ type UsageLogMutation struct {
addactual_cost *float64
rate_multiplier *float64
addrate_multiplier *float64
+ account_rate_multiplier *float64
+ addaccount_rate_multiplier *float64
billing_type *int8
addbilling_type *int8
stream *bool
@@ -11323,6 +11539,76 @@ func (m *UsageLogMutation) ResetRateMultiplier() {
m.addrate_multiplier = nil
}
+// SetAccountRateMultiplier sets the "account_rate_multiplier" field.
+func (m *UsageLogMutation) SetAccountRateMultiplier(f float64) {
+ m.account_rate_multiplier = &f
+ m.addaccount_rate_multiplier = nil
+}
+
+// AccountRateMultiplier returns the value of the "account_rate_multiplier" field in the mutation.
+func (m *UsageLogMutation) AccountRateMultiplier() (r float64, exists bool) {
+ v := m.account_rate_multiplier
+ if v == nil {
+ return
+ }
+ return *v, true
+}
+
+// OldAccountRateMultiplier returns the old "account_rate_multiplier" field's value of the UsageLog entity.
+// If the UsageLog object wasn't provided to the builder, the object is fetched from the database.
+// An error is returned if the mutation operation is not UpdateOne, or the database query fails.
+func (m *UsageLogMutation) OldAccountRateMultiplier(ctx context.Context) (v *float64, err error) {
+ if !m.op.Is(OpUpdateOne) {
+ return v, errors.New("OldAccountRateMultiplier is only allowed on UpdateOne operations")
+ }
+ if m.id == nil || m.oldValue == nil {
+ return v, errors.New("OldAccountRateMultiplier requires an ID field in the mutation")
+ }
+ oldValue, err := m.oldValue(ctx)
+ if err != nil {
+ return v, fmt.Errorf("querying old value for OldAccountRateMultiplier: %w", err)
+ }
+ return oldValue.AccountRateMultiplier, nil
+}
+
+// AddAccountRateMultiplier adds f to the "account_rate_multiplier" field.
+func (m *UsageLogMutation) AddAccountRateMultiplier(f float64) {
+ if m.addaccount_rate_multiplier != nil {
+ *m.addaccount_rate_multiplier += f
+ } else {
+ m.addaccount_rate_multiplier = &f
+ }
+}
+
+// AddedAccountRateMultiplier returns the value that was added to the "account_rate_multiplier" field in this mutation.
+func (m *UsageLogMutation) AddedAccountRateMultiplier() (r float64, exists bool) {
+ v := m.addaccount_rate_multiplier
+ if v == nil {
+ return
+ }
+ return *v, true
+}
+
+// ClearAccountRateMultiplier clears the value of the "account_rate_multiplier" field.
+func (m *UsageLogMutation) ClearAccountRateMultiplier() {
+ m.account_rate_multiplier = nil
+ m.addaccount_rate_multiplier = nil
+ m.clearedFields[usagelog.FieldAccountRateMultiplier] = struct{}{}
+}
+
+// AccountRateMultiplierCleared returns if the "account_rate_multiplier" field was cleared in this mutation.
+func (m *UsageLogMutation) AccountRateMultiplierCleared() bool {
+ _, ok := m.clearedFields[usagelog.FieldAccountRateMultiplier]
+ return ok
+}
+
+// ResetAccountRateMultiplier resets all changes to the "account_rate_multiplier" field.
+func (m *UsageLogMutation) ResetAccountRateMultiplier() {
+ m.account_rate_multiplier = nil
+ m.addaccount_rate_multiplier = nil
+ delete(m.clearedFields, usagelog.FieldAccountRateMultiplier)
+}
+
// SetBillingType sets the "billing_type" field.
func (m *UsageLogMutation) SetBillingType(i int8) {
m.billing_type = &i
@@ -11963,7 +12249,7 @@ func (m *UsageLogMutation) Type() string {
// order to get all numeric fields that were incremented/decremented, call
// AddedFields().
func (m *UsageLogMutation) Fields() []string {
- fields := make([]string, 0, 29)
+ fields := make([]string, 0, 30)
if m.user != nil {
fields = append(fields, usagelog.FieldUserID)
}
@@ -12024,6 +12310,9 @@ func (m *UsageLogMutation) Fields() []string {
if m.rate_multiplier != nil {
fields = append(fields, usagelog.FieldRateMultiplier)
}
+ if m.account_rate_multiplier != nil {
+ fields = append(fields, usagelog.FieldAccountRateMultiplier)
+ }
if m.billing_type != nil {
fields = append(fields, usagelog.FieldBillingType)
}
@@ -12099,6 +12388,8 @@ func (m *UsageLogMutation) Field(name string) (ent.Value, bool) {
return m.ActualCost()
case usagelog.FieldRateMultiplier:
return m.RateMultiplier()
+ case usagelog.FieldAccountRateMultiplier:
+ return m.AccountRateMultiplier()
case usagelog.FieldBillingType:
return m.BillingType()
case usagelog.FieldStream:
@@ -12166,6 +12457,8 @@ func (m *UsageLogMutation) OldField(ctx context.Context, name string) (ent.Value
return m.OldActualCost(ctx)
case usagelog.FieldRateMultiplier:
return m.OldRateMultiplier(ctx)
+ case usagelog.FieldAccountRateMultiplier:
+ return m.OldAccountRateMultiplier(ctx)
case usagelog.FieldBillingType:
return m.OldBillingType(ctx)
case usagelog.FieldStream:
@@ -12333,6 +12626,13 @@ func (m *UsageLogMutation) SetField(name string, value ent.Value) error {
}
m.SetRateMultiplier(v)
return nil
+ case usagelog.FieldAccountRateMultiplier:
+ v, ok := value.(float64)
+ if !ok {
+ return fmt.Errorf("unexpected type %T for field %s", value, name)
+ }
+ m.SetAccountRateMultiplier(v)
+ return nil
case usagelog.FieldBillingType:
v, ok := value.(int8)
if !ok {
@@ -12443,6 +12743,9 @@ func (m *UsageLogMutation) AddedFields() []string {
if m.addrate_multiplier != nil {
fields = append(fields, usagelog.FieldRateMultiplier)
}
+ if m.addaccount_rate_multiplier != nil {
+ fields = append(fields, usagelog.FieldAccountRateMultiplier)
+ }
if m.addbilling_type != nil {
fields = append(fields, usagelog.FieldBillingType)
}
@@ -12489,6 +12792,8 @@ func (m *UsageLogMutation) AddedField(name string) (ent.Value, bool) {
return m.AddedActualCost()
case usagelog.FieldRateMultiplier:
return m.AddedRateMultiplier()
+ case usagelog.FieldAccountRateMultiplier:
+ return m.AddedAccountRateMultiplier()
case usagelog.FieldBillingType:
return m.AddedBillingType()
case usagelog.FieldDurationMs:
@@ -12597,6 +12902,13 @@ func (m *UsageLogMutation) AddField(name string, value ent.Value) error {
}
m.AddRateMultiplier(v)
return nil
+ case usagelog.FieldAccountRateMultiplier:
+ v, ok := value.(float64)
+ if !ok {
+ return fmt.Errorf("unexpected type %T for field %s", value, name)
+ }
+ m.AddAccountRateMultiplier(v)
+ return nil
case usagelog.FieldBillingType:
v, ok := value.(int8)
if !ok {
@@ -12639,6 +12951,9 @@ func (m *UsageLogMutation) ClearedFields() []string {
if m.FieldCleared(usagelog.FieldSubscriptionID) {
fields = append(fields, usagelog.FieldSubscriptionID)
}
+ if m.FieldCleared(usagelog.FieldAccountRateMultiplier) {
+ fields = append(fields, usagelog.FieldAccountRateMultiplier)
+ }
if m.FieldCleared(usagelog.FieldDurationMs) {
fields = append(fields, usagelog.FieldDurationMs)
}
@@ -12674,6 +12989,9 @@ func (m *UsageLogMutation) ClearField(name string) error {
case usagelog.FieldSubscriptionID:
m.ClearSubscriptionID()
return nil
+ case usagelog.FieldAccountRateMultiplier:
+ m.ClearAccountRateMultiplier()
+ return nil
case usagelog.FieldDurationMs:
m.ClearDurationMs()
return nil
@@ -12757,6 +13075,9 @@ func (m *UsageLogMutation) ResetField(name string) error {
case usagelog.FieldRateMultiplier:
m.ResetRateMultiplier()
return nil
+ case usagelog.FieldAccountRateMultiplier:
+ m.ResetAccountRateMultiplier()
+ return nil
case usagelog.FieldBillingType:
m.ResetBillingType()
return nil
diff --git a/backend/ent/runtime/runtime.go b/backend/ent/runtime/runtime.go
index ad1aa626..0cb10775 100644
--- a/backend/ent/runtime/runtime.go
+++ b/backend/ent/runtime/runtime.go
@@ -177,22 +177,26 @@ func init() {
accountDescPriority := accountFields[8].Descriptor()
// account.DefaultPriority holds the default value on creation for the priority field.
account.DefaultPriority = accountDescPriority.Default.(int)
+ // accountDescRateMultiplier is the schema descriptor for rate_multiplier field.
+ accountDescRateMultiplier := accountFields[9].Descriptor()
+ // account.DefaultRateMultiplier holds the default value on creation for the rate_multiplier field.
+ account.DefaultRateMultiplier = accountDescRateMultiplier.Default.(float64)
// accountDescStatus is the schema descriptor for status field.
- accountDescStatus := accountFields[9].Descriptor()
+ accountDescStatus := accountFields[10].Descriptor()
// account.DefaultStatus holds the default value on creation for the status field.
account.DefaultStatus = accountDescStatus.Default.(string)
// account.StatusValidator is a validator for the "status" field. It is called by the builders before save.
account.StatusValidator = accountDescStatus.Validators[0].(func(string) error)
// accountDescAutoPauseOnExpired is the schema descriptor for auto_pause_on_expired field.
- accountDescAutoPauseOnExpired := accountFields[13].Descriptor()
+ accountDescAutoPauseOnExpired := accountFields[14].Descriptor()
// account.DefaultAutoPauseOnExpired holds the default value on creation for the auto_pause_on_expired field.
account.DefaultAutoPauseOnExpired = accountDescAutoPauseOnExpired.Default.(bool)
// accountDescSchedulable is the schema descriptor for schedulable field.
- accountDescSchedulable := accountFields[14].Descriptor()
+ accountDescSchedulable := accountFields[15].Descriptor()
// account.DefaultSchedulable holds the default value on creation for the schedulable field.
account.DefaultSchedulable = accountDescSchedulable.Default.(bool)
// accountDescSessionWindowStatus is the schema descriptor for session_window_status field.
- accountDescSessionWindowStatus := accountFields[20].Descriptor()
+ accountDescSessionWindowStatus := accountFields[21].Descriptor()
// account.SessionWindowStatusValidator is a validator for the "session_window_status" field. It is called by the builders before save.
account.SessionWindowStatusValidator = accountDescSessionWindowStatus.Validators[0].(func(string) error)
accountgroupFields := schema.AccountGroup{}.Fields()
@@ -276,6 +280,10 @@ func init() {
groupDescClaudeCodeOnly := groupFields[14].Descriptor()
// group.DefaultClaudeCodeOnly holds the default value on creation for the claude_code_only field.
group.DefaultClaudeCodeOnly = groupDescClaudeCodeOnly.Default.(bool)
+ // groupDescModelRoutingEnabled is the schema descriptor for model_routing_enabled field.
+ groupDescModelRoutingEnabled := groupFields[17].Descriptor()
+ // group.DefaultModelRoutingEnabled holds the default value on creation for the model_routing_enabled field.
+ group.DefaultModelRoutingEnabled = groupDescModelRoutingEnabled.Default.(bool)
promocodeFields := schema.PromoCode{}.Fields()
_ = promocodeFields
// promocodeDescCode is the schema descriptor for code field.
@@ -578,31 +586,31 @@ func init() {
// usagelog.DefaultRateMultiplier holds the default value on creation for the rate_multiplier field.
usagelog.DefaultRateMultiplier = usagelogDescRateMultiplier.Default.(float64)
// usagelogDescBillingType is the schema descriptor for billing_type field.
- usagelogDescBillingType := usagelogFields[20].Descriptor()
+ usagelogDescBillingType := usagelogFields[21].Descriptor()
// usagelog.DefaultBillingType holds the default value on creation for the billing_type field.
usagelog.DefaultBillingType = usagelogDescBillingType.Default.(int8)
// usagelogDescStream is the schema descriptor for stream field.
- usagelogDescStream := usagelogFields[21].Descriptor()
+ usagelogDescStream := usagelogFields[22].Descriptor()
// usagelog.DefaultStream holds the default value on creation for the stream field.
usagelog.DefaultStream = usagelogDescStream.Default.(bool)
// usagelogDescUserAgent is the schema descriptor for user_agent field.
- usagelogDescUserAgent := usagelogFields[24].Descriptor()
+ usagelogDescUserAgent := usagelogFields[25].Descriptor()
// usagelog.UserAgentValidator is a validator for the "user_agent" field. It is called by the builders before save.
usagelog.UserAgentValidator = usagelogDescUserAgent.Validators[0].(func(string) error)
// usagelogDescIPAddress is the schema descriptor for ip_address field.
- usagelogDescIPAddress := usagelogFields[25].Descriptor()
+ usagelogDescIPAddress := usagelogFields[26].Descriptor()
// usagelog.IPAddressValidator is a validator for the "ip_address" field. It is called by the builders before save.
usagelog.IPAddressValidator = usagelogDescIPAddress.Validators[0].(func(string) error)
// usagelogDescImageCount is the schema descriptor for image_count field.
- usagelogDescImageCount := usagelogFields[26].Descriptor()
+ usagelogDescImageCount := usagelogFields[27].Descriptor()
// usagelog.DefaultImageCount holds the default value on creation for the image_count field.
usagelog.DefaultImageCount = usagelogDescImageCount.Default.(int)
// usagelogDescImageSize is the schema descriptor for image_size field.
- usagelogDescImageSize := usagelogFields[27].Descriptor()
+ usagelogDescImageSize := usagelogFields[28].Descriptor()
// usagelog.ImageSizeValidator is a validator for the "image_size" field. It is called by the builders before save.
usagelog.ImageSizeValidator = usagelogDescImageSize.Validators[0].(func(string) error)
// usagelogDescCreatedAt is the schema descriptor for created_at field.
- usagelogDescCreatedAt := usagelogFields[28].Descriptor()
+ usagelogDescCreatedAt := usagelogFields[29].Descriptor()
// usagelog.DefaultCreatedAt holds the default value on creation for the created_at field.
usagelog.DefaultCreatedAt = usagelogDescCreatedAt.Default.(func() time.Time)
userMixin := schema.User{}.Mixin()
diff --git a/backend/ent/schema/account.go b/backend/ent/schema/account.go
index ec192a97..dd79ba96 100644
--- a/backend/ent/schema/account.go
+++ b/backend/ent/schema/account.go
@@ -102,6 +102,12 @@ func (Account) Fields() []ent.Field {
field.Int("priority").
Default(50),
+ // rate_multiplier: 账号计费倍率(>=0,允许 0 表示该账号计费为 0)
+ // 仅影响账号维度计费口径,不影响用户/API Key 扣费(分组倍率)
+ field.Float("rate_multiplier").
+ SchemaType(map[string]string{dialect.Postgres: "decimal(10,4)"}).
+ Default(1.0),
+
// status: 账户状态,如 "active", "error", "disabled"
field.String("status").
MaxLen(20).
diff --git a/backend/ent/schema/group.go b/backend/ent/schema/group.go
index d38925b1..5d0a1e9a 100644
--- a/backend/ent/schema/group.go
+++ b/backend/ent/schema/group.go
@@ -95,6 +95,17 @@ func (Group) Fields() []ent.Field {
Optional().
Nillable().
Comment("非 Claude Code 请求降级使用的分组 ID"),
+
+ // 模型路由配置 (added by migration 040)
+ field.JSON("model_routing", map[string][]int64{}).
+ Optional().
+ SchemaType(map[string]string{dialect.Postgres: "jsonb"}).
+ Comment("模型路由配置:模型模式 -> 优先账号ID列表"),
+
+ // 模型路由开关 (added by migration 041)
+ field.Bool("model_routing_enabled").
+ Default(false).
+ Comment("是否启用模型路由配置"),
}
}
diff --git a/backend/ent/schema/usage_log.go b/backend/ent/schema/usage_log.go
index 264a4087..fc7c7165 100644
--- a/backend/ent/schema/usage_log.go
+++ b/backend/ent/schema/usage_log.go
@@ -85,6 +85,12 @@ func (UsageLog) Fields() []ent.Field {
Default(1).
SchemaType(map[string]string{dialect.Postgres: "decimal(10,4)"}),
+ // account_rate_multiplier: 账号计费倍率快照(NULL 表示按 1.0 处理)
+ field.Float("account_rate_multiplier").
+ Optional().
+ Nillable().
+ SchemaType(map[string]string{dialect.Postgres: "decimal(10,4)"}),
+
// 其他字段
field.Int8("billing_type").
Default(0),
diff --git a/backend/ent/usagelog.go b/backend/ent/usagelog.go
index cd576466..81c466b4 100644
--- a/backend/ent/usagelog.go
+++ b/backend/ent/usagelog.go
@@ -62,6 +62,8 @@ type UsageLog struct {
ActualCost float64 `json:"actual_cost,omitempty"`
// RateMultiplier holds the value of the "rate_multiplier" field.
RateMultiplier float64 `json:"rate_multiplier,omitempty"`
+ // AccountRateMultiplier holds the value of the "account_rate_multiplier" field.
+ AccountRateMultiplier *float64 `json:"account_rate_multiplier,omitempty"`
// BillingType holds the value of the "billing_type" field.
BillingType int8 `json:"billing_type,omitempty"`
// Stream holds the value of the "stream" field.
@@ -165,7 +167,7 @@ func (*UsageLog) scanValues(columns []string) ([]any, error) {
switch columns[i] {
case usagelog.FieldStream:
values[i] = new(sql.NullBool)
- case usagelog.FieldInputCost, usagelog.FieldOutputCost, usagelog.FieldCacheCreationCost, usagelog.FieldCacheReadCost, usagelog.FieldTotalCost, usagelog.FieldActualCost, usagelog.FieldRateMultiplier:
+ case usagelog.FieldInputCost, usagelog.FieldOutputCost, usagelog.FieldCacheCreationCost, usagelog.FieldCacheReadCost, usagelog.FieldTotalCost, usagelog.FieldActualCost, usagelog.FieldRateMultiplier, usagelog.FieldAccountRateMultiplier:
values[i] = new(sql.NullFloat64)
case usagelog.FieldID, usagelog.FieldUserID, usagelog.FieldAPIKeyID, usagelog.FieldAccountID, usagelog.FieldGroupID, usagelog.FieldSubscriptionID, usagelog.FieldInputTokens, usagelog.FieldOutputTokens, usagelog.FieldCacheCreationTokens, usagelog.FieldCacheReadTokens, usagelog.FieldCacheCreation5mTokens, usagelog.FieldCacheCreation1hTokens, usagelog.FieldBillingType, usagelog.FieldDurationMs, usagelog.FieldFirstTokenMs, usagelog.FieldImageCount:
values[i] = new(sql.NullInt64)
@@ -316,6 +318,13 @@ func (_m *UsageLog) assignValues(columns []string, values []any) error {
} else if value.Valid {
_m.RateMultiplier = value.Float64
}
+ case usagelog.FieldAccountRateMultiplier:
+ if value, ok := values[i].(*sql.NullFloat64); !ok {
+ return fmt.Errorf("unexpected type %T for field account_rate_multiplier", values[i])
+ } else if value.Valid {
+ _m.AccountRateMultiplier = new(float64)
+ *_m.AccountRateMultiplier = value.Float64
+ }
case usagelog.FieldBillingType:
if value, ok := values[i].(*sql.NullInt64); !ok {
return fmt.Errorf("unexpected type %T for field billing_type", values[i])
@@ -500,6 +509,11 @@ func (_m *UsageLog) String() string {
builder.WriteString("rate_multiplier=")
builder.WriteString(fmt.Sprintf("%v", _m.RateMultiplier))
builder.WriteString(", ")
+ if v := _m.AccountRateMultiplier; v != nil {
+ builder.WriteString("account_rate_multiplier=")
+ builder.WriteString(fmt.Sprintf("%v", *v))
+ }
+ builder.WriteString(", ")
builder.WriteString("billing_type=")
builder.WriteString(fmt.Sprintf("%v", _m.BillingType))
builder.WriteString(", ")
diff --git a/backend/ent/usagelog/usagelog.go b/backend/ent/usagelog/usagelog.go
index c06925c4..980f1e58 100644
--- a/backend/ent/usagelog/usagelog.go
+++ b/backend/ent/usagelog/usagelog.go
@@ -54,6 +54,8 @@ const (
FieldActualCost = "actual_cost"
// FieldRateMultiplier holds the string denoting the rate_multiplier field in the database.
FieldRateMultiplier = "rate_multiplier"
+ // FieldAccountRateMultiplier holds the string denoting the account_rate_multiplier field in the database.
+ FieldAccountRateMultiplier = "account_rate_multiplier"
// FieldBillingType holds the string denoting the billing_type field in the database.
FieldBillingType = "billing_type"
// FieldStream holds the string denoting the stream field in the database.
@@ -144,6 +146,7 @@ var Columns = []string{
FieldTotalCost,
FieldActualCost,
FieldRateMultiplier,
+ FieldAccountRateMultiplier,
FieldBillingType,
FieldStream,
FieldDurationMs,
@@ -320,6 +323,11 @@ func ByRateMultiplier(opts ...sql.OrderTermOption) OrderOption {
return sql.OrderByField(FieldRateMultiplier, opts...).ToFunc()
}
+// ByAccountRateMultiplier orders the results by the account_rate_multiplier field.
+func ByAccountRateMultiplier(opts ...sql.OrderTermOption) OrderOption {
+ return sql.OrderByField(FieldAccountRateMultiplier, opts...).ToFunc()
+}
+
// ByBillingType orders the results by the billing_type field.
func ByBillingType(opts ...sql.OrderTermOption) OrderOption {
return sql.OrderByField(FieldBillingType, opts...).ToFunc()
diff --git a/backend/ent/usagelog/where.go b/backend/ent/usagelog/where.go
index 96b7a19c..28e2ab4c 100644
--- a/backend/ent/usagelog/where.go
+++ b/backend/ent/usagelog/where.go
@@ -155,6 +155,11 @@ func RateMultiplier(v float64) predicate.UsageLog {
return predicate.UsageLog(sql.FieldEQ(FieldRateMultiplier, v))
}
+// AccountRateMultiplier applies equality check predicate on the "account_rate_multiplier" field. It's identical to AccountRateMultiplierEQ.
+func AccountRateMultiplier(v float64) predicate.UsageLog {
+ return predicate.UsageLog(sql.FieldEQ(FieldAccountRateMultiplier, v))
+}
+
// BillingType applies equality check predicate on the "billing_type" field. It's identical to BillingTypeEQ.
func BillingType(v int8) predicate.UsageLog {
return predicate.UsageLog(sql.FieldEQ(FieldBillingType, v))
@@ -970,6 +975,56 @@ func RateMultiplierLTE(v float64) predicate.UsageLog {
return predicate.UsageLog(sql.FieldLTE(FieldRateMultiplier, v))
}
+// AccountRateMultiplierEQ applies the EQ predicate on the "account_rate_multiplier" field.
+func AccountRateMultiplierEQ(v float64) predicate.UsageLog {
+ return predicate.UsageLog(sql.FieldEQ(FieldAccountRateMultiplier, v))
+}
+
+// AccountRateMultiplierNEQ applies the NEQ predicate on the "account_rate_multiplier" field.
+func AccountRateMultiplierNEQ(v float64) predicate.UsageLog {
+ return predicate.UsageLog(sql.FieldNEQ(FieldAccountRateMultiplier, v))
+}
+
+// AccountRateMultiplierIn applies the In predicate on the "account_rate_multiplier" field.
+func AccountRateMultiplierIn(vs ...float64) predicate.UsageLog {
+ return predicate.UsageLog(sql.FieldIn(FieldAccountRateMultiplier, vs...))
+}
+
+// AccountRateMultiplierNotIn applies the NotIn predicate on the "account_rate_multiplier" field.
+func AccountRateMultiplierNotIn(vs ...float64) predicate.UsageLog {
+ return predicate.UsageLog(sql.FieldNotIn(FieldAccountRateMultiplier, vs...))
+}
+
+// AccountRateMultiplierGT applies the GT predicate on the "account_rate_multiplier" field.
+func AccountRateMultiplierGT(v float64) predicate.UsageLog {
+ return predicate.UsageLog(sql.FieldGT(FieldAccountRateMultiplier, v))
+}
+
+// AccountRateMultiplierGTE applies the GTE predicate on the "account_rate_multiplier" field.
+func AccountRateMultiplierGTE(v float64) predicate.UsageLog {
+ return predicate.UsageLog(sql.FieldGTE(FieldAccountRateMultiplier, v))
+}
+
+// AccountRateMultiplierLT applies the LT predicate on the "account_rate_multiplier" field.
+func AccountRateMultiplierLT(v float64) predicate.UsageLog {
+ return predicate.UsageLog(sql.FieldLT(FieldAccountRateMultiplier, v))
+}
+
+// AccountRateMultiplierLTE applies the LTE predicate on the "account_rate_multiplier" field.
+func AccountRateMultiplierLTE(v float64) predicate.UsageLog {
+ return predicate.UsageLog(sql.FieldLTE(FieldAccountRateMultiplier, v))
+}
+
+// AccountRateMultiplierIsNil applies the IsNil predicate on the "account_rate_multiplier" field.
+func AccountRateMultiplierIsNil() predicate.UsageLog {
+ return predicate.UsageLog(sql.FieldIsNull(FieldAccountRateMultiplier))
+}
+
+// AccountRateMultiplierNotNil applies the NotNil predicate on the "account_rate_multiplier" field.
+func AccountRateMultiplierNotNil() predicate.UsageLog {
+ return predicate.UsageLog(sql.FieldNotNull(FieldAccountRateMultiplier))
+}
+
// BillingTypeEQ applies the EQ predicate on the "billing_type" field.
func BillingTypeEQ(v int8) predicate.UsageLog {
return predicate.UsageLog(sql.FieldEQ(FieldBillingType, v))
diff --git a/backend/ent/usagelog_create.go b/backend/ent/usagelog_create.go
index e63fab05..a17d6507 100644
--- a/backend/ent/usagelog_create.go
+++ b/backend/ent/usagelog_create.go
@@ -267,6 +267,20 @@ func (_c *UsageLogCreate) SetNillableRateMultiplier(v *float64) *UsageLogCreate
return _c
}
+// SetAccountRateMultiplier sets the "account_rate_multiplier" field.
+func (_c *UsageLogCreate) SetAccountRateMultiplier(v float64) *UsageLogCreate {
+ _c.mutation.SetAccountRateMultiplier(v)
+ return _c
+}
+
+// SetNillableAccountRateMultiplier sets the "account_rate_multiplier" field if the given value is not nil.
+func (_c *UsageLogCreate) SetNillableAccountRateMultiplier(v *float64) *UsageLogCreate {
+ if v != nil {
+ _c.SetAccountRateMultiplier(*v)
+ }
+ return _c
+}
+
// SetBillingType sets the "billing_type" field.
func (_c *UsageLogCreate) SetBillingType(v int8) *UsageLogCreate {
_c.mutation.SetBillingType(v)
@@ -712,6 +726,10 @@ func (_c *UsageLogCreate) createSpec() (*UsageLog, *sqlgraph.CreateSpec) {
_spec.SetField(usagelog.FieldRateMultiplier, field.TypeFloat64, value)
_node.RateMultiplier = value
}
+ if value, ok := _c.mutation.AccountRateMultiplier(); ok {
+ _spec.SetField(usagelog.FieldAccountRateMultiplier, field.TypeFloat64, value)
+ _node.AccountRateMultiplier = &value
+ }
if value, ok := _c.mutation.BillingType(); ok {
_spec.SetField(usagelog.FieldBillingType, field.TypeInt8, value)
_node.BillingType = value
@@ -1215,6 +1233,30 @@ func (u *UsageLogUpsert) AddRateMultiplier(v float64) *UsageLogUpsert {
return u
}
+// SetAccountRateMultiplier sets the "account_rate_multiplier" field.
+func (u *UsageLogUpsert) SetAccountRateMultiplier(v float64) *UsageLogUpsert {
+ u.Set(usagelog.FieldAccountRateMultiplier, v)
+ return u
+}
+
+// UpdateAccountRateMultiplier sets the "account_rate_multiplier" field to the value that was provided on create.
+func (u *UsageLogUpsert) UpdateAccountRateMultiplier() *UsageLogUpsert {
+ u.SetExcluded(usagelog.FieldAccountRateMultiplier)
+ return u
+}
+
+// AddAccountRateMultiplier adds v to the "account_rate_multiplier" field.
+func (u *UsageLogUpsert) AddAccountRateMultiplier(v float64) *UsageLogUpsert {
+ u.Add(usagelog.FieldAccountRateMultiplier, v)
+ return u
+}
+
+// ClearAccountRateMultiplier clears the value of the "account_rate_multiplier" field.
+func (u *UsageLogUpsert) ClearAccountRateMultiplier() *UsageLogUpsert {
+ u.SetNull(usagelog.FieldAccountRateMultiplier)
+ return u
+}
+
// SetBillingType sets the "billing_type" field.
func (u *UsageLogUpsert) SetBillingType(v int8) *UsageLogUpsert {
u.Set(usagelog.FieldBillingType, v)
@@ -1795,6 +1837,34 @@ func (u *UsageLogUpsertOne) UpdateRateMultiplier() *UsageLogUpsertOne {
})
}
+// SetAccountRateMultiplier sets the "account_rate_multiplier" field.
+func (u *UsageLogUpsertOne) SetAccountRateMultiplier(v float64) *UsageLogUpsertOne {
+ return u.Update(func(s *UsageLogUpsert) {
+ s.SetAccountRateMultiplier(v)
+ })
+}
+
+// AddAccountRateMultiplier adds v to the "account_rate_multiplier" field.
+func (u *UsageLogUpsertOne) AddAccountRateMultiplier(v float64) *UsageLogUpsertOne {
+ return u.Update(func(s *UsageLogUpsert) {
+ s.AddAccountRateMultiplier(v)
+ })
+}
+
+// UpdateAccountRateMultiplier sets the "account_rate_multiplier" field to the value that was provided on create.
+func (u *UsageLogUpsertOne) UpdateAccountRateMultiplier() *UsageLogUpsertOne {
+ return u.Update(func(s *UsageLogUpsert) {
+ s.UpdateAccountRateMultiplier()
+ })
+}
+
+// ClearAccountRateMultiplier clears the value of the "account_rate_multiplier" field.
+func (u *UsageLogUpsertOne) ClearAccountRateMultiplier() *UsageLogUpsertOne {
+ return u.Update(func(s *UsageLogUpsert) {
+ s.ClearAccountRateMultiplier()
+ })
+}
+
// SetBillingType sets the "billing_type" field.
func (u *UsageLogUpsertOne) SetBillingType(v int8) *UsageLogUpsertOne {
return u.Update(func(s *UsageLogUpsert) {
@@ -2566,6 +2636,34 @@ func (u *UsageLogUpsertBulk) UpdateRateMultiplier() *UsageLogUpsertBulk {
})
}
+// SetAccountRateMultiplier sets the "account_rate_multiplier" field.
+func (u *UsageLogUpsertBulk) SetAccountRateMultiplier(v float64) *UsageLogUpsertBulk {
+ return u.Update(func(s *UsageLogUpsert) {
+ s.SetAccountRateMultiplier(v)
+ })
+}
+
+// AddAccountRateMultiplier adds v to the "account_rate_multiplier" field.
+func (u *UsageLogUpsertBulk) AddAccountRateMultiplier(v float64) *UsageLogUpsertBulk {
+ return u.Update(func(s *UsageLogUpsert) {
+ s.AddAccountRateMultiplier(v)
+ })
+}
+
+// UpdateAccountRateMultiplier sets the "account_rate_multiplier" field to the value that was provided on create.
+func (u *UsageLogUpsertBulk) UpdateAccountRateMultiplier() *UsageLogUpsertBulk {
+ return u.Update(func(s *UsageLogUpsert) {
+ s.UpdateAccountRateMultiplier()
+ })
+}
+
+// ClearAccountRateMultiplier clears the value of the "account_rate_multiplier" field.
+func (u *UsageLogUpsertBulk) ClearAccountRateMultiplier() *UsageLogUpsertBulk {
+ return u.Update(func(s *UsageLogUpsert) {
+ s.ClearAccountRateMultiplier()
+ })
+}
+
// SetBillingType sets the "billing_type" field.
func (u *UsageLogUpsertBulk) SetBillingType(v int8) *UsageLogUpsertBulk {
return u.Update(func(s *UsageLogUpsert) {
diff --git a/backend/ent/usagelog_update.go b/backend/ent/usagelog_update.go
index ec2acbbb..571a7b3c 100644
--- a/backend/ent/usagelog_update.go
+++ b/backend/ent/usagelog_update.go
@@ -415,6 +415,33 @@ func (_u *UsageLogUpdate) AddRateMultiplier(v float64) *UsageLogUpdate {
return _u
}
+// SetAccountRateMultiplier sets the "account_rate_multiplier" field.
+func (_u *UsageLogUpdate) SetAccountRateMultiplier(v float64) *UsageLogUpdate {
+ _u.mutation.ResetAccountRateMultiplier()
+ _u.mutation.SetAccountRateMultiplier(v)
+ return _u
+}
+
+// SetNillableAccountRateMultiplier sets the "account_rate_multiplier" field if the given value is not nil.
+func (_u *UsageLogUpdate) SetNillableAccountRateMultiplier(v *float64) *UsageLogUpdate {
+ if v != nil {
+ _u.SetAccountRateMultiplier(*v)
+ }
+ return _u
+}
+
+// AddAccountRateMultiplier adds value to the "account_rate_multiplier" field.
+func (_u *UsageLogUpdate) AddAccountRateMultiplier(v float64) *UsageLogUpdate {
+ _u.mutation.AddAccountRateMultiplier(v)
+ return _u
+}
+
+// ClearAccountRateMultiplier clears the value of the "account_rate_multiplier" field.
+func (_u *UsageLogUpdate) ClearAccountRateMultiplier() *UsageLogUpdate {
+ _u.mutation.ClearAccountRateMultiplier()
+ return _u
+}
+
// SetBillingType sets the "billing_type" field.
func (_u *UsageLogUpdate) SetBillingType(v int8) *UsageLogUpdate {
_u.mutation.ResetBillingType()
@@ -807,6 +834,15 @@ func (_u *UsageLogUpdate) sqlSave(ctx context.Context) (_node int, err error) {
if value, ok := _u.mutation.AddedRateMultiplier(); ok {
_spec.AddField(usagelog.FieldRateMultiplier, field.TypeFloat64, value)
}
+ if value, ok := _u.mutation.AccountRateMultiplier(); ok {
+ _spec.SetField(usagelog.FieldAccountRateMultiplier, field.TypeFloat64, value)
+ }
+ if value, ok := _u.mutation.AddedAccountRateMultiplier(); ok {
+ _spec.AddField(usagelog.FieldAccountRateMultiplier, field.TypeFloat64, value)
+ }
+ if _u.mutation.AccountRateMultiplierCleared() {
+ _spec.ClearField(usagelog.FieldAccountRateMultiplier, field.TypeFloat64)
+ }
if value, ok := _u.mutation.BillingType(); ok {
_spec.SetField(usagelog.FieldBillingType, field.TypeInt8, value)
}
@@ -1406,6 +1442,33 @@ func (_u *UsageLogUpdateOne) AddRateMultiplier(v float64) *UsageLogUpdateOne {
return _u
}
+// SetAccountRateMultiplier sets the "account_rate_multiplier" field.
+func (_u *UsageLogUpdateOne) SetAccountRateMultiplier(v float64) *UsageLogUpdateOne {
+ _u.mutation.ResetAccountRateMultiplier()
+ _u.mutation.SetAccountRateMultiplier(v)
+ return _u
+}
+
+// SetNillableAccountRateMultiplier sets the "account_rate_multiplier" field if the given value is not nil.
+func (_u *UsageLogUpdateOne) SetNillableAccountRateMultiplier(v *float64) *UsageLogUpdateOne {
+ if v != nil {
+ _u.SetAccountRateMultiplier(*v)
+ }
+ return _u
+}
+
+// AddAccountRateMultiplier adds value to the "account_rate_multiplier" field.
+func (_u *UsageLogUpdateOne) AddAccountRateMultiplier(v float64) *UsageLogUpdateOne {
+ _u.mutation.AddAccountRateMultiplier(v)
+ return _u
+}
+
+// ClearAccountRateMultiplier clears the value of the "account_rate_multiplier" field.
+func (_u *UsageLogUpdateOne) ClearAccountRateMultiplier() *UsageLogUpdateOne {
+ _u.mutation.ClearAccountRateMultiplier()
+ return _u
+}
+
// SetBillingType sets the "billing_type" field.
func (_u *UsageLogUpdateOne) SetBillingType(v int8) *UsageLogUpdateOne {
_u.mutation.ResetBillingType()
@@ -1828,6 +1891,15 @@ func (_u *UsageLogUpdateOne) sqlSave(ctx context.Context) (_node *UsageLog, err
if value, ok := _u.mutation.AddedRateMultiplier(); ok {
_spec.AddField(usagelog.FieldRateMultiplier, field.TypeFloat64, value)
}
+ if value, ok := _u.mutation.AccountRateMultiplier(); ok {
+ _spec.SetField(usagelog.FieldAccountRateMultiplier, field.TypeFloat64, value)
+ }
+ if value, ok := _u.mutation.AddedAccountRateMultiplier(); ok {
+ _spec.AddField(usagelog.FieldAccountRateMultiplier, field.TypeFloat64, value)
+ }
+ if _u.mutation.AccountRateMultiplierCleared() {
+ _spec.ClearField(usagelog.FieldAccountRateMultiplier, field.TypeFloat64)
+ }
if value, ok := _u.mutation.BillingType(); ok {
_spec.SetField(usagelog.FieldBillingType, field.TypeInt8, value)
}
diff --git a/backend/internal/config/config.go b/backend/internal/config/config.go
index 944e0f84..655169cc 100644
--- a/backend/internal/config/config.go
+++ b/backend/internal/config/config.go
@@ -19,7 +19,9 @@ const (
RunModeSimple = "simple"
)
-const DefaultCSPPolicy = "default-src 'self'; script-src 'self' https://challenges.cloudflare.com; style-src 'self' 'unsafe-inline' https://fonts.googleapis.com; img-src 'self' data: https:; font-src 'self' data: https://fonts.gstatic.com; connect-src 'self' https:; frame-src https://challenges.cloudflare.com; frame-ancestors 'none'; base-uri 'self'; form-action 'self'"
+// DefaultCSPPolicy is the default Content-Security-Policy with nonce support
+// __CSP_NONCE__ will be replaced with actual nonce at request time by the SecurityHeaders middleware
+const DefaultCSPPolicy = "default-src 'self'; script-src 'self' __CSP_NONCE__ https://challenges.cloudflare.com https://static.cloudflareinsights.com; style-src 'self' 'unsafe-inline' https://fonts.googleapis.com; img-src 'self' data: https:; font-src 'self' data: https://fonts.gstatic.com; connect-src 'self' https:; frame-src https://challenges.cloudflare.com; frame-ancestors 'none'; base-uri 'self'; form-action 'self'"
// 连接池隔离策略常量
// 用于控制上游 HTTP 连接池的隔离粒度,影响连接复用和资源消耗
@@ -232,6 +234,10 @@ type GatewayConfig struct {
// ConcurrencySlotTTLMinutes: 并发槽位过期时间(分钟)
// 应大于最长 LLM 请求时间,防止请求完成前槽位过期
ConcurrencySlotTTLMinutes int `mapstructure:"concurrency_slot_ttl_minutes"`
+ // SessionIdleTimeoutMinutes: 会话空闲超时时间(分钟),默认 5 分钟
+ // 用于 Anthropic OAuth/SetupToken 账号的会话数量限制功能
+ // 空闲超过此时间的会话将被自动释放
+ SessionIdleTimeoutMinutes int `mapstructure:"session_idle_timeout_minutes"`
// StreamDataIntervalTimeout: 流数据间隔超时(秒),0表示禁用
StreamDataIntervalTimeout int `mapstructure:"stream_data_interval_timeout"`
diff --git a/backend/internal/handler/admin/account_handler.go b/backend/internal/handler/admin/account_handler.go
index 8a7270e5..33c91dae 100644
--- a/backend/internal/handler/admin/account_handler.go
+++ b/backend/internal/handler/admin/account_handler.go
@@ -44,6 +44,7 @@ type AccountHandler struct {
accountTestService *service.AccountTestService
concurrencyService *service.ConcurrencyService
crsSyncService *service.CRSSyncService
+ sessionLimitCache service.SessionLimitCache
}
// NewAccountHandler creates a new admin account handler
@@ -58,6 +59,7 @@ func NewAccountHandler(
accountTestService *service.AccountTestService,
concurrencyService *service.ConcurrencyService,
crsSyncService *service.CRSSyncService,
+ sessionLimitCache service.SessionLimitCache,
) *AccountHandler {
return &AccountHandler{
adminService: adminService,
@@ -70,6 +72,7 @@ func NewAccountHandler(
accountTestService: accountTestService,
concurrencyService: concurrencyService,
crsSyncService: crsSyncService,
+ sessionLimitCache: sessionLimitCache,
}
}
@@ -84,6 +87,7 @@ type CreateAccountRequest struct {
ProxyID *int64 `json:"proxy_id"`
Concurrency int `json:"concurrency"`
Priority int `json:"priority"`
+ RateMultiplier *float64 `json:"rate_multiplier"`
GroupIDs []int64 `json:"group_ids"`
ExpiresAt *int64 `json:"expires_at"`
AutoPauseOnExpired *bool `json:"auto_pause_on_expired"`
@@ -101,6 +105,7 @@ type UpdateAccountRequest struct {
ProxyID *int64 `json:"proxy_id"`
Concurrency *int `json:"concurrency"`
Priority *int `json:"priority"`
+ RateMultiplier *float64 `json:"rate_multiplier"`
Status string `json:"status" binding:"omitempty,oneof=active inactive"`
GroupIDs *[]int64 `json:"group_ids"`
ExpiresAt *int64 `json:"expires_at"`
@@ -115,6 +120,7 @@ type BulkUpdateAccountsRequest struct {
ProxyID *int64 `json:"proxy_id"`
Concurrency *int `json:"concurrency"`
Priority *int `json:"priority"`
+ RateMultiplier *float64 `json:"rate_multiplier"`
Status string `json:"status" binding:"omitempty,oneof=active inactive error"`
Schedulable *bool `json:"schedulable"`
GroupIDs *[]int64 `json:"group_ids"`
@@ -127,6 +133,9 @@ type BulkUpdateAccountsRequest struct {
type AccountWithConcurrency struct {
*dto.Account
CurrentConcurrency int `json:"current_concurrency"`
+ // 以下字段仅对 Anthropic OAuth/SetupToken 账号有效,且仅在启用相应功能时返回
+ CurrentWindowCost *float64 `json:"current_window_cost,omitempty"` // 当前窗口费用
+ ActiveSessions *int `json:"active_sessions,omitempty"` // 当前活跃会话数
}
// List handles listing all accounts with pagination
@@ -161,13 +170,89 @@ func (h *AccountHandler) List(c *gin.Context) {
concurrencyCounts = make(map[int64]int)
}
+ // 识别需要查询窗口费用和会话数的账号(Anthropic OAuth/SetupToken 且启用了相应功能)
+ windowCostAccountIDs := make([]int64, 0)
+ sessionLimitAccountIDs := make([]int64, 0)
+ for i := range accounts {
+ acc := &accounts[i]
+ if acc.IsAnthropicOAuthOrSetupToken() {
+ if acc.GetWindowCostLimit() > 0 {
+ windowCostAccountIDs = append(windowCostAccountIDs, acc.ID)
+ }
+ if acc.GetMaxSessions() > 0 {
+ sessionLimitAccountIDs = append(sessionLimitAccountIDs, acc.ID)
+ }
+ }
+ }
+
+ // 并行获取窗口费用和活跃会话数
+ var windowCosts map[int64]float64
+ var activeSessions map[int64]int
+
+ // 获取活跃会话数(批量查询)
+ if len(sessionLimitAccountIDs) > 0 && h.sessionLimitCache != nil {
+ activeSessions, _ = h.sessionLimitCache.GetActiveSessionCountBatch(c.Request.Context(), sessionLimitAccountIDs)
+ if activeSessions == nil {
+ activeSessions = make(map[int64]int)
+ }
+ }
+
+ // 获取窗口费用(并行查询)
+ if len(windowCostAccountIDs) > 0 {
+ windowCosts = make(map[int64]float64)
+ var mu sync.Mutex
+ g, gctx := errgroup.WithContext(c.Request.Context())
+ g.SetLimit(10) // 限制并发数
+
+ for i := range accounts {
+ acc := &accounts[i]
+ if !acc.IsAnthropicOAuthOrSetupToken() || acc.GetWindowCostLimit() <= 0 {
+ continue
+ }
+ accCopy := acc // 闭包捕获
+ g.Go(func() error {
+ var startTime time.Time
+ if accCopy.SessionWindowStart != nil {
+ startTime = *accCopy.SessionWindowStart
+ } else {
+ startTime = time.Now().Add(-5 * time.Hour)
+ }
+ stats, err := h.accountUsageService.GetAccountWindowStats(gctx, accCopy.ID, startTime)
+ if err == nil && stats != nil {
+ mu.Lock()
+ windowCosts[accCopy.ID] = stats.StandardCost // 使用标准费用
+ mu.Unlock()
+ }
+ return nil // 不返回错误,允许部分失败
+ })
+ }
+ _ = g.Wait()
+ }
+
// Build response with concurrency info
result := make([]AccountWithConcurrency, len(accounts))
for i := range accounts {
- result[i] = AccountWithConcurrency{
- Account: dto.AccountFromService(&accounts[i]),
- CurrentConcurrency: concurrencyCounts[accounts[i].ID],
+ acc := &accounts[i]
+ item := AccountWithConcurrency{
+ Account: dto.AccountFromService(acc),
+ CurrentConcurrency: concurrencyCounts[acc.ID],
}
+
+ // 添加窗口费用(仅当启用时)
+ if windowCosts != nil {
+ if cost, ok := windowCosts[acc.ID]; ok {
+ item.CurrentWindowCost = &cost
+ }
+ }
+
+ // 添加活跃会话数(仅当启用时)
+ if activeSessions != nil {
+ if count, ok := activeSessions[acc.ID]; ok {
+ item.ActiveSessions = &count
+ }
+ }
+
+ result[i] = item
}
response.Paginated(c, result, total, page, pageSize)
@@ -199,6 +284,10 @@ func (h *AccountHandler) Create(c *gin.Context) {
response.BadRequest(c, "Invalid request: "+err.Error())
return
}
+ if req.RateMultiplier != nil && *req.RateMultiplier < 0 {
+ response.BadRequest(c, "rate_multiplier must be >= 0")
+ return
+ }
// 确定是否跳过混合渠道检查
skipCheck := req.ConfirmMixedChannelRisk != nil && *req.ConfirmMixedChannelRisk
@@ -213,6 +302,7 @@ func (h *AccountHandler) Create(c *gin.Context) {
ProxyID: req.ProxyID,
Concurrency: req.Concurrency,
Priority: req.Priority,
+ RateMultiplier: req.RateMultiplier,
GroupIDs: req.GroupIDs,
ExpiresAt: req.ExpiresAt,
AutoPauseOnExpired: req.AutoPauseOnExpired,
@@ -258,6 +348,10 @@ func (h *AccountHandler) Update(c *gin.Context) {
response.BadRequest(c, "Invalid request: "+err.Error())
return
}
+ if req.RateMultiplier != nil && *req.RateMultiplier < 0 {
+ response.BadRequest(c, "rate_multiplier must be >= 0")
+ return
+ }
// 确定是否跳过混合渠道检查
skipCheck := req.ConfirmMixedChannelRisk != nil && *req.ConfirmMixedChannelRisk
@@ -271,6 +365,7 @@ func (h *AccountHandler) Update(c *gin.Context) {
ProxyID: req.ProxyID,
Concurrency: req.Concurrency, // 指针类型,nil 表示未提供
Priority: req.Priority, // 指针类型,nil 表示未提供
+ RateMultiplier: req.RateMultiplier,
Status: req.Status,
GroupIDs: req.GroupIDs,
ExpiresAt: req.ExpiresAt,
@@ -652,6 +747,10 @@ func (h *AccountHandler) BulkUpdate(c *gin.Context) {
response.BadRequest(c, "Invalid request: "+err.Error())
return
}
+ if req.RateMultiplier != nil && *req.RateMultiplier < 0 {
+ response.BadRequest(c, "rate_multiplier must be >= 0")
+ return
+ }
// 确定是否跳过混合渠道检查
skipCheck := req.ConfirmMixedChannelRisk != nil && *req.ConfirmMixedChannelRisk
@@ -660,6 +759,7 @@ func (h *AccountHandler) BulkUpdate(c *gin.Context) {
req.ProxyID != nil ||
req.Concurrency != nil ||
req.Priority != nil ||
+ req.RateMultiplier != nil ||
req.Status != "" ||
req.Schedulable != nil ||
req.GroupIDs != nil ||
@@ -677,6 +777,7 @@ func (h *AccountHandler) BulkUpdate(c *gin.Context) {
ProxyID: req.ProxyID,
Concurrency: req.Concurrency,
Priority: req.Priority,
+ RateMultiplier: req.RateMultiplier,
Status: req.Status,
Schedulable: req.Schedulable,
GroupIDs: req.GroupIDs,
diff --git a/backend/internal/handler/admin/dashboard_handler.go b/backend/internal/handler/admin/dashboard_handler.go
index 9b675974..3f07403d 100644
--- a/backend/internal/handler/admin/dashboard_handler.go
+++ b/backend/internal/handler/admin/dashboard_handler.go
@@ -186,13 +186,16 @@ func (h *DashboardHandler) GetRealtimeMetrics(c *gin.Context) {
// GetUsageTrend handles getting usage trend data
// GET /api/v1/admin/dashboard/trend
-// Query params: start_date, end_date (YYYY-MM-DD), granularity (day/hour), user_id, api_key_id
+// Query params: start_date, end_date (YYYY-MM-DD), granularity (day/hour), user_id, api_key_id, model, account_id, group_id, stream
func (h *DashboardHandler) GetUsageTrend(c *gin.Context) {
startTime, endTime := parseTimeRange(c)
granularity := c.DefaultQuery("granularity", "day")
// Parse optional filter params
- var userID, apiKeyID int64
+ var userID, apiKeyID, accountID, groupID int64
+ var model string
+ var stream *bool
+
if userIDStr := c.Query("user_id"); userIDStr != "" {
if id, err := strconv.ParseInt(userIDStr, 10, 64); err == nil {
userID = id
@@ -203,8 +206,26 @@ func (h *DashboardHandler) GetUsageTrend(c *gin.Context) {
apiKeyID = id
}
}
+ if accountIDStr := c.Query("account_id"); accountIDStr != "" {
+ if id, err := strconv.ParseInt(accountIDStr, 10, 64); err == nil {
+ accountID = id
+ }
+ }
+ if groupIDStr := c.Query("group_id"); groupIDStr != "" {
+ if id, err := strconv.ParseInt(groupIDStr, 10, 64); err == nil {
+ groupID = id
+ }
+ }
+ if modelStr := c.Query("model"); modelStr != "" {
+ model = modelStr
+ }
+ if streamStr := c.Query("stream"); streamStr != "" {
+ if streamVal, err := strconv.ParseBool(streamStr); err == nil {
+ stream = &streamVal
+ }
+ }
- trend, err := h.dashboardService.GetUsageTrendWithFilters(c.Request.Context(), startTime, endTime, granularity, userID, apiKeyID)
+ trend, err := h.dashboardService.GetUsageTrendWithFilters(c.Request.Context(), startTime, endTime, granularity, userID, apiKeyID, accountID, groupID, model, stream)
if err != nil {
response.Error(c, 500, "Failed to get usage trend")
return
@@ -220,12 +241,14 @@ func (h *DashboardHandler) GetUsageTrend(c *gin.Context) {
// GetModelStats handles getting model usage statistics
// GET /api/v1/admin/dashboard/models
-// Query params: start_date, end_date (YYYY-MM-DD), user_id, api_key_id
+// Query params: start_date, end_date (YYYY-MM-DD), user_id, api_key_id, account_id, group_id, stream
func (h *DashboardHandler) GetModelStats(c *gin.Context) {
startTime, endTime := parseTimeRange(c)
// Parse optional filter params
- var userID, apiKeyID int64
+ var userID, apiKeyID, accountID, groupID int64
+ var stream *bool
+
if userIDStr := c.Query("user_id"); userIDStr != "" {
if id, err := strconv.ParseInt(userIDStr, 10, 64); err == nil {
userID = id
@@ -236,8 +259,23 @@ func (h *DashboardHandler) GetModelStats(c *gin.Context) {
apiKeyID = id
}
}
+ if accountIDStr := c.Query("account_id"); accountIDStr != "" {
+ if id, err := strconv.ParseInt(accountIDStr, 10, 64); err == nil {
+ accountID = id
+ }
+ }
+ if groupIDStr := c.Query("group_id"); groupIDStr != "" {
+ if id, err := strconv.ParseInt(groupIDStr, 10, 64); err == nil {
+ groupID = id
+ }
+ }
+ if streamStr := c.Query("stream"); streamStr != "" {
+ if streamVal, err := strconv.ParseBool(streamStr); err == nil {
+ stream = &streamVal
+ }
+ }
- stats, err := h.dashboardService.GetModelStatsWithFilters(c.Request.Context(), startTime, endTime, userID, apiKeyID)
+ stats, err := h.dashboardService.GetModelStatsWithFilters(c.Request.Context(), startTime, endTime, userID, apiKeyID, accountID, groupID, stream)
if err != nil {
response.Error(c, 500, "Failed to get model statistics")
return
diff --git a/backend/internal/handler/admin/group_handler.go b/backend/internal/handler/admin/group_handler.go
index a8bae35e..f6780dee 100644
--- a/backend/internal/handler/admin/group_handler.go
+++ b/backend/internal/handler/admin/group_handler.go
@@ -40,6 +40,9 @@ type CreateGroupRequest struct {
ImagePrice4K *float64 `json:"image_price_4k"`
ClaudeCodeOnly bool `json:"claude_code_only"`
FallbackGroupID *int64 `json:"fallback_group_id"`
+ // 模型路由配置(仅 anthropic 平台使用)
+ ModelRouting map[string][]int64 `json:"model_routing"`
+ ModelRoutingEnabled bool `json:"model_routing_enabled"`
}
// UpdateGroupRequest represents update group request
@@ -60,6 +63,9 @@ type UpdateGroupRequest struct {
ImagePrice4K *float64 `json:"image_price_4k"`
ClaudeCodeOnly *bool `json:"claude_code_only"`
FallbackGroupID *int64 `json:"fallback_group_id"`
+ // 模型路由配置(仅 anthropic 平台使用)
+ ModelRouting map[string][]int64 `json:"model_routing"`
+ ModelRoutingEnabled *bool `json:"model_routing_enabled"`
}
// List handles listing all groups with pagination
@@ -149,20 +155,22 @@ func (h *GroupHandler) Create(c *gin.Context) {
}
group, err := h.adminService.CreateGroup(c.Request.Context(), &service.CreateGroupInput{
- Name: req.Name,
- Description: req.Description,
- Platform: req.Platform,
- RateMultiplier: req.RateMultiplier,
- IsExclusive: req.IsExclusive,
- SubscriptionType: req.SubscriptionType,
- DailyLimitUSD: req.DailyLimitUSD,
- WeeklyLimitUSD: req.WeeklyLimitUSD,
- MonthlyLimitUSD: req.MonthlyLimitUSD,
- ImagePrice1K: req.ImagePrice1K,
- ImagePrice2K: req.ImagePrice2K,
- ImagePrice4K: req.ImagePrice4K,
- ClaudeCodeOnly: req.ClaudeCodeOnly,
- FallbackGroupID: req.FallbackGroupID,
+ Name: req.Name,
+ Description: req.Description,
+ Platform: req.Platform,
+ RateMultiplier: req.RateMultiplier,
+ IsExclusive: req.IsExclusive,
+ SubscriptionType: req.SubscriptionType,
+ DailyLimitUSD: req.DailyLimitUSD,
+ WeeklyLimitUSD: req.WeeklyLimitUSD,
+ MonthlyLimitUSD: req.MonthlyLimitUSD,
+ ImagePrice1K: req.ImagePrice1K,
+ ImagePrice2K: req.ImagePrice2K,
+ ImagePrice4K: req.ImagePrice4K,
+ ClaudeCodeOnly: req.ClaudeCodeOnly,
+ FallbackGroupID: req.FallbackGroupID,
+ ModelRouting: req.ModelRouting,
+ ModelRoutingEnabled: req.ModelRoutingEnabled,
})
if err != nil {
response.ErrorFrom(c, err)
@@ -188,21 +196,23 @@ func (h *GroupHandler) Update(c *gin.Context) {
}
group, err := h.adminService.UpdateGroup(c.Request.Context(), groupID, &service.UpdateGroupInput{
- Name: req.Name,
- Description: req.Description,
- Platform: req.Platform,
- RateMultiplier: req.RateMultiplier,
- IsExclusive: req.IsExclusive,
- Status: req.Status,
- SubscriptionType: req.SubscriptionType,
- DailyLimitUSD: req.DailyLimitUSD,
- WeeklyLimitUSD: req.WeeklyLimitUSD,
- MonthlyLimitUSD: req.MonthlyLimitUSD,
- ImagePrice1K: req.ImagePrice1K,
- ImagePrice2K: req.ImagePrice2K,
- ImagePrice4K: req.ImagePrice4K,
- ClaudeCodeOnly: req.ClaudeCodeOnly,
- FallbackGroupID: req.FallbackGroupID,
+ Name: req.Name,
+ Description: req.Description,
+ Platform: req.Platform,
+ RateMultiplier: req.RateMultiplier,
+ IsExclusive: req.IsExclusive,
+ Status: req.Status,
+ SubscriptionType: req.SubscriptionType,
+ DailyLimitUSD: req.DailyLimitUSD,
+ WeeklyLimitUSD: req.WeeklyLimitUSD,
+ MonthlyLimitUSD: req.MonthlyLimitUSD,
+ ImagePrice1K: req.ImagePrice1K,
+ ImagePrice2K: req.ImagePrice2K,
+ ImagePrice4K: req.ImagePrice4K,
+ ClaudeCodeOnly: req.ClaudeCodeOnly,
+ FallbackGroupID: req.FallbackGroupID,
+ ModelRouting: req.ModelRouting,
+ ModelRoutingEnabled: req.ModelRoutingEnabled,
})
if err != nil {
response.ErrorFrom(c, err)
diff --git a/backend/internal/handler/admin/ops_alerts_handler.go b/backend/internal/handler/admin/ops_alerts_handler.go
index 1e33ddd5..c9da19c7 100644
--- a/backend/internal/handler/admin/ops_alerts_handler.go
+++ b/backend/internal/handler/admin/ops_alerts_handler.go
@@ -7,8 +7,10 @@ import (
"net/http"
"strconv"
"strings"
+ "time"
"github.com/Wei-Shaw/sub2api/internal/pkg/response"
+ "github.com/Wei-Shaw/sub2api/internal/server/middleware"
"github.com/Wei-Shaw/sub2api/internal/service"
"github.com/gin-gonic/gin"
"github.com/gin-gonic/gin/binding"
@@ -18,8 +20,6 @@ var validOpsAlertMetricTypes = []string{
"success_rate",
"error_rate",
"upstream_error_rate",
- "p95_latency_ms",
- "p99_latency_ms",
"cpu_usage_percent",
"memory_usage_percent",
"concurrency_queue_depth",
@@ -372,8 +372,135 @@ func (h *OpsHandler) DeleteAlertRule(c *gin.Context) {
response.Success(c, gin.H{"deleted": true})
}
+// GetAlertEvent returns a single ops alert event.
+// GET /api/v1/admin/ops/alert-events/:id
+func (h *OpsHandler) GetAlertEvent(c *gin.Context) {
+ if h.opsService == nil {
+ response.Error(c, http.StatusServiceUnavailable, "Ops service not available")
+ return
+ }
+ if err := h.opsService.RequireMonitoringEnabled(c.Request.Context()); err != nil {
+ response.ErrorFrom(c, err)
+ return
+ }
+
+ id, err := strconv.ParseInt(c.Param("id"), 10, 64)
+ if err != nil || id <= 0 {
+ response.BadRequest(c, "Invalid event ID")
+ return
+ }
+
+ ev, err := h.opsService.GetAlertEventByID(c.Request.Context(), id)
+ if err != nil {
+ response.ErrorFrom(c, err)
+ return
+ }
+ response.Success(c, ev)
+}
+
+// UpdateAlertEventStatus updates an ops alert event status.
+// PUT /api/v1/admin/ops/alert-events/:id/status
+func (h *OpsHandler) UpdateAlertEventStatus(c *gin.Context) {
+ if h.opsService == nil {
+ response.Error(c, http.StatusServiceUnavailable, "Ops service not available")
+ return
+ }
+ if err := h.opsService.RequireMonitoringEnabled(c.Request.Context()); err != nil {
+ response.ErrorFrom(c, err)
+ return
+ }
+
+ id, err := strconv.ParseInt(c.Param("id"), 10, 64)
+ if err != nil || id <= 0 {
+ response.BadRequest(c, "Invalid event ID")
+ return
+ }
+
+ var payload struct {
+ Status string `json:"status"`
+ }
+ if err := c.ShouldBindJSON(&payload); err != nil {
+ response.BadRequest(c, "Invalid request body")
+ return
+ }
+ payload.Status = strings.TrimSpace(payload.Status)
+ if payload.Status == "" {
+ response.BadRequest(c, "Invalid status")
+ return
+ }
+ if payload.Status != service.OpsAlertStatusResolved && payload.Status != service.OpsAlertStatusManualResolved {
+ response.BadRequest(c, "Invalid status")
+ return
+ }
+
+ var resolvedAt *time.Time
+ if payload.Status == service.OpsAlertStatusResolved || payload.Status == service.OpsAlertStatusManualResolved {
+ now := time.Now().UTC()
+ resolvedAt = &now
+ }
+ if err := h.opsService.UpdateAlertEventStatus(c.Request.Context(), id, payload.Status, resolvedAt); err != nil {
+ response.ErrorFrom(c, err)
+ return
+ }
+ response.Success(c, gin.H{"updated": true})
+}
+
// ListAlertEvents lists recent ops alert events.
// GET /api/v1/admin/ops/alert-events
+// CreateAlertSilence creates a scoped silence for ops alerts.
+// POST /api/v1/admin/ops/alert-silences
+func (h *OpsHandler) CreateAlertSilence(c *gin.Context) {
+ if h.opsService == nil {
+ response.Error(c, http.StatusServiceUnavailable, "Ops service not available")
+ return
+ }
+ if err := h.opsService.RequireMonitoringEnabled(c.Request.Context()); err != nil {
+ response.ErrorFrom(c, err)
+ return
+ }
+
+ var payload struct {
+ RuleID int64 `json:"rule_id"`
+ Platform string `json:"platform"`
+ GroupID *int64 `json:"group_id"`
+ Region *string `json:"region"`
+ Until string `json:"until"`
+ Reason string `json:"reason"`
+ }
+ if err := c.ShouldBindJSON(&payload); err != nil {
+ response.BadRequest(c, "Invalid request body")
+ return
+ }
+ until, err := time.Parse(time.RFC3339, strings.TrimSpace(payload.Until))
+ if err != nil {
+ response.BadRequest(c, "Invalid until")
+ return
+ }
+
+ createdBy := (*int64)(nil)
+ if subject, ok := middleware.GetAuthSubjectFromContext(c); ok {
+ uid := subject.UserID
+ createdBy = &uid
+ }
+
+ silence := &service.OpsAlertSilence{
+ RuleID: payload.RuleID,
+ Platform: strings.TrimSpace(payload.Platform),
+ GroupID: payload.GroupID,
+ Region: payload.Region,
+ Until: until,
+ Reason: strings.TrimSpace(payload.Reason),
+ CreatedBy: createdBy,
+ }
+
+ created, err := h.opsService.CreateAlertSilence(c.Request.Context(), silence)
+ if err != nil {
+ response.ErrorFrom(c, err)
+ return
+ }
+ response.Success(c, created)
+}
+
func (h *OpsHandler) ListAlertEvents(c *gin.Context) {
if h.opsService == nil {
response.Error(c, http.StatusServiceUnavailable, "Ops service not available")
@@ -384,7 +511,7 @@ func (h *OpsHandler) ListAlertEvents(c *gin.Context) {
return
}
- limit := 100
+ limit := 20
if raw := strings.TrimSpace(c.Query("limit")); raw != "" {
n, err := strconv.Atoi(raw)
if err != nil || n <= 0 {
@@ -400,6 +527,49 @@ func (h *OpsHandler) ListAlertEvents(c *gin.Context) {
Severity: strings.TrimSpace(c.Query("severity")),
}
+ if v := strings.TrimSpace(c.Query("email_sent")); v != "" {
+ vv := strings.ToLower(v)
+ switch vv {
+ case "true", "1":
+ b := true
+ filter.EmailSent = &b
+ case "false", "0":
+ b := false
+ filter.EmailSent = &b
+ default:
+ response.BadRequest(c, "Invalid email_sent")
+ return
+ }
+ }
+
+ // Cursor pagination: both params must be provided together.
+ rawTS := strings.TrimSpace(c.Query("before_fired_at"))
+ rawID := strings.TrimSpace(c.Query("before_id"))
+ if (rawTS == "") != (rawID == "") {
+ response.BadRequest(c, "before_fired_at and before_id must be provided together")
+ return
+ }
+ if rawTS != "" {
+ ts, err := time.Parse(time.RFC3339Nano, rawTS)
+ if err != nil {
+ if t2, err2 := time.Parse(time.RFC3339, rawTS); err2 == nil {
+ ts = t2
+ } else {
+ response.BadRequest(c, "Invalid before_fired_at")
+ return
+ }
+ }
+ filter.BeforeFiredAt = &ts
+ }
+ if rawID != "" {
+ id, err := strconv.ParseInt(rawID, 10, 64)
+ if err != nil || id <= 0 {
+ response.BadRequest(c, "Invalid before_id")
+ return
+ }
+ filter.BeforeID = &id
+ }
+
// Optional global filter support (platform/group/time range).
if platform := strings.TrimSpace(c.Query("platform")); platform != "" {
filter.Platform = platform
diff --git a/backend/internal/handler/admin/ops_handler.go b/backend/internal/handler/admin/ops_handler.go
index bff7426a..44accc8f 100644
--- a/backend/internal/handler/admin/ops_handler.go
+++ b/backend/internal/handler/admin/ops_handler.go
@@ -19,6 +19,57 @@ type OpsHandler struct {
opsService *service.OpsService
}
+// GetErrorLogByID returns ops error log detail.
+// GET /api/v1/admin/ops/errors/:id
+func (h *OpsHandler) GetErrorLogByID(c *gin.Context) {
+ if h.opsService == nil {
+ response.Error(c, http.StatusServiceUnavailable, "Ops service not available")
+ return
+ }
+ if err := h.opsService.RequireMonitoringEnabled(c.Request.Context()); err != nil {
+ response.ErrorFrom(c, err)
+ return
+ }
+
+ idStr := strings.TrimSpace(c.Param("id"))
+ id, err := strconv.ParseInt(idStr, 10, 64)
+ if err != nil || id <= 0 {
+ response.BadRequest(c, "Invalid error id")
+ return
+ }
+
+ detail, err := h.opsService.GetErrorLogByID(c.Request.Context(), id)
+ if err != nil {
+ response.ErrorFrom(c, err)
+ return
+ }
+
+ response.Success(c, detail)
+}
+
+const (
+ opsListViewErrors = "errors"
+ opsListViewExcluded = "excluded"
+ opsListViewAll = "all"
+)
+
+func parseOpsViewParam(c *gin.Context) string {
+ if c == nil {
+ return ""
+ }
+ v := strings.ToLower(strings.TrimSpace(c.Query("view")))
+ switch v {
+ case "", opsListViewErrors:
+ return opsListViewErrors
+ case opsListViewExcluded:
+ return opsListViewExcluded
+ case opsListViewAll:
+ return opsListViewAll
+ default:
+ return opsListViewErrors
+ }
+}
+
func NewOpsHandler(opsService *service.OpsService) *OpsHandler {
return &OpsHandler{opsService: opsService}
}
@@ -47,16 +98,26 @@ func (h *OpsHandler) GetErrorLogs(c *gin.Context) {
return
}
- filter := &service.OpsErrorLogFilter{
- Page: page,
- PageSize: pageSize,
- }
+ filter := &service.OpsErrorLogFilter{Page: page, PageSize: pageSize}
+
if !startTime.IsZero() {
filter.StartTime = &startTime
}
if !endTime.IsZero() {
filter.EndTime = &endTime
}
+ filter.View = parseOpsViewParam(c)
+ filter.Phase = strings.TrimSpace(c.Query("phase"))
+ filter.Owner = strings.TrimSpace(c.Query("error_owner"))
+ filter.Source = strings.TrimSpace(c.Query("error_source"))
+ filter.Query = strings.TrimSpace(c.Query("q"))
+ filter.UserQuery = strings.TrimSpace(c.Query("user_query"))
+
+ // Force request errors: client-visible status >= 400.
+ // buildOpsErrorLogsWhere already applies this for non-upstream phase.
+ if strings.EqualFold(strings.TrimSpace(filter.Phase), "upstream") {
+ filter.Phase = ""
+ }
if platform := strings.TrimSpace(c.Query("platform")); platform != "" {
filter.Platform = platform
@@ -77,11 +138,19 @@ func (h *OpsHandler) GetErrorLogs(c *gin.Context) {
}
filter.AccountID = &id
}
- if phase := strings.TrimSpace(c.Query("phase")); phase != "" {
- filter.Phase = phase
- }
- if q := strings.TrimSpace(c.Query("q")); q != "" {
- filter.Query = q
+
+ if v := strings.TrimSpace(c.Query("resolved")); v != "" {
+ switch strings.ToLower(v) {
+ case "1", "true", "yes":
+ b := true
+ filter.Resolved = &b
+ case "0", "false", "no":
+ b := false
+ filter.Resolved = &b
+ default:
+ response.BadRequest(c, "Invalid resolved")
+ return
+ }
}
if statusCodesStr := strings.TrimSpace(c.Query("status_codes")); statusCodesStr != "" {
parts := strings.Split(statusCodesStr, ",")
@@ -106,13 +175,120 @@ func (h *OpsHandler) GetErrorLogs(c *gin.Context) {
response.ErrorFrom(c, err)
return
}
-
response.Paginated(c, result.Errors, int64(result.Total), result.Page, result.PageSize)
}
-// GetErrorLogByID returns a single error log detail.
-// GET /api/v1/admin/ops/errors/:id
-func (h *OpsHandler) GetErrorLogByID(c *gin.Context) {
+// ListRequestErrors lists client-visible request errors.
+// GET /api/v1/admin/ops/request-errors
+func (h *OpsHandler) ListRequestErrors(c *gin.Context) {
+ if h.opsService == nil {
+ response.Error(c, http.StatusServiceUnavailable, "Ops service not available")
+ return
+ }
+ if err := h.opsService.RequireMonitoringEnabled(c.Request.Context()); err != nil {
+ response.ErrorFrom(c, err)
+ return
+ }
+
+ page, pageSize := response.ParsePagination(c)
+ if pageSize > 500 {
+ pageSize = 500
+ }
+ startTime, endTime, err := parseOpsTimeRange(c, "1h")
+ if err != nil {
+ response.BadRequest(c, err.Error())
+ return
+ }
+
+ filter := &service.OpsErrorLogFilter{Page: page, PageSize: pageSize}
+ if !startTime.IsZero() {
+ filter.StartTime = &startTime
+ }
+ if !endTime.IsZero() {
+ filter.EndTime = &endTime
+ }
+ filter.View = parseOpsViewParam(c)
+ filter.Phase = strings.TrimSpace(c.Query("phase"))
+ filter.Owner = strings.TrimSpace(c.Query("error_owner"))
+ filter.Source = strings.TrimSpace(c.Query("error_source"))
+ filter.Query = strings.TrimSpace(c.Query("q"))
+ filter.UserQuery = strings.TrimSpace(c.Query("user_query"))
+
+ // Force request errors: client-visible status >= 400.
+ // buildOpsErrorLogsWhere already applies this for non-upstream phase.
+ if strings.EqualFold(strings.TrimSpace(filter.Phase), "upstream") {
+ filter.Phase = ""
+ }
+
+ if platform := strings.TrimSpace(c.Query("platform")); platform != "" {
+ filter.Platform = platform
+ }
+ if v := strings.TrimSpace(c.Query("group_id")); v != "" {
+ id, err := strconv.ParseInt(v, 10, 64)
+ if err != nil || id <= 0 {
+ response.BadRequest(c, "Invalid group_id")
+ return
+ }
+ filter.GroupID = &id
+ }
+ if v := strings.TrimSpace(c.Query("account_id")); v != "" {
+ id, err := strconv.ParseInt(v, 10, 64)
+ if err != nil || id <= 0 {
+ response.BadRequest(c, "Invalid account_id")
+ return
+ }
+ filter.AccountID = &id
+ }
+
+ if v := strings.TrimSpace(c.Query("resolved")); v != "" {
+ switch strings.ToLower(v) {
+ case "1", "true", "yes":
+ b := true
+ filter.Resolved = &b
+ case "0", "false", "no":
+ b := false
+ filter.Resolved = &b
+ default:
+ response.BadRequest(c, "Invalid resolved")
+ return
+ }
+ }
+ if statusCodesStr := strings.TrimSpace(c.Query("status_codes")); statusCodesStr != "" {
+ parts := strings.Split(statusCodesStr, ",")
+ out := make([]int, 0, len(parts))
+ for _, part := range parts {
+ p := strings.TrimSpace(part)
+ if p == "" {
+ continue
+ }
+ n, err := strconv.Atoi(p)
+ if err != nil || n < 0 {
+ response.BadRequest(c, "Invalid status_codes")
+ return
+ }
+ out = append(out, n)
+ }
+ filter.StatusCodes = out
+ }
+
+ result, err := h.opsService.GetErrorLogs(c.Request.Context(), filter)
+ if err != nil {
+ response.ErrorFrom(c, err)
+ return
+ }
+ response.Paginated(c, result.Errors, int64(result.Total), result.Page, result.PageSize)
+}
+
+// GetRequestError returns request error detail.
+// GET /api/v1/admin/ops/request-errors/:id
+func (h *OpsHandler) GetRequestError(c *gin.Context) {
+ // same storage; just proxy to existing detail
+ h.GetErrorLogByID(c)
+}
+
+// ListRequestErrorUpstreamErrors lists upstream error logs correlated to a request error.
+// GET /api/v1/admin/ops/request-errors/:id/upstream-errors
+func (h *OpsHandler) ListRequestErrorUpstreamErrors(c *gin.Context) {
if h.opsService == nil {
response.Error(c, http.StatusServiceUnavailable, "Ops service not available")
return
@@ -129,15 +305,306 @@ func (h *OpsHandler) GetErrorLogByID(c *gin.Context) {
return
}
+ // Load request error to get correlation keys.
detail, err := h.opsService.GetErrorLogByID(c.Request.Context(), id)
if err != nil {
response.ErrorFrom(c, err)
return
}
- response.Success(c, detail)
+ // Correlate by request_id/client_request_id.
+ requestID := strings.TrimSpace(detail.RequestID)
+ clientRequestID := strings.TrimSpace(detail.ClientRequestID)
+ if requestID == "" && clientRequestID == "" {
+ response.Paginated(c, []*service.OpsErrorLog{}, 0, 1, 10)
+ return
+ }
+
+ page, pageSize := response.ParsePagination(c)
+ if pageSize > 500 {
+ pageSize = 500
+ }
+
+ // Keep correlation window wide enough so linked upstream errors
+ // are discoverable even when UI defaults to 1h elsewhere.
+ startTime, endTime, err := parseOpsTimeRange(c, "30d")
+ if err != nil {
+ response.BadRequest(c, err.Error())
+ return
+ }
+
+ filter := &service.OpsErrorLogFilter{Page: page, PageSize: pageSize}
+ if !startTime.IsZero() {
+ filter.StartTime = &startTime
+ }
+ if !endTime.IsZero() {
+ filter.EndTime = &endTime
+ }
+ filter.View = "all"
+ filter.Phase = "upstream"
+ filter.Owner = "provider"
+ filter.Source = strings.TrimSpace(c.Query("error_source"))
+ filter.Query = strings.TrimSpace(c.Query("q"))
+
+ if platform := strings.TrimSpace(c.Query("platform")); platform != "" {
+ filter.Platform = platform
+ }
+
+ // Prefer exact match on request_id; if missing, fall back to client_request_id.
+ if requestID != "" {
+ filter.RequestID = requestID
+ } else {
+ filter.ClientRequestID = clientRequestID
+ }
+
+ result, err := h.opsService.GetErrorLogs(c.Request.Context(), filter)
+ if err != nil {
+ response.ErrorFrom(c, err)
+ return
+ }
+
+ // If client asks for details, expand each upstream error log to include upstream response fields.
+ includeDetail := strings.TrimSpace(c.Query("include_detail"))
+ if includeDetail == "1" || strings.EqualFold(includeDetail, "true") || strings.EqualFold(includeDetail, "yes") {
+ details := make([]*service.OpsErrorLogDetail, 0, len(result.Errors))
+ for _, item := range result.Errors {
+ if item == nil {
+ continue
+ }
+ d, err := h.opsService.GetErrorLogByID(c.Request.Context(), item.ID)
+ if err != nil || d == nil {
+ continue
+ }
+ details = append(details, d)
+ }
+ response.Paginated(c, details, int64(result.Total), result.Page, result.PageSize)
+ return
+ }
+
+ response.Paginated(c, result.Errors, int64(result.Total), result.Page, result.PageSize)
}
+// RetryRequestErrorClient retries the client request based on stored request body.
+// POST /api/v1/admin/ops/request-errors/:id/retry-client
+func (h *OpsHandler) RetryRequestErrorClient(c *gin.Context) {
+ if h.opsService == nil {
+ response.Error(c, http.StatusServiceUnavailable, "Ops service not available")
+ return
+ }
+ if err := h.opsService.RequireMonitoringEnabled(c.Request.Context()); err != nil {
+ response.ErrorFrom(c, err)
+ return
+ }
+
+ subject, ok := middleware.GetAuthSubjectFromContext(c)
+ if !ok || subject.UserID <= 0 {
+ response.Error(c, http.StatusUnauthorized, "Unauthorized")
+ return
+ }
+
+ idStr := strings.TrimSpace(c.Param("id"))
+ id, err := strconv.ParseInt(idStr, 10, 64)
+ if err != nil || id <= 0 {
+ response.BadRequest(c, "Invalid error id")
+ return
+ }
+
+ result, err := h.opsService.RetryError(c.Request.Context(), subject.UserID, id, service.OpsRetryModeClient, nil)
+ if err != nil {
+ response.ErrorFrom(c, err)
+ return
+ }
+ response.Success(c, result)
+}
+
+// RetryRequestErrorUpstreamEvent retries a specific upstream attempt using captured upstream_request_body.
+// POST /api/v1/admin/ops/request-errors/:id/upstream-errors/:idx/retry
+func (h *OpsHandler) RetryRequestErrorUpstreamEvent(c *gin.Context) {
+ if h.opsService == nil {
+ response.Error(c, http.StatusServiceUnavailable, "Ops service not available")
+ return
+ }
+ if err := h.opsService.RequireMonitoringEnabled(c.Request.Context()); err != nil {
+ response.ErrorFrom(c, err)
+ return
+ }
+
+ subject, ok := middleware.GetAuthSubjectFromContext(c)
+ if !ok || subject.UserID <= 0 {
+ response.Error(c, http.StatusUnauthorized, "Unauthorized")
+ return
+ }
+
+ idStr := strings.TrimSpace(c.Param("id"))
+ id, err := strconv.ParseInt(idStr, 10, 64)
+ if err != nil || id <= 0 {
+ response.BadRequest(c, "Invalid error id")
+ return
+ }
+
+ idxStr := strings.TrimSpace(c.Param("idx"))
+ idx, err := strconv.Atoi(idxStr)
+ if err != nil || idx < 0 {
+ response.BadRequest(c, "Invalid upstream idx")
+ return
+ }
+
+ result, err := h.opsService.RetryUpstreamEvent(c.Request.Context(), subject.UserID, id, idx)
+ if err != nil {
+ response.ErrorFrom(c, err)
+ return
+ }
+ response.Success(c, result)
+}
+
+// ResolveRequestError toggles resolved status.
+// PUT /api/v1/admin/ops/request-errors/:id/resolve
+func (h *OpsHandler) ResolveRequestError(c *gin.Context) {
+ h.UpdateErrorResolution(c)
+}
+
+// ListUpstreamErrors lists independent upstream errors.
+// GET /api/v1/admin/ops/upstream-errors
+func (h *OpsHandler) ListUpstreamErrors(c *gin.Context) {
+ if h.opsService == nil {
+ response.Error(c, http.StatusServiceUnavailable, "Ops service not available")
+ return
+ }
+ if err := h.opsService.RequireMonitoringEnabled(c.Request.Context()); err != nil {
+ response.ErrorFrom(c, err)
+ return
+ }
+
+ page, pageSize := response.ParsePagination(c)
+ if pageSize > 500 {
+ pageSize = 500
+ }
+ startTime, endTime, err := parseOpsTimeRange(c, "1h")
+ if err != nil {
+ response.BadRequest(c, err.Error())
+ return
+ }
+
+ filter := &service.OpsErrorLogFilter{Page: page, PageSize: pageSize}
+ if !startTime.IsZero() {
+ filter.StartTime = &startTime
+ }
+ if !endTime.IsZero() {
+ filter.EndTime = &endTime
+ }
+
+ filter.View = parseOpsViewParam(c)
+ filter.Phase = "upstream"
+ filter.Owner = "provider"
+ filter.Source = strings.TrimSpace(c.Query("error_source"))
+ filter.Query = strings.TrimSpace(c.Query("q"))
+
+ if platform := strings.TrimSpace(c.Query("platform")); platform != "" {
+ filter.Platform = platform
+ }
+ if v := strings.TrimSpace(c.Query("group_id")); v != "" {
+ id, err := strconv.ParseInt(v, 10, 64)
+ if err != nil || id <= 0 {
+ response.BadRequest(c, "Invalid group_id")
+ return
+ }
+ filter.GroupID = &id
+ }
+ if v := strings.TrimSpace(c.Query("account_id")); v != "" {
+ id, err := strconv.ParseInt(v, 10, 64)
+ if err != nil || id <= 0 {
+ response.BadRequest(c, "Invalid account_id")
+ return
+ }
+ filter.AccountID = &id
+ }
+
+ if v := strings.TrimSpace(c.Query("resolved")); v != "" {
+ switch strings.ToLower(v) {
+ case "1", "true", "yes":
+ b := true
+ filter.Resolved = &b
+ case "0", "false", "no":
+ b := false
+ filter.Resolved = &b
+ default:
+ response.BadRequest(c, "Invalid resolved")
+ return
+ }
+ }
+ if statusCodesStr := strings.TrimSpace(c.Query("status_codes")); statusCodesStr != "" {
+ parts := strings.Split(statusCodesStr, ",")
+ out := make([]int, 0, len(parts))
+ for _, part := range parts {
+ p := strings.TrimSpace(part)
+ if p == "" {
+ continue
+ }
+ n, err := strconv.Atoi(p)
+ if err != nil || n < 0 {
+ response.BadRequest(c, "Invalid status_codes")
+ return
+ }
+ out = append(out, n)
+ }
+ filter.StatusCodes = out
+ }
+
+ result, err := h.opsService.GetErrorLogs(c.Request.Context(), filter)
+ if err != nil {
+ response.ErrorFrom(c, err)
+ return
+ }
+ response.Paginated(c, result.Errors, int64(result.Total), result.Page, result.PageSize)
+}
+
+// GetUpstreamError returns upstream error detail.
+// GET /api/v1/admin/ops/upstream-errors/:id
+func (h *OpsHandler) GetUpstreamError(c *gin.Context) {
+ h.GetErrorLogByID(c)
+}
+
+// RetryUpstreamError retries upstream error using the original account_id.
+// POST /api/v1/admin/ops/upstream-errors/:id/retry
+func (h *OpsHandler) RetryUpstreamError(c *gin.Context) {
+ if h.opsService == nil {
+ response.Error(c, http.StatusServiceUnavailable, "Ops service not available")
+ return
+ }
+ if err := h.opsService.RequireMonitoringEnabled(c.Request.Context()); err != nil {
+ response.ErrorFrom(c, err)
+ return
+ }
+
+ subject, ok := middleware.GetAuthSubjectFromContext(c)
+ if !ok || subject.UserID <= 0 {
+ response.Error(c, http.StatusUnauthorized, "Unauthorized")
+ return
+ }
+
+ idStr := strings.TrimSpace(c.Param("id"))
+ id, err := strconv.ParseInt(idStr, 10, 64)
+ if err != nil || id <= 0 {
+ response.BadRequest(c, "Invalid error id")
+ return
+ }
+
+ result, err := h.opsService.RetryError(c.Request.Context(), subject.UserID, id, service.OpsRetryModeUpstream, nil)
+ if err != nil {
+ response.ErrorFrom(c, err)
+ return
+ }
+ response.Success(c, result)
+}
+
+// ResolveUpstreamError toggles resolved status.
+// PUT /api/v1/admin/ops/upstream-errors/:id/resolve
+func (h *OpsHandler) ResolveUpstreamError(c *gin.Context) {
+ h.UpdateErrorResolution(c)
+}
+
+// ==================== Existing endpoints ====================
+
// ListRequestDetails returns a request-level list (success + error) for drill-down.
// GET /api/v1/admin/ops/requests
func (h *OpsHandler) ListRequestDetails(c *gin.Context) {
@@ -242,6 +709,11 @@ func (h *OpsHandler) ListRequestDetails(c *gin.Context) {
type opsRetryRequest struct {
Mode string `json:"mode"`
PinnedAccountID *int64 `json:"pinned_account_id"`
+ Force bool `json:"force"`
+}
+
+type opsResolveRequest struct {
+ Resolved bool `json:"resolved"`
}
// RetryErrorRequest retries a failed request using stored request_body.
@@ -278,6 +750,16 @@ func (h *OpsHandler) RetryErrorRequest(c *gin.Context) {
req.Mode = service.OpsRetryModeClient
}
+ // Force flag is currently a UI-level acknowledgement. Server may still enforce safety constraints.
+ _ = req.Force
+
+ // Legacy endpoint safety: only allow retrying the client request here.
+ // Upstream retries must go through the split endpoints.
+ if strings.EqualFold(strings.TrimSpace(req.Mode), service.OpsRetryModeUpstream) {
+ response.BadRequest(c, "upstream retry is not supported on this endpoint")
+ return
+ }
+
result, err := h.opsService.RetryError(c.Request.Context(), subject.UserID, id, req.Mode, req.PinnedAccountID)
if err != nil {
response.ErrorFrom(c, err)
@@ -287,6 +769,81 @@ func (h *OpsHandler) RetryErrorRequest(c *gin.Context) {
response.Success(c, result)
}
+// ListRetryAttempts lists retry attempts for an error log.
+// GET /api/v1/admin/ops/errors/:id/retries
+func (h *OpsHandler) ListRetryAttempts(c *gin.Context) {
+ if h.opsService == nil {
+ response.Error(c, http.StatusServiceUnavailable, "Ops service not available")
+ return
+ }
+ if err := h.opsService.RequireMonitoringEnabled(c.Request.Context()); err != nil {
+ response.ErrorFrom(c, err)
+ return
+ }
+
+ idStr := strings.TrimSpace(c.Param("id"))
+ id, err := strconv.ParseInt(idStr, 10, 64)
+ if err != nil || id <= 0 {
+ response.BadRequest(c, "Invalid error id")
+ return
+ }
+
+ limit := 50
+ if v := strings.TrimSpace(c.Query("limit")); v != "" {
+ n, err := strconv.Atoi(v)
+ if err != nil || n <= 0 {
+ response.BadRequest(c, "Invalid limit")
+ return
+ }
+ limit = n
+ }
+
+ items, err := h.opsService.ListRetryAttemptsByErrorID(c.Request.Context(), id, limit)
+ if err != nil {
+ response.ErrorFrom(c, err)
+ return
+ }
+ response.Success(c, items)
+}
+
+// UpdateErrorResolution allows manual resolve/unresolve.
+// PUT /api/v1/admin/ops/errors/:id/resolve
+func (h *OpsHandler) UpdateErrorResolution(c *gin.Context) {
+ if h.opsService == nil {
+ response.Error(c, http.StatusServiceUnavailable, "Ops service not available")
+ return
+ }
+ if err := h.opsService.RequireMonitoringEnabled(c.Request.Context()); err != nil {
+ response.ErrorFrom(c, err)
+ return
+ }
+
+ subject, ok := middleware.GetAuthSubjectFromContext(c)
+ if !ok || subject.UserID <= 0 {
+ response.Error(c, http.StatusUnauthorized, "Unauthorized")
+ return
+ }
+
+ idStr := strings.TrimSpace(c.Param("id"))
+ id, err := strconv.ParseInt(idStr, 10, 64)
+ if err != nil || id <= 0 {
+ response.BadRequest(c, "Invalid error id")
+ return
+ }
+
+ var req opsResolveRequest
+ if err := c.ShouldBindJSON(&req); err != nil {
+ response.BadRequest(c, "Invalid request: "+err.Error())
+ return
+ }
+ uid := subject.UserID
+ if err := h.opsService.UpdateErrorResolution(c.Request.Context(), id, req.Resolved, &uid, nil); err != nil {
+ response.ErrorFrom(c, err)
+ return
+ }
+ response.Success(c, gin.H{"ok": true})
+}
+
func parseOpsTimeRange(c *gin.Context, defaultRange string) (time.Time, time.Time, error) {
startStr := strings.TrimSpace(c.Query("start_time"))
endStr := strings.TrimSpace(c.Query("end_time"))
@@ -358,6 +915,10 @@ func parseOpsDuration(v string) (time.Duration, bool) {
return 6 * time.Hour, true
case "24h":
return 24 * time.Hour, true
+ case "7d":
+ return 7 * 24 * time.Hour, true
+ case "30d":
+ return 30 * 24 * time.Hour, true
default:
return 0, false
}
diff --git a/backend/internal/handler/admin/proxy_handler.go b/backend/internal/handler/admin/proxy_handler.go
index 437e9300..a6758f69 100644
--- a/backend/internal/handler/admin/proxy_handler.go
+++ b/backend/internal/handler/admin/proxy_handler.go
@@ -196,6 +196,28 @@ func (h *ProxyHandler) Delete(c *gin.Context) {
response.Success(c, gin.H{"message": "Proxy deleted successfully"})
}
+// BatchDelete handles batch deleting proxies
+// POST /api/v1/admin/proxies/batch-delete
+func (h *ProxyHandler) BatchDelete(c *gin.Context) {
+ type BatchDeleteRequest struct {
+ IDs []int64 `json:"ids" binding:"required,min=1"`
+ }
+
+ var req BatchDeleteRequest
+ if err := c.ShouldBindJSON(&req); err != nil {
+ response.BadRequest(c, "Invalid request: "+err.Error())
+ return
+ }
+
+ result, err := h.adminService.BatchDeleteProxies(c.Request.Context(), req.IDs)
+ if err != nil {
+ response.ErrorFrom(c, err)
+ return
+ }
+
+ response.Success(c, result)
+}
+
// Test handles testing proxy connectivity
// POST /api/v1/admin/proxies/:id/test
func (h *ProxyHandler) Test(c *gin.Context) {
@@ -243,19 +265,17 @@ func (h *ProxyHandler) GetProxyAccounts(c *gin.Context) {
return
}
- page, pageSize := response.ParsePagination(c)
-
- accounts, total, err := h.adminService.GetProxyAccounts(c.Request.Context(), proxyID, page, pageSize)
+ accounts, err := h.adminService.GetProxyAccounts(c.Request.Context(), proxyID)
if err != nil {
response.ErrorFrom(c, err)
return
}
- out := make([]dto.Account, 0, len(accounts))
+ out := make([]dto.ProxyAccountSummary, 0, len(accounts))
for i := range accounts {
- out = append(out, *dto.AccountFromService(&accounts[i]))
+ out = append(out, *dto.ProxyAccountSummaryFromService(&accounts[i]))
}
- response.Paginated(c, out, total, page, pageSize)
+ response.Success(c, out)
}
// BatchCreateProxyItem represents a single proxy in batch create request
diff --git a/backend/internal/handler/dto/mappers.go b/backend/internal/handler/dto/mappers.go
index 6ffaedea..f5bdd008 100644
--- a/backend/internal/handler/dto/mappers.go
+++ b/backend/internal/handler/dto/mappers.go
@@ -73,25 +73,27 @@ func GroupFromServiceShallow(g *service.Group) *Group {
return nil
}
return &Group{
- ID: g.ID,
- Name: g.Name,
- Description: g.Description,
- Platform: g.Platform,
- RateMultiplier: g.RateMultiplier,
- IsExclusive: g.IsExclusive,
- Status: g.Status,
- SubscriptionType: g.SubscriptionType,
- DailyLimitUSD: g.DailyLimitUSD,
- WeeklyLimitUSD: g.WeeklyLimitUSD,
- MonthlyLimitUSD: g.MonthlyLimitUSD,
- ImagePrice1K: g.ImagePrice1K,
- ImagePrice2K: g.ImagePrice2K,
- ImagePrice4K: g.ImagePrice4K,
- ClaudeCodeOnly: g.ClaudeCodeOnly,
- FallbackGroupID: g.FallbackGroupID,
- CreatedAt: g.CreatedAt,
- UpdatedAt: g.UpdatedAt,
- AccountCount: g.AccountCount,
+ ID: g.ID,
+ Name: g.Name,
+ Description: g.Description,
+ Platform: g.Platform,
+ RateMultiplier: g.RateMultiplier,
+ IsExclusive: g.IsExclusive,
+ Status: g.Status,
+ SubscriptionType: g.SubscriptionType,
+ DailyLimitUSD: g.DailyLimitUSD,
+ WeeklyLimitUSD: g.WeeklyLimitUSD,
+ MonthlyLimitUSD: g.MonthlyLimitUSD,
+ ImagePrice1K: g.ImagePrice1K,
+ ImagePrice2K: g.ImagePrice2K,
+ ImagePrice4K: g.ImagePrice4K,
+ ClaudeCodeOnly: g.ClaudeCodeOnly,
+ FallbackGroupID: g.FallbackGroupID,
+ ModelRouting: g.ModelRouting,
+ ModelRoutingEnabled: g.ModelRoutingEnabled,
+ CreatedAt: g.CreatedAt,
+ UpdatedAt: g.UpdatedAt,
+ AccountCount: g.AccountCount,
}
}
@@ -114,7 +116,7 @@ func AccountFromServiceShallow(a *service.Account) *Account {
if a == nil {
return nil
}
- return &Account{
+ out := &Account{
ID: a.ID,
Name: a.Name,
Notes: a.Notes,
@@ -125,6 +127,7 @@ func AccountFromServiceShallow(a *service.Account) *Account {
ProxyID: a.ProxyID,
Concurrency: a.Concurrency,
Priority: a.Priority,
+ RateMultiplier: a.BillingRateMultiplier(),
Status: a.Status,
ErrorMessage: a.ErrorMessage,
LastUsedAt: a.LastUsedAt,
@@ -143,6 +146,24 @@ func AccountFromServiceShallow(a *service.Account) *Account {
SessionWindowStatus: a.SessionWindowStatus,
GroupIDs: a.GroupIDs,
}
+
+ // 提取 5h 窗口费用控制和会话数量控制配置(仅 Anthropic OAuth/SetupToken 账号有效)
+ if a.IsAnthropicOAuthOrSetupToken() {
+ if limit := a.GetWindowCostLimit(); limit > 0 {
+ out.WindowCostLimit = &limit
+ }
+ if reserve := a.GetWindowCostStickyReserve(); reserve > 0 {
+ out.WindowCostStickyReserve = &reserve
+ }
+ if maxSessions := a.GetMaxSessions(); maxSessions > 0 {
+ out.MaxSessions = &maxSessions
+ }
+ if idleTimeout := a.GetSessionIdleTimeoutMinutes(); idleTimeout > 0 {
+ out.SessionIdleTimeoutMin = &idleTimeout
+ }
+ }
+
+ return out
}
func AccountFromService(a *service.Account) *Account {
@@ -212,8 +233,29 @@ func ProxyWithAccountCountFromService(p *service.ProxyWithAccountCount) *ProxyWi
return nil
}
return &ProxyWithAccountCount{
- Proxy: *ProxyFromService(&p.Proxy),
- AccountCount: p.AccountCount,
+ Proxy: *ProxyFromService(&p.Proxy),
+ AccountCount: p.AccountCount,
+ LatencyMs: p.LatencyMs,
+ LatencyStatus: p.LatencyStatus,
+ LatencyMessage: p.LatencyMessage,
+ IPAddress: p.IPAddress,
+ Country: p.Country,
+ CountryCode: p.CountryCode,
+ Region: p.Region,
+ City: p.City,
+ }
+}
+
+func ProxyAccountSummaryFromService(a *service.ProxyAccountSummary) *ProxyAccountSummary {
+ if a == nil {
+ return nil
+ }
+ return &ProxyAccountSummary{
+ ID: a.ID,
+ Name: a.Name,
+ Platform: a.Platform,
+ Type: a.Type,
+ Notes: a.Notes,
}
}
@@ -279,6 +321,7 @@ func usageLogFromServiceBase(l *service.UsageLog, account *AccountSummary, inclu
TotalCost: l.TotalCost,
ActualCost: l.ActualCost,
RateMultiplier: l.RateMultiplier,
+ AccountRateMultiplier: l.AccountRateMultiplier,
BillingType: l.BillingType,
Stream: l.Stream,
DurationMs: l.DurationMs,
diff --git a/backend/internal/handler/dto/types.go b/backend/internal/handler/dto/types.go
index a9b010b9..4519143c 100644
--- a/backend/internal/handler/dto/types.go
+++ b/backend/internal/handler/dto/types.go
@@ -58,6 +58,10 @@ type Group struct {
ClaudeCodeOnly bool `json:"claude_code_only"`
FallbackGroupID *int64 `json:"fallback_group_id"`
+ // 模型路由配置(仅 anthropic 平台使用)
+ ModelRouting map[string][]int64 `json:"model_routing"`
+ ModelRoutingEnabled bool `json:"model_routing_enabled"`
+
CreatedAt time.Time `json:"created_at"`
UpdatedAt time.Time `json:"updated_at"`
@@ -76,6 +80,7 @@ type Account struct {
ProxyID *int64 `json:"proxy_id"`
Concurrency int `json:"concurrency"`
Priority int `json:"priority"`
+ RateMultiplier float64 `json:"rate_multiplier"`
Status string `json:"status"`
ErrorMessage string `json:"error_message"`
LastUsedAt *time.Time `json:"last_used_at"`
@@ -97,6 +102,16 @@ type Account struct {
SessionWindowEnd *time.Time `json:"session_window_end"`
SessionWindowStatus string `json:"session_window_status"`
+ // 5h窗口费用控制(仅 Anthropic OAuth/SetupToken 账号有效)
+ // 从 extra 字段提取,方便前端显示和编辑
+ WindowCostLimit *float64 `json:"window_cost_limit,omitempty"`
+ WindowCostStickyReserve *float64 `json:"window_cost_sticky_reserve,omitempty"`
+
+ // 会话数量控制(仅 Anthropic OAuth/SetupToken 账号有效)
+ // 从 extra 字段提取,方便前端显示和编辑
+ MaxSessions *int `json:"max_sessions,omitempty"`
+ SessionIdleTimeoutMin *int `json:"session_idle_timeout_minutes,omitempty"`
+
Proxy *Proxy `json:"proxy,omitempty"`
AccountGroups []AccountGroup `json:"account_groups,omitempty"`
@@ -129,7 +144,23 @@ type Proxy struct {
type ProxyWithAccountCount struct {
Proxy
- AccountCount int64 `json:"account_count"`
+ AccountCount int64 `json:"account_count"`
+ LatencyMs *int64 `json:"latency_ms,omitempty"`
+ LatencyStatus string `json:"latency_status,omitempty"`
+ LatencyMessage string `json:"latency_message,omitempty"`
+ IPAddress string `json:"ip_address,omitempty"`
+ Country string `json:"country,omitempty"`
+ CountryCode string `json:"country_code,omitempty"`
+ Region string `json:"region,omitempty"`
+ City string `json:"city,omitempty"`
+}
+
+type ProxyAccountSummary struct {
+ ID int64 `json:"id"`
+ Name string `json:"name"`
+ Platform string `json:"platform"`
+ Type string `json:"type"`
+ Notes *string `json:"notes,omitempty"`
}
type RedeemCode struct {
@@ -169,13 +200,14 @@ type UsageLog struct {
CacheCreation5mTokens int `json:"cache_creation_5m_tokens"`
CacheCreation1hTokens int `json:"cache_creation_1h_tokens"`
- InputCost float64 `json:"input_cost"`
- OutputCost float64 `json:"output_cost"`
- CacheCreationCost float64 `json:"cache_creation_cost"`
- CacheReadCost float64 `json:"cache_read_cost"`
- TotalCost float64 `json:"total_cost"`
- ActualCost float64 `json:"actual_cost"`
- RateMultiplier float64 `json:"rate_multiplier"`
+ InputCost float64 `json:"input_cost"`
+ OutputCost float64 `json:"output_cost"`
+ CacheCreationCost float64 `json:"cache_creation_cost"`
+ CacheReadCost float64 `json:"cache_read_cost"`
+ TotalCost float64 `json:"total_cost"`
+ ActualCost float64 `json:"actual_cost"`
+ RateMultiplier float64 `json:"rate_multiplier"`
+ AccountRateMultiplier *float64 `json:"account_rate_multiplier"`
BillingType int8 `json:"billing_type"`
Stream bool `json:"stream"`
diff --git a/backend/internal/handler/gateway_handler.go b/backend/internal/handler/gateway_handler.go
index 91d590bf..7605805a 100644
--- a/backend/internal/handler/gateway_handler.go
+++ b/backend/internal/handler/gateway_handler.go
@@ -185,7 +185,7 @@ func (h *GatewayHandler) Messages(c *gin.Context) {
lastFailoverStatus := 0
for {
- selection, err := h.gatewayService.SelectAccountWithLoadAwareness(c.Request.Context(), apiKey.GroupID, sessionKey, reqModel, failedAccountIDs)
+ selection, err := h.gatewayService.SelectAccountWithLoadAwareness(c.Request.Context(), apiKey.GroupID, sessionKey, reqModel, failedAccountIDs, "") // Gemini 不使用会话限制
if err != nil {
if len(failedAccountIDs) == 0 {
h.handleStreamingAwareError(c, http.StatusServiceUnavailable, "api_error", "No available accounts: "+err.Error(), streamStarted)
@@ -320,7 +320,7 @@ func (h *GatewayHandler) Messages(c *gin.Context) {
for {
// 选择支持该模型的账号
- selection, err := h.gatewayService.SelectAccountWithLoadAwareness(c.Request.Context(), apiKey.GroupID, sessionKey, reqModel, failedAccountIDs)
+ selection, err := h.gatewayService.SelectAccountWithLoadAwareness(c.Request.Context(), apiKey.GroupID, sessionKey, reqModel, failedAccountIDs, parsedReq.MetadataUserID)
if err != nil {
if len(failedAccountIDs) == 0 {
h.handleStreamingAwareError(c, http.StatusServiceUnavailable, "api_error", "No available accounts: "+err.Error(), streamStarted)
diff --git a/backend/internal/handler/gemini_v1beta_handler.go b/backend/internal/handler/gemini_v1beta_handler.go
index 2dddb856..ec943e61 100644
--- a/backend/internal/handler/gemini_v1beta_handler.go
+++ b/backend/internal/handler/gemini_v1beta_handler.go
@@ -226,7 +226,7 @@ func (h *GatewayHandler) GeminiV1BetaModels(c *gin.Context) {
lastFailoverStatus := 0
for {
- selection, err := h.gatewayService.SelectAccountWithLoadAwareness(c.Request.Context(), apiKey.GroupID, sessionKey, modelName, failedAccountIDs)
+ selection, err := h.gatewayService.SelectAccountWithLoadAwareness(c.Request.Context(), apiKey.GroupID, sessionKey, modelName, failedAccountIDs, "") // Gemini 不使用会话限制
if err != nil {
if len(failedAccountIDs) == 0 {
googleError(c, http.StatusServiceUnavailable, "No available Gemini accounts: "+err.Error())
diff --git a/backend/internal/handler/ops_error_logger.go b/backend/internal/handler/ops_error_logger.go
index 13bd9d94..f62e6b3e 100644
--- a/backend/internal/handler/ops_error_logger.go
+++ b/backend/internal/handler/ops_error_logger.go
@@ -544,6 +544,11 @@ func OpsErrorLoggerMiddleware(ops *service.OpsService) gin.HandlerFunc {
body := w.buf.Bytes()
parsed := parseOpsErrorResponse(body)
+ // Skip logging if the error should be filtered based on settings
+ if shouldSkipOpsErrorLog(c.Request.Context(), ops, parsed.Message, string(body), c.Request.URL.Path) {
+ return
+ }
+
apiKey, _ := middleware2.GetAPIKeyFromContext(c)
clientRequestID, _ := c.Request.Context().Value(ctxkey.ClientRequestID).(string)
@@ -832,28 +837,30 @@ func normalizeOpsErrorType(errType string, code string) string {
func classifyOpsPhase(errType, message, code string) string {
msg := strings.ToLower(message)
+ // Standardized phases: request|auth|routing|upstream|network|internal
+ // Map billing/concurrency/response => request; scheduling => routing.
switch strings.TrimSpace(code) {
case "INSUFFICIENT_BALANCE", "USAGE_LIMIT_EXCEEDED", "SUBSCRIPTION_NOT_FOUND", "SUBSCRIPTION_INVALID":
- return "billing"
+ return "request"
}
switch errType {
case "authentication_error":
return "auth"
case "billing_error", "subscription_error":
- return "billing"
+ return "request"
case "rate_limit_error":
if strings.Contains(msg, "concurrency") || strings.Contains(msg, "pending") || strings.Contains(msg, "queue") {
- return "concurrency"
+ return "request"
}
return "upstream"
case "invalid_request_error":
- return "response"
+ return "request"
case "upstream_error", "overloaded_error":
return "upstream"
case "api_error":
if strings.Contains(msg, "no available accounts") {
- return "scheduling"
+ return "routing"
}
return "internal"
default:
@@ -914,34 +921,38 @@ func classifyOpsIsBusinessLimited(errType, phase, code string, status int, messa
}
func classifyOpsErrorOwner(phase string, message string) string {
+ // Standardized owners: client|provider|platform
switch phase {
case "upstream", "network":
return "provider"
- case "billing", "concurrency", "auth", "response":
+ case "request", "auth":
return "client"
+ case "routing", "internal":
+ return "platform"
default:
if strings.Contains(strings.ToLower(message), "upstream") {
return "provider"
}
- return "sub2api"
+ return "platform"
}
}
func classifyOpsErrorSource(phase string, message string) string {
+ // Standardized sources: client_request|upstream_http|gateway
switch phase {
case "upstream":
return "upstream_http"
case "network":
- return "upstream_network"
- case "billing":
- return "billing"
- case "concurrency":
- return "concurrency"
+ return "gateway"
+ case "request", "auth":
+ return "client_request"
+ case "routing", "internal":
+ return "gateway"
default:
if strings.Contains(strings.ToLower(message), "upstream") {
return "upstream_http"
}
- return "internal"
+ return "gateway"
}
}
@@ -963,3 +974,42 @@ func truncateString(s string, max int) string {
func strconvItoa(v int) string {
return strconv.Itoa(v)
}
+
+// shouldSkipOpsErrorLog determines if an error should be skipped from logging based on settings.
+// Returns true for errors that should be filtered according to OpsAdvancedSettings.
+func shouldSkipOpsErrorLog(ctx context.Context, ops *service.OpsService, message, body, requestPath string) bool {
+ if ops == nil {
+ return false
+ }
+
+ // Get advanced settings to check filter configuration
+ settings, err := ops.GetOpsAdvancedSettings(ctx)
+ if err != nil || settings == nil {
+ // If we can't get settings, don't skip (fail open)
+ return false
+ }
+
+ msgLower := strings.ToLower(message)
+ bodyLower := strings.ToLower(body)
+
+ // Check if count_tokens errors should be ignored
+ if settings.IgnoreCountTokensErrors && strings.Contains(requestPath, "/count_tokens") {
+ return true
+ }
+
+ // Check if context canceled errors should be ignored (client disconnects)
+ if settings.IgnoreContextCanceled {
+ if strings.Contains(msgLower, "context canceled") || strings.Contains(bodyLower, "context canceled") {
+ return true
+ }
+ }
+
+ // Check if "no available accounts" errors should be ignored
+ if settings.IgnoreNoAvailableAccounts {
+ if strings.Contains(msgLower, "no available accounts") || strings.Contains(bodyLower, "no available accounts") {
+ return true
+ }
+ }
+
+ return false
+}
diff --git a/backend/internal/pkg/usagestats/account_stats.go b/backend/internal/pkg/usagestats/account_stats.go
index ed77dd27..9ac49625 100644
--- a/backend/internal/pkg/usagestats/account_stats.go
+++ b/backend/internal/pkg/usagestats/account_stats.go
@@ -1,8 +1,14 @@
package usagestats
// AccountStats 账号使用统计
+//
+// cost: 账号口径费用(使用 total_cost * account_rate_multiplier)
+// standard_cost: 标准费用(使用 total_cost,不含倍率)
+// user_cost: 用户/API Key 口径费用(使用 actual_cost,受分组倍率影响)
type AccountStats struct {
- Requests int64 `json:"requests"`
- Tokens int64 `json:"tokens"`
- Cost float64 `json:"cost"`
+ Requests int64 `json:"requests"`
+ Tokens int64 `json:"tokens"`
+ Cost float64 `json:"cost"`
+ StandardCost float64 `json:"standard_cost"`
+ UserCost float64 `json:"user_cost"`
}
diff --git a/backend/internal/pkg/usagestats/usage_log_types.go b/backend/internal/pkg/usagestats/usage_log_types.go
index 3952785b..2f6c7fe0 100644
--- a/backend/internal/pkg/usagestats/usage_log_types.go
+++ b/backend/internal/pkg/usagestats/usage_log_types.go
@@ -147,14 +147,15 @@ type UsageLogFilters struct {
// UsageStats represents usage statistics
type UsageStats struct {
- TotalRequests int64 `json:"total_requests"`
- TotalInputTokens int64 `json:"total_input_tokens"`
- TotalOutputTokens int64 `json:"total_output_tokens"`
- TotalCacheTokens int64 `json:"total_cache_tokens"`
- TotalTokens int64 `json:"total_tokens"`
- TotalCost float64 `json:"total_cost"`
- TotalActualCost float64 `json:"total_actual_cost"`
- AverageDurationMs float64 `json:"average_duration_ms"`
+ TotalRequests int64 `json:"total_requests"`
+ TotalInputTokens int64 `json:"total_input_tokens"`
+ TotalOutputTokens int64 `json:"total_output_tokens"`
+ TotalCacheTokens int64 `json:"total_cache_tokens"`
+ TotalTokens int64 `json:"total_tokens"`
+ TotalCost float64 `json:"total_cost"`
+ TotalActualCost float64 `json:"total_actual_cost"`
+ TotalAccountCost *float64 `json:"total_account_cost,omitempty"`
+ AverageDurationMs float64 `json:"average_duration_ms"`
}
// BatchUserUsageStats represents usage stats for a single user
@@ -177,25 +178,29 @@ type AccountUsageHistory struct {
Label string `json:"label"`
Requests int64 `json:"requests"`
Tokens int64 `json:"tokens"`
- Cost float64 `json:"cost"`
- ActualCost float64 `json:"actual_cost"`
+ Cost float64 `json:"cost"` // 标准计费(total_cost)
+ ActualCost float64 `json:"actual_cost"` // 账号口径费用(total_cost * account_rate_multiplier)
+ UserCost float64 `json:"user_cost"` // 用户口径费用(actual_cost,受分组倍率影响)
}
// AccountUsageSummary represents summary statistics for an account
type AccountUsageSummary struct {
Days int `json:"days"`
ActualDaysUsed int `json:"actual_days_used"`
- TotalCost float64 `json:"total_cost"`
+ TotalCost float64 `json:"total_cost"` // 账号口径费用
+ TotalUserCost float64 `json:"total_user_cost"` // 用户口径费用
TotalStandardCost float64 `json:"total_standard_cost"`
TotalRequests int64 `json:"total_requests"`
TotalTokens int64 `json:"total_tokens"`
- AvgDailyCost float64 `json:"avg_daily_cost"`
+ AvgDailyCost float64 `json:"avg_daily_cost"` // 账号口径日均
+ AvgDailyUserCost float64 `json:"avg_daily_user_cost"`
AvgDailyRequests float64 `json:"avg_daily_requests"`
AvgDailyTokens float64 `json:"avg_daily_tokens"`
AvgDurationMs float64 `json:"avg_duration_ms"`
Today *struct {
Date string `json:"date"`
Cost float64 `json:"cost"`
+ UserCost float64 `json:"user_cost"`
Requests int64 `json:"requests"`
Tokens int64 `json:"tokens"`
} `json:"today"`
@@ -203,6 +208,7 @@ type AccountUsageSummary struct {
Date string `json:"date"`
Label string `json:"label"`
Cost float64 `json:"cost"`
+ UserCost float64 `json:"user_cost"`
Requests int64 `json:"requests"`
} `json:"highest_cost_day"`
HighestRequestDay *struct {
@@ -210,6 +216,7 @@ type AccountUsageSummary struct {
Label string `json:"label"`
Requests int64 `json:"requests"`
Cost float64 `json:"cost"`
+ UserCost float64 `json:"user_cost"`
} `json:"highest_request_day"`
}
diff --git a/backend/internal/repository/account_repo.go b/backend/internal/repository/account_repo.go
index aaa89f21..f7725820 100644
--- a/backend/internal/repository/account_repo.go
+++ b/backend/internal/repository/account_repo.go
@@ -80,6 +80,10 @@ func (r *accountRepository) Create(ctx context.Context, account *service.Account
SetSchedulable(account.Schedulable).
SetAutoPauseOnExpired(account.AutoPauseOnExpired)
+ if account.RateMultiplier != nil {
+ builder.SetRateMultiplier(*account.RateMultiplier)
+ }
+
if account.ProxyID != nil {
builder.SetProxyID(*account.ProxyID)
}
@@ -291,6 +295,10 @@ func (r *accountRepository) Update(ctx context.Context, account *service.Account
SetSchedulable(account.Schedulable).
SetAutoPauseOnExpired(account.AutoPauseOnExpired)
+ if account.RateMultiplier != nil {
+ builder.SetRateMultiplier(*account.RateMultiplier)
+ }
+
if account.ProxyID != nil {
builder.SetProxyID(*account.ProxyID)
} else {
@@ -786,6 +794,46 @@ func (r *accountRepository) SetAntigravityQuotaScopeLimit(ctx context.Context, i
return nil
}
+func (r *accountRepository) SetModelRateLimit(ctx context.Context, id int64, scope string, resetAt time.Time) error {
+ if scope == "" {
+ return nil
+ }
+ now := time.Now().UTC()
+ payload := map[string]string{
+ "rate_limited_at": now.Format(time.RFC3339),
+ "rate_limit_reset_at": resetAt.UTC().Format(time.RFC3339),
+ }
+ raw, err := json.Marshal(payload)
+ if err != nil {
+ return err
+ }
+
+ path := "{model_rate_limits," + scope + "}"
+ client := clientFromContext(ctx, r.client)
+ result, err := client.ExecContext(
+ ctx,
+ "UPDATE accounts SET extra = jsonb_set(COALESCE(extra, '{}'::jsonb), $1::text[], $2::jsonb, true), updated_at = NOW() WHERE id = $3 AND deleted_at IS NULL",
+ path,
+ raw,
+ id,
+ )
+ if err != nil {
+ return err
+ }
+
+ affected, err := result.RowsAffected()
+ if err != nil {
+ return err
+ }
+ if affected == 0 {
+ return service.ErrAccountNotFound
+ }
+ if err := enqueueSchedulerOutbox(ctx, r.sql, service.SchedulerOutboxEventAccountChanged, &id, nil, nil); err != nil {
+ log.Printf("[SchedulerOutbox] enqueue model rate limit failed: account=%d err=%v", id, err)
+ }
+ return nil
+}
+
func (r *accountRepository) SetOverloaded(ctx context.Context, id int64, until time.Time) error {
_, err := r.client.Account.Update().
Where(dbaccount.IDEQ(id)).
@@ -877,6 +925,30 @@ func (r *accountRepository) ClearAntigravityQuotaScopes(ctx context.Context, id
return nil
}
+func (r *accountRepository) ClearModelRateLimits(ctx context.Context, id int64) error {
+ client := clientFromContext(ctx, r.client)
+ result, err := client.ExecContext(
+ ctx,
+ "UPDATE accounts SET extra = COALESCE(extra, '{}'::jsonb) - 'model_rate_limits', updated_at = NOW() WHERE id = $1 AND deleted_at IS NULL",
+ id,
+ )
+ if err != nil {
+ return err
+ }
+
+ affected, err := result.RowsAffected()
+ if err != nil {
+ return err
+ }
+ if affected == 0 {
+ return service.ErrAccountNotFound
+ }
+ if err := enqueueSchedulerOutbox(ctx, r.sql, service.SchedulerOutboxEventAccountChanged, &id, nil, nil); err != nil {
+ log.Printf("[SchedulerOutbox] enqueue clear model rate limit failed: account=%d err=%v", id, err)
+ }
+ return nil
+}
+
func (r *accountRepository) UpdateSessionWindow(ctx context.Context, id int64, start, end *time.Time, status string) error {
builder := r.client.Account.Update().
Where(dbaccount.IDEQ(id)).
@@ -999,6 +1071,11 @@ func (r *accountRepository) BulkUpdate(ctx context.Context, ids []int64, updates
args = append(args, *updates.Priority)
idx++
}
+ if updates.RateMultiplier != nil {
+ setClauses = append(setClauses, "rate_multiplier = $"+itoa(idx))
+ args = append(args, *updates.RateMultiplier)
+ idx++
+ }
if updates.Status != nil {
setClauses = append(setClauses, "status = $"+itoa(idx))
args = append(args, *updates.Status)
@@ -1347,6 +1424,8 @@ func accountEntityToService(m *dbent.Account) *service.Account {
return nil
}
+ rateMultiplier := m.RateMultiplier
+
return &service.Account{
ID: m.ID,
Name: m.Name,
@@ -1358,6 +1437,7 @@ func accountEntityToService(m *dbent.Account) *service.Account {
ProxyID: m.ProxyID,
Concurrency: m.Concurrency,
Priority: m.Priority,
+ RateMultiplier: &rateMultiplier,
Status: m.Status,
ErrorMessage: derefString(m.ErrorMessage),
LastUsedAt: m.LastUsedAt,
diff --git a/backend/internal/repository/api_key_repo.go b/backend/internal/repository/api_key_repo.go
index 77a3f233..ab890844 100644
--- a/backend/internal/repository/api_key_repo.go
+++ b/backend/internal/repository/api_key_repo.go
@@ -136,6 +136,8 @@ func (r *apiKeyRepository) GetByKeyForAuth(ctx context.Context, key string) (*se
group.FieldImagePrice4k,
group.FieldClaudeCodeOnly,
group.FieldFallbackGroupID,
+ group.FieldModelRoutingEnabled,
+ group.FieldModelRouting,
)
}).
Only(ctx)
@@ -422,6 +424,8 @@ func groupEntityToService(g *dbent.Group) *service.Group {
DefaultValidityDays: g.DefaultValidityDays,
ClaudeCodeOnly: g.ClaudeCodeOnly,
FallbackGroupID: g.FallbackGroupID,
+ ModelRouting: g.ModelRouting,
+ ModelRoutingEnabled: g.ModelRoutingEnabled,
CreatedAt: g.CreatedAt,
UpdatedAt: g.UpdatedAt,
}
diff --git a/backend/internal/repository/dashboard_aggregation_repo.go b/backend/internal/repository/dashboard_aggregation_repo.go
index d238e320..3543e061 100644
--- a/backend/internal/repository/dashboard_aggregation_repo.go
+++ b/backend/internal/repository/dashboard_aggregation_repo.go
@@ -8,6 +8,7 @@ import (
"strings"
"time"
+ "github.com/Wei-Shaw/sub2api/internal/pkg/timezone"
"github.com/Wei-Shaw/sub2api/internal/service"
"github.com/lib/pq"
)
@@ -41,21 +42,22 @@ func isPostgresDriver(db *sql.DB) bool {
}
func (r *dashboardAggregationRepository) AggregateRange(ctx context.Context, start, end time.Time) error {
- startUTC := start.UTC()
- endUTC := end.UTC()
- if !endUTC.After(startUTC) {
+ loc := timezone.Location()
+ startLocal := start.In(loc)
+ endLocal := end.In(loc)
+ if !endLocal.After(startLocal) {
return nil
}
- hourStart := startUTC.Truncate(time.Hour)
- hourEnd := endUTC.Truncate(time.Hour)
- if endUTC.After(hourEnd) {
+ hourStart := startLocal.Truncate(time.Hour)
+ hourEnd := endLocal.Truncate(time.Hour)
+ if endLocal.After(hourEnd) {
hourEnd = hourEnd.Add(time.Hour)
}
- dayStart := truncateToDayUTC(startUTC)
- dayEnd := truncateToDayUTC(endUTC)
- if endUTC.After(dayEnd) {
+ dayStart := truncateToDay(startLocal)
+ dayEnd := truncateToDay(endLocal)
+ if endLocal.After(dayEnd) {
dayEnd = dayEnd.Add(24 * time.Hour)
}
@@ -146,38 +148,41 @@ func (r *dashboardAggregationRepository) EnsureUsageLogsPartitions(ctx context.C
}
func (r *dashboardAggregationRepository) insertHourlyActiveUsers(ctx context.Context, start, end time.Time) error {
+ tzName := timezone.Name()
query := `
INSERT INTO usage_dashboard_hourly_users (bucket_start, user_id)
SELECT DISTINCT
- date_trunc('hour', created_at AT TIME ZONE 'UTC') AT TIME ZONE 'UTC' AS bucket_start,
+ date_trunc('hour', created_at AT TIME ZONE $3) AT TIME ZONE $3 AS bucket_start,
user_id
FROM usage_logs
WHERE created_at >= $1 AND created_at < $2
ON CONFLICT DO NOTHING
`
- _, err := r.sql.ExecContext(ctx, query, start.UTC(), end.UTC())
+ _, err := r.sql.ExecContext(ctx, query, start, end, tzName)
return err
}
func (r *dashboardAggregationRepository) insertDailyActiveUsers(ctx context.Context, start, end time.Time) error {
+ tzName := timezone.Name()
query := `
INSERT INTO usage_dashboard_daily_users (bucket_date, user_id)
SELECT DISTINCT
- (bucket_start AT TIME ZONE 'UTC')::date AS bucket_date,
+ (bucket_start AT TIME ZONE $3)::date AS bucket_date,
user_id
FROM usage_dashboard_hourly_users
WHERE bucket_start >= $1 AND bucket_start < $2
ON CONFLICT DO NOTHING
`
- _, err := r.sql.ExecContext(ctx, query, start.UTC(), end.UTC())
+ _, err := r.sql.ExecContext(ctx, query, start, end, tzName)
return err
}
func (r *dashboardAggregationRepository) upsertHourlyAggregates(ctx context.Context, start, end time.Time) error {
+ tzName := timezone.Name()
query := `
WITH hourly AS (
SELECT
- date_trunc('hour', created_at AT TIME ZONE 'UTC') AT TIME ZONE 'UTC' AS bucket_start,
+ date_trunc('hour', created_at AT TIME ZONE $3) AT TIME ZONE $3 AS bucket_start,
COUNT(*) AS total_requests,
COALESCE(SUM(input_tokens), 0) AS input_tokens,
COALESCE(SUM(output_tokens), 0) AS output_tokens,
@@ -236,15 +241,16 @@ func (r *dashboardAggregationRepository) upsertHourlyAggregates(ctx context.Cont
active_users = EXCLUDED.active_users,
computed_at = EXCLUDED.computed_at
`
- _, err := r.sql.ExecContext(ctx, query, start.UTC(), end.UTC())
+ _, err := r.sql.ExecContext(ctx, query, start, end, tzName)
return err
}
func (r *dashboardAggregationRepository) upsertDailyAggregates(ctx context.Context, start, end time.Time) error {
+ tzName := timezone.Name()
query := `
WITH daily AS (
SELECT
- (bucket_start AT TIME ZONE 'UTC')::date AS bucket_date,
+ (bucket_start AT TIME ZONE $5)::date AS bucket_date,
COALESCE(SUM(total_requests), 0) AS total_requests,
COALESCE(SUM(input_tokens), 0) AS input_tokens,
COALESCE(SUM(output_tokens), 0) AS output_tokens,
@@ -255,7 +261,7 @@ func (r *dashboardAggregationRepository) upsertDailyAggregates(ctx context.Conte
COALESCE(SUM(total_duration_ms), 0) AS total_duration_ms
FROM usage_dashboard_hourly
WHERE bucket_start >= $1 AND bucket_start < $2
- GROUP BY (bucket_start AT TIME ZONE 'UTC')::date
+ GROUP BY (bucket_start AT TIME ZONE $5)::date
),
user_counts AS (
SELECT bucket_date, COUNT(*) AS active_users
@@ -303,7 +309,7 @@ func (r *dashboardAggregationRepository) upsertDailyAggregates(ctx context.Conte
active_users = EXCLUDED.active_users,
computed_at = EXCLUDED.computed_at
`
- _, err := r.sql.ExecContext(ctx, query, start.UTC(), end.UTC(), start.UTC(), end.UTC())
+ _, err := r.sql.ExecContext(ctx, query, start, end, start, end, tzName)
return err
}
@@ -376,9 +382,8 @@ func (r *dashboardAggregationRepository) createUsageLogsPartition(ctx context.Co
return err
}
-func truncateToDayUTC(t time.Time) time.Time {
- t = t.UTC()
- return time.Date(t.Year(), t.Month(), t.Day(), 0, 0, 0, 0, time.UTC)
+func truncateToDay(t time.Time) time.Time {
+ return timezone.StartOfDay(t)
}
func truncateToMonthUTC(t time.Time) time.Time {
diff --git a/backend/internal/repository/gemini_token_cache.go b/backend/internal/repository/gemini_token_cache.go
index a7270556..d4f552bc 100644
--- a/backend/internal/repository/gemini_token_cache.go
+++ b/backend/internal/repository/gemini_token_cache.go
@@ -11,8 +11,8 @@ import (
)
const (
- geminiTokenKeyPrefix = "gemini:token:"
- geminiRefreshLockKeyPrefix = "gemini:refresh_lock:"
+ oauthTokenKeyPrefix = "oauth:token:"
+ oauthRefreshLockKeyPrefix = "oauth:refresh_lock:"
)
type geminiTokenCache struct {
@@ -24,21 +24,26 @@ func NewGeminiTokenCache(rdb *redis.Client) service.GeminiTokenCache {
}
func (c *geminiTokenCache) GetAccessToken(ctx context.Context, cacheKey string) (string, error) {
- key := fmt.Sprintf("%s%s", geminiTokenKeyPrefix, cacheKey)
+ key := fmt.Sprintf("%s%s", oauthTokenKeyPrefix, cacheKey)
return c.rdb.Get(ctx, key).Result()
}
func (c *geminiTokenCache) SetAccessToken(ctx context.Context, cacheKey string, token string, ttl time.Duration) error {
- key := fmt.Sprintf("%s%s", geminiTokenKeyPrefix, cacheKey)
+ key := fmt.Sprintf("%s%s", oauthTokenKeyPrefix, cacheKey)
return c.rdb.Set(ctx, key, token, ttl).Err()
}
+func (c *geminiTokenCache) DeleteAccessToken(ctx context.Context, cacheKey string) error {
+ key := fmt.Sprintf("%s%s", oauthTokenKeyPrefix, cacheKey)
+ return c.rdb.Del(ctx, key).Err()
+}
+
func (c *geminiTokenCache) AcquireRefreshLock(ctx context.Context, cacheKey string, ttl time.Duration) (bool, error) {
- key := fmt.Sprintf("%s%s", geminiRefreshLockKeyPrefix, cacheKey)
+ key := fmt.Sprintf("%s%s", oauthRefreshLockKeyPrefix, cacheKey)
return c.rdb.SetNX(ctx, key, 1, ttl).Result()
}
func (c *geminiTokenCache) ReleaseRefreshLock(ctx context.Context, cacheKey string) error {
- key := fmt.Sprintf("%s%s", geminiRefreshLockKeyPrefix, cacheKey)
+ key := fmt.Sprintf("%s%s", oauthRefreshLockKeyPrefix, cacheKey)
return c.rdb.Del(ctx, key).Err()
}
diff --git a/backend/internal/repository/gemini_token_cache_integration_test.go b/backend/internal/repository/gemini_token_cache_integration_test.go
new file mode 100644
index 00000000..4fe89865
--- /dev/null
+++ b/backend/internal/repository/gemini_token_cache_integration_test.go
@@ -0,0 +1,47 @@
+//go:build integration
+
+package repository
+
+import (
+ "errors"
+ "testing"
+ "time"
+
+ "github.com/Wei-Shaw/sub2api/internal/service"
+ "github.com/redis/go-redis/v9"
+ "github.com/stretchr/testify/require"
+ "github.com/stretchr/testify/suite"
+)
+
+type GeminiTokenCacheSuite struct {
+ IntegrationRedisSuite
+ cache service.GeminiTokenCache
+}
+
+func (s *GeminiTokenCacheSuite) SetupTest() {
+ s.IntegrationRedisSuite.SetupTest()
+ s.cache = NewGeminiTokenCache(s.rdb)
+}
+
+func (s *GeminiTokenCacheSuite) TestDeleteAccessToken() {
+ cacheKey := "project-123"
+ token := "token-value"
+ require.NoError(s.T(), s.cache.SetAccessToken(s.ctx, cacheKey, token, time.Minute))
+
+ got, err := s.cache.GetAccessToken(s.ctx, cacheKey)
+ require.NoError(s.T(), err)
+ require.Equal(s.T(), token, got)
+
+ require.NoError(s.T(), s.cache.DeleteAccessToken(s.ctx, cacheKey))
+
+ _, err = s.cache.GetAccessToken(s.ctx, cacheKey)
+ require.True(s.T(), errors.Is(err, redis.Nil), "expected redis.Nil after delete")
+}
+
+func (s *GeminiTokenCacheSuite) TestDeleteAccessToken_MissingKey() {
+ require.NoError(s.T(), s.cache.DeleteAccessToken(s.ctx, "missing-key"))
+}
+
+func TestGeminiTokenCacheSuite(t *testing.T) {
+ suite.Run(t, new(GeminiTokenCacheSuite))
+}
diff --git a/backend/internal/repository/gemini_token_cache_test.go b/backend/internal/repository/gemini_token_cache_test.go
new file mode 100644
index 00000000..4fcebfdd
--- /dev/null
+++ b/backend/internal/repository/gemini_token_cache_test.go
@@ -0,0 +1,28 @@
+//go:build unit
+
+package repository
+
+import (
+ "context"
+ "testing"
+ "time"
+
+ "github.com/redis/go-redis/v9"
+ "github.com/stretchr/testify/require"
+)
+
+func TestGeminiTokenCache_DeleteAccessToken_RedisError(t *testing.T) {
+ rdb := redis.NewClient(&redis.Options{
+ Addr: "127.0.0.1:1",
+ DialTimeout: 50 * time.Millisecond,
+ ReadTimeout: 50 * time.Millisecond,
+ WriteTimeout: 50 * time.Millisecond,
+ })
+ t.Cleanup(func() {
+ _ = rdb.Close()
+ })
+
+ cache := NewGeminiTokenCache(rdb)
+ err := cache.DeleteAccessToken(context.Background(), "broken")
+ require.Error(t, err)
+}
diff --git a/backend/internal/repository/group_repo.go b/backend/internal/repository/group_repo.go
index 9f3c1a57..5c4d6cf4 100644
--- a/backend/internal/repository/group_repo.go
+++ b/backend/internal/repository/group_repo.go
@@ -49,7 +49,13 @@ func (r *groupRepository) Create(ctx context.Context, groupIn *service.Group) er
SetNillableImagePrice4k(groupIn.ImagePrice4K).
SetDefaultValidityDays(groupIn.DefaultValidityDays).
SetClaudeCodeOnly(groupIn.ClaudeCodeOnly).
- SetNillableFallbackGroupID(groupIn.FallbackGroupID)
+ SetNillableFallbackGroupID(groupIn.FallbackGroupID).
+ SetModelRoutingEnabled(groupIn.ModelRoutingEnabled)
+
+ // 设置模型路由配置
+ if groupIn.ModelRouting != nil {
+ builder = builder.SetModelRouting(groupIn.ModelRouting)
+ }
created, err := builder.Save(ctx)
if err == nil {
@@ -101,7 +107,8 @@ func (r *groupRepository) Update(ctx context.Context, groupIn *service.Group) er
SetNillableImagePrice2k(groupIn.ImagePrice2K).
SetNillableImagePrice4k(groupIn.ImagePrice4K).
SetDefaultValidityDays(groupIn.DefaultValidityDays).
- SetClaudeCodeOnly(groupIn.ClaudeCodeOnly)
+ SetClaudeCodeOnly(groupIn.ClaudeCodeOnly).
+ SetModelRoutingEnabled(groupIn.ModelRoutingEnabled)
// 处理 FallbackGroupID:nil 时清除,否则设置
if groupIn.FallbackGroupID != nil {
@@ -110,6 +117,13 @@ func (r *groupRepository) Update(ctx context.Context, groupIn *service.Group) er
builder = builder.ClearFallbackGroupID()
}
+ // 处理 ModelRouting:nil 时清除,否则设置
+ if groupIn.ModelRouting != nil {
+ builder = builder.SetModelRouting(groupIn.ModelRouting)
+ } else {
+ builder = builder.ClearModelRouting()
+ }
+
updated, err := builder.Save(ctx)
if err != nil {
return translatePersistenceError(err, service.ErrGroupNotFound, service.ErrGroupExists)
diff --git a/backend/internal/repository/ops_repo.go b/backend/internal/repository/ops_repo.go
index f9cb6b4d..613c5bd5 100644
--- a/backend/internal/repository/ops_repo.go
+++ b/backend/internal/repository/ops_repo.go
@@ -55,7 +55,6 @@ INSERT INTO ops_error_logs (
upstream_error_message,
upstream_error_detail,
upstream_errors,
- duration_ms,
time_to_first_token_ms,
request_body,
request_body_truncated,
@@ -65,7 +64,7 @@ INSERT INTO ops_error_logs (
retry_count,
created_at
) VALUES (
- $1,$2,$3,$4,$5,$6,$7,$8,$9,$10,$11,$12,$13,$14,$15,$16,$17,$18,$19,$20,$21,$22,$23,$24,$25,$26,$27,$28,$29,$30,$31,$32,$33,$34,$35
+ $1,$2,$3,$4,$5,$6,$7,$8,$9,$10,$11,$12,$13,$14,$15,$16,$17,$18,$19,$20,$21,$22,$23,$24,$25,$26,$27,$28,$29,$30,$31,$32,$33,$34
) RETURNING id`
var id int64
@@ -98,7 +97,6 @@ INSERT INTO ops_error_logs (
opsNullString(input.UpstreamErrorMessage),
opsNullString(input.UpstreamErrorDetail),
opsNullString(input.UpstreamErrorsJSON),
- opsNullInt(input.DurationMs),
opsNullInt64(input.TimeToFirstTokenMs),
opsNullString(input.RequestBodyJSON),
input.RequestBodyTruncated,
@@ -135,7 +133,7 @@ func (r *opsRepository) ListErrorLogs(ctx context.Context, filter *service.OpsEr
}
where, args := buildOpsErrorLogsWhere(filter)
- countSQL := "SELECT COUNT(*) FROM ops_error_logs " + where
+ countSQL := "SELECT COUNT(*) FROM ops_error_logs e " + where
var total int
if err := r.db.QueryRowContext(ctx, countSQL, args...).Scan(&total); err != nil {
@@ -146,28 +144,43 @@ func (r *opsRepository) ListErrorLogs(ctx context.Context, filter *service.OpsEr
argsWithLimit := append(args, pageSize, offset)
selectSQL := `
SELECT
- id,
- created_at,
- error_phase,
- error_type,
- severity,
- COALESCE(upstream_status_code, status_code, 0),
- COALESCE(platform, ''),
- COALESCE(model, ''),
- duration_ms,
- COALESCE(client_request_id, ''),
- COALESCE(request_id, ''),
- COALESCE(error_message, ''),
- user_id,
- api_key_id,
- account_id,
- group_id,
- CASE WHEN client_ip IS NULL THEN NULL ELSE client_ip::text END,
- COALESCE(request_path, ''),
- stream
-FROM ops_error_logs
+ e.id,
+ e.created_at,
+ e.error_phase,
+ e.error_type,
+ COALESCE(e.error_owner, ''),
+ COALESCE(e.error_source, ''),
+ e.severity,
+ COALESCE(e.upstream_status_code, e.status_code, 0),
+ COALESCE(e.platform, ''),
+ COALESCE(e.model, ''),
+ COALESCE(e.is_retryable, false),
+ COALESCE(e.retry_count, 0),
+ COALESCE(e.resolved, false),
+ e.resolved_at,
+ e.resolved_by_user_id,
+ COALESCE(u2.email, ''),
+ e.resolved_retry_id,
+ COALESCE(e.client_request_id, ''),
+ COALESCE(e.request_id, ''),
+ COALESCE(e.error_message, ''),
+ e.user_id,
+ COALESCE(u.email, ''),
+ e.api_key_id,
+ e.account_id,
+ COALESCE(a.name, ''),
+ e.group_id,
+ COALESCE(g.name, ''),
+ CASE WHEN e.client_ip IS NULL THEN NULL ELSE e.client_ip::text END,
+ COALESCE(e.request_path, ''),
+ e.stream
+FROM ops_error_logs e
+LEFT JOIN accounts a ON e.account_id = a.id
+LEFT JOIN groups g ON e.group_id = g.id
+LEFT JOIN users u ON e.user_id = u.id
+LEFT JOIN users u2 ON e.resolved_by_user_id = u2.id
` + where + `
-ORDER BY created_at DESC
+ORDER BY e.created_at DESC
LIMIT $` + itoa(len(args)+1) + ` OFFSET $` + itoa(len(args)+2)
rows, err := r.db.QueryContext(ctx, selectSQL, argsWithLimit...)
@@ -179,39 +192,65 @@ LIMIT $` + itoa(len(args)+1) + ` OFFSET $` + itoa(len(args)+2)
out := make([]*service.OpsErrorLog, 0, pageSize)
for rows.Next() {
var item service.OpsErrorLog
- var latency sql.NullInt64
var statusCode sql.NullInt64
var clientIP sql.NullString
var userID sql.NullInt64
var apiKeyID sql.NullInt64
var accountID sql.NullInt64
+ var accountName string
var groupID sql.NullInt64
+ var groupName string
+ var userEmail string
+ var resolvedAt sql.NullTime
+ var resolvedBy sql.NullInt64
+ var resolvedByName string
+ var resolvedRetryID sql.NullInt64
if err := rows.Scan(
&item.ID,
&item.CreatedAt,
&item.Phase,
&item.Type,
+ &item.Owner,
+ &item.Source,
&item.Severity,
&statusCode,
&item.Platform,
&item.Model,
- &latency,
+ &item.IsRetryable,
+ &item.RetryCount,
+ &item.Resolved,
+ &resolvedAt,
+ &resolvedBy,
+ &resolvedByName,
+ &resolvedRetryID,
&item.ClientRequestID,
&item.RequestID,
&item.Message,
&userID,
+ &userEmail,
&apiKeyID,
&accountID,
+ &accountName,
&groupID,
+ &groupName,
&clientIP,
&item.RequestPath,
&item.Stream,
); err != nil {
return nil, err
}
- if latency.Valid {
- v := int(latency.Int64)
- item.LatencyMs = &v
+ if resolvedAt.Valid {
+ t := resolvedAt.Time
+ item.ResolvedAt = &t
+ }
+ if resolvedBy.Valid {
+ v := resolvedBy.Int64
+ item.ResolvedByUserID = &v
+ }
+ item.ResolvedByUserName = resolvedByName
+ if resolvedRetryID.Valid {
+ v := resolvedRetryID.Int64
+ item.ResolvedRetryID = &v
}
item.StatusCode = int(statusCode.Int64)
if clientIP.Valid {
@@ -222,6 +261,7 @@ LIMIT $` + itoa(len(args)+1) + ` OFFSET $` + itoa(len(args)+2)
v := userID.Int64
item.UserID = &v
}
+ item.UserEmail = userEmail
if apiKeyID.Valid {
v := apiKeyID.Int64
item.APIKeyID = &v
@@ -230,10 +270,12 @@ LIMIT $` + itoa(len(args)+1) + ` OFFSET $` + itoa(len(args)+2)
v := accountID.Int64
item.AccountID = &v
}
+ item.AccountName = accountName
if groupID.Valid {
v := groupID.Int64
item.GroupID = &v
}
+ item.GroupName = groupName
out = append(out, &item)
}
if err := rows.Err(); err != nil {
@@ -258,49 +300,64 @@ func (r *opsRepository) GetErrorLogByID(ctx context.Context, id int64) (*service
q := `
SELECT
- id,
- created_at,
- error_phase,
- error_type,
- severity,
- COALESCE(upstream_status_code, status_code, 0),
- COALESCE(platform, ''),
- COALESCE(model, ''),
- duration_ms,
- COALESCE(client_request_id, ''),
- COALESCE(request_id, ''),
- COALESCE(error_message, ''),
- COALESCE(error_body, ''),
- upstream_status_code,
- COALESCE(upstream_error_message, ''),
- COALESCE(upstream_error_detail, ''),
- COALESCE(upstream_errors::text, ''),
- is_business_limited,
- user_id,
- api_key_id,
- account_id,
- group_id,
- CASE WHEN client_ip IS NULL THEN NULL ELSE client_ip::text END,
- COALESCE(request_path, ''),
- stream,
- COALESCE(user_agent, ''),
- auth_latency_ms,
- routing_latency_ms,
- upstream_latency_ms,
- response_latency_ms,
- time_to_first_token_ms,
- COALESCE(request_body::text, ''),
- request_body_truncated,
- request_body_bytes,
- COALESCE(request_headers::text, '')
-FROM ops_error_logs
-WHERE id = $1
+ e.id,
+ e.created_at,
+ e.error_phase,
+ e.error_type,
+ COALESCE(e.error_owner, ''),
+ COALESCE(e.error_source, ''),
+ e.severity,
+ COALESCE(e.upstream_status_code, e.status_code, 0),
+ COALESCE(e.platform, ''),
+ COALESCE(e.model, ''),
+ COALESCE(e.is_retryable, false),
+ COALESCE(e.retry_count, 0),
+ COALESCE(e.resolved, false),
+ e.resolved_at,
+ e.resolved_by_user_id,
+ e.resolved_retry_id,
+ COALESCE(e.client_request_id, ''),
+ COALESCE(e.request_id, ''),
+ COALESCE(e.error_message, ''),
+ COALESCE(e.error_body, ''),
+ e.upstream_status_code,
+ COALESCE(e.upstream_error_message, ''),
+ COALESCE(e.upstream_error_detail, ''),
+ COALESCE(e.upstream_errors::text, ''),
+ e.is_business_limited,
+ e.user_id,
+ COALESCE(u.email, ''),
+ e.api_key_id,
+ e.account_id,
+ COALESCE(a.name, ''),
+ e.group_id,
+ COALESCE(g.name, ''),
+ CASE WHEN e.client_ip IS NULL THEN NULL ELSE e.client_ip::text END,
+ COALESCE(e.request_path, ''),
+ e.stream,
+ COALESCE(e.user_agent, ''),
+ e.auth_latency_ms,
+ e.routing_latency_ms,
+ e.upstream_latency_ms,
+ e.response_latency_ms,
+ e.time_to_first_token_ms,
+ COALESCE(e.request_body::text, ''),
+ e.request_body_truncated,
+ e.request_body_bytes,
+ COALESCE(e.request_headers::text, '')
+FROM ops_error_logs e
+LEFT JOIN users u ON e.user_id = u.id
+LEFT JOIN accounts a ON e.account_id = a.id
+LEFT JOIN groups g ON e.group_id = g.id
+WHERE e.id = $1
LIMIT 1`
var out service.OpsErrorLogDetail
- var latency sql.NullInt64
var statusCode sql.NullInt64
var upstreamStatusCode sql.NullInt64
+ var resolvedAt sql.NullTime
+ var resolvedBy sql.NullInt64
+ var resolvedRetryID sql.NullInt64
var clientIP sql.NullString
var userID sql.NullInt64
var apiKeyID sql.NullInt64
@@ -318,11 +375,18 @@ LIMIT 1`
&out.CreatedAt,
&out.Phase,
&out.Type,
+ &out.Owner,
+ &out.Source,
&out.Severity,
&statusCode,
&out.Platform,
&out.Model,
- &latency,
+ &out.IsRetryable,
+ &out.RetryCount,
+ &out.Resolved,
+ &resolvedAt,
+ &resolvedBy,
+ &resolvedRetryID,
&out.ClientRequestID,
&out.RequestID,
&out.Message,
@@ -333,9 +397,12 @@ LIMIT 1`
&out.UpstreamErrors,
&out.IsBusinessLimited,
&userID,
+ &out.UserEmail,
&apiKeyID,
&accountID,
+ &out.AccountName,
&groupID,
+ &out.GroupName,
&clientIP,
&out.RequestPath,
&out.Stream,
@@ -355,9 +422,17 @@ LIMIT 1`
}
out.StatusCode = int(statusCode.Int64)
- if latency.Valid {
- v := int(latency.Int64)
- out.LatencyMs = &v
+ if resolvedAt.Valid {
+ t := resolvedAt.Time
+ out.ResolvedAt = &t
+ }
+ if resolvedBy.Valid {
+ v := resolvedBy.Int64
+ out.ResolvedByUserID = &v
+ }
+ if resolvedRetryID.Valid {
+ v := resolvedRetryID.Int64
+ out.ResolvedRetryID = &v
}
if clientIP.Valid {
s := clientIP.String
@@ -487,9 +562,15 @@ SET
status = $2,
finished_at = $3,
duration_ms = $4,
- result_request_id = $5,
- result_error_id = $6,
- error_message = $7
+ success = $5,
+ http_status_code = $6,
+ upstream_request_id = $7,
+ used_account_id = $8,
+ response_preview = $9,
+ response_truncated = $10,
+ result_request_id = $11,
+ result_error_id = $12,
+ error_message = $13
WHERE id = $1`
_, err := r.db.ExecContext(
@@ -499,8 +580,14 @@ WHERE id = $1`
strings.TrimSpace(input.Status),
nullTime(input.FinishedAt),
input.DurationMs,
+ nullBool(input.Success),
+ nullInt(input.HTTPStatusCode),
+ opsNullString(input.UpstreamRequestID),
+ nullInt64(input.UsedAccountID),
+ opsNullString(input.ResponsePreview),
+ nullBool(input.ResponseTruncated),
opsNullString(input.ResultRequestID),
- opsNullInt64(input.ResultErrorID),
+ nullInt64(input.ResultErrorID),
opsNullString(input.ErrorMessage),
)
return err
@@ -526,6 +613,12 @@ SELECT
started_at,
finished_at,
duration_ms,
+ success,
+ http_status_code,
+ upstream_request_id,
+ used_account_id,
+ response_preview,
+ response_truncated,
result_request_id,
result_error_id,
error_message
@@ -540,6 +633,12 @@ LIMIT 1`
var startedAt sql.NullTime
var finishedAt sql.NullTime
var durationMs sql.NullInt64
+ var success sql.NullBool
+ var httpStatusCode sql.NullInt64
+ var upstreamRequestID sql.NullString
+ var usedAccountID sql.NullInt64
+ var responsePreview sql.NullString
+ var responseTruncated sql.NullBool
var resultRequestID sql.NullString
var resultErrorID sql.NullInt64
var errorMessage sql.NullString
@@ -555,6 +654,12 @@ LIMIT 1`
&startedAt,
&finishedAt,
&durationMs,
+ &success,
+ &httpStatusCode,
+ &upstreamRequestID,
+ &usedAccountID,
+ &responsePreview,
+ &responseTruncated,
&resultRequestID,
&resultErrorID,
&errorMessage,
@@ -579,6 +684,30 @@ LIMIT 1`
v := durationMs.Int64
out.DurationMs = &v
}
+ if success.Valid {
+ v := success.Bool
+ out.Success = &v
+ }
+ if httpStatusCode.Valid {
+ v := int(httpStatusCode.Int64)
+ out.HTTPStatusCode = &v
+ }
+ if upstreamRequestID.Valid {
+ s := upstreamRequestID.String
+ out.UpstreamRequestID = &s
+ }
+ if usedAccountID.Valid {
+ v := usedAccountID.Int64
+ out.UsedAccountID = &v
+ }
+ if responsePreview.Valid {
+ s := responsePreview.String
+ out.ResponsePreview = &s
+ }
+ if responseTruncated.Valid {
+ v := responseTruncated.Bool
+ out.ResponseTruncated = &v
+ }
if resultRequestID.Valid {
s := resultRequestID.String
out.ResultRequestID = &s
@@ -602,30 +731,234 @@ func nullTime(t time.Time) sql.NullTime {
return sql.NullTime{Time: t, Valid: true}
}
+func nullBool(v *bool) sql.NullBool {
+ if v == nil {
+ return sql.NullBool{}
+ }
+ return sql.NullBool{Bool: *v, Valid: true}
+}
+
+func (r *opsRepository) ListRetryAttemptsByErrorID(ctx context.Context, sourceErrorID int64, limit int) ([]*service.OpsRetryAttempt, error) {
+ if r == nil || r.db == nil {
+ return nil, fmt.Errorf("nil ops repository")
+ }
+ if sourceErrorID <= 0 {
+ return nil, fmt.Errorf("invalid source_error_id")
+ }
+ if limit <= 0 {
+ limit = 50
+ }
+ if limit > 200 {
+ limit = 200
+ }
+
+ q := `
+SELECT
+ r.id,
+ r.created_at,
+ COALESCE(r.requested_by_user_id, 0),
+ r.source_error_id,
+ COALESCE(r.mode, ''),
+ r.pinned_account_id,
+ COALESCE(pa.name, ''),
+ COALESCE(r.status, ''),
+ r.started_at,
+ r.finished_at,
+ r.duration_ms,
+ r.success,
+ r.http_status_code,
+ r.upstream_request_id,
+ r.used_account_id,
+ COALESCE(ua.name, ''),
+ r.response_preview,
+ r.response_truncated,
+ r.result_request_id,
+ r.result_error_id,
+ r.error_message
+FROM ops_retry_attempts r
+LEFT JOIN accounts pa ON r.pinned_account_id = pa.id
+LEFT JOIN accounts ua ON r.used_account_id = ua.id
+WHERE r.source_error_id = $1
+ORDER BY r.created_at DESC
+LIMIT $2`
+
+ rows, err := r.db.QueryContext(ctx, q, sourceErrorID, limit)
+ if err != nil {
+ return nil, err
+ }
+ defer func() { _ = rows.Close() }()
+
+ out := make([]*service.OpsRetryAttempt, 0, 16)
+ for rows.Next() {
+ var item service.OpsRetryAttempt
+ var pinnedAccountID sql.NullInt64
+ var pinnedAccountName string
+ var requestedBy sql.NullInt64
+ var startedAt sql.NullTime
+ var finishedAt sql.NullTime
+ var durationMs sql.NullInt64
+ var success sql.NullBool
+ var httpStatusCode sql.NullInt64
+ var upstreamRequestID sql.NullString
+ var usedAccountID sql.NullInt64
+ var usedAccountName string
+ var responsePreview sql.NullString
+ var responseTruncated sql.NullBool
+ var resultRequestID sql.NullString
+ var resultErrorID sql.NullInt64
+ var errorMessage sql.NullString
+
+ if err := rows.Scan(
+ &item.ID,
+ &item.CreatedAt,
+ &requestedBy,
+ &item.SourceErrorID,
+ &item.Mode,
+ &pinnedAccountID,
+ &pinnedAccountName,
+ &item.Status,
+ &startedAt,
+ &finishedAt,
+ &durationMs,
+ &success,
+ &httpStatusCode,
+ &upstreamRequestID,
+ &usedAccountID,
+ &usedAccountName,
+ &responsePreview,
+ &responseTruncated,
+ &resultRequestID,
+ &resultErrorID,
+ &errorMessage,
+ ); err != nil {
+ return nil, err
+ }
+
+ item.RequestedByUserID = requestedBy.Int64
+ if pinnedAccountID.Valid {
+ v := pinnedAccountID.Int64
+ item.PinnedAccountID = &v
+ }
+ item.PinnedAccountName = pinnedAccountName
+ if startedAt.Valid {
+ t := startedAt.Time
+ item.StartedAt = &t
+ }
+ if finishedAt.Valid {
+ t := finishedAt.Time
+ item.FinishedAt = &t
+ }
+ if durationMs.Valid {
+ v := durationMs.Int64
+ item.DurationMs = &v
+ }
+ if success.Valid {
+ v := success.Bool
+ item.Success = &v
+ }
+ if httpStatusCode.Valid {
+ v := int(httpStatusCode.Int64)
+ item.HTTPStatusCode = &v
+ }
+ if upstreamRequestID.Valid {
+ item.UpstreamRequestID = &upstreamRequestID.String
+ }
+ if usedAccountID.Valid {
+ v := usedAccountID.Int64
+ item.UsedAccountID = &v
+ }
+ item.UsedAccountName = usedAccountName
+ if responsePreview.Valid {
+ item.ResponsePreview = &responsePreview.String
+ }
+ if responseTruncated.Valid {
+ v := responseTruncated.Bool
+ item.ResponseTruncated = &v
+ }
+ if resultRequestID.Valid {
+ item.ResultRequestID = &resultRequestID.String
+ }
+ if resultErrorID.Valid {
+ v := resultErrorID.Int64
+ item.ResultErrorID = &v
+ }
+ if errorMessage.Valid {
+ item.ErrorMessage = &errorMessage.String
+ }
+ out = append(out, &item)
+ }
+ if err := rows.Err(); err != nil {
+ return nil, err
+ }
+ return out, nil
+}
+
+func (r *opsRepository) UpdateErrorResolution(ctx context.Context, errorID int64, resolved bool, resolvedByUserID *int64, resolvedRetryID *int64, resolvedAt *time.Time) error {
+ if r == nil || r.db == nil {
+ return fmt.Errorf("nil ops repository")
+ }
+ if errorID <= 0 {
+ return fmt.Errorf("invalid error id")
+ }
+
+ q := `
+UPDATE ops_error_logs
+SET
+ resolved = $2,
+ resolved_at = $3,
+ resolved_by_user_id = $4,
+ resolved_retry_id = $5
+WHERE id = $1`
+
+ at := sql.NullTime{}
+ if resolvedAt != nil && !resolvedAt.IsZero() {
+ at = sql.NullTime{Time: resolvedAt.UTC(), Valid: true}
+ } else if resolved {
+ now := time.Now().UTC()
+ at = sql.NullTime{Time: now, Valid: true}
+ }
+
+ _, err := r.db.ExecContext(
+ ctx,
+ q,
+ errorID,
+ resolved,
+ at,
+ nullInt64(resolvedByUserID),
+ nullInt64(resolvedRetryID),
+ )
+ return err
+}
+
func buildOpsErrorLogsWhere(filter *service.OpsErrorLogFilter) (string, []any) {
- clauses := make([]string, 0, 8)
- args := make([]any, 0, 8)
+ clauses := make([]string, 0, 12)
+ args := make([]any, 0, 12)
clauses = append(clauses, "1=1")
phaseFilter := ""
if filter != nil {
phaseFilter = strings.TrimSpace(strings.ToLower(filter.Phase))
}
- // ops_error_logs primarily stores client-visible error requests (status>=400),
+ // ops_error_logs stores client-visible error requests (status>=400),
// but we also persist "recovered" upstream errors (status<400) for upstream health visibility.
- // By default, keep list endpoints scoped to client errors unless explicitly filtering upstream phase.
+ // If Resolved is not specified, do not filter by resolved state (backward-compatible).
+ resolvedFilter := (*bool)(nil)
+ if filter != nil {
+ resolvedFilter = filter.Resolved
+ }
+ // Keep list endpoints scoped to client errors unless explicitly filtering upstream phase.
if phaseFilter != "upstream" {
clauses = append(clauses, "COALESCE(status_code, 0) >= 400")
}
if filter.StartTime != nil && !filter.StartTime.IsZero() {
args = append(args, filter.StartTime.UTC())
- clauses = append(clauses, "created_at >= $"+itoa(len(args)))
+ clauses = append(clauses, "e.created_at >= $"+itoa(len(args)))
}
if filter.EndTime != nil && !filter.EndTime.IsZero() {
args = append(args, filter.EndTime.UTC())
// Keep time-window semantics consistent with other ops queries: [start, end)
- clauses = append(clauses, "created_at < $"+itoa(len(args)))
+ clauses = append(clauses, "e.created_at < $"+itoa(len(args)))
}
if p := strings.TrimSpace(filter.Platform); p != "" {
args = append(args, p)
@@ -643,10 +976,59 @@ func buildOpsErrorLogsWhere(filter *service.OpsErrorLogFilter) (string, []any) {
args = append(args, phase)
clauses = append(clauses, "error_phase = $"+itoa(len(args)))
}
+ if filter != nil {
+ if owner := strings.TrimSpace(strings.ToLower(filter.Owner)); owner != "" {
+ args = append(args, owner)
+ clauses = append(clauses, "LOWER(COALESCE(error_owner,'')) = $"+itoa(len(args)))
+ }
+ if source := strings.TrimSpace(strings.ToLower(filter.Source)); source != "" {
+ args = append(args, source)
+ clauses = append(clauses, "LOWER(COALESCE(error_source,'')) = $"+itoa(len(args)))
+ }
+ }
+ if resolvedFilter != nil {
+ args = append(args, *resolvedFilter)
+ clauses = append(clauses, "COALESCE(resolved,false) = $"+itoa(len(args)))
+ }
+
+ // View filter: errors vs excluded vs all.
+ // Excluded = upstream 429/529 and business-limited (quota/concurrency/billing) errors.
+ view := ""
+ if filter != nil {
+ view = strings.ToLower(strings.TrimSpace(filter.View))
+ }
+ switch view {
+ case "", "errors":
+ clauses = append(clauses, "COALESCE(is_business_limited,false) = false")
+ clauses = append(clauses, "COALESCE(upstream_status_code, status_code, 0) NOT IN (429, 529)")
+ case "excluded":
+ clauses = append(clauses, "(COALESCE(is_business_limited,false) = true OR COALESCE(upstream_status_code, status_code, 0) IN (429, 529))")
+ case "all":
+ // no-op
+ default:
+ // treat unknown as default 'errors'
+ clauses = append(clauses, "COALESCE(is_business_limited,false) = false")
+ clauses = append(clauses, "COALESCE(upstream_status_code, status_code, 0) NOT IN (429, 529)")
+ }
if len(filter.StatusCodes) > 0 {
args = append(args, pq.Array(filter.StatusCodes))
clauses = append(clauses, "COALESCE(upstream_status_code, status_code, 0) = ANY($"+itoa(len(args))+")")
+ } else if filter.StatusCodesOther {
+ // "Other" means: status codes not in the common list.
+ known := []int{400, 401, 403, 404, 409, 422, 429, 500, 502, 503, 504, 529}
+ args = append(args, pq.Array(known))
+ clauses = append(clauses, "NOT (COALESCE(upstream_status_code, status_code, 0) = ANY($"+itoa(len(args))+"))")
}
+ // Exact correlation keys (preferred for request↔upstream linkage).
+ if rid := strings.TrimSpace(filter.RequestID); rid != "" {
+ args = append(args, rid)
+ clauses = append(clauses, "COALESCE(request_id,'') = $"+itoa(len(args)))
+ }
+ if crid := strings.TrimSpace(filter.ClientRequestID); crid != "" {
+ args = append(args, crid)
+ clauses = append(clauses, "COALESCE(client_request_id,'') = $"+itoa(len(args)))
+ }
+
if q := strings.TrimSpace(filter.Query); q != "" {
like := "%" + q + "%"
args = append(args, like)
@@ -654,6 +1036,13 @@ func buildOpsErrorLogsWhere(filter *service.OpsErrorLogFilter) (string, []any) {
clauses = append(clauses, "(request_id ILIKE $"+n+" OR client_request_id ILIKE $"+n+" OR error_message ILIKE $"+n+")")
}
+ if userQuery := strings.TrimSpace(filter.UserQuery); userQuery != "" {
+ like := "%" + userQuery + "%"
+ args = append(args, like)
+ n := itoa(len(args))
+ clauses = append(clauses, "u.email ILIKE $"+n)
+ }
+
return "WHERE " + strings.Join(clauses, " AND "), args
}
diff --git a/backend/internal/repository/ops_repo_alerts.go b/backend/internal/repository/ops_repo_alerts.go
index f601c363..bd98b7e4 100644
--- a/backend/internal/repository/ops_repo_alerts.go
+++ b/backend/internal/repository/ops_repo_alerts.go
@@ -354,7 +354,7 @@ SELECT
created_at
FROM ops_alert_events
` + where + `
-ORDER BY fired_at DESC
+ORDER BY fired_at DESC, id DESC
LIMIT ` + limitArg
rows, err := r.db.QueryContext(ctx, q, args...)
@@ -413,6 +413,43 @@ LIMIT ` + limitArg
return out, nil
}
+func (r *opsRepository) GetAlertEventByID(ctx context.Context, eventID int64) (*service.OpsAlertEvent, error) {
+ if r == nil || r.db == nil {
+ return nil, fmt.Errorf("nil ops repository")
+ }
+ if eventID <= 0 {
+ return nil, fmt.Errorf("invalid event id")
+ }
+
+ q := `
+SELECT
+ id,
+ COALESCE(rule_id, 0),
+ COALESCE(severity, ''),
+ COALESCE(status, ''),
+ COALESCE(title, ''),
+ COALESCE(description, ''),
+ metric_value,
+ threshold_value,
+ dimensions,
+ fired_at,
+ resolved_at,
+ email_sent,
+ created_at
+FROM ops_alert_events
+WHERE id = $1`
+
+ row := r.db.QueryRowContext(ctx, q, eventID)
+ ev, err := scanOpsAlertEvent(row)
+ if err != nil {
+ if err == sql.ErrNoRows {
+ return nil, nil
+ }
+ return nil, err
+ }
+ return ev, nil
+}
+
func (r *opsRepository) GetActiveAlertEvent(ctx context.Context, ruleID int64) (*service.OpsAlertEvent, error) {
if r == nil || r.db == nil {
return nil, fmt.Errorf("nil ops repository")
@@ -591,6 +628,121 @@ type opsAlertEventRow interface {
Scan(dest ...any) error
}
+func (r *opsRepository) CreateAlertSilence(ctx context.Context, input *service.OpsAlertSilence) (*service.OpsAlertSilence, error) {
+ if r == nil || r.db == nil {
+ return nil, fmt.Errorf("nil ops repository")
+ }
+ if input == nil {
+ return nil, fmt.Errorf("nil input")
+ }
+ if input.RuleID <= 0 {
+ return nil, fmt.Errorf("invalid rule_id")
+ }
+ platform := strings.TrimSpace(input.Platform)
+ if platform == "" {
+ return nil, fmt.Errorf("invalid platform")
+ }
+ if input.Until.IsZero() {
+ return nil, fmt.Errorf("invalid until")
+ }
+
+ q := `
+INSERT INTO ops_alert_silences (
+ rule_id,
+ platform,
+ group_id,
+ region,
+ until,
+ reason,
+ created_by,
+ created_at
+) VALUES (
+ $1,$2,$3,$4,$5,$6,$7,NOW()
+)
+RETURNING id, rule_id, platform, group_id, region, until, COALESCE(reason,''), created_by, created_at`
+
+ row := r.db.QueryRowContext(
+ ctx,
+ q,
+ input.RuleID,
+ platform,
+ opsNullInt64(input.GroupID),
+ opsNullString(input.Region),
+ input.Until,
+ opsNullString(input.Reason),
+ opsNullInt64(input.CreatedBy),
+ )
+
+ var out service.OpsAlertSilence
+ var groupID sql.NullInt64
+ var region sql.NullString
+ var createdBy sql.NullInt64
+ if err := row.Scan(
+ &out.ID,
+ &out.RuleID,
+ &out.Platform,
+ &groupID,
+ ®ion,
+ &out.Until,
+ &out.Reason,
+ &createdBy,
+ &out.CreatedAt,
+ ); err != nil {
+ return nil, err
+ }
+ if groupID.Valid {
+ v := groupID.Int64
+ out.GroupID = &v
+ }
+ if region.Valid {
+ v := strings.TrimSpace(region.String)
+ if v != "" {
+ out.Region = &v
+ }
+ }
+ if createdBy.Valid {
+ v := createdBy.Int64
+ out.CreatedBy = &v
+ }
+ return &out, nil
+}
+
+func (r *opsRepository) IsAlertSilenced(ctx context.Context, ruleID int64, platform string, groupID *int64, region *string, now time.Time) (bool, error) {
+ if r == nil || r.db == nil {
+ return false, fmt.Errorf("nil ops repository")
+ }
+ if ruleID <= 0 {
+ return false, fmt.Errorf("invalid rule id")
+ }
+ platform = strings.TrimSpace(platform)
+ if platform == "" {
+ return false, nil
+ }
+ if now.IsZero() {
+ now = time.Now().UTC()
+ }
+
+ q := `
+SELECT 1
+FROM ops_alert_silences
+WHERE rule_id = $1
+ AND platform = $2
+ AND (group_id IS NOT DISTINCT FROM $3)
+ AND (region IS NOT DISTINCT FROM $4)
+ AND until > $5
+LIMIT 1`
+
+ var dummy int
+ err := r.db.QueryRowContext(ctx, q, ruleID, platform, opsNullInt64(groupID), opsNullString(region), now).Scan(&dummy)
+ if err != nil {
+ if err == sql.ErrNoRows {
+ return false, nil
+ }
+ return false, err
+ }
+ return true, nil
+}
+
func scanOpsAlertEvent(row opsAlertEventRow) (*service.OpsAlertEvent, error) {
var ev service.OpsAlertEvent
var metricValue sql.NullFloat64
@@ -652,6 +804,10 @@ func buildOpsAlertEventsWhere(filter *service.OpsAlertEventFilter) (string, []an
args = append(args, severity)
clauses = append(clauses, "severity = $"+itoa(len(args)))
}
+ if filter.EmailSent != nil {
+ args = append(args, *filter.EmailSent)
+ clauses = append(clauses, "email_sent = $"+itoa(len(args)))
+ }
if filter.StartTime != nil && !filter.StartTime.IsZero() {
args = append(args, *filter.StartTime)
clauses = append(clauses, "fired_at >= $"+itoa(len(args)))
@@ -661,6 +817,14 @@ func buildOpsAlertEventsWhere(filter *service.OpsAlertEventFilter) (string, []an
clauses = append(clauses, "fired_at < $"+itoa(len(args)))
}
+ // Cursor pagination (descending by fired_at, then id)
+ if filter.BeforeFiredAt != nil && !filter.BeforeFiredAt.IsZero() && filter.BeforeID != nil && *filter.BeforeID > 0 {
+ args = append(args, *filter.BeforeFiredAt)
+ tsArg := "$" + itoa(len(args))
+ args = append(args, *filter.BeforeID)
+ idArg := "$" + itoa(len(args))
+ clauses = append(clauses, fmt.Sprintf("(fired_at < %s OR (fired_at = %s AND id < %s))", tsArg, tsArg, idArg))
+ }
// Dimensions are stored in JSONB. We filter best-effort without requiring GIN indexes.
if platform := strings.TrimSpace(filter.Platform); platform != "" {
args = append(args, platform)
diff --git a/backend/internal/repository/ops_repo_metrics.go b/backend/internal/repository/ops_repo_metrics.go
index bc80ed6e..713e0eb9 100644
--- a/backend/internal/repository/ops_repo_metrics.go
+++ b/backend/internal/repository/ops_repo_metrics.go
@@ -296,9 +296,10 @@ INSERT INTO ops_job_heartbeats (
last_error_at,
last_error,
last_duration_ms,
+ last_result,
updated_at
) VALUES (
- $1,$2,$3,$4,$5,$6,NOW()
+ $1,$2,$3,$4,$5,$6,$7,NOW()
)
ON CONFLICT (job_name) DO UPDATE SET
last_run_at = COALESCE(EXCLUDED.last_run_at, ops_job_heartbeats.last_run_at),
@@ -312,6 +313,10 @@ ON CONFLICT (job_name) DO UPDATE SET
ELSE COALESCE(EXCLUDED.last_error, ops_job_heartbeats.last_error)
END,
last_duration_ms = COALESCE(EXCLUDED.last_duration_ms, ops_job_heartbeats.last_duration_ms),
+ last_result = CASE
+ WHEN EXCLUDED.last_success_at IS NOT NULL THEN COALESCE(EXCLUDED.last_result, ops_job_heartbeats.last_result)
+ ELSE ops_job_heartbeats.last_result
+ END,
updated_at = NOW()`
_, err := r.db.ExecContext(
@@ -323,6 +328,7 @@ ON CONFLICT (job_name) DO UPDATE SET
opsNullTime(input.LastErrorAt),
opsNullString(input.LastError),
opsNullInt(input.LastDurationMs),
+ opsNullString(input.LastResult),
)
return err
}
@@ -340,6 +346,7 @@ SELECT
last_error_at,
last_error,
last_duration_ms,
+ last_result,
updated_at
FROM ops_job_heartbeats
ORDER BY job_name ASC`
@@ -359,6 +366,8 @@ ORDER BY job_name ASC`
var lastError sql.NullString
var lastDuration sql.NullInt64
+ var lastResult sql.NullString
+
if err := rows.Scan(
&item.JobName,
&lastRun,
@@ -366,6 +375,7 @@ ORDER BY job_name ASC`
&lastErrorAt,
&lastError,
&lastDuration,
+ &lastResult,
&item.UpdatedAt,
); err != nil {
return nil, err
@@ -391,6 +401,10 @@ ORDER BY job_name ASC`
v := lastDuration.Int64
item.LastDurationMs = &v
}
+ if lastResult.Valid {
+ v := lastResult.String
+ item.LastResult = &v
+ }
out = append(out, &item)
}
diff --git a/backend/internal/repository/proxy_latency_cache.go b/backend/internal/repository/proxy_latency_cache.go
new file mode 100644
index 00000000..4458b5e1
--- /dev/null
+++ b/backend/internal/repository/proxy_latency_cache.go
@@ -0,0 +1,74 @@
+package repository
+
+import (
+ "context"
+ "encoding/json"
+ "fmt"
+
+ "github.com/Wei-Shaw/sub2api/internal/service"
+ "github.com/redis/go-redis/v9"
+)
+
+const proxyLatencyKeyPrefix = "proxy:latency:"
+
+func proxyLatencyKey(proxyID int64) string {
+ return fmt.Sprintf("%s%d", proxyLatencyKeyPrefix, proxyID)
+}
+
+type proxyLatencyCache struct {
+ rdb *redis.Client
+}
+
+func NewProxyLatencyCache(rdb *redis.Client) service.ProxyLatencyCache {
+ return &proxyLatencyCache{rdb: rdb}
+}
+
+func (c *proxyLatencyCache) GetProxyLatencies(ctx context.Context, proxyIDs []int64) (map[int64]*service.ProxyLatencyInfo, error) {
+ results := make(map[int64]*service.ProxyLatencyInfo)
+ if len(proxyIDs) == 0 {
+ return results, nil
+ }
+
+ keys := make([]string, 0, len(proxyIDs))
+ for _, id := range proxyIDs {
+ keys = append(keys, proxyLatencyKey(id))
+ }
+
+ values, err := c.rdb.MGet(ctx, keys...).Result()
+ if err != nil {
+ return results, err
+ }
+
+ for i, raw := range values {
+ if raw == nil {
+ continue
+ }
+ var payload []byte
+ switch v := raw.(type) {
+ case string:
+ payload = []byte(v)
+ case []byte:
+ payload = v
+ default:
+ continue
+ }
+ var info service.ProxyLatencyInfo
+ if err := json.Unmarshal(payload, &info); err != nil {
+ continue
+ }
+ results[proxyIDs[i]] = &info
+ }
+
+ return results, nil
+}
+
+func (c *proxyLatencyCache) SetProxyLatency(ctx context.Context, proxyID int64, info *service.ProxyLatencyInfo) error {
+ if info == nil {
+ return nil
+ }
+ payload, err := json.Marshal(info)
+ if err != nil {
+ return err
+ }
+ return c.rdb.Set(ctx, proxyLatencyKey(proxyID), payload, 0).Err()
+}
diff --git a/backend/internal/repository/proxy_probe_service.go b/backend/internal/repository/proxy_probe_service.go
index 5c42e4d1..fb6f405e 100644
--- a/backend/internal/repository/proxy_probe_service.go
+++ b/backend/internal/repository/proxy_probe_service.go
@@ -7,6 +7,7 @@ import (
"io"
"log"
"net/http"
+ "strings"
"time"
"github.com/Wei-Shaw/sub2api/internal/config"
@@ -34,7 +35,10 @@ func NewProxyExitInfoProber(cfg *config.Config) service.ProxyExitInfoProber {
}
}
-const defaultIPInfoURL = "https://ipinfo.io/json"
+const (
+ defaultIPInfoURL = "http://ip-api.com/json/?lang=zh-CN"
+ defaultProxyProbeTimeout = 30 * time.Second
+)
type proxyProbeService struct {
ipInfoURL string
@@ -46,7 +50,7 @@ type proxyProbeService struct {
func (s *proxyProbeService) ProbeProxy(ctx context.Context, proxyURL string) (*service.ProxyExitInfo, int64, error) {
client, err := httpclient.GetClient(httpclient.Options{
ProxyURL: proxyURL,
- Timeout: 15 * time.Second,
+ Timeout: defaultProxyProbeTimeout,
InsecureSkipVerify: s.insecureSkipVerify,
ProxyStrict: true,
ValidateResolvedIP: s.validateResolvedIP,
@@ -75,10 +79,14 @@ func (s *proxyProbeService) ProbeProxy(ctx context.Context, proxyURL string) (*s
}
var ipInfo struct {
- IP string `json:"ip"`
- City string `json:"city"`
- Region string `json:"region"`
- Country string `json:"country"`
+ Status string `json:"status"`
+ Message string `json:"message"`
+ Query string `json:"query"`
+ City string `json:"city"`
+ Region string `json:"region"`
+ RegionName string `json:"regionName"`
+ Country string `json:"country"`
+ CountryCode string `json:"countryCode"`
}
body, err := io.ReadAll(resp.Body)
@@ -89,11 +97,22 @@ func (s *proxyProbeService) ProbeProxy(ctx context.Context, proxyURL string) (*s
if err := json.Unmarshal(body, &ipInfo); err != nil {
return nil, latencyMs, fmt.Errorf("failed to parse response: %w", err)
}
+ if strings.ToLower(ipInfo.Status) != "success" {
+ if ipInfo.Message == "" {
+ ipInfo.Message = "ip-api request failed"
+ }
+ return nil, latencyMs, fmt.Errorf("ip-api request failed: %s", ipInfo.Message)
+ }
+ region := ipInfo.RegionName
+ if region == "" {
+ region = ipInfo.Region
+ }
return &service.ProxyExitInfo{
- IP: ipInfo.IP,
- City: ipInfo.City,
- Region: ipInfo.Region,
- Country: ipInfo.Country,
+ IP: ipInfo.Query,
+ City: ipInfo.City,
+ Region: region,
+ Country: ipInfo.Country,
+ CountryCode: ipInfo.CountryCode,
}, latencyMs, nil
}
diff --git a/backend/internal/repository/proxy_probe_service_test.go b/backend/internal/repository/proxy_probe_service_test.go
index fe45adbb..f1cd5721 100644
--- a/backend/internal/repository/proxy_probe_service_test.go
+++ b/backend/internal/repository/proxy_probe_service_test.go
@@ -21,7 +21,7 @@ type ProxyProbeServiceSuite struct {
func (s *ProxyProbeServiceSuite) SetupTest() {
s.ctx = context.Background()
s.prober = &proxyProbeService{
- ipInfoURL: "http://ipinfo.test/json",
+ ipInfoURL: "http://ip-api.test/json/?lang=zh-CN",
allowPrivateHosts: true,
}
}
@@ -54,7 +54,7 @@ func (s *ProxyProbeServiceSuite) TestProbeProxy_Success() {
s.setupProxyServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
seen <- r.RequestURI
w.Header().Set("Content-Type", "application/json")
- _, _ = io.WriteString(w, `{"ip":"1.2.3.4","city":"c","region":"r","country":"cc"}`)
+ _, _ = io.WriteString(w, `{"status":"success","query":"1.2.3.4","city":"c","regionName":"r","country":"cc","countryCode":"CC"}`)
}))
info, latencyMs, err := s.prober.ProbeProxy(s.ctx, s.proxySrv.URL)
@@ -64,11 +64,12 @@ func (s *ProxyProbeServiceSuite) TestProbeProxy_Success() {
require.Equal(s.T(), "c", info.City)
require.Equal(s.T(), "r", info.Region)
require.Equal(s.T(), "cc", info.Country)
+ require.Equal(s.T(), "CC", info.CountryCode)
// Verify proxy received the request
select {
case uri := <-seen:
- require.Contains(s.T(), uri, "ipinfo.test", "expected request to go through proxy")
+ require.Contains(s.T(), uri, "ip-api.test", "expected request to go through proxy")
default:
require.Fail(s.T(), "expected proxy to receive request")
}
diff --git a/backend/internal/repository/proxy_repo.go b/backend/internal/repository/proxy_repo.go
index 622b0aeb..36965c05 100644
--- a/backend/internal/repository/proxy_repo.go
+++ b/backend/internal/repository/proxy_repo.go
@@ -219,12 +219,54 @@ func (r *proxyRepository) ExistsByHostPortAuth(ctx context.Context, host string,
// CountAccountsByProxyID returns the number of accounts using a specific proxy
func (r *proxyRepository) CountAccountsByProxyID(ctx context.Context, proxyID int64) (int64, error) {
var count int64
- if err := scanSingleRow(ctx, r.sql, "SELECT COUNT(*) FROM accounts WHERE proxy_id = $1", []any{proxyID}, &count); err != nil {
+ if err := scanSingleRow(ctx, r.sql, "SELECT COUNT(*) FROM accounts WHERE proxy_id = $1 AND deleted_at IS NULL", []any{proxyID}, &count); err != nil {
return 0, err
}
return count, nil
}
+func (r *proxyRepository) ListAccountSummariesByProxyID(ctx context.Context, proxyID int64) ([]service.ProxyAccountSummary, error) {
+ rows, err := r.sql.QueryContext(ctx, `
+ SELECT id, name, platform, type, notes
+ FROM accounts
+ WHERE proxy_id = $1 AND deleted_at IS NULL
+ ORDER BY id DESC
+ `, proxyID)
+ if err != nil {
+ return nil, err
+ }
+ defer func() { _ = rows.Close() }()
+
+ out := make([]service.ProxyAccountSummary, 0)
+ for rows.Next() {
+ var (
+ id int64
+ name string
+ platform string
+ accType string
+ notes sql.NullString
+ )
+ if err := rows.Scan(&id, &name, &platform, &accType, ¬es); err != nil {
+ return nil, err
+ }
+ var notesPtr *string
+ if notes.Valid {
+ notesPtr = ¬es.String
+ }
+ out = append(out, service.ProxyAccountSummary{
+ ID: id,
+ Name: name,
+ Platform: platform,
+ Type: accType,
+ Notes: notesPtr,
+ })
+ }
+ if err := rows.Err(); err != nil {
+ return nil, err
+ }
+ return out, nil
+}
+
// GetAccountCountsForProxies returns a map of proxy ID to account count for all proxies
func (r *proxyRepository) GetAccountCountsForProxies(ctx context.Context) (counts map[int64]int64, err error) {
rows, err := r.sql.QueryContext(ctx, "SELECT proxy_id, COUNT(*) AS count FROM accounts WHERE proxy_id IS NOT NULL AND deleted_at IS NULL GROUP BY proxy_id")
diff --git a/backend/internal/repository/scheduler_snapshot_outbox_integration_test.go b/backend/internal/repository/scheduler_snapshot_outbox_integration_test.go
index dede6014..e442a125 100644
--- a/backend/internal/repository/scheduler_snapshot_outbox_integration_test.go
+++ b/backend/internal/repository/scheduler_snapshot_outbox_integration_test.go
@@ -27,7 +27,7 @@ func TestSchedulerSnapshotOutboxReplay(t *testing.T) {
RunMode: config.RunModeStandard,
Gateway: config.GatewayConfig{
Scheduling: config.GatewaySchedulingConfig{
- OutboxPollIntervalSeconds: 1,
+ OutboxPollIntervalSeconds: 1,
FullRebuildIntervalSeconds: 0,
DbFallbackEnabled: true,
},
diff --git a/backend/internal/repository/session_limit_cache.go b/backend/internal/repository/session_limit_cache.go
new file mode 100644
index 00000000..16f2a69c
--- /dev/null
+++ b/backend/internal/repository/session_limit_cache.go
@@ -0,0 +1,321 @@
+package repository
+
+import (
+ "context"
+ "fmt"
+ "strconv"
+ "time"
+
+ "github.com/Wei-Shaw/sub2api/internal/service"
+ "github.com/redis/go-redis/v9"
+)
+
+// 会话限制缓存常量定义
+//
+// 设计说明:
+// 使用 Redis 有序集合(Sorted Set)跟踪每个账号的活跃会话:
+// - Key: session_limit:account:{accountID}
+// - Member: sessionUUID(从 metadata.user_id 中提取)
+// - Score: Unix 时间戳(会话最后活跃时间)
+//
+// 通过 ZREMRANGEBYSCORE 自动清理过期会话,无需手动管理 TTL
+const (
+ // 会话限制键前缀
+ // 格式: session_limit:account:{accountID}
+ sessionLimitKeyPrefix = "session_limit:account:"
+
+ // 窗口费用缓存键前缀
+ // 格式: window_cost:account:{accountID}
+ windowCostKeyPrefix = "window_cost:account:"
+
+ // 窗口费用缓存 TTL(30秒)
+ windowCostCacheTTL = 30 * time.Second
+)
+
+var (
+ // registerSessionScript 注册会话活动
+ // 使用 Redis TIME 命令获取服务器时间,避免多实例时钟不同步
+ // KEYS[1] = session_limit:account:{accountID}
+ // ARGV[1] = maxSessions
+ // ARGV[2] = idleTimeout(秒)
+ // ARGV[3] = sessionUUID
+ // 返回: 1 = 允许, 0 = 拒绝
+ registerSessionScript = redis.NewScript(`
+ local key = KEYS[1]
+ local maxSessions = tonumber(ARGV[1])
+ local idleTimeout = tonumber(ARGV[2])
+ local sessionUUID = ARGV[3]
+
+ -- 使用 Redis 服务器时间,确保多实例时钟一致
+ local timeResult = redis.call('TIME')
+ local now = tonumber(timeResult[1])
+ local expireBefore = now - idleTimeout
+
+ -- 清理过期会话
+ redis.call('ZREMRANGEBYSCORE', key, '-inf', expireBefore)
+
+ -- 检查会话是否已存在(支持刷新时间戳)
+ local exists = redis.call('ZSCORE', key, sessionUUID)
+ if exists ~= false then
+ -- 会话已存在,刷新时间戳
+ redis.call('ZADD', key, now, sessionUUID)
+ redis.call('EXPIRE', key, idleTimeout + 60)
+ return 1
+ end
+
+ -- 检查是否达到会话数量上限
+ local count = redis.call('ZCARD', key)
+ if count < maxSessions then
+ -- 未达上限,添加新会话
+ redis.call('ZADD', key, now, sessionUUID)
+ redis.call('EXPIRE', key, idleTimeout + 60)
+ return 1
+ end
+
+ -- 达到上限,拒绝新会话
+ return 0
+ `)
+
+ // refreshSessionScript 刷新会话时间戳
+ // KEYS[1] = session_limit:account:{accountID}
+ // ARGV[1] = idleTimeout(秒)
+ // ARGV[2] = sessionUUID
+ refreshSessionScript = redis.NewScript(`
+ local key = KEYS[1]
+ local idleTimeout = tonumber(ARGV[1])
+ local sessionUUID = ARGV[2]
+
+ local timeResult = redis.call('TIME')
+ local now = tonumber(timeResult[1])
+
+ -- 检查会话是否存在
+ local exists = redis.call('ZSCORE', key, sessionUUID)
+ if exists ~= false then
+ redis.call('ZADD', key, now, sessionUUID)
+ redis.call('EXPIRE', key, idleTimeout + 60)
+ end
+ return 1
+ `)
+
+ // getActiveSessionCountScript 获取活跃会话数
+ // KEYS[1] = session_limit:account:{accountID}
+ // ARGV[1] = idleTimeout(秒)
+ getActiveSessionCountScript = redis.NewScript(`
+ local key = KEYS[1]
+ local idleTimeout = tonumber(ARGV[1])
+
+ local timeResult = redis.call('TIME')
+ local now = tonumber(timeResult[1])
+ local expireBefore = now - idleTimeout
+
+ -- 清理过期会话
+ redis.call('ZREMRANGEBYSCORE', key, '-inf', expireBefore)
+
+ return redis.call('ZCARD', key)
+ `)
+
+ // isSessionActiveScript 检查会话是否活跃
+ // KEYS[1] = session_limit:account:{accountID}
+ // ARGV[1] = idleTimeout(秒)
+ // ARGV[2] = sessionUUID
+ isSessionActiveScript = redis.NewScript(`
+ local key = KEYS[1]
+ local idleTimeout = tonumber(ARGV[1])
+ local sessionUUID = ARGV[2]
+
+ local timeResult = redis.call('TIME')
+ local now = tonumber(timeResult[1])
+ local expireBefore = now - idleTimeout
+
+ -- 获取会话的时间戳
+ local score = redis.call('ZSCORE', key, sessionUUID)
+ if score == false then
+ return 0
+ end
+
+ -- 检查是否过期
+ if tonumber(score) <= expireBefore then
+ return 0
+ end
+
+ return 1
+ `)
+)
+
+type sessionLimitCache struct {
+ rdb *redis.Client
+ defaultIdleTimeout time.Duration // 默认空闲超时(用于 GetActiveSessionCount)
+}
+
+// NewSessionLimitCache 创建会话限制缓存
+// defaultIdleTimeoutMinutes: 默认空闲超时时间(分钟),用于无参数查询
+func NewSessionLimitCache(rdb *redis.Client, defaultIdleTimeoutMinutes int) service.SessionLimitCache {
+ if defaultIdleTimeoutMinutes <= 0 {
+ defaultIdleTimeoutMinutes = 5 // 默认 5 分钟
+ }
+ return &sessionLimitCache{
+ rdb: rdb,
+ defaultIdleTimeout: time.Duration(defaultIdleTimeoutMinutes) * time.Minute,
+ }
+}
+
+// sessionLimitKey 生成会话限制的 Redis 键
+func sessionLimitKey(accountID int64) string {
+ return fmt.Sprintf("%s%d", sessionLimitKeyPrefix, accountID)
+}
+
+// windowCostKey 生成窗口费用缓存的 Redis 键
+func windowCostKey(accountID int64) string {
+ return fmt.Sprintf("%s%d", windowCostKeyPrefix, accountID)
+}
+
+// RegisterSession 注册会话活动
+func (c *sessionLimitCache) RegisterSession(ctx context.Context, accountID int64, sessionUUID string, maxSessions int, idleTimeout time.Duration) (bool, error) {
+ if sessionUUID == "" || maxSessions <= 0 {
+ return true, nil // 无效参数,默认允许
+ }
+
+ key := sessionLimitKey(accountID)
+ idleTimeoutSeconds := int(idleTimeout.Seconds())
+ if idleTimeoutSeconds <= 0 {
+ idleTimeoutSeconds = int(c.defaultIdleTimeout.Seconds())
+ }
+
+ result, err := registerSessionScript.Run(ctx, c.rdb, []string{key}, maxSessions, idleTimeoutSeconds, sessionUUID).Int()
+ if err != nil {
+ return true, err // 失败开放:缓存错误时允许请求通过
+ }
+ return result == 1, nil
+}
+
+// RefreshSession 刷新会话时间戳
+func (c *sessionLimitCache) RefreshSession(ctx context.Context, accountID int64, sessionUUID string, idleTimeout time.Duration) error {
+ if sessionUUID == "" {
+ return nil
+ }
+
+ key := sessionLimitKey(accountID)
+ idleTimeoutSeconds := int(idleTimeout.Seconds())
+ if idleTimeoutSeconds <= 0 {
+ idleTimeoutSeconds = int(c.defaultIdleTimeout.Seconds())
+ }
+
+ _, err := refreshSessionScript.Run(ctx, c.rdb, []string{key}, idleTimeoutSeconds, sessionUUID).Result()
+ return err
+}
+
+// GetActiveSessionCount 获取活跃会话数
+func (c *sessionLimitCache) GetActiveSessionCount(ctx context.Context, accountID int64) (int, error) {
+ key := sessionLimitKey(accountID)
+ idleTimeoutSeconds := int(c.defaultIdleTimeout.Seconds())
+
+ result, err := getActiveSessionCountScript.Run(ctx, c.rdb, []string{key}, idleTimeoutSeconds).Int()
+ if err != nil {
+ return 0, err
+ }
+ return result, nil
+}
+
+// GetActiveSessionCountBatch 批量获取多个账号的活跃会话数
+func (c *sessionLimitCache) GetActiveSessionCountBatch(ctx context.Context, accountIDs []int64) (map[int64]int, error) {
+ if len(accountIDs) == 0 {
+ return make(map[int64]int), nil
+ }
+
+ results := make(map[int64]int, len(accountIDs))
+
+ // 使用 pipeline 批量执行
+ pipe := c.rdb.Pipeline()
+ idleTimeoutSeconds := int(c.defaultIdleTimeout.Seconds())
+
+ cmds := make(map[int64]*redis.Cmd, len(accountIDs))
+ for _, accountID := range accountIDs {
+ key := sessionLimitKey(accountID)
+ cmds[accountID] = getActiveSessionCountScript.Run(ctx, pipe, []string{key}, idleTimeoutSeconds)
+ }
+
+ // 执行 pipeline,即使部分失败也尝试获取成功的结果
+ _, _ = pipe.Exec(ctx)
+
+ for accountID, cmd := range cmds {
+ if result, err := cmd.Int(); err == nil {
+ results[accountID] = result
+ }
+ }
+
+ return results, nil
+}
+
+// IsSessionActive 检查会话是否活跃
+func (c *sessionLimitCache) IsSessionActive(ctx context.Context, accountID int64, sessionUUID string) (bool, error) {
+ if sessionUUID == "" {
+ return false, nil
+ }
+
+ key := sessionLimitKey(accountID)
+ idleTimeoutSeconds := int(c.defaultIdleTimeout.Seconds())
+
+ result, err := isSessionActiveScript.Run(ctx, c.rdb, []string{key}, idleTimeoutSeconds, sessionUUID).Int()
+ if err != nil {
+ return false, err
+ }
+ return result == 1, nil
+}
+
+// ========== 5h窗口费用缓存实现 ==========
+
+// GetWindowCost 获取缓存的窗口费用
+func (c *sessionLimitCache) GetWindowCost(ctx context.Context, accountID int64) (float64, bool, error) {
+ key := windowCostKey(accountID)
+ val, err := c.rdb.Get(ctx, key).Float64()
+ if err == redis.Nil {
+ return 0, false, nil // 缓存未命中
+ }
+ if err != nil {
+ return 0, false, err
+ }
+ return val, true, nil
+}
+
+// SetWindowCost 设置窗口费用缓存
+func (c *sessionLimitCache) SetWindowCost(ctx context.Context, accountID int64, cost float64) error {
+ key := windowCostKey(accountID)
+ return c.rdb.Set(ctx, key, cost, windowCostCacheTTL).Err()
+}
+
+// GetWindowCostBatch 批量获取窗口费用缓存
+func (c *sessionLimitCache) GetWindowCostBatch(ctx context.Context, accountIDs []int64) (map[int64]float64, error) {
+ if len(accountIDs) == 0 {
+ return make(map[int64]float64), nil
+ }
+
+ // 构建批量查询的 keys
+ keys := make([]string, len(accountIDs))
+ for i, accountID := range accountIDs {
+ keys[i] = windowCostKey(accountID)
+ }
+
+ // 使用 MGET 批量获取
+ vals, err := c.rdb.MGet(ctx, keys...).Result()
+ if err != nil {
+ return nil, err
+ }
+
+ results := make(map[int64]float64, len(accountIDs))
+ for i, val := range vals {
+ if val == nil {
+ continue // 缓存未命中
+ }
+ // 尝试解析为 float64
+ switch v := val.(type) {
+ case string:
+ if cost, err := strconv.ParseFloat(v, 64); err == nil {
+ results[accountIDs[i]] = cost
+ }
+ case float64:
+ results[accountIDs[i]] = v
+ }
+ }
+
+ return results, nil
+}
diff --git a/backend/internal/repository/usage_log_repo.go b/backend/internal/repository/usage_log_repo.go
index e483f89f..4a2aaade 100644
--- a/backend/internal/repository/usage_log_repo.go
+++ b/backend/internal/repository/usage_log_repo.go
@@ -22,7 +22,7 @@ import (
"github.com/lib/pq"
)
-const usageLogSelectColumns = "id, user_id, api_key_id, account_id, request_id, model, group_id, subscription_id, input_tokens, output_tokens, cache_creation_tokens, cache_read_tokens, cache_creation_5m_tokens, cache_creation_1h_tokens, input_cost, output_cost, cache_creation_cost, cache_read_cost, total_cost, actual_cost, rate_multiplier, billing_type, stream, duration_ms, first_token_ms, user_agent, ip_address, image_count, image_size, created_at"
+const usageLogSelectColumns = "id, user_id, api_key_id, account_id, request_id, model, group_id, subscription_id, input_tokens, output_tokens, cache_creation_tokens, cache_read_tokens, cache_creation_5m_tokens, cache_creation_1h_tokens, input_cost, output_cost, cache_creation_cost, cache_read_cost, total_cost, actual_cost, rate_multiplier, account_rate_multiplier, billing_type, stream, duration_ms, first_token_ms, user_agent, ip_address, image_count, image_size, created_at"
type usageLogRepository struct {
client *dbent.Client
@@ -105,6 +105,7 @@ func (r *usageLogRepository) Create(ctx context.Context, log *service.UsageLog)
total_cost,
actual_cost,
rate_multiplier,
+ account_rate_multiplier,
billing_type,
stream,
duration_ms,
@@ -120,7 +121,7 @@ func (r *usageLogRepository) Create(ctx context.Context, log *service.UsageLog)
$8, $9, $10, $11,
$12, $13,
$14, $15, $16, $17, $18, $19,
- $20, $21, $22, $23, $24, $25, $26, $27, $28, $29
+ $20, $21, $22, $23, $24, $25, $26, $27, $28, $29, $30
)
ON CONFLICT (request_id, api_key_id) DO NOTHING
RETURNING id, created_at
@@ -160,6 +161,7 @@ func (r *usageLogRepository) Create(ctx context.Context, log *service.UsageLog)
log.TotalCost,
log.ActualCost,
rateMultiplier,
+ log.AccountRateMultiplier,
log.BillingType,
log.Stream,
duration,
@@ -270,13 +272,13 @@ type DashboardStats = usagestats.DashboardStats
func (r *usageLogRepository) GetDashboardStats(ctx context.Context) (*DashboardStats, error) {
stats := &DashboardStats{}
- now := time.Now().UTC()
- todayUTC := truncateToDayUTC(now)
+ now := timezone.Now()
+ todayStart := timezone.Today()
- if err := r.fillDashboardEntityStats(ctx, stats, todayUTC, now); err != nil {
+ if err := r.fillDashboardEntityStats(ctx, stats, todayStart, now); err != nil {
return nil, err
}
- if err := r.fillDashboardUsageStatsAggregated(ctx, stats, todayUTC, now); err != nil {
+ if err := r.fillDashboardUsageStatsAggregated(ctx, stats, todayStart, now); err != nil {
return nil, err
}
@@ -298,13 +300,13 @@ func (r *usageLogRepository) GetDashboardStatsWithRange(ctx context.Context, sta
}
stats := &DashboardStats{}
- now := time.Now().UTC()
- todayUTC := truncateToDayUTC(now)
+ now := timezone.Now()
+ todayStart := timezone.Today()
- if err := r.fillDashboardEntityStats(ctx, stats, todayUTC, now); err != nil {
+ if err := r.fillDashboardEntityStats(ctx, stats, todayStart, now); err != nil {
return nil, err
}
- if err := r.fillDashboardUsageStatsFromUsageLogs(ctx, stats, startUTC, endUTC, todayUTC, now); err != nil {
+ if err := r.fillDashboardUsageStatsFromUsageLogs(ctx, stats, startUTC, endUTC, todayStart, now); err != nil {
return nil, err
}
@@ -455,7 +457,7 @@ func (r *usageLogRepository) fillDashboardUsageStatsAggregated(ctx context.Conte
FROM usage_dashboard_hourly
WHERE bucket_start = $1
`
- hourStart := now.UTC().Truncate(time.Hour)
+ hourStart := now.In(timezone.Location()).Truncate(time.Hour)
if err := scanSingleRow(ctx, r.sql, hourlyActiveQuery, []any{hourStart}, &stats.HourlyActiveUsers); err != nil {
if err != sql.ErrNoRows {
return err
@@ -835,7 +837,9 @@ func (r *usageLogRepository) GetAccountTodayStats(ctx context.Context, accountID
SELECT
COUNT(*) as requests,
COALESCE(SUM(input_tokens + output_tokens + cache_creation_tokens + cache_read_tokens), 0) as tokens,
- COALESCE(SUM(actual_cost), 0) as cost
+ COALESCE(SUM(total_cost * COALESCE(account_rate_multiplier, 1)), 0) as cost,
+ COALESCE(SUM(total_cost), 0) as standard_cost,
+ COALESCE(SUM(actual_cost), 0) as user_cost
FROM usage_logs
WHERE account_id = $1 AND created_at >= $2
`
@@ -849,6 +853,8 @@ func (r *usageLogRepository) GetAccountTodayStats(ctx context.Context, accountID
&stats.Requests,
&stats.Tokens,
&stats.Cost,
+ &stats.StandardCost,
+ &stats.UserCost,
); err != nil {
return nil, err
}
@@ -861,7 +867,9 @@ func (r *usageLogRepository) GetAccountWindowStats(ctx context.Context, accountI
SELECT
COUNT(*) as requests,
COALESCE(SUM(input_tokens + output_tokens + cache_creation_tokens + cache_read_tokens), 0) as tokens,
- COALESCE(SUM(actual_cost), 0) as cost
+ COALESCE(SUM(total_cost * COALESCE(account_rate_multiplier, 1)), 0) as cost,
+ COALESCE(SUM(total_cost), 0) as standard_cost,
+ COALESCE(SUM(actual_cost), 0) as user_cost
FROM usage_logs
WHERE account_id = $1 AND created_at >= $2
`
@@ -875,6 +883,8 @@ func (r *usageLogRepository) GetAccountWindowStats(ctx context.Context, accountI
&stats.Requests,
&stats.Tokens,
&stats.Cost,
+ &stats.StandardCost,
+ &stats.UserCost,
); err != nil {
return nil, err
}
@@ -1400,8 +1410,8 @@ func (r *usageLogRepository) GetBatchAPIKeyUsageStats(ctx context.Context, apiKe
return result, nil
}
-// GetUsageTrendWithFilters returns usage trend data with optional user/api_key filters
-func (r *usageLogRepository) GetUsageTrendWithFilters(ctx context.Context, startTime, endTime time.Time, granularity string, userID, apiKeyID int64) (results []TrendDataPoint, err error) {
+// GetUsageTrendWithFilters returns usage trend data with optional filters
+func (r *usageLogRepository) GetUsageTrendWithFilters(ctx context.Context, startTime, endTime time.Time, granularity string, userID, apiKeyID, accountID, groupID int64, model string, stream *bool) (results []TrendDataPoint, err error) {
dateFormat := "YYYY-MM-DD"
if granularity == "hour" {
dateFormat = "YYYY-MM-DD HH24:00"
@@ -1430,6 +1440,22 @@ func (r *usageLogRepository) GetUsageTrendWithFilters(ctx context.Context, start
query += fmt.Sprintf(" AND api_key_id = $%d", len(args)+1)
args = append(args, apiKeyID)
}
+ if accountID > 0 {
+ query += fmt.Sprintf(" AND account_id = $%d", len(args)+1)
+ args = append(args, accountID)
+ }
+ if groupID > 0 {
+ query += fmt.Sprintf(" AND group_id = $%d", len(args)+1)
+ args = append(args, groupID)
+ }
+ if model != "" {
+ query += fmt.Sprintf(" AND model = $%d", len(args)+1)
+ args = append(args, model)
+ }
+ if stream != nil {
+ query += fmt.Sprintf(" AND stream = $%d", len(args)+1)
+ args = append(args, *stream)
+ }
query += " GROUP BY date ORDER BY date ASC"
rows, err := r.sql.QueryContext(ctx, query, args...)
@@ -1452,9 +1478,15 @@ func (r *usageLogRepository) GetUsageTrendWithFilters(ctx context.Context, start
return results, nil
}
-// GetModelStatsWithFilters returns model statistics with optional user/api_key filters
-func (r *usageLogRepository) GetModelStatsWithFilters(ctx context.Context, startTime, endTime time.Time, userID, apiKeyID, accountID int64) (results []ModelStat, err error) {
- query := `
+// GetModelStatsWithFilters returns model statistics with optional filters
+func (r *usageLogRepository) GetModelStatsWithFilters(ctx context.Context, startTime, endTime time.Time, userID, apiKeyID, accountID, groupID int64, stream *bool) (results []ModelStat, err error) {
+ actualCostExpr := "COALESCE(SUM(actual_cost), 0) as actual_cost"
+ // 当仅按 account_id 聚合时,实际费用使用账号倍率(total_cost * account_rate_multiplier)。
+ if accountID > 0 && userID == 0 && apiKeyID == 0 {
+ actualCostExpr = "COALESCE(SUM(total_cost * COALESCE(account_rate_multiplier, 1)), 0) as actual_cost"
+ }
+
+ query := fmt.Sprintf(`
SELECT
model,
COUNT(*) as requests,
@@ -1462,10 +1494,10 @@ func (r *usageLogRepository) GetModelStatsWithFilters(ctx context.Context, start
COALESCE(SUM(output_tokens), 0) as output_tokens,
COALESCE(SUM(input_tokens + output_tokens + cache_creation_tokens + cache_read_tokens), 0) as total_tokens,
COALESCE(SUM(total_cost), 0) as cost,
- COALESCE(SUM(actual_cost), 0) as actual_cost
+ %s
FROM usage_logs
WHERE created_at >= $1 AND created_at < $2
- `
+ `, actualCostExpr)
args := []any{startTime, endTime}
if userID > 0 {
@@ -1480,6 +1512,14 @@ func (r *usageLogRepository) GetModelStatsWithFilters(ctx context.Context, start
query += fmt.Sprintf(" AND account_id = $%d", len(args)+1)
args = append(args, accountID)
}
+ if groupID > 0 {
+ query += fmt.Sprintf(" AND group_id = $%d", len(args)+1)
+ args = append(args, groupID)
+ }
+ if stream != nil {
+ query += fmt.Sprintf(" AND stream = $%d", len(args)+1)
+ args = append(args, *stream)
+ }
query += " GROUP BY model ORDER BY total_tokens DESC"
rows, err := r.sql.QueryContext(ctx, query, args...)
@@ -1587,12 +1627,14 @@ func (r *usageLogRepository) GetStatsWithFilters(ctx context.Context, filters Us
COALESCE(SUM(cache_creation_tokens + cache_read_tokens), 0) as total_cache_tokens,
COALESCE(SUM(total_cost), 0) as total_cost,
COALESCE(SUM(actual_cost), 0) as total_actual_cost,
+ COALESCE(SUM(total_cost * COALESCE(account_rate_multiplier, 1)), 0) as total_account_cost,
COALESCE(AVG(duration_ms), 0) as avg_duration_ms
FROM usage_logs
%s
`, buildWhere(conditions))
stats := &UsageStats{}
+ var totalAccountCost float64
if err := scanSingleRow(
ctx,
r.sql,
@@ -1604,10 +1646,14 @@ func (r *usageLogRepository) GetStatsWithFilters(ctx context.Context, filters Us
&stats.TotalCacheTokens,
&stats.TotalCost,
&stats.TotalActualCost,
+ &totalAccountCost,
&stats.AverageDurationMs,
); err != nil {
return nil, err
}
+ if filters.AccountID > 0 {
+ stats.TotalAccountCost = &totalAccountCost
+ }
stats.TotalTokens = stats.TotalInputTokens + stats.TotalOutputTokens + stats.TotalCacheTokens
return stats, nil
}
@@ -1634,7 +1680,8 @@ func (r *usageLogRepository) GetAccountUsageStats(ctx context.Context, accountID
COUNT(*) as requests,
COALESCE(SUM(input_tokens + output_tokens + cache_creation_tokens + cache_read_tokens), 0) as tokens,
COALESCE(SUM(total_cost), 0) as cost,
- COALESCE(SUM(actual_cost), 0) as actual_cost
+ COALESCE(SUM(total_cost * COALESCE(account_rate_multiplier, 1)), 0) as actual_cost,
+ COALESCE(SUM(actual_cost), 0) as user_cost
FROM usage_logs
WHERE account_id = $1 AND created_at >= $2 AND created_at < $3
GROUP BY date
@@ -1661,7 +1708,8 @@ func (r *usageLogRepository) GetAccountUsageStats(ctx context.Context, accountID
var tokens int64
var cost float64
var actualCost float64
- if err = rows.Scan(&date, &requests, &tokens, &cost, &actualCost); err != nil {
+ var userCost float64
+ if err = rows.Scan(&date, &requests, &tokens, &cost, &actualCost, &userCost); err != nil {
return nil, err
}
t, _ := time.Parse("2006-01-02", date)
@@ -1672,19 +1720,21 @@ func (r *usageLogRepository) GetAccountUsageStats(ctx context.Context, accountID
Tokens: tokens,
Cost: cost,
ActualCost: actualCost,
+ UserCost: userCost,
})
}
if err = rows.Err(); err != nil {
return nil, err
}
- var totalActualCost, totalStandardCost float64
+ var totalAccountCost, totalUserCost, totalStandardCost float64
var totalRequests, totalTokens int64
var highestCostDay, highestRequestDay *AccountUsageHistory
for i := range history {
h := &history[i]
- totalActualCost += h.ActualCost
+ totalAccountCost += h.ActualCost
+ totalUserCost += h.UserCost
totalStandardCost += h.Cost
totalRequests += h.Requests
totalTokens += h.Tokens
@@ -1711,11 +1761,13 @@ func (r *usageLogRepository) GetAccountUsageStats(ctx context.Context, accountID
summary := AccountUsageSummary{
Days: daysCount,
ActualDaysUsed: actualDaysUsed,
- TotalCost: totalActualCost,
+ TotalCost: totalAccountCost,
+ TotalUserCost: totalUserCost,
TotalStandardCost: totalStandardCost,
TotalRequests: totalRequests,
TotalTokens: totalTokens,
- AvgDailyCost: totalActualCost / float64(actualDaysUsed),
+ AvgDailyCost: totalAccountCost / float64(actualDaysUsed),
+ AvgDailyUserCost: totalUserCost / float64(actualDaysUsed),
AvgDailyRequests: float64(totalRequests) / float64(actualDaysUsed),
AvgDailyTokens: float64(totalTokens) / float64(actualDaysUsed),
AvgDurationMs: avgDuration,
@@ -1727,11 +1779,13 @@ func (r *usageLogRepository) GetAccountUsageStats(ctx context.Context, accountID
summary.Today = &struct {
Date string `json:"date"`
Cost float64 `json:"cost"`
+ UserCost float64 `json:"user_cost"`
Requests int64 `json:"requests"`
Tokens int64 `json:"tokens"`
}{
Date: history[i].Date,
Cost: history[i].ActualCost,
+ UserCost: history[i].UserCost,
Requests: history[i].Requests,
Tokens: history[i].Tokens,
}
@@ -1744,11 +1798,13 @@ func (r *usageLogRepository) GetAccountUsageStats(ctx context.Context, accountID
Date string `json:"date"`
Label string `json:"label"`
Cost float64 `json:"cost"`
+ UserCost float64 `json:"user_cost"`
Requests int64 `json:"requests"`
}{
Date: highestCostDay.Date,
Label: highestCostDay.Label,
Cost: highestCostDay.ActualCost,
+ UserCost: highestCostDay.UserCost,
Requests: highestCostDay.Requests,
}
}
@@ -1759,15 +1815,17 @@ func (r *usageLogRepository) GetAccountUsageStats(ctx context.Context, accountID
Label string `json:"label"`
Requests int64 `json:"requests"`
Cost float64 `json:"cost"`
+ UserCost float64 `json:"user_cost"`
}{
Date: highestRequestDay.Date,
Label: highestRequestDay.Label,
Requests: highestRequestDay.Requests,
Cost: highestRequestDay.ActualCost,
+ UserCost: highestRequestDay.UserCost,
}
}
- models, err := r.GetModelStatsWithFilters(ctx, startTime, endTime, 0, 0, accountID)
+ models, err := r.GetModelStatsWithFilters(ctx, startTime, endTime, 0, 0, accountID, 0, nil)
if err != nil {
models = []ModelStat{}
}
@@ -1994,36 +2052,37 @@ func (r *usageLogRepository) loadSubscriptions(ctx context.Context, ids []int64)
func scanUsageLog(scanner interface{ Scan(...any) error }) (*service.UsageLog, error) {
var (
- id int64
- userID int64
- apiKeyID int64
- accountID int64
- requestID sql.NullString
- model string
- groupID sql.NullInt64
- subscriptionID sql.NullInt64
- inputTokens int
- outputTokens int
- cacheCreationTokens int
- cacheReadTokens int
- cacheCreation5m int
- cacheCreation1h int
- inputCost float64
- outputCost float64
- cacheCreationCost float64
- cacheReadCost float64
- totalCost float64
- actualCost float64
- rateMultiplier float64
- billingType int16
- stream bool
- durationMs sql.NullInt64
- firstTokenMs sql.NullInt64
- userAgent sql.NullString
- ipAddress sql.NullString
- imageCount int
- imageSize sql.NullString
- createdAt time.Time
+ id int64
+ userID int64
+ apiKeyID int64
+ accountID int64
+ requestID sql.NullString
+ model string
+ groupID sql.NullInt64
+ subscriptionID sql.NullInt64
+ inputTokens int
+ outputTokens int
+ cacheCreationTokens int
+ cacheReadTokens int
+ cacheCreation5m int
+ cacheCreation1h int
+ inputCost float64
+ outputCost float64
+ cacheCreationCost float64
+ cacheReadCost float64
+ totalCost float64
+ actualCost float64
+ rateMultiplier float64
+ accountRateMultiplier sql.NullFloat64
+ billingType int16
+ stream bool
+ durationMs sql.NullInt64
+ firstTokenMs sql.NullInt64
+ userAgent sql.NullString
+ ipAddress sql.NullString
+ imageCount int
+ imageSize sql.NullString
+ createdAt time.Time
)
if err := scanner.Scan(
@@ -2048,6 +2107,7 @@ func scanUsageLog(scanner interface{ Scan(...any) error }) (*service.UsageLog, e
&totalCost,
&actualCost,
&rateMultiplier,
+ &accountRateMultiplier,
&billingType,
&stream,
&durationMs,
@@ -2080,6 +2140,7 @@ func scanUsageLog(scanner interface{ Scan(...any) error }) (*service.UsageLog, e
TotalCost: totalCost,
ActualCost: actualCost,
RateMultiplier: rateMultiplier,
+ AccountRateMultiplier: nullFloat64Ptr(accountRateMultiplier),
BillingType: int8(billingType),
Stream: stream,
ImageCount: imageCount,
@@ -2186,6 +2247,14 @@ func nullInt(v *int) sql.NullInt64 {
return sql.NullInt64{Int64: int64(*v), Valid: true}
}
+func nullFloat64Ptr(v sql.NullFloat64) *float64 {
+ if !v.Valid {
+ return nil
+ }
+ out := v.Float64
+ return &out
+}
+
func nullString(v *string) sql.NullString {
if v == nil || *v == "" {
return sql.NullString{}
diff --git a/backend/internal/repository/usage_log_repo_integration_test.go b/backend/internal/repository/usage_log_repo_integration_test.go
index 3f90e49e..7174be18 100644
--- a/backend/internal/repository/usage_log_repo_integration_test.go
+++ b/backend/internal/repository/usage_log_repo_integration_test.go
@@ -11,6 +11,7 @@ import (
dbent "github.com/Wei-Shaw/sub2api/ent"
"github.com/Wei-Shaw/sub2api/internal/pkg/pagination"
+ "github.com/Wei-Shaw/sub2api/internal/pkg/timezone"
"github.com/Wei-Shaw/sub2api/internal/pkg/usagestats"
"github.com/Wei-Shaw/sub2api/internal/service"
"github.com/stretchr/testify/suite"
@@ -36,6 +37,12 @@ func TestUsageLogRepoSuite(t *testing.T) {
suite.Run(t, new(UsageLogRepoSuite))
}
+// truncateToDayUTC 截断到 UTC 日期边界(测试辅助函数)
+func truncateToDayUTC(t time.Time) time.Time {
+ t = t.UTC()
+ return time.Date(t.Year(), t.Month(), t.Day(), 0, 0, 0, 0, time.UTC)
+}
+
func (s *UsageLogRepoSuite) createUsageLog(user *service.User, apiKey *service.APIKey, account *service.Account, inputTokens, outputTokens int, cost float64, createdAt time.Time) *service.UsageLog {
log := &service.UsageLog{
UserID: user.ID,
@@ -95,6 +102,34 @@ func (s *UsageLogRepoSuite) TestGetByID_NotFound() {
s.Require().Error(err, "expected error for non-existent ID")
}
+func (s *UsageLogRepoSuite) TestGetByID_ReturnsAccountRateMultiplier() {
+ user := mustCreateUser(s.T(), s.client, &service.User{Email: "getbyid-mult@test.com"})
+ apiKey := mustCreateApiKey(s.T(), s.client, &service.APIKey{UserID: user.ID, Key: "sk-getbyid-mult", Name: "k"})
+ account := mustCreateAccount(s.T(), s.client, &service.Account{Name: "acc-getbyid-mult"})
+
+ m := 0.5
+ log := &service.UsageLog{
+ UserID: user.ID,
+ APIKeyID: apiKey.ID,
+ AccountID: account.ID,
+ RequestID: uuid.New().String(),
+ Model: "claude-3",
+ InputTokens: 10,
+ OutputTokens: 20,
+ TotalCost: 1.0,
+ ActualCost: 2.0,
+ AccountRateMultiplier: &m,
+ CreatedAt: timezone.Today().Add(2 * time.Hour),
+ }
+ _, err := s.repo.Create(s.ctx, log)
+ s.Require().NoError(err)
+
+ got, err := s.repo.GetByID(s.ctx, log.ID)
+ s.Require().NoError(err)
+ s.Require().NotNil(got.AccountRateMultiplier)
+ s.Require().InEpsilon(0.5, *got.AccountRateMultiplier, 0.0001)
+}
+
// --- Delete ---
func (s *UsageLogRepoSuite) TestDelete() {
@@ -403,12 +438,49 @@ func (s *UsageLogRepoSuite) TestGetAccountTodayStats() {
apiKey := mustCreateApiKey(s.T(), s.client, &service.APIKey{UserID: user.ID, Key: "sk-acctoday", Name: "k"})
account := mustCreateAccount(s.T(), s.client, &service.Account{Name: "acc-today"})
- s.createUsageLog(user, apiKey, account, 10, 20, 0.5, time.Now())
+ createdAt := timezone.Today().Add(1 * time.Hour)
+
+ m1 := 1.5
+ m2 := 0.0
+ _, err := s.repo.Create(s.ctx, &service.UsageLog{
+ UserID: user.ID,
+ APIKeyID: apiKey.ID,
+ AccountID: account.ID,
+ RequestID: uuid.New().String(),
+ Model: "claude-3",
+ InputTokens: 10,
+ OutputTokens: 20,
+ TotalCost: 1.0,
+ ActualCost: 2.0,
+ AccountRateMultiplier: &m1,
+ CreatedAt: createdAt,
+ })
+ s.Require().NoError(err)
+ _, err = s.repo.Create(s.ctx, &service.UsageLog{
+ UserID: user.ID,
+ APIKeyID: apiKey.ID,
+ AccountID: account.ID,
+ RequestID: uuid.New().String(),
+ Model: "claude-3",
+ InputTokens: 5,
+ OutputTokens: 5,
+ TotalCost: 0.5,
+ ActualCost: 1.0,
+ AccountRateMultiplier: &m2,
+ CreatedAt: createdAt,
+ })
+ s.Require().NoError(err)
stats, err := s.repo.GetAccountTodayStats(s.ctx, account.ID)
s.Require().NoError(err, "GetAccountTodayStats")
- s.Require().Equal(int64(1), stats.Requests)
- s.Require().Equal(int64(30), stats.Tokens)
+ s.Require().Equal(int64(2), stats.Requests)
+ s.Require().Equal(int64(40), stats.Tokens)
+ // account cost = SUM(total_cost * account_rate_multiplier)
+ s.Require().InEpsilon(1.5, stats.Cost, 0.0001)
+ // standard cost = SUM(total_cost)
+ s.Require().InEpsilon(1.5, stats.StandardCost, 0.0001)
+ // user cost = SUM(actual_cost)
+ s.Require().InEpsilon(3.0, stats.UserCost, 0.0001)
}
func (s *UsageLogRepoSuite) TestDashboardAggregationConsistency() {
@@ -416,8 +488,8 @@ func (s *UsageLogRepoSuite) TestDashboardAggregationConsistency() {
// 使用固定的时间偏移确保 hour1 和 hour2 在同一天且都在过去
// 选择当天 02:00 和 03:00 作为测试时间点(基于 now 的日期)
dayStart := truncateToDayUTC(now)
- hour1 := dayStart.Add(2 * time.Hour) // 当天 02:00
- hour2 := dayStart.Add(3 * time.Hour) // 当天 03:00
+ hour1 := dayStart.Add(2 * time.Hour) // 当天 02:00
+ hour2 := dayStart.Add(3 * time.Hour) // 当天 03:00
// 如果当前时间早于 hour2,则使用昨天的时间
if now.Before(hour2.Add(time.Hour)) {
dayStart = dayStart.Add(-24 * time.Hour)
@@ -872,17 +944,17 @@ func (s *UsageLogRepoSuite) TestGetUsageTrendWithFilters() {
endTime := base.Add(48 * time.Hour)
// Test with user filter
- trend, err := s.repo.GetUsageTrendWithFilters(s.ctx, startTime, endTime, "day", user.ID, 0)
+ trend, err := s.repo.GetUsageTrendWithFilters(s.ctx, startTime, endTime, "day", user.ID, 0, 0, 0, "", nil)
s.Require().NoError(err, "GetUsageTrendWithFilters user filter")
s.Require().Len(trend, 2)
// Test with apiKey filter
- trend, err = s.repo.GetUsageTrendWithFilters(s.ctx, startTime, endTime, "day", 0, apiKey.ID)
+ trend, err = s.repo.GetUsageTrendWithFilters(s.ctx, startTime, endTime, "day", 0, apiKey.ID, 0, 0, "", nil)
s.Require().NoError(err, "GetUsageTrendWithFilters apiKey filter")
s.Require().Len(trend, 2)
// Test with both filters
- trend, err = s.repo.GetUsageTrendWithFilters(s.ctx, startTime, endTime, "day", user.ID, apiKey.ID)
+ trend, err = s.repo.GetUsageTrendWithFilters(s.ctx, startTime, endTime, "day", user.ID, apiKey.ID, 0, 0, "", nil)
s.Require().NoError(err, "GetUsageTrendWithFilters both filters")
s.Require().Len(trend, 2)
}
@@ -899,7 +971,7 @@ func (s *UsageLogRepoSuite) TestGetUsageTrendWithFilters_HourlyGranularity() {
startTime := base.Add(-1 * time.Hour)
endTime := base.Add(3 * time.Hour)
- trend, err := s.repo.GetUsageTrendWithFilters(s.ctx, startTime, endTime, "hour", user.ID, 0)
+ trend, err := s.repo.GetUsageTrendWithFilters(s.ctx, startTime, endTime, "hour", user.ID, 0, 0, 0, "", nil)
s.Require().NoError(err, "GetUsageTrendWithFilters hourly")
s.Require().Len(trend, 2)
}
@@ -945,17 +1017,17 @@ func (s *UsageLogRepoSuite) TestGetModelStatsWithFilters() {
endTime := base.Add(2 * time.Hour)
// Test with user filter
- stats, err := s.repo.GetModelStatsWithFilters(s.ctx, startTime, endTime, user.ID, 0, 0)
+ stats, err := s.repo.GetModelStatsWithFilters(s.ctx, startTime, endTime, user.ID, 0, 0, 0, nil)
s.Require().NoError(err, "GetModelStatsWithFilters user filter")
s.Require().Len(stats, 2)
// Test with apiKey filter
- stats, err = s.repo.GetModelStatsWithFilters(s.ctx, startTime, endTime, 0, apiKey.ID, 0)
+ stats, err = s.repo.GetModelStatsWithFilters(s.ctx, startTime, endTime, 0, apiKey.ID, 0, 0, nil)
s.Require().NoError(err, "GetModelStatsWithFilters apiKey filter")
s.Require().Len(stats, 2)
// Test with account filter
- stats, err = s.repo.GetModelStatsWithFilters(s.ctx, startTime, endTime, 0, 0, account.ID)
+ stats, err = s.repo.GetModelStatsWithFilters(s.ctx, startTime, endTime, 0, 0, account.ID, 0, nil)
s.Require().NoError(err, "GetModelStatsWithFilters account filter")
s.Require().Len(stats, 2)
}
diff --git a/backend/internal/repository/wire.go b/backend/internal/repository/wire.go
index 45a8f182..77ed37e1 100644
--- a/backend/internal/repository/wire.go
+++ b/backend/internal/repository/wire.go
@@ -37,6 +37,16 @@ func ProvidePricingRemoteClient(cfg *config.Config) service.PricingRemoteClient
return NewPricingRemoteClient(cfg.Update.ProxyURL)
}
+// ProvideSessionLimitCache 创建会话限制缓存
+// 用于 Anthropic OAuth/SetupToken 账号的并发会话数量控制
+func ProvideSessionLimitCache(rdb *redis.Client, cfg *config.Config) service.SessionLimitCache {
+ defaultIdleTimeoutMinutes := 5 // 默认 5 分钟空闲超时
+ if cfg != nil && cfg.Gateway.SessionIdleTimeoutMinutes > 0 {
+ defaultIdleTimeoutMinutes = cfg.Gateway.SessionIdleTimeoutMinutes
+ }
+ return NewSessionLimitCache(rdb, defaultIdleTimeoutMinutes)
+}
+
// ProviderSet is the Wire provider set for all repositories
var ProviderSet = wire.NewSet(
NewUserRepository,
@@ -61,6 +71,7 @@ var ProviderSet = wire.NewSet(
NewTempUnschedCache,
NewTimeoutCounterCache,
ProvideConcurrencyCache,
+ ProvideSessionLimitCache,
NewDashboardCache,
NewEmailCache,
NewIdentityCache,
@@ -69,6 +80,7 @@ var ProviderSet = wire.NewSet(
NewGeminiTokenCache,
NewSchedulerCache,
NewSchedulerOutboxRepository,
+ NewProxyLatencyCache,
// HTTP service ports (DI Strategy A: return interface directly)
NewTurnstileVerifier,
diff --git a/backend/internal/server/api_contract_test.go b/backend/internal/server/api_contract_test.go
index d96732bd..356b4a4e 100644
--- a/backend/internal/server/api_contract_test.go
+++ b/backend/internal/server/api_contract_test.go
@@ -239,9 +239,10 @@ func TestAPIContracts(t *testing.T) {
"cache_creation_cost": 0,
"cache_read_cost": 0,
"total_cost": 0.5,
- "actual_cost": 0.5,
- "rate_multiplier": 1,
- "billing_type": 0,
+ "actual_cost": 0.5,
+ "rate_multiplier": 1,
+ "account_rate_multiplier": null,
+ "billing_type": 0,
"stream": true,
"duration_ms": 100,
"first_token_ms": 50,
@@ -262,11 +263,11 @@ func TestAPIContracts(t *testing.T) {
name: "GET /api/v1/admin/settings",
setup: func(t *testing.T, deps *contractDeps) {
t.Helper()
- deps.settingRepo.SetAll(map[string]string{
- service.SettingKeyRegistrationEnabled: "true",
- service.SettingKeyEmailVerifyEnabled: "false",
+ deps.settingRepo.SetAll(map[string]string{
+ service.SettingKeyRegistrationEnabled: "true",
+ service.SettingKeyEmailVerifyEnabled: "false",
- service.SettingKeySMTPHost: "smtp.example.com",
+ service.SettingKeySMTPHost: "smtp.example.com",
service.SettingKeySMTPPort: "587",
service.SettingKeySMTPUsername: "user",
service.SettingKeySMTPPassword: "secret",
@@ -285,15 +286,15 @@ func TestAPIContracts(t *testing.T) {
service.SettingKeyContactInfo: "support",
service.SettingKeyDocURL: "https://docs.example.com",
- service.SettingKeyDefaultConcurrency: "5",
- service.SettingKeyDefaultBalance: "1.25",
+ service.SettingKeyDefaultConcurrency: "5",
+ service.SettingKeyDefaultBalance: "1.25",
- service.SettingKeyOpsMonitoringEnabled: "false",
- service.SettingKeyOpsRealtimeMonitoringEnabled: "true",
- service.SettingKeyOpsQueryModeDefault: "auto",
- service.SettingKeyOpsMetricsIntervalSeconds: "60",
- })
- },
+ service.SettingKeyOpsMonitoringEnabled: "false",
+ service.SettingKeyOpsRealtimeMonitoringEnabled: "true",
+ service.SettingKeyOpsQueryModeDefault: "auto",
+ service.SettingKeyOpsMetricsIntervalSeconds: "60",
+ })
+ },
method: http.MethodGet,
path: "/api/v1/admin/settings",
wantStatus: http.StatusOK,
@@ -435,12 +436,12 @@ func newContractDeps(t *testing.T) *contractDeps {
settingRepo := newStubSettingRepo()
settingService := service.NewSettingService(settingRepo, cfg)
- adminService := service.NewAdminService(userRepo, groupRepo, &accountRepo, proxyRepo, apiKeyRepo, redeemRepo, nil, nil, nil)
+ adminService := service.NewAdminService(userRepo, groupRepo, &accountRepo, proxyRepo, apiKeyRepo, redeemRepo, nil, nil, nil, nil)
authHandler := handler.NewAuthHandler(cfg, nil, userService, settingService, nil)
apiKeyHandler := handler.NewAPIKeyHandler(apiKeyService)
usageHandler := handler.NewUsageHandler(usageService, apiKeyService)
adminSettingHandler := adminhandler.NewSettingHandler(settingService, nil, nil, nil)
- adminAccountHandler := adminhandler.NewAccountHandler(adminService, nil, nil, nil, nil, nil, nil, nil, nil, nil)
+ adminAccountHandler := adminhandler.NewAccountHandler(adminService, nil, nil, nil, nil, nil, nil, nil, nil, nil, nil)
jwtAuth := func(c *gin.Context) {
c.Set(string(middleware.ContextKeyUser), middleware.AuthSubject{
@@ -779,6 +780,10 @@ func (s *stubAccountRepo) SetAntigravityQuotaScopeLimit(ctx context.Context, id
return errors.New("not implemented")
}
+func (s *stubAccountRepo) SetModelRateLimit(ctx context.Context, id int64, scope string, resetAt time.Time) error {
+ return errors.New("not implemented")
+}
+
func (s *stubAccountRepo) SetOverloaded(ctx context.Context, id int64, until time.Time) error {
return errors.New("not implemented")
}
@@ -799,6 +804,10 @@ func (s *stubAccountRepo) ClearAntigravityQuotaScopes(ctx context.Context, id in
return errors.New("not implemented")
}
+func (s *stubAccountRepo) ClearModelRateLimits(ctx context.Context, id int64) error {
+ return errors.New("not implemented")
+}
+
func (s *stubAccountRepo) UpdateSessionWindow(ctx context.Context, id int64, start, end *time.Time, status string) error {
return errors.New("not implemented")
}
@@ -858,6 +867,10 @@ func (stubProxyRepo) CountAccountsByProxyID(ctx context.Context, proxyID int64)
return 0, errors.New("not implemented")
}
+func (stubProxyRepo) ListAccountSummariesByProxyID(ctx context.Context, proxyID int64) ([]service.ProxyAccountSummary, error) {
+ return nil, errors.New("not implemented")
+}
+
type stubRedeemCodeRepo struct{}
func (stubRedeemCodeRepo) Create(ctx context.Context, code *service.RedeemCode) error {
@@ -1229,11 +1242,11 @@ func (r *stubUsageLogRepo) GetDashboardStats(ctx context.Context) (*usagestats.D
return nil, errors.New("not implemented")
}
-func (r *stubUsageLogRepo) GetUsageTrendWithFilters(ctx context.Context, startTime, endTime time.Time, granularity string, userID, apiKeyID int64) ([]usagestats.TrendDataPoint, error) {
+func (r *stubUsageLogRepo) GetUsageTrendWithFilters(ctx context.Context, startTime, endTime time.Time, granularity string, userID, apiKeyID, accountID, groupID int64, model string, stream *bool) ([]usagestats.TrendDataPoint, error) {
return nil, errors.New("not implemented")
}
-func (r *stubUsageLogRepo) GetModelStatsWithFilters(ctx context.Context, startTime, endTime time.Time, userID, apiKeyID, accountID int64) ([]usagestats.ModelStat, error) {
+func (r *stubUsageLogRepo) GetModelStatsWithFilters(ctx context.Context, startTime, endTime time.Time, userID, apiKeyID, accountID, groupID int64, stream *bool) ([]usagestats.ModelStat, error) {
return nil, errors.New("not implemented")
}
diff --git a/backend/internal/server/middleware/security_headers.go b/backend/internal/server/middleware/security_headers.go
index 9fca0cd3..9ce7f449 100644
--- a/backend/internal/server/middleware/security_headers.go
+++ b/backend/internal/server/middleware/security_headers.go
@@ -1,12 +1,40 @@
package middleware
import (
+ "crypto/rand"
+ "encoding/base64"
"strings"
"github.com/Wei-Shaw/sub2api/internal/config"
"github.com/gin-gonic/gin"
)
+const (
+ // CSPNonceKey is the context key for storing the CSP nonce
+ CSPNonceKey = "csp_nonce"
+ // NonceTemplate is the placeholder in CSP policy for nonce
+ NonceTemplate = "__CSP_NONCE__"
+ // CloudflareInsightsDomain is the domain for Cloudflare Web Analytics
+ CloudflareInsightsDomain = "https://static.cloudflareinsights.com"
+)
+
+// GenerateNonce generates a cryptographically secure random nonce
+func GenerateNonce() string {
+ b := make([]byte, 16)
+ _, _ = rand.Read(b)
+ return base64.StdEncoding.EncodeToString(b)
+}
+
+// GetNonceFromContext retrieves the CSP nonce from gin context
+func GetNonceFromContext(c *gin.Context) string {
+ if nonce, exists := c.Get(CSPNonceKey); exists {
+ if s, ok := nonce.(string); ok {
+ return s
+ }
+ }
+ return ""
+}
+
// SecurityHeaders sets baseline security headers for all responses.
func SecurityHeaders(cfg config.CSPConfig) gin.HandlerFunc {
policy := strings.TrimSpace(cfg.Policy)
@@ -14,13 +42,75 @@ func SecurityHeaders(cfg config.CSPConfig) gin.HandlerFunc {
policy = config.DefaultCSPPolicy
}
+ // Enhance policy with required directives (nonce placeholder and Cloudflare Insights)
+ policy = enhanceCSPPolicy(policy)
+
return func(c *gin.Context) {
c.Header("X-Content-Type-Options", "nosniff")
c.Header("X-Frame-Options", "DENY")
c.Header("Referrer-Policy", "strict-origin-when-cross-origin")
+
if cfg.Enabled {
- c.Header("Content-Security-Policy", policy)
+ // Generate nonce for this request
+ nonce := GenerateNonce()
+ c.Set(CSPNonceKey, nonce)
+
+ // Replace nonce placeholder in policy
+ finalPolicy := strings.ReplaceAll(policy, NonceTemplate, "'nonce-"+nonce+"'")
+ c.Header("Content-Security-Policy", finalPolicy)
}
c.Next()
}
}
+
+// enhanceCSPPolicy ensures the CSP policy includes nonce support and Cloudflare Insights domain.
+// This allows the application to work correctly even if the config file has an older CSP policy.
+func enhanceCSPPolicy(policy string) string {
+ // Add nonce placeholder to script-src if not present
+ if !strings.Contains(policy, NonceTemplate) && !strings.Contains(policy, "'nonce-") {
+ policy = addToDirective(policy, "script-src", NonceTemplate)
+ }
+
+ // Add Cloudflare Insights domain to script-src if not present
+ if !strings.Contains(policy, CloudflareInsightsDomain) {
+ policy = addToDirective(policy, "script-src", CloudflareInsightsDomain)
+ }
+
+ return policy
+}
+
+// addToDirective adds a value to a specific CSP directive.
+// If the directive doesn't exist, it will be added after default-src.
+func addToDirective(policy, directive, value string) string {
+ // Find the directive in the policy
+ directivePrefix := directive + " "
+ idx := strings.Index(policy, directivePrefix)
+
+ if idx == -1 {
+ // Directive not found, add it after default-src or at the beginning
+ defaultSrcIdx := strings.Index(policy, "default-src ")
+ if defaultSrcIdx != -1 {
+ // Find the end of default-src directive (next semicolon)
+ endIdx := strings.Index(policy[defaultSrcIdx:], ";")
+ if endIdx != -1 {
+ insertPos := defaultSrcIdx + endIdx + 1
+ // Insert new directive after default-src
+ return policy[:insertPos] + " " + directive + " 'self' " + value + ";" + policy[insertPos:]
+ }
+ }
+ // Fallback: prepend the directive
+ return directive + " 'self' " + value + "; " + policy
+ }
+
+ // Find the end of this directive (next semicolon or end of string)
+ endIdx := strings.Index(policy[idx:], ";")
+
+ if endIdx == -1 {
+ // No semicolon found, directive goes to end of string
+ return policy + " " + value
+ }
+
+ // Insert value before the semicolon
+ insertPos := idx + endIdx
+ return policy[:insertPos] + " " + value + policy[insertPos:]
+}
diff --git a/backend/internal/server/middleware/security_headers_test.go b/backend/internal/server/middleware/security_headers_test.go
new file mode 100644
index 00000000..dc7a87d8
--- /dev/null
+++ b/backend/internal/server/middleware/security_headers_test.go
@@ -0,0 +1,365 @@
+package middleware
+
+import (
+ "encoding/base64"
+ "net/http"
+ "net/http/httptest"
+ "strings"
+ "testing"
+
+ "github.com/Wei-Shaw/sub2api/internal/config"
+ "github.com/gin-gonic/gin"
+ "github.com/stretchr/testify/assert"
+ "github.com/stretchr/testify/require"
+)
+
+func init() {
+ gin.SetMode(gin.TestMode)
+}
+
+func TestGenerateNonce(t *testing.T) {
+ t.Run("generates_valid_base64_string", func(t *testing.T) {
+ nonce := GenerateNonce()
+
+ // Should be valid base64
+ decoded, err := base64.StdEncoding.DecodeString(nonce)
+ require.NoError(t, err)
+
+ // Should decode to 16 bytes
+ assert.Len(t, decoded, 16)
+ })
+
+ t.Run("generates_unique_nonces", func(t *testing.T) {
+ nonces := make(map[string]bool)
+ for i := 0; i < 100; i++ {
+ nonce := GenerateNonce()
+ assert.False(t, nonces[nonce], "nonce should be unique")
+ nonces[nonce] = true
+ }
+ })
+
+ t.Run("nonce_has_expected_length", func(t *testing.T) {
+ nonce := GenerateNonce()
+ // 16 bytes -> 24 chars in base64 (with padding)
+ assert.Len(t, nonce, 24)
+ })
+}
+
+func TestGetNonceFromContext(t *testing.T) {
+ t.Run("returns_nonce_when_present", func(t *testing.T) {
+ w := httptest.NewRecorder()
+ c, _ := gin.CreateTestContext(w)
+
+ expectedNonce := "test-nonce-123"
+ c.Set(CSPNonceKey, expectedNonce)
+
+ nonce := GetNonceFromContext(c)
+ assert.Equal(t, expectedNonce, nonce)
+ })
+
+ t.Run("returns_empty_string_when_not_present", func(t *testing.T) {
+ w := httptest.NewRecorder()
+ c, _ := gin.CreateTestContext(w)
+
+ nonce := GetNonceFromContext(c)
+ assert.Empty(t, nonce)
+ })
+
+ t.Run("returns_empty_for_wrong_type", func(t *testing.T) {
+ w := httptest.NewRecorder()
+ c, _ := gin.CreateTestContext(w)
+
+ // Set a non-string value
+ c.Set(CSPNonceKey, 12345)
+
+ // Should return empty string for wrong type (safe type assertion)
+ nonce := GetNonceFromContext(c)
+ assert.Empty(t, nonce)
+ })
+}
+
+func TestSecurityHeaders(t *testing.T) {
+ t.Run("sets_basic_security_headers", func(t *testing.T) {
+ cfg := config.CSPConfig{Enabled: false}
+ middleware := SecurityHeaders(cfg)
+
+ w := httptest.NewRecorder()
+ c, _ := gin.CreateTestContext(w)
+ c.Request = httptest.NewRequest(http.MethodGet, "/", nil)
+
+ middleware(c)
+
+ assert.Equal(t, "nosniff", w.Header().Get("X-Content-Type-Options"))
+ assert.Equal(t, "DENY", w.Header().Get("X-Frame-Options"))
+ assert.Equal(t, "strict-origin-when-cross-origin", w.Header().Get("Referrer-Policy"))
+ })
+
+ t.Run("csp_disabled_no_csp_header", func(t *testing.T) {
+ cfg := config.CSPConfig{Enabled: false}
+ middleware := SecurityHeaders(cfg)
+
+ w := httptest.NewRecorder()
+ c, _ := gin.CreateTestContext(w)
+ c.Request = httptest.NewRequest(http.MethodGet, "/", nil)
+
+ middleware(c)
+
+ assert.Empty(t, w.Header().Get("Content-Security-Policy"))
+ })
+
+ t.Run("csp_enabled_sets_csp_header", func(t *testing.T) {
+ cfg := config.CSPConfig{
+ Enabled: true,
+ Policy: "default-src 'self'",
+ }
+ middleware := SecurityHeaders(cfg)
+
+ w := httptest.NewRecorder()
+ c, _ := gin.CreateTestContext(w)
+ c.Request = httptest.NewRequest(http.MethodGet, "/", nil)
+
+ middleware(c)
+
+ csp := w.Header().Get("Content-Security-Policy")
+ assert.NotEmpty(t, csp)
+ // Policy is auto-enhanced with nonce and Cloudflare Insights domain
+ assert.Contains(t, csp, "default-src 'self'")
+ assert.Contains(t, csp, "'nonce-")
+ assert.Contains(t, csp, CloudflareInsightsDomain)
+ })
+
+ t.Run("csp_enabled_with_nonce_placeholder", func(t *testing.T) {
+ cfg := config.CSPConfig{
+ Enabled: true,
+ Policy: "script-src 'self' __CSP_NONCE__",
+ }
+ middleware := SecurityHeaders(cfg)
+
+ w := httptest.NewRecorder()
+ c, _ := gin.CreateTestContext(w)
+ c.Request = httptest.NewRequest(http.MethodGet, "/", nil)
+
+ middleware(c)
+
+ csp := w.Header().Get("Content-Security-Policy")
+ assert.NotEmpty(t, csp)
+ assert.NotContains(t, csp, "__CSP_NONCE__", "placeholder should be replaced")
+ assert.Contains(t, csp, "'nonce-", "should contain nonce directive")
+
+ // Verify nonce is stored in context
+ nonce := GetNonceFromContext(c)
+ assert.NotEmpty(t, nonce)
+ assert.Contains(t, csp, "'nonce-"+nonce+"'")
+ })
+
+ t.Run("uses_default_policy_when_empty", func(t *testing.T) {
+ cfg := config.CSPConfig{
+ Enabled: true,
+ Policy: "",
+ }
+ middleware := SecurityHeaders(cfg)
+
+ w := httptest.NewRecorder()
+ c, _ := gin.CreateTestContext(w)
+ c.Request = httptest.NewRequest(http.MethodGet, "/", nil)
+
+ middleware(c)
+
+ csp := w.Header().Get("Content-Security-Policy")
+ assert.NotEmpty(t, csp)
+ // Default policy should contain these elements
+ assert.Contains(t, csp, "default-src 'self'")
+ })
+
+ t.Run("uses_default_policy_when_whitespace_only", func(t *testing.T) {
+ cfg := config.CSPConfig{
+ Enabled: true,
+ Policy: " \t\n ",
+ }
+ middleware := SecurityHeaders(cfg)
+
+ w := httptest.NewRecorder()
+ c, _ := gin.CreateTestContext(w)
+ c.Request = httptest.NewRequest(http.MethodGet, "/", nil)
+
+ middleware(c)
+
+ csp := w.Header().Get("Content-Security-Policy")
+ assert.NotEmpty(t, csp)
+ assert.Contains(t, csp, "default-src 'self'")
+ })
+
+ t.Run("multiple_nonce_placeholders_replaced", func(t *testing.T) {
+ cfg := config.CSPConfig{
+ Enabled: true,
+ Policy: "script-src __CSP_NONCE__; style-src __CSP_NONCE__",
+ }
+ middleware := SecurityHeaders(cfg)
+
+ w := httptest.NewRecorder()
+ c, _ := gin.CreateTestContext(w)
+ c.Request = httptest.NewRequest(http.MethodGet, "/", nil)
+
+ middleware(c)
+
+ csp := w.Header().Get("Content-Security-Policy")
+ nonce := GetNonceFromContext(c)
+
+ // Count occurrences of the nonce
+ count := strings.Count(csp, "'nonce-"+nonce+"'")
+ assert.Equal(t, 2, count, "both placeholders should be replaced with same nonce")
+ })
+
+ t.Run("calls_next_handler", func(t *testing.T) {
+ cfg := config.CSPConfig{Enabled: true, Policy: "default-src 'self'"}
+ middleware := SecurityHeaders(cfg)
+
+ nextCalled := false
+ router := gin.New()
+ router.Use(middleware)
+ router.GET("/test", func(c *gin.Context) {
+ nextCalled = true
+ c.Status(http.StatusOK)
+ })
+
+ w := httptest.NewRecorder()
+ req := httptest.NewRequest(http.MethodGet, "/test", nil)
+ router.ServeHTTP(w, req)
+
+ assert.True(t, nextCalled, "next handler should be called")
+ assert.Equal(t, http.StatusOK, w.Code)
+ })
+
+ t.Run("nonce_unique_per_request", func(t *testing.T) {
+ cfg := config.CSPConfig{
+ Enabled: true,
+ Policy: "script-src __CSP_NONCE__",
+ }
+ middleware := SecurityHeaders(cfg)
+
+ nonces := make(map[string]bool)
+ for i := 0; i < 10; i++ {
+ w := httptest.NewRecorder()
+ c, _ := gin.CreateTestContext(w)
+ c.Request = httptest.NewRequest(http.MethodGet, "/", nil)
+
+ middleware(c)
+
+ nonce := GetNonceFromContext(c)
+ assert.False(t, nonces[nonce], "nonce should be unique per request")
+ nonces[nonce] = true
+ }
+ })
+}
+
+func TestCSPNonceKey(t *testing.T) {
+ t.Run("constant_value", func(t *testing.T) {
+ assert.Equal(t, "csp_nonce", CSPNonceKey)
+ })
+}
+
+func TestNonceTemplate(t *testing.T) {
+ t.Run("constant_value", func(t *testing.T) {
+ assert.Equal(t, "__CSP_NONCE__", NonceTemplate)
+ })
+}
+
+func TestEnhanceCSPPolicy(t *testing.T) {
+ t.Run("adds_nonce_placeholder_if_missing", func(t *testing.T) {
+ policy := "default-src 'self'; script-src 'self'"
+ enhanced := enhanceCSPPolicy(policy)
+
+ assert.Contains(t, enhanced, NonceTemplate)
+ assert.Contains(t, enhanced, CloudflareInsightsDomain)
+ })
+
+ t.Run("does_not_duplicate_nonce_placeholder", func(t *testing.T) {
+ policy := "default-src 'self'; script-src 'self' __CSP_NONCE__"
+ enhanced := enhanceCSPPolicy(policy)
+
+ // Should not duplicate
+ count := strings.Count(enhanced, NonceTemplate)
+ assert.Equal(t, 1, count)
+ })
+
+ t.Run("does_not_duplicate_cloudflare_domain", func(t *testing.T) {
+ policy := "default-src 'self'; script-src 'self' https://static.cloudflareinsights.com"
+ enhanced := enhanceCSPPolicy(policy)
+
+ count := strings.Count(enhanced, CloudflareInsightsDomain)
+ assert.Equal(t, 1, count)
+ })
+
+ t.Run("handles_policy_without_script_src", func(t *testing.T) {
+ policy := "default-src 'self'"
+ enhanced := enhanceCSPPolicy(policy)
+
+ assert.Contains(t, enhanced, "script-src")
+ assert.Contains(t, enhanced, NonceTemplate)
+ assert.Contains(t, enhanced, CloudflareInsightsDomain)
+ })
+
+ t.Run("preserves_existing_nonce", func(t *testing.T) {
+ policy := "script-src 'self' 'nonce-existing'"
+ enhanced := enhanceCSPPolicy(policy)
+
+ // Should not add placeholder if nonce already exists
+ assert.NotContains(t, enhanced, NonceTemplate)
+ assert.Contains(t, enhanced, "'nonce-existing'")
+ })
+}
+
+func TestAddToDirective(t *testing.T) {
+ t.Run("adds_to_existing_directive", func(t *testing.T) {
+ policy := "script-src 'self'; style-src 'self'"
+ result := addToDirective(policy, "script-src", "https://example.com")
+
+ assert.Contains(t, result, "script-src 'self' https://example.com")
+ })
+
+ t.Run("creates_directive_if_not_exists", func(t *testing.T) {
+ policy := "default-src 'self'"
+ result := addToDirective(policy, "script-src", "https://example.com")
+
+ assert.Contains(t, result, "script-src")
+ assert.Contains(t, result, "https://example.com")
+ })
+
+ t.Run("handles_directive_at_end_without_semicolon", func(t *testing.T) {
+ policy := "default-src 'self'; script-src 'self'"
+ result := addToDirective(policy, "script-src", "https://example.com")
+
+ assert.Contains(t, result, "https://example.com")
+ })
+
+ t.Run("handles_empty_policy", func(t *testing.T) {
+ policy := ""
+ result := addToDirective(policy, "script-src", "https://example.com")
+
+ assert.Contains(t, result, "script-src")
+ assert.Contains(t, result, "https://example.com")
+ })
+}
+
+// Benchmark tests
+func BenchmarkGenerateNonce(b *testing.B) {
+ for i := 0; i < b.N; i++ {
+ GenerateNonce()
+ }
+}
+
+func BenchmarkSecurityHeadersMiddleware(b *testing.B) {
+ cfg := config.CSPConfig{
+ Enabled: true,
+ Policy: "script-src 'self' __CSP_NONCE__",
+ }
+ middleware := SecurityHeaders(cfg)
+
+ b.ResetTimer()
+ for i := 0; i < b.N; i++ {
+ w := httptest.NewRecorder()
+ c, _ := gin.CreateTestContext(w)
+ c.Request = httptest.NewRequest(http.MethodGet, "/", nil)
+ middleware(c)
+ }
+}
diff --git a/backend/internal/server/routes/admin.go b/backend/internal/server/routes/admin.go
index 9bb019bb..ff05b32a 100644
--- a/backend/internal/server/routes/admin.go
+++ b/backend/internal/server/routes/admin.go
@@ -81,6 +81,9 @@ func registerOpsRoutes(admin *gin.RouterGroup, h *handler.Handlers) {
ops.PUT("/alert-rules/:id", h.Admin.Ops.UpdateAlertRule)
ops.DELETE("/alert-rules/:id", h.Admin.Ops.DeleteAlertRule)
ops.GET("/alert-events", h.Admin.Ops.ListAlertEvents)
+ ops.GET("/alert-events/:id", h.Admin.Ops.GetAlertEvent)
+ ops.PUT("/alert-events/:id/status", h.Admin.Ops.UpdateAlertEventStatus)
+ ops.POST("/alert-silences", h.Admin.Ops.CreateAlertSilence)
// Email notification config (DB-backed)
ops.GET("/email-notification/config", h.Admin.Ops.GetEmailNotificationConfig)
@@ -110,10 +113,26 @@ func registerOpsRoutes(admin *gin.RouterGroup, h *handler.Handlers) {
ws.GET("/qps", h.Admin.Ops.QPSWSHandler)
}
- // Error logs (MVP-1)
+ // Error logs (legacy)
ops.GET("/errors", h.Admin.Ops.GetErrorLogs)
ops.GET("/errors/:id", h.Admin.Ops.GetErrorLogByID)
+ ops.GET("/errors/:id/retries", h.Admin.Ops.ListRetryAttempts)
ops.POST("/errors/:id/retry", h.Admin.Ops.RetryErrorRequest)
+ ops.PUT("/errors/:id/resolve", h.Admin.Ops.UpdateErrorResolution)
+
+ // Request errors (client-visible failures)
+ ops.GET("/request-errors", h.Admin.Ops.ListRequestErrors)
+ ops.GET("/request-errors/:id", h.Admin.Ops.GetRequestError)
+ ops.GET("/request-errors/:id/upstream-errors", h.Admin.Ops.ListRequestErrorUpstreamErrors)
+ ops.POST("/request-errors/:id/retry-client", h.Admin.Ops.RetryRequestErrorClient)
+ ops.POST("/request-errors/:id/upstream-errors/:idx/retry", h.Admin.Ops.RetryRequestErrorUpstreamEvent)
+ ops.PUT("/request-errors/:id/resolve", h.Admin.Ops.ResolveRequestError)
+
+ // Upstream errors (independent upstream failures)
+ ops.GET("/upstream-errors", h.Admin.Ops.ListUpstreamErrors)
+ ops.GET("/upstream-errors/:id", h.Admin.Ops.GetUpstreamError)
+ ops.POST("/upstream-errors/:id/retry", h.Admin.Ops.RetryUpstreamError)
+ ops.PUT("/upstream-errors/:id/resolve", h.Admin.Ops.ResolveUpstreamError)
// Request drilldown (success + error)
ops.GET("/requests", h.Admin.Ops.ListRequestDetails)
@@ -250,6 +269,7 @@ func registerProxyRoutes(admin *gin.RouterGroup, h *handler.Handlers) {
proxies.POST("/:id/test", h.Admin.Proxy.Test)
proxies.GET("/:id/stats", h.Admin.Proxy.GetStats)
proxies.GET("/:id/accounts", h.Admin.Proxy.GetProxyAccounts)
+ proxies.POST("/batch-delete", h.Admin.Proxy.BatchDelete)
proxies.POST("/batch", h.Admin.Proxy.BatchCreate)
}
}
diff --git a/backend/internal/service/account.go b/backend/internal/service/account.go
index 435eecd9..4fda300e 100644
--- a/backend/internal/service/account.go
+++ b/backend/internal/service/account.go
@@ -9,16 +9,19 @@ import (
)
type Account struct {
- ID int64
- Name string
- Notes *string
- Platform string
- Type string
- Credentials map[string]any
- Extra map[string]any
- ProxyID *int64
- Concurrency int
- Priority int
+ ID int64
+ Name string
+ Notes *string
+ Platform string
+ Type string
+ Credentials map[string]any
+ Extra map[string]any
+ ProxyID *int64
+ Concurrency int
+ Priority int
+ // RateMultiplier 账号计费倍率(>=0,允许 0 表示该账号计费为 0)。
+ // 使用指针用于兼容旧版本调度缓存(Redis)中缺字段的情况:nil 表示按 1.0 处理。
+ RateMultiplier *float64
Status string
ErrorMessage string
LastUsedAt *time.Time
@@ -57,6 +60,20 @@ func (a *Account) IsActive() bool {
return a.Status == StatusActive
}
+// BillingRateMultiplier 返回账号计费倍率。
+// - nil 表示未配置/旧缓存缺字段,按 1.0 处理
+// - 允许 0,表示该账号计费为 0
+// - 负数属于非法数据,出于安全考虑按 1.0 处理
+func (a *Account) BillingRateMultiplier() float64 {
+ if a == nil || a.RateMultiplier == nil {
+ return 1.0
+ }
+ if *a.RateMultiplier < 0 {
+ return 1.0
+ }
+ return *a.RateMultiplier
+}
+
func (a *Account) IsSchedulable() bool {
if !a.IsActive() || !a.Schedulable {
return false
@@ -556,3 +573,141 @@ func (a *Account) IsMixedSchedulingEnabled() bool {
}
return false
}
+
+// WindowCostSchedulability 窗口费用调度状态
+type WindowCostSchedulability int
+
+const (
+ // WindowCostSchedulable 可正常调度
+ WindowCostSchedulable WindowCostSchedulability = iota
+ // WindowCostStickyOnly 仅允许粘性会话
+ WindowCostStickyOnly
+ // WindowCostNotSchedulable 完全不可调度
+ WindowCostNotSchedulable
+)
+
+// IsAnthropicOAuthOrSetupToken 判断是否为 Anthropic OAuth 或 SetupToken 类型账号
+// 仅这两类账号支持 5h 窗口额度控制和会话数量控制
+func (a *Account) IsAnthropicOAuthOrSetupToken() bool {
+ return a.Platform == PlatformAnthropic && (a.Type == AccountTypeOAuth || a.Type == AccountTypeSetupToken)
+}
+
+// GetWindowCostLimit 获取 5h 窗口费用阈值(美元)
+// 返回 0 表示未启用
+func (a *Account) GetWindowCostLimit() float64 {
+ if a.Extra == nil {
+ return 0
+ }
+ if v, ok := a.Extra["window_cost_limit"]; ok {
+ return parseExtraFloat64(v)
+ }
+ return 0
+}
+
+// GetWindowCostStickyReserve 获取粘性会话预留额度(美元)
+// 默认值为 10
+func (a *Account) GetWindowCostStickyReserve() float64 {
+ if a.Extra == nil {
+ return 10.0
+ }
+ if v, ok := a.Extra["window_cost_sticky_reserve"]; ok {
+ val := parseExtraFloat64(v)
+ if val > 0 {
+ return val
+ }
+ }
+ return 10.0
+}
+
+// GetMaxSessions 获取最大并发会话数
+// 返回 0 表示未启用
+func (a *Account) GetMaxSessions() int {
+ if a.Extra == nil {
+ return 0
+ }
+ if v, ok := a.Extra["max_sessions"]; ok {
+ return parseExtraInt(v)
+ }
+ return 0
+}
+
+// GetSessionIdleTimeoutMinutes 获取会话空闲超时分钟数
+// 默认值为 5 分钟
+func (a *Account) GetSessionIdleTimeoutMinutes() int {
+ if a.Extra == nil {
+ return 5
+ }
+ if v, ok := a.Extra["session_idle_timeout_minutes"]; ok {
+ val := parseExtraInt(v)
+ if val > 0 {
+ return val
+ }
+ }
+ return 5
+}
+
+// CheckWindowCostSchedulability 根据当前窗口费用检查调度状态
+// - 费用 < 阈值: WindowCostSchedulable(可正常调度)
+// - 费用 >= 阈值 且 < 阈值+预留: WindowCostStickyOnly(仅粘性会话)
+// - 费用 >= 阈值+预留: WindowCostNotSchedulable(不可调度)
+func (a *Account) CheckWindowCostSchedulability(currentWindowCost float64) WindowCostSchedulability {
+ limit := a.GetWindowCostLimit()
+ if limit <= 0 {
+ return WindowCostSchedulable
+ }
+
+ if currentWindowCost < limit {
+ return WindowCostSchedulable
+ }
+
+ stickyReserve := a.GetWindowCostStickyReserve()
+ if currentWindowCost < limit+stickyReserve {
+ return WindowCostStickyOnly
+ }
+
+ return WindowCostNotSchedulable
+}
+
+// parseExtraFloat64 从 extra 字段解析 float64 值
+func parseExtraFloat64(value any) float64 {
+ switch v := value.(type) {
+ case float64:
+ return v
+ case float32:
+ return float64(v)
+ case int:
+ return float64(v)
+ case int64:
+ return float64(v)
+ case json.Number:
+ if f, err := v.Float64(); err == nil {
+ return f
+ }
+ case string:
+ if f, err := strconv.ParseFloat(strings.TrimSpace(v), 64); err == nil {
+ return f
+ }
+ }
+ return 0
+}
+
+// parseExtraInt 从 extra 字段解析 int 值
+func parseExtraInt(value any) int {
+ switch v := value.(type) {
+ case int:
+ return v
+ case int64:
+ return int(v)
+ case float64:
+ return int(v)
+ case json.Number:
+ if i, err := v.Int64(); err == nil {
+ return int(i)
+ }
+ case string:
+ if i, err := strconv.Atoi(strings.TrimSpace(v)); err == nil {
+ return i
+ }
+ }
+ return 0
+}
diff --git a/backend/internal/service/account_billing_rate_multiplier_test.go b/backend/internal/service/account_billing_rate_multiplier_test.go
new file mode 100644
index 00000000..731cfa7a
--- /dev/null
+++ b/backend/internal/service/account_billing_rate_multiplier_test.go
@@ -0,0 +1,27 @@
+package service
+
+import (
+ "encoding/json"
+ "testing"
+
+ "github.com/stretchr/testify/require"
+)
+
+func TestAccount_BillingRateMultiplier_DefaultsToOneWhenNil(t *testing.T) {
+ var a Account
+ require.NoError(t, json.Unmarshal([]byte(`{"id":1,"name":"acc","status":"active"}`), &a))
+ require.Nil(t, a.RateMultiplier)
+ require.Equal(t, 1.0, a.BillingRateMultiplier())
+}
+
+func TestAccount_BillingRateMultiplier_AllowsZero(t *testing.T) {
+ v := 0.0
+ a := Account{RateMultiplier: &v}
+ require.Equal(t, 0.0, a.BillingRateMultiplier())
+}
+
+func TestAccount_BillingRateMultiplier_NegativeFallsBackToOne(t *testing.T) {
+ v := -1.0
+ a := Account{RateMultiplier: &v}
+ require.Equal(t, 1.0, a.BillingRateMultiplier())
+}
diff --git a/backend/internal/service/account_service.go b/backend/internal/service/account_service.go
index 2f138b81..ede5b12f 100644
--- a/backend/internal/service/account_service.go
+++ b/backend/internal/service/account_service.go
@@ -50,11 +50,13 @@ type AccountRepository interface {
SetRateLimited(ctx context.Context, id int64, resetAt time.Time) error
SetAntigravityQuotaScopeLimit(ctx context.Context, id int64, scope AntigravityQuotaScope, resetAt time.Time) error
+ SetModelRateLimit(ctx context.Context, id int64, scope string, resetAt time.Time) error
SetOverloaded(ctx context.Context, id int64, until time.Time) error
SetTempUnschedulable(ctx context.Context, id int64, until time.Time, reason string) error
ClearTempUnschedulable(ctx context.Context, id int64) error
ClearRateLimit(ctx context.Context, id int64) error
ClearAntigravityQuotaScopes(ctx context.Context, id int64) error
+ ClearModelRateLimits(ctx context.Context, id int64) error
UpdateSessionWindow(ctx context.Context, id int64, start, end *time.Time, status string) error
UpdateExtra(ctx context.Context, id int64, updates map[string]any) error
BulkUpdate(ctx context.Context, ids []int64, updates AccountBulkUpdate) (int64, error)
@@ -63,14 +65,15 @@ type AccountRepository interface {
// AccountBulkUpdate describes the fields that can be updated in a bulk operation.
// Nil pointers mean "do not change".
type AccountBulkUpdate struct {
- Name *string
- ProxyID *int64
- Concurrency *int
- Priority *int
- Status *string
- Schedulable *bool
- Credentials map[string]any
- Extra map[string]any
+ Name *string
+ ProxyID *int64
+ Concurrency *int
+ Priority *int
+ RateMultiplier *float64
+ Status *string
+ Schedulable *bool
+ Credentials map[string]any
+ Extra map[string]any
}
// CreateAccountRequest 创建账号请求
diff --git a/backend/internal/service/account_service_delete_test.go b/backend/internal/service/account_service_delete_test.go
index 6923067d..36af719c 100644
--- a/backend/internal/service/account_service_delete_test.go
+++ b/backend/internal/service/account_service_delete_test.go
@@ -143,6 +143,10 @@ func (s *accountRepoStub) SetAntigravityQuotaScopeLimit(ctx context.Context, id
panic("unexpected SetAntigravityQuotaScopeLimit call")
}
+func (s *accountRepoStub) SetModelRateLimit(ctx context.Context, id int64, scope string, resetAt time.Time) error {
+ panic("unexpected SetModelRateLimit call")
+}
+
func (s *accountRepoStub) SetOverloaded(ctx context.Context, id int64, until time.Time) error {
panic("unexpected SetOverloaded call")
}
@@ -163,6 +167,10 @@ func (s *accountRepoStub) ClearAntigravityQuotaScopes(ctx context.Context, id in
panic("unexpected ClearAntigravityQuotaScopes call")
}
+func (s *accountRepoStub) ClearModelRateLimits(ctx context.Context, id int64) error {
+ panic("unexpected ClearModelRateLimits call")
+}
+
func (s *accountRepoStub) UpdateSessionWindow(ctx context.Context, id int64, start, end *time.Time, status string) error {
panic("unexpected UpdateSessionWindow call")
}
diff --git a/backend/internal/service/account_usage_service.go b/backend/internal/service/account_usage_service.go
index f1ee43d2..6f012385 100644
--- a/backend/internal/service/account_usage_service.go
+++ b/backend/internal/service/account_usage_service.go
@@ -32,8 +32,8 @@ type UsageLogRepository interface {
// Admin dashboard stats
GetDashboardStats(ctx context.Context) (*usagestats.DashboardStats, error)
- GetUsageTrendWithFilters(ctx context.Context, startTime, endTime time.Time, granularity string, userID, apiKeyID int64) ([]usagestats.TrendDataPoint, error)
- GetModelStatsWithFilters(ctx context.Context, startTime, endTime time.Time, userID, apiKeyID, accountID int64) ([]usagestats.ModelStat, error)
+ GetUsageTrendWithFilters(ctx context.Context, startTime, endTime time.Time, granularity string, userID, apiKeyID, accountID, groupID int64, model string, stream *bool) ([]usagestats.TrendDataPoint, error)
+ GetModelStatsWithFilters(ctx context.Context, startTime, endTime time.Time, userID, apiKeyID, accountID, groupID int64, stream *bool) ([]usagestats.ModelStat, error)
GetAPIKeyUsageTrend(ctx context.Context, startTime, endTime time.Time, granularity string, limit int) ([]usagestats.APIKeyUsageTrendPoint, error)
GetUserUsageTrend(ctx context.Context, startTime, endTime time.Time, granularity string, limit int) ([]usagestats.UserUsageTrendPoint, error)
GetBatchUserUsageStats(ctx context.Context, userIDs []int64) (map[int64]*usagestats.BatchUserUsageStats, error)
@@ -96,10 +96,16 @@ func NewUsageCache() *UsageCache {
}
// WindowStats 窗口期统计
+//
+// cost: 账号口径费用(total_cost * account_rate_multiplier)
+// standard_cost: 标准费用(total_cost,不含倍率)
+// user_cost: 用户/API Key 口径费用(actual_cost,受分组倍率影响)
type WindowStats struct {
- Requests int64 `json:"requests"`
- Tokens int64 `json:"tokens"`
- Cost float64 `json:"cost"`
+ Requests int64 `json:"requests"`
+ Tokens int64 `json:"tokens"`
+ Cost float64 `json:"cost"`
+ StandardCost float64 `json:"standard_cost"`
+ UserCost float64 `json:"user_cost"`
}
// UsageProgress 使用量进度
@@ -266,7 +272,7 @@ func (s *AccountUsageService) getGeminiUsage(ctx context.Context, account *Accou
}
dayStart := geminiDailyWindowStart(now)
- stats, err := s.usageLogRepo.GetModelStatsWithFilters(ctx, dayStart, now, 0, 0, account.ID)
+ stats, err := s.usageLogRepo.GetModelStatsWithFilters(ctx, dayStart, now, 0, 0, account.ID, 0, nil)
if err != nil {
return nil, fmt.Errorf("get gemini usage stats failed: %w", err)
}
@@ -288,7 +294,7 @@ func (s *AccountUsageService) getGeminiUsage(ctx context.Context, account *Accou
// Minute window (RPM) - fixed-window approximation: current minute [truncate(now), truncate(now)+1m)
minuteStart := now.Truncate(time.Minute)
minuteResetAt := minuteStart.Add(time.Minute)
- minuteStats, err := s.usageLogRepo.GetModelStatsWithFilters(ctx, minuteStart, now, 0, 0, account.ID)
+ minuteStats, err := s.usageLogRepo.GetModelStatsWithFilters(ctx, minuteStart, now, 0, 0, account.ID, 0, nil)
if err != nil {
return nil, fmt.Errorf("get gemini minute usage stats failed: %w", err)
}
@@ -377,9 +383,11 @@ func (s *AccountUsageService) addWindowStats(ctx context.Context, account *Accou
}
windowStats = &WindowStats{
- Requests: stats.Requests,
- Tokens: stats.Tokens,
- Cost: stats.Cost,
+ Requests: stats.Requests,
+ Tokens: stats.Tokens,
+ Cost: stats.Cost,
+ StandardCost: stats.StandardCost,
+ UserCost: stats.UserCost,
}
// 缓存窗口统计(1 分钟)
@@ -403,9 +411,11 @@ func (s *AccountUsageService) GetTodayStats(ctx context.Context, accountID int64
}
return &WindowStats{
- Requests: stats.Requests,
- Tokens: stats.Tokens,
- Cost: stats.Cost,
+ Requests: stats.Requests,
+ Tokens: stats.Tokens,
+ Cost: stats.Cost,
+ StandardCost: stats.StandardCost,
+ UserCost: stats.UserCost,
}, nil
}
@@ -565,3 +575,9 @@ func buildGeminiUsageProgress(used, limit int64, resetAt time.Time, tokens int64
},
}
}
+
+// GetAccountWindowStats 获取账号在指定时间窗口内的使用统计
+// 用于账号列表页面显示当前窗口费用
+func (s *AccountUsageService) GetAccountWindowStats(ctx context.Context, accountID int64, startTime time.Time) (*usagestats.AccountStats, error) {
+ return s.usageLogRepo.GetAccountWindowStats(ctx, accountID, startTime)
+}
diff --git a/backend/internal/service/admin_service.go b/backend/internal/service/admin_service.go
index 1874c5c1..c0694e4e 100644
--- a/backend/internal/service/admin_service.go
+++ b/backend/internal/service/admin_service.go
@@ -54,7 +54,8 @@ type AdminService interface {
CreateProxy(ctx context.Context, input *CreateProxyInput) (*Proxy, error)
UpdateProxy(ctx context.Context, id int64, input *UpdateProxyInput) (*Proxy, error)
DeleteProxy(ctx context.Context, id int64) error
- GetProxyAccounts(ctx context.Context, proxyID int64, page, pageSize int) ([]Account, int64, error)
+ BatchDeleteProxies(ctx context.Context, ids []int64) (*ProxyBatchDeleteResult, error)
+ GetProxyAccounts(ctx context.Context, proxyID int64) ([]ProxyAccountSummary, error)
CheckProxyExists(ctx context.Context, host string, port int, username, password string) (bool, error)
TestProxy(ctx context.Context, id int64) (*ProxyTestResult, error)
@@ -105,6 +106,9 @@ type CreateGroupInput struct {
ImagePrice4K *float64
ClaudeCodeOnly bool // 仅允许 Claude Code 客户端
FallbackGroupID *int64 // 降级分组 ID
+ // 模型路由配置(仅 anthropic 平台使用)
+ ModelRouting map[string][]int64
+ ModelRoutingEnabled bool // 是否启用模型路由
}
type UpdateGroupInput struct {
@@ -124,6 +128,9 @@ type UpdateGroupInput struct {
ImagePrice4K *float64
ClaudeCodeOnly *bool // 仅允许 Claude Code 客户端
FallbackGroupID *int64 // 降级分组 ID
+ // 模型路由配置(仅 anthropic 平台使用)
+ ModelRouting map[string][]int64
+ ModelRoutingEnabled *bool // 是否启用模型路由
}
type CreateAccountInput struct {
@@ -136,6 +143,7 @@ type CreateAccountInput struct {
ProxyID *int64
Concurrency int
Priority int
+ RateMultiplier *float64 // 账号计费倍率(>=0,允许 0)
GroupIDs []int64
ExpiresAt *int64
AutoPauseOnExpired *bool
@@ -151,8 +159,9 @@ type UpdateAccountInput struct {
Credentials map[string]any
Extra map[string]any
ProxyID *int64
- Concurrency *int // 使用指针区分"未提供"和"设置为0"
- Priority *int // 使用指针区分"未提供"和"设置为0"
+ Concurrency *int // 使用指针区分"未提供"和"设置为0"
+ Priority *int // 使用指针区分"未提供"和"设置为0"
+ RateMultiplier *float64 // 账号计费倍率(>=0,允许 0)
Status string
GroupIDs *[]int64
ExpiresAt *int64
@@ -162,16 +171,17 @@ type UpdateAccountInput struct {
// BulkUpdateAccountsInput describes the payload for bulk updating accounts.
type BulkUpdateAccountsInput struct {
- AccountIDs []int64
- Name string
- ProxyID *int64
- Concurrency *int
- Priority *int
- Status string
- Schedulable *bool
- GroupIDs *[]int64
- Credentials map[string]any
- Extra map[string]any
+ AccountIDs []int64
+ Name string
+ ProxyID *int64
+ Concurrency *int
+ Priority *int
+ RateMultiplier *float64 // 账号计费倍率(>=0,允许 0)
+ Status string
+ Schedulable *bool
+ GroupIDs *[]int64
+ Credentials map[string]any
+ Extra map[string]any
// SkipMixedChannelCheck skips the mixed channel risk check when binding groups.
// This should only be set when the caller has explicitly confirmed the risk.
SkipMixedChannelCheck bool
@@ -220,23 +230,35 @@ type GenerateRedeemCodesInput struct {
ValidityDays int // 订阅类型专用:有效天数
}
-// ProxyTestResult represents the result of testing a proxy
-type ProxyTestResult struct {
- Success bool `json:"success"`
- Message string `json:"message"`
- LatencyMs int64 `json:"latency_ms,omitempty"`
- IPAddress string `json:"ip_address,omitempty"`
- City string `json:"city,omitempty"`
- Region string `json:"region,omitempty"`
- Country string `json:"country,omitempty"`
+type ProxyBatchDeleteResult struct {
+ DeletedIDs []int64 `json:"deleted_ids"`
+ Skipped []ProxyBatchDeleteSkipped `json:"skipped"`
}
-// ProxyExitInfo represents proxy exit information from ipinfo.io
+type ProxyBatchDeleteSkipped struct {
+ ID int64 `json:"id"`
+ Reason string `json:"reason"`
+}
+
+// ProxyTestResult represents the result of testing a proxy
+type ProxyTestResult struct {
+ Success bool `json:"success"`
+ Message string `json:"message"`
+ LatencyMs int64 `json:"latency_ms,omitempty"`
+ IPAddress string `json:"ip_address,omitempty"`
+ City string `json:"city,omitempty"`
+ Region string `json:"region,omitempty"`
+ Country string `json:"country,omitempty"`
+ CountryCode string `json:"country_code,omitempty"`
+}
+
+// ProxyExitInfo represents proxy exit information from ip-api.com
type ProxyExitInfo struct {
- IP string
- City string
- Region string
- Country string
+ IP string
+ City string
+ Region string
+ Country string
+ CountryCode string
}
// ProxyExitInfoProber tests proxy connectivity and retrieves exit information
@@ -254,6 +276,7 @@ type adminServiceImpl struct {
redeemCodeRepo RedeemCodeRepository
billingCacheService *BillingCacheService
proxyProber ProxyExitInfoProber
+ proxyLatencyCache ProxyLatencyCache
authCacheInvalidator APIKeyAuthCacheInvalidator
}
@@ -267,6 +290,7 @@ func NewAdminService(
redeemCodeRepo RedeemCodeRepository,
billingCacheService *BillingCacheService,
proxyProber ProxyExitInfoProber,
+ proxyLatencyCache ProxyLatencyCache,
authCacheInvalidator APIKeyAuthCacheInvalidator,
) AdminService {
return &adminServiceImpl{
@@ -278,6 +302,7 @@ func NewAdminService(
redeemCodeRepo: redeemCodeRepo,
billingCacheService: billingCacheService,
proxyProber: proxyProber,
+ proxyLatencyCache: proxyLatencyCache,
authCacheInvalidator: authCacheInvalidator,
}
}
@@ -562,6 +587,7 @@ func (s *adminServiceImpl) CreateGroup(ctx context.Context, input *CreateGroupIn
ImagePrice4K: imagePrice4K,
ClaudeCodeOnly: input.ClaudeCodeOnly,
FallbackGroupID: input.FallbackGroupID,
+ ModelRouting: input.ModelRouting,
}
if err := s.groupRepo.Create(ctx, group); err != nil {
return nil, err
@@ -690,6 +716,14 @@ func (s *adminServiceImpl) UpdateGroup(ctx context.Context, id int64, input *Upd
}
}
+ // 模型路由配置
+ if input.ModelRouting != nil {
+ group.ModelRouting = input.ModelRouting
+ }
+ if input.ModelRoutingEnabled != nil {
+ group.ModelRoutingEnabled = *input.ModelRoutingEnabled
+ }
+
if err := s.groupRepo.Update(ctx, group); err != nil {
return nil, err
}
@@ -817,6 +851,12 @@ func (s *adminServiceImpl) CreateAccount(ctx context.Context, input *CreateAccou
} else {
account.AutoPauseOnExpired = true
}
+ if input.RateMultiplier != nil {
+ if *input.RateMultiplier < 0 {
+ return nil, errors.New("rate_multiplier must be >= 0")
+ }
+ account.RateMultiplier = input.RateMultiplier
+ }
if err := s.accountRepo.Create(ctx, account); err != nil {
return nil, err
}
@@ -869,6 +909,12 @@ func (s *adminServiceImpl) UpdateAccount(ctx context.Context, id int64, input *U
if input.Priority != nil {
account.Priority = *input.Priority
}
+ if input.RateMultiplier != nil {
+ if *input.RateMultiplier < 0 {
+ return nil, errors.New("rate_multiplier must be >= 0")
+ }
+ account.RateMultiplier = input.RateMultiplier
+ }
if input.Status != "" {
account.Status = input.Status
}
@@ -942,6 +988,12 @@ func (s *adminServiceImpl) BulkUpdateAccounts(ctx context.Context, input *BulkUp
}
}
+ if input.RateMultiplier != nil {
+ if *input.RateMultiplier < 0 {
+ return nil, errors.New("rate_multiplier must be >= 0")
+ }
+ }
+
// Prepare bulk updates for columns and JSONB fields.
repoUpdates := AccountBulkUpdate{
Credentials: input.Credentials,
@@ -959,6 +1011,9 @@ func (s *adminServiceImpl) BulkUpdateAccounts(ctx context.Context, input *BulkUp
if input.Priority != nil {
repoUpdates.Priority = input.Priority
}
+ if input.RateMultiplier != nil {
+ repoUpdates.RateMultiplier = input.RateMultiplier
+ }
if input.Status != "" {
repoUpdates.Status = &input.Status
}
@@ -1069,6 +1124,7 @@ func (s *adminServiceImpl) ListProxiesWithAccountCount(ctx context.Context, page
if err != nil {
return nil, 0, err
}
+ s.attachProxyLatency(ctx, proxies)
return proxies, result.Total, nil
}
@@ -1077,7 +1133,12 @@ func (s *adminServiceImpl) GetAllProxies(ctx context.Context) ([]Proxy, error) {
}
func (s *adminServiceImpl) GetAllProxiesWithAccountCount(ctx context.Context) ([]ProxyWithAccountCount, error) {
- return s.proxyRepo.ListActiveWithAccountCount(ctx)
+ proxies, err := s.proxyRepo.ListActiveWithAccountCount(ctx)
+ if err != nil {
+ return nil, err
+ }
+ s.attachProxyLatency(ctx, proxies)
+ return proxies, nil
}
func (s *adminServiceImpl) GetProxy(ctx context.Context, id int64) (*Proxy, error) {
@@ -1097,6 +1158,8 @@ func (s *adminServiceImpl) CreateProxy(ctx context.Context, input *CreateProxyIn
if err := s.proxyRepo.Create(ctx, proxy); err != nil {
return nil, err
}
+ // Probe latency asynchronously so creation isn't blocked by network timeout.
+ go s.probeProxyLatency(context.Background(), proxy)
return proxy, nil
}
@@ -1135,12 +1198,53 @@ func (s *adminServiceImpl) UpdateProxy(ctx context.Context, id int64, input *Upd
}
func (s *adminServiceImpl) DeleteProxy(ctx context.Context, id int64) error {
+ count, err := s.proxyRepo.CountAccountsByProxyID(ctx, id)
+ if err != nil {
+ return err
+ }
+ if count > 0 {
+ return ErrProxyInUse
+ }
return s.proxyRepo.Delete(ctx, id)
}
-func (s *adminServiceImpl) GetProxyAccounts(ctx context.Context, proxyID int64, page, pageSize int) ([]Account, int64, error) {
- // Return mock data for now - would need a dedicated repository method
- return []Account{}, 0, nil
+func (s *adminServiceImpl) BatchDeleteProxies(ctx context.Context, ids []int64) (*ProxyBatchDeleteResult, error) {
+ result := &ProxyBatchDeleteResult{}
+ if len(ids) == 0 {
+ return result, nil
+ }
+
+ for _, id := range ids {
+ count, err := s.proxyRepo.CountAccountsByProxyID(ctx, id)
+ if err != nil {
+ result.Skipped = append(result.Skipped, ProxyBatchDeleteSkipped{
+ ID: id,
+ Reason: err.Error(),
+ })
+ continue
+ }
+ if count > 0 {
+ result.Skipped = append(result.Skipped, ProxyBatchDeleteSkipped{
+ ID: id,
+ Reason: ErrProxyInUse.Error(),
+ })
+ continue
+ }
+ if err := s.proxyRepo.Delete(ctx, id); err != nil {
+ result.Skipped = append(result.Skipped, ProxyBatchDeleteSkipped{
+ ID: id,
+ Reason: err.Error(),
+ })
+ continue
+ }
+ result.DeletedIDs = append(result.DeletedIDs, id)
+ }
+
+ return result, nil
+}
+
+func (s *adminServiceImpl) GetProxyAccounts(ctx context.Context, proxyID int64) ([]ProxyAccountSummary, error) {
+ return s.proxyRepo.ListAccountSummariesByProxyID(ctx, proxyID)
}
func (s *adminServiceImpl) CheckProxyExists(ctx context.Context, host string, port int, username, password string) (bool, error) {
@@ -1240,23 +1344,69 @@ func (s *adminServiceImpl) TestProxy(ctx context.Context, id int64) (*ProxyTestR
proxyURL := proxy.URL()
exitInfo, latencyMs, err := s.proxyProber.ProbeProxy(ctx, proxyURL)
if err != nil {
+ s.saveProxyLatency(ctx, id, &ProxyLatencyInfo{
+ Success: false,
+ Message: err.Error(),
+ UpdatedAt: time.Now(),
+ })
return &ProxyTestResult{
Success: false,
Message: err.Error(),
}, nil
}
+ latency := latencyMs
+ s.saveProxyLatency(ctx, id, &ProxyLatencyInfo{
+ Success: true,
+ LatencyMs: &latency,
+ Message: "Proxy is accessible",
+ IPAddress: exitInfo.IP,
+ Country: exitInfo.Country,
+ CountryCode: exitInfo.CountryCode,
+ Region: exitInfo.Region,
+ City: exitInfo.City,
+ UpdatedAt: time.Now(),
+ })
return &ProxyTestResult{
- Success: true,
- Message: "Proxy is accessible",
- LatencyMs: latencyMs,
- IPAddress: exitInfo.IP,
- City: exitInfo.City,
- Region: exitInfo.Region,
- Country: exitInfo.Country,
+ Success: true,
+ Message: "Proxy is accessible",
+ LatencyMs: latencyMs,
+ IPAddress: exitInfo.IP,
+ City: exitInfo.City,
+ Region: exitInfo.Region,
+ Country: exitInfo.Country,
+ CountryCode: exitInfo.CountryCode,
}, nil
}
+func (s *adminServiceImpl) probeProxyLatency(ctx context.Context, proxy *Proxy) {
+ if s.proxyProber == nil || proxy == nil {
+ return
+ }
+ exitInfo, latencyMs, err := s.proxyProber.ProbeProxy(ctx, proxy.URL())
+ if err != nil {
+ s.saveProxyLatency(ctx, proxy.ID, &ProxyLatencyInfo{
+ Success: false,
+ Message: err.Error(),
+ UpdatedAt: time.Now(),
+ })
+ return
+ }
+
+ latency := latencyMs
+ s.saveProxyLatency(ctx, proxy.ID, &ProxyLatencyInfo{
+ Success: true,
+ LatencyMs: &latency,
+ Message: "Proxy is accessible",
+ IPAddress: exitInfo.IP,
+ Country: exitInfo.Country,
+ CountryCode: exitInfo.CountryCode,
+ Region: exitInfo.Region,
+ City: exitInfo.City,
+ UpdatedAt: time.Now(),
+ })
+}
+
// checkMixedChannelRisk 检查分组中是否存在混合渠道(Antigravity + Anthropic)
// 如果存在混合,返回错误提示用户确认
func (s *adminServiceImpl) checkMixedChannelRisk(ctx context.Context, currentAccountID int64, currentAccountPlatform string, groupIDs []int64) error {
@@ -1306,6 +1456,51 @@ func (s *adminServiceImpl) checkMixedChannelRisk(ctx context.Context, currentAcc
return nil
}
+func (s *adminServiceImpl) attachProxyLatency(ctx context.Context, proxies []ProxyWithAccountCount) {
+ if s.proxyLatencyCache == nil || len(proxies) == 0 {
+ return
+ }
+
+ ids := make([]int64, 0, len(proxies))
+ for i := range proxies {
+ ids = append(ids, proxies[i].ID)
+ }
+
+ latencies, err := s.proxyLatencyCache.GetProxyLatencies(ctx, ids)
+ if err != nil {
+ log.Printf("Warning: load proxy latency cache failed: %v", err)
+ return
+ }
+
+ for i := range proxies {
+ info := latencies[proxies[i].ID]
+ if info == nil {
+ continue
+ }
+ if info.Success {
+ proxies[i].LatencyStatus = "success"
+ proxies[i].LatencyMs = info.LatencyMs
+ } else {
+ proxies[i].LatencyStatus = "failed"
+ }
+ proxies[i].LatencyMessage = info.Message
+ proxies[i].IPAddress = info.IPAddress
+ proxies[i].Country = info.Country
+ proxies[i].CountryCode = info.CountryCode
+ proxies[i].Region = info.Region
+ proxies[i].City = info.City
+ }
+}
+
+func (s *adminServiceImpl) saveProxyLatency(ctx context.Context, proxyID int64, info *ProxyLatencyInfo) {
+ if s.proxyLatencyCache == nil || info == nil {
+ return
+ }
+ if err := s.proxyLatencyCache.SetProxyLatency(ctx, proxyID, info); err != nil {
+ log.Printf("Warning: store proxy latency cache failed: %v", err)
+ }
+}
+
// getAccountPlatform 根据账号 platform 判断混合渠道检查用的平台标识
func getAccountPlatform(accountPlatform string) string {
switch strings.ToLower(strings.TrimSpace(accountPlatform)) {
diff --git a/backend/internal/service/admin_service_bulk_update_test.go b/backend/internal/service/admin_service_bulk_update_test.go
index ef621213..662b95fb 100644
--- a/backend/internal/service/admin_service_bulk_update_test.go
+++ b/backend/internal/service/admin_service_bulk_update_test.go
@@ -12,9 +12,9 @@ import (
type accountRepoStubForBulkUpdate struct {
accountRepoStub
- bulkUpdateErr error
- bulkUpdateIDs []int64
- bindGroupErrByID map[int64]error
+ bulkUpdateErr error
+ bulkUpdateIDs []int64
+ bindGroupErrByID map[int64]error
}
func (s *accountRepoStubForBulkUpdate) BulkUpdate(_ context.Context, ids []int64, _ AccountBulkUpdate) (int64, error) {
diff --git a/backend/internal/service/admin_service_delete_test.go b/backend/internal/service/admin_service_delete_test.go
index 31639472..afa433af 100644
--- a/backend/internal/service/admin_service_delete_test.go
+++ b/backend/internal/service/admin_service_delete_test.go
@@ -153,8 +153,10 @@ func (s *groupRepoStub) DeleteAccountGroupsByGroupID(ctx context.Context, groupI
}
type proxyRepoStub struct {
- deleteErr error
- deletedIDs []int64
+ deleteErr error
+ countErr error
+ accountCount int64
+ deletedIDs []int64
}
func (s *proxyRepoStub) Create(ctx context.Context, proxy *Proxy) error {
@@ -199,7 +201,14 @@ func (s *proxyRepoStub) ExistsByHostPortAuth(ctx context.Context, host string, p
}
func (s *proxyRepoStub) CountAccountsByProxyID(ctx context.Context, proxyID int64) (int64, error) {
- panic("unexpected CountAccountsByProxyID call")
+ if s.countErr != nil {
+ return 0, s.countErr
+ }
+ return s.accountCount, nil
+}
+
+func (s *proxyRepoStub) ListAccountSummariesByProxyID(ctx context.Context, proxyID int64) ([]ProxyAccountSummary, error) {
+ panic("unexpected ListAccountSummariesByProxyID call")
}
type redeemRepoStub struct {
@@ -409,6 +418,15 @@ func TestAdminService_DeleteProxy_Idempotent(t *testing.T) {
require.Equal(t, []int64{404}, repo.deletedIDs)
}
+func TestAdminService_DeleteProxy_InUse(t *testing.T) {
+ repo := &proxyRepoStub{accountCount: 2}
+ svc := &adminServiceImpl{proxyRepo: repo}
+
+ err := svc.DeleteProxy(context.Background(), 77)
+ require.ErrorIs(t, err, ErrProxyInUse)
+ require.Empty(t, repo.deletedIDs)
+}
+
func TestAdminService_DeleteProxy_Error(t *testing.T) {
deleteErr := errors.New("delete failed")
repo := &proxyRepoStub{deleteErr: deleteErr}
diff --git a/backend/internal/service/antigravity_gateway_service.go b/backend/internal/service/antigravity_gateway_service.go
index 60567434..7f3e97a2 100644
--- a/backend/internal/service/antigravity_gateway_service.go
+++ b/backend/internal/service/antigravity_gateway_service.go
@@ -564,6 +564,10 @@ urlFallbackLoop:
}
upstreamReq, err := antigravity.NewAPIRequestWithURL(ctx, baseURL, action, accessToken, geminiBody)
+ // Capture upstream request body for ops retry of this attempt.
+ if c != nil {
+ c.Set(OpsUpstreamRequestBodyKey, string(geminiBody))
+ }
if err != nil {
return nil, err
}
@@ -574,6 +578,7 @@ urlFallbackLoop:
appendOpsUpstreamError(c, OpsUpstreamErrorEvent{
Platform: account.Platform,
AccountID: account.ID,
+ AccountName: account.Name,
UpstreamStatusCode: 0,
Kind: "request_error",
Message: safeErr,
@@ -615,6 +620,7 @@ urlFallbackLoop:
appendOpsUpstreamError(c, OpsUpstreamErrorEvent{
Platform: account.Platform,
AccountID: account.ID,
+ AccountName: account.Name,
UpstreamStatusCode: resp.StatusCode,
UpstreamRequestID: resp.Header.Get("x-request-id"),
Kind: "retry",
@@ -645,6 +651,7 @@ urlFallbackLoop:
appendOpsUpstreamError(c, OpsUpstreamErrorEvent{
Platform: account.Platform,
AccountID: account.ID,
+ AccountName: account.Name,
UpstreamStatusCode: resp.StatusCode,
UpstreamRequestID: resp.Header.Get("x-request-id"),
Kind: "retry",
@@ -697,6 +704,7 @@ urlFallbackLoop:
appendOpsUpstreamError(c, OpsUpstreamErrorEvent{
Platform: account.Platform,
AccountID: account.ID,
+ AccountName: account.Name,
UpstreamStatusCode: resp.StatusCode,
UpstreamRequestID: resp.Header.Get("x-request-id"),
Kind: "signature_error",
@@ -740,6 +748,7 @@ urlFallbackLoop:
appendOpsUpstreamError(c, OpsUpstreamErrorEvent{
Platform: account.Platform,
AccountID: account.ID,
+ AccountName: account.Name,
UpstreamStatusCode: 0,
Kind: "signature_retry_request_error",
Message: sanitizeUpstreamErrorMessage(retryErr.Error()),
@@ -770,6 +779,7 @@ urlFallbackLoop:
appendOpsUpstreamError(c, OpsUpstreamErrorEvent{
Platform: account.Platform,
AccountID: account.ID,
+ AccountName: account.Name,
UpstreamStatusCode: retryResp.StatusCode,
UpstreamRequestID: retryResp.Header.Get("x-request-id"),
Kind: kind,
@@ -817,6 +827,7 @@ urlFallbackLoop:
appendOpsUpstreamError(c, OpsUpstreamErrorEvent{
Platform: account.Platform,
AccountID: account.ID,
+ AccountName: account.Name,
UpstreamStatusCode: resp.StatusCode,
UpstreamRequestID: resp.Header.Get("x-request-id"),
Kind: "failover",
@@ -1371,6 +1382,7 @@ urlFallbackLoop:
appendOpsUpstreamError(c, OpsUpstreamErrorEvent{
Platform: account.Platform,
AccountID: account.ID,
+ AccountName: account.Name,
UpstreamStatusCode: 0,
Kind: "request_error",
Message: safeErr,
@@ -1412,6 +1424,7 @@ urlFallbackLoop:
appendOpsUpstreamError(c, OpsUpstreamErrorEvent{
Platform: account.Platform,
AccountID: account.ID,
+ AccountName: account.Name,
UpstreamStatusCode: resp.StatusCode,
UpstreamRequestID: resp.Header.Get("x-request-id"),
Kind: "retry",
@@ -1442,6 +1455,7 @@ urlFallbackLoop:
appendOpsUpstreamError(c, OpsUpstreamErrorEvent{
Platform: account.Platform,
AccountID: account.ID,
+ AccountName: account.Name,
UpstreamStatusCode: resp.StatusCode,
UpstreamRequestID: resp.Header.Get("x-request-id"),
Kind: "retry",
@@ -1543,6 +1557,7 @@ urlFallbackLoop:
appendOpsUpstreamError(c, OpsUpstreamErrorEvent{
Platform: account.Platform,
AccountID: account.ID,
+ AccountName: account.Name,
UpstreamStatusCode: resp.StatusCode,
UpstreamRequestID: requestID,
Kind: "failover",
@@ -1559,6 +1574,7 @@ urlFallbackLoop:
appendOpsUpstreamError(c, OpsUpstreamErrorEvent{
Platform: account.Platform,
AccountID: account.ID,
+ AccountName: account.Name,
UpstreamStatusCode: resp.StatusCode,
UpstreamRequestID: requestID,
Kind: "http_error",
@@ -2039,6 +2055,7 @@ func (s *AntigravityGatewayService) writeMappedClaudeError(c *gin.Context, accou
appendOpsUpstreamError(c, OpsUpstreamErrorEvent{
Platform: account.Platform,
AccountID: account.ID,
+ AccountName: account.Name,
UpstreamStatusCode: upstreamStatus,
UpstreamRequestID: upstreamRequestID,
Kind: "http_error",
diff --git a/backend/internal/service/antigravity_quota_scope.go b/backend/internal/service/antigravity_quota_scope.go
index e9f7184b..a3b2ec66 100644
--- a/backend/internal/service/antigravity_quota_scope.go
+++ b/backend/internal/service/antigravity_quota_scope.go
@@ -49,6 +49,9 @@ func (a *Account) IsSchedulableForModel(requestedModel string) bool {
if !a.IsSchedulable() {
return false
}
+ if a.isModelRateLimited(requestedModel) {
+ return false
+ }
if a.Platform != PlatformAntigravity {
return true
}
diff --git a/backend/internal/service/antigravity_token_provider.go b/backend/internal/service/antigravity_token_provider.go
index cbd1bef4..c5dc55db 100644
--- a/backend/internal/service/antigravity_token_provider.go
+++ b/backend/internal/service/antigravity_token_provider.go
@@ -45,7 +45,7 @@ func (p *AntigravityTokenProvider) GetAccessToken(ctx context.Context, account *
return "", errors.New("not an antigravity oauth account")
}
- cacheKey := antigravityTokenCacheKey(account)
+ cacheKey := AntigravityTokenCacheKey(account)
// 1. 先尝试缓存
if p.tokenCache != nil {
@@ -121,7 +121,7 @@ func (p *AntigravityTokenProvider) GetAccessToken(ctx context.Context, account *
return accessToken, nil
}
-func antigravityTokenCacheKey(account *Account) string {
+func AntigravityTokenCacheKey(account *Account) string {
projectID := strings.TrimSpace(account.GetCredential("project_id"))
if projectID != "" {
return "ag:" + projectID
diff --git a/backend/internal/service/api_key_auth_cache.go b/backend/internal/service/api_key_auth_cache.go
index 7ce9a8a2..5b476dbc 100644
--- a/backend/internal/service/api_key_auth_cache.go
+++ b/backend/internal/service/api_key_auth_cache.go
@@ -37,6 +37,11 @@ type APIKeyAuthGroupSnapshot struct {
ImagePrice4K *float64 `json:"image_price_4k,omitempty"`
ClaudeCodeOnly bool `json:"claude_code_only"`
FallbackGroupID *int64 `json:"fallback_group_id,omitempty"`
+
+ // Model routing is used by gateway account selection, so it must be part of auth cache snapshot.
+ // Only anthropic groups use these fields; others may leave them empty.
+ ModelRouting map[string][]int64 `json:"model_routing,omitempty"`
+ ModelRoutingEnabled bool `json:"model_routing_enabled"`
}
// APIKeyAuthCacheEntry 缓存条目,支持负缓存
diff --git a/backend/internal/service/api_key_auth_cache_impl.go b/backend/internal/service/api_key_auth_cache_impl.go
index dfc55eeb..521f1da5 100644
--- a/backend/internal/service/api_key_auth_cache_impl.go
+++ b/backend/internal/service/api_key_auth_cache_impl.go
@@ -207,20 +207,22 @@ func (s *APIKeyService) snapshotFromAPIKey(apiKey *APIKey) *APIKeyAuthSnapshot {
}
if apiKey.Group != nil {
snapshot.Group = &APIKeyAuthGroupSnapshot{
- ID: apiKey.Group.ID,
- Name: apiKey.Group.Name,
- Platform: apiKey.Group.Platform,
- Status: apiKey.Group.Status,
- SubscriptionType: apiKey.Group.SubscriptionType,
- RateMultiplier: apiKey.Group.RateMultiplier,
- DailyLimitUSD: apiKey.Group.DailyLimitUSD,
- WeeklyLimitUSD: apiKey.Group.WeeklyLimitUSD,
- MonthlyLimitUSD: apiKey.Group.MonthlyLimitUSD,
- ImagePrice1K: apiKey.Group.ImagePrice1K,
- ImagePrice2K: apiKey.Group.ImagePrice2K,
- ImagePrice4K: apiKey.Group.ImagePrice4K,
- ClaudeCodeOnly: apiKey.Group.ClaudeCodeOnly,
- FallbackGroupID: apiKey.Group.FallbackGroupID,
+ ID: apiKey.Group.ID,
+ Name: apiKey.Group.Name,
+ Platform: apiKey.Group.Platform,
+ Status: apiKey.Group.Status,
+ SubscriptionType: apiKey.Group.SubscriptionType,
+ RateMultiplier: apiKey.Group.RateMultiplier,
+ DailyLimitUSD: apiKey.Group.DailyLimitUSD,
+ WeeklyLimitUSD: apiKey.Group.WeeklyLimitUSD,
+ MonthlyLimitUSD: apiKey.Group.MonthlyLimitUSD,
+ ImagePrice1K: apiKey.Group.ImagePrice1K,
+ ImagePrice2K: apiKey.Group.ImagePrice2K,
+ ImagePrice4K: apiKey.Group.ImagePrice4K,
+ ClaudeCodeOnly: apiKey.Group.ClaudeCodeOnly,
+ FallbackGroupID: apiKey.Group.FallbackGroupID,
+ ModelRouting: apiKey.Group.ModelRouting,
+ ModelRoutingEnabled: apiKey.Group.ModelRoutingEnabled,
}
}
return snapshot
@@ -248,21 +250,23 @@ func (s *APIKeyService) snapshotToAPIKey(key string, snapshot *APIKeyAuthSnapsho
}
if snapshot.Group != nil {
apiKey.Group = &Group{
- ID: snapshot.Group.ID,
- Name: snapshot.Group.Name,
- Platform: snapshot.Group.Platform,
- Status: snapshot.Group.Status,
- Hydrated: true,
- SubscriptionType: snapshot.Group.SubscriptionType,
- RateMultiplier: snapshot.Group.RateMultiplier,
- DailyLimitUSD: snapshot.Group.DailyLimitUSD,
- WeeklyLimitUSD: snapshot.Group.WeeklyLimitUSD,
- MonthlyLimitUSD: snapshot.Group.MonthlyLimitUSD,
- ImagePrice1K: snapshot.Group.ImagePrice1K,
- ImagePrice2K: snapshot.Group.ImagePrice2K,
- ImagePrice4K: snapshot.Group.ImagePrice4K,
- ClaudeCodeOnly: snapshot.Group.ClaudeCodeOnly,
- FallbackGroupID: snapshot.Group.FallbackGroupID,
+ ID: snapshot.Group.ID,
+ Name: snapshot.Group.Name,
+ Platform: snapshot.Group.Platform,
+ Status: snapshot.Group.Status,
+ Hydrated: true,
+ SubscriptionType: snapshot.Group.SubscriptionType,
+ RateMultiplier: snapshot.Group.RateMultiplier,
+ DailyLimitUSD: snapshot.Group.DailyLimitUSD,
+ WeeklyLimitUSD: snapshot.Group.WeeklyLimitUSD,
+ MonthlyLimitUSD: snapshot.Group.MonthlyLimitUSD,
+ ImagePrice1K: snapshot.Group.ImagePrice1K,
+ ImagePrice2K: snapshot.Group.ImagePrice2K,
+ ImagePrice4K: snapshot.Group.ImagePrice4K,
+ ClaudeCodeOnly: snapshot.Group.ClaudeCodeOnly,
+ FallbackGroupID: snapshot.Group.FallbackGroupID,
+ ModelRouting: snapshot.Group.ModelRouting,
+ ModelRoutingEnabled: snapshot.Group.ModelRoutingEnabled,
}
}
return apiKey
diff --git a/backend/internal/service/api_key_service_cache_test.go b/backend/internal/service/api_key_service_cache_test.go
index 3314ca8d..5f2d69c4 100644
--- a/backend/internal/service/api_key_service_cache_test.go
+++ b/backend/internal/service/api_key_service_cache_test.go
@@ -172,12 +172,16 @@ func TestAPIKeyService_GetByKey_UsesL2Cache(t *testing.T) {
Concurrency: 3,
},
Group: &APIKeyAuthGroupSnapshot{
- ID: groupID,
- Name: "g",
- Platform: PlatformAnthropic,
- Status: StatusActive,
- SubscriptionType: SubscriptionTypeStandard,
- RateMultiplier: 1,
+ ID: groupID,
+ Name: "g",
+ Platform: PlatformAnthropic,
+ Status: StatusActive,
+ SubscriptionType: SubscriptionTypeStandard,
+ RateMultiplier: 1,
+ ModelRoutingEnabled: true,
+ ModelRouting: map[string][]int64{
+ "claude-opus-*": {1, 2},
+ },
},
},
}
@@ -190,6 +194,8 @@ func TestAPIKeyService_GetByKey_UsesL2Cache(t *testing.T) {
require.Equal(t, int64(1), apiKey.ID)
require.Equal(t, int64(2), apiKey.User.ID)
require.Equal(t, groupID, apiKey.Group.ID)
+ require.True(t, apiKey.Group.ModelRoutingEnabled)
+ require.Equal(t, map[string][]int64{"claude-opus-*": {1, 2}}, apiKey.Group.ModelRouting)
}
func TestAPIKeyService_GetByKey_NegativeCache(t *testing.T) {
diff --git a/backend/internal/service/claude_token_provider.go b/backend/internal/service/claude_token_provider.go
new file mode 100644
index 00000000..c7c6e42d
--- /dev/null
+++ b/backend/internal/service/claude_token_provider.go
@@ -0,0 +1,208 @@
+package service
+
+import (
+ "context"
+ "errors"
+ "log/slog"
+ "strconv"
+ "strings"
+ "time"
+)
+
+const (
+ claudeTokenRefreshSkew = 3 * time.Minute
+ claudeTokenCacheSkew = 5 * time.Minute
+ claudeLockWaitTime = 200 * time.Millisecond
+)
+
+// ClaudeTokenCache Token 缓存接口(复用 GeminiTokenCache 接口定义)
+type ClaudeTokenCache = GeminiTokenCache
+
+// ClaudeTokenProvider 管理 Claude (Anthropic) OAuth 账户的 access_token
+type ClaudeTokenProvider struct {
+ accountRepo AccountRepository
+ tokenCache ClaudeTokenCache
+ oauthService *OAuthService
+}
+
+func NewClaudeTokenProvider(
+ accountRepo AccountRepository,
+ tokenCache ClaudeTokenCache,
+ oauthService *OAuthService,
+) *ClaudeTokenProvider {
+ return &ClaudeTokenProvider{
+ accountRepo: accountRepo,
+ tokenCache: tokenCache,
+ oauthService: oauthService,
+ }
+}
+
+// GetAccessToken 获取有效的 access_token
+func (p *ClaudeTokenProvider) GetAccessToken(ctx context.Context, account *Account) (string, error) {
+ if account == nil {
+ return "", errors.New("account is nil")
+ }
+ if account.Platform != PlatformAnthropic || account.Type != AccountTypeOAuth {
+ return "", errors.New("not an anthropic oauth account")
+ }
+
+ cacheKey := ClaudeTokenCacheKey(account)
+
+ // 1. 先尝试缓存
+ if p.tokenCache != nil {
+ if token, err := p.tokenCache.GetAccessToken(ctx, cacheKey); err == nil && strings.TrimSpace(token) != "" {
+ slog.Debug("claude_token_cache_hit", "account_id", account.ID)
+ return token, nil
+ } else if err != nil {
+ slog.Warn("claude_token_cache_get_failed", "account_id", account.ID, "error", err)
+ }
+ }
+
+ slog.Debug("claude_token_cache_miss", "account_id", account.ID)
+
+ // 2. 如果即将过期则刷新
+ expiresAt := account.GetCredentialAsTime("expires_at")
+ needsRefresh := expiresAt == nil || time.Until(*expiresAt) <= claudeTokenRefreshSkew
+ refreshFailed := false
+ if needsRefresh && p.tokenCache != nil {
+ locked, lockErr := p.tokenCache.AcquireRefreshLock(ctx, cacheKey, 30*time.Second)
+ if lockErr == nil && locked {
+ defer func() { _ = p.tokenCache.ReleaseRefreshLock(ctx, cacheKey) }()
+
+ // 拿到锁后再次检查缓存(另一个 worker 可能已刷新)
+ if token, err := p.tokenCache.GetAccessToken(ctx, cacheKey); err == nil && strings.TrimSpace(token) != "" {
+ return token, nil
+ }
+
+ // 从数据库获取最新账户信息
+ fresh, err := p.accountRepo.GetByID(ctx, account.ID)
+ if err == nil && fresh != nil {
+ account = fresh
+ }
+ expiresAt = account.GetCredentialAsTime("expires_at")
+ if expiresAt == nil || time.Until(*expiresAt) <= claudeTokenRefreshSkew {
+ if p.oauthService == nil {
+ slog.Warn("claude_oauth_service_not_configured", "account_id", account.ID)
+ refreshFailed = true // 无法刷新,标记失败
+ } else {
+ tokenInfo, err := p.oauthService.RefreshAccountToken(ctx, account)
+ if err != nil {
+ // 刷新失败时记录警告,但不立即返回错误,尝试使用现有 token
+ slog.Warn("claude_token_refresh_failed", "account_id", account.ID, "error", err)
+ refreshFailed = true // 刷新失败,标记以使用短 TTL
+ } else {
+ // 构建新 credentials,保留原有字段
+ newCredentials := make(map[string]any)
+ for k, v := range account.Credentials {
+ newCredentials[k] = v
+ }
+ newCredentials["access_token"] = tokenInfo.AccessToken
+ newCredentials["token_type"] = tokenInfo.TokenType
+ newCredentials["expires_in"] = strconv.FormatInt(tokenInfo.ExpiresIn, 10)
+ newCredentials["expires_at"] = strconv.FormatInt(tokenInfo.ExpiresAt, 10)
+ if tokenInfo.RefreshToken != "" {
+ newCredentials["refresh_token"] = tokenInfo.RefreshToken
+ }
+ if tokenInfo.Scope != "" {
+ newCredentials["scope"] = tokenInfo.Scope
+ }
+ account.Credentials = newCredentials
+ if updateErr := p.accountRepo.Update(ctx, account); updateErr != nil {
+ slog.Error("claude_token_provider_update_failed", "account_id", account.ID, "error", updateErr)
+ }
+ expiresAt = account.GetCredentialAsTime("expires_at")
+ }
+ }
+ }
+ } else if lockErr != nil {
+ // Redis 错误导致无法获取锁,降级为无锁刷新(仅在 token 接近过期时)
+ slog.Warn("claude_token_lock_failed_degraded_refresh", "account_id", account.ID, "error", lockErr)
+
+ // 检查 ctx 是否已取消
+ if ctx.Err() != nil {
+ return "", ctx.Err()
+ }
+
+ // 从数据库获取最新账户信息
+ if p.accountRepo != nil {
+ fresh, err := p.accountRepo.GetByID(ctx, account.ID)
+ if err == nil && fresh != nil {
+ account = fresh
+ }
+ }
+ expiresAt = account.GetCredentialAsTime("expires_at")
+
+ // 仅在 expires_at 已过期/接近过期时才执行无锁刷新
+ if expiresAt == nil || time.Until(*expiresAt) <= claudeTokenRefreshSkew {
+ if p.oauthService == nil {
+ slog.Warn("claude_oauth_service_not_configured", "account_id", account.ID)
+ refreshFailed = true
+ } else {
+ tokenInfo, err := p.oauthService.RefreshAccountToken(ctx, account)
+ if err != nil {
+ slog.Warn("claude_token_refresh_failed_degraded", "account_id", account.ID, "error", err)
+ refreshFailed = true
+ } else {
+ // 构建新 credentials,保留原有字段
+ newCredentials := make(map[string]any)
+ for k, v := range account.Credentials {
+ newCredentials[k] = v
+ }
+ newCredentials["access_token"] = tokenInfo.AccessToken
+ newCredentials["token_type"] = tokenInfo.TokenType
+ newCredentials["expires_in"] = strconv.FormatInt(tokenInfo.ExpiresIn, 10)
+ newCredentials["expires_at"] = strconv.FormatInt(tokenInfo.ExpiresAt, 10)
+ if tokenInfo.RefreshToken != "" {
+ newCredentials["refresh_token"] = tokenInfo.RefreshToken
+ }
+ if tokenInfo.Scope != "" {
+ newCredentials["scope"] = tokenInfo.Scope
+ }
+ account.Credentials = newCredentials
+ if updateErr := p.accountRepo.Update(ctx, account); updateErr != nil {
+ slog.Error("claude_token_provider_update_failed", "account_id", account.ID, "error", updateErr)
+ }
+ expiresAt = account.GetCredentialAsTime("expires_at")
+ }
+ }
+ }
+ } else {
+ // 锁获取失败(被其他 worker 持有),等待 200ms 后重试读取缓存
+ time.Sleep(claudeLockWaitTime)
+ if token, err := p.tokenCache.GetAccessToken(ctx, cacheKey); err == nil && strings.TrimSpace(token) != "" {
+ slog.Debug("claude_token_cache_hit_after_wait", "account_id", account.ID)
+ return token, nil
+ }
+ }
+ }
+
+ accessToken := account.GetCredential("access_token")
+ if strings.TrimSpace(accessToken) == "" {
+ return "", errors.New("access_token not found in credentials")
+ }
+
+ // 3. 存入缓存
+ if p.tokenCache != nil {
+ ttl := 30 * time.Minute
+ if refreshFailed {
+ // 刷新失败时使用短 TTL,避免失效 token 长时间缓存导致 401 抖动
+ ttl = time.Minute
+ slog.Debug("claude_token_cache_short_ttl", "account_id", account.ID, "reason", "refresh_failed")
+ } else if expiresAt != nil {
+ until := time.Until(*expiresAt)
+ switch {
+ case until > claudeTokenCacheSkew:
+ ttl = until - claudeTokenCacheSkew
+ case until > 0:
+ ttl = until
+ default:
+ ttl = time.Minute
+ }
+ }
+ if err := p.tokenCache.SetAccessToken(ctx, cacheKey, accessToken, ttl); err != nil {
+ slog.Warn("claude_token_cache_set_failed", "account_id", account.ID, "error", err)
+ }
+ }
+
+ return accessToken, nil
+}
diff --git a/backend/internal/service/claude_token_provider_test.go b/backend/internal/service/claude_token_provider_test.go
new file mode 100644
index 00000000..3e21f6f4
--- /dev/null
+++ b/backend/internal/service/claude_token_provider_test.go
@@ -0,0 +1,939 @@
+//go:build unit
+
+package service
+
+import (
+ "context"
+ "errors"
+ "sync"
+ "sync/atomic"
+ "testing"
+ "time"
+
+ "github.com/stretchr/testify/require"
+)
+
+// claudeTokenCacheStub implements ClaudeTokenCache for testing
+type claudeTokenCacheStub struct {
+ mu sync.Mutex
+ tokens map[string]string
+ getErr error
+ setErr error
+ deleteErr error
+ lockAcquired bool
+ lockErr error
+ releaseLockErr error
+ getCalled int32
+ setCalled int32
+ lockCalled int32
+ unlockCalled int32
+ simulateLockRace bool
+}
+
+func newClaudeTokenCacheStub() *claudeTokenCacheStub {
+ return &claudeTokenCacheStub{
+ tokens: make(map[string]string),
+ lockAcquired: true,
+ }
+}
+
+func (s *claudeTokenCacheStub) GetAccessToken(ctx context.Context, cacheKey string) (string, error) {
+ atomic.AddInt32(&s.getCalled, 1)
+ if s.getErr != nil {
+ return "", s.getErr
+ }
+ s.mu.Lock()
+ defer s.mu.Unlock()
+ return s.tokens[cacheKey], nil
+}
+
+func (s *claudeTokenCacheStub) SetAccessToken(ctx context.Context, cacheKey string, token string, ttl time.Duration) error {
+ atomic.AddInt32(&s.setCalled, 1)
+ if s.setErr != nil {
+ return s.setErr
+ }
+ s.mu.Lock()
+ defer s.mu.Unlock()
+ s.tokens[cacheKey] = token
+ return nil
+}
+
+func (s *claudeTokenCacheStub) DeleteAccessToken(ctx context.Context, cacheKey string) error {
+ if s.deleteErr != nil {
+ return s.deleteErr
+ }
+ s.mu.Lock()
+ defer s.mu.Unlock()
+ delete(s.tokens, cacheKey)
+ return nil
+}
+
+func (s *claudeTokenCacheStub) AcquireRefreshLock(ctx context.Context, cacheKey string, ttl time.Duration) (bool, error) {
+ atomic.AddInt32(&s.lockCalled, 1)
+ if s.lockErr != nil {
+ return false, s.lockErr
+ }
+ if s.simulateLockRace {
+ return false, nil
+ }
+ return s.lockAcquired, nil
+}
+
+func (s *claudeTokenCacheStub) ReleaseRefreshLock(ctx context.Context, cacheKey string) error {
+ atomic.AddInt32(&s.unlockCalled, 1)
+ return s.releaseLockErr
+}
+
+// claudeAccountRepoStub is a minimal stub implementing only the methods used by ClaudeTokenProvider
+type claudeAccountRepoStub struct {
+ account *Account
+ getErr error
+ updateErr error
+ getCalled int32
+ updateCalled int32
+}
+
+func (r *claudeAccountRepoStub) GetByID(ctx context.Context, id int64) (*Account, error) {
+ atomic.AddInt32(&r.getCalled, 1)
+ if r.getErr != nil {
+ return nil, r.getErr
+ }
+ return r.account, nil
+}
+
+func (r *claudeAccountRepoStub) Update(ctx context.Context, account *Account) error {
+ atomic.AddInt32(&r.updateCalled, 1)
+ if r.updateErr != nil {
+ return r.updateErr
+ }
+ r.account = account
+ return nil
+}
+
+// claudeOAuthServiceStub implements OAuthService methods for testing
+type claudeOAuthServiceStub struct {
+ tokenInfo *TokenInfo
+ refreshErr error
+ refreshCalled int32
+}
+
+func (s *claudeOAuthServiceStub) RefreshAccountToken(ctx context.Context, account *Account) (*TokenInfo, error) {
+ atomic.AddInt32(&s.refreshCalled, 1)
+ if s.refreshErr != nil {
+ return nil, s.refreshErr
+ }
+ return s.tokenInfo, nil
+}
+
+// testClaudeTokenProvider is a test version that uses the stub OAuth service
+type testClaudeTokenProvider struct {
+ accountRepo *claudeAccountRepoStub
+ tokenCache *claudeTokenCacheStub
+ oauthService *claudeOAuthServiceStub
+}
+
+func (p *testClaudeTokenProvider) GetAccessToken(ctx context.Context, account *Account) (string, error) {
+ if account == nil {
+ return "", errors.New("account is nil")
+ }
+ if account.Platform != PlatformAnthropic || account.Type != AccountTypeOAuth {
+ return "", errors.New("not an anthropic oauth account")
+ }
+
+ cacheKey := ClaudeTokenCacheKey(account)
+
+ // 1. Check cache
+ if p.tokenCache != nil {
+ if token, err := p.tokenCache.GetAccessToken(ctx, cacheKey); err == nil && token != "" {
+ return token, nil
+ }
+ }
+
+ // 2. Check if refresh needed
+ expiresAt := account.GetCredentialAsTime("expires_at")
+ needsRefresh := expiresAt == nil || time.Until(*expiresAt) <= claudeTokenRefreshSkew
+ refreshFailed := false
+ if needsRefresh && p.tokenCache != nil {
+ locked, err := p.tokenCache.AcquireRefreshLock(ctx, cacheKey, 30*time.Second)
+ if err == nil && locked {
+ defer func() { _ = p.tokenCache.ReleaseRefreshLock(ctx, cacheKey) }()
+
+ // Check cache again after acquiring lock
+ if token, err := p.tokenCache.GetAccessToken(ctx, cacheKey); err == nil && token != "" {
+ return token, nil
+ }
+
+ // Get fresh account from DB
+ fresh, err := p.accountRepo.GetByID(ctx, account.ID)
+ if err == nil && fresh != nil {
+ account = fresh
+ }
+ expiresAt = account.GetCredentialAsTime("expires_at")
+ if expiresAt == nil || time.Until(*expiresAt) <= claudeTokenRefreshSkew {
+ if p.oauthService == nil {
+ refreshFailed = true // 无法刷新,标记失败
+ } else {
+ tokenInfo, err := p.oauthService.RefreshAccountToken(ctx, account)
+ if err != nil {
+ refreshFailed = true // 刷新失败,标记以使用短 TTL
+ } else {
+ // Build new credentials
+ newCredentials := make(map[string]any)
+ for k, v := range account.Credentials {
+ newCredentials[k] = v
+ }
+ newCredentials["access_token"] = tokenInfo.AccessToken
+ newCredentials["token_type"] = tokenInfo.TokenType
+ newCredentials["expires_at"] = time.Now().Add(time.Duration(tokenInfo.ExpiresIn) * time.Second).Format(time.RFC3339)
+ if tokenInfo.RefreshToken != "" {
+ newCredentials["refresh_token"] = tokenInfo.RefreshToken
+ }
+ account.Credentials = newCredentials
+ _ = p.accountRepo.Update(ctx, account)
+ expiresAt = account.GetCredentialAsTime("expires_at")
+ }
+ }
+ }
+ } else if p.tokenCache.simulateLockRace {
+ // Wait and retry cache
+ time.Sleep(10 * time.Millisecond)
+ if token, err := p.tokenCache.GetAccessToken(ctx, cacheKey); err == nil && token != "" {
+ return token, nil
+ }
+ }
+ }
+
+ accessToken := account.GetCredential("access_token")
+ if accessToken == "" {
+ return "", errors.New("access_token not found in credentials")
+ }
+
+ // 3. Store in cache
+ if p.tokenCache != nil {
+ ttl := 30 * time.Minute
+ if refreshFailed {
+ ttl = time.Minute // 刷新失败时使用短 TTL
+ } else if expiresAt != nil {
+ until := time.Until(*expiresAt)
+ if until > claudeTokenCacheSkew {
+ ttl = until - claudeTokenCacheSkew
+ } else if until > 0 {
+ ttl = until
+ } else {
+ ttl = time.Minute
+ }
+ }
+ _ = p.tokenCache.SetAccessToken(ctx, cacheKey, accessToken, ttl)
+ }
+
+ return accessToken, nil
+}
+
+func TestClaudeTokenProvider_CacheHit(t *testing.T) {
+ cache := newClaudeTokenCacheStub()
+ account := &Account{
+ ID: 100,
+ Platform: PlatformAnthropic,
+ Type: AccountTypeOAuth,
+ Credentials: map[string]any{
+ "access_token": "db-token",
+ },
+ }
+ cacheKey := ClaudeTokenCacheKey(account)
+ cache.tokens[cacheKey] = "cached-token"
+
+ provider := NewClaudeTokenProvider(nil, cache, nil)
+
+ token, err := provider.GetAccessToken(context.Background(), account)
+ require.NoError(t, err)
+ require.Equal(t, "cached-token", token)
+ require.Equal(t, int32(1), atomic.LoadInt32(&cache.getCalled))
+ require.Equal(t, int32(0), atomic.LoadInt32(&cache.setCalled))
+}
+
+func TestClaudeTokenProvider_CacheMiss_FromCredentials(t *testing.T) {
+ cache := newClaudeTokenCacheStub()
+ // Token expires in far future, no refresh needed
+ expiresAt := time.Now().Add(1 * time.Hour).Format(time.RFC3339)
+ account := &Account{
+ ID: 101,
+ Platform: PlatformAnthropic,
+ Type: AccountTypeOAuth,
+ Credentials: map[string]any{
+ "access_token": "credential-token",
+ "expires_at": expiresAt,
+ },
+ }
+
+ provider := NewClaudeTokenProvider(nil, cache, nil)
+
+ token, err := provider.GetAccessToken(context.Background(), account)
+ require.NoError(t, err)
+ require.Equal(t, "credential-token", token)
+
+ // Should have stored in cache
+ cacheKey := ClaudeTokenCacheKey(account)
+ require.Equal(t, "credential-token", cache.tokens[cacheKey])
+}
+
+func TestClaudeTokenProvider_TokenRefresh(t *testing.T) {
+ cache := newClaudeTokenCacheStub()
+ accountRepo := &claudeAccountRepoStub{}
+ oauthService := &claudeOAuthServiceStub{
+ tokenInfo: &TokenInfo{
+ AccessToken: "refreshed-token",
+ RefreshToken: "new-refresh-token",
+ TokenType: "Bearer",
+ ExpiresIn: 3600,
+ ExpiresAt: time.Now().Add(time.Hour).Unix(),
+ },
+ }
+
+ // Token expires soon (within refresh skew)
+ expiresAt := time.Now().Add(1 * time.Minute).Format(time.RFC3339)
+ account := &Account{
+ ID: 102,
+ Platform: PlatformAnthropic,
+ Type: AccountTypeOAuth,
+ Credentials: map[string]any{
+ "access_token": "old-token",
+ "refresh_token": "old-refresh-token",
+ "expires_at": expiresAt,
+ },
+ }
+ accountRepo.account = account
+
+ provider := &testClaudeTokenProvider{
+ accountRepo: accountRepo,
+ tokenCache: cache,
+ oauthService: oauthService,
+ }
+
+ token, err := provider.GetAccessToken(context.Background(), account)
+ require.NoError(t, err)
+ require.Equal(t, "refreshed-token", token)
+ require.Equal(t, int32(1), atomic.LoadInt32(&oauthService.refreshCalled))
+}
+
+func TestClaudeTokenProvider_LockRaceCondition(t *testing.T) {
+ cache := newClaudeTokenCacheStub()
+ cache.simulateLockRace = true
+ accountRepo := &claudeAccountRepoStub{}
+
+ // Token expires soon
+ expiresAt := time.Now().Add(1 * time.Minute).Format(time.RFC3339)
+ account := &Account{
+ ID: 103,
+ Platform: PlatformAnthropic,
+ Type: AccountTypeOAuth,
+ Credentials: map[string]any{
+ "access_token": "race-token",
+ "expires_at": expiresAt,
+ },
+ }
+ accountRepo.account = account
+
+ // Simulate another worker already refreshed and cached
+ cacheKey := ClaudeTokenCacheKey(account)
+ go func() {
+ time.Sleep(5 * time.Millisecond)
+ cache.mu.Lock()
+ cache.tokens[cacheKey] = "winner-token"
+ cache.mu.Unlock()
+ }()
+
+ provider := &testClaudeTokenProvider{
+ accountRepo: accountRepo,
+ tokenCache: cache,
+ }
+
+ token, err := provider.GetAccessToken(context.Background(), account)
+ require.NoError(t, err)
+ require.NotEmpty(t, token)
+}
+
+func TestClaudeTokenProvider_NilAccount(t *testing.T) {
+ provider := NewClaudeTokenProvider(nil, nil, nil)
+
+ token, err := provider.GetAccessToken(context.Background(), nil)
+ require.Error(t, err)
+ require.Contains(t, err.Error(), "account is nil")
+ require.Empty(t, token)
+}
+
+func TestClaudeTokenProvider_WrongPlatform(t *testing.T) {
+ provider := NewClaudeTokenProvider(nil, nil, nil)
+ account := &Account{
+ ID: 104,
+ Platform: PlatformOpenAI,
+ Type: AccountTypeOAuth,
+ }
+
+ token, err := provider.GetAccessToken(context.Background(), account)
+ require.Error(t, err)
+ require.Contains(t, err.Error(), "not an anthropic oauth account")
+ require.Empty(t, token)
+}
+
+func TestClaudeTokenProvider_WrongAccountType(t *testing.T) {
+ provider := NewClaudeTokenProvider(nil, nil, nil)
+ account := &Account{
+ ID: 105,
+ Platform: PlatformAnthropic,
+ Type: AccountTypeAPIKey,
+ }
+
+ token, err := provider.GetAccessToken(context.Background(), account)
+ require.Error(t, err)
+ require.Contains(t, err.Error(), "not an anthropic oauth account")
+ require.Empty(t, token)
+}
+
+func TestClaudeTokenProvider_SetupTokenType(t *testing.T) {
+ provider := NewClaudeTokenProvider(nil, nil, nil)
+ account := &Account{
+ ID: 106,
+ Platform: PlatformAnthropic,
+ Type: AccountTypeSetupToken,
+ }
+
+ token, err := provider.GetAccessToken(context.Background(), account)
+ require.Error(t, err)
+ require.Contains(t, err.Error(), "not an anthropic oauth account")
+ require.Empty(t, token)
+}
+
+func TestClaudeTokenProvider_NilCache(t *testing.T) {
+ // Token doesn't need refresh
+ expiresAt := time.Now().Add(1 * time.Hour).Format(time.RFC3339)
+ account := &Account{
+ ID: 107,
+ Platform: PlatformAnthropic,
+ Type: AccountTypeOAuth,
+ Credentials: map[string]any{
+ "access_token": "nocache-token",
+ "expires_at": expiresAt,
+ },
+ }
+
+ provider := NewClaudeTokenProvider(nil, nil, nil)
+
+ token, err := provider.GetAccessToken(context.Background(), account)
+ require.NoError(t, err)
+ require.Equal(t, "nocache-token", token)
+}
+
+func TestClaudeTokenProvider_CacheGetError(t *testing.T) {
+ cache := newClaudeTokenCacheStub()
+ cache.getErr = errors.New("redis connection failed")
+
+ // Token doesn't need refresh
+ expiresAt := time.Now().Add(1 * time.Hour).Format(time.RFC3339)
+ account := &Account{
+ ID: 108,
+ Platform: PlatformAnthropic,
+ Type: AccountTypeOAuth,
+ Credentials: map[string]any{
+ "access_token": "fallback-token",
+ "expires_at": expiresAt,
+ },
+ }
+
+ provider := NewClaudeTokenProvider(nil, cache, nil)
+
+ // Should gracefully degrade and return from credentials
+ token, err := provider.GetAccessToken(context.Background(), account)
+ require.NoError(t, err)
+ require.Equal(t, "fallback-token", token)
+}
+
+func TestClaudeTokenProvider_CacheSetError(t *testing.T) {
+ cache := newClaudeTokenCacheStub()
+ cache.setErr = errors.New("redis write failed")
+
+ expiresAt := time.Now().Add(1 * time.Hour).Format(time.RFC3339)
+ account := &Account{
+ ID: 109,
+ Platform: PlatformAnthropic,
+ Type: AccountTypeOAuth,
+ Credentials: map[string]any{
+ "access_token": "still-works-token",
+ "expires_at": expiresAt,
+ },
+ }
+
+ provider := NewClaudeTokenProvider(nil, cache, nil)
+
+ // Should still work even if cache set fails
+ token, err := provider.GetAccessToken(context.Background(), account)
+ require.NoError(t, err)
+ require.Equal(t, "still-works-token", token)
+}
+
+func TestClaudeTokenProvider_MissingAccessToken(t *testing.T) {
+ cache := newClaudeTokenCacheStub()
+ expiresAt := time.Now().Add(1 * time.Hour).Format(time.RFC3339)
+ account := &Account{
+ ID: 110,
+ Platform: PlatformAnthropic,
+ Type: AccountTypeOAuth,
+ Credentials: map[string]any{
+ "expires_at": expiresAt,
+ // missing access_token
+ },
+ }
+
+ provider := NewClaudeTokenProvider(nil, cache, nil)
+
+ token, err := provider.GetAccessToken(context.Background(), account)
+ require.Error(t, err)
+ require.Contains(t, err.Error(), "access_token not found")
+ require.Empty(t, token)
+}
+
+func TestClaudeTokenProvider_RefreshError(t *testing.T) {
+ cache := newClaudeTokenCacheStub()
+ accountRepo := &claudeAccountRepoStub{}
+ oauthService := &claudeOAuthServiceStub{
+ refreshErr: errors.New("oauth refresh failed"),
+ }
+
+ // Token expires soon
+ expiresAt := time.Now().Add(1 * time.Minute).Format(time.RFC3339)
+ account := &Account{
+ ID: 111,
+ Platform: PlatformAnthropic,
+ Type: AccountTypeOAuth,
+ Credentials: map[string]any{
+ "access_token": "old-token",
+ "refresh_token": "old-refresh-token",
+ "expires_at": expiresAt,
+ },
+ }
+ accountRepo.account = account
+
+ provider := &testClaudeTokenProvider{
+ accountRepo: accountRepo,
+ tokenCache: cache,
+ oauthService: oauthService,
+ }
+
+ // Now with fallback behavior, should return existing token even if refresh fails
+ token, err := provider.GetAccessToken(context.Background(), account)
+ require.NoError(t, err)
+ require.Equal(t, "old-token", token) // Fallback to existing token
+}
+
+func TestClaudeTokenProvider_OAuthServiceNotConfigured(t *testing.T) {
+ cache := newClaudeTokenCacheStub()
+ accountRepo := &claudeAccountRepoStub{}
+
+ // Token expires soon
+ expiresAt := time.Now().Add(1 * time.Minute).Format(time.RFC3339)
+ account := &Account{
+ ID: 112,
+ Platform: PlatformAnthropic,
+ Type: AccountTypeOAuth,
+ Credentials: map[string]any{
+ "access_token": "old-token",
+ "expires_at": expiresAt,
+ },
+ }
+ accountRepo.account = account
+
+ provider := &testClaudeTokenProvider{
+ accountRepo: accountRepo,
+ tokenCache: cache,
+ oauthService: nil, // not configured
+ }
+
+ // Now with fallback behavior, should return existing token even if oauth service not configured
+ token, err := provider.GetAccessToken(context.Background(), account)
+ require.NoError(t, err)
+ require.Equal(t, "old-token", token) // Fallback to existing token
+}
+
+func TestClaudeTokenProvider_TTLCalculation(t *testing.T) {
+ tests := []struct {
+ name string
+ expiresIn time.Duration
+ }{
+ {
+ name: "far_future_expiry",
+ expiresIn: 1 * time.Hour,
+ },
+ {
+ name: "medium_expiry",
+ expiresIn: 10 * time.Minute,
+ },
+ {
+ name: "near_expiry",
+ expiresIn: 6 * time.Minute,
+ },
+ }
+
+ for _, tt := range tests {
+ t.Run(tt.name, func(t *testing.T) {
+ cache := newClaudeTokenCacheStub()
+ expiresAt := time.Now().Add(tt.expiresIn).Format(time.RFC3339)
+ account := &Account{
+ ID: 200,
+ Platform: PlatformAnthropic,
+ Type: AccountTypeOAuth,
+ Credentials: map[string]any{
+ "access_token": "test-token",
+ "expires_at": expiresAt,
+ },
+ }
+
+ provider := NewClaudeTokenProvider(nil, cache, nil)
+
+ _, err := provider.GetAccessToken(context.Background(), account)
+ require.NoError(t, err)
+
+ // Verify token was cached
+ cacheKey := ClaudeTokenCacheKey(account)
+ require.Equal(t, "test-token", cache.tokens[cacheKey])
+ })
+ }
+}
+
+func TestClaudeTokenProvider_AccountRepoGetError(t *testing.T) {
+ cache := newClaudeTokenCacheStub()
+ accountRepo := &claudeAccountRepoStub{
+ getErr: errors.New("db connection failed"),
+ }
+ oauthService := &claudeOAuthServiceStub{
+ tokenInfo: &TokenInfo{
+ AccessToken: "refreshed-token",
+ RefreshToken: "new-refresh",
+ TokenType: "Bearer",
+ ExpiresIn: 3600,
+ },
+ }
+
+ // Token expires soon
+ expiresAt := time.Now().Add(1 * time.Minute).Format(time.RFC3339)
+ account := &Account{
+ ID: 113,
+ Platform: PlatformAnthropic,
+ Type: AccountTypeOAuth,
+ Credentials: map[string]any{
+ "access_token": "old-token",
+ "refresh_token": "old-refresh",
+ "expires_at": expiresAt,
+ },
+ }
+
+ provider := &testClaudeTokenProvider{
+ accountRepo: accountRepo,
+ tokenCache: cache,
+ oauthService: oauthService,
+ }
+
+ // Should still work, just using the passed-in account
+ token, err := provider.GetAccessToken(context.Background(), account)
+ require.NoError(t, err)
+ require.Equal(t, "refreshed-token", token)
+}
+
+func TestClaudeTokenProvider_AccountUpdateError(t *testing.T) {
+ cache := newClaudeTokenCacheStub()
+ accountRepo := &claudeAccountRepoStub{
+ updateErr: errors.New("db write failed"),
+ }
+ oauthService := &claudeOAuthServiceStub{
+ tokenInfo: &TokenInfo{
+ AccessToken: "refreshed-token",
+ RefreshToken: "new-refresh",
+ TokenType: "Bearer",
+ ExpiresIn: 3600,
+ },
+ }
+
+ // Token expires soon
+ expiresAt := time.Now().Add(1 * time.Minute).Format(time.RFC3339)
+ account := &Account{
+ ID: 114,
+ Platform: PlatformAnthropic,
+ Type: AccountTypeOAuth,
+ Credentials: map[string]any{
+ "access_token": "old-token",
+ "refresh_token": "old-refresh",
+ "expires_at": expiresAt,
+ },
+ }
+ accountRepo.account = account
+
+ provider := &testClaudeTokenProvider{
+ accountRepo: accountRepo,
+ tokenCache: cache,
+ oauthService: oauthService,
+ }
+
+ // Should still return token even if update fails
+ token, err := provider.GetAccessToken(context.Background(), account)
+ require.NoError(t, err)
+ require.Equal(t, "refreshed-token", token)
+}
+
+func TestClaudeTokenProvider_RefreshPreservesExistingCredentials(t *testing.T) {
+ cache := newClaudeTokenCacheStub()
+ accountRepo := &claudeAccountRepoStub{}
+ oauthService := &claudeOAuthServiceStub{
+ tokenInfo: &TokenInfo{
+ AccessToken: "new-access-token",
+ RefreshToken: "new-refresh-token",
+ TokenType: "Bearer",
+ ExpiresIn: 3600,
+ },
+ }
+
+ // Token expires soon
+ expiresAt := time.Now().Add(1 * time.Minute).Format(time.RFC3339)
+ account := &Account{
+ ID: 115,
+ Platform: PlatformAnthropic,
+ Type: AccountTypeOAuth,
+ Credentials: map[string]any{
+ "access_token": "old-access-token",
+ "refresh_token": "old-refresh-token",
+ "expires_at": expiresAt,
+ "custom_field": "should-be-preserved",
+ "organization": "test-org",
+ },
+ }
+ accountRepo.account = account
+
+ provider := &testClaudeTokenProvider{
+ accountRepo: accountRepo,
+ tokenCache: cache,
+ oauthService: oauthService,
+ }
+
+ token, err := provider.GetAccessToken(context.Background(), account)
+ require.NoError(t, err)
+ require.Equal(t, "new-access-token", token)
+
+ // Verify existing fields are preserved
+ require.Equal(t, "should-be-preserved", accountRepo.account.Credentials["custom_field"])
+ require.Equal(t, "test-org", accountRepo.account.Credentials["organization"])
+ // Verify new fields are updated
+ require.Equal(t, "new-access-token", accountRepo.account.Credentials["access_token"])
+ require.Equal(t, "new-refresh-token", accountRepo.account.Credentials["refresh_token"])
+}
+
+func TestClaudeTokenProvider_DoubleCheckCacheAfterLock(t *testing.T) {
+ cache := newClaudeTokenCacheStub()
+ accountRepo := &claudeAccountRepoStub{}
+ oauthService := &claudeOAuthServiceStub{
+ tokenInfo: &TokenInfo{
+ AccessToken: "refreshed-token",
+ RefreshToken: "new-refresh",
+ TokenType: "Bearer",
+ ExpiresIn: 3600,
+ },
+ }
+
+ // Token expires soon
+ expiresAt := time.Now().Add(1 * time.Minute).Format(time.RFC3339)
+ account := &Account{
+ ID: 116,
+ Platform: PlatformAnthropic,
+ Type: AccountTypeOAuth,
+ Credentials: map[string]any{
+ "access_token": "old-token",
+ "expires_at": expiresAt,
+ },
+ }
+ accountRepo.account = account
+ cacheKey := ClaudeTokenCacheKey(account)
+
+ // After lock is acquired, cache should have the token (simulating another worker)
+ go func() {
+ time.Sleep(5 * time.Millisecond)
+ cache.mu.Lock()
+ cache.tokens[cacheKey] = "cached-by-other-worker"
+ cache.mu.Unlock()
+ }()
+
+ provider := &testClaudeTokenProvider{
+ accountRepo: accountRepo,
+ tokenCache: cache,
+ oauthService: oauthService,
+ }
+
+ token, err := provider.GetAccessToken(context.Background(), account)
+ require.NoError(t, err)
+ require.NotEmpty(t, token)
+}
+
+// Tests for real provider - to increase coverage
+func TestClaudeTokenProvider_Real_LockFailedWait(t *testing.T) {
+ cache := newClaudeTokenCacheStub()
+ cache.lockAcquired = false // Lock acquisition fails
+
+ // Token expires soon (within refresh skew) to trigger lock attempt
+ expiresAt := time.Now().Add(1 * time.Minute).Format(time.RFC3339)
+ account := &Account{
+ ID: 300,
+ Platform: PlatformAnthropic,
+ Type: AccountTypeOAuth,
+ Credentials: map[string]any{
+ "access_token": "fallback-token",
+ "expires_at": expiresAt,
+ },
+ }
+
+ // Set token in cache after lock wait period (simulate other worker refreshing)
+ cacheKey := ClaudeTokenCacheKey(account)
+ go func() {
+ time.Sleep(100 * time.Millisecond)
+ cache.mu.Lock()
+ cache.tokens[cacheKey] = "refreshed-by-other"
+ cache.mu.Unlock()
+ }()
+
+ provider := NewClaudeTokenProvider(nil, cache, nil)
+ token, err := provider.GetAccessToken(context.Background(), account)
+ require.NoError(t, err)
+ require.NotEmpty(t, token)
+}
+
+func TestClaudeTokenProvider_Real_CacheHitAfterWait(t *testing.T) {
+ cache := newClaudeTokenCacheStub()
+ cache.lockAcquired = false // Lock acquisition fails
+
+ // Token expires soon
+ expiresAt := time.Now().Add(1 * time.Minute).Format(time.RFC3339)
+ account := &Account{
+ ID: 301,
+ Platform: PlatformAnthropic,
+ Type: AccountTypeOAuth,
+ Credentials: map[string]any{
+ "access_token": "original-token",
+ "expires_at": expiresAt,
+ },
+ }
+
+ cacheKey := ClaudeTokenCacheKey(account)
+ // Set token in cache immediately after wait starts
+ go func() {
+ time.Sleep(50 * time.Millisecond)
+ cache.mu.Lock()
+ cache.tokens[cacheKey] = "winner-token"
+ cache.mu.Unlock()
+ }()
+
+ provider := NewClaudeTokenProvider(nil, cache, nil)
+ token, err := provider.GetAccessToken(context.Background(), account)
+ require.NoError(t, err)
+ require.NotEmpty(t, token)
+}
+
+func TestClaudeTokenProvider_Real_NoExpiresAt(t *testing.T) {
+ cache := newClaudeTokenCacheStub()
+ cache.lockAcquired = false // Prevent entering refresh logic
+
+ // Token with nil expires_at (no expiry set)
+ account := &Account{
+ ID: 302,
+ Platform: PlatformAnthropic,
+ Type: AccountTypeOAuth,
+ Credentials: map[string]any{
+ "access_token": "no-expiry-token",
+ },
+ }
+
+ // After lock wait, return token from credentials
+ provider := NewClaudeTokenProvider(nil, cache, nil)
+ token, err := provider.GetAccessToken(context.Background(), account)
+ require.NoError(t, err)
+ require.Equal(t, "no-expiry-token", token)
+}
+
+func TestClaudeTokenProvider_Real_WhitespaceToken(t *testing.T) {
+ cache := newClaudeTokenCacheStub()
+ cacheKey := "claude:account:303"
+ cache.tokens[cacheKey] = " " // Whitespace only - should be treated as empty
+
+ expiresAt := time.Now().Add(1 * time.Hour).Format(time.RFC3339)
+ account := &Account{
+ ID: 303,
+ Platform: PlatformAnthropic,
+ Type: AccountTypeOAuth,
+ Credentials: map[string]any{
+ "access_token": "real-token",
+ "expires_at": expiresAt,
+ },
+ }
+
+ provider := NewClaudeTokenProvider(nil, cache, nil)
+ token, err := provider.GetAccessToken(context.Background(), account)
+ require.NoError(t, err)
+ require.Equal(t, "real-token", token)
+}
+
+func TestClaudeTokenProvider_Real_EmptyCredentialToken(t *testing.T) {
+ cache := newClaudeTokenCacheStub()
+
+ expiresAt := time.Now().Add(1 * time.Hour).Format(time.RFC3339)
+ account := &Account{
+ ID: 304,
+ Platform: PlatformAnthropic,
+ Type: AccountTypeOAuth,
+ Credentials: map[string]any{
+ "access_token": " ", // Whitespace only
+ "expires_at": expiresAt,
+ },
+ }
+
+ provider := NewClaudeTokenProvider(nil, cache, nil)
+ token, err := provider.GetAccessToken(context.Background(), account)
+ require.Error(t, err)
+ require.Contains(t, err.Error(), "access_token not found")
+ require.Empty(t, token)
+}
+
+func TestClaudeTokenProvider_Real_LockError(t *testing.T) {
+ cache := newClaudeTokenCacheStub()
+ cache.lockErr = errors.New("redis lock failed")
+
+ // Token expires soon (within refresh skew)
+ expiresAt := time.Now().Add(1 * time.Minute).Format(time.RFC3339)
+ account := &Account{
+ ID: 305,
+ Platform: PlatformAnthropic,
+ Type: AccountTypeOAuth,
+ Credentials: map[string]any{
+ "access_token": "fallback-on-lock-error",
+ "expires_at": expiresAt,
+ },
+ }
+
+ provider := NewClaudeTokenProvider(nil, cache, nil)
+ token, err := provider.GetAccessToken(context.Background(), account)
+ require.NoError(t, err)
+ require.Equal(t, "fallback-on-lock-error", token)
+}
+
+func TestClaudeTokenProvider_Real_NilCredentials(t *testing.T) {
+ cache := newClaudeTokenCacheStub()
+
+ expiresAt := time.Now().Add(1 * time.Hour).Format(time.RFC3339)
+ account := &Account{
+ ID: 306,
+ Platform: PlatformAnthropic,
+ Type: AccountTypeOAuth,
+ Credentials: map[string]any{
+ "expires_at": expiresAt,
+ // No access_token
+ },
+ }
+
+ provider := NewClaudeTokenProvider(nil, cache, nil)
+ token, err := provider.GetAccessToken(context.Background(), account)
+ require.Error(t, err)
+ require.Contains(t, err.Error(), "access_token not found")
+ require.Empty(t, token)
+}
diff --git a/backend/internal/service/dashboard_service.go b/backend/internal/service/dashboard_service.go
index 9bc56c54..a9811919 100644
--- a/backend/internal/service/dashboard_service.go
+++ b/backend/internal/service/dashboard_service.go
@@ -124,16 +124,16 @@ func (s *DashboardService) GetDashboardStats(ctx context.Context) (*usagestats.D
return stats, nil
}
-func (s *DashboardService) GetUsageTrendWithFilters(ctx context.Context, startTime, endTime time.Time, granularity string, userID, apiKeyID int64) ([]usagestats.TrendDataPoint, error) {
- trend, err := s.usageRepo.GetUsageTrendWithFilters(ctx, startTime, endTime, granularity, userID, apiKeyID)
+func (s *DashboardService) GetUsageTrendWithFilters(ctx context.Context, startTime, endTime time.Time, granularity string, userID, apiKeyID, accountID, groupID int64, model string, stream *bool) ([]usagestats.TrendDataPoint, error) {
+ trend, err := s.usageRepo.GetUsageTrendWithFilters(ctx, startTime, endTime, granularity, userID, apiKeyID, accountID, groupID, model, stream)
if err != nil {
return nil, fmt.Errorf("get usage trend with filters: %w", err)
}
return trend, nil
}
-func (s *DashboardService) GetModelStatsWithFilters(ctx context.Context, startTime, endTime time.Time, userID, apiKeyID int64) ([]usagestats.ModelStat, error) {
- stats, err := s.usageRepo.GetModelStatsWithFilters(ctx, startTime, endTime, userID, apiKeyID, 0)
+func (s *DashboardService) GetModelStatsWithFilters(ctx context.Context, startTime, endTime time.Time, userID, apiKeyID, accountID, groupID int64, stream *bool) ([]usagestats.ModelStat, error) {
+ stats, err := s.usageRepo.GetModelStatsWithFilters(ctx, startTime, endTime, userID, apiKeyID, accountID, groupID, stream)
if err != nil {
return nil, fmt.Errorf("get model stats with filters: %w", err)
}
diff --git a/backend/internal/service/gateway_multiplatform_test.go b/backend/internal/service/gateway_multiplatform_test.go
index c2dbf7c9..f543ef1a 100644
--- a/backend/internal/service/gateway_multiplatform_test.go
+++ b/backend/internal/service/gateway_multiplatform_test.go
@@ -142,6 +142,9 @@ func (m *mockAccountRepoForPlatform) SetRateLimited(ctx context.Context, id int6
func (m *mockAccountRepoForPlatform) SetAntigravityQuotaScopeLimit(ctx context.Context, id int64, scope AntigravityQuotaScope, resetAt time.Time) error {
return nil
}
+func (m *mockAccountRepoForPlatform) SetModelRateLimit(ctx context.Context, id int64, scope string, resetAt time.Time) error {
+ return nil
+}
func (m *mockAccountRepoForPlatform) SetOverloaded(ctx context.Context, id int64, until time.Time) error {
return nil
}
@@ -157,6 +160,9 @@ func (m *mockAccountRepoForPlatform) ClearRateLimit(ctx context.Context, id int6
func (m *mockAccountRepoForPlatform) ClearAntigravityQuotaScopes(ctx context.Context, id int64) error {
return nil
}
+func (m *mockAccountRepoForPlatform) ClearModelRateLimits(ctx context.Context, id int64) error {
+ return nil
+}
func (m *mockAccountRepoForPlatform) UpdateSessionWindow(ctx context.Context, id int64, start, end *time.Time, status string) error {
return nil
}
@@ -1046,13 +1052,67 @@ func TestGatewayService_SelectAccountWithLoadAwareness(t *testing.T) {
concurrencyService: nil, // No concurrency service
}
- result, err := svc.SelectAccountWithLoadAwareness(ctx, nil, "", "claude-3-5-sonnet-20241022", nil)
+ result, err := svc.SelectAccountWithLoadAwareness(ctx, nil, "", "claude-3-5-sonnet-20241022", nil, "")
require.NoError(t, err)
require.NotNil(t, result)
require.NotNil(t, result.Account)
require.Equal(t, int64(1), result.Account.ID, "应选择优先级最高的账号")
})
+ t.Run("模型路由-无ConcurrencyService也生效", func(t *testing.T) {
+ groupID := int64(1)
+ sessionHash := "sticky"
+
+ repo := &mockAccountRepoForPlatform{
+ accounts: []Account{
+ {ID: 1, Platform: PlatformAnthropic, Priority: 1, Status: StatusActive, Schedulable: true, Concurrency: 5, AccountGroups: []AccountGroup{{GroupID: groupID}}},
+ {ID: 2, Platform: PlatformAnthropic, Priority: 1, Status: StatusActive, Schedulable: true, Concurrency: 5, AccountGroups: []AccountGroup{{GroupID: groupID}}},
+ },
+ accountsByID: map[int64]*Account{},
+ }
+ for i := range repo.accounts {
+ repo.accountsByID[repo.accounts[i].ID] = &repo.accounts[i]
+ }
+
+ cache := &mockGatewayCacheForPlatform{
+ sessionBindings: map[string]int64{sessionHash: 1},
+ }
+
+ groupRepo := &mockGroupRepoForGateway{
+ groups: map[int64]*Group{
+ groupID: {
+ ID: groupID,
+ Platform: PlatformAnthropic,
+ Status: StatusActive,
+ Hydrated: true,
+ ModelRoutingEnabled: true,
+ ModelRouting: map[string][]int64{
+ "claude-a": {1},
+ "claude-b": {2},
+ },
+ },
+ },
+ }
+
+ cfg := testConfig()
+ cfg.Gateway.Scheduling.LoadBatchEnabled = true
+
+ svc := &GatewayService{
+ accountRepo: repo,
+ groupRepo: groupRepo,
+ cache: cache,
+ cfg: cfg,
+ concurrencyService: nil, // legacy path
+ }
+
+ result, err := svc.SelectAccountWithLoadAwareness(ctx, &groupID, sessionHash, "claude-b", nil, "")
+ require.NoError(t, err)
+ require.NotNil(t, result)
+ require.NotNil(t, result.Account)
+ require.Equal(t, int64(2), result.Account.ID, "切换到 claude-b 时应按模型路由切换账号")
+ require.Equal(t, int64(2), cache.sessionBindings[sessionHash], "粘性绑定应更新为路由选择的账号")
+ })
+
t.Run("无ConcurrencyService-降级到传统选择", func(t *testing.T) {
repo := &mockAccountRepoForPlatform{
accounts: []Account{
@@ -1077,7 +1137,7 @@ func TestGatewayService_SelectAccountWithLoadAwareness(t *testing.T) {
concurrencyService: nil,
}
- result, err := svc.SelectAccountWithLoadAwareness(ctx, nil, "", "claude-3-5-sonnet-20241022", nil)
+ result, err := svc.SelectAccountWithLoadAwareness(ctx, nil, "", "claude-3-5-sonnet-20241022", nil, "")
require.NoError(t, err)
require.NotNil(t, result)
require.NotNil(t, result.Account)
@@ -1109,7 +1169,7 @@ func TestGatewayService_SelectAccountWithLoadAwareness(t *testing.T) {
}
excludedIDs := map[int64]struct{}{1: {}}
- result, err := svc.SelectAccountWithLoadAwareness(ctx, nil, "", "claude-3-5-sonnet-20241022", excludedIDs)
+ result, err := svc.SelectAccountWithLoadAwareness(ctx, nil, "", "claude-3-5-sonnet-20241022", excludedIDs, "")
require.NoError(t, err)
require.NotNil(t, result)
require.NotNil(t, result.Account)
@@ -1143,7 +1203,7 @@ func TestGatewayService_SelectAccountWithLoadAwareness(t *testing.T) {
concurrencyService: NewConcurrencyService(concurrencyCache),
}
- result, err := svc.SelectAccountWithLoadAwareness(ctx, nil, "sticky", "claude-3-5-sonnet-20241022", nil)
+ result, err := svc.SelectAccountWithLoadAwareness(ctx, nil, "sticky", "claude-3-5-sonnet-20241022", nil, "")
require.NoError(t, err)
require.NotNil(t, result)
require.NotNil(t, result.Account)
@@ -1179,7 +1239,7 @@ func TestGatewayService_SelectAccountWithLoadAwareness(t *testing.T) {
concurrencyService: NewConcurrencyService(concurrencyCache),
}
- result, err := svc.SelectAccountWithLoadAwareness(ctx, nil, "sticky", "claude-3-5-sonnet-20241022", nil)
+ result, err := svc.SelectAccountWithLoadAwareness(ctx, nil, "sticky", "claude-3-5-sonnet-20241022", nil, "")
require.NoError(t, err)
require.NotNil(t, result)
require.NotNil(t, result.Account)
@@ -1206,7 +1266,7 @@ func TestGatewayService_SelectAccountWithLoadAwareness(t *testing.T) {
concurrencyService: nil,
}
- result, err := svc.SelectAccountWithLoadAwareness(ctx, nil, "", "claude-3-5-sonnet-20241022", nil)
+ result, err := svc.SelectAccountWithLoadAwareness(ctx, nil, "", "claude-3-5-sonnet-20241022", nil, "")
require.Error(t, err)
require.Nil(t, result)
require.Contains(t, err.Error(), "no available accounts")
@@ -1238,7 +1298,7 @@ func TestGatewayService_SelectAccountWithLoadAwareness(t *testing.T) {
concurrencyService: nil,
}
- result, err := svc.SelectAccountWithLoadAwareness(ctx, nil, "", "claude-3-5-sonnet-20241022", nil)
+ result, err := svc.SelectAccountWithLoadAwareness(ctx, nil, "", "claude-3-5-sonnet-20241022", nil, "")
require.NoError(t, err)
require.NotNil(t, result)
require.NotNil(t, result.Account)
@@ -1271,7 +1331,7 @@ func TestGatewayService_SelectAccountWithLoadAwareness(t *testing.T) {
concurrencyService: nil,
}
- result, err := svc.SelectAccountWithLoadAwareness(ctx, nil, "", "claude-3-5-sonnet-20241022", nil)
+ result, err := svc.SelectAccountWithLoadAwareness(ctx, nil, "", "claude-3-5-sonnet-20241022", nil, "")
require.NoError(t, err)
require.NotNil(t, result)
require.NotNil(t, result.Account)
@@ -1341,6 +1401,7 @@ func TestGatewayService_GroupResolution_IgnoresInvalidContextGroup(t *testing.T)
ID: groupID,
Platform: PlatformAnthropic,
Status: StatusActive,
+ Hydrated: true,
}
groupRepo := &mockGroupRepoForGateway{
groups: map[int64]*Group{groupID: group},
@@ -1398,6 +1459,7 @@ func TestGatewayService_GroupResolution_FallbackUsesLiteOnce(t *testing.T) {
ID: fallbackID,
Platform: PlatformAnthropic,
Status: StatusActive,
+ Hydrated: true,
}
ctx = context.WithValue(ctx, ctxkey.Group, group)
diff --git a/backend/internal/service/gateway_service.go b/backend/internal/service/gateway_service.go
index ff143eee..1f67f07d 100644
--- a/backend/internal/service/gateway_service.go
+++ b/backend/internal/service/gateway_service.go
@@ -12,6 +12,7 @@ import (
"io"
"log"
"net/http"
+ "os"
"regexp"
"sort"
"strings"
@@ -40,6 +41,21 @@ const (
maxCacheControlBlocks = 4 // Anthropic API 允许的最大 cache_control 块数量
)
+func (s *GatewayService) debugModelRoutingEnabled() bool {
+ v := strings.ToLower(strings.TrimSpace(os.Getenv("SUB2API_DEBUG_MODEL_ROUTING")))
+ return v == "1" || v == "true" || v == "yes" || v == "on"
+}
+
+func shortSessionHash(sessionHash string) string {
+ if sessionHash == "" {
+ return ""
+ }
+ if len(sessionHash) <= 8 {
+ return sessionHash
+ }
+ return sessionHash[:8]
+}
+
// sseDataRe matches SSE data lines with optional whitespace after colon.
// Some upstream APIs return non-standard "data:" without space (should be "data: ").
var (
@@ -196,6 +212,8 @@ type GatewayService struct {
httpUpstream HTTPUpstream
deferredService *DeferredService
concurrencyService *ConcurrencyService
+ claudeTokenProvider *ClaudeTokenProvider
+ sessionLimitCache SessionLimitCache // 会话数量限制缓存(仅 Anthropic OAuth/SetupToken)
}
// NewGatewayService creates a new GatewayService
@@ -215,6 +233,8 @@ func NewGatewayService(
identityService *IdentityService,
httpUpstream HTTPUpstream,
deferredService *DeferredService,
+ claudeTokenProvider *ClaudeTokenProvider,
+ sessionLimitCache SessionLimitCache,
) *GatewayService {
return &GatewayService{
accountRepo: accountRepo,
@@ -232,6 +252,8 @@ func NewGatewayService(
identityService: identityService,
httpUpstream: httpUpstream,
deferredService: deferredService,
+ claudeTokenProvider: claudeTokenProvider,
+ sessionLimitCache: sessionLimitCache,
}
}
@@ -797,8 +819,12 @@ func (s *GatewayService) SelectAccountForModelWithExclusions(ctx context.Context
}
// SelectAccountWithLoadAwareness selects account with load-awareness and wait plan.
-func (s *GatewayService) SelectAccountWithLoadAwareness(ctx context.Context, groupID *int64, sessionHash string, requestedModel string, excludedIDs map[int64]struct{}) (*AccountSelectionResult, error) {
+// metadataUserID: 原始 metadata.user_id 字段(用于提取会话 UUID 进行会话数量限制)
+func (s *GatewayService) SelectAccountWithLoadAwareness(ctx context.Context, groupID *int64, sessionHash string, requestedModel string, excludedIDs map[int64]struct{}, metadataUserID string) (*AccountSelectionResult, error) {
cfg := s.schedulingConfig()
+ // 提取会话 UUID(用于会话数量限制)
+ sessionUUID := extractSessionUUID(metadataUserID)
+
var stickyAccountID int64
if sessionHash != "" && s.cache != nil {
if accountID, err := s.cache.GetSessionAccountID(ctx, derefGroupID(groupID), sessionHash); err == nil {
@@ -813,6 +839,15 @@ func (s *GatewayService) SelectAccountWithLoadAwareness(ctx context.Context, gro
}
ctx = s.withGroupContext(ctx, group)
+ if s.debugModelRoutingEnabled() && requestedModel != "" {
+ groupPlatform := ""
+ if group != nil {
+ groupPlatform = group.Platform
+ }
+ log.Printf("[ModelRoutingDebug] select entry: group_id=%v group_platform=%s model=%s session=%s sticky_account=%d load_batch=%v concurrency=%v",
+ derefGroupID(groupID), groupPlatform, requestedModel, shortSessionHash(sessionHash), stickyAccountID, cfg.LoadBatchEnabled, s.concurrencyService != nil)
+ }
+
if s.concurrencyService == nil || !cfg.LoadBatchEnabled {
account, err := s.SelectAccountForModelWithExclusions(ctx, groupID, sessionHash, requestedModel, excludedIDs)
if err != nil {
@@ -856,6 +891,9 @@ func (s *GatewayService) SelectAccountWithLoadAwareness(ctx context.Context, gro
return nil, err
}
preferOAuth := platform == PlatformGemini
+ if s.debugModelRoutingEnabled() && platform == PlatformAnthropic && requestedModel != "" {
+ log.Printf("[ModelRoutingDebug] load-aware enabled: group_id=%v model=%s session=%s platform=%s", derefGroupID(groupID), requestedModel, shortSessionHash(sessionHash), platform)
+ }
accounts, useMixed, err := s.listSchedulableAccounts(ctx, groupID, platform, hasForcePlatform)
if err != nil {
@@ -873,28 +911,242 @@ func (s *GatewayService) SelectAccountWithLoadAwareness(ctx context.Context, gro
return excluded
}
- // ============ Layer 1: 粘性会话优先 ============
- if sessionHash != "" && s.cache != nil {
+ // 提前构建 accountByID(供 Layer 1 和 Layer 1.5 使用)
+ accountByID := make(map[int64]*Account, len(accounts))
+ for i := range accounts {
+ accountByID[accounts[i].ID] = &accounts[i]
+ }
+
+ // 获取模型路由配置(仅 anthropic 平台)
+ var routingAccountIDs []int64
+ if group != nil && requestedModel != "" && group.Platform == PlatformAnthropic {
+ routingAccountIDs = group.GetRoutingAccountIDs(requestedModel)
+ if s.debugModelRoutingEnabled() {
+ log.Printf("[ModelRoutingDebug] context group routing: group_id=%d model=%s enabled=%v rules=%d matched_ids=%v session=%s sticky_account=%d",
+ group.ID, requestedModel, group.ModelRoutingEnabled, len(group.ModelRouting), routingAccountIDs, shortSessionHash(sessionHash), stickyAccountID)
+ if len(routingAccountIDs) == 0 && group.ModelRoutingEnabled && len(group.ModelRouting) > 0 {
+ keys := make([]string, 0, len(group.ModelRouting))
+ for k := range group.ModelRouting {
+ keys = append(keys, k)
+ }
+ sort.Strings(keys)
+ const maxKeys = 20
+ if len(keys) > maxKeys {
+ keys = keys[:maxKeys]
+ }
+ log.Printf("[ModelRoutingDebug] context group routing miss: group_id=%d model=%s patterns(sample)=%v", group.ID, requestedModel, keys)
+ }
+ }
+ }
+
+ // ============ Layer 1: 模型路由优先选择(优先级高于粘性会话) ============
+ if len(routingAccountIDs) > 0 && s.concurrencyService != nil {
+ // 1. 过滤出路由列表中可调度的账号
+ var routingCandidates []*Account
+ var filteredExcluded, filteredMissing, filteredUnsched, filteredPlatform, filteredModelScope, filteredModelMapping, filteredWindowCost int
+ for _, routingAccountID := range routingAccountIDs {
+ if isExcluded(routingAccountID) {
+ filteredExcluded++
+ continue
+ }
+ account, ok := accountByID[routingAccountID]
+ if !ok || !account.IsSchedulable() {
+ if !ok {
+ filteredMissing++
+ } else {
+ filteredUnsched++
+ }
+ continue
+ }
+ if !s.isAccountAllowedForPlatform(account, platform, useMixed) {
+ filteredPlatform++
+ continue
+ }
+ if !account.IsSchedulableForModel(requestedModel) {
+ filteredModelScope++
+ continue
+ }
+ if requestedModel != "" && !s.isModelSupportedByAccount(account, requestedModel) {
+ filteredModelMapping++
+ continue
+ }
+ // 窗口费用检查(非粘性会话路径)
+ if !s.isAccountSchedulableForWindowCost(ctx, account, false) {
+ filteredWindowCost++
+ continue
+ }
+ routingCandidates = append(routingCandidates, account)
+ }
+
+ if s.debugModelRoutingEnabled() {
+ log.Printf("[ModelRoutingDebug] routed candidates: group_id=%v model=%s routed=%d candidates=%d filtered(excluded=%d missing=%d unsched=%d platform=%d model_scope=%d model_mapping=%d window_cost=%d)",
+ derefGroupID(groupID), requestedModel, len(routingAccountIDs), len(routingCandidates),
+ filteredExcluded, filteredMissing, filteredUnsched, filteredPlatform, filteredModelScope, filteredModelMapping, filteredWindowCost)
+ }
+
+ if len(routingCandidates) > 0 {
+ // 1.5. 在路由账号范围内检查粘性会话
+ if sessionHash != "" && s.cache != nil {
+ stickyAccountID, err := s.cache.GetSessionAccountID(ctx, derefGroupID(groupID), sessionHash)
+ if err == nil && stickyAccountID > 0 && containsInt64(routingAccountIDs, stickyAccountID) && !isExcluded(stickyAccountID) {
+ // 粘性账号在路由列表中,优先使用
+ if stickyAccount, ok := accountByID[stickyAccountID]; ok {
+ if stickyAccount.IsSchedulable() &&
+ s.isAccountAllowedForPlatform(stickyAccount, platform, useMixed) &&
+ stickyAccount.IsSchedulableForModel(requestedModel) &&
+ (requestedModel == "" || s.isModelSupportedByAccount(stickyAccount, requestedModel)) &&
+ s.isAccountSchedulableForWindowCost(ctx, stickyAccount, true) { // 粘性会话窗口费用检查
+ result, err := s.tryAcquireAccountSlot(ctx, stickyAccountID, stickyAccount.Concurrency)
+ if err == nil && result.Acquired {
+ // 会话数量限制检查
+ if !s.checkAndRegisterSession(ctx, stickyAccount, sessionUUID) {
+ result.ReleaseFunc() // 释放槽位
+ // 继续到负载感知选择
+ } else {
+ _ = s.cache.RefreshSessionTTL(ctx, derefGroupID(groupID), sessionHash, stickySessionTTL)
+ if s.debugModelRoutingEnabled() {
+ log.Printf("[ModelRoutingDebug] routed sticky hit: group_id=%v model=%s session=%s account=%d", derefGroupID(groupID), requestedModel, shortSessionHash(sessionHash), stickyAccountID)
+ }
+ return &AccountSelectionResult{
+ Account: stickyAccount,
+ Acquired: true,
+ ReleaseFunc: result.ReleaseFunc,
+ }, nil
+ }
+ }
+
+ waitingCount, _ := s.concurrencyService.GetAccountWaitingCount(ctx, stickyAccountID)
+ if waitingCount < cfg.StickySessionMaxWaiting {
+ return &AccountSelectionResult{
+ Account: stickyAccount,
+ WaitPlan: &AccountWaitPlan{
+ AccountID: stickyAccountID,
+ MaxConcurrency: stickyAccount.Concurrency,
+ Timeout: cfg.StickySessionWaitTimeout,
+ MaxWaiting: cfg.StickySessionMaxWaiting,
+ },
+ }, nil
+ }
+ // 粘性账号槽位满且等待队列已满,继续使用负载感知选择
+ }
+ }
+ }
+ }
+
+ // 2. 批量获取负载信息
+ routingLoads := make([]AccountWithConcurrency, 0, len(routingCandidates))
+ for _, acc := range routingCandidates {
+ routingLoads = append(routingLoads, AccountWithConcurrency{
+ ID: acc.ID,
+ MaxConcurrency: acc.Concurrency,
+ })
+ }
+ routingLoadMap, _ := s.concurrencyService.GetAccountsLoadBatch(ctx, routingLoads)
+
+ // 3. 按负载感知排序
+ type accountWithLoad struct {
+ account *Account
+ loadInfo *AccountLoadInfo
+ }
+ var routingAvailable []accountWithLoad
+ for _, acc := range routingCandidates {
+ loadInfo := routingLoadMap[acc.ID]
+ if loadInfo == nil {
+ loadInfo = &AccountLoadInfo{AccountID: acc.ID}
+ }
+ if loadInfo.LoadRate < 100 {
+ routingAvailable = append(routingAvailable, accountWithLoad{account: acc, loadInfo: loadInfo})
+ }
+ }
+
+ if len(routingAvailable) > 0 {
+ // 排序:优先级 > 负载率 > 最后使用时间
+ sort.SliceStable(routingAvailable, func(i, j int) bool {
+ a, b := routingAvailable[i], routingAvailable[j]
+ if a.account.Priority != b.account.Priority {
+ return a.account.Priority < b.account.Priority
+ }
+ if a.loadInfo.LoadRate != b.loadInfo.LoadRate {
+ return a.loadInfo.LoadRate < b.loadInfo.LoadRate
+ }
+ switch {
+ case a.account.LastUsedAt == nil && b.account.LastUsedAt != nil:
+ return true
+ case a.account.LastUsedAt != nil && b.account.LastUsedAt == nil:
+ return false
+ case a.account.LastUsedAt == nil && b.account.LastUsedAt == nil:
+ return false
+ default:
+ return a.account.LastUsedAt.Before(*b.account.LastUsedAt)
+ }
+ })
+
+ // 4. 尝试获取槽位
+ for _, item := range routingAvailable {
+ result, err := s.tryAcquireAccountSlot(ctx, item.account.ID, item.account.Concurrency)
+ if err == nil && result.Acquired {
+ // 会话数量限制检查
+ if !s.checkAndRegisterSession(ctx, item.account, sessionUUID) {
+ result.ReleaseFunc() // 释放槽位,继续尝试下一个账号
+ continue
+ }
+ if sessionHash != "" && s.cache != nil {
+ _ = s.cache.SetSessionAccountID(ctx, derefGroupID(groupID), sessionHash, item.account.ID, stickySessionTTL)
+ }
+ if s.debugModelRoutingEnabled() {
+ log.Printf("[ModelRoutingDebug] routed select: group_id=%v model=%s session=%s account=%d", derefGroupID(groupID), requestedModel, shortSessionHash(sessionHash), item.account.ID)
+ }
+ return &AccountSelectionResult{
+ Account: item.account,
+ Acquired: true,
+ ReleaseFunc: result.ReleaseFunc,
+ }, nil
+ }
+ }
+
+ // 5. 所有路由账号槽位满,返回等待计划(选择负载最低的)
+ acc := routingAvailable[0].account
+ if s.debugModelRoutingEnabled() {
+ log.Printf("[ModelRoutingDebug] routed wait: group_id=%v model=%s session=%s account=%d", derefGroupID(groupID), requestedModel, shortSessionHash(sessionHash), acc.ID)
+ }
+ return &AccountSelectionResult{
+ Account: acc,
+ WaitPlan: &AccountWaitPlan{
+ AccountID: acc.ID,
+ MaxConcurrency: acc.Concurrency,
+ Timeout: cfg.StickySessionWaitTimeout,
+ MaxWaiting: cfg.StickySessionMaxWaiting,
+ },
+ }, nil
+ }
+ // 路由列表中的账号都不可用(负载率 >= 100),继续到 Layer 2 回退
+ log.Printf("[ModelRouting] All routed accounts unavailable for model=%s, falling back to normal selection", requestedModel)
+ }
+ }
+
+ // ============ Layer 1.5: 粘性会话(仅在无模型路由配置时生效) ============
+ if len(routingAccountIDs) == 0 && sessionHash != "" && s.cache != nil {
accountID, err := s.cache.GetSessionAccountID(ctx, derefGroupID(groupID), sessionHash)
if err == nil && accountID > 0 && !isExcluded(accountID) {
- // 粘性命中仅在当前可调度候选集中生效。
- accountByID := make(map[int64]*Account, len(accounts))
- for i := range accounts {
- accountByID[accounts[i].ID] = &accounts[i]
- }
account, ok := accountByID[accountID]
if ok && s.isAccountInGroup(account, groupID) &&
s.isAccountAllowedForPlatform(account, platform, useMixed) &&
account.IsSchedulableForModel(requestedModel) &&
- (requestedModel == "" || s.isModelSupportedByAccount(account, requestedModel)) {
+ (requestedModel == "" || s.isModelSupportedByAccount(account, requestedModel)) &&
+ s.isAccountSchedulableForWindowCost(ctx, account, true) { // 粘性会话窗口费用检查
result, err := s.tryAcquireAccountSlot(ctx, accountID, account.Concurrency)
if err == nil && result.Acquired {
- _ = s.cache.RefreshSessionTTL(ctx, derefGroupID(groupID), sessionHash, stickySessionTTL)
- return &AccountSelectionResult{
- Account: account,
- Acquired: true,
- ReleaseFunc: result.ReleaseFunc,
- }, nil
+ // 会话数量限制检查
+ if !s.checkAndRegisterSession(ctx, account, sessionUUID) {
+ result.ReleaseFunc() // 释放槽位,继续到 Layer 2
+ } else {
+ _ = s.cache.RefreshSessionTTL(ctx, derefGroupID(groupID), sessionHash, stickySessionTTL)
+ return &AccountSelectionResult{
+ Account: account,
+ Acquired: true,
+ ReleaseFunc: result.ReleaseFunc,
+ }, nil
+ }
}
waitingCount, _ := s.concurrencyService.GetAccountWaitingCount(ctx, accountID)
@@ -935,6 +1187,10 @@ func (s *GatewayService) SelectAccountWithLoadAwareness(ctx context.Context, gro
if requestedModel != "" && !s.isModelSupportedByAccount(acc, requestedModel) {
continue
}
+ // 窗口费用检查(非粘性会话路径)
+ if !s.isAccountSchedulableForWindowCost(ctx, acc, false) {
+ continue
+ }
candidates = append(candidates, acc)
}
@@ -952,7 +1208,7 @@ func (s *GatewayService) SelectAccountWithLoadAwareness(ctx context.Context, gro
loadMap, err := s.concurrencyService.GetAccountsLoadBatch(ctx, accountLoads)
if err != nil {
- if result, ok := s.tryAcquireByLegacyOrder(ctx, candidates, groupID, sessionHash, preferOAuth); ok {
+ if result, ok := s.tryAcquireByLegacyOrder(ctx, candidates, groupID, sessionHash, preferOAuth, sessionUUID); ok {
return result, nil
}
} else {
@@ -1001,6 +1257,11 @@ func (s *GatewayService) SelectAccountWithLoadAwareness(ctx context.Context, gro
for _, item := range available {
result, err := s.tryAcquireAccountSlot(ctx, item.account.ID, item.account.Concurrency)
if err == nil && result.Acquired {
+ // 会话数量限制检查
+ if !s.checkAndRegisterSession(ctx, item.account, sessionUUID) {
+ result.ReleaseFunc() // 释放槽位,继续尝试下一个账号
+ continue
+ }
if sessionHash != "" && s.cache != nil {
_ = s.cache.SetSessionAccountID(ctx, derefGroupID(groupID), sessionHash, item.account.ID, stickySessionTTL)
}
@@ -1030,13 +1291,18 @@ func (s *GatewayService) SelectAccountWithLoadAwareness(ctx context.Context, gro
return nil, errors.New("no available accounts")
}
-func (s *GatewayService) tryAcquireByLegacyOrder(ctx context.Context, candidates []*Account, groupID *int64, sessionHash string, preferOAuth bool) (*AccountSelectionResult, bool) {
+func (s *GatewayService) tryAcquireByLegacyOrder(ctx context.Context, candidates []*Account, groupID *int64, sessionHash string, preferOAuth bool, sessionUUID string) (*AccountSelectionResult, bool) {
ordered := append([]*Account(nil), candidates...)
sortAccountsByPriorityAndLastUsed(ordered, preferOAuth)
for _, acc := range ordered {
result, err := s.tryAcquireAccountSlot(ctx, acc.ID, acc.Concurrency)
if err == nil && result.Acquired {
+ // 会话数量限制检查
+ if !s.checkAndRegisterSession(ctx, acc, sessionUUID) {
+ result.ReleaseFunc() // 释放槽位,继续尝试下一个账号
+ continue
+ }
if sessionHash != "" && s.cache != nil {
_ = s.cache.SetSessionAccountID(ctx, derefGroupID(groupID), sessionHash, acc.ID, stickySessionTTL)
}
@@ -1093,6 +1359,32 @@ func (s *GatewayService) resolveGroupByID(ctx context.Context, groupID int64) (*
return group, nil
}
+func (s *GatewayService) routingAccountIDsForRequest(ctx context.Context, groupID *int64, requestedModel string, platform string) []int64 {
+ if groupID == nil || requestedModel == "" || platform != PlatformAnthropic {
+ return nil
+ }
+ group, err := s.resolveGroupByID(ctx, *groupID)
+ if err != nil || group == nil {
+ if s.debugModelRoutingEnabled() {
+ log.Printf("[ModelRoutingDebug] resolve group failed: group_id=%v model=%s platform=%s err=%v", derefGroupID(groupID), requestedModel, platform, err)
+ }
+ return nil
+ }
+ // Preserve existing behavior: model routing only applies to anthropic groups.
+ if group.Platform != PlatformAnthropic {
+ if s.debugModelRoutingEnabled() {
+ log.Printf("[ModelRoutingDebug] skip: non-anthropic group platform: group_id=%d group_platform=%s model=%s", group.ID, group.Platform, requestedModel)
+ }
+ return nil
+ }
+ ids := group.GetRoutingAccountIDs(requestedModel)
+ if s.debugModelRoutingEnabled() {
+ log.Printf("[ModelRoutingDebug] routing lookup: group_id=%d model=%s enabled=%v rules=%d matched_ids=%v",
+ group.ID, requestedModel, group.ModelRoutingEnabled, len(group.ModelRouting), ids)
+ }
+ return ids
+}
+
func (s *GatewayService) resolveGatewayGroup(ctx context.Context, groupID *int64) (*Group, *int64, error) {
if groupID == nil {
return nil, nil, nil
@@ -1242,6 +1534,107 @@ func (s *GatewayService) tryAcquireAccountSlot(ctx context.Context, accountID in
return s.concurrencyService.AcquireAccountSlot(ctx, accountID, maxConcurrency)
}
+// isAccountSchedulableForWindowCost 检查账号是否可根据窗口费用进行调度
+// 仅适用于 Anthropic OAuth/SetupToken 账号
+// 返回 true 表示可调度,false 表示不可调度
+func (s *GatewayService) isAccountSchedulableForWindowCost(ctx context.Context, account *Account, isSticky bool) bool {
+ // 只检查 Anthropic OAuth/SetupToken 账号
+ if !account.IsAnthropicOAuthOrSetupToken() {
+ return true
+ }
+
+ limit := account.GetWindowCostLimit()
+ if limit <= 0 {
+ return true // 未启用窗口费用限制
+ }
+
+ // 尝试从缓存获取窗口费用
+ var currentCost float64
+ if s.sessionLimitCache != nil {
+ if cost, hit, err := s.sessionLimitCache.GetWindowCost(ctx, account.ID); err == nil && hit {
+ currentCost = cost
+ goto checkSchedulability
+ }
+ }
+
+ // 缓存未命中,从数据库查询
+ {
+ var startTime time.Time
+ if account.SessionWindowStart != nil {
+ startTime = *account.SessionWindowStart
+ } else {
+ startTime = time.Now().Add(-5 * time.Hour)
+ }
+
+ stats, err := s.usageLogRepo.GetAccountWindowStats(ctx, account.ID, startTime)
+ if err != nil {
+ // 失败开放:查询失败时允许调度
+ return true
+ }
+
+ // 使用标准费用(不含账号倍率)
+ currentCost = stats.StandardCost
+
+ // 设置缓存(忽略错误)
+ if s.sessionLimitCache != nil {
+ _ = s.sessionLimitCache.SetWindowCost(ctx, account.ID, currentCost)
+ }
+ }
+
+checkSchedulability:
+ schedulability := account.CheckWindowCostSchedulability(currentCost)
+
+ switch schedulability {
+ case WindowCostSchedulable:
+ return true
+ case WindowCostStickyOnly:
+ return isSticky
+ case WindowCostNotSchedulable:
+ return false
+ }
+ return true
+}
+
+// checkAndRegisterSession 检查并注册会话,用于会话数量限制
+// 仅适用于 Anthropic OAuth/SetupToken 账号
+// 返回 true 表示允许(在限制内或会话已存在),false 表示拒绝(超出限制且是新会话)
+func (s *GatewayService) checkAndRegisterSession(ctx context.Context, account *Account, sessionUUID string) bool {
+ // 只检查 Anthropic OAuth/SetupToken 账号
+ if !account.IsAnthropicOAuthOrSetupToken() {
+ return true
+ }
+
+ maxSessions := account.GetMaxSessions()
+ if maxSessions <= 0 || sessionUUID == "" {
+ return true // 未启用会话限制或无会话ID
+ }
+
+ if s.sessionLimitCache == nil {
+ return true // 缓存不可用时允许通过
+ }
+
+ idleTimeout := time.Duration(account.GetSessionIdleTimeoutMinutes()) * time.Minute
+
+ allowed, err := s.sessionLimitCache.RegisterSession(ctx, account.ID, sessionUUID, maxSessions, idleTimeout)
+ if err != nil {
+ // 失败开放:缓存错误时允许通过
+ return true
+ }
+ return allowed
+}
+
+// extractSessionUUID 从 metadata.user_id 中提取会话 UUID
+// 格式: user_{64位hex}_account__session_{uuid}
+func extractSessionUUID(metadataUserID string) string {
+ if metadataUserID == "" {
+ return ""
+ }
+ if match := sessionIDRegex.FindStringSubmatch(metadataUserID); len(match) > 1 {
+ return match[1]
+ }
+ return ""
+}
+
func (s *GatewayService) getSchedulableAccount(ctx context.Context, accountID int64) (*Account, error) {
if s.schedulerSnapshot != nil {
return s.schedulerSnapshot.GetAccount(ctx, accountID)
@@ -1274,6 +1667,116 @@ func sortAccountsByPriorityAndLastUsed(accounts []*Account, preferOAuth bool) {
// selectAccountForModelWithPlatform 选择单平台账户(完全隔离)
func (s *GatewayService) selectAccountForModelWithPlatform(ctx context.Context, groupID *int64, sessionHash string, requestedModel string, excludedIDs map[int64]struct{}, platform string) (*Account, error) {
preferOAuth := platform == PlatformGemini
+ routingAccountIDs := s.routingAccountIDsForRequest(ctx, groupID, requestedModel, platform)
+
+ var accounts []Account
+ accountsLoaded := false
+
+ // ============ Model Routing (legacy path): apply before sticky session ============
+ // When load-awareness is disabled (e.g. concurrency service not configured), we still honor model routing
+ // so switching model can switch upstream account within the same sticky session.
+ if len(routingAccountIDs) > 0 {
+ if s.debugModelRoutingEnabled() {
+ log.Printf("[ModelRoutingDebug] legacy routed begin: group_id=%v model=%s platform=%s session=%s routed_ids=%v",
+ derefGroupID(groupID), requestedModel, platform, shortSessionHash(sessionHash), routingAccountIDs)
+ }
+ // 1) Sticky session only applies if the bound account is within the routing set.
+ if sessionHash != "" && s.cache != nil {
+ accountID, err := s.cache.GetSessionAccountID(ctx, derefGroupID(groupID), sessionHash)
+ if err == nil && accountID > 0 && containsInt64(routingAccountIDs, accountID) {
+ if _, excluded := excludedIDs[accountID]; !excluded {
+ account, err := s.getSchedulableAccount(ctx, accountID)
+ // 检查账号分组归属和平台匹配(确保粘性会话不会跨分组或跨平台)
+ if err == nil && s.isAccountInGroup(account, groupID) && account.Platform == platform && account.IsSchedulableForModel(requestedModel) && (requestedModel == "" || s.isModelSupportedByAccount(account, requestedModel)) {
+ if err := s.cache.RefreshSessionTTL(ctx, derefGroupID(groupID), sessionHash, stickySessionTTL); err != nil {
+ log.Printf("refresh session ttl failed: session=%s err=%v", sessionHash, err)
+ }
+ if s.debugModelRoutingEnabled() {
+ log.Printf("[ModelRoutingDebug] legacy routed sticky hit: group_id=%v model=%s session=%s account=%d", derefGroupID(groupID), requestedModel, shortSessionHash(sessionHash), accountID)
+ }
+ return account, nil
+ }
+ }
+ }
+ }
+
+ // 2) Select an account from the routed candidates.
+ forcePlatform, hasForcePlatform := ctx.Value(ctxkey.ForcePlatform).(string)
+ if hasForcePlatform && forcePlatform == "" {
+ hasForcePlatform = false
+ }
+ var err error
+ accounts, _, err = s.listSchedulableAccounts(ctx, groupID, platform, hasForcePlatform)
+ if err != nil {
+ return nil, fmt.Errorf("query accounts failed: %w", err)
+ }
+ accountsLoaded = true
+
+ routingSet := make(map[int64]struct{}, len(routingAccountIDs))
+ for _, id := range routingAccountIDs {
+ if id > 0 {
+ routingSet[id] = struct{}{}
+ }
+ }
+
+ var selected *Account
+ for i := range accounts {
+ acc := &accounts[i]
+ if _, ok := routingSet[acc.ID]; !ok {
+ continue
+ }
+ if _, excluded := excludedIDs[acc.ID]; excluded {
+ continue
+ }
+ // Scheduler snapshots can be temporarily stale; re-check schedulability here to
+ // avoid selecting accounts that were recently rate-limited/overloaded.
+ if !acc.IsSchedulable() {
+ continue
+ }
+ if !acc.IsSchedulableForModel(requestedModel) {
+ continue
+ }
+ if requestedModel != "" && !s.isModelSupportedByAccount(acc, requestedModel) {
+ continue
+ }
+ if selected == nil {
+ selected = acc
+ continue
+ }
+ if acc.Priority < selected.Priority {
+ selected = acc
+ } else if acc.Priority == selected.Priority {
+ switch {
+ case acc.LastUsedAt == nil && selected.LastUsedAt != nil:
+ selected = acc
+ case acc.LastUsedAt != nil && selected.LastUsedAt == nil:
+ // keep selected (never used is preferred)
+ case acc.LastUsedAt == nil && selected.LastUsedAt == nil:
+ if preferOAuth && acc.Type != selected.Type && acc.Type == AccountTypeOAuth {
+ selected = acc
+ }
+ default:
+ if acc.LastUsedAt.Before(*selected.LastUsedAt) {
+ selected = acc
+ }
+ }
+ }
+ }
+
+ if selected != nil {
+ if sessionHash != "" && s.cache != nil {
+ if err := s.cache.SetSessionAccountID(ctx, derefGroupID(groupID), sessionHash, selected.ID, stickySessionTTL); err != nil {
+ log.Printf("set session account failed: session=%s account_id=%d err=%v", sessionHash, selected.ID, err)
+ }
+ }
+ if s.debugModelRoutingEnabled() {
+ log.Printf("[ModelRoutingDebug] legacy routed select: group_id=%v model=%s session=%s account=%d", derefGroupID(groupID), requestedModel, shortSessionHash(sessionHash), selected.ID)
+ }
+ return selected, nil
+ }
+ log.Printf("[ModelRouting] No routed accounts available for model=%s, falling back to normal selection", requestedModel)
+ }
+
// 1. 查询粘性会话
if sessionHash != "" && s.cache != nil {
accountID, err := s.cache.GetSessionAccountID(ctx, derefGroupID(groupID), sessionHash)
@@ -1292,13 +1795,16 @@ func (s *GatewayService) selectAccountForModelWithPlatform(ctx context.Context,
}
// 2. 获取可调度账号列表(单平台)
- forcePlatform, hasForcePlatform := ctx.Value(ctxkey.ForcePlatform).(string)
- if hasForcePlatform && forcePlatform == "" {
- hasForcePlatform = false
- }
- accounts, _, err := s.listSchedulableAccounts(ctx, groupID, platform, hasForcePlatform)
- if err != nil {
- return nil, fmt.Errorf("query accounts failed: %w", err)
+ if !accountsLoaded {
+ forcePlatform, hasForcePlatform := ctx.Value(ctxkey.ForcePlatform).(string)
+ if hasForcePlatform && forcePlatform == "" {
+ hasForcePlatform = false
+ }
+ var err error
+ accounts, _, err = s.listSchedulableAccounts(ctx, groupID, platform, hasForcePlatform)
+ if err != nil {
+ return nil, fmt.Errorf("query accounts failed: %w", err)
+ }
}
// 3. 按优先级+最久未用选择(考虑模型支持)
@@ -1364,6 +1870,115 @@ func (s *GatewayService) selectAccountForModelWithPlatform(ctx context.Context,
// 查询原生平台账户 + 启用 mixed_scheduling 的 antigravity 账户
func (s *GatewayService) selectAccountWithMixedScheduling(ctx context.Context, groupID *int64, sessionHash string, requestedModel string, excludedIDs map[int64]struct{}, nativePlatform string) (*Account, error) {
preferOAuth := nativePlatform == PlatformGemini
+ routingAccountIDs := s.routingAccountIDsForRequest(ctx, groupID, requestedModel, nativePlatform)
+
+ var accounts []Account
+ accountsLoaded := false
+
+ // ============ Model Routing (legacy path): apply before sticky session ============
+ if len(routingAccountIDs) > 0 {
+ if s.debugModelRoutingEnabled() {
+ log.Printf("[ModelRoutingDebug] legacy mixed routed begin: group_id=%v model=%s platform=%s session=%s routed_ids=%v",
+ derefGroupID(groupID), requestedModel, nativePlatform, shortSessionHash(sessionHash), routingAccountIDs)
+ }
+ // 1) Sticky session only applies if the bound account is within the routing set.
+ if sessionHash != "" && s.cache != nil {
+ accountID, err := s.cache.GetSessionAccountID(ctx, derefGroupID(groupID), sessionHash)
+ if err == nil && accountID > 0 && containsInt64(routingAccountIDs, accountID) {
+ if _, excluded := excludedIDs[accountID]; !excluded {
+ account, err := s.getSchedulableAccount(ctx, accountID)
+ // 检查账号分组归属和有效性:原生平台直接匹配,antigravity 需要启用混合调度
+ if err == nil && s.isAccountInGroup(account, groupID) && account.IsSchedulableForModel(requestedModel) && (requestedModel == "" || s.isModelSupportedByAccount(account, requestedModel)) {
+ if account.Platform == nativePlatform || (account.Platform == PlatformAntigravity && account.IsMixedSchedulingEnabled()) {
+ if err := s.cache.RefreshSessionTTL(ctx, derefGroupID(groupID), sessionHash, stickySessionTTL); err != nil {
+ log.Printf("refresh session ttl failed: session=%s err=%v", sessionHash, err)
+ }
+ if s.debugModelRoutingEnabled() {
+ log.Printf("[ModelRoutingDebug] legacy mixed routed sticky hit: group_id=%v model=%s session=%s account=%d", derefGroupID(groupID), requestedModel, shortSessionHash(sessionHash), accountID)
+ }
+ return account, nil
+ }
+ }
+ }
+ }
+ }
+
+ // 2) Select an account from the routed candidates.
+ var err error
+ accounts, _, err = s.listSchedulableAccounts(ctx, groupID, nativePlatform, false)
+ if err != nil {
+ return nil, fmt.Errorf("query accounts failed: %w", err)
+ }
+ accountsLoaded = true
+
+ routingSet := make(map[int64]struct{}, len(routingAccountIDs))
+ for _, id := range routingAccountIDs {
+ if id > 0 {
+ routingSet[id] = struct{}{}
+ }
+ }
+
+ var selected *Account
+ for i := range accounts {
+ acc := &accounts[i]
+ if _, ok := routingSet[acc.ID]; !ok {
+ continue
+ }
+ if _, excluded := excludedIDs[acc.ID]; excluded {
+ continue
+ }
+ // Scheduler snapshots can be temporarily stale; re-check schedulability here to
+ // avoid selecting accounts that were recently rate-limited/overloaded.
+ if !acc.IsSchedulable() {
+ continue
+ }
+ // 过滤:原生平台直接通过,antigravity 需要启用混合调度
+ if acc.Platform == PlatformAntigravity && !acc.IsMixedSchedulingEnabled() {
+ continue
+ }
+ if !acc.IsSchedulableForModel(requestedModel) {
+ continue
+ }
+ if requestedModel != "" && !s.isModelSupportedByAccount(acc, requestedModel) {
+ continue
+ }
+ if selected == nil {
+ selected = acc
+ continue
+ }
+ if acc.Priority < selected.Priority {
+ selected = acc
+ } else if acc.Priority == selected.Priority {
+ switch {
+ case acc.LastUsedAt == nil && selected.LastUsedAt != nil:
+ selected = acc
+ case acc.LastUsedAt != nil && selected.LastUsedAt == nil:
+ // keep selected (never used is preferred)
+ case acc.LastUsedAt == nil && selected.LastUsedAt == nil:
+ if preferOAuth && acc.Platform == PlatformGemini && selected.Platform == PlatformGemini && acc.Type != selected.Type && acc.Type == AccountTypeOAuth {
+ selected = acc
+ }
+ default:
+ if acc.LastUsedAt.Before(*selected.LastUsedAt) {
+ selected = acc
+ }
+ }
+ }
+ }
+
+ if selected != nil {
+ if sessionHash != "" && s.cache != nil {
+ if err := s.cache.SetSessionAccountID(ctx, derefGroupID(groupID), sessionHash, selected.ID, stickySessionTTL); err != nil {
+ log.Printf("set session account failed: session=%s account_id=%d err=%v", sessionHash, selected.ID, err)
+ }
+ }
+ if s.debugModelRoutingEnabled() {
+ log.Printf("[ModelRoutingDebug] legacy mixed routed select: group_id=%v model=%s session=%s account=%d", derefGroupID(groupID), requestedModel, shortSessionHash(sessionHash), selected.ID)
+ }
+ return selected, nil
+ }
+ log.Printf("[ModelRouting] No routed accounts available for model=%s, falling back to normal selection", requestedModel)
+ }
// 1. 查询粘性会话
if sessionHash != "" && s.cache != nil {
@@ -1385,9 +2000,12 @@ func (s *GatewayService) selectAccountWithMixedScheduling(ctx context.Context, g
}
// 2. 获取可调度账号列表
- accounts, _, err := s.listSchedulableAccounts(ctx, groupID, nativePlatform, false)
- if err != nil {
- return nil, fmt.Errorf("query accounts failed: %w", err)
+ if !accountsLoaded {
+ var err error
+ accounts, _, err = s.listSchedulableAccounts(ctx, groupID, nativePlatform, false)
+ if err != nil {
+ return nil, fmt.Errorf("query accounts failed: %w", err)
+ }
}
// 3. 按优先级+最久未用选择(考虑模型支持和混合调度)
@@ -1488,6 +2106,16 @@ func (s *GatewayService) GetAccessToken(ctx context.Context, account *Account) (
}
func (s *GatewayService) getOAuthToken(ctx context.Context, account *Account) (string, string, error) {
+ // 对于 Anthropic OAuth 账号,使用 ClaudeTokenProvider 获取缓存的 token
+ if account.Platform == PlatformAnthropic && account.Type == AccountTypeOAuth && s.claudeTokenProvider != nil {
+ accessToken, err := s.claudeTokenProvider.GetAccessToken(ctx, account)
+ if err != nil {
+ return "", "", err
+ }
+ return accessToken, "oauth", nil
+ }
+
+ // 其他情况(Gemini 有自己的 TokenProvider,setup-token 类型等)直接从账号读取
accessToken := account.GetCredential("access_token")
if accessToken == "" {
return "", "", errors.New("access_token not found in credentials")
@@ -1901,6 +2529,8 @@ func (s *GatewayService) Forward(ctx context.Context, c *gin.Context, account *A
retryStart := time.Now()
for attempt := 1; attempt <= maxRetryAttempts; attempt++ {
// 构建上游请求(每次重试需要重新构建,因为请求体需要重新读取)
+ // Capture upstream request body for ops retry of this attempt.
+ c.Set(OpsUpstreamRequestBodyKey, string(body))
upstreamReq, err := s.buildUpstreamRequest(ctx, c, account, body, token, tokenType, reqModel, reqStream, shouldMimicClaudeCode)
if err != nil {
return nil, err
@@ -1918,6 +2548,7 @@ func (s *GatewayService) Forward(ctx context.Context, c *gin.Context, account *A
appendOpsUpstreamError(c, OpsUpstreamErrorEvent{
Platform: account.Platform,
AccountID: account.ID,
+ AccountName: account.Name,
UpstreamStatusCode: 0,
Kind: "request_error",
Message: safeErr,
@@ -1942,6 +2573,7 @@ func (s *GatewayService) Forward(ctx context.Context, c *gin.Context, account *A
appendOpsUpstreamError(c, OpsUpstreamErrorEvent{
Platform: account.Platform,
AccountID: account.ID,
+ AccountName: account.Name,
UpstreamStatusCode: resp.StatusCode,
UpstreamRequestID: resp.Header.Get("x-request-id"),
Kind: "signature_error",
@@ -1993,6 +2625,7 @@ func (s *GatewayService) Forward(ctx context.Context, c *gin.Context, account *A
appendOpsUpstreamError(c, OpsUpstreamErrorEvent{
Platform: account.Platform,
AccountID: account.ID,
+ AccountName: account.Name,
UpstreamStatusCode: retryResp.StatusCode,
UpstreamRequestID: retryResp.Header.Get("x-request-id"),
Kind: "signature_retry_thinking",
@@ -2021,6 +2654,7 @@ func (s *GatewayService) Forward(ctx context.Context, c *gin.Context, account *A
appendOpsUpstreamError(c, OpsUpstreamErrorEvent{
Platform: account.Platform,
AccountID: account.ID,
+ AccountName: account.Name,
UpstreamStatusCode: 0,
Kind: "signature_retry_tools_request_error",
Message: sanitizeUpstreamErrorMessage(retryErr2.Error()),
@@ -2079,6 +2713,7 @@ func (s *GatewayService) Forward(ctx context.Context, c *gin.Context, account *A
appendOpsUpstreamError(c, OpsUpstreamErrorEvent{
Platform: account.Platform,
AccountID: account.ID,
+ AccountName: account.Name,
UpstreamStatusCode: resp.StatusCode,
UpstreamRequestID: resp.Header.Get("x-request-id"),
Kind: "retry",
@@ -2127,6 +2762,7 @@ func (s *GatewayService) Forward(ctx context.Context, c *gin.Context, account *A
appendOpsUpstreamError(c, OpsUpstreamErrorEvent{
Platform: account.Platform,
AccountID: account.ID,
+ AccountName: account.Name,
UpstreamStatusCode: resp.StatusCode,
UpstreamRequestID: resp.Header.Get("x-request-id"),
Kind: "retry_exhausted_failover",
@@ -2193,6 +2829,7 @@ func (s *GatewayService) Forward(ctx context.Context, c *gin.Context, account *A
appendOpsUpstreamError(c, OpsUpstreamErrorEvent{
Platform: account.Platform,
AccountID: account.ID,
+ AccountName: account.Name,
UpstreamStatusCode: resp.StatusCode,
UpstreamRequestID: resp.Header.Get("x-request-id"),
Kind: "failover_on_400",
@@ -3283,30 +3920,32 @@ func (s *GatewayService) RecordUsage(ctx context.Context, input *RecordUsageInpu
if result.ImageSize != "" {
imageSize = &result.ImageSize
}
+ accountRateMultiplier := account.BillingRateMultiplier()
usageLog := &UsageLog{
- UserID: user.ID,
- APIKeyID: apiKey.ID,
- AccountID: account.ID,
- RequestID: result.RequestID,
- Model: result.Model,
- InputTokens: result.Usage.InputTokens,
- OutputTokens: result.Usage.OutputTokens,
- CacheCreationTokens: result.Usage.CacheCreationInputTokens,
- CacheReadTokens: result.Usage.CacheReadInputTokens,
- InputCost: cost.InputCost,
- OutputCost: cost.OutputCost,
- CacheCreationCost: cost.CacheCreationCost,
- CacheReadCost: cost.CacheReadCost,
- TotalCost: cost.TotalCost,
- ActualCost: cost.ActualCost,
- RateMultiplier: multiplier,
- BillingType: billingType,
- Stream: result.Stream,
- DurationMs: &durationMs,
- FirstTokenMs: result.FirstTokenMs,
- ImageCount: result.ImageCount,
- ImageSize: imageSize,
- CreatedAt: time.Now(),
+ UserID: user.ID,
+ APIKeyID: apiKey.ID,
+ AccountID: account.ID,
+ RequestID: result.RequestID,
+ Model: result.Model,
+ InputTokens: result.Usage.InputTokens,
+ OutputTokens: result.Usage.OutputTokens,
+ CacheCreationTokens: result.Usage.CacheCreationInputTokens,
+ CacheReadTokens: result.Usage.CacheReadInputTokens,
+ InputCost: cost.InputCost,
+ OutputCost: cost.OutputCost,
+ CacheCreationCost: cost.CacheCreationCost,
+ CacheReadCost: cost.CacheReadCost,
+ TotalCost: cost.TotalCost,
+ ActualCost: cost.ActualCost,
+ RateMultiplier: multiplier,
+ AccountRateMultiplier: &accountRateMultiplier,
+ BillingType: billingType,
+ Stream: result.Stream,
+ DurationMs: &durationMs,
+ FirstTokenMs: result.FirstTokenMs,
+ ImageCount: result.ImageCount,
+ ImageSize: imageSize,
+ CreatedAt: time.Now(),
}
// 添加 UserAgent
diff --git a/backend/internal/service/gemini_messages_compat_service.go b/backend/internal/service/gemini_messages_compat_service.go
index 190e6afc..75de90f2 100644
--- a/backend/internal/service/gemini_messages_compat_service.go
+++ b/backend/internal/service/gemini_messages_compat_service.go
@@ -545,12 +545,19 @@ func (s *GeminiMessagesCompatService) Forward(ctx context.Context, c *gin.Contex
}
requestIDHeader = idHeader
+ // Capture upstream request body for ops retry of this attempt.
+ if c != nil {
+ // In this code path `body` is already the JSON sent to upstream.
+ c.Set(OpsUpstreamRequestBodyKey, string(body))
+ }
+
resp, err = s.httpUpstream.Do(upstreamReq, proxyURL, account.ID, account.Concurrency)
if err != nil {
safeErr := sanitizeUpstreamErrorMessage(err.Error())
appendOpsUpstreamError(c, OpsUpstreamErrorEvent{
Platform: account.Platform,
AccountID: account.ID,
+ AccountName: account.Name,
UpstreamStatusCode: 0,
Kind: "request_error",
Message: safeErr,
@@ -588,6 +595,7 @@ func (s *GeminiMessagesCompatService) Forward(ctx context.Context, c *gin.Contex
appendOpsUpstreamError(c, OpsUpstreamErrorEvent{
Platform: account.Platform,
AccountID: account.ID,
+ AccountName: account.Name,
UpstreamStatusCode: resp.StatusCode,
UpstreamRequestID: upstreamReqID,
Kind: "signature_error",
@@ -662,6 +670,7 @@ func (s *GeminiMessagesCompatService) Forward(ctx context.Context, c *gin.Contex
appendOpsUpstreamError(c, OpsUpstreamErrorEvent{
Platform: account.Platform,
AccountID: account.ID,
+ AccountName: account.Name,
UpstreamStatusCode: resp.StatusCode,
UpstreamRequestID: upstreamReqID,
Kind: "retry",
@@ -711,6 +720,7 @@ func (s *GeminiMessagesCompatService) Forward(ctx context.Context, c *gin.Contex
appendOpsUpstreamError(c, OpsUpstreamErrorEvent{
Platform: account.Platform,
AccountID: account.ID,
+ AccountName: account.Name,
UpstreamStatusCode: resp.StatusCode,
UpstreamRequestID: upstreamReqID,
Kind: "failover",
@@ -737,6 +747,7 @@ func (s *GeminiMessagesCompatService) Forward(ctx context.Context, c *gin.Contex
appendOpsUpstreamError(c, OpsUpstreamErrorEvent{
Platform: account.Platform,
AccountID: account.ID,
+ AccountName: account.Name,
UpstreamStatusCode: resp.StatusCode,
UpstreamRequestID: upstreamReqID,
Kind: "failover",
@@ -972,12 +983,19 @@ func (s *GeminiMessagesCompatService) ForwardNative(ctx context.Context, c *gin.
}
requestIDHeader = idHeader
+ // Capture upstream request body for ops retry of this attempt.
+ if c != nil {
+ // In this code path `body` is already the JSON sent to upstream.
+ c.Set(OpsUpstreamRequestBodyKey, string(body))
+ }
+
resp, err = s.httpUpstream.Do(upstreamReq, proxyURL, account.ID, account.Concurrency)
if err != nil {
safeErr := sanitizeUpstreamErrorMessage(err.Error())
appendOpsUpstreamError(c, OpsUpstreamErrorEvent{
Platform: account.Platform,
AccountID: account.ID,
+ AccountName: account.Name,
UpstreamStatusCode: 0,
Kind: "request_error",
Message: safeErr,
@@ -1036,6 +1054,7 @@ func (s *GeminiMessagesCompatService) ForwardNative(ctx context.Context, c *gin.
appendOpsUpstreamError(c, OpsUpstreamErrorEvent{
Platform: account.Platform,
AccountID: account.ID,
+ AccountName: account.Name,
UpstreamStatusCode: resp.StatusCode,
UpstreamRequestID: upstreamReqID,
Kind: "retry",
@@ -1120,6 +1139,7 @@ func (s *GeminiMessagesCompatService) ForwardNative(ctx context.Context, c *gin.
appendOpsUpstreamError(c, OpsUpstreamErrorEvent{
Platform: account.Platform,
AccountID: account.ID,
+ AccountName: account.Name,
UpstreamStatusCode: resp.StatusCode,
UpstreamRequestID: requestID,
Kind: "failover",
@@ -1143,6 +1163,7 @@ func (s *GeminiMessagesCompatService) ForwardNative(ctx context.Context, c *gin.
appendOpsUpstreamError(c, OpsUpstreamErrorEvent{
Platform: account.Platform,
AccountID: account.ID,
+ AccountName: account.Name,
UpstreamStatusCode: resp.StatusCode,
UpstreamRequestID: requestID,
Kind: "failover",
@@ -1168,6 +1189,7 @@ func (s *GeminiMessagesCompatService) ForwardNative(ctx context.Context, c *gin.
appendOpsUpstreamError(c, OpsUpstreamErrorEvent{
Platform: account.Platform,
AccountID: account.ID,
+ AccountName: account.Name,
UpstreamStatusCode: resp.StatusCode,
UpstreamRequestID: requestID,
Kind: "http_error",
@@ -1300,6 +1322,7 @@ func (s *GeminiMessagesCompatService) writeGeminiMappedError(c *gin.Context, acc
appendOpsUpstreamError(c, OpsUpstreamErrorEvent{
Platform: account.Platform,
AccountID: account.ID,
+ AccountName: account.Name,
UpstreamStatusCode: upstreamStatus,
UpstreamRequestID: upstreamRequestID,
Kind: "http_error",
diff --git a/backend/internal/service/gemini_multiplatform_test.go b/backend/internal/service/gemini_multiplatform_test.go
index c99cb87d..03f5d757 100644
--- a/backend/internal/service/gemini_multiplatform_test.go
+++ b/backend/internal/service/gemini_multiplatform_test.go
@@ -125,6 +125,9 @@ func (m *mockAccountRepoForGemini) SetRateLimited(ctx context.Context, id int64,
func (m *mockAccountRepoForGemini) SetAntigravityQuotaScopeLimit(ctx context.Context, id int64, scope AntigravityQuotaScope, resetAt time.Time) error {
return nil
}
+func (m *mockAccountRepoForGemini) SetModelRateLimit(ctx context.Context, id int64, scope string, resetAt time.Time) error {
+ return nil
+}
func (m *mockAccountRepoForGemini) SetOverloaded(ctx context.Context, id int64, until time.Time) error {
return nil
}
@@ -138,6 +141,9 @@ func (m *mockAccountRepoForGemini) ClearRateLimit(ctx context.Context, id int64)
func (m *mockAccountRepoForGemini) ClearAntigravityQuotaScopes(ctx context.Context, id int64) error {
return nil
}
+func (m *mockAccountRepoForGemini) ClearModelRateLimits(ctx context.Context, id int64) error {
+ return nil
+}
func (m *mockAccountRepoForGemini) UpdateSessionWindow(ctx context.Context, id int64, start, end *time.Time, status string) error {
return nil
}
diff --git a/backend/internal/service/gemini_token_cache.go b/backend/internal/service/gemini_token_cache.go
index d5e64f9a..70f246da 100644
--- a/backend/internal/service/gemini_token_cache.go
+++ b/backend/internal/service/gemini_token_cache.go
@@ -10,6 +10,7 @@ type GeminiTokenCache interface {
// cacheKey should be stable for the token scope; for GeminiCli OAuth we primarily use project_id.
GetAccessToken(ctx context.Context, cacheKey string) (string, error)
SetAccessToken(ctx context.Context, cacheKey string, token string, ttl time.Duration) error
+ DeleteAccessToken(ctx context.Context, cacheKey string) error
AcquireRefreshLock(ctx context.Context, cacheKey string, ttl time.Duration) (bool, error)
ReleaseRefreshLock(ctx context.Context, cacheKey string) error
diff --git a/backend/internal/service/gemini_token_provider.go b/backend/internal/service/gemini_token_provider.go
index 0257d19f..f13ae169 100644
--- a/backend/internal/service/gemini_token_provider.go
+++ b/backend/internal/service/gemini_token_provider.go
@@ -40,7 +40,7 @@ func (p *GeminiTokenProvider) GetAccessToken(ctx context.Context, account *Accou
return "", errors.New("not a gemini oauth account")
}
- cacheKey := geminiTokenCacheKey(account)
+ cacheKey := GeminiTokenCacheKey(account)
// 1) Try cache first.
if p.tokenCache != nil {
@@ -151,10 +151,10 @@ func (p *GeminiTokenProvider) GetAccessToken(ctx context.Context, account *Accou
return accessToken, nil
}
-func geminiTokenCacheKey(account *Account) string {
+func GeminiTokenCacheKey(account *Account) string {
projectID := strings.TrimSpace(account.GetCredential("project_id"))
if projectID != "" {
- return projectID
+ return "gemini:" + projectID
}
- return "account:" + strconv.FormatInt(account.ID, 10)
+ return "gemini:account:" + strconv.FormatInt(account.ID, 10)
}
diff --git a/backend/internal/service/group.go b/backend/internal/service/group.go
index 8e8d47d6..d6d1269b 100644
--- a/backend/internal/service/group.go
+++ b/backend/internal/service/group.go
@@ -1,6 +1,9 @@
package service
-import "time"
+import (
+ "strings"
+ "time"
+)
type Group struct {
ID int64
@@ -27,6 +30,12 @@ type Group struct {
ClaudeCodeOnly bool
FallbackGroupID *int64
+ // 模型路由配置
+ // key: 模型匹配模式(支持 * 通配符,如 "claude-opus-*")
+ // value: 优先账号 ID 列表
+ ModelRouting map[string][]int64
+ ModelRoutingEnabled bool
+
CreatedAt time.Time
UpdatedAt time.Time
@@ -90,3 +99,41 @@ func IsGroupContextValid(group *Group) bool {
}
return true
}
+
+// GetRoutingAccountIDs 根据请求模型获取路由账号 ID 列表
+// 返回匹配的优先账号 ID 列表,如果没有匹配规则则返回 nil
+func (g *Group) GetRoutingAccountIDs(requestedModel string) []int64 {
+ if !g.ModelRoutingEnabled || len(g.ModelRouting) == 0 || requestedModel == "" {
+ return nil
+ }
+
+ // 1. 精确匹配优先
+ if accountIDs, ok := g.ModelRouting[requestedModel]; ok && len(accountIDs) > 0 {
+ return accountIDs
+ }
+
+ // 2. 通配符匹配(前缀匹配)
+ for pattern, accountIDs := range g.ModelRouting {
+ if matchModelPattern(pattern, requestedModel) && len(accountIDs) > 0 {
+ return accountIDs
+ }
+ }
+
+ return nil
+}
+
+// matchModelPattern 检查模型是否匹配模式
+// 支持 * 通配符,如 "claude-opus-*" 匹配 "claude-opus-4-20250514"
+func matchModelPattern(pattern, model string) bool {
+ if pattern == model {
+ return true
+ }
+
+ // 处理 * 通配符(仅支持末尾通配符)
+ if strings.HasSuffix(pattern, "*") {
+ prefix := strings.TrimSuffix(pattern, "*")
+ return strings.HasPrefix(model, prefix)
+ }
+
+ return false
+}
diff --git a/backend/internal/service/model_rate_limit.go b/backend/internal/service/model_rate_limit.go
new file mode 100644
index 00000000..49354a7f
--- /dev/null
+++ b/backend/internal/service/model_rate_limit.go
@@ -0,0 +1,56 @@
+package service
+
+import (
+ "strings"
+ "time"
+)
+
+const modelRateLimitsKey = "model_rate_limits"
+const modelRateLimitScopeClaudeSonnet = "claude_sonnet"
+
+func resolveModelRateLimitScope(requestedModel string) (string, bool) {
+ model := strings.ToLower(strings.TrimSpace(requestedModel))
+ if model == "" {
+ return "", false
+ }
+ model = strings.TrimPrefix(model, "models/")
+ if strings.Contains(model, "sonnet") {
+ return modelRateLimitScopeClaudeSonnet, true
+ }
+ return "", false
+}
+
+func (a *Account) isModelRateLimited(requestedModel string) bool {
+ scope, ok := resolveModelRateLimitScope(requestedModel)
+ if !ok {
+ return false
+ }
+ resetAt := a.modelRateLimitResetAt(scope)
+ if resetAt == nil {
+ return false
+ }
+ return time.Now().Before(*resetAt)
+}
+
+func (a *Account) modelRateLimitResetAt(scope string) *time.Time {
+ if a == nil || a.Extra == nil || scope == "" {
+ return nil
+ }
+ rawLimits, ok := a.Extra[modelRateLimitsKey].(map[string]any)
+ if !ok {
+ return nil
+ }
+ rawLimit, ok := rawLimits[scope].(map[string]any)
+ if !ok {
+ return nil
+ }
+ resetAtRaw, ok := rawLimit["rate_limit_reset_at"].(string)
+ if !ok || strings.TrimSpace(resetAtRaw) == "" {
+ return nil
+ }
+ resetAt, err := time.Parse(time.RFC3339, resetAtRaw)
+ if err != nil {
+ return nil
+ }
+ return &resetAt
+}
diff --git a/backend/internal/service/openai_gateway_service.go b/backend/internal/service/openai_gateway_service.go
index fb811e9e..87ad37a6 100644
--- a/backend/internal/service/openai_gateway_service.go
+++ b/backend/internal/service/openai_gateway_service.go
@@ -93,6 +93,8 @@ type OpenAIGatewayService struct {
billingCacheService *BillingCacheService
httpUpstream HTTPUpstream
deferredService *DeferredService
+ openAITokenProvider *OpenAITokenProvider
+ toolCorrector *CodexToolCorrector
}
// NewOpenAIGatewayService creates a new OpenAIGatewayService
@@ -110,6 +112,7 @@ func NewOpenAIGatewayService(
billingCacheService *BillingCacheService,
httpUpstream HTTPUpstream,
deferredService *DeferredService,
+ openAITokenProvider *OpenAITokenProvider,
) *OpenAIGatewayService {
return &OpenAIGatewayService{
accountRepo: accountRepo,
@@ -125,6 +128,8 @@ func NewOpenAIGatewayService(
billingCacheService: billingCacheService,
httpUpstream: httpUpstream,
deferredService: deferredService,
+ openAITokenProvider: openAITokenProvider,
+ toolCorrector: NewCodexToolCorrector(),
}
}
@@ -503,6 +508,15 @@ func (s *OpenAIGatewayService) schedulingConfig() config.GatewaySchedulingConfig
func (s *OpenAIGatewayService) GetAccessToken(ctx context.Context, account *Account) (string, string, error) {
switch account.Type {
case AccountTypeOAuth:
+ // 使用 TokenProvider 获取缓存的 token
+ if s.openAITokenProvider != nil {
+ accessToken, err := s.openAITokenProvider.GetAccessToken(ctx, account)
+ if err != nil {
+ return "", "", err
+ }
+ return accessToken, "oauth", nil
+ }
+ // 降级:TokenProvider 未配置时直接从账号读取
accessToken := account.GetOpenAIAccessToken()
if accessToken == "" {
return "", "", errors.New("access_token not found in credentials")
@@ -664,6 +678,11 @@ func (s *OpenAIGatewayService) Forward(ctx context.Context, c *gin.Context, acco
proxyURL = account.Proxy.URL()
}
+ // Capture upstream request body for ops retry of this attempt.
+ if c != nil {
+ c.Set(OpsUpstreamRequestBodyKey, string(body))
+ }
+
// Send request
resp, err := s.httpUpstream.Do(upstreamReq, proxyURL, account.ID, account.Concurrency)
if err != nil {
@@ -673,6 +692,7 @@ func (s *OpenAIGatewayService) Forward(ctx context.Context, c *gin.Context, acco
appendOpsUpstreamError(c, OpsUpstreamErrorEvent{
Platform: account.Platform,
AccountID: account.ID,
+ AccountName: account.Name,
UpstreamStatusCode: 0,
Kind: "request_error",
Message: safeErr,
@@ -707,6 +727,7 @@ func (s *OpenAIGatewayService) Forward(ctx context.Context, c *gin.Context, acco
appendOpsUpstreamError(c, OpsUpstreamErrorEvent{
Platform: account.Platform,
AccountID: account.ID,
+ AccountName: account.Name,
UpstreamStatusCode: resp.StatusCode,
UpstreamRequestID: resp.Header.Get("x-request-id"),
Kind: "failover",
@@ -864,6 +885,7 @@ func (s *OpenAIGatewayService) handleErrorResponse(ctx context.Context, resp *ht
appendOpsUpstreamError(c, OpsUpstreamErrorEvent{
Platform: account.Platform,
AccountID: account.ID,
+ AccountName: account.Name,
UpstreamStatusCode: resp.StatusCode,
UpstreamRequestID: resp.Header.Get("x-request-id"),
Kind: "http_error",
@@ -894,6 +916,7 @@ func (s *OpenAIGatewayService) handleErrorResponse(ctx context.Context, resp *ht
appendOpsUpstreamError(c, OpsUpstreamErrorEvent{
Platform: account.Platform,
AccountID: account.ID,
+ AccountName: account.Name,
UpstreamStatusCode: resp.StatusCode,
UpstreamRequestID: resp.Header.Get("x-request-id"),
Kind: kind,
@@ -1097,6 +1120,12 @@ func (s *OpenAIGatewayService) handleStreamingResponse(ctx context.Context, resp
line = s.replaceModelInSSELine(line, mappedModel, originalModel)
}
+ // Correct Codex tool calls if needed (apply_patch -> edit, etc.)
+ if correctedData, corrected := s.toolCorrector.CorrectToolCallsInSSEData(data); corrected {
+ data = correctedData
+ line = "data: " + correctedData
+ }
+
// 写入客户端(客户端断开后继续 drain 上游)
if !clientDisconnected {
if _, err := fmt.Fprintf(w, "%s\n", line); err != nil {
@@ -1199,6 +1228,20 @@ func (s *OpenAIGatewayService) replaceModelInSSELine(line, fromModel, toModel st
return line
}
+// correctToolCallsInResponseBody 修正响应体中的工具调用
+func (s *OpenAIGatewayService) correctToolCallsInResponseBody(body []byte) []byte {
+ if len(body) == 0 {
+ return body
+ }
+
+ bodyStr := string(body)
+ corrected, changed := s.toolCorrector.CorrectToolCallsInSSEData(bodyStr)
+ if changed {
+ return []byte(corrected)
+ }
+ return body
+}
+
func (s *OpenAIGatewayService) parseSSEUsage(data string, usage *OpenAIUsage) {
// Parse response.completed event for usage (OpenAI Responses format)
var event struct {
@@ -1302,6 +1345,8 @@ func (s *OpenAIGatewayService) handleOAuthSSEToJSON(resp *http.Response, c *gin.
if originalModel != mappedModel {
body = s.replaceModelInResponseBody(body, mappedModel, originalModel)
}
+ // Correct tool calls in final response
+ body = s.correctToolCallsInResponseBody(body)
} else {
usage = s.parseSSEUsageFromBody(bodyText)
if originalModel != mappedModel {
@@ -1470,28 +1515,30 @@ func (s *OpenAIGatewayService) RecordUsage(ctx context.Context, input *OpenAIRec
// Create usage log
durationMs := int(result.Duration.Milliseconds())
+ accountRateMultiplier := account.BillingRateMultiplier()
usageLog := &UsageLog{
- UserID: user.ID,
- APIKeyID: apiKey.ID,
- AccountID: account.ID,
- RequestID: result.RequestID,
- Model: result.Model,
- InputTokens: actualInputTokens,
- OutputTokens: result.Usage.OutputTokens,
- CacheCreationTokens: result.Usage.CacheCreationInputTokens,
- CacheReadTokens: result.Usage.CacheReadInputTokens,
- InputCost: cost.InputCost,
- OutputCost: cost.OutputCost,
- CacheCreationCost: cost.CacheCreationCost,
- CacheReadCost: cost.CacheReadCost,
- TotalCost: cost.TotalCost,
- ActualCost: cost.ActualCost,
- RateMultiplier: multiplier,
- BillingType: billingType,
- Stream: result.Stream,
- DurationMs: &durationMs,
- FirstTokenMs: result.FirstTokenMs,
- CreatedAt: time.Now(),
+ UserID: user.ID,
+ APIKeyID: apiKey.ID,
+ AccountID: account.ID,
+ RequestID: result.RequestID,
+ Model: result.Model,
+ InputTokens: actualInputTokens,
+ OutputTokens: result.Usage.OutputTokens,
+ CacheCreationTokens: result.Usage.CacheCreationInputTokens,
+ CacheReadTokens: result.Usage.CacheReadInputTokens,
+ InputCost: cost.InputCost,
+ OutputCost: cost.OutputCost,
+ CacheCreationCost: cost.CacheCreationCost,
+ CacheReadCost: cost.CacheReadCost,
+ TotalCost: cost.TotalCost,
+ ActualCost: cost.ActualCost,
+ RateMultiplier: multiplier,
+ AccountRateMultiplier: &accountRateMultiplier,
+ BillingType: billingType,
+ Stream: result.Stream,
+ DurationMs: &durationMs,
+ FirstTokenMs: result.FirstTokenMs,
+ CreatedAt: time.Now(),
}
// 添加 UserAgent
diff --git a/backend/internal/service/openai_gateway_service_tool_correction_test.go b/backend/internal/service/openai_gateway_service_tool_correction_test.go
new file mode 100644
index 00000000..d4491cfe
--- /dev/null
+++ b/backend/internal/service/openai_gateway_service_tool_correction_test.go
@@ -0,0 +1,133 @@
+package service
+
+import (
+ "strings"
+ "testing"
+)
+
+// TestOpenAIGatewayService_ToolCorrection 测试 OpenAIGatewayService 中的工具修正集成
+func TestOpenAIGatewayService_ToolCorrection(t *testing.T) {
+ // 创建一个简单的 service 实例来测试工具修正
+ service := &OpenAIGatewayService{
+ toolCorrector: NewCodexToolCorrector(),
+ }
+
+ tests := []struct {
+ name string
+ input []byte
+ expected string
+ changed bool
+ }{
+ {
+ name: "correct apply_patch in response body",
+ input: []byte(`{
+ "choices": [{
+ "message": {
+ "tool_calls": [{
+ "function": {"name": "apply_patch"}
+ }]
+ }
+ }]
+ }`),
+ expected: "edit",
+ changed: true,
+ },
+ {
+ name: "correct update_plan in response body",
+ input: []byte(`{
+ "tool_calls": [{
+ "function": {"name": "update_plan"}
+ }]
+ }`),
+ expected: "todowrite",
+ changed: true,
+ },
+ {
+ name: "no change for correct tool name",
+ input: []byte(`{
+ "tool_calls": [{
+ "function": {"name": "edit"}
+ }]
+ }`),
+ expected: "edit",
+ changed: false,
+ },
+ }
+
+ for _, tt := range tests {
+ t.Run(tt.name, func(t *testing.T) {
+ result := service.correctToolCallsInResponseBody(tt.input)
+ resultStr := string(result)
+
+ // 检查是否包含期望的工具名称
+ if !strings.Contains(resultStr, tt.expected) {
+ t.Errorf("expected result to contain %q, got %q", tt.expected, resultStr)
+ }
+
+ // 对于预期有变化的情况,验证结果与输入不同
+ if tt.changed && string(result) == string(tt.input) {
+ t.Error("expected result to be different from input, but they are the same")
+ }
+
+ // 对于预期无变化的情况,验证结果与输入相同
+ if !tt.changed && string(result) != string(tt.input) {
+ t.Error("expected result to be same as input, but they are different")
+ }
+ })
+ }
+}
+
+// TestOpenAIGatewayService_ToolCorrectorInitialization 测试工具修正器是否正确初始化
+func TestOpenAIGatewayService_ToolCorrectorInitialization(t *testing.T) {
+ service := &OpenAIGatewayService{
+ toolCorrector: NewCodexToolCorrector(),
+ }
+
+ if service.toolCorrector == nil {
+ t.Fatal("toolCorrector should not be nil")
+ }
+
+ // 测试修正器可以正常工作
+ data := `{"tool_calls":[{"function":{"name":"apply_patch"}}]}`
+ corrected, changed := service.toolCorrector.CorrectToolCallsInSSEData(data)
+
+ if !changed {
+ t.Error("expected tool call to be corrected")
+ }
+
+ if !strings.Contains(corrected, "edit") {
+ t.Errorf("expected corrected data to contain 'edit', got %q", corrected)
+ }
+}
+
+// TestToolCorrectionStats 测试工具修正统计功能
+func TestToolCorrectionStats(t *testing.T) {
+ service := &OpenAIGatewayService{
+ toolCorrector: NewCodexToolCorrector(),
+ }
+
+ // 执行几次修正
+ testData := []string{
+ `{"tool_calls":[{"function":{"name":"apply_patch"}}]}`,
+ `{"tool_calls":[{"function":{"name":"update_plan"}}]}`,
+ `{"tool_calls":[{"function":{"name":"apply_patch"}}]}`,
+ }
+
+ for _, data := range testData {
+ service.toolCorrector.CorrectToolCallsInSSEData(data)
+ }
+
+ stats := service.toolCorrector.GetStats()
+
+ if stats.TotalCorrected != 3 {
+ t.Errorf("expected 3 corrections, got %d", stats.TotalCorrected)
+ }
+
+ if stats.CorrectionsByTool["apply_patch->edit"] != 2 {
+ t.Errorf("expected 2 apply_patch->edit corrections, got %d", stats.CorrectionsByTool["apply_patch->edit"])
+ }
+
+ if stats.CorrectionsByTool["update_plan->todowrite"] != 1 {
+ t.Errorf("expected 1 update_plan->todowrite correction, got %d", stats.CorrectionsByTool["update_plan->todowrite"])
+ }
+}
diff --git a/backend/internal/service/openai_token_provider.go b/backend/internal/service/openai_token_provider.go
new file mode 100644
index 00000000..82a0866f
--- /dev/null
+++ b/backend/internal/service/openai_token_provider.go
@@ -0,0 +1,189 @@
+package service
+
+import (
+ "context"
+ "errors"
+ "log/slog"
+ "strings"
+ "time"
+)
+
+const (
+ openAITokenRefreshSkew = 3 * time.Minute
+ openAITokenCacheSkew = 5 * time.Minute
+ openAILockWaitTime = 200 * time.Millisecond
+)
+
+// OpenAITokenCache Token 缓存接口(复用 GeminiTokenCache 接口定义)
+type OpenAITokenCache = GeminiTokenCache
+
+// OpenAITokenProvider 管理 OpenAI OAuth 账户的 access_token
+type OpenAITokenProvider struct {
+ accountRepo AccountRepository
+ tokenCache OpenAITokenCache
+ openAIOAuthService *OpenAIOAuthService
+}
+
+func NewOpenAITokenProvider(
+ accountRepo AccountRepository,
+ tokenCache OpenAITokenCache,
+ openAIOAuthService *OpenAIOAuthService,
+) *OpenAITokenProvider {
+ return &OpenAITokenProvider{
+ accountRepo: accountRepo,
+ tokenCache: tokenCache,
+ openAIOAuthService: openAIOAuthService,
+ }
+}
+
+// GetAccessToken 获取有效的 access_token
+func (p *OpenAITokenProvider) GetAccessToken(ctx context.Context, account *Account) (string, error) {
+ if account == nil {
+ return "", errors.New("account is nil")
+ }
+ if account.Platform != PlatformOpenAI || account.Type != AccountTypeOAuth {
+ return "", errors.New("not an openai oauth account")
+ }
+
+ cacheKey := OpenAITokenCacheKey(account)
+
+ // 1. 先尝试缓存
+ if p.tokenCache != nil {
+ if token, err := p.tokenCache.GetAccessToken(ctx, cacheKey); err == nil && strings.TrimSpace(token) != "" {
+ slog.Debug("openai_token_cache_hit", "account_id", account.ID)
+ return token, nil
+ } else if err != nil {
+ slog.Warn("openai_token_cache_get_failed", "account_id", account.ID, "error", err)
+ }
+ }
+
+ slog.Debug("openai_token_cache_miss", "account_id", account.ID)
+
+ // 2. 如果即将过期则刷新
+ expiresAt := account.GetCredentialAsTime("expires_at")
+ needsRefresh := expiresAt == nil || time.Until(*expiresAt) <= openAITokenRefreshSkew
+ refreshFailed := false
+ if needsRefresh && p.tokenCache != nil {
+ locked, lockErr := p.tokenCache.AcquireRefreshLock(ctx, cacheKey, 30*time.Second)
+ if lockErr == nil && locked {
+ defer func() { _ = p.tokenCache.ReleaseRefreshLock(ctx, cacheKey) }()
+
+ // 拿到锁后再次检查缓存(另一个 worker 可能已刷新)
+ if token, err := p.tokenCache.GetAccessToken(ctx, cacheKey); err == nil && strings.TrimSpace(token) != "" {
+ return token, nil
+ }
+
+ // 从数据库获取最新账户信息
+ fresh, err := p.accountRepo.GetByID(ctx, account.ID)
+ if err == nil && fresh != nil {
+ account = fresh
+ }
+ expiresAt = account.GetCredentialAsTime("expires_at")
+ if expiresAt == nil || time.Until(*expiresAt) <= openAITokenRefreshSkew {
+ if p.openAIOAuthService == nil {
+ slog.Warn("openai_oauth_service_not_configured", "account_id", account.ID)
+ refreshFailed = true // 无法刷新,标记失败
+ } else {
+ tokenInfo, err := p.openAIOAuthService.RefreshAccountToken(ctx, account)
+ if err != nil {
+ // 刷新失败时记录警告,但不立即返回错误,尝试使用现有 token
+ slog.Warn("openai_token_refresh_failed", "account_id", account.ID, "error", err)
+ refreshFailed = true // 刷新失败,标记以使用短 TTL
+ } else {
+ newCredentials := p.openAIOAuthService.BuildAccountCredentials(tokenInfo)
+ for k, v := range account.Credentials {
+ if _, exists := newCredentials[k]; !exists {
+ newCredentials[k] = v
+ }
+ }
+ account.Credentials = newCredentials
+ if updateErr := p.accountRepo.Update(ctx, account); updateErr != nil {
+ slog.Error("openai_token_provider_update_failed", "account_id", account.ID, "error", updateErr)
+ }
+ expiresAt = account.GetCredentialAsTime("expires_at")
+ }
+ }
+ }
+ } else if lockErr != nil {
+ // Redis 错误导致无法获取锁,降级为无锁刷新(仅在 token 接近过期时)
+ slog.Warn("openai_token_lock_failed_degraded_refresh", "account_id", account.ID, "error", lockErr)
+
+ // 检查 ctx 是否已取消
+ if ctx.Err() != nil {
+ return "", ctx.Err()
+ }
+
+ // 从数据库获取最新账户信息
+ if p.accountRepo != nil {
+ fresh, err := p.accountRepo.GetByID(ctx, account.ID)
+ if err == nil && fresh != nil {
+ account = fresh
+ }
+ }
+ expiresAt = account.GetCredentialAsTime("expires_at")
+
+ // 仅在 expires_at 已过期/接近过期时才执行无锁刷新
+ if expiresAt == nil || time.Until(*expiresAt) <= openAITokenRefreshSkew {
+ if p.openAIOAuthService == nil {
+ slog.Warn("openai_oauth_service_not_configured", "account_id", account.ID)
+ refreshFailed = true
+ } else {
+ tokenInfo, err := p.openAIOAuthService.RefreshAccountToken(ctx, account)
+ if err != nil {
+ slog.Warn("openai_token_refresh_failed_degraded", "account_id", account.ID, "error", err)
+ refreshFailed = true
+ } else {
+ newCredentials := p.openAIOAuthService.BuildAccountCredentials(tokenInfo)
+ for k, v := range account.Credentials {
+ if _, exists := newCredentials[k]; !exists {
+ newCredentials[k] = v
+ }
+ }
+ account.Credentials = newCredentials
+ if updateErr := p.accountRepo.Update(ctx, account); updateErr != nil {
+ slog.Error("openai_token_provider_update_failed", "account_id", account.ID, "error", updateErr)
+ }
+ expiresAt = account.GetCredentialAsTime("expires_at")
+ }
+ }
+ }
+ } else {
+ // 锁获取失败(被其他 worker 持有),等待 200ms 后重试读取缓存
+ time.Sleep(openAILockWaitTime)
+ if token, err := p.tokenCache.GetAccessToken(ctx, cacheKey); err == nil && strings.TrimSpace(token) != "" {
+ slog.Debug("openai_token_cache_hit_after_wait", "account_id", account.ID)
+ return token, nil
+ }
+ }
+ }
+
+ accessToken := account.GetOpenAIAccessToken()
+ if strings.TrimSpace(accessToken) == "" {
+ return "", errors.New("access_token not found in credentials")
+ }
+
+ // 3. 存入缓存
+ if p.tokenCache != nil {
+ ttl := 30 * time.Minute
+ if refreshFailed {
+ // 刷新失败时使用短 TTL,避免失效 token 长时间缓存导致 401 抖动
+ ttl = time.Minute
+ slog.Debug("openai_token_cache_short_ttl", "account_id", account.ID, "reason", "refresh_failed")
+ } else if expiresAt != nil {
+ until := time.Until(*expiresAt)
+ switch {
+ case until > openAITokenCacheSkew:
+ ttl = until - openAITokenCacheSkew
+ case until > 0:
+ ttl = until
+ default:
+ ttl = time.Minute
+ }
+ }
+ if err := p.tokenCache.SetAccessToken(ctx, cacheKey, accessToken, ttl); err != nil {
+ slog.Warn("openai_token_cache_set_failed", "account_id", account.ID, "error", err)
+ }
+ }
+
+ return accessToken, nil
+}
diff --git a/backend/internal/service/openai_token_provider_test.go b/backend/internal/service/openai_token_provider_test.go
new file mode 100644
index 00000000..c2e3dbb0
--- /dev/null
+++ b/backend/internal/service/openai_token_provider_test.go
@@ -0,0 +1,810 @@
+//go:build unit
+
+package service
+
+import (
+ "context"
+ "errors"
+ "sync"
+ "sync/atomic"
+ "testing"
+ "time"
+
+ "github.com/stretchr/testify/require"
+)
+
+// openAITokenCacheStub implements OpenAITokenCache for testing
+type openAITokenCacheStub struct {
+ mu sync.Mutex
+ tokens map[string]string
+ getErr error
+ setErr error
+ deleteErr error
+ lockAcquired bool
+ lockErr error
+ releaseLockErr error
+ getCalled int32
+ setCalled int32
+ lockCalled int32
+ unlockCalled int32
+ simulateLockRace bool
+}
+
+func newOpenAITokenCacheStub() *openAITokenCacheStub {
+ return &openAITokenCacheStub{
+ tokens: make(map[string]string),
+ lockAcquired: true,
+ }
+}
+
+func (s *openAITokenCacheStub) GetAccessToken(ctx context.Context, cacheKey string) (string, error) {
+ atomic.AddInt32(&s.getCalled, 1)
+ if s.getErr != nil {
+ return "", s.getErr
+ }
+ s.mu.Lock()
+ defer s.mu.Unlock()
+ return s.tokens[cacheKey], nil
+}
+
+func (s *openAITokenCacheStub) SetAccessToken(ctx context.Context, cacheKey string, token string, ttl time.Duration) error {
+ atomic.AddInt32(&s.setCalled, 1)
+ if s.setErr != nil {
+ return s.setErr
+ }
+ s.mu.Lock()
+ defer s.mu.Unlock()
+ s.tokens[cacheKey] = token
+ return nil
+}
+
+func (s *openAITokenCacheStub) DeleteAccessToken(ctx context.Context, cacheKey string) error {
+ if s.deleteErr != nil {
+ return s.deleteErr
+ }
+ s.mu.Lock()
+ defer s.mu.Unlock()
+ delete(s.tokens, cacheKey)
+ return nil
+}
+
+func (s *openAITokenCacheStub) AcquireRefreshLock(ctx context.Context, cacheKey string, ttl time.Duration) (bool, error) {
+ atomic.AddInt32(&s.lockCalled, 1)
+ if s.lockErr != nil {
+ return false, s.lockErr
+ }
+ if s.simulateLockRace {
+ return false, nil
+ }
+ return s.lockAcquired, nil
+}
+
+func (s *openAITokenCacheStub) ReleaseRefreshLock(ctx context.Context, cacheKey string) error {
+ atomic.AddInt32(&s.unlockCalled, 1)
+ return s.releaseLockErr
+}
+
+// openAIAccountRepoStub is a minimal stub implementing only the methods used by OpenAITokenProvider
+type openAIAccountRepoStub struct {
+ account *Account
+ getErr error
+ updateErr error
+ getCalled int32
+ updateCalled int32
+}
+
+func (r *openAIAccountRepoStub) GetByID(ctx context.Context, id int64) (*Account, error) {
+ atomic.AddInt32(&r.getCalled, 1)
+ if r.getErr != nil {
+ return nil, r.getErr
+ }
+ return r.account, nil
+}
+
+func (r *openAIAccountRepoStub) Update(ctx context.Context, account *Account) error {
+ atomic.AddInt32(&r.updateCalled, 1)
+ if r.updateErr != nil {
+ return r.updateErr
+ }
+ r.account = account
+ return nil
+}
+
+// openAIOAuthServiceStub implements OpenAIOAuthService methods for testing
+type openAIOAuthServiceStub struct {
+ tokenInfo *OpenAITokenInfo
+ refreshErr error
+ refreshCalled int32
+}
+
+func (s *openAIOAuthServiceStub) RefreshAccountToken(ctx context.Context, account *Account) (*OpenAITokenInfo, error) {
+ atomic.AddInt32(&s.refreshCalled, 1)
+ if s.refreshErr != nil {
+ return nil, s.refreshErr
+ }
+ return s.tokenInfo, nil
+}
+
+func (s *openAIOAuthServiceStub) BuildAccountCredentials(info *OpenAITokenInfo) map[string]any {
+ now := time.Now()
+ return map[string]any{
+ "access_token": info.AccessToken,
+ "refresh_token": info.RefreshToken,
+ "expires_at": now.Add(time.Duration(info.ExpiresIn) * time.Second).Format(time.RFC3339),
+ }
+}
+
+func TestOpenAITokenProvider_CacheHit(t *testing.T) {
+ cache := newOpenAITokenCacheStub()
+ account := &Account{
+ ID: 100,
+ Platform: PlatformOpenAI,
+ Type: AccountTypeOAuth,
+ Credentials: map[string]any{
+ "access_token": "db-token",
+ },
+ }
+ cacheKey := OpenAITokenCacheKey(account)
+ cache.tokens[cacheKey] = "cached-token"
+
+ provider := NewOpenAITokenProvider(nil, cache, nil)
+
+ token, err := provider.GetAccessToken(context.Background(), account)
+ require.NoError(t, err)
+ require.Equal(t, "cached-token", token)
+ require.Equal(t, int32(1), atomic.LoadInt32(&cache.getCalled))
+ require.Equal(t, int32(0), atomic.LoadInt32(&cache.setCalled))
+}
+
+func TestOpenAITokenProvider_CacheMiss_FromCredentials(t *testing.T) {
+ cache := newOpenAITokenCacheStub()
+ // Token expires in far future, no refresh needed
+ expiresAt := time.Now().Add(1 * time.Hour).Format(time.RFC3339)
+ account := &Account{
+ ID: 101,
+ Platform: PlatformOpenAI,
+ Type: AccountTypeOAuth,
+ Credentials: map[string]any{
+ "access_token": "credential-token",
+ "expires_at": expiresAt,
+ },
+ }
+
+ provider := NewOpenAITokenProvider(nil, cache, nil)
+
+ token, err := provider.GetAccessToken(context.Background(), account)
+ require.NoError(t, err)
+ require.Equal(t, "credential-token", token)
+
+ // Should have stored in cache
+ cacheKey := OpenAITokenCacheKey(account)
+ require.Equal(t, "credential-token", cache.tokens[cacheKey])
+}
+
+func TestOpenAITokenProvider_TokenRefresh(t *testing.T) {
+ cache := newOpenAITokenCacheStub()
+ accountRepo := &openAIAccountRepoStub{}
+ oauthService := &openAIOAuthServiceStub{
+ tokenInfo: &OpenAITokenInfo{
+ AccessToken: "refreshed-token",
+ RefreshToken: "new-refresh-token",
+ ExpiresIn: 3600,
+ },
+ }
+
+ // Token expires soon (within refresh skew)
+ expiresAt := time.Now().Add(1 * time.Minute).Format(time.RFC3339)
+ account := &Account{
+ ID: 102,
+ Platform: PlatformOpenAI,
+ Type: AccountTypeOAuth,
+ Credentials: map[string]any{
+ "access_token": "old-token",
+ "refresh_token": "old-refresh-token",
+ "expires_at": expiresAt,
+ },
+ }
+ accountRepo.account = account
+
+ // We need to directly test with the stub - create a custom provider
+ customProvider := &testOpenAITokenProvider{
+ accountRepo: accountRepo,
+ tokenCache: cache,
+ oauthService: oauthService,
+ }
+
+ token, err := customProvider.GetAccessToken(context.Background(), account)
+ require.NoError(t, err)
+ require.Equal(t, "refreshed-token", token)
+ require.Equal(t, int32(1), atomic.LoadInt32(&oauthService.refreshCalled))
+}
+
+// testOpenAITokenProvider is a test version that uses the stub OAuth service
+type testOpenAITokenProvider struct {
+ accountRepo *openAIAccountRepoStub
+ tokenCache *openAITokenCacheStub
+ oauthService *openAIOAuthServiceStub
+}
+
+func (p *testOpenAITokenProvider) GetAccessToken(ctx context.Context, account *Account) (string, error) {
+ if account == nil {
+ return "", errors.New("account is nil")
+ }
+ if account.Platform != PlatformOpenAI || account.Type != AccountTypeOAuth {
+ return "", errors.New("not an openai oauth account")
+ }
+
+ cacheKey := OpenAITokenCacheKey(account)
+
+ // 1. Check cache
+ if p.tokenCache != nil {
+ if token, err := p.tokenCache.GetAccessToken(ctx, cacheKey); err == nil && token != "" {
+ return token, nil
+ }
+ }
+
+ // 2. Check if refresh needed
+ expiresAt := account.GetCredentialAsTime("expires_at")
+ needsRefresh := expiresAt == nil || time.Until(*expiresAt) <= openAITokenRefreshSkew
+ refreshFailed := false
+ if needsRefresh && p.tokenCache != nil {
+ locked, err := p.tokenCache.AcquireRefreshLock(ctx, cacheKey, 30*time.Second)
+ if err == nil && locked {
+ defer func() { _ = p.tokenCache.ReleaseRefreshLock(ctx, cacheKey) }()
+
+ // Check cache again after acquiring lock
+ if token, err := p.tokenCache.GetAccessToken(ctx, cacheKey); err == nil && token != "" {
+ return token, nil
+ }
+
+ // Get fresh account from DB
+ fresh, err := p.accountRepo.GetByID(ctx, account.ID)
+ if err == nil && fresh != nil {
+ account = fresh
+ }
+ expiresAt = account.GetCredentialAsTime("expires_at")
+ if expiresAt == nil || time.Until(*expiresAt) <= openAITokenRefreshSkew {
+ if p.oauthService == nil {
+ refreshFailed = true // 无法刷新,标记失败
+ } else {
+ tokenInfo, err := p.oauthService.RefreshAccountToken(ctx, account)
+ if err != nil {
+ refreshFailed = true // 刷新失败,标记以使用短 TTL
+ } else {
+ newCredentials := p.oauthService.BuildAccountCredentials(tokenInfo)
+ for k, v := range account.Credentials {
+ if _, exists := newCredentials[k]; !exists {
+ newCredentials[k] = v
+ }
+ }
+ account.Credentials = newCredentials
+ _ = p.accountRepo.Update(ctx, account)
+ expiresAt = account.GetCredentialAsTime("expires_at")
+ }
+ }
+ }
+ } else if p.tokenCache.simulateLockRace {
+ // Wait and retry cache
+ time.Sleep(10 * time.Millisecond) // Short wait for test
+ if token, err := p.tokenCache.GetAccessToken(ctx, cacheKey); err == nil && token != "" {
+ return token, nil
+ }
+ }
+ }
+
+ accessToken := account.GetOpenAIAccessToken()
+ if accessToken == "" {
+ return "", errors.New("access_token not found in credentials")
+ }
+
+ // 3. Store in cache
+ if p.tokenCache != nil {
+ ttl := 30 * time.Minute
+ if refreshFailed {
+ ttl = time.Minute // 刷新失败时使用短 TTL
+ } else if expiresAt != nil {
+ until := time.Until(*expiresAt)
+ if until > openAITokenCacheSkew {
+ ttl = until - openAITokenCacheSkew
+ } else if until > 0 {
+ ttl = until
+ } else {
+ ttl = time.Minute
+ }
+ }
+ _ = p.tokenCache.SetAccessToken(ctx, cacheKey, accessToken, ttl)
+ }
+
+ return accessToken, nil
+}
+
+func TestOpenAITokenProvider_LockRaceCondition(t *testing.T) {
+ cache := newOpenAITokenCacheStub()
+ cache.simulateLockRace = true
+ accountRepo := &openAIAccountRepoStub{}
+
+ // Token expires soon
+ expiresAt := time.Now().Add(1 * time.Minute).Format(time.RFC3339)
+ account := &Account{
+ ID: 103,
+ Platform: PlatformOpenAI,
+ Type: AccountTypeOAuth,
+ Credentials: map[string]any{
+ "access_token": "race-token",
+ "expires_at": expiresAt,
+ },
+ }
+ accountRepo.account = account
+
+ // Simulate another worker already refreshed and cached
+ cacheKey := OpenAITokenCacheKey(account)
+ go func() {
+ time.Sleep(5 * time.Millisecond)
+ cache.mu.Lock()
+ cache.tokens[cacheKey] = "winner-token"
+ cache.mu.Unlock()
+ }()
+
+ provider := &testOpenAITokenProvider{
+ accountRepo: accountRepo,
+ tokenCache: cache,
+ }
+
+ token, err := provider.GetAccessToken(context.Background(), account)
+ require.NoError(t, err)
+ // Should get the token set by the "winner" or the original
+ require.NotEmpty(t, token)
+}
+
+func TestOpenAITokenProvider_NilAccount(t *testing.T) {
+ provider := NewOpenAITokenProvider(nil, nil, nil)
+
+ token, err := provider.GetAccessToken(context.Background(), nil)
+ require.Error(t, err)
+ require.Contains(t, err.Error(), "account is nil")
+ require.Empty(t, token)
+}
+
+func TestOpenAITokenProvider_WrongPlatform(t *testing.T) {
+ provider := NewOpenAITokenProvider(nil, nil, nil)
+ account := &Account{
+ ID: 104,
+ Platform: PlatformGemini,
+ Type: AccountTypeOAuth,
+ }
+
+ token, err := provider.GetAccessToken(context.Background(), account)
+ require.Error(t, err)
+ require.Contains(t, err.Error(), "not an openai oauth account")
+ require.Empty(t, token)
+}
+
+func TestOpenAITokenProvider_WrongAccountType(t *testing.T) {
+ provider := NewOpenAITokenProvider(nil, nil, nil)
+ account := &Account{
+ ID: 105,
+ Platform: PlatformOpenAI,
+ Type: AccountTypeAPIKey,
+ }
+
+ token, err := provider.GetAccessToken(context.Background(), account)
+ require.Error(t, err)
+ require.Contains(t, err.Error(), "not an openai oauth account")
+ require.Empty(t, token)
+}
+
+func TestOpenAITokenProvider_NilCache(t *testing.T) {
+ // Token doesn't need refresh
+ expiresAt := time.Now().Add(1 * time.Hour).Format(time.RFC3339)
+ account := &Account{
+ ID: 106,
+ Platform: PlatformOpenAI,
+ Type: AccountTypeOAuth,
+ Credentials: map[string]any{
+ "access_token": "nocache-token",
+ "expires_at": expiresAt,
+ },
+ }
+
+ provider := NewOpenAITokenProvider(nil, nil, nil)
+
+ token, err := provider.GetAccessToken(context.Background(), account)
+ require.NoError(t, err)
+ require.Equal(t, "nocache-token", token)
+}
+
+func TestOpenAITokenProvider_CacheGetError(t *testing.T) {
+ cache := newOpenAITokenCacheStub()
+ cache.getErr = errors.New("redis connection failed")
+
+ // Token doesn't need refresh
+ expiresAt := time.Now().Add(1 * time.Hour).Format(time.RFC3339)
+ account := &Account{
+ ID: 107,
+ Platform: PlatformOpenAI,
+ Type: AccountTypeOAuth,
+ Credentials: map[string]any{
+ "access_token": "fallback-token",
+ "expires_at": expiresAt,
+ },
+ }
+
+ provider := NewOpenAITokenProvider(nil, cache, nil)
+
+ // Should gracefully degrade and return from credentials
+ token, err := provider.GetAccessToken(context.Background(), account)
+ require.NoError(t, err)
+ require.Equal(t, "fallback-token", token)
+}
+
+func TestOpenAITokenProvider_CacheSetError(t *testing.T) {
+ cache := newOpenAITokenCacheStub()
+ cache.setErr = errors.New("redis write failed")
+
+ expiresAt := time.Now().Add(1 * time.Hour).Format(time.RFC3339)
+ account := &Account{
+ ID: 108,
+ Platform: PlatformOpenAI,
+ Type: AccountTypeOAuth,
+ Credentials: map[string]any{
+ "access_token": "still-works-token",
+ "expires_at": expiresAt,
+ },
+ }
+
+ provider := NewOpenAITokenProvider(nil, cache, nil)
+
+ // Should still work even if cache set fails
+ token, err := provider.GetAccessToken(context.Background(), account)
+ require.NoError(t, err)
+ require.Equal(t, "still-works-token", token)
+}
+
+func TestOpenAITokenProvider_MissingAccessToken(t *testing.T) {
+ cache := newOpenAITokenCacheStub()
+ expiresAt := time.Now().Add(1 * time.Hour).Format(time.RFC3339)
+ account := &Account{
+ ID: 109,
+ Platform: PlatformOpenAI,
+ Type: AccountTypeOAuth,
+ Credentials: map[string]any{
+ "expires_at": expiresAt,
+ // missing access_token
+ },
+ }
+
+ provider := NewOpenAITokenProvider(nil, cache, nil)
+
+ token, err := provider.GetAccessToken(context.Background(), account)
+ require.Error(t, err)
+ require.Contains(t, err.Error(), "access_token not found")
+ require.Empty(t, token)
+}
+
+func TestOpenAITokenProvider_RefreshError(t *testing.T) {
+ cache := newOpenAITokenCacheStub()
+ accountRepo := &openAIAccountRepoStub{}
+ oauthService := &openAIOAuthServiceStub{
+ refreshErr: errors.New("oauth refresh failed"),
+ }
+
+ // Token expires soon
+ expiresAt := time.Now().Add(1 * time.Minute).Format(time.RFC3339)
+ account := &Account{
+ ID: 110,
+ Platform: PlatformOpenAI,
+ Type: AccountTypeOAuth,
+ Credentials: map[string]any{
+ "access_token": "old-token",
+ "refresh_token": "old-refresh-token",
+ "expires_at": expiresAt,
+ },
+ }
+ accountRepo.account = account
+
+ provider := &testOpenAITokenProvider{
+ accountRepo: accountRepo,
+ tokenCache: cache,
+ oauthService: oauthService,
+ }
+
+ // Now with fallback behavior, should return existing token even if refresh fails
+ token, err := provider.GetAccessToken(context.Background(), account)
+ require.NoError(t, err)
+ require.Equal(t, "old-token", token) // Fallback to existing token
+}
+
+func TestOpenAITokenProvider_OAuthServiceNotConfigured(t *testing.T) {
+ cache := newOpenAITokenCacheStub()
+ accountRepo := &openAIAccountRepoStub{}
+
+ // Token expires soon
+ expiresAt := time.Now().Add(1 * time.Minute).Format(time.RFC3339)
+ account := &Account{
+ ID: 111,
+ Platform: PlatformOpenAI,
+ Type: AccountTypeOAuth,
+ Credentials: map[string]any{
+ "access_token": "old-token",
+ "expires_at": expiresAt,
+ },
+ }
+ accountRepo.account = account
+
+ provider := &testOpenAITokenProvider{
+ accountRepo: accountRepo,
+ tokenCache: cache,
+ oauthService: nil, // not configured
+ }
+
+ // Now with fallback behavior, should return existing token even if oauth service not configured
+ token, err := provider.GetAccessToken(context.Background(), account)
+ require.NoError(t, err)
+ require.Equal(t, "old-token", token) // Fallback to existing token
+}
+
+func TestOpenAITokenProvider_TTLCalculation(t *testing.T) {
+ tests := []struct {
+ name string
+ expiresIn time.Duration
+ }{
+ {
+ name: "far_future_expiry",
+ expiresIn: 1 * time.Hour,
+ },
+ {
+ name: "medium_expiry",
+ expiresIn: 10 * time.Minute,
+ },
+ {
+ name: "near_expiry",
+ expiresIn: 6 * time.Minute,
+ },
+ }
+
+ for _, tt := range tests {
+ t.Run(tt.name, func(t *testing.T) {
+ cache := newOpenAITokenCacheStub()
+ expiresAt := time.Now().Add(tt.expiresIn).Format(time.RFC3339)
+ account := &Account{
+ ID: 200,
+ Platform: PlatformOpenAI,
+ Type: AccountTypeOAuth,
+ Credentials: map[string]any{
+ "access_token": "test-token",
+ "expires_at": expiresAt,
+ },
+ }
+
+ provider := NewOpenAITokenProvider(nil, cache, nil)
+
+ _, err := provider.GetAccessToken(context.Background(), account)
+ require.NoError(t, err)
+
+ // Verify token was cached
+ cacheKey := OpenAITokenCacheKey(account)
+ require.Equal(t, "test-token", cache.tokens[cacheKey])
+ })
+ }
+}
+
+func TestOpenAITokenProvider_DoubleCheckAfterLock(t *testing.T) {
+ cache := newOpenAITokenCacheStub()
+ accountRepo := &openAIAccountRepoStub{}
+ oauthService := &openAIOAuthServiceStub{
+ tokenInfo: &OpenAITokenInfo{
+ AccessToken: "refreshed-token",
+ RefreshToken: "new-refresh",
+ ExpiresIn: 3600,
+ },
+ }
+
+ // Token expires soon
+ expiresAt := time.Now().Add(1 * time.Minute).Format(time.RFC3339)
+ account := &Account{
+ ID: 112,
+ Platform: PlatformOpenAI,
+ Type: AccountTypeOAuth,
+ Credentials: map[string]any{
+ "access_token": "old-token",
+ "expires_at": expiresAt,
+ },
+ }
+ accountRepo.account = account
+ cacheKey := OpenAITokenCacheKey(account)
+
+ // Simulate: first GetAccessToken returns empty, but after lock acquired, cache has token
+ originalGet := int32(0)
+ cache.tokens[cacheKey] = "" // Empty initially
+
+ provider := &testOpenAITokenProvider{
+ accountRepo: accountRepo,
+ tokenCache: cache,
+ oauthService: oauthService,
+ }
+
+ // In a goroutine, set the cached token after a small delay (simulating race)
+ go func() {
+ time.Sleep(5 * time.Millisecond)
+ cache.mu.Lock()
+ cache.tokens[cacheKey] = "cached-by-other"
+ cache.mu.Unlock()
+ }()
+
+ token, err := provider.GetAccessToken(context.Background(), account)
+ require.NoError(t, err)
+ // Should get either the refreshed token or the cached one
+ require.NotEmpty(t, token)
+ _ = originalGet // Suppress unused warning
+}
+
+// Tests for real provider - to increase coverage
+func TestOpenAITokenProvider_Real_LockFailedWait(t *testing.T) {
+ cache := newOpenAITokenCacheStub()
+ cache.lockAcquired = false // Lock acquisition fails
+
+ // Token expires soon (within refresh skew) to trigger lock attempt
+ expiresAt := time.Now().Add(1 * time.Minute).Format(time.RFC3339)
+ account := &Account{
+ ID: 200,
+ Platform: PlatformOpenAI,
+ Type: AccountTypeOAuth,
+ Credentials: map[string]any{
+ "access_token": "fallback-token",
+ "expires_at": expiresAt,
+ },
+ }
+
+ // Set token in cache after lock wait period (simulate other worker refreshing)
+ cacheKey := OpenAITokenCacheKey(account)
+ go func() {
+ time.Sleep(100 * time.Millisecond)
+ cache.mu.Lock()
+ cache.tokens[cacheKey] = "refreshed-by-other"
+ cache.mu.Unlock()
+ }()
+
+ provider := NewOpenAITokenProvider(nil, cache, nil)
+ token, err := provider.GetAccessToken(context.Background(), account)
+ require.NoError(t, err)
+ // Should get either the fallback token or the refreshed one
+ require.NotEmpty(t, token)
+}
+
+func TestOpenAITokenProvider_Real_CacheHitAfterWait(t *testing.T) {
+ cache := newOpenAITokenCacheStub()
+ cache.lockAcquired = false // Lock acquisition fails
+
+ // Token expires soon
+ expiresAt := time.Now().Add(1 * time.Minute).Format(time.RFC3339)
+ account := &Account{
+ ID: 201,
+ Platform: PlatformOpenAI,
+ Type: AccountTypeOAuth,
+ Credentials: map[string]any{
+ "access_token": "original-token",
+ "expires_at": expiresAt,
+ },
+ }
+
+ cacheKey := OpenAITokenCacheKey(account)
+ // Set token in cache immediately after wait starts
+ go func() {
+ time.Sleep(50 * time.Millisecond)
+ cache.mu.Lock()
+ cache.tokens[cacheKey] = "winner-token"
+ cache.mu.Unlock()
+ }()
+
+ provider := NewOpenAITokenProvider(nil, cache, nil)
+ token, err := provider.GetAccessToken(context.Background(), account)
+ require.NoError(t, err)
+ require.NotEmpty(t, token)
+}
+
+func TestOpenAITokenProvider_Real_ExpiredWithoutRefreshToken(t *testing.T) {
+ cache := newOpenAITokenCacheStub()
+ cache.lockAcquired = false // Prevent entering refresh logic
+
+ // Token with nil expires_at (no expiry set) - should use credentials
+ account := &Account{
+ ID: 202,
+ Platform: PlatformOpenAI,
+ Type: AccountTypeOAuth,
+ Credentials: map[string]any{
+ "access_token": "no-expiry-token",
+ },
+ }
+
+ provider := NewOpenAITokenProvider(nil, cache, nil)
+ token, err := provider.GetAccessToken(context.Background(), account)
+ // Without OAuth service, refresh will fail but token should be returned from credentials
+ require.NoError(t, err)
+ require.Equal(t, "no-expiry-token", token)
+}
+
+func TestOpenAITokenProvider_Real_WhitespaceToken(t *testing.T) {
+ cache := newOpenAITokenCacheStub()
+ cacheKey := "openai:account:203"
+ cache.tokens[cacheKey] = " " // Whitespace only - should be treated as empty
+
+ expiresAt := time.Now().Add(1 * time.Hour).Format(time.RFC3339)
+ account := &Account{
+ ID: 203,
+ Platform: PlatformOpenAI,
+ Type: AccountTypeOAuth,
+ Credentials: map[string]any{
+ "access_token": "real-token",
+ "expires_at": expiresAt,
+ },
+ }
+
+ provider := NewOpenAITokenProvider(nil, cache, nil)
+ token, err := provider.GetAccessToken(context.Background(), account)
+ require.NoError(t, err)
+ require.Equal(t, "real-token", token) // Should fall back to credentials
+}
+
+func TestOpenAITokenProvider_Real_LockError(t *testing.T) {
+ cache := newOpenAITokenCacheStub()
+ cache.lockErr = errors.New("redis lock failed")
+
+ // Token expires soon (within refresh skew)
+ expiresAt := time.Now().Add(1 * time.Minute).Format(time.RFC3339)
+ account := &Account{
+ ID: 204,
+ Platform: PlatformOpenAI,
+ Type: AccountTypeOAuth,
+ Credentials: map[string]any{
+ "access_token": "fallback-on-lock-error",
+ "expires_at": expiresAt,
+ },
+ }
+
+ provider := NewOpenAITokenProvider(nil, cache, nil)
+ token, err := provider.GetAccessToken(context.Background(), account)
+ require.NoError(t, err)
+ require.Equal(t, "fallback-on-lock-error", token)
+}
+
+func TestOpenAITokenProvider_Real_WhitespaceCredentialToken(t *testing.T) {
+ cache := newOpenAITokenCacheStub()
+
+ expiresAt := time.Now().Add(1 * time.Hour).Format(time.RFC3339)
+ account := &Account{
+ ID: 205,
+ Platform: PlatformOpenAI,
+ Type: AccountTypeOAuth,
+ Credentials: map[string]any{
+ "access_token": " ", // Whitespace only
+ "expires_at": expiresAt,
+ },
+ }
+
+ provider := NewOpenAITokenProvider(nil, cache, nil)
+ token, err := provider.GetAccessToken(context.Background(), account)
+ require.Error(t, err)
+ require.Contains(t, err.Error(), "access_token not found")
+ require.Empty(t, token)
+}
+
+func TestOpenAITokenProvider_Real_NilCredentials(t *testing.T) {
+ cache := newOpenAITokenCacheStub()
+
+ expiresAt := time.Now().Add(1 * time.Hour).Format(time.RFC3339)
+ account := &Account{
+ ID: 206,
+ Platform: PlatformOpenAI,
+ Type: AccountTypeOAuth,
+ Credentials: map[string]any{
+ "expires_at": expiresAt,
+ // No access_token
+ },
+ }
+
+ provider := NewOpenAITokenProvider(nil, cache, nil)
+ token, err := provider.GetAccessToken(context.Background(), account)
+ require.Error(t, err)
+ require.Contains(t, err.Error(), "access_token not found")
+ require.Empty(t, token)
+}
diff --git a/backend/internal/service/openai_tool_corrector.go b/backend/internal/service/openai_tool_corrector.go
new file mode 100644
index 00000000..9c9eab84
--- /dev/null
+++ b/backend/internal/service/openai_tool_corrector.go
@@ -0,0 +1,307 @@
+package service
+
+import (
+ "encoding/json"
+ "fmt"
+ "log"
+ "sync"
+)
+
+// codexToolNameMapping 定义 Codex 原生工具名称到 OpenCode 工具名称的映射
+var codexToolNameMapping = map[string]string{
+ "apply_patch": "edit",
+ "applyPatch": "edit",
+ "update_plan": "todowrite",
+ "updatePlan": "todowrite",
+ "read_plan": "todoread",
+ "readPlan": "todoread",
+ "search_files": "grep",
+ "searchFiles": "grep",
+ "list_files": "glob",
+ "listFiles": "glob",
+ "read_file": "read",
+ "readFile": "read",
+ "write_file": "write",
+ "writeFile": "write",
+ "execute_bash": "bash",
+ "executeBash": "bash",
+ "exec_bash": "bash",
+ "execBash": "bash",
+}
+
+// ToolCorrectionStats 记录工具修正的统计信息(导出用于 JSON 序列化)
+type ToolCorrectionStats struct {
+ TotalCorrected int `json:"total_corrected"`
+ CorrectionsByTool map[string]int `json:"corrections_by_tool"`
+}
+
+// CodexToolCorrector 处理 Codex 工具调用的自动修正
+type CodexToolCorrector struct {
+ stats ToolCorrectionStats
+ mu sync.RWMutex
+}
+
+// NewCodexToolCorrector 创建新的工具修正器
+func NewCodexToolCorrector() *CodexToolCorrector {
+ return &CodexToolCorrector{
+ stats: ToolCorrectionStats{
+ CorrectionsByTool: make(map[string]int),
+ },
+ }
+}
+
+// CorrectToolCallsInSSEData 修正 SSE 数据中的工具调用
+// 返回修正后的数据和是否进行了修正
+func (c *CodexToolCorrector) CorrectToolCallsInSSEData(data string) (string, bool) {
+ if data == "" || data == "\n" {
+ return data, false
+ }
+
+ // 尝试解析 JSON
+ var payload map[string]any
+ if err := json.Unmarshal([]byte(data), &payload); err != nil {
+ // 不是有效的 JSON,直接返回原数据
+ return data, false
+ }
+
+ corrected := false
+
+ // 处理 tool_calls 数组
+ if toolCalls, ok := payload["tool_calls"].([]any); ok {
+ if c.correctToolCallsArray(toolCalls) {
+ corrected = true
+ }
+ }
+
+ // 处理 function_call 对象
+ if functionCall, ok := payload["function_call"].(map[string]any); ok {
+ if c.correctFunctionCall(functionCall) {
+ corrected = true
+ }
+ }
+
+ // 处理 delta.tool_calls
+ if delta, ok := payload["delta"].(map[string]any); ok {
+ if toolCalls, ok := delta["tool_calls"].([]any); ok {
+ if c.correctToolCallsArray(toolCalls) {
+ corrected = true
+ }
+ }
+ if functionCall, ok := delta["function_call"].(map[string]any); ok {
+ if c.correctFunctionCall(functionCall) {
+ corrected = true
+ }
+ }
+ }
+
+ // 处理 choices[].message.tool_calls 和 choices[].delta.tool_calls
+ if choices, ok := payload["choices"].([]any); ok {
+ for _, choice := range choices {
+ if choiceMap, ok := choice.(map[string]any); ok {
+ // 处理 message 中的工具调用
+ if message, ok := choiceMap["message"].(map[string]any); ok {
+ if toolCalls, ok := message["tool_calls"].([]any); ok {
+ if c.correctToolCallsArray(toolCalls) {
+ corrected = true
+ }
+ }
+ if functionCall, ok := message["function_call"].(map[string]any); ok {
+ if c.correctFunctionCall(functionCall) {
+ corrected = true
+ }
+ }
+ }
+ // 处理 delta 中的工具调用
+ if delta, ok := choiceMap["delta"].(map[string]any); ok {
+ if toolCalls, ok := delta["tool_calls"].([]any); ok {
+ if c.correctToolCallsArray(toolCalls) {
+ corrected = true
+ }
+ }
+ if functionCall, ok := delta["function_call"].(map[string]any); ok {
+ if c.correctFunctionCall(functionCall) {
+ corrected = true
+ }
+ }
+ }
+ }
+ }
+ }
+
+ if !corrected {
+ return data, false
+ }
+
+ // 序列化回 JSON
+ correctedBytes, err := json.Marshal(payload)
+ if err != nil {
+ log.Printf("[CodexToolCorrector] Failed to marshal corrected data: %v", err)
+ return data, false
+ }
+
+ return string(correctedBytes), true
+}
+
+// correctToolCallsArray 修正工具调用数组中的工具名称
+func (c *CodexToolCorrector) correctToolCallsArray(toolCalls []any) bool {
+ corrected := false
+ for _, toolCall := range toolCalls {
+ if toolCallMap, ok := toolCall.(map[string]any); ok {
+ if function, ok := toolCallMap["function"].(map[string]any); ok {
+ if c.correctFunctionCall(function) {
+ corrected = true
+ }
+ }
+ }
+ }
+ return corrected
+}
+
+// correctFunctionCall 修正单个函数调用的工具名称和参数
+func (c *CodexToolCorrector) correctFunctionCall(functionCall map[string]any) bool {
+ name, ok := functionCall["name"].(string)
+ if !ok || name == "" {
+ return false
+ }
+
+ corrected := false
+
+ // 查找并修正工具名称
+ if correctName, found := codexToolNameMapping[name]; found {
+ functionCall["name"] = correctName
+ c.recordCorrection(name, correctName)
+ corrected = true
+ name = correctName // 使用修正后的名称进行参数修正
+ }
+
+ // 修正工具参数(基于工具名称)
+ if c.correctToolParameters(name, functionCall) {
+ corrected = true
+ }
+
+ return corrected
+}
+
+// correctToolParameters 修正工具参数以符合 OpenCode 规范
+func (c *CodexToolCorrector) correctToolParameters(toolName string, functionCall map[string]any) bool {
+ arguments, ok := functionCall["arguments"]
+ if !ok {
+ return false
+ }
+
+ // arguments 可能是字符串(JSON)或已解析的 map
+ var argsMap map[string]any
+ switch v := arguments.(type) {
+ case string:
+ // 解析 JSON 字符串
+ if err := json.Unmarshal([]byte(v), &argsMap); err != nil {
+ return false
+ }
+ case map[string]any:
+ argsMap = v
+ default:
+ return false
+ }
+
+ corrected := false
+
+ // 根据工具名称应用特定的参数修正规则
+ switch toolName {
+ case "bash":
+ // 移除 workdir 参数(OpenCode 不支持)
+ if _, exists := argsMap["workdir"]; exists {
+ delete(argsMap, "workdir")
+ corrected = true
+ log.Printf("[CodexToolCorrector] Removed 'workdir' parameter from bash tool")
+ }
+ if _, exists := argsMap["work_dir"]; exists {
+ delete(argsMap, "work_dir")
+ corrected = true
+ log.Printf("[CodexToolCorrector] Removed 'work_dir' parameter from bash tool")
+ }
+
+ case "edit":
+ // OpenCode edit 使用 old_string/new_string,Codex 可能使用其他名称
+ // 这里可以添加参数名称的映射逻辑
+ if _, exists := argsMap["file_path"]; !exists {
+ if path, exists := argsMap["path"]; exists {
+ argsMap["file_path"] = path
+ delete(argsMap, "path")
+ corrected = true
+ log.Printf("[CodexToolCorrector] Renamed 'path' to 'file_path' in edit tool")
+ }
+ }
+ }
+
+ // 如果修正了参数,需要重新序列化
+ if corrected {
+ if _, wasString := arguments.(string); wasString {
+ // 原本是字符串,序列化回字符串
+ if newArgsJSON, err := json.Marshal(argsMap); err == nil {
+ functionCall["arguments"] = string(newArgsJSON)
+ }
+ } else {
+ // 原本是 map,直接赋值
+ functionCall["arguments"] = argsMap
+ }
+ }
+
+ return corrected
+}
+
+// recordCorrection 记录一次工具名称修正
+func (c *CodexToolCorrector) recordCorrection(from, to string) {
+ c.mu.Lock()
+ defer c.mu.Unlock()
+
+ c.stats.TotalCorrected++
+ key := fmt.Sprintf("%s->%s", from, to)
+ c.stats.CorrectionsByTool[key]++
+
+ log.Printf("[CodexToolCorrector] Corrected tool call: %s -> %s (total: %d)",
+ from, to, c.stats.TotalCorrected)
+}
+
+// GetStats 获取工具修正统计信息
+func (c *CodexToolCorrector) GetStats() ToolCorrectionStats {
+ c.mu.RLock()
+ defer c.mu.RUnlock()
+
+ // 返回副本以避免并发问题
+ statsCopy := ToolCorrectionStats{
+ TotalCorrected: c.stats.TotalCorrected,
+ CorrectionsByTool: make(map[string]int, len(c.stats.CorrectionsByTool)),
+ }
+ for k, v := range c.stats.CorrectionsByTool {
+ statsCopy.CorrectionsByTool[k] = v
+ }
+
+ return statsCopy
+}
+
+// ResetStats 重置统计信息
+func (c *CodexToolCorrector) ResetStats() {
+ c.mu.Lock()
+ defer c.mu.Unlock()
+
+ c.stats.TotalCorrected = 0
+ c.stats.CorrectionsByTool = make(map[string]int)
+}
+
+// CorrectToolName 直接修正工具名称(用于非 SSE 场景)
+func CorrectToolName(name string) (string, bool) {
+ if correctName, found := codexToolNameMapping[name]; found {
+ return correctName, true
+ }
+ return name, false
+}
+
+// GetToolNameMapping 获取工具名称映射表
+func GetToolNameMapping() map[string]string {
+ // 返回副本以避免外部修改
+ mapping := make(map[string]string, len(codexToolNameMapping))
+ for k, v := range codexToolNameMapping {
+ mapping[k] = v
+ }
+ return mapping
+}
diff --git a/backend/internal/service/openai_tool_corrector_test.go b/backend/internal/service/openai_tool_corrector_test.go
new file mode 100644
index 00000000..3e885b4b
--- /dev/null
+++ b/backend/internal/service/openai_tool_corrector_test.go
@@ -0,0 +1,503 @@
+package service
+
+import (
+ "encoding/json"
+ "testing"
+)
+
+func TestCorrectToolCallsInSSEData(t *testing.T) {
+ corrector := NewCodexToolCorrector()
+
+ tests := []struct {
+ name string
+ input string
+ expectCorrected bool
+ checkFunc func(t *testing.T, result string)
+ }{
+ {
+ name: "empty string",
+ input: "",
+ expectCorrected: false,
+ },
+ {
+ name: "newline only",
+ input: "\n",
+ expectCorrected: false,
+ },
+ {
+ name: "invalid json",
+ input: "not a json",
+ expectCorrected: false,
+ },
+ {
+ name: "correct apply_patch in tool_calls",
+ input: `{"tool_calls":[{"function":{"name":"apply_patch","arguments":"{}"}}]}`,
+ expectCorrected: true,
+ checkFunc: func(t *testing.T, result string) {
+ var payload map[string]any
+ if err := json.Unmarshal([]byte(result), &payload); err != nil {
+ t.Fatalf("Failed to parse result: %v", err)
+ }
+ toolCalls, ok := payload["tool_calls"].([]any)
+ if !ok || len(toolCalls) == 0 {
+ t.Fatal("No tool_calls found in result")
+ }
+ toolCall, ok := toolCalls[0].(map[string]any)
+ if !ok {
+ t.Fatal("Invalid tool_call format")
+ }
+ functionCall, ok := toolCall["function"].(map[string]any)
+ if !ok {
+ t.Fatal("Invalid function format")
+ }
+ if functionCall["name"] != "edit" {
+ t.Errorf("Expected tool name 'edit', got '%v'", functionCall["name"])
+ }
+ },
+ },
+ {
+ name: "correct update_plan in function_call",
+ input: `{"function_call":{"name":"update_plan","arguments":"{}"}}`,
+ expectCorrected: true,
+ checkFunc: func(t *testing.T, result string) {
+ var payload map[string]any
+ if err := json.Unmarshal([]byte(result), &payload); err != nil {
+ t.Fatalf("Failed to parse result: %v", err)
+ }
+ functionCall, ok := payload["function_call"].(map[string]any)
+ if !ok {
+ t.Fatal("Invalid function_call format")
+ }
+ if functionCall["name"] != "todowrite" {
+ t.Errorf("Expected tool name 'todowrite', got '%v'", functionCall["name"])
+ }
+ },
+ },
+ {
+ name: "correct search_files in delta.tool_calls",
+ input: `{"delta":{"tool_calls":[{"function":{"name":"search_files"}}]}}`,
+ expectCorrected: true,
+ checkFunc: func(t *testing.T, result string) {
+ var payload map[string]any
+ if err := json.Unmarshal([]byte(result), &payload); err != nil {
+ t.Fatalf("Failed to parse result: %v", err)
+ }
+ delta, ok := payload["delta"].(map[string]any)
+ if !ok {
+ t.Fatal("Invalid delta format")
+ }
+ toolCalls, ok := delta["tool_calls"].([]any)
+ if !ok || len(toolCalls) == 0 {
+ t.Fatal("No tool_calls found in delta")
+ }
+ toolCall, ok := toolCalls[0].(map[string]any)
+ if !ok {
+ t.Fatal("Invalid tool_call format")
+ }
+ functionCall, ok := toolCall["function"].(map[string]any)
+ if !ok {
+ t.Fatal("Invalid function format")
+ }
+ if functionCall["name"] != "grep" {
+ t.Errorf("Expected tool name 'grep', got '%v'", functionCall["name"])
+ }
+ },
+ },
+ {
+ name: "correct list_files in choices.message.tool_calls",
+ input: `{"choices":[{"message":{"tool_calls":[{"function":{"name":"list_files"}}]}}]}`,
+ expectCorrected: true,
+ checkFunc: func(t *testing.T, result string) {
+ var payload map[string]any
+ if err := json.Unmarshal([]byte(result), &payload); err != nil {
+ t.Fatalf("Failed to parse result: %v", err)
+ }
+ choices, ok := payload["choices"].([]any)
+ if !ok || len(choices) == 0 {
+ t.Fatal("No choices found in result")
+ }
+ choice, ok := choices[0].(map[string]any)
+ if !ok {
+ t.Fatal("Invalid choice format")
+ }
+ message, ok := choice["message"].(map[string]any)
+ if !ok {
+ t.Fatal("Invalid message format")
+ }
+ toolCalls, ok := message["tool_calls"].([]any)
+ if !ok || len(toolCalls) == 0 {
+ t.Fatal("No tool_calls found in message")
+ }
+ toolCall, ok := toolCalls[0].(map[string]any)
+ if !ok {
+ t.Fatal("Invalid tool_call format")
+ }
+ functionCall, ok := toolCall["function"].(map[string]any)
+ if !ok {
+ t.Fatal("Invalid function format")
+ }
+ if functionCall["name"] != "glob" {
+ t.Errorf("Expected tool name 'glob', got '%v'", functionCall["name"])
+ }
+ },
+ },
+ {
+ name: "no correction needed",
+ input: `{"tool_calls":[{"function":{"name":"read","arguments":"{}"}}]}`,
+ expectCorrected: false,
+ },
+ {
+ name: "correct multiple tool calls",
+ input: `{"tool_calls":[{"function":{"name":"apply_patch"}},{"function":{"name":"read_file"}}]}`,
+ expectCorrected: true,
+ checkFunc: func(t *testing.T, result string) {
+ var payload map[string]any
+ if err := json.Unmarshal([]byte(result), &payload); err != nil {
+ t.Fatalf("Failed to parse result: %v", err)
+ }
+ toolCalls, ok := payload["tool_calls"].([]any)
+ if !ok || len(toolCalls) < 2 {
+ t.Fatal("Expected at least 2 tool_calls")
+ }
+
+ toolCall1, ok := toolCalls[0].(map[string]any)
+ if !ok {
+ t.Fatal("Invalid first tool_call format")
+ }
+ func1, ok := toolCall1["function"].(map[string]any)
+ if !ok {
+ t.Fatal("Invalid first function format")
+ }
+ if func1["name"] != "edit" {
+ t.Errorf("Expected first tool name 'edit', got '%v'", func1["name"])
+ }
+
+ toolCall2, ok := toolCalls[1].(map[string]any)
+ if !ok {
+ t.Fatal("Invalid second tool_call format")
+ }
+ func2, ok := toolCall2["function"].(map[string]any)
+ if !ok {
+ t.Fatal("Invalid second function format")
+ }
+ if func2["name"] != "read" {
+ t.Errorf("Expected second tool name 'read', got '%v'", func2["name"])
+ }
+ },
+ },
+ {
+ name: "camelCase format - applyPatch",
+ input: `{"tool_calls":[{"function":{"name":"applyPatch"}}]}`,
+ expectCorrected: true,
+ checkFunc: func(t *testing.T, result string) {
+ var payload map[string]any
+ if err := json.Unmarshal([]byte(result), &payload); err != nil {
+ t.Fatalf("Failed to parse result: %v", err)
+ }
+ toolCalls, ok := payload["tool_calls"].([]any)
+ if !ok || len(toolCalls) == 0 {
+ t.Fatal("No tool_calls found in result")
+ }
+ toolCall, ok := toolCalls[0].(map[string]any)
+ if !ok {
+ t.Fatal("Invalid tool_call format")
+ }
+ functionCall, ok := toolCall["function"].(map[string]any)
+ if !ok {
+ t.Fatal("Invalid function format")
+ }
+ if functionCall["name"] != "edit" {
+ t.Errorf("Expected tool name 'edit', got '%v'", functionCall["name"])
+ }
+ },
+ },
+ }
+
+ for _, tt := range tests {
+ t.Run(tt.name, func(t *testing.T) {
+ result, corrected := corrector.CorrectToolCallsInSSEData(tt.input)
+
+ if corrected != tt.expectCorrected {
+ t.Errorf("Expected corrected=%v, got %v", tt.expectCorrected, corrected)
+ }
+
+ if !corrected && result != tt.input {
+ t.Errorf("Expected unchanged result when not corrected")
+ }
+
+ if tt.checkFunc != nil {
+ tt.checkFunc(t, result)
+ }
+ })
+ }
+}
+
+func TestCorrectToolName(t *testing.T) {
+ tests := []struct {
+ input string
+ expected string
+ corrected bool
+ }{
+ {"apply_patch", "edit", true},
+ {"applyPatch", "edit", true},
+ {"update_plan", "todowrite", true},
+ {"updatePlan", "todowrite", true},
+ {"read_plan", "todoread", true},
+ {"readPlan", "todoread", true},
+ {"search_files", "grep", true},
+ {"searchFiles", "grep", true},
+ {"list_files", "glob", true},
+ {"listFiles", "glob", true},
+ {"read_file", "read", true},
+ {"readFile", "read", true},
+ {"write_file", "write", true},
+ {"writeFile", "write", true},
+ {"execute_bash", "bash", true},
+ {"executeBash", "bash", true},
+ {"exec_bash", "bash", true},
+ {"execBash", "bash", true},
+ {"unknown_tool", "unknown_tool", false},
+ {"read", "read", false},
+ {"edit", "edit", false},
+ }
+
+ for _, tt := range tests {
+ t.Run(tt.input, func(t *testing.T) {
+ result, corrected := CorrectToolName(tt.input)
+
+ if corrected != tt.corrected {
+ t.Errorf("Expected corrected=%v, got %v", tt.corrected, corrected)
+ }
+
+ if result != tt.expected {
+ t.Errorf("Expected '%s', got '%s'", tt.expected, result)
+ }
+ })
+ }
+}
+
+func TestGetToolNameMapping(t *testing.T) {
+ mapping := GetToolNameMapping()
+
+ expectedMappings := map[string]string{
+ "apply_patch": "edit",
+ "update_plan": "todowrite",
+ "read_plan": "todoread",
+ "search_files": "grep",
+ "list_files": "glob",
+ }
+
+ for from, to := range expectedMappings {
+ if mapping[from] != to {
+ t.Errorf("Expected mapping[%s] = %s, got %s", from, to, mapping[from])
+ }
+ }
+
+ mapping["test_tool"] = "test_value"
+ newMapping := GetToolNameMapping()
+ if _, exists := newMapping["test_tool"]; exists {
+ t.Error("Modifications to returned mapping should not affect original")
+ }
+}
+
+func TestCorrectorStats(t *testing.T) {
+ corrector := NewCodexToolCorrector()
+
+ stats := corrector.GetStats()
+ if stats.TotalCorrected != 0 {
+ t.Errorf("Expected TotalCorrected=0, got %d", stats.TotalCorrected)
+ }
+ if len(stats.CorrectionsByTool) != 0 {
+ t.Errorf("Expected empty CorrectionsByTool, got length %d", len(stats.CorrectionsByTool))
+ }
+
+ corrector.CorrectToolCallsInSSEData(`{"tool_calls":[{"function":{"name":"apply_patch"}}]}`)
+ corrector.CorrectToolCallsInSSEData(`{"tool_calls":[{"function":{"name":"apply_patch"}}]}`)
+ corrector.CorrectToolCallsInSSEData(`{"tool_calls":[{"function":{"name":"update_plan"}}]}`)
+
+ stats = corrector.GetStats()
+ if stats.TotalCorrected != 3 {
+ t.Errorf("Expected TotalCorrected=3, got %d", stats.TotalCorrected)
+ }
+
+ if stats.CorrectionsByTool["apply_patch->edit"] != 2 {
+ t.Errorf("Expected apply_patch->edit count=2, got %d", stats.CorrectionsByTool["apply_patch->edit"])
+ }
+
+ if stats.CorrectionsByTool["update_plan->todowrite"] != 1 {
+ t.Errorf("Expected update_plan->todowrite count=1, got %d", stats.CorrectionsByTool["update_plan->todowrite"])
+ }
+
+ corrector.ResetStats()
+ stats = corrector.GetStats()
+ if stats.TotalCorrected != 0 {
+ t.Errorf("Expected TotalCorrected=0 after reset, got %d", stats.TotalCorrected)
+ }
+ if len(stats.CorrectionsByTool) != 0 {
+ t.Errorf("Expected empty CorrectionsByTool after reset, got length %d", len(stats.CorrectionsByTool))
+ }
+}
+
+func TestComplexSSEData(t *testing.T) {
+ corrector := NewCodexToolCorrector()
+
+ input := `{
+ "id": "chatcmpl-123",
+ "object": "chat.completion.chunk",
+ "created": 1234567890,
+ "model": "gpt-5.1-codex",
+ "choices": [
+ {
+ "index": 0,
+ "delta": {
+ "tool_calls": [
+ {
+ "index": 0,
+ "function": {
+ "name": "apply_patch",
+ "arguments": "{\"file\":\"test.go\"}"
+ }
+ }
+ ]
+ },
+ "finish_reason": null
+ }
+ ]
+ }`
+
+ result, corrected := corrector.CorrectToolCallsInSSEData(input)
+
+ if !corrected {
+ t.Error("Expected data to be corrected")
+ }
+
+ var payload map[string]any
+ if err := json.Unmarshal([]byte(result), &payload); err != nil {
+ t.Fatalf("Failed to parse result: %v", err)
+ }
+
+ choices, ok := payload["choices"].([]any)
+ if !ok || len(choices) == 0 {
+ t.Fatal("No choices found in result")
+ }
+ choice, ok := choices[0].(map[string]any)
+ if !ok {
+ t.Fatal("Invalid choice format")
+ }
+ delta, ok := choice["delta"].(map[string]any)
+ if !ok {
+ t.Fatal("Invalid delta format")
+ }
+ toolCalls, ok := delta["tool_calls"].([]any)
+ if !ok || len(toolCalls) == 0 {
+ t.Fatal("No tool_calls found in delta")
+ }
+ toolCall, ok := toolCalls[0].(map[string]any)
+ if !ok {
+ t.Fatal("Invalid tool_call format")
+ }
+ function, ok := toolCall["function"].(map[string]any)
+ if !ok {
+ t.Fatal("Invalid function format")
+ }
+
+ if function["name"] != "edit" {
+ t.Errorf("Expected tool name 'edit', got '%v'", function["name"])
+ }
+}
+
+// TestCorrectToolParameters 测试工具参数修正
+func TestCorrectToolParameters(t *testing.T) {
+ corrector := NewCodexToolCorrector()
+
+ tests := []struct {
+ name string
+ input string
+ expected map[string]bool // key: 期待存在的参数, value: true表示应该存在
+ }{
+ {
+ name: "remove workdir from bash tool",
+ input: `{
+ "tool_calls": [{
+ "function": {
+ "name": "bash",
+ "arguments": "{\"command\":\"ls\",\"workdir\":\"/tmp\"}"
+ }
+ }]
+ }`,
+ expected: map[string]bool{
+ "command": true,
+ "workdir": false,
+ },
+ },
+ {
+ name: "rename path to file_path in edit tool",
+ input: `{
+ "tool_calls": [{
+ "function": {
+ "name": "apply_patch",
+ "arguments": "{\"path\":\"/foo/bar.go\",\"old_string\":\"old\",\"new_string\":\"new\"}"
+ }
+ }]
+ }`,
+ expected: map[string]bool{
+ "file_path": true,
+ "path": false,
+ "old_string": true,
+ "new_string": true,
+ },
+ },
+ }
+
+ for _, tt := range tests {
+ t.Run(tt.name, func(t *testing.T) {
+ corrected, changed := corrector.CorrectToolCallsInSSEData(tt.input)
+ if !changed {
+ t.Error("expected data to be corrected")
+ }
+
+ // 解析修正后的数据
+ var result map[string]any
+ if err := json.Unmarshal([]byte(corrected), &result); err != nil {
+ t.Fatalf("failed to parse corrected data: %v", err)
+ }
+
+ // 检查工具调用
+ toolCalls, ok := result["tool_calls"].([]any)
+ if !ok || len(toolCalls) == 0 {
+ t.Fatal("no tool_calls found in corrected data")
+ }
+
+ toolCall, ok := toolCalls[0].(map[string]any)
+ if !ok {
+ t.Fatal("invalid tool_call structure")
+ }
+
+ function, ok := toolCall["function"].(map[string]any)
+ if !ok {
+ t.Fatal("no function found in tool_call")
+ }
+
+ argumentsStr, ok := function["arguments"].(string)
+ if !ok {
+ t.Fatal("arguments is not a string")
+ }
+
+ var args map[string]any
+ if err := json.Unmarshal([]byte(argumentsStr), &args); err != nil {
+ t.Fatalf("failed to parse arguments: %v", err)
+ }
+
+ // 验证期望的参数
+ for param, shouldExist := range tt.expected {
+ _, exists := args[param]
+ if shouldExist && !exists {
+ t.Errorf("expected parameter %q to exist, but it doesn't", param)
+ }
+ if !shouldExist && exists {
+ t.Errorf("expected parameter %q to not exist, but it does", param)
+ }
+ }
+ })
+ }
+}
diff --git a/backend/internal/service/ops_aggregation_service.go b/backend/internal/service/ops_aggregation_service.go
index 2a6afbba..972462ec 100644
--- a/backend/internal/service/ops_aggregation_service.go
+++ b/backend/internal/service/ops_aggregation_service.go
@@ -4,6 +4,7 @@ import (
"context"
"database/sql"
"errors"
+ "fmt"
"log"
"strings"
"sync"
@@ -235,11 +236,13 @@ func (s *OpsAggregationService) aggregateHourly() {
successAt := finishedAt
hbCtx, hbCancel := context.WithTimeout(context.Background(), 2*time.Second)
defer hbCancel()
+ result := truncateString(fmt.Sprintf("window=%s..%s", start.Format(time.RFC3339), end.Format(time.RFC3339)), 2048)
_ = s.opsRepo.UpsertJobHeartbeat(hbCtx, &OpsUpsertJobHeartbeatInput{
JobName: opsAggHourlyJobName,
LastRunAt: &runAt,
LastSuccessAt: &successAt,
LastDurationMs: &dur,
+ LastResult: &result,
})
}
@@ -331,11 +334,13 @@ func (s *OpsAggregationService) aggregateDaily() {
successAt := finishedAt
hbCtx, hbCancel := context.WithTimeout(context.Background(), 2*time.Second)
defer hbCancel()
+ result := truncateString(fmt.Sprintf("window=%s..%s", start.Format(time.RFC3339), end.Format(time.RFC3339)), 2048)
_ = s.opsRepo.UpsertJobHeartbeat(hbCtx, &OpsUpsertJobHeartbeatInput{
JobName: opsAggDailyJobName,
LastRunAt: &runAt,
LastSuccessAt: &successAt,
LastDurationMs: &dur,
+ LastResult: &result,
})
}
diff --git a/backend/internal/service/ops_alert_evaluator_service.go b/backend/internal/service/ops_alert_evaluator_service.go
index f376c246..7c62e247 100644
--- a/backend/internal/service/ops_alert_evaluator_service.go
+++ b/backend/internal/service/ops_alert_evaluator_service.go
@@ -190,6 +190,13 @@ func (s *OpsAlertEvaluatorService) evaluateOnce(interval time.Duration) {
return
}
+ rulesTotal := len(rules)
+ rulesEnabled := 0
+ rulesEvaluated := 0
+ eventsCreated := 0
+ eventsResolved := 0
+ emailsSent := 0
+
now := time.Now().UTC()
safeEnd := now.Truncate(time.Minute)
if safeEnd.IsZero() {
@@ -205,8 +212,9 @@ func (s *OpsAlertEvaluatorService) evaluateOnce(interval time.Duration) {
if rule == nil || !rule.Enabled || rule.ID <= 0 {
continue
}
+ rulesEnabled++
- scopePlatform, scopeGroupID := parseOpsAlertRuleScope(rule.Filters)
+ scopePlatform, scopeGroupID, scopeRegion := parseOpsAlertRuleScope(rule.Filters)
windowMinutes := rule.WindowMinutes
if windowMinutes <= 0 {
@@ -220,6 +228,7 @@ func (s *OpsAlertEvaluatorService) evaluateOnce(interval time.Duration) {
s.resetRuleState(rule.ID, now)
continue
}
+ rulesEvaluated++
breachedNow := compareMetric(metricValue, rule.Operator, rule.Threshold)
required := requiredSustainedBreaches(rule.SustainedMinutes, interval)
@@ -236,6 +245,17 @@ func (s *OpsAlertEvaluatorService) evaluateOnce(interval time.Duration) {
continue
}
+ // Scoped silencing: if a matching silence exists, skip creating a firing event.
+ if s.opsService != nil {
+ platform := strings.TrimSpace(scopePlatform)
+ region := scopeRegion
+ if platform != "" {
+ if ok, err := s.opsService.IsAlertSilenced(ctx, rule.ID, platform, scopeGroupID, region, now); err == nil && ok {
+ continue
+ }
+ }
+ }
+
latestEvent, err := s.opsRepo.GetLatestAlertEvent(ctx, rule.ID)
if err != nil {
log.Printf("[OpsAlertEvaluator] get latest event failed (rule=%d): %v", rule.ID, err)
@@ -267,8 +287,11 @@ func (s *OpsAlertEvaluatorService) evaluateOnce(interval time.Duration) {
continue
}
+ eventsCreated++
if created != nil && created.ID > 0 {
- s.maybeSendAlertEmail(ctx, runtimeCfg, rule, created)
+ if s.maybeSendAlertEmail(ctx, runtimeCfg, rule, created) {
+ emailsSent++
+ }
}
continue
}
@@ -278,11 +301,14 @@ func (s *OpsAlertEvaluatorService) evaluateOnce(interval time.Duration) {
resolvedAt := now
if err := s.opsRepo.UpdateAlertEventStatus(ctx, activeEvent.ID, OpsAlertStatusResolved, &resolvedAt); err != nil {
log.Printf("[OpsAlertEvaluator] resolve event failed (event=%d): %v", activeEvent.ID, err)
+ } else {
+ eventsResolved++
}
}
}
- s.recordHeartbeatSuccess(runAt, time.Since(startedAt))
+ result := truncateString(fmt.Sprintf("rules=%d enabled=%d evaluated=%d created=%d resolved=%d emails_sent=%d", rulesTotal, rulesEnabled, rulesEvaluated, eventsCreated, eventsResolved, emailsSent), 2048)
+ s.recordHeartbeatSuccess(runAt, time.Since(startedAt), result)
}
func (s *OpsAlertEvaluatorService) pruneRuleStates(rules []*OpsAlertRule) {
@@ -359,9 +385,9 @@ func requiredSustainedBreaches(sustainedMinutes int, interval time.Duration) int
return required
}
-func parseOpsAlertRuleScope(filters map[string]any) (platform string, groupID *int64) {
+func parseOpsAlertRuleScope(filters map[string]any) (platform string, groupID *int64, region *string) {
if filters == nil {
- return "", nil
+ return "", nil, nil
}
if v, ok := filters["platform"]; ok {
if s, ok := v.(string); ok {
@@ -392,7 +418,15 @@ func parseOpsAlertRuleScope(filters map[string]any) (platform string, groupID *i
}
}
}
- return platform, groupID
+ if v, ok := filters["region"]; ok {
+ if s, ok := v.(string); ok {
+ vv := strings.TrimSpace(s)
+ if vv != "" {
+ region = &vv
+ }
+ }
+ }
+ return platform, groupID, region
}
func (s *OpsAlertEvaluatorService) computeRuleMetric(
@@ -504,16 +538,6 @@ func (s *OpsAlertEvaluatorService) computeRuleMetric(
return 0, false
}
return overview.UpstreamErrorRate * 100, true
- case "p95_latency_ms":
- if overview.Duration.P95 == nil {
- return 0, false
- }
- return float64(*overview.Duration.P95), true
- case "p99_latency_ms":
- if overview.Duration.P99 == nil {
- return 0, false
- }
- return float64(*overview.Duration.P99), true
default:
return 0, false
}
@@ -576,32 +600,32 @@ func buildOpsAlertDescription(rule *OpsAlertRule, value float64, windowMinutes i
)
}
-func (s *OpsAlertEvaluatorService) maybeSendAlertEmail(ctx context.Context, runtimeCfg *OpsAlertRuntimeSettings, rule *OpsAlertRule, event *OpsAlertEvent) {
+func (s *OpsAlertEvaluatorService) maybeSendAlertEmail(ctx context.Context, runtimeCfg *OpsAlertRuntimeSettings, rule *OpsAlertRule, event *OpsAlertEvent) bool {
if s == nil || s.emailService == nil || s.opsService == nil || event == nil || rule == nil {
- return
+ return false
}
if event.EmailSent {
- return
+ return false
}
if !rule.NotifyEmail {
- return
+ return false
}
emailCfg, err := s.opsService.GetEmailNotificationConfig(ctx)
if err != nil || emailCfg == nil || !emailCfg.Alert.Enabled {
- return
+ return false
}
if len(emailCfg.Alert.Recipients) == 0 {
- return
+ return false
}
if !shouldSendOpsAlertEmailByMinSeverity(strings.TrimSpace(emailCfg.Alert.MinSeverity), strings.TrimSpace(rule.Severity)) {
- return
+ return false
}
if runtimeCfg != nil && runtimeCfg.Silencing.Enabled {
if isOpsAlertSilenced(time.Now().UTC(), rule, event, runtimeCfg.Silencing) {
- return
+ return false
}
}
@@ -630,6 +654,7 @@ func (s *OpsAlertEvaluatorService) maybeSendAlertEmail(ctx context.Context, runt
if anySent {
_ = s.opsRepo.UpdateAlertEventEmailSent(context.Background(), event.ID, true)
}
+ return anySent
}
func buildOpsAlertEmailBody(rule *OpsAlertRule, event *OpsAlertEvent) string {
@@ -797,7 +822,7 @@ func (s *OpsAlertEvaluatorService) maybeLogSkip(key string) {
log.Printf("[OpsAlertEvaluator] leader lock held by another instance; skipping (key=%q)", key)
}
-func (s *OpsAlertEvaluatorService) recordHeartbeatSuccess(runAt time.Time, duration time.Duration) {
+func (s *OpsAlertEvaluatorService) recordHeartbeatSuccess(runAt time.Time, duration time.Duration, result string) {
if s == nil || s.opsRepo == nil {
return
}
@@ -805,11 +830,17 @@ func (s *OpsAlertEvaluatorService) recordHeartbeatSuccess(runAt time.Time, durat
durMs := duration.Milliseconds()
ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second)
defer cancel()
+ msg := strings.TrimSpace(result)
+ if msg == "" {
+ msg = "ok"
+ }
+ msg = truncateString(msg, 2048)
_ = s.opsRepo.UpsertJobHeartbeat(ctx, &OpsUpsertJobHeartbeatInput{
JobName: opsAlertEvaluatorJobName,
LastRunAt: &runAt,
LastSuccessAt: &now,
LastDurationMs: &durMs,
+ LastResult: &msg,
})
}
diff --git a/backend/internal/service/ops_alert_models.go b/backend/internal/service/ops_alert_models.go
index 0acf13ab..a0caa990 100644
--- a/backend/internal/service/ops_alert_models.go
+++ b/backend/internal/service/ops_alert_models.go
@@ -8,8 +8,9 @@ import "time"
// with the existing ops dashboard frontend (backup style).
const (
- OpsAlertStatusFiring = "firing"
- OpsAlertStatusResolved = "resolved"
+ OpsAlertStatusFiring = "firing"
+ OpsAlertStatusResolved = "resolved"
+ OpsAlertStatusManualResolved = "manual_resolved"
)
type OpsAlertRule struct {
@@ -58,12 +59,32 @@ type OpsAlertEvent struct {
CreatedAt time.Time `json:"created_at"`
}
+type OpsAlertSilence struct {
+ ID int64 `json:"id"`
+
+ RuleID int64 `json:"rule_id"`
+ Platform string `json:"platform"`
+ GroupID *int64 `json:"group_id,omitempty"`
+ Region *string `json:"region,omitempty"`
+
+ Until time.Time `json:"until"`
+ Reason string `json:"reason"`
+
+ CreatedBy *int64 `json:"created_by,omitempty"`
+ CreatedAt time.Time `json:"created_at"`
+}
+
type OpsAlertEventFilter struct {
Limit int
+ // Cursor pagination (descending by fired_at, then id).
+ BeforeFiredAt *time.Time
+ BeforeID *int64
+
// Optional filters.
- Status string
- Severity string
+ Status string
+ Severity string
+ EmailSent *bool
StartTime *time.Time
EndTime *time.Time
diff --git a/backend/internal/service/ops_alerts.go b/backend/internal/service/ops_alerts.go
index b6c3d1c3..b4c09824 100644
--- a/backend/internal/service/ops_alerts.go
+++ b/backend/internal/service/ops_alerts.go
@@ -88,6 +88,29 @@ func (s *OpsService) ListAlertEvents(ctx context.Context, filter *OpsAlertEventF
return s.opsRepo.ListAlertEvents(ctx, filter)
}
+func (s *OpsService) GetAlertEventByID(ctx context.Context, eventID int64) (*OpsAlertEvent, error) {
+ if err := s.RequireMonitoringEnabled(ctx); err != nil {
+ return nil, err
+ }
+ if s.opsRepo == nil {
+ return nil, infraerrors.ServiceUnavailable("OPS_REPO_UNAVAILABLE", "Ops repository not available")
+ }
+ if eventID <= 0 {
+ return nil, infraerrors.BadRequest("INVALID_EVENT_ID", "invalid event id")
+ }
+ ev, err := s.opsRepo.GetAlertEventByID(ctx, eventID)
+ if err != nil {
+ if errors.Is(err, sql.ErrNoRows) {
+ return nil, infraerrors.NotFound("OPS_ALERT_EVENT_NOT_FOUND", "alert event not found")
+ }
+ return nil, err
+ }
+ if ev == nil {
+ return nil, infraerrors.NotFound("OPS_ALERT_EVENT_NOT_FOUND", "alert event not found")
+ }
+ return ev, nil
+}
+
func (s *OpsService) GetActiveAlertEvent(ctx context.Context, ruleID int64) (*OpsAlertEvent, error) {
if err := s.RequireMonitoringEnabled(ctx); err != nil {
return nil, err
@@ -101,6 +124,49 @@ func (s *OpsService) GetActiveAlertEvent(ctx context.Context, ruleID int64) (*Op
return s.opsRepo.GetActiveAlertEvent(ctx, ruleID)
}
+func (s *OpsService) CreateAlertSilence(ctx context.Context, input *OpsAlertSilence) (*OpsAlertSilence, error) {
+ if err := s.RequireMonitoringEnabled(ctx); err != nil {
+ return nil, err
+ }
+ if s.opsRepo == nil {
+ return nil, infraerrors.ServiceUnavailable("OPS_REPO_UNAVAILABLE", "Ops repository not available")
+ }
+ if input == nil {
+ return nil, infraerrors.BadRequest("INVALID_SILENCE", "invalid silence")
+ }
+ if input.RuleID <= 0 {
+ return nil, infraerrors.BadRequest("INVALID_RULE_ID", "invalid rule id")
+ }
+ if strings.TrimSpace(input.Platform) == "" {
+ return nil, infraerrors.BadRequest("INVALID_PLATFORM", "invalid platform")
+ }
+ if input.Until.IsZero() {
+ return nil, infraerrors.BadRequest("INVALID_UNTIL", "invalid until")
+ }
+
+ created, err := s.opsRepo.CreateAlertSilence(ctx, input)
+ if err != nil {
+ return nil, err
+ }
+ return created, nil
+}
+
+func (s *OpsService) IsAlertSilenced(ctx context.Context, ruleID int64, platform string, groupID *int64, region *string, now time.Time) (bool, error) {
+ if err := s.RequireMonitoringEnabled(ctx); err != nil {
+ return false, err
+ }
+ if s.opsRepo == nil {
+ return false, infraerrors.ServiceUnavailable("OPS_REPO_UNAVAILABLE", "Ops repository not available")
+ }
+ if ruleID <= 0 {
+ return false, infraerrors.BadRequest("INVALID_RULE_ID", "invalid rule id")
+ }
+ if strings.TrimSpace(platform) == "" {
+ return false, nil
+ }
+ return s.opsRepo.IsAlertSilenced(ctx, ruleID, platform, groupID, region, now)
+}
+
func (s *OpsService) GetLatestAlertEvent(ctx context.Context, ruleID int64) (*OpsAlertEvent, error) {
if err := s.RequireMonitoringEnabled(ctx); err != nil {
return nil, err
@@ -142,7 +208,11 @@ func (s *OpsService) UpdateAlertEventStatus(ctx context.Context, eventID int64,
if eventID <= 0 {
return infraerrors.BadRequest("INVALID_EVENT_ID", "invalid event id")
}
- if strings.TrimSpace(status) == "" {
+ status = strings.TrimSpace(status)
+ if status == "" {
+ return infraerrors.BadRequest("INVALID_STATUS", "invalid status")
+ }
+ if status != OpsAlertStatusResolved && status != OpsAlertStatusManualResolved {
return infraerrors.BadRequest("INVALID_STATUS", "invalid status")
}
return s.opsRepo.UpdateAlertEventStatus(ctx, eventID, status, resolvedAt)
diff --git a/backend/internal/service/ops_cleanup_service.go b/backend/internal/service/ops_cleanup_service.go
index afd2d22c..1ade7176 100644
--- a/backend/internal/service/ops_cleanup_service.go
+++ b/backend/internal/service/ops_cleanup_service.go
@@ -149,7 +149,7 @@ func (s *OpsCleanupService) runScheduled() {
log.Printf("[OpsCleanup] cleanup failed: %v", err)
return
}
- s.recordHeartbeatSuccess(runAt, time.Since(startedAt))
+ s.recordHeartbeatSuccess(runAt, time.Since(startedAt), counts)
log.Printf("[OpsCleanup] cleanup complete: %s", counts)
}
@@ -330,12 +330,13 @@ func (s *OpsCleanupService) tryAcquireLeaderLock(ctx context.Context) (func(), b
return release, true
}
-func (s *OpsCleanupService) recordHeartbeatSuccess(runAt time.Time, duration time.Duration) {
+func (s *OpsCleanupService) recordHeartbeatSuccess(runAt time.Time, duration time.Duration, counts opsCleanupDeletedCounts) {
if s == nil || s.opsRepo == nil {
return
}
now := time.Now().UTC()
durMs := duration.Milliseconds()
+ result := truncateString(counts.String(), 2048)
ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second)
defer cancel()
_ = s.opsRepo.UpsertJobHeartbeat(ctx, &OpsUpsertJobHeartbeatInput{
@@ -343,6 +344,7 @@ func (s *OpsCleanupService) recordHeartbeatSuccess(runAt time.Time, duration tim
LastRunAt: &runAt,
LastSuccessAt: &now,
LastDurationMs: &durMs,
+ LastResult: &result,
})
}
diff --git a/backend/internal/service/ops_health_score.go b/backend/internal/service/ops_health_score.go
index feb0d843..5efae870 100644
--- a/backend/internal/service/ops_health_score.go
+++ b/backend/internal/service/ops_health_score.go
@@ -32,49 +32,38 @@ func computeDashboardHealthScore(now time.Time, overview *OpsDashboardOverview)
}
// computeBusinessHealth calculates business health score (0-100)
-// Components: SLA (50%) + Error Rate (30%) + Latency (20%)
+// Components: Error Rate (50%) + TTFT (50%)
func computeBusinessHealth(overview *OpsDashboardOverview) float64 {
- // SLA score: 99.5% → 100, 95% → 0 (linear)
- slaScore := 100.0
- slaPct := clampFloat64(overview.SLA*100, 0, 100)
- if slaPct < 99.5 {
- if slaPct >= 95 {
- slaScore = (slaPct - 95) / 4.5 * 100
- } else {
- slaScore = 0
- }
- }
-
- // Error rate score: 0.5% → 100, 5% → 0 (linear)
+ // Error rate score: 1% → 100, 10% → 0 (linear)
// Combines request errors and upstream errors
errorScore := 100.0
errorPct := clampFloat64(overview.ErrorRate*100, 0, 100)
upstreamPct := clampFloat64(overview.UpstreamErrorRate*100, 0, 100)
combinedErrorPct := math.Max(errorPct, upstreamPct) // Use worst case
- if combinedErrorPct > 0.5 {
- if combinedErrorPct <= 5 {
- errorScore = (5 - combinedErrorPct) / 4.5 * 100
+ if combinedErrorPct > 1.0 {
+ if combinedErrorPct <= 10.0 {
+ errorScore = (10.0 - combinedErrorPct) / 9.0 * 100
} else {
errorScore = 0
}
}
- // Latency score: 1s → 100, 10s → 0 (linear)
- // Uses P99 of duration (TTFT is less critical for overall health)
- latencyScore := 100.0
- if overview.Duration.P99 != nil {
- p99 := float64(*overview.Duration.P99)
+ // TTFT score: 1s → 100, 3s → 0 (linear)
+ // Time to first token is critical for user experience
+ ttftScore := 100.0
+ if overview.TTFT.P99 != nil {
+ p99 := float64(*overview.TTFT.P99)
if p99 > 1000 {
- if p99 <= 10000 {
- latencyScore = (10000 - p99) / 9000 * 100
+ if p99 <= 3000 {
+ ttftScore = (3000 - p99) / 2000 * 100
} else {
- latencyScore = 0
+ ttftScore = 0
}
}
}
- // Weighted combination
- return slaScore*0.5 + errorScore*0.3 + latencyScore*0.2
+ // Weighted combination: 50% error rate + 50% TTFT
+ return errorScore*0.5 + ttftScore*0.5
}
// computeInfraHealth calculates infrastructure health score (0-100)
diff --git a/backend/internal/service/ops_health_score_test.go b/backend/internal/service/ops_health_score_test.go
index 849ba146..25bfb43d 100644
--- a/backend/internal/service/ops_health_score_test.go
+++ b/backend/internal/service/ops_health_score_test.go
@@ -127,8 +127,8 @@ func TestComputeDashboardHealthScore_Comprehensive(t *testing.T) {
MemoryUsagePercent: float64Ptr(75),
},
},
- wantMin: 60,
- wantMax: 85,
+ wantMin: 96,
+ wantMax: 97,
},
{
name: "DB failure",
@@ -203,8 +203,8 @@ func TestComputeDashboardHealthScore_Comprehensive(t *testing.T) {
MemoryUsagePercent: float64Ptr(30),
},
},
- wantMin: 25,
- wantMax: 50,
+ wantMin: 84,
+ wantMax: 85,
},
{
name: "combined failures - business healthy + infra degraded",
@@ -277,30 +277,41 @@ func TestComputeBusinessHealth(t *testing.T) {
UpstreamErrorRate: 0,
Duration: OpsPercentiles{P99: intPtr(500)},
},
- wantMin: 50,
- wantMax: 60,
+ wantMin: 100,
+ wantMax: 100,
},
{
- name: "error rate boundary 0.5%",
+ name: "error rate boundary 1%",
overview: &OpsDashboardOverview{
- SLA: 0.995,
- ErrorRate: 0.005,
+ SLA: 0.99,
+ ErrorRate: 0.01,
UpstreamErrorRate: 0,
Duration: OpsPercentiles{P99: intPtr(500)},
},
- wantMin: 95,
+ wantMin: 100,
wantMax: 100,
},
{
- name: "latency boundary 1000ms",
+ name: "error rate 5%",
overview: &OpsDashboardOverview{
- SLA: 0.995,
+ SLA: 0.95,
+ ErrorRate: 0.05,
+ UpstreamErrorRate: 0,
+ Duration: OpsPercentiles{P99: intPtr(500)},
+ },
+ wantMin: 77,
+ wantMax: 78,
+ },
+ {
+ name: "TTFT boundary 2s",
+ overview: &OpsDashboardOverview{
+ SLA: 0.99,
ErrorRate: 0,
UpstreamErrorRate: 0,
- Duration: OpsPercentiles{P99: intPtr(1000)},
+ TTFT: OpsPercentiles{P99: intPtr(2000)},
},
- wantMin: 95,
- wantMax: 100,
+ wantMin: 75,
+ wantMax: 75,
},
{
name: "upstream error dominates",
@@ -310,7 +321,7 @@ func TestComputeBusinessHealth(t *testing.T) {
UpstreamErrorRate: 0.03,
Duration: OpsPercentiles{P99: intPtr(500)},
},
- wantMin: 75,
+ wantMin: 88,
wantMax: 90,
},
}
diff --git a/backend/internal/service/ops_models.go b/backend/internal/service/ops_models.go
index 996267fd..347cd52b 100644
--- a/backend/internal/service/ops_models.go
+++ b/backend/internal/service/ops_models.go
@@ -6,24 +6,43 @@ type OpsErrorLog struct {
ID int64 `json:"id"`
CreatedAt time.Time `json:"created_at"`
- Phase string `json:"phase"`
- Type string `json:"type"`
+ // Standardized classification
+ // - phase: request|auth|routing|upstream|network|internal
+ // - owner: client|provider|platform
+ // - source: client_request|upstream_http|gateway
+ Phase string `json:"phase"`
+ Type string `json:"type"`
+
+ Owner string `json:"error_owner"`
+ Source string `json:"error_source"`
+
Severity string `json:"severity"`
StatusCode int `json:"status_code"`
Platform string `json:"platform"`
Model string `json:"model"`
- LatencyMs *int `json:"latency_ms"`
+ IsRetryable bool `json:"is_retryable"`
+ RetryCount int `json:"retry_count"`
+
+ Resolved bool `json:"resolved"`
+ ResolvedAt *time.Time `json:"resolved_at"`
+ ResolvedByUserID *int64 `json:"resolved_by_user_id"`
+ ResolvedByUserName string `json:"resolved_by_user_name"`
+ ResolvedRetryID *int64 `json:"resolved_retry_id"`
+ ResolvedStatusRaw string `json:"-"`
ClientRequestID string `json:"client_request_id"`
RequestID string `json:"request_id"`
Message string `json:"message"`
- UserID *int64 `json:"user_id"`
- APIKeyID *int64 `json:"api_key_id"`
- AccountID *int64 `json:"account_id"`
- GroupID *int64 `json:"group_id"`
+ UserID *int64 `json:"user_id"`
+ UserEmail string `json:"user_email"`
+ APIKeyID *int64 `json:"api_key_id"`
+ AccountID *int64 `json:"account_id"`
+ AccountName string `json:"account_name"`
+ GroupID *int64 `json:"group_id"`
+ GroupName string `json:"group_name"`
ClientIP *string `json:"client_ip"`
RequestPath string `json:"request_path"`
@@ -67,9 +86,24 @@ type OpsErrorLogFilter struct {
GroupID *int64
AccountID *int64
- StatusCodes []int
- Phase string
- Query string
+ StatusCodes []int
+ StatusCodesOther bool
+ Phase string
+ Owner string
+ Source string
+ Resolved *bool
+ Query string
+ UserQuery string // Search by user email
+
+ // Optional correlation keys for exact matching.
+ RequestID string
+ ClientRequestID string
+
+ // View controls error categorization for list endpoints.
+ // - errors: show actionable errors (exclude business-limited / 429 / 529)
+ // - excluded: only show excluded errors
+ // - all: show everything
+ View string
Page int
PageSize int
@@ -90,12 +124,23 @@ type OpsRetryAttempt struct {
SourceErrorID int64 `json:"source_error_id"`
Mode string `json:"mode"`
PinnedAccountID *int64 `json:"pinned_account_id"`
+ PinnedAccountName string `json:"pinned_account_name"`
Status string `json:"status"`
StartedAt *time.Time `json:"started_at"`
FinishedAt *time.Time `json:"finished_at"`
DurationMs *int64 `json:"duration_ms"`
+ // Persisted execution results (best-effort)
+ Success *bool `json:"success"`
+ HTTPStatusCode *int `json:"http_status_code"`
+ UpstreamRequestID *string `json:"upstream_request_id"`
+ UsedAccountID *int64 `json:"used_account_id"`
+ UsedAccountName string `json:"used_account_name"`
+ ResponsePreview *string `json:"response_preview"`
+ ResponseTruncated *bool `json:"response_truncated"`
+
+ // Optional correlation
ResultRequestID *string `json:"result_request_id"`
ResultErrorID *int64 `json:"result_error_id"`
diff --git a/backend/internal/service/ops_port.go b/backend/internal/service/ops_port.go
index 4df21c37..515b47bb 100644
--- a/backend/internal/service/ops_port.go
+++ b/backend/internal/service/ops_port.go
@@ -14,6 +14,8 @@ type OpsRepository interface {
InsertRetryAttempt(ctx context.Context, input *OpsInsertRetryAttemptInput) (int64, error)
UpdateRetryAttempt(ctx context.Context, input *OpsUpdateRetryAttemptInput) error
GetLatestRetryAttemptForError(ctx context.Context, sourceErrorID int64) (*OpsRetryAttempt, error)
+ ListRetryAttemptsByErrorID(ctx context.Context, sourceErrorID int64, limit int) ([]*OpsRetryAttempt, error)
+ UpdateErrorResolution(ctx context.Context, errorID int64, resolved bool, resolvedByUserID *int64, resolvedRetryID *int64, resolvedAt *time.Time) error
// Lightweight window stats (for realtime WS / quick sampling).
GetWindowStats(ctx context.Context, filter *OpsDashboardFilter) (*OpsWindowStats, error)
@@ -39,12 +41,17 @@ type OpsRepository interface {
DeleteAlertRule(ctx context.Context, id int64) error
ListAlertEvents(ctx context.Context, filter *OpsAlertEventFilter) ([]*OpsAlertEvent, error)
+ GetAlertEventByID(ctx context.Context, eventID int64) (*OpsAlertEvent, error)
GetActiveAlertEvent(ctx context.Context, ruleID int64) (*OpsAlertEvent, error)
GetLatestAlertEvent(ctx context.Context, ruleID int64) (*OpsAlertEvent, error)
CreateAlertEvent(ctx context.Context, event *OpsAlertEvent) (*OpsAlertEvent, error)
UpdateAlertEventStatus(ctx context.Context, eventID int64, status string, resolvedAt *time.Time) error
UpdateAlertEventEmailSent(ctx context.Context, eventID int64, emailSent bool) error
+ // Alert silences
+ CreateAlertSilence(ctx context.Context, input *OpsAlertSilence) (*OpsAlertSilence, error)
+ IsAlertSilenced(ctx context.Context, ruleID int64, platform string, groupID *int64, region *string, now time.Time) (bool, error)
+
// Pre-aggregation (hourly/daily) used for long-window dashboard performance.
UpsertHourlyMetrics(ctx context.Context, startTime, endTime time.Time) error
UpsertDailyMetrics(ctx context.Context, startTime, endTime time.Time) error
@@ -91,7 +98,6 @@ type OpsInsertErrorLogInput struct {
// It is set by OpsService.RecordError before persisting.
UpstreamErrorsJSON *string
- DurationMs *int
TimeToFirstTokenMs *int64
RequestBodyJSON *string // sanitized json string (not raw bytes)
@@ -124,7 +130,15 @@ type OpsUpdateRetryAttemptInput struct {
FinishedAt time.Time
DurationMs int64
- // Optional correlation
+ // Persisted execution results (best-effort)
+ Success *bool
+ HTTPStatusCode *int
+ UpstreamRequestID *string
+ UsedAccountID *int64
+ ResponsePreview *string
+ ResponseTruncated *bool
+
+ // Optional correlation (legacy fields kept)
ResultRequestID *string
ResultErrorID *int64
@@ -221,6 +235,9 @@ type OpsUpsertJobHeartbeatInput struct {
LastErrorAt *time.Time
LastError *string
LastDurationMs *int64
+
+ // LastResult is an optional human-readable summary of the last successful run.
+ LastResult *string
}
type OpsJobHeartbeat struct {
@@ -231,6 +248,7 @@ type OpsJobHeartbeat struct {
LastErrorAt *time.Time `json:"last_error_at"`
LastError *string `json:"last_error"`
LastDurationMs *int64 `json:"last_duration_ms"`
+ LastResult *string `json:"last_result"`
UpdatedAt time.Time `json:"updated_at"`
}
diff --git a/backend/internal/service/ops_retry.go b/backend/internal/service/ops_retry.go
index 747aa3b8..8d98e43f 100644
--- a/backend/internal/service/ops_retry.go
+++ b/backend/internal/service/ops_retry.go
@@ -108,6 +108,10 @@ func (w *limitedResponseWriter) truncated() bool {
return w.totalWritten > int64(w.limit)
}
+const (
+ OpsRetryModeUpstreamEvent = "upstream_event"
+)
+
func (s *OpsService) RetryError(ctx context.Context, requestedByUserID int64, errorID int64, mode string, pinnedAccountID *int64) (*OpsRetryResult, error) {
if err := s.RequireMonitoringEnabled(ctx); err != nil {
return nil, err
@@ -123,6 +127,81 @@ func (s *OpsService) RetryError(ctx context.Context, requestedByUserID int64, er
return nil, infraerrors.BadRequest("OPS_RETRY_INVALID_MODE", "mode must be client or upstream")
}
+ errorLog, err := s.GetErrorLogByID(ctx, errorID)
+ if err != nil {
+ return nil, err
+ }
+ if errorLog == nil {
+ return nil, infraerrors.NotFound("OPS_ERROR_NOT_FOUND", "ops error log not found")
+ }
+ if strings.TrimSpace(errorLog.RequestBody) == "" {
+ return nil, infraerrors.BadRequest("OPS_RETRY_NO_REQUEST_BODY", "No request body found to retry")
+ }
+
+ var pinned *int64
+ if mode == OpsRetryModeUpstream {
+ if pinnedAccountID != nil && *pinnedAccountID > 0 {
+ pinned = pinnedAccountID
+ } else if errorLog.AccountID != nil && *errorLog.AccountID > 0 {
+ pinned = errorLog.AccountID
+ } else {
+ return nil, infraerrors.BadRequest("OPS_RETRY_PINNED_ACCOUNT_REQUIRED", "pinned_account_id is required for upstream retry")
+ }
+ }
+
+ return s.retryWithErrorLog(ctx, requestedByUserID, errorID, mode, mode, pinned, errorLog)
+}
+
+// RetryUpstreamEvent retries a specific upstream attempt captured inside ops_error_logs.upstream_errors.
+// idx is 0-based. It always pins the original event account_id.
+func (s *OpsService) RetryUpstreamEvent(ctx context.Context, requestedByUserID int64, errorID int64, idx int) (*OpsRetryResult, error) {
+ if err := s.RequireMonitoringEnabled(ctx); err != nil {
+ return nil, err
+ }
+ if s.opsRepo == nil {
+ return nil, infraerrors.ServiceUnavailable("OPS_REPO_UNAVAILABLE", "Ops repository not available")
+ }
+ if idx < 0 {
+ return nil, infraerrors.BadRequest("OPS_RETRY_INVALID_UPSTREAM_IDX", "invalid upstream idx")
+ }
+
+ errorLog, err := s.GetErrorLogByID(ctx, errorID)
+ if err != nil {
+ return nil, err
+ }
+ if errorLog == nil {
+ return nil, infraerrors.NotFound("OPS_ERROR_NOT_FOUND", "ops error log not found")
+ }
+
+ events, err := ParseOpsUpstreamErrors(errorLog.UpstreamErrors)
+ if err != nil {
+ return nil, infraerrors.BadRequest("OPS_RETRY_UPSTREAM_EVENTS_INVALID", "invalid upstream_errors")
+ }
+ if idx >= len(events) {
+ return nil, infraerrors.BadRequest("OPS_RETRY_UPSTREAM_IDX_OOB", "upstream idx out of range")
+ }
+ ev := events[idx]
+ if ev == nil {
+ return nil, infraerrors.BadRequest("OPS_RETRY_UPSTREAM_EVENT_MISSING", "upstream event missing")
+ }
+ if ev.AccountID <= 0 {
+ return nil, infraerrors.BadRequest("OPS_RETRY_PINNED_ACCOUNT_REQUIRED", "account_id is required for upstream retry")
+ }
+
+ upstreamBody := strings.TrimSpace(ev.UpstreamRequestBody)
+ if upstreamBody == "" {
+ return nil, infraerrors.BadRequest("OPS_RETRY_UPSTREAM_NO_REQUEST_BODY", "No upstream request body found to retry")
+ }
+
+ override := *errorLog
+ override.RequestBody = upstreamBody
+ pinned := ev.AccountID
+
+ // Persist as upstream_event, execute as upstream pinned retry.
+ return s.retryWithErrorLog(ctx, requestedByUserID, errorID, OpsRetryModeUpstreamEvent, OpsRetryModeUpstream, &pinned, &override)
+}
+
+func (s *OpsService) retryWithErrorLog(ctx context.Context, requestedByUserID int64, errorID int64, mode string, execMode string, pinnedAccountID *int64, errorLog *OpsErrorLogDetail) (*OpsRetryResult, error) {
latest, err := s.opsRepo.GetLatestRetryAttemptForError(ctx, errorID)
if err != nil && !errors.Is(err, sql.ErrNoRows) {
return nil, infraerrors.InternalServer("OPS_RETRY_LOAD_LATEST_FAILED", "Failed to check retry status").WithCause(err)
@@ -144,22 +223,18 @@ func (s *OpsService) RetryError(ctx context.Context, requestedByUserID int64, er
}
}
- errorLog, err := s.GetErrorLogByID(ctx, errorID)
- if err != nil {
- return nil, err
- }
- if strings.TrimSpace(errorLog.RequestBody) == "" {
+ if errorLog == nil || strings.TrimSpace(errorLog.RequestBody) == "" {
return nil, infraerrors.BadRequest("OPS_RETRY_NO_REQUEST_BODY", "No request body found to retry")
}
var pinned *int64
- if mode == OpsRetryModeUpstream {
+ if execMode == OpsRetryModeUpstream {
if pinnedAccountID != nil && *pinnedAccountID > 0 {
pinned = pinnedAccountID
} else if errorLog.AccountID != nil && *errorLog.AccountID > 0 {
pinned = errorLog.AccountID
} else {
- return nil, infraerrors.BadRequest("OPS_RETRY_PINNED_ACCOUNT_REQUIRED", "pinned_account_id is required for upstream retry")
+ return nil, infraerrors.BadRequest("OPS_RETRY_PINNED_ACCOUNT_REQUIRED", "account_id is required for upstream retry")
}
}
@@ -196,7 +271,7 @@ func (s *OpsService) RetryError(ctx context.Context, requestedByUserID int64, er
execCtx, cancel := context.WithTimeout(ctx, opsRetryTimeout)
defer cancel()
- execRes := s.executeRetry(execCtx, errorLog, mode, pinned)
+ execRes := s.executeRetry(execCtx, errorLog, execMode, pinned)
finishedAt := time.Now()
result.FinishedAt = finishedAt
@@ -220,27 +295,40 @@ func (s *OpsService) RetryError(ctx context.Context, requestedByUserID int64, er
msg := result.ErrorMessage
updateErrMsg = &msg
}
+ // Keep legacy result_request_id empty; use upstream_request_id instead.
var resultRequestID *string
- if strings.TrimSpace(result.UpstreamRequestID) != "" {
- v := result.UpstreamRequestID
- resultRequestID = &v
- }
finalStatus := result.Status
if strings.TrimSpace(finalStatus) == "" {
finalStatus = opsRetryStatusFailed
}
+ success := strings.EqualFold(finalStatus, opsRetryStatusSucceeded)
+ httpStatus := result.HTTPStatusCode
+ upstreamReqID := result.UpstreamRequestID
+ usedAccountID := result.UsedAccountID
+ preview := result.ResponsePreview
+ truncated := result.ResponseTruncated
+
if err := s.opsRepo.UpdateRetryAttempt(updateCtx, &OpsUpdateRetryAttemptInput{
- ID: attemptID,
- Status: finalStatus,
- FinishedAt: finishedAt,
- DurationMs: result.DurationMs,
- ResultRequestID: resultRequestID,
- ErrorMessage: updateErrMsg,
+ ID: attemptID,
+ Status: finalStatus,
+ FinishedAt: finishedAt,
+ DurationMs: result.DurationMs,
+ Success: &success,
+ HTTPStatusCode: &httpStatus,
+ UpstreamRequestID: &upstreamReqID,
+ UsedAccountID: usedAccountID,
+ ResponsePreview: &preview,
+ ResponseTruncated: &truncated,
+ ResultRequestID: resultRequestID,
+ ErrorMessage: updateErrMsg,
}); err != nil {
- // Best-effort: retry itself already executed; do not fail the API response.
log.Printf("[Ops] UpdateRetryAttempt failed: %v", err)
+ } else if success {
+ if err := s.opsRepo.UpdateErrorResolution(updateCtx, errorID, true, &requestedByUserID, &attemptID, &finishedAt); err != nil {
+ log.Printf("[Ops] UpdateErrorResolution failed: %v", err)
+ }
}
return result, nil
@@ -426,7 +514,7 @@ func (s *OpsService) selectAccountForRetry(ctx context.Context, reqType opsRetry
if s.gatewayService == nil {
return nil, fmt.Errorf("gateway service not available")
}
- return s.gatewayService.SelectAccountWithLoadAwareness(ctx, groupID, "", model, excludedIDs)
+ return s.gatewayService.SelectAccountWithLoadAwareness(ctx, groupID, "", model, excludedIDs, "") // 重试不使用会话限制
default:
return nil, fmt.Errorf("unsupported retry type: %s", reqType)
}
diff --git a/backend/internal/service/ops_scheduled_report_service.go b/backend/internal/service/ops_scheduled_report_service.go
index 28902cbc..98b2045d 100644
--- a/backend/internal/service/ops_scheduled_report_service.go
+++ b/backend/internal/service/ops_scheduled_report_service.go
@@ -177,6 +177,10 @@ func (s *OpsScheduledReportService) runOnce() {
return
}
+ reportsTotal := len(reports)
+ reportsDue := 0
+ sentAttempts := 0
+
for _, report := range reports {
if report == nil || !report.Enabled {
continue
@@ -184,14 +188,18 @@ func (s *OpsScheduledReportService) runOnce() {
if report.NextRunAt.After(now) {
continue
}
+ reportsDue++
- if err := s.runReport(ctx, report, now); err != nil {
+ attempts, err := s.runReport(ctx, report, now)
+ if err != nil {
s.recordHeartbeatError(runAt, time.Since(startedAt), err)
return
}
+ sentAttempts += attempts
}
- s.recordHeartbeatSuccess(runAt, time.Since(startedAt))
+ result := truncateString(fmt.Sprintf("reports=%d due=%d send_attempts=%d", reportsTotal, reportsDue, sentAttempts), 2048)
+ s.recordHeartbeatSuccess(runAt, time.Since(startedAt), result)
}
type opsScheduledReport struct {
@@ -297,9 +305,9 @@ func (s *OpsScheduledReportService) listScheduledReports(ctx context.Context, no
return out
}
-func (s *OpsScheduledReportService) runReport(ctx context.Context, report *opsScheduledReport, now time.Time) error {
+func (s *OpsScheduledReportService) runReport(ctx context.Context, report *opsScheduledReport, now time.Time) (int, error) {
if s == nil || s.opsService == nil || s.emailService == nil || report == nil {
- return nil
+ return 0, nil
}
if ctx == nil {
ctx = context.Background()
@@ -310,11 +318,11 @@ func (s *OpsScheduledReportService) runReport(ctx context.Context, report *opsSc
content, err := s.generateReportHTML(ctx, report, now)
if err != nil {
- return err
+ return 0, err
}
if strings.TrimSpace(content) == "" {
// Skip sending when the report decides not to emit content (e.g., digest below min count).
- return nil
+ return 0, nil
}
recipients := report.Recipients
@@ -325,22 +333,24 @@ func (s *OpsScheduledReportService) runReport(ctx context.Context, report *opsSc
}
}
if len(recipients) == 0 {
- return nil
+ return 0, nil
}
subject := fmt.Sprintf("[Ops Report] %s", strings.TrimSpace(report.Name))
+ attempts := 0
for _, to := range recipients {
addr := strings.TrimSpace(to)
if addr == "" {
continue
}
+ attempts++
if err := s.emailService.SendEmail(ctx, addr, subject, content); err != nil {
// Ignore per-recipient failures; continue best-effort.
continue
}
}
- return nil
+ return attempts, nil
}
func (s *OpsScheduledReportService) generateReportHTML(ctx context.Context, report *opsScheduledReport, now time.Time) (string, error) {
@@ -650,7 +660,7 @@ func (s *OpsScheduledReportService) setLastRunAt(ctx context.Context, reportType
_ = s.redisClient.Set(ctx, key, strconv.FormatInt(t.UTC().Unix(), 10), 14*24*time.Hour).Err()
}
-func (s *OpsScheduledReportService) recordHeartbeatSuccess(runAt time.Time, duration time.Duration) {
+func (s *OpsScheduledReportService) recordHeartbeatSuccess(runAt time.Time, duration time.Duration, result string) {
if s == nil || s.opsService == nil || s.opsService.opsRepo == nil {
return
}
@@ -658,11 +668,17 @@ func (s *OpsScheduledReportService) recordHeartbeatSuccess(runAt time.Time, dura
durMs := duration.Milliseconds()
ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second)
defer cancel()
+ msg := strings.TrimSpace(result)
+ if msg == "" {
+ msg = "ok"
+ }
+ msg = truncateString(msg, 2048)
_ = s.opsService.opsRepo.UpsertJobHeartbeat(ctx, &OpsUpsertJobHeartbeatInput{
JobName: opsScheduledReportJobName,
LastRunAt: &runAt,
LastSuccessAt: &now,
LastDurationMs: &durMs,
+ LastResult: &msg,
})
}
diff --git a/backend/internal/service/ops_service.go b/backend/internal/service/ops_service.go
index 426d46f1..abb8ae12 100644
--- a/backend/internal/service/ops_service.go
+++ b/backend/internal/service/ops_service.go
@@ -208,6 +208,25 @@ func (s *OpsService) RecordError(ctx context.Context, entry *OpsInsertErrorLogIn
out.Detail = ""
}
+ out.UpstreamRequestBody = strings.TrimSpace(out.UpstreamRequestBody)
+ if out.UpstreamRequestBody != "" {
+ // Reuse the same sanitization/trimming strategy as request body storage.
+ // Keep it small so it is safe to persist in ops_error_logs JSON.
+ sanitized, truncated, _ := sanitizeAndTrimRequestBody([]byte(out.UpstreamRequestBody), 10*1024)
+ if sanitized != "" {
+ out.UpstreamRequestBody = sanitized
+ if truncated {
+ out.Kind = strings.TrimSpace(out.Kind)
+ if out.Kind == "" {
+ out.Kind = "upstream"
+ }
+ out.Kind = out.Kind + ":request_body_truncated"
+ }
+ } else {
+ out.UpstreamRequestBody = ""
+ }
+ }
+
// Drop fully-empty events (can happen if only status code was known).
if out.UpstreamStatusCode == 0 && out.Message == "" && out.Detail == "" {
continue
@@ -236,7 +255,13 @@ func (s *OpsService) GetErrorLogs(ctx context.Context, filter *OpsErrorLogFilter
if s.opsRepo == nil {
return &OpsErrorLogList{Errors: []*OpsErrorLog{}, Total: 0, Page: 1, PageSize: 20}, nil
}
- return s.opsRepo.ListErrorLogs(ctx, filter)
+ result, err := s.opsRepo.ListErrorLogs(ctx, filter)
+ if err != nil {
+ log.Printf("[Ops] GetErrorLogs failed: %v", err)
+ return nil, err
+ }
+
+ return result, nil
}
func (s *OpsService) GetErrorLogByID(ctx context.Context, id int64) (*OpsErrorLogDetail, error) {
@@ -256,6 +281,46 @@ func (s *OpsService) GetErrorLogByID(ctx context.Context, id int64) (*OpsErrorLo
return detail, nil
}
+func (s *OpsService) ListRetryAttemptsByErrorID(ctx context.Context, errorID int64, limit int) ([]*OpsRetryAttempt, error) {
+ if err := s.RequireMonitoringEnabled(ctx); err != nil {
+ return nil, err
+ }
+ if s.opsRepo == nil {
+ return nil, infraerrors.ServiceUnavailable("OPS_REPO_UNAVAILABLE", "Ops repository not available")
+ }
+ if errorID <= 0 {
+ return nil, infraerrors.BadRequest("OPS_ERROR_INVALID_ID", "invalid error id")
+ }
+ items, err := s.opsRepo.ListRetryAttemptsByErrorID(ctx, errorID, limit)
+ if err != nil {
+ if errors.Is(err, sql.ErrNoRows) {
+ return []*OpsRetryAttempt{}, nil
+ }
+ return nil, infraerrors.InternalServer("OPS_RETRY_LIST_FAILED", "Failed to list retry attempts").WithCause(err)
+ }
+ return items, nil
+}
+
+func (s *OpsService) UpdateErrorResolution(ctx context.Context, errorID int64, resolved bool, resolvedByUserID *int64, resolvedRetryID *int64) error {
+ if err := s.RequireMonitoringEnabled(ctx); err != nil {
+ return err
+ }
+ if s.opsRepo == nil {
+ return infraerrors.ServiceUnavailable("OPS_REPO_UNAVAILABLE", "Ops repository not available")
+ }
+ if errorID <= 0 {
+ return infraerrors.BadRequest("OPS_ERROR_INVALID_ID", "invalid error id")
+ }
+ // Best-effort ensure the error exists
+ if _, err := s.opsRepo.GetErrorLogByID(ctx, errorID); err != nil {
+ if errors.Is(err, sql.ErrNoRows) {
+ return infraerrors.NotFound("OPS_ERROR_NOT_FOUND", "ops error log not found")
+ }
+ return infraerrors.InternalServer("OPS_ERROR_LOAD_FAILED", "Failed to load ops error log").WithCause(err)
+ }
+ return s.opsRepo.UpdateErrorResolution(ctx, errorID, resolved, resolvedByUserID, resolvedRetryID, nil)
+}
+
func sanitizeAndTrimRequestBody(raw []byte, maxBytes int) (jsonString string, truncated bool, bytesLen int) {
bytesLen = len(raw)
if len(raw) == 0 {
@@ -296,14 +361,34 @@ func sanitizeAndTrimRequestBody(raw []byte, maxBytes int) (jsonString string, tr
}
}
- // Last resort: store a minimal placeholder (still valid JSON).
- placeholder := map[string]any{
- "request_body_truncated": true,
+ // Last resort: keep JSON shape but drop big fields.
+ // This avoids downstream code that expects certain top-level keys from crashing.
+ if root, ok := decoded.(map[string]any); ok {
+ placeholder := shallowCopyMap(root)
+ placeholder["request_body_truncated"] = true
+
+ // Replace potentially huge arrays/strings, but keep the keys present.
+ for _, k := range []string{"messages", "contents", "input", "prompt"} {
+ if _, exists := placeholder[k]; exists {
+ placeholder[k] = []any{}
+ }
+ }
+ for _, k := range []string{"text"} {
+ if _, exists := placeholder[k]; exists {
+ placeholder[k] = ""
+ }
+ }
+
+ encoded4, err4 := json.Marshal(placeholder)
+ if err4 == nil {
+ if len(encoded4) <= maxBytes {
+ return string(encoded4), true, bytesLen
+ }
+ }
}
- if model := extractString(decoded, "model"); model != "" {
- placeholder["model"] = model
- }
- encoded4, err4 := json.Marshal(placeholder)
+
+ // Final fallback: minimal valid JSON.
+ encoded4, err4 := json.Marshal(map[string]any{"request_body_truncated": true})
if err4 != nil {
return "", true, bytesLen
}
@@ -526,12 +611,3 @@ func sanitizeErrorBodyForStorage(raw string, maxBytes int) (sanitized string, tr
}
return raw, false
}
-
-func extractString(v any, key string) string {
- root, ok := v.(map[string]any)
- if !ok {
- return ""
- }
- s, _ := root[key].(string)
- return strings.TrimSpace(s)
-}
diff --git a/backend/internal/service/ops_settings.go b/backend/internal/service/ops_settings.go
index 53c78fed..a6a4a0d7 100644
--- a/backend/internal/service/ops_settings.go
+++ b/backend/internal/service/ops_settings.go
@@ -368,9 +368,11 @@ func defaultOpsAdvancedSettings() *OpsAdvancedSettings {
Aggregation: OpsAggregationSettings{
AggregationEnabled: false,
},
- IgnoreCountTokensErrors: false,
- AutoRefreshEnabled: false,
- AutoRefreshIntervalSec: 30,
+ IgnoreCountTokensErrors: false,
+ IgnoreContextCanceled: true, // Default to true - client disconnects are not errors
+ IgnoreNoAvailableAccounts: false, // Default to false - this is a real routing issue
+ AutoRefreshEnabled: false,
+ AutoRefreshIntervalSec: 30,
}
}
@@ -482,13 +484,11 @@ const SettingKeyOpsMetricThresholds = "ops_metric_thresholds"
func defaultOpsMetricThresholds() *OpsMetricThresholds {
slaMin := 99.5
- latencyMax := 2000.0
ttftMax := 500.0
reqErrMax := 5.0
upstreamErrMax := 5.0
return &OpsMetricThresholds{
SLAPercentMin: &slaMin,
- LatencyP99MsMax: &latencyMax,
TTFTp99MsMax: &ttftMax,
RequestErrorRatePercentMax: &reqErrMax,
UpstreamErrorRatePercentMax: &upstreamErrMax,
@@ -538,9 +538,6 @@ func (s *OpsService) UpdateMetricThresholds(ctx context.Context, cfg *OpsMetricT
if cfg.SLAPercentMin != nil && (*cfg.SLAPercentMin < 0 || *cfg.SLAPercentMin > 100) {
return nil, errors.New("sla_percent_min must be between 0 and 100")
}
- if cfg.LatencyP99MsMax != nil && *cfg.LatencyP99MsMax < 0 {
- return nil, errors.New("latency_p99_ms_max must be >= 0")
- }
if cfg.TTFTp99MsMax != nil && *cfg.TTFTp99MsMax < 0 {
return nil, errors.New("ttft_p99_ms_max must be >= 0")
}
diff --git a/backend/internal/service/ops_settings_models.go b/backend/internal/service/ops_settings_models.go
index 229488a1..df06f578 100644
--- a/backend/internal/service/ops_settings_models.go
+++ b/backend/internal/service/ops_settings_models.go
@@ -63,7 +63,6 @@ type OpsAlertSilencingSettings struct {
type OpsMetricThresholds struct {
SLAPercentMin *float64 `json:"sla_percent_min,omitempty"` // SLA低于此值变红
- LatencyP99MsMax *float64 `json:"latency_p99_ms_max,omitempty"` // 延迟P99高于此值变红
TTFTp99MsMax *float64 `json:"ttft_p99_ms_max,omitempty"` // TTFT P99高于此值变红
RequestErrorRatePercentMax *float64 `json:"request_error_rate_percent_max,omitempty"` // 请求错误率高于此值变红
UpstreamErrorRatePercentMax *float64 `json:"upstream_error_rate_percent_max,omitempty"` // 上游错误率高于此值变红
@@ -79,11 +78,13 @@ type OpsAlertRuntimeSettings struct {
// OpsAdvancedSettings stores advanced ops configuration (data retention, aggregation).
type OpsAdvancedSettings struct {
- DataRetention OpsDataRetentionSettings `json:"data_retention"`
- Aggregation OpsAggregationSettings `json:"aggregation"`
- IgnoreCountTokensErrors bool `json:"ignore_count_tokens_errors"`
- AutoRefreshEnabled bool `json:"auto_refresh_enabled"`
- AutoRefreshIntervalSec int `json:"auto_refresh_interval_seconds"`
+ DataRetention OpsDataRetentionSettings `json:"data_retention"`
+ Aggregation OpsAggregationSettings `json:"aggregation"`
+ IgnoreCountTokensErrors bool `json:"ignore_count_tokens_errors"`
+ IgnoreContextCanceled bool `json:"ignore_context_canceled"`
+ IgnoreNoAvailableAccounts bool `json:"ignore_no_available_accounts"`
+ AutoRefreshEnabled bool `json:"auto_refresh_enabled"`
+ AutoRefreshIntervalSec int `json:"auto_refresh_interval_seconds"`
}
type OpsDataRetentionSettings struct {
diff --git a/backend/internal/service/ops_upstream_context.go b/backend/internal/service/ops_upstream_context.go
index 615ae6a1..96bcc9fe 100644
--- a/backend/internal/service/ops_upstream_context.go
+++ b/backend/internal/service/ops_upstream_context.go
@@ -15,6 +15,11 @@ const (
OpsUpstreamErrorMessageKey = "ops_upstream_error_message"
OpsUpstreamErrorDetailKey = "ops_upstream_error_detail"
OpsUpstreamErrorsKey = "ops_upstream_errors"
+
+ // Best-effort capture of the current upstream request body so ops can
+ // retry the specific upstream attempt (not just the client request).
+ // This value is sanitized+trimmed before being persisted.
+ OpsUpstreamRequestBodyKey = "ops_upstream_request_body"
)
func setOpsUpstreamError(c *gin.Context, upstreamStatusCode int, upstreamMessage, upstreamDetail string) {
@@ -38,13 +43,21 @@ type OpsUpstreamErrorEvent struct {
AtUnixMs int64 `json:"at_unix_ms,omitempty"`
// Context
- Platform string `json:"platform,omitempty"`
- AccountID int64 `json:"account_id,omitempty"`
+ Platform string `json:"platform,omitempty"`
+ AccountID int64 `json:"account_id,omitempty"`
+ AccountName string `json:"account_name,omitempty"`
// Outcome
UpstreamStatusCode int `json:"upstream_status_code,omitempty"`
UpstreamRequestID string `json:"upstream_request_id,omitempty"`
+ // Best-effort upstream request capture (sanitized+trimmed).
+ // Required for retrying a specific upstream attempt.
+ UpstreamRequestBody string `json:"upstream_request_body,omitempty"`
+
+ // Best-effort upstream response capture (sanitized+trimmed).
+ UpstreamResponseBody string `json:"upstream_response_body,omitempty"`
+
// Kind: http_error | request_error | retry_exhausted | failover
Kind string `json:"kind,omitempty"`
@@ -61,6 +74,8 @@ func appendOpsUpstreamError(c *gin.Context, ev OpsUpstreamErrorEvent) {
}
ev.Platform = strings.TrimSpace(ev.Platform)
ev.UpstreamRequestID = strings.TrimSpace(ev.UpstreamRequestID)
+ ev.UpstreamRequestBody = strings.TrimSpace(ev.UpstreamRequestBody)
+ ev.UpstreamResponseBody = strings.TrimSpace(ev.UpstreamResponseBody)
ev.Kind = strings.TrimSpace(ev.Kind)
ev.Message = strings.TrimSpace(ev.Message)
ev.Detail = strings.TrimSpace(ev.Detail)
@@ -68,6 +83,16 @@ func appendOpsUpstreamError(c *gin.Context, ev OpsUpstreamErrorEvent) {
ev.Message = sanitizeUpstreamErrorMessage(ev.Message)
}
+ // If the caller didn't explicitly pass upstream request body but the gateway
+ // stored it on the context, attach it so ops can retry this specific attempt.
+ if ev.UpstreamRequestBody == "" {
+ if v, ok := c.Get(OpsUpstreamRequestBodyKey); ok {
+ if s, ok := v.(string); ok {
+ ev.UpstreamRequestBody = strings.TrimSpace(s)
+ }
+ }
+ }
+
var existing []*OpsUpstreamErrorEvent
if v, ok := c.Get(OpsUpstreamErrorsKey); ok {
if arr, ok := v.([]*OpsUpstreamErrorEvent); ok {
@@ -92,3 +117,15 @@ func marshalOpsUpstreamErrors(events []*OpsUpstreamErrorEvent) *string {
s := string(raw)
return &s
}
+
+func ParseOpsUpstreamErrors(raw string) ([]*OpsUpstreamErrorEvent, error) {
+ raw = strings.TrimSpace(raw)
+ if raw == "" {
+ return []*OpsUpstreamErrorEvent{}, nil
+ }
+ var out []*OpsUpstreamErrorEvent
+ if err := json.Unmarshal([]byte(raw), &out); err != nil {
+ return nil, err
+ }
+ return out, nil
+}
diff --git a/backend/internal/service/proxy.go b/backend/internal/service/proxy.go
index 768e2a0a..7eb7728f 100644
--- a/backend/internal/service/proxy.go
+++ b/backend/internal/service/proxy.go
@@ -31,5 +31,21 @@ func (p *Proxy) URL() string {
type ProxyWithAccountCount struct {
Proxy
- AccountCount int64
+ AccountCount int64
+ LatencyMs *int64
+ LatencyStatus string
+ LatencyMessage string
+ IPAddress string
+ Country string
+ CountryCode string
+ Region string
+ City string
+}
+
+type ProxyAccountSummary struct {
+ ID int64
+ Name string
+ Platform string
+ Type string
+ Notes *string
}
diff --git a/backend/internal/service/proxy_latency_cache.go b/backend/internal/service/proxy_latency_cache.go
new file mode 100644
index 00000000..4a1cc77b
--- /dev/null
+++ b/backend/internal/service/proxy_latency_cache.go
@@ -0,0 +1,23 @@
+package service
+
+import (
+ "context"
+ "time"
+)
+
+type ProxyLatencyInfo struct {
+ Success bool `json:"success"`
+ LatencyMs *int64 `json:"latency_ms,omitempty"`
+ Message string `json:"message,omitempty"`
+ IPAddress string `json:"ip_address,omitempty"`
+ Country string `json:"country,omitempty"`
+ CountryCode string `json:"country_code,omitempty"`
+ Region string `json:"region,omitempty"`
+ City string `json:"city,omitempty"`
+ UpdatedAt time.Time `json:"updated_at"`
+}
+
+type ProxyLatencyCache interface {
+ GetProxyLatencies(ctx context.Context, proxyIDs []int64) (map[int64]*ProxyLatencyInfo, error)
+ SetProxyLatency(ctx context.Context, proxyID int64, info *ProxyLatencyInfo) error
+}
diff --git a/backend/internal/service/proxy_service.go b/backend/internal/service/proxy_service.go
index 58408d04..a5d897f6 100644
--- a/backend/internal/service/proxy_service.go
+++ b/backend/internal/service/proxy_service.go
@@ -10,6 +10,7 @@ import (
var (
ErrProxyNotFound = infraerrors.NotFound("PROXY_NOT_FOUND", "proxy not found")
+ ErrProxyInUse = infraerrors.Conflict("PROXY_IN_USE", "proxy is in use by accounts")
)
type ProxyRepository interface {
@@ -26,6 +27,7 @@ type ProxyRepository interface {
ExistsByHostPortAuth(ctx context.Context, host string, port int, username, password string) (bool, error)
CountAccountsByProxyID(ctx context.Context, proxyID int64) (int64, error)
+ ListAccountSummariesByProxyID(ctx context.Context, proxyID int64) ([]ProxyAccountSummary, error)
}
// CreateProxyRequest 创建代理请求
diff --git a/backend/internal/service/ratelimit_service.go b/backend/internal/service/ratelimit_service.go
index 62b0c6b8..47a04cf5 100644
--- a/backend/internal/service/ratelimit_service.go
+++ b/backend/internal/service/ratelimit_service.go
@@ -3,7 +3,7 @@ package service
import (
"context"
"encoding/json"
- "log"
+ "log/slog"
"net/http"
"strconv"
"strings"
@@ -15,15 +15,16 @@ import (
// RateLimitService 处理限流和过载状态管理
type RateLimitService struct {
- accountRepo AccountRepository
- usageRepo UsageLogRepository
- cfg *config.Config
- geminiQuotaService *GeminiQuotaService
- tempUnschedCache TempUnschedCache
- timeoutCounterCache TimeoutCounterCache
- settingService *SettingService
- usageCacheMu sync.RWMutex
- usageCache map[int64]*geminiUsageCacheEntry
+ accountRepo AccountRepository
+ usageRepo UsageLogRepository
+ cfg *config.Config
+ geminiQuotaService *GeminiQuotaService
+ tempUnschedCache TempUnschedCache
+ timeoutCounterCache TimeoutCounterCache
+ settingService *SettingService
+ tokenCacheInvalidator TokenCacheInvalidator
+ usageCacheMu sync.RWMutex
+ usageCache map[int64]*geminiUsageCacheEntry
}
type geminiUsageCacheEntry struct {
@@ -56,6 +57,11 @@ func (s *RateLimitService) SetSettingService(settingService *SettingService) {
s.settingService = settingService
}
+// SetTokenCacheInvalidator 设置 token 缓存清理器(可选依赖)
+func (s *RateLimitService) SetTokenCacheInvalidator(invalidator TokenCacheInvalidator) {
+ s.tokenCacheInvalidator = invalidator
+}
+
// HandleUpstreamError 处理上游错误响应,标记账号状态
// 返回是否应该停止该账号的调度
func (s *RateLimitService) HandleUpstreamError(ctx context.Context, account *Account, statusCode int, headers http.Header, responseBody []byte) (shouldDisable bool) {
@@ -63,11 +69,14 @@ func (s *RateLimitService) HandleUpstreamError(ctx context.Context, account *Acc
// 如果启用且错误码不在列表中,则不处理(不停止调度、不标记限流/过载)
customErrorCodesEnabled := account.IsCustomErrorCodesEnabled()
if !account.ShouldHandleErrorCode(statusCode) {
- log.Printf("Account %d: error %d skipped (not in custom error codes)", account.ID, statusCode)
+ slog.Info("account_error_code_skipped", "account_id", account.ID, "status_code", statusCode)
return false
}
- tempMatched := s.tryTempUnschedulable(ctx, account, statusCode, responseBody)
+ tempMatched := false
+ if statusCode != 401 {
+ tempMatched = s.tryTempUnschedulable(ctx, account, statusCode, responseBody)
+ }
upstreamMsg := strings.TrimSpace(extractUpstreamErrorMessage(responseBody))
upstreamMsg = sanitizeUpstreamErrorMessage(upstreamMsg)
if upstreamMsg != "" {
@@ -76,7 +85,25 @@ func (s *RateLimitService) HandleUpstreamError(ctx context.Context, account *Acc
switch statusCode {
case 401:
- // 认证失败:停止调度,记录错误
+ // 对所有 OAuth 账号在 401 错误时调用缓存失效并强制下次刷新
+ if account.Type == AccountTypeOAuth {
+ // 1. 失效缓存
+ if s.tokenCacheInvalidator != nil {
+ if err := s.tokenCacheInvalidator.InvalidateToken(ctx, account); err != nil {
+ slog.Warn("oauth_401_invalidate_cache_failed", "account_id", account.ID, "error", err)
+ }
+ }
+ // 2. 设置 expires_at 为当前时间,强制下次请求刷新 token
+ if account.Credentials == nil {
+ account.Credentials = make(map[string]any)
+ }
+ account.Credentials["expires_at"] = time.Now().Format(time.RFC3339)
+ if err := s.accountRepo.Update(ctx, account); err != nil {
+ slog.Warn("oauth_401_force_refresh_update_failed", "account_id", account.ID, "error", err)
+ } else {
+ slog.Info("oauth_401_force_refresh_set", "account_id", account.ID, "platform", account.Platform)
+ }
+ }
msg := "Authentication failed (401): invalid or expired credentials"
if upstreamMsg != "" {
msg = "Authentication failed (401): " + upstreamMsg
@@ -100,7 +127,7 @@ func (s *RateLimitService) HandleUpstreamError(ctx context.Context, account *Acc
s.handleAuthError(ctx, account, msg)
shouldDisable = true
case 429:
- s.handle429(ctx, account, headers)
+ s.handle429(ctx, account, headers, responseBody)
shouldDisable = false
case 529:
s.handle529(ctx, account)
@@ -116,7 +143,7 @@ func (s *RateLimitService) HandleUpstreamError(ctx context.Context, account *Acc
shouldDisable = true
} else if statusCode >= 500 {
// 未启用自定义错误码时:仅记录5xx错误
- log.Printf("Account %d received upstream error %d", account.ID, statusCode)
+ slog.Warn("account_upstream_error", "account_id", account.ID, "status_code", statusCode)
shouldDisable = false
}
}
@@ -163,7 +190,7 @@ func (s *RateLimitService) PreCheckUsage(ctx context.Context, account *Account,
start := geminiDailyWindowStart(now)
totals, ok := s.getGeminiUsageTotals(account.ID, start, now)
if !ok {
- stats, err := s.usageRepo.GetModelStatsWithFilters(ctx, start, now, 0, 0, account.ID)
+ stats, err := s.usageRepo.GetModelStatsWithFilters(ctx, start, now, 0, 0, account.ID, 0, nil)
if err != nil {
return true, err
}
@@ -188,7 +215,7 @@ func (s *RateLimitService) PreCheckUsage(ctx context.Context, account *Account,
// NOTE:
// - This is a local precheck to reduce upstream 429s.
// - Do NOT mark the account as rate-limited here; rate_limit_reset_at should reflect real upstream 429s.
- log.Printf("[Gemini PreCheck] Account %d reached daily quota (%d/%d), skip until %v", account.ID, used, limit, resetAt)
+ slog.Info("gemini_precheck_daily_quota_reached", "account_id", account.ID, "used", used, "limit", limit, "reset_at", resetAt)
return false, nil
}
}
@@ -210,7 +237,7 @@ func (s *RateLimitService) PreCheckUsage(ctx context.Context, account *Account,
if limit > 0 {
start := now.Truncate(time.Minute)
- stats, err := s.usageRepo.GetModelStatsWithFilters(ctx, start, now, 0, 0, account.ID)
+ stats, err := s.usageRepo.GetModelStatsWithFilters(ctx, start, now, 0, 0, account.ID, 0, nil)
if err != nil {
return true, err
}
@@ -231,7 +258,7 @@ func (s *RateLimitService) PreCheckUsage(ctx context.Context, account *Account,
if used >= limit {
resetAt := start.Add(time.Minute)
// Do not persist "rate limited" status from local precheck. See note above.
- log.Printf("[Gemini PreCheck] Account %d reached minute quota (%d/%d), skip until %v", account.ID, used, limit, resetAt)
+ slog.Info("gemini_precheck_minute_quota_reached", "account_id", account.ID, "used", used, "limit", limit, "reset_at", resetAt)
return false, nil
}
}
@@ -288,32 +315,40 @@ func (s *RateLimitService) GeminiCooldown(ctx context.Context, account *Account)
// handleAuthError 处理认证类错误(401/403),停止账号调度
func (s *RateLimitService) handleAuthError(ctx context.Context, account *Account, errorMsg string) {
if err := s.accountRepo.SetError(ctx, account.ID, errorMsg); err != nil {
- log.Printf("SetError failed for account %d: %v", account.ID, err)
+ slog.Warn("account_set_error_failed", "account_id", account.ID, "error", err)
return
}
- log.Printf("Account %d disabled due to auth error: %s", account.ID, errorMsg)
+ slog.Warn("account_disabled_auth_error", "account_id", account.ID, "error", errorMsg)
}
// handleCustomErrorCode 处理自定义错误码,停止账号调度
func (s *RateLimitService) handleCustomErrorCode(ctx context.Context, account *Account, statusCode int, errorMsg string) {
msg := "Custom error code " + strconv.Itoa(statusCode) + ": " + errorMsg
if err := s.accountRepo.SetError(ctx, account.ID, msg); err != nil {
- log.Printf("SetError failed for account %d: %v", account.ID, err)
+ slog.Warn("account_set_error_failed", "account_id", account.ID, "status_code", statusCode, "error", err)
return
}
- log.Printf("Account %d disabled due to custom error code %d: %s", account.ID, statusCode, errorMsg)
+ slog.Warn("account_disabled_custom_error", "account_id", account.ID, "status_code", statusCode, "error", errorMsg)
}
// handle429 处理429限流错误
// 解析响应头获取重置时间,标记账号为限流状态
-func (s *RateLimitService) handle429(ctx context.Context, account *Account, headers http.Header) {
+func (s *RateLimitService) handle429(ctx context.Context, account *Account, headers http.Header, responseBody []byte) {
// 解析重置时间戳
resetTimestamp := headers.Get("anthropic-ratelimit-unified-reset")
if resetTimestamp == "" {
// 没有重置时间,使用默认5分钟
resetAt := time.Now().Add(5 * time.Minute)
+ if s.shouldScopeClaudeSonnetRateLimit(account, responseBody) {
+ if err := s.accountRepo.SetModelRateLimit(ctx, account.ID, modelRateLimitScopeClaudeSonnet, resetAt); err != nil {
+ slog.Warn("model_rate_limit_set_failed", "account_id", account.ID, "scope", modelRateLimitScopeClaudeSonnet, "error", err)
+ } else {
+ slog.Info("account_model_rate_limited", "account_id", account.ID, "scope", modelRateLimitScopeClaudeSonnet, "reset_at", resetAt)
+ }
+ return
+ }
if err := s.accountRepo.SetRateLimited(ctx, account.ID, resetAt); err != nil {
- log.Printf("SetRateLimited failed for account %d: %v", account.ID, err)
+ slog.Warn("rate_limit_set_failed", "account_id", account.ID, "error", err)
}
return
}
@@ -321,19 +356,36 @@ func (s *RateLimitService) handle429(ctx context.Context, account *Account, head
// 解析Unix时间戳
ts, err := strconv.ParseInt(resetTimestamp, 10, 64)
if err != nil {
- log.Printf("Parse reset timestamp failed: %v", err)
+ slog.Warn("rate_limit_reset_parse_failed", "reset_timestamp", resetTimestamp, "error", err)
resetAt := time.Now().Add(5 * time.Minute)
+ if s.shouldScopeClaudeSonnetRateLimit(account, responseBody) {
+ if err := s.accountRepo.SetModelRateLimit(ctx, account.ID, modelRateLimitScopeClaudeSonnet, resetAt); err != nil {
+ slog.Warn("model_rate_limit_set_failed", "account_id", account.ID, "scope", modelRateLimitScopeClaudeSonnet, "error", err)
+ } else {
+ slog.Info("account_model_rate_limited", "account_id", account.ID, "scope", modelRateLimitScopeClaudeSonnet, "reset_at", resetAt)
+ }
+ return
+ }
if err := s.accountRepo.SetRateLimited(ctx, account.ID, resetAt); err != nil {
- log.Printf("SetRateLimited failed for account %d: %v", account.ID, err)
+ slog.Warn("rate_limit_set_failed", "account_id", account.ID, "error", err)
}
return
}
resetAt := time.Unix(ts, 0)
+ if s.shouldScopeClaudeSonnetRateLimit(account, responseBody) {
+ if err := s.accountRepo.SetModelRateLimit(ctx, account.ID, modelRateLimitScopeClaudeSonnet, resetAt); err != nil {
+ slog.Warn("model_rate_limit_set_failed", "account_id", account.ID, "scope", modelRateLimitScopeClaudeSonnet, "error", err)
+ return
+ }
+ slog.Info("account_model_rate_limited", "account_id", account.ID, "scope", modelRateLimitScopeClaudeSonnet, "reset_at", resetAt)
+ return
+ }
+
// 标记限流状态
if err := s.accountRepo.SetRateLimited(ctx, account.ID, resetAt); err != nil {
- log.Printf("SetRateLimited failed for account %d: %v", account.ID, err)
+ slog.Warn("rate_limit_set_failed", "account_id", account.ID, "error", err)
return
}
@@ -341,10 +393,21 @@ func (s *RateLimitService) handle429(ctx context.Context, account *Account, head
windowEnd := resetAt
windowStart := resetAt.Add(-5 * time.Hour)
if err := s.accountRepo.UpdateSessionWindow(ctx, account.ID, &windowStart, &windowEnd, "rejected"); err != nil {
- log.Printf("UpdateSessionWindow failed for account %d: %v", account.ID, err)
+ slog.Warn("rate_limit_update_session_window_failed", "account_id", account.ID, "error", err)
}
- log.Printf("Account %d rate limited until %v", account.ID, resetAt)
+ slog.Info("account_rate_limited", "account_id", account.ID, "reset_at", resetAt)
+}
+
+func (s *RateLimitService) shouldScopeClaudeSonnetRateLimit(account *Account, responseBody []byte) bool {
+ if account == nil || account.Platform != PlatformAnthropic {
+ return false
+ }
+ msg := strings.ToLower(strings.TrimSpace(extractUpstreamErrorMessage(responseBody)))
+ if msg == "" {
+ return false
+ }
+ return strings.Contains(msg, "sonnet")
}
// handle529 处理529过载错误
@@ -357,11 +420,11 @@ func (s *RateLimitService) handle529(ctx context.Context, account *Account) {
until := time.Now().Add(time.Duration(cooldownMinutes) * time.Minute)
if err := s.accountRepo.SetOverloaded(ctx, account.ID, until); err != nil {
- log.Printf("SetOverloaded failed for account %d: %v", account.ID, err)
+ slog.Warn("overload_set_failed", "account_id", account.ID, "error", err)
return
}
- log.Printf("Account %d overloaded until %v", account.ID, until)
+ slog.Info("account_overloaded", "account_id", account.ID, "until", until)
}
// UpdateSessionWindow 从成功响应更新5h窗口状态
@@ -384,17 +447,17 @@ func (s *RateLimitService) UpdateSessionWindow(ctx context.Context, account *Acc
end := start.Add(5 * time.Hour)
windowStart = &start
windowEnd = &end
- log.Printf("Account %d: initializing 5h window from %v to %v (status: %s)", account.ID, start, end, status)
+ slog.Info("account_session_window_initialized", "account_id", account.ID, "window_start", start, "window_end", end, "status", status)
}
if err := s.accountRepo.UpdateSessionWindow(ctx, account.ID, windowStart, windowEnd, status); err != nil {
- log.Printf("UpdateSessionWindow failed for account %d: %v", account.ID, err)
+ slog.Warn("session_window_update_failed", "account_id", account.ID, "error", err)
}
// 如果状态为allowed且之前有限流,说明窗口已重置,清除限流状态
if status == "allowed" && account.IsRateLimited() {
if err := s.ClearRateLimit(ctx, account.ID); err != nil {
- log.Printf("ClearRateLimit failed for account %d: %v", account.ID, err)
+ slog.Warn("rate_limit_clear_failed", "account_id", account.ID, "error", err)
}
}
}
@@ -404,7 +467,10 @@ func (s *RateLimitService) ClearRateLimit(ctx context.Context, accountID int64)
if err := s.accountRepo.ClearRateLimit(ctx, accountID); err != nil {
return err
}
- return s.accountRepo.ClearAntigravityQuotaScopes(ctx, accountID)
+ if err := s.accountRepo.ClearAntigravityQuotaScopes(ctx, accountID); err != nil {
+ return err
+ }
+ return s.accountRepo.ClearModelRateLimits(ctx, accountID)
}
func (s *RateLimitService) ClearTempUnschedulable(ctx context.Context, accountID int64) error {
@@ -413,7 +479,7 @@ func (s *RateLimitService) ClearTempUnschedulable(ctx context.Context, accountID
}
if s.tempUnschedCache != nil {
if err := s.tempUnschedCache.DeleteTempUnsched(ctx, accountID); err != nil {
- log.Printf("DeleteTempUnsched failed for account %d: %v", accountID, err)
+ slog.Warn("temp_unsched_cache_delete_failed", "account_id", accountID, "error", err)
}
}
return nil
@@ -460,7 +526,7 @@ func (s *RateLimitService) GetTempUnschedStatus(ctx context.Context, accountID i
if s.tempUnschedCache != nil {
if err := s.tempUnschedCache.SetTempUnsched(ctx, accountID, state); err != nil {
- log.Printf("SetTempUnsched failed for account %d: %v", accountID, err)
+ slog.Warn("temp_unsched_cache_set_failed", "account_id", accountID, "error", err)
}
}
@@ -563,17 +629,17 @@ func (s *RateLimitService) triggerTempUnschedulable(ctx context.Context, account
}
if err := s.accountRepo.SetTempUnschedulable(ctx, account.ID, until, reason); err != nil {
- log.Printf("SetTempUnschedulable failed for account %d: %v", account.ID, err)
+ slog.Warn("temp_unsched_set_failed", "account_id", account.ID, "error", err)
return false
}
if s.tempUnschedCache != nil {
if err := s.tempUnschedCache.SetTempUnsched(ctx, account.ID, state); err != nil {
- log.Printf("SetTempUnsched cache failed for account %d: %v", account.ID, err)
+ slog.Warn("temp_unsched_cache_set_failed", "account_id", account.ID, "error", err)
}
}
- log.Printf("Account %d temp unschedulable until %v (rule %d, code %d)", account.ID, until, ruleIndex, statusCode)
+ slog.Info("account_temp_unschedulable", "account_id", account.ID, "until", until, "rule_index", ruleIndex, "status_code", statusCode)
return true
}
@@ -597,13 +663,13 @@ func (s *RateLimitService) HandleStreamTimeout(ctx context.Context, account *Acc
// 获取系统设置
if s.settingService == nil {
- log.Printf("[StreamTimeout] settingService not configured, skipping timeout handling for account %d", account.ID)
+ slog.Warn("stream_timeout_setting_service_missing", "account_id", account.ID)
return false
}
settings, err := s.settingService.GetStreamTimeoutSettings(ctx)
if err != nil {
- log.Printf("[StreamTimeout] Failed to get settings: %v", err)
+ slog.Warn("stream_timeout_get_settings_failed", "account_id", account.ID, "error", err)
return false
}
@@ -620,14 +686,13 @@ func (s *RateLimitService) HandleStreamTimeout(ctx context.Context, account *Acc
if s.timeoutCounterCache != nil {
count, err = s.timeoutCounterCache.IncrementTimeoutCount(ctx, account.ID, settings.ThresholdWindowMinutes)
if err != nil {
- log.Printf("[StreamTimeout] Failed to increment timeout count for account %d: %v", account.ID, err)
+ slog.Warn("stream_timeout_increment_count_failed", "account_id", account.ID, "error", err)
// 继续处理,使用 count=1
count = 1
}
}
- log.Printf("[StreamTimeout] Account %d timeout count: %d/%d (window: %d min, model: %s)",
- account.ID, count, settings.ThresholdCount, settings.ThresholdWindowMinutes, model)
+ slog.Info("stream_timeout_count", "account_id", account.ID, "count", count, "threshold", settings.ThresholdCount, "window_minutes", settings.ThresholdWindowMinutes, "model", model)
// 检查是否达到阈值
if count < int64(settings.ThresholdCount) {
@@ -668,24 +733,24 @@ func (s *RateLimitService) triggerStreamTimeoutTempUnsched(ctx context.Context,
}
if err := s.accountRepo.SetTempUnschedulable(ctx, account.ID, until, reason); err != nil {
- log.Printf("[StreamTimeout] SetTempUnschedulable failed for account %d: %v", account.ID, err)
+ slog.Warn("stream_timeout_set_temp_unsched_failed", "account_id", account.ID, "error", err)
return false
}
if s.tempUnschedCache != nil {
if err := s.tempUnschedCache.SetTempUnsched(ctx, account.ID, state); err != nil {
- log.Printf("[StreamTimeout] SetTempUnsched cache failed for account %d: %v", account.ID, err)
+ slog.Warn("stream_timeout_set_temp_unsched_cache_failed", "account_id", account.ID, "error", err)
}
}
// 重置超时计数
if s.timeoutCounterCache != nil {
if err := s.timeoutCounterCache.ResetTimeoutCount(ctx, account.ID); err != nil {
- log.Printf("[StreamTimeout] ResetTimeoutCount failed for account %d: %v", account.ID, err)
+ slog.Warn("stream_timeout_reset_count_failed", "account_id", account.ID, "error", err)
}
}
- log.Printf("[StreamTimeout] Account %d marked as temp unschedulable until %v (model: %s)", account.ID, until, model)
+ slog.Info("stream_timeout_temp_unschedulable", "account_id", account.ID, "until", until, "model", model)
return true
}
@@ -694,17 +759,17 @@ func (s *RateLimitService) triggerStreamTimeoutError(ctx context.Context, accoun
errorMsg := "Stream data interval timeout (repeated failures) for model: " + model
if err := s.accountRepo.SetError(ctx, account.ID, errorMsg); err != nil {
- log.Printf("[StreamTimeout] SetError failed for account %d: %v", account.ID, err)
+ slog.Warn("stream_timeout_set_error_failed", "account_id", account.ID, "error", err)
return false
}
// 重置超时计数
if s.timeoutCounterCache != nil {
if err := s.timeoutCounterCache.ResetTimeoutCount(ctx, account.ID); err != nil {
- log.Printf("[StreamTimeout] ResetTimeoutCount failed for account %d: %v", account.ID, err)
+ slog.Warn("stream_timeout_reset_count_failed", "account_id", account.ID, "error", err)
}
}
- log.Printf("[StreamTimeout] Account %d marked as error (model: %s)", account.ID, model)
+ slog.Warn("stream_timeout_account_error", "account_id", account.ID, "model", model)
return true
}
diff --git a/backend/internal/service/ratelimit_service_401_test.go b/backend/internal/service/ratelimit_service_401_test.go
new file mode 100644
index 00000000..36357a4b
--- /dev/null
+++ b/backend/internal/service/ratelimit_service_401_test.go
@@ -0,0 +1,121 @@
+//go:build unit
+
+package service
+
+import (
+ "context"
+ "errors"
+ "net/http"
+ "testing"
+ "time"
+
+ "github.com/Wei-Shaw/sub2api/internal/config"
+ "github.com/stretchr/testify/require"
+)
+
+type rateLimitAccountRepoStub struct {
+ mockAccountRepoForGemini
+ setErrorCalls int
+ tempCalls int
+ lastErrorMsg string
+}
+
+func (r *rateLimitAccountRepoStub) SetError(ctx context.Context, id int64, errorMsg string) error {
+ r.setErrorCalls++
+ r.lastErrorMsg = errorMsg
+ return nil
+}
+
+func (r *rateLimitAccountRepoStub) SetTempUnschedulable(ctx context.Context, id int64, until time.Time, reason string) error {
+ r.tempCalls++
+ return nil
+}
+
+type tokenCacheInvalidatorRecorder struct {
+ accounts []*Account
+ err error
+}
+
+func (r *tokenCacheInvalidatorRecorder) InvalidateToken(ctx context.Context, account *Account) error {
+ r.accounts = append(r.accounts, account)
+ return r.err
+}
+
+func TestRateLimitService_HandleUpstreamError_OAuth401MarksError(t *testing.T) {
+ tests := []struct {
+ name string
+ platform string
+ }{
+ {name: "gemini", platform: PlatformGemini},
+ {name: "antigravity", platform: PlatformAntigravity},
+ }
+
+ for _, tt := range tests {
+ t.Run(tt.name, func(t *testing.T) {
+ repo := &rateLimitAccountRepoStub{}
+ invalidator := &tokenCacheInvalidatorRecorder{}
+ service := NewRateLimitService(repo, nil, &config.Config{}, nil, nil)
+ service.SetTokenCacheInvalidator(invalidator)
+ account := &Account{
+ ID: 100,
+ Platform: tt.platform,
+ Type: AccountTypeOAuth,
+ Credentials: map[string]any{
+ "temp_unschedulable_enabled": true,
+ "temp_unschedulable_rules": []any{
+ map[string]any{
+ "error_code": 401,
+ "keywords": []any{"unauthorized"},
+ "duration_minutes": 30,
+ "description": "custom rule",
+ },
+ },
+ },
+ }
+
+ shouldDisable := service.HandleUpstreamError(context.Background(), account, 401, http.Header{}, []byte("unauthorized"))
+
+ require.True(t, shouldDisable)
+ require.Equal(t, 1, repo.setErrorCalls)
+ require.Equal(t, 0, repo.tempCalls)
+ require.Contains(t, repo.lastErrorMsg, "Authentication failed (401)")
+ require.Len(t, invalidator.accounts, 1)
+ })
+ }
+}
+
+func TestRateLimitService_HandleUpstreamError_OAuth401InvalidatorError(t *testing.T) {
+ repo := &rateLimitAccountRepoStub{}
+ invalidator := &tokenCacheInvalidatorRecorder{err: errors.New("boom")}
+ service := NewRateLimitService(repo, nil, &config.Config{}, nil, nil)
+ service.SetTokenCacheInvalidator(invalidator)
+ account := &Account{
+ ID: 101,
+ Platform: PlatformGemini,
+ Type: AccountTypeOAuth,
+ }
+
+ shouldDisable := service.HandleUpstreamError(context.Background(), account, 401, http.Header{}, []byte("unauthorized"))
+
+ require.True(t, shouldDisable)
+ require.Equal(t, 1, repo.setErrorCalls)
+ require.Len(t, invalidator.accounts, 1)
+}
+
+func TestRateLimitService_HandleUpstreamError_NonOAuth401(t *testing.T) {
+ repo := &rateLimitAccountRepoStub{}
+ invalidator := &tokenCacheInvalidatorRecorder{}
+ service := NewRateLimitService(repo, nil, &config.Config{}, nil, nil)
+ service.SetTokenCacheInvalidator(invalidator)
+ account := &Account{
+ ID: 102,
+ Platform: PlatformOpenAI,
+ Type: AccountTypeAPIKey,
+ }
+
+ shouldDisable := service.HandleUpstreamError(context.Background(), account, 401, http.Header{}, []byte("unauthorized"))
+
+ require.True(t, shouldDisable)
+ require.Equal(t, 1, repo.setErrorCalls)
+ require.Empty(t, invalidator.accounts)
+}
diff --git a/backend/internal/service/session_limit_cache.go b/backend/internal/service/session_limit_cache.go
new file mode 100644
index 00000000..f6f0c26a
--- /dev/null
+++ b/backend/internal/service/session_limit_cache.go
@@ -0,0 +1,63 @@
+package service
+
+import (
+ "context"
+ "time"
+)
+
+// SessionLimitCache 管理账号级别的活跃会话跟踪
+// 用于 Anthropic OAuth/SetupToken 账号的会话数量限制
+//
+// Key 格式: session_limit:account:{accountID}
+// 数据结构: Sorted Set (member=sessionUUID, score=timestamp)
+//
+// 会话在空闲超时后自动过期,无需手动清理
+type SessionLimitCache interface {
+ // RegisterSession 注册会话活动
+ // - 如果会话已存在,刷新其时间戳并返回 true
+ // - 如果会话不存在且活跃会话数 < maxSessions,添加新会话并返回 true
+ // - 如果会话不存在且活跃会话数 >= maxSessions,返回 false(拒绝)
+ //
+ // 参数:
+ // accountID: 账号 ID
+ // sessionUUID: 从 metadata.user_id 中提取的会话 UUID
+ // maxSessions: 最大并发会话数限制
+ // idleTimeout: 会话空闲超时时间
+ //
+ // 返回:
+ // allowed: true 表示允许(在限制内或会话已存在),false 表示拒绝(超出限制且是新会话)
+ // error: 操作错误
+ RegisterSession(ctx context.Context, accountID int64, sessionUUID string, maxSessions int, idleTimeout time.Duration) (allowed bool, err error)
+
+ // RefreshSession 刷新现有会话的时间戳
+ // 用于活跃会话保持活动状态
+ RefreshSession(ctx context.Context, accountID int64, sessionUUID string, idleTimeout time.Duration) error
+
+ // GetActiveSessionCount 获取当前活跃会话数
+ // 返回未过期的会话数量
+ GetActiveSessionCount(ctx context.Context, accountID int64) (int, error)
+
+ // GetActiveSessionCountBatch 批量获取多个账号的活跃会话数
+ // 返回 map[accountID]count,查询失败的账号不在 map 中
+ GetActiveSessionCountBatch(ctx context.Context, accountIDs []int64) (map[int64]int, error)
+
+ // IsSessionActive 检查特定会话是否活跃(未过期)
+ IsSessionActive(ctx context.Context, accountID int64, sessionUUID string) (bool, error)
+
+ // ========== 5h窗口费用缓存 ==========
+ // Key 格式: window_cost:account:{accountID}
+ // 用于缓存账号在当前5h窗口内的标准费用,减少数据库聚合查询压力
+
+ // GetWindowCost 获取缓存的窗口费用
+ // 返回 (cost, true, nil) 如果缓存命中
+ // 返回 (0, false, nil) 如果缓存未命中
+ // 返回 (0, false, err) 如果发生错误
+ GetWindowCost(ctx context.Context, accountID int64) (cost float64, hit bool, err error)
+
+ // SetWindowCost 设置窗口费用缓存
+ SetWindowCost(ctx context.Context, accountID int64, cost float64) error
+
+ // GetWindowCostBatch 批量获取窗口费用缓存
+ // 返回 map[accountID]cost,缓存未命中的账号不在 map 中
+ GetWindowCostBatch(ctx context.Context, accountIDs []int64) (map[int64]float64, error)
+}
diff --git a/backend/internal/service/timing_wheel_service.go b/backend/internal/service/timing_wheel_service.go
index c4e64e33..5a2dea75 100644
--- a/backend/internal/service/timing_wheel_service.go
+++ b/backend/internal/service/timing_wheel_service.go
@@ -1,6 +1,7 @@
package service
import (
+ "fmt"
"log"
"sync"
"time"
@@ -8,6 +9,8 @@ import (
"github.com/zeromicro/go-zero/core/collection"
)
+var newTimingWheel = collection.NewTimingWheel
+
// TimingWheelService wraps go-zero's TimingWheel for task scheduling
type TimingWheelService struct {
tw *collection.TimingWheel
@@ -15,18 +18,18 @@ type TimingWheelService struct {
}
// NewTimingWheelService creates a new TimingWheelService instance
-func NewTimingWheelService() *TimingWheelService {
+func NewTimingWheelService() (*TimingWheelService, error) {
// 1 second tick, 3600 slots = supports up to 1 hour delay
// execute function: runs func() type tasks
- tw, err := collection.NewTimingWheel(1*time.Second, 3600, func(key, value any) {
+ tw, err := newTimingWheel(1*time.Second, 3600, func(key, value any) {
if fn, ok := value.(func()); ok {
fn()
}
})
if err != nil {
- panic(err)
+ return nil, fmt.Errorf("创建 timing wheel 失败: %w", err)
}
- return &TimingWheelService{tw: tw}
+ return &TimingWheelService{tw: tw}, nil
}
// Start starts the timing wheel
diff --git a/backend/internal/service/timing_wheel_service_test.go b/backend/internal/service/timing_wheel_service_test.go
new file mode 100644
index 00000000..cd0bffb7
--- /dev/null
+++ b/backend/internal/service/timing_wheel_service_test.go
@@ -0,0 +1,146 @@
+package service
+
+import (
+ "errors"
+ "sync/atomic"
+ "testing"
+ "time"
+
+ "github.com/zeromicro/go-zero/core/collection"
+)
+
+func TestNewTimingWheelService_InitFail_NoPanicAndReturnError(t *testing.T) {
+ original := newTimingWheel
+ t.Cleanup(func() { newTimingWheel = original })
+
+ newTimingWheel = func(_ time.Duration, _ int, _ collection.Execute) (*collection.TimingWheel, error) {
+ return nil, errors.New("boom")
+ }
+
+ svc, err := NewTimingWheelService()
+ if err == nil {
+ t.Fatalf("期望返回 error,但得到 nil")
+ }
+ if svc != nil {
+ t.Fatalf("期望返回 nil svc,但得到非空")
+ }
+}
+
+func TestNewTimingWheelService_Success(t *testing.T) {
+ svc, err := NewTimingWheelService()
+ if err != nil {
+ t.Fatalf("期望 err 为 nil,但得到: %v", err)
+ }
+ if svc == nil {
+ t.Fatalf("期望 svc 非空,但得到 nil")
+ }
+ svc.Stop()
+}
+
+func TestNewTimingWheelService_ExecuteCallbackRunsFunc(t *testing.T) {
+ original := newTimingWheel
+ t.Cleanup(func() { newTimingWheel = original })
+
+ var captured collection.Execute
+ newTimingWheel = func(interval time.Duration, numSlots int, execute collection.Execute) (*collection.TimingWheel, error) {
+ captured = execute
+ return original(interval, numSlots, execute)
+ }
+
+ svc, err := NewTimingWheelService()
+ if err != nil {
+ t.Fatalf("期望 err 为 nil,但得到: %v", err)
+ }
+ if captured == nil {
+ t.Fatalf("期望 captured 非空,但得到 nil")
+ }
+
+ called := false
+ captured("k", func() { called = true })
+ if !called {
+ t.Fatalf("期望 execute 回调触发传入函数执行")
+ }
+
+ svc.Stop()
+}
+
+func TestTimingWheelService_Schedule_ExecutesOnce(t *testing.T) {
+ original := newTimingWheel
+ t.Cleanup(func() { newTimingWheel = original })
+
+ newTimingWheel = func(_ time.Duration, _ int, execute collection.Execute) (*collection.TimingWheel, error) {
+ return original(10*time.Millisecond, 128, execute)
+ }
+
+ svc, err := NewTimingWheelService()
+ if err != nil {
+ t.Fatalf("期望 err 为 nil,但得到: %v", err)
+ }
+ defer svc.Stop()
+
+ ch := make(chan struct{}, 1)
+ svc.Schedule("once", 30*time.Millisecond, func() { ch <- struct{}{} })
+
+ select {
+ case <-ch:
+ case <-time.After(500 * time.Millisecond):
+ t.Fatalf("等待任务执行超时")
+ }
+
+ select {
+ case <-ch:
+ t.Fatalf("任务不应重复执行")
+ case <-time.After(80 * time.Millisecond):
+ }
+}
+
+func TestTimingWheelService_Cancel_PreventsExecution(t *testing.T) {
+ original := newTimingWheel
+ t.Cleanup(func() { newTimingWheel = original })
+
+ newTimingWheel = func(_ time.Duration, _ int, execute collection.Execute) (*collection.TimingWheel, error) {
+ return original(10*time.Millisecond, 128, execute)
+ }
+
+ svc, err := NewTimingWheelService()
+ if err != nil {
+ t.Fatalf("期望 err 为 nil,但得到: %v", err)
+ }
+ defer svc.Stop()
+
+ ch := make(chan struct{}, 1)
+ svc.Schedule("cancel", 80*time.Millisecond, func() { ch <- struct{}{} })
+ svc.Cancel("cancel")
+
+ select {
+ case <-ch:
+ t.Fatalf("任务已取消,不应执行")
+ case <-time.After(200 * time.Millisecond):
+ }
+}
+
+func TestTimingWheelService_ScheduleRecurring_ExecutesMultipleTimes(t *testing.T) {
+ original := newTimingWheel
+ t.Cleanup(func() { newTimingWheel = original })
+
+ newTimingWheel = func(_ time.Duration, _ int, execute collection.Execute) (*collection.TimingWheel, error) {
+ return original(10*time.Millisecond, 128, execute)
+ }
+
+ svc, err := NewTimingWheelService()
+ if err != nil {
+ t.Fatalf("期望 err 为 nil,但得到: %v", err)
+ }
+ defer svc.Stop()
+
+ var count int32
+ svc.ScheduleRecurring("rec", 30*time.Millisecond, func() { atomic.AddInt32(&count, 1) })
+
+ deadline := time.Now().Add(500 * time.Millisecond)
+ for atomic.LoadInt32(&count) < 2 && time.Now().Before(deadline) {
+ time.Sleep(10 * time.Millisecond)
+ }
+ if atomic.LoadInt32(&count) < 2 {
+ t.Fatalf("期望周期任务至少执行 2 次,但只执行了 %d 次", atomic.LoadInt32(&count))
+ }
+}
diff --git a/backend/internal/service/token_cache_invalidator.go b/backend/internal/service/token_cache_invalidator.go
new file mode 100644
index 00000000..1117d2f1
--- /dev/null
+++ b/backend/internal/service/token_cache_invalidator.go
@@ -0,0 +1,41 @@
+package service
+
+import "context"
+
+type TokenCacheInvalidator interface {
+ InvalidateToken(ctx context.Context, account *Account) error
+}
+
+type CompositeTokenCacheInvalidator struct {
+ cache GeminiTokenCache // 统一使用一个缓存接口,通过缓存键前缀区分平台
+}
+
+func NewCompositeTokenCacheInvalidator(cache GeminiTokenCache) *CompositeTokenCacheInvalidator {
+ return &CompositeTokenCacheInvalidator{
+ cache: cache,
+ }
+}
+
+func (c *CompositeTokenCacheInvalidator) InvalidateToken(ctx context.Context, account *Account) error {
+ if c == nil || c.cache == nil || account == nil {
+ return nil
+ }
+ if account.Type != AccountTypeOAuth {
+ return nil
+ }
+
+ var cacheKey string
+ switch account.Platform {
+ case PlatformGemini:
+ cacheKey = GeminiTokenCacheKey(account)
+ case PlatformAntigravity:
+ cacheKey = AntigravityTokenCacheKey(account)
+ case PlatformOpenAI:
+ cacheKey = OpenAITokenCacheKey(account)
+ case PlatformAnthropic:
+ cacheKey = ClaudeTokenCacheKey(account)
+ default:
+ return nil
+ }
+ return c.cache.DeleteAccessToken(ctx, cacheKey)
+}
diff --git a/backend/internal/service/token_cache_invalidator_test.go b/backend/internal/service/token_cache_invalidator_test.go
new file mode 100644
index 00000000..30d208ce
--- /dev/null
+++ b/backend/internal/service/token_cache_invalidator_test.go
@@ -0,0 +1,268 @@
+//go:build unit
+
+package service
+
+import (
+ "context"
+ "errors"
+ "testing"
+ "time"
+
+ "github.com/stretchr/testify/require"
+)
+
+type geminiTokenCacheStub struct {
+ deletedKeys []string
+ deleteErr error
+}
+
+func (s *geminiTokenCacheStub) GetAccessToken(ctx context.Context, cacheKey string) (string, error) {
+ return "", nil
+}
+
+func (s *geminiTokenCacheStub) SetAccessToken(ctx context.Context, cacheKey string, token string, ttl time.Duration) error {
+ return nil
+}
+
+func (s *geminiTokenCacheStub) DeleteAccessToken(ctx context.Context, cacheKey string) error {
+ s.deletedKeys = append(s.deletedKeys, cacheKey)
+ return s.deleteErr
+}
+
+func (s *geminiTokenCacheStub) AcquireRefreshLock(ctx context.Context, cacheKey string, ttl time.Duration) (bool, error) {
+ return true, nil
+}
+
+func (s *geminiTokenCacheStub) ReleaseRefreshLock(ctx context.Context, cacheKey string) error {
+ return nil
+}
+
+func TestCompositeTokenCacheInvalidator_Gemini(t *testing.T) {
+ cache := &geminiTokenCacheStub{}
+ invalidator := NewCompositeTokenCacheInvalidator(cache)
+ account := &Account{
+ ID: 10,
+ Platform: PlatformGemini,
+ Type: AccountTypeOAuth,
+ Credentials: map[string]any{
+ "project_id": "project-x",
+ },
+ }
+
+ err := invalidator.InvalidateToken(context.Background(), account)
+ require.NoError(t, err)
+ require.Equal(t, []string{"gemini:project-x"}, cache.deletedKeys)
+}
+
+func TestCompositeTokenCacheInvalidator_Antigravity(t *testing.T) {
+ cache := &geminiTokenCacheStub{}
+ invalidator := NewCompositeTokenCacheInvalidator(cache)
+ account := &Account{
+ ID: 99,
+ Platform: PlatformAntigravity,
+ Type: AccountTypeOAuth,
+ Credentials: map[string]any{
+ "project_id": "ag-project",
+ },
+ }
+
+ err := invalidator.InvalidateToken(context.Background(), account)
+ require.NoError(t, err)
+ require.Equal(t, []string{"ag:ag-project"}, cache.deletedKeys)
+}
+
+func TestCompositeTokenCacheInvalidator_OpenAI(t *testing.T) {
+ cache := &geminiTokenCacheStub{}
+ invalidator := NewCompositeTokenCacheInvalidator(cache)
+ account := &Account{
+ ID: 500,
+ Platform: PlatformOpenAI,
+ Type: AccountTypeOAuth,
+ Credentials: map[string]any{
+ "access_token": "openai-token",
+ },
+ }
+
+ err := invalidator.InvalidateToken(context.Background(), account)
+ require.NoError(t, err)
+ require.Equal(t, []string{"openai:account:500"}, cache.deletedKeys)
+}
+
+func TestCompositeTokenCacheInvalidator_Claude(t *testing.T) {
+ cache := &geminiTokenCacheStub{}
+ invalidator := NewCompositeTokenCacheInvalidator(cache)
+ account := &Account{
+ ID: 600,
+ Platform: PlatformAnthropic,
+ Type: AccountTypeOAuth,
+ Credentials: map[string]any{
+ "access_token": "claude-token",
+ },
+ }
+
+ err := invalidator.InvalidateToken(context.Background(), account)
+ require.NoError(t, err)
+ require.Equal(t, []string{"claude:account:600"}, cache.deletedKeys)
+}
+
+func TestCompositeTokenCacheInvalidator_SkipNonOAuth(t *testing.T) {
+ cache := &geminiTokenCacheStub{}
+ invalidator := NewCompositeTokenCacheInvalidator(cache)
+
+ tests := []struct {
+ name string
+ account *Account
+ }{
+ {
+ name: "gemini_api_key",
+ account: &Account{
+ ID: 1,
+ Platform: PlatformGemini,
+ Type: AccountTypeAPIKey,
+ },
+ },
+ {
+ name: "openai_api_key",
+ account: &Account{
+ ID: 2,
+ Platform: PlatformOpenAI,
+ Type: AccountTypeAPIKey,
+ },
+ },
+ {
+ name: "claude_api_key",
+ account: &Account{
+ ID: 3,
+ Platform: PlatformAnthropic,
+ Type: AccountTypeAPIKey,
+ },
+ },
+ {
+ name: "claude_setup_token",
+ account: &Account{
+ ID: 4,
+ Platform: PlatformAnthropic,
+ Type: AccountTypeSetupToken,
+ },
+ },
+ }
+
+ for _, tt := range tests {
+ t.Run(tt.name, func(t *testing.T) {
+ cache.deletedKeys = nil
+ err := invalidator.InvalidateToken(context.Background(), tt.account)
+ require.NoError(t, err)
+ require.Empty(t, cache.deletedKeys)
+ })
+ }
+}
+
+func TestCompositeTokenCacheInvalidator_SkipUnsupportedPlatform(t *testing.T) {
+ cache := &geminiTokenCacheStub{}
+ invalidator := NewCompositeTokenCacheInvalidator(cache)
+ account := &Account{
+ ID: 100,
+ Platform: "unknown-platform",
+ Type: AccountTypeOAuth,
+ }
+
+ err := invalidator.InvalidateToken(context.Background(), account)
+ require.NoError(t, err)
+ require.Empty(t, cache.deletedKeys)
+}
+
+func TestCompositeTokenCacheInvalidator_NilCache(t *testing.T) {
+ invalidator := NewCompositeTokenCacheInvalidator(nil)
+ account := &Account{
+ ID: 2,
+ Platform: PlatformGemini,
+ Type: AccountTypeOAuth,
+ }
+
+ err := invalidator.InvalidateToken(context.Background(), account)
+ require.NoError(t, err)
+}
+
+func TestCompositeTokenCacheInvalidator_NilAccount(t *testing.T) {
+ cache := &geminiTokenCacheStub{}
+ invalidator := NewCompositeTokenCacheInvalidator(cache)
+
+ err := invalidator.InvalidateToken(context.Background(), nil)
+ require.NoError(t, err)
+ require.Empty(t, cache.deletedKeys)
+}
+
+func TestCompositeTokenCacheInvalidator_NilInvalidator(t *testing.T) {
+ var invalidator *CompositeTokenCacheInvalidator
+ account := &Account{
+ ID: 5,
+ Platform: PlatformGemini,
+ Type: AccountTypeOAuth,
+ }
+
+ err := invalidator.InvalidateToken(context.Background(), account)
+ require.NoError(t, err)
+}
+
+func TestCompositeTokenCacheInvalidator_DeleteError(t *testing.T) {
+ expectedErr := errors.New("redis connection failed")
+ cache := &geminiTokenCacheStub{deleteErr: expectedErr}
+ invalidator := NewCompositeTokenCacheInvalidator(cache)
+
+ tests := []struct {
+ name string
+ account *Account
+ }{
+ {
+ name: "openai_delete_error",
+ account: &Account{
+ ID: 700,
+ Platform: PlatformOpenAI,
+ Type: AccountTypeOAuth,
+ },
+ },
+ {
+ name: "claude_delete_error",
+ account: &Account{
+ ID: 800,
+ Platform: PlatformAnthropic,
+ Type: AccountTypeOAuth,
+ },
+ },
+ }
+
+ for _, tt := range tests {
+ t.Run(tt.name, func(t *testing.T) {
+ err := invalidator.InvalidateToken(context.Background(), tt.account)
+ require.Error(t, err)
+ require.Equal(t, expectedErr, err)
+ })
+ }
+}
+
+func TestCompositeTokenCacheInvalidator_AllPlatformsIntegration(t *testing.T) {
+ // 测试所有平台的缓存键生成和删除
+ cache := &geminiTokenCacheStub{}
+ invalidator := NewCompositeTokenCacheInvalidator(cache)
+
+ accounts := []*Account{
+ {ID: 1, Platform: PlatformGemini, Type: AccountTypeOAuth, Credentials: map[string]any{"project_id": "gemini-proj"}},
+ {ID: 2, Platform: PlatformAntigravity, Type: AccountTypeOAuth, Credentials: map[string]any{"project_id": "ag-proj"}},
+ {ID: 3, Platform: PlatformOpenAI, Type: AccountTypeOAuth},
+ {ID: 4, Platform: PlatformAnthropic, Type: AccountTypeOAuth},
+ }
+
+ expectedKeys := []string{
+ "gemini:gemini-proj",
+ "ag:ag-proj",
+ "openai:account:3",
+ "claude:account:4",
+ }
+
+ for _, acc := range accounts {
+ err := invalidator.InvalidateToken(context.Background(), acc)
+ require.NoError(t, err)
+ }
+
+ require.Equal(t, expectedKeys, cache.deletedKeys)
+}
diff --git a/backend/internal/service/token_cache_key.go b/backend/internal/service/token_cache_key.go
new file mode 100644
index 00000000..df0c025e
--- /dev/null
+++ b/backend/internal/service/token_cache_key.go
@@ -0,0 +1,15 @@
+package service
+
+import "strconv"
+
+// OpenAITokenCacheKey 生成 OpenAI OAuth 账号的缓存键
+// 格式: "openai:account:{account_id}"
+func OpenAITokenCacheKey(account *Account) string {
+ return "openai:account:" + strconv.FormatInt(account.ID, 10)
+}
+
+// ClaudeTokenCacheKey 生成 Claude (Anthropic) OAuth 账号的缓存键
+// 格式: "claude:account:{account_id}"
+func ClaudeTokenCacheKey(account *Account) string {
+ return "claude:account:" + strconv.FormatInt(account.ID, 10)
+}
diff --git a/backend/internal/service/token_cache_key_test.go b/backend/internal/service/token_cache_key_test.go
new file mode 100644
index 00000000..6215eeaf
--- /dev/null
+++ b/backend/internal/service/token_cache_key_test.go
@@ -0,0 +1,259 @@
+//go:build unit
+
+package service
+
+import (
+ "testing"
+
+ "github.com/stretchr/testify/require"
+)
+
+func TestGeminiTokenCacheKey(t *testing.T) {
+ tests := []struct {
+ name string
+ account *Account
+ expected string
+ }{
+ {
+ name: "with_project_id",
+ account: &Account{
+ ID: 100,
+ Credentials: map[string]any{
+ "project_id": "my-project-123",
+ },
+ },
+ expected: "gemini:my-project-123",
+ },
+ {
+ name: "project_id_with_whitespace",
+ account: &Account{
+ ID: 101,
+ Credentials: map[string]any{
+ "project_id": " project-with-spaces ",
+ },
+ },
+ expected: "gemini:project-with-spaces",
+ },
+ {
+ name: "empty_project_id_fallback_to_account_id",
+ account: &Account{
+ ID: 102,
+ Credentials: map[string]any{
+ "project_id": "",
+ },
+ },
+ expected: "gemini:account:102",
+ },
+ {
+ name: "whitespace_only_project_id_fallback_to_account_id",
+ account: &Account{
+ ID: 103,
+ Credentials: map[string]any{
+ "project_id": " ",
+ },
+ },
+ expected: "gemini:account:103",
+ },
+ {
+ name: "no_project_id_key_fallback_to_account_id",
+ account: &Account{
+ ID: 104,
+ Credentials: map[string]any{},
+ },
+ expected: "gemini:account:104",
+ },
+ {
+ name: "nil_credentials_fallback_to_account_id",
+ account: &Account{
+ ID: 105,
+ Credentials: nil,
+ },
+ expected: "gemini:account:105",
+ },
+ }
+
+ for _, tt := range tests {
+ t.Run(tt.name, func(t *testing.T) {
+ result := GeminiTokenCacheKey(tt.account)
+ require.Equal(t, tt.expected, result)
+ })
+ }
+}
+
+func TestAntigravityTokenCacheKey(t *testing.T) {
+ tests := []struct {
+ name string
+ account *Account
+ expected string
+ }{
+ {
+ name: "with_project_id",
+ account: &Account{
+ ID: 200,
+ Credentials: map[string]any{
+ "project_id": "ag-project-456",
+ },
+ },
+ expected: "ag:ag-project-456",
+ },
+ {
+ name: "project_id_with_whitespace",
+ account: &Account{
+ ID: 201,
+ Credentials: map[string]any{
+ "project_id": " ag-project-spaces ",
+ },
+ },
+ expected: "ag:ag-project-spaces",
+ },
+ {
+ name: "empty_project_id_fallback_to_account_id",
+ account: &Account{
+ ID: 202,
+ Credentials: map[string]any{
+ "project_id": "",
+ },
+ },
+ expected: "ag:account:202",
+ },
+ {
+ name: "whitespace_only_project_id_fallback_to_account_id",
+ account: &Account{
+ ID: 203,
+ Credentials: map[string]any{
+ "project_id": " ",
+ },
+ },
+ expected: "ag:account:203",
+ },
+ {
+ name: "no_project_id_key_fallback_to_account_id",
+ account: &Account{
+ ID: 204,
+ Credentials: map[string]any{},
+ },
+ expected: "ag:account:204",
+ },
+ {
+ name: "nil_credentials_fallback_to_account_id",
+ account: &Account{
+ ID: 205,
+ Credentials: nil,
+ },
+ expected: "ag:account:205",
+ },
+ }
+
+ for _, tt := range tests {
+ t.Run(tt.name, func(t *testing.T) {
+ result := AntigravityTokenCacheKey(tt.account)
+ require.Equal(t, tt.expected, result)
+ })
+ }
+}
+
+func TestOpenAITokenCacheKey(t *testing.T) {
+ tests := []struct {
+ name string
+ account *Account
+ expected string
+ }{
+ {
+ name: "basic_account",
+ account: &Account{
+ ID: 300,
+ },
+ expected: "openai:account:300",
+ },
+ {
+ name: "account_with_credentials",
+ account: &Account{
+ ID: 301,
+ Credentials: map[string]any{
+ "access_token": "test-token",
+ },
+ },
+ expected: "openai:account:301",
+ },
+ {
+ name: "account_id_zero",
+ account: &Account{
+ ID: 0,
+ },
+ expected: "openai:account:0",
+ },
+ {
+ name: "large_account_id",
+ account: &Account{
+ ID: 9999999999,
+ },
+ expected: "openai:account:9999999999",
+ },
+ }
+
+ for _, tt := range tests {
+ t.Run(tt.name, func(t *testing.T) {
+ result := OpenAITokenCacheKey(tt.account)
+ require.Equal(t, tt.expected, result)
+ })
+ }
+}
+
+func TestClaudeTokenCacheKey(t *testing.T) {
+ tests := []struct {
+ name string
+ account *Account
+ expected string
+ }{
+ {
+ name: "basic_account",
+ account: &Account{
+ ID: 400,
+ },
+ expected: "claude:account:400",
+ },
+ {
+ name: "account_with_credentials",
+ account: &Account{
+ ID: 401,
+ Credentials: map[string]any{
+ "access_token": "claude-token",
+ },
+ },
+ expected: "claude:account:401",
+ },
+ {
+ name: "account_id_zero",
+ account: &Account{
+ ID: 0,
+ },
+ expected: "claude:account:0",
+ },
+ {
+ name: "large_account_id",
+ account: &Account{
+ ID: 9999999999,
+ },
+ expected: "claude:account:9999999999",
+ },
+ }
+
+ for _, tt := range tests {
+ t.Run(tt.name, func(t *testing.T) {
+ result := ClaudeTokenCacheKey(tt.account)
+ require.Equal(t, tt.expected, result)
+ })
+ }
+}
+
+func TestCacheKeyUniqueness(t *testing.T) {
+ // 确保不同平台的缓存键不会冲突
+ account := &Account{ID: 123}
+
+ openaiKey := OpenAITokenCacheKey(account)
+ claudeKey := ClaudeTokenCacheKey(account)
+
+ require.NotEqual(t, openaiKey, claudeKey, "OpenAI and Claude cache keys should be different")
+ require.Contains(t, openaiKey, "openai:")
+ require.Contains(t, claudeKey, "claude:")
+}
diff --git a/backend/internal/service/token_refresh_service.go b/backend/internal/service/token_refresh_service.go
index 3ed35f04..26cfd97d 100644
--- a/backend/internal/service/token_refresh_service.go
+++ b/backend/internal/service/token_refresh_service.go
@@ -14,9 +14,10 @@ import (
// TokenRefreshService OAuth token自动刷新服务
// 定期检查并刷新即将过期的token
type TokenRefreshService struct {
- accountRepo AccountRepository
- refreshers []TokenRefresher
- cfg *config.TokenRefreshConfig
+ accountRepo AccountRepository
+ refreshers []TokenRefresher
+ cfg *config.TokenRefreshConfig
+ cacheInvalidator TokenCacheInvalidator
stopCh chan struct{}
wg sync.WaitGroup
@@ -29,12 +30,14 @@ func NewTokenRefreshService(
openaiOAuthService *OpenAIOAuthService,
geminiOAuthService *GeminiOAuthService,
antigravityOAuthService *AntigravityOAuthService,
+ cacheInvalidator TokenCacheInvalidator,
cfg *config.Config,
) *TokenRefreshService {
s := &TokenRefreshService{
- accountRepo: accountRepo,
- cfg: &cfg.TokenRefresh,
- stopCh: make(chan struct{}),
+ accountRepo: accountRepo,
+ cfg: &cfg.TokenRefresh,
+ cacheInvalidator: cacheInvalidator,
+ stopCh: make(chan struct{}),
}
// 注册平台特定的刷新器
@@ -169,6 +172,14 @@ func (s *TokenRefreshService) refreshWithRetry(ctx context.Context, account *Acc
if err := s.accountRepo.Update(ctx, account); err != nil {
return fmt.Errorf("failed to save credentials: %w", err)
}
+ // 对所有 OAuth 账号调用缓存失效(InvalidateToken 内部根据平台判断是否需要处理)
+ if s.cacheInvalidator != nil && account.Type == AccountTypeOAuth {
+ if err := s.cacheInvalidator.InvalidateToken(ctx, account); err != nil {
+ log.Printf("[TokenRefresh] Failed to invalidate token cache for account %d: %v", account.ID, err)
+ } else {
+ log.Printf("[TokenRefresh] Token cache invalidated for account %d", account.ID)
+ }
+ }
return nil
}
diff --git a/backend/internal/service/token_refresh_service_test.go b/backend/internal/service/token_refresh_service_test.go
new file mode 100644
index 00000000..d23a0bb6
--- /dev/null
+++ b/backend/internal/service/token_refresh_service_test.go
@@ -0,0 +1,361 @@
+//go:build unit
+
+package service
+
+import (
+ "context"
+ "errors"
+ "testing"
+ "time"
+
+ "github.com/Wei-Shaw/sub2api/internal/config"
+ "github.com/stretchr/testify/require"
+)
+
+type tokenRefreshAccountRepo struct {
+ mockAccountRepoForGemini
+ updateCalls int
+ setErrorCalls int
+ lastAccount *Account
+ updateErr error
+}
+
+func (r *tokenRefreshAccountRepo) Update(ctx context.Context, account *Account) error {
+ r.updateCalls++
+ r.lastAccount = account
+ return r.updateErr
+}
+
+func (r *tokenRefreshAccountRepo) SetError(ctx context.Context, id int64, errorMsg string) error {
+ r.setErrorCalls++
+ return nil
+}
+
+type tokenCacheInvalidatorStub struct {
+ calls int
+ err error
+}
+
+func (s *tokenCacheInvalidatorStub) InvalidateToken(ctx context.Context, account *Account) error {
+ s.calls++
+ return s.err
+}
+
+type tokenRefresherStub struct {
+ credentials map[string]any
+ err error
+}
+
+func (r *tokenRefresherStub) CanRefresh(account *Account) bool {
+ return true
+}
+
+func (r *tokenRefresherStub) NeedsRefresh(account *Account, refreshWindowDuration time.Duration) bool {
+ return true
+}
+
+func (r *tokenRefresherStub) Refresh(ctx context.Context, account *Account) (map[string]any, error) {
+ if r.err != nil {
+ return nil, r.err
+ }
+ return r.credentials, nil
+}
+
+func TestTokenRefreshService_RefreshWithRetry_InvalidatesCache(t *testing.T) {
+ repo := &tokenRefreshAccountRepo{}
+ invalidator := &tokenCacheInvalidatorStub{}
+ cfg := &config.Config{
+ TokenRefresh: config.TokenRefreshConfig{
+ MaxRetries: 1,
+ RetryBackoffSeconds: 0,
+ },
+ }
+ service := NewTokenRefreshService(repo, nil, nil, nil, nil, invalidator, cfg)
+ account := &Account{
+ ID: 5,
+ Platform: PlatformGemini,
+ Type: AccountTypeOAuth,
+ }
+ refresher := &tokenRefresherStub{
+ credentials: map[string]any{
+ "access_token": "new-token",
+ },
+ }
+
+ err := service.refreshWithRetry(context.Background(), account, refresher)
+ require.NoError(t, err)
+ require.Equal(t, 1, repo.updateCalls)
+ require.Equal(t, 1, invalidator.calls)
+ require.Equal(t, "new-token", account.GetCredential("access_token"))
+}
+
+func TestTokenRefreshService_RefreshWithRetry_InvalidatorErrorIgnored(t *testing.T) {
+ repo := &tokenRefreshAccountRepo{}
+ invalidator := &tokenCacheInvalidatorStub{err: errors.New("invalidate failed")}
+ cfg := &config.Config{
+ TokenRefresh: config.TokenRefreshConfig{
+ MaxRetries: 1,
+ RetryBackoffSeconds: 0,
+ },
+ }
+ service := NewTokenRefreshService(repo, nil, nil, nil, nil, invalidator, cfg)
+ account := &Account{
+ ID: 6,
+ Platform: PlatformGemini,
+ Type: AccountTypeOAuth,
+ }
+ refresher := &tokenRefresherStub{
+ credentials: map[string]any{
+ "access_token": "token",
+ },
+ }
+
+ err := service.refreshWithRetry(context.Background(), account, refresher)
+ require.NoError(t, err)
+ require.Equal(t, 1, repo.updateCalls)
+ require.Equal(t, 1, invalidator.calls)
+}
+
+func TestTokenRefreshService_RefreshWithRetry_NilInvalidator(t *testing.T) {
+ repo := &tokenRefreshAccountRepo{}
+ cfg := &config.Config{
+ TokenRefresh: config.TokenRefreshConfig{
+ MaxRetries: 1,
+ RetryBackoffSeconds: 0,
+ },
+ }
+ service := NewTokenRefreshService(repo, nil, nil, nil, nil, nil, cfg)
+ account := &Account{
+ ID: 7,
+ Platform: PlatformGemini,
+ Type: AccountTypeOAuth,
+ }
+ refresher := &tokenRefresherStub{
+ credentials: map[string]any{
+ "access_token": "token",
+ },
+ }
+
+ err := service.refreshWithRetry(context.Background(), account, refresher)
+ require.NoError(t, err)
+ require.Equal(t, 1, repo.updateCalls)
+}
+
+// TestTokenRefreshService_RefreshWithRetry_Antigravity 测试 Antigravity 平台的缓存失效
+func TestTokenRefreshService_RefreshWithRetry_Antigravity(t *testing.T) {
+ repo := &tokenRefreshAccountRepo{}
+ invalidator := &tokenCacheInvalidatorStub{}
+ cfg := &config.Config{
+ TokenRefresh: config.TokenRefreshConfig{
+ MaxRetries: 1,
+ RetryBackoffSeconds: 0,
+ },
+ }
+ service := NewTokenRefreshService(repo, nil, nil, nil, nil, invalidator, cfg)
+ account := &Account{
+ ID: 8,
+ Platform: PlatformAntigravity,
+ Type: AccountTypeOAuth,
+ }
+ refresher := &tokenRefresherStub{
+ credentials: map[string]any{
+ "access_token": "ag-token",
+ },
+ }
+
+ err := service.refreshWithRetry(context.Background(), account, refresher)
+ require.NoError(t, err)
+ require.Equal(t, 1, repo.updateCalls)
+ require.Equal(t, 1, invalidator.calls) // Antigravity 也应触发缓存失效
+}
+
+// TestTokenRefreshService_RefreshWithRetry_NonOAuthAccount 测试非 OAuth 账号不触发缓存失效
+func TestTokenRefreshService_RefreshWithRetry_NonOAuthAccount(t *testing.T) {
+ repo := &tokenRefreshAccountRepo{}
+ invalidator := &tokenCacheInvalidatorStub{}
+ cfg := &config.Config{
+ TokenRefresh: config.TokenRefreshConfig{
+ MaxRetries: 1,
+ RetryBackoffSeconds: 0,
+ },
+ }
+ service := NewTokenRefreshService(repo, nil, nil, nil, nil, invalidator, cfg)
+ account := &Account{
+ ID: 9,
+ Platform: PlatformGemini,
+ Type: AccountTypeAPIKey, // 非 OAuth
+ }
+ refresher := &tokenRefresherStub{
+ credentials: map[string]any{
+ "access_token": "token",
+ },
+ }
+
+ err := service.refreshWithRetry(context.Background(), account, refresher)
+ require.NoError(t, err)
+ require.Equal(t, 1, repo.updateCalls)
+ require.Equal(t, 0, invalidator.calls) // 非 OAuth 不触发缓存失效
+}
+
+// TestTokenRefreshService_RefreshWithRetry_OtherPlatformOAuth 测试所有 OAuth 平台都触发缓存失效
+func TestTokenRefreshService_RefreshWithRetry_OtherPlatformOAuth(t *testing.T) {
+ repo := &tokenRefreshAccountRepo{}
+ invalidator := &tokenCacheInvalidatorStub{}
+ cfg := &config.Config{
+ TokenRefresh: config.TokenRefreshConfig{
+ MaxRetries: 1,
+ RetryBackoffSeconds: 0,
+ },
+ }
+ service := NewTokenRefreshService(repo, nil, nil, nil, nil, invalidator, cfg)
+ account := &Account{
+ ID: 10,
+ Platform: PlatformOpenAI, // OpenAI OAuth 账户
+ Type: AccountTypeOAuth,
+ }
+ refresher := &tokenRefresherStub{
+ credentials: map[string]any{
+ "access_token": "token",
+ },
+ }
+
+ err := service.refreshWithRetry(context.Background(), account, refresher)
+ require.NoError(t, err)
+ require.Equal(t, 1, repo.updateCalls)
+ require.Equal(t, 1, invalidator.calls) // 所有 OAuth 账户刷新后触发缓存失效
+}
+
+// TestTokenRefreshService_RefreshWithRetry_UpdateFailed 测试更新失败的情况
+func TestTokenRefreshService_RefreshWithRetry_UpdateFailed(t *testing.T) {
+ repo := &tokenRefreshAccountRepo{updateErr: errors.New("update failed")}
+ invalidator := &tokenCacheInvalidatorStub{}
+ cfg := &config.Config{
+ TokenRefresh: config.TokenRefreshConfig{
+ MaxRetries: 1,
+ RetryBackoffSeconds: 0,
+ },
+ }
+ service := NewTokenRefreshService(repo, nil, nil, nil, nil, invalidator, cfg)
+ account := &Account{
+ ID: 11,
+ Platform: PlatformGemini,
+ Type: AccountTypeOAuth,
+ }
+ refresher := &tokenRefresherStub{
+ credentials: map[string]any{
+ "access_token": "token",
+ },
+ }
+
+ err := service.refreshWithRetry(context.Background(), account, refresher)
+ require.Error(t, err)
+ require.Contains(t, err.Error(), "failed to save credentials")
+ require.Equal(t, 1, repo.updateCalls)
+ require.Equal(t, 0, invalidator.calls) // 更新失败时不应触发缓存失效
+}
+
+// TestTokenRefreshService_RefreshWithRetry_RefreshFailed 测试刷新失败的情况
+func TestTokenRefreshService_RefreshWithRetry_RefreshFailed(t *testing.T) {
+ repo := &tokenRefreshAccountRepo{}
+ invalidator := &tokenCacheInvalidatorStub{}
+ cfg := &config.Config{
+ TokenRefresh: config.TokenRefreshConfig{
+ MaxRetries: 2,
+ RetryBackoffSeconds: 0,
+ },
+ }
+ service := NewTokenRefreshService(repo, nil, nil, nil, nil, invalidator, cfg)
+ account := &Account{
+ ID: 12,
+ Platform: PlatformGemini,
+ Type: AccountTypeOAuth,
+ }
+ refresher := &tokenRefresherStub{
+ err: errors.New("refresh failed"),
+ }
+
+ err := service.refreshWithRetry(context.Background(), account, refresher)
+ require.Error(t, err)
+ require.Equal(t, 0, repo.updateCalls) // 刷新失败不应更新
+ require.Equal(t, 0, invalidator.calls) // 刷新失败不应触发缓存失效
+ require.Equal(t, 1, repo.setErrorCalls) // 应设置错误状态
+}
+
+// TestTokenRefreshService_RefreshWithRetry_AntigravityRefreshFailed 测试 Antigravity 刷新失败不设置错误状态
+func TestTokenRefreshService_RefreshWithRetry_AntigravityRefreshFailed(t *testing.T) {
+ repo := &tokenRefreshAccountRepo{}
+ invalidator := &tokenCacheInvalidatorStub{}
+ cfg := &config.Config{
+ TokenRefresh: config.TokenRefreshConfig{
+ MaxRetries: 1,
+ RetryBackoffSeconds: 0,
+ },
+ }
+ service := NewTokenRefreshService(repo, nil, nil, nil, nil, invalidator, cfg)
+ account := &Account{
+ ID: 13,
+ Platform: PlatformAntigravity,
+ Type: AccountTypeOAuth,
+ }
+ refresher := &tokenRefresherStub{
+ err: errors.New("network error"), // 可重试错误
+ }
+
+ err := service.refreshWithRetry(context.Background(), account, refresher)
+ require.Error(t, err)
+ require.Equal(t, 0, repo.updateCalls)
+ require.Equal(t, 0, invalidator.calls)
+ require.Equal(t, 0, repo.setErrorCalls) // Antigravity 可重试错误不设置错误状态
+}
+
+// TestTokenRefreshService_RefreshWithRetry_AntigravityNonRetryableError 测试 Antigravity 不可重试错误
+func TestTokenRefreshService_RefreshWithRetry_AntigravityNonRetryableError(t *testing.T) {
+ repo := &tokenRefreshAccountRepo{}
+ invalidator := &tokenCacheInvalidatorStub{}
+ cfg := &config.Config{
+ TokenRefresh: config.TokenRefreshConfig{
+ MaxRetries: 3,
+ RetryBackoffSeconds: 0,
+ },
+ }
+ service := NewTokenRefreshService(repo, nil, nil, nil, nil, invalidator, cfg)
+ account := &Account{
+ ID: 14,
+ Platform: PlatformAntigravity,
+ Type: AccountTypeOAuth,
+ }
+ refresher := &tokenRefresherStub{
+ err: errors.New("invalid_grant: token revoked"), // 不可重试错误
+ }
+
+ err := service.refreshWithRetry(context.Background(), account, refresher)
+ require.Error(t, err)
+ require.Equal(t, 0, repo.updateCalls)
+ require.Equal(t, 0, invalidator.calls)
+ require.Equal(t, 1, repo.setErrorCalls) // 不可重试错误应设置错误状态
+}
+
+// TestIsNonRetryableRefreshError 测试不可重试错误判断
+func TestIsNonRetryableRefreshError(t *testing.T) {
+ tests := []struct {
+ name string
+ err error
+ expected bool
+ }{
+ {name: "nil_error", err: nil, expected: false},
+ {name: "network_error", err: errors.New("network timeout"), expected: false},
+ {name: "invalid_grant", err: errors.New("invalid_grant"), expected: true},
+ {name: "invalid_client", err: errors.New("invalid_client"), expected: true},
+ {name: "unauthorized_client", err: errors.New("unauthorized_client"), expected: true},
+ {name: "access_denied", err: errors.New("access_denied"), expected: true},
+ {name: "invalid_grant_with_desc", err: errors.New("Error: invalid_grant - token revoked"), expected: true},
+ {name: "case_insensitive", err: errors.New("INVALID_GRANT"), expected: true},
+ }
+
+ for _, tt := range tests {
+ t.Run(tt.name, func(t *testing.T) {
+ result := isNonRetryableRefreshError(tt.err)
+ require.Equal(t, tt.expected, result)
+ })
+ }
+}
diff --git a/backend/internal/service/usage_log.go b/backend/internal/service/usage_log.go
index 62d7fae0..3b0e934f 100644
--- a/backend/internal/service/usage_log.go
+++ b/backend/internal/service/usage_log.go
@@ -33,6 +33,8 @@ type UsageLog struct {
TotalCost float64
ActualCost float64
RateMultiplier float64
+ // AccountRateMultiplier 账号计费倍率快照(nil 表示历史数据,按 1.0 处理)
+ AccountRateMultiplier *float64
BillingType int8
Stream bool
diff --git a/backend/internal/service/wire.go b/backend/internal/service/wire.go
index 5326bace..acc0a5fb 100644
--- a/backend/internal/service/wire.go
+++ b/backend/internal/service/wire.go
@@ -42,9 +42,10 @@ func ProvideTokenRefreshService(
openaiOAuthService *OpenAIOAuthService,
geminiOAuthService *GeminiOAuthService,
antigravityOAuthService *AntigravityOAuthService,
+ cacheInvalidator TokenCacheInvalidator,
cfg *config.Config,
) *TokenRefreshService {
- svc := NewTokenRefreshService(accountRepo, oauthService, openaiOAuthService, geminiOAuthService, antigravityOAuthService, cfg)
+ svc := NewTokenRefreshService(accountRepo, oauthService, openaiOAuthService, geminiOAuthService, antigravityOAuthService, cacheInvalidator, cfg)
svc.Start()
return svc
}
@@ -64,10 +65,13 @@ func ProvideAccountExpiryService(accountRepo AccountRepository) *AccountExpirySe
}
// ProvideTimingWheelService creates and starts TimingWheelService
-func ProvideTimingWheelService() *TimingWheelService {
- svc := NewTimingWheelService()
+func ProvideTimingWheelService() (*TimingWheelService, error) {
+ svc, err := NewTimingWheelService()
+ if err != nil {
+ return nil, err
+ }
svc.Start()
- return svc
+ return svc, nil
}
// ProvideDeferredService creates and starts DeferredService
@@ -108,10 +112,12 @@ func ProvideRateLimitService(
tempUnschedCache TempUnschedCache,
timeoutCounterCache TimeoutCounterCache,
settingService *SettingService,
+ tokenCacheInvalidator TokenCacheInvalidator,
) *RateLimitService {
svc := NewRateLimitService(accountRepo, usageRepo, cfg, geminiQuotaService, tempUnschedCache)
svc.SetTimeoutCounterCache(timeoutCounterCache)
svc.SetSettingService(settingService)
+ svc.SetTokenCacheInvalidator(tokenCacheInvalidator)
return svc
}
@@ -210,10 +216,14 @@ var ProviderSet = wire.NewSet(
NewOpenAIOAuthService,
NewGeminiOAuthService,
NewGeminiQuotaService,
+ NewCompositeTokenCacheInvalidator,
+ wire.Bind(new(TokenCacheInvalidator), new(*CompositeTokenCacheInvalidator)),
NewAntigravityOAuthService,
NewGeminiTokenProvider,
NewGeminiMessagesCompatService,
NewAntigravityTokenProvider,
+ NewOpenAITokenProvider,
+ NewClaudeTokenProvider,
NewAntigravityGatewayService,
ProvideRateLimitService,
NewAccountUsageService,
diff --git a/backend/internal/service/wire_test.go b/backend/internal/service/wire_test.go
new file mode 100644
index 00000000..5f7866f6
--- /dev/null
+++ b/backend/internal/service/wire_test.go
@@ -0,0 +1,37 @@
+package service
+
+import (
+ "errors"
+ "testing"
+ "time"
+
+ "github.com/zeromicro/go-zero/core/collection"
+)
+
+func TestProvideTimingWheelService_ReturnsError(t *testing.T) {
+ original := newTimingWheel
+ t.Cleanup(func() { newTimingWheel = original })
+
+ newTimingWheel = func(_ time.Duration, _ int, _ collection.Execute) (*collection.TimingWheel, error) {
+ return nil, errors.New("boom")
+ }
+
+ svc, err := ProvideTimingWheelService()
+ if err == nil {
+ t.Fatalf("期望返回 error,但得到 nil")
+ }
+ if svc != nil {
+ t.Fatalf("期望返回 nil svc,但得到非空")
+ }
+}
+
+func TestProvideTimingWheelService_Success(t *testing.T) {
+ svc, err := ProvideTimingWheelService()
+ if err != nil {
+ t.Fatalf("期望 err 为 nil,但得到: %v", err)
+ }
+ if svc == nil {
+ t.Fatalf("期望 svc 非空,但得到 nil")
+ }
+ svc.Stop()
+}
diff --git a/backend/internal/web/embed_on.go b/backend/internal/web/embed_on.go
index 35697fbb..7f37d59c 100644
--- a/backend/internal/web/embed_on.go
+++ b/backend/internal/web/embed_on.go
@@ -13,9 +13,15 @@ import (
"strings"
"time"
+ "github.com/Wei-Shaw/sub2api/internal/server/middleware"
"github.com/gin-gonic/gin"
)
+const (
+ // NonceHTMLPlaceholder is the placeholder for nonce in HTML script tags
+ NonceHTMLPlaceholder = "__CSP_NONCE_VALUE__"
+)
+
//go:embed all:dist
var frontendFS embed.FS
@@ -115,6 +121,9 @@ func (s *FrontendServer) fileExists(path string) bool {
}
func (s *FrontendServer) serveIndexHTML(c *gin.Context) {
+ // Get nonce from context (generated by SecurityHeaders middleware)
+ nonce := middleware.GetNonceFromContext(c)
+
// Check cache first
cached := s.cache.Get()
if cached != nil {
@@ -125,9 +134,12 @@ func (s *FrontendServer) serveIndexHTML(c *gin.Context) {
return
}
+ // Replace nonce placeholder with actual nonce before serving
+ content := replaceNoncePlaceholder(cached.Content, nonce)
+
c.Header("ETag", cached.ETag)
c.Header("Cache-Control", "no-cache") // Must revalidate
- c.Data(http.StatusOK, "text/html; charset=utf-8", cached.Content)
+ c.Data(http.StatusOK, "text/html; charset=utf-8", content)
c.Abort()
return
}
@@ -155,24 +167,33 @@ func (s *FrontendServer) serveIndexHTML(c *gin.Context) {
rendered := s.injectSettings(settingsJSON)
s.cache.Set(rendered, settingsJSON)
+ // Replace nonce placeholder with actual nonce before serving
+ content := replaceNoncePlaceholder(rendered, nonce)
+
cached = s.cache.Get()
if cached != nil {
c.Header("ETag", cached.ETag)
}
c.Header("Cache-Control", "no-cache")
- c.Data(http.StatusOK, "text/html; charset=utf-8", rendered)
+ c.Data(http.StatusOK, "text/html; charset=utf-8", content)
c.Abort()
}
func (s *FrontendServer) injectSettings(settingsJSON []byte) []byte {
- // Create the script tag to inject
- script := []byte(``)
+ // Create the script tag to inject with nonce placeholder
+ // The placeholder will be replaced with actual nonce at request time
+ script := []byte(``)
// Inject before
headClose := []byte("")
return bytes.Replace(s.baseHTML, headClose, append(script, headClose...), 1)
}
+// replaceNoncePlaceholder replaces the nonce placeholder with actual nonce value
+func replaceNoncePlaceholder(html []byte, nonce string) []byte {
+ return bytes.ReplaceAll(html, []byte(NonceHTMLPlaceholder), []byte(nonce))
+}
+
// ServeEmbeddedFrontend returns a middleware for serving embedded frontend
// This is the legacy function for backward compatibility when no settings provider is available
func ServeEmbeddedFrontend() gin.HandlerFunc {
diff --git a/backend/internal/web/embed_test.go b/backend/internal/web/embed_test.go
new file mode 100644
index 00000000..50f5a323
--- /dev/null
+++ b/backend/internal/web/embed_test.go
@@ -0,0 +1,660 @@
+//go:build embed
+
+package web
+
+import (
+ "bytes"
+ "context"
+ "net/http"
+ "net/http/httptest"
+ "strings"
+ "testing"
+
+ "github.com/Wei-Shaw/sub2api/internal/server/middleware"
+ "github.com/gin-gonic/gin"
+ "github.com/stretchr/testify/assert"
+ "github.com/stretchr/testify/require"
+)
+
+func init() {
+ gin.SetMode(gin.TestMode)
+}
+
+func TestReplaceNoncePlaceholder(t *testing.T) {
+ t.Run("replaces_single_placeholder", func(t *testing.T) {
+ html := []byte(``)
+ nonce := "abc123xyz"
+
+ result := replaceNoncePlaceholder(html, nonce)
+
+ expected := ``
+ assert.Equal(t, expected, string(result))
+ })
+
+ t.Run("replaces_multiple_placeholders", func(t *testing.T) {
+ html := []byte(``)
+ nonce := "nonce123"
+
+ result := replaceNoncePlaceholder(html, nonce)
+
+ assert.Equal(t, 2, strings.Count(string(result), `nonce="nonce123"`))
+ assert.NotContains(t, string(result), NonceHTMLPlaceholder)
+ })
+
+ t.Run("handles_empty_nonce", func(t *testing.T) {
+ html := []byte(``)
+ nonce := ""
+
+ result := replaceNoncePlaceholder(html, nonce)
+
+ assert.Equal(t, ``, string(result))
+ })
+
+ t.Run("no_placeholder_returns_unchanged", func(t *testing.T) {
+ html := []byte(``)
+ nonce := "abc123"
+
+ result := replaceNoncePlaceholder(html, nonce)
+
+ assert.Equal(t, string(html), string(result))
+ })
+
+ t.Run("handles_empty_html", func(t *testing.T) {
+ html := []byte(``)
+ nonce := "abc123"
+
+ result := replaceNoncePlaceholder(html, nonce)
+
+ assert.Empty(t, result)
+ })
+}
+
+func TestNonceHTMLPlaceholder(t *testing.T) {
+ t.Run("constant_value", func(t *testing.T) {
+ assert.Equal(t, "__CSP_NONCE_VALUE__", NonceHTMLPlaceholder)
+ })
+}
+
+// mockSettingsProvider implements PublicSettingsProvider for testing
+type mockSettingsProvider struct {
+ settings any
+ err error
+ called int
+}
+
+func (m *mockSettingsProvider) GetPublicSettingsForInjection(ctx context.Context) (any, error) {
+ m.called++
+ return m.settings, m.err
+}
+
+func TestFrontendServer_InjectSettings(t *testing.T) {
+ t.Run("injects_settings_with_nonce_placeholder", func(t *testing.T) {
+ provider := &mockSettingsProvider{
+ settings: map[string]string{"key": "value"},
+ }
+
+ server, err := NewFrontendServer(provider)
+ require.NoError(t, err)
+
+ settingsJSON := []byte(`{"test":"data"}`)
+ result := server.injectSettings(settingsJSON)
+
+ // Should contain the script with nonce placeholder
+ assert.Contains(t, string(result), ``)
+ })
+
+ t.Run("injects_before_head_close", func(t *testing.T) {
+ provider := &mockSettingsProvider{
+ settings: map[string]string{"key": "value"},
+ }
+
+ server, err := NewFrontendServer(provider)
+ require.NoError(t, err)
+
+ settingsJSON := []byte(`{}`)
+ result := server.injectSettings(settingsJSON)
+
+ // Script should be injected before
+ headCloseIndex := bytes.Index(result, []byte(""))
+ scriptIndex := bytes.Index(result, []byte(`