diff --git a/.github/workflows/backend-ci.yml b/.github/workflows/backend-ci.yml index 2596a18c..4fd22aff 100644 --- a/.github/workflows/backend-ci.yml +++ b/.github/workflows/backend-ci.yml @@ -44,4 +44,4 @@ jobs: with: version: v2.7 args: --timeout=5m - working-directory: backend + working-directory: backend \ No newline at end of file diff --git a/.gitignore b/.gitignore index 48172982..515ce84f 100644 --- a/.gitignore +++ b/.gitignore @@ -121,7 +121,6 @@ AGENTS.md scripts .code-review-state openspec/ -docs/ code-reviews/ AGENTS.md backend/cmd/server/server @@ -129,4 +128,8 @@ deploy/docker-compose.override.yml .gocache/ vite.config.js docs/* -.serena/ \ No newline at end of file +.serena/ +.codex/ +frontend/coverage/ +aicodex + diff --git a/Dockerfile b/Dockerfile index c9fcf301..645465f1 100644 --- a/Dockerfile +++ b/Dockerfile @@ -36,7 +36,7 @@ RUN pnpm run build FROM ${GOLANG_IMAGE} AS backend-builder # Build arguments for version info (set by CI) -ARG VERSION=docker +ARG VERSION= ARG COMMIT=docker ARG DATE ARG GOPROXY @@ -61,9 +61,13 @@ COPY backend/ ./ COPY --from=frontend-builder /app/backend/internal/web/dist ./internal/web/dist # Build the binary (BuildType=release for CI builds, embed frontend) -RUN CGO_ENABLED=0 GOOS=linux go build \ +# Version precedence: build arg VERSION > cmd/server/VERSION +RUN VERSION_VALUE="${VERSION}" && \ + if [ -z "${VERSION_VALUE}" ]; then VERSION_VALUE="$(tr -d '\r\n' < ./cmd/server/VERSION)"; fi && \ + DATE_VALUE="${DATE:-$(date -u +%Y-%m-%dT%H:%M:%SZ)}" && \ + CGO_ENABLED=0 GOOS=linux go build \ -tags embed \ - -ldflags="-s -w -X main.Commit=${COMMIT} -X main.Date=${DATE:-$(date -u +%Y-%m-%dT%H:%M:%SZ)} -X main.BuildType=release" \ + -ldflags="-s -w -X main.Version=${VERSION_VALUE} -X main.Commit=${COMMIT} -X main.Date=${DATE_VALUE} -X main.BuildType=release" \ -o /app/sub2api \ ./cmd/server diff --git a/Makefile b/Makefile index a5e18a37..b97404eb 100644 --- a/Makefile +++ b/Makefile @@ -1,4 +1,4 @@ -.PHONY: build build-backend build-frontend test test-backend test-frontend +.PHONY: build build-backend build-frontend test test-backend test-frontend secret-scan # 一键编译前后端 build: build-backend build-frontend @@ -20,3 +20,6 @@ test-backend: test-frontend: @pnpm --dir frontend run lint:check @pnpm --dir frontend run typecheck + +secret-scan: + @python3 tools/secret_scan.py diff --git a/README.md b/README.md index 36949b0a..a5f680bf 100644 --- a/README.md +++ b/README.md @@ -363,6 +363,12 @@ default: rate_multiplier: 1.0 ``` +### Sora Status (Temporarily Unavailable) + +> ⚠️ Sora-related features are temporarily unavailable due to technical issues in upstream integration and media delivery. +> Please do not rely on Sora in production at this time. +> Existing `gateway.sora_*` configuration keys are reserved and may not take effect until these issues are resolved. + Additional security-related options are available in `config.yaml`: - `cors.allowed_origins` for CORS allowlist diff --git a/README_CN.md b/README_CN.md index 1e0d1d62..ea35a19d 100644 --- a/README_CN.md +++ b/README_CN.md @@ -139,6 +139,8 @@ curl -sSL https://raw.githubusercontent.com/Wei-Shaw/sub2api/main/deploy/install 使用 Docker Compose 部署,包含 PostgreSQL 和 Redis 容器。 +如果你的服务器是 **Ubuntu 24.04**,建议直接参考:`deploy/ubuntu24-docker-compose-aicodex.md`,其中包含「安装最新版 Docker + docker-compose-aicodex.yml 部署」的完整步骤。 + #### 前置条件 - Docker 20.10+ @@ -370,6 +372,33 @@ default: rate_multiplier: 1.0 ``` +### Sora 功能状态(暂不可用) + +> ⚠️ 当前 Sora 相关功能因上游接入与媒体链路存在技术问题,暂时不可用。 +> 现阶段请勿在生产环境依赖 Sora 能力。 +> 文档中的 `gateway.sora_*` 配置仅作预留,待技术问题修复后再恢复可用。 + +### Sora 媒体签名 URL(功能恢复后可选) + +当配置 `gateway.sora_media_signing_key` 且 `gateway.sora_media_signed_url_ttl_seconds > 0` 时,网关会将 Sora 输出的媒体地址改写为临时签名 URL(`/sora/media-signed/...`)。这样无需 API Key 即可在浏览器中直接访问,且具备过期控制与防篡改能力(签名包含 path + query)。 + +```yaml +gateway: + # /sora/media 是否强制要求 API Key(默认 false) + sora_media_require_api_key: false + # 媒体临时签名密钥(为空则禁用签名) + sora_media_signing_key: "your-signing-key" + # 临时签名 URL 有效期(秒) + sora_media_signed_url_ttl_seconds: 900 +``` + +> 若未配置签名密钥,`/sora/media-signed` 将返回 503。 +> 如需更严格的访问控制,可将 `sora_media_require_api_key` 设为 true,仅允许携带 API Key 的 `/sora/media` 访问。 + +访问策略说明: +- `/sora/media`:内部调用或客户端携带 API Key 才能下载 +- `/sora/media-signed`:外部可访问,但有签名 + 过期控制 + `config.yaml` 还支持以下安全相关配置: - `cors.allowed_origins` 配置 CORS 白名单 @@ -383,6 +412,14 @@ default: - `server.trusted_proxies` 启用可信代理解析 X-Forwarded-For - `turnstile.required` 在 release 模式强制启用 Turnstile +**网关防御纵深建议(重点)** + +- `gateway.upstream_response_read_max_bytes`:限制非流式上游响应读取大小(默认 `8MB`),用于防止异常响应导致内存放大。 +- `gateway.proxy_probe_response_read_max_bytes`:限制代理探测响应读取大小(默认 `1MB`)。 +- `gateway.gemini_debug_response_headers`:默认 `false`,仅在排障时短时开启,避免高频请求日志开销。 +- `/auth/register`、`/auth/login`、`/auth/login/2fa`、`/auth/send-verify-code` 已提供服务端兜底限流(Redis 故障时 fail-close)。 +- 推荐将 WAF/CDN 作为第一层防护,服务端限流与响应读取上限作为第二层兜底;两层同时保留,避免旁路流量与误配置风险。 + **⚠️ 安全警告:HTTP URL 配置** 当 `security.url_allowlist.enabled=false` 时,系统默认执行最小 URL 校验,**拒绝 HTTP URL**,仅允许 HTTPS。要允许 HTTP URL(例如用于开发或内网测试),必须显式设置: @@ -428,6 +465,29 @@ Invalid base URL: invalid url scheme: http ./sub2api ``` +#### HTTP/2 (h2c) 与 HTTP/1.1 回退 + +后端明文端口默认支持 h2c,并保留 HTTP/1.1 回退用于 WebSocket 与旧客户端。浏览器通常不支持 h2c,性能收益主要在反向代理或内网链路。 + +**反向代理示例(Caddy):** + +```caddyfile +transport http { + versions h2c h1 +} +``` + +**验证:** + +```bash +# h2c prior knowledge +curl --http2-prior-knowledge -I http://localhost:8080/health +# HTTP/1.1 回退 +curl --http1.1 -I http://localhost:8080/health +# WebSocket 回退验证(需管理员 token) +websocat -H="Sec-WebSocket-Protocol: sub2api-admin, jwt." ws://localhost:8080/api/v1/admin/ops/ws/qps +``` + #### 开发模式 ```bash diff --git a/backend/Makefile b/backend/Makefile index 6a5d2caa..89db1104 100644 --- a/backend/Makefile +++ b/backend/Makefile @@ -14,4 +14,7 @@ test-integration: go test -tags=integration ./... test-e2e: - go test -tags=e2e ./... + ./scripts/e2e-test.sh + +test-e2e-local: + go test -tags=e2e -v -timeout=300s ./internal/integration/... diff --git a/backend/cmd/jwtgen/main.go b/backend/cmd/jwtgen/main.go index ce4718bf..2ff7358b 100644 --- a/backend/cmd/jwtgen/main.go +++ b/backend/cmd/jwtgen/main.go @@ -17,7 +17,7 @@ func main() { email := flag.String("email", "", "Admin email to issue a JWT for (defaults to first active admin)") flag.Parse() - cfg, err := config.Load() + cfg, err := config.LoadForBootstrap() if err != nil { log.Fatalf("failed to load config: %v", err) } diff --git a/backend/cmd/server/VERSION b/backend/cmd/server/VERSION index 8b063cd5..0a752ff7 100644 --- a/backend/cmd/server/VERSION +++ b/backend/cmd/server/VERSION @@ -1 +1 @@ -0.1.83 \ No newline at end of file +0.1.83.4 diff --git a/backend/cmd/server/main.go b/backend/cmd/server/main.go index f8a7d313..63095209 100644 --- a/backend/cmd/server/main.go +++ b/backend/cmd/server/main.go @@ -8,7 +8,6 @@ import ( "errors" "flag" "log" - "log/slog" "net/http" "os" "os/signal" @@ -19,11 +18,14 @@ import ( _ "github.com/Wei-Shaw/sub2api/ent/runtime" "github.com/Wei-Shaw/sub2api/internal/config" "github.com/Wei-Shaw/sub2api/internal/handler" + "github.com/Wei-Shaw/sub2api/internal/pkg/logger" "github.com/Wei-Shaw/sub2api/internal/server/middleware" "github.com/Wei-Shaw/sub2api/internal/setup" "github.com/Wei-Shaw/sub2api/internal/web" "github.com/gin-gonic/gin" + "golang.org/x/net/http2" + "golang.org/x/net/http2/h2c" ) //go:embed VERSION @@ -38,7 +40,12 @@ var ( ) func init() { - // Read version from embedded VERSION file + // 如果 Version 已通过 ldflags 注入(例如 -X main.Version=...),则不要覆盖。 + if strings.TrimSpace(Version) != "" { + return + } + + // 默认从 embedded VERSION 文件读取版本号(编译期打包进二进制)。 Version = strings.TrimSpace(embeddedVersion) if Version == "" { Version = "0.0.0-dev" @@ -47,22 +54,9 @@ func init() { // initLogger configures the default slog handler based on gin.Mode(). // In non-release mode, Debug level logs are enabled. -func initLogger() { - var level slog.Level - if gin.Mode() == gin.ReleaseMode { - level = slog.LevelInfo - } else { - level = slog.LevelDebug - } - handler := slog.NewTextHandler(os.Stderr, &slog.HandlerOptions{ - Level: level, - }) - slog.SetDefault(slog.New(handler)) -} - func main() { - // Initialize slog logger based on gin mode - initLogger() + logger.InitBootstrap() + defer logger.Sync() // Parse command line flags setupMode := flag.Bool("setup", false, "Run setup wizard in CLI mode") @@ -122,16 +116,26 @@ func runSetupServer() { log.Printf("Setup wizard available at http://%s", addr) log.Println("Complete the setup wizard to configure Sub2API") - if err := r.Run(addr); err != nil { + server := &http.Server{ + Addr: addr, + Handler: h2c.NewHandler(r, &http2.Server{}), + ReadHeaderTimeout: 30 * time.Second, + IdleTimeout: 120 * time.Second, + } + + if err := server.ListenAndServe(); err != nil && !errors.Is(err, http.ErrServerClosed) { log.Fatalf("Failed to start setup server: %v", err) } } func runMainServer() { - cfg, err := config.Load() + cfg, err := config.LoadForBootstrap() if err != nil { log.Fatalf("Failed to load config: %v", err) } + if err := logger.Init(logger.OptionsFromConfig(cfg.Log)); err != nil { + log.Fatalf("Failed to initialize logger: %v", err) + } if cfg.RunMode == config.RunModeSimple { log.Println("⚠️ WARNING: Running in SIMPLE mode - billing and quota checks are DISABLED") } diff --git a/backend/cmd/server/wire.go b/backend/cmd/server/wire.go index d9ff788e..1ba6b184 100644 --- a/backend/cmd/server/wire.go +++ b/backend/cmd/server/wire.go @@ -67,14 +67,19 @@ func provideCleanup( opsAlertEvaluator *service.OpsAlertEvaluatorService, opsCleanup *service.OpsCleanupService, opsScheduledReport *service.OpsScheduledReportService, + opsSystemLogSink *service.OpsSystemLogSink, + soraMediaCleanup *service.SoraMediaCleanupService, schedulerSnapshot *service.SchedulerSnapshotService, tokenRefresh *service.TokenRefreshService, accountExpiry *service.AccountExpiryService, subscriptionExpiry *service.SubscriptionExpiryService, usageCleanup *service.UsageCleanupService, + idempotencyCleanup *service.IdempotencyCleanupService, pricing *service.PricingService, emailQueue *service.EmailQueueService, billingCache *service.BillingCacheService, + usageRecordWorkerPool *service.UsageRecordWorkerPool, + subscriptionService *service.SubscriptionService, oauth *service.OAuthService, openaiOAuth *service.OpenAIOAuthService, geminiOAuth *service.GeminiOAuthService, @@ -101,6 +106,18 @@ func provideCleanup( } return nil }}, + {"OpsSystemLogSink", func() error { + if opsSystemLogSink != nil { + opsSystemLogSink.Stop() + } + return nil + }}, + {"SoraMediaCleanupService", func() error { + if soraMediaCleanup != nil { + soraMediaCleanup.Stop() + } + return nil + }}, {"OpsAlertEvaluatorService", func() error { if opsAlertEvaluator != nil { opsAlertEvaluator.Stop() @@ -131,6 +148,12 @@ func provideCleanup( } return nil }}, + {"IdempotencyCleanupService", func() error { + if idempotencyCleanup != nil { + idempotencyCleanup.Stop() + } + return nil + }}, {"TokenRefreshService", func() error { tokenRefresh.Stop() return nil @@ -143,6 +166,12 @@ func provideCleanup( subscriptionExpiry.Stop() return nil }}, + {"SubscriptionService", func() error { + if subscriptionService != nil { + subscriptionService.Stop() + } + return nil + }}, {"PricingService", func() error { pricing.Stop() return nil @@ -155,6 +184,12 @@ func provideCleanup( billingCache.Stop() return nil }}, + {"UsageRecordWorkerPool", func() error { + if usageRecordWorkerPool != nil { + usageRecordWorkerPool.Stop() + } + return nil + }}, {"OAuthService", func() error { oauth.Stop() return nil diff --git a/backend/cmd/server/wire_gen.go b/backend/cmd/server/wire_gen.go index 5ccd797e..7a277112 100644 --- a/backend/cmd/server/wire_gen.go +++ b/backend/cmd/server/wire_gen.go @@ -65,8 +65,8 @@ func initializeApplication(buildInfo handler.BuildInfo) (*Application, error) { apiKeyAuthCacheInvalidator := service.ProvideAPIKeyAuthCacheInvalidator(apiKeyService) promoService := service.NewPromoService(promoCodeRepository, userRepository, billingCacheService, client, apiKeyAuthCacheInvalidator) authService := service.NewAuthService(userRepository, redeemCodeRepository, refreshTokenCache, configConfig, settingService, emailService, turnstileService, emailQueueService, promoService) - userService := service.NewUserService(userRepository, apiKeyAuthCacheInvalidator) - subscriptionService := service.NewSubscriptionService(groupRepository, userSubscriptionRepository, billingCacheService) + userService := service.NewUserService(userRepository, apiKeyAuthCacheInvalidator, billingCache) + subscriptionService := service.NewSubscriptionService(groupRepository, userSubscriptionRepository, billingCacheService, client, configConfig) redeemCache := repository.NewRedeemCache(redisClient) redeemService := service.NewRedeemService(redeemCodeRepository, userRepository, subscriptionService, redeemCache, billingCacheService, client, apiKeyAuthCacheInvalidator) secretEncryptor, err := repository.NewAESEncryptor(configConfig) @@ -98,10 +98,11 @@ func initializeApplication(buildInfo handler.BuildInfo) (*Application, error) { dashboardHandler := admin.NewDashboardHandler(dashboardService, dashboardAggregationService) schedulerCache := repository.NewSchedulerCache(redisClient) accountRepository := repository.NewAccountRepository(client, db, schedulerCache) + soraAccountRepository := repository.NewSoraAccountRepository(db) proxyRepository := repository.NewProxyRepository(client, db) proxyExitInfoProber := repository.NewProxyExitInfoProber(configConfig) proxyLatencyCache := repository.NewProxyLatencyCache(redisClient) - adminService := service.NewAdminService(userRepository, groupRepository, accountRepository, proxyRepository, apiKeyRepository, redeemCodeRepository, userGroupRateRepository, billingCacheService, proxyExitInfoProber, proxyLatencyCache, apiKeyAuthCacheInvalidator) + adminService := service.NewAdminService(userRepository, groupRepository, accountRepository, soraAccountRepository, proxyRepository, apiKeyRepository, redeemCodeRepository, userGroupRateRepository, billingCacheService, proxyExitInfoProber, proxyLatencyCache, apiKeyAuthCacheInvalidator) concurrencyCache := repository.ProvideConcurrencyCache(redisClient, configConfig) concurrencyService := service.ProvideConcurrencyService(concurrencyCache, accountRepository, configConfig) adminUserHandler := admin.NewUserHandler(adminService, concurrencyService) @@ -159,14 +160,17 @@ func initializeApplication(buildInfo handler.BuildInfo) (*Application, error) { openAITokenProvider := service.NewOpenAITokenProvider(accountRepository, geminiTokenCache, openAIOAuthService) openAIGatewayService := service.NewOpenAIGatewayService(accountRepository, usageLogRepository, userRepository, userSubscriptionRepository, gatewayCache, configConfig, schedulerSnapshotService, concurrencyService, billingService, rateLimitService, billingCacheService, httpUpstream, deferredService, openAITokenProvider) geminiMessagesCompatService := service.NewGeminiMessagesCompatService(accountRepository, groupRepository, gatewayCache, schedulerSnapshotService, geminiTokenProvider, rateLimitService, httpUpstream, antigravityGatewayService, configConfig) - opsService := service.NewOpsService(opsRepository, settingRepository, configConfig, accountRepository, userRepository, concurrencyService, gatewayService, openAIGatewayService, geminiMessagesCompatService, antigravityGatewayService) + opsSystemLogSink := service.ProvideOpsSystemLogSink(opsRepository) + opsService := service.NewOpsService(opsRepository, settingRepository, configConfig, accountRepository, userRepository, concurrencyService, gatewayService, openAIGatewayService, geminiMessagesCompatService, antigravityGatewayService, opsSystemLogSink) settingHandler := admin.NewSettingHandler(settingService, emailService, turnstileService, opsService) opsHandler := admin.NewOpsHandler(opsService) updateCache := repository.NewUpdateCache(redisClient) gitHubReleaseClient := repository.ProvideGitHubReleaseClient(configConfig) serviceBuildInfo := provideServiceBuildInfo(buildInfo) updateService := service.ProvideUpdateService(updateCache, gitHubReleaseClient, serviceBuildInfo) - systemHandler := handler.ProvideSystemHandler(updateService) + idempotencyRepository := repository.NewIdempotencyRepository(client, db) + systemOperationLockService := service.ProvideSystemOperationLockService(idempotencyRepository, configConfig) + systemHandler := handler.ProvideSystemHandler(updateService, systemOperationLockService) adminSubscriptionHandler := admin.NewSubscriptionHandler(subscriptionService) usageCleanupRepository := repository.NewUsageCleanupRepository(client, db) usageCleanupService := service.ProvideUsageCleanupService(usageCleanupRepository, timingWheelService, dashboardAggregationService, configConfig) @@ -180,11 +184,18 @@ func initializeApplication(buildInfo handler.BuildInfo) (*Application, error) { errorPassthroughService := service.NewErrorPassthroughService(errorPassthroughRepository, errorPassthroughCache) errorPassthroughHandler := admin.NewErrorPassthroughHandler(errorPassthroughService) adminHandlers := handler.ProvideAdminHandlers(dashboardHandler, adminUserHandler, groupHandler, accountHandler, adminAnnouncementHandler, oAuthHandler, openAIOAuthHandler, geminiOAuthHandler, antigravityOAuthHandler, proxyHandler, adminRedeemHandler, promoHandler, settingHandler, opsHandler, systemHandler, adminSubscriptionHandler, adminUsageHandler, userAttributeHandler, errorPassthroughHandler) - gatewayHandler := handler.NewGatewayHandler(gatewayService, geminiMessagesCompatService, antigravityGatewayService, userService, concurrencyService, billingCacheService, usageService, apiKeyService, errorPassthroughService, configConfig) - openAIGatewayHandler := handler.NewOpenAIGatewayHandler(openAIGatewayService, concurrencyService, billingCacheService, apiKeyService, errorPassthroughService, configConfig) + usageRecordWorkerPool := service.NewUsageRecordWorkerPool(configConfig) + gatewayHandler := handler.NewGatewayHandler(gatewayService, geminiMessagesCompatService, antigravityGatewayService, userService, concurrencyService, billingCacheService, usageService, apiKeyService, usageRecordWorkerPool, errorPassthroughService, configConfig) + openAIGatewayHandler := handler.NewOpenAIGatewayHandler(openAIGatewayService, concurrencyService, billingCacheService, apiKeyService, usageRecordWorkerPool, errorPassthroughService, configConfig) + soraDirectClient := service.ProvideSoraDirectClient(configConfig, httpUpstream, openAITokenProvider, accountRepository, soraAccountRepository) + soraMediaStorage := service.ProvideSoraMediaStorage(configConfig) + soraGatewayService := service.NewSoraGatewayService(soraDirectClient, soraMediaStorage, rateLimitService, configConfig) + soraGatewayHandler := handler.NewSoraGatewayHandler(gatewayService, soraGatewayService, concurrencyService, billingCacheService, usageRecordWorkerPool, configConfig) handlerSettingHandler := handler.ProvideSettingHandler(settingService, buildInfo) totpHandler := handler.NewTotpHandler(totpService) - handlers := handler.ProvideHandlers(authHandler, userHandler, apiKeyHandler, usageHandler, redeemHandler, subscriptionHandler, announcementHandler, adminHandlers, gatewayHandler, openAIGatewayHandler, handlerSettingHandler, totpHandler) + idempotencyCoordinator := service.ProvideIdempotencyCoordinator(idempotencyRepository, configConfig) + idempotencyCleanupService := service.ProvideIdempotencyCleanupService(idempotencyRepository, configConfig) + handlers := handler.ProvideHandlers(authHandler, userHandler, apiKeyHandler, usageHandler, redeemHandler, subscriptionHandler, announcementHandler, adminHandlers, gatewayHandler, openAIGatewayHandler, soraGatewayHandler, handlerSettingHandler, totpHandler, idempotencyCoordinator, idempotencyCleanupService) jwtAuthMiddleware := middleware.NewJWTAuthMiddleware(authService, userService) adminAuthMiddleware := middleware.NewAdminAuthMiddleware(authService, userService, settingService) apiKeyAuthMiddleware := middleware.NewAPIKeyAuthMiddleware(apiKeyService, subscriptionService, configConfig) @@ -195,10 +206,11 @@ func initializeApplication(buildInfo handler.BuildInfo) (*Application, error) { opsAlertEvaluatorService := service.ProvideOpsAlertEvaluatorService(opsService, opsRepository, emailService, redisClient, configConfig) opsCleanupService := service.ProvideOpsCleanupService(opsRepository, db, redisClient, configConfig) opsScheduledReportService := service.ProvideOpsScheduledReportService(opsService, userService, emailService, redisClient, configConfig) - tokenRefreshService := service.ProvideTokenRefreshService(accountRepository, oAuthService, openAIOAuthService, geminiOAuthService, antigravityOAuthService, compositeTokenCacheInvalidator, schedulerCache, configConfig) + soraMediaCleanupService := service.ProvideSoraMediaCleanupService(soraMediaStorage, configConfig) + tokenRefreshService := service.ProvideTokenRefreshService(accountRepository, soraAccountRepository, oAuthService, openAIOAuthService, geminiOAuthService, antigravityOAuthService, compositeTokenCacheInvalidator, schedulerCache, configConfig) accountExpiryService := service.ProvideAccountExpiryService(accountRepository) subscriptionExpiryService := service.ProvideSubscriptionExpiryService(userSubscriptionRepository) - v := provideCleanup(client, redisClient, opsMetricsCollector, opsAggregationService, opsAlertEvaluatorService, opsCleanupService, opsScheduledReportService, schedulerSnapshotService, tokenRefreshService, accountExpiryService, subscriptionExpiryService, usageCleanupService, pricingService, emailQueueService, billingCacheService, oAuthService, openAIOAuthService, geminiOAuthService, antigravityOAuthService) + v := provideCleanup(client, redisClient, opsMetricsCollector, opsAggregationService, opsAlertEvaluatorService, opsCleanupService, opsScheduledReportService, opsSystemLogSink, soraMediaCleanupService, schedulerSnapshotService, tokenRefreshService, accountExpiryService, subscriptionExpiryService, usageCleanupService, idempotencyCleanupService, pricingService, emailQueueService, billingCacheService, usageRecordWorkerPool, subscriptionService, oAuthService, openAIOAuthService, geminiOAuthService, antigravityOAuthService) application := &Application{ Server: httpServer, Cleanup: v, @@ -228,14 +240,19 @@ func provideCleanup( opsAlertEvaluator *service.OpsAlertEvaluatorService, opsCleanup *service.OpsCleanupService, opsScheduledReport *service.OpsScheduledReportService, + opsSystemLogSink *service.OpsSystemLogSink, + soraMediaCleanup *service.SoraMediaCleanupService, schedulerSnapshot *service.SchedulerSnapshotService, tokenRefresh *service.TokenRefreshService, accountExpiry *service.AccountExpiryService, subscriptionExpiry *service.SubscriptionExpiryService, usageCleanup *service.UsageCleanupService, + idempotencyCleanup *service.IdempotencyCleanupService, pricing *service.PricingService, emailQueue *service.EmailQueueService, billingCache *service.BillingCacheService, + usageRecordWorkerPool *service.UsageRecordWorkerPool, + subscriptionService *service.SubscriptionService, oauth *service.OAuthService, openaiOAuth *service.OpenAIOAuthService, geminiOAuth *service.GeminiOAuthService, @@ -261,6 +278,18 @@ func provideCleanup( } return nil }}, + {"OpsSystemLogSink", func() error { + if opsSystemLogSink != nil { + opsSystemLogSink.Stop() + } + return nil + }}, + {"SoraMediaCleanupService", func() error { + if soraMediaCleanup != nil { + soraMediaCleanup.Stop() + } + return nil + }}, {"OpsAlertEvaluatorService", func() error { if opsAlertEvaluator != nil { opsAlertEvaluator.Stop() @@ -291,6 +320,12 @@ func provideCleanup( } return nil }}, + {"IdempotencyCleanupService", func() error { + if idempotencyCleanup != nil { + idempotencyCleanup.Stop() + } + return nil + }}, {"TokenRefreshService", func() error { tokenRefresh.Stop() return nil @@ -303,6 +338,12 @@ func provideCleanup( subscriptionExpiry.Stop() return nil }}, + {"SubscriptionService", func() error { + if subscriptionService != nil { + subscriptionService.Stop() + } + return nil + }}, {"PricingService", func() error { pricing.Stop() return nil @@ -315,6 +356,12 @@ func provideCleanup( billingCache.Stop() return nil }}, + {"UsageRecordWorkerPool", func() error { + if usageRecordWorkerPool != nil { + usageRecordWorkerPool.Stop() + } + return nil + }}, {"OAuthService", func() error { oauth.Stop() return nil diff --git a/backend/ent/apikey.go b/backend/ent/apikey.go index 91d71964..760851c8 100644 --- a/backend/ent/apikey.go +++ b/backend/ent/apikey.go @@ -36,6 +36,8 @@ type APIKey struct { GroupID *int64 `json:"group_id,omitempty"` // Status holds the value of the "status" field. Status string `json:"status,omitempty"` + // Last usage time of this API key + LastUsedAt *time.Time `json:"last_used_at,omitempty"` // Allowed IPs/CIDRs, e.g. ["192.168.1.100", "10.0.0.0/8"] IPWhitelist []string `json:"ip_whitelist,omitempty"` // Blocked IPs/CIDRs @@ -109,7 +111,7 @@ func (*APIKey) scanValues(columns []string) ([]any, error) { values[i] = new(sql.NullInt64) case apikey.FieldKey, apikey.FieldName, apikey.FieldStatus: values[i] = new(sql.NullString) - case apikey.FieldCreatedAt, apikey.FieldUpdatedAt, apikey.FieldDeletedAt, apikey.FieldExpiresAt: + case apikey.FieldCreatedAt, apikey.FieldUpdatedAt, apikey.FieldDeletedAt, apikey.FieldLastUsedAt, apikey.FieldExpiresAt: values[i] = new(sql.NullTime) default: values[i] = new(sql.UnknownType) @@ -182,6 +184,13 @@ func (_m *APIKey) assignValues(columns []string, values []any) error { } else if value.Valid { _m.Status = value.String } + case apikey.FieldLastUsedAt: + if value, ok := values[i].(*sql.NullTime); !ok { + return fmt.Errorf("unexpected type %T for field last_used_at", values[i]) + } else if value.Valid { + _m.LastUsedAt = new(time.Time) + *_m.LastUsedAt = value.Time + } case apikey.FieldIPWhitelist: if value, ok := values[i].(*[]byte); !ok { return fmt.Errorf("unexpected type %T for field ip_whitelist", values[i]) @@ -296,6 +305,11 @@ func (_m *APIKey) String() string { builder.WriteString("status=") builder.WriteString(_m.Status) builder.WriteString(", ") + if v := _m.LastUsedAt; v != nil { + builder.WriteString("last_used_at=") + builder.WriteString(v.Format(time.ANSIC)) + } + builder.WriteString(", ") builder.WriteString("ip_whitelist=") builder.WriteString(fmt.Sprintf("%v", _m.IPWhitelist)) builder.WriteString(", ") diff --git a/backend/ent/apikey/apikey.go b/backend/ent/apikey/apikey.go index ac2a6008..6abea56b 100644 --- a/backend/ent/apikey/apikey.go +++ b/backend/ent/apikey/apikey.go @@ -31,6 +31,8 @@ const ( FieldGroupID = "group_id" // FieldStatus holds the string denoting the status field in the database. FieldStatus = "status" + // FieldLastUsedAt holds the string denoting the last_used_at field in the database. + FieldLastUsedAt = "last_used_at" // FieldIPWhitelist holds the string denoting the ip_whitelist field in the database. FieldIPWhitelist = "ip_whitelist" // FieldIPBlacklist holds the string denoting the ip_blacklist field in the database. @@ -83,6 +85,7 @@ var Columns = []string{ FieldName, FieldGroupID, FieldStatus, + FieldLastUsedAt, FieldIPWhitelist, FieldIPBlacklist, FieldQuota, @@ -176,6 +179,11 @@ func ByStatus(opts ...sql.OrderTermOption) OrderOption { return sql.OrderByField(FieldStatus, opts...).ToFunc() } +// ByLastUsedAt orders the results by the last_used_at field. +func ByLastUsedAt(opts ...sql.OrderTermOption) OrderOption { + return sql.OrderByField(FieldLastUsedAt, opts...).ToFunc() +} + // ByQuota orders the results by the quota field. func ByQuota(opts ...sql.OrderTermOption) OrderOption { return sql.OrderByField(FieldQuota, opts...).ToFunc() diff --git a/backend/ent/apikey/where.go b/backend/ent/apikey/where.go index f54f44b7..c1900ee1 100644 --- a/backend/ent/apikey/where.go +++ b/backend/ent/apikey/where.go @@ -95,6 +95,11 @@ func Status(v string) predicate.APIKey { return predicate.APIKey(sql.FieldEQ(FieldStatus, v)) } +// LastUsedAt applies equality check predicate on the "last_used_at" field. It's identical to LastUsedAtEQ. +func LastUsedAt(v time.Time) predicate.APIKey { + return predicate.APIKey(sql.FieldEQ(FieldLastUsedAt, v)) +} + // Quota applies equality check predicate on the "quota" field. It's identical to QuotaEQ. func Quota(v float64) predicate.APIKey { return predicate.APIKey(sql.FieldEQ(FieldQuota, v)) @@ -485,6 +490,56 @@ func StatusContainsFold(v string) predicate.APIKey { return predicate.APIKey(sql.FieldContainsFold(FieldStatus, v)) } +// LastUsedAtEQ applies the EQ predicate on the "last_used_at" field. +func LastUsedAtEQ(v time.Time) predicate.APIKey { + return predicate.APIKey(sql.FieldEQ(FieldLastUsedAt, v)) +} + +// LastUsedAtNEQ applies the NEQ predicate on the "last_used_at" field. +func LastUsedAtNEQ(v time.Time) predicate.APIKey { + return predicate.APIKey(sql.FieldNEQ(FieldLastUsedAt, v)) +} + +// LastUsedAtIn applies the In predicate on the "last_used_at" field. +func LastUsedAtIn(vs ...time.Time) predicate.APIKey { + return predicate.APIKey(sql.FieldIn(FieldLastUsedAt, vs...)) +} + +// LastUsedAtNotIn applies the NotIn predicate on the "last_used_at" field. +func LastUsedAtNotIn(vs ...time.Time) predicate.APIKey { + return predicate.APIKey(sql.FieldNotIn(FieldLastUsedAt, vs...)) +} + +// LastUsedAtGT applies the GT predicate on the "last_used_at" field. +func LastUsedAtGT(v time.Time) predicate.APIKey { + return predicate.APIKey(sql.FieldGT(FieldLastUsedAt, v)) +} + +// LastUsedAtGTE applies the GTE predicate on the "last_used_at" field. +func LastUsedAtGTE(v time.Time) predicate.APIKey { + return predicate.APIKey(sql.FieldGTE(FieldLastUsedAt, v)) +} + +// LastUsedAtLT applies the LT predicate on the "last_used_at" field. +func LastUsedAtLT(v time.Time) predicate.APIKey { + return predicate.APIKey(sql.FieldLT(FieldLastUsedAt, v)) +} + +// LastUsedAtLTE applies the LTE predicate on the "last_used_at" field. +func LastUsedAtLTE(v time.Time) predicate.APIKey { + return predicate.APIKey(sql.FieldLTE(FieldLastUsedAt, v)) +} + +// LastUsedAtIsNil applies the IsNil predicate on the "last_used_at" field. +func LastUsedAtIsNil() predicate.APIKey { + return predicate.APIKey(sql.FieldIsNull(FieldLastUsedAt)) +} + +// LastUsedAtNotNil applies the NotNil predicate on the "last_used_at" field. +func LastUsedAtNotNil() predicate.APIKey { + return predicate.APIKey(sql.FieldNotNull(FieldLastUsedAt)) +} + // IPWhitelistIsNil applies the IsNil predicate on the "ip_whitelist" field. func IPWhitelistIsNil() predicate.APIKey { return predicate.APIKey(sql.FieldIsNull(FieldIPWhitelist)) diff --git a/backend/ent/apikey_create.go b/backend/ent/apikey_create.go index 71540975..bc506585 100644 --- a/backend/ent/apikey_create.go +++ b/backend/ent/apikey_create.go @@ -113,6 +113,20 @@ func (_c *APIKeyCreate) SetNillableStatus(v *string) *APIKeyCreate { return _c } +// SetLastUsedAt sets the "last_used_at" field. +func (_c *APIKeyCreate) SetLastUsedAt(v time.Time) *APIKeyCreate { + _c.mutation.SetLastUsedAt(v) + return _c +} + +// SetNillableLastUsedAt sets the "last_used_at" field if the given value is not nil. +func (_c *APIKeyCreate) SetNillableLastUsedAt(v *time.Time) *APIKeyCreate { + if v != nil { + _c.SetLastUsedAt(*v) + } + return _c +} + // SetIPWhitelist sets the "ip_whitelist" field. func (_c *APIKeyCreate) SetIPWhitelist(v []string) *APIKeyCreate { _c.mutation.SetIPWhitelist(v) @@ -353,6 +367,10 @@ func (_c *APIKeyCreate) createSpec() (*APIKey, *sqlgraph.CreateSpec) { _spec.SetField(apikey.FieldStatus, field.TypeString, value) _node.Status = value } + if value, ok := _c.mutation.LastUsedAt(); ok { + _spec.SetField(apikey.FieldLastUsedAt, field.TypeTime, value) + _node.LastUsedAt = &value + } if value, ok := _c.mutation.IPWhitelist(); ok { _spec.SetField(apikey.FieldIPWhitelist, field.TypeJSON, value) _node.IPWhitelist = value @@ -571,6 +589,24 @@ func (u *APIKeyUpsert) UpdateStatus() *APIKeyUpsert { return u } +// SetLastUsedAt sets the "last_used_at" field. +func (u *APIKeyUpsert) SetLastUsedAt(v time.Time) *APIKeyUpsert { + u.Set(apikey.FieldLastUsedAt, v) + return u +} + +// UpdateLastUsedAt sets the "last_used_at" field to the value that was provided on create. +func (u *APIKeyUpsert) UpdateLastUsedAt() *APIKeyUpsert { + u.SetExcluded(apikey.FieldLastUsedAt) + return u +} + +// ClearLastUsedAt clears the value of the "last_used_at" field. +func (u *APIKeyUpsert) ClearLastUsedAt() *APIKeyUpsert { + u.SetNull(apikey.FieldLastUsedAt) + return u +} + // SetIPWhitelist sets the "ip_whitelist" field. func (u *APIKeyUpsert) SetIPWhitelist(v []string) *APIKeyUpsert { u.Set(apikey.FieldIPWhitelist, v) @@ -818,6 +854,27 @@ func (u *APIKeyUpsertOne) UpdateStatus() *APIKeyUpsertOne { }) } +// SetLastUsedAt sets the "last_used_at" field. +func (u *APIKeyUpsertOne) SetLastUsedAt(v time.Time) *APIKeyUpsertOne { + return u.Update(func(s *APIKeyUpsert) { + s.SetLastUsedAt(v) + }) +} + +// UpdateLastUsedAt sets the "last_used_at" field to the value that was provided on create. +func (u *APIKeyUpsertOne) UpdateLastUsedAt() *APIKeyUpsertOne { + return u.Update(func(s *APIKeyUpsert) { + s.UpdateLastUsedAt() + }) +} + +// ClearLastUsedAt clears the value of the "last_used_at" field. +func (u *APIKeyUpsertOne) ClearLastUsedAt() *APIKeyUpsertOne { + return u.Update(func(s *APIKeyUpsert) { + s.ClearLastUsedAt() + }) +} + // SetIPWhitelist sets the "ip_whitelist" field. func (u *APIKeyUpsertOne) SetIPWhitelist(v []string) *APIKeyUpsertOne { return u.Update(func(s *APIKeyUpsert) { @@ -1246,6 +1303,27 @@ func (u *APIKeyUpsertBulk) UpdateStatus() *APIKeyUpsertBulk { }) } +// SetLastUsedAt sets the "last_used_at" field. +func (u *APIKeyUpsertBulk) SetLastUsedAt(v time.Time) *APIKeyUpsertBulk { + return u.Update(func(s *APIKeyUpsert) { + s.SetLastUsedAt(v) + }) +} + +// UpdateLastUsedAt sets the "last_used_at" field to the value that was provided on create. +func (u *APIKeyUpsertBulk) UpdateLastUsedAt() *APIKeyUpsertBulk { + return u.Update(func(s *APIKeyUpsert) { + s.UpdateLastUsedAt() + }) +} + +// ClearLastUsedAt clears the value of the "last_used_at" field. +func (u *APIKeyUpsertBulk) ClearLastUsedAt() *APIKeyUpsertBulk { + return u.Update(func(s *APIKeyUpsert) { + s.ClearLastUsedAt() + }) +} + // SetIPWhitelist sets the "ip_whitelist" field. func (u *APIKeyUpsertBulk) SetIPWhitelist(v []string) *APIKeyUpsertBulk { return u.Update(func(s *APIKeyUpsert) { diff --git a/backend/ent/apikey_update.go b/backend/ent/apikey_update.go index b4ff230b..6ca01854 100644 --- a/backend/ent/apikey_update.go +++ b/backend/ent/apikey_update.go @@ -134,6 +134,26 @@ func (_u *APIKeyUpdate) SetNillableStatus(v *string) *APIKeyUpdate { return _u } +// SetLastUsedAt sets the "last_used_at" field. +func (_u *APIKeyUpdate) SetLastUsedAt(v time.Time) *APIKeyUpdate { + _u.mutation.SetLastUsedAt(v) + return _u +} + +// SetNillableLastUsedAt sets the "last_used_at" field if the given value is not nil. +func (_u *APIKeyUpdate) SetNillableLastUsedAt(v *time.Time) *APIKeyUpdate { + if v != nil { + _u.SetLastUsedAt(*v) + } + return _u +} + +// ClearLastUsedAt clears the value of the "last_used_at" field. +func (_u *APIKeyUpdate) ClearLastUsedAt() *APIKeyUpdate { + _u.mutation.ClearLastUsedAt() + return _u +} + // SetIPWhitelist sets the "ip_whitelist" field. func (_u *APIKeyUpdate) SetIPWhitelist(v []string) *APIKeyUpdate { _u.mutation.SetIPWhitelist(v) @@ -390,6 +410,12 @@ func (_u *APIKeyUpdate) sqlSave(ctx context.Context) (_node int, err error) { if value, ok := _u.mutation.Status(); ok { _spec.SetField(apikey.FieldStatus, field.TypeString, value) } + if value, ok := _u.mutation.LastUsedAt(); ok { + _spec.SetField(apikey.FieldLastUsedAt, field.TypeTime, value) + } + if _u.mutation.LastUsedAtCleared() { + _spec.ClearField(apikey.FieldLastUsedAt, field.TypeTime) + } if value, ok := _u.mutation.IPWhitelist(); ok { _spec.SetField(apikey.FieldIPWhitelist, field.TypeJSON, value) } @@ -655,6 +681,26 @@ func (_u *APIKeyUpdateOne) SetNillableStatus(v *string) *APIKeyUpdateOne { return _u } +// SetLastUsedAt sets the "last_used_at" field. +func (_u *APIKeyUpdateOne) SetLastUsedAt(v time.Time) *APIKeyUpdateOne { + _u.mutation.SetLastUsedAt(v) + return _u +} + +// SetNillableLastUsedAt sets the "last_used_at" field if the given value is not nil. +func (_u *APIKeyUpdateOne) SetNillableLastUsedAt(v *time.Time) *APIKeyUpdateOne { + if v != nil { + _u.SetLastUsedAt(*v) + } + return _u +} + +// ClearLastUsedAt clears the value of the "last_used_at" field. +func (_u *APIKeyUpdateOne) ClearLastUsedAt() *APIKeyUpdateOne { + _u.mutation.ClearLastUsedAt() + return _u +} + // SetIPWhitelist sets the "ip_whitelist" field. func (_u *APIKeyUpdateOne) SetIPWhitelist(v []string) *APIKeyUpdateOne { _u.mutation.SetIPWhitelist(v) @@ -941,6 +987,12 @@ func (_u *APIKeyUpdateOne) sqlSave(ctx context.Context) (_node *APIKey, err erro if value, ok := _u.mutation.Status(); ok { _spec.SetField(apikey.FieldStatus, field.TypeString, value) } + if value, ok := _u.mutation.LastUsedAt(); ok { + _spec.SetField(apikey.FieldLastUsedAt, field.TypeTime, value) + } + if _u.mutation.LastUsedAtCleared() { + _spec.ClearField(apikey.FieldLastUsedAt, field.TypeTime) + } if value, ok := _u.mutation.IPWhitelist(); ok { _spec.SetField(apikey.FieldIPWhitelist, field.TypeJSON, value) } diff --git a/backend/ent/client.go b/backend/ent/client.go index a791c081..504c1755 100644 --- a/backend/ent/client.go +++ b/backend/ent/client.go @@ -26,6 +26,7 @@ import ( "github.com/Wei-Shaw/sub2api/ent/promocodeusage" "github.com/Wei-Shaw/sub2api/ent/proxy" "github.com/Wei-Shaw/sub2api/ent/redeemcode" + "github.com/Wei-Shaw/sub2api/ent/securitysecret" "github.com/Wei-Shaw/sub2api/ent/setting" "github.com/Wei-Shaw/sub2api/ent/usagecleanuptask" "github.com/Wei-Shaw/sub2api/ent/usagelog" @@ -65,6 +66,8 @@ type Client struct { Proxy *ProxyClient // RedeemCode is the client for interacting with the RedeemCode builders. RedeemCode *RedeemCodeClient + // SecuritySecret is the client for interacting with the SecuritySecret builders. + SecuritySecret *SecuritySecretClient // Setting is the client for interacting with the Setting builders. Setting *SettingClient // UsageCleanupTask is the client for interacting with the UsageCleanupTask builders. @@ -103,6 +106,7 @@ func (c *Client) init() { c.PromoCodeUsage = NewPromoCodeUsageClient(c.config) c.Proxy = NewProxyClient(c.config) c.RedeemCode = NewRedeemCodeClient(c.config) + c.SecuritySecret = NewSecuritySecretClient(c.config) c.Setting = NewSettingClient(c.config) c.UsageCleanupTask = NewUsageCleanupTaskClient(c.config) c.UsageLog = NewUsageLogClient(c.config) @@ -214,6 +218,7 @@ func (c *Client) Tx(ctx context.Context) (*Tx, error) { PromoCodeUsage: NewPromoCodeUsageClient(cfg), Proxy: NewProxyClient(cfg), RedeemCode: NewRedeemCodeClient(cfg), + SecuritySecret: NewSecuritySecretClient(cfg), Setting: NewSettingClient(cfg), UsageCleanupTask: NewUsageCleanupTaskClient(cfg), UsageLog: NewUsageLogClient(cfg), @@ -252,6 +257,7 @@ func (c *Client) BeginTx(ctx context.Context, opts *sql.TxOptions) (*Tx, error) PromoCodeUsage: NewPromoCodeUsageClient(cfg), Proxy: NewProxyClient(cfg), RedeemCode: NewRedeemCodeClient(cfg), + SecuritySecret: NewSecuritySecretClient(cfg), Setting: NewSettingClient(cfg), UsageCleanupTask: NewUsageCleanupTaskClient(cfg), UsageLog: NewUsageLogClient(cfg), @@ -291,8 +297,8 @@ func (c *Client) Use(hooks ...Hook) { for _, n := range []interface{ Use(...Hook) }{ c.APIKey, c.Account, c.AccountGroup, c.Announcement, c.AnnouncementRead, c.ErrorPassthroughRule, c.Group, c.PromoCode, c.PromoCodeUsage, c.Proxy, - c.RedeemCode, c.Setting, c.UsageCleanupTask, c.UsageLog, c.User, - c.UserAllowedGroup, c.UserAttributeDefinition, c.UserAttributeValue, + c.RedeemCode, c.SecuritySecret, c.Setting, c.UsageCleanupTask, c.UsageLog, + c.User, c.UserAllowedGroup, c.UserAttributeDefinition, c.UserAttributeValue, c.UserSubscription, } { n.Use(hooks...) @@ -305,8 +311,8 @@ func (c *Client) Intercept(interceptors ...Interceptor) { for _, n := range []interface{ Intercept(...Interceptor) }{ c.APIKey, c.Account, c.AccountGroup, c.Announcement, c.AnnouncementRead, c.ErrorPassthroughRule, c.Group, c.PromoCode, c.PromoCodeUsage, c.Proxy, - c.RedeemCode, c.Setting, c.UsageCleanupTask, c.UsageLog, c.User, - c.UserAllowedGroup, c.UserAttributeDefinition, c.UserAttributeValue, + c.RedeemCode, c.SecuritySecret, c.Setting, c.UsageCleanupTask, c.UsageLog, + c.User, c.UserAllowedGroup, c.UserAttributeDefinition, c.UserAttributeValue, c.UserSubscription, } { n.Intercept(interceptors...) @@ -338,6 +344,8 @@ func (c *Client) Mutate(ctx context.Context, m Mutation) (Value, error) { return c.Proxy.mutate(ctx, m) case *RedeemCodeMutation: return c.RedeemCode.mutate(ctx, m) + case *SecuritySecretMutation: + return c.SecuritySecret.mutate(ctx, m) case *SettingMutation: return c.Setting.mutate(ctx, m) case *UsageCleanupTaskMutation: @@ -2197,6 +2205,139 @@ func (c *RedeemCodeClient) mutate(ctx context.Context, m *RedeemCodeMutation) (V } } +// SecuritySecretClient is a client for the SecuritySecret schema. +type SecuritySecretClient struct { + config +} + +// NewSecuritySecretClient returns a client for the SecuritySecret from the given config. +func NewSecuritySecretClient(c config) *SecuritySecretClient { + return &SecuritySecretClient{config: c} +} + +// Use adds a list of mutation hooks to the hooks stack. +// A call to `Use(f, g, h)` equals to `securitysecret.Hooks(f(g(h())))`. +func (c *SecuritySecretClient) Use(hooks ...Hook) { + c.hooks.SecuritySecret = append(c.hooks.SecuritySecret, hooks...) +} + +// Intercept adds a list of query interceptors to the interceptors stack. +// A call to `Intercept(f, g, h)` equals to `securitysecret.Intercept(f(g(h())))`. +func (c *SecuritySecretClient) Intercept(interceptors ...Interceptor) { + c.inters.SecuritySecret = append(c.inters.SecuritySecret, interceptors...) +} + +// Create returns a builder for creating a SecuritySecret entity. +func (c *SecuritySecretClient) Create() *SecuritySecretCreate { + mutation := newSecuritySecretMutation(c.config, OpCreate) + return &SecuritySecretCreate{config: c.config, hooks: c.Hooks(), mutation: mutation} +} + +// CreateBulk returns a builder for creating a bulk of SecuritySecret entities. +func (c *SecuritySecretClient) CreateBulk(builders ...*SecuritySecretCreate) *SecuritySecretCreateBulk { + return &SecuritySecretCreateBulk{config: c.config, builders: builders} +} + +// MapCreateBulk creates a bulk creation builder from the given slice. For each item in the slice, the function creates +// a builder and applies setFunc on it. +func (c *SecuritySecretClient) MapCreateBulk(slice any, setFunc func(*SecuritySecretCreate, int)) *SecuritySecretCreateBulk { + rv := reflect.ValueOf(slice) + if rv.Kind() != reflect.Slice { + return &SecuritySecretCreateBulk{err: fmt.Errorf("calling to SecuritySecretClient.MapCreateBulk with wrong type %T, need slice", slice)} + } + builders := make([]*SecuritySecretCreate, rv.Len()) + for i := 0; i < rv.Len(); i++ { + builders[i] = c.Create() + setFunc(builders[i], i) + } + return &SecuritySecretCreateBulk{config: c.config, builders: builders} +} + +// Update returns an update builder for SecuritySecret. +func (c *SecuritySecretClient) Update() *SecuritySecretUpdate { + mutation := newSecuritySecretMutation(c.config, OpUpdate) + return &SecuritySecretUpdate{config: c.config, hooks: c.Hooks(), mutation: mutation} +} + +// UpdateOne returns an update builder for the given entity. +func (c *SecuritySecretClient) UpdateOne(_m *SecuritySecret) *SecuritySecretUpdateOne { + mutation := newSecuritySecretMutation(c.config, OpUpdateOne, withSecuritySecret(_m)) + return &SecuritySecretUpdateOne{config: c.config, hooks: c.Hooks(), mutation: mutation} +} + +// UpdateOneID returns an update builder for the given id. +func (c *SecuritySecretClient) UpdateOneID(id int64) *SecuritySecretUpdateOne { + mutation := newSecuritySecretMutation(c.config, OpUpdateOne, withSecuritySecretID(id)) + return &SecuritySecretUpdateOne{config: c.config, hooks: c.Hooks(), mutation: mutation} +} + +// Delete returns a delete builder for SecuritySecret. +func (c *SecuritySecretClient) Delete() *SecuritySecretDelete { + mutation := newSecuritySecretMutation(c.config, OpDelete) + return &SecuritySecretDelete{config: c.config, hooks: c.Hooks(), mutation: mutation} +} + +// DeleteOne returns a builder for deleting the given entity. +func (c *SecuritySecretClient) DeleteOne(_m *SecuritySecret) *SecuritySecretDeleteOne { + return c.DeleteOneID(_m.ID) +} + +// DeleteOneID returns a builder for deleting the given entity by its id. +func (c *SecuritySecretClient) DeleteOneID(id int64) *SecuritySecretDeleteOne { + builder := c.Delete().Where(securitysecret.ID(id)) + builder.mutation.id = &id + builder.mutation.op = OpDeleteOne + return &SecuritySecretDeleteOne{builder} +} + +// Query returns a query builder for SecuritySecret. +func (c *SecuritySecretClient) Query() *SecuritySecretQuery { + return &SecuritySecretQuery{ + config: c.config, + ctx: &QueryContext{Type: TypeSecuritySecret}, + inters: c.Interceptors(), + } +} + +// Get returns a SecuritySecret entity by its id. +func (c *SecuritySecretClient) Get(ctx context.Context, id int64) (*SecuritySecret, error) { + return c.Query().Where(securitysecret.ID(id)).Only(ctx) +} + +// GetX is like Get, but panics if an error occurs. +func (c *SecuritySecretClient) GetX(ctx context.Context, id int64) *SecuritySecret { + obj, err := c.Get(ctx, id) + if err != nil { + panic(err) + } + return obj +} + +// Hooks returns the client hooks. +func (c *SecuritySecretClient) Hooks() []Hook { + return c.hooks.SecuritySecret +} + +// Interceptors returns the client interceptors. +func (c *SecuritySecretClient) Interceptors() []Interceptor { + return c.inters.SecuritySecret +} + +func (c *SecuritySecretClient) mutate(ctx context.Context, m *SecuritySecretMutation) (Value, error) { + switch m.Op() { + case OpCreate: + return (&SecuritySecretCreate{config: c.config, hooks: c.Hooks(), mutation: m}).Save(ctx) + case OpUpdate: + return (&SecuritySecretUpdate{config: c.config, hooks: c.Hooks(), mutation: m}).Save(ctx) + case OpUpdateOne: + return (&SecuritySecretUpdateOne{config: c.config, hooks: c.Hooks(), mutation: m}).Save(ctx) + case OpDelete, OpDeleteOne: + return (&SecuritySecretDelete{config: c.config, hooks: c.Hooks(), mutation: m}).Exec(ctx) + default: + return nil, fmt.Errorf("ent: unknown SecuritySecret mutation op: %q", m.Op()) + } +} + // SettingClient is a client for the Setting schema. type SettingClient struct { config @@ -3607,13 +3748,13 @@ type ( hooks struct { APIKey, Account, AccountGroup, Announcement, AnnouncementRead, ErrorPassthroughRule, Group, PromoCode, PromoCodeUsage, Proxy, RedeemCode, - Setting, UsageCleanupTask, UsageLog, User, UserAllowedGroup, + SecuritySecret, Setting, UsageCleanupTask, UsageLog, User, UserAllowedGroup, UserAttributeDefinition, UserAttributeValue, UserSubscription []ent.Hook } inters struct { APIKey, Account, AccountGroup, Announcement, AnnouncementRead, ErrorPassthroughRule, Group, PromoCode, PromoCodeUsage, Proxy, RedeemCode, - Setting, UsageCleanupTask, UsageLog, User, UserAllowedGroup, + SecuritySecret, Setting, UsageCleanupTask, UsageLog, User, UserAllowedGroup, UserAttributeDefinition, UserAttributeValue, UserSubscription []ent.Interceptor } ) diff --git a/backend/ent/ent.go b/backend/ent/ent.go index 5767a167..c4ec3387 100644 --- a/backend/ent/ent.go +++ b/backend/ent/ent.go @@ -23,6 +23,7 @@ import ( "github.com/Wei-Shaw/sub2api/ent/promocodeusage" "github.com/Wei-Shaw/sub2api/ent/proxy" "github.com/Wei-Shaw/sub2api/ent/redeemcode" + "github.com/Wei-Shaw/sub2api/ent/securitysecret" "github.com/Wei-Shaw/sub2api/ent/setting" "github.com/Wei-Shaw/sub2api/ent/usagecleanuptask" "github.com/Wei-Shaw/sub2api/ent/usagelog" @@ -102,6 +103,7 @@ func checkColumn(t, c string) error { promocodeusage.Table: promocodeusage.ValidColumn, proxy.Table: proxy.ValidColumn, redeemcode.Table: redeemcode.ValidColumn, + securitysecret.Table: securitysecret.ValidColumn, setting.Table: setting.ValidColumn, usagecleanuptask.Table: usagecleanuptask.ValidColumn, usagelog.Table: usagelog.ValidColumn, diff --git a/backend/ent/group.go b/backend/ent/group.go index 3c8d68b5..79ec5bf5 100644 --- a/backend/ent/group.go +++ b/backend/ent/group.go @@ -52,6 +52,14 @@ type Group struct { ImagePrice2k *float64 `json:"image_price_2k,omitempty"` // ImagePrice4k holds the value of the "image_price_4k" field. ImagePrice4k *float64 `json:"image_price_4k,omitempty"` + // SoraImagePrice360 holds the value of the "sora_image_price_360" field. + SoraImagePrice360 *float64 `json:"sora_image_price_360,omitempty"` + // SoraImagePrice540 holds the value of the "sora_image_price_540" field. + SoraImagePrice540 *float64 `json:"sora_image_price_540,omitempty"` + // SoraVideoPricePerRequest holds the value of the "sora_video_price_per_request" field. + SoraVideoPricePerRequest *float64 `json:"sora_video_price_per_request,omitempty"` + // SoraVideoPricePerRequestHd holds the value of the "sora_video_price_per_request_hd" field. + SoraVideoPricePerRequestHd *float64 `json:"sora_video_price_per_request_hd,omitempty"` // 是否仅允许 Claude Code 客户端 ClaudeCodeOnly bool `json:"claude_code_only,omitempty"` // 非 Claude Code 请求降级使用的分组 ID @@ -178,7 +186,7 @@ func (*Group) scanValues(columns []string) ([]any, error) { values[i] = new([]byte) case group.FieldIsExclusive, group.FieldClaudeCodeOnly, group.FieldModelRoutingEnabled, group.FieldMcpXMLInject: values[i] = new(sql.NullBool) - case group.FieldRateMultiplier, group.FieldDailyLimitUsd, group.FieldWeeklyLimitUsd, group.FieldMonthlyLimitUsd, group.FieldImagePrice1k, group.FieldImagePrice2k, group.FieldImagePrice4k: + case group.FieldRateMultiplier, group.FieldDailyLimitUsd, group.FieldWeeklyLimitUsd, group.FieldMonthlyLimitUsd, group.FieldImagePrice1k, group.FieldImagePrice2k, group.FieldImagePrice4k, group.FieldSoraImagePrice360, group.FieldSoraImagePrice540, group.FieldSoraVideoPricePerRequest, group.FieldSoraVideoPricePerRequestHd: values[i] = new(sql.NullFloat64) case group.FieldID, group.FieldDefaultValidityDays, group.FieldFallbackGroupID, group.FieldFallbackGroupIDOnInvalidRequest, group.FieldSortOrder: values[i] = new(sql.NullInt64) @@ -317,6 +325,34 @@ func (_m *Group) assignValues(columns []string, values []any) error { _m.ImagePrice4k = new(float64) *_m.ImagePrice4k = value.Float64 } + case group.FieldSoraImagePrice360: + if value, ok := values[i].(*sql.NullFloat64); !ok { + return fmt.Errorf("unexpected type %T for field sora_image_price_360", values[i]) + } else if value.Valid { + _m.SoraImagePrice360 = new(float64) + *_m.SoraImagePrice360 = value.Float64 + } + case group.FieldSoraImagePrice540: + if value, ok := values[i].(*sql.NullFloat64); !ok { + return fmt.Errorf("unexpected type %T for field sora_image_price_540", values[i]) + } else if value.Valid { + _m.SoraImagePrice540 = new(float64) + *_m.SoraImagePrice540 = value.Float64 + } + case group.FieldSoraVideoPricePerRequest: + if value, ok := values[i].(*sql.NullFloat64); !ok { + return fmt.Errorf("unexpected type %T for field sora_video_price_per_request", values[i]) + } else if value.Valid { + _m.SoraVideoPricePerRequest = new(float64) + *_m.SoraVideoPricePerRequest = value.Float64 + } + case group.FieldSoraVideoPricePerRequestHd: + if value, ok := values[i].(*sql.NullFloat64); !ok { + return fmt.Errorf("unexpected type %T for field sora_video_price_per_request_hd", values[i]) + } else if value.Valid { + _m.SoraVideoPricePerRequestHd = new(float64) + *_m.SoraVideoPricePerRequestHd = value.Float64 + } case group.FieldClaudeCodeOnly: if value, ok := values[i].(*sql.NullBool); !ok { return fmt.Errorf("unexpected type %T for field claude_code_only", values[i]) @@ -514,6 +550,26 @@ func (_m *Group) String() string { builder.WriteString(fmt.Sprintf("%v", *v)) } builder.WriteString(", ") + if v := _m.SoraImagePrice360; v != nil { + builder.WriteString("sora_image_price_360=") + builder.WriteString(fmt.Sprintf("%v", *v)) + } + builder.WriteString(", ") + if v := _m.SoraImagePrice540; v != nil { + builder.WriteString("sora_image_price_540=") + builder.WriteString(fmt.Sprintf("%v", *v)) + } + builder.WriteString(", ") + if v := _m.SoraVideoPricePerRequest; v != nil { + builder.WriteString("sora_video_price_per_request=") + builder.WriteString(fmt.Sprintf("%v", *v)) + } + builder.WriteString(", ") + if v := _m.SoraVideoPricePerRequestHd; v != nil { + builder.WriteString("sora_video_price_per_request_hd=") + builder.WriteString(fmt.Sprintf("%v", *v)) + } + builder.WriteString(", ") builder.WriteString("claude_code_only=") builder.WriteString(fmt.Sprintf("%v", _m.ClaudeCodeOnly)) builder.WriteString(", ") diff --git a/backend/ent/group/group.go b/backend/ent/group/group.go index 31c67756..133123a1 100644 --- a/backend/ent/group/group.go +++ b/backend/ent/group/group.go @@ -49,6 +49,14 @@ const ( FieldImagePrice2k = "image_price_2k" // FieldImagePrice4k holds the string denoting the image_price_4k field in the database. FieldImagePrice4k = "image_price_4k" + // FieldSoraImagePrice360 holds the string denoting the sora_image_price_360 field in the database. + FieldSoraImagePrice360 = "sora_image_price_360" + // FieldSoraImagePrice540 holds the string denoting the sora_image_price_540 field in the database. + FieldSoraImagePrice540 = "sora_image_price_540" + // FieldSoraVideoPricePerRequest holds the string denoting the sora_video_price_per_request field in the database. + FieldSoraVideoPricePerRequest = "sora_video_price_per_request" + // FieldSoraVideoPricePerRequestHd holds the string denoting the sora_video_price_per_request_hd field in the database. + FieldSoraVideoPricePerRequestHd = "sora_video_price_per_request_hd" // FieldClaudeCodeOnly holds the string denoting the claude_code_only field in the database. FieldClaudeCodeOnly = "claude_code_only" // FieldFallbackGroupID holds the string denoting the fallback_group_id field in the database. @@ -157,6 +165,10 @@ var Columns = []string{ FieldImagePrice1k, FieldImagePrice2k, FieldImagePrice4k, + FieldSoraImagePrice360, + FieldSoraImagePrice540, + FieldSoraVideoPricePerRequest, + FieldSoraVideoPricePerRequestHd, FieldClaudeCodeOnly, FieldFallbackGroupID, FieldFallbackGroupIDOnInvalidRequest, @@ -325,6 +337,26 @@ func ByImagePrice4k(opts ...sql.OrderTermOption) OrderOption { return sql.OrderByField(FieldImagePrice4k, opts...).ToFunc() } +// BySoraImagePrice360 orders the results by the sora_image_price_360 field. +func BySoraImagePrice360(opts ...sql.OrderTermOption) OrderOption { + return sql.OrderByField(FieldSoraImagePrice360, opts...).ToFunc() +} + +// BySoraImagePrice540 orders the results by the sora_image_price_540 field. +func BySoraImagePrice540(opts ...sql.OrderTermOption) OrderOption { + return sql.OrderByField(FieldSoraImagePrice540, opts...).ToFunc() +} + +// BySoraVideoPricePerRequest orders the results by the sora_video_price_per_request field. +func BySoraVideoPricePerRequest(opts ...sql.OrderTermOption) OrderOption { + return sql.OrderByField(FieldSoraVideoPricePerRequest, opts...).ToFunc() +} + +// BySoraVideoPricePerRequestHd orders the results by the sora_video_price_per_request_hd field. +func BySoraVideoPricePerRequestHd(opts ...sql.OrderTermOption) OrderOption { + return sql.OrderByField(FieldSoraVideoPricePerRequestHd, opts...).ToFunc() +} + // ByClaudeCodeOnly orders the results by the claude_code_only field. func ByClaudeCodeOnly(opts ...sql.OrderTermOption) OrderOption { return sql.OrderByField(FieldClaudeCodeOnly, opts...).ToFunc() diff --git a/backend/ent/group/where.go b/backend/ent/group/where.go index cd5197c9..127d4ae9 100644 --- a/backend/ent/group/where.go +++ b/backend/ent/group/where.go @@ -140,6 +140,26 @@ func ImagePrice4k(v float64) predicate.Group { return predicate.Group(sql.FieldEQ(FieldImagePrice4k, v)) } +// SoraImagePrice360 applies equality check predicate on the "sora_image_price_360" field. It's identical to SoraImagePrice360EQ. +func SoraImagePrice360(v float64) predicate.Group { + return predicate.Group(sql.FieldEQ(FieldSoraImagePrice360, v)) +} + +// SoraImagePrice540 applies equality check predicate on the "sora_image_price_540" field. It's identical to SoraImagePrice540EQ. +func SoraImagePrice540(v float64) predicate.Group { + return predicate.Group(sql.FieldEQ(FieldSoraImagePrice540, v)) +} + +// SoraVideoPricePerRequest applies equality check predicate on the "sora_video_price_per_request" field. It's identical to SoraVideoPricePerRequestEQ. +func SoraVideoPricePerRequest(v float64) predicate.Group { + return predicate.Group(sql.FieldEQ(FieldSoraVideoPricePerRequest, v)) +} + +// SoraVideoPricePerRequestHd applies equality check predicate on the "sora_video_price_per_request_hd" field. It's identical to SoraVideoPricePerRequestHdEQ. +func SoraVideoPricePerRequestHd(v float64) predicate.Group { + return predicate.Group(sql.FieldEQ(FieldSoraVideoPricePerRequestHd, v)) +} + // ClaudeCodeOnly applies equality check predicate on the "claude_code_only" field. It's identical to ClaudeCodeOnlyEQ. func ClaudeCodeOnly(v bool) predicate.Group { return predicate.Group(sql.FieldEQ(FieldClaudeCodeOnly, v)) @@ -1025,6 +1045,206 @@ func ImagePrice4kNotNil() predicate.Group { return predicate.Group(sql.FieldNotNull(FieldImagePrice4k)) } +// SoraImagePrice360EQ applies the EQ predicate on the "sora_image_price_360" field. +func SoraImagePrice360EQ(v float64) predicate.Group { + return predicate.Group(sql.FieldEQ(FieldSoraImagePrice360, v)) +} + +// SoraImagePrice360NEQ applies the NEQ predicate on the "sora_image_price_360" field. +func SoraImagePrice360NEQ(v float64) predicate.Group { + return predicate.Group(sql.FieldNEQ(FieldSoraImagePrice360, v)) +} + +// SoraImagePrice360In applies the In predicate on the "sora_image_price_360" field. +func SoraImagePrice360In(vs ...float64) predicate.Group { + return predicate.Group(sql.FieldIn(FieldSoraImagePrice360, vs...)) +} + +// SoraImagePrice360NotIn applies the NotIn predicate on the "sora_image_price_360" field. +func SoraImagePrice360NotIn(vs ...float64) predicate.Group { + return predicate.Group(sql.FieldNotIn(FieldSoraImagePrice360, vs...)) +} + +// SoraImagePrice360GT applies the GT predicate on the "sora_image_price_360" field. +func SoraImagePrice360GT(v float64) predicate.Group { + return predicate.Group(sql.FieldGT(FieldSoraImagePrice360, v)) +} + +// SoraImagePrice360GTE applies the GTE predicate on the "sora_image_price_360" field. +func SoraImagePrice360GTE(v float64) predicate.Group { + return predicate.Group(sql.FieldGTE(FieldSoraImagePrice360, v)) +} + +// SoraImagePrice360LT applies the LT predicate on the "sora_image_price_360" field. +func SoraImagePrice360LT(v float64) predicate.Group { + return predicate.Group(sql.FieldLT(FieldSoraImagePrice360, v)) +} + +// SoraImagePrice360LTE applies the LTE predicate on the "sora_image_price_360" field. +func SoraImagePrice360LTE(v float64) predicate.Group { + return predicate.Group(sql.FieldLTE(FieldSoraImagePrice360, v)) +} + +// SoraImagePrice360IsNil applies the IsNil predicate on the "sora_image_price_360" field. +func SoraImagePrice360IsNil() predicate.Group { + return predicate.Group(sql.FieldIsNull(FieldSoraImagePrice360)) +} + +// SoraImagePrice360NotNil applies the NotNil predicate on the "sora_image_price_360" field. +func SoraImagePrice360NotNil() predicate.Group { + return predicate.Group(sql.FieldNotNull(FieldSoraImagePrice360)) +} + +// SoraImagePrice540EQ applies the EQ predicate on the "sora_image_price_540" field. +func SoraImagePrice540EQ(v float64) predicate.Group { + return predicate.Group(sql.FieldEQ(FieldSoraImagePrice540, v)) +} + +// SoraImagePrice540NEQ applies the NEQ predicate on the "sora_image_price_540" field. +func SoraImagePrice540NEQ(v float64) predicate.Group { + return predicate.Group(sql.FieldNEQ(FieldSoraImagePrice540, v)) +} + +// SoraImagePrice540In applies the In predicate on the "sora_image_price_540" field. +func SoraImagePrice540In(vs ...float64) predicate.Group { + return predicate.Group(sql.FieldIn(FieldSoraImagePrice540, vs...)) +} + +// SoraImagePrice540NotIn applies the NotIn predicate on the "sora_image_price_540" field. +func SoraImagePrice540NotIn(vs ...float64) predicate.Group { + return predicate.Group(sql.FieldNotIn(FieldSoraImagePrice540, vs...)) +} + +// SoraImagePrice540GT applies the GT predicate on the "sora_image_price_540" field. +func SoraImagePrice540GT(v float64) predicate.Group { + return predicate.Group(sql.FieldGT(FieldSoraImagePrice540, v)) +} + +// SoraImagePrice540GTE applies the GTE predicate on the "sora_image_price_540" field. +func SoraImagePrice540GTE(v float64) predicate.Group { + return predicate.Group(sql.FieldGTE(FieldSoraImagePrice540, v)) +} + +// SoraImagePrice540LT applies the LT predicate on the "sora_image_price_540" field. +func SoraImagePrice540LT(v float64) predicate.Group { + return predicate.Group(sql.FieldLT(FieldSoraImagePrice540, v)) +} + +// SoraImagePrice540LTE applies the LTE predicate on the "sora_image_price_540" field. +func SoraImagePrice540LTE(v float64) predicate.Group { + return predicate.Group(sql.FieldLTE(FieldSoraImagePrice540, v)) +} + +// SoraImagePrice540IsNil applies the IsNil predicate on the "sora_image_price_540" field. +func SoraImagePrice540IsNil() predicate.Group { + return predicate.Group(sql.FieldIsNull(FieldSoraImagePrice540)) +} + +// SoraImagePrice540NotNil applies the NotNil predicate on the "sora_image_price_540" field. +func SoraImagePrice540NotNil() predicate.Group { + return predicate.Group(sql.FieldNotNull(FieldSoraImagePrice540)) +} + +// SoraVideoPricePerRequestEQ applies the EQ predicate on the "sora_video_price_per_request" field. +func SoraVideoPricePerRequestEQ(v float64) predicate.Group { + return predicate.Group(sql.FieldEQ(FieldSoraVideoPricePerRequest, v)) +} + +// SoraVideoPricePerRequestNEQ applies the NEQ predicate on the "sora_video_price_per_request" field. +func SoraVideoPricePerRequestNEQ(v float64) predicate.Group { + return predicate.Group(sql.FieldNEQ(FieldSoraVideoPricePerRequest, v)) +} + +// SoraVideoPricePerRequestIn applies the In predicate on the "sora_video_price_per_request" field. +func SoraVideoPricePerRequestIn(vs ...float64) predicate.Group { + return predicate.Group(sql.FieldIn(FieldSoraVideoPricePerRequest, vs...)) +} + +// SoraVideoPricePerRequestNotIn applies the NotIn predicate on the "sora_video_price_per_request" field. +func SoraVideoPricePerRequestNotIn(vs ...float64) predicate.Group { + return predicate.Group(sql.FieldNotIn(FieldSoraVideoPricePerRequest, vs...)) +} + +// SoraVideoPricePerRequestGT applies the GT predicate on the "sora_video_price_per_request" field. +func SoraVideoPricePerRequestGT(v float64) predicate.Group { + return predicate.Group(sql.FieldGT(FieldSoraVideoPricePerRequest, v)) +} + +// SoraVideoPricePerRequestGTE applies the GTE predicate on the "sora_video_price_per_request" field. +func SoraVideoPricePerRequestGTE(v float64) predicate.Group { + return predicate.Group(sql.FieldGTE(FieldSoraVideoPricePerRequest, v)) +} + +// SoraVideoPricePerRequestLT applies the LT predicate on the "sora_video_price_per_request" field. +func SoraVideoPricePerRequestLT(v float64) predicate.Group { + return predicate.Group(sql.FieldLT(FieldSoraVideoPricePerRequest, v)) +} + +// SoraVideoPricePerRequestLTE applies the LTE predicate on the "sora_video_price_per_request" field. +func SoraVideoPricePerRequestLTE(v float64) predicate.Group { + return predicate.Group(sql.FieldLTE(FieldSoraVideoPricePerRequest, v)) +} + +// SoraVideoPricePerRequestIsNil applies the IsNil predicate on the "sora_video_price_per_request" field. +func SoraVideoPricePerRequestIsNil() predicate.Group { + return predicate.Group(sql.FieldIsNull(FieldSoraVideoPricePerRequest)) +} + +// SoraVideoPricePerRequestNotNil applies the NotNil predicate on the "sora_video_price_per_request" field. +func SoraVideoPricePerRequestNotNil() predicate.Group { + return predicate.Group(sql.FieldNotNull(FieldSoraVideoPricePerRequest)) +} + +// SoraVideoPricePerRequestHdEQ applies the EQ predicate on the "sora_video_price_per_request_hd" field. +func SoraVideoPricePerRequestHdEQ(v float64) predicate.Group { + return predicate.Group(sql.FieldEQ(FieldSoraVideoPricePerRequestHd, v)) +} + +// SoraVideoPricePerRequestHdNEQ applies the NEQ predicate on the "sora_video_price_per_request_hd" field. +func SoraVideoPricePerRequestHdNEQ(v float64) predicate.Group { + return predicate.Group(sql.FieldNEQ(FieldSoraVideoPricePerRequestHd, v)) +} + +// SoraVideoPricePerRequestHdIn applies the In predicate on the "sora_video_price_per_request_hd" field. +func SoraVideoPricePerRequestHdIn(vs ...float64) predicate.Group { + return predicate.Group(sql.FieldIn(FieldSoraVideoPricePerRequestHd, vs...)) +} + +// SoraVideoPricePerRequestHdNotIn applies the NotIn predicate on the "sora_video_price_per_request_hd" field. +func SoraVideoPricePerRequestHdNotIn(vs ...float64) predicate.Group { + return predicate.Group(sql.FieldNotIn(FieldSoraVideoPricePerRequestHd, vs...)) +} + +// SoraVideoPricePerRequestHdGT applies the GT predicate on the "sora_video_price_per_request_hd" field. +func SoraVideoPricePerRequestHdGT(v float64) predicate.Group { + return predicate.Group(sql.FieldGT(FieldSoraVideoPricePerRequestHd, v)) +} + +// SoraVideoPricePerRequestHdGTE applies the GTE predicate on the "sora_video_price_per_request_hd" field. +func SoraVideoPricePerRequestHdGTE(v float64) predicate.Group { + return predicate.Group(sql.FieldGTE(FieldSoraVideoPricePerRequestHd, v)) +} + +// SoraVideoPricePerRequestHdLT applies the LT predicate on the "sora_video_price_per_request_hd" field. +func SoraVideoPricePerRequestHdLT(v float64) predicate.Group { + return predicate.Group(sql.FieldLT(FieldSoraVideoPricePerRequestHd, v)) +} + +// SoraVideoPricePerRequestHdLTE applies the LTE predicate on the "sora_video_price_per_request_hd" field. +func SoraVideoPricePerRequestHdLTE(v float64) predicate.Group { + return predicate.Group(sql.FieldLTE(FieldSoraVideoPricePerRequestHd, v)) +} + +// SoraVideoPricePerRequestHdIsNil applies the IsNil predicate on the "sora_video_price_per_request_hd" field. +func SoraVideoPricePerRequestHdIsNil() predicate.Group { + return predicate.Group(sql.FieldIsNull(FieldSoraVideoPricePerRequestHd)) +} + +// SoraVideoPricePerRequestHdNotNil applies the NotNil predicate on the "sora_video_price_per_request_hd" field. +func SoraVideoPricePerRequestHdNotNil() predicate.Group { + return predicate.Group(sql.FieldNotNull(FieldSoraVideoPricePerRequestHd)) +} + // ClaudeCodeOnlyEQ applies the EQ predicate on the "claude_code_only" field. func ClaudeCodeOnlyEQ(v bool) predicate.Group { return predicate.Group(sql.FieldEQ(FieldClaudeCodeOnly, v)) diff --git a/backend/ent/group_create.go b/backend/ent/group_create.go index 707600a7..4416516b 100644 --- a/backend/ent/group_create.go +++ b/backend/ent/group_create.go @@ -258,6 +258,62 @@ func (_c *GroupCreate) SetNillableImagePrice4k(v *float64) *GroupCreate { return _c } +// SetSoraImagePrice360 sets the "sora_image_price_360" field. +func (_c *GroupCreate) SetSoraImagePrice360(v float64) *GroupCreate { + _c.mutation.SetSoraImagePrice360(v) + return _c +} + +// SetNillableSoraImagePrice360 sets the "sora_image_price_360" field if the given value is not nil. +func (_c *GroupCreate) SetNillableSoraImagePrice360(v *float64) *GroupCreate { + if v != nil { + _c.SetSoraImagePrice360(*v) + } + return _c +} + +// SetSoraImagePrice540 sets the "sora_image_price_540" field. +func (_c *GroupCreate) SetSoraImagePrice540(v float64) *GroupCreate { + _c.mutation.SetSoraImagePrice540(v) + return _c +} + +// SetNillableSoraImagePrice540 sets the "sora_image_price_540" field if the given value is not nil. +func (_c *GroupCreate) SetNillableSoraImagePrice540(v *float64) *GroupCreate { + if v != nil { + _c.SetSoraImagePrice540(*v) + } + return _c +} + +// SetSoraVideoPricePerRequest sets the "sora_video_price_per_request" field. +func (_c *GroupCreate) SetSoraVideoPricePerRequest(v float64) *GroupCreate { + _c.mutation.SetSoraVideoPricePerRequest(v) + return _c +} + +// SetNillableSoraVideoPricePerRequest sets the "sora_video_price_per_request" field if the given value is not nil. +func (_c *GroupCreate) SetNillableSoraVideoPricePerRequest(v *float64) *GroupCreate { + if v != nil { + _c.SetSoraVideoPricePerRequest(*v) + } + return _c +} + +// SetSoraVideoPricePerRequestHd sets the "sora_video_price_per_request_hd" field. +func (_c *GroupCreate) SetSoraVideoPricePerRequestHd(v float64) *GroupCreate { + _c.mutation.SetSoraVideoPricePerRequestHd(v) + return _c +} + +// SetNillableSoraVideoPricePerRequestHd sets the "sora_video_price_per_request_hd" field if the given value is not nil. +func (_c *GroupCreate) SetNillableSoraVideoPricePerRequestHd(v *float64) *GroupCreate { + if v != nil { + _c.SetSoraVideoPricePerRequestHd(*v) + } + return _c +} + // SetClaudeCodeOnly sets the "claude_code_only" field. func (_c *GroupCreate) SetClaudeCodeOnly(v bool) *GroupCreate { _c.mutation.SetClaudeCodeOnly(v) @@ -701,6 +757,22 @@ func (_c *GroupCreate) createSpec() (*Group, *sqlgraph.CreateSpec) { _spec.SetField(group.FieldImagePrice4k, field.TypeFloat64, value) _node.ImagePrice4k = &value } + if value, ok := _c.mutation.SoraImagePrice360(); ok { + _spec.SetField(group.FieldSoraImagePrice360, field.TypeFloat64, value) + _node.SoraImagePrice360 = &value + } + if value, ok := _c.mutation.SoraImagePrice540(); ok { + _spec.SetField(group.FieldSoraImagePrice540, field.TypeFloat64, value) + _node.SoraImagePrice540 = &value + } + if value, ok := _c.mutation.SoraVideoPricePerRequest(); ok { + _spec.SetField(group.FieldSoraVideoPricePerRequest, field.TypeFloat64, value) + _node.SoraVideoPricePerRequest = &value + } + if value, ok := _c.mutation.SoraVideoPricePerRequestHd(); ok { + _spec.SetField(group.FieldSoraVideoPricePerRequestHd, field.TypeFloat64, value) + _node.SoraVideoPricePerRequestHd = &value + } if value, ok := _c.mutation.ClaudeCodeOnly(); ok { _spec.SetField(group.FieldClaudeCodeOnly, field.TypeBool, value) _node.ClaudeCodeOnly = value @@ -1177,6 +1249,102 @@ func (u *GroupUpsert) ClearImagePrice4k() *GroupUpsert { return u } +// SetSoraImagePrice360 sets the "sora_image_price_360" field. +func (u *GroupUpsert) SetSoraImagePrice360(v float64) *GroupUpsert { + u.Set(group.FieldSoraImagePrice360, v) + return u +} + +// UpdateSoraImagePrice360 sets the "sora_image_price_360" field to the value that was provided on create. +func (u *GroupUpsert) UpdateSoraImagePrice360() *GroupUpsert { + u.SetExcluded(group.FieldSoraImagePrice360) + return u +} + +// AddSoraImagePrice360 adds v to the "sora_image_price_360" field. +func (u *GroupUpsert) AddSoraImagePrice360(v float64) *GroupUpsert { + u.Add(group.FieldSoraImagePrice360, v) + return u +} + +// ClearSoraImagePrice360 clears the value of the "sora_image_price_360" field. +func (u *GroupUpsert) ClearSoraImagePrice360() *GroupUpsert { + u.SetNull(group.FieldSoraImagePrice360) + return u +} + +// SetSoraImagePrice540 sets the "sora_image_price_540" field. +func (u *GroupUpsert) SetSoraImagePrice540(v float64) *GroupUpsert { + u.Set(group.FieldSoraImagePrice540, v) + return u +} + +// UpdateSoraImagePrice540 sets the "sora_image_price_540" field to the value that was provided on create. +func (u *GroupUpsert) UpdateSoraImagePrice540() *GroupUpsert { + u.SetExcluded(group.FieldSoraImagePrice540) + return u +} + +// AddSoraImagePrice540 adds v to the "sora_image_price_540" field. +func (u *GroupUpsert) AddSoraImagePrice540(v float64) *GroupUpsert { + u.Add(group.FieldSoraImagePrice540, v) + return u +} + +// ClearSoraImagePrice540 clears the value of the "sora_image_price_540" field. +func (u *GroupUpsert) ClearSoraImagePrice540() *GroupUpsert { + u.SetNull(group.FieldSoraImagePrice540) + return u +} + +// SetSoraVideoPricePerRequest sets the "sora_video_price_per_request" field. +func (u *GroupUpsert) SetSoraVideoPricePerRequest(v float64) *GroupUpsert { + u.Set(group.FieldSoraVideoPricePerRequest, v) + return u +} + +// UpdateSoraVideoPricePerRequest sets the "sora_video_price_per_request" field to the value that was provided on create. +func (u *GroupUpsert) UpdateSoraVideoPricePerRequest() *GroupUpsert { + u.SetExcluded(group.FieldSoraVideoPricePerRequest) + return u +} + +// AddSoraVideoPricePerRequest adds v to the "sora_video_price_per_request" field. +func (u *GroupUpsert) AddSoraVideoPricePerRequest(v float64) *GroupUpsert { + u.Add(group.FieldSoraVideoPricePerRequest, v) + return u +} + +// ClearSoraVideoPricePerRequest clears the value of the "sora_video_price_per_request" field. +func (u *GroupUpsert) ClearSoraVideoPricePerRequest() *GroupUpsert { + u.SetNull(group.FieldSoraVideoPricePerRequest) + return u +} + +// SetSoraVideoPricePerRequestHd sets the "sora_video_price_per_request_hd" field. +func (u *GroupUpsert) SetSoraVideoPricePerRequestHd(v float64) *GroupUpsert { + u.Set(group.FieldSoraVideoPricePerRequestHd, v) + return u +} + +// UpdateSoraVideoPricePerRequestHd sets the "sora_video_price_per_request_hd" field to the value that was provided on create. +func (u *GroupUpsert) UpdateSoraVideoPricePerRequestHd() *GroupUpsert { + u.SetExcluded(group.FieldSoraVideoPricePerRequestHd) + return u +} + +// AddSoraVideoPricePerRequestHd adds v to the "sora_video_price_per_request_hd" field. +func (u *GroupUpsert) AddSoraVideoPricePerRequestHd(v float64) *GroupUpsert { + u.Add(group.FieldSoraVideoPricePerRequestHd, v) + return u +} + +// ClearSoraVideoPricePerRequestHd clears the value of the "sora_video_price_per_request_hd" field. +func (u *GroupUpsert) ClearSoraVideoPricePerRequestHd() *GroupUpsert { + u.SetNull(group.FieldSoraVideoPricePerRequestHd) + return u +} + // SetClaudeCodeOnly sets the "claude_code_only" field. func (u *GroupUpsert) SetClaudeCodeOnly(v bool) *GroupUpsert { u.Set(group.FieldClaudeCodeOnly, v) @@ -1690,6 +1858,118 @@ func (u *GroupUpsertOne) ClearImagePrice4k() *GroupUpsertOne { }) } +// SetSoraImagePrice360 sets the "sora_image_price_360" field. +func (u *GroupUpsertOne) SetSoraImagePrice360(v float64) *GroupUpsertOne { + return u.Update(func(s *GroupUpsert) { + s.SetSoraImagePrice360(v) + }) +} + +// AddSoraImagePrice360 adds v to the "sora_image_price_360" field. +func (u *GroupUpsertOne) AddSoraImagePrice360(v float64) *GroupUpsertOne { + return u.Update(func(s *GroupUpsert) { + s.AddSoraImagePrice360(v) + }) +} + +// UpdateSoraImagePrice360 sets the "sora_image_price_360" field to the value that was provided on create. +func (u *GroupUpsertOne) UpdateSoraImagePrice360() *GroupUpsertOne { + return u.Update(func(s *GroupUpsert) { + s.UpdateSoraImagePrice360() + }) +} + +// ClearSoraImagePrice360 clears the value of the "sora_image_price_360" field. +func (u *GroupUpsertOne) ClearSoraImagePrice360() *GroupUpsertOne { + return u.Update(func(s *GroupUpsert) { + s.ClearSoraImagePrice360() + }) +} + +// SetSoraImagePrice540 sets the "sora_image_price_540" field. +func (u *GroupUpsertOne) SetSoraImagePrice540(v float64) *GroupUpsertOne { + return u.Update(func(s *GroupUpsert) { + s.SetSoraImagePrice540(v) + }) +} + +// AddSoraImagePrice540 adds v to the "sora_image_price_540" field. +func (u *GroupUpsertOne) AddSoraImagePrice540(v float64) *GroupUpsertOne { + return u.Update(func(s *GroupUpsert) { + s.AddSoraImagePrice540(v) + }) +} + +// UpdateSoraImagePrice540 sets the "sora_image_price_540" field to the value that was provided on create. +func (u *GroupUpsertOne) UpdateSoraImagePrice540() *GroupUpsertOne { + return u.Update(func(s *GroupUpsert) { + s.UpdateSoraImagePrice540() + }) +} + +// ClearSoraImagePrice540 clears the value of the "sora_image_price_540" field. +func (u *GroupUpsertOne) ClearSoraImagePrice540() *GroupUpsertOne { + return u.Update(func(s *GroupUpsert) { + s.ClearSoraImagePrice540() + }) +} + +// SetSoraVideoPricePerRequest sets the "sora_video_price_per_request" field. +func (u *GroupUpsertOne) SetSoraVideoPricePerRequest(v float64) *GroupUpsertOne { + return u.Update(func(s *GroupUpsert) { + s.SetSoraVideoPricePerRequest(v) + }) +} + +// AddSoraVideoPricePerRequest adds v to the "sora_video_price_per_request" field. +func (u *GroupUpsertOne) AddSoraVideoPricePerRequest(v float64) *GroupUpsertOne { + return u.Update(func(s *GroupUpsert) { + s.AddSoraVideoPricePerRequest(v) + }) +} + +// UpdateSoraVideoPricePerRequest sets the "sora_video_price_per_request" field to the value that was provided on create. +func (u *GroupUpsertOne) UpdateSoraVideoPricePerRequest() *GroupUpsertOne { + return u.Update(func(s *GroupUpsert) { + s.UpdateSoraVideoPricePerRequest() + }) +} + +// ClearSoraVideoPricePerRequest clears the value of the "sora_video_price_per_request" field. +func (u *GroupUpsertOne) ClearSoraVideoPricePerRequest() *GroupUpsertOne { + return u.Update(func(s *GroupUpsert) { + s.ClearSoraVideoPricePerRequest() + }) +} + +// SetSoraVideoPricePerRequestHd sets the "sora_video_price_per_request_hd" field. +func (u *GroupUpsertOne) SetSoraVideoPricePerRequestHd(v float64) *GroupUpsertOne { + return u.Update(func(s *GroupUpsert) { + s.SetSoraVideoPricePerRequestHd(v) + }) +} + +// AddSoraVideoPricePerRequestHd adds v to the "sora_video_price_per_request_hd" field. +func (u *GroupUpsertOne) AddSoraVideoPricePerRequestHd(v float64) *GroupUpsertOne { + return u.Update(func(s *GroupUpsert) { + s.AddSoraVideoPricePerRequestHd(v) + }) +} + +// UpdateSoraVideoPricePerRequestHd sets the "sora_video_price_per_request_hd" field to the value that was provided on create. +func (u *GroupUpsertOne) UpdateSoraVideoPricePerRequestHd() *GroupUpsertOne { + return u.Update(func(s *GroupUpsert) { + s.UpdateSoraVideoPricePerRequestHd() + }) +} + +// ClearSoraVideoPricePerRequestHd clears the value of the "sora_video_price_per_request_hd" field. +func (u *GroupUpsertOne) ClearSoraVideoPricePerRequestHd() *GroupUpsertOne { + return u.Update(func(s *GroupUpsert) { + s.ClearSoraVideoPricePerRequestHd() + }) +} + // SetClaudeCodeOnly sets the "claude_code_only" field. func (u *GroupUpsertOne) SetClaudeCodeOnly(v bool) *GroupUpsertOne { return u.Update(func(s *GroupUpsert) { @@ -2391,6 +2671,118 @@ func (u *GroupUpsertBulk) ClearImagePrice4k() *GroupUpsertBulk { }) } +// SetSoraImagePrice360 sets the "sora_image_price_360" field. +func (u *GroupUpsertBulk) SetSoraImagePrice360(v float64) *GroupUpsertBulk { + return u.Update(func(s *GroupUpsert) { + s.SetSoraImagePrice360(v) + }) +} + +// AddSoraImagePrice360 adds v to the "sora_image_price_360" field. +func (u *GroupUpsertBulk) AddSoraImagePrice360(v float64) *GroupUpsertBulk { + return u.Update(func(s *GroupUpsert) { + s.AddSoraImagePrice360(v) + }) +} + +// UpdateSoraImagePrice360 sets the "sora_image_price_360" field to the value that was provided on create. +func (u *GroupUpsertBulk) UpdateSoraImagePrice360() *GroupUpsertBulk { + return u.Update(func(s *GroupUpsert) { + s.UpdateSoraImagePrice360() + }) +} + +// ClearSoraImagePrice360 clears the value of the "sora_image_price_360" field. +func (u *GroupUpsertBulk) ClearSoraImagePrice360() *GroupUpsertBulk { + return u.Update(func(s *GroupUpsert) { + s.ClearSoraImagePrice360() + }) +} + +// SetSoraImagePrice540 sets the "sora_image_price_540" field. +func (u *GroupUpsertBulk) SetSoraImagePrice540(v float64) *GroupUpsertBulk { + return u.Update(func(s *GroupUpsert) { + s.SetSoraImagePrice540(v) + }) +} + +// AddSoraImagePrice540 adds v to the "sora_image_price_540" field. +func (u *GroupUpsertBulk) AddSoraImagePrice540(v float64) *GroupUpsertBulk { + return u.Update(func(s *GroupUpsert) { + s.AddSoraImagePrice540(v) + }) +} + +// UpdateSoraImagePrice540 sets the "sora_image_price_540" field to the value that was provided on create. +func (u *GroupUpsertBulk) UpdateSoraImagePrice540() *GroupUpsertBulk { + return u.Update(func(s *GroupUpsert) { + s.UpdateSoraImagePrice540() + }) +} + +// ClearSoraImagePrice540 clears the value of the "sora_image_price_540" field. +func (u *GroupUpsertBulk) ClearSoraImagePrice540() *GroupUpsertBulk { + return u.Update(func(s *GroupUpsert) { + s.ClearSoraImagePrice540() + }) +} + +// SetSoraVideoPricePerRequest sets the "sora_video_price_per_request" field. +func (u *GroupUpsertBulk) SetSoraVideoPricePerRequest(v float64) *GroupUpsertBulk { + return u.Update(func(s *GroupUpsert) { + s.SetSoraVideoPricePerRequest(v) + }) +} + +// AddSoraVideoPricePerRequest adds v to the "sora_video_price_per_request" field. +func (u *GroupUpsertBulk) AddSoraVideoPricePerRequest(v float64) *GroupUpsertBulk { + return u.Update(func(s *GroupUpsert) { + s.AddSoraVideoPricePerRequest(v) + }) +} + +// UpdateSoraVideoPricePerRequest sets the "sora_video_price_per_request" field to the value that was provided on create. +func (u *GroupUpsertBulk) UpdateSoraVideoPricePerRequest() *GroupUpsertBulk { + return u.Update(func(s *GroupUpsert) { + s.UpdateSoraVideoPricePerRequest() + }) +} + +// ClearSoraVideoPricePerRequest clears the value of the "sora_video_price_per_request" field. +func (u *GroupUpsertBulk) ClearSoraVideoPricePerRequest() *GroupUpsertBulk { + return u.Update(func(s *GroupUpsert) { + s.ClearSoraVideoPricePerRequest() + }) +} + +// SetSoraVideoPricePerRequestHd sets the "sora_video_price_per_request_hd" field. +func (u *GroupUpsertBulk) SetSoraVideoPricePerRequestHd(v float64) *GroupUpsertBulk { + return u.Update(func(s *GroupUpsert) { + s.SetSoraVideoPricePerRequestHd(v) + }) +} + +// AddSoraVideoPricePerRequestHd adds v to the "sora_video_price_per_request_hd" field. +func (u *GroupUpsertBulk) AddSoraVideoPricePerRequestHd(v float64) *GroupUpsertBulk { + return u.Update(func(s *GroupUpsert) { + s.AddSoraVideoPricePerRequestHd(v) + }) +} + +// UpdateSoraVideoPricePerRequestHd sets the "sora_video_price_per_request_hd" field to the value that was provided on create. +func (u *GroupUpsertBulk) UpdateSoraVideoPricePerRequestHd() *GroupUpsertBulk { + return u.Update(func(s *GroupUpsert) { + s.UpdateSoraVideoPricePerRequestHd() + }) +} + +// ClearSoraVideoPricePerRequestHd clears the value of the "sora_video_price_per_request_hd" field. +func (u *GroupUpsertBulk) ClearSoraVideoPricePerRequestHd() *GroupUpsertBulk { + return u.Update(func(s *GroupUpsert) { + s.ClearSoraVideoPricePerRequestHd() + }) +} + // SetClaudeCodeOnly sets the "claude_code_only" field. func (u *GroupUpsertBulk) SetClaudeCodeOnly(v bool) *GroupUpsertBulk { return u.Update(func(s *GroupUpsert) { diff --git a/backend/ent/group_update.go b/backend/ent/group_update.go index 393fd304..db510e05 100644 --- a/backend/ent/group_update.go +++ b/backend/ent/group_update.go @@ -355,6 +355,114 @@ func (_u *GroupUpdate) ClearImagePrice4k() *GroupUpdate { return _u } +// SetSoraImagePrice360 sets the "sora_image_price_360" field. +func (_u *GroupUpdate) SetSoraImagePrice360(v float64) *GroupUpdate { + _u.mutation.ResetSoraImagePrice360() + _u.mutation.SetSoraImagePrice360(v) + return _u +} + +// SetNillableSoraImagePrice360 sets the "sora_image_price_360" field if the given value is not nil. +func (_u *GroupUpdate) SetNillableSoraImagePrice360(v *float64) *GroupUpdate { + if v != nil { + _u.SetSoraImagePrice360(*v) + } + return _u +} + +// AddSoraImagePrice360 adds value to the "sora_image_price_360" field. +func (_u *GroupUpdate) AddSoraImagePrice360(v float64) *GroupUpdate { + _u.mutation.AddSoraImagePrice360(v) + return _u +} + +// ClearSoraImagePrice360 clears the value of the "sora_image_price_360" field. +func (_u *GroupUpdate) ClearSoraImagePrice360() *GroupUpdate { + _u.mutation.ClearSoraImagePrice360() + return _u +} + +// SetSoraImagePrice540 sets the "sora_image_price_540" field. +func (_u *GroupUpdate) SetSoraImagePrice540(v float64) *GroupUpdate { + _u.mutation.ResetSoraImagePrice540() + _u.mutation.SetSoraImagePrice540(v) + return _u +} + +// SetNillableSoraImagePrice540 sets the "sora_image_price_540" field if the given value is not nil. +func (_u *GroupUpdate) SetNillableSoraImagePrice540(v *float64) *GroupUpdate { + if v != nil { + _u.SetSoraImagePrice540(*v) + } + return _u +} + +// AddSoraImagePrice540 adds value to the "sora_image_price_540" field. +func (_u *GroupUpdate) AddSoraImagePrice540(v float64) *GroupUpdate { + _u.mutation.AddSoraImagePrice540(v) + return _u +} + +// ClearSoraImagePrice540 clears the value of the "sora_image_price_540" field. +func (_u *GroupUpdate) ClearSoraImagePrice540() *GroupUpdate { + _u.mutation.ClearSoraImagePrice540() + return _u +} + +// SetSoraVideoPricePerRequest sets the "sora_video_price_per_request" field. +func (_u *GroupUpdate) SetSoraVideoPricePerRequest(v float64) *GroupUpdate { + _u.mutation.ResetSoraVideoPricePerRequest() + _u.mutation.SetSoraVideoPricePerRequest(v) + return _u +} + +// SetNillableSoraVideoPricePerRequest sets the "sora_video_price_per_request" field if the given value is not nil. +func (_u *GroupUpdate) SetNillableSoraVideoPricePerRequest(v *float64) *GroupUpdate { + if v != nil { + _u.SetSoraVideoPricePerRequest(*v) + } + return _u +} + +// AddSoraVideoPricePerRequest adds value to the "sora_video_price_per_request" field. +func (_u *GroupUpdate) AddSoraVideoPricePerRequest(v float64) *GroupUpdate { + _u.mutation.AddSoraVideoPricePerRequest(v) + return _u +} + +// ClearSoraVideoPricePerRequest clears the value of the "sora_video_price_per_request" field. +func (_u *GroupUpdate) ClearSoraVideoPricePerRequest() *GroupUpdate { + _u.mutation.ClearSoraVideoPricePerRequest() + return _u +} + +// SetSoraVideoPricePerRequestHd sets the "sora_video_price_per_request_hd" field. +func (_u *GroupUpdate) SetSoraVideoPricePerRequestHd(v float64) *GroupUpdate { + _u.mutation.ResetSoraVideoPricePerRequestHd() + _u.mutation.SetSoraVideoPricePerRequestHd(v) + return _u +} + +// SetNillableSoraVideoPricePerRequestHd sets the "sora_video_price_per_request_hd" field if the given value is not nil. +func (_u *GroupUpdate) SetNillableSoraVideoPricePerRequestHd(v *float64) *GroupUpdate { + if v != nil { + _u.SetSoraVideoPricePerRequestHd(*v) + } + return _u +} + +// AddSoraVideoPricePerRequestHd adds value to the "sora_video_price_per_request_hd" field. +func (_u *GroupUpdate) AddSoraVideoPricePerRequestHd(v float64) *GroupUpdate { + _u.mutation.AddSoraVideoPricePerRequestHd(v) + return _u +} + +// ClearSoraVideoPricePerRequestHd clears the value of the "sora_video_price_per_request_hd" field. +func (_u *GroupUpdate) ClearSoraVideoPricePerRequestHd() *GroupUpdate { + _u.mutation.ClearSoraVideoPricePerRequestHd() + return _u +} + // SetClaudeCodeOnly sets the "claude_code_only" field. func (_u *GroupUpdate) SetClaudeCodeOnly(v bool) *GroupUpdate { _u.mutation.SetClaudeCodeOnly(v) @@ -892,6 +1000,42 @@ func (_u *GroupUpdate) sqlSave(ctx context.Context) (_node int, err error) { if _u.mutation.ImagePrice4kCleared() { _spec.ClearField(group.FieldImagePrice4k, field.TypeFloat64) } + if value, ok := _u.mutation.SoraImagePrice360(); ok { + _spec.SetField(group.FieldSoraImagePrice360, field.TypeFloat64, value) + } + if value, ok := _u.mutation.AddedSoraImagePrice360(); ok { + _spec.AddField(group.FieldSoraImagePrice360, field.TypeFloat64, value) + } + if _u.mutation.SoraImagePrice360Cleared() { + _spec.ClearField(group.FieldSoraImagePrice360, field.TypeFloat64) + } + if value, ok := _u.mutation.SoraImagePrice540(); ok { + _spec.SetField(group.FieldSoraImagePrice540, field.TypeFloat64, value) + } + if value, ok := _u.mutation.AddedSoraImagePrice540(); ok { + _spec.AddField(group.FieldSoraImagePrice540, field.TypeFloat64, value) + } + if _u.mutation.SoraImagePrice540Cleared() { + _spec.ClearField(group.FieldSoraImagePrice540, field.TypeFloat64) + } + if value, ok := _u.mutation.SoraVideoPricePerRequest(); ok { + _spec.SetField(group.FieldSoraVideoPricePerRequest, field.TypeFloat64, value) + } + if value, ok := _u.mutation.AddedSoraVideoPricePerRequest(); ok { + _spec.AddField(group.FieldSoraVideoPricePerRequest, field.TypeFloat64, value) + } + if _u.mutation.SoraVideoPricePerRequestCleared() { + _spec.ClearField(group.FieldSoraVideoPricePerRequest, field.TypeFloat64) + } + if value, ok := _u.mutation.SoraVideoPricePerRequestHd(); ok { + _spec.SetField(group.FieldSoraVideoPricePerRequestHd, field.TypeFloat64, value) + } + if value, ok := _u.mutation.AddedSoraVideoPricePerRequestHd(); ok { + _spec.AddField(group.FieldSoraVideoPricePerRequestHd, field.TypeFloat64, value) + } + if _u.mutation.SoraVideoPricePerRequestHdCleared() { + _spec.ClearField(group.FieldSoraVideoPricePerRequestHd, field.TypeFloat64) + } if value, ok := _u.mutation.ClaudeCodeOnly(); ok { _spec.SetField(group.FieldClaudeCodeOnly, field.TypeBool, value) } @@ -1573,6 +1717,114 @@ func (_u *GroupUpdateOne) ClearImagePrice4k() *GroupUpdateOne { return _u } +// SetSoraImagePrice360 sets the "sora_image_price_360" field. +func (_u *GroupUpdateOne) SetSoraImagePrice360(v float64) *GroupUpdateOne { + _u.mutation.ResetSoraImagePrice360() + _u.mutation.SetSoraImagePrice360(v) + return _u +} + +// SetNillableSoraImagePrice360 sets the "sora_image_price_360" field if the given value is not nil. +func (_u *GroupUpdateOne) SetNillableSoraImagePrice360(v *float64) *GroupUpdateOne { + if v != nil { + _u.SetSoraImagePrice360(*v) + } + return _u +} + +// AddSoraImagePrice360 adds value to the "sora_image_price_360" field. +func (_u *GroupUpdateOne) AddSoraImagePrice360(v float64) *GroupUpdateOne { + _u.mutation.AddSoraImagePrice360(v) + return _u +} + +// ClearSoraImagePrice360 clears the value of the "sora_image_price_360" field. +func (_u *GroupUpdateOne) ClearSoraImagePrice360() *GroupUpdateOne { + _u.mutation.ClearSoraImagePrice360() + return _u +} + +// SetSoraImagePrice540 sets the "sora_image_price_540" field. +func (_u *GroupUpdateOne) SetSoraImagePrice540(v float64) *GroupUpdateOne { + _u.mutation.ResetSoraImagePrice540() + _u.mutation.SetSoraImagePrice540(v) + return _u +} + +// SetNillableSoraImagePrice540 sets the "sora_image_price_540" field if the given value is not nil. +func (_u *GroupUpdateOne) SetNillableSoraImagePrice540(v *float64) *GroupUpdateOne { + if v != nil { + _u.SetSoraImagePrice540(*v) + } + return _u +} + +// AddSoraImagePrice540 adds value to the "sora_image_price_540" field. +func (_u *GroupUpdateOne) AddSoraImagePrice540(v float64) *GroupUpdateOne { + _u.mutation.AddSoraImagePrice540(v) + return _u +} + +// ClearSoraImagePrice540 clears the value of the "sora_image_price_540" field. +func (_u *GroupUpdateOne) ClearSoraImagePrice540() *GroupUpdateOne { + _u.mutation.ClearSoraImagePrice540() + return _u +} + +// SetSoraVideoPricePerRequest sets the "sora_video_price_per_request" field. +func (_u *GroupUpdateOne) SetSoraVideoPricePerRequest(v float64) *GroupUpdateOne { + _u.mutation.ResetSoraVideoPricePerRequest() + _u.mutation.SetSoraVideoPricePerRequest(v) + return _u +} + +// SetNillableSoraVideoPricePerRequest sets the "sora_video_price_per_request" field if the given value is not nil. +func (_u *GroupUpdateOne) SetNillableSoraVideoPricePerRequest(v *float64) *GroupUpdateOne { + if v != nil { + _u.SetSoraVideoPricePerRequest(*v) + } + return _u +} + +// AddSoraVideoPricePerRequest adds value to the "sora_video_price_per_request" field. +func (_u *GroupUpdateOne) AddSoraVideoPricePerRequest(v float64) *GroupUpdateOne { + _u.mutation.AddSoraVideoPricePerRequest(v) + return _u +} + +// ClearSoraVideoPricePerRequest clears the value of the "sora_video_price_per_request" field. +func (_u *GroupUpdateOne) ClearSoraVideoPricePerRequest() *GroupUpdateOne { + _u.mutation.ClearSoraVideoPricePerRequest() + return _u +} + +// SetSoraVideoPricePerRequestHd sets the "sora_video_price_per_request_hd" field. +func (_u *GroupUpdateOne) SetSoraVideoPricePerRequestHd(v float64) *GroupUpdateOne { + _u.mutation.ResetSoraVideoPricePerRequestHd() + _u.mutation.SetSoraVideoPricePerRequestHd(v) + return _u +} + +// SetNillableSoraVideoPricePerRequestHd sets the "sora_video_price_per_request_hd" field if the given value is not nil. +func (_u *GroupUpdateOne) SetNillableSoraVideoPricePerRequestHd(v *float64) *GroupUpdateOne { + if v != nil { + _u.SetSoraVideoPricePerRequestHd(*v) + } + return _u +} + +// AddSoraVideoPricePerRequestHd adds value to the "sora_video_price_per_request_hd" field. +func (_u *GroupUpdateOne) AddSoraVideoPricePerRequestHd(v float64) *GroupUpdateOne { + _u.mutation.AddSoraVideoPricePerRequestHd(v) + return _u +} + +// ClearSoraVideoPricePerRequestHd clears the value of the "sora_video_price_per_request_hd" field. +func (_u *GroupUpdateOne) ClearSoraVideoPricePerRequestHd() *GroupUpdateOne { + _u.mutation.ClearSoraVideoPricePerRequestHd() + return _u +} + // SetClaudeCodeOnly sets the "claude_code_only" field. func (_u *GroupUpdateOne) SetClaudeCodeOnly(v bool) *GroupUpdateOne { _u.mutation.SetClaudeCodeOnly(v) @@ -2140,6 +2392,42 @@ func (_u *GroupUpdateOne) sqlSave(ctx context.Context) (_node *Group, err error) if _u.mutation.ImagePrice4kCleared() { _spec.ClearField(group.FieldImagePrice4k, field.TypeFloat64) } + if value, ok := _u.mutation.SoraImagePrice360(); ok { + _spec.SetField(group.FieldSoraImagePrice360, field.TypeFloat64, value) + } + if value, ok := _u.mutation.AddedSoraImagePrice360(); ok { + _spec.AddField(group.FieldSoraImagePrice360, field.TypeFloat64, value) + } + if _u.mutation.SoraImagePrice360Cleared() { + _spec.ClearField(group.FieldSoraImagePrice360, field.TypeFloat64) + } + if value, ok := _u.mutation.SoraImagePrice540(); ok { + _spec.SetField(group.FieldSoraImagePrice540, field.TypeFloat64, value) + } + if value, ok := _u.mutation.AddedSoraImagePrice540(); ok { + _spec.AddField(group.FieldSoraImagePrice540, field.TypeFloat64, value) + } + if _u.mutation.SoraImagePrice540Cleared() { + _spec.ClearField(group.FieldSoraImagePrice540, field.TypeFloat64) + } + if value, ok := _u.mutation.SoraVideoPricePerRequest(); ok { + _spec.SetField(group.FieldSoraVideoPricePerRequest, field.TypeFloat64, value) + } + if value, ok := _u.mutation.AddedSoraVideoPricePerRequest(); ok { + _spec.AddField(group.FieldSoraVideoPricePerRequest, field.TypeFloat64, value) + } + if _u.mutation.SoraVideoPricePerRequestCleared() { + _spec.ClearField(group.FieldSoraVideoPricePerRequest, field.TypeFloat64) + } + if value, ok := _u.mutation.SoraVideoPricePerRequestHd(); ok { + _spec.SetField(group.FieldSoraVideoPricePerRequestHd, field.TypeFloat64, value) + } + if value, ok := _u.mutation.AddedSoraVideoPricePerRequestHd(); ok { + _spec.AddField(group.FieldSoraVideoPricePerRequestHd, field.TypeFloat64, value) + } + if _u.mutation.SoraVideoPricePerRequestHdCleared() { + _spec.ClearField(group.FieldSoraVideoPricePerRequestHd, field.TypeFloat64) + } if value, ok := _u.mutation.ClaudeCodeOnly(); ok { _spec.SetField(group.FieldClaudeCodeOnly, field.TypeBool, value) } diff --git a/backend/ent/hook/hook.go b/backend/ent/hook/hook.go index 1b15685c..aff9caa0 100644 --- a/backend/ent/hook/hook.go +++ b/backend/ent/hook/hook.go @@ -141,6 +141,18 @@ func (f RedeemCodeFunc) Mutate(ctx context.Context, m ent.Mutation) (ent.Value, return nil, fmt.Errorf("unexpected mutation type %T. expect *ent.RedeemCodeMutation", m) } +// The SecuritySecretFunc type is an adapter to allow the use of ordinary +// function as SecuritySecret mutator. +type SecuritySecretFunc func(context.Context, *ent.SecuritySecretMutation) (ent.Value, error) + +// Mutate calls f(ctx, m). +func (f SecuritySecretFunc) Mutate(ctx context.Context, m ent.Mutation) (ent.Value, error) { + if mv, ok := m.(*ent.SecuritySecretMutation); ok { + return f(ctx, mv) + } + return nil, fmt.Errorf("unexpected mutation type %T. expect *ent.SecuritySecretMutation", m) +} + // The SettingFunc type is an adapter to allow the use of ordinary // function as Setting mutator. type SettingFunc func(context.Context, *ent.SettingMutation) (ent.Value, error) diff --git a/backend/ent/intercept/intercept.go b/backend/ent/intercept/intercept.go index 8ee42db3..290fb163 100644 --- a/backend/ent/intercept/intercept.go +++ b/backend/ent/intercept/intercept.go @@ -20,6 +20,7 @@ import ( "github.com/Wei-Shaw/sub2api/ent/promocodeusage" "github.com/Wei-Shaw/sub2api/ent/proxy" "github.com/Wei-Shaw/sub2api/ent/redeemcode" + "github.com/Wei-Shaw/sub2api/ent/securitysecret" "github.com/Wei-Shaw/sub2api/ent/setting" "github.com/Wei-Shaw/sub2api/ent/usagecleanuptask" "github.com/Wei-Shaw/sub2api/ent/usagelog" @@ -383,6 +384,33 @@ func (f TraverseRedeemCode) Traverse(ctx context.Context, q ent.Query) error { return fmt.Errorf("unexpected query type %T. expect *ent.RedeemCodeQuery", q) } +// The SecuritySecretFunc type is an adapter to allow the use of ordinary function as a Querier. +type SecuritySecretFunc func(context.Context, *ent.SecuritySecretQuery) (ent.Value, error) + +// Query calls f(ctx, q). +func (f SecuritySecretFunc) Query(ctx context.Context, q ent.Query) (ent.Value, error) { + if q, ok := q.(*ent.SecuritySecretQuery); ok { + return f(ctx, q) + } + return nil, fmt.Errorf("unexpected query type %T. expect *ent.SecuritySecretQuery", q) +} + +// The TraverseSecuritySecret type is an adapter to allow the use of ordinary function as Traverser. +type TraverseSecuritySecret func(context.Context, *ent.SecuritySecretQuery) error + +// Intercept is a dummy implementation of Intercept that returns the next Querier in the pipeline. +func (f TraverseSecuritySecret) Intercept(next ent.Querier) ent.Querier { + return next +} + +// Traverse calls f(ctx, q). +func (f TraverseSecuritySecret) Traverse(ctx context.Context, q ent.Query) error { + if q, ok := q.(*ent.SecuritySecretQuery); ok { + return f(ctx, q) + } + return fmt.Errorf("unexpected query type %T. expect *ent.SecuritySecretQuery", q) +} + // The SettingFunc type is an adapter to allow the use of ordinary function as a Querier. type SettingFunc func(context.Context, *ent.SettingQuery) (ent.Value, error) @@ -624,6 +652,8 @@ func NewQuery(q ent.Query) (Query, error) { return &query[*ent.ProxyQuery, predicate.Proxy, proxy.OrderOption]{typ: ent.TypeProxy, tq: q}, nil case *ent.RedeemCodeQuery: return &query[*ent.RedeemCodeQuery, predicate.RedeemCode, redeemcode.OrderOption]{typ: ent.TypeRedeemCode, tq: q}, nil + case *ent.SecuritySecretQuery: + return &query[*ent.SecuritySecretQuery, predicate.SecuritySecret, securitysecret.OrderOption]{typ: ent.TypeSecuritySecret, tq: q}, nil case *ent.SettingQuery: return &query[*ent.SettingQuery, predicate.Setting, setting.OrderOption]{typ: ent.TypeSetting, tq: q}, nil case *ent.UsageCleanupTaskQuery: diff --git a/backend/ent/migrate/schema.go b/backend/ent/migrate/schema.go index 07f2a68e..aba00d4f 100644 --- a/backend/ent/migrate/schema.go +++ b/backend/ent/migrate/schema.go @@ -18,6 +18,7 @@ var ( {Name: "key", Type: field.TypeString, Unique: true, Size: 128}, {Name: "name", Type: field.TypeString, Size: 100}, {Name: "status", Type: field.TypeString, Size: 20, Default: "active"}, + {Name: "last_used_at", Type: field.TypeTime, Nullable: true}, {Name: "ip_whitelist", Type: field.TypeJSON, Nullable: true}, {Name: "ip_blacklist", Type: field.TypeJSON, Nullable: true}, {Name: "quota", Type: field.TypeFloat64, Default: 0, SchemaType: map[string]string{"postgres": "decimal(20,8)"}}, @@ -34,13 +35,13 @@ var ( ForeignKeys: []*schema.ForeignKey{ { Symbol: "api_keys_groups_api_keys", - Columns: []*schema.Column{APIKeysColumns[12]}, + Columns: []*schema.Column{APIKeysColumns[13]}, RefColumns: []*schema.Column{GroupsColumns[0]}, OnDelete: schema.SetNull, }, { Symbol: "api_keys_users_api_keys", - Columns: []*schema.Column{APIKeysColumns[13]}, + Columns: []*schema.Column{APIKeysColumns[14]}, RefColumns: []*schema.Column{UsersColumns[0]}, OnDelete: schema.NoAction, }, @@ -49,12 +50,12 @@ var ( { Name: "apikey_user_id", Unique: false, - Columns: []*schema.Column{APIKeysColumns[13]}, + Columns: []*schema.Column{APIKeysColumns[14]}, }, { Name: "apikey_group_id", Unique: false, - Columns: []*schema.Column{APIKeysColumns[12]}, + Columns: []*schema.Column{APIKeysColumns[13]}, }, { Name: "apikey_status", @@ -66,15 +67,20 @@ var ( Unique: false, Columns: []*schema.Column{APIKeysColumns[3]}, }, + { + Name: "apikey_last_used_at", + Unique: false, + Columns: []*schema.Column{APIKeysColumns[7]}, + }, { Name: "apikey_quota_quota_used", Unique: false, - Columns: []*schema.Column{APIKeysColumns[9], APIKeysColumns[10]}, + Columns: []*schema.Column{APIKeysColumns[10], APIKeysColumns[11]}, }, { Name: "apikey_expires_at", Unique: false, - Columns: []*schema.Column{APIKeysColumns[11]}, + Columns: []*schema.Column{APIKeysColumns[12]}, }, }, } @@ -366,6 +372,10 @@ var ( {Name: "image_price_1k", Type: field.TypeFloat64, Nullable: true, SchemaType: map[string]string{"postgres": "decimal(20,8)"}}, {Name: "image_price_2k", Type: field.TypeFloat64, Nullable: true, SchemaType: map[string]string{"postgres": "decimal(20,8)"}}, {Name: "image_price_4k", Type: field.TypeFloat64, Nullable: true, SchemaType: map[string]string{"postgres": "decimal(20,8)"}}, + {Name: "sora_image_price_360", Type: field.TypeFloat64, Nullable: true, SchemaType: map[string]string{"postgres": "decimal(20,8)"}}, + {Name: "sora_image_price_540", Type: field.TypeFloat64, Nullable: true, SchemaType: map[string]string{"postgres": "decimal(20,8)"}}, + {Name: "sora_video_price_per_request", Type: field.TypeFloat64, Nullable: true, SchemaType: map[string]string{"postgres": "decimal(20,8)"}}, + {Name: "sora_video_price_per_request_hd", Type: field.TypeFloat64, Nullable: true, SchemaType: map[string]string{"postgres": "decimal(20,8)"}}, {Name: "claude_code_only", Type: field.TypeBool, Default: false}, {Name: "fallback_group_id", Type: field.TypeInt64, Nullable: true}, {Name: "fallback_group_id_on_invalid_request", Type: field.TypeInt64, Nullable: true}, @@ -409,7 +419,7 @@ var ( { Name: "group_sort_order", Unique: false, - Columns: []*schema.Column{GroupsColumns[25]}, + Columns: []*schema.Column{GroupsColumns[29]}, }, }, } @@ -572,6 +582,20 @@ var ( }, }, } + // SecuritySecretsColumns holds the columns for the "security_secrets" table. + SecuritySecretsColumns = []*schema.Column{ + {Name: "id", Type: field.TypeInt64, Increment: true}, + {Name: "created_at", Type: field.TypeTime, SchemaType: map[string]string{"postgres": "timestamptz"}}, + {Name: "updated_at", Type: field.TypeTime, SchemaType: map[string]string{"postgres": "timestamptz"}}, + {Name: "key", Type: field.TypeString, Unique: true, Size: 100}, + {Name: "value", Type: field.TypeString, SchemaType: map[string]string{"postgres": "text"}}, + } + // SecuritySecretsTable holds the schema information for the "security_secrets" table. + SecuritySecretsTable = &schema.Table{ + Name: "security_secrets", + Columns: SecuritySecretsColumns, + PrimaryKey: []*schema.Column{SecuritySecretsColumns[0]}, + } // SettingsColumns holds the columns for the "settings" table. SettingsColumns = []*schema.Column{ {Name: "id", Type: field.TypeInt64, Increment: true}, @@ -650,6 +674,7 @@ var ( {Name: "ip_address", Type: field.TypeString, Nullable: true, Size: 45}, {Name: "image_count", Type: field.TypeInt, Default: 0}, {Name: "image_size", Type: field.TypeString, Nullable: true, Size: 10}, + {Name: "media_type", Type: field.TypeString, Nullable: true, Size: 16}, {Name: "cache_ttl_overridden", Type: field.TypeBool, Default: false}, {Name: "created_at", Type: field.TypeTime, SchemaType: map[string]string{"postgres": "timestamptz"}}, {Name: "api_key_id", Type: field.TypeInt64}, @@ -666,31 +691,31 @@ var ( ForeignKeys: []*schema.ForeignKey{ { Symbol: "usage_logs_api_keys_usage_logs", - Columns: []*schema.Column{UsageLogsColumns[27]}, + Columns: []*schema.Column{UsageLogsColumns[28]}, RefColumns: []*schema.Column{APIKeysColumns[0]}, OnDelete: schema.NoAction, }, { Symbol: "usage_logs_accounts_usage_logs", - Columns: []*schema.Column{UsageLogsColumns[28]}, + Columns: []*schema.Column{UsageLogsColumns[29]}, RefColumns: []*schema.Column{AccountsColumns[0]}, OnDelete: schema.NoAction, }, { Symbol: "usage_logs_groups_usage_logs", - Columns: []*schema.Column{UsageLogsColumns[29]}, + Columns: []*schema.Column{UsageLogsColumns[30]}, RefColumns: []*schema.Column{GroupsColumns[0]}, OnDelete: schema.SetNull, }, { Symbol: "usage_logs_users_usage_logs", - Columns: []*schema.Column{UsageLogsColumns[30]}, + Columns: []*schema.Column{UsageLogsColumns[31]}, RefColumns: []*schema.Column{UsersColumns[0]}, OnDelete: schema.NoAction, }, { Symbol: "usage_logs_user_subscriptions_usage_logs", - Columns: []*schema.Column{UsageLogsColumns[31]}, + Columns: []*schema.Column{UsageLogsColumns[32]}, RefColumns: []*schema.Column{UserSubscriptionsColumns[0]}, OnDelete: schema.SetNull, }, @@ -699,32 +724,32 @@ var ( { Name: "usagelog_user_id", Unique: false, - Columns: []*schema.Column{UsageLogsColumns[30]}, + Columns: []*schema.Column{UsageLogsColumns[31]}, }, { Name: "usagelog_api_key_id", Unique: false, - Columns: []*schema.Column{UsageLogsColumns[27]}, + Columns: []*schema.Column{UsageLogsColumns[28]}, }, { Name: "usagelog_account_id", Unique: false, - Columns: []*schema.Column{UsageLogsColumns[28]}, + Columns: []*schema.Column{UsageLogsColumns[29]}, }, { Name: "usagelog_group_id", Unique: false, - Columns: []*schema.Column{UsageLogsColumns[29]}, + Columns: []*schema.Column{UsageLogsColumns[30]}, }, { Name: "usagelog_subscription_id", Unique: false, - Columns: []*schema.Column{UsageLogsColumns[31]}, + Columns: []*schema.Column{UsageLogsColumns[32]}, }, { Name: "usagelog_created_at", Unique: false, - Columns: []*schema.Column{UsageLogsColumns[26]}, + Columns: []*schema.Column{UsageLogsColumns[27]}, }, { Name: "usagelog_model", @@ -739,12 +764,12 @@ var ( { Name: "usagelog_user_id_created_at", Unique: false, - Columns: []*schema.Column{UsageLogsColumns[30], UsageLogsColumns[26]}, + Columns: []*schema.Column{UsageLogsColumns[31], UsageLogsColumns[27]}, }, { Name: "usagelog_api_key_id_created_at", Unique: false, - Columns: []*schema.Column{UsageLogsColumns[27], UsageLogsColumns[26]}, + Columns: []*schema.Column{UsageLogsColumns[28], UsageLogsColumns[27]}, }, }, } @@ -1000,6 +1025,7 @@ var ( PromoCodeUsagesTable, ProxiesTable, RedeemCodesTable, + SecuritySecretsTable, SettingsTable, UsageCleanupTasksTable, UsageLogsTable, @@ -1056,6 +1082,9 @@ func init() { RedeemCodesTable.Annotation = &entsql.Annotation{ Table: "redeem_codes", } + SecuritySecretsTable.Annotation = &entsql.Annotation{ + Table: "security_secrets", + } SettingsTable.Annotation = &entsql.Annotation{ Table: "settings", } diff --git a/backend/ent/mutation.go b/backend/ent/mutation.go index 34b3268e..7d5bf180 100644 --- a/backend/ent/mutation.go +++ b/backend/ent/mutation.go @@ -24,6 +24,7 @@ import ( "github.com/Wei-Shaw/sub2api/ent/promocodeusage" "github.com/Wei-Shaw/sub2api/ent/proxy" "github.com/Wei-Shaw/sub2api/ent/redeemcode" + "github.com/Wei-Shaw/sub2api/ent/securitysecret" "github.com/Wei-Shaw/sub2api/ent/setting" "github.com/Wei-Shaw/sub2api/ent/usagecleanuptask" "github.com/Wei-Shaw/sub2api/ent/usagelog" @@ -55,6 +56,7 @@ const ( TypePromoCodeUsage = "PromoCodeUsage" TypeProxy = "Proxy" TypeRedeemCode = "RedeemCode" + TypeSecuritySecret = "SecuritySecret" TypeSetting = "Setting" TypeUsageCleanupTask = "UsageCleanupTask" TypeUsageLog = "UsageLog" @@ -77,6 +79,7 @@ type APIKeyMutation struct { key *string name *string status *string + last_used_at *time.Time ip_whitelist *[]string appendip_whitelist []string ip_blacklist *[]string @@ -511,6 +514,55 @@ func (m *APIKeyMutation) ResetStatus() { m.status = nil } +// SetLastUsedAt sets the "last_used_at" field. +func (m *APIKeyMutation) SetLastUsedAt(t time.Time) { + m.last_used_at = &t +} + +// LastUsedAt returns the value of the "last_used_at" field in the mutation. +func (m *APIKeyMutation) LastUsedAt() (r time.Time, exists bool) { + v := m.last_used_at + if v == nil { + return + } + return *v, true +} + +// OldLastUsedAt returns the old "last_used_at" field's value of the APIKey entity. +// If the APIKey object wasn't provided to the builder, the object is fetched from the database. +// An error is returned if the mutation operation is not UpdateOne, or the database query fails. +func (m *APIKeyMutation) OldLastUsedAt(ctx context.Context) (v *time.Time, err error) { + if !m.op.Is(OpUpdateOne) { + return v, errors.New("OldLastUsedAt is only allowed on UpdateOne operations") + } + if m.id == nil || m.oldValue == nil { + return v, errors.New("OldLastUsedAt requires an ID field in the mutation") + } + oldValue, err := m.oldValue(ctx) + if err != nil { + return v, fmt.Errorf("querying old value for OldLastUsedAt: %w", err) + } + return oldValue.LastUsedAt, nil +} + +// ClearLastUsedAt clears the value of the "last_used_at" field. +func (m *APIKeyMutation) ClearLastUsedAt() { + m.last_used_at = nil + m.clearedFields[apikey.FieldLastUsedAt] = struct{}{} +} + +// LastUsedAtCleared returns if the "last_used_at" field was cleared in this mutation. +func (m *APIKeyMutation) LastUsedAtCleared() bool { + _, ok := m.clearedFields[apikey.FieldLastUsedAt] + return ok +} + +// ResetLastUsedAt resets all changes to the "last_used_at" field. +func (m *APIKeyMutation) ResetLastUsedAt() { + m.last_used_at = nil + delete(m.clearedFields, apikey.FieldLastUsedAt) +} + // SetIPWhitelist sets the "ip_whitelist" field. func (m *APIKeyMutation) SetIPWhitelist(s []string) { m.ip_whitelist = &s @@ -944,7 +996,7 @@ func (m *APIKeyMutation) Type() string { // order to get all numeric fields that were incremented/decremented, call // AddedFields(). func (m *APIKeyMutation) Fields() []string { - fields := make([]string, 0, 13) + fields := make([]string, 0, 14) if m.created_at != nil { fields = append(fields, apikey.FieldCreatedAt) } @@ -969,6 +1021,9 @@ func (m *APIKeyMutation) Fields() []string { if m.status != nil { fields = append(fields, apikey.FieldStatus) } + if m.last_used_at != nil { + fields = append(fields, apikey.FieldLastUsedAt) + } if m.ip_whitelist != nil { fields = append(fields, apikey.FieldIPWhitelist) } @@ -1008,6 +1063,8 @@ func (m *APIKeyMutation) Field(name string) (ent.Value, bool) { return m.GroupID() case apikey.FieldStatus: return m.Status() + case apikey.FieldLastUsedAt: + return m.LastUsedAt() case apikey.FieldIPWhitelist: return m.IPWhitelist() case apikey.FieldIPBlacklist: @@ -1043,6 +1100,8 @@ func (m *APIKeyMutation) OldField(ctx context.Context, name string) (ent.Value, return m.OldGroupID(ctx) case apikey.FieldStatus: return m.OldStatus(ctx) + case apikey.FieldLastUsedAt: + return m.OldLastUsedAt(ctx) case apikey.FieldIPWhitelist: return m.OldIPWhitelist(ctx) case apikey.FieldIPBlacklist: @@ -1118,6 +1177,13 @@ func (m *APIKeyMutation) SetField(name string, value ent.Value) error { } m.SetStatus(v) return nil + case apikey.FieldLastUsedAt: + v, ok := value.(time.Time) + if !ok { + return fmt.Errorf("unexpected type %T for field %s", value, name) + } + m.SetLastUsedAt(v) + return nil case apikey.FieldIPWhitelist: v, ok := value.([]string) if !ok { @@ -1216,6 +1282,9 @@ func (m *APIKeyMutation) ClearedFields() []string { if m.FieldCleared(apikey.FieldGroupID) { fields = append(fields, apikey.FieldGroupID) } + if m.FieldCleared(apikey.FieldLastUsedAt) { + fields = append(fields, apikey.FieldLastUsedAt) + } if m.FieldCleared(apikey.FieldIPWhitelist) { fields = append(fields, apikey.FieldIPWhitelist) } @@ -1245,6 +1314,9 @@ func (m *APIKeyMutation) ClearField(name string) error { case apikey.FieldGroupID: m.ClearGroupID() return nil + case apikey.FieldLastUsedAt: + m.ClearLastUsedAt() + return nil case apikey.FieldIPWhitelist: m.ClearIPWhitelist() return nil @@ -1286,6 +1358,9 @@ func (m *APIKeyMutation) ResetField(name string) error { case apikey.FieldStatus: m.ResetStatus() return nil + case apikey.FieldLastUsedAt: + m.ResetLastUsedAt() + return nil case apikey.FieldIPWhitelist: m.ResetIPWhitelist() return nil @@ -7103,6 +7178,14 @@ type GroupMutation struct { addimage_price_2k *float64 image_price_4k *float64 addimage_price_4k *float64 + sora_image_price_360 *float64 + addsora_image_price_360 *float64 + sora_image_price_540 *float64 + addsora_image_price_540 *float64 + sora_video_price_per_request *float64 + addsora_video_price_per_request *float64 + sora_video_price_per_request_hd *float64 + addsora_video_price_per_request_hd *float64 claude_code_only *bool fallback_group_id *int64 addfallback_group_id *int64 @@ -8119,6 +8202,286 @@ func (m *GroupMutation) ResetImagePrice4k() { delete(m.clearedFields, group.FieldImagePrice4k) } +// SetSoraImagePrice360 sets the "sora_image_price_360" field. +func (m *GroupMutation) SetSoraImagePrice360(f float64) { + m.sora_image_price_360 = &f + m.addsora_image_price_360 = nil +} + +// SoraImagePrice360 returns the value of the "sora_image_price_360" field in the mutation. +func (m *GroupMutation) SoraImagePrice360() (r float64, exists bool) { + v := m.sora_image_price_360 + if v == nil { + return + } + return *v, true +} + +// OldSoraImagePrice360 returns the old "sora_image_price_360" field's value of the Group entity. +// If the Group object wasn't provided to the builder, the object is fetched from the database. +// An error is returned if the mutation operation is not UpdateOne, or the database query fails. +func (m *GroupMutation) OldSoraImagePrice360(ctx context.Context) (v *float64, err error) { + if !m.op.Is(OpUpdateOne) { + return v, errors.New("OldSoraImagePrice360 is only allowed on UpdateOne operations") + } + if m.id == nil || m.oldValue == nil { + return v, errors.New("OldSoraImagePrice360 requires an ID field in the mutation") + } + oldValue, err := m.oldValue(ctx) + if err != nil { + return v, fmt.Errorf("querying old value for OldSoraImagePrice360: %w", err) + } + return oldValue.SoraImagePrice360, nil +} + +// AddSoraImagePrice360 adds f to the "sora_image_price_360" field. +func (m *GroupMutation) AddSoraImagePrice360(f float64) { + if m.addsora_image_price_360 != nil { + *m.addsora_image_price_360 += f + } else { + m.addsora_image_price_360 = &f + } +} + +// AddedSoraImagePrice360 returns the value that was added to the "sora_image_price_360" field in this mutation. +func (m *GroupMutation) AddedSoraImagePrice360() (r float64, exists bool) { + v := m.addsora_image_price_360 + if v == nil { + return + } + return *v, true +} + +// ClearSoraImagePrice360 clears the value of the "sora_image_price_360" field. +func (m *GroupMutation) ClearSoraImagePrice360() { + m.sora_image_price_360 = nil + m.addsora_image_price_360 = nil + m.clearedFields[group.FieldSoraImagePrice360] = struct{}{} +} + +// SoraImagePrice360Cleared returns if the "sora_image_price_360" field was cleared in this mutation. +func (m *GroupMutation) SoraImagePrice360Cleared() bool { + _, ok := m.clearedFields[group.FieldSoraImagePrice360] + return ok +} + +// ResetSoraImagePrice360 resets all changes to the "sora_image_price_360" field. +func (m *GroupMutation) ResetSoraImagePrice360() { + m.sora_image_price_360 = nil + m.addsora_image_price_360 = nil + delete(m.clearedFields, group.FieldSoraImagePrice360) +} + +// SetSoraImagePrice540 sets the "sora_image_price_540" field. +func (m *GroupMutation) SetSoraImagePrice540(f float64) { + m.sora_image_price_540 = &f + m.addsora_image_price_540 = nil +} + +// SoraImagePrice540 returns the value of the "sora_image_price_540" field in the mutation. +func (m *GroupMutation) SoraImagePrice540() (r float64, exists bool) { + v := m.sora_image_price_540 + if v == nil { + return + } + return *v, true +} + +// OldSoraImagePrice540 returns the old "sora_image_price_540" field's value of the Group entity. +// If the Group object wasn't provided to the builder, the object is fetched from the database. +// An error is returned if the mutation operation is not UpdateOne, or the database query fails. +func (m *GroupMutation) OldSoraImagePrice540(ctx context.Context) (v *float64, err error) { + if !m.op.Is(OpUpdateOne) { + return v, errors.New("OldSoraImagePrice540 is only allowed on UpdateOne operations") + } + if m.id == nil || m.oldValue == nil { + return v, errors.New("OldSoraImagePrice540 requires an ID field in the mutation") + } + oldValue, err := m.oldValue(ctx) + if err != nil { + return v, fmt.Errorf("querying old value for OldSoraImagePrice540: %w", err) + } + return oldValue.SoraImagePrice540, nil +} + +// AddSoraImagePrice540 adds f to the "sora_image_price_540" field. +func (m *GroupMutation) AddSoraImagePrice540(f float64) { + if m.addsora_image_price_540 != nil { + *m.addsora_image_price_540 += f + } else { + m.addsora_image_price_540 = &f + } +} + +// AddedSoraImagePrice540 returns the value that was added to the "sora_image_price_540" field in this mutation. +func (m *GroupMutation) AddedSoraImagePrice540() (r float64, exists bool) { + v := m.addsora_image_price_540 + if v == nil { + return + } + return *v, true +} + +// ClearSoraImagePrice540 clears the value of the "sora_image_price_540" field. +func (m *GroupMutation) ClearSoraImagePrice540() { + m.sora_image_price_540 = nil + m.addsora_image_price_540 = nil + m.clearedFields[group.FieldSoraImagePrice540] = struct{}{} +} + +// SoraImagePrice540Cleared returns if the "sora_image_price_540" field was cleared in this mutation. +func (m *GroupMutation) SoraImagePrice540Cleared() bool { + _, ok := m.clearedFields[group.FieldSoraImagePrice540] + return ok +} + +// ResetSoraImagePrice540 resets all changes to the "sora_image_price_540" field. +func (m *GroupMutation) ResetSoraImagePrice540() { + m.sora_image_price_540 = nil + m.addsora_image_price_540 = nil + delete(m.clearedFields, group.FieldSoraImagePrice540) +} + +// SetSoraVideoPricePerRequest sets the "sora_video_price_per_request" field. +func (m *GroupMutation) SetSoraVideoPricePerRequest(f float64) { + m.sora_video_price_per_request = &f + m.addsora_video_price_per_request = nil +} + +// SoraVideoPricePerRequest returns the value of the "sora_video_price_per_request" field in the mutation. +func (m *GroupMutation) SoraVideoPricePerRequest() (r float64, exists bool) { + v := m.sora_video_price_per_request + if v == nil { + return + } + return *v, true +} + +// OldSoraVideoPricePerRequest returns the old "sora_video_price_per_request" field's value of the Group entity. +// If the Group object wasn't provided to the builder, the object is fetched from the database. +// An error is returned if the mutation operation is not UpdateOne, or the database query fails. +func (m *GroupMutation) OldSoraVideoPricePerRequest(ctx context.Context) (v *float64, err error) { + if !m.op.Is(OpUpdateOne) { + return v, errors.New("OldSoraVideoPricePerRequest is only allowed on UpdateOne operations") + } + if m.id == nil || m.oldValue == nil { + return v, errors.New("OldSoraVideoPricePerRequest requires an ID field in the mutation") + } + oldValue, err := m.oldValue(ctx) + if err != nil { + return v, fmt.Errorf("querying old value for OldSoraVideoPricePerRequest: %w", err) + } + return oldValue.SoraVideoPricePerRequest, nil +} + +// AddSoraVideoPricePerRequest adds f to the "sora_video_price_per_request" field. +func (m *GroupMutation) AddSoraVideoPricePerRequest(f float64) { + if m.addsora_video_price_per_request != nil { + *m.addsora_video_price_per_request += f + } else { + m.addsora_video_price_per_request = &f + } +} + +// AddedSoraVideoPricePerRequest returns the value that was added to the "sora_video_price_per_request" field in this mutation. +func (m *GroupMutation) AddedSoraVideoPricePerRequest() (r float64, exists bool) { + v := m.addsora_video_price_per_request + if v == nil { + return + } + return *v, true +} + +// ClearSoraVideoPricePerRequest clears the value of the "sora_video_price_per_request" field. +func (m *GroupMutation) ClearSoraVideoPricePerRequest() { + m.sora_video_price_per_request = nil + m.addsora_video_price_per_request = nil + m.clearedFields[group.FieldSoraVideoPricePerRequest] = struct{}{} +} + +// SoraVideoPricePerRequestCleared returns if the "sora_video_price_per_request" field was cleared in this mutation. +func (m *GroupMutation) SoraVideoPricePerRequestCleared() bool { + _, ok := m.clearedFields[group.FieldSoraVideoPricePerRequest] + return ok +} + +// ResetSoraVideoPricePerRequest resets all changes to the "sora_video_price_per_request" field. +func (m *GroupMutation) ResetSoraVideoPricePerRequest() { + m.sora_video_price_per_request = nil + m.addsora_video_price_per_request = nil + delete(m.clearedFields, group.FieldSoraVideoPricePerRequest) +} + +// SetSoraVideoPricePerRequestHd sets the "sora_video_price_per_request_hd" field. +func (m *GroupMutation) SetSoraVideoPricePerRequestHd(f float64) { + m.sora_video_price_per_request_hd = &f + m.addsora_video_price_per_request_hd = nil +} + +// SoraVideoPricePerRequestHd returns the value of the "sora_video_price_per_request_hd" field in the mutation. +func (m *GroupMutation) SoraVideoPricePerRequestHd() (r float64, exists bool) { + v := m.sora_video_price_per_request_hd + if v == nil { + return + } + return *v, true +} + +// OldSoraVideoPricePerRequestHd returns the old "sora_video_price_per_request_hd" field's value of the Group entity. +// If the Group object wasn't provided to the builder, the object is fetched from the database. +// An error is returned if the mutation operation is not UpdateOne, or the database query fails. +func (m *GroupMutation) OldSoraVideoPricePerRequestHd(ctx context.Context) (v *float64, err error) { + if !m.op.Is(OpUpdateOne) { + return v, errors.New("OldSoraVideoPricePerRequestHd is only allowed on UpdateOne operations") + } + if m.id == nil || m.oldValue == nil { + return v, errors.New("OldSoraVideoPricePerRequestHd requires an ID field in the mutation") + } + oldValue, err := m.oldValue(ctx) + if err != nil { + return v, fmt.Errorf("querying old value for OldSoraVideoPricePerRequestHd: %w", err) + } + return oldValue.SoraVideoPricePerRequestHd, nil +} + +// AddSoraVideoPricePerRequestHd adds f to the "sora_video_price_per_request_hd" field. +func (m *GroupMutation) AddSoraVideoPricePerRequestHd(f float64) { + if m.addsora_video_price_per_request_hd != nil { + *m.addsora_video_price_per_request_hd += f + } else { + m.addsora_video_price_per_request_hd = &f + } +} + +// AddedSoraVideoPricePerRequestHd returns the value that was added to the "sora_video_price_per_request_hd" field in this mutation. +func (m *GroupMutation) AddedSoraVideoPricePerRequestHd() (r float64, exists bool) { + v := m.addsora_video_price_per_request_hd + if v == nil { + return + } + return *v, true +} + +// ClearSoraVideoPricePerRequestHd clears the value of the "sora_video_price_per_request_hd" field. +func (m *GroupMutation) ClearSoraVideoPricePerRequestHd() { + m.sora_video_price_per_request_hd = nil + m.addsora_video_price_per_request_hd = nil + m.clearedFields[group.FieldSoraVideoPricePerRequestHd] = struct{}{} +} + +// SoraVideoPricePerRequestHdCleared returns if the "sora_video_price_per_request_hd" field was cleared in this mutation. +func (m *GroupMutation) SoraVideoPricePerRequestHdCleared() bool { + _, ok := m.clearedFields[group.FieldSoraVideoPricePerRequestHd] + return ok +} + +// ResetSoraVideoPricePerRequestHd resets all changes to the "sora_video_price_per_request_hd" field. +func (m *GroupMutation) ResetSoraVideoPricePerRequestHd() { + m.sora_video_price_per_request_hd = nil + m.addsora_video_price_per_request_hd = nil + delete(m.clearedFields, group.FieldSoraVideoPricePerRequestHd) +} + // SetClaudeCodeOnly sets the "claude_code_only" field. func (m *GroupMutation) SetClaudeCodeOnly(b bool) { m.claude_code_only = &b @@ -8881,7 +9244,7 @@ func (m *GroupMutation) Type() string { // order to get all numeric fields that were incremented/decremented, call // AddedFields(). func (m *GroupMutation) Fields() []string { - fields := make([]string, 0, 25) + fields := make([]string, 0, 29) if m.created_at != nil { fields = append(fields, group.FieldCreatedAt) } @@ -8933,6 +9296,18 @@ func (m *GroupMutation) Fields() []string { if m.image_price_4k != nil { fields = append(fields, group.FieldImagePrice4k) } + if m.sora_image_price_360 != nil { + fields = append(fields, group.FieldSoraImagePrice360) + } + if m.sora_image_price_540 != nil { + fields = append(fields, group.FieldSoraImagePrice540) + } + if m.sora_video_price_per_request != nil { + fields = append(fields, group.FieldSoraVideoPricePerRequest) + } + if m.sora_video_price_per_request_hd != nil { + fields = append(fields, group.FieldSoraVideoPricePerRequestHd) + } if m.claude_code_only != nil { fields = append(fields, group.FieldClaudeCodeOnly) } @@ -8999,6 +9374,14 @@ func (m *GroupMutation) Field(name string) (ent.Value, bool) { return m.ImagePrice2k() case group.FieldImagePrice4k: return m.ImagePrice4k() + case group.FieldSoraImagePrice360: + return m.SoraImagePrice360() + case group.FieldSoraImagePrice540: + return m.SoraImagePrice540() + case group.FieldSoraVideoPricePerRequest: + return m.SoraVideoPricePerRequest() + case group.FieldSoraVideoPricePerRequestHd: + return m.SoraVideoPricePerRequestHd() case group.FieldClaudeCodeOnly: return m.ClaudeCodeOnly() case group.FieldFallbackGroupID: @@ -9058,6 +9441,14 @@ func (m *GroupMutation) OldField(ctx context.Context, name string) (ent.Value, e return m.OldImagePrice2k(ctx) case group.FieldImagePrice4k: return m.OldImagePrice4k(ctx) + case group.FieldSoraImagePrice360: + return m.OldSoraImagePrice360(ctx) + case group.FieldSoraImagePrice540: + return m.OldSoraImagePrice540(ctx) + case group.FieldSoraVideoPricePerRequest: + return m.OldSoraVideoPricePerRequest(ctx) + case group.FieldSoraVideoPricePerRequestHd: + return m.OldSoraVideoPricePerRequestHd(ctx) case group.FieldClaudeCodeOnly: return m.OldClaudeCodeOnly(ctx) case group.FieldFallbackGroupID: @@ -9202,6 +9593,34 @@ func (m *GroupMutation) SetField(name string, value ent.Value) error { } m.SetImagePrice4k(v) return nil + case group.FieldSoraImagePrice360: + v, ok := value.(float64) + if !ok { + return fmt.Errorf("unexpected type %T for field %s", value, name) + } + m.SetSoraImagePrice360(v) + return nil + case group.FieldSoraImagePrice540: + v, ok := value.(float64) + if !ok { + return fmt.Errorf("unexpected type %T for field %s", value, name) + } + m.SetSoraImagePrice540(v) + return nil + case group.FieldSoraVideoPricePerRequest: + v, ok := value.(float64) + if !ok { + return fmt.Errorf("unexpected type %T for field %s", value, name) + } + m.SetSoraVideoPricePerRequest(v) + return nil + case group.FieldSoraVideoPricePerRequestHd: + v, ok := value.(float64) + if !ok { + return fmt.Errorf("unexpected type %T for field %s", value, name) + } + m.SetSoraVideoPricePerRequestHd(v) + return nil case group.FieldClaudeCodeOnly: v, ok := value.(bool) if !ok { @@ -9290,6 +9709,18 @@ func (m *GroupMutation) AddedFields() []string { if m.addimage_price_4k != nil { fields = append(fields, group.FieldImagePrice4k) } + if m.addsora_image_price_360 != nil { + fields = append(fields, group.FieldSoraImagePrice360) + } + if m.addsora_image_price_540 != nil { + fields = append(fields, group.FieldSoraImagePrice540) + } + if m.addsora_video_price_per_request != nil { + fields = append(fields, group.FieldSoraVideoPricePerRequest) + } + if m.addsora_video_price_per_request_hd != nil { + fields = append(fields, group.FieldSoraVideoPricePerRequestHd) + } if m.addfallback_group_id != nil { fields = append(fields, group.FieldFallbackGroupID) } @@ -9323,6 +9754,14 @@ func (m *GroupMutation) AddedField(name string) (ent.Value, bool) { return m.AddedImagePrice2k() case group.FieldImagePrice4k: return m.AddedImagePrice4k() + case group.FieldSoraImagePrice360: + return m.AddedSoraImagePrice360() + case group.FieldSoraImagePrice540: + return m.AddedSoraImagePrice540() + case group.FieldSoraVideoPricePerRequest: + return m.AddedSoraVideoPricePerRequest() + case group.FieldSoraVideoPricePerRequestHd: + return m.AddedSoraVideoPricePerRequestHd() case group.FieldFallbackGroupID: return m.AddedFallbackGroupID() case group.FieldFallbackGroupIDOnInvalidRequest: @@ -9394,6 +9833,34 @@ func (m *GroupMutation) AddField(name string, value ent.Value) error { } m.AddImagePrice4k(v) return nil + case group.FieldSoraImagePrice360: + v, ok := value.(float64) + if !ok { + return fmt.Errorf("unexpected type %T for field %s", value, name) + } + m.AddSoraImagePrice360(v) + return nil + case group.FieldSoraImagePrice540: + v, ok := value.(float64) + if !ok { + return fmt.Errorf("unexpected type %T for field %s", value, name) + } + m.AddSoraImagePrice540(v) + return nil + case group.FieldSoraVideoPricePerRequest: + v, ok := value.(float64) + if !ok { + return fmt.Errorf("unexpected type %T for field %s", value, name) + } + m.AddSoraVideoPricePerRequest(v) + return nil + case group.FieldSoraVideoPricePerRequestHd: + v, ok := value.(float64) + if !ok { + return fmt.Errorf("unexpected type %T for field %s", value, name) + } + m.AddSoraVideoPricePerRequestHd(v) + return nil case group.FieldFallbackGroupID: v, ok := value.(int64) if !ok { @@ -9447,6 +9914,18 @@ func (m *GroupMutation) ClearedFields() []string { if m.FieldCleared(group.FieldImagePrice4k) { fields = append(fields, group.FieldImagePrice4k) } + if m.FieldCleared(group.FieldSoraImagePrice360) { + fields = append(fields, group.FieldSoraImagePrice360) + } + if m.FieldCleared(group.FieldSoraImagePrice540) { + fields = append(fields, group.FieldSoraImagePrice540) + } + if m.FieldCleared(group.FieldSoraVideoPricePerRequest) { + fields = append(fields, group.FieldSoraVideoPricePerRequest) + } + if m.FieldCleared(group.FieldSoraVideoPricePerRequestHd) { + fields = append(fields, group.FieldSoraVideoPricePerRequestHd) + } if m.FieldCleared(group.FieldFallbackGroupID) { fields = append(fields, group.FieldFallbackGroupID) } @@ -9494,6 +9973,18 @@ func (m *GroupMutation) ClearField(name string) error { case group.FieldImagePrice4k: m.ClearImagePrice4k() return nil + case group.FieldSoraImagePrice360: + m.ClearSoraImagePrice360() + return nil + case group.FieldSoraImagePrice540: + m.ClearSoraImagePrice540() + return nil + case group.FieldSoraVideoPricePerRequest: + m.ClearSoraVideoPricePerRequest() + return nil + case group.FieldSoraVideoPricePerRequestHd: + m.ClearSoraVideoPricePerRequestHd() + return nil case group.FieldFallbackGroupID: m.ClearFallbackGroupID() return nil @@ -9562,6 +10053,18 @@ func (m *GroupMutation) ResetField(name string) error { case group.FieldImagePrice4k: m.ResetImagePrice4k() return nil + case group.FieldSoraImagePrice360: + m.ResetSoraImagePrice360() + return nil + case group.FieldSoraImagePrice540: + m.ResetSoraImagePrice540() + return nil + case group.FieldSoraVideoPricePerRequest: + m.ResetSoraVideoPricePerRequest() + return nil + case group.FieldSoraVideoPricePerRequestHd: + m.ResetSoraVideoPricePerRequestHd() + return nil case group.FieldClaudeCodeOnly: m.ResetClaudeCodeOnly() return nil @@ -13496,6 +13999,494 @@ func (m *RedeemCodeMutation) ResetEdge(name string) error { return fmt.Errorf("unknown RedeemCode edge %s", name) } +// SecuritySecretMutation represents an operation that mutates the SecuritySecret nodes in the graph. +type SecuritySecretMutation struct { + config + op Op + typ string + id *int64 + created_at *time.Time + updated_at *time.Time + key *string + value *string + clearedFields map[string]struct{} + done bool + oldValue func(context.Context) (*SecuritySecret, error) + predicates []predicate.SecuritySecret +} + +var _ ent.Mutation = (*SecuritySecretMutation)(nil) + +// securitysecretOption allows management of the mutation configuration using functional options. +type securitysecretOption func(*SecuritySecretMutation) + +// newSecuritySecretMutation creates new mutation for the SecuritySecret entity. +func newSecuritySecretMutation(c config, op Op, opts ...securitysecretOption) *SecuritySecretMutation { + m := &SecuritySecretMutation{ + config: c, + op: op, + typ: TypeSecuritySecret, + clearedFields: make(map[string]struct{}), + } + for _, opt := range opts { + opt(m) + } + return m +} + +// withSecuritySecretID sets the ID field of the mutation. +func withSecuritySecretID(id int64) securitysecretOption { + return func(m *SecuritySecretMutation) { + var ( + err error + once sync.Once + value *SecuritySecret + ) + m.oldValue = func(ctx context.Context) (*SecuritySecret, error) { + once.Do(func() { + if m.done { + err = errors.New("querying old values post mutation is not allowed") + } else { + value, err = m.Client().SecuritySecret.Get(ctx, id) + } + }) + return value, err + } + m.id = &id + } +} + +// withSecuritySecret sets the old SecuritySecret of the mutation. +func withSecuritySecret(node *SecuritySecret) securitysecretOption { + return func(m *SecuritySecretMutation) { + m.oldValue = func(context.Context) (*SecuritySecret, error) { + return node, nil + } + m.id = &node.ID + } +} + +// Client returns a new `ent.Client` from the mutation. If the mutation was +// executed in a transaction (ent.Tx), a transactional client is returned. +func (m SecuritySecretMutation) Client() *Client { + client := &Client{config: m.config} + client.init() + return client +} + +// Tx returns an `ent.Tx` for mutations that were executed in transactions; +// it returns an error otherwise. +func (m SecuritySecretMutation) Tx() (*Tx, error) { + if _, ok := m.driver.(*txDriver); !ok { + return nil, errors.New("ent: mutation is not running in a transaction") + } + tx := &Tx{config: m.config} + tx.init() + return tx, nil +} + +// ID returns the ID value in the mutation. Note that the ID is only available +// if it was provided to the builder or after it was returned from the database. +func (m *SecuritySecretMutation) ID() (id int64, exists bool) { + if m.id == nil { + return + } + return *m.id, true +} + +// IDs queries the database and returns the entity ids that match the mutation's predicate. +// That means, if the mutation is applied within a transaction with an isolation level such +// as sql.LevelSerializable, the returned ids match the ids of the rows that will be updated +// or updated by the mutation. +func (m *SecuritySecretMutation) IDs(ctx context.Context) ([]int64, error) { + switch { + case m.op.Is(OpUpdateOne | OpDeleteOne): + id, exists := m.ID() + if exists { + return []int64{id}, nil + } + fallthrough + case m.op.Is(OpUpdate | OpDelete): + return m.Client().SecuritySecret.Query().Where(m.predicates...).IDs(ctx) + default: + return nil, fmt.Errorf("IDs is not allowed on %s operations", m.op) + } +} + +// SetCreatedAt sets the "created_at" field. +func (m *SecuritySecretMutation) SetCreatedAt(t time.Time) { + m.created_at = &t +} + +// CreatedAt returns the value of the "created_at" field in the mutation. +func (m *SecuritySecretMutation) CreatedAt() (r time.Time, exists bool) { + v := m.created_at + if v == nil { + return + } + return *v, true +} + +// OldCreatedAt returns the old "created_at" field's value of the SecuritySecret entity. +// If the SecuritySecret object wasn't provided to the builder, the object is fetched from the database. +// An error is returned if the mutation operation is not UpdateOne, or the database query fails. +func (m *SecuritySecretMutation) OldCreatedAt(ctx context.Context) (v time.Time, err error) { + if !m.op.Is(OpUpdateOne) { + return v, errors.New("OldCreatedAt is only allowed on UpdateOne operations") + } + if m.id == nil || m.oldValue == nil { + return v, errors.New("OldCreatedAt requires an ID field in the mutation") + } + oldValue, err := m.oldValue(ctx) + if err != nil { + return v, fmt.Errorf("querying old value for OldCreatedAt: %w", err) + } + return oldValue.CreatedAt, nil +} + +// ResetCreatedAt resets all changes to the "created_at" field. +func (m *SecuritySecretMutation) ResetCreatedAt() { + m.created_at = nil +} + +// SetUpdatedAt sets the "updated_at" field. +func (m *SecuritySecretMutation) SetUpdatedAt(t time.Time) { + m.updated_at = &t +} + +// UpdatedAt returns the value of the "updated_at" field in the mutation. +func (m *SecuritySecretMutation) UpdatedAt() (r time.Time, exists bool) { + v := m.updated_at + if v == nil { + return + } + return *v, true +} + +// OldUpdatedAt returns the old "updated_at" field's value of the SecuritySecret entity. +// If the SecuritySecret object wasn't provided to the builder, the object is fetched from the database. +// An error is returned if the mutation operation is not UpdateOne, or the database query fails. +func (m *SecuritySecretMutation) OldUpdatedAt(ctx context.Context) (v time.Time, err error) { + if !m.op.Is(OpUpdateOne) { + return v, errors.New("OldUpdatedAt is only allowed on UpdateOne operations") + } + if m.id == nil || m.oldValue == nil { + return v, errors.New("OldUpdatedAt requires an ID field in the mutation") + } + oldValue, err := m.oldValue(ctx) + if err != nil { + return v, fmt.Errorf("querying old value for OldUpdatedAt: %w", err) + } + return oldValue.UpdatedAt, nil +} + +// ResetUpdatedAt resets all changes to the "updated_at" field. +func (m *SecuritySecretMutation) ResetUpdatedAt() { + m.updated_at = nil +} + +// SetKey sets the "key" field. +func (m *SecuritySecretMutation) SetKey(s string) { + m.key = &s +} + +// Key returns the value of the "key" field in the mutation. +func (m *SecuritySecretMutation) Key() (r string, exists bool) { + v := m.key + if v == nil { + return + } + return *v, true +} + +// OldKey returns the old "key" field's value of the SecuritySecret entity. +// If the SecuritySecret object wasn't provided to the builder, the object is fetched from the database. +// An error is returned if the mutation operation is not UpdateOne, or the database query fails. +func (m *SecuritySecretMutation) OldKey(ctx context.Context) (v string, err error) { + if !m.op.Is(OpUpdateOne) { + return v, errors.New("OldKey is only allowed on UpdateOne operations") + } + if m.id == nil || m.oldValue == nil { + return v, errors.New("OldKey requires an ID field in the mutation") + } + oldValue, err := m.oldValue(ctx) + if err != nil { + return v, fmt.Errorf("querying old value for OldKey: %w", err) + } + return oldValue.Key, nil +} + +// ResetKey resets all changes to the "key" field. +func (m *SecuritySecretMutation) ResetKey() { + m.key = nil +} + +// SetValue sets the "value" field. +func (m *SecuritySecretMutation) SetValue(s string) { + m.value = &s +} + +// Value returns the value of the "value" field in the mutation. +func (m *SecuritySecretMutation) Value() (r string, exists bool) { + v := m.value + if v == nil { + return + } + return *v, true +} + +// OldValue returns the old "value" field's value of the SecuritySecret entity. +// If the SecuritySecret object wasn't provided to the builder, the object is fetched from the database. +// An error is returned if the mutation operation is not UpdateOne, or the database query fails. +func (m *SecuritySecretMutation) OldValue(ctx context.Context) (v string, err error) { + if !m.op.Is(OpUpdateOne) { + return v, errors.New("OldValue is only allowed on UpdateOne operations") + } + if m.id == nil || m.oldValue == nil { + return v, errors.New("OldValue requires an ID field in the mutation") + } + oldValue, err := m.oldValue(ctx) + if err != nil { + return v, fmt.Errorf("querying old value for OldValue: %w", err) + } + return oldValue.Value, nil +} + +// ResetValue resets all changes to the "value" field. +func (m *SecuritySecretMutation) ResetValue() { + m.value = nil +} + +// Where appends a list predicates to the SecuritySecretMutation builder. +func (m *SecuritySecretMutation) Where(ps ...predicate.SecuritySecret) { + m.predicates = append(m.predicates, ps...) +} + +// WhereP appends storage-level predicates to the SecuritySecretMutation builder. Using this method, +// users can use type-assertion to append predicates that do not depend on any generated package. +func (m *SecuritySecretMutation) WhereP(ps ...func(*sql.Selector)) { + p := make([]predicate.SecuritySecret, len(ps)) + for i := range ps { + p[i] = ps[i] + } + m.Where(p...) +} + +// Op returns the operation name. +func (m *SecuritySecretMutation) Op() Op { + return m.op +} + +// SetOp allows setting the mutation operation. +func (m *SecuritySecretMutation) SetOp(op Op) { + m.op = op +} + +// Type returns the node type of this mutation (SecuritySecret). +func (m *SecuritySecretMutation) Type() string { + return m.typ +} + +// Fields returns all fields that were changed during this mutation. Note that in +// order to get all numeric fields that were incremented/decremented, call +// AddedFields(). +func (m *SecuritySecretMutation) Fields() []string { + fields := make([]string, 0, 4) + if m.created_at != nil { + fields = append(fields, securitysecret.FieldCreatedAt) + } + if m.updated_at != nil { + fields = append(fields, securitysecret.FieldUpdatedAt) + } + if m.key != nil { + fields = append(fields, securitysecret.FieldKey) + } + if m.value != nil { + fields = append(fields, securitysecret.FieldValue) + } + return fields +} + +// Field returns the value of a field with the given name. The second boolean +// return value indicates that this field was not set, or was not defined in the +// schema. +func (m *SecuritySecretMutation) Field(name string) (ent.Value, bool) { + switch name { + case securitysecret.FieldCreatedAt: + return m.CreatedAt() + case securitysecret.FieldUpdatedAt: + return m.UpdatedAt() + case securitysecret.FieldKey: + return m.Key() + case securitysecret.FieldValue: + return m.Value() + } + return nil, false +} + +// OldField returns the old value of the field from the database. An error is +// returned if the mutation operation is not UpdateOne, or the query to the +// database failed. +func (m *SecuritySecretMutation) OldField(ctx context.Context, name string) (ent.Value, error) { + switch name { + case securitysecret.FieldCreatedAt: + return m.OldCreatedAt(ctx) + case securitysecret.FieldUpdatedAt: + return m.OldUpdatedAt(ctx) + case securitysecret.FieldKey: + return m.OldKey(ctx) + case securitysecret.FieldValue: + return m.OldValue(ctx) + } + return nil, fmt.Errorf("unknown SecuritySecret field %s", name) +} + +// SetField sets the value of a field with the given name. It returns an error if +// the field is not defined in the schema, or if the type mismatched the field +// type. +func (m *SecuritySecretMutation) SetField(name string, value ent.Value) error { + switch name { + case securitysecret.FieldCreatedAt: + v, ok := value.(time.Time) + if !ok { + return fmt.Errorf("unexpected type %T for field %s", value, name) + } + m.SetCreatedAt(v) + return nil + case securitysecret.FieldUpdatedAt: + v, ok := value.(time.Time) + if !ok { + return fmt.Errorf("unexpected type %T for field %s", value, name) + } + m.SetUpdatedAt(v) + return nil + case securitysecret.FieldKey: + v, ok := value.(string) + if !ok { + return fmt.Errorf("unexpected type %T for field %s", value, name) + } + m.SetKey(v) + return nil + case securitysecret.FieldValue: + v, ok := value.(string) + if !ok { + return fmt.Errorf("unexpected type %T for field %s", value, name) + } + m.SetValue(v) + return nil + } + return fmt.Errorf("unknown SecuritySecret field %s", name) +} + +// AddedFields returns all numeric fields that were incremented/decremented during +// this mutation. +func (m *SecuritySecretMutation) AddedFields() []string { + return nil +} + +// AddedField returns the numeric value that was incremented/decremented on a field +// with the given name. The second boolean return value indicates that this field +// was not set, or was not defined in the schema. +func (m *SecuritySecretMutation) AddedField(name string) (ent.Value, bool) { + return nil, false +} + +// AddField adds the value to the field with the given name. It returns an error if +// the field is not defined in the schema, or if the type mismatched the field +// type. +func (m *SecuritySecretMutation) AddField(name string, value ent.Value) error { + switch name { + } + return fmt.Errorf("unknown SecuritySecret numeric field %s", name) +} + +// ClearedFields returns all nullable fields that were cleared during this +// mutation. +func (m *SecuritySecretMutation) ClearedFields() []string { + return nil +} + +// FieldCleared returns a boolean indicating if a field with the given name was +// cleared in this mutation. +func (m *SecuritySecretMutation) FieldCleared(name string) bool { + _, ok := m.clearedFields[name] + return ok +} + +// ClearField clears the value of the field with the given name. It returns an +// error if the field is not defined in the schema. +func (m *SecuritySecretMutation) ClearField(name string) error { + return fmt.Errorf("unknown SecuritySecret nullable field %s", name) +} + +// ResetField resets all changes in the mutation for the field with the given name. +// It returns an error if the field is not defined in the schema. +func (m *SecuritySecretMutation) ResetField(name string) error { + switch name { + case securitysecret.FieldCreatedAt: + m.ResetCreatedAt() + return nil + case securitysecret.FieldUpdatedAt: + m.ResetUpdatedAt() + return nil + case securitysecret.FieldKey: + m.ResetKey() + return nil + case securitysecret.FieldValue: + m.ResetValue() + return nil + } + return fmt.Errorf("unknown SecuritySecret field %s", name) +} + +// AddedEdges returns all edge names that were set/added in this mutation. +func (m *SecuritySecretMutation) AddedEdges() []string { + edges := make([]string, 0, 0) + return edges +} + +// AddedIDs returns all IDs (to other nodes) that were added for the given edge +// name in this mutation. +func (m *SecuritySecretMutation) AddedIDs(name string) []ent.Value { + return nil +} + +// RemovedEdges returns all edge names that were removed in this mutation. +func (m *SecuritySecretMutation) RemovedEdges() []string { + edges := make([]string, 0, 0) + return edges +} + +// RemovedIDs returns all IDs (to other nodes) that were removed for the edge with +// the given name in this mutation. +func (m *SecuritySecretMutation) RemovedIDs(name string) []ent.Value { + return nil +} + +// ClearedEdges returns all edge names that were cleared in this mutation. +func (m *SecuritySecretMutation) ClearedEdges() []string { + edges := make([]string, 0, 0) + return edges +} + +// EdgeCleared returns a boolean which indicates if the edge with the given name +// was cleared in this mutation. +func (m *SecuritySecretMutation) EdgeCleared(name string) bool { + return false +} + +// ClearEdge clears the value of the edge with the given name. It returns an error +// if that edge is not defined in the schema. +func (m *SecuritySecretMutation) ClearEdge(name string) error { + return fmt.Errorf("unknown SecuritySecret unique edge %s", name) +} + +// ResetEdge resets all changes to the edge with the given name in this mutation. +// It returns an error if the edge is not defined in the schema. +func (m *SecuritySecretMutation) ResetEdge(name string) error { + return fmt.Errorf("unknown SecuritySecret edge %s", name) +} + // SettingMutation represents an operation that mutates the Setting nodes in the graph. type SettingMutation struct { config @@ -15061,6 +16052,7 @@ type UsageLogMutation struct { image_count *int addimage_count *int image_size *string + media_type *string cache_ttl_overridden *bool created_at *time.Time clearedFields map[string]struct{} @@ -16688,6 +17680,55 @@ func (m *UsageLogMutation) ResetImageSize() { delete(m.clearedFields, usagelog.FieldImageSize) } +// SetMediaType sets the "media_type" field. +func (m *UsageLogMutation) SetMediaType(s string) { + m.media_type = &s +} + +// MediaType returns the value of the "media_type" field in the mutation. +func (m *UsageLogMutation) MediaType() (r string, exists bool) { + v := m.media_type + if v == nil { + return + } + return *v, true +} + +// OldMediaType returns the old "media_type" field's value of the UsageLog entity. +// If the UsageLog object wasn't provided to the builder, the object is fetched from the database. +// An error is returned if the mutation operation is not UpdateOne, or the database query fails. +func (m *UsageLogMutation) OldMediaType(ctx context.Context) (v *string, err error) { + if !m.op.Is(OpUpdateOne) { + return v, errors.New("OldMediaType is only allowed on UpdateOne operations") + } + if m.id == nil || m.oldValue == nil { + return v, errors.New("OldMediaType requires an ID field in the mutation") + } + oldValue, err := m.oldValue(ctx) + if err != nil { + return v, fmt.Errorf("querying old value for OldMediaType: %w", err) + } + return oldValue.MediaType, nil +} + +// ClearMediaType clears the value of the "media_type" field. +func (m *UsageLogMutation) ClearMediaType() { + m.media_type = nil + m.clearedFields[usagelog.FieldMediaType] = struct{}{} +} + +// MediaTypeCleared returns if the "media_type" field was cleared in this mutation. +func (m *UsageLogMutation) MediaTypeCleared() bool { + _, ok := m.clearedFields[usagelog.FieldMediaType] + return ok +} + +// ResetMediaType resets all changes to the "media_type" field. +func (m *UsageLogMutation) ResetMediaType() { + m.media_type = nil + delete(m.clearedFields, usagelog.FieldMediaType) +} + // SetCacheTTLOverridden sets the "cache_ttl_overridden" field. func (m *UsageLogMutation) SetCacheTTLOverridden(b bool) { m.cache_ttl_overridden = &b @@ -16929,7 +17970,7 @@ func (m *UsageLogMutation) Type() string { // order to get all numeric fields that were incremented/decremented, call // AddedFields(). func (m *UsageLogMutation) Fields() []string { - fields := make([]string, 0, 31) + fields := make([]string, 0, 32) if m.user != nil { fields = append(fields, usagelog.FieldUserID) } @@ -17017,6 +18058,9 @@ func (m *UsageLogMutation) Fields() []string { if m.image_size != nil { fields = append(fields, usagelog.FieldImageSize) } + if m.media_type != nil { + fields = append(fields, usagelog.FieldMediaType) + } if m.cache_ttl_overridden != nil { fields = append(fields, usagelog.FieldCacheTTLOverridden) } @@ -17089,6 +18133,8 @@ func (m *UsageLogMutation) Field(name string) (ent.Value, bool) { return m.ImageCount() case usagelog.FieldImageSize: return m.ImageSize() + case usagelog.FieldMediaType: + return m.MediaType() case usagelog.FieldCacheTTLOverridden: return m.CacheTTLOverridden() case usagelog.FieldCreatedAt: @@ -17160,6 +18206,8 @@ func (m *UsageLogMutation) OldField(ctx context.Context, name string) (ent.Value return m.OldImageCount(ctx) case usagelog.FieldImageSize: return m.OldImageSize(ctx) + case usagelog.FieldMediaType: + return m.OldMediaType(ctx) case usagelog.FieldCacheTTLOverridden: return m.OldCacheTTLOverridden(ctx) case usagelog.FieldCreatedAt: @@ -17376,6 +18424,13 @@ func (m *UsageLogMutation) SetField(name string, value ent.Value) error { } m.SetImageSize(v) return nil + case usagelog.FieldMediaType: + v, ok := value.(string) + if !ok { + return fmt.Errorf("unexpected type %T for field %s", value, name) + } + m.SetMediaType(v) + return nil case usagelog.FieldCacheTTLOverridden: v, ok := value.(bool) if !ok { @@ -17663,6 +18718,9 @@ func (m *UsageLogMutation) ClearedFields() []string { if m.FieldCleared(usagelog.FieldImageSize) { fields = append(fields, usagelog.FieldImageSize) } + if m.FieldCleared(usagelog.FieldMediaType) { + fields = append(fields, usagelog.FieldMediaType) + } return fields } @@ -17701,6 +18759,9 @@ func (m *UsageLogMutation) ClearField(name string) error { case usagelog.FieldImageSize: m.ClearImageSize() return nil + case usagelog.FieldMediaType: + m.ClearMediaType() + return nil } return fmt.Errorf("unknown UsageLog nullable field %s", name) } @@ -17796,6 +18857,9 @@ func (m *UsageLogMutation) ResetField(name string) error { case usagelog.FieldImageSize: m.ResetImageSize() return nil + case usagelog.FieldMediaType: + m.ResetMediaType() + return nil case usagelog.FieldCacheTTLOverridden: m.ResetCacheTTLOverridden() return nil diff --git a/backend/ent/predicate/predicate.go b/backend/ent/predicate/predicate.go index c12955ef..584b9606 100644 --- a/backend/ent/predicate/predicate.go +++ b/backend/ent/predicate/predicate.go @@ -39,6 +39,9 @@ type Proxy func(*sql.Selector) // RedeemCode is the predicate function for redeemcode builders. type RedeemCode func(*sql.Selector) +// SecuritySecret is the predicate function for securitysecret builders. +type SecuritySecret func(*sql.Selector) + // Setting is the predicate function for setting builders. type Setting func(*sql.Selector) diff --git a/backend/ent/runtime/runtime.go b/backend/ent/runtime/runtime.go index d96f9a00..ff3f8f26 100644 --- a/backend/ent/runtime/runtime.go +++ b/backend/ent/runtime/runtime.go @@ -17,6 +17,7 @@ import ( "github.com/Wei-Shaw/sub2api/ent/proxy" "github.com/Wei-Shaw/sub2api/ent/redeemcode" "github.com/Wei-Shaw/sub2api/ent/schema" + "github.com/Wei-Shaw/sub2api/ent/securitysecret" "github.com/Wei-Shaw/sub2api/ent/setting" "github.com/Wei-Shaw/sub2api/ent/usagecleanuptask" "github.com/Wei-Shaw/sub2api/ent/usagelog" @@ -93,11 +94,11 @@ func init() { // apikey.StatusValidator is a validator for the "status" field. It is called by the builders before save. apikey.StatusValidator = apikeyDescStatus.Validators[0].(func(string) error) // apikeyDescQuota is the schema descriptor for quota field. - apikeyDescQuota := apikeyFields[7].Descriptor() + apikeyDescQuota := apikeyFields[8].Descriptor() // apikey.DefaultQuota holds the default value on creation for the quota field. apikey.DefaultQuota = apikeyDescQuota.Default.(float64) // apikeyDescQuotaUsed is the schema descriptor for quota_used field. - apikeyDescQuotaUsed := apikeyFields[8].Descriptor() + apikeyDescQuotaUsed := apikeyFields[9].Descriptor() // apikey.DefaultQuotaUsed holds the default value on creation for the quota_used field. apikey.DefaultQuotaUsed = apikeyDescQuotaUsed.Default.(float64) accountMixin := schema.Account{}.Mixin() @@ -398,23 +399,23 @@ func init() { // group.DefaultDefaultValidityDays holds the default value on creation for the default_validity_days field. group.DefaultDefaultValidityDays = groupDescDefaultValidityDays.Default.(int) // groupDescClaudeCodeOnly is the schema descriptor for claude_code_only field. - groupDescClaudeCodeOnly := groupFields[14].Descriptor() + groupDescClaudeCodeOnly := groupFields[18].Descriptor() // group.DefaultClaudeCodeOnly holds the default value on creation for the claude_code_only field. group.DefaultClaudeCodeOnly = groupDescClaudeCodeOnly.Default.(bool) // groupDescModelRoutingEnabled is the schema descriptor for model_routing_enabled field. - groupDescModelRoutingEnabled := groupFields[18].Descriptor() + groupDescModelRoutingEnabled := groupFields[22].Descriptor() // group.DefaultModelRoutingEnabled holds the default value on creation for the model_routing_enabled field. group.DefaultModelRoutingEnabled = groupDescModelRoutingEnabled.Default.(bool) // groupDescMcpXMLInject is the schema descriptor for mcp_xml_inject field. - groupDescMcpXMLInject := groupFields[19].Descriptor() + groupDescMcpXMLInject := groupFields[23].Descriptor() // group.DefaultMcpXMLInject holds the default value on creation for the mcp_xml_inject field. group.DefaultMcpXMLInject = groupDescMcpXMLInject.Default.(bool) // groupDescSupportedModelScopes is the schema descriptor for supported_model_scopes field. - groupDescSupportedModelScopes := groupFields[20].Descriptor() + groupDescSupportedModelScopes := groupFields[24].Descriptor() // group.DefaultSupportedModelScopes holds the default value on creation for the supported_model_scopes field. group.DefaultSupportedModelScopes = groupDescSupportedModelScopes.Default.([]string) // groupDescSortOrder is the schema descriptor for sort_order field. - groupDescSortOrder := groupFields[21].Descriptor() + groupDescSortOrder := groupFields[25].Descriptor() // group.DefaultSortOrder holds the default value on creation for the sort_order field. group.DefaultSortOrder = groupDescSortOrder.Default.(int) promocodeFields := schema.PromoCode{}.Fields() @@ -602,6 +603,43 @@ func init() { redeemcodeDescValidityDays := redeemcodeFields[9].Descriptor() // redeemcode.DefaultValidityDays holds the default value on creation for the validity_days field. redeemcode.DefaultValidityDays = redeemcodeDescValidityDays.Default.(int) + securitysecretMixin := schema.SecuritySecret{}.Mixin() + securitysecretMixinFields0 := securitysecretMixin[0].Fields() + _ = securitysecretMixinFields0 + securitysecretFields := schema.SecuritySecret{}.Fields() + _ = securitysecretFields + // securitysecretDescCreatedAt is the schema descriptor for created_at field. + securitysecretDescCreatedAt := securitysecretMixinFields0[0].Descriptor() + // securitysecret.DefaultCreatedAt holds the default value on creation for the created_at field. + securitysecret.DefaultCreatedAt = securitysecretDescCreatedAt.Default.(func() time.Time) + // securitysecretDescUpdatedAt is the schema descriptor for updated_at field. + securitysecretDescUpdatedAt := securitysecretMixinFields0[1].Descriptor() + // securitysecret.DefaultUpdatedAt holds the default value on creation for the updated_at field. + securitysecret.DefaultUpdatedAt = securitysecretDescUpdatedAt.Default.(func() time.Time) + // securitysecret.UpdateDefaultUpdatedAt holds the default value on update for the updated_at field. + securitysecret.UpdateDefaultUpdatedAt = securitysecretDescUpdatedAt.UpdateDefault.(func() time.Time) + // securitysecretDescKey is the schema descriptor for key field. + securitysecretDescKey := securitysecretFields[0].Descriptor() + // securitysecret.KeyValidator is a validator for the "key" field. It is called by the builders before save. + securitysecret.KeyValidator = func() func(string) error { + validators := securitysecretDescKey.Validators + fns := [...]func(string) error{ + validators[0].(func(string) error), + validators[1].(func(string) error), + } + return func(key string) error { + for _, fn := range fns { + if err := fn(key); err != nil { + return err + } + } + return nil + } + }() + // securitysecretDescValue is the schema descriptor for value field. + securitysecretDescValue := securitysecretFields[1].Descriptor() + // securitysecret.ValueValidator is a validator for the "value" field. It is called by the builders before save. + securitysecret.ValueValidator = securitysecretDescValue.Validators[0].(func(string) error) settingFields := schema.Setting{}.Fields() _ = settingFields // settingDescKey is the schema descriptor for key field. @@ -779,12 +817,16 @@ func init() { usagelogDescImageSize := usagelogFields[28].Descriptor() // usagelog.ImageSizeValidator is a validator for the "image_size" field. It is called by the builders before save. usagelog.ImageSizeValidator = usagelogDescImageSize.Validators[0].(func(string) error) + // usagelogDescMediaType is the schema descriptor for media_type field. + usagelogDescMediaType := usagelogFields[29].Descriptor() + // usagelog.MediaTypeValidator is a validator for the "media_type" field. It is called by the builders before save. + usagelog.MediaTypeValidator = usagelogDescMediaType.Validators[0].(func(string) error) // usagelogDescCacheTTLOverridden is the schema descriptor for cache_ttl_overridden field. - usagelogDescCacheTTLOverridden := usagelogFields[29].Descriptor() + usagelogDescCacheTTLOverridden := usagelogFields[30].Descriptor() // usagelog.DefaultCacheTTLOverridden holds the default value on creation for the cache_ttl_overridden field. usagelog.DefaultCacheTTLOverridden = usagelogDescCacheTTLOverridden.Default.(bool) // usagelogDescCreatedAt is the schema descriptor for created_at field. - usagelogDescCreatedAt := usagelogFields[30].Descriptor() + usagelogDescCreatedAt := usagelogFields[31].Descriptor() // usagelog.DefaultCreatedAt holds the default value on creation for the created_at field. usagelog.DefaultCreatedAt = usagelogDescCreatedAt.Default.(func() time.Time) userMixin := schema.User{}.Mixin() diff --git a/backend/ent/schema/api_key.go b/backend/ent/schema/api_key.go index 26d52cb0..c1ac7ac3 100644 --- a/backend/ent/schema/api_key.go +++ b/backend/ent/schema/api_key.go @@ -47,6 +47,10 @@ func (APIKey) Fields() []ent.Field { field.String("status"). MaxLen(20). Default(domain.StatusActive), + field.Time("last_used_at"). + Optional(). + Nillable(). + Comment("Last usage time of this API key"), field.JSON("ip_whitelist", []string{}). Optional(). Comment("Allowed IPs/CIDRs, e.g. [\"192.168.1.100\", \"10.0.0.0/8\"]"), @@ -95,6 +99,7 @@ func (APIKey) Indexes() []ent.Index { index.Fields("group_id"), index.Fields("status"), index.Fields("deleted_at"), + index.Fields("last_used_at"), // Index for quota queries index.Fields("quota", "quota_used"), index.Fields("expires_at"), diff --git a/backend/ent/schema/group.go b/backend/ent/schema/group.go index c36ca770..fddf23ce 100644 --- a/backend/ent/schema/group.go +++ b/backend/ent/schema/group.go @@ -87,6 +87,24 @@ func (Group) Fields() []ent.Field { Nillable(). SchemaType(map[string]string{dialect.Postgres: "decimal(20,8)"}), + // Sora 按次计费配置(阶段 1) + field.Float("sora_image_price_360"). + Optional(). + Nillable(). + SchemaType(map[string]string{dialect.Postgres: "decimal(20,8)"}), + field.Float("sora_image_price_540"). + Optional(). + Nillable(). + SchemaType(map[string]string{dialect.Postgres: "decimal(20,8)"}), + field.Float("sora_video_price_per_request"). + Optional(). + Nillable(). + SchemaType(map[string]string{dialect.Postgres: "decimal(20,8)"}), + field.Float("sora_video_price_per_request_hd"). + Optional(). + Nillable(). + SchemaType(map[string]string{dialect.Postgres: "decimal(20,8)"}), + // Claude Code 客户端限制 (added by migration 029) field.Bool("claude_code_only"). Default(false). diff --git a/backend/ent/schema/idempotency_record.go b/backend/ent/schema/idempotency_record.go new file mode 100644 index 00000000..ed09ad65 --- /dev/null +++ b/backend/ent/schema/idempotency_record.go @@ -0,0 +1,50 @@ +package schema + +import ( + "github.com/Wei-Shaw/sub2api/ent/schema/mixins" + + "entgo.io/ent" + "entgo.io/ent/dialect/entsql" + "entgo.io/ent/schema" + "entgo.io/ent/schema/field" + "entgo.io/ent/schema/index" +) + +// IdempotencyRecord 幂等请求记录表。 +type IdempotencyRecord struct { + ent.Schema +} + +func (IdempotencyRecord) Annotations() []schema.Annotation { + return []schema.Annotation{ + entsql.Annotation{Table: "idempotency_records"}, + } +} + +func (IdempotencyRecord) Mixin() []ent.Mixin { + return []ent.Mixin{ + mixins.TimeMixin{}, + } +} + +func (IdempotencyRecord) Fields() []ent.Field { + return []ent.Field{ + field.String("scope").MaxLen(128), + field.String("idempotency_key_hash").MaxLen(64), + field.String("request_fingerprint").MaxLen(64), + field.String("status").MaxLen(32), + field.Int("response_status").Optional().Nillable(), + field.String("response_body").Optional().Nillable(), + field.String("error_reason").MaxLen(128).Optional().Nillable(), + field.Time("locked_until").Optional().Nillable(), + field.Time("expires_at"), + } +} + +func (IdempotencyRecord) Indexes() []ent.Index { + return []ent.Index{ + index.Fields("scope", "idempotency_key_hash").Unique(), + index.Fields("expires_at"), + index.Fields("status", "locked_until"), + } +} diff --git a/backend/ent/schema/security_secret.go b/backend/ent/schema/security_secret.go new file mode 100644 index 00000000..ffe6d348 --- /dev/null +++ b/backend/ent/schema/security_secret.go @@ -0,0 +1,42 @@ +package schema + +import ( + "github.com/Wei-Shaw/sub2api/ent/schema/mixins" + + "entgo.io/ent" + "entgo.io/ent/dialect" + "entgo.io/ent/dialect/entsql" + "entgo.io/ent/schema" + "entgo.io/ent/schema/field" +) + +// SecuritySecret 存储系统级安全密钥(如 JWT 签名密钥、TOTP 加密密钥)。 +type SecuritySecret struct { + ent.Schema +} + +func (SecuritySecret) Annotations() []schema.Annotation { + return []schema.Annotation{ + entsql.Annotation{Table: "security_secrets"}, + } +} + +func (SecuritySecret) Mixin() []ent.Mixin { + return []ent.Mixin{ + mixins.TimeMixin{}, + } +} + +func (SecuritySecret) Fields() []ent.Field { + return []ent.Field{ + field.String("key"). + MaxLen(100). + NotEmpty(). + Unique(), + field.String("value"). + NotEmpty(). + SchemaType(map[string]string{ + dialect.Postgres: "text", + }), + } +} diff --git a/backend/ent/schema/usage_log.go b/backend/ent/schema/usage_log.go index a5032605..ffcae840 100644 --- a/backend/ent/schema/usage_log.go +++ b/backend/ent/schema/usage_log.go @@ -118,6 +118,11 @@ func (UsageLog) Fields() []ent.Field { MaxLen(10). Optional(). Nillable(), + // 媒体类型字段(sora 使用) + field.String("media_type"). + MaxLen(16). + Optional(). + Nillable(), // Cache TTL Override 标记(管理员强制替换了缓存 TTL 计费) field.Bool("cache_ttl_overridden"). diff --git a/backend/ent/securitysecret.go b/backend/ent/securitysecret.go new file mode 100644 index 00000000..e0e93c91 --- /dev/null +++ b/backend/ent/securitysecret.go @@ -0,0 +1,139 @@ +// Code generated by ent, DO NOT EDIT. + +package ent + +import ( + "fmt" + "strings" + "time" + + "entgo.io/ent" + "entgo.io/ent/dialect/sql" + "github.com/Wei-Shaw/sub2api/ent/securitysecret" +) + +// SecuritySecret is the model entity for the SecuritySecret schema. +type SecuritySecret struct { + config `json:"-"` + // ID of the ent. + ID int64 `json:"id,omitempty"` + // CreatedAt holds the value of the "created_at" field. + CreatedAt time.Time `json:"created_at,omitempty"` + // UpdatedAt holds the value of the "updated_at" field. + UpdatedAt time.Time `json:"updated_at,omitempty"` + // Key holds the value of the "key" field. + Key string `json:"key,omitempty"` + // Value holds the value of the "value" field. + Value string `json:"value,omitempty"` + selectValues sql.SelectValues +} + +// scanValues returns the types for scanning values from sql.Rows. +func (*SecuritySecret) scanValues(columns []string) ([]any, error) { + values := make([]any, len(columns)) + for i := range columns { + switch columns[i] { + case securitysecret.FieldID: + values[i] = new(sql.NullInt64) + case securitysecret.FieldKey, securitysecret.FieldValue: + values[i] = new(sql.NullString) + case securitysecret.FieldCreatedAt, securitysecret.FieldUpdatedAt: + values[i] = new(sql.NullTime) + default: + values[i] = new(sql.UnknownType) + } + } + return values, nil +} + +// assignValues assigns the values that were returned from sql.Rows (after scanning) +// to the SecuritySecret fields. +func (_m *SecuritySecret) assignValues(columns []string, values []any) error { + if m, n := len(values), len(columns); m < n { + return fmt.Errorf("mismatch number of scan values: %d != %d", m, n) + } + for i := range columns { + switch columns[i] { + case securitysecret.FieldID: + value, ok := values[i].(*sql.NullInt64) + if !ok { + return fmt.Errorf("unexpected type %T for field id", value) + } + _m.ID = int64(value.Int64) + case securitysecret.FieldCreatedAt: + if value, ok := values[i].(*sql.NullTime); !ok { + return fmt.Errorf("unexpected type %T for field created_at", values[i]) + } else if value.Valid { + _m.CreatedAt = value.Time + } + case securitysecret.FieldUpdatedAt: + if value, ok := values[i].(*sql.NullTime); !ok { + return fmt.Errorf("unexpected type %T for field updated_at", values[i]) + } else if value.Valid { + _m.UpdatedAt = value.Time + } + case securitysecret.FieldKey: + if value, ok := values[i].(*sql.NullString); !ok { + return fmt.Errorf("unexpected type %T for field key", values[i]) + } else if value.Valid { + _m.Key = value.String + } + case securitysecret.FieldValue: + if value, ok := values[i].(*sql.NullString); !ok { + return fmt.Errorf("unexpected type %T for field value", values[i]) + } else if value.Valid { + _m.Value = value.String + } + default: + _m.selectValues.Set(columns[i], values[i]) + } + } + return nil +} + +// GetValue returns the ent.Value that was dynamically selected and assigned to the SecuritySecret. +// This includes values selected through modifiers, order, etc. +func (_m *SecuritySecret) GetValue(name string) (ent.Value, error) { + return _m.selectValues.Get(name) +} + +// Update returns a builder for updating this SecuritySecret. +// Note that you need to call SecuritySecret.Unwrap() before calling this method if this SecuritySecret +// was returned from a transaction, and the transaction was committed or rolled back. +func (_m *SecuritySecret) Update() *SecuritySecretUpdateOne { + return NewSecuritySecretClient(_m.config).UpdateOne(_m) +} + +// Unwrap unwraps the SecuritySecret entity that was returned from a transaction after it was closed, +// so that all future queries will be executed through the driver which created the transaction. +func (_m *SecuritySecret) Unwrap() *SecuritySecret { + _tx, ok := _m.config.driver.(*txDriver) + if !ok { + panic("ent: SecuritySecret is not a transactional entity") + } + _m.config.driver = _tx.drv + return _m +} + +// String implements the fmt.Stringer. +func (_m *SecuritySecret) String() string { + var builder strings.Builder + builder.WriteString("SecuritySecret(") + builder.WriteString(fmt.Sprintf("id=%v, ", _m.ID)) + builder.WriteString("created_at=") + builder.WriteString(_m.CreatedAt.Format(time.ANSIC)) + builder.WriteString(", ") + builder.WriteString("updated_at=") + builder.WriteString(_m.UpdatedAt.Format(time.ANSIC)) + builder.WriteString(", ") + builder.WriteString("key=") + builder.WriteString(_m.Key) + builder.WriteString(", ") + builder.WriteString("value=") + builder.WriteString(_m.Value) + builder.WriteByte(')') + return builder.String() +} + +// SecuritySecrets is a parsable slice of SecuritySecret. +type SecuritySecrets []*SecuritySecret diff --git a/backend/ent/securitysecret/securitysecret.go b/backend/ent/securitysecret/securitysecret.go new file mode 100644 index 00000000..4c5d9ef6 --- /dev/null +++ b/backend/ent/securitysecret/securitysecret.go @@ -0,0 +1,86 @@ +// Code generated by ent, DO NOT EDIT. + +package securitysecret + +import ( + "time" + + "entgo.io/ent/dialect/sql" +) + +const ( + // Label holds the string label denoting the securitysecret type in the database. + Label = "security_secret" + // FieldID holds the string denoting the id field in the database. + FieldID = "id" + // FieldCreatedAt holds the string denoting the created_at field in the database. + FieldCreatedAt = "created_at" + // FieldUpdatedAt holds the string denoting the updated_at field in the database. + FieldUpdatedAt = "updated_at" + // FieldKey holds the string denoting the key field in the database. + FieldKey = "key" + // FieldValue holds the string denoting the value field in the database. + FieldValue = "value" + // Table holds the table name of the securitysecret in the database. + Table = "security_secrets" +) + +// Columns holds all SQL columns for securitysecret fields. +var Columns = []string{ + FieldID, + FieldCreatedAt, + FieldUpdatedAt, + FieldKey, + FieldValue, +} + +// ValidColumn reports if the column name is valid (part of the table columns). +func ValidColumn(column string) bool { + for i := range Columns { + if column == Columns[i] { + return true + } + } + return false +} + +var ( + // DefaultCreatedAt holds the default value on creation for the "created_at" field. + DefaultCreatedAt func() time.Time + // DefaultUpdatedAt holds the default value on creation for the "updated_at" field. + DefaultUpdatedAt func() time.Time + // UpdateDefaultUpdatedAt holds the default value on update for the "updated_at" field. + UpdateDefaultUpdatedAt func() time.Time + // KeyValidator is a validator for the "key" field. It is called by the builders before save. + KeyValidator func(string) error + // ValueValidator is a validator for the "value" field. It is called by the builders before save. + ValueValidator func(string) error +) + +// OrderOption defines the ordering options for the SecuritySecret queries. +type OrderOption func(*sql.Selector) + +// ByID orders the results by the id field. +func ByID(opts ...sql.OrderTermOption) OrderOption { + return sql.OrderByField(FieldID, opts...).ToFunc() +} + +// ByCreatedAt orders the results by the created_at field. +func ByCreatedAt(opts ...sql.OrderTermOption) OrderOption { + return sql.OrderByField(FieldCreatedAt, opts...).ToFunc() +} + +// ByUpdatedAt orders the results by the updated_at field. +func ByUpdatedAt(opts ...sql.OrderTermOption) OrderOption { + return sql.OrderByField(FieldUpdatedAt, opts...).ToFunc() +} + +// ByKey orders the results by the key field. +func ByKey(opts ...sql.OrderTermOption) OrderOption { + return sql.OrderByField(FieldKey, opts...).ToFunc() +} + +// ByValue orders the results by the value field. +func ByValue(opts ...sql.OrderTermOption) OrderOption { + return sql.OrderByField(FieldValue, opts...).ToFunc() +} diff --git a/backend/ent/securitysecret/where.go b/backend/ent/securitysecret/where.go new file mode 100644 index 00000000..34f50752 --- /dev/null +++ b/backend/ent/securitysecret/where.go @@ -0,0 +1,300 @@ +// Code generated by ent, DO NOT EDIT. + +package securitysecret + +import ( + "time" + + "entgo.io/ent/dialect/sql" + "github.com/Wei-Shaw/sub2api/ent/predicate" +) + +// ID filters vertices based on their ID field. +func ID(id int64) predicate.SecuritySecret { + return predicate.SecuritySecret(sql.FieldEQ(FieldID, id)) +} + +// IDEQ applies the EQ predicate on the ID field. +func IDEQ(id int64) predicate.SecuritySecret { + return predicate.SecuritySecret(sql.FieldEQ(FieldID, id)) +} + +// IDNEQ applies the NEQ predicate on the ID field. +func IDNEQ(id int64) predicate.SecuritySecret { + return predicate.SecuritySecret(sql.FieldNEQ(FieldID, id)) +} + +// IDIn applies the In predicate on the ID field. +func IDIn(ids ...int64) predicate.SecuritySecret { + return predicate.SecuritySecret(sql.FieldIn(FieldID, ids...)) +} + +// IDNotIn applies the NotIn predicate on the ID field. +func IDNotIn(ids ...int64) predicate.SecuritySecret { + return predicate.SecuritySecret(sql.FieldNotIn(FieldID, ids...)) +} + +// IDGT applies the GT predicate on the ID field. +func IDGT(id int64) predicate.SecuritySecret { + return predicate.SecuritySecret(sql.FieldGT(FieldID, id)) +} + +// IDGTE applies the GTE predicate on the ID field. +func IDGTE(id int64) predicate.SecuritySecret { + return predicate.SecuritySecret(sql.FieldGTE(FieldID, id)) +} + +// IDLT applies the LT predicate on the ID field. +func IDLT(id int64) predicate.SecuritySecret { + return predicate.SecuritySecret(sql.FieldLT(FieldID, id)) +} + +// IDLTE applies the LTE predicate on the ID field. +func IDLTE(id int64) predicate.SecuritySecret { + return predicate.SecuritySecret(sql.FieldLTE(FieldID, id)) +} + +// CreatedAt applies equality check predicate on the "created_at" field. It's identical to CreatedAtEQ. +func CreatedAt(v time.Time) predicate.SecuritySecret { + return predicate.SecuritySecret(sql.FieldEQ(FieldCreatedAt, v)) +} + +// UpdatedAt applies equality check predicate on the "updated_at" field. It's identical to UpdatedAtEQ. +func UpdatedAt(v time.Time) predicate.SecuritySecret { + return predicate.SecuritySecret(sql.FieldEQ(FieldUpdatedAt, v)) +} + +// Key applies equality check predicate on the "key" field. It's identical to KeyEQ. +func Key(v string) predicate.SecuritySecret { + return predicate.SecuritySecret(sql.FieldEQ(FieldKey, v)) +} + +// Value applies equality check predicate on the "value" field. It's identical to ValueEQ. +func Value(v string) predicate.SecuritySecret { + return predicate.SecuritySecret(sql.FieldEQ(FieldValue, v)) +} + +// CreatedAtEQ applies the EQ predicate on the "created_at" field. +func CreatedAtEQ(v time.Time) predicate.SecuritySecret { + return predicate.SecuritySecret(sql.FieldEQ(FieldCreatedAt, v)) +} + +// CreatedAtNEQ applies the NEQ predicate on the "created_at" field. +func CreatedAtNEQ(v time.Time) predicate.SecuritySecret { + return predicate.SecuritySecret(sql.FieldNEQ(FieldCreatedAt, v)) +} + +// CreatedAtIn applies the In predicate on the "created_at" field. +func CreatedAtIn(vs ...time.Time) predicate.SecuritySecret { + return predicate.SecuritySecret(sql.FieldIn(FieldCreatedAt, vs...)) +} + +// CreatedAtNotIn applies the NotIn predicate on the "created_at" field. +func CreatedAtNotIn(vs ...time.Time) predicate.SecuritySecret { + return predicate.SecuritySecret(sql.FieldNotIn(FieldCreatedAt, vs...)) +} + +// CreatedAtGT applies the GT predicate on the "created_at" field. +func CreatedAtGT(v time.Time) predicate.SecuritySecret { + return predicate.SecuritySecret(sql.FieldGT(FieldCreatedAt, v)) +} + +// CreatedAtGTE applies the GTE predicate on the "created_at" field. +func CreatedAtGTE(v time.Time) predicate.SecuritySecret { + return predicate.SecuritySecret(sql.FieldGTE(FieldCreatedAt, v)) +} + +// CreatedAtLT applies the LT predicate on the "created_at" field. +func CreatedAtLT(v time.Time) predicate.SecuritySecret { + return predicate.SecuritySecret(sql.FieldLT(FieldCreatedAt, v)) +} + +// CreatedAtLTE applies the LTE predicate on the "created_at" field. +func CreatedAtLTE(v time.Time) predicate.SecuritySecret { + return predicate.SecuritySecret(sql.FieldLTE(FieldCreatedAt, v)) +} + +// UpdatedAtEQ applies the EQ predicate on the "updated_at" field. +func UpdatedAtEQ(v time.Time) predicate.SecuritySecret { + return predicate.SecuritySecret(sql.FieldEQ(FieldUpdatedAt, v)) +} + +// UpdatedAtNEQ applies the NEQ predicate on the "updated_at" field. +func UpdatedAtNEQ(v time.Time) predicate.SecuritySecret { + return predicate.SecuritySecret(sql.FieldNEQ(FieldUpdatedAt, v)) +} + +// UpdatedAtIn applies the In predicate on the "updated_at" field. +func UpdatedAtIn(vs ...time.Time) predicate.SecuritySecret { + return predicate.SecuritySecret(sql.FieldIn(FieldUpdatedAt, vs...)) +} + +// UpdatedAtNotIn applies the NotIn predicate on the "updated_at" field. +func UpdatedAtNotIn(vs ...time.Time) predicate.SecuritySecret { + return predicate.SecuritySecret(sql.FieldNotIn(FieldUpdatedAt, vs...)) +} + +// UpdatedAtGT applies the GT predicate on the "updated_at" field. +func UpdatedAtGT(v time.Time) predicate.SecuritySecret { + return predicate.SecuritySecret(sql.FieldGT(FieldUpdatedAt, v)) +} + +// UpdatedAtGTE applies the GTE predicate on the "updated_at" field. +func UpdatedAtGTE(v time.Time) predicate.SecuritySecret { + return predicate.SecuritySecret(sql.FieldGTE(FieldUpdatedAt, v)) +} + +// UpdatedAtLT applies the LT predicate on the "updated_at" field. +func UpdatedAtLT(v time.Time) predicate.SecuritySecret { + return predicate.SecuritySecret(sql.FieldLT(FieldUpdatedAt, v)) +} + +// UpdatedAtLTE applies the LTE predicate on the "updated_at" field. +func UpdatedAtLTE(v time.Time) predicate.SecuritySecret { + return predicate.SecuritySecret(sql.FieldLTE(FieldUpdatedAt, v)) +} + +// KeyEQ applies the EQ predicate on the "key" field. +func KeyEQ(v string) predicate.SecuritySecret { + return predicate.SecuritySecret(sql.FieldEQ(FieldKey, v)) +} + +// KeyNEQ applies the NEQ predicate on the "key" field. +func KeyNEQ(v string) predicate.SecuritySecret { + return predicate.SecuritySecret(sql.FieldNEQ(FieldKey, v)) +} + +// KeyIn applies the In predicate on the "key" field. +func KeyIn(vs ...string) predicate.SecuritySecret { + return predicate.SecuritySecret(sql.FieldIn(FieldKey, vs...)) +} + +// KeyNotIn applies the NotIn predicate on the "key" field. +func KeyNotIn(vs ...string) predicate.SecuritySecret { + return predicate.SecuritySecret(sql.FieldNotIn(FieldKey, vs...)) +} + +// KeyGT applies the GT predicate on the "key" field. +func KeyGT(v string) predicate.SecuritySecret { + return predicate.SecuritySecret(sql.FieldGT(FieldKey, v)) +} + +// KeyGTE applies the GTE predicate on the "key" field. +func KeyGTE(v string) predicate.SecuritySecret { + return predicate.SecuritySecret(sql.FieldGTE(FieldKey, v)) +} + +// KeyLT applies the LT predicate on the "key" field. +func KeyLT(v string) predicate.SecuritySecret { + return predicate.SecuritySecret(sql.FieldLT(FieldKey, v)) +} + +// KeyLTE applies the LTE predicate on the "key" field. +func KeyLTE(v string) predicate.SecuritySecret { + return predicate.SecuritySecret(sql.FieldLTE(FieldKey, v)) +} + +// KeyContains applies the Contains predicate on the "key" field. +func KeyContains(v string) predicate.SecuritySecret { + return predicate.SecuritySecret(sql.FieldContains(FieldKey, v)) +} + +// KeyHasPrefix applies the HasPrefix predicate on the "key" field. +func KeyHasPrefix(v string) predicate.SecuritySecret { + return predicate.SecuritySecret(sql.FieldHasPrefix(FieldKey, v)) +} + +// KeyHasSuffix applies the HasSuffix predicate on the "key" field. +func KeyHasSuffix(v string) predicate.SecuritySecret { + return predicate.SecuritySecret(sql.FieldHasSuffix(FieldKey, v)) +} + +// KeyEqualFold applies the EqualFold predicate on the "key" field. +func KeyEqualFold(v string) predicate.SecuritySecret { + return predicate.SecuritySecret(sql.FieldEqualFold(FieldKey, v)) +} + +// KeyContainsFold applies the ContainsFold predicate on the "key" field. +func KeyContainsFold(v string) predicate.SecuritySecret { + return predicate.SecuritySecret(sql.FieldContainsFold(FieldKey, v)) +} + +// ValueEQ applies the EQ predicate on the "value" field. +func ValueEQ(v string) predicate.SecuritySecret { + return predicate.SecuritySecret(sql.FieldEQ(FieldValue, v)) +} + +// ValueNEQ applies the NEQ predicate on the "value" field. +func ValueNEQ(v string) predicate.SecuritySecret { + return predicate.SecuritySecret(sql.FieldNEQ(FieldValue, v)) +} + +// ValueIn applies the In predicate on the "value" field. +func ValueIn(vs ...string) predicate.SecuritySecret { + return predicate.SecuritySecret(sql.FieldIn(FieldValue, vs...)) +} + +// ValueNotIn applies the NotIn predicate on the "value" field. +func ValueNotIn(vs ...string) predicate.SecuritySecret { + return predicate.SecuritySecret(sql.FieldNotIn(FieldValue, vs...)) +} + +// ValueGT applies the GT predicate on the "value" field. +func ValueGT(v string) predicate.SecuritySecret { + return predicate.SecuritySecret(sql.FieldGT(FieldValue, v)) +} + +// ValueGTE applies the GTE predicate on the "value" field. +func ValueGTE(v string) predicate.SecuritySecret { + return predicate.SecuritySecret(sql.FieldGTE(FieldValue, v)) +} + +// ValueLT applies the LT predicate on the "value" field. +func ValueLT(v string) predicate.SecuritySecret { + return predicate.SecuritySecret(sql.FieldLT(FieldValue, v)) +} + +// ValueLTE applies the LTE predicate on the "value" field. +func ValueLTE(v string) predicate.SecuritySecret { + return predicate.SecuritySecret(sql.FieldLTE(FieldValue, v)) +} + +// ValueContains applies the Contains predicate on the "value" field. +func ValueContains(v string) predicate.SecuritySecret { + return predicate.SecuritySecret(sql.FieldContains(FieldValue, v)) +} + +// ValueHasPrefix applies the HasPrefix predicate on the "value" field. +func ValueHasPrefix(v string) predicate.SecuritySecret { + return predicate.SecuritySecret(sql.FieldHasPrefix(FieldValue, v)) +} + +// ValueHasSuffix applies the HasSuffix predicate on the "value" field. +func ValueHasSuffix(v string) predicate.SecuritySecret { + return predicate.SecuritySecret(sql.FieldHasSuffix(FieldValue, v)) +} + +// ValueEqualFold applies the EqualFold predicate on the "value" field. +func ValueEqualFold(v string) predicate.SecuritySecret { + return predicate.SecuritySecret(sql.FieldEqualFold(FieldValue, v)) +} + +// ValueContainsFold applies the ContainsFold predicate on the "value" field. +func ValueContainsFold(v string) predicate.SecuritySecret { + return predicate.SecuritySecret(sql.FieldContainsFold(FieldValue, v)) +} + +// And groups predicates with the AND operator between them. +func And(predicates ...predicate.SecuritySecret) predicate.SecuritySecret { + return predicate.SecuritySecret(sql.AndPredicates(predicates...)) +} + +// Or groups predicates with the OR operator between them. +func Or(predicates ...predicate.SecuritySecret) predicate.SecuritySecret { + return predicate.SecuritySecret(sql.OrPredicates(predicates...)) +} + +// Not applies the not operator on the given predicate. +func Not(p predicate.SecuritySecret) predicate.SecuritySecret { + return predicate.SecuritySecret(sql.NotPredicates(p)) +} diff --git a/backend/ent/securitysecret_create.go b/backend/ent/securitysecret_create.go new file mode 100644 index 00000000..397503be --- /dev/null +++ b/backend/ent/securitysecret_create.go @@ -0,0 +1,626 @@ +// Code generated by ent, DO NOT EDIT. + +package ent + +import ( + "context" + "errors" + "fmt" + "time" + + "entgo.io/ent/dialect/sql" + "entgo.io/ent/dialect/sql/sqlgraph" + "entgo.io/ent/schema/field" + "github.com/Wei-Shaw/sub2api/ent/securitysecret" +) + +// SecuritySecretCreate is the builder for creating a SecuritySecret entity. +type SecuritySecretCreate struct { + config + mutation *SecuritySecretMutation + hooks []Hook + conflict []sql.ConflictOption +} + +// SetCreatedAt sets the "created_at" field. +func (_c *SecuritySecretCreate) SetCreatedAt(v time.Time) *SecuritySecretCreate { + _c.mutation.SetCreatedAt(v) + return _c +} + +// SetNillableCreatedAt sets the "created_at" field if the given value is not nil. +func (_c *SecuritySecretCreate) SetNillableCreatedAt(v *time.Time) *SecuritySecretCreate { + if v != nil { + _c.SetCreatedAt(*v) + } + return _c +} + +// SetUpdatedAt sets the "updated_at" field. +func (_c *SecuritySecretCreate) SetUpdatedAt(v time.Time) *SecuritySecretCreate { + _c.mutation.SetUpdatedAt(v) + return _c +} + +// SetNillableUpdatedAt sets the "updated_at" field if the given value is not nil. +func (_c *SecuritySecretCreate) SetNillableUpdatedAt(v *time.Time) *SecuritySecretCreate { + if v != nil { + _c.SetUpdatedAt(*v) + } + return _c +} + +// SetKey sets the "key" field. +func (_c *SecuritySecretCreate) SetKey(v string) *SecuritySecretCreate { + _c.mutation.SetKey(v) + return _c +} + +// SetValue sets the "value" field. +func (_c *SecuritySecretCreate) SetValue(v string) *SecuritySecretCreate { + _c.mutation.SetValue(v) + return _c +} + +// Mutation returns the SecuritySecretMutation object of the builder. +func (_c *SecuritySecretCreate) Mutation() *SecuritySecretMutation { + return _c.mutation +} + +// Save creates the SecuritySecret in the database. +func (_c *SecuritySecretCreate) Save(ctx context.Context) (*SecuritySecret, error) { + _c.defaults() + return withHooks(ctx, _c.sqlSave, _c.mutation, _c.hooks) +} + +// SaveX calls Save and panics if Save returns an error. +func (_c *SecuritySecretCreate) SaveX(ctx context.Context) *SecuritySecret { + v, err := _c.Save(ctx) + if err != nil { + panic(err) + } + return v +} + +// Exec executes the query. +func (_c *SecuritySecretCreate) Exec(ctx context.Context) error { + _, err := _c.Save(ctx) + return err +} + +// ExecX is like Exec, but panics if an error occurs. +func (_c *SecuritySecretCreate) ExecX(ctx context.Context) { + if err := _c.Exec(ctx); err != nil { + panic(err) + } +} + +// defaults sets the default values of the builder before save. +func (_c *SecuritySecretCreate) defaults() { + if _, ok := _c.mutation.CreatedAt(); !ok { + v := securitysecret.DefaultCreatedAt() + _c.mutation.SetCreatedAt(v) + } + if _, ok := _c.mutation.UpdatedAt(); !ok { + v := securitysecret.DefaultUpdatedAt() + _c.mutation.SetUpdatedAt(v) + } +} + +// check runs all checks and user-defined validators on the builder. +func (_c *SecuritySecretCreate) check() error { + if _, ok := _c.mutation.CreatedAt(); !ok { + return &ValidationError{Name: "created_at", err: errors.New(`ent: missing required field "SecuritySecret.created_at"`)} + } + if _, ok := _c.mutation.UpdatedAt(); !ok { + return &ValidationError{Name: "updated_at", err: errors.New(`ent: missing required field "SecuritySecret.updated_at"`)} + } + if _, ok := _c.mutation.Key(); !ok { + return &ValidationError{Name: "key", err: errors.New(`ent: missing required field "SecuritySecret.key"`)} + } + if v, ok := _c.mutation.Key(); ok { + if err := securitysecret.KeyValidator(v); err != nil { + return &ValidationError{Name: "key", err: fmt.Errorf(`ent: validator failed for field "SecuritySecret.key": %w`, err)} + } + } + if _, ok := _c.mutation.Value(); !ok { + return &ValidationError{Name: "value", err: errors.New(`ent: missing required field "SecuritySecret.value"`)} + } + if v, ok := _c.mutation.Value(); ok { + if err := securitysecret.ValueValidator(v); err != nil { + return &ValidationError{Name: "value", err: fmt.Errorf(`ent: validator failed for field "SecuritySecret.value": %w`, err)} + } + } + return nil +} + +func (_c *SecuritySecretCreate) sqlSave(ctx context.Context) (*SecuritySecret, error) { + if err := _c.check(); err != nil { + return nil, err + } + _node, _spec := _c.createSpec() + if err := sqlgraph.CreateNode(ctx, _c.driver, _spec); err != nil { + if sqlgraph.IsConstraintError(err) { + err = &ConstraintError{msg: err.Error(), wrap: err} + } + return nil, err + } + id := _spec.ID.Value.(int64) + _node.ID = int64(id) + _c.mutation.id = &_node.ID + _c.mutation.done = true + return _node, nil +} + +func (_c *SecuritySecretCreate) createSpec() (*SecuritySecret, *sqlgraph.CreateSpec) { + var ( + _node = &SecuritySecret{config: _c.config} + _spec = sqlgraph.NewCreateSpec(securitysecret.Table, sqlgraph.NewFieldSpec(securitysecret.FieldID, field.TypeInt64)) + ) + _spec.OnConflict = _c.conflict + if value, ok := _c.mutation.CreatedAt(); ok { + _spec.SetField(securitysecret.FieldCreatedAt, field.TypeTime, value) + _node.CreatedAt = value + } + if value, ok := _c.mutation.UpdatedAt(); ok { + _spec.SetField(securitysecret.FieldUpdatedAt, field.TypeTime, value) + _node.UpdatedAt = value + } + if value, ok := _c.mutation.Key(); ok { + _spec.SetField(securitysecret.FieldKey, field.TypeString, value) + _node.Key = value + } + if value, ok := _c.mutation.Value(); ok { + _spec.SetField(securitysecret.FieldValue, field.TypeString, value) + _node.Value = value + } + return _node, _spec +} + +// OnConflict allows configuring the `ON CONFLICT` / `ON DUPLICATE KEY` clause +// of the `INSERT` statement. For example: +// +// client.SecuritySecret.Create(). +// SetCreatedAt(v). +// OnConflict( +// // Update the row with the new values +// // the was proposed for insertion. +// sql.ResolveWithNewValues(), +// ). +// // Override some of the fields with custom +// // update values. +// Update(func(u *ent.SecuritySecretUpsert) { +// SetCreatedAt(v+v). +// }). +// Exec(ctx) +func (_c *SecuritySecretCreate) OnConflict(opts ...sql.ConflictOption) *SecuritySecretUpsertOne { + _c.conflict = opts + return &SecuritySecretUpsertOne{ + create: _c, + } +} + +// OnConflictColumns calls `OnConflict` and configures the columns +// as conflict target. Using this option is equivalent to using: +// +// client.SecuritySecret.Create(). +// OnConflict(sql.ConflictColumns(columns...)). +// Exec(ctx) +func (_c *SecuritySecretCreate) OnConflictColumns(columns ...string) *SecuritySecretUpsertOne { + _c.conflict = append(_c.conflict, sql.ConflictColumns(columns...)) + return &SecuritySecretUpsertOne{ + create: _c, + } +} + +type ( + // SecuritySecretUpsertOne is the builder for "upsert"-ing + // one SecuritySecret node. + SecuritySecretUpsertOne struct { + create *SecuritySecretCreate + } + + // SecuritySecretUpsert is the "OnConflict" setter. + SecuritySecretUpsert struct { + *sql.UpdateSet + } +) + +// SetUpdatedAt sets the "updated_at" field. +func (u *SecuritySecretUpsert) SetUpdatedAt(v time.Time) *SecuritySecretUpsert { + u.Set(securitysecret.FieldUpdatedAt, v) + return u +} + +// UpdateUpdatedAt sets the "updated_at" field to the value that was provided on create. +func (u *SecuritySecretUpsert) UpdateUpdatedAt() *SecuritySecretUpsert { + u.SetExcluded(securitysecret.FieldUpdatedAt) + return u +} + +// SetKey sets the "key" field. +func (u *SecuritySecretUpsert) SetKey(v string) *SecuritySecretUpsert { + u.Set(securitysecret.FieldKey, v) + return u +} + +// UpdateKey sets the "key" field to the value that was provided on create. +func (u *SecuritySecretUpsert) UpdateKey() *SecuritySecretUpsert { + u.SetExcluded(securitysecret.FieldKey) + return u +} + +// SetValue sets the "value" field. +func (u *SecuritySecretUpsert) SetValue(v string) *SecuritySecretUpsert { + u.Set(securitysecret.FieldValue, v) + return u +} + +// UpdateValue sets the "value" field to the value that was provided on create. +func (u *SecuritySecretUpsert) UpdateValue() *SecuritySecretUpsert { + u.SetExcluded(securitysecret.FieldValue) + return u +} + +// UpdateNewValues updates the mutable fields using the new values that were set on create. +// Using this option is equivalent to using: +// +// client.SecuritySecret.Create(). +// OnConflict( +// sql.ResolveWithNewValues(), +// ). +// Exec(ctx) +func (u *SecuritySecretUpsertOne) UpdateNewValues() *SecuritySecretUpsertOne { + u.create.conflict = append(u.create.conflict, sql.ResolveWithNewValues()) + u.create.conflict = append(u.create.conflict, sql.ResolveWith(func(s *sql.UpdateSet) { + if _, exists := u.create.mutation.CreatedAt(); exists { + s.SetIgnore(securitysecret.FieldCreatedAt) + } + })) + return u +} + +// Ignore sets each column to itself in case of conflict. +// Using this option is equivalent to using: +// +// client.SecuritySecret.Create(). +// OnConflict(sql.ResolveWithIgnore()). +// Exec(ctx) +func (u *SecuritySecretUpsertOne) Ignore() *SecuritySecretUpsertOne { + u.create.conflict = append(u.create.conflict, sql.ResolveWithIgnore()) + return u +} + +// DoNothing configures the conflict_action to `DO NOTHING`. +// Supported only by SQLite and PostgreSQL. +func (u *SecuritySecretUpsertOne) DoNothing() *SecuritySecretUpsertOne { + u.create.conflict = append(u.create.conflict, sql.DoNothing()) + return u +} + +// Update allows overriding fields `UPDATE` values. See the SecuritySecretCreate.OnConflict +// documentation for more info. +func (u *SecuritySecretUpsertOne) Update(set func(*SecuritySecretUpsert)) *SecuritySecretUpsertOne { + u.create.conflict = append(u.create.conflict, sql.ResolveWith(func(update *sql.UpdateSet) { + set(&SecuritySecretUpsert{UpdateSet: update}) + })) + return u +} + +// SetUpdatedAt sets the "updated_at" field. +func (u *SecuritySecretUpsertOne) SetUpdatedAt(v time.Time) *SecuritySecretUpsertOne { + return u.Update(func(s *SecuritySecretUpsert) { + s.SetUpdatedAt(v) + }) +} + +// UpdateUpdatedAt sets the "updated_at" field to the value that was provided on create. +func (u *SecuritySecretUpsertOne) UpdateUpdatedAt() *SecuritySecretUpsertOne { + return u.Update(func(s *SecuritySecretUpsert) { + s.UpdateUpdatedAt() + }) +} + +// SetKey sets the "key" field. +func (u *SecuritySecretUpsertOne) SetKey(v string) *SecuritySecretUpsertOne { + return u.Update(func(s *SecuritySecretUpsert) { + s.SetKey(v) + }) +} + +// UpdateKey sets the "key" field to the value that was provided on create. +func (u *SecuritySecretUpsertOne) UpdateKey() *SecuritySecretUpsertOne { + return u.Update(func(s *SecuritySecretUpsert) { + s.UpdateKey() + }) +} + +// SetValue sets the "value" field. +func (u *SecuritySecretUpsertOne) SetValue(v string) *SecuritySecretUpsertOne { + return u.Update(func(s *SecuritySecretUpsert) { + s.SetValue(v) + }) +} + +// UpdateValue sets the "value" field to the value that was provided on create. +func (u *SecuritySecretUpsertOne) UpdateValue() *SecuritySecretUpsertOne { + return u.Update(func(s *SecuritySecretUpsert) { + s.UpdateValue() + }) +} + +// Exec executes the query. +func (u *SecuritySecretUpsertOne) Exec(ctx context.Context) error { + if len(u.create.conflict) == 0 { + return errors.New("ent: missing options for SecuritySecretCreate.OnConflict") + } + return u.create.Exec(ctx) +} + +// ExecX is like Exec, but panics if an error occurs. +func (u *SecuritySecretUpsertOne) ExecX(ctx context.Context) { + if err := u.create.Exec(ctx); err != nil { + panic(err) + } +} + +// Exec executes the UPSERT query and returns the inserted/updated ID. +func (u *SecuritySecretUpsertOne) ID(ctx context.Context) (id int64, err error) { + node, err := u.create.Save(ctx) + if err != nil { + return id, err + } + return node.ID, nil +} + +// IDX is like ID, but panics if an error occurs. +func (u *SecuritySecretUpsertOne) IDX(ctx context.Context) int64 { + id, err := u.ID(ctx) + if err != nil { + panic(err) + } + return id +} + +// SecuritySecretCreateBulk is the builder for creating many SecuritySecret entities in bulk. +type SecuritySecretCreateBulk struct { + config + err error + builders []*SecuritySecretCreate + conflict []sql.ConflictOption +} + +// Save creates the SecuritySecret entities in the database. +func (_c *SecuritySecretCreateBulk) Save(ctx context.Context) ([]*SecuritySecret, error) { + if _c.err != nil { + return nil, _c.err + } + specs := make([]*sqlgraph.CreateSpec, len(_c.builders)) + nodes := make([]*SecuritySecret, len(_c.builders)) + mutators := make([]Mutator, len(_c.builders)) + for i := range _c.builders { + func(i int, root context.Context) { + builder := _c.builders[i] + builder.defaults() + var mut Mutator = MutateFunc(func(ctx context.Context, m Mutation) (Value, error) { + mutation, ok := m.(*SecuritySecretMutation) + if !ok { + return nil, fmt.Errorf("unexpected mutation type %T", m) + } + if err := builder.check(); err != nil { + return nil, err + } + builder.mutation = mutation + var err error + nodes[i], specs[i] = builder.createSpec() + if i < len(mutators)-1 { + _, err = mutators[i+1].Mutate(root, _c.builders[i+1].mutation) + } else { + spec := &sqlgraph.BatchCreateSpec{Nodes: specs} + spec.OnConflict = _c.conflict + // Invoke the actual operation on the latest mutation in the chain. + if err = sqlgraph.BatchCreate(ctx, _c.driver, spec); err != nil { + if sqlgraph.IsConstraintError(err) { + err = &ConstraintError{msg: err.Error(), wrap: err} + } + } + } + if err != nil { + return nil, err + } + mutation.id = &nodes[i].ID + if specs[i].ID.Value != nil { + id := specs[i].ID.Value.(int64) + nodes[i].ID = int64(id) + } + mutation.done = true + return nodes[i], nil + }) + for i := len(builder.hooks) - 1; i >= 0; i-- { + mut = builder.hooks[i](mut) + } + mutators[i] = mut + }(i, ctx) + } + if len(mutators) > 0 { + if _, err := mutators[0].Mutate(ctx, _c.builders[0].mutation); err != nil { + return nil, err + } + } + return nodes, nil +} + +// SaveX is like Save, but panics if an error occurs. +func (_c *SecuritySecretCreateBulk) SaveX(ctx context.Context) []*SecuritySecret { + v, err := _c.Save(ctx) + if err != nil { + panic(err) + } + return v +} + +// Exec executes the query. +func (_c *SecuritySecretCreateBulk) Exec(ctx context.Context) error { + _, err := _c.Save(ctx) + return err +} + +// ExecX is like Exec, but panics if an error occurs. +func (_c *SecuritySecretCreateBulk) ExecX(ctx context.Context) { + if err := _c.Exec(ctx); err != nil { + panic(err) + } +} + +// OnConflict allows configuring the `ON CONFLICT` / `ON DUPLICATE KEY` clause +// of the `INSERT` statement. For example: +// +// client.SecuritySecret.CreateBulk(builders...). +// OnConflict( +// // Update the row with the new values +// // the was proposed for insertion. +// sql.ResolveWithNewValues(), +// ). +// // Override some of the fields with custom +// // update values. +// Update(func(u *ent.SecuritySecretUpsert) { +// SetCreatedAt(v+v). +// }). +// Exec(ctx) +func (_c *SecuritySecretCreateBulk) OnConflict(opts ...sql.ConflictOption) *SecuritySecretUpsertBulk { + _c.conflict = opts + return &SecuritySecretUpsertBulk{ + create: _c, + } +} + +// OnConflictColumns calls `OnConflict` and configures the columns +// as conflict target. Using this option is equivalent to using: +// +// client.SecuritySecret.Create(). +// OnConflict(sql.ConflictColumns(columns...)). +// Exec(ctx) +func (_c *SecuritySecretCreateBulk) OnConflictColumns(columns ...string) *SecuritySecretUpsertBulk { + _c.conflict = append(_c.conflict, sql.ConflictColumns(columns...)) + return &SecuritySecretUpsertBulk{ + create: _c, + } +} + +// SecuritySecretUpsertBulk is the builder for "upsert"-ing +// a bulk of SecuritySecret nodes. +type SecuritySecretUpsertBulk struct { + create *SecuritySecretCreateBulk +} + +// UpdateNewValues updates the mutable fields using the new values that +// were set on create. Using this option is equivalent to using: +// +// client.SecuritySecret.Create(). +// OnConflict( +// sql.ResolveWithNewValues(), +// ). +// Exec(ctx) +func (u *SecuritySecretUpsertBulk) UpdateNewValues() *SecuritySecretUpsertBulk { + u.create.conflict = append(u.create.conflict, sql.ResolveWithNewValues()) + u.create.conflict = append(u.create.conflict, sql.ResolveWith(func(s *sql.UpdateSet) { + for _, b := range u.create.builders { + if _, exists := b.mutation.CreatedAt(); exists { + s.SetIgnore(securitysecret.FieldCreatedAt) + } + } + })) + return u +} + +// Ignore sets each column to itself in case of conflict. +// Using this option is equivalent to using: +// +// client.SecuritySecret.Create(). +// OnConflict(sql.ResolveWithIgnore()). +// Exec(ctx) +func (u *SecuritySecretUpsertBulk) Ignore() *SecuritySecretUpsertBulk { + u.create.conflict = append(u.create.conflict, sql.ResolveWithIgnore()) + return u +} + +// DoNothing configures the conflict_action to `DO NOTHING`. +// Supported only by SQLite and PostgreSQL. +func (u *SecuritySecretUpsertBulk) DoNothing() *SecuritySecretUpsertBulk { + u.create.conflict = append(u.create.conflict, sql.DoNothing()) + return u +} + +// Update allows overriding fields `UPDATE` values. See the SecuritySecretCreateBulk.OnConflict +// documentation for more info. +func (u *SecuritySecretUpsertBulk) Update(set func(*SecuritySecretUpsert)) *SecuritySecretUpsertBulk { + u.create.conflict = append(u.create.conflict, sql.ResolveWith(func(update *sql.UpdateSet) { + set(&SecuritySecretUpsert{UpdateSet: update}) + })) + return u +} + +// SetUpdatedAt sets the "updated_at" field. +func (u *SecuritySecretUpsertBulk) SetUpdatedAt(v time.Time) *SecuritySecretUpsertBulk { + return u.Update(func(s *SecuritySecretUpsert) { + s.SetUpdatedAt(v) + }) +} + +// UpdateUpdatedAt sets the "updated_at" field to the value that was provided on create. +func (u *SecuritySecretUpsertBulk) UpdateUpdatedAt() *SecuritySecretUpsertBulk { + return u.Update(func(s *SecuritySecretUpsert) { + s.UpdateUpdatedAt() + }) +} + +// SetKey sets the "key" field. +func (u *SecuritySecretUpsertBulk) SetKey(v string) *SecuritySecretUpsertBulk { + return u.Update(func(s *SecuritySecretUpsert) { + s.SetKey(v) + }) +} + +// UpdateKey sets the "key" field to the value that was provided on create. +func (u *SecuritySecretUpsertBulk) UpdateKey() *SecuritySecretUpsertBulk { + return u.Update(func(s *SecuritySecretUpsert) { + s.UpdateKey() + }) +} + +// SetValue sets the "value" field. +func (u *SecuritySecretUpsertBulk) SetValue(v string) *SecuritySecretUpsertBulk { + return u.Update(func(s *SecuritySecretUpsert) { + s.SetValue(v) + }) +} + +// UpdateValue sets the "value" field to the value that was provided on create. +func (u *SecuritySecretUpsertBulk) UpdateValue() *SecuritySecretUpsertBulk { + return u.Update(func(s *SecuritySecretUpsert) { + s.UpdateValue() + }) +} + +// Exec executes the query. +func (u *SecuritySecretUpsertBulk) Exec(ctx context.Context) error { + if u.create.err != nil { + return u.create.err + } + for i, b := range u.create.builders { + if len(b.conflict) != 0 { + return fmt.Errorf("ent: OnConflict was set for builder %d. Set it on the SecuritySecretCreateBulk instead", i) + } + } + if len(u.create.conflict) == 0 { + return errors.New("ent: missing options for SecuritySecretCreateBulk.OnConflict") + } + return u.create.Exec(ctx) +} + +// ExecX is like Exec, but panics if an error occurs. +func (u *SecuritySecretUpsertBulk) ExecX(ctx context.Context) { + if err := u.create.Exec(ctx); err != nil { + panic(err) + } +} diff --git a/backend/ent/securitysecret_delete.go b/backend/ent/securitysecret_delete.go new file mode 100644 index 00000000..66757138 --- /dev/null +++ b/backend/ent/securitysecret_delete.go @@ -0,0 +1,88 @@ +// Code generated by ent, DO NOT EDIT. + +package ent + +import ( + "context" + + "entgo.io/ent/dialect/sql" + "entgo.io/ent/dialect/sql/sqlgraph" + "entgo.io/ent/schema/field" + "github.com/Wei-Shaw/sub2api/ent/predicate" + "github.com/Wei-Shaw/sub2api/ent/securitysecret" +) + +// SecuritySecretDelete is the builder for deleting a SecuritySecret entity. +type SecuritySecretDelete struct { + config + hooks []Hook + mutation *SecuritySecretMutation +} + +// Where appends a list predicates to the SecuritySecretDelete builder. +func (_d *SecuritySecretDelete) Where(ps ...predicate.SecuritySecret) *SecuritySecretDelete { + _d.mutation.Where(ps...) + return _d +} + +// Exec executes the deletion query and returns how many vertices were deleted. +func (_d *SecuritySecretDelete) Exec(ctx context.Context) (int, error) { + return withHooks(ctx, _d.sqlExec, _d.mutation, _d.hooks) +} + +// ExecX is like Exec, but panics if an error occurs. +func (_d *SecuritySecretDelete) ExecX(ctx context.Context) int { + n, err := _d.Exec(ctx) + if err != nil { + panic(err) + } + return n +} + +func (_d *SecuritySecretDelete) sqlExec(ctx context.Context) (int, error) { + _spec := sqlgraph.NewDeleteSpec(securitysecret.Table, sqlgraph.NewFieldSpec(securitysecret.FieldID, field.TypeInt64)) + if ps := _d.mutation.predicates; len(ps) > 0 { + _spec.Predicate = func(selector *sql.Selector) { + for i := range ps { + ps[i](selector) + } + } + } + affected, err := sqlgraph.DeleteNodes(ctx, _d.driver, _spec) + if err != nil && sqlgraph.IsConstraintError(err) { + err = &ConstraintError{msg: err.Error(), wrap: err} + } + _d.mutation.done = true + return affected, err +} + +// SecuritySecretDeleteOne is the builder for deleting a single SecuritySecret entity. +type SecuritySecretDeleteOne struct { + _d *SecuritySecretDelete +} + +// Where appends a list predicates to the SecuritySecretDelete builder. +func (_d *SecuritySecretDeleteOne) Where(ps ...predicate.SecuritySecret) *SecuritySecretDeleteOne { + _d._d.mutation.Where(ps...) + return _d +} + +// Exec executes the deletion query. +func (_d *SecuritySecretDeleteOne) Exec(ctx context.Context) error { + n, err := _d._d.Exec(ctx) + switch { + case err != nil: + return err + case n == 0: + return &NotFoundError{securitysecret.Label} + default: + return nil + } +} + +// ExecX is like Exec, but panics if an error occurs. +func (_d *SecuritySecretDeleteOne) ExecX(ctx context.Context) { + if err := _d.Exec(ctx); err != nil { + panic(err) + } +} diff --git a/backend/ent/securitysecret_query.go b/backend/ent/securitysecret_query.go new file mode 100644 index 00000000..fe53adf1 --- /dev/null +++ b/backend/ent/securitysecret_query.go @@ -0,0 +1,564 @@ +// Code generated by ent, DO NOT EDIT. + +package ent + +import ( + "context" + "fmt" + "math" + + "entgo.io/ent" + "entgo.io/ent/dialect" + "entgo.io/ent/dialect/sql" + "entgo.io/ent/dialect/sql/sqlgraph" + "entgo.io/ent/schema/field" + "github.com/Wei-Shaw/sub2api/ent/predicate" + "github.com/Wei-Shaw/sub2api/ent/securitysecret" +) + +// SecuritySecretQuery is the builder for querying SecuritySecret entities. +type SecuritySecretQuery struct { + config + ctx *QueryContext + order []securitysecret.OrderOption + inters []Interceptor + predicates []predicate.SecuritySecret + modifiers []func(*sql.Selector) + // intermediate query (i.e. traversal path). + sql *sql.Selector + path func(context.Context) (*sql.Selector, error) +} + +// Where adds a new predicate for the SecuritySecretQuery builder. +func (_q *SecuritySecretQuery) Where(ps ...predicate.SecuritySecret) *SecuritySecretQuery { + _q.predicates = append(_q.predicates, ps...) + return _q +} + +// Limit the number of records to be returned by this query. +func (_q *SecuritySecretQuery) Limit(limit int) *SecuritySecretQuery { + _q.ctx.Limit = &limit + return _q +} + +// Offset to start from. +func (_q *SecuritySecretQuery) Offset(offset int) *SecuritySecretQuery { + _q.ctx.Offset = &offset + return _q +} + +// Unique configures the query builder to filter duplicate records on query. +// By default, unique is set to true, and can be disabled using this method. +func (_q *SecuritySecretQuery) Unique(unique bool) *SecuritySecretQuery { + _q.ctx.Unique = &unique + return _q +} + +// Order specifies how the records should be ordered. +func (_q *SecuritySecretQuery) Order(o ...securitysecret.OrderOption) *SecuritySecretQuery { + _q.order = append(_q.order, o...) + return _q +} + +// First returns the first SecuritySecret entity from the query. +// Returns a *NotFoundError when no SecuritySecret was found. +func (_q *SecuritySecretQuery) First(ctx context.Context) (*SecuritySecret, error) { + nodes, err := _q.Limit(1).All(setContextOp(ctx, _q.ctx, ent.OpQueryFirst)) + if err != nil { + return nil, err + } + if len(nodes) == 0 { + return nil, &NotFoundError{securitysecret.Label} + } + return nodes[0], nil +} + +// FirstX is like First, but panics if an error occurs. +func (_q *SecuritySecretQuery) FirstX(ctx context.Context) *SecuritySecret { + node, err := _q.First(ctx) + if err != nil && !IsNotFound(err) { + panic(err) + } + return node +} + +// FirstID returns the first SecuritySecret ID from the query. +// Returns a *NotFoundError when no SecuritySecret ID was found. +func (_q *SecuritySecretQuery) FirstID(ctx context.Context) (id int64, err error) { + var ids []int64 + if ids, err = _q.Limit(1).IDs(setContextOp(ctx, _q.ctx, ent.OpQueryFirstID)); err != nil { + return + } + if len(ids) == 0 { + err = &NotFoundError{securitysecret.Label} + return + } + return ids[0], nil +} + +// FirstIDX is like FirstID, but panics if an error occurs. +func (_q *SecuritySecretQuery) FirstIDX(ctx context.Context) int64 { + id, err := _q.FirstID(ctx) + if err != nil && !IsNotFound(err) { + panic(err) + } + return id +} + +// Only returns a single SecuritySecret entity found by the query, ensuring it only returns one. +// Returns a *NotSingularError when more than one SecuritySecret entity is found. +// Returns a *NotFoundError when no SecuritySecret entities are found. +func (_q *SecuritySecretQuery) Only(ctx context.Context) (*SecuritySecret, error) { + nodes, err := _q.Limit(2).All(setContextOp(ctx, _q.ctx, ent.OpQueryOnly)) + if err != nil { + return nil, err + } + switch len(nodes) { + case 1: + return nodes[0], nil + case 0: + return nil, &NotFoundError{securitysecret.Label} + default: + return nil, &NotSingularError{securitysecret.Label} + } +} + +// OnlyX is like Only, but panics if an error occurs. +func (_q *SecuritySecretQuery) OnlyX(ctx context.Context) *SecuritySecret { + node, err := _q.Only(ctx) + if err != nil { + panic(err) + } + return node +} + +// OnlyID is like Only, but returns the only SecuritySecret ID in the query. +// Returns a *NotSingularError when more than one SecuritySecret ID is found. +// Returns a *NotFoundError when no entities are found. +func (_q *SecuritySecretQuery) OnlyID(ctx context.Context) (id int64, err error) { + var ids []int64 + if ids, err = _q.Limit(2).IDs(setContextOp(ctx, _q.ctx, ent.OpQueryOnlyID)); err != nil { + return + } + switch len(ids) { + case 1: + id = ids[0] + case 0: + err = &NotFoundError{securitysecret.Label} + default: + err = &NotSingularError{securitysecret.Label} + } + return +} + +// OnlyIDX is like OnlyID, but panics if an error occurs. +func (_q *SecuritySecretQuery) OnlyIDX(ctx context.Context) int64 { + id, err := _q.OnlyID(ctx) + if err != nil { + panic(err) + } + return id +} + +// All executes the query and returns a list of SecuritySecrets. +func (_q *SecuritySecretQuery) All(ctx context.Context) ([]*SecuritySecret, error) { + ctx = setContextOp(ctx, _q.ctx, ent.OpQueryAll) + if err := _q.prepareQuery(ctx); err != nil { + return nil, err + } + qr := querierAll[[]*SecuritySecret, *SecuritySecretQuery]() + return withInterceptors[[]*SecuritySecret](ctx, _q, qr, _q.inters) +} + +// AllX is like All, but panics if an error occurs. +func (_q *SecuritySecretQuery) AllX(ctx context.Context) []*SecuritySecret { + nodes, err := _q.All(ctx) + if err != nil { + panic(err) + } + return nodes +} + +// IDs executes the query and returns a list of SecuritySecret IDs. +func (_q *SecuritySecretQuery) IDs(ctx context.Context) (ids []int64, err error) { + if _q.ctx.Unique == nil && _q.path != nil { + _q.Unique(true) + } + ctx = setContextOp(ctx, _q.ctx, ent.OpQueryIDs) + if err = _q.Select(securitysecret.FieldID).Scan(ctx, &ids); err != nil { + return nil, err + } + return ids, nil +} + +// IDsX is like IDs, but panics if an error occurs. +func (_q *SecuritySecretQuery) IDsX(ctx context.Context) []int64 { + ids, err := _q.IDs(ctx) + if err != nil { + panic(err) + } + return ids +} + +// Count returns the count of the given query. +func (_q *SecuritySecretQuery) Count(ctx context.Context) (int, error) { + ctx = setContextOp(ctx, _q.ctx, ent.OpQueryCount) + if err := _q.prepareQuery(ctx); err != nil { + return 0, err + } + return withInterceptors[int](ctx, _q, querierCount[*SecuritySecretQuery](), _q.inters) +} + +// CountX is like Count, but panics if an error occurs. +func (_q *SecuritySecretQuery) CountX(ctx context.Context) int { + count, err := _q.Count(ctx) + if err != nil { + panic(err) + } + return count +} + +// Exist returns true if the query has elements in the graph. +func (_q *SecuritySecretQuery) Exist(ctx context.Context) (bool, error) { + ctx = setContextOp(ctx, _q.ctx, ent.OpQueryExist) + switch _, err := _q.FirstID(ctx); { + case IsNotFound(err): + return false, nil + case err != nil: + return false, fmt.Errorf("ent: check existence: %w", err) + default: + return true, nil + } +} + +// ExistX is like Exist, but panics if an error occurs. +func (_q *SecuritySecretQuery) ExistX(ctx context.Context) bool { + exist, err := _q.Exist(ctx) + if err != nil { + panic(err) + } + return exist +} + +// Clone returns a duplicate of the SecuritySecretQuery builder, including all associated steps. It can be +// used to prepare common query builders and use them differently after the clone is made. +func (_q *SecuritySecretQuery) Clone() *SecuritySecretQuery { + if _q == nil { + return nil + } + return &SecuritySecretQuery{ + config: _q.config, + ctx: _q.ctx.Clone(), + order: append([]securitysecret.OrderOption{}, _q.order...), + inters: append([]Interceptor{}, _q.inters...), + predicates: append([]predicate.SecuritySecret{}, _q.predicates...), + // clone intermediate query. + sql: _q.sql.Clone(), + path: _q.path, + } +} + +// GroupBy is used to group vertices by one or more fields/columns. +// It is often used with aggregate functions, like: count, max, mean, min, sum. +// +// Example: +// +// var v []struct { +// CreatedAt time.Time `json:"created_at,omitempty"` +// Count int `json:"count,omitempty"` +// } +// +// client.SecuritySecret.Query(). +// GroupBy(securitysecret.FieldCreatedAt). +// Aggregate(ent.Count()). +// Scan(ctx, &v) +func (_q *SecuritySecretQuery) GroupBy(field string, fields ...string) *SecuritySecretGroupBy { + _q.ctx.Fields = append([]string{field}, fields...) + grbuild := &SecuritySecretGroupBy{build: _q} + grbuild.flds = &_q.ctx.Fields + grbuild.label = securitysecret.Label + grbuild.scan = grbuild.Scan + return grbuild +} + +// Select allows the selection one or more fields/columns for the given query, +// instead of selecting all fields in the entity. +// +// Example: +// +// var v []struct { +// CreatedAt time.Time `json:"created_at,omitempty"` +// } +// +// client.SecuritySecret.Query(). +// Select(securitysecret.FieldCreatedAt). +// Scan(ctx, &v) +func (_q *SecuritySecretQuery) Select(fields ...string) *SecuritySecretSelect { + _q.ctx.Fields = append(_q.ctx.Fields, fields...) + sbuild := &SecuritySecretSelect{SecuritySecretQuery: _q} + sbuild.label = securitysecret.Label + sbuild.flds, sbuild.scan = &_q.ctx.Fields, sbuild.Scan + return sbuild +} + +// Aggregate returns a SecuritySecretSelect configured with the given aggregations. +func (_q *SecuritySecretQuery) Aggregate(fns ...AggregateFunc) *SecuritySecretSelect { + return _q.Select().Aggregate(fns...) +} + +func (_q *SecuritySecretQuery) prepareQuery(ctx context.Context) error { + for _, inter := range _q.inters { + if inter == nil { + return fmt.Errorf("ent: uninitialized interceptor (forgotten import ent/runtime?)") + } + if trv, ok := inter.(Traverser); ok { + if err := trv.Traverse(ctx, _q); err != nil { + return err + } + } + } + for _, f := range _q.ctx.Fields { + if !securitysecret.ValidColumn(f) { + return &ValidationError{Name: f, err: fmt.Errorf("ent: invalid field %q for query", f)} + } + } + if _q.path != nil { + prev, err := _q.path(ctx) + if err != nil { + return err + } + _q.sql = prev + } + return nil +} + +func (_q *SecuritySecretQuery) sqlAll(ctx context.Context, hooks ...queryHook) ([]*SecuritySecret, error) { + var ( + nodes = []*SecuritySecret{} + _spec = _q.querySpec() + ) + _spec.ScanValues = func(columns []string) ([]any, error) { + return (*SecuritySecret).scanValues(nil, columns) + } + _spec.Assign = func(columns []string, values []any) error { + node := &SecuritySecret{config: _q.config} + nodes = append(nodes, node) + return node.assignValues(columns, values) + } + if len(_q.modifiers) > 0 { + _spec.Modifiers = _q.modifiers + } + for i := range hooks { + hooks[i](ctx, _spec) + } + if err := sqlgraph.QueryNodes(ctx, _q.driver, _spec); err != nil { + return nil, err + } + if len(nodes) == 0 { + return nodes, nil + } + return nodes, nil +} + +func (_q *SecuritySecretQuery) sqlCount(ctx context.Context) (int, error) { + _spec := _q.querySpec() + if len(_q.modifiers) > 0 { + _spec.Modifiers = _q.modifiers + } + _spec.Node.Columns = _q.ctx.Fields + if len(_q.ctx.Fields) > 0 { + _spec.Unique = _q.ctx.Unique != nil && *_q.ctx.Unique + } + return sqlgraph.CountNodes(ctx, _q.driver, _spec) +} + +func (_q *SecuritySecretQuery) querySpec() *sqlgraph.QuerySpec { + _spec := sqlgraph.NewQuerySpec(securitysecret.Table, securitysecret.Columns, sqlgraph.NewFieldSpec(securitysecret.FieldID, field.TypeInt64)) + _spec.From = _q.sql + if unique := _q.ctx.Unique; unique != nil { + _spec.Unique = *unique + } else if _q.path != nil { + _spec.Unique = true + } + if fields := _q.ctx.Fields; len(fields) > 0 { + _spec.Node.Columns = make([]string, 0, len(fields)) + _spec.Node.Columns = append(_spec.Node.Columns, securitysecret.FieldID) + for i := range fields { + if fields[i] != securitysecret.FieldID { + _spec.Node.Columns = append(_spec.Node.Columns, fields[i]) + } + } + } + if ps := _q.predicates; len(ps) > 0 { + _spec.Predicate = func(selector *sql.Selector) { + for i := range ps { + ps[i](selector) + } + } + } + if limit := _q.ctx.Limit; limit != nil { + _spec.Limit = *limit + } + if offset := _q.ctx.Offset; offset != nil { + _spec.Offset = *offset + } + if ps := _q.order; len(ps) > 0 { + _spec.Order = func(selector *sql.Selector) { + for i := range ps { + ps[i](selector) + } + } + } + return _spec +} + +func (_q *SecuritySecretQuery) sqlQuery(ctx context.Context) *sql.Selector { + builder := sql.Dialect(_q.driver.Dialect()) + t1 := builder.Table(securitysecret.Table) + columns := _q.ctx.Fields + if len(columns) == 0 { + columns = securitysecret.Columns + } + selector := builder.Select(t1.Columns(columns...)...).From(t1) + if _q.sql != nil { + selector = _q.sql + selector.Select(selector.Columns(columns...)...) + } + if _q.ctx.Unique != nil && *_q.ctx.Unique { + selector.Distinct() + } + for _, m := range _q.modifiers { + m(selector) + } + for _, p := range _q.predicates { + p(selector) + } + for _, p := range _q.order { + p(selector) + } + if offset := _q.ctx.Offset; offset != nil { + // limit is mandatory for offset clause. We start + // with default value, and override it below if needed. + selector.Offset(*offset).Limit(math.MaxInt32) + } + if limit := _q.ctx.Limit; limit != nil { + selector.Limit(*limit) + } + return selector +} + +// ForUpdate locks the selected rows against concurrent updates, and prevent them from being +// updated, deleted or "selected ... for update" by other sessions, until the transaction is +// either committed or rolled-back. +func (_q *SecuritySecretQuery) ForUpdate(opts ...sql.LockOption) *SecuritySecretQuery { + if _q.driver.Dialect() == dialect.Postgres { + _q.Unique(false) + } + _q.modifiers = append(_q.modifiers, func(s *sql.Selector) { + s.ForUpdate(opts...) + }) + return _q +} + +// ForShare behaves similarly to ForUpdate, except that it acquires a shared mode lock +// on any rows that are read. Other sessions can read the rows, but cannot modify them +// until your transaction commits. +func (_q *SecuritySecretQuery) ForShare(opts ...sql.LockOption) *SecuritySecretQuery { + if _q.driver.Dialect() == dialect.Postgres { + _q.Unique(false) + } + _q.modifiers = append(_q.modifiers, func(s *sql.Selector) { + s.ForShare(opts...) + }) + return _q +} + +// SecuritySecretGroupBy is the group-by builder for SecuritySecret entities. +type SecuritySecretGroupBy struct { + selector + build *SecuritySecretQuery +} + +// Aggregate adds the given aggregation functions to the group-by query. +func (_g *SecuritySecretGroupBy) Aggregate(fns ...AggregateFunc) *SecuritySecretGroupBy { + _g.fns = append(_g.fns, fns...) + return _g +} + +// Scan applies the selector query and scans the result into the given value. +func (_g *SecuritySecretGroupBy) Scan(ctx context.Context, v any) error { + ctx = setContextOp(ctx, _g.build.ctx, ent.OpQueryGroupBy) + if err := _g.build.prepareQuery(ctx); err != nil { + return err + } + return scanWithInterceptors[*SecuritySecretQuery, *SecuritySecretGroupBy](ctx, _g.build, _g, _g.build.inters, v) +} + +func (_g *SecuritySecretGroupBy) sqlScan(ctx context.Context, root *SecuritySecretQuery, v any) error { + selector := root.sqlQuery(ctx).Select() + aggregation := make([]string, 0, len(_g.fns)) + for _, fn := range _g.fns { + aggregation = append(aggregation, fn(selector)) + } + if len(selector.SelectedColumns()) == 0 { + columns := make([]string, 0, len(*_g.flds)+len(_g.fns)) + for _, f := range *_g.flds { + columns = append(columns, selector.C(f)) + } + columns = append(columns, aggregation...) + selector.Select(columns...) + } + selector.GroupBy(selector.Columns(*_g.flds...)...) + if err := selector.Err(); err != nil { + return err + } + rows := &sql.Rows{} + query, args := selector.Query() + if err := _g.build.driver.Query(ctx, query, args, rows); err != nil { + return err + } + defer rows.Close() + return sql.ScanSlice(rows, v) +} + +// SecuritySecretSelect is the builder for selecting fields of SecuritySecret entities. +type SecuritySecretSelect struct { + *SecuritySecretQuery + selector +} + +// Aggregate adds the given aggregation functions to the selector query. +func (_s *SecuritySecretSelect) Aggregate(fns ...AggregateFunc) *SecuritySecretSelect { + _s.fns = append(_s.fns, fns...) + return _s +} + +// Scan applies the selector query and scans the result into the given value. +func (_s *SecuritySecretSelect) Scan(ctx context.Context, v any) error { + ctx = setContextOp(ctx, _s.ctx, ent.OpQuerySelect) + if err := _s.prepareQuery(ctx); err != nil { + return err + } + return scanWithInterceptors[*SecuritySecretQuery, *SecuritySecretSelect](ctx, _s.SecuritySecretQuery, _s, _s.inters, v) +} + +func (_s *SecuritySecretSelect) sqlScan(ctx context.Context, root *SecuritySecretQuery, v any) error { + selector := root.sqlQuery(ctx) + aggregation := make([]string, 0, len(_s.fns)) + for _, fn := range _s.fns { + aggregation = append(aggregation, fn(selector)) + } + switch n := len(*_s.selector.flds); { + case n == 0 && len(aggregation) > 0: + selector.Select(aggregation...) + case n != 0 && len(aggregation) > 0: + selector.AppendSelect(aggregation...) + } + rows := &sql.Rows{} + query, args := selector.Query() + if err := _s.driver.Query(ctx, query, args, rows); err != nil { + return err + } + defer rows.Close() + return sql.ScanSlice(rows, v) +} diff --git a/backend/ent/securitysecret_update.go b/backend/ent/securitysecret_update.go new file mode 100644 index 00000000..ec3979af --- /dev/null +++ b/backend/ent/securitysecret_update.go @@ -0,0 +1,316 @@ +// Code generated by ent, DO NOT EDIT. + +package ent + +import ( + "context" + "errors" + "fmt" + "time" + + "entgo.io/ent/dialect/sql" + "entgo.io/ent/dialect/sql/sqlgraph" + "entgo.io/ent/schema/field" + "github.com/Wei-Shaw/sub2api/ent/predicate" + "github.com/Wei-Shaw/sub2api/ent/securitysecret" +) + +// SecuritySecretUpdate is the builder for updating SecuritySecret entities. +type SecuritySecretUpdate struct { + config + hooks []Hook + mutation *SecuritySecretMutation +} + +// Where appends a list predicates to the SecuritySecretUpdate builder. +func (_u *SecuritySecretUpdate) Where(ps ...predicate.SecuritySecret) *SecuritySecretUpdate { + _u.mutation.Where(ps...) + return _u +} + +// SetUpdatedAt sets the "updated_at" field. +func (_u *SecuritySecretUpdate) SetUpdatedAt(v time.Time) *SecuritySecretUpdate { + _u.mutation.SetUpdatedAt(v) + return _u +} + +// SetKey sets the "key" field. +func (_u *SecuritySecretUpdate) SetKey(v string) *SecuritySecretUpdate { + _u.mutation.SetKey(v) + return _u +} + +// SetNillableKey sets the "key" field if the given value is not nil. +func (_u *SecuritySecretUpdate) SetNillableKey(v *string) *SecuritySecretUpdate { + if v != nil { + _u.SetKey(*v) + } + return _u +} + +// SetValue sets the "value" field. +func (_u *SecuritySecretUpdate) SetValue(v string) *SecuritySecretUpdate { + _u.mutation.SetValue(v) + return _u +} + +// SetNillableValue sets the "value" field if the given value is not nil. +func (_u *SecuritySecretUpdate) SetNillableValue(v *string) *SecuritySecretUpdate { + if v != nil { + _u.SetValue(*v) + } + return _u +} + +// Mutation returns the SecuritySecretMutation object of the builder. +func (_u *SecuritySecretUpdate) Mutation() *SecuritySecretMutation { + return _u.mutation +} + +// Save executes the query and returns the number of nodes affected by the update operation. +func (_u *SecuritySecretUpdate) Save(ctx context.Context) (int, error) { + _u.defaults() + return withHooks(ctx, _u.sqlSave, _u.mutation, _u.hooks) +} + +// SaveX is like Save, but panics if an error occurs. +func (_u *SecuritySecretUpdate) SaveX(ctx context.Context) int { + affected, err := _u.Save(ctx) + if err != nil { + panic(err) + } + return affected +} + +// Exec executes the query. +func (_u *SecuritySecretUpdate) Exec(ctx context.Context) error { + _, err := _u.Save(ctx) + return err +} + +// ExecX is like Exec, but panics if an error occurs. +func (_u *SecuritySecretUpdate) ExecX(ctx context.Context) { + if err := _u.Exec(ctx); err != nil { + panic(err) + } +} + +// defaults sets the default values of the builder before save. +func (_u *SecuritySecretUpdate) defaults() { + if _, ok := _u.mutation.UpdatedAt(); !ok { + v := securitysecret.UpdateDefaultUpdatedAt() + _u.mutation.SetUpdatedAt(v) + } +} + +// check runs all checks and user-defined validators on the builder. +func (_u *SecuritySecretUpdate) check() error { + if v, ok := _u.mutation.Key(); ok { + if err := securitysecret.KeyValidator(v); err != nil { + return &ValidationError{Name: "key", err: fmt.Errorf(`ent: validator failed for field "SecuritySecret.key": %w`, err)} + } + } + if v, ok := _u.mutation.Value(); ok { + if err := securitysecret.ValueValidator(v); err != nil { + return &ValidationError{Name: "value", err: fmt.Errorf(`ent: validator failed for field "SecuritySecret.value": %w`, err)} + } + } + return nil +} + +func (_u *SecuritySecretUpdate) sqlSave(ctx context.Context) (_node int, err error) { + if err := _u.check(); err != nil { + return _node, err + } + _spec := sqlgraph.NewUpdateSpec(securitysecret.Table, securitysecret.Columns, sqlgraph.NewFieldSpec(securitysecret.FieldID, field.TypeInt64)) + if ps := _u.mutation.predicates; len(ps) > 0 { + _spec.Predicate = func(selector *sql.Selector) { + for i := range ps { + ps[i](selector) + } + } + } + if value, ok := _u.mutation.UpdatedAt(); ok { + _spec.SetField(securitysecret.FieldUpdatedAt, field.TypeTime, value) + } + if value, ok := _u.mutation.Key(); ok { + _spec.SetField(securitysecret.FieldKey, field.TypeString, value) + } + if value, ok := _u.mutation.Value(); ok { + _spec.SetField(securitysecret.FieldValue, field.TypeString, value) + } + if _node, err = sqlgraph.UpdateNodes(ctx, _u.driver, _spec); err != nil { + if _, ok := err.(*sqlgraph.NotFoundError); ok { + err = &NotFoundError{securitysecret.Label} + } else if sqlgraph.IsConstraintError(err) { + err = &ConstraintError{msg: err.Error(), wrap: err} + } + return 0, err + } + _u.mutation.done = true + return _node, nil +} + +// SecuritySecretUpdateOne is the builder for updating a single SecuritySecret entity. +type SecuritySecretUpdateOne struct { + config + fields []string + hooks []Hook + mutation *SecuritySecretMutation +} + +// SetUpdatedAt sets the "updated_at" field. +func (_u *SecuritySecretUpdateOne) SetUpdatedAt(v time.Time) *SecuritySecretUpdateOne { + _u.mutation.SetUpdatedAt(v) + return _u +} + +// SetKey sets the "key" field. +func (_u *SecuritySecretUpdateOne) SetKey(v string) *SecuritySecretUpdateOne { + _u.mutation.SetKey(v) + return _u +} + +// SetNillableKey sets the "key" field if the given value is not nil. +func (_u *SecuritySecretUpdateOne) SetNillableKey(v *string) *SecuritySecretUpdateOne { + if v != nil { + _u.SetKey(*v) + } + return _u +} + +// SetValue sets the "value" field. +func (_u *SecuritySecretUpdateOne) SetValue(v string) *SecuritySecretUpdateOne { + _u.mutation.SetValue(v) + return _u +} + +// SetNillableValue sets the "value" field if the given value is not nil. +func (_u *SecuritySecretUpdateOne) SetNillableValue(v *string) *SecuritySecretUpdateOne { + if v != nil { + _u.SetValue(*v) + } + return _u +} + +// Mutation returns the SecuritySecretMutation object of the builder. +func (_u *SecuritySecretUpdateOne) Mutation() *SecuritySecretMutation { + return _u.mutation +} + +// Where appends a list predicates to the SecuritySecretUpdate builder. +func (_u *SecuritySecretUpdateOne) Where(ps ...predicate.SecuritySecret) *SecuritySecretUpdateOne { + _u.mutation.Where(ps...) + return _u +} + +// Select allows selecting one or more fields (columns) of the returned entity. +// The default is selecting all fields defined in the entity schema. +func (_u *SecuritySecretUpdateOne) Select(field string, fields ...string) *SecuritySecretUpdateOne { + _u.fields = append([]string{field}, fields...) + return _u +} + +// Save executes the query and returns the updated SecuritySecret entity. +func (_u *SecuritySecretUpdateOne) Save(ctx context.Context) (*SecuritySecret, error) { + _u.defaults() + return withHooks(ctx, _u.sqlSave, _u.mutation, _u.hooks) +} + +// SaveX is like Save, but panics if an error occurs. +func (_u *SecuritySecretUpdateOne) SaveX(ctx context.Context) *SecuritySecret { + node, err := _u.Save(ctx) + if err != nil { + panic(err) + } + return node +} + +// Exec executes the query on the entity. +func (_u *SecuritySecretUpdateOne) Exec(ctx context.Context) error { + _, err := _u.Save(ctx) + return err +} + +// ExecX is like Exec, but panics if an error occurs. +func (_u *SecuritySecretUpdateOne) ExecX(ctx context.Context) { + if err := _u.Exec(ctx); err != nil { + panic(err) + } +} + +// defaults sets the default values of the builder before save. +func (_u *SecuritySecretUpdateOne) defaults() { + if _, ok := _u.mutation.UpdatedAt(); !ok { + v := securitysecret.UpdateDefaultUpdatedAt() + _u.mutation.SetUpdatedAt(v) + } +} + +// check runs all checks and user-defined validators on the builder. +func (_u *SecuritySecretUpdateOne) check() error { + if v, ok := _u.mutation.Key(); ok { + if err := securitysecret.KeyValidator(v); err != nil { + return &ValidationError{Name: "key", err: fmt.Errorf(`ent: validator failed for field "SecuritySecret.key": %w`, err)} + } + } + if v, ok := _u.mutation.Value(); ok { + if err := securitysecret.ValueValidator(v); err != nil { + return &ValidationError{Name: "value", err: fmt.Errorf(`ent: validator failed for field "SecuritySecret.value": %w`, err)} + } + } + return nil +} + +func (_u *SecuritySecretUpdateOne) sqlSave(ctx context.Context) (_node *SecuritySecret, err error) { + if err := _u.check(); err != nil { + return _node, err + } + _spec := sqlgraph.NewUpdateSpec(securitysecret.Table, securitysecret.Columns, sqlgraph.NewFieldSpec(securitysecret.FieldID, field.TypeInt64)) + id, ok := _u.mutation.ID() + if !ok { + return nil, &ValidationError{Name: "id", err: errors.New(`ent: missing "SecuritySecret.id" for update`)} + } + _spec.Node.ID.Value = id + if fields := _u.fields; len(fields) > 0 { + _spec.Node.Columns = make([]string, 0, len(fields)) + _spec.Node.Columns = append(_spec.Node.Columns, securitysecret.FieldID) + for _, f := range fields { + if !securitysecret.ValidColumn(f) { + return nil, &ValidationError{Name: f, err: fmt.Errorf("ent: invalid field %q for query", f)} + } + if f != securitysecret.FieldID { + _spec.Node.Columns = append(_spec.Node.Columns, f) + } + } + } + if ps := _u.mutation.predicates; len(ps) > 0 { + _spec.Predicate = func(selector *sql.Selector) { + for i := range ps { + ps[i](selector) + } + } + } + if value, ok := _u.mutation.UpdatedAt(); ok { + _spec.SetField(securitysecret.FieldUpdatedAt, field.TypeTime, value) + } + if value, ok := _u.mutation.Key(); ok { + _spec.SetField(securitysecret.FieldKey, field.TypeString, value) + } + if value, ok := _u.mutation.Value(); ok { + _spec.SetField(securitysecret.FieldValue, field.TypeString, value) + } + _node = &SecuritySecret{config: _u.config} + _spec.Assign = _node.assignValues + _spec.ScanValues = _node.scanValues + if err = sqlgraph.UpdateNode(ctx, _u.driver, _spec); err != nil { + if _, ok := err.(*sqlgraph.NotFoundError); ok { + err = &NotFoundError{securitysecret.Label} + } else if sqlgraph.IsConstraintError(err) { + err = &ConstraintError{msg: err.Error(), wrap: err} + } + return nil, err + } + _u.mutation.done = true + return _node, nil +} diff --git a/backend/ent/tx.go b/backend/ent/tx.go index 45d83428..4fbe9bb4 100644 --- a/backend/ent/tx.go +++ b/backend/ent/tx.go @@ -36,6 +36,8 @@ type Tx struct { Proxy *ProxyClient // RedeemCode is the client for interacting with the RedeemCode builders. RedeemCode *RedeemCodeClient + // SecuritySecret is the client for interacting with the SecuritySecret builders. + SecuritySecret *SecuritySecretClient // Setting is the client for interacting with the Setting builders. Setting *SettingClient // UsageCleanupTask is the client for interacting with the UsageCleanupTask builders. @@ -194,6 +196,7 @@ func (tx *Tx) init() { tx.PromoCodeUsage = NewPromoCodeUsageClient(tx.config) tx.Proxy = NewProxyClient(tx.config) tx.RedeemCode = NewRedeemCodeClient(tx.config) + tx.SecuritySecret = NewSecuritySecretClient(tx.config) tx.Setting = NewSettingClient(tx.config) tx.UsageCleanupTask = NewUsageCleanupTaskClient(tx.config) tx.UsageLog = NewUsageLogClient(tx.config) diff --git a/backend/ent/usagelog.go b/backend/ent/usagelog.go index 59b25b99..f6968d0d 100644 --- a/backend/ent/usagelog.go +++ b/backend/ent/usagelog.go @@ -80,6 +80,8 @@ type UsageLog struct { ImageCount int `json:"image_count,omitempty"` // ImageSize holds the value of the "image_size" field. ImageSize *string `json:"image_size,omitempty"` + // MediaType holds the value of the "media_type" field. + MediaType *string `json:"media_type,omitempty"` // CacheTTLOverridden holds the value of the "cache_ttl_overridden" field. CacheTTLOverridden bool `json:"cache_ttl_overridden,omitempty"` // CreatedAt holds the value of the "created_at" field. @@ -173,7 +175,7 @@ func (*UsageLog) scanValues(columns []string) ([]any, error) { values[i] = new(sql.NullFloat64) case usagelog.FieldID, usagelog.FieldUserID, usagelog.FieldAPIKeyID, usagelog.FieldAccountID, usagelog.FieldGroupID, usagelog.FieldSubscriptionID, usagelog.FieldInputTokens, usagelog.FieldOutputTokens, usagelog.FieldCacheCreationTokens, usagelog.FieldCacheReadTokens, usagelog.FieldCacheCreation5mTokens, usagelog.FieldCacheCreation1hTokens, usagelog.FieldBillingType, usagelog.FieldDurationMs, usagelog.FieldFirstTokenMs, usagelog.FieldImageCount: values[i] = new(sql.NullInt64) - case usagelog.FieldRequestID, usagelog.FieldModel, usagelog.FieldUserAgent, usagelog.FieldIPAddress, usagelog.FieldImageSize: + case usagelog.FieldRequestID, usagelog.FieldModel, usagelog.FieldUserAgent, usagelog.FieldIPAddress, usagelog.FieldImageSize, usagelog.FieldMediaType: values[i] = new(sql.NullString) case usagelog.FieldCreatedAt: values[i] = new(sql.NullTime) @@ -380,6 +382,13 @@ func (_m *UsageLog) assignValues(columns []string, values []any) error { _m.ImageSize = new(string) *_m.ImageSize = value.String } + case usagelog.FieldMediaType: + if value, ok := values[i].(*sql.NullString); !ok { + return fmt.Errorf("unexpected type %T for field media_type", values[i]) + } else if value.Valid { + _m.MediaType = new(string) + *_m.MediaType = value.String + } case usagelog.FieldCacheTTLOverridden: if value, ok := values[i].(*sql.NullBool); !ok { return fmt.Errorf("unexpected type %T for field cache_ttl_overridden", values[i]) @@ -556,6 +565,11 @@ func (_m *UsageLog) String() string { builder.WriteString(*v) } builder.WriteString(", ") + if v := _m.MediaType; v != nil { + builder.WriteString("media_type=") + builder.WriteString(*v) + } + builder.WriteString(", ") builder.WriteString("cache_ttl_overridden=") builder.WriteString(fmt.Sprintf("%v", _m.CacheTTLOverridden)) builder.WriteString(", ") diff --git a/backend/ent/usagelog/usagelog.go b/backend/ent/usagelog/usagelog.go index fca720d2..ba97b843 100644 --- a/backend/ent/usagelog/usagelog.go +++ b/backend/ent/usagelog/usagelog.go @@ -72,6 +72,8 @@ const ( FieldImageCount = "image_count" // FieldImageSize holds the string denoting the image_size field in the database. FieldImageSize = "image_size" + // FieldMediaType holds the string denoting the media_type field in the database. + FieldMediaType = "media_type" // FieldCacheTTLOverridden holds the string denoting the cache_ttl_overridden field in the database. FieldCacheTTLOverridden = "cache_ttl_overridden" // FieldCreatedAt holds the string denoting the created_at field in the database. @@ -157,6 +159,7 @@ var Columns = []string{ FieldIPAddress, FieldImageCount, FieldImageSize, + FieldMediaType, FieldCacheTTLOverridden, FieldCreatedAt, } @@ -214,6 +217,8 @@ var ( DefaultImageCount int // ImageSizeValidator is a validator for the "image_size" field. It is called by the builders before save. ImageSizeValidator func(string) error + // MediaTypeValidator is a validator for the "media_type" field. It is called by the builders before save. + MediaTypeValidator func(string) error // DefaultCacheTTLOverridden holds the default value on creation for the "cache_ttl_overridden" field. DefaultCacheTTLOverridden bool // DefaultCreatedAt holds the default value on creation for the "created_at" field. @@ -373,6 +378,11 @@ func ByImageSize(opts ...sql.OrderTermOption) OrderOption { return sql.OrderByField(FieldImageSize, opts...).ToFunc() } +// ByMediaType orders the results by the media_type field. +func ByMediaType(opts ...sql.OrderTermOption) OrderOption { + return sql.OrderByField(FieldMediaType, opts...).ToFunc() +} + // ByCacheTTLOverridden orders the results by the cache_ttl_overridden field. func ByCacheTTLOverridden(opts ...sql.OrderTermOption) OrderOption { return sql.OrderByField(FieldCacheTTLOverridden, opts...).ToFunc() diff --git a/backend/ent/usagelog/where.go b/backend/ent/usagelog/where.go index ae832959..af960335 100644 --- a/backend/ent/usagelog/where.go +++ b/backend/ent/usagelog/where.go @@ -200,6 +200,11 @@ func ImageSize(v string) predicate.UsageLog { return predicate.UsageLog(sql.FieldEQ(FieldImageSize, v)) } +// MediaType applies equality check predicate on the "media_type" field. It's identical to MediaTypeEQ. +func MediaType(v string) predicate.UsageLog { + return predicate.UsageLog(sql.FieldEQ(FieldMediaType, v)) +} + // CacheTTLOverridden applies equality check predicate on the "cache_ttl_overridden" field. It's identical to CacheTTLOverriddenEQ. func CacheTTLOverridden(v bool) predicate.UsageLog { return predicate.UsageLog(sql.FieldEQ(FieldCacheTTLOverridden, v)) @@ -1445,6 +1450,81 @@ func ImageSizeContainsFold(v string) predicate.UsageLog { return predicate.UsageLog(sql.FieldContainsFold(FieldImageSize, v)) } +// MediaTypeEQ applies the EQ predicate on the "media_type" field. +func MediaTypeEQ(v string) predicate.UsageLog { + return predicate.UsageLog(sql.FieldEQ(FieldMediaType, v)) +} + +// MediaTypeNEQ applies the NEQ predicate on the "media_type" field. +func MediaTypeNEQ(v string) predicate.UsageLog { + return predicate.UsageLog(sql.FieldNEQ(FieldMediaType, v)) +} + +// MediaTypeIn applies the In predicate on the "media_type" field. +func MediaTypeIn(vs ...string) predicate.UsageLog { + return predicate.UsageLog(sql.FieldIn(FieldMediaType, vs...)) +} + +// MediaTypeNotIn applies the NotIn predicate on the "media_type" field. +func MediaTypeNotIn(vs ...string) predicate.UsageLog { + return predicate.UsageLog(sql.FieldNotIn(FieldMediaType, vs...)) +} + +// MediaTypeGT applies the GT predicate on the "media_type" field. +func MediaTypeGT(v string) predicate.UsageLog { + return predicate.UsageLog(sql.FieldGT(FieldMediaType, v)) +} + +// MediaTypeGTE applies the GTE predicate on the "media_type" field. +func MediaTypeGTE(v string) predicate.UsageLog { + return predicate.UsageLog(sql.FieldGTE(FieldMediaType, v)) +} + +// MediaTypeLT applies the LT predicate on the "media_type" field. +func MediaTypeLT(v string) predicate.UsageLog { + return predicate.UsageLog(sql.FieldLT(FieldMediaType, v)) +} + +// MediaTypeLTE applies the LTE predicate on the "media_type" field. +func MediaTypeLTE(v string) predicate.UsageLog { + return predicate.UsageLog(sql.FieldLTE(FieldMediaType, v)) +} + +// MediaTypeContains applies the Contains predicate on the "media_type" field. +func MediaTypeContains(v string) predicate.UsageLog { + return predicate.UsageLog(sql.FieldContains(FieldMediaType, v)) +} + +// MediaTypeHasPrefix applies the HasPrefix predicate on the "media_type" field. +func MediaTypeHasPrefix(v string) predicate.UsageLog { + return predicate.UsageLog(sql.FieldHasPrefix(FieldMediaType, v)) +} + +// MediaTypeHasSuffix applies the HasSuffix predicate on the "media_type" field. +func MediaTypeHasSuffix(v string) predicate.UsageLog { + return predicate.UsageLog(sql.FieldHasSuffix(FieldMediaType, v)) +} + +// MediaTypeIsNil applies the IsNil predicate on the "media_type" field. +func MediaTypeIsNil() predicate.UsageLog { + return predicate.UsageLog(sql.FieldIsNull(FieldMediaType)) +} + +// MediaTypeNotNil applies the NotNil predicate on the "media_type" field. +func MediaTypeNotNil() predicate.UsageLog { + return predicate.UsageLog(sql.FieldNotNull(FieldMediaType)) +} + +// MediaTypeEqualFold applies the EqualFold predicate on the "media_type" field. +func MediaTypeEqualFold(v string) predicate.UsageLog { + return predicate.UsageLog(sql.FieldEqualFold(FieldMediaType, v)) +} + +// MediaTypeContainsFold applies the ContainsFold predicate on the "media_type" field. +func MediaTypeContainsFold(v string) predicate.UsageLog { + return predicate.UsageLog(sql.FieldContainsFold(FieldMediaType, v)) +} + // CacheTTLOverriddenEQ applies the EQ predicate on the "cache_ttl_overridden" field. func CacheTTLOverriddenEQ(v bool) predicate.UsageLog { return predicate.UsageLog(sql.FieldEQ(FieldCacheTTLOverridden, v)) diff --git a/backend/ent/usagelog_create.go b/backend/ent/usagelog_create.go index 5b9cdf14..e0285a5e 100644 --- a/backend/ent/usagelog_create.go +++ b/backend/ent/usagelog_create.go @@ -393,6 +393,20 @@ func (_c *UsageLogCreate) SetNillableImageSize(v *string) *UsageLogCreate { return _c } +// SetMediaType sets the "media_type" field. +func (_c *UsageLogCreate) SetMediaType(v string) *UsageLogCreate { + _c.mutation.SetMediaType(v) + return _c +} + +// SetNillableMediaType sets the "media_type" field if the given value is not nil. +func (_c *UsageLogCreate) SetNillableMediaType(v *string) *UsageLogCreate { + if v != nil { + _c.SetMediaType(*v) + } + return _c +} + // SetCacheTTLOverridden sets the "cache_ttl_overridden" field. func (_c *UsageLogCreate) SetCacheTTLOverridden(v bool) *UsageLogCreate { _c.mutation.SetCacheTTLOverridden(v) @@ -645,6 +659,11 @@ func (_c *UsageLogCreate) check() error { return &ValidationError{Name: "image_size", err: fmt.Errorf(`ent: validator failed for field "UsageLog.image_size": %w`, err)} } } + if v, ok := _c.mutation.MediaType(); ok { + if err := usagelog.MediaTypeValidator(v); err != nil { + return &ValidationError{Name: "media_type", err: fmt.Errorf(`ent: validator failed for field "UsageLog.media_type": %w`, err)} + } + } if _, ok := _c.mutation.CacheTTLOverridden(); !ok { return &ValidationError{Name: "cache_ttl_overridden", err: errors.New(`ent: missing required field "UsageLog.cache_ttl_overridden"`)} } @@ -783,6 +802,10 @@ func (_c *UsageLogCreate) createSpec() (*UsageLog, *sqlgraph.CreateSpec) { _spec.SetField(usagelog.FieldImageSize, field.TypeString, value) _node.ImageSize = &value } + if value, ok := _c.mutation.MediaType(); ok { + _spec.SetField(usagelog.FieldMediaType, field.TypeString, value) + _node.MediaType = &value + } if value, ok := _c.mutation.CacheTTLOverridden(); ok { _spec.SetField(usagelog.FieldCacheTTLOverridden, field.TypeBool, value) _node.CacheTTLOverridden = value @@ -1432,6 +1455,24 @@ func (u *UsageLogUpsert) ClearImageSize() *UsageLogUpsert { return u } +// SetMediaType sets the "media_type" field. +func (u *UsageLogUpsert) SetMediaType(v string) *UsageLogUpsert { + u.Set(usagelog.FieldMediaType, v) + return u +} + +// UpdateMediaType sets the "media_type" field to the value that was provided on create. +func (u *UsageLogUpsert) UpdateMediaType() *UsageLogUpsert { + u.SetExcluded(usagelog.FieldMediaType) + return u +} + +// ClearMediaType clears the value of the "media_type" field. +func (u *UsageLogUpsert) ClearMediaType() *UsageLogUpsert { + u.SetNull(usagelog.FieldMediaType) + return u +} + // SetCacheTTLOverridden sets the "cache_ttl_overridden" field. func (u *UsageLogUpsert) SetCacheTTLOverridden(v bool) *UsageLogUpsert { u.Set(usagelog.FieldCacheTTLOverridden, v) @@ -2077,6 +2118,27 @@ func (u *UsageLogUpsertOne) ClearImageSize() *UsageLogUpsertOne { }) } +// SetMediaType sets the "media_type" field. +func (u *UsageLogUpsertOne) SetMediaType(v string) *UsageLogUpsertOne { + return u.Update(func(s *UsageLogUpsert) { + s.SetMediaType(v) + }) +} + +// UpdateMediaType sets the "media_type" field to the value that was provided on create. +func (u *UsageLogUpsertOne) UpdateMediaType() *UsageLogUpsertOne { + return u.Update(func(s *UsageLogUpsert) { + s.UpdateMediaType() + }) +} + +// ClearMediaType clears the value of the "media_type" field. +func (u *UsageLogUpsertOne) ClearMediaType() *UsageLogUpsertOne { + return u.Update(func(s *UsageLogUpsert) { + s.ClearMediaType() + }) +} + // SetCacheTTLOverridden sets the "cache_ttl_overridden" field. func (u *UsageLogUpsertOne) SetCacheTTLOverridden(v bool) *UsageLogUpsertOne { return u.Update(func(s *UsageLogUpsert) { @@ -2890,6 +2952,27 @@ func (u *UsageLogUpsertBulk) ClearImageSize() *UsageLogUpsertBulk { }) } +// SetMediaType sets the "media_type" field. +func (u *UsageLogUpsertBulk) SetMediaType(v string) *UsageLogUpsertBulk { + return u.Update(func(s *UsageLogUpsert) { + s.SetMediaType(v) + }) +} + +// UpdateMediaType sets the "media_type" field to the value that was provided on create. +func (u *UsageLogUpsertBulk) UpdateMediaType() *UsageLogUpsertBulk { + return u.Update(func(s *UsageLogUpsert) { + s.UpdateMediaType() + }) +} + +// ClearMediaType clears the value of the "media_type" field. +func (u *UsageLogUpsertBulk) ClearMediaType() *UsageLogUpsertBulk { + return u.Update(func(s *UsageLogUpsert) { + s.ClearMediaType() + }) +} + // SetCacheTTLOverridden sets the "cache_ttl_overridden" field. func (u *UsageLogUpsertBulk) SetCacheTTLOverridden(v bool) *UsageLogUpsertBulk { return u.Update(func(s *UsageLogUpsert) { diff --git a/backend/ent/usagelog_update.go b/backend/ent/usagelog_update.go index 22f3cb31..b46e5b56 100644 --- a/backend/ent/usagelog_update.go +++ b/backend/ent/usagelog_update.go @@ -612,6 +612,26 @@ func (_u *UsageLogUpdate) ClearImageSize() *UsageLogUpdate { return _u } +// SetMediaType sets the "media_type" field. +func (_u *UsageLogUpdate) SetMediaType(v string) *UsageLogUpdate { + _u.mutation.SetMediaType(v) + return _u +} + +// SetNillableMediaType sets the "media_type" field if the given value is not nil. +func (_u *UsageLogUpdate) SetNillableMediaType(v *string) *UsageLogUpdate { + if v != nil { + _u.SetMediaType(*v) + } + return _u +} + +// ClearMediaType clears the value of the "media_type" field. +func (_u *UsageLogUpdate) ClearMediaType() *UsageLogUpdate { + _u.mutation.ClearMediaType() + return _u +} + // SetCacheTTLOverridden sets the "cache_ttl_overridden" field. func (_u *UsageLogUpdate) SetCacheTTLOverridden(v bool) *UsageLogUpdate { _u.mutation.SetCacheTTLOverridden(v) @@ -740,6 +760,11 @@ func (_u *UsageLogUpdate) check() error { return &ValidationError{Name: "image_size", err: fmt.Errorf(`ent: validator failed for field "UsageLog.image_size": %w`, err)} } } + if v, ok := _u.mutation.MediaType(); ok { + if err := usagelog.MediaTypeValidator(v); err != nil { + return &ValidationError{Name: "media_type", err: fmt.Errorf(`ent: validator failed for field "UsageLog.media_type": %w`, err)} + } + } if _u.mutation.UserCleared() && len(_u.mutation.UserIDs()) > 0 { return errors.New(`ent: clearing a required unique edge "UsageLog.user"`) } @@ -908,6 +933,12 @@ func (_u *UsageLogUpdate) sqlSave(ctx context.Context) (_node int, err error) { if _u.mutation.ImageSizeCleared() { _spec.ClearField(usagelog.FieldImageSize, field.TypeString) } + if value, ok := _u.mutation.MediaType(); ok { + _spec.SetField(usagelog.FieldMediaType, field.TypeString, value) + } + if _u.mutation.MediaTypeCleared() { + _spec.ClearField(usagelog.FieldMediaType, field.TypeString) + } if value, ok := _u.mutation.CacheTTLOverridden(); ok { _spec.SetField(usagelog.FieldCacheTTLOverridden, field.TypeBool, value) } @@ -1656,6 +1687,26 @@ func (_u *UsageLogUpdateOne) ClearImageSize() *UsageLogUpdateOne { return _u } +// SetMediaType sets the "media_type" field. +func (_u *UsageLogUpdateOne) SetMediaType(v string) *UsageLogUpdateOne { + _u.mutation.SetMediaType(v) + return _u +} + +// SetNillableMediaType sets the "media_type" field if the given value is not nil. +func (_u *UsageLogUpdateOne) SetNillableMediaType(v *string) *UsageLogUpdateOne { + if v != nil { + _u.SetMediaType(*v) + } + return _u +} + +// ClearMediaType clears the value of the "media_type" field. +func (_u *UsageLogUpdateOne) ClearMediaType() *UsageLogUpdateOne { + _u.mutation.ClearMediaType() + return _u +} + // SetCacheTTLOverridden sets the "cache_ttl_overridden" field. func (_u *UsageLogUpdateOne) SetCacheTTLOverridden(v bool) *UsageLogUpdateOne { _u.mutation.SetCacheTTLOverridden(v) @@ -1797,6 +1848,11 @@ func (_u *UsageLogUpdateOne) check() error { return &ValidationError{Name: "image_size", err: fmt.Errorf(`ent: validator failed for field "UsageLog.image_size": %w`, err)} } } + if v, ok := _u.mutation.MediaType(); ok { + if err := usagelog.MediaTypeValidator(v); err != nil { + return &ValidationError{Name: "media_type", err: fmt.Errorf(`ent: validator failed for field "UsageLog.media_type": %w`, err)} + } + } if _u.mutation.UserCleared() && len(_u.mutation.UserIDs()) > 0 { return errors.New(`ent: clearing a required unique edge "UsageLog.user"`) } @@ -1982,6 +2038,12 @@ func (_u *UsageLogUpdateOne) sqlSave(ctx context.Context) (_node *UsageLog, err if _u.mutation.ImageSizeCleared() { _spec.ClearField(usagelog.FieldImageSize, field.TypeString) } + if value, ok := _u.mutation.MediaType(); ok { + _spec.SetField(usagelog.FieldMediaType, field.TypeString, value) + } + if _u.mutation.MediaTypeCleared() { + _spec.ClearField(usagelog.FieldMediaType, field.TypeString) + } if value, ok := _u.mutation.CacheTTLOverridden(); ok { _spec.SetField(usagelog.FieldCacheTTLOverridden, field.TypeBool, value) } diff --git a/backend/go.mod b/backend/go.mod index 08d54b91..94b6fcbb 100644 --- a/backend/go.mod +++ b/backend/go.mod @@ -5,6 +5,8 @@ go 1.25.7 require ( entgo.io/ent v0.14.5 github.com/DATA-DOG/go-sqlmock v1.5.2 + github.com/alitto/pond/v2 v2.6.2 + github.com/cespare/xxhash/v2 v2.3.0 github.com/dgraph-io/ristretto v0.2.0 github.com/gin-gonic/gin v1.9.1 github.com/golang-jwt/jwt/v5 v5.2.2 @@ -13,6 +15,7 @@ require ( github.com/gorilla/websocket v1.5.3 github.com/imroc/req/v3 v3.57.0 github.com/lib/pq v1.10.9 + github.com/patrickmn/go-cache v2.1.0+incompatible github.com/pquerna/otp v1.5.0 github.com/redis/go-redis/v9 v9.17.2 github.com/refraction-networking/utls v1.8.1 @@ -25,10 +28,12 @@ require ( github.com/tidwall/gjson v1.18.0 github.com/tidwall/sjson v1.2.5 github.com/zeromicro/go-zero v1.9.4 + go.uber.org/zap v1.24.0 golang.org/x/crypto v0.47.0 golang.org/x/net v0.49.0 golang.org/x/sync v0.19.0 golang.org/x/term v0.39.0 + gopkg.in/natefinch/lumberjack.v2 v2.2.1 gopkg.in/yaml.v3 v3.0.1 modernc.org/sqlite v1.44.3 ) @@ -45,7 +50,6 @@ require ( github.com/boombuler/barcode v1.0.1-0.20190219062509-6c824513bacc // indirect github.com/bytedance/sonic v1.9.1 // indirect github.com/cenkalti/backoff/v4 v4.3.0 // indirect - github.com/cespare/xxhash/v2 v2.3.0 // indirect github.com/chenzhuoyu/base64x v0.0.0-20221115062448-fe3a3abad311 // indirect github.com/containerd/errdefs v1.0.0 // indirect github.com/containerd/errdefs/pkg v0.3.0 // indirect @@ -75,6 +79,7 @@ require ( github.com/goccy/go-json v0.10.2 // indirect github.com/google/go-cmp v0.7.0 // indirect github.com/google/go-querystring v1.1.0 // indirect + github.com/google/subcommands v1.2.0 // indirect github.com/grpc-ecosystem/grpc-gateway/v2 v2.27.3 // indirect github.com/hashicorp/hcl v1.0.0 // indirect github.com/hashicorp/hcl/v2 v2.18.1 // indirect @@ -103,7 +108,6 @@ require ( github.com/ncruces/go-strftime v1.0.0 // indirect github.com/opencontainers/go-digest v1.0.0 // indirect github.com/opencontainers/image-spec v1.1.1 // indirect - github.com/patrickmn/go-cache v2.1.0+incompatible // indirect github.com/pelletier/go-toml/v2 v2.2.2 // indirect github.com/pkg/errors v0.9.1 // indirect github.com/pmezard/go-difflib v1.0.1-0.20181226105442-5d4384ee4fb2 // indirect @@ -144,6 +148,7 @@ require ( golang.org/x/mod v0.31.0 // indirect golang.org/x/sys v0.40.0 // indirect golang.org/x/text v0.33.0 // indirect + golang.org/x/tools v0.40.0 // indirect google.golang.org/grpc v1.75.1 // indirect google.golang.org/protobuf v1.36.10 // indirect gopkg.in/ini.v1 v1.67.0 // indirect diff --git a/backend/go.sum b/backend/go.sum index 71e8f504..f044c3a8 100644 --- a/backend/go.sum +++ b/backend/go.sum @@ -14,10 +14,14 @@ github.com/Microsoft/go-winio v0.6.2 h1:F2VQgta7ecxGYO8k3ZZz3RS8fVIXVxONVUPlNERo github.com/Microsoft/go-winio v0.6.2/go.mod h1:yd8OoFMLzJbo9gZq8j5qaps8bJ9aShtEA8Ipt1oGCvU= github.com/agext/levenshtein v1.2.3 h1:YB2fHEn0UJagG8T1rrWknE3ZQzWM06O8AMAatNn7lmo= github.com/agext/levenshtein v1.2.3/go.mod h1:JEDfjyjHDjOF/1e4FlBE/PkbqA9OfWu2ki2W0IB5558= +github.com/alitto/pond/v2 v2.6.2 h1:Sphe40g0ILeM1pA2c2K+Th0DGU+pt0A/Kprr+WB24Pw= +github.com/alitto/pond/v2 v2.6.2/go.mod h1:xkjYEgQ05RSpWdfSd1nM3OVv7TBhLdy7rMp3+2Nq+yE= github.com/andybalholm/brotli v1.2.0 h1:ukwgCxwYrmACq68yiUqwIWnGY0cTPox/M94sVwToPjQ= github.com/andybalholm/brotli v1.2.0/go.mod h1:rzTDkvFWvIrjDXZHkuS16NPggd91W3kUSvPlQ1pLaKY= github.com/apparentlymart/go-textseg/v15 v15.0.0 h1:uYvfpb3DyLSCGWnctWKGj857c6ew1u1fNQOlOtuGxQY= github.com/apparentlymart/go-textseg/v15 v15.0.0/go.mod h1:K8XmNZdhEBkdlyDdvbmmsvpAG721bKi0joRfFdHIWJ4= +github.com/benbjohnson/clock v1.1.0 h1:Q92kusRqC1XV2MjkWETPvjJVqKetz1OzxZB7mHJLju8= +github.com/benbjohnson/clock v1.1.0/go.mod h1:J11/hYXuz8f4ySSvYwY0FKfm+ezbsZBKZxNJlLklBHA= github.com/bmatcuk/doublestar v1.3.4 h1:gPypJ5xD31uhX6Tf54sDPUOBXTqKH4c9aPY66CyQrS0= github.com/bmatcuk/doublestar v1.3.4/go.mod h1:wiQtGV+rzVYxB7WIlirSN++5HPtPlXEo9MEoZQC/PmE= github.com/boombuler/barcode v1.0.1-0.20190219062509-6c824513bacc h1:biVzkmvwrH8WK8raXaxBx6fRVTlJILwEwQGL1I/ByEI= @@ -116,6 +120,8 @@ github.com/google/go-querystring v1.1.0/go.mod h1:Kcdr2DB4koayq7X8pmAG4sNG59So17 github.com/google/gofuzz v1.0.0/go.mod h1:dBl0BpW6vV/+mYPU4Po3pmUjxk6FQPldtuIdl/M65Eg= github.com/google/pprof v0.0.0-20250317173921-a4b03ec1a45e h1:ijClszYn+mADRFY17kjQEVQ1XRhq2/JR1M3sGqeJoxs= github.com/google/pprof v0.0.0-20250317173921-a4b03ec1a45e/go.mod h1:boTsfXsheKC2y+lKOCMpSfarhxDeIzfZG1jqGcPl3cA= +github.com/google/subcommands v1.2.0 h1:vWQspBTo2nEqTUFita5/KeEWlUL8kQObDFbub/EN9oE= +github.com/google/subcommands v1.2.0/go.mod h1:ZjhPrFU+Olkh9WazFPsl27BQ4UPiG37m3yTrtFlrHVk= github.com/google/uuid v1.6.0 h1:NIvaJDMOsjHA8n1jAhLSgzrAzy1Hgr+hNrb57e+94F0= github.com/google/uuid v1.6.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo= github.com/google/wire v0.7.0 h1:JxUKI6+CVBgCO2WToKy/nQk0sS+amI9z9EjVmdaocj4= @@ -135,8 +141,6 @@ github.com/icholy/digest v1.1.0 h1:HfGg9Irj7i+IX1o1QAmPfIBNu/Q5A5Tu3n/MED9k9H4= github.com/icholy/digest v1.1.0/go.mod h1:QNrsSGQ5v7v9cReDI0+eyjsXGUoRSUZQHeQ5C4XLa0Y= github.com/imroc/req/v3 v3.57.0 h1:LMTUjNRUybUkTPn8oJDq8Kg3JRBOBTcnDhKu7mzupKI= github.com/imroc/req/v3 v3.57.0/go.mod h1:JL62ey1nvSLq81HORNcosvlf7SxZStONNqOprg0Pz00= -github.com/inconshreveable/mousetrap v1.1.0 h1:wN+x4NVGpMsO7ErUn/mUI3vEoE6Jt13X2s0bqwp9tc8= -github.com/inconshreveable/mousetrap v1.1.0/go.mod h1:vpF70FUmC8bwa3OWnCshd2FqLfsEA9PFc4w1p2J65bw= github.com/jackc/pgpassfile v1.0.0 h1:/6Hmqy13Ss2zCq62VdNG8tM1wchn8zjSGOBJ6icpsIM= github.com/jackc/pgpassfile v1.0.0/go.mod h1:CEx0iS5ambNFdcRtxPj5JhEz+xB6uRky5eyVu/W2HEg= github.com/jackc/pgservicefile v0.0.0-20240606120523-5a60cdf6a761 h1:iCEnooe7UlwOQYpKFhBabPMi4aNAfoODPEFNiAnClxo= @@ -340,10 +344,14 @@ go.uber.org/atomic v1.10.0 h1:9qC72Qh0+3MqyJbAn8YU5xVq1frD8bn3JtD2oXtafVQ= go.uber.org/atomic v1.10.0/go.mod h1:LUxbIzbOniOlMKjJjyPfpl4v+PKK2cNJn91OQbhoJI0= go.uber.org/automaxprocs v1.6.0 h1:O3y2/QNTOdbF+e/dpXNNW7Rx2hZ4sTIPyybbxyNqTUs= go.uber.org/automaxprocs v1.6.0/go.mod h1:ifeIMSnPZuznNm6jmdzmU3/bfk01Fe2fotchwEFJ8r8= +go.uber.org/goleak v1.3.0 h1:2K3zAYmnTNqV73imy9J1T3WC+gmCePx2hEGkimedGto= +go.uber.org/goleak v1.3.0/go.mod h1:CoHD4mav9JJNrW/WLlf7HGZPjdw8EucARQHekz1X6bE= go.uber.org/mock v0.6.0 h1:hyF9dfmbgIX5EfOdasqLsWD6xqpNZlXblLB/Dbnwv3Y= go.uber.org/mock v0.6.0/go.mod h1:KiVJ4BqZJaMj4svdfmHM0AUx4NJYO8ZNpPnZn1Z+BBU= go.uber.org/multierr v1.9.0 h1:7fIwc/ZtS0q++VgcfqFDxSBZVv/Xo49/SYnDFupUwlI= go.uber.org/multierr v1.9.0/go.mod h1:X2jQV1h+kxSjClGpnseKVIxpmcjrj7MNnI0bnlfKTVQ= +go.uber.org/zap v1.24.0 h1:FiJd5l1UOLj0wCgbSE0rwwXHzEdAZS6hiiSnxJN/D60= +go.uber.org/zap v1.24.0/go.mod h1:2kMP+WWQ8aoFoedH3T2sq6iJ2yDWpHbP0f6MQbS9Gkg= golang.org/x/arch v0.0.0-20210923205945-b76863e36670/go.mod h1:5om86z9Hs0C8fWVUuoMHwpExlXzs5Tkyp9hOrfG7pp8= golang.org/x/arch v0.3.0 h1:02VY4/ZcO/gBOH6PUaoiptASxtXU10jazRCP865E97k= golang.org/x/arch v0.3.0/go.mod h1:5om86z9Hs0C8fWVUuoMHwpExlXzs5Tkyp9hOrfG7pp8= @@ -391,6 +399,8 @@ gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c h1:Hei/4ADfdWqJk1ZMxUNpqntN gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c/go.mod h1:JHkPIbrfpd72SG/EVd6muEfDQjcINNoR0C8j2r3qZ4Q= gopkg.in/ini.v1 v1.67.0 h1:Dgnx+6+nfE+IfzjUEISNeydPJh9AXNNsWbGP9KzCsOA= gopkg.in/ini.v1 v1.67.0/go.mod h1:pNLf8WUiyNEtQjuu5G5vTm06TEv9tsIgeAvK8hOrP4k= +gopkg.in/natefinch/lumberjack.v2 v2.2.1 h1:bBRl1b0OH9s/DuPhuXpNl+VtCaJXFZ5/uEFST95x9zc= +gopkg.in/natefinch/lumberjack.v2 v2.2.1/go.mod h1:YD8tP3GAjkrDg1eZH7EGmyESg/lsYskCTPBJVb9jqSc= gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA= gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= diff --git a/backend/internal/config/config.go b/backend/internal/config/config.go index 91437ba8..c4d4fdab 100644 --- a/backend/internal/config/config.go +++ b/backend/internal/config/config.go @@ -5,7 +5,7 @@ import ( "crypto/rand" "encoding/hex" "fmt" - "log" + "log/slog" "net/url" "os" "strings" @@ -19,6 +19,13 @@ const ( RunModeSimple = "simple" ) +// 使用量记录队列溢出策略 +const ( + UsageRecordOverflowPolicyDrop = "drop" + UsageRecordOverflowPolicySample = "sample" + UsageRecordOverflowPolicySync = "sync" +) + // DefaultCSPPolicy is the default Content-Security-Policy with nonce support // __CSP_NONCE__ will be replaced with actual nonce at request time by the SecurityHeaders middleware const DefaultCSPPolicy = "default-src 'self'; script-src 'self' __CSP_NONCE__ https://challenges.cloudflare.com https://static.cloudflareinsights.com; style-src 'self' 'unsafe-inline' https://fonts.googleapis.com; img-src 'self' data: https:; font-src 'self' data: https://fonts.gstatic.com; connect-src 'self' https:; frame-src https://challenges.cloudflare.com; frame-ancestors 'none'; base-uri 'self'; form-action 'self'" @@ -38,31 +45,68 @@ const ( ) type Config struct { - Server ServerConfig `mapstructure:"server"` - CORS CORSConfig `mapstructure:"cors"` - Security SecurityConfig `mapstructure:"security"` - Billing BillingConfig `mapstructure:"billing"` - Turnstile TurnstileConfig `mapstructure:"turnstile"` - Database DatabaseConfig `mapstructure:"database"` - Redis RedisConfig `mapstructure:"redis"` - Ops OpsConfig `mapstructure:"ops"` - JWT JWTConfig `mapstructure:"jwt"` - Totp TotpConfig `mapstructure:"totp"` - LinuxDo LinuxDoConnectConfig `mapstructure:"linuxdo_connect"` - Default DefaultConfig `mapstructure:"default"` - RateLimit RateLimitConfig `mapstructure:"rate_limit"` - Pricing PricingConfig `mapstructure:"pricing"` - Gateway GatewayConfig `mapstructure:"gateway"` - APIKeyAuth APIKeyAuthCacheConfig `mapstructure:"api_key_auth_cache"` - Dashboard DashboardCacheConfig `mapstructure:"dashboard_cache"` - DashboardAgg DashboardAggregationConfig `mapstructure:"dashboard_aggregation"` - UsageCleanup UsageCleanupConfig `mapstructure:"usage_cleanup"` - Concurrency ConcurrencyConfig `mapstructure:"concurrency"` - TokenRefresh TokenRefreshConfig `mapstructure:"token_refresh"` - RunMode string `mapstructure:"run_mode" yaml:"run_mode"` - Timezone string `mapstructure:"timezone"` // e.g. "Asia/Shanghai", "UTC" - Gemini GeminiConfig `mapstructure:"gemini"` - Update UpdateConfig `mapstructure:"update"` + Server ServerConfig `mapstructure:"server"` + Log LogConfig `mapstructure:"log"` + CORS CORSConfig `mapstructure:"cors"` + Security SecurityConfig `mapstructure:"security"` + Billing BillingConfig `mapstructure:"billing"` + Turnstile TurnstileConfig `mapstructure:"turnstile"` + Database DatabaseConfig `mapstructure:"database"` + Redis RedisConfig `mapstructure:"redis"` + Ops OpsConfig `mapstructure:"ops"` + JWT JWTConfig `mapstructure:"jwt"` + Totp TotpConfig `mapstructure:"totp"` + LinuxDo LinuxDoConnectConfig `mapstructure:"linuxdo_connect"` + Default DefaultConfig `mapstructure:"default"` + RateLimit RateLimitConfig `mapstructure:"rate_limit"` + Pricing PricingConfig `mapstructure:"pricing"` + Gateway GatewayConfig `mapstructure:"gateway"` + APIKeyAuth APIKeyAuthCacheConfig `mapstructure:"api_key_auth_cache"` + SubscriptionCache SubscriptionCacheConfig `mapstructure:"subscription_cache"` + SubscriptionMaintenance SubscriptionMaintenanceConfig `mapstructure:"subscription_maintenance"` + Dashboard DashboardCacheConfig `mapstructure:"dashboard_cache"` + DashboardAgg DashboardAggregationConfig `mapstructure:"dashboard_aggregation"` + UsageCleanup UsageCleanupConfig `mapstructure:"usage_cleanup"` + Concurrency ConcurrencyConfig `mapstructure:"concurrency"` + TokenRefresh TokenRefreshConfig `mapstructure:"token_refresh"` + Sora SoraConfig `mapstructure:"sora"` + RunMode string `mapstructure:"run_mode" yaml:"run_mode"` + Timezone string `mapstructure:"timezone"` // e.g. "Asia/Shanghai", "UTC" + Gemini GeminiConfig `mapstructure:"gemini"` + Update UpdateConfig `mapstructure:"update"` + Idempotency IdempotencyConfig `mapstructure:"idempotency"` +} + +type LogConfig struct { + Level string `mapstructure:"level"` + Format string `mapstructure:"format"` + ServiceName string `mapstructure:"service_name"` + Environment string `mapstructure:"env"` + Caller bool `mapstructure:"caller"` + StacktraceLevel string `mapstructure:"stacktrace_level"` + Output LogOutputConfig `mapstructure:"output"` + Rotation LogRotationConfig `mapstructure:"rotation"` + Sampling LogSamplingConfig `mapstructure:"sampling"` +} + +type LogOutputConfig struct { + ToStdout bool `mapstructure:"to_stdout"` + ToFile bool `mapstructure:"to_file"` + FilePath string `mapstructure:"file_path"` +} + +type LogRotationConfig struct { + MaxSizeMB int `mapstructure:"max_size_mb"` + MaxBackups int `mapstructure:"max_backups"` + MaxAgeDays int `mapstructure:"max_age_days"` + Compress bool `mapstructure:"compress"` + LocalTime bool `mapstructure:"local_time"` +} + +type LogSamplingConfig struct { + Enabled bool `mapstructure:"enabled"` + Initial int `mapstructure:"initial"` + Thereafter int `mapstructure:"thereafter"` } type GeminiConfig struct { @@ -94,6 +138,25 @@ type UpdateConfig struct { ProxyURL string `mapstructure:"proxy_url"` } +type IdempotencyConfig struct { + // ObserveOnly 为 true 时处于观察期:未携带 Idempotency-Key 的请求继续放行。 + ObserveOnly bool `mapstructure:"observe_only"` + // DefaultTTLSeconds 关键写接口的幂等记录默认 TTL(秒)。 + DefaultTTLSeconds int `mapstructure:"default_ttl_seconds"` + // SystemOperationTTLSeconds 系统操作接口的幂等记录 TTL(秒)。 + SystemOperationTTLSeconds int `mapstructure:"system_operation_ttl_seconds"` + // ProcessingTimeoutSeconds processing 状态锁超时(秒)。 + ProcessingTimeoutSeconds int `mapstructure:"processing_timeout_seconds"` + // FailedRetryBackoffSeconds 失败退避窗口(秒)。 + FailedRetryBackoffSeconds int `mapstructure:"failed_retry_backoff_seconds"` + // MaxStoredResponseLen 持久化响应体最大长度(字节)。 + MaxStoredResponseLen int `mapstructure:"max_stored_response_len"` + // CleanupIntervalSeconds 过期记录清理周期(秒)。 + CleanupIntervalSeconds int `mapstructure:"cleanup_interval_seconds"` + // CleanupBatchSize 每次清理的最大记录数。 + CleanupBatchSize int `mapstructure:"cleanup_batch_size"` +} + type LinuxDoConnectConfig struct { Enabled bool `mapstructure:"enabled"` ClientID string `mapstructure:"client_id"` @@ -126,6 +189,8 @@ type TokenRefreshConfig struct { MaxRetries int `mapstructure:"max_retries"` // 重试退避基础时间(秒) RetryBackoffSeconds int `mapstructure:"retry_backoff_seconds"` + // 是否允许 OpenAI 刷新器同步覆盖关联的 Sora 账号 token(默认关闭) + SyncLinkedSoraAccounts bool `mapstructure:"sync_linked_sora_accounts"` } type PricingConfig struct { @@ -147,6 +212,7 @@ type ServerConfig struct { Host string `mapstructure:"host"` Port int `mapstructure:"port"` Mode string `mapstructure:"mode"` // debug/release + FrontendURL string `mapstructure:"frontend_url"` // 前端基础 URL,用于生成邮件中的外部链接 ReadHeaderTimeout int `mapstructure:"read_header_timeout"` // 读取请求头超时(秒) IdleTimeout int `mapstructure:"idle_timeout"` // 空闲连接超时(秒) TrustedProxies []string `mapstructure:"trusted_proxies"` // 可信代理列表(CIDR/IP) @@ -173,6 +239,7 @@ type SecurityConfig struct { URLAllowlist URLAllowlistConfig `mapstructure:"url_allowlist"` ResponseHeaders ResponseHeaderConfig `mapstructure:"response_headers"` CSP CSPConfig `mapstructure:"csp"` + ProxyFallback ProxyFallbackConfig `mapstructure:"proxy_fallback"` ProxyProbe ProxyProbeConfig `mapstructure:"proxy_probe"` } @@ -197,6 +264,12 @@ type CSPConfig struct { Policy string `mapstructure:"policy"` } +type ProxyFallbackConfig struct { + // AllowDirectOnError 当代理初始化失败时是否允许回退直连。 + // 默认 false:避免因代理配置错误导致 IP 泄露/关联。 + AllowDirectOnError bool `mapstructure:"allow_direct_on_error"` +} + type ProxyProbeConfig struct { InsecureSkipVerify bool `mapstructure:"insecure_skip_verify"` // 已禁用:禁止跳过 TLS 证书验证 } @@ -217,6 +290,59 @@ type ConcurrencyConfig struct { PingInterval int `mapstructure:"ping_interval"` } +// SoraConfig 直连 Sora 配置 +type SoraConfig struct { + Client SoraClientConfig `mapstructure:"client"` + Storage SoraStorageConfig `mapstructure:"storage"` +} + +// SoraClientConfig 直连 Sora 客户端配置 +type SoraClientConfig struct { + BaseURL string `mapstructure:"base_url"` + TimeoutSeconds int `mapstructure:"timeout_seconds"` + MaxRetries int `mapstructure:"max_retries"` + CloudflareChallengeCooldownSeconds int `mapstructure:"cloudflare_challenge_cooldown_seconds"` + PollIntervalSeconds int `mapstructure:"poll_interval_seconds"` + MaxPollAttempts int `mapstructure:"max_poll_attempts"` + RecentTaskLimit int `mapstructure:"recent_task_limit"` + RecentTaskLimitMax int `mapstructure:"recent_task_limit_max"` + Debug bool `mapstructure:"debug"` + UseOpenAITokenProvider bool `mapstructure:"use_openai_token_provider"` + Headers map[string]string `mapstructure:"headers"` + UserAgent string `mapstructure:"user_agent"` + DisableTLSFingerprint bool `mapstructure:"disable_tls_fingerprint"` + CurlCFFISidecar SoraCurlCFFISidecarConfig `mapstructure:"curl_cffi_sidecar"` +} + +// SoraCurlCFFISidecarConfig Sora 专用 curl_cffi sidecar 配置 +type SoraCurlCFFISidecarConfig struct { + Enabled bool `mapstructure:"enabled"` + BaseURL string `mapstructure:"base_url"` + Impersonate string `mapstructure:"impersonate"` + TimeoutSeconds int `mapstructure:"timeout_seconds"` + SessionReuseEnabled bool `mapstructure:"session_reuse_enabled"` + SessionTTLSeconds int `mapstructure:"session_ttl_seconds"` +} + +// SoraStorageConfig 媒体存储配置 +type SoraStorageConfig struct { + Type string `mapstructure:"type"` + LocalPath string `mapstructure:"local_path"` + FallbackToUpstream bool `mapstructure:"fallback_to_upstream"` + MaxConcurrentDownloads int `mapstructure:"max_concurrent_downloads"` + DownloadTimeoutSeconds int `mapstructure:"download_timeout_seconds"` + MaxDownloadBytes int64 `mapstructure:"max_download_bytes"` + Debug bool `mapstructure:"debug"` + Cleanup SoraStorageCleanupConfig `mapstructure:"cleanup"` +} + +// SoraStorageCleanupConfig 媒体清理配置 +type SoraStorageCleanupConfig struct { + Enabled bool `mapstructure:"enabled"` + Schedule string `mapstructure:"schedule"` + RetentionDays int `mapstructure:"retention_days"` +} + // GatewayConfig API网关相关配置 type GatewayConfig struct { // 等待上游响应头的超时时间(秒),0表示无超时 @@ -224,8 +350,20 @@ type GatewayConfig struct { ResponseHeaderTimeout int `mapstructure:"response_header_timeout"` // 请求体最大字节数,用于网关请求体大小限制 MaxBodySize int64 `mapstructure:"max_body_size"` + // 非流式上游响应体读取上限(字节),用于防止无界读取导致内存放大 + UpstreamResponseReadMaxBytes int64 `mapstructure:"upstream_response_read_max_bytes"` + // 代理探测响应体读取上限(字节) + ProxyProbeResponseReadMaxBytes int64 `mapstructure:"proxy_probe_response_read_max_bytes"` + // Gemini 上游响应头调试日志开关(默认关闭,避免高频日志开销) + GeminiDebugResponseHeaders bool `mapstructure:"gemini_debug_response_headers"` // ConnectionPoolIsolation: 上游连接池隔离策略(proxy/account/account_proxy) ConnectionPoolIsolation string `mapstructure:"connection_pool_isolation"` + // ForceCodexCLI: 强制将 OpenAI `/v1/responses` 请求按 Codex CLI 处理。 + // 用于网关未透传/改写 User-Agent 时的兼容兜底(默认关闭,避免影响其他客户端)。 + ForceCodexCLI bool `mapstructure:"force_codex_cli"` + // OpenAIPassthroughAllowTimeoutHeaders: OpenAI 透传模式是否放行客户端超时头 + // 关闭(默认)可避免 x-stainless-timeout 等头导致上游提前断流。 + OpenAIPassthroughAllowTimeoutHeaders bool `mapstructure:"openai_passthrough_allow_timeout_headers"` // HTTP 上游连接池配置(性能优化:支持高并发场景调优) // MaxIdleConns: 所有主机的最大空闲连接总数 @@ -271,6 +409,24 @@ type GatewayConfig struct { // 是否允许对部分 400 错误触发 failover(默认关闭以避免改变语义) FailoverOn400 bool `mapstructure:"failover_on_400"` + // Sora 专用配置 + // SoraMaxBodySize: Sora 请求体最大字节数(0 表示使用 gateway.max_body_size) + SoraMaxBodySize int64 `mapstructure:"sora_max_body_size"` + // SoraStreamTimeoutSeconds: Sora 流式请求总超时(秒,0 表示不限制) + SoraStreamTimeoutSeconds int `mapstructure:"sora_stream_timeout_seconds"` + // SoraRequestTimeoutSeconds: Sora 非流式请求超时(秒,0 表示不限制) + SoraRequestTimeoutSeconds int `mapstructure:"sora_request_timeout_seconds"` + // SoraStreamMode: stream 强制策略(force/error) + SoraStreamMode string `mapstructure:"sora_stream_mode"` + // SoraModelFilters: 模型列表过滤配置 + SoraModelFilters SoraModelFiltersConfig `mapstructure:"sora_model_filters"` + // SoraMediaRequireAPIKey: 是否要求访问 /sora/media 携带 API Key + SoraMediaRequireAPIKey bool `mapstructure:"sora_media_require_api_key"` + // SoraMediaSigningKey: /sora/media 临时签名密钥(空表示禁用签名) + SoraMediaSigningKey string `mapstructure:"sora_media_signing_key"` + // SoraMediaSignedURLTTLSeconds: 临时签名 URL 有效期(秒,<=0 表示禁用) + SoraMediaSignedURLTTLSeconds int `mapstructure:"sora_media_signed_url_ttl_seconds"` + // 账户切换最大次数(遇到上游错误时切换到其他账户的次数上限) MaxAccountSwitches int `mapstructure:"max_account_switches"` // Gemini 账户切换最大次数(Gemini 平台单独配置,因 API 限制更严格) @@ -284,6 +440,53 @@ type GatewayConfig struct { // TLSFingerprint: TLS指纹伪装配置 TLSFingerprint TLSFingerprintConfig `mapstructure:"tls_fingerprint"` + + // UsageRecord: 使用量记录异步队列配置(有界队列 + 固定 worker) + UsageRecord GatewayUsageRecordConfig `mapstructure:"usage_record"` + + // UserGroupRateCacheTTLSeconds: 用户分组倍率热路径缓存 TTL(秒) + UserGroupRateCacheTTLSeconds int `mapstructure:"user_group_rate_cache_ttl_seconds"` + // ModelsListCacheTTLSeconds: /v1/models 模型列表短缓存 TTL(秒) + ModelsListCacheTTLSeconds int `mapstructure:"models_list_cache_ttl_seconds"` +} + +// GatewayUsageRecordConfig 使用量记录异步队列配置 +type GatewayUsageRecordConfig struct { + // WorkerCount: worker 初始数量(自动扩缩容开启时作为初始并发上限) + WorkerCount int `mapstructure:"worker_count"` + // QueueSize: 队列容量(有界) + QueueSize int `mapstructure:"queue_size"` + // TaskTimeoutSeconds: 单个使用量记录任务超时(秒) + TaskTimeoutSeconds int `mapstructure:"task_timeout_seconds"` + // OverflowPolicy: 队列满时策略(drop/sample/sync) + OverflowPolicy string `mapstructure:"overflow_policy"` + // OverflowSamplePercent: sample 策略下,同步回写采样百分比(1-100) + OverflowSamplePercent int `mapstructure:"overflow_sample_percent"` + + // AutoScaleEnabled: 是否启用 worker 自动扩缩容 + AutoScaleEnabled bool `mapstructure:"auto_scale_enabled"` + // AutoScaleMinWorkers: 自动扩缩容最小 worker 数 + AutoScaleMinWorkers int `mapstructure:"auto_scale_min_workers"` + // AutoScaleMaxWorkers: 自动扩缩容最大 worker 数 + AutoScaleMaxWorkers int `mapstructure:"auto_scale_max_workers"` + // AutoScaleUpQueuePercent: 队列占用率达到该阈值时触发扩容 + AutoScaleUpQueuePercent int `mapstructure:"auto_scale_up_queue_percent"` + // AutoScaleDownQueuePercent: 队列占用率低于该阈值时触发缩容 + AutoScaleDownQueuePercent int `mapstructure:"auto_scale_down_queue_percent"` + // AutoScaleUpStep: 每次扩容步长 + AutoScaleUpStep int `mapstructure:"auto_scale_up_step"` + // AutoScaleDownStep: 每次缩容步长 + AutoScaleDownStep int `mapstructure:"auto_scale_down_step"` + // AutoScaleCheckIntervalSeconds: 自动扩缩容检测间隔(秒) + AutoScaleCheckIntervalSeconds int `mapstructure:"auto_scale_check_interval_seconds"` + // AutoScaleCooldownSeconds: 自动扩缩容冷却时间(秒) + AutoScaleCooldownSeconds int `mapstructure:"auto_scale_cooldown_seconds"` +} + +// SoraModelFiltersConfig Sora 模型过滤配置 +type SoraModelFiltersConfig struct { + // HidePromptEnhance 是否隐藏 prompt-enhance 模型 + HidePromptEnhance bool `mapstructure:"hide_prompt_enhance"` } // TLSFingerprintConfig TLS指纹伪装配置 @@ -479,8 +682,9 @@ type OpsMetricsCollectorCacheConfig struct { type JWTConfig struct { Secret string `mapstructure:"secret"` ExpireHour int `mapstructure:"expire_hour"` - // AccessTokenExpireMinutes: Access Token有效期(分钟),默认15分钟 - // 短有效期减少被盗用风险,配合Refresh Token实现无感续期 + // AccessTokenExpireMinutes: Access Token有效期(分钟) + // - >0: 使用分钟配置(优先级高于 ExpireHour) + // - =0: 回退使用 ExpireHour(向后兼容旧配置) AccessTokenExpireMinutes int `mapstructure:"access_token_expire_minutes"` // RefreshTokenExpireDays: Refresh Token有效期(天),默认30天 RefreshTokenExpireDays int `mapstructure:"refresh_token_expire_days"` @@ -525,6 +729,20 @@ type APIKeyAuthCacheConfig struct { Singleflight bool `mapstructure:"singleflight"` } +// SubscriptionCacheConfig 订阅认证 L1 缓存配置 +type SubscriptionCacheConfig struct { + L1Size int `mapstructure:"l1_size"` + L1TTLSeconds int `mapstructure:"l1_ttl_seconds"` + JitterPercent int `mapstructure:"jitter_percent"` +} + +// SubscriptionMaintenanceConfig 订阅窗口维护后台任务配置。 +// 用于将“请求路径触发的维护动作”有界化,避免高并发下 goroutine 膨胀。 +type SubscriptionMaintenanceConfig struct { + WorkerCount int `mapstructure:"worker_count"` + QueueSize int `mapstructure:"queue_size"` +} + // DashboardCacheConfig 仪表盘统计缓存配置 type DashboardCacheConfig struct { // Enabled: 是否启用仪表盘缓存 @@ -588,7 +806,19 @@ func NormalizeRunMode(value string) string { } } +// Load 读取并校验完整配置(要求 jwt.secret 已显式提供)。 func Load() (*Config, error) { + return load(false) +} + +// LoadForBootstrap 读取启动阶段配置。 +// +// 启动阶段允许 jwt.secret 先留空,后续由数据库初始化流程补齐并再次完整校验。 +func LoadForBootstrap() (*Config, error) { + return load(true) +} + +func load(allowMissingJWTSecret bool) (*Config, error) { viper.SetConfigName("config") viper.SetConfigType("yaml") @@ -630,6 +860,7 @@ func Load() (*Config, error) { if cfg.Server.Mode == "" { cfg.Server.Mode = "debug" } + cfg.Server.FrontendURL = strings.TrimSpace(cfg.Server.FrontendURL) cfg.JWT.Secret = strings.TrimSpace(cfg.JWT.Secret) cfg.LinuxDo.ClientID = strings.TrimSpace(cfg.LinuxDo.ClientID) cfg.LinuxDo.ClientSecret = strings.TrimSpace(cfg.LinuxDo.ClientSecret) @@ -648,15 +879,12 @@ func Load() (*Config, error) { cfg.Security.ResponseHeaders.AdditionalAllowed = normalizeStringSlice(cfg.Security.ResponseHeaders.AdditionalAllowed) cfg.Security.ResponseHeaders.ForceRemove = normalizeStringSlice(cfg.Security.ResponseHeaders.ForceRemove) cfg.Security.CSP.Policy = strings.TrimSpace(cfg.Security.CSP.Policy) - - if cfg.JWT.Secret == "" { - secret, err := generateJWTSecret(64) - if err != nil { - return nil, fmt.Errorf("generate jwt secret error: %w", err) - } - cfg.JWT.Secret = secret - log.Println("Warning: JWT secret auto-generated. Consider setting a fixed secret for production.") - } + cfg.Log.Level = strings.ToLower(strings.TrimSpace(cfg.Log.Level)) + cfg.Log.Format = strings.ToLower(strings.TrimSpace(cfg.Log.Format)) + cfg.Log.ServiceName = strings.TrimSpace(cfg.Log.ServiceName) + cfg.Log.Environment = strings.TrimSpace(cfg.Log.Environment) + cfg.Log.StacktraceLevel = strings.ToLower(strings.TrimSpace(cfg.Log.StacktraceLevel)) + cfg.Log.Output.FilePath = strings.TrimSpace(cfg.Log.Output.FilePath) // Auto-generate TOTP encryption key if not set (32 bytes = 64 hex chars for AES-256) cfg.Totp.EncryptionKey = strings.TrimSpace(cfg.Totp.EncryptionKey) @@ -667,29 +895,39 @@ func Load() (*Config, error) { } cfg.Totp.EncryptionKey = key cfg.Totp.EncryptionKeyConfigured = false - log.Println("Warning: TOTP encryption key auto-generated. Consider setting a fixed key for production.") + slog.Warn("TOTP encryption key auto-generated. Consider setting a fixed key for production.") } else { cfg.Totp.EncryptionKeyConfigured = true } + originalJWTSecret := cfg.JWT.Secret + if allowMissingJWTSecret && originalJWTSecret == "" { + // 启动阶段允许先无 JWT 密钥,后续在数据库初始化后补齐。 + cfg.JWT.Secret = strings.Repeat("0", 32) + } + if err := cfg.Validate(); err != nil { return nil, fmt.Errorf("validate config error: %w", err) } + if allowMissingJWTSecret && originalJWTSecret == "" { + cfg.JWT.Secret = "" + } + if !cfg.Security.URLAllowlist.Enabled { - log.Println("Warning: security.url_allowlist.enabled=false; allowlist/SSRF checks disabled (minimal format validation only).") + slog.Warn("security.url_allowlist.enabled=false; allowlist/SSRF checks disabled (minimal format validation only).") } if !cfg.Security.ResponseHeaders.Enabled { - log.Println("Warning: security.response_headers.enabled=false; configurable header filtering disabled (default allowlist only).") + slog.Warn("security.response_headers.enabled=false; configurable header filtering disabled (default allowlist only).") } if cfg.JWT.Secret != "" && isWeakJWTSecret(cfg.JWT.Secret) { - log.Println("Warning: JWT secret appears weak; use a 32+ character random secret in production.") + slog.Warn("JWT secret appears weak; use a 32+ character random secret in production.") } if len(cfg.Security.ResponseHeaders.AdditionalAllowed) > 0 || len(cfg.Security.ResponseHeaders.ForceRemove) > 0 { - log.Printf("AUDIT: response header policy configured additional_allowed=%v force_remove=%v", - cfg.Security.ResponseHeaders.AdditionalAllowed, - cfg.Security.ResponseHeaders.ForceRemove, + slog.Info("response header policy configured", + "additional_allowed", cfg.Security.ResponseHeaders.AdditionalAllowed, + "force_remove", cfg.Security.ResponseHeaders.ForceRemove, ) } @@ -702,7 +940,8 @@ func setDefaults() { // Server viper.SetDefault("server.host", "0.0.0.0") viper.SetDefault("server.port", 8080) - viper.SetDefault("server.mode", "debug") + viper.SetDefault("server.mode", "release") + viper.SetDefault("server.frontend_url", "") viper.SetDefault("server.read_header_timeout", 30) // 30秒读取请求头 viper.SetDefault("server.idle_timeout", 120) // 120秒空闲超时 viper.SetDefault("server.trusted_proxies", []string{}) @@ -715,6 +954,25 @@ func setDefaults() { viper.SetDefault("server.h2c.max_upload_buffer_per_connection", 2<<20) // 2MB viper.SetDefault("server.h2c.max_upload_buffer_per_stream", 512<<10) // 512KB + // Log + viper.SetDefault("log.level", "info") + viper.SetDefault("log.format", "console") + viper.SetDefault("log.service_name", "sub2api") + viper.SetDefault("log.env", "production") + viper.SetDefault("log.caller", true) + viper.SetDefault("log.stacktrace_level", "error") + viper.SetDefault("log.output.to_stdout", true) + viper.SetDefault("log.output.to_file", true) + viper.SetDefault("log.output.file_path", "") + viper.SetDefault("log.rotation.max_size_mb", 100) + viper.SetDefault("log.rotation.max_backups", 10) + viper.SetDefault("log.rotation.max_age_days", 7) + viper.SetDefault("log.rotation.compress", true) + viper.SetDefault("log.rotation.local_time", true) + viper.SetDefault("log.sampling.enabled", false) + viper.SetDefault("log.sampling.initial", 100) + viper.SetDefault("log.sampling.thereafter", 100) + // CORS viper.SetDefault("cors.allowed_origins", []string{}) viper.SetDefault("cors.allow_credentials", true) @@ -737,7 +995,7 @@ func setDefaults() { viper.SetDefault("security.url_allowlist.crs_hosts", []string{}) viper.SetDefault("security.url_allowlist.allow_private_hosts", true) viper.SetDefault("security.url_allowlist.allow_insecure_http", true) - viper.SetDefault("security.response_headers.enabled", false) + viper.SetDefault("security.response_headers.enabled", true) viper.SetDefault("security.response_headers.additional_allowed", []string{}) viper.SetDefault("security.response_headers.force_remove", []string{}) viper.SetDefault("security.csp.enabled", true) @@ -775,9 +1033,9 @@ func setDefaults() { viper.SetDefault("database.user", "postgres") viper.SetDefault("database.password", "postgres") viper.SetDefault("database.dbname", "sub2api") - viper.SetDefault("database.sslmode", "disable") - viper.SetDefault("database.max_open_conns", 50) - viper.SetDefault("database.max_idle_conns", 10) + viper.SetDefault("database.sslmode", "prefer") + viper.SetDefault("database.max_open_conns", 256) + viper.SetDefault("database.max_idle_conns", 128) viper.SetDefault("database.conn_max_lifetime_minutes", 30) viper.SetDefault("database.conn_max_idle_time_minutes", 5) @@ -789,8 +1047,8 @@ func setDefaults() { viper.SetDefault("redis.dial_timeout_seconds", 5) viper.SetDefault("redis.read_timeout_seconds", 3) viper.SetDefault("redis.write_timeout_seconds", 3) - viper.SetDefault("redis.pool_size", 128) - viper.SetDefault("redis.min_idle_conns", 10) + viper.SetDefault("redis.pool_size", 1024) + viper.SetDefault("redis.min_idle_conns", 128) viper.SetDefault("redis.enable_tls", false) // Ops (vNext) @@ -810,9 +1068,9 @@ func setDefaults() { // JWT viper.SetDefault("jwt.secret", "") viper.SetDefault("jwt.expire_hour", 24) - viper.SetDefault("jwt.access_token_expire_minutes", 360) // 6小时Access Token有效期 - viper.SetDefault("jwt.refresh_token_expire_days", 30) // 30天Refresh Token有效期 - viper.SetDefault("jwt.refresh_window_minutes", 2) // 过期前2分钟开始允许刷新 + viper.SetDefault("jwt.access_token_expire_minutes", 0) // 0 表示回退到 expire_hour + viper.SetDefault("jwt.refresh_token_expire_days", 30) // 30天Refresh Token有效期 + viper.SetDefault("jwt.refresh_window_minutes", 2) // 过期前2分钟开始允许刷新 // TOTP viper.SetDefault("totp.encryption_key", "") @@ -849,6 +1107,11 @@ func setDefaults() { viper.SetDefault("api_key_auth_cache.jitter_percent", 10) viper.SetDefault("api_key_auth_cache.singleflight", true) + // Subscription auth L1 cache + viper.SetDefault("subscription_cache.l1_size", 16384) + viper.SetDefault("subscription_cache.l1_ttl_seconds", 10) + viper.SetDefault("subscription_cache.jitter_percent", 10) + // Dashboard cache viper.SetDefault("dashboard_cache.enabled", true) viper.SetDefault("dashboard_cache.key_prefix", "sub2api:") @@ -874,6 +1137,16 @@ func setDefaults() { viper.SetDefault("usage_cleanup.worker_interval_seconds", 10) viper.SetDefault("usage_cleanup.task_timeout_seconds", 1800) + // Idempotency + viper.SetDefault("idempotency.observe_only", true) + viper.SetDefault("idempotency.default_ttl_seconds", 86400) + viper.SetDefault("idempotency.system_operation_ttl_seconds", 3600) + viper.SetDefault("idempotency.processing_timeout_seconds", 30) + viper.SetDefault("idempotency.failed_retry_backoff_seconds", 5) + viper.SetDefault("idempotency.max_stored_response_len", 64*1024) + viper.SetDefault("idempotency.cleanup_interval_seconds", 60) + viper.SetDefault("idempotency.cleanup_batch_size", 500) + // Gateway viper.SetDefault("gateway.response_header_timeout", 600) // 600秒(10分钟)等待上游响应头,LLM高负载时可能排队较久 viper.SetDefault("gateway.log_upstream_error_body", true) @@ -882,13 +1155,25 @@ func setDefaults() { viper.SetDefault("gateway.failover_on_400", false) viper.SetDefault("gateway.max_account_switches", 10) viper.SetDefault("gateway.max_account_switches_gemini", 3) + viper.SetDefault("gateway.force_codex_cli", false) + viper.SetDefault("gateway.openai_passthrough_allow_timeout_headers", false) viper.SetDefault("gateway.antigravity_fallback_cooldown_minutes", 1) viper.SetDefault("gateway.max_body_size", int64(100*1024*1024)) + viper.SetDefault("gateway.upstream_response_read_max_bytes", int64(8*1024*1024)) + viper.SetDefault("gateway.proxy_probe_response_read_max_bytes", int64(1024*1024)) + viper.SetDefault("gateway.gemini_debug_response_headers", false) + viper.SetDefault("gateway.sora_max_body_size", int64(256*1024*1024)) + viper.SetDefault("gateway.sora_stream_timeout_seconds", 900) + viper.SetDefault("gateway.sora_request_timeout_seconds", 180) + viper.SetDefault("gateway.sora_stream_mode", "force") + viper.SetDefault("gateway.sora_model_filters.hide_prompt_enhance", true) + viper.SetDefault("gateway.sora_media_require_api_key", true) + viper.SetDefault("gateway.sora_media_signed_url_ttl_seconds", 900) viper.SetDefault("gateway.connection_pool_isolation", ConnectionPoolIsolationAccountProxy) // HTTP 上游连接池配置(针对 5000+ 并发用户优化) - viper.SetDefault("gateway.max_idle_conns", 240) // 最大空闲连接总数(HTTP/2 场景默认) + viper.SetDefault("gateway.max_idle_conns", 2560) // 最大空闲连接总数(高并发场景可调大) viper.SetDefault("gateway.max_idle_conns_per_host", 120) // 每主机最大空闲连接(HTTP/2 场景默认) - viper.SetDefault("gateway.max_conns_per_host", 240) // 每主机最大连接数(含活跃,HTTP/2 场景默认) + viper.SetDefault("gateway.max_conns_per_host", 1024) // 每主机最大连接数(含活跃;流式/HTTP1.1 场景可调大,如 2400+) viper.SetDefault("gateway.idle_conn_timeout_seconds", 90) // 空闲连接超时(秒) viper.SetDefault("gateway.max_upstream_clients", 5000) viper.SetDefault("gateway.client_idle_ttl_seconds", 900) @@ -912,16 +1197,65 @@ func setDefaults() { viper.SetDefault("gateway.scheduling.outbox_lag_rebuild_failures", 3) viper.SetDefault("gateway.scheduling.outbox_backlog_rebuild_rows", 10000) viper.SetDefault("gateway.scheduling.full_rebuild_interval_seconds", 300) + viper.SetDefault("gateway.usage_record.worker_count", 128) + viper.SetDefault("gateway.usage_record.queue_size", 16384) + viper.SetDefault("gateway.usage_record.task_timeout_seconds", 5) + viper.SetDefault("gateway.usage_record.overflow_policy", UsageRecordOverflowPolicySample) + viper.SetDefault("gateway.usage_record.overflow_sample_percent", 10) + viper.SetDefault("gateway.usage_record.auto_scale_enabled", true) + viper.SetDefault("gateway.usage_record.auto_scale_min_workers", 128) + viper.SetDefault("gateway.usage_record.auto_scale_max_workers", 512) + viper.SetDefault("gateway.usage_record.auto_scale_up_queue_percent", 70) + viper.SetDefault("gateway.usage_record.auto_scale_down_queue_percent", 15) + viper.SetDefault("gateway.usage_record.auto_scale_up_step", 32) + viper.SetDefault("gateway.usage_record.auto_scale_down_step", 16) + viper.SetDefault("gateway.usage_record.auto_scale_check_interval_seconds", 3) + viper.SetDefault("gateway.usage_record.auto_scale_cooldown_seconds", 10) + viper.SetDefault("gateway.user_group_rate_cache_ttl_seconds", 30) + viper.SetDefault("gateway.models_list_cache_ttl_seconds", 15) // TLS指纹伪装配置(默认关闭,需要账号级别单独启用) viper.SetDefault("gateway.tls_fingerprint.enabled", true) viper.SetDefault("concurrency.ping_interval", 10) + // Sora 直连配置 + viper.SetDefault("sora.client.base_url", "https://sora.chatgpt.com/backend") + viper.SetDefault("sora.client.timeout_seconds", 120) + viper.SetDefault("sora.client.max_retries", 3) + viper.SetDefault("sora.client.cloudflare_challenge_cooldown_seconds", 900) + viper.SetDefault("sora.client.poll_interval_seconds", 2) + viper.SetDefault("sora.client.max_poll_attempts", 600) + viper.SetDefault("sora.client.recent_task_limit", 50) + viper.SetDefault("sora.client.recent_task_limit_max", 200) + viper.SetDefault("sora.client.debug", false) + viper.SetDefault("sora.client.use_openai_token_provider", false) + viper.SetDefault("sora.client.headers", map[string]string{}) + viper.SetDefault("sora.client.user_agent", "Sora/1.2026.007 (Android 15; 24122RKC7C; build 2600700)") + viper.SetDefault("sora.client.disable_tls_fingerprint", false) + viper.SetDefault("sora.client.curl_cffi_sidecar.enabled", true) + viper.SetDefault("sora.client.curl_cffi_sidecar.base_url", "http://sora-curl-cffi-sidecar:8080") + viper.SetDefault("sora.client.curl_cffi_sidecar.impersonate", "chrome131") + viper.SetDefault("sora.client.curl_cffi_sidecar.timeout_seconds", 60) + viper.SetDefault("sora.client.curl_cffi_sidecar.session_reuse_enabled", true) + viper.SetDefault("sora.client.curl_cffi_sidecar.session_ttl_seconds", 3600) + + viper.SetDefault("sora.storage.type", "local") + viper.SetDefault("sora.storage.local_path", "") + viper.SetDefault("sora.storage.fallback_to_upstream", true) + viper.SetDefault("sora.storage.max_concurrent_downloads", 4) + viper.SetDefault("sora.storage.download_timeout_seconds", 120) + viper.SetDefault("sora.storage.max_download_bytes", int64(200<<20)) + viper.SetDefault("sora.storage.debug", false) + viper.SetDefault("sora.storage.cleanup.enabled", true) + viper.SetDefault("sora.storage.cleanup.retention_days", 7) + viper.SetDefault("sora.storage.cleanup.schedule", "0 3 * * *") + // TokenRefresh viper.SetDefault("token_refresh.enabled", true) viper.SetDefault("token_refresh.check_interval_minutes", 5) // 每5分钟检查一次 viper.SetDefault("token_refresh.refresh_before_expiry_hours", 0.5) // 提前30分钟刷新(适配Google 1小时token) viper.SetDefault("token_refresh.max_retries", 3) // 最多重试3次 viper.SetDefault("token_refresh.retry_backoff_seconds", 2) // 重试退避基础2秒 + viper.SetDefault("token_refresh.sync_linked_sora_accounts", false) // 默认不跨平台覆盖 Sora token // Gemini OAuth - configure via environment variables or config file // GEMINI_OAUTH_CLIENT_ID and GEMINI_OAUTH_CLIENT_SECRET @@ -930,9 +1264,106 @@ func setDefaults() { viper.SetDefault("gemini.oauth.client_secret", "") viper.SetDefault("gemini.oauth.scopes", "") viper.SetDefault("gemini.quota.policy", "") + + // Security - proxy fallback + viper.SetDefault("security.proxy_fallback.allow_direct_on_error", false) + + // Subscription Maintenance (bounded queue + worker pool) + viper.SetDefault("subscription_maintenance.worker_count", 2) + viper.SetDefault("subscription_maintenance.queue_size", 1024) + } func (c *Config) Validate() error { + jwtSecret := strings.TrimSpace(c.JWT.Secret) + if jwtSecret == "" { + return fmt.Errorf("jwt.secret is required") + } + // NOTE: 按 UTF-8 编码后的字节长度计算。 + // 选择 bytes 而不是 rune 计数,确保二进制/随机串的长度语义更接近“熵”而非“字符数”。 + if len([]byte(jwtSecret)) < 32 { + return fmt.Errorf("jwt.secret must be at least 32 bytes") + } + switch c.Log.Level { + case "debug", "info", "warn", "error": + case "": + return fmt.Errorf("log.level is required") + default: + return fmt.Errorf("log.level must be one of: debug/info/warn/error") + } + switch c.Log.Format { + case "json", "console": + case "": + return fmt.Errorf("log.format is required") + default: + return fmt.Errorf("log.format must be one of: json/console") + } + switch c.Log.StacktraceLevel { + case "none", "error", "fatal": + case "": + return fmt.Errorf("log.stacktrace_level is required") + default: + return fmt.Errorf("log.stacktrace_level must be one of: none/error/fatal") + } + if !c.Log.Output.ToStdout && !c.Log.Output.ToFile { + return fmt.Errorf("log.output.to_stdout and log.output.to_file cannot both be false") + } + if c.Log.Rotation.MaxSizeMB <= 0 { + return fmt.Errorf("log.rotation.max_size_mb must be positive") + } + if c.Log.Rotation.MaxBackups < 0 { + return fmt.Errorf("log.rotation.max_backups must be non-negative") + } + if c.Log.Rotation.MaxAgeDays < 0 { + return fmt.Errorf("log.rotation.max_age_days must be non-negative") + } + if c.Log.Sampling.Enabled { + if c.Log.Sampling.Initial <= 0 { + return fmt.Errorf("log.sampling.initial must be positive when sampling is enabled") + } + if c.Log.Sampling.Thereafter <= 0 { + return fmt.Errorf("log.sampling.thereafter must be positive when sampling is enabled") + } + } else { + if c.Log.Sampling.Initial < 0 { + return fmt.Errorf("log.sampling.initial must be non-negative") + } + if c.Log.Sampling.Thereafter < 0 { + return fmt.Errorf("log.sampling.thereafter must be non-negative") + } + } + + if c.SubscriptionMaintenance.WorkerCount < 0 { + return fmt.Errorf("subscription_maintenance.worker_count must be non-negative") + } + if c.SubscriptionMaintenance.QueueSize < 0 { + return fmt.Errorf("subscription_maintenance.queue_size must be non-negative") + } + + // Gemini OAuth 配置校验:client_id 与 client_secret 必须同时设置或同时留空。 + // 留空时表示使用内置的 Gemini CLI OAuth 客户端(其 client_secret 通过环境变量注入)。 + geminiClientID := strings.TrimSpace(c.Gemini.OAuth.ClientID) + geminiClientSecret := strings.TrimSpace(c.Gemini.OAuth.ClientSecret) + if (geminiClientID == "") != (geminiClientSecret == "") { + return fmt.Errorf("gemini.oauth.client_id and gemini.oauth.client_secret must be both set or both empty") + } + + if strings.TrimSpace(c.Server.FrontendURL) != "" { + if err := ValidateAbsoluteHTTPURL(c.Server.FrontendURL); err != nil { + return fmt.Errorf("server.frontend_url invalid: %w", err) + } + u, err := url.Parse(strings.TrimSpace(c.Server.FrontendURL)) + if err != nil { + return fmt.Errorf("server.frontend_url invalid: %w", err) + } + if u.RawQuery != "" || u.ForceQuery { + return fmt.Errorf("server.frontend_url invalid: must not include query") + } + if u.User != nil { + return fmt.Errorf("server.frontend_url invalid: must not include userinfo") + } + warnIfInsecureURL("server.frontend_url", c.Server.FrontendURL) + } if c.JWT.ExpireHour <= 0 { return fmt.Errorf("jwt.expire_hour must be positive") } @@ -940,20 +1371,20 @@ func (c *Config) Validate() error { return fmt.Errorf("jwt.expire_hour must be <= 168 (7 days)") } if c.JWT.ExpireHour > 24 { - log.Printf("Warning: jwt.expire_hour is %d hours (> 24). Consider shorter expiration for security.", c.JWT.ExpireHour) + slog.Warn("jwt.expire_hour is high; consider shorter expiration for security", "expire_hour", c.JWT.ExpireHour) } // JWT Refresh Token配置验证 - if c.JWT.AccessTokenExpireMinutes <= 0 { - return fmt.Errorf("jwt.access_token_expire_minutes must be positive") + if c.JWT.AccessTokenExpireMinutes < 0 { + return fmt.Errorf("jwt.access_token_expire_minutes must be non-negative") } if c.JWT.AccessTokenExpireMinutes > 720 { - log.Printf("Warning: jwt.access_token_expire_minutes is %d (> 720). Consider shorter expiration for security.", c.JWT.AccessTokenExpireMinutes) + slog.Warn("jwt.access_token_expire_minutes is high; consider shorter expiration for security", "access_token_expire_minutes", c.JWT.AccessTokenExpireMinutes) } if c.JWT.RefreshTokenExpireDays <= 0 { return fmt.Errorf("jwt.refresh_token_expire_days must be positive") } if c.JWT.RefreshTokenExpireDays > 90 { - log.Printf("Warning: jwt.refresh_token_expire_days is %d (> 90). Consider shorter expiration for security.", c.JWT.RefreshTokenExpireDays) + slog.Warn("jwt.refresh_token_expire_days is high; consider shorter expiration for security", "refresh_token_expire_days", c.JWT.RefreshTokenExpireDays) } if c.JWT.RefreshWindowMinutes < 0 { return fmt.Errorf("jwt.refresh_window_minutes must be non-negative") @@ -1159,9 +1590,116 @@ func (c *Config) Validate() error { return fmt.Errorf("usage_cleanup.task_timeout_seconds must be non-negative") } } + if c.Idempotency.DefaultTTLSeconds <= 0 { + return fmt.Errorf("idempotency.default_ttl_seconds must be positive") + } + if c.Idempotency.SystemOperationTTLSeconds <= 0 { + return fmt.Errorf("idempotency.system_operation_ttl_seconds must be positive") + } + if c.Idempotency.ProcessingTimeoutSeconds <= 0 { + return fmt.Errorf("idempotency.processing_timeout_seconds must be positive") + } + if c.Idempotency.FailedRetryBackoffSeconds <= 0 { + return fmt.Errorf("idempotency.failed_retry_backoff_seconds must be positive") + } + if c.Idempotency.MaxStoredResponseLen <= 0 { + return fmt.Errorf("idempotency.max_stored_response_len must be positive") + } + if c.Idempotency.CleanupIntervalSeconds <= 0 { + return fmt.Errorf("idempotency.cleanup_interval_seconds must be positive") + } + if c.Idempotency.CleanupBatchSize <= 0 { + return fmt.Errorf("idempotency.cleanup_batch_size must be positive") + } if c.Gateway.MaxBodySize <= 0 { return fmt.Errorf("gateway.max_body_size must be positive") } + if c.Gateway.UpstreamResponseReadMaxBytes <= 0 { + return fmt.Errorf("gateway.upstream_response_read_max_bytes must be positive") + } + if c.Gateway.ProxyProbeResponseReadMaxBytes <= 0 { + return fmt.Errorf("gateway.proxy_probe_response_read_max_bytes must be positive") + } + if c.Gateway.SoraMaxBodySize < 0 { + return fmt.Errorf("gateway.sora_max_body_size must be non-negative") + } + if c.Gateway.SoraStreamTimeoutSeconds < 0 { + return fmt.Errorf("gateway.sora_stream_timeout_seconds must be non-negative") + } + if c.Gateway.SoraRequestTimeoutSeconds < 0 { + return fmt.Errorf("gateway.sora_request_timeout_seconds must be non-negative") + } + if c.Gateway.SoraMediaSignedURLTTLSeconds < 0 { + return fmt.Errorf("gateway.sora_media_signed_url_ttl_seconds must be non-negative") + } + if mode := strings.TrimSpace(strings.ToLower(c.Gateway.SoraStreamMode)); mode != "" { + switch mode { + case "force", "error": + default: + return fmt.Errorf("gateway.sora_stream_mode must be one of: force/error") + } + } + if c.Sora.Client.TimeoutSeconds < 0 { + return fmt.Errorf("sora.client.timeout_seconds must be non-negative") + } + if c.Sora.Client.MaxRetries < 0 { + return fmt.Errorf("sora.client.max_retries must be non-negative") + } + if c.Sora.Client.CloudflareChallengeCooldownSeconds < 0 { + return fmt.Errorf("sora.client.cloudflare_challenge_cooldown_seconds must be non-negative") + } + if c.Sora.Client.PollIntervalSeconds < 0 { + return fmt.Errorf("sora.client.poll_interval_seconds must be non-negative") + } + if c.Sora.Client.MaxPollAttempts < 0 { + return fmt.Errorf("sora.client.max_poll_attempts must be non-negative") + } + if c.Sora.Client.RecentTaskLimit < 0 { + return fmt.Errorf("sora.client.recent_task_limit must be non-negative") + } + if c.Sora.Client.RecentTaskLimitMax < 0 { + return fmt.Errorf("sora.client.recent_task_limit_max must be non-negative") + } + if c.Sora.Client.RecentTaskLimitMax > 0 && c.Sora.Client.RecentTaskLimit > 0 && + c.Sora.Client.RecentTaskLimitMax < c.Sora.Client.RecentTaskLimit { + c.Sora.Client.RecentTaskLimitMax = c.Sora.Client.RecentTaskLimit + } + if c.Sora.Client.CurlCFFISidecar.TimeoutSeconds < 0 { + return fmt.Errorf("sora.client.curl_cffi_sidecar.timeout_seconds must be non-negative") + } + if c.Sora.Client.CurlCFFISidecar.SessionTTLSeconds < 0 { + return fmt.Errorf("sora.client.curl_cffi_sidecar.session_ttl_seconds must be non-negative") + } + if !c.Sora.Client.CurlCFFISidecar.Enabled { + return fmt.Errorf("sora.client.curl_cffi_sidecar.enabled must be true") + } + if strings.TrimSpace(c.Sora.Client.CurlCFFISidecar.BaseURL) == "" { + return fmt.Errorf("sora.client.curl_cffi_sidecar.base_url is required") + } + if c.Sora.Storage.MaxConcurrentDownloads < 0 { + return fmt.Errorf("sora.storage.max_concurrent_downloads must be non-negative") + } + if c.Sora.Storage.DownloadTimeoutSeconds < 0 { + return fmt.Errorf("sora.storage.download_timeout_seconds must be non-negative") + } + if c.Sora.Storage.MaxDownloadBytes < 0 { + return fmt.Errorf("sora.storage.max_download_bytes must be non-negative") + } + if c.Sora.Storage.Cleanup.Enabled { + if c.Sora.Storage.Cleanup.RetentionDays <= 0 { + return fmt.Errorf("sora.storage.cleanup.retention_days must be positive") + } + if strings.TrimSpace(c.Sora.Storage.Cleanup.Schedule) == "" { + return fmt.Errorf("sora.storage.cleanup.schedule is required when cleanup is enabled") + } + } else { + if c.Sora.Storage.Cleanup.RetentionDays < 0 { + return fmt.Errorf("sora.storage.cleanup.retention_days must be non-negative") + } + } + if storageType := strings.TrimSpace(strings.ToLower(c.Sora.Storage.Type)); storageType != "" && storageType != "local" { + return fmt.Errorf("sora.storage.type must be 'local'") + } if strings.TrimSpace(c.Gateway.ConnectionPoolIsolation) != "" { switch c.Gateway.ConnectionPoolIsolation { case ConnectionPoolIsolationProxy, ConnectionPoolIsolationAccount, ConnectionPoolIsolationAccountProxy: @@ -1183,7 +1721,7 @@ func (c *Config) Validate() error { return fmt.Errorf("gateway.idle_conn_timeout_seconds must be positive") } if c.Gateway.IdleConnTimeoutSeconds > 180 { - log.Printf("Warning: gateway.idle_conn_timeout_seconds is %d (> 180). Consider 60-120 seconds for better connection reuse.", c.Gateway.IdleConnTimeoutSeconds) + slog.Warn("gateway.idle_conn_timeout_seconds is high; consider 60-120 seconds for better connection reuse", "idle_conn_timeout_seconds", c.Gateway.IdleConnTimeoutSeconds) } if c.Gateway.MaxUpstreamClients <= 0 { return fmt.Errorf("gateway.max_upstream_clients must be positive") @@ -1214,6 +1752,70 @@ func (c *Config) Validate() error { if c.Gateway.MaxLineSize != 0 && c.Gateway.MaxLineSize < 1024*1024 { return fmt.Errorf("gateway.max_line_size must be at least 1MB") } + if c.Gateway.UsageRecord.WorkerCount <= 0 { + return fmt.Errorf("gateway.usage_record.worker_count must be positive") + } + if c.Gateway.UsageRecord.QueueSize <= 0 { + return fmt.Errorf("gateway.usage_record.queue_size must be positive") + } + if c.Gateway.UsageRecord.TaskTimeoutSeconds <= 0 { + return fmt.Errorf("gateway.usage_record.task_timeout_seconds must be positive") + } + switch strings.ToLower(strings.TrimSpace(c.Gateway.UsageRecord.OverflowPolicy)) { + case UsageRecordOverflowPolicyDrop, UsageRecordOverflowPolicySample, UsageRecordOverflowPolicySync: + default: + return fmt.Errorf("gateway.usage_record.overflow_policy must be one of: %s/%s/%s", + UsageRecordOverflowPolicyDrop, UsageRecordOverflowPolicySample, UsageRecordOverflowPolicySync) + } + if c.Gateway.UsageRecord.OverflowSamplePercent < 0 || c.Gateway.UsageRecord.OverflowSamplePercent > 100 { + return fmt.Errorf("gateway.usage_record.overflow_sample_percent must be between 0-100") + } + if strings.EqualFold(strings.TrimSpace(c.Gateway.UsageRecord.OverflowPolicy), UsageRecordOverflowPolicySample) && + c.Gateway.UsageRecord.OverflowSamplePercent <= 0 { + return fmt.Errorf("gateway.usage_record.overflow_sample_percent must be positive when overflow_policy=sample") + } + if c.Gateway.UsageRecord.AutoScaleEnabled { + if c.Gateway.UsageRecord.AutoScaleMinWorkers <= 0 { + return fmt.Errorf("gateway.usage_record.auto_scale_min_workers must be positive") + } + if c.Gateway.UsageRecord.AutoScaleMaxWorkers <= 0 { + return fmt.Errorf("gateway.usage_record.auto_scale_max_workers must be positive") + } + if c.Gateway.UsageRecord.AutoScaleMaxWorkers < c.Gateway.UsageRecord.AutoScaleMinWorkers { + return fmt.Errorf("gateway.usage_record.auto_scale_max_workers must be >= auto_scale_min_workers") + } + if c.Gateway.UsageRecord.WorkerCount < c.Gateway.UsageRecord.AutoScaleMinWorkers || + c.Gateway.UsageRecord.WorkerCount > c.Gateway.UsageRecord.AutoScaleMaxWorkers { + return fmt.Errorf("gateway.usage_record.worker_count must be between auto_scale_min_workers and auto_scale_max_workers") + } + if c.Gateway.UsageRecord.AutoScaleUpQueuePercent <= 0 || c.Gateway.UsageRecord.AutoScaleUpQueuePercent > 100 { + return fmt.Errorf("gateway.usage_record.auto_scale_up_queue_percent must be between 1-100") + } + if c.Gateway.UsageRecord.AutoScaleDownQueuePercent < 0 || c.Gateway.UsageRecord.AutoScaleDownQueuePercent >= 100 { + return fmt.Errorf("gateway.usage_record.auto_scale_down_queue_percent must be between 0-99") + } + if c.Gateway.UsageRecord.AutoScaleDownQueuePercent >= c.Gateway.UsageRecord.AutoScaleUpQueuePercent { + return fmt.Errorf("gateway.usage_record.auto_scale_down_queue_percent must be less than auto_scale_up_queue_percent") + } + if c.Gateway.UsageRecord.AutoScaleUpStep <= 0 { + return fmt.Errorf("gateway.usage_record.auto_scale_up_step must be positive") + } + if c.Gateway.UsageRecord.AutoScaleDownStep <= 0 { + return fmt.Errorf("gateway.usage_record.auto_scale_down_step must be positive") + } + if c.Gateway.UsageRecord.AutoScaleCheckIntervalSeconds <= 0 { + return fmt.Errorf("gateway.usage_record.auto_scale_check_interval_seconds must be positive") + } + if c.Gateway.UsageRecord.AutoScaleCooldownSeconds < 0 { + return fmt.Errorf("gateway.usage_record.auto_scale_cooldown_seconds must be non-negative") + } + } + if c.Gateway.UserGroupRateCacheTTLSeconds <= 0 { + return fmt.Errorf("gateway.user_group_rate_cache_ttl_seconds must be positive") + } + if c.Gateway.ModelsListCacheTTLSeconds < 10 || c.Gateway.ModelsListCacheTTLSeconds > 30 { + return fmt.Errorf("gateway.models_list_cache_ttl_seconds must be between 10-30") + } if c.Gateway.Scheduling.StickySessionMaxWaiting <= 0 { return fmt.Errorf("gateway.scheduling.sticky_session_max_waiting must be positive") } @@ -1420,6 +2022,6 @@ func warnIfInsecureURL(field, raw string) { return } if strings.EqualFold(u.Scheme, "http") { - log.Printf("Warning: %s uses http scheme; use https in production to avoid token leakage.", field) + slog.Warn("url uses http scheme; use https in production to avoid token leakage", "field", field) } } diff --git a/backend/internal/config/config_test.go b/backend/internal/config/config_test.go index f734619f..b0402a3b 100644 --- a/backend/internal/config/config_test.go +++ b/backend/internal/config/config_test.go @@ -8,6 +8,25 @@ import ( "github.com/spf13/viper" ) +func resetViperWithJWTSecret(t *testing.T) { + t.Helper() + viper.Reset() + t.Setenv("JWT_SECRET", strings.Repeat("x", 32)) +} + +func TestLoadForBootstrapAllowsMissingJWTSecret(t *testing.T) { + viper.Reset() + t.Setenv("JWT_SECRET", "") + + cfg, err := LoadForBootstrap() + if err != nil { + t.Fatalf("LoadForBootstrap() error: %v", err) + } + if cfg.JWT.Secret != "" { + t.Fatalf("LoadForBootstrap() should keep empty jwt.secret during bootstrap") + } +} + func TestNormalizeRunMode(t *testing.T) { tests := []struct { input string @@ -29,7 +48,7 @@ func TestNormalizeRunMode(t *testing.T) { } func TestLoadDefaultSchedulingConfig(t *testing.T) { - viper.Reset() + resetViperWithJWTSecret(t) cfg, err := Load() if err != nil { @@ -56,8 +75,44 @@ func TestLoadDefaultSchedulingConfig(t *testing.T) { } } +func TestLoadDefaultIdempotencyConfig(t *testing.T) { + resetViperWithJWTSecret(t) + + cfg, err := Load() + if err != nil { + t.Fatalf("Load() error: %v", err) + } + + if !cfg.Idempotency.ObserveOnly { + t.Fatalf("Idempotency.ObserveOnly = false, want true") + } + if cfg.Idempotency.DefaultTTLSeconds != 86400 { + t.Fatalf("Idempotency.DefaultTTLSeconds = %d, want 86400", cfg.Idempotency.DefaultTTLSeconds) + } + if cfg.Idempotency.SystemOperationTTLSeconds != 3600 { + t.Fatalf("Idempotency.SystemOperationTTLSeconds = %d, want 3600", cfg.Idempotency.SystemOperationTTLSeconds) + } +} + +func TestLoadIdempotencyConfigFromEnv(t *testing.T) { + resetViperWithJWTSecret(t) + t.Setenv("IDEMPOTENCY_OBSERVE_ONLY", "false") + t.Setenv("IDEMPOTENCY_DEFAULT_TTL_SECONDS", "600") + + cfg, err := Load() + if err != nil { + t.Fatalf("Load() error: %v", err) + } + if cfg.Idempotency.ObserveOnly { + t.Fatalf("Idempotency.ObserveOnly = true, want false") + } + if cfg.Idempotency.DefaultTTLSeconds != 600 { + t.Fatalf("Idempotency.DefaultTTLSeconds = %d, want 600", cfg.Idempotency.DefaultTTLSeconds) + } +} + func TestLoadSchedulingConfigFromEnv(t *testing.T) { - viper.Reset() + resetViperWithJWTSecret(t) t.Setenv("GATEWAY_SCHEDULING_STICKY_SESSION_MAX_WAITING", "5") cfg, err := Load() @@ -71,7 +126,7 @@ func TestLoadSchedulingConfigFromEnv(t *testing.T) { } func TestLoadDefaultSecurityToggles(t *testing.T) { - viper.Reset() + resetViperWithJWTSecret(t) cfg, err := Load() if err != nil { @@ -87,13 +142,69 @@ func TestLoadDefaultSecurityToggles(t *testing.T) { if !cfg.Security.URLAllowlist.AllowPrivateHosts { t.Fatalf("URLAllowlist.AllowPrivateHosts = false, want true") } - if cfg.Security.ResponseHeaders.Enabled { - t.Fatalf("ResponseHeaders.Enabled = true, want false") + if !cfg.Security.ResponseHeaders.Enabled { + t.Fatalf("ResponseHeaders.Enabled = false, want true") + } +} + +func TestLoadDefaultServerMode(t *testing.T) { + resetViperWithJWTSecret(t) + + cfg, err := Load() + if err != nil { + t.Fatalf("Load() error: %v", err) + } + + if cfg.Server.Mode != "release" { + t.Fatalf("Server.Mode = %q, want %q", cfg.Server.Mode, "release") + } +} + +func TestLoadDefaultJWTAccessTokenExpireMinutes(t *testing.T) { + resetViperWithJWTSecret(t) + + cfg, err := Load() + if err != nil { + t.Fatalf("Load() error: %v", err) + } + + if cfg.JWT.ExpireHour != 24 { + t.Fatalf("JWT.ExpireHour = %d, want 24", cfg.JWT.ExpireHour) + } + if cfg.JWT.AccessTokenExpireMinutes != 0 { + t.Fatalf("JWT.AccessTokenExpireMinutes = %d, want 0", cfg.JWT.AccessTokenExpireMinutes) + } +} + +func TestLoadJWTAccessTokenExpireMinutesFromEnv(t *testing.T) { + resetViperWithJWTSecret(t) + t.Setenv("JWT_ACCESS_TOKEN_EXPIRE_MINUTES", "90") + + cfg, err := Load() + if err != nil { + t.Fatalf("Load() error: %v", err) + } + + if cfg.JWT.AccessTokenExpireMinutes != 90 { + t.Fatalf("JWT.AccessTokenExpireMinutes = %d, want 90", cfg.JWT.AccessTokenExpireMinutes) + } +} + +func TestLoadDefaultDatabaseSSLMode(t *testing.T) { + resetViperWithJWTSecret(t) + + cfg, err := Load() + if err != nil { + t.Fatalf("Load() error: %v", err) + } + + if cfg.Database.SSLMode != "prefer" { + t.Fatalf("Database.SSLMode = %q, want %q", cfg.Database.SSLMode, "prefer") } } func TestValidateLinuxDoFrontendRedirectURL(t *testing.T) { - viper.Reset() + resetViperWithJWTSecret(t) cfg, err := Load() if err != nil { @@ -118,7 +229,7 @@ func TestValidateLinuxDoFrontendRedirectURL(t *testing.T) { } func TestValidateLinuxDoPKCERequiredForPublicClient(t *testing.T) { - viper.Reset() + resetViperWithJWTSecret(t) cfg, err := Load() if err != nil { @@ -143,7 +254,7 @@ func TestValidateLinuxDoPKCERequiredForPublicClient(t *testing.T) { } func TestLoadDefaultDashboardCacheConfig(t *testing.T) { - viper.Reset() + resetViperWithJWTSecret(t) cfg, err := Load() if err != nil { @@ -168,7 +279,7 @@ func TestLoadDefaultDashboardCacheConfig(t *testing.T) { } func TestValidateDashboardCacheConfigEnabled(t *testing.T) { - viper.Reset() + resetViperWithJWTSecret(t) cfg, err := Load() if err != nil { @@ -188,7 +299,7 @@ func TestValidateDashboardCacheConfigEnabled(t *testing.T) { } func TestValidateDashboardCacheConfigDisabled(t *testing.T) { - viper.Reset() + resetViperWithJWTSecret(t) cfg, err := Load() if err != nil { @@ -207,7 +318,7 @@ func TestValidateDashboardCacheConfigDisabled(t *testing.T) { } func TestLoadDefaultDashboardAggregationConfig(t *testing.T) { - viper.Reset() + resetViperWithJWTSecret(t) cfg, err := Load() if err != nil { @@ -244,7 +355,7 @@ func TestLoadDefaultDashboardAggregationConfig(t *testing.T) { } func TestValidateDashboardAggregationConfigDisabled(t *testing.T) { - viper.Reset() + resetViperWithJWTSecret(t) cfg, err := Load() if err != nil { @@ -263,7 +374,7 @@ func TestValidateDashboardAggregationConfigDisabled(t *testing.T) { } func TestValidateDashboardAggregationBackfillMaxDays(t *testing.T) { - viper.Reset() + resetViperWithJWTSecret(t) cfg, err := Load() if err != nil { @@ -282,7 +393,7 @@ func TestValidateDashboardAggregationBackfillMaxDays(t *testing.T) { } func TestLoadDefaultUsageCleanupConfig(t *testing.T) { - viper.Reset() + resetViperWithJWTSecret(t) cfg, err := Load() if err != nil { @@ -307,7 +418,7 @@ func TestLoadDefaultUsageCleanupConfig(t *testing.T) { } func TestValidateUsageCleanupConfigEnabled(t *testing.T) { - viper.Reset() + resetViperWithJWTSecret(t) cfg, err := Load() if err != nil { @@ -326,7 +437,7 @@ func TestValidateUsageCleanupConfigEnabled(t *testing.T) { } func TestValidateUsageCleanupConfigDisabled(t *testing.T) { - viper.Reset() + resetViperWithJWTSecret(t) cfg, err := Load() if err != nil { @@ -424,6 +535,40 @@ func TestValidateAbsoluteHTTPURL(t *testing.T) { } } +func TestValidateServerFrontendURL(t *testing.T) { + resetViperWithJWTSecret(t) + + cfg, err := Load() + if err != nil { + t.Fatalf("Load() error: %v", err) + } + + cfg.Server.FrontendURL = "https://example.com" + if err := cfg.Validate(); err != nil { + t.Fatalf("Validate() frontend_url valid error: %v", err) + } + + cfg.Server.FrontendURL = "https://example.com/path" + if err := cfg.Validate(); err != nil { + t.Fatalf("Validate() frontend_url with path valid error: %v", err) + } + + cfg.Server.FrontendURL = "https://example.com?utm=1" + if err := cfg.Validate(); err == nil { + t.Fatalf("Validate() should reject server.frontend_url with query") + } + + cfg.Server.FrontendURL = "https://user:pass@example.com" + if err := cfg.Validate(); err == nil { + t.Fatalf("Validate() should reject server.frontend_url with userinfo") + } + + cfg.Server.FrontendURL = "/relative" + if err := cfg.Validate(); err == nil { + t.Fatalf("Validate() should reject relative server.frontend_url") + } +} + func TestValidateFrontendRedirectURL(t *testing.T) { if err := ValidateFrontendRedirectURL("/auth/callback"); err != nil { t.Fatalf("ValidateFrontendRedirectURL relative error: %v", err) @@ -445,6 +590,7 @@ func TestValidateFrontendRedirectURL(t *testing.T) { func TestWarnIfInsecureURL(t *testing.T) { warnIfInsecureURL("test", "http://example.com") warnIfInsecureURL("test", "bad://url") + warnIfInsecureURL("test", "://invalid") } func TestGenerateJWTSecretDefaultLength(t *testing.T) { @@ -458,7 +604,7 @@ func TestGenerateJWTSecretDefaultLength(t *testing.T) { } func TestValidateOpsCleanupScheduleRequired(t *testing.T) { - viper.Reset() + resetViperWithJWTSecret(t) cfg, err := Load() if err != nil { @@ -476,7 +622,7 @@ func TestValidateOpsCleanupScheduleRequired(t *testing.T) { } func TestValidateConcurrencyPingInterval(t *testing.T) { - viper.Reset() + resetViperWithJWTSecret(t) cfg, err := Load() if err != nil { @@ -493,14 +639,14 @@ func TestValidateConcurrencyPingInterval(t *testing.T) { } func TestProvideConfig(t *testing.T) { - viper.Reset() + resetViperWithJWTSecret(t) if _, err := ProvideConfig(); err != nil { t.Fatalf("ProvideConfig() error: %v", err) } } func TestValidateConfigWithLinuxDoEnabled(t *testing.T) { - viper.Reset() + resetViperWithJWTSecret(t) cfg, err := Load() if err != nil { @@ -544,6 +690,24 @@ func TestGenerateJWTSecretWithLength(t *testing.T) { } } +func TestDatabaseDSNWithTimezone_WithPassword(t *testing.T) { + d := &DatabaseConfig{ + Host: "localhost", + Port: 5432, + User: "u", + Password: "p", + DBName: "db", + SSLMode: "prefer", + } + got := d.DSNWithTimezone("UTC") + if !strings.Contains(got, "password=p") { + t.Fatalf("DSNWithTimezone should include password: %q", got) + } + if !strings.Contains(got, "TimeZone=UTC") { + t.Fatalf("DSNWithTimezone should include TimeZone=UTC: %q", got) + } +} + func TestValidateAbsoluteHTTPURLMissingHost(t *testing.T) { if err := ValidateAbsoluteHTTPURL("https://"); err == nil { t.Fatalf("ValidateAbsoluteHTTPURL should reject missing host") @@ -566,10 +730,35 @@ func TestWarnIfInsecureURLHTTPS(t *testing.T) { warnIfInsecureURL("secure", "https://example.com") } +func TestValidateJWTSecret_UTF8Bytes(t *testing.T) { + resetViperWithJWTSecret(t) + cfg, err := Load() + if err != nil { + t.Fatalf("Load() error: %v", err) + } + + // 31 bytes (< 32) even though it's 31 characters. + cfg.JWT.Secret = strings.Repeat("a", 31) + err = cfg.Validate() + if err == nil { + t.Fatalf("Validate() should reject 31-byte secret") + } + if !strings.Contains(err.Error(), "at least 32 bytes") { + t.Fatalf("Validate() error = %v", err) + } + + // 32 bytes OK. + cfg.JWT.Secret = strings.Repeat("a", 32) + err = cfg.Validate() + if err != nil { + t.Fatalf("Validate() should accept 32-byte secret: %v", err) + } +} + func TestValidateConfigErrors(t *testing.T) { buildValid := func(t *testing.T) *Config { t.Helper() - viper.Reset() + resetViperWithJWTSecret(t) cfg, err := Load() if err != nil { t.Fatalf("Load() error: %v", err) @@ -582,6 +771,26 @@ func TestValidateConfigErrors(t *testing.T) { mutate func(*Config) wantErr string }{ + { + name: "jwt secret required", + mutate: func(c *Config) { c.JWT.Secret = "" }, + wantErr: "jwt.secret is required", + }, + { + name: "jwt secret min bytes", + mutate: func(c *Config) { c.JWT.Secret = strings.Repeat("a", 31) }, + wantErr: "jwt.secret must be at least 32 bytes", + }, + { + name: "subscription maintenance worker_count non-negative", + mutate: func(c *Config) { c.SubscriptionMaintenance.WorkerCount = -1 }, + wantErr: "subscription_maintenance.worker_count", + }, + { + name: "subscription maintenance queue_size non-negative", + mutate: func(c *Config) { c.SubscriptionMaintenance.QueueSize = -1 }, + wantErr: "subscription_maintenance.queue_size", + }, { name: "jwt expire hour positive", mutate: func(c *Config) { c.JWT.ExpireHour = 0 }, @@ -592,6 +801,11 @@ func TestValidateConfigErrors(t *testing.T) { mutate: func(c *Config) { c.JWT.ExpireHour = 200 }, wantErr: "jwt.expire_hour must be <= 168", }, + { + name: "jwt access token expire minutes non-negative", + mutate: func(c *Config) { c.JWT.AccessTokenExpireMinutes = -1 }, + wantErr: "jwt.access_token_expire_minutes must be non-negative", + }, { name: "csp policy required", mutate: func(c *Config) { c.Security.CSP.Enabled = true; c.Security.CSP.Policy = "" }, @@ -799,6 +1013,84 @@ func TestValidateConfigErrors(t *testing.T) { mutate: func(c *Config) { c.Gateway.MaxLineSize = -1 }, wantErr: "gateway.max_line_size must be non-negative", }, + { + name: "gateway usage record worker count", + mutate: func(c *Config) { c.Gateway.UsageRecord.WorkerCount = 0 }, + wantErr: "gateway.usage_record.worker_count", + }, + { + name: "gateway usage record queue size", + mutate: func(c *Config) { c.Gateway.UsageRecord.QueueSize = 0 }, + wantErr: "gateway.usage_record.queue_size", + }, + { + name: "gateway usage record timeout", + mutate: func(c *Config) { c.Gateway.UsageRecord.TaskTimeoutSeconds = 0 }, + wantErr: "gateway.usage_record.task_timeout_seconds", + }, + { + name: "gateway usage record overflow policy", + mutate: func(c *Config) { c.Gateway.UsageRecord.OverflowPolicy = "invalid" }, + wantErr: "gateway.usage_record.overflow_policy", + }, + { + name: "gateway usage record sample percent range", + mutate: func(c *Config) { c.Gateway.UsageRecord.OverflowSamplePercent = 101 }, + wantErr: "gateway.usage_record.overflow_sample_percent", + }, + { + name: "gateway usage record sample percent required for sample policy", + mutate: func(c *Config) { + c.Gateway.UsageRecord.OverflowPolicy = UsageRecordOverflowPolicySample + c.Gateway.UsageRecord.OverflowSamplePercent = 0 + }, + wantErr: "gateway.usage_record.overflow_sample_percent must be positive", + }, + { + name: "gateway usage record auto scale max gte min", + mutate: func(c *Config) { + c.Gateway.UsageRecord.AutoScaleMinWorkers = 256 + c.Gateway.UsageRecord.AutoScaleMaxWorkers = 128 + }, + wantErr: "gateway.usage_record.auto_scale_max_workers", + }, + { + name: "gateway usage record worker in auto scale range", + mutate: func(c *Config) { + c.Gateway.UsageRecord.AutoScaleMinWorkers = 200 + c.Gateway.UsageRecord.AutoScaleMaxWorkers = 300 + c.Gateway.UsageRecord.WorkerCount = 128 + }, + wantErr: "gateway.usage_record.worker_count must be between auto_scale_min_workers and auto_scale_max_workers", + }, + { + name: "gateway usage record auto scale queue thresholds order", + mutate: func(c *Config) { + c.Gateway.UsageRecord.AutoScaleUpQueuePercent = 50 + c.Gateway.UsageRecord.AutoScaleDownQueuePercent = 50 + }, + wantErr: "gateway.usage_record.auto_scale_down_queue_percent must be less", + }, + { + name: "gateway usage record auto scale up step", + mutate: func(c *Config) { c.Gateway.UsageRecord.AutoScaleUpStep = 0 }, + wantErr: "gateway.usage_record.auto_scale_up_step", + }, + { + name: "gateway usage record auto scale interval", + mutate: func(c *Config) { c.Gateway.UsageRecord.AutoScaleCheckIntervalSeconds = 0 }, + wantErr: "gateway.usage_record.auto_scale_check_interval_seconds", + }, + { + name: "gateway user group rate cache ttl", + mutate: func(c *Config) { c.Gateway.UserGroupRateCacheTTLSeconds = 0 }, + wantErr: "gateway.user_group_rate_cache_ttl_seconds", + }, + { + name: "gateway models list cache ttl range", + mutate: func(c *Config) { c.Gateway.ModelsListCacheTTLSeconds = 31 }, + wantErr: "gateway.models_list_cache_ttl_seconds", + }, { name: "gateway scheduling sticky waiting", mutate: func(c *Config) { c.Gateway.Scheduling.StickySessionMaxWaiting = 0 }, @@ -822,6 +1114,37 @@ func TestValidateConfigErrors(t *testing.T) { }, wantErr: "gateway.scheduling.outbox_lag_rebuild_seconds", }, + { + name: "log level invalid", + mutate: func(c *Config) { c.Log.Level = "trace" }, + wantErr: "log.level", + }, + { + name: "log format invalid", + mutate: func(c *Config) { c.Log.Format = "plain" }, + wantErr: "log.format", + }, + { + name: "log output disabled", + mutate: func(c *Config) { + c.Log.Output.ToStdout = false + c.Log.Output.ToFile = false + }, + wantErr: "log.output.to_stdout and log.output.to_file cannot both be false", + }, + { + name: "log rotation size", + mutate: func(c *Config) { c.Log.Rotation.MaxSizeMB = 0 }, + wantErr: "log.rotation.max_size_mb", + }, + { + name: "log sampling enabled invalid", + mutate: func(c *Config) { + c.Log.Sampling.Enabled = true + c.Log.Sampling.Initial = 0 + }, + wantErr: "log.sampling.initial", + }, { name: "ops metrics collector ttl", mutate: func(c *Config) { c.Ops.MetricsCollectorCache.TTL = -1 }, @@ -850,3 +1173,234 @@ func TestValidateConfigErrors(t *testing.T) { }) } } + +func TestValidateConfig_AutoScaleDisabledIgnoreAutoScaleFields(t *testing.T) { + resetViperWithJWTSecret(t) + cfg, err := Load() + if err != nil { + t.Fatalf("Load() error: %v", err) + } + + cfg.Gateway.UsageRecord.AutoScaleEnabled = false + cfg.Gateway.UsageRecord.WorkerCount = 64 + + // 自动扩缩容关闭时,这些字段应被忽略,不应导致校验失败。 + cfg.Gateway.UsageRecord.AutoScaleMinWorkers = 0 + cfg.Gateway.UsageRecord.AutoScaleMaxWorkers = 0 + cfg.Gateway.UsageRecord.AutoScaleUpQueuePercent = 0 + cfg.Gateway.UsageRecord.AutoScaleDownQueuePercent = 100 + cfg.Gateway.UsageRecord.AutoScaleUpStep = 0 + cfg.Gateway.UsageRecord.AutoScaleDownStep = 0 + cfg.Gateway.UsageRecord.AutoScaleCheckIntervalSeconds = 0 + cfg.Gateway.UsageRecord.AutoScaleCooldownSeconds = -1 + + if err := cfg.Validate(); err != nil { + t.Fatalf("Validate() should ignore auto scale fields when disabled: %v", err) + } +} + +func TestValidateConfig_LogRequiredAndRotationBounds(t *testing.T) { + resetViperWithJWTSecret(t) + + cases := []struct { + name string + mutate func(*Config) + wantErr string + }{ + { + name: "log level required", + mutate: func(c *Config) { + c.Log.Level = "" + }, + wantErr: "log.level is required", + }, + { + name: "log format required", + mutate: func(c *Config) { + c.Log.Format = "" + }, + wantErr: "log.format is required", + }, + { + name: "log stacktrace required", + mutate: func(c *Config) { + c.Log.StacktraceLevel = "" + }, + wantErr: "log.stacktrace_level is required", + }, + { + name: "log max backups non-negative", + mutate: func(c *Config) { + c.Log.Rotation.MaxBackups = -1 + }, + wantErr: "log.rotation.max_backups must be non-negative", + }, + { + name: "log max age non-negative", + mutate: func(c *Config) { + c.Log.Rotation.MaxAgeDays = -1 + }, + wantErr: "log.rotation.max_age_days must be non-negative", + }, + { + name: "sampling thereafter non-negative when disabled", + mutate: func(c *Config) { + c.Log.Sampling.Enabled = false + c.Log.Sampling.Thereafter = -1 + }, + wantErr: "log.sampling.thereafter must be non-negative", + }, + } + + for _, tt := range cases { + t.Run(tt.name, func(t *testing.T) { + cfg, err := Load() + if err != nil { + t.Fatalf("Load() error: %v", err) + } + tt.mutate(cfg) + err = cfg.Validate() + if err == nil || !strings.Contains(err.Error(), tt.wantErr) { + t.Fatalf("Validate() error = %v, want %q", err, tt.wantErr) + } + }) + } +} + +func TestSoraCurlCFFISidecarDefaults(t *testing.T) { + resetViperWithJWTSecret(t) + + cfg, err := Load() + if err != nil { + t.Fatalf("Load() error: %v", err) + } + + if !cfg.Sora.Client.CurlCFFISidecar.Enabled { + t.Fatalf("Sora curl_cffi sidecar should be enabled by default") + } + if cfg.Sora.Client.CloudflareChallengeCooldownSeconds <= 0 { + t.Fatalf("Sora cloudflare challenge cooldown should be positive by default") + } + if cfg.Sora.Client.CurlCFFISidecar.BaseURL == "" { + t.Fatalf("Sora curl_cffi sidecar base_url should not be empty by default") + } + if cfg.Sora.Client.CurlCFFISidecar.Impersonate == "" { + t.Fatalf("Sora curl_cffi sidecar impersonate should not be empty by default") + } + if !cfg.Sora.Client.CurlCFFISidecar.SessionReuseEnabled { + t.Fatalf("Sora curl_cffi sidecar session reuse should be enabled by default") + } + if cfg.Sora.Client.CurlCFFISidecar.SessionTTLSeconds <= 0 { + t.Fatalf("Sora curl_cffi sidecar session ttl should be positive by default") + } +} + +func TestValidateSoraCurlCFFISidecarRequired(t *testing.T) { + resetViperWithJWTSecret(t) + + cfg, err := Load() + if err != nil { + t.Fatalf("Load() error: %v", err) + } + + cfg.Sora.Client.CurlCFFISidecar.Enabled = false + err = cfg.Validate() + if err == nil || !strings.Contains(err.Error(), "sora.client.curl_cffi_sidecar.enabled must be true") { + t.Fatalf("Validate() error = %v, want sidecar enabled error", err) + } +} + +func TestValidateSoraCurlCFFISidecarBaseURLRequired(t *testing.T) { + resetViperWithJWTSecret(t) + + cfg, err := Load() + if err != nil { + t.Fatalf("Load() error: %v", err) + } + + cfg.Sora.Client.CurlCFFISidecar.BaseURL = " " + err = cfg.Validate() + if err == nil || !strings.Contains(err.Error(), "sora.client.curl_cffi_sidecar.base_url is required") { + t.Fatalf("Validate() error = %v, want sidecar base_url required error", err) + } +} + +func TestValidateSoraCurlCFFISidecarSessionTTLNonNegative(t *testing.T) { + resetViperWithJWTSecret(t) + + cfg, err := Load() + if err != nil { + t.Fatalf("Load() error: %v", err) + } + + cfg.Sora.Client.CurlCFFISidecar.SessionTTLSeconds = -1 + err = cfg.Validate() + if err == nil || !strings.Contains(err.Error(), "sora.client.curl_cffi_sidecar.session_ttl_seconds must be non-negative") { + t.Fatalf("Validate() error = %v, want sidecar session ttl error", err) + } +} + +func TestValidateSoraCloudflareChallengeCooldownNonNegative(t *testing.T) { + resetViperWithJWTSecret(t) + + cfg, err := Load() + if err != nil { + t.Fatalf("Load() error: %v", err) + } + + cfg.Sora.Client.CloudflareChallengeCooldownSeconds = -1 + err = cfg.Validate() + if err == nil || !strings.Contains(err.Error(), "sora.client.cloudflare_challenge_cooldown_seconds must be non-negative") { + t.Fatalf("Validate() error = %v, want cloudflare cooldown error", err) + } +} + +func TestLoad_DefaultGatewayUsageRecordConfig(t *testing.T) { + resetViperWithJWTSecret(t) + cfg, err := Load() + if err != nil { + t.Fatalf("Load() error: %v", err) + } + if cfg.Gateway.UsageRecord.WorkerCount != 128 { + t.Fatalf("worker_count = %d, want 128", cfg.Gateway.UsageRecord.WorkerCount) + } + if cfg.Gateway.UsageRecord.QueueSize != 16384 { + t.Fatalf("queue_size = %d, want 16384", cfg.Gateway.UsageRecord.QueueSize) + } + if cfg.Gateway.UsageRecord.TaskTimeoutSeconds != 5 { + t.Fatalf("task_timeout_seconds = %d, want 5", cfg.Gateway.UsageRecord.TaskTimeoutSeconds) + } + if cfg.Gateway.UsageRecord.OverflowPolicy != UsageRecordOverflowPolicySample { + t.Fatalf("overflow_policy = %s, want %s", cfg.Gateway.UsageRecord.OverflowPolicy, UsageRecordOverflowPolicySample) + } + if cfg.Gateway.UsageRecord.OverflowSamplePercent != 10 { + t.Fatalf("overflow_sample_percent = %d, want 10", cfg.Gateway.UsageRecord.OverflowSamplePercent) + } + if !cfg.Gateway.UsageRecord.AutoScaleEnabled { + t.Fatalf("auto_scale_enabled = false, want true") + } + if cfg.Gateway.UsageRecord.AutoScaleMinWorkers != 128 { + t.Fatalf("auto_scale_min_workers = %d, want 128", cfg.Gateway.UsageRecord.AutoScaleMinWorkers) + } + if cfg.Gateway.UsageRecord.AutoScaleMaxWorkers != 512 { + t.Fatalf("auto_scale_max_workers = %d, want 512", cfg.Gateway.UsageRecord.AutoScaleMaxWorkers) + } + if cfg.Gateway.UsageRecord.AutoScaleUpQueuePercent != 70 { + t.Fatalf("auto_scale_up_queue_percent = %d, want 70", cfg.Gateway.UsageRecord.AutoScaleUpQueuePercent) + } + if cfg.Gateway.UsageRecord.AutoScaleDownQueuePercent != 15 { + t.Fatalf("auto_scale_down_queue_percent = %d, want 15", cfg.Gateway.UsageRecord.AutoScaleDownQueuePercent) + } + if cfg.Gateway.UsageRecord.AutoScaleUpStep != 32 { + t.Fatalf("auto_scale_up_step = %d, want 32", cfg.Gateway.UsageRecord.AutoScaleUpStep) + } + if cfg.Gateway.UsageRecord.AutoScaleDownStep != 16 { + t.Fatalf("auto_scale_down_step = %d, want 16", cfg.Gateway.UsageRecord.AutoScaleDownStep) + } + if cfg.Gateway.UsageRecord.AutoScaleCheckIntervalSeconds != 3 { + t.Fatalf("auto_scale_check_interval_seconds = %d, want 3", cfg.Gateway.UsageRecord.AutoScaleCheckIntervalSeconds) + } + if cfg.Gateway.UsageRecord.AutoScaleCooldownSeconds != 10 { + t.Fatalf("auto_scale_cooldown_seconds = %d, want 10", cfg.Gateway.UsageRecord.AutoScaleCooldownSeconds) + } +} diff --git a/backend/internal/config/wire.go b/backend/internal/config/wire.go index ec26c401..bf6b3bd6 100644 --- a/backend/internal/config/wire.go +++ b/backend/internal/config/wire.go @@ -9,5 +9,5 @@ var ProviderSet = wire.NewSet( // ProvideConfig 提供应用配置 func ProvideConfig() (*Config, error) { - return Load() + return LoadForBootstrap() } diff --git a/backend/internal/domain/constants.go b/backend/internal/domain/constants.go index 05b5adc1..27972d01 100644 --- a/backend/internal/domain/constants.go +++ b/backend/internal/domain/constants.go @@ -22,6 +22,7 @@ const ( PlatformOpenAI = "openai" PlatformGemini = "gemini" PlatformAntigravity = "antigravity" + PlatformSora = "sora" ) // Account type constants diff --git a/backend/internal/handler/admin/account_data.go b/backend/internal/handler/admin/account_data.go index 34397696..4ce17219 100644 --- a/backend/internal/handler/admin/account_data.go +++ b/backend/internal/handler/admin/account_data.go @@ -175,22 +175,28 @@ func (h *AccountHandler) ImportData(c *gin.Context) { return } - dataPayload := req.Data - if err := validateDataHeader(dataPayload); err != nil { + if err := validateDataHeader(req.Data); err != nil { response.BadRequest(c, err.Error()) return } + executeAdminIdempotentJSON(c, "admin.accounts.import_data", req, service.DefaultWriteIdempotencyTTL(), func(ctx context.Context) (any, error) { + return h.importData(ctx, req) + }) +} + +func (h *AccountHandler) importData(ctx context.Context, req DataImportRequest) (DataImportResult, error) { skipDefaultGroupBind := true if req.SkipDefaultGroupBind != nil { skipDefaultGroupBind = *req.SkipDefaultGroupBind } + dataPayload := req.Data result := DataImportResult{} - existingProxies, err := h.listAllProxies(c.Request.Context()) + + existingProxies, err := h.listAllProxies(ctx) if err != nil { - response.ErrorFrom(c, err) - return + return result, err } proxyKeyToID := make(map[string]int64, len(existingProxies)) @@ -221,8 +227,8 @@ func (h *AccountHandler) ImportData(c *gin.Context) { proxyKeyToID[key] = existingID result.ProxyReused++ if normalizedStatus != "" { - if proxy, err := h.adminService.GetProxy(c.Request.Context(), existingID); err == nil && proxy != nil && proxy.Status != normalizedStatus { - _, _ = h.adminService.UpdateProxy(c.Request.Context(), existingID, &service.UpdateProxyInput{ + if proxy, getErr := h.adminService.GetProxy(ctx, existingID); getErr == nil && proxy != nil && proxy.Status != normalizedStatus { + _, _ = h.adminService.UpdateProxy(ctx, existingID, &service.UpdateProxyInput{ Status: normalizedStatus, }) } @@ -230,7 +236,7 @@ func (h *AccountHandler) ImportData(c *gin.Context) { continue } - created, err := h.adminService.CreateProxy(c.Request.Context(), &service.CreateProxyInput{ + created, createErr := h.adminService.CreateProxy(ctx, &service.CreateProxyInput{ Name: defaultProxyName(item.Name), Protocol: item.Protocol, Host: item.Host, @@ -238,13 +244,13 @@ func (h *AccountHandler) ImportData(c *gin.Context) { Username: item.Username, Password: item.Password, }) - if err != nil { + if createErr != nil { result.ProxyFailed++ result.Errors = append(result.Errors, DataImportError{ Kind: "proxy", Name: item.Name, ProxyKey: key, - Message: err.Error(), + Message: createErr.Error(), }) continue } @@ -252,7 +258,7 @@ func (h *AccountHandler) ImportData(c *gin.Context) { result.ProxyCreated++ if normalizedStatus != "" && normalizedStatus != created.Status { - _, _ = h.adminService.UpdateProxy(c.Request.Context(), created.ID, &service.UpdateProxyInput{ + _, _ = h.adminService.UpdateProxy(ctx, created.ID, &service.UpdateProxyInput{ Status: normalizedStatus, }) } @@ -303,7 +309,7 @@ func (h *AccountHandler) ImportData(c *gin.Context) { SkipDefaultGroupBind: skipDefaultGroupBind, } - if _, err := h.adminService.CreateAccount(c.Request.Context(), accountInput); err != nil { + if _, err := h.adminService.CreateAccount(ctx, accountInput); err != nil { result.AccountFailed++ result.Errors = append(result.Errors, DataImportError{ Kind: "account", @@ -315,7 +321,7 @@ func (h *AccountHandler) ImportData(c *gin.Context) { result.AccountCreated++ } - response.Success(c, result) + return result, nil } func (h *AccountHandler) listAllProxies(ctx context.Context) ([]service.Proxy, error) { diff --git a/backend/internal/handler/admin/account_handler.go b/backend/internal/handler/admin/account_handler.go index 0fae04ac..a2a8dd43 100644 --- a/backend/internal/handler/admin/account_handler.go +++ b/backend/internal/handler/admin/account_handler.go @@ -2,7 +2,13 @@ package admin import ( + "context" + "crypto/sha256" + "encoding/hex" + "encoding/json" "errors" + "fmt" + "net/http" "strconv" "strings" "sync" @@ -142,6 +148,44 @@ type AccountWithConcurrency struct { ActiveSessions *int `json:"active_sessions,omitempty"` // 当前活跃会话数 } +func (h *AccountHandler) buildAccountResponseWithRuntime(ctx context.Context, account *service.Account) AccountWithConcurrency { + item := AccountWithConcurrency{ + Account: dto.AccountFromService(account), + CurrentConcurrency: 0, + } + if account == nil { + return item + } + + if h.concurrencyService != nil { + if counts, err := h.concurrencyService.GetAccountConcurrencyBatch(ctx, []int64{account.ID}); err == nil { + item.CurrentConcurrency = counts[account.ID] + } + } + + if account.IsAnthropicOAuthOrSetupToken() { + if h.accountUsageService != nil && account.GetWindowCostLimit() > 0 { + startTime := account.GetCurrentWindowStartTime() + if stats, err := h.accountUsageService.GetAccountWindowStats(ctx, account.ID, startTime); err == nil && stats != nil { + cost := stats.StandardCost + item.CurrentWindowCost = &cost + } + } + + if h.sessionLimitCache != nil && account.GetMaxSessions() > 0 { + idleTimeout := time.Duration(account.GetSessionIdleTimeoutMinutes()) * time.Minute + idleTimeouts := map[int64]time.Duration{account.ID: idleTimeout} + if sessions, err := h.sessionLimitCache.GetActiveSessionCountBatch(ctx, []int64{account.ID}, idleTimeouts); err == nil { + if count, ok := sessions[account.ID]; ok { + item.ActiveSessions = &count + } + } + } + } + + return item +} + // List handles listing all accounts with pagination // GET /api/v1/admin/accounts func (h *AccountHandler) List(c *gin.Context) { @@ -262,9 +306,71 @@ func (h *AccountHandler) List(c *gin.Context) { result[i] = item } + etag := buildAccountsListETag(result, total, page, pageSize, platform, accountType, status, search) + if etag != "" { + c.Header("ETag", etag) + c.Header("Vary", "If-None-Match") + if ifNoneMatchMatched(c.GetHeader("If-None-Match"), etag) { + c.Status(http.StatusNotModified) + return + } + } + response.Paginated(c, result, total, page, pageSize) } +func buildAccountsListETag( + items []AccountWithConcurrency, + total int64, + page, pageSize int, + platform, accountType, status, search string, +) string { + payload := struct { + Total int64 `json:"total"` + Page int `json:"page"` + PageSize int `json:"page_size"` + Platform string `json:"platform"` + AccountType string `json:"type"` + Status string `json:"status"` + Search string `json:"search"` + Items []AccountWithConcurrency `json:"items"` + }{ + Total: total, + Page: page, + PageSize: pageSize, + Platform: platform, + AccountType: accountType, + Status: status, + Search: search, + Items: items, + } + raw, err := json.Marshal(payload) + if err != nil { + return "" + } + sum := sha256.Sum256(raw) + return "\"" + hex.EncodeToString(sum[:]) + "\"" +} + +func ifNoneMatchMatched(ifNoneMatch, etag string) bool { + if etag == "" || ifNoneMatch == "" { + return false + } + for _, token := range strings.Split(ifNoneMatch, ",") { + candidate := strings.TrimSpace(token) + if candidate == "*" { + return true + } + if candidate == etag { + return true + } + if strings.HasPrefix(candidate, "W/") && strings.TrimPrefix(candidate, "W/") == etag { + return true + } + } + return false +} + // GetByID handles getting an account by ID // GET /api/v1/admin/accounts/:id func (h *AccountHandler) GetByID(c *gin.Context) { @@ -280,7 +386,7 @@ func (h *AccountHandler) GetByID(c *gin.Context) { return } - response.Success(c, dto.AccountFromService(account)) + response.Success(c, h.buildAccountResponseWithRuntime(c.Request.Context(), account)) } // Create handles creating a new account @@ -299,21 +405,27 @@ func (h *AccountHandler) Create(c *gin.Context) { // 确定是否跳过混合渠道检查 skipCheck := req.ConfirmMixedChannelRisk != nil && *req.ConfirmMixedChannelRisk - account, err := h.adminService.CreateAccount(c.Request.Context(), &service.CreateAccountInput{ - Name: req.Name, - Notes: req.Notes, - Platform: req.Platform, - Type: req.Type, - Credentials: req.Credentials, - Extra: req.Extra, - ProxyID: req.ProxyID, - Concurrency: req.Concurrency, - Priority: req.Priority, - RateMultiplier: req.RateMultiplier, - GroupIDs: req.GroupIDs, - ExpiresAt: req.ExpiresAt, - AutoPauseOnExpired: req.AutoPauseOnExpired, - SkipMixedChannelCheck: skipCheck, + result, err := executeAdminIdempotent(c, "admin.accounts.create", req, service.DefaultWriteIdempotencyTTL(), func(ctx context.Context) (any, error) { + account, execErr := h.adminService.CreateAccount(ctx, &service.CreateAccountInput{ + Name: req.Name, + Notes: req.Notes, + Platform: req.Platform, + Type: req.Type, + Credentials: req.Credentials, + Extra: req.Extra, + ProxyID: req.ProxyID, + Concurrency: req.Concurrency, + Priority: req.Priority, + RateMultiplier: req.RateMultiplier, + GroupIDs: req.GroupIDs, + ExpiresAt: req.ExpiresAt, + AutoPauseOnExpired: req.AutoPauseOnExpired, + SkipMixedChannelCheck: skipCheck, + }) + if execErr != nil { + return nil, execErr + } + return h.buildAccountResponseWithRuntime(ctx, account), nil }) if err != nil { // 检查是否为混合渠道错误 @@ -334,11 +446,17 @@ func (h *AccountHandler) Create(c *gin.Context) { return } + if retryAfter := service.RetryAfterSecondsFromError(err); retryAfter > 0 { + c.Header("Retry-After", strconv.Itoa(retryAfter)) + } response.ErrorFrom(c, err) return } - response.Success(c, dto.AccountFromService(account)) + if result != nil && result.Replayed { + c.Header("X-Idempotency-Replayed", "true") + } + response.Success(c, result.Data) } // Update handles updating an account @@ -402,7 +520,7 @@ func (h *AccountHandler) Update(c *gin.Context) { return } - response.Success(c, dto.AccountFromService(account)) + response.Success(c, h.buildAccountResponseWithRuntime(c.Request.Context(), account)) } // Delete handles deleting an account @@ -660,7 +778,7 @@ func (h *AccountHandler) Refresh(c *gin.Context) { } } - response.Success(c, dto.AccountFromService(updatedAccount)) + response.Success(c, h.buildAccountResponseWithRuntime(c.Request.Context(), updatedAccount)) } // GetStats handles getting account statistics @@ -718,7 +836,7 @@ func (h *AccountHandler) ClearError(c *gin.Context) { } } - response.Success(c, dto.AccountFromService(account)) + response.Success(c, h.buildAccountResponseWithRuntime(c.Request.Context(), account)) } // BatchCreate handles batch creating accounts @@ -732,61 +850,62 @@ func (h *AccountHandler) BatchCreate(c *gin.Context) { return } - ctx := c.Request.Context() - success := 0 - failed := 0 - results := make([]gin.H, 0, len(req.Accounts)) + executeAdminIdempotentJSON(c, "admin.accounts.batch_create", req, service.DefaultWriteIdempotencyTTL(), func(ctx context.Context) (any, error) { + success := 0 + failed := 0 + results := make([]gin.H, 0, len(req.Accounts)) - for _, item := range req.Accounts { - if item.RateMultiplier != nil && *item.RateMultiplier < 0 { - failed++ + for _, item := range req.Accounts { + if item.RateMultiplier != nil && *item.RateMultiplier < 0 { + failed++ + results = append(results, gin.H{ + "name": item.Name, + "success": false, + "error": "rate_multiplier must be >= 0", + }) + continue + } + + skipCheck := item.ConfirmMixedChannelRisk != nil && *item.ConfirmMixedChannelRisk + + account, err := h.adminService.CreateAccount(ctx, &service.CreateAccountInput{ + Name: item.Name, + Notes: item.Notes, + Platform: item.Platform, + Type: item.Type, + Credentials: item.Credentials, + Extra: item.Extra, + ProxyID: item.ProxyID, + Concurrency: item.Concurrency, + Priority: item.Priority, + RateMultiplier: item.RateMultiplier, + GroupIDs: item.GroupIDs, + ExpiresAt: item.ExpiresAt, + AutoPauseOnExpired: item.AutoPauseOnExpired, + SkipMixedChannelCheck: skipCheck, + }) + if err != nil { + failed++ + results = append(results, gin.H{ + "name": item.Name, + "success": false, + "error": err.Error(), + }) + continue + } + success++ results = append(results, gin.H{ "name": item.Name, - "success": false, - "error": "rate_multiplier must be >= 0", + "id": account.ID, + "success": true, }) - continue } - skipCheck := item.ConfirmMixedChannelRisk != nil && *item.ConfirmMixedChannelRisk - - account, err := h.adminService.CreateAccount(ctx, &service.CreateAccountInput{ - Name: item.Name, - Notes: item.Notes, - Platform: item.Platform, - Type: item.Type, - Credentials: item.Credentials, - Extra: item.Extra, - ProxyID: item.ProxyID, - Concurrency: item.Concurrency, - Priority: item.Priority, - RateMultiplier: item.RateMultiplier, - GroupIDs: item.GroupIDs, - ExpiresAt: item.ExpiresAt, - AutoPauseOnExpired: item.AutoPauseOnExpired, - SkipMixedChannelCheck: skipCheck, - }) - if err != nil { - failed++ - results = append(results, gin.H{ - "name": item.Name, - "success": false, - "error": err.Error(), - }) - continue - } - success++ - results = append(results, gin.H{ - "name": item.Name, - "id": account.ID, - "success": true, - }) - } - - response.Success(c, gin.H{ - "success": success, - "failed": failed, - "results": results, + return gin.H{ + "success": success, + "failed": failed, + "results": results, + }, nil }) } @@ -824,57 +943,58 @@ func (h *AccountHandler) BatchUpdateCredentials(c *gin.Context) { } ctx := c.Request.Context() - success := 0 - failed := 0 - results := []gin.H{} + // 阶段一:预验证所有账号存在,收集 credentials + type accountUpdate struct { + ID int64 + Credentials map[string]any + } + updates := make([]accountUpdate, 0, len(req.AccountIDs)) for _, accountID := range req.AccountIDs { - // Get account account, err := h.adminService.GetAccount(ctx, accountID) if err != nil { - failed++ - results = append(results, gin.H{ - "account_id": accountID, - "success": false, - "error": "Account not found", - }) - continue + response.Error(c, 404, fmt.Sprintf("Account %d not found", accountID)) + return } - - // Update credentials field if account.Credentials == nil { account.Credentials = make(map[string]any) } - account.Credentials[req.Field] = req.Value + updates = append(updates, accountUpdate{ID: accountID, Credentials: account.Credentials}) + } - // Update account - updateInput := &service.UpdateAccountInput{ - Credentials: account.Credentials, - } - - _, err = h.adminService.UpdateAccount(ctx, accountID, updateInput) - if err != nil { + // 阶段二:依次更新,返回每个账号的成功/失败明细,便于调用方重试 + success := 0 + failed := 0 + successIDs := make([]int64, 0, len(updates)) + failedIDs := make([]int64, 0, len(updates)) + results := make([]gin.H, 0, len(updates)) + for _, u := range updates { + updateInput := &service.UpdateAccountInput{Credentials: u.Credentials} + if _, err := h.adminService.UpdateAccount(ctx, u.ID, updateInput); err != nil { failed++ + failedIDs = append(failedIDs, u.ID) results = append(results, gin.H{ - "account_id": accountID, + "account_id": u.ID, "success": false, "error": err.Error(), }) continue } - success++ + successIDs = append(successIDs, u.ID) results = append(results, gin.H{ - "account_id": accountID, + "account_id": u.ID, "success": true, }) } response.Success(c, gin.H{ - "success": success, - "failed": failed, - "results": results, + "success": success, + "failed": failed, + "success_ids": successIDs, + "failed_ids": failedIDs, + "results": results, }) } @@ -1109,7 +1229,13 @@ func (h *AccountHandler) ClearRateLimit(c *gin.Context) { return } - response.Success(c, gin.H{"message": "Rate limit cleared successfully"}) + account, err := h.adminService.GetAccount(c.Request.Context(), accountID) + if err != nil { + response.ErrorFrom(c, err) + return + } + + response.Success(c, h.buildAccountResponseWithRuntime(c.Request.Context(), account)) } // GetTempUnschedulable handles getting temporary unschedulable status @@ -1199,7 +1325,7 @@ func (h *AccountHandler) SetSchedulable(c *gin.Context) { return } - response.Success(c, dto.AccountFromService(account)) + response.Success(c, h.buildAccountResponseWithRuntime(c.Request.Context(), account)) } // GetAvailableModels handles getting available models for an account @@ -1325,6 +1451,12 @@ func (h *AccountHandler) GetAvailableModels(c *gin.Context) { return } + // Handle Sora accounts + if account.Platform == service.PlatformSora { + response.Success(c, service.DefaultSoraModels(nil)) + return + } + // Handle Claude/Anthropic accounts // For OAuth and Setup-Token accounts: return default models if account.IsOAuth() { diff --git a/backend/internal/handler/admin/account_handler_passthrough_test.go b/backend/internal/handler/admin/account_handler_passthrough_test.go new file mode 100644 index 00000000..d09cccd6 --- /dev/null +++ b/backend/internal/handler/admin/account_handler_passthrough_test.go @@ -0,0 +1,66 @@ +package admin + +import ( + "bytes" + "encoding/json" + "net/http" + "net/http/httptest" + "testing" + + "github.com/gin-gonic/gin" + "github.com/stretchr/testify/require" +) + +func TestAccountHandler_Create_AnthropicAPIKeyPassthroughExtraForwarded(t *testing.T) { + gin.SetMode(gin.TestMode) + + adminSvc := newStubAdminService() + handler := NewAccountHandler( + adminSvc, + nil, + nil, + nil, + nil, + nil, + nil, + nil, + nil, + nil, + nil, + nil, + ) + + router := gin.New() + router.POST("/api/v1/admin/accounts", handler.Create) + + body := map[string]any{ + "name": "anthropic-key-1", + "platform": "anthropic", + "type": "apikey", + "credentials": map[string]any{ + "api_key": "sk-ant-xxx", + "base_url": "https://api.anthropic.com", + }, + "extra": map[string]any{ + "anthropic_passthrough": true, + }, + "concurrency": 1, + "priority": 1, + } + raw, err := json.Marshal(body) + require.NoError(t, err) + + rec := httptest.NewRecorder() + req := httptest.NewRequest(http.MethodPost, "/api/v1/admin/accounts", bytes.NewReader(raw)) + req.Header.Set("Content-Type", "application/json") + router.ServeHTTP(rec, req) + + require.Equal(t, http.StatusOK, rec.Code) + require.Len(t, adminSvc.createdAccounts, 1) + + created := adminSvc.createdAccounts[0] + require.Equal(t, "anthropic", created.Platform) + require.Equal(t, "apikey", created.Type) + require.NotNil(t, created.Extra) + require.Equal(t, true, created.Extra["anthropic_passthrough"]) +} diff --git a/backend/internal/handler/admin/admin_basic_handlers_test.go b/backend/internal/handler/admin/admin_basic_handlers_test.go index 20a25222..aeb4097f 100644 --- a/backend/internal/handler/admin/admin_basic_handlers_test.go +++ b/backend/internal/handler/admin/admin_basic_handlers_test.go @@ -47,6 +47,7 @@ func setupAdminRouter() (*gin.Engine, *stubAdminService) { router.DELETE("/api/v1/admin/proxies/:id", proxyHandler.Delete) router.POST("/api/v1/admin/proxies/batch-delete", proxyHandler.BatchDelete) router.POST("/api/v1/admin/proxies/:id/test", proxyHandler.Test) + router.POST("/api/v1/admin/proxies/:id/quality-check", proxyHandler.CheckQuality) router.GET("/api/v1/admin/proxies/:id/stats", proxyHandler.GetStats) router.GET("/api/v1/admin/proxies/:id/accounts", proxyHandler.GetProxyAccounts) @@ -208,6 +209,11 @@ func TestProxyHandlerEndpoints(t *testing.T) { router.ServeHTTP(rec, req) require.Equal(t, http.StatusOK, rec.Code) + rec = httptest.NewRecorder() + req = httptest.NewRequest(http.MethodPost, "/api/v1/admin/proxies/4/quality-check", nil) + router.ServeHTTP(rec, req) + require.Equal(t, http.StatusOK, rec.Code) + rec = httptest.NewRecorder() req = httptest.NewRequest(http.MethodGet, "/api/v1/admin/proxies/4/stats", nil) router.ServeHTTP(rec, req) diff --git a/backend/internal/handler/admin/admin_helpers_test.go b/backend/internal/handler/admin/admin_helpers_test.go index 863c755c..3833d32e 100644 --- a/backend/internal/handler/admin/admin_helpers_test.go +++ b/backend/internal/handler/admin/admin_helpers_test.go @@ -58,6 +58,96 @@ func TestParseOpsDuration(t *testing.T) { require.False(t, ok) } +func TestParseOpsOpenAITokenStatsDuration(t *testing.T) { + tests := []struct { + input string + want time.Duration + ok bool + }{ + {input: "30m", want: 30 * time.Minute, ok: true}, + {input: "1h", want: time.Hour, ok: true}, + {input: "1d", want: 24 * time.Hour, ok: true}, + {input: "15d", want: 15 * 24 * time.Hour, ok: true}, + {input: "30d", want: 30 * 24 * time.Hour, ok: true}, + {input: "7d", want: 0, ok: false}, + } + + for _, tt := range tests { + got, ok := parseOpsOpenAITokenStatsDuration(tt.input) + require.Equal(t, tt.ok, ok, "input=%s", tt.input) + require.Equal(t, tt.want, got, "input=%s", tt.input) + } +} + +func TestParseOpsOpenAITokenStatsFilter_Defaults(t *testing.T) { + gin.SetMode(gin.TestMode) + w := httptest.NewRecorder() + c, _ := gin.CreateTestContext(w) + c.Request = httptest.NewRequest(http.MethodGet, "/", nil) + + before := time.Now().UTC() + filter, err := parseOpsOpenAITokenStatsFilter(c) + after := time.Now().UTC() + + require.NoError(t, err) + require.NotNil(t, filter) + require.Equal(t, "30d", filter.TimeRange) + require.Equal(t, 1, filter.Page) + require.Equal(t, 20, filter.PageSize) + require.Equal(t, 0, filter.TopN) + require.Nil(t, filter.GroupID) + require.Equal(t, "", filter.Platform) + require.True(t, filter.StartTime.Before(filter.EndTime)) + require.WithinDuration(t, before.Add(-30*24*time.Hour), filter.StartTime, 2*time.Second) + require.WithinDuration(t, after, filter.EndTime, 2*time.Second) +} + +func TestParseOpsOpenAITokenStatsFilter_WithTopN(t *testing.T) { + gin.SetMode(gin.TestMode) + w := httptest.NewRecorder() + c, _ := gin.CreateTestContext(w) + c.Request = httptest.NewRequest( + http.MethodGet, + "/?time_range=1h&platform=openai&group_id=12&top_n=50", + nil, + ) + + filter, err := parseOpsOpenAITokenStatsFilter(c) + require.NoError(t, err) + require.Equal(t, "1h", filter.TimeRange) + require.Equal(t, "openai", filter.Platform) + require.NotNil(t, filter.GroupID) + require.Equal(t, int64(12), *filter.GroupID) + require.Equal(t, 50, filter.TopN) + require.Equal(t, 0, filter.Page) + require.Equal(t, 0, filter.PageSize) +} + +func TestParseOpsOpenAITokenStatsFilter_InvalidParams(t *testing.T) { + tests := []string{ + "/?time_range=7d", + "/?group_id=0", + "/?group_id=abc", + "/?top_n=0", + "/?top_n=101", + "/?top_n=10&page=1", + "/?top_n=10&page_size=20", + "/?page=0", + "/?page_size=0", + "/?page_size=101", + } + + gin.SetMode(gin.TestMode) + for _, rawURL := range tests { + w := httptest.NewRecorder() + c, _ := gin.CreateTestContext(w) + c.Request = httptest.NewRequest(http.MethodGet, rawURL, nil) + + _, err := parseOpsOpenAITokenStatsFilter(c) + require.Error(t, err, "url=%s", rawURL) + } +} + func TestParseOpsTimeRange(t *testing.T) { gin.SetMode(gin.TestMode) w := httptest.NewRecorder() diff --git a/backend/internal/handler/admin/admin_service_stub_test.go b/backend/internal/handler/admin/admin_service_stub_test.go index d44c99ea..9f3dcf80 100644 --- a/backend/internal/handler/admin/admin_service_stub_test.go +++ b/backend/internal/handler/admin/admin_service_stub_test.go @@ -327,6 +327,27 @@ func (s *stubAdminService) TestProxy(ctx context.Context, id int64) (*service.Pr return &service.ProxyTestResult{Success: true, Message: "ok"}, nil } +func (s *stubAdminService) CheckProxyQuality(ctx context.Context, id int64) (*service.ProxyQualityCheckResult, error) { + return &service.ProxyQualityCheckResult{ + ProxyID: id, + Score: 95, + Grade: "A", + Summary: "通过 5 项,告警 0 项,失败 0 项,挑战 0 项", + PassedCount: 5, + WarnCount: 0, + FailedCount: 0, + ChallengeCount: 0, + CheckedAt: time.Now().Unix(), + Items: []service.ProxyQualityCheckItem{ + {Target: "base_connectivity", Status: "pass", Message: "ok"}, + {Target: "openai", Status: "pass", HTTPStatus: 401}, + {Target: "anthropic", Status: "pass", HTTPStatus: 401}, + {Target: "gemini", Status: "pass", HTTPStatus: 200}, + {Target: "sora", Status: "pass", HTTPStatus: 401}, + }, + }, nil +} + func (s *stubAdminService) ListRedeemCodes(ctx context.Context, page, pageSize int, codeType, status, search string) ([]service.RedeemCode, int64, error) { return s.redeems, int64(len(s.redeems)), nil } diff --git a/backend/internal/handler/admin/batch_update_credentials_test.go b/backend/internal/handler/admin/batch_update_credentials_test.go new file mode 100644 index 00000000..c8185735 --- /dev/null +++ b/backend/internal/handler/admin/batch_update_credentials_test.go @@ -0,0 +1,208 @@ +//go:build unit + +package admin + +import ( + "bytes" + "context" + "encoding/json" + "errors" + "net/http" + "net/http/httptest" + "sync/atomic" + "testing" + + "github.com/gin-gonic/gin" + "github.com/stretchr/testify/require" + + "github.com/Wei-Shaw/sub2api/internal/service" +) + +// failingAdminService 嵌入 stubAdminService,可配置 UpdateAccount 在指定 ID 时失败。 +type failingAdminService struct { + *stubAdminService + failOnAccountID int64 + updateCallCount atomic.Int64 +} + +func (f *failingAdminService) UpdateAccount(ctx context.Context, id int64, input *service.UpdateAccountInput) (*service.Account, error) { + f.updateCallCount.Add(1) + if id == f.failOnAccountID { + return nil, errors.New("database error") + } + return f.stubAdminService.UpdateAccount(ctx, id, input) +} + +func setupAccountHandlerWithService(adminSvc service.AdminService) (*gin.Engine, *AccountHandler) { + gin.SetMode(gin.TestMode) + router := gin.New() + handler := NewAccountHandler(adminSvc, nil, nil, nil, nil, nil, nil, nil, nil, nil, nil, nil) + router.POST("/api/v1/admin/accounts/batch-update-credentials", handler.BatchUpdateCredentials) + return router, handler +} + +func TestBatchUpdateCredentials_AllSuccess(t *testing.T) { + svc := &failingAdminService{stubAdminService: newStubAdminService()} + router, _ := setupAccountHandlerWithService(svc) + + body, _ := json.Marshal(BatchUpdateCredentialsRequest{ + AccountIDs: []int64{1, 2, 3}, + Field: "account_uuid", + Value: "test-uuid", + }) + + w := httptest.NewRecorder() + req, _ := http.NewRequest("POST", "/api/v1/admin/accounts/batch-update-credentials", bytes.NewReader(body)) + req.Header.Set("Content-Type", "application/json") + router.ServeHTTP(w, req) + + require.Equal(t, http.StatusOK, w.Code, "全部成功时应返回 200") + require.Equal(t, int64(3), svc.updateCallCount.Load(), "应调用 3 次 UpdateAccount") +} + +func TestBatchUpdateCredentials_PartialFailure(t *testing.T) { + // 让第 2 个账号(ID=2)更新时失败 + svc := &failingAdminService{ + stubAdminService: newStubAdminService(), + failOnAccountID: 2, + } + router, _ := setupAccountHandlerWithService(svc) + + body, _ := json.Marshal(BatchUpdateCredentialsRequest{ + AccountIDs: []int64{1, 2, 3}, + Field: "org_uuid", + Value: "test-org", + }) + + w := httptest.NewRecorder() + req, _ := http.NewRequest("POST", "/api/v1/admin/accounts/batch-update-credentials", bytes.NewReader(body)) + req.Header.Set("Content-Type", "application/json") + router.ServeHTTP(w, req) + + // 实现采用"部分成功"模式:总是返回 200 + 成功/失败明细 + require.Equal(t, http.StatusOK, w.Code, "批量更新返回 200 + 成功/失败明细") + + var resp map[string]any + require.NoError(t, json.Unmarshal(w.Body.Bytes(), &resp)) + data := resp["data"].(map[string]any) + require.Equal(t, float64(2), data["success"], "应有 2 个成功") + require.Equal(t, float64(1), data["failed"], "应有 1 个失败") + + // 所有 3 个账号都会被尝试更新(非 fail-fast) + require.Equal(t, int64(3), svc.updateCallCount.Load(), + "应调用 3 次 UpdateAccount(逐个尝试,失败后继续)") +} + +func TestBatchUpdateCredentials_FirstAccountNotFound(t *testing.T) { + // GetAccount 在 stubAdminService 中总是成功的,需要创建一个 GetAccount 会失败的 stub + svc := &getAccountFailingService{ + stubAdminService: newStubAdminService(), + failOnAccountID: 1, + } + router, _ := setupAccountHandlerWithService(svc) + + body, _ := json.Marshal(BatchUpdateCredentialsRequest{ + AccountIDs: []int64{1, 2, 3}, + Field: "account_uuid", + Value: "test", + }) + + w := httptest.NewRecorder() + req, _ := http.NewRequest("POST", "/api/v1/admin/accounts/batch-update-credentials", bytes.NewReader(body)) + req.Header.Set("Content-Type", "application/json") + router.ServeHTTP(w, req) + + require.Equal(t, http.StatusNotFound, w.Code, "第一阶段验证失败应返回 404") +} + +// getAccountFailingService 模拟 GetAccount 在特定 ID 时返回 not found。 +type getAccountFailingService struct { + *stubAdminService + failOnAccountID int64 +} + +func (f *getAccountFailingService) GetAccount(ctx context.Context, id int64) (*service.Account, error) { + if id == f.failOnAccountID { + return nil, errors.New("not found") + } + return f.stubAdminService.GetAccount(ctx, id) +} + +func TestBatchUpdateCredentials_InterceptWarmupRequests_NonBool(t *testing.T) { + svc := &failingAdminService{stubAdminService: newStubAdminService()} + router, _ := setupAccountHandlerWithService(svc) + + // intercept_warmup_requests 传入非 bool 类型(string),应返回 400 + body, _ := json.Marshal(map[string]any{ + "account_ids": []int64{1}, + "field": "intercept_warmup_requests", + "value": "not-a-bool", + }) + + w := httptest.NewRecorder() + req, _ := http.NewRequest("POST", "/api/v1/admin/accounts/batch-update-credentials", bytes.NewReader(body)) + req.Header.Set("Content-Type", "application/json") + router.ServeHTTP(w, req) + + require.Equal(t, http.StatusBadRequest, w.Code, + "intercept_warmup_requests 传入非 bool 值应返回 400") +} + +func TestBatchUpdateCredentials_InterceptWarmupRequests_ValidBool(t *testing.T) { + svc := &failingAdminService{stubAdminService: newStubAdminService()} + router, _ := setupAccountHandlerWithService(svc) + + body, _ := json.Marshal(map[string]any{ + "account_ids": []int64{1}, + "field": "intercept_warmup_requests", + "value": true, + }) + + w := httptest.NewRecorder() + req, _ := http.NewRequest("POST", "/api/v1/admin/accounts/batch-update-credentials", bytes.NewReader(body)) + req.Header.Set("Content-Type", "application/json") + router.ServeHTTP(w, req) + + require.Equal(t, http.StatusOK, w.Code, + "intercept_warmup_requests 传入合法 bool 值应返回 200") +} + +func TestBatchUpdateCredentials_AccountUUID_NonString(t *testing.T) { + svc := &failingAdminService{stubAdminService: newStubAdminService()} + router, _ := setupAccountHandlerWithService(svc) + + // account_uuid 传入非 string 类型(number),应返回 400 + body, _ := json.Marshal(map[string]any{ + "account_ids": []int64{1}, + "field": "account_uuid", + "value": 12345, + }) + + w := httptest.NewRecorder() + req, _ := http.NewRequest("POST", "/api/v1/admin/accounts/batch-update-credentials", bytes.NewReader(body)) + req.Header.Set("Content-Type", "application/json") + router.ServeHTTP(w, req) + + require.Equal(t, http.StatusBadRequest, w.Code, + "account_uuid 传入非 string 值应返回 400") +} + +func TestBatchUpdateCredentials_AccountUUID_NullValue(t *testing.T) { + svc := &failingAdminService{stubAdminService: newStubAdminService()} + router, _ := setupAccountHandlerWithService(svc) + + // account_uuid 传入 null(设置为空),应正常通过 + body, _ := json.Marshal(map[string]any{ + "account_ids": []int64{1}, + "field": "account_uuid", + "value": nil, + }) + + w := httptest.NewRecorder() + req, _ := http.NewRequest("POST", "/api/v1/admin/accounts/batch-update-credentials", bytes.NewReader(body)) + req.Header.Set("Content-Type", "application/json") + router.ServeHTTP(w, req) + + require.Equal(t, http.StatusOK, w.Code, + "account_uuid 传入 null 应返回 200") +} diff --git a/backend/internal/handler/admin/dashboard_handler.go b/backend/internal/handler/admin/dashboard_handler.go index 18365186..fab66c04 100644 --- a/backend/internal/handler/admin/dashboard_handler.go +++ b/backend/internal/handler/admin/dashboard_handler.go @@ -379,7 +379,7 @@ func (h *DashboardHandler) GetBatchUsersUsage(c *gin.Context) { return } - stats, err := h.dashboardService.GetBatchUserUsageStats(c.Request.Context(), req.UserIDs) + stats, err := h.dashboardService.GetBatchUserUsageStats(c.Request.Context(), req.UserIDs, time.Time{}, time.Time{}) if err != nil { response.Error(c, 500, "Failed to get user usage stats") return @@ -407,7 +407,7 @@ func (h *DashboardHandler) GetBatchAPIKeysUsage(c *gin.Context) { return } - stats, err := h.dashboardService.GetBatchAPIKeyUsageStats(c.Request.Context(), req.APIKeyIDs) + stats, err := h.dashboardService.GetBatchAPIKeyUsageStats(c.Request.Context(), req.APIKeyIDs, time.Time{}, time.Time{}) if err != nil { response.Error(c, 500, "Failed to get API key usage stats") return diff --git a/backend/internal/handler/admin/group_handler.go b/backend/internal/handler/admin/group_handler.go index 7daaf281..25ff3c96 100644 --- a/backend/internal/handler/admin/group_handler.go +++ b/backend/internal/handler/admin/group_handler.go @@ -27,7 +27,7 @@ func NewGroupHandler(adminService service.AdminService) *GroupHandler { type CreateGroupRequest struct { Name string `json:"name" binding:"required"` Description string `json:"description"` - Platform string `json:"platform" binding:"omitempty,oneof=anthropic openai gemini antigravity"` + Platform string `json:"platform" binding:"omitempty,oneof=anthropic openai gemini antigravity sora"` RateMultiplier float64 `json:"rate_multiplier"` IsExclusive bool `json:"is_exclusive"` SubscriptionType string `json:"subscription_type" binding:"omitempty,oneof=standard subscription"` @@ -38,6 +38,10 @@ type CreateGroupRequest struct { ImagePrice1K *float64 `json:"image_price_1k"` ImagePrice2K *float64 `json:"image_price_2k"` ImagePrice4K *float64 `json:"image_price_4k"` + SoraImagePrice360 *float64 `json:"sora_image_price_360"` + SoraImagePrice540 *float64 `json:"sora_image_price_540"` + SoraVideoPricePerRequest *float64 `json:"sora_video_price_per_request"` + SoraVideoPricePerRequestHD *float64 `json:"sora_video_price_per_request_hd"` ClaudeCodeOnly bool `json:"claude_code_only"` FallbackGroupID *int64 `json:"fallback_group_id"` FallbackGroupIDOnInvalidRequest *int64 `json:"fallback_group_id_on_invalid_request"` @@ -55,7 +59,7 @@ type CreateGroupRequest struct { type UpdateGroupRequest struct { Name string `json:"name"` Description string `json:"description"` - Platform string `json:"platform" binding:"omitempty,oneof=anthropic openai gemini antigravity"` + Platform string `json:"platform" binding:"omitempty,oneof=anthropic openai gemini antigravity sora"` RateMultiplier *float64 `json:"rate_multiplier"` IsExclusive *bool `json:"is_exclusive"` Status string `json:"status" binding:"omitempty,oneof=active inactive"` @@ -67,6 +71,10 @@ type UpdateGroupRequest struct { ImagePrice1K *float64 `json:"image_price_1k"` ImagePrice2K *float64 `json:"image_price_2k"` ImagePrice4K *float64 `json:"image_price_4k"` + SoraImagePrice360 *float64 `json:"sora_image_price_360"` + SoraImagePrice540 *float64 `json:"sora_image_price_540"` + SoraVideoPricePerRequest *float64 `json:"sora_video_price_per_request"` + SoraVideoPricePerRequestHD *float64 `json:"sora_video_price_per_request_hd"` ClaudeCodeOnly *bool `json:"claude_code_only"` FallbackGroupID *int64 `json:"fallback_group_id"` FallbackGroupIDOnInvalidRequest *int64 `json:"fallback_group_id_on_invalid_request"` @@ -179,6 +187,10 @@ func (h *GroupHandler) Create(c *gin.Context) { ImagePrice1K: req.ImagePrice1K, ImagePrice2K: req.ImagePrice2K, ImagePrice4K: req.ImagePrice4K, + SoraImagePrice360: req.SoraImagePrice360, + SoraImagePrice540: req.SoraImagePrice540, + SoraVideoPricePerRequest: req.SoraVideoPricePerRequest, + SoraVideoPricePerRequestHD: req.SoraVideoPricePerRequestHD, ClaudeCodeOnly: req.ClaudeCodeOnly, FallbackGroupID: req.FallbackGroupID, FallbackGroupIDOnInvalidRequest: req.FallbackGroupIDOnInvalidRequest, @@ -225,6 +237,10 @@ func (h *GroupHandler) Update(c *gin.Context) { ImagePrice1K: req.ImagePrice1K, ImagePrice2K: req.ImagePrice2K, ImagePrice4K: req.ImagePrice4K, + SoraImagePrice360: req.SoraImagePrice360, + SoraImagePrice540: req.SoraImagePrice540, + SoraVideoPricePerRequest: req.SoraVideoPricePerRequest, + SoraVideoPricePerRequestHD: req.SoraVideoPricePerRequestHD, ClaudeCodeOnly: req.ClaudeCodeOnly, FallbackGroupID: req.FallbackGroupID, FallbackGroupIDOnInvalidRequest: req.FallbackGroupIDOnInvalidRequest, diff --git a/backend/internal/handler/admin/idempotency_helper.go b/backend/internal/handler/admin/idempotency_helper.go new file mode 100644 index 00000000..aa8eeaaf --- /dev/null +++ b/backend/internal/handler/admin/idempotency_helper.go @@ -0,0 +1,115 @@ +package admin + +import ( + "context" + "strconv" + "time" + + infraerrors "github.com/Wei-Shaw/sub2api/internal/pkg/errors" + "github.com/Wei-Shaw/sub2api/internal/pkg/logger" + "github.com/Wei-Shaw/sub2api/internal/pkg/response" + middleware2 "github.com/Wei-Shaw/sub2api/internal/server/middleware" + "github.com/Wei-Shaw/sub2api/internal/service" + + "github.com/gin-gonic/gin" +) + +type idempotencyStoreUnavailableMode int + +const ( + idempotencyStoreUnavailableFailClose idempotencyStoreUnavailableMode = iota + idempotencyStoreUnavailableFailOpen +) + +func executeAdminIdempotent( + c *gin.Context, + scope string, + payload any, + ttl time.Duration, + execute func(context.Context) (any, error), +) (*service.IdempotencyExecuteResult, error) { + coordinator := service.DefaultIdempotencyCoordinator() + if coordinator == nil { + data, err := execute(c.Request.Context()) + if err != nil { + return nil, err + } + return &service.IdempotencyExecuteResult{Data: data}, nil + } + + actorScope := "admin:0" + if subject, ok := middleware2.GetAuthSubjectFromContext(c); ok { + actorScope = "admin:" + strconv.FormatInt(subject.UserID, 10) + } + + return coordinator.Execute(c.Request.Context(), service.IdempotencyExecuteOptions{ + Scope: scope, + ActorScope: actorScope, + Method: c.Request.Method, + Route: c.FullPath(), + IdempotencyKey: c.GetHeader("Idempotency-Key"), + Payload: payload, + RequireKey: true, + TTL: ttl, + }, execute) +} + +func executeAdminIdempotentJSON( + c *gin.Context, + scope string, + payload any, + ttl time.Duration, + execute func(context.Context) (any, error), +) { + executeAdminIdempotentJSONWithMode(c, scope, payload, ttl, idempotencyStoreUnavailableFailClose, execute) +} + +func executeAdminIdempotentJSONFailOpenOnStoreUnavailable( + c *gin.Context, + scope string, + payload any, + ttl time.Duration, + execute func(context.Context) (any, error), +) { + executeAdminIdempotentJSONWithMode(c, scope, payload, ttl, idempotencyStoreUnavailableFailOpen, execute) +} + +func executeAdminIdempotentJSONWithMode( + c *gin.Context, + scope string, + payload any, + ttl time.Duration, + mode idempotencyStoreUnavailableMode, + execute func(context.Context) (any, error), +) { + result, err := executeAdminIdempotent(c, scope, payload, ttl, execute) + if err != nil { + if infraerrors.Code(err) == infraerrors.Code(service.ErrIdempotencyStoreUnavail) { + strategy := "fail_close" + if mode == idempotencyStoreUnavailableFailOpen { + strategy = "fail_open" + } + service.RecordIdempotencyStoreUnavailable(c.FullPath(), scope, "handler_"+strategy) + logger.LegacyPrintf("handler.idempotency", "[Idempotency] store unavailable: method=%s route=%s scope=%s strategy=%s", c.Request.Method, c.FullPath(), scope, strategy) + if mode == idempotencyStoreUnavailableFailOpen { + data, fallbackErr := execute(c.Request.Context()) + if fallbackErr != nil { + response.ErrorFrom(c, fallbackErr) + return + } + c.Header("X-Idempotency-Degraded", "store-unavailable") + response.Success(c, data) + return + } + } + if retryAfter := service.RetryAfterSecondsFromError(err); retryAfter > 0 { + c.Header("Retry-After", strconv.Itoa(retryAfter)) + } + response.ErrorFrom(c, err) + return + } + if result != nil && result.Replayed { + c.Header("X-Idempotency-Replayed", "true") + } + response.Success(c, result.Data) +} diff --git a/backend/internal/handler/admin/idempotency_helper_test.go b/backend/internal/handler/admin/idempotency_helper_test.go new file mode 100644 index 00000000..7dd86e16 --- /dev/null +++ b/backend/internal/handler/admin/idempotency_helper_test.go @@ -0,0 +1,285 @@ +package admin + +import ( + "bytes" + "context" + "errors" + "net/http" + "net/http/httptest" + "sync" + "sync/atomic" + "testing" + "time" + + "github.com/Wei-Shaw/sub2api/internal/service" + "github.com/gin-gonic/gin" + "github.com/stretchr/testify/require" +) + +type storeUnavailableRepoStub struct{} + +func (storeUnavailableRepoStub) CreateProcessing(context.Context, *service.IdempotencyRecord) (bool, error) { + return false, errors.New("store unavailable") +} +func (storeUnavailableRepoStub) GetByScopeAndKeyHash(context.Context, string, string) (*service.IdempotencyRecord, error) { + return nil, errors.New("store unavailable") +} +func (storeUnavailableRepoStub) TryReclaim(context.Context, int64, string, time.Time, time.Time, time.Time) (bool, error) { + return false, errors.New("store unavailable") +} +func (storeUnavailableRepoStub) ExtendProcessingLock(context.Context, int64, string, time.Time, time.Time) (bool, error) { + return false, errors.New("store unavailable") +} +func (storeUnavailableRepoStub) MarkSucceeded(context.Context, int64, int, string, time.Time) error { + return errors.New("store unavailable") +} +func (storeUnavailableRepoStub) MarkFailedRetryable(context.Context, int64, string, time.Time, time.Time) error { + return errors.New("store unavailable") +} +func (storeUnavailableRepoStub) DeleteExpired(context.Context, time.Time, int) (int64, error) { + return 0, errors.New("store unavailable") +} + +func TestExecuteAdminIdempotentJSONFailCloseOnStoreUnavailable(t *testing.T) { + gin.SetMode(gin.TestMode) + service.SetDefaultIdempotencyCoordinator(service.NewIdempotencyCoordinator(storeUnavailableRepoStub{}, service.DefaultIdempotencyConfig())) + t.Cleanup(func() { + service.SetDefaultIdempotencyCoordinator(nil) + }) + + var executed int + router := gin.New() + router.POST("/idempotent", func(c *gin.Context) { + executeAdminIdempotentJSON(c, "admin.test.high", map[string]any{"a": 1}, time.Minute, func(ctx context.Context) (any, error) { + executed++ + return gin.H{"ok": true}, nil + }) + }) + + req := httptest.NewRequest(http.MethodPost, "/idempotent", bytes.NewBufferString(`{"a":1}`)) + req.Header.Set("Content-Type", "application/json") + req.Header.Set("Idempotency-Key", "test-key-1") + rec := httptest.NewRecorder() + router.ServeHTTP(rec, req) + + require.Equal(t, http.StatusServiceUnavailable, rec.Code) + require.Equal(t, 0, executed, "fail-close should block business execution when idempotency store is unavailable") +} + +func TestExecuteAdminIdempotentJSONFailOpenOnStoreUnavailable(t *testing.T) { + gin.SetMode(gin.TestMode) + service.SetDefaultIdempotencyCoordinator(service.NewIdempotencyCoordinator(storeUnavailableRepoStub{}, service.DefaultIdempotencyConfig())) + t.Cleanup(func() { + service.SetDefaultIdempotencyCoordinator(nil) + }) + + var executed int + router := gin.New() + router.POST("/idempotent", func(c *gin.Context) { + executeAdminIdempotentJSONFailOpenOnStoreUnavailable(c, "admin.test.medium", map[string]any{"a": 1}, time.Minute, func(ctx context.Context) (any, error) { + executed++ + return gin.H{"ok": true}, nil + }) + }) + + req := httptest.NewRequest(http.MethodPost, "/idempotent", bytes.NewBufferString(`{"a":1}`)) + req.Header.Set("Content-Type", "application/json") + req.Header.Set("Idempotency-Key", "test-key-2") + rec := httptest.NewRecorder() + router.ServeHTTP(rec, req) + + require.Equal(t, http.StatusOK, rec.Code) + require.Equal(t, "store-unavailable", rec.Header().Get("X-Idempotency-Degraded")) + require.Equal(t, 1, executed, "fail-open strategy should allow semantic idempotent path to continue") +} + +type memoryIdempotencyRepoStub struct { + mu sync.Mutex + nextID int64 + data map[string]*service.IdempotencyRecord +} + +func newMemoryIdempotencyRepoStub() *memoryIdempotencyRepoStub { + return &memoryIdempotencyRepoStub{ + nextID: 1, + data: make(map[string]*service.IdempotencyRecord), + } +} + +func (r *memoryIdempotencyRepoStub) key(scope, keyHash string) string { + return scope + "|" + keyHash +} + +func (r *memoryIdempotencyRepoStub) clone(in *service.IdempotencyRecord) *service.IdempotencyRecord { + if in == nil { + return nil + } + out := *in + if in.LockedUntil != nil { + v := *in.LockedUntil + out.LockedUntil = &v + } + if in.ResponseBody != nil { + v := *in.ResponseBody + out.ResponseBody = &v + } + if in.ResponseStatus != nil { + v := *in.ResponseStatus + out.ResponseStatus = &v + } + if in.ErrorReason != nil { + v := *in.ErrorReason + out.ErrorReason = &v + } + return &out +} + +func (r *memoryIdempotencyRepoStub) CreateProcessing(_ context.Context, record *service.IdempotencyRecord) (bool, error) { + r.mu.Lock() + defer r.mu.Unlock() + k := r.key(record.Scope, record.IdempotencyKeyHash) + if _, ok := r.data[k]; ok { + return false, nil + } + cp := r.clone(record) + cp.ID = r.nextID + r.nextID++ + r.data[k] = cp + record.ID = cp.ID + return true, nil +} + +func (r *memoryIdempotencyRepoStub) GetByScopeAndKeyHash(_ context.Context, scope, keyHash string) (*service.IdempotencyRecord, error) { + r.mu.Lock() + defer r.mu.Unlock() + return r.clone(r.data[r.key(scope, keyHash)]), nil +} + +func (r *memoryIdempotencyRepoStub) TryReclaim(_ context.Context, id int64, fromStatus string, now, newLockedUntil, newExpiresAt time.Time) (bool, error) { + r.mu.Lock() + defer r.mu.Unlock() + for _, rec := range r.data { + if rec.ID != id { + continue + } + if rec.Status != fromStatus { + return false, nil + } + if rec.LockedUntil != nil && rec.LockedUntil.After(now) { + return false, nil + } + rec.Status = service.IdempotencyStatusProcessing + rec.LockedUntil = &newLockedUntil + rec.ExpiresAt = newExpiresAt + rec.ErrorReason = nil + return true, nil + } + return false, nil +} + +func (r *memoryIdempotencyRepoStub) ExtendProcessingLock(_ context.Context, id int64, requestFingerprint string, newLockedUntil, newExpiresAt time.Time) (bool, error) { + r.mu.Lock() + defer r.mu.Unlock() + for _, rec := range r.data { + if rec.ID != id { + continue + } + if rec.Status != service.IdempotencyStatusProcessing || rec.RequestFingerprint != requestFingerprint { + return false, nil + } + rec.LockedUntil = &newLockedUntil + rec.ExpiresAt = newExpiresAt + return true, nil + } + return false, nil +} + +func (r *memoryIdempotencyRepoStub) MarkSucceeded(_ context.Context, id int64, responseStatus int, responseBody string, expiresAt time.Time) error { + r.mu.Lock() + defer r.mu.Unlock() + for _, rec := range r.data { + if rec.ID != id { + continue + } + rec.Status = service.IdempotencyStatusSucceeded + rec.LockedUntil = nil + rec.ExpiresAt = expiresAt + rec.ResponseStatus = &responseStatus + rec.ResponseBody = &responseBody + rec.ErrorReason = nil + return nil + } + return nil +} + +func (r *memoryIdempotencyRepoStub) MarkFailedRetryable(_ context.Context, id int64, errorReason string, lockedUntil, expiresAt time.Time) error { + r.mu.Lock() + defer r.mu.Unlock() + for _, rec := range r.data { + if rec.ID != id { + continue + } + rec.Status = service.IdempotencyStatusFailedRetryable + rec.LockedUntil = &lockedUntil + rec.ExpiresAt = expiresAt + rec.ErrorReason = &errorReason + return nil + } + return nil +} + +func (r *memoryIdempotencyRepoStub) DeleteExpired(_ context.Context, _ time.Time, _ int) (int64, error) { + return 0, nil +} + +func TestExecuteAdminIdempotentJSONConcurrentRetryOnlyOneSideEffect(t *testing.T) { + gin.SetMode(gin.TestMode) + repo := newMemoryIdempotencyRepoStub() + cfg := service.DefaultIdempotencyConfig() + cfg.ProcessingTimeout = 2 * time.Second + service.SetDefaultIdempotencyCoordinator(service.NewIdempotencyCoordinator(repo, cfg)) + t.Cleanup(func() { + service.SetDefaultIdempotencyCoordinator(nil) + }) + + var executed atomic.Int32 + router := gin.New() + router.POST("/idempotent", func(c *gin.Context) { + executeAdminIdempotentJSON(c, "admin.test.concurrent", map[string]any{"a": 1}, time.Minute, func(ctx context.Context) (any, error) { + executed.Add(1) + time.Sleep(120 * time.Millisecond) + return gin.H{"ok": true}, nil + }) + }) + + call := func() (int, http.Header) { + req := httptest.NewRequest(http.MethodPost, "/idempotent", bytes.NewBufferString(`{"a":1}`)) + req.Header.Set("Content-Type", "application/json") + req.Header.Set("Idempotency-Key", "same-key") + rec := httptest.NewRecorder() + router.ServeHTTP(rec, req) + return rec.Code, rec.Header() + } + + var status1, status2 int + var wg sync.WaitGroup + wg.Add(2) + go func() { + defer wg.Done() + status1, _ = call() + }() + go func() { + defer wg.Done() + status2, _ = call() + }() + wg.Wait() + + require.Contains(t, []int{http.StatusOK, http.StatusConflict}, status1) + require.Contains(t, []int{http.StatusOK, http.StatusConflict}, status2) + require.Equal(t, int32(1), executed.Load(), "same idempotency key should execute side-effect only once") + + status3, headers3 := call() + require.Equal(t, http.StatusOK, status3) + require.Equal(t, "true", headers3.Get("X-Idempotency-Replayed")) + require.Equal(t, int32(1), executed.Load()) +} diff --git a/backend/internal/handler/admin/openai_oauth_handler.go b/backend/internal/handler/admin/openai_oauth_handler.go index ed86fea9..cf43f89e 100644 --- a/backend/internal/handler/admin/openai_oauth_handler.go +++ b/backend/internal/handler/admin/openai_oauth_handler.go @@ -2,6 +2,7 @@ package admin import ( "strconv" + "strings" "github.com/Wei-Shaw/sub2api/internal/handler/dto" "github.com/Wei-Shaw/sub2api/internal/pkg/response" @@ -16,6 +17,13 @@ type OpenAIOAuthHandler struct { adminService service.AdminService } +func oauthPlatformFromPath(c *gin.Context) string { + if strings.Contains(c.FullPath(), "/admin/sora/") { + return service.PlatformSora + } + return service.PlatformOpenAI +} + // NewOpenAIOAuthHandler creates a new OpenAI OAuth handler func NewOpenAIOAuthHandler(openaiOAuthService *service.OpenAIOAuthService, adminService service.AdminService) *OpenAIOAuthHandler { return &OpenAIOAuthHandler{ @@ -52,6 +60,7 @@ func (h *OpenAIOAuthHandler) GenerateAuthURL(c *gin.Context) { type OpenAIExchangeCodeRequest struct { SessionID string `json:"session_id" binding:"required"` Code string `json:"code" binding:"required"` + State string `json:"state" binding:"required"` RedirectURI string `json:"redirect_uri"` ProxyID *int64 `json:"proxy_id"` } @@ -68,6 +77,7 @@ func (h *OpenAIOAuthHandler) ExchangeCode(c *gin.Context) { tokenInfo, err := h.openaiOAuthService.ExchangeCode(c.Request.Context(), &service.OpenAIExchangeCodeInput{ SessionID: req.SessionID, Code: req.Code, + State: req.State, RedirectURI: req.RedirectURI, ProxyID: req.ProxyID, }) @@ -81,18 +91,29 @@ func (h *OpenAIOAuthHandler) ExchangeCode(c *gin.Context) { // OpenAIRefreshTokenRequest represents the request for refreshing OpenAI token type OpenAIRefreshTokenRequest struct { - RefreshToken string `json:"refresh_token" binding:"required"` + RefreshToken string `json:"refresh_token"` + RT string `json:"rt"` + ClientID string `json:"client_id"` ProxyID *int64 `json:"proxy_id"` } // RefreshToken refreshes an OpenAI OAuth token // POST /api/v1/admin/openai/refresh-token +// POST /api/v1/admin/sora/rt2at func (h *OpenAIOAuthHandler) RefreshToken(c *gin.Context) { var req OpenAIRefreshTokenRequest if err := c.ShouldBindJSON(&req); err != nil { response.BadRequest(c, "Invalid request: "+err.Error()) return } + refreshToken := strings.TrimSpace(req.RefreshToken) + if refreshToken == "" { + refreshToken = strings.TrimSpace(req.RT) + } + if refreshToken == "" { + response.BadRequest(c, "refresh_token is required") + return + } var proxyURL string if req.ProxyID != nil { @@ -102,7 +123,7 @@ func (h *OpenAIOAuthHandler) RefreshToken(c *gin.Context) { } } - tokenInfo, err := h.openaiOAuthService.RefreshToken(c.Request.Context(), req.RefreshToken, proxyURL) + tokenInfo, err := h.openaiOAuthService.RefreshTokenWithClientID(c.Request.Context(), refreshToken, proxyURL, strings.TrimSpace(req.ClientID)) if err != nil { response.ErrorFrom(c, err) return @@ -111,8 +132,39 @@ func (h *OpenAIOAuthHandler) RefreshToken(c *gin.Context) { response.Success(c, tokenInfo) } -// RefreshAccountToken refreshes token for a specific OpenAI account +// ExchangeSoraSessionToken exchanges Sora session token to access token +// POST /api/v1/admin/sora/st2at +func (h *OpenAIOAuthHandler) ExchangeSoraSessionToken(c *gin.Context) { + var req struct { + SessionToken string `json:"session_token"` + ST string `json:"st"` + ProxyID *int64 `json:"proxy_id"` + } + if err := c.ShouldBindJSON(&req); err != nil { + response.BadRequest(c, "Invalid request: "+err.Error()) + return + } + + sessionToken := strings.TrimSpace(req.SessionToken) + if sessionToken == "" { + sessionToken = strings.TrimSpace(req.ST) + } + if sessionToken == "" { + response.BadRequest(c, "session_token is required") + return + } + + tokenInfo, err := h.openaiOAuthService.ExchangeSoraSessionToken(c.Request.Context(), sessionToken, req.ProxyID) + if err != nil { + response.ErrorFrom(c, err) + return + } + response.Success(c, tokenInfo) +} + +// RefreshAccountToken refreshes token for a specific OpenAI/Sora account // POST /api/v1/admin/openai/accounts/:id/refresh +// POST /api/v1/admin/sora/accounts/:id/refresh func (h *OpenAIOAuthHandler) RefreshAccountToken(c *gin.Context) { accountID, err := strconv.ParseInt(c.Param("id"), 10, 64) if err != nil { @@ -127,9 +179,9 @@ func (h *OpenAIOAuthHandler) RefreshAccountToken(c *gin.Context) { return } - // Ensure account is OpenAI platform - if !account.IsOpenAI() { - response.BadRequest(c, "Account is not an OpenAI account") + platform := oauthPlatformFromPath(c) + if account.Platform != platform { + response.BadRequest(c, "Account platform does not match OAuth endpoint") return } @@ -167,12 +219,14 @@ func (h *OpenAIOAuthHandler) RefreshAccountToken(c *gin.Context) { response.Success(c, dto.AccountFromService(updatedAccount)) } -// CreateAccountFromOAuth creates a new OpenAI OAuth account from token info +// CreateAccountFromOAuth creates a new OpenAI/Sora OAuth account from token info // POST /api/v1/admin/openai/create-from-oauth +// POST /api/v1/admin/sora/create-from-oauth func (h *OpenAIOAuthHandler) CreateAccountFromOAuth(c *gin.Context) { var req struct { SessionID string `json:"session_id" binding:"required"` Code string `json:"code" binding:"required"` + State string `json:"state" binding:"required"` RedirectURI string `json:"redirect_uri"` ProxyID *int64 `json:"proxy_id"` Name string `json:"name"` @@ -189,6 +243,7 @@ func (h *OpenAIOAuthHandler) CreateAccountFromOAuth(c *gin.Context) { tokenInfo, err := h.openaiOAuthService.ExchangeCode(c.Request.Context(), &service.OpenAIExchangeCodeInput{ SessionID: req.SessionID, Code: req.Code, + State: req.State, RedirectURI: req.RedirectURI, ProxyID: req.ProxyID, }) @@ -200,19 +255,25 @@ func (h *OpenAIOAuthHandler) CreateAccountFromOAuth(c *gin.Context) { // Build credentials from token info credentials := h.openaiOAuthService.BuildAccountCredentials(tokenInfo) + platform := oauthPlatformFromPath(c) + // Use email as default name if not provided name := req.Name if name == "" && tokenInfo.Email != "" { name = tokenInfo.Email } if name == "" { - name = "OpenAI OAuth Account" + if platform == service.PlatformSora { + name = "Sora OAuth Account" + } else { + name = "OpenAI OAuth Account" + } } // Create account account, err := h.adminService.CreateAccount(c.Request.Context(), &service.CreateAccountInput{ Name: name, - Platform: "openai", + Platform: platform, Type: "oauth", Credentials: credentials, ProxyID: req.ProxyID, diff --git a/backend/internal/handler/admin/ops_dashboard_handler.go b/backend/internal/handler/admin/ops_dashboard_handler.go index 2c87f734..01f7bc2b 100644 --- a/backend/internal/handler/admin/ops_dashboard_handler.go +++ b/backend/internal/handler/admin/ops_dashboard_handler.go @@ -1,6 +1,7 @@ package admin import ( + "fmt" "net/http" "strconv" "strings" @@ -218,6 +219,115 @@ func (h *OpsHandler) GetDashboardErrorDistribution(c *gin.Context) { response.Success(c, data) } +// GetDashboardOpenAITokenStats returns OpenAI token efficiency stats grouped by model. +// GET /api/v1/admin/ops/dashboard/openai-token-stats +func (h *OpsHandler) GetDashboardOpenAITokenStats(c *gin.Context) { + if h.opsService == nil { + response.Error(c, http.StatusServiceUnavailable, "Ops service not available") + return + } + if err := h.opsService.RequireMonitoringEnabled(c.Request.Context()); err != nil { + response.ErrorFrom(c, err) + return + } + + filter, err := parseOpsOpenAITokenStatsFilter(c) + if err != nil { + response.BadRequest(c, err.Error()) + return + } + + data, err := h.opsService.GetOpenAITokenStats(c.Request.Context(), filter) + if err != nil { + response.ErrorFrom(c, err) + return + } + response.Success(c, data) +} + +func parseOpsOpenAITokenStatsFilter(c *gin.Context) (*service.OpsOpenAITokenStatsFilter, error) { + if c == nil { + return nil, fmt.Errorf("invalid request") + } + + timeRange := strings.TrimSpace(c.Query("time_range")) + if timeRange == "" { + timeRange = "30d" + } + dur, ok := parseOpsOpenAITokenStatsDuration(timeRange) + if !ok { + return nil, fmt.Errorf("invalid time_range") + } + end := time.Now().UTC() + start := end.Add(-dur) + + filter := &service.OpsOpenAITokenStatsFilter{ + TimeRange: timeRange, + StartTime: start, + EndTime: end, + Platform: strings.TrimSpace(c.Query("platform")), + } + + if v := strings.TrimSpace(c.Query("group_id")); v != "" { + id, err := strconv.ParseInt(v, 10, 64) + if err != nil || id <= 0 { + return nil, fmt.Errorf("invalid group_id") + } + filter.GroupID = &id + } + + topNRaw := strings.TrimSpace(c.Query("top_n")) + pageRaw := strings.TrimSpace(c.Query("page")) + pageSizeRaw := strings.TrimSpace(c.Query("page_size")) + if topNRaw != "" && (pageRaw != "" || pageSizeRaw != "") { + return nil, fmt.Errorf("invalid query: top_n cannot be used with page/page_size") + } + + if topNRaw != "" { + topN, err := strconv.Atoi(topNRaw) + if err != nil || topN < 1 || topN > 100 { + return nil, fmt.Errorf("invalid top_n") + } + filter.TopN = topN + return filter, nil + } + + filter.Page = 1 + filter.PageSize = 20 + if pageRaw != "" { + page, err := strconv.Atoi(pageRaw) + if err != nil || page < 1 { + return nil, fmt.Errorf("invalid page") + } + filter.Page = page + } + if pageSizeRaw != "" { + pageSize, err := strconv.Atoi(pageSizeRaw) + if err != nil || pageSize < 1 || pageSize > 100 { + return nil, fmt.Errorf("invalid page_size") + } + filter.PageSize = pageSize + } + return filter, nil +} + +func parseOpsOpenAITokenStatsDuration(v string) (time.Duration, bool) { + switch strings.TrimSpace(v) { + case "30m": + return 30 * time.Minute, true + case "1h": + return time.Hour, true + case "1d": + return 24 * time.Hour, true + case "15d": + return 15 * 24 * time.Hour, true + case "30d": + return 30 * 24 * time.Hour, true + default: + return 0, false + } +} + func pickThroughputBucketSeconds(window time.Duration) int { // Keep buckets predictable and avoid huge responses. switch { diff --git a/backend/internal/handler/admin/ops_runtime_logging_handler_test.go b/backend/internal/handler/admin/ops_runtime_logging_handler_test.go new file mode 100644 index 00000000..0e84b4f9 --- /dev/null +++ b/backend/internal/handler/admin/ops_runtime_logging_handler_test.go @@ -0,0 +1,173 @@ +package admin + +import ( + "bytes" + "context" + "encoding/json" + "net/http" + "net/http/httptest" + "testing" + + "github.com/Wei-Shaw/sub2api/internal/config" + "github.com/Wei-Shaw/sub2api/internal/pkg/logger" + "github.com/Wei-Shaw/sub2api/internal/server/middleware" + "github.com/Wei-Shaw/sub2api/internal/service" + "github.com/gin-gonic/gin" +) + +type testSettingRepo struct { + values map[string]string +} + +func newTestSettingRepo() *testSettingRepo { + return &testSettingRepo{values: map[string]string{}} +} + +func (s *testSettingRepo) Get(ctx context.Context, key string) (*service.Setting, error) { + v, err := s.GetValue(ctx, key) + if err != nil { + return nil, err + } + return &service.Setting{Key: key, Value: v}, nil +} +func (s *testSettingRepo) GetValue(ctx context.Context, key string) (string, error) { + v, ok := s.values[key] + if !ok { + return "", service.ErrSettingNotFound + } + return v, nil +} +func (s *testSettingRepo) Set(ctx context.Context, key, value string) error { + s.values[key] = value + return nil +} +func (s *testSettingRepo) GetMultiple(ctx context.Context, keys []string) (map[string]string, error) { + out := make(map[string]string, len(keys)) + for _, k := range keys { + if v, ok := s.values[k]; ok { + out[k] = v + } + } + return out, nil +} +func (s *testSettingRepo) SetMultiple(ctx context.Context, settings map[string]string) error { + for k, v := range settings { + s.values[k] = v + } + return nil +} +func (s *testSettingRepo) GetAll(ctx context.Context) (map[string]string, error) { + out := make(map[string]string, len(s.values)) + for k, v := range s.values { + out[k] = v + } + return out, nil +} +func (s *testSettingRepo) Delete(ctx context.Context, key string) error { + delete(s.values, key) + return nil +} + +func newOpsRuntimeRouter(handler *OpsHandler, withUser bool) *gin.Engine { + gin.SetMode(gin.TestMode) + r := gin.New() + if withUser { + r.Use(func(c *gin.Context) { + c.Set(string(middleware.ContextKeyUser), middleware.AuthSubject{UserID: 7}) + c.Next() + }) + } + r.GET("/runtime/logging", handler.GetRuntimeLogConfig) + r.PUT("/runtime/logging", handler.UpdateRuntimeLogConfig) + r.POST("/runtime/logging/reset", handler.ResetRuntimeLogConfig) + return r +} + +func newRuntimeOpsService(t *testing.T) *service.OpsService { + t.Helper() + if err := logger.Init(logger.InitOptions{ + Level: "info", + Format: "json", + ServiceName: "sub2api", + Environment: "test", + Output: logger.OutputOptions{ + ToStdout: false, + ToFile: false, + }, + }); err != nil { + t.Fatalf("init logger: %v", err) + } + + settingRepo := newTestSettingRepo() + cfg := &config.Config{ + Ops: config.OpsConfig{Enabled: true}, + Log: config.LogConfig{ + Level: "info", + Caller: true, + StacktraceLevel: "error", + Sampling: config.LogSamplingConfig{ + Enabled: false, + Initial: 100, + Thereafter: 100, + }, + }, + } + return service.NewOpsService(nil, settingRepo, cfg, nil, nil, nil, nil, nil, nil, nil, nil) +} + +func TestOpsRuntimeLoggingHandler_GetConfig(t *testing.T) { + h := NewOpsHandler(newRuntimeOpsService(t)) + r := newOpsRuntimeRouter(h, false) + + w := httptest.NewRecorder() + req := httptest.NewRequest(http.MethodGet, "/runtime/logging", nil) + r.ServeHTTP(w, req) + if w.Code != http.StatusOK { + t.Fatalf("status=%d, want 200", w.Code) + } +} + +func TestOpsRuntimeLoggingHandler_UpdateUnauthorized(t *testing.T) { + h := NewOpsHandler(newRuntimeOpsService(t)) + r := newOpsRuntimeRouter(h, false) + + body := `{"level":"debug","enable_sampling":false,"sampling_initial":100,"sampling_thereafter":100,"caller":true,"stacktrace_level":"error","retention_days":30}` + w := httptest.NewRecorder() + req := httptest.NewRequest(http.MethodPut, "/runtime/logging", bytes.NewBufferString(body)) + req.Header.Set("Content-Type", "application/json") + r.ServeHTTP(w, req) + if w.Code != http.StatusUnauthorized { + t.Fatalf("status=%d, want 401", w.Code) + } +} + +func TestOpsRuntimeLoggingHandler_UpdateAndResetSuccess(t *testing.T) { + h := NewOpsHandler(newRuntimeOpsService(t)) + r := newOpsRuntimeRouter(h, true) + + payload := map[string]any{ + "level": "debug", + "enable_sampling": false, + "sampling_initial": 100, + "sampling_thereafter": 100, + "caller": true, + "stacktrace_level": "error", + "retention_days": 30, + } + raw, _ := json.Marshal(payload) + + w := httptest.NewRecorder() + req := httptest.NewRequest(http.MethodPut, "/runtime/logging", bytes.NewReader(raw)) + req.Header.Set("Content-Type", "application/json") + r.ServeHTTP(w, req) + if w.Code != http.StatusOK { + t.Fatalf("update status=%d, want 200, body=%s", w.Code, w.Body.String()) + } + + w = httptest.NewRecorder() + req = httptest.NewRequest(http.MethodPost, "/runtime/logging/reset", nil) + r.ServeHTTP(w, req) + if w.Code != http.StatusOK { + t.Fatalf("reset status=%d, want 200, body=%s", w.Code, w.Body.String()) + } +} diff --git a/backend/internal/handler/admin/ops_settings_handler.go b/backend/internal/handler/admin/ops_settings_handler.go index ebc8bf49..226b89f3 100644 --- a/backend/internal/handler/admin/ops_settings_handler.go +++ b/backend/internal/handler/admin/ops_settings_handler.go @@ -4,6 +4,7 @@ import ( "net/http" "github.com/Wei-Shaw/sub2api/internal/pkg/response" + "github.com/Wei-Shaw/sub2api/internal/server/middleware" "github.com/Wei-Shaw/sub2api/internal/service" "github.com/gin-gonic/gin" ) @@ -101,6 +102,84 @@ func (h *OpsHandler) UpdateAlertRuntimeSettings(c *gin.Context) { response.Success(c, updated) } +// GetRuntimeLogConfig returns runtime log config (DB-backed). +// GET /api/v1/admin/ops/runtime/logging +func (h *OpsHandler) GetRuntimeLogConfig(c *gin.Context) { + if h.opsService == nil { + response.Error(c, http.StatusServiceUnavailable, "Ops service not available") + return + } + if err := h.opsService.RequireMonitoringEnabled(c.Request.Context()); err != nil { + response.ErrorFrom(c, err) + return + } + + cfg, err := h.opsService.GetRuntimeLogConfig(c.Request.Context()) + if err != nil { + response.Error(c, http.StatusInternalServerError, "Failed to get runtime log config") + return + } + response.Success(c, cfg) +} + +// UpdateRuntimeLogConfig updates runtime log config and applies changes immediately. +// PUT /api/v1/admin/ops/runtime/logging +func (h *OpsHandler) UpdateRuntimeLogConfig(c *gin.Context) { + if h.opsService == nil { + response.Error(c, http.StatusServiceUnavailable, "Ops service not available") + return + } + if err := h.opsService.RequireMonitoringEnabled(c.Request.Context()); err != nil { + response.ErrorFrom(c, err) + return + } + + var req service.OpsRuntimeLogConfig + if err := c.ShouldBindJSON(&req); err != nil { + response.BadRequest(c, "Invalid request body") + return + } + + subject, ok := middleware.GetAuthSubjectFromContext(c) + if !ok || subject.UserID <= 0 { + response.Error(c, http.StatusUnauthorized, "Unauthorized") + return + } + + updated, err := h.opsService.UpdateRuntimeLogConfig(c.Request.Context(), &req, subject.UserID) + if err != nil { + response.Error(c, http.StatusBadRequest, err.Error()) + return + } + response.Success(c, updated) +} + +// ResetRuntimeLogConfig removes runtime override and falls back to env/yaml baseline. +// POST /api/v1/admin/ops/runtime/logging/reset +func (h *OpsHandler) ResetRuntimeLogConfig(c *gin.Context) { + if h.opsService == nil { + response.Error(c, http.StatusServiceUnavailable, "Ops service not available") + return + } + if err := h.opsService.RequireMonitoringEnabled(c.Request.Context()); err != nil { + response.ErrorFrom(c, err) + return + } + + subject, ok := middleware.GetAuthSubjectFromContext(c) + if !ok || subject.UserID <= 0 { + response.Error(c, http.StatusUnauthorized, "Unauthorized") + return + } + + updated, err := h.opsService.ResetRuntimeLogConfig(c.Request.Context(), subject.UserID) + if err != nil { + response.Error(c, http.StatusBadRequest, err.Error()) + return + } + response.Success(c, updated) +} + // GetAdvancedSettings returns Ops advanced settings (DB-backed). // GET /api/v1/admin/ops/advanced-settings func (h *OpsHandler) GetAdvancedSettings(c *gin.Context) { diff --git a/backend/internal/handler/admin/ops_system_log_handler.go b/backend/internal/handler/admin/ops_system_log_handler.go new file mode 100644 index 00000000..31fd51eb --- /dev/null +++ b/backend/internal/handler/admin/ops_system_log_handler.go @@ -0,0 +1,174 @@ +package admin + +import ( + "net/http" + "strconv" + "strings" + "time" + + "github.com/Wei-Shaw/sub2api/internal/pkg/response" + "github.com/Wei-Shaw/sub2api/internal/server/middleware" + "github.com/Wei-Shaw/sub2api/internal/service" + "github.com/gin-gonic/gin" +) + +type opsSystemLogCleanupRequest struct { + StartTime string `json:"start_time"` + EndTime string `json:"end_time"` + + Level string `json:"level"` + Component string `json:"component"` + RequestID string `json:"request_id"` + ClientRequestID string `json:"client_request_id"` + UserID *int64 `json:"user_id"` + AccountID *int64 `json:"account_id"` + Platform string `json:"platform"` + Model string `json:"model"` + Query string `json:"q"` +} + +// ListSystemLogs returns indexed system logs. +// GET /api/v1/admin/ops/system-logs +func (h *OpsHandler) ListSystemLogs(c *gin.Context) { + if h.opsService == nil { + response.Error(c, http.StatusServiceUnavailable, "Ops service not available") + return + } + if err := h.opsService.RequireMonitoringEnabled(c.Request.Context()); err != nil { + response.ErrorFrom(c, err) + return + } + + page, pageSize := response.ParsePagination(c) + if pageSize > 200 { + pageSize = 200 + } + + start, end, err := parseOpsTimeRange(c, "1h") + if err != nil { + response.BadRequest(c, err.Error()) + return + } + + filter := &service.OpsSystemLogFilter{ + Page: page, + PageSize: pageSize, + StartTime: &start, + EndTime: &end, + Level: strings.TrimSpace(c.Query("level")), + Component: strings.TrimSpace(c.Query("component")), + RequestID: strings.TrimSpace(c.Query("request_id")), + ClientRequestID: strings.TrimSpace(c.Query("client_request_id")), + Platform: strings.TrimSpace(c.Query("platform")), + Model: strings.TrimSpace(c.Query("model")), + Query: strings.TrimSpace(c.Query("q")), + } + if v := strings.TrimSpace(c.Query("user_id")); v != "" { + id, parseErr := strconv.ParseInt(v, 10, 64) + if parseErr != nil || id <= 0 { + response.BadRequest(c, "Invalid user_id") + return + } + filter.UserID = &id + } + if v := strings.TrimSpace(c.Query("account_id")); v != "" { + id, parseErr := strconv.ParseInt(v, 10, 64) + if parseErr != nil || id <= 0 { + response.BadRequest(c, "Invalid account_id") + return + } + filter.AccountID = &id + } + + result, err := h.opsService.ListSystemLogs(c.Request.Context(), filter) + if err != nil { + response.ErrorFrom(c, err) + return + } + response.Paginated(c, result.Logs, int64(result.Total), result.Page, result.PageSize) +} + +// CleanupSystemLogs deletes indexed system logs by filter. +// POST /api/v1/admin/ops/system-logs/cleanup +func (h *OpsHandler) CleanupSystemLogs(c *gin.Context) { + if h.opsService == nil { + response.Error(c, http.StatusServiceUnavailable, "Ops service not available") + return + } + if err := h.opsService.RequireMonitoringEnabled(c.Request.Context()); err != nil { + response.ErrorFrom(c, err) + return + } + + subject, ok := middleware.GetAuthSubjectFromContext(c) + if !ok || subject.UserID <= 0 { + response.Error(c, http.StatusUnauthorized, "Unauthorized") + return + } + + var req opsSystemLogCleanupRequest + if err := c.ShouldBindJSON(&req); err != nil { + response.BadRequest(c, "Invalid request body") + return + } + + parseTS := func(raw string) (*time.Time, error) { + raw = strings.TrimSpace(raw) + if raw == "" { + return nil, nil + } + if t, err := time.Parse(time.RFC3339Nano, raw); err == nil { + return &t, nil + } + t, err := time.Parse(time.RFC3339, raw) + if err != nil { + return nil, err + } + return &t, nil + } + start, err := parseTS(req.StartTime) + if err != nil { + response.BadRequest(c, "Invalid start_time") + return + } + end, err := parseTS(req.EndTime) + if err != nil { + response.BadRequest(c, "Invalid end_time") + return + } + + filter := &service.OpsSystemLogCleanupFilter{ + StartTime: start, + EndTime: end, + Level: strings.TrimSpace(req.Level), + Component: strings.TrimSpace(req.Component), + RequestID: strings.TrimSpace(req.RequestID), + ClientRequestID: strings.TrimSpace(req.ClientRequestID), + UserID: req.UserID, + AccountID: req.AccountID, + Platform: strings.TrimSpace(req.Platform), + Model: strings.TrimSpace(req.Model), + Query: strings.TrimSpace(req.Query), + } + + deleted, err := h.opsService.CleanupSystemLogs(c.Request.Context(), filter, subject.UserID) + if err != nil { + response.ErrorFrom(c, err) + return + } + response.Success(c, gin.H{"deleted": deleted}) +} + +// GetSystemLogIngestionHealth returns sink health metrics. +// GET /api/v1/admin/ops/system-logs/health +func (h *OpsHandler) GetSystemLogIngestionHealth(c *gin.Context) { + if h.opsService == nil { + response.Error(c, http.StatusServiceUnavailable, "Ops service not available") + return + } + if err := h.opsService.RequireMonitoringEnabled(c.Request.Context()); err != nil { + response.ErrorFrom(c, err) + return + } + response.Success(c, h.opsService.GetSystemLogSinkHealth()) +} diff --git a/backend/internal/handler/admin/ops_system_log_handler_test.go b/backend/internal/handler/admin/ops_system_log_handler_test.go new file mode 100644 index 00000000..7528acd8 --- /dev/null +++ b/backend/internal/handler/admin/ops_system_log_handler_test.go @@ -0,0 +1,233 @@ +package admin + +import ( + "bytes" + "encoding/json" + "net/http" + "net/http/httptest" + "testing" + + "github.com/Wei-Shaw/sub2api/internal/config" + "github.com/Wei-Shaw/sub2api/internal/server/middleware" + "github.com/Wei-Shaw/sub2api/internal/service" + "github.com/gin-gonic/gin" +) + +type responseEnvelope struct { + Code int `json:"code"` + Message string `json:"message"` + Data json.RawMessage `json:"data"` +} + +func newOpsSystemLogTestRouter(handler *OpsHandler, withUser bool) *gin.Engine { + gin.SetMode(gin.TestMode) + r := gin.New() + if withUser { + r.Use(func(c *gin.Context) { + c.Set(string(middleware.ContextKeyUser), middleware.AuthSubject{UserID: 99}) + c.Next() + }) + } + r.GET("/logs", handler.ListSystemLogs) + r.POST("/logs/cleanup", handler.CleanupSystemLogs) + r.GET("/logs/health", handler.GetSystemLogIngestionHealth) + return r +} + +func TestOpsSystemLogHandler_ListUnavailable(t *testing.T) { + h := NewOpsHandler(nil) + r := newOpsSystemLogTestRouter(h, false) + + w := httptest.NewRecorder() + req := httptest.NewRequest(http.MethodGet, "/logs", nil) + r.ServeHTTP(w, req) + if w.Code != http.StatusServiceUnavailable { + t.Fatalf("status=%d, want 503", w.Code) + } +} + +func TestOpsSystemLogHandler_ListInvalidUserID(t *testing.T) { + svc := service.NewOpsService(nil, nil, nil, nil, nil, nil, nil, nil, nil, nil, nil) + h := NewOpsHandler(svc) + r := newOpsSystemLogTestRouter(h, false) + + w := httptest.NewRecorder() + req := httptest.NewRequest(http.MethodGet, "/logs?user_id=abc", nil) + r.ServeHTTP(w, req) + if w.Code != http.StatusBadRequest { + t.Fatalf("status=%d, want 400", w.Code) + } +} + +func TestOpsSystemLogHandler_ListInvalidAccountID(t *testing.T) { + svc := service.NewOpsService(nil, nil, nil, nil, nil, nil, nil, nil, nil, nil, nil) + h := NewOpsHandler(svc) + r := newOpsSystemLogTestRouter(h, false) + + w := httptest.NewRecorder() + req := httptest.NewRequest(http.MethodGet, "/logs?account_id=-1", nil) + r.ServeHTTP(w, req) + if w.Code != http.StatusBadRequest { + t.Fatalf("status=%d, want 400", w.Code) + } +} + +func TestOpsSystemLogHandler_ListMonitoringDisabled(t *testing.T) { + svc := service.NewOpsService(nil, nil, &config.Config{ + Ops: config.OpsConfig{Enabled: false}, + }, nil, nil, nil, nil, nil, nil, nil, nil) + h := NewOpsHandler(svc) + r := newOpsSystemLogTestRouter(h, false) + + w := httptest.NewRecorder() + req := httptest.NewRequest(http.MethodGet, "/logs", nil) + r.ServeHTTP(w, req) + if w.Code != http.StatusNotFound { + t.Fatalf("status=%d, want 404", w.Code) + } +} + +func TestOpsSystemLogHandler_ListSuccess(t *testing.T) { + svc := service.NewOpsService(nil, nil, nil, nil, nil, nil, nil, nil, nil, nil, nil) + h := NewOpsHandler(svc) + r := newOpsSystemLogTestRouter(h, false) + + w := httptest.NewRecorder() + req := httptest.NewRequest(http.MethodGet, "/logs?time_range=30m&page=1&page_size=20", nil) + r.ServeHTTP(w, req) + if w.Code != http.StatusOK { + t.Fatalf("status=%d, want 200", w.Code) + } + + var resp responseEnvelope + if err := json.Unmarshal(w.Body.Bytes(), &resp); err != nil { + t.Fatalf("unmarshal response: %v", err) + } + if resp.Code != 0 { + t.Fatalf("unexpected response code: %+v", resp) + } +} + +func TestOpsSystemLogHandler_CleanupUnauthorized(t *testing.T) { + svc := service.NewOpsService(nil, nil, nil, nil, nil, nil, nil, nil, nil, nil, nil) + h := NewOpsHandler(svc) + r := newOpsSystemLogTestRouter(h, false) + + w := httptest.NewRecorder() + req := httptest.NewRequest(http.MethodPost, "/logs/cleanup", bytes.NewBufferString(`{"request_id":"r1"}`)) + req.Header.Set("Content-Type", "application/json") + r.ServeHTTP(w, req) + if w.Code != http.StatusUnauthorized { + t.Fatalf("status=%d, want 401", w.Code) + } +} + +func TestOpsSystemLogHandler_CleanupInvalidPayload(t *testing.T) { + svc := service.NewOpsService(nil, nil, nil, nil, nil, nil, nil, nil, nil, nil, nil) + h := NewOpsHandler(svc) + r := newOpsSystemLogTestRouter(h, true) + + w := httptest.NewRecorder() + req := httptest.NewRequest(http.MethodPost, "/logs/cleanup", bytes.NewBufferString(`{bad-json`)) + req.Header.Set("Content-Type", "application/json") + r.ServeHTTP(w, req) + if w.Code != http.StatusBadRequest { + t.Fatalf("status=%d, want 400", w.Code) + } +} + +func TestOpsSystemLogHandler_CleanupInvalidTime(t *testing.T) { + svc := service.NewOpsService(nil, nil, nil, nil, nil, nil, nil, nil, nil, nil, nil) + h := NewOpsHandler(svc) + r := newOpsSystemLogTestRouter(h, true) + + w := httptest.NewRecorder() + req := httptest.NewRequest(http.MethodPost, "/logs/cleanup", bytes.NewBufferString(`{"start_time":"bad","request_id":"r1"}`)) + req.Header.Set("Content-Type", "application/json") + r.ServeHTTP(w, req) + if w.Code != http.StatusBadRequest { + t.Fatalf("status=%d, want 400", w.Code) + } +} + +func TestOpsSystemLogHandler_CleanupInvalidEndTime(t *testing.T) { + svc := service.NewOpsService(nil, nil, nil, nil, nil, nil, nil, nil, nil, nil, nil) + h := NewOpsHandler(svc) + r := newOpsSystemLogTestRouter(h, true) + + w := httptest.NewRecorder() + req := httptest.NewRequest(http.MethodPost, "/logs/cleanup", bytes.NewBufferString(`{"end_time":"bad","request_id":"r1"}`)) + req.Header.Set("Content-Type", "application/json") + r.ServeHTTP(w, req) + if w.Code != http.StatusBadRequest { + t.Fatalf("status=%d, want 400", w.Code) + } +} + +func TestOpsSystemLogHandler_CleanupServiceUnavailable(t *testing.T) { + svc := service.NewOpsService(nil, nil, nil, nil, nil, nil, nil, nil, nil, nil, nil) + h := NewOpsHandler(svc) + r := newOpsSystemLogTestRouter(h, true) + + w := httptest.NewRecorder() + req := httptest.NewRequest(http.MethodPost, "/logs/cleanup", bytes.NewBufferString(`{"request_id":"r1"}`)) + req.Header.Set("Content-Type", "application/json") + r.ServeHTTP(w, req) + if w.Code != http.StatusServiceUnavailable { + t.Fatalf("status=%d, want 503", w.Code) + } +} + +func TestOpsSystemLogHandler_CleanupMonitoringDisabled(t *testing.T) { + svc := service.NewOpsService(nil, nil, &config.Config{ + Ops: config.OpsConfig{Enabled: false}, + }, nil, nil, nil, nil, nil, nil, nil, nil) + h := NewOpsHandler(svc) + r := newOpsSystemLogTestRouter(h, true) + + w := httptest.NewRecorder() + req := httptest.NewRequest(http.MethodPost, "/logs/cleanup", bytes.NewBufferString(`{"request_id":"r1"}`)) + req.Header.Set("Content-Type", "application/json") + r.ServeHTTP(w, req) + if w.Code != http.StatusNotFound { + t.Fatalf("status=%d, want 404", w.Code) + } +} + +func TestOpsSystemLogHandler_Health(t *testing.T) { + sink := service.NewOpsSystemLogSink(nil) + svc := service.NewOpsService(nil, nil, nil, nil, nil, nil, nil, nil, nil, nil, sink) + h := NewOpsHandler(svc) + r := newOpsSystemLogTestRouter(h, false) + + w := httptest.NewRecorder() + req := httptest.NewRequest(http.MethodGet, "/logs/health", nil) + r.ServeHTTP(w, req) + if w.Code != http.StatusOK { + t.Fatalf("status=%d, want 200", w.Code) + } +} + +func TestOpsSystemLogHandler_HealthUnavailableAndMonitoringDisabled(t *testing.T) { + h := NewOpsHandler(nil) + r := newOpsSystemLogTestRouter(h, false) + + w := httptest.NewRecorder() + req := httptest.NewRequest(http.MethodGet, "/logs/health", nil) + r.ServeHTTP(w, req) + if w.Code != http.StatusServiceUnavailable { + t.Fatalf("status=%d, want 503", w.Code) + } + + svc := service.NewOpsService(nil, nil, &config.Config{ + Ops: config.OpsConfig{Enabled: false}, + }, nil, nil, nil, nil, nil, nil, nil, nil) + h = NewOpsHandler(svc) + r = newOpsSystemLogTestRouter(h, false) + w = httptest.NewRecorder() + req = httptest.NewRequest(http.MethodGet, "/logs/health", nil) + r.ServeHTTP(w, req) + if w.Code != http.StatusNotFound { + t.Fatalf("status=%d, want 404", w.Code) + } +} diff --git a/backend/internal/handler/admin/ops_ws_handler.go b/backend/internal/handler/admin/ops_ws_handler.go index db7442e5..c030d303 100644 --- a/backend/internal/handler/admin/ops_ws_handler.go +++ b/backend/internal/handler/admin/ops_ws_handler.go @@ -3,7 +3,6 @@ package admin import ( "context" "encoding/json" - "log" "math" "net" "net/http" @@ -16,6 +15,7 @@ import ( "sync/atomic" "time" + "github.com/Wei-Shaw/sub2api/internal/pkg/logger" "github.com/Wei-Shaw/sub2api/internal/service" "github.com/gin-gonic/gin" "github.com/gorilla/websocket" @@ -252,7 +252,7 @@ func (c *opsWSQPSCache) refresh(parentCtx context.Context) { stats, err := opsService.GetWindowStats(ctx, now.Add(-c.requestCountWindow), now) if err != nil || stats == nil { if err != nil { - log.Printf("[OpsWS] refresh: get window stats failed: %v", err) + logger.LegacyPrintf("handler.admin.ops_ws", "[OpsWS] refresh: get window stats failed: %v", err) } return } @@ -278,7 +278,7 @@ func (c *opsWSQPSCache) refresh(parentCtx context.Context) { msg, err := json.Marshal(payload) if err != nil { - log.Printf("[OpsWS] refresh: marshal payload failed: %v", err) + logger.LegacyPrintf("handler.admin.ops_ws", "[OpsWS] refresh: marshal payload failed: %v", err) return } @@ -338,7 +338,7 @@ func (h *OpsHandler) QPSWSHandler(c *gin.Context) { // Reserve a global slot before upgrading the connection to keep the limit strict. if !tryAcquireOpsWSTotalSlot(opsWSLimits.MaxConns) { - log.Printf("[OpsWS] connection limit reached: %d/%d", wsConnCount.Load(), opsWSLimits.MaxConns) + logger.LegacyPrintf("handler.admin.ops_ws", "[OpsWS] connection limit reached: %d/%d", wsConnCount.Load(), opsWSLimits.MaxConns) c.JSON(http.StatusServiceUnavailable, gin.H{"error": "too many connections"}) return } @@ -350,7 +350,7 @@ func (h *OpsHandler) QPSWSHandler(c *gin.Context) { if opsWSLimits.MaxConnsPerIP > 0 && clientIP != "" { if !tryAcquireOpsWSIPSlot(clientIP, opsWSLimits.MaxConnsPerIP) { - log.Printf("[OpsWS] per-ip connection limit reached: ip=%s limit=%d", clientIP, opsWSLimits.MaxConnsPerIP) + logger.LegacyPrintf("handler.admin.ops_ws", "[OpsWS] per-ip connection limit reached: ip=%s limit=%d", clientIP, opsWSLimits.MaxConnsPerIP) c.JSON(http.StatusServiceUnavailable, gin.H{"error": "too many connections"}) return } @@ -359,7 +359,7 @@ func (h *OpsHandler) QPSWSHandler(c *gin.Context) { conn, err := upgrader.Upgrade(c.Writer, c.Request, nil) if err != nil { - log.Printf("[OpsWS] upgrade failed: %v", err) + logger.LegacyPrintf("handler.admin.ops_ws", "[OpsWS] upgrade failed: %v", err) return } @@ -452,7 +452,7 @@ func handleQPSWebSocket(parentCtx context.Context, conn *websocket.Conn) { conn.SetReadLimit(qpsWSMaxReadBytes) if err := conn.SetReadDeadline(time.Now().Add(qpsWSPongWait)); err != nil { - log.Printf("[OpsWS] set read deadline failed: %v", err) + logger.LegacyPrintf("handler.admin.ops_ws", "[OpsWS] set read deadline failed: %v", err) return } conn.SetPongHandler(func(string) error { @@ -471,7 +471,7 @@ func handleQPSWebSocket(parentCtx context.Context, conn *websocket.Conn) { _, _, err := conn.ReadMessage() if err != nil { if websocket.IsUnexpectedCloseError(err, websocket.CloseNormalClosure, websocket.CloseGoingAway, websocket.CloseNoStatusReceived) { - log.Printf("[OpsWS] read failed: %v", err) + logger.LegacyPrintf("handler.admin.ops_ws", "[OpsWS] read failed: %v", err) } return } @@ -508,7 +508,7 @@ func handleQPSWebSocket(parentCtx context.Context, conn *websocket.Conn) { continue } if err := writeWithTimeout(websocket.TextMessage, msg); err != nil { - log.Printf("[OpsWS] write failed: %v", err) + logger.LegacyPrintf("handler.admin.ops_ws", "[OpsWS] write failed: %v", err) cancel() closeConn() wg.Wait() @@ -517,7 +517,7 @@ func handleQPSWebSocket(parentCtx context.Context, conn *websocket.Conn) { case <-pingTicker.C: if err := writeWithTimeout(websocket.PingMessage, nil); err != nil { - log.Printf("[OpsWS] ping failed: %v", err) + logger.LegacyPrintf("handler.admin.ops_ws", "[OpsWS] ping failed: %v", err) cancel() closeConn() wg.Wait() @@ -666,14 +666,14 @@ func loadOpsWSProxyConfigFromEnv() OpsWSProxyConfig { if parsed, err := strconv.ParseBool(v); err == nil { cfg.TrustProxy = parsed } else { - log.Printf("[OpsWS] invalid %s=%q (expected bool); using default=%v", envOpsWSTrustProxy, v, cfg.TrustProxy) + logger.LegacyPrintf("handler.admin.ops_ws", "[OpsWS] invalid %s=%q (expected bool); using default=%v", envOpsWSTrustProxy, v, cfg.TrustProxy) } } if raw := strings.TrimSpace(os.Getenv(envOpsWSTrustedProxies)); raw != "" { prefixes, invalid := parseTrustedProxyList(raw) if len(invalid) > 0 { - log.Printf("[OpsWS] invalid %s entries ignored: %s", envOpsWSTrustedProxies, strings.Join(invalid, ", ")) + logger.LegacyPrintf("handler.admin.ops_ws", "[OpsWS] invalid %s entries ignored: %s", envOpsWSTrustedProxies, strings.Join(invalid, ", ")) } cfg.TrustedProxies = prefixes } @@ -684,7 +684,7 @@ func loadOpsWSProxyConfigFromEnv() OpsWSProxyConfig { case OriginPolicyStrict, OriginPolicyPermissive: cfg.OriginPolicy = normalized default: - log.Printf("[OpsWS] invalid %s=%q (expected %q or %q); using default=%q", envOpsWSOriginPolicy, v, OriginPolicyStrict, OriginPolicyPermissive, cfg.OriginPolicy) + logger.LegacyPrintf("handler.admin.ops_ws", "[OpsWS] invalid %s=%q (expected %q or %q); using default=%q", envOpsWSOriginPolicy, v, OriginPolicyStrict, OriginPolicyPermissive, cfg.OriginPolicy) } } @@ -701,14 +701,14 @@ func loadOpsWSRuntimeLimitsFromEnv() opsWSRuntimeLimits { if parsed, err := strconv.Atoi(v); err == nil && parsed > 0 { cfg.MaxConns = int32(parsed) } else { - log.Printf("[OpsWS] invalid %s=%q (expected int>0); using default=%d", envOpsWSMaxConns, v, cfg.MaxConns) + logger.LegacyPrintf("handler.admin.ops_ws", "[OpsWS] invalid %s=%q (expected int>0); using default=%d", envOpsWSMaxConns, v, cfg.MaxConns) } } if v := strings.TrimSpace(os.Getenv(envOpsWSMaxConnsPerIP)); v != "" { if parsed, err := strconv.Atoi(v); err == nil && parsed >= 0 { cfg.MaxConnsPerIP = int32(parsed) } else { - log.Printf("[OpsWS] invalid %s=%q (expected int>=0); using default=%d", envOpsWSMaxConnsPerIP, v, cfg.MaxConnsPerIP) + logger.LegacyPrintf("handler.admin.ops_ws", "[OpsWS] invalid %s=%q (expected int>=0); using default=%d", envOpsWSMaxConnsPerIP, v, cfg.MaxConnsPerIP) } } return cfg diff --git a/backend/internal/handler/admin/proxy_handler.go b/backend/internal/handler/admin/proxy_handler.go index a6758f69..9fd187fc 100644 --- a/backend/internal/handler/admin/proxy_handler.go +++ b/backend/internal/handler/admin/proxy_handler.go @@ -1,6 +1,7 @@ package admin import ( + "context" "strconv" "strings" @@ -130,20 +131,20 @@ func (h *ProxyHandler) Create(c *gin.Context) { return } - proxy, err := h.adminService.CreateProxy(c.Request.Context(), &service.CreateProxyInput{ - Name: strings.TrimSpace(req.Name), - Protocol: strings.TrimSpace(req.Protocol), - Host: strings.TrimSpace(req.Host), - Port: req.Port, - Username: strings.TrimSpace(req.Username), - Password: strings.TrimSpace(req.Password), + executeAdminIdempotentJSON(c, "admin.proxies.create", req, service.DefaultWriteIdempotencyTTL(), func(ctx context.Context) (any, error) { + proxy, err := h.adminService.CreateProxy(ctx, &service.CreateProxyInput{ + Name: strings.TrimSpace(req.Name), + Protocol: strings.TrimSpace(req.Protocol), + Host: strings.TrimSpace(req.Host), + Port: req.Port, + Username: strings.TrimSpace(req.Username), + Password: strings.TrimSpace(req.Password), + }) + if err != nil { + return nil, err + } + return dto.ProxyFromService(proxy), nil }) - if err != nil { - response.ErrorFrom(c, err) - return - } - - response.Success(c, dto.ProxyFromService(proxy)) } // Update handles updating a proxy @@ -236,6 +237,24 @@ func (h *ProxyHandler) Test(c *gin.Context) { response.Success(c, result) } +// CheckQuality handles checking proxy quality across common AI targets. +// POST /api/v1/admin/proxies/:id/quality-check +func (h *ProxyHandler) CheckQuality(c *gin.Context) { + proxyID, err := strconv.ParseInt(c.Param("id"), 10, 64) + if err != nil { + response.BadRequest(c, "Invalid proxy ID") + return + } + + result, err := h.adminService.CheckProxyQuality(c.Request.Context(), proxyID) + if err != nil { + response.ErrorFrom(c, err) + return + } + + response.Success(c, result) +} + // GetStats handles getting proxy statistics // GET /api/v1/admin/proxies/:id/stats func (h *ProxyHandler) GetStats(c *gin.Context) { diff --git a/backend/internal/handler/admin/redeem_handler.go b/backend/internal/handler/admin/redeem_handler.go index 02752fea..7073061d 100644 --- a/backend/internal/handler/admin/redeem_handler.go +++ b/backend/internal/handler/admin/redeem_handler.go @@ -2,6 +2,7 @@ package admin import ( "bytes" + "context" "encoding/csv" "fmt" "strconv" @@ -88,23 +89,24 @@ func (h *RedeemHandler) Generate(c *gin.Context) { return } - codes, err := h.adminService.GenerateRedeemCodes(c.Request.Context(), &service.GenerateRedeemCodesInput{ - Count: req.Count, - Type: req.Type, - Value: req.Value, - GroupID: req.GroupID, - ValidityDays: req.ValidityDays, - }) - if err != nil { - response.ErrorFrom(c, err) - return - } + executeAdminIdempotentJSON(c, "admin.redeem_codes.generate", req, service.DefaultWriteIdempotencyTTL(), func(ctx context.Context) (any, error) { + codes, execErr := h.adminService.GenerateRedeemCodes(ctx, &service.GenerateRedeemCodesInput{ + Count: req.Count, + Type: req.Type, + Value: req.Value, + GroupID: req.GroupID, + ValidityDays: req.ValidityDays, + }) + if execErr != nil { + return nil, execErr + } - out := make([]dto.AdminRedeemCode, 0, len(codes)) - for i := range codes { - out = append(out, *dto.RedeemCodeFromServiceAdmin(&codes[i])) - } - response.Success(c, out) + out := make([]dto.AdminRedeemCode, 0, len(codes)) + for i := range codes { + out = append(out, *dto.RedeemCodeFromServiceAdmin(&codes[i])) + } + return out, nil + }) } // Delete handles deleting a redeem code diff --git a/backend/internal/handler/admin/search_truncate_test.go b/backend/internal/handler/admin/search_truncate_test.go new file mode 100644 index 00000000..ffd60e2a --- /dev/null +++ b/backend/internal/handler/admin/search_truncate_test.go @@ -0,0 +1,97 @@ +//go:build unit + +package admin + +import ( + "testing" + + "github.com/stretchr/testify/require" +) + +// truncateSearchByRune 模拟 user_handler.go 中的 search 截断逻辑 +func truncateSearchByRune(search string, maxRunes int) string { + if runes := []rune(search); len(runes) > maxRunes { + return string(runes[:maxRunes]) + } + return search +} + +func TestTruncateSearchByRune(t *testing.T) { + tests := []struct { + name string + input string + maxRunes int + wantLen int // 期望的 rune 长度 + }{ + { + name: "纯中文超长", + input: string(make([]rune, 150)), + maxRunes: 100, + wantLen: 100, + }, + { + name: "纯 ASCII 超长", + input: string(make([]byte, 150)), + maxRunes: 100, + wantLen: 100, + }, + { + name: "空字符串", + input: "", + maxRunes: 100, + wantLen: 0, + }, + { + name: "恰好 100 个字符", + input: string(make([]rune, 100)), + maxRunes: 100, + wantLen: 100, + }, + { + name: "不足 100 字符不截断", + input: "hello世界", + maxRunes: 100, + wantLen: 7, + }, + } + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + result := truncateSearchByRune(tc.input, tc.maxRunes) + require.Equal(t, tc.wantLen, len([]rune(result))) + }) + } +} + +func TestTruncateSearchByRune_PreservesMultibyte(t *testing.T) { + // 101 个中文字符,截断到 100 个后应该仍然是有效 UTF-8 + input := "" + for i := 0; i < 101; i++ { + input += "中" + } + result := truncateSearchByRune(input, 100) + + require.Equal(t, 100, len([]rune(result))) + // 验证截断结果是有效的 UTF-8(每个中文字符 3 字节) + require.Equal(t, 300, len(result)) +} + +func TestTruncateSearchByRune_MixedASCIIAndMultibyte(t *testing.T) { + // 50 个 ASCII + 51 个中文 = 101 个 rune + input := "" + for i := 0; i < 50; i++ { + input += "a" + } + for i := 0; i < 51; i++ { + input += "中" + } + result := truncateSearchByRune(input, 100) + + runes := []rune(result) + require.Equal(t, 100, len(runes)) + // 前 50 个应该是 'a',后 50 个应该是 '中' + require.Equal(t, 'a', runes[0]) + require.Equal(t, 'a', runes[49]) + require.Equal(t, '中', runes[50]) + require.Equal(t, '中', runes[99]) +} diff --git a/backend/internal/handler/admin/subscription_handler.go b/backend/internal/handler/admin/subscription_handler.go index 51995ab1..e5b6db13 100644 --- a/backend/internal/handler/admin/subscription_handler.go +++ b/backend/internal/handler/admin/subscription_handler.go @@ -1,6 +1,7 @@ package admin import ( + "context" "strconv" "github.com/Wei-Shaw/sub2api/internal/handler/dto" @@ -199,13 +200,20 @@ func (h *SubscriptionHandler) Extend(c *gin.Context) { return } - subscription, err := h.subscriptionService.ExtendSubscription(c.Request.Context(), subscriptionID, req.Days) - if err != nil { - response.ErrorFrom(c, err) - return + idempotencyPayload := struct { + SubscriptionID int64 `json:"subscription_id"` + Body AdjustSubscriptionRequest `json:"body"` + }{ + SubscriptionID: subscriptionID, + Body: req, } - - response.Success(c, dto.UserSubscriptionFromServiceAdmin(subscription)) + executeAdminIdempotentJSON(c, "admin.subscriptions.extend", idempotencyPayload, service.DefaultWriteIdempotencyTTL(), func(ctx context.Context) (any, error) { + subscription, execErr := h.subscriptionService.ExtendSubscription(ctx, subscriptionID, req.Days) + if execErr != nil { + return nil, execErr + } + return dto.UserSubscriptionFromServiceAdmin(subscription), nil + }) } // Revoke handles revoking a subscription diff --git a/backend/internal/handler/admin/system_handler.go b/backend/internal/handler/admin/system_handler.go index 28c075aa..3e2022c7 100644 --- a/backend/internal/handler/admin/system_handler.go +++ b/backend/internal/handler/admin/system_handler.go @@ -1,11 +1,15 @@ package admin import ( + "context" "net/http" + "strconv" + "strings" "time" "github.com/Wei-Shaw/sub2api/internal/pkg/response" "github.com/Wei-Shaw/sub2api/internal/pkg/sysutil" + middleware2 "github.com/Wei-Shaw/sub2api/internal/server/middleware" "github.com/Wei-Shaw/sub2api/internal/service" "github.com/gin-gonic/gin" @@ -14,12 +18,14 @@ import ( // SystemHandler handles system-related operations type SystemHandler struct { updateSvc *service.UpdateService + lockSvc *service.SystemOperationLockService } // NewSystemHandler creates a new SystemHandler -func NewSystemHandler(updateSvc *service.UpdateService) *SystemHandler { +func NewSystemHandler(updateSvc *service.UpdateService, lockSvc *service.SystemOperationLockService) *SystemHandler { return &SystemHandler{ updateSvc: updateSvc, + lockSvc: lockSvc, } } @@ -47,41 +53,125 @@ func (h *SystemHandler) CheckUpdates(c *gin.Context) { // PerformUpdate downloads and applies the update // POST /api/v1/admin/system/update func (h *SystemHandler) PerformUpdate(c *gin.Context) { - if err := h.updateSvc.PerformUpdate(c.Request.Context()); err != nil { - response.Error(c, http.StatusInternalServerError, err.Error()) - return - } - response.Success(c, gin.H{ - "message": "Update completed. Please restart the service.", - "need_restart": true, + operationID := buildSystemOperationID(c, "update") + payload := gin.H{"operation_id": operationID} + executeAdminIdempotentJSON(c, "admin.system.update", payload, service.DefaultSystemOperationIdempotencyTTL(), func(ctx context.Context) (any, error) { + lock, release, err := h.acquireSystemLock(ctx, operationID) + if err != nil { + return nil, err + } + var releaseReason string + succeeded := false + defer func() { + release(releaseReason, succeeded) + }() + + if err := h.updateSvc.PerformUpdate(ctx); err != nil { + releaseReason = "SYSTEM_UPDATE_FAILED" + return nil, err + } + succeeded = true + + return gin.H{ + "message": "Update completed. Please restart the service.", + "need_restart": true, + "operation_id": lock.OperationID(), + }, nil }) } // Rollback restores the previous version // POST /api/v1/admin/system/rollback func (h *SystemHandler) Rollback(c *gin.Context) { - if err := h.updateSvc.Rollback(); err != nil { - response.Error(c, http.StatusInternalServerError, err.Error()) - return - } - response.Success(c, gin.H{ - "message": "Rollback completed. Please restart the service.", - "need_restart": true, + operationID := buildSystemOperationID(c, "rollback") + payload := gin.H{"operation_id": operationID} + executeAdminIdempotentJSON(c, "admin.system.rollback", payload, service.DefaultSystemOperationIdempotencyTTL(), func(ctx context.Context) (any, error) { + lock, release, err := h.acquireSystemLock(ctx, operationID) + if err != nil { + return nil, err + } + var releaseReason string + succeeded := false + defer func() { + release(releaseReason, succeeded) + }() + + if err := h.updateSvc.Rollback(); err != nil { + releaseReason = "SYSTEM_ROLLBACK_FAILED" + return nil, err + } + succeeded = true + + return gin.H{ + "message": "Rollback completed. Please restart the service.", + "need_restart": true, + "operation_id": lock.OperationID(), + }, nil }) } // RestartService restarts the systemd service // POST /api/v1/admin/system/restart func (h *SystemHandler) RestartService(c *gin.Context) { - // Schedule service restart in background after sending response - // This ensures the client receives the success response before the service restarts - go func() { - // Wait a moment to ensure the response is sent - time.Sleep(500 * time.Millisecond) - sysutil.RestartServiceAsync() - }() + operationID := buildSystemOperationID(c, "restart") + payload := gin.H{"operation_id": operationID} + executeAdminIdempotentJSON(c, "admin.system.restart", payload, service.DefaultSystemOperationIdempotencyTTL(), func(ctx context.Context) (any, error) { + lock, release, err := h.acquireSystemLock(ctx, operationID) + if err != nil { + return nil, err + } + succeeded := false + defer func() { + release("", succeeded) + }() - response.Success(c, gin.H{ - "message": "Service restart initiated", + // Schedule service restart in background after sending response + // This ensures the client receives the success response before the service restarts + go func() { + // Wait a moment to ensure the response is sent + time.Sleep(500 * time.Millisecond) + sysutil.RestartServiceAsync() + }() + succeeded = true + return gin.H{ + "message": "Service restart initiated", + "operation_id": lock.OperationID(), + }, nil }) } + +func (h *SystemHandler) acquireSystemLock( + ctx context.Context, + operationID string, +) (*service.SystemOperationLock, func(string, bool), error) { + if h.lockSvc == nil { + return nil, nil, service.ErrIdempotencyStoreUnavail + } + lock, err := h.lockSvc.Acquire(ctx, operationID) + if err != nil { + return nil, nil, err + } + release := func(reason string, succeeded bool) { + releaseCtx, cancel := context.WithTimeout(context.Background(), 2*time.Second) + defer cancel() + _ = h.lockSvc.Release(releaseCtx, lock, succeeded, reason) + } + return lock, release, nil +} + +func buildSystemOperationID(c *gin.Context, operation string) string { + key := strings.TrimSpace(c.GetHeader("Idempotency-Key")) + if key == "" { + return "sysop-" + operation + "-" + strconv.FormatInt(time.Now().UnixNano(), 36) + } + actorScope := "admin:0" + if subject, ok := middleware2.GetAuthSubjectFromContext(c); ok { + actorScope = "admin:" + strconv.FormatInt(subject.UserID, 10) + } + seed := operation + "|" + actorScope + "|" + c.FullPath() + "|" + key + hash := service.HashIdempotencyKey(seed) + if len(hash) > 24 { + hash = hash[:24] + } + return "sysop-" + hash +} diff --git a/backend/internal/handler/admin/usage_handler.go b/backend/internal/handler/admin/usage_handler.go index 3f3238dd..5cbf18e6 100644 --- a/backend/internal/handler/admin/usage_handler.go +++ b/backend/internal/handler/admin/usage_handler.go @@ -1,13 +1,14 @@ package admin import ( - "log" + "context" "net/http" "strconv" "strings" "time" "github.com/Wei-Shaw/sub2api/internal/handler/dto" + "github.com/Wei-Shaw/sub2api/internal/pkg/logger" "github.com/Wei-Shaw/sub2api/internal/pkg/pagination" "github.com/Wei-Shaw/sub2api/internal/pkg/response" "github.com/Wei-Shaw/sub2api/internal/pkg/timezone" @@ -378,11 +379,11 @@ func (h *UsageHandler) ListCleanupTasks(c *gin.Context) { operator = subject.UserID } page, pageSize := response.ParsePagination(c) - log.Printf("[UsageCleanup] 请求清理任务列表: operator=%d page=%d page_size=%d", operator, page, pageSize) + logger.LegacyPrintf("handler.admin.usage", "[UsageCleanup] 请求清理任务列表: operator=%d page=%d page_size=%d", operator, page, pageSize) params := pagination.PaginationParams{Page: page, PageSize: pageSize} tasks, result, err := h.cleanupService.ListTasks(c.Request.Context(), params) if err != nil { - log.Printf("[UsageCleanup] 查询清理任务列表失败: operator=%d page=%d page_size=%d err=%v", operator, page, pageSize, err) + logger.LegacyPrintf("handler.admin.usage", "[UsageCleanup] 查询清理任务列表失败: operator=%d page=%d page_size=%d err=%v", operator, page, pageSize, err) response.ErrorFrom(c, err) return } @@ -390,7 +391,7 @@ func (h *UsageHandler) ListCleanupTasks(c *gin.Context) { for i := range tasks { out = append(out, *dto.UsageCleanupTaskFromService(&tasks[i])) } - log.Printf("[UsageCleanup] 返回清理任务列表: operator=%d total=%d items=%d page=%d page_size=%d", operator, result.Total, len(out), page, pageSize) + logger.LegacyPrintf("handler.admin.usage", "[UsageCleanup] 返回清理任务列表: operator=%d total=%d items=%d page=%d page_size=%d", operator, result.Total, len(out), page, pageSize) response.Paginated(c, out, result.Total, page, pageSize) } @@ -472,29 +473,36 @@ func (h *UsageHandler) CreateCleanupTask(c *gin.Context) { billingType = *filters.BillingType } - log.Printf("[UsageCleanup] 请求创建清理任务: operator=%d start=%s end=%s user_id=%v api_key_id=%v account_id=%v group_id=%v model=%v stream=%v billing_type=%v tz=%q", - subject.UserID, - filters.StartTime.Format(time.RFC3339), - filters.EndTime.Format(time.RFC3339), - userID, - apiKeyID, - accountID, - groupID, - model, - stream, - billingType, - req.Timezone, - ) - - task, err := h.cleanupService.CreateTask(c.Request.Context(), filters, subject.UserID) - if err != nil { - log.Printf("[UsageCleanup] 创建清理任务失败: operator=%d err=%v", subject.UserID, err) - response.ErrorFrom(c, err) - return + idempotencyPayload := struct { + OperatorID int64 `json:"operator_id"` + Body CreateUsageCleanupTaskRequest `json:"body"` + }{ + OperatorID: subject.UserID, + Body: req, } + executeAdminIdempotentJSON(c, "admin.usage.cleanup_tasks.create", idempotencyPayload, service.DefaultWriteIdempotencyTTL(), func(ctx context.Context) (any, error) { + logger.LegacyPrintf("handler.admin.usage", "[UsageCleanup] 请求创建清理任务: operator=%d start=%s end=%s user_id=%v api_key_id=%v account_id=%v group_id=%v model=%v stream=%v billing_type=%v tz=%q", + subject.UserID, + filters.StartTime.Format(time.RFC3339), + filters.EndTime.Format(time.RFC3339), + userID, + apiKeyID, + accountID, + groupID, + model, + stream, + billingType, + req.Timezone, + ) - log.Printf("[UsageCleanup] 清理任务已创建: task=%d operator=%d status=%s", task.ID, subject.UserID, task.Status) - response.Success(c, dto.UsageCleanupTaskFromService(task)) + task, err := h.cleanupService.CreateTask(ctx, filters, subject.UserID) + if err != nil { + logger.LegacyPrintf("handler.admin.usage", "[UsageCleanup] 创建清理任务失败: operator=%d err=%v", subject.UserID, err) + return nil, err + } + logger.LegacyPrintf("handler.admin.usage", "[UsageCleanup] 清理任务已创建: task=%d operator=%d status=%s", task.ID, subject.UserID, task.Status) + return dto.UsageCleanupTaskFromService(task), nil + }) } // CancelCleanupTask handles canceling a usage cleanup task @@ -515,12 +523,12 @@ func (h *UsageHandler) CancelCleanupTask(c *gin.Context) { response.BadRequest(c, "Invalid task id") return } - log.Printf("[UsageCleanup] 请求取消清理任务: task=%d operator=%d", taskID, subject.UserID) + logger.LegacyPrintf("handler.admin.usage", "[UsageCleanup] 请求取消清理任务: task=%d operator=%d", taskID, subject.UserID) if err := h.cleanupService.CancelTask(c.Request.Context(), taskID, subject.UserID); err != nil { - log.Printf("[UsageCleanup] 取消清理任务失败: task=%d operator=%d err=%v", taskID, subject.UserID, err) + logger.LegacyPrintf("handler.admin.usage", "[UsageCleanup] 取消清理任务失败: task=%d operator=%d err=%v", taskID, subject.UserID, err) response.ErrorFrom(c, err) return } - log.Printf("[UsageCleanup] 清理任务已取消: task=%d operator=%d", taskID, subject.UserID) + logger.LegacyPrintf("handler.admin.usage", "[UsageCleanup] 清理任务已取消: task=%d operator=%d", taskID, subject.UserID) response.Success(c, gin.H{"id": taskID, "status": service.UsageCleanupStatusCanceled}) } diff --git a/backend/internal/handler/admin/user_handler.go b/backend/internal/handler/admin/user_handler.go index efb9abb5..d85202e5 100644 --- a/backend/internal/handler/admin/user_handler.go +++ b/backend/internal/handler/admin/user_handler.go @@ -1,6 +1,7 @@ package admin import ( + "context" "strconv" "strings" @@ -78,8 +79,8 @@ func (h *UserHandler) List(c *gin.Context) { search := c.Query("search") // 标准化和验证 search 参数 search = strings.TrimSpace(search) - if len(search) > 100 { - search = search[:100] + if runes := []rune(search); len(runes) > 100 { + search = string(runes[:100]) } filters := service.UserListFilters{ @@ -257,13 +258,20 @@ func (h *UserHandler) UpdateBalance(c *gin.Context) { return } - user, err := h.adminService.UpdateUserBalance(c.Request.Context(), userID, req.Balance, req.Operation, req.Notes) - if err != nil { - response.ErrorFrom(c, err) - return + idempotencyPayload := struct { + UserID int64 `json:"user_id"` + Body UpdateBalanceRequest `json:"body"` + }{ + UserID: userID, + Body: req, } - - response.Success(c, dto.UserFromServiceAdmin(user)) + executeAdminIdempotentJSON(c, "admin.users.balance.update", idempotencyPayload, service.DefaultWriteIdempotencyTTL(), func(ctx context.Context) (any, error) { + user, execErr := h.adminService.UpdateUserBalance(ctx, userID, req.Balance, req.Operation, req.Notes) + if execErr != nil { + return nil, execErr + } + return dto.UserFromServiceAdmin(user), nil + }) } // GetUserAPIKeys handles getting user's API keys diff --git a/backend/internal/handler/api_key_handler.go b/backend/internal/handler/api_key_handler.go index f1a18ad2..61762744 100644 --- a/backend/internal/handler/api_key_handler.go +++ b/backend/internal/handler/api_key_handler.go @@ -2,6 +2,7 @@ package handler import ( + "context" "strconv" "time" @@ -130,13 +131,14 @@ func (h *APIKeyHandler) Create(c *gin.Context) { if req.Quota != nil { svcReq.Quota = *req.Quota } - key, err := h.apiKeyService.Create(c.Request.Context(), subject.UserID, svcReq) - if err != nil { - response.ErrorFrom(c, err) - return - } - response.Success(c, dto.APIKeyFromService(key)) + executeUserIdempotentJSON(c, "user.api_keys.create", req, service.DefaultWriteIdempotencyTTL(), func(ctx context.Context) (any, error) { + key, err := h.apiKeyService.Create(ctx, subject.UserID, svcReq) + if err != nil { + return nil, err + } + return dto.APIKeyFromService(key), nil + }) } // Update handles updating an API key diff --git a/backend/internal/handler/auth_handler.go b/backend/internal/handler/auth_handler.go index 34ed63bc..e0078e14 100644 --- a/backend/internal/handler/auth_handler.go +++ b/backend/internal/handler/auth_handler.go @@ -2,6 +2,7 @@ package handler import ( "log/slog" + "strings" "github.com/Wei-Shaw/sub2api/internal/config" "github.com/Wei-Shaw/sub2api/internal/handler/dto" @@ -112,12 +113,11 @@ func (h *AuthHandler) Register(c *gin.Context) { return } - // Turnstile 验证(当提供了邮箱验证码时跳过,因为发送验证码时已验证过) - if req.VerifyCode == "" { - if err := h.authService.VerifyTurnstile(c.Request.Context(), req.TurnstileToken, ip.GetClientIP(c)); err != nil { - response.ErrorFrom(c, err) - return - } + // Turnstile 验证 — 始终执行,防止绕过 + // TODO: 确认前端在提交邮箱验证码注册时也传递了 turnstile_token + if err := h.authService.VerifyTurnstile(c.Request.Context(), req.TurnstileToken, ip.GetClientIP(c)); err != nil { + response.ErrorFrom(c, err) + return } _, user, err := h.authService.RegisterWithVerification(c.Request.Context(), req.Email, req.Password, req.VerifyCode, req.PromoCode, req.InvitationCode) @@ -448,17 +448,12 @@ func (h *AuthHandler) ForgotPassword(c *gin.Context) { return } - // Build frontend base URL from request - scheme := "https" - if c.Request.TLS == nil { - // Check X-Forwarded-Proto header (common in reverse proxy setups) - if proto := c.GetHeader("X-Forwarded-Proto"); proto != "" { - scheme = proto - } else { - scheme = "http" - } + frontendBaseURL := strings.TrimSpace(h.cfg.Server.FrontendURL) + if frontendBaseURL == "" { + slog.Error("server.frontend_url not configured; cannot build password reset link") + response.InternalError(c, "Password reset is not configured") + return } - frontendBaseURL := scheme + "://" + c.Request.Host // Request password reset (async) // Note: This returns success even if email doesn't exist (to prevent enumeration) diff --git a/backend/internal/handler/dto/api_key_mapper_last_used_test.go b/backend/internal/handler/dto/api_key_mapper_last_used_test.go new file mode 100644 index 00000000..99644ced --- /dev/null +++ b/backend/internal/handler/dto/api_key_mapper_last_used_test.go @@ -0,0 +1,40 @@ +package dto + +import ( + "testing" + "time" + + "github.com/Wei-Shaw/sub2api/internal/service" + "github.com/stretchr/testify/require" +) + +func TestAPIKeyFromService_MapsLastUsedAt(t *testing.T) { + lastUsed := time.Now().UTC().Truncate(time.Second) + src := &service.APIKey{ + ID: 1, + UserID: 2, + Key: "sk-map-last-used", + Name: "Mapper", + Status: service.StatusActive, + LastUsedAt: &lastUsed, + } + + out := APIKeyFromService(src) + require.NotNil(t, out) + require.NotNil(t, out.LastUsedAt) + require.WithinDuration(t, lastUsed, *out.LastUsedAt, time.Second) +} + +func TestAPIKeyFromService_MapsNilLastUsedAt(t *testing.T) { + src := &service.APIKey{ + ID: 1, + UserID: 2, + Key: "sk-map-last-used-nil", + Name: "MapperNil", + Status: service.StatusActive, + } + + out := APIKeyFromService(src) + require.NotNil(t, out) + require.Nil(t, out.LastUsedAt) +} diff --git a/backend/internal/handler/dto/mappers.go b/backend/internal/handler/dto/mappers.go index eee5910e..42ff4a84 100644 --- a/backend/internal/handler/dto/mappers.go +++ b/backend/internal/handler/dto/mappers.go @@ -2,6 +2,7 @@ package dto import ( + "strconv" "time" "github.com/Wei-Shaw/sub2api/internal/service" @@ -77,6 +78,7 @@ func APIKeyFromService(k *service.APIKey) *APIKey { Status: k.Status, IPWhitelist: k.IPWhitelist, IPBlacklist: k.IPBlacklist, + LastUsedAt: k.LastUsedAt, Quota: k.Quota, QuotaUsed: k.QuotaUsed, ExpiresAt: k.ExpiresAt, @@ -129,23 +131,26 @@ func GroupFromServiceAdmin(g *service.Group) *AdminGroup { func groupFromServiceBase(g *service.Group) Group { return Group{ - ID: g.ID, - Name: g.Name, - Description: g.Description, - Platform: g.Platform, - RateMultiplier: g.RateMultiplier, - IsExclusive: g.IsExclusive, - Status: g.Status, - SubscriptionType: g.SubscriptionType, - DailyLimitUSD: g.DailyLimitUSD, - WeeklyLimitUSD: g.WeeklyLimitUSD, - MonthlyLimitUSD: g.MonthlyLimitUSD, - ImagePrice1K: g.ImagePrice1K, - ImagePrice2K: g.ImagePrice2K, - ImagePrice4K: g.ImagePrice4K, - ClaudeCodeOnly: g.ClaudeCodeOnly, - FallbackGroupID: g.FallbackGroupID, - // 无效请求兜底分组 + ID: g.ID, + Name: g.Name, + Description: g.Description, + Platform: g.Platform, + RateMultiplier: g.RateMultiplier, + IsExclusive: g.IsExclusive, + Status: g.Status, + SubscriptionType: g.SubscriptionType, + DailyLimitUSD: g.DailyLimitUSD, + WeeklyLimitUSD: g.WeeklyLimitUSD, + MonthlyLimitUSD: g.MonthlyLimitUSD, + ImagePrice1K: g.ImagePrice1K, + ImagePrice2K: g.ImagePrice2K, + ImagePrice4K: g.ImagePrice4K, + SoraImagePrice360: g.SoraImagePrice360, + SoraImagePrice540: g.SoraImagePrice540, + SoraVideoPricePerRequest: g.SoraVideoPricePerRequest, + SoraVideoPricePerRequestHD: g.SoraVideoPricePerRequestHD, + ClaudeCodeOnly: g.ClaudeCodeOnly, + FallbackGroupID: g.FallbackGroupID, FallbackGroupIDOnInvalidRequest: g.FallbackGroupIDOnInvalidRequest, CreatedAt: g.CreatedAt, UpdatedAt: g.UpdatedAt, @@ -300,6 +305,11 @@ func ProxyWithAccountCountFromService(p *service.ProxyWithAccountCount) *ProxyWi CountryCode: p.CountryCode, Region: p.Region, City: p.City, + QualityStatus: p.QualityStatus, + QualityScore: p.QualityScore, + QualityGrade: p.QualityGrade, + QualitySummary: p.QualitySummary, + QualityChecked: p.QualityChecked, } } @@ -404,6 +414,7 @@ func usageLogFromServiceUser(l *service.UsageLog) UsageLog { FirstTokenMs: l.FirstTokenMs, ImageCount: l.ImageCount, ImageSize: l.ImageSize, + MediaType: l.MediaType, UserAgent: l.UserAgent, CacheTTLOverridden: l.CacheTTLOverridden, CreatedAt: l.CreatedAt, @@ -532,11 +543,18 @@ func BulkAssignResultFromService(r *service.BulkAssignResult) *BulkAssignResult for i := range r.Subscriptions { subs = append(subs, *UserSubscriptionFromServiceAdmin(&r.Subscriptions[i])) } + statuses := make(map[string]string, len(r.Statuses)) + for userID, status := range r.Statuses { + statuses[strconv.FormatInt(userID, 10)] = status + } return &BulkAssignResult{ SuccessCount: r.SuccessCount, + CreatedCount: r.CreatedCount, + ReusedCount: r.ReusedCount, FailedCount: r.FailedCount, Subscriptions: subs, Errors: r.Errors, + Statuses: statuses, } } diff --git a/backend/internal/handler/dto/types.go b/backend/internal/handler/dto/types.go index 0253caf7..0cd1b241 100644 --- a/backend/internal/handler/dto/types.go +++ b/backend/internal/handler/dto/types.go @@ -38,6 +38,7 @@ type APIKey struct { Status string `json:"status"` IPWhitelist []string `json:"ip_whitelist"` IPBlacklist []string `json:"ip_blacklist"` + LastUsedAt *time.Time `json:"last_used_at"` Quota float64 `json:"quota"` // Quota limit in USD (0 = unlimited) QuotaUsed float64 `json:"quota_used"` // Used quota amount in USD ExpiresAt *time.Time `json:"expires_at"` // Expiration time (nil = never expires) @@ -67,6 +68,12 @@ type Group struct { ImagePrice2K *float64 `json:"image_price_2k"` ImagePrice4K *float64 `json:"image_price_4k"` + // Sora 按次计费配置 + SoraImagePrice360 *float64 `json:"sora_image_price_360"` + SoraImagePrice540 *float64 `json:"sora_image_price_540"` + SoraVideoPricePerRequest *float64 `json:"sora_video_price_per_request"` + SoraVideoPricePerRequestHD *float64 `json:"sora_video_price_per_request_hd"` + // Claude Code 客户端限制 ClaudeCodeOnly bool `json:"claude_code_only"` FallbackGroupID *int64 `json:"fallback_group_id"` @@ -196,6 +203,11 @@ type ProxyWithAccountCount struct { CountryCode string `json:"country_code,omitempty"` Region string `json:"region,omitempty"` City string `json:"city,omitempty"` + QualityStatus string `json:"quality_status,omitempty"` + QualityScore *int `json:"quality_score,omitempty"` + QualityGrade string `json:"quality_grade,omitempty"` + QualitySummary string `json:"quality_summary,omitempty"` + QualityChecked *int64 `json:"quality_checked,omitempty"` } type ProxyAccountSummary struct { @@ -274,6 +286,7 @@ type UsageLog struct { // 图片生成字段 ImageCount int `json:"image_count"` ImageSize *string `json:"image_size"` + MediaType *string `json:"media_type"` // User-Agent UserAgent *string `json:"user_agent"` @@ -382,9 +395,12 @@ type AdminUserSubscription struct { type BulkAssignResult struct { SuccessCount int `json:"success_count"` + CreatedCount int `json:"created_count"` + ReusedCount int `json:"reused_count"` FailedCount int `json:"failed_count"` Subscriptions []AdminUserSubscription `json:"subscriptions"` Errors []string `json:"errors"` + Statuses map[string]string `json:"statuses,omitempty"` } // PromoCode 注册优惠码 diff --git a/backend/internal/handler/gateway_handler.go b/backend/internal/handler/gateway_handler.go index c2b6bf09..4b32969f 100644 --- a/backend/internal/handler/gateway_handler.go +++ b/backend/internal/handler/gateway_handler.go @@ -19,11 +19,13 @@ import ( "github.com/Wei-Shaw/sub2api/internal/pkg/ctxkey" pkgerrors "github.com/Wei-Shaw/sub2api/internal/pkg/errors" "github.com/Wei-Shaw/sub2api/internal/pkg/ip" + "github.com/Wei-Shaw/sub2api/internal/pkg/logger" "github.com/Wei-Shaw/sub2api/internal/pkg/openai" middleware2 "github.com/Wei-Shaw/sub2api/internal/server/middleware" "github.com/Wei-Shaw/sub2api/internal/service" "github.com/gin-gonic/gin" + "go.uber.org/zap" ) // GatewayHandler handles API gateway requests @@ -35,10 +37,12 @@ type GatewayHandler struct { billingCacheService *service.BillingCacheService usageService *service.UsageService apiKeyService *service.APIKeyService + usageRecordWorkerPool *service.UsageRecordWorkerPool errorPassthroughService *service.ErrorPassthroughService concurrencyHelper *ConcurrencyHelper maxAccountSwitches int maxAccountSwitchesGemini int + cfg *config.Config } // NewGatewayHandler creates a new GatewayHandler @@ -51,6 +55,7 @@ func NewGatewayHandler( billingCacheService *service.BillingCacheService, usageService *service.UsageService, apiKeyService *service.APIKeyService, + usageRecordWorkerPool *service.UsageRecordWorkerPool, errorPassthroughService *service.ErrorPassthroughService, cfg *config.Config, ) *GatewayHandler { @@ -74,10 +79,12 @@ func NewGatewayHandler( billingCacheService: billingCacheService, usageService: usageService, apiKeyService: apiKeyService, + usageRecordWorkerPool: usageRecordWorkerPool, errorPassthroughService: errorPassthroughService, concurrencyHelper: NewConcurrencyHelper(concurrencyService, SSEPingFormatClaude, pingInterval), maxAccountSwitches: maxAccountSwitches, maxAccountSwitchesGemini: maxAccountSwitchesGemini, + cfg: cfg, } } @@ -96,6 +103,13 @@ func (h *GatewayHandler) Messages(c *gin.Context) { h.errorResponse(c, http.StatusInternalServerError, "api_error", "User context not found") return } + reqLog := requestLogger( + c, + "handler.gateway.messages", + zap.Int64("user_id", subject.UserID), + zap.Int64("api_key_id", apiKey.ID), + zap.Any("group_id", apiKey.GroupID), + ) // 读取请求体 body, err := io.ReadAll(c.Request.Body) @@ -122,6 +136,7 @@ func (h *GatewayHandler) Messages(c *gin.Context) { } reqModel := parsedReq.Model reqStream := parsedReq.Stream + reqLog = reqLog.With(zap.String("model", reqModel), zap.Bool("stream", reqStream)) // 设置 max_tokens=1 + haiku 探测请求标识到 context 中 // 必须在 SetClaudeCodeClientContext 之前设置,因为 ClaudeCodeValidator 需要读取此标识进行绕过判断 @@ -161,9 +176,10 @@ func (h *GatewayHandler) Messages(c *gin.Context) { canWait, err := h.concurrencyHelper.IncrementWaitCount(c.Request.Context(), subject.UserID, maxWait) waitCounted := false if err != nil { - log.Printf("Increment wait count failed: %v", err) + reqLog.Warn("gateway.user_wait_counter_increment_failed", zap.Error(err)) // On error, allow request to proceed } else if !canWait { + reqLog.Info("gateway.user_wait_queue_full", zap.Int("max_wait", maxWait)) h.errorResponse(c, http.StatusTooManyRequests, "rate_limit_error", "Too many pending requests, please retry later") return } @@ -180,7 +196,7 @@ func (h *GatewayHandler) Messages(c *gin.Context) { // 1. 首先获取用户并发槽位 userReleaseFunc, err := h.concurrencyHelper.AcquireUserSlotWithWait(c, subject.UserID, subject.Concurrency, reqStream, &streamStarted) if err != nil { - log.Printf("User concurrency acquire failed: %v", err) + reqLog.Warn("gateway.user_slot_acquire_failed", zap.Error(err)) h.handleConcurrencyError(c, err, "user", streamStarted) return } @@ -197,7 +213,7 @@ func (h *GatewayHandler) Messages(c *gin.Context) { // 2. 【新增】Wait后二次检查余额/订阅 if err := h.billingCacheService.CheckBillingEligibility(c.Request.Context(), apiKey.User, apiKey, apiKey.Group, subscription); err != nil { - log.Printf("Billing eligibility check failed after wait: %v", err) + reqLog.Info("gateway.billing_eligibility_check_failed", zap.Error(err)) status, code, message := billingErrorDetails(err) h.handleStreamingAwareError(c, status, code, message, streamStarted) return @@ -227,6 +243,15 @@ func (h *GatewayHandler) Messages(c *gin.Context) { var sessionBoundAccountID int64 if sessionKey != "" { sessionBoundAccountID, _ = h.gatewayService.GetCachedSessionAccountID(c.Request.Context(), apiKey.GroupID, sessionKey) + if sessionBoundAccountID > 0 { + prefetchedGroupID := int64(0) + if apiKey.GroupID != nil { + prefetchedGroupID = *apiKey.GroupID + } + ctx := context.WithValue(c.Request.Context(), ctxkey.PrefetchedStickyAccountID, sessionBoundAccountID) + ctx = context.WithValue(ctx, ctxkey.PrefetchedStickyGroupID, prefetchedGroupID) + c.Request = c.Request.WithContext(ctx) + } } // 判断是否真的绑定了粘性会话:有 sessionKey 且已经绑定到某个账号 hasBoundSession := sessionKey != "" && sessionBoundAccountID > 0 @@ -250,7 +275,8 @@ func (h *GatewayHandler) Messages(c *gin.Context) { selection, err := h.gatewayService.SelectAccountWithLoadAwareness(c.Request.Context(), apiKey.GroupID, sessionKey, reqModel, failedAccountIDs, "") // Gemini 不使用会话限制 if err != nil { if len(failedAccountIDs) == 0 { - h.handleStreamingAwareError(c, http.StatusServiceUnavailable, "api_error", "No available accounts: "+err.Error(), streamStarted) + reqLog.Warn("gateway.account_select_failed", zap.Error(err), zap.Int("excluded_account_count", len(failedAccountIDs))) + h.handleStreamingAwareError(c, http.StatusServiceUnavailable, "api_error", "Service temporarily unavailable", streamStarted) return } // Antigravity 单账号退避重试:分组内没有其他可用账号时, @@ -258,7 +284,10 @@ func (h *GatewayHandler) Messages(c *gin.Context) { // 谷歌上游 503 (MODEL_CAPACITY_EXHAUSTED) 通常是暂时性的,等几秒就能恢复。 if lastFailoverErr != nil && lastFailoverErr.StatusCode == http.StatusServiceUnavailable && switchCount <= maxAccountSwitches { if sleepAntigravitySingleAccountBackoff(c.Request.Context(), switchCount) { - log.Printf("Antigravity single-account 503 retry: clearing failed accounts, retry %d/%d", switchCount, maxAccountSwitches) + reqLog.Warn("gateway.single_account_retrying", + zap.Int("retry_count", switchCount), + zap.Int("max_retries", maxAccountSwitches), + ) failedAccountIDs = make(map[int64]struct{}) // 设置 context 标记,让 Service 层预检查等待限流过期而非直接切换 ctx := context.WithValue(c.Request.Context(), ctxkey.SingleAccountRetry, true) @@ -274,7 +303,7 @@ func (h *GatewayHandler) Messages(c *gin.Context) { return } account := selection.Account - setOpsSelectedAccount(c, account.ID) + setOpsSelectedAccount(c, account.ID, account.Platform) // 检查请求拦截(预热请求、SUGGESTION MODE等) if account.IsInterceptWarmupEnabled() { @@ -302,21 +331,24 @@ func (h *GatewayHandler) Messages(c *gin.Context) { accountWaitCounted := false canWait, err := h.concurrencyHelper.IncrementAccountWaitCount(c.Request.Context(), account.ID, selection.WaitPlan.MaxWaiting) if err != nil { - log.Printf("Increment account wait count failed: %v", err) + reqLog.Warn("gateway.account_wait_counter_increment_failed", zap.Int64("account_id", account.ID), zap.Error(err)) } else if !canWait { - log.Printf("Account wait queue full: account=%d", account.ID) + reqLog.Info("gateway.account_wait_queue_full", + zap.Int64("account_id", account.ID), + zap.Int("max_waiting", selection.WaitPlan.MaxWaiting), + ) h.handleStreamingAwareError(c, http.StatusTooManyRequests, "rate_limit_error", "Too many pending requests, please retry later", streamStarted) return } if err == nil && canWait { accountWaitCounted = true } - // Ensure the wait counter is decremented if we exit before acquiring the slot. - defer func() { + releaseWait := func() { if accountWaitCounted { h.concurrencyHelper.DecrementAccountWaitCount(c.Request.Context(), account.ID) + accountWaitCounted = false } - }() + } accountReleaseFunc, err = h.concurrencyHelper.AcquireAccountSlotWithWaitTimeout( c, @@ -327,17 +359,15 @@ func (h *GatewayHandler) Messages(c *gin.Context) { &streamStarted, ) if err != nil { - log.Printf("Account concurrency acquire failed: %v", err) + reqLog.Warn("gateway.account_slot_acquire_failed", zap.Int64("account_id", account.ID), zap.Error(err)) + releaseWait() h.handleConcurrencyError(c, err, "account", streamStarted) return } // Slot acquired: no longer waiting in queue. - if accountWaitCounted { - h.concurrencyHelper.DecrementAccountWaitCount(c.Request.Context(), account.ID) - accountWaitCounted = false - } + releaseWait() if err := h.gatewayService.BindStickySession(c.Request.Context(), apiKey.GroupID, sessionKey, account.ID); err != nil { - log.Printf("Bind sticky session failed: %v", err) + reqLog.Warn("gateway.bind_sticky_session_failed", zap.Int64("account_id", account.ID), zap.Error(err)) } } // 账号槽位/等待计数需要在超时或断开时安全回收 @@ -387,7 +417,12 @@ func (h *GatewayHandler) Messages(c *gin.Context) { return } switchCount++ - log.Printf("Account %d: upstream error %d, switching account %d/%d", account.ID, failoverErr.StatusCode, switchCount, maxAccountSwitches) + reqLog.Warn("gateway.upstream_failover_switching", + zap.Int64("account_id", account.ID), + zap.Int("upstream_status", failoverErr.StatusCode), + zap.Int("switch_count", switchCount), + zap.Int("max_switches", maxAccountSwitches), + ) if account.Platform == service.PlatformAntigravity { if !sleepFailoverDelay(c.Request.Context(), switchCount) { return @@ -395,8 +430,12 @@ func (h *GatewayHandler) Messages(c *gin.Context) { } continue } - // 错误响应已在Forward中处理,这里只记录日志 - log.Printf("Forward request failed: %v", err) + wroteFallback := h.ensureForwardErrorResponse(c, streamStarted) + reqLog.Error("gateway.forward_failed", + zap.Int64("account_id", account.ID), + zap.Bool("fallback_error_response_written", wroteFallback), + zap.Error(err), + ) return } @@ -404,24 +443,29 @@ func (h *GatewayHandler) Messages(c *gin.Context) { userAgent := c.GetHeader("User-Agent") clientIP := ip.GetClientIP(c) - // 异步记录使用量(subscription已在函数开头获取) - go func(result *service.ForwardResult, usedAccount *service.Account, ua, clientIP string, fcb bool) { - ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) - defer cancel() + // 使用量记录通过有界 worker 池提交,避免请求热路径创建无界 goroutine。 + h.submitUsageRecordTask(func(ctx context.Context) { if err := h.gatewayService.RecordUsage(ctx, &service.RecordUsageInput{ Result: result, APIKey: apiKey, User: apiKey.User, - Account: usedAccount, + Account: account, Subscription: subscription, - UserAgent: ua, + UserAgent: userAgent, IPAddress: clientIP, - ForceCacheBilling: fcb, + ForceCacheBilling: forceCacheBilling, APIKeyService: h.apiKeyService, }); err != nil { - log.Printf("Record usage failed: %v", err) + logger.L().With( + zap.String("component", "handler.gateway.messages"), + zap.Int64("user_id", subject.UserID), + zap.Int64("api_key_id", apiKey.ID), + zap.Any("group_id", apiKey.GroupID), + zap.String("model", reqModel), + zap.Int64("account_id", account.ID), + ).Error("gateway.record_usage_failed", zap.Error(err)) } - }(result, account, userAgent, clientIP, forceCacheBilling) + }) return } } @@ -455,7 +499,8 @@ func (h *GatewayHandler) Messages(c *gin.Context) { selection, err := h.gatewayService.SelectAccountWithLoadAwareness(c.Request.Context(), currentAPIKey.GroupID, sessionKey, reqModel, failedAccountIDs, parsedReq.MetadataUserID) if err != nil { if len(failedAccountIDs) == 0 { - h.handleStreamingAwareError(c, http.StatusServiceUnavailable, "api_error", "No available accounts: "+err.Error(), streamStarted) + reqLog.Warn("gateway.account_select_failed", zap.Error(err), zap.Int("excluded_account_count", len(failedAccountIDs))) + h.handleStreamingAwareError(c, http.StatusServiceUnavailable, "api_error", "Service temporarily unavailable", streamStarted) return } // Antigravity 单账号退避重试:分组内没有其他可用账号时, @@ -463,7 +508,10 @@ func (h *GatewayHandler) Messages(c *gin.Context) { // 谷歌上游 503 (MODEL_CAPACITY_EXHAUSTED) 通常是暂时性的,等几秒就能恢复。 if lastFailoverErr != nil && lastFailoverErr.StatusCode == http.StatusServiceUnavailable && switchCount <= maxAccountSwitches { if sleepAntigravitySingleAccountBackoff(c.Request.Context(), switchCount) { - log.Printf("Antigravity single-account 503 retry: clearing failed accounts, retry %d/%d", switchCount, maxAccountSwitches) + reqLog.Warn("gateway.single_account_retrying", + zap.Int("retry_count", switchCount), + zap.Int("max_retries", maxAccountSwitches), + ) failedAccountIDs = make(map[int64]struct{}) // 设置 context 标记,让 Service 层预检查等待限流过期而非直接切换 ctx := context.WithValue(c.Request.Context(), ctxkey.SingleAccountRetry, true) @@ -479,7 +527,7 @@ func (h *GatewayHandler) Messages(c *gin.Context) { return } account := selection.Account - setOpsSelectedAccount(c, account.ID) + setOpsSelectedAccount(c, account.ID, account.Platform) // 检查请求拦截(预热请求、SUGGESTION MODE等) if account.IsInterceptWarmupEnabled() { @@ -507,20 +555,24 @@ func (h *GatewayHandler) Messages(c *gin.Context) { accountWaitCounted := false canWait, err := h.concurrencyHelper.IncrementAccountWaitCount(c.Request.Context(), account.ID, selection.WaitPlan.MaxWaiting) if err != nil { - log.Printf("Increment account wait count failed: %v", err) + reqLog.Warn("gateway.account_wait_counter_increment_failed", zap.Int64("account_id", account.ID), zap.Error(err)) } else if !canWait { - log.Printf("Account wait queue full: account=%d", account.ID) + reqLog.Info("gateway.account_wait_queue_full", + zap.Int64("account_id", account.ID), + zap.Int("max_waiting", selection.WaitPlan.MaxWaiting), + ) h.handleStreamingAwareError(c, http.StatusTooManyRequests, "rate_limit_error", "Too many pending requests, please retry later", streamStarted) return } if err == nil && canWait { accountWaitCounted = true } - defer func() { + releaseWait := func() { if accountWaitCounted { h.concurrencyHelper.DecrementAccountWaitCount(c.Request.Context(), account.ID) + accountWaitCounted = false } - }() + } accountReleaseFunc, err = h.concurrencyHelper.AcquireAccountSlotWithWaitTimeout( c, @@ -531,16 +583,15 @@ func (h *GatewayHandler) Messages(c *gin.Context) { &streamStarted, ) if err != nil { - log.Printf("Account concurrency acquire failed: %v", err) + reqLog.Warn("gateway.account_slot_acquire_failed", zap.Int64("account_id", account.ID), zap.Error(err)) + releaseWait() h.handleConcurrencyError(c, err, "account", streamStarted) return } - if accountWaitCounted { - h.concurrencyHelper.DecrementAccountWaitCount(c.Request.Context(), account.ID) - accountWaitCounted = false - } + // Slot acquired: no longer waiting in queue. + releaseWait() if err := h.gatewayService.BindStickySession(c.Request.Context(), currentAPIKey.GroupID, sessionKey, account.ID); err != nil { - log.Printf("Bind sticky session failed: %v", err) + reqLog.Warn("gateway.bind_sticky_session_failed", zap.Int64("account_id", account.ID), zap.Error(err)) } } // 账号槽位/等待计数需要在超时或断开时安全回收 @@ -563,18 +614,26 @@ func (h *GatewayHandler) Messages(c *gin.Context) { if err != nil { var promptTooLongErr *service.PromptTooLongError if errors.As(err, &promptTooLongErr) { - log.Printf("Prompt too long from antigravity: group=%d fallback_group_id=%v fallback_used=%v", currentAPIKey.GroupID, fallbackGroupID, fallbackUsed) + reqLog.Warn("gateway.prompt_too_long_from_antigravity", + zap.Any("current_group_id", currentAPIKey.GroupID), + zap.Any("fallback_group_id", fallbackGroupID), + zap.Bool("fallback_used", fallbackUsed), + ) if !fallbackUsed && fallbackGroupID != nil && *fallbackGroupID > 0 { fallbackGroup, err := h.gatewayService.ResolveGroupByID(c.Request.Context(), *fallbackGroupID) if err != nil { - log.Printf("Resolve fallback group failed: %v", err) + reqLog.Warn("gateway.resolve_fallback_group_failed", zap.Int64("fallback_group_id", *fallbackGroupID), zap.Error(err)) _ = h.antigravityGatewayService.WriteMappedClaudeError(c, account, promptTooLongErr.StatusCode, promptTooLongErr.RequestID, promptTooLongErr.Body) return } if fallbackGroup.Platform != service.PlatformAnthropic || fallbackGroup.SubscriptionType == service.SubscriptionTypeSubscription || fallbackGroup.FallbackGroupIDOnInvalidRequest != nil { - log.Printf("Fallback group invalid: group=%d platform=%s subscription=%s", fallbackGroup.ID, fallbackGroup.Platform, fallbackGroup.SubscriptionType) + reqLog.Warn("gateway.fallback_group_invalid", + zap.Int64("fallback_group_id", fallbackGroup.ID), + zap.String("fallback_platform", fallbackGroup.Platform), + zap.String("fallback_subscription_type", fallbackGroup.SubscriptionType), + ) _ = h.antigravityGatewayService.WriteMappedClaudeError(c, account, promptTooLongErr.StatusCode, promptTooLongErr.RequestID, promptTooLongErr.Body) return } @@ -625,7 +684,12 @@ func (h *GatewayHandler) Messages(c *gin.Context) { return } switchCount++ - log.Printf("Account %d: upstream error %d, switching account %d/%d", account.ID, failoverErr.StatusCode, switchCount, maxAccountSwitches) + reqLog.Warn("gateway.upstream_failover_switching", + zap.Int64("account_id", account.ID), + zap.Int("upstream_status", failoverErr.StatusCode), + zap.Int("switch_count", switchCount), + zap.Int("max_switches", maxAccountSwitches), + ) if account.Platform == service.PlatformAntigravity { if !sleepFailoverDelay(c.Request.Context(), switchCount) { return @@ -633,8 +697,12 @@ func (h *GatewayHandler) Messages(c *gin.Context) { } continue } - // 错误响应已在Forward中处理,这里只记录日志 - log.Printf("Account %d: Forward request failed: %v", account.ID, err) + wroteFallback := h.ensureForwardErrorResponse(c, streamStarted) + reqLog.Error("gateway.forward_failed", + zap.Int64("account_id", account.ID), + zap.Bool("fallback_error_response_written", wroteFallback), + zap.Error(err), + ) return } @@ -642,24 +710,34 @@ func (h *GatewayHandler) Messages(c *gin.Context) { userAgent := c.GetHeader("User-Agent") clientIP := ip.GetClientIP(c) - // 异步记录使用量(subscription已在函数开头获取) - go func(result *service.ForwardResult, usedAccount *service.Account, ua, clientIP string, fcb bool) { - ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) - defer cancel() + // 使用量记录通过有界 worker 池提交,避免请求热路径创建无界 goroutine。 + h.submitUsageRecordTask(func(ctx context.Context) { if err := h.gatewayService.RecordUsage(ctx, &service.RecordUsageInput{ Result: result, APIKey: currentAPIKey, User: currentAPIKey.User, - Account: usedAccount, + Account: account, Subscription: currentSubscription, - UserAgent: ua, + UserAgent: userAgent, IPAddress: clientIP, - ForceCacheBilling: fcb, + ForceCacheBilling: forceCacheBilling, APIKeyService: h.apiKeyService, }); err != nil { - log.Printf("Record usage failed: %v", err) + logger.L().With( + zap.String("component", "handler.gateway.messages"), + zap.Int64("user_id", subject.UserID), + zap.Int64("api_key_id", currentAPIKey.ID), + zap.Any("group_id", currentAPIKey.GroupID), + zap.String("model", reqModel), + zap.Int64("account_id", account.ID), + ).Error("gateway.record_usage_failed", zap.Error(err)) } - }(result, account, userAgent, clientIP, forceCacheBilling) + }) + reqLog.Debug("gateway.request_completed", + zap.Int64("account_id", account.ID), + zap.Int("switch_count", switchCount), + zap.Bool("fallback_used", fallbackUsed), + ) return } if !retryWithFallback { @@ -682,6 +760,17 @@ func (h *GatewayHandler) Models(c *gin.Context) { groupID = &apiKey.Group.ID platform = apiKey.Group.Platform } + if forcedPlatform, ok := middleware2.GetForcePlatformFromContext(c); ok && strings.TrimSpace(forcedPlatform) != "" { + platform = forcedPlatform + } + + if platform == service.PlatformSora { + c.JSON(http.StatusOK, gin.H{ + "object": "list", + "data": service.DefaultSoraModels(h.cfg), + }) + return + } // Get available models from account configurations (without platform filter) availableModels := h.gatewayService.GetAvailableModels(c.Request.Context(), groupID, "") @@ -942,7 +1031,11 @@ func sleepAntigravitySingleAccountBackoff(ctx context.Context, retryCount int) b // Handler 层只需短暂间隔后重新进入 Service 层即可。 const delay = 2 * time.Second - log.Printf("Antigravity single-account 503 backoff: waiting %v before retry (attempt %d)", delay, retryCount) + logger.L().With( + zap.String("component", "handler.gateway.failover"), + zap.Duration("delay", delay), + zap.Int("retry_count", retryCount), + ).Info("gateway.single_account_backoff_waiting") select { case <-ctx.Done(): @@ -1040,6 +1133,15 @@ func (h *GatewayHandler) handleStreamingAwareError(c *gin.Context, status int, e h.errorResponse(c, status, errType, message) } +// ensureForwardErrorResponse 在 Forward 返回错误但尚未写响应时补写统一错误响应。 +func (h *GatewayHandler) ensureForwardErrorResponse(c *gin.Context, streamStarted bool) bool { + if c == nil || c.Writer == nil || c.Writer.Written() { + return false + } + h.handleStreamingAwareError(c, http.StatusBadGateway, "upstream_error", "Upstream request failed", streamStarted) + return true +} + // errorResponse 返回Claude API格式的错误响应 func (h *GatewayHandler) errorResponse(c *gin.Context, status int, errType, message string) { c.JSON(status, gin.H{ @@ -1067,6 +1169,12 @@ func (h *GatewayHandler) CountTokens(c *gin.Context) { h.errorResponse(c, http.StatusInternalServerError, "api_error", "User context not found") return } + reqLog := requestLogger( + c, + "handler.gateway.count_tokens", + zap.Int64("api_key_id", apiKey.ID), + zap.Any("group_id", apiKey.GroupID), + ) // 读取请求体 body, err := io.ReadAll(c.Request.Body) @@ -1094,6 +1202,7 @@ func (h *GatewayHandler) CountTokens(c *gin.Context) { h.errorResponse(c, http.StatusBadRequest, "invalid_request_error", "Failed to parse request body") return } + reqLog = reqLog.With(zap.String("model", parsedReq.Model), zap.Bool("stream", parsedReq.Stream)) // 在请求上下文中记录 thinking 状态,供 Antigravity 最终模型 key 推导/模型维度限流使用 c.Request = c.Request.WithContext(context.WithValue(c.Request.Context(), ctxkey.ThinkingEnabled, parsedReq.ThinkingEnabled)) @@ -1127,14 +1236,15 @@ func (h *GatewayHandler) CountTokens(c *gin.Context) { // 选择支持该模型的账号 account, err := h.gatewayService.SelectAccountForModel(c.Request.Context(), apiKey.GroupID, sessionHash, parsedReq.Model) if err != nil { - h.errorResponse(c, http.StatusServiceUnavailable, "api_error", "No available accounts: "+err.Error()) + reqLog.Warn("gateway.count_tokens_select_account_failed", zap.Error(err)) + h.errorResponse(c, http.StatusServiceUnavailable, "api_error", "Service temporarily unavailable") return } - setOpsSelectedAccount(c, account.ID) + setOpsSelectedAccount(c, account.ID, account.Platform) // 转发请求(不记录使用量) if err := h.gatewayService.ForwardCountTokens(c.Request.Context(), c, account, parsedReq); err != nil { - log.Printf("Forward count_tokens request failed: %v", err) + reqLog.Error("gateway.count_tokens_forward_failed", zap.Int64("account_id", account.ID), zap.Error(err)) // 错误响应已在 ForwardCountTokens 中处理 return } @@ -1398,7 +1508,25 @@ func billingErrorDetails(err error) (status int, code, message string) { } msg := pkgerrors.Message(err) if msg == "" { - msg = err.Error() + logger.L().With( + zap.String("component", "handler.gateway.billing"), + zap.Error(err), + ).Warn("gateway.billing_error_missing_message") + msg = "Billing error" } return http.StatusForbidden, "billing_error", msg } + +func (h *GatewayHandler) submitUsageRecordTask(task service.UsageRecordTask) { + if task == nil { + return + } + if h.usageRecordWorkerPool != nil { + h.usageRecordWorkerPool.Submit(task) + return + } + // 回退路径:worker 池未注入时同步执行,避免退回到无界 goroutine 模式。 + ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) + defer cancel() + task(ctx) +} diff --git a/backend/internal/handler/gateway_handler_error_fallback_test.go b/backend/internal/handler/gateway_handler_error_fallback_test.go new file mode 100644 index 00000000..4fce5ec1 --- /dev/null +++ b/backend/internal/handler/gateway_handler_error_fallback_test.go @@ -0,0 +1,49 @@ +package handler + +import ( + "encoding/json" + "net/http" + "net/http/httptest" + "testing" + + "github.com/gin-gonic/gin" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestGatewayEnsureForwardErrorResponse_WritesFallbackWhenNotWritten(t *testing.T) { + gin.SetMode(gin.TestMode) + w := httptest.NewRecorder() + c, _ := gin.CreateTestContext(w) + c.Request = httptest.NewRequest(http.MethodGet, "/", nil) + + h := &GatewayHandler{} + wrote := h.ensureForwardErrorResponse(c, false) + + require.True(t, wrote) + require.Equal(t, http.StatusBadGateway, w.Code) + + var parsed map[string]any + err := json.Unmarshal(w.Body.Bytes(), &parsed) + require.NoError(t, err) + assert.Equal(t, "error", parsed["type"]) + errorObj, ok := parsed["error"].(map[string]any) + require.True(t, ok) + assert.Equal(t, "upstream_error", errorObj["type"]) + assert.Equal(t, "Upstream request failed", errorObj["message"]) +} + +func TestGatewayEnsureForwardErrorResponse_DoesNotOverrideWrittenResponse(t *testing.T) { + gin.SetMode(gin.TestMode) + w := httptest.NewRecorder() + c, _ := gin.CreateTestContext(w) + c.Request = httptest.NewRequest(http.MethodGet, "/", nil) + c.String(http.StatusTeapot, "already written") + + h := &GatewayHandler{} + wrote := h.ensureForwardErrorResponse(c, false) + + require.False(t, wrote) + require.Equal(t, http.StatusTeapot, w.Code) + assert.Equal(t, "already written", w.Body.String()) +} diff --git a/backend/internal/handler/gateway_helper.go b/backend/internal/handler/gateway_helper.go index 0393f954..efff7997 100644 --- a/backend/internal/handler/gateway_helper.go +++ b/backend/internal/handler/gateway_helper.go @@ -4,8 +4,9 @@ import ( "context" "encoding/json" "fmt" - "math/rand" + "math/rand/v2" "net/http" + "strings" "sync" "time" @@ -20,14 +21,28 @@ var claudeCodeValidator = service.NewClaudeCodeValidator() // SetClaudeCodeClientContext 检查请求是否来自 Claude Code 客户端,并设置到 context 中 // 返回更新后的 context func SetClaudeCodeClientContext(c *gin.Context, body []byte) { - // 解析请求体为 map - var bodyMap map[string]any - if len(body) > 0 { - _ = json.Unmarshal(body, &bodyMap) + if c == nil || c.Request == nil { + return + } + // Fast path:非 Claude CLI UA 直接判定 false,避免热路径二次 JSON 反序列化。 + if !claudeCodeValidator.ValidateUserAgent(c.GetHeader("User-Agent")) { + ctx := service.SetClaudeCodeClient(c.Request.Context(), false) + c.Request = c.Request.WithContext(ctx) + return } - // 验证是否为 Claude Code 客户端 - isClaudeCode := claudeCodeValidator.Validate(c.Request, bodyMap) + isClaudeCode := false + if !strings.Contains(c.Request.URL.Path, "messages") { + // 与 Validate 行为一致:非 messages 路径 UA 命中即可视为 Claude Code 客户端。 + isClaudeCode = true + } else { + // 仅在确认为 Claude CLI 且 messages 路径时再做 body 解析。 + var bodyMap map[string]any + if len(body) > 0 { + _ = json.Unmarshal(body, &bodyMap) + } + isClaudeCode = claudeCodeValidator.Validate(c.Request, bodyMap) + } // 更新 request context ctx := service.SetClaudeCodeClient(c.Request.Context(), isClaudeCode) @@ -104,31 +119,24 @@ func NewConcurrencyHelper(concurrencyService *service.ConcurrencyService, pingFo // wrapReleaseOnDone ensures release runs at most once and still triggers on context cancellation. // 用于避免客户端断开或上游超时导致的并发槽位泄漏。 -// 修复:添加 quit channel 确保 goroutine 及时退出,避免泄露 +// 优化:基于 context.AfterFunc 注册回调,避免每请求额外守护 goroutine。 func wrapReleaseOnDone(ctx context.Context, releaseFunc func()) func() { if releaseFunc == nil { return nil } var once sync.Once - quit := make(chan struct{}) + var stop func() bool release := func() { once.Do(func() { + if stop != nil { + _ = stop() + } releaseFunc() - close(quit) // 通知监听 goroutine 退出 }) } - go func() { - select { - case <-ctx.Done(): - // Context 取消时释放资源 - release() - case <-quit: - // 正常释放已完成,goroutine 退出 - return - } - }() + stop = context.AfterFunc(ctx, release) return release } @@ -153,6 +161,32 @@ func (h *ConcurrencyHelper) DecrementAccountWaitCount(ctx context.Context, accou h.concurrencyService.DecrementAccountWaitCount(ctx, accountID) } +// TryAcquireUserSlot 尝试立即获取用户并发槽位。 +// 返回值: (releaseFunc, acquired, error) +func (h *ConcurrencyHelper) TryAcquireUserSlot(ctx context.Context, userID int64, maxConcurrency int) (func(), bool, error) { + result, err := h.concurrencyService.AcquireUserSlot(ctx, userID, maxConcurrency) + if err != nil { + return nil, false, err + } + if !result.Acquired { + return nil, false, nil + } + return result.ReleaseFunc, true, nil +} + +// TryAcquireAccountSlot 尝试立即获取账号并发槽位。 +// 返回值: (releaseFunc, acquired, error) +func (h *ConcurrencyHelper) TryAcquireAccountSlot(ctx context.Context, accountID int64, maxConcurrency int) (func(), bool, error) { + result, err := h.concurrencyService.AcquireAccountSlot(ctx, accountID, maxConcurrency) + if err != nil { + return nil, false, err + } + if !result.Acquired { + return nil, false, nil + } + return result.ReleaseFunc, true, nil +} + // AcquireUserSlotWithWait acquires a user concurrency slot, waiting if necessary. // For streaming requests, sends ping events during the wait. // streamStarted is updated if streaming response has begun. @@ -160,13 +194,13 @@ func (h *ConcurrencyHelper) AcquireUserSlotWithWait(c *gin.Context, userID int64 ctx := c.Request.Context() // Try to acquire immediately - result, err := h.concurrencyService.AcquireUserSlot(ctx, userID, maxConcurrency) + releaseFunc, acquired, err := h.TryAcquireUserSlot(ctx, userID, maxConcurrency) if err != nil { return nil, err } - if result.Acquired { - return result.ReleaseFunc, nil + if acquired { + return releaseFunc, nil } // Need to wait - handle streaming ping if needed @@ -180,13 +214,13 @@ func (h *ConcurrencyHelper) AcquireAccountSlotWithWait(c *gin.Context, accountID ctx := c.Request.Context() // Try to acquire immediately - result, err := h.concurrencyService.AcquireAccountSlot(ctx, accountID, maxConcurrency) + releaseFunc, acquired, err := h.TryAcquireAccountSlot(ctx, accountID, maxConcurrency) if err != nil { return nil, err } - if result.Acquired { - return result.ReleaseFunc, nil + if acquired { + return releaseFunc, nil } // Need to wait - handle streaming ping if needed @@ -196,27 +230,29 @@ func (h *ConcurrencyHelper) AcquireAccountSlotWithWait(c *gin.Context, accountID // waitForSlotWithPing waits for a concurrency slot, sending ping events for streaming requests. // streamStarted pointer is updated when streaming begins (for proper error handling by caller). func (h *ConcurrencyHelper) waitForSlotWithPing(c *gin.Context, slotType string, id int64, maxConcurrency int, isStream bool, streamStarted *bool) (func(), error) { - return h.waitForSlotWithPingTimeout(c, slotType, id, maxConcurrency, maxConcurrencyWait, isStream, streamStarted) + return h.waitForSlotWithPingTimeout(c, slotType, id, maxConcurrency, maxConcurrencyWait, isStream, streamStarted, false) } // waitForSlotWithPingTimeout waits for a concurrency slot with a custom timeout. -func (h *ConcurrencyHelper) waitForSlotWithPingTimeout(c *gin.Context, slotType string, id int64, maxConcurrency int, timeout time.Duration, isStream bool, streamStarted *bool) (func(), error) { +func (h *ConcurrencyHelper) waitForSlotWithPingTimeout(c *gin.Context, slotType string, id int64, maxConcurrency int, timeout time.Duration, isStream bool, streamStarted *bool, tryImmediate bool) (func(), error) { ctx, cancel := context.WithTimeout(c.Request.Context(), timeout) defer cancel() - // Try immediate acquire first (avoid unnecessary wait) - var result *service.AcquireResult - var err error - if slotType == "user" { - result, err = h.concurrencyService.AcquireUserSlot(ctx, id, maxConcurrency) - } else { - result, err = h.concurrencyService.AcquireAccountSlot(ctx, id, maxConcurrency) + acquireSlot := func() (*service.AcquireResult, error) { + if slotType == "user" { + return h.concurrencyService.AcquireUserSlot(ctx, id, maxConcurrency) + } + return h.concurrencyService.AcquireAccountSlot(ctx, id, maxConcurrency) } - if err != nil { - return nil, err - } - if result.Acquired { - return result.ReleaseFunc, nil + + if tryImmediate { + result, err := acquireSlot() + if err != nil { + return nil, err + } + if result.Acquired { + return result.ReleaseFunc, nil + } } // Determine if ping is needed (streaming + ping format defined) @@ -242,7 +278,6 @@ func (h *ConcurrencyHelper) waitForSlotWithPingTimeout(c *gin.Context, slotType backoff := initialBackoff timer := time.NewTimer(backoff) defer timer.Stop() - rng := rand.New(rand.NewSource(time.Now().UnixNano())) for { select { @@ -268,15 +303,7 @@ func (h *ConcurrencyHelper) waitForSlotWithPingTimeout(c *gin.Context, slotType case <-timer.C: // Try to acquire slot - var result *service.AcquireResult - var err error - - if slotType == "user" { - result, err = h.concurrencyService.AcquireUserSlot(ctx, id, maxConcurrency) - } else { - result, err = h.concurrencyService.AcquireAccountSlot(ctx, id, maxConcurrency) - } - + result, err := acquireSlot() if err != nil { return nil, err } @@ -284,7 +311,7 @@ func (h *ConcurrencyHelper) waitForSlotWithPingTimeout(c *gin.Context, slotType if result.Acquired { return result.ReleaseFunc, nil } - backoff = nextBackoff(backoff, rng) + backoff = nextBackoff(backoff) timer.Reset(backoff) } } @@ -292,26 +319,22 @@ func (h *ConcurrencyHelper) waitForSlotWithPingTimeout(c *gin.Context, slotType // AcquireAccountSlotWithWaitTimeout acquires an account slot with a custom timeout (keeps SSE ping). func (h *ConcurrencyHelper) AcquireAccountSlotWithWaitTimeout(c *gin.Context, accountID int64, maxConcurrency int, timeout time.Duration, isStream bool, streamStarted *bool) (func(), error) { - return h.waitForSlotWithPingTimeout(c, "account", accountID, maxConcurrency, timeout, isStream, streamStarted) + return h.waitForSlotWithPingTimeout(c, "account", accountID, maxConcurrency, timeout, isStream, streamStarted, true) } // nextBackoff 计算下一次退避时间 // 性能优化:使用指数退避 + 随机抖动,避免惊群效应 // current: 当前退避时间 -// rng: 随机数生成器(可为 nil,此时不添加抖动) // 返回值:下一次退避时间(100ms ~ 2s 之间) -func nextBackoff(current time.Duration, rng *rand.Rand) time.Duration { +func nextBackoff(current time.Duration) time.Duration { // 指数退避:当前时间 * 1.5 next := time.Duration(float64(current) * backoffMultiplier) if next > maxBackoff { next = maxBackoff } - if rng == nil { - return next - } // 添加 ±20% 的随机抖动(jitter 范围 0.8 ~ 1.2) // 抖动可以分散多个请求的重试时间点,避免同时冲击 Redis - jitter := 0.8 + rng.Float64()*0.4 + jitter := 0.8 + rand.Float64()*0.4 jittered := time.Duration(float64(next) * jitter) if jittered < initialBackoff { return initialBackoff diff --git a/backend/internal/handler/gateway_helper_backoff_test.go b/backend/internal/handler/gateway_helper_backoff_test.go new file mode 100644 index 00000000..a5056bbb --- /dev/null +++ b/backend/internal/handler/gateway_helper_backoff_test.go @@ -0,0 +1,106 @@ +package handler + +import ( + "testing" + "time" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +// --- Task 6.2 验证: math/rand/v2 迁移后 nextBackoff 行为正确 --- + +func TestNextBackoff_ExponentialGrowth(t *testing.T) { + // 验证退避时间指数增长(乘数 1.5) + // 由于有随机抖动(±20%),需要验证范围 + current := initialBackoff // 100ms + + for i := 0; i < 10; i++ { + next := nextBackoff(current) + + // 退避结果应在 [initialBackoff, maxBackoff] 范围内 + assert.GreaterOrEqual(t, int64(next), int64(initialBackoff), + "第 %d 次退避不应低于初始值 %v", i, initialBackoff) + assert.LessOrEqual(t, int64(next), int64(maxBackoff), + "第 %d 次退避不应超过最大值 %v", i, maxBackoff) + + // 为下一轮提供当前退避值 + current = next + } +} + +func TestNextBackoff_BoundedByMaxBackoff(t *testing.T) { + // 即使输入非常大,输出也不超过 maxBackoff + for i := 0; i < 100; i++ { + result := nextBackoff(10 * time.Second) + assert.LessOrEqual(t, int64(result), int64(maxBackoff), + "退避值不应超过 maxBackoff") + } +} + +func TestNextBackoff_BoundedByInitialBackoff(t *testing.T) { + // 即使输入非常小,输出也不低于 initialBackoff + for i := 0; i < 100; i++ { + result := nextBackoff(1 * time.Millisecond) + assert.GreaterOrEqual(t, int64(result), int64(initialBackoff), + "退避值不应低于 initialBackoff") + } +} + +func TestNextBackoff_HasJitter(t *testing.T) { + // 验证多次调用会产生不同的值(随机抖动生效) + // 使用相同的输入调用 50 次,收集结果 + results := make(map[time.Duration]bool) + current := 500 * time.Millisecond + + for i := 0; i < 50; i++ { + result := nextBackoff(current) + results[result] = true + } + + // 50 次调用应该至少有 2 个不同的值(抖动存在) + require.Greater(t, len(results), 1, + "nextBackoff 应产生随机抖动,但所有 50 次调用结果相同") +} + +func TestNextBackoff_InitialValueGrows(t *testing.T) { + // 验证从初始值开始,退避趋势是增长的 + current := initialBackoff + var sum time.Duration + + runs := 100 + for i := 0; i < runs; i++ { + next := nextBackoff(current) + sum += next + current = next + } + + avg := sum / time.Duration(runs) + // 平均退避时间应大于初始值(因为指数增长 + 上限) + assert.Greater(t, int64(avg), int64(initialBackoff), + "平均退避时间应大于初始退避值") +} + +func TestNextBackoff_ConvergesToMaxBackoff(t *testing.T) { + // 从初始值开始,经过多次退避后应收敛到 maxBackoff 附近 + current := initialBackoff + for i := 0; i < 20; i++ { + current = nextBackoff(current) + } + + // 经过 20 次迭代后,应该已经到达 maxBackoff 区间 + // 由于抖动,允许 ±20% 的范围 + lowerBound := time.Duration(float64(maxBackoff) * 0.8) + assert.GreaterOrEqual(t, int64(current), int64(lowerBound), + "经过多次退避后应收敛到 maxBackoff 附近") +} + +func BenchmarkNextBackoff(b *testing.B) { + current := initialBackoff + for i := 0; i < b.N; i++ { + current = nextBackoff(current) + if current > maxBackoff { + current = initialBackoff + } + } +} diff --git a/backend/internal/handler/gateway_helper_fastpath_test.go b/backend/internal/handler/gateway_helper_fastpath_test.go new file mode 100644 index 00000000..3e6c376b --- /dev/null +++ b/backend/internal/handler/gateway_helper_fastpath_test.go @@ -0,0 +1,114 @@ +package handler + +import ( + "context" + "sync/atomic" + "testing" + "time" + + "github.com/Wei-Shaw/sub2api/internal/service" + "github.com/stretchr/testify/require" +) + +type concurrencyCacheMock struct { + acquireUserSlotFn func(ctx context.Context, userID int64, maxConcurrency int, requestID string) (bool, error) + acquireAccountSlotFn func(ctx context.Context, accountID int64, maxConcurrency int, requestID string) (bool, error) + releaseUserCalled int32 + releaseAccountCalled int32 +} + +func (m *concurrencyCacheMock) AcquireAccountSlot(ctx context.Context, accountID int64, maxConcurrency int, requestID string) (bool, error) { + if m.acquireAccountSlotFn != nil { + return m.acquireAccountSlotFn(ctx, accountID, maxConcurrency, requestID) + } + return false, nil +} + +func (m *concurrencyCacheMock) ReleaseAccountSlot(ctx context.Context, accountID int64, requestID string) error { + atomic.AddInt32(&m.releaseAccountCalled, 1) + return nil +} + +func (m *concurrencyCacheMock) GetAccountConcurrency(ctx context.Context, accountID int64) (int, error) { + return 0, nil +} + +func (m *concurrencyCacheMock) IncrementAccountWaitCount(ctx context.Context, accountID int64, maxWait int) (bool, error) { + return true, nil +} + +func (m *concurrencyCacheMock) DecrementAccountWaitCount(ctx context.Context, accountID int64) error { + return nil +} + +func (m *concurrencyCacheMock) GetAccountWaitingCount(ctx context.Context, accountID int64) (int, error) { + return 0, nil +} + +func (m *concurrencyCacheMock) AcquireUserSlot(ctx context.Context, userID int64, maxConcurrency int, requestID string) (bool, error) { + if m.acquireUserSlotFn != nil { + return m.acquireUserSlotFn(ctx, userID, maxConcurrency, requestID) + } + return false, nil +} + +func (m *concurrencyCacheMock) ReleaseUserSlot(ctx context.Context, userID int64, requestID string) error { + atomic.AddInt32(&m.releaseUserCalled, 1) + return nil +} + +func (m *concurrencyCacheMock) GetUserConcurrency(ctx context.Context, userID int64) (int, error) { + return 0, nil +} + +func (m *concurrencyCacheMock) IncrementWaitCount(ctx context.Context, userID int64, maxWait int) (bool, error) { + return true, nil +} + +func (m *concurrencyCacheMock) DecrementWaitCount(ctx context.Context, userID int64) error { + return nil +} + +func (m *concurrencyCacheMock) GetAccountsLoadBatch(ctx context.Context, accounts []service.AccountWithConcurrency) (map[int64]*service.AccountLoadInfo, error) { + return map[int64]*service.AccountLoadInfo{}, nil +} + +func (m *concurrencyCacheMock) GetUsersLoadBatch(ctx context.Context, users []service.UserWithConcurrency) (map[int64]*service.UserLoadInfo, error) { + return map[int64]*service.UserLoadInfo{}, nil +} + +func (m *concurrencyCacheMock) CleanupExpiredAccountSlots(ctx context.Context, accountID int64) error { + return nil +} + +func TestConcurrencyHelper_TryAcquireUserSlot(t *testing.T) { + cache := &concurrencyCacheMock{ + acquireUserSlotFn: func(ctx context.Context, userID int64, maxConcurrency int, requestID string) (bool, error) { + return true, nil + }, + } + helper := NewConcurrencyHelper(service.NewConcurrencyService(cache), SSEPingFormatNone, time.Second) + + release, acquired, err := helper.TryAcquireUserSlot(context.Background(), 101, 2) + require.NoError(t, err) + require.True(t, acquired) + require.NotNil(t, release) + + release() + require.Equal(t, int32(1), atomic.LoadInt32(&cache.releaseUserCalled)) +} + +func TestConcurrencyHelper_TryAcquireAccountSlot_NotAcquired(t *testing.T) { + cache := &concurrencyCacheMock{ + acquireAccountSlotFn: func(ctx context.Context, accountID int64, maxConcurrency int, requestID string) (bool, error) { + return false, nil + }, + } + helper := NewConcurrencyHelper(service.NewConcurrencyService(cache), SSEPingFormatNone, time.Second) + + release, acquired, err := helper.TryAcquireAccountSlot(context.Background(), 201, 1) + require.NoError(t, err) + require.False(t, acquired) + require.Nil(t, release) + require.Equal(t, int32(0), atomic.LoadInt32(&cache.releaseAccountCalled)) +} diff --git a/backend/internal/handler/gateway_helper_hotpath_test.go b/backend/internal/handler/gateway_helper_hotpath_test.go new file mode 100644 index 00000000..3fdf1bfc --- /dev/null +++ b/backend/internal/handler/gateway_helper_hotpath_test.go @@ -0,0 +1,269 @@ +package handler + +import ( + "context" + "errors" + "net/http" + "net/http/httptest" + "sync" + "testing" + "time" + + "github.com/Wei-Shaw/sub2api/internal/service" + "github.com/gin-gonic/gin" + "github.com/stretchr/testify/require" +) + +type helperConcurrencyCacheStub struct { + mu sync.Mutex + + accountSeq []bool + userSeq []bool + + accountAcquireCalls int + userAcquireCalls int + accountReleaseCalls int + userReleaseCalls int +} + +func (s *helperConcurrencyCacheStub) AcquireAccountSlot(ctx context.Context, accountID int64, maxConcurrency int, requestID string) (bool, error) { + s.mu.Lock() + defer s.mu.Unlock() + s.accountAcquireCalls++ + if len(s.accountSeq) == 0 { + return false, nil + } + v := s.accountSeq[0] + s.accountSeq = s.accountSeq[1:] + return v, nil +} + +func (s *helperConcurrencyCacheStub) ReleaseAccountSlot(ctx context.Context, accountID int64, requestID string) error { + s.mu.Lock() + defer s.mu.Unlock() + s.accountReleaseCalls++ + return nil +} + +func (s *helperConcurrencyCacheStub) GetAccountConcurrency(ctx context.Context, accountID int64) (int, error) { + return 0, nil +} + +func (s *helperConcurrencyCacheStub) IncrementAccountWaitCount(ctx context.Context, accountID int64, maxWait int) (bool, error) { + return true, nil +} + +func (s *helperConcurrencyCacheStub) DecrementAccountWaitCount(ctx context.Context, accountID int64) error { + return nil +} + +func (s *helperConcurrencyCacheStub) GetAccountWaitingCount(ctx context.Context, accountID int64) (int, error) { + return 0, nil +} + +func (s *helperConcurrencyCacheStub) AcquireUserSlot(ctx context.Context, userID int64, maxConcurrency int, requestID string) (bool, error) { + s.mu.Lock() + defer s.mu.Unlock() + s.userAcquireCalls++ + if len(s.userSeq) == 0 { + return false, nil + } + v := s.userSeq[0] + s.userSeq = s.userSeq[1:] + return v, nil +} + +func (s *helperConcurrencyCacheStub) ReleaseUserSlot(ctx context.Context, userID int64, requestID string) error { + s.mu.Lock() + defer s.mu.Unlock() + s.userReleaseCalls++ + return nil +} + +func (s *helperConcurrencyCacheStub) GetUserConcurrency(ctx context.Context, userID int64) (int, error) { + return 0, nil +} + +func (s *helperConcurrencyCacheStub) IncrementWaitCount(ctx context.Context, userID int64, maxWait int) (bool, error) { + return true, nil +} + +func (s *helperConcurrencyCacheStub) DecrementWaitCount(ctx context.Context, userID int64) error { + return nil +} + +func (s *helperConcurrencyCacheStub) GetAccountsLoadBatch(ctx context.Context, accounts []service.AccountWithConcurrency) (map[int64]*service.AccountLoadInfo, error) { + out := make(map[int64]*service.AccountLoadInfo, len(accounts)) + for _, acc := range accounts { + out[acc.ID] = &service.AccountLoadInfo{AccountID: acc.ID} + } + return out, nil +} + +func (s *helperConcurrencyCacheStub) GetUsersLoadBatch(ctx context.Context, users []service.UserWithConcurrency) (map[int64]*service.UserLoadInfo, error) { + out := make(map[int64]*service.UserLoadInfo, len(users)) + for _, user := range users { + out[user.ID] = &service.UserLoadInfo{UserID: user.ID} + } + return out, nil +} + +func (s *helperConcurrencyCacheStub) CleanupExpiredAccountSlots(ctx context.Context, accountID int64) error { + return nil +} + +func newHelperTestContext(method, path string) (*gin.Context, *httptest.ResponseRecorder) { + gin.SetMode(gin.TestMode) + rec := httptest.NewRecorder() + c, _ := gin.CreateTestContext(rec) + c.Request = httptest.NewRequest(method, path, nil) + return c, rec +} + +func validClaudeCodeBodyJSON() []byte { + return []byte(`{ + "model":"claude-3-5-sonnet-20241022", + "system":[{"text":"You are Claude Code, Anthropic's official CLI for Claude."}], + "metadata":{"user_id":"user_aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa_account__session_abc-123"} + }`) +} + +func TestSetClaudeCodeClientContext_FastPathAndStrictPath(t *testing.T) { + t.Run("non_cli_user_agent_sets_false", func(t *testing.T) { + c, _ := newHelperTestContext(http.MethodPost, "/v1/messages") + c.Request.Header.Set("User-Agent", "curl/8.6.0") + + SetClaudeCodeClientContext(c, validClaudeCodeBodyJSON()) + require.False(t, service.IsClaudeCodeClient(c.Request.Context())) + }) + + t.Run("cli_non_messages_path_sets_true", func(t *testing.T) { + c, _ := newHelperTestContext(http.MethodGet, "/v1/models") + c.Request.Header.Set("User-Agent", "claude-cli/1.0.1") + + SetClaudeCodeClientContext(c, nil) + require.True(t, service.IsClaudeCodeClient(c.Request.Context())) + }) + + t.Run("cli_messages_path_valid_body_sets_true", func(t *testing.T) { + c, _ := newHelperTestContext(http.MethodPost, "/v1/messages") + c.Request.Header.Set("User-Agent", "claude-cli/1.0.1") + c.Request.Header.Set("X-App", "claude-code") + c.Request.Header.Set("anthropic-beta", "message-batches-2024-09-24") + c.Request.Header.Set("anthropic-version", "2023-06-01") + + SetClaudeCodeClientContext(c, validClaudeCodeBodyJSON()) + require.True(t, service.IsClaudeCodeClient(c.Request.Context())) + }) + + t.Run("cli_messages_path_invalid_body_sets_false", func(t *testing.T) { + c, _ := newHelperTestContext(http.MethodPost, "/v1/messages") + c.Request.Header.Set("User-Agent", "claude-cli/1.0.1") + // 缺少严格校验所需 header + body 字段 + SetClaudeCodeClientContext(c, []byte(`{"model":"x"}`)) + require.False(t, service.IsClaudeCodeClient(c.Request.Context())) + }) +} + +func TestWaitForSlotWithPingTimeout_AccountAndUserAcquire(t *testing.T) { + cache := &helperConcurrencyCacheStub{ + accountSeq: []bool{false, true}, + userSeq: []bool{false, true}, + } + concurrency := service.NewConcurrencyService(cache) + helper := NewConcurrencyHelper(concurrency, SSEPingFormatNone, 5*time.Millisecond) + + t.Run("account_slot_acquired_after_retry", func(t *testing.T) { + c, _ := newHelperTestContext(http.MethodPost, "/v1/messages") + streamStarted := false + release, err := helper.waitForSlotWithPingTimeout(c, "account", 101, 2, time.Second, false, &streamStarted, true) + require.NoError(t, err) + require.NotNil(t, release) + require.False(t, streamStarted) + release() + require.GreaterOrEqual(t, cache.accountAcquireCalls, 2) + require.GreaterOrEqual(t, cache.accountReleaseCalls, 1) + }) + + t.Run("user_slot_acquired_after_retry", func(t *testing.T) { + c, _ := newHelperTestContext(http.MethodPost, "/v1/messages") + streamStarted := false + release, err := helper.waitForSlotWithPingTimeout(c, "user", 202, 3, time.Second, false, &streamStarted, true) + require.NoError(t, err) + require.NotNil(t, release) + release() + require.GreaterOrEqual(t, cache.userAcquireCalls, 2) + require.GreaterOrEqual(t, cache.userReleaseCalls, 1) + }) +} + +func TestWaitForSlotWithPingTimeout_TimeoutAndStreamPing(t *testing.T) { + cache := &helperConcurrencyCacheStub{ + accountSeq: []bool{false, false, false}, + } + concurrency := service.NewConcurrencyService(cache) + + t.Run("timeout_returns_concurrency_error", func(t *testing.T) { + helper := NewConcurrencyHelper(concurrency, SSEPingFormatNone, 5*time.Millisecond) + c, _ := newHelperTestContext(http.MethodPost, "/v1/messages") + streamStarted := false + release, err := helper.waitForSlotWithPingTimeout(c, "account", 101, 2, 130*time.Millisecond, false, &streamStarted, true) + require.Nil(t, release) + var cErr *ConcurrencyError + require.ErrorAs(t, err, &cErr) + require.True(t, cErr.IsTimeout) + }) + + t.Run("stream_mode_sends_ping_before_timeout", func(t *testing.T) { + helper := NewConcurrencyHelper(concurrency, SSEPingFormatComment, 10*time.Millisecond) + c, rec := newHelperTestContext(http.MethodPost, "/v1/messages") + streamStarted := false + release, err := helper.waitForSlotWithPingTimeout(c, "account", 101, 2, 70*time.Millisecond, true, &streamStarted, true) + require.Nil(t, release) + var cErr *ConcurrencyError + require.ErrorAs(t, err, &cErr) + require.True(t, cErr.IsTimeout) + require.True(t, streamStarted) + require.Contains(t, rec.Body.String(), ":\n\n") + }) +} + +func TestWaitForSlotWithPingTimeout_AcquireError(t *testing.T) { + errCache := &helperConcurrencyCacheStubWithError{ + err: errors.New("redis unavailable"), + } + concurrency := service.NewConcurrencyService(errCache) + helper := NewConcurrencyHelper(concurrency, SSEPingFormatNone, 5*time.Millisecond) + c, _ := newHelperTestContext(http.MethodPost, "/v1/messages") + streamStarted := false + release, err := helper.waitForSlotWithPingTimeout(c, "account", 1, 1, 200*time.Millisecond, false, &streamStarted, true) + require.Nil(t, release) + require.Error(t, err) + require.Contains(t, err.Error(), "redis unavailable") +} + +func TestAcquireAccountSlotWithWaitTimeout_ImmediateAttemptBeforeBackoff(t *testing.T) { + cache := &helperConcurrencyCacheStub{ + accountSeq: []bool{false}, + } + concurrency := service.NewConcurrencyService(cache) + helper := NewConcurrencyHelper(concurrency, SSEPingFormatNone, 5*time.Millisecond) + c, _ := newHelperTestContext(http.MethodPost, "/v1/messages") + streamStarted := false + + release, err := helper.AcquireAccountSlotWithWaitTimeout(c, 301, 1, 30*time.Millisecond, false, &streamStarted) + require.Nil(t, release) + var cErr *ConcurrencyError + require.ErrorAs(t, err, &cErr) + require.True(t, cErr.IsTimeout) + require.GreaterOrEqual(t, cache.accountAcquireCalls, 1) +} + +type helperConcurrencyCacheStubWithError struct { + helperConcurrencyCacheStub + err error +} + +func (s *helperConcurrencyCacheStubWithError) AcquireAccountSlot(ctx context.Context, accountID int64, maxConcurrency int, requestID string) (bool, error) { + return false, s.err +} diff --git a/backend/internal/handler/gemini_v1beta_handler.go b/backend/internal/handler/gemini_v1beta_handler.go index 3d25505b..ea212088 100644 --- a/backend/internal/handler/gemini_v1beta_handler.go +++ b/backend/internal/handler/gemini_v1beta_handler.go @@ -8,11 +8,9 @@ import ( "encoding/json" "errors" "io" - "log" "net/http" "regexp" "strings" - "time" "github.com/Wei-Shaw/sub2api/internal/domain" "github.com/Wei-Shaw/sub2api/internal/pkg/antigravity" @@ -20,11 +18,13 @@ import ( "github.com/Wei-Shaw/sub2api/internal/pkg/gemini" "github.com/Wei-Shaw/sub2api/internal/pkg/googleapi" "github.com/Wei-Shaw/sub2api/internal/pkg/ip" + "github.com/Wei-Shaw/sub2api/internal/pkg/logger" "github.com/Wei-Shaw/sub2api/internal/server/middleware" "github.com/Wei-Shaw/sub2api/internal/service" "github.com/google/uuid" "github.com/gin-gonic/gin" + "go.uber.org/zap" ) // geminiCLITmpDirRegex 用于从 Gemini CLI 请求体中提取 tmp 目录的哈希值 @@ -143,6 +143,13 @@ func (h *GatewayHandler) GeminiV1BetaModels(c *gin.Context) { googleError(c, http.StatusInternalServerError, "User context not found") return } + reqLog := requestLogger( + c, + "handler.gemini_v1beta.models", + zap.Int64("user_id", authSubject.UserID), + zap.Int64("api_key_id", apiKey.ID), + zap.Any("group_id", apiKey.GroupID), + ) // 检查平台:优先使用强制平台(/antigravity 路由,中间件已设置 request.Context),否则要求 gemini 分组 if !middleware.HasForcePlatform(c) { @@ -159,6 +166,7 @@ func (h *GatewayHandler) GeminiV1BetaModels(c *gin.Context) { } stream := action == "streamGenerateContent" + reqLog = reqLog.With(zap.String("model", modelName), zap.String("action", action), zap.Bool("stream", stream)) body, err := io.ReadAll(c.Request.Body) if err != nil { @@ -187,8 +195,9 @@ func (h *GatewayHandler) GeminiV1BetaModels(c *gin.Context) { canWait, err := geminiConcurrency.IncrementWaitCount(c.Request.Context(), authSubject.UserID, maxWait) waitCounted := false if err != nil { - log.Printf("Increment wait count failed: %v", err) + reqLog.Warn("gemini.user_wait_counter_increment_failed", zap.Error(err)) } else if !canWait { + reqLog.Info("gemini.user_wait_queue_full", zap.Int("max_wait", maxWait)) googleError(c, http.StatusTooManyRequests, "Too many pending requests, please retry later") return } @@ -208,6 +217,7 @@ func (h *GatewayHandler) GeminiV1BetaModels(c *gin.Context) { } userReleaseFunc, err := geminiConcurrency.AcquireUserSlotWithWait(c, authSubject.UserID, authSubject.Concurrency, stream, &streamStarted) if err != nil { + reqLog.Warn("gemini.user_slot_acquire_failed", zap.Error(err)) googleError(c, http.StatusTooManyRequests, err.Error()) return } @@ -223,6 +233,7 @@ func (h *GatewayHandler) GeminiV1BetaModels(c *gin.Context) { // 2) billing eligibility check (after wait) if err := h.billingCacheService.CheckBillingEligibility(c.Request.Context(), apiKey.User, apiKey, apiKey.Group, subscription); err != nil { + reqLog.Info("gemini.billing_eligibility_check_failed", zap.Error(err)) status, _, message := billingErrorDetails(err) googleError(c, status, message) return @@ -252,6 +263,15 @@ func (h *GatewayHandler) GeminiV1BetaModels(c *gin.Context) { var sessionBoundAccountID int64 if sessionKey != "" { sessionBoundAccountID, _ = h.gatewayService.GetCachedSessionAccountID(c.Request.Context(), apiKey.GroupID, sessionKey) + if sessionBoundAccountID > 0 { + prefetchedGroupID := int64(0) + if apiKey.GroupID != nil { + prefetchedGroupID = *apiKey.GroupID + } + ctx := context.WithValue(c.Request.Context(), ctxkey.PrefetchedStickyAccountID, sessionBoundAccountID) + ctx = context.WithValue(ctx, ctxkey.PrefetchedStickyGroupID, prefetchedGroupID) + c.Request = c.Request.WithContext(ctx) + } } // === Gemini 内容摘要会话 Fallback 逻辑 === @@ -296,8 +316,11 @@ func (h *GatewayHandler) GeminiV1BetaModels(c *gin.Context) { matchedDigestChain = foundMatchedChain sessionBoundAccountID = foundAccountID geminiSessionUUID = foundUUID - log.Printf("[Gemini] Digest fallback matched: uuid=%s, accountID=%d, chain=%s", - safeShortPrefix(foundUUID, 8), foundAccountID, truncateDigestChain(geminiDigestChain)) + reqLog.Info("gemini.digest_fallback_matched", + zap.String("session_uuid_prefix", safeShortPrefix(foundUUID, 8)), + zap.Int64("account_id", foundAccountID), + zap.String("digest_chain", truncateDigestChain(geminiDigestChain)), + ) // 关键:如果原 sessionKey 为空,使用 prefixHash + uuid 作为 sessionKey // 这样 SelectAccountWithLoadAwareness 的粘性会话逻辑会优先使用匹配到的账号 @@ -346,7 +369,10 @@ func (h *GatewayHandler) GeminiV1BetaModels(c *gin.Context) { // 谷歌上游 503 (MODEL_CAPACITY_EXHAUSTED) 通常是暂时性的,等几秒就能恢复。 if lastFailoverErr != nil && lastFailoverErr.StatusCode == http.StatusServiceUnavailable && switchCount <= maxAccountSwitches { if sleepAntigravitySingleAccountBackoff(c.Request.Context(), switchCount) { - log.Printf("Antigravity single-account 503 retry: clearing failed accounts, retry %d/%d", switchCount, maxAccountSwitches) + reqLog.Warn("gemini.single_account_retrying", + zap.Int("retry_count", switchCount), + zap.Int("max_retries", maxAccountSwitches), + ) failedAccountIDs = make(map[int64]struct{}) // 设置 context 标记,让 Service 层预检查等待限流过期而非直接切换 ctx := context.WithValue(c.Request.Context(), ctxkey.SingleAccountRetry, true) @@ -358,18 +384,24 @@ func (h *GatewayHandler) GeminiV1BetaModels(c *gin.Context) { return } account := selection.Account - setOpsSelectedAccount(c, account.ID) + setOpsSelectedAccount(c, account.ID, account.Platform) // 检测账号切换:如果粘性会话绑定的账号与当前选择的账号不同,清除 thoughtSignature // 注意:Gemini 原生 API 的 thoughtSignature 与具体上游账号强相关;跨账号透传会导致 400。 if sessionBoundAccountID > 0 && sessionBoundAccountID != account.ID { - log.Printf("[Gemini] Sticky session account switched: %d -> %d, cleaning thoughtSignature", sessionBoundAccountID, account.ID) + reqLog.Info("gemini.sticky_session_account_switched", + zap.Int64("from_account_id", sessionBoundAccountID), + zap.Int64("to_account_id", account.ID), + zap.Bool("clean_thought_signature", true), + ) body = service.CleanGeminiNativeThoughtSignatures(body) sessionBoundAccountID = account.ID } else if sessionKey != "" && sessionBoundAccountID == 0 && !cleanedForUnknownBinding && bytes.Contains(body, []byte(`"thoughtSignature"`)) { // 无缓存绑定但请求里已有 thoughtSignature:常见于缓存丢失/TTL 过期后,客户端继续携带旧签名。 // 为避免第一次转发就 400,这里做一次确定性清理,让新账号重新生成签名链路。 - log.Printf("[Gemini] Sticky session binding missing, cleaning thoughtSignature proactively") + reqLog.Info("gemini.sticky_session_binding_missing", + zap.Bool("clean_thought_signature", true), + ) body = service.CleanGeminiNativeThoughtSignatures(body) cleanedForUnknownBinding = true sessionBoundAccountID = account.ID @@ -388,9 +420,12 @@ func (h *GatewayHandler) GeminiV1BetaModels(c *gin.Context) { accountWaitCounted := false canWait, err := geminiConcurrency.IncrementAccountWaitCount(c.Request.Context(), account.ID, selection.WaitPlan.MaxWaiting) if err != nil { - log.Printf("Increment account wait count failed: %v", err) + reqLog.Warn("gemini.account_wait_counter_increment_failed", zap.Int64("account_id", account.ID), zap.Error(err)) } else if !canWait { - log.Printf("Account wait queue full: account=%d", account.ID) + reqLog.Info("gemini.account_wait_queue_full", + zap.Int64("account_id", account.ID), + zap.Int("max_waiting", selection.WaitPlan.MaxWaiting), + ) googleError(c, http.StatusTooManyRequests, "Too many pending requests, please retry later") return } @@ -412,6 +447,7 @@ func (h *GatewayHandler) GeminiV1BetaModels(c *gin.Context) { &streamStarted, ) if err != nil { + reqLog.Warn("gemini.account_slot_acquire_failed", zap.Int64("account_id", account.ID), zap.Error(err)) googleError(c, http.StatusTooManyRequests, err.Error()) return } @@ -420,7 +456,7 @@ func (h *GatewayHandler) GeminiV1BetaModels(c *gin.Context) { accountWaitCounted = false } if err := h.gatewayService.BindStickySession(c.Request.Context(), apiKey.GroupID, sessionKey, account.ID); err != nil { - log.Printf("Bind sticky session failed: %v", err) + reqLog.Warn("gemini.bind_sticky_session_failed", zap.Int64("account_id", account.ID), zap.Error(err)) } } // 账号槽位/等待计数需要在超时或断开时安全回收 @@ -454,7 +490,12 @@ func (h *GatewayHandler) GeminiV1BetaModels(c *gin.Context) { } lastFailoverErr = failoverErr switchCount++ - log.Printf("Gemini account %d: upstream error %d, switching account %d/%d", account.ID, failoverErr.StatusCode, switchCount, maxAccountSwitches) + reqLog.Warn("gemini.upstream_failover_switching", + zap.Int64("account_id", account.ID), + zap.Int("upstream_status", failoverErr.StatusCode), + zap.Int("switch_count", switchCount), + zap.Int("max_switches", maxAccountSwitches), + ) if account.Platform == service.PlatformAntigravity { if !sleepFailoverDelay(c.Request.Context(), switchCount) { return @@ -463,7 +504,7 @@ func (h *GatewayHandler) GeminiV1BetaModels(c *gin.Context) { continue } // ForwardNative already wrote the response - log.Printf("Gemini native forward failed: %v", err) + reqLog.Error("gemini.forward_failed", zap.Int64("account_id", account.ID), zap.Error(err)) return } @@ -482,31 +523,39 @@ func (h *GatewayHandler) GeminiV1BetaModels(c *gin.Context) { account.ID, matchedDigestChain, ); err != nil { - log.Printf("[Gemini] Failed to save digest session: %v", err) + reqLog.Warn("gemini.digest_session_save_failed", zap.Int64("account_id", account.ID), zap.Error(err)) } } - // 6) record usage async (Gemini 使用长上下文双倍计费) - go func(result *service.ForwardResult, usedAccount *service.Account, ua, ip string, fcb bool) { - ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) - defer cancel() - + // 使用量记录通过有界 worker 池提交,避免请求热路径创建无界 goroutine。 + h.submitUsageRecordTask(func(ctx context.Context) { if err := h.gatewayService.RecordUsageWithLongContext(ctx, &service.RecordUsageLongContextInput{ Result: result, APIKey: apiKey, User: apiKey.User, - Account: usedAccount, + Account: account, Subscription: subscription, - UserAgent: ua, - IPAddress: ip, + UserAgent: userAgent, + IPAddress: clientIP, LongContextThreshold: 200000, // Gemini 200K 阈值 LongContextMultiplier: 2.0, // 超出部分双倍计费 - ForceCacheBilling: fcb, + ForceCacheBilling: forceCacheBilling, APIKeyService: h.apiKeyService, }); err != nil { - log.Printf("Record usage failed: %v", err) + logger.L().With( + zap.String("component", "handler.gemini_v1beta.models"), + zap.Int64("user_id", authSubject.UserID), + zap.Int64("api_key_id", apiKey.ID), + zap.Any("group_id", apiKey.GroupID), + zap.String("model", modelName), + zap.Int64("account_id", account.ID), + ).Error("gemini.record_usage_failed", zap.Error(err)) } - }(result, account, userAgent, clientIP, forceCacheBilling) + }) + reqLog.Debug("gemini.request_completed", + zap.Int64("account_id", account.ID), + zap.Int("switch_count", switchCount), + ) return } } diff --git a/backend/internal/handler/handler.go b/backend/internal/handler/handler.go index b2b12c0d..b999180b 100644 --- a/backend/internal/handler/handler.go +++ b/backend/internal/handler/handler.go @@ -39,6 +39,7 @@ type Handlers struct { Admin *AdminHandlers Gateway *GatewayHandler OpenAIGateway *OpenAIGatewayHandler + SoraGateway *SoraGatewayHandler Setting *SettingHandler Totp *TotpHandler } diff --git a/backend/internal/handler/idempotency_helper.go b/backend/internal/handler/idempotency_helper.go new file mode 100644 index 00000000..bca63b6b --- /dev/null +++ b/backend/internal/handler/idempotency_helper.go @@ -0,0 +1,65 @@ +package handler + +import ( + "context" + "strconv" + "time" + + infraerrors "github.com/Wei-Shaw/sub2api/internal/pkg/errors" + "github.com/Wei-Shaw/sub2api/internal/pkg/logger" + "github.com/Wei-Shaw/sub2api/internal/pkg/response" + middleware2 "github.com/Wei-Shaw/sub2api/internal/server/middleware" + "github.com/Wei-Shaw/sub2api/internal/service" + + "github.com/gin-gonic/gin" +) + +func executeUserIdempotentJSON( + c *gin.Context, + scope string, + payload any, + ttl time.Duration, + execute func(context.Context) (any, error), +) { + coordinator := service.DefaultIdempotencyCoordinator() + if coordinator == nil { + data, err := execute(c.Request.Context()) + if err != nil { + response.ErrorFrom(c, err) + return + } + response.Success(c, data) + return + } + + actorScope := "user:0" + if subject, ok := middleware2.GetAuthSubjectFromContext(c); ok { + actorScope = "user:" + strconv.FormatInt(subject.UserID, 10) + } + + result, err := coordinator.Execute(c.Request.Context(), service.IdempotencyExecuteOptions{ + Scope: scope, + ActorScope: actorScope, + Method: c.Request.Method, + Route: c.FullPath(), + IdempotencyKey: c.GetHeader("Idempotency-Key"), + Payload: payload, + RequireKey: true, + TTL: ttl, + }, execute) + if err != nil { + if infraerrors.Code(err) == infraerrors.Code(service.ErrIdempotencyStoreUnavail) { + service.RecordIdempotencyStoreUnavailable(c.FullPath(), scope, "handler_fail_close") + logger.LegacyPrintf("handler.idempotency", "[Idempotency] store unavailable: method=%s route=%s scope=%s strategy=fail_close", c.Request.Method, c.FullPath(), scope) + } + if retryAfter := service.RetryAfterSecondsFromError(err); retryAfter > 0 { + c.Header("Retry-After", strconv.Itoa(retryAfter)) + } + response.ErrorFrom(c, err) + return + } + if result != nil && result.Replayed { + c.Header("X-Idempotency-Replayed", "true") + } + response.Success(c, result.Data) +} diff --git a/backend/internal/handler/idempotency_helper_test.go b/backend/internal/handler/idempotency_helper_test.go new file mode 100644 index 00000000..e8213a2b --- /dev/null +++ b/backend/internal/handler/idempotency_helper_test.go @@ -0,0 +1,285 @@ +package handler + +import ( + "bytes" + "context" + "errors" + "net/http" + "net/http/httptest" + "sync" + "sync/atomic" + "testing" + "time" + + "github.com/Wei-Shaw/sub2api/internal/server/middleware" + "github.com/Wei-Shaw/sub2api/internal/service" + "github.com/gin-gonic/gin" + "github.com/stretchr/testify/require" +) + +type userStoreUnavailableRepoStub struct{} + +func (userStoreUnavailableRepoStub) CreateProcessing(context.Context, *service.IdempotencyRecord) (bool, error) { + return false, errors.New("store unavailable") +} +func (userStoreUnavailableRepoStub) GetByScopeAndKeyHash(context.Context, string, string) (*service.IdempotencyRecord, error) { + return nil, errors.New("store unavailable") +} +func (userStoreUnavailableRepoStub) TryReclaim(context.Context, int64, string, time.Time, time.Time, time.Time) (bool, error) { + return false, errors.New("store unavailable") +} +func (userStoreUnavailableRepoStub) ExtendProcessingLock(context.Context, int64, string, time.Time, time.Time) (bool, error) { + return false, errors.New("store unavailable") +} +func (userStoreUnavailableRepoStub) MarkSucceeded(context.Context, int64, int, string, time.Time) error { + return errors.New("store unavailable") +} +func (userStoreUnavailableRepoStub) MarkFailedRetryable(context.Context, int64, string, time.Time, time.Time) error { + return errors.New("store unavailable") +} +func (userStoreUnavailableRepoStub) DeleteExpired(context.Context, time.Time, int) (int64, error) { + return 0, errors.New("store unavailable") +} + +type userMemoryIdempotencyRepoStub struct { + mu sync.Mutex + nextID int64 + data map[string]*service.IdempotencyRecord +} + +func newUserMemoryIdempotencyRepoStub() *userMemoryIdempotencyRepoStub { + return &userMemoryIdempotencyRepoStub{ + nextID: 1, + data: make(map[string]*service.IdempotencyRecord), + } +} + +func (r *userMemoryIdempotencyRepoStub) key(scope, keyHash string) string { + return scope + "|" + keyHash +} + +func (r *userMemoryIdempotencyRepoStub) clone(in *service.IdempotencyRecord) *service.IdempotencyRecord { + if in == nil { + return nil + } + out := *in + if in.LockedUntil != nil { + v := *in.LockedUntil + out.LockedUntil = &v + } + if in.ResponseBody != nil { + v := *in.ResponseBody + out.ResponseBody = &v + } + if in.ResponseStatus != nil { + v := *in.ResponseStatus + out.ResponseStatus = &v + } + if in.ErrorReason != nil { + v := *in.ErrorReason + out.ErrorReason = &v + } + return &out +} + +func (r *userMemoryIdempotencyRepoStub) CreateProcessing(_ context.Context, record *service.IdempotencyRecord) (bool, error) { + r.mu.Lock() + defer r.mu.Unlock() + k := r.key(record.Scope, record.IdempotencyKeyHash) + if _, ok := r.data[k]; ok { + return false, nil + } + cp := r.clone(record) + cp.ID = r.nextID + r.nextID++ + r.data[k] = cp + record.ID = cp.ID + return true, nil +} + +func (r *userMemoryIdempotencyRepoStub) GetByScopeAndKeyHash(_ context.Context, scope, keyHash string) (*service.IdempotencyRecord, error) { + r.mu.Lock() + defer r.mu.Unlock() + return r.clone(r.data[r.key(scope, keyHash)]), nil +} + +func (r *userMemoryIdempotencyRepoStub) TryReclaim(_ context.Context, id int64, fromStatus string, now, newLockedUntil, newExpiresAt time.Time) (bool, error) { + r.mu.Lock() + defer r.mu.Unlock() + for _, rec := range r.data { + if rec.ID != id { + continue + } + if rec.Status != fromStatus { + return false, nil + } + if rec.LockedUntil != nil && rec.LockedUntil.After(now) { + return false, nil + } + rec.Status = service.IdempotencyStatusProcessing + rec.LockedUntil = &newLockedUntil + rec.ExpiresAt = newExpiresAt + rec.ErrorReason = nil + return true, nil + } + return false, nil +} + +func (r *userMemoryIdempotencyRepoStub) ExtendProcessingLock(_ context.Context, id int64, requestFingerprint string, newLockedUntil, newExpiresAt time.Time) (bool, error) { + r.mu.Lock() + defer r.mu.Unlock() + for _, rec := range r.data { + if rec.ID != id { + continue + } + if rec.Status != service.IdempotencyStatusProcessing || rec.RequestFingerprint != requestFingerprint { + return false, nil + } + rec.LockedUntil = &newLockedUntil + rec.ExpiresAt = newExpiresAt + return true, nil + } + return false, nil +} + +func (r *userMemoryIdempotencyRepoStub) MarkSucceeded(_ context.Context, id int64, responseStatus int, responseBody string, expiresAt time.Time) error { + r.mu.Lock() + defer r.mu.Unlock() + for _, rec := range r.data { + if rec.ID != id { + continue + } + rec.Status = service.IdempotencyStatusSucceeded + rec.LockedUntil = nil + rec.ExpiresAt = expiresAt + rec.ResponseStatus = &responseStatus + rec.ResponseBody = &responseBody + rec.ErrorReason = nil + return nil + } + return nil +} + +func (r *userMemoryIdempotencyRepoStub) MarkFailedRetryable(_ context.Context, id int64, errorReason string, lockedUntil, expiresAt time.Time) error { + r.mu.Lock() + defer r.mu.Unlock() + for _, rec := range r.data { + if rec.ID != id { + continue + } + rec.Status = service.IdempotencyStatusFailedRetryable + rec.LockedUntil = &lockedUntil + rec.ExpiresAt = expiresAt + rec.ErrorReason = &errorReason + return nil + } + return nil +} + +func (r *userMemoryIdempotencyRepoStub) DeleteExpired(_ context.Context, _ time.Time, _ int) (int64, error) { + return 0, nil +} + +func withUserSubject(userID int64) gin.HandlerFunc { + return func(c *gin.Context) { + c.Set(string(middleware.ContextKeyUser), middleware.AuthSubject{UserID: userID}) + c.Next() + } +} + +func TestExecuteUserIdempotentJSONFallbackWithoutCoordinator(t *testing.T) { + gin.SetMode(gin.TestMode) + service.SetDefaultIdempotencyCoordinator(nil) + + var executed int + router := gin.New() + router.Use(withUserSubject(1)) + router.POST("/idempotent", func(c *gin.Context) { + executeUserIdempotentJSON(c, "user.test.scope", map[string]any{"a": 1}, time.Minute, func(ctx context.Context) (any, error) { + executed++ + return gin.H{"ok": true}, nil + }) + }) + + req := httptest.NewRequest(http.MethodPost, "/idempotent", bytes.NewBufferString(`{"a":1}`)) + req.Header.Set("Content-Type", "application/json") + rec := httptest.NewRecorder() + router.ServeHTTP(rec, req) + + require.Equal(t, http.StatusOK, rec.Code) + require.Equal(t, 1, executed) +} + +func TestExecuteUserIdempotentJSONFailCloseOnStoreUnavailable(t *testing.T) { + gin.SetMode(gin.TestMode) + service.SetDefaultIdempotencyCoordinator(service.NewIdempotencyCoordinator(userStoreUnavailableRepoStub{}, service.DefaultIdempotencyConfig())) + t.Cleanup(func() { + service.SetDefaultIdempotencyCoordinator(nil) + }) + + var executed int + router := gin.New() + router.Use(withUserSubject(2)) + router.POST("/idempotent", func(c *gin.Context) { + executeUserIdempotentJSON(c, "user.test.scope", map[string]any{"a": 1}, time.Minute, func(ctx context.Context) (any, error) { + executed++ + return gin.H{"ok": true}, nil + }) + }) + + req := httptest.NewRequest(http.MethodPost, "/idempotent", bytes.NewBufferString(`{"a":1}`)) + req.Header.Set("Content-Type", "application/json") + req.Header.Set("Idempotency-Key", "k1") + rec := httptest.NewRecorder() + router.ServeHTTP(rec, req) + + require.Equal(t, http.StatusServiceUnavailable, rec.Code) + require.Equal(t, 0, executed) +} + +func TestExecuteUserIdempotentJSONConcurrentRetrySingleSideEffectAndReplay(t *testing.T) { + gin.SetMode(gin.TestMode) + repo := newUserMemoryIdempotencyRepoStub() + cfg := service.DefaultIdempotencyConfig() + cfg.ProcessingTimeout = 2 * time.Second + service.SetDefaultIdempotencyCoordinator(service.NewIdempotencyCoordinator(repo, cfg)) + t.Cleanup(func() { + service.SetDefaultIdempotencyCoordinator(nil) + }) + + var executed atomic.Int32 + router := gin.New() + router.Use(withUserSubject(3)) + router.POST("/idempotent", func(c *gin.Context) { + executeUserIdempotentJSON(c, "user.test.scope", map[string]any{"a": 1}, time.Minute, func(ctx context.Context) (any, error) { + executed.Add(1) + time.Sleep(80 * time.Millisecond) + return gin.H{"ok": true}, nil + }) + }) + + call := func() (int, http.Header) { + req := httptest.NewRequest(http.MethodPost, "/idempotent", bytes.NewBufferString(`{"a":1}`)) + req.Header.Set("Content-Type", "application/json") + req.Header.Set("Idempotency-Key", "same-user-key") + rec := httptest.NewRecorder() + router.ServeHTTP(rec, req) + return rec.Code, rec.Header() + } + + var status1, status2 int + var wg sync.WaitGroup + wg.Add(2) + go func() { defer wg.Done(); status1, _ = call() }() + go func() { defer wg.Done(); status2, _ = call() }() + wg.Wait() + + require.Contains(t, []int{http.StatusOK, http.StatusConflict}, status1) + require.Contains(t, []int{http.StatusOK, http.StatusConflict}, status2) + require.Equal(t, int32(1), executed.Load()) + + status3, headers3 := call() + require.Equal(t, http.StatusOK, status3) + require.Equal(t, "true", headers3.Get("X-Idempotency-Replayed")) + require.Equal(t, int32(1), executed.Load()) +} diff --git a/backend/internal/handler/logging.go b/backend/internal/handler/logging.go new file mode 100644 index 00000000..2d5e6e22 --- /dev/null +++ b/backend/internal/handler/logging.go @@ -0,0 +1,19 @@ +package handler + +import ( + "github.com/Wei-Shaw/sub2api/internal/pkg/logger" + "github.com/gin-gonic/gin" + "go.uber.org/zap" +) + +func requestLogger(c *gin.Context, component string, fields ...zap.Field) *zap.Logger { + base := logger.L() + if c != nil && c.Request != nil { + base = logger.FromContext(c.Request.Context()) + } + + if component != "" { + fields = append([]zap.Field{zap.String("component", component)}, fields...) + } + return base.With(fields...) +} diff --git a/backend/internal/handler/openai_gateway_handler.go b/backend/internal/handler/openai_gateway_handler.go index c08a8b0e..50af684d 100644 --- a/backend/internal/handler/openai_gateway_handler.go +++ b/backend/internal/handler/openai_gateway_handler.go @@ -6,18 +6,19 @@ import ( "errors" "fmt" "io" - "log" "net/http" "strings" "time" "github.com/Wei-Shaw/sub2api/internal/config" "github.com/Wei-Shaw/sub2api/internal/pkg/ip" - "github.com/Wei-Shaw/sub2api/internal/pkg/openai" + "github.com/Wei-Shaw/sub2api/internal/pkg/logger" middleware2 "github.com/Wei-Shaw/sub2api/internal/server/middleware" "github.com/Wei-Shaw/sub2api/internal/service" "github.com/gin-gonic/gin" + "github.com/tidwall/gjson" + "go.uber.org/zap" ) // OpenAIGatewayHandler handles OpenAI API gateway requests @@ -25,6 +26,7 @@ type OpenAIGatewayHandler struct { gatewayService *service.OpenAIGatewayService billingCacheService *service.BillingCacheService apiKeyService *service.APIKeyService + usageRecordWorkerPool *service.UsageRecordWorkerPool errorPassthroughService *service.ErrorPassthroughService concurrencyHelper *ConcurrencyHelper maxAccountSwitches int @@ -36,6 +38,7 @@ func NewOpenAIGatewayHandler( concurrencyService *service.ConcurrencyService, billingCacheService *service.BillingCacheService, apiKeyService *service.APIKeyService, + usageRecordWorkerPool *service.UsageRecordWorkerPool, errorPassthroughService *service.ErrorPassthroughService, cfg *config.Config, ) *OpenAIGatewayHandler { @@ -51,6 +54,7 @@ func NewOpenAIGatewayHandler( gatewayService: gatewayService, billingCacheService: billingCacheService, apiKeyService: apiKeyService, + usageRecordWorkerPool: usageRecordWorkerPool, errorPassthroughService: errorPassthroughService, concurrencyHelper: NewConcurrencyHelper(concurrencyService, SSEPingFormatComment, pingInterval), maxAccountSwitches: maxAccountSwitches, @@ -60,6 +64,8 @@ func NewOpenAIGatewayHandler( // Responses handles OpenAI Responses API endpoint // POST /openai/v1/responses func (h *OpenAIGatewayHandler) Responses(c *gin.Context) { + requestStart := time.Now() + // Get apiKey and user from context (set by ApiKeyAuth middleware) apiKey, ok := middleware2.GetAPIKeyFromContext(c) if !ok { @@ -72,6 +78,13 @@ func (h *OpenAIGatewayHandler) Responses(c *gin.Context) { h.errorResponse(c, http.StatusInternalServerError, "api_error", "User context not found") return } + reqLog := requestLogger( + c, + "handler.openai_gateway.responses", + zap.Int64("user_id", subject.UserID), + zap.Int64("api_key_id", apiKey.ID), + zap.Any("group_id", apiKey.GroupID), + ) // Read request body body, err := io.ReadAll(c.Request.Body) @@ -91,57 +104,57 @@ func (h *OpenAIGatewayHandler) Responses(c *gin.Context) { setOpsRequestContext(c, "", false, body) - // Parse request body to map for potential modification - var reqBody map[string]any - if err := json.Unmarshal(body, &reqBody); err != nil { + // 校验请求体 JSON 合法性 + if !gjson.ValidBytes(body) { h.errorResponse(c, http.StatusBadRequest, "invalid_request_error", "Failed to parse request body") return } - // Extract model and stream - reqModel, _ := reqBody["model"].(string) - reqStream, _ := reqBody["stream"].(bool) - - // 验证 model 必填 - if reqModel == "" { + // 使用 gjson 只读提取字段做校验,避免完整 Unmarshal + modelResult := gjson.GetBytes(body, "model") + if !modelResult.Exists() || modelResult.Type != gjson.String || modelResult.String() == "" { h.errorResponse(c, http.StatusBadRequest, "invalid_request_error", "model is required") return } + reqModel := modelResult.String() - userAgent := c.GetHeader("User-Agent") - if !openai.IsCodexCLIRequest(userAgent) { - existingInstructions, _ := reqBody["instructions"].(string) - if strings.TrimSpace(existingInstructions) == "" { - if instructions := strings.TrimSpace(service.GetOpenCodeInstructions()); instructions != "" { - reqBody["instructions"] = instructions - // Re-serialize body - body, err = json.Marshal(reqBody) - if err != nil { - h.errorResponse(c, http.StatusInternalServerError, "api_error", "Failed to process request") - return - } - } - } + streamResult := gjson.GetBytes(body, "stream") + if streamResult.Exists() && streamResult.Type != gjson.True && streamResult.Type != gjson.False { + h.errorResponse(c, http.StatusBadRequest, "invalid_request_error", "invalid stream field type") + return } + reqStream := streamResult.Bool() + reqLog = reqLog.With(zap.String("model", reqModel), zap.Bool("stream", reqStream)) setOpsRequestContext(c, reqModel, reqStream, body) // 提前校验 function_call_output 是否具备可关联上下文,避免上游 400。 // 要求 previous_response_id,或 input 内存在带 call_id 的 tool_call/function_call, // 或带 id 且与 call_id 匹配的 item_reference。 - if service.HasFunctionCallOutput(reqBody) { - previousResponseID, _ := reqBody["previous_response_id"].(string) - if strings.TrimSpace(previousResponseID) == "" && !service.HasToolCallContext(reqBody) { - if service.HasFunctionCallOutputMissingCallID(reqBody) { - log.Printf("[OpenAI Handler] function_call_output 缺少 call_id: model=%s", reqModel) - h.errorResponse(c, http.StatusBadRequest, "invalid_request_error", "function_call_output requires call_id or previous_response_id; if relying on history, ensure store=true and reuse previous_response_id") - return - } - callIDs := service.FunctionCallOutputCallIDs(reqBody) - if !service.HasItemReferenceForCallIDs(reqBody, callIDs) { - log.Printf("[OpenAI Handler] function_call_output 缺少匹配的 item_reference: model=%s", reqModel) - h.errorResponse(c, http.StatusBadRequest, "invalid_request_error", "function_call_output requires item_reference ids matching each call_id, or previous_response_id/tool_call context; if relying on history, ensure store=true and reuse previous_response_id") - return + // 此路径需要遍历 input 数组做 call_id 关联检查,保留 Unmarshal + if gjson.GetBytes(body, `input.#(type=="function_call_output")`).Exists() { + var reqBody map[string]any + if err := json.Unmarshal(body, &reqBody); err == nil { + c.Set(service.OpenAIParsedRequestBodyKey, reqBody) + if service.HasFunctionCallOutput(reqBody) { + previousResponseID, _ := reqBody["previous_response_id"].(string) + if strings.TrimSpace(previousResponseID) == "" && !service.HasToolCallContext(reqBody) { + if service.HasFunctionCallOutputMissingCallID(reqBody) { + reqLog.Warn("openai.request_validation_failed", + zap.String("reason", "function_call_output_missing_call_id"), + ) + h.errorResponse(c, http.StatusBadRequest, "invalid_request_error", "function_call_output requires call_id or previous_response_id; if relying on history, ensure store=true and reuse previous_response_id") + return + } + callIDs := service.FunctionCallOutputCallIDs(reqBody) + if !service.HasItemReferenceForCallIDs(reqBody, callIDs) { + reqLog.Warn("openai.request_validation_failed", + zap.String("reason", "function_call_output_missing_item_reference"), + ) + h.errorResponse(c, http.StatusBadRequest, "invalid_request_error", "function_call_output requires item_reference ids matching each call_id, or previous_response_id/tool_call context; if relying on history, ensure store=true and reuse previous_response_id") + return + } + } } } } @@ -157,34 +170,48 @@ func (h *OpenAIGatewayHandler) Responses(c *gin.Context) { // Get subscription info (may be nil) subscription, _ := middleware2.GetSubscriptionFromContext(c) - // 0. Check if wait queue is full - maxWait := service.CalculateMaxWait(subject.Concurrency) - canWait, err := h.concurrencyHelper.IncrementWaitCount(c.Request.Context(), subject.UserID, maxWait) - waitCounted := false - if err != nil { - log.Printf("Increment wait count failed: %v", err) - // On error, allow request to proceed - } else if !canWait { - h.errorResponse(c, http.StatusTooManyRequests, "rate_limit_error", "Too many pending requests, please retry later") - return - } - if err == nil && canWait { - waitCounted = true - } - defer func() { - if waitCounted { - h.concurrencyHelper.DecrementWaitCount(c.Request.Context(), subject.UserID) - } - }() + service.SetOpsLatencyMs(c, service.OpsAuthLatencyMsKey, time.Since(requestStart).Milliseconds()) + routingStart := time.Now() - // 1. First acquire user concurrency slot - userReleaseFunc, err := h.concurrencyHelper.AcquireUserSlotWithWait(c, subject.UserID, subject.Concurrency, reqStream, &streamStarted) + // 0. 先尝试直接抢占用户槽位(快速路径) + userReleaseFunc, userAcquired, err := h.concurrencyHelper.TryAcquireUserSlot(c.Request.Context(), subject.UserID, subject.Concurrency) if err != nil { - log.Printf("User concurrency acquire failed: %v", err) + reqLog.Warn("openai.user_slot_acquire_failed", zap.Error(err)) h.handleConcurrencyError(c, err, "user", streamStarted) return } - // User slot acquired: no longer waiting. + + waitCounted := false + if !userAcquired { + // 仅在抢槽失败时才进入等待队列,减少常态请求 Redis 写入。 + maxWait := service.CalculateMaxWait(subject.Concurrency) + canWait, waitErr := h.concurrencyHelper.IncrementWaitCount(c.Request.Context(), subject.UserID, maxWait) + if waitErr != nil { + reqLog.Warn("openai.user_wait_counter_increment_failed", zap.Error(waitErr)) + // 按现有降级语义:等待计数异常时放行后续抢槽流程 + } else if !canWait { + reqLog.Info("openai.user_wait_queue_full", zap.Int("max_wait", maxWait)) + h.errorResponse(c, http.StatusTooManyRequests, "rate_limit_error", "Too many pending requests, please retry later") + return + } + if waitErr == nil && canWait { + waitCounted = true + } + defer func() { + if waitCounted { + h.concurrencyHelper.DecrementWaitCount(c.Request.Context(), subject.UserID) + } + }() + + userReleaseFunc, err = h.concurrencyHelper.AcquireUserSlotWithWait(c, subject.UserID, subject.Concurrency, reqStream, &streamStarted) + if err != nil { + reqLog.Warn("openai.user_slot_acquire_failed_after_wait", zap.Error(err)) + h.handleConcurrencyError(c, err, "user", streamStarted) + return + } + } + + // 用户槽位已获取:退出等待队列计数。 if waitCounted { h.concurrencyHelper.DecrementWaitCount(c.Request.Context(), subject.UserID) waitCounted = false @@ -197,14 +224,14 @@ func (h *OpenAIGatewayHandler) Responses(c *gin.Context) { // 2. Re-check billing eligibility after wait if err := h.billingCacheService.CheckBillingEligibility(c.Request.Context(), apiKey.User, apiKey, apiKey.Group, subscription); err != nil { - log.Printf("Billing eligibility check failed after wait: %v", err) + reqLog.Info("openai.billing_eligibility_check_failed", zap.Error(err)) status, code, message := billingErrorDetails(err) h.handleStreamingAwareError(c, status, code, message, streamStarted) return } // Generate session hash (header first; fallback to prompt_cache_key) - sessionHash := h.gatewayService.GenerateSessionHash(c, reqBody) + sessionHash := h.gatewayService.GenerateSessionHash(c, body) maxAccountSwitches := h.maxAccountSwitches switchCount := 0 @@ -213,12 +240,15 @@ func (h *OpenAIGatewayHandler) Responses(c *gin.Context) { for { // Select account supporting the requested model - log.Printf("[OpenAI Handler] Selecting account: groupID=%v model=%s", apiKey.GroupID, reqModel) + reqLog.Debug("openai.account_selecting", zap.Int("excluded_account_count", len(failedAccountIDs))) selection, err := h.gatewayService.SelectAccountWithLoadAwareness(c.Request.Context(), apiKey.GroupID, sessionHash, reqModel, failedAccountIDs) if err != nil { - log.Printf("[OpenAI Handler] SelectAccount failed: %v", err) + reqLog.Warn("openai.account_select_failed", + zap.Error(err), + zap.Int("excluded_account_count", len(failedAccountIDs)), + ) if len(failedAccountIDs) == 0 { - h.handleStreamingAwareError(c, http.StatusServiceUnavailable, "api_error", "No available accounts: "+err.Error(), streamStarted) + h.handleStreamingAwareError(c, http.StatusServiceUnavailable, "api_error", "Service temporarily unavailable", streamStarted) return } if lastFailoverErr != nil { @@ -229,8 +259,8 @@ func (h *OpenAIGatewayHandler) Responses(c *gin.Context) { return } account := selection.Account - log.Printf("[OpenAI Handler] Selected account: id=%d name=%s", account.ID, account.Name) - setOpsSelectedAccount(c, account.ID) + reqLog.Debug("openai.account_selected", zap.Int64("account_id", account.ID), zap.String("account_name", account.Name)) + setOpsSelectedAccount(c, account.ID, account.Platform) // 3. Acquire account concurrency slot accountReleaseFunc := selection.ReleaseFunc @@ -239,53 +269,87 @@ func (h *OpenAIGatewayHandler) Responses(c *gin.Context) { h.handleStreamingAwareError(c, http.StatusServiceUnavailable, "api_error", "No available accounts", streamStarted) return } - accountWaitCounted := false - canWait, err := h.concurrencyHelper.IncrementAccountWaitCount(c.Request.Context(), account.ID, selection.WaitPlan.MaxWaiting) - if err != nil { - log.Printf("Increment account wait count failed: %v", err) - } else if !canWait { - log.Printf("Account wait queue full: account=%d", account.ID) - h.handleStreamingAwareError(c, http.StatusTooManyRequests, "rate_limit_error", "Too many pending requests, please retry later", streamStarted) - return - } - if err == nil && canWait { - accountWaitCounted = true - } - defer func() { - if accountWaitCounted { - h.concurrencyHelper.DecrementAccountWaitCount(c.Request.Context(), account.ID) - } - }() - accountReleaseFunc, err = h.concurrencyHelper.AcquireAccountSlotWithWaitTimeout( - c, + // 先快速尝试一次账号槽位,命中则跳过等待计数写入。 + fastReleaseFunc, fastAcquired, err := h.concurrencyHelper.TryAcquireAccountSlot( + c.Request.Context(), account.ID, selection.WaitPlan.MaxConcurrency, - selection.WaitPlan.Timeout, - reqStream, - &streamStarted, ) if err != nil { - log.Printf("Account concurrency acquire failed: %v", err) + reqLog.Warn("openai.account_slot_quick_acquire_failed", zap.Int64("account_id", account.ID), zap.Error(err)) h.handleConcurrencyError(c, err, "account", streamStarted) return } - if accountWaitCounted { - h.concurrencyHelper.DecrementAccountWaitCount(c.Request.Context(), account.ID) - accountWaitCounted = false - } - if err := h.gatewayService.BindStickySession(c.Request.Context(), apiKey.GroupID, sessionHash, account.ID); err != nil { - log.Printf("Bind sticky session failed: %v", err) + if fastAcquired { + accountReleaseFunc = fastReleaseFunc + if err := h.gatewayService.BindStickySession(c.Request.Context(), apiKey.GroupID, sessionHash, account.ID); err != nil { + reqLog.Warn("openai.bind_sticky_session_failed", zap.Int64("account_id", account.ID), zap.Error(err)) + } + } else { + accountWaitCounted := false + canWait, err := h.concurrencyHelper.IncrementAccountWaitCount(c.Request.Context(), account.ID, selection.WaitPlan.MaxWaiting) + if err != nil { + reqLog.Warn("openai.account_wait_counter_increment_failed", zap.Int64("account_id", account.ID), zap.Error(err)) + } else if !canWait { + reqLog.Info("openai.account_wait_queue_full", + zap.Int64("account_id", account.ID), + zap.Int("max_waiting", selection.WaitPlan.MaxWaiting), + ) + h.handleStreamingAwareError(c, http.StatusTooManyRequests, "rate_limit_error", "Too many pending requests, please retry later", streamStarted) + return + } + if err == nil && canWait { + accountWaitCounted = true + } + releaseWait := func() { + if accountWaitCounted { + h.concurrencyHelper.DecrementAccountWaitCount(c.Request.Context(), account.ID) + accountWaitCounted = false + } + } + + accountReleaseFunc, err = h.concurrencyHelper.AcquireAccountSlotWithWaitTimeout( + c, + account.ID, + selection.WaitPlan.MaxConcurrency, + selection.WaitPlan.Timeout, + reqStream, + &streamStarted, + ) + if err != nil { + reqLog.Warn("openai.account_slot_acquire_failed", zap.Int64("account_id", account.ID), zap.Error(err)) + releaseWait() + h.handleConcurrencyError(c, err, "account", streamStarted) + return + } + // Slot acquired: no longer waiting in queue. + releaseWait() + if err := h.gatewayService.BindStickySession(c.Request.Context(), apiKey.GroupID, sessionHash, account.ID); err != nil { + reqLog.Warn("openai.bind_sticky_session_failed", zap.Int64("account_id", account.ID), zap.Error(err)) + } } } // 账号槽位/等待计数需要在超时或断开时安全回收 accountReleaseFunc = wrapReleaseOnDone(c.Request.Context(), accountReleaseFunc) // Forward request + service.SetOpsLatencyMs(c, service.OpsRoutingLatencyMsKey, time.Since(routingStart).Milliseconds()) + forwardStart := time.Now() result, err := h.gatewayService.Forward(c.Request.Context(), c, account, body) + forwardDurationMs := time.Since(forwardStart).Milliseconds() if accountReleaseFunc != nil { accountReleaseFunc() } + upstreamLatencyMs, _ := getContextInt64(c, service.OpsUpstreamLatencyMsKey) + responseLatencyMs := forwardDurationMs + if upstreamLatencyMs > 0 && forwardDurationMs > upstreamLatencyMs { + responseLatencyMs = forwardDurationMs - upstreamLatencyMs + } + service.SetOpsLatencyMs(c, service.OpsResponseLatencyMsKey, responseLatencyMs) + if err == nil && result != nil && result.FirstTokenMs != nil { + service.SetOpsLatencyMs(c, service.OpsTimeToFirstTokenMsKey, int64(*result.FirstTokenMs)) + } if err != nil { var failoverErr *service.UpstreamFailoverError if errors.As(err, &failoverErr) { @@ -296,11 +360,20 @@ func (h *OpenAIGatewayHandler) Responses(c *gin.Context) { return } switchCount++ - log.Printf("Account %d: upstream error %d, switching account %d/%d", account.ID, failoverErr.StatusCode, switchCount, maxAccountSwitches) + reqLog.Warn("openai.upstream_failover_switching", + zap.Int64("account_id", account.ID), + zap.Int("upstream_status", failoverErr.StatusCode), + zap.Int("switch_count", switchCount), + zap.Int("max_switches", maxAccountSwitches), + ) continue } - // Error response already handled in Forward, just log - log.Printf("Account %d: Forward request failed: %v", account.ID, err) + wroteFallback := h.ensureForwardErrorResponse(c, streamStarted) + reqLog.Error("openai.forward_failed", + zap.Int64("account_id", account.ID), + zap.Bool("fallback_error_response_written", wroteFallback), + zap.Error(err), + ) return } @@ -308,27 +381,72 @@ func (h *OpenAIGatewayHandler) Responses(c *gin.Context) { userAgent := c.GetHeader("User-Agent") clientIP := ip.GetClientIP(c) - // Async record usage - go func(result *service.OpenAIForwardResult, usedAccount *service.Account, ua, ip string) { - ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) - defer cancel() + // 使用量记录通过有界 worker 池提交,避免请求热路径创建无界 goroutine。 + h.submitUsageRecordTask(func(ctx context.Context) { if err := h.gatewayService.RecordUsage(ctx, &service.OpenAIRecordUsageInput{ Result: result, APIKey: apiKey, User: apiKey.User, - Account: usedAccount, + Account: account, Subscription: subscription, - UserAgent: ua, - IPAddress: ip, + UserAgent: userAgent, + IPAddress: clientIP, APIKeyService: h.apiKeyService, }); err != nil { - log.Printf("Record usage failed: %v", err) + logger.L().With( + zap.String("component", "handler.openai_gateway.responses"), + zap.Int64("user_id", subject.UserID), + zap.Int64("api_key_id", apiKey.ID), + zap.Any("group_id", apiKey.GroupID), + zap.String("model", reqModel), + zap.Int64("account_id", account.ID), + ).Error("openai.record_usage_failed", zap.Error(err)) } - }(result, account, userAgent, clientIP) + }) + reqLog.Debug("openai.request_completed", + zap.Int64("account_id", account.ID), + zap.Int("switch_count", switchCount), + ) return } } +func getContextInt64(c *gin.Context, key string) (int64, bool) { + if c == nil || key == "" { + return 0, false + } + v, ok := c.Get(key) + if !ok { + return 0, false + } + switch t := v.(type) { + case int64: + return t, true + case int: + return int64(t), true + case int32: + return int64(t), true + case float64: + return int64(t), true + default: + return 0, false + } +} + +func (h *OpenAIGatewayHandler) submitUsageRecordTask(task service.UsageRecordTask) { + if task == nil { + return + } + if h.usageRecordWorkerPool != nil { + h.usageRecordWorkerPool.Submit(task) + return + } + // 回退路径:worker 池未注入时同步执行,避免退回到无界 goroutine 模式。 + ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) + defer cancel() + task(ctx) +} + // handleConcurrencyError handles concurrency-related errors with proper 429 response func (h *OpenAIGatewayHandler) handleConcurrencyError(c *gin.Context, err error, slotType string, streamStarted bool) { h.handleStreamingAwareError(c, http.StatusTooManyRequests, "rate_limit_error", @@ -397,8 +515,19 @@ func (h *OpenAIGatewayHandler) handleStreamingAwareError(c *gin.Context, status // Stream already started, send error as SSE event then close flusher, ok := c.Writer.(http.Flusher) if ok { - // Send error event in OpenAI SSE format - errorEvent := fmt.Sprintf(`event: error`+"\n"+`data: {"error": {"type": "%s", "message": "%s"}}`+"\n\n", errType, message) + // Send error event in OpenAI SSE format with proper JSON marshaling + errorData := map[string]any{ + "error": map[string]string{ + "type": errType, + "message": message, + }, + } + jsonBytes, err := json.Marshal(errorData) + if err != nil { + _ = c.Error(err) + return + } + errorEvent := fmt.Sprintf("event: error\ndata: %s\n\n", string(jsonBytes)) if _, err := fmt.Fprint(c.Writer, errorEvent); err != nil { _ = c.Error(err) } @@ -411,6 +540,15 @@ func (h *OpenAIGatewayHandler) handleStreamingAwareError(c *gin.Context, status h.errorResponse(c, status, errType, message) } +// ensureForwardErrorResponse 在 Forward 返回错误但尚未写响应时补写统一错误响应。 +func (h *OpenAIGatewayHandler) ensureForwardErrorResponse(c *gin.Context, streamStarted bool) bool { + if c == nil || c.Writer == nil || c.Writer.Written() { + return false + } + h.handleStreamingAwareError(c, http.StatusBadGateway, "upstream_error", "Upstream request failed", streamStarted) + return true +} + // errorResponse returns OpenAI API format error response func (h *OpenAIGatewayHandler) errorResponse(c *gin.Context, status int, errType, message string) { c.JSON(status, gin.H{ diff --git a/backend/internal/handler/openai_gateway_handler_test.go b/backend/internal/handler/openai_gateway_handler_test.go new file mode 100644 index 00000000..1ca52c2d --- /dev/null +++ b/backend/internal/handler/openai_gateway_handler_test.go @@ -0,0 +1,230 @@ +package handler + +import ( + "encoding/json" + "net/http" + "net/http/httptest" + "strings" + "testing" + + "github.com/gin-gonic/gin" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + "github.com/tidwall/gjson" + "github.com/tidwall/sjson" +) + +func TestOpenAIHandleStreamingAwareError_JSONEscaping(t *testing.T) { + tests := []struct { + name string + errType string + message string + }{ + { + name: "包含双引号的消息", + errType: "server_error", + message: `upstream returned "invalid" response`, + }, + { + name: "包含反斜杠的消息", + errType: "server_error", + message: `path C:\Users\test\file.txt not found`, + }, + { + name: "包含双引号和反斜杠的消息", + errType: "upstream_error", + message: `error parsing "key\value": unexpected token`, + }, + { + name: "包含换行符的消息", + errType: "server_error", + message: "line1\nline2\ttab", + }, + { + name: "普通消息", + errType: "upstream_error", + message: "Upstream service temporarily unavailable", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + gin.SetMode(gin.TestMode) + w := httptest.NewRecorder() + c, _ := gin.CreateTestContext(w) + c.Request = httptest.NewRequest(http.MethodGet, "/", nil) + + h := &OpenAIGatewayHandler{} + h.handleStreamingAwareError(c, http.StatusBadGateway, tt.errType, tt.message, true) + + body := w.Body.String() + + // 验证 SSE 格式:event: error\ndata: {JSON}\n\n + assert.True(t, strings.HasPrefix(body, "event: error\n"), "应以 'event: error\\n' 开头") + assert.True(t, strings.HasSuffix(body, "\n\n"), "应以 '\\n\\n' 结尾") + + // 提取 data 部分 + lines := strings.Split(strings.TrimSuffix(body, "\n\n"), "\n") + require.Len(t, lines, 2, "应有 event 行和 data 行") + dataLine := lines[1] + require.True(t, strings.HasPrefix(dataLine, "data: "), "第二行应以 'data: ' 开头") + jsonStr := strings.TrimPrefix(dataLine, "data: ") + + // 验证 JSON 合法性 + var parsed map[string]any + err := json.Unmarshal([]byte(jsonStr), &parsed) + require.NoError(t, err, "JSON 应能被成功解析,原始 JSON: %s", jsonStr) + + // 验证结构 + errorObj, ok := parsed["error"].(map[string]any) + require.True(t, ok, "应包含 error 对象") + assert.Equal(t, tt.errType, errorObj["type"]) + assert.Equal(t, tt.message, errorObj["message"]) + }) + } +} + +func TestOpenAIHandleStreamingAwareError_NonStreaming(t *testing.T) { + gin.SetMode(gin.TestMode) + w := httptest.NewRecorder() + c, _ := gin.CreateTestContext(w) + c.Request = httptest.NewRequest(http.MethodGet, "/", nil) + + h := &OpenAIGatewayHandler{} + h.handleStreamingAwareError(c, http.StatusBadGateway, "upstream_error", "test error", false) + + // 非流式应返回 JSON 响应 + assert.Equal(t, http.StatusBadGateway, w.Code) + + var parsed map[string]any + err := json.Unmarshal(w.Body.Bytes(), &parsed) + require.NoError(t, err) + errorObj, ok := parsed["error"].(map[string]any) + require.True(t, ok) + assert.Equal(t, "upstream_error", errorObj["type"]) + assert.Equal(t, "test error", errorObj["message"]) +} + +func TestOpenAIEnsureForwardErrorResponse_WritesFallbackWhenNotWritten(t *testing.T) { + gin.SetMode(gin.TestMode) + w := httptest.NewRecorder() + c, _ := gin.CreateTestContext(w) + c.Request = httptest.NewRequest(http.MethodGet, "/", nil) + + h := &OpenAIGatewayHandler{} + wrote := h.ensureForwardErrorResponse(c, false) + + require.True(t, wrote) + require.Equal(t, http.StatusBadGateway, w.Code) + + var parsed map[string]any + err := json.Unmarshal(w.Body.Bytes(), &parsed) + require.NoError(t, err) + errorObj, ok := parsed["error"].(map[string]any) + require.True(t, ok) + assert.Equal(t, "upstream_error", errorObj["type"]) + assert.Equal(t, "Upstream request failed", errorObj["message"]) +} + +func TestOpenAIEnsureForwardErrorResponse_DoesNotOverrideWrittenResponse(t *testing.T) { + gin.SetMode(gin.TestMode) + w := httptest.NewRecorder() + c, _ := gin.CreateTestContext(w) + c.Request = httptest.NewRequest(http.MethodGet, "/", nil) + c.String(http.StatusTeapot, "already written") + + h := &OpenAIGatewayHandler{} + wrote := h.ensureForwardErrorResponse(c, false) + + require.False(t, wrote) + require.Equal(t, http.StatusTeapot, w.Code) + assert.Equal(t, "already written", w.Body.String()) +} + +// TestOpenAIHandler_GjsonExtraction 验证 gjson 从请求体中提取 model/stream 的正确性 +func TestOpenAIHandler_GjsonExtraction(t *testing.T) { + tests := []struct { + name string + body string + wantModel string + wantStream bool + }{ + {"正常提取", `{"model":"gpt-4","stream":true,"input":"hello"}`, "gpt-4", true}, + {"stream false", `{"model":"gpt-4","stream":false}`, "gpt-4", false}, + {"无 stream 字段", `{"model":"gpt-4"}`, "gpt-4", false}, + {"model 缺失", `{"stream":true}`, "", true}, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + body := []byte(tt.body) + modelResult := gjson.GetBytes(body, "model") + model := "" + if modelResult.Type == gjson.String { + model = modelResult.String() + } + stream := gjson.GetBytes(body, "stream").Bool() + require.Equal(t, tt.wantModel, model) + require.Equal(t, tt.wantStream, stream) + }) + } +} + +// TestOpenAIHandler_GjsonValidation 验证修复后的 JSON 合法性和类型校验 +func TestOpenAIHandler_GjsonValidation(t *testing.T) { + // 非法 JSON 被 gjson.ValidBytes 拦截 + require.False(t, gjson.ValidBytes([]byte(`{invalid json`))) + + // model 为数字 → 类型不是 gjson.String,应被拒绝 + body := []byte(`{"model":123}`) + modelResult := gjson.GetBytes(body, "model") + require.True(t, modelResult.Exists()) + require.NotEqual(t, gjson.String, modelResult.Type) + + // model 为 null → 类型不是 gjson.String,应被拒绝 + body2 := []byte(`{"model":null}`) + modelResult2 := gjson.GetBytes(body2, "model") + require.True(t, modelResult2.Exists()) + require.NotEqual(t, gjson.String, modelResult2.Type) + + // stream 为 string → 类型既不是 True 也不是 False,应被拒绝 + body3 := []byte(`{"model":"gpt-4","stream":"true"}`) + streamResult := gjson.GetBytes(body3, "stream") + require.True(t, streamResult.Exists()) + require.NotEqual(t, gjson.True, streamResult.Type) + require.NotEqual(t, gjson.False, streamResult.Type) + + // stream 为 int → 同上 + body4 := []byte(`{"model":"gpt-4","stream":1}`) + streamResult2 := gjson.GetBytes(body4, "stream") + require.True(t, streamResult2.Exists()) + require.NotEqual(t, gjson.True, streamResult2.Type) + require.NotEqual(t, gjson.False, streamResult2.Type) +} + +// TestOpenAIHandler_InstructionsInjection 验证 instructions 的 gjson/sjson 注入逻辑 +func TestOpenAIHandler_InstructionsInjection(t *testing.T) { + // 测试 1:无 instructions → 注入 + body := []byte(`{"model":"gpt-4"}`) + existing := gjson.GetBytes(body, "instructions").String() + require.Empty(t, existing) + newBody, err := sjson.SetBytes(body, "instructions", "test instruction") + require.NoError(t, err) + require.Equal(t, "test instruction", gjson.GetBytes(newBody, "instructions").String()) + + // 测试 2:已有 instructions → 不覆盖 + body2 := []byte(`{"model":"gpt-4","instructions":"existing"}`) + existing2 := gjson.GetBytes(body2, "instructions").String() + require.Equal(t, "existing", existing2) + + // 测试 3:空白 instructions → 注入 + body3 := []byte(`{"model":"gpt-4","instructions":" "}`) + existing3 := strings.TrimSpace(gjson.GetBytes(body3, "instructions").String()) + require.Empty(t, existing3) + + // 测试 4:sjson.SetBytes 返回错误时不应 panic + // 正常 JSON 不会产生 sjson 错误,验证返回值被正确处理 + validBody := []byte(`{"model":"gpt-4"}`) + result, setErr := sjson.SetBytes(validBody, "instructions", "hello") + require.NoError(t, setErr) + require.True(t, gjson.ValidBytes(result)) +} diff --git a/backend/internal/handler/ops_error_logger.go b/backend/internal/handler/ops_error_logger.go index cb62ceae..ab9a2167 100644 --- a/backend/internal/handler/ops_error_logger.go +++ b/backend/internal/handler/ops_error_logger.go @@ -41,9 +41,8 @@ const ( ) type opsErrorLogJob struct { - ops *service.OpsService - entry *service.OpsInsertErrorLogInput - requestBody []byte + ops *service.OpsService + entry *service.OpsInsertErrorLogInput } var ( @@ -58,6 +57,7 @@ var ( opsErrorLogEnqueued atomic.Int64 opsErrorLogDropped atomic.Int64 opsErrorLogProcessed atomic.Int64 + opsErrorLogSanitized atomic.Int64 opsErrorLogLastDropLogAt atomic.Int64 @@ -94,7 +94,7 @@ func startOpsErrorLogWorkers() { } }() ctx, cancel := context.WithTimeout(context.Background(), opsErrorLogTimeout) - _ = job.ops.RecordError(ctx, job.entry, job.requestBody) + _ = job.ops.RecordError(ctx, job.entry, nil) cancel() opsErrorLogProcessed.Add(1) }() @@ -103,7 +103,7 @@ func startOpsErrorLogWorkers() { } } -func enqueueOpsErrorLog(ops *service.OpsService, entry *service.OpsInsertErrorLogInput, requestBody []byte) { +func enqueueOpsErrorLog(ops *service.OpsService, entry *service.OpsInsertErrorLogInput) { if ops == nil || entry == nil { return } @@ -129,7 +129,7 @@ func enqueueOpsErrorLog(ops *service.OpsService, entry *service.OpsInsertErrorLo } select { - case opsErrorLogQueue <- opsErrorLogJob{ops: ops, entry: entry, requestBody: requestBody}: + case opsErrorLogQueue <- opsErrorLogJob{ops: ops, entry: entry}: opsErrorLogQueueLen.Add(1) opsErrorLogEnqueued.Add(1) default: @@ -205,6 +205,10 @@ func OpsErrorLogProcessedTotal() int64 { return opsErrorLogProcessed.Load() } +func OpsErrorLogSanitizedTotal() int64 { + return opsErrorLogSanitized.Load() +} + func maybeLogOpsErrorLogDrop() { now := time.Now().Unix() @@ -222,12 +226,13 @@ func maybeLogOpsErrorLogDrop() { queueCap := OpsErrorLogQueueCapacity() log.Printf( - "[OpsErrorLogger] queue is full; dropping logs (queued=%d cap=%d enqueued_total=%d dropped_total=%d processed_total=%d)", + "[OpsErrorLogger] queue is full; dropping logs (queued=%d cap=%d enqueued_total=%d dropped_total=%d processed_total=%d sanitized_total=%d)", queued, queueCap, opsErrorLogEnqueued.Load(), opsErrorLogDropped.Load(), opsErrorLogProcessed.Load(), + opsErrorLogSanitized.Load(), ) } @@ -255,18 +260,49 @@ func setOpsRequestContext(c *gin.Context, model string, stream bool, requestBody if c == nil { return } + model = strings.TrimSpace(model) c.Set(opsModelKey, model) c.Set(opsStreamKey, stream) if len(requestBody) > 0 { c.Set(opsRequestBodyKey, requestBody) } + if c.Request != nil && model != "" { + ctx := context.WithValue(c.Request.Context(), ctxkey.Model, model) + c.Request = c.Request.WithContext(ctx) + } } -func setOpsSelectedAccount(c *gin.Context, accountID int64) { +func attachOpsRequestBodyToEntry(c *gin.Context, entry *service.OpsInsertErrorLogInput) { + if c == nil || entry == nil { + return + } + v, ok := c.Get(opsRequestBodyKey) + if !ok { + return + } + raw, ok := v.([]byte) + if !ok || len(raw) == 0 { + return + } + entry.RequestBodyJSON, entry.RequestBodyTruncated, entry.RequestBodyBytes = service.PrepareOpsRequestBodyForQueue(raw) + opsErrorLogSanitized.Add(1) +} + +func setOpsSelectedAccount(c *gin.Context, accountID int64, platform ...string) { if c == nil || accountID <= 0 { return } c.Set(opsAccountIDKey, accountID) + if c.Request != nil { + ctx := context.WithValue(c.Request.Context(), ctxkey.AccountID, accountID) + if len(platform) > 0 { + p := strings.TrimSpace(platform[0]) + if p != "" { + ctx = context.WithValue(ctx, ctxkey.Platform, p) + } + } + c.Request = c.Request.WithContext(ctx) + } } type opsCaptureWriter struct { @@ -507,6 +543,7 @@ func OpsErrorLoggerMiddleware(ops *service.OpsService) gin.HandlerFunc { RetryCount: 0, CreatedAt: time.Now(), } + applyOpsLatencyFieldsFromContext(c, entry) if apiKey != nil { entry.APIKeyID = &apiKey.ID @@ -528,14 +565,9 @@ func OpsErrorLoggerMiddleware(ops *service.OpsService) gin.HandlerFunc { entry.ClientIP = &clientIP } - var requestBody []byte - if v, ok := c.Get(opsRequestBodyKey); ok { - if b, ok := v.([]byte); ok && len(b) > 0 { - requestBody = b - } - } // Store request headers/body only when an upstream error occurred to keep overhead minimal. entry.RequestHeadersJSON = extractOpsRetryRequestHeaders(c) + attachOpsRequestBodyToEntry(c, entry) // Skip logging if a passthrough rule with skip_monitoring=true matched. if v, ok := c.Get(service.OpsSkipPassthroughKey); ok { @@ -544,7 +576,7 @@ func OpsErrorLoggerMiddleware(ops *service.OpsService) gin.HandlerFunc { } } - enqueueOpsErrorLog(ops, entry, requestBody) + enqueueOpsErrorLog(ops, entry) return } @@ -632,6 +664,7 @@ func OpsErrorLoggerMiddleware(ops *service.OpsService) gin.HandlerFunc { RetryCount: 0, CreatedAt: time.Now(), } + applyOpsLatencyFieldsFromContext(c, entry) // Capture upstream error context set by gateway services (if present). // This does NOT affect the client response; it enriches Ops troubleshooting data. @@ -707,17 +740,12 @@ func OpsErrorLoggerMiddleware(ops *service.OpsService) gin.HandlerFunc { entry.ClientIP = &clientIP } - var requestBody []byte - if v, ok := c.Get(opsRequestBodyKey); ok { - if b, ok := v.([]byte); ok && len(b) > 0 { - requestBody = b - } - } // Persist only a minimal, whitelisted set of request headers to improve retry fidelity. // Do NOT store Authorization/Cookie/etc. entry.RequestHeadersJSON = extractOpsRetryRequestHeaders(c) + attachOpsRequestBodyToEntry(c, entry) - enqueueOpsErrorLog(ops, entry, requestBody) + enqueueOpsErrorLog(ops, entry) } } @@ -760,6 +788,44 @@ func extractOpsRetryRequestHeaders(c *gin.Context) *string { return &s } +func applyOpsLatencyFieldsFromContext(c *gin.Context, entry *service.OpsInsertErrorLogInput) { + if c == nil || entry == nil { + return + } + entry.AuthLatencyMs = getContextLatencyMs(c, service.OpsAuthLatencyMsKey) + entry.RoutingLatencyMs = getContextLatencyMs(c, service.OpsRoutingLatencyMsKey) + entry.UpstreamLatencyMs = getContextLatencyMs(c, service.OpsUpstreamLatencyMsKey) + entry.ResponseLatencyMs = getContextLatencyMs(c, service.OpsResponseLatencyMsKey) + entry.TimeToFirstTokenMs = getContextLatencyMs(c, service.OpsTimeToFirstTokenMsKey) +} + +func getContextLatencyMs(c *gin.Context, key string) *int64 { + if c == nil || strings.TrimSpace(key) == "" { + return nil + } + v, ok := c.Get(key) + if !ok { + return nil + } + var ms int64 + switch t := v.(type) { + case int: + ms = int64(t) + case int32: + ms = int64(t) + case int64: + ms = t + case float64: + ms = int64(t) + default: + return nil + } + if ms < 0 { + return nil + } + return &ms +} + type parsedOpsError struct { ErrorType string Message string diff --git a/backend/internal/handler/ops_error_logger_test.go b/backend/internal/handler/ops_error_logger_test.go new file mode 100644 index 00000000..a11fa1f2 --- /dev/null +++ b/backend/internal/handler/ops_error_logger_test.go @@ -0,0 +1,175 @@ +package handler + +import ( + "net/http" + "net/http/httptest" + "sync" + "testing" + + "github.com/Wei-Shaw/sub2api/internal/service" + "github.com/gin-gonic/gin" + "github.com/stretchr/testify/require" +) + +func resetOpsErrorLoggerStateForTest(t *testing.T) { + t.Helper() + + opsErrorLogMu.Lock() + ch := opsErrorLogQueue + opsErrorLogQueue = nil + opsErrorLogStopping = true + opsErrorLogMu.Unlock() + + if ch != nil { + close(ch) + } + opsErrorLogWorkersWg.Wait() + + opsErrorLogOnce = sync.Once{} + opsErrorLogStopOnce = sync.Once{} + opsErrorLogWorkersWg = sync.WaitGroup{} + opsErrorLogMu = sync.RWMutex{} + opsErrorLogStopping = false + + opsErrorLogQueueLen.Store(0) + opsErrorLogEnqueued.Store(0) + opsErrorLogDropped.Store(0) + opsErrorLogProcessed.Store(0) + opsErrorLogSanitized.Store(0) + opsErrorLogLastDropLogAt.Store(0) + + opsErrorLogShutdownCh = make(chan struct{}) + opsErrorLogShutdownOnce = sync.Once{} + opsErrorLogDrained.Store(false) +} + +func TestAttachOpsRequestBodyToEntry_SanitizeAndTrim(t *testing.T) { + resetOpsErrorLoggerStateForTest(t) + gin.SetMode(gin.TestMode) + + rec := httptest.NewRecorder() + c, _ := gin.CreateTestContext(rec) + c.Request = httptest.NewRequest(http.MethodPost, "/v1/messages", nil) + + raw := []byte(`{"access_token":"secret-token","messages":[{"role":"user","content":"hello"}]}`) + setOpsRequestContext(c, "claude-3", false, raw) + + entry := &service.OpsInsertErrorLogInput{} + attachOpsRequestBodyToEntry(c, entry) + + require.NotNil(t, entry.RequestBodyBytes) + require.Equal(t, len(raw), *entry.RequestBodyBytes) + require.NotNil(t, entry.RequestBodyJSON) + require.NotContains(t, *entry.RequestBodyJSON, "secret-token") + require.Contains(t, *entry.RequestBodyJSON, "[REDACTED]") + require.Equal(t, int64(1), OpsErrorLogSanitizedTotal()) +} + +func TestAttachOpsRequestBodyToEntry_InvalidJSONKeepsSize(t *testing.T) { + resetOpsErrorLoggerStateForTest(t) + gin.SetMode(gin.TestMode) + + rec := httptest.NewRecorder() + c, _ := gin.CreateTestContext(rec) + c.Request = httptest.NewRequest(http.MethodPost, "/v1/messages", nil) + + raw := []byte("not-json") + setOpsRequestContext(c, "claude-3", false, raw) + + entry := &service.OpsInsertErrorLogInput{} + attachOpsRequestBodyToEntry(c, entry) + + require.Nil(t, entry.RequestBodyJSON) + require.NotNil(t, entry.RequestBodyBytes) + require.Equal(t, len(raw), *entry.RequestBodyBytes) + require.False(t, entry.RequestBodyTruncated) + require.Equal(t, int64(1), OpsErrorLogSanitizedTotal()) +} + +func TestEnqueueOpsErrorLog_QueueFullDrop(t *testing.T) { + resetOpsErrorLoggerStateForTest(t) + + // 禁止 enqueueOpsErrorLog 触发 workers,使用测试队列验证满队列降级。 + opsErrorLogOnce.Do(func() {}) + + opsErrorLogMu.Lock() + opsErrorLogQueue = make(chan opsErrorLogJob, 1) + opsErrorLogMu.Unlock() + + ops := service.NewOpsService(nil, nil, nil, nil, nil, nil, nil, nil, nil, nil, nil) + entry := &service.OpsInsertErrorLogInput{ErrorPhase: "upstream", ErrorType: "upstream_error"} + + enqueueOpsErrorLog(ops, entry) + enqueueOpsErrorLog(ops, entry) + + require.Equal(t, int64(1), OpsErrorLogEnqueuedTotal()) + require.Equal(t, int64(1), OpsErrorLogDroppedTotal()) + require.Equal(t, int64(1), OpsErrorLogQueueLength()) +} + +func TestAttachOpsRequestBodyToEntry_EarlyReturnBranches(t *testing.T) { + resetOpsErrorLoggerStateForTest(t) + gin.SetMode(gin.TestMode) + + entry := &service.OpsInsertErrorLogInput{} + attachOpsRequestBodyToEntry(nil, entry) + attachOpsRequestBodyToEntry(&gin.Context{}, nil) + + rec := httptest.NewRecorder() + c, _ := gin.CreateTestContext(rec) + c.Request = httptest.NewRequest(http.MethodPost, "/v1/messages", nil) + + // 无请求体 key + attachOpsRequestBodyToEntry(c, entry) + require.Nil(t, entry.RequestBodyJSON) + require.Nil(t, entry.RequestBodyBytes) + require.False(t, entry.RequestBodyTruncated) + + // 错误类型 + c.Set(opsRequestBodyKey, "not-bytes") + attachOpsRequestBodyToEntry(c, entry) + require.Nil(t, entry.RequestBodyJSON) + require.Nil(t, entry.RequestBodyBytes) + + // 空 bytes + c.Set(opsRequestBodyKey, []byte{}) + attachOpsRequestBodyToEntry(c, entry) + require.Nil(t, entry.RequestBodyJSON) + require.Nil(t, entry.RequestBodyBytes) + + require.Equal(t, int64(0), OpsErrorLogSanitizedTotal()) +} + +func TestEnqueueOpsErrorLog_EarlyReturnBranches(t *testing.T) { + resetOpsErrorLoggerStateForTest(t) + + ops := service.NewOpsService(nil, nil, nil, nil, nil, nil, nil, nil, nil, nil, nil) + entry := &service.OpsInsertErrorLogInput{ErrorPhase: "upstream", ErrorType: "upstream_error"} + + // nil 入参分支 + enqueueOpsErrorLog(nil, entry) + enqueueOpsErrorLog(ops, nil) + require.Equal(t, int64(0), OpsErrorLogEnqueuedTotal()) + + // shutdown 分支 + close(opsErrorLogShutdownCh) + enqueueOpsErrorLog(ops, entry) + require.Equal(t, int64(0), OpsErrorLogEnqueuedTotal()) + + // stopping 分支 + resetOpsErrorLoggerStateForTest(t) + opsErrorLogMu.Lock() + opsErrorLogStopping = true + opsErrorLogMu.Unlock() + enqueueOpsErrorLog(ops, entry) + require.Equal(t, int64(0), OpsErrorLogEnqueuedTotal()) + + // queue nil 分支(防止启动 worker 干扰) + resetOpsErrorLoggerStateForTest(t) + opsErrorLogOnce.Do(func() {}) + opsErrorLogMu.Lock() + opsErrorLogQueue = nil + opsErrorLogMu.Unlock() + enqueueOpsErrorLog(ops, entry) + require.Equal(t, int64(0), OpsErrorLogEnqueuedTotal()) +} diff --git a/backend/internal/handler/sora_gateway_handler.go b/backend/internal/handler/sora_gateway_handler.go new file mode 100644 index 00000000..ab3a3f14 --- /dev/null +++ b/backend/internal/handler/sora_gateway_handler.go @@ -0,0 +1,677 @@ +package handler + +import ( + "context" + "crypto/sha256" + "encoding/hex" + "encoding/json" + "errors" + "fmt" + "io" + "net/http" + "os" + "path" + "path/filepath" + "strconv" + "strings" + "time" + + "github.com/Wei-Shaw/sub2api/internal/config" + "github.com/Wei-Shaw/sub2api/internal/pkg/ip" + "github.com/Wei-Shaw/sub2api/internal/pkg/logger" + middleware2 "github.com/Wei-Shaw/sub2api/internal/server/middleware" + "github.com/Wei-Shaw/sub2api/internal/service" + "github.com/Wei-Shaw/sub2api/internal/util/soraerror" + + "github.com/gin-gonic/gin" + "github.com/tidwall/gjson" + "github.com/tidwall/sjson" + "go.uber.org/zap" +) + +// SoraGatewayHandler handles Sora chat completions requests +type SoraGatewayHandler struct { + gatewayService *service.GatewayService + soraGatewayService *service.SoraGatewayService + billingCacheService *service.BillingCacheService + usageRecordWorkerPool *service.UsageRecordWorkerPool + concurrencyHelper *ConcurrencyHelper + maxAccountSwitches int + streamMode string + soraTLSEnabled bool + soraMediaSigningKey string + soraMediaRoot string +} + +// NewSoraGatewayHandler creates a new SoraGatewayHandler +func NewSoraGatewayHandler( + gatewayService *service.GatewayService, + soraGatewayService *service.SoraGatewayService, + concurrencyService *service.ConcurrencyService, + billingCacheService *service.BillingCacheService, + usageRecordWorkerPool *service.UsageRecordWorkerPool, + cfg *config.Config, +) *SoraGatewayHandler { + pingInterval := time.Duration(0) + maxAccountSwitches := 3 + streamMode := "force" + soraTLSEnabled := true + signKey := "" + mediaRoot := "/app/data/sora" + if cfg != nil { + pingInterval = time.Duration(cfg.Concurrency.PingInterval) * time.Second + if cfg.Gateway.MaxAccountSwitches > 0 { + maxAccountSwitches = cfg.Gateway.MaxAccountSwitches + } + if mode := strings.TrimSpace(cfg.Gateway.SoraStreamMode); mode != "" { + streamMode = mode + } + soraTLSEnabled = !cfg.Sora.Client.DisableTLSFingerprint + signKey = strings.TrimSpace(cfg.Gateway.SoraMediaSigningKey) + if root := strings.TrimSpace(cfg.Sora.Storage.LocalPath); root != "" { + mediaRoot = root + } + } + return &SoraGatewayHandler{ + gatewayService: gatewayService, + soraGatewayService: soraGatewayService, + billingCacheService: billingCacheService, + usageRecordWorkerPool: usageRecordWorkerPool, + concurrencyHelper: NewConcurrencyHelper(concurrencyService, SSEPingFormatComment, pingInterval), + maxAccountSwitches: maxAccountSwitches, + streamMode: strings.ToLower(streamMode), + soraTLSEnabled: soraTLSEnabled, + soraMediaSigningKey: signKey, + soraMediaRoot: mediaRoot, + } +} + +// ChatCompletions handles Sora /v1/chat/completions endpoint +func (h *SoraGatewayHandler) ChatCompletions(c *gin.Context) { + apiKey, ok := middleware2.GetAPIKeyFromContext(c) + if !ok { + h.errorResponse(c, http.StatusUnauthorized, "authentication_error", "Invalid API key") + return + } + + subject, ok := middleware2.GetAuthSubjectFromContext(c) + if !ok { + h.errorResponse(c, http.StatusInternalServerError, "api_error", "User context not found") + return + } + reqLog := requestLogger( + c, + "handler.sora_gateway.chat_completions", + zap.Int64("user_id", subject.UserID), + zap.Int64("api_key_id", apiKey.ID), + zap.Any("group_id", apiKey.GroupID), + ) + + body, err := io.ReadAll(c.Request.Body) + if err != nil { + if maxErr, ok := extractMaxBytesError(err); ok { + h.errorResponse(c, http.StatusRequestEntityTooLarge, "invalid_request_error", buildBodyTooLargeMessage(maxErr.Limit)) + return + } + h.errorResponse(c, http.StatusBadRequest, "invalid_request_error", "Failed to read request body") + return + } + if len(body) == 0 { + h.errorResponse(c, http.StatusBadRequest, "invalid_request_error", "Request body is empty") + return + } + + setOpsRequestContext(c, "", false, body) + + // 校验请求体 JSON 合法性 + if !gjson.ValidBytes(body) { + h.errorResponse(c, http.StatusBadRequest, "invalid_request_error", "Failed to parse request body") + return + } + + // 使用 gjson 只读提取字段做校验,避免完整 Unmarshal + modelResult := gjson.GetBytes(body, "model") + if !modelResult.Exists() || modelResult.Type != gjson.String || modelResult.String() == "" { + h.errorResponse(c, http.StatusBadRequest, "invalid_request_error", "model is required") + return + } + reqModel := modelResult.String() + + msgsResult := gjson.GetBytes(body, "messages") + if !msgsResult.IsArray() || len(msgsResult.Array()) == 0 { + h.errorResponse(c, http.StatusBadRequest, "invalid_request_error", "messages is required") + return + } + + clientStream := gjson.GetBytes(body, "stream").Bool() + reqLog = reqLog.With(zap.String("model", reqModel), zap.Bool("stream", clientStream)) + if !clientStream { + if h.streamMode == "error" { + h.errorResponse(c, http.StatusBadRequest, "invalid_request_error", "Sora requires stream=true") + return + } + var err error + body, err = sjson.SetBytes(body, "stream", true) + if err != nil { + h.errorResponse(c, http.StatusInternalServerError, "api_error", "Failed to process request") + return + } + } + + setOpsRequestContext(c, reqModel, clientStream, body) + + platform := "" + if forced, ok := middleware2.GetForcePlatformFromContext(c); ok { + platform = forced + } else if apiKey.Group != nil { + platform = apiKey.Group.Platform + } + if platform != service.PlatformSora { + h.errorResponse(c, http.StatusBadRequest, "invalid_request_error", "This endpoint only supports Sora platform") + return + } + + streamStarted := false + subscription, _ := middleware2.GetSubscriptionFromContext(c) + + maxWait := service.CalculateMaxWait(subject.Concurrency) + canWait, err := h.concurrencyHelper.IncrementWaitCount(c.Request.Context(), subject.UserID, maxWait) + waitCounted := false + if err != nil { + reqLog.Warn("sora.user_wait_counter_increment_failed", zap.Error(err)) + } else if !canWait { + reqLog.Info("sora.user_wait_queue_full", zap.Int("max_wait", maxWait)) + h.errorResponse(c, http.StatusTooManyRequests, "rate_limit_error", "Too many pending requests, please retry later") + return + } + if err == nil && canWait { + waitCounted = true + } + defer func() { + if waitCounted { + h.concurrencyHelper.DecrementWaitCount(c.Request.Context(), subject.UserID) + } + }() + + userReleaseFunc, err := h.concurrencyHelper.AcquireUserSlotWithWait(c, subject.UserID, subject.Concurrency, clientStream, &streamStarted) + if err != nil { + reqLog.Warn("sora.user_slot_acquire_failed", zap.Error(err)) + h.handleConcurrencyError(c, err, "user", streamStarted) + return + } + if waitCounted { + h.concurrencyHelper.DecrementWaitCount(c.Request.Context(), subject.UserID) + waitCounted = false + } + userReleaseFunc = wrapReleaseOnDone(c.Request.Context(), userReleaseFunc) + if userReleaseFunc != nil { + defer userReleaseFunc() + } + + if err := h.billingCacheService.CheckBillingEligibility(c.Request.Context(), apiKey.User, apiKey, apiKey.Group, subscription); err != nil { + reqLog.Info("sora.billing_eligibility_check_failed", zap.Error(err)) + status, code, message := billingErrorDetails(err) + h.handleStreamingAwareError(c, status, code, message, streamStarted) + return + } + + sessionHash := generateOpenAISessionHash(c, body) + + maxAccountSwitches := h.maxAccountSwitches + switchCount := 0 + failedAccountIDs := make(map[int64]struct{}) + lastFailoverStatus := 0 + var lastFailoverBody []byte + var lastFailoverHeaders http.Header + + for { + selection, err := h.gatewayService.SelectAccountWithLoadAwareness(c.Request.Context(), apiKey.GroupID, sessionHash, reqModel, failedAccountIDs, "") + if err != nil { + reqLog.Warn("sora.account_select_failed", + zap.Error(err), + zap.Int("excluded_account_count", len(failedAccountIDs)), + ) + if len(failedAccountIDs) == 0 { + h.handleStreamingAwareError(c, http.StatusServiceUnavailable, "api_error", "No available accounts: "+err.Error(), streamStarted) + return + } + rayID, mitigated, contentType := extractSoraFailoverHeaderInsights(lastFailoverHeaders, lastFailoverBody) + fields := []zap.Field{ + zap.Int("last_upstream_status", lastFailoverStatus), + } + if rayID != "" { + fields = append(fields, zap.String("last_upstream_cf_ray", rayID)) + } + if mitigated != "" { + fields = append(fields, zap.String("last_upstream_cf_mitigated", mitigated)) + } + if contentType != "" { + fields = append(fields, zap.String("last_upstream_content_type", contentType)) + } + reqLog.Warn("sora.failover_exhausted_no_available_accounts", fields...) + h.handleFailoverExhausted(c, lastFailoverStatus, lastFailoverHeaders, lastFailoverBody, streamStarted) + return + } + account := selection.Account + setOpsSelectedAccount(c, account.ID, account.Platform) + proxyBound := account.ProxyID != nil + proxyID := int64(0) + if account.ProxyID != nil { + proxyID = *account.ProxyID + } + tlsFingerprintEnabled := h.soraTLSEnabled + + accountReleaseFunc := selection.ReleaseFunc + if !selection.Acquired { + if selection.WaitPlan == nil { + h.handleStreamingAwareError(c, http.StatusServiceUnavailable, "api_error", "No available accounts", streamStarted) + return + } + accountWaitCounted := false + canWait, err := h.concurrencyHelper.IncrementAccountWaitCount(c.Request.Context(), account.ID, selection.WaitPlan.MaxWaiting) + if err != nil { + reqLog.Warn("sora.account_wait_counter_increment_failed", + zap.Int64("account_id", account.ID), + zap.Int64("proxy_id", proxyID), + zap.Bool("proxy_bound", proxyBound), + zap.Bool("tls_fingerprint_enabled", tlsFingerprintEnabled), + zap.Error(err), + ) + } else if !canWait { + reqLog.Info("sora.account_wait_queue_full", + zap.Int64("account_id", account.ID), + zap.Int64("proxy_id", proxyID), + zap.Bool("proxy_bound", proxyBound), + zap.Bool("tls_fingerprint_enabled", tlsFingerprintEnabled), + zap.Int("max_waiting", selection.WaitPlan.MaxWaiting), + ) + h.handleStreamingAwareError(c, http.StatusTooManyRequests, "rate_limit_error", "Too many pending requests, please retry later", streamStarted) + return + } + if err == nil && canWait { + accountWaitCounted = true + } + defer func() { + if accountWaitCounted { + h.concurrencyHelper.DecrementAccountWaitCount(c.Request.Context(), account.ID) + } + }() + + accountReleaseFunc, err = h.concurrencyHelper.AcquireAccountSlotWithWaitTimeout( + c, + account.ID, + selection.WaitPlan.MaxConcurrency, + selection.WaitPlan.Timeout, + clientStream, + &streamStarted, + ) + if err != nil { + reqLog.Warn("sora.account_slot_acquire_failed", + zap.Int64("account_id", account.ID), + zap.Int64("proxy_id", proxyID), + zap.Bool("proxy_bound", proxyBound), + zap.Bool("tls_fingerprint_enabled", tlsFingerprintEnabled), + zap.Error(err), + ) + h.handleConcurrencyError(c, err, "account", streamStarted) + return + } + if accountWaitCounted { + h.concurrencyHelper.DecrementAccountWaitCount(c.Request.Context(), account.ID) + accountWaitCounted = false + } + } + accountReleaseFunc = wrapReleaseOnDone(c.Request.Context(), accountReleaseFunc) + + result, err := h.soraGatewayService.Forward(c.Request.Context(), c, account, body, clientStream) + if accountReleaseFunc != nil { + accountReleaseFunc() + } + if err != nil { + var failoverErr *service.UpstreamFailoverError + if errors.As(err, &failoverErr) { + failedAccountIDs[account.ID] = struct{}{} + if switchCount >= maxAccountSwitches { + lastFailoverStatus = failoverErr.StatusCode + lastFailoverHeaders = cloneHTTPHeaders(failoverErr.ResponseHeaders) + lastFailoverBody = failoverErr.ResponseBody + rayID, mitigated, contentType := extractSoraFailoverHeaderInsights(lastFailoverHeaders, lastFailoverBody) + fields := []zap.Field{ + zap.Int64("account_id", account.ID), + zap.Int64("proxy_id", proxyID), + zap.Bool("proxy_bound", proxyBound), + zap.Bool("tls_fingerprint_enabled", tlsFingerprintEnabled), + zap.Int("upstream_status", failoverErr.StatusCode), + zap.Int("switch_count", switchCount), + zap.Int("max_switches", maxAccountSwitches), + } + if rayID != "" { + fields = append(fields, zap.String("upstream_cf_ray", rayID)) + } + if mitigated != "" { + fields = append(fields, zap.String("upstream_cf_mitigated", mitigated)) + } + if contentType != "" { + fields = append(fields, zap.String("upstream_content_type", contentType)) + } + reqLog.Warn("sora.upstream_failover_exhausted", fields...) + h.handleFailoverExhausted(c, lastFailoverStatus, lastFailoverHeaders, lastFailoverBody, streamStarted) + return + } + lastFailoverStatus = failoverErr.StatusCode + lastFailoverHeaders = cloneHTTPHeaders(failoverErr.ResponseHeaders) + lastFailoverBody = failoverErr.ResponseBody + switchCount++ + upstreamErrCode, upstreamErrMsg := extractUpstreamErrorCodeAndMessage(lastFailoverBody) + rayID, mitigated, contentType := extractSoraFailoverHeaderInsights(lastFailoverHeaders, lastFailoverBody) + fields := []zap.Field{ + zap.Int64("account_id", account.ID), + zap.Int64("proxy_id", proxyID), + zap.Bool("proxy_bound", proxyBound), + zap.Bool("tls_fingerprint_enabled", tlsFingerprintEnabled), + zap.Int("upstream_status", failoverErr.StatusCode), + zap.String("upstream_error_code", upstreamErrCode), + zap.String("upstream_error_message", upstreamErrMsg), + zap.Int("switch_count", switchCount), + zap.Int("max_switches", maxAccountSwitches), + } + if rayID != "" { + fields = append(fields, zap.String("upstream_cf_ray", rayID)) + } + if mitigated != "" { + fields = append(fields, zap.String("upstream_cf_mitigated", mitigated)) + } + if contentType != "" { + fields = append(fields, zap.String("upstream_content_type", contentType)) + } + reqLog.Warn("sora.upstream_failover_switching", fields...) + continue + } + reqLog.Error("sora.forward_failed", + zap.Int64("account_id", account.ID), + zap.Int64("proxy_id", proxyID), + zap.Bool("proxy_bound", proxyBound), + zap.Bool("tls_fingerprint_enabled", tlsFingerprintEnabled), + zap.Error(err), + ) + return + } + + userAgent := c.GetHeader("User-Agent") + clientIP := ip.GetClientIP(c) + + // 使用量记录通过有界 worker 池提交,避免请求热路径创建无界 goroutine。 + h.submitUsageRecordTask(func(ctx context.Context) { + if err := h.gatewayService.RecordUsage(ctx, &service.RecordUsageInput{ + Result: result, + APIKey: apiKey, + User: apiKey.User, + Account: account, + Subscription: subscription, + UserAgent: userAgent, + IPAddress: clientIP, + }); err != nil { + logger.L().With( + zap.String("component", "handler.sora_gateway.chat_completions"), + zap.Int64("user_id", subject.UserID), + zap.Int64("api_key_id", apiKey.ID), + zap.Any("group_id", apiKey.GroupID), + zap.String("model", reqModel), + zap.Int64("account_id", account.ID), + ).Error("sora.record_usage_failed", zap.Error(err)) + } + }) + reqLog.Debug("sora.request_completed", + zap.Int64("account_id", account.ID), + zap.Int64("proxy_id", proxyID), + zap.Bool("proxy_bound", proxyBound), + zap.Bool("tls_fingerprint_enabled", tlsFingerprintEnabled), + zap.Int("switch_count", switchCount), + ) + return + } +} + +func generateOpenAISessionHash(c *gin.Context, body []byte) string { + if c == nil { + return "" + } + sessionID := strings.TrimSpace(c.GetHeader("session_id")) + if sessionID == "" { + sessionID = strings.TrimSpace(c.GetHeader("conversation_id")) + } + if sessionID == "" && len(body) > 0 { + sessionID = strings.TrimSpace(gjson.GetBytes(body, "prompt_cache_key").String()) + } + if sessionID == "" { + return "" + } + hash := sha256.Sum256([]byte(sessionID)) + return hex.EncodeToString(hash[:]) +} + +func (h *SoraGatewayHandler) submitUsageRecordTask(task service.UsageRecordTask) { + if task == nil { + return + } + if h.usageRecordWorkerPool != nil { + h.usageRecordWorkerPool.Submit(task) + return + } + // 回退路径:worker 池未注入时同步执行,避免退回到无界 goroutine 模式。 + ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) + defer cancel() + task(ctx) +} + +func (h *SoraGatewayHandler) handleConcurrencyError(c *gin.Context, err error, slotType string, streamStarted bool) { + h.handleStreamingAwareError(c, http.StatusTooManyRequests, "rate_limit_error", + fmt.Sprintf("Concurrency limit exceeded for %s, please retry later", slotType), streamStarted) +} + +func (h *SoraGatewayHandler) handleFailoverExhausted(c *gin.Context, statusCode int, responseHeaders http.Header, responseBody []byte, streamStarted bool) { + status, errType, errMsg := h.mapUpstreamError(statusCode, responseHeaders, responseBody) + h.handleStreamingAwareError(c, status, errType, errMsg, streamStarted) +} + +func (h *SoraGatewayHandler) mapUpstreamError(statusCode int, responseHeaders http.Header, responseBody []byte) (int, string, string) { + if isSoraCloudflareChallengeResponse(statusCode, responseHeaders, responseBody) { + baseMsg := fmt.Sprintf("Sora request blocked by Cloudflare challenge (HTTP %d). Please switch to a clean proxy/network and retry.", statusCode) + return http.StatusBadGateway, "upstream_error", formatSoraCloudflareChallengeMessage(baseMsg, responseHeaders, responseBody) + } + + upstreamCode, upstreamMessage := extractUpstreamErrorCodeAndMessage(responseBody) + if strings.EqualFold(upstreamCode, "cf_shield_429") { + baseMsg := "Sora request blocked by Cloudflare shield (429). Please switch to a clean proxy/network and retry." + return http.StatusTooManyRequests, "rate_limit_error", formatSoraCloudflareChallengeMessage(baseMsg, responseHeaders, responseBody) + } + if shouldPassthroughSoraUpstreamMessage(statusCode, upstreamMessage) { + switch statusCode { + case 401, 403, 404, 500, 502, 503, 504: + return http.StatusBadGateway, "upstream_error", upstreamMessage + case 429: + return http.StatusTooManyRequests, "rate_limit_error", upstreamMessage + } + } + + switch statusCode { + case 401: + return http.StatusBadGateway, "upstream_error", "Upstream authentication failed, please contact administrator" + case 403: + return http.StatusBadGateway, "upstream_error", "Upstream access forbidden, please contact administrator" + case 404: + if strings.EqualFold(upstreamCode, "unsupported_country_code") { + return http.StatusBadGateway, "upstream_error", "Upstream region capability unavailable for this account, please contact administrator" + } + return http.StatusBadGateway, "upstream_error", "Upstream capability unavailable for this account, please contact administrator" + case 429: + return http.StatusTooManyRequests, "rate_limit_error", "Upstream rate limit exceeded, please retry later" + case 529: + return http.StatusServiceUnavailable, "upstream_error", "Upstream service overloaded, please retry later" + case 500, 502, 503, 504: + return http.StatusBadGateway, "upstream_error", "Upstream service temporarily unavailable" + default: + return http.StatusBadGateway, "upstream_error", "Upstream request failed" + } +} + +func cloneHTTPHeaders(headers http.Header) http.Header { + if headers == nil { + return nil + } + return headers.Clone() +} + +func extractSoraFailoverHeaderInsights(headers http.Header, body []byte) (rayID, mitigated, contentType string) { + if headers != nil { + mitigated = strings.TrimSpace(headers.Get("cf-mitigated")) + contentType = strings.TrimSpace(headers.Get("content-type")) + if contentType == "" { + contentType = strings.TrimSpace(headers.Get("Content-Type")) + } + } + rayID = soraerror.ExtractCloudflareRayID(headers, body) + return rayID, mitigated, contentType +} + +func isSoraCloudflareChallengeResponse(statusCode int, headers http.Header, body []byte) bool { + return soraerror.IsCloudflareChallengeResponse(statusCode, headers, body) +} + +func shouldPassthroughSoraUpstreamMessage(statusCode int, message string) bool { + message = strings.TrimSpace(message) + if message == "" { + return false + } + if statusCode == http.StatusForbidden || statusCode == http.StatusTooManyRequests { + lower := strings.ToLower(message) + if strings.Contains(lower, "Just a moment...`) + + h := &SoraGatewayHandler{} + h.handleFailoverExhausted(c, http.StatusForbidden, headers, body, true) + + lines := strings.Split(strings.TrimSuffix(w.Body.String(), "\n\n"), "\n") + require.Len(t, lines, 2) + jsonStr := strings.TrimPrefix(lines[1], "data: ") + + var parsed map[string]any + require.NoError(t, json.Unmarshal([]byte(jsonStr), &parsed)) + + errorObj, ok := parsed["error"].(map[string]any) + require.True(t, ok) + require.Equal(t, "upstream_error", errorObj["type"]) + msg, _ := errorObj["message"].(string) + require.Contains(t, msg, "Cloudflare challenge") + require.Contains(t, msg, "cf-ray: 9d01b0e9ecc35829-SEA") +} + +func TestSoraHandleFailoverExhausted_CfShield429MappedToRateLimitError(t *testing.T) { + gin.SetMode(gin.TestMode) + w := httptest.NewRecorder() + c, _ := gin.CreateTestContext(w) + c.Request = httptest.NewRequest(http.MethodGet, "/", nil) + + headers := http.Header{} + headers.Set("cf-ray", "9d03b68c086027a1-SEA") + body := []byte(`{"error":{"code":"cf_shield_429","message":"shield blocked"}}`) + + h := &SoraGatewayHandler{} + h.handleFailoverExhausted(c, http.StatusTooManyRequests, headers, body, true) + + lines := strings.Split(strings.TrimSuffix(w.Body.String(), "\n\n"), "\n") + require.Len(t, lines, 2) + jsonStr := strings.TrimPrefix(lines[1], "data: ") + + var parsed map[string]any + require.NoError(t, json.Unmarshal([]byte(jsonStr), &parsed)) + + errorObj, ok := parsed["error"].(map[string]any) + require.True(t, ok) + require.Equal(t, "rate_limit_error", errorObj["type"]) + msg, _ := errorObj["message"].(string) + require.Contains(t, msg, "Cloudflare shield") + require.Contains(t, msg, "cf-ray: 9d03b68c086027a1-SEA") +} + +func TestExtractSoraFailoverHeaderInsights(t *testing.T) { + headers := http.Header{} + headers.Set("cf-mitigated", "challenge") + headers.Set("content-type", "text/html") + body := []byte(``) + + rayID, mitigated, contentType := extractSoraFailoverHeaderInsights(headers, body) + require.Equal(t, "9cff2d62d83bb98d", rayID) + require.Equal(t, "challenge", mitigated) + require.Equal(t, "text/html", contentType) +} diff --git a/backend/internal/handler/usage_handler.go b/backend/internal/handler/usage_handler.go index 129dbfa6..b8182dad 100644 --- a/backend/internal/handler/usage_handler.go +++ b/backend/internal/handler/usage_handler.go @@ -392,7 +392,7 @@ func (h *UsageHandler) DashboardAPIKeysUsage(c *gin.Context) { return } - stats, err := h.usageService.GetBatchAPIKeyUsageStats(c.Request.Context(), validAPIKeyIDs) + stats, err := h.usageService.GetBatchAPIKeyUsageStats(c.Request.Context(), validAPIKeyIDs, time.Time{}, time.Time{}) if err != nil { response.ErrorFrom(c, err) return diff --git a/backend/internal/handler/usage_record_submit_task_test.go b/backend/internal/handler/usage_record_submit_task_test.go new file mode 100644 index 00000000..df759f44 --- /dev/null +++ b/backend/internal/handler/usage_record_submit_task_test.go @@ -0,0 +1,136 @@ +package handler + +import ( + "context" + "sync/atomic" + "testing" + "time" + + "github.com/Wei-Shaw/sub2api/internal/service" + "github.com/stretchr/testify/require" +) + +func newUsageRecordTestPool(t *testing.T) *service.UsageRecordWorkerPool { + t.Helper() + pool := service.NewUsageRecordWorkerPoolWithOptions(service.UsageRecordWorkerPoolOptions{ + WorkerCount: 1, + QueueSize: 8, + TaskTimeout: time.Second, + OverflowPolicy: "drop", + OverflowSamplePercent: 0, + AutoScaleEnabled: false, + }) + t.Cleanup(pool.Stop) + return pool +} + +func TestGatewayHandlerSubmitUsageRecordTask_WithPool(t *testing.T) { + pool := newUsageRecordTestPool(t) + h := &GatewayHandler{usageRecordWorkerPool: pool} + + done := make(chan struct{}) + h.submitUsageRecordTask(func(ctx context.Context) { + close(done) + }) + + select { + case <-done: + case <-time.After(time.Second): + t.Fatal("task not executed") + } +} + +func TestGatewayHandlerSubmitUsageRecordTask_WithoutPoolSyncFallback(t *testing.T) { + h := &GatewayHandler{} + var called atomic.Bool + + h.submitUsageRecordTask(func(ctx context.Context) { + if _, ok := ctx.Deadline(); !ok { + t.Fatal("expected deadline in fallback context") + } + called.Store(true) + }) + + require.True(t, called.Load()) +} + +func TestGatewayHandlerSubmitUsageRecordTask_NilTask(t *testing.T) { + h := &GatewayHandler{} + require.NotPanics(t, func() { + h.submitUsageRecordTask(nil) + }) +} + +func TestOpenAIGatewayHandlerSubmitUsageRecordTask_WithPool(t *testing.T) { + pool := newUsageRecordTestPool(t) + h := &OpenAIGatewayHandler{usageRecordWorkerPool: pool} + + done := make(chan struct{}) + h.submitUsageRecordTask(func(ctx context.Context) { + close(done) + }) + + select { + case <-done: + case <-time.After(time.Second): + t.Fatal("task not executed") + } +} + +func TestOpenAIGatewayHandlerSubmitUsageRecordTask_WithoutPoolSyncFallback(t *testing.T) { + h := &OpenAIGatewayHandler{} + var called atomic.Bool + + h.submitUsageRecordTask(func(ctx context.Context) { + if _, ok := ctx.Deadline(); !ok { + t.Fatal("expected deadline in fallback context") + } + called.Store(true) + }) + + require.True(t, called.Load()) +} + +func TestOpenAIGatewayHandlerSubmitUsageRecordTask_NilTask(t *testing.T) { + h := &OpenAIGatewayHandler{} + require.NotPanics(t, func() { + h.submitUsageRecordTask(nil) + }) +} + +func TestSoraGatewayHandlerSubmitUsageRecordTask_WithPool(t *testing.T) { + pool := newUsageRecordTestPool(t) + h := &SoraGatewayHandler{usageRecordWorkerPool: pool} + + done := make(chan struct{}) + h.submitUsageRecordTask(func(ctx context.Context) { + close(done) + }) + + select { + case <-done: + case <-time.After(time.Second): + t.Fatal("task not executed") + } +} + +func TestSoraGatewayHandlerSubmitUsageRecordTask_WithoutPoolSyncFallback(t *testing.T) { + h := &SoraGatewayHandler{} + var called atomic.Bool + + h.submitUsageRecordTask(func(ctx context.Context) { + if _, ok := ctx.Deadline(); !ok { + t.Fatal("expected deadline in fallback context") + } + called.Store(true) + }) + + require.True(t, called.Load()) +} + +func TestSoraGatewayHandlerSubmitUsageRecordTask_NilTask(t *testing.T) { + h := &SoraGatewayHandler{} + require.NotPanics(t, func() { + h.submitUsageRecordTask(nil) + }) +} diff --git a/backend/internal/handler/wire.go b/backend/internal/handler/wire.go index 7b62149c..79d583fd 100644 --- a/backend/internal/handler/wire.go +++ b/backend/internal/handler/wire.go @@ -53,8 +53,8 @@ func ProvideAdminHandlers( } // ProvideSystemHandler creates admin.SystemHandler with UpdateService -func ProvideSystemHandler(updateService *service.UpdateService) *admin.SystemHandler { - return admin.NewSystemHandler(updateService) +func ProvideSystemHandler(updateService *service.UpdateService, lockService *service.SystemOperationLockService) *admin.SystemHandler { + return admin.NewSystemHandler(updateService, lockService) } // ProvideSettingHandler creates SettingHandler with version from BuildInfo @@ -74,8 +74,11 @@ func ProvideHandlers( adminHandlers *AdminHandlers, gatewayHandler *GatewayHandler, openaiGatewayHandler *OpenAIGatewayHandler, + soraGatewayHandler *SoraGatewayHandler, settingHandler *SettingHandler, totpHandler *TotpHandler, + _ *service.IdempotencyCoordinator, + _ *service.IdempotencyCleanupService, ) *Handlers { return &Handlers{ Auth: authHandler, @@ -88,6 +91,7 @@ func ProvideHandlers( Admin: adminHandlers, Gateway: gatewayHandler, OpenAIGateway: openaiGatewayHandler, + SoraGateway: soraGatewayHandler, Setting: settingHandler, Totp: totpHandler, } @@ -105,6 +109,7 @@ var ProviderSet = wire.NewSet( NewAnnouncementHandler, NewGatewayHandler, NewOpenAIGatewayHandler, + NewSoraGatewayHandler, NewTotpHandler, ProvideSettingHandler, diff --git a/backend/internal/integration/e2e_gateway_test.go b/backend/internal/integration/e2e_gateway_test.go index ec0b29f7..8ee3f22e 100644 --- a/backend/internal/integration/e2e_gateway_test.go +++ b/backend/internal/integration/e2e_gateway_test.go @@ -21,11 +21,18 @@ var ( // - "" (默认): 使用 /v1/messages, /v1beta/models(混合模式,可调度 antigravity 账户) // - "/antigravity": 使用 /antigravity/v1/messages, /antigravity/v1beta/models(非混合模式,仅 antigravity 账户) endpointPrefix = getEnv("ENDPOINT_PREFIX", "") - claudeAPIKey = "sk-8e572bc3b3de92ace4f41f4256c28600ca11805732a7b693b5c44741346bbbb3" - geminiAPIKey = "sk-5950197a2085b38bbe5a1b229cc02b8ece914963fc44cacc06d497ae8b87410f" testInterval = 1 * time.Second // 测试间隔,防止限流 ) +const ( + // 注意:E2E 测试请使用环境变量注入密钥,避免任何凭证进入仓库历史。 + // 例如: + // export CLAUDE_API_KEY="sk-..." + // export GEMINI_API_KEY="sk-..." + claudeAPIKeyEnv = "CLAUDE_API_KEY" + geminiAPIKeyEnv = "GEMINI_API_KEY" +) + func getEnv(key, defaultVal string) string { if v := os.Getenv(key); v != "" { return v @@ -65,16 +72,45 @@ func TestMain(m *testing.M) { if endpointPrefix != "" { mode = "Antigravity 模式" } - fmt.Printf("\n🚀 E2E Gateway Tests - %s (prefix=%q, %s)\n\n", baseURL, endpointPrefix, mode) + claudeKeySet := strings.TrimSpace(os.Getenv(claudeAPIKeyEnv)) != "" + geminiKeySet := strings.TrimSpace(os.Getenv(geminiAPIKeyEnv)) != "" + fmt.Printf("\n🚀 E2E Gateway Tests - %s (prefix=%q, %s, %s=%v, %s=%v)\n\n", + baseURL, + endpointPrefix, + mode, + claudeAPIKeyEnv, + claudeKeySet, + geminiAPIKeyEnv, + geminiKeySet, + ) os.Exit(m.Run()) } +func requireClaudeAPIKey(t *testing.T) string { + t.Helper() + key := strings.TrimSpace(os.Getenv(claudeAPIKeyEnv)) + if key == "" { + t.Skipf("未设置 %s,跳过 Claude 相关 E2E 测试", claudeAPIKeyEnv) + } + return key +} + +func requireGeminiAPIKey(t *testing.T) string { + t.Helper() + key := strings.TrimSpace(os.Getenv(geminiAPIKeyEnv)) + if key == "" { + t.Skipf("未设置 %s,跳过 Gemini 相关 E2E 测试", geminiAPIKeyEnv) + } + return key +} + // TestClaudeModelsList 测试 GET /v1/models func TestClaudeModelsList(t *testing.T) { + claudeKey := requireClaudeAPIKey(t) url := baseURL + endpointPrefix + "/v1/models" req, _ := http.NewRequest("GET", url, nil) - req.Header.Set("Authorization", "Bearer "+claudeAPIKey) + req.Header.Set("Authorization", "Bearer "+claudeKey) client := &http.Client{Timeout: 30 * time.Second} resp, err := client.Do(req) @@ -106,10 +142,11 @@ func TestClaudeModelsList(t *testing.T) { // TestGeminiModelsList 测试 GET /v1beta/models func TestGeminiModelsList(t *testing.T) { + geminiKey := requireGeminiAPIKey(t) url := baseURL + endpointPrefix + "/v1beta/models" req, _ := http.NewRequest("GET", url, nil) - req.Header.Set("Authorization", "Bearer "+geminiAPIKey) + req.Header.Set("Authorization", "Bearer "+geminiKey) client := &http.Client{Timeout: 30 * time.Second} resp, err := client.Do(req) @@ -137,21 +174,22 @@ func TestGeminiModelsList(t *testing.T) { // TestClaudeMessages 测试 Claude /v1/messages 接口 func TestClaudeMessages(t *testing.T) { + claudeKey := requireClaudeAPIKey(t) for i, model := range claudeModels { if i > 0 { time.Sleep(testInterval) } t.Run(model+"_非流式", func(t *testing.T) { - testClaudeMessage(t, model, false) + testClaudeMessage(t, claudeKey, model, false) }) time.Sleep(testInterval) t.Run(model+"_流式", func(t *testing.T) { - testClaudeMessage(t, model, true) + testClaudeMessage(t, claudeKey, model, true) }) } } -func testClaudeMessage(t *testing.T, model string, stream bool) { +func testClaudeMessage(t *testing.T, claudeKey string, model string, stream bool) { url := baseURL + endpointPrefix + "/v1/messages" payload := map[string]any{ @@ -166,7 +204,7 @@ func testClaudeMessage(t *testing.T, model string, stream bool) { req, _ := http.NewRequest("POST", url, bytes.NewReader(body)) req.Header.Set("Content-Type", "application/json") - req.Header.Set("Authorization", "Bearer "+claudeAPIKey) + req.Header.Set("Authorization", "Bearer "+claudeKey) req.Header.Set("anthropic-version", "2023-06-01") client := &http.Client{Timeout: 60 * time.Second} @@ -213,21 +251,22 @@ func testClaudeMessage(t *testing.T, model string, stream bool) { // TestGeminiGenerateContent 测试 Gemini /v1beta/models/:model 接口 func TestGeminiGenerateContent(t *testing.T) { + geminiKey := requireGeminiAPIKey(t) for i, model := range geminiModels { if i > 0 { time.Sleep(testInterval) } t.Run(model+"_非流式", func(t *testing.T) { - testGeminiGenerate(t, model, false) + testGeminiGenerate(t, geminiKey, model, false) }) time.Sleep(testInterval) t.Run(model+"_流式", func(t *testing.T) { - testGeminiGenerate(t, model, true) + testGeminiGenerate(t, geminiKey, model, true) }) } } -func testGeminiGenerate(t *testing.T, model string, stream bool) { +func testGeminiGenerate(t *testing.T, geminiKey string, model string, stream bool) { action := "generateContent" if stream { action = "streamGenerateContent" @@ -254,7 +293,7 @@ func testGeminiGenerate(t *testing.T, model string, stream bool) { req, _ := http.NewRequest("POST", url, bytes.NewReader(body)) req.Header.Set("Content-Type", "application/json") - req.Header.Set("Authorization", "Bearer "+geminiAPIKey) + req.Header.Set("Authorization", "Bearer "+geminiKey) client := &http.Client{Timeout: 60 * time.Second} resp, err := client.Do(req) @@ -301,6 +340,7 @@ func testGeminiGenerate(t *testing.T, model string, stream bool) { // TestClaudeMessagesWithComplexTools 测试带复杂工具 schema 的请求 // 模拟 Claude Code 发送的请求,包含需要清理的 JSON Schema 字段 func TestClaudeMessagesWithComplexTools(t *testing.T) { + claudeKey := requireClaudeAPIKey(t) // 测试模型列表(只测试几个代表性模型) models := []string{ "claude-opus-4-5-20251101", // Claude 模型 @@ -312,12 +352,12 @@ func TestClaudeMessagesWithComplexTools(t *testing.T) { time.Sleep(testInterval) } t.Run(model+"_复杂工具", func(t *testing.T) { - testClaudeMessageWithTools(t, model) + testClaudeMessageWithTools(t, claudeKey, model) }) } } -func testClaudeMessageWithTools(t *testing.T, model string) { +func testClaudeMessageWithTools(t *testing.T, claudeKey string, model string) { url := baseURL + endpointPrefix + "/v1/messages" // 构造包含复杂 schema 的工具定义(模拟 Claude Code 的工具) @@ -473,7 +513,7 @@ func testClaudeMessageWithTools(t *testing.T, model string) { req, _ := http.NewRequest("POST", url, bytes.NewReader(body)) req.Header.Set("Content-Type", "application/json") - req.Header.Set("Authorization", "Bearer "+claudeAPIKey) + req.Header.Set("Authorization", "Bearer "+claudeKey) req.Header.Set("anthropic-version", "2023-06-01") client := &http.Client{Timeout: 60 * time.Second} @@ -519,6 +559,7 @@ func testClaudeMessageWithTools(t *testing.T, model string) { // 验证:当历史 assistant 消息包含 tool_use 但没有 signature 时, // 系统应自动添加 dummy thought_signature 避免 Gemini 400 错误 func TestClaudeMessagesWithThinkingAndTools(t *testing.T) { + claudeKey := requireClaudeAPIKey(t) models := []string{ "claude-haiku-4-5-20251001", // gemini-3-flash } @@ -527,12 +568,12 @@ func TestClaudeMessagesWithThinkingAndTools(t *testing.T) { time.Sleep(testInterval) } t.Run(model+"_thinking模式工具调用", func(t *testing.T) { - testClaudeThinkingWithToolHistory(t, model) + testClaudeThinkingWithToolHistory(t, claudeKey, model) }) } } -func testClaudeThinkingWithToolHistory(t *testing.T, model string) { +func testClaudeThinkingWithToolHistory(t *testing.T, claudeKey string, model string) { url := baseURL + endpointPrefix + "/v1/messages" // 模拟历史对话:用户请求 → assistant 调用工具 → 工具返回 → 继续对话 @@ -600,7 +641,7 @@ func testClaudeThinkingWithToolHistory(t *testing.T, model string) { req, _ := http.NewRequest("POST", url, bytes.NewReader(body)) req.Header.Set("Content-Type", "application/json") - req.Header.Set("Authorization", "Bearer "+claudeAPIKey) + req.Header.Set("Authorization", "Bearer "+claudeKey) req.Header.Set("anthropic-version", "2023-06-01") client := &http.Client{Timeout: 60 * time.Second} @@ -649,6 +690,7 @@ func TestClaudeMessagesWithGeminiModel(t *testing.T) { if endpointPrefix != "/antigravity" { t.Skip("仅在 Antigravity 模式下运行") } + claudeKey := requireClaudeAPIKey(t) // 测试通过 Claude 端点调用 Gemini 模型 geminiViaClaude := []string{ @@ -664,11 +706,11 @@ func TestClaudeMessagesWithGeminiModel(t *testing.T) { time.Sleep(testInterval) } t.Run(model+"_通过Claude端点", func(t *testing.T) { - testClaudeMessage(t, model, false) + testClaudeMessage(t, claudeKey, model, false) }) time.Sleep(testInterval) t.Run(model+"_通过Claude端点_流式", func(t *testing.T) { - testClaudeMessage(t, model, true) + testClaudeMessage(t, claudeKey, model, true) }) } } @@ -676,6 +718,7 @@ func TestClaudeMessagesWithGeminiModel(t *testing.T) { // TestClaudeMessagesWithNoSignature 测试历史 thinking block 不带 signature 的场景 // 验证:Gemini 模型接受没有 signature 的 thinking block func TestClaudeMessagesWithNoSignature(t *testing.T) { + claudeKey := requireClaudeAPIKey(t) models := []string{ "claude-haiku-4-5-20251001", // gemini-3-flash - 支持无 signature } @@ -684,12 +727,12 @@ func TestClaudeMessagesWithNoSignature(t *testing.T) { time.Sleep(testInterval) } t.Run(model+"_无signature", func(t *testing.T) { - testClaudeWithNoSignature(t, model) + testClaudeWithNoSignature(t, claudeKey, model) }) } } -func testClaudeWithNoSignature(t *testing.T, model string) { +func testClaudeWithNoSignature(t *testing.T, claudeKey string, model string) { url := baseURL + endpointPrefix + "/v1/messages" // 模拟历史对话包含 thinking block 但没有 signature @@ -732,7 +775,7 @@ func testClaudeWithNoSignature(t *testing.T, model string) { req, _ := http.NewRequest("POST", url, bytes.NewReader(body)) req.Header.Set("Content-Type", "application/json") - req.Header.Set("Authorization", "Bearer "+claudeAPIKey) + req.Header.Set("Authorization", "Bearer "+claudeKey) req.Header.Set("anthropic-version", "2023-06-01") client := &http.Client{Timeout: 60 * time.Second} @@ -777,6 +820,7 @@ func TestGeminiEndpointWithClaudeModel(t *testing.T) { if endpointPrefix != "/antigravity" { t.Skip("仅在 Antigravity 模式下运行") } + geminiKey := requireGeminiAPIKey(t) // 测试通过 Gemini 端点调用 Claude 模型 claudeViaGemini := []string{ @@ -789,11 +833,11 @@ func TestGeminiEndpointWithClaudeModel(t *testing.T) { time.Sleep(testInterval) } t.Run(model+"_通过Gemini端点", func(t *testing.T) { - testGeminiGenerate(t, model, false) + testGeminiGenerate(t, geminiKey, model, false) }) time.Sleep(testInterval) t.Run(model+"_通过Gemini端点_流式", func(t *testing.T) { - testGeminiGenerate(t, model, true) + testGeminiGenerate(t, geminiKey, model, true) }) } } diff --git a/backend/internal/integration/e2e_helpers_test.go b/backend/internal/integration/e2e_helpers_test.go new file mode 100644 index 00000000..7d266bcb --- /dev/null +++ b/backend/internal/integration/e2e_helpers_test.go @@ -0,0 +1,48 @@ +//go:build e2e + +package integration + +import ( + "os" + "strings" + "testing" +) + +// ============================================================================= +// E2E Mock 模式支持 +// ============================================================================= +// 当 E2E_MOCK=true 时,使用本地 Mock 响应替代真实 API 调用。 +// 这允许在没有真实 API Key 的环境(如 CI)中验证基本的请求/响应流程。 + +// isMockMode 检查是否启用 Mock 模式 +func isMockMode() bool { + return strings.EqualFold(os.Getenv("E2E_MOCK"), "true") +} + +// skipIfNoRealAPI 如果未配置真实 API Key 且不在 Mock 模式,则跳过测试 +func skipIfNoRealAPI(t *testing.T) { + t.Helper() + if isMockMode() { + return // Mock 模式下不跳过 + } + claudeKey := strings.TrimSpace(os.Getenv(claudeAPIKeyEnv)) + geminiKey := strings.TrimSpace(os.Getenv(geminiAPIKeyEnv)) + if claudeKey == "" && geminiKey == "" { + t.Skip("未设置 API Key 且未启用 Mock 模式,跳过测试") + } +} + +// ============================================================================= +// API Key 脱敏(Task 6.10) +// ============================================================================= + +// safeLogKey 安全地记录 API Key(仅显示前 8 位) +func safeLogKey(t *testing.T, prefix string, key string) { + t.Helper() + key = strings.TrimSpace(key) + if len(key) <= 8 { + t.Logf("%s: ***(长度: %d)", prefix, len(key)) + return + } + t.Logf("%s: %s...(长度: %d)", prefix, key[:8], len(key)) +} diff --git a/backend/internal/integration/e2e_user_flow_test.go b/backend/internal/integration/e2e_user_flow_test.go new file mode 100644 index 00000000..5489d0a3 --- /dev/null +++ b/backend/internal/integration/e2e_user_flow_test.go @@ -0,0 +1,317 @@ +//go:build e2e + +package integration + +import ( + "bytes" + "encoding/json" + "fmt" + "io" + "net/http" + "strings" + "testing" + "time" +) + +// E2E 用户流程测试 +// 测试完整的用户操作链路:注册 → 登录 → 创建 API Key → 调用网关 → 查询用量 + +var ( + testUserEmail = "e2e-test-" + fmt.Sprintf("%d", time.Now().UnixMilli()) + "@test.local" + testUserPassword = "E2eTest@12345" + testUserName = "e2e-test-user" +) + +// TestUserRegistrationAndLogin 测试用户注册和登录流程 +func TestUserRegistrationAndLogin(t *testing.T) { + // 步骤 1: 注册新用户 + t.Run("注册新用户", func(t *testing.T) { + payload := map[string]string{ + "email": testUserEmail, + "password": testUserPassword, + "username": testUserName, + } + body, _ := json.Marshal(payload) + + resp, err := doRequest(t, "POST", "/api/auth/register", body, "") + if err != nil { + t.Skipf("注册接口不可用,跳过用户流程测试: %v", err) + return + } + defer resp.Body.Close() + + respBody, _ := io.ReadAll(resp.Body) + + // 注册可能返回 200(成功)或 400(邮箱已存在)或 403(注册已关闭) + switch resp.StatusCode { + case 200: + t.Logf("✅ 用户注册成功: %s", testUserEmail) + case 400: + t.Logf("⚠️ 用户可能已存在: %s", string(respBody)) + case 403: + t.Skipf("注册功能已关闭: %s", string(respBody)) + default: + t.Logf("⚠️ 注册返回 HTTP %d: %s(继续尝试登录)", resp.StatusCode, string(respBody)) + } + }) + + // 步骤 2: 登录获取 JWT + var accessToken string + t.Run("用户登录获取JWT", func(t *testing.T) { + payload := map[string]string{ + "email": testUserEmail, + "password": testUserPassword, + } + body, _ := json.Marshal(payload) + + resp, err := doRequest(t, "POST", "/api/auth/login", body, "") + if err != nil { + t.Fatalf("登录请求失败: %v", err) + } + defer resp.Body.Close() + + respBody, _ := io.ReadAll(resp.Body) + + if resp.StatusCode != 200 { + t.Skipf("登录失败 HTTP %d: %s(可能需要先注册用户)", resp.StatusCode, string(respBody)) + return + } + + var result map[string]any + if err := json.Unmarshal(respBody, &result); err != nil { + t.Fatalf("解析登录响应失败: %v", err) + } + + // 尝试从标准响应格式获取 token + if token, ok := result["access_token"].(string); ok && token != "" { + accessToken = token + } else if data, ok := result["data"].(map[string]any); ok { + if token, ok := data["access_token"].(string); ok { + accessToken = token + } + } + + if accessToken == "" { + t.Skipf("未获取到 access_token,响应: %s", string(respBody)) + return + } + + // 验证 token 不为空且格式基本正确 + if len(accessToken) < 10 { + t.Fatalf("access_token 格式异常: %s", accessToken) + } + + t.Logf("✅ 登录成功,获取 JWT(长度: %d)", len(accessToken)) + }) + + if accessToken == "" { + t.Skip("未获取到 JWT,跳过后续测试") + return + } + + // 步骤 3: 使用 JWT 获取当前用户信息 + t.Run("获取当前用户信息", func(t *testing.T) { + resp, err := doRequest(t, "GET", "/api/user/me", nil, accessToken) + if err != nil { + t.Fatalf("请求失败: %v", err) + } + defer resp.Body.Close() + + if resp.StatusCode != 200 { + body, _ := io.ReadAll(resp.Body) + t.Fatalf("HTTP %d: %s", resp.StatusCode, string(body)) + } + + t.Logf("✅ 成功获取用户信息") + }) +} + +// TestAPIKeyLifecycle 测试 API Key 的创建和使用 +func TestAPIKeyLifecycle(t *testing.T) { + // 先登录获取 JWT + accessToken := loginTestUser(t) + if accessToken == "" { + t.Skip("无法登录,跳过 API Key 生命周期测试") + return + } + + var apiKey string + + // 步骤 1: 创建 API Key + t.Run("创建API_Key", func(t *testing.T) { + payload := map[string]string{ + "name": "e2e-test-key-" + fmt.Sprintf("%d", time.Now().UnixMilli()), + } + body, _ := json.Marshal(payload) + + resp, err := doRequest(t, "POST", "/api/keys", body, accessToken) + if err != nil { + t.Fatalf("创建 API Key 请求失败: %v", err) + } + defer resp.Body.Close() + + respBody, _ := io.ReadAll(resp.Body) + + if resp.StatusCode != 200 { + t.Skipf("创建 API Key 失败 HTTP %d: %s", resp.StatusCode, string(respBody)) + return + } + + var result map[string]any + if err := json.Unmarshal(respBody, &result); err != nil { + t.Fatalf("解析响应失败: %v", err) + } + + // 从响应中提取 key + if key, ok := result["key"].(string); ok { + apiKey = key + } else if data, ok := result["data"].(map[string]any); ok { + if key, ok := data["key"].(string); ok { + apiKey = key + } + } + + if apiKey == "" { + t.Skipf("未获取到 API Key,响应: %s", string(respBody)) + return + } + + // 验证 API Key 脱敏日志(只显示前 8 位) + masked := apiKey + if len(masked) > 8 { + masked = masked[:8] + "..." + } + t.Logf("✅ API Key 创建成功: %s", masked) + }) + + if apiKey == "" { + t.Skip("未创建 API Key,跳过后续测试") + return + } + + // 步骤 2: 使用 API Key 调用网关(需要 Claude 或 Gemini 可用) + t.Run("使用API_Key调用网关", func(t *testing.T) { + // 尝试调用 models 列表(最轻量的 API 调用) + resp, err := doRequest(t, "GET", "/v1/models", nil, apiKey) + if err != nil { + t.Fatalf("网关请求失败: %v", err) + } + defer resp.Body.Close() + + respBody, _ := io.ReadAll(resp.Body) + + // 可能返回 200(成功)或 402(余额不足)或 403(无可用账户) + switch { + case resp.StatusCode == 200: + t.Logf("✅ API Key 网关调用成功") + case resp.StatusCode == 402: + t.Logf("⚠️ 余额不足,但 API Key 认证通过") + case resp.StatusCode == 403: + t.Logf("⚠️ 无可用账户,但 API Key 认证通过") + default: + t.Logf("⚠️ 网关返回 HTTP %d: %s", resp.StatusCode, string(respBody)) + } + }) + + // 步骤 3: 查询用量记录 + t.Run("查询用量记录", func(t *testing.T) { + resp, err := doRequest(t, "GET", "/api/usage/dashboard", nil, accessToken) + if err != nil { + t.Fatalf("用量查询请求失败: %v", err) + } + defer resp.Body.Close() + + if resp.StatusCode != 200 { + body, _ := io.ReadAll(resp.Body) + t.Logf("⚠️ 用量查询返回 HTTP %d: %s", resp.StatusCode, string(body)) + return + } + + t.Logf("✅ 用量查询成功") + }) +} + +// ============================================================================= +// 辅助函数 +// ============================================================================= + +func doRequest(t *testing.T, method, path string, body []byte, token string) (*http.Response, error) { + t.Helper() + + url := baseURL + path + var bodyReader io.Reader + if body != nil { + bodyReader = bytes.NewReader(body) + } + + req, err := http.NewRequest(method, url, bodyReader) + if err != nil { + return nil, fmt.Errorf("创建请求失败: %w", err) + } + + if body != nil { + req.Header.Set("Content-Type", "application/json") + } + if token != "" { + req.Header.Set("Authorization", "Bearer "+token) + } + + client := &http.Client{Timeout: 30 * time.Second} + return client.Do(req) +} + +func loginTestUser(t *testing.T) string { + t.Helper() + + // 先尝试用管理员账户登录 + adminEmail := getEnv("ADMIN_EMAIL", "admin@sub2api.local") + adminPassword := getEnv("ADMIN_PASSWORD", "") + + if adminPassword == "" { + // 尝试用测试用户 + adminEmail = testUserEmail + adminPassword = testUserPassword + } + + payload := map[string]string{ + "email": adminEmail, + "password": adminPassword, + } + body, _ := json.Marshal(payload) + + resp, err := doRequest(t, "POST", "/api/auth/login", body, "") + if err != nil { + return "" + } + defer resp.Body.Close() + + if resp.StatusCode != 200 { + return "" + } + + respBody, _ := io.ReadAll(resp.Body) + var result map[string]any + if err := json.Unmarshal(respBody, &result); err != nil { + return "" + } + + if token, ok := result["access_token"].(string); ok { + return token + } + if data, ok := result["data"].(map[string]any); ok { + if token, ok := data["access_token"].(string); ok { + return token + } + } + + return "" +} + +// redactAPIKey API Key 脱敏,只显示前 8 位 +func redactAPIKey(key string) string { + key = strings.TrimSpace(key) + if len(key) <= 8 { + return "***" + } + return key[:8] + "..." +} diff --git a/backend/internal/middleware/rate_limiter_test.go b/backend/internal/middleware/rate_limiter_test.go index 0c379c0f..e362274f 100644 --- a/backend/internal/middleware/rate_limiter_test.go +++ b/backend/internal/middleware/rate_limiter_test.go @@ -60,6 +60,49 @@ func TestRateLimiterFailureModes(t *testing.T) { require.Equal(t, http.StatusTooManyRequests, recorder.Code) } +func TestRateLimiterDifferentIPsIndependent(t *testing.T) { + gin.SetMode(gin.TestMode) + + callCounts := make(map[string]int64) + originalRun := rateLimitRun + rateLimitRun = func(ctx context.Context, client *redis.Client, key string, windowMillis int64) (int64, bool, error) { + callCounts[key]++ + return callCounts[key], false, nil + } + t.Cleanup(func() { + rateLimitRun = originalRun + }) + + limiter := NewRateLimiter(redis.NewClient(&redis.Options{Addr: "127.0.0.1:1"})) + + router := gin.New() + router.Use(limiter.Limit("api", 1, time.Second)) + router.GET("/test", func(c *gin.Context) { + c.JSON(http.StatusOK, gin.H{"ok": true}) + }) + + // 第一个 IP 的请求应通过 + req1 := httptest.NewRequest(http.MethodGet, "/test", nil) + req1.RemoteAddr = "10.0.0.1:1234" + rec1 := httptest.NewRecorder() + router.ServeHTTP(rec1, req1) + require.Equal(t, http.StatusOK, rec1.Code, "第一个 IP 的第一次请求应通过") + + // 第二个 IP 的请求应独立通过(不受第一个 IP 的计数影响) + req2 := httptest.NewRequest(http.MethodGet, "/test", nil) + req2.RemoteAddr = "10.0.0.2:5678" + rec2 := httptest.NewRecorder() + router.ServeHTTP(rec2, req2) + require.Equal(t, http.StatusOK, rec2.Code, "第二个 IP 的第一次请求应独立通过") + + // 第一个 IP 的第二次请求应被限流 + req3 := httptest.NewRequest(http.MethodGet, "/test", nil) + req3.RemoteAddr = "10.0.0.1:1234" + rec3 := httptest.NewRecorder() + router.ServeHTTP(rec3, req3) + require.Equal(t, http.StatusTooManyRequests, rec3.Code, "第一个 IP 的第二次请求应被限流") +} + func TestRateLimiterSuccessAndLimit(t *testing.T) { gin.SetMode(gin.TestMode) diff --git a/backend/internal/pkg/antigravity/client.go b/backend/internal/pkg/antigravity/client.go index f1aa744a..1998221a 100644 --- a/backend/internal/pkg/antigravity/client.go +++ b/backend/internal/pkg/antigravity/client.go @@ -204,9 +204,14 @@ func shouldFallbackToNextURL(err error, statusCode int) bool { // ExchangeCode 用 authorization code 交换 token func (c *Client) ExchangeCode(ctx context.Context, code, codeVerifier string) (*TokenResponse, error) { + clientSecret, err := getClientSecret() + if err != nil { + return nil, err + } + params := url.Values{} params.Set("client_id", ClientID) - params.Set("client_secret", ClientSecret) + params.Set("client_secret", clientSecret) params.Set("code", code) params.Set("redirect_uri", RedirectURI) params.Set("grant_type", "authorization_code") @@ -243,9 +248,14 @@ func (c *Client) ExchangeCode(ctx context.Context, code, codeVerifier string) (* // RefreshToken 刷新 access_token func (c *Client) RefreshToken(ctx context.Context, refreshToken string) (*TokenResponse, error) { + clientSecret, err := getClientSecret() + if err != nil { + return nil, err + } + params := url.Values{} params.Set("client_id", ClientID) - params.Set("client_secret", ClientSecret) + params.Set("client_secret", clientSecret) params.Set("refresh_token", refreshToken) params.Set("grant_type", "refresh_token") diff --git a/backend/internal/pkg/antigravity/client_test.go b/backend/internal/pkg/antigravity/client_test.go index ac30093d..7e8d3a2a 100644 --- a/backend/internal/pkg/antigravity/client_test.go +++ b/backend/internal/pkg/antigravity/client_test.go @@ -1,9 +1,1661 @@ +//go:build unit + package antigravity import ( + "context" + "encoding/json" + "fmt" + "net" + "net/http" + "net/http/httptest" + "net/url" + "strings" "testing" + "time" ) +// --------------------------------------------------------------------------- +// NewAPIRequestWithURL +// --------------------------------------------------------------------------- + +func TestNewAPIRequestWithURL_普通请求(t *testing.T) { + ctx := context.Background() + baseURL := "https://example.com" + action := "generateContent" + token := "test-token" + body := []byte(`{"prompt":"hello"}`) + + req, err := NewAPIRequestWithURL(ctx, baseURL, action, token, body) + if err != nil { + t.Fatalf("创建请求失败: %v", err) + } + + // 验证 URL 不含 ?alt=sse + expectedURL := "https://example.com/v1internal:generateContent" + if req.URL.String() != expectedURL { + t.Errorf("URL 不匹配: got %s, want %s", req.URL.String(), expectedURL) + } + + // 验证请求方法 + if req.Method != http.MethodPost { + t.Errorf("请求方法不匹配: got %s, want POST", req.Method) + } + + // 验证 Headers + if ct := req.Header.Get("Content-Type"); ct != "application/json" { + t.Errorf("Content-Type 不匹配: got %s", ct) + } + if auth := req.Header.Get("Authorization"); auth != "Bearer test-token" { + t.Errorf("Authorization 不匹配: got %s", auth) + } + if ua := req.Header.Get("User-Agent"); ua != UserAgent { + t.Errorf("User-Agent 不匹配: got %s, want %s", ua, UserAgent) + } +} + +func TestNewAPIRequestWithURL_流式请求(t *testing.T) { + ctx := context.Background() + baseURL := "https://example.com" + action := "streamGenerateContent" + token := "tok" + body := []byte(`{}`) + + req, err := NewAPIRequestWithURL(ctx, baseURL, action, token, body) + if err != nil { + t.Fatalf("创建请求失败: %v", err) + } + + expectedURL := "https://example.com/v1internal:streamGenerateContent?alt=sse" + if req.URL.String() != expectedURL { + t.Errorf("URL 不匹配: got %s, want %s", req.URL.String(), expectedURL) + } +} + +func TestNewAPIRequestWithURL_空Body(t *testing.T) { + ctx := context.Background() + req, err := NewAPIRequestWithURL(ctx, "https://example.com", "test", "tok", nil) + if err != nil { + t.Fatalf("创建请求失败: %v", err) + } + if req.Body == nil { + t.Error("Body 应该非 nil(bytes.NewReader(nil) 会返回空 reader)") + } +} + +// --------------------------------------------------------------------------- +// NewAPIRequest +// --------------------------------------------------------------------------- + +func TestNewAPIRequest_使用默认URL(t *testing.T) { + ctx := context.Background() + req, err := NewAPIRequest(ctx, "generateContent", "tok", []byte(`{}`)) + if err != nil { + t.Fatalf("创建请求失败: %v", err) + } + + expected := BaseURL + "/v1internal:generateContent" + if req.URL.String() != expected { + t.Errorf("URL 不匹配: got %s, want %s", req.URL.String(), expected) + } +} + +// --------------------------------------------------------------------------- +// TierInfo.UnmarshalJSON +// --------------------------------------------------------------------------- + +func TestTierInfo_UnmarshalJSON_字符串格式(t *testing.T) { + data := []byte(`"free-tier"`) + var tier TierInfo + if err := tier.UnmarshalJSON(data); err != nil { + t.Fatalf("反序列化失败: %v", err) + } + if tier.ID != "free-tier" { + t.Errorf("ID 不匹配: got %s, want free-tier", tier.ID) + } + if tier.Name != "" { + t.Errorf("Name 应为空: got %s", tier.Name) + } +} + +func TestTierInfo_UnmarshalJSON_对象格式(t *testing.T) { + data := []byte(`{"id":"g1-pro-tier","name":"Pro","description":"Pro plan"}`) + var tier TierInfo + if err := tier.UnmarshalJSON(data); err != nil { + t.Fatalf("反序列化失败: %v", err) + } + if tier.ID != "g1-pro-tier" { + t.Errorf("ID 不匹配: got %s, want g1-pro-tier", tier.ID) + } + if tier.Name != "Pro" { + t.Errorf("Name 不匹配: got %s, want Pro", tier.Name) + } + if tier.Description != "Pro plan" { + t.Errorf("Description 不匹配: got %s, want Pro plan", tier.Description) + } +} + +func TestTierInfo_UnmarshalJSON_null(t *testing.T) { + data := []byte(`null`) + var tier TierInfo + if err := tier.UnmarshalJSON(data); err != nil { + t.Fatalf("反序列化 null 失败: %v", err) + } + if tier.ID != "" { + t.Errorf("null 场景下 ID 应为空: got %s", tier.ID) + } +} + +func TestTierInfo_UnmarshalJSON_空数据(t *testing.T) { + data := []byte(``) + var tier TierInfo + if err := tier.UnmarshalJSON(data); err != nil { + t.Fatalf("反序列化空数据失败: %v", err) + } + if tier.ID != "" { + t.Errorf("空数据场景下 ID 应为空: got %s", tier.ID) + } +} + +func TestTierInfo_UnmarshalJSON_空格包裹null(t *testing.T) { + data := []byte(` null `) + var tier TierInfo + if err := tier.UnmarshalJSON(data); err != nil { + t.Fatalf("反序列化空格 null 失败: %v", err) + } + if tier.ID != "" { + t.Errorf("空格 null 场景下 ID 应为空: got %s", tier.ID) + } +} + +func TestTierInfo_UnmarshalJSON_通过JSON嵌套结构(t *testing.T) { + // 模拟 LoadCodeAssistResponse 中的嵌套反序列化 + jsonData := `{"currentTier":"free-tier","paidTier":{"id":"g1-ultra-tier","name":"Ultra"}}` + var resp LoadCodeAssistResponse + if err := json.Unmarshal([]byte(jsonData), &resp); err != nil { + t.Fatalf("反序列化嵌套结构失败: %v", err) + } + if resp.CurrentTier == nil || resp.CurrentTier.ID != "free-tier" { + t.Errorf("CurrentTier 不匹配: got %+v", resp.CurrentTier) + } + if resp.PaidTier == nil || resp.PaidTier.ID != "g1-ultra-tier" { + t.Errorf("PaidTier 不匹配: got %+v", resp.PaidTier) + } +} + +// --------------------------------------------------------------------------- +// LoadCodeAssistResponse.GetTier +// --------------------------------------------------------------------------- + +func TestGetTier_PaidTier优先(t *testing.T) { + resp := &LoadCodeAssistResponse{ + CurrentTier: &TierInfo{ID: "free-tier"}, + PaidTier: &TierInfo{ID: "g1-pro-tier"}, + } + if got := resp.GetTier(); got != "g1-pro-tier" { + t.Errorf("应返回 paidTier: got %s", got) + } +} + +func TestGetTier_回退到CurrentTier(t *testing.T) { + resp := &LoadCodeAssistResponse{ + CurrentTier: &TierInfo{ID: "free-tier"}, + } + if got := resp.GetTier(); got != "free-tier" { + t.Errorf("应返回 currentTier: got %s", got) + } +} + +func TestGetTier_PaidTier为空ID(t *testing.T) { + resp := &LoadCodeAssistResponse{ + CurrentTier: &TierInfo{ID: "free-tier"}, + PaidTier: &TierInfo{ID: ""}, + } + // paidTier.ID 为空时应回退到 currentTier + if got := resp.GetTier(); got != "free-tier" { + t.Errorf("paidTier.ID 为空时应回退到 currentTier: got %s", got) + } +} + +func TestGetTier_两者都为nil(t *testing.T) { + resp := &LoadCodeAssistResponse{} + if got := resp.GetTier(); got != "" { + t.Errorf("两者都为 nil 时应返回空字符串: got %s", got) + } +} + +// --------------------------------------------------------------------------- +// NewClient +// --------------------------------------------------------------------------- + +func TestNewClient_无代理(t *testing.T) { + client := NewClient("") + if client == nil { + t.Fatal("NewClient 返回 nil") + } + if client.httpClient == nil { + t.Fatal("httpClient 为 nil") + } + if client.httpClient.Timeout != 30*time.Second { + t.Errorf("Timeout 不匹配: got %v, want 30s", client.httpClient.Timeout) + } + // 无代理时 Transport 应为 nil(使用默认) + if client.httpClient.Transport != nil { + t.Error("无代理时 Transport 应为 nil") + } +} + +func TestNewClient_有代理(t *testing.T) { + client := NewClient("http://proxy.example.com:8080") + if client == nil { + t.Fatal("NewClient 返回 nil") + } + if client.httpClient.Transport == nil { + t.Fatal("有代理时 Transport 不应为 nil") + } +} + +func TestNewClient_空格代理(t *testing.T) { + client := NewClient(" ") + if client == nil { + t.Fatal("NewClient 返回 nil") + } + // 空格代理应等同于无代理 + if client.httpClient.Transport != nil { + t.Error("空格代理 Transport 应为 nil") + } +} + +func TestNewClient_无效代理URL(t *testing.T) { + // 无效 URL 时 url.Parse 不一定返回错误(Go 的 url.Parse 很宽容), + // 但 ://invalid 会导致解析错误 + client := NewClient("://invalid") + if client == nil { + t.Fatal("NewClient 返回 nil") + } + // 无效 URL 解析失败时,Transport 应保持 nil + if client.httpClient.Transport != nil { + t.Error("无效代理 URL 时 Transport 应为 nil") + } +} + +// --------------------------------------------------------------------------- +// isConnectionError +// --------------------------------------------------------------------------- + +func TestIsConnectionError_nil(t *testing.T) { + if isConnectionError(nil) { + t.Error("nil 错误不应判定为连接错误") + } +} + +func TestIsConnectionError_超时错误(t *testing.T) { + // 使用 net.OpError 包装超时 + err := &net.OpError{ + Op: "dial", + Net: "tcp", + Err: &timeoutError{}, + } + if !isConnectionError(err) { + t.Error("超时错误应判定为连接错误") + } +} + +// timeoutError 实现 net.Error 接口用于测试 +type timeoutError struct{} + +func (e *timeoutError) Error() string { return "timeout" } +func (e *timeoutError) Timeout() bool { return true } +func (e *timeoutError) Temporary() bool { return true } + +func TestIsConnectionError_netOpError(t *testing.T) { + err := &net.OpError{ + Op: "dial", + Net: "tcp", + Err: fmt.Errorf("connection refused"), + } + if !isConnectionError(err) { + t.Error("net.OpError 应判定为连接错误") + } +} + +func TestIsConnectionError_urlError(t *testing.T) { + err := &url.Error{ + Op: "Get", + URL: "https://example.com", + Err: fmt.Errorf("some error"), + } + if !isConnectionError(err) { + t.Error("url.Error 应判定为连接错误") + } +} + +func TestIsConnectionError_普通错误(t *testing.T) { + err := fmt.Errorf("some random error") + if isConnectionError(err) { + t.Error("普通错误不应判定为连接错误") + } +} + +func TestIsConnectionError_包装的netOpError(t *testing.T) { + inner := &net.OpError{ + Op: "dial", + Net: "tcp", + Err: fmt.Errorf("connection refused"), + } + err := fmt.Errorf("wrapping: %w", inner) + if !isConnectionError(err) { + t.Error("被包装的 net.OpError 应判定为连接错误") + } +} + +// --------------------------------------------------------------------------- +// shouldFallbackToNextURL +// --------------------------------------------------------------------------- + +func TestShouldFallbackToNextURL_连接错误(t *testing.T) { + err := &net.OpError{Op: "dial", Net: "tcp", Err: fmt.Errorf("refused")} + if !shouldFallbackToNextURL(err, 0) { + t.Error("连接错误应触发 URL 降级") + } +} + +func TestShouldFallbackToNextURL_状态码(t *testing.T) { + tests := []struct { + name string + statusCode int + want bool + }{ + {"429 Too Many Requests", http.StatusTooManyRequests, true}, + {"408 Request Timeout", http.StatusRequestTimeout, true}, + {"404 Not Found", http.StatusNotFound, true}, + {"500 Internal Server Error", http.StatusInternalServerError, true}, + {"502 Bad Gateway", http.StatusBadGateway, true}, + {"503 Service Unavailable", http.StatusServiceUnavailable, true}, + {"200 OK", http.StatusOK, false}, + {"201 Created", http.StatusCreated, false}, + {"400 Bad Request", http.StatusBadRequest, false}, + {"401 Unauthorized", http.StatusUnauthorized, false}, + {"403 Forbidden", http.StatusForbidden, false}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got := shouldFallbackToNextURL(nil, tt.statusCode) + if got != tt.want { + t.Errorf("shouldFallbackToNextURL(nil, %d) = %v, want %v", tt.statusCode, got, tt.want) + } + }) + } +} + +func TestShouldFallbackToNextURL_无错误且200(t *testing.T) { + if shouldFallbackToNextURL(nil, http.StatusOK) { + t.Error("无错误且 200 不应触发 URL 降级") + } +} + +// --------------------------------------------------------------------------- +// Client.ExchangeCode (使用 httptest) +// --------------------------------------------------------------------------- + +func TestClient_ExchangeCode_成功(t *testing.T) { + t.Setenv(AntigravityOAuthClientSecretEnv, "test-secret") + + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + // 验证请求方法 + if r.Method != http.MethodPost { + t.Errorf("请求方法不匹配: got %s", r.Method) + } + // 验证 Content-Type + if ct := r.Header.Get("Content-Type"); ct != "application/x-www-form-urlencoded" { + t.Errorf("Content-Type 不匹配: got %s", ct) + } + // 验证请求体参数 + if err := r.ParseForm(); err != nil { + t.Fatalf("解析表单失败: %v", err) + } + if r.FormValue("client_id") != ClientID { + t.Errorf("client_id 不匹配: got %s", r.FormValue("client_id")) + } + if r.FormValue("client_secret") != "test-secret" { + t.Errorf("client_secret 不匹配: got %s", r.FormValue("client_secret")) + } + if r.FormValue("code") != "auth-code" { + t.Errorf("code 不匹配: got %s", r.FormValue("code")) + } + if r.FormValue("code_verifier") != "verifier123" { + t.Errorf("code_verifier 不匹配: got %s", r.FormValue("code_verifier")) + } + if r.FormValue("grant_type") != "authorization_code" { + t.Errorf("grant_type 不匹配: got %s", r.FormValue("grant_type")) + } + + w.Header().Set("Content-Type", "application/json") + w.WriteHeader(http.StatusOK) + _ = json.NewEncoder(w).Encode(TokenResponse{ + AccessToken: "access-tok", + ExpiresIn: 3600, + TokenType: "Bearer", + RefreshToken: "refresh-tok", + }) + })) + defer server.Close() + + // 临时替换 TokenURL(该函数直接使用常量,需要我们通过构建自定义 client 来绕过) + // 由于 ExchangeCode 硬编码了 TokenURL,我们需要直接测试 HTTP client 的行为 + // 这里通过构造一个直接调用 mock server 的测试 + client := &Client{httpClient: server.Client()} + + // 由于 ExchangeCode 使用硬编码的 TokenURL,我们无法直接注入 mock server URL + // 需要使用 httptest 的 Transport 重定向 + originalTokenURL := TokenURL + // 我们改为直接构造请求来测试逻辑 + _ = originalTokenURL + _ = client + + // 改用直接构造请求测试 mock server 响应 + ctx := context.Background() + params := url.Values{} + params.Set("client_id", ClientID) + params.Set("client_secret", "test-secret") + params.Set("code", "auth-code") + params.Set("redirect_uri", RedirectURI) + params.Set("grant_type", "authorization_code") + params.Set("code_verifier", "verifier123") + + req, err := http.NewRequestWithContext(ctx, http.MethodPost, server.URL, strings.NewReader(params.Encode())) + if err != nil { + t.Fatalf("创建请求失败: %v", err) + } + req.Header.Set("Content-Type", "application/x-www-form-urlencoded") + + resp, err := server.Client().Do(req) + if err != nil { + t.Fatalf("请求失败: %v", err) + } + defer func() { _ = resp.Body.Close() }() + + if resp.StatusCode != http.StatusOK { + t.Fatalf("状态码不匹配: got %d", resp.StatusCode) + } + + var tokenResp TokenResponse + if err := json.NewDecoder(resp.Body).Decode(&tokenResp); err != nil { + t.Fatalf("解码失败: %v", err) + } + if tokenResp.AccessToken != "access-tok" { + t.Errorf("AccessToken 不匹配: got %s", tokenResp.AccessToken) + } + if tokenResp.RefreshToken != "refresh-tok" { + t.Errorf("RefreshToken 不匹配: got %s", tokenResp.RefreshToken) + } +} + +func TestClient_ExchangeCode_无ClientSecret(t *testing.T) { + t.Setenv(AntigravityOAuthClientSecretEnv, "") + + client := NewClient("") + _, err := client.ExchangeCode(context.Background(), "code", "verifier") + if err == nil { + t.Fatal("缺少 client_secret 时应返回错误") + } + if !strings.Contains(err.Error(), AntigravityOAuthClientSecretEnv) { + t.Errorf("错误信息应包含环境变量名: got %s", err.Error()) + } +} + +func TestClient_ExchangeCode_服务器返回错误(t *testing.T) { + t.Setenv(AntigravityOAuthClientSecretEnv, "test-secret") + + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusBadRequest) + _, _ = w.Write([]byte(`{"error":"invalid_grant"}`)) + })) + defer server.Close() + + // 直接测试 mock server 的错误响应 + resp, err := server.Client().Get(server.URL) + if err != nil { + t.Fatalf("请求失败: %v", err) + } + defer func() { _ = resp.Body.Close() }() + + if resp.StatusCode != http.StatusBadRequest { + t.Errorf("状态码不匹配: got %d, want 400", resp.StatusCode) + } +} + +// --------------------------------------------------------------------------- +// Client.RefreshToken (使用 httptest) +// --------------------------------------------------------------------------- + +func TestClient_RefreshToken_MockServer(t *testing.T) { + t.Setenv(AntigravityOAuthClientSecretEnv, "test-secret") + + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if r.Method != http.MethodPost { + t.Errorf("请求方法不匹配: got %s", r.Method) + } + if err := r.ParseForm(); err != nil { + t.Fatalf("解析表单失败: %v", err) + } + if r.FormValue("grant_type") != "refresh_token" { + t.Errorf("grant_type 不匹配: got %s", r.FormValue("grant_type")) + } + if r.FormValue("refresh_token") != "old-refresh-tok" { + t.Errorf("refresh_token 不匹配: got %s", r.FormValue("refresh_token")) + } + + w.Header().Set("Content-Type", "application/json") + w.WriteHeader(http.StatusOK) + _ = json.NewEncoder(w).Encode(TokenResponse{ + AccessToken: "new-access-tok", + ExpiresIn: 3600, + TokenType: "Bearer", + }) + })) + defer server.Close() + + ctx := context.Background() + params := url.Values{} + params.Set("client_id", ClientID) + params.Set("client_secret", "test-secret") + params.Set("refresh_token", "old-refresh-tok") + params.Set("grant_type", "refresh_token") + + req, err := http.NewRequestWithContext(ctx, http.MethodPost, server.URL, strings.NewReader(params.Encode())) + if err != nil { + t.Fatalf("创建请求失败: %v", err) + } + req.Header.Set("Content-Type", "application/x-www-form-urlencoded") + + resp, err := server.Client().Do(req) + if err != nil { + t.Fatalf("请求失败: %v", err) + } + defer func() { _ = resp.Body.Close() }() + + if resp.StatusCode != http.StatusOK { + t.Fatalf("状态码不匹配: got %d", resp.StatusCode) + } + + var tokenResp TokenResponse + if err := json.NewDecoder(resp.Body).Decode(&tokenResp); err != nil { + t.Fatalf("解码失败: %v", err) + } + if tokenResp.AccessToken != "new-access-tok" { + t.Errorf("AccessToken 不匹配: got %s", tokenResp.AccessToken) + } +} + +func TestClient_RefreshToken_无ClientSecret(t *testing.T) { + t.Setenv(AntigravityOAuthClientSecretEnv, "") + + client := NewClient("") + _, err := client.RefreshToken(context.Background(), "refresh-tok") + if err == nil { + t.Fatal("缺少 client_secret 时应返回错误") + } +} + +// --------------------------------------------------------------------------- +// Client.GetUserInfo (使用 httptest) +// --------------------------------------------------------------------------- + +func TestClient_GetUserInfo_成功(t *testing.T) { + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if r.Method != http.MethodGet { + t.Errorf("请求方法不匹配: got %s", r.Method) + } + auth := r.Header.Get("Authorization") + if auth != "Bearer test-access-token" { + t.Errorf("Authorization 不匹配: got %s", auth) + } + + w.Header().Set("Content-Type", "application/json") + w.WriteHeader(http.StatusOK) + _ = json.NewEncoder(w).Encode(UserInfo{ + Email: "user@example.com", + Name: "Test User", + GivenName: "Test", + FamilyName: "User", + Picture: "https://example.com/photo.jpg", + }) + })) + defer server.Close() + + // 直接通过 mock server 测试 GetUserInfo 的行为逻辑 + ctx := context.Background() + req, err := http.NewRequestWithContext(ctx, http.MethodGet, server.URL, nil) + if err != nil { + t.Fatalf("创建请求失败: %v", err) + } + req.Header.Set("Authorization", "Bearer test-access-token") + + resp, err := server.Client().Do(req) + if err != nil { + t.Fatalf("请求失败: %v", err) + } + defer func() { _ = resp.Body.Close() }() + + if resp.StatusCode != http.StatusOK { + t.Fatalf("状态码不匹配: got %d", resp.StatusCode) + } + + var userInfo UserInfo + if err := json.NewDecoder(resp.Body).Decode(&userInfo); err != nil { + t.Fatalf("解码失败: %v", err) + } + if userInfo.Email != "user@example.com" { + t.Errorf("Email 不匹配: got %s", userInfo.Email) + } + if userInfo.Name != "Test User" { + t.Errorf("Name 不匹配: got %s", userInfo.Name) + } +} + +func TestClient_GetUserInfo_服务器返回错误(t *testing.T) { + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusUnauthorized) + _, _ = w.Write([]byte(`{"error":"invalid_token"}`)) + })) + defer server.Close() + + resp, err := server.Client().Get(server.URL) + if err != nil { + t.Fatalf("请求失败: %v", err) + } + defer func() { _ = resp.Body.Close() }() + + if resp.StatusCode != http.StatusUnauthorized { + t.Errorf("状态码不匹配: got %d, want 401", resp.StatusCode) + } +} + +// --------------------------------------------------------------------------- +// TokenResponse / UserInfo JSON 序列化 +// --------------------------------------------------------------------------- + +func TestTokenResponse_JSON序列化(t *testing.T) { + jsonData := `{"access_token":"at","expires_in":3600,"token_type":"Bearer","scope":"openid","refresh_token":"rt"}` + var resp TokenResponse + if err := json.Unmarshal([]byte(jsonData), &resp); err != nil { + t.Fatalf("反序列化失败: %v", err) + } + if resp.AccessToken != "at" { + t.Errorf("AccessToken 不匹配: got %s", resp.AccessToken) + } + if resp.ExpiresIn != 3600 { + t.Errorf("ExpiresIn 不匹配: got %d", resp.ExpiresIn) + } + if resp.RefreshToken != "rt" { + t.Errorf("RefreshToken 不匹配: got %s", resp.RefreshToken) + } +} + +func TestUserInfo_JSON序列化(t *testing.T) { + jsonData := `{"email":"a@b.com","name":"Alice"}` + var info UserInfo + if err := json.Unmarshal([]byte(jsonData), &info); err != nil { + t.Fatalf("反序列化失败: %v", err) + } + if info.Email != "a@b.com" { + t.Errorf("Email 不匹配: got %s", info.Email) + } + if info.Name != "Alice" { + t.Errorf("Name 不匹配: got %s", info.Name) + } +} + +// --------------------------------------------------------------------------- +// LoadCodeAssistResponse JSON 序列化 +// --------------------------------------------------------------------------- + +func TestLoadCodeAssistResponse_完整JSON(t *testing.T) { + jsonData := `{ + "cloudaicompanionProject": "proj-123", + "currentTier": "free-tier", + "paidTier": {"id": "g1-pro-tier", "name": "Pro"}, + "ineligibleTiers": [{"tier": {"id": "g1-ultra-tier"}, "reasonCode": "INELIGIBLE_ACCOUNT"}] + }` + var resp LoadCodeAssistResponse + if err := json.Unmarshal([]byte(jsonData), &resp); err != nil { + t.Fatalf("反序列化失败: %v", err) + } + if resp.CloudAICompanionProject != "proj-123" { + t.Errorf("CloudAICompanionProject 不匹配: got %s", resp.CloudAICompanionProject) + } + if resp.GetTier() != "g1-pro-tier" { + t.Errorf("GetTier 不匹配: got %s", resp.GetTier()) + } + if len(resp.IneligibleTiers) != 1 { + t.Fatalf("IneligibleTiers 数量不匹配: got %d", len(resp.IneligibleTiers)) + } + if resp.IneligibleTiers[0].ReasonCode != "INELIGIBLE_ACCOUNT" { + t.Errorf("ReasonCode 不匹配: got %s", resp.IneligibleTiers[0].ReasonCode) + } +} + +// =========================================================================== +// 以下为新增测试:真正调用 Client 方法,通过 RoundTripper 拦截 HTTP 请求 +// =========================================================================== + +// redirectRoundTripper 将请求中特定前缀的 URL 重定向到 httptest server +type redirectRoundTripper struct { + // 原始 URL 前缀 -> 替换目标 URL 的映射 + redirects map[string]string + transport http.RoundTripper +} + +func (rt *redirectRoundTripper) RoundTrip(req *http.Request) (*http.Response, error) { + originalURL := req.URL.String() + for prefix, target := range rt.redirects { + if strings.HasPrefix(originalURL, prefix) { + newURL := target + strings.TrimPrefix(originalURL, prefix) + parsed, err := url.Parse(newURL) + if err != nil { + return nil, err + } + req.URL = parsed + break + } + } + if rt.transport == nil { + return http.DefaultTransport.RoundTrip(req) + } + return rt.transport.RoundTrip(req) +} + +// newTestClientWithRedirect 创建一个 Client,将指定 URL 前缀的请求重定向到 mock server +func newTestClientWithRedirect(redirects map[string]string) *Client { + return &Client{ + httpClient: &http.Client{ + Timeout: 10 * time.Second, + Transport: &redirectRoundTripper{ + redirects: redirects, + }, + }, + } +} + +// --------------------------------------------------------------------------- +// Client.ExchangeCode - 真正调用方法的测试 +// --------------------------------------------------------------------------- + +func TestClient_ExchangeCode_Success_RealCall(t *testing.T) { + t.Setenv(AntigravityOAuthClientSecretEnv, "test-secret") + + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if r.Method != http.MethodPost { + t.Errorf("请求方法不匹配: got %s, want POST", r.Method) + } + if ct := r.Header.Get("Content-Type"); ct != "application/x-www-form-urlencoded" { + t.Errorf("Content-Type 不匹配: got %s", ct) + } + if err := r.ParseForm(); err != nil { + t.Fatalf("解析表单失败: %v", err) + } + if r.FormValue("client_id") != ClientID { + t.Errorf("client_id 不匹配: got %s", r.FormValue("client_id")) + } + if r.FormValue("client_secret") != "test-secret" { + t.Errorf("client_secret 不匹配: got %s", r.FormValue("client_secret")) + } + if r.FormValue("code") != "test-auth-code" { + t.Errorf("code 不匹配: got %s", r.FormValue("code")) + } + if r.FormValue("code_verifier") != "test-verifier" { + t.Errorf("code_verifier 不匹配: got %s", r.FormValue("code_verifier")) + } + if r.FormValue("grant_type") != "authorization_code" { + t.Errorf("grant_type 不匹配: got %s", r.FormValue("grant_type")) + } + if r.FormValue("redirect_uri") != RedirectURI { + t.Errorf("redirect_uri 不匹配: got %s", r.FormValue("redirect_uri")) + } + + w.Header().Set("Content-Type", "application/json") + w.WriteHeader(http.StatusOK) + _ = json.NewEncoder(w).Encode(TokenResponse{ + AccessToken: "new-access-token", + ExpiresIn: 3600, + TokenType: "Bearer", + Scope: "openid email", + RefreshToken: "new-refresh-token", + }) + })) + defer server.Close() + + client := newTestClientWithRedirect(map[string]string{ + TokenURL: server.URL, + }) + + tokenResp, err := client.ExchangeCode(context.Background(), "test-auth-code", "test-verifier") + if err != nil { + t.Fatalf("ExchangeCode 失败: %v", err) + } + if tokenResp.AccessToken != "new-access-token" { + t.Errorf("AccessToken 不匹配: got %s, want new-access-token", tokenResp.AccessToken) + } + if tokenResp.RefreshToken != "new-refresh-token" { + t.Errorf("RefreshToken 不匹配: got %s, want new-refresh-token", tokenResp.RefreshToken) + } + if tokenResp.ExpiresIn != 3600 { + t.Errorf("ExpiresIn 不匹配: got %d, want 3600", tokenResp.ExpiresIn) + } + if tokenResp.TokenType != "Bearer" { + t.Errorf("TokenType 不匹配: got %s, want Bearer", tokenResp.TokenType) + } + if tokenResp.Scope != "openid email" { + t.Errorf("Scope 不匹配: got %s, want openid email", tokenResp.Scope) + } +} + +func TestClient_ExchangeCode_ServerError_RealCall(t *testing.T) { + t.Setenv(AntigravityOAuthClientSecretEnv, "test-secret") + + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusBadRequest) + _, _ = w.Write([]byte(`{"error":"invalid_grant","error_description":"code expired"}`)) + })) + defer server.Close() + + client := newTestClientWithRedirect(map[string]string{ + TokenURL: server.URL, + }) + + _, err := client.ExchangeCode(context.Background(), "expired-code", "verifier") + if err == nil { + t.Fatal("服务器返回 400 时应返回错误") + } + if !strings.Contains(err.Error(), "token 交换失败") { + t.Errorf("错误信息应包含 'token 交换失败': got %s", err.Error()) + } + if !strings.Contains(err.Error(), "400") { + t.Errorf("错误信息应包含状态码 400: got %s", err.Error()) + } +} + +func TestClient_ExchangeCode_InvalidJSON_RealCall(t *testing.T) { + t.Setenv(AntigravityOAuthClientSecretEnv, "test-secret") + + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.Header().Set("Content-Type", "application/json") + w.WriteHeader(http.StatusOK) + _, _ = w.Write([]byte(`{invalid json`)) + })) + defer server.Close() + + client := newTestClientWithRedirect(map[string]string{ + TokenURL: server.URL, + }) + + _, err := client.ExchangeCode(context.Background(), "code", "verifier") + if err == nil { + t.Fatal("无效 JSON 响应应返回错误") + } + if !strings.Contains(err.Error(), "token 解析失败") { + t.Errorf("错误信息应包含 'token 解析失败': got %s", err.Error()) + } +} + +func TestClient_ExchangeCode_ContextCanceled_RealCall(t *testing.T) { + t.Setenv(AntigravityOAuthClientSecretEnv, "test-secret") + + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + time.Sleep(5 * time.Second) // 模拟慢响应 + w.WriteHeader(http.StatusOK) + })) + defer server.Close() + + client := newTestClientWithRedirect(map[string]string{ + TokenURL: server.URL, + }) + + ctx, cancel := context.WithCancel(context.Background()) + cancel() // 立即取消 + + _, err := client.ExchangeCode(ctx, "code", "verifier") + if err == nil { + t.Fatal("context 取消时应返回错误") + } +} + +// --------------------------------------------------------------------------- +// Client.RefreshToken - 真正调用方法的测试 +// --------------------------------------------------------------------------- + +func TestClient_RefreshToken_Success_RealCall(t *testing.T) { + t.Setenv(AntigravityOAuthClientSecretEnv, "test-secret") + + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if r.Method != http.MethodPost { + t.Errorf("请求方法不匹配: got %s, want POST", r.Method) + } + if err := r.ParseForm(); err != nil { + t.Fatalf("解析表单失败: %v", err) + } + if r.FormValue("grant_type") != "refresh_token" { + t.Errorf("grant_type 不匹配: got %s", r.FormValue("grant_type")) + } + if r.FormValue("refresh_token") != "my-refresh-token" { + t.Errorf("refresh_token 不匹配: got %s", r.FormValue("refresh_token")) + } + if r.FormValue("client_id") != ClientID { + t.Errorf("client_id 不匹配: got %s", r.FormValue("client_id")) + } + if r.FormValue("client_secret") != "test-secret" { + t.Errorf("client_secret 不匹配: got %s", r.FormValue("client_secret")) + } + + w.Header().Set("Content-Type", "application/json") + w.WriteHeader(http.StatusOK) + _ = json.NewEncoder(w).Encode(TokenResponse{ + AccessToken: "refreshed-access-token", + ExpiresIn: 3600, + TokenType: "Bearer", + }) + })) + defer server.Close() + + client := newTestClientWithRedirect(map[string]string{ + TokenURL: server.URL, + }) + + tokenResp, err := client.RefreshToken(context.Background(), "my-refresh-token") + if err != nil { + t.Fatalf("RefreshToken 失败: %v", err) + } + if tokenResp.AccessToken != "refreshed-access-token" { + t.Errorf("AccessToken 不匹配: got %s, want refreshed-access-token", tokenResp.AccessToken) + } + if tokenResp.ExpiresIn != 3600 { + t.Errorf("ExpiresIn 不匹配: got %d, want 3600", tokenResp.ExpiresIn) + } +} + +func TestClient_RefreshToken_ServerError_RealCall(t *testing.T) { + t.Setenv(AntigravityOAuthClientSecretEnv, "test-secret") + + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusUnauthorized) + _, _ = w.Write([]byte(`{"error":"invalid_grant","error_description":"token revoked"}`)) + })) + defer server.Close() + + client := newTestClientWithRedirect(map[string]string{ + TokenURL: server.URL, + }) + + _, err := client.RefreshToken(context.Background(), "revoked-token") + if err == nil { + t.Fatal("服务器返回 401 时应返回错误") + } + if !strings.Contains(err.Error(), "token 刷新失败") { + t.Errorf("错误信息应包含 'token 刷新失败': got %s", err.Error()) + } +} + +func TestClient_RefreshToken_InvalidJSON_RealCall(t *testing.T) { + t.Setenv(AntigravityOAuthClientSecretEnv, "test-secret") + + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.Header().Set("Content-Type", "application/json") + w.WriteHeader(http.StatusOK) + _, _ = w.Write([]byte(`not-json`)) + })) + defer server.Close() + + client := newTestClientWithRedirect(map[string]string{ + TokenURL: server.URL, + }) + + _, err := client.RefreshToken(context.Background(), "refresh-tok") + if err == nil { + t.Fatal("无效 JSON 响应应返回错误") + } + if !strings.Contains(err.Error(), "token 解析失败") { + t.Errorf("错误信息应包含 'token 解析失败': got %s", err.Error()) + } +} + +func TestClient_RefreshToken_ContextCanceled_RealCall(t *testing.T) { + t.Setenv(AntigravityOAuthClientSecretEnv, "test-secret") + + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + time.Sleep(5 * time.Second) + w.WriteHeader(http.StatusOK) + })) + defer server.Close() + + client := newTestClientWithRedirect(map[string]string{ + TokenURL: server.URL, + }) + + ctx, cancel := context.WithCancel(context.Background()) + cancel() + + _, err := client.RefreshToken(ctx, "refresh-tok") + if err == nil { + t.Fatal("context 取消时应返回错误") + } +} + +// --------------------------------------------------------------------------- +// Client.GetUserInfo - 真正调用方法的测试 +// --------------------------------------------------------------------------- + +func TestClient_GetUserInfo_Success_RealCall(t *testing.T) { + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if r.Method != http.MethodGet { + t.Errorf("请求方法不匹配: got %s, want GET", r.Method) + } + auth := r.Header.Get("Authorization") + if auth != "Bearer user-access-token" { + t.Errorf("Authorization 不匹配: got %s", auth) + } + + w.Header().Set("Content-Type", "application/json") + w.WriteHeader(http.StatusOK) + _ = json.NewEncoder(w).Encode(UserInfo{ + Email: "test@example.com", + Name: "Test User", + GivenName: "Test", + FamilyName: "User", + Picture: "https://example.com/avatar.jpg", + }) + })) + defer server.Close() + + client := newTestClientWithRedirect(map[string]string{ + UserInfoURL: server.URL, + }) + + userInfo, err := client.GetUserInfo(context.Background(), "user-access-token") + if err != nil { + t.Fatalf("GetUserInfo 失败: %v", err) + } + if userInfo.Email != "test@example.com" { + t.Errorf("Email 不匹配: got %s, want test@example.com", userInfo.Email) + } + if userInfo.Name != "Test User" { + t.Errorf("Name 不匹配: got %s, want Test User", userInfo.Name) + } + if userInfo.GivenName != "Test" { + t.Errorf("GivenName 不匹配: got %s, want Test", userInfo.GivenName) + } + if userInfo.FamilyName != "User" { + t.Errorf("FamilyName 不匹配: got %s, want User", userInfo.FamilyName) + } + if userInfo.Picture != "https://example.com/avatar.jpg" { + t.Errorf("Picture 不匹配: got %s", userInfo.Picture) + } +} + +func TestClient_GetUserInfo_Unauthorized_RealCall(t *testing.T) { + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusUnauthorized) + _, _ = w.Write([]byte(`{"error":"invalid_token"}`)) + })) + defer server.Close() + + client := newTestClientWithRedirect(map[string]string{ + UserInfoURL: server.URL, + }) + + _, err := client.GetUserInfo(context.Background(), "bad-token") + if err == nil { + t.Fatal("服务器返回 401 时应返回错误") + } + if !strings.Contains(err.Error(), "获取用户信息失败") { + t.Errorf("错误信息应包含 '获取用户信息失败': got %s", err.Error()) + } + if !strings.Contains(err.Error(), "401") { + t.Errorf("错误信息应包含状态码 401: got %s", err.Error()) + } +} + +func TestClient_GetUserInfo_InvalidJSON_RealCall(t *testing.T) { + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.Header().Set("Content-Type", "application/json") + w.WriteHeader(http.StatusOK) + _, _ = w.Write([]byte(`{broken`)) + })) + defer server.Close() + + client := newTestClientWithRedirect(map[string]string{ + UserInfoURL: server.URL, + }) + + _, err := client.GetUserInfo(context.Background(), "token") + if err == nil { + t.Fatal("无效 JSON 响应应返回错误") + } + if !strings.Contains(err.Error(), "用户信息解析失败") { + t.Errorf("错误信息应包含 '用户信息解析失败': got %s", err.Error()) + } +} + +func TestClient_GetUserInfo_ContextCanceled_RealCall(t *testing.T) { + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + time.Sleep(5 * time.Second) + w.WriteHeader(http.StatusOK) + })) + defer server.Close() + + client := newTestClientWithRedirect(map[string]string{ + UserInfoURL: server.URL, + }) + + ctx, cancel := context.WithCancel(context.Background()) + cancel() + + _, err := client.GetUserInfo(ctx, "token") + if err == nil { + t.Fatal("context 取消时应返回错误") + } +} + +// --------------------------------------------------------------------------- +// Client.LoadCodeAssist - 真正调用方法的测试 +// --------------------------------------------------------------------------- + +// withMockBaseURLs 临时替换 BaseURLs,测试结束后恢复 +func withMockBaseURLs(t *testing.T, urls []string) { + t.Helper() + origBaseURLs := BaseURLs + origBaseURL := BaseURL + BaseURLs = urls + if len(urls) > 0 { + BaseURL = urls[0] + } + t.Cleanup(func() { + BaseURLs = origBaseURLs + BaseURL = origBaseURL + }) +} + +func TestClient_LoadCodeAssist_Success_RealCall(t *testing.T) { + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if r.Method != http.MethodPost { + t.Errorf("请求方法不匹配: got %s, want POST", r.Method) + } + if !strings.HasSuffix(r.URL.Path, "/v1internal:loadCodeAssist") { + t.Errorf("URL 路径不匹配: got %s", r.URL.Path) + } + auth := r.Header.Get("Authorization") + if auth != "Bearer test-token" { + t.Errorf("Authorization 不匹配: got %s", auth) + } + if ct := r.Header.Get("Content-Type"); ct != "application/json" { + t.Errorf("Content-Type 不匹配: got %s", ct) + } + if ua := r.Header.Get("User-Agent"); ua != UserAgent { + t.Errorf("User-Agent 不匹配: got %s", ua) + } + + // 验证请求体 + var reqBody LoadCodeAssistRequest + if err := json.NewDecoder(r.Body).Decode(&reqBody); err != nil { + t.Fatalf("解析请求体失败: %v", err) + } + if reqBody.Metadata.IDEType != "ANTIGRAVITY" { + t.Errorf("IDEType 不匹配: got %s, want ANTIGRAVITY", reqBody.Metadata.IDEType) + } + + w.Header().Set("Content-Type", "application/json") + w.WriteHeader(http.StatusOK) + _, _ = w.Write([]byte(`{ + "cloudaicompanionProject": "test-project-123", + "currentTier": {"id": "free-tier", "name": "Free"}, + "paidTier": {"id": "g1-pro-tier", "name": "Pro", "description": "Pro plan"} + }`)) + })) + defer server.Close() + + withMockBaseURLs(t, []string{server.URL}) + + client := NewClient("") + resp, rawResp, err := client.LoadCodeAssist(context.Background(), "test-token") + if err != nil { + t.Fatalf("LoadCodeAssist 失败: %v", err) + } + if resp.CloudAICompanionProject != "test-project-123" { + t.Errorf("CloudAICompanionProject 不匹配: got %s", resp.CloudAICompanionProject) + } + if resp.GetTier() != "g1-pro-tier" { + t.Errorf("GetTier 不匹配: got %s, want g1-pro-tier", resp.GetTier()) + } + if resp.CurrentTier == nil || resp.CurrentTier.ID != "free-tier" { + t.Errorf("CurrentTier 不匹配: got %+v", resp.CurrentTier) + } + if resp.PaidTier == nil || resp.PaidTier.ID != "g1-pro-tier" { + t.Errorf("PaidTier 不匹配: got %+v", resp.PaidTier) + } + // 验证原始 JSON map + if rawResp == nil { + t.Fatal("rawResp 不应为 nil") + } + if rawResp["cloudaicompanionProject"] != "test-project-123" { + t.Errorf("rawResp cloudaicompanionProject 不匹配: got %v", rawResp["cloudaicompanionProject"]) + } +} + +func TestClient_LoadCodeAssist_HTTPError_RealCall(t *testing.T) { + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusForbidden) + _, _ = w.Write([]byte(`{"error":"forbidden"}`)) + })) + defer server.Close() + + withMockBaseURLs(t, []string{server.URL}) + + client := NewClient("") + _, _, err := client.LoadCodeAssist(context.Background(), "bad-token") + if err == nil { + t.Fatal("服务器返回 403 时应返回错误") + } + if !strings.Contains(err.Error(), "loadCodeAssist 失败") { + t.Errorf("错误信息应包含 'loadCodeAssist 失败': got %s", err.Error()) + } + if !strings.Contains(err.Error(), "403") { + t.Errorf("错误信息应包含状态码 403: got %s", err.Error()) + } +} + +func TestClient_LoadCodeAssist_InvalidJSON_RealCall(t *testing.T) { + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.Header().Set("Content-Type", "application/json") + w.WriteHeader(http.StatusOK) + _, _ = w.Write([]byte(`{not valid json!!!`)) + })) + defer server.Close() + + withMockBaseURLs(t, []string{server.URL}) + + client := NewClient("") + _, _, err := client.LoadCodeAssist(context.Background(), "token") + if err == nil { + t.Fatal("无效 JSON 响应应返回错误") + } + if !strings.Contains(err.Error(), "响应解析失败") { + t.Errorf("错误信息应包含 '响应解析失败': got %s", err.Error()) + } +} + +func TestClient_LoadCodeAssist_URLFallback_RealCall(t *testing.T) { + // 第一个 server 返回 500,第二个 server 返回成功 + callCount := 0 + server1 := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + callCount++ + w.WriteHeader(http.StatusInternalServerError) + _, _ = w.Write([]byte(`{"error":"internal"}`)) + })) + defer server1.Close() + + server2 := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + callCount++ + w.Header().Set("Content-Type", "application/json") + w.WriteHeader(http.StatusOK) + _, _ = w.Write([]byte(`{ + "cloudaicompanionProject": "fallback-project", + "currentTier": {"id": "free-tier", "name": "Free"} + }`)) + })) + defer server2.Close() + + withMockBaseURLs(t, []string{server1.URL, server2.URL}) + + client := NewClient("") + resp, _, err := client.LoadCodeAssist(context.Background(), "token") + if err != nil { + t.Fatalf("LoadCodeAssist 应在 fallback 后成功: %v", err) + } + if resp.CloudAICompanionProject != "fallback-project" { + t.Errorf("CloudAICompanionProject 不匹配: got %s", resp.CloudAICompanionProject) + } + if callCount != 2 { + t.Errorf("应该调用了 2 个 server,实际调用 %d 次", callCount) + } +} + +func TestClient_LoadCodeAssist_AllURLsFail_RealCall(t *testing.T) { + server1 := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusServiceUnavailable) + _, _ = w.Write([]byte(`{"error":"unavailable"}`)) + })) + defer server1.Close() + + server2 := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusBadGateway) + _, _ = w.Write([]byte(`{"error":"bad_gateway"}`)) + })) + defer server2.Close() + + withMockBaseURLs(t, []string{server1.URL, server2.URL}) + + client := NewClient("") + _, _, err := client.LoadCodeAssist(context.Background(), "token") + if err == nil { + t.Fatal("所有 URL 都失败时应返回错误") + } +} + +func TestClient_LoadCodeAssist_ContextCanceled_RealCall(t *testing.T) { + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + time.Sleep(5 * time.Second) + w.WriteHeader(http.StatusOK) + })) + defer server.Close() + + withMockBaseURLs(t, []string{server.URL}) + + client := NewClient("") + ctx, cancel := context.WithCancel(context.Background()) + cancel() + + _, _, err := client.LoadCodeAssist(ctx, "token") + if err == nil { + t.Fatal("context 取消时应返回错误") + } +} + +// --------------------------------------------------------------------------- +// Client.FetchAvailableModels - 真正调用方法的测试 +// --------------------------------------------------------------------------- + +func TestClient_FetchAvailableModels_Success_RealCall(t *testing.T) { + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if r.Method != http.MethodPost { + t.Errorf("请求方法不匹配: got %s, want POST", r.Method) + } + if !strings.HasSuffix(r.URL.Path, "/v1internal:fetchAvailableModels") { + t.Errorf("URL 路径不匹配: got %s", r.URL.Path) + } + auth := r.Header.Get("Authorization") + if auth != "Bearer test-token" { + t.Errorf("Authorization 不匹配: got %s", auth) + } + if ct := r.Header.Get("Content-Type"); ct != "application/json" { + t.Errorf("Content-Type 不匹配: got %s", ct) + } + if ua := r.Header.Get("User-Agent"); ua != UserAgent { + t.Errorf("User-Agent 不匹配: got %s", ua) + } + + // 验证请求体 + var reqBody FetchAvailableModelsRequest + if err := json.NewDecoder(r.Body).Decode(&reqBody); err != nil { + t.Fatalf("解析请求体失败: %v", err) + } + if reqBody.Project != "project-abc" { + t.Errorf("Project 不匹配: got %s, want project-abc", reqBody.Project) + } + + w.Header().Set("Content-Type", "application/json") + w.WriteHeader(http.StatusOK) + _, _ = w.Write([]byte(`{ + "models": { + "gemini-2.0-flash": { + "quotaInfo": { + "remainingFraction": 0.85, + "resetTime": "2025-01-01T00:00:00Z" + } + }, + "gemini-2.5-pro": { + "quotaInfo": { + "remainingFraction": 0.5 + } + } + } + }`)) + })) + defer server.Close() + + withMockBaseURLs(t, []string{server.URL}) + + client := NewClient("") + resp, rawResp, err := client.FetchAvailableModels(context.Background(), "test-token", "project-abc") + if err != nil { + t.Fatalf("FetchAvailableModels 失败: %v", err) + } + if resp.Models == nil { + t.Fatal("Models 不应为 nil") + } + if len(resp.Models) != 2 { + t.Errorf("Models 数量不匹配: got %d, want 2", len(resp.Models)) + } + + flashModel, ok := resp.Models["gemini-2.0-flash"] + if !ok { + t.Fatal("缺少 gemini-2.0-flash 模型") + } + if flashModel.QuotaInfo == nil { + t.Fatal("gemini-2.0-flash QuotaInfo 不应为 nil") + } + if flashModel.QuotaInfo.RemainingFraction != 0.85 { + t.Errorf("RemainingFraction 不匹配: got %f, want 0.85", flashModel.QuotaInfo.RemainingFraction) + } + if flashModel.QuotaInfo.ResetTime != "2025-01-01T00:00:00Z" { + t.Errorf("ResetTime 不匹配: got %s", flashModel.QuotaInfo.ResetTime) + } + + proModel, ok := resp.Models["gemini-2.5-pro"] + if !ok { + t.Fatal("缺少 gemini-2.5-pro 模型") + } + if proModel.QuotaInfo == nil { + t.Fatal("gemini-2.5-pro QuotaInfo 不应为 nil") + } + if proModel.QuotaInfo.RemainingFraction != 0.5 { + t.Errorf("RemainingFraction 不匹配: got %f, want 0.5", proModel.QuotaInfo.RemainingFraction) + } + + // 验证原始 JSON map + if rawResp == nil { + t.Fatal("rawResp 不应为 nil") + } + if rawResp["models"] == nil { + t.Error("rawResp models 不应为 nil") + } +} + +func TestClient_FetchAvailableModels_HTTPError_RealCall(t *testing.T) { + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusForbidden) + _, _ = w.Write([]byte(`{"error":"forbidden"}`)) + })) + defer server.Close() + + withMockBaseURLs(t, []string{server.URL}) + + client := NewClient("") + _, _, err := client.FetchAvailableModels(context.Background(), "bad-token", "proj") + if err == nil { + t.Fatal("服务器返回 403 时应返回错误") + } + if !strings.Contains(err.Error(), "fetchAvailableModels 失败") { + t.Errorf("错误信息应包含 'fetchAvailableModels 失败': got %s", err.Error()) + } +} + +func TestClient_FetchAvailableModels_InvalidJSON_RealCall(t *testing.T) { + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.Header().Set("Content-Type", "application/json") + w.WriteHeader(http.StatusOK) + _, _ = w.Write([]byte(`<<>>`)) + })) + defer server.Close() + + withMockBaseURLs(t, []string{server.URL}) + + client := NewClient("") + _, _, err := client.FetchAvailableModels(context.Background(), "token", "proj") + if err == nil { + t.Fatal("无效 JSON 响应应返回错误") + } + if !strings.Contains(err.Error(), "响应解析失败") { + t.Errorf("错误信息应包含 '响应解析失败': got %s", err.Error()) + } +} + +func TestClient_FetchAvailableModels_URLFallback_RealCall(t *testing.T) { + callCount := 0 + // 第一个 server 返回 429,第二个 server 返回成功 + server1 := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + callCount++ + w.WriteHeader(http.StatusTooManyRequests) + _, _ = w.Write([]byte(`{"error":"rate_limited"}`)) + })) + defer server1.Close() + + server2 := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + callCount++ + w.Header().Set("Content-Type", "application/json") + w.WriteHeader(http.StatusOK) + _, _ = w.Write([]byte(`{"models": {"model-a": {}}}`)) + })) + defer server2.Close() + + withMockBaseURLs(t, []string{server1.URL, server2.URL}) + + client := NewClient("") + resp, _, err := client.FetchAvailableModels(context.Background(), "token", "proj") + if err != nil { + t.Fatalf("FetchAvailableModels 应在 fallback 后成功: %v", err) + } + if _, ok := resp.Models["model-a"]; !ok { + t.Error("应返回 fallback server 的模型") + } + if callCount != 2 { + t.Errorf("应该调用了 2 个 server,实际调用 %d 次", callCount) + } +} + +func TestClient_FetchAvailableModels_AllURLsFail_RealCall(t *testing.T) { + server1 := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusNotFound) + _, _ = w.Write([]byte(`not found`)) + })) + defer server1.Close() + + server2 := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusInternalServerError) + _, _ = w.Write([]byte(`internal error`)) + })) + defer server2.Close() + + withMockBaseURLs(t, []string{server1.URL, server2.URL}) + + client := NewClient("") + _, _, err := client.FetchAvailableModels(context.Background(), "token", "proj") + if err == nil { + t.Fatal("所有 URL 都失败时应返回错误") + } +} + +func TestClient_FetchAvailableModels_ContextCanceled_RealCall(t *testing.T) { + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + time.Sleep(5 * time.Second) + w.WriteHeader(http.StatusOK) + })) + defer server.Close() + + withMockBaseURLs(t, []string{server.URL}) + + client := NewClient("") + ctx, cancel := context.WithCancel(context.Background()) + cancel() + + _, _, err := client.FetchAvailableModels(ctx, "token", "proj") + if err == nil { + t.Fatal("context 取消时应返回错误") + } +} + +func TestClient_FetchAvailableModels_EmptyModels_RealCall(t *testing.T) { + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.Header().Set("Content-Type", "application/json") + w.WriteHeader(http.StatusOK) + _, _ = w.Write([]byte(`{"models": {}}`)) + })) + defer server.Close() + + withMockBaseURLs(t, []string{server.URL}) + + client := NewClient("") + resp, rawResp, err := client.FetchAvailableModels(context.Background(), "token", "proj") + if err != nil { + t.Fatalf("FetchAvailableModels 失败: %v", err) + } + if resp.Models == nil { + t.Fatal("Models 不应为 nil") + } + if len(resp.Models) != 0 { + t.Errorf("Models 应为空: got %d", len(resp.Models)) + } + if rawResp == nil { + t.Fatal("rawResp 不应为 nil") + } +} + +// --------------------------------------------------------------------------- +// LoadCodeAssist 和 FetchAvailableModels 的 408 fallback 测试 +// --------------------------------------------------------------------------- + +func TestClient_LoadCodeAssist_408Fallback_RealCall(t *testing.T) { + server1 := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusRequestTimeout) + _, _ = w.Write([]byte(`timeout`)) + })) + defer server1.Close() + + server2 := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.Header().Set("Content-Type", "application/json") + w.WriteHeader(http.StatusOK) + _, _ = w.Write([]byte(`{"cloudaicompanionProject":"p2","currentTier":"free-tier"}`)) + })) + defer server2.Close() + + withMockBaseURLs(t, []string{server1.URL, server2.URL}) + + client := NewClient("") + resp, _, err := client.LoadCodeAssist(context.Background(), "token") + if err != nil { + t.Fatalf("LoadCodeAssist 应在 408 fallback 后成功: %v", err) + } + if resp.CloudAICompanionProject != "p2" { + t.Errorf("CloudAICompanionProject 不匹配: got %s", resp.CloudAICompanionProject) + } +} + +func TestClient_FetchAvailableModels_404Fallback_RealCall(t *testing.T) { + server1 := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusNotFound) + _, _ = w.Write([]byte(`not found`)) + })) + defer server1.Close() + + server2 := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.Header().Set("Content-Type", "application/json") + w.WriteHeader(http.StatusOK) + _, _ = w.Write([]byte(`{"models":{"m1":{"quotaInfo":{"remainingFraction":1.0}}}}`)) + })) + defer server2.Close() + + withMockBaseURLs(t, []string{server1.URL, server2.URL}) + + client := NewClient("") + resp, _, err := client.FetchAvailableModels(context.Background(), "token", "proj") + if err != nil { + t.Fatalf("FetchAvailableModels 应在 404 fallback 后成功: %v", err) + } + if _, ok := resp.Models["m1"]; !ok { + t.Error("应返回 fallback server 的模型 m1") + } +} + func TestExtractProjectIDFromOnboardResponse(t *testing.T) { t.Parallel() diff --git a/backend/internal/pkg/antigravity/oauth.go b/backend/internal/pkg/antigravity/oauth.go index 89b1bcde..cdda9be6 100644 --- a/backend/internal/pkg/antigravity/oauth.go +++ b/backend/internal/pkg/antigravity/oauth.go @@ -6,11 +6,14 @@ import ( "encoding/base64" "encoding/hex" "fmt" + "net/http" "net/url" "os" "strings" "sync" "time" + + infraerrors "github.com/Wei-Shaw/sub2api/internal/pkg/errors" ) const ( @@ -21,7 +24,11 @@ const ( // Antigravity OAuth 客户端凭证 ClientID = "1071006060591-tmhssin2h21lcre235vtolojh4g403ep.apps.googleusercontent.com" - ClientSecret = "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" + ClientSecret = "" + + // AntigravityOAuthClientSecretEnv 是 Antigravity OAuth client_secret 的环境变量名。 + // 出于安全原因,该值不得硬编码入库。 + AntigravityOAuthClientSecretEnv = "ANTIGRAVITY_OAUTH_CLIENT_SECRET" // 固定的 redirect_uri(用户需手动复制 code) RedirectURI = "http://localhost:8085/callback" @@ -57,6 +64,17 @@ func init() { // GetUserAgent 返回当前配置的 User-Agent func GetUserAgent() string { return fmt.Sprintf("antigravity/%s windows/amd64", defaultUserAgentVersion) + +func getClientSecret() (string, error) { + if v := strings.TrimSpace(ClientSecret); v != "" { + return v, nil + } + if v, ok := os.LookupEnv(AntigravityOAuthClientSecretEnv); ok { + if vv := strings.TrimSpace(v); vv != "" { + return vv, nil + } + } + return "", infraerrors.Newf(http.StatusBadRequest, "ANTIGRAVITY_OAUTH_CLIENT_SECRET_MISSING", "missing antigravity oauth client_secret; set %s", AntigravityOAuthClientSecretEnv) } // BaseURLs 定义 Antigravity API 端点(与 Antigravity-Manager 保持一致) diff --git a/backend/internal/pkg/antigravity/oauth_test.go b/backend/internal/pkg/antigravity/oauth_test.go new file mode 100644 index 00000000..67731c06 --- /dev/null +++ b/backend/internal/pkg/antigravity/oauth_test.go @@ -0,0 +1,704 @@ +//go:build unit + +package antigravity + +import ( + "crypto/sha256" + "encoding/base64" + "encoding/hex" + "net/url" + "strings" + "testing" + "time" +) + +// --------------------------------------------------------------------------- +// getClientSecret +// --------------------------------------------------------------------------- + +func TestGetClientSecret_环境变量设置(t *testing.T) { + t.Setenv(AntigravityOAuthClientSecretEnv, "my-secret-value") + + secret, err := getClientSecret() + if err != nil { + t.Fatalf("获取 client_secret 失败: %v", err) + } + if secret != "my-secret-value" { + t.Errorf("client_secret 不匹配: got %s, want my-secret-value", secret) + } +} + +func TestGetClientSecret_环境变量为空(t *testing.T) { + t.Setenv(AntigravityOAuthClientSecretEnv, "") + + _, err := getClientSecret() + if err == nil { + t.Fatal("环境变量为空时应返回错误") + } + if !strings.Contains(err.Error(), AntigravityOAuthClientSecretEnv) { + t.Errorf("错误信息应包含环境变量名: got %s", err.Error()) + } +} + +func TestGetClientSecret_环境变量未设置(t *testing.T) { + // t.Setenv 会在测试结束时恢复,但我们需要确保它不存在 + // 注意:如果 ClientSecret 常量非空,这个测试会直接返回常量值 + // 当前代码中 ClientSecret = "",所以会走环境变量逻辑 + + // 明确设置再取消,确保环境变量不存在 + t.Setenv(AntigravityOAuthClientSecretEnv, "") + + _, err := getClientSecret() + if err == nil { + t.Fatal("环境变量未设置时应返回错误") + } +} + +func TestGetClientSecret_环境变量含空格(t *testing.T) { + t.Setenv(AntigravityOAuthClientSecretEnv, " ") + + _, err := getClientSecret() + if err == nil { + t.Fatal("环境变量仅含空格时应返回错误") + } +} + +func TestGetClientSecret_环境变量有前后空格(t *testing.T) { + t.Setenv(AntigravityOAuthClientSecretEnv, " valid-secret ") + + secret, err := getClientSecret() + if err != nil { + t.Fatalf("获取 client_secret 失败: %v", err) + } + if secret != "valid-secret" { + t.Errorf("应去除前后空格: got %q, want %q", secret, "valid-secret") + } +} + +// --------------------------------------------------------------------------- +// ForwardBaseURLs +// --------------------------------------------------------------------------- + +func TestForwardBaseURLs_Daily优先(t *testing.T) { + urls := ForwardBaseURLs() + if len(urls) == 0 { + t.Fatal("ForwardBaseURLs 返回空列表") + } + + // daily URL 应排在第一位 + if urls[0] != antigravityDailyBaseURL { + t.Errorf("第一个 URL 应为 daily: got %s, want %s", urls[0], antigravityDailyBaseURL) + } + + // 应包含所有 URL + if len(urls) != len(BaseURLs) { + t.Errorf("URL 数量不匹配: got %d, want %d", len(urls), len(BaseURLs)) + } + + // 验证 prod URL 也在列表中 + found := false + for _, u := range urls { + if u == antigravityProdBaseURL { + found = true + break + } + } + if !found { + t.Error("ForwardBaseURLs 中缺少 prod URL") + } +} + +func TestForwardBaseURLs_不修改原切片(t *testing.T) { + originalFirst := BaseURLs[0] + _ = ForwardBaseURLs() + // 确保原始 BaseURLs 未被修改 + if BaseURLs[0] != originalFirst { + t.Errorf("ForwardBaseURLs 不应修改原始 BaseURLs: got %s, want %s", BaseURLs[0], originalFirst) + } +} + +// --------------------------------------------------------------------------- +// URLAvailability +// --------------------------------------------------------------------------- + +func TestNewURLAvailability(t *testing.T) { + ua := NewURLAvailability(5 * time.Minute) + if ua == nil { + t.Fatal("NewURLAvailability 返回 nil") + } + if ua.ttl != 5*time.Minute { + t.Errorf("TTL 不匹配: got %v, want 5m", ua.ttl) + } + if ua.unavailable == nil { + t.Error("unavailable map 不应为 nil") + } +} + +func TestURLAvailability_MarkUnavailable(t *testing.T) { + ua := NewURLAvailability(5 * time.Minute) + testURL := "https://example.com" + + ua.MarkUnavailable(testURL) + + if ua.IsAvailable(testURL) { + t.Error("标记为不可用后 IsAvailable 应返回 false") + } +} + +func TestURLAvailability_MarkSuccess(t *testing.T) { + ua := NewURLAvailability(5 * time.Minute) + testURL := "https://example.com" + + // 先标记为不可用 + ua.MarkUnavailable(testURL) + if ua.IsAvailable(testURL) { + t.Error("标记为不可用后应不可用") + } + + // 标记成功后应恢复可用 + ua.MarkSuccess(testURL) + if !ua.IsAvailable(testURL) { + t.Error("MarkSuccess 后应恢复可用") + } + + // 验证 lastSuccess 被设置 + ua.mu.RLock() + if ua.lastSuccess != testURL { + t.Errorf("lastSuccess 不匹配: got %s, want %s", ua.lastSuccess, testURL) + } + ua.mu.RUnlock() +} + +func TestURLAvailability_IsAvailable_TTL过期(t *testing.T) { + // 使用极短的 TTL + ua := NewURLAvailability(1 * time.Millisecond) + testURL := "https://example.com" + + ua.MarkUnavailable(testURL) + // 等待 TTL 过期 + time.Sleep(5 * time.Millisecond) + + if !ua.IsAvailable(testURL) { + t.Error("TTL 过期后 URL 应恢复可用") + } +} + +func TestURLAvailability_IsAvailable_未标记的URL(t *testing.T) { + ua := NewURLAvailability(5 * time.Minute) + if !ua.IsAvailable("https://never-marked.com") { + t.Error("未标记的 URL 应默认可用") + } +} + +func TestURLAvailability_GetAvailableURLs(t *testing.T) { + ua := NewURLAvailability(10 * time.Minute) + + // 默认所有 URL 都可用 + urls := ua.GetAvailableURLs() + if len(urls) != len(BaseURLs) { + t.Errorf("可用 URL 数量不匹配: got %d, want %d", len(urls), len(BaseURLs)) + } +} + +func TestURLAvailability_GetAvailableURLs_标记一个不可用(t *testing.T) { + ua := NewURLAvailability(10 * time.Minute) + + if len(BaseURLs) < 2 { + t.Skip("BaseURLs 少于 2 个,跳过此测试") + } + + ua.MarkUnavailable(BaseURLs[0]) + urls := ua.GetAvailableURLs() + + // 标记的 URL 不应出现在可用列表中 + for _, u := range urls { + if u == BaseURLs[0] { + t.Errorf("被标记不可用的 URL 不应出现在可用列表中: %s", BaseURLs[0]) + } + } +} + +func TestURLAvailability_GetAvailableURLsWithBase(t *testing.T) { + ua := NewURLAvailability(10 * time.Minute) + customURLs := []string{"https://a.com", "https://b.com", "https://c.com"} + + urls := ua.GetAvailableURLsWithBase(customURLs) + if len(urls) != 3 { + t.Errorf("可用 URL 数量不匹配: got %d, want 3", len(urls)) + } +} + +func TestURLAvailability_GetAvailableURLsWithBase_LastSuccess优先(t *testing.T) { + ua := NewURLAvailability(10 * time.Minute) + customURLs := []string{"https://a.com", "https://b.com", "https://c.com"} + + ua.MarkSuccess("https://c.com") + + urls := ua.GetAvailableURLsWithBase(customURLs) + if len(urls) != 3 { + t.Fatalf("可用 URL 数量不匹配: got %d, want 3", len(urls)) + } + // c.com 应排在第一位 + if urls[0] != "https://c.com" { + t.Errorf("lastSuccess 应排在第一位: got %s, want https://c.com", urls[0]) + } + // 其余按原始顺序 + if urls[1] != "https://a.com" { + t.Errorf("第二个应为 a.com: got %s", urls[1]) + } + if urls[2] != "https://b.com" { + t.Errorf("第三个应为 b.com: got %s", urls[2]) + } +} + +func TestURLAvailability_GetAvailableURLsWithBase_LastSuccess不可用(t *testing.T) { + ua := NewURLAvailability(10 * time.Minute) + customURLs := []string{"https://a.com", "https://b.com"} + + ua.MarkSuccess("https://b.com") + ua.MarkUnavailable("https://b.com") + + urls := ua.GetAvailableURLsWithBase(customURLs) + // b.com 被标记不可用,不应出现 + if len(urls) != 1 { + t.Fatalf("可用 URL 数量不匹配: got %d, want 1", len(urls)) + } + if urls[0] != "https://a.com" { + t.Errorf("仅 a.com 应可用: got %s", urls[0]) + } +} + +func TestURLAvailability_GetAvailableURLsWithBase_LastSuccess不在列表中(t *testing.T) { + ua := NewURLAvailability(10 * time.Minute) + customURLs := []string{"https://a.com", "https://b.com"} + + ua.MarkSuccess("https://not-in-list.com") + + urls := ua.GetAvailableURLsWithBase(customURLs) + // lastSuccess 不在自定义列表中,不应被添加 + if len(urls) != 2 { + t.Fatalf("可用 URL 数量不匹配: got %d, want 2", len(urls)) + } +} + +// --------------------------------------------------------------------------- +// SessionStore +// --------------------------------------------------------------------------- + +func TestNewSessionStore(t *testing.T) { + store := NewSessionStore() + defer store.Stop() + + if store == nil { + t.Fatal("NewSessionStore 返回 nil") + } + if store.sessions == nil { + t.Error("sessions map 不应为 nil") + } +} + +func TestSessionStore_SetAndGet(t *testing.T) { + store := NewSessionStore() + defer store.Stop() + + session := &OAuthSession{ + State: "test-state", + CodeVerifier: "test-verifier", + ProxyURL: "http://proxy.example.com", + CreatedAt: time.Now(), + } + + store.Set("session-1", session) + + got, ok := store.Get("session-1") + if !ok { + t.Fatal("Get 应返回 true") + } + if got.State != "test-state" { + t.Errorf("State 不匹配: got %s", got.State) + } + if got.CodeVerifier != "test-verifier" { + t.Errorf("CodeVerifier 不匹配: got %s", got.CodeVerifier) + } + if got.ProxyURL != "http://proxy.example.com" { + t.Errorf("ProxyURL 不匹配: got %s", got.ProxyURL) + } +} + +func TestSessionStore_Get_不存在(t *testing.T) { + store := NewSessionStore() + defer store.Stop() + + _, ok := store.Get("nonexistent") + if ok { + t.Error("不存在的 session 应返回 false") + } +} + +func TestSessionStore_Get_过期(t *testing.T) { + store := NewSessionStore() + defer store.Stop() + + session := &OAuthSession{ + State: "expired-state", + CreatedAt: time.Now().Add(-SessionTTL - time.Minute), // 已过期 + } + + store.Set("expired-session", session) + + _, ok := store.Get("expired-session") + if ok { + t.Error("过期的 session 应返回 false") + } +} + +func TestSessionStore_Delete(t *testing.T) { + store := NewSessionStore() + defer store.Stop() + + session := &OAuthSession{ + State: "to-delete", + CreatedAt: time.Now(), + } + + store.Set("del-session", session) + store.Delete("del-session") + + _, ok := store.Get("del-session") + if ok { + t.Error("删除后 Get 应返回 false") + } +} + +func TestSessionStore_Delete_不存在(t *testing.T) { + store := NewSessionStore() + defer store.Stop() + + // 删除不存在的 session 不应 panic + store.Delete("nonexistent") +} + +func TestSessionStore_Stop(t *testing.T) { + store := NewSessionStore() + store.Stop() + + // 多次 Stop 不应 panic + store.Stop() +} + +func TestSessionStore_多个Session(t *testing.T) { + store := NewSessionStore() + defer store.Stop() + + for i := 0; i < 10; i++ { + session := &OAuthSession{ + State: "state-" + string(rune('0'+i)), + CreatedAt: time.Now(), + } + store.Set("session-"+string(rune('0'+i)), session) + } + + // 验证都能取到 + for i := 0; i < 10; i++ { + _, ok := store.Get("session-" + string(rune('0'+i))) + if !ok { + t.Errorf("session-%d 应存在", i) + } + } +} + +// --------------------------------------------------------------------------- +// GenerateRandomBytes +// --------------------------------------------------------------------------- + +func TestGenerateRandomBytes_长度正确(t *testing.T) { + sizes := []int{0, 1, 16, 32, 64, 128} + for _, size := range sizes { + b, err := GenerateRandomBytes(size) + if err != nil { + t.Fatalf("GenerateRandomBytes(%d) 失败: %v", size, err) + } + if len(b) != size { + t.Errorf("长度不匹配: got %d, want %d", len(b), size) + } + } +} + +func TestGenerateRandomBytes_不同调用产生不同结果(t *testing.T) { + b1, err := GenerateRandomBytes(32) + if err != nil { + t.Fatalf("第一次调用失败: %v", err) + } + b2, err := GenerateRandomBytes(32) + if err != nil { + t.Fatalf("第二次调用失败: %v", err) + } + // 两次生成的随机字节应该不同(概率上几乎不可能相同) + if string(b1) == string(b2) { + t.Error("两次生成的随机字节相同,概率极低,可能有问题") + } +} + +// --------------------------------------------------------------------------- +// GenerateState +// --------------------------------------------------------------------------- + +func TestGenerateState_返回值格式(t *testing.T) { + state, err := GenerateState() + if err != nil { + t.Fatalf("GenerateState 失败: %v", err) + } + if state == "" { + t.Error("GenerateState 返回空字符串") + } + // base64url 编码不应包含 +, /, = + if strings.ContainsAny(state, "+/=") { + t.Errorf("GenerateState 返回值包含非 base64url 字符: %s", state) + } + // 32 字节的 base64url 编码长度应为 43(去掉了尾部 = 填充) + if len(state) != 43 { + t.Errorf("GenerateState 返回值长度不匹配: got %d, want 43", len(state)) + } +} + +func TestGenerateState_唯一性(t *testing.T) { + s1, _ := GenerateState() + s2, _ := GenerateState() + if s1 == s2 { + t.Error("两次 GenerateState 结果相同") + } +} + +// --------------------------------------------------------------------------- +// GenerateSessionID +// --------------------------------------------------------------------------- + +func TestGenerateSessionID_返回值格式(t *testing.T) { + id, err := GenerateSessionID() + if err != nil { + t.Fatalf("GenerateSessionID 失败: %v", err) + } + if id == "" { + t.Error("GenerateSessionID 返回空字符串") + } + // 16 字节的 hex 编码长度应为 32 + if len(id) != 32 { + t.Errorf("GenerateSessionID 返回值长度不匹配: got %d, want 32", len(id)) + } + // 验证是合法的 hex 字符串 + if _, err := hex.DecodeString(id); err != nil { + t.Errorf("GenerateSessionID 返回值不是合法的 hex 字符串: %s, err: %v", id, err) + } +} + +func TestGenerateSessionID_唯一性(t *testing.T) { + id1, _ := GenerateSessionID() + id2, _ := GenerateSessionID() + if id1 == id2 { + t.Error("两次 GenerateSessionID 结果相同") + } +} + +// --------------------------------------------------------------------------- +// GenerateCodeVerifier +// --------------------------------------------------------------------------- + +func TestGenerateCodeVerifier_返回值格式(t *testing.T) { + verifier, err := GenerateCodeVerifier() + if err != nil { + t.Fatalf("GenerateCodeVerifier 失败: %v", err) + } + if verifier == "" { + t.Error("GenerateCodeVerifier 返回空字符串") + } + // base64url 编码不应包含 +, /, = + if strings.ContainsAny(verifier, "+/=") { + t.Errorf("GenerateCodeVerifier 返回值包含非 base64url 字符: %s", verifier) + } + // 32 字节的 base64url 编码长度应为 43 + if len(verifier) != 43 { + t.Errorf("GenerateCodeVerifier 返回值长度不匹配: got %d, want 43", len(verifier)) + } +} + +func TestGenerateCodeVerifier_唯一性(t *testing.T) { + v1, _ := GenerateCodeVerifier() + v2, _ := GenerateCodeVerifier() + if v1 == v2 { + t.Error("两次 GenerateCodeVerifier 结果相同") + } +} + +// --------------------------------------------------------------------------- +// GenerateCodeChallenge +// --------------------------------------------------------------------------- + +func TestGenerateCodeChallenge_SHA256_Base64URL(t *testing.T) { + verifier := "dBjftJeZ4CVP-mB92K27uhbUJU1p1r_wW1gFWFOEjXk" + + challenge := GenerateCodeChallenge(verifier) + + // 手动计算预期值 + hash := sha256.Sum256([]byte(verifier)) + expected := strings.TrimRight(base64.URLEncoding.EncodeToString(hash[:]), "=") + + if challenge != expected { + t.Errorf("CodeChallenge 不匹配: got %s, want %s", challenge, expected) + } +} + +func TestGenerateCodeChallenge_不含填充字符(t *testing.T) { + challenge := GenerateCodeChallenge("test-verifier") + if strings.Contains(challenge, "=") { + t.Errorf("CodeChallenge 不应包含 = 填充字符: %s", challenge) + } +} + +func TestGenerateCodeChallenge_不含非URL安全字符(t *testing.T) { + challenge := GenerateCodeChallenge("another-verifier") + if strings.ContainsAny(challenge, "+/") { + t.Errorf("CodeChallenge 不应包含 + 或 / 字符: %s", challenge) + } +} + +func TestGenerateCodeChallenge_相同输入相同输出(t *testing.T) { + c1 := GenerateCodeChallenge("same-verifier") + c2 := GenerateCodeChallenge("same-verifier") + if c1 != c2 { + t.Errorf("相同输入应产生相同输出: got %s and %s", c1, c2) + } +} + +func TestGenerateCodeChallenge_不同输入不同输出(t *testing.T) { + c1 := GenerateCodeChallenge("verifier-1") + c2 := GenerateCodeChallenge("verifier-2") + if c1 == c2 { + t.Error("不同输入应产生不同输出") + } +} + +// --------------------------------------------------------------------------- +// BuildAuthorizationURL +// --------------------------------------------------------------------------- + +func TestBuildAuthorizationURL_参数验证(t *testing.T) { + state := "test-state-123" + codeChallenge := "test-challenge-abc" + + authURL := BuildAuthorizationURL(state, codeChallenge) + + // 验证以 AuthorizeURL 开头 + if !strings.HasPrefix(authURL, AuthorizeURL+"?") { + t.Errorf("URL 应以 %s? 开头: got %s", AuthorizeURL, authURL) + } + + // 解析 URL 并验证参数 + parsed, err := url.Parse(authURL) + if err != nil { + t.Fatalf("解析 URL 失败: %v", err) + } + + params := parsed.Query() + + expectedParams := map[string]string{ + "client_id": ClientID, + "redirect_uri": RedirectURI, + "response_type": "code", + "scope": Scopes, + "state": state, + "code_challenge": codeChallenge, + "code_challenge_method": "S256", + "access_type": "offline", + "prompt": "consent", + "include_granted_scopes": "true", + } + + for key, want := range expectedParams { + got := params.Get(key) + if got != want { + t.Errorf("参数 %s 不匹配: got %q, want %q", key, got, want) + } + } +} + +func TestBuildAuthorizationURL_参数数量(t *testing.T) { + authURL := BuildAuthorizationURL("s", "c") + parsed, err := url.Parse(authURL) + if err != nil { + t.Fatalf("解析 URL 失败: %v", err) + } + + params := parsed.Query() + // 应包含 10 个参数 + expectedCount := 10 + if len(params) != expectedCount { + t.Errorf("参数数量不匹配: got %d, want %d", len(params), expectedCount) + } +} + +func TestBuildAuthorizationURL_特殊字符编码(t *testing.T) { + state := "state+with/special=chars" + codeChallenge := "challenge+value" + + authURL := BuildAuthorizationURL(state, codeChallenge) + + parsed, err := url.Parse(authURL) + if err != nil { + t.Fatalf("解析 URL 失败: %v", err) + } + + // 解析后应正确还原特殊字符 + if got := parsed.Query().Get("state"); got != state { + t.Errorf("state 参数编码/解码不匹配: got %q, want %q", got, state) + } +} + +// --------------------------------------------------------------------------- +// 常量值验证 +// --------------------------------------------------------------------------- + +func TestConstants_值正确(t *testing.T) { + if AuthorizeURL != "https://accounts.google.com/o/oauth2/v2/auth" { + t.Errorf("AuthorizeURL 不匹配: got %s", AuthorizeURL) + } + if TokenURL != "https://oauth2.googleapis.com/token" { + t.Errorf("TokenURL 不匹配: got %s", TokenURL) + } + if UserInfoURL != "https://www.googleapis.com/oauth2/v2/userinfo" { + t.Errorf("UserInfoURL 不匹配: got %s", UserInfoURL) + } + if ClientID != "1071006060591-tmhssin2h21lcre235vtolojh4g403ep.apps.googleusercontent.com" { + t.Errorf("ClientID 不匹配: got %s", ClientID) + } + if ClientSecret != "" { + t.Error("ClientSecret 应为空字符串") + } + if RedirectURI != "http://localhost:8085/callback" { + t.Errorf("RedirectURI 不匹配: got %s", RedirectURI) + } + if UserAgent != "antigravity/1.15.8 windows/amd64" { + t.Errorf("UserAgent 不匹配: got %s", UserAgent) + } + if SessionTTL != 30*time.Minute { + t.Errorf("SessionTTL 不匹配: got %v", SessionTTL) + } + if URLAvailabilityTTL != 5*time.Minute { + t.Errorf("URLAvailabilityTTL 不匹配: got %v", URLAvailabilityTTL) + } +} + +func TestScopes_包含必要范围(t *testing.T) { + expectedScopes := []string{ + "https://www.googleapis.com/auth/cloud-platform", + "https://www.googleapis.com/auth/userinfo.email", + "https://www.googleapis.com/auth/userinfo.profile", + "https://www.googleapis.com/auth/cclog", + "https://www.googleapis.com/auth/experimentsandconfigs", + } + + for _, scope := range expectedScopes { + if !strings.Contains(Scopes, scope) { + t.Errorf("Scopes 缺少 %s", scope) + } + } +} diff --git a/backend/internal/pkg/antigravity/response_transformer.go b/backend/internal/pkg/antigravity/response_transformer.go index 463033f1..f12effb6 100644 --- a/backend/internal/pkg/antigravity/response_transformer.go +++ b/backend/internal/pkg/antigravity/response_transformer.go @@ -1,10 +1,13 @@ package antigravity import ( + "crypto/rand" "encoding/json" "fmt" "log" "strings" + "sync/atomic" + "time" ) // TransformGeminiToClaude 将 Gemini 响应转换为 Claude 格式(非流式) @@ -341,12 +344,30 @@ func buildGroundingText(grounding *GeminiGroundingMetadata) string { return builder.String() } -// generateRandomID 生成随机 ID +// fallbackCounter 降级伪随机 ID 的全局计数器,混入 seed 避免高并发下 UnixNano 相同导致碰撞。 +var fallbackCounter uint64 + +// generateRandomID 生成密码学安全的随机 ID func generateRandomID() string { const chars = "abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789" - result := make([]byte, 12) - for i := range result { - result[i] = chars[i%len(chars)] + id := make([]byte, 12) + randBytes := make([]byte, 12) + if _, err := rand.Read(randBytes); err != nil { + // 避免在请求路径里 panic:极端情况下熵源不可用时降级为伪随机。 + // 这里主要用于生成响应/工具调用的临时 ID,安全要求不高但需尽量避免碰撞。 + cnt := atomic.AddUint64(&fallbackCounter, 1) + seed := uint64(time.Now().UnixNano()) ^ cnt + seed ^= uint64(len(err.Error())) << 32 + for i := range id { + seed ^= seed << 13 + seed ^= seed >> 7 + seed ^= seed << 17 + id[i] = chars[int(seed)%len(chars)] + } + return string(id) } - return string(result) + for i, b := range randBytes { + id[i] = chars[int(b)%len(chars)] + } + return string(id) } diff --git a/backend/internal/pkg/antigravity/response_transformer_test.go b/backend/internal/pkg/antigravity/response_transformer_test.go new file mode 100644 index 00000000..da402b17 --- /dev/null +++ b/backend/internal/pkg/antigravity/response_transformer_test.go @@ -0,0 +1,109 @@ +//go:build unit + +package antigravity + +import ( + "sync" + "sync/atomic" + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +// --- Task 7: 验证 generateRandomID 和降级碰撞防护 --- + +func TestGenerateRandomID_Uniqueness(t *testing.T) { + seen := make(map[string]struct{}, 100) + for i := 0; i < 100; i++ { + id := generateRandomID() + require.Len(t, id, 12, "ID 长度应为 12") + _, dup := seen[id] + require.False(t, dup, "第 %d 次调用生成了重复 ID: %s", i, id) + seen[id] = struct{}{} + } +} + +func TestFallbackCounter_Increments(t *testing.T) { + // 验证 fallbackCounter 的原子递增行为确保降级分支不会生成相同 seed + before := atomic.LoadUint64(&fallbackCounter) + cnt1 := atomic.AddUint64(&fallbackCounter, 1) + cnt2 := atomic.AddUint64(&fallbackCounter, 1) + require.Equal(t, before+1, cnt1, "第一次递增应为 before+1") + require.Equal(t, before+2, cnt2, "第二次递增应为 before+2") + require.NotEqual(t, cnt1, cnt2, "连续两次递增的计数器值应不同") +} + +func TestFallbackCounter_ConcurrentIncrements(t *testing.T) { + // 验证并发递增的原子性 — 每次递增都应产生唯一值 + const goroutines = 50 + results := make([]uint64, goroutines) + var wg sync.WaitGroup + wg.Add(goroutines) + + for i := 0; i < goroutines; i++ { + go func(idx int) { + defer wg.Done() + results[idx] = atomic.AddUint64(&fallbackCounter, 1) + }(i) + } + wg.Wait() + + // 所有结果应唯一 + seen := make(map[uint64]bool, goroutines) + for _, v := range results { + assert.False(t, seen[v], "并发递增产生了重复值: %d", v) + seen[v] = true + } +} + +func TestGenerateRandomID_Charset(t *testing.T) { + const validChars = "abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789" + validSet := make(map[byte]struct{}, len(validChars)) + for i := 0; i < len(validChars); i++ { + validSet[validChars[i]] = struct{}{} + } + + for i := 0; i < 50; i++ { + id := generateRandomID() + for j := 0; j < len(id); j++ { + _, ok := validSet[id[j]] + require.True(t, ok, "ID 包含非法字符: %c (ID=%s)", id[j], id) + } + } +} + +func TestGenerateRandomID_Length(t *testing.T) { + for i := 0; i < 100; i++ { + id := generateRandomID() + assert.Len(t, id, 12, "每次生成的 ID 长度应为 12") + } +} + +func TestGenerateRandomID_ConcurrentUniqueness(t *testing.T) { + // 验证并发调用不会产生重复 ID + const goroutines = 100 + results := make([]string, goroutines) + var wg sync.WaitGroup + wg.Add(goroutines) + + for i := 0; i < goroutines; i++ { + go func(idx int) { + defer wg.Done() + results[idx] = generateRandomID() + }(i) + } + wg.Wait() + + seen := make(map[string]bool, goroutines) + for _, id := range results { + assert.False(t, seen[id], "并发调用产生了重复 ID: %s", id) + seen[id] = true + } +} + +func BenchmarkGenerateRandomID(b *testing.B) { + for i := 0; i < b.N; i++ { + _ = generateRandomID() + } +} diff --git a/backend/internal/pkg/ctxkey/ctxkey.go b/backend/internal/pkg/ctxkey/ctxkey.go index 0c4d82f7..b13d66cb 100644 --- a/backend/internal/pkg/ctxkey/ctxkey.go +++ b/backend/internal/pkg/ctxkey/ctxkey.go @@ -8,9 +8,21 @@ const ( // ForcePlatform 强制平台(用于 /antigravity 路由),由 middleware.ForcePlatform 设置 ForcePlatform Key = "ctx_force_platform" + // RequestID 为服务端生成/透传的请求 ID。 + RequestID Key = "ctx_request_id" + // ClientRequestID 客户端请求的唯一标识,用于追踪请求全生命周期(用于 Ops 监控与排障)。 ClientRequestID Key = "ctx_client_request_id" + // Model 请求模型标识(用于统一请求链路日志字段)。 + Model Key = "ctx_model" + + // Platform 当前请求最终命中的平台(用于统一请求链路日志字段)。 + Platform Key = "ctx_platform" + + // AccountID 当前请求最终命中的账号 ID(用于统一请求链路日志字段)。 + AccountID Key = "ctx_account_id" + // RetryCount 表示当前请求在网关层的重试次数(用于 Ops 记录与排障)。 RetryCount Key = "ctx_retry_count" @@ -32,4 +44,12 @@ const ( // SingleAccountRetry 标识当前请求处于单账号 503 退避重试模式。 // 在此模式下,Service 层的模型限流预检查将等待限流过期而非直接切换账号。 SingleAccountRetry Key = "ctx_single_account_retry" + + // PrefetchedStickyAccountID 标识上游(通常 handler)预取到的 sticky session 账号 ID。 + // Service 层可复用该值,避免同请求链路重复读取 Redis。 + PrefetchedStickyAccountID Key = "ctx_prefetched_sticky_account_id" + + // PrefetchedStickyGroupID 标识上游预取 sticky session 时所使用的分组 ID。 + // Service 层仅在分组匹配时复用 PrefetchedStickyAccountID,避免分组切换重试误用旧 sticky。 + PrefetchedStickyGroupID Key = "ctx_prefetched_sticky_group_id" ) diff --git a/backend/internal/pkg/gemini/models.go b/backend/internal/pkg/gemini/models.go index 424e8ddb..c300b17d 100644 --- a/backend/internal/pkg/gemini/models.go +++ b/backend/internal/pkg/gemini/models.go @@ -21,6 +21,7 @@ func DefaultModels() []Model { {Name: "models/gemini-2.5-pro", SupportedGenerationMethods: methods}, {Name: "models/gemini-3-flash-preview", SupportedGenerationMethods: methods}, {Name: "models/gemini-3-pro-preview", SupportedGenerationMethods: methods}, + {Name: "models/gemini-3.1-pro-preview", SupportedGenerationMethods: methods}, } } diff --git a/backend/internal/pkg/geminicli/constants.go b/backend/internal/pkg/geminicli/constants.go index d4d52116..f85e3b97 100644 --- a/backend/internal/pkg/geminicli/constants.go +++ b/backend/internal/pkg/geminicli/constants.go @@ -38,8 +38,13 @@ const ( // GeminiCLIOAuthClientID/Secret are the public OAuth client credentials used by Google Gemini CLI. // They enable the "login without creating your own OAuth client" experience, but Google may // restrict which scopes are allowed for this client. - GeminiCLIOAuthClientID = "681255809395-oo8ft2oprdrnp9e3aqf6av3hmdib135j.apps.googleusercontent.com" - GeminiCLIOAuthClientSecret = "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" + GeminiCLIOAuthClientID = "681255809395-oo8ft2oprdrnp9e3aqf6av3hmdib135j.apps.googleusercontent.com" + // GeminiCLIOAuthClientSecret is intentionally not embedded in this repository. + // If you rely on the built-in Gemini CLI OAuth client, you MUST provide its client_secret via config/env. + GeminiCLIOAuthClientSecret = "" + + // GeminiCLIOAuthClientSecretEnv is the environment variable name for the built-in client secret. + GeminiCLIOAuthClientSecretEnv = "GEMINI_CLI_OAUTH_CLIENT_SECRET" SessionTTL = 30 * time.Minute diff --git a/backend/internal/pkg/geminicli/models.go b/backend/internal/pkg/geminicli/models.go index 08e69886..1fc4d983 100644 --- a/backend/internal/pkg/geminicli/models.go +++ b/backend/internal/pkg/geminicli/models.go @@ -16,6 +16,7 @@ var DefaultModels = []Model{ {ID: "gemini-2.5-pro", Type: "model", DisplayName: "Gemini 2.5 Pro", CreatedAt: ""}, {ID: "gemini-3-flash-preview", Type: "model", DisplayName: "Gemini 3 Flash Preview", CreatedAt: ""}, {ID: "gemini-3-pro-preview", Type: "model", DisplayName: "Gemini 3 Pro Preview", CreatedAt: ""}, + {ID: "gemini-3.1-pro-preview", Type: "model", DisplayName: "Gemini 3.1 Pro Preview", CreatedAt: ""}, } // DefaultTestModel is the default model to preselect in test flows. diff --git a/backend/internal/pkg/geminicli/oauth.go b/backend/internal/pkg/geminicli/oauth.go index c71e8aad..b10b5750 100644 --- a/backend/internal/pkg/geminicli/oauth.go +++ b/backend/internal/pkg/geminicli/oauth.go @@ -6,10 +6,14 @@ import ( "encoding/base64" "encoding/hex" "fmt" + "net/http" "net/url" + "os" "strings" "sync" "time" + + infraerrors "github.com/Wei-Shaw/sub2api/internal/pkg/errors" ) type OAuthConfig struct { @@ -164,15 +168,24 @@ func EffectiveOAuthConfig(cfg OAuthConfig, oauthType string) (OAuthConfig, error } // Fall back to built-in Gemini CLI OAuth client when not configured. + // SECURITY: This repo does not embed the built-in client secret; it must be provided via env. if effective.ClientID == "" && effective.ClientSecret == "" { + secret := strings.TrimSpace(GeminiCLIOAuthClientSecret) + if secret == "" { + if v, ok := os.LookupEnv(GeminiCLIOAuthClientSecretEnv); ok { + secret = strings.TrimSpace(v) + } + } + if secret == "" { + return OAuthConfig{}, infraerrors.Newf(http.StatusBadRequest, "GEMINI_CLI_OAUTH_CLIENT_SECRET_MISSING", "built-in Gemini CLI OAuth client_secret is not configured; set %s or provide a custom OAuth client", GeminiCLIOAuthClientSecretEnv) + } effective.ClientID = GeminiCLIOAuthClientID - effective.ClientSecret = GeminiCLIOAuthClientSecret + effective.ClientSecret = secret } else if effective.ClientID == "" || effective.ClientSecret == "" { - return OAuthConfig{}, fmt.Errorf("OAuth client not configured: please set both client_id and client_secret (or leave both empty to use the built-in Gemini CLI client)") + return OAuthConfig{}, infraerrors.New(http.StatusBadRequest, "GEMINI_OAUTH_CLIENT_NOT_CONFIGURED", "OAuth client not configured: please set both client_id and client_secret (or leave both empty to use the built-in Gemini CLI client)") } - isBuiltinClient := effective.ClientID == GeminiCLIOAuthClientID && - effective.ClientSecret == GeminiCLIOAuthClientSecret + isBuiltinClient := effective.ClientID == GeminiCLIOAuthClientID if effective.Scopes == "" { // Use different default scopes based on OAuth type diff --git a/backend/internal/pkg/geminicli/oauth_test.go b/backend/internal/pkg/geminicli/oauth_test.go index 0770730a..14bc3c6b 100644 --- a/backend/internal/pkg/geminicli/oauth_test.go +++ b/backend/internal/pkg/geminicli/oauth_test.go @@ -1,11 +1,439 @@ package geminicli import ( + "encoding/hex" "strings" + "sync" "testing" + "time" ) +// --------------------------------------------------------------------------- +// SessionStore 测试 +// --------------------------------------------------------------------------- + +func TestSessionStore_SetAndGet(t *testing.T) { + store := NewSessionStore() + defer store.Stop() + + session := &OAuthSession{ + State: "test-state", + OAuthType: "code_assist", + CreatedAt: time.Now(), + } + store.Set("sid-1", session) + + got, ok := store.Get("sid-1") + if !ok { + t.Fatal("期望 Get 返回 ok=true,实际返回 false") + } + if got.State != "test-state" { + t.Errorf("期望 State=%q,实际=%q", "test-state", got.State) + } +} + +func TestSessionStore_GetNotFound(t *testing.T) { + store := NewSessionStore() + defer store.Stop() + + _, ok := store.Get("不存在的ID") + if ok { + t.Error("期望不存在的 sessionID 返回 ok=false") + } +} + +func TestSessionStore_GetExpired(t *testing.T) { + store := NewSessionStore() + defer store.Stop() + + // 创建一个已过期的 session(CreatedAt 设置为 SessionTTL+1 分钟之前) + session := &OAuthSession{ + State: "expired-state", + OAuthType: "code_assist", + CreatedAt: time.Now().Add(-(SessionTTL + 1*time.Minute)), + } + store.Set("expired-sid", session) + + _, ok := store.Get("expired-sid") + if ok { + t.Error("期望过期的 session 返回 ok=false") + } +} + +func TestSessionStore_Delete(t *testing.T) { + store := NewSessionStore() + defer store.Stop() + + session := &OAuthSession{ + State: "to-delete", + OAuthType: "code_assist", + CreatedAt: time.Now(), + } + store.Set("del-sid", session) + + // 先确认存在 + if _, ok := store.Get("del-sid"); !ok { + t.Fatal("删除前 session 应该存在") + } + + store.Delete("del-sid") + + if _, ok := store.Get("del-sid"); ok { + t.Error("删除后 session 不应该存在") + } +} + +func TestSessionStore_Stop_Idempotent(t *testing.T) { + store := NewSessionStore() + + // 多次调用 Stop 不应 panic + store.Stop() + store.Stop() + store.Stop() +} + +func TestSessionStore_ConcurrentAccess(t *testing.T) { + store := NewSessionStore() + defer store.Stop() + + const goroutines = 50 + var wg sync.WaitGroup + wg.Add(goroutines * 3) + + // 并发写入 + for i := 0; i < goroutines; i++ { + go func(idx int) { + defer wg.Done() + sid := "concurrent-" + string(rune('A'+idx%26)) + store.Set(sid, &OAuthSession{ + State: sid, + OAuthType: "code_assist", + CreatedAt: time.Now(), + }) + }(i) + } + + // 并发读取 + for i := 0; i < goroutines; i++ { + go func(idx int) { + defer wg.Done() + sid := "concurrent-" + string(rune('A'+idx%26)) + store.Get(sid) // 可能找到也可能没找到,关键是不 panic + }(i) + } + + // 并发删除 + for i := 0; i < goroutines; i++ { + go func(idx int) { + defer wg.Done() + sid := "concurrent-" + string(rune('A'+idx%26)) + store.Delete(sid) + }(i) + } + + wg.Wait() +} + +// --------------------------------------------------------------------------- +// GenerateRandomBytes 测试 +// --------------------------------------------------------------------------- + +func TestGenerateRandomBytes(t *testing.T) { + tests := []int{0, 1, 16, 32, 64} + for _, n := range tests { + b, err := GenerateRandomBytes(n) + if err != nil { + t.Errorf("GenerateRandomBytes(%d) 出错: %v", n, err) + continue + } + if len(b) != n { + t.Errorf("GenerateRandomBytes(%d) 返回长度=%d,期望=%d", n, len(b), n) + } + } +} + +func TestGenerateRandomBytes_Uniqueness(t *testing.T) { + // 两次调用应该返回不同的结果(极小概率相同,32字节足够) + a, _ := GenerateRandomBytes(32) + b, _ := GenerateRandomBytes(32) + if string(a) == string(b) { + t.Error("两次 GenerateRandomBytes(32) 返回了相同结果,随机性可能有问题") + } +} + +// --------------------------------------------------------------------------- +// GenerateState 测试 +// --------------------------------------------------------------------------- + +func TestGenerateState(t *testing.T) { + state, err := GenerateState() + if err != nil { + t.Fatalf("GenerateState() 出错: %v", err) + } + if state == "" { + t.Error("GenerateState() 返回空字符串") + } + // base64url 编码不应包含 padding '=' + if strings.Contains(state, "=") { + t.Errorf("GenerateState() 结果包含 '=' padding: %s", state) + } + // base64url 不应包含 '+' 或 '/' + if strings.ContainsAny(state, "+/") { + t.Errorf("GenerateState() 结果包含非 base64url 字符: %s", state) + } +} + +// --------------------------------------------------------------------------- +// GenerateSessionID 测试 +// --------------------------------------------------------------------------- + +func TestGenerateSessionID(t *testing.T) { + sid, err := GenerateSessionID() + if err != nil { + t.Fatalf("GenerateSessionID() 出错: %v", err) + } + // 16 字节 -> 32 个 hex 字符 + if len(sid) != 32 { + t.Errorf("GenerateSessionID() 长度=%d,期望=32", len(sid)) + } + // 必须是合法的 hex 字符串 + if _, err := hex.DecodeString(sid); err != nil { + t.Errorf("GenerateSessionID() 不是合法的 hex 字符串: %s, err=%v", sid, err) + } +} + +func TestGenerateSessionID_Uniqueness(t *testing.T) { + a, _ := GenerateSessionID() + b, _ := GenerateSessionID() + if a == b { + t.Error("两次 GenerateSessionID() 返回了相同结果") + } +} + +// --------------------------------------------------------------------------- +// GenerateCodeVerifier 测试 +// --------------------------------------------------------------------------- + +func TestGenerateCodeVerifier(t *testing.T) { + verifier, err := GenerateCodeVerifier() + if err != nil { + t.Fatalf("GenerateCodeVerifier() 出错: %v", err) + } + if verifier == "" { + t.Error("GenerateCodeVerifier() 返回空字符串") + } + // RFC 7636 要求 code_verifier 至少 43 个字符 + if len(verifier) < 43 { + t.Errorf("GenerateCodeVerifier() 长度=%d,RFC 7636 要求至少 43 字符", len(verifier)) + } + // base64url 编码不应包含 padding 和非 URL 安全字符 + if strings.Contains(verifier, "=") { + t.Errorf("GenerateCodeVerifier() 包含 '=' padding: %s", verifier) + } + if strings.ContainsAny(verifier, "+/") { + t.Errorf("GenerateCodeVerifier() 包含非 base64url 字符: %s", verifier) + } +} + +// --------------------------------------------------------------------------- +// GenerateCodeChallenge 测试 +// --------------------------------------------------------------------------- + +func TestGenerateCodeChallenge(t *testing.T) { + // 使用已知输入验证输出 + // RFC 7636 附录 B 示例: verifier = "dBjftJeZ4CVP-mB92K27uhbUJU1p1r_wW1gFWFOEjXk" + // 预期 challenge = "E9Melhoa2OwvFrEMTJguCHaoeK1t8URWbuGJSstw-cM" + verifier := "dBjftJeZ4CVP-mB92K27uhbUJU1p1r_wW1gFWFOEjXk" + expected := "E9Melhoa2OwvFrEMTJguCHaoeK1t8URWbuGJSstw-cM" + + challenge := GenerateCodeChallenge(verifier) + if challenge != expected { + t.Errorf("GenerateCodeChallenge(%q) = %q,期望 %q", verifier, challenge, expected) + } +} + +func TestGenerateCodeChallenge_NoPadding(t *testing.T) { + challenge := GenerateCodeChallenge("test-verifier-string") + if strings.Contains(challenge, "=") { + t.Errorf("GenerateCodeChallenge() 结果包含 '=' padding: %s", challenge) + } +} + +// --------------------------------------------------------------------------- +// base64URLEncode 测试 +// --------------------------------------------------------------------------- + +func TestBase64URLEncode(t *testing.T) { + tests := []struct { + name string + input []byte + }{ + {"空字节", []byte{}}, + {"单字节", []byte{0xff}}, + {"多字节", []byte{0x01, 0x02, 0x03, 0x04, 0x05}}, + {"全零", []byte{0x00, 0x00, 0x00}}, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + result := base64URLEncode(tt.input) + // 不应包含 '=' padding + if strings.Contains(result, "=") { + t.Errorf("base64URLEncode(%v) 包含 '=' padding: %s", tt.input, result) + } + // 不应包含标准 base64 的 '+' 或 '/' + if strings.ContainsAny(result, "+/") { + t.Errorf("base64URLEncode(%v) 包含非 URL 安全字符: %s", tt.input, result) + } + }) + } +} + +// --------------------------------------------------------------------------- +// hasRestrictedScope 测试 +// --------------------------------------------------------------------------- + +func TestHasRestrictedScope(t *testing.T) { + tests := []struct { + scope string + expected bool + }{ + // 受限 scope + {"https://www.googleapis.com/auth/generative-language", true}, + {"https://www.googleapis.com/auth/generative-language.retriever", true}, + {"https://www.googleapis.com/auth/generative-language.tuning", true}, + {"https://www.googleapis.com/auth/drive", true}, + {"https://www.googleapis.com/auth/drive.readonly", true}, + {"https://www.googleapis.com/auth/drive.file", true}, + // 非受限 scope + {"https://www.googleapis.com/auth/cloud-platform", false}, + {"https://www.googleapis.com/auth/userinfo.email", false}, + {"https://www.googleapis.com/auth/userinfo.profile", false}, + // 边界情况 + {"", false}, + {"random-scope", false}, + } + for _, tt := range tests { + t.Run(tt.scope, func(t *testing.T) { + got := hasRestrictedScope(tt.scope) + if got != tt.expected { + t.Errorf("hasRestrictedScope(%q) = %v,期望 %v", tt.scope, got, tt.expected) + } + }) + } +} + +// --------------------------------------------------------------------------- +// BuildAuthorizationURL 测试 +// --------------------------------------------------------------------------- + +func TestBuildAuthorizationURL(t *testing.T) { + t.Setenv(GeminiCLIOAuthClientSecretEnv, "test-secret") + + authURL, err := BuildAuthorizationURL( + OAuthConfig{}, + "test-state", + "test-challenge", + "https://example.com/callback", + "", + "code_assist", + ) + if err != nil { + t.Fatalf("BuildAuthorizationURL() 出错: %v", err) + } + + // 检查返回的 URL 包含期望的参数 + checks := []string{ + "response_type=code", + "client_id=" + GeminiCLIOAuthClientID, + "redirect_uri=", + "state=test-state", + "code_challenge=test-challenge", + "code_challenge_method=S256", + "access_type=offline", + "prompt=consent", + "include_granted_scopes=true", + } + for _, check := range checks { + if !strings.Contains(authURL, check) { + t.Errorf("BuildAuthorizationURL() URL 缺少参数 %q\nURL: %s", check, authURL) + } + } + + // 不应包含 project_id(因为传的是空字符串) + if strings.Contains(authURL, "project_id=") { + t.Errorf("BuildAuthorizationURL() 空 projectID 时不应包含 project_id 参数") + } + + // URL 应该以正确的授权端点开头 + if !strings.HasPrefix(authURL, AuthorizeURL+"?") { + t.Errorf("BuildAuthorizationURL() URL 应以 %s? 开头,实际: %s", AuthorizeURL, authURL) + } +} + +func TestBuildAuthorizationURL_EmptyRedirectURI(t *testing.T) { + t.Setenv(GeminiCLIOAuthClientSecretEnv, "test-secret") + + _, err := BuildAuthorizationURL( + OAuthConfig{}, + "test-state", + "test-challenge", + "", // 空 redirectURI + "", + "code_assist", + ) + if err == nil { + t.Error("BuildAuthorizationURL() 空 redirectURI 应该报错") + } + if !strings.Contains(err.Error(), "redirect_uri") { + t.Errorf("错误消息应包含 'redirect_uri',实际: %v", err) + } +} + +func TestBuildAuthorizationURL_WithProjectID(t *testing.T) { + t.Setenv(GeminiCLIOAuthClientSecretEnv, "test-secret") + + authURL, err := BuildAuthorizationURL( + OAuthConfig{}, + "test-state", + "test-challenge", + "https://example.com/callback", + "my-project-123", + "code_assist", + ) + if err != nil { + t.Fatalf("BuildAuthorizationURL() 出错: %v", err) + } + if !strings.Contains(authURL, "project_id=my-project-123") { + t.Errorf("BuildAuthorizationURL() 带 projectID 时应包含 project_id 参数\nURL: %s", authURL) + } +} + +func TestBuildAuthorizationURL_OAuthConfigError(t *testing.T) { + // 不设置环境变量,也不提供 client 凭据,EffectiveOAuthConfig 应该报错 + t.Setenv(GeminiCLIOAuthClientSecretEnv, "") + + _, err := BuildAuthorizationURL( + OAuthConfig{}, + "test-state", + "test-challenge", + "https://example.com/callback", + "", + "code_assist", + ) + if err == nil { + t.Error("当 EffectiveOAuthConfig 失败时,BuildAuthorizationURL 应该返回错误") + } +} + +// --------------------------------------------------------------------------- +// EffectiveOAuthConfig 测试 - 原有测试 +// --------------------------------------------------------------------------- + func TestEffectiveOAuthConfig_GoogleOne(t *testing.T) { + // 内置的 Gemini CLI client secret 不嵌入在此仓库中。 + // 测试通过环境变量设置一个假的 secret 来模拟运维配置。 + t.Setenv(GeminiCLIOAuthClientSecretEnv, "test-built-in-secret") + tests := []struct { name string input OAuthConfig @@ -15,7 +443,7 @@ func TestEffectiveOAuthConfig_GoogleOne(t *testing.T) { wantErr bool }{ { - name: "Google One with built-in client (empty config)", + name: "Google One 使用内置客户端(空配置)", input: OAuthConfig{}, oauthType: "google_one", wantClientID: GeminiCLIOAuthClientID, @@ -23,18 +451,18 @@ func TestEffectiveOAuthConfig_GoogleOne(t *testing.T) { wantErr: false, }, { - name: "Google One always uses built-in client (even if custom credentials passed)", + name: "Google One 使用自定义客户端(传入自定义凭据时使用自定义)", input: OAuthConfig{ ClientID: "custom-client-id", ClientSecret: "custom-client-secret", }, oauthType: "google_one", wantClientID: "custom-client-id", - wantScopes: DefaultCodeAssistScopes, // Uses code assist scopes even with custom client + wantScopes: DefaultCodeAssistScopes, wantErr: false, }, { - name: "Google One with built-in client and custom scopes (should filter restricted scopes)", + name: "Google One 内置客户端 + 自定义 scopes(应过滤受限 scopes)", input: OAuthConfig{ Scopes: "https://www.googleapis.com/auth/cloud-platform https://www.googleapis.com/auth/generative-language.retriever https://www.googleapis.com/auth/drive.readonly", }, @@ -44,7 +472,7 @@ func TestEffectiveOAuthConfig_GoogleOne(t *testing.T) { wantErr: false, }, { - name: "Google One with built-in client and only restricted scopes (should fallback to default)", + name: "Google One 内置客户端 + 仅受限 scopes(应回退到默认)", input: OAuthConfig{ Scopes: "https://www.googleapis.com/auth/generative-language.retriever https://www.googleapis.com/auth/drive.readonly", }, @@ -54,7 +482,7 @@ func TestEffectiveOAuthConfig_GoogleOne(t *testing.T) { wantErr: false, }, { - name: "Code Assist with built-in client", + name: "Code Assist 使用内置客户端", input: OAuthConfig{}, oauthType: "code_assist", wantClientID: GeminiCLIOAuthClientID, @@ -84,7 +512,9 @@ func TestEffectiveOAuthConfig_GoogleOne(t *testing.T) { } func TestEffectiveOAuthConfig_ScopeFiltering(t *testing.T) { - // Test that Google One with built-in client filters out restricted scopes + t.Setenv(GeminiCLIOAuthClientSecretEnv, "test-built-in-secret") + + // 测试 Google One + 内置客户端过滤受限 scopes cfg, err := EffectiveOAuthConfig(OAuthConfig{ Scopes: "https://www.googleapis.com/auth/cloud-platform https://www.googleapis.com/auth/userinfo.email https://www.googleapis.com/auth/generative-language.retriever https://www.googleapis.com/auth/drive.readonly https://www.googleapis.com/auth/userinfo.profile", }, "google_one") @@ -93,21 +523,240 @@ func TestEffectiveOAuthConfig_ScopeFiltering(t *testing.T) { t.Fatalf("EffectiveOAuthConfig() error = %v", err) } - // Should only contain cloud-platform, userinfo.email, and userinfo.profile - // Should NOT contain generative-language or drive scopes + // 应仅包含 cloud-platform、userinfo.email 和 userinfo.profile + // 不应包含 generative-language 或 drive scopes if strings.Contains(cfg.Scopes, "generative-language") { - t.Errorf("Scopes should not contain generative-language when using built-in client, got: %v", cfg.Scopes) + t.Errorf("使用内置客户端时 Scopes 不应包含 generative-language,实际: %v", cfg.Scopes) } if strings.Contains(cfg.Scopes, "drive") { - t.Errorf("Scopes should not contain drive when using built-in client, got: %v", cfg.Scopes) + t.Errorf("使用内置客户端时 Scopes 不应包含 drive,实际: %v", cfg.Scopes) } if !strings.Contains(cfg.Scopes, "cloud-platform") { - t.Errorf("Scopes should contain cloud-platform, got: %v", cfg.Scopes) + t.Errorf("Scopes 应包含 cloud-platform,实际: %v", cfg.Scopes) } if !strings.Contains(cfg.Scopes, "userinfo.email") { - t.Errorf("Scopes should contain userinfo.email, got: %v", cfg.Scopes) + t.Errorf("Scopes 应包含 userinfo.email,实际: %v", cfg.Scopes) } if !strings.Contains(cfg.Scopes, "userinfo.profile") { - t.Errorf("Scopes should contain userinfo.profile, got: %v", cfg.Scopes) + t.Errorf("Scopes 应包含 userinfo.profile,实际: %v", cfg.Scopes) + } +} + +// --------------------------------------------------------------------------- +// EffectiveOAuthConfig 测试 - 新增分支覆盖 +// --------------------------------------------------------------------------- + +func TestEffectiveOAuthConfig_OnlyClientID_NoSecret(t *testing.T) { + // 只提供 clientID 不提供 secret 应报错 + _, err := EffectiveOAuthConfig(OAuthConfig{ + ClientID: "some-client-id", + }, "code_assist") + if err == nil { + t.Error("只提供 ClientID 不提供 ClientSecret 应该报错") + } + if !strings.Contains(err.Error(), "client_id") || !strings.Contains(err.Error(), "client_secret") { + t.Errorf("错误消息应提及 client_id 和 client_secret,实际: %v", err) + } +} + +func TestEffectiveOAuthConfig_OnlyClientSecret_NoID(t *testing.T) { + // 只提供 secret 不提供 clientID 应报错 + _, err := EffectiveOAuthConfig(OAuthConfig{ + ClientSecret: "some-client-secret", + }, "code_assist") + if err == nil { + t.Error("只提供 ClientSecret 不提供 ClientID 应该报错") + } + if !strings.Contains(err.Error(), "client_id") || !strings.Contains(err.Error(), "client_secret") { + t.Errorf("错误消息应提及 client_id 和 client_secret,实际: %v", err) + } +} + +func TestEffectiveOAuthConfig_AIStudio_DefaultScopes_BuiltinClient(t *testing.T) { + t.Setenv(GeminiCLIOAuthClientSecretEnv, "test-built-in-secret") + + // ai_studio 类型,使用内置客户端,scopes 为空 -> 应使用 DefaultCodeAssistScopes(因为内置客户端不能请求 generative-language scope) + cfg, err := EffectiveOAuthConfig(OAuthConfig{}, "ai_studio") + if err != nil { + t.Fatalf("EffectiveOAuthConfig() error = %v", err) + } + if cfg.Scopes != DefaultCodeAssistScopes { + t.Errorf("ai_studio + 内置客户端应使用 DefaultCodeAssistScopes,实际: %q", cfg.Scopes) + } +} + +func TestEffectiveOAuthConfig_AIStudio_DefaultScopes_CustomClient(t *testing.T) { + // ai_studio 类型,使用自定义客户端,scopes 为空 -> 应使用 DefaultAIStudioScopes + cfg, err := EffectiveOAuthConfig(OAuthConfig{ + ClientID: "custom-id", + ClientSecret: "custom-secret", + }, "ai_studio") + if err != nil { + t.Fatalf("EffectiveOAuthConfig() error = %v", err) + } + if cfg.Scopes != DefaultAIStudioScopes { + t.Errorf("ai_studio + 自定义客户端应使用 DefaultAIStudioScopes,实际: %q", cfg.Scopes) + } +} + +func TestEffectiveOAuthConfig_AIStudio_ScopeNormalization(t *testing.T) { + // ai_studio 类型,旧的 generative-language scope 应被归一化为 generative-language.retriever + cfg, err := EffectiveOAuthConfig(OAuthConfig{ + ClientID: "custom-id", + ClientSecret: "custom-secret", + Scopes: "https://www.googleapis.com/auth/generative-language https://www.googleapis.com/auth/cloud-platform", + }, "ai_studio") + if err != nil { + t.Fatalf("EffectiveOAuthConfig() error = %v", err) + } + if strings.Contains(cfg.Scopes, "auth/generative-language ") || strings.HasSuffix(cfg.Scopes, "auth/generative-language") { + // 确保不包含未归一化的旧 scope(仅 generative-language 而非 generative-language.retriever) + parts := strings.Fields(cfg.Scopes) + for _, p := range parts { + if p == "https://www.googleapis.com/auth/generative-language" { + t.Errorf("ai_studio 应将 generative-language 归一化为 generative-language.retriever,实际 scopes: %q", cfg.Scopes) + } + } + } + if !strings.Contains(cfg.Scopes, "generative-language.retriever") { + t.Errorf("ai_studio 归一化后应包含 generative-language.retriever,实际: %q", cfg.Scopes) + } +} + +func TestEffectiveOAuthConfig_CommaSeparatedScopes(t *testing.T) { + t.Setenv(GeminiCLIOAuthClientSecretEnv, "test-built-in-secret") + + // 逗号分隔的 scopes 应被归一化为空格分隔 + cfg, err := EffectiveOAuthConfig(OAuthConfig{ + ClientID: "custom-id", + ClientSecret: "custom-secret", + Scopes: "https://www.googleapis.com/auth/cloud-platform,https://www.googleapis.com/auth/userinfo.email", + }, "code_assist") + if err != nil { + t.Fatalf("EffectiveOAuthConfig() error = %v", err) + } + // 应该用空格分隔,而非逗号 + if strings.Contains(cfg.Scopes, ",") { + t.Errorf("逗号分隔的 scopes 应被归一化为空格分隔,实际: %q", cfg.Scopes) + } + if !strings.Contains(cfg.Scopes, "cloud-platform") { + t.Errorf("归一化后应包含 cloud-platform,实际: %q", cfg.Scopes) + } + if !strings.Contains(cfg.Scopes, "userinfo.email") { + t.Errorf("归一化后应包含 userinfo.email,实际: %q", cfg.Scopes) + } +} + +func TestEffectiveOAuthConfig_MixedCommaAndSpaceScopes(t *testing.T) { + // 混合逗号和空格分隔的 scopes + cfg, err := EffectiveOAuthConfig(OAuthConfig{ + ClientID: "custom-id", + ClientSecret: "custom-secret", + Scopes: "https://www.googleapis.com/auth/cloud-platform, https://www.googleapis.com/auth/userinfo.email https://www.googleapis.com/auth/userinfo.profile", + }, "code_assist") + if err != nil { + t.Fatalf("EffectiveOAuthConfig() error = %v", err) + } + parts := strings.Fields(cfg.Scopes) + if len(parts) != 3 { + t.Errorf("归一化后应有 3 个 scope,实际: %d,scopes: %q", len(parts), cfg.Scopes) + } +} + +func TestEffectiveOAuthConfig_WhitespaceTriming(t *testing.T) { + // 输入中的前后空白应被清理 + cfg, err := EffectiveOAuthConfig(OAuthConfig{ + ClientID: " custom-id ", + ClientSecret: " custom-secret ", + Scopes: " https://www.googleapis.com/auth/cloud-platform ", + }, "code_assist") + if err != nil { + t.Fatalf("EffectiveOAuthConfig() error = %v", err) + } + if cfg.ClientID != "custom-id" { + t.Errorf("ClientID 应去除前后空白,实际: %q", cfg.ClientID) + } + if cfg.ClientSecret != "custom-secret" { + t.Errorf("ClientSecret 应去除前后空白,实际: %q", cfg.ClientSecret) + } + if cfg.Scopes != "https://www.googleapis.com/auth/cloud-platform" { + t.Errorf("Scopes 应去除前后空白,实际: %q", cfg.Scopes) + } +} + +func TestEffectiveOAuthConfig_NoEnvSecret(t *testing.T) { + // 不设置环境变量且不提供凭据,应该报错 + t.Setenv(GeminiCLIOAuthClientSecretEnv, "") + + _, err := EffectiveOAuthConfig(OAuthConfig{}, "code_assist") + if err == nil { + t.Error("没有内置 secret 且未提供凭据时应该报错") + } + if !strings.Contains(err.Error(), GeminiCLIOAuthClientSecretEnv) { + t.Errorf("错误消息应提及环境变量 %s,实际: %v", GeminiCLIOAuthClientSecretEnv, err) + } +} + +func TestEffectiveOAuthConfig_AIStudio_BuiltinClient_CustomScopes(t *testing.T) { + t.Setenv(GeminiCLIOAuthClientSecretEnv, "test-built-in-secret") + + // ai_studio + 内置客户端 + 自定义 scopes -> 应过滤受限 scopes + cfg, err := EffectiveOAuthConfig(OAuthConfig{ + Scopes: "https://www.googleapis.com/auth/cloud-platform https://www.googleapis.com/auth/generative-language.retriever", + }, "ai_studio") + if err != nil { + t.Fatalf("EffectiveOAuthConfig() error = %v", err) + } + // 内置客户端应过滤 generative-language.retriever + if strings.Contains(cfg.Scopes, "generative-language") { + t.Errorf("ai_studio + 内置客户端应过滤受限 scopes,实际: %q", cfg.Scopes) + } + if !strings.Contains(cfg.Scopes, "cloud-platform") { + t.Errorf("应保留 cloud-platform scope,实际: %q", cfg.Scopes) + } +} + +func TestEffectiveOAuthConfig_UnknownOAuthType_DefaultScopes(t *testing.T) { + t.Setenv(GeminiCLIOAuthClientSecretEnv, "test-built-in-secret") + + // 未知的 oauthType 应回退到默认的 code_assist scopes + cfg, err := EffectiveOAuthConfig(OAuthConfig{}, "unknown_type") + if err != nil { + t.Fatalf("EffectiveOAuthConfig() error = %v", err) + } + if cfg.Scopes != DefaultCodeAssistScopes { + t.Errorf("未知 oauthType 应使用 DefaultCodeAssistScopes,实际: %q", cfg.Scopes) + } +} + +func TestEffectiveOAuthConfig_EmptyOAuthType_DefaultScopes(t *testing.T) { + t.Setenv(GeminiCLIOAuthClientSecretEnv, "test-built-in-secret") + + // 空的 oauthType 应走 default 分支,使用 DefaultCodeAssistScopes + cfg, err := EffectiveOAuthConfig(OAuthConfig{}, "") + if err != nil { + t.Fatalf("EffectiveOAuthConfig() error = %v", err) + } + if cfg.Scopes != DefaultCodeAssistScopes { + t.Errorf("空 oauthType 应使用 DefaultCodeAssistScopes,实际: %q", cfg.Scopes) + } +} + +func TestEffectiveOAuthConfig_CustomClient_NoScopeFiltering(t *testing.T) { + // 自定义客户端 + google_one + 包含受限 scopes -> 不应被过滤(因为不是内置客户端) + cfg, err := EffectiveOAuthConfig(OAuthConfig{ + ClientID: "custom-id", + ClientSecret: "custom-secret", + Scopes: "https://www.googleapis.com/auth/generative-language.retriever https://www.googleapis.com/auth/drive.readonly", + }, "google_one") + if err != nil { + t.Fatalf("EffectiveOAuthConfig() error = %v", err) + } + // 自定义客户端不应过滤任何 scope + if !strings.Contains(cfg.Scopes, "generative-language.retriever") { + t.Errorf("自定义客户端不应过滤 generative-language.retriever,实际: %q", cfg.Scopes) + } + if !strings.Contains(cfg.Scopes, "drive.readonly") { + t.Errorf("自定义客户端不应过滤 drive.readonly,实际: %q", cfg.Scopes) } } diff --git a/backend/internal/pkg/ip/ip.go b/backend/internal/pkg/ip/ip.go index 97109c0c..3f05ac41 100644 --- a/backend/internal/pkg/ip/ip.go +++ b/backend/internal/pkg/ip/ip.go @@ -44,6 +44,16 @@ func GetClientIP(c *gin.Context) string { return normalizeIP(c.ClientIP()) } +// GetTrustedClientIP 从 Gin 的可信代理解析链提取客户端 IP。 +// 该方法依赖 gin.Engine.SetTrustedProxies 配置,不会优先直接信任原始转发头值。 +// 适用于 ACL / 风控等安全敏感场景。 +func GetTrustedClientIP(c *gin.Context) string { + if c == nil { + return "" + } + return normalizeIP(c.ClientIP()) +} + // normalizeIP 规范化 IP 地址,去除端口号和空格。 func normalizeIP(ip string) string { ip = strings.TrimSpace(ip) @@ -54,29 +64,34 @@ func normalizeIP(ip string) string { return ip } -// isPrivateIP 检查 IP 是否为私有地址。 -func isPrivateIP(ipStr string) bool { - ip := net.ParseIP(ipStr) - if ip == nil { - return false - } +// privateNets 预编译私有 IP CIDR 块,避免每次调用 isPrivateIP 时重复解析 +var privateNets []*net.IPNet - // 私有 IP 范围 - privateBlocks := []string{ +func init() { + for _, cidr := range []string{ "10.0.0.0/8", "172.16.0.0/12", "192.168.0.0/16", "127.0.0.0/8", "::1/128", "fc00::/7", - } - - for _, block := range privateBlocks { - _, cidr, err := net.ParseCIDR(block) + } { + _, block, err := net.ParseCIDR(cidr) if err != nil { - continue + panic("invalid CIDR: " + cidr) } - if cidr.Contains(ip) { + privateNets = append(privateNets, block) + } +} + +// isPrivateIP 检查 IP 是否为私有地址。 +func isPrivateIP(ipStr string) bool { + ip := net.ParseIP(ipStr) + if ip == nil { + return false + } + for _, block := range privateNets { + if block.Contains(ip) { return true } } diff --git a/backend/internal/pkg/ip/ip_test.go b/backend/internal/pkg/ip/ip_test.go new file mode 100644 index 00000000..3839403c --- /dev/null +++ b/backend/internal/pkg/ip/ip_test.go @@ -0,0 +1,75 @@ +//go:build unit + +package ip + +import ( + "net/http/httptest" + "testing" + + "github.com/gin-gonic/gin" + "github.com/stretchr/testify/require" +) + +func TestIsPrivateIP(t *testing.T) { + tests := []struct { + name string + ip string + expected bool + }{ + // 私有 IPv4 + {"10.x 私有地址", "10.0.0.1", true}, + {"10.x 私有地址段末", "10.255.255.255", true}, + {"172.16.x 私有地址", "172.16.0.1", true}, + {"172.31.x 私有地址", "172.31.255.255", true}, + {"192.168.x 私有地址", "192.168.1.1", true}, + {"127.0.0.1 本地回环", "127.0.0.1", true}, + {"127.x 回环段", "127.255.255.255", true}, + + // 公网 IPv4 + {"8.8.8.8 公网 DNS", "8.8.8.8", false}, + {"1.1.1.1 公网", "1.1.1.1", false}, + {"172.15.255.255 非私有", "172.15.255.255", false}, + {"172.32.0.0 非私有", "172.32.0.0", false}, + {"11.0.0.1 公网", "11.0.0.1", false}, + + // IPv6 + {"::1 IPv6 回环", "::1", true}, + {"fc00:: IPv6 私有", "fc00::1", true}, + {"fd00:: IPv6 私有", "fd00::1", true}, + {"2001:db8::1 IPv6 公网", "2001:db8::1", false}, + + // 无效输入 + {"空字符串", "", false}, + {"非法字符串", "not-an-ip", false}, + {"不完整 IP", "192.168", false}, + } + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + got := isPrivateIP(tc.ip) + require.Equal(t, tc.expected, got, "isPrivateIP(%q)", tc.ip) + }) + } +} + +func TestGetTrustedClientIPUsesGinClientIP(t *testing.T) { + gin.SetMode(gin.TestMode) + + r := gin.New() + require.NoError(t, r.SetTrustedProxies(nil)) + + r.GET("/t", func(c *gin.Context) { + c.String(200, GetTrustedClientIP(c)) + }) + + w := httptest.NewRecorder() + req := httptest.NewRequest("GET", "/t", nil) + req.RemoteAddr = "9.9.9.9:12345" + req.Header.Set("X-Forwarded-For", "1.2.3.4") + req.Header.Set("X-Real-IP", "1.2.3.4") + req.Header.Set("CF-Connecting-IP", "1.2.3.4") + r.ServeHTTP(w, req) + + require.Equal(t, 200, w.Code) + require.Equal(t, "9.9.9.9", w.Body.String()) +} diff --git a/backend/internal/pkg/logger/config_adapter.go b/backend/internal/pkg/logger/config_adapter.go new file mode 100644 index 00000000..c34e448b --- /dev/null +++ b/backend/internal/pkg/logger/config_adapter.go @@ -0,0 +1,31 @@ +package logger + +import "github.com/Wei-Shaw/sub2api/internal/config" + +func OptionsFromConfig(cfg config.LogConfig) InitOptions { + return InitOptions{ + Level: cfg.Level, + Format: cfg.Format, + ServiceName: cfg.ServiceName, + Environment: cfg.Environment, + Caller: cfg.Caller, + StacktraceLevel: cfg.StacktraceLevel, + Output: OutputOptions{ + ToStdout: cfg.Output.ToStdout, + ToFile: cfg.Output.ToFile, + FilePath: cfg.Output.FilePath, + }, + Rotation: RotationOptions{ + MaxSizeMB: cfg.Rotation.MaxSizeMB, + MaxBackups: cfg.Rotation.MaxBackups, + MaxAgeDays: cfg.Rotation.MaxAgeDays, + Compress: cfg.Rotation.Compress, + LocalTime: cfg.Rotation.LocalTime, + }, + Sampling: SamplingOptions{ + Enabled: cfg.Sampling.Enabled, + Initial: cfg.Sampling.Initial, + Thereafter: cfg.Sampling.Thereafter, + }, + } +} diff --git a/backend/internal/pkg/logger/logger.go b/backend/internal/pkg/logger/logger.go new file mode 100644 index 00000000..80d92517 --- /dev/null +++ b/backend/internal/pkg/logger/logger.go @@ -0,0 +1,519 @@ +package logger + +import ( + "context" + "fmt" + "io" + "log" + "log/slog" + "os" + "path/filepath" + "strings" + "sync" + "time" + + "go.uber.org/zap" + "go.uber.org/zap/zapcore" + "gopkg.in/natefinch/lumberjack.v2" +) + +type Level = zapcore.Level + +const ( + LevelDebug = zapcore.DebugLevel + LevelInfo = zapcore.InfoLevel + LevelWarn = zapcore.WarnLevel + LevelError = zapcore.ErrorLevel + LevelFatal = zapcore.FatalLevel +) + +type Sink interface { + WriteLogEvent(event *LogEvent) +} + +type LogEvent struct { + Time time.Time + Level string + Component string + Message string + LoggerName string + Fields map[string]any +} + +var ( + mu sync.RWMutex + global *zap.Logger + sugar *zap.SugaredLogger + atomicLevel zap.AtomicLevel + initOptions InitOptions + currentSink Sink + stdLogUndo func() + bootstrapOnce sync.Once +) + +func InitBootstrap() { + bootstrapOnce.Do(func() { + if err := Init(bootstrapOptions()); err != nil { + _, _ = fmt.Fprintf(os.Stderr, "logger bootstrap init failed: %v\n", err) + } + }) +} + +func Init(options InitOptions) error { + mu.Lock() + defer mu.Unlock() + return initLocked(options) +} + +func initLocked(options InitOptions) error { + normalized := options.normalized() + zl, al, err := buildLogger(normalized) + if err != nil { + return err + } + + prev := global + global = zl + sugar = zl.Sugar() + atomicLevel = al + initOptions = normalized + + bridgeSlogLocked() + bridgeStdLogLocked() + + if prev != nil { + _ = prev.Sync() + } + return nil +} + +func Reconfigure(mutator func(*InitOptions) error) error { + mu.Lock() + defer mu.Unlock() + next := initOptions + if mutator != nil { + if err := mutator(&next); err != nil { + return err + } + } + return initLocked(next) +} + +func SetLevel(level string) error { + lv, ok := parseLevel(level) + if !ok { + return fmt.Errorf("invalid log level: %s", level) + } + + mu.Lock() + defer mu.Unlock() + atomicLevel.SetLevel(lv) + initOptions.Level = strings.ToLower(strings.TrimSpace(level)) + return nil +} + +func CurrentLevel() string { + mu.RLock() + defer mu.RUnlock() + if global == nil { + return "info" + } + return atomicLevel.Level().String() +} + +func SetSink(sink Sink) { + mu.Lock() + defer mu.Unlock() + currentSink = sink +} + +// WriteSinkEvent 直接写入日志 sink,不经过全局日志级别门控。 +// 用于需要“可观测性入库”与“业务输出级别”解耦的场景(例如 ops 系统日志索引)。 +func WriteSinkEvent(level, component, message string, fields map[string]any) { + mu.RLock() + sink := currentSink + mu.RUnlock() + if sink == nil { + return + } + + level = strings.ToLower(strings.TrimSpace(level)) + if level == "" { + level = "info" + } + component = strings.TrimSpace(component) + message = strings.TrimSpace(message) + if message == "" { + return + } + + eventFields := make(map[string]any, len(fields)+1) + for k, v := range fields { + eventFields[k] = v + } + if component != "" { + if _, ok := eventFields["component"]; !ok { + eventFields["component"] = component + } + } + + sink.WriteLogEvent(&LogEvent{ + Time: time.Now(), + Level: level, + Component: component, + Message: message, + LoggerName: component, + Fields: eventFields, + }) +} + +func L() *zap.Logger { + mu.RLock() + defer mu.RUnlock() + if global != nil { + return global + } + return zap.NewNop() +} + +func S() *zap.SugaredLogger { + mu.RLock() + defer mu.RUnlock() + if sugar != nil { + return sugar + } + return zap.NewNop().Sugar() +} + +func With(fields ...zap.Field) *zap.Logger { + return L().With(fields...) +} + +func Sync() { + mu.RLock() + l := global + mu.RUnlock() + if l != nil { + _ = l.Sync() + } +} + +func bridgeStdLogLocked() { + if stdLogUndo != nil { + stdLogUndo() + stdLogUndo = nil + } + + prevFlags := log.Flags() + prevPrefix := log.Prefix() + prevWriter := log.Writer() + + log.SetFlags(0) + log.SetPrefix("") + log.SetOutput(newStdLogBridge(global.Named("stdlog"))) + + stdLogUndo = func() { + log.SetOutput(prevWriter) + log.SetFlags(prevFlags) + log.SetPrefix(prevPrefix) + } +} + +func bridgeSlogLocked() { + slog.SetDefault(slog.New(newSlogZapHandler(global.Named("slog")))) +} + +func buildLogger(options InitOptions) (*zap.Logger, zap.AtomicLevel, error) { + level, _ := parseLevel(options.Level) + atomic := zap.NewAtomicLevelAt(level) + + encoderCfg := zapcore.EncoderConfig{ + TimeKey: "time", + LevelKey: "level", + NameKey: "logger", + CallerKey: "caller", + MessageKey: "msg", + StacktraceKey: "stacktrace", + LineEnding: zapcore.DefaultLineEnding, + EncodeLevel: zapcore.CapitalLevelEncoder, + EncodeTime: zapcore.ISO8601TimeEncoder, + EncodeDuration: zapcore.MillisDurationEncoder, + EncodeCaller: zapcore.ShortCallerEncoder, + } + + var enc zapcore.Encoder + if options.Format == "console" { + enc = zapcore.NewConsoleEncoder(encoderCfg) + } else { + enc = zapcore.NewJSONEncoder(encoderCfg) + } + + sinkCore := newSinkCore() + cores := make([]zapcore.Core, 0, 3) + + if options.Output.ToStdout { + infoPriority := zap.LevelEnablerFunc(func(lvl zapcore.Level) bool { + return lvl >= atomic.Level() && lvl < zapcore.WarnLevel + }) + errPriority := zap.LevelEnablerFunc(func(lvl zapcore.Level) bool { + return lvl >= atomic.Level() && lvl >= zapcore.WarnLevel + }) + cores = append(cores, zapcore.NewCore(enc, zapcore.Lock(os.Stdout), infoPriority)) + cores = append(cores, zapcore.NewCore(enc, zapcore.Lock(os.Stderr), errPriority)) + } + + if options.Output.ToFile { + fileCore, filePath, fileErr := buildFileCore(enc, atomic, options) + if fileErr != nil { + _, _ = fmt.Fprintf(os.Stderr, "time=%s level=WARN msg=\"日志文件输出初始化失败,降级为仅标准输出\" path=%s err=%v\n", + time.Now().Format(time.RFC3339Nano), + filePath, + fileErr, + ) + } else { + cores = append(cores, fileCore) + } + } + + if len(cores) == 0 { + cores = append(cores, zapcore.NewCore(enc, zapcore.Lock(os.Stdout), atomic)) + } + + core := zapcore.NewTee(cores...) + if options.Sampling.Enabled { + core = zapcore.NewSamplerWithOptions(core, samplingTick(), options.Sampling.Initial, options.Sampling.Thereafter) + } + core = sinkCore.Wrap(core) + + stacktraceLevel, _ := parseStacktraceLevel(options.StacktraceLevel) + zapOpts := make([]zap.Option, 0, 5) + if options.Caller { + zapOpts = append(zapOpts, zap.AddCaller()) + } + if stacktraceLevel <= zapcore.FatalLevel { + zapOpts = append(zapOpts, zap.AddStacktrace(stacktraceLevel)) + } + + logger := zap.New(core, zapOpts...).With( + zap.String("service", options.ServiceName), + zap.String("env", options.Environment), + ) + return logger, atomic, nil +} + +func buildFileCore(enc zapcore.Encoder, atomic zap.AtomicLevel, options InitOptions) (zapcore.Core, string, error) { + filePath := options.Output.FilePath + if strings.TrimSpace(filePath) == "" { + filePath = resolveLogFilePath("") + } + + dir := filepath.Dir(filePath) + if err := os.MkdirAll(dir, 0o755); err != nil { + return nil, filePath, err + } + lj := &lumberjack.Logger{ + Filename: filePath, + MaxSize: options.Rotation.MaxSizeMB, + MaxBackups: options.Rotation.MaxBackups, + MaxAge: options.Rotation.MaxAgeDays, + Compress: options.Rotation.Compress, + LocalTime: options.Rotation.LocalTime, + } + return zapcore.NewCore(enc, zapcore.AddSync(lj), atomic), filePath, nil +} + +type sinkCore struct { + core zapcore.Core + fields []zapcore.Field +} + +func newSinkCore() *sinkCore { + return &sinkCore{} +} + +func (s *sinkCore) Wrap(core zapcore.Core) zapcore.Core { + cp := *s + cp.core = core + return &cp +} + +func (s *sinkCore) Enabled(level zapcore.Level) bool { + return s.core.Enabled(level) +} + +func (s *sinkCore) With(fields []zapcore.Field) zapcore.Core { + nextFields := append([]zapcore.Field{}, s.fields...) + nextFields = append(nextFields, fields...) + return &sinkCore{ + core: s.core.With(fields), + fields: nextFields, + } +} + +func (s *sinkCore) Check(entry zapcore.Entry, ce *zapcore.CheckedEntry) *zapcore.CheckedEntry { + // Delegate to inner core (tee) so each sub-core's level enabler is respected. + // Then add ourselves for sink forwarding only. + ce = s.core.Check(entry, ce) + if ce != nil { + ce = ce.AddCore(entry, s) + } + return ce +} + +func (s *sinkCore) Write(entry zapcore.Entry, fields []zapcore.Field) error { + // Only handle sink forwarding — the inner cores write via their own + // Write methods (added to CheckedEntry by s.core.Check above). + mu.RLock() + sink := currentSink + mu.RUnlock() + if sink == nil { + return nil + } + + enc := zapcore.NewMapObjectEncoder() + for _, f := range s.fields { + f.AddTo(enc) + } + for _, f := range fields { + f.AddTo(enc) + } + + event := &LogEvent{ + Time: entry.Time, + Level: strings.ToLower(entry.Level.String()), + Component: entry.LoggerName, + Message: entry.Message, + LoggerName: entry.LoggerName, + Fields: enc.Fields, + } + sink.WriteLogEvent(event) + return nil +} + +func (s *sinkCore) Sync() error { + return s.core.Sync() +} + +type stdLogBridge struct { + logger *zap.Logger +} + +func newStdLogBridge(l *zap.Logger) io.Writer { + if l == nil { + l = zap.NewNop() + } + return &stdLogBridge{logger: l} +} + +func (b *stdLogBridge) Write(p []byte) (int, error) { + msg := normalizeStdLogMessage(string(p)) + if msg == "" { + return len(p), nil + } + + level := inferStdLogLevel(msg) + entry := b.logger.WithOptions(zap.AddCallerSkip(4)) + + switch level { + case LevelDebug: + entry.Debug(msg, zap.Bool("legacy_stdlog", true)) + case LevelWarn: + entry.Warn(msg, zap.Bool("legacy_stdlog", true)) + case LevelError, LevelFatal: + entry.Error(msg, zap.Bool("legacy_stdlog", true)) + default: + entry.Info(msg, zap.Bool("legacy_stdlog", true)) + } + return len(p), nil +} + +func normalizeStdLogMessage(raw string) string { + msg := strings.TrimSpace(strings.ReplaceAll(raw, "\n", " ")) + if msg == "" { + return "" + } + return strings.Join(strings.Fields(msg), " ") +} + +func inferStdLogLevel(msg string) Level { + lower := strings.ToLower(strings.TrimSpace(msg)) + if lower == "" { + return LevelInfo + } + + if strings.HasPrefix(lower, "[debug]") || strings.HasPrefix(lower, "debug:") { + return LevelDebug + } + if strings.HasPrefix(lower, "[warn]") || strings.HasPrefix(lower, "[warning]") || strings.HasPrefix(lower, "warn:") || strings.HasPrefix(lower, "warning:") { + return LevelWarn + } + if strings.HasPrefix(lower, "[error]") || strings.HasPrefix(lower, "error:") || strings.HasPrefix(lower, "fatal:") || strings.HasPrefix(lower, "panic:") { + return LevelError + } + + if strings.Contains(lower, " failed") || strings.Contains(lower, "error") || strings.Contains(lower, "panic") || strings.Contains(lower, "fatal") { + return LevelError + } + if strings.Contains(lower, "warning") || strings.Contains(lower, "warn") || strings.Contains(lower, " retry") || strings.Contains(lower, " queue full") || strings.Contains(lower, "fallback") { + return LevelWarn + } + return LevelInfo +} + +// LegacyPrintf 用于平滑迁移历史的 printf 风格日志到结构化 logger。 +func LegacyPrintf(component, format string, args ...any) { + msg := normalizeStdLogMessage(fmt.Sprintf(format, args...)) + if msg == "" { + return + } + + mu.RLock() + initialized := global != nil + mu.RUnlock() + if !initialized { + // 在日志系统未初始化前,回退到标准库 log,避免测试/工具链丢日志。 + log.Print(msg) + return + } + + l := L() + if component != "" { + l = l.With(zap.String("component", component)) + } + l = l.WithOptions(zap.AddCallerSkip(1)) + + switch inferStdLogLevel(msg) { + case LevelDebug: + l.Debug(msg, zap.Bool("legacy_printf", true)) + case LevelWarn: + l.Warn(msg, zap.Bool("legacy_printf", true)) + case LevelError, LevelFatal: + l.Error(msg, zap.Bool("legacy_printf", true)) + default: + l.Info(msg, zap.Bool("legacy_printf", true)) + } +} + +type contextKey string + +const loggerContextKey contextKey = "ctx_logger" + +func IntoContext(ctx context.Context, l *zap.Logger) context.Context { + if ctx == nil { + ctx = context.Background() + } + if l == nil { + l = L() + } + return context.WithValue(ctx, loggerContextKey, l) +} + +func FromContext(ctx context.Context) *zap.Logger { + if ctx == nil { + return L() + } + if l, ok := ctx.Value(loggerContextKey).(*zap.Logger); ok && l != nil { + return l + } + return L() +} diff --git a/backend/internal/pkg/logger/logger_test.go b/backend/internal/pkg/logger/logger_test.go new file mode 100644 index 00000000..74aae061 --- /dev/null +++ b/backend/internal/pkg/logger/logger_test.go @@ -0,0 +1,192 @@ +package logger + +import ( + "encoding/json" + "io" + "os" + "path/filepath" + "strings" + "testing" +) + +func TestInit_DualOutput(t *testing.T) { + tmpDir := t.TempDir() + logPath := filepath.Join(tmpDir, "logs", "sub2api.log") + + origStdout := os.Stdout + origStderr := os.Stderr + stdoutR, stdoutW, err := os.Pipe() + if err != nil { + t.Fatalf("create stdout pipe: %v", err) + } + stderrR, stderrW, err := os.Pipe() + if err != nil { + t.Fatalf("create stderr pipe: %v", err) + } + os.Stdout = stdoutW + os.Stderr = stderrW + t.Cleanup(func() { + os.Stdout = origStdout + os.Stderr = origStderr + _ = stdoutR.Close() + _ = stderrR.Close() + _ = stdoutW.Close() + _ = stderrW.Close() + }) + + err = Init(InitOptions{ + Level: "debug", + Format: "json", + ServiceName: "sub2api", + Environment: "test", + Output: OutputOptions{ + ToStdout: true, + ToFile: true, + FilePath: logPath, + }, + Rotation: RotationOptions{ + MaxSizeMB: 10, + MaxBackups: 2, + MaxAgeDays: 1, + }, + Sampling: SamplingOptions{Enabled: false}, + }) + if err != nil { + t.Fatalf("Init() error: %v", err) + } + + L().Info("dual-output-info") + L().Warn("dual-output-warn") + Sync() + + _ = stdoutW.Close() + _ = stderrW.Close() + stdoutBytes, _ := io.ReadAll(stdoutR) + stderrBytes, _ := io.ReadAll(stderrR) + stdoutText := string(stdoutBytes) + stderrText := string(stderrBytes) + + if !strings.Contains(stdoutText, "dual-output-info") { + t.Fatalf("stdout missing info log: %s", stdoutText) + } + if !strings.Contains(stderrText, "dual-output-warn") { + t.Fatalf("stderr missing warn log: %s", stderrText) + } + + fileBytes, err := os.ReadFile(logPath) + if err != nil { + t.Fatalf("read log file: %v", err) + } + fileText := string(fileBytes) + if !strings.Contains(fileText, "dual-output-info") || !strings.Contains(fileText, "dual-output-warn") { + t.Fatalf("file missing logs: %s", fileText) + } +} + +func TestInit_FileOutputFailureDowngrade(t *testing.T) { + origStdout := os.Stdout + origStderr := os.Stderr + _, stdoutW, err := os.Pipe() + if err != nil { + t.Fatalf("create stdout pipe: %v", err) + } + stderrR, stderrW, err := os.Pipe() + if err != nil { + t.Fatalf("create stderr pipe: %v", err) + } + os.Stdout = stdoutW + os.Stderr = stderrW + t.Cleanup(func() { + os.Stdout = origStdout + os.Stderr = origStderr + _ = stdoutW.Close() + _ = stderrR.Close() + _ = stderrW.Close() + }) + + err = Init(InitOptions{ + Level: "info", + Format: "json", + Output: OutputOptions{ + ToStdout: true, + ToFile: true, + FilePath: filepath.Join(os.DevNull, "logs", "sub2api.log"), + }, + Rotation: RotationOptions{ + MaxSizeMB: 10, + MaxBackups: 1, + MaxAgeDays: 1, + }, + }) + if err != nil { + t.Fatalf("Init() should downgrade instead of failing, got: %v", err) + } + + _ = stderrW.Close() + stderrBytes, _ := io.ReadAll(stderrR) + if !strings.Contains(string(stderrBytes), "日志文件输出初始化失败") { + t.Fatalf("stderr should contain fallback warning, got: %s", string(stderrBytes)) + } +} + +func TestInit_CallerShouldPointToCallsite(t *testing.T) { + origStdout := os.Stdout + origStderr := os.Stderr + stdoutR, stdoutW, err := os.Pipe() + if err != nil { + t.Fatalf("create stdout pipe: %v", err) + } + _, stderrW, err := os.Pipe() + if err != nil { + t.Fatalf("create stderr pipe: %v", err) + } + os.Stdout = stdoutW + os.Stderr = stderrW + t.Cleanup(func() { + os.Stdout = origStdout + os.Stderr = origStderr + _ = stdoutR.Close() + _ = stdoutW.Close() + _ = stderrW.Close() + }) + + if err := Init(InitOptions{ + Level: "info", + Format: "json", + ServiceName: "sub2api", + Environment: "test", + Caller: true, + Output: OutputOptions{ + ToStdout: true, + ToFile: false, + }, + Sampling: SamplingOptions{Enabled: false}, + }); err != nil { + t.Fatalf("Init() error: %v", err) + } + + L().Info("caller-check") + Sync() + _ = stdoutW.Close() + logBytes, _ := io.ReadAll(stdoutR) + + var line string + for _, item := range strings.Split(string(logBytes), "\n") { + if strings.Contains(item, "caller-check") { + line = item + break + } + } + if line == "" { + t.Fatalf("log output missing caller-check: %s", string(logBytes)) + } + + var payload map[string]any + if err := json.Unmarshal([]byte(line), &payload); err != nil { + t.Fatalf("parse log json failed: %v, line=%s", err, line) + } + caller, _ := payload["caller"].(string) + if !strings.Contains(caller, "logger_test.go:") { + t.Fatalf("caller should point to this test file, got: %s", caller) + } +} diff --git a/backend/internal/pkg/logger/options.go b/backend/internal/pkg/logger/options.go new file mode 100644 index 00000000..efcd701c --- /dev/null +++ b/backend/internal/pkg/logger/options.go @@ -0,0 +1,161 @@ +package logger + +import ( + "os" + "path/filepath" + "strings" + "time" +) + +const ( + // DefaultContainerLogPath 为容器内默认日志文件路径。 + DefaultContainerLogPath = "/app/data/logs/sub2api.log" + defaultLogFilename = "sub2api.log" +) + +type InitOptions struct { + Level string + Format string + ServiceName string + Environment string + Caller bool + StacktraceLevel string + Output OutputOptions + Rotation RotationOptions + Sampling SamplingOptions +} + +type OutputOptions struct { + ToStdout bool + ToFile bool + FilePath string +} + +type RotationOptions struct { + MaxSizeMB int + MaxBackups int + MaxAgeDays int + Compress bool + LocalTime bool +} + +type SamplingOptions struct { + Enabled bool + Initial int + Thereafter int +} + +func (o InitOptions) normalized() InitOptions { + out := o + out.Level = strings.ToLower(strings.TrimSpace(out.Level)) + if out.Level == "" { + out.Level = "info" + } + out.Format = strings.ToLower(strings.TrimSpace(out.Format)) + if out.Format == "" { + out.Format = "console" + } + out.ServiceName = strings.TrimSpace(out.ServiceName) + if out.ServiceName == "" { + out.ServiceName = "sub2api" + } + out.Environment = strings.TrimSpace(out.Environment) + if out.Environment == "" { + out.Environment = "production" + } + out.StacktraceLevel = strings.ToLower(strings.TrimSpace(out.StacktraceLevel)) + if out.StacktraceLevel == "" { + out.StacktraceLevel = "error" + } + if !out.Output.ToStdout && !out.Output.ToFile { + out.Output.ToStdout = true + } + out.Output.FilePath = resolveLogFilePath(out.Output.FilePath) + if out.Rotation.MaxSizeMB <= 0 { + out.Rotation.MaxSizeMB = 100 + } + if out.Rotation.MaxBackups < 0 { + out.Rotation.MaxBackups = 10 + } + if out.Rotation.MaxAgeDays < 0 { + out.Rotation.MaxAgeDays = 7 + } + if out.Sampling.Enabled { + if out.Sampling.Initial <= 0 { + out.Sampling.Initial = 100 + } + if out.Sampling.Thereafter <= 0 { + out.Sampling.Thereafter = 100 + } + } + return out +} + +func resolveLogFilePath(explicit string) string { + explicit = strings.TrimSpace(explicit) + if explicit != "" { + return explicit + } + dataDir := strings.TrimSpace(os.Getenv("DATA_DIR")) + if dataDir != "" { + return filepath.Join(dataDir, "logs", defaultLogFilename) + } + return DefaultContainerLogPath +} + +func bootstrapOptions() InitOptions { + return InitOptions{ + Level: "info", + Format: "console", + ServiceName: "sub2api", + Environment: "bootstrap", + Output: OutputOptions{ + ToStdout: true, + ToFile: false, + }, + Rotation: RotationOptions{ + MaxSizeMB: 100, + MaxBackups: 10, + MaxAgeDays: 7, + Compress: true, + LocalTime: true, + }, + Sampling: SamplingOptions{ + Enabled: false, + Initial: 100, + Thereafter: 100, + }, + } +} + +func parseLevel(level string) (Level, bool) { + switch strings.ToLower(strings.TrimSpace(level)) { + case "debug": + return LevelDebug, true + case "info": + return LevelInfo, true + case "warn": + return LevelWarn, true + case "error": + return LevelError, true + default: + return LevelInfo, false + } +} + +func parseStacktraceLevel(level string) (Level, bool) { + switch strings.ToLower(strings.TrimSpace(level)) { + case "none": + return LevelFatal + 1, true + case "error": + return LevelError, true + case "fatal": + return LevelFatal, true + default: + return LevelError, false + } +} + +func samplingTick() time.Duration { + return time.Second +} diff --git a/backend/internal/pkg/logger/options_test.go b/backend/internal/pkg/logger/options_test.go new file mode 100644 index 00000000..10d50d72 --- /dev/null +++ b/backend/internal/pkg/logger/options_test.go @@ -0,0 +1,102 @@ +package logger + +import ( + "os" + "path/filepath" + "testing" + + "go.uber.org/zap" + "go.uber.org/zap/zapcore" +) + +func TestResolveLogFilePath_Default(t *testing.T) { + t.Setenv("DATA_DIR", "") + got := resolveLogFilePath("") + if got != DefaultContainerLogPath { + t.Fatalf("resolveLogFilePath() = %q, want %q", got, DefaultContainerLogPath) + } +} + +func TestResolveLogFilePath_WithDataDir(t *testing.T) { + t.Setenv("DATA_DIR", "/tmp/sub2api-data") + got := resolveLogFilePath("") + want := filepath.Join("/tmp/sub2api-data", "logs", "sub2api.log") + if got != want { + t.Fatalf("resolveLogFilePath() = %q, want %q", got, want) + } +} + +func TestResolveLogFilePath_ExplicitPath(t *testing.T) { + t.Setenv("DATA_DIR", "/tmp/ignore") + got := resolveLogFilePath("/var/log/custom.log") + if got != "/var/log/custom.log" { + t.Fatalf("resolveLogFilePath() = %q, want explicit path", got) + } +} + +func TestNormalizedOptions_InvalidFallback(t *testing.T) { + t.Setenv("DATA_DIR", "") + opts := InitOptions{ + Level: "TRACE", + Format: "TEXT", + ServiceName: "", + Environment: "", + StacktraceLevel: "panic", + Output: OutputOptions{ + ToStdout: false, + ToFile: false, + }, + Rotation: RotationOptions{ + MaxSizeMB: 0, + MaxBackups: -1, + MaxAgeDays: -1, + }, + Sampling: SamplingOptions{ + Enabled: true, + Initial: 0, + Thereafter: 0, + }, + } + out := opts.normalized() + if out.Level != "trace" { + // normalized 仅做 trim/lower,不做校验;校验在 config 层。 + t.Fatalf("normalized level should preserve value for upstream validation, got %q", out.Level) + } + if !out.Output.ToStdout { + t.Fatalf("normalized output should fallback to stdout") + } + if out.Output.FilePath != DefaultContainerLogPath { + t.Fatalf("normalized file path = %q", out.Output.FilePath) + } + if out.Rotation.MaxSizeMB != 100 { + t.Fatalf("normalized max_size_mb = %d", out.Rotation.MaxSizeMB) + } + if out.Rotation.MaxBackups != 10 { + t.Fatalf("normalized max_backups = %d", out.Rotation.MaxBackups) + } + if out.Rotation.MaxAgeDays != 7 { + t.Fatalf("normalized max_age_days = %d", out.Rotation.MaxAgeDays) + } + if out.Sampling.Initial != 100 || out.Sampling.Thereafter != 100 { + t.Fatalf("normalized sampling defaults invalid: %+v", out.Sampling) + } +} + +func TestBuildFileCore_InvalidPathFallback(t *testing.T) { + t.Setenv("DATA_DIR", "") + opts := bootstrapOptions() + opts.Output.ToFile = true + opts.Output.FilePath = filepath.Join(os.DevNull, "logs", "sub2api.log") + encoderCfg := zapcore.EncoderConfig{ + TimeKey: "time", + LevelKey: "level", + MessageKey: "msg", + EncodeTime: zapcore.ISO8601TimeEncoder, + EncodeLevel: zapcore.CapitalLevelEncoder, + } + encoder := zapcore.NewJSONEncoder(encoderCfg) + _, _, err := buildFileCore(encoder, zap.NewAtomicLevel(), opts) + if err == nil { + t.Fatalf("buildFileCore() expected error for invalid path") + } +} diff --git a/backend/internal/pkg/logger/slog_handler.go b/backend/internal/pkg/logger/slog_handler.go new file mode 100644 index 00000000..562b8341 --- /dev/null +++ b/backend/internal/pkg/logger/slog_handler.go @@ -0,0 +1,132 @@ +package logger + +import ( + "context" + "log/slog" + "strings" + "time" + + "go.uber.org/zap" + "go.uber.org/zap/zapcore" +) + +type slogZapHandler struct { + logger *zap.Logger + attrs []slog.Attr + groups []string +} + +func newSlogZapHandler(logger *zap.Logger) slog.Handler { + if logger == nil { + logger = zap.NewNop() + } + return &slogZapHandler{ + logger: logger, + attrs: make([]slog.Attr, 0, 8), + groups: make([]string, 0, 4), + } +} + +func (h *slogZapHandler) Enabled(_ context.Context, level slog.Level) bool { + switch { + case level >= slog.LevelError: + return h.logger.Core().Enabled(LevelError) + case level >= slog.LevelWarn: + return h.logger.Core().Enabled(LevelWarn) + case level <= slog.LevelDebug: + return h.logger.Core().Enabled(LevelDebug) + default: + return h.logger.Core().Enabled(LevelInfo) + } +} + +func (h *slogZapHandler) Handle(_ context.Context, record slog.Record) error { + fields := make([]zap.Field, 0, len(h.attrs)+record.NumAttrs()+3) + fields = append(fields, slogAttrsToZapFields(h.groups, h.attrs)...) + record.Attrs(func(attr slog.Attr) bool { + fields = append(fields, slogAttrToZapField(h.groups, attr)) + return true + }) + + entry := h.logger.With(fields...) + switch { + case record.Level >= slog.LevelError: + entry.Error(record.Message) + case record.Level >= slog.LevelWarn: + entry.Warn(record.Message) + case record.Level <= slog.LevelDebug: + entry.Debug(record.Message) + default: + entry.Info(record.Message) + } + return nil +} + +func (h *slogZapHandler) WithAttrs(attrs []slog.Attr) slog.Handler { + next := *h + next.attrs = append(append([]slog.Attr{}, h.attrs...), attrs...) + return &next +} + +func (h *slogZapHandler) WithGroup(name string) slog.Handler { + name = strings.TrimSpace(name) + if name == "" { + return h + } + next := *h + next.groups = append(append([]string{}, h.groups...), name) + return &next +} + +func slogAttrsToZapFields(groups []string, attrs []slog.Attr) []zap.Field { + fields := make([]zap.Field, 0, len(attrs)) + for _, attr := range attrs { + fields = append(fields, slogAttrToZapField(groups, attr)) + } + return fields +} + +func slogAttrToZapField(groups []string, attr slog.Attr) zap.Field { + if len(groups) > 0 { + attr.Key = strings.Join(append(append([]string{}, groups...), attr.Key), ".") + } + value := attr.Value.Resolve() + switch value.Kind() { + case slog.KindBool: + return zap.Bool(attr.Key, value.Bool()) + case slog.KindInt64: + return zap.Int64(attr.Key, value.Int64()) + case slog.KindUint64: + return zap.Uint64(attr.Key, value.Uint64()) + case slog.KindFloat64: + return zap.Float64(attr.Key, value.Float64()) + case slog.KindDuration: + return zap.Duration(attr.Key, value.Duration()) + case slog.KindTime: + return zap.Time(attr.Key, value.Time()) + case slog.KindString: + return zap.String(attr.Key, value.String()) + case slog.KindGroup: + groupFields := make([]zap.Field, 0, len(value.Group())) + for _, nested := range value.Group() { + groupFields = append(groupFields, slogAttrToZapField(nil, nested)) + } + return zap.Object(attr.Key, zapObjectFields(groupFields)) + case slog.KindAny: + if t, ok := value.Any().(time.Time); ok { + return zap.Time(attr.Key, t) + } + return zap.Any(attr.Key, value.Any()) + default: + return zap.String(attr.Key, value.String()) + } +} + +type zapObjectFields []zap.Field + +func (z zapObjectFields) MarshalLogObject(enc zapcore.ObjectEncoder) error { + for _, field := range z { + field.AddTo(enc) + } + return nil +} diff --git a/backend/internal/pkg/logger/slog_handler_test.go b/backend/internal/pkg/logger/slog_handler_test.go new file mode 100644 index 00000000..d2b4208d --- /dev/null +++ b/backend/internal/pkg/logger/slog_handler_test.go @@ -0,0 +1,88 @@ +package logger + +import ( + "context" + "log/slog" + "testing" + "time" + + "go.uber.org/zap" + "go.uber.org/zap/zapcore" +) + +type captureState struct { + writes []capturedWrite +} + +type capturedWrite struct { + fields []zapcore.Field +} + +type captureCore struct { + state *captureState + withFields []zapcore.Field +} + +func newCaptureCore() *captureCore { + return &captureCore{state: &captureState{}} +} + +func (c *captureCore) Enabled(zapcore.Level) bool { + return true +} + +func (c *captureCore) With(fields []zapcore.Field) zapcore.Core { + nextFields := make([]zapcore.Field, 0, len(c.withFields)+len(fields)) + nextFields = append(nextFields, c.withFields...) + nextFields = append(nextFields, fields...) + return &captureCore{ + state: c.state, + withFields: nextFields, + } +} + +func (c *captureCore) Check(entry zapcore.Entry, ce *zapcore.CheckedEntry) *zapcore.CheckedEntry { + return ce.AddCore(entry, c) +} + +func (c *captureCore) Write(entry zapcore.Entry, fields []zapcore.Field) error { + allFields := make([]zapcore.Field, 0, len(c.withFields)+len(fields)) + allFields = append(allFields, c.withFields...) + allFields = append(allFields, fields...) + c.state.writes = append(c.state.writes, capturedWrite{ + fields: allFields, + }) + return nil +} + +func (c *captureCore) Sync() error { + return nil +} + +func TestSlogZapHandler_Handle_DoesNotAppendTimeField(t *testing.T) { + core := newCaptureCore() + handler := newSlogZapHandler(zap.New(core)) + + record := slog.NewRecord(time.Date(2026, 1, 1, 12, 0, 0, 0, time.UTC), slog.LevelInfo, "hello", 0) + record.AddAttrs(slog.String("component", "http.access")) + + if err := handler.Handle(context.Background(), record); err != nil { + t.Fatalf("handle slog record: %v", err) + } + if len(core.state.writes) != 1 { + t.Fatalf("write calls = %d, want 1", len(core.state.writes)) + } + + var hasComponent bool + for _, field := range core.state.writes[0].fields { + if field.Key == "time" { + t.Fatalf("unexpected duplicate time field in slog adapter output") + } + if field.Key == "component" { + hasComponent = true + } + } + if !hasComponent { + t.Fatalf("component field should be preserved") + } +} diff --git a/backend/internal/pkg/logger/stdlog_bridge_test.go b/backend/internal/pkg/logger/stdlog_bridge_test.go new file mode 100644 index 00000000..a3f76fd7 --- /dev/null +++ b/backend/internal/pkg/logger/stdlog_bridge_test.go @@ -0,0 +1,165 @@ +package logger + +import ( + "io" + "log" + "os" + "strings" + "testing" +) + +func TestInferStdLogLevel(t *testing.T) { + cases := []struct { + msg string + want Level + }{ + {msg: "Warning: queue full", want: LevelWarn}, + {msg: "Forward request failed: timeout", want: LevelError}, + {msg: "[ERROR] upstream unavailable", want: LevelError}, + {msg: "service started", want: LevelInfo}, + {msg: "debug: cache miss", want: LevelDebug}, + } + + for _, tc := range cases { + got := inferStdLogLevel(tc.msg) + if got != tc.want { + t.Fatalf("inferStdLogLevel(%q)=%v want=%v", tc.msg, got, tc.want) + } + } +} + +func TestNormalizeStdLogMessage(t *testing.T) { + raw := " [TokenRefresh] cycle complete \n total=1 failed=0 \n" + got := normalizeStdLogMessage(raw) + want := "[TokenRefresh] cycle complete total=1 failed=0" + if got != want { + t.Fatalf("normalizeStdLogMessage()=%q want=%q", got, want) + } +} + +func TestStdLogBridgeRoutesLevels(t *testing.T) { + origStdout := os.Stdout + origStderr := os.Stderr + stdoutR, stdoutW, err := os.Pipe() + if err != nil { + t.Fatalf("create stdout pipe: %v", err) + } + stderrR, stderrW, err := os.Pipe() + if err != nil { + t.Fatalf("create stderr pipe: %v", err) + } + os.Stdout = stdoutW + os.Stderr = stderrW + t.Cleanup(func() { + os.Stdout = origStdout + os.Stderr = origStderr + _ = stdoutR.Close() + _ = stdoutW.Close() + _ = stderrR.Close() + _ = stderrW.Close() + }) + + if err := Init(InitOptions{ + Level: "debug", + Format: "json", + ServiceName: "sub2api", + Environment: "test", + Output: OutputOptions{ + ToStdout: true, + ToFile: false, + }, + Sampling: SamplingOptions{Enabled: false}, + }); err != nil { + t.Fatalf("Init() error: %v", err) + } + + log.Printf("service started") + log.Printf("Warning: queue full") + log.Printf("Forward request failed: timeout") + Sync() + + _ = stdoutW.Close() + _ = stderrW.Close() + stdoutBytes, _ := io.ReadAll(stdoutR) + stderrBytes, _ := io.ReadAll(stderrR) + stdoutText := string(stdoutBytes) + stderrText := string(stderrBytes) + + if !strings.Contains(stdoutText, "service started") { + t.Fatalf("stdout missing info log: %s", stdoutText) + } + if !strings.Contains(stderrText, "Warning: queue full") { + t.Fatalf("stderr missing warn log: %s", stderrText) + } + if !strings.Contains(stderrText, "Forward request failed: timeout") { + t.Fatalf("stderr missing error log: %s", stderrText) + } + if !strings.Contains(stderrText, "\"legacy_stdlog\":true") { + t.Fatalf("stderr missing legacy_stdlog marker: %s", stderrText) + } +} + +func TestLegacyPrintfRoutesLevels(t *testing.T) { + origStdout := os.Stdout + origStderr := os.Stderr + stdoutR, stdoutW, err := os.Pipe() + if err != nil { + t.Fatalf("create stdout pipe: %v", err) + } + stderrR, stderrW, err := os.Pipe() + if err != nil { + t.Fatalf("create stderr pipe: %v", err) + } + os.Stdout = stdoutW + os.Stderr = stderrW + t.Cleanup(func() { + os.Stdout = origStdout + os.Stderr = origStderr + _ = stdoutR.Close() + _ = stdoutW.Close() + _ = stderrR.Close() + _ = stderrW.Close() + }) + + if err := Init(InitOptions{ + Level: "debug", + Format: "json", + ServiceName: "sub2api", + Environment: "test", + Output: OutputOptions{ + ToStdout: true, + ToFile: false, + }, + Sampling: SamplingOptions{Enabled: false}, + }); err != nil { + t.Fatalf("Init() error: %v", err) + } + + LegacyPrintf("service.test", "request started") + LegacyPrintf("service.test", "Warning: queue full") + LegacyPrintf("service.test", "forward failed: timeout") + Sync() + + _ = stdoutW.Close() + _ = stderrW.Close() + stdoutBytes, _ := io.ReadAll(stdoutR) + stderrBytes, _ := io.ReadAll(stderrR) + stdoutText := string(stdoutBytes) + stderrText := string(stderrBytes) + + if !strings.Contains(stdoutText, "request started") { + t.Fatalf("stdout missing info log: %s", stdoutText) + } + if !strings.Contains(stderrText, "Warning: queue full") { + t.Fatalf("stderr missing warn log: %s", stderrText) + } + if !strings.Contains(stderrText, "forward failed: timeout") { + t.Fatalf("stderr missing error log: %s", stderrText) + } + if !strings.Contains(stderrText, "\"legacy_printf\":true") { + t.Fatalf("stderr missing legacy_printf marker: %s", stderrText) + } + if !strings.Contains(stderrText, "\"component\":\"service.test\"") { + t.Fatalf("stderr missing component field: %s", stderrText) + } +} diff --git a/backend/internal/pkg/oauth/oauth.go b/backend/internal/pkg/oauth/oauth.go index 33caffd7..cfc91bee 100644 --- a/backend/internal/pkg/oauth/oauth.go +++ b/backend/internal/pkg/oauth/oauth.go @@ -50,6 +50,7 @@ type OAuthSession struct { type SessionStore struct { mu sync.RWMutex sessions map[string]*OAuthSession + stopOnce sync.Once stopCh chan struct{} } @@ -65,7 +66,9 @@ func NewSessionStore() *SessionStore { // Stop stops the cleanup goroutine func (s *SessionStore) Stop() { - close(s.stopCh) + s.stopOnce.Do(func() { + close(s.stopCh) + }) } // Set stores a session diff --git a/backend/internal/pkg/oauth/oauth_test.go b/backend/internal/pkg/oauth/oauth_test.go new file mode 100644 index 00000000..9e59f0f0 --- /dev/null +++ b/backend/internal/pkg/oauth/oauth_test.go @@ -0,0 +1,43 @@ +package oauth + +import ( + "sync" + "testing" + "time" +) + +func TestSessionStore_Stop_Idempotent(t *testing.T) { + store := NewSessionStore() + + store.Stop() + store.Stop() + + select { + case <-store.stopCh: + // ok + case <-time.After(time.Second): + t.Fatal("stopCh 未关闭") + } +} + +func TestSessionStore_Stop_Concurrent(t *testing.T) { + store := NewSessionStore() + + var wg sync.WaitGroup + for range 50 { + wg.Add(1) + go func() { + defer wg.Done() + store.Stop() + }() + } + + wg.Wait() + + select { + case <-store.stopCh: + // ok + case <-time.After(time.Second): + t.Fatal("stopCh 未关闭") + } +} diff --git a/backend/internal/pkg/openai/constants.go b/backend/internal/pkg/openai/constants.go index fd24b11d..4bbc68e7 100644 --- a/backend/internal/pkg/openai/constants.go +++ b/backend/internal/pkg/openai/constants.go @@ -15,8 +15,8 @@ type Model struct { // DefaultModels OpenAI models list var DefaultModels = []Model{ - {ID: "gpt-5.3", Object: "model", Created: 1735689600, OwnedBy: "openai", Type: "model", DisplayName: "GPT-5.3"}, {ID: "gpt-5.3-codex", Object: "model", Created: 1735689600, OwnedBy: "openai", Type: "model", DisplayName: "GPT-5.3 Codex"}, + {ID: "gpt-5.3-codex-spark", Object: "model", Created: 1735689600, OwnedBy: "openai", Type: "model", DisplayName: "GPT-5.3 Codex Spark"}, {ID: "gpt-5.2", Object: "model", Created: 1733875200, OwnedBy: "openai", Type: "model", DisplayName: "GPT-5.2"}, {ID: "gpt-5.2-codex", Object: "model", Created: 1733011200, OwnedBy: "openai", Type: "model", DisplayName: "GPT-5.2 Codex"}, {ID: "gpt-5.1-codex-max", Object: "model", Created: 1730419200, OwnedBy: "openai", Type: "model", DisplayName: "GPT-5.1 Codex Max"}, diff --git a/backend/internal/pkg/openai/oauth.go b/backend/internal/pkg/openai/oauth.go index df972a13..e3b931be 100644 --- a/backend/internal/pkg/openai/oauth.go +++ b/backend/internal/pkg/openai/oauth.go @@ -17,6 +17,8 @@ import ( const ( // OAuth Client ID for OpenAI (Codex CLI official) ClientID = "app_EMoamEEZ73f0CkXaXp7hrann" + // OAuth Client ID for Sora mobile flow (aligned with sora2api) + SoraClientID = "app_LlGpXReQgckcGGUo2JrYvtJK" // OAuth endpoints AuthorizeURL = "https://auth.openai.com/oauth/authorize" @@ -47,6 +49,7 @@ type OAuthSession struct { type SessionStore struct { mu sync.RWMutex sessions map[string]*OAuthSession + stopOnce sync.Once stopCh chan struct{} } @@ -92,7 +95,9 @@ func (s *SessionStore) Delete(sessionID string) { // Stop stops the cleanup goroutine func (s *SessionStore) Stop() { - close(s.stopCh) + s.stopOnce.Do(func() { + close(s.stopCh) + }) } // cleanup removes expired sessions periodically diff --git a/backend/internal/pkg/openai/oauth_test.go b/backend/internal/pkg/openai/oauth_test.go new file mode 100644 index 00000000..f1d616a6 --- /dev/null +++ b/backend/internal/pkg/openai/oauth_test.go @@ -0,0 +1,43 @@ +package openai + +import ( + "sync" + "testing" + "time" +) + +func TestSessionStore_Stop_Idempotent(t *testing.T) { + store := NewSessionStore() + + store.Stop() + store.Stop() + + select { + case <-store.stopCh: + // ok + case <-time.After(time.Second): + t.Fatal("stopCh 未关闭") + } +} + +func TestSessionStore_Stop_Concurrent(t *testing.T) { + store := NewSessionStore() + + var wg sync.WaitGroup + for range 50 { + wg.Add(1) + go func() { + defer wg.Done() + store.Stop() + }() + } + + wg.Wait() + + select { + case <-store.stopCh: + // ok + case <-time.After(time.Second): + t.Fatal("stopCh 未关闭") + } +} diff --git a/backend/internal/pkg/openai/request.go b/backend/internal/pkg/openai/request.go index 5b049ddc..c24d1273 100644 --- a/backend/internal/pkg/openai/request.go +++ b/backend/internal/pkg/openai/request.go @@ -1,5 +1,7 @@ package openai +import "strings" + // CodexCLIUserAgentPrefixes matches Codex CLI User-Agent patterns // Examples: "codex_vscode/1.0.0", "codex_cli_rs/0.1.2" var CodexCLIUserAgentPrefixes = []string{ @@ -7,10 +9,67 @@ var CodexCLIUserAgentPrefixes = []string{ "codex_cli_rs/", } +// CodexOfficialClientUserAgentPrefixes matches Codex 官方客户端家族 User-Agent 前缀。 +// 该列表仅用于 OpenAI OAuth `codex_cli_only` 访问限制判定。 +var CodexOfficialClientUserAgentPrefixes = []string{ + "codex_cli_rs/", + "codex_vscode/", + "codex_app/", + "codex_chatgpt_desktop/", + "codex_atlas/", + "codex_exec/", + "codex_sdk_ts/", + "codex ", +} + +// CodexOfficialClientOriginatorPrefixes matches Codex 官方客户端家族 originator 前缀。 +// 说明:OpenAI 官方 Codex 客户端并不只使用固定的 codex_app 标识。 +// 例如 codex_cli_rs、codex_vscode、codex_chatgpt_desktop、codex_atlas、codex_exec、codex_sdk_ts 等。 +var CodexOfficialClientOriginatorPrefixes = []string{ + "codex_", + "codex ", +} + // IsCodexCLIRequest checks if the User-Agent indicates a Codex CLI request func IsCodexCLIRequest(userAgent string) bool { - for _, prefix := range CodexCLIUserAgentPrefixes { - if len(userAgent) >= len(prefix) && userAgent[:len(prefix)] == prefix { + ua := normalizeCodexClientHeader(userAgent) + if ua == "" { + return false + } + return matchCodexClientHeaderPrefixes(ua, CodexCLIUserAgentPrefixes) +} + +// IsCodexOfficialClientRequest checks if the User-Agent indicates a Codex 官方客户端请求。 +// 与 IsCodexCLIRequest 解耦,避免影响历史兼容逻辑。 +func IsCodexOfficialClientRequest(userAgent string) bool { + ua := normalizeCodexClientHeader(userAgent) + if ua == "" { + return false + } + return matchCodexClientHeaderPrefixes(ua, CodexOfficialClientUserAgentPrefixes) +} + +// IsCodexOfficialClientOriginator checks if originator indicates a Codex 官方客户端请求。 +func IsCodexOfficialClientOriginator(originator string) bool { + v := normalizeCodexClientHeader(originator) + if v == "" { + return false + } + return matchCodexClientHeaderPrefixes(v, CodexOfficialClientOriginatorPrefixes) +} + +func normalizeCodexClientHeader(value string) string { + return strings.ToLower(strings.TrimSpace(value)) +} + +func matchCodexClientHeaderPrefixes(value string, prefixes []string) bool { + for _, prefix := range prefixes { + normalizedPrefix := normalizeCodexClientHeader(prefix) + if normalizedPrefix == "" { + continue + } + // 优先前缀匹配;若 UA/Originator 被网关拼接为复合字符串时,退化为包含匹配。 + if strings.HasPrefix(value, normalizedPrefix) || strings.Contains(value, normalizedPrefix) { return true } } diff --git a/backend/internal/pkg/openai/request_test.go b/backend/internal/pkg/openai/request_test.go new file mode 100644 index 00000000..508bf561 --- /dev/null +++ b/backend/internal/pkg/openai/request_test.go @@ -0,0 +1,87 @@ +package openai + +import "testing" + +func TestIsCodexCLIRequest(t *testing.T) { + tests := []struct { + name string + ua string + want bool + }{ + {name: "codex_cli_rs 前缀", ua: "codex_cli_rs/0.1.0", want: true}, + {name: "codex_vscode 前缀", ua: "codex_vscode/1.2.3", want: true}, + {name: "大小写混合", ua: "Codex_CLI_Rs/0.1.0", want: true}, + {name: "复合 UA 包含 codex", ua: "Mozilla/5.0 codex_cli_rs/0.1.0", want: true}, + {name: "空白包裹", ua: " codex_vscode/1.2.3 ", want: true}, + {name: "非 codex", ua: "curl/8.0.1", want: false}, + {name: "空字符串", ua: "", want: false}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got := IsCodexCLIRequest(tt.ua) + if got != tt.want { + t.Fatalf("IsCodexCLIRequest(%q) = %v, want %v", tt.ua, got, tt.want) + } + }) + } +} + +func TestIsCodexOfficialClientRequest(t *testing.T) { + tests := []struct { + name string + ua string + want bool + }{ + {name: "codex_cli_rs 前缀", ua: "codex_cli_rs/0.98.0", want: true}, + {name: "codex_vscode 前缀", ua: "codex_vscode/1.0.0", want: true}, + {name: "codex_app 前缀", ua: "codex_app/0.1.0", want: true}, + {name: "codex_chatgpt_desktop 前缀", ua: "codex_chatgpt_desktop/1.0.0", want: true}, + {name: "codex_atlas 前缀", ua: "codex_atlas/1.0.0", want: true}, + {name: "codex_exec 前缀", ua: "codex_exec/0.1.0", want: true}, + {name: "codex_sdk_ts 前缀", ua: "codex_sdk_ts/0.1.0", want: true}, + {name: "Codex 桌面 UA", ua: "Codex Desktop/1.2.3", want: true}, + {name: "复合 UA 包含 codex_app", ua: "Mozilla/5.0 codex_app/0.1.0", want: true}, + {name: "大小写混合", ua: "Codex_VSCode/1.2.3", want: true}, + {name: "非 codex", ua: "curl/8.0.1", want: false}, + {name: "空字符串", ua: "", want: false}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got := IsCodexOfficialClientRequest(tt.ua) + if got != tt.want { + t.Fatalf("IsCodexOfficialClientRequest(%q) = %v, want %v", tt.ua, got, tt.want) + } + }) + } +} + +func TestIsCodexOfficialClientOriginator(t *testing.T) { + tests := []struct { + name string + originator string + want bool + }{ + {name: "codex_cli_rs", originator: "codex_cli_rs", want: true}, + {name: "codex_vscode", originator: "codex_vscode", want: true}, + {name: "codex_app", originator: "codex_app", want: true}, + {name: "codex_chatgpt_desktop", originator: "codex_chatgpt_desktop", want: true}, + {name: "codex_atlas", originator: "codex_atlas", want: true}, + {name: "codex_exec", originator: "codex_exec", want: true}, + {name: "codex_sdk_ts", originator: "codex_sdk_ts", want: true}, + {name: "Codex 前缀", originator: "Codex Desktop", want: true}, + {name: "空白包裹", originator: " codex_vscode ", want: true}, + {name: "非 codex", originator: "my_client", want: false}, + {name: "空字符串", originator: "", want: false}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got := IsCodexOfficialClientOriginator(tt.originator) + if got != tt.want { + t.Fatalf("IsCodexOfficialClientOriginator(%q) = %v, want %v", tt.originator, got, tt.want) + } + }) + } +} diff --git a/backend/internal/pkg/response/response.go b/backend/internal/pkg/response/response.go index c5b41d6e..0519c2cc 100644 --- a/backend/internal/pkg/response/response.go +++ b/backend/internal/pkg/response/response.go @@ -7,6 +7,7 @@ import ( "net/http" infraerrors "github.com/Wei-Shaw/sub2api/internal/pkg/errors" + "github.com/Wei-Shaw/sub2api/internal/util/logredact" "github.com/gin-gonic/gin" ) @@ -78,7 +79,7 @@ func ErrorFrom(c *gin.Context, err error) bool { // Log internal errors with full details for debugging if statusCode >= 500 && c.Request != nil { - log.Printf("[ERROR] %s %s\n Error: %s", c.Request.Method, c.Request.URL.Path, err.Error()) + log.Printf("[ERROR] %s %s\n Error: %s", c.Request.Method, c.Request.URL.Path, logredact.RedactText(err.Error())) } ErrorWithDetails(c, statusCode, status.Message, status.Reason, status.Metadata) diff --git a/backend/internal/pkg/response/response_test.go b/backend/internal/pkg/response/response_test.go index ef31ca3c..3c12f5f4 100644 --- a/backend/internal/pkg/response/response_test.go +++ b/backend/internal/pkg/response/response_test.go @@ -14,6 +14,44 @@ import ( "github.com/stretchr/testify/require" ) +// ---------- 辅助函数 ---------- + +// parseResponseBody 从 httptest.ResponseRecorder 中解析 JSON 响应体 +func parseResponseBody(t *testing.T, w *httptest.ResponseRecorder) Response { + t.Helper() + var got Response + require.NoError(t, json.Unmarshal(w.Body.Bytes(), &got)) + return got +} + +// parsePaginatedBody 从响应体中解析分页数据(Data 字段是 PaginatedData) +func parsePaginatedBody(t *testing.T, w *httptest.ResponseRecorder) (Response, PaginatedData) { + t.Helper() + // 先用 raw json 解析,因为 Data 是 any 类型 + var raw struct { + Code int `json:"code"` + Message string `json:"message"` + Reason string `json:"reason,omitempty"` + Data json.RawMessage `json:"data,omitempty"` + } + require.NoError(t, json.Unmarshal(w.Body.Bytes(), &raw)) + + var pd PaginatedData + require.NoError(t, json.Unmarshal(raw.Data, &pd)) + + return Response{Code: raw.Code, Message: raw.Message, Reason: raw.Reason}, pd +} + +// newContextWithQuery 创建一个带有 URL query 参数的 gin.Context 用于测试 ParsePagination +func newContextWithQuery(query string) (*httptest.ResponseRecorder, *gin.Context) { + w := httptest.NewRecorder() + c, _ := gin.CreateTestContext(w) + c.Request = httptest.NewRequest(http.MethodGet, "/?"+query, nil) + return w, c +} + +// ---------- 现有测试 ---------- + func TestErrorWithDetails(t *testing.T) { gin.SetMode(gin.TestMode) @@ -169,3 +207,582 @@ func TestErrorFrom(t *testing.T) { }) } } + +// ---------- 新增测试 ---------- + +func TestSuccess(t *testing.T) { + gin.SetMode(gin.TestMode) + + tests := []struct { + name string + data any + wantCode int + wantBody Response + }{ + { + name: "返回字符串数据", + data: "hello", + wantCode: http.StatusOK, + wantBody: Response{Code: 0, Message: "success", Data: "hello"}, + }, + { + name: "返回nil数据", + data: nil, + wantCode: http.StatusOK, + wantBody: Response{Code: 0, Message: "success"}, + }, + { + name: "返回map数据", + data: map[string]string{"key": "value"}, + wantCode: http.StatusOK, + wantBody: Response{Code: 0, Message: "success"}, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + w := httptest.NewRecorder() + c, _ := gin.CreateTestContext(w) + + Success(c, tt.data) + + require.Equal(t, tt.wantCode, w.Code) + + // 只验证 code 和 message,data 字段类型在 JSON 反序列化时会变成 map/slice + got := parseResponseBody(t, w) + require.Equal(t, 0, got.Code) + require.Equal(t, "success", got.Message) + + if tt.data == nil { + require.Nil(t, got.Data) + } else { + require.NotNil(t, got.Data) + } + }) + } +} + +func TestCreated(t *testing.T) { + gin.SetMode(gin.TestMode) + + tests := []struct { + name string + data any + wantCode int + }{ + { + name: "创建成功_返回数据", + data: map[string]int{"id": 42}, + wantCode: http.StatusCreated, + }, + { + name: "创建成功_nil数据", + data: nil, + wantCode: http.StatusCreated, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + w := httptest.NewRecorder() + c, _ := gin.CreateTestContext(w) + + Created(c, tt.data) + + require.Equal(t, tt.wantCode, w.Code) + + got := parseResponseBody(t, w) + require.Equal(t, 0, got.Code) + require.Equal(t, "success", got.Message) + }) + } +} + +func TestError(t *testing.T) { + gin.SetMode(gin.TestMode) + + tests := []struct { + name string + statusCode int + message string + }{ + { + name: "400错误", + statusCode: http.StatusBadRequest, + message: "bad request", + }, + { + name: "500错误", + statusCode: http.StatusInternalServerError, + message: "internal error", + }, + { + name: "自定义状态码", + statusCode: 418, + message: "I'm a teapot", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + w := httptest.NewRecorder() + c, _ := gin.CreateTestContext(w) + + Error(c, tt.statusCode, tt.message) + + require.Equal(t, tt.statusCode, w.Code) + + got := parseResponseBody(t, w) + require.Equal(t, tt.statusCode, got.Code) + require.Equal(t, tt.message, got.Message) + require.Empty(t, got.Reason) + require.Nil(t, got.Metadata) + require.Nil(t, got.Data) + }) + } +} + +func TestBadRequest(t *testing.T) { + gin.SetMode(gin.TestMode) + + w := httptest.NewRecorder() + c, _ := gin.CreateTestContext(w) + + BadRequest(c, "参数无效") + + require.Equal(t, http.StatusBadRequest, w.Code) + got := parseResponseBody(t, w) + require.Equal(t, http.StatusBadRequest, got.Code) + require.Equal(t, "参数无效", got.Message) +} + +func TestUnauthorized(t *testing.T) { + gin.SetMode(gin.TestMode) + + w := httptest.NewRecorder() + c, _ := gin.CreateTestContext(w) + + Unauthorized(c, "未登录") + + require.Equal(t, http.StatusUnauthorized, w.Code) + got := parseResponseBody(t, w) + require.Equal(t, http.StatusUnauthorized, got.Code) + require.Equal(t, "未登录", got.Message) +} + +func TestForbidden(t *testing.T) { + gin.SetMode(gin.TestMode) + + w := httptest.NewRecorder() + c, _ := gin.CreateTestContext(w) + + Forbidden(c, "无权限") + + require.Equal(t, http.StatusForbidden, w.Code) + got := parseResponseBody(t, w) + require.Equal(t, http.StatusForbidden, got.Code) + require.Equal(t, "无权限", got.Message) +} + +func TestNotFound(t *testing.T) { + gin.SetMode(gin.TestMode) + + w := httptest.NewRecorder() + c, _ := gin.CreateTestContext(w) + + NotFound(c, "资源不存在") + + require.Equal(t, http.StatusNotFound, w.Code) + got := parseResponseBody(t, w) + require.Equal(t, http.StatusNotFound, got.Code) + require.Equal(t, "资源不存在", got.Message) +} + +func TestInternalError(t *testing.T) { + gin.SetMode(gin.TestMode) + + w := httptest.NewRecorder() + c, _ := gin.CreateTestContext(w) + + InternalError(c, "服务器内部错误") + + require.Equal(t, http.StatusInternalServerError, w.Code) + got := parseResponseBody(t, w) + require.Equal(t, http.StatusInternalServerError, got.Code) + require.Equal(t, "服务器内部错误", got.Message) +} + +func TestPaginated(t *testing.T) { + gin.SetMode(gin.TestMode) + + tests := []struct { + name string + items any + total int64 + page int + pageSize int + wantPages int + wantTotal int64 + wantPage int + wantPageSize int + }{ + { + name: "标准分页_多页", + items: []string{"a", "b"}, + total: 25, + page: 1, + pageSize: 10, + wantPages: 3, + wantTotal: 25, + wantPage: 1, + wantPageSize: 10, + }, + { + name: "总数刚好整除", + items: []string{"a"}, + total: 20, + page: 2, + pageSize: 10, + wantPages: 2, + wantTotal: 20, + wantPage: 2, + wantPageSize: 10, + }, + { + name: "总数为0_pages至少为1", + items: []string{}, + total: 0, + page: 1, + pageSize: 10, + wantPages: 1, + wantTotal: 0, + wantPage: 1, + wantPageSize: 10, + }, + { + name: "单页数据", + items: []int{1, 2, 3}, + total: 3, + page: 1, + pageSize: 20, + wantPages: 1, + wantTotal: 3, + wantPage: 1, + wantPageSize: 20, + }, + { + name: "总数为1", + items: []string{"only"}, + total: 1, + page: 1, + pageSize: 10, + wantPages: 1, + wantTotal: 1, + wantPage: 1, + wantPageSize: 10, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + w := httptest.NewRecorder() + c, _ := gin.CreateTestContext(w) + + Paginated(c, tt.items, tt.total, tt.page, tt.pageSize) + + require.Equal(t, http.StatusOK, w.Code) + + resp, pd := parsePaginatedBody(t, w) + require.Equal(t, 0, resp.Code) + require.Equal(t, "success", resp.Message) + require.Equal(t, tt.wantTotal, pd.Total) + require.Equal(t, tt.wantPage, pd.Page) + require.Equal(t, tt.wantPageSize, pd.PageSize) + require.Equal(t, tt.wantPages, pd.Pages) + }) + } +} + +func TestPaginatedWithResult(t *testing.T) { + gin.SetMode(gin.TestMode) + + tests := []struct { + name string + items any + pagination *PaginationResult + wantTotal int64 + wantPage int + wantPageSize int + wantPages int + }{ + { + name: "正常分页结果", + items: []string{"a", "b"}, + pagination: &PaginationResult{ + Total: 50, + Page: 3, + PageSize: 10, + Pages: 5, + }, + wantTotal: 50, + wantPage: 3, + wantPageSize: 10, + wantPages: 5, + }, + { + name: "pagination为nil_使用默认值", + items: []string{}, + pagination: nil, + wantTotal: 0, + wantPage: 1, + wantPageSize: 20, + wantPages: 1, + }, + { + name: "单页结果", + items: []int{1}, + pagination: &PaginationResult{ + Total: 1, + Page: 1, + PageSize: 20, + Pages: 1, + }, + wantTotal: 1, + wantPage: 1, + wantPageSize: 20, + wantPages: 1, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + w := httptest.NewRecorder() + c, _ := gin.CreateTestContext(w) + + PaginatedWithResult(c, tt.items, tt.pagination) + + require.Equal(t, http.StatusOK, w.Code) + + resp, pd := parsePaginatedBody(t, w) + require.Equal(t, 0, resp.Code) + require.Equal(t, "success", resp.Message) + require.Equal(t, tt.wantTotal, pd.Total) + require.Equal(t, tt.wantPage, pd.Page) + require.Equal(t, tt.wantPageSize, pd.PageSize) + require.Equal(t, tt.wantPages, pd.Pages) + }) + } +} + +func TestParsePagination(t *testing.T) { + gin.SetMode(gin.TestMode) + + tests := []struct { + name string + query string + wantPage int + wantPageSize int + }{ + { + name: "无参数_使用默认值", + query: "", + wantPage: 1, + wantPageSize: 20, + }, + { + name: "仅指定page", + query: "page=3", + wantPage: 3, + wantPageSize: 20, + }, + { + name: "仅指定page_size", + query: "page_size=50", + wantPage: 1, + wantPageSize: 50, + }, + { + name: "同时指定page和page_size", + query: "page=2&page_size=30", + wantPage: 2, + wantPageSize: 30, + }, + { + name: "使用limit代替page_size", + query: "limit=15", + wantPage: 1, + wantPageSize: 15, + }, + { + name: "page_size优先于limit", + query: "page_size=25&limit=50", + wantPage: 1, + wantPageSize: 25, + }, + { + name: "page为0_使用默认值", + query: "page=0", + wantPage: 1, + wantPageSize: 20, + }, + { + name: "page_size超过1000_使用默认值", + query: "page_size=1001", + wantPage: 1, + wantPageSize: 20, + }, + { + name: "page_size恰好1000_有效", + query: "page_size=1000", + wantPage: 1, + wantPageSize: 1000, + }, + { + name: "page为非数字_使用默认值", + query: "page=abc", + wantPage: 1, + wantPageSize: 20, + }, + { + name: "page_size为非数字_使用默认值", + query: "page_size=xyz", + wantPage: 1, + wantPageSize: 20, + }, + { + name: "limit为非数字_使用默认值", + query: "limit=abc", + wantPage: 1, + wantPageSize: 20, + }, + { + name: "page_size为0_使用默认值", + query: "page_size=0", + wantPage: 1, + wantPageSize: 20, + }, + { + name: "limit为0_使用默认值", + query: "limit=0", + wantPage: 1, + wantPageSize: 20, + }, + { + name: "大页码", + query: "page=999&page_size=100", + wantPage: 999, + wantPageSize: 100, + }, + { + name: "page_size为1_最小有效值", + query: "page_size=1", + wantPage: 1, + wantPageSize: 1, + }, + { + name: "混合数字和字母的page", + query: "page=12a", + wantPage: 1, + wantPageSize: 20, + }, + { + name: "limit超过1000_使用默认值", + query: "limit=2000", + wantPage: 1, + wantPageSize: 20, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + _, c := newContextWithQuery(tt.query) + + page, pageSize := ParsePagination(c) + + require.Equal(t, tt.wantPage, page, "page 不符合预期") + require.Equal(t, tt.wantPageSize, pageSize, "pageSize 不符合预期") + }) + } +} + +func Test_parseInt(t *testing.T) { + tests := []struct { + name string + input string + wantVal int + wantErr bool + }{ + { + name: "正常数字", + input: "123", + wantVal: 123, + wantErr: false, + }, + { + name: "零", + input: "0", + wantVal: 0, + wantErr: false, + }, + { + name: "单个数字", + input: "5", + wantVal: 5, + wantErr: false, + }, + { + name: "大数字", + input: "99999", + wantVal: 99999, + wantErr: false, + }, + { + name: "包含字母_返回0", + input: "abc", + wantVal: 0, + wantErr: false, + }, + { + name: "数字开头接字母_返回0", + input: "12a", + wantVal: 0, + wantErr: false, + }, + { + name: "包含负号_返回0", + input: "-1", + wantVal: 0, + wantErr: false, + }, + { + name: "包含小数点_返回0", + input: "1.5", + wantVal: 0, + wantErr: false, + }, + { + name: "包含空格_返回0", + input: "1 2", + wantVal: 0, + wantErr: false, + }, + { + name: "空字符串", + input: "", + wantVal: 0, + wantErr: false, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + val, err := parseInt(tt.input) + if tt.wantErr { + require.Error(t, err) + } else { + require.NoError(t, err) + } + require.Equal(t, tt.wantVal, val) + }) + } +} diff --git a/backend/internal/pkg/tlsfingerprint/dialer.go b/backend/internal/pkg/tlsfingerprint/dialer.go index 42510986..992f8b0a 100644 --- a/backend/internal/pkg/tlsfingerprint/dialer.go +++ b/backend/internal/pkg/tlsfingerprint/dialer.go @@ -286,7 +286,7 @@ func (d *SOCKS5ProxyDialer) DialTLSContext(ctx context.Context, network, addr st return nil, fmt.Errorf("apply TLS preset: %w", err) } - if err := tlsConn.Handshake(); err != nil { + if err := tlsConn.HandshakeContext(ctx); err != nil { slog.Debug("tls_fingerprint_socks5_handshake_failed", "error", err) _ = conn.Close() return nil, fmt.Errorf("TLS handshake failed: %w", err) diff --git a/backend/internal/pkg/tlsfingerprint/dialer_test.go b/backend/internal/pkg/tlsfingerprint/dialer_test.go index dff7570f..6d3db174 100644 --- a/backend/internal/pkg/tlsfingerprint/dialer_test.go +++ b/backend/internal/pkg/tlsfingerprint/dialer_test.go @@ -1,3 +1,5 @@ +//go:build unit + // Package tlsfingerprint provides TLS fingerprint simulation for HTTP clients. // // Unit tests for TLS fingerprint dialer. @@ -9,26 +11,161 @@ package tlsfingerprint import ( + "context" + "encoding/json" + "io" + "net/http" "net/url" + "os" + "strings" "testing" + "time" ) -// FingerprintResponse represents the response from tls.peet.ws/api/all. -type FingerprintResponse struct { - IP string `json:"ip"` - TLS TLSInfo `json:"tls"` - HTTP2 any `json:"http2"` +// TestDialerBasicConnection tests that the dialer can establish TLS connections. +func TestDialerBasicConnection(t *testing.T) { + skipNetworkTest(t) + + // Create a dialer with default profile + profile := &Profile{ + Name: "Test Profile", + EnableGREASE: false, + } + dialer := NewDialer(profile, nil) + + // Create HTTP client with custom TLS dialer + client := &http.Client{ + Transport: &http.Transport{ + DialTLSContext: dialer.DialTLSContext, + }, + Timeout: 30 * time.Second, + } + + // Make a request to a known HTTPS endpoint + resp, err := client.Get("https://www.google.com") + if err != nil { + t.Fatalf("failed to connect: %v", err) + } + defer func() { _ = resp.Body.Close() }() + + if resp.StatusCode != http.StatusOK { + t.Errorf("expected status 200, got %d", resp.StatusCode) + } } -// TLSInfo contains TLS fingerprint details. -type TLSInfo struct { - JA3 string `json:"ja3"` - JA3Hash string `json:"ja3_hash"` - JA4 string `json:"ja4"` - PeetPrint string `json:"peetprint"` - PeetPrintHash string `json:"peetprint_hash"` - ClientRandom string `json:"client_random"` - SessionID string `json:"session_id"` +// TestJA3Fingerprint verifies the JA3/JA4 fingerprint matches expected value. +// This test uses tls.peet.ws to verify the fingerprint. +// Expected JA3 hash: 1a28e69016765d92e3b381168d68922c (Claude CLI / Node.js 20.x) +// Expected JA4: t13d5911h1_a33745022dd6_1f22a2ca17c4 (d=domain) or t13i5911h1_... (i=IP) +func TestJA3Fingerprint(t *testing.T) { + skipNetworkTest(t) + + profile := &Profile{ + Name: "Claude CLI Test", + EnableGREASE: false, + } + dialer := NewDialer(profile, nil) + + client := &http.Client{ + Transport: &http.Transport{ + DialTLSContext: dialer.DialTLSContext, + }, + Timeout: 30 * time.Second, + } + + // Use tls.peet.ws fingerprint detection API + ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second) + defer cancel() + + req, err := http.NewRequestWithContext(ctx, "GET", "https://tls.peet.ws/api/all", nil) + if err != nil { + t.Fatalf("failed to create request: %v", err) + } + req.Header.Set("User-Agent", "Claude Code/2.0.0 Node.js/20.0.0") + + resp, err := client.Do(req) + if err != nil { + t.Fatalf("failed to get fingerprint: %v", err) + } + defer func() { _ = resp.Body.Close() }() + + body, err := io.ReadAll(resp.Body) + if err != nil { + t.Fatalf("failed to read response: %v", err) + } + + var fpResp FingerprintResponse + if err := json.Unmarshal(body, &fpResp); err != nil { + t.Logf("Response body: %s", string(body)) + t.Fatalf("failed to parse fingerprint response: %v", err) + } + + // Log all fingerprint information + t.Logf("JA3: %s", fpResp.TLS.JA3) + t.Logf("JA3 Hash: %s", fpResp.TLS.JA3Hash) + t.Logf("JA4: %s", fpResp.TLS.JA4) + t.Logf("PeetPrint: %s", fpResp.TLS.PeetPrint) + t.Logf("PeetPrint Hash: %s", fpResp.TLS.PeetPrintHash) + + // Verify JA3 hash matches expected value + expectedJA3Hash := "1a28e69016765d92e3b381168d68922c" + if fpResp.TLS.JA3Hash == expectedJA3Hash { + t.Logf("✓ JA3 hash matches expected value: %s", expectedJA3Hash) + } else { + t.Errorf("✗ JA3 hash mismatch: got %s, expected %s", fpResp.TLS.JA3Hash, expectedJA3Hash) + } + + // Verify JA4 fingerprint + // JA4 format: t[version][sni][cipher_count][ext_count][alpn]_[cipher_hash]_[ext_hash] + // Expected: t13d5910h1 (d=domain) or t13i5910h1 (i=IP) + // The suffix _a33745022dd6_1f22a2ca17c4 should match + expectedJA4Suffix := "_a33745022dd6_1f22a2ca17c4" + if strings.HasSuffix(fpResp.TLS.JA4, expectedJA4Suffix) { + t.Logf("✓ JA4 suffix matches expected value: %s", expectedJA4Suffix) + } else { + t.Errorf("✗ JA4 suffix mismatch: got %s, expected suffix %s", fpResp.TLS.JA4, expectedJA4Suffix) + } + + // Verify JA4 prefix (t13d5911h1 or t13i5911h1) + // d = domain (SNI present), i = IP (no SNI) + // Since we connect to tls.peet.ws (domain), we expect 'd' + expectedJA4Prefix := "t13d5911h1" + if strings.HasPrefix(fpResp.TLS.JA4, expectedJA4Prefix) { + t.Logf("✓ JA4 prefix matches: %s (t13=TLS1.3, d=domain, 59=ciphers, 11=extensions, h1=HTTP/1.1)", expectedJA4Prefix) + } else { + // Also accept 'i' variant for IP connections + altPrefix := "t13i5911h1" + if strings.HasPrefix(fpResp.TLS.JA4, altPrefix) { + t.Logf("✓ JA4 prefix matches (IP variant): %s", altPrefix) + } else { + t.Errorf("✗ JA4 prefix mismatch: got %s, expected %s or %s", fpResp.TLS.JA4, expectedJA4Prefix, altPrefix) + } + } + + // Verify JA3 contains expected cipher suites (TLS 1.3 ciphers at the beginning) + if strings.Contains(fpResp.TLS.JA3, "4866-4867-4865") { + t.Logf("✓ JA3 contains expected TLS 1.3 cipher suites") + } else { + t.Logf("Warning: JA3 does not contain expected TLS 1.3 cipher suites") + } + + // Verify extension list (should be 11 extensions including SNI) + // Expected: 0-11-10-35-16-22-23-13-43-45-51 + expectedExtensions := "0-11-10-35-16-22-23-13-43-45-51" + if strings.Contains(fpResp.TLS.JA3, expectedExtensions) { + t.Logf("✓ JA3 contains expected extension list: %s", expectedExtensions) + } else { + t.Logf("Warning: JA3 extension list may differ") + } +} + +func skipNetworkTest(t *testing.T) { + if testing.Short() { + t.Skip("跳过网络测试(short 模式)") + } + if os.Getenv("TLSFINGERPRINT_NETWORK_TESTS") != "1" { + t.Skip("跳过网络测试(需要设置 TLSFINGERPRINT_NETWORK_TESTS=1)") + } } // TestDialerWithProfile tests that different profiles produce different fingerprints. @@ -158,3 +295,137 @@ func mustParseURL(rawURL string) *url.URL { } return u } + +// TestProfileExpectation defines expected fingerprint values for a profile. +type TestProfileExpectation struct { + Profile *Profile + ExpectedJA3 string // Expected JA3 hash (empty = don't check) + ExpectedJA4 string // Expected full JA4 (empty = don't check) + JA4CipherHash string // Expected JA4 cipher hash - the stable middle part (empty = don't check) +} + +// TestAllProfiles tests multiple TLS fingerprint profiles against tls.peet.ws. +// Run with: go test -v -run TestAllProfiles ./internal/pkg/tlsfingerprint/... +func TestAllProfiles(t *testing.T) { + skipNetworkTest(t) + + // Define all profiles to test with their expected fingerprints + // These profiles are from config.yaml gateway.tls_fingerprint.profiles + profiles := []TestProfileExpectation{ + { + // Linux x64 Node.js v22.17.1 + // Expected JA3 Hash: 1a28e69016765d92e3b381168d68922c + // Expected JA4: t13d5911h1_a33745022dd6_1f22a2ca17c4 + Profile: &Profile{ + Name: "linux_x64_node_v22171", + EnableGREASE: false, + CipherSuites: []uint16{4866, 4867, 4865, 49199, 49195, 49200, 49196, 158, 49191, 103, 49192, 107, 163, 159, 52393, 52392, 52394, 49327, 49325, 49315, 49311, 49245, 49249, 49239, 49235, 162, 49326, 49324, 49314, 49310, 49244, 49248, 49238, 49234, 49188, 106, 49187, 64, 49162, 49172, 57, 56, 49161, 49171, 51, 50, 157, 49313, 49309, 49233, 156, 49312, 49308, 49232, 61, 60, 53, 47, 255}, + Curves: []uint16{29, 23, 30, 25, 24, 256, 257, 258, 259, 260}, + PointFormats: []uint8{0, 1, 2}, + }, + JA4CipherHash: "a33745022dd6", // stable part + }, + { + // MacOS arm64 Node.js v22.18.0 + // Expected JA3 Hash: 70cb5ca646080902703ffda87036a5ea + // Expected JA4: t13d5912h1_a33745022dd6_dbd39dd1d406 + Profile: &Profile{ + Name: "macos_arm64_node_v22180", + EnableGREASE: false, + CipherSuites: []uint16{4866, 4867, 4865, 49199, 49195, 49200, 49196, 158, 49191, 103, 49192, 107, 163, 159, 52393, 52392, 52394, 49327, 49325, 49315, 49311, 49245, 49249, 49239, 49235, 162, 49326, 49324, 49314, 49310, 49244, 49248, 49238, 49234, 49188, 106, 49187, 64, 49162, 49172, 57, 56, 49161, 49171, 51, 50, 157, 49313, 49309, 49233, 156, 49312, 49308, 49232, 61, 60, 53, 47, 255}, + Curves: []uint16{29, 23, 30, 25, 24, 256, 257, 258, 259, 260}, + PointFormats: []uint8{0, 1, 2}, + }, + JA4CipherHash: "a33745022dd6", // stable part (same cipher suites) + }, + } + + for _, tc := range profiles { + tc := tc // capture range variable + t.Run(tc.Profile.Name, func(t *testing.T) { + fp := fetchFingerprint(t, tc.Profile) + if fp == nil { + return // fetchFingerprint already called t.Fatal + } + + t.Logf("Profile: %s", tc.Profile.Name) + t.Logf(" JA3: %s", fp.JA3) + t.Logf(" JA3 Hash: %s", fp.JA3Hash) + t.Logf(" JA4: %s", fp.JA4) + t.Logf(" PeetPrint: %s", fp.PeetPrint) + t.Logf(" PeetPrintHash: %s", fp.PeetPrintHash) + + // Verify expectations + if tc.ExpectedJA3 != "" { + if fp.JA3Hash == tc.ExpectedJA3 { + t.Logf(" ✓ JA3 hash matches: %s", tc.ExpectedJA3) + } else { + t.Errorf(" ✗ JA3 hash mismatch: got %s, expected %s", fp.JA3Hash, tc.ExpectedJA3) + } + } + + if tc.ExpectedJA4 != "" { + if fp.JA4 == tc.ExpectedJA4 { + t.Logf(" ✓ JA4 matches: %s", tc.ExpectedJA4) + } else { + t.Errorf(" ✗ JA4 mismatch: got %s, expected %s", fp.JA4, tc.ExpectedJA4) + } + } + + // Check JA4 cipher hash (stable middle part) + // JA4 format: prefix_cipherHash_extHash + if tc.JA4CipherHash != "" { + if strings.Contains(fp.JA4, "_"+tc.JA4CipherHash+"_") { + t.Logf(" ✓ JA4 cipher hash matches: %s", tc.JA4CipherHash) + } else { + t.Errorf(" ✗ JA4 cipher hash mismatch: got %s, expected cipher hash %s", fp.JA4, tc.JA4CipherHash) + } + } + }) + } +} + +// fetchFingerprint makes a request to tls.peet.ws and returns the TLS fingerprint info. +func fetchFingerprint(t *testing.T, profile *Profile) *TLSInfo { + t.Helper() + + dialer := NewDialer(profile, nil) + client := &http.Client{ + Transport: &http.Transport{ + DialTLSContext: dialer.DialTLSContext, + }, + Timeout: 30 * time.Second, + } + + ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second) + defer cancel() + + req, err := http.NewRequestWithContext(ctx, "GET", "https://tls.peet.ws/api/all", nil) + if err != nil { + t.Fatalf("failed to create request: %v", err) + return nil + } + req.Header.Set("User-Agent", "Claude Code/2.0.0 Node.js/20.0.0") + + resp, err := client.Do(req) + if err != nil { + t.Fatalf("failed to get fingerprint: %v", err) + return nil + } + defer func() { _ = resp.Body.Close() }() + + body, err := io.ReadAll(resp.Body) + if err != nil { + t.Fatalf("failed to read response: %v", err) + return nil + } + + var fpResp FingerprintResponse + if err := json.Unmarshal(body, &fpResp); err != nil { + t.Logf("Response body: %s", string(body)) + t.Fatalf("failed to parse fingerprint response: %v", err) + return nil + } + + return &fpResp.TLS +} diff --git a/backend/internal/pkg/tlsfingerprint/test_types_test.go b/backend/internal/pkg/tlsfingerprint/test_types_test.go new file mode 100644 index 00000000..2bbf2d22 --- /dev/null +++ b/backend/internal/pkg/tlsfingerprint/test_types_test.go @@ -0,0 +1,20 @@ +package tlsfingerprint + +// FingerprintResponse represents the response from tls.peet.ws/api/all. +// 共享测试类型,供 unit 和 integration 测试文件使用。 +type FingerprintResponse struct { + IP string `json:"ip"` + TLS TLSInfo `json:"tls"` + HTTP2 any `json:"http2"` +} + +// TLSInfo contains TLS fingerprint details. +type TLSInfo struct { + JA3 string `json:"ja3"` + JA3Hash string `json:"ja3_hash"` + JA4 string `json:"ja4"` + PeetPrint string `json:"peetprint"` + PeetPrintHash string `json:"peetprint_hash"` + ClientRandom string `json:"client_random"` + SessionID string `json:"session_id"` +} diff --git a/backend/internal/repository/account_repo.go b/backend/internal/repository/account_repo.go index e3e70213..3f77a57e 100644 --- a/backend/internal/repository/account_repo.go +++ b/backend/internal/repository/account_repo.go @@ -15,7 +15,6 @@ import ( "database/sql" "encoding/json" "errors" - "log" "strconv" "time" @@ -25,6 +24,7 @@ import ( dbgroup "github.com/Wei-Shaw/sub2api/ent/group" dbpredicate "github.com/Wei-Shaw/sub2api/ent/predicate" dbproxy "github.com/Wei-Shaw/sub2api/ent/proxy" + "github.com/Wei-Shaw/sub2api/internal/pkg/logger" "github.com/Wei-Shaw/sub2api/internal/pkg/pagination" "github.com/Wei-Shaw/sub2api/internal/service" "github.com/lib/pq" @@ -127,7 +127,7 @@ func (r *accountRepository) Create(ctx context.Context, account *service.Account account.CreatedAt = created.CreatedAt account.UpdatedAt = created.UpdatedAt if err := enqueueSchedulerOutbox(ctx, r.sql, service.SchedulerOutboxEventAccountChanged, &account.ID, nil, buildSchedulerGroupPayload(account.GroupIDs)); err != nil { - log.Printf("[SchedulerOutbox] enqueue account create failed: account=%d err=%v", account.ID, err) + logger.LegacyPrintf("repository.account", "[SchedulerOutbox] enqueue account create failed: account=%d err=%v", account.ID, err) } return nil } @@ -388,7 +388,7 @@ func (r *accountRepository) Update(ctx context.Context, account *service.Account } account.UpdatedAt = updated.UpdatedAt if err := enqueueSchedulerOutbox(ctx, r.sql, service.SchedulerOutboxEventAccountChanged, &account.ID, nil, buildSchedulerGroupPayload(account.GroupIDs)); err != nil { - log.Printf("[SchedulerOutbox] enqueue account update failed: account=%d err=%v", account.ID, err) + logger.LegacyPrintf("repository.account", "[SchedulerOutbox] enqueue account update failed: account=%d err=%v", account.ID, err) } if account.Status == service.StatusError || account.Status == service.StatusDisabled || !account.Schedulable { r.syncSchedulerAccountSnapshot(ctx, account.ID) @@ -429,7 +429,7 @@ func (r *accountRepository) Delete(ctx context.Context, id int64) error { } } if err := enqueueSchedulerOutbox(ctx, r.sql, service.SchedulerOutboxEventAccountChanged, &id, nil, buildSchedulerGroupPayload(groupIDs)); err != nil { - log.Printf("[SchedulerOutbox] enqueue account delete failed: account=%d err=%v", id, err) + logger.LegacyPrintf("repository.account", "[SchedulerOutbox] enqueue account delete failed: account=%d err=%v", id, err) } return nil } @@ -533,7 +533,7 @@ func (r *accountRepository) UpdateLastUsed(ctx context.Context, id int64) error }, } if err := enqueueSchedulerOutbox(ctx, r.sql, service.SchedulerOutboxEventAccountLastUsed, &id, nil, payload); err != nil { - log.Printf("[SchedulerOutbox] enqueue last used failed: account=%d err=%v", id, err) + logger.LegacyPrintf("repository.account", "[SchedulerOutbox] enqueue last used failed: account=%d err=%v", id, err) } return nil } @@ -568,7 +568,7 @@ func (r *accountRepository) BatchUpdateLastUsed(ctx context.Context, updates map } payload := map[string]any{"last_used": lastUsedPayload} if err := enqueueSchedulerOutbox(ctx, r.sql, service.SchedulerOutboxEventAccountLastUsed, nil, nil, payload); err != nil { - log.Printf("[SchedulerOutbox] enqueue batch last used failed: err=%v", err) + logger.LegacyPrintf("repository.account", "[SchedulerOutbox] enqueue batch last used failed: err=%v", err) } return nil } @@ -583,7 +583,7 @@ func (r *accountRepository) SetError(ctx context.Context, id int64, errorMsg str return err } if err := enqueueSchedulerOutbox(ctx, r.sql, service.SchedulerOutboxEventAccountChanged, &id, nil, nil); err != nil { - log.Printf("[SchedulerOutbox] enqueue set error failed: account=%d err=%v", id, err) + logger.LegacyPrintf("repository.account", "[SchedulerOutbox] enqueue set error failed: account=%d err=%v", id, err) } r.syncSchedulerAccountSnapshot(ctx, id) return nil @@ -603,11 +603,11 @@ func (r *accountRepository) syncSchedulerAccountSnapshot(ctx context.Context, ac } account, err := r.GetByID(ctx, accountID) if err != nil { - log.Printf("[Scheduler] sync account snapshot read failed: id=%d err=%v", accountID, err) + logger.LegacyPrintf("repository.account", "[Scheduler] sync account snapshot read failed: id=%d err=%v", accountID, err) return } if err := r.schedulerCache.SetAccount(ctx, account); err != nil { - log.Printf("[Scheduler] sync account snapshot write failed: id=%d err=%v", accountID, err) + logger.LegacyPrintf("repository.account", "[Scheduler] sync account snapshot write failed: id=%d err=%v", accountID, err) } } @@ -631,7 +631,7 @@ func (r *accountRepository) AddToGroup(ctx context.Context, accountID, groupID i } payload := buildSchedulerGroupPayload([]int64{groupID}) if err := enqueueSchedulerOutbox(ctx, r.sql, service.SchedulerOutboxEventAccountGroupsChanged, &accountID, nil, payload); err != nil { - log.Printf("[SchedulerOutbox] enqueue add to group failed: account=%d group=%d err=%v", accountID, groupID, err) + logger.LegacyPrintf("repository.account", "[SchedulerOutbox] enqueue add to group failed: account=%d group=%d err=%v", accountID, groupID, err) } return nil } @@ -648,7 +648,7 @@ func (r *accountRepository) RemoveFromGroup(ctx context.Context, accountID, grou } payload := buildSchedulerGroupPayload([]int64{groupID}) if err := enqueueSchedulerOutbox(ctx, r.sql, service.SchedulerOutboxEventAccountGroupsChanged, &accountID, nil, payload); err != nil { - log.Printf("[SchedulerOutbox] enqueue remove from group failed: account=%d group=%d err=%v", accountID, groupID, err) + logger.LegacyPrintf("repository.account", "[SchedulerOutbox] enqueue remove from group failed: account=%d group=%d err=%v", accountID, groupID, err) } return nil } @@ -721,7 +721,7 @@ func (r *accountRepository) BindGroups(ctx context.Context, accountID int64, gro } payload := buildSchedulerGroupPayload(mergeGroupIDs(existingGroupIDs, groupIDs)) if err := enqueueSchedulerOutbox(ctx, r.sql, service.SchedulerOutboxEventAccountGroupsChanged, &accountID, nil, payload); err != nil { - log.Printf("[SchedulerOutbox] enqueue bind groups failed: account=%d err=%v", accountID, err) + logger.LegacyPrintf("repository.account", "[SchedulerOutbox] enqueue bind groups failed: account=%d err=%v", accountID, err) } return nil } @@ -829,7 +829,7 @@ func (r *accountRepository) SetRateLimited(ctx context.Context, id int64, resetA return err } if err := enqueueSchedulerOutbox(ctx, r.sql, service.SchedulerOutboxEventAccountChanged, &id, nil, nil); err != nil { - log.Printf("[SchedulerOutbox] enqueue rate limit failed: account=%d err=%v", id, err) + logger.LegacyPrintf("repository.account", "[SchedulerOutbox] enqueue rate limit failed: account=%d err=%v", id, err) } return nil } @@ -876,7 +876,7 @@ func (r *accountRepository) SetModelRateLimit(ctx context.Context, id int64, sco return service.ErrAccountNotFound } if err := enqueueSchedulerOutbox(ctx, r.sql, service.SchedulerOutboxEventAccountChanged, &id, nil, nil); err != nil { - log.Printf("[SchedulerOutbox] enqueue model rate limit failed: account=%d err=%v", id, err) + logger.LegacyPrintf("repository.account", "[SchedulerOutbox] enqueue model rate limit failed: account=%d err=%v", id, err) } return nil } @@ -890,7 +890,7 @@ func (r *accountRepository) SetOverloaded(ctx context.Context, id int64, until t return err } if err := enqueueSchedulerOutbox(ctx, r.sql, service.SchedulerOutboxEventAccountChanged, &id, nil, nil); err != nil { - log.Printf("[SchedulerOutbox] enqueue overload failed: account=%d err=%v", id, err) + logger.LegacyPrintf("repository.account", "[SchedulerOutbox] enqueue overload failed: account=%d err=%v", id, err) } return nil } @@ -909,7 +909,7 @@ func (r *accountRepository) SetTempUnschedulable(ctx context.Context, id int64, return err } if err := enqueueSchedulerOutbox(ctx, r.sql, service.SchedulerOutboxEventAccountChanged, &id, nil, nil); err != nil { - log.Printf("[SchedulerOutbox] enqueue temp unschedulable failed: account=%d err=%v", id, err) + logger.LegacyPrintf("repository.account", "[SchedulerOutbox] enqueue temp unschedulable failed: account=%d err=%v", id, err) } r.syncSchedulerAccountSnapshot(ctx, id) return nil @@ -928,7 +928,7 @@ func (r *accountRepository) ClearTempUnschedulable(ctx context.Context, id int64 return err } if err := enqueueSchedulerOutbox(ctx, r.sql, service.SchedulerOutboxEventAccountChanged, &id, nil, nil); err != nil { - log.Printf("[SchedulerOutbox] enqueue clear temp unschedulable failed: account=%d err=%v", id, err) + logger.LegacyPrintf("repository.account", "[SchedulerOutbox] enqueue clear temp unschedulable failed: account=%d err=%v", id, err) } return nil } @@ -944,7 +944,7 @@ func (r *accountRepository) ClearRateLimit(ctx context.Context, id int64) error return err } if err := enqueueSchedulerOutbox(ctx, r.sql, service.SchedulerOutboxEventAccountChanged, &id, nil, nil); err != nil { - log.Printf("[SchedulerOutbox] enqueue clear rate limit failed: account=%d err=%v", id, err) + logger.LegacyPrintf("repository.account", "[SchedulerOutbox] enqueue clear rate limit failed: account=%d err=%v", id, err) } return nil } @@ -968,7 +968,7 @@ func (r *accountRepository) ClearAntigravityQuotaScopes(ctx context.Context, id return service.ErrAccountNotFound } if err := enqueueSchedulerOutbox(ctx, r.sql, service.SchedulerOutboxEventAccountChanged, &id, nil, nil); err != nil { - log.Printf("[SchedulerOutbox] enqueue clear quota scopes failed: account=%d err=%v", id, err) + logger.LegacyPrintf("repository.account", "[SchedulerOutbox] enqueue clear quota scopes failed: account=%d err=%v", id, err) } return nil } @@ -992,7 +992,7 @@ func (r *accountRepository) ClearModelRateLimits(ctx context.Context, id int64) return service.ErrAccountNotFound } if err := enqueueSchedulerOutbox(ctx, r.sql, service.SchedulerOutboxEventAccountChanged, &id, nil, nil); err != nil { - log.Printf("[SchedulerOutbox] enqueue clear model rate limit failed: account=%d err=%v", id, err) + logger.LegacyPrintf("repository.account", "[SchedulerOutbox] enqueue clear model rate limit failed: account=%d err=%v", id, err) } return nil } @@ -1014,7 +1014,7 @@ func (r *accountRepository) UpdateSessionWindow(ctx context.Context, id int64, s // 触发调度器缓存更新(仅当窗口时间有变化时) if start != nil || end != nil { if err := enqueueSchedulerOutbox(ctx, r.sql, service.SchedulerOutboxEventAccountChanged, &id, nil, nil); err != nil { - log.Printf("[SchedulerOutbox] enqueue session window update failed: account=%d err=%v", id, err) + logger.LegacyPrintf("repository.account", "[SchedulerOutbox] enqueue session window update failed: account=%d err=%v", id, err) } } return nil @@ -1029,7 +1029,7 @@ func (r *accountRepository) SetSchedulable(ctx context.Context, id int64, schedu return err } if err := enqueueSchedulerOutbox(ctx, r.sql, service.SchedulerOutboxEventAccountChanged, &id, nil, nil); err != nil { - log.Printf("[SchedulerOutbox] enqueue schedulable change failed: account=%d err=%v", id, err) + logger.LegacyPrintf("repository.account", "[SchedulerOutbox] enqueue schedulable change failed: account=%d err=%v", id, err) } if !schedulable { r.syncSchedulerAccountSnapshot(ctx, id) @@ -1057,7 +1057,7 @@ func (r *accountRepository) AutoPauseExpiredAccounts(ctx context.Context, now ti } if rows > 0 { if err := enqueueSchedulerOutbox(ctx, r.sql, service.SchedulerOutboxEventFullRebuild, nil, nil, nil); err != nil { - log.Printf("[SchedulerOutbox] enqueue auto pause rebuild failed: err=%v", err) + logger.LegacyPrintf("repository.account", "[SchedulerOutbox] enqueue auto pause rebuild failed: err=%v", err) } } return rows, nil @@ -1093,7 +1093,7 @@ func (r *accountRepository) UpdateExtra(ctx context.Context, id int64, updates m return service.ErrAccountNotFound } if err := enqueueSchedulerOutbox(ctx, r.sql, service.SchedulerOutboxEventAccountChanged, &id, nil, nil); err != nil { - log.Printf("[SchedulerOutbox] enqueue extra update failed: account=%d err=%v", id, err) + logger.LegacyPrintf("repository.account", "[SchedulerOutbox] enqueue extra update failed: account=%d err=%v", id, err) } return nil } @@ -1187,7 +1187,7 @@ func (r *accountRepository) BulkUpdate(ctx context.Context, ids []int64, updates if rows > 0 { payload := map[string]any{"account_ids": ids} if err := enqueueSchedulerOutbox(ctx, r.sql, service.SchedulerOutboxEventAccountBulkChanged, nil, nil, payload); err != nil { - log.Printf("[SchedulerOutbox] enqueue bulk update failed: err=%v", err) + logger.LegacyPrintf("repository.account", "[SchedulerOutbox] enqueue bulk update failed: err=%v", err) } shouldSync := false if updates.Status != nil && (*updates.Status == service.StatusError || *updates.Status == service.StatusDisabled) { @@ -1560,3 +1560,64 @@ func joinClauses(clauses []string, sep string) string { func itoa(v int) string { return strconv.Itoa(v) } + +// FindByExtraField 根据 extra 字段中的键值对查找账号。 +// 该方法限定 platform='sora',避免误查询其他平台的账号。 +// 使用 PostgreSQL JSONB @> 操作符进行高效查询(需要 GIN 索引支持)。 +// +// 应用场景:查找通过 linked_openai_account_id 关联的 Sora 账号。 +// +// FindByExtraField finds accounts by key-value pairs in the extra field. +// Limited to platform='sora' to avoid querying accounts from other platforms. +// Uses PostgreSQL JSONB @> operator for efficient queries (requires GIN index). +// +// Use case: Finding Sora accounts linked via linked_openai_account_id. +func (r *accountRepository) FindByExtraField(ctx context.Context, key string, value any) ([]service.Account, error) { + accounts, err := r.client.Account.Query(). + Where( + dbaccount.PlatformEQ("sora"), // 限定平台为 sora + dbaccount.DeletedAtIsNil(), + func(s *entsql.Selector) { + path := sqljson.Path(key) + switch v := value.(type) { + case string: + preds := []*entsql.Predicate{sqljson.ValueEQ(dbaccount.FieldExtra, v, path)} + if parsed, err := strconv.ParseInt(v, 10, 64); err == nil { + preds = append(preds, sqljson.ValueEQ(dbaccount.FieldExtra, parsed, path)) + } + if len(preds) == 1 { + s.Where(preds[0]) + } else { + s.Where(entsql.Or(preds...)) + } + case int: + s.Where(entsql.Or( + sqljson.ValueEQ(dbaccount.FieldExtra, v, path), + sqljson.ValueEQ(dbaccount.FieldExtra, strconv.Itoa(v), path), + )) + case int64: + s.Where(entsql.Or( + sqljson.ValueEQ(dbaccount.FieldExtra, v, path), + sqljson.ValueEQ(dbaccount.FieldExtra, strconv.FormatInt(v, 10), path), + )) + case json.Number: + if parsed, err := v.Int64(); err == nil { + s.Where(entsql.Or( + sqljson.ValueEQ(dbaccount.FieldExtra, parsed, path), + sqljson.ValueEQ(dbaccount.FieldExtra, v.String(), path), + )) + } else { + s.Where(sqljson.ValueEQ(dbaccount.FieldExtra, v.String(), path)) + } + default: + s.Where(sqljson.ValueEQ(dbaccount.FieldExtra, value, path)) + } + }, + ). + All(ctx) + if err != nil { + return nil, translatePersistenceError(err, service.ErrAccountNotFound, nil) + } + + return r.accountsToService(ctx, accounts) +} diff --git a/backend/internal/repository/api_key_repo.go b/backend/internal/repository/api_key_repo.go index 22dfa700..cdccd4fc 100644 --- a/backend/internal/repository/api_key_repo.go +++ b/backend/internal/repository/api_key_repo.go @@ -34,6 +34,7 @@ func (r *apiKeyRepository) Create(ctx context.Context, key *service.APIKey) erro SetName(key.Name). SetStatus(key.Status). SetNillableGroupID(key.GroupID). + SetNillableLastUsedAt(key.LastUsedAt). SetQuota(key.Quota). SetQuotaUsed(key.QuotaUsed). SetNillableExpiresAt(key.ExpiresAt) @@ -48,6 +49,7 @@ func (r *apiKeyRepository) Create(ctx context.Context, key *service.APIKey) erro created, err := builder.Save(ctx) if err == nil { key.ID = created.ID + key.LastUsedAt = created.LastUsedAt key.CreatedAt = created.CreatedAt key.UpdatedAt = created.UpdatedAt } @@ -140,6 +142,10 @@ func (r *apiKeyRepository) GetByKeyForAuth(ctx context.Context, key string) (*se group.FieldImagePrice1k, group.FieldImagePrice2k, group.FieldImagePrice4k, + group.FieldSoraImagePrice360, + group.FieldSoraImagePrice540, + group.FieldSoraVideoPricePerRequest, + group.FieldSoraVideoPricePerRequestHd, group.FieldClaudeCodeOnly, group.FieldFallbackGroupID, group.FieldFallbackGroupIDOnInvalidRequest, @@ -375,36 +381,34 @@ func (r *apiKeyRepository) ListKeysByGroupID(ctx context.Context, groupID int64) return keys, nil } -// IncrementQuotaUsed atomically increments the quota_used field and returns the new value +// IncrementQuotaUsed 使用 Ent 原子递增 quota_used 字段并返回新值 func (r *apiKeyRepository) IncrementQuotaUsed(ctx context.Context, id int64, amount float64) (float64, error) { - // Use raw SQL for atomic increment to avoid race conditions - // First get current value - m, err := r.activeQuery(). - Where(apikey.IDEQ(id)). - Select(apikey.FieldQuotaUsed). - Only(ctx) + updated, err := r.client.APIKey.UpdateOneID(id). + Where(apikey.DeletedAtIsNil()). + AddQuotaUsed(amount). + Save(ctx) if err != nil { if dbent.IsNotFound(err) { return 0, service.ErrAPIKeyNotFound } return 0, err } + return updated.QuotaUsed, nil +} - newValue := m.QuotaUsed + amount - - // Update with new value +func (r *apiKeyRepository) UpdateLastUsed(ctx context.Context, id int64, usedAt time.Time) error { affected, err := r.client.APIKey.Update(). Where(apikey.IDEQ(id), apikey.DeletedAtIsNil()). - SetQuotaUsed(newValue). + SetLastUsedAt(usedAt). + SetUpdatedAt(usedAt). Save(ctx) if err != nil { - return 0, err + return err } if affected == 0 { - return 0, service.ErrAPIKeyNotFound + return service.ErrAPIKeyNotFound } - - return newValue, nil + return nil } func apiKeyEntityToService(m *dbent.APIKey) *service.APIKey { @@ -419,6 +423,7 @@ func apiKeyEntityToService(m *dbent.APIKey) *service.APIKey { Status: m.Status, IPWhitelist: m.IPWhitelist, IPBlacklist: m.IPBlacklist, + LastUsedAt: m.LastUsedAt, CreatedAt: m.CreatedAt, UpdatedAt: m.UpdatedAt, GroupID: m.GroupID, @@ -477,6 +482,10 @@ func groupEntityToService(g *dbent.Group) *service.Group { ImagePrice1K: g.ImagePrice1k, ImagePrice2K: g.ImagePrice2k, ImagePrice4K: g.ImagePrice4k, + SoraImagePrice360: g.SoraImagePrice360, + SoraImagePrice540: g.SoraImagePrice540, + SoraVideoPricePerRequest: g.SoraVideoPricePerRequest, + SoraVideoPricePerRequestHD: g.SoraVideoPricePerRequestHd, DefaultValidityDays: g.DefaultValidityDays, ClaudeCodeOnly: g.ClaudeCodeOnly, FallbackGroupID: g.FallbackGroupID, diff --git a/backend/internal/repository/api_key_repo_integration_test.go b/backend/internal/repository/api_key_repo_integration_test.go index 879a0576..303d7126 100644 --- a/backend/internal/repository/api_key_repo_integration_test.go +++ b/backend/internal/repository/api_key_repo_integration_test.go @@ -4,11 +4,14 @@ package repository import ( "context" + "sync" "testing" + "time" dbent "github.com/Wei-Shaw/sub2api/ent" "github.com/Wei-Shaw/sub2api/internal/pkg/pagination" "github.com/Wei-Shaw/sub2api/internal/service" + "github.com/stretchr/testify/require" "github.com/stretchr/testify/suite" ) @@ -383,3 +386,87 @@ func (s *APIKeyRepoSuite) mustCreateApiKey(userID int64, key, name string, group s.Require().NoError(s.repo.Create(s.ctx, k), "create api key") return k } + +// --- IncrementQuotaUsed --- + +func (s *APIKeyRepoSuite) TestIncrementQuotaUsed_Basic() { + user := s.mustCreateUser("incr-basic@test.com") + key := s.mustCreateApiKey(user.ID, "sk-incr-basic", "Incr", nil) + + newQuota, err := s.repo.IncrementQuotaUsed(s.ctx, key.ID, 1.5) + s.Require().NoError(err, "IncrementQuotaUsed") + s.Require().Equal(1.5, newQuota, "第一次递增后应为 1.5") + + newQuota, err = s.repo.IncrementQuotaUsed(s.ctx, key.ID, 2.5) + s.Require().NoError(err, "IncrementQuotaUsed second") + s.Require().Equal(4.0, newQuota, "第二次递增后应为 4.0") +} + +func (s *APIKeyRepoSuite) TestIncrementQuotaUsed_NotFound() { + _, err := s.repo.IncrementQuotaUsed(s.ctx, 999999, 1.0) + s.Require().ErrorIs(err, service.ErrAPIKeyNotFound, "不存在的 key 应返回 ErrAPIKeyNotFound") +} + +func (s *APIKeyRepoSuite) TestIncrementQuotaUsed_DeletedKey() { + user := s.mustCreateUser("incr-deleted@test.com") + key := s.mustCreateApiKey(user.ID, "sk-incr-del", "Deleted", nil) + + s.Require().NoError(s.repo.Delete(s.ctx, key.ID), "Delete") + + _, err := s.repo.IncrementQuotaUsed(s.ctx, key.ID, 1.0) + s.Require().ErrorIs(err, service.ErrAPIKeyNotFound, "已删除的 key 应返回 ErrAPIKeyNotFound") +} + +// TestIncrementQuotaUsed_Concurrent 使用真实数据库验证并发原子性。 +// 注意:此测试使用 testEntClient(非事务隔离),数据会真正写入数据库。 +func TestIncrementQuotaUsed_Concurrent(t *testing.T) { + client := testEntClient(t) + repo := NewAPIKeyRepository(client).(*apiKeyRepository) + ctx := context.Background() + + // 创建测试用户和 API Key + u, err := client.User.Create(). + SetEmail("concurrent-incr-" + time.Now().Format(time.RFC3339Nano) + "@test.com"). + SetPasswordHash("hash"). + SetStatus(service.StatusActive). + SetRole(service.RoleUser). + Save(ctx) + require.NoError(t, err, "create user") + + k := &service.APIKey{ + UserID: u.ID, + Key: "sk-concurrent-" + time.Now().Format(time.RFC3339Nano), + Name: "Concurrent", + Status: service.StatusActive, + } + require.NoError(t, repo.Create(ctx, k), "create api key") + t.Cleanup(func() { + _ = client.APIKey.DeleteOneID(k.ID).Exec(ctx) + _ = client.User.DeleteOneID(u.ID).Exec(ctx) + }) + + // 10 个 goroutine 各递增 1.0,总计应为 10.0 + const goroutines = 10 + const increment = 1.0 + var wg sync.WaitGroup + errs := make([]error, goroutines) + + for i := 0; i < goroutines; i++ { + wg.Add(1) + go func(idx int) { + defer wg.Done() + _, errs[idx] = repo.IncrementQuotaUsed(ctx, k.ID, increment) + }(i) + } + wg.Wait() + + for i, e := range errs { + require.NoError(t, e, "goroutine %d failed", i) + } + + // 验证最终结果 + got, err := repo.GetByID(ctx, k.ID) + require.NoError(t, err, "GetByID") + require.Equal(t, float64(goroutines)*increment, got.QuotaUsed, + "并发递增后总和应为 %v,实际为 %v", float64(goroutines)*increment, got.QuotaUsed) +} diff --git a/backend/internal/repository/api_key_repo_last_used_unit_test.go b/backend/internal/repository/api_key_repo_last_used_unit_test.go new file mode 100644 index 00000000..7c6e2850 --- /dev/null +++ b/backend/internal/repository/api_key_repo_last_used_unit_test.go @@ -0,0 +1,156 @@ +package repository + +import ( + "context" + "database/sql" + "testing" + "time" + + dbent "github.com/Wei-Shaw/sub2api/ent" + "github.com/Wei-Shaw/sub2api/ent/enttest" + "github.com/Wei-Shaw/sub2api/internal/service" + "github.com/stretchr/testify/require" + + "entgo.io/ent/dialect" + entsql "entgo.io/ent/dialect/sql" + _ "modernc.org/sqlite" +) + +func newAPIKeyRepoSQLite(t *testing.T) (*apiKeyRepository, *dbent.Client) { + t.Helper() + + db, err := sql.Open("sqlite", "file:api_key_repo_last_used?mode=memory&cache=shared") + require.NoError(t, err) + t.Cleanup(func() { _ = db.Close() }) + + _, err = db.Exec("PRAGMA foreign_keys = ON") + require.NoError(t, err) + + drv := entsql.OpenDB(dialect.SQLite, db) + client := enttest.NewClient(t, enttest.WithOptions(dbent.Driver(drv))) + t.Cleanup(func() { _ = client.Close() }) + + return &apiKeyRepository{client: client}, client +} + +func mustCreateAPIKeyRepoUser(t *testing.T, ctx context.Context, client *dbent.Client, email string) *service.User { + t.Helper() + u, err := client.User.Create(). + SetEmail(email). + SetPasswordHash("test-password-hash"). + SetRole(service.RoleUser). + SetStatus(service.StatusActive). + Save(ctx) + require.NoError(t, err) + return userEntityToService(u) +} + +func TestAPIKeyRepository_CreateWithLastUsedAt(t *testing.T) { + repo, client := newAPIKeyRepoSQLite(t) + ctx := context.Background() + user := mustCreateAPIKeyRepoUser(t, ctx, client, "create-last-used@test.com") + + lastUsed := time.Now().UTC().Add(-time.Hour).Truncate(time.Second) + key := &service.APIKey{ + UserID: user.ID, + Key: "sk-create-last-used", + Name: "CreateWithLastUsed", + Status: service.StatusActive, + LastUsedAt: &lastUsed, + } + + require.NoError(t, repo.Create(ctx, key)) + require.NotNil(t, key.LastUsedAt) + require.WithinDuration(t, lastUsed, *key.LastUsedAt, time.Second) + + got, err := repo.GetByID(ctx, key.ID) + require.NoError(t, err) + require.NotNil(t, got.LastUsedAt) + require.WithinDuration(t, lastUsed, *got.LastUsedAt, time.Second) +} + +func TestAPIKeyRepository_UpdateLastUsed(t *testing.T) { + repo, client := newAPIKeyRepoSQLite(t) + ctx := context.Background() + user := mustCreateAPIKeyRepoUser(t, ctx, client, "update-last-used@test.com") + + key := &service.APIKey{ + UserID: user.ID, + Key: "sk-update-last-used", + Name: "UpdateLastUsed", + Status: service.StatusActive, + } + require.NoError(t, repo.Create(ctx, key)) + + before, err := repo.GetByID(ctx, key.ID) + require.NoError(t, err) + require.Nil(t, before.LastUsedAt) + + target := time.Now().UTC().Add(2 * time.Minute).Truncate(time.Second) + require.NoError(t, repo.UpdateLastUsed(ctx, key.ID, target)) + + after, err := repo.GetByID(ctx, key.ID) + require.NoError(t, err) + require.NotNil(t, after.LastUsedAt) + require.WithinDuration(t, target, *after.LastUsedAt, time.Second) + require.WithinDuration(t, target, after.UpdatedAt, time.Second) +} + +func TestAPIKeyRepository_UpdateLastUsedDeletedKey(t *testing.T) { + repo, client := newAPIKeyRepoSQLite(t) + ctx := context.Background() + user := mustCreateAPIKeyRepoUser(t, ctx, client, "deleted-last-used@test.com") + + key := &service.APIKey{ + UserID: user.ID, + Key: "sk-update-last-used-deleted", + Name: "UpdateLastUsedDeleted", + Status: service.StatusActive, + } + require.NoError(t, repo.Create(ctx, key)) + require.NoError(t, repo.Delete(ctx, key.ID)) + + err := repo.UpdateLastUsed(ctx, key.ID, time.Now().UTC()) + require.ErrorIs(t, err, service.ErrAPIKeyNotFound) +} + +func TestAPIKeyRepository_UpdateLastUsedDBError(t *testing.T) { + repo, client := newAPIKeyRepoSQLite(t) + ctx := context.Background() + user := mustCreateAPIKeyRepoUser(t, ctx, client, "db-error-last-used@test.com") + + key := &service.APIKey{ + UserID: user.ID, + Key: "sk-update-last-used-db-error", + Name: "UpdateLastUsedDBError", + Status: service.StatusActive, + } + require.NoError(t, repo.Create(ctx, key)) + + require.NoError(t, client.Close()) + err := repo.UpdateLastUsed(ctx, key.ID, time.Now().UTC()) + require.Error(t, err) +} + +func TestAPIKeyRepository_CreateDuplicateKey(t *testing.T) { + repo, client := newAPIKeyRepoSQLite(t) + ctx := context.Background() + user := mustCreateAPIKeyRepoUser(t, ctx, client, "duplicate-key@test.com") + + first := &service.APIKey{ + UserID: user.ID, + Key: "sk-duplicate", + Name: "first", + Status: service.StatusActive, + } + second := &service.APIKey{ + UserID: user.ID, + Key: "sk-duplicate", + Name: "second", + Status: service.StatusActive, + } + + require.NoError(t, repo.Create(ctx, first)) + err := repo.Create(ctx, second) + require.ErrorIs(t, err, service.ErrAPIKeyExists) +} diff --git a/backend/internal/repository/billing_cache.go b/backend/internal/repository/billing_cache.go index ac5803a1..e753e1b8 100644 --- a/backend/internal/repository/billing_cache.go +++ b/backend/internal/repository/billing_cache.go @@ -5,6 +5,7 @@ import ( "errors" "fmt" "log" + "math/rand/v2" "strconv" "time" @@ -16,8 +17,19 @@ const ( billingBalanceKeyPrefix = "billing:balance:" billingSubKeyPrefix = "billing:sub:" billingCacheTTL = 5 * time.Minute + billingCacheJitter = 30 * time.Second ) +// jitteredTTL 返回带随机抖动的 TTL,防止缓存雪崩 +func jitteredTTL() time.Duration { + // 只做“减法抖动”,确保实际 TTL 不会超过 billingCacheTTL(避免上界预期被打破)。 + if billingCacheJitter <= 0 { + return billingCacheTTL + } + jitter := time.Duration(rand.IntN(int(billingCacheJitter))) + return billingCacheTTL - jitter +} + // billingBalanceKey generates the Redis key for user balance cache. func billingBalanceKey(userID int64) string { return fmt.Sprintf("%s%d", billingBalanceKeyPrefix, userID) @@ -82,14 +94,15 @@ func (c *billingCache) GetUserBalance(ctx context.Context, userID int64) (float6 func (c *billingCache) SetUserBalance(ctx context.Context, userID int64, balance float64) error { key := billingBalanceKey(userID) - return c.rdb.Set(ctx, key, balance, billingCacheTTL).Err() + return c.rdb.Set(ctx, key, balance, jitteredTTL()).Err() } func (c *billingCache) DeductUserBalance(ctx context.Context, userID int64, amount float64) error { key := billingBalanceKey(userID) - _, err := deductBalanceScript.Run(ctx, c.rdb, []string{key}, amount, int(billingCacheTTL.Seconds())).Result() + _, err := deductBalanceScript.Run(ctx, c.rdb, []string{key}, amount, int(jitteredTTL().Seconds())).Result() if err != nil && !errors.Is(err, redis.Nil) { log.Printf("Warning: deduct balance cache failed for user %d: %v", userID, err) + return err } return nil } @@ -163,16 +176,17 @@ func (c *billingCache) SetSubscriptionCache(ctx context.Context, userID, groupID pipe := c.rdb.Pipeline() pipe.HSet(ctx, key, fields) - pipe.Expire(ctx, key, billingCacheTTL) + pipe.Expire(ctx, key, jitteredTTL()) _, err := pipe.Exec(ctx) return err } func (c *billingCache) UpdateSubscriptionUsage(ctx context.Context, userID, groupID int64, cost float64) error { key := billingSubKey(userID, groupID) - _, err := updateSubUsageScript.Run(ctx, c.rdb, []string{key}, cost, int(billingCacheTTL.Seconds())).Result() + _, err := updateSubUsageScript.Run(ctx, c.rdb, []string{key}, cost, int(jitteredTTL().Seconds())).Result() if err != nil && !errors.Is(err, redis.Nil) { log.Printf("Warning: update subscription usage cache failed for user %d group %d: %v", userID, groupID, err) + return err } return nil } diff --git a/backend/internal/repository/billing_cache_integration_test.go b/backend/internal/repository/billing_cache_integration_test.go index 2f7c69a7..4b7377b1 100644 --- a/backend/internal/repository/billing_cache_integration_test.go +++ b/backend/internal/repository/billing_cache_integration_test.go @@ -278,6 +278,90 @@ func (s *BillingCacheSuite) TestSubscriptionCache() { } } +// TestDeductUserBalance_ErrorPropagation 验证 P2-12 修复: +// Redis 真实错误应传播,key 不存在(redis.Nil)应返回 nil。 +func (s *BillingCacheSuite) TestDeductUserBalance_ErrorPropagation() { + tests := []struct { + name string + fn func(ctx context.Context, cache service.BillingCache) + expectErr bool + }{ + { + name: "key_not_exists_returns_nil", + fn: func(ctx context.Context, cache service.BillingCache) { + // key 不存在时,Lua 脚本返回 0(redis.Nil),应返回 nil 而非错误 + err := cache.DeductUserBalance(ctx, 99999, 1.0) + require.NoError(s.T(), err, "DeductUserBalance on non-existent key should return nil") + }, + }, + { + name: "existing_key_deducts_successfully", + fn: func(ctx context.Context, cache service.BillingCache) { + require.NoError(s.T(), cache.SetUserBalance(ctx, 200, 50.0)) + err := cache.DeductUserBalance(ctx, 200, 10.0) + require.NoError(s.T(), err, "DeductUserBalance should succeed") + + bal, err := cache.GetUserBalance(ctx, 200) + require.NoError(s.T(), err) + require.Equal(s.T(), 40.0, bal, "余额应为 40.0") + }, + }, + { + name: "cancelled_context_propagates_error", + fn: func(ctx context.Context, cache service.BillingCache) { + require.NoError(s.T(), cache.SetUserBalance(ctx, 201, 50.0)) + + cancelCtx, cancel := context.WithCancel(ctx) + cancel() // 立即取消 + + err := cache.DeductUserBalance(cancelCtx, 201, 10.0) + require.Error(s.T(), err, "cancelled context should propagate error") + }, + }, + } + + for _, tt := range tests { + s.Run(tt.name, func() { + rdb := testRedis(s.T()) + cache := NewBillingCache(rdb) + ctx := context.Background() + tt.fn(ctx, cache) + }) + } +} + +// TestUpdateSubscriptionUsage_ErrorPropagation 验证 P2-12 修复: +// Redis 真实错误应传播,key 不存在(redis.Nil)应返回 nil。 +func (s *BillingCacheSuite) TestUpdateSubscriptionUsage_ErrorPropagation() { + s.Run("key_not_exists_returns_nil", func() { + rdb := testRedis(s.T()) + cache := NewBillingCache(rdb) + ctx := context.Background() + + err := cache.UpdateSubscriptionUsage(ctx, 88888, 77777, 1.0) + require.NoError(s.T(), err, "UpdateSubscriptionUsage on non-existent key should return nil") + }) + + s.Run("cancelled_context_propagates_error", func() { + rdb := testRedis(s.T()) + cache := NewBillingCache(rdb) + ctx := context.Background() + + data := &service.SubscriptionCacheData{ + Status: "active", + ExpiresAt: time.Now().Add(1 * time.Hour), + Version: 1, + } + require.NoError(s.T(), cache.SetSubscriptionCache(ctx, 301, 401, data)) + + cancelCtx, cancel := context.WithCancel(ctx) + cancel() + + err := cache.UpdateSubscriptionUsage(cancelCtx, 301, 401, 1.0) + require.Error(s.T(), err, "cancelled context should propagate error") + }) +} + func TestBillingCacheSuite(t *testing.T) { suite.Run(t, new(BillingCacheSuite)) } diff --git a/backend/internal/repository/billing_cache_jitter_test.go b/backend/internal/repository/billing_cache_jitter_test.go new file mode 100644 index 00000000..ba4f2873 --- /dev/null +++ b/backend/internal/repository/billing_cache_jitter_test.go @@ -0,0 +1,82 @@ +package repository + +import ( + "testing" + "time" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +// --- Task 6.1 验证: math/rand/v2 迁移后 jitteredTTL 行为正确 --- + +func TestJitteredTTL_WithinExpectedRange(t *testing.T) { + // jitteredTTL 使用减法抖动: billingCacheTTL - [0, billingCacheJitter) + // 所以结果应在 [billingCacheTTL - billingCacheJitter, billingCacheTTL] 范围内 + lowerBound := billingCacheTTL - billingCacheJitter // 5min - 30s = 4min30s + upperBound := billingCacheTTL // 5min + + for i := 0; i < 200; i++ { + ttl := jitteredTTL() + assert.GreaterOrEqual(t, int64(ttl), int64(lowerBound), + "TTL 不应低于 %v,实际得到 %v", lowerBound, ttl) + assert.LessOrEqual(t, int64(ttl), int64(upperBound), + "TTL 不应超过 %v(上界不变保证),实际得到 %v", upperBound, ttl) + } +} + +func TestJitteredTTL_NeverExceedsBase(t *testing.T) { + // 关键安全性测试:jitteredTTL 使用减法抖动,确保永远不超过 billingCacheTTL + for i := 0; i < 500; i++ { + ttl := jitteredTTL() + assert.LessOrEqual(t, int64(ttl), int64(billingCacheTTL), + "jitteredTTL 不应超过基础 TTL(上界预期不被打破)") + } +} + +func TestJitteredTTL_HasVariance(t *testing.T) { + // 验证抖动确实产生了不同的值 + results := make(map[time.Duration]bool) + for i := 0; i < 100; i++ { + ttl := jitteredTTL() + results[ttl] = true + } + + require.Greater(t, len(results), 1, + "jitteredTTL 应产生不同的值(抖动生效),但 100 次调用结果全部相同") +} + +func TestJitteredTTL_AverageNearCenter(t *testing.T) { + // 验证平均值大约在抖动范围中间 + var sum time.Duration + runs := 1000 + for i := 0; i < runs; i++ { + sum += jitteredTTL() + } + + avg := sum / time.Duration(runs) + expectedCenter := billingCacheTTL - billingCacheJitter/2 // 4min45s + + // 允许 ±5s 的误差 + tolerance := 5 * time.Second + assert.InDelta(t, float64(expectedCenter), float64(avg), float64(tolerance), + "平均 TTL 应接近抖动范围中心 %v", expectedCenter) +} + +func TestBillingKeyGeneration(t *testing.T) { + t.Run("balance_key", func(t *testing.T) { + key := billingBalanceKey(12345) + assert.Equal(t, "billing:balance:12345", key) + }) + + t.Run("sub_key", func(t *testing.T) { + key := billingSubKey(100, 200) + assert.Equal(t, "billing:sub:100:200", key) + }) +} + +func BenchmarkJitteredTTL(b *testing.B) { + for i := 0; i < b.N; i++ { + _ = jitteredTTL() + } +} diff --git a/backend/internal/repository/billing_cache_test.go b/backend/internal/repository/billing_cache_test.go index 7d3fd19d..2de1da87 100644 --- a/backend/internal/repository/billing_cache_test.go +++ b/backend/internal/repository/billing_cache_test.go @@ -5,6 +5,7 @@ package repository import ( "math" "testing" + "time" "github.com/stretchr/testify/require" ) @@ -85,3 +86,26 @@ func TestBillingSubKey(t *testing.T) { }) } } + +func TestJitteredTTL(t *testing.T) { + const ( + minTTL = 4*time.Minute + 30*time.Second // 270s = 5min - 30s + maxTTL = 5*time.Minute + 30*time.Second // 330s = 5min + 30s + ) + + for i := 0; i < 200; i++ { + ttl := jitteredTTL() + require.GreaterOrEqual(t, ttl, minTTL, "jitteredTTL() 返回值低于下限: %v", ttl) + require.LessOrEqual(t, ttl, maxTTL, "jitteredTTL() 返回值超过上限: %v", ttl) + } +} + +func TestJitteredTTL_HasVariation(t *testing.T) { + // 多次调用应该产生不同的值(验证抖动存在) + seen := make(map[time.Duration]struct{}, 50) + for i := 0; i < 50; i++ { + seen[jitteredTTL()] = struct{}{} + } + // 50 次调用中应该至少有 2 个不同的值 + require.Greater(t, len(seen), 1, "jitteredTTL() 应产生不同的 TTL 值") +} diff --git a/backend/internal/repository/claude_oauth_service.go b/backend/internal/repository/claude_oauth_service.go index fc0d2918..77764881 100644 --- a/backend/internal/repository/claude_oauth_service.go +++ b/backend/internal/repository/claude_oauth_service.go @@ -4,12 +4,12 @@ import ( "context" "encoding/json" "fmt" - "log" "net/http" "net/url" "strings" "time" + "github.com/Wei-Shaw/sub2api/internal/pkg/logger" "github.com/Wei-Shaw/sub2api/internal/pkg/oauth" "github.com/Wei-Shaw/sub2api/internal/service" "github.com/Wei-Shaw/sub2api/internal/util/logredact" @@ -41,7 +41,7 @@ func (s *claudeOAuthService) GetOrganizationUUID(ctx context.Context, sessionKey } targetURL := s.baseURL + "/api/organizations" - log.Printf("[OAuth] Step 1: Getting organization UUID from %s", targetURL) + logger.LegacyPrintf("repository.claude_oauth", "[OAuth] Step 1: Getting organization UUID from %s", targetURL) resp, err := client.R(). SetContext(ctx). @@ -53,11 +53,11 @@ func (s *claudeOAuthService) GetOrganizationUUID(ctx context.Context, sessionKey Get(targetURL) if err != nil { - log.Printf("[OAuth] Step 1 FAILED - Request error: %v", err) + logger.LegacyPrintf("repository.claude_oauth", "[OAuth] Step 1 FAILED - Request error: %v", err) return "", fmt.Errorf("request failed: %w", err) } - log.Printf("[OAuth] Step 1 Response - Status: %d", resp.StatusCode) + logger.LegacyPrintf("repository.claude_oauth", "[OAuth] Step 1 Response - Status: %d", resp.StatusCode) if !resp.IsSuccessState() { return "", fmt.Errorf("failed to get organizations: status %d, body: %s", resp.StatusCode, resp.String()) @@ -69,21 +69,21 @@ func (s *claudeOAuthService) GetOrganizationUUID(ctx context.Context, sessionKey // 如果只有一个组织,直接使用 if len(orgs) == 1 { - log.Printf("[OAuth] Step 1 SUCCESS - Single org found, UUID: %s, Name: %s", orgs[0].UUID, orgs[0].Name) + logger.LegacyPrintf("repository.claude_oauth", "[OAuth] Step 1 SUCCESS - Single org found, UUID: %s, Name: %s", orgs[0].UUID, orgs[0].Name) return orgs[0].UUID, nil } // 如果有多个组织,优先选择 raven_type 为 "team" 的组织 for _, org := range orgs { if org.RavenType != nil && *org.RavenType == "team" { - log.Printf("[OAuth] Step 1 SUCCESS - Selected team org, UUID: %s, Name: %s, RavenType: %s", + logger.LegacyPrintf("repository.claude_oauth", "[OAuth] Step 1 SUCCESS - Selected team org, UUID: %s, Name: %s, RavenType: %s", org.UUID, org.Name, *org.RavenType) return org.UUID, nil } } // 如果没有 team 类型的组织,使用第一个 - log.Printf("[OAuth] Step 1 SUCCESS - No team org found, using first org, UUID: %s, Name: %s", orgs[0].UUID, orgs[0].Name) + logger.LegacyPrintf("repository.claude_oauth", "[OAuth] Step 1 SUCCESS - No team org found, using first org, UUID: %s, Name: %s", orgs[0].UUID, orgs[0].Name) return orgs[0].UUID, nil } @@ -103,9 +103,9 @@ func (s *claudeOAuthService) GetAuthorizationCode(ctx context.Context, sessionKe "code_challenge_method": "S256", } - log.Printf("[OAuth] Step 2: Getting authorization code from %s", authURL) + logger.LegacyPrintf("repository.claude_oauth", "[OAuth] Step 2: Getting authorization code from %s", authURL) reqBodyJSON, _ := json.Marshal(logredact.RedactMap(reqBody)) - log.Printf("[OAuth] Step 2 Request Body: %s", string(reqBodyJSON)) + logger.LegacyPrintf("repository.claude_oauth", "[OAuth] Step 2 Request Body: %s", string(reqBodyJSON)) var result struct { RedirectURI string `json:"redirect_uri"` @@ -128,11 +128,11 @@ func (s *claudeOAuthService) GetAuthorizationCode(ctx context.Context, sessionKe Post(authURL) if err != nil { - log.Printf("[OAuth] Step 2 FAILED - Request error: %v", err) + logger.LegacyPrintf("repository.claude_oauth", "[OAuth] Step 2 FAILED - Request error: %v", err) return "", fmt.Errorf("request failed: %w", err) } - log.Printf("[OAuth] Step 2 Response - Status: %d, Body: %s", resp.StatusCode, logredact.RedactJSON(resp.Bytes())) + logger.LegacyPrintf("repository.claude_oauth", "[OAuth] Step 2 Response - Status: %d, Body: %s", resp.StatusCode, logredact.RedactJSON(resp.Bytes())) if !resp.IsSuccessState() { return "", fmt.Errorf("failed to get authorization code: status %d, body: %s", resp.StatusCode, resp.String()) @@ -160,7 +160,7 @@ func (s *claudeOAuthService) GetAuthorizationCode(ctx context.Context, sessionKe fullCode = authCode + "#" + responseState } - log.Printf("[OAuth] Step 2 SUCCESS - Got authorization code") + logger.LegacyPrintf("repository.claude_oauth", "[OAuth] Step 2 SUCCESS - Got authorization code") return fullCode, nil } @@ -192,9 +192,9 @@ func (s *claudeOAuthService) ExchangeCodeForToken(ctx context.Context, code, cod reqBody["expires_in"] = 31536000 // 365 * 24 * 60 * 60 seconds } - log.Printf("[OAuth] Step 3: Exchanging code for token at %s", s.tokenURL) + logger.LegacyPrintf("repository.claude_oauth", "[OAuth] Step 3: Exchanging code for token at %s", s.tokenURL) reqBodyJSON, _ := json.Marshal(logredact.RedactMap(reqBody)) - log.Printf("[OAuth] Step 3 Request Body: %s", string(reqBodyJSON)) + logger.LegacyPrintf("repository.claude_oauth", "[OAuth] Step 3 Request Body: %s", string(reqBodyJSON)) var tokenResp oauth.TokenResponse @@ -208,17 +208,17 @@ func (s *claudeOAuthService) ExchangeCodeForToken(ctx context.Context, code, cod Post(s.tokenURL) if err != nil { - log.Printf("[OAuth] Step 3 FAILED - Request error: %v", err) + logger.LegacyPrintf("repository.claude_oauth", "[OAuth] Step 3 FAILED - Request error: %v", err) return nil, fmt.Errorf("request failed: %w", err) } - log.Printf("[OAuth] Step 3 Response - Status: %d, Body: %s", resp.StatusCode, logredact.RedactJSON(resp.Bytes())) + logger.LegacyPrintf("repository.claude_oauth", "[OAuth] Step 3 Response - Status: %d, Body: %s", resp.StatusCode, logredact.RedactJSON(resp.Bytes())) if !resp.IsSuccessState() { return nil, fmt.Errorf("token exchange failed: status %d, body: %s", resp.StatusCode, resp.String()) } - log.Printf("[OAuth] Step 3 SUCCESS - Got access token") + logger.LegacyPrintf("repository.claude_oauth", "[OAuth] Step 3 SUCCESS - Got access token") return &tokenResp, nil } diff --git a/backend/internal/repository/concurrency_cache.go b/backend/internal/repository/concurrency_cache.go index cc0c6db5..e047bff0 100644 --- a/backend/internal/repository/concurrency_cache.go +++ b/backend/internal/repository/concurrency_cache.go @@ -147,100 +147,6 @@ var ( return 1 `) - // getAccountsLoadBatchScript - batch load query with expired slot cleanup - // ARGV[1] = slot TTL (seconds) - // ARGV[2..n] = accountID1, maxConcurrency1, accountID2, maxConcurrency2, ... - getAccountsLoadBatchScript = redis.NewScript(` - local result = {} - local slotTTL = tonumber(ARGV[1]) - - -- Get current server time - local timeResult = redis.call('TIME') - local nowSeconds = tonumber(timeResult[1]) - local cutoffTime = nowSeconds - slotTTL - - local i = 2 - while i <= #ARGV do - local accountID = ARGV[i] - local maxConcurrency = tonumber(ARGV[i + 1]) - - local slotKey = 'concurrency:account:' .. accountID - - -- Clean up expired slots before counting - redis.call('ZREMRANGEBYSCORE', slotKey, '-inf', cutoffTime) - local currentConcurrency = redis.call('ZCARD', slotKey) - - local waitKey = 'wait:account:' .. accountID - local waitingCount = redis.call('GET', waitKey) - if waitingCount == false then - waitingCount = 0 - else - waitingCount = tonumber(waitingCount) - end - - local loadRate = 0 - if maxConcurrency > 0 then - loadRate = math.floor((currentConcurrency + waitingCount) * 100 / maxConcurrency) - end - - table.insert(result, accountID) - table.insert(result, currentConcurrency) - table.insert(result, waitingCount) - table.insert(result, loadRate) - - i = i + 2 - end - - return result - `) - - // getUsersLoadBatchScript - batch load query for users with expired slot cleanup - // ARGV[1] = slot TTL (seconds) - // ARGV[2..n] = userID1, maxConcurrency1, userID2, maxConcurrency2, ... - getUsersLoadBatchScript = redis.NewScript(` - local result = {} - local slotTTL = tonumber(ARGV[1]) - - -- Get current server time - local timeResult = redis.call('TIME') - local nowSeconds = tonumber(timeResult[1]) - local cutoffTime = nowSeconds - slotTTL - - local i = 2 - while i <= #ARGV do - local userID = ARGV[i] - local maxConcurrency = tonumber(ARGV[i + 1]) - - local slotKey = 'concurrency:user:' .. userID - - -- Clean up expired slots before counting - redis.call('ZREMRANGEBYSCORE', slotKey, '-inf', cutoffTime) - local currentConcurrency = redis.call('ZCARD', slotKey) - - local waitKey = 'concurrency:wait:' .. userID - local waitingCount = redis.call('GET', waitKey) - if waitingCount == false then - waitingCount = 0 - else - waitingCount = tonumber(waitingCount) - end - - local loadRate = 0 - if maxConcurrency > 0 then - loadRate = math.floor((currentConcurrency + waitingCount) * 100 / maxConcurrency) - end - - table.insert(result, userID) - table.insert(result, currentConcurrency) - table.insert(result, waitingCount) - table.insert(result, loadRate) - - i = i + 2 - end - - return result - `) - // cleanupExpiredSlotsScript - remove expired slots // KEYS[1] = concurrency:account:{accountID} // ARGV[1] = TTL (seconds) @@ -399,29 +305,53 @@ func (c *concurrencyCache) GetAccountsLoadBatch(ctx context.Context, accounts [] return map[int64]*service.AccountLoadInfo{}, nil } - args := []any{c.slotTTLSeconds} - for _, acc := range accounts { - args = append(args, acc.ID, acc.MaxConcurrency) - } - - result, err := getAccountsLoadBatchScript.Run(ctx, c.rdb, []string{}, args...).Slice() + // 使用 Pipeline 替代 Lua 脚本,兼容 Redis Cluster(Lua 内动态拼 key 会 CROSSSLOT)。 + // 每个账号执行 3 个命令:ZREMRANGEBYSCORE(清理过期)、ZCARD(并发数)、GET(等待数)。 + now, err := c.rdb.Time(ctx).Result() if err != nil { - return nil, err + return nil, fmt.Errorf("redis TIME: %w", err) + } + cutoffTime := now.Unix() - int64(c.slotTTLSeconds) + + pipe := c.rdb.Pipeline() + + type accountCmds struct { + id int64 + maxConcurrency int + zcardCmd *redis.IntCmd + getCmd *redis.StringCmd + } + cmds := make([]accountCmds, 0, len(accounts)) + for _, acc := range accounts { + slotKey := accountSlotKeyPrefix + strconv.FormatInt(acc.ID, 10) + waitKey := accountWaitKeyPrefix + strconv.FormatInt(acc.ID, 10) + pipe.ZRemRangeByScore(ctx, slotKey, "-inf", strconv.FormatInt(cutoffTime, 10)) + ac := accountCmds{ + id: acc.ID, + maxConcurrency: acc.MaxConcurrency, + zcardCmd: pipe.ZCard(ctx, slotKey), + getCmd: pipe.Get(ctx, waitKey), + } + cmds = append(cmds, ac) } - loadMap := make(map[int64]*service.AccountLoadInfo) - for i := 0; i < len(result); i += 4 { - if i+3 >= len(result) { - break + if _, err := pipe.Exec(ctx); err != nil && !errors.Is(err, redis.Nil) { + return nil, fmt.Errorf("pipeline exec: %w", err) + } + + loadMap := make(map[int64]*service.AccountLoadInfo, len(accounts)) + for _, ac := range cmds { + currentConcurrency := int(ac.zcardCmd.Val()) + waitingCount := 0 + if v, err := ac.getCmd.Int(); err == nil { + waitingCount = v } - - accountID, _ := strconv.ParseInt(fmt.Sprintf("%v", result[i]), 10, 64) - currentConcurrency, _ := strconv.Atoi(fmt.Sprintf("%v", result[i+1])) - waitingCount, _ := strconv.Atoi(fmt.Sprintf("%v", result[i+2])) - loadRate, _ := strconv.Atoi(fmt.Sprintf("%v", result[i+3])) - - loadMap[accountID] = &service.AccountLoadInfo{ - AccountID: accountID, + loadRate := 0 + if ac.maxConcurrency > 0 { + loadRate = (currentConcurrency + waitingCount) * 100 / ac.maxConcurrency + } + loadMap[ac.id] = &service.AccountLoadInfo{ + AccountID: ac.id, CurrentConcurrency: currentConcurrency, WaitingCount: waitingCount, LoadRate: loadRate, @@ -436,29 +366,52 @@ func (c *concurrencyCache) GetUsersLoadBatch(ctx context.Context, users []servic return map[int64]*service.UserLoadInfo{}, nil } - args := []any{c.slotTTLSeconds} - for _, u := range users { - args = append(args, u.ID, u.MaxConcurrency) - } - - result, err := getUsersLoadBatchScript.Run(ctx, c.rdb, []string{}, args...).Slice() + // 使用 Pipeline 替代 Lua 脚本,兼容 Redis Cluster。 + now, err := c.rdb.Time(ctx).Result() if err != nil { - return nil, err + return nil, fmt.Errorf("redis TIME: %w", err) + } + cutoffTime := now.Unix() - int64(c.slotTTLSeconds) + + pipe := c.rdb.Pipeline() + + type userCmds struct { + id int64 + maxConcurrency int + zcardCmd *redis.IntCmd + getCmd *redis.StringCmd + } + cmds := make([]userCmds, 0, len(users)) + for _, u := range users { + slotKey := userSlotKeyPrefix + strconv.FormatInt(u.ID, 10) + waitKey := waitQueueKeyPrefix + strconv.FormatInt(u.ID, 10) + pipe.ZRemRangeByScore(ctx, slotKey, "-inf", strconv.FormatInt(cutoffTime, 10)) + uc := userCmds{ + id: u.ID, + maxConcurrency: u.MaxConcurrency, + zcardCmd: pipe.ZCard(ctx, slotKey), + getCmd: pipe.Get(ctx, waitKey), + } + cmds = append(cmds, uc) } - loadMap := make(map[int64]*service.UserLoadInfo) - for i := 0; i < len(result); i += 4 { - if i+3 >= len(result) { - break + if _, err := pipe.Exec(ctx); err != nil && !errors.Is(err, redis.Nil) { + return nil, fmt.Errorf("pipeline exec: %w", err) + } + + loadMap := make(map[int64]*service.UserLoadInfo, len(users)) + for _, uc := range cmds { + currentConcurrency := int(uc.zcardCmd.Val()) + waitingCount := 0 + if v, err := uc.getCmd.Int(); err == nil { + waitingCount = v } - - userID, _ := strconv.ParseInt(fmt.Sprintf("%v", result[i]), 10, 64) - currentConcurrency, _ := strconv.Atoi(fmt.Sprintf("%v", result[i+1])) - waitingCount, _ := strconv.Atoi(fmt.Sprintf("%v", result[i+2])) - loadRate, _ := strconv.Atoi(fmt.Sprintf("%v", result[i+3])) - - loadMap[userID] = &service.UserLoadInfo{ - UserID: userID, + loadRate := 0 + if uc.maxConcurrency > 0 { + loadRate = (currentConcurrency + waitingCount) * 100 / uc.maxConcurrency + } + loadMap[uc.id] = &service.UserLoadInfo{ + UserID: uc.id, CurrentConcurrency: currentConcurrency, WaitingCount: waitingCount, LoadRate: loadRate, diff --git a/backend/internal/repository/ent.go b/backend/internal/repository/ent.go index d7d574e8..5f3f5a84 100644 --- a/backend/internal/repository/ent.go +++ b/backend/internal/repository/ent.go @@ -5,6 +5,7 @@ package repository import ( "context" "database/sql" + "fmt" "time" "github.com/Wei-Shaw/sub2api/ent" @@ -66,6 +67,18 @@ func InitEnt(cfg *config.Config) (*ent.Client, *sql.DB, error) { // 创建 Ent 客户端,绑定到已配置的数据库驱动。 client := ent.NewClient(ent.Driver(drv)) + // 启动阶段:从配置或数据库中确保系统密钥可用。 + if err := ensureBootstrapSecrets(migrationCtx, client, cfg); err != nil { + _ = client.Close() + return nil, nil, err + } + + // 在密钥补齐后执行完整配置校验,避免空 jwt.secret 导致服务运行时失败。 + if err := cfg.Validate(); err != nil { + _ = client.Close() + return nil, nil, fmt.Errorf("validate config after secret bootstrap: %w", err) + } + // SIMPLE 模式:启动时补齐各平台默认分组。 // - anthropic/openai/gemini: 确保存在 -default // - antigravity: 仅要求存在 >=2 个未软删除分组(用于 claude/gemini 混合调度场景) diff --git a/backend/internal/repository/github_release_service.go b/backend/internal/repository/github_release_service.go index 03f8cc66..28efe914 100644 --- a/backend/internal/repository/github_release_service.go +++ b/backend/internal/repository/github_release_service.go @@ -18,14 +18,21 @@ type githubReleaseClient struct { downloadHTTPClient *http.Client } +type githubReleaseClientError struct { + err error +} + // NewGitHubReleaseClient 创建 GitHub Release 客户端 // proxyURL 为空时直连 GitHub,支持 http/https/socks5/socks5h 协议 -func NewGitHubReleaseClient(proxyURL string) service.GitHubReleaseClient { +func NewGitHubReleaseClient(proxyURL string, allowDirectOnProxyError bool) service.GitHubReleaseClient { sharedClient, err := httpclient.GetClient(httpclient.Options{ Timeout: 30 * time.Second, ProxyURL: proxyURL, }) if err != nil { + if proxyURL != "" && !allowDirectOnProxyError { + return &githubReleaseClientError{err: fmt.Errorf("proxy client init failed and direct fallback is disabled; set security.proxy_fallback.allow_direct_on_error=true to allow fallback: %w", err)} + } sharedClient = &http.Client{Timeout: 30 * time.Second} } @@ -35,6 +42,9 @@ func NewGitHubReleaseClient(proxyURL string) service.GitHubReleaseClient { ProxyURL: proxyURL, }) if err != nil { + if proxyURL != "" && !allowDirectOnProxyError { + return &githubReleaseClientError{err: fmt.Errorf("proxy client init failed and direct fallback is disabled; set security.proxy_fallback.allow_direct_on_error=true to allow fallback: %w", err)} + } downloadClient = &http.Client{Timeout: 10 * time.Minute} } @@ -44,6 +54,18 @@ func NewGitHubReleaseClient(proxyURL string) service.GitHubReleaseClient { } } +func (c *githubReleaseClientError) FetchLatestRelease(ctx context.Context, repo string) (*service.GitHubRelease, error) { + return nil, c.err +} + +func (c *githubReleaseClientError) DownloadFile(ctx context.Context, url, dest string, maxSize int64) error { + return c.err +} + +func (c *githubReleaseClientError) FetchChecksumFile(ctx context.Context, url string) ([]byte, error) { + return nil, c.err +} + func (c *githubReleaseClient) FetchLatestRelease(ctx context.Context, repo string) (*service.GitHubRelease, error) { url := fmt.Sprintf("https://api.github.com/repos/%s/releases/latest", repo) diff --git a/backend/internal/repository/group_repo.go b/backend/internal/repository/group_repo.go index 4e7a836f..fd239996 100644 --- a/backend/internal/repository/group_repo.go +++ b/backend/internal/repository/group_repo.go @@ -4,11 +4,11 @@ import ( "context" "database/sql" "errors" - "log" dbent "github.com/Wei-Shaw/sub2api/ent" "github.com/Wei-Shaw/sub2api/ent/apikey" "github.com/Wei-Shaw/sub2api/ent/group" + "github.com/Wei-Shaw/sub2api/internal/pkg/logger" "github.com/Wei-Shaw/sub2api/internal/pkg/pagination" "github.com/Wei-Shaw/sub2api/internal/service" "github.com/lib/pq" @@ -47,6 +47,10 @@ func (r *groupRepository) Create(ctx context.Context, groupIn *service.Group) er SetNillableImagePrice1k(groupIn.ImagePrice1K). SetNillableImagePrice2k(groupIn.ImagePrice2K). SetNillableImagePrice4k(groupIn.ImagePrice4K). + SetNillableSoraImagePrice360(groupIn.SoraImagePrice360). + SetNillableSoraImagePrice540(groupIn.SoraImagePrice540). + SetNillableSoraVideoPricePerRequest(groupIn.SoraVideoPricePerRequest). + SetNillableSoraVideoPricePerRequestHd(groupIn.SoraVideoPricePerRequestHD). SetDefaultValidityDays(groupIn.DefaultValidityDays). SetClaudeCodeOnly(groupIn.ClaudeCodeOnly). SetNillableFallbackGroupID(groupIn.FallbackGroupID). @@ -68,7 +72,7 @@ func (r *groupRepository) Create(ctx context.Context, groupIn *service.Group) er groupIn.CreatedAt = created.CreatedAt groupIn.UpdatedAt = created.UpdatedAt if err := enqueueSchedulerOutbox(ctx, r.sql, service.SchedulerOutboxEventGroupChanged, nil, &groupIn.ID, nil); err != nil { - log.Printf("[SchedulerOutbox] enqueue group create failed: group=%d err=%v", groupIn.ID, err) + logger.LegacyPrintf("repository.group", "[SchedulerOutbox] enqueue group create failed: group=%d err=%v", groupIn.ID, err) } } return translatePersistenceError(err, nil, service.ErrGroupExists) @@ -110,6 +114,10 @@ func (r *groupRepository) Update(ctx context.Context, groupIn *service.Group) er SetNillableImagePrice1k(groupIn.ImagePrice1K). SetNillableImagePrice2k(groupIn.ImagePrice2K). SetNillableImagePrice4k(groupIn.ImagePrice4K). + SetNillableSoraImagePrice360(groupIn.SoraImagePrice360). + SetNillableSoraImagePrice540(groupIn.SoraImagePrice540). + SetNillableSoraVideoPricePerRequest(groupIn.SoraVideoPricePerRequest). + SetNillableSoraVideoPricePerRequestHd(groupIn.SoraVideoPricePerRequestHD). SetDefaultValidityDays(groupIn.DefaultValidityDays). SetClaudeCodeOnly(groupIn.ClaudeCodeOnly). SetModelRoutingEnabled(groupIn.ModelRoutingEnabled). @@ -144,7 +152,7 @@ func (r *groupRepository) Update(ctx context.Context, groupIn *service.Group) er } groupIn.UpdatedAt = updated.UpdatedAt if err := enqueueSchedulerOutbox(ctx, r.sql, service.SchedulerOutboxEventGroupChanged, nil, &groupIn.ID, nil); err != nil { - log.Printf("[SchedulerOutbox] enqueue group update failed: group=%d err=%v", groupIn.ID, err) + logger.LegacyPrintf("repository.group", "[SchedulerOutbox] enqueue group update failed: group=%d err=%v", groupIn.ID, err) } return nil } @@ -155,7 +163,7 @@ func (r *groupRepository) Delete(ctx context.Context, id int64) error { return translatePersistenceError(err, service.ErrGroupNotFound, nil) } if err := enqueueSchedulerOutbox(ctx, r.sql, service.SchedulerOutboxEventGroupChanged, nil, &id, nil); err != nil { - log.Printf("[SchedulerOutbox] enqueue group delete failed: group=%d err=%v", id, err) + logger.LegacyPrintf("repository.group", "[SchedulerOutbox] enqueue group delete failed: group=%d err=%v", id, err) } return nil } @@ -183,7 +191,7 @@ func (r *groupRepository) ListWithFilters(ctx context.Context, params pagination q = q.Where(group.IsExclusiveEQ(*isExclusive)) } - total, err := q.Count(ctx) + total, err := q.Clone().Count(ctx) if err != nil { return nil, nil, err } @@ -288,7 +296,7 @@ func (r *groupRepository) DeleteAccountGroupsByGroupID(ctx context.Context, grou } affected, _ := res.RowsAffected() if err := enqueueSchedulerOutbox(ctx, r.sql, service.SchedulerOutboxEventGroupChanged, nil, &groupID, nil); err != nil { - log.Printf("[SchedulerOutbox] enqueue group account clear failed: group=%d err=%v", groupID, err) + logger.LegacyPrintf("repository.group", "[SchedulerOutbox] enqueue group account clear failed: group=%d err=%v", groupID, err) } return affected, nil } @@ -398,7 +406,7 @@ func (r *groupRepository) DeleteCascade(ctx context.Context, id int64) ([]int64, } } if err := enqueueSchedulerOutbox(ctx, r.sql, service.SchedulerOutboxEventGroupChanged, nil, &id, nil); err != nil { - log.Printf("[SchedulerOutbox] enqueue group cascade delete failed: group=%d err=%v", id, err) + logger.LegacyPrintf("repository.group", "[SchedulerOutbox] enqueue group cascade delete failed: group=%d err=%v", id, err) } return affectedUserIDs, nil @@ -492,7 +500,7 @@ func (r *groupRepository) BindAccountsToGroup(ctx context.Context, groupID int64 // 发送调度器事件 if err := enqueueSchedulerOutbox(ctx, r.sql, service.SchedulerOutboxEventGroupChanged, nil, &groupID, nil); err != nil { - log.Printf("[SchedulerOutbox] enqueue bind accounts to group failed: group=%d err=%v", groupID, err) + logger.LegacyPrintf("repository.group", "[SchedulerOutbox] enqueue bind accounts to group failed: group=%d err=%v", groupID, err) } return nil diff --git a/backend/internal/repository/idempotency_repo.go b/backend/internal/repository/idempotency_repo.go new file mode 100644 index 00000000..32f2faae --- /dev/null +++ b/backend/internal/repository/idempotency_repo.go @@ -0,0 +1,237 @@ +package repository + +import ( + "context" + "database/sql" + "errors" + "time" + + dbent "github.com/Wei-Shaw/sub2api/ent" + "github.com/Wei-Shaw/sub2api/internal/service" +) + +type idempotencyRepository struct { + sql sqlExecutor +} + +func NewIdempotencyRepository(_ *dbent.Client, sqlDB *sql.DB) service.IdempotencyRepository { + return &idempotencyRepository{sql: sqlDB} +} + +func (r *idempotencyRepository) CreateProcessing(ctx context.Context, record *service.IdempotencyRecord) (bool, error) { + if record == nil { + return false, nil + } + query := ` + INSERT INTO idempotency_records ( + scope, idempotency_key_hash, request_fingerprint, status, locked_until, expires_at + ) VALUES ($1, $2, $3, $4, $5, $6) + ON CONFLICT (scope, idempotency_key_hash) DO NOTHING + RETURNING id, created_at, updated_at + ` + var createdAt time.Time + var updatedAt time.Time + err := scanSingleRow(ctx, r.sql, query, []any{ + record.Scope, + record.IdempotencyKeyHash, + record.RequestFingerprint, + record.Status, + record.LockedUntil, + record.ExpiresAt, + }, &record.ID, &createdAt, &updatedAt) + if errors.Is(err, sql.ErrNoRows) { + return false, nil + } + if err != nil { + return false, err + } + record.CreatedAt = createdAt + record.UpdatedAt = updatedAt + return true, nil +} + +func (r *idempotencyRepository) GetByScopeAndKeyHash(ctx context.Context, scope, keyHash string) (*service.IdempotencyRecord, error) { + query := ` + SELECT + id, scope, idempotency_key_hash, request_fingerprint, status, response_status, + response_body, error_reason, locked_until, expires_at, created_at, updated_at + FROM idempotency_records + WHERE scope = $1 AND idempotency_key_hash = $2 + ` + record := &service.IdempotencyRecord{} + var responseStatus sql.NullInt64 + var responseBody sql.NullString + var errorReason sql.NullString + var lockedUntil sql.NullTime + err := scanSingleRow(ctx, r.sql, query, []any{scope, keyHash}, + &record.ID, + &record.Scope, + &record.IdempotencyKeyHash, + &record.RequestFingerprint, + &record.Status, + &responseStatus, + &responseBody, + &errorReason, + &lockedUntil, + &record.ExpiresAt, + &record.CreatedAt, + &record.UpdatedAt, + ) + if errors.Is(err, sql.ErrNoRows) { + return nil, nil + } + if err != nil { + return nil, err + } + if responseStatus.Valid { + v := int(responseStatus.Int64) + record.ResponseStatus = &v + } + if responseBody.Valid { + v := responseBody.String + record.ResponseBody = &v + } + if errorReason.Valid { + v := errorReason.String + record.ErrorReason = &v + } + if lockedUntil.Valid { + v := lockedUntil.Time + record.LockedUntil = &v + } + return record, nil +} + +func (r *idempotencyRepository) TryReclaim( + ctx context.Context, + id int64, + fromStatus string, + now, newLockedUntil, newExpiresAt time.Time, +) (bool, error) { + query := ` + UPDATE idempotency_records + SET status = $2, + locked_until = $3, + error_reason = NULL, + updated_at = NOW(), + expires_at = $4 + WHERE id = $1 + AND status = $5 + AND (locked_until IS NULL OR locked_until <= $6) + ` + res, err := r.sql.ExecContext(ctx, query, + id, + service.IdempotencyStatusProcessing, + newLockedUntil, + newExpiresAt, + fromStatus, + now, + ) + if err != nil { + return false, err + } + affected, err := res.RowsAffected() + if err != nil { + return false, err + } + return affected > 0, nil +} + +func (r *idempotencyRepository) ExtendProcessingLock( + ctx context.Context, + id int64, + requestFingerprint string, + newLockedUntil, + newExpiresAt time.Time, +) (bool, error) { + query := ` + UPDATE idempotency_records + SET locked_until = $2, + expires_at = $3, + updated_at = NOW() + WHERE id = $1 + AND status = $4 + AND request_fingerprint = $5 + ` + res, err := r.sql.ExecContext( + ctx, + query, + id, + newLockedUntil, + newExpiresAt, + service.IdempotencyStatusProcessing, + requestFingerprint, + ) + if err != nil { + return false, err + } + affected, err := res.RowsAffected() + if err != nil { + return false, err + } + return affected > 0, nil +} + +func (r *idempotencyRepository) MarkSucceeded(ctx context.Context, id int64, responseStatus int, responseBody string, expiresAt time.Time) error { + query := ` + UPDATE idempotency_records + SET status = $2, + response_status = $3, + response_body = $4, + error_reason = NULL, + locked_until = NULL, + expires_at = $5, + updated_at = NOW() + WHERE id = $1 + ` + _, err := r.sql.ExecContext(ctx, query, + id, + service.IdempotencyStatusSucceeded, + responseStatus, + responseBody, + expiresAt, + ) + return err +} + +func (r *idempotencyRepository) MarkFailedRetryable(ctx context.Context, id int64, errorReason string, lockedUntil, expiresAt time.Time) error { + query := ` + UPDATE idempotency_records + SET status = $2, + error_reason = $3, + locked_until = $4, + expires_at = $5, + updated_at = NOW() + WHERE id = $1 + ` + _, err := r.sql.ExecContext(ctx, query, + id, + service.IdempotencyStatusFailedRetryable, + errorReason, + lockedUntil, + expiresAt, + ) + return err +} + +func (r *idempotencyRepository) DeleteExpired(ctx context.Context, now time.Time, limit int) (int64, error) { + if limit <= 0 { + limit = 500 + } + query := ` + WITH victims AS ( + SELECT id + FROM idempotency_records + WHERE expires_at <= $1 + ORDER BY expires_at ASC + LIMIT $2 + ) + DELETE FROM idempotency_records + WHERE id IN (SELECT id FROM victims) + ` + res, err := r.sql.ExecContext(ctx, query, now, limit) + if err != nil { + return 0, err + } + return res.RowsAffected() +} diff --git a/backend/internal/repository/idempotency_repo_integration_test.go b/backend/internal/repository/idempotency_repo_integration_test.go new file mode 100644 index 00000000..23b52726 --- /dev/null +++ b/backend/internal/repository/idempotency_repo_integration_test.go @@ -0,0 +1,150 @@ +//go:build integration + +package repository + +import ( + "context" + "crypto/sha256" + "encoding/hex" + "testing" + "time" + + "github.com/Wei-Shaw/sub2api/internal/service" + "github.com/stretchr/testify/require" +) + +// hashedTestValue returns a unique SHA-256 hex string (64 chars) that fits VARCHAR(64) columns. +func hashedTestValue(t *testing.T, prefix string) string { + t.Helper() + sum := sha256.Sum256([]byte(uniqueTestValue(t, prefix))) + return hex.EncodeToString(sum[:]) +} + +func TestIdempotencyRepo_CreateProcessing_CompeteSameKey(t *testing.T) { + tx := testTx(t) + repo := &idempotencyRepository{sql: tx} + ctx := context.Background() + + now := time.Now().UTC() + record := &service.IdempotencyRecord{ + Scope: uniqueTestValue(t, "idem-scope-create"), + IdempotencyKeyHash: hashedTestValue(t, "idem-hash"), + RequestFingerprint: hashedTestValue(t, "idem-fp"), + Status: service.IdempotencyStatusProcessing, + LockedUntil: ptrTime(now.Add(30 * time.Second)), + ExpiresAt: now.Add(24 * time.Hour), + } + owner, err := repo.CreateProcessing(ctx, record) + require.NoError(t, err) + require.True(t, owner) + require.NotZero(t, record.ID) + + duplicate := &service.IdempotencyRecord{ + Scope: record.Scope, + IdempotencyKeyHash: record.IdempotencyKeyHash, + RequestFingerprint: hashedTestValue(t, "idem-fp-other"), + Status: service.IdempotencyStatusProcessing, + LockedUntil: ptrTime(now.Add(30 * time.Second)), + ExpiresAt: now.Add(24 * time.Hour), + } + owner, err = repo.CreateProcessing(ctx, duplicate) + require.NoError(t, err) + require.False(t, owner, "same scope+key hash should be de-duplicated") +} + +func TestIdempotencyRepo_TryReclaim_StatusAndLockWindow(t *testing.T) { + tx := testTx(t) + repo := &idempotencyRepository{sql: tx} + ctx := context.Background() + + now := time.Now().UTC() + record := &service.IdempotencyRecord{ + Scope: uniqueTestValue(t, "idem-scope-reclaim"), + IdempotencyKeyHash: hashedTestValue(t, "idem-hash-reclaim"), + RequestFingerprint: hashedTestValue(t, "idem-fp-reclaim"), + Status: service.IdempotencyStatusProcessing, + LockedUntil: ptrTime(now.Add(10 * time.Second)), + ExpiresAt: now.Add(24 * time.Hour), + } + owner, err := repo.CreateProcessing(ctx, record) + require.NoError(t, err) + require.True(t, owner) + + require.NoError(t, repo.MarkFailedRetryable( + ctx, + record.ID, + "RETRYABLE_FAILURE", + now.Add(-2*time.Second), + now.Add(24*time.Hour), + )) + + newLockedUntil := now.Add(20 * time.Second) + reclaimed, err := repo.TryReclaim( + ctx, + record.ID, + service.IdempotencyStatusFailedRetryable, + now, + newLockedUntil, + now.Add(24*time.Hour), + ) + require.NoError(t, err) + require.True(t, reclaimed, "failed_retryable + expired lock should allow reclaim") + + got, err := repo.GetByScopeAndKeyHash(ctx, record.Scope, record.IdempotencyKeyHash) + require.NoError(t, err) + require.NotNil(t, got) + require.Equal(t, service.IdempotencyStatusProcessing, got.Status) + require.NotNil(t, got.LockedUntil) + require.True(t, got.LockedUntil.After(now)) + + require.NoError(t, repo.MarkFailedRetryable( + ctx, + record.ID, + "RETRYABLE_FAILURE", + now.Add(20*time.Second), + now.Add(24*time.Hour), + )) + + reclaimed, err = repo.TryReclaim( + ctx, + record.ID, + service.IdempotencyStatusFailedRetryable, + now, + now.Add(40*time.Second), + now.Add(24*time.Hour), + ) + require.NoError(t, err) + require.False(t, reclaimed, "within lock window should not reclaim") +} + +func TestIdempotencyRepo_StatusTransition_ToSucceeded(t *testing.T) { + tx := testTx(t) + repo := &idempotencyRepository{sql: tx} + ctx := context.Background() + + now := time.Now().UTC() + record := &service.IdempotencyRecord{ + Scope: uniqueTestValue(t, "idem-scope-success"), + IdempotencyKeyHash: hashedTestValue(t, "idem-hash-success"), + RequestFingerprint: hashedTestValue(t, "idem-fp-success"), + Status: service.IdempotencyStatusProcessing, + LockedUntil: ptrTime(now.Add(10 * time.Second)), + ExpiresAt: now.Add(24 * time.Hour), + } + owner, err := repo.CreateProcessing(ctx, record) + require.NoError(t, err) + require.True(t, owner) + + require.NoError(t, repo.MarkSucceeded(ctx, record.ID, 200, `{"ok":true}`, now.Add(24*time.Hour))) + + got, err := repo.GetByScopeAndKeyHash(ctx, record.Scope, record.IdempotencyKeyHash) + require.NoError(t, err) + require.NotNil(t, got) + require.Equal(t, service.IdempotencyStatusSucceeded, got.Status) + require.NotNil(t, got.ResponseStatus) + require.Equal(t, 200, *got.ResponseStatus) + require.NotNil(t, got.ResponseBody) + require.Equal(t, `{"ok":true}`, *got.ResponseBody) + require.Nil(t, got.LockedUntil) +} + diff --git a/backend/internal/repository/migrations_schema_integration_test.go b/backend/internal/repository/migrations_schema_integration_test.go index bc37ee72..f50d2b26 100644 --- a/backend/internal/repository/migrations_schema_integration_test.go +++ b/backend/internal/repository/migrations_schema_integration_test.go @@ -48,6 +48,11 @@ func TestMigrationsRunner_IsIdempotent_AndSchemaIsUpToDate(t *testing.T) { require.NoError(t, tx.QueryRowContext(context.Background(), "SELECT to_regclass('public.settings')").Scan(&settingsRegclass)) require.True(t, settingsRegclass.Valid, "expected settings table to exist") + // security_secrets table should exist + var securitySecretsRegclass sql.NullString + require.NoError(t, tx.QueryRowContext(context.Background(), "SELECT to_regclass('public.security_secrets')").Scan(&securitySecretsRegclass)) + require.True(t, securitySecretsRegclass.Valid, "expected security_secrets table to exist") + // user_allowed_groups table should exist var uagRegclass sql.NullString require.NoError(t, tx.QueryRowContext(context.Background(), "SELECT to_regclass('public.user_allowed_groups')").Scan(&uagRegclass)) diff --git a/backend/internal/repository/openai_oauth_service.go b/backend/internal/repository/openai_oauth_service.go index 394d3a1a..088e7d7f 100644 --- a/backend/internal/repository/openai_oauth_service.go +++ b/backend/internal/repository/openai_oauth_service.go @@ -4,6 +4,7 @@ import ( "context" "net/http" "net/url" + "strings" "time" infraerrors "github.com/Wei-Shaw/sub2api/internal/pkg/errors" @@ -56,12 +57,49 @@ func (s *openaiOAuthService) ExchangeCode(ctx context.Context, code, codeVerifie } func (s *openaiOAuthService) RefreshToken(ctx context.Context, refreshToken, proxyURL string) (*openai.TokenResponse, error) { + return s.RefreshTokenWithClientID(ctx, refreshToken, proxyURL, "") +} + +func (s *openaiOAuthService) RefreshTokenWithClientID(ctx context.Context, refreshToken, proxyURL string, clientID string) (*openai.TokenResponse, error) { + if strings.TrimSpace(clientID) != "" { + return s.refreshTokenWithClientID(ctx, refreshToken, proxyURL, strings.TrimSpace(clientID)) + } + + clientIDs := []string{ + openai.ClientID, + openai.SoraClientID, + } + seen := make(map[string]struct{}, len(clientIDs)) + var lastErr error + for _, clientID := range clientIDs { + clientID = strings.TrimSpace(clientID) + if clientID == "" { + continue + } + if _, ok := seen[clientID]; ok { + continue + } + seen[clientID] = struct{}{} + + tokenResp, err := s.refreshTokenWithClientID(ctx, refreshToken, proxyURL, clientID) + if err == nil { + return tokenResp, nil + } + lastErr = err + } + if lastErr != nil { + return nil, lastErr + } + return nil, infraerrors.New(http.StatusBadGateway, "OPENAI_OAUTH_TOKEN_REFRESH_FAILED", "token refresh failed") +} + +func (s *openaiOAuthService) refreshTokenWithClientID(ctx context.Context, refreshToken, proxyURL, clientID string) (*openai.TokenResponse, error) { client := createOpenAIReqClient(proxyURL) formData := url.Values{} formData.Set("grant_type", "refresh_token") formData.Set("refresh_token", refreshToken) - formData.Set("client_id", openai.ClientID) + formData.Set("client_id", clientID) formData.Set("scope", openai.RefreshScopes) var tokenResp openai.TokenResponse diff --git a/backend/internal/repository/openai_oauth_service_test.go b/backend/internal/repository/openai_oauth_service_test.go index f9df08c8..5938272a 100644 --- a/backend/internal/repository/openai_oauth_service_test.go +++ b/backend/internal/repository/openai_oauth_service_test.go @@ -136,6 +136,60 @@ func (s *OpenAIOAuthServiceSuite) TestRefreshToken_FormFields() { require.Equal(s.T(), "rt2", resp.RefreshToken) } +func (s *OpenAIOAuthServiceSuite) TestRefreshToken_FallbackToSoraClientID() { + var seenClientIDs []string + s.setupServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if err := r.ParseForm(); err != nil { + w.WriteHeader(http.StatusBadRequest) + return + } + clientID := r.PostForm.Get("client_id") + seenClientIDs = append(seenClientIDs, clientID) + if clientID == openai.ClientID { + w.WriteHeader(http.StatusBadRequest) + _, _ = io.WriteString(w, "invalid_grant") + return + } + if clientID == openai.SoraClientID { + w.Header().Set("Content-Type", "application/json") + _, _ = io.WriteString(w, `{"access_token":"at-sora","refresh_token":"rt-sora","token_type":"bearer","expires_in":3600}`) + return + } + w.WriteHeader(http.StatusBadRequest) + })) + + resp, err := s.svc.RefreshToken(s.ctx, "rt", "") + require.NoError(s.T(), err, "RefreshToken") + require.Equal(s.T(), "at-sora", resp.AccessToken) + require.Equal(s.T(), "rt-sora", resp.RefreshToken) + require.Equal(s.T(), []string{openai.ClientID, openai.SoraClientID}, seenClientIDs) +} + +func (s *OpenAIOAuthServiceSuite) TestRefreshToken_UseProvidedClientID() { + const customClientID = "custom-client-id" + var seenClientIDs []string + s.setupServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if err := r.ParseForm(); err != nil { + w.WriteHeader(http.StatusBadRequest) + return + } + clientID := r.PostForm.Get("client_id") + seenClientIDs = append(seenClientIDs, clientID) + if clientID != customClientID { + w.WriteHeader(http.StatusBadRequest) + return + } + w.Header().Set("Content-Type", "application/json") + _, _ = io.WriteString(w, `{"access_token":"at-custom","refresh_token":"rt-custom","token_type":"bearer","expires_in":3600}`) + })) + + resp, err := s.svc.RefreshTokenWithClientID(s.ctx, "rt", "", customClientID) + require.NoError(s.T(), err, "RefreshTokenWithClientID") + require.Equal(s.T(), "at-custom", resp.AccessToken) + require.Equal(s.T(), "rt-custom", resp.RefreshToken) + require.Equal(s.T(), []string{customClientID}, seenClientIDs) +} + func (s *OpenAIOAuthServiceSuite) TestNonSuccessStatus_IncludesBody() { s.setupServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { w.WriteHeader(http.StatusBadRequest) diff --git a/backend/internal/repository/ops_repo.go b/backend/internal/repository/ops_repo.go index b04154b7..989573f2 100644 --- a/backend/internal/repository/ops_repo.go +++ b/backend/internal/repository/ops_repo.go @@ -3,6 +3,7 @@ package repository import ( "context" "database/sql" + "encoding/json" "fmt" "strings" "time" @@ -55,6 +56,10 @@ INSERT INTO ops_error_logs ( upstream_error_message, upstream_error_detail, upstream_errors, + auth_latency_ms, + routing_latency_ms, + upstream_latency_ms, + response_latency_ms, time_to_first_token_ms, request_body, request_body_truncated, @@ -64,7 +69,7 @@ INSERT INTO ops_error_logs ( retry_count, created_at ) VALUES ( - $1,$2,$3,$4,$5,$6,$7,$8,$9,$10,$11,$12,$13,$14,$15,$16,$17,$18,$19,$20,$21,$22,$23,$24,$25,$26,$27,$28,$29,$30,$31,$32,$33,$34 + $1,$2,$3,$4,$5,$6,$7,$8,$9,$10,$11,$12,$13,$14,$15,$16,$17,$18,$19,$20,$21,$22,$23,$24,$25,$26,$27,$28,$29,$30,$31,$32,$33,$34,$35,$36,$37,$38 ) RETURNING id` var id int64 @@ -97,6 +102,10 @@ INSERT INTO ops_error_logs ( opsNullString(input.UpstreamErrorMessage), opsNullString(input.UpstreamErrorDetail), opsNullString(input.UpstreamErrorsJSON), + opsNullInt64(input.AuthLatencyMs), + opsNullInt64(input.RoutingLatencyMs), + opsNullInt64(input.UpstreamLatencyMs), + opsNullInt64(input.ResponseLatencyMs), opsNullInt64(input.TimeToFirstTokenMs), opsNullString(input.RequestBodyJSON), input.RequestBodyTruncated, @@ -930,6 +939,243 @@ WHERE id = $1` return err } +func (r *opsRepository) BatchInsertSystemLogs(ctx context.Context, inputs []*service.OpsInsertSystemLogInput) (int64, error) { + if r == nil || r.db == nil { + return 0, fmt.Errorf("nil ops repository") + } + if len(inputs) == 0 { + return 0, nil + } + + tx, err := r.db.BeginTx(ctx, nil) + if err != nil { + return 0, err + } + stmt, err := tx.PrepareContext(ctx, pq.CopyIn( + "ops_system_logs", + "created_at", + "level", + "component", + "message", + "request_id", + "client_request_id", + "user_id", + "account_id", + "platform", + "model", + "extra", + )) + if err != nil { + _ = tx.Rollback() + return 0, err + } + + var inserted int64 + for _, input := range inputs { + if input == nil { + continue + } + createdAt := input.CreatedAt + if createdAt.IsZero() { + createdAt = time.Now().UTC() + } + component := strings.TrimSpace(input.Component) + level := strings.ToLower(strings.TrimSpace(input.Level)) + message := strings.TrimSpace(input.Message) + if level == "" || message == "" { + continue + } + if component == "" { + component = "app" + } + extra := strings.TrimSpace(input.ExtraJSON) + if extra == "" { + extra = "{}" + } + if _, err := stmt.ExecContext( + ctx, + createdAt.UTC(), + level, + component, + message, + opsNullString(input.RequestID), + opsNullString(input.ClientRequestID), + opsNullInt64(input.UserID), + opsNullInt64(input.AccountID), + opsNullString(input.Platform), + opsNullString(input.Model), + extra, + ); err != nil { + _ = stmt.Close() + _ = tx.Rollback() + return inserted, err + } + inserted++ + } + + if _, err := stmt.ExecContext(ctx); err != nil { + _ = stmt.Close() + _ = tx.Rollback() + return inserted, err + } + if err := stmt.Close(); err != nil { + _ = tx.Rollback() + return inserted, err + } + if err := tx.Commit(); err != nil { + return inserted, err + } + return inserted, nil +} + +func (r *opsRepository) ListSystemLogs(ctx context.Context, filter *service.OpsSystemLogFilter) (*service.OpsSystemLogList, error) { + if r == nil || r.db == nil { + return nil, fmt.Errorf("nil ops repository") + } + if filter == nil { + filter = &service.OpsSystemLogFilter{} + } + + page := filter.Page + if page <= 0 { + page = 1 + } + pageSize := filter.PageSize + if pageSize <= 0 { + pageSize = 50 + } + if pageSize > 200 { + pageSize = 200 + } + + where, args, _ := buildOpsSystemLogsWhere(filter) + countSQL := "SELECT COUNT(*) FROM ops_system_logs l " + where + var total int + if err := r.db.QueryRowContext(ctx, countSQL, args...).Scan(&total); err != nil { + return nil, err + } + + offset := (page - 1) * pageSize + argsWithLimit := append(args, pageSize, offset) + query := ` +SELECT + l.id, + l.created_at, + l.level, + COALESCE(l.component, ''), + COALESCE(l.message, ''), + COALESCE(l.request_id, ''), + COALESCE(l.client_request_id, ''), + l.user_id, + l.account_id, + COALESCE(l.platform, ''), + COALESCE(l.model, ''), + COALESCE(l.extra::text, '{}') +FROM ops_system_logs l +` + where + ` +ORDER BY l.created_at DESC, l.id DESC +LIMIT $` + itoa(len(args)+1) + ` OFFSET $` + itoa(len(args)+2) + + rows, err := r.db.QueryContext(ctx, query, argsWithLimit...) + if err != nil { + return nil, err + } + defer func() { _ = rows.Close() }() + + logs := make([]*service.OpsSystemLog, 0, pageSize) + for rows.Next() { + item := &service.OpsSystemLog{} + var userID sql.NullInt64 + var accountID sql.NullInt64 + var extraRaw string + if err := rows.Scan( + &item.ID, + &item.CreatedAt, + &item.Level, + &item.Component, + &item.Message, + &item.RequestID, + &item.ClientRequestID, + &userID, + &accountID, + &item.Platform, + &item.Model, + &extraRaw, + ); err != nil { + return nil, err + } + if userID.Valid { + v := userID.Int64 + item.UserID = &v + } + if accountID.Valid { + v := accountID.Int64 + item.AccountID = &v + } + extraRaw = strings.TrimSpace(extraRaw) + if extraRaw != "" && extraRaw != "null" && extraRaw != "{}" { + extra := make(map[string]any) + if err := json.Unmarshal([]byte(extraRaw), &extra); err == nil { + item.Extra = extra + } + } + logs = append(logs, item) + } + if err := rows.Err(); err != nil { + return nil, err + } + + return &service.OpsSystemLogList{ + Logs: logs, + Total: total, + Page: page, + PageSize: pageSize, + }, nil +} + +func (r *opsRepository) DeleteSystemLogs(ctx context.Context, filter *service.OpsSystemLogCleanupFilter) (int64, error) { + if r == nil || r.db == nil { + return 0, fmt.Errorf("nil ops repository") + } + if filter == nil { + filter = &service.OpsSystemLogCleanupFilter{} + } + + where, args, hasConstraint := buildOpsSystemLogsCleanupWhere(filter) + if !hasConstraint { + return 0, fmt.Errorf("cleanup requires at least one filter condition") + } + + query := "DELETE FROM ops_system_logs l " + where + res, err := r.db.ExecContext(ctx, query, args...) + if err != nil { + return 0, err + } + return res.RowsAffected() +} + +func (r *opsRepository) InsertSystemLogCleanupAudit(ctx context.Context, input *service.OpsSystemLogCleanupAudit) error { + if r == nil || r.db == nil { + return fmt.Errorf("nil ops repository") + } + if input == nil { + return fmt.Errorf("nil input") + } + createdAt := input.CreatedAt + if createdAt.IsZero() { + createdAt = time.Now().UTC() + } + _, err := r.db.ExecContext(ctx, ` +INSERT INTO ops_system_log_cleanup_audits ( + created_at, + operator_id, + conditions, + deleted_rows +) VALUES ($1,$2,$3,$4) +`, createdAt.UTC(), input.OperatorID, input.Conditions, input.DeletedRows) + return err +} + func buildOpsErrorLogsWhere(filter *service.OpsErrorLogFilter) (string, []any) { clauses := make([]string, 0, 12) args := make([]any, 0, 12) @@ -948,7 +1194,7 @@ func buildOpsErrorLogsWhere(filter *service.OpsErrorLogFilter) (string, []any) { } // Keep list endpoints scoped to client errors unless explicitly filtering upstream phase. if phaseFilter != "upstream" { - clauses = append(clauses, "COALESCE(status_code, 0) >= 400") + clauses = append(clauses, "COALESCE(e.status_code, 0) >= 400") } if filter.StartTime != nil && !filter.StartTime.IsZero() { @@ -962,33 +1208,33 @@ func buildOpsErrorLogsWhere(filter *service.OpsErrorLogFilter) (string, []any) { } if p := strings.TrimSpace(filter.Platform); p != "" { args = append(args, p) - clauses = append(clauses, "platform = $"+itoa(len(args))) + clauses = append(clauses, "e.platform = $"+itoa(len(args))) } if filter.GroupID != nil && *filter.GroupID > 0 { args = append(args, *filter.GroupID) - clauses = append(clauses, "group_id = $"+itoa(len(args))) + clauses = append(clauses, "e.group_id = $"+itoa(len(args))) } if filter.AccountID != nil && *filter.AccountID > 0 { args = append(args, *filter.AccountID) - clauses = append(clauses, "account_id = $"+itoa(len(args))) + clauses = append(clauses, "e.account_id = $"+itoa(len(args))) } if phase := phaseFilter; phase != "" { args = append(args, phase) - clauses = append(clauses, "error_phase = $"+itoa(len(args))) + clauses = append(clauses, "e.error_phase = $"+itoa(len(args))) } if filter != nil { if owner := strings.TrimSpace(strings.ToLower(filter.Owner)); owner != "" { args = append(args, owner) - clauses = append(clauses, "LOWER(COALESCE(error_owner,'')) = $"+itoa(len(args))) + clauses = append(clauses, "LOWER(COALESCE(e.error_owner,'')) = $"+itoa(len(args))) } if source := strings.TrimSpace(strings.ToLower(filter.Source)); source != "" { args = append(args, source) - clauses = append(clauses, "LOWER(COALESCE(error_source,'')) = $"+itoa(len(args))) + clauses = append(clauses, "LOWER(COALESCE(e.error_source,'')) = $"+itoa(len(args))) } } if resolvedFilter != nil { args = append(args, *resolvedFilter) - clauses = append(clauses, "COALESCE(resolved,false) = $"+itoa(len(args))) + clauses = append(clauses, "COALESCE(e.resolved,false) = $"+itoa(len(args))) } // View filter: errors vs excluded vs all. @@ -1000,51 +1246,140 @@ func buildOpsErrorLogsWhere(filter *service.OpsErrorLogFilter) (string, []any) { } switch view { case "", "errors": - clauses = append(clauses, "COALESCE(is_business_limited,false) = false") + clauses = append(clauses, "COALESCE(e.is_business_limited,false) = false") case "excluded": - clauses = append(clauses, "COALESCE(is_business_limited,false) = true") + clauses = append(clauses, "COALESCE(e.is_business_limited,false) = true") case "all": // no-op default: // treat unknown as default 'errors' - clauses = append(clauses, "COALESCE(is_business_limited,false) = false") + clauses = append(clauses, "COALESCE(e.is_business_limited,false) = false") } if len(filter.StatusCodes) > 0 { args = append(args, pq.Array(filter.StatusCodes)) - clauses = append(clauses, "COALESCE(upstream_status_code, status_code, 0) = ANY($"+itoa(len(args))+")") + clauses = append(clauses, "COALESCE(e.upstream_status_code, e.status_code, 0) = ANY($"+itoa(len(args))+")") } else if filter.StatusCodesOther { // "Other" means: status codes not in the common list. known := []int{400, 401, 403, 404, 409, 422, 429, 500, 502, 503, 504, 529} args = append(args, pq.Array(known)) - clauses = append(clauses, "NOT (COALESCE(upstream_status_code, status_code, 0) = ANY($"+itoa(len(args))+"))") + clauses = append(clauses, "NOT (COALESCE(e.upstream_status_code, e.status_code, 0) = ANY($"+itoa(len(args))+"))") } // Exact correlation keys (preferred for request↔upstream linkage). if rid := strings.TrimSpace(filter.RequestID); rid != "" { args = append(args, rid) - clauses = append(clauses, "COALESCE(request_id,'') = $"+itoa(len(args))) + clauses = append(clauses, "COALESCE(e.request_id,'') = $"+itoa(len(args))) } if crid := strings.TrimSpace(filter.ClientRequestID); crid != "" { args = append(args, crid) - clauses = append(clauses, "COALESCE(client_request_id,'') = $"+itoa(len(args))) + clauses = append(clauses, "COALESCE(e.client_request_id,'') = $"+itoa(len(args))) } if q := strings.TrimSpace(filter.Query); q != "" { like := "%" + q + "%" args = append(args, like) n := itoa(len(args)) - clauses = append(clauses, "(request_id ILIKE $"+n+" OR client_request_id ILIKE $"+n+" OR error_message ILIKE $"+n+")") + clauses = append(clauses, "(e.request_id ILIKE $"+n+" OR e.client_request_id ILIKE $"+n+" OR e.error_message ILIKE $"+n+")") } if userQuery := strings.TrimSpace(filter.UserQuery); userQuery != "" { like := "%" + userQuery + "%" args = append(args, like) n := itoa(len(args)) - clauses = append(clauses, "u.email ILIKE $"+n) + clauses = append(clauses, "EXISTS (SELECT 1 FROM users u WHERE u.id = e.user_id AND u.email ILIKE $"+n+")") } return "WHERE " + strings.Join(clauses, " AND "), args } +func buildOpsSystemLogsWhere(filter *service.OpsSystemLogFilter) (string, []any, bool) { + clauses := make([]string, 0, 10) + args := make([]any, 0, 10) + clauses = append(clauses, "1=1") + hasConstraint := false + + if filter != nil && filter.StartTime != nil && !filter.StartTime.IsZero() { + args = append(args, filter.StartTime.UTC()) + clauses = append(clauses, "l.created_at >= $"+itoa(len(args))) + hasConstraint = true + } + if filter != nil && filter.EndTime != nil && !filter.EndTime.IsZero() { + args = append(args, filter.EndTime.UTC()) + clauses = append(clauses, "l.created_at < $"+itoa(len(args))) + hasConstraint = true + } + if filter != nil { + if v := strings.ToLower(strings.TrimSpace(filter.Level)); v != "" { + args = append(args, v) + clauses = append(clauses, "LOWER(COALESCE(l.level,'')) = $"+itoa(len(args))) + hasConstraint = true + } + if v := strings.TrimSpace(filter.Component); v != "" { + args = append(args, v) + clauses = append(clauses, "COALESCE(l.component,'') = $"+itoa(len(args))) + hasConstraint = true + } + if v := strings.TrimSpace(filter.RequestID); v != "" { + args = append(args, v) + clauses = append(clauses, "COALESCE(l.request_id,'') = $"+itoa(len(args))) + hasConstraint = true + } + if v := strings.TrimSpace(filter.ClientRequestID); v != "" { + args = append(args, v) + clauses = append(clauses, "COALESCE(l.client_request_id,'') = $"+itoa(len(args))) + hasConstraint = true + } + if filter.UserID != nil && *filter.UserID > 0 { + args = append(args, *filter.UserID) + clauses = append(clauses, "l.user_id = $"+itoa(len(args))) + hasConstraint = true + } + if filter.AccountID != nil && *filter.AccountID > 0 { + args = append(args, *filter.AccountID) + clauses = append(clauses, "l.account_id = $"+itoa(len(args))) + hasConstraint = true + } + if v := strings.TrimSpace(filter.Platform); v != "" { + args = append(args, v) + clauses = append(clauses, "COALESCE(l.platform,'') = $"+itoa(len(args))) + hasConstraint = true + } + if v := strings.TrimSpace(filter.Model); v != "" { + args = append(args, v) + clauses = append(clauses, "COALESCE(l.model,'') = $"+itoa(len(args))) + hasConstraint = true + } + if v := strings.TrimSpace(filter.Query); v != "" { + like := "%" + v + "%" + args = append(args, like) + n := itoa(len(args)) + clauses = append(clauses, "(l.message ILIKE $"+n+" OR COALESCE(l.request_id,'') ILIKE $"+n+" OR COALESCE(l.client_request_id,'') ILIKE $"+n+" OR COALESCE(l.extra::text,'') ILIKE $"+n+")") + hasConstraint = true + } + } + + return "WHERE " + strings.Join(clauses, " AND "), args, hasConstraint +} + +func buildOpsSystemLogsCleanupWhere(filter *service.OpsSystemLogCleanupFilter) (string, []any, bool) { + if filter == nil { + filter = &service.OpsSystemLogCleanupFilter{} + } + listFilter := &service.OpsSystemLogFilter{ + StartTime: filter.StartTime, + EndTime: filter.EndTime, + Level: filter.Level, + Component: filter.Component, + RequestID: filter.RequestID, + ClientRequestID: filter.ClientRequestID, + UserID: filter.UserID, + AccountID: filter.AccountID, + Platform: filter.Platform, + Model: filter.Model, + Query: filter.Query, + } + return buildOpsSystemLogsWhere(listFilter) +} + // Helpers for nullable args func opsNullString(v any) any { switch s := v.(type) { diff --git a/backend/internal/repository/ops_repo_error_where_test.go b/backend/internal/repository/ops_repo_error_where_test.go new file mode 100644 index 00000000..9ab1a89a --- /dev/null +++ b/backend/internal/repository/ops_repo_error_where_test.go @@ -0,0 +1,48 @@ +package repository + +import ( + "strings" + "testing" + + "github.com/Wei-Shaw/sub2api/internal/service" +) + +func TestBuildOpsErrorLogsWhere_QueryUsesQualifiedColumns(t *testing.T) { + filter := &service.OpsErrorLogFilter{ + Query: "ACCESS_DENIED", + } + + where, args := buildOpsErrorLogsWhere(filter) + if where == "" { + t.Fatalf("where should not be empty") + } + if len(args) != 1 { + t.Fatalf("args len = %d, want 1", len(args)) + } + if !strings.Contains(where, "e.request_id ILIKE $") { + t.Fatalf("where should include qualified request_id condition: %s", where) + } + if !strings.Contains(where, "e.client_request_id ILIKE $") { + t.Fatalf("where should include qualified client_request_id condition: %s", where) + } + if !strings.Contains(where, "e.error_message ILIKE $") { + t.Fatalf("where should include qualified error_message condition: %s", where) + } +} + +func TestBuildOpsErrorLogsWhere_UserQueryUsesExistsSubquery(t *testing.T) { + filter := &service.OpsErrorLogFilter{ + UserQuery: "admin@", + } + + where, args := buildOpsErrorLogsWhere(filter) + if where == "" { + t.Fatalf("where should not be empty") + } + if len(args) != 1 { + t.Fatalf("args len = %d, want 1", len(args)) + } + if !strings.Contains(where, "EXISTS (SELECT 1 FROM users u WHERE u.id = e.user_id AND u.email ILIKE $") { + t.Fatalf("where should include EXISTS user email condition: %s", where) + } +} diff --git a/backend/internal/repository/ops_repo_openai_token_stats.go b/backend/internal/repository/ops_repo_openai_token_stats.go new file mode 100644 index 00000000..6aea416e --- /dev/null +++ b/backend/internal/repository/ops_repo_openai_token_stats.go @@ -0,0 +1,145 @@ +package repository + +import ( + "context" + "database/sql" + "fmt" + "strings" + + "github.com/Wei-Shaw/sub2api/internal/service" +) + +func (r *opsRepository) GetOpenAITokenStats(ctx context.Context, filter *service.OpsOpenAITokenStatsFilter) (*service.OpsOpenAITokenStatsResponse, error) { + if r == nil || r.db == nil { + return nil, fmt.Errorf("nil ops repository") + } + if filter == nil { + return nil, fmt.Errorf("nil filter") + } + if filter.StartTime.IsZero() || filter.EndTime.IsZero() { + return nil, fmt.Errorf("start_time/end_time required") + } + // 允许 start_time == end_time(结果为空),与 service 层校验口径保持一致。 + if filter.StartTime.After(filter.EndTime) { + return nil, fmt.Errorf("start_time must be <= end_time") + } + + dashboardFilter := &service.OpsDashboardFilter{ + StartTime: filter.StartTime.UTC(), + EndTime: filter.EndTime.UTC(), + Platform: strings.TrimSpace(strings.ToLower(filter.Platform)), + GroupID: filter.GroupID, + } + + join, where, baseArgs, next := buildUsageWhere(dashboardFilter, dashboardFilter.StartTime, dashboardFilter.EndTime, 1) + where += " AND ul.model LIKE 'gpt%'" + + baseCTE := ` +WITH stats AS ( + SELECT + ul.model AS model, + COUNT(*)::bigint AS request_count, + ROUND( + AVG( + CASE + WHEN ul.duration_ms > 0 AND ul.output_tokens > 0 + THEN ul.output_tokens * 1000.0 / ul.duration_ms + END + )::numeric, + 2 + )::float8 AS avg_tokens_per_sec, + ROUND(AVG(ul.first_token_ms)::numeric, 2)::float8 AS avg_first_token_ms, + COALESCE(SUM(ul.output_tokens), 0)::bigint AS total_output_tokens, + COALESCE(ROUND(AVG(ul.duration_ms)::numeric, 0), 0)::bigint AS avg_duration_ms, + COUNT(CASE WHEN ul.first_token_ms IS NOT NULL THEN 1 END)::bigint AS requests_with_first_token + FROM usage_logs ul + ` + join + ` + ` + where + ` + GROUP BY ul.model +) +` + + countSQL := baseCTE + `SELECT COUNT(*) FROM stats` + var total int64 + if err := r.db.QueryRowContext(ctx, countSQL, baseArgs...).Scan(&total); err != nil { + return nil, err + } + + querySQL := baseCTE + ` +SELECT + model, + request_count, + avg_tokens_per_sec, + avg_first_token_ms, + total_output_tokens, + avg_duration_ms, + requests_with_first_token +FROM stats +ORDER BY request_count DESC, model ASC` + + args := make([]any, 0, len(baseArgs)+2) + args = append(args, baseArgs...) + + if filter.IsTopNMode() { + querySQL += fmt.Sprintf("\nLIMIT $%d", next) + args = append(args, filter.TopN) + } else { + offset := (filter.Page - 1) * filter.PageSize + querySQL += fmt.Sprintf("\nLIMIT $%d OFFSET $%d", next, next+1) + args = append(args, filter.PageSize, offset) + } + + rows, err := r.db.QueryContext(ctx, querySQL, args...) + if err != nil { + return nil, err + } + defer func() { _ = rows.Close() }() + + items := make([]*service.OpsOpenAITokenStatsItem, 0, 32) + for rows.Next() { + item := &service.OpsOpenAITokenStatsItem{} + var avgTPS sql.NullFloat64 + var avgFirstToken sql.NullFloat64 + if err := rows.Scan( + &item.Model, + &item.RequestCount, + &avgTPS, + &avgFirstToken, + &item.TotalOutputTokens, + &item.AvgDurationMs, + &item.RequestsWithFirstToken, + ); err != nil { + return nil, err + } + if avgTPS.Valid { + v := avgTPS.Float64 + item.AvgTokensPerSec = &v + } + if avgFirstToken.Valid { + v := avgFirstToken.Float64 + item.AvgFirstTokenMs = &v + } + items = append(items, item) + } + if err := rows.Err(); err != nil { + return nil, err + } + + resp := &service.OpsOpenAITokenStatsResponse{ + TimeRange: strings.TrimSpace(filter.TimeRange), + StartTime: dashboardFilter.StartTime, + EndTime: dashboardFilter.EndTime, + Platform: dashboardFilter.Platform, + GroupID: dashboardFilter.GroupID, + Items: items, + Total: total, + } + if filter.IsTopNMode() { + topN := filter.TopN + resp.TopN = &topN + } else { + resp.Page = filter.Page + resp.PageSize = filter.PageSize + } + return resp, nil +} diff --git a/backend/internal/repository/ops_repo_openai_token_stats_test.go b/backend/internal/repository/ops_repo_openai_token_stats_test.go new file mode 100644 index 00000000..bb01d820 --- /dev/null +++ b/backend/internal/repository/ops_repo_openai_token_stats_test.go @@ -0,0 +1,156 @@ +package repository + +import ( + "context" + "testing" + "time" + + "github.com/DATA-DOG/go-sqlmock" + "github.com/Wei-Shaw/sub2api/internal/service" + "github.com/stretchr/testify/require" +) + +func TestOpsRepositoryGetOpenAITokenStats_PaginationMode(t *testing.T) { + db, mock := newSQLMock(t) + repo := &opsRepository{db: db} + + start := time.Date(2026, 1, 1, 0, 0, 0, 0, time.UTC) + end := start.Add(24 * time.Hour) + groupID := int64(9) + + filter := &service.OpsOpenAITokenStatsFilter{ + TimeRange: "1d", + StartTime: start, + EndTime: end, + Platform: " OpenAI ", + GroupID: &groupID, + Page: 2, + PageSize: 10, + } + + mock.ExpectQuery(`SELECT COUNT\(\*\) FROM stats`). + WithArgs(start, end, groupID, "openai"). + WillReturnRows(sqlmock.NewRows([]string{"count"}).AddRow(int64(3))) + + rows := sqlmock.NewRows([]string{ + "model", + "request_count", + "avg_tokens_per_sec", + "avg_first_token_ms", + "total_output_tokens", + "avg_duration_ms", + "requests_with_first_token", + }). + AddRow("gpt-4o-mini", int64(20), 21.56, 120.34, int64(3000), int64(850), int64(18)). + AddRow("gpt-4.1", int64(20), 10.2, 240.0, int64(2500), int64(900), int64(20)) + + mock.ExpectQuery(`ORDER BY request_count DESC, model ASC\s+LIMIT \$5 OFFSET \$6`). + WithArgs(start, end, groupID, "openai", 10, 10). + WillReturnRows(rows) + + resp, err := repo.GetOpenAITokenStats(context.Background(), filter) + require.NoError(t, err) + require.NotNil(t, resp) + require.Equal(t, int64(3), resp.Total) + require.Equal(t, 2, resp.Page) + require.Equal(t, 10, resp.PageSize) + require.Nil(t, resp.TopN) + require.Equal(t, "openai", resp.Platform) + require.NotNil(t, resp.GroupID) + require.Equal(t, groupID, *resp.GroupID) + require.Len(t, resp.Items, 2) + require.Equal(t, "gpt-4o-mini", resp.Items[0].Model) + require.NotNil(t, resp.Items[0].AvgTokensPerSec) + require.InDelta(t, 21.56, *resp.Items[0].AvgTokensPerSec, 0.0001) + require.NotNil(t, resp.Items[0].AvgFirstTokenMs) + require.InDelta(t, 120.34, *resp.Items[0].AvgFirstTokenMs, 0.0001) + + require.NoError(t, mock.ExpectationsWereMet()) +} + +func TestOpsRepositoryGetOpenAITokenStats_TopNMode(t *testing.T) { + db, mock := newSQLMock(t) + repo := &opsRepository{db: db} + + start := time.Date(2026, 1, 1, 10, 0, 0, 0, time.UTC) + end := start.Add(time.Hour) + filter := &service.OpsOpenAITokenStatsFilter{ + TimeRange: "1h", + StartTime: start, + EndTime: end, + TopN: 5, + } + + mock.ExpectQuery(`SELECT COUNT\(\*\) FROM stats`). + WithArgs(start, end). + WillReturnRows(sqlmock.NewRows([]string{"count"}).AddRow(int64(1))) + + rows := sqlmock.NewRows([]string{ + "model", + "request_count", + "avg_tokens_per_sec", + "avg_first_token_ms", + "total_output_tokens", + "avg_duration_ms", + "requests_with_first_token", + }). + AddRow("gpt-4o", int64(5), nil, nil, int64(0), int64(0), int64(0)) + + mock.ExpectQuery(`ORDER BY request_count DESC, model ASC\s+LIMIT \$3`). + WithArgs(start, end, 5). + WillReturnRows(rows) + + resp, err := repo.GetOpenAITokenStats(context.Background(), filter) + require.NoError(t, err) + require.NotNil(t, resp) + require.NotNil(t, resp.TopN) + require.Equal(t, 5, *resp.TopN) + require.Equal(t, 0, resp.Page) + require.Equal(t, 0, resp.PageSize) + require.Len(t, resp.Items, 1) + require.Nil(t, resp.Items[0].AvgTokensPerSec) + require.Nil(t, resp.Items[0].AvgFirstTokenMs) + + require.NoError(t, mock.ExpectationsWereMet()) +} + +func TestOpsRepositoryGetOpenAITokenStats_EmptyResult(t *testing.T) { + db, mock := newSQLMock(t) + repo := &opsRepository{db: db} + + start := time.Date(2026, 1, 2, 0, 0, 0, 0, time.UTC) + end := start.Add(30 * time.Minute) + filter := &service.OpsOpenAITokenStatsFilter{ + TimeRange: "30m", + StartTime: start, + EndTime: end, + Page: 1, + PageSize: 20, + } + + mock.ExpectQuery(`SELECT COUNT\(\*\) FROM stats`). + WithArgs(start, end). + WillReturnRows(sqlmock.NewRows([]string{"count"}).AddRow(int64(0))) + + mock.ExpectQuery(`ORDER BY request_count DESC, model ASC\s+LIMIT \$3 OFFSET \$4`). + WithArgs(start, end, 20, 0). + WillReturnRows(sqlmock.NewRows([]string{ + "model", + "request_count", + "avg_tokens_per_sec", + "avg_first_token_ms", + "total_output_tokens", + "avg_duration_ms", + "requests_with_first_token", + })) + + resp, err := repo.GetOpenAITokenStats(context.Background(), filter) + require.NoError(t, err) + require.NotNil(t, resp) + require.Equal(t, int64(0), resp.Total) + require.Len(t, resp.Items, 0) + require.Equal(t, 1, resp.Page) + require.Equal(t, 20, resp.PageSize) + + require.NoError(t, mock.ExpectationsWereMet()) +} diff --git a/backend/internal/repository/ops_repo_system_logs_test.go b/backend/internal/repository/ops_repo_system_logs_test.go new file mode 100644 index 00000000..c3524fe4 --- /dev/null +++ b/backend/internal/repository/ops_repo_system_logs_test.go @@ -0,0 +1,86 @@ +package repository + +import ( + "strings" + "testing" + "time" + + "github.com/Wei-Shaw/sub2api/internal/service" +) + +func TestBuildOpsSystemLogsWhere_WithClientRequestIDAndUserID(t *testing.T) { + start := time.Date(2026, 2, 1, 0, 0, 0, 0, time.UTC) + end := time.Date(2026, 2, 2, 0, 0, 0, 0, time.UTC) + userID := int64(12) + accountID := int64(34) + + filter := &service.OpsSystemLogFilter{ + StartTime: &start, + EndTime: &end, + Level: "warn", + Component: "http.access", + RequestID: "req-1", + ClientRequestID: "creq-1", + UserID: &userID, + AccountID: &accountID, + Platform: "openai", + Model: "gpt-5", + Query: "timeout", + } + + where, args, hasConstraint := buildOpsSystemLogsWhere(filter) + if !hasConstraint { + t.Fatalf("expected hasConstraint=true") + } + if where == "" { + t.Fatalf("where should not be empty") + } + if len(args) != 11 { + t.Fatalf("args len = %d, want 11", len(args)) + } + if !contains(where, "COALESCE(l.client_request_id,'') = $") { + t.Fatalf("where should include client_request_id condition: %s", where) + } + if !contains(where, "l.user_id = $") { + t.Fatalf("where should include user_id condition: %s", where) + } +} + +func TestBuildOpsSystemLogsCleanupWhere_RequireConstraint(t *testing.T) { + where, args, hasConstraint := buildOpsSystemLogsCleanupWhere(&service.OpsSystemLogCleanupFilter{}) + if hasConstraint { + t.Fatalf("expected hasConstraint=false") + } + if where == "" { + t.Fatalf("where should not be empty") + } + if len(args) != 0 { + t.Fatalf("args len = %d, want 0", len(args)) + } +} + +func TestBuildOpsSystemLogsCleanupWhere_WithClientRequestIDAndUserID(t *testing.T) { + userID := int64(9) + filter := &service.OpsSystemLogCleanupFilter{ + ClientRequestID: "creq-9", + UserID: &userID, + } + + where, args, hasConstraint := buildOpsSystemLogsCleanupWhere(filter) + if !hasConstraint { + t.Fatalf("expected hasConstraint=true") + } + if len(args) != 2 { + t.Fatalf("args len = %d, want 2", len(args)) + } + if !contains(where, "COALESCE(l.client_request_id,'') = $") { + t.Fatalf("where should include client_request_id condition: %s", where) + } + if !contains(where, "l.user_id = $") { + t.Fatalf("where should include user_id condition: %s", where) + } +} + +func contains(s string, sub string) bool { + return strings.Contains(s, sub) +} diff --git a/backend/internal/repository/promo_code_repo.go b/backend/internal/repository/promo_code_repo.go index 98b422e0..95ce687a 100644 --- a/backend/internal/repository/promo_code_repo.go +++ b/backend/internal/repository/promo_code_repo.go @@ -132,7 +132,7 @@ func (r *promoCodeRepository) ListWithFilters(ctx context.Context, params pagina q = q.Where(promocode.CodeContainsFold(search)) } - total, err := q.Count(ctx) + total, err := q.Clone().Count(ctx) if err != nil { return nil, nil, err } @@ -187,7 +187,7 @@ func (r *promoCodeRepository) ListUsagesByPromoCode(ctx context.Context, promoCo q := r.client.PromoCodeUsage.Query(). Where(promocodeusage.PromoCodeIDEQ(promoCodeID)) - total, err := q.Count(ctx) + total, err := q.Clone().Count(ctx) if err != nil { return nil, nil, err } diff --git a/backend/internal/repository/proxy_probe_service.go b/backend/internal/repository/proxy_probe_service.go index 513e929c..54de2897 100644 --- a/backend/internal/repository/proxy_probe_service.go +++ b/backend/internal/repository/proxy_probe_service.go @@ -19,10 +19,14 @@ func NewProxyExitInfoProber(cfg *config.Config) service.ProxyExitInfoProber { insecure := false allowPrivate := false validateResolvedIP := true + maxResponseBytes := defaultProxyProbeResponseMaxBytes if cfg != nil { insecure = cfg.Security.ProxyProbe.InsecureSkipVerify allowPrivate = cfg.Security.URLAllowlist.AllowPrivateHosts validateResolvedIP = cfg.Security.URLAllowlist.Enabled + if cfg.Gateway.ProxyProbeResponseReadMaxBytes > 0 { + maxResponseBytes = cfg.Gateway.ProxyProbeResponseReadMaxBytes + } } if insecure { log.Printf("[ProxyProbe] Warning: insecure_skip_verify is not allowed and will cause probe failure.") @@ -31,11 +35,13 @@ func NewProxyExitInfoProber(cfg *config.Config) service.ProxyExitInfoProber { insecureSkipVerify: insecure, allowPrivateHosts: allowPrivate, validateResolvedIP: validateResolvedIP, + maxResponseBytes: maxResponseBytes, } } const ( - defaultProxyProbeTimeout = 30 * time.Second + defaultProxyProbeTimeout = 30 * time.Second + defaultProxyProbeResponseMaxBytes = int64(1024 * 1024) ) // probeURLs 按优先级排列的探测 URL 列表 @@ -52,6 +58,7 @@ type proxyProbeService struct { insecureSkipVerify bool allowPrivateHosts bool validateResolvedIP bool + maxResponseBytes int64 } func (s *proxyProbeService) ProbeProxy(ctx context.Context, proxyURL string) (*service.ProxyExitInfo, int64, error) { @@ -98,10 +105,17 @@ func (s *proxyProbeService) probeWithURL(ctx context.Context, client *http.Clien return nil, latencyMs, fmt.Errorf("request failed with status: %d", resp.StatusCode) } - body, err := io.ReadAll(resp.Body) + maxResponseBytes := s.maxResponseBytes + if maxResponseBytes <= 0 { + maxResponseBytes = defaultProxyProbeResponseMaxBytes + } + body, err := io.ReadAll(io.LimitReader(resp.Body, maxResponseBytes+1)) if err != nil { return nil, latencyMs, fmt.Errorf("failed to read response: %w", err) } + if int64(len(body)) > maxResponseBytes { + return nil, latencyMs, fmt.Errorf("proxy probe response exceeds limit: %d", maxResponseBytes) + } switch parser { case "ip-api": diff --git a/backend/internal/repository/security_secret_bootstrap.go b/backend/internal/repository/security_secret_bootstrap.go new file mode 100644 index 00000000..e773c238 --- /dev/null +++ b/backend/internal/repository/security_secret_bootstrap.go @@ -0,0 +1,177 @@ +package repository + +import ( + "context" + "crypto/rand" + "database/sql" + "encoding/hex" + "errors" + "fmt" + "log" + "strings" + "time" + + "github.com/Wei-Shaw/sub2api/ent" + "github.com/Wei-Shaw/sub2api/ent/securitysecret" + "github.com/Wei-Shaw/sub2api/internal/config" +) + +const ( + securitySecretKeyJWT = "jwt_secret" + securitySecretReadRetryMax = 5 + securitySecretReadRetryWait = 10 * time.Millisecond +) + +var readRandomBytes = rand.Read + +func ensureBootstrapSecrets(ctx context.Context, client *ent.Client, cfg *config.Config) error { + if client == nil { + return fmt.Errorf("nil ent client") + } + if cfg == nil { + return fmt.Errorf("nil config") + } + + cfg.JWT.Secret = strings.TrimSpace(cfg.JWT.Secret) + if cfg.JWT.Secret != "" { + storedSecret, err := createSecuritySecretIfAbsent(ctx, client, securitySecretKeyJWT, cfg.JWT.Secret) + if err != nil { + return fmt.Errorf("persist jwt secret: %w", err) + } + if storedSecret != cfg.JWT.Secret { + log.Println("Warning: configured JWT secret mismatches persisted value; using persisted secret for cross-instance consistency.") + } + cfg.JWT.Secret = storedSecret + return nil + } + + secret, created, err := getOrCreateGeneratedSecuritySecret(ctx, client, securitySecretKeyJWT, 32) + if err != nil { + return fmt.Errorf("ensure jwt secret: %w", err) + } + cfg.JWT.Secret = secret + + if created { + log.Println("Warning: JWT secret auto-generated and persisted to database. Consider rotating to a managed secret for production.") + } + return nil +} + +func getOrCreateGeneratedSecuritySecret(ctx context.Context, client *ent.Client, key string, byteLength int) (string, bool, error) { + existing, err := client.SecuritySecret.Query().Where(securitysecret.KeyEQ(key)).Only(ctx) + if err == nil { + value := strings.TrimSpace(existing.Value) + if len([]byte(value)) < 32 { + return "", false, fmt.Errorf("stored secret %q must be at least 32 bytes", key) + } + return value, false, nil + } + if !ent.IsNotFound(err) { + return "", false, err + } + + generated, err := generateHexSecret(byteLength) + if err != nil { + return "", false, err + } + + if err := client.SecuritySecret.Create(). + SetKey(key). + SetValue(generated). + OnConflictColumns(securitysecret.FieldKey). + DoNothing(). + Exec(ctx); err != nil { + if !isSQLNoRowsError(err) { + return "", false, err + } + } + + stored, err := querySecuritySecretWithRetry(ctx, client, key) + if err != nil { + return "", false, err + } + value := strings.TrimSpace(stored.Value) + if len([]byte(value)) < 32 { + return "", false, fmt.Errorf("stored secret %q must be at least 32 bytes", key) + } + return value, value == generated, nil +} + +func createSecuritySecretIfAbsent(ctx context.Context, client *ent.Client, key, value string) (string, error) { + value = strings.TrimSpace(value) + if len([]byte(value)) < 32 { + return "", fmt.Errorf("secret %q must be at least 32 bytes", key) + } + + if err := client.SecuritySecret.Create(). + SetKey(key). + SetValue(value). + OnConflictColumns(securitysecret.FieldKey). + DoNothing(). + Exec(ctx); err != nil { + if !isSQLNoRowsError(err) { + return "", err + } + } + + stored, err := querySecuritySecretWithRetry(ctx, client, key) + if err != nil { + return "", err + } + storedValue := strings.TrimSpace(stored.Value) + if len([]byte(storedValue)) < 32 { + return "", fmt.Errorf("stored secret %q must be at least 32 bytes", key) + } + return storedValue, nil +} + +func querySecuritySecretWithRetry(ctx context.Context, client *ent.Client, key string) (*ent.SecuritySecret, error) { + var lastErr error + for attempt := 0; attempt <= securitySecretReadRetryMax; attempt++ { + stored, err := client.SecuritySecret.Query().Where(securitysecret.KeyEQ(key)).Only(ctx) + if err == nil { + return stored, nil + } + if !isSecretNotFoundError(err) { + return nil, err + } + lastErr = err + if attempt == securitySecretReadRetryMax { + break + } + + timer := time.NewTimer(securitySecretReadRetryWait) + select { + case <-ctx.Done(): + timer.Stop() + return nil, ctx.Err() + case <-timer.C: + } + } + return nil, lastErr +} + +func isSecretNotFoundError(err error) bool { + if err == nil { + return false + } + return ent.IsNotFound(err) || isSQLNoRowsError(err) +} + +func isSQLNoRowsError(err error) bool { + if err == nil { + return false + } + return errors.Is(err, sql.ErrNoRows) || strings.Contains(err.Error(), "no rows in result set") +} + +func generateHexSecret(byteLength int) (string, error) { + if byteLength <= 0 { + byteLength = 32 + } + buf := make([]byte, byteLength) + if _, err := readRandomBytes(buf); err != nil { + return "", fmt.Errorf("generate random secret: %w", err) + } + return hex.EncodeToString(buf), nil +} diff --git a/backend/internal/repository/security_secret_bootstrap_test.go b/backend/internal/repository/security_secret_bootstrap_test.go new file mode 100644 index 00000000..288edf33 --- /dev/null +++ b/backend/internal/repository/security_secret_bootstrap_test.go @@ -0,0 +1,337 @@ +package repository + +import ( + "context" + "database/sql" + "encoding/hex" + "errors" + "fmt" + "strings" + "sync" + "testing" + + dbent "github.com/Wei-Shaw/sub2api/ent" + "github.com/Wei-Shaw/sub2api/ent/enttest" + "github.com/Wei-Shaw/sub2api/ent/securitysecret" + "github.com/Wei-Shaw/sub2api/internal/config" + "github.com/stretchr/testify/require" + + "entgo.io/ent/dialect" + entsql "entgo.io/ent/dialect/sql" + _ "modernc.org/sqlite" +) + +func newSecuritySecretTestClient(t *testing.T) *dbent.Client { + t.Helper() + name := strings.ReplaceAll(t.Name(), "/", "_") + dsn := fmt.Sprintf("file:%s?mode=memory&cache=shared&_fk=1", name) + + db, err := sql.Open("sqlite", dsn) + require.NoError(t, err) + t.Cleanup(func() { _ = db.Close() }) + + _, err = db.Exec("PRAGMA foreign_keys = ON") + require.NoError(t, err) + + drv := entsql.OpenDB(dialect.SQLite, db) + client := enttest.NewClient(t, enttest.WithOptions(dbent.Driver(drv))) + t.Cleanup(func() { _ = client.Close() }) + return client +} + +func TestEnsureBootstrapSecretsNilInputs(t *testing.T) { + err := ensureBootstrapSecrets(context.Background(), nil, &config.Config{}) + require.Error(t, err) + require.Contains(t, err.Error(), "nil ent client") + + client := newSecuritySecretTestClient(t) + err = ensureBootstrapSecrets(context.Background(), client, nil) + require.Error(t, err) + require.Contains(t, err.Error(), "nil config") +} + +func TestEnsureBootstrapSecretsGenerateAndPersistJWTSecret(t *testing.T) { + client := newSecuritySecretTestClient(t) + cfg := &config.Config{} + + err := ensureBootstrapSecrets(context.Background(), client, cfg) + require.NoError(t, err) + require.NotEmpty(t, cfg.JWT.Secret) + require.GreaterOrEqual(t, len([]byte(cfg.JWT.Secret)), 32) + + stored, err := client.SecuritySecret.Query().Where(securitysecret.KeyEQ(securitySecretKeyJWT)).Only(context.Background()) + require.NoError(t, err) + require.Equal(t, cfg.JWT.Secret, stored.Value) +} + +func TestEnsureBootstrapSecretsLoadExistingJWTSecret(t *testing.T) { + client := newSecuritySecretTestClient(t) + _, err := client.SecuritySecret.Create().SetKey(securitySecretKeyJWT).SetValue("existing-jwt-secret-32bytes-long!!!!").Save(context.Background()) + require.NoError(t, err) + + cfg := &config.Config{} + err = ensureBootstrapSecrets(context.Background(), client, cfg) + require.NoError(t, err) + require.Equal(t, "existing-jwt-secret-32bytes-long!!!!", cfg.JWT.Secret) +} + +func TestEnsureBootstrapSecretsRejectInvalidStoredSecret(t *testing.T) { + client := newSecuritySecretTestClient(t) + _, err := client.SecuritySecret.Create().SetKey(securitySecretKeyJWT).SetValue("too-short").Save(context.Background()) + require.NoError(t, err) + + cfg := &config.Config{} + err = ensureBootstrapSecrets(context.Background(), client, cfg) + require.Error(t, err) + require.Contains(t, err.Error(), "at least 32 bytes") +} + +func TestEnsureBootstrapSecretsPersistConfiguredJWTSecret(t *testing.T) { + client := newSecuritySecretTestClient(t) + cfg := &config.Config{ + JWT: config.JWTConfig{Secret: "configured-jwt-secret-32bytes-long!!"}, + } + + err := ensureBootstrapSecrets(context.Background(), client, cfg) + require.NoError(t, err) + + stored, err := client.SecuritySecret.Query().Where(securitysecret.KeyEQ(securitySecretKeyJWT)).Only(context.Background()) + require.NoError(t, err) + require.Equal(t, "configured-jwt-secret-32bytes-long!!", stored.Value) +} + +func TestEnsureBootstrapSecretsConfiguredSecretTooShort(t *testing.T) { + client := newSecuritySecretTestClient(t) + cfg := &config.Config{JWT: config.JWTConfig{Secret: "short"}} + + err := ensureBootstrapSecrets(context.Background(), client, cfg) + require.Error(t, err) + require.Contains(t, err.Error(), "at least 32 bytes") +} + +func TestEnsureBootstrapSecretsConfiguredSecretDuplicateIgnored(t *testing.T) { + client := newSecuritySecretTestClient(t) + _, err := client.SecuritySecret.Create(). + SetKey(securitySecretKeyJWT). + SetValue("existing-jwt-secret-32bytes-long!!!!"). + Save(context.Background()) + require.NoError(t, err) + + cfg := &config.Config{JWT: config.JWTConfig{Secret: "another-configured-jwt-secret-32!!!!"}} + err = ensureBootstrapSecrets(context.Background(), client, cfg) + require.NoError(t, err) + + stored, err := client.SecuritySecret.Query().Where(securitysecret.KeyEQ(securitySecretKeyJWT)).Only(context.Background()) + require.NoError(t, err) + require.Equal(t, "existing-jwt-secret-32bytes-long!!!!", stored.Value) + require.Equal(t, "existing-jwt-secret-32bytes-long!!!!", cfg.JWT.Secret) +} + +func TestGetOrCreateGeneratedSecuritySecretTrimmedExistingValue(t *testing.T) { + client := newSecuritySecretTestClient(t) + _, err := client.SecuritySecret.Create(). + SetKey("trimmed_key"). + SetValue(" existing-trimmed-secret-32bytes-long!! "). + Save(context.Background()) + require.NoError(t, err) + + value, created, err := getOrCreateGeneratedSecuritySecret(context.Background(), client, "trimmed_key", 32) + require.NoError(t, err) + require.False(t, created) + require.Equal(t, "existing-trimmed-secret-32bytes-long!!", value) +} + +func TestGetOrCreateGeneratedSecuritySecretQueryError(t *testing.T) { + client := newSecuritySecretTestClient(t) + require.NoError(t, client.Close()) + + _, _, err := getOrCreateGeneratedSecuritySecret(context.Background(), client, "closed_client_key", 32) + require.Error(t, err) +} + +func TestGetOrCreateGeneratedSecuritySecretCreateValidationError(t *testing.T) { + client := newSecuritySecretTestClient(t) + tooLongKey := strings.Repeat("k", 101) + + _, _, err := getOrCreateGeneratedSecuritySecret(context.Background(), client, tooLongKey, 32) + require.Error(t, err) +} + +func TestGetOrCreateGeneratedSecuritySecretConcurrentCreation(t *testing.T) { + client := newSecuritySecretTestClient(t) + const goroutines = 8 + key := "concurrent_bootstrap_key" + + values := make([]string, goroutines) + createdFlags := make([]bool, goroutines) + errs := make([]error, goroutines) + + var wg sync.WaitGroup + for i := 0; i < goroutines; i++ { + wg.Add(1) + go func(idx int) { + defer wg.Done() + values[idx], createdFlags[idx], errs[idx] = getOrCreateGeneratedSecuritySecret(context.Background(), client, key, 32) + }(i) + } + wg.Wait() + + for i := range errs { + require.NoError(t, errs[i]) + require.NotEmpty(t, values[i]) + } + for i := 1; i < len(values); i++ { + require.Equal(t, values[0], values[i]) + } + + createdCount := 0 + for _, created := range createdFlags { + if created { + createdCount++ + } + } + require.GreaterOrEqual(t, createdCount, 1) + require.LessOrEqual(t, createdCount, 1) + + count, err := client.SecuritySecret.Query().Where(securitysecret.KeyEQ(key)).Count(context.Background()) + require.NoError(t, err) + require.Equal(t, 1, count) +} + +func TestGetOrCreateGeneratedSecuritySecretGenerateError(t *testing.T) { + client := newSecuritySecretTestClient(t) + originalRead := readRandomBytes + readRandomBytes = func([]byte) (int, error) { + return 0, errors.New("boom") + } + t.Cleanup(func() { + readRandomBytes = originalRead + }) + + _, _, err := getOrCreateGeneratedSecuritySecret(context.Background(), client, "gen_error_key", 32) + require.Error(t, err) + require.Contains(t, err.Error(), "boom") +} + +func TestCreateSecuritySecretIfAbsent(t *testing.T) { + client := newSecuritySecretTestClient(t) + + _, err := createSecuritySecretIfAbsent(context.Background(), client, "abc", "short") + require.Error(t, err) + require.Contains(t, err.Error(), "at least 32 bytes") + + stored, err := createSecuritySecretIfAbsent(context.Background(), client, "abc", "valid-jwt-secret-value-32bytes-long") + require.NoError(t, err) + require.Equal(t, "valid-jwt-secret-value-32bytes-long", stored) + + stored, err = createSecuritySecretIfAbsent(context.Background(), client, "abc", "another-valid-secret-value-32bytes") + require.NoError(t, err) + require.Equal(t, "valid-jwt-secret-value-32bytes-long", stored) + + count, err := client.SecuritySecret.Query().Where(securitysecret.KeyEQ("abc")).Count(context.Background()) + require.NoError(t, err) + require.Equal(t, 1, count) +} + +func TestCreateSecuritySecretIfAbsentValidationError(t *testing.T) { + client := newSecuritySecretTestClient(t) + _, err := createSecuritySecretIfAbsent( + context.Background(), + client, + strings.Repeat("k", 101), + "valid-jwt-secret-value-32bytes-long", + ) + require.Error(t, err) +} + +func TestCreateSecuritySecretIfAbsentExecError(t *testing.T) { + client := newSecuritySecretTestClient(t) + require.NoError(t, client.Close()) + + _, err := createSecuritySecretIfAbsent(context.Background(), client, "closed-client-key", "valid-jwt-secret-value-32bytes-long") + require.Error(t, err) +} + +func TestQuerySecuritySecretWithRetrySuccess(t *testing.T) { + client := newSecuritySecretTestClient(t) + created, err := client.SecuritySecret.Create(). + SetKey("retry_success_key"). + SetValue("retry-success-jwt-secret-value-32!!"). + Save(context.Background()) + require.NoError(t, err) + + got, err := querySecuritySecretWithRetry(context.Background(), client, "retry_success_key") + require.NoError(t, err) + require.Equal(t, created.ID, got.ID) + require.Equal(t, "retry-success-jwt-secret-value-32!!", got.Value) +} + +func TestQuerySecuritySecretWithRetryExhausted(t *testing.T) { + client := newSecuritySecretTestClient(t) + + _, err := querySecuritySecretWithRetry(context.Background(), client, "retry_missing_key") + require.Error(t, err) + require.True(t, isSecretNotFoundError(err)) +} + +func TestQuerySecuritySecretWithRetryContextCanceled(t *testing.T) { + client := newSecuritySecretTestClient(t) + ctx, cancel := context.WithTimeout(context.Background(), securitySecretReadRetryWait/2) + defer cancel() + + _, err := querySecuritySecretWithRetry(ctx, client, "retry_ctx_cancel_key") + require.Error(t, err) + require.ErrorIs(t, err, context.DeadlineExceeded) +} + +func TestQuerySecuritySecretWithRetryNonNotFoundError(t *testing.T) { + client := newSecuritySecretTestClient(t) + require.NoError(t, client.Close()) + + _, err := querySecuritySecretWithRetry(context.Background(), client, "retry_closed_client_key") + require.Error(t, err) + require.False(t, isSecretNotFoundError(err)) +} + +func TestSecretNotFoundHelpers(t *testing.T) { + require.False(t, isSecretNotFoundError(nil)) + require.False(t, isSQLNoRowsError(nil)) + + require.True(t, isSQLNoRowsError(sql.ErrNoRows)) + require.True(t, isSQLNoRowsError(fmt.Errorf("wrapped: %w", sql.ErrNoRows))) + require.True(t, isSQLNoRowsError(errors.New("sql: no rows in result set"))) + + require.True(t, isSecretNotFoundError(sql.ErrNoRows)) + require.True(t, isSecretNotFoundError(errors.New("sql: no rows in result set"))) + require.False(t, isSecretNotFoundError(errors.New("some other error"))) +} + +func TestGenerateHexSecretReadError(t *testing.T) { + originalRead := readRandomBytes + readRandomBytes = func([]byte) (int, error) { + return 0, errors.New("read random failed") + } + t.Cleanup(func() { + readRandomBytes = originalRead + }) + + _, err := generateHexSecret(32) + require.Error(t, err) + require.Contains(t, err.Error(), "read random failed") +} + +func TestGenerateHexSecretLengths(t *testing.T) { + v1, err := generateHexSecret(0) + require.NoError(t, err) + require.Len(t, v1, 64) + _, err = hex.DecodeString(v1) + require.NoError(t, err) + + v2, err := generateHexSecret(16) + require.NoError(t, err) + require.Len(t, v2, 32) + _, err = hex.DecodeString(v2) + require.NoError(t, err) + + require.NotEqual(t, v1, v2) +} diff --git a/backend/internal/repository/sora_account_repo.go b/backend/internal/repository/sora_account_repo.go new file mode 100644 index 00000000..ad2ae638 --- /dev/null +++ b/backend/internal/repository/sora_account_repo.go @@ -0,0 +1,98 @@ +package repository + +import ( + "context" + "database/sql" + "errors" + + "github.com/Wei-Shaw/sub2api/internal/service" +) + +// soraAccountRepository 实现 service.SoraAccountRepository 接口。 +// 使用原生 SQL 操作 sora_accounts 表,因为该表不在 Ent ORM 管理范围内。 +// +// 设计说明: +// - sora_accounts 表是独立迁移创建的,不通过 Ent Schema 管理 +// - 使用 ON CONFLICT (account_id) DO UPDATE 实现 Upsert 语义 +// - 与 accounts 主表通过外键关联,ON DELETE CASCADE 确保级联删除 +type soraAccountRepository struct { + sql *sql.DB +} + +// NewSoraAccountRepository 创建 Sora 账号扩展表仓储实例 +func NewSoraAccountRepository(sqlDB *sql.DB) service.SoraAccountRepository { + return &soraAccountRepository{sql: sqlDB} +} + +// Upsert 创建或更新 Sora 账号扩展信息 +// 使用 PostgreSQL ON CONFLICT ... DO UPDATE 实现原子性 upsert +func (r *soraAccountRepository) Upsert(ctx context.Context, accountID int64, updates map[string]any) error { + accessToken, accessOK := updates["access_token"].(string) + refreshToken, refreshOK := updates["refresh_token"].(string) + sessionToken, sessionOK := updates["session_token"].(string) + + if !accessOK || accessToken == "" || !refreshOK || refreshToken == "" { + if !sessionOK { + return errors.New("缺少 access_token/refresh_token,且未提供可更新字段") + } + result, err := r.sql.ExecContext(ctx, ` + UPDATE sora_accounts + SET session_token = CASE WHEN $2 = '' THEN session_token ELSE $2 END, + updated_at = NOW() + WHERE account_id = $1 + `, accountID, sessionToken) + if err != nil { + return err + } + rows, err := result.RowsAffected() + if err != nil { + return err + } + if rows == 0 { + return errors.New("sora_accounts 记录不存在,无法仅更新 session_token") + } + return nil + } + + _, err := r.sql.ExecContext(ctx, ` + INSERT INTO sora_accounts (account_id, access_token, refresh_token, session_token, created_at, updated_at) + VALUES ($1, $2, $3, $4, NOW(), NOW()) + ON CONFLICT (account_id) DO UPDATE SET + access_token = EXCLUDED.access_token, + refresh_token = EXCLUDED.refresh_token, + session_token = CASE WHEN EXCLUDED.session_token = '' THEN sora_accounts.session_token ELSE EXCLUDED.session_token END, + updated_at = NOW() + `, accountID, accessToken, refreshToken, sessionToken) + return err +} + +// GetByAccountID 根据账号 ID 获取 Sora 扩展信息 +func (r *soraAccountRepository) GetByAccountID(ctx context.Context, accountID int64) (*service.SoraAccount, error) { + rows, err := r.sql.QueryContext(ctx, ` + SELECT account_id, access_token, refresh_token, COALESCE(session_token, '') + FROM sora_accounts + WHERE account_id = $1 + `, accountID) + if err != nil { + return nil, err + } + defer func() { _ = rows.Close() }() + + if !rows.Next() { + return nil, nil // 记录不存在 + } + + var sa service.SoraAccount + if err := rows.Scan(&sa.AccountID, &sa.AccessToken, &sa.RefreshToken, &sa.SessionToken); err != nil { + return nil, err + } + return &sa, nil +} + +// Delete 删除 Sora 账号扩展信息 +func (r *soraAccountRepository) Delete(ctx context.Context, accountID int64) error { + _, err := r.sql.ExecContext(ctx, ` + DELETE FROM sora_accounts WHERE account_id = $1 + `, accountID) + return err +} diff --git a/backend/internal/repository/usage_log_repo.go b/backend/internal/repository/usage_log_repo.go index c3e5ae85..ce67ba4d 100644 --- a/backend/internal/repository/usage_log_repo.go +++ b/backend/internal/repository/usage_log_repo.go @@ -22,7 +22,23 @@ import ( "github.com/lib/pq" ) -const usageLogSelectColumns = "id, user_id, api_key_id, account_id, request_id, model, group_id, subscription_id, input_tokens, output_tokens, cache_creation_tokens, cache_read_tokens, cache_creation_5m_tokens, cache_creation_1h_tokens, input_cost, output_cost, cache_creation_cost, cache_read_cost, total_cost, actual_cost, rate_multiplier, account_rate_multiplier, billing_type, stream, duration_ms, first_token_ms, user_agent, ip_address, image_count, image_size, reasoning_effort, cache_ttl_overridden, created_at" +const usageLogSelectColumns = "id, user_id, api_key_id, account_id, request_id, model, group_id, subscription_id, input_tokens, output_tokens, cache_creation_tokens, cache_read_tokens, cache_creation_5m_tokens, cache_creation_1h_tokens, input_cost, output_cost, cache_creation_cost, cache_read_cost, total_cost, actual_cost, rate_multiplier, account_rate_multiplier, billing_type, stream, duration_ms, first_token_ms, user_agent, ip_address, image_count, image_size, media_type, reasoning_effort, cache_ttl_overridden, created_at" + +// dateFormatWhitelist 将 granularity 参数映射为 PostgreSQL TO_CHAR 格式字符串,防止外部输入直接拼入 SQL +var dateFormatWhitelist = map[string]string{ + "hour": "YYYY-MM-DD HH24:00", + "day": "YYYY-MM-DD", + "week": "IYYY-IW", + "month": "YYYY-MM", +} + +// safeDateFormat 根据白名单获取 dateFormat,未匹配时返回默认值 +func safeDateFormat(granularity string) string { + if f, ok := dateFormatWhitelist[granularity]; ok { + return f + } + return "YYYY-MM-DD" +} type usageLogRepository struct { client *dbent.Client @@ -111,23 +127,24 @@ func (r *usageLogRepository) Create(ctx context.Context, log *service.UsageLog) duration_ms, first_token_ms, user_agent, - ip_address, - image_count, - image_size, - reasoning_effort, - cache_ttl_overridden, - created_at - ) VALUES ( - $1, $2, $3, $4, $5, - $6, $7, - $8, $9, $10, $11, - $12, $13, - $14, $15, $16, $17, $18, $19, - $20, $21, $22, $23, $24, $25, $26, $27, $28, $29, $30, $31, $32 - ) - ON CONFLICT (request_id, api_key_id) DO NOTHING - RETURNING id, created_at - ` + ip_address, + image_count, + image_size, + media_type, + reasoning_effort, + cache_ttl_overridden, + created_at + ) VALUES ( + $1, $2, $3, $4, $5, + $6, $7, + $8, $9, $10, $11, + $12, $13, + $14, $15, $16, $17, $18, $19, + $20, $21, $22, $23, $24, $25, $26, $27, $28, $29, $30, $31, $32, $33 + ) + ON CONFLICT (request_id, api_key_id) DO NOTHING + RETURNING id, created_at + ` groupID := nullInt64(log.GroupID) subscriptionID := nullInt64(log.SubscriptionID) @@ -136,6 +153,7 @@ func (r *usageLogRepository) Create(ctx context.Context, log *service.UsageLog) userAgent := nullString(log.UserAgent) ipAddress := nullString(log.IPAddress) imageSize := nullString(log.ImageSize) + mediaType := nullString(log.MediaType) reasoningEffort := nullString(log.ReasoningEffort) var requestIDArg any @@ -173,6 +191,7 @@ func (r *usageLogRepository) Create(ctx context.Context, log *service.UsageLog) ipAddress, log.ImageCount, imageSize, + mediaType, reasoningEffort, log.CacheTTLOverridden, createdAt, @@ -566,7 +585,7 @@ func (r *usageLogRepository) ListByAccount(ctx context.Context, accountID int64, } func (r *usageLogRepository) ListByUserAndTimeRange(ctx context.Context, userID int64, startTime, endTime time.Time) ([]service.UsageLog, *pagination.PaginationResult, error) { - query := "SELECT " + usageLogSelectColumns + " FROM usage_logs WHERE user_id = $1 AND created_at >= $2 AND created_at < $3 ORDER BY id DESC" + query := "SELECT " + usageLogSelectColumns + " FROM usage_logs WHERE user_id = $1 AND created_at >= $2 AND created_at < $3 ORDER BY id DESC LIMIT 10000" logs, err := r.queryUsageLogs(ctx, query, userID, startTime, endTime) return logs, nil, err } @@ -812,19 +831,19 @@ func resolveUsageStatsTimezone() string { } func (r *usageLogRepository) ListByAPIKeyAndTimeRange(ctx context.Context, apiKeyID int64, startTime, endTime time.Time) ([]service.UsageLog, *pagination.PaginationResult, error) { - query := "SELECT " + usageLogSelectColumns + " FROM usage_logs WHERE api_key_id = $1 AND created_at >= $2 AND created_at < $3 ORDER BY id DESC" + query := "SELECT " + usageLogSelectColumns + " FROM usage_logs WHERE api_key_id = $1 AND created_at >= $2 AND created_at < $3 ORDER BY id DESC LIMIT 10000" logs, err := r.queryUsageLogs(ctx, query, apiKeyID, startTime, endTime) return logs, nil, err } func (r *usageLogRepository) ListByAccountAndTimeRange(ctx context.Context, accountID int64, startTime, endTime time.Time) ([]service.UsageLog, *pagination.PaginationResult, error) { - query := "SELECT " + usageLogSelectColumns + " FROM usage_logs WHERE account_id = $1 AND created_at >= $2 AND created_at < $3 ORDER BY id DESC" + query := "SELECT " + usageLogSelectColumns + " FROM usage_logs WHERE account_id = $1 AND created_at >= $2 AND created_at < $3 ORDER BY id DESC LIMIT 10000" logs, err := r.queryUsageLogs(ctx, query, accountID, startTime, endTime) return logs, nil, err } func (r *usageLogRepository) ListByModelAndTimeRange(ctx context.Context, modelName string, startTime, endTime time.Time) ([]service.UsageLog, *pagination.PaginationResult, error) { - query := "SELECT " + usageLogSelectColumns + " FROM usage_logs WHERE model = $1 AND created_at >= $2 AND created_at < $3 ORDER BY id DESC" + query := "SELECT " + usageLogSelectColumns + " FROM usage_logs WHERE model = $1 AND created_at >= $2 AND created_at < $3 ORDER BY id DESC LIMIT 10000" logs, err := r.queryUsageLogs(ctx, query, modelName, startTime, endTime) return logs, nil, err } @@ -896,6 +915,59 @@ func (r *usageLogRepository) GetAccountWindowStats(ctx context.Context, accountI return stats, nil } +// GetAccountWindowStatsBatch 批量获取同一窗口起点下多个账号的统计数据。 +// 返回 map[accountID]*AccountStats,未命中的账号会返回零值统计,便于上层直接复用。 +func (r *usageLogRepository) GetAccountWindowStatsBatch(ctx context.Context, accountIDs []int64, startTime time.Time) (map[int64]*usagestats.AccountStats, error) { + result := make(map[int64]*usagestats.AccountStats, len(accountIDs)) + if len(accountIDs) == 0 { + return result, nil + } + + query := ` + SELECT + account_id, + COUNT(*) as requests, + COALESCE(SUM(input_tokens + output_tokens + cache_creation_tokens + cache_read_tokens), 0) as tokens, + COALESCE(SUM(total_cost * COALESCE(account_rate_multiplier, 1)), 0) as cost, + COALESCE(SUM(total_cost), 0) as standard_cost, + COALESCE(SUM(actual_cost), 0) as user_cost + FROM usage_logs + WHERE account_id = ANY($1) AND created_at >= $2 + GROUP BY account_id + ` + rows, err := r.sql.QueryContext(ctx, query, pq.Array(accountIDs), startTime) + if err != nil { + return nil, err + } + defer func() { _ = rows.Close() }() + + for rows.Next() { + var accountID int64 + stats := &usagestats.AccountStats{} + if err := rows.Scan( + &accountID, + &stats.Requests, + &stats.Tokens, + &stats.Cost, + &stats.StandardCost, + &stats.UserCost, + ); err != nil { + return nil, err + } + result[accountID] = stats + } + if err := rows.Err(); err != nil { + return nil, err + } + + for _, accountID := range accountIDs { + if _, ok := result[accountID]; !ok { + result[accountID] = &usagestats.AccountStats{} + } + } + return result, nil +} + // TrendDataPoint represents a single point in trend data type TrendDataPoint = usagestats.TrendDataPoint @@ -910,10 +982,7 @@ type APIKeyUsageTrendPoint = usagestats.APIKeyUsageTrendPoint // GetAPIKeyUsageTrend returns usage trend data grouped by API key and date func (r *usageLogRepository) GetAPIKeyUsageTrend(ctx context.Context, startTime, endTime time.Time, granularity string, limit int) (results []APIKeyUsageTrendPoint, err error) { - dateFormat := "YYYY-MM-DD" - if granularity == "hour" { - dateFormat = "YYYY-MM-DD HH24:00" - } + dateFormat := safeDateFormat(granularity) query := fmt.Sprintf(` WITH top_keys AS ( @@ -968,10 +1037,7 @@ func (r *usageLogRepository) GetAPIKeyUsageTrend(ctx context.Context, startTime, // GetUserUsageTrend returns usage trend data grouped by user and date func (r *usageLogRepository) GetUserUsageTrend(ctx context.Context, startTime, endTime time.Time, granularity string, limit int) (results []UserUsageTrendPoint, err error) { - dateFormat := "YYYY-MM-DD" - if granularity == "hour" { - dateFormat = "YYYY-MM-DD HH24:00" - } + dateFormat := safeDateFormat(granularity) query := fmt.Sprintf(` WITH top_users AS ( @@ -1230,10 +1296,7 @@ func (r *usageLogRepository) GetAPIKeyDashboardStats(ctx context.Context, apiKey // GetUserUsageTrendByUserID 获取指定用户的使用趋势 func (r *usageLogRepository) GetUserUsageTrendByUserID(ctx context.Context, userID int64, startTime, endTime time.Time, granularity string) (results []TrendDataPoint, err error) { - dateFormat := "YYYY-MM-DD" - if granularity == "hour" { - dateFormat = "YYYY-MM-DD HH24:00" - } + dateFormat := safeDateFormat(granularity) query := fmt.Sprintf(` SELECT @@ -1371,13 +1434,22 @@ type UsageStats = usagestats.UsageStats // BatchUserUsageStats represents usage stats for a single user type BatchUserUsageStats = usagestats.BatchUserUsageStats -// GetBatchUserUsageStats gets today and total actual_cost for multiple users -func (r *usageLogRepository) GetBatchUserUsageStats(ctx context.Context, userIDs []int64) (map[int64]*BatchUserUsageStats, error) { +// GetBatchUserUsageStats gets today and total actual_cost for multiple users within a time range. +// If startTime is zero, defaults to 30 days ago. +func (r *usageLogRepository) GetBatchUserUsageStats(ctx context.Context, userIDs []int64, startTime, endTime time.Time) (map[int64]*BatchUserUsageStats, error) { result := make(map[int64]*BatchUserUsageStats) if len(userIDs) == 0 { return result, nil } + // 默认最近 30 天 + if startTime.IsZero() { + startTime = time.Now().AddDate(0, 0, -30) + } + if endTime.IsZero() { + endTime = time.Now() + } + for _, id := range userIDs { result[id] = &BatchUserUsageStats{UserID: id} } @@ -1385,10 +1457,10 @@ func (r *usageLogRepository) GetBatchUserUsageStats(ctx context.Context, userIDs query := ` SELECT user_id, COALESCE(SUM(actual_cost), 0) as total_cost FROM usage_logs - WHERE user_id = ANY($1) + WHERE user_id = ANY($1) AND created_at >= $2 AND created_at < $3 GROUP BY user_id ` - rows, err := r.sql.QueryContext(ctx, query, pq.Array(userIDs)) + rows, err := r.sql.QueryContext(ctx, query, pq.Array(userIDs), startTime, endTime) if err != nil { return nil, err } @@ -1445,13 +1517,22 @@ func (r *usageLogRepository) GetBatchUserUsageStats(ctx context.Context, userIDs // BatchAPIKeyUsageStats represents usage stats for a single API key type BatchAPIKeyUsageStats = usagestats.BatchAPIKeyUsageStats -// GetBatchAPIKeyUsageStats gets today and total actual_cost for multiple API keys -func (r *usageLogRepository) GetBatchAPIKeyUsageStats(ctx context.Context, apiKeyIDs []int64) (map[int64]*BatchAPIKeyUsageStats, error) { +// GetBatchAPIKeyUsageStats gets today and total actual_cost for multiple API keys within a time range. +// If startTime is zero, defaults to 30 days ago. +func (r *usageLogRepository) GetBatchAPIKeyUsageStats(ctx context.Context, apiKeyIDs []int64, startTime, endTime time.Time) (map[int64]*BatchAPIKeyUsageStats, error) { result := make(map[int64]*BatchAPIKeyUsageStats) if len(apiKeyIDs) == 0 { return result, nil } + // 默认最近 30 天 + if startTime.IsZero() { + startTime = time.Now().AddDate(0, 0, -30) + } + if endTime.IsZero() { + endTime = time.Now() + } + for _, id := range apiKeyIDs { result[id] = &BatchAPIKeyUsageStats{APIKeyID: id} } @@ -1459,10 +1540,10 @@ func (r *usageLogRepository) GetBatchAPIKeyUsageStats(ctx context.Context, apiKe query := ` SELECT api_key_id, COALESCE(SUM(actual_cost), 0) as total_cost FROM usage_logs - WHERE api_key_id = ANY($1) + WHERE api_key_id = ANY($1) AND created_at >= $2 AND created_at < $3 GROUP BY api_key_id ` - rows, err := r.sql.QueryContext(ctx, query, pq.Array(apiKeyIDs)) + rows, err := r.sql.QueryContext(ctx, query, pq.Array(apiKeyIDs), startTime, endTime) if err != nil { return nil, err } @@ -1518,10 +1599,7 @@ func (r *usageLogRepository) GetBatchAPIKeyUsageStats(ctx context.Context, apiKe // GetUsageTrendWithFilters returns usage trend data with optional filters func (r *usageLogRepository) GetUsageTrendWithFilters(ctx context.Context, startTime, endTime time.Time, granularity string, userID, apiKeyID, accountID, groupID int64, model string, stream *bool, billingType *int8) (results []TrendDataPoint, err error) { - dateFormat := "YYYY-MM-DD" - if granularity == "hour" { - dateFormat = "YYYY-MM-DD HH24:00" - } + dateFormat := safeDateFormat(granularity) query := fmt.Sprintf(` SELECT @@ -2196,6 +2274,7 @@ func scanUsageLog(scanner interface{ Scan(...any) error }) (*service.UsageLog, e ipAddress sql.NullString imageCount int imageSize sql.NullString + mediaType sql.NullString reasoningEffort sql.NullString cacheTTLOverridden bool createdAt time.Time @@ -2232,6 +2311,7 @@ func scanUsageLog(scanner interface{ Scan(...any) error }) (*service.UsageLog, e &ipAddress, &imageCount, &imageSize, + &mediaType, &reasoningEffort, &cacheTTLOverridden, &createdAt, @@ -2294,6 +2374,9 @@ func scanUsageLog(scanner interface{ Scan(...any) error }) (*service.UsageLog, e if imageSize.Valid { log.ImageSize = &imageSize.String } + if mediaType.Valid { + log.MediaType = &mediaType.String + } if reasoningEffort.Valid { log.ReasoningEffort = &reasoningEffort.String } diff --git a/backend/internal/repository/usage_log_repo_integration_test.go b/backend/internal/repository/usage_log_repo_integration_test.go index eb220f22..8cb3aab1 100644 --- a/backend/internal/repository/usage_log_repo_integration_test.go +++ b/backend/internal/repository/usage_log_repo_integration_test.go @@ -648,7 +648,7 @@ func (s *UsageLogRepoSuite) TestGetBatchUserUsageStats() { s.createUsageLog(user1, apiKey1, account, 10, 20, 0.5, time.Now()) s.createUsageLog(user2, apiKey2, account, 15, 25, 0.6, time.Now()) - stats, err := s.repo.GetBatchUserUsageStats(s.ctx, []int64{user1.ID, user2.ID}) + stats, err := s.repo.GetBatchUserUsageStats(s.ctx, []int64{user1.ID, user2.ID}, time.Time{}, time.Time{}) s.Require().NoError(err, "GetBatchUserUsageStats") s.Require().Len(stats, 2) s.Require().NotNil(stats[user1.ID]) @@ -656,7 +656,7 @@ func (s *UsageLogRepoSuite) TestGetBatchUserUsageStats() { } func (s *UsageLogRepoSuite) TestGetBatchUserUsageStats_Empty() { - stats, err := s.repo.GetBatchUserUsageStats(s.ctx, []int64{}) + stats, err := s.repo.GetBatchUserUsageStats(s.ctx, []int64{}, time.Time{}, time.Time{}) s.Require().NoError(err) s.Require().Empty(stats) } @@ -672,13 +672,13 @@ func (s *UsageLogRepoSuite) TestGetBatchApiKeyUsageStats() { s.createUsageLog(user, apiKey1, account, 10, 20, 0.5, time.Now()) s.createUsageLog(user, apiKey2, account, 15, 25, 0.6, time.Now()) - stats, err := s.repo.GetBatchAPIKeyUsageStats(s.ctx, []int64{apiKey1.ID, apiKey2.ID}) + stats, err := s.repo.GetBatchAPIKeyUsageStats(s.ctx, []int64{apiKey1.ID, apiKey2.ID}, time.Time{}, time.Time{}) s.Require().NoError(err, "GetBatchAPIKeyUsageStats") s.Require().Len(stats, 2) } func (s *UsageLogRepoSuite) TestGetBatchApiKeyUsageStats_Empty() { - stats, err := s.repo.GetBatchAPIKeyUsageStats(s.ctx, []int64{}) + stats, err := s.repo.GetBatchAPIKeyUsageStats(s.ctx, []int64{}, time.Time{}, time.Time{}) s.Require().NoError(err) s.Require().Empty(stats) } diff --git a/backend/internal/repository/usage_log_repo_unit_test.go b/backend/internal/repository/usage_log_repo_unit_test.go new file mode 100644 index 00000000..d0e14ffd --- /dev/null +++ b/backend/internal/repository/usage_log_repo_unit_test.go @@ -0,0 +1,41 @@ +//go:build unit + +package repository + +import ( + "testing" + + "github.com/stretchr/testify/require" +) + +func TestSafeDateFormat(t *testing.T) { + tests := []struct { + name string + granularity string + expected string + }{ + // 合法值 + {"hour", "hour", "YYYY-MM-DD HH24:00"}, + {"day", "day", "YYYY-MM-DD"}, + {"week", "week", "IYYY-IW"}, + {"month", "month", "YYYY-MM"}, + + // 非法值回退到默认 + {"空字符串", "", "YYYY-MM-DD"}, + {"未知粒度 year", "year", "YYYY-MM-DD"}, + {"未知粒度 minute", "minute", "YYYY-MM-DD"}, + + // 恶意字符串 + {"SQL 注入尝试", "'; DROP TABLE users; --", "YYYY-MM-DD"}, + {"带引号", "day'", "YYYY-MM-DD"}, + {"带括号", "day)", "YYYY-MM-DD"}, + {"Unicode", "日", "YYYY-MM-DD"}, + } + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + got := safeDateFormat(tc.granularity) + require.Equal(t, tc.expected, got, "safeDateFormat(%q)", tc.granularity) + }) + } +} diff --git a/backend/internal/repository/wire.go b/backend/internal/repository/wire.go index 3aed9d9c..0878c43d 100644 --- a/backend/internal/repository/wire.go +++ b/backend/internal/repository/wire.go @@ -28,7 +28,7 @@ func ProvideConcurrencyCache(rdb *redis.Client, cfg *config.Config) service.Conc // ProvideGitHubReleaseClient 创建 GitHub Release 客户端 // 从配置中读取代理设置,支持国内服务器通过代理访问 GitHub func ProvideGitHubReleaseClient(cfg *config.Config) service.GitHubReleaseClient { - return NewGitHubReleaseClient(cfg.Update.ProxyURL) + return NewGitHubReleaseClient(cfg.Update.ProxyURL, cfg.Security.ProxyFallback.AllowDirectOnError) } // ProvidePricingRemoteClient 创建定价数据远程客户端 @@ -53,12 +53,14 @@ var ProviderSet = wire.NewSet( NewAPIKeyRepository, NewGroupRepository, NewAccountRepository, + NewSoraAccountRepository, // Sora 账号扩展表仓储 NewProxyRepository, NewRedeemCodeRepository, NewPromoCodeRepository, NewAnnouncementRepository, NewAnnouncementReadRepository, NewUsageLogRepository, + NewIdempotencyRepository, NewUsageCleanupRepository, NewDashboardAggregationRepository, NewSettingRepository, diff --git a/backend/internal/server/api_contract_test.go b/backend/internal/server/api_contract_test.go index a8040f82..76897bc1 100644 --- a/backend/internal/server/api_contract_test.go +++ b/backend/internal/server/api_contract_test.go @@ -83,6 +83,7 @@ func TestAPIContracts(t *testing.T) { "status": "active", "ip_whitelist": null, "ip_blacklist": null, + "last_used_at": null, "quota": 0, "quota_used": 0, "expires_at": null, @@ -122,6 +123,7 @@ func TestAPIContracts(t *testing.T) { "status": "active", "ip_whitelist": null, "ip_blacklist": null, + "last_used_at": null, "quota": 0, "quota_used": 0, "expires_at": null, @@ -184,6 +186,10 @@ func TestAPIContracts(t *testing.T) { "image_price_1k": null, "image_price_2k": null, "image_price_4k": null, + "sora_image_price_360": null, + "sora_image_price_540": null, + "sora_video_price_per_request": null, + "sora_video_price_per_request_hd": null, "claude_code_only": false, "fallback_group_id": null, "fallback_group_id_on_invalid_request": null, @@ -401,6 +407,7 @@ func TestAPIContracts(t *testing.T) { "first_token_ms": 50, "image_count": 0, "image_size": null, + "media_type": null, "cache_ttl_overridden": false, "created_at": "2025-01-02T03:04:05Z", "user_agent": null @@ -593,13 +600,13 @@ func newContractDeps(t *testing.T) *contractDeps { RunMode: config.RunModeStandard, } - userService := service.NewUserService(userRepo, nil) + userService := service.NewUserService(userRepo, nil, nil) apiKeyService := service.NewAPIKeyService(apiKeyRepo, userRepo, groupRepo, userSubRepo, nil, apiKeyCache, cfg) usageRepo := newStubUsageLogRepo() usageService := service.NewUsageService(usageRepo, userRepo, nil, nil) - subscriptionService := service.NewSubscriptionService(groupRepo, userSubRepo, nil) + subscriptionService := service.NewSubscriptionService(groupRepo, userSubRepo, nil, nil, cfg) subscriptionHandler := handler.NewSubscriptionHandler(subscriptionService) redeemService := service.NewRedeemService(redeemRepo, userRepo, subscriptionService, nil, nil, nil, nil) @@ -608,7 +615,7 @@ func newContractDeps(t *testing.T) *contractDeps { settingRepo := newStubSettingRepo() settingService := service.NewSettingService(settingRepo, cfg) - adminService := service.NewAdminService(userRepo, groupRepo, &accountRepo, proxyRepo, apiKeyRepo, redeemRepo, nil, nil, nil, nil, nil) + adminService := service.NewAdminService(userRepo, groupRepo, &accountRepo, nil, proxyRepo, apiKeyRepo, redeemRepo, nil, nil, nil, nil, nil) authHandler := handler.NewAuthHandler(cfg, nil, userService, settingService, nil, redeemService, nil) apiKeyHandler := handler.NewAPIKeyHandler(apiKeyService) usageHandler := handler.NewUsageHandler(usageService, apiKeyService) @@ -925,6 +932,10 @@ func (s *stubAccountRepo) GetByCRSAccountID(ctx context.Context, crsAccountID st return nil, errors.New("not implemented") } +func (s *stubAccountRepo) FindByExtraField(ctx context.Context, key string, value any) ([]service.Account, error) { + return nil, errors.New("not implemented") +} + func (s *stubAccountRepo) Update(ctx context.Context, account *service.Account) error { return errors.New("not implemented") } @@ -1462,6 +1473,20 @@ func (r *stubApiKeyRepo) IncrementQuotaUsed(ctx context.Context, id int64, amoun return 0, errors.New("not implemented") } +func (r *stubApiKeyRepo) UpdateLastUsed(ctx context.Context, id int64, usedAt time.Time) error { + key, ok := r.byID[id] + if !ok { + return service.ErrAPIKeyNotFound + } + ts := usedAt + key.LastUsedAt = &ts + key.UpdatedAt = usedAt + clone := *key + r.byID[id] = &clone + r.byKey[clone.Key] = &clone + return nil +} + type stubUsageLogRepo struct { userLogs map[int64][]service.UsageLog } @@ -1607,11 +1632,11 @@ func (r *stubUsageLogRepo) GetDailyStatsAggregated(ctx context.Context, userID i return nil, errors.New("not implemented") } -func (r *stubUsageLogRepo) GetBatchUserUsageStats(ctx context.Context, userIDs []int64) (map[int64]*usagestats.BatchUserUsageStats, error) { +func (r *stubUsageLogRepo) GetBatchUserUsageStats(ctx context.Context, userIDs []int64, startTime, endTime time.Time) (map[int64]*usagestats.BatchUserUsageStats, error) { return nil, errors.New("not implemented") } -func (r *stubUsageLogRepo) GetBatchAPIKeyUsageStats(ctx context.Context, apiKeyIDs []int64) (map[int64]*usagestats.BatchAPIKeyUsageStats, error) { +func (r *stubUsageLogRepo) GetBatchAPIKeyUsageStats(ctx context.Context, apiKeyIDs []int64, startTime, endTime time.Time) (map[int64]*usagestats.BatchAPIKeyUsageStats, error) { return nil, errors.New("not implemented") } diff --git a/backend/internal/server/http.go b/backend/internal/server/http.go index d2d8ed40..a8034e98 100644 --- a/backend/internal/server/http.go +++ b/backend/internal/server/http.go @@ -51,6 +51,9 @@ func ProvideRouter( if err := r.SetTrustedProxies(nil); err != nil { log.Printf("Failed to disable trusted proxies: %v", err) } + if cfg.Server.Mode == "release" { + log.Printf("Warning: server.trusted_proxies is empty in release mode; client IP trust chain is disabled") + } } return SetupRouter(r, handlers, jwtAuth, adminAuth, apiKeyAuth, apiKeyService, subscriptionService, opsService, settingService, cfg, redisClient) diff --git a/backend/internal/server/middleware/admin_auth.go b/backend/internal/server/middleware/admin_auth.go index 8f30107c..6f294ff0 100644 --- a/backend/internal/server/middleware/admin_auth.go +++ b/backend/internal/server/middleware/admin_auth.go @@ -58,8 +58,13 @@ func adminAuth( authHeader := c.GetHeader("Authorization") if authHeader != "" { parts := strings.SplitN(authHeader, " ", 2) - if len(parts) == 2 && parts[0] == "Bearer" { - if !validateJWTForAdmin(c, parts[1], authService, userService) { + if len(parts) == 2 && strings.EqualFold(parts[0], "Bearer") { + token := strings.TrimSpace(parts[1]) + if token == "" { + AbortWithError(c, 401, "UNAUTHORIZED", "Authorization required") + return + } + if !validateJWTForAdmin(c, token, authService, userService) { return } c.Next() @@ -176,6 +181,12 @@ func validateJWTForAdmin( return false } + // 校验 TokenVersion,确保管理员改密后旧 token 失效 + if claims.TokenVersion != user.TokenVersion { + AbortWithError(c, 401, "TOKEN_REVOKED", "Token has been revoked (password changed)") + return false + } + // 检查管理员权限 if !user.IsAdmin() { AbortWithError(c, 403, "FORBIDDEN", "Admin access required") diff --git a/backend/internal/server/middleware/admin_auth_test.go b/backend/internal/server/middleware/admin_auth_test.go new file mode 100644 index 00000000..7b6d4ce8 --- /dev/null +++ b/backend/internal/server/middleware/admin_auth_test.go @@ -0,0 +1,194 @@ +//go:build unit + +package middleware + +import ( + "context" + "net/http" + "net/http/httptest" + "testing" + + "github.com/Wei-Shaw/sub2api/internal/config" + "github.com/Wei-Shaw/sub2api/internal/pkg/pagination" + "github.com/Wei-Shaw/sub2api/internal/service" + "github.com/gin-gonic/gin" + "github.com/stretchr/testify/require" +) + +func TestAdminAuthJWTValidatesTokenVersion(t *testing.T) { + gin.SetMode(gin.TestMode) + + cfg := &config.Config{JWT: config.JWTConfig{Secret: "test-secret", ExpireHour: 1}} + authService := service.NewAuthService(nil, nil, nil, cfg, nil, nil, nil, nil, nil) + + admin := &service.User{ + ID: 1, + Email: "admin@example.com", + Role: service.RoleAdmin, + Status: service.StatusActive, + TokenVersion: 2, + Concurrency: 1, + } + + userRepo := &stubUserRepo{ + getByID: func(ctx context.Context, id int64) (*service.User, error) { + if id != admin.ID { + return nil, service.ErrUserNotFound + } + clone := *admin + return &clone, nil + }, + } + userService := service.NewUserService(userRepo, nil, nil) + + router := gin.New() + router.Use(gin.HandlerFunc(NewAdminAuthMiddleware(authService, userService, nil))) + router.GET("/t", func(c *gin.Context) { + c.JSON(http.StatusOK, gin.H{"ok": true}) + }) + + t.Run("token_version_mismatch_rejected", func(t *testing.T) { + token, err := authService.GenerateToken(&service.User{ + ID: admin.ID, + Email: admin.Email, + Role: admin.Role, + TokenVersion: admin.TokenVersion - 1, + }) + require.NoError(t, err) + + w := httptest.NewRecorder() + req := httptest.NewRequest(http.MethodGet, "/t", nil) + req.Header.Set("Authorization", "Bearer "+token) + router.ServeHTTP(w, req) + + require.Equal(t, http.StatusUnauthorized, w.Code) + require.Contains(t, w.Body.String(), "TOKEN_REVOKED") + }) + + t.Run("token_version_match_allows", func(t *testing.T) { + token, err := authService.GenerateToken(&service.User{ + ID: admin.ID, + Email: admin.Email, + Role: admin.Role, + TokenVersion: admin.TokenVersion, + }) + require.NoError(t, err) + + w := httptest.NewRecorder() + req := httptest.NewRequest(http.MethodGet, "/t", nil) + req.Header.Set("Authorization", "Bearer "+token) + router.ServeHTTP(w, req) + + require.Equal(t, http.StatusOK, w.Code) + }) + + t.Run("websocket_token_version_mismatch_rejected", func(t *testing.T) { + token, err := authService.GenerateToken(&service.User{ + ID: admin.ID, + Email: admin.Email, + Role: admin.Role, + TokenVersion: admin.TokenVersion - 1, + }) + require.NoError(t, err) + + w := httptest.NewRecorder() + req := httptest.NewRequest(http.MethodGet, "/t", nil) + req.Header.Set("Upgrade", "websocket") + req.Header.Set("Connection", "Upgrade") + req.Header.Set("Sec-WebSocket-Protocol", "sub2api-admin, jwt."+token) + router.ServeHTTP(w, req) + + require.Equal(t, http.StatusUnauthorized, w.Code) + require.Contains(t, w.Body.String(), "TOKEN_REVOKED") + }) + + t.Run("websocket_token_version_match_allows", func(t *testing.T) { + token, err := authService.GenerateToken(&service.User{ + ID: admin.ID, + Email: admin.Email, + Role: admin.Role, + TokenVersion: admin.TokenVersion, + }) + require.NoError(t, err) + + w := httptest.NewRecorder() + req := httptest.NewRequest(http.MethodGet, "/t", nil) + req.Header.Set("Upgrade", "websocket") + req.Header.Set("Connection", "Upgrade") + req.Header.Set("Sec-WebSocket-Protocol", "sub2api-admin, jwt."+token) + router.ServeHTTP(w, req) + + require.Equal(t, http.StatusOK, w.Code) + }) +} + +type stubUserRepo struct { + getByID func(ctx context.Context, id int64) (*service.User, error) +} + +func (s *stubUserRepo) Create(ctx context.Context, user *service.User) error { + panic("unexpected Create call") +} + +func (s *stubUserRepo) GetByID(ctx context.Context, id int64) (*service.User, error) { + if s.getByID == nil { + panic("GetByID not stubbed") + } + return s.getByID(ctx, id) +} + +func (s *stubUserRepo) GetByEmail(ctx context.Context, email string) (*service.User, error) { + panic("unexpected GetByEmail call") +} + +func (s *stubUserRepo) GetFirstAdmin(ctx context.Context) (*service.User, error) { + panic("unexpected GetFirstAdmin call") +} + +func (s *stubUserRepo) Update(ctx context.Context, user *service.User) error { + panic("unexpected Update call") +} + +func (s *stubUserRepo) Delete(ctx context.Context, id int64) error { + panic("unexpected Delete call") +} + +func (s *stubUserRepo) List(ctx context.Context, params pagination.PaginationParams) ([]service.User, *pagination.PaginationResult, error) { + panic("unexpected List call") +} + +func (s *stubUserRepo) ListWithFilters(ctx context.Context, params pagination.PaginationParams, filters service.UserListFilters) ([]service.User, *pagination.PaginationResult, error) { + panic("unexpected ListWithFilters call") +} + +func (s *stubUserRepo) UpdateBalance(ctx context.Context, id int64, amount float64) error { + panic("unexpected UpdateBalance call") +} + +func (s *stubUserRepo) DeductBalance(ctx context.Context, id int64, amount float64) error { + panic("unexpected DeductBalance call") +} + +func (s *stubUserRepo) UpdateConcurrency(ctx context.Context, id int64, amount int) error { + panic("unexpected UpdateConcurrency call") +} + +func (s *stubUserRepo) ExistsByEmail(ctx context.Context, email string) (bool, error) { + panic("unexpected ExistsByEmail call") +} + +func (s *stubUserRepo) RemoveGroupFromAllowedGroups(ctx context.Context, groupID int64) (int64, error) { + panic("unexpected RemoveGroupFromAllowedGroups call") +} + +func (s *stubUserRepo) UpdateTotpSecret(ctx context.Context, userID int64, encryptedSecret *string) error { + panic("unexpected UpdateTotpSecret call") +} + +func (s *stubUserRepo) EnableTotp(ctx context.Context, userID int64) error { + panic("unexpected EnableTotp call") +} + +func (s *stubUserRepo) DisableTotp(ctx context.Context, userID int64) error { + panic("unexpected DisableTotp call") +} diff --git a/backend/internal/server/middleware/api_key_auth.go b/backend/internal/server/middleware/api_key_auth.go index 2f739357..8fa3517a 100644 --- a/backend/internal/server/middleware/api_key_auth.go +++ b/backend/internal/server/middleware/api_key_auth.go @@ -3,7 +3,6 @@ package middleware import ( "context" "errors" - "log" "strings" "github.com/Wei-Shaw/sub2api/internal/config" @@ -36,8 +35,8 @@ func apiKeyAuthWithSubscription(apiKeyService *service.APIKeyService, subscripti if authHeader != "" { // 验证Bearer scheme parts := strings.SplitN(authHeader, " ", 2) - if len(parts) == 2 && parts[0] == "Bearer" { - apiKeyString = parts[1] + if len(parts) == 2 && strings.EqualFold(parts[0], "Bearer") { + apiKeyString = strings.TrimSpace(parts[1]) } } @@ -97,7 +96,7 @@ func apiKeyAuthWithSubscription(apiKeyService *service.APIKeyService, subscripti // 检查 IP 限制(白名单/黑名单) // 注意:错误信息故意模糊,避免暴露具体的 IP 限制机制 if len(apiKey.IPWhitelist) > 0 || len(apiKey.IPBlacklist) > 0 { - clientIP := ip.GetClientIP(c) + clientIP := ip.GetTrustedClientIP(c) allowed, _ := ip.CheckIPRestriction(clientIP, apiKey.IPWhitelist, apiKey.IPBlacklist) if !allowed { AbortWithError(c, 403, "ACCESS_DENIED", "Access denied") @@ -126,6 +125,7 @@ func apiKeyAuthWithSubscription(apiKeyService *service.APIKeyService, subscripti }) c.Set(string(ContextKeyUserRole), apiKey.User.Role) setGroupContext(c, apiKey.Group) + _ = apiKeyService.TouchLastUsed(c.Request.Context(), apiKey.ID) c.Next() return } @@ -134,7 +134,7 @@ func apiKeyAuthWithSubscription(apiKeyService *service.APIKeyService, subscripti isSubscriptionType := apiKey.Group != nil && apiKey.Group.IsSubscriptionType() if isSubscriptionType && subscriptionService != nil { - // 订阅模式:验证订阅 + // 订阅模式:获取订阅(L1 缓存 + singleflight) subscription, err := subscriptionService.GetActiveSubscription( c.Request.Context(), apiKey.User.ID, @@ -145,30 +145,30 @@ func apiKeyAuthWithSubscription(apiKeyService *service.APIKeyService, subscripti return } - // 验证订阅状态(是否过期、暂停等) - if err := subscriptionService.ValidateSubscription(c.Request.Context(), subscription); err != nil { - AbortWithError(c, 403, "SUBSCRIPTION_INVALID", err.Error()) - return - } - - // 激活滑动窗口(首次使用时) - if err := subscriptionService.CheckAndActivateWindow(c.Request.Context(), subscription); err != nil { - log.Printf("Failed to activate subscription windows: %v", err) - } - - // 检查并重置过期窗口 - if err := subscriptionService.CheckAndResetWindows(c.Request.Context(), subscription); err != nil { - log.Printf("Failed to reset subscription windows: %v", err) - } - - // 预检查用量限制(使用0作为额外费用进行预检查) - if err := subscriptionService.CheckUsageLimits(c.Request.Context(), subscription, apiKey.Group, 0); err != nil { - AbortWithError(c, 429, "USAGE_LIMIT_EXCEEDED", err.Error()) + // 合并验证 + 限额检查(纯内存操作) + needsMaintenance, err := subscriptionService.ValidateAndCheckLimits(subscription, apiKey.Group) + if err != nil { + code := "SUBSCRIPTION_INVALID" + status := 403 + if errors.Is(err, service.ErrDailyLimitExceeded) || + errors.Is(err, service.ErrWeeklyLimitExceeded) || + errors.Is(err, service.ErrMonthlyLimitExceeded) { + code = "USAGE_LIMIT_EXCEEDED" + status = 429 + } + AbortWithError(c, status, code, err.Error()) return } // 将订阅信息存入上下文 c.Set(string(ContextKeySubscription), subscription) + + // 窗口维护异步化(不阻塞请求) + // 传递独立拷贝,避免与 handler 读取 context 中的 subscription 产生 data race + if needsMaintenance { + maintenanceCopy := *subscription + subscriptionService.DoWindowMaintenance(&maintenanceCopy) + } } else { // 余额模式:检查用户余额 if apiKey.User.Balance <= 0 { @@ -185,6 +185,7 @@ func apiKeyAuthWithSubscription(apiKeyService *service.APIKeyService, subscripti }) c.Set(string(ContextKeyUserRole), apiKey.User.Role) setGroupContext(c, apiKey.Group) + _ = apiKeyService.TouchLastUsed(c.Request.Context(), apiKey.ID) c.Next() } diff --git a/backend/internal/server/middleware/api_key_auth_google.go b/backend/internal/server/middleware/api_key_auth_google.go index 38fbe38b..9da1b1c6 100644 --- a/backend/internal/server/middleware/api_key_auth_google.go +++ b/backend/internal/server/middleware/api_key_auth_google.go @@ -64,6 +64,7 @@ func APIKeyAuthWithSubscriptionGoogle(apiKeyService *service.APIKeyService, subs }) c.Set(string(ContextKeyUserRole), apiKey.User.Role) setGroupContext(c, apiKey.Group) + _ = apiKeyService.TouchLastUsed(c.Request.Context(), apiKey.ID) c.Next() return } @@ -104,6 +105,7 @@ func APIKeyAuthWithSubscriptionGoogle(apiKeyService *service.APIKeyService, subs }) c.Set(string(ContextKeyUserRole), apiKey.User.Role) setGroupContext(c, apiKey.Group) + _ = apiKeyService.TouchLastUsed(c.Request.Context(), apiKey.ID) c.Next() } } diff --git a/backend/internal/server/middleware/api_key_auth_google_test.go b/backend/internal/server/middleware/api_key_auth_google_test.go index 38b93cb2..e4e0e253 100644 --- a/backend/internal/server/middleware/api_key_auth_google_test.go +++ b/backend/internal/server/middleware/api_key_auth_google_test.go @@ -7,6 +7,7 @@ import ( "net/http" "net/http/httptest" "testing" + "time" "github.com/Wei-Shaw/sub2api/internal/config" "github.com/Wei-Shaw/sub2api/internal/pkg/ctxkey" @@ -18,7 +19,8 @@ import ( ) type fakeAPIKeyRepo struct { - getByKey func(ctx context.Context, key string) (*service.APIKey, error) + getByKey func(ctx context.Context, key string) (*service.APIKey, error) + updateLastUsed func(ctx context.Context, id int64, usedAt time.Time) error } func (f fakeAPIKeyRepo) Create(ctx context.Context, key *service.APIKey) error { @@ -78,6 +80,12 @@ func (f fakeAPIKeyRepo) ListKeysByGroupID(ctx context.Context, groupID int64) ([ func (f fakeAPIKeyRepo) IncrementQuotaUsed(ctx context.Context, id int64, amount float64) (float64, error) { return 0, errors.New("not implemented") } +func (f fakeAPIKeyRepo) UpdateLastUsed(ctx context.Context, id int64, usedAt time.Time) error { + if f.updateLastUsed != nil { + return f.updateLastUsed(ctx, id, usedAt) + } + return nil +} type googleErrorResponse struct { Error struct { @@ -356,3 +364,144 @@ func TestApiKeyAuthWithSubscriptionGoogle_InsufficientBalance(t *testing.T) { require.Equal(t, "Insufficient account balance", resp.Error.Message) require.Equal(t, "PERMISSION_DENIED", resp.Error.Status) } + +func TestApiKeyAuthWithSubscriptionGoogle_TouchesLastUsedOnSuccess(t *testing.T) { + gin.SetMode(gin.TestMode) + + user := &service.User{ + ID: 11, + Role: service.RoleUser, + Status: service.StatusActive, + Balance: 10, + Concurrency: 3, + } + apiKey := &service.APIKey{ + ID: 201, + UserID: user.ID, + Key: "google-touch-ok", + Status: service.StatusActive, + User: user, + } + + var touchedID int64 + var touchedAt time.Time + r := gin.New() + apiKeyService := newTestAPIKeyService(fakeAPIKeyRepo{ + getByKey: func(ctx context.Context, key string) (*service.APIKey, error) { + if key != apiKey.Key { + return nil, service.ErrAPIKeyNotFound + } + clone := *apiKey + return &clone, nil + }, + updateLastUsed: func(ctx context.Context, id int64, usedAt time.Time) error { + touchedID = id + touchedAt = usedAt + return nil + }, + }) + cfg := &config.Config{RunMode: config.RunModeSimple} + r.Use(APIKeyAuthWithSubscriptionGoogle(apiKeyService, nil, cfg)) + r.GET("/v1beta/test", func(c *gin.Context) { c.JSON(200, gin.H{"ok": true}) }) + + req := httptest.NewRequest(http.MethodGet, "/v1beta/test", nil) + req.Header.Set("x-goog-api-key", apiKey.Key) + rec := httptest.NewRecorder() + r.ServeHTTP(rec, req) + + require.Equal(t, http.StatusOK, rec.Code) + require.Equal(t, apiKey.ID, touchedID) + require.False(t, touchedAt.IsZero()) +} + +func TestApiKeyAuthWithSubscriptionGoogle_TouchFailureDoesNotBlock(t *testing.T) { + gin.SetMode(gin.TestMode) + + user := &service.User{ + ID: 12, + Role: service.RoleUser, + Status: service.StatusActive, + Balance: 10, + Concurrency: 3, + } + apiKey := &service.APIKey{ + ID: 202, + UserID: user.ID, + Key: "google-touch-fail", + Status: service.StatusActive, + User: user, + } + + touchCalls := 0 + r := gin.New() + apiKeyService := newTestAPIKeyService(fakeAPIKeyRepo{ + getByKey: func(ctx context.Context, key string) (*service.APIKey, error) { + if key != apiKey.Key { + return nil, service.ErrAPIKeyNotFound + } + clone := *apiKey + return &clone, nil + }, + updateLastUsed: func(ctx context.Context, id int64, usedAt time.Time) error { + touchCalls++ + return errors.New("write failed") + }, + }) + cfg := &config.Config{RunMode: config.RunModeSimple} + r.Use(APIKeyAuthWithSubscriptionGoogle(apiKeyService, nil, cfg)) + r.GET("/v1beta/test", func(c *gin.Context) { c.JSON(200, gin.H{"ok": true}) }) + + req := httptest.NewRequest(http.MethodGet, "/v1beta/test", nil) + req.Header.Set("x-goog-api-key", apiKey.Key) + rec := httptest.NewRecorder() + r.ServeHTTP(rec, req) + + require.Equal(t, http.StatusOK, rec.Code) + require.Equal(t, 1, touchCalls) +} + +func TestApiKeyAuthWithSubscriptionGoogle_TouchesLastUsedInStandardMode(t *testing.T) { + gin.SetMode(gin.TestMode) + + user := &service.User{ + ID: 13, + Role: service.RoleUser, + Status: service.StatusActive, + Balance: 10, + Concurrency: 3, + } + apiKey := &service.APIKey{ + ID: 203, + UserID: user.ID, + Key: "google-touch-standard", + Status: service.StatusActive, + User: user, + } + + touchCalls := 0 + r := gin.New() + apiKeyService := newTestAPIKeyService(fakeAPIKeyRepo{ + getByKey: func(ctx context.Context, key string) (*service.APIKey, error) { + if key != apiKey.Key { + return nil, service.ErrAPIKeyNotFound + } + clone := *apiKey + return &clone, nil + }, + updateLastUsed: func(ctx context.Context, id int64, usedAt time.Time) error { + touchCalls++ + return nil + }, + }) + cfg := &config.Config{RunMode: config.RunModeStandard} + r.Use(APIKeyAuthWithSubscriptionGoogle(apiKeyService, nil, cfg)) + r.GET("/v1beta/test", func(c *gin.Context) { c.JSON(200, gin.H{"ok": true}) }) + + req := httptest.NewRequest(http.MethodGet, "/v1beta/test", nil) + req.Header.Set("Authorization", "Bearer "+apiKey.Key) + rec := httptest.NewRecorder() + r.ServeHTTP(rec, req) + + require.Equal(t, http.StatusOK, rec.Code) + require.Equal(t, 1, touchCalls) +} diff --git a/backend/internal/server/middleware/api_key_auth_test.go b/backend/internal/server/middleware/api_key_auth_test.go index 9d514818..0d331761 100644 --- a/backend/internal/server/middleware/api_key_auth_test.go +++ b/backend/internal/server/middleware/api_key_auth_test.go @@ -57,10 +57,41 @@ func TestSimpleModeBypassesQuotaCheck(t *testing.T) { }, } - t.Run("simple_mode_bypasses_quota_check", func(t *testing.T) { - cfg := &config.Config{RunMode: config.RunModeSimple} + t.Run("standard_mode_needs_maintenance_does_not_block_request", func(t *testing.T) { + cfg := &config.Config{RunMode: config.RunModeStandard} + cfg.SubscriptionMaintenance.WorkerCount = 1 + cfg.SubscriptionMaintenance.QueueSize = 1 + apiKeyService := service.NewAPIKeyService(apiKeyRepo, nil, nil, nil, nil, nil, cfg) - subscriptionService := service.NewSubscriptionService(nil, &stubUserSubscriptionRepo{}, nil) + + past := time.Now().Add(-48 * time.Hour) + sub := &service.UserSubscription{ + ID: 55, + UserID: user.ID, + GroupID: group.ID, + Status: service.SubscriptionStatusActive, + ExpiresAt: time.Now().Add(24 * time.Hour), + DailyWindowStart: &past, + DailyUsageUSD: 0, + } + maintenanceCalled := make(chan struct{}, 1) + subscriptionRepo := &stubUserSubscriptionRepo{ + getActive: func(ctx context.Context, userID, groupID int64) (*service.UserSubscription, error) { + clone := *sub + return &clone, nil + }, + updateStatus: func(ctx context.Context, subscriptionID int64, status string) error { return nil }, + activateWindow: func(ctx context.Context, id int64, start time.Time) error { return nil }, + resetDaily: func(ctx context.Context, id int64, start time.Time) error { + maintenanceCalled <- struct{}{} + return nil + }, + resetWeekly: func(ctx context.Context, id int64, start time.Time) error { return nil }, + resetMonthly: func(ctx context.Context, id int64, start time.Time) error { return nil }, + } + subscriptionService := service.NewSubscriptionService(nil, subscriptionRepo, nil, nil, cfg) + t.Cleanup(subscriptionService.Stop) + router := newAuthTestRouter(apiKeyService, subscriptionService, cfg) w := httptest.NewRecorder() @@ -68,6 +99,40 @@ func TestSimpleModeBypassesQuotaCheck(t *testing.T) { req.Header.Set("x-api-key", apiKey.Key) router.ServeHTTP(w, req) + require.Equal(t, http.StatusOK, w.Code) + select { + case <-maintenanceCalled: + // ok + case <-time.After(time.Second): + t.Fatalf("expected maintenance to be scheduled") + } + }) + + t.Run("simple_mode_bypasses_quota_check", func(t *testing.T) { + cfg := &config.Config{RunMode: config.RunModeSimple} + apiKeyService := service.NewAPIKeyService(apiKeyRepo, nil, nil, nil, nil, nil, cfg) + subscriptionService := service.NewSubscriptionService(nil, &stubUserSubscriptionRepo{}, nil, nil, cfg) + router := newAuthTestRouter(apiKeyService, subscriptionService, cfg) + + w := httptest.NewRecorder() + req := httptest.NewRequest(http.MethodGet, "/t", nil) + req.Header.Set("x-api-key", apiKey.Key) + router.ServeHTTP(w, req) + + require.Equal(t, http.StatusOK, w.Code) + }) + + t.Run("simple_mode_accepts_lowercase_bearer", func(t *testing.T) { + cfg := &config.Config{RunMode: config.RunModeSimple} + apiKeyService := service.NewAPIKeyService(apiKeyRepo, nil, nil, nil, nil, nil, cfg) + subscriptionService := service.NewSubscriptionService(nil, &stubUserSubscriptionRepo{}, nil, nil, cfg) + router := newAuthTestRouter(apiKeyService, subscriptionService, cfg) + + w := httptest.NewRecorder() + req := httptest.NewRequest(http.MethodGet, "/t", nil) + req.Header.Set("Authorization", "bearer "+apiKey.Key) + router.ServeHTTP(w, req) + require.Equal(t, http.StatusOK, w.Code) }) @@ -99,7 +164,7 @@ func TestSimpleModeBypassesQuotaCheck(t *testing.T) { resetWeekly: func(ctx context.Context, id int64, start time.Time) error { return nil }, resetMonthly: func(ctx context.Context, id int64, start time.Time) error { return nil }, } - subscriptionService := service.NewSubscriptionService(nil, subscriptionRepo, nil) + subscriptionService := service.NewSubscriptionService(nil, subscriptionRepo, nil, nil, cfg) router := newAuthTestRouter(apiKeyService, subscriptionService, cfg) w := httptest.NewRecorder() @@ -235,6 +300,198 @@ func TestAPIKeyAuthOverwritesInvalidContextGroup(t *testing.T) { require.Equal(t, http.StatusOK, w.Code) } +func TestAPIKeyAuthIPRestrictionDoesNotTrustSpoofedForwardHeaders(t *testing.T) { + gin.SetMode(gin.TestMode) + + user := &service.User{ + ID: 7, + Role: service.RoleUser, + Status: service.StatusActive, + Balance: 10, + Concurrency: 3, + } + apiKey := &service.APIKey{ + ID: 100, + UserID: user.ID, + Key: "test-key", + Status: service.StatusActive, + User: user, + IPWhitelist: []string{"1.2.3.4"}, + } + + apiKeyRepo := &stubApiKeyRepo{ + getByKey: func(ctx context.Context, key string) (*service.APIKey, error) { + if key != apiKey.Key { + return nil, service.ErrAPIKeyNotFound + } + clone := *apiKey + return &clone, nil + }, + } + + cfg := &config.Config{RunMode: config.RunModeSimple} + apiKeyService := service.NewAPIKeyService(apiKeyRepo, nil, nil, nil, nil, nil, cfg) + router := gin.New() + require.NoError(t, router.SetTrustedProxies(nil)) + router.Use(gin.HandlerFunc(NewAPIKeyAuthMiddleware(apiKeyService, nil, cfg))) + router.GET("/t", func(c *gin.Context) { + c.JSON(http.StatusOK, gin.H{"ok": true}) + }) + + w := httptest.NewRecorder() + req := httptest.NewRequest(http.MethodGet, "/t", nil) + req.RemoteAddr = "9.9.9.9:12345" + req.Header.Set("x-api-key", apiKey.Key) + req.Header.Set("X-Forwarded-For", "1.2.3.4") + req.Header.Set("X-Real-IP", "1.2.3.4") + req.Header.Set("CF-Connecting-IP", "1.2.3.4") + router.ServeHTTP(w, req) + + require.Equal(t, http.StatusForbidden, w.Code) + require.Contains(t, w.Body.String(), "ACCESS_DENIED") +} + +func TestAPIKeyAuthTouchesLastUsedOnSuccess(t *testing.T) { + gin.SetMode(gin.TestMode) + + user := &service.User{ + ID: 7, + Role: service.RoleUser, + Status: service.StatusActive, + Balance: 10, + Concurrency: 3, + } + apiKey := &service.APIKey{ + ID: 100, + UserID: user.ID, + Key: "touch-ok", + Status: service.StatusActive, + User: user, + } + + var touchedID int64 + var touchedAt time.Time + apiKeyRepo := &stubApiKeyRepo{ + getByKey: func(ctx context.Context, key string) (*service.APIKey, error) { + if key != apiKey.Key { + return nil, service.ErrAPIKeyNotFound + } + clone := *apiKey + return &clone, nil + }, + updateLastUsed: func(ctx context.Context, id int64, usedAt time.Time) error { + touchedID = id + touchedAt = usedAt + return nil + }, + } + + cfg := &config.Config{RunMode: config.RunModeSimple} + apiKeyService := service.NewAPIKeyService(apiKeyRepo, nil, nil, nil, nil, nil, cfg) + router := newAuthTestRouter(apiKeyService, nil, cfg) + + w := httptest.NewRecorder() + req := httptest.NewRequest(http.MethodGet, "/t", nil) + req.Header.Set("x-api-key", apiKey.Key) + router.ServeHTTP(w, req) + + require.Equal(t, http.StatusOK, w.Code) + require.Equal(t, apiKey.ID, touchedID) + require.False(t, touchedAt.IsZero(), "expected touch timestamp") +} + +func TestAPIKeyAuthTouchLastUsedFailureDoesNotBlock(t *testing.T) { + gin.SetMode(gin.TestMode) + + user := &service.User{ + ID: 8, + Role: service.RoleUser, + Status: service.StatusActive, + Balance: 10, + Concurrency: 3, + } + apiKey := &service.APIKey{ + ID: 101, + UserID: user.ID, + Key: "touch-fail", + Status: service.StatusActive, + User: user, + } + + touchCalls := 0 + apiKeyRepo := &stubApiKeyRepo{ + getByKey: func(ctx context.Context, key string) (*service.APIKey, error) { + if key != apiKey.Key { + return nil, service.ErrAPIKeyNotFound + } + clone := *apiKey + return &clone, nil + }, + updateLastUsed: func(ctx context.Context, id int64, usedAt time.Time) error { + touchCalls++ + return errors.New("db unavailable") + }, + } + + cfg := &config.Config{RunMode: config.RunModeSimple} + apiKeyService := service.NewAPIKeyService(apiKeyRepo, nil, nil, nil, nil, nil, cfg) + router := newAuthTestRouter(apiKeyService, nil, cfg) + + w := httptest.NewRecorder() + req := httptest.NewRequest(http.MethodGet, "/t", nil) + req.Header.Set("x-api-key", apiKey.Key) + router.ServeHTTP(w, req) + + require.Equal(t, http.StatusOK, w.Code, "touch failure should not block request") + require.Equal(t, 1, touchCalls) +} + +func TestAPIKeyAuthTouchesLastUsedInStandardMode(t *testing.T) { + gin.SetMode(gin.TestMode) + + user := &service.User{ + ID: 9, + Role: service.RoleUser, + Status: service.StatusActive, + Balance: 10, + Concurrency: 3, + } + apiKey := &service.APIKey{ + ID: 102, + UserID: user.ID, + Key: "touch-standard", + Status: service.StatusActive, + User: user, + } + + touchCalls := 0 + apiKeyRepo := &stubApiKeyRepo{ + getByKey: func(ctx context.Context, key string) (*service.APIKey, error) { + if key != apiKey.Key { + return nil, service.ErrAPIKeyNotFound + } + clone := *apiKey + return &clone, nil + }, + updateLastUsed: func(ctx context.Context, id int64, usedAt time.Time) error { + touchCalls++ + return nil + }, + } + + cfg := &config.Config{RunMode: config.RunModeStandard} + apiKeyService := service.NewAPIKeyService(apiKeyRepo, nil, nil, nil, nil, nil, cfg) + router := newAuthTestRouter(apiKeyService, nil, cfg) + + w := httptest.NewRecorder() + req := httptest.NewRequest(http.MethodGet, "/t", nil) + req.Header.Set("x-api-key", apiKey.Key) + router.ServeHTTP(w, req) + + require.Equal(t, http.StatusOK, w.Code) + require.Equal(t, 1, touchCalls) +} + func newAuthTestRouter(apiKeyService *service.APIKeyService, subscriptionService *service.SubscriptionService, cfg *config.Config) *gin.Engine { router := gin.New() router.Use(gin.HandlerFunc(NewAPIKeyAuthMiddleware(apiKeyService, subscriptionService, cfg))) @@ -245,7 +502,8 @@ func newAuthTestRouter(apiKeyService *service.APIKeyService, subscriptionService } type stubApiKeyRepo struct { - getByKey func(ctx context.Context, key string) (*service.APIKey, error) + getByKey func(ctx context.Context, key string) (*service.APIKey, error) + updateLastUsed func(ctx context.Context, id int64, usedAt time.Time) error } func (r *stubApiKeyRepo) Create(ctx context.Context, key *service.APIKey) error { @@ -323,6 +581,13 @@ func (r *stubApiKeyRepo) IncrementQuotaUsed(ctx context.Context, id int64, amoun return 0, errors.New("not implemented") } +func (r *stubApiKeyRepo) UpdateLastUsed(ctx context.Context, id int64, usedAt time.Time) error { + if r.updateLastUsed != nil { + return r.updateLastUsed(ctx, id, usedAt) + } + return nil +} + type stubUserSubscriptionRepo struct { getActive func(ctx context.Context, userID, groupID int64) (*service.UserSubscription, error) updateStatus func(ctx context.Context, subscriptionID int64, status string) error diff --git a/backend/internal/server/middleware/client_request_id.go b/backend/internal/server/middleware/client_request_id.go index d22b6cc5..6838d6af 100644 --- a/backend/internal/server/middleware/client_request_id.go +++ b/backend/internal/server/middleware/client_request_id.go @@ -2,10 +2,13 @@ package middleware import ( "context" + "strings" "github.com/Wei-Shaw/sub2api/internal/pkg/ctxkey" + "github.com/Wei-Shaw/sub2api/internal/pkg/logger" "github.com/gin-gonic/gin" "github.com/google/uuid" + "go.uber.org/zap" ) // ClientRequestID ensures every request has a unique client_request_id in request.Context(). @@ -24,7 +27,10 @@ func ClientRequestID() gin.HandlerFunc { } id := uuid.New().String() - c.Request = c.Request.WithContext(context.WithValue(c.Request.Context(), ctxkey.ClientRequestID, id)) + ctx := context.WithValue(c.Request.Context(), ctxkey.ClientRequestID, id) + requestLogger := logger.FromContext(ctx).With(zap.String("client_request_id", strings.TrimSpace(id))) + ctx = logger.IntoContext(ctx, requestLogger) + c.Request = c.Request.WithContext(ctx) c.Next() } } diff --git a/backend/internal/server/middleware/cors.go b/backend/internal/server/middleware/cors.go index f1dd51af..03d5d025 100644 --- a/backend/internal/server/middleware/cors.go +++ b/backend/internal/server/middleware/cors.go @@ -50,6 +50,19 @@ func CORS(cfg config.CORSConfig) gin.HandlerFunc { } allowedSet[origin] = struct{}{} } + allowHeaders := []string{ + "Content-Type", "Content-Length", "Accept-Encoding", "X-CSRF-Token", "Authorization", + "accept", "origin", "Cache-Control", "X-Requested-With", "X-API-Key", + } + // OpenAI Node SDK 会发送 x-stainless-* 请求头,需在 CORS 中显式放行。 + openAIProperties := []string{ + "lang", "package-version", "os", "arch", "retry-count", "runtime", + "runtime-version", "async", "helper-method", "poll-helper", "custom-poll-interval", "timeout", + } + for _, prop := range openAIProperties { + allowHeaders = append(allowHeaders, "x-stainless-"+prop) + } + allowHeadersValue := strings.Join(allowHeaders, ", ") return func(c *gin.Context) { origin := strings.TrimSpace(c.GetHeader("Origin")) @@ -68,19 +81,11 @@ func CORS(cfg config.CORSConfig) gin.HandlerFunc { if allowCredentials { c.Writer.Header().Set("Access-Control-Allow-Credentials", "true") } + c.Writer.Header().Set("Access-Control-Allow-Headers", allowHeadersValue) + c.Writer.Header().Set("Access-Control-Allow-Methods", "POST, OPTIONS, GET, PUT, DELETE, PATCH") + c.Writer.Header().Set("Access-Control-Expose-Headers", "ETag") + c.Writer.Header().Set("Access-Control-Max-Age", "86400") } - - allowHeaders := []string{"Content-Type", "Content-Length", "Accept-Encoding", "X-CSRF-Token", "Authorization", "accept", "origin", "Cache-Control", "X-Requested-With", "X-API-Key"} - - // openai node sdk - openAIProperties := []string{"lang", "package-version", "os", "arch", "retry-count", "runtime", "runtime-version", "async", "helper-method", "poll-helper", "custom-poll-interval", "timeout"} - for _, prop := range openAIProperties { - allowHeaders = append(allowHeaders, "x-stainless-"+prop) - } - - c.Writer.Header().Set("Access-Control-Allow-Headers", strings.Join(allowHeaders, ", ")) - c.Writer.Header().Set("Access-Control-Allow-Methods", "POST, OPTIONS, GET, PUT, DELETE, PATCH") - // 处理预检请求 if c.Request.Method == http.MethodOptions { if originAllowed { diff --git a/backend/internal/server/middleware/cors_test.go b/backend/internal/server/middleware/cors_test.go new file mode 100644 index 00000000..6d0bea36 --- /dev/null +++ b/backend/internal/server/middleware/cors_test.go @@ -0,0 +1,308 @@ +package middleware + +import ( + "net/http" + "net/http/httptest" + "testing" + + "github.com/Wei-Shaw/sub2api/internal/config" + "github.com/gin-gonic/gin" + "github.com/stretchr/testify/assert" +) + +func init() { + // cors_test 与 security_headers_test 在同一个包,但 init 是幂等的 + gin.SetMode(gin.TestMode) +} + +// --- Task 8.2: 验证 CORS 条件化头部 --- + +func TestCORS_DisallowedOrigin_NoAllowHeaders(t *testing.T) { + cfg := config.CORSConfig{ + AllowedOrigins: []string{"https://allowed.example.com"}, + AllowCredentials: false, + } + middleware := CORS(cfg) + + tests := []struct { + name string + method string + origin string + }{ + { + name: "preflight_disallowed_origin", + method: http.MethodOptions, + origin: "https://evil.example.com", + }, + { + name: "get_disallowed_origin", + method: http.MethodGet, + origin: "https://evil.example.com", + }, + { + name: "post_disallowed_origin", + method: http.MethodPost, + origin: "https://attacker.example.com", + }, + { + name: "preflight_no_origin", + method: http.MethodOptions, + origin: "", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + w := httptest.NewRecorder() + c, _ := gin.CreateTestContext(w) + c.Request = httptest.NewRequest(tt.method, "/", nil) + if tt.origin != "" { + c.Request.Header.Set("Origin", tt.origin) + } + + middleware(c) + + // 不应设置 Allow-Headers、Allow-Methods 和 Max-Age + assert.Empty(t, w.Header().Get("Access-Control-Allow-Headers"), + "不允许的 origin 不应收到 Allow-Headers") + assert.Empty(t, w.Header().Get("Access-Control-Allow-Methods"), + "不允许的 origin 不应收到 Allow-Methods") + assert.Empty(t, w.Header().Get("Access-Control-Max-Age"), + "不允许的 origin 不应收到 Max-Age") + assert.Empty(t, w.Header().Get("Access-Control-Allow-Origin"), + "不允许的 origin 不应收到 Allow-Origin") + }) + } +} + +func TestCORS_AllowedOrigin_HasAllowHeaders(t *testing.T) { + cfg := config.CORSConfig{ + AllowedOrigins: []string{"https://allowed.example.com"}, + AllowCredentials: false, + } + middleware := CORS(cfg) + + tests := []struct { + name string + method string + }{ + {name: "preflight_OPTIONS", method: http.MethodOptions}, + {name: "normal_GET", method: http.MethodGet}, + {name: "normal_POST", method: http.MethodPost}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + w := httptest.NewRecorder() + c, _ := gin.CreateTestContext(w) + c.Request = httptest.NewRequest(tt.method, "/", nil) + c.Request.Header.Set("Origin", "https://allowed.example.com") + + middleware(c) + + // 应设置 Allow-Headers、Allow-Methods 和 Max-Age + assert.NotEmpty(t, w.Header().Get("Access-Control-Allow-Headers"), + "允许的 origin 应收到 Allow-Headers") + assert.NotEmpty(t, w.Header().Get("Access-Control-Allow-Methods"), + "允许的 origin 应收到 Allow-Methods") + assert.Equal(t, "86400", w.Header().Get("Access-Control-Max-Age"), + "允许的 origin 应收到 Max-Age=86400") + assert.Equal(t, "https://allowed.example.com", w.Header().Get("Access-Control-Allow-Origin"), + "允许的 origin 应收到 Allow-Origin") + }) + } +} + +func TestCORS_PreflightDisallowedOrigin_ReturnsForbidden(t *testing.T) { + cfg := config.CORSConfig{ + AllowedOrigins: []string{"https://allowed.example.com"}, + AllowCredentials: false, + } + middleware := CORS(cfg) + + w := httptest.NewRecorder() + c, _ := gin.CreateTestContext(w) + c.Request = httptest.NewRequest(http.MethodOptions, "/", nil) + c.Request.Header.Set("Origin", "https://evil.example.com") + + middleware(c) + + assert.Equal(t, http.StatusForbidden, w.Code, + "不允许的 origin 的 preflight 请求应返回 403") +} + +func TestCORS_PreflightAllowedOrigin_ReturnsNoContent(t *testing.T) { + cfg := config.CORSConfig{ + AllowedOrigins: []string{"https://allowed.example.com"}, + AllowCredentials: false, + } + middleware := CORS(cfg) + + w := httptest.NewRecorder() + c, _ := gin.CreateTestContext(w) + c.Request = httptest.NewRequest(http.MethodOptions, "/", nil) + c.Request.Header.Set("Origin", "https://allowed.example.com") + + middleware(c) + + assert.Equal(t, http.StatusNoContent, w.Code, + "允许的 origin 的 preflight 请求应返回 204") +} + +func TestCORS_WildcardOrigin_AllowsAny(t *testing.T) { + cfg := config.CORSConfig{ + AllowedOrigins: []string{"*"}, + AllowCredentials: false, + } + middleware := CORS(cfg) + + w := httptest.NewRecorder() + c, _ := gin.CreateTestContext(w) + c.Request = httptest.NewRequest(http.MethodGet, "/", nil) + c.Request.Header.Set("Origin", "https://any-origin.example.com") + + middleware(c) + + assert.Equal(t, "*", w.Header().Get("Access-Control-Allow-Origin"), + "通配符配置应返回 *") + assert.NotEmpty(t, w.Header().Get("Access-Control-Allow-Headers"), + "通配符 origin 应设置 Allow-Headers") + assert.NotEmpty(t, w.Header().Get("Access-Control-Allow-Methods"), + "通配符 origin 应设置 Allow-Methods") +} + +func TestCORS_AllowCredentials_SetCorrectly(t *testing.T) { + cfg := config.CORSConfig{ + AllowedOrigins: []string{"https://allowed.example.com"}, + AllowCredentials: true, + } + middleware := CORS(cfg) + + t.Run("allowed_origin_gets_credentials", func(t *testing.T) { + w := httptest.NewRecorder() + c, _ := gin.CreateTestContext(w) + c.Request = httptest.NewRequest(http.MethodGet, "/", nil) + c.Request.Header.Set("Origin", "https://allowed.example.com") + + middleware(c) + + assert.Equal(t, "true", w.Header().Get("Access-Control-Allow-Credentials"), + "允许的 origin 且开启 credentials 应设置 Allow-Credentials") + }) + + t.Run("disallowed_origin_no_credentials", func(t *testing.T) { + w := httptest.NewRecorder() + c, _ := gin.CreateTestContext(w) + c.Request = httptest.NewRequest(http.MethodGet, "/", nil) + c.Request.Header.Set("Origin", "https://evil.example.com") + + middleware(c) + + assert.Empty(t, w.Header().Get("Access-Control-Allow-Credentials"), + "不允许的 origin 不应收到 Allow-Credentials") + }) +} + +func TestCORS_WildcardWithCredentials_DisablesCredentials(t *testing.T) { + cfg := config.CORSConfig{ + AllowedOrigins: []string{"*"}, + AllowCredentials: true, + } + middleware := CORS(cfg) + + w := httptest.NewRecorder() + c, _ := gin.CreateTestContext(w) + c.Request = httptest.NewRequest(http.MethodGet, "/", nil) + c.Request.Header.Set("Origin", "https://any.example.com") + + middleware(c) + + // 通配符 + credentials 不兼容,credentials 应被禁用 + assert.Empty(t, w.Header().Get("Access-Control-Allow-Credentials"), + "通配符 origin 应禁用 Allow-Credentials") +} + +func TestCORS_MultipleAllowedOrigins(t *testing.T) { + cfg := config.CORSConfig{ + AllowedOrigins: []string{ + "https://app1.example.com", + "https://app2.example.com", + }, + AllowCredentials: false, + } + middleware := CORS(cfg) + + t.Run("first_origin_allowed", func(t *testing.T) { + w := httptest.NewRecorder() + c, _ := gin.CreateTestContext(w) + c.Request = httptest.NewRequest(http.MethodGet, "/", nil) + c.Request.Header.Set("Origin", "https://app1.example.com") + + middleware(c) + + assert.Equal(t, "https://app1.example.com", w.Header().Get("Access-Control-Allow-Origin")) + assert.NotEmpty(t, w.Header().Get("Access-Control-Allow-Headers")) + }) + + t.Run("second_origin_allowed", func(t *testing.T) { + w := httptest.NewRecorder() + c, _ := gin.CreateTestContext(w) + c.Request = httptest.NewRequest(http.MethodGet, "/", nil) + c.Request.Header.Set("Origin", "https://app2.example.com") + + middleware(c) + + assert.Equal(t, "https://app2.example.com", w.Header().Get("Access-Control-Allow-Origin")) + assert.NotEmpty(t, w.Header().Get("Access-Control-Allow-Headers")) + }) + + t.Run("unlisted_origin_rejected", func(t *testing.T) { + w := httptest.NewRecorder() + c, _ := gin.CreateTestContext(w) + c.Request = httptest.NewRequest(http.MethodGet, "/", nil) + c.Request.Header.Set("Origin", "https://app3.example.com") + + middleware(c) + + assert.Empty(t, w.Header().Get("Access-Control-Allow-Origin")) + assert.Empty(t, w.Header().Get("Access-Control-Allow-Headers")) + }) +} + +func TestCORS_VaryHeader_SetForSpecificOrigin(t *testing.T) { + cfg := config.CORSConfig{ + AllowedOrigins: []string{"https://allowed.example.com"}, + AllowCredentials: false, + } + middleware := CORS(cfg) + + w := httptest.NewRecorder() + c, _ := gin.CreateTestContext(w) + c.Request = httptest.NewRequest(http.MethodGet, "/", nil) + c.Request.Header.Set("Origin", "https://allowed.example.com") + + middleware(c) + + assert.Contains(t, w.Header().Values("Vary"), "Origin", + "非通配符允许的 origin 应设置 Vary: Origin") +} + +func TestNormalizeOrigins(t *testing.T) { + tests := []struct { + name string + input []string + expect []string + }{ + {name: "nil_input", input: nil, expect: nil}, + {name: "empty_input", input: []string{}, expect: nil}, + {name: "trims_whitespace", input: []string{" https://a.com ", " https://b.com"}, expect: []string{"https://a.com", "https://b.com"}}, + {name: "removes_empty_strings", input: []string{"", " ", "https://a.com"}, expect: []string{"https://a.com"}}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + result := normalizeOrigins(tt.input) + assert.Equal(t, tt.expect, result) + }) + } +} diff --git a/backend/internal/server/middleware/jwt_auth.go b/backend/internal/server/middleware/jwt_auth.go index 9a89aab7..4aceb355 100644 --- a/backend/internal/server/middleware/jwt_auth.go +++ b/backend/internal/server/middleware/jwt_auth.go @@ -26,12 +26,12 @@ func jwtAuth(authService *service.AuthService, userService *service.UserService) // 验证Bearer scheme parts := strings.SplitN(authHeader, " ", 2) - if len(parts) != 2 || parts[0] != "Bearer" { + if len(parts) != 2 || !strings.EqualFold(parts[0], "Bearer") { AbortWithError(c, 401, "INVALID_AUTH_HEADER", "Authorization header format must be 'Bearer {token}'") return } - tokenString := parts[1] + tokenString := strings.TrimSpace(parts[1]) if tokenString == "" { AbortWithError(c, 401, "EMPTY_TOKEN", "Token cannot be empty") return diff --git a/backend/internal/server/middleware/jwt_auth_test.go b/backend/internal/server/middleware/jwt_auth_test.go new file mode 100644 index 00000000..bc320958 --- /dev/null +++ b/backend/internal/server/middleware/jwt_auth_test.go @@ -0,0 +1,256 @@ +//go:build unit + +package middleware + +import ( + "context" + "encoding/json" + "errors" + "net/http" + "net/http/httptest" + "testing" + + "github.com/Wei-Shaw/sub2api/internal/config" + "github.com/Wei-Shaw/sub2api/internal/service" + "github.com/gin-gonic/gin" + "github.com/stretchr/testify/require" +) + +// stubJWTUserRepo 实现 UserRepository 的最小子集,仅支持 GetByID。 +type stubJWTUserRepo struct { + service.UserRepository + users map[int64]*service.User +} + +func (r *stubJWTUserRepo) GetByID(_ context.Context, id int64) (*service.User, error) { + u, ok := r.users[id] + if !ok { + return nil, errors.New("user not found") + } + return u, nil +} + +// newJWTTestEnv 创建 JWT 认证中间件测试环境。 +// 返回 gin.Engine(已注册 JWT 中间件)和 AuthService(用于生成 Token)。 +func newJWTTestEnv(users map[int64]*service.User) (*gin.Engine, *service.AuthService) { + gin.SetMode(gin.TestMode) + + cfg := &config.Config{} + cfg.JWT.Secret = "test-jwt-secret-32bytes-long!!!" + cfg.JWT.AccessTokenExpireMinutes = 60 + + userRepo := &stubJWTUserRepo{users: users} + authSvc := service.NewAuthService(userRepo, nil, nil, cfg, nil, nil, nil, nil, nil) + userSvc := service.NewUserService(userRepo, nil, nil) + mw := NewJWTAuthMiddleware(authSvc, userSvc) + + r := gin.New() + r.Use(gin.HandlerFunc(mw)) + r.GET("/protected", func(c *gin.Context) { + subject, _ := GetAuthSubjectFromContext(c) + role, _ := GetUserRoleFromContext(c) + c.JSON(http.StatusOK, gin.H{ + "user_id": subject.UserID, + "role": role, + }) + }) + return r, authSvc +} + +func TestJWTAuth_ValidToken(t *testing.T) { + user := &service.User{ + ID: 1, + Email: "test@example.com", + Role: "user", + Status: service.StatusActive, + Concurrency: 5, + TokenVersion: 1, + } + router, authSvc := newJWTTestEnv(map[int64]*service.User{1: user}) + + token, err := authSvc.GenerateToken(user) + require.NoError(t, err) + + w := httptest.NewRecorder() + req := httptest.NewRequest(http.MethodGet, "/protected", nil) + req.Header.Set("Authorization", "Bearer "+token) + router.ServeHTTP(w, req) + + require.Equal(t, http.StatusOK, w.Code) + + var body map[string]any + require.NoError(t, json.Unmarshal(w.Body.Bytes(), &body)) + require.Equal(t, float64(1), body["user_id"]) + require.Equal(t, "user", body["role"]) +} + +func TestJWTAuth_ValidToken_LowercaseBearer(t *testing.T) { + user := &service.User{ + ID: 1, + Email: "test@example.com", + Role: "user", + Status: service.StatusActive, + Concurrency: 5, + TokenVersion: 1, + } + router, authSvc := newJWTTestEnv(map[int64]*service.User{1: user}) + + token, err := authSvc.GenerateToken(user) + require.NoError(t, err) + + w := httptest.NewRecorder() + req := httptest.NewRequest(http.MethodGet, "/protected", nil) + req.Header.Set("Authorization", "bearer "+token) + router.ServeHTTP(w, req) + + require.Equal(t, http.StatusOK, w.Code) +} + +func TestJWTAuth_MissingAuthorizationHeader(t *testing.T) { + router, _ := newJWTTestEnv(nil) + + w := httptest.NewRecorder() + req := httptest.NewRequest(http.MethodGet, "/protected", nil) + router.ServeHTTP(w, req) + + require.Equal(t, http.StatusUnauthorized, w.Code) + var body ErrorResponse + require.NoError(t, json.Unmarshal(w.Body.Bytes(), &body)) + require.Equal(t, "UNAUTHORIZED", body.Code) +} + +func TestJWTAuth_InvalidHeaderFormat(t *testing.T) { + tests := []struct { + name string + header string + }{ + {"无Bearer前缀", "Token abc123"}, + {"缺少空格分隔", "Bearerabc123"}, + {"仅有单词", "abc123"}, + } + router, _ := newJWTTestEnv(nil) + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + w := httptest.NewRecorder() + req := httptest.NewRequest(http.MethodGet, "/protected", nil) + req.Header.Set("Authorization", tt.header) + router.ServeHTTP(w, req) + + require.Equal(t, http.StatusUnauthorized, w.Code) + var body ErrorResponse + require.NoError(t, json.Unmarshal(w.Body.Bytes(), &body)) + require.Equal(t, "INVALID_AUTH_HEADER", body.Code) + }) + } +} + +func TestJWTAuth_EmptyToken(t *testing.T) { + router, _ := newJWTTestEnv(nil) + + w := httptest.NewRecorder() + req := httptest.NewRequest(http.MethodGet, "/protected", nil) + req.Header.Set("Authorization", "Bearer ") + router.ServeHTTP(w, req) + + require.Equal(t, http.StatusUnauthorized, w.Code) + var body ErrorResponse + require.NoError(t, json.Unmarshal(w.Body.Bytes(), &body)) + require.Equal(t, "EMPTY_TOKEN", body.Code) +} + +func TestJWTAuth_TamperedToken(t *testing.T) { + router, _ := newJWTTestEnv(nil) + + w := httptest.NewRecorder() + req := httptest.NewRequest(http.MethodGet, "/protected", nil) + req.Header.Set("Authorization", "Bearer eyJhbGciOiJIUzI1NiJ9.eyJ1c2VyX2lkIjoxfQ.invalid_signature") + router.ServeHTTP(w, req) + + require.Equal(t, http.StatusUnauthorized, w.Code) + var body ErrorResponse + require.NoError(t, json.Unmarshal(w.Body.Bytes(), &body)) + require.Equal(t, "INVALID_TOKEN", body.Code) +} + +func TestJWTAuth_UserNotFound(t *testing.T) { + // 使用 user ID=1 的 token,但 repo 中没有该用户 + fakeUser := &service.User{ + ID: 999, + Email: "ghost@example.com", + Role: "user", + Status: service.StatusActive, + TokenVersion: 1, + } + // 创建环境时不注入此用户,这样 GetByID 会失败 + router, authSvc := newJWTTestEnv(map[int64]*service.User{}) + + token, err := authSvc.GenerateToken(fakeUser) + require.NoError(t, err) + + w := httptest.NewRecorder() + req := httptest.NewRequest(http.MethodGet, "/protected", nil) + req.Header.Set("Authorization", "Bearer "+token) + router.ServeHTTP(w, req) + + require.Equal(t, http.StatusUnauthorized, w.Code) + var body ErrorResponse + require.NoError(t, json.Unmarshal(w.Body.Bytes(), &body)) + require.Equal(t, "USER_NOT_FOUND", body.Code) +} + +func TestJWTAuth_UserInactive(t *testing.T) { + user := &service.User{ + ID: 1, + Email: "disabled@example.com", + Role: "user", + Status: service.StatusDisabled, + TokenVersion: 1, + } + router, authSvc := newJWTTestEnv(map[int64]*service.User{1: user}) + + token, err := authSvc.GenerateToken(user) + require.NoError(t, err) + + w := httptest.NewRecorder() + req := httptest.NewRequest(http.MethodGet, "/protected", nil) + req.Header.Set("Authorization", "Bearer "+token) + router.ServeHTTP(w, req) + + require.Equal(t, http.StatusUnauthorized, w.Code) + var body ErrorResponse + require.NoError(t, json.Unmarshal(w.Body.Bytes(), &body)) + require.Equal(t, "USER_INACTIVE", body.Code) +} + +func TestJWTAuth_TokenVersionMismatch(t *testing.T) { + // Token 生成时 TokenVersion=1,但数据库中用户已更新为 TokenVersion=2(密码修改) + userForToken := &service.User{ + ID: 1, + Email: "test@example.com", + Role: "user", + Status: service.StatusActive, + TokenVersion: 1, + } + userInDB := &service.User{ + ID: 1, + Email: "test@example.com", + Role: "user", + Status: service.StatusActive, + TokenVersion: 2, // 密码修改后版本递增 + } + router, authSvc := newJWTTestEnv(map[int64]*service.User{1: userInDB}) + + token, err := authSvc.GenerateToken(userForToken) + require.NoError(t, err) + + w := httptest.NewRecorder() + req := httptest.NewRequest(http.MethodGet, "/protected", nil) + req.Header.Set("Authorization", "Bearer "+token) + router.ServeHTTP(w, req) + + require.Equal(t, http.StatusUnauthorized, w.Code) + var body ErrorResponse + require.NoError(t, json.Unmarshal(w.Body.Bytes(), &body)) + require.Equal(t, "TOKEN_REVOKED", body.Code) +} diff --git a/backend/internal/server/middleware/logger.go b/backend/internal/server/middleware/logger.go index 842efda9..b14a3a21 100644 --- a/backend/internal/server/middleware/logger.go +++ b/backend/internal/server/middleware/logger.go @@ -1,10 +1,12 @@ package middleware import ( - "log" "time" + "github.com/Wei-Shaw/sub2api/internal/pkg/ctxkey" + "github.com/Wei-Shaw/sub2api/internal/pkg/logger" "github.com/gin-gonic/gin" + "go.uber.org/zap" ) // Logger 请求日志中间件 @@ -13,44 +15,52 @@ func Logger() gin.HandlerFunc { // 开始时间 startTime := time.Now() - // 处理请求 - c.Next() - - // 结束时间 - endTime := time.Now() - - // 执行时间 - latency := endTime.Sub(startTime) - - // 请求方法 - method := c.Request.Method - // 请求路径 path := c.Request.URL.Path - // 状态码 + // 处理请求 + c.Next() + + // 跳过健康检查等高频探针路径的日志 + if path == "/health" || path == "/setup/status" { + return + } + + endTime := time.Now() + latency := endTime.Sub(startTime) + + method := c.Request.Method statusCode := c.Writer.Status() - - // 客户端IP clientIP := c.ClientIP() - - // 协议版本 protocol := c.Request.Proto + accountID, hasAccountID := c.Request.Context().Value(ctxkey.AccountID).(int64) + platform, _ := c.Request.Context().Value(ctxkey.Platform).(string) + model, _ := c.Request.Context().Value(ctxkey.Model).(string) - // 日志格式: [时间] 状态码 | 延迟 | IP | 协议 | 方法 路径 - log.Printf("[GIN] %v | %3d | %13v | %15s | %-6s | %-7s %s", - endTime.Format("2006/01/02 - 15:04:05"), - statusCode, - latency, - clientIP, - protocol, - method, - path, - ) + fields := []zap.Field{ + zap.String("component", "http.access"), + zap.Int("status_code", statusCode), + zap.Int64("latency_ms", latency.Milliseconds()), + zap.String("client_ip", clientIP), + zap.String("protocol", protocol), + zap.String("method", method), + zap.String("path", path), + } + if hasAccountID && accountID > 0 { + fields = append(fields, zap.Int64("account_id", accountID)) + } + if platform != "" { + fields = append(fields, zap.String("platform", platform)) + } + if model != "" { + fields = append(fields, zap.String("model", model)) + } + + l := logger.FromContext(c.Request.Context()).With(fields...) + l.Info("http request completed", zap.Time("completed_at", endTime)) - // 如果有错误,额外记录错误信息 if len(c.Errors) > 0 { - log.Printf("[GIN] Errors: %v", c.Errors.String()) + l.Warn("http request contains gin errors", zap.String("errors", c.Errors.String())) } } } diff --git a/backend/internal/server/middleware/misc_coverage_test.go b/backend/internal/server/middleware/misc_coverage_test.go new file mode 100644 index 00000000..c0adfc4d --- /dev/null +++ b/backend/internal/server/middleware/misc_coverage_test.go @@ -0,0 +1,126 @@ +//go:build unit + +package middleware + +import ( + "bytes" + "context" + "io" + "net/http" + "net/http/httptest" + "testing" + + "github.com/Wei-Shaw/sub2api/internal/pkg/ctxkey" + "github.com/Wei-Shaw/sub2api/internal/service" + "github.com/gin-gonic/gin" + "github.com/stretchr/testify/require" +) + +func TestClientRequestID_GeneratesWhenMissing(t *testing.T) { + gin.SetMode(gin.TestMode) + + r := gin.New() + r.Use(ClientRequestID()) + r.GET("/t", func(c *gin.Context) { + v := c.Request.Context().Value(ctxkey.ClientRequestID) + require.NotNil(t, v) + id, ok := v.(string) + require.True(t, ok) + require.NotEmpty(t, id) + c.Status(http.StatusOK) + }) + + w := httptest.NewRecorder() + req := httptest.NewRequest(http.MethodGet, "/t", nil) + r.ServeHTTP(w, req) + require.Equal(t, http.StatusOK, w.Code) +} + +func TestClientRequestID_PreservesExisting(t *testing.T) { + gin.SetMode(gin.TestMode) + + r := gin.New() + r.Use(ClientRequestID()) + r.GET("/t", func(c *gin.Context) { + id, ok := c.Request.Context().Value(ctxkey.ClientRequestID).(string) + require.True(t, ok) + require.Equal(t, "keep", id) + c.Status(http.StatusOK) + }) + + w := httptest.NewRecorder() + req := httptest.NewRequest(http.MethodGet, "/t", nil) + req = req.WithContext(context.WithValue(req.Context(), ctxkey.ClientRequestID, "keep")) + r.ServeHTTP(w, req) + require.Equal(t, http.StatusOK, w.Code) +} + +func TestRequestBodyLimit_LimitsBody(t *testing.T) { + gin.SetMode(gin.TestMode) + + r := gin.New() + r.Use(RequestBodyLimit(4)) + r.POST("/t", func(c *gin.Context) { + _, err := io.ReadAll(c.Request.Body) + require.Error(t, err) + c.Status(http.StatusOK) + }) + + w := httptest.NewRecorder() + req := httptest.NewRequest(http.MethodPost, "/t", bytes.NewBufferString("12345")) + r.ServeHTTP(w, req) + require.Equal(t, http.StatusOK, w.Code) +} + +func TestForcePlatform_SetsContextAndGinValue(t *testing.T) { + gin.SetMode(gin.TestMode) + + r := gin.New() + r.Use(ForcePlatform("anthropic")) + r.GET("/t", func(c *gin.Context) { + require.True(t, HasForcePlatform(c)) + v, ok := GetForcePlatformFromContext(c) + require.True(t, ok) + require.Equal(t, "anthropic", v) + + ctxV := c.Request.Context().Value(ctxkey.ForcePlatform) + require.Equal(t, "anthropic", ctxV) + c.Status(http.StatusOK) + }) + + w := httptest.NewRecorder() + req := httptest.NewRequest(http.MethodGet, "/t", nil) + r.ServeHTTP(w, req) + require.Equal(t, http.StatusOK, w.Code) +} + +func TestAuthSubjectHelpers_RoundTrip(t *testing.T) { + c := &gin.Context{} + c.Set(string(ContextKeyUser), AuthSubject{UserID: 1, Concurrency: 2}) + c.Set(string(ContextKeyUserRole), "admin") + + sub, ok := GetAuthSubjectFromContext(c) + require.True(t, ok) + require.Equal(t, int64(1), sub.UserID) + require.Equal(t, 2, sub.Concurrency) + + role, ok := GetUserRoleFromContext(c) + require.True(t, ok) + require.Equal(t, "admin", role) +} + +func TestAPIKeyAndSubscriptionFromContext(t *testing.T) { + c := &gin.Context{} + + key := &service.APIKey{ID: 1} + c.Set(string(ContextKeyAPIKey), key) + gotKey, ok := GetAPIKeyFromContext(c) + require.True(t, ok) + require.Equal(t, int64(1), gotKey.ID) + + sub := &service.UserSubscription{ID: 2} + c.Set(string(ContextKeySubscription), sub) + gotSub, ok := GetSubscriptionFromContext(c) + require.True(t, ok) + require.Equal(t, int64(2), gotSub.ID) +} diff --git a/backend/internal/server/middleware/recovery_test.go b/backend/internal/server/middleware/recovery_test.go index 439f44cb..33e71d51 100644 --- a/backend/internal/server/middleware/recovery_test.go +++ b/backend/internal/server/middleware/recovery_test.go @@ -3,6 +3,7 @@ package middleware import ( + "bytes" "encoding/json" "net/http" "net/http/httptest" @@ -14,6 +15,34 @@ import ( "github.com/stretchr/testify/require" ) +func TestRecovery_PanicLogContainsInfo(t *testing.T) { + gin.SetMode(gin.TestMode) + + // 临时替换 DefaultErrorWriter 以捕获日志输出 + var buf bytes.Buffer + originalWriter := gin.DefaultErrorWriter + gin.DefaultErrorWriter = &buf + t.Cleanup(func() { + gin.DefaultErrorWriter = originalWriter + }) + + r := gin.New() + r.Use(Recovery()) + r.GET("/panic", func(c *gin.Context) { + panic("custom panic message for test") + }) + + w := httptest.NewRecorder() + req := httptest.NewRequest(http.MethodGet, "/panic", nil) + r.ServeHTTP(w, req) + + require.Equal(t, http.StatusInternalServerError, w.Code) + + logOutput := buf.String() + require.Contains(t, logOutput, "custom panic message for test", "日志应包含 panic 信息") + require.Contains(t, logOutput, "recovery_test.go", "日志应包含堆栈跟踪文件名") +} + func TestRecovery(t *testing.T) { gin.SetMode(gin.TestMode) diff --git a/backend/internal/server/middleware/request_access_logger_test.go b/backend/internal/server/middleware/request_access_logger_test.go new file mode 100644 index 00000000..fec3ed22 --- /dev/null +++ b/backend/internal/server/middleware/request_access_logger_test.go @@ -0,0 +1,228 @@ +package middleware + +import ( + "context" + "net/http" + "net/http/httptest" + "strings" + "sync" + "testing" + + "github.com/Wei-Shaw/sub2api/internal/pkg/ctxkey" + "github.com/Wei-Shaw/sub2api/internal/pkg/logger" + "github.com/gin-gonic/gin" +) + +type testLogSink struct { + mu sync.Mutex + events []*logger.LogEvent +} + +func (s *testLogSink) WriteLogEvent(event *logger.LogEvent) { + s.mu.Lock() + defer s.mu.Unlock() + s.events = append(s.events, event) +} + +func (s *testLogSink) list() []*logger.LogEvent { + s.mu.Lock() + defer s.mu.Unlock() + out := make([]*logger.LogEvent, len(s.events)) + copy(out, s.events) + return out +} + +func initMiddlewareTestLogger(t *testing.T) *testLogSink { + return initMiddlewareTestLoggerWithLevel(t, "debug") +} + +func initMiddlewareTestLoggerWithLevel(t *testing.T, level string) *testLogSink { + t.Helper() + level = strings.TrimSpace(level) + if level == "" { + level = "debug" + } + if err := logger.Init(logger.InitOptions{ + Level: level, + Format: "json", + ServiceName: "sub2api", + Environment: "test", + Output: logger.OutputOptions{ + ToStdout: false, + ToFile: false, + }, + }); err != nil { + t.Fatalf("init logger: %v", err) + } + sink := &testLogSink{} + logger.SetSink(sink) + t.Cleanup(func() { + logger.SetSink(nil) + }) + return sink +} + +func TestRequestLogger_GenerateAndPropagateRequestID(t *testing.T) { + gin.SetMode(gin.TestMode) + r := gin.New() + r.Use(RequestLogger()) + r.GET("/t", func(c *gin.Context) { + reqID, ok := c.Request.Context().Value(ctxkey.RequestID).(string) + if !ok || reqID == "" { + t.Fatalf("request_id missing in context") + } + if got := c.Writer.Header().Get(requestIDHeader); got != reqID { + t.Fatalf("response header request_id mismatch, header=%q ctx=%q", got, reqID) + } + c.Status(http.StatusOK) + }) + + w := httptest.NewRecorder() + req := httptest.NewRequest(http.MethodGet, "/t", nil) + r.ServeHTTP(w, req) + if w.Code != http.StatusOK { + t.Fatalf("status=%d", w.Code) + } + if w.Header().Get(requestIDHeader) == "" { + t.Fatalf("X-Request-ID should be set") + } +} + +func TestRequestLogger_KeepIncomingRequestID(t *testing.T) { + gin.SetMode(gin.TestMode) + r := gin.New() + r.Use(RequestLogger()) + r.GET("/t", func(c *gin.Context) { + reqID, _ := c.Request.Context().Value(ctxkey.RequestID).(string) + if reqID != "rid-fixed" { + t.Fatalf("request_id=%q, want rid-fixed", reqID) + } + c.Status(http.StatusOK) + }) + + w := httptest.NewRecorder() + req := httptest.NewRequest(http.MethodGet, "/t", nil) + req.Header.Set(requestIDHeader, "rid-fixed") + r.ServeHTTP(w, req) + if w.Code != http.StatusOK { + t.Fatalf("status=%d", w.Code) + } + if got := w.Header().Get(requestIDHeader); got != "rid-fixed" { + t.Fatalf("header=%q, want rid-fixed", got) + } +} + +func TestLogger_AccessLogIncludesCoreFields(t *testing.T) { + gin.SetMode(gin.TestMode) + sink := initMiddlewareTestLogger(t) + + r := gin.New() + r.Use(Logger()) + r.Use(func(c *gin.Context) { + ctx := c.Request.Context() + ctx = context.WithValue(ctx, ctxkey.AccountID, int64(101)) + ctx = context.WithValue(ctx, ctxkey.Platform, "openai") + ctx = context.WithValue(ctx, ctxkey.Model, "gpt-5") + c.Request = c.Request.WithContext(ctx) + c.Next() + }) + r.GET("/api/test", func(c *gin.Context) { + c.Status(http.StatusCreated) + }) + + w := httptest.NewRecorder() + req := httptest.NewRequest(http.MethodGet, "/api/test", nil) + r.ServeHTTP(w, req) + if w.Code != http.StatusCreated { + t.Fatalf("status=%d", w.Code) + } + + events := sink.list() + if len(events) == 0 { + t.Fatalf("expected at least one log event") + } + found := false + for _, event := range events { + if event == nil || event.Message != "http request completed" { + continue + } + found = true + switch v := event.Fields["status_code"].(type) { + case int: + if v != http.StatusCreated { + t.Fatalf("status_code field mismatch: %v", v) + } + case int64: + if v != int64(http.StatusCreated) { + t.Fatalf("status_code field mismatch: %v", v) + } + default: + t.Fatalf("status_code type mismatch: %T", v) + } + switch v := event.Fields["account_id"].(type) { + case int64: + if v != 101 { + t.Fatalf("account_id field mismatch: %v", v) + } + case int: + if v != 101 { + t.Fatalf("account_id field mismatch: %v", v) + } + default: + t.Fatalf("account_id type mismatch: %T", v) + } + if event.Fields["platform"] != "openai" || event.Fields["model"] != "gpt-5" { + t.Fatalf("platform/model mismatch: %+v", event.Fields) + } + } + if !found { + t.Fatalf("access log event not found") + } +} + +func TestLogger_HealthPathSkipped(t *testing.T) { + gin.SetMode(gin.TestMode) + sink := initMiddlewareTestLogger(t) + + r := gin.New() + r.Use(Logger()) + r.GET("/health", func(c *gin.Context) { + c.Status(http.StatusOK) + }) + + w := httptest.NewRecorder() + req := httptest.NewRequest(http.MethodGet, "/health", nil) + r.ServeHTTP(w, req) + if w.Code != http.StatusOK { + t.Fatalf("status=%d", w.Code) + } + if len(sink.list()) != 0 { + t.Fatalf("health endpoint should not write access log") + } +} + +func TestLogger_AccessLogDroppedWhenLevelWarn(t *testing.T) { + gin.SetMode(gin.TestMode) + sink := initMiddlewareTestLoggerWithLevel(t, "warn") + + r := gin.New() + r.Use(RequestLogger()) + r.Use(Logger()) + r.GET("/api/test", func(c *gin.Context) { + c.Status(http.StatusCreated) + }) + + w := httptest.NewRecorder() + req := httptest.NewRequest(http.MethodGet, "/api/test", nil) + r.ServeHTTP(w, req) + if w.Code != http.StatusCreated { + t.Fatalf("status=%d", w.Code) + } + + events := sink.list() + for _, event := range events { + if event != nil && event.Message == "http request completed" { + t.Fatalf("access log should not be indexed when level=warn: %+v", event) + } + } +} diff --git a/backend/internal/server/middleware/request_logger.go b/backend/internal/server/middleware/request_logger.go new file mode 100644 index 00000000..0fb2feca --- /dev/null +++ b/backend/internal/server/middleware/request_logger.go @@ -0,0 +1,45 @@ +package middleware + +import ( + "context" + "strings" + + "github.com/Wei-Shaw/sub2api/internal/pkg/ctxkey" + "github.com/Wei-Shaw/sub2api/internal/pkg/logger" + "github.com/gin-gonic/gin" + "github.com/google/uuid" + "go.uber.org/zap" +) + +const requestIDHeader = "X-Request-ID" + +// RequestLogger 在请求入口注入 request-scoped logger。 +func RequestLogger() gin.HandlerFunc { + return func(c *gin.Context) { + if c.Request == nil { + c.Next() + return + } + + requestID := strings.TrimSpace(c.GetHeader(requestIDHeader)) + if requestID == "" { + requestID = uuid.NewString() + } + c.Header(requestIDHeader, requestID) + + ctx := context.WithValue(c.Request.Context(), ctxkey.RequestID, requestID) + clientRequestID, _ := ctx.Value(ctxkey.ClientRequestID).(string) + + requestLogger := logger.With( + zap.String("component", "http"), + zap.String("request_id", requestID), + zap.String("client_request_id", strings.TrimSpace(clientRequestID)), + zap.String("path", c.Request.URL.Path), + zap.String("method", c.Request.Method), + ) + + ctx = logger.IntoContext(ctx, requestLogger) + c.Request = c.Request.WithContext(ctx) + c.Next() + } +} diff --git a/backend/internal/server/middleware/security_headers.go b/backend/internal/server/middleware/security_headers.go index 9ce7f449..67b19c09 100644 --- a/backend/internal/server/middleware/security_headers.go +++ b/backend/internal/server/middleware/security_headers.go @@ -3,6 +3,8 @@ package middleware import ( "crypto/rand" "encoding/base64" + "fmt" + "log" "strings" "github.com/Wei-Shaw/sub2api/internal/config" @@ -18,11 +20,14 @@ const ( CloudflareInsightsDomain = "https://static.cloudflareinsights.com" ) -// GenerateNonce generates a cryptographically secure random nonce -func GenerateNonce() string { +// GenerateNonce generates a cryptographically secure random nonce. +// 返回 error 以确保调用方在 crypto/rand 失败时能正确降级。 +func GenerateNonce() (string, error) { b := make([]byte, 16) - _, _ = rand.Read(b) - return base64.StdEncoding.EncodeToString(b) + if _, err := rand.Read(b); err != nil { + return "", fmt.Errorf("generate CSP nonce: %w", err) + } + return base64.StdEncoding.EncodeToString(b), nil } // GetNonceFromContext retrieves the CSP nonce from gin context @@ -52,12 +57,17 @@ func SecurityHeaders(cfg config.CSPConfig) gin.HandlerFunc { if cfg.Enabled { // Generate nonce for this request - nonce := GenerateNonce() - c.Set(CSPNonceKey, nonce) - - // Replace nonce placeholder in policy - finalPolicy := strings.ReplaceAll(policy, NonceTemplate, "'nonce-"+nonce+"'") - c.Header("Content-Security-Policy", finalPolicy) + nonce, err := GenerateNonce() + if err != nil { + // crypto/rand 失败时降级为无 nonce 的 CSP 策略 + log.Printf("[SecurityHeaders] %v — 降级为无 nonce 的 CSP", err) + finalPolicy := strings.ReplaceAll(policy, NonceTemplate, "'unsafe-inline'") + c.Header("Content-Security-Policy", finalPolicy) + } else { + c.Set(CSPNonceKey, nonce) + finalPolicy := strings.ReplaceAll(policy, NonceTemplate, "'nonce-"+nonce+"'") + c.Header("Content-Security-Policy", finalPolicy) + } } c.Next() } diff --git a/backend/internal/server/middleware/security_headers_test.go b/backend/internal/server/middleware/security_headers_test.go index dc7a87d8..43462b82 100644 --- a/backend/internal/server/middleware/security_headers_test.go +++ b/backend/internal/server/middleware/security_headers_test.go @@ -19,7 +19,8 @@ func init() { func TestGenerateNonce(t *testing.T) { t.Run("generates_valid_base64_string", func(t *testing.T) { - nonce := GenerateNonce() + nonce, err := GenerateNonce() + require.NoError(t, err) // Should be valid base64 decoded, err := base64.StdEncoding.DecodeString(nonce) @@ -32,14 +33,16 @@ func TestGenerateNonce(t *testing.T) { t.Run("generates_unique_nonces", func(t *testing.T) { nonces := make(map[string]bool) for i := 0; i < 100; i++ { - nonce := GenerateNonce() + nonce, err := GenerateNonce() + require.NoError(t, err) assert.False(t, nonces[nonce], "nonce should be unique") nonces[nonce] = true } }) t.Run("nonce_has_expected_length", func(t *testing.T) { - nonce := GenerateNonce() + nonce, err := GenerateNonce() + require.NoError(t, err) // 16 bytes -> 24 chars in base64 (with padding) assert.Len(t, nonce, 24) }) @@ -344,7 +347,7 @@ func TestAddToDirective(t *testing.T) { // Benchmark tests func BenchmarkGenerateNonce(b *testing.B) { for i := 0; i < b.N; i++ { - GenerateNonce() + _, _ = GenerateNonce() } } diff --git a/backend/internal/server/router.go b/backend/internal/server/router.go index cf9015e4..fb91bc0e 100644 --- a/backend/internal/server/router.go +++ b/backend/internal/server/router.go @@ -29,6 +29,7 @@ func SetupRouter( redisClient *redis.Client, ) *gin.Engine { // 应用中间件 + r.Use(middleware2.RequestLogger()) r.Use(middleware2.Logger()) r.Use(middleware2.CORS(cfg.CORS)) r.Use(middleware2.SecurityHeaders(cfg.Security.CSP)) diff --git a/backend/internal/server/routes/admin.go b/backend/internal/server/routes/admin.go index 4509b4bc..4b4d97c3 100644 --- a/backend/internal/server/routes/admin.go +++ b/backend/internal/server/routes/admin.go @@ -34,6 +34,8 @@ func RegisterAdminRoutes( // OpenAI OAuth registerOpenAIOAuthRoutes(admin, h) + // Sora OAuth(实现复用 OpenAI OAuth 服务,入口独立) + registerSoraOAuthRoutes(admin, h) // Gemini OAuth registerGeminiOAuthRoutes(admin, h) @@ -101,6 +103,9 @@ func registerOpsRoutes(admin *gin.RouterGroup, h *handler.Handlers) { { runtime.GET("/alert", h.Admin.Ops.GetAlertRuntimeSettings) runtime.PUT("/alert", h.Admin.Ops.UpdateAlertRuntimeSettings) + runtime.GET("/logging", h.Admin.Ops.GetRuntimeLogConfig) + runtime.PUT("/logging", h.Admin.Ops.UpdateRuntimeLogConfig) + runtime.POST("/logging/reset", h.Admin.Ops.ResetRuntimeLogConfig) } // Advanced settings (DB-backed) @@ -144,12 +149,18 @@ func registerOpsRoutes(admin *gin.RouterGroup, h *handler.Handlers) { // Request drilldown (success + error) ops.GET("/requests", h.Admin.Ops.ListRequestDetails) + // Indexed system logs + ops.GET("/system-logs", h.Admin.Ops.ListSystemLogs) + ops.POST("/system-logs/cleanup", h.Admin.Ops.CleanupSystemLogs) + ops.GET("/system-logs/health", h.Admin.Ops.GetSystemLogIngestionHealth) + // Dashboard (vNext - raw path for MVP) ops.GET("/dashboard/overview", h.Admin.Ops.GetDashboardOverview) ops.GET("/dashboard/throughput-trend", h.Admin.Ops.GetDashboardThroughputTrend) ops.GET("/dashboard/latency-histogram", h.Admin.Ops.GetDashboardLatencyHistogram) ops.GET("/dashboard/error-trend", h.Admin.Ops.GetDashboardErrorTrend) ops.GET("/dashboard/error-distribution", h.Admin.Ops.GetDashboardErrorDistribution) + ops.GET("/dashboard/openai-token-stats", h.Admin.Ops.GetDashboardOpenAITokenStats) } } @@ -267,6 +278,19 @@ func registerOpenAIOAuthRoutes(admin *gin.RouterGroup, h *handler.Handlers) { } } +func registerSoraOAuthRoutes(admin *gin.RouterGroup, h *handler.Handlers) { + sora := admin.Group("/sora") + { + sora.POST("/generate-auth-url", h.Admin.OpenAIOAuth.GenerateAuthURL) + sora.POST("/exchange-code", h.Admin.OpenAIOAuth.ExchangeCode) + sora.POST("/refresh-token", h.Admin.OpenAIOAuth.RefreshToken) + sora.POST("/st2at", h.Admin.OpenAIOAuth.ExchangeSoraSessionToken) + sora.POST("/rt2at", h.Admin.OpenAIOAuth.RefreshToken) + sora.POST("/accounts/:id/refresh", h.Admin.OpenAIOAuth.RefreshAccountToken) + sora.POST("/create-from-oauth", h.Admin.OpenAIOAuth.CreateAccountFromOAuth) + } +} + func registerGeminiOAuthRoutes(admin *gin.RouterGroup, h *handler.Handlers) { gemini := admin.Group("/gemini") { @@ -297,6 +321,7 @@ func registerProxyRoutes(admin *gin.RouterGroup, h *handler.Handlers) { proxies.PUT("/:id", h.Admin.Proxy.Update) proxies.DELETE("/:id", h.Admin.Proxy.Delete) proxies.POST("/:id/test", h.Admin.Proxy.Test) + proxies.POST("/:id/quality-check", h.Admin.Proxy.CheckQuality) proxies.GET("/:id/stats", h.Admin.Proxy.GetStats) proxies.GET("/:id/accounts", h.Admin.Proxy.GetProxyAccounts) proxies.POST("/batch-delete", h.Admin.Proxy.BatchDelete) diff --git a/backend/internal/server/routes/auth.go b/backend/internal/server/routes/auth.go index 26d79605..c168820c 100644 --- a/backend/internal/server/routes/auth.go +++ b/backend/internal/server/routes/auth.go @@ -24,10 +24,19 @@ func RegisterAuthRoutes( // 公开接口 auth := v1.Group("/auth") { - auth.POST("/register", h.Auth.Register) - auth.POST("/login", h.Auth.Login) - auth.POST("/login/2fa", h.Auth.Login2FA) - auth.POST("/send-verify-code", h.Auth.SendVerifyCode) + // 注册/登录/2FA/验证码发送均属于高风险入口,增加服务端兜底限流(Redis 故障时 fail-close) + auth.POST("/register", rateLimiter.LimitWithOptions("auth-register", 5, time.Minute, middleware.RateLimitOptions{ + FailureMode: middleware.RateLimitFailClose, + }), h.Auth.Register) + auth.POST("/login", rateLimiter.LimitWithOptions("auth-login", 20, time.Minute, middleware.RateLimitOptions{ + FailureMode: middleware.RateLimitFailClose, + }), h.Auth.Login) + auth.POST("/login/2fa", rateLimiter.LimitWithOptions("auth-login-2fa", 20, time.Minute, middleware.RateLimitOptions{ + FailureMode: middleware.RateLimitFailClose, + }), h.Auth.Login2FA) + auth.POST("/send-verify-code", rateLimiter.LimitWithOptions("auth-send-verify-code", 5, time.Minute, middleware.RateLimitOptions{ + FailureMode: middleware.RateLimitFailClose, + }), h.Auth.SendVerifyCode) // Token刷新接口添加速率限制:每分钟最多 30 次(Redis 故障时 fail-close) auth.POST("/refresh", rateLimiter.LimitWithOptions("refresh-token", 30, time.Minute, middleware.RateLimitOptions{ FailureMode: middleware.RateLimitFailClose, diff --git a/backend/internal/server/routes/auth_rate_limit_integration_test.go b/backend/internal/server/routes/auth_rate_limit_integration_test.go new file mode 100644 index 00000000..8a0ef860 --- /dev/null +++ b/backend/internal/server/routes/auth_rate_limit_integration_test.go @@ -0,0 +1,111 @@ +//go:build integration + +package routes + +import ( + "context" + "fmt" + "net/http" + "net/http/httptest" + "os" + "path/filepath" + "strconv" + "strings" + "testing" + + "github.com/redis/go-redis/v9" + "github.com/stretchr/testify/require" + tcredis "github.com/testcontainers/testcontainers-go/modules/redis" +) + +const authRouteRedisImageTag = "redis:8.4-alpine" + +func TestAuthRegisterRateLimitThresholdHitReturns429(t *testing.T) { + ctx := context.Background() + rdb := startAuthRouteRedis(t, ctx) + + router := newAuthRoutesTestRouter(rdb) + const path = "/api/v1/auth/register" + + for i := 1; i <= 6; i++ { + req := httptest.NewRequest(http.MethodPost, path, strings.NewReader(`{}`)) + req.Header.Set("Content-Type", "application/json") + req.RemoteAddr = "198.51.100.10:23456" + + w := httptest.NewRecorder() + router.ServeHTTP(w, req) + + if i <= 5 { + require.Equal(t, http.StatusBadRequest, w.Code, "第 %d 次请求应先进入业务校验", i) + continue + } + require.Equal(t, http.StatusTooManyRequests, w.Code, "第 6 次请求应命中限流") + require.Contains(t, w.Body.String(), "rate limit exceeded") + } +} + +func startAuthRouteRedis(t *testing.T, ctx context.Context) *redis.Client { + t.Helper() + ensureAuthRouteDockerAvailable(t) + + redisContainer, err := tcredis.Run(ctx, authRouteRedisImageTag) + require.NoError(t, err) + t.Cleanup(func() { + _ = redisContainer.Terminate(ctx) + }) + + redisHost, err := redisContainer.Host(ctx) + require.NoError(t, err) + redisPort, err := redisContainer.MappedPort(ctx, "6379/tcp") + require.NoError(t, err) + + rdb := redis.NewClient(&redis.Options{ + Addr: fmt.Sprintf("%s:%d", redisHost, redisPort.Int()), + DB: 0, + }) + require.NoError(t, rdb.Ping(ctx).Err()) + t.Cleanup(func() { + _ = rdb.Close() + }) + return rdb +} + +func ensureAuthRouteDockerAvailable(t *testing.T) { + t.Helper() + if authRouteDockerAvailable() { + return + } + t.Skip("Docker 未启用,跳过认证限流集成测试") +} + +func authRouteDockerAvailable() bool { + if os.Getenv("DOCKER_HOST") != "" { + return true + } + + socketCandidates := []string{ + "/var/run/docker.sock", + filepath.Join(os.Getenv("XDG_RUNTIME_DIR"), "docker.sock"), + filepath.Join(authRouteUserHomeDir(), ".docker", "run", "docker.sock"), + filepath.Join(authRouteUserHomeDir(), ".docker", "desktop", "docker.sock"), + filepath.Join("/run/user", strconv.Itoa(os.Getuid()), "docker.sock"), + } + + for _, socket := range socketCandidates { + if socket == "" { + continue + } + if _, err := os.Stat(socket); err == nil { + return true + } + } + return false +} + +func authRouteUserHomeDir() string { + home, err := os.UserHomeDir() + if err != nil { + return "" + } + return home +} diff --git a/backend/internal/server/routes/auth_rate_limit_test.go b/backend/internal/server/routes/auth_rate_limit_test.go new file mode 100644 index 00000000..5ce8497c --- /dev/null +++ b/backend/internal/server/routes/auth_rate_limit_test.go @@ -0,0 +1,67 @@ +package routes + +import ( + "net/http" + "net/http/httptest" + "strings" + "testing" + "time" + + "github.com/Wei-Shaw/sub2api/internal/handler" + servermiddleware "github.com/Wei-Shaw/sub2api/internal/server/middleware" + "github.com/gin-gonic/gin" + "github.com/redis/go-redis/v9" + "github.com/stretchr/testify/require" +) + +func newAuthRoutesTestRouter(redisClient *redis.Client) *gin.Engine { + gin.SetMode(gin.TestMode) + router := gin.New() + v1 := router.Group("/api/v1") + + RegisterAuthRoutes( + v1, + &handler.Handlers{ + Auth: &handler.AuthHandler{}, + Setting: &handler.SettingHandler{}, + }, + servermiddleware.JWTAuthMiddleware(func(c *gin.Context) { + c.Next() + }), + redisClient, + ) + + return router +} + +func TestAuthRoutesRateLimitFailCloseWhenRedisUnavailable(t *testing.T) { + rdb := redis.NewClient(&redis.Options{ + Addr: "127.0.0.1:1", + DialTimeout: 50 * time.Millisecond, + ReadTimeout: 50 * time.Millisecond, + WriteTimeout: 50 * time.Millisecond, + }) + t.Cleanup(func() { + _ = rdb.Close() + }) + + router := newAuthRoutesTestRouter(rdb) + paths := []string{ + "/api/v1/auth/register", + "/api/v1/auth/login", + "/api/v1/auth/login/2fa", + "/api/v1/auth/send-verify-code", + } + + for _, path := range paths { + req := httptest.NewRequest(http.MethodPost, path, strings.NewReader(`{}`)) + req.Header.Set("Content-Type", "application/json") + req.RemoteAddr = "203.0.113.10:12345" + + w := httptest.NewRecorder() + router.ServeHTTP(w, req) + + require.Equal(t, http.StatusTooManyRequests, w.Code, "path=%s", path) + require.Contains(t, w.Body.String(), "rate limit exceeded", "path=%s", path) + } +} diff --git a/backend/internal/server/routes/gateway.go b/backend/internal/server/routes/gateway.go index bf019ce3..930c8b9e 100644 --- a/backend/internal/server/routes/gateway.go +++ b/backend/internal/server/routes/gateway.go @@ -1,6 +1,8 @@ package routes import ( + "net/http" + "github.com/Wei-Shaw/sub2api/internal/config" "github.com/Wei-Shaw/sub2api/internal/handler" "github.com/Wei-Shaw/sub2api/internal/server/middleware" @@ -20,6 +22,11 @@ func RegisterGatewayRoutes( cfg *config.Config, ) { bodyLimit := middleware.RequestBodyLimit(cfg.Gateway.MaxBodySize) + soraMaxBodySize := cfg.Gateway.SoraMaxBodySize + if soraMaxBodySize <= 0 { + soraMaxBodySize = cfg.Gateway.MaxBodySize + } + soraBodyLimit := middleware.RequestBodyLimit(soraMaxBodySize) clientRequestID := middleware.ClientRequestID() opsErrorLogger := handler.OpsErrorLoggerMiddleware(opsService) @@ -36,6 +43,15 @@ func RegisterGatewayRoutes( gateway.GET("/usage", h.Gateway.Usage) // OpenAI Responses API gateway.POST("/responses", h.OpenAIGateway.Responses) + // 明确阻止旧协议入口:OpenAI 仅支持 Responses API,避免客户端误解为会自动路由到其它平台。 + gateway.POST("/chat/completions", func(c *gin.Context) { + c.JSON(http.StatusBadRequest, gin.H{ + "error": gin.H{ + "type": "invalid_request_error", + "message": "Unsupported legacy protocol: /v1/chat/completions is not supported. Please use /v1/responses.", + }, + }) + }) } // Gemini 原生 API 兼容层(Gemini SDK/CLI 直连) @@ -82,4 +98,25 @@ func RegisterGatewayRoutes( antigravityV1Beta.GET("/models/:model", h.Gateway.GeminiV1BetaGetModel) antigravityV1Beta.POST("/models/*modelAction", h.Gateway.GeminiV1BetaModels) } + + // Sora 专用路由(强制使用 sora 平台) + soraV1 := r.Group("/sora/v1") + soraV1.Use(soraBodyLimit) + soraV1.Use(clientRequestID) + soraV1.Use(opsErrorLogger) + soraV1.Use(middleware.ForcePlatform(service.PlatformSora)) + soraV1.Use(gin.HandlerFunc(apiKeyAuth)) + { + soraV1.POST("/chat/completions", h.SoraGateway.ChatCompletions) + soraV1.GET("/models", h.Gateway.Models) + } + + // Sora 媒体代理(可选 API Key 验证) + if cfg.Gateway.SoraMediaRequireAPIKey { + r.GET("/sora/media/*filepath", gin.HandlerFunc(apiKeyAuth), h.SoraGateway.MediaProxy) + } else { + r.GET("/sora/media/*filepath", h.SoraGateway.MediaProxy) + } + // Sora 媒体代理(签名 URL,无需 API Key) + r.GET("/sora/media-signed/*filepath", h.SoraGateway.MediaProxySigned) } diff --git a/backend/internal/service/account.go b/backend/internal/service/account.go index fa3ce738..51ab84dd 100644 --- a/backend/internal/service/account.go +++ b/backend/internal/service/account.go @@ -696,6 +696,51 @@ func (a *Account) IsMixedSchedulingEnabled() bool { return false } +// IsOpenAIPassthroughEnabled 返回 OpenAI 账号是否启用“自动透传(仅替换认证)”。 +// +// 新字段:accounts.extra.openai_passthrough。 +// 兼容字段:accounts.extra.openai_oauth_passthrough(历史 OAuth 开关)。 +// 字段缺失或类型不正确时,按 false(关闭)处理。 +func (a *Account) IsOpenAIPassthroughEnabled() bool { + if a == nil || !a.IsOpenAI() || a.Extra == nil { + return false + } + if enabled, ok := a.Extra["openai_passthrough"].(bool); ok { + return enabled + } + if enabled, ok := a.Extra["openai_oauth_passthrough"].(bool); ok { + return enabled + } + return false +} + +// IsOpenAIOAuthPassthroughEnabled 兼容旧接口,等价于 OAuth 账号的 IsOpenAIPassthroughEnabled。 +func (a *Account) IsOpenAIOAuthPassthroughEnabled() bool { + return a != nil && a.IsOpenAIOAuth() && a.IsOpenAIPassthroughEnabled() +} + +// IsAnthropicAPIKeyPassthroughEnabled 返回 Anthropic API Key 账号是否启用“自动透传(仅替换认证)”。 +// 字段:accounts.extra.anthropic_passthrough。 +// 字段缺失或类型不正确时,按 false(关闭)处理。 +func (a *Account) IsAnthropicAPIKeyPassthroughEnabled() bool { + if a == nil || a.Platform != PlatformAnthropic || a.Type != AccountTypeAPIKey || a.Extra == nil { + return false + } + enabled, ok := a.Extra["anthropic_passthrough"].(bool) + return ok && enabled +} + +// IsCodexCLIOnlyEnabled 返回 OpenAI OAuth 账号是否启用“仅允许 Codex 官方客户端”。 +// 字段:accounts.extra.codex_cli_only。 +// 字段缺失或类型不正确时,按 false(关闭)处理。 +func (a *Account) IsCodexCLIOnlyEnabled() bool { + if a == nil || !a.IsOpenAIOAuth() || a.Extra == nil { + return false + } + enabled, ok := a.Extra["codex_cli_only"].(bool) + return ok && enabled +} + // WindowCostSchedulability 窗口费用调度状态 type WindowCostSchedulability int diff --git a/backend/internal/service/account_anthropic_passthrough_test.go b/backend/internal/service/account_anthropic_passthrough_test.go new file mode 100644 index 00000000..e66407a3 --- /dev/null +++ b/backend/internal/service/account_anthropic_passthrough_test.go @@ -0,0 +1,62 @@ +package service + +import ( + "testing" + + "github.com/stretchr/testify/require" +) + +func TestAccount_IsAnthropicAPIKeyPassthroughEnabled(t *testing.T) { + t.Run("Anthropic API Key 开启", func(t *testing.T) { + account := &Account{ + Platform: PlatformAnthropic, + Type: AccountTypeAPIKey, + Extra: map[string]any{ + "anthropic_passthrough": true, + }, + } + require.True(t, account.IsAnthropicAPIKeyPassthroughEnabled()) + }) + + t.Run("Anthropic API Key 关闭", func(t *testing.T) { + account := &Account{ + Platform: PlatformAnthropic, + Type: AccountTypeAPIKey, + Extra: map[string]any{ + "anthropic_passthrough": false, + }, + } + require.False(t, account.IsAnthropicAPIKeyPassthroughEnabled()) + }) + + t.Run("字段类型非法默认关闭", func(t *testing.T) { + account := &Account{ + Platform: PlatformAnthropic, + Type: AccountTypeAPIKey, + Extra: map[string]any{ + "anthropic_passthrough": "true", + }, + } + require.False(t, account.IsAnthropicAPIKeyPassthroughEnabled()) + }) + + t.Run("非 Anthropic API Key 账号始终关闭", func(t *testing.T) { + oauth := &Account{ + Platform: PlatformAnthropic, + Type: AccountTypeOAuth, + Extra: map[string]any{ + "anthropic_passthrough": true, + }, + } + require.False(t, oauth.IsAnthropicAPIKeyPassthroughEnabled()) + + openai := &Account{ + Platform: PlatformOpenAI, + Type: AccountTypeAPIKey, + Extra: map[string]any{ + "anthropic_passthrough": true, + }, + } + require.False(t, openai.IsAnthropicAPIKeyPassthroughEnabled()) + }) +} diff --git a/backend/internal/service/account_openai_passthrough_test.go b/backend/internal/service/account_openai_passthrough_test.go new file mode 100644 index 00000000..59f8cd8c --- /dev/null +++ b/backend/internal/service/account_openai_passthrough_test.go @@ -0,0 +1,136 @@ +package service + +import ( + "testing" + + "github.com/stretchr/testify/require" +) + +func TestAccount_IsOpenAIPassthroughEnabled(t *testing.T) { + t.Run("新字段开启", func(t *testing.T) { + account := &Account{ + Platform: PlatformOpenAI, + Type: AccountTypeAPIKey, + Extra: map[string]any{ + "openai_passthrough": true, + }, + } + require.True(t, account.IsOpenAIPassthroughEnabled()) + }) + + t.Run("兼容旧字段", func(t *testing.T) { + account := &Account{ + Platform: PlatformOpenAI, + Type: AccountTypeOAuth, + Extra: map[string]any{ + "openai_oauth_passthrough": true, + }, + } + require.True(t, account.IsOpenAIPassthroughEnabled()) + }) + + t.Run("非OpenAI账号始终关闭", func(t *testing.T) { + account := &Account{ + Platform: PlatformAnthropic, + Type: AccountTypeOAuth, + Extra: map[string]any{ + "openai_passthrough": true, + }, + } + require.False(t, account.IsOpenAIPassthroughEnabled()) + }) + + t.Run("空额外配置默认关闭", func(t *testing.T) { + account := &Account{ + Platform: PlatformOpenAI, + Type: AccountTypeOAuth, + } + require.False(t, account.IsOpenAIPassthroughEnabled()) + }) +} + +func TestAccount_IsOpenAIOAuthPassthroughEnabled(t *testing.T) { + t.Run("仅OAuth类型允许返回开启", func(t *testing.T) { + oauthAccount := &Account{ + Platform: PlatformOpenAI, + Type: AccountTypeOAuth, + Extra: map[string]any{ + "openai_passthrough": true, + }, + } + require.True(t, oauthAccount.IsOpenAIOAuthPassthroughEnabled()) + + apiKeyAccount := &Account{ + Platform: PlatformOpenAI, + Type: AccountTypeAPIKey, + Extra: map[string]any{ + "openai_passthrough": true, + }, + } + require.False(t, apiKeyAccount.IsOpenAIOAuthPassthroughEnabled()) + }) +} + +func TestAccount_IsCodexCLIOnlyEnabled(t *testing.T) { + t.Run("OpenAI OAuth 开启", func(t *testing.T) { + account := &Account{ + Platform: PlatformOpenAI, + Type: AccountTypeOAuth, + Extra: map[string]any{ + "codex_cli_only": true, + }, + } + require.True(t, account.IsCodexCLIOnlyEnabled()) + }) + + t.Run("OpenAI OAuth 关闭", func(t *testing.T) { + account := &Account{ + Platform: PlatformOpenAI, + Type: AccountTypeOAuth, + Extra: map[string]any{ + "codex_cli_only": false, + }, + } + require.False(t, account.IsCodexCLIOnlyEnabled()) + }) + + t.Run("字段缺失默认关闭", func(t *testing.T) { + account := &Account{ + Platform: PlatformOpenAI, + Type: AccountTypeOAuth, + Extra: map[string]any{}, + } + require.False(t, account.IsCodexCLIOnlyEnabled()) + }) + + t.Run("类型非法默认关闭", func(t *testing.T) { + account := &Account{ + Platform: PlatformOpenAI, + Type: AccountTypeOAuth, + Extra: map[string]any{ + "codex_cli_only": "true", + }, + } + require.False(t, account.IsCodexCLIOnlyEnabled()) + }) + + t.Run("非 OAuth 账号始终关闭", func(t *testing.T) { + apiKeyAccount := &Account{ + Platform: PlatformOpenAI, + Type: AccountTypeAPIKey, + Extra: map[string]any{ + "codex_cli_only": true, + }, + } + require.False(t, apiKeyAccount.IsCodexCLIOnlyEnabled()) + + otherPlatform := &Account{ + Platform: PlatformAnthropic, + Type: AccountTypeOAuth, + Extra: map[string]any{ + "codex_cli_only": true, + }, + } + require.False(t, otherPlatform.IsCodexCLIOnlyEnabled()) + }) +} diff --git a/backend/internal/service/account_service.go b/backend/internal/service/account_service.go index f192fba4..b301049f 100644 --- a/backend/internal/service/account_service.go +++ b/backend/internal/service/account_service.go @@ -25,6 +25,9 @@ type AccountRepository interface { // GetByCRSAccountID finds an account previously synced from CRS. // Returns (nil, nil) if not found. GetByCRSAccountID(ctx context.Context, crsAccountID string) (*Account, error) + // FindByExtraField 根据 extra 字段中的键值对查找账号(限定 platform='sora') + // 用于查找通过 linked_openai_account_id 关联的 Sora 账号 + FindByExtraField(ctx context.Context, key string, value any) ([]Account, error) // ListCRSAccountIDs returns a map of crs_account_id -> local account ID // for all accounts that have been synced from CRS. ListCRSAccountIDs(ctx context.Context) (map[string]int64, error) diff --git a/backend/internal/service/account_service_delete_test.go b/backend/internal/service/account_service_delete_test.go index a420d46b..a466b68a 100644 --- a/backend/internal/service/account_service_delete_test.go +++ b/backend/internal/service/account_service_delete_test.go @@ -54,6 +54,10 @@ func (s *accountRepoStub) GetByCRSAccountID(ctx context.Context, crsAccountID st panic("unexpected GetByCRSAccountID call") } +func (s *accountRepoStub) FindByExtraField(ctx context.Context, key string, value any) ([]Account, error) { + panic("unexpected FindByExtraField call") +} + func (s *accountRepoStub) ListCRSAccountIDs(ctx context.Context) (map[string]int64, error) { panic("unexpected ListCRSAccountIDs call") } diff --git a/backend/internal/service/account_test_service.go b/backend/internal/service/account_test_service.go index 899a4498..a507efb4 100644 --- a/backend/internal/service/account_test_service.go +++ b/backend/internal/service/account_test_service.go @@ -12,13 +12,17 @@ import ( "io" "log" "net/http" + "net/url" "regexp" "strings" + "sync" + "time" "github.com/Wei-Shaw/sub2api/internal/config" "github.com/Wei-Shaw/sub2api/internal/pkg/claude" "github.com/Wei-Shaw/sub2api/internal/pkg/geminicli" "github.com/Wei-Shaw/sub2api/internal/pkg/openai" + "github.com/Wei-Shaw/sub2api/internal/util/soraerror" "github.com/Wei-Shaw/sub2api/internal/util/urlvalidator" "github.com/gin-gonic/gin" "github.com/google/uuid" @@ -31,6 +35,11 @@ var sseDataPrefix = regexp.MustCompile(`^data:\s*`) const ( testClaudeAPIURL = "https://api.anthropic.com/v1/messages" chatgptCodexAPIURL = "https://chatgpt.com/backend-api/codex/responses" + soraMeAPIURL = "https://sora.chatgpt.com/backend/me" // Sora 用户信息接口,用于测试连接 + soraBillingAPIURL = "https://sora.chatgpt.com/backend/billing/subscriptions" + soraInviteMineURL = "https://sora.chatgpt.com/backend/project_y/invite/mine" + soraBootstrapURL = "https://sora.chatgpt.com/backend/m/bootstrap" + soraRemainingURL = "https://sora.chatgpt.com/backend/nf/check" ) // TestEvent represents a SSE event for account testing @@ -38,6 +47,9 @@ type TestEvent struct { Type string `json:"type"` Text string `json:"text,omitempty"` Model string `json:"model,omitempty"` + Status string `json:"status,omitempty"` + Code string `json:"code,omitempty"` + Data any `json:"data,omitempty"` Success bool `json:"success,omitempty"` Error string `json:"error,omitempty"` } @@ -49,8 +61,13 @@ type AccountTestService struct { antigravityGatewayService *AntigravityGatewayService httpUpstream HTTPUpstream cfg *config.Config + soraTestGuardMu sync.Mutex + soraTestLastRun map[int64]time.Time + soraTestCooldown time.Duration } +const defaultSoraTestCooldown = 10 * time.Second + // NewAccountTestService creates a new AccountTestService func NewAccountTestService( accountRepo AccountRepository, @@ -65,6 +82,8 @@ func NewAccountTestService( antigravityGatewayService: antigravityGatewayService, httpUpstream: httpUpstream, cfg: cfg, + soraTestLastRun: make(map[int64]time.Time), + soraTestCooldown: defaultSoraTestCooldown, } } @@ -163,6 +182,10 @@ func (s *AccountTestService) TestAccountConnection(c *gin.Context, accountID int return s.testAntigravityAccountConnection(c, account, modelID) } + if account.Platform == PlatformSora { + return s.testSoraAccountConnection(c, account) + } + return s.testClaudeAccountConnection(c, account, modelID) } @@ -462,6 +485,604 @@ func (s *AccountTestService) testGeminiAccountConnection(c *gin.Context, account return s.processGeminiStream(c, resp.Body) } +type soraProbeStep struct { + Name string `json:"name"` + Status string `json:"status"` + HTTPStatus int `json:"http_status,omitempty"` + ErrorCode string `json:"error_code,omitempty"` + Message string `json:"message,omitempty"` +} + +type soraProbeSummary struct { + Status string `json:"status"` + Steps []soraProbeStep `json:"steps"` +} + +type soraProbeRecorder struct { + steps []soraProbeStep +} + +func (r *soraProbeRecorder) addStep(name, status string, httpStatus int, errorCode, message string) { + r.steps = append(r.steps, soraProbeStep{ + Name: name, + Status: status, + HTTPStatus: httpStatus, + ErrorCode: strings.TrimSpace(errorCode), + Message: strings.TrimSpace(message), + }) +} + +func (r *soraProbeRecorder) finalize() soraProbeSummary { + meSuccess := false + partial := false + for _, step := range r.steps { + if step.Name == "me" { + meSuccess = strings.EqualFold(step.Status, "success") + continue + } + if strings.EqualFold(step.Status, "failed") { + partial = true + } + } + + status := "success" + if !meSuccess { + status = "failed" + } else if partial { + status = "partial_success" + } + + return soraProbeSummary{ + Status: status, + Steps: append([]soraProbeStep(nil), r.steps...), + } +} + +func (s *AccountTestService) emitSoraProbeSummary(c *gin.Context, rec *soraProbeRecorder) { + if rec == nil { + return + } + summary := rec.finalize() + code := "" + for _, step := range summary.Steps { + if strings.EqualFold(step.Status, "failed") && strings.TrimSpace(step.ErrorCode) != "" { + code = step.ErrorCode + break + } + } + s.sendEvent(c, TestEvent{ + Type: "sora_test_result", + Status: summary.Status, + Code: code, + Data: summary, + }) +} + +func (s *AccountTestService) acquireSoraTestPermit(accountID int64) (time.Duration, bool) { + if accountID <= 0 { + return 0, true + } + s.soraTestGuardMu.Lock() + defer s.soraTestGuardMu.Unlock() + + if s.soraTestLastRun == nil { + s.soraTestLastRun = make(map[int64]time.Time) + } + cooldown := s.soraTestCooldown + if cooldown <= 0 { + cooldown = defaultSoraTestCooldown + } + + now := time.Now() + if lastRun, ok := s.soraTestLastRun[accountID]; ok { + elapsed := now.Sub(lastRun) + if elapsed < cooldown { + return cooldown - elapsed, false + } + } + s.soraTestLastRun[accountID] = now + return 0, true +} + +func ceilSeconds(d time.Duration) int { + if d <= 0 { + return 1 + } + sec := int(d / time.Second) + if d%time.Second != 0 { + sec++ + } + if sec < 1 { + sec = 1 + } + return sec +} + +// testSoraAccountConnection 测试 Sora 账号的连接 +// 调用 /backend/me 接口验证 access_token 有效性(不需要 Sentinel Token) +func (s *AccountTestService) testSoraAccountConnection(c *gin.Context, account *Account) error { + ctx := c.Request.Context() + recorder := &soraProbeRecorder{} + + authToken := account.GetCredential("access_token") + if authToken == "" { + recorder.addStep("me", "failed", http.StatusUnauthorized, "missing_access_token", "No access token available") + s.emitSoraProbeSummary(c, recorder) + return s.sendErrorAndEnd(c, "No access token available") + } + + // Set SSE headers + c.Writer.Header().Set("Content-Type", "text/event-stream") + c.Writer.Header().Set("Cache-Control", "no-cache") + c.Writer.Header().Set("Connection", "keep-alive") + c.Writer.Header().Set("X-Accel-Buffering", "no") + c.Writer.Flush() + + if wait, ok := s.acquireSoraTestPermit(account.ID); !ok { + msg := fmt.Sprintf("Sora 账号测试过于频繁,请 %d 秒后重试", ceilSeconds(wait)) + recorder.addStep("rate_limit", "failed", http.StatusTooManyRequests, "test_rate_limited", msg) + s.emitSoraProbeSummary(c, recorder) + return s.sendErrorAndEnd(c, msg) + } + + // Send test_start event + s.sendEvent(c, TestEvent{Type: "test_start", Model: "sora"}) + + req, err := http.NewRequestWithContext(ctx, "GET", soraMeAPIURL, nil) + if err != nil { + recorder.addStep("me", "failed", 0, "request_build_failed", err.Error()) + s.emitSoraProbeSummary(c, recorder) + return s.sendErrorAndEnd(c, "Failed to create request") + } + + // 使用 Sora 客户端标准请求头 + req.Header.Set("Authorization", "Bearer "+authToken) + req.Header.Set("User-Agent", "Sora/1.2026.007 (Android 15; 24122RKC7C; build 2600700)") + req.Header.Set("Accept", "application/json") + req.Header.Set("Accept-Language", "en-US,en;q=0.9") + req.Header.Set("Origin", "https://sora.chatgpt.com") + req.Header.Set("Referer", "https://sora.chatgpt.com/") + + // Get proxy URL + proxyURL := "" + if account.ProxyID != nil && account.Proxy != nil { + proxyURL = account.Proxy.URL() + } + enableSoraTLSFingerprint := s.shouldEnableSoraTLSFingerprint() + + resp, err := s.httpUpstream.DoWithTLS(req, proxyURL, account.ID, account.Concurrency, enableSoraTLSFingerprint) + if err != nil { + recorder.addStep("me", "failed", 0, "network_error", err.Error()) + s.emitSoraProbeSummary(c, recorder) + return s.sendErrorAndEnd(c, fmt.Sprintf("Request failed: %s", err.Error())) + } + defer func() { _ = resp.Body.Close() }() + + body, _ := io.ReadAll(resp.Body) + + if resp.StatusCode != http.StatusOK { + if isCloudflareChallengeResponse(resp.StatusCode, resp.Header, body) { + recorder.addStep("me", "failed", resp.StatusCode, "cf_challenge", "Cloudflare challenge detected") + s.emitSoraProbeSummary(c, recorder) + s.logSoraCloudflareChallenge(account, proxyURL, soraMeAPIURL, resp.Header, body) + return s.sendErrorAndEnd(c, formatCloudflareChallengeMessage(fmt.Sprintf("Sora request blocked by Cloudflare challenge (HTTP %d). Please switch to a clean proxy/network and retry.", resp.StatusCode), resp.Header, body)) + } + upstreamCode, upstreamMessage := soraerror.ExtractUpstreamErrorCodeAndMessage(body) + switch { + case resp.StatusCode == http.StatusUnauthorized && strings.EqualFold(upstreamCode, "token_invalidated"): + recorder.addStep("me", "failed", resp.StatusCode, "token_invalidated", "Sora token invalidated") + s.emitSoraProbeSummary(c, recorder) + return s.sendErrorAndEnd(c, "Sora token 已失效(token_invalidated),请重新授权账号") + case strings.EqualFold(upstreamCode, "unsupported_country_code"): + recorder.addStep("me", "failed", resp.StatusCode, "unsupported_country_code", "Sora is unavailable in current egress region") + s.emitSoraProbeSummary(c, recorder) + return s.sendErrorAndEnd(c, "Sora 在当前网络出口地区不可用(unsupported_country_code),请切换到支持地区后重试") + case strings.TrimSpace(upstreamMessage) != "": + recorder.addStep("me", "failed", resp.StatusCode, upstreamCode, upstreamMessage) + s.emitSoraProbeSummary(c, recorder) + return s.sendErrorAndEnd(c, fmt.Sprintf("Sora API returned %d: %s", resp.StatusCode, upstreamMessage)) + default: + recorder.addStep("me", "failed", resp.StatusCode, upstreamCode, "Sora me endpoint failed") + s.emitSoraProbeSummary(c, recorder) + return s.sendErrorAndEnd(c, fmt.Sprintf("Sora API returned %d: %s", resp.StatusCode, truncateSoraErrorBody(body, 512))) + } + } + recorder.addStep("me", "success", resp.StatusCode, "", "me endpoint ok") + + // 解析 /me 响应,提取用户信息 + var meResp map[string]any + if err := json.Unmarshal(body, &meResp); err != nil { + // 能收到 200 就说明 token 有效 + s.sendEvent(c, TestEvent{Type: "content", Text: "Sora connection OK (token valid)"}) + } else { + // 尝试提取用户名或邮箱信息 + info := "Sora connection OK" + if name, ok := meResp["name"].(string); ok && name != "" { + info = fmt.Sprintf("Sora connection OK - User: %s", name) + } else if email, ok := meResp["email"].(string); ok && email != "" { + info = fmt.Sprintf("Sora connection OK - Email: %s", email) + } + s.sendEvent(c, TestEvent{Type: "content", Text: info}) + } + + // 追加轻量能力检查:订阅信息查询(失败仅告警,不中断连接测试) + subReq, err := http.NewRequestWithContext(ctx, "GET", soraBillingAPIURL, nil) + if err == nil { + subReq.Header.Set("Authorization", "Bearer "+authToken) + subReq.Header.Set("User-Agent", "Sora/1.2026.007 (Android 15; 24122RKC7C; build 2600700)") + subReq.Header.Set("Accept", "application/json") + subReq.Header.Set("Accept-Language", "en-US,en;q=0.9") + subReq.Header.Set("Origin", "https://sora.chatgpt.com") + subReq.Header.Set("Referer", "https://sora.chatgpt.com/") + + subResp, subErr := s.httpUpstream.DoWithTLS(subReq, proxyURL, account.ID, account.Concurrency, enableSoraTLSFingerprint) + if subErr != nil { + recorder.addStep("subscription", "failed", 0, "network_error", subErr.Error()) + s.sendEvent(c, TestEvent{Type: "content", Text: fmt.Sprintf("Subscription check skipped: %s", subErr.Error())}) + } else { + subBody, _ := io.ReadAll(subResp.Body) + _ = subResp.Body.Close() + if subResp.StatusCode == http.StatusOK { + recorder.addStep("subscription", "success", subResp.StatusCode, "", "subscription endpoint ok") + if summary := parseSoraSubscriptionSummary(subBody); summary != "" { + s.sendEvent(c, TestEvent{Type: "content", Text: summary}) + } else { + s.sendEvent(c, TestEvent{Type: "content", Text: "Subscription check OK"}) + } + } else { + if isCloudflareChallengeResponse(subResp.StatusCode, subResp.Header, subBody) { + recorder.addStep("subscription", "failed", subResp.StatusCode, "cf_challenge", "Cloudflare challenge detected") + s.logSoraCloudflareChallenge(account, proxyURL, soraBillingAPIURL, subResp.Header, subBody) + s.sendEvent(c, TestEvent{Type: "content", Text: formatCloudflareChallengeMessage(fmt.Sprintf("Subscription check blocked by Cloudflare challenge (HTTP %d)", subResp.StatusCode), subResp.Header, subBody)}) + } else { + upstreamCode, upstreamMessage := soraerror.ExtractUpstreamErrorCodeAndMessage(subBody) + recorder.addStep("subscription", "failed", subResp.StatusCode, upstreamCode, upstreamMessage) + s.sendEvent(c, TestEvent{Type: "content", Text: fmt.Sprintf("Subscription check returned %d", subResp.StatusCode)}) + } + } + } + } + + // 追加 Sora2 能力探测(对齐 sora2api 的测试思路):邀请码 + 剩余额度。 + s.testSora2Capabilities(c, ctx, account, authToken, proxyURL, enableSoraTLSFingerprint, recorder) + + s.emitSoraProbeSummary(c, recorder) + s.sendEvent(c, TestEvent{Type: "test_complete", Success: true}) + return nil +} + +func (s *AccountTestService) testSora2Capabilities( + c *gin.Context, + ctx context.Context, + account *Account, + authToken string, + proxyURL string, + enableTLSFingerprint bool, + recorder *soraProbeRecorder, +) { + inviteStatus, inviteHeader, inviteBody, err := s.fetchSoraTestEndpoint( + ctx, + account, + authToken, + soraInviteMineURL, + proxyURL, + enableTLSFingerprint, + ) + if err != nil { + if recorder != nil { + recorder.addStep("sora2_invite", "failed", 0, "network_error", err.Error()) + } + s.sendEvent(c, TestEvent{Type: "content", Text: fmt.Sprintf("Sora2 invite check skipped: %s", err.Error())}) + return + } + + if inviteStatus == http.StatusUnauthorized { + bootstrapStatus, _, _, bootstrapErr := s.fetchSoraTestEndpoint( + ctx, + account, + authToken, + soraBootstrapURL, + proxyURL, + enableTLSFingerprint, + ) + if bootstrapErr == nil && bootstrapStatus == http.StatusOK { + if recorder != nil { + recorder.addStep("sora2_bootstrap", "success", bootstrapStatus, "", "bootstrap endpoint ok") + } + s.sendEvent(c, TestEvent{Type: "content", Text: "Sora2 bootstrap OK, retry invite check"}) + inviteStatus, inviteHeader, inviteBody, err = s.fetchSoraTestEndpoint( + ctx, + account, + authToken, + soraInviteMineURL, + proxyURL, + enableTLSFingerprint, + ) + if err != nil { + if recorder != nil { + recorder.addStep("sora2_invite", "failed", 0, "network_error", err.Error()) + } + s.sendEvent(c, TestEvent{Type: "content", Text: fmt.Sprintf("Sora2 invite retry failed: %s", err.Error())}) + return + } + } else if recorder != nil { + code := "" + msg := "" + if bootstrapErr != nil { + code = "network_error" + msg = bootstrapErr.Error() + } + recorder.addStep("sora2_bootstrap", "failed", bootstrapStatus, code, msg) + } + } + + if inviteStatus != http.StatusOK { + if isCloudflareChallengeResponse(inviteStatus, inviteHeader, inviteBody) { + if recorder != nil { + recorder.addStep("sora2_invite", "failed", inviteStatus, "cf_challenge", "Cloudflare challenge detected") + } + s.logSoraCloudflareChallenge(account, proxyURL, soraInviteMineURL, inviteHeader, inviteBody) + s.sendEvent(c, TestEvent{Type: "content", Text: formatCloudflareChallengeMessage(fmt.Sprintf("Sora2 invite check blocked by Cloudflare challenge (HTTP %d)", inviteStatus), inviteHeader, inviteBody)}) + return + } + upstreamCode, upstreamMessage := soraerror.ExtractUpstreamErrorCodeAndMessage(inviteBody) + if recorder != nil { + recorder.addStep("sora2_invite", "failed", inviteStatus, upstreamCode, upstreamMessage) + } + s.sendEvent(c, TestEvent{Type: "content", Text: fmt.Sprintf("Sora2 invite check returned %d", inviteStatus)}) + return + } + if recorder != nil { + recorder.addStep("sora2_invite", "success", inviteStatus, "", "invite endpoint ok") + } + + if summary := parseSoraInviteSummary(inviteBody); summary != "" { + s.sendEvent(c, TestEvent{Type: "content", Text: summary}) + } else { + s.sendEvent(c, TestEvent{Type: "content", Text: "Sora2 invite check OK"}) + } + + remainingStatus, remainingHeader, remainingBody, remainingErr := s.fetchSoraTestEndpoint( + ctx, + account, + authToken, + soraRemainingURL, + proxyURL, + enableTLSFingerprint, + ) + if remainingErr != nil { + if recorder != nil { + recorder.addStep("sora2_remaining", "failed", 0, "network_error", remainingErr.Error()) + } + s.sendEvent(c, TestEvent{Type: "content", Text: fmt.Sprintf("Sora2 remaining check skipped: %s", remainingErr.Error())}) + return + } + if remainingStatus != http.StatusOK { + if isCloudflareChallengeResponse(remainingStatus, remainingHeader, remainingBody) { + if recorder != nil { + recorder.addStep("sora2_remaining", "failed", remainingStatus, "cf_challenge", "Cloudflare challenge detected") + } + s.logSoraCloudflareChallenge(account, proxyURL, soraRemainingURL, remainingHeader, remainingBody) + s.sendEvent(c, TestEvent{Type: "content", Text: formatCloudflareChallengeMessage(fmt.Sprintf("Sora2 remaining check blocked by Cloudflare challenge (HTTP %d)", remainingStatus), remainingHeader, remainingBody)}) + return + } + upstreamCode, upstreamMessage := soraerror.ExtractUpstreamErrorCodeAndMessage(remainingBody) + if recorder != nil { + recorder.addStep("sora2_remaining", "failed", remainingStatus, upstreamCode, upstreamMessage) + } + s.sendEvent(c, TestEvent{Type: "content", Text: fmt.Sprintf("Sora2 remaining check returned %d", remainingStatus)}) + return + } + if recorder != nil { + recorder.addStep("sora2_remaining", "success", remainingStatus, "", "remaining endpoint ok") + } + if summary := parseSoraRemainingSummary(remainingBody); summary != "" { + s.sendEvent(c, TestEvent{Type: "content", Text: summary}) + } else { + s.sendEvent(c, TestEvent{Type: "content", Text: "Sora2 remaining check OK"}) + } +} + +func (s *AccountTestService) fetchSoraTestEndpoint( + ctx context.Context, + account *Account, + authToken string, + url string, + proxyURL string, + enableTLSFingerprint bool, +) (int, http.Header, []byte, error) { + req, err := http.NewRequestWithContext(ctx, "GET", url, nil) + if err != nil { + return 0, nil, nil, err + } + req.Header.Set("Authorization", "Bearer "+authToken) + req.Header.Set("User-Agent", "Sora/1.2026.007 (Android 15; 24122RKC7C; build 2600700)") + req.Header.Set("Accept", "application/json") + req.Header.Set("Accept-Language", "en-US,en;q=0.9") + req.Header.Set("Origin", "https://sora.chatgpt.com") + req.Header.Set("Referer", "https://sora.chatgpt.com/") + + resp, err := s.httpUpstream.DoWithTLS(req, proxyURL, account.ID, account.Concurrency, enableTLSFingerprint) + if err != nil { + return 0, nil, nil, err + } + defer func() { _ = resp.Body.Close() }() + + body, readErr := io.ReadAll(resp.Body) + if readErr != nil { + return resp.StatusCode, resp.Header, nil, readErr + } + return resp.StatusCode, resp.Header, body, nil +} + +func parseSoraSubscriptionSummary(body []byte) string { + var subResp struct { + Data []struct { + Plan struct { + ID string `json:"id"` + Title string `json:"title"` + } `json:"plan"` + EndTS string `json:"end_ts"` + } `json:"data"` + } + if err := json.Unmarshal(body, &subResp); err != nil { + return "" + } + if len(subResp.Data) == 0 { + return "" + } + + first := subResp.Data[0] + parts := make([]string, 0, 3) + if first.Plan.Title != "" { + parts = append(parts, first.Plan.Title) + } + if first.Plan.ID != "" { + parts = append(parts, first.Plan.ID) + } + if first.EndTS != "" { + parts = append(parts, "end="+first.EndTS) + } + if len(parts) == 0 { + return "" + } + return "Subscription: " + strings.Join(parts, " | ") +} + +func parseSoraInviteSummary(body []byte) string { + var inviteResp struct { + InviteCode string `json:"invite_code"` + RedeemedCount int64 `json:"redeemed_count"` + TotalCount int64 `json:"total_count"` + } + if err := json.Unmarshal(body, &inviteResp); err != nil { + return "" + } + + parts := []string{"Sora2: supported"} + if inviteResp.InviteCode != "" { + parts = append(parts, "invite="+inviteResp.InviteCode) + } + if inviteResp.TotalCount > 0 { + parts = append(parts, fmt.Sprintf("used=%d/%d", inviteResp.RedeemedCount, inviteResp.TotalCount)) + } + return strings.Join(parts, " | ") +} + +func parseSoraRemainingSummary(body []byte) string { + var remainingResp struct { + RateLimitAndCreditBalance struct { + EstimatedNumVideosRemaining int64 `json:"estimated_num_videos_remaining"` + RateLimitReached bool `json:"rate_limit_reached"` + AccessResetsInSeconds int64 `json:"access_resets_in_seconds"` + } `json:"rate_limit_and_credit_balance"` + } + if err := json.Unmarshal(body, &remainingResp); err != nil { + return "" + } + info := remainingResp.RateLimitAndCreditBalance + parts := []string{fmt.Sprintf("Sora2 remaining: %d", info.EstimatedNumVideosRemaining)} + if info.RateLimitReached { + parts = append(parts, "rate_limited=true") + } + if info.AccessResetsInSeconds > 0 { + parts = append(parts, fmt.Sprintf("reset_in=%ds", info.AccessResetsInSeconds)) + } + return strings.Join(parts, " | ") +} + +func (s *AccountTestService) shouldEnableSoraTLSFingerprint() bool { + if s == nil || s.cfg == nil { + return true + } + return !s.cfg.Sora.Client.DisableTLSFingerprint +} + +func isCloudflareChallengeResponse(statusCode int, headers http.Header, body []byte) bool { + return soraerror.IsCloudflareChallengeResponse(statusCode, headers, body) +} + +func formatCloudflareChallengeMessage(base string, headers http.Header, body []byte) string { + return soraerror.FormatCloudflareChallengeMessage(base, headers, body) +} + +func extractCloudflareRayID(headers http.Header, body []byte) string { + return soraerror.ExtractCloudflareRayID(headers, body) +} + +func extractSoraEgressIPHint(headers http.Header) string { + if headers == nil { + return "unknown" + } + candidates := []string{ + "x-openai-public-ip", + "x-envoy-external-address", + "cf-connecting-ip", + "x-forwarded-for", + } + for _, key := range candidates { + if value := strings.TrimSpace(headers.Get(key)); value != "" { + return value + } + } + return "unknown" +} + +func sanitizeProxyURLForLog(raw string) string { + raw = strings.TrimSpace(raw) + if raw == "" { + return "" + } + u, err := url.Parse(raw) + if err != nil { + return "" + } + if u.User != nil { + u.User = nil + } + return u.String() +} + +func endpointPathForLog(endpoint string) string { + parsed, err := url.Parse(strings.TrimSpace(endpoint)) + if err != nil || parsed.Path == "" { + return endpoint + } + return parsed.Path +} + +func (s *AccountTestService) logSoraCloudflareChallenge(account *Account, proxyURL, endpoint string, headers http.Header, body []byte) { + accountID := int64(0) + platform := "" + proxyID := "none" + if account != nil { + accountID = account.ID + platform = account.Platform + if account.ProxyID != nil { + proxyID = fmt.Sprintf("%d", *account.ProxyID) + } + } + cfRay := extractCloudflareRayID(headers, body) + if cfRay == "" { + cfRay = "unknown" + } + log.Printf( + "[SoraCFChallenge] account_id=%d platform=%s endpoint=%s path=%s proxy_id=%s proxy_url=%s cf_ray=%s egress_ip_hint=%s", + accountID, + platform, + endpoint, + endpointPathForLog(endpoint), + proxyID, + sanitizeProxyURLForLog(proxyURL), + cfRay, + extractSoraEgressIPHint(headers), + ) +} + +func truncateSoraErrorBody(body []byte, max int) string { + return soraerror.TruncateBody(body, max) +} + // testAntigravityAccountConnection tests an Antigravity account's connection // 支持 Claude 和 Gemini 两种协议,使用非流式请求 func (s *AccountTestService) testAntigravityAccountConnection(c *gin.Context, account *Account, modelID string) error { diff --git a/backend/internal/service/account_test_service_sora_test.go b/backend/internal/service/account_test_service_sora_test.go new file mode 100644 index 00000000..3dfac786 --- /dev/null +++ b/backend/internal/service/account_test_service_sora_test.go @@ -0,0 +1,319 @@ +package service + +import ( + "fmt" + "io" + "net/http" + "net/http/httptest" + "strings" + "testing" + "time" + + "github.com/Wei-Shaw/sub2api/internal/config" + "github.com/gin-gonic/gin" + "github.com/stretchr/testify/require" +) + +type queuedHTTPUpstream struct { + responses []*http.Response + requests []*http.Request + tlsFlags []bool +} + +func (u *queuedHTTPUpstream) Do(_ *http.Request, _ string, _ int64, _ int) (*http.Response, error) { + return nil, fmt.Errorf("unexpected Do call") +} + +func (u *queuedHTTPUpstream) DoWithTLS(req *http.Request, _ string, _ int64, _ int, enableTLSFingerprint bool) (*http.Response, error) { + u.requests = append(u.requests, req) + u.tlsFlags = append(u.tlsFlags, enableTLSFingerprint) + if len(u.responses) == 0 { + return nil, fmt.Errorf("no mocked response") + } + resp := u.responses[0] + u.responses = u.responses[1:] + return resp, nil +} + +func newJSONResponse(status int, body string) *http.Response { + return &http.Response{ + StatusCode: status, + Header: make(http.Header), + Body: io.NopCloser(strings.NewReader(body)), + } +} + +func newJSONResponseWithHeader(status int, body, key, value string) *http.Response { + resp := newJSONResponse(status, body) + resp.Header.Set(key, value) + return resp +} + +func newSoraTestContext() (*gin.Context, *httptest.ResponseRecorder) { + gin.SetMode(gin.TestMode) + rec := httptest.NewRecorder() + c, _ := gin.CreateTestContext(rec) + c.Request = httptest.NewRequest(http.MethodPost, "/api/v1/admin/accounts/1/test", nil) + return c, rec +} + +func TestAccountTestService_testSoraAccountConnection_WithSubscription(t *testing.T) { + upstream := &queuedHTTPUpstream{ + responses: []*http.Response{ + newJSONResponse(http.StatusOK, `{"email":"demo@example.com"}`), + newJSONResponse(http.StatusOK, `{"data":[{"plan":{"id":"chatgpt_plus","title":"ChatGPT Plus"},"end_ts":"2026-12-31T00:00:00Z"}]}`), + newJSONResponse(http.StatusOK, `{"invite_code":"inv_abc","redeemed_count":3,"total_count":50}`), + newJSONResponse(http.StatusOK, `{"rate_limit_and_credit_balance":{"estimated_num_videos_remaining":27,"rate_limit_reached":false,"access_resets_in_seconds":46833}}`), + }, + } + svc := &AccountTestService{ + httpUpstream: upstream, + cfg: &config.Config{ + Gateway: config.GatewayConfig{ + TLSFingerprint: config.TLSFingerprintConfig{ + Enabled: true, + }, + }, + Sora: config.SoraConfig{ + Client: config.SoraClientConfig{ + DisableTLSFingerprint: false, + }, + }, + }, + } + account := &Account{ + ID: 1, + Platform: PlatformSora, + Type: AccountTypeOAuth, + Concurrency: 1, + Credentials: map[string]any{ + "access_token": "test_token", + }, + } + + c, rec := newSoraTestContext() + err := svc.testSoraAccountConnection(c, account) + + require.NoError(t, err) + require.Len(t, upstream.requests, 4) + require.Equal(t, soraMeAPIURL, upstream.requests[0].URL.String()) + require.Equal(t, soraBillingAPIURL, upstream.requests[1].URL.String()) + require.Equal(t, soraInviteMineURL, upstream.requests[2].URL.String()) + require.Equal(t, soraRemainingURL, upstream.requests[3].URL.String()) + require.Equal(t, "Bearer test_token", upstream.requests[0].Header.Get("Authorization")) + require.Equal(t, "Bearer test_token", upstream.requests[1].Header.Get("Authorization")) + require.Equal(t, []bool{true, true, true, true}, upstream.tlsFlags) + + body := rec.Body.String() + require.Contains(t, body, `"type":"test_start"`) + require.Contains(t, body, "Sora connection OK - Email: demo@example.com") + require.Contains(t, body, "Subscription: ChatGPT Plus | chatgpt_plus | end=2026-12-31T00:00:00Z") + require.Contains(t, body, "Sora2: supported | invite=inv_abc | used=3/50") + require.Contains(t, body, "Sora2 remaining: 27 | reset_in=46833s") + require.Contains(t, body, `"type":"sora_test_result"`) + require.Contains(t, body, `"status":"success"`) + require.Contains(t, body, `"type":"test_complete","success":true`) +} + +func TestAccountTestService_testSoraAccountConnection_SubscriptionFailedStillSuccess(t *testing.T) { + upstream := &queuedHTTPUpstream{ + responses: []*http.Response{ + newJSONResponse(http.StatusOK, `{"name":"demo-user"}`), + newJSONResponse(http.StatusForbidden, `{"error":{"message":"forbidden"}}`), + newJSONResponse(http.StatusUnauthorized, `{"error":{"message":"Unauthorized"}}`), + newJSONResponse(http.StatusForbidden, `{"error":{"message":"forbidden"}}`), + }, + } + svc := &AccountTestService{httpUpstream: upstream} + account := &Account{ + ID: 1, + Platform: PlatformSora, + Type: AccountTypeOAuth, + Concurrency: 1, + Credentials: map[string]any{ + "access_token": "test_token", + }, + } + + c, rec := newSoraTestContext() + err := svc.testSoraAccountConnection(c, account) + + require.NoError(t, err) + require.Len(t, upstream.requests, 4) + body := rec.Body.String() + require.Contains(t, body, "Sora connection OK - User: demo-user") + require.Contains(t, body, "Subscription check returned 403") + require.Contains(t, body, "Sora2 invite check returned 401") + require.Contains(t, body, `"type":"sora_test_result"`) + require.Contains(t, body, `"status":"partial_success"`) + require.Contains(t, body, `"type":"test_complete","success":true`) +} + +func TestAccountTestService_testSoraAccountConnection_CloudflareChallenge(t *testing.T) { + upstream := &queuedHTTPUpstream{ + responses: []*http.Response{ + newJSONResponseWithHeader(http.StatusForbidden, `Just a moment...`, "cf-ray", "9cff2d62d83bb98d"), + }, + } + svc := &AccountTestService{httpUpstream: upstream} + account := &Account{ + ID: 1, + Platform: PlatformSora, + Type: AccountTypeOAuth, + Concurrency: 1, + Credentials: map[string]any{ + "access_token": "test_token", + }, + } + + c, rec := newSoraTestContext() + err := svc.testSoraAccountConnection(c, account) + + require.Error(t, err) + require.Contains(t, err.Error(), "Cloudflare challenge") + require.Contains(t, err.Error(), "cf-ray: 9cff2d62d83bb98d") + body := rec.Body.String() + require.Contains(t, body, `"type":"error"`) + require.Contains(t, body, "Cloudflare challenge") + require.Contains(t, body, "cf-ray: 9cff2d62d83bb98d") +} + +func TestAccountTestService_testSoraAccountConnection_CloudflareChallenge429WithHeader(t *testing.T) { + upstream := &queuedHTTPUpstream{ + responses: []*http.Response{ + newJSONResponseWithHeader(http.StatusTooManyRequests, `Just a moment...`, "cf-mitigated", "challenge"), + }, + } + svc := &AccountTestService{httpUpstream: upstream} + account := &Account{ + ID: 1, + Platform: PlatformSora, + Type: AccountTypeOAuth, + Concurrency: 1, + Credentials: map[string]any{ + "access_token": "test_token", + }, + } + + c, rec := newSoraTestContext() + err := svc.testSoraAccountConnection(c, account) + + require.Error(t, err) + require.Contains(t, err.Error(), "Cloudflare challenge") + require.Contains(t, err.Error(), "HTTP 429") + body := rec.Body.String() + require.Contains(t, body, "Cloudflare challenge") +} + +func TestAccountTestService_testSoraAccountConnection_TokenInvalidated(t *testing.T) { + upstream := &queuedHTTPUpstream{ + responses: []*http.Response{ + newJSONResponse(http.StatusUnauthorized, `{"error":{"code":"token_invalidated","message":"Token invalid"}}`), + }, + } + svc := &AccountTestService{httpUpstream: upstream} + account := &Account{ + ID: 1, + Platform: PlatformSora, + Type: AccountTypeOAuth, + Concurrency: 1, + Credentials: map[string]any{ + "access_token": "test_token", + }, + } + + c, rec := newSoraTestContext() + err := svc.testSoraAccountConnection(c, account) + + require.Error(t, err) + require.Contains(t, err.Error(), "token_invalidated") + body := rec.Body.String() + require.Contains(t, body, `"type":"sora_test_result"`) + require.Contains(t, body, `"status":"failed"`) + require.Contains(t, body, "token_invalidated") + require.NotContains(t, body, `"type":"test_complete","success":true`) +} + +func TestAccountTestService_testSoraAccountConnection_RateLimited(t *testing.T) { + upstream := &queuedHTTPUpstream{ + responses: []*http.Response{ + newJSONResponse(http.StatusOK, `{"email":"demo@example.com"}`), + }, + } + svc := &AccountTestService{ + httpUpstream: upstream, + soraTestCooldown: time.Hour, + } + account := &Account{ + ID: 1, + Platform: PlatformSora, + Type: AccountTypeOAuth, + Concurrency: 1, + Credentials: map[string]any{ + "access_token": "test_token", + }, + } + + c1, _ := newSoraTestContext() + err := svc.testSoraAccountConnection(c1, account) + require.NoError(t, err) + + c2, rec2 := newSoraTestContext() + err = svc.testSoraAccountConnection(c2, account) + require.Error(t, err) + require.Contains(t, err.Error(), "测试过于频繁") + body := rec2.Body.String() + require.Contains(t, body, `"type":"sora_test_result"`) + require.Contains(t, body, `"code":"test_rate_limited"`) + require.Contains(t, body, `"status":"failed"`) + require.NotContains(t, body, `"type":"test_complete","success":true`) +} + +func TestAccountTestService_testSoraAccountConnection_SubscriptionCloudflareChallengeWithRay(t *testing.T) { + upstream := &queuedHTTPUpstream{ + responses: []*http.Response{ + newJSONResponse(http.StatusOK, `{"name":"demo-user"}`), + newJSONResponse(http.StatusForbidden, `Just a moment...`), + newJSONResponse(http.StatusForbidden, `Just a moment...`), + }, + } + svc := &AccountTestService{httpUpstream: upstream} + account := &Account{ + ID: 1, + Platform: PlatformSora, + Type: AccountTypeOAuth, + Concurrency: 1, + Credentials: map[string]any{ + "access_token": "test_token", + }, + } + + c, rec := newSoraTestContext() + err := svc.testSoraAccountConnection(c, account) + + require.NoError(t, err) + body := rec.Body.String() + require.Contains(t, body, "Subscription check blocked by Cloudflare challenge (HTTP 403)") + require.Contains(t, body, "Sora2 invite check blocked by Cloudflare challenge (HTTP 403)") + require.Contains(t, body, "cf-ray: 9cff2d62d83bb98d") + require.Contains(t, body, `"type":"test_complete","success":true`) +} + +func TestSanitizeProxyURLForLog(t *testing.T) { + require.Equal(t, "http://proxy.example.com:8080", sanitizeProxyURLForLog("http://user:pass@proxy.example.com:8080")) + require.Equal(t, "", sanitizeProxyURLForLog("")) + require.Equal(t, "", sanitizeProxyURLForLog("://invalid")) +} + +func TestExtractSoraEgressIPHint(t *testing.T) { + h := make(http.Header) + h.Set("x-openai-public-ip", "203.0.113.10") + require.Equal(t, "203.0.113.10", extractSoraEgressIPHint(h)) + + h2 := make(http.Header) + h2.Set("x-envoy-external-address", "198.51.100.9") + require.Equal(t, "198.51.100.9", extractSoraEgressIPHint(h2)) + + require.Equal(t, "unknown", extractSoraEgressIPHint(nil)) + require.Equal(t, "unknown", extractSoraEgressIPHint(http.Header{})) +} diff --git a/backend/internal/service/account_usage_service.go b/backend/internal/service/account_usage_service.go index 304c5781..7698223e 100644 --- a/backend/internal/service/account_usage_service.go +++ b/backend/internal/service/account_usage_service.go @@ -36,8 +36,8 @@ type UsageLogRepository interface { GetModelStatsWithFilters(ctx context.Context, startTime, endTime time.Time, userID, apiKeyID, accountID, groupID int64, stream *bool, billingType *int8) ([]usagestats.ModelStat, error) GetAPIKeyUsageTrend(ctx context.Context, startTime, endTime time.Time, granularity string, limit int) ([]usagestats.APIKeyUsageTrendPoint, error) GetUserUsageTrend(ctx context.Context, startTime, endTime time.Time, granularity string, limit int) ([]usagestats.UserUsageTrendPoint, error) - GetBatchUserUsageStats(ctx context.Context, userIDs []int64) (map[int64]*usagestats.BatchUserUsageStats, error) - GetBatchAPIKeyUsageStats(ctx context.Context, apiKeyIDs []int64) (map[int64]*usagestats.BatchAPIKeyUsageStats, error) + GetBatchUserUsageStats(ctx context.Context, userIDs []int64, startTime, endTime time.Time) (map[int64]*usagestats.BatchUserUsageStats, error) + GetBatchAPIKeyUsageStats(ctx context.Context, apiKeyIDs []int64, startTime, endTime time.Time) (map[int64]*usagestats.BatchAPIKeyUsageStats, error) // User dashboard stats GetUserDashboardStats(ctx context.Context, userID int64) (*usagestats.UserDashboardStats, error) diff --git a/backend/internal/service/admin_service.go b/backend/internal/service/admin_service.go index 1f6e91e5..8614f24a 100644 --- a/backend/internal/service/admin_service.go +++ b/backend/internal/service/admin_service.go @@ -4,11 +4,15 @@ import ( "context" "errors" "fmt" - "log" + "io" + "net/http" "strings" "time" + "github.com/Wei-Shaw/sub2api/internal/pkg/httpclient" + "github.com/Wei-Shaw/sub2api/internal/pkg/logger" "github.com/Wei-Shaw/sub2api/internal/pkg/pagination" + "github.com/Wei-Shaw/sub2api/internal/util/soraerror" ) // AdminService interface defines admin management operations @@ -65,6 +69,7 @@ type AdminService interface { GetProxyAccounts(ctx context.Context, proxyID int64) ([]ProxyAccountSummary, error) CheckProxyExists(ctx context.Context, host string, port int, username, password string) (bool, error) TestProxy(ctx context.Context, id int64) (*ProxyTestResult, error) + CheckProxyQuality(ctx context.Context, id int64) (*ProxyQualityCheckResult, error) // Redeem code management ListRedeemCodes(ctx context.Context, page, pageSize int, codeType, status, search string) ([]RedeemCode, int64, error) @@ -111,11 +116,16 @@ type CreateGroupInput struct { WeeklyLimitUSD *float64 // 周限额 (USD) MonthlyLimitUSD *float64 // 月限额 (USD) // 图片生成计费配置(仅 antigravity 平台使用) - ImagePrice1K *float64 - ImagePrice2K *float64 - ImagePrice4K *float64 - ClaudeCodeOnly bool // 仅允许 Claude Code 客户端 - FallbackGroupID *int64 // 降级分组 ID + ImagePrice1K *float64 + ImagePrice2K *float64 + ImagePrice4K *float64 + // Sora 按次计费配置 + SoraImagePrice360 *float64 + SoraImagePrice540 *float64 + SoraVideoPricePerRequest *float64 + SoraVideoPricePerRequestHD *float64 + ClaudeCodeOnly bool // 仅允许 Claude Code 客户端 + FallbackGroupID *int64 // 降级分组 ID // 无效请求兜底分组 ID(仅 anthropic 平台使用) FallbackGroupIDOnInvalidRequest *int64 // 模型路由配置(仅 anthropic 平台使用) @@ -140,11 +150,16 @@ type UpdateGroupInput struct { WeeklyLimitUSD *float64 // 周限额 (USD) MonthlyLimitUSD *float64 // 月限额 (USD) // 图片生成计费配置(仅 antigravity 平台使用) - ImagePrice1K *float64 - ImagePrice2K *float64 - ImagePrice4K *float64 - ClaudeCodeOnly *bool // 仅允许 Claude Code 客户端 - FallbackGroupID *int64 // 降级分组 ID + ImagePrice1K *float64 + ImagePrice2K *float64 + ImagePrice4K *float64 + // Sora 按次计费配置 + SoraImagePrice360 *float64 + SoraImagePrice540 *float64 + SoraVideoPricePerRequest *float64 + SoraVideoPricePerRequestHD *float64 + ClaudeCodeOnly *bool // 仅允许 Claude Code 客户端 + FallbackGroupID *int64 // 降级分组 ID // 无效请求兜底分组 ID(仅 anthropic 平台使用) FallbackGroupIDOnInvalidRequest *int64 // 模型路由配置(仅 anthropic 平台使用) @@ -278,6 +293,32 @@ type ProxyTestResult struct { CountryCode string `json:"country_code,omitempty"` } +type ProxyQualityCheckResult struct { + ProxyID int64 `json:"proxy_id"` + Score int `json:"score"` + Grade string `json:"grade"` + Summary string `json:"summary"` + ExitIP string `json:"exit_ip,omitempty"` + Country string `json:"country,omitempty"` + CountryCode string `json:"country_code,omitempty"` + BaseLatencyMs int64 `json:"base_latency_ms,omitempty"` + PassedCount int `json:"passed_count"` + WarnCount int `json:"warn_count"` + FailedCount int `json:"failed_count"` + ChallengeCount int `json:"challenge_count"` + CheckedAt int64 `json:"checked_at"` + Items []ProxyQualityCheckItem `json:"items"` +} + +type ProxyQualityCheckItem struct { + Target string `json:"target"` + Status string `json:"status"` // pass/warn/fail/challenge + HTTPStatus int `json:"http_status,omitempty"` + LatencyMs int64 `json:"latency_ms,omitempty"` + Message string `json:"message,omitempty"` + CFRay string `json:"cf_ray,omitempty"` +} + // ProxyExitInfo represents proxy exit information from ip-api.com type ProxyExitInfo struct { IP string @@ -292,11 +333,64 @@ type ProxyExitInfoProber interface { ProbeProxy(ctx context.Context, proxyURL string) (*ProxyExitInfo, int64, error) } +type proxyQualityTarget struct { + Target string + URL string + Method string + AllowedStatuses map[int]struct{} +} + +var proxyQualityTargets = []proxyQualityTarget{ + { + Target: "openai", + URL: "https://api.openai.com/v1/models", + Method: http.MethodGet, + AllowedStatuses: map[int]struct{}{ + http.StatusUnauthorized: {}, + }, + }, + { + Target: "anthropic", + URL: "https://api.anthropic.com/v1/messages", + Method: http.MethodGet, + AllowedStatuses: map[int]struct{}{ + http.StatusUnauthorized: {}, + http.StatusMethodNotAllowed: {}, + http.StatusNotFound: {}, + http.StatusBadRequest: {}, + }, + }, + { + Target: "gemini", + URL: "https://generativelanguage.googleapis.com/$discovery/rest?version=v1beta", + Method: http.MethodGet, + AllowedStatuses: map[int]struct{}{ + http.StatusOK: {}, + }, + }, + { + Target: "sora", + URL: "https://sora.chatgpt.com/backend/me", + Method: http.MethodGet, + AllowedStatuses: map[int]struct{}{ + http.StatusUnauthorized: {}, + }, + }, +} + +const ( + proxyQualityRequestTimeout = 15 * time.Second + proxyQualityResponseHeaderTimeout = 10 * time.Second + proxyQualityMaxBodyBytes = int64(8 * 1024) + proxyQualityClientUserAgent = "Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/136.0.0.0 Safari/537.36" +) + // adminServiceImpl implements AdminService type adminServiceImpl struct { userRepo UserRepository groupRepo GroupRepository accountRepo AccountRepository + soraAccountRepo SoraAccountRepository // Sora 账号扩展表仓储 proxyRepo ProxyRepository apiKeyRepo APIKeyRepository redeemCodeRepo RedeemCodeRepository @@ -312,6 +406,7 @@ func NewAdminService( userRepo UserRepository, groupRepo GroupRepository, accountRepo AccountRepository, + soraAccountRepo SoraAccountRepository, proxyRepo ProxyRepository, apiKeyRepo APIKeyRepository, redeemCodeRepo RedeemCodeRepository, @@ -325,6 +420,7 @@ func NewAdminService( userRepo: userRepo, groupRepo: groupRepo, accountRepo: accountRepo, + soraAccountRepo: soraAccountRepo, proxyRepo: proxyRepo, apiKeyRepo: apiKeyRepo, redeemCodeRepo: redeemCodeRepo, @@ -348,7 +444,7 @@ func (s *adminServiceImpl) ListUsers(ctx context.Context, page, pageSize int, fi for i := range users { rates, err := s.userGroupRateRepo.GetByUserID(ctx, users[i].ID) if err != nil { - log.Printf("failed to load user group rates: user_id=%d err=%v", users[i].ID, err) + logger.LegacyPrintf("service.admin", "failed to load user group rates: user_id=%d err=%v", users[i].ID, err) continue } users[i].GroupRates = rates @@ -366,7 +462,7 @@ func (s *adminServiceImpl) GetUser(ctx context.Context, id int64) (*User, error) if s.userGroupRateRepo != nil { rates, err := s.userGroupRateRepo.GetByUserID(ctx, id) if err != nil { - log.Printf("failed to load user group rates: user_id=%d err=%v", id, err) + logger.LegacyPrintf("service.admin", "failed to load user group rates: user_id=%d err=%v", id, err) } else { user.GroupRates = rates } @@ -444,7 +540,7 @@ func (s *adminServiceImpl) UpdateUser(ctx context.Context, id int64, input *Upda // 同步用户专属分组倍率 if input.GroupRates != nil && s.userGroupRateRepo != nil { if err := s.userGroupRateRepo.SyncUserGroupRates(ctx, user.ID, input.GroupRates); err != nil { - log.Printf("failed to sync user group rates: user_id=%d err=%v", user.ID, err) + logger.LegacyPrintf("service.admin", "failed to sync user group rates: user_id=%d err=%v", user.ID, err) } } @@ -458,7 +554,7 @@ func (s *adminServiceImpl) UpdateUser(ctx context.Context, id int64, input *Upda if concurrencyDiff != 0 { code, err := GenerateRedeemCode() if err != nil { - log.Printf("failed to generate adjustment redeem code: %v", err) + logger.LegacyPrintf("service.admin", "failed to generate adjustment redeem code: %v", err) return user, nil } adjustmentRecord := &RedeemCode{ @@ -471,7 +567,7 @@ func (s *adminServiceImpl) UpdateUser(ctx context.Context, id int64, input *Upda now := time.Now() adjustmentRecord.UsedAt = &now if err := s.redeemCodeRepo.Create(ctx, adjustmentRecord); err != nil { - log.Printf("failed to create concurrency adjustment redeem code: %v", err) + logger.LegacyPrintf("service.admin", "failed to create concurrency adjustment redeem code: %v", err) } } @@ -488,7 +584,7 @@ func (s *adminServiceImpl) DeleteUser(ctx context.Context, id int64) error { return errors.New("cannot delete admin user") } if err := s.userRepo.Delete(ctx, id); err != nil { - log.Printf("delete user failed: user_id=%d err=%v", id, err) + logger.LegacyPrintf("service.admin", "delete user failed: user_id=%d err=%v", id, err) return err } if s.authCacheInvalidator != nil { @@ -531,7 +627,7 @@ func (s *adminServiceImpl) UpdateUserBalance(ctx context.Context, userID int64, cacheCtx, cancel := context.WithTimeout(context.Background(), 5*time.Second) defer cancel() if err := s.billingCacheService.InvalidateUserBalance(cacheCtx, userID); err != nil { - log.Printf("invalidate user balance cache failed: user_id=%d err=%v", userID, err) + logger.LegacyPrintf("service.admin", "invalidate user balance cache failed: user_id=%d err=%v", userID, err) } }() } @@ -539,7 +635,7 @@ func (s *adminServiceImpl) UpdateUserBalance(ctx context.Context, userID int64, if balanceDiff != 0 { code, err := GenerateRedeemCode() if err != nil { - log.Printf("failed to generate adjustment redeem code: %v", err) + logger.LegacyPrintf("service.admin", "failed to generate adjustment redeem code: %v", err) return user, nil } @@ -555,7 +651,7 @@ func (s *adminServiceImpl) UpdateUserBalance(ctx context.Context, userID int64, adjustmentRecord.UsedAt = &now if err := s.redeemCodeRepo.Create(ctx, adjustmentRecord); err != nil { - log.Printf("failed to create balance adjustment redeem code: %v", err) + logger.LegacyPrintf("service.admin", "failed to create balance adjustment redeem code: %v", err) } } @@ -639,6 +735,10 @@ func (s *adminServiceImpl) CreateGroup(ctx context.Context, input *CreateGroupIn imagePrice1K := normalizePrice(input.ImagePrice1K) imagePrice2K := normalizePrice(input.ImagePrice2K) imagePrice4K := normalizePrice(input.ImagePrice4K) + soraImagePrice360 := normalizePrice(input.SoraImagePrice360) + soraImagePrice540 := normalizePrice(input.SoraImagePrice540) + soraVideoPrice := normalizePrice(input.SoraVideoPricePerRequest) + soraVideoPriceHD := normalizePrice(input.SoraVideoPricePerRequestHD) // 校验降级分组 if input.FallbackGroupID != nil { @@ -709,6 +809,10 @@ func (s *adminServiceImpl) CreateGroup(ctx context.Context, input *CreateGroupIn ImagePrice1K: imagePrice1K, ImagePrice2K: imagePrice2K, ImagePrice4K: imagePrice4K, + SoraImagePrice360: soraImagePrice360, + SoraImagePrice540: soraImagePrice540, + SoraVideoPricePerRequest: soraVideoPrice, + SoraVideoPricePerRequestHD: soraVideoPriceHD, ClaudeCodeOnly: input.ClaudeCodeOnly, FallbackGroupID: input.FallbackGroupID, FallbackGroupIDOnInvalidRequest: fallbackOnInvalidRequest, @@ -865,6 +969,18 @@ func (s *adminServiceImpl) UpdateGroup(ctx context.Context, id int64, input *Upd if input.ImagePrice4K != nil { group.ImagePrice4K = normalizePrice(input.ImagePrice4K) } + if input.SoraImagePrice360 != nil { + group.SoraImagePrice360 = normalizePrice(input.SoraImagePrice360) + } + if input.SoraImagePrice540 != nil { + group.SoraImagePrice540 = normalizePrice(input.SoraImagePrice540) + } + if input.SoraVideoPricePerRequest != nil { + group.SoraVideoPricePerRequest = normalizePrice(input.SoraVideoPricePerRequest) + } + if input.SoraVideoPricePerRequestHD != nil { + group.SoraVideoPricePerRequestHD = normalizePrice(input.SoraVideoPricePerRequestHD) + } // Claude Code 客户端限制 if input.ClaudeCodeOnly != nil { @@ -993,7 +1109,7 @@ func (s *adminServiceImpl) DeleteGroup(ctx context.Context, id int64) error { defer cancel() for _, userID := range affectedUserIDs { if err := s.billingCacheService.InvalidateSubscription(cacheCtx, userID, groupID); err != nil { - log.Printf("invalidate subscription cache failed: user_id=%d group_id=%d err=%v", userID, groupID, err) + logger.LegacyPrintf("service.admin", "invalidate subscription cache failed: user_id=%d group_id=%d err=%v", userID, groupID, err) } } }() @@ -1103,6 +1219,18 @@ func (s *adminServiceImpl) CreateAccount(ctx context.Context, input *CreateAccou return nil, err } + // 如果是 Sora 平台账号,自动创建 sora_accounts 扩展表记录 + if account.Platform == PlatformSora && s.soraAccountRepo != nil { + soraUpdates := map[string]any{ + "access_token": account.GetCredential("access_token"), + "refresh_token": account.GetCredential("refresh_token"), + } + if err := s.soraAccountRepo.Upsert(ctx, account.ID, soraUpdates); err != nil { + // 只记录警告日志,不阻塞账号创建 + logger.LegacyPrintf("service.admin", "[AdminService] 创建 sora_accounts 记录失败: account_id=%d err=%v", account.ID, err) + } + } + // 绑定分组 if len(groupIDs) > 0 { if err := s.accountRepo.BindGroups(ctx, account.ID, groupIDs); err != nil { @@ -1200,7 +1328,11 @@ func (s *adminServiceImpl) UpdateAccount(ctx context.Context, id int64, input *U } // 重新查询以确保返回完整数据(包括正确的 Proxy 关联对象) - return s.accountRepo.GetByID(ctx, id) + updated, err := s.accountRepo.GetByID(ctx, id) + if err != nil { + return nil, err + } + return updated, nil } // BulkUpdateAccounts updates multiple accounts in one request. @@ -1216,16 +1348,21 @@ func (s *adminServiceImpl) BulkUpdateAccounts(ctx context.Context, input *BulkUp return result, nil } - // Preload account platforms for mixed channel risk checks if group bindings are requested. + needMixedChannelCheck := input.GroupIDs != nil && !input.SkipMixedChannelCheck + + // 预加载账号平台信息(混合渠道检查或 Sora 同步需要)。 platformByID := map[int64]string{} - if input.GroupIDs != nil && !input.SkipMixedChannelCheck { + if needMixedChannelCheck { accounts, err := s.accountRepo.GetByIDs(ctx, input.AccountIDs) if err != nil { - return nil, err - } - for _, account := range accounts { - if account != nil { - platformByID[account.ID] = account.Platform + if needMixedChannelCheck { + return nil, err + } + } else { + for _, account := range accounts { + if account != nil { + platformByID[account.ID] = account.Platform + } } } } @@ -1318,7 +1455,10 @@ func (s *adminServiceImpl) BulkUpdateAccounts(ctx context.Context, input *BulkUp } func (s *adminServiceImpl) DeleteAccount(ctx context.Context, id int64) error { - return s.accountRepo.Delete(ctx, id) + if err := s.accountRepo.Delete(ctx, id); err != nil { + return err + } + return nil } func (s *adminServiceImpl) RefreshAccountCredentials(ctx context.Context, id int64) (*Account, error) { @@ -1351,7 +1491,11 @@ func (s *adminServiceImpl) SetAccountSchedulable(ctx context.Context, id int64, if err := s.accountRepo.SetSchedulable(ctx, id, schedulable); err != nil { return nil, err } - return s.accountRepo.GetByID(ctx, id) + updated, err := s.accountRepo.GetByID(ctx, id) + if err != nil { + return nil, err + } + return updated, nil } // Proxy management implementations @@ -1629,6 +1773,270 @@ func (s *adminServiceImpl) TestProxy(ctx context.Context, id int64) (*ProxyTestR }, nil } +func (s *adminServiceImpl) CheckProxyQuality(ctx context.Context, id int64) (*ProxyQualityCheckResult, error) { + proxy, err := s.proxyRepo.GetByID(ctx, id) + if err != nil { + return nil, err + } + + result := &ProxyQualityCheckResult{ + ProxyID: id, + Score: 100, + Grade: "A", + CheckedAt: time.Now().Unix(), + Items: make([]ProxyQualityCheckItem, 0, len(proxyQualityTargets)+1), + } + + proxyURL := proxy.URL() + if s.proxyProber == nil { + result.Items = append(result.Items, ProxyQualityCheckItem{ + Target: "base_connectivity", + Status: "fail", + Message: "代理探测服务未配置", + }) + result.FailedCount++ + finalizeProxyQualityResult(result) + s.saveProxyQualitySnapshot(ctx, id, result, nil) + return result, nil + } + + exitInfo, latencyMs, err := s.proxyProber.ProbeProxy(ctx, proxyURL) + if err != nil { + result.Items = append(result.Items, ProxyQualityCheckItem{ + Target: "base_connectivity", + Status: "fail", + LatencyMs: latencyMs, + Message: err.Error(), + }) + result.FailedCount++ + finalizeProxyQualityResult(result) + s.saveProxyQualitySnapshot(ctx, id, result, nil) + return result, nil + } + + result.ExitIP = exitInfo.IP + result.Country = exitInfo.Country + result.CountryCode = exitInfo.CountryCode + result.BaseLatencyMs = latencyMs + result.Items = append(result.Items, ProxyQualityCheckItem{ + Target: "base_connectivity", + Status: "pass", + LatencyMs: latencyMs, + Message: "代理出口连通正常", + }) + result.PassedCount++ + + client, err := httpclient.GetClient(httpclient.Options{ + ProxyURL: proxyURL, + Timeout: proxyQualityRequestTimeout, + ResponseHeaderTimeout: proxyQualityResponseHeaderTimeout, + ProxyStrict: true, + }) + if err != nil { + result.Items = append(result.Items, ProxyQualityCheckItem{ + Target: "http_client", + Status: "fail", + Message: fmt.Sprintf("创建检测客户端失败: %v", err), + }) + result.FailedCount++ + finalizeProxyQualityResult(result) + s.saveProxyQualitySnapshot(ctx, id, result, exitInfo) + return result, nil + } + + for _, target := range proxyQualityTargets { + item := runProxyQualityTarget(ctx, client, target) + result.Items = append(result.Items, item) + switch item.Status { + case "pass": + result.PassedCount++ + case "warn": + result.WarnCount++ + case "challenge": + result.ChallengeCount++ + default: + result.FailedCount++ + } + } + + finalizeProxyQualityResult(result) + s.saveProxyQualitySnapshot(ctx, id, result, exitInfo) + return result, nil +} + +func runProxyQualityTarget(ctx context.Context, client *http.Client, target proxyQualityTarget) ProxyQualityCheckItem { + item := ProxyQualityCheckItem{ + Target: target.Target, + } + + req, err := http.NewRequestWithContext(ctx, target.Method, target.URL, nil) + if err != nil { + item.Status = "fail" + item.Message = fmt.Sprintf("构建请求失败: %v", err) + return item + } + req.Header.Set("Accept", "application/json,text/html,*/*") + req.Header.Set("User-Agent", proxyQualityClientUserAgent) + + start := time.Now() + resp, err := client.Do(req) + if err != nil { + item.Status = "fail" + item.LatencyMs = time.Since(start).Milliseconds() + item.Message = fmt.Sprintf("请求失败: %v", err) + return item + } + defer func() { _ = resp.Body.Close() }() + item.LatencyMs = time.Since(start).Milliseconds() + item.HTTPStatus = resp.StatusCode + + body, readErr := io.ReadAll(io.LimitReader(resp.Body, proxyQualityMaxBodyBytes+1)) + if readErr != nil { + item.Status = "fail" + item.Message = fmt.Sprintf("读取响应失败: %v", readErr) + return item + } + if int64(len(body)) > proxyQualityMaxBodyBytes { + body = body[:proxyQualityMaxBodyBytes] + } + + if target.Target == "sora" && soraerror.IsCloudflareChallengeResponse(resp.StatusCode, resp.Header, body) { + item.Status = "challenge" + item.CFRay = soraerror.ExtractCloudflareRayID(resp.Header, body) + item.Message = "Sora 命中 Cloudflare challenge" + return item + } + + if _, ok := target.AllowedStatuses[resp.StatusCode]; ok { + if resp.StatusCode >= http.StatusOK && resp.StatusCode < http.StatusMultipleChoices { + item.Status = "pass" + item.Message = fmt.Sprintf("HTTP %d", resp.StatusCode) + } else { + item.Status = "warn" + item.Message = fmt.Sprintf("HTTP %d(目标可达,但鉴权或方法受限)", resp.StatusCode) + } + return item + } + + if resp.StatusCode == http.StatusTooManyRequests { + item.Status = "warn" + item.Message = "目标返回 429,可能存在频控" + return item + } + + item.Status = "fail" + item.Message = fmt.Sprintf("非预期状态码: %d", resp.StatusCode) + return item +} + +func finalizeProxyQualityResult(result *ProxyQualityCheckResult) { + if result == nil { + return + } + score := 100 - result.WarnCount*10 - result.FailedCount*22 - result.ChallengeCount*30 + if score < 0 { + score = 0 + } + result.Score = score + result.Grade = proxyQualityGrade(score) + result.Summary = fmt.Sprintf( + "通过 %d 项,告警 %d 项,失败 %d 项,挑战 %d 项", + result.PassedCount, + result.WarnCount, + result.FailedCount, + result.ChallengeCount, + ) +} + +func proxyQualityGrade(score int) string { + switch { + case score >= 90: + return "A" + case score >= 75: + return "B" + case score >= 60: + return "C" + case score >= 40: + return "D" + default: + return "F" + } +} + +func proxyQualityOverallStatus(result *ProxyQualityCheckResult) string { + if result == nil { + return "" + } + if result.ChallengeCount > 0 { + return "challenge" + } + if result.FailedCount > 0 { + return "failed" + } + if result.WarnCount > 0 { + return "warn" + } + if result.PassedCount > 0 { + return "healthy" + } + return "failed" +} + +func proxyQualityFirstCFRay(result *ProxyQualityCheckResult) string { + if result == nil { + return "" + } + for _, item := range result.Items { + if item.CFRay != "" { + return item.CFRay + } + } + return "" +} + +func proxyQualityBaseConnectivityPass(result *ProxyQualityCheckResult) bool { + if result == nil { + return false + } + for _, item := range result.Items { + if item.Target == "base_connectivity" { + return item.Status == "pass" + } + } + return false +} + +func (s *adminServiceImpl) saveProxyQualitySnapshot(ctx context.Context, proxyID int64, result *ProxyQualityCheckResult, exitInfo *ProxyExitInfo) { + if result == nil { + return + } + score := result.Score + checkedAt := result.CheckedAt + info := &ProxyLatencyInfo{ + Success: proxyQualityBaseConnectivityPass(result), + Message: result.Summary, + QualityStatus: proxyQualityOverallStatus(result), + QualityScore: &score, + QualityGrade: result.Grade, + QualitySummary: result.Summary, + QualityCheckedAt: &checkedAt, + QualityCFRay: proxyQualityFirstCFRay(result), + UpdatedAt: time.Now(), + } + if result.BaseLatencyMs > 0 { + latency := result.BaseLatencyMs + info.LatencyMs = &latency + } + if exitInfo != nil { + info.IPAddress = exitInfo.IP + info.Country = exitInfo.Country + info.CountryCode = exitInfo.CountryCode + info.Region = exitInfo.Region + info.City = exitInfo.City + } + s.saveProxyLatency(ctx, proxyID, info) +} + func (s *adminServiceImpl) probeProxyLatency(ctx context.Context, proxy *Proxy) { if s.proxyProber == nil || proxy == nil { return @@ -1718,7 +2126,7 @@ func (s *adminServiceImpl) attachProxyLatency(ctx context.Context, proxies []Pro latencies, err := s.proxyLatencyCache.GetProxyLatencies(ctx, ids) if err != nil { - log.Printf("Warning: load proxy latency cache failed: %v", err) + logger.LegacyPrintf("service.admin", "Warning: load proxy latency cache failed: %v", err) return } @@ -1739,6 +2147,11 @@ func (s *adminServiceImpl) attachProxyLatency(ctx context.Context, proxies []Pro proxies[i].CountryCode = info.CountryCode proxies[i].Region = info.Region proxies[i].City = info.City + proxies[i].QualityStatus = info.QualityStatus + proxies[i].QualityScore = info.QualityScore + proxies[i].QualityGrade = info.QualityGrade + proxies[i].QualitySummary = info.QualitySummary + proxies[i].QualityChecked = info.QualityCheckedAt } } @@ -1746,8 +2159,28 @@ func (s *adminServiceImpl) saveProxyLatency(ctx context.Context, proxyID int64, if s.proxyLatencyCache == nil || info == nil { return } - if err := s.proxyLatencyCache.SetProxyLatency(ctx, proxyID, info); err != nil { - log.Printf("Warning: store proxy latency cache failed: %v", err) + + merged := *info + if latencies, err := s.proxyLatencyCache.GetProxyLatencies(ctx, []int64{proxyID}); err == nil { + if existing := latencies[proxyID]; existing != nil { + if merged.QualityCheckedAt == nil && + merged.QualityScore == nil && + merged.QualityGrade == "" && + merged.QualityStatus == "" && + merged.QualitySummary == "" && + merged.QualityCFRay == "" { + merged.QualityStatus = existing.QualityStatus + merged.QualityScore = existing.QualityScore + merged.QualityGrade = existing.QualityGrade + merged.QualitySummary = existing.QualitySummary + merged.QualityCheckedAt = existing.QualityCheckedAt + merged.QualityCFRay = existing.QualityCFRay + } + } + } + + if err := s.proxyLatencyCache.SetProxyLatency(ctx, proxyID, &merged); err != nil { + logger.LegacyPrintf("service.admin", "Warning: store proxy latency cache failed: %v", err) } } diff --git a/backend/internal/service/admin_service_bulk_update_test.go b/backend/internal/service/admin_service_bulk_update_test.go index 662b95fb..0dccacbb 100644 --- a/backend/internal/service/admin_service_bulk_update_test.go +++ b/backend/internal/service/admin_service_bulk_update_test.go @@ -15,6 +15,13 @@ type accountRepoStubForBulkUpdate struct { bulkUpdateErr error bulkUpdateIDs []int64 bindGroupErrByID map[int64]error + getByIDsAccounts []*Account + getByIDsErr error + getByIDsCalled bool + getByIDsIDs []int64 + getByIDAccounts map[int64]*Account + getByIDErrByID map[int64]error + getByIDCalled []int64 } func (s *accountRepoStubForBulkUpdate) BulkUpdate(_ context.Context, ids []int64, _ AccountBulkUpdate) (int64, error) { @@ -32,6 +39,26 @@ func (s *accountRepoStubForBulkUpdate) BindGroups(_ context.Context, accountID i return nil } +func (s *accountRepoStubForBulkUpdate) GetByIDs(_ context.Context, ids []int64) ([]*Account, error) { + s.getByIDsCalled = true + s.getByIDsIDs = append([]int64{}, ids...) + if s.getByIDsErr != nil { + return nil, s.getByIDsErr + } + return s.getByIDsAccounts, nil +} + +func (s *accountRepoStubForBulkUpdate) GetByID(_ context.Context, id int64) (*Account, error) { + s.getByIDCalled = append(s.getByIDCalled, id) + if err, ok := s.getByIDErrByID[id]; ok { + return nil, err + } + if account, ok := s.getByIDAccounts[id]; ok { + return account, nil + } + return nil, errors.New("account not found") +} + // TestAdminService_BulkUpdateAccounts_AllSuccessIDs 验证批量更新成功时返回 success_ids/failed_ids。 func TestAdminService_BulkUpdateAccounts_AllSuccessIDs(t *testing.T) { repo := &accountRepoStubForBulkUpdate{} diff --git a/backend/internal/service/admin_service_proxy_quality_test.go b/backend/internal/service/admin_service_proxy_quality_test.go new file mode 100644 index 00000000..5a43cd9c --- /dev/null +++ b/backend/internal/service/admin_service_proxy_quality_test.go @@ -0,0 +1,95 @@ +package service + +import ( + "context" + "net/http" + "net/http/httptest" + "testing" + + "github.com/stretchr/testify/require" +) + +func TestFinalizeProxyQualityResult_ScoreAndGrade(t *testing.T) { + result := &ProxyQualityCheckResult{ + PassedCount: 2, + WarnCount: 1, + FailedCount: 1, + ChallengeCount: 1, + } + + finalizeProxyQualityResult(result) + + require.Equal(t, 38, result.Score) + require.Equal(t, "F", result.Grade) + require.Contains(t, result.Summary, "通过 2 项") + require.Contains(t, result.Summary, "告警 1 项") + require.Contains(t, result.Summary, "失败 1 项") + require.Contains(t, result.Summary, "挑战 1 项") +} + +func TestRunProxyQualityTarget_SoraChallenge(t *testing.T) { + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { + w.Header().Set("Content-Type", "text/html") + w.Header().Set("cf-ray", "test-ray-123") + w.WriteHeader(http.StatusForbidden) + _, _ = w.Write([]byte("Just a moment...")) + })) + defer server.Close() + + target := proxyQualityTarget{ + Target: "sora", + URL: server.URL, + Method: http.MethodGet, + AllowedStatuses: map[int]struct{}{ + http.StatusUnauthorized: {}, + }, + } + + item := runProxyQualityTarget(context.Background(), server.Client(), target) + require.Equal(t, "challenge", item.Status) + require.Equal(t, http.StatusForbidden, item.HTTPStatus) + require.Equal(t, "test-ray-123", item.CFRay) +} + +func TestRunProxyQualityTarget_AllowedStatusPass(t *testing.T) { + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { + w.WriteHeader(http.StatusOK) + _, _ = w.Write([]byte(`{"models":[]}`)) + })) + defer server.Close() + + target := proxyQualityTarget{ + Target: "gemini", + URL: server.URL, + Method: http.MethodGet, + AllowedStatuses: map[int]struct{}{ + http.StatusOK: {}, + }, + } + + item := runProxyQualityTarget(context.Background(), server.Client(), target) + require.Equal(t, "pass", item.Status) + require.Equal(t, http.StatusOK, item.HTTPStatus) +} + +func TestRunProxyQualityTarget_AllowedStatusWarnForUnauthorized(t *testing.T) { + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { + w.WriteHeader(http.StatusUnauthorized) + _, _ = w.Write([]byte(`{"error":"unauthorized"}`)) + })) + defer server.Close() + + target := proxyQualityTarget{ + Target: "openai", + URL: server.URL, + Method: http.MethodGet, + AllowedStatuses: map[int]struct{}{ + http.StatusUnauthorized: {}, + }, + } + + item := runProxyQualityTarget(context.Background(), server.Client(), target) + require.Equal(t, "warn", item.Status) + require.Equal(t, http.StatusUnauthorized, item.HTTPStatus) + require.Contains(t, item.Message, "目标可达") +} diff --git a/backend/internal/service/antigravity_gateway_service.go b/backend/internal/service/antigravity_gateway_service.go index 1d87f4b1..cf87b282 100644 --- a/backend/internal/service/antigravity_gateway_service.go +++ b/backend/internal/service/antigravity_gateway_service.go @@ -22,8 +22,10 @@ import ( "github.com/Wei-Shaw/sub2api/internal/pkg/antigravity" "github.com/Wei-Shaw/sub2api/internal/pkg/ctxkey" + "github.com/Wei-Shaw/sub2api/internal/pkg/logger" "github.com/gin-gonic/gin" "github.com/google/uuid" + "github.com/tidwall/gjson" ) const ( @@ -184,7 +186,7 @@ type smartRetryResult struct { func (s *AntigravityGatewayService) handleSmartRetry(p antigravityRetryLoopParams, resp *http.Response, respBody []byte, baseURL string, urlIdx int, availableURLs []string) *smartRetryResult { // "Resource has been exhausted" 是 URL 级别限流,切换 URL(仅 429) if resp.StatusCode == http.StatusTooManyRequests && isURLLevelRateLimit(respBody) && urlIdx < len(availableURLs)-1 { - log.Printf("%s URL fallback (429): %s -> %s", p.prefix, baseURL, availableURLs[urlIdx+1]) + logger.LegacyPrintf("service.antigravity_gateway", "%s URL fallback (429): %s -> %s", p.prefix, baseURL, availableURLs[urlIdx+1]) return &smartRetryResult{action: smartRetryActionContinueURL} } @@ -204,13 +206,13 @@ func (s *AntigravityGatewayService) handleSmartRetry(p antigravityRetryLoopParam if rateLimitDuration <= 0 { rateLimitDuration = antigravityDefaultRateLimitDuration } - log.Printf("%s status=%d oauth_long_delay model=%s account=%d upstream_retry_delay=%v body=%s (model rate limit, switch account)", + logger.LegacyPrintf("service.antigravity_gateway", "%s status=%d oauth_long_delay model=%s account=%d upstream_retry_delay=%v body=%s (model rate limit, switch account)", p.prefix, resp.StatusCode, modelName, p.account.ID, rateLimitDuration, truncateForLog(respBody, 200)) resetAt := time.Now().Add(rateLimitDuration) if !setModelRateLimitByModelName(p.ctx, p.accountRepo, p.account.ID, modelName, p.prefix, resp.StatusCode, resetAt, false) { p.handleError(p.ctx, p.prefix, p.account, resp.StatusCode, resp.Header, respBody, p.requestedModel, p.groupID, p.sessionHash, p.isStickySession) - log.Printf("%s status=%d rate_limited account=%d (no model mapping)", p.prefix, resp.StatusCode, p.account.ID) + logger.LegacyPrintf("service.antigravity_gateway", "%s status=%d rate_limited account=%d (no model mapping)", p.prefix, resp.StatusCode, p.account.ID) } else { s.updateAccountModelRateLimitInCache(p.ctx, p.account, modelName, resetAt) } @@ -273,7 +275,7 @@ func (s *AntigravityGatewayService) handleSmartRetry(p antigravityRetryLoopParam // 智能重试:创建新请求 retryReq, err := antigravity.NewAPIRequestWithURL(p.ctx, baseURL, p.action, p.accessToken, p.body) if err != nil { - log.Printf("%s status=smart_retry_request_build_failed error=%v", p.prefix, err) + logger.LegacyPrintf("service.antigravity_gateway", "%s status=smart_retry_request_build_failed error=%v", p.prefix, err) p.handleError(p.ctx, p.prefix, p.account, resp.StatusCode, resp.Header, respBody, p.requestedModel, p.groupID, p.sessionHash, p.isStickySession) return &smartRetryResult{ action: smartRetryActionBreakWithResp, @@ -356,7 +358,7 @@ func (s *AntigravityGatewayService) handleSmartRetry(p antigravityRetryLoopParam // 单账号 503 退避重试模式:智能重试耗尽后不设限流、不切换账号, // 直接返回 503 让 Handler 层的单账号退避循环做最终处理。 if resp.StatusCode == http.StatusServiceUnavailable && isSingleAccountRetry(p.ctx) { - log.Printf("%s status=%d smart_retry_exhausted_single_account attempts=%d model=%s account=%d body=%s (return 503 directly)", + logger.LegacyPrintf("service.antigravity_gateway", "%s status=%d smart_retry_exhausted_single_account attempts=%d model=%s account=%d body=%s (return 503 directly)", p.prefix, resp.StatusCode, antigravitySmartRetryMaxAttempts, modelName, p.account.ID, truncateForLog(retryBody, 200)) return &smartRetryResult{ action: smartRetryActionBreakWithResp, @@ -374,9 +376,9 @@ func (s *AntigravityGatewayService) handleSmartRetry(p antigravityRetryLoopParam resetAt := time.Now().Add(rateLimitDuration) if p.accountRepo != nil && modelName != "" { if err := p.accountRepo.SetModelRateLimit(p.ctx, p.account.ID, modelName, resetAt); err != nil { - log.Printf("%s status=%d model_rate_limit_failed model=%s error=%v", p.prefix, resp.StatusCode, modelName, err) + logger.LegacyPrintf("service.antigravity_gateway", "%s status=%d model_rate_limit_failed model=%s error=%v", p.prefix, resp.StatusCode, modelName, err) } else { - log.Printf("%s status=%d model_rate_limited_after_smart_retry model=%s account=%d reset_in=%v", + logger.LegacyPrintf("service.antigravity_gateway", "%s status=%d model_rate_limited_after_smart_retry model=%s account=%d reset_in=%v", p.prefix, resp.StatusCode, modelName, p.account.ID, rateLimitDuration) s.updateAccountModelRateLimitInCache(p.ctx, p.account, modelName, resetAt) } @@ -431,7 +433,7 @@ func (s *AntigravityGatewayService) handleSingleAccountRetryInPlace( waitDuration = antigravitySmartRetryMinWait } - log.Printf("%s status=%d single_account_503_retry_in_place model=%s account=%d upstream_retry_delay=%v (retrying in-place instead of rate-limiting)", + logger.LegacyPrintf("service.antigravity_gateway", "%s status=%d single_account_503_retry_in_place model=%s account=%d upstream_retry_delay=%v (retrying in-place instead of rate-limiting)", p.prefix, resp.StatusCode, modelName, p.account.ID, waitDuration) var lastRetryResp *http.Response @@ -443,21 +445,21 @@ func (s *AntigravityGatewayService) handleSingleAccountRetryInPlace( if totalWaited+waitDuration > antigravitySingleAccountSmartRetryTotalMaxWait { remaining := antigravitySingleAccountSmartRetryTotalMaxWait - totalWaited if remaining <= 0 { - log.Printf("%s single_account_503_retry: total_wait_exceeded total=%v max=%v, giving up", + logger.LegacyPrintf("service.antigravity_gateway", "%s single_account_503_retry: total_wait_exceeded total=%v max=%v, giving up", p.prefix, totalWaited, antigravitySingleAccountSmartRetryTotalMaxWait) break } waitDuration = remaining } - log.Printf("%s status=%d single_account_503_retry attempt=%d/%d delay=%v total_waited=%v model=%s account=%d", + logger.LegacyPrintf("service.antigravity_gateway", "%s status=%d single_account_503_retry attempt=%d/%d delay=%v total_waited=%v model=%s account=%d", p.prefix, resp.StatusCode, attempt, antigravitySingleAccountSmartRetryMaxAttempts, waitDuration, totalWaited, modelName, p.account.ID) timer := time.NewTimer(waitDuration) select { case <-p.ctx.Done(): timer.Stop() - log.Printf("%s status=context_canceled_during_single_account_retry", p.prefix) + logger.LegacyPrintf("service.antigravity_gateway", "%s status=context_canceled_during_single_account_retry", p.prefix) return &smartRetryResult{action: smartRetryActionBreakWithResp, err: p.ctx.Err()} case <-timer.C: } @@ -466,13 +468,13 @@ func (s *AntigravityGatewayService) handleSingleAccountRetryInPlace( // 创建新请求 retryReq, err := antigravity.NewAPIRequestWithURL(p.ctx, baseURL, p.action, p.accessToken, p.body) if err != nil { - log.Printf("%s single_account_503_retry: request_build_failed error=%v", p.prefix, err) + logger.LegacyPrintf("service.antigravity_gateway", "%s single_account_503_retry: request_build_failed error=%v", p.prefix, err) break } retryResp, retryErr := p.httpUpstream.Do(retryReq, p.proxyURL, p.account.ID, p.account.Concurrency) if retryErr == nil && retryResp != nil && retryResp.StatusCode != http.StatusTooManyRequests && retryResp.StatusCode != http.StatusServiceUnavailable { - log.Printf("%s status=%d single_account_503_retry_success attempt=%d/%d total_waited=%v", + logger.LegacyPrintf("service.antigravity_gateway", "%s status=%d single_account_503_retry_success attempt=%d/%d total_waited=%v", p.prefix, retryResp.StatusCode, attempt, antigravitySingleAccountSmartRetryMaxAttempts, totalWaited) // 关闭之前的响应 if lastRetryResp != nil { @@ -483,7 +485,7 @@ func (s *AntigravityGatewayService) handleSingleAccountRetryInPlace( // 网络错误时继续重试 if retryErr != nil || retryResp == nil { - log.Printf("%s single_account_503_retry: network_error attempt=%d/%d error=%v", + logger.LegacyPrintf("service.antigravity_gateway", "%s single_account_503_retry: network_error attempt=%d/%d error=%v", p.prefix, attempt, antigravitySingleAccountSmartRetryMaxAttempts, retryErr) continue } @@ -517,7 +519,7 @@ func (s *AntigravityGatewayService) handleSingleAccountRetryInPlace( if retryBody == nil { retryBody = respBody } - log.Printf("%s status=%d single_account_503_retry_exhausted attempts=%d total_waited=%v model=%s account=%d body=%s (return 503 directly)", + logger.LegacyPrintf("service.antigravity_gateway", "%s status=%d single_account_503_retry_exhausted attempts=%d total_waited=%v model=%s account=%d body=%s (return 503 directly)", p.prefix, resp.StatusCode, antigravitySingleAccountSmartRetryMaxAttempts, totalWaited, modelName, p.account.ID, truncateForLog(retryBody, 200)) return &smartRetryResult{ @@ -540,10 +542,10 @@ func (s *AntigravityGatewayService) antigravityRetryLoop(p antigravityRetryLoopP // 如果上游确实还不可用,handleSmartRetry → handleSingleAccountRetryInPlace // 会在 Service 层原地等待+重试,不需要在预检查这里等。 if isSingleAccountRetry(p.ctx) { - log.Printf("%s pre_check: single_account_retry skipping rate_limit remaining=%v model=%s account=%d (will retry in-place if 503)", + logger.LegacyPrintf("service.antigravity_gateway", "%s pre_check: single_account_retry skipping rate_limit remaining=%v model=%s account=%d (will retry in-place if 503)", p.prefix, remaining.Truncate(time.Millisecond), p.requestedModel, p.account.ID) } else { - log.Printf("%s pre_check: rate_limit_switch remaining=%v model=%s account=%d", + logger.LegacyPrintf("service.antigravity_gateway", "%s pre_check: rate_limit_switch remaining=%v model=%s account=%d", p.prefix, remaining.Truncate(time.Millisecond), p.requestedModel, p.account.ID) return nil, &AntigravityAccountSwitchError{ OriginalAccountID: p.account.ID, @@ -580,7 +582,7 @@ urlFallbackLoop: for attempt := 1; attempt <= antigravityMaxRetries; attempt++ { select { case <-p.ctx.Done(): - log.Printf("%s status=context_canceled error=%v", p.prefix, p.ctx.Err()) + logger.LegacyPrintf("service.antigravity_gateway", "%s status=context_canceled error=%v", p.prefix, p.ctx.Err()) return nil, p.ctx.Err() default: } @@ -610,18 +612,18 @@ urlFallbackLoop: Message: safeErr, }) if shouldAntigravityFallbackToNextURL(err, 0) && urlIdx < len(availableURLs)-1 { - log.Printf("%s URL fallback (connection error): %s -> %s", p.prefix, baseURL, availableURLs[urlIdx+1]) + logger.LegacyPrintf("service.antigravity_gateway", "%s URL fallback (connection error): %s -> %s", p.prefix, baseURL, availableURLs[urlIdx+1]) continue urlFallbackLoop } if attempt < antigravityMaxRetries { - log.Printf("%s status=request_failed retry=%d/%d error=%v", p.prefix, attempt, antigravityMaxRetries, err) + logger.LegacyPrintf("service.antigravity_gateway", "%s status=request_failed retry=%d/%d error=%v", p.prefix, attempt, antigravityMaxRetries, err) if !sleepAntigravityBackoffWithContext(p.ctx, attempt) { - log.Printf("%s status=context_canceled_during_backoff", p.prefix) + logger.LegacyPrintf("service.antigravity_gateway", "%s status=context_canceled_during_backoff", p.prefix) return nil, p.ctx.Err() } continue } - log.Printf("%s status=request_failed retries_exhausted error=%v", p.prefix, err) + logger.LegacyPrintf("service.antigravity_gateway", "%s status=request_failed retries_exhausted error=%v", p.prefix, err) setOpsUpstreamError(p.c, 0, safeErr, "") return nil, fmt.Errorf("upstream request failed after retries: %w", err) } @@ -678,9 +680,9 @@ urlFallbackLoop: Message: upstreamMsg, Detail: getUpstreamDetail(respBody), }) - log.Printf("%s status=%d retry=%d/%d body=%s", p.prefix, resp.StatusCode, attempt, antigravityMaxRetries, truncateForLog(respBody, 200)) + logger.LegacyPrintf("service.antigravity_gateway", "%s status=%d retry=%d/%d body=%s", p.prefix, resp.StatusCode, attempt, antigravityMaxRetries, truncateForLog(respBody, 200)) if !sleepAntigravityBackoffWithContext(p.ctx, attempt) { - log.Printf("%s status=context_canceled_during_backoff", p.prefix) + logger.LegacyPrintf("service.antigravity_gateway", "%s status=context_canceled_during_backoff", p.prefix) return nil, p.ctx.Err() } continue @@ -688,7 +690,7 @@ urlFallbackLoop: // 重试用尽,标记账户限流 p.handleError(p.ctx, p.prefix, p.account, resp.StatusCode, resp.Header, respBody, p.requestedModel, p.groupID, p.sessionHash, p.isStickySession) - log.Printf("%s status=%d rate_limited base_url=%s body=%s", p.prefix, resp.StatusCode, baseURL, truncateForLog(respBody, 200)) + logger.LegacyPrintf("service.antigravity_gateway", "%s status=%d rate_limited base_url=%s body=%s", p.prefix, resp.StatusCode, baseURL, truncateForLog(respBody, 200)) resp = &http.Response{ StatusCode: resp.StatusCode, Header: resp.Header.Clone(), @@ -712,9 +714,9 @@ urlFallbackLoop: Message: upstreamMsg, Detail: getUpstreamDetail(respBody), }) - log.Printf("%s status=%d retry=%d/%d body=%s", p.prefix, resp.StatusCode, attempt, antigravityMaxRetries, truncateForLog(respBody, 500)) + logger.LegacyPrintf("service.antigravity_gateway", "%s status=%d retry=%d/%d body=%s", p.prefix, resp.StatusCode, attempt, antigravityMaxRetries, truncateForLog(respBody, 500)) if !sleepAntigravityBackoffWithContext(p.ctx, attempt) { - log.Printf("%s status=context_canceled_during_backoff", p.prefix) + logger.LegacyPrintf("service.antigravity_gateway", "%s status=context_canceled_during_backoff", p.prefix) return nil, p.ctx.Err() } continue @@ -1012,14 +1014,14 @@ func (s *AntigravityGatewayService) TestConnection(ctx context.Context, account } // 调试日志:Test 请求信息 - log.Printf("[antigravity-Test] account=%s request_size=%d url=%s", account.Name, len(requestBody), req.URL.String()) + logger.LegacyPrintf("service.antigravity_gateway", "[antigravity-Test] account=%s request_size=%d url=%s", account.Name, len(requestBody), req.URL.String()) // 发送请求 resp, err := s.httpUpstream.Do(req, proxyURL, account.ID, account.Concurrency) if err != nil { lastErr = fmt.Errorf("请求失败: %w", err) if shouldAntigravityFallbackToNextURL(err, 0) && urlIdx < len(availableURLs)-1 { - log.Printf("[antigravity-Test] URL fallback: %s -> %s", baseURL, availableURLs[urlIdx+1]) + logger.LegacyPrintf("service.antigravity_gateway", "[antigravity-Test] URL fallback: %s -> %s", baseURL, availableURLs[urlIdx+1]) continue } return nil, lastErr @@ -1034,7 +1036,7 @@ func (s *AntigravityGatewayService) TestConnection(ctx context.Context, account // 检查是否需要 URL 降级 if shouldAntigravityFallbackToNextURL(nil, resp.StatusCode) && urlIdx < len(availableURLs)-1 { - log.Printf("[antigravity-Test] URL fallback (HTTP %d): %s -> %s", resp.StatusCode, baseURL, availableURLs[urlIdx+1]) + logger.LegacyPrintf("service.antigravity_gateway", "[antigravity-Test] URL fallback (HTTP %d): %s -> %s", resp.StatusCode, baseURL, availableURLs[urlIdx+1]) continue } @@ -1243,16 +1245,12 @@ func (s *AntigravityGatewayService) wrapV1InternalRequest(projectID, model strin } // unwrapV1InternalResponse 解包 v1internal 响应 +// 使用 gjson 零拷贝提取 response 字段,避免 Unmarshal+Marshal 双重开销 func (s *AntigravityGatewayService) unwrapV1InternalResponse(body []byte) ([]byte, error) { - var outer map[string]any - if err := json.Unmarshal(body, &outer); err != nil { - return nil, err + result := gjson.GetBytes(body, "response") + if result.Exists() { + return []byte(result.Raw), nil } - - if resp, ok := outer["response"]; ok { - return json.Marshal(resp) - } - return body, nil } @@ -1420,7 +1418,7 @@ func (s *AntigravityGatewayService) Forward(ctx context.Context, c *gin.Context, continue } - log.Printf("Antigravity account %d: detected signature-related 400, retrying once (%s)", account.ID, stage.name) + logger.LegacyPrintf("service.antigravity_gateway", "Antigravity account %d: detected signature-related 400, retrying once (%s)", account.ID, stage.name) retryGeminiBody, txErr := antigravity.TransformClaudeToGeminiWithOptions(&retryClaudeReq, projectID, mappedModel, s.getClaudeTransformOptions(ctx)) if txErr != nil { @@ -1453,7 +1451,7 @@ func (s *AntigravityGatewayService) Forward(ctx context.Context, c *gin.Context, Kind: "signature_retry_request_error", Message: sanitizeUpstreamErrorMessage(retryErr.Error()), }) - log.Printf("Antigravity account %d: signature retry request failed (%s): %v", account.ID, stage.name, retryErr) + logger.LegacyPrintf("service.antigravity_gateway", "Antigravity account %d: signature retry request failed (%s): %v", account.ID, stage.name, retryErr) continue } @@ -1472,7 +1470,7 @@ func (s *AntigravityGatewayService) Forward(ctx context.Context, c *gin.Context, if retryResp.Request != nil && retryResp.Request.URL != nil { retryBaseURL = retryResp.Request.URL.Scheme + "://" + retryResp.Request.URL.Host } - log.Printf("%s status=429 rate_limited base_url=%s retry_stage=%s body=%s", prefix, retryBaseURL, stage.name, truncateForLog(retryBody, 200)) + logger.LegacyPrintf("service.antigravity_gateway", "%s status=429 rate_limited base_url=%s retry_stage=%s body=%s", prefix, retryBaseURL, stage.name, truncateForLog(retryBody, 200)) } kind := "signature_retry" if strings.TrimSpace(stage.name) != "" { @@ -1525,7 +1523,7 @@ func (s *AntigravityGatewayService) Forward(ctx context.Context, c *gin.Context, upstreamDetail := s.getUpstreamErrorDetail(respBody) logBody, maxBytes := s.getLogConfig() if logBody { - log.Printf("%s status=400 prompt_too_long=true upstream_message=%q request_id=%s body=%s", prefix, upstreamMsg, resp.Header.Get("x-request-id"), truncateForLog(respBody, maxBytes)) + logger.LegacyPrintf("service.antigravity_gateway", "%s status=400 prompt_too_long=true upstream_message=%q request_id=%s body=%s", prefix, upstreamMsg, resp.Header.Get("x-request-id"), truncateForLog(respBody, maxBytes)) } appendOpsUpstreamError(c, OpsUpstreamErrorEvent{ Platform: account.Platform, @@ -1600,7 +1598,7 @@ func (s *AntigravityGatewayService) Forward(ctx context.Context, c *gin.Context, // 客户端要求流式,直接透传转换 streamRes, err := s.handleClaudeStreamingResponse(c, resp, startTime, originalModel) if err != nil { - log.Printf("%s status=stream_error error=%v", prefix, err) + logger.LegacyPrintf("service.antigravity_gateway", "%s status=stream_error error=%v", prefix, err) return nil, err } usage = streamRes.usage @@ -1610,7 +1608,7 @@ func (s *AntigravityGatewayService) Forward(ctx context.Context, c *gin.Context, // 客户端要求非流式,收集流式响应后转换返回 streamRes, err := s.handleClaudeStreamToNonStreaming(c, resp, startTime, originalModel) if err != nil { - log.Printf("%s status=stream_collect_error error=%v", prefix, err) + logger.LegacyPrintf("service.antigravity_gateway", "%s status=stream_collect_error error=%v", prefix, err) return nil, err } usage = streamRes.usage @@ -1963,7 +1961,7 @@ func (s *AntigravityGatewayService) ForwardGemini(ctx context.Context, c *gin.Co Usage: ClaudeUsage{}, Model: originalModel, Stream: false, - Duration: time.Since(time.Now()), + Duration: time.Since(startTime), FirstTokenMs: nil, }, nil default: @@ -2002,9 +2000,9 @@ func (s *AntigravityGatewayService) ForwardGemini(ctx context.Context, c *gin.Co // 清理 Schema if cleanedBody, err := cleanGeminiRequest(injectedBody); err == nil { injectedBody = cleanedBody - log.Printf("[Antigravity] Cleaned request schema in forwarded request for account %s", account.Name) + logger.LegacyPrintf("service.antigravity_gateway", "[Antigravity] Cleaned request schema in forwarded request for account %s", account.Name) } else { - log.Printf("[Antigravity] Failed to clean schema: %v", err) + logger.LegacyPrintf("service.antigravity_gateway", "[Antigravity] Failed to clean schema: %v", err) } // 包装请求 @@ -2066,7 +2064,7 @@ func (s *AntigravityGatewayService) ForwardGemini(ctx context.Context, c *gin.Co isModelNotFoundError(resp.StatusCode, respBody) { fallbackModel := s.settingService.GetFallbackModel(ctx, PlatformAntigravity) if fallbackModel != "" && fallbackModel != mappedModel { - log.Printf("[Antigravity] Model not found (%s), retrying with fallback model %s (account: %s)", mappedModel, fallbackModel, account.Name) + logger.LegacyPrintf("service.antigravity_gateway", "[Antigravity] Model not found (%s), retrying with fallback model %s (account: %s)", mappedModel, fallbackModel, account.Name) fallbackWrapped, err := s.wrapV1InternalRequest(projectID, fallbackModel, injectedBody) if err == nil { @@ -2149,7 +2147,7 @@ func (s *AntigravityGatewayService) ForwardGemini(ctx context.Context, c *gin.Co Message: upstreamMsg, Detail: upstreamDetail, }) - log.Printf("[antigravity-Forward] upstream error status=%d body=%s", resp.StatusCode, truncateForLog(unwrappedForOps, 500)) + logger.LegacyPrintf("service.antigravity_gateway", "[antigravity-Forward] upstream error status=%d body=%s", resp.StatusCode, truncateForLog(unwrappedForOps, 500)) c.Data(resp.StatusCode, contentType, unwrappedForOps) return nil, fmt.Errorf("antigravity upstream error: %d", resp.StatusCode) } @@ -2168,7 +2166,7 @@ handleSuccess: // 客户端要求流式,直接透传 streamRes, err := s.handleGeminiStreamingResponse(c, resp, startTime) if err != nil { - log.Printf("%s status=stream_error error=%v", prefix, err) + logger.LegacyPrintf("service.antigravity_gateway", "%s status=stream_error error=%v", prefix, err) return nil, err } usage = streamRes.usage @@ -2178,7 +2176,7 @@ handleSuccess: // 客户端要求非流式,收集流式响应后返回 streamRes, err := s.handleGeminiStreamToNonStreaming(c, resp, startTime) if err != nil { - log.Printf("%s status=stream_collect_error error=%v", prefix, err) + logger.LegacyPrintf("service.antigravity_gateway", "%s status=stream_collect_error error=%v", prefix, err) return nil, err } usage = streamRes.usage @@ -2297,13 +2295,13 @@ func setModelRateLimitByModelName(ctx context.Context, repo AccountRepository, a } // 直接使用官方模型 ID 作为 key,不再转换为 scope if err := repo.SetModelRateLimit(ctx, accountID, modelName, resetAt); err != nil { - log.Printf("%s status=%d model_rate_limit_failed model=%s error=%v", prefix, statusCode, modelName, err) + logger.LegacyPrintf("service.antigravity_gateway", "%s status=%d model_rate_limit_failed model=%s error=%v", prefix, statusCode, modelName, err) return false } if afterSmartRetry { - log.Printf("%s status=%d model_rate_limited_after_smart_retry model=%s account=%d reset_in=%v", prefix, statusCode, modelName, accountID, time.Until(resetAt).Truncate(time.Second)) + logger.LegacyPrintf("service.antigravity_gateway", "%s status=%d model_rate_limited_after_smart_retry model=%s account=%d reset_in=%v", prefix, statusCode, modelName, accountID, time.Until(resetAt).Truncate(time.Second)) } else { - log.Printf("%s status=%d model_rate_limited model=%s account=%d reset_in=%v", prefix, statusCode, modelName, accountID, time.Until(resetAt).Truncate(time.Second)) + logger.LegacyPrintf("service.antigravity_gateway", "%s status=%d model_rate_limited model=%s account=%d reset_in=%v", prefix, statusCode, modelName, accountID, time.Until(resetAt).Truncate(time.Second)) } return true } @@ -2411,7 +2409,7 @@ func parseAntigravitySmartRetryInfo(body []byte) *antigravitySmartRetryInfo { // 例如: "0.5s", "10s", "4m50s", "1h30m", "200ms" 等 dur, err := time.ParseDuration(delay) if err != nil { - log.Printf("[Antigravity] failed to parse retryDelay: %s error=%v", delay, err) + logger.LegacyPrintf("service.antigravity_gateway", "[Antigravity] failed to parse retryDelay: %s error=%v", delay, err) continue } retryDelay = dur @@ -2532,7 +2530,7 @@ func (s *AntigravityGatewayService) handleModelRateLimit(p *handleModelRateLimit // RATE_LIMIT_EXCEEDED: < antigravityRateLimitThreshold: 等待后重试 if info.RetryDelay < antigravityRateLimitThreshold { - log.Printf("%s status=%d model_rate_limit_wait model=%s wait=%v", + logger.LegacyPrintf("service.antigravity_gateway", "%s status=%d model_rate_limit_wait model=%s wait=%v", p.prefix, p.statusCode, info.ModelName, info.RetryDelay) return &handleModelRateLimitResult{ Handled: true, @@ -2557,12 +2555,12 @@ func (s *AntigravityGatewayService) handleModelRateLimit(p *handleModelRateLimit // setModelRateLimitAndClearSession 设置模型限流并清除粘性会话 func (s *AntigravityGatewayService) setModelRateLimitAndClearSession(p *handleModelRateLimitParams, info *antigravitySmartRetryInfo) { resetAt := time.Now().Add(info.RetryDelay) - log.Printf("%s status=%d model_rate_limited model=%s account=%d reset_in=%v", + logger.LegacyPrintf("service.antigravity_gateway", "%s status=%d model_rate_limited model=%s account=%d reset_in=%v", p.prefix, p.statusCode, info.ModelName, p.account.ID, info.RetryDelay) // 设置模型限流状态(数据库) if err := s.accountRepo.SetModelRateLimit(p.ctx, p.account.ID, info.ModelName, resetAt); err != nil { - log.Printf("%s model_rate_limit_failed model=%s error=%v", p.prefix, info.ModelName, err) + logger.LegacyPrintf("service.antigravity_gateway", "%s model_rate_limit_failed model=%s error=%v", p.prefix, info.ModelName, err) } // 立即更新 Redis 快照中账号的限流状态,避免并发请求重复选中 @@ -2598,7 +2596,7 @@ func (s *AntigravityGatewayService) updateAccountModelRateLimitInCache(ctx conte // 更新 Redis 快照 if err := s.schedulerSnapshot.UpdateAccountInCache(ctx, account); err != nil { - log.Printf("[antigravity-Forward] cache_update_failed account=%d model=%s err=%v", account.ID, modelKey, err) + logger.LegacyPrintf("service.antigravity_gateway", "[antigravity-Forward] cache_update_failed account=%d model=%s err=%v", account.ID, modelKey, err) } } @@ -2637,7 +2635,7 @@ func (s *AntigravityGatewayService) handleUpstreamError( // 429:尝试解析模型级限流,解析失败时兜底为账号级限流 if statusCode == 429 { if logBody, maxBytes := s.getLogConfig(); logBody { - log.Printf("[Antigravity-Debug] 429 response body: %s", truncateString(string(body), maxBytes)) + logger.LegacyPrintf("service.antigravity_gateway", "[Antigravity-Debug] 429 response body: %s", truncateString(string(body), maxBytes)) } resetAt := ParseGeminiRateLimitResetTime(body) @@ -2648,9 +2646,9 @@ func (s *AntigravityGatewayService) handleUpstreamError( if modelKey != "" { ra := s.resolveResetTime(resetAt, defaultDur) if err := s.accountRepo.SetModelRateLimit(ctx, account.ID, modelKey, ra); err != nil { - log.Printf("%s status=429 model_rate_limit_set_failed model=%s error=%v", prefix, modelKey, err) + logger.LegacyPrintf("service.antigravity_gateway", "%s status=429 model_rate_limit_set_failed model=%s error=%v", prefix, modelKey, err) } else { - log.Printf("%s status=429 model_rate_limited model=%s account=%d reset_at=%v reset_in=%v", + logger.LegacyPrintf("service.antigravity_gateway", "%s status=429 model_rate_limited model=%s account=%d reset_at=%v reset_in=%v", prefix, modelKey, account.ID, ra.Format("15:04:05"), time.Until(ra).Truncate(time.Second)) s.updateAccountModelRateLimitInCache(ctx, account, modelKey, ra) } @@ -2659,10 +2657,10 @@ func (s *AntigravityGatewayService) handleUpstreamError( // 无法解析模型 key,兜底为账号级限流 ra := s.resolveResetTime(resetAt, defaultDur) - log.Printf("%s status=429 rate_limited account=%d reset_at=%v reset_in=%v (fallback)", + logger.LegacyPrintf("service.antigravity_gateway", "%s status=429 rate_limited account=%d reset_at=%v reset_in=%v (fallback)", prefix, account.ID, ra.Format("15:04:05"), time.Until(ra).Truncate(time.Second)) if err := s.accountRepo.SetRateLimited(ctx, account.ID, ra); err != nil { - log.Printf("%s status=429 rate_limit_set_failed account=%d error=%v", prefix, account.ID, err) + logger.LegacyPrintf("service.antigravity_gateway", "%s status=429 rate_limit_set_failed account=%d error=%v", prefix, account.ID, err) } return nil } @@ -2672,7 +2670,7 @@ func (s *AntigravityGatewayService) handleUpstreamError( } shouldDisable := s.rateLimitService.HandleUpstreamError(ctx, account, statusCode, headers, body) if shouldDisable { - log.Printf("%s status=%d marked_error", prefix, statusCode) + logger.LegacyPrintf("service.antigravity_gateway", "%s status=%d marked_error", prefix, statusCode) } return nil } @@ -2746,18 +2744,18 @@ func (cw *antigravityClientWriter) Disconnected() bool { return cw.disconnected func (cw *antigravityClientWriter) markDisconnected() { cw.disconnected = true - log.Printf("Client disconnected during streaming (%s), continuing to drain upstream for billing", cw.prefix) + logger.LegacyPrintf("service.antigravity_gateway", "Client disconnected during streaming (%s), continuing to drain upstream for billing", cw.prefix) } // handleStreamReadError 处理上游读取错误的通用逻辑。 // 返回 (clientDisconnect, handled):handled=true 表示错误已处理,调用方应返回已收集的 usage。 func handleStreamReadError(err error, clientDisconnected bool, prefix string) (disconnect bool, handled bool) { if errors.Is(err, context.Canceled) || errors.Is(err, context.DeadlineExceeded) { - log.Printf("Context canceled during streaming (%s), returning collected usage", prefix) + logger.LegacyPrintf("service.antigravity_gateway", "Context canceled during streaming (%s), returning collected usage", prefix) return true, true } if clientDisconnected { - log.Printf("Upstream read error after client disconnect (%s): %v, returning collected usage", prefix, err) + logger.LegacyPrintf("service.antigravity_gateway", "Upstream read error after client disconnect (%s): %v, returning collected usage", prefix, err) return true, true } return false, false @@ -2786,7 +2784,8 @@ func (s *AntigravityGatewayService) handleGeminiStreamingResponse(c *gin.Context if s.settingService.cfg != nil && s.settingService.cfg.Gateway.MaxLineSize > 0 { maxLineSize = s.settingService.cfg.Gateway.MaxLineSize } - scanner.Buffer(make([]byte, 64*1024), maxLineSize) + scanBuf := getSSEScannerBuf64K() + scanner.Buffer(scanBuf[:0], maxLineSize) usage := &ClaudeUsage{} var firstTokenMs *int @@ -2807,7 +2806,8 @@ func (s *AntigravityGatewayService) handleGeminiStreamingResponse(c *gin.Context } var lastReadAt int64 atomic.StoreInt64(&lastReadAt, time.Now().UnixNano()) - go func() { + go func(scanBuf *sseScannerBuf64K) { + defer putSSEScannerBuf64K(scanBuf) defer close(events) for scanner.Scan() { atomic.StoreInt64(&lastReadAt, time.Now().UnixNano()) @@ -2818,7 +2818,7 @@ func (s *AntigravityGatewayService) handleGeminiStreamingResponse(c *gin.Context if err := scanner.Err(); err != nil { _ = sendEvent(scanEvent{err: err}) } - }() + }(scanBuf) defer close(done) // 上游数据间隔超时保护(防止上游挂起长期占用连接) @@ -2860,7 +2860,7 @@ func (s *AntigravityGatewayService) handleGeminiStreamingResponse(c *gin.Context return &antigravityStreamResult{usage: usage, firstTokenMs: firstTokenMs, clientDisconnect: disconnect}, nil } if errors.Is(ev.err, bufio.ErrTooLong) { - log.Printf("SSE line too long (antigravity): max_size=%d error=%v", maxLineSize, ev.err) + logger.LegacyPrintf("service.antigravity_gateway", "SSE line too long (antigravity): max_size=%d error=%v", maxLineSize, ev.err) sendErrorEvent("response_too_large") return &antigravityStreamResult{usage: usage, firstTokenMs: firstTokenMs}, ev.err } @@ -2884,19 +2884,19 @@ func (s *AntigravityGatewayService) handleGeminiStreamingResponse(c *gin.Context } // 解析 usage + if u := extractGeminiUsage(inner); u != nil { + usage = u + } var parsed map[string]any if json.Unmarshal(inner, &parsed) == nil { - if u := extractGeminiUsage(parsed); u != nil { - usage = u - } // Check for MALFORMED_FUNCTION_CALL if candidates, ok := parsed["candidates"].([]any); ok && len(candidates) > 0 { if cand, ok := candidates[0].(map[string]any); ok { if fr, ok := cand["finishReason"].(string); ok && fr == "MALFORMED_FUNCTION_CALL" { - log.Printf("[Antigravity] MALFORMED_FUNCTION_CALL detected in forward stream") + logger.LegacyPrintf("service.antigravity_gateway", "[Antigravity] MALFORMED_FUNCTION_CALL detected in forward stream") if content, ok := cand["content"]; ok { if b, err := json.Marshal(content); err == nil { - log.Printf("[Antigravity] Malformed content: %s", string(b)) + logger.LegacyPrintf("service.antigravity_gateway", "[Antigravity] Malformed content: %s", string(b)) } } } @@ -2921,10 +2921,10 @@ func (s *AntigravityGatewayService) handleGeminiStreamingResponse(c *gin.Context continue } if cw.Disconnected() { - log.Printf("Upstream timeout after client disconnect (antigravity gemini), returning collected usage") + logger.LegacyPrintf("service.antigravity_gateway", "Upstream timeout after client disconnect (antigravity gemini), returning collected usage") return &antigravityStreamResult{usage: usage, firstTokenMs: firstTokenMs, clientDisconnect: true}, nil } - log.Printf("Stream data interval timeout (antigravity)") + logger.LegacyPrintf("service.antigravity_gateway", "Stream data interval timeout (antigravity)") sendErrorEvent("stream_timeout") return &antigravityStreamResult{usage: usage, firstTokenMs: firstTokenMs}, fmt.Errorf("stream data interval timeout") } @@ -2939,7 +2939,8 @@ func (s *AntigravityGatewayService) handleGeminiStreamToNonStreaming(c *gin.Cont if s.settingService.cfg != nil && s.settingService.cfg.Gateway.MaxLineSize > 0 { maxLineSize = s.settingService.cfg.Gateway.MaxLineSize } - scanner.Buffer(make([]byte, 64*1024), maxLineSize) + scanBuf := getSSEScannerBuf64K() + scanner.Buffer(scanBuf[:0], maxLineSize) usage := &ClaudeUsage{} var firstTokenMs *int @@ -2967,7 +2968,8 @@ func (s *AntigravityGatewayService) handleGeminiStreamToNonStreaming(c *gin.Cont var lastReadAt int64 atomic.StoreInt64(&lastReadAt, time.Now().UnixNano()) - go func() { + go func(scanBuf *sseScannerBuf64K) { + defer putSSEScannerBuf64K(scanBuf) defer close(events) for scanner.Scan() { atomic.StoreInt64(&lastReadAt, time.Now().UnixNano()) @@ -2978,7 +2980,7 @@ func (s *AntigravityGatewayService) handleGeminiStreamToNonStreaming(c *gin.Cont if err := scanner.Err(); err != nil { _ = sendEvent(scanEvent{err: err}) } - }() + }(scanBuf) defer close(done) // 上游数据间隔超时保护(防止上游挂起长期占用连接) @@ -3005,7 +3007,7 @@ func (s *AntigravityGatewayService) handleGeminiStreamToNonStreaming(c *gin.Cont } if ev.err != nil { if errors.Is(ev.err, bufio.ErrTooLong) { - log.Printf("SSE line too long (antigravity non-stream): max_size=%d error=%v", maxLineSize, ev.err) + logger.LegacyPrintf("service.antigravity_gateway", "SSE line too long (antigravity non-stream): max_size=%d error=%v", maxLineSize, ev.err) } return nil, ev.err } @@ -3042,7 +3044,7 @@ func (s *AntigravityGatewayService) handleGeminiStreamToNonStreaming(c *gin.Cont last = parsed // 提取 usage - if u := extractGeminiUsage(parsed); u != nil { + if u := extractGeminiUsage(inner); u != nil { usage = u } @@ -3050,10 +3052,10 @@ func (s *AntigravityGatewayService) handleGeminiStreamToNonStreaming(c *gin.Cont if candidates, ok := parsed["candidates"].([]any); ok && len(candidates) > 0 { if cand, ok := candidates[0].(map[string]any); ok { if fr, ok := cand["finishReason"].(string); ok && fr == "MALFORMED_FUNCTION_CALL" { - log.Printf("[Antigravity] MALFORMED_FUNCTION_CALL detected in forward non-stream collect") + logger.LegacyPrintf("service.antigravity_gateway", "[Antigravity] MALFORMED_FUNCTION_CALL detected in forward non-stream collect") if content, ok := cand["content"]; ok { if b, err := json.Marshal(content); err == nil { - log.Printf("[Antigravity] Malformed content: %s", string(b)) + logger.LegacyPrintf("service.antigravity_gateway", "[Antigravity] Malformed content: %s", string(b)) } } } @@ -3080,7 +3082,7 @@ func (s *AntigravityGatewayService) handleGeminiStreamToNonStreaming(c *gin.Cont if time.Since(lastRead) < streamInterval { continue } - log.Printf("Stream data interval timeout (antigravity non-stream)") + logger.LegacyPrintf("service.antigravity_gateway", "Stream data interval timeout (antigravity non-stream)") return nil, fmt.Errorf("stream data interval timeout") } } @@ -3091,7 +3093,7 @@ returnResponse: // 处理空响应情况 — 触发同账号重试 + failover 切换账号 if last == nil && lastWithParts == nil { - log.Printf("[antigravity-Forward] warning: empty stream response (gemini non-stream), triggering failover") + logger.LegacyPrintf("service.antigravity_gateway", "[antigravity-Forward] warning: empty stream response (gemini non-stream), triggering failover") return nil, &UpstreamFailoverError{ StatusCode: http.StatusBadGateway, ResponseBody: []byte(`{"error":"empty stream response from upstream"}`), @@ -3311,7 +3313,7 @@ func (s *AntigravityGatewayService) writeMappedClaudeError(c *gin.Context, accou // 记录上游错误详情便于排障(可选:由配置控制;不回显到客户端) if logBody { - log.Printf("[antigravity-Forward] upstream_error status=%d body=%s", upstreamStatus, truncateForLog(body, maxBytes)) + logger.LegacyPrintf("service.antigravity_gateway", "[antigravity-Forward] upstream_error status=%d body=%s", upstreamStatus, truncateForLog(body, maxBytes)) } // 检查错误透传规则 @@ -3402,7 +3404,8 @@ func (s *AntigravityGatewayService) handleClaudeStreamToNonStreaming(c *gin.Cont if s.settingService.cfg != nil && s.settingService.cfg.Gateway.MaxLineSize > 0 { maxLineSize = s.settingService.cfg.Gateway.MaxLineSize } - scanner.Buffer(make([]byte, 64*1024), maxLineSize) + scanBuf := getSSEScannerBuf64K() + scanner.Buffer(scanBuf[:0], maxLineSize) var firstTokenMs *int var last map[string]any @@ -3428,7 +3431,8 @@ func (s *AntigravityGatewayService) handleClaudeStreamToNonStreaming(c *gin.Cont var lastReadAt int64 atomic.StoreInt64(&lastReadAt, time.Now().UnixNano()) - go func() { + go func(scanBuf *sseScannerBuf64K) { + defer putSSEScannerBuf64K(scanBuf) defer close(events) for scanner.Scan() { atomic.StoreInt64(&lastReadAt, time.Now().UnixNano()) @@ -3439,7 +3443,7 @@ func (s *AntigravityGatewayService) handleClaudeStreamToNonStreaming(c *gin.Cont if err := scanner.Err(); err != nil { _ = sendEvent(scanEvent{err: err}) } - }() + }(scanBuf) defer close(done) // 上游数据间隔超时保护(防止上游挂起长期占用连接) @@ -3466,7 +3470,7 @@ func (s *AntigravityGatewayService) handleClaudeStreamToNonStreaming(c *gin.Cont } if ev.err != nil { if errors.Is(ev.err, bufio.ErrTooLong) { - log.Printf("SSE line too long (antigravity claude non-stream): max_size=%d error=%v", maxLineSize, ev.err) + logger.LegacyPrintf("service.antigravity_gateway", "SSE line too long (antigravity claude non-stream): max_size=%d error=%v", maxLineSize, ev.err) } return nil, ev.err } @@ -3515,7 +3519,7 @@ func (s *AntigravityGatewayService) handleClaudeStreamToNonStreaming(c *gin.Cont if time.Since(lastRead) < streamInterval { continue } - log.Printf("Stream data interval timeout (antigravity claude non-stream)") + logger.LegacyPrintf("service.antigravity_gateway", "Stream data interval timeout (antigravity claude non-stream)") return nil, fmt.Errorf("stream data interval timeout") } } @@ -3526,7 +3530,7 @@ returnResponse: // 处理空响应情况 — 触发同账号重试 + failover 切换账号 if last == nil && lastWithParts == nil { - log.Printf("[antigravity-Forward] warning: empty stream response (claude non-stream), triggering failover") + logger.LegacyPrintf("service.antigravity_gateway", "[antigravity-Forward] warning: empty stream response (claude non-stream), triggering failover") return nil, &UpstreamFailoverError{ StatusCode: http.StatusBadGateway, ResponseBody: []byte(`{"error":"empty stream response from upstream"}`), @@ -3548,7 +3552,7 @@ returnResponse: // 转换 Gemini 响应为 Claude 格式 claudeResp, agUsage, err := antigravity.TransformGeminiToClaude(geminiBody, originalModel) if err != nil { - log.Printf("[antigravity-Forward] transform_error error=%v body=%s", err, string(geminiBody)) + logger.LegacyPrintf("service.antigravity_gateway", "[antigravity-Forward] transform_error error=%v body=%s", err, string(geminiBody)) return nil, s.writeClaudeError(c, http.StatusBadGateway, "upstream_error", "Failed to parse upstream response") } @@ -3586,7 +3590,8 @@ func (s *AntigravityGatewayService) handleClaudeStreamingResponse(c *gin.Context if s.settingService.cfg != nil && s.settingService.cfg.Gateway.MaxLineSize > 0 { maxLineSize = s.settingService.cfg.Gateway.MaxLineSize } - scanner.Buffer(make([]byte, 64*1024), maxLineSize) + scanBuf := getSSEScannerBuf64K() + scanner.Buffer(scanBuf[:0], maxLineSize) // 辅助函数:转换 antigravity.ClaudeUsage 到 service.ClaudeUsage convertUsage := func(agUsage *antigravity.ClaudeUsage) *ClaudeUsage { @@ -3618,7 +3623,8 @@ func (s *AntigravityGatewayService) handleClaudeStreamingResponse(c *gin.Context } var lastReadAt int64 atomic.StoreInt64(&lastReadAt, time.Now().UnixNano()) - go func() { + go func(scanBuf *sseScannerBuf64K) { + defer putSSEScannerBuf64K(scanBuf) defer close(events) for scanner.Scan() { atomic.StoreInt64(&lastReadAt, time.Now().UnixNano()) @@ -3629,7 +3635,7 @@ func (s *AntigravityGatewayService) handleClaudeStreamingResponse(c *gin.Context if err := scanner.Err(); err != nil { _ = sendEvent(scanEvent{err: err}) } - }() + }(scanBuf) defer close(done) streamInterval := time.Duration(0) @@ -3681,7 +3687,7 @@ func (s *AntigravityGatewayService) handleClaudeStreamingResponse(c *gin.Context return &antigravityStreamResult{usage: finishUsage(), firstTokenMs: firstTokenMs, clientDisconnect: disconnect}, nil } if errors.Is(ev.err, bufio.ErrTooLong) { - log.Printf("SSE line too long (antigravity): max_size=%d error=%v", maxLineSize, ev.err) + logger.LegacyPrintf("service.antigravity_gateway", "SSE line too long (antigravity): max_size=%d error=%v", maxLineSize, ev.err) sendErrorEvent("response_too_large") return &antigravityStreamResult{usage: convertUsage(nil), firstTokenMs: firstTokenMs}, ev.err } @@ -3705,10 +3711,10 @@ func (s *AntigravityGatewayService) handleClaudeStreamingResponse(c *gin.Context continue } if cw.Disconnected() { - log.Printf("Upstream timeout after client disconnect (antigravity claude), returning collected usage") + logger.LegacyPrintf("service.antigravity_gateway", "Upstream timeout after client disconnect (antigravity claude), returning collected usage") return &antigravityStreamResult{usage: finishUsage(), firstTokenMs: firstTokenMs, clientDisconnect: true}, nil } - log.Printf("Stream data interval timeout (antigravity)") + logger.LegacyPrintf("service.antigravity_gateway", "Stream data interval timeout (antigravity)") sendErrorEvent("stream_timeout") return &antigravityStreamResult{usage: convertUsage(nil), firstTokenMs: firstTokenMs}, fmt.Errorf("stream data interval timeout") } @@ -3908,7 +3914,7 @@ func (s *AntigravityGatewayService) ForwardUpstream(ctx context.Context, c *gin. // 发送请求 resp, err := s.httpUpstream.Do(req, proxyURL, account.ID, account.Concurrency) if err != nil { - log.Printf("%s upstream request failed: %v", prefix, err) + logger.LegacyPrintf("service.antigravity_gateway", "%s upstream request failed: %v", prefix, err) return nil, fmt.Errorf("upstream request failed: %w", err) } defer func() { _ = resp.Body.Close() }() @@ -3966,7 +3972,7 @@ func (s *AntigravityGatewayService) ForwardUpstream(ctx context.Context, c *gin. // 构建计费结果 duration := time.Since(startTime) - log.Printf("%s status=success duration_ms=%d", prefix, duration.Milliseconds()) + logger.LegacyPrintf("service.antigravity_gateway", "%s status=success duration_ms=%d", prefix, duration.Milliseconds()) return &ForwardResult{ Model: billingModel, @@ -4052,7 +4058,7 @@ func (s *AntigravityGatewayService) streamUpstreamResponse(c *gin.Context, resp if disconnect, handled := handleStreamReadError(ev.err, cw.Disconnected(), "antigravity upstream"); handled { return &antigravityStreamResult{usage: usage, firstTokenMs: firstTokenMs, clientDisconnect: disconnect} } - log.Printf("Stream read error (antigravity upstream): %v", ev.err) + logger.LegacyPrintf("service.antigravity_gateway", "Stream read error (antigravity upstream): %v", ev.err) return &antigravityStreamResult{usage: usage, firstTokenMs: firstTokenMs} } @@ -4076,10 +4082,10 @@ func (s *AntigravityGatewayService) streamUpstreamResponse(c *gin.Context, resp continue } if cw.Disconnected() { - log.Printf("Upstream timeout after client disconnect (antigravity upstream), returning collected usage") + logger.LegacyPrintf("service.antigravity_gateway", "Upstream timeout after client disconnect (antigravity upstream), returning collected usage") return &antigravityStreamResult{usage: usage, firstTokenMs: firstTokenMs, clientDisconnect: true} } - log.Printf("Stream data interval timeout (antigravity upstream)") + logger.LegacyPrintf("service.antigravity_gateway", "Stream data interval timeout (antigravity upstream)") return &antigravityStreamResult{usage: usage, firstTokenMs: firstTokenMs} } } diff --git a/backend/internal/service/antigravity_gateway_service_test.go b/backend/internal/service/antigravity_gateway_service_test.go index b312e5ca..abe7b75d 100644 --- a/backend/internal/service/antigravity_gateway_service_test.go +++ b/backend/internal/service/antigravity_gateway_service_test.go @@ -9,6 +9,7 @@ import ( "io" "net/http" "net/http/httptest" + "strings" "testing" "time" @@ -417,6 +418,44 @@ func TestAntigravityGatewayService_ForwardGemini_StickySessionForceCacheBilling( require.True(t, failoverErr.ForceCacheBilling, "ForceCacheBilling should be true for sticky session switch") } +// TestStreamUpstreamResponse_UsageAndFirstToken +// 验证:usage 字段可被累积/覆盖更新,并且能记录首 token 时间 +func TestStreamUpstreamResponse_UsageAndFirstToken(t *testing.T) { + gin.SetMode(gin.TestMode) + svc := newAntigravityTestService(&config.Config{ + Gateway: config.GatewayConfig{MaxLineSize: defaultMaxLineSize}, + }) + + rec := httptest.NewRecorder() + c, _ := gin.CreateTestContext(rec) + c.Request = httptest.NewRequest(http.MethodPost, "/v1/messages", nil) + + pr, pw := io.Pipe() + resp := &http.Response{StatusCode: http.StatusOK, Header: http.Header{}, Body: pr} + + go func() { + defer func() { _ = pw.Close() }() + fmt.Fprintln(pw, `data: {"usage":{"input_tokens":1,"output_tokens":2,"cache_read_input_tokens":3,"cache_creation_input_tokens":4}}`) + fmt.Fprintln(pw, `data: {"usage":{"output_tokens":5}}`) + }() + + start := time.Now().Add(-10 * time.Millisecond) + result := svc.streamUpstreamResponse(c, resp, start) + _ = pr.Close() + + require.NotNil(t, result) + require.NotNil(t, result.usage) + require.Equal(t, 1, result.usage.InputTokens) + // 第二次事件覆盖 output_tokens + require.Equal(t, 5, result.usage.OutputTokens) + require.Equal(t, 3, result.usage.CacheReadInputTokens) + require.Equal(t, 4, result.usage.CacheCreationInputTokens) + require.NotNil(t, result.firstTokenMs) + + // 确保有透传输出 + require.Contains(t, rec.Body.String(), "data:") +} + // --- 流式 happy path 测试 --- // TestStreamUpstreamResponse_NormalComplete @@ -920,3 +959,144 @@ func TestAntigravityClientWriter(t *testing.T) { require.True(t, cw.Disconnected()) }) } + +// TestUnwrapV1InternalResponse 测试 unwrapV1InternalResponse 的各种输入场景 +func TestUnwrapV1InternalResponse(t *testing.T) { + svc := &AntigravityGatewayService{} + + // 构造 >50KB 的大型 JSON + largePadding := strings.Repeat("x", 50*1024) + largeInput := []byte(fmt.Sprintf(`{"response":{"id":"big","pad":"%s"}}`, largePadding)) + largeExpected := fmt.Sprintf(`{"id":"big","pad":"%s"}`, largePadding) + + tests := []struct { + name string + input []byte + expected string + wantErr bool + }{ + { + name: "正常 response 包装", + input: []byte(`{"response":{"id":"123","content":"hello"}}`), + expected: `{"id":"123","content":"hello"}`, + }, + { + name: "无 response 透传", + input: []byte(`{"id":"456"}`), + expected: `{"id":"456"}`, + }, + { + name: "空 JSON", + input: []byte(`{}`), + expected: `{}`, + }, + { + name: "response 为 null", + input: []byte(`{"response":null}`), + expected: `null`, + }, + { + name: "response 为基础类型 string", + input: []byte(`{"response":"hello"}`), + expected: `"hello"`, + }, + { + name: "非法 JSON", + input: []byte(`not json`), + expected: `not json`, + }, + { + name: "嵌套 response 只解一层", + input: []byte(`{"response":{"response":{"inner":true}}}`), + expected: `{"response":{"inner":true}}`, + }, + { + name: "大型 JSON >50KB", + input: largeInput, + expected: largeExpected, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got, err := svc.unwrapV1InternalResponse(tt.input) + if tt.wantErr { + require.Error(t, err) + return + } + require.NoError(t, err) + require.Equal(t, tt.expected, strings.TrimSpace(string(got))) + }) + } +} + +// --- unwrapV1InternalResponse benchmark 对照组 --- + +// unwrapV1InternalResponseOld 旧实现:Unmarshal+Marshal 双重开销(仅用于 benchmark 对照) +func unwrapV1InternalResponseOld(body []byte) ([]byte, error) { + var outer map[string]any + if err := json.Unmarshal(body, &outer); err != nil { + return nil, err + } + if resp, ok := outer["response"]; ok { + return json.Marshal(resp) + } + return body, nil +} + +func BenchmarkUnwrapV1Internal_Old_Small(b *testing.B) { + body := []byte(`{"response":{"candidates":[{"content":{"parts":[{"text":"hello world"}]}}],"usageMetadata":{"promptTokenCount":10,"candidatesTokenCount":5}}}`) + b.ResetTimer() + for i := 0; i < b.N; i++ { + _, _ = unwrapV1InternalResponseOld(body) + } +} + +func BenchmarkUnwrapV1Internal_New_Small(b *testing.B) { + body := []byte(`{"response":{"candidates":[{"content":{"parts":[{"text":"hello world"}]}}],"usageMetadata":{"promptTokenCount":10,"candidatesTokenCount":5}}}`) + svc := &AntigravityGatewayService{} + b.ResetTimer() + for i := 0; i < b.N; i++ { + _, _ = svc.unwrapV1InternalResponse(body) + } +} + +func BenchmarkUnwrapV1Internal_Old_Large(b *testing.B) { + body := generateLargeUnwrapJSON(10 * 1024) // ~10KB + b.ResetTimer() + for i := 0; i < b.N; i++ { + _, _ = unwrapV1InternalResponseOld(body) + } +} + +func BenchmarkUnwrapV1Internal_New_Large(b *testing.B) { + body := generateLargeUnwrapJSON(10 * 1024) // ~10KB + svc := &AntigravityGatewayService{} + b.ResetTimer() + for i := 0; i < b.N; i++ { + _, _ = svc.unwrapV1InternalResponse(body) + } +} + +// generateLargeUnwrapJSON 生成指定最小大小的包含 response 包装的 JSON +func generateLargeUnwrapJSON(minSize int) []byte { + parts := make([]map[string]string, 0) + current := 0 + for current < minSize { + text := fmt.Sprintf("这是第 %d 段内容,用于填充 JSON 到目标大小。", len(parts)+1) + parts = append(parts, map[string]string{"text": text}) + current += len(text) + 20 // 估算 JSON 编码开销 + } + inner := map[string]any{ + "candidates": []map[string]any{ + {"content": map[string]any{"parts": parts}}, + }, + "usageMetadata": map[string]any{ + "promptTokenCount": 100, + "candidatesTokenCount": 50, + }, + } + outer := map[string]any{"response": inner} + b, _ := json.Marshal(outer) + return b +} diff --git a/backend/internal/service/antigravity_rate_limit_test.go b/backend/internal/service/antigravity_rate_limit_test.go index 0befa7d9..6a486ebc 100644 --- a/backend/internal/service/antigravity_rate_limit_test.go +++ b/backend/internal/service/antigravity_rate_limit_test.go @@ -15,6 +15,12 @@ import ( "github.com/stretchr/testify/require" ) +// 编译期接口断言 +var _ HTTPUpstream = (*stubAntigravityUpstream)(nil) +var _ HTTPUpstream = (*recordingOKUpstream)(nil) +var _ AccountRepository = (*stubAntigravityAccountRepo)(nil) +var _ SchedulerCache = (*stubSchedulerCache)(nil) + type stubAntigravityUpstream struct { firstBase string secondBase string diff --git a/backend/internal/service/api_key.go b/backend/internal/service/api_key.go index d66059dd..fe1b3a5d 100644 --- a/backend/internal/service/api_key.go +++ b/backend/internal/service/api_key.go @@ -19,6 +19,7 @@ type APIKey struct { Status string IPWhitelist []string IPBlacklist []string + LastUsedAt *time.Time CreatedAt time.Time UpdatedAt time.Time User *User diff --git a/backend/internal/service/api_key_auth_cache.go b/backend/internal/service/api_key_auth_cache.go index d15b5817..4240be23 100644 --- a/backend/internal/service/api_key_auth_cache.go +++ b/backend/internal/service/api_key_auth_cache.go @@ -44,6 +44,10 @@ type APIKeyAuthGroupSnapshot struct { ImagePrice1K *float64 `json:"image_price_1k,omitempty"` ImagePrice2K *float64 `json:"image_price_2k,omitempty"` ImagePrice4K *float64 `json:"image_price_4k,omitempty"` + SoraImagePrice360 *float64 `json:"sora_image_price_360,omitempty"` + SoraImagePrice540 *float64 `json:"sora_image_price_540,omitempty"` + SoraVideoPricePerRequest *float64 `json:"sora_video_price_per_request,omitempty"` + SoraVideoPricePerRequestHD *float64 `json:"sora_video_price_per_request_hd,omitempty"` ClaudeCodeOnly bool `json:"claude_code_only"` FallbackGroupID *int64 `json:"fallback_group_id,omitempty"` FallbackGroupIDOnInvalidRequest *int64 `json:"fallback_group_id_on_invalid_request,omitempty"` diff --git a/backend/internal/service/api_key_auth_cache_impl.go b/backend/internal/service/api_key_auth_cache_impl.go index f5bba7d0..77a75674 100644 --- a/backend/internal/service/api_key_auth_cache_impl.go +++ b/backend/internal/service/api_key_auth_cache_impl.go @@ -6,8 +6,7 @@ import ( "encoding/hex" "errors" "fmt" - "math/rand" - "sync" + "math/rand/v2" "time" "github.com/Wei-Shaw/sub2api/internal/config" @@ -23,12 +22,6 @@ type apiKeyAuthCacheConfig struct { singleflight bool } -var ( - jitterRandMu sync.Mutex - // 认证缓存抖动使用独立随机源,避免全局 Seed - jitterRand = rand.New(rand.NewSource(time.Now().UnixNano())) -) - func newAPIKeyAuthCacheConfig(cfg *config.Config) apiKeyAuthCacheConfig { if cfg == nil { return apiKeyAuthCacheConfig{} @@ -56,6 +49,8 @@ func (c apiKeyAuthCacheConfig) negativeEnabled() bool { return c.negativeTTL > 0 } +// jitterTTL 为缓存 TTL 添加抖动,避免多个请求在同一时刻同时过期触发集中回源。 +// 这里直接使用 rand/v2 的顶层函数:并发安全,无需全局互斥锁。 func (c apiKeyAuthCacheConfig) jitterTTL(ttl time.Duration) time.Duration { if ttl <= 0 { return ttl @@ -68,9 +63,7 @@ func (c apiKeyAuthCacheConfig) jitterTTL(ttl time.Duration) time.Duration { percent = 100 } delta := float64(percent) / 100 - jitterRandMu.Lock() - randVal := jitterRand.Float64() - jitterRandMu.Unlock() + randVal := rand.Float64() factor := 1 - delta + randVal*(2*delta) if factor <= 0 { return ttl @@ -238,6 +231,10 @@ func (s *APIKeyService) snapshotFromAPIKey(apiKey *APIKey) *APIKeyAuthSnapshot { ImagePrice1K: apiKey.Group.ImagePrice1K, ImagePrice2K: apiKey.Group.ImagePrice2K, ImagePrice4K: apiKey.Group.ImagePrice4K, + SoraImagePrice360: apiKey.Group.SoraImagePrice360, + SoraImagePrice540: apiKey.Group.SoraImagePrice540, + SoraVideoPricePerRequest: apiKey.Group.SoraVideoPricePerRequest, + SoraVideoPricePerRequestHD: apiKey.Group.SoraVideoPricePerRequestHD, ClaudeCodeOnly: apiKey.Group.ClaudeCodeOnly, FallbackGroupID: apiKey.Group.FallbackGroupID, FallbackGroupIDOnInvalidRequest: apiKey.Group.FallbackGroupIDOnInvalidRequest, @@ -288,6 +285,10 @@ func (s *APIKeyService) snapshotToAPIKey(key string, snapshot *APIKeyAuthSnapsho ImagePrice1K: snapshot.Group.ImagePrice1K, ImagePrice2K: snapshot.Group.ImagePrice2K, ImagePrice4K: snapshot.Group.ImagePrice4K, + SoraImagePrice360: snapshot.Group.SoraImagePrice360, + SoraImagePrice540: snapshot.Group.SoraImagePrice540, + SoraVideoPricePerRequest: snapshot.Group.SoraVideoPricePerRequest, + SoraVideoPricePerRequestHD: snapshot.Group.SoraVideoPricePerRequestHD, ClaudeCodeOnly: snapshot.Group.ClaudeCodeOnly, FallbackGroupID: snapshot.Group.FallbackGroupID, FallbackGroupIDOnInvalidRequest: snapshot.Group.FallbackGroupIDOnInvalidRequest, diff --git a/backend/internal/service/api_key_service.go b/backend/internal/service/api_key_service.go index cb1dd60a..c5e1cfab 100644 --- a/backend/internal/service/api_key_service.go +++ b/backend/internal/service/api_key_service.go @@ -5,6 +5,8 @@ import ( "crypto/rand" "encoding/hex" "fmt" + "strconv" + "sync" "time" "github.com/Wei-Shaw/sub2api/internal/config" @@ -32,6 +34,9 @@ var ( const ( apiKeyMaxErrorsPerHour = 20 + apiKeyLastUsedMinTouch = 30 * time.Second + // DB 写失败后的短退避,避免请求路径持续同步重试造成写风暴与高延迟。 + apiKeyLastUsedFailBackoff = 5 * time.Second ) type APIKeyRepository interface { @@ -58,6 +63,7 @@ type APIKeyRepository interface { // Quota methods IncrementQuotaUsed(ctx context.Context, id int64, amount float64) (float64, error) + UpdateLastUsed(ctx context.Context, id int64, usedAt time.Time) error } // APIKeyCache defines cache operations for API key service @@ -125,6 +131,8 @@ type APIKeyService struct { authCacheL1 *ristretto.Cache authCfg apiKeyAuthCacheConfig authGroup singleflight.Group + lastUsedTouchL1 sync.Map // keyID -> nextAllowedAt(time.Time) + lastUsedTouchSF singleflight.Group } // NewAPIKeyService 创建API Key服务实例 @@ -527,6 +535,7 @@ func (s *APIKeyService) Delete(ctx context.Context, id int64, userID int64) erro if err := s.apiKeyRepo.Delete(ctx, id); err != nil { return fmt.Errorf("delete api key: %w", err) } + s.lastUsedTouchL1.Delete(id) return nil } @@ -558,6 +567,38 @@ func (s *APIKeyService) ValidateKey(ctx context.Context, key string) (*APIKey, * return apiKey, user, nil } +// TouchLastUsed 通过防抖更新 api_keys.last_used_at,减少高频写放大。 +// 该操作为尽力而为,不应阻塞主请求链路。 +func (s *APIKeyService) TouchLastUsed(ctx context.Context, keyID int64) error { + if keyID <= 0 { + return nil + } + + now := time.Now() + if v, ok := s.lastUsedTouchL1.Load(keyID); ok { + if nextAllowedAt, ok := v.(time.Time); ok && now.Before(nextAllowedAt) { + return nil + } + } + + _, err, _ := s.lastUsedTouchSF.Do(strconv.FormatInt(keyID, 10), func() (any, error) { + latest := time.Now() + if v, ok := s.lastUsedTouchL1.Load(keyID); ok { + if nextAllowedAt, ok := v.(time.Time); ok && latest.Before(nextAllowedAt) { + return nil, nil + } + } + + if err := s.apiKeyRepo.UpdateLastUsed(ctx, keyID, latest); err != nil { + s.lastUsedTouchL1.Store(keyID, latest.Add(apiKeyLastUsedFailBackoff)) + return nil, fmt.Errorf("touch api key last used: %w", err) + } + s.lastUsedTouchL1.Store(keyID, latest.Add(apiKeyLastUsedMinTouch)) + return nil, nil + }) + return err +} + // IncrementUsage 增加API Key使用次数(可选:用于统计) func (s *APIKeyService) IncrementUsage(ctx context.Context, keyID int64) error { // 使用Redis计数器 diff --git a/backend/internal/service/api_key_service_cache_test.go b/backend/internal/service/api_key_service_cache_test.go index 14ecbf39..2357813b 100644 --- a/backend/internal/service/api_key_service_cache_test.go +++ b/backend/internal/service/api_key_service_cache_test.go @@ -103,6 +103,10 @@ func (s *authRepoStub) IncrementQuotaUsed(ctx context.Context, id int64, amount panic("unexpected IncrementQuotaUsed call") } +func (s *authRepoStub) UpdateLastUsed(ctx context.Context, id int64, usedAt time.Time) error { + panic("unexpected UpdateLastUsed call") +} + type authCacheStub struct { getAuthCache func(ctx context.Context, key string) (*APIKeyAuthCacheEntry, error) setAuthKeys []string diff --git a/backend/internal/service/api_key_service_delete_test.go b/backend/internal/service/api_key_service_delete_test.go index d4d12144..79757808 100644 --- a/backend/internal/service/api_key_service_delete_test.go +++ b/backend/internal/service/api_key_service_delete_test.go @@ -24,10 +24,13 @@ import ( // - deleteErr: 模拟 Delete 返回的错误 // - deletedIDs: 记录被调用删除的 API Key ID,用于断言验证 type apiKeyRepoStub struct { - apiKey *APIKey // GetKeyAndOwnerID 的返回值 - getByIDErr error // GetKeyAndOwnerID 的错误返回值 - deleteErr error // Delete 的错误返回值 - deletedIDs []int64 // 记录已删除的 API Key ID 列表 + apiKey *APIKey // GetKeyAndOwnerID 的返回值 + getByIDErr error // GetKeyAndOwnerID 的错误返回值 + deleteErr error // Delete 的错误返回值 + deletedIDs []int64 // 记录已删除的 API Key ID 列表 + updateLastUsed func(ctx context.Context, id int64, usedAt time.Time) error + touchedIDs []int64 + touchedUsedAts []time.Time } // 以下方法在本测试中不应被调用,使用 panic 确保测试失败时能快速定位问题 @@ -122,6 +125,15 @@ func (s *apiKeyRepoStub) IncrementQuotaUsed(ctx context.Context, id int64, amoun panic("unexpected IncrementQuotaUsed call") } +func (s *apiKeyRepoStub) UpdateLastUsed(ctx context.Context, id int64, usedAt time.Time) error { + s.touchedIDs = append(s.touchedIDs, id) + s.touchedUsedAts = append(s.touchedUsedAts, usedAt) + if s.updateLastUsed != nil { + return s.updateLastUsed(ctx, id, usedAt) + } + return nil +} + // apiKeyCacheStub 是 APIKeyCache 接口的测试桩实现。 // 用于验证删除操作时缓存清理逻辑是否被正确调用。 // @@ -214,12 +226,15 @@ func TestApiKeyService_Delete_Success(t *testing.T) { } cache := &apiKeyCacheStub{} svc := &APIKeyService{apiKeyRepo: repo, cache: cache} + svc.lastUsedTouchL1.Store(int64(42), time.Now()) err := svc.Delete(context.Background(), 42, 7) // API Key ID=42, 调用者 userID=7 require.NoError(t, err) require.Equal(t, []int64{42}, repo.deletedIDs) // 验证正确的 API Key 被删除 require.Equal(t, []int64{7}, cache.invalidated) // 验证所有者的缓存被清除 require.Equal(t, []string{svc.authCacheKey("k")}, cache.deleteAuthKeys) + _, exists := svc.lastUsedTouchL1.Load(int64(42)) + require.False(t, exists, "delete should clear touch debounce cache") } // TestApiKeyService_Delete_NotFound 测试删除不存在的 API Key 时返回正确的错误。 diff --git a/backend/internal/service/api_key_service_touch_last_used_test.go b/backend/internal/service/api_key_service_touch_last_used_test.go new file mode 100644 index 00000000..b49bf9ce --- /dev/null +++ b/backend/internal/service/api_key_service_touch_last_used_test.go @@ -0,0 +1,160 @@ +//go:build unit + +package service + +import ( + "context" + "errors" + "sync" + "testing" + "time" + + "github.com/stretchr/testify/require" +) + +func TestAPIKeyService_TouchLastUsed_InvalidKeyID(t *testing.T) { + repo := &apiKeyRepoStub{ + updateLastUsed: func(ctx context.Context, id int64, usedAt time.Time) error { + return errors.New("should not be called") + }, + } + svc := &APIKeyService{apiKeyRepo: repo} + + require.NoError(t, svc.TouchLastUsed(context.Background(), 0)) + require.NoError(t, svc.TouchLastUsed(context.Background(), -1)) + require.Empty(t, repo.touchedIDs) +} + +func TestAPIKeyService_TouchLastUsed_FirstTouchSucceeds(t *testing.T) { + repo := &apiKeyRepoStub{} + svc := &APIKeyService{apiKeyRepo: repo} + + err := svc.TouchLastUsed(context.Background(), 123) + require.NoError(t, err) + require.Equal(t, []int64{123}, repo.touchedIDs) + require.Len(t, repo.touchedUsedAts, 1) + require.False(t, repo.touchedUsedAts[0].IsZero()) + + cached, ok := svc.lastUsedTouchL1.Load(int64(123)) + require.True(t, ok, "successful touch should update debounce cache") + _, isTime := cached.(time.Time) + require.True(t, isTime) +} + +func TestAPIKeyService_TouchLastUsed_DebouncedWithinWindow(t *testing.T) { + repo := &apiKeyRepoStub{} + svc := &APIKeyService{apiKeyRepo: repo} + + require.NoError(t, svc.TouchLastUsed(context.Background(), 123)) + require.NoError(t, svc.TouchLastUsed(context.Background(), 123)) + + require.Equal(t, []int64{123}, repo.touchedIDs, "second touch within debounce window should not hit repository") +} + +func TestAPIKeyService_TouchLastUsed_ExpiredDebounceTouchesAgain(t *testing.T) { + repo := &apiKeyRepoStub{} + svc := &APIKeyService{apiKeyRepo: repo} + + require.NoError(t, svc.TouchLastUsed(context.Background(), 123)) + + // 强制将 debounce 时间回拨到窗口之外,触发第二次写库。 + svc.lastUsedTouchL1.Store(int64(123), time.Now().Add(-apiKeyLastUsedMinTouch-time.Second)) + + require.NoError(t, svc.TouchLastUsed(context.Background(), 123)) + require.Len(t, repo.touchedIDs, 2) + require.Equal(t, int64(123), repo.touchedIDs[0]) + require.Equal(t, int64(123), repo.touchedIDs[1]) +} + +func TestAPIKeyService_TouchLastUsed_RepoError(t *testing.T) { + repo := &apiKeyRepoStub{ + updateLastUsed: func(ctx context.Context, id int64, usedAt time.Time) error { + return errors.New("db write failed") + }, + } + svc := &APIKeyService{apiKeyRepo: repo} + + err := svc.TouchLastUsed(context.Background(), 123) + require.Error(t, err) + require.ErrorContains(t, err, "touch api key last used") + require.Equal(t, []int64{123}, repo.touchedIDs) + + cached, ok := svc.lastUsedTouchL1.Load(int64(123)) + require.True(t, ok, "failed touch should still update retry debounce cache") + _, isTime := cached.(time.Time) + require.True(t, isTime) +} + +func TestAPIKeyService_TouchLastUsed_RepoErrorDebounced(t *testing.T) { + repo := &apiKeyRepoStub{ + updateLastUsed: func(ctx context.Context, id int64, usedAt time.Time) error { + return errors.New("db write failed") + }, + } + svc := &APIKeyService{apiKeyRepo: repo} + + firstErr := svc.TouchLastUsed(context.Background(), 456) + require.Error(t, firstErr) + require.ErrorContains(t, firstErr, "touch api key last used") + + secondErr := svc.TouchLastUsed(context.Background(), 456) + require.NoError(t, secondErr, "failed touch should be debounced and skip immediate retry") + require.Equal(t, []int64{456}, repo.touchedIDs, "debounced retry should not hit repository again") +} + +type touchSingleflightRepo struct { + *apiKeyRepoStub + mu sync.Mutex + calls int + blockCh chan struct{} +} + +func (r *touchSingleflightRepo) UpdateLastUsed(ctx context.Context, id int64, usedAt time.Time) error { + r.mu.Lock() + r.calls++ + r.mu.Unlock() + <-r.blockCh + return nil +} + +func TestAPIKeyService_TouchLastUsed_ConcurrentFirstTouchDeduplicated(t *testing.T) { + repo := &touchSingleflightRepo{ + apiKeyRepoStub: &apiKeyRepoStub{}, + blockCh: make(chan struct{}), + } + svc := &APIKeyService{apiKeyRepo: repo} + + const workers = 20 + startCh := make(chan struct{}) + errCh := make(chan error, workers) + var wg sync.WaitGroup + + for i := 0; i < workers; i++ { + wg.Add(1) + go func() { + defer wg.Done() + <-startCh + errCh <- svc.TouchLastUsed(context.Background(), 321) + }() + } + + close(startCh) + + require.Eventually(t, func() bool { + repo.mu.Lock() + defer repo.mu.Unlock() + return repo.calls >= 1 + }, time.Second, 10*time.Millisecond) + + close(repo.blockCh) + wg.Wait() + close(errCh) + + for err := range errCh { + require.NoError(t, err) + } + + repo.mu.Lock() + defer repo.mu.Unlock() + require.Equal(t, 1, repo.calls, "并发首次 touch 只应写库一次") +} diff --git a/backend/internal/service/auth_service.go b/backend/internal/service/auth_service.go index fb8aaf9c..73f59dd0 100644 --- a/backend/internal/service/auth_service.go +++ b/backend/internal/service/auth_service.go @@ -7,13 +7,13 @@ import ( "encoding/hex" "errors" "fmt" - "log" "net/mail" "strings" "time" "github.com/Wei-Shaw/sub2api/internal/config" infraerrors "github.com/Wei-Shaw/sub2api/internal/pkg/errors" + "github.com/Wei-Shaw/sub2api/internal/pkg/logger" "github.com/golang-jwt/jwt/v5" "golang.org/x/crypto/bcrypt" @@ -118,12 +118,12 @@ func (s *AuthService) RegisterWithVerification(ctx context.Context, email, passw // 验证邀请码 redeemCode, err := s.redeemRepo.GetByCode(ctx, invitationCode) if err != nil { - log.Printf("[Auth] Invalid invitation code: %s, error: %v", invitationCode, err) + logger.LegacyPrintf("service.auth", "[Auth] Invalid invitation code: %s, error: %v", invitationCode, err) return "", nil, ErrInvitationCodeInvalid } // 检查类型和状态 if redeemCode.Type != RedeemTypeInvitation || redeemCode.Status != StatusUnused { - log.Printf("[Auth] Invitation code invalid: type=%s, status=%s", redeemCode.Type, redeemCode.Status) + logger.LegacyPrintf("service.auth", "[Auth] Invitation code invalid: type=%s, status=%s", redeemCode.Type, redeemCode.Status) return "", nil, ErrInvitationCodeInvalid } invitationRedeemCode = redeemCode @@ -134,7 +134,7 @@ func (s *AuthService) RegisterWithVerification(ctx context.Context, email, passw // 如果邮件验证已开启但邮件服务未配置,拒绝注册 // 这是一个配置错误,不应该允许绕过验证 if s.emailService == nil { - log.Println("[Auth] Email verification enabled but email service not configured, rejecting registration") + logger.LegacyPrintf("service.auth", "%s", "[Auth] Email verification enabled but email service not configured, rejecting registration") return "", nil, ErrServiceUnavailable } if verifyCode == "" { @@ -149,7 +149,7 @@ func (s *AuthService) RegisterWithVerification(ctx context.Context, email, passw // 检查邮箱是否已存在 existsEmail, err := s.userRepo.ExistsByEmail(ctx, email) if err != nil { - log.Printf("[Auth] Database error checking email exists: %v", err) + logger.LegacyPrintf("service.auth", "[Auth] Database error checking email exists: %v", err) return "", nil, ErrServiceUnavailable } if existsEmail { @@ -185,7 +185,7 @@ func (s *AuthService) RegisterWithVerification(ctx context.Context, email, passw if errors.Is(err, ErrEmailExists) { return "", nil, ErrEmailExists } - log.Printf("[Auth] Database error creating user: %v", err) + logger.LegacyPrintf("service.auth", "[Auth] Database error creating user: %v", err) return "", nil, ErrServiceUnavailable } @@ -193,14 +193,14 @@ func (s *AuthService) RegisterWithVerification(ctx context.Context, email, passw if invitationRedeemCode != nil { if err := s.redeemRepo.Use(ctx, invitationRedeemCode.ID, user.ID); err != nil { // 邀请码标记失败不影响注册,只记录日志 - log.Printf("[Auth] Failed to mark invitation code as used for user %d: %v", user.ID, err) + logger.LegacyPrintf("service.auth", "[Auth] Failed to mark invitation code as used for user %d: %v", user.ID, err) } } // 应用优惠码(如果提供且功能已启用) if promoCode != "" && s.promoService != nil && s.settingService != nil && s.settingService.IsPromoCodeEnabled(ctx) { if err := s.promoService.ApplyPromoCode(ctx, user.ID, promoCode); err != nil { // 优惠码应用失败不影响注册,只记录日志 - log.Printf("[Auth] Failed to apply promo code for user %d: %v", user.ID, err) + logger.LegacyPrintf("service.auth", "[Auth] Failed to apply promo code for user %d: %v", user.ID, err) } else { // 重新获取用户信息以获取更新后的余额 if updatedUser, err := s.userRepo.GetByID(ctx, user.ID); err == nil { @@ -237,7 +237,7 @@ func (s *AuthService) SendVerifyCode(ctx context.Context, email string) error { // 检查邮箱是否已存在 existsEmail, err := s.userRepo.ExistsByEmail(ctx, email) if err != nil { - log.Printf("[Auth] Database error checking email exists: %v", err) + logger.LegacyPrintf("service.auth", "[Auth] Database error checking email exists: %v", err) return ErrServiceUnavailable } if existsEmail { @@ -260,11 +260,11 @@ func (s *AuthService) SendVerifyCode(ctx context.Context, email string) error { // SendVerifyCodeAsync 异步发送邮箱验证码并返回倒计时 func (s *AuthService) SendVerifyCodeAsync(ctx context.Context, email string) (*SendVerifyCodeResult, error) { - log.Printf("[Auth] SendVerifyCodeAsync called for email: %s", email) + logger.LegacyPrintf("service.auth", "[Auth] SendVerifyCodeAsync called for email: %s", email) // 检查是否开放注册(默认关闭) if s.settingService == nil || !s.settingService.IsRegistrationEnabled(ctx) { - log.Println("[Auth] Registration is disabled") + logger.LegacyPrintf("service.auth", "%s", "[Auth] Registration is disabled") return nil, ErrRegDisabled } @@ -275,17 +275,17 @@ func (s *AuthService) SendVerifyCodeAsync(ctx context.Context, email string) (*S // 检查邮箱是否已存在 existsEmail, err := s.userRepo.ExistsByEmail(ctx, email) if err != nil { - log.Printf("[Auth] Database error checking email exists: %v", err) + logger.LegacyPrintf("service.auth", "[Auth] Database error checking email exists: %v", err) return nil, ErrServiceUnavailable } if existsEmail { - log.Printf("[Auth] Email already exists: %s", email) + logger.LegacyPrintf("service.auth", "[Auth] Email already exists: %s", email) return nil, ErrEmailExists } // 检查邮件队列服务是否配置 if s.emailQueueService == nil { - log.Println("[Auth] Email queue service not configured") + logger.LegacyPrintf("service.auth", "%s", "[Auth] Email queue service not configured") return nil, errors.New("email queue service not configured") } @@ -296,13 +296,13 @@ func (s *AuthService) SendVerifyCodeAsync(ctx context.Context, email string) (*S } // 异步发送 - log.Printf("[Auth] Enqueueing verify code for: %s", email) + logger.LegacyPrintf("service.auth", "[Auth] Enqueueing verify code for: %s", email) if err := s.emailQueueService.EnqueueVerifyCode(email, siteName); err != nil { - log.Printf("[Auth] Failed to enqueue: %v", err) + logger.LegacyPrintf("service.auth", "[Auth] Failed to enqueue: %v", err) return nil, fmt.Errorf("enqueue verify code: %w", err) } - log.Printf("[Auth] Verify code enqueued successfully for: %s", email) + logger.LegacyPrintf("service.auth", "[Auth] Verify code enqueued successfully for: %s", email) return &SendVerifyCodeResult{ Countdown: 60, // 60秒倒计时 }, nil @@ -314,27 +314,27 @@ func (s *AuthService) VerifyTurnstile(ctx context.Context, token string, remoteI if required { if s.settingService == nil { - log.Println("[Auth] Turnstile required but settings service is not configured") + logger.LegacyPrintf("service.auth", "%s", "[Auth] Turnstile required but settings service is not configured") return ErrTurnstileNotConfigured } enabled := s.settingService.IsTurnstileEnabled(ctx) secretConfigured := s.settingService.GetTurnstileSecretKey(ctx) != "" if !enabled || !secretConfigured { - log.Printf("[Auth] Turnstile required but not configured (enabled=%v, secret_configured=%v)", enabled, secretConfigured) + logger.LegacyPrintf("service.auth", "[Auth] Turnstile required but not configured (enabled=%v, secret_configured=%v)", enabled, secretConfigured) return ErrTurnstileNotConfigured } } if s.turnstileService == nil { if required { - log.Println("[Auth] Turnstile required but service not configured") + logger.LegacyPrintf("service.auth", "%s", "[Auth] Turnstile required but service not configured") return ErrTurnstileNotConfigured } return nil // 服务未配置则跳过验证 } if !required && s.settingService != nil && s.settingService.IsTurnstileEnabled(ctx) && s.settingService.GetTurnstileSecretKey(ctx) == "" { - log.Println("[Auth] Turnstile enabled but secret key not configured") + logger.LegacyPrintf("service.auth", "%s", "[Auth] Turnstile enabled but secret key not configured") } return s.turnstileService.VerifyToken(ctx, token, remoteIP) @@ -373,7 +373,7 @@ func (s *AuthService) Login(ctx context.Context, email, password string) (string return "", nil, ErrInvalidCredentials } // 记录数据库错误但不暴露给用户 - log.Printf("[Auth] Database error during login: %v", err) + logger.LegacyPrintf("service.auth", "[Auth] Database error during login: %v", err) return "", nil, ErrServiceUnavailable } @@ -426,7 +426,7 @@ func (s *AuthService) LoginOrRegisterOAuth(ctx context.Context, email, username randomPassword, err := randomHexString(32) if err != nil { - log.Printf("[Auth] Failed to generate random password for oauth signup: %v", err) + logger.LegacyPrintf("service.auth", "[Auth] Failed to generate random password for oauth signup: %v", err) return "", nil, ErrServiceUnavailable } hashedPassword, err := s.HashPassword(randomPassword) @@ -457,18 +457,18 @@ func (s *AuthService) LoginOrRegisterOAuth(ctx context.Context, email, username // 并发场景:GetByEmail 与 Create 之间用户被创建。 user, err = s.userRepo.GetByEmail(ctx, email) if err != nil { - log.Printf("[Auth] Database error getting user after conflict: %v", err) + logger.LegacyPrintf("service.auth", "[Auth] Database error getting user after conflict: %v", err) return "", nil, ErrServiceUnavailable } } else { - log.Printf("[Auth] Database error creating oauth user: %v", err) + logger.LegacyPrintf("service.auth", "[Auth] Database error creating oauth user: %v", err) return "", nil, ErrServiceUnavailable } } else { user = newUser } } else { - log.Printf("[Auth] Database error during oauth login: %v", err) + logger.LegacyPrintf("service.auth", "[Auth] Database error during oauth login: %v", err) return "", nil, ErrServiceUnavailable } } @@ -481,7 +481,7 @@ func (s *AuthService) LoginOrRegisterOAuth(ctx context.Context, email, username if user.Username == "" && username != "" { user.Username = username if err := s.userRepo.Update(ctx, user); err != nil { - log.Printf("[Auth] Failed to update username after oauth login: %v", err) + logger.LegacyPrintf("service.auth", "[Auth] Failed to update username after oauth login: %v", err) } } @@ -523,7 +523,7 @@ func (s *AuthService) LoginOrRegisterOAuthWithTokenPair(ctx context.Context, ema randomPassword, err := randomHexString(32) if err != nil { - log.Printf("[Auth] Failed to generate random password for oauth signup: %v", err) + logger.LegacyPrintf("service.auth", "[Auth] Failed to generate random password for oauth signup: %v", err) return nil, nil, ErrServiceUnavailable } hashedPassword, err := s.HashPassword(randomPassword) @@ -552,18 +552,18 @@ func (s *AuthService) LoginOrRegisterOAuthWithTokenPair(ctx context.Context, ema if errors.Is(err, ErrEmailExists) { user, err = s.userRepo.GetByEmail(ctx, email) if err != nil { - log.Printf("[Auth] Database error getting user after conflict: %v", err) + logger.LegacyPrintf("service.auth", "[Auth] Database error getting user after conflict: %v", err) return nil, nil, ErrServiceUnavailable } } else { - log.Printf("[Auth] Database error creating oauth user: %v", err) + logger.LegacyPrintf("service.auth", "[Auth] Database error creating oauth user: %v", err) return nil, nil, ErrServiceUnavailable } } else { user = newUser } } else { - log.Printf("[Auth] Database error during oauth login: %v", err) + logger.LegacyPrintf("service.auth", "[Auth] Database error during oauth login: %v", err) return nil, nil, ErrServiceUnavailable } } @@ -575,7 +575,7 @@ func (s *AuthService) LoginOrRegisterOAuthWithTokenPair(ctx context.Context, ema if user.Username == "" && username != "" { user.Username = username if err := s.userRepo.Update(ctx, user); err != nil { - log.Printf("[Auth] Failed to update username after oauth login: %v", err) + logger.LegacyPrintf("service.auth", "[Auth] Failed to update username after oauth login: %v", err) } } @@ -715,7 +715,7 @@ func (s *AuthService) RefreshToken(ctx context.Context, oldTokenString string) ( if errors.Is(err, ErrUserNotFound) { return "", ErrInvalidToken } - log.Printf("[Auth] Database error refreshing token: %v", err) + logger.LegacyPrintf("service.auth", "[Auth] Database error refreshing token: %v", err) return "", ErrServiceUnavailable } @@ -756,16 +756,16 @@ func (s *AuthService) preparePasswordReset(ctx context.Context, email, frontendB if err != nil { if errors.Is(err, ErrUserNotFound) { // Security: Log but don't reveal that user doesn't exist - log.Printf("[Auth] Password reset requested for non-existent email: %s", email) + logger.LegacyPrintf("service.auth", "[Auth] Password reset requested for non-existent email: %s", email) return "", "", false } - log.Printf("[Auth] Database error checking email for password reset: %v", err) + logger.LegacyPrintf("service.auth", "[Auth] Database error checking email for password reset: %v", err) return "", "", false } // Check if user is active if !user.IsActive() { - log.Printf("[Auth] Password reset requested for inactive user: %s", email) + logger.LegacyPrintf("service.auth", "[Auth] Password reset requested for inactive user: %s", email) return "", "", false } @@ -797,11 +797,11 @@ func (s *AuthService) RequestPasswordReset(ctx context.Context, email, frontendB } if err := s.emailService.SendPasswordResetEmail(ctx, email, siteName, resetURL); err != nil { - log.Printf("[Auth] Failed to send password reset email to %s: %v", email, err) + logger.LegacyPrintf("service.auth", "[Auth] Failed to send password reset email to %s: %v", email, err) return nil // Silent success to prevent enumeration } - log.Printf("[Auth] Password reset email sent to: %s", email) + logger.LegacyPrintf("service.auth", "[Auth] Password reset email sent to: %s", email) return nil } @@ -821,11 +821,11 @@ func (s *AuthService) RequestPasswordResetAsync(ctx context.Context, email, fron } if err := s.emailQueueService.EnqueuePasswordReset(email, siteName, resetURL); err != nil { - log.Printf("[Auth] Failed to enqueue password reset email for %s: %v", email, err) + logger.LegacyPrintf("service.auth", "[Auth] Failed to enqueue password reset email for %s: %v", email, err) return nil // Silent success to prevent enumeration } - log.Printf("[Auth] Password reset email enqueued for: %s", email) + logger.LegacyPrintf("service.auth", "[Auth] Password reset email enqueued for: %s", email) return nil } @@ -852,7 +852,7 @@ func (s *AuthService) ResetPassword(ctx context.Context, email, token, newPasswo if errors.Is(err, ErrUserNotFound) { return ErrInvalidResetToken // Token was valid but user was deleted } - log.Printf("[Auth] Database error getting user for password reset: %v", err) + logger.LegacyPrintf("service.auth", "[Auth] Database error getting user for password reset: %v", err) return ErrServiceUnavailable } @@ -872,17 +872,17 @@ func (s *AuthService) ResetPassword(ctx context.Context, email, token, newPasswo user.TokenVersion++ // Invalidate all existing tokens if err := s.userRepo.Update(ctx, user); err != nil { - log.Printf("[Auth] Database error updating password for user %d: %v", user.ID, err) + logger.LegacyPrintf("service.auth", "[Auth] Database error updating password for user %d: %v", user.ID, err) return ErrServiceUnavailable } // Also revoke all refresh tokens for this user if err := s.RevokeAllUserSessions(ctx, user.ID); err != nil { - log.Printf("[Auth] Failed to revoke refresh tokens for user %d: %v", user.ID, err) + logger.LegacyPrintf("service.auth", "[Auth] Failed to revoke refresh tokens for user %d: %v", user.ID, err) // Don't return error - password was already changed successfully } - log.Printf("[Auth] Password reset successful for user: %s", email) + logger.LegacyPrintf("service.auth", "[Auth] Password reset successful for user: %s", email) return nil } @@ -961,13 +961,13 @@ func (s *AuthService) generateRefreshToken(ctx context.Context, user *User, fami // 添加到用户Token集合 if err := s.refreshTokenCache.AddToUserTokenSet(ctx, user.ID, tokenHash, ttl); err != nil { - log.Printf("[Auth] Failed to add token to user set: %v", err) + logger.LegacyPrintf("service.auth", "[Auth] Failed to add token to user set: %v", err) // 不影响主流程 } // 添加到家族Token集合 if err := s.refreshTokenCache.AddToFamilyTokenSet(ctx, familyID, tokenHash, ttl); err != nil { - log.Printf("[Auth] Failed to add token to family set: %v", err) + logger.LegacyPrintf("service.auth", "[Auth] Failed to add token to family set: %v", err) // 不影响主流程 } @@ -994,10 +994,10 @@ func (s *AuthService) RefreshTokenPair(ctx context.Context, refreshToken string) if err != nil { if errors.Is(err, ErrRefreshTokenNotFound) { // Token不存在,可能是已被使用(Token轮转)或已过期 - log.Printf("[Auth] Refresh token not found, possible reuse attack") + logger.LegacyPrintf("service.auth", "[Auth] Refresh token not found, possible reuse attack") return nil, ErrRefreshTokenInvalid } - log.Printf("[Auth] Error getting refresh token: %v", err) + logger.LegacyPrintf("service.auth", "[Auth] Error getting refresh token: %v", err) return nil, ErrServiceUnavailable } @@ -1016,7 +1016,7 @@ func (s *AuthService) RefreshTokenPair(ctx context.Context, refreshToken string) _ = s.refreshTokenCache.DeleteTokenFamily(ctx, data.FamilyID) return nil, ErrRefreshTokenInvalid } - log.Printf("[Auth] Database error getting user for token refresh: %v", err) + logger.LegacyPrintf("service.auth", "[Auth] Database error getting user for token refresh: %v", err) return nil, ErrServiceUnavailable } @@ -1036,7 +1036,7 @@ func (s *AuthService) RefreshTokenPair(ctx context.Context, refreshToken string) // Token轮转:立即使旧Token失效 if err := s.refreshTokenCache.DeleteRefreshToken(ctx, tokenHash); err != nil { - log.Printf("[Auth] Failed to delete old refresh token: %v", err) + logger.LegacyPrintf("service.auth", "[Auth] Failed to delete old refresh token: %v", err) // 继续处理,不影响主流程 } diff --git a/backend/internal/service/auth_service_register_test.go b/backend/internal/service/auth_service_register_test.go index f1685be5..93659743 100644 --- a/backend/internal/service/auth_service_register_test.go +++ b/backend/internal/service/auth_service_register_test.go @@ -315,3 +315,69 @@ func TestAuthService_RefreshToken_ExpiredTokenNoPanic(t *testing.T) { require.NotEmpty(t, newToken) }) } + +func TestAuthService_GetAccessTokenExpiresIn_FallbackToExpireHour(t *testing.T) { + service := newAuthService(&userRepoStub{}, nil, nil) + service.cfg.JWT.ExpireHour = 24 + service.cfg.JWT.AccessTokenExpireMinutes = 0 + + require.Equal(t, 24*3600, service.GetAccessTokenExpiresIn()) +} + +func TestAuthService_GetAccessTokenExpiresIn_MinutesHasPriority(t *testing.T) { + service := newAuthService(&userRepoStub{}, nil, nil) + service.cfg.JWT.ExpireHour = 24 + service.cfg.JWT.AccessTokenExpireMinutes = 90 + + require.Equal(t, 90*60, service.GetAccessTokenExpiresIn()) +} + +func TestAuthService_GenerateToken_UsesExpireHourWhenMinutesZero(t *testing.T) { + service := newAuthService(&userRepoStub{}, nil, nil) + service.cfg.JWT.ExpireHour = 24 + service.cfg.JWT.AccessTokenExpireMinutes = 0 + + user := &User{ + ID: 1, + Email: "test@test.com", + Role: RoleUser, + Status: StatusActive, + TokenVersion: 1, + } + + token, err := service.GenerateToken(user) + require.NoError(t, err) + + claims, err := service.ValidateToken(token) + require.NoError(t, err) + require.NotNil(t, claims) + require.NotNil(t, claims.IssuedAt) + require.NotNil(t, claims.ExpiresAt) + + require.WithinDuration(t, claims.IssuedAt.Time.Add(24*time.Hour), claims.ExpiresAt.Time, 2*time.Second) +} + +func TestAuthService_GenerateToken_UsesMinutesWhenConfigured(t *testing.T) { + service := newAuthService(&userRepoStub{}, nil, nil) + service.cfg.JWT.ExpireHour = 24 + service.cfg.JWT.AccessTokenExpireMinutes = 90 + + user := &User{ + ID: 2, + Email: "test2@test.com", + Role: RoleUser, + Status: StatusActive, + TokenVersion: 1, + } + + token, err := service.GenerateToken(user) + require.NoError(t, err) + + claims, err := service.ValidateToken(token) + require.NoError(t, err) + require.NotNil(t, claims) + require.NotNil(t, claims.IssuedAt) + require.NotNil(t, claims.ExpiresAt) + + require.WithinDuration(t, claims.IssuedAt.Time.Add(90*time.Minute), claims.ExpiresAt.Time, 2*time.Second) +} diff --git a/backend/internal/service/billing_cache_service.go b/backend/internal/service/billing_cache_service.go index c09cafb9..a560930b 100644 --- a/backend/internal/service/billing_cache_service.go +++ b/backend/internal/service/billing_cache_service.go @@ -3,13 +3,13 @@ package service import ( "context" "fmt" - "log" "sync" "sync/atomic" "time" "github.com/Wei-Shaw/sub2api/internal/config" infraerrors "github.com/Wei-Shaw/sub2api/internal/pkg/errors" + "github.com/Wei-Shaw/sub2api/internal/pkg/logger" ) // 错误定义 @@ -156,13 +156,13 @@ func (s *BillingCacheService) cacheWriteWorker() { case cacheWriteUpdateSubscriptionUsage: if s.cache != nil { if err := s.cache.UpdateSubscriptionUsage(ctx, task.userID, task.groupID, task.amount); err != nil { - log.Printf("Warning: update subscription cache failed for user %d group %d: %v", task.userID, task.groupID, err) + logger.LegacyPrintf("service.billing_cache", "Warning: update subscription cache failed for user %d group %d: %v", task.userID, task.groupID, err) } } case cacheWriteDeductBalance: if s.cache != nil { if err := s.cache.DeductUserBalance(ctx, task.userID, task.amount); err != nil { - log.Printf("Warning: deduct balance cache failed for user %d: %v", task.userID, err) + logger.LegacyPrintf("service.billing_cache", "Warning: deduct balance cache failed for user %d: %v", task.userID, err) } } } @@ -216,7 +216,7 @@ func (s *BillingCacheService) logCacheWriteDrop(task cacheWriteTask, reason stri if dropped == 0 { return } - log.Printf("Warning: cache write queue %s, dropped %d tasks in last %s (latest kind=%s user %d group %d)", + logger.LegacyPrintf("service.billing_cache", "Warning: cache write queue %s, dropped %d tasks in last %s (latest kind=%s user %d group %d)", reason, dropped, cacheWriteDropLogInterval, @@ -274,7 +274,7 @@ func (s *BillingCacheService) setBalanceCache(ctx context.Context, userID int64, return } if err := s.cache.SetUserBalance(ctx, userID, balance); err != nil { - log.Printf("Warning: set balance cache failed for user %d: %v", userID, err) + logger.LegacyPrintf("service.billing_cache", "Warning: set balance cache failed for user %d: %v", userID, err) } } @@ -302,7 +302,7 @@ func (s *BillingCacheService) QueueDeductBalance(userID int64, amount float64) { ctx, cancel := context.WithTimeout(context.Background(), cacheWriteTimeout) defer cancel() if err := s.DeductBalanceCache(ctx, userID, amount); err != nil { - log.Printf("Warning: deduct balance cache fallback failed for user %d: %v", userID, err) + logger.LegacyPrintf("service.billing_cache", "Warning: deduct balance cache fallback failed for user %d: %v", userID, err) } } @@ -312,7 +312,7 @@ func (s *BillingCacheService) InvalidateUserBalance(ctx context.Context, userID return nil } if err := s.cache.InvalidateUserBalance(ctx, userID); err != nil { - log.Printf("Warning: invalidate balance cache failed for user %d: %v", userID, err) + logger.LegacyPrintf("service.billing_cache", "Warning: invalidate balance cache failed for user %d: %v", userID, err) return err } return nil @@ -396,7 +396,7 @@ func (s *BillingCacheService) setSubscriptionCache(ctx context.Context, userID, return } if err := s.cache.SetSubscriptionCache(ctx, userID, groupID, s.convertToPortsData(data)); err != nil { - log.Printf("Warning: set subscription cache failed for user %d group %d: %v", userID, groupID, err) + logger.LegacyPrintf("service.billing_cache", "Warning: set subscription cache failed for user %d group %d: %v", userID, groupID, err) } } @@ -425,7 +425,7 @@ func (s *BillingCacheService) QueueUpdateSubscriptionUsage(userID, groupID int64 ctx, cancel := context.WithTimeout(context.Background(), cacheWriteTimeout) defer cancel() if err := s.UpdateSubscriptionUsage(ctx, userID, groupID, costUSD); err != nil { - log.Printf("Warning: update subscription cache fallback failed for user %d group %d: %v", userID, groupID, err) + logger.LegacyPrintf("service.billing_cache", "Warning: update subscription cache fallback failed for user %d group %d: %v", userID, groupID, err) } } @@ -435,7 +435,7 @@ func (s *BillingCacheService) InvalidateSubscription(ctx context.Context, userID return nil } if err := s.cache.InvalidateSubscriptionCache(ctx, userID, groupID); err != nil { - log.Printf("Warning: invalidate subscription cache failed for user %d group %d: %v", userID, groupID, err) + logger.LegacyPrintf("service.billing_cache", "Warning: invalidate subscription cache failed for user %d group %d: %v", userID, groupID, err) return err } return nil @@ -474,7 +474,7 @@ func (s *BillingCacheService) checkBalanceEligibility(ctx context.Context, userI if s.circuitBreaker != nil { s.circuitBreaker.OnFailure(err) } - log.Printf("ALERT: billing balance check failed for user %d: %v", userID, err) + logger.LegacyPrintf("service.billing_cache", "ALERT: billing balance check failed for user %d: %v", userID, err) return ErrBillingServiceUnavailable.WithCause(err) } if s.circuitBreaker != nil { @@ -496,7 +496,7 @@ func (s *BillingCacheService) checkSubscriptionEligibility(ctx context.Context, if s.circuitBreaker != nil { s.circuitBreaker.OnFailure(err) } - log.Printf("ALERT: billing subscription check failed for user %d group %d: %v", userID, group.ID, err) + logger.LegacyPrintf("service.billing_cache", "ALERT: billing subscription check failed for user %d group %d: %v", userID, group.ID, err) return ErrBillingServiceUnavailable.WithCause(err) } if s.circuitBreaker != nil { @@ -585,7 +585,7 @@ func (b *billingCircuitBreaker) Allow() bool { } b.state = billingCircuitHalfOpen b.halfOpenRemaining = b.halfOpenRequests - log.Printf("ALERT: billing circuit breaker entering half-open state") + logger.LegacyPrintf("service.billing_cache", "ALERT: billing circuit breaker entering half-open state") fallthrough case billingCircuitHalfOpen: if b.halfOpenRemaining <= 0 { @@ -612,7 +612,7 @@ func (b *billingCircuitBreaker) OnFailure(err error) { b.state = billingCircuitOpen b.openedAt = time.Now() b.halfOpenRemaining = 0 - log.Printf("ALERT: billing circuit breaker opened after half-open failure: %v", err) + logger.LegacyPrintf("service.billing_cache", "ALERT: billing circuit breaker opened after half-open failure: %v", err) return default: b.failures++ @@ -620,7 +620,7 @@ func (b *billingCircuitBreaker) OnFailure(err error) { b.state = billingCircuitOpen b.openedAt = time.Now() b.halfOpenRemaining = 0 - log.Printf("ALERT: billing circuit breaker opened after %d failures: %v", b.failures, err) + logger.LegacyPrintf("service.billing_cache", "ALERT: billing circuit breaker opened after %d failures: %v", b.failures, err) } } } @@ -641,9 +641,9 @@ func (b *billingCircuitBreaker) OnSuccess() { // 只有状态真正发生变化时才记录日志 if previousState != billingCircuitClosed { - log.Printf("ALERT: billing circuit breaker closed (was %s)", circuitStateString(previousState)) + logger.LegacyPrintf("service.billing_cache", "ALERT: billing circuit breaker closed (was %s)", circuitStateString(previousState)) } else if previousFailures > 0 { - log.Printf("INFO: billing circuit breaker failures reset from %d", previousFailures) + logger.LegacyPrintf("service.billing_cache", "INFO: billing circuit breaker failures reset from %d", previousFailures) } } diff --git a/backend/internal/service/billing_service.go b/backend/internal/service/billing_service.go index 6934bc64..f100be0b 100644 --- a/backend/internal/service/billing_service.go +++ b/backend/internal/service/billing_service.go @@ -312,7 +312,7 @@ func (s *BillingService) CalculateCostWithLongContext(model string, tokens Usage } outRangeCost, err := s.CalculateCost(model, outRangeTokens, rateMultiplier*extraMultiplier) if err != nil { - return inRangeCost, nil // 出错时返回范围内成本 + return inRangeCost, fmt.Errorf("out-range cost: %w", err) } // 合并成本 @@ -388,6 +388,14 @@ type ImagePriceConfig struct { Price4K *float64 // 4K 尺寸价格(nil 表示使用默认值) } +// SoraPriceConfig Sora 按次计费配置 +type SoraPriceConfig struct { + ImagePrice360 *float64 + ImagePrice540 *float64 + VideoPricePerRequest *float64 + VideoPricePerRequestHD *float64 +} + // CalculateImageCost 计算图片生成费用 // model: 请求的模型名称(用于获取 LiteLLM 默认价格) // imageSize: 图片尺寸 "1K", "2K", "4K" @@ -417,6 +425,65 @@ func (s *BillingService) CalculateImageCost(model string, imageSize string, imag } } +// CalculateSoraImageCost 计算 Sora 图片按次费用 +func (s *BillingService) CalculateSoraImageCost(imageSize string, imageCount int, groupConfig *SoraPriceConfig, rateMultiplier float64) *CostBreakdown { + if imageCount <= 0 { + return &CostBreakdown{} + } + + unitPrice := 0.0 + if groupConfig != nil { + switch imageSize { + case "540": + if groupConfig.ImagePrice540 != nil { + unitPrice = *groupConfig.ImagePrice540 + } + default: + if groupConfig.ImagePrice360 != nil { + unitPrice = *groupConfig.ImagePrice360 + } + } + } + + totalCost := unitPrice * float64(imageCount) + if rateMultiplier <= 0 { + rateMultiplier = 1.0 + } + actualCost := totalCost * rateMultiplier + + return &CostBreakdown{ + TotalCost: totalCost, + ActualCost: actualCost, + } +} + +// CalculateSoraVideoCost 计算 Sora 视频按次费用 +func (s *BillingService) CalculateSoraVideoCost(model string, groupConfig *SoraPriceConfig, rateMultiplier float64) *CostBreakdown { + unitPrice := 0.0 + if groupConfig != nil { + modelLower := strings.ToLower(model) + if strings.Contains(modelLower, "sora2pro-hd") { + if groupConfig.VideoPricePerRequestHD != nil { + unitPrice = *groupConfig.VideoPricePerRequestHD + } + } + if unitPrice <= 0 && groupConfig.VideoPricePerRequest != nil { + unitPrice = *groupConfig.VideoPricePerRequest + } + } + + totalCost := unitPrice + if rateMultiplier <= 0 { + rateMultiplier = 1.0 + } + actualCost := totalCost * rateMultiplier + + return &CostBreakdown{ + TotalCost: totalCost, + ActualCost: actualCost, + } +} + // getImageUnitPrice 获取图片单价 func (s *BillingService) getImageUnitPrice(model string, imageSize string, groupConfig *ImagePriceConfig) float64 { // 优先使用分组配置的价格 diff --git a/backend/internal/service/billing_service_test.go b/backend/internal/service/billing_service_test.go new file mode 100644 index 00000000..5eb278f6 --- /dev/null +++ b/backend/internal/service/billing_service_test.go @@ -0,0 +1,437 @@ +//go:build unit + +package service + +import ( + "math" + "testing" + + "github.com/Wei-Shaw/sub2api/internal/config" + "github.com/stretchr/testify/require" +) + +func newTestBillingService() *BillingService { + return NewBillingService(&config.Config{}, nil) +} + +func TestCalculateCost_BasicComputation(t *testing.T) { + svc := newTestBillingService() + + // 使用 claude-sonnet-4 的回退价格:Input $3/MTok, Output $15/MTok + tokens := UsageTokens{ + InputTokens: 1000, + OutputTokens: 500, + } + cost, err := svc.CalculateCost("claude-sonnet-4", tokens, 1.0) + require.NoError(t, err) + + // 1000 * 3e-6 = 0.003, 500 * 15e-6 = 0.0075 + expectedInput := 1000 * 3e-6 + expectedOutput := 500 * 15e-6 + require.InDelta(t, expectedInput, cost.InputCost, 1e-10) + require.InDelta(t, expectedOutput, cost.OutputCost, 1e-10) + require.InDelta(t, expectedInput+expectedOutput, cost.TotalCost, 1e-10) + require.InDelta(t, expectedInput+expectedOutput, cost.ActualCost, 1e-10) +} + +func TestCalculateCost_WithCacheTokens(t *testing.T) { + svc := newTestBillingService() + + tokens := UsageTokens{ + InputTokens: 1000, + OutputTokens: 500, + CacheCreationTokens: 2000, + CacheReadTokens: 3000, + } + cost, err := svc.CalculateCost("claude-sonnet-4", tokens, 1.0) + require.NoError(t, err) + + expectedCacheCreation := 2000 * 3.75e-6 + expectedCacheRead := 3000 * 0.3e-6 + require.InDelta(t, expectedCacheCreation, cost.CacheCreationCost, 1e-10) + require.InDelta(t, expectedCacheRead, cost.CacheReadCost, 1e-10) + + expectedTotal := cost.InputCost + cost.OutputCost + expectedCacheCreation + expectedCacheRead + require.InDelta(t, expectedTotal, cost.TotalCost, 1e-10) +} + +func TestCalculateCost_RateMultiplier(t *testing.T) { + svc := newTestBillingService() + + tokens := UsageTokens{InputTokens: 1000, OutputTokens: 500} + + cost1x, err := svc.CalculateCost("claude-sonnet-4", tokens, 1.0) + require.NoError(t, err) + + cost2x, err := svc.CalculateCost("claude-sonnet-4", tokens, 2.0) + require.NoError(t, err) + + // TotalCost 不受倍率影响,ActualCost 翻倍 + require.InDelta(t, cost1x.TotalCost, cost2x.TotalCost, 1e-10) + require.InDelta(t, cost1x.ActualCost*2, cost2x.ActualCost, 1e-10) +} + +func TestCalculateCost_ZeroMultiplierDefaultsToOne(t *testing.T) { + svc := newTestBillingService() + + tokens := UsageTokens{InputTokens: 1000} + + costZero, err := svc.CalculateCost("claude-sonnet-4", tokens, 0) + require.NoError(t, err) + + costOne, err := svc.CalculateCost("claude-sonnet-4", tokens, 1.0) + require.NoError(t, err) + + require.InDelta(t, costOne.ActualCost, costZero.ActualCost, 1e-10) +} + +func TestCalculateCost_NegativeMultiplierDefaultsToOne(t *testing.T) { + svc := newTestBillingService() + + tokens := UsageTokens{InputTokens: 1000} + + costNeg, err := svc.CalculateCost("claude-sonnet-4", tokens, -1.0) + require.NoError(t, err) + + costOne, err := svc.CalculateCost("claude-sonnet-4", tokens, 1.0) + require.NoError(t, err) + + require.InDelta(t, costOne.ActualCost, costNeg.ActualCost, 1e-10) +} + +func TestGetModelPricing_FallbackMatchesByFamily(t *testing.T) { + svc := newTestBillingService() + + tests := []struct { + model string + expectedInput float64 + }{ + {"claude-opus-4.5-20250101", 5e-6}, + {"claude-3-opus-20240229", 15e-6}, + {"claude-sonnet-4-20250514", 3e-6}, + {"claude-3-5-sonnet-20241022", 3e-6}, + {"claude-3-5-haiku-20241022", 1e-6}, + {"claude-3-haiku-20240307", 0.25e-6}, + } + + for _, tt := range tests { + pricing, err := svc.GetModelPricing(tt.model) + require.NoError(t, err, "模型 %s", tt.model) + require.InDelta(t, tt.expectedInput, pricing.InputPricePerToken, 1e-12, "模型 %s 输入价格", tt.model) + } +} + +func TestGetModelPricing_CaseInsensitive(t *testing.T) { + svc := newTestBillingService() + + p1, err := svc.GetModelPricing("Claude-Sonnet-4") + require.NoError(t, err) + + p2, err := svc.GetModelPricing("claude-sonnet-4") + require.NoError(t, err) + + require.Equal(t, p1.InputPricePerToken, p2.InputPricePerToken) +} + +func TestGetModelPricing_UnknownModelFallsBackToSonnet(t *testing.T) { + svc := newTestBillingService() + + // 不包含 opus/sonnet/haiku 关键词的 Claude 模型会走默认 Sonnet 价格 + pricing, err := svc.GetModelPricing("claude-unknown-model") + require.NoError(t, err) + require.InDelta(t, 3e-6, pricing.InputPricePerToken, 1e-12) +} + +func TestCalculateCostWithLongContext_BelowThreshold(t *testing.T) { + svc := newTestBillingService() + + tokens := UsageTokens{ + InputTokens: 50000, + OutputTokens: 1000, + CacheReadTokens: 100000, + } + // 总输入 150k < 200k 阈值,应走正常计费 + cost, err := svc.CalculateCostWithLongContext("claude-sonnet-4", tokens, 1.0, 200000, 2.0) + require.NoError(t, err) + + normalCost, err := svc.CalculateCost("claude-sonnet-4", tokens, 1.0) + require.NoError(t, err) + + require.InDelta(t, normalCost.ActualCost, cost.ActualCost, 1e-10) +} + +func TestCalculateCostWithLongContext_AboveThreshold_CacheExceedsThreshold(t *testing.T) { + svc := newTestBillingService() + + // 缓存 210k + 输入 10k = 220k > 200k 阈值 + // 缓存已超阈值:范围内 200k 缓存,范围外 10k 缓存 + 10k 输入 + tokens := UsageTokens{ + InputTokens: 10000, + OutputTokens: 1000, + CacheReadTokens: 210000, + } + cost, err := svc.CalculateCostWithLongContext("claude-sonnet-4", tokens, 1.0, 200000, 2.0) + require.NoError(t, err) + + // 范围内:200k cache + 0 input + 1k output + inRange, _ := svc.CalculateCost("claude-sonnet-4", UsageTokens{ + InputTokens: 0, + OutputTokens: 1000, + CacheReadTokens: 200000, + }, 1.0) + + // 范围外:10k cache + 10k input,倍率 2.0 + outRange, _ := svc.CalculateCost("claude-sonnet-4", UsageTokens{ + InputTokens: 10000, + CacheReadTokens: 10000, + }, 2.0) + + require.InDelta(t, inRange.ActualCost+outRange.ActualCost, cost.ActualCost, 1e-10) +} + +func TestCalculateCostWithLongContext_AboveThreshold_CacheBelowThreshold(t *testing.T) { + svc := newTestBillingService() + + // 缓存 100k + 输入 150k = 250k > 200k 阈值 + // 缓存未超阈值:范围内 100k 缓存 + 100k 输入,范围外 50k 输入 + tokens := UsageTokens{ + InputTokens: 150000, + OutputTokens: 1000, + CacheReadTokens: 100000, + } + cost, err := svc.CalculateCostWithLongContext("claude-sonnet-4", tokens, 1.0, 200000, 2.0) + require.NoError(t, err) + + require.True(t, cost.ActualCost > 0, "费用应大于 0") + + // 正常费用不含长上下文 + normalCost, _ := svc.CalculateCost("claude-sonnet-4", tokens, 1.0) + require.True(t, cost.ActualCost > normalCost.ActualCost, "长上下文费用应高于正常费用") +} + +func TestCalculateCostWithLongContext_DisabledThreshold(t *testing.T) { + svc := newTestBillingService() + + tokens := UsageTokens{InputTokens: 300000, CacheReadTokens: 0} + + // threshold <= 0 应禁用长上下文计费 + cost1, err := svc.CalculateCostWithLongContext("claude-sonnet-4", tokens, 1.0, 0, 2.0) + require.NoError(t, err) + + cost2, err := svc.CalculateCost("claude-sonnet-4", tokens, 1.0) + require.NoError(t, err) + + require.InDelta(t, cost2.ActualCost, cost1.ActualCost, 1e-10) +} + +func TestCalculateCostWithLongContext_ExtraMultiplierLessEqualOne(t *testing.T) { + svc := newTestBillingService() + + tokens := UsageTokens{InputTokens: 300000} + + // extraMultiplier <= 1 应禁用长上下文计费 + cost, err := svc.CalculateCostWithLongContext("claude-sonnet-4", tokens, 1.0, 200000, 1.0) + require.NoError(t, err) + + normalCost, err := svc.CalculateCost("claude-sonnet-4", tokens, 1.0) + require.NoError(t, err) + + require.InDelta(t, normalCost.ActualCost, cost.ActualCost, 1e-10) +} + +func TestCalculateImageCost(t *testing.T) { + svc := newTestBillingService() + + price := 0.134 + cfg := &ImagePriceConfig{Price1K: &price} + cost := svc.CalculateImageCost("gpt-image-1", "1K", 3, cfg, 1.0) + + require.InDelta(t, 0.134*3, cost.TotalCost, 1e-10) + require.InDelta(t, 0.134*3, cost.ActualCost, 1e-10) +} + +func TestCalculateSoraVideoCost(t *testing.T) { + svc := newTestBillingService() + + price := 0.5 + cfg := &SoraPriceConfig{VideoPricePerRequest: &price} + cost := svc.CalculateSoraVideoCost("sora-video", cfg, 1.0) + + require.InDelta(t, 0.5, cost.TotalCost, 1e-10) +} + +func TestCalculateSoraVideoCost_HDModel(t *testing.T) { + svc := newTestBillingService() + + hdPrice := 1.0 + normalPrice := 0.5 + cfg := &SoraPriceConfig{ + VideoPricePerRequest: &normalPrice, + VideoPricePerRequestHD: &hdPrice, + } + cost := svc.CalculateSoraVideoCost("sora2pro-hd", cfg, 1.0) + require.InDelta(t, 1.0, cost.TotalCost, 1e-10) +} + +func TestIsModelSupported(t *testing.T) { + svc := newTestBillingService() + + require.True(t, svc.IsModelSupported("claude-sonnet-4")) + require.True(t, svc.IsModelSupported("Claude-Opus-4.5")) + require.True(t, svc.IsModelSupported("claude-3-haiku")) + require.False(t, svc.IsModelSupported("gpt-4o")) + require.False(t, svc.IsModelSupported("gemini-pro")) +} + +func TestCalculateCost_ZeroTokens(t *testing.T) { + svc := newTestBillingService() + + cost, err := svc.CalculateCost("claude-sonnet-4", UsageTokens{}, 1.0) + require.NoError(t, err) + require.Equal(t, 0.0, cost.TotalCost) + require.Equal(t, 0.0, cost.ActualCost) +} + +func TestCalculateCostWithConfig(t *testing.T) { + cfg := &config.Config{} + cfg.Default.RateMultiplier = 1.5 + svc := NewBillingService(cfg, nil) + + tokens := UsageTokens{InputTokens: 1000, OutputTokens: 500} + cost, err := svc.CalculateCostWithConfig("claude-sonnet-4", tokens) + require.NoError(t, err) + + expected, _ := svc.CalculateCost("claude-sonnet-4", tokens, 1.5) + require.InDelta(t, expected.ActualCost, cost.ActualCost, 1e-10) +} + +func TestCalculateCostWithConfig_ZeroMultiplier(t *testing.T) { + cfg := &config.Config{} + cfg.Default.RateMultiplier = 0 + svc := NewBillingService(cfg, nil) + + tokens := UsageTokens{InputTokens: 1000} + cost, err := svc.CalculateCostWithConfig("claude-sonnet-4", tokens) + require.NoError(t, err) + + // 倍率 <=0 时默认 1.0 + expected, _ := svc.CalculateCost("claude-sonnet-4", tokens, 1.0) + require.InDelta(t, expected.ActualCost, cost.ActualCost, 1e-10) +} + +func TestGetEstimatedCost(t *testing.T) { + svc := newTestBillingService() + + est, err := svc.GetEstimatedCost("claude-sonnet-4", 1000, 500) + require.NoError(t, err) + require.True(t, est > 0) +} + +func TestListSupportedModels(t *testing.T) { + svc := newTestBillingService() + + models := svc.ListSupportedModels() + require.NotEmpty(t, models) + require.GreaterOrEqual(t, len(models), 6) +} + +func TestGetPricingServiceStatus_NilService(t *testing.T) { + svc := newTestBillingService() + + status := svc.GetPricingServiceStatus() + require.NotNil(t, status) + require.Equal(t, "using fallback", status["last_updated"]) +} + +func TestForceUpdatePricing_NilService(t *testing.T) { + svc := newTestBillingService() + + err := svc.ForceUpdatePricing() + require.Error(t, err) + require.Contains(t, err.Error(), "not initialized") +} + +func TestCalculateSoraImageCost(t *testing.T) { + svc := newTestBillingService() + + price360 := 0.05 + price540 := 0.08 + cfg := &SoraPriceConfig{ImagePrice360: &price360, ImagePrice540: &price540} + + cost := svc.CalculateSoraImageCost("360", 2, cfg, 1.0) + require.InDelta(t, 0.10, cost.TotalCost, 1e-10) + + cost540 := svc.CalculateSoraImageCost("540", 1, cfg, 2.0) + require.InDelta(t, 0.08, cost540.TotalCost, 1e-10) + require.InDelta(t, 0.16, cost540.ActualCost, 1e-10) +} + +func TestCalculateSoraImageCost_ZeroCount(t *testing.T) { + svc := newTestBillingService() + cost := svc.CalculateSoraImageCost("360", 0, nil, 1.0) + require.Equal(t, 0.0, cost.TotalCost) +} + +func TestCalculateSoraVideoCost_NilConfig(t *testing.T) { + svc := newTestBillingService() + cost := svc.CalculateSoraVideoCost("sora-video", nil, 1.0) + require.Equal(t, 0.0, cost.TotalCost) +} + +func TestCalculateCostWithLongContext_PropagatesError(t *testing.T) { + // 使用空的 fallback prices 让 GetModelPricing 失败 + svc := &BillingService{ + cfg: &config.Config{}, + fallbackPrices: make(map[string]*ModelPricing), + } + + tokens := UsageTokens{InputTokens: 300000, CacheReadTokens: 0} + _, err := svc.CalculateCostWithLongContext("unknown-model", tokens, 1.0, 200000, 2.0) + require.Error(t, err) + require.Contains(t, err.Error(), "pricing not found") +} + +func TestCalculateCost_SupportsCacheBreakdown(t *testing.T) { + svc := &BillingService{ + cfg: &config.Config{}, + fallbackPrices: map[string]*ModelPricing{ + "claude-sonnet-4": { + InputPricePerToken: 3e-6, + OutputPricePerToken: 15e-6, + SupportsCacheBreakdown: true, + CacheCreation5mPrice: 4e-6, // per token + CacheCreation1hPrice: 5e-6, // per token + }, + }, + } + + tokens := UsageTokens{ + InputTokens: 1000, + OutputTokens: 500, + CacheCreation5mTokens: 100000, + CacheCreation1hTokens: 50000, + } + cost, err := svc.CalculateCost("claude-sonnet-4", tokens, 1.0) + require.NoError(t, err) + + expected5m := float64(tokens.CacheCreation5mTokens) * 4e-6 + expected1h := float64(tokens.CacheCreation1hTokens) * 5e-6 + require.InDelta(t, expected5m+expected1h, cost.CacheCreationCost, 1e-10) +} + +func TestCalculateCost_LargeTokenCount(t *testing.T) { + svc := newTestBillingService() + + tokens := UsageTokens{ + InputTokens: 1_000_000, + OutputTokens: 1_000_000, + } + cost, err := svc.CalculateCost("claude-sonnet-4", tokens, 1.0) + require.NoError(t, err) + + // Input: 1M * 3e-6 = $3, Output: 1M * 15e-6 = $15 + require.InDelta(t, 3.0, cost.InputCost, 1e-6) + require.InDelta(t, 15.0, cost.OutputCost, 1e-6) + require.False(t, math.IsNaN(cost.TotalCost)) + require.False(t, math.IsInf(cost.TotalCost, 0)) +} diff --git a/backend/internal/service/claude_code_detection_test.go b/backend/internal/service/claude_code_detection_test.go new file mode 100644 index 00000000..ff7ad7f4 --- /dev/null +++ b/backend/internal/service/claude_code_detection_test.go @@ -0,0 +1,282 @@ +//go:build unit + +package service + +import ( + "context" + "net/http/httptest" + "testing" + + "github.com/Wei-Shaw/sub2api/internal/pkg/ctxkey" + "github.com/stretchr/testify/require" +) + +func newTestValidator() *ClaudeCodeValidator { + return NewClaudeCodeValidator() +} + +// validClaudeCodeBody 构造一个完整有效的 Claude Code 请求体 +func validClaudeCodeBody() map[string]any { + return map[string]any{ + "model": "claude-sonnet-4-20250514", + "system": []any{ + map[string]any{ + "type": "text", + "text": "You are Claude Code, Anthropic's official CLI for Claude.", + }, + }, + "metadata": map[string]any{ + "user_id": "user_" + "a1b2c3d4e5f6a1b2c3d4e5f6a1b2c3d4e5f6a1b2c3d4e5f6a1b2c3d4e5f6a1b2" + "_account__session_" + "12345678-1234-1234-1234-123456789abc", + }, + } +} + +func TestValidate_ClaudeCLIUserAgent(t *testing.T) { + v := newTestValidator() + + tests := []struct { + name string + ua string + want bool + }{ + {"标准版本号", "claude-cli/1.0.0", true}, + {"多位版本号", "claude-cli/12.34.56", true}, + {"大写开头", "Claude-CLI/1.0.0", true}, + {"非 claude-cli", "curl/7.64.1", false}, + {"空 User-Agent", "", false}, + {"部分匹配", "not-claude-cli/1.0.0", false}, + {"缺少版本号", "claude-cli/", false}, + {"版本格式不对", "claude-cli/1.0", false}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + require.Equal(t, tt.want, v.ValidateUserAgent(tt.ua), "UA: %q", tt.ua) + }) + } +} + +func TestValidate_NonMessagesPath_UAOnly(t *testing.T) { + v := newTestValidator() + + // 非 messages 路径只检查 UA + req := httptest.NewRequest("GET", "/v1/models", nil) + req.Header.Set("User-Agent", "claude-cli/1.0.0") + + result := v.Validate(req, nil) + require.True(t, result, "非 messages 路径只需 UA 匹配") +} + +func TestValidate_NonMessagesPath_InvalidUA(t *testing.T) { + v := newTestValidator() + + req := httptest.NewRequest("GET", "/v1/models", nil) + req.Header.Set("User-Agent", "curl/7.64.1") + + result := v.Validate(req, nil) + require.False(t, result, "UA 不匹配时应返回 false") +} + +func TestValidate_MessagesPath_FullValid(t *testing.T) { + v := newTestValidator() + + req := httptest.NewRequest("POST", "/v1/messages", nil) + req.Header.Set("User-Agent", "claude-cli/1.0.0") + req.Header.Set("X-App", "claude-code") + req.Header.Set("anthropic-beta", "max-tokens-3-5-sonnet-2024-07-15") + req.Header.Set("anthropic-version", "2023-06-01") + + result := v.Validate(req, validClaudeCodeBody()) + require.True(t, result, "完整有效请求应通过") +} + +func TestValidate_MessagesPath_MissingHeaders(t *testing.T) { + v := newTestValidator() + body := validClaudeCodeBody() + + tests := []struct { + name string + missingHeader string + }{ + {"缺少 X-App", "X-App"}, + {"缺少 anthropic-beta", "anthropic-beta"}, + {"缺少 anthropic-version", "anthropic-version"}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + req := httptest.NewRequest("POST", "/v1/messages", nil) + req.Header.Set("User-Agent", "claude-cli/1.0.0") + req.Header.Set("X-App", "claude-code") + req.Header.Set("anthropic-beta", "beta") + req.Header.Set("anthropic-version", "2023-06-01") + req.Header.Del(tt.missingHeader) + + result := v.Validate(req, body) + require.False(t, result, "缺少 %s 应返回 false", tt.missingHeader) + }) + } +} + +func TestValidate_MessagesPath_InvalidMetadataUserID(t *testing.T) { + v := newTestValidator() + + tests := []struct { + name string + metadata map[string]any + }{ + {"缺少 metadata", nil}, + {"缺少 user_id", map[string]any{"other": "value"}}, + {"空 user_id", map[string]any{"user_id": ""}}, + {"格式错误", map[string]any{"user_id": "invalid-format"}}, + {"hex 长度不足", map[string]any{"user_id": "user_abc_account__session_uuid"}}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + req := httptest.NewRequest("POST", "/v1/messages", nil) + req.Header.Set("User-Agent", "claude-cli/1.0.0") + req.Header.Set("X-App", "claude-code") + req.Header.Set("anthropic-beta", "beta") + req.Header.Set("anthropic-version", "2023-06-01") + + body := map[string]any{ + "model": "claude-sonnet-4", + "system": []any{ + map[string]any{ + "type": "text", + "text": "You are Claude Code, Anthropic's official CLI for Claude.", + }, + }, + } + if tt.metadata != nil { + body["metadata"] = tt.metadata + } + + result := v.Validate(req, body) + require.False(t, result, "metadata.user_id: %v", tt.metadata) + }) + } +} + +func TestValidate_MessagesPath_InvalidSystemPrompt(t *testing.T) { + v := newTestValidator() + + req := httptest.NewRequest("POST", "/v1/messages", nil) + req.Header.Set("User-Agent", "claude-cli/1.0.0") + req.Header.Set("X-App", "claude-code") + req.Header.Set("anthropic-beta", "beta") + req.Header.Set("anthropic-version", "2023-06-01") + + body := map[string]any{ + "model": "claude-sonnet-4", + "system": []any{ + map[string]any{ + "type": "text", + "text": "Generate JSON data for testing database migrations.", + }, + }, + "metadata": map[string]any{ + "user_id": "user_" + "a1b2c3d4e5f6a1b2c3d4e5f6a1b2c3d4e5f6a1b2c3d4e5f6a1b2c3d4e5f6a1b2" + "_account__session_12345678-1234-1234-1234-123456789abc", + }, + } + + result := v.Validate(req, body) + require.False(t, result, "无关系统提示词应返回 false") +} + +func TestValidate_MaxTokensOneHaikuBypass(t *testing.T) { + v := newTestValidator() + + req := httptest.NewRequest("POST", "/v1/messages", nil) + req.Header.Set("User-Agent", "claude-cli/1.0.0") + // 不设置 X-App 等头,通过 context 标记为 haiku 探测请求 + ctx := context.WithValue(req.Context(), ctxkey.IsMaxTokensOneHaikuRequest, true) + req = req.WithContext(ctx) + + // 即使 body 不包含 system prompt,也应通过 + result := v.Validate(req, map[string]any{"model": "claude-3-haiku", "max_tokens": 1}) + require.True(t, result, "max_tokens=1+haiku 探测请求应绕过严格验证") +} + +func TestSystemPromptSimilarity(t *testing.T) { + v := newTestValidator() + + tests := []struct { + name string + prompt string + want bool + }{ + {"精确匹配", "You are Claude Code, Anthropic's official CLI for Claude.", true}, + {"带多余空格", "You are Claude Code, Anthropic's official CLI for Claude.", true}, + {"Agent SDK 模板", "You are a Claude agent, built on Anthropic's Claude Agent SDK.", true}, + {"文件搜索专家模板", "You are a file search specialist for Claude Code, Anthropic's official CLI for Claude.", true}, + {"对话摘要模板", "You are a helpful AI assistant tasked with summarizing conversations.", true}, + {"交互式 CLI 模板", "You are an interactive CLI tool that helps users", true}, + {"无关文本", "Write me a poem about cats", false}, + {"空文本", "", false}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + body := map[string]any{ + "model": "claude-sonnet-4", + "system": []any{ + map[string]any{"type": "text", "text": tt.prompt}, + }, + } + result := v.IncludesClaudeCodeSystemPrompt(body) + require.Equal(t, tt.want, result, "提示词: %q", tt.prompt) + }) + } +} + +func TestDiceCoefficient(t *testing.T) { + tests := []struct { + name string + a string + b string + want float64 + tol float64 + }{ + {"相同字符串", "hello", "hello", 1.0, 0.001}, + {"完全不同", "abc", "xyz", 0.0, 0.001}, + {"空字符串", "", "hello", 0.0, 0.001}, + {"单字符", "a", "b", 0.0, 0.001}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + result := diceCoefficient(tt.a, tt.b) + require.InDelta(t, tt.want, result, tt.tol) + }) + } +} + +func TestIsClaudeCodeClient_Context(t *testing.T) { + ctx := context.Background() + + // 默认应为 false + require.False(t, IsClaudeCodeClient(ctx)) + + // 设置为 true + ctx = SetClaudeCodeClient(ctx, true) + require.True(t, IsClaudeCodeClient(ctx)) + + // 设置为 false + ctx = SetClaudeCodeClient(ctx, false) + require.False(t, IsClaudeCodeClient(ctx)) +} + +func TestValidate_NilBody_MessagesPath(t *testing.T) { + v := newTestValidator() + + req := httptest.NewRequest("POST", "/v1/messages", nil) + req.Header.Set("User-Agent", "claude-cli/1.0.0") + req.Header.Set("X-App", "claude-code") + req.Header.Set("anthropic-beta", "beta") + req.Header.Set("anthropic-version", "2023-06-01") + + result := v.Validate(req, nil) + require.False(t, result, "nil body 的 messages 请求应返回 false") +} diff --git a/backend/internal/service/concurrency_service.go b/backend/internal/service/concurrency_service.go index d5cb2025..32b6d97c 100644 --- a/backend/internal/service/concurrency_service.go +++ b/backend/internal/service/concurrency_service.go @@ -5,8 +5,9 @@ import ( "crypto/rand" "encoding/hex" "fmt" - "log" "time" + + "github.com/Wei-Shaw/sub2api/internal/pkg/logger" ) // ConcurrencyCache 定义并发控制的缓存接口 @@ -124,7 +125,7 @@ func (s *ConcurrencyService) AcquireAccountSlot(ctx context.Context, accountID i bgCtx, cancel := context.WithTimeout(context.Background(), 5*time.Second) defer cancel() if err := s.cache.ReleaseAccountSlot(bgCtx, accountID, requestID); err != nil { - log.Printf("Warning: failed to release account slot for %d (req=%s): %v", accountID, requestID, err) + logger.LegacyPrintf("service.concurrency", "Warning: failed to release account slot for %d (req=%s): %v", accountID, requestID, err) } }, }, nil @@ -163,7 +164,7 @@ func (s *ConcurrencyService) AcquireUserSlot(ctx context.Context, userID int64, bgCtx, cancel := context.WithTimeout(context.Background(), 5*time.Second) defer cancel() if err := s.cache.ReleaseUserSlot(bgCtx, userID, requestID); err != nil { - log.Printf("Warning: failed to release user slot for %d (req=%s): %v", userID, requestID, err) + logger.LegacyPrintf("service.concurrency", "Warning: failed to release user slot for %d (req=%s): %v", userID, requestID, err) } }, }, nil @@ -191,7 +192,7 @@ func (s *ConcurrencyService) IncrementWaitCount(ctx context.Context, userID int6 result, err := s.cache.IncrementWaitCount(ctx, userID, maxWait) if err != nil { // On error, allow the request to proceed (fail open) - log.Printf("Warning: increment wait count failed for user %d: %v", userID, err) + logger.LegacyPrintf("service.concurrency", "Warning: increment wait count failed for user %d: %v", userID, err) return true, nil } return result, nil @@ -209,7 +210,7 @@ func (s *ConcurrencyService) DecrementWaitCount(ctx context.Context, userID int6 defer cancel() if err := s.cache.DecrementWaitCount(bgCtx, userID); err != nil { - log.Printf("Warning: decrement wait count failed for user %d: %v", userID, err) + logger.LegacyPrintf("service.concurrency", "Warning: decrement wait count failed for user %d: %v", userID, err) } } @@ -221,7 +222,7 @@ func (s *ConcurrencyService) IncrementAccountWaitCount(ctx context.Context, acco result, err := s.cache.IncrementAccountWaitCount(ctx, accountID, maxWait) if err != nil { - log.Printf("Warning: increment wait count failed for account %d: %v", accountID, err) + logger.LegacyPrintf("service.concurrency", "Warning: increment wait count failed for account %d: %v", accountID, err) return true, nil } return result, nil @@ -237,7 +238,7 @@ func (s *ConcurrencyService) DecrementAccountWaitCount(ctx context.Context, acco defer cancel() if err := s.cache.DecrementAccountWaitCount(bgCtx, accountID); err != nil { - log.Printf("Warning: decrement wait count failed for account %d: %v", accountID, err) + logger.LegacyPrintf("service.concurrency", "Warning: decrement wait count failed for account %d: %v", accountID, err) } } @@ -293,7 +294,7 @@ func (s *ConcurrencyService) StartSlotCleanupWorker(accountRepo AccountRepositor accounts, err := accountRepo.ListSchedulable(listCtx) cancel() if err != nil { - log.Printf("Warning: list schedulable accounts failed: %v", err) + logger.LegacyPrintf("service.concurrency", "Warning: list schedulable accounts failed: %v", err) return } for _, account := range accounts { @@ -301,7 +302,7 @@ func (s *ConcurrencyService) StartSlotCleanupWorker(accountRepo AccountRepositor err := s.cache.CleanupExpiredAccountSlots(accountCtx, account.ID) accountCancel() if err != nil { - log.Printf("Warning: cleanup expired slots failed for account %d: %v", account.ID, err) + logger.LegacyPrintf("service.concurrency", "Warning: cleanup expired slots failed for account %d: %v", account.ID, err) } } } diff --git a/backend/internal/service/concurrency_service_test.go b/backend/internal/service/concurrency_service_test.go new file mode 100644 index 00000000..33ce4cb9 --- /dev/null +++ b/backend/internal/service/concurrency_service_test.go @@ -0,0 +1,280 @@ +//go:build unit + +package service + +import ( + "context" + "errors" + "testing" + + "github.com/stretchr/testify/require" +) + +// stubConcurrencyCacheForTest 用于并发服务单元测试的缓存桩 +type stubConcurrencyCacheForTest struct { + acquireResult bool + acquireErr error + releaseErr error + concurrency int + concurrencyErr error + waitAllowed bool + waitErr error + waitCount int + waitCountErr error + loadBatch map[int64]*AccountLoadInfo + loadBatchErr error + usersLoadBatch map[int64]*UserLoadInfo + usersLoadErr error + cleanupErr error + + // 记录调用 + releasedAccountIDs []int64 + releasedRequestIDs []string +} + +var _ ConcurrencyCache = (*stubConcurrencyCacheForTest)(nil) + +func (c *stubConcurrencyCacheForTest) AcquireAccountSlot(_ context.Context, _ int64, _ int, _ string) (bool, error) { + return c.acquireResult, c.acquireErr +} +func (c *stubConcurrencyCacheForTest) ReleaseAccountSlot(_ context.Context, accountID int64, requestID string) error { + c.releasedAccountIDs = append(c.releasedAccountIDs, accountID) + c.releasedRequestIDs = append(c.releasedRequestIDs, requestID) + return c.releaseErr +} +func (c *stubConcurrencyCacheForTest) GetAccountConcurrency(_ context.Context, _ int64) (int, error) { + return c.concurrency, c.concurrencyErr +} +func (c *stubConcurrencyCacheForTest) IncrementAccountWaitCount(_ context.Context, _ int64, _ int) (bool, error) { + return c.waitAllowed, c.waitErr +} +func (c *stubConcurrencyCacheForTest) DecrementAccountWaitCount(_ context.Context, _ int64) error { + return nil +} +func (c *stubConcurrencyCacheForTest) GetAccountWaitingCount(_ context.Context, _ int64) (int, error) { + return c.waitCount, c.waitCountErr +} +func (c *stubConcurrencyCacheForTest) AcquireUserSlot(_ context.Context, _ int64, _ int, _ string) (bool, error) { + return c.acquireResult, c.acquireErr +} +func (c *stubConcurrencyCacheForTest) ReleaseUserSlot(_ context.Context, _ int64, _ string) error { + return c.releaseErr +} +func (c *stubConcurrencyCacheForTest) GetUserConcurrency(_ context.Context, _ int64) (int, error) { + return c.concurrency, c.concurrencyErr +} +func (c *stubConcurrencyCacheForTest) IncrementWaitCount(_ context.Context, _ int64, _ int) (bool, error) { + return c.waitAllowed, c.waitErr +} +func (c *stubConcurrencyCacheForTest) DecrementWaitCount(_ context.Context, _ int64) error { + return nil +} +func (c *stubConcurrencyCacheForTest) GetAccountsLoadBatch(_ context.Context, _ []AccountWithConcurrency) (map[int64]*AccountLoadInfo, error) { + return c.loadBatch, c.loadBatchErr +} +func (c *stubConcurrencyCacheForTest) GetUsersLoadBatch(_ context.Context, _ []UserWithConcurrency) (map[int64]*UserLoadInfo, error) { + return c.usersLoadBatch, c.usersLoadErr +} +func (c *stubConcurrencyCacheForTest) CleanupExpiredAccountSlots(_ context.Context, _ int64) error { + return c.cleanupErr +} + +func TestAcquireAccountSlot_Success(t *testing.T) { + cache := &stubConcurrencyCacheForTest{acquireResult: true} + svc := NewConcurrencyService(cache) + + result, err := svc.AcquireAccountSlot(context.Background(), 1, 5) + require.NoError(t, err) + require.True(t, result.Acquired) + require.NotNil(t, result.ReleaseFunc) +} + +func TestAcquireAccountSlot_Failure(t *testing.T) { + cache := &stubConcurrencyCacheForTest{acquireResult: false} + svc := NewConcurrencyService(cache) + + result, err := svc.AcquireAccountSlot(context.Background(), 1, 5) + require.NoError(t, err) + require.False(t, result.Acquired) + require.Nil(t, result.ReleaseFunc) +} + +func TestAcquireAccountSlot_UnlimitedConcurrency(t *testing.T) { + svc := NewConcurrencyService(&stubConcurrencyCacheForTest{}) + + for _, maxConcurrency := range []int{0, -1} { + result, err := svc.AcquireAccountSlot(context.Background(), 1, maxConcurrency) + require.NoError(t, err) + require.True(t, result.Acquired, "maxConcurrency=%d 应无限制通过", maxConcurrency) + require.NotNil(t, result.ReleaseFunc, "ReleaseFunc 应为 no-op 函数") + } +} + +func TestAcquireAccountSlot_CacheError(t *testing.T) { + cache := &stubConcurrencyCacheForTest{acquireErr: errors.New("redis down")} + svc := NewConcurrencyService(cache) + + result, err := svc.AcquireAccountSlot(context.Background(), 1, 5) + require.Error(t, err) + require.Nil(t, result) +} + +func TestAcquireAccountSlot_ReleaseDecrements(t *testing.T) { + cache := &stubConcurrencyCacheForTest{acquireResult: true} + svc := NewConcurrencyService(cache) + + result, err := svc.AcquireAccountSlot(context.Background(), 42, 5) + require.NoError(t, err) + require.True(t, result.Acquired) + + // 调用 ReleaseFunc 应释放槽位 + result.ReleaseFunc() + + require.Len(t, cache.releasedAccountIDs, 1) + require.Equal(t, int64(42), cache.releasedAccountIDs[0]) + require.Len(t, cache.releasedRequestIDs, 1) + require.NotEmpty(t, cache.releasedRequestIDs[0], "requestID 不应为空") +} + +func TestAcquireUserSlot_IndependentFromAccount(t *testing.T) { + cache := &stubConcurrencyCacheForTest{acquireResult: true} + svc := NewConcurrencyService(cache) + + // 用户槽位获取应独立于账户槽位 + result, err := svc.AcquireUserSlot(context.Background(), 100, 3) + require.NoError(t, err) + require.True(t, result.Acquired) + require.NotNil(t, result.ReleaseFunc) +} + +func TestAcquireUserSlot_UnlimitedConcurrency(t *testing.T) { + svc := NewConcurrencyService(&stubConcurrencyCacheForTest{}) + + result, err := svc.AcquireUserSlot(context.Background(), 1, 0) + require.NoError(t, err) + require.True(t, result.Acquired) +} + +func TestGetAccountsLoadBatch_ReturnsCorrectData(t *testing.T) { + expected := map[int64]*AccountLoadInfo{ + 1: {AccountID: 1, CurrentConcurrency: 3, WaitingCount: 0, LoadRate: 60}, + 2: {AccountID: 2, CurrentConcurrency: 5, WaitingCount: 2, LoadRate: 100}, + } + cache := &stubConcurrencyCacheForTest{loadBatch: expected} + svc := NewConcurrencyService(cache) + + accounts := []AccountWithConcurrency{ + {ID: 1, MaxConcurrency: 5}, + {ID: 2, MaxConcurrency: 5}, + } + result, err := svc.GetAccountsLoadBatch(context.Background(), accounts) + require.NoError(t, err) + require.Equal(t, expected, result) +} + +func TestGetAccountsLoadBatch_NilCache(t *testing.T) { + svc := &ConcurrencyService{cache: nil} + + result, err := svc.GetAccountsLoadBatch(context.Background(), nil) + require.NoError(t, err) + require.Empty(t, result) +} + +func TestIncrementWaitCount_Success(t *testing.T) { + cache := &stubConcurrencyCacheForTest{waitAllowed: true} + svc := NewConcurrencyService(cache) + + allowed, err := svc.IncrementWaitCount(context.Background(), 1, 25) + require.NoError(t, err) + require.True(t, allowed) +} + +func TestIncrementWaitCount_QueueFull(t *testing.T) { + cache := &stubConcurrencyCacheForTest{waitAllowed: false} + svc := NewConcurrencyService(cache) + + allowed, err := svc.IncrementWaitCount(context.Background(), 1, 25) + require.NoError(t, err) + require.False(t, allowed) +} + +func TestIncrementWaitCount_FailOpen(t *testing.T) { + // Redis 错误时应 fail-open(允许请求通过) + cache := &stubConcurrencyCacheForTest{waitErr: errors.New("redis timeout")} + svc := NewConcurrencyService(cache) + + allowed, err := svc.IncrementWaitCount(context.Background(), 1, 25) + require.NoError(t, err, "Redis 错误不应传播") + require.True(t, allowed, "Redis 错误时应 fail-open") +} + +func TestIncrementWaitCount_NilCache(t *testing.T) { + svc := &ConcurrencyService{cache: nil} + + allowed, err := svc.IncrementWaitCount(context.Background(), 1, 25) + require.NoError(t, err) + require.True(t, allowed, "nil cache 应 fail-open") +} + +func TestCalculateMaxWait(t *testing.T) { + tests := []struct { + concurrency int + expected int + }{ + {5, 25}, // 5 + 20 + {1, 21}, // 1 + 20 + {0, 21}, // min(1) + 20 + {-1, 21}, // min(1) + 20 + {10, 30}, // 10 + 20 + } + for _, tt := range tests { + result := CalculateMaxWait(tt.concurrency) + require.Equal(t, tt.expected, result, "CalculateMaxWait(%d)", tt.concurrency) + } +} + +func TestGetAccountWaitingCount(t *testing.T) { + cache := &stubConcurrencyCacheForTest{waitCount: 5} + svc := NewConcurrencyService(cache) + + count, err := svc.GetAccountWaitingCount(context.Background(), 1) + require.NoError(t, err) + require.Equal(t, 5, count) +} + +func TestGetAccountWaitingCount_NilCache(t *testing.T) { + svc := &ConcurrencyService{cache: nil} + + count, err := svc.GetAccountWaitingCount(context.Background(), 1) + require.NoError(t, err) + require.Equal(t, 0, count) +} + +func TestGetAccountConcurrencyBatch(t *testing.T) { + cache := &stubConcurrencyCacheForTest{concurrency: 3} + svc := NewConcurrencyService(cache) + + result, err := svc.GetAccountConcurrencyBatch(context.Background(), []int64{1, 2, 3}) + require.NoError(t, err) + require.Len(t, result, 3) + for _, id := range []int64{1, 2, 3} { + require.Equal(t, 3, result[id]) + } +} + +func TestIncrementAccountWaitCount_FailOpen(t *testing.T) { + cache := &stubConcurrencyCacheForTest{waitErr: errors.New("redis error")} + svc := NewConcurrencyService(cache) + + allowed, err := svc.IncrementAccountWaitCount(context.Background(), 1, 10) + require.NoError(t, err, "Redis 错误不应传播") + require.True(t, allowed, "Redis 错误时应 fail-open") +} + +func TestIncrementAccountWaitCount_NilCache(t *testing.T) { + svc := &ConcurrencyService{cache: nil} + + allowed, err := svc.IncrementAccountWaitCount(context.Background(), 1, 10) + require.NoError(t, err) + require.True(t, allowed) +} diff --git a/backend/internal/service/dashboard_aggregation_service.go b/backend/internal/service/dashboard_aggregation_service.go index 10c68868..a67f8532 100644 --- a/backend/internal/service/dashboard_aggregation_service.go +++ b/backend/internal/service/dashboard_aggregation_service.go @@ -3,11 +3,12 @@ package service import ( "context" "errors" - "log" + "log/slog" "sync/atomic" "time" "github.com/Wei-Shaw/sub2api/internal/config" + "github.com/Wei-Shaw/sub2api/internal/pkg/logger" ) const ( @@ -65,7 +66,7 @@ func (s *DashboardAggregationService) Start() { return } if !s.cfg.Enabled { - log.Printf("[DashboardAggregation] 聚合作业已禁用") + logger.LegacyPrintf("service.dashboard_aggregation", "[DashboardAggregation] 聚合作业已禁用") return } @@ -81,9 +82,9 @@ func (s *DashboardAggregationService) Start() { s.timingWheel.ScheduleRecurring("dashboard:aggregation", interval, func() { s.runScheduledAggregation() }) - log.Printf("[DashboardAggregation] 聚合作业启动 (interval=%v, lookback=%ds)", interval, s.cfg.LookbackSeconds) + logger.LegacyPrintf("service.dashboard_aggregation", "[DashboardAggregation] 聚合作业启动 (interval=%v, lookback=%ds)", interval, s.cfg.LookbackSeconds) if !s.cfg.BackfillEnabled { - log.Printf("[DashboardAggregation] 回填已禁用,如需补齐保留窗口以外历史数据请手动回填") + logger.LegacyPrintf("service.dashboard_aggregation", "[DashboardAggregation] 回填已禁用,如需补齐保留窗口以外历史数据请手动回填") } } @@ -93,7 +94,7 @@ func (s *DashboardAggregationService) TriggerBackfill(start, end time.Time) erro return errors.New("聚合服务未初始化") } if !s.cfg.BackfillEnabled { - log.Printf("[DashboardAggregation] 回填被拒绝: backfill_enabled=false") + logger.LegacyPrintf("service.dashboard_aggregation", "[DashboardAggregation] 回填被拒绝: backfill_enabled=false") return ErrDashboardBackfillDisabled } if !end.After(start) { @@ -110,7 +111,7 @@ func (s *DashboardAggregationService) TriggerBackfill(start, end time.Time) erro ctx, cancel := context.WithTimeout(context.Background(), defaultDashboardAggregationBackfillTimeout) defer cancel() if err := s.backfillRange(ctx, start, end); err != nil { - log.Printf("[DashboardAggregation] 回填失败: %v", err) + logger.LegacyPrintf("service.dashboard_aggregation", "[DashboardAggregation] 回填失败: %v", err) } }() return nil @@ -141,12 +142,12 @@ func (s *DashboardAggregationService) TriggerRecomputeRange(start, end time.Time return } if !errors.Is(err, errDashboardAggregationRunning) { - log.Printf("[DashboardAggregation] 重新计算失败: %v", err) + logger.LegacyPrintf("service.dashboard_aggregation", "[DashboardAggregation] 重新计算失败: %v", err) return } time.Sleep(5 * time.Second) } - log.Printf("[DashboardAggregation] 重新计算放弃: 聚合作业持续占用") + logger.LegacyPrintf("service.dashboard_aggregation", "[DashboardAggregation] 重新计算放弃: 聚合作业持续占用") }() return nil } @@ -162,7 +163,7 @@ func (s *DashboardAggregationService) recomputeRecentDays() { ctx, cancel := context.WithTimeout(context.Background(), defaultDashboardAggregationBackfillTimeout) defer cancel() if err := s.backfillRange(ctx, start, now); err != nil { - log.Printf("[DashboardAggregation] 启动重算失败: %v", err) + logger.LegacyPrintf("service.dashboard_aggregation", "[DashboardAggregation] 启动重算失败: %v", err) return } } @@ -177,7 +178,7 @@ func (s *DashboardAggregationService) recomputeRange(ctx context.Context, start, if err := s.repo.RecomputeRange(ctx, start, end); err != nil { return err } - log.Printf("[DashboardAggregation] 重新计算完成 (start=%s end=%s duration=%s)", + logger.LegacyPrintf("service.dashboard_aggregation", "[DashboardAggregation] 重新计算完成 (start=%s end=%s duration=%s)", start.UTC().Format(time.RFC3339), end.UTC().Format(time.RFC3339), time.Since(jobStart).String(), @@ -198,7 +199,7 @@ func (s *DashboardAggregationService) runScheduledAggregation() { now := time.Now().UTC() last, err := s.repo.GetAggregationWatermark(ctx) if err != nil { - log.Printf("[DashboardAggregation] 读取水位失败: %v", err) + logger.LegacyPrintf("service.dashboard_aggregation", "[DashboardAggregation] 读取水位失败: %v", err) last = time.Unix(0, 0).UTC() } @@ -216,19 +217,19 @@ func (s *DashboardAggregationService) runScheduledAggregation() { } if err := s.aggregateRange(ctx, start, now); err != nil { - log.Printf("[DashboardAggregation] 聚合失败: %v", err) + logger.LegacyPrintf("service.dashboard_aggregation", "[DashboardAggregation] 聚合失败: %v", err) return } updateErr := s.repo.UpdateAggregationWatermark(ctx, now) if updateErr != nil { - log.Printf("[DashboardAggregation] 更新水位失败: %v", updateErr) + logger.LegacyPrintf("service.dashboard_aggregation", "[DashboardAggregation] 更新水位失败: %v", updateErr) } - log.Printf("[DashboardAggregation] 聚合完成 (start=%s end=%s duration=%s watermark_updated=%t)", - start.Format(time.RFC3339), - now.Format(time.RFC3339), - time.Since(jobStart).String(), - updateErr == nil, + slog.Debug("[DashboardAggregation] 聚合完成", + "start", start.Format(time.RFC3339), + "end", now.Format(time.RFC3339), + "duration", time.Since(jobStart).String(), + "watermark_updated", updateErr == nil, ) s.maybeCleanupRetention(ctx, now) @@ -261,9 +262,9 @@ func (s *DashboardAggregationService) backfillRange(ctx context.Context, start, updateErr := s.repo.UpdateAggregationWatermark(ctx, endUTC) if updateErr != nil { - log.Printf("[DashboardAggregation] 更新水位失败: %v", updateErr) + logger.LegacyPrintf("service.dashboard_aggregation", "[DashboardAggregation] 更新水位失败: %v", updateErr) } - log.Printf("[DashboardAggregation] 回填聚合完成 (start=%s end=%s duration=%s watermark_updated=%t)", + logger.LegacyPrintf("service.dashboard_aggregation", "[DashboardAggregation] 回填聚合完成 (start=%s end=%s duration=%s watermark_updated=%t)", startUTC.Format(time.RFC3339), endUTC.Format(time.RFC3339), time.Since(jobStart).String(), @@ -279,7 +280,7 @@ func (s *DashboardAggregationService) aggregateRange(ctx context.Context, start, return nil } if err := s.repo.EnsureUsageLogsPartitions(ctx, end); err != nil { - log.Printf("[DashboardAggregation] 分区检查失败: %v", err) + logger.LegacyPrintf("service.dashboard_aggregation", "[DashboardAggregation] 分区检查失败: %v", err) } return s.repo.AggregateRange(ctx, start, end) } @@ -298,11 +299,11 @@ func (s *DashboardAggregationService) maybeCleanupRetention(ctx context.Context, aggErr := s.repo.CleanupAggregates(ctx, hourlyCutoff, dailyCutoff) if aggErr != nil { - log.Printf("[DashboardAggregation] 聚合保留清理失败: %v", aggErr) + logger.LegacyPrintf("service.dashboard_aggregation", "[DashboardAggregation] 聚合保留清理失败: %v", aggErr) } usageErr := s.repo.CleanupUsageLogs(ctx, usageCutoff) if usageErr != nil { - log.Printf("[DashboardAggregation] usage_logs 保留清理失败: %v", usageErr) + logger.LegacyPrintf("service.dashboard_aggregation", "[DashboardAggregation] usage_logs 保留清理失败: %v", usageErr) } if aggErr == nil && usageErr == nil { s.lastRetentionCleanup.Store(now) diff --git a/backend/internal/service/dashboard_service.go b/backend/internal/service/dashboard_service.go index cd11923e..9aab10d2 100644 --- a/backend/internal/service/dashboard_service.go +++ b/backend/internal/service/dashboard_service.go @@ -5,11 +5,11 @@ import ( "encoding/json" "errors" "fmt" - "log" "sync/atomic" "time" "github.com/Wei-Shaw/sub2api/internal/config" + "github.com/Wei-Shaw/sub2api/internal/pkg/logger" "github.com/Wei-Shaw/sub2api/internal/pkg/usagestats" ) @@ -113,7 +113,7 @@ func (s *DashboardService) GetDashboardStats(ctx context.Context) (*usagestats.D return cached, nil } if err != nil && !errors.Is(err, ErrDashboardStatsCacheMiss) { - log.Printf("[Dashboard] 仪表盘缓存读取失败: %v", err) + logger.LegacyPrintf("service.dashboard", "[Dashboard] 仪表盘缓存读取失败: %v", err) } } @@ -188,7 +188,7 @@ func (s *DashboardService) refreshDashboardStatsAsync() { stats, err := s.fetchDashboardStats(ctx) if err != nil { - log.Printf("[Dashboard] 仪表盘缓存异步刷新失败: %v", err) + logger.LegacyPrintf("service.dashboard", "[Dashboard] 仪表盘缓存异步刷新失败: %v", err) return } s.applyAggregationStatus(ctx, stats) @@ -220,12 +220,12 @@ func (s *DashboardService) saveDashboardStatsCache(ctx context.Context, stats *u } data, err := json.Marshal(entry) if err != nil { - log.Printf("[Dashboard] 仪表盘缓存序列化失败: %v", err) + logger.LegacyPrintf("service.dashboard", "[Dashboard] 仪表盘缓存序列化失败: %v", err) return } if err := s.cache.SetDashboardStats(ctx, string(data), s.cacheTTL); err != nil { - log.Printf("[Dashboard] 仪表盘缓存写入失败: %v", err) + logger.LegacyPrintf("service.dashboard", "[Dashboard] 仪表盘缓存写入失败: %v", err) } } @@ -237,10 +237,10 @@ func (s *DashboardService) evictDashboardStatsCache(reason error) { defer cancel() if err := s.cache.DeleteDashboardStats(cacheCtx); err != nil { - log.Printf("[Dashboard] 仪表盘缓存清理失败: %v", err) + logger.LegacyPrintf("service.dashboard", "[Dashboard] 仪表盘缓存清理失败: %v", err) } if reason != nil { - log.Printf("[Dashboard] 仪表盘缓存异常,已清理: %v", reason) + logger.LegacyPrintf("service.dashboard", "[Dashboard] 仪表盘缓存异常,已清理: %v", reason) } } @@ -271,7 +271,7 @@ func (s *DashboardService) fetchAggregationUpdatedAt(ctx context.Context) time.T } updatedAt, err := s.aggRepo.GetAggregationWatermark(ctx) if err != nil { - log.Printf("[Dashboard] 读取聚合水位失败: %v", err) + logger.LegacyPrintf("service.dashboard", "[Dashboard] 读取聚合水位失败: %v", err) return time.Unix(0, 0).UTC() } if updatedAt.IsZero() { @@ -319,16 +319,16 @@ func (s *DashboardService) GetUserUsageTrend(ctx context.Context, startTime, end return trend, nil } -func (s *DashboardService) GetBatchUserUsageStats(ctx context.Context, userIDs []int64) (map[int64]*usagestats.BatchUserUsageStats, error) { - stats, err := s.usageRepo.GetBatchUserUsageStats(ctx, userIDs) +func (s *DashboardService) GetBatchUserUsageStats(ctx context.Context, userIDs []int64, startTime, endTime time.Time) (map[int64]*usagestats.BatchUserUsageStats, error) { + stats, err := s.usageRepo.GetBatchUserUsageStats(ctx, userIDs, startTime, endTime) if err != nil { return nil, fmt.Errorf("get batch user usage stats: %w", err) } return stats, nil } -func (s *DashboardService) GetBatchAPIKeyUsageStats(ctx context.Context, apiKeyIDs []int64) (map[int64]*usagestats.BatchAPIKeyUsageStats, error) { - stats, err := s.usageRepo.GetBatchAPIKeyUsageStats(ctx, apiKeyIDs) +func (s *DashboardService) GetBatchAPIKeyUsageStats(ctx context.Context, apiKeyIDs []int64, startTime, endTime time.Time) (map[int64]*usagestats.BatchAPIKeyUsageStats, error) { + stats, err := s.usageRepo.GetBatchAPIKeyUsageStats(ctx, apiKeyIDs, startTime, endTime) if err != nil { return nil, fmt.Errorf("get batch api key usage stats: %w", err) } diff --git a/backend/internal/service/domain_constants.go b/backend/internal/service/domain_constants.go index 0295c23b..ceae443f 100644 --- a/backend/internal/service/domain_constants.go +++ b/backend/internal/service/domain_constants.go @@ -24,6 +24,7 @@ const ( PlatformOpenAI = domain.PlatformOpenAI PlatformGemini = domain.PlatformGemini PlatformAntigravity = domain.PlatformAntigravity + PlatformSora = domain.PlatformSora ) // Account type constants @@ -160,6 +161,9 @@ const ( // SettingKeyOpsAdvancedSettings stores JSON config for ops advanced settings (data retention, aggregation). SettingKeyOpsAdvancedSettings = "ops_advanced_settings" + // SettingKeyOpsRuntimeLogConfig stores JSON config for runtime log settings. + SettingKeyOpsRuntimeLogConfig = "ops_runtime_log_config" + // ========================= // Stream Timeout Handling // ========================= diff --git a/backend/internal/service/email_queue_service.go b/backend/internal/service/email_queue_service.go index 6c975c69..d8f0a518 100644 --- a/backend/internal/service/email_queue_service.go +++ b/backend/internal/service/email_queue_service.go @@ -3,9 +3,10 @@ package service import ( "context" "fmt" - "log" "sync" "time" + + "github.com/Wei-Shaw/sub2api/internal/pkg/logger" ) // Task type constants @@ -56,7 +57,7 @@ func (s *EmailQueueService) start() { s.wg.Add(1) go s.worker(i) } - log.Printf("[EmailQueue] Started %d workers", s.workers) + logger.LegacyPrintf("service.email_queue", "[EmailQueue] Started %d workers", s.workers) } // worker 工作协程 @@ -68,7 +69,7 @@ func (s *EmailQueueService) worker(id int) { case task := <-s.taskChan: s.processTask(id, task) case <-s.stopChan: - log.Printf("[EmailQueue] Worker %d stopping", id) + logger.LegacyPrintf("service.email_queue", "[EmailQueue] Worker %d stopping", id) return } } @@ -82,18 +83,18 @@ func (s *EmailQueueService) processTask(workerID int, task EmailTask) { switch task.TaskType { case TaskTypeVerifyCode: if err := s.emailService.SendVerifyCode(ctx, task.Email, task.SiteName); err != nil { - log.Printf("[EmailQueue] Worker %d failed to send verify code to %s: %v", workerID, task.Email, err) + logger.LegacyPrintf("service.email_queue", "[EmailQueue] Worker %d failed to send verify code to %s: %v", workerID, task.Email, err) } else { - log.Printf("[EmailQueue] Worker %d sent verify code to %s", workerID, task.Email) + logger.LegacyPrintf("service.email_queue", "[EmailQueue] Worker %d sent verify code to %s", workerID, task.Email) } case TaskTypePasswordReset: if err := s.emailService.SendPasswordResetEmailWithCooldown(ctx, task.Email, task.SiteName, task.ResetURL); err != nil { - log.Printf("[EmailQueue] Worker %d failed to send password reset to %s: %v", workerID, task.Email, err) + logger.LegacyPrintf("service.email_queue", "[EmailQueue] Worker %d failed to send password reset to %s: %v", workerID, task.Email, err) } else { - log.Printf("[EmailQueue] Worker %d sent password reset to %s", workerID, task.Email) + logger.LegacyPrintf("service.email_queue", "[EmailQueue] Worker %d sent password reset to %s", workerID, task.Email) } default: - log.Printf("[EmailQueue] Worker %d unknown task type: %s", workerID, task.TaskType) + logger.LegacyPrintf("service.email_queue", "[EmailQueue] Worker %d unknown task type: %s", workerID, task.TaskType) } } @@ -107,7 +108,7 @@ func (s *EmailQueueService) EnqueueVerifyCode(email, siteName string) error { select { case s.taskChan <- task: - log.Printf("[EmailQueue] Enqueued verify code task for %s", email) + logger.LegacyPrintf("service.email_queue", "[EmailQueue] Enqueued verify code task for %s", email) return nil default: return fmt.Errorf("email queue is full") @@ -125,7 +126,7 @@ func (s *EmailQueueService) EnqueuePasswordReset(email, siteName, resetURL strin select { case s.taskChan <- task: - log.Printf("[EmailQueue] Enqueued password reset task for %s", email) + logger.LegacyPrintf("service.email_queue", "[EmailQueue] Enqueued password reset task for %s", email) return nil default: return fmt.Errorf("email queue is full") @@ -136,5 +137,5 @@ func (s *EmailQueueService) EnqueuePasswordReset(email, siteName, resetURL strin func (s *EmailQueueService) Stop() { close(s.stopChan) s.wg.Wait() - log.Println("[EmailQueue] All workers stopped") + logger.LegacyPrintf("service.email_queue", "%s", "[EmailQueue] All workers stopped") } diff --git a/backend/internal/service/error_passthrough_runtime_test.go b/backend/internal/service/error_passthrough_runtime_test.go index 0a45e57a..7032d15b 100644 --- a/backend/internal/service/error_passthrough_runtime_test.go +++ b/backend/internal/service/error_passthrough_runtime_test.go @@ -76,7 +76,7 @@ func TestOpenAIHandleErrorResponse_NoRuleKeepsDefault(t *testing.T) { } account := &Account{ID: 12, Platform: PlatformOpenAI, Type: AccountTypeAPIKey} - _, err := svc.handleErrorResponse(context.Background(), resp, c, account) + _, err := svc.handleErrorResponse(context.Background(), resp, c, account, nil) require.Error(t, err) assert.Equal(t, http.StatusBadGateway, rec.Code) @@ -157,7 +157,7 @@ func TestOpenAIHandleErrorResponse_AppliesRuleFor422(t *testing.T) { } account := &Account{ID: 2, Platform: PlatformOpenAI, Type: AccountTypeAPIKey} - _, err := svc.handleErrorResponse(context.Background(), resp, c, account) + _, err := svc.handleErrorResponse(context.Background(), resp, c, account, nil) require.Error(t, err) assert.Equal(t, http.StatusTeapot, rec.Code) diff --git a/backend/internal/service/error_passthrough_service.go b/backend/internal/service/error_passthrough_service.go index da8c9ccf..26fdf9a7 100644 --- a/backend/internal/service/error_passthrough_service.go +++ b/backend/internal/service/error_passthrough_service.go @@ -2,13 +2,13 @@ package service import ( "context" - "log" "sort" "strings" "sync" "time" "github.com/Wei-Shaw/sub2api/internal/model" + "github.com/Wei-Shaw/sub2api/internal/pkg/logger" ) // ErrorPassthroughRepository 定义错误透传规则的数据访问接口 @@ -72,9 +72,9 @@ func NewErrorPassthroughService( // 启动时加载规则到本地缓存 ctx := context.Background() if err := svc.reloadRulesFromDB(ctx); err != nil { - log.Printf("[ErrorPassthroughService] Failed to load rules from DB on startup: %v", err) + logger.LegacyPrintf("service.error_passthrough", "[ErrorPassthroughService] Failed to load rules from DB on startup: %v", err) if fallbackErr := svc.refreshLocalCache(ctx); fallbackErr != nil { - log.Printf("[ErrorPassthroughService] Failed to load rules from cache fallback on startup: %v", fallbackErr) + logger.LegacyPrintf("service.error_passthrough", "[ErrorPassthroughService] Failed to load rules from cache fallback on startup: %v", fallbackErr) } } @@ -82,7 +82,7 @@ func NewErrorPassthroughService( if cache != nil { cache.SubscribeUpdates(ctx, func() { if err := svc.refreshLocalCache(context.Background()); err != nil { - log.Printf("[ErrorPassthroughService] Failed to refresh cache on notification: %v", err) + logger.LegacyPrintf("service.error_passthrough", "[ErrorPassthroughService] Failed to refresh cache on notification: %v", err) } }) } @@ -192,7 +192,7 @@ func (s *ErrorPassthroughService) getCachedRules() []*cachedPassthroughRule { // 如果本地缓存为空,尝试刷新 ctx := context.Background() if err := s.refreshLocalCache(ctx); err != nil { - log.Printf("[ErrorPassthroughService] Failed to refresh cache: %v", err) + logger.LegacyPrintf("service.error_passthrough", "[ErrorPassthroughService] Failed to refresh cache: %v", err) return nil } @@ -225,7 +225,7 @@ func (s *ErrorPassthroughService) reloadRulesFromDB(ctx context.Context) error { // 更新 Redis 缓存 if s.cache != nil { if err := s.cache.Set(ctx, rules); err != nil { - log.Printf("[ErrorPassthroughService] Failed to set cache: %v", err) + logger.LegacyPrintf("service.error_passthrough", "[ErrorPassthroughService] Failed to set cache: %v", err) } } @@ -288,13 +288,13 @@ func (s *ErrorPassthroughService) invalidateAndNotify(ctx context.Context) { // 先失效缓存,避免后续刷新读到陈旧规则。 if s.cache != nil { if err := s.cache.Invalidate(ctx); err != nil { - log.Printf("[ErrorPassthroughService] Failed to invalidate cache: %v", err) + logger.LegacyPrintf("service.error_passthrough", "[ErrorPassthroughService] Failed to invalidate cache: %v", err) } } // 刷新本地缓存 if err := s.reloadRulesFromDB(ctx); err != nil { - log.Printf("[ErrorPassthroughService] Failed to refresh local cache: %v", err) + logger.LegacyPrintf("service.error_passthrough", "[ErrorPassthroughService] Failed to refresh local cache: %v", err) // 刷新失败时清空本地缓存,避免继续使用陈旧规则。 s.clearLocalCache() } @@ -302,7 +302,7 @@ func (s *ErrorPassthroughService) invalidateAndNotify(ctx context.Context) { // 通知其他实例 if s.cache != nil { if err := s.cache.NotifyUpdate(ctx); err != nil { - log.Printf("[ErrorPassthroughService] Failed to notify cache update: %v", err) + logger.LegacyPrintf("service.error_passthrough", "[ErrorPassthroughService] Failed to notify cache update: %v", err) } } } diff --git a/backend/internal/service/gateway_account_selection_test.go b/backend/internal/service/gateway_account_selection_test.go new file mode 100644 index 00000000..0a82fade --- /dev/null +++ b/backend/internal/service/gateway_account_selection_test.go @@ -0,0 +1,206 @@ +//go:build unit + +package service + +import ( + "testing" + "time" + + "github.com/stretchr/testify/require" +) + +// --- helpers --- + +func testTimePtr(t time.Time) *time.Time { return &t } + +func makeAccWithLoad(id int64, priority int, loadRate int, lastUsed *time.Time, accType string) accountWithLoad { + return accountWithLoad{ + account: &Account{ + ID: id, + Priority: priority, + LastUsedAt: lastUsed, + Type: accType, + Schedulable: true, + Status: StatusActive, + }, + loadInfo: &AccountLoadInfo{ + AccountID: id, + CurrentConcurrency: 0, + LoadRate: loadRate, + }, + } +} + +// --- sortAccountsByPriorityAndLastUsed --- + +func TestSortAccountsByPriorityAndLastUsed_ByPriority(t *testing.T) { + now := time.Now() + accounts := []*Account{ + {ID: 1, Priority: 5, LastUsedAt: testTimePtr(now)}, + {ID: 2, Priority: 1, LastUsedAt: testTimePtr(now)}, + {ID: 3, Priority: 3, LastUsedAt: testTimePtr(now)}, + } + sortAccountsByPriorityAndLastUsed(accounts, false) + require.Equal(t, int64(2), accounts[0].ID, "优先级最低的排第一") + require.Equal(t, int64(3), accounts[1].ID) + require.Equal(t, int64(1), accounts[2].ID) +} + +func TestSortAccountsByPriorityAndLastUsed_SamePriorityByLastUsed(t *testing.T) { + now := time.Now() + accounts := []*Account{ + {ID: 1, Priority: 1, LastUsedAt: testTimePtr(now)}, + {ID: 2, Priority: 1, LastUsedAt: testTimePtr(now.Add(-1 * time.Hour))}, + {ID: 3, Priority: 1, LastUsedAt: nil}, + } + sortAccountsByPriorityAndLastUsed(accounts, false) + require.Equal(t, int64(3), accounts[0].ID, "nil LastUsedAt 排最前") + require.Equal(t, int64(2), accounts[1].ID, "更早使用的排前面") + require.Equal(t, int64(1), accounts[2].ID) +} + +func TestSortAccountsByPriorityAndLastUsed_PreferOAuth(t *testing.T) { + accounts := []*Account{ + {ID: 1, Priority: 1, LastUsedAt: nil, Type: AccountTypeAPIKey}, + {ID: 2, Priority: 1, LastUsedAt: nil, Type: AccountTypeOAuth}, + } + sortAccountsByPriorityAndLastUsed(accounts, true) + require.Equal(t, int64(2), accounts[0].ID, "preferOAuth 时 OAuth 账号排前面") +} + +func TestSortAccountsByPriorityAndLastUsed_StableSort(t *testing.T) { + accounts := []*Account{ + {ID: 1, Priority: 1, LastUsedAt: nil, Type: AccountTypeAPIKey}, + {ID: 2, Priority: 1, LastUsedAt: nil, Type: AccountTypeAPIKey}, + {ID: 3, Priority: 1, LastUsedAt: nil, Type: AccountTypeAPIKey}, + } + + // sortAccountsByPriorityAndLastUsed 内部会在同组(Priority+LastUsedAt)内做随机打散, + // 因此这里不再断言“稳定排序”。我们只验证: + // 1) 元素集合不变;2) 多次运行能产生不同的顺序。 + seenFirst := map[int64]bool{} + for i := 0; i < 100; i++ { + cpy := make([]*Account, len(accounts)) + copy(cpy, accounts) + sortAccountsByPriorityAndLastUsed(cpy, false) + seenFirst[cpy[0].ID] = true + + ids := map[int64]bool{} + for _, a := range cpy { + ids[a.ID] = true + } + require.True(t, ids[1] && ids[2] && ids[3]) + } + require.GreaterOrEqual(t, len(seenFirst), 2, "同组账号应能被随机打散") +} + +func TestSortAccountsByPriorityAndLastUsed_MixedPriorityAndTime(t *testing.T) { + now := time.Now() + accounts := []*Account{ + {ID: 1, Priority: 2, LastUsedAt: nil}, + {ID: 2, Priority: 1, LastUsedAt: testTimePtr(now)}, + {ID: 3, Priority: 1, LastUsedAt: testTimePtr(now.Add(-1 * time.Hour))}, + {ID: 4, Priority: 2, LastUsedAt: testTimePtr(now.Add(-2 * time.Hour))}, + } + sortAccountsByPriorityAndLastUsed(accounts, false) + // 优先级1排前:nil < earlier + require.Equal(t, int64(3), accounts[0].ID, "优先级1 + 更早") + require.Equal(t, int64(2), accounts[1].ID, "优先级1 + 现在") + // 优先级2排后:nil < time + require.Equal(t, int64(1), accounts[2].ID, "优先级2 + nil") + require.Equal(t, int64(4), accounts[3].ID, "优先级2 + 有时间") +} + +// --- filterByMinPriority --- + +func TestFilterByMinPriority_Empty(t *testing.T) { + result := filterByMinPriority(nil) + require.Nil(t, result) +} + +func TestFilterByMinPriority_SelectsMinPriority(t *testing.T) { + accounts := []accountWithLoad{ + makeAccWithLoad(1, 5, 10, nil, AccountTypeAPIKey), + makeAccWithLoad(2, 1, 10, nil, AccountTypeAPIKey), + makeAccWithLoad(3, 1, 20, nil, AccountTypeAPIKey), + makeAccWithLoad(4, 2, 10, nil, AccountTypeAPIKey), + } + result := filterByMinPriority(accounts) + require.Len(t, result, 2) + require.Equal(t, int64(2), result[0].account.ID) + require.Equal(t, int64(3), result[1].account.ID) +} + +// --- filterByMinLoadRate --- + +func TestFilterByMinLoadRate_Empty(t *testing.T) { + result := filterByMinLoadRate(nil) + require.Nil(t, result) +} + +func TestFilterByMinLoadRate_SelectsMinLoadRate(t *testing.T) { + accounts := []accountWithLoad{ + makeAccWithLoad(1, 1, 30, nil, AccountTypeAPIKey), + makeAccWithLoad(2, 1, 10, nil, AccountTypeAPIKey), + makeAccWithLoad(3, 1, 10, nil, AccountTypeAPIKey), + makeAccWithLoad(4, 1, 20, nil, AccountTypeAPIKey), + } + result := filterByMinLoadRate(accounts) + require.Len(t, result, 2) + require.Equal(t, int64(2), result[0].account.ID) + require.Equal(t, int64(3), result[1].account.ID) +} + +// --- selectByLRU --- + +func TestSelectByLRU_Empty(t *testing.T) { + result := selectByLRU(nil, false) + require.Nil(t, result) +} + +func TestSelectByLRU_Single(t *testing.T) { + accounts := []accountWithLoad{makeAccWithLoad(1, 1, 10, nil, AccountTypeAPIKey)} + result := selectByLRU(accounts, false) + require.NotNil(t, result) + require.Equal(t, int64(1), result.account.ID) +} + +func TestSelectByLRU_NilLastUsedAtWins(t *testing.T) { + now := time.Now() + accounts := []accountWithLoad{ + makeAccWithLoad(1, 1, 10, testTimePtr(now), AccountTypeAPIKey), + makeAccWithLoad(2, 1, 10, nil, AccountTypeAPIKey), + makeAccWithLoad(3, 1, 10, testTimePtr(now.Add(-1*time.Hour)), AccountTypeAPIKey), + } + result := selectByLRU(accounts, false) + require.NotNil(t, result) + require.Equal(t, int64(2), result.account.ID) +} + +func TestSelectByLRU_EarliestTimeWins(t *testing.T) { + now := time.Now() + accounts := []accountWithLoad{ + makeAccWithLoad(1, 1, 10, testTimePtr(now), AccountTypeAPIKey), + makeAccWithLoad(2, 1, 10, testTimePtr(now.Add(-1*time.Hour)), AccountTypeAPIKey), + makeAccWithLoad(3, 1, 10, testTimePtr(now.Add(-2*time.Hour)), AccountTypeAPIKey), + } + result := selectByLRU(accounts, false) + require.NotNil(t, result) + require.Equal(t, int64(3), result.account.ID) +} + +func TestSelectByLRU_TiePreferOAuth(t *testing.T) { + now := time.Now() + // 账号 1/2 LastUsedAt 相同,且同为最小值。 + accounts := []accountWithLoad{ + makeAccWithLoad(1, 1, 10, testTimePtr(now), AccountTypeAPIKey), + makeAccWithLoad(2, 1, 10, testTimePtr(now), AccountTypeOAuth), + makeAccWithLoad(3, 1, 10, testTimePtr(now.Add(1*time.Hour)), AccountTypeAPIKey), + } + for i := 0; i < 50; i++ { + result := selectByLRU(accounts, true) + require.NotNil(t, result) + require.Equal(t, AccountTypeOAuth, result.account.Type) + require.Equal(t, int64(2), result.account.ID) + } +} diff --git a/backend/internal/service/gateway_anthropic_apikey_passthrough_benchmark_test.go b/backend/internal/service/gateway_anthropic_apikey_passthrough_benchmark_test.go new file mode 100644 index 00000000..37fd709f --- /dev/null +++ b/backend/internal/service/gateway_anthropic_apikey_passthrough_benchmark_test.go @@ -0,0 +1,56 @@ +package service + +import "testing" + +func BenchmarkGatewayService_ParseSSEUsage_MessageStart(b *testing.B) { + svc := &GatewayService{} + data := `{"type":"message_start","message":{"usage":{"input_tokens":123,"cache_creation_input_tokens":45,"cache_read_input_tokens":6,"cached_tokens":6,"cache_creation":{"ephemeral_5m_input_tokens":20,"ephemeral_1h_input_tokens":25}}}}` + b.ReportAllocs() + b.ResetTimer() + for i := 0; i < b.N; i++ { + usage := &ClaudeUsage{} + svc.parseSSEUsage(data, usage) + } +} + +func BenchmarkGatewayService_ParseSSEUsagePassthrough_MessageStart(b *testing.B) { + svc := &GatewayService{} + data := `{"type":"message_start","message":{"usage":{"input_tokens":123,"cache_creation_input_tokens":45,"cache_read_input_tokens":6,"cached_tokens":6,"cache_creation":{"ephemeral_5m_input_tokens":20,"ephemeral_1h_input_tokens":25}}}}` + b.ReportAllocs() + b.ResetTimer() + for i := 0; i < b.N; i++ { + usage := &ClaudeUsage{} + svc.parseSSEUsagePassthrough(data, usage) + } +} + +func BenchmarkGatewayService_ParseSSEUsage_MessageDelta(b *testing.B) { + svc := &GatewayService{} + data := `{"type":"message_delta","usage":{"output_tokens":456,"cache_creation_input_tokens":30,"cache_read_input_tokens":7,"cached_tokens":7,"cache_creation":{"ephemeral_5m_input_tokens":10,"ephemeral_1h_input_tokens":20}}}` + b.ReportAllocs() + b.ResetTimer() + for i := 0; i < b.N; i++ { + usage := &ClaudeUsage{} + svc.parseSSEUsage(data, usage) + } +} + +func BenchmarkGatewayService_ParseSSEUsagePassthrough_MessageDelta(b *testing.B) { + svc := &GatewayService{} + data := `{"type":"message_delta","usage":{"output_tokens":456,"cache_creation_input_tokens":30,"cache_read_input_tokens":7,"cached_tokens":7,"cache_creation":{"ephemeral_5m_input_tokens":10,"ephemeral_1h_input_tokens":20}}}` + b.ReportAllocs() + b.ResetTimer() + for i := 0; i < b.N; i++ { + usage := &ClaudeUsage{} + svc.parseSSEUsagePassthrough(data, usage) + } +} + +func BenchmarkParseClaudeUsageFromResponseBody(b *testing.B) { + body := []byte(`{"id":"msg_123","type":"message","usage":{"input_tokens":123,"output_tokens":456,"cache_creation_input_tokens":45,"cache_read_input_tokens":6,"cached_tokens":6,"cache_creation":{"ephemeral_5m_input_tokens":20,"ephemeral_1h_input_tokens":25}}}`) + b.ReportAllocs() + b.ResetTimer() + for i := 0; i < b.N; i++ { + _ = parseClaudeUsageFromResponseBody(body) + } +} diff --git a/backend/internal/service/gateway_anthropic_apikey_passthrough_test.go b/backend/internal/service/gateway_anthropic_apikey_passthrough_test.go new file mode 100644 index 00000000..5183891b --- /dev/null +++ b/backend/internal/service/gateway_anthropic_apikey_passthrough_test.go @@ -0,0 +1,773 @@ +package service + +import ( + "bufio" + "bytes" + "context" + "errors" + "io" + "net/http" + "net/http/httptest" + "strings" + "testing" + "time" + + "github.com/Wei-Shaw/sub2api/internal/config" + "github.com/Wei-Shaw/sub2api/internal/pkg/claude" + "github.com/gin-gonic/gin" + "github.com/stretchr/testify/require" + "github.com/tidwall/gjson" +) + +type anthropicHTTPUpstreamRecorder struct { + lastReq *http.Request + lastBody []byte + resp *http.Response + err error +} + +func newAnthropicAPIKeyAccountForTest() *Account { + return &Account{ + ID: 201, + Name: "anthropic-apikey-pass-test", + Platform: PlatformAnthropic, + Type: AccountTypeAPIKey, + Concurrency: 1, + Credentials: map[string]any{ + "api_key": "upstream-anthropic-key", + "base_url": "https://api.anthropic.com", + }, + Extra: map[string]any{ + "anthropic_passthrough": true, + }, + Status: StatusActive, + Schedulable: true, + } +} + +func (u *anthropicHTTPUpstreamRecorder) Do(req *http.Request, proxyURL string, accountID int64, accountConcurrency int) (*http.Response, error) { + u.lastReq = req + if req != nil && req.Body != nil { + b, _ := io.ReadAll(req.Body) + u.lastBody = b + _ = req.Body.Close() + req.Body = io.NopCloser(bytes.NewReader(b)) + } + if u.err != nil { + return nil, u.err + } + return u.resp, nil +} + +func (u *anthropicHTTPUpstreamRecorder) DoWithTLS(req *http.Request, proxyURL string, accountID int64, accountConcurrency int, enableTLSFingerprint bool) (*http.Response, error) { + return u.Do(req, proxyURL, accountID, accountConcurrency) +} + +type streamReadCloser struct { + payload []byte + sent bool + err error +} + +func (r *streamReadCloser) Read(p []byte) (int, error) { + if !r.sent { + r.sent = true + n := copy(p, r.payload) + return n, nil + } + if r.err != nil { + return 0, r.err + } + return 0, io.EOF +} + +func (r *streamReadCloser) Close() error { return nil } + +type failWriteResponseWriter struct { + gin.ResponseWriter +} + +func (w *failWriteResponseWriter) Write(data []byte) (int, error) { + return 0, errors.New("client disconnected") +} + +func (w *failWriteResponseWriter) WriteString(_ string) (int, error) { + return 0, errors.New("client disconnected") +} + +func TestGatewayService_AnthropicAPIKeyPassthrough_ForwardStreamPreservesBodyAndAuthReplacement(t *testing.T) { + gin.SetMode(gin.TestMode) + + rec := httptest.NewRecorder() + c, _ := gin.CreateTestContext(rec) + c.Request = httptest.NewRequest(http.MethodPost, "/v1/messages", nil) + c.Request.Header.Set("User-Agent", "claude-cli/1.0.0") + c.Request.Header.Set("Authorization", "Bearer inbound-token") + c.Request.Header.Set("X-Api-Key", "inbound-api-key") + c.Request.Header.Set("X-Goog-Api-Key", "inbound-goog-key") + c.Request.Header.Set("Cookie", "secret=1") + c.Request.Header.Set("Anthropic-Beta", "interleaved-thinking-2025-05-14") + + body := []byte(`{"model":"claude-3-7-sonnet-20250219","stream":true,"system":[{"type":"text","text":"x-anthropic-billing-header keep"}],"messages":[{"role":"user","content":[{"type":"text","text":"hello"}]}]}`) + parsed := &ParsedRequest{ + Body: body, + Model: "claude-3-7-sonnet-20250219", + Stream: true, + } + + upstreamSSE := strings.Join([]string{ + `data: {"type":"message_start","message":{"usage":{"input_tokens":9,"cached_tokens":7}}}`, + "", + `data: {"type":"message_delta","usage":{"output_tokens":3}}`, + "", + "data: [DONE]", + "", + }, "\n") + upstream := &anthropicHTTPUpstreamRecorder{ + resp: &http.Response{ + StatusCode: http.StatusOK, + Header: http.Header{ + "Content-Type": []string{"text/event-stream"}, + "x-request-id": []string{"rid-anthropic-pass"}, + "Set-Cookie": []string{"secret=upstream"}, + }, + Body: io.NopCloser(strings.NewReader(upstreamSSE)), + }, + } + + svc := &GatewayService{ + cfg: &config.Config{ + Gateway: config.GatewayConfig{ + MaxLineSize: defaultMaxLineSize, + }, + }, + httpUpstream: upstream, + rateLimitService: &RateLimitService{}, + deferredService: &DeferredService{}, + billingCacheService: nil, + } + + account := &Account{ + ID: 101, + Name: "anthropic-apikey-pass", + Platform: PlatformAnthropic, + Type: AccountTypeAPIKey, + Concurrency: 1, + Credentials: map[string]any{ + "api_key": "upstream-anthropic-key", + "base_url": "https://api.anthropic.com", + "model_mapping": map[string]any{"claude-3-7-sonnet-20250219": "claude-3-haiku-20240307"}, + }, + Extra: map[string]any{ + "anthropic_passthrough": true, + }, + Status: StatusActive, + Schedulable: true, + } + + result, err := svc.Forward(context.Background(), c, account, parsed) + require.NoError(t, err) + require.NotNil(t, result) + require.True(t, result.Stream) + + require.Equal(t, body, upstream.lastBody, "透传模式不应改写上游请求体") + require.Equal(t, "claude-3-7-sonnet-20250219", gjson.GetBytes(upstream.lastBody, "model").String()) + + require.Equal(t, "upstream-anthropic-key", upstream.lastReq.Header.Get("x-api-key")) + require.Empty(t, upstream.lastReq.Header.Get("authorization")) + require.Empty(t, upstream.lastReq.Header.Get("x-goog-api-key")) + require.Empty(t, upstream.lastReq.Header.Get("cookie")) + require.Equal(t, "2023-06-01", upstream.lastReq.Header.Get("anthropic-version")) + require.Equal(t, "interleaved-thinking-2025-05-14", upstream.lastReq.Header.Get("anthropic-beta")) + require.Empty(t, upstream.lastReq.Header.Get("x-stainless-lang"), "API Key 透传不应注入 OAuth 指纹头") + + require.Contains(t, rec.Body.String(), `"cached_tokens":7`) + require.NotContains(t, rec.Body.String(), `"cache_read_input_tokens":7`, "透传输出不应被网关改写") + require.Equal(t, 7, result.Usage.CacheReadInputTokens, "计费 usage 解析应保留 cached_tokens 兼容") + require.Empty(t, rec.Header().Get("Set-Cookie"), "响应头应经过安全过滤") + rawBody, ok := c.Get(OpsUpstreamRequestBodyKey) + require.True(t, ok) + bodyBytes, ok := rawBody.([]byte) + require.True(t, ok, "应以 []byte 形式缓存上游请求体,避免重复 string 拷贝") + require.Equal(t, body, bodyBytes) +} + +func TestGatewayService_AnthropicAPIKeyPassthrough_ForwardCountTokensPreservesBody(t *testing.T) { + gin.SetMode(gin.TestMode) + + rec := httptest.NewRecorder() + c, _ := gin.CreateTestContext(rec) + c.Request = httptest.NewRequest(http.MethodPost, "/v1/messages/count_tokens", nil) + c.Request.Header.Set("Authorization", "Bearer inbound-token") + c.Request.Header.Set("X-Api-Key", "inbound-api-key") + c.Request.Header.Set("Cookie", "secret=1") + + body := []byte(`{"model":"claude-3-5-sonnet-latest","messages":[{"role":"user","content":[{"type":"text","text":"hello"}]}],"thinking":{"type":"enabled"}}`) + parsed := &ParsedRequest{ + Body: body, + Model: "claude-3-5-sonnet-latest", + } + + upstreamRespBody := `{"input_tokens":42}` + upstream := &anthropicHTTPUpstreamRecorder{ + resp: &http.Response{ + StatusCode: http.StatusOK, + Header: http.Header{ + "Content-Type": []string{"application/json"}, + "x-request-id": []string{"rid-count"}, + "Set-Cookie": []string{"secret=upstream"}, + }, + Body: io.NopCloser(strings.NewReader(upstreamRespBody)), + }, + } + + svc := &GatewayService{ + cfg: &config.Config{ + Gateway: config.GatewayConfig{ + MaxLineSize: defaultMaxLineSize, + }, + }, + httpUpstream: upstream, + rateLimitService: &RateLimitService{}, + } + + account := &Account{ + ID: 102, + Name: "anthropic-apikey-pass-count", + Platform: PlatformAnthropic, + Type: AccountTypeAPIKey, + Concurrency: 1, + Credentials: map[string]any{ + "api_key": "upstream-anthropic-key", + "base_url": "https://api.anthropic.com", + "model_mapping": map[string]any{"claude-3-5-sonnet-latest": "claude-3-opus-20240229"}, + }, + Extra: map[string]any{ + "anthropic_passthrough": true, + }, + Status: StatusActive, + Schedulable: true, + } + + err := svc.ForwardCountTokens(context.Background(), c, account, parsed) + require.NoError(t, err) + + require.Equal(t, body, upstream.lastBody, "count_tokens 透传模式不应改写请求体") + require.Equal(t, "claude-3-5-sonnet-latest", gjson.GetBytes(upstream.lastBody, "model").String()) + require.Equal(t, "upstream-anthropic-key", upstream.lastReq.Header.Get("x-api-key")) + require.Empty(t, upstream.lastReq.Header.Get("authorization")) + require.Empty(t, upstream.lastReq.Header.Get("cookie")) + require.Equal(t, http.StatusOK, rec.Code) + require.JSONEq(t, upstreamRespBody, rec.Body.String()) + require.Empty(t, rec.Header().Get("Set-Cookie")) +} + +func TestGatewayService_AnthropicAPIKeyPassthrough_BuildRequestRejectsInvalidBaseURL(t *testing.T) { + gin.SetMode(gin.TestMode) + rec := httptest.NewRecorder() + c, _ := gin.CreateTestContext(rec) + c.Request = httptest.NewRequest(http.MethodPost, "/v1/messages", nil) + + svc := &GatewayService{ + cfg: &config.Config{ + Security: config.SecurityConfig{ + URLAllowlist: config.URLAllowlistConfig{ + Enabled: false, + }, + }, + }, + } + account := &Account{ + Platform: PlatformAnthropic, + Type: AccountTypeAPIKey, + Credentials: map[string]any{ + "api_key": "k", + "base_url": "://invalid-url", + }, + } + + _, err := svc.buildUpstreamRequestAnthropicAPIKeyPassthrough(context.Background(), c, account, []byte(`{}`), "k") + require.Error(t, err) +} + +func TestGatewayService_AnthropicOAuth_NotAffectedByAPIKeyPassthroughToggle(t *testing.T) { + gin.SetMode(gin.TestMode) + rec := httptest.NewRecorder() + c, _ := gin.CreateTestContext(rec) + c.Request = httptest.NewRequest(http.MethodPost, "/v1/messages", nil) + + svc := &GatewayService{ + cfg: &config.Config{ + Gateway: config.GatewayConfig{MaxLineSize: defaultMaxLineSize}, + }, + } + account := &Account{ + Platform: PlatformAnthropic, + Type: AccountTypeOAuth, + Extra: map[string]any{ + "anthropic_passthrough": true, + }, + } + + require.False(t, account.IsAnthropicAPIKeyPassthroughEnabled()) + + req, err := svc.buildUpstreamRequest(context.Background(), c, account, []byte(`{"model":"claude-3-7-sonnet-20250219"}`), "oauth-token", "oauth", "claude-3-7-sonnet-20250219", true, false) + require.NoError(t, err) + require.Equal(t, "Bearer oauth-token", req.Header.Get("authorization")) + require.Contains(t, req.Header.Get("anthropic-beta"), claude.BetaOAuth, "OAuth 链路仍应按原逻辑补齐 oauth beta") +} + +func TestGatewayService_AnthropicAPIKeyPassthrough_StreamingStillCollectsUsageAfterClientDisconnect(t *testing.T) { + gin.SetMode(gin.TestMode) + + // Use a canceled context recorder to simulate client disconnect behavior. + req := httptest.NewRequest(http.MethodPost, "/v1/messages", nil) + ctx, cancel := context.WithCancel(req.Context()) + cancel() + req = req.WithContext(ctx) + + rec := httptest.NewRecorder() + c, _ := gin.CreateTestContext(rec) + c.Request = req + + svc := &GatewayService{ + cfg: &config.Config{ + Gateway: config.GatewayConfig{ + MaxLineSize: defaultMaxLineSize, + }, + }, + rateLimitService: &RateLimitService{}, + } + + resp := &http.Response{ + StatusCode: http.StatusOK, + Header: http.Header{"Content-Type": []string{"text/event-stream"}}, + Body: io.NopCloser(strings.NewReader(strings.Join([]string{ + `data: {"type":"message_start","message":{"usage":{"input_tokens":11}}}`, + "", + `data: {"type":"message_delta","usage":{"output_tokens":5}}`, + "", + "data: [DONE]", + "", + }, "\n"))), + } + + result, err := svc.handleStreamingResponseAnthropicAPIKeyPassthrough(context.Background(), resp, c, &Account{ID: 1}, time.Now(), "claude-3-7-sonnet-20250219") + require.NoError(t, err) + require.NotNil(t, result) + require.NotNil(t, result.usage) + require.Equal(t, 11, result.usage.InputTokens) + require.Equal(t, 5, result.usage.OutputTokens) +} + +func TestGatewayService_AnthropicAPIKeyPassthrough_ForwardDirect_NonStreamingSuccess(t *testing.T) { + gin.SetMode(gin.TestMode) + rec := httptest.NewRecorder() + c, _ := gin.CreateTestContext(rec) + c.Request = httptest.NewRequest(http.MethodPost, "/v1/messages", nil) + + body := []byte(`{"model":"claude-3-5-sonnet-latest","messages":[{"role":"user","content":[{"type":"text","text":"hello"}]}]}`) + upstreamJSON := `{"id":"msg_1","type":"message","usage":{"input_tokens":12,"output_tokens":7,"cache_creation":{"ephemeral_5m_input_tokens":2,"ephemeral_1h_input_tokens":3},"cached_tokens":4}}` + upstream := &anthropicHTTPUpstreamRecorder{ + resp: &http.Response{ + StatusCode: http.StatusOK, + Header: http.Header{ + "Content-Type": []string{"application/json"}, + "x-request-id": []string{"rid-nonstream"}, + }, + Body: io.NopCloser(strings.NewReader(upstreamJSON)), + }, + } + svc := &GatewayService{ + cfg: &config.Config{}, + httpUpstream: upstream, + rateLimitService: &RateLimitService{}, + } + + result, err := svc.forwardAnthropicAPIKeyPassthrough(context.Background(), c, newAnthropicAPIKeyAccountForTest(), body, "claude-3-5-sonnet-latest", false, time.Now()) + require.NoError(t, err) + require.NotNil(t, result) + require.Equal(t, 12, result.Usage.InputTokens) + require.Equal(t, 7, result.Usage.OutputTokens) + require.Equal(t, 5, result.Usage.CacheCreationInputTokens) + require.Equal(t, 4, result.Usage.CacheReadInputTokens) + require.Equal(t, upstreamJSON, rec.Body.String()) +} + +func TestGatewayService_AnthropicAPIKeyPassthrough_ForwardDirect_InvalidTokenType(t *testing.T) { + gin.SetMode(gin.TestMode) + rec := httptest.NewRecorder() + c, _ := gin.CreateTestContext(rec) + c.Request = httptest.NewRequest(http.MethodPost, "/v1/messages", nil) + + account := &Account{ + ID: 202, + Name: "anthropic-oauth", + Platform: PlatformAnthropic, + Type: AccountTypeOAuth, + Credentials: map[string]any{ + "access_token": "oauth-token", + }, + } + svc := &GatewayService{} + + result, err := svc.forwardAnthropicAPIKeyPassthrough(context.Background(), c, account, []byte(`{}`), "claude-3-5-sonnet-latest", false, time.Now()) + require.Nil(t, result) + require.Error(t, err) + require.Contains(t, err.Error(), "requires apikey token") +} + +func TestGatewayService_AnthropicAPIKeyPassthrough_ForwardDirect_UpstreamRequestError(t *testing.T) { + gin.SetMode(gin.TestMode) + rec := httptest.NewRecorder() + c, _ := gin.CreateTestContext(rec) + c.Request = httptest.NewRequest(http.MethodPost, "/v1/messages", nil) + + upstream := &anthropicHTTPUpstreamRecorder{ + err: errors.New("dial tcp timeout"), + } + svc := &GatewayService{ + cfg: &config.Config{ + Security: config.SecurityConfig{ + URLAllowlist: config.URLAllowlistConfig{Enabled: false}, + }, + }, + httpUpstream: upstream, + } + account := newAnthropicAPIKeyAccountForTest() + + result, err := svc.forwardAnthropicAPIKeyPassthrough(context.Background(), c, account, []byte(`{"model":"x"}`), "x", false, time.Now()) + require.Nil(t, result) + require.Error(t, err) + require.Contains(t, err.Error(), "upstream request failed") + require.Equal(t, http.StatusBadGateway, rec.Code) + rawBody, ok := c.Get(OpsUpstreamRequestBodyKey) + require.True(t, ok) + _, ok = rawBody.([]byte) + require.True(t, ok) +} + +func TestGatewayService_AnthropicAPIKeyPassthrough_ForwardDirect_EmptyResponseBody(t *testing.T) { + gin.SetMode(gin.TestMode) + rec := httptest.NewRecorder() + c, _ := gin.CreateTestContext(rec) + c.Request = httptest.NewRequest(http.MethodPost, "/v1/messages", nil) + + upstream := &anthropicHTTPUpstreamRecorder{ + resp: &http.Response{ + StatusCode: http.StatusOK, + Header: http.Header{"x-request-id": []string{"rid-empty-body"}}, + Body: nil, + }, + } + svc := &GatewayService{ + cfg: &config.Config{ + Security: config.SecurityConfig{ + URLAllowlist: config.URLAllowlistConfig{Enabled: false}, + }, + }, + httpUpstream: upstream, + } + + result, err := svc.forwardAnthropicAPIKeyPassthrough(context.Background(), c, newAnthropicAPIKeyAccountForTest(), []byte(`{"model":"x"}`), "x", false, time.Now()) + require.Nil(t, result) + require.Error(t, err) + require.Contains(t, err.Error(), "empty response") +} + +func TestExtractAnthropicSSEDataLine(t *testing.T) { + t.Run("valid data line with spaces", func(t *testing.T) { + data, ok := extractAnthropicSSEDataLine("data: {\"type\":\"message_start\"}") + require.True(t, ok) + require.Equal(t, `{"type":"message_start"}`, data) + }) + + t.Run("non data line", func(t *testing.T) { + data, ok := extractAnthropicSSEDataLine("event: message_start") + require.False(t, ok) + require.Empty(t, data) + }) +} + +func TestGatewayService_ParseSSEUsagePassthrough_MessageStartFallbacks(t *testing.T) { + svc := &GatewayService{} + usage := &ClaudeUsage{} + data := `{"type":"message_start","message":{"usage":{"input_tokens":12,"cache_creation_input_tokens":0,"cache_read_input_tokens":0,"cached_tokens":9,"cache_creation":{"ephemeral_5m_input_tokens":3,"ephemeral_1h_input_tokens":4}}}}` + + svc.parseSSEUsagePassthrough(data, usage) + + require.Equal(t, 12, usage.InputTokens) + require.Equal(t, 9, usage.CacheReadInputTokens, "应兼容 cached_tokens 字段") + require.Equal(t, 7, usage.CacheCreationInputTokens, "聚合字段为空时应从 5m/1h 明细回填") + require.Equal(t, 3, usage.CacheCreation5mTokens) + require.Equal(t, 4, usage.CacheCreation1hTokens) +} + +func TestGatewayService_ParseSSEUsagePassthrough_MessageDeltaSelectiveOverwrite(t *testing.T) { + svc := &GatewayService{} + usage := &ClaudeUsage{ + InputTokens: 10, + CacheCreation5mTokens: 2, + CacheCreation1hTokens: 6, + } + data := `{"type":"message_delta","usage":{"input_tokens":0,"output_tokens":5,"cache_creation_input_tokens":8,"cache_read_input_tokens":0,"cached_tokens":11,"cache_creation":{"ephemeral_5m_input_tokens":1,"ephemeral_1h_input_tokens":0}}}` + + svc.parseSSEUsagePassthrough(data, usage) + + require.Equal(t, 10, usage.InputTokens, "message_delta 中 0 值不应覆盖已有 input_tokens") + require.Equal(t, 5, usage.OutputTokens) + require.Equal(t, 8, usage.CacheCreationInputTokens) + require.Equal(t, 11, usage.CacheReadInputTokens, "cache_read_input_tokens 为空时应回退到 cached_tokens") + require.Equal(t, 1, usage.CacheCreation5mTokens) + require.Equal(t, 6, usage.CacheCreation1hTokens, "message_delta 中 0 值不应覆盖已有 1h 明细") +} + +func TestGatewayService_ParseSSEUsagePassthrough_NoopCases(t *testing.T) { + svc := &GatewayService{} + + usage := &ClaudeUsage{InputTokens: 3} + svc.parseSSEUsagePassthrough("", usage) + require.Equal(t, 3, usage.InputTokens) + + svc.parseSSEUsagePassthrough("[DONE]", usage) + require.Equal(t, 3, usage.InputTokens) + + svc.parseSSEUsagePassthrough("not-json", usage) + require.Equal(t, 3, usage.InputTokens) + + // nil usage 不应 panic + svc.parseSSEUsagePassthrough(`{"type":"message_start"}`, nil) +} + +func TestGatewayService_ParseSSEUsagePassthrough_FallbackFromUsageNode(t *testing.T) { + svc := &GatewayService{} + usage := &ClaudeUsage{} + data := `{"type":"content_block_delta","usage":{"cached_tokens":6,"cache_creation":{"ephemeral_5m_input_tokens":2,"ephemeral_1h_input_tokens":1}}}` + + svc.parseSSEUsagePassthrough(data, usage) + + require.Equal(t, 6, usage.CacheReadInputTokens) + require.Equal(t, 3, usage.CacheCreationInputTokens) +} + +func TestParseClaudeUsageFromResponseBody(t *testing.T) { + t.Run("empty or missing usage", func(t *testing.T) { + got := parseClaudeUsageFromResponseBody(nil) + require.NotNil(t, got) + require.Equal(t, 0, got.InputTokens) + + got = parseClaudeUsageFromResponseBody([]byte(`{"id":"x"}`)) + require.NotNil(t, got) + require.Equal(t, 0, got.OutputTokens) + }) + + t.Run("parse all usage fields and fallback", func(t *testing.T) { + body := []byte(`{"usage":{"input_tokens":21,"output_tokens":34,"cache_creation_input_tokens":0,"cache_read_input_tokens":0,"cached_tokens":13,"cache_creation":{"ephemeral_5m_input_tokens":5,"ephemeral_1h_input_tokens":8}}}`) + got := parseClaudeUsageFromResponseBody(body) + require.Equal(t, 21, got.InputTokens) + require.Equal(t, 34, got.OutputTokens) + require.Equal(t, 13, got.CacheReadInputTokens, "cache_read_input_tokens 为空时应回退 cached_tokens") + require.Equal(t, 13, got.CacheCreationInputTokens, "聚合字段为空时应由 5m/1h 回填") + require.Equal(t, 5, got.CacheCreation5mTokens) + require.Equal(t, 8, got.CacheCreation1hTokens) + }) + + t.Run("keep explicit aggregate values", func(t *testing.T) { + body := []byte(`{"usage":{"input_tokens":1,"output_tokens":2,"cache_creation_input_tokens":9,"cache_read_input_tokens":7,"cached_tokens":99,"cache_creation":{"ephemeral_5m_input_tokens":4,"ephemeral_1h_input_tokens":5}}}`) + got := parseClaudeUsageFromResponseBody(body) + require.Equal(t, 9, got.CacheCreationInputTokens, "已显式提供聚合字段时不应被明细覆盖") + require.Equal(t, 7, got.CacheReadInputTokens, "已显式提供 cache_read_input_tokens 时不应回退 cached_tokens") + }) +} + +func TestGatewayService_AnthropicAPIKeyPassthrough_StreamingErrTooLong(t *testing.T) { + gin.SetMode(gin.TestMode) + rec := httptest.NewRecorder() + c, _ := gin.CreateTestContext(rec) + c.Request = httptest.NewRequest(http.MethodPost, "/v1/messages", nil) + + svc := &GatewayService{ + cfg: &config.Config{ + Gateway: config.GatewayConfig{ + MaxLineSize: 32, + }, + }, + } + + // Scanner 初始缓冲为 64KB,构造更长单行触发 bufio.ErrTooLong。 + longLine := "data: " + strings.Repeat("x", 80*1024) + resp := &http.Response{ + StatusCode: http.StatusOK, + Header: http.Header{"Content-Type": []string{"text/event-stream"}}, + Body: io.NopCloser(strings.NewReader(longLine)), + } + + result, err := svc.handleStreamingResponseAnthropicAPIKeyPassthrough(context.Background(), resp, c, &Account{ID: 2}, time.Now(), "claude-3-7-sonnet-20250219") + require.Error(t, err) + require.ErrorIs(t, err, bufio.ErrTooLong) + require.NotNil(t, result) +} + +func TestGatewayService_AnthropicAPIKeyPassthrough_StreamingDataIntervalTimeout(t *testing.T) { + gin.SetMode(gin.TestMode) + rec := httptest.NewRecorder() + c, _ := gin.CreateTestContext(rec) + c.Request = httptest.NewRequest(http.MethodPost, "/v1/messages", nil) + + svc := &GatewayService{ + cfg: &config.Config{ + Gateway: config.GatewayConfig{ + StreamDataIntervalTimeout: 1, + MaxLineSize: defaultMaxLineSize, + }, + }, + rateLimitService: &RateLimitService{}, + } + + pr, pw := io.Pipe() + resp := &http.Response{ + StatusCode: http.StatusOK, + Header: http.Header{"Content-Type": []string{"text/event-stream"}}, + Body: pr, + } + + result, err := svc.handleStreamingResponseAnthropicAPIKeyPassthrough(context.Background(), resp, c, &Account{ID: 5}, time.Now(), "claude-3-7-sonnet-20250219") + _ = pw.Close() + _ = pr.Close() + + require.Error(t, err) + require.Contains(t, err.Error(), "stream data interval timeout") + require.NotNil(t, result) + require.False(t, result.clientDisconnect) +} + +func TestGatewayService_AnthropicAPIKeyPassthrough_StreamingReadError(t *testing.T) { + gin.SetMode(gin.TestMode) + rec := httptest.NewRecorder() + c, _ := gin.CreateTestContext(rec) + c.Request = httptest.NewRequest(http.MethodPost, "/v1/messages", nil) + + svc := &GatewayService{ + cfg: &config.Config{ + Gateway: config.GatewayConfig{ + MaxLineSize: defaultMaxLineSize, + }, + }, + } + + resp := &http.Response{ + StatusCode: http.StatusOK, + Header: http.Header{"Content-Type": []string{"text/event-stream"}}, + Body: &streamReadCloser{ + err: io.ErrUnexpectedEOF, + }, + } + + result, err := svc.handleStreamingResponseAnthropicAPIKeyPassthrough(context.Background(), resp, c, &Account{ID: 6}, time.Now(), "claude-3-7-sonnet-20250219") + require.Error(t, err) + require.Contains(t, err.Error(), "stream read error") + require.NotNil(t, result) + require.False(t, result.clientDisconnect) +} + +func TestGatewayService_AnthropicAPIKeyPassthrough_StreamingTimeoutAfterClientDisconnect(t *testing.T) { + gin.SetMode(gin.TestMode) + rec := httptest.NewRecorder() + c, _ := gin.CreateTestContext(rec) + c.Request = httptest.NewRequest(http.MethodPost, "/v1/messages", nil) + c.Writer = &failWriteResponseWriter{ResponseWriter: c.Writer} + + svc := &GatewayService{ + cfg: &config.Config{ + Gateway: config.GatewayConfig{ + StreamDataIntervalTimeout: 1, + MaxLineSize: defaultMaxLineSize, + }, + }, + rateLimitService: &RateLimitService{}, + } + + pr, pw := io.Pipe() + resp := &http.Response{ + StatusCode: http.StatusOK, + Header: http.Header{"Content-Type": []string{"text/event-stream"}}, + Body: pr, + } + + done := make(chan struct{}) + go func() { + defer close(done) + _, _ = pw.Write([]byte(`data: {"type":"message_start","message":{"usage":{"input_tokens":9}}}` + "\n")) + // 保持上游连接静默,触发数据间隔超时分支。 + time.Sleep(1500 * time.Millisecond) + _ = pw.Close() + }() + + result, err := svc.handleStreamingResponseAnthropicAPIKeyPassthrough(context.Background(), resp, c, &Account{ID: 7}, time.Now(), "claude-3-7-sonnet-20250219") + _ = pr.Close() + <-done + + require.NoError(t, err) + require.NotNil(t, result) + require.True(t, result.clientDisconnect) + require.Equal(t, 9, result.usage.InputTokens) +} + +func TestGatewayService_AnthropicAPIKeyPassthrough_StreamingContextCanceled(t *testing.T) { + gin.SetMode(gin.TestMode) + rec := httptest.NewRecorder() + c, _ := gin.CreateTestContext(rec) + c.Request = httptest.NewRequest(http.MethodPost, "/v1/messages", nil) + + svc := &GatewayService{ + cfg: &config.Config{ + Gateway: config.GatewayConfig{ + MaxLineSize: defaultMaxLineSize, + }, + }, + } + + resp := &http.Response{ + StatusCode: http.StatusOK, + Header: http.Header{"Content-Type": []string{"text/event-stream"}}, + Body: &streamReadCloser{ + err: context.Canceled, + }, + } + + result, err := svc.handleStreamingResponseAnthropicAPIKeyPassthrough(context.Background(), resp, c, &Account{ID: 3}, time.Now(), "claude-3-7-sonnet-20250219") + require.NoError(t, err) + require.NotNil(t, result) + require.True(t, result.clientDisconnect) +} + +func TestGatewayService_AnthropicAPIKeyPassthrough_StreamingUpstreamReadErrorAfterClientDisconnect(t *testing.T) { + gin.SetMode(gin.TestMode) + rec := httptest.NewRecorder() + c, _ := gin.CreateTestContext(rec) + c.Request = httptest.NewRequest(http.MethodPost, "/v1/messages", nil) + c.Writer = &failWriteResponseWriter{ResponseWriter: c.Writer} + + svc := &GatewayService{ + cfg: &config.Config{ + Gateway: config.GatewayConfig{ + MaxLineSize: defaultMaxLineSize, + }, + }, + } + + resp := &http.Response{ + StatusCode: http.StatusOK, + Header: http.Header{"Content-Type": []string{"text/event-stream"}}, + Body: &streamReadCloser{ + payload: []byte(`data: {"type":"message_start","message":{"usage":{"input_tokens":8}}}` + "\n\n"), + err: io.ErrUnexpectedEOF, + }, + } + + result, err := svc.handleStreamingResponseAnthropicAPIKeyPassthrough(context.Background(), resp, c, &Account{ID: 4}, time.Now(), "claude-3-7-sonnet-20250219") + require.NoError(t, err) + require.NotNil(t, result) + require.True(t, result.clientDisconnect) + require.Equal(t, 8, result.usage.InputTokens) +} diff --git a/backend/internal/service/gateway_hotpath_optimization_test.go b/backend/internal/service/gateway_hotpath_optimization_test.go new file mode 100644 index 00000000..161c4ba4 --- /dev/null +++ b/backend/internal/service/gateway_hotpath_optimization_test.go @@ -0,0 +1,786 @@ +package service + +import ( + "context" + "errors" + "sync" + "sync/atomic" + "testing" + "time" + + "github.com/Wei-Shaw/sub2api/internal/config" + "github.com/Wei-Shaw/sub2api/internal/pkg/ctxkey" + "github.com/Wei-Shaw/sub2api/internal/pkg/usagestats" + gocache "github.com/patrickmn/go-cache" + "github.com/stretchr/testify/require" +) + +type userGroupRateRepoHotpathStub struct { + UserGroupRateRepository + + rate *float64 + err error + wait <-chan struct{} + calls atomic.Int64 +} + +func (s *userGroupRateRepoHotpathStub) GetByUserAndGroup(ctx context.Context, userID, groupID int64) (*float64, error) { + s.calls.Add(1) + if s.wait != nil { + <-s.wait + } + if s.err != nil { + return nil, s.err + } + return s.rate, nil +} + +type usageLogWindowBatchRepoStub struct { + UsageLogRepository + + batchResult map[int64]*usagestats.AccountStats + batchErr error + batchCalls atomic.Int64 + + singleResult map[int64]*usagestats.AccountStats + singleErr error + singleCalls atomic.Int64 +} + +func (s *usageLogWindowBatchRepoStub) GetAccountWindowStatsBatch(ctx context.Context, accountIDs []int64, startTime time.Time) (map[int64]*usagestats.AccountStats, error) { + s.batchCalls.Add(1) + if s.batchErr != nil { + return nil, s.batchErr + } + out := make(map[int64]*usagestats.AccountStats, len(accountIDs)) + for _, id := range accountIDs { + if stats, ok := s.batchResult[id]; ok { + out[id] = stats + } + } + return out, nil +} + +func (s *usageLogWindowBatchRepoStub) GetAccountWindowStats(ctx context.Context, accountID int64, startTime time.Time) (*usagestats.AccountStats, error) { + s.singleCalls.Add(1) + if s.singleErr != nil { + return nil, s.singleErr + } + if stats, ok := s.singleResult[accountID]; ok { + return stats, nil + } + return &usagestats.AccountStats{}, nil +} + +type sessionLimitCacheHotpathStub struct { + SessionLimitCache + + batchData map[int64]float64 + batchErr error + + setData map[int64]float64 + setErr error +} + +func (s *sessionLimitCacheHotpathStub) GetWindowCostBatch(ctx context.Context, accountIDs []int64) (map[int64]float64, error) { + if s.batchErr != nil { + return nil, s.batchErr + } + out := make(map[int64]float64, len(accountIDs)) + for _, id := range accountIDs { + if v, ok := s.batchData[id]; ok { + out[id] = v + } + } + return out, nil +} + +func (s *sessionLimitCacheHotpathStub) SetWindowCost(ctx context.Context, accountID int64, cost float64) error { + if s.setErr != nil { + return s.setErr + } + if s.setData == nil { + s.setData = make(map[int64]float64) + } + s.setData[accountID] = cost + return nil +} + +type modelsListAccountRepoStub struct { + AccountRepository + + byGroup map[int64][]Account + all []Account + err error + + listByGroupCalls atomic.Int64 + listAllCalls atomic.Int64 +} + +type stickyGatewayCacheHotpathStub struct { + GatewayCache + + stickyID int64 + getCalls atomic.Int64 +} + +func (s *stickyGatewayCacheHotpathStub) GetSessionAccountID(ctx context.Context, groupID int64, sessionHash string) (int64, error) { + s.getCalls.Add(1) + if s.stickyID > 0 { + return s.stickyID, nil + } + return 0, errors.New("not found") +} + +func (s *stickyGatewayCacheHotpathStub) SetSessionAccountID(ctx context.Context, groupID int64, sessionHash string, accountID int64, ttl time.Duration) error { + return nil +} + +func (s *stickyGatewayCacheHotpathStub) RefreshSessionTTL(ctx context.Context, groupID int64, sessionHash string, ttl time.Duration) error { + return nil +} + +func (s *stickyGatewayCacheHotpathStub) DeleteSessionAccountID(ctx context.Context, groupID int64, sessionHash string) error { + return nil +} + +func (s *modelsListAccountRepoStub) ListSchedulableByGroupID(ctx context.Context, groupID int64) ([]Account, error) { + s.listByGroupCalls.Add(1) + if s.err != nil { + return nil, s.err + } + accounts, ok := s.byGroup[groupID] + if !ok { + return nil, nil + } + out := make([]Account, len(accounts)) + copy(out, accounts) + return out, nil +} + +func (s *modelsListAccountRepoStub) ListSchedulable(ctx context.Context) ([]Account, error) { + s.listAllCalls.Add(1) + if s.err != nil { + return nil, s.err + } + out := make([]Account, len(s.all)) + copy(out, s.all) + return out, nil +} + +func resetGatewayHotpathStatsForTest() { + windowCostPrefetchCacheHitTotal.Store(0) + windowCostPrefetchCacheMissTotal.Store(0) + windowCostPrefetchBatchSQLTotal.Store(0) + windowCostPrefetchFallbackTotal.Store(0) + windowCostPrefetchErrorTotal.Store(0) + + userGroupRateCacheHitTotal.Store(0) + userGroupRateCacheMissTotal.Store(0) + userGroupRateCacheLoadTotal.Store(0) + userGroupRateCacheSFSharedTotal.Store(0) + userGroupRateCacheFallbackTotal.Store(0) + + modelsListCacheHitTotal.Store(0) + modelsListCacheMissTotal.Store(0) + modelsListCacheStoreTotal.Store(0) +} + +func TestGetUserGroupRateMultiplier_UsesCacheAndSingleflight(t *testing.T) { + resetGatewayHotpathStatsForTest() + + rate := 1.7 + unblock := make(chan struct{}) + repo := &userGroupRateRepoHotpathStub{ + rate: &rate, + wait: unblock, + } + svc := &GatewayService{ + userGroupRateRepo: repo, + userGroupRateCache: gocache.New(time.Minute, time.Minute), + cfg: &config.Config{ + Gateway: config.GatewayConfig{ + UserGroupRateCacheTTLSeconds: 30, + }, + }, + } + + const concurrent = 12 + results := make([]float64, concurrent) + start := make(chan struct{}) + var wg sync.WaitGroup + wg.Add(concurrent) + for i := 0; i < concurrent; i++ { + go func(idx int) { + defer wg.Done() + <-start + results[idx] = svc.getUserGroupRateMultiplier(context.Background(), 101, 202, 1.2) + }(i) + } + + close(start) + time.Sleep(20 * time.Millisecond) + close(unblock) + wg.Wait() + + for _, got := range results { + require.Equal(t, rate, got) + } + require.Equal(t, int64(1), repo.calls.Load()) + + // 再次读取应命中缓存,不再回源。 + got := svc.getUserGroupRateMultiplier(context.Background(), 101, 202, 1.2) + require.Equal(t, rate, got) + require.Equal(t, int64(1), repo.calls.Load()) + + hit, miss, load, sfShared, fallback := GatewayUserGroupRateCacheStats() + require.GreaterOrEqual(t, hit, int64(1)) + require.Equal(t, int64(12), miss) + require.Equal(t, int64(1), load) + require.GreaterOrEqual(t, sfShared, int64(1)) + require.Equal(t, int64(0), fallback) +} + +func TestGetUserGroupRateMultiplier_FallbackOnRepoError(t *testing.T) { + resetGatewayHotpathStatsForTest() + + repo := &userGroupRateRepoHotpathStub{ + err: errors.New("db down"), + } + svc := &GatewayService{ + userGroupRateRepo: repo, + userGroupRateCache: gocache.New(time.Minute, time.Minute), + cfg: &config.Config{ + Gateway: config.GatewayConfig{ + UserGroupRateCacheTTLSeconds: 30, + }, + }, + } + + got := svc.getUserGroupRateMultiplier(context.Background(), 101, 202, 1.25) + require.Equal(t, 1.25, got) + require.Equal(t, int64(1), repo.calls.Load()) + + _, _, _, _, fallback := GatewayUserGroupRateCacheStats() + require.Equal(t, int64(1), fallback) +} + +func TestGetUserGroupRateMultiplier_CacheHitAndNilRepo(t *testing.T) { + resetGatewayHotpathStatsForTest() + + repo := &userGroupRateRepoHotpathStub{ + err: errors.New("should not be called"), + } + svc := &GatewayService{ + userGroupRateRepo: repo, + userGroupRateCache: gocache.New(time.Minute, time.Minute), + } + key := "101:202" + svc.userGroupRateCache.Set(key, 2.3, time.Minute) + + got := svc.getUserGroupRateMultiplier(context.Background(), 101, 202, 1.1) + require.Equal(t, 2.3, got) + + hit, miss, load, _, fallback := GatewayUserGroupRateCacheStats() + require.Equal(t, int64(1), hit) + require.Equal(t, int64(0), miss) + require.Equal(t, int64(0), load) + require.Equal(t, int64(0), fallback) + require.Equal(t, int64(0), repo.calls.Load()) + + // 无 repo 时直接返回分组默认倍率 + svc2 := &GatewayService{ + userGroupRateCache: gocache.New(time.Minute, time.Minute), + } + svc2.userGroupRateCache.Set(key, 1.9, time.Minute) + require.Equal(t, 1.9, svc2.getUserGroupRateMultiplier(context.Background(), 101, 202, 1.4)) + require.Equal(t, 1.4, svc2.getUserGroupRateMultiplier(context.Background(), 0, 202, 1.4)) + svc2.userGroupRateCache.Delete(key) + require.Equal(t, 1.4, svc2.getUserGroupRateMultiplier(context.Background(), 101, 202, 1.4)) +} + +func TestWithWindowCostPrefetch_BatchReadAndContextReuse(t *testing.T) { + resetGatewayHotpathStatsForTest() + + windowStart := time.Now().Add(-30 * time.Minute).Truncate(time.Hour) + windowEnd := windowStart.Add(5 * time.Hour) + accounts := []Account{ + { + ID: 1, + Platform: PlatformAnthropic, + Type: AccountTypeOAuth, + Extra: map[string]any{"window_cost_limit": 100.0}, + SessionWindowStart: &windowStart, + SessionWindowEnd: &windowEnd, + }, + { + ID: 2, + Platform: PlatformAnthropic, + Type: AccountTypeSetupToken, + Extra: map[string]any{"window_cost_limit": 100.0}, + SessionWindowStart: &windowStart, + SessionWindowEnd: &windowEnd, + }, + { + ID: 3, + Platform: PlatformAnthropic, + Type: AccountTypeAPIKey, + Extra: map[string]any{"window_cost_limit": 100.0}, + }, + } + + cache := &sessionLimitCacheHotpathStub{ + batchData: map[int64]float64{ + 1: 11.0, + }, + } + repo := &usageLogWindowBatchRepoStub{ + batchResult: map[int64]*usagestats.AccountStats{ + 2: {StandardCost: 22.0}, + }, + } + svc := &GatewayService{ + sessionLimitCache: cache, + usageLogRepo: repo, + } + + outCtx := svc.withWindowCostPrefetch(context.Background(), accounts) + require.NotNil(t, outCtx) + + cost1, ok1 := windowCostFromPrefetchContext(outCtx, 1) + require.True(t, ok1) + require.Equal(t, 11.0, cost1) + + cost2, ok2 := windowCostFromPrefetchContext(outCtx, 2) + require.True(t, ok2) + require.Equal(t, 22.0, cost2) + + _, ok3 := windowCostFromPrefetchContext(outCtx, 3) + require.False(t, ok3) + + require.Equal(t, int64(1), repo.batchCalls.Load()) + require.Equal(t, 22.0, cache.setData[2]) + + hit, miss, batchSQL, fallback, errCount := GatewayWindowCostPrefetchStats() + require.Equal(t, int64(1), hit) + require.Equal(t, int64(1), miss) + require.Equal(t, int64(1), batchSQL) + require.Equal(t, int64(0), fallback) + require.Equal(t, int64(0), errCount) +} + +func TestWithWindowCostPrefetch_AllHitNoSQL(t *testing.T) { + resetGatewayHotpathStatsForTest() + + windowStart := time.Now().Add(-30 * time.Minute).Truncate(time.Hour) + windowEnd := windowStart.Add(5 * time.Hour) + accounts := []Account{ + { + ID: 1, + Platform: PlatformAnthropic, + Type: AccountTypeOAuth, + Extra: map[string]any{"window_cost_limit": 100.0}, + SessionWindowStart: &windowStart, + SessionWindowEnd: &windowEnd, + }, + { + ID: 2, + Platform: PlatformAnthropic, + Type: AccountTypeSetupToken, + Extra: map[string]any{"window_cost_limit": 100.0}, + SessionWindowStart: &windowStart, + SessionWindowEnd: &windowEnd, + }, + } + + cache := &sessionLimitCacheHotpathStub{ + batchData: map[int64]float64{ + 1: 11.0, + 2: 22.0, + }, + } + repo := &usageLogWindowBatchRepoStub{} + svc := &GatewayService{ + sessionLimitCache: cache, + usageLogRepo: repo, + } + + outCtx := svc.withWindowCostPrefetch(context.Background(), accounts) + cost1, ok1 := windowCostFromPrefetchContext(outCtx, 1) + cost2, ok2 := windowCostFromPrefetchContext(outCtx, 2) + require.True(t, ok1) + require.True(t, ok2) + require.Equal(t, 11.0, cost1) + require.Equal(t, 22.0, cost2) + require.Equal(t, int64(0), repo.batchCalls.Load()) + require.Equal(t, int64(0), repo.singleCalls.Load()) + + hit, miss, batchSQL, fallback, errCount := GatewayWindowCostPrefetchStats() + require.Equal(t, int64(2), hit) + require.Equal(t, int64(0), miss) + require.Equal(t, int64(0), batchSQL) + require.Equal(t, int64(0), fallback) + require.Equal(t, int64(0), errCount) +} + +func TestWithWindowCostPrefetch_BatchErrorFallbackSingleQuery(t *testing.T) { + resetGatewayHotpathStatsForTest() + + windowStart := time.Now().Add(-30 * time.Minute).Truncate(time.Hour) + windowEnd := windowStart.Add(5 * time.Hour) + accounts := []Account{ + { + ID: 2, + Platform: PlatformAnthropic, + Type: AccountTypeSetupToken, + Extra: map[string]any{"window_cost_limit": 100.0}, + SessionWindowStart: &windowStart, + SessionWindowEnd: &windowEnd, + }, + } + + cache := &sessionLimitCacheHotpathStub{} + repo := &usageLogWindowBatchRepoStub{ + batchErr: errors.New("batch failed"), + singleResult: map[int64]*usagestats.AccountStats{ + 2: {StandardCost: 33.0}, + }, + } + svc := &GatewayService{ + sessionLimitCache: cache, + usageLogRepo: repo, + } + + outCtx := svc.withWindowCostPrefetch(context.Background(), accounts) + cost, ok := windowCostFromPrefetchContext(outCtx, 2) + require.True(t, ok) + require.Equal(t, 33.0, cost) + require.Equal(t, int64(1), repo.batchCalls.Load()) + require.Equal(t, int64(1), repo.singleCalls.Load()) + + _, _, _, fallback, errCount := GatewayWindowCostPrefetchStats() + require.Equal(t, int64(1), fallback) + require.Equal(t, int64(1), errCount) +} + +func TestGetAvailableModels_UsesShortCacheAndSupportsInvalidation(t *testing.T) { + resetGatewayHotpathStatsForTest() + + groupID := int64(9) + repo := &modelsListAccountRepoStub{ + byGroup: map[int64][]Account{ + groupID: { + { + ID: 1, + Platform: PlatformAnthropic, + Credentials: map[string]any{ + "model_mapping": map[string]any{ + "claude-3-5-sonnet": "claude-3-5-sonnet", + "claude-3-5-haiku": "claude-3-5-haiku", + }, + }, + }, + { + ID: 2, + Platform: PlatformGemini, + Credentials: map[string]any{ + "model_mapping": map[string]any{ + "gemini-2.5-pro": "gemini-2.5-pro", + }, + }, + }, + }, + }, + } + + svc := &GatewayService{ + accountRepo: repo, + modelsListCache: gocache.New(time.Minute, time.Minute), + modelsListCacheTTL: time.Minute, + } + + models1 := svc.GetAvailableModels(context.Background(), &groupID, PlatformAnthropic) + require.Equal(t, []string{"claude-3-5-haiku", "claude-3-5-sonnet"}, models1) + require.Equal(t, int64(1), repo.listByGroupCalls.Load()) + + // TTL 内再次请求应命中缓存,不回源。 + models2 := svc.GetAvailableModels(context.Background(), &groupID, PlatformAnthropic) + require.Equal(t, models1, models2) + require.Equal(t, int64(1), repo.listByGroupCalls.Load()) + + // 更新仓储数据,但缓存未失效前应继续返回旧值。 + repo.byGroup[groupID] = []Account{ + { + ID: 3, + Platform: PlatformAnthropic, + Credentials: map[string]any{ + "model_mapping": map[string]any{ + "claude-3-7-sonnet": "claude-3-7-sonnet", + }, + }, + }, + } + models3 := svc.GetAvailableModels(context.Background(), &groupID, PlatformAnthropic) + require.Equal(t, []string{"claude-3-5-haiku", "claude-3-5-sonnet"}, models3) + require.Equal(t, int64(1), repo.listByGroupCalls.Load()) + + svc.InvalidateAvailableModelsCache(&groupID, PlatformAnthropic) + models4 := svc.GetAvailableModels(context.Background(), &groupID, PlatformAnthropic) + require.Equal(t, []string{"claude-3-7-sonnet"}, models4) + require.Equal(t, int64(2), repo.listByGroupCalls.Load()) + + hit, miss, store := GatewayModelsListCacheStats() + require.Equal(t, int64(2), hit) + require.Equal(t, int64(2), miss) + require.Equal(t, int64(2), store) +} + +func TestGetAvailableModels_ErrorAndGlobalListBranches(t *testing.T) { + resetGatewayHotpathStatsForTest() + + errRepo := &modelsListAccountRepoStub{ + err: errors.New("db error"), + } + svcErr := &GatewayService{ + accountRepo: errRepo, + modelsListCache: gocache.New(time.Minute, time.Minute), + modelsListCacheTTL: time.Minute, + } + require.Nil(t, svcErr.GetAvailableModels(context.Background(), nil, "")) + + okRepo := &modelsListAccountRepoStub{ + all: []Account{ + { + ID: 1, + Platform: PlatformAnthropic, + Credentials: map[string]any{ + "model_mapping": map[string]any{ + "claude-3-5-sonnet": "claude-3-5-sonnet", + }, + }, + }, + { + ID: 2, + Platform: PlatformGemini, + Credentials: map[string]any{ + "model_mapping": map[string]any{ + "gemini-2.5-pro": "gemini-2.5-pro", + }, + }, + }, + }, + } + svcOK := &GatewayService{ + accountRepo: okRepo, + modelsListCache: gocache.New(time.Minute, time.Minute), + modelsListCacheTTL: time.Minute, + } + models := svcOK.GetAvailableModels(context.Background(), nil, "") + require.Equal(t, []string{"claude-3-5-sonnet", "gemini-2.5-pro"}, models) + require.Equal(t, int64(1), okRepo.listAllCalls.Load()) +} + +func TestGatewayHotpathHelpers_CacheTTLAndStickyContext(t *testing.T) { + t.Run("resolve_user_group_rate_cache_ttl", func(t *testing.T) { + require.Equal(t, defaultUserGroupRateCacheTTL, resolveUserGroupRateCacheTTL(nil)) + + cfg := &config.Config{ + Gateway: config.GatewayConfig{ + UserGroupRateCacheTTLSeconds: 45, + }, + } + require.Equal(t, 45*time.Second, resolveUserGroupRateCacheTTL(cfg)) + }) + + t.Run("resolve_models_list_cache_ttl", func(t *testing.T) { + require.Equal(t, defaultModelsListCacheTTL, resolveModelsListCacheTTL(nil)) + + cfg := &config.Config{ + Gateway: config.GatewayConfig{ + ModelsListCacheTTLSeconds: 20, + }, + } + require.Equal(t, 20*time.Second, resolveModelsListCacheTTL(cfg)) + }) + + t.Run("prefetched_sticky_account_id_from_context", func(t *testing.T) { + require.Equal(t, int64(0), prefetchedStickyAccountIDFromContext(context.TODO(), nil)) + require.Equal(t, int64(0), prefetchedStickyAccountIDFromContext(context.Background(), nil)) + + ctx := context.WithValue(context.Background(), ctxkey.PrefetchedStickyAccountID, int64(123)) + ctx = context.WithValue(ctx, ctxkey.PrefetchedStickyGroupID, int64(0)) + require.Equal(t, int64(123), prefetchedStickyAccountIDFromContext(ctx, nil)) + + groupID := int64(9) + ctx2 := context.WithValue(context.Background(), ctxkey.PrefetchedStickyAccountID, 456) + ctx2 = context.WithValue(ctx2, ctxkey.PrefetchedStickyGroupID, groupID) + require.Equal(t, int64(456), prefetchedStickyAccountIDFromContext(ctx2, &groupID)) + + ctx3 := context.WithValue(context.Background(), ctxkey.PrefetchedStickyAccountID, "invalid") + ctx3 = context.WithValue(ctx3, ctxkey.PrefetchedStickyGroupID, groupID) + require.Equal(t, int64(0), prefetchedStickyAccountIDFromContext(ctx3, &groupID)) + + ctx4 := context.WithValue(context.Background(), ctxkey.PrefetchedStickyAccountID, int64(789)) + ctx4 = context.WithValue(ctx4, ctxkey.PrefetchedStickyGroupID, int64(10)) + require.Equal(t, int64(0), prefetchedStickyAccountIDFromContext(ctx4, &groupID)) + }) + + t.Run("window_cost_from_prefetch_context", func(t *testing.T) { + require.Equal(t, false, func() bool { + _, ok := windowCostFromPrefetchContext(context.TODO(), 0) + return ok + }()) + require.Equal(t, false, func() bool { + _, ok := windowCostFromPrefetchContext(context.Background(), 1) + return ok + }()) + + ctx := context.WithValue(context.Background(), windowCostPrefetchContextKey, map[int64]float64{ + 9: 12.34, + }) + cost, ok := windowCostFromPrefetchContext(ctx, 9) + require.True(t, ok) + require.Equal(t, 12.34, cost) + }) +} + +func TestInvalidateAvailableModelsCache_ByDimensions(t *testing.T) { + svc := &GatewayService{ + modelsListCache: gocache.New(time.Minute, time.Minute), + } + group9 := int64(9) + group10 := int64(10) + svc.modelsListCache.Set(modelsListCacheKey(&group9, PlatformAnthropic), []string{"a"}, time.Minute) + svc.modelsListCache.Set(modelsListCacheKey(&group9, PlatformGemini), []string{"b"}, time.Minute) + svc.modelsListCache.Set(modelsListCacheKey(&group10, PlatformAnthropic), []string{"c"}, time.Minute) + svc.modelsListCache.Set("invalid-key", []string{"d"}, time.Minute) + + t.Run("invalidate_group_and_platform", func(t *testing.T) { + svc.InvalidateAvailableModelsCache(&group9, PlatformAnthropic) + _, found := svc.modelsListCache.Get(modelsListCacheKey(&group9, PlatformAnthropic)) + require.False(t, found) + _, stillFound := svc.modelsListCache.Get(modelsListCacheKey(&group9, PlatformGemini)) + require.True(t, stillFound) + }) + + t.Run("invalidate_group_only", func(t *testing.T) { + svc.InvalidateAvailableModelsCache(&group9, "") + _, foundA := svc.modelsListCache.Get(modelsListCacheKey(&group9, PlatformAnthropic)) + _, foundB := svc.modelsListCache.Get(modelsListCacheKey(&group9, PlatformGemini)) + require.False(t, foundA) + require.False(t, foundB) + _, foundOtherGroup := svc.modelsListCache.Get(modelsListCacheKey(&group10, PlatformAnthropic)) + require.True(t, foundOtherGroup) + }) + + t.Run("invalidate_platform_only", func(t *testing.T) { + // 重建数据后仅按 platform 失效 + svc.modelsListCache.Set(modelsListCacheKey(&group9, PlatformAnthropic), []string{"a"}, time.Minute) + svc.modelsListCache.Set(modelsListCacheKey(&group9, PlatformGemini), []string{"b"}, time.Minute) + svc.modelsListCache.Set(modelsListCacheKey(&group10, PlatformAnthropic), []string{"c"}, time.Minute) + + svc.InvalidateAvailableModelsCache(nil, PlatformAnthropic) + _, found9Anthropic := svc.modelsListCache.Get(modelsListCacheKey(&group9, PlatformAnthropic)) + _, found10Anthropic := svc.modelsListCache.Get(modelsListCacheKey(&group10, PlatformAnthropic)) + _, found9Gemini := svc.modelsListCache.Get(modelsListCacheKey(&group9, PlatformGemini)) + require.False(t, found9Anthropic) + require.False(t, found10Anthropic) + require.True(t, found9Gemini) + }) +} + +func TestSelectAccountWithLoadAwareness_StickyReadReuse(t *testing.T) { + now := time.Now().Add(-time.Minute) + account := Account{ + ID: 88, + Platform: PlatformAnthropic, + Type: AccountTypeAPIKey, + Status: StatusActive, + Schedulable: true, + Concurrency: 4, + Priority: 1, + LastUsedAt: &now, + } + + repo := stubOpenAIAccountRepo{accounts: []Account{account}} + concurrency := NewConcurrencyService(stubConcurrencyCache{}) + + cfg := &config.Config{ + RunMode: config.RunModeStandard, + Gateway: config.GatewayConfig{ + Scheduling: config.GatewaySchedulingConfig{ + LoadBatchEnabled: true, + StickySessionMaxWaiting: 3, + StickySessionWaitTimeout: time.Second, + FallbackWaitTimeout: time.Second, + FallbackMaxWaiting: 10, + }, + }, + } + + baseCtx := context.WithValue(context.Background(), ctxkey.ForcePlatform, PlatformAnthropic) + + t.Run("without_prefetch_reads_cache_once", func(t *testing.T) { + cache := &stickyGatewayCacheHotpathStub{stickyID: account.ID} + svc := &GatewayService{ + accountRepo: repo, + cache: cache, + cfg: cfg, + concurrencyService: concurrency, + userGroupRateCache: gocache.New(time.Minute, time.Minute), + modelsListCache: gocache.New(time.Minute, time.Minute), + modelsListCacheTTL: time.Minute, + } + + result, err := svc.SelectAccountWithLoadAwareness(baseCtx, nil, "sess-hash", "", nil, "") + require.NoError(t, err) + require.NotNil(t, result) + require.NotNil(t, result.Account) + require.Equal(t, account.ID, result.Account.ID) + require.Equal(t, int64(1), cache.getCalls.Load()) + }) + + t.Run("with_prefetch_skips_cache_read", func(t *testing.T) { + cache := &stickyGatewayCacheHotpathStub{stickyID: account.ID} + svc := &GatewayService{ + accountRepo: repo, + cache: cache, + cfg: cfg, + concurrencyService: concurrency, + userGroupRateCache: gocache.New(time.Minute, time.Minute), + modelsListCache: gocache.New(time.Minute, time.Minute), + modelsListCacheTTL: time.Minute, + } + + ctx := context.WithValue(baseCtx, ctxkey.PrefetchedStickyAccountID, account.ID) + ctx = context.WithValue(ctx, ctxkey.PrefetchedStickyGroupID, int64(0)) + result, err := svc.SelectAccountWithLoadAwareness(ctx, nil, "sess-hash", "", nil, "") + require.NoError(t, err) + require.NotNil(t, result) + require.NotNil(t, result.Account) + require.Equal(t, account.ID, result.Account.ID) + require.Equal(t, int64(0), cache.getCalls.Load()) + }) + + t.Run("with_prefetch_group_mismatch_reads_cache", func(t *testing.T) { + cache := &stickyGatewayCacheHotpathStub{stickyID: account.ID} + svc := &GatewayService{ + accountRepo: repo, + cache: cache, + cfg: cfg, + concurrencyService: concurrency, + userGroupRateCache: gocache.New(time.Minute, time.Minute), + modelsListCache: gocache.New(time.Minute, time.Minute), + modelsListCacheTTL: time.Minute, + } + + ctx := context.WithValue(baseCtx, ctxkey.PrefetchedStickyAccountID, int64(999)) + ctx = context.WithValue(ctx, ctxkey.PrefetchedStickyGroupID, int64(77)) + result, err := svc.SelectAccountWithLoadAwareness(ctx, nil, "sess-hash", "", nil, "") + require.NoError(t, err) + require.NotNil(t, result) + require.NotNil(t, result.Account) + require.Equal(t, account.ID, result.Account.ID) + require.Equal(t, int64(1), cache.getCalls.Load()) + }) +} diff --git a/backend/internal/service/gateway_multiplatform_test.go b/backend/internal/service/gateway_multiplatform_test.go index 09fda60e..70d5068b 100644 --- a/backend/internal/service/gateway_multiplatform_test.go +++ b/backend/internal/service/gateway_multiplatform_test.go @@ -77,6 +77,11 @@ func (m *mockAccountRepoForPlatform) Create(ctx context.Context, account *Accoun func (m *mockAccountRepoForPlatform) GetByCRSAccountID(ctx context.Context, crsAccountID string) (*Account, error) { return nil, nil } + +func (m *mockAccountRepoForPlatform) FindByExtraField(ctx context.Context, key string, value any) ([]Account, error) { + return nil, nil +} + func (m *mockAccountRepoForPlatform) ListCRSAccountIDs(ctx context.Context) (map[string]int64, error) { return nil, nil } diff --git a/backend/internal/service/gateway_request.go b/backend/internal/service/gateway_request.go index 743dd738..f8096a0e 100644 --- a/backend/internal/service/gateway_request.go +++ b/backend/internal/service/gateway_request.go @@ -5,9 +5,28 @@ import ( "encoding/json" "fmt" "math" + "unsafe" "github.com/Wei-Shaw/sub2api/internal/domain" "github.com/Wei-Shaw/sub2api/internal/pkg/antigravity" + "github.com/tidwall/gjson" + "github.com/tidwall/sjson" +) + +var ( + // 这些字节模式用于 fast-path 判断,避免每次 []byte("...") 产生临时分配。 + patternTypeThinking = []byte(`"type":"thinking"`) + patternTypeThinkingSpaced = []byte(`"type": "thinking"`) + patternTypeRedactedThinking = []byte(`"type":"redacted_thinking"`) + patternTypeRedactedSpaced = []byte(`"type": "redacted_thinking"`) + + patternThinkingField = []byte(`"thinking":`) + patternThinkingFieldSpaced = []byte(`"thinking" :`) + + patternEmptyContent = []byte(`"content":[]`) + patternEmptyContentSpaced = []byte(`"content": []`) + patternEmptyContentSp1 = []byte(`"content" : []`) + patternEmptyContentSp2 = []byte(`"content" :[]`) ) // SessionContext 粘性会话上下文,用于区分不同来源的请求。 @@ -48,113 +67,127 @@ type ParsedRequest struct { // protocol 指定请求协议格式(domain.PlatformAnthropic / domain.PlatformGemini), // 不同协议使用不同的 system/messages 字段名。 func ParseGatewayRequest(body []byte, protocol string) (*ParsedRequest, error) { - var req map[string]any - if err := json.Unmarshal(body, &req); err != nil { - return nil, err + // 保持与旧实现一致:请求体必须是合法 JSON。 + // 注意:gjson.GetBytes 对非法 JSON 不会报错,因此需要显式校验。 + if !gjson.ValidBytes(body) { + return nil, fmt.Errorf("invalid json") } + // 性能: + // - gjson.GetBytes 会把匹配的 Raw/Str 安全复制成 string(对于巨大 messages 会产生额外拷贝)。 + // - 这里将 body 通过 unsafe 零拷贝视为 string,仅在本函数内使用,且 body 不会被修改。 + jsonStr := *(*string)(unsafe.Pointer(&body)) + parsed := &ParsedRequest{ Body: body, } - if rawModel, exists := req["model"]; exists { - model, ok := rawModel.(string) - if !ok { + // --- gjson 提取简单字段(避免完整 Unmarshal) --- + + // model: 需要严格类型校验,非 string 返回错误 + modelResult := gjson.Get(jsonStr, "model") + if modelResult.Exists() { + if modelResult.Type != gjson.String { return nil, fmt.Errorf("invalid model field type") } - parsed.Model = model + parsed.Model = modelResult.String() } - if rawStream, exists := req["stream"]; exists { - stream, ok := rawStream.(bool) - if !ok { + + // stream: 需要严格类型校验,非 bool 返回错误 + streamResult := gjson.Get(jsonStr, "stream") + if streamResult.Exists() { + if streamResult.Type != gjson.True && streamResult.Type != gjson.False { return nil, fmt.Errorf("invalid stream field type") } - parsed.Stream = stream + parsed.Stream = streamResult.Bool() } - if metadata, ok := req["metadata"].(map[string]any); ok { - if userID, ok := metadata["user_id"].(string); ok { - parsed.MetadataUserID = userID + + // metadata.user_id: 直接路径提取,不需要严格类型校验 + parsed.MetadataUserID = gjson.Get(jsonStr, "metadata.user_id").String() + + // thinking.type: enabled/adaptive 都视为开启 + thinkingType := gjson.Get(jsonStr, "thinking.type").String() + if thinkingType == "enabled" || thinkingType == "adaptive" { + parsed.ThinkingEnabled = true + } + + // max_tokens: 仅接受整数值 + maxTokensResult := gjson.Get(jsonStr, "max_tokens") + if maxTokensResult.Exists() && maxTokensResult.Type == gjson.Number { + f := maxTokensResult.Float() + if !math.IsNaN(f) && !math.IsInf(f, 0) && f == math.Trunc(f) && + f <= float64(math.MaxInt) && f >= float64(math.MinInt) { + parsed.MaxTokens = int(f) } } + // --- system/messages 提取 --- + // 避免把整个 body Unmarshal 到 map(会产生大量 map/接口分配)。 + // 使用 gjson 抽取目标字段的 Raw,再对该子树进行 Unmarshal。 + switch protocol { case domain.PlatformGemini: // Gemini 原生格式: systemInstruction.parts / contents - if sysInst, ok := req["systemInstruction"].(map[string]any); ok { - if parts, ok := sysInst["parts"].([]any); ok { - parsed.System = parts + if sysParts := gjson.Get(jsonStr, "systemInstruction.parts"); sysParts.Exists() && sysParts.IsArray() { + var parts []any + if err := json.Unmarshal(sliceRawFromBody(body, sysParts), &parts); err != nil { + return nil, err } + parsed.System = parts } - if contents, ok := req["contents"].([]any); ok { - parsed.Messages = contents + + if contents := gjson.Get(jsonStr, "contents"); contents.Exists() && contents.IsArray() { + var msgs []any + if err := json.Unmarshal(sliceRawFromBody(body, contents), &msgs); err != nil { + return nil, err + } + parsed.Messages = msgs } default: // Anthropic / OpenAI 格式: system / messages // system 字段只要存在就视为显式提供(即使为 null), // 以避免客户端传 null 时被默认 system 误注入。 - if system, ok := req["system"]; ok { + if sys := gjson.Get(jsonStr, "system"); sys.Exists() { parsed.HasSystem = true - parsed.System = system + switch sys.Type { + case gjson.Null: + parsed.System = nil + case gjson.String: + // 与 encoding/json 的 Unmarshal 行为一致:返回解码后的字符串。 + parsed.System = sys.String() + default: + var system any + if err := json.Unmarshal(sliceRawFromBody(body, sys), &system); err != nil { + return nil, err + } + parsed.System = system + } } - if messages, ok := req["messages"].([]any); ok { + + if msgs := gjson.Get(jsonStr, "messages"); msgs.Exists() && msgs.IsArray() { + var messages []any + if err := json.Unmarshal(sliceRawFromBody(body, msgs), &messages); err != nil { + return nil, err + } parsed.Messages = messages } } - // thinking: {type: "enabled" | "adaptive"} - if rawThinking, ok := req["thinking"].(map[string]any); ok { - if t, ok := rawThinking["type"].(string); ok && (t == "enabled" || t == "adaptive") { - parsed.ThinkingEnabled = true - } - } - - // max_tokens - if rawMaxTokens, exists := req["max_tokens"]; exists { - if maxTokens, ok := parseIntegralNumber(rawMaxTokens); ok { - parsed.MaxTokens = maxTokens - } - } - return parsed, nil } -// parseIntegralNumber 将 JSON 解码后的数字安全转换为 int。 -// 仅接受“整数值”的输入,小数/NaN/Inf/越界值都会返回 false。 -func parseIntegralNumber(raw any) (int, bool) { - switch v := raw.(type) { - case float64: - if math.IsNaN(v) || math.IsInf(v, 0) || v != math.Trunc(v) { - return 0, false +// sliceRawFromBody 返回 Result.Raw 对应的原始字节切片。 +// 优先使用 Result.Index 直接从 body 切片,避免对大字段(如 messages)产生额外拷贝。 +// 当 Index 不可用时,退化为复制(理论上极少发生)。 +func sliceRawFromBody(body []byte, r gjson.Result) []byte { + if r.Index > 0 { + end := r.Index + len(r.Raw) + if end <= len(body) { + return body[r.Index:end] } - if v > float64(math.MaxInt) || v < float64(math.MinInt) { - return 0, false - } - return int(v), true - case int: - return v, true - case int8: - return int(v), true - case int16: - return int(v), true - case int32: - return int(v), true - case int64: - if v > int64(math.MaxInt) || v < int64(math.MinInt) { - return 0, false - } - return int(v), true - case json.Number: - i64, err := v.Int64() - if err != nil { - return 0, false - } - if i64 > int64(math.MaxInt) || i64 < int64(math.MinInt) { - return 0, false - } - return int(i64), true - default: - return 0, false } + // fallback: 不影响正确性,但会产生一次拷贝 + return []byte(r.Raw) } // FilterThinkingBlocks removes thinking blocks from request body @@ -184,49 +217,63 @@ func FilterThinkingBlocks(body []byte) []byte { // - Remove `redacted_thinking` blocks (cannot be converted to text). // - Ensure no message ends up with empty content. func FilterThinkingBlocksForRetry(body []byte) []byte { - hasThinkingContent := bytes.Contains(body, []byte(`"type":"thinking"`)) || - bytes.Contains(body, []byte(`"type": "thinking"`)) || - bytes.Contains(body, []byte(`"type":"redacted_thinking"`)) || - bytes.Contains(body, []byte(`"type": "redacted_thinking"`)) || - bytes.Contains(body, []byte(`"thinking":`)) || - bytes.Contains(body, []byte(`"thinking" :`)) + hasThinkingContent := bytes.Contains(body, patternTypeThinking) || + bytes.Contains(body, patternTypeThinkingSpaced) || + bytes.Contains(body, patternTypeRedactedThinking) || + bytes.Contains(body, patternTypeRedactedSpaced) || + bytes.Contains(body, patternThinkingField) || + bytes.Contains(body, patternThinkingFieldSpaced) // Also check for empty content arrays that need fixing. // Note: This is a heuristic check; the actual empty content handling is done below. - hasEmptyContent := bytes.Contains(body, []byte(`"content":[]`)) || - bytes.Contains(body, []byte(`"content": []`)) || - bytes.Contains(body, []byte(`"content" : []`)) || - bytes.Contains(body, []byte(`"content" :[]`)) + hasEmptyContent := bytes.Contains(body, patternEmptyContent) || + bytes.Contains(body, patternEmptyContentSpaced) || + bytes.Contains(body, patternEmptyContentSp1) || + bytes.Contains(body, patternEmptyContentSp2) // Fast path: nothing to process if !hasThinkingContent && !hasEmptyContent { return body } - var req map[string]any - if err := json.Unmarshal(body, &req); err != nil { + // 尽量避免把整个 body Unmarshal 成 map(会产生大量 map/接口分配)。 + // 这里先用 gjson 把 messages 子树摘出来,后续只对 messages 做 Unmarshal/Marshal。 + jsonStr := *(*string)(unsafe.Pointer(&body)) + msgsRes := gjson.Get(jsonStr, "messages") + if !msgsRes.Exists() || !msgsRes.IsArray() { + return body + } + + // Fast path:只需要删除顶层 thinking,不需要改 messages。 + // 注意:patternThinkingField 可能来自嵌套字段(如 tool_use.input.thinking),因此必须用 gjson 判断顶层字段是否存在。 + containsThinkingBlocks := bytes.Contains(body, patternTypeThinking) || + bytes.Contains(body, patternTypeThinkingSpaced) || + bytes.Contains(body, patternTypeRedactedThinking) || + bytes.Contains(body, patternTypeRedactedSpaced) || + bytes.Contains(body, patternThinkingFieldSpaced) + if !hasEmptyContent && !containsThinkingBlocks { + if topThinking := gjson.Get(jsonStr, "thinking"); topThinking.Exists() { + if out, err := sjson.DeleteBytes(body, "thinking"); err == nil { + return out + } + return body + } + return body + } + + var messages []any + if err := json.Unmarshal(sliceRawFromBody(body, msgsRes), &messages); err != nil { return body } modified := false - messages, ok := req["messages"].([]any) - if !ok { - return body - } - // Disable top-level thinking mode for retry to avoid structural/signature constraints upstream. - if _, exists := req["thinking"]; exists { - delete(req, "thinking") - modified = true - } + deleteTopLevelThinking := gjson.Get(jsonStr, "thinking").Exists() - newMessages := make([]any, 0, len(messages)) - - for _, msg := range messages { - msgMap, ok := msg.(map[string]any) + for i := 0; i < len(messages); i++ { + msgMap, ok := messages[i].(map[string]any) if !ok { - newMessages = append(newMessages, msg) continue } @@ -234,17 +281,30 @@ func FilterThinkingBlocksForRetry(body []byte) []byte { content, ok := msgMap["content"].([]any) if !ok { // String content or other format - keep as is - newMessages = append(newMessages, msg) continue } - newContent := make([]any, 0, len(content)) + // 延迟分配:只有检测到需要修改的块,才构建新 slice。 + var newContent []any modifiedThisMsg := false - for _, block := range content { + ensureNewContent := func(prefixLen int) { + if newContent != nil { + return + } + newContent = make([]any, 0, len(content)) + if prefixLen > 0 { + newContent = append(newContent, content[:prefixLen]...) + } + } + + for bi := 0; bi < len(content); bi++ { + block := content[bi] blockMap, ok := block.(map[string]any) if !ok { - newContent = append(newContent, block) + if newContent != nil { + newContent = append(newContent, block) + } continue } @@ -254,17 +314,15 @@ func FilterThinkingBlocksForRetry(body []byte) []byte { switch blockType { case "thinking": modifiedThisMsg = true + ensureNewContent(bi) thinkingText, _ := blockMap["thinking"].(string) - if thinkingText == "" { - continue + if thinkingText != "" { + newContent = append(newContent, map[string]any{"type": "text", "text": thinkingText}) } - newContent = append(newContent, map[string]any{ - "type": "text", - "text": thinkingText, - }) continue case "redacted_thinking": modifiedThisMsg = true + ensureNewContent(bi) continue } @@ -272,6 +330,7 @@ func FilterThinkingBlocksForRetry(body []byte) []byte { if blockType == "" { if rawThinking, hasThinking := blockMap["thinking"]; hasThinking { modifiedThisMsg = true + ensureNewContent(bi) switch v := rawThinking.(type) { case string: if v != "" { @@ -286,40 +345,64 @@ func FilterThinkingBlocksForRetry(body []byte) []byte { } } - newContent = append(newContent, block) + if newContent != nil { + newContent = append(newContent, block) + } } // Handle empty content: either from filtering or originally empty + if newContent == nil { + if len(content) == 0 { + modified = true + placeholder := "(content removed)" + if role == "assistant" { + placeholder = "(assistant content removed)" + } + msgMap["content"] = []any{map[string]any{"type": "text", "text": placeholder}} + } + continue + } + if len(newContent) == 0 { modified = true placeholder := "(content removed)" if role == "assistant" { placeholder = "(assistant content removed)" } - newContent = append(newContent, map[string]any{ - "type": "text", - "text": placeholder, - }) - msgMap["content"] = newContent - } else if modifiedThisMsg { + msgMap["content"] = []any{map[string]any{"type": "text", "text": placeholder}} + continue + } + + if modifiedThisMsg { modified = true msgMap["content"] = newContent } - newMessages = append(newMessages, msgMap) } - if modified { - req["messages"] = newMessages - } else { + if !modified && !deleteTopLevelThinking { // Avoid rewriting JSON when no changes are needed. return body } - newBody, err := json.Marshal(req) - if err != nil { - return body + out := body + if deleteTopLevelThinking { + if b, err := sjson.DeleteBytes(out, "thinking"); err == nil { + out = b + } else { + return body + } } - return newBody + if modified { + msgsBytes, err := json.Marshal(messages) + if err != nil { + return body + } + out, err = sjson.SetRawBytes(out, "messages", msgsBytes) + if err != nil { + return body + } + } + return out } // FilterSignatureSensitiveBlocksForRetry is a stronger retry filter for cases where upstream errors indicate diff --git a/backend/internal/service/gateway_request_test.go b/backend/internal/service/gateway_request_test.go index 5b85e752..2a9b4017 100644 --- a/backend/internal/service/gateway_request_test.go +++ b/backend/internal/service/gateway_request_test.go @@ -1,7 +1,11 @@ +//go:build unit + package service import ( "encoding/json" + "fmt" + "strings" "testing" "github.com/Wei-Shaw/sub2api/internal/domain" @@ -434,3 +438,341 @@ func TestFilterSignatureSensitiveBlocksForRetry_DowngradesTools(t *testing.T) { require.Contains(t, content0["text"], "tool_use") require.Contains(t, content1["text"], "tool_result") } + +// ============ Group 7: ParseGatewayRequest 补充单元测试 ============ + +// Task 7.1 — 类型校验边界测试 +func TestParseGatewayRequest_TypeValidation(t *testing.T) { + tests := []struct { + name string + body string + wantErr bool + errSubstr string // 期望的错误信息子串(为空则不检查) + }{ + { + name: "model 为 int", + body: `{"model":123}`, + wantErr: true, + errSubstr: "invalid model field type", + }, + { + name: "model 为 array", + body: `{"model":[]}`, + wantErr: true, + errSubstr: "invalid model field type", + }, + { + name: "model 为 bool", + body: `{"model":true}`, + wantErr: true, + errSubstr: "invalid model field type", + }, + { + name: "model 为 null — gjson Null 类型触发类型校验错误", + body: `{"model":null}`, + wantErr: true, // gjson: Exists()=true, Type=Null != String → 返回错误 + errSubstr: "invalid model field type", + }, + { + name: "stream 为 string", + body: `{"stream":"true"}`, + wantErr: true, + errSubstr: "invalid stream field type", + }, + { + name: "stream 为 int", + body: `{"stream":1}`, + wantErr: true, + errSubstr: "invalid stream field type", + }, + { + name: "stream 为 null — gjson Null 类型触发类型校验错误", + body: `{"stream":null}`, + wantErr: true, // gjson: Exists()=true, Type=Null != True && != False → 返回错误 + errSubstr: "invalid stream field type", + }, + { + name: "model 为 object", + body: `{"model":{}}`, + wantErr: true, + errSubstr: "invalid model field type", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + _, err := ParseGatewayRequest([]byte(tt.body), "") + if tt.wantErr { + require.Error(t, err) + if tt.errSubstr != "" { + require.Contains(t, err.Error(), tt.errSubstr) + } + } else { + require.NoError(t, err) + } + }) + } +} + +// Task 7.2 — 可选字段缺失测试 +func TestParseGatewayRequest_OptionalFieldsMissing(t *testing.T) { + tests := []struct { + name string + body string + wantModel string + wantStream bool + wantMetadataUID string + wantHasSystem bool + wantThinking bool + wantMaxTokens int + wantMessagesNil bool + wantMessagesLen int + }{ + { + name: "完全空 JSON — 所有字段零值", + body: `{}`, + wantModel: "", + wantStream: false, + wantMetadataUID: "", + wantHasSystem: false, + wantThinking: false, + wantMaxTokens: 0, + wantMessagesNil: true, + }, + { + name: "metadata 无 user_id", + body: `{"model":"test"}`, + wantModel: "test", + wantMetadataUID: "", + wantHasSystem: false, + wantThinking: false, + }, + { + name: "thinking 非 enabled(type=disabled)", + body: `{"model":"test","thinking":{"type":"disabled"}}`, + wantModel: "test", + wantThinking: false, + }, + { + name: "thinking 字段缺失", + body: `{"model":"test"}`, + wantModel: "test", + wantThinking: false, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + parsed, err := ParseGatewayRequest([]byte(tt.body), "") + require.NoError(t, err) + + require.Equal(t, tt.wantModel, parsed.Model) + require.Equal(t, tt.wantStream, parsed.Stream) + require.Equal(t, tt.wantMetadataUID, parsed.MetadataUserID) + require.Equal(t, tt.wantHasSystem, parsed.HasSystem) + require.Equal(t, tt.wantThinking, parsed.ThinkingEnabled) + require.Equal(t, tt.wantMaxTokens, parsed.MaxTokens) + + if tt.wantMessagesNil { + require.Nil(t, parsed.Messages) + } + if tt.wantMessagesLen > 0 { + require.Len(t, parsed.Messages, tt.wantMessagesLen) + } + }) + } +} + +// Task 7.3 — Gemini 协议分支测试 +// 已有测试覆盖: +// - TestParseGatewayRequest_GeminiSystemInstruction: 正常 systemInstruction+contents +// - TestParseGatewayRequest_GeminiNoContents: 缺失 contents +// - TestParseGatewayRequest_GeminiContents: 正常 contents(无 systemInstruction) +// 因此跳过。 + +// Task 7.4 — max_tokens 边界测试 +func TestParseGatewayRequest_MaxTokensBoundary(t *testing.T) { + tests := []struct { + name string + body string + wantMaxTokens int + wantErr bool + }{ + { + name: "正常整数", + body: `{"max_tokens":1024}`, + wantMaxTokens: 1024, + }, + { + name: "浮点数(非整数)被忽略", + body: `{"max_tokens":10.5}`, + wantMaxTokens: 0, + }, + { + name: "负整数可以通过", + body: `{"max_tokens":-1}`, + wantMaxTokens: -1, + }, + { + name: "超大值不 panic", + body: `{"max_tokens":9999999999999999}`, + wantMaxTokens: 10000000000000000, // float64 精度导致 9999999999999999 → 1e16 + }, + { + name: "null 值被忽略", + body: `{"max_tokens":null}`, + wantMaxTokens: 0, // gjson Type=Null != Number → 条件不满足,跳过 + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + parsed, err := ParseGatewayRequest([]byte(tt.body), "") + if tt.wantErr { + require.Error(t, err) + return + } + require.NoError(t, err) + require.Equal(t, tt.wantMaxTokens, parsed.MaxTokens) + }) + } +} + +// ============ Task 7.5: Benchmark 测试 ============ + +// parseGatewayRequestOld 是基于完整 json.Unmarshal 的旧实现,用于 benchmark 对比基线。 +// 核心路径:先 Unmarshal 到 map[string]any,再逐字段提取。 +func parseGatewayRequestOld(body []byte, protocol string) (*ParsedRequest, error) { + parsed := &ParsedRequest{ + Body: body, + } + + var req map[string]any + if err := json.Unmarshal(body, &req); err != nil { + return nil, err + } + + // model + if raw, ok := req["model"]; ok { + s, ok := raw.(string) + if !ok { + return nil, fmt.Errorf("invalid model field type") + } + parsed.Model = s + } + + // stream + if raw, ok := req["stream"]; ok { + b, ok := raw.(bool) + if !ok { + return nil, fmt.Errorf("invalid stream field type") + } + parsed.Stream = b + } + + // metadata.user_id + if meta, ok := req["metadata"].(map[string]any); ok { + if uid, ok := meta["user_id"].(string); ok { + parsed.MetadataUserID = uid + } + } + + // thinking.type + if thinking, ok := req["thinking"].(map[string]any); ok { + if thinkType, ok := thinking["type"].(string); ok && thinkType == "enabled" { + parsed.ThinkingEnabled = true + } + } + + // max_tokens + if raw, ok := req["max_tokens"]; ok { + if n, ok := parseIntegralNumber(raw); ok { + parsed.MaxTokens = n + } + } + + // system / messages(按协议分支) + switch protocol { + case domain.PlatformGemini: + if sysInst, ok := req["systemInstruction"].(map[string]any); ok { + if parts, ok := sysInst["parts"].([]any); ok { + parsed.System = parts + } + } + if contents, ok := req["contents"].([]any); ok { + parsed.Messages = contents + } + default: + if system, ok := req["system"]; ok { + parsed.HasSystem = true + parsed.System = system + } + if messages, ok := req["messages"].([]any); ok { + parsed.Messages = messages + } + } + + return parsed, nil +} + +// buildSmallJSON 构建 ~500B 的小型测试 JSON +func buildSmallJSON() []byte { + return []byte(`{"model":"claude-sonnet-4-5","stream":true,"max_tokens":4096,"metadata":{"user_id":"user-abc123"},"thinking":{"type":"enabled","budget_tokens":2048},"system":"You are a helpful assistant.","messages":[{"role":"user","content":"What is the meaning of life?"},{"role":"assistant","content":"The meaning of life is a philosophical question."},{"role":"user","content":"Can you elaborate?"}]}`) +} + +// buildLargeJSON 构建 ~50KB 的大型测试 JSON(大量 messages) +func buildLargeJSON() []byte { + var b strings.Builder + b.WriteString(`{"model":"claude-sonnet-4-5","stream":true,"max_tokens":8192,"metadata":{"user_id":"user-xyz789"},"system":[{"type":"text","text":"You are a detailed assistant.","cache_control":{"type":"ephemeral"}}],"messages":[`) + + msgCount := 200 + for i := 0; i < msgCount; i++ { + if i > 0 { + b.WriteByte(',') + } + if i%2 == 0 { + b.WriteString(fmt.Sprintf(`{"role":"user","content":"This is user message number %d with some extra padding text to make the message reasonably long for benchmarking purposes. Lorem ipsum dolor sit amet."}`, i)) + } else { + b.WriteString(fmt.Sprintf(`{"role":"assistant","content":[{"type":"text","text":"This is assistant response number %d. I will provide a detailed answer with multiple sentences to simulate real conversation content for benchmark testing."}]}`, i)) + } + } + + b.WriteString(`]}`) + return []byte(b.String()) +} + +func BenchmarkParseGatewayRequest_Old_Small(b *testing.B) { + data := buildSmallJSON() + b.SetBytes(int64(len(data))) + b.ResetTimer() + for i := 0; i < b.N; i++ { + _, _ = parseGatewayRequestOld(data, "") + } +} + +func BenchmarkParseGatewayRequest_New_Small(b *testing.B) { + data := buildSmallJSON() + b.SetBytes(int64(len(data))) + b.ResetTimer() + for i := 0; i < b.N; i++ { + _, _ = ParseGatewayRequest(data, "") + } +} + +func BenchmarkParseGatewayRequest_Old_Large(b *testing.B) { + data := buildLargeJSON() + b.SetBytes(int64(len(data))) + b.ResetTimer() + for i := 0; i < b.N; i++ { + _, _ = parseGatewayRequestOld(data, "") + } +} + +func BenchmarkParseGatewayRequest_New_Large(b *testing.B) { + data := buildLargeJSON() + b.SetBytes(int64(len(data))) + b.ResetTimer() + for i := 0; i < b.N; i++ { + _, _ = ParseGatewayRequest(data, "") + } +} diff --git a/backend/internal/service/gateway_service.go b/backend/internal/service/gateway_service.go index 4d1dbad0..e55940ee 100644 --- a/backend/internal/service/gateway_service.go +++ b/backend/internal/service/gateway_service.go @@ -9,7 +9,6 @@ import ( "errors" "fmt" "io" - "log" "log/slog" mathrand "math/rand" "net/http" @@ -24,12 +23,16 @@ import ( "github.com/Wei-Shaw/sub2api/internal/config" "github.com/Wei-Shaw/sub2api/internal/pkg/claude" "github.com/Wei-Shaw/sub2api/internal/pkg/ctxkey" + "github.com/Wei-Shaw/sub2api/internal/pkg/logger" + "github.com/Wei-Shaw/sub2api/internal/pkg/usagestats" "github.com/Wei-Shaw/sub2api/internal/util/responseheaders" "github.com/Wei-Shaw/sub2api/internal/util/urlvalidator" "github.com/cespare/xxhash/v2" "github.com/google/uuid" + gocache "github.com/patrickmn/go-cache" "github.com/tidwall/gjson" "github.com/tidwall/sjson" + "golang.org/x/sync/singleflight" "github.com/gin-gonic/gin" ) @@ -44,6 +47,9 @@ const ( // separator between system blocks, we add "\n\n" at concatenation time. claudeCodeSystemPrompt = "You are Claude Code, Anthropic's official CLI for Claude." maxCacheControlBlocks = 4 // Anthropic API 允许的最大 cache_control 块数量 + + defaultUserGroupRateCacheTTL = 30 * time.Second + defaultModelsListCacheTTL = 15 * time.Second ) const ( @@ -62,6 +68,53 @@ type accountWithLoad struct { var ForceCacheBillingContextKey = forceCacheBillingKeyType{} +var ( + windowCostPrefetchCacheHitTotal atomic.Int64 + windowCostPrefetchCacheMissTotal atomic.Int64 + windowCostPrefetchBatchSQLTotal atomic.Int64 + windowCostPrefetchFallbackTotal atomic.Int64 + windowCostPrefetchErrorTotal atomic.Int64 + + userGroupRateCacheHitTotal atomic.Int64 + userGroupRateCacheMissTotal atomic.Int64 + userGroupRateCacheLoadTotal atomic.Int64 + userGroupRateCacheSFSharedTotal atomic.Int64 + userGroupRateCacheFallbackTotal atomic.Int64 + + modelsListCacheHitTotal atomic.Int64 + modelsListCacheMissTotal atomic.Int64 + modelsListCacheStoreTotal atomic.Int64 +) + +func GatewayWindowCostPrefetchStats() (cacheHit, cacheMiss, batchSQL, fallback, errCount int64) { + return windowCostPrefetchCacheHitTotal.Load(), + windowCostPrefetchCacheMissTotal.Load(), + windowCostPrefetchBatchSQLTotal.Load(), + windowCostPrefetchFallbackTotal.Load(), + windowCostPrefetchErrorTotal.Load() +} + +func GatewayUserGroupRateCacheStats() (cacheHit, cacheMiss, load, singleflightShared, fallback int64) { + return userGroupRateCacheHitTotal.Load(), + userGroupRateCacheMissTotal.Load(), + userGroupRateCacheLoadTotal.Load(), + userGroupRateCacheSFSharedTotal.Load(), + userGroupRateCacheFallbackTotal.Load() +} + +func GatewayModelsListCacheStats() (cacheHit, cacheMiss, store int64) { + return modelsListCacheHitTotal.Load(), modelsListCacheMissTotal.Load(), modelsListCacheStoreTotal.Load() +} + +func cloneStringSlice(src []string) []string { + if len(src) == 0 { + return nil + } + dst := make([]string, len(src)) + copy(dst, src) + return dst +} + // IsForceCacheBilling 检查是否启用强制缓存计费 func IsForceCacheBilling(ctx context.Context) bool { v, _ := ctx.Value(ForceCacheBillingContextKey).(bool) @@ -213,7 +266,7 @@ func logClaudeMimicDebug(req *http.Request, body []byte, account *Account, token if line == "" { return } - log.Printf("[ClaudeMimicDebug] %s", line) + logger.LegacyPrintf("service.gateway", "[ClaudeMimicDebug] %s", line) } func isClaudeCodeCredentialScopeError(msg string) bool { @@ -302,6 +355,60 @@ func derefGroupID(groupID *int64) int64 { return *groupID } +func resolveUserGroupRateCacheTTL(cfg *config.Config) time.Duration { + if cfg == nil || cfg.Gateway.UserGroupRateCacheTTLSeconds <= 0 { + return defaultUserGroupRateCacheTTL + } + return time.Duration(cfg.Gateway.UserGroupRateCacheTTLSeconds) * time.Second +} + +func resolveModelsListCacheTTL(cfg *config.Config) time.Duration { + if cfg == nil || cfg.Gateway.ModelsListCacheTTLSeconds <= 0 { + return defaultModelsListCacheTTL + } + return time.Duration(cfg.Gateway.ModelsListCacheTTLSeconds) * time.Second +} + +func modelsListCacheKey(groupID *int64, platform string) string { + return fmt.Sprintf("%d|%s", derefGroupID(groupID), strings.TrimSpace(platform)) +} + +func prefetchedStickyGroupIDFromContext(ctx context.Context) (int64, bool) { + if ctx == nil { + return 0, false + } + v := ctx.Value(ctxkey.PrefetchedStickyGroupID) + switch t := v.(type) { + case int64: + return t, true + case int: + return int64(t), true + } + return 0, false +} + +func prefetchedStickyAccountIDFromContext(ctx context.Context, groupID *int64) int64 { + if ctx == nil { + return 0 + } + prefetchedGroupID, ok := prefetchedStickyGroupIDFromContext(ctx) + if !ok || prefetchedGroupID != derefGroupID(groupID) { + return 0 + } + v := ctx.Value(ctxkey.PrefetchedStickyAccountID) + switch t := v.(type) { + case int64: + if t > 0 { + return t + } + case int: + if t > 0 { + return int64(t) + } + } + return 0 +} + // shouldClearStickySession 检查账号是否处于不可调度状态,需要清理粘性会话绑定。 // 当账号状态为错误、禁用、不可调度、处于临时不可调度期间, // 或请求的模型处于限流状态时,返回 true。 @@ -366,14 +473,19 @@ type ForwardResult struct { // 图片生成计费字段(仅 gemini-3-pro-image 使用) ImageCount int // 生成的图片数量 ImageSize string // 图片尺寸 "1K", "2K", "4K" + + // Sora 媒体字段 + MediaType string // image / video / prompt + MediaURL string // 生成后的媒体地址(可选) } // UpstreamFailoverError indicates an upstream error that should trigger account failover. type UpstreamFailoverError struct { StatusCode int - ResponseBody []byte // 上游响应体,用于错误透传规则匹配 - ForceCacheBilling bool // Antigravity 粘性会话切换时设为 true - RetryableOnSameAccount bool // 临时性错误(如 Google 间歇性 400、空响应),应在同一账号上重试 N 次再切换 + ResponseBody []byte // 上游响应体,用于错误透传规则匹配 + ResponseHeaders http.Header // 上游响应头,用于透传 cf-ray/cf-mitigated/content-type 等诊断信息 + ForceCacheBilling bool // Antigravity 粘性会话切换时设为 true + RetryableOnSameAccount bool // 临时性错误(如 Google 间歇性 400、空响应),应在同一账号上重试 N 次再切换 } func (e *UpstreamFailoverError) Error() string { @@ -416,6 +528,10 @@ type GatewayService struct { concurrencyService *ConcurrencyService claudeTokenProvider *ClaudeTokenProvider sessionLimitCache SessionLimitCache // 会话数量限制缓存(仅 Anthropic OAuth/SetupToken) + userGroupRateCache *gocache.Cache + userGroupRateSF singleflight.Group + modelsListCache *gocache.Cache + modelsListCacheTTL time.Duration } // NewGatewayService creates a new GatewayService @@ -440,6 +556,9 @@ func NewGatewayService( sessionLimitCache SessionLimitCache, digestStore *DigestSessionStore, ) *GatewayService { + userGroupRateTTL := resolveUserGroupRateCacheTTL(cfg) + modelsListTTL := resolveModelsListCacheTTL(cfg) + return &GatewayService{ accountRepo: accountRepo, groupRepo: groupRepo, @@ -460,6 +579,9 @@ func NewGatewayService( deferredService: deferredService, claudeTokenProvider: claudeTokenProvider, sessionLimitCache: sessionLimitCache, + userGroupRateCache: gocache.New(userGroupRateTTL, time.Minute), + modelsListCache: gocache.New(modelsListTTL, time.Minute), + modelsListCacheTTL: modelsListTTL, } } @@ -931,13 +1053,6 @@ func (s *GatewayService) SelectAccountWithLoadAwareness(ctx context.Context, gro cfg := s.schedulingConfig() - var stickyAccountID int64 - if sessionHash != "" && s.cache != nil { - if accountID, err := s.cache.GetSessionAccountID(ctx, derefGroupID(groupID), sessionHash); err == nil { - stickyAccountID = accountID - } - } - // 检查 Claude Code 客户端限制(可能会替换 groupID 为降级分组) group, groupID, err := s.checkClaudeCodeRestriction(ctx, groupID) if err != nil { @@ -945,12 +1060,21 @@ func (s *GatewayService) SelectAccountWithLoadAwareness(ctx context.Context, gro } ctx = s.withGroupContext(ctx, group) + var stickyAccountID int64 + if prefetch := prefetchedStickyAccountIDFromContext(ctx, groupID); prefetch > 0 { + stickyAccountID = prefetch + } else if sessionHash != "" && s.cache != nil { + if accountID, err := s.cache.GetSessionAccountID(ctx, derefGroupID(groupID), sessionHash); err == nil { + stickyAccountID = accountID + } + } + if s.debugModelRoutingEnabled() && requestedModel != "" { groupPlatform := "" if group != nil { groupPlatform = group.Platform } - log.Printf("[ModelRoutingDebug] select entry: group_id=%v group_platform=%s model=%s session=%s sticky_account=%d load_batch=%v concurrency=%v", + logger.LegacyPrintf("service.gateway", "[ModelRoutingDebug] select entry: group_id=%v group_platform=%s model=%s session=%s sticky_account=%d load_batch=%v concurrency=%v", derefGroupID(groupID), groupPlatform, requestedModel, shortSessionHash(sessionHash), stickyAccountID, cfg.LoadBatchEnabled, s.concurrencyService != nil) } @@ -1020,7 +1144,7 @@ func (s *GatewayService) SelectAccountWithLoadAwareness(ctx context.Context, gro } preferOAuth := platform == PlatformGemini if s.debugModelRoutingEnabled() && platform == PlatformAnthropic && requestedModel != "" { - log.Printf("[ModelRoutingDebug] load-aware enabled: group_id=%v model=%s session=%s platform=%s", derefGroupID(groupID), requestedModel, shortSessionHash(sessionHash), platform) + logger.LegacyPrintf("service.gateway", "[ModelRoutingDebug] load-aware enabled: group_id=%v model=%s session=%s platform=%s", derefGroupID(groupID), requestedModel, shortSessionHash(sessionHash), platform) } accounts, useMixed, err := s.listSchedulableAccounts(ctx, groupID, platform, hasForcePlatform) @@ -1030,6 +1154,7 @@ func (s *GatewayService) SelectAccountWithLoadAwareness(ctx context.Context, gro if len(accounts) == 0 { return nil, errors.New("no available accounts") } + ctx = s.withWindowCostPrefetch(ctx, accounts) isExcluded := func(accountID int64) bool { if excludedIDs == nil { @@ -1050,7 +1175,7 @@ func (s *GatewayService) SelectAccountWithLoadAwareness(ctx context.Context, gro if group != nil && requestedModel != "" && group.Platform == PlatformAnthropic { routingAccountIDs = group.GetRoutingAccountIDs(requestedModel) if s.debugModelRoutingEnabled() { - log.Printf("[ModelRoutingDebug] context group routing: group_id=%d model=%s enabled=%v rules=%d matched_ids=%v session=%s sticky_account=%d", + logger.LegacyPrintf("service.gateway", "[ModelRoutingDebug] context group routing: group_id=%d model=%s enabled=%v rules=%d matched_ids=%v session=%s sticky_account=%d", group.ID, requestedModel, group.ModelRoutingEnabled, len(group.ModelRouting), routingAccountIDs, shortSessionHash(sessionHash), stickyAccountID) if len(routingAccountIDs) == 0 && group.ModelRoutingEnabled && len(group.ModelRouting) > 0 { keys := make([]string, 0, len(group.ModelRouting)) @@ -1062,7 +1187,7 @@ func (s *GatewayService) SelectAccountWithLoadAwareness(ctx context.Context, gro if len(keys) > maxKeys { keys = keys[:maxKeys] } - log.Printf("[ModelRoutingDebug] context group routing miss: group_id=%d model=%s patterns(sample)=%v", group.ID, requestedModel, keys) + logger.LegacyPrintf("service.gateway", "[ModelRoutingDebug] context group routing miss: group_id=%d model=%s patterns(sample)=%v", group.ID, requestedModel, keys) } } } @@ -1109,20 +1234,19 @@ func (s *GatewayService) SelectAccountWithLoadAwareness(ctx context.Context, gro } if s.debugModelRoutingEnabled() { - log.Printf("[ModelRoutingDebug] routed candidates: group_id=%v model=%s routed=%d candidates=%d filtered(excluded=%d missing=%d unsched=%d platform=%d model_scope=%d model_mapping=%d window_cost=%d)", + logger.LegacyPrintf("service.gateway", "[ModelRoutingDebug] routed candidates: group_id=%v model=%s routed=%d candidates=%d filtered(excluded=%d missing=%d unsched=%d platform=%d model_scope=%d model_mapping=%d window_cost=%d)", derefGroupID(groupID), requestedModel, len(routingAccountIDs), len(routingCandidates), filteredExcluded, filteredMissing, filteredUnsched, filteredPlatform, filteredModelScope, filteredModelMapping, filteredWindowCost) if len(modelScopeSkippedIDs) > 0 { - log.Printf("[ModelRoutingDebug] model_rate_limited accounts skipped: group_id=%v model=%s account_ids=%v", + logger.LegacyPrintf("service.gateway", "[ModelRoutingDebug] model_rate_limited accounts skipped: group_id=%v model=%s account_ids=%v", derefGroupID(groupID), requestedModel, modelScopeSkippedIDs) } } if len(routingCandidates) > 0 { // 1.5. 在路由账号范围内检查粘性会话 - if sessionHash != "" && s.cache != nil { - stickyAccountID, err := s.cache.GetSessionAccountID(ctx, derefGroupID(groupID), sessionHash) - if err == nil && stickyAccountID > 0 && containsInt64(routingAccountIDs, stickyAccountID) && !isExcluded(stickyAccountID) { + if sessionHash != "" && stickyAccountID > 0 { + if containsInt64(routingAccountIDs, stickyAccountID) && !isExcluded(stickyAccountID) { // 粘性账号在路由列表中,优先使用 if stickyAccount, ok := accountByID[stickyAccountID]; ok { if stickyAccount.IsSchedulable() && @@ -1138,7 +1262,7 @@ func (s *GatewayService) SelectAccountWithLoadAwareness(ctx context.Context, gro // 继续到负载感知选择 } else { if s.debugModelRoutingEnabled() { - log.Printf("[ModelRoutingDebug] routed sticky hit: group_id=%v model=%s session=%s account=%d", derefGroupID(groupID), requestedModel, shortSessionHash(sessionHash), stickyAccountID) + logger.LegacyPrintf("service.gateway", "[ModelRoutingDebug] routed sticky hit: group_id=%v model=%s session=%s account=%d", derefGroupID(groupID), requestedModel, shortSessionHash(sessionHash), stickyAccountID) } return &AccountSelectionResult{ Account: stickyAccount, @@ -1231,7 +1355,7 @@ func (s *GatewayService) SelectAccountWithLoadAwareness(ctx context.Context, gro _ = s.cache.SetSessionAccountID(ctx, derefGroupID(groupID), sessionHash, item.account.ID, stickySessionTTL) } if s.debugModelRoutingEnabled() { - log.Printf("[ModelRoutingDebug] routed select: group_id=%v model=%s session=%s account=%d", derefGroupID(groupID), requestedModel, shortSessionHash(sessionHash), item.account.ID) + logger.LegacyPrintf("service.gateway", "[ModelRoutingDebug] routed select: group_id=%v model=%s session=%s account=%d", derefGroupID(groupID), requestedModel, shortSessionHash(sessionHash), item.account.ID) } return &AccountSelectionResult{ Account: item.account, @@ -1248,7 +1372,7 @@ func (s *GatewayService) SelectAccountWithLoadAwareness(ctx context.Context, gro continue // 会话限制已满,尝试下一个 } if s.debugModelRoutingEnabled() { - log.Printf("[ModelRoutingDebug] routed wait: group_id=%v model=%s session=%s account=%d", derefGroupID(groupID), requestedModel, shortSessionHash(sessionHash), item.account.ID) + logger.LegacyPrintf("service.gateway", "[ModelRoutingDebug] routed wait: group_id=%v model=%s session=%s account=%d", derefGroupID(groupID), requestedModel, shortSessionHash(sessionHash), item.account.ID) } return &AccountSelectionResult{ Account: item.account, @@ -1263,14 +1387,14 @@ func (s *GatewayService) SelectAccountWithLoadAwareness(ctx context.Context, gro // 所有路由账号会话限制都已满,继续到 Layer 2 回退 } // 路由列表中的账号都不可用(负载率 >= 100),继续到 Layer 2 回退 - log.Printf("[ModelRouting] All routed accounts unavailable for model=%s, falling back to normal selection", requestedModel) + logger.LegacyPrintf("service.gateway", "[ModelRouting] All routed accounts unavailable for model=%s, falling back to normal selection", requestedModel) } } // ============ Layer 1.5: 粘性会话(仅在无模型路由配置时生效) ============ - if len(routingAccountIDs) == 0 && sessionHash != "" && s.cache != nil { - accountID, err := s.cache.GetSessionAccountID(ctx, derefGroupID(groupID), sessionHash) - if err == nil && accountID > 0 && !isExcluded(accountID) { + if len(routingAccountIDs) == 0 && sessionHash != "" && stickyAccountID > 0 && !isExcluded(stickyAccountID) { + accountID := stickyAccountID + if accountID > 0 && !isExcluded(accountID) { account, ok := accountByID[accountID] if ok { // 检查账户是否需要清理粘性会话绑定 @@ -1524,20 +1648,20 @@ func (s *GatewayService) routingAccountIDsForRequest(ctx context.Context, groupI group, err := s.resolveGroupByID(ctx, *groupID) if err != nil || group == nil { if s.debugModelRoutingEnabled() { - log.Printf("[ModelRoutingDebug] resolve group failed: group_id=%v model=%s platform=%s err=%v", derefGroupID(groupID), requestedModel, platform, err) + logger.LegacyPrintf("service.gateway", "[ModelRoutingDebug] resolve group failed: group_id=%v model=%s platform=%s err=%v", derefGroupID(groupID), requestedModel, platform, err) } return nil } // Preserve existing behavior: model routing only applies to anthropic groups. if group.Platform != PlatformAnthropic { if s.debugModelRoutingEnabled() { - log.Printf("[ModelRoutingDebug] skip: non-anthropic group platform: group_id=%d group_platform=%s model=%s", group.ID, group.Platform, requestedModel) + logger.LegacyPrintf("service.gateway", "[ModelRoutingDebug] skip: non-anthropic group platform: group_id=%d group_platform=%s model=%s", group.ID, group.Platform, requestedModel) } return nil } ids := group.GetRoutingAccountIDs(requestedModel) if s.debugModelRoutingEnabled() { - log.Printf("[ModelRoutingDebug] routing lookup: group_id=%d model=%s enabled=%v rules=%d matched_ids=%v", + logger.LegacyPrintf("service.gateway", "[ModelRoutingDebug] routing lookup: group_id=%d model=%s enabled=%v rules=%d matched_ids=%v", group.ID, requestedModel, group.ModelRoutingEnabled, len(group.ModelRouting), ids) } return ids @@ -1755,6 +1879,129 @@ func (s *GatewayService) tryAcquireAccountSlot(ctx context.Context, accountID in return s.concurrencyService.AcquireAccountSlot(ctx, accountID, maxConcurrency) } +type usageLogWindowStatsBatchProvider interface { + GetAccountWindowStatsBatch(ctx context.Context, accountIDs []int64, startTime time.Time) (map[int64]*usagestats.AccountStats, error) +} + +type windowCostPrefetchContextKeyType struct{} + +var windowCostPrefetchContextKey = windowCostPrefetchContextKeyType{} + +func windowCostFromPrefetchContext(ctx context.Context, accountID int64) (float64, bool) { + if ctx == nil || accountID <= 0 { + return 0, false + } + m, ok := ctx.Value(windowCostPrefetchContextKey).(map[int64]float64) + if !ok || len(m) == 0 { + return 0, false + } + v, exists := m[accountID] + return v, exists +} + +func (s *GatewayService) withWindowCostPrefetch(ctx context.Context, accounts []Account) context.Context { + if ctx == nil || len(accounts) == 0 || s.sessionLimitCache == nil || s.usageLogRepo == nil { + return ctx + } + + accountByID := make(map[int64]*Account) + accountIDs := make([]int64, 0, len(accounts)) + for i := range accounts { + account := &accounts[i] + if account == nil || !account.IsAnthropicOAuthOrSetupToken() { + continue + } + if account.GetWindowCostLimit() <= 0 { + continue + } + accountByID[account.ID] = account + accountIDs = append(accountIDs, account.ID) + } + if len(accountIDs) == 0 { + return ctx + } + + costs := make(map[int64]float64, len(accountIDs)) + cacheValues, err := s.sessionLimitCache.GetWindowCostBatch(ctx, accountIDs) + if err == nil { + for accountID, cost := range cacheValues { + costs[accountID] = cost + } + windowCostPrefetchCacheHitTotal.Add(int64(len(cacheValues))) + } else { + windowCostPrefetchErrorTotal.Add(1) + logger.LegacyPrintf("service.gateway", "window_cost batch cache read failed: %v", err) + } + cacheMissCount := len(accountIDs) - len(costs) + if cacheMissCount < 0 { + cacheMissCount = 0 + } + windowCostPrefetchCacheMissTotal.Add(int64(cacheMissCount)) + + missingByStart := make(map[int64][]int64) + startTimes := make(map[int64]time.Time) + for _, accountID := range accountIDs { + if _, ok := costs[accountID]; ok { + continue + } + account := accountByID[accountID] + if account == nil { + continue + } + startTime := account.GetCurrentWindowStartTime() + startKey := startTime.Unix() + missingByStart[startKey] = append(missingByStart[startKey], accountID) + startTimes[startKey] = startTime + } + if len(missingByStart) == 0 { + return context.WithValue(ctx, windowCostPrefetchContextKey, costs) + } + + batchReader, hasBatch := s.usageLogRepo.(usageLogWindowStatsBatchProvider) + for startKey, ids := range missingByStart { + startTime := startTimes[startKey] + + if hasBatch { + windowCostPrefetchBatchSQLTotal.Add(1) + queryStart := time.Now() + statsByAccount, err := batchReader.GetAccountWindowStatsBatch(ctx, ids, startTime) + if err == nil { + slog.Debug("window_cost_batch_query_ok", + "accounts", len(ids), + "window_start", startTime.Format(time.RFC3339), + "duration_ms", time.Since(queryStart).Milliseconds()) + for _, accountID := range ids { + stats := statsByAccount[accountID] + cost := 0.0 + if stats != nil { + cost = stats.StandardCost + } + costs[accountID] = cost + _ = s.sessionLimitCache.SetWindowCost(ctx, accountID, cost) + } + continue + } + windowCostPrefetchErrorTotal.Add(1) + logger.LegacyPrintf("service.gateway", "window_cost batch db query failed: start=%s err=%v", startTime.Format(time.RFC3339), err) + } + + // 回退路径:缺少批量仓储能力或批量查询失败时,按账号单查(失败开放)。 + windowCostPrefetchFallbackTotal.Add(int64(len(ids))) + for _, accountID := range ids { + stats, err := s.usageLogRepo.GetAccountWindowStats(ctx, accountID, startTime) + if err != nil { + windowCostPrefetchErrorTotal.Add(1) + continue + } + cost := stats.StandardCost + costs[accountID] = cost + _ = s.sessionLimitCache.SetWindowCost(ctx, accountID, cost) + } + } + + return context.WithValue(ctx, windowCostPrefetchContextKey, costs) +} + // isAccountSchedulableForWindowCost 检查账号是否可根据窗口费用进行调度 // 仅适用于 Anthropic OAuth/SetupToken 账号 // 返回 true 表示可调度,false 表示不可调度 @@ -1771,6 +2018,10 @@ func (s *GatewayService) isAccountSchedulableForWindowCost(ctx context.Context, // 尝试从缓存获取窗口费用 var currentCost float64 + if cost, ok := windowCostFromPrefetchContext(ctx, account.ID); ok { + currentCost = cost + goto checkSchedulability + } if s.sessionLimitCache != nil { if cost, hit, err := s.sessionLimitCache.GetWindowCost(ctx, account.ID); err == nil && hit { currentCost = cost @@ -1968,7 +2219,7 @@ func sortAccountsByPriorityAndLastUsed(accounts []*Account, preferOAuth bool) { return a.LastUsedAt.Before(*b.LastUsedAt) } }) - shuffleWithinPriorityAndLastUsed(accounts) + shuffleWithinPriorityAndLastUsed(accounts, preferOAuth) } // shuffleWithinSortGroups 对排序后的 accountWithLoad 切片,按 (Priority, LoadRate, LastUsedAt) 分组后组内随机打乱。 @@ -2004,7 +2255,12 @@ func sameAccountWithLoadGroup(a, b accountWithLoad) bool { } // shuffleWithinPriorityAndLastUsed 对排序后的 []*Account 切片,按 (Priority, LastUsedAt) 分组后组内随机打乱。 -func shuffleWithinPriorityAndLastUsed(accounts []*Account) { +// +// 注意:当 preferOAuth=true 时,需要保证 OAuth 账号在同组内仍然优先,否则会把排序时的偏好打散掉。 +// 因此这里采用“组内分区 + 分区内 shuffle”的方式: +// - 先把同组账号按 (OAuth / 非 OAuth) 拆成两段,保持 OAuth 段在前; +// - 再分别在各段内随机打散,避免热点。 +func shuffleWithinPriorityAndLastUsed(accounts []*Account, preferOAuth bool) { if len(accounts) <= 1 { return } @@ -2015,9 +2271,29 @@ func shuffleWithinPriorityAndLastUsed(accounts []*Account) { j++ } if j-i > 1 { - mathrand.Shuffle(j-i, func(a, b int) { - accounts[i+a], accounts[i+b] = accounts[i+b], accounts[i+a] - }) + if preferOAuth { + oauth := make([]*Account, 0, j-i) + others := make([]*Account, 0, j-i) + for _, acc := range accounts[i:j] { + if acc.Type == AccountTypeOAuth { + oauth = append(oauth, acc) + } else { + others = append(others, acc) + } + } + if len(oauth) > 1 { + mathrand.Shuffle(len(oauth), func(a, b int) { oauth[a], oauth[b] = oauth[b], oauth[a] }) + } + if len(others) > 1 { + mathrand.Shuffle(len(others), func(a, b int) { others[a], others[b] = others[b], others[a] }) + } + copy(accounts[i:], oauth) + copy(accounts[i+len(oauth):], others) + } else { + mathrand.Shuffle(j-i, func(a, b int) { + accounts[i+a], accounts[i+b] = accounts[i+b], accounts[i+a] + }) + } } i = j } @@ -2106,7 +2382,7 @@ func (s *GatewayService) selectAccountForModelWithPlatform(ctx context.Context, // so switching model can switch upstream account within the same sticky session. if len(routingAccountIDs) > 0 { if s.debugModelRoutingEnabled() { - log.Printf("[ModelRoutingDebug] legacy routed begin: group_id=%v model=%s platform=%s session=%s routed_ids=%v", + logger.LegacyPrintf("service.gateway", "[ModelRoutingDebug] legacy routed begin: group_id=%v model=%s platform=%s session=%s routed_ids=%v", derefGroupID(groupID), requestedModel, platform, shortSessionHash(sessionHash), routingAccountIDs) } // 1) Sticky session only applies if the bound account is within the routing set. @@ -2123,7 +2399,7 @@ func (s *GatewayService) selectAccountForModelWithPlatform(ctx context.Context, } if !clearSticky && s.isAccountInGroup(account, groupID) && account.Platform == platform && (requestedModel == "" || s.isModelSupportedByAccountWithContext(ctx, account, requestedModel)) && account.IsSchedulableForModelWithContext(ctx, requestedModel) { if s.debugModelRoutingEnabled() { - log.Printf("[ModelRoutingDebug] legacy routed sticky hit: group_id=%v model=%s session=%s account=%d", derefGroupID(groupID), requestedModel, shortSessionHash(sessionHash), accountID) + logger.LegacyPrintf("service.gateway", "[ModelRoutingDebug] legacy routed sticky hit: group_id=%v model=%s session=%s account=%d", derefGroupID(groupID), requestedModel, shortSessionHash(sessionHash), accountID) } return account, nil } @@ -2198,15 +2474,15 @@ func (s *GatewayService) selectAccountForModelWithPlatform(ctx context.Context, if selected != nil { if sessionHash != "" && s.cache != nil { if err := s.cache.SetSessionAccountID(ctx, derefGroupID(groupID), sessionHash, selected.ID, stickySessionTTL); err != nil { - log.Printf("set session account failed: session=%s account_id=%d err=%v", sessionHash, selected.ID, err) + logger.LegacyPrintf("service.gateway", "set session account failed: session=%s account_id=%d err=%v", sessionHash, selected.ID, err) } } if s.debugModelRoutingEnabled() { - log.Printf("[ModelRoutingDebug] legacy routed select: group_id=%v model=%s session=%s account=%d", derefGroupID(groupID), requestedModel, shortSessionHash(sessionHash), selected.ID) + logger.LegacyPrintf("service.gateway", "[ModelRoutingDebug] legacy routed select: group_id=%v model=%s session=%s account=%d", derefGroupID(groupID), requestedModel, shortSessionHash(sessionHash), selected.ID) } return selected, nil } - log.Printf("[ModelRouting] No routed accounts available for model=%s, falling back to normal selection", requestedModel) + logger.LegacyPrintf("service.gateway", "[ModelRouting] No routed accounts available for model=%s, falling back to normal selection", requestedModel) } // 1. 查询粘性会话 @@ -2294,7 +2570,7 @@ func (s *GatewayService) selectAccountForModelWithPlatform(ctx context.Context, // 4. 建立粘性绑定 if sessionHash != "" && s.cache != nil { if err := s.cache.SetSessionAccountID(ctx, derefGroupID(groupID), sessionHash, selected.ID, stickySessionTTL); err != nil { - log.Printf("set session account failed: session=%s account_id=%d err=%v", sessionHash, selected.ID, err) + logger.LegacyPrintf("service.gateway", "set session account failed: session=%s account_id=%d err=%v", sessionHash, selected.ID, err) } } @@ -2313,7 +2589,7 @@ func (s *GatewayService) selectAccountWithMixedScheduling(ctx context.Context, g // ============ Model Routing (legacy path): apply before sticky session ============ if len(routingAccountIDs) > 0 { if s.debugModelRoutingEnabled() { - log.Printf("[ModelRoutingDebug] legacy mixed routed begin: group_id=%v model=%s platform=%s session=%s routed_ids=%v", + logger.LegacyPrintf("service.gateway", "[ModelRoutingDebug] legacy mixed routed begin: group_id=%v model=%s platform=%s session=%s routed_ids=%v", derefGroupID(groupID), requestedModel, nativePlatform, shortSessionHash(sessionHash), routingAccountIDs) } // 1) Sticky session only applies if the bound account is within the routing set. @@ -2331,7 +2607,7 @@ func (s *GatewayService) selectAccountWithMixedScheduling(ctx context.Context, g if !clearSticky && s.isAccountInGroup(account, groupID) && (requestedModel == "" || s.isModelSupportedByAccountWithContext(ctx, account, requestedModel)) && account.IsSchedulableForModelWithContext(ctx, requestedModel) { if account.Platform == nativePlatform || (account.Platform == PlatformAntigravity && account.IsMixedSchedulingEnabled()) { if s.debugModelRoutingEnabled() { - log.Printf("[ModelRoutingDebug] legacy mixed routed sticky hit: group_id=%v model=%s session=%s account=%d", derefGroupID(groupID), requestedModel, shortSessionHash(sessionHash), accountID) + logger.LegacyPrintf("service.gateway", "[ModelRoutingDebug] legacy mixed routed sticky hit: group_id=%v model=%s session=%s account=%d", derefGroupID(groupID), requestedModel, shortSessionHash(sessionHash), accountID) } return account, nil } @@ -2407,15 +2683,15 @@ func (s *GatewayService) selectAccountWithMixedScheduling(ctx context.Context, g if selected != nil { if sessionHash != "" && s.cache != nil { if err := s.cache.SetSessionAccountID(ctx, derefGroupID(groupID), sessionHash, selected.ID, stickySessionTTL); err != nil { - log.Printf("set session account failed: session=%s account_id=%d err=%v", sessionHash, selected.ID, err) + logger.LegacyPrintf("service.gateway", "set session account failed: session=%s account_id=%d err=%v", sessionHash, selected.ID, err) } } if s.debugModelRoutingEnabled() { - log.Printf("[ModelRoutingDebug] legacy mixed routed select: group_id=%v model=%s session=%s account=%d", derefGroupID(groupID), requestedModel, shortSessionHash(sessionHash), selected.ID) + logger.LegacyPrintf("service.gateway", "[ModelRoutingDebug] legacy mixed routed select: group_id=%v model=%s session=%s account=%d", derefGroupID(groupID), requestedModel, shortSessionHash(sessionHash), selected.ID) } return selected, nil } - log.Printf("[ModelRouting] No routed accounts available for model=%s, falling back to normal selection", requestedModel) + logger.LegacyPrintf("service.gateway", "[ModelRouting] No routed accounts available for model=%s, falling back to normal selection", requestedModel) } // 1. 查询粘性会话 @@ -2505,7 +2781,7 @@ func (s *GatewayService) selectAccountWithMixedScheduling(ctx context.Context, g // 4. 建立粘性绑定 if sessionHash != "" && s.cache != nil { if err := s.cache.SetSessionAccountID(ctx, derefGroupID(groupID), sessionHash, selected.ID, stickySessionTTL); err != nil { - log.Printf("set session account failed: session=%s account_id=%d err=%v", sessionHash, selected.ID, err) + logger.LegacyPrintf("service.gateway", "set session account failed: session=%s account_id=%d err=%v", sessionHash, selected.ID, err) } } @@ -2820,7 +3096,7 @@ func injectClaudeCodePrompt(body []byte, system any) []byte { result, err := sjson.SetBytes(body, "system", newSystem) if err != nil { - log.Printf("Warning: failed to inject Claude Code prompt: %v", err) + logger.LegacyPrintf("service.gateway", "Warning: failed to inject Claude Code prompt: %v", err) return body } return result @@ -2976,7 +3252,7 @@ func removeCacheControlFromThinkingBlocks(data map[string]any) { if blockType, _ := m["type"].(string); blockType == "thinking" { if _, has := m["cache_control"]; has { delete(m, "cache_control") - log.Printf("[Warning] Removed illegal cache_control from thinking block in system") + logger.LegacyPrintf("service.gateway", "[Warning] Removed illegal cache_control from thinking block in system") } } } @@ -2993,7 +3269,7 @@ func removeCacheControlFromThinkingBlocks(data map[string]any) { if blockType, _ := m["type"].(string); blockType == "thinking" { if _, has := m["cache_control"]; has { delete(m, "cache_control") - log.Printf("[Warning] Removed illegal cache_control from thinking block in messages[%d].content[%d]", msgIdx, contentIdx) + logger.LegacyPrintf("service.gateway", "[Warning] Removed illegal cache_control from thinking block in messages[%d].content[%d]", msgIdx, contentIdx) } } } @@ -3011,6 +3287,10 @@ func (s *GatewayService) Forward(ctx context.Context, c *gin.Context, account *A return nil, fmt.Errorf("parse request: empty request") } + if account != nil && account.IsAnthropicAPIKeyPassthroughEnabled() { + return s.forwardAnthropicAPIKeyPassthrough(ctx, c, account, parsed.Body, parsed.Model, parsed.Stream, startTime) + } + body := parsed.Body reqModel := parsed.Model reqStream := parsed.Stream @@ -3072,7 +3352,7 @@ func (s *GatewayService) Forward(ctx context.Context, c *gin.Context, account *A // 替换请求体中的模型名 body = s.replaceModelInBody(body, mappedModel) reqModel = mappedModel - log.Printf("Model mapping applied: %s -> %s (account: %s, source=%s)", originalModel, mappedModel, account.Name, mappingSource) + logger.LegacyPrintf("service.gateway", "Model mapping applied: %s -> %s (account: %s, source=%s)", originalModel, mappedModel, account.Name, mappingSource) } // 获取凭证 @@ -3088,16 +3368,16 @@ func (s *GatewayService) Forward(ctx context.Context, c *gin.Context, account *A } // 调试日志:记录即将转发的账号信息 - log.Printf("[Forward] Using account: ID=%d Name=%s Platform=%s Type=%s TLSFingerprint=%v Proxy=%s", + logger.LegacyPrintf("service.gateway", "[Forward] Using account: ID=%d Name=%s Platform=%s Type=%s TLSFingerprint=%v Proxy=%s", account.ID, account.Name, account.Platform, account.Type, account.IsTLSFingerprintEnabled(), proxyURL) + // 重试间复用同一请求体,避免每次 string(body) 产生额外分配。 + setOpsUpstreamRequestBody(c, body) // 重试循环 var resp *http.Response retryStart := time.Now() for attempt := 1; attempt <= maxRetryAttempts; attempt++ { // 构建上游请求(每次重试需要重新构建,因为请求体需要重新读取) - // Capture upstream request body for ops retry of this attempt. - c.Set(OpsUpstreamRequestBodyKey, string(body)) upstreamReq, err := s.buildUpstreamRequest(ctx, c, account, body, token, tokenType, reqModel, reqStream, shouldMimicClaudeCode) if err != nil { return nil, err @@ -3168,7 +3448,7 @@ func (s *GatewayService) Forward(ctx context.Context, c *gin.Context, account *A resp.Body = io.NopCloser(bytes.NewReader(respBody)) break } - log.Printf("Account %d: detected thinking block signature error, retrying with filtered thinking blocks", account.ID) + logger.LegacyPrintf("service.gateway", "Account %d: detected thinking block signature error, retrying with filtered thinking blocks", account.ID) // Conservative two-stage fallback: // 1) Disable thinking + thinking->text (preserve content) @@ -3181,7 +3461,7 @@ func (s *GatewayService) Forward(ctx context.Context, c *gin.Context, account *A retryResp, retryErr := s.httpUpstream.DoWithTLS(retryReq, proxyURL, account.ID, account.Concurrency, account.IsTLSFingerprintEnabled()) if retryErr == nil { if retryResp.StatusCode < 400 { - log.Printf("Account %d: signature error retry succeeded (thinking downgraded)", account.ID) + logger.LegacyPrintf("service.gateway", "Account %d: signature error retry succeeded (thinking downgraded)", account.ID) resp = retryResp break } @@ -3206,7 +3486,7 @@ func (s *GatewayService) Forward(ctx context.Context, c *gin.Context, account *A }) msg2 := extractUpstreamErrorMessage(retryRespBody) if looksLikeToolSignatureError(msg2) && time.Since(retryStart) < maxRetryElapsed { - log.Printf("Account %d: signature retry still failing and looks tool-related, retrying with tool blocks downgraded", account.ID) + logger.LegacyPrintf("service.gateway", "Account %d: signature retry still failing and looks tool-related, retrying with tool blocks downgraded", account.ID) filteredBody2 := FilterSignatureSensitiveBlocksForRetry(body) retryReq2, buildErr2 := s.buildUpstreamRequest(ctx, c, account, filteredBody2, token, tokenType, reqModel, reqStream, shouldMimicClaudeCode) if buildErr2 == nil { @@ -3226,9 +3506,9 @@ func (s *GatewayService) Forward(ctx context.Context, c *gin.Context, account *A Kind: "signature_retry_tools_request_error", Message: sanitizeUpstreamErrorMessage(retryErr2.Error()), }) - log.Printf("Account %d: tool-downgrade signature retry failed: %v", account.ID, retryErr2) + logger.LegacyPrintf("service.gateway", "Account %d: tool-downgrade signature retry failed: %v", account.ID, retryErr2) } else { - log.Printf("Account %d: tool-downgrade signature retry build failed: %v", account.ID, buildErr2) + logger.LegacyPrintf("service.gateway", "Account %d: tool-downgrade signature retry build failed: %v", account.ID, buildErr2) } } } @@ -3244,9 +3524,9 @@ func (s *GatewayService) Forward(ctx context.Context, c *gin.Context, account *A if retryResp != nil && retryResp.Body != nil { _ = retryResp.Body.Close() } - log.Printf("Account %d: signature error retry failed: %v", account.ID, retryErr) + logger.LegacyPrintf("service.gateway", "Account %d: signature error retry failed: %v", account.ID, retryErr) } else { - log.Printf("Account %d: signature error retry build request failed: %v", account.ID, buildErr) + logger.LegacyPrintf("service.gateway", "Account %d: signature error retry build request failed: %v", account.ID, buildErr) } // Retry failed: restore original response body and continue handling. @@ -3292,7 +3572,7 @@ func (s *GatewayService) Forward(ctx context.Context, c *gin.Context, account *A return "" }(), }) - log.Printf("Account %d: upstream error %d, retry %d/%d after %v (elapsed=%v/%v)", + logger.LegacyPrintf("service.gateway", "Account %d: upstream error %d, retry %d/%d after %v (elapsed=%v/%v)", account.ID, resp.StatusCode, attempt, maxRetryAttempts, delay, elapsed, maxRetryElapsed) if err := sleepWithContext(ctx, delay); err != nil { return nil, err @@ -3305,10 +3585,10 @@ func (s *GatewayService) Forward(ctx context.Context, c *gin.Context, account *A // 不需要重试(成功或不可重试的错误),跳出循环 // DEBUG: 输出响应 headers(用于检测 rate limit 信息) - if account.Platform == PlatformGemini && resp.StatusCode < 400 { - log.Printf("[DEBUG] Gemini API Response Headers for account %d:", account.ID) + if account.Platform == PlatformGemini && resp.StatusCode < 400 && s.cfg != nil && s.cfg.Gateway.GeminiDebugResponseHeaders { + logger.LegacyPrintf("service.gateway", "[DEBUG] Gemini API Response Headers for account %d:", account.ID) for k, v := range resp.Header { - log.Printf("[DEBUG] %s: %v", k, v) + logger.LegacyPrintf("service.gateway", "[DEBUG] %s: %v", k, v) } } break @@ -3326,7 +3606,7 @@ func (s *GatewayService) Forward(ctx context.Context, c *gin.Context, account *A resp.Body = io.NopCloser(bytes.NewReader(respBody)) // 调试日志:打印重试耗尽后的错误响应 - log.Printf("[Forward] Upstream error (retry exhausted, failover): Account=%d(%s) Status=%d RequestID=%s Body=%s", + logger.LegacyPrintf("service.gateway", "[Forward] Upstream error (retry exhausted, failover): Account=%d(%s) Status=%d RequestID=%s Body=%s", account.ID, account.Name, resp.StatusCode, resp.Header.Get("x-request-id"), truncateString(string(respBody), 1000)) s.handleRetryExhaustedSideEffects(ctx, resp, account) @@ -3357,7 +3637,7 @@ func (s *GatewayService) Forward(ctx context.Context, c *gin.Context, account *A resp.Body = io.NopCloser(bytes.NewReader(respBody)) // 调试日志:打印上游错误响应 - log.Printf("[Forward] Upstream error (failover): Account=%d(%s) Status=%d RequestID=%s Body=%s", + logger.LegacyPrintf("service.gateway", "[Forward] Upstream error (failover): Account=%d(%s) Status=%d RequestID=%s Body=%s", account.ID, account.Name, resp.StatusCode, resp.Header.Get("x-request-id"), truncateString(string(respBody), 1000)) s.handleFailoverSideEffects(ctx, resp, account) @@ -3411,13 +3691,13 @@ func (s *GatewayService) Forward(ctx context.Context, c *gin.Context, account *A }) if s.cfg.Gateway.LogUpstreamErrorBody { - log.Printf( + logger.LegacyPrintf("service.gateway", "Account %d: 400 error, attempting failover: %s", account.ID, truncateForLog(respBody, s.cfg.Gateway.LogUpstreamErrorBodyMaxBytes), ) } else { - log.Printf("Account %d: 400 error, attempting failover", account.ID) + logger.LegacyPrintf("service.gateway", "Account %d: 400 error, attempting failover", account.ID) } s.handleFailoverSideEffects(ctx, resp, account) return nil, &UpstreamFailoverError{StatusCode: resp.StatusCode, ResponseBody: respBody} @@ -3461,6 +3741,602 @@ func (s *GatewayService) Forward(ctx context.Context, c *gin.Context, account *A }, nil } +func (s *GatewayService) forwardAnthropicAPIKeyPassthrough( + ctx context.Context, + c *gin.Context, + account *Account, + body []byte, + reqModel string, + reqStream bool, + startTime time.Time, +) (*ForwardResult, error) { + token, tokenType, err := s.GetAccessToken(ctx, account) + if err != nil { + return nil, err + } + if tokenType != "apikey" { + return nil, fmt.Errorf("anthropic api key passthrough requires apikey token, got: %s", tokenType) + } + + proxyURL := "" + if account.ProxyID != nil && account.Proxy != nil { + proxyURL = account.Proxy.URL() + } + + logger.LegacyPrintf("service.gateway", "[Anthropic 自动透传] 命中 API Key 透传分支: account=%d name=%s model=%s stream=%v", + account.ID, account.Name, reqModel, reqStream) + + if c != nil { + c.Set("anthropic_passthrough", true) + } + // 重试间复用同一请求体,避免每次 string(body) 产生额外分配。 + setOpsUpstreamRequestBody(c, body) + + var resp *http.Response + retryStart := time.Now() + for attempt := 1; attempt <= maxRetryAttempts; attempt++ { + upstreamReq, err := s.buildUpstreamRequestAnthropicAPIKeyPassthrough(ctx, c, account, body, token) + if err != nil { + return nil, err + } + + resp, err = s.httpUpstream.DoWithTLS(upstreamReq, proxyURL, account.ID, account.Concurrency, account.IsTLSFingerprintEnabled()) + if err != nil { + if resp != nil && resp.Body != nil { + _ = resp.Body.Close() + } + safeErr := sanitizeUpstreamErrorMessage(err.Error()) + setOpsUpstreamError(c, 0, safeErr, "") + appendOpsUpstreamError(c, OpsUpstreamErrorEvent{ + Platform: account.Platform, + AccountID: account.ID, + AccountName: account.Name, + UpstreamStatusCode: 0, + Passthrough: true, + Kind: "request_error", + Message: safeErr, + }) + c.JSON(http.StatusBadGateway, gin.H{ + "type": "error", + "error": gin.H{ + "type": "upstream_error", + "message": "Upstream request failed", + }, + }) + return nil, fmt.Errorf("upstream request failed: %s", safeErr) + } + + // 透传分支禁止 400 请求体降级重试(该重试会改写请求体) + if resp.StatusCode >= 400 && resp.StatusCode != 400 && s.shouldRetryUpstreamError(account, resp.StatusCode) { + if attempt < maxRetryAttempts { + elapsed := time.Since(retryStart) + if elapsed >= maxRetryElapsed { + break + } + + delay := retryBackoffDelay(attempt) + remaining := maxRetryElapsed - elapsed + if delay > remaining { + delay = remaining + } + if delay <= 0 { + break + } + + respBody, _ := io.ReadAll(io.LimitReader(resp.Body, 2<<20)) + _ = resp.Body.Close() + appendOpsUpstreamError(c, OpsUpstreamErrorEvent{ + Platform: account.Platform, + AccountID: account.ID, + AccountName: account.Name, + UpstreamStatusCode: resp.StatusCode, + UpstreamRequestID: resp.Header.Get("x-request-id"), + Passthrough: true, + Kind: "retry", + Message: extractUpstreamErrorMessage(respBody), + Detail: func() string { + if s.cfg != nil && s.cfg.Gateway.LogUpstreamErrorBody { + return truncateString(string(respBody), s.cfg.Gateway.LogUpstreamErrorBodyMaxBytes) + } + return "" + }(), + }) + logger.LegacyPrintf("service.gateway", "Anthropic passthrough account %d: upstream error %d, retry %d/%d after %v (elapsed=%v/%v)", + account.ID, resp.StatusCode, attempt, maxRetryAttempts, delay, elapsed, maxRetryElapsed) + if err := sleepWithContext(ctx, delay); err != nil { + return nil, err + } + continue + } + break + } + + break + } + if resp == nil || resp.Body == nil { + return nil, errors.New("upstream request failed: empty response") + } + defer func() { _ = resp.Body.Close() }() + + if resp.StatusCode >= 400 && s.shouldRetryUpstreamError(account, resp.StatusCode) { + if s.shouldFailoverUpstreamError(resp.StatusCode) { + respBody, _ := io.ReadAll(io.LimitReader(resp.Body, 2<<20)) + _ = resp.Body.Close() + resp.Body = io.NopCloser(bytes.NewReader(respBody)) + + logger.LegacyPrintf("service.gateway", "[Anthropic Passthrough] Upstream error (retry exhausted, failover): Account=%d(%s) Status=%d RequestID=%s Body=%s", + account.ID, account.Name, resp.StatusCode, resp.Header.Get("x-request-id"), truncateString(string(respBody), 1000)) + + s.handleRetryExhaustedSideEffects(ctx, resp, account) + appendOpsUpstreamError(c, OpsUpstreamErrorEvent{ + Platform: account.Platform, + AccountID: account.ID, + AccountName: account.Name, + UpstreamStatusCode: resp.StatusCode, + UpstreamRequestID: resp.Header.Get("x-request-id"), + Passthrough: true, + Kind: "retry_exhausted_failover", + Message: extractUpstreamErrorMessage(respBody), + Detail: func() string { + if s.cfg != nil && s.cfg.Gateway.LogUpstreamErrorBody { + return truncateString(string(respBody), s.cfg.Gateway.LogUpstreamErrorBodyMaxBytes) + } + return "" + }(), + }) + return nil, &UpstreamFailoverError{StatusCode: resp.StatusCode, ResponseBody: respBody} + } + return s.handleRetryExhaustedError(ctx, resp, c, account) + } + + if resp.StatusCode >= 400 && s.shouldFailoverUpstreamError(resp.StatusCode) { + respBody, _ := io.ReadAll(io.LimitReader(resp.Body, 2<<20)) + _ = resp.Body.Close() + resp.Body = io.NopCloser(bytes.NewReader(respBody)) + + logger.LegacyPrintf("service.gateway", "[Anthropic Passthrough] Upstream error (failover): Account=%d(%s) Status=%d RequestID=%s Body=%s", + account.ID, account.Name, resp.StatusCode, resp.Header.Get("x-request-id"), truncateString(string(respBody), 1000)) + + s.handleFailoverSideEffects(ctx, resp, account) + appendOpsUpstreamError(c, OpsUpstreamErrorEvent{ + Platform: account.Platform, + AccountID: account.ID, + AccountName: account.Name, + UpstreamStatusCode: resp.StatusCode, + UpstreamRequestID: resp.Header.Get("x-request-id"), + Passthrough: true, + Kind: "failover", + Message: extractUpstreamErrorMessage(respBody), + Detail: func() string { + if s.cfg != nil && s.cfg.Gateway.LogUpstreamErrorBody { + return truncateString(string(respBody), s.cfg.Gateway.LogUpstreamErrorBodyMaxBytes) + } + return "" + }(), + }) + return nil, &UpstreamFailoverError{StatusCode: resp.StatusCode, ResponseBody: respBody} + } + + if resp.StatusCode >= 400 { + return s.handleErrorResponse(ctx, resp, c, account) + } + + var usage *ClaudeUsage + var firstTokenMs *int + var clientDisconnect bool + if reqStream { + streamResult, err := s.handleStreamingResponseAnthropicAPIKeyPassthrough(ctx, resp, c, account, startTime, reqModel) + if err != nil { + return nil, err + } + usage = streamResult.usage + firstTokenMs = streamResult.firstTokenMs + clientDisconnect = streamResult.clientDisconnect + } else { + usage, err = s.handleNonStreamingResponseAnthropicAPIKeyPassthrough(ctx, resp, c, account) + if err != nil { + return nil, err + } + } + if usage == nil { + usage = &ClaudeUsage{} + } + + return &ForwardResult{ + RequestID: resp.Header.Get("x-request-id"), + Usage: *usage, + Model: reqModel, + Stream: reqStream, + Duration: time.Since(startTime), + FirstTokenMs: firstTokenMs, + ClientDisconnect: clientDisconnect, + }, nil +} + +func (s *GatewayService) buildUpstreamRequestAnthropicAPIKeyPassthrough( + ctx context.Context, + c *gin.Context, + account *Account, + body []byte, + token string, +) (*http.Request, error) { + targetURL := claudeAPIURL + baseURL := account.GetBaseURL() + if baseURL != "" { + validatedURL, err := s.validateUpstreamBaseURL(baseURL) + if err != nil { + return nil, err + } + targetURL = validatedURL + "/v1/messages" + } + + req, err := http.NewRequestWithContext(ctx, http.MethodPost, targetURL, bytes.NewReader(body)) + if err != nil { + return nil, err + } + + if c != nil && c.Request != nil { + for key, values := range c.Request.Header { + lowerKey := strings.ToLower(strings.TrimSpace(key)) + if !allowedHeaders[lowerKey] { + continue + } + for _, v := range values { + req.Header.Add(key, v) + } + } + } + + // 覆盖入站鉴权残留,并注入上游认证 + req.Header.Del("authorization") + req.Header.Del("x-api-key") + req.Header.Del("x-goog-api-key") + req.Header.Del("cookie") + req.Header.Set("x-api-key", token) + + if req.Header.Get("content-type") == "" { + req.Header.Set("content-type", "application/json") + } + if req.Header.Get("anthropic-version") == "" { + req.Header.Set("anthropic-version", "2023-06-01") + } + + return req, nil +} + +func (s *GatewayService) handleStreamingResponseAnthropicAPIKeyPassthrough( + ctx context.Context, + resp *http.Response, + c *gin.Context, + account *Account, + startTime time.Time, + model string, +) (*streamingResult, error) { + if s.rateLimitService != nil { + s.rateLimitService.UpdateSessionWindow(ctx, account, resp.Header) + } + + writeAnthropicPassthroughResponseHeaders(c.Writer.Header(), resp.Header, s.cfg) + + contentType := strings.TrimSpace(resp.Header.Get("Content-Type")) + if contentType == "" { + contentType = "text/event-stream" + } + c.Header("Content-Type", contentType) + if c.Writer.Header().Get("Cache-Control") == "" { + c.Header("Cache-Control", "no-cache") + } + if c.Writer.Header().Get("Connection") == "" { + c.Header("Connection", "keep-alive") + } + c.Header("X-Accel-Buffering", "no") + if v := resp.Header.Get("x-request-id"); v != "" { + c.Header("x-request-id", v) + } + + w := c.Writer + flusher, ok := w.(http.Flusher) + if !ok { + return nil, errors.New("streaming not supported") + } + + usage := &ClaudeUsage{} + var firstTokenMs *int + clientDisconnected := false + + scanner := bufio.NewScanner(resp.Body) + maxLineSize := defaultMaxLineSize + if s.cfg != nil && s.cfg.Gateway.MaxLineSize > 0 { + maxLineSize = s.cfg.Gateway.MaxLineSize + } + scanBuf := getSSEScannerBuf64K() + scanner.Buffer(scanBuf[:0], maxLineSize) + + type scanEvent struct { + line string + err error + } + events := make(chan scanEvent, 16) + done := make(chan struct{}) + sendEvent := func(ev scanEvent) bool { + select { + case events <- ev: + return true + case <-done: + return false + } + } + var lastReadAt int64 + atomic.StoreInt64(&lastReadAt, time.Now().UnixNano()) + go func(scanBuf *sseScannerBuf64K) { + defer putSSEScannerBuf64K(scanBuf) + defer close(events) + for scanner.Scan() { + atomic.StoreInt64(&lastReadAt, time.Now().UnixNano()) + if !sendEvent(scanEvent{line: scanner.Text()}) { + return + } + } + if err := scanner.Err(); err != nil { + _ = sendEvent(scanEvent{err: err}) + } + }(scanBuf) + defer close(done) + + streamInterval := time.Duration(0) + if s.cfg != nil && s.cfg.Gateway.StreamDataIntervalTimeout > 0 { + streamInterval = time.Duration(s.cfg.Gateway.StreamDataIntervalTimeout) * time.Second + } + var intervalTicker *time.Ticker + if streamInterval > 0 { + intervalTicker = time.NewTicker(streamInterval) + defer intervalTicker.Stop() + } + var intervalCh <-chan time.Time + if intervalTicker != nil { + intervalCh = intervalTicker.C + } + + for { + select { + case ev, ok := <-events: + if !ok { + if !clientDisconnected { + // 兜底补刷,确保最后一个未以空行结尾的事件也能及时送达客户端。 + flusher.Flush() + } + return &streamingResult{usage: usage, firstTokenMs: firstTokenMs, clientDisconnect: clientDisconnected}, nil + } + if ev.err != nil { + if clientDisconnected { + logger.LegacyPrintf("service.gateway", "[Anthropic passthrough] Upstream read error after client disconnect: account=%d err=%v", account.ID, ev.err) + return &streamingResult{usage: usage, firstTokenMs: firstTokenMs, clientDisconnect: true}, nil + } + if errors.Is(ev.err, context.Canceled) || errors.Is(ev.err, context.DeadlineExceeded) { + logger.LegacyPrintf("service.gateway", "[Anthropic passthrough] 流读取被取消: account=%d request_id=%s err=%v ctx_err=%v", + account.ID, resp.Header.Get("x-request-id"), ev.err, ctx.Err()) + return &streamingResult{usage: usage, firstTokenMs: firstTokenMs, clientDisconnect: true}, nil + } + if errors.Is(ev.err, bufio.ErrTooLong) { + logger.LegacyPrintf("service.gateway", "[Anthropic passthrough] SSE line too long: account=%d max_size=%d error=%v", account.ID, maxLineSize, ev.err) + return &streamingResult{usage: usage, firstTokenMs: firstTokenMs}, ev.err + } + return &streamingResult{usage: usage, firstTokenMs: firstTokenMs}, fmt.Errorf("stream read error: %w", ev.err) + } + + line := ev.line + if data, ok := extractAnthropicSSEDataLine(line); ok { + trimmed := strings.TrimSpace(data) + if firstTokenMs == nil && trimmed != "" && trimmed != "[DONE]" { + ms := int(time.Since(startTime).Milliseconds()) + firstTokenMs = &ms + } + s.parseSSEUsagePassthrough(data, usage) + } + + if !clientDisconnected { + if _, err := io.WriteString(w, line); err != nil { + clientDisconnected = true + logger.LegacyPrintf("service.gateway", "[Anthropic passthrough] Client disconnected during streaming, continue draining upstream for usage: account=%d", account.ID) + } else if _, err := io.WriteString(w, "\n"); err != nil { + clientDisconnected = true + logger.LegacyPrintf("service.gateway", "[Anthropic passthrough] Client disconnected during streaming, continue draining upstream for usage: account=%d", account.ID) + } else if line == "" { + // 按 SSE 事件边界刷出,减少每行 flush 带来的 syscall 开销。 + flusher.Flush() + } + } + + case <-intervalCh: + lastRead := time.Unix(0, atomic.LoadInt64(&lastReadAt)) + if time.Since(lastRead) < streamInterval { + continue + } + if clientDisconnected { + logger.LegacyPrintf("service.gateway", "[Anthropic passthrough] Upstream timeout after client disconnect: account=%d model=%s", account.ID, model) + return &streamingResult{usage: usage, firstTokenMs: firstTokenMs, clientDisconnect: true}, nil + } + logger.LegacyPrintf("service.gateway", "[Anthropic passthrough] Stream data interval timeout: account=%d model=%s interval=%s", account.ID, model, streamInterval) + if s.rateLimitService != nil { + s.rateLimitService.HandleStreamTimeout(ctx, account, model) + } + return &streamingResult{usage: usage, firstTokenMs: firstTokenMs}, fmt.Errorf("stream data interval timeout") + } + } +} + +func extractAnthropicSSEDataLine(line string) (string, bool) { + if !strings.HasPrefix(line, "data:") { + return "", false + } + start := len("data:") + for start < len(line) { + if line[start] != ' ' && line[start] != '\t' { + break + } + start++ + } + return line[start:], true +} + +func (s *GatewayService) parseSSEUsagePassthrough(data string, usage *ClaudeUsage) { + if usage == nil || data == "" || data == "[DONE]" { + return + } + + parsed := gjson.Parse(data) + switch parsed.Get("type").String() { + case "message_start": + msgUsage := parsed.Get("message.usage") + if msgUsage.Exists() { + usage.InputTokens = int(msgUsage.Get("input_tokens").Int()) + usage.CacheCreationInputTokens = int(msgUsage.Get("cache_creation_input_tokens").Int()) + usage.CacheReadInputTokens = int(msgUsage.Get("cache_read_input_tokens").Int()) + + // 保持与通用解析一致:message_start 允许覆盖 5m/1h 明细(包括 0)。 + cc5m := msgUsage.Get("cache_creation.ephemeral_5m_input_tokens") + cc1h := msgUsage.Get("cache_creation.ephemeral_1h_input_tokens") + if cc5m.Exists() || cc1h.Exists() { + usage.CacheCreation5mTokens = int(cc5m.Int()) + usage.CacheCreation1hTokens = int(cc1h.Int()) + } + } + case "message_delta": + deltaUsage := parsed.Get("usage") + if deltaUsage.Exists() { + if v := deltaUsage.Get("input_tokens").Int(); v > 0 { + usage.InputTokens = int(v) + } + if v := deltaUsage.Get("output_tokens").Int(); v > 0 { + usage.OutputTokens = int(v) + } + if v := deltaUsage.Get("cache_creation_input_tokens").Int(); v > 0 { + usage.CacheCreationInputTokens = int(v) + } + if v := deltaUsage.Get("cache_read_input_tokens").Int(); v > 0 { + usage.CacheReadInputTokens = int(v) + } + + cc5m := deltaUsage.Get("cache_creation.ephemeral_5m_input_tokens") + cc1h := deltaUsage.Get("cache_creation.ephemeral_1h_input_tokens") + if cc5m.Exists() && cc5m.Int() > 0 { + usage.CacheCreation5mTokens = int(cc5m.Int()) + } + if cc1h.Exists() && cc1h.Int() > 0 { + usage.CacheCreation1hTokens = int(cc1h.Int()) + } + } + } + + if usage.CacheReadInputTokens == 0 { + if cached := parsed.Get("message.usage.cached_tokens").Int(); cached > 0 { + usage.CacheReadInputTokens = int(cached) + } + if cached := parsed.Get("usage.cached_tokens").Int(); usage.CacheReadInputTokens == 0 && cached > 0 { + usage.CacheReadInputTokens = int(cached) + } + } + if usage.CacheCreationInputTokens == 0 { + cc5m := parsed.Get("message.usage.cache_creation.ephemeral_5m_input_tokens").Int() + cc1h := parsed.Get("message.usage.cache_creation.ephemeral_1h_input_tokens").Int() + if cc5m == 0 && cc1h == 0 { + cc5m = parsed.Get("usage.cache_creation.ephemeral_5m_input_tokens").Int() + cc1h = parsed.Get("usage.cache_creation.ephemeral_1h_input_tokens").Int() + } + total := cc5m + cc1h + if total > 0 { + usage.CacheCreationInputTokens = int(total) + } + } +} + +func parseClaudeUsageFromResponseBody(body []byte) *ClaudeUsage { + usage := &ClaudeUsage{} + if len(body) == 0 { + return usage + } + + parsed := gjson.ParseBytes(body) + usageNode := parsed.Get("usage") + if !usageNode.Exists() { + return usage + } + + usage.InputTokens = int(usageNode.Get("input_tokens").Int()) + usage.OutputTokens = int(usageNode.Get("output_tokens").Int()) + usage.CacheCreationInputTokens = int(usageNode.Get("cache_creation_input_tokens").Int()) + usage.CacheReadInputTokens = int(usageNode.Get("cache_read_input_tokens").Int()) + + cc5m := usageNode.Get("cache_creation.ephemeral_5m_input_tokens").Int() + cc1h := usageNode.Get("cache_creation.ephemeral_1h_input_tokens").Int() + if cc5m > 0 || cc1h > 0 { + usage.CacheCreation5mTokens = int(cc5m) + usage.CacheCreation1hTokens = int(cc1h) + } + if usage.CacheCreationInputTokens == 0 && (cc5m > 0 || cc1h > 0) { + usage.CacheCreationInputTokens = int(cc5m + cc1h) + } + if usage.CacheReadInputTokens == 0 { + if cached := usageNode.Get("cached_tokens").Int(); cached > 0 { + usage.CacheReadInputTokens = int(cached) + } + } + return usage +} + +func (s *GatewayService) handleNonStreamingResponseAnthropicAPIKeyPassthrough( + ctx context.Context, + resp *http.Response, + c *gin.Context, + account *Account, +) (*ClaudeUsage, error) { + if s.rateLimitService != nil { + s.rateLimitService.UpdateSessionWindow(ctx, account, resp.Header) + } + + maxBytes := resolveUpstreamResponseReadLimit(s.cfg) + body, err := readUpstreamResponseBodyLimited(resp.Body, maxBytes) + if err != nil { + if errors.Is(err, ErrUpstreamResponseBodyTooLarge) { + setOpsUpstreamError(c, http.StatusBadGateway, "upstream response too large", "") + c.JSON(http.StatusBadGateway, gin.H{ + "type": "error", + "error": gin.H{ + "type": "upstream_error", + "message": "Upstream response too large", + }, + }) + } + return nil, err + } + + usage := parseClaudeUsageFromResponseBody(body) + + writeAnthropicPassthroughResponseHeaders(c.Writer.Header(), resp.Header, s.cfg) + contentType := strings.TrimSpace(resp.Header.Get("Content-Type")) + if contentType == "" { + contentType = "application/json" + } + c.Data(resp.StatusCode, contentType, body) + return usage, nil +} + +func writeAnthropicPassthroughResponseHeaders(dst http.Header, src http.Header, cfg *config.Config) { + if dst == nil || src == nil { + return + } + if cfg != nil { + responseheaders.WriteFilteredHeaders(dst, src, cfg.Security.ResponseHeaders) + return + } + if v := strings.TrimSpace(src.Get("Content-Type")); v != "" { + dst.Set("Content-Type", v) + } + if v := strings.TrimSpace(src.Get("x-request-id")); v != "" { + dst.Set("x-request-id", v) + } +} + func (s *GatewayService) buildUpstreamRequest(ctx context.Context, c *gin.Context, account *Account, body []byte, token, tokenType, modelID string, reqStream bool, mimicClaudeCode bool) (*http.Request, error) { // 确定目标URL targetURL := claudeAPIURL @@ -3486,7 +4362,7 @@ func (s *GatewayService) buildUpstreamRequest(ctx context.Context, c *gin.Contex // 1. 获取或创建指纹(包含随机生成的ClientID) fp, err := s.identityService.GetOrCreateFingerprint(ctx, account.ID, clientHeaders) if err != nil { - log.Printf("Warning: failed to get fingerprint for account %d: %v", account.ID, err) + logger.LegacyPrintf("service.gateway", "Warning: failed to get fingerprint for account %d: %v", account.ID, err) // 失败时降级为透传原始headers } else { fingerprint = fp @@ -3775,33 +4651,33 @@ func (s *GatewayService) isThinkingBlockSignatureError(respBody []byte) bool { } // Log for debugging - log.Printf("[SignatureCheck] Checking error message: %s", msg) + logger.LegacyPrintf("service.gateway", "[SignatureCheck] Checking error message: %s", msg) // 检测signature相关的错误(更宽松的匹配) // 例如: "Invalid `signature` in `thinking` block", "***.signature" 等 if strings.Contains(msg, "signature") { - log.Printf("[SignatureCheck] Detected signature error") + logger.LegacyPrintf("service.gateway", "[SignatureCheck] Detected signature error") return true } // 检测 thinking block 顺序/类型错误 // 例如: "Expected `thinking` or `redacted_thinking`, but found `text`" if strings.Contains(msg, "expected") && (strings.Contains(msg, "thinking") || strings.Contains(msg, "redacted_thinking")) { - log.Printf("[SignatureCheck] Detected thinking block type error") + logger.LegacyPrintf("service.gateway", "[SignatureCheck] Detected thinking block type error") return true } // 检测 thinking block 被修改的错误 // 例如: "thinking or redacted_thinking blocks in the latest assistant message cannot be modified" if strings.Contains(msg, "cannot be modified") && (strings.Contains(msg, "thinking") || strings.Contains(msg, "redacted_thinking")) { - log.Printf("[SignatureCheck] Detected thinking block modification error") + logger.LegacyPrintf("service.gateway", "[SignatureCheck] Detected thinking block modification error") return true } // 检测空消息内容错误(可能是过滤 thinking blocks 后导致的) // 例如: "all messages must have non-empty content" if strings.Contains(msg, "non-empty content") || strings.Contains(msg, "empty content") { - log.Printf("[SignatureCheck] Detected empty content error") + logger.LegacyPrintf("service.gateway", "[SignatureCheck] Detected empty content error") return true } @@ -3862,7 +4738,7 @@ func (s *GatewayService) handleErrorResponse(ctx context.Context, resp *http.Res body, _ := io.ReadAll(io.LimitReader(resp.Body, 2<<20)) // 调试日志:打印上游错误响应 - log.Printf("[Forward] Upstream error (non-retryable): Account=%d(%s) Status=%d RequestID=%s Body=%s", + logger.LegacyPrintf("service.gateway", "[Forward] Upstream error (non-retryable): Account=%d(%s) Status=%d RequestID=%s Body=%s", account.ID, account.Name, resp.StatusCode, resp.Header.Get("x-request-id"), truncateString(string(body), 1000)) upstreamMsg := strings.TrimSpace(extractUpstreamErrorMessage(body)) @@ -3873,7 +4749,7 @@ func (s *GatewayService) handleErrorResponse(ctx context.Context, resp *http.Res if isClaudeCodeCredentialScopeError(upstreamMsg) && c != nil { if v, ok := c.Get(claudeMimicDebugInfoKey); ok { if line, ok := v.(string); ok && strings.TrimSpace(line) != "" { - log.Printf("[ClaudeMimicDebugOnError] status=%d request_id=%s %s", + logger.LegacyPrintf("service.gateway", "[ClaudeMimicDebugOnError] status=%d request_id=%s %s", resp.StatusCode, resp.Header.Get("x-request-id"), line, @@ -3913,7 +4789,7 @@ func (s *GatewayService) handleErrorResponse(ctx context.Context, resp *http.Res // 记录上游错误响应体摘要便于排障(可选:由配置控制;不回显到客户端) if s.cfg != nil && s.cfg.Gateway.LogUpstreamErrorBody { - log.Printf( + logger.LegacyPrintf("service.gateway", "Upstream error %d (account=%d platform=%s type=%s): %s", resp.StatusCode, account.ID, @@ -4014,10 +4890,10 @@ func (s *GatewayService) handleRetryExhaustedSideEffects(ctx context.Context, re // OAuth/Setup Token 账号的 403:标记账号异常 if account.IsOAuth() && statusCode == 403 { s.rateLimitService.HandleUpstreamError(ctx, account, statusCode, resp.Header, body) - log.Printf("Account %d: marked as error after %d retries for status %d", account.ID, maxRetryAttempts, statusCode) + logger.LegacyPrintf("service.gateway", "Account %d: marked as error after %d retries for status %d", account.ID, maxRetryAttempts, statusCode) } else { // API Key 未配置错误码:不标记账号状态 - log.Printf("Account %d: upstream error %d after %d retries (not marking account)", account.ID, statusCode, maxRetryAttempts) + logger.LegacyPrintf("service.gateway", "Account %d: upstream error %d after %d retries (not marking account)", account.ID, statusCode, maxRetryAttempts) } } @@ -4043,7 +4919,7 @@ func (s *GatewayService) handleRetryExhaustedError(ctx context.Context, resp *ht if isClaudeCodeCredentialScopeError(upstreamMsg) && c != nil { if v, ok := c.Get(claudeMimicDebugInfoKey); ok { if line, ok := v.(string); ok && strings.TrimSpace(line) != "" { - log.Printf("[ClaudeMimicDebugOnError] status=%d request_id=%s %s", + logger.LegacyPrintf("service.gateway", "[ClaudeMimicDebugOnError] status=%d request_id=%s %s", resp.StatusCode, resp.Header.Get("x-request-id"), line, @@ -4072,7 +4948,7 @@ func (s *GatewayService) handleRetryExhaustedError(ctx context.Context, resp *ht }) if s.cfg != nil && s.cfg.Gateway.LogUpstreamErrorBody { - log.Printf( + logger.LegacyPrintf("service.gateway", "Upstream error %d retries_exhausted (account=%d platform=%s type=%s): %s", resp.StatusCode, account.ID, @@ -4164,7 +5040,8 @@ func (s *GatewayService) handleStreamingResponse(ctx context.Context, resp *http if s.cfg != nil && s.cfg.Gateway.MaxLineSize > 0 { maxLineSize = s.cfg.Gateway.MaxLineSize } - scanner.Buffer(make([]byte, 64*1024), maxLineSize) + scanBuf := getSSEScannerBuf64K() + scanner.Buffer(scanBuf[:0], maxLineSize) type scanEvent struct { line string @@ -4183,7 +5060,8 @@ func (s *GatewayService) handleStreamingResponse(ctx context.Context, resp *http } var lastReadAt int64 atomic.StoreInt64(&lastReadAt, time.Now().UnixNano()) - go func() { + go func(scanBuf *sseScannerBuf64K) { + defer putSSEScannerBuf64K(scanBuf) defer close(events) for scanner.Scan() { atomic.StoreInt64(&lastReadAt, time.Now().UnixNano()) @@ -4194,7 +5072,7 @@ func (s *GatewayService) handleStreamingResponse(ctx context.Context, resp *http if err := scanner.Err(); err != nil { _ = sendEvent(scanEvent{err: err}) } - }() + }(scanBuf) defer close(done) streamInterval := time.Duration(0) @@ -4347,17 +5225,17 @@ func (s *GatewayService) handleStreamingResponse(ctx context.Context, resp *http if ev.err != nil { // 检测 context 取消(客户端断开会导致 context 取消,进而影响上游读取) if errors.Is(ev.err, context.Canceled) || errors.Is(ev.err, context.DeadlineExceeded) { - log.Printf("Context canceled during streaming, returning collected usage") + logger.LegacyPrintf("service.gateway", "Context canceled during streaming, returning collected usage") return &streamingResult{usage: usage, firstTokenMs: firstTokenMs, clientDisconnect: true}, nil } // 客户端已通过写入失败检测到断开,上游也出错了,返回已收集的 usage if clientDisconnected { - log.Printf("Upstream read error after client disconnect: %v, returning collected usage", ev.err) + logger.LegacyPrintf("service.gateway", "Upstream read error after client disconnect: %v, returning collected usage", ev.err) return &streamingResult{usage: usage, firstTokenMs: firstTokenMs, clientDisconnect: true}, nil } // 客户端未断开,正常的错误处理 if errors.Is(ev.err, bufio.ErrTooLong) { - log.Printf("SSE line too long: account=%d max_size=%d error=%v", account.ID, maxLineSize, ev.err) + logger.LegacyPrintf("service.gateway", "SSE line too long: account=%d max_size=%d error=%v", account.ID, maxLineSize, ev.err) sendErrorEvent("response_too_large") return &streamingResult{usage: usage, firstTokenMs: firstTokenMs}, ev.err } @@ -4385,7 +5263,7 @@ func (s *GatewayService) handleStreamingResponse(ctx context.Context, resp *http if !clientDisconnected { if _, werr := fmt.Fprint(w, block); werr != nil { clientDisconnected = true - log.Printf("Client disconnected during streaming, continuing to drain upstream for billing") + logger.LegacyPrintf("service.gateway", "Client disconnected during streaming, continuing to drain upstream for billing") break } flusher.Flush() @@ -4410,10 +5288,10 @@ func (s *GatewayService) handleStreamingResponse(ctx context.Context, resp *http } if clientDisconnected { // 客户端已断开,上游也超时了,返回已收集的 usage - log.Printf("Upstream timeout after client disconnect, returning collected usage") + logger.LegacyPrintf("service.gateway", "Upstream timeout after client disconnect, returning collected usage") return &streamingResult{usage: usage, firstTokenMs: firstTokenMs, clientDisconnect: true}, nil } - log.Printf("Stream data interval timeout: account=%d model=%s interval=%s", account.ID, originalModel, streamInterval) + logger.LegacyPrintf("service.gateway", "Stream data interval timeout: account=%d model=%s interval=%s", account.ID, originalModel, streamInterval) // 处理流超时,可能标记账户为临时不可调度或错误状态 if s.rateLimitService != nil { s.rateLimitService.HandleStreamTimeout(ctx, account, originalModel) @@ -4477,8 +5355,10 @@ func (s *GatewayService) parseSSEUsage(data string, usage *ClaudeUsage) { // 解析嵌套的 cache_creation 对象中的 5m/1h 明细 cc5m := gjson.Get(data, "usage.cache_creation.ephemeral_5m_input_tokens") cc1h := gjson.Get(data, "usage.cache_creation.ephemeral_1h_input_tokens") - if cc5m.Exists() || cc1h.Exists() { + if cc5m.Exists() && cc5m.Int() > 0 { usage.CacheCreation5mTokens = int(cc5m.Int()) + } + if cc1h.Exists() && cc1h.Int() > 0 { usage.CacheCreation1hTokens = int(cc1h.Int()) } } @@ -4540,8 +5420,19 @@ func (s *GatewayService) handleNonStreamingResponse(ctx context.Context, resp *h // 更新5h窗口状态 s.rateLimitService.UpdateSessionWindow(ctx, account, resp.Header) - body, err := io.ReadAll(resp.Body) + maxBytes := resolveUpstreamResponseReadLimit(s.cfg) + body, err := readUpstreamResponseBodyLimited(resp.Body, maxBytes) if err != nil { + if errors.Is(err, ErrUpstreamResponseBodyTooLarge) { + setOpsUpstreamError(c, http.StatusBadGateway, "upstream response too large", "") + c.JSON(http.StatusBadGateway, gin.H{ + "type": "error", + "error": gin.H{ + "type": "upstream_error", + "message": "Upstream response too large", + }, + }) + } return nil, err } @@ -4607,24 +5498,76 @@ func (s *GatewayService) handleNonStreamingResponse(ctx context.Context, resp *h } // replaceModelInResponseBody 替换响应体中的model字段 +// 使用 gjson/sjson 精确替换,避免全量 JSON 反序列化 func (s *GatewayService) replaceModelInResponseBody(body []byte, fromModel, toModel string) []byte { - var resp map[string]any - if err := json.Unmarshal(body, &resp); err != nil { - return body + if m := gjson.GetBytes(body, "model"); m.Exists() && m.Str == fromModel { + newBody, err := sjson.SetBytes(body, "model", toModel) + if err != nil { + return body + } + return newBody + } + return body +} + +func (s *GatewayService) getUserGroupRateMultiplier(ctx context.Context, userID, groupID int64, groupDefaultMultiplier float64) float64 { + if s == nil || userID <= 0 || groupID <= 0 { + return groupDefaultMultiplier } - model, ok := resp["model"].(string) - if !ok || model != fromModel { - return body + key := fmt.Sprintf("%d:%d", userID, groupID) + if s.userGroupRateCache != nil { + if cached, ok := s.userGroupRateCache.Get(key); ok { + if multiplier, castOK := cached.(float64); castOK { + userGroupRateCacheHitTotal.Add(1) + return multiplier + } + } } + if s.userGroupRateRepo == nil { + return groupDefaultMultiplier + } + userGroupRateCacheMissTotal.Add(1) - resp["model"] = toModel - newBody, err := json.Marshal(resp) + value, err, shared := s.userGroupRateSF.Do(key, func() (any, error) { + if s.userGroupRateCache != nil { + if cached, ok := s.userGroupRateCache.Get(key); ok { + if multiplier, castOK := cached.(float64); castOK { + userGroupRateCacheHitTotal.Add(1) + return multiplier, nil + } + } + } + + userGroupRateCacheLoadTotal.Add(1) + userRate, repoErr := s.userGroupRateRepo.GetByUserAndGroup(ctx, userID, groupID) + if repoErr != nil { + return nil, repoErr + } + multiplier := groupDefaultMultiplier + if userRate != nil { + multiplier = *userRate + } + if s.userGroupRateCache != nil { + s.userGroupRateCache.Set(key, multiplier, resolveUserGroupRateCacheTTL(s.cfg)) + } + return multiplier, nil + }) + if shared { + userGroupRateCacheSFSharedTotal.Add(1) + } if err != nil { - return body + userGroupRateCacheFallbackTotal.Add(1) + logger.LegacyPrintf("service.gateway", "get user group rate failed, fallback to group default: user=%d group=%d err=%v", userID, groupID, err) + return groupDefaultMultiplier } - return newBody + multiplier, ok := value.(float64) + if !ok { + userGroupRateCacheFallbackTotal.Add(1) + return groupDefaultMultiplier + } + return multiplier } // RecordUsageInput 记录使用量的输入参数 @@ -4656,7 +5599,7 @@ func (s *GatewayService) RecordUsage(ctx context.Context, input *RecordUsageInpu // 强制缓存计费:将 input_tokens 转为 cache_read_input_tokens // 用于粘性会话切换时的特殊计费处理 if input.ForceCacheBilling && result.Usage.InputTokens > 0 { - log.Printf("force_cache_billing: %d input_tokens → cache_read_input_tokens (account=%d)", + logger.LegacyPrintf("service.gateway", "force_cache_billing: %d input_tokens → cache_read_input_tokens (account=%d)", result.Usage.InputTokens, account.ID) result.Usage.CacheReadInputTokens += result.Usage.InputTokens result.Usage.InputTokens = 0 @@ -4670,22 +5613,36 @@ func (s *GatewayService) RecordUsage(ctx context.Context, input *RecordUsageInpu } // 获取费率倍数(优先级:用户专属 > 分组默认 > 系统默认) - multiplier := s.cfg.Default.RateMultiplier + multiplier := 1.0 + if s.cfg != nil { + multiplier = s.cfg.Default.RateMultiplier + } if apiKey.GroupID != nil && apiKey.Group != nil { - multiplier = apiKey.Group.RateMultiplier - - // 检查用户专属倍率 - if s.userGroupRateRepo != nil { - if userRate, err := s.userGroupRateRepo.GetByUserAndGroup(ctx, user.ID, *apiKey.GroupID); err == nil && userRate != nil { - multiplier = *userRate - } - } + groupDefault := apiKey.Group.RateMultiplier + multiplier = s.getUserGroupRateMultiplier(ctx, user.ID, *apiKey.GroupID, groupDefault) } var cost *CostBreakdown // 根据请求类型选择计费方式 - if result.ImageCount > 0 { + if result.MediaType == "image" || result.MediaType == "video" { + var soraConfig *SoraPriceConfig + if apiKey.Group != nil { + soraConfig = &SoraPriceConfig{ + ImagePrice360: apiKey.Group.SoraImagePrice360, + ImagePrice540: apiKey.Group.SoraImagePrice540, + VideoPricePerRequest: apiKey.Group.SoraVideoPricePerRequest, + VideoPricePerRequestHD: apiKey.Group.SoraVideoPricePerRequestHD, + } + } + if result.MediaType == "image" { + cost = s.billingService.CalculateSoraImageCost(result.ImageSize, result.ImageCount, soraConfig, multiplier) + } else { + cost = s.billingService.CalculateSoraVideoCost(result.Model, soraConfig, multiplier) + } + } else if result.MediaType == "prompt" { + cost = &CostBreakdown{} + } else if result.ImageCount > 0 { // 图片生成计费 var groupConfig *ImagePriceConfig if apiKey.Group != nil { @@ -4709,7 +5666,7 @@ func (s *GatewayService) RecordUsage(ctx context.Context, input *RecordUsageInpu var err error cost, err = s.billingService.CalculateCost(result.Model, tokens, multiplier) if err != nil { - log.Printf("Calculate cost failed: %v", err) + logger.LegacyPrintf("service.gateway", "Calculate cost failed: %v", err) cost = &CostBreakdown{ActualCost: 0} } } @@ -4727,6 +5684,10 @@ func (s *GatewayService) RecordUsage(ctx context.Context, input *RecordUsageInpu if result.ImageSize != "" { imageSize = &result.ImageSize } + var mediaType *string + if strings.TrimSpace(result.MediaType) != "" { + mediaType = &result.MediaType + } accountRateMultiplier := account.BillingRateMultiplier() usageLog := &UsageLog{ UserID: user.ID, @@ -4754,6 +5715,7 @@ func (s *GatewayService) RecordUsage(ctx context.Context, input *RecordUsageInpu FirstTokenMs: result.FirstTokenMs, ImageCount: result.ImageCount, ImageSize: imageSize, + MediaType: mediaType, CacheTTLOverridden: cacheTTLOverridden, CreatedAt: time.Now(), } @@ -4778,11 +5740,11 @@ func (s *GatewayService) RecordUsage(ctx context.Context, input *RecordUsageInpu inserted, err := s.usageLogRepo.Create(ctx, usageLog) if err != nil { - log.Printf("Create usage log failed: %v", err) + logger.LegacyPrintf("service.gateway", "Create usage log failed: %v", err) } if s.cfg != nil && s.cfg.RunMode == config.RunModeSimple { - log.Printf("[SIMPLE MODE] Usage recorded (not billed): user=%d, tokens=%d", usageLog.UserID, usageLog.TotalTokens()) + logger.LegacyPrintf("service.gateway", "[SIMPLE MODE] Usage recorded (not billed): user=%d, tokens=%d", usageLog.UserID, usageLog.TotalTokens()) s.deferredService.ScheduleLastUsedUpdate(account.ID) return nil } @@ -4794,7 +5756,7 @@ func (s *GatewayService) RecordUsage(ctx context.Context, input *RecordUsageInpu // 订阅模式:更新订阅用量(使用 TotalCost 原始费用,不考虑倍率) if shouldBill && cost.TotalCost > 0 { if err := s.userSubRepo.IncrementUsage(ctx, subscription.ID, cost.TotalCost); err != nil { - log.Printf("Increment subscription usage failed: %v", err) + logger.LegacyPrintf("service.gateway", "Increment subscription usage failed: %v", err) } // 异步更新订阅缓存 s.billingCacheService.QueueUpdateSubscriptionUsage(user.ID, *apiKey.GroupID, cost.TotalCost) @@ -4803,7 +5765,7 @@ func (s *GatewayService) RecordUsage(ctx context.Context, input *RecordUsageInpu // 余额模式:扣除用户余额(使用 ActualCost 考虑倍率后的费用) if shouldBill && cost.ActualCost > 0 { if err := s.userRepo.DeductBalance(ctx, user.ID, cost.ActualCost); err != nil { - log.Printf("Deduct balance failed: %v", err) + logger.LegacyPrintf("service.gateway", "Deduct balance failed: %v", err) } // 异步更新余额缓存 s.billingCacheService.QueueDeductBalance(user.ID, cost.ActualCost) @@ -4813,7 +5775,7 @@ func (s *GatewayService) RecordUsage(ctx context.Context, input *RecordUsageInpu // 更新 API Key 配额(如果设置了配额限制) if shouldBill && cost.ActualCost > 0 && apiKey.Quota > 0 && input.APIKeyService != nil { if err := input.APIKeyService.UpdateQuotaUsed(ctx, apiKey.ID, cost.ActualCost); err != nil { - log.Printf("Update API key quota failed: %v", err) + logger.LegacyPrintf("service.gateway", "Update API key quota failed: %v", err) } } @@ -4849,7 +5811,7 @@ func (s *GatewayService) RecordUsageWithLongContext(ctx context.Context, input * // 强制缓存计费:将 input_tokens 转为 cache_read_input_tokens // 用于粘性会话切换时的特殊计费处理 if input.ForceCacheBilling && result.Usage.InputTokens > 0 { - log.Printf("force_cache_billing: %d input_tokens → cache_read_input_tokens (account=%d)", + logger.LegacyPrintf("service.gateway", "force_cache_billing: %d input_tokens → cache_read_input_tokens (account=%d)", result.Usage.InputTokens, account.ID) result.Usage.CacheReadInputTokens += result.Usage.InputTokens result.Usage.InputTokens = 0 @@ -4863,16 +5825,13 @@ func (s *GatewayService) RecordUsageWithLongContext(ctx context.Context, input * } // 获取费率倍数(优先级:用户专属 > 分组默认 > 系统默认) - multiplier := s.cfg.Default.RateMultiplier + multiplier := 1.0 + if s.cfg != nil { + multiplier = s.cfg.Default.RateMultiplier + } if apiKey.GroupID != nil && apiKey.Group != nil { - multiplier = apiKey.Group.RateMultiplier - - // 检查用户专属倍率 - if s.userGroupRateRepo != nil { - if userRate, err := s.userGroupRateRepo.GetByUserAndGroup(ctx, user.ID, *apiKey.GroupID); err == nil && userRate != nil { - multiplier = *userRate - } - } + groupDefault := apiKey.Group.RateMultiplier + multiplier = s.getUserGroupRateMultiplier(ctx, user.ID, *apiKey.GroupID, groupDefault) } var cost *CostBreakdown @@ -4902,7 +5861,7 @@ func (s *GatewayService) RecordUsageWithLongContext(ctx context.Context, input * var err error cost, err = s.billingService.CalculateCostWithLongContext(result.Model, tokens, multiplier, input.LongContextThreshold, input.LongContextMultiplier) if err != nil { - log.Printf("Calculate cost failed: %v", err) + logger.LegacyPrintf("service.gateway", "Calculate cost failed: %v", err) cost = &CostBreakdown{ActualCost: 0} } } @@ -4971,11 +5930,11 @@ func (s *GatewayService) RecordUsageWithLongContext(ctx context.Context, input * inserted, err := s.usageLogRepo.Create(ctx, usageLog) if err != nil { - log.Printf("Create usage log failed: %v", err) + logger.LegacyPrintf("service.gateway", "Create usage log failed: %v", err) } if s.cfg != nil && s.cfg.RunMode == config.RunModeSimple { - log.Printf("[SIMPLE MODE] Usage recorded (not billed): user=%d, tokens=%d", usageLog.UserID, usageLog.TotalTokens()) + logger.LegacyPrintf("service.gateway", "[SIMPLE MODE] Usage recorded (not billed): user=%d, tokens=%d", usageLog.UserID, usageLog.TotalTokens()) s.deferredService.ScheduleLastUsedUpdate(account.ID) return nil } @@ -4987,7 +5946,7 @@ func (s *GatewayService) RecordUsageWithLongContext(ctx context.Context, input * // 订阅模式:更新订阅用量(使用 TotalCost 原始费用,不考虑倍率) if shouldBill && cost.TotalCost > 0 { if err := s.userSubRepo.IncrementUsage(ctx, subscription.ID, cost.TotalCost); err != nil { - log.Printf("Increment subscription usage failed: %v", err) + logger.LegacyPrintf("service.gateway", "Increment subscription usage failed: %v", err) } // 异步更新订阅缓存 s.billingCacheService.QueueUpdateSubscriptionUsage(user.ID, *apiKey.GroupID, cost.TotalCost) @@ -4996,14 +5955,14 @@ func (s *GatewayService) RecordUsageWithLongContext(ctx context.Context, input * // 余额模式:扣除用户余额(使用 ActualCost 考虑倍率后的费用) if shouldBill && cost.ActualCost > 0 { if err := s.userRepo.DeductBalance(ctx, user.ID, cost.ActualCost); err != nil { - log.Printf("Deduct balance failed: %v", err) + logger.LegacyPrintf("service.gateway", "Deduct balance failed: %v", err) } // 异步更新余额缓存 s.billingCacheService.QueueDeductBalance(user.ID, cost.ActualCost) // API Key 独立配额扣费 if input.APIKeyService != nil && apiKey.Quota > 0 { if err := input.APIKeyService.UpdateQuotaUsed(ctx, apiKey.ID, cost.ActualCost); err != nil { - log.Printf("Add API key quota used failed: %v", err) + logger.LegacyPrintf("service.gateway", "Add API key quota used failed: %v", err) } } } @@ -5023,6 +5982,10 @@ func (s *GatewayService) ForwardCountTokens(ctx context.Context, c *gin.Context, return fmt.Errorf("parse request: empty request") } + if account != nil && account.IsAnthropicAPIKeyPassthroughEnabled() { + return s.forwardCountTokensAnthropicAPIKeyPassthrough(ctx, c, account, parsed.Body) + } + body := parsed.Body reqModel := parsed.Model @@ -5062,7 +6025,7 @@ func (s *GatewayService) ForwardCountTokens(ctx context.Context, c *gin.Context, if mappedModel != reqModel { body = s.replaceModelInBody(body, mappedModel) reqModel = mappedModel - log.Printf("CountTokens model mapping applied: %s -> %s (account: %s, source=%s)", parsed.Model, mappedModel, account.Name, mappingSource) + logger.LegacyPrintf("service.gateway", "CountTokens model mapping applied: %s -> %s (account: %s, source=%s)", parsed.Model, mappedModel, account.Name, mappingSource) } } @@ -5095,16 +6058,22 @@ func (s *GatewayService) ForwardCountTokens(ctx context.Context, c *gin.Context, } // 读取响应体 - respBody, err := io.ReadAll(resp.Body) + maxReadBytes := resolveUpstreamResponseReadLimit(s.cfg) + respBody, err := readUpstreamResponseBodyLimited(resp.Body, maxReadBytes) _ = resp.Body.Close() if err != nil { + if errors.Is(err, ErrUpstreamResponseBodyTooLarge) { + setOpsUpstreamError(c, http.StatusBadGateway, "upstream response too large", "") + s.countTokensError(c, http.StatusBadGateway, "upstream_error", "Upstream response too large") + return err + } s.countTokensError(c, http.StatusBadGateway, "upstream_error", "Failed to read response") return err } // 检测 thinking block 签名错误(400)并重试一次(过滤 thinking blocks) if resp.StatusCode == 400 && s.isThinkingBlockSignatureError(respBody) { - log.Printf("Account %d: detected thinking block signature error on count_tokens, retrying with filtered thinking blocks", account.ID) + logger.LegacyPrintf("service.gateway", "Account %d: detected thinking block signature error on count_tokens, retrying with filtered thinking blocks", account.ID) filteredBody := FilterThinkingBlocksForRetry(body) retryReq, buildErr := s.buildCountTokensRequest(ctx, c, account, filteredBody, token, tokenType, reqModel, shouldMimicClaudeCode) @@ -5112,9 +6081,14 @@ func (s *GatewayService) ForwardCountTokens(ctx context.Context, c *gin.Context, retryResp, retryErr := s.httpUpstream.DoWithTLS(retryReq, proxyURL, account.ID, account.Concurrency, account.IsTLSFingerprintEnabled()) if retryErr == nil { resp = retryResp - respBody, err = io.ReadAll(resp.Body) + respBody, err = readUpstreamResponseBodyLimited(resp.Body, maxReadBytes) _ = resp.Body.Close() if err != nil { + if errors.Is(err, ErrUpstreamResponseBodyTooLarge) { + setOpsUpstreamError(c, http.StatusBadGateway, "upstream response too large", "") + s.countTokensError(c, http.StatusBadGateway, "upstream_error", "Upstream response too large") + return err + } s.countTokensError(c, http.StatusBadGateway, "upstream_error", "Failed to read response") return err } @@ -5141,7 +6115,7 @@ func (s *GatewayService) ForwardCountTokens(ctx context.Context, c *gin.Context, // 记录上游错误摘要便于排障(不回显请求内容) if s.cfg != nil && s.cfg.Gateway.LogUpstreamErrorBody { - log.Printf( + logger.LegacyPrintf("service.gateway", "count_tokens upstream error %d (account=%d platform=%s type=%s): %s", resp.StatusCode, account.ID, @@ -5171,6 +6145,158 @@ func (s *GatewayService) ForwardCountTokens(ctx context.Context, c *gin.Context, return nil } +func (s *GatewayService) forwardCountTokensAnthropicAPIKeyPassthrough(ctx context.Context, c *gin.Context, account *Account, body []byte) error { + token, tokenType, err := s.GetAccessToken(ctx, account) + if err != nil { + s.countTokensError(c, http.StatusBadGateway, "upstream_error", "Failed to get access token") + return err + } + if tokenType != "apikey" { + s.countTokensError(c, http.StatusBadGateway, "upstream_error", "Invalid account token type") + return fmt.Errorf("anthropic api key passthrough requires apikey token, got: %s", tokenType) + } + + upstreamReq, err := s.buildCountTokensRequestAnthropicAPIKeyPassthrough(ctx, c, account, body, token) + if err != nil { + s.countTokensError(c, http.StatusInternalServerError, "api_error", "Failed to build request") + return err + } + + proxyURL := "" + if account.ProxyID != nil && account.Proxy != nil { + proxyURL = account.Proxy.URL() + } + + resp, err := s.httpUpstream.DoWithTLS(upstreamReq, proxyURL, account.ID, account.Concurrency, account.IsTLSFingerprintEnabled()) + if err != nil { + setOpsUpstreamError(c, 0, sanitizeUpstreamErrorMessage(err.Error()), "") + appendOpsUpstreamError(c, OpsUpstreamErrorEvent{ + Platform: account.Platform, + AccountID: account.ID, + AccountName: account.Name, + UpstreamStatusCode: 0, + Passthrough: true, + Kind: "request_error", + Message: sanitizeUpstreamErrorMessage(err.Error()), + }) + s.countTokensError(c, http.StatusBadGateway, "upstream_error", "Request failed") + return fmt.Errorf("upstream request failed: %w", err) + } + + maxReadBytes := resolveUpstreamResponseReadLimit(s.cfg) + respBody, err := readUpstreamResponseBodyLimited(resp.Body, maxReadBytes) + _ = resp.Body.Close() + if err != nil { + if errors.Is(err, ErrUpstreamResponseBodyTooLarge) { + setOpsUpstreamError(c, http.StatusBadGateway, "upstream response too large", "") + s.countTokensError(c, http.StatusBadGateway, "upstream_error", "Upstream response too large") + return err + } + s.countTokensError(c, http.StatusBadGateway, "upstream_error", "Failed to read response") + return err + } + + if resp.StatusCode >= 400 { + if s.rateLimitService != nil { + s.rateLimitService.HandleUpstreamError(ctx, account, resp.StatusCode, resp.Header, respBody) + } + + upstreamMsg := strings.TrimSpace(extractUpstreamErrorMessage(respBody)) + upstreamMsg = sanitizeUpstreamErrorMessage(upstreamMsg) + upstreamDetail := "" + if s.cfg != nil && s.cfg.Gateway.LogUpstreamErrorBody { + maxBytes := s.cfg.Gateway.LogUpstreamErrorBodyMaxBytes + if maxBytes <= 0 { + maxBytes = 2048 + } + upstreamDetail = truncateString(string(respBody), maxBytes) + } + setOpsUpstreamError(c, resp.StatusCode, upstreamMsg, upstreamDetail) + appendOpsUpstreamError(c, OpsUpstreamErrorEvent{ + Platform: account.Platform, + AccountID: account.ID, + AccountName: account.Name, + UpstreamStatusCode: resp.StatusCode, + UpstreamRequestID: resp.Header.Get("x-request-id"), + Passthrough: true, + Kind: "http_error", + Message: upstreamMsg, + Detail: upstreamDetail, + }) + + errMsg := "Upstream request failed" + switch resp.StatusCode { + case 429: + errMsg = "Rate limit exceeded" + case 529: + errMsg = "Service overloaded" + } + s.countTokensError(c, resp.StatusCode, "upstream_error", errMsg) + if upstreamMsg == "" { + return fmt.Errorf("upstream error: %d", resp.StatusCode) + } + return fmt.Errorf("upstream error: %d message=%s", resp.StatusCode, upstreamMsg) + } + + writeAnthropicPassthroughResponseHeaders(c.Writer.Header(), resp.Header, s.cfg) + contentType := strings.TrimSpace(resp.Header.Get("Content-Type")) + if contentType == "" { + contentType = "application/json" + } + c.Data(resp.StatusCode, contentType, respBody) + return nil +} + +func (s *GatewayService) buildCountTokensRequestAnthropicAPIKeyPassthrough( + ctx context.Context, + c *gin.Context, + account *Account, + body []byte, + token string, +) (*http.Request, error) { + targetURL := claudeAPICountTokensURL + baseURL := account.GetBaseURL() + if baseURL != "" { + validatedURL, err := s.validateUpstreamBaseURL(baseURL) + if err != nil { + return nil, err + } + targetURL = validatedURL + "/v1/messages/count_tokens" + } + + req, err := http.NewRequestWithContext(ctx, http.MethodPost, targetURL, bytes.NewReader(body)) + if err != nil { + return nil, err + } + + if c != nil && c.Request != nil { + for key, values := range c.Request.Header { + lowerKey := strings.ToLower(strings.TrimSpace(key)) + if !allowedHeaders[lowerKey] { + continue + } + for _, v := range values { + req.Header.Add(key, v) + } + } + } + + req.Header.Del("authorization") + req.Header.Del("x-api-key") + req.Header.Del("x-goog-api-key") + req.Header.Del("cookie") + req.Header.Set("x-api-key", token) + + if req.Header.Get("content-type") == "" { + req.Header.Set("content-type", "application/json") + } + if req.Header.Get("anthropic-version") == "" { + req.Header.Set("anthropic-version", "2023-06-01") + } + + return req, nil +} + // buildCountTokensRequest 构建 count_tokens 上游请求 func (s *GatewayService) buildCountTokensRequest(ctx context.Context, c *gin.Context, account *Account, body []byte, token, tokenType, modelID string, mimicClaudeCode bool) (*http.Request, error) { // 确定目标 URL @@ -5319,6 +6445,17 @@ func (s *GatewayService) validateUpstreamBaseURL(raw string) (string, error) { // GetAvailableModels returns the list of models available for a group // It aggregates model_mapping keys from all schedulable accounts in the group func (s *GatewayService) GetAvailableModels(ctx context.Context, groupID *int64, platform string) []string { + cacheKey := modelsListCacheKey(groupID, platform) + if s.modelsListCache != nil { + if cached, found := s.modelsListCache.Get(cacheKey); found { + if models, ok := cached.([]string); ok { + modelsListCacheHitTotal.Add(1) + return cloneStringSlice(models) + } + } + } + modelsListCacheMissTotal.Add(1) + var accounts []Account var err error @@ -5359,6 +6496,10 @@ func (s *GatewayService) GetAvailableModels(ctx context.Context, groupID *int64, // If no account has model_mapping, return nil (use default) if !hasAnyMapping { + if s.modelsListCache != nil { + s.modelsListCache.Set(cacheKey, []string(nil), s.modelsListCacheTTL) + modelsListCacheStoreTotal.Add(1) + } return nil } @@ -5367,8 +6508,45 @@ func (s *GatewayService) GetAvailableModels(ctx context.Context, groupID *int64, for model := range modelSet { models = append(models, model) } + sort.Strings(models) - return models + if s.modelsListCache != nil { + s.modelsListCache.Set(cacheKey, cloneStringSlice(models), s.modelsListCacheTTL) + modelsListCacheStoreTotal.Add(1) + } + return cloneStringSlice(models) +} + +func (s *GatewayService) InvalidateAvailableModelsCache(groupID *int64, platform string) { + if s == nil || s.modelsListCache == nil { + return + } + + normalizedPlatform := strings.TrimSpace(platform) + // 完整匹配时精准失效;否则按维度批量失效。 + if groupID != nil && normalizedPlatform != "" { + s.modelsListCache.Delete(modelsListCacheKey(groupID, normalizedPlatform)) + return + } + + targetGroup := derefGroupID(groupID) + for key := range s.modelsListCache.Items() { + parts := strings.SplitN(key, "|", 2) + if len(parts) != 2 { + continue + } + groupPart, parseErr := strconv.ParseInt(parts[0], 10, 64) + if parseErr != nil { + continue + } + if groupID != nil && groupPart != targetGroup { + continue + } + if normalizedPlatform != "" && parts[1] != normalizedPlatform { + continue + } + s.modelsListCache.Delete(key) + } } // reconcileCachedTokens 兼容 Kimi 等上游: diff --git a/backend/internal/service/gateway_service_streaming_test.go b/backend/internal/service/gateway_service_streaming_test.go new file mode 100644 index 00000000..c8803d39 --- /dev/null +++ b/backend/internal/service/gateway_service_streaming_test.go @@ -0,0 +1,52 @@ +package service + +import ( + "context" + "io" + "net/http" + "net/http/httptest" + "testing" + "time" + + "github.com/Wei-Shaw/sub2api/internal/config" + "github.com/gin-gonic/gin" + "github.com/stretchr/testify/require" +) + +func TestGatewayService_StreamingReusesScannerBufferAndStillParsesUsage(t *testing.T) { + gin.SetMode(gin.TestMode) + cfg := &config.Config{ + Gateway: config.GatewayConfig{ + StreamDataIntervalTimeout: 0, + MaxLineSize: defaultMaxLineSize, + }, + } + + svc := &GatewayService{ + cfg: cfg, + rateLimitService: &RateLimitService{}, + } + + rec := httptest.NewRecorder() + c, _ := gin.CreateTestContext(rec) + c.Request = httptest.NewRequest(http.MethodPost, "/v1/messages", nil) + + pr, pw := io.Pipe() + resp := &http.Response{StatusCode: http.StatusOK, Header: http.Header{}, Body: pr} + + go func() { + defer func() { _ = pw.Close() }() + // Minimal SSE event to trigger parseSSEUsage + _, _ = pw.Write([]byte("data: {\"type\":\"message_start\",\"message\":{\"usage\":{\"input_tokens\":3}}}\n\n")) + _, _ = pw.Write([]byte("data: {\"type\":\"message_delta\",\"usage\":{\"output_tokens\":7}}\n\n")) + _, _ = pw.Write([]byte("data: [DONE]\n\n")) + }() + + result, err := svc.handleStreamingResponse(context.Background(), resp, c, &Account{ID: 1}, time.Now(), "model", "model", false) + _ = pr.Close() + require.NoError(t, err) + require.NotNil(t, result) + require.NotNil(t, result.usage) + require.Equal(t, 3, result.usage.InputTokens) + require.Equal(t, 7, result.usage.OutputTokens) +} diff --git a/backend/internal/service/gateway_streaming_test.go b/backend/internal/service/gateway_streaming_test.go new file mode 100644 index 00000000..cd690cbd --- /dev/null +++ b/backend/internal/service/gateway_streaming_test.go @@ -0,0 +1,219 @@ +//go:build unit + +package service + +import ( + "context" + "io" + "net/http" + "net/http/httptest" + "testing" + "time" + + "github.com/Wei-Shaw/sub2api/internal/config" + "github.com/gin-gonic/gin" + "github.com/stretchr/testify/require" +) + +// --- parseSSEUsage 测试 --- + +func newMinimalGatewayService() *GatewayService { + return &GatewayService{ + cfg: &config.Config{ + Gateway: config.GatewayConfig{ + StreamDataIntervalTimeout: 0, + MaxLineSize: defaultMaxLineSize, + }, + }, + rateLimitService: &RateLimitService{}, + } +} + +func TestParseSSEUsage_MessageStart(t *testing.T) { + svc := newMinimalGatewayService() + usage := &ClaudeUsage{} + + data := `{"type":"message_start","message":{"usage":{"input_tokens":100,"cache_creation_input_tokens":50,"cache_read_input_tokens":200}}}` + svc.parseSSEUsage(data, usage) + + require.Equal(t, 100, usage.InputTokens) + require.Equal(t, 50, usage.CacheCreationInputTokens) + require.Equal(t, 200, usage.CacheReadInputTokens) + require.Equal(t, 0, usage.OutputTokens, "message_start 不应设置 output_tokens") +} + +func TestParseSSEUsage_MessageDelta(t *testing.T) { + svc := newMinimalGatewayService() + usage := &ClaudeUsage{} + + data := `{"type":"message_delta","usage":{"output_tokens":42}}` + svc.parseSSEUsage(data, usage) + + require.Equal(t, 42, usage.OutputTokens) + require.Equal(t, 0, usage.InputTokens, "message_delta 的 output_tokens 不应影响已有的 input_tokens") +} + +func TestParseSSEUsage_DeltaDoesNotOverwriteStartValues(t *testing.T) { + svc := newMinimalGatewayService() + usage := &ClaudeUsage{} + + // 先处理 message_start + svc.parseSSEUsage(`{"type":"message_start","message":{"usage":{"input_tokens":100}}}`, usage) + require.Equal(t, 100, usage.InputTokens) + + // 再处理 message_delta(output_tokens > 0, input_tokens = 0) + svc.parseSSEUsage(`{"type":"message_delta","usage":{"output_tokens":50}}`, usage) + require.Equal(t, 100, usage.InputTokens, "delta 中 input_tokens=0 不应覆盖 start 中的值") + require.Equal(t, 50, usage.OutputTokens) +} + +func TestParseSSEUsage_DeltaOverwritesWithNonZero(t *testing.T) { + svc := newMinimalGatewayService() + usage := &ClaudeUsage{} + + // GLM 等 API 会在 delta 中包含所有 usage 信息 + svc.parseSSEUsage(`{"type":"message_delta","usage":{"input_tokens":200,"output_tokens":100,"cache_creation_input_tokens":30,"cache_read_input_tokens":60}}`, usage) + require.Equal(t, 200, usage.InputTokens) + require.Equal(t, 100, usage.OutputTokens) + require.Equal(t, 30, usage.CacheCreationInputTokens) + require.Equal(t, 60, usage.CacheReadInputTokens) +} + +func TestParseSSEUsage_DeltaDoesNotResetCacheCreationBreakdown(t *testing.T) { + svc := newMinimalGatewayService() + usage := &ClaudeUsage{} + + // 先在 message_start 中写入非零 5m/1h 明细 + svc.parseSSEUsage(`{"type":"message_start","message":{"usage":{"input_tokens":100,"cache_creation":{"ephemeral_5m_input_tokens":30,"ephemeral_1h_input_tokens":70}}}}`, usage) + require.Equal(t, 30, usage.CacheCreation5mTokens) + require.Equal(t, 70, usage.CacheCreation1hTokens) + + // 后续 delta 带默认 0,不应覆盖已有非零值 + svc.parseSSEUsage(`{"type":"message_delta","usage":{"output_tokens":12,"cache_creation":{"ephemeral_5m_input_tokens":0,"ephemeral_1h_input_tokens":0}}}`, usage) + require.Equal(t, 30, usage.CacheCreation5mTokens, "delta 的 0 值不应重置 5m 明细") + require.Equal(t, 70, usage.CacheCreation1hTokens, "delta 的 0 值不应重置 1h 明细") + require.Equal(t, 12, usage.OutputTokens) +} + +func TestParseSSEUsage_InvalidJSON(t *testing.T) { + svc := newMinimalGatewayService() + usage := &ClaudeUsage{} + + // 无效 JSON 不应 panic + svc.parseSSEUsage("not json", usage) + require.Equal(t, 0, usage.InputTokens) + require.Equal(t, 0, usage.OutputTokens) +} + +func TestParseSSEUsage_UnknownType(t *testing.T) { + svc := newMinimalGatewayService() + usage := &ClaudeUsage{} + + // 不是 message_start 或 message_delta 的类型 + svc.parseSSEUsage(`{"type":"content_block_delta","delta":{"text":"hello"}}`, usage) + require.Equal(t, 0, usage.InputTokens) + require.Equal(t, 0, usage.OutputTokens) +} + +func TestParseSSEUsage_EmptyString(t *testing.T) { + svc := newMinimalGatewayService() + usage := &ClaudeUsage{} + + svc.parseSSEUsage("", usage) + require.Equal(t, 0, usage.InputTokens) +} + +func TestParseSSEUsage_DoneEvent(t *testing.T) { + svc := newMinimalGatewayService() + usage := &ClaudeUsage{} + + // [DONE] 事件不应影响 usage + svc.parseSSEUsage("[DONE]", usage) + require.Equal(t, 0, usage.InputTokens) +} + +// --- 流式响应端到端测试 --- + +func TestHandleStreamingResponse_CacheTokens(t *testing.T) { + gin.SetMode(gin.TestMode) + svc := newMinimalGatewayService() + + rec := httptest.NewRecorder() + c, _ := gin.CreateTestContext(rec) + c.Request = httptest.NewRequest(http.MethodPost, "/v1/messages", nil) + + pr, pw := io.Pipe() + resp := &http.Response{StatusCode: http.StatusOK, Header: http.Header{}, Body: pr} + + go func() { + defer func() { _ = pw.Close() }() + _, _ = pw.Write([]byte("data: {\"type\":\"message_start\",\"message\":{\"usage\":{\"input_tokens\":10,\"cache_creation_input_tokens\":20,\"cache_read_input_tokens\":30}}}\n\n")) + _, _ = pw.Write([]byte("data: {\"type\":\"message_delta\",\"usage\":{\"output_tokens\":15}}\n\n")) + _, _ = pw.Write([]byte("data: [DONE]\n\n")) + }() + + result, err := svc.handleStreamingResponse(context.Background(), resp, c, &Account{ID: 1}, time.Now(), "model", "model", false) + _ = pr.Close() + require.NoError(t, err) + require.NotNil(t, result) + require.NotNil(t, result.usage) + require.Equal(t, 10, result.usage.InputTokens) + require.Equal(t, 15, result.usage.OutputTokens) + require.Equal(t, 20, result.usage.CacheCreationInputTokens) + require.Equal(t, 30, result.usage.CacheReadInputTokens) +} + +func TestHandleStreamingResponse_EmptyStream(t *testing.T) { + gin.SetMode(gin.TestMode) + svc := newMinimalGatewayService() + + rec := httptest.NewRecorder() + c, _ := gin.CreateTestContext(rec) + c.Request = httptest.NewRequest(http.MethodPost, "/v1/messages", nil) + + pr, pw := io.Pipe() + resp := &http.Response{StatusCode: http.StatusOK, Header: http.Header{}, Body: pr} + + go func() { + // 直接关闭,不发送任何事件 + _ = pw.Close() + }() + + result, err := svc.handleStreamingResponse(context.Background(), resp, c, &Account{ID: 1}, time.Now(), "model", "model", false) + _ = pr.Close() + require.NoError(t, err) + require.NotNil(t, result) +} + +func TestHandleStreamingResponse_SpecialCharactersInJSON(t *testing.T) { + gin.SetMode(gin.TestMode) + svc := newMinimalGatewayService() + + rec := httptest.NewRecorder() + c, _ := gin.CreateTestContext(rec) + c.Request = httptest.NewRequest(http.MethodPost, "/v1/messages", nil) + + pr, pw := io.Pipe() + resp := &http.Response{StatusCode: http.StatusOK, Header: http.Header{}, Body: pr} + + go func() { + defer func() { _ = pw.Close() }() + // 包含特殊字符的 content_block_delta(引号、换行、Unicode) + _, _ = pw.Write([]byte("data: {\"type\":\"content_block_delta\",\"index\":0,\"delta\":{\"type\":\"text_delta\",\"text\":\"Hello \\\"world\\\"\\n你好\"}}\n\n")) + _, _ = pw.Write([]byte("data: {\"type\":\"message_start\",\"message\":{\"usage\":{\"input_tokens\":5}}}\n\n")) + _, _ = pw.Write([]byte("data: {\"type\":\"message_delta\",\"usage\":{\"output_tokens\":3}}\n\n")) + _, _ = pw.Write([]byte("data: [DONE]\n\n")) + }() + + result, err := svc.handleStreamingResponse(context.Background(), resp, c, &Account{ID: 1}, time.Now(), "model", "model", false) + _ = pr.Close() + require.NoError(t, err) + require.NotNil(t, result) + require.NotNil(t, result.usage) + require.Equal(t, 5, result.usage.InputTokens) + require.Equal(t, 3, result.usage.OutputTokens) + + // 验证响应中包含转发的数据 + body := rec.Body.String() + require.Contains(t, body, "content_block_delta", "响应应包含转发的 SSE 事件") +} diff --git a/backend/internal/service/gateway_waiting_queue_test.go b/backend/internal/service/gateway_waiting_queue_test.go new file mode 100644 index 00000000..0ed95c87 --- /dev/null +++ b/backend/internal/service/gateway_waiting_queue_test.go @@ -0,0 +1,120 @@ +//go:build unit + +package service + +import ( + "context" + "errors" + "testing" + + "github.com/stretchr/testify/require" +) + +// TestDecrementWaitCount_NilCache 确保 nil cache 不会 panic +func TestDecrementWaitCount_NilCache(t *testing.T) { + svc := &ConcurrencyService{cache: nil} + // 不应 panic + svc.DecrementWaitCount(context.Background(), 1) +} + +// TestDecrementWaitCount_CacheError 确保 cache 错误不会传播 +func TestDecrementWaitCount_CacheError(t *testing.T) { + cache := &stubConcurrencyCacheForTest{} + svc := NewConcurrencyService(cache) + // DecrementWaitCount 使用 background context,错误只记录日志不传播 + svc.DecrementWaitCount(context.Background(), 1) +} + +// TestDecrementAccountWaitCount_NilCache 确保 nil cache 不会 panic +func TestDecrementAccountWaitCount_NilCache(t *testing.T) { + svc := &ConcurrencyService{cache: nil} + svc.DecrementAccountWaitCount(context.Background(), 1) +} + +// TestDecrementAccountWaitCount_CacheError 确保 cache 错误不会传播 +func TestDecrementAccountWaitCount_CacheError(t *testing.T) { + cache := &stubConcurrencyCacheForTest{} + svc := NewConcurrencyService(cache) + svc.DecrementAccountWaitCount(context.Background(), 1) +} + +// TestWaitingQueueFlow_IncrementThenDecrement 测试完整的等待队列增减流程 +func TestWaitingQueueFlow_IncrementThenDecrement(t *testing.T) { + cache := &stubConcurrencyCacheForTest{waitAllowed: true} + svc := NewConcurrencyService(cache) + + // 进入等待队列 + allowed, err := svc.IncrementWaitCount(context.Background(), 1, 25) + require.NoError(t, err) + require.True(t, allowed) + + // 离开等待队列(不应 panic) + svc.DecrementWaitCount(context.Background(), 1) +} + +// TestWaitingQueueFlow_AccountLevel 测试账号级等待队列流程 +func TestWaitingQueueFlow_AccountLevel(t *testing.T) { + cache := &stubConcurrencyCacheForTest{waitAllowed: true} + svc := NewConcurrencyService(cache) + + // 进入账号等待队列 + allowed, err := svc.IncrementAccountWaitCount(context.Background(), 42, 10) + require.NoError(t, err) + require.True(t, allowed) + + // 离开账号等待队列 + svc.DecrementAccountWaitCount(context.Background(), 42) +} + +// TestWaitingQueueFull_Returns429Signal 测试等待队列满时返回 false +func TestWaitingQueueFull_Returns429Signal(t *testing.T) { + // waitAllowed=false 模拟队列已满 + cache := &stubConcurrencyCacheForTest{waitAllowed: false} + svc := NewConcurrencyService(cache) + + // 用户级等待队列满 + allowed, err := svc.IncrementWaitCount(context.Background(), 1, 25) + require.NoError(t, err) + require.False(t, allowed, "等待队列满时应返回 false(调用方根据此返回 429)") + + // 账号级等待队列满 + allowed, err = svc.IncrementAccountWaitCount(context.Background(), 1, 10) + require.NoError(t, err) + require.False(t, allowed, "账号等待队列满时应返回 false") +} + +// TestWaitingQueue_FailOpen_OnCacheError 测试 Redis 故障时 fail-open +func TestWaitingQueue_FailOpen_OnCacheError(t *testing.T) { + cache := &stubConcurrencyCacheForTest{waitErr: errors.New("redis connection refused")} + svc := NewConcurrencyService(cache) + + // 用户级:Redis 错误时允许通过 + allowed, err := svc.IncrementWaitCount(context.Background(), 1, 25) + require.NoError(t, err, "Redis 错误不应向调用方传播") + require.True(t, allowed, "Redis 故障时应 fail-open 放行") + + // 账号级:同样 fail-open + allowed, err = svc.IncrementAccountWaitCount(context.Background(), 1, 10) + require.NoError(t, err, "Redis 错误不应向调用方传播") + require.True(t, allowed, "Redis 故障时应 fail-open 放行") +} + +// TestCalculateMaxWait_Scenarios 测试最大等待队列大小计算 +func TestCalculateMaxWait_Scenarios(t *testing.T) { + tests := []struct { + concurrency int + expected int + }{ + {5, 25}, // 5 + 20 + {10, 30}, // 10 + 20 + {1, 21}, // 1 + 20 + {0, 21}, // min(1) + 20 + {-1, 21}, // min(1) + 20 + {-10, 21}, // min(1) + 20 + {100, 120}, // 100 + 20 + } + for _, tt := range tests { + result := CalculateMaxWait(tt.concurrency) + require.Equal(t, tt.expected, result, "CalculateMaxWait(%d)", tt.concurrency) + } +} diff --git a/backend/internal/service/gemini_messages_compat_service.go b/backend/internal/service/gemini_messages_compat_service.go index f3abd1dc..8670f99a 100644 --- a/backend/internal/service/gemini_messages_compat_service.go +++ b/backend/internal/service/gemini_messages_compat_service.go @@ -22,10 +22,12 @@ import ( "github.com/Wei-Shaw/sub2api/internal/pkg/ctxkey" "github.com/Wei-Shaw/sub2api/internal/pkg/geminicli" "github.com/Wei-Shaw/sub2api/internal/pkg/googleapi" + "github.com/Wei-Shaw/sub2api/internal/pkg/logger" "github.com/Wei-Shaw/sub2api/internal/util/responseheaders" "github.com/Wei-Shaw/sub2api/internal/util/urlvalidator" "github.com/gin-gonic/gin" + "github.com/tidwall/gjson" ) const geminiStickySessionTTL = time.Hour @@ -281,7 +283,7 @@ func (s *GeminiMessagesCompatService) passesRateLimitPreCheck(ctx context.Contex } ok, err := s.rateLimitService.PreCheckUsage(ctx, account, requestedModel) if err != nil { - log.Printf("[Gemini PreCheck] Account %d precheck error: %v", account.ID, err) + logger.LegacyPrintf("service.gemini_messages_compat", "[Gemini PreCheck] Account %d precheck error: %v", account.ID, err) } return ok } @@ -697,7 +699,7 @@ func (s *GeminiMessagesCompatService) Forward(ctx context.Context, c *gin.Contex Message: safeErr, }) if attempt < geminiMaxRetries { - log.Printf("Gemini account %d: upstream request failed, retry %d/%d: %v", account.ID, attempt, geminiMaxRetries, err) + logger.LegacyPrintf("service.gemini_messages_compat", "Gemini account %d: upstream request failed, retry %d/%d: %v", account.ID, attempt, geminiMaxRetries, err) sleepGeminiBackoff(attempt) continue } @@ -753,7 +755,7 @@ func (s *GeminiMessagesCompatService) Forward(ctx context.Context, c *gin.Contex } retryGeminiReq, txErr := convertClaudeMessagesToGeminiGenerateContent(strippedClaudeBody) if txErr == nil { - log.Printf("Gemini account %d: detected signature-related 400, retrying with downgraded Claude blocks (%s)", account.ID, stageName) + logger.LegacyPrintf("service.gemini_messages_compat", "Gemini account %d: detected signature-related 400, retrying with downgraded Claude blocks (%s)", account.ID, stageName) geminiReq = retryGeminiReq // Consume one retry budget attempt and continue with the updated request payload. sleepGeminiBackoff(1) @@ -820,7 +822,7 @@ func (s *GeminiMessagesCompatService) Forward(ctx context.Context, c *gin.Contex Detail: upstreamDetail, }) - log.Printf("Gemini account %d: upstream status %d, retry %d/%d", account.ID, resp.StatusCode, attempt, geminiMaxRetries) + logger.LegacyPrintf("service.gemini_messages_compat", "Gemini account %d: upstream status %d, retry %d/%d", account.ID, resp.StatusCode, attempt, geminiMaxRetries) sleepGeminiBackoff(attempt) continue } @@ -968,7 +970,8 @@ func (s *GeminiMessagesCompatService) Forward(ctx context.Context, c *gin.Contex if err != nil { return nil, s.writeClaudeError(c, http.StatusBadGateway, "upstream_error", "Failed to read upstream stream") } - claudeResp, usageObj2 := convertGeminiToClaudeMessage(collected, originalModel) + collectedBytes, _ := json.Marshal(collected) + claudeResp, usageObj2 := convertGeminiToClaudeMessage(collected, originalModel, collectedBytes) c.JSON(http.StatusOK, claudeResp) usage = usageObj2 if usageObj != nil && (usageObj.InputTokens > 0 || usageObj.OutputTokens > 0) { @@ -1195,7 +1198,7 @@ func (s *GeminiMessagesCompatService) ForwardNative(ctx context.Context, c *gin. Message: safeErr, }) if attempt < geminiMaxRetries { - log.Printf("Gemini account %d: upstream request failed, retry %d/%d: %v", account.ID, attempt, geminiMaxRetries, err) + logger.LegacyPrintf("service.gemini_messages_compat", "Gemini account %d: upstream request failed, retry %d/%d: %v", account.ID, attempt, geminiMaxRetries, err) sleepGeminiBackoff(attempt) continue } @@ -1264,7 +1267,7 @@ func (s *GeminiMessagesCompatService) ForwardNative(ctx context.Context, c *gin. Detail: upstreamDetail, }) - log.Printf("Gemini account %d: upstream status %d, retry %d/%d", account.ID, resp.StatusCode, attempt, geminiMaxRetries) + logger.LegacyPrintf("service.gemini_messages_compat", "Gemini account %d: upstream status %d, retry %d/%d", account.ID, resp.StatusCode, attempt, geminiMaxRetries) sleepGeminiBackoff(attempt) continue } @@ -1424,7 +1427,7 @@ func (s *GeminiMessagesCompatService) ForwardNative(ctx context.Context, c *gin. maxBytes = 2048 } upstreamDetail = truncateString(string(respBody), maxBytes) - log.Printf("[Gemini] native upstream error %d: %s", resp.StatusCode, truncateForLog(respBody, s.cfg.Gateway.LogUpstreamErrorBodyMaxBytes)) + logger.LegacyPrintf("service.gemini_messages_compat", "[Gemini] native upstream error %d: %s", resp.StatusCode, truncateForLog(respBody, s.cfg.Gateway.LogUpstreamErrorBodyMaxBytes)) } setOpsUpstreamError(c, resp.StatusCode, upstreamMsg, upstreamDetail) appendOpsUpstreamError(c, OpsUpstreamErrorEvent{ @@ -1601,7 +1604,7 @@ func (s *GeminiMessagesCompatService) writeGeminiMappedError(c *gin.Context, acc }) if s.cfg != nil && s.cfg.Gateway.LogUpstreamErrorBody { - log.Printf("[Gemini] upstream error %d: %s", upstreamStatus, truncateForLog(body, s.cfg.Gateway.LogUpstreamErrorBodyMaxBytes)) + logger.LegacyPrintf("service.gemini_messages_compat", "[Gemini] upstream error %d: %s", upstreamStatus, truncateForLog(body, s.cfg.Gateway.LogUpstreamErrorBodyMaxBytes)) } if status, errType, errMsg, matched := applyErrorPassthroughRule( @@ -1821,12 +1824,17 @@ func (s *GeminiMessagesCompatService) handleNonStreamingResponse(c *gin.Context, return nil, s.writeClaudeError(c, http.StatusBadGateway, "upstream_error", "Failed to read upstream response") } - geminiResp, err := unwrapGeminiResponse(body) + unwrappedBody, err := unwrapGeminiResponse(body) if err != nil { return nil, s.writeClaudeError(c, http.StatusBadGateway, "upstream_error", "Failed to parse upstream response") } - claudeResp, usage := convertGeminiToClaudeMessage(geminiResp, originalModel) + var geminiResp map[string]any + if err := json.Unmarshal(unwrappedBody, &geminiResp); err != nil { + return nil, s.writeClaudeError(c, http.StatusBadGateway, "upstream_error", "Failed to parse upstream response") + } + + claudeResp, usage := convertGeminiToClaudeMessage(geminiResp, originalModel, unwrappedBody) c.JSON(http.StatusOK, claudeResp) return usage, nil @@ -1899,11 +1907,16 @@ func (s *GeminiMessagesCompatService) handleStreamingResponse(c *gin.Context, re continue } - geminiResp, err := unwrapGeminiResponse([]byte(payload)) + unwrappedBytes, err := unwrapGeminiResponse([]byte(payload)) if err != nil { continue } + var geminiResp map[string]any + if err := json.Unmarshal(unwrappedBytes, &geminiResp); err != nil { + continue + } + if fr := extractGeminiFinishReason(geminiResp); fr != "" { finishReason = fr } @@ -2030,7 +2043,7 @@ func (s *GeminiMessagesCompatService) handleStreamingResponse(c *gin.Context, re } } - if u := extractGeminiUsage(geminiResp); u != nil { + if u := extractGeminiUsage(unwrappedBytes); u != nil { usage = *u } @@ -2121,11 +2134,7 @@ func unwrapIfNeeded(isOAuth bool, raw []byte) []byte { if err != nil { return raw } - b, err := json.Marshal(inner) - if err != nil { - return raw - } - return b + return inner } func collectGeminiSSE(body io.Reader, isOAuth bool) (map[string]any, *ClaudeUsage, error) { @@ -2149,17 +2158,20 @@ func collectGeminiSSE(body io.Reader, isOAuth bool) (map[string]any, *ClaudeUsag } default: var parsed map[string]any + var rawBytes []byte if isOAuth { - inner, err := unwrapGeminiResponse([]byte(payload)) - if err == nil && inner != nil { - parsed = inner + innerBytes, err := unwrapGeminiResponse([]byte(payload)) + if err == nil { + rawBytes = innerBytes + _ = json.Unmarshal(innerBytes, &parsed) } } else { - _ = json.Unmarshal([]byte(payload), &parsed) + rawBytes = []byte(payload) + _ = json.Unmarshal(rawBytes, &parsed) } if parsed != nil { last = parsed - if u := extractGeminiUsage(parsed); u != nil { + if u := extractGeminiUsage(rawBytes); u != nil { usage = u } if parts := extractGeminiParts(parsed); len(parts) > 0 { @@ -2288,53 +2300,27 @@ func isGeminiInsufficientScope(headers http.Header, body []byte) bool { } func estimateGeminiCountTokens(reqBody []byte) int { - var obj map[string]any - if err := json.Unmarshal(reqBody, &obj); err != nil { - return 0 - } - - var texts []string + total := 0 // systemInstruction.parts[].text - if si, ok := obj["systemInstruction"].(map[string]any); ok { - if parts, ok := si["parts"].([]any); ok { - for _, p := range parts { - if pm, ok := p.(map[string]any); ok { - if t, ok := pm["text"].(string); ok && strings.TrimSpace(t) != "" { - texts = append(texts, t) - } - } - } + gjson.GetBytes(reqBody, "systemInstruction.parts").ForEach(func(_, part gjson.Result) bool { + if t := strings.TrimSpace(part.Get("text").String()); t != "" { + total += estimateTokensForText(t) } - } + return true + }) // contents[].parts[].text - if contents, ok := obj["contents"].([]any); ok { - for _, c := range contents { - cm, ok := c.(map[string]any) - if !ok { - continue + gjson.GetBytes(reqBody, "contents").ForEach(func(_, content gjson.Result) bool { + content.Get("parts").ForEach(func(_, part gjson.Result) bool { + if t := strings.TrimSpace(part.Get("text").String()); t != "" { + total += estimateTokensForText(t) } - parts, ok := cm["parts"].([]any) - if !ok { - continue - } - for _, p := range parts { - pm, ok := p.(map[string]any) - if !ok { - continue - } - if t, ok := pm["text"].(string); ok && strings.TrimSpace(t) != "" { - texts = append(texts, t) - } - } - } - } + return true + }) + return true + }) - total := 0 - for _, t := range texts { - total += estimateTokensForText(t) - } if total < 0 { return 0 } @@ -2372,28 +2358,36 @@ type UpstreamHTTPResult struct { } func (s *GeminiMessagesCompatService) handleNativeNonStreamingResponse(c *gin.Context, resp *http.Response, isOAuth bool) (*ClaudeUsage, error) { - // Log response headers for debugging - log.Printf("[GeminiAPI] ========== Response Headers ==========") - for key, values := range resp.Header { - if strings.HasPrefix(strings.ToLower(key), "x-ratelimit") { - log.Printf("[GeminiAPI] %s: %v", key, values) + if s.cfg != nil && s.cfg.Gateway.GeminiDebugResponseHeaders { + logger.LegacyPrintf("service.gemini_messages_compat", "[GeminiAPI] ========== Response Headers ==========") + for key, values := range resp.Header { + if strings.HasPrefix(strings.ToLower(key), "x-ratelimit") { + logger.LegacyPrintf("service.gemini_messages_compat", "[GeminiAPI] %s: %v", key, values) + } } + logger.LegacyPrintf("service.gemini_messages_compat", "[GeminiAPI] ========================================") } - log.Printf("[GeminiAPI] ========================================") - respBody, err := io.ReadAll(resp.Body) + maxBytes := resolveUpstreamResponseReadLimit(s.cfg) + respBody, err := readUpstreamResponseBodyLimited(resp.Body, maxBytes) if err != nil { + if errors.Is(err, ErrUpstreamResponseBodyTooLarge) { + setOpsUpstreamError(c, http.StatusBadGateway, "upstream response too large", "") + c.JSON(http.StatusBadGateway, gin.H{ + "error": gin.H{ + "type": "upstream_error", + "message": "Upstream response too large", + }, + }) + } return nil, err } - var parsed map[string]any if isOAuth { - parsed, err = unwrapGeminiResponse(respBody) - if err == nil && parsed != nil { - respBody, _ = json.Marshal(parsed) + unwrappedBody, uwErr := unwrapGeminiResponse(respBody) + if uwErr == nil { + respBody = unwrappedBody } - } else { - _ = json.Unmarshal(respBody, &parsed) } responseheaders.WriteFilteredHeaders(c.Writer.Header(), resp.Header, s.cfg.Security.ResponseHeaders) @@ -2404,23 +2398,22 @@ func (s *GeminiMessagesCompatService) handleNativeNonStreamingResponse(c *gin.Co } c.Data(resp.StatusCode, contentType, respBody) - if parsed != nil { - if u := extractGeminiUsage(parsed); u != nil { - return u, nil - } + if u := extractGeminiUsage(respBody); u != nil { + return u, nil } return &ClaudeUsage{}, nil } func (s *GeminiMessagesCompatService) handleNativeStreamingResponse(c *gin.Context, resp *http.Response, startTime time.Time, isOAuth bool) (*geminiNativeStreamResult, error) { - // Log response headers for debugging - log.Printf("[GeminiAPI] ========== Streaming Response Headers ==========") - for key, values := range resp.Header { - if strings.HasPrefix(strings.ToLower(key), "x-ratelimit") { - log.Printf("[GeminiAPI] %s: %v", key, values) + if s.cfg != nil && s.cfg.Gateway.GeminiDebugResponseHeaders { + logger.LegacyPrintf("service.gemini_messages_compat", "[GeminiAPI] ========== Streaming Response Headers ==========") + for key, values := range resp.Header { + if strings.HasPrefix(strings.ToLower(key), "x-ratelimit") { + logger.LegacyPrintf("service.gemini_messages_compat", "[GeminiAPI] %s: %v", key, values) + } } + logger.LegacyPrintf("service.gemini_messages_compat", "[GeminiAPI] ====================================================") } - log.Printf("[GeminiAPI] ====================================================") if s.cfg != nil { responseheaders.WriteFilteredHeaders(c.Writer.Header(), resp.Header, s.cfg.Security.ResponseHeaders) @@ -2460,23 +2453,19 @@ func (s *GeminiMessagesCompatService) handleNativeStreamingResponse(c *gin.Conte var rawToWrite string rawToWrite = payload - var parsed map[string]any + var rawBytes []byte if isOAuth { - inner, err := unwrapGeminiResponse([]byte(payload)) - if err == nil && inner != nil { - parsed = inner - if b, err := json.Marshal(inner); err == nil { - rawToWrite = string(b) - } + innerBytes, err := unwrapGeminiResponse([]byte(payload)) + if err == nil { + rawToWrite = string(innerBytes) + rawBytes = innerBytes } } else { - _ = json.Unmarshal([]byte(payload), &parsed) + rawBytes = []byte(payload) } - if parsed != nil { - if u := extractGeminiUsage(parsed); u != nil { - usage = u - } + if u := extractGeminiUsage(rawBytes); u != nil { + usage = u } if firstTokenMs == nil { @@ -2579,19 +2568,18 @@ func (s *GeminiMessagesCompatService) ForwardAIStudioGET(ctx context.Context, ac }, nil } -func unwrapGeminiResponse(raw []byte) (map[string]any, error) { - var outer map[string]any - if err := json.Unmarshal(raw, &outer); err != nil { - return nil, err +// unwrapGeminiResponse 解包 Gemini OAuth 响应中的 response 字段 +// 使用 gjson 零拷贝提取,避免完整 Unmarshal+Marshal +func unwrapGeminiResponse(raw []byte) ([]byte, error) { + result := gjson.GetBytes(raw, "response") + if result.Exists() && result.Type == gjson.JSON { + return []byte(result.Raw), nil } - if resp, ok := outer["response"].(map[string]any); ok && resp != nil { - return resp, nil - } - return outer, nil + return raw, nil } -func convertGeminiToClaudeMessage(geminiResp map[string]any, originalModel string) (map[string]any, *ClaudeUsage) { - usage := extractGeminiUsage(geminiResp) +func convertGeminiToClaudeMessage(geminiResp map[string]any, originalModel string, rawData []byte) (map[string]any, *ClaudeUsage) { + usage := extractGeminiUsage(rawData) if usage == nil { usage = &ClaudeUsage{} } @@ -2655,15 +2643,15 @@ func convertGeminiToClaudeMessage(geminiResp map[string]any, originalModel strin return resp, usage } -func extractGeminiUsage(geminiResp map[string]any) *ClaudeUsage { - usageMeta, ok := geminiResp["usageMetadata"].(map[string]any) - if !ok || usageMeta == nil { +func extractGeminiUsage(data []byte) *ClaudeUsage { + usage := gjson.GetBytes(data, "usageMetadata") + if !usage.Exists() { return nil } - prompt, _ := asInt(usageMeta["promptTokenCount"]) - cand, _ := asInt(usageMeta["candidatesTokenCount"]) - cached, _ := asInt(usageMeta["cachedContentTokenCount"]) - thoughts, _ := asInt(usageMeta["thoughtsTokenCount"]) + prompt := int(usage.Get("promptTokenCount").Int()) + cand := int(usage.Get("candidatesTokenCount").Int()) + cached := int(usage.Get("cachedContentTokenCount").Int()) + thoughts := int(usage.Get("thoughtsTokenCount").Int()) // 注意:Gemini 的 promptTokenCount 包含 cachedContentTokenCount, // 但 Claude 的 input_tokens 不包含 cache_read_input_tokens,需要减去 return &ClaudeUsage{ @@ -2721,16 +2709,16 @@ func (s *GeminiMessagesCompatService) handleGeminiUpstreamError(ctx context.Cont cooldown = s.rateLimitService.GeminiCooldown(ctx, account) } ra = time.Now().Add(cooldown) - log.Printf("[Gemini 429] Account %d (Code Assist, tier=%s, project=%s) rate limited, cooldown=%v", account.ID, tierID, projectID, time.Until(ra).Truncate(time.Second)) + logger.LegacyPrintf("service.gemini_messages_compat", "[Gemini 429] Account %d (Code Assist, tier=%s, project=%s) rate limited, cooldown=%v", account.ID, tierID, projectID, time.Until(ra).Truncate(time.Second)) } else { // API Key / AI Studio OAuth: PST 午夜 if ts := nextGeminiDailyResetUnix(); ts != nil { ra = time.Unix(*ts, 0) - log.Printf("[Gemini 429] Account %d (API Key/AI Studio, type=%s) rate limited, reset at PST midnight (%v)", account.ID, account.Type, ra) + logger.LegacyPrintf("service.gemini_messages_compat", "[Gemini 429] Account %d (API Key/AI Studio, type=%s) rate limited, reset at PST midnight (%v)", account.ID, account.Type, ra) } else { // 兜底:5 分钟 ra = time.Now().Add(5 * time.Minute) - log.Printf("[Gemini 429] Account %d rate limited, fallback to 5min", account.ID) + logger.LegacyPrintf("service.gemini_messages_compat", "[Gemini 429] Account %d rate limited, fallback to 5min", account.ID) } } _ = s.accountRepo.SetRateLimited(ctx, account.ID, ra) @@ -2740,45 +2728,41 @@ func (s *GeminiMessagesCompatService) handleGeminiUpstreamError(ctx context.Cont // 使用解析到的重置时间 resetTime := time.Unix(*resetAt, 0) _ = s.accountRepo.SetRateLimited(ctx, account.ID, resetTime) - log.Printf("[Gemini 429] Account %d rate limited until %v (oauth_type=%s, tier=%s)", + logger.LegacyPrintf("service.gemini_messages_compat", "[Gemini 429] Account %d rate limited until %v (oauth_type=%s, tier=%s)", account.ID, resetTime, oauthType, tierID) } // ParseGeminiRateLimitResetTime 解析 Gemini 格式的 429 响应,返回重置时间的 Unix 时间戳 func ParseGeminiRateLimitResetTime(body []byte) *int64 { - // Try to parse metadata.quotaResetDelay like "12.345s" - var parsed map[string]any - if err := json.Unmarshal(body, &parsed); err == nil { - if errObj, ok := parsed["error"].(map[string]any); ok { - if msg, ok := errObj["message"].(string); ok { - if looksLikeGeminiDailyQuota(msg) { - if ts := nextGeminiDailyResetUnix(); ts != nil { - return ts - } - } - } - if details, ok := errObj["details"].([]any); ok { - for _, d := range details { - dm, ok := d.(map[string]any) - if !ok { - continue - } - if meta, ok := dm["metadata"].(map[string]any); ok { - if v, ok := meta["quotaResetDelay"].(string); ok { - if dur, err := time.ParseDuration(v); err == nil { - // Use ceil to avoid undercounting fractional seconds (e.g. 10.1s should not become 10s), - // which can affect scheduling decisions around thresholds (like 10s). - ts := time.Now().Unix() + int64(math.Ceil(dur.Seconds())) - return &ts - } - } - } - } - } + // 第一阶段:gjson 结构化提取 + errMsg := gjson.GetBytes(body, "error.message").String() + if looksLikeGeminiDailyQuota(errMsg) { + if ts := nextGeminiDailyResetUnix(); ts != nil { + return ts } } - // Match "Please retry in Xs" + // 遍历 error.details 查找 quotaResetDelay + var found *int64 + gjson.GetBytes(body, "error.details").ForEach(func(_, detail gjson.Result) bool { + v := detail.Get("metadata.quotaResetDelay").String() + if v == "" { + return true + } + if dur, err := time.ParseDuration(v); err == nil { + // Use ceil to avoid undercounting fractional seconds (e.g. 10.1s should not become 10s), + // which can affect scheduling decisions around thresholds (like 10s). + ts := time.Now().Unix() + int64(math.Ceil(dur.Seconds())) + found = &ts + return false + } + return true + }) + if found != nil { + return found + } + + // 第二阶段:regex 回退匹配 "Please retry in Xs" matches := retryInRegex.FindStringSubmatch(string(body)) if len(matches) == 2 { if dur, err := time.ParseDuration(matches[1] + "s"); err == nil { diff --git a/backend/internal/service/gemini_messages_compat_service_test.go b/backend/internal/service/gemini_messages_compat_service_test.go index 5bc26973..7560f480 100644 --- a/backend/internal/service/gemini_messages_compat_service_test.go +++ b/backend/internal/service/gemini_messages_compat_service_test.go @@ -2,9 +2,16 @@ package service import ( "encoding/json" + "fmt" + "io" + "net/http" + "net/http/httptest" "strings" "testing" + "time" + "github.com/Wei-Shaw/sub2api/internal/config" + "github.com/gin-gonic/gin" "github.com/stretchr/testify/require" ) @@ -131,6 +138,38 @@ func TestConvertClaudeToolsToGeminiTools_CustomType(t *testing.T) { } } +func TestGeminiHandleNativeNonStreamingResponse_DebugDisabledDoesNotEmitHeaderLogs(t *testing.T) { + gin.SetMode(gin.TestMode) + logSink, restore := captureStructuredLog(t) + defer restore() + + svc := &GeminiMessagesCompatService{ + cfg: &config.Config{ + Gateway: config.GatewayConfig{ + GeminiDebugResponseHeaders: false, + }, + }, + } + + w := httptest.NewRecorder() + c, _ := gin.CreateTestContext(w) + c.Request = httptest.NewRequest(http.MethodPost, "/v1/messages", nil) + + resp := &http.Response{ + StatusCode: http.StatusOK, + Header: http.Header{ + "Content-Type": []string{"application/json"}, + "X-RateLimit-Limit": []string{"60"}, + }, + Body: io.NopCloser(strings.NewReader(`{"usageMetadata":{"promptTokenCount":10,"candidatesTokenCount":2}}`)), + } + + usage, err := svc.handleNativeNonStreamingResponse(c, resp, false) + require.NoError(t, err) + require.NotNil(t, usage) + require.False(t, logSink.ContainsMessage("[GeminiAPI]"), "debug 关闭时不应输出 Gemini 响应头日志") +} + func TestConvertClaudeMessagesToGeminiGenerateContent_AddsThoughtSignatureForToolUse(t *testing.T) { claudeReq := map[string]any{ "model": "claude-haiku-4-5-20251001", @@ -206,69 +245,323 @@ func TestEnsureGeminiFunctionCallThoughtSignatures_InsertsWhenMissing(t *testing } } -func TestExtractGeminiUsage_ThoughtsTokenCount(t *testing.T) { +// TestUnwrapGeminiResponse 测试 unwrapGeminiResponse 的各种输入场景 +// 关键区别:只有 response 为 JSON 对象/数组时才解包 +func TestUnwrapGeminiResponse(t *testing.T) { + // 构造 >50KB 的大型 JSON 对象 + largePadding := strings.Repeat("x", 50*1024) + largeInput := []byte(fmt.Sprintf(`{"response":{"id":"big","pad":"%s"}}`, largePadding)) + largeExpected := fmt.Sprintf(`{"id":"big","pad":"%s"}`, largePadding) + tests := []struct { - name string - resp map[string]any - wantInput int - wantOutput int - wantCacheRead int - wantNil bool + name string + input []byte + expected string + wantErr bool }{ { - name: "with thoughtsTokenCount", - resp: map[string]any{ - "usageMetadata": map[string]any{ - "promptTokenCount": float64(100), - "candidatesTokenCount": float64(20), - "thoughtsTokenCount": float64(50), - }, - }, - wantInput: 100, - wantOutput: 70, + name: "正常 response 包装(JSON 对象)", + input: []byte(`{"response":{"key":"val"}}`), + expected: `{"key":"val"}`, }, { - name: "with thoughtsTokenCount and cache", - resp: map[string]any{ - "usageMetadata": map[string]any{ - "promptTokenCount": float64(100), - "candidatesTokenCount": float64(20), - "cachedContentTokenCount": float64(30), - "thoughtsTokenCount": float64(50), - }, - }, - wantInput: 70, - wantOutput: 70, - wantCacheRead: 30, + name: "无包装直接返回", + input: []byte(`{"key":"val"}`), + expected: `{"key":"val"}`, }, { - name: "without thoughtsTokenCount (old model)", - resp: map[string]any{ - "usageMetadata": map[string]any{ - "promptTokenCount": float64(100), - "candidatesTokenCount": float64(20), - }, - }, - wantInput: 100, - wantOutput: 20, + name: "空 JSON", + input: []byte(`{}`), + expected: `{}`, }, { - name: "no usageMetadata", - resp: map[string]any{}, - wantNil: true, + name: "null response 返回原始 body", + input: []byte(`{"response":null}`), + expected: `{"response":null}`, + }, + { + name: "非法 JSON 返回原始 body", + input: []byte(`not json`), + expected: `not json`, + }, + { + name: "response 为基础类型 string 返回原始 body", + input: []byte(`{"response":"hello"}`), + expected: `{"response":"hello"}`, + }, + { + name: "嵌套 response 只解一层", + input: []byte(`{"response":{"response":{"inner":true}}}`), + expected: `{"response":{"inner":true}}`, + }, + { + name: "大型 JSON >50KB", + input: largeInput, + expected: largeExpected, }, } + for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - usage := extractGeminiUsage(tt.resp) - if tt.wantNil { - require.Nil(t, usage) + got, err := unwrapGeminiResponse(tt.input) + if tt.wantErr { + require.Error(t, err) return } - require.NotNil(t, usage) - require.Equal(t, tt.wantInput, usage.InputTokens) - require.Equal(t, tt.wantOutput, usage.OutputTokens) - require.Equal(t, tt.wantCacheRead, usage.CacheReadInputTokens) + require.NoError(t, err) + require.Equal(t, tt.expected, strings.TrimSpace(string(got))) + }) + } +} + +// --------------------------------------------------------------------------- +// Task 8.1 — extractGeminiUsage 测试 +// --------------------------------------------------------------------------- + +func TestExtractGeminiUsage(t *testing.T) { + tests := []struct { + name string + input string + wantNil bool + wantUsage *ClaudeUsage + }{ + { + name: "完整 usageMetadata", + input: `{"usageMetadata":{"promptTokenCount":100,"candidatesTokenCount":50,"cachedContentTokenCount":20}}`, + wantNil: false, + wantUsage: &ClaudeUsage{ + InputTokens: 80, + OutputTokens: 50, + CacheReadInputTokens: 20, + }, + }, + { + name: "包含 thoughtsTokenCount", + input: `{"usageMetadata":{"promptTokenCount":100,"candidatesTokenCount":20,"thoughtsTokenCount":50}}`, + wantNil: false, + wantUsage: &ClaudeUsage{ + InputTokens: 100, + OutputTokens: 70, + CacheReadInputTokens: 0, + }, + }, + { + name: "包含 thoughtsTokenCount 与缓存", + input: `{"usageMetadata":{"promptTokenCount":100,"candidatesTokenCount":20,"cachedContentTokenCount":30,"thoughtsTokenCount":50}}`, + wantNil: false, + wantUsage: &ClaudeUsage{ + InputTokens: 70, + OutputTokens: 70, + CacheReadInputTokens: 30, + }, + }, + { + name: "缺失 cachedContentTokenCount", + input: `{"usageMetadata":{"promptTokenCount":100,"candidatesTokenCount":50}}`, + wantNil: false, + wantUsage: &ClaudeUsage{ + InputTokens: 100, + OutputTokens: 50, + CacheReadInputTokens: 0, + }, + }, + { + name: "无 usageMetadata", + input: `{"candidates":[]}`, + wantNil: true, + }, + { + // gjson 对 null 返回 Exists()=true,因此函数不会返回 nil, + // 而是返回全零的 ClaudeUsage。 + name: "null usageMetadata — gjson Exists 为 true", + input: `{"usageMetadata":null}`, + wantNil: false, + wantUsage: &ClaudeUsage{ + InputTokens: 0, + OutputTokens: 0, + CacheReadInputTokens: 0, + }, + }, + { + name: "零值字段", + input: `{"usageMetadata":{"promptTokenCount":0,"candidatesTokenCount":0,"cachedContentTokenCount":0}}`, + wantNil: false, + wantUsage: &ClaudeUsage{ + InputTokens: 0, + OutputTokens: 0, + CacheReadInputTokens: 0, + }, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got := extractGeminiUsage([]byte(tt.input)) + if tt.wantNil { + if got != nil { + t.Fatalf("期望返回 nil,实际返回 %+v", got) + } + return + } + if got == nil { + t.Fatalf("期望返回非 nil,实际返回 nil") + } + if got.InputTokens != tt.wantUsage.InputTokens { + t.Errorf("InputTokens: 期望 %d,实际 %d", tt.wantUsage.InputTokens, got.InputTokens) + } + if got.OutputTokens != tt.wantUsage.OutputTokens { + t.Errorf("OutputTokens: 期望 %d,实际 %d", tt.wantUsage.OutputTokens, got.OutputTokens) + } + if got.CacheReadInputTokens != tt.wantUsage.CacheReadInputTokens { + t.Errorf("CacheReadInputTokens: 期望 %d,实际 %d", tt.wantUsage.CacheReadInputTokens, got.CacheReadInputTokens) + } + }) + } +} + +// --------------------------------------------------------------------------- +// Task 8.2 — estimateGeminiCountTokens 测试 +// --------------------------------------------------------------------------- + +func TestEstimateGeminiCountTokens(t *testing.T) { + tests := []struct { + name string + input string + wantGt0 bool // 期望结果 > 0 + wantExact *int // 如果非 nil,期望精确匹配 + }{ + { + name: "含 systemInstruction 和 contents", + input: `{ + "systemInstruction":{"parts":[{"text":"You are a helpful assistant."}]}, + "contents":[{"parts":[{"text":"Hello, how are you?"}]}] + }`, + wantGt0: true, + }, + { + name: "仅 contents,无 systemInstruction", + input: `{ + "contents":[{"parts":[{"text":"Hello, how are you?"}]}] + }`, + wantGt0: true, + }, + { + name: "空 parts", + input: `{"contents":[{"parts":[]}]}`, + wantGt0: false, + wantExact: intPtr(0), + }, + { + name: "非文本 parts(inlineData)", + input: `{"contents":[{"parts":[{"inlineData":{"mimeType":"image/png"}}]}]}`, + wantGt0: false, + wantExact: intPtr(0), + }, + { + name: "空白文本", + input: `{"contents":[{"parts":[{"text":" "}]}]}`, + wantGt0: false, + wantExact: intPtr(0), + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got := estimateGeminiCountTokens([]byte(tt.input)) + if tt.wantExact != nil { + if got != *tt.wantExact { + t.Errorf("期望精确值 %d,实际 %d", *tt.wantExact, got) + } + return + } + if tt.wantGt0 && got <= 0 { + t.Errorf("期望返回 > 0,实际 %d", got) + } + if !tt.wantGt0 && got != 0 { + t.Errorf("期望返回 0,实际 %d", got) + } + }) + } +} + +// --------------------------------------------------------------------------- +// Task 8.3 — ParseGeminiRateLimitResetTime 测试 +// --------------------------------------------------------------------------- + +func TestParseGeminiRateLimitResetTime(t *testing.T) { + tests := []struct { + name string + input string + wantNil bool + approxDelta int64 // 预期的 (返回值 - now) 大约是多少秒 + }{ + { + name: "正常 quotaResetDelay", + input: `{"error":{"details":[{"metadata":{"quotaResetDelay":"12.345s"}}]}}`, + wantNil: false, + approxDelta: 13, // 向上取整 12.345 -> 13 + }, + { + name: "daily quota", + input: `{"error":{"message":"quota per day exceeded"}}`, + wantNil: false, + approxDelta: -1, // 不检查精确 delta,仅检查非 nil + }, + { + name: "无 details 且无 regex 匹配", + input: `{"error":{"message":"rate limit"}}`, + wantNil: true, + }, + { + name: "regex 回退匹配", + input: `Please retry in 30s`, + wantNil: false, + approxDelta: 30, + }, + { + name: "完全无匹配", + input: `{"error":{"code":429}}`, + wantNil: true, + }, + { + name: "非法 JSON 但 regex 回退仍工作", + input: `not json but Please retry in 10s`, + wantNil: false, + approxDelta: 10, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + now := time.Now().Unix() + got := ParseGeminiRateLimitResetTime([]byte(tt.input)) + + if tt.wantNil { + if got != nil { + t.Fatalf("期望返回 nil,实际返回 %d", *got) + } + return + } + + if got == nil { + t.Fatalf("期望返回非 nil,实际返回 nil") + } + + // approxDelta == -1 表示只检查非 nil,不检查具体值(如 daily quota 场景) + if tt.approxDelta == -1 { + // 仅验证返回的时间戳在合理范围内(未来的某个时间) + if *got < now { + t.Errorf("期望返回的时间戳 >= now(%d),实际 %d", now, *got) + } + return + } + + // 使用 +/-2 秒容差进行范围检查 + delta := *got - now + if delta < tt.approxDelta-2 || delta > tt.approxDelta+2 { + t.Errorf("期望 delta 约为 %d 秒(+/-2),实际 delta = %d 秒(返回值=%d, now=%d)", + tt.approxDelta, delta, *got, now) + } }) } } diff --git a/backend/internal/service/gemini_multiplatform_test.go b/backend/internal/service/gemini_multiplatform_test.go index 6b1fcecc..86bc9476 100644 --- a/backend/internal/service/gemini_multiplatform_test.go +++ b/backend/internal/service/gemini_multiplatform_test.go @@ -66,6 +66,11 @@ func (m *mockAccountRepoForGemini) Create(ctx context.Context, account *Account) func (m *mockAccountRepoForGemini) GetByCRSAccountID(ctx context.Context, crsAccountID string) (*Account, error) { return nil, nil } + +func (m *mockAccountRepoForGemini) FindByExtraField(ctx context.Context, key string, value any) ([]Account, error) { + return nil, nil +} + func (m *mockAccountRepoForGemini) ListCRSAccountIDs(ctx context.Context) (map[string]int64, error) { return nil, nil } diff --git a/backend/internal/service/gemini_oauth_service.go b/backend/internal/service/gemini_oauth_service.go index fd2932e6..0b9734f6 100644 --- a/backend/internal/service/gemini_oauth_service.go +++ b/backend/internal/service/gemini_oauth_service.go @@ -6,7 +6,6 @@ import ( "errors" "fmt" "io" - "log" "net/http" "regexp" "strconv" @@ -16,6 +15,7 @@ import ( "github.com/Wei-Shaw/sub2api/internal/config" "github.com/Wei-Shaw/sub2api/internal/pkg/geminicli" "github.com/Wei-Shaw/sub2api/internal/pkg/httpclient" + "github.com/Wei-Shaw/sub2api/internal/pkg/logger" ) const ( @@ -81,8 +81,7 @@ func (s *GeminiOAuthService) GetOAuthConfig() *GeminiOAuthCapabilities { // AI Studio OAuth is only enabled when the operator configures a custom OAuth client. clientID := strings.TrimSpace(s.cfg.Gemini.OAuth.ClientID) clientSecret := strings.TrimSpace(s.cfg.Gemini.OAuth.ClientSecret) - enabled := clientID != "" && clientSecret != "" && - (clientID != geminicli.GeminiCLIOAuthClientID || clientSecret != geminicli.GeminiCLIOAuthClientSecret) + enabled := clientID != "" && clientSecret != "" && clientID != geminicli.GeminiCLIOAuthClientID return &GeminiOAuthCapabilities{ AIStudioOAuthEnabled: enabled, @@ -151,8 +150,7 @@ func (s *GeminiOAuthService) GenerateAuthURL(ctx context.Context, proxyID *int64 return nil, err } - isBuiltinClient := effectiveCfg.ClientID == geminicli.GeminiCLIOAuthClientID && - effectiveCfg.ClientSecret == geminicli.GeminiCLIOAuthClientSecret + isBuiltinClient := effectiveCfg.ClientID == geminicli.GeminiCLIOAuthClientID // AI Studio OAuth requires a user-provided OAuth client (built-in Gemini CLI client is scope-restricted). if oauthType == "ai_studio" && isBuiltinClient { @@ -330,27 +328,27 @@ func extractTierIDFromAllowedTiers(allowedTiers []geminicli.AllowedTier) string // inferGoogleOneTier infers Google One tier from Drive storage limit func inferGoogleOneTier(storageBytes int64) string { - log.Printf("[GeminiOAuth] inferGoogleOneTier - input: %d bytes (%.2f TB)", storageBytes, float64(storageBytes)/float64(TB)) + logger.LegacyPrintf("service.gemini_oauth", "[GeminiOAuth] inferGoogleOneTier - input: %d bytes (%.2f TB)", storageBytes, float64(storageBytes)/float64(TB)) if storageBytes <= 0 { - log.Printf("[GeminiOAuth] inferGoogleOneTier - storageBytes <= 0, returning UNKNOWN") + logger.LegacyPrintf("service.gemini_oauth", "[GeminiOAuth] inferGoogleOneTier - storageBytes <= 0, returning UNKNOWN") return GeminiTierGoogleOneUnknown } if storageBytes > StorageTierUnlimited { - log.Printf("[GeminiOAuth] inferGoogleOneTier - > %d bytes (100TB), returning UNLIMITED", StorageTierUnlimited) + logger.LegacyPrintf("service.gemini_oauth", "[GeminiOAuth] inferGoogleOneTier - > %d bytes (100TB), returning UNLIMITED", StorageTierUnlimited) return GeminiTierGoogleAIUltra } if storageBytes >= StorageTierAIPremium { - log.Printf("[GeminiOAuth] inferGoogleOneTier - >= %d bytes (2TB), returning google_ai_pro", StorageTierAIPremium) + logger.LegacyPrintf("service.gemini_oauth", "[GeminiOAuth] inferGoogleOneTier - >= %d bytes (2TB), returning google_ai_pro", StorageTierAIPremium) return GeminiTierGoogleAIPro } if storageBytes >= StorageTierFree { - log.Printf("[GeminiOAuth] inferGoogleOneTier - >= %d bytes (15GB), returning FREE", StorageTierFree) + logger.LegacyPrintf("service.gemini_oauth", "[GeminiOAuth] inferGoogleOneTier - >= %d bytes (15GB), returning FREE", StorageTierFree) return GeminiTierGoogleOneFree } - log.Printf("[GeminiOAuth] inferGoogleOneTier - < %d bytes (15GB), returning UNKNOWN", StorageTierFree) + logger.LegacyPrintf("service.gemini_oauth", "[GeminiOAuth] inferGoogleOneTier - < %d bytes (15GB), returning UNKNOWN", StorageTierFree) return GeminiTierGoogleOneUnknown } @@ -360,30 +358,30 @@ func inferGoogleOneTier(storageBytes int64) string { // 2. Personal accounts will get 403/404 from cloudaicompanion.googleapis.com // 3. Google consumer (Google One) and enterprise (GCP) systems are physically isolated func (s *GeminiOAuthService) FetchGoogleOneTier(ctx context.Context, accessToken, proxyURL string) (string, *geminicli.DriveStorageInfo, error) { - log.Printf("[GeminiOAuth] Starting FetchGoogleOneTier (Google One personal account)") + logger.LegacyPrintf("service.gemini_oauth", "[GeminiOAuth] Starting FetchGoogleOneTier (Google One personal account)") // Use Drive API to infer tier from storage quota (requires drive.readonly scope) - log.Printf("[GeminiOAuth] Calling Drive API for storage quota...") + logger.LegacyPrintf("service.gemini_oauth", "[GeminiOAuth] Calling Drive API for storage quota...") driveClient := geminicli.NewDriveClient() storageInfo, err := driveClient.GetStorageQuota(ctx, accessToken, proxyURL) if err != nil { // Check if it's a 403 (scope not granted) if strings.Contains(err.Error(), "status 403") { - log.Printf("[GeminiOAuth] Drive API scope not available (403): %v", err) + logger.LegacyPrintf("service.gemini_oauth", "[GeminiOAuth] Drive API scope not available (403): %v", err) return GeminiTierGoogleOneUnknown, nil, err } // Other errors - log.Printf("[GeminiOAuth] Failed to fetch Drive storage: %v", err) + logger.LegacyPrintf("service.gemini_oauth", "[GeminiOAuth] Failed to fetch Drive storage: %v", err) return GeminiTierGoogleOneUnknown, nil, err } - log.Printf("[GeminiOAuth] Drive API response - Limit: %d bytes (%.2f TB), Usage: %d bytes (%.2f GB)", + logger.LegacyPrintf("service.gemini_oauth", "[GeminiOAuth] Drive API response - Limit: %d bytes (%.2f TB), Usage: %d bytes (%.2f GB)", storageInfo.Limit, float64(storageInfo.Limit)/float64(TB), storageInfo.Usage, float64(storageInfo.Usage)/float64(GB)) tierID := inferGoogleOneTier(storageInfo.Limit) - log.Printf("[GeminiOAuth] Inferred tier from storage: %s", tierID) + logger.LegacyPrintf("service.gemini_oauth", "[GeminiOAuth] Inferred tier from storage: %s", tierID) return tierID, storageInfo, nil } @@ -443,16 +441,16 @@ func (s *GeminiOAuthService) RefreshAccountGoogleOneTier( } func (s *GeminiOAuthService) ExchangeCode(ctx context.Context, input *GeminiExchangeCodeInput) (*GeminiTokenInfo, error) { - log.Printf("[GeminiOAuth] ========== ExchangeCode START ==========") - log.Printf("[GeminiOAuth] SessionID: %s", input.SessionID) + logger.LegacyPrintf("service.gemini_oauth", "[GeminiOAuth] ========== ExchangeCode START ==========") + logger.LegacyPrintf("service.gemini_oauth", "[GeminiOAuth] SessionID: %s", input.SessionID) session, ok := s.sessionStore.Get(input.SessionID) if !ok { - log.Printf("[GeminiOAuth] ERROR: Session not found or expired") + logger.LegacyPrintf("service.gemini_oauth", "[GeminiOAuth] ERROR: Session not found or expired") return nil, fmt.Errorf("session not found or expired") } if strings.TrimSpace(input.State) == "" || input.State != session.State { - log.Printf("[GeminiOAuth] ERROR: Invalid state") + logger.LegacyPrintf("service.gemini_oauth", "[GeminiOAuth] ERROR: Invalid state") return nil, fmt.Errorf("invalid state") } @@ -463,7 +461,7 @@ func (s *GeminiOAuthService) ExchangeCode(ctx context.Context, input *GeminiExch proxyURL = proxy.URL() } } - log.Printf("[GeminiOAuth] ProxyURL: %s", proxyURL) + logger.LegacyPrintf("service.gemini_oauth", "[GeminiOAuth] ProxyURL: %s", proxyURL) redirectURI := session.RedirectURI @@ -472,8 +470,8 @@ func (s *GeminiOAuthService) ExchangeCode(ctx context.Context, input *GeminiExch if oauthType == "" { oauthType = "code_assist" } - log.Printf("[GeminiOAuth] OAuth Type: %s", oauthType) - log.Printf("[GeminiOAuth] Project ID from session: %s", session.ProjectID) + logger.LegacyPrintf("service.gemini_oauth", "[GeminiOAuth] OAuth Type: %s", oauthType) + logger.LegacyPrintf("service.gemini_oauth", "[GeminiOAuth] Project ID from session: %s", session.ProjectID) // If the session was created for AI Studio OAuth, ensure a custom OAuth client is configured. if oauthType == "ai_studio" { @@ -485,26 +483,25 @@ func (s *GeminiOAuthService) ExchangeCode(ctx context.Context, input *GeminiExch if err != nil { return nil, err } - isBuiltinClient := effectiveCfg.ClientID == geminicli.GeminiCLIOAuthClientID && - effectiveCfg.ClientSecret == geminicli.GeminiCLIOAuthClientSecret + isBuiltinClient := effectiveCfg.ClientID == geminicli.GeminiCLIOAuthClientID if isBuiltinClient { return nil, fmt.Errorf("AI Studio OAuth requires a custom OAuth Client. Please use an AI Studio API Key account, or configure GEMINI_OAUTH_CLIENT_ID / GEMINI_OAUTH_CLIENT_SECRET and re-authorize") } } - // code_assist always uses the built-in client and its fixed redirect URI. - if oauthType == "code_assist" { + // code_assist/google_one always uses the built-in client and its fixed redirect URI. + if oauthType == "code_assist" || oauthType == "google_one" { redirectURI = geminicli.GeminiCLIRedirectURI } tokenResp, err := s.oauthClient.ExchangeCode(ctx, oauthType, input.Code, session.CodeVerifier, redirectURI, proxyURL) if err != nil { - log.Printf("[GeminiOAuth] ERROR: Failed to exchange code: %v", err) + logger.LegacyPrintf("service.gemini_oauth", "[GeminiOAuth] ERROR: Failed to exchange code: %v", err) return nil, fmt.Errorf("failed to exchange code: %w", err) } - log.Printf("[GeminiOAuth] Token exchange successful") - log.Printf("[GeminiOAuth] Token scope: %s", tokenResp.Scope) - log.Printf("[GeminiOAuth] Token expires_in: %d seconds", tokenResp.ExpiresIn) + logger.LegacyPrintf("service.gemini_oauth", "[GeminiOAuth] Token exchange successful") + logger.LegacyPrintf("service.gemini_oauth", "[GeminiOAuth] Token scope: %s", tokenResp.Scope) + logger.LegacyPrintf("service.gemini_oauth", "[GeminiOAuth] Token expires_in: %d seconds", tokenResp.ExpiresIn) sessionProjectID := strings.TrimSpace(session.ProjectID) s.sessionStore.Delete(input.SessionID) @@ -526,40 +523,40 @@ func (s *GeminiOAuthService) ExchangeCode(ctx context.Context, input *GeminiExch fallbackTierID = canonicalGeminiTierIDForOAuthType(oauthType, session.TierID) } - log.Printf("[GeminiOAuth] ========== Account Type Detection START ==========") - log.Printf("[GeminiOAuth] OAuth Type: %s", oauthType) + logger.LegacyPrintf("service.gemini_oauth", "[GeminiOAuth] ========== Account Type Detection START ==========") + logger.LegacyPrintf("service.gemini_oauth", "[GeminiOAuth] OAuth Type: %s", oauthType) // 对于 code_assist 模式,project_id 是必需的,需要调用 Code Assist API // 对于 google_one 模式,使用个人 Google 账号,不需要 project_id,配额由 Google 网关自动识别 // 对于 ai_studio 模式,project_id 是可选的(不影响使用 AI Studio API) switch oauthType { case "code_assist": - log.Printf("[GeminiOAuth] Processing code_assist OAuth type") + logger.LegacyPrintf("service.gemini_oauth", "[GeminiOAuth] Processing code_assist OAuth type") if projectID == "" { - log.Printf("[GeminiOAuth] No project_id provided, attempting to fetch from LoadCodeAssist API...") + logger.LegacyPrintf("service.gemini_oauth", "[GeminiOAuth] No project_id provided, attempting to fetch from LoadCodeAssist API...") var err error projectID, tierID, err = s.fetchProjectID(ctx, tokenResp.AccessToken, proxyURL) if err != nil { // 记录警告但不阻断流程,允许后续补充 project_id fmt.Printf("[GeminiOAuth] Warning: Failed to fetch project_id during token exchange: %v\n", err) - log.Printf("[GeminiOAuth] WARNING: Failed to fetch project_id: %v", err) + logger.LegacyPrintf("service.gemini_oauth", "[GeminiOAuth] WARNING: Failed to fetch project_id: %v", err) } else { - log.Printf("[GeminiOAuth] Successfully fetched project_id: %s, tier_id: %s", projectID, tierID) + logger.LegacyPrintf("service.gemini_oauth", "[GeminiOAuth] Successfully fetched project_id: %s, tier_id: %s", projectID, tierID) } } else { - log.Printf("[GeminiOAuth] User provided project_id: %s, fetching tier_id...", projectID) + logger.LegacyPrintf("service.gemini_oauth", "[GeminiOAuth] User provided project_id: %s, fetching tier_id...", projectID) // 用户手动填了 project_id,仍需调用 LoadCodeAssist 获取 tierID _, fetchedTierID, err := s.fetchProjectID(ctx, tokenResp.AccessToken, proxyURL) if err != nil { fmt.Printf("[GeminiOAuth] Warning: Failed to fetch tierID: %v\n", err) - log.Printf("[GeminiOAuth] WARNING: Failed to fetch tier_id: %v", err) + logger.LegacyPrintf("service.gemini_oauth", "[GeminiOAuth] WARNING: Failed to fetch tier_id: %v", err) } else { tierID = fetchedTierID - log.Printf("[GeminiOAuth] Successfully fetched tier_id: %s", tierID) + logger.LegacyPrintf("service.gemini_oauth", "[GeminiOAuth] Successfully fetched tier_id: %s", tierID) } } if strings.TrimSpace(projectID) == "" { - log.Printf("[GeminiOAuth] ERROR: Missing project_id for Code Assist OAuth") + logger.LegacyPrintf("service.gemini_oauth", "[GeminiOAuth] ERROR: Missing project_id for Code Assist OAuth") return nil, fmt.Errorf("missing project_id for Code Assist OAuth: please fill Project ID (optional field) and regenerate the auth URL, or ensure your Google account has an ACTIVE GCP project") } // Prefer auto-detected tier; fall back to user-selected tier. @@ -567,31 +564,31 @@ func (s *GeminiOAuthService) ExchangeCode(ctx context.Context, input *GeminiExch if tierID == "" { if fallbackTierID != "" { tierID = fallbackTierID - log.Printf("[GeminiOAuth] Using fallback tier_id from user/session: %s", tierID) + logger.LegacyPrintf("service.gemini_oauth", "[GeminiOAuth] Using fallback tier_id from user/session: %s", tierID) } else { tierID = GeminiTierGCPStandard - log.Printf("[GeminiOAuth] Using default tier_id: %s", tierID) + logger.LegacyPrintf("service.gemini_oauth", "[GeminiOAuth] Using default tier_id: %s", tierID) } } - log.Printf("[GeminiOAuth] Final code_assist result - project_id: %s, tier_id: %s", projectID, tierID) + logger.LegacyPrintf("service.gemini_oauth", "[GeminiOAuth] Final code_assist result - project_id: %s, tier_id: %s", projectID, tierID) case "google_one": - log.Printf("[GeminiOAuth] Processing google_one OAuth type") + logger.LegacyPrintf("service.gemini_oauth", "[GeminiOAuth] Processing google_one OAuth type") // Google One accounts use cloudaicompanion API, which requires a project_id. // For personal accounts, Google auto-assigns a project_id via the LoadCodeAssist API. if projectID == "" { - log.Printf("[GeminiOAuth] No project_id provided, attempting to fetch from LoadCodeAssist API...") + logger.LegacyPrintf("service.gemini_oauth", "[GeminiOAuth] No project_id provided, attempting to fetch from LoadCodeAssist API...") var err error projectID, _, err = s.fetchProjectID(ctx, tokenResp.AccessToken, proxyURL) if err != nil { - log.Printf("[GeminiOAuth] ERROR: Failed to fetch project_id: %v", err) + logger.LegacyPrintf("service.gemini_oauth", "[GeminiOAuth] ERROR: Failed to fetch project_id: %v", err) return nil, fmt.Errorf("google One accounts require a project_id, failed to auto-detect: %w", err) } - log.Printf("[GeminiOAuth] Successfully fetched project_id: %s", projectID) + logger.LegacyPrintf("service.gemini_oauth", "[GeminiOAuth] Successfully fetched project_id: %s", projectID) } - log.Printf("[GeminiOAuth] Attempting to fetch Google One tier from Drive API...") + logger.LegacyPrintf("service.gemini_oauth", "[GeminiOAuth] Attempting to fetch Google One tier from Drive API...") // Attempt to fetch Drive storage tier var storageInfo *geminicli.DriveStorageInfo var err error @@ -599,12 +596,12 @@ func (s *GeminiOAuthService) ExchangeCode(ctx context.Context, input *GeminiExch if err != nil { // Log warning but don't block - use fallback fmt.Printf("[GeminiOAuth] Warning: Failed to fetch Drive tier: %v\n", err) - log.Printf("[GeminiOAuth] WARNING: Failed to fetch Drive tier: %v", err) + logger.LegacyPrintf("service.gemini_oauth", "[GeminiOAuth] WARNING: Failed to fetch Drive tier: %v", err) tierID = "" } else { - log.Printf("[GeminiOAuth] Successfully fetched Drive tier: %s", tierID) + logger.LegacyPrintf("service.gemini_oauth", "[GeminiOAuth] Successfully fetched Drive tier: %s", tierID) if storageInfo != nil { - log.Printf("[GeminiOAuth] Drive storage - Limit: %d bytes (%.2f TB), Usage: %d bytes (%.2f GB)", + logger.LegacyPrintf("service.gemini_oauth", "[GeminiOAuth] Drive storage - Limit: %d bytes (%.2f TB), Usage: %d bytes (%.2f GB)", storageInfo.Limit, float64(storageInfo.Limit)/float64(TB), storageInfo.Usage, float64(storageInfo.Usage)/float64(GB)) } @@ -613,10 +610,10 @@ func (s *GeminiOAuthService) ExchangeCode(ctx context.Context, input *GeminiExch if tierID == "" || tierID == GeminiTierGoogleOneUnknown { if fallbackTierID != "" { tierID = fallbackTierID - log.Printf("[GeminiOAuth] Using fallback tier_id from user/session: %s", tierID) + logger.LegacyPrintf("service.gemini_oauth", "[GeminiOAuth] Using fallback tier_id from user/session: %s", tierID) } else { tierID = GeminiTierGoogleOneFree - log.Printf("[GeminiOAuth] Using default tier_id: %s", tierID) + logger.LegacyPrintf("service.gemini_oauth", "[GeminiOAuth] Using default tier_id: %s", tierID) } } fmt.Printf("[GeminiOAuth] Google One tierID after normalization: %s\n", tierID) @@ -639,7 +636,7 @@ func (s *GeminiOAuthService) ExchangeCode(ctx context.Context, input *GeminiExch "drive_tier_updated_at": time.Now().Format(time.RFC3339), }, } - log.Printf("[GeminiOAuth] ========== ExchangeCode END (google_one with storage info) ==========") + logger.LegacyPrintf("service.gemini_oauth", "[GeminiOAuth] ========== ExchangeCode END (google_one with storage info) ==========") return tokenInfo, nil } @@ -652,10 +649,10 @@ func (s *GeminiOAuthService) ExchangeCode(ctx context.Context, input *GeminiExch } default: - log.Printf("[GeminiOAuth] Processing %s OAuth type (no tier detection)", oauthType) + logger.LegacyPrintf("service.gemini_oauth", "[GeminiOAuth] Processing %s OAuth type (no tier detection)", oauthType) } - log.Printf("[GeminiOAuth] ========== Account Type Detection END ==========") + logger.LegacyPrintf("service.gemini_oauth", "[GeminiOAuth] ========== Account Type Detection END ==========") result := &GeminiTokenInfo{ AccessToken: tokenResp.AccessToken, @@ -668,8 +665,8 @@ func (s *GeminiOAuthService) ExchangeCode(ctx context.Context, input *GeminiExch TierID: tierID, OAuthType: oauthType, } - log.Printf("[GeminiOAuth] Final result - OAuth Type: %s, Project ID: %s, Tier ID: %s", result.OAuthType, result.ProjectID, result.TierID) - log.Printf("[GeminiOAuth] ========== ExchangeCode END ==========") + logger.LegacyPrintf("service.gemini_oauth", "[GeminiOAuth] Final result - OAuth Type: %s, Project ID: %s, Tier ID: %s", result.OAuthType, result.ProjectID, result.TierID) + logger.LegacyPrintf("service.gemini_oauth", "[GeminiOAuth] ========== ExchangeCode END ==========") return result, nil } @@ -952,23 +949,23 @@ func (s *GeminiOAuthService) fetchProjectID(ctx context.Context, accessToken, pr registeredTierID := strings.TrimSpace(loadResp.GetTier()) if registeredTierID != "" { // 已注册但未返回 cloudaicompanionProject,这在 Google One 用户中较常见:需要用户自行提供 project_id。 - log.Printf("[GeminiOAuth] User has tier (%s) but no cloudaicompanionProject, trying Cloud Resource Manager...", registeredTierID) + logger.LegacyPrintf("service.gemini_oauth", "[GeminiOAuth] User has tier (%s) but no cloudaicompanionProject, trying Cloud Resource Manager...", registeredTierID) // Try to get project from Cloud Resource Manager fallback, fbErr := fetchProjectIDFromResourceManager(ctx, accessToken, proxyURL) if fbErr == nil && strings.TrimSpace(fallback) != "" { - log.Printf("[GeminiOAuth] Found project from Cloud Resource Manager: %s", fallback) + logger.LegacyPrintf("service.gemini_oauth", "[GeminiOAuth] Found project from Cloud Resource Manager: %s", fallback) return strings.TrimSpace(fallback), tierID, nil } // No project found - user must provide project_id manually - log.Printf("[GeminiOAuth] No project found from Cloud Resource Manager, user must provide project_id manually") + logger.LegacyPrintf("service.gemini_oauth", "[GeminiOAuth] No project found from Cloud Resource Manager, user must provide project_id manually") return "", tierID, fmt.Errorf("user is registered (tier: %s) but no project_id available. Please provide Project ID manually in the authorization form, or create a project at https://console.cloud.google.com", registeredTierID) } } // 未检测到 currentTier/paidTier,视为新用户,继续调用 onboardUser - log.Printf("[GeminiOAuth] No currentTier/paidTier found, proceeding with onboardUser (tierID: %s)", tierID) + logger.LegacyPrintf("service.gemini_oauth", "[GeminiOAuth] No currentTier/paidTier found, proceeding with onboardUser (tierID: %s)", tierID) req := &geminicli.OnboardUserRequest{ TierID: tierID, diff --git a/backend/internal/service/gemini_oauth_service_test.go b/backend/internal/service/gemini_oauth_service_test.go index 5591eb39..c58a5930 100644 --- a/backend/internal/service/gemini_oauth_service_test.go +++ b/backend/internal/service/gemini_oauth_service_test.go @@ -1,17 +1,29 @@ +//go:build unit + package service import ( "context" + "fmt" "net/url" "strings" "testing" + "time" "github.com/Wei-Shaw/sub2api/internal/config" "github.com/Wei-Shaw/sub2api/internal/pkg/geminicli" + "github.com/Wei-Shaw/sub2api/internal/pkg/pagination" ) +// ===================== +// 保留原有测试 +// ===================== + func TestGeminiOAuthService_GenerateAuthURL_RedirectURIStrategy(t *testing.T) { - t.Parallel() + // NOTE: This test sets process env; it must not run in parallel. + // The built-in Gemini CLI client secret is not embedded in this repository. + // Tests set a dummy secret via env to simulate operator-provided configuration. + t.Setenv(geminicli.GeminiCLIOAuthClientSecretEnv, "test-built-in-secret") type testCase struct { name string @@ -128,3 +140,1324 @@ func TestGeminiOAuthService_GenerateAuthURL_RedirectURIStrategy(t *testing.T) { }) } } + +// ===================== +// 新增测试:validateTierID +// ===================== + +func TestValidateTierID(t *testing.T) { + t.Parallel() + + tests := []struct { + name string + tierID string + wantErr bool + }{ + {name: "空字符串合法", tierID: "", wantErr: false}, + {name: "正常 tier_id", tierID: "google_one_free", wantErr: false}, + {name: "包含斜杠", tierID: "tier/sub", wantErr: false}, + {name: "包含连字符", tierID: "gcp-standard", wantErr: false}, + {name: "纯数字", tierID: "12345", wantErr: false}, + {name: "超长字符串(65个字符)", tierID: strings.Repeat("a", 65), wantErr: true}, + {name: "刚好64个字符", tierID: strings.Repeat("b", 64), wantErr: false}, + {name: "非法字符_空格", tierID: "tier id", wantErr: true}, + {name: "非法字符_中文", tierID: "tier_中文", wantErr: true}, + {name: "非法字符_特殊符号", tierID: "tier@id", wantErr: true}, + {name: "非法字符_感叹号", tierID: "tier!id", wantErr: true}, + } + + for _, tt := range tests { + tt := tt + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + err := validateTierID(tt.tierID) + if tt.wantErr && err == nil { + t.Fatalf("期望返回错误,但返回 nil") + } + if !tt.wantErr && err != nil { + t.Fatalf("不期望返回错误,但返回: %v", err) + } + }) + } +} + +// ===================== +// 新增测试:canonicalGeminiTierID +// ===================== + +func TestCanonicalGeminiTierID(t *testing.T) { + t.Parallel() + + tests := []struct { + name string + raw string + want string + }{ + // 空值 + {name: "空字符串", raw: "", want: ""}, + {name: "纯空白", raw: " ", want: ""}, + + // 已规范化的值(直接返回) + {name: "google_one_free", raw: "google_one_free", want: GeminiTierGoogleOneFree}, + {name: "google_ai_pro", raw: "google_ai_pro", want: GeminiTierGoogleAIPro}, + {name: "google_ai_ultra", raw: "google_ai_ultra", want: GeminiTierGoogleAIUltra}, + {name: "gcp_standard", raw: "gcp_standard", want: GeminiTierGCPStandard}, + {name: "gcp_enterprise", raw: "gcp_enterprise", want: GeminiTierGCPEnterprise}, + {name: "aistudio_free", raw: "aistudio_free", want: GeminiTierAIStudioFree}, + {name: "aistudio_paid", raw: "aistudio_paid", want: GeminiTierAIStudioPaid}, + {name: "google_one_unknown", raw: "google_one_unknown", want: GeminiTierGoogleOneUnknown}, + + // 大小写不敏感 + {name: "Google_One_Free 大写", raw: "Google_One_Free", want: GeminiTierGoogleOneFree}, + {name: "GCP_STANDARD 全大写", raw: "GCP_STANDARD", want: GeminiTierGCPStandard}, + + // legacy 映射: Google One + {name: "AI_PREMIUM -> google_ai_pro", raw: "AI_PREMIUM", want: GeminiTierGoogleAIPro}, + {name: "FREE -> google_one_free", raw: "FREE", want: GeminiTierGoogleOneFree}, + {name: "GOOGLE_ONE_BASIC -> google_one_free", raw: "GOOGLE_ONE_BASIC", want: GeminiTierGoogleOneFree}, + {name: "GOOGLE_ONE_STANDARD -> google_one_free", raw: "GOOGLE_ONE_STANDARD", want: GeminiTierGoogleOneFree}, + {name: "GOOGLE_ONE_UNLIMITED -> google_ai_ultra", raw: "GOOGLE_ONE_UNLIMITED", want: GeminiTierGoogleAIUltra}, + {name: "GOOGLE_ONE_UNKNOWN -> google_one_unknown", raw: "GOOGLE_ONE_UNKNOWN", want: GeminiTierGoogleOneUnknown}, + + // legacy 映射: Code Assist + {name: "STANDARD -> gcp_standard", raw: "STANDARD", want: GeminiTierGCPStandard}, + {name: "PRO -> gcp_standard", raw: "PRO", want: GeminiTierGCPStandard}, + {name: "LEGACY -> gcp_standard", raw: "LEGACY", want: GeminiTierGCPStandard}, + {name: "ENTERPRISE -> gcp_enterprise", raw: "ENTERPRISE", want: GeminiTierGCPEnterprise}, + {name: "ULTRA -> gcp_enterprise", raw: "ULTRA", want: GeminiTierGCPEnterprise}, + + // kebab-case + {name: "standard-tier -> gcp_standard", raw: "standard-tier", want: GeminiTierGCPStandard}, + {name: "pro-tier -> gcp_standard", raw: "pro-tier", want: GeminiTierGCPStandard}, + {name: "ultra-tier -> gcp_enterprise", raw: "ultra-tier", want: GeminiTierGCPEnterprise}, + + // 未知值 + {name: "unknown_value -> 空", raw: "unknown_value", want: ""}, + {name: "random-text -> 空", raw: "random-text", want: ""}, + + // 带空白 + {name: "带前后空白", raw: " google_one_free ", want: GeminiTierGoogleOneFree}, + } + + for _, tt := range tests { + tt := tt + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + got := canonicalGeminiTierID(tt.raw) + if got != tt.want { + t.Fatalf("canonicalGeminiTierID(%q) = %q, want %q", tt.raw, got, tt.want) + } + }) + } +} + +// ===================== +// 新增测试:canonicalGeminiTierIDForOAuthType +// ===================== + +func TestCanonicalGeminiTierIDForOAuthType(t *testing.T) { + t.Parallel() + + tests := []struct { + name string + oauthType string + tierID string + want string + }{ + // google_one 类型过滤 + {name: "google_one + google_one_free", oauthType: "google_one", tierID: "google_one_free", want: GeminiTierGoogleOneFree}, + {name: "google_one + google_ai_pro", oauthType: "google_one", tierID: "google_ai_pro", want: GeminiTierGoogleAIPro}, + {name: "google_one + google_ai_ultra", oauthType: "google_one", tierID: "google_ai_ultra", want: GeminiTierGoogleAIUltra}, + {name: "google_one + gcp_standard 被过滤", oauthType: "google_one", tierID: "gcp_standard", want: ""}, + {name: "google_one + aistudio_free 被过滤", oauthType: "google_one", tierID: "aistudio_free", want: ""}, + {name: "google_one + AI_PREMIUM 遗留映射", oauthType: "google_one", tierID: "AI_PREMIUM", want: GeminiTierGoogleAIPro}, + + // code_assist 类型过滤 + {name: "code_assist + gcp_standard", oauthType: "code_assist", tierID: "gcp_standard", want: GeminiTierGCPStandard}, + {name: "code_assist + gcp_enterprise", oauthType: "code_assist", tierID: "gcp_enterprise", want: GeminiTierGCPEnterprise}, + {name: "code_assist + google_one_free 被过滤", oauthType: "code_assist", tierID: "google_one_free", want: ""}, + {name: "code_assist + aistudio_free 被过滤", oauthType: "code_assist", tierID: "aistudio_free", want: ""}, + {name: "code_assist + STANDARD 遗留映射", oauthType: "code_assist", tierID: "STANDARD", want: GeminiTierGCPStandard}, + {name: "code_assist + standard-tier kebab", oauthType: "code_assist", tierID: "standard-tier", want: GeminiTierGCPStandard}, + + // ai_studio 类型过滤 + {name: "ai_studio + aistudio_free", oauthType: "ai_studio", tierID: "aistudio_free", want: GeminiTierAIStudioFree}, + {name: "ai_studio + aistudio_paid", oauthType: "ai_studio", tierID: "aistudio_paid", want: GeminiTierAIStudioPaid}, + {name: "ai_studio + gcp_standard 被过滤", oauthType: "ai_studio", tierID: "gcp_standard", want: ""}, + {name: "ai_studio + google_one_free 被过滤", oauthType: "ai_studio", tierID: "google_one_free", want: ""}, + + // 空值 + {name: "空 tierID", oauthType: "google_one", tierID: "", want: ""}, + {name: "空 oauthType + 有效 tierID", oauthType: "", tierID: "gcp_standard", want: GeminiTierGCPStandard}, + {name: "未知 oauthType 接受规范化值", oauthType: "unknown_type", tierID: "gcp_standard", want: GeminiTierGCPStandard}, + + // oauthType 大小写和空白 + {name: "GOOGLE_ONE 大写", oauthType: "GOOGLE_ONE", tierID: "google_one_free", want: GeminiTierGoogleOneFree}, + {name: "oauthType 带空白", oauthType: " code_assist ", tierID: "gcp_standard", want: GeminiTierGCPStandard}, + } + + for _, tt := range tests { + tt := tt + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + got := canonicalGeminiTierIDForOAuthType(tt.oauthType, tt.tierID) + if got != tt.want { + t.Fatalf("canonicalGeminiTierIDForOAuthType(%q, %q) = %q, want %q", tt.oauthType, tt.tierID, got, tt.want) + } + }) + } +} + +// ===================== +// 新增测试:extractTierIDFromAllowedTiers +// ===================== + +func TestExtractTierIDFromAllowedTiers(t *testing.T) { + t.Parallel() + + tests := []struct { + name string + allowedTiers []geminicli.AllowedTier + want string + }{ + { + name: "nil 列表返回 LEGACY", + allowedTiers: nil, + want: "LEGACY", + }, + { + name: "空列表返回 LEGACY", + allowedTiers: []geminicli.AllowedTier{}, + want: "LEGACY", + }, + { + name: "有 IsDefault 的 tier", + allowedTiers: []geminicli.AllowedTier{ + {ID: "STANDARD", IsDefault: false}, + {ID: "PRO", IsDefault: true}, + {ID: "ENTERPRISE", IsDefault: false}, + }, + want: "PRO", + }, + { + name: "没有 IsDefault 取第一个非空", + allowedTiers: []geminicli.AllowedTier{ + {ID: "STANDARD", IsDefault: false}, + {ID: "ENTERPRISE", IsDefault: false}, + }, + want: "STANDARD", + }, + { + name: "IsDefault 的 ID 为空,取第一个非空", + allowedTiers: []geminicli.AllowedTier{ + {ID: "", IsDefault: true}, + {ID: "PRO", IsDefault: false}, + }, + want: "PRO", + }, + { + name: "所有 ID 都为空返回 LEGACY", + allowedTiers: []geminicli.AllowedTier{ + {ID: "", IsDefault: false}, + {ID: " ", IsDefault: false}, + }, + want: "LEGACY", + }, + { + name: "ID 带空白会被 trim", + allowedTiers: []geminicli.AllowedTier{ + {ID: " STANDARD ", IsDefault: true}, + }, + want: "STANDARD", + }, + { + name: "单个 tier 且 IsDefault", + allowedTiers: []geminicli.AllowedTier{ + {ID: "ENTERPRISE", IsDefault: true}, + }, + want: "ENTERPRISE", + }, + { + name: "单个 tier 非 IsDefault", + allowedTiers: []geminicli.AllowedTier{ + {ID: "ENTERPRISE", IsDefault: false}, + }, + want: "ENTERPRISE", + }, + } + + for _, tt := range tests { + tt := tt + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + got := extractTierIDFromAllowedTiers(tt.allowedTiers) + if got != tt.want { + t.Fatalf("extractTierIDFromAllowedTiers() = %q, want %q", got, tt.want) + } + }) + } +} + +// ===================== +// 新增测试:inferGoogleOneTier +// ===================== + +func TestInferGoogleOneTier(t *testing.T) { + t.Parallel() + + tests := []struct { + name string + storageBytes int64 + want string + }{ + // 边界:<= 0 + {name: "0 bytes -> unknown", storageBytes: 0, want: GeminiTierGoogleOneUnknown}, + {name: "负数 -> unknown", storageBytes: -1, want: GeminiTierGoogleOneUnknown}, + + // > 100TB -> ultra + {name: "> 100TB -> ultra", storageBytes: int64(StorageTierUnlimited) + 1, want: GeminiTierGoogleAIUltra}, + {name: "200TB -> ultra", storageBytes: 200 * int64(TB), want: GeminiTierGoogleAIUltra}, + + // >= 2TB -> pro (但 <= 100TB) + {name: "正好 2TB -> pro", storageBytes: int64(StorageTierAIPremium), want: GeminiTierGoogleAIPro}, + {name: "5TB -> pro", storageBytes: 5 * int64(TB), want: GeminiTierGoogleAIPro}, + {name: "100TB 正好 -> pro (不是 > 100TB)", storageBytes: int64(StorageTierUnlimited), want: GeminiTierGoogleAIPro}, + + // >= 15GB -> free (但 < 2TB) + {name: "正好 15GB -> free", storageBytes: int64(StorageTierFree), want: GeminiTierGoogleOneFree}, + {name: "100GB -> free", storageBytes: 100 * int64(GB), want: GeminiTierGoogleOneFree}, + {name: "略低于 2TB -> free", storageBytes: int64(StorageTierAIPremium) - 1, want: GeminiTierGoogleOneFree}, + + // < 15GB -> unknown + {name: "1GB -> unknown", storageBytes: int64(GB), want: GeminiTierGoogleOneUnknown}, + {name: "略低于 15GB -> unknown", storageBytes: int64(StorageTierFree) - 1, want: GeminiTierGoogleOneUnknown}, + {name: "1 byte -> unknown", storageBytes: 1, want: GeminiTierGoogleOneUnknown}, + } + + for _, tt := range tests { + tt := tt + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + got := inferGoogleOneTier(tt.storageBytes) + if got != tt.want { + t.Fatalf("inferGoogleOneTier(%d) = %q, want %q", tt.storageBytes, got, tt.want) + } + }) + } +} + +// ===================== +// 新增测试:isNonRetryableGeminiOAuthError +// ===================== + +func TestIsNonRetryableGeminiOAuthError(t *testing.T) { + t.Parallel() + + tests := []struct { + name string + err error + want bool + }{ + {name: "invalid_grant", err: fmt.Errorf("error: invalid_grant"), want: true}, + {name: "invalid_client", err: fmt.Errorf("oauth error: invalid_client"), want: true}, + {name: "unauthorized_client", err: fmt.Errorf("unauthorized_client: mismatch"), want: true}, + {name: "access_denied", err: fmt.Errorf("access_denied by user"), want: true}, + {name: "普通网络错误", err: fmt.Errorf("connection timeout"), want: false}, + {name: "HTTP 500 错误", err: fmt.Errorf("server error 500"), want: false}, + {name: "空错误信息", err: fmt.Errorf(""), want: false}, + {name: "包含 invalid 但不是完整匹配", err: fmt.Errorf("invalid request"), want: false}, + } + + for _, tt := range tests { + tt := tt + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + got := isNonRetryableGeminiOAuthError(tt.err) + if got != tt.want { + t.Fatalf("isNonRetryableGeminiOAuthError(%v) = %v, want %v", tt.err, got, tt.want) + } + }) + } +} + +// ===================== +// 新增测试:BuildAccountCredentials +// ===================== + +func TestGeminiOAuthService_BuildAccountCredentials(t *testing.T) { + t.Parallel() + + svc := NewGeminiOAuthService(nil, nil, nil, &config.Config{}) + defer svc.Stop() + + t.Run("完整字段", func(t *testing.T) { + t.Parallel() + tokenInfo := &GeminiTokenInfo{ + AccessToken: "access-123", + RefreshToken: "refresh-456", + ExpiresIn: 3600, + ExpiresAt: 1700000000, + TokenType: "Bearer", + Scope: "openid email", + ProjectID: "my-project", + TierID: "gcp_standard", + OAuthType: "code_assist", + Extra: map[string]any{ + "drive_storage_limit": int64(2199023255552), + }, + } + + creds := svc.BuildAccountCredentials(tokenInfo) + + assertCredStr(t, creds, "access_token", "access-123") + assertCredStr(t, creds, "refresh_token", "refresh-456") + assertCredStr(t, creds, "token_type", "Bearer") + assertCredStr(t, creds, "scope", "openid email") + assertCredStr(t, creds, "project_id", "my-project") + assertCredStr(t, creds, "tier_id", "gcp_standard") + assertCredStr(t, creds, "oauth_type", "code_assist") + assertCredStr(t, creds, "expires_at", "1700000000") + + if _, ok := creds["drive_storage_limit"]; !ok { + t.Fatal("extra 字段 drive_storage_limit 未包含在 creds 中") + } + }) + + t.Run("最小字段(仅 access_token 和 expires_at)", func(t *testing.T) { + t.Parallel() + tokenInfo := &GeminiTokenInfo{ + AccessToken: "token-only", + ExpiresAt: 1700000000, + } + + creds := svc.BuildAccountCredentials(tokenInfo) + + assertCredStr(t, creds, "access_token", "token-only") + assertCredStr(t, creds, "expires_at", "1700000000") + + // 可选字段不应存在 + for _, key := range []string{"refresh_token", "token_type", "scope", "project_id", "tier_id", "oauth_type"} { + if _, ok := creds[key]; ok { + t.Fatalf("不应包含空字段 %q", key) + } + } + }) + + t.Run("无效 tier_id 被静默跳过", func(t *testing.T) { + t.Parallel() + tokenInfo := &GeminiTokenInfo{ + AccessToken: "token", + ExpiresAt: 1700000000, + TierID: "tier with spaces", + } + + creds := svc.BuildAccountCredentials(tokenInfo) + + if _, ok := creds["tier_id"]; ok { + t.Fatal("无效 tier_id 不应被存入 creds") + } + }) + + t.Run("超长 tier_id 被静默跳过", func(t *testing.T) { + t.Parallel() + tokenInfo := &GeminiTokenInfo{ + AccessToken: "token", + ExpiresAt: 1700000000, + TierID: strings.Repeat("x", 65), + } + + creds := svc.BuildAccountCredentials(tokenInfo) + + if _, ok := creds["tier_id"]; ok { + t.Fatal("超长 tier_id 不应被存入 creds") + } + }) + + t.Run("无 extra 字段", func(t *testing.T) { + t.Parallel() + tokenInfo := &GeminiTokenInfo{ + AccessToken: "token", + ExpiresAt: 1700000000, + RefreshToken: "rt", + } + + creds := svc.BuildAccountCredentials(tokenInfo) + + // 仅包含基础字段 + if len(creds) != 3 { // access_token, expires_at, refresh_token + t.Fatalf("creds 字段数量不匹配: got=%d want=3, keys=%v", len(creds), credKeys(creds)) + } + }) +} + +// ===================== +// 新增测试:GetOAuthConfig +// ===================== + +func TestGeminiOAuthService_GetOAuthConfig(t *testing.T) { + t.Parallel() + + tests := []struct { + name string + cfg *config.Config + wantEnabled bool + }{ + { + name: "无自定义 OAuth 客户端", + cfg: &config.Config{ + Gemini: config.GeminiConfig{ + OAuth: config.GeminiOAuthConfig{}, + }, + }, + wantEnabled: false, + }, + { + name: "仅 ClientID 无 ClientSecret", + cfg: &config.Config{ + Gemini: config.GeminiConfig{ + OAuth: config.GeminiOAuthConfig{ + ClientID: "custom-id", + }, + }, + }, + wantEnabled: false, + }, + { + name: "仅 ClientSecret 无 ClientID", + cfg: &config.Config{ + Gemini: config.GeminiConfig{ + OAuth: config.GeminiOAuthConfig{ + ClientSecret: "custom-secret", + }, + }, + }, + wantEnabled: false, + }, + { + name: "使用内置 Gemini CLI ClientID(不算自定义)", + cfg: &config.Config{ + Gemini: config.GeminiConfig{ + OAuth: config.GeminiOAuthConfig{ + ClientID: geminicli.GeminiCLIOAuthClientID, + ClientSecret: "some-secret", + }, + }, + }, + wantEnabled: false, + }, + { + name: "自定义 OAuth 客户端(非内置 ID)", + cfg: &config.Config{ + Gemini: config.GeminiConfig{ + OAuth: config.GeminiOAuthConfig{ + ClientID: "my-custom-client-id", + ClientSecret: "my-custom-client-secret", + }, + }, + }, + wantEnabled: true, + }, + { + name: "带空白的自定义客户端", + cfg: &config.Config{ + Gemini: config.GeminiConfig{ + OAuth: config.GeminiOAuthConfig{ + ClientID: " my-custom-client-id ", + ClientSecret: " my-custom-client-secret ", + }, + }, + }, + wantEnabled: true, + }, + { + name: "纯空白字符串不算配置", + cfg: &config.Config{ + Gemini: config.GeminiConfig{ + OAuth: config.GeminiOAuthConfig{ + ClientID: " ", + ClientSecret: " ", + }, + }, + }, + wantEnabled: false, + }, + } + + for _, tt := range tests { + tt := tt + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + svc := NewGeminiOAuthService(nil, nil, nil, tt.cfg) + defer svc.Stop() + + result := svc.GetOAuthConfig() + if result.AIStudioOAuthEnabled != tt.wantEnabled { + t.Fatalf("AIStudioOAuthEnabled = %v, want %v", result.AIStudioOAuthEnabled, tt.wantEnabled) + } + // RequiredRedirectURIs 始终包含 AI Studio redirect URI + if len(result.RequiredRedirectURIs) != 1 || result.RequiredRedirectURIs[0] != geminicli.AIStudioOAuthRedirectURI { + t.Fatalf("RequiredRedirectURIs 不匹配: got=%v", result.RequiredRedirectURIs) + } + }) + } +} + +// ===================== +// 新增测试:GeminiOAuthService.Stop +// ===================== + +func TestGeminiOAuthService_Stop_NoPanic(t *testing.T) { + t.Parallel() + + svc := NewGeminiOAuthService(nil, nil, nil, &config.Config{}) + + // 调用 Stop 不应 panic + svc.Stop() + // 多次调用也不应 panic + svc.Stop() +} + +// ===================== +// mock: GeminiOAuthClient +// ===================== + +type mockGeminiOAuthClient struct { + exchangeCodeFunc func(ctx context.Context, oauthType, code, codeVerifier, redirectURI, proxyURL string) (*geminicli.TokenResponse, error) + refreshTokenFunc func(ctx context.Context, oauthType, refreshToken, proxyURL string) (*geminicli.TokenResponse, error) +} + +func (m *mockGeminiOAuthClient) ExchangeCode(ctx context.Context, oauthType, code, codeVerifier, redirectURI, proxyURL string) (*geminicli.TokenResponse, error) { + if m.exchangeCodeFunc != nil { + return m.exchangeCodeFunc(ctx, oauthType, code, codeVerifier, redirectURI, proxyURL) + } + panic("ExchangeCode not implemented") +} + +func (m *mockGeminiOAuthClient) RefreshToken(ctx context.Context, oauthType, refreshToken, proxyURL string) (*geminicli.TokenResponse, error) { + if m.refreshTokenFunc != nil { + return m.refreshTokenFunc(ctx, oauthType, refreshToken, proxyURL) + } + panic("RefreshToken not implemented") +} + +// ===================== +// mock: GeminiCliCodeAssistClient +// ===================== + +type mockGeminiCodeAssistClient struct { + loadCodeAssistFunc func(ctx context.Context, accessToken, proxyURL string, req *geminicli.LoadCodeAssistRequest) (*geminicli.LoadCodeAssistResponse, error) + onboardUserFunc func(ctx context.Context, accessToken, proxyURL string, req *geminicli.OnboardUserRequest) (*geminicli.OnboardUserResponse, error) +} + +func (m *mockGeminiCodeAssistClient) LoadCodeAssist(ctx context.Context, accessToken, proxyURL string, req *geminicli.LoadCodeAssistRequest) (*geminicli.LoadCodeAssistResponse, error) { + if m.loadCodeAssistFunc != nil { + return m.loadCodeAssistFunc(ctx, accessToken, proxyURL, req) + } + panic("LoadCodeAssist not implemented") +} + +func (m *mockGeminiCodeAssistClient) OnboardUser(ctx context.Context, accessToken, proxyURL string, req *geminicli.OnboardUserRequest) (*geminicli.OnboardUserResponse, error) { + if m.onboardUserFunc != nil { + return m.onboardUserFunc(ctx, accessToken, proxyURL, req) + } + panic("OnboardUser not implemented") +} + +// ===================== +// mock: ProxyRepository (最小实现) +// ===================== + +type mockGeminiProxyRepo struct { + getByIDFunc func(ctx context.Context, id int64) (*Proxy, error) +} + +func (m *mockGeminiProxyRepo) Create(ctx context.Context, proxy *Proxy) error { panic("not impl") } +func (m *mockGeminiProxyRepo) GetByID(ctx context.Context, id int64) (*Proxy, error) { + if m.getByIDFunc != nil { + return m.getByIDFunc(ctx, id) + } + return nil, fmt.Errorf("proxy not found") +} +func (m *mockGeminiProxyRepo) ListByIDs(ctx context.Context, ids []int64) ([]Proxy, error) { + panic("not impl") +} +func (m *mockGeminiProxyRepo) Update(ctx context.Context, proxy *Proxy) error { panic("not impl") } +func (m *mockGeminiProxyRepo) Delete(ctx context.Context, id int64) error { panic("not impl") } +func (m *mockGeminiProxyRepo) List(ctx context.Context, params pagination.PaginationParams) ([]Proxy, *pagination.PaginationResult, error) { + panic("not impl") +} +func (m *mockGeminiProxyRepo) ListWithFilters(ctx context.Context, params pagination.PaginationParams, protocol, status, search string) ([]Proxy, *pagination.PaginationResult, error) { + panic("not impl") +} +func (m *mockGeminiProxyRepo) ListWithFiltersAndAccountCount(ctx context.Context, params pagination.PaginationParams, protocol, status, search string) ([]ProxyWithAccountCount, *pagination.PaginationResult, error) { + panic("not impl") +} +func (m *mockGeminiProxyRepo) ListActive(ctx context.Context) ([]Proxy, error) { panic("not impl") } +func (m *mockGeminiProxyRepo) ListActiveWithAccountCount(ctx context.Context) ([]ProxyWithAccountCount, error) { + panic("not impl") +} +func (m *mockGeminiProxyRepo) ExistsByHostPortAuth(ctx context.Context, host string, port int, username, password string) (bool, error) { + panic("not impl") +} +func (m *mockGeminiProxyRepo) CountAccountsByProxyID(ctx context.Context, proxyID int64) (int64, error) { + panic("not impl") +} +func (m *mockGeminiProxyRepo) ListAccountSummariesByProxyID(ctx context.Context, proxyID int64) ([]ProxyAccountSummary, error) { + panic("not impl") +} + +// ===================== +// 新增测试:GeminiOAuthService.RefreshToken(含重试逻辑) +// ===================== + +func TestGeminiOAuthService_RefreshToken_Success(t *testing.T) { + t.Parallel() + + client := &mockGeminiOAuthClient{ + refreshTokenFunc: func(ctx context.Context, oauthType, refreshToken, proxyURL string) (*geminicli.TokenResponse, error) { + return &geminicli.TokenResponse{ + AccessToken: "new-access", + RefreshToken: "new-refresh", + TokenType: "Bearer", + ExpiresIn: 3600, + Scope: "openid", + }, nil + }, + } + + svc := NewGeminiOAuthService(nil, client, nil, &config.Config{}) + defer svc.Stop() + + info, err := svc.RefreshToken(context.Background(), "code_assist", "old-refresh", "") + if err != nil { + t.Fatalf("RefreshToken 返回错误: %v", err) + } + if info.AccessToken != "new-access" { + t.Fatalf("AccessToken 不匹配: got=%q", info.AccessToken) + } + if info.RefreshToken != "new-refresh" { + t.Fatalf("RefreshToken 不匹配: got=%q", info.RefreshToken) + } + if info.ExpiresAt == 0 { + t.Fatal("ExpiresAt 不应为 0") + } +} + +func TestGeminiOAuthService_RefreshToken_NonRetryableError(t *testing.T) { + t.Parallel() + + client := &mockGeminiOAuthClient{ + refreshTokenFunc: func(ctx context.Context, oauthType, refreshToken, proxyURL string) (*geminicli.TokenResponse, error) { + return nil, fmt.Errorf("invalid_grant: token revoked") + }, + } + + svc := NewGeminiOAuthService(nil, client, nil, &config.Config{}) + defer svc.Stop() + + _, err := svc.RefreshToken(context.Background(), "code_assist", "revoked-token", "") + if err == nil { + t.Fatal("RefreshToken 应返回错误(不可重试的 invalid_grant)") + } + if !strings.Contains(err.Error(), "invalid_grant") { + t.Fatalf("错误应包含 invalid_grant: got=%q", err.Error()) + } +} + +func TestGeminiOAuthService_RefreshToken_RetryableError(t *testing.T) { + t.Parallel() + + callCount := 0 + client := &mockGeminiOAuthClient{ + refreshTokenFunc: func(ctx context.Context, oauthType, refreshToken, proxyURL string) (*geminicli.TokenResponse, error) { + callCount++ + if callCount <= 2 { + return nil, fmt.Errorf("temporary network error") + } + return &geminicli.TokenResponse{ + AccessToken: "recovered", + ExpiresIn: 3600, + }, nil + }, + } + + svc := NewGeminiOAuthService(nil, client, nil, &config.Config{}) + defer svc.Stop() + + info, err := svc.RefreshToken(context.Background(), "code_assist", "rt", "") + if err != nil { + t.Fatalf("RefreshToken 应在重试后成功: %v", err) + } + if info.AccessToken != "recovered" { + t.Fatalf("AccessToken 不匹配: got=%q", info.AccessToken) + } + if callCount < 3 { + t.Fatalf("应至少调用 3 次(2 次失败 + 1 次成功): got=%d", callCount) + } +} + +// ===================== +// 新增测试:GeminiOAuthService.RefreshAccountToken +// ===================== + +func TestGeminiOAuthService_RefreshAccountToken_NotGeminiOAuth(t *testing.T) { + t.Parallel() + + svc := NewGeminiOAuthService(nil, nil, nil, &config.Config{}) + defer svc.Stop() + + account := &Account{ + Platform: PlatformAnthropic, + Type: AccountTypeOAuth, + } + + _, err := svc.RefreshAccountToken(context.Background(), account) + if err == nil { + t.Fatal("应返回错误(非 Gemini OAuth 账号)") + } + if !strings.Contains(err.Error(), "not a Gemini OAuth account") { + t.Fatalf("错误信息不匹配: got=%q", err.Error()) + } +} + +func TestGeminiOAuthService_RefreshAccountToken_NoRefreshToken(t *testing.T) { + t.Parallel() + + svc := NewGeminiOAuthService(nil, nil, nil, &config.Config{}) + defer svc.Stop() + + account := &Account{ + Platform: PlatformGemini, + Type: AccountTypeOAuth, + Credentials: map[string]any{ + "access_token": "at", + "oauth_type": "code_assist", + }, + } + + _, err := svc.RefreshAccountToken(context.Background(), account) + if err == nil { + t.Fatal("应返回错误(无 refresh_token)") + } + if !strings.Contains(err.Error(), "no refresh token") { + t.Fatalf("错误信息不匹配: got=%q", err.Error()) + } +} + +func TestGeminiOAuthService_RefreshAccountToken_AIStudio(t *testing.T) { + t.Parallel() + + client := &mockGeminiOAuthClient{ + refreshTokenFunc: func(ctx context.Context, oauthType, refreshToken, proxyURL string) (*geminicli.TokenResponse, error) { + return &geminicli.TokenResponse{ + AccessToken: "refreshed-at", + RefreshToken: "refreshed-rt", + ExpiresIn: 3600, + TokenType: "Bearer", + }, nil + }, + } + + svc := NewGeminiOAuthService(&mockGeminiProxyRepo{}, client, nil, &config.Config{}) + defer svc.Stop() + + account := &Account{ + Platform: PlatformGemini, + Type: AccountTypeOAuth, + Credentials: map[string]any{ + "access_token": "old-at", + "refresh_token": "old-rt", + "oauth_type": "ai_studio", + "tier_id": "aistudio_free", + }, + } + + info, err := svc.RefreshAccountToken(context.Background(), account) + if err != nil { + t.Fatalf("RefreshAccountToken 返回错误: %v", err) + } + if info.AccessToken != "refreshed-at" { + t.Fatalf("AccessToken 不匹配: got=%q", info.AccessToken) + } + if info.OAuthType != "ai_studio" { + t.Fatalf("OAuthType 不匹配: got=%q", info.OAuthType) + } +} + +func TestGeminiOAuthService_RefreshAccountToken_CodeAssist_WithProjectID(t *testing.T) { + t.Parallel() + + client := &mockGeminiOAuthClient{ + refreshTokenFunc: func(ctx context.Context, oauthType, refreshToken, proxyURL string) (*geminicli.TokenResponse, error) { + return &geminicli.TokenResponse{ + AccessToken: "refreshed", + RefreshToken: "new-rt", + ExpiresIn: 3600, + }, nil + }, + } + + svc := NewGeminiOAuthService(&mockGeminiProxyRepo{}, client, nil, &config.Config{}) + defer svc.Stop() + + account := &Account{ + Platform: PlatformGemini, + Type: AccountTypeOAuth, + Credentials: map[string]any{ + "access_token": "old-at", + "refresh_token": "old-rt", + "oauth_type": "code_assist", + "project_id": "my-project", + "tier_id": "gcp_standard", + }, + } + + info, err := svc.RefreshAccountToken(context.Background(), account) + if err != nil { + t.Fatalf("RefreshAccountToken 返回错误: %v", err) + } + if info.ProjectID != "my-project" { + t.Fatalf("ProjectID 应保留: got=%q", info.ProjectID) + } + if info.TierID != GeminiTierGCPStandard { + t.Fatalf("TierID 不匹配: got=%q want=%q", info.TierID, GeminiTierGCPStandard) + } + if info.OAuthType != "code_assist" { + t.Fatalf("OAuthType 不匹配: got=%q", info.OAuthType) + } +} + +func TestGeminiOAuthService_RefreshAccountToken_DefaultOAuthType(t *testing.T) { + t.Parallel() + + client := &mockGeminiOAuthClient{ + refreshTokenFunc: func(ctx context.Context, oauthType, refreshToken, proxyURL string) (*geminicli.TokenResponse, error) { + if oauthType != "code_assist" { + t.Errorf("默认 oauthType 应为 code_assist: got=%q", oauthType) + } + return &geminicli.TokenResponse{ + AccessToken: "refreshed", + ExpiresIn: 3600, + }, nil + }, + } + + svc := NewGeminiOAuthService(&mockGeminiProxyRepo{}, client, nil, &config.Config{}) + defer svc.Stop() + + // 无 oauth_type 凭据的旧账号 + account := &Account{ + Platform: PlatformGemini, + Type: AccountTypeOAuth, + Credentials: map[string]any{ + "refresh_token": "old-rt", + "project_id": "proj", + "tier_id": "STANDARD", + }, + } + + info, err := svc.RefreshAccountToken(context.Background(), account) + if err != nil { + t.Fatalf("RefreshAccountToken 返回错误: %v", err) + } + if info.OAuthType != "code_assist" { + t.Fatalf("OAuthType 应默认为 code_assist: got=%q", info.OAuthType) + } +} + +func TestGeminiOAuthService_RefreshAccountToken_WithProxy(t *testing.T) { + t.Parallel() + + proxyRepo := &mockGeminiProxyRepo{ + getByIDFunc: func(ctx context.Context, id int64) (*Proxy, error) { + return &Proxy{ + Protocol: "http", + Host: "proxy.test", + Port: 3128, + }, nil + }, + } + + client := &mockGeminiOAuthClient{ + refreshTokenFunc: func(ctx context.Context, oauthType, refreshToken, proxyURL string) (*geminicli.TokenResponse, error) { + if proxyURL != "http://proxy.test:3128" { + t.Errorf("proxyURL 不匹配: got=%q", proxyURL) + } + return &geminicli.TokenResponse{ + AccessToken: "refreshed", + ExpiresIn: 3600, + }, nil + }, + } + + svc := NewGeminiOAuthService(proxyRepo, client, nil, &config.Config{}) + defer svc.Stop() + + proxyID := int64(5) + account := &Account{ + Platform: PlatformGemini, + Type: AccountTypeOAuth, + ProxyID: &proxyID, + Credentials: map[string]any{ + "refresh_token": "rt", + "oauth_type": "code_assist", + "project_id": "proj", + }, + } + + _, err := svc.RefreshAccountToken(context.Background(), account) + if err != nil { + t.Fatalf("RefreshAccountToken 返回错误: %v", err) + } +} + +func TestGeminiOAuthService_RefreshAccountToken_CodeAssist_NoProjectID_AutoDetect(t *testing.T) { + t.Parallel() + + client := &mockGeminiOAuthClient{ + refreshTokenFunc: func(ctx context.Context, oauthType, refreshToken, proxyURL string) (*geminicli.TokenResponse, error) { + return &geminicli.TokenResponse{ + AccessToken: "at", + ExpiresIn: 3600, + }, nil + }, + } + + codeAssist := &mockGeminiCodeAssistClient{ + loadCodeAssistFunc: func(ctx context.Context, accessToken, proxyURL string, req *geminicli.LoadCodeAssistRequest) (*geminicli.LoadCodeAssistResponse, error) { + return &geminicli.LoadCodeAssistResponse{ + CloudAICompanionProject: "auto-project-123", + CurrentTier: &geminicli.TierInfo{ID: "STANDARD"}, + }, nil + }, + } + + svc := NewGeminiOAuthService(&mockGeminiProxyRepo{}, client, codeAssist, &config.Config{}) + defer svc.Stop() + + account := &Account{ + Platform: PlatformGemini, + Type: AccountTypeOAuth, + Credentials: map[string]any{ + "refresh_token": "rt", + "oauth_type": "code_assist", + // 无 project_id,触发 fetchProjectID + }, + } + + info, err := svc.RefreshAccountToken(context.Background(), account) + if err != nil { + t.Fatalf("RefreshAccountToken 返回错误: %v", err) + } + if info.ProjectID != "auto-project-123" { + t.Fatalf("ProjectID 应为自动检测值: got=%q", info.ProjectID) + } + if info.TierID != GeminiTierGCPStandard { + t.Fatalf("TierID 不匹配: got=%q", info.TierID) + } +} + +func TestGeminiOAuthService_RefreshAccountToken_CodeAssist_NoProjectID_FailsEmpty(t *testing.T) { + t.Parallel() + + client := &mockGeminiOAuthClient{ + refreshTokenFunc: func(ctx context.Context, oauthType, refreshToken, proxyURL string) (*geminicli.TokenResponse, error) { + return &geminicli.TokenResponse{ + AccessToken: "at", + ExpiresIn: 3600, + }, nil + }, + } + + // 返回有 currentTier 但无 cloudaicompanionProject 的响应, + // 使 fetchProjectID 走"已注册用户"路径(尝试 Cloud Resource Manager -> 失败 -> 返回错误), + // 避免走 onboardUser 路径(5 次重试 x 2 秒 = 10 秒超时) + codeAssist := &mockGeminiCodeAssistClient{ + loadCodeAssistFunc: func(ctx context.Context, accessToken, proxyURL string, req *geminicli.LoadCodeAssistRequest) (*geminicli.LoadCodeAssistResponse, error) { + return &geminicli.LoadCodeAssistResponse{ + CurrentTier: &geminicli.TierInfo{ID: "STANDARD"}, + // 无 CloudAICompanionProject + }, nil + }, + } + + svc := NewGeminiOAuthService(&mockGeminiProxyRepo{}, client, codeAssist, &config.Config{}) + defer svc.Stop() + + account := &Account{ + Platform: PlatformGemini, + Type: AccountTypeOAuth, + Credentials: map[string]any{ + "refresh_token": "rt", + "oauth_type": "code_assist", + }, + } + + _, err := svc.RefreshAccountToken(context.Background(), account) + if err == nil { + t.Fatal("应返回错误(无法检测 project_id)") + } + if !strings.Contains(err.Error(), "project_id") { + t.Fatalf("错误信息应包含 project_id: got=%q", err.Error()) + } +} + +func TestGeminiOAuthService_RefreshAccountToken_GoogleOne_FreshCache(t *testing.T) { + t.Parallel() + + client := &mockGeminiOAuthClient{ + refreshTokenFunc: func(ctx context.Context, oauthType, refreshToken, proxyURL string) (*geminicli.TokenResponse, error) { + return &geminicli.TokenResponse{ + AccessToken: "at", + ExpiresIn: 3600, + }, nil + }, + } + + svc := NewGeminiOAuthService(&mockGeminiProxyRepo{}, client, nil, &config.Config{}) + defer svc.Stop() + + account := &Account{ + Platform: PlatformGemini, + Type: AccountTypeOAuth, + Credentials: map[string]any{ + "refresh_token": "rt", + "oauth_type": "google_one", + "project_id": "proj", + "tier_id": "google_ai_pro", + }, + Extra: map[string]any{ + // 缓存刷新时间在 24 小时内 + "drive_tier_updated_at": time.Now().Add(-1 * time.Hour).Format(time.RFC3339), + }, + } + + info, err := svc.RefreshAccountToken(context.Background(), account) + if err != nil { + t.Fatalf("RefreshAccountToken 返回错误: %v", err) + } + // 缓存新鲜,应使用已有的 tier_id + if info.TierID != GeminiTierGoogleAIPro { + t.Fatalf("TierID 应使用缓存值: got=%q want=%q", info.TierID, GeminiTierGoogleAIPro) + } +} + +func TestGeminiOAuthService_RefreshAccountToken_GoogleOne_NoTierID_DefaultsFree(t *testing.T) { + t.Parallel() + + client := &mockGeminiOAuthClient{ + refreshTokenFunc: func(ctx context.Context, oauthType, refreshToken, proxyURL string) (*geminicli.TokenResponse, error) { + return &geminicli.TokenResponse{ + AccessToken: "at", + ExpiresIn: 3600, + }, nil + }, + } + + svc := NewGeminiOAuthService(&mockGeminiProxyRepo{}, client, nil, &config.Config{}) + defer svc.Stop() + + account := &Account{ + Platform: PlatformGemini, + Type: AccountTypeOAuth, + Credentials: map[string]any{ + "refresh_token": "rt", + "oauth_type": "google_one", + "project_id": "proj", + // 无 tier_id + }, + } + + info, err := svc.RefreshAccountToken(context.Background(), account) + if err != nil { + t.Fatalf("RefreshAccountToken 返回错误: %v", err) + } + // FetchGoogleOneTier 会被调用但 oauthClient(此处 mock)不实现 Drive API, + // svc.FetchGoogleOneTier 使用真实 DriveClient 会失败,最终回退到默认值。 + // 由于没有 tier_id 且 FetchGoogleOneTier 失败,应默认为 google_one_free + if info.TierID != GeminiTierGoogleOneFree { + t.Fatalf("TierID 应为默认 free: got=%q", info.TierID) + } +} + +func TestGeminiOAuthService_RefreshAccountToken_UnauthorizedClient_Fallback(t *testing.T) { + t.Parallel() + + callCount := 0 + client := &mockGeminiOAuthClient{ + refreshTokenFunc: func(ctx context.Context, oauthType, refreshToken, proxyURL string) (*geminicli.TokenResponse, error) { + callCount++ + if oauthType == "code_assist" { + return nil, fmt.Errorf("unauthorized_client: client mismatch") + } + // ai_studio 路径成功 + return &geminicli.TokenResponse{ + AccessToken: "recovered", + ExpiresIn: 3600, + }, nil + }, + } + + // 启用自定义 OAuth 客户端以触发 fallback 路径 + cfg := &config.Config{ + Gemini: config.GeminiConfig{ + OAuth: config.GeminiOAuthConfig{ + ClientID: "custom-id", + ClientSecret: "custom-secret", + }, + }, + } + + svc := NewGeminiOAuthService(&mockGeminiProxyRepo{}, client, nil, cfg) + defer svc.Stop() + + account := &Account{ + Platform: PlatformGemini, + Type: AccountTypeOAuth, + Credentials: map[string]any{ + "refresh_token": "rt", + "oauth_type": "code_assist", + "project_id": "proj", + "tier_id": "gcp_standard", + }, + } + + info, err := svc.RefreshAccountToken(context.Background(), account) + if err != nil { + t.Fatalf("RefreshAccountToken 应在 fallback 后成功: %v", err) + } + if info.AccessToken != "recovered" { + t.Fatalf("AccessToken 不匹配: got=%q", info.AccessToken) + } +} + +func TestGeminiOAuthService_RefreshAccountToken_UnauthorizedClient_NoFallback(t *testing.T) { + t.Parallel() + + client := &mockGeminiOAuthClient{ + refreshTokenFunc: func(ctx context.Context, oauthType, refreshToken, proxyURL string) (*geminicli.TokenResponse, error) { + return nil, fmt.Errorf("unauthorized_client: client mismatch") + }, + } + + // 无自定义 OAuth 客户端,无法 fallback + svc := NewGeminiOAuthService(&mockGeminiProxyRepo{}, client, nil, &config.Config{}) + defer svc.Stop() + + account := &Account{ + Platform: PlatformGemini, + Type: AccountTypeOAuth, + Credentials: map[string]any{ + "refresh_token": "rt", + "oauth_type": "code_assist", + "project_id": "proj", + }, + } + + _, err := svc.RefreshAccountToken(context.Background(), account) + if err == nil { + t.Fatal("应返回错误(无 fallback)") + } + if !strings.Contains(err.Error(), "OAuth client mismatch") { + t.Fatalf("错误应包含 OAuth client mismatch: got=%q", err.Error()) + } +} + +// ===================== +// 新增测试:GeminiOAuthService.ExchangeCode +// ===================== + +func TestGeminiOAuthService_ExchangeCode_SessionNotFound(t *testing.T) { + t.Parallel() + + svc := NewGeminiOAuthService(nil, nil, nil, &config.Config{}) + defer svc.Stop() + + _, err := svc.ExchangeCode(context.Background(), &GeminiExchangeCodeInput{ + SessionID: "nonexistent", + State: "some-state", + Code: "some-code", + }) + if err == nil { + t.Fatal("应返回错误(session 不存在)") + } + if !strings.Contains(err.Error(), "session not found") { + t.Fatalf("错误信息不匹配: got=%q", err.Error()) + } +} + +func TestGeminiOAuthService_ExchangeCode_InvalidState(t *testing.T) { + t.Parallel() + + svc := NewGeminiOAuthService(nil, nil, nil, &config.Config{}) + defer svc.Stop() + + // 手动创建 session(必须设置 CreatedAt,否则会因 TTL 过期被拒绝) + svc.sessionStore.Set("test-session", &geminicli.OAuthSession{ + State: "correct-state", + CodeVerifier: "verifier", + OAuthType: "ai_studio", + CreatedAt: time.Now(), + }) + + _, err := svc.ExchangeCode(context.Background(), &GeminiExchangeCodeInput{ + SessionID: "test-session", + State: "wrong-state", + Code: "code", + }) + if err == nil { + t.Fatal("应返回错误(state 不匹配)") + } + if !strings.Contains(err.Error(), "invalid state") { + t.Fatalf("错误信息不匹配: got=%q", err.Error()) + } +} + +func TestGeminiOAuthService_ExchangeCode_EmptyState(t *testing.T) { + t.Parallel() + + svc := NewGeminiOAuthService(nil, nil, nil, &config.Config{}) + defer svc.Stop() + + svc.sessionStore.Set("test-session", &geminicli.OAuthSession{ + State: "correct-state", + CodeVerifier: "verifier", + CreatedAt: time.Now(), + }) + + _, err := svc.ExchangeCode(context.Background(), &GeminiExchangeCodeInput{ + SessionID: "test-session", + State: "", + Code: "code", + }) + if err == nil { + t.Fatal("应返回错误(空 state)") + } +} + +// ===================== +// 辅助函数 +// ===================== + +func assertCredStr(t *testing.T, creds map[string]any, key, want string) { + t.Helper() + raw, ok := creds[key] + if !ok { + t.Fatalf("creds 缺少 key=%q", key) + } + got, ok := raw.(string) + if !ok { + t.Fatalf("creds[%q] 不是 string: %T", key, raw) + } + if got != want { + t.Fatalf("creds[%q] = %q, want %q", key, got, want) + } +} + +func credKeys(m map[string]any) []string { + keys := make([]string, 0, len(m)) + for k := range m { + keys = append(keys, k) + } + return keys +} diff --git a/backend/internal/service/group.go b/backend/internal/service/group.go index e9423ddb..86ece03f 100644 --- a/backend/internal/service/group.go +++ b/backend/internal/service/group.go @@ -26,6 +26,12 @@ type Group struct { ImagePrice2K *float64 ImagePrice4K *float64 + // Sora 按次计费配置(阶段 1) + SoraImagePrice360 *float64 + SoraImagePrice540 *float64 + SoraVideoPricePerRequest *float64 + SoraVideoPricePerRequestHD *float64 + // Claude Code 客户端限制 ClaudeCodeOnly bool FallbackGroupID *int64 @@ -95,6 +101,18 @@ func (g *Group) GetImagePrice(imageSize string) *float64 { } } +// GetSoraImagePrice 根据 Sora 图片尺寸返回价格(360/540) +func (g *Group) GetSoraImagePrice(imageSize string) *float64 { + switch imageSize { + case "360": + return g.SoraImagePrice360 + case "540": + return g.SoraImagePrice540 + default: + return g.SoraImagePrice360 + } +} + // IsGroupContextValid reports whether a group from context has the fields required for routing decisions. func IsGroupContextValid(group *Group) bool { if group == nil { diff --git a/backend/internal/service/idempotency.go b/backend/internal/service/idempotency.go new file mode 100644 index 00000000..2a86bd60 --- /dev/null +++ b/backend/internal/service/idempotency.go @@ -0,0 +1,471 @@ +package service + +import ( + "context" + "crypto/sha256" + "encoding/hex" + "encoding/json" + "errors" + "fmt" + "strconv" + "strings" + "sync" + "time" + + infraerrors "github.com/Wei-Shaw/sub2api/internal/pkg/errors" + "github.com/Wei-Shaw/sub2api/internal/util/logredact" +) + +const ( + IdempotencyStatusProcessing = "processing" + IdempotencyStatusSucceeded = "succeeded" + IdempotencyStatusFailedRetryable = "failed_retryable" +) + +var ( + ErrIdempotencyKeyRequired = infraerrors.BadRequest("IDEMPOTENCY_KEY_REQUIRED", "idempotency key is required") + ErrIdempotencyKeyInvalid = infraerrors.BadRequest("IDEMPOTENCY_KEY_INVALID", "idempotency key is invalid") + ErrIdempotencyKeyConflict = infraerrors.Conflict("IDEMPOTENCY_KEY_CONFLICT", "idempotency key reused with different payload") + ErrIdempotencyInProgress = infraerrors.Conflict("IDEMPOTENCY_IN_PROGRESS", "idempotent request is still processing") + ErrIdempotencyRetryBackoff = infraerrors.Conflict("IDEMPOTENCY_RETRY_BACKOFF", "idempotent request is in retry backoff window") + ErrIdempotencyStoreUnavail = infraerrors.ServiceUnavailable("IDEMPOTENCY_STORE_UNAVAILABLE", "idempotency store unavailable") + ErrIdempotencyInvalidPayload = infraerrors.BadRequest("IDEMPOTENCY_PAYLOAD_INVALID", "failed to normalize request payload") +) + +type IdempotencyRecord struct { + ID int64 + Scope string + IdempotencyKeyHash string + RequestFingerprint string + Status string + ResponseStatus *int + ResponseBody *string + ErrorReason *string + LockedUntil *time.Time + ExpiresAt time.Time + CreatedAt time.Time + UpdatedAt time.Time +} + +type IdempotencyRepository interface { + CreateProcessing(ctx context.Context, record *IdempotencyRecord) (bool, error) + GetByScopeAndKeyHash(ctx context.Context, scope, keyHash string) (*IdempotencyRecord, error) + TryReclaim(ctx context.Context, id int64, fromStatus string, now, newLockedUntil, newExpiresAt time.Time) (bool, error) + ExtendProcessingLock(ctx context.Context, id int64, requestFingerprint string, newLockedUntil, newExpiresAt time.Time) (bool, error) + MarkSucceeded(ctx context.Context, id int64, responseStatus int, responseBody string, expiresAt time.Time) error + MarkFailedRetryable(ctx context.Context, id int64, errorReason string, lockedUntil, expiresAt time.Time) error + DeleteExpired(ctx context.Context, now time.Time, limit int) (int64, error) +} + +type IdempotencyConfig struct { + DefaultTTL time.Duration + SystemOperationTTL time.Duration + ProcessingTimeout time.Duration + FailedRetryBackoff time.Duration + MaxStoredResponseLen int + ObserveOnly bool +} + +func DefaultIdempotencyConfig() IdempotencyConfig { + return IdempotencyConfig{ + DefaultTTL: 24 * time.Hour, + SystemOperationTTL: 1 * time.Hour, + ProcessingTimeout: 30 * time.Second, + FailedRetryBackoff: 5 * time.Second, + MaxStoredResponseLen: 64 * 1024, + ObserveOnly: true, // 默认先观察再强制,避免老客户端立刻中断 + } +} + +type IdempotencyExecuteOptions struct { + Scope string + ActorScope string + Method string + Route string + IdempotencyKey string + Payload any + TTL time.Duration + RequireKey bool +} + +type IdempotencyExecuteResult struct { + Data any + Replayed bool +} + +type IdempotencyCoordinator struct { + repo IdempotencyRepository + cfg IdempotencyConfig +} + +var ( + defaultIdempotencyMu sync.RWMutex + defaultIdempotencySvc *IdempotencyCoordinator +) + +func SetDefaultIdempotencyCoordinator(svc *IdempotencyCoordinator) { + defaultIdempotencyMu.Lock() + defaultIdempotencySvc = svc + defaultIdempotencyMu.Unlock() +} + +func DefaultIdempotencyCoordinator() *IdempotencyCoordinator { + defaultIdempotencyMu.RLock() + defer defaultIdempotencyMu.RUnlock() + return defaultIdempotencySvc +} + +func DefaultWriteIdempotencyTTL() time.Duration { + defaultTTL := DefaultIdempotencyConfig().DefaultTTL + if coordinator := DefaultIdempotencyCoordinator(); coordinator != nil && coordinator.cfg.DefaultTTL > 0 { + return coordinator.cfg.DefaultTTL + } + return defaultTTL +} + +func DefaultSystemOperationIdempotencyTTL() time.Duration { + defaultTTL := DefaultIdempotencyConfig().SystemOperationTTL + if coordinator := DefaultIdempotencyCoordinator(); coordinator != nil && coordinator.cfg.SystemOperationTTL > 0 { + return coordinator.cfg.SystemOperationTTL + } + return defaultTTL +} + +func NewIdempotencyCoordinator(repo IdempotencyRepository, cfg IdempotencyConfig) *IdempotencyCoordinator { + return &IdempotencyCoordinator{ + repo: repo, + cfg: cfg, + } +} + +func NormalizeIdempotencyKey(raw string) (string, error) { + key := strings.TrimSpace(raw) + if key == "" { + return "", nil + } + if len(key) > 128 { + return "", ErrIdempotencyKeyInvalid + } + for _, r := range key { + if r < 33 || r > 126 { + return "", ErrIdempotencyKeyInvalid + } + } + return key, nil +} + +func HashIdempotencyKey(key string) string { + sum := sha256.Sum256([]byte(key)) + return hex.EncodeToString(sum[:]) +} + +func BuildIdempotencyFingerprint(method, route, actorScope string, payload any) (string, error) { + if method == "" { + method = "POST" + } + if route == "" { + route = "/" + } + if actorScope == "" { + actorScope = "anonymous" + } + + raw, err := json.Marshal(payload) + if err != nil { + return "", ErrIdempotencyInvalidPayload.WithCause(err) + } + sum := sha256.Sum256([]byte( + strings.ToUpper(method) + "\n" + route + "\n" + actorScope + "\n" + string(raw), + )) + return hex.EncodeToString(sum[:]), nil +} + +func RetryAfterSecondsFromError(err error) int { + appErr := new(infraerrors.ApplicationError) + if !errors.As(err, &appErr) || appErr == nil || appErr.Metadata == nil { + return 0 + } + v := strings.TrimSpace(appErr.Metadata["retry_after"]) + if v == "" { + return 0 + } + seconds, convErr := strconv.Atoi(v) + if convErr != nil || seconds <= 0 { + return 0 + } + return seconds +} + +func (c *IdempotencyCoordinator) Execute( + ctx context.Context, + opts IdempotencyExecuteOptions, + execute func(context.Context) (any, error), +) (*IdempotencyExecuteResult, error) { + if execute == nil { + return nil, infraerrors.InternalServer("IDEMPOTENCY_EXECUTOR_NIL", "idempotency executor is nil") + } + + key, err := NormalizeIdempotencyKey(opts.IdempotencyKey) + if err != nil { + return nil, err + } + if key == "" { + if opts.RequireKey && !c.cfg.ObserveOnly { + return nil, ErrIdempotencyKeyRequired + } + data, execErr := execute(ctx) + if execErr != nil { + return nil, execErr + } + return &IdempotencyExecuteResult{Data: data}, nil + } + if c.repo == nil { + RecordIdempotencyStoreUnavailable(opts.Route, opts.Scope, "repo_nil") + return nil, ErrIdempotencyStoreUnavail + } + + if opts.Scope == "" { + return nil, infraerrors.BadRequest("IDEMPOTENCY_SCOPE_REQUIRED", "idempotency scope is required") + } + + fingerprint, err := BuildIdempotencyFingerprint(opts.Method, opts.Route, opts.ActorScope, opts.Payload) + if err != nil { + return nil, err + } + + ttl := opts.TTL + if ttl <= 0 { + ttl = c.cfg.DefaultTTL + } + now := time.Now() + expiresAt := now.Add(ttl) + lockedUntil := now.Add(c.cfg.ProcessingTimeout) + keyHash := HashIdempotencyKey(key) + + record := &IdempotencyRecord{ + Scope: opts.Scope, + IdempotencyKeyHash: keyHash, + RequestFingerprint: fingerprint, + Status: IdempotencyStatusProcessing, + LockedUntil: &lockedUntil, + ExpiresAt: expiresAt, + } + + owner, err := c.repo.CreateProcessing(ctx, record) + if err != nil { + RecordIdempotencyStoreUnavailable(opts.Route, opts.Scope, "create_processing_error") + logIdempotencyAudit(opts.Route, opts.Scope, keyHash, "unknown->store_unavailable", false, map[string]string{ + "operation": "create_processing", + }) + return nil, ErrIdempotencyStoreUnavail.WithCause(err) + } + if owner { + recordIdempotencyClaim(opts.Route, opts.Scope, map[string]string{"mode": "new_claim"}) + logIdempotencyAudit(opts.Route, opts.Scope, keyHash, "none->processing", false, map[string]string{ + "claim_mode": "new", + }) + } + if !owner { + existing, getErr := c.repo.GetByScopeAndKeyHash(ctx, opts.Scope, keyHash) + if getErr != nil { + RecordIdempotencyStoreUnavailable(opts.Route, opts.Scope, "get_existing_error") + logIdempotencyAudit(opts.Route, opts.Scope, keyHash, "unknown->store_unavailable", false, map[string]string{ + "operation": "get_existing", + }) + return nil, ErrIdempotencyStoreUnavail.WithCause(getErr) + } + if existing == nil { + RecordIdempotencyStoreUnavailable(opts.Route, opts.Scope, "missing_existing") + logIdempotencyAudit(opts.Route, opts.Scope, keyHash, "unknown->store_unavailable", false, map[string]string{ + "operation": "missing_existing", + }) + return nil, ErrIdempotencyStoreUnavail + } + if existing.RequestFingerprint != fingerprint { + recordIdempotencyConflict(opts.Route, opts.Scope, map[string]string{"reason": "fingerprint_mismatch"}) + logIdempotencyAudit(opts.Route, opts.Scope, keyHash, "existing->fingerprint_mismatch", false, nil) + return nil, ErrIdempotencyKeyConflict + } + reclaimedByExpired := false + if !existing.ExpiresAt.After(now) { + taken, reclaimErr := c.repo.TryReclaim(ctx, existing.ID, existing.Status, now, lockedUntil, expiresAt) + if reclaimErr != nil { + RecordIdempotencyStoreUnavailable(opts.Route, opts.Scope, "try_reclaim_expired_error") + logIdempotencyAudit(opts.Route, opts.Scope, keyHash, existing.Status+"->store_unavailable", false, map[string]string{ + "operation": "try_reclaim_expired", + }) + return nil, ErrIdempotencyStoreUnavail.WithCause(reclaimErr) + } + if taken { + reclaimedByExpired = true + recordIdempotencyClaim(opts.Route, opts.Scope, map[string]string{"mode": "expired_reclaim"}) + logIdempotencyAudit(opts.Route, opts.Scope, keyHash, existing.Status+"->processing", false, map[string]string{ + "claim_mode": "expired_reclaim", + }) + record.ID = existing.ID + } else { + latest, latestErr := c.repo.GetByScopeAndKeyHash(ctx, opts.Scope, keyHash) + if latestErr != nil { + RecordIdempotencyStoreUnavailable(opts.Route, opts.Scope, "get_existing_after_expired_reclaim_error") + logIdempotencyAudit(opts.Route, opts.Scope, keyHash, "unknown->store_unavailable", false, map[string]string{ + "operation": "get_existing_after_expired_reclaim", + }) + return nil, ErrIdempotencyStoreUnavail.WithCause(latestErr) + } + if latest == nil { + RecordIdempotencyStoreUnavailable(opts.Route, opts.Scope, "missing_existing_after_expired_reclaim") + logIdempotencyAudit(opts.Route, opts.Scope, keyHash, "unknown->store_unavailable", false, map[string]string{ + "operation": "missing_existing_after_expired_reclaim", + }) + return nil, ErrIdempotencyStoreUnavail + } + if latest.RequestFingerprint != fingerprint { + recordIdempotencyConflict(opts.Route, opts.Scope, map[string]string{"reason": "fingerprint_mismatch"}) + logIdempotencyAudit(opts.Route, opts.Scope, keyHash, "existing->fingerprint_mismatch", false, nil) + return nil, ErrIdempotencyKeyConflict + } + existing = latest + } + } + + if !reclaimedByExpired { + switch existing.Status { + case IdempotencyStatusSucceeded: + data, parseErr := c.decodeStoredResponse(existing.ResponseBody) + if parseErr != nil { + RecordIdempotencyStoreUnavailable(opts.Route, opts.Scope, "decode_stored_response_error") + logIdempotencyAudit(opts.Route, opts.Scope, keyHash, "succeeded->store_unavailable", false, map[string]string{ + "operation": "decode_stored_response", + }) + return nil, ErrIdempotencyStoreUnavail.WithCause(parseErr) + } + recordIdempotencyReplay(opts.Route, opts.Scope, nil) + logIdempotencyAudit(opts.Route, opts.Scope, keyHash, "succeeded->replayed", true, nil) + return &IdempotencyExecuteResult{Data: data, Replayed: true}, nil + case IdempotencyStatusProcessing: + recordIdempotencyConflict(opts.Route, opts.Scope, map[string]string{"reason": "in_progress"}) + logIdempotencyAudit(opts.Route, opts.Scope, keyHash, "processing->conflict", false, nil) + return nil, c.conflictWithRetryAfter(ErrIdempotencyInProgress, existing.LockedUntil, now) + case IdempotencyStatusFailedRetryable: + if existing.LockedUntil != nil && existing.LockedUntil.After(now) { + recordIdempotencyConflict(opts.Route, opts.Scope, map[string]string{"reason": "retry_backoff"}) + recordIdempotencyRetryBackoff(opts.Route, opts.Scope, nil) + logIdempotencyAudit(opts.Route, opts.Scope, keyHash, "failed_retryable->retry_backoff_conflict", false, nil) + return nil, c.conflictWithRetryAfter(ErrIdempotencyRetryBackoff, existing.LockedUntil, now) + } + taken, reclaimErr := c.repo.TryReclaim(ctx, existing.ID, IdempotencyStatusFailedRetryable, now, lockedUntil, expiresAt) + if reclaimErr != nil { + RecordIdempotencyStoreUnavailable(opts.Route, opts.Scope, "try_reclaim_error") + logIdempotencyAudit(opts.Route, opts.Scope, keyHash, "failed_retryable->store_unavailable", false, map[string]string{ + "operation": "try_reclaim", + }) + return nil, ErrIdempotencyStoreUnavail.WithCause(reclaimErr) + } + if !taken { + recordIdempotencyConflict(opts.Route, opts.Scope, map[string]string{"reason": "reclaim_race"}) + logIdempotencyAudit(opts.Route, opts.Scope, keyHash, "failed_retryable->conflict", false, map[string]string{ + "conflict": "reclaim_race", + }) + return nil, c.conflictWithRetryAfter(ErrIdempotencyInProgress, existing.LockedUntil, now) + } + recordIdempotencyClaim(opts.Route, opts.Scope, map[string]string{"mode": "reclaim"}) + logIdempotencyAudit(opts.Route, opts.Scope, keyHash, "failed_retryable->processing", false, map[string]string{ + "claim_mode": "reclaim", + }) + record.ID = existing.ID + default: + recordIdempotencyConflict(opts.Route, opts.Scope, map[string]string{"reason": "unexpected_status"}) + logIdempotencyAudit(opts.Route, opts.Scope, keyHash, "existing->conflict", false, map[string]string{ + "status": existing.Status, + }) + return nil, ErrIdempotencyKeyConflict + } + } + } + + if record.ID == 0 { + RecordIdempotencyStoreUnavailable(opts.Route, opts.Scope, "record_id_missing") + logIdempotencyAudit(opts.Route, opts.Scope, keyHash, "processing->store_unavailable", false, map[string]string{ + "operation": "record_id_missing", + }) + return nil, ErrIdempotencyStoreUnavail + } + + execStart := time.Now() + defer func() { + recordIdempotencyProcessingDuration(opts.Route, opts.Scope, time.Since(execStart), nil) + }() + + data, execErr := execute(ctx) + if execErr != nil { + backoffUntil := time.Now().Add(c.cfg.FailedRetryBackoff) + reason := infraerrors.Reason(execErr) + if reason == "" { + reason = "EXECUTION_FAILED" + } + recordIdempotencyRetryBackoff(opts.Route, opts.Scope, nil) + logIdempotencyAudit(opts.Route, opts.Scope, keyHash, "processing->failed_retryable", false, map[string]string{ + "reason": reason, + }) + if markErr := c.repo.MarkFailedRetryable(ctx, record.ID, reason, backoffUntil, expiresAt); markErr != nil { + RecordIdempotencyStoreUnavailable(opts.Route, opts.Scope, "mark_failed_retryable_error") + logIdempotencyAudit(opts.Route, opts.Scope, keyHash, "processing->store_unavailable", false, map[string]string{ + "operation": "mark_failed_retryable", + }) + } + return nil, execErr + } + + storedBody, marshalErr := c.marshalStoredResponse(data) + if marshalErr != nil { + RecordIdempotencyStoreUnavailable(opts.Route, opts.Scope, "marshal_response_error") + logIdempotencyAudit(opts.Route, opts.Scope, keyHash, "processing->store_unavailable", false, map[string]string{ + "operation": "marshal_response", + }) + return nil, ErrIdempotencyStoreUnavail.WithCause(marshalErr) + } + if markErr := c.repo.MarkSucceeded(ctx, record.ID, 200, storedBody, expiresAt); markErr != nil { + RecordIdempotencyStoreUnavailable(opts.Route, opts.Scope, "mark_succeeded_error") + logIdempotencyAudit(opts.Route, opts.Scope, keyHash, "processing->store_unavailable", false, map[string]string{ + "operation": "mark_succeeded", + }) + return nil, ErrIdempotencyStoreUnavail.WithCause(markErr) + } + logIdempotencyAudit(opts.Route, opts.Scope, keyHash, "processing->succeeded", false, nil) + + return &IdempotencyExecuteResult{Data: data}, nil +} + +func (c *IdempotencyCoordinator) conflictWithRetryAfter(base *infraerrors.ApplicationError, lockedUntil *time.Time, now time.Time) error { + if lockedUntil == nil { + return base + } + sec := int(lockedUntil.Sub(now).Seconds()) + if sec <= 0 { + sec = 1 + } + return base.WithMetadata(map[string]string{"retry_after": strconv.Itoa(sec)}) +} + +func (c *IdempotencyCoordinator) marshalStoredResponse(data any) (string, error) { + raw, err := json.Marshal(data) + if err != nil { + return "", err + } + redacted := logredact.RedactText(string(raw)) + if c.cfg.MaxStoredResponseLen > 0 && len(redacted) > c.cfg.MaxStoredResponseLen { + redacted = redacted[:c.cfg.MaxStoredResponseLen] + "...(truncated)" + } + return redacted, nil +} + +func (c *IdempotencyCoordinator) decodeStoredResponse(stored *string) (any, error) { + if stored == nil || strings.TrimSpace(*stored) == "" { + return map[string]any{}, nil + } + var out any + if err := json.Unmarshal([]byte(*stored), &out); err != nil { + return nil, fmt.Errorf("decode stored response: %w", err) + } + return out, nil +} diff --git a/backend/internal/service/idempotency_cleanup_service.go b/backend/internal/service/idempotency_cleanup_service.go new file mode 100644 index 00000000..aaf6949a --- /dev/null +++ b/backend/internal/service/idempotency_cleanup_service.go @@ -0,0 +1,91 @@ +package service + +import ( + "context" + "sync" + "time" + + "github.com/Wei-Shaw/sub2api/internal/config" + "github.com/Wei-Shaw/sub2api/internal/pkg/logger" +) + +// IdempotencyCleanupService 定期清理已过期的幂等记录,避免表无限增长。 +type IdempotencyCleanupService struct { + repo IdempotencyRepository + interval time.Duration + batch int + + startOnce sync.Once + stopOnce sync.Once + stopCh chan struct{} +} + +func NewIdempotencyCleanupService(repo IdempotencyRepository, cfg *config.Config) *IdempotencyCleanupService { + interval := 60 * time.Second + batch := 500 + if cfg != nil { + if cfg.Idempotency.CleanupIntervalSeconds > 0 { + interval = time.Duration(cfg.Idempotency.CleanupIntervalSeconds) * time.Second + } + if cfg.Idempotency.CleanupBatchSize > 0 { + batch = cfg.Idempotency.CleanupBatchSize + } + } + return &IdempotencyCleanupService{ + repo: repo, + interval: interval, + batch: batch, + stopCh: make(chan struct{}), + } +} + +func (s *IdempotencyCleanupService) Start() { + if s == nil || s.repo == nil { + return + } + s.startOnce.Do(func() { + logger.LegacyPrintf("service.idempotency_cleanup", "[IdempotencyCleanup] started interval=%s batch=%d", s.interval, s.batch) + go s.runLoop() + }) +} + +func (s *IdempotencyCleanupService) Stop() { + if s == nil { + return + } + s.stopOnce.Do(func() { + close(s.stopCh) + logger.LegacyPrintf("service.idempotency_cleanup", "[IdempotencyCleanup] stopped") + }) +} + +func (s *IdempotencyCleanupService) runLoop() { + ticker := time.NewTicker(s.interval) + defer ticker.Stop() + + // 启动后先清理一轮,防止重启后积压。 + s.cleanupOnce() + + for { + select { + case <-ticker.C: + s.cleanupOnce() + case <-s.stopCh: + return + } + } +} + +func (s *IdempotencyCleanupService) cleanupOnce() { + ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) + defer cancel() + + deleted, err := s.repo.DeleteExpired(ctx, time.Now(), s.batch) + if err != nil { + logger.LegacyPrintf("service.idempotency_cleanup", "[IdempotencyCleanup] cleanup failed err=%v", err) + return + } + if deleted > 0 { + logger.LegacyPrintf("service.idempotency_cleanup", "[IdempotencyCleanup] cleaned expired records count=%d", deleted) + } +} diff --git a/backend/internal/service/idempotency_cleanup_service_test.go b/backend/internal/service/idempotency_cleanup_service_test.go new file mode 100644 index 00000000..556ff364 --- /dev/null +++ b/backend/internal/service/idempotency_cleanup_service_test.go @@ -0,0 +1,69 @@ +package service + +import ( + "context" + "testing" + "time" + + "github.com/Wei-Shaw/sub2api/internal/config" + "github.com/stretchr/testify/require" +) + +type idempotencyCleanupRepoStub struct { + deleteCalls int + lastLimit int + deleteErr error +} + +func (r *idempotencyCleanupRepoStub) CreateProcessing(context.Context, *IdempotencyRecord) (bool, error) { + return false, nil +} +func (r *idempotencyCleanupRepoStub) GetByScopeAndKeyHash(context.Context, string, string) (*IdempotencyRecord, error) { + return nil, nil +} +func (r *idempotencyCleanupRepoStub) TryReclaim(context.Context, int64, string, time.Time, time.Time, time.Time) (bool, error) { + return false, nil +} +func (r *idempotencyCleanupRepoStub) ExtendProcessingLock(context.Context, int64, string, time.Time, time.Time) (bool, error) { + return false, nil +} +func (r *idempotencyCleanupRepoStub) MarkSucceeded(context.Context, int64, int, string, time.Time) error { + return nil +} +func (r *idempotencyCleanupRepoStub) MarkFailedRetryable(context.Context, int64, string, time.Time, time.Time) error { + return nil +} +func (r *idempotencyCleanupRepoStub) DeleteExpired(_ context.Context, _ time.Time, limit int) (int64, error) { + r.deleteCalls++ + r.lastLimit = limit + if r.deleteErr != nil { + return 0, r.deleteErr + } + return 1, nil +} + +func TestNewIdempotencyCleanupService_UsesConfig(t *testing.T) { + repo := &idempotencyCleanupRepoStub{} + cfg := &config.Config{ + Idempotency: config.IdempotencyConfig{ + CleanupIntervalSeconds: 7, + CleanupBatchSize: 321, + }, + } + svc := NewIdempotencyCleanupService(repo, cfg) + require.Equal(t, 7*time.Second, svc.interval) + require.Equal(t, 321, svc.batch) +} + +func TestIdempotencyCleanupService_CleanupOnce(t *testing.T) { + repo := &idempotencyCleanupRepoStub{} + svc := NewIdempotencyCleanupService(repo, &config.Config{ + Idempotency: config.IdempotencyConfig{ + CleanupBatchSize: 99, + }, + }) + + svc.cleanupOnce() + require.Equal(t, 1, repo.deleteCalls) + require.Equal(t, 99, repo.lastLimit) +} diff --git a/backend/internal/service/idempotency_observability.go b/backend/internal/service/idempotency_observability.go new file mode 100644 index 00000000..f1bf2df2 --- /dev/null +++ b/backend/internal/service/idempotency_observability.go @@ -0,0 +1,171 @@ +package service + +import ( + "sort" + "strconv" + "strings" + "sync/atomic" + "time" + + "github.com/Wei-Shaw/sub2api/internal/pkg/logger" +) + +// IdempotencyMetricsSnapshot 提供幂等核心指标快照(进程内累计)。 +type IdempotencyMetricsSnapshot struct { + ClaimTotal uint64 `json:"claim_total"` + ReplayTotal uint64 `json:"replay_total"` + ConflictTotal uint64 `json:"conflict_total"` + RetryBackoffTotal uint64 `json:"retry_backoff_total"` + ProcessingDurationCount uint64 `json:"processing_duration_count"` + ProcessingDurationTotalMs float64 `json:"processing_duration_total_ms"` + StoreUnavailableTotal uint64 `json:"store_unavailable_total"` +} + +type idempotencyMetrics struct { + claimTotal atomic.Uint64 + replayTotal atomic.Uint64 + conflictTotal atomic.Uint64 + retryBackoffTotal atomic.Uint64 + processingDurationCount atomic.Uint64 + processingDurationMicros atomic.Uint64 + storeUnavailableTotal atomic.Uint64 +} + +var defaultIdempotencyMetrics idempotencyMetrics + +// GetIdempotencyMetricsSnapshot 返回当前幂等指标快照。 +func GetIdempotencyMetricsSnapshot() IdempotencyMetricsSnapshot { + totalMicros := defaultIdempotencyMetrics.processingDurationMicros.Load() + return IdempotencyMetricsSnapshot{ + ClaimTotal: defaultIdempotencyMetrics.claimTotal.Load(), + ReplayTotal: defaultIdempotencyMetrics.replayTotal.Load(), + ConflictTotal: defaultIdempotencyMetrics.conflictTotal.Load(), + RetryBackoffTotal: defaultIdempotencyMetrics.retryBackoffTotal.Load(), + ProcessingDurationCount: defaultIdempotencyMetrics.processingDurationCount.Load(), + ProcessingDurationTotalMs: float64(totalMicros) / 1000.0, + StoreUnavailableTotal: defaultIdempotencyMetrics.storeUnavailableTotal.Load(), + } +} + +func recordIdempotencyClaim(endpoint, scope string, attrs map[string]string) { + defaultIdempotencyMetrics.claimTotal.Add(1) + logIdempotencyMetric("idempotency_claim_total", endpoint, scope, "1", attrs) +} + +func recordIdempotencyReplay(endpoint, scope string, attrs map[string]string) { + defaultIdempotencyMetrics.replayTotal.Add(1) + logIdempotencyMetric("idempotency_replay_total", endpoint, scope, "1", attrs) +} + +func recordIdempotencyConflict(endpoint, scope string, attrs map[string]string) { + defaultIdempotencyMetrics.conflictTotal.Add(1) + logIdempotencyMetric("idempotency_conflict_total", endpoint, scope, "1", attrs) +} + +func recordIdempotencyRetryBackoff(endpoint, scope string, attrs map[string]string) { + defaultIdempotencyMetrics.retryBackoffTotal.Add(1) + logIdempotencyMetric("idempotency_retry_backoff_total", endpoint, scope, "1", attrs) +} + +func recordIdempotencyProcessingDuration(endpoint, scope string, duration time.Duration, attrs map[string]string) { + if duration < 0 { + duration = 0 + } + defaultIdempotencyMetrics.processingDurationCount.Add(1) + defaultIdempotencyMetrics.processingDurationMicros.Add(uint64(duration.Microseconds())) + logIdempotencyMetric("idempotency_processing_duration_ms", endpoint, scope, strconv.FormatFloat(duration.Seconds()*1000, 'f', 3, 64), attrs) +} + +// RecordIdempotencyStoreUnavailable 记录幂等存储不可用事件(用于降级路径观测)。 +func RecordIdempotencyStoreUnavailable(endpoint, scope, strategy string) { + defaultIdempotencyMetrics.storeUnavailableTotal.Add(1) + attrs := map[string]string{} + if strategy != "" { + attrs["strategy"] = strategy + } + logIdempotencyMetric("idempotency_store_unavailable_total", endpoint, scope, "1", attrs) +} + +func logIdempotencyAudit(endpoint, scope, keyHash, stateTransition string, replayed bool, attrs map[string]string) { + var b strings.Builder + builderWriteString(&b, "[IdempotencyAudit]") + builderWriteString(&b, " endpoint=") + builderWriteString(&b, safeAuditField(endpoint)) + builderWriteString(&b, " scope=") + builderWriteString(&b, safeAuditField(scope)) + builderWriteString(&b, " key_hash=") + builderWriteString(&b, safeAuditField(keyHash)) + builderWriteString(&b, " state_transition=") + builderWriteString(&b, safeAuditField(stateTransition)) + builderWriteString(&b, " replayed=") + builderWriteString(&b, strconv.FormatBool(replayed)) + if len(attrs) > 0 { + appendSortedAttrs(&b, attrs) + } + logger.LegacyPrintf("service.idempotency", "%s", b.String()) +} + +func logIdempotencyMetric(name, endpoint, scope, value string, attrs map[string]string) { + var b strings.Builder + builderWriteString(&b, "[IdempotencyMetric]") + builderWriteString(&b, " name=") + builderWriteString(&b, safeAuditField(name)) + builderWriteString(&b, " endpoint=") + builderWriteString(&b, safeAuditField(endpoint)) + builderWriteString(&b, " scope=") + builderWriteString(&b, safeAuditField(scope)) + builderWriteString(&b, " value=") + builderWriteString(&b, safeAuditField(value)) + if len(attrs) > 0 { + appendSortedAttrs(&b, attrs) + } + logger.LegacyPrintf("service.idempotency", "%s", b.String()) +} + +func appendSortedAttrs(builder *strings.Builder, attrs map[string]string) { + if len(attrs) == 0 { + return + } + keys := make([]string, 0, len(attrs)) + for k := range attrs { + keys = append(keys, k) + } + sort.Strings(keys) + for _, k := range keys { + builderWriteByte(builder, ' ') + builderWriteString(builder, k) + builderWriteByte(builder, '=') + builderWriteString(builder, safeAuditField(attrs[k])) + } +} + +func safeAuditField(v string) string { + value := strings.TrimSpace(v) + if value == "" { + return "-" + } + // 日志按 key=value 输出,替换空白避免解析歧义。 + value = strings.ReplaceAll(value, "\n", "_") + value = strings.ReplaceAll(value, "\r", "_") + value = strings.ReplaceAll(value, "\t", "_") + value = strings.ReplaceAll(value, " ", "_") + return value +} + +func resetIdempotencyMetricsForTest() { + defaultIdempotencyMetrics.claimTotal.Store(0) + defaultIdempotencyMetrics.replayTotal.Store(0) + defaultIdempotencyMetrics.conflictTotal.Store(0) + defaultIdempotencyMetrics.retryBackoffTotal.Store(0) + defaultIdempotencyMetrics.processingDurationCount.Store(0) + defaultIdempotencyMetrics.processingDurationMicros.Store(0) + defaultIdempotencyMetrics.storeUnavailableTotal.Store(0) +} + +func builderWriteString(builder *strings.Builder, value string) { + _, _ = builder.WriteString(value) +} + +func builderWriteByte(builder *strings.Builder, value byte) { + _ = builder.WriteByte(value) +} diff --git a/backend/internal/service/idempotency_test.go b/backend/internal/service/idempotency_test.go new file mode 100644 index 00000000..6ff75d1c --- /dev/null +++ b/backend/internal/service/idempotency_test.go @@ -0,0 +1,805 @@ +package service + +import ( + "context" + "errors" + "sync" + "sync/atomic" + "testing" + "time" + + infraerrors "github.com/Wei-Shaw/sub2api/internal/pkg/errors" + "github.com/stretchr/testify/require" +) + +type inMemoryIdempotencyRepo struct { + mu sync.Mutex + nextID int64 + data map[string]*IdempotencyRecord +} + +func newInMemoryIdempotencyRepo() *inMemoryIdempotencyRepo { + return &inMemoryIdempotencyRepo{ + nextID: 1, + data: make(map[string]*IdempotencyRecord), + } +} + +func (r *inMemoryIdempotencyRepo) key(scope, hash string) string { + return scope + "|" + hash +} + +func cloneRecord(in *IdempotencyRecord) *IdempotencyRecord { + if in == nil { + return nil + } + out := *in + if in.ResponseStatus != nil { + v := *in.ResponseStatus + out.ResponseStatus = &v + } + if in.ResponseBody != nil { + v := *in.ResponseBody + out.ResponseBody = &v + } + if in.ErrorReason != nil { + v := *in.ErrorReason + out.ErrorReason = &v + } + if in.LockedUntil != nil { + v := *in.LockedUntil + out.LockedUntil = &v + } + return &out +} + +func (r *inMemoryIdempotencyRepo) CreateProcessing(_ context.Context, record *IdempotencyRecord) (bool, error) { + r.mu.Lock() + defer r.mu.Unlock() + k := r.key(record.Scope, record.IdempotencyKeyHash) + if _, ok := r.data[k]; ok { + return false, nil + } + rec := cloneRecord(record) + rec.ID = r.nextID + rec.CreatedAt = time.Now() + rec.UpdatedAt = rec.CreatedAt + r.nextID++ + r.data[k] = rec + record.ID = rec.ID + record.CreatedAt = rec.CreatedAt + record.UpdatedAt = rec.UpdatedAt + return true, nil +} + +func (r *inMemoryIdempotencyRepo) GetByScopeAndKeyHash(_ context.Context, scope, keyHash string) (*IdempotencyRecord, error) { + r.mu.Lock() + defer r.mu.Unlock() + return cloneRecord(r.data[r.key(scope, keyHash)]), nil +} + +func (r *inMemoryIdempotencyRepo) TryReclaim(_ context.Context, id int64, fromStatus string, now, newLockedUntil, newExpiresAt time.Time) (bool, error) { + r.mu.Lock() + defer r.mu.Unlock() + for _, rec := range r.data { + if rec.ID != id { + continue + } + if rec.Status != fromStatus { + return false, nil + } + if rec.LockedUntil != nil && rec.LockedUntil.After(now) { + return false, nil + } + rec.Status = IdempotencyStatusProcessing + rec.LockedUntil = &newLockedUntil + rec.ExpiresAt = newExpiresAt + rec.ErrorReason = nil + rec.UpdatedAt = time.Now() + return true, nil + } + return false, nil +} + +func (r *inMemoryIdempotencyRepo) ExtendProcessingLock(_ context.Context, id int64, requestFingerprint string, newLockedUntil, newExpiresAt time.Time) (bool, error) { + r.mu.Lock() + defer r.mu.Unlock() + + for _, rec := range r.data { + if rec.ID != id { + continue + } + if rec.Status != IdempotencyStatusProcessing || rec.RequestFingerprint != requestFingerprint { + return false, nil + } + rec.LockedUntil = &newLockedUntil + rec.ExpiresAt = newExpiresAt + rec.UpdatedAt = time.Now() + return true, nil + } + return false, nil +} + +func (r *inMemoryIdempotencyRepo) MarkSucceeded(_ context.Context, id int64, responseStatus int, responseBody string, expiresAt time.Time) error { + r.mu.Lock() + defer r.mu.Unlock() + for _, rec := range r.data { + if rec.ID != id { + continue + } + rec.Status = IdempotencyStatusSucceeded + rec.LockedUntil = nil + rec.ExpiresAt = expiresAt + rec.UpdatedAt = time.Now() + rec.ErrorReason = nil + rec.ResponseStatus = &responseStatus + rec.ResponseBody = &responseBody + return nil + } + return errors.New("record not found") +} + +func (r *inMemoryIdempotencyRepo) MarkFailedRetryable(_ context.Context, id int64, errorReason string, lockedUntil, expiresAt time.Time) error { + r.mu.Lock() + defer r.mu.Unlock() + for _, rec := range r.data { + if rec.ID != id { + continue + } + rec.Status = IdempotencyStatusFailedRetryable + rec.LockedUntil = &lockedUntil + rec.ExpiresAt = expiresAt + rec.UpdatedAt = time.Now() + rec.ErrorReason = &errorReason + return nil + } + return errors.New("record not found") +} + +func (r *inMemoryIdempotencyRepo) DeleteExpired(_ context.Context, now time.Time, _ int) (int64, error) { + r.mu.Lock() + defer r.mu.Unlock() + var deleted int64 + for k, rec := range r.data { + if !rec.ExpiresAt.After(now) { + delete(r.data, k) + deleted++ + } + } + return deleted, nil +} + +func TestIdempotencyCoordinator_RequireKey(t *testing.T) { + resetIdempotencyMetricsForTest() + repo := newInMemoryIdempotencyRepo() + cfg := DefaultIdempotencyConfig() + cfg.ObserveOnly = false + coordinator := NewIdempotencyCoordinator(repo, cfg) + + _, err := coordinator.Execute(context.Background(), IdempotencyExecuteOptions{ + Scope: "test.scope", + Method: "POST", + Route: "/test", + ActorScope: "admin:1", + RequireKey: true, + Payload: map[string]any{"a": 1}, + }, func(ctx context.Context) (any, error) { + return map[string]any{"ok": true}, nil + }) + require.Error(t, err) + require.Equal(t, infraerrors.Code(err), infraerrors.Code(ErrIdempotencyKeyRequired)) +} + +func TestIdempotencyCoordinator_ReplaySucceededResult(t *testing.T) { + resetIdempotencyMetricsForTest() + repo := newInMemoryIdempotencyRepo() + cfg := DefaultIdempotencyConfig() + coordinator := NewIdempotencyCoordinator(repo, cfg) + + execCount := 0 + exec := func(ctx context.Context) (any, error) { + execCount++ + return map[string]any{"count": execCount}, nil + } + + opts := IdempotencyExecuteOptions{ + Scope: "test.scope", + Method: "POST", + Route: "/test", + ActorScope: "user:1", + RequireKey: true, + IdempotencyKey: "case-1", + Payload: map[string]any{"a": 1}, + } + + first, err := coordinator.Execute(context.Background(), opts, exec) + require.NoError(t, err) + require.False(t, first.Replayed) + + second, err := coordinator.Execute(context.Background(), opts, exec) + require.NoError(t, err) + require.True(t, second.Replayed) + require.Equal(t, 1, execCount, "second request should replay without executing business logic") + + metrics := GetIdempotencyMetricsSnapshot() + require.Equal(t, uint64(1), metrics.ClaimTotal) + require.Equal(t, uint64(1), metrics.ReplayTotal) +} + +func TestIdempotencyCoordinator_ReclaimExpiredSucceededRecord(t *testing.T) { + resetIdempotencyMetricsForTest() + repo := newInMemoryIdempotencyRepo() + coordinator := NewIdempotencyCoordinator(repo, DefaultIdempotencyConfig()) + + opts := IdempotencyExecuteOptions{ + Scope: "test.scope.expired", + Method: "POST", + Route: "/test/expired", + ActorScope: "user:99", + RequireKey: true, + IdempotencyKey: "expired-case", + Payload: map[string]any{"k": "v"}, + } + + execCount := 0 + exec := func(ctx context.Context) (any, error) { + execCount++ + return map[string]any{"count": execCount}, nil + } + + first, err := coordinator.Execute(context.Background(), opts, exec) + require.NoError(t, err) + require.NotNil(t, first) + require.False(t, first.Replayed) + require.Equal(t, 1, execCount) + + keyHash := HashIdempotencyKey(opts.IdempotencyKey) + repo.mu.Lock() + existing := repo.data[repo.key(opts.Scope, keyHash)] + require.NotNil(t, existing) + existing.ExpiresAt = time.Now().Add(-time.Second) + repo.mu.Unlock() + + second, err := coordinator.Execute(context.Background(), opts, exec) + require.NoError(t, err) + require.NotNil(t, second) + require.False(t, second.Replayed, "expired record should be reclaimed and execute business logic again") + require.Equal(t, 2, execCount) + + third, err := coordinator.Execute(context.Background(), opts, exec) + require.NoError(t, err) + require.NotNil(t, third) + require.True(t, third.Replayed) + payload, ok := third.Data.(map[string]any) + require.True(t, ok) + require.Equal(t, float64(2), payload["count"]) + + metrics := GetIdempotencyMetricsSnapshot() + require.GreaterOrEqual(t, metrics.ClaimTotal, uint64(2)) + require.GreaterOrEqual(t, metrics.ReplayTotal, uint64(1)) +} + +func TestIdempotencyCoordinator_SameKeyDifferentPayloadConflict(t *testing.T) { + resetIdempotencyMetricsForTest() + repo := newInMemoryIdempotencyRepo() + cfg := DefaultIdempotencyConfig() + coordinator := NewIdempotencyCoordinator(repo, cfg) + + _, err := coordinator.Execute(context.Background(), IdempotencyExecuteOptions{ + Scope: "test.scope", + Method: "POST", + Route: "/test", + ActorScope: "user:1", + RequireKey: true, + IdempotencyKey: "case-2", + Payload: map[string]any{"a": 1}, + }, func(ctx context.Context) (any, error) { + return map[string]any{"ok": true}, nil + }) + require.NoError(t, err) + + _, err = coordinator.Execute(context.Background(), IdempotencyExecuteOptions{ + Scope: "test.scope", + Method: "POST", + Route: "/test", + ActorScope: "user:1", + RequireKey: true, + IdempotencyKey: "case-2", + Payload: map[string]any{"a": 2}, + }, func(ctx context.Context) (any, error) { + return map[string]any{"ok": true}, nil + }) + require.Error(t, err) + require.Equal(t, infraerrors.Code(err), infraerrors.Code(ErrIdempotencyKeyConflict)) + + metrics := GetIdempotencyMetricsSnapshot() + require.Equal(t, uint64(1), metrics.ConflictTotal) +} + +func TestIdempotencyCoordinator_BackoffAfterRetryableFailure(t *testing.T) { + resetIdempotencyMetricsForTest() + repo := newInMemoryIdempotencyRepo() + cfg := DefaultIdempotencyConfig() + cfg.FailedRetryBackoff = 2 * time.Second + coordinator := NewIdempotencyCoordinator(repo, cfg) + + opts := IdempotencyExecuteOptions{ + Scope: "test.scope", + Method: "POST", + Route: "/test", + ActorScope: "user:1", + RequireKey: true, + IdempotencyKey: "case-3", + Payload: map[string]any{"a": 1}, + } + + _, err := coordinator.Execute(context.Background(), opts, func(ctx context.Context) (any, error) { + return nil, infraerrors.InternalServer("UPSTREAM_ERROR", "upstream error") + }) + require.Error(t, err) + + _, err = coordinator.Execute(context.Background(), opts, func(ctx context.Context) (any, error) { + return map[string]any{"ok": true}, nil + }) + require.Error(t, err) + require.Equal(t, infraerrors.Code(err), infraerrors.Code(ErrIdempotencyRetryBackoff)) + require.Greater(t, RetryAfterSecondsFromError(err), 0) + + metrics := GetIdempotencyMetricsSnapshot() + require.GreaterOrEqual(t, metrics.RetryBackoffTotal, uint64(2)) + require.GreaterOrEqual(t, metrics.ConflictTotal, uint64(1)) + require.GreaterOrEqual(t, metrics.ProcessingDurationCount, uint64(1)) +} + +func TestIdempotencyCoordinator_ConcurrentSameKeySingleSideEffect(t *testing.T) { + resetIdempotencyMetricsForTest() + repo := newInMemoryIdempotencyRepo() + cfg := DefaultIdempotencyConfig() + cfg.ProcessingTimeout = 2 * time.Second + coordinator := NewIdempotencyCoordinator(repo, cfg) + + opts := IdempotencyExecuteOptions{ + Scope: "test.scope.concurrent", + Method: "POST", + Route: "/test/concurrent", + ActorScope: "user:7", + RequireKey: true, + IdempotencyKey: "concurrent-case", + Payload: map[string]any{"v": 1}, + } + + var execCount int32 + var wg sync.WaitGroup + for i := 0; i < 8; i++ { + wg.Add(1) + go func() { + defer wg.Done() + _, _ = coordinator.Execute(context.Background(), opts, func(ctx context.Context) (any, error) { + atomic.AddInt32(&execCount, 1) + time.Sleep(80 * time.Millisecond) + return map[string]any{"ok": true}, nil + }) + }() + } + wg.Wait() + + replayed, err := coordinator.Execute(context.Background(), opts, func(ctx context.Context) (any, error) { + atomic.AddInt32(&execCount, 1) + return map[string]any{"ok": true}, nil + }) + require.NoError(t, err) + require.True(t, replayed.Replayed) + require.Equal(t, int32(1), atomic.LoadInt32(&execCount), "concurrent same-key requests should execute business side-effect once") + + metrics := GetIdempotencyMetricsSnapshot() + require.Equal(t, uint64(1), metrics.ClaimTotal) + require.Equal(t, uint64(1), metrics.ReplayTotal) + require.GreaterOrEqual(t, metrics.ConflictTotal, uint64(1)) +} + +type failingIdempotencyRepo struct{} + +func (failingIdempotencyRepo) CreateProcessing(context.Context, *IdempotencyRecord) (bool, error) { + return false, errors.New("store unavailable") +} +func (failingIdempotencyRepo) GetByScopeAndKeyHash(context.Context, string, string) (*IdempotencyRecord, error) { + return nil, errors.New("store unavailable") +} +func (failingIdempotencyRepo) TryReclaim(context.Context, int64, string, time.Time, time.Time, time.Time) (bool, error) { + return false, errors.New("store unavailable") +} +func (failingIdempotencyRepo) ExtendProcessingLock(context.Context, int64, string, time.Time, time.Time) (bool, error) { + return false, errors.New("store unavailable") +} +func (failingIdempotencyRepo) MarkSucceeded(context.Context, int64, int, string, time.Time) error { + return errors.New("store unavailable") +} +func (failingIdempotencyRepo) MarkFailedRetryable(context.Context, int64, string, time.Time, time.Time) error { + return errors.New("store unavailable") +} +func (failingIdempotencyRepo) DeleteExpired(context.Context, time.Time, int) (int64, error) { + return 0, errors.New("store unavailable") +} + +func TestIdempotencyCoordinator_StoreUnavailableMetrics(t *testing.T) { + resetIdempotencyMetricsForTest() + coordinator := NewIdempotencyCoordinator(failingIdempotencyRepo{}, DefaultIdempotencyConfig()) + + _, err := coordinator.Execute(context.Background(), IdempotencyExecuteOptions{ + Scope: "test.scope.unavailable", + Method: "POST", + Route: "/test/unavailable", + ActorScope: "admin:1", + RequireKey: true, + IdempotencyKey: "case-unavailable", + Payload: map[string]any{"v": 1}, + }, func(ctx context.Context) (any, error) { + return map[string]any{"ok": true}, nil + }) + require.Error(t, err) + require.Equal(t, infraerrors.Code(ErrIdempotencyStoreUnavail), infraerrors.Code(err)) + require.GreaterOrEqual(t, GetIdempotencyMetricsSnapshot().StoreUnavailableTotal, uint64(1)) +} + +func TestDefaultIdempotencyCoordinatorAndTTLs(t *testing.T) { + SetDefaultIdempotencyCoordinator(nil) + require.Nil(t, DefaultIdempotencyCoordinator()) + require.Equal(t, DefaultIdempotencyConfig().DefaultTTL, DefaultWriteIdempotencyTTL()) + require.Equal(t, DefaultIdempotencyConfig().SystemOperationTTL, DefaultSystemOperationIdempotencyTTL()) + + coordinator := NewIdempotencyCoordinator(newInMemoryIdempotencyRepo(), IdempotencyConfig{ + DefaultTTL: 2 * time.Hour, + SystemOperationTTL: 15 * time.Minute, + ProcessingTimeout: 10 * time.Second, + FailedRetryBackoff: 3 * time.Second, + ObserveOnly: false, + }) + SetDefaultIdempotencyCoordinator(coordinator) + t.Cleanup(func() { + SetDefaultIdempotencyCoordinator(nil) + }) + + require.Same(t, coordinator, DefaultIdempotencyCoordinator()) + require.Equal(t, 2*time.Hour, DefaultWriteIdempotencyTTL()) + require.Equal(t, 15*time.Minute, DefaultSystemOperationIdempotencyTTL()) +} + +func TestNormalizeIdempotencyKeyAndFingerprint(t *testing.T) { + key, err := NormalizeIdempotencyKey(" abc-123 ") + require.NoError(t, err) + require.Equal(t, "abc-123", key) + + key, err = NormalizeIdempotencyKey("") + require.NoError(t, err) + require.Equal(t, "", key) + + _, err = NormalizeIdempotencyKey(string(make([]byte, 129))) + require.Error(t, err) + + _, err = NormalizeIdempotencyKey("bad\nkey") + require.Error(t, err) + + fp1, err := BuildIdempotencyFingerprint("", "", "", map[string]any{"a": 1}) + require.NoError(t, err) + require.NotEmpty(t, fp1) + fp2, err := BuildIdempotencyFingerprint("POST", "/", "anonymous", map[string]any{"a": 1}) + require.NoError(t, err) + require.Equal(t, fp1, fp2) + + _, err = BuildIdempotencyFingerprint("POST", "/x", "u:1", map[string]any{"bad": make(chan int)}) + require.Error(t, err) + require.Equal(t, infraerrors.Code(ErrIdempotencyInvalidPayload), infraerrors.Code(err)) +} + +func TestRetryAfterSecondsFromErrorBranches(t *testing.T) { + require.Equal(t, 0, RetryAfterSecondsFromError(nil)) + require.Equal(t, 0, RetryAfterSecondsFromError(errors.New("plain"))) + + err := ErrIdempotencyInProgress.WithMetadata(map[string]string{"retry_after": "12"}) + require.Equal(t, 12, RetryAfterSecondsFromError(err)) + + err = ErrIdempotencyInProgress.WithMetadata(map[string]string{"retry_after": "bad"}) + require.Equal(t, 0, RetryAfterSecondsFromError(err)) +} + +func TestIdempotencyCoordinator_ExecuteNilExecutorAndNoKeyPassThrough(t *testing.T) { + repo := newInMemoryIdempotencyRepo() + coordinator := NewIdempotencyCoordinator(repo, DefaultIdempotencyConfig()) + + _, err := coordinator.Execute(context.Background(), IdempotencyExecuteOptions{ + Scope: "scope", + IdempotencyKey: "k", + Payload: map[string]any{"a": 1}, + }, nil) + require.Error(t, err) + require.Equal(t, "IDEMPOTENCY_EXECUTOR_NIL", infraerrors.Reason(err)) + + called := 0 + result, err := coordinator.Execute(context.Background(), IdempotencyExecuteOptions{ + Scope: "scope", + RequireKey: true, + Payload: map[string]any{"a": 1}, + }, func(ctx context.Context) (any, error) { + called++ + return map[string]any{"ok": true}, nil + }) + require.NoError(t, err) + require.Equal(t, 1, called) + require.NotNil(t, result) + require.False(t, result.Replayed) +} + +type noIDOwnerRepo struct{} + +func (noIDOwnerRepo) CreateProcessing(context.Context, *IdempotencyRecord) (bool, error) { + return true, nil +} +func (noIDOwnerRepo) GetByScopeAndKeyHash(context.Context, string, string) (*IdempotencyRecord, error) { + return nil, nil +} +func (noIDOwnerRepo) TryReclaim(context.Context, int64, string, time.Time, time.Time, time.Time) (bool, error) { + return false, nil +} +func (noIDOwnerRepo) ExtendProcessingLock(context.Context, int64, string, time.Time, time.Time) (bool, error) { + return false, nil +} +func (noIDOwnerRepo) MarkSucceeded(context.Context, int64, int, string, time.Time) error { return nil } +func (noIDOwnerRepo) MarkFailedRetryable(context.Context, int64, string, time.Time, time.Time) error { + return nil +} +func (noIDOwnerRepo) DeleteExpired(context.Context, time.Time, int) (int64, error) { return 0, nil } + +func TestIdempotencyCoordinator_RepoNilScopeRequiredAndRecordIDMissing(t *testing.T) { + cfg := DefaultIdempotencyConfig() + coordinator := NewIdempotencyCoordinator(nil, cfg) + + _, err := coordinator.Execute(context.Background(), IdempotencyExecuteOptions{ + Scope: "scope", + IdempotencyKey: "k", + Payload: map[string]any{"a": 1}, + }, func(ctx context.Context) (any, error) { + return map[string]any{"ok": true}, nil + }) + require.Error(t, err) + require.Equal(t, infraerrors.Code(ErrIdempotencyStoreUnavail), infraerrors.Code(err)) + + coordinator = NewIdempotencyCoordinator(newInMemoryIdempotencyRepo(), cfg) + _, err = coordinator.Execute(context.Background(), IdempotencyExecuteOptions{ + IdempotencyKey: "k2", + Payload: map[string]any{"a": 1}, + }, func(ctx context.Context) (any, error) { + return map[string]any{"ok": true}, nil + }) + require.Error(t, err) + require.Equal(t, "IDEMPOTENCY_SCOPE_REQUIRED", infraerrors.Reason(err)) + + coordinator = NewIdempotencyCoordinator(noIDOwnerRepo{}, cfg) + _, err = coordinator.Execute(context.Background(), IdempotencyExecuteOptions{ + Scope: "scope-no-id", + IdempotencyKey: "k3", + Payload: map[string]any{"a": 1}, + }, func(ctx context.Context) (any, error) { + return map[string]any{"ok": true}, nil + }) + require.Error(t, err) + require.Equal(t, infraerrors.Code(ErrIdempotencyStoreUnavail), infraerrors.Code(err)) +} + +type conflictBranchRepo struct { + existing *IdempotencyRecord + tryReclaimErr error + tryReclaimOK bool +} + +func (r *conflictBranchRepo) CreateProcessing(context.Context, *IdempotencyRecord) (bool, error) { + return false, nil +} +func (r *conflictBranchRepo) GetByScopeAndKeyHash(context.Context, string, string) (*IdempotencyRecord, error) { + return cloneRecord(r.existing), nil +} +func (r *conflictBranchRepo) TryReclaim(context.Context, int64, string, time.Time, time.Time, time.Time) (bool, error) { + if r.tryReclaimErr != nil { + return false, r.tryReclaimErr + } + return r.tryReclaimOK, nil +} +func (r *conflictBranchRepo) ExtendProcessingLock(context.Context, int64, string, time.Time, time.Time) (bool, error) { + return false, nil +} +func (r *conflictBranchRepo) MarkSucceeded(context.Context, int64, int, string, time.Time) error { + return nil +} +func (r *conflictBranchRepo) MarkFailedRetryable(context.Context, int64, string, time.Time, time.Time) error { + return nil +} +func (r *conflictBranchRepo) DeleteExpired(context.Context, time.Time, int) (int64, error) { + return 0, nil +} + +func TestIdempotencyCoordinator_ConflictBranchesAndDecodeError(t *testing.T) { + now := time.Now() + fp, err := BuildIdempotencyFingerprint("POST", "/x", "u:1", map[string]any{"a": 1}) + require.NoError(t, err) + badBody := "{bad-json" + repo := &conflictBranchRepo{ + existing: &IdempotencyRecord{ + ID: 1, + Scope: "scope", + IdempotencyKeyHash: HashIdempotencyKey("k"), + RequestFingerprint: fp, + Status: IdempotencyStatusSucceeded, + ResponseBody: &badBody, + ExpiresAt: now.Add(time.Hour), + }, + } + coordinator := NewIdempotencyCoordinator(repo, DefaultIdempotencyConfig()) + _, err = coordinator.Execute(context.Background(), IdempotencyExecuteOptions{ + Scope: "scope", + IdempotencyKey: "k", + Method: "POST", + Route: "/x", + ActorScope: "u:1", + Payload: map[string]any{"a": 1}, + }, func(ctx context.Context) (any, error) { + return map[string]any{"ok": true}, nil + }) + require.Error(t, err) + require.Equal(t, infraerrors.Code(ErrIdempotencyStoreUnavail), infraerrors.Code(err)) + + repo.existing = &IdempotencyRecord{ + ID: 2, + Scope: "scope", + IdempotencyKeyHash: HashIdempotencyKey("k"), + RequestFingerprint: fp, + Status: "unknown", + ExpiresAt: now.Add(time.Hour), + } + _, err = coordinator.Execute(context.Background(), IdempotencyExecuteOptions{ + Scope: "scope", + IdempotencyKey: "k", + Method: "POST", + Route: "/x", + ActorScope: "u:1", + Payload: map[string]any{"a": 1}, + }, func(ctx context.Context) (any, error) { + return map[string]any{"ok": true}, nil + }) + require.Error(t, err) + require.Equal(t, infraerrors.Code(ErrIdempotencyKeyConflict), infraerrors.Code(err)) + + repo.existing = &IdempotencyRecord{ + ID: 3, + Scope: "scope", + IdempotencyKeyHash: HashIdempotencyKey("k"), + RequestFingerprint: fp, + Status: IdempotencyStatusFailedRetryable, + LockedUntil: ptrTime(now.Add(-time.Second)), + ExpiresAt: now.Add(time.Hour), + } + repo.tryReclaimErr = errors.New("reclaim down") + _, err = coordinator.Execute(context.Background(), IdempotencyExecuteOptions{ + Scope: "scope", + IdempotencyKey: "k", + Method: "POST", + Route: "/x", + ActorScope: "u:1", + Payload: map[string]any{"a": 1}, + }, func(ctx context.Context) (any, error) { + return map[string]any{"ok": true}, nil + }) + require.Error(t, err) + require.Equal(t, infraerrors.Code(ErrIdempotencyStoreUnavail), infraerrors.Code(err)) + + repo.tryReclaimErr = nil + repo.tryReclaimOK = false + _, err = coordinator.Execute(context.Background(), IdempotencyExecuteOptions{ + Scope: "scope", + IdempotencyKey: "k", + Method: "POST", + Route: "/x", + ActorScope: "u:1", + Payload: map[string]any{"a": 1}, + }, func(ctx context.Context) (any, error) { + return map[string]any{"ok": true}, nil + }) + require.Error(t, err) + require.Equal(t, infraerrors.Code(ErrIdempotencyInProgress), infraerrors.Code(err)) +} + +type markBehaviorRepo struct { + inMemoryIdempotencyRepo + failMarkSucceeded bool + failMarkFailed bool +} + +func (r *markBehaviorRepo) MarkSucceeded(ctx context.Context, id int64, responseStatus int, responseBody string, expiresAt time.Time) error { + if r.failMarkSucceeded { + return errors.New("mark succeeded failed") + } + return r.inMemoryIdempotencyRepo.MarkSucceeded(ctx, id, responseStatus, responseBody, expiresAt) +} + +func (r *markBehaviorRepo) MarkFailedRetryable(ctx context.Context, id int64, errorReason string, lockedUntil, expiresAt time.Time) error { + if r.failMarkFailed { + return errors.New("mark failed retryable failed") + } + return r.inMemoryIdempotencyRepo.MarkFailedRetryable(ctx, id, errorReason, lockedUntil, expiresAt) +} + +func TestIdempotencyCoordinator_MarkAndMarshalBranches(t *testing.T) { + repo := &markBehaviorRepo{inMemoryIdempotencyRepo: *newInMemoryIdempotencyRepo()} + coordinator := NewIdempotencyCoordinator(repo, DefaultIdempotencyConfig()) + + repo.failMarkSucceeded = true + _, err := coordinator.Execute(context.Background(), IdempotencyExecuteOptions{ + Scope: "scope-success", + IdempotencyKey: "k1", + Method: "POST", + Route: "/ok", + ActorScope: "u:1", + Payload: map[string]any{"a": 1}, + }, func(ctx context.Context) (any, error) { + return map[string]any{"ok": true}, nil + }) + require.Error(t, err) + require.Equal(t, infraerrors.Code(ErrIdempotencyStoreUnavail), infraerrors.Code(err)) + + repo.failMarkSucceeded = false + _, err = coordinator.Execute(context.Background(), IdempotencyExecuteOptions{ + Scope: "scope-marshal", + IdempotencyKey: "k2", + Method: "POST", + Route: "/bad", + ActorScope: "u:1", + Payload: map[string]any{"a": 1}, + }, func(ctx context.Context) (any, error) { + return map[string]any{"bad": make(chan int)}, nil + }) + require.Error(t, err) + require.Equal(t, infraerrors.Code(ErrIdempotencyStoreUnavail), infraerrors.Code(err)) + + repo.failMarkFailed = true + _, err = coordinator.Execute(context.Background(), IdempotencyExecuteOptions{ + Scope: "scope-fail", + IdempotencyKey: "k3", + Method: "POST", + Route: "/fail", + ActorScope: "u:1", + Payload: map[string]any{"a": 1}, + }, func(ctx context.Context) (any, error) { + return nil, errors.New("plain failure") + }) + require.Error(t, err) + require.Equal(t, "plain failure", err.Error()) +} + +func TestIdempotencyCoordinator_HelperBranches(t *testing.T) { + c := NewIdempotencyCoordinator(newInMemoryIdempotencyRepo(), IdempotencyConfig{ + DefaultTTL: time.Hour, + SystemOperationTTL: time.Hour, + ProcessingTimeout: time.Second, + FailedRetryBackoff: time.Second, + MaxStoredResponseLen: 12, + ObserveOnly: false, + }) + + // conflictWithRetryAfter without locked_until should return base error. + base := ErrIdempotencyInProgress + err := c.conflictWithRetryAfter(base, nil, time.Now()) + require.Equal(t, infraerrors.Code(base), infraerrors.Code(err)) + + // marshalStoredResponse should truncate. + body, err := c.marshalStoredResponse(map[string]any{"long": "abcdefghijklmnopqrstuvwxyz"}) + require.NoError(t, err) + require.Contains(t, body, "...(truncated)") + + // decodeStoredResponse empty and invalid json. + out, err := c.decodeStoredResponse(nil) + require.NoError(t, err) + _, ok := out.(map[string]any) + require.True(t, ok) + + invalid := "{invalid" + _, err = c.decodeStoredResponse(&invalid) + require.Error(t, err) +} diff --git a/backend/internal/service/identity_service.go b/backend/internal/service/identity_service.go index 261da0ef..dc59010d 100644 --- a/backend/internal/service/identity_service.go +++ b/backend/internal/service/identity_service.go @@ -7,13 +7,14 @@ import ( "encoding/hex" "encoding/json" "fmt" - "log" "log/slog" "net/http" "regexp" "strconv" "strings" "time" + + "github.com/Wei-Shaw/sub2api/internal/pkg/logger" ) // 预编译正则表达式(避免每次调用重新编译) @@ -84,7 +85,7 @@ func (s *IdentityService) GetOrCreateFingerprint(ctx context.Context, accountID cached.UserAgent = clientUA // 保存更新后的指纹 _ = s.cache.SetFingerprint(ctx, accountID, cached) - log.Printf("Updated fingerprint user-agent for account %d: %s", accountID, clientUA) + logger.LegacyPrintf("service.identity", "Updated fingerprint user-agent for account %d: %s", accountID, clientUA) } return cached, nil } @@ -97,10 +98,10 @@ func (s *IdentityService) GetOrCreateFingerprint(ctx context.Context, accountID // 保存到缓存(永不过期) if err := s.cache.SetFingerprint(ctx, accountID, fp); err != nil { - log.Printf("Warning: failed to cache fingerprint for account %d: %v", accountID, err) + logger.LegacyPrintf("service.identity", "Warning: failed to cache fingerprint for account %d: %v", accountID, err) } - log.Printf("Created new fingerprint for account %d with client_id: %s", accountID, fp.ClientID) + logger.LegacyPrintf("service.identity", "Created new fingerprint for account %d with client_id: %s", accountID, fp.ClientID) return fp, nil } @@ -277,19 +278,19 @@ func (s *IdentityService) RewriteUserIDWithMasking(ctx context.Context, body []b // 获取或生成固定的伪装 session ID maskedSessionID, err := s.cache.GetMaskedSessionID(ctx, account.ID) if err != nil { - log.Printf("Warning: failed to get masked session ID for account %d: %v", account.ID, err) + logger.LegacyPrintf("service.identity", "Warning: failed to get masked session ID for account %d: %v", account.ID, err) return newBody, nil } if maskedSessionID == "" { // 首次或已过期,生成新的伪装 session ID maskedSessionID = generateRandomUUID() - log.Printf("Generated new masked session ID for account %d: %s", account.ID, maskedSessionID) + logger.LegacyPrintf("service.identity", "Generated new masked session ID for account %d: %s", account.ID, maskedSessionID) } // 刷新 TTL(每次请求都刷新,保持 15 分钟有效期) if err := s.cache.SetMaskedSessionID(ctx, account.ID, maskedSessionID); err != nil { - log.Printf("Warning: failed to set masked session ID for account %d: %v", account.ID, err) + logger.LegacyPrintf("service.identity", "Warning: failed to set masked session ID for account %d: %v", account.ID, err) } // 替换 session 部分:保留 _session_ 之前的内容,替换之后的内容 @@ -335,7 +336,7 @@ func generateClientID() string { b := make([]byte, 32) if _, err := rand.Read(b); err != nil { // 极罕见的情况,使用时间戳+固定值作为fallback - log.Printf("Warning: crypto/rand.Read failed: %v, using fallback", err) + logger.LegacyPrintf("service.identity", "Warning: crypto/rand.Read failed: %v, using fallback", err) // 使用SHA256(当前纳秒时间)作为fallback h := sha256.Sum256([]byte(fmt.Sprintf("%d", time.Now().UnixNano()))) return hex.EncodeToString(h[:]) diff --git a/backend/internal/service/oauth_service.go b/backend/internal/service/oauth_service.go index 15543080..6f6261d8 100644 --- a/backend/internal/service/oauth_service.go +++ b/backend/internal/service/oauth_service.go @@ -14,6 +14,7 @@ import ( type OpenAIOAuthClient interface { ExchangeCode(ctx context.Context, code, codeVerifier, redirectURI, proxyURL string) (*openai.TokenResponse, error) RefreshToken(ctx context.Context, refreshToken, proxyURL string) (*openai.TokenResponse, error) + RefreshTokenWithClientID(ctx context.Context, refreshToken, proxyURL string, clientID string) (*openai.TokenResponse, error) } // ClaudeOAuthClient handles HTTP requests for Claude OAuth flows @@ -217,7 +218,7 @@ func (s *OAuthService) CookieAuth(ctx context.Context, input *CookieAuthInput) ( // Ensure org_uuid is set (from step 1 if not from token response) if tokenInfo.OrgUUID == "" && orgUUID != "" { tokenInfo.OrgUUID = orgUUID - log.Printf("[OAuth] Set org_uuid from cookie auth: %s", orgUUID) + log.Printf("[OAuth] Set org_uuid from cookie auth") } return tokenInfo, nil @@ -251,16 +252,16 @@ func (s *OAuthService) exchangeCodeForToken(ctx context.Context, code, codeVerif if tokenResp.Organization != nil && tokenResp.Organization.UUID != "" { tokenInfo.OrgUUID = tokenResp.Organization.UUID - log.Printf("[OAuth] Got org_uuid: %s", tokenInfo.OrgUUID) + log.Printf("[OAuth] Got org_uuid") } if tokenResp.Account != nil { if tokenResp.Account.UUID != "" { tokenInfo.AccountUUID = tokenResp.Account.UUID - log.Printf("[OAuth] Got account_uuid: %s", tokenInfo.AccountUUID) + log.Printf("[OAuth] Got account_uuid") } if tokenResp.Account.EmailAddress != "" { tokenInfo.EmailAddress = tokenResp.Account.EmailAddress - log.Printf("[OAuth] Got email_address: %s", tokenInfo.EmailAddress) + log.Printf("[OAuth] Got email_address") } } diff --git a/backend/internal/service/oauth_service_test.go b/backend/internal/service/oauth_service_test.go new file mode 100644 index 00000000..72de4b8c --- /dev/null +++ b/backend/internal/service/oauth_service_test.go @@ -0,0 +1,607 @@ +//go:build unit + +package service + +import ( + "context" + "fmt" + "testing" + + "github.com/Wei-Shaw/sub2api/internal/pkg/oauth" + "github.com/Wei-Shaw/sub2api/internal/pkg/pagination" +) + +// --- mock: ClaudeOAuthClient --- + +type mockClaudeOAuthClient struct { + getOrgUUIDFunc func(ctx context.Context, sessionKey, proxyURL string) (string, error) + getAuthCodeFunc func(ctx context.Context, sessionKey, orgUUID, scope, codeChallenge, state, proxyURL string) (string, error) + exchangeCodeFunc func(ctx context.Context, code, codeVerifier, state, proxyURL string, isSetupToken bool) (*oauth.TokenResponse, error) + refreshTokenFunc func(ctx context.Context, refreshToken, proxyURL string) (*oauth.TokenResponse, error) +} + +func (m *mockClaudeOAuthClient) GetOrganizationUUID(ctx context.Context, sessionKey, proxyURL string) (string, error) { + if m.getOrgUUIDFunc != nil { + return m.getOrgUUIDFunc(ctx, sessionKey, proxyURL) + } + panic("GetOrganizationUUID not implemented") +} + +func (m *mockClaudeOAuthClient) GetAuthorizationCode(ctx context.Context, sessionKey, orgUUID, scope, codeChallenge, state, proxyURL string) (string, error) { + if m.getAuthCodeFunc != nil { + return m.getAuthCodeFunc(ctx, sessionKey, orgUUID, scope, codeChallenge, state, proxyURL) + } + panic("GetAuthorizationCode not implemented") +} + +func (m *mockClaudeOAuthClient) ExchangeCodeForToken(ctx context.Context, code, codeVerifier, state, proxyURL string, isSetupToken bool) (*oauth.TokenResponse, error) { + if m.exchangeCodeFunc != nil { + return m.exchangeCodeFunc(ctx, code, codeVerifier, state, proxyURL, isSetupToken) + } + panic("ExchangeCodeForToken not implemented") +} + +func (m *mockClaudeOAuthClient) RefreshToken(ctx context.Context, refreshToken, proxyURL string) (*oauth.TokenResponse, error) { + if m.refreshTokenFunc != nil { + return m.refreshTokenFunc(ctx, refreshToken, proxyURL) + } + panic("RefreshToken not implemented") +} + +// --- mock: ProxyRepository (最小实现,仅覆盖 OAuthService 依赖的方法) --- + +type mockProxyRepoForOAuth struct { + getByIDFunc func(ctx context.Context, id int64) (*Proxy, error) +} + +func (m *mockProxyRepoForOAuth) Create(ctx context.Context, proxy *Proxy) error { + panic("Create not implemented") +} +func (m *mockProxyRepoForOAuth) GetByID(ctx context.Context, id int64) (*Proxy, error) { + if m.getByIDFunc != nil { + return m.getByIDFunc(ctx, id) + } + return nil, fmt.Errorf("proxy not found") +} +func (m *mockProxyRepoForOAuth) ListByIDs(ctx context.Context, ids []int64) ([]Proxy, error) { + panic("ListByIDs not implemented") +} +func (m *mockProxyRepoForOAuth) Update(ctx context.Context, proxy *Proxy) error { + panic("Update not implemented") +} +func (m *mockProxyRepoForOAuth) Delete(ctx context.Context, id int64) error { + panic("Delete not implemented") +} +func (m *mockProxyRepoForOAuth) List(ctx context.Context, params pagination.PaginationParams) ([]Proxy, *pagination.PaginationResult, error) { + panic("List not implemented") +} +func (m *mockProxyRepoForOAuth) ListWithFilters(ctx context.Context, params pagination.PaginationParams, protocol, status, search string) ([]Proxy, *pagination.PaginationResult, error) { + panic("ListWithFilters not implemented") +} +func (m *mockProxyRepoForOAuth) ListWithFiltersAndAccountCount(ctx context.Context, params pagination.PaginationParams, protocol, status, search string) ([]ProxyWithAccountCount, *pagination.PaginationResult, error) { + panic("ListWithFiltersAndAccountCount not implemented") +} +func (m *mockProxyRepoForOAuth) ListActive(ctx context.Context) ([]Proxy, error) { + panic("ListActive not implemented") +} +func (m *mockProxyRepoForOAuth) ListActiveWithAccountCount(ctx context.Context) ([]ProxyWithAccountCount, error) { + panic("ListActiveWithAccountCount not implemented") +} +func (m *mockProxyRepoForOAuth) ExistsByHostPortAuth(ctx context.Context, host string, port int, username, password string) (bool, error) { + panic("ExistsByHostPortAuth not implemented") +} +func (m *mockProxyRepoForOAuth) CountAccountsByProxyID(ctx context.Context, proxyID int64) (int64, error) { + panic("CountAccountsByProxyID not implemented") +} +func (m *mockProxyRepoForOAuth) ListAccountSummariesByProxyID(ctx context.Context, proxyID int64) ([]ProxyAccountSummary, error) { + panic("ListAccountSummariesByProxyID not implemented") +} + +// ===================== +// 测试用例 +// ===================== + +func TestNewOAuthService(t *testing.T) { + t.Parallel() + + proxyRepo := &mockProxyRepoForOAuth{} + client := &mockClaudeOAuthClient{} + svc := NewOAuthService(proxyRepo, client) + + if svc == nil { + t.Fatal("NewOAuthService 返回 nil") + } + if svc.proxyRepo != proxyRepo { + t.Fatal("proxyRepo 未正确设置") + } + if svc.oauthClient != client { + t.Fatal("oauthClient 未正确设置") + } + if svc.sessionStore == nil { + t.Fatal("sessionStore 应被自动初始化") + } + + // 清理 + svc.Stop() +} + +func TestOAuthService_GenerateAuthURL(t *testing.T) { + t.Parallel() + + svc := NewOAuthService(&mockProxyRepoForOAuth{}, &mockClaudeOAuthClient{}) + defer svc.Stop() + + result, err := svc.GenerateAuthURL(context.Background(), nil) + if err != nil { + t.Fatalf("GenerateAuthURL 返回错误: %v", err) + } + if result == nil { + t.Fatal("GenerateAuthURL 返回 nil") + } + if result.AuthURL == "" { + t.Fatal("AuthURL 为空") + } + if result.SessionID == "" { + t.Fatal("SessionID 为空") + } + + // 验证 session 已存储 + session, ok := svc.sessionStore.Get(result.SessionID) + if !ok { + t.Fatal("session 未在 sessionStore 中找到") + } + if session.Scope != oauth.ScopeOAuth { + t.Fatalf("scope 不匹配: got=%q want=%q", session.Scope, oauth.ScopeOAuth) + } +} + +func TestOAuthService_GenerateAuthURL_WithProxy(t *testing.T) { + t.Parallel() + + proxyRepo := &mockProxyRepoForOAuth{ + getByIDFunc: func(ctx context.Context, id int64) (*Proxy, error) { + return &Proxy{ + ID: 1, + Protocol: "http", + Host: "proxy.example.com", + Port: 8080, + }, nil + }, + } + svc := NewOAuthService(proxyRepo, &mockClaudeOAuthClient{}) + defer svc.Stop() + + proxyID := int64(1) + result, err := svc.GenerateAuthURL(context.Background(), &proxyID) + if err != nil { + t.Fatalf("GenerateAuthURL 返回错误: %v", err) + } + + session, ok := svc.sessionStore.Get(result.SessionID) + if !ok { + t.Fatal("session 未在 sessionStore 中找到") + } + if session.ProxyURL != "http://proxy.example.com:8080" { + t.Fatalf("ProxyURL 不匹配: got=%q", session.ProxyURL) + } +} + +func TestOAuthService_GenerateSetupTokenURL(t *testing.T) { + t.Parallel() + + svc := NewOAuthService(&mockProxyRepoForOAuth{}, &mockClaudeOAuthClient{}) + defer svc.Stop() + + result, err := svc.GenerateSetupTokenURL(context.Background(), nil) + if err != nil { + t.Fatalf("GenerateSetupTokenURL 返回错误: %v", err) + } + if result == nil { + t.Fatal("GenerateSetupTokenURL 返回 nil") + } + + // 验证 scope 是 inference + session, ok := svc.sessionStore.Get(result.SessionID) + if !ok { + t.Fatal("session 未在 sessionStore 中找到") + } + if session.Scope != oauth.ScopeInference { + t.Fatalf("scope 不匹配: got=%q want=%q", session.Scope, oauth.ScopeInference) + } +} + +func TestOAuthService_ExchangeCode_SessionNotFound(t *testing.T) { + t.Parallel() + + svc := NewOAuthService(&mockProxyRepoForOAuth{}, &mockClaudeOAuthClient{}) + defer svc.Stop() + + _, err := svc.ExchangeCode(context.Background(), &ExchangeCodeInput{ + SessionID: "nonexistent-session", + Code: "test-code", + }) + if err == nil { + t.Fatal("ExchangeCode 应返回错误(session 不存在)") + } + if err.Error() != "session not found or expired" { + t.Fatalf("错误信息不匹配: got=%q", err.Error()) + } +} + +func TestOAuthService_ExchangeCode_Success(t *testing.T) { + t.Parallel() + + exchangeCalled := false + client := &mockClaudeOAuthClient{ + exchangeCodeFunc: func(ctx context.Context, code, codeVerifier, state, proxyURL string, isSetupToken bool) (*oauth.TokenResponse, error) { + exchangeCalled = true + if code != "auth-code-123" { + t.Errorf("code 不匹配: got=%q", code) + } + if isSetupToken { + t.Error("isSetupToken 应为 false(ScopeOAuth)") + } + return &oauth.TokenResponse{ + AccessToken: "access-token-abc", + TokenType: "Bearer", + ExpiresIn: 3600, + RefreshToken: "refresh-token-xyz", + Scope: oauth.ScopeOAuth, + Organization: &oauth.OrgInfo{UUID: "org-uuid-111"}, + Account: &oauth.AccountInfo{UUID: "acc-uuid-222", EmailAddress: "test@example.com"}, + }, nil + }, + } + + svc := NewOAuthService(&mockProxyRepoForOAuth{}, client) + defer svc.Stop() + + // 先生成 URL 以创建 session + result, err := svc.GenerateAuthURL(context.Background(), nil) + if err != nil { + t.Fatalf("GenerateAuthURL 返回错误: %v", err) + } + + // 交换 code + tokenInfo, err := svc.ExchangeCode(context.Background(), &ExchangeCodeInput{ + SessionID: result.SessionID, + Code: "auth-code-123", + }) + if err != nil { + t.Fatalf("ExchangeCode 返回错误: %v", err) + } + + if !exchangeCalled { + t.Fatal("ExchangeCodeForToken 未被调用") + } + if tokenInfo.AccessToken != "access-token-abc" { + t.Fatalf("AccessToken 不匹配: got=%q", tokenInfo.AccessToken) + } + if tokenInfo.TokenType != "Bearer" { + t.Fatalf("TokenType 不匹配: got=%q", tokenInfo.TokenType) + } + if tokenInfo.RefreshToken != "refresh-token-xyz" { + t.Fatalf("RefreshToken 不匹配: got=%q", tokenInfo.RefreshToken) + } + if tokenInfo.OrgUUID != "org-uuid-111" { + t.Fatalf("OrgUUID 不匹配: got=%q", tokenInfo.OrgUUID) + } + if tokenInfo.AccountUUID != "acc-uuid-222" { + t.Fatalf("AccountUUID 不匹配: got=%q", tokenInfo.AccountUUID) + } + if tokenInfo.EmailAddress != "test@example.com" { + t.Fatalf("EmailAddress 不匹配: got=%q", tokenInfo.EmailAddress) + } + if tokenInfo.ExpiresIn != 3600 { + t.Fatalf("ExpiresIn 不匹配: got=%d", tokenInfo.ExpiresIn) + } + if tokenInfo.ExpiresAt == 0 { + t.Fatal("ExpiresAt 不应为 0") + } + + // 验证 session 已被删除 + _, ok := svc.sessionStore.Get(result.SessionID) + if ok { + t.Fatal("session 应在交换成功后被删除") + } +} + +func TestOAuthService_ExchangeCode_SetupToken(t *testing.T) { + t.Parallel() + + client := &mockClaudeOAuthClient{ + exchangeCodeFunc: func(ctx context.Context, code, codeVerifier, state, proxyURL string, isSetupToken bool) (*oauth.TokenResponse, error) { + if !isSetupToken { + t.Error("isSetupToken 应为 true(ScopeInference)") + } + return &oauth.TokenResponse{ + AccessToken: "setup-token", + TokenType: "Bearer", + ExpiresIn: 3600, + Scope: oauth.ScopeInference, + }, nil + }, + } + + svc := NewOAuthService(&mockProxyRepoForOAuth{}, client) + defer svc.Stop() + + // 使用 SetupToken URL(inference scope) + result, err := svc.GenerateSetupTokenURL(context.Background(), nil) + if err != nil { + t.Fatalf("GenerateSetupTokenURL 返回错误: %v", err) + } + + tokenInfo, err := svc.ExchangeCode(context.Background(), &ExchangeCodeInput{ + SessionID: result.SessionID, + Code: "setup-code", + }) + if err != nil { + t.Fatalf("ExchangeCode 返回错误: %v", err) + } + if tokenInfo.AccessToken != "setup-token" { + t.Fatalf("AccessToken 不匹配: got=%q", tokenInfo.AccessToken) + } +} + +func TestOAuthService_ExchangeCode_ClientError(t *testing.T) { + t.Parallel() + + client := &mockClaudeOAuthClient{ + exchangeCodeFunc: func(ctx context.Context, code, codeVerifier, state, proxyURL string, isSetupToken bool) (*oauth.TokenResponse, error) { + return nil, fmt.Errorf("upstream error: invalid code") + }, + } + + svc := NewOAuthService(&mockProxyRepoForOAuth{}, client) + defer svc.Stop() + + result, _ := svc.GenerateAuthURL(context.Background(), nil) + _, err := svc.ExchangeCode(context.Background(), &ExchangeCodeInput{ + SessionID: result.SessionID, + Code: "bad-code", + }) + if err == nil { + t.Fatal("ExchangeCode 应返回错误") + } + if err.Error() != "upstream error: invalid code" { + t.Fatalf("错误信息不匹配: got=%q", err.Error()) + } +} + +func TestOAuthService_RefreshToken(t *testing.T) { + t.Parallel() + + client := &mockClaudeOAuthClient{ + refreshTokenFunc: func(ctx context.Context, refreshToken, proxyURL string) (*oauth.TokenResponse, error) { + if refreshToken != "my-refresh-token" { + t.Errorf("refreshToken 不匹配: got=%q", refreshToken) + } + if proxyURL != "" { + t.Errorf("proxyURL 应为空: got=%q", proxyURL) + } + return &oauth.TokenResponse{ + AccessToken: "new-access-token", + TokenType: "Bearer", + ExpiresIn: 7200, + RefreshToken: "new-refresh-token", + Scope: oauth.ScopeOAuth, + }, nil + }, + } + + svc := NewOAuthService(&mockProxyRepoForOAuth{}, client) + defer svc.Stop() + + tokenInfo, err := svc.RefreshToken(context.Background(), "my-refresh-token", "") + if err != nil { + t.Fatalf("RefreshToken 返回错误: %v", err) + } + if tokenInfo.AccessToken != "new-access-token" { + t.Fatalf("AccessToken 不匹配: got=%q", tokenInfo.AccessToken) + } + if tokenInfo.RefreshToken != "new-refresh-token" { + t.Fatalf("RefreshToken 不匹配: got=%q", tokenInfo.RefreshToken) + } + if tokenInfo.ExpiresIn != 7200 { + t.Fatalf("ExpiresIn 不匹配: got=%d", tokenInfo.ExpiresIn) + } + if tokenInfo.ExpiresAt == 0 { + t.Fatal("ExpiresAt 不应为 0") + } +} + +func TestOAuthService_RefreshToken_Error(t *testing.T) { + t.Parallel() + + client := &mockClaudeOAuthClient{ + refreshTokenFunc: func(ctx context.Context, refreshToken, proxyURL string) (*oauth.TokenResponse, error) { + return nil, fmt.Errorf("invalid_grant: token expired") + }, + } + + svc := NewOAuthService(&mockProxyRepoForOAuth{}, client) + defer svc.Stop() + + _, err := svc.RefreshToken(context.Background(), "expired-token", "") + if err == nil { + t.Fatal("RefreshToken 应返回错误") + } +} + +func TestOAuthService_RefreshAccountToken_NoRefreshToken(t *testing.T) { + t.Parallel() + + svc := NewOAuthService(&mockProxyRepoForOAuth{}, &mockClaudeOAuthClient{}) + defer svc.Stop() + + // 无 refresh_token 的账号 + account := &Account{ + ID: 1, + Platform: PlatformAnthropic, + Type: AccountTypeOAuth, + Credentials: map[string]any{ + "access_token": "some-token", + }, + } + _, err := svc.RefreshAccountToken(context.Background(), account) + if err == nil { + t.Fatal("RefreshAccountToken 应返回错误(无 refresh_token)") + } + if err.Error() != "no refresh token available" { + t.Fatalf("错误信息不匹配: got=%q", err.Error()) + } +} + +func TestOAuthService_RefreshAccountToken_EmptyRefreshToken(t *testing.T) { + t.Parallel() + + svc := NewOAuthService(&mockProxyRepoForOAuth{}, &mockClaudeOAuthClient{}) + defer svc.Stop() + + account := &Account{ + ID: 2, + Platform: PlatformAnthropic, + Type: AccountTypeOAuth, + Credentials: map[string]any{ + "access_token": "some-token", + "refresh_token": "", + }, + } + _, err := svc.RefreshAccountToken(context.Background(), account) + if err == nil { + t.Fatal("RefreshAccountToken 应返回错误(refresh_token 为空)") + } +} + +func TestOAuthService_RefreshAccountToken_Success(t *testing.T) { + t.Parallel() + + client := &mockClaudeOAuthClient{ + refreshTokenFunc: func(ctx context.Context, refreshToken, proxyURL string) (*oauth.TokenResponse, error) { + if refreshToken != "account-refresh-token" { + t.Errorf("refreshToken 不匹配: got=%q", refreshToken) + } + return &oauth.TokenResponse{ + AccessToken: "refreshed-access", + TokenType: "Bearer", + ExpiresIn: 3600, + RefreshToken: "new-refresh", + }, nil + }, + } + + svc := NewOAuthService(&mockProxyRepoForOAuth{}, client) + defer svc.Stop() + + account := &Account{ + ID: 3, + Platform: PlatformAnthropic, + Type: AccountTypeOAuth, + Credentials: map[string]any{ + "access_token": "old-access", + "refresh_token": "account-refresh-token", + }, + } + + tokenInfo, err := svc.RefreshAccountToken(context.Background(), account) + if err != nil { + t.Fatalf("RefreshAccountToken 返回错误: %v", err) + } + if tokenInfo.AccessToken != "refreshed-access" { + t.Fatalf("AccessToken 不匹配: got=%q", tokenInfo.AccessToken) + } +} + +func TestOAuthService_RefreshAccountToken_WithProxy(t *testing.T) { + t.Parallel() + + proxyRepo := &mockProxyRepoForOAuth{ + getByIDFunc: func(ctx context.Context, id int64) (*Proxy, error) { + return &Proxy{ + Protocol: "socks5", + Host: "socks.example.com", + Port: 1080, + Username: "user", + Password: "pass", + }, nil + }, + } + + client := &mockClaudeOAuthClient{ + refreshTokenFunc: func(ctx context.Context, refreshToken, proxyURL string) (*oauth.TokenResponse, error) { + if proxyURL != "socks5://user:pass@socks.example.com:1080" { + t.Errorf("proxyURL 不匹配: got=%q", proxyURL) + } + return &oauth.TokenResponse{ + AccessToken: "refreshed", + ExpiresIn: 3600, + }, nil + }, + } + + svc := NewOAuthService(proxyRepo, client) + defer svc.Stop() + + proxyID := int64(10) + account := &Account{ + ID: 4, + Platform: PlatformAnthropic, + Type: AccountTypeOAuth, + ProxyID: &proxyID, + Credentials: map[string]any{ + "refresh_token": "rt-with-proxy", + }, + } + + _, err := svc.RefreshAccountToken(context.Background(), account) + if err != nil { + t.Fatalf("RefreshAccountToken 返回错误: %v", err) + } +} + +func TestOAuthService_ExchangeCode_NilOrg(t *testing.T) { + t.Parallel() + + client := &mockClaudeOAuthClient{ + exchangeCodeFunc: func(ctx context.Context, code, codeVerifier, state, proxyURL string, isSetupToken bool) (*oauth.TokenResponse, error) { + return &oauth.TokenResponse{ + AccessToken: "token-no-org", + TokenType: "Bearer", + ExpiresIn: 3600, + Organization: nil, + Account: nil, + }, nil + }, + } + + svc := NewOAuthService(&mockProxyRepoForOAuth{}, client) + defer svc.Stop() + + result, _ := svc.GenerateAuthURL(context.Background(), nil) + tokenInfo, err := svc.ExchangeCode(context.Background(), &ExchangeCodeInput{ + SessionID: result.SessionID, + Code: "code", + }) + if err != nil { + t.Fatalf("ExchangeCode 返回错误: %v", err) + } + if tokenInfo.OrgUUID != "" { + t.Fatalf("OrgUUID 应为空: got=%q", tokenInfo.OrgUUID) + } + if tokenInfo.AccountUUID != "" { + t.Fatalf("AccountUUID 应为空: got=%q", tokenInfo.AccountUUID) + } +} + +func TestOAuthService_Stop_NoPanic(t *testing.T) { + t.Parallel() + + svc := NewOAuthService(&mockProxyRepoForOAuth{}, &mockClaudeOAuthClient{}) + + // 调用 Stop 不应 panic + svc.Stop() + + // 多次调用也不应 panic + svc.Stop() +} diff --git a/backend/internal/service/openai_client_restriction_detector.go b/backend/internal/service/openai_client_restriction_detector.go new file mode 100644 index 00000000..d1784e11 --- /dev/null +++ b/backend/internal/service/openai_client_restriction_detector.go @@ -0,0 +1,86 @@ +package service + +import ( + "github.com/Wei-Shaw/sub2api/internal/config" + "github.com/Wei-Shaw/sub2api/internal/pkg/openai" + "github.com/gin-gonic/gin" +) + +const ( + // CodexClientRestrictionReasonDisabled 表示账号未开启 codex_cli_only。 + CodexClientRestrictionReasonDisabled = "codex_cli_only_disabled" + // CodexClientRestrictionReasonMatchedUA 表示请求命中官方客户端 UA 白名单。 + CodexClientRestrictionReasonMatchedUA = "official_client_user_agent_matched" + // CodexClientRestrictionReasonMatchedOriginator 表示请求命中官方客户端 originator 白名单。 + CodexClientRestrictionReasonMatchedOriginator = "official_client_originator_matched" + // CodexClientRestrictionReasonNotMatchedUA 表示请求未命中官方客户端 UA 白名单。 + CodexClientRestrictionReasonNotMatchedUA = "official_client_user_agent_not_matched" + // CodexClientRestrictionReasonForceCodexCLI 表示通过 ForceCodexCLI 配置兜底放行。 + CodexClientRestrictionReasonForceCodexCLI = "force_codex_cli_enabled" +) + +// CodexClientRestrictionDetectionResult 是 codex_cli_only 统一检测入口结果。 +type CodexClientRestrictionDetectionResult struct { + Enabled bool + Matched bool + Reason string +} + +// CodexClientRestrictionDetector 定义 codex_cli_only 统一检测入口。 +type CodexClientRestrictionDetector interface { + Detect(c *gin.Context, account *Account) CodexClientRestrictionDetectionResult +} + +// OpenAICodexClientRestrictionDetector 为 OpenAI OAuth codex_cli_only 的默认实现。 +type OpenAICodexClientRestrictionDetector struct { + cfg *config.Config +} + +func NewOpenAICodexClientRestrictionDetector(cfg *config.Config) *OpenAICodexClientRestrictionDetector { + return &OpenAICodexClientRestrictionDetector{cfg: cfg} +} + +func (d *OpenAICodexClientRestrictionDetector) Detect(c *gin.Context, account *Account) CodexClientRestrictionDetectionResult { + if account == nil || !account.IsCodexCLIOnlyEnabled() { + return CodexClientRestrictionDetectionResult{ + Enabled: false, + Matched: false, + Reason: CodexClientRestrictionReasonDisabled, + } + } + + if d != nil && d.cfg != nil && d.cfg.Gateway.ForceCodexCLI { + return CodexClientRestrictionDetectionResult{ + Enabled: true, + Matched: true, + Reason: CodexClientRestrictionReasonForceCodexCLI, + } + } + + userAgent := "" + originator := "" + if c != nil { + userAgent = c.GetHeader("User-Agent") + originator = c.GetHeader("originator") + } + if openai.IsCodexOfficialClientRequest(userAgent) { + return CodexClientRestrictionDetectionResult{ + Enabled: true, + Matched: true, + Reason: CodexClientRestrictionReasonMatchedUA, + } + } + if openai.IsCodexOfficialClientOriginator(originator) { + return CodexClientRestrictionDetectionResult{ + Enabled: true, + Matched: true, + Reason: CodexClientRestrictionReasonMatchedOriginator, + } + } + + return CodexClientRestrictionDetectionResult{ + Enabled: true, + Matched: false, + Reason: CodexClientRestrictionReasonNotMatchedUA, + } +} diff --git a/backend/internal/service/openai_client_restriction_detector_test.go b/backend/internal/service/openai_client_restriction_detector_test.go new file mode 100644 index 00000000..984b4ff6 --- /dev/null +++ b/backend/internal/service/openai_client_restriction_detector_test.go @@ -0,0 +1,124 @@ +package service + +import ( + "net/http" + "net/http/httptest" + "testing" + + "github.com/Wei-Shaw/sub2api/internal/config" + "github.com/gin-gonic/gin" + "github.com/stretchr/testify/require" +) + +func newCodexDetectorTestContext(ua string, originator string) *gin.Context { + rec := httptest.NewRecorder() + c, _ := gin.CreateTestContext(rec) + c.Request = httptest.NewRequest(http.MethodPost, "/v1/responses", nil) + if ua != "" { + c.Request.Header.Set("User-Agent", ua) + } + if originator != "" { + c.Request.Header.Set("originator", originator) + } + return c +} + +func TestOpenAICodexClientRestrictionDetector_Detect(t *testing.T) { + gin.SetMode(gin.TestMode) + + t.Run("未开启开关时绕过", func(t *testing.T) { + detector := NewOpenAICodexClientRestrictionDetector(nil) + account := &Account{Platform: PlatformOpenAI, Type: AccountTypeOAuth, Extra: map[string]any{}} + + result := detector.Detect(newCodexDetectorTestContext("curl/8.0", ""), account) + require.False(t, result.Enabled) + require.False(t, result.Matched) + require.Equal(t, CodexClientRestrictionReasonDisabled, result.Reason) + }) + + t.Run("开启后 codex_cli_rs 命中", func(t *testing.T) { + detector := NewOpenAICodexClientRestrictionDetector(nil) + account := &Account{ + Platform: PlatformOpenAI, + Type: AccountTypeOAuth, + Extra: map[string]any{"codex_cli_only": true}, + } + + result := detector.Detect(newCodexDetectorTestContext("codex_cli_rs/0.99.0", ""), account) + require.True(t, result.Enabled) + require.True(t, result.Matched) + require.Equal(t, CodexClientRestrictionReasonMatchedUA, result.Reason) + }) + + t.Run("开启后 codex_vscode 命中", func(t *testing.T) { + detector := NewOpenAICodexClientRestrictionDetector(nil) + account := &Account{ + Platform: PlatformOpenAI, + Type: AccountTypeOAuth, + Extra: map[string]any{"codex_cli_only": true}, + } + + result := detector.Detect(newCodexDetectorTestContext("codex_vscode/1.0.0", ""), account) + require.True(t, result.Enabled) + require.True(t, result.Matched) + require.Equal(t, CodexClientRestrictionReasonMatchedUA, result.Reason) + }) + + t.Run("开启后 codex_app 命中", func(t *testing.T) { + detector := NewOpenAICodexClientRestrictionDetector(nil) + account := &Account{ + Platform: PlatformOpenAI, + Type: AccountTypeOAuth, + Extra: map[string]any{"codex_cli_only": true}, + } + + result := detector.Detect(newCodexDetectorTestContext("codex_app/2.1.0", ""), account) + require.True(t, result.Enabled) + require.True(t, result.Matched) + require.Equal(t, CodexClientRestrictionReasonMatchedUA, result.Reason) + }) + + t.Run("开启后 originator 命中", func(t *testing.T) { + detector := NewOpenAICodexClientRestrictionDetector(nil) + account := &Account{ + Platform: PlatformOpenAI, + Type: AccountTypeOAuth, + Extra: map[string]any{"codex_cli_only": true}, + } + + result := detector.Detect(newCodexDetectorTestContext("curl/8.0", "codex_chatgpt_desktop"), account) + require.True(t, result.Enabled) + require.True(t, result.Matched) + require.Equal(t, CodexClientRestrictionReasonMatchedOriginator, result.Reason) + }) + + t.Run("开启后非官方客户端拒绝", func(t *testing.T) { + detector := NewOpenAICodexClientRestrictionDetector(nil) + account := &Account{ + Platform: PlatformOpenAI, + Type: AccountTypeOAuth, + Extra: map[string]any{"codex_cli_only": true}, + } + + result := detector.Detect(newCodexDetectorTestContext("curl/8.0", "my_client"), account) + require.True(t, result.Enabled) + require.False(t, result.Matched) + require.Equal(t, CodexClientRestrictionReasonNotMatchedUA, result.Reason) + }) + + t.Run("开启 ForceCodexCLI 时允许通过", func(t *testing.T) { + detector := NewOpenAICodexClientRestrictionDetector(&config.Config{ + Gateway: config.GatewayConfig{ForceCodexCLI: true}, + }) + account := &Account{ + Platform: PlatformOpenAI, + Type: AccountTypeOAuth, + Extra: map[string]any{"codex_cli_only": true}, + } + + result := detector.Detect(newCodexDetectorTestContext("curl/8.0", "my_client"), account) + require.True(t, result.Enabled) + require.True(t, result.Matched) + require.Equal(t, CodexClientRestrictionReasonForceCodexCLI, result.Reason) + }) +} diff --git a/backend/internal/service/openai_codex_transform.go b/backend/internal/service/openai_codex_transform.go index a57f0f99..16befb82 100644 --- a/backend/internal/service/openai_codex_transform.go +++ b/backend/internal/service/openai_codex_transform.go @@ -2,73 +2,66 @@ package service import ( _ "embed" - "encoding/json" - "fmt" - "io" - "net/http" - "os" - "path/filepath" "strings" - "time" -) - -const ( - opencodeCodexHeaderURL = "https://raw.githubusercontent.com/anomalyco/opencode/dev/packages/opencode/src/session/prompt/codex_header.txt" - codexCacheTTL = 15 * time.Minute ) //go:embed prompts/codex_cli_instructions.md var codexCLIInstructions string var codexModelMap = map[string]string{ - "gpt-5.3": "gpt-5.3", - "gpt-5.3-none": "gpt-5.3", - "gpt-5.3-low": "gpt-5.3", - "gpt-5.3-medium": "gpt-5.3", - "gpt-5.3-high": "gpt-5.3", - "gpt-5.3-xhigh": "gpt-5.3", - "gpt-5.3-codex": "gpt-5.3-codex", - "gpt-5.3-codex-low": "gpt-5.3-codex", - "gpt-5.3-codex-medium": "gpt-5.3-codex", - "gpt-5.3-codex-high": "gpt-5.3-codex", - "gpt-5.3-codex-xhigh": "gpt-5.3-codex", - "gpt-5.1-codex": "gpt-5.1-codex", - "gpt-5.1-codex-low": "gpt-5.1-codex", - "gpt-5.1-codex-medium": "gpt-5.1-codex", - "gpt-5.1-codex-high": "gpt-5.1-codex", - "gpt-5.1-codex-max": "gpt-5.1-codex-max", - "gpt-5.1-codex-max-low": "gpt-5.1-codex-max", - "gpt-5.1-codex-max-medium": "gpt-5.1-codex-max", - "gpt-5.1-codex-max-high": "gpt-5.1-codex-max", - "gpt-5.1-codex-max-xhigh": "gpt-5.1-codex-max", - "gpt-5.2": "gpt-5.2", - "gpt-5.2-none": "gpt-5.2", - "gpt-5.2-low": "gpt-5.2", - "gpt-5.2-medium": "gpt-5.2", - "gpt-5.2-high": "gpt-5.2", - "gpt-5.2-xhigh": "gpt-5.2", - "gpt-5.2-codex": "gpt-5.2-codex", - "gpt-5.2-codex-low": "gpt-5.2-codex", - "gpt-5.2-codex-medium": "gpt-5.2-codex", - "gpt-5.2-codex-high": "gpt-5.2-codex", - "gpt-5.2-codex-xhigh": "gpt-5.2-codex", - "gpt-5.1-codex-mini": "gpt-5.1-codex-mini", - "gpt-5.1-codex-mini-medium": "gpt-5.1-codex-mini", - "gpt-5.1-codex-mini-high": "gpt-5.1-codex-mini", - "gpt-5.1": "gpt-5.1", - "gpt-5.1-none": "gpt-5.1", - "gpt-5.1-low": "gpt-5.1", - "gpt-5.1-medium": "gpt-5.1", - "gpt-5.1-high": "gpt-5.1", - "gpt-5.1-chat-latest": "gpt-5.1", - "gpt-5-codex": "gpt-5.1-codex", - "codex-mini-latest": "gpt-5.1-codex-mini", - "gpt-5-codex-mini": "gpt-5.1-codex-mini", - "gpt-5-codex-mini-medium": "gpt-5.1-codex-mini", - "gpt-5-codex-mini-high": "gpt-5.1-codex-mini", - "gpt-5": "gpt-5.1", - "gpt-5-mini": "gpt-5.1", - "gpt-5-nano": "gpt-5.1", + "gpt-5.3": "gpt-5.3-codex", + "gpt-5.3-none": "gpt-5.3-codex", + "gpt-5.3-low": "gpt-5.3-codex", + "gpt-5.3-medium": "gpt-5.3-codex", + "gpt-5.3-high": "gpt-5.3-codex", + "gpt-5.3-xhigh": "gpt-5.3-codex", + "gpt-5.3-codex": "gpt-5.3-codex", + "gpt-5.3-codex-spark": "gpt-5.3-codex", + "gpt-5.3-codex-spark-low": "gpt-5.3-codex", + "gpt-5.3-codex-spark-medium": "gpt-5.3-codex", + "gpt-5.3-codex-spark-high": "gpt-5.3-codex", + "gpt-5.3-codex-spark-xhigh": "gpt-5.3-codex", + "gpt-5.3-codex-low": "gpt-5.3-codex", + "gpt-5.3-codex-medium": "gpt-5.3-codex", + "gpt-5.3-codex-high": "gpt-5.3-codex", + "gpt-5.3-codex-xhigh": "gpt-5.3-codex", + "gpt-5.1-codex": "gpt-5.1-codex", + "gpt-5.1-codex-low": "gpt-5.1-codex", + "gpt-5.1-codex-medium": "gpt-5.1-codex", + "gpt-5.1-codex-high": "gpt-5.1-codex", + "gpt-5.1-codex-max": "gpt-5.1-codex-max", + "gpt-5.1-codex-max-low": "gpt-5.1-codex-max", + "gpt-5.1-codex-max-medium": "gpt-5.1-codex-max", + "gpt-5.1-codex-max-high": "gpt-5.1-codex-max", + "gpt-5.1-codex-max-xhigh": "gpt-5.1-codex-max", + "gpt-5.2": "gpt-5.2", + "gpt-5.2-none": "gpt-5.2", + "gpt-5.2-low": "gpt-5.2", + "gpt-5.2-medium": "gpt-5.2", + "gpt-5.2-high": "gpt-5.2", + "gpt-5.2-xhigh": "gpt-5.2", + "gpt-5.2-codex": "gpt-5.2-codex", + "gpt-5.2-codex-low": "gpt-5.2-codex", + "gpt-5.2-codex-medium": "gpt-5.2-codex", + "gpt-5.2-codex-high": "gpt-5.2-codex", + "gpt-5.2-codex-xhigh": "gpt-5.2-codex", + "gpt-5.1-codex-mini": "gpt-5.1-codex-mini", + "gpt-5.1-codex-mini-medium": "gpt-5.1-codex-mini", + "gpt-5.1-codex-mini-high": "gpt-5.1-codex-mini", + "gpt-5.1": "gpt-5.1", + "gpt-5.1-none": "gpt-5.1", + "gpt-5.1-low": "gpt-5.1", + "gpt-5.1-medium": "gpt-5.1", + "gpt-5.1-high": "gpt-5.1", + "gpt-5.1-chat-latest": "gpt-5.1", + "gpt-5-codex": "gpt-5.1-codex", + "codex-mini-latest": "gpt-5.1-codex-mini", + "gpt-5-codex-mini": "gpt-5.1-codex-mini", + "gpt-5-codex-mini-medium": "gpt-5.1-codex-mini", + "gpt-5-codex-mini-high": "gpt-5.1-codex-mini", + "gpt-5": "gpt-5.1", + "gpt-5-mini": "gpt-5.1", + "gpt-5-nano": "gpt-5.1", } type codexTransformResult struct { @@ -77,12 +70,6 @@ type codexTransformResult struct { PromptCacheKey string } -type opencodeCacheMetadata struct { - ETag string `json:"etag"` - LastFetch string `json:"lastFetch,omitempty"` - LastChecked int64 `json:"lastChecked"` -} - func applyCodexOAuthTransform(reqBody map[string]any, isCodexCLI bool) codexTransformResult { result := codexTransformResult{} // 工具续链需求会影响存储策略与 input 过滤逻辑。 @@ -177,7 +164,7 @@ func normalizeCodexModel(model string) string { return "gpt-5.3-codex" } if strings.Contains(normalized, "gpt-5.3") || strings.Contains(normalized, "gpt 5.3") { - return "gpt-5.3" + return "gpt-5.3-codex" } if strings.Contains(normalized, "gpt-5.1-codex-max") || strings.Contains(normalized, "gpt 5.1 codex max") { return "gpt-5.1-codex-max" @@ -222,54 +209,9 @@ func getNormalizedCodexModel(modelID string) string { return "" } -func getOpenCodeCachedPrompt(url, cacheFileName, metaFileName string) string { - cacheDir := codexCachePath("") - if cacheDir == "" { - return "" - } - cacheFile := filepath.Join(cacheDir, cacheFileName) - metaFile := filepath.Join(cacheDir, metaFileName) - - var cachedContent string - if content, ok := readFile(cacheFile); ok { - cachedContent = content - } - - var meta opencodeCacheMetadata - if loadJSON(metaFile, &meta) && meta.LastChecked > 0 && cachedContent != "" { - if time.Since(time.UnixMilli(meta.LastChecked)) < codexCacheTTL { - return cachedContent - } - } - - content, etag, status, err := fetchWithETag(url, meta.ETag) - if err == nil && status == http.StatusNotModified && cachedContent != "" { - return cachedContent - } - if err == nil && status >= 200 && status < 300 && content != "" { - _ = writeFile(cacheFile, content) - meta = opencodeCacheMetadata{ - ETag: etag, - LastFetch: time.Now().UTC().Format(time.RFC3339), - LastChecked: time.Now().UnixMilli(), - } - _ = writeJSON(metaFile, meta) - return content - } - - return cachedContent -} - func getOpenCodeCodexHeader() string { - // 优先从 opencode 仓库缓存获取指令。 - opencodeInstructions := getOpenCodeCachedPrompt(opencodeCodexHeaderURL, "opencode-codex-header.txt", "opencode-codex-header-meta.json") - - // 若 opencode 指令可用,直接返回。 - if opencodeInstructions != "" { - return opencodeInstructions - } - - // 否则回退使用本地 Codex CLI 指令。 + // 兼容保留:历史上这里会从 opencode 仓库拉取 codex_header.txt。 + // 现在我们与 Codex CLI 一致,直接使用仓库内置的 instructions,避免读写缓存与外网依赖。 return getCodexCLIInstructions() } @@ -287,8 +229,8 @@ func GetCodexCLIInstructions() string { } // applyInstructions 处理 instructions 字段 -// isCodexCLI=true: 仅补充缺失的 instructions(使用 opencode 指令) -// isCodexCLI=false: 优先使用 opencode 指令覆盖 +// isCodexCLI=true: 仅补充缺失的 instructions(使用内置 Codex CLI 指令) +// isCodexCLI=false: 优先使用内置 Codex CLI 指令覆盖 func applyInstructions(reqBody map[string]any, isCodexCLI bool) bool { if isCodexCLI { return applyCodexCLIInstructions(reqBody) @@ -297,13 +239,13 @@ func applyInstructions(reqBody map[string]any, isCodexCLI bool) bool { } // applyCodexCLIInstructions 为 Codex CLI 请求补充缺失的 instructions -// 仅在 instructions 为空时添加 opencode 指令 +// 仅在 instructions 为空时添加内置 Codex CLI 指令(不依赖 opencode 缓存/回源) func applyCodexCLIInstructions(reqBody map[string]any) bool { if !isInstructionsEmpty(reqBody) { return false // 已有有效 instructions,不修改 } - instructions := strings.TrimSpace(getOpenCodeCodexHeader()) + instructions := strings.TrimSpace(getCodexCLIInstructions()) if instructions != "" { reqBody["instructions"] = instructions return true @@ -312,8 +254,8 @@ func applyCodexCLIInstructions(reqBody map[string]any) bool { return false } -// applyOpenCodeInstructions 为非 Codex CLI 请求应用 opencode 指令 -// 优先使用 opencode 指令覆盖 +// applyOpenCodeInstructions 为非 Codex CLI 请求应用内置 Codex CLI 指令(兼容历史函数名) +// 优先使用内置 Codex CLI 指令覆盖 func applyOpenCodeInstructions(reqBody map[string]any) bool { instructions := strings.TrimSpace(getOpenCodeCodexHeader()) existingInstructions, _ := reqBody["instructions"].(string) @@ -495,85 +437,3 @@ func normalizeCodexTools(reqBody map[string]any) bool { return modified } - -func codexCachePath(filename string) string { - home, err := os.UserHomeDir() - if err != nil { - return "" - } - cacheDir := filepath.Join(home, ".opencode", "cache") - if filename == "" { - return cacheDir - } - return filepath.Join(cacheDir, filename) -} - -func readFile(path string) (string, bool) { - if path == "" { - return "", false - } - data, err := os.ReadFile(path) - if err != nil { - return "", false - } - return string(data), true -} - -func writeFile(path, content string) error { - if path == "" { - return fmt.Errorf("empty cache path") - } - if err := os.MkdirAll(filepath.Dir(path), 0o755); err != nil { - return err - } - return os.WriteFile(path, []byte(content), 0o644) -} - -func loadJSON(path string, target any) bool { - data, err := os.ReadFile(path) - if err != nil { - return false - } - if err := json.Unmarshal(data, target); err != nil { - return false - } - return true -} - -func writeJSON(path string, value any) error { - if path == "" { - return fmt.Errorf("empty json path") - } - if err := os.MkdirAll(filepath.Dir(path), 0o755); err != nil { - return err - } - data, err := json.Marshal(value) - if err != nil { - return err - } - return os.WriteFile(path, data, 0o644) -} - -func fetchWithETag(url, etag string) (string, string, int, error) { - req, err := http.NewRequest(http.MethodGet, url, nil) - if err != nil { - return "", "", 0, err - } - req.Header.Set("User-Agent", "sub2api-codex") - if etag != "" { - req.Header.Set("If-None-Match", etag) - } - resp, err := http.DefaultClient.Do(req) - if err != nil { - return "", "", 0, err - } - defer func() { - _ = resp.Body.Close() - }() - - body, err := io.ReadAll(resp.Body) - if err != nil { - return "", "", resp.StatusCode, err - } - return string(body), resp.Header.Get("etag"), resp.StatusCode, nil -} diff --git a/backend/internal/service/openai_codex_transform_test.go b/backend/internal/service/openai_codex_transform_test.go index cc0acafc..27093f6c 100644 --- a/backend/internal/service/openai_codex_transform_test.go +++ b/backend/internal/service/openai_codex_transform_test.go @@ -1,18 +1,13 @@ package service import ( - "encoding/json" - "os" - "path/filepath" "testing" - "time" "github.com/stretchr/testify/require" ) func TestApplyCodexOAuthTransform_ToolContinuationPreservesInput(t *testing.T) { // 续链场景:保留 item_reference 与 id,但不再强制 store=true。 - setupCodexCache(t) reqBody := map[string]any{ "model": "gpt-5.2", @@ -48,7 +43,6 @@ func TestApplyCodexOAuthTransform_ToolContinuationPreservesInput(t *testing.T) { func TestApplyCodexOAuthTransform_ExplicitStoreFalsePreserved(t *testing.T) { // 续链场景:显式 store=false 不再强制为 true,保持 false。 - setupCodexCache(t) reqBody := map[string]any{ "model": "gpt-5.1", @@ -68,7 +62,6 @@ func TestApplyCodexOAuthTransform_ExplicitStoreFalsePreserved(t *testing.T) { func TestApplyCodexOAuthTransform_ExplicitStoreTrueForcedFalse(t *testing.T) { // 显式 store=true 也会强制为 false。 - setupCodexCache(t) reqBody := map[string]any{ "model": "gpt-5.1", @@ -88,7 +81,6 @@ func TestApplyCodexOAuthTransform_ExplicitStoreTrueForcedFalse(t *testing.T) { func TestApplyCodexOAuthTransform_NonContinuationDefaultsStoreFalseAndStripsIDs(t *testing.T) { // 非续链场景:未设置 store 时默认 false,并移除 input 中的 id。 - setupCodexCache(t) reqBody := map[string]any{ "model": "gpt-5.1", @@ -130,8 +122,6 @@ func TestFilterCodexInput_RemovesItemReferenceWhenNotPreserved(t *testing.T) { } func TestApplyCodexOAuthTransform_NormalizeCodexTools_PreservesResponsesFunctionTools(t *testing.T) { - setupCodexCache(t) - reqBody := map[string]any{ "model": "gpt-5.1", "tools": []any{ @@ -162,7 +152,6 @@ func TestApplyCodexOAuthTransform_NormalizeCodexTools_PreservesResponsesFunction func TestApplyCodexOAuthTransform_EmptyInput(t *testing.T) { // 空 input 应保持为空且不触发异常。 - setupCodexCache(t) reqBody := map[string]any{ "model": "gpt-5.1", @@ -178,97 +167,39 @@ func TestApplyCodexOAuthTransform_EmptyInput(t *testing.T) { func TestNormalizeCodexModel_Gpt53(t *testing.T) { cases := map[string]string{ - "gpt-5.3": "gpt-5.3", - "gpt-5.3-codex": "gpt-5.3-codex", - "gpt-5.3-codex-xhigh": "gpt-5.3-codex", - "gpt 5.3 codex": "gpt-5.3-codex", + "gpt-5.3": "gpt-5.3-codex", + "gpt-5.3-codex": "gpt-5.3-codex", + "gpt-5.3-codex-xhigh": "gpt-5.3-codex", + "gpt-5.3-codex-spark": "gpt-5.3-codex", + "gpt-5.3-codex-spark-high": "gpt-5.3-codex", + "gpt-5.3-codex-spark-xhigh": "gpt-5.3-codex", + "gpt 5.3 codex": "gpt-5.3-codex", } for input, expected := range cases { require.Equal(t, expected, normalizeCodexModel(input)) } - } func TestApplyCodexOAuthTransform_CodexCLI_PreservesExistingInstructions(t *testing.T) { - // Codex CLI 场景:已有 instructions 时保持不变 - setupCodexCache(t) + // Codex CLI 场景:已有 instructions 时不修改 reqBody := map[string]any{ "model": "gpt-5.1", - "instructions": "user custom instructions", - "input": []any{}, + "instructions": "existing instructions", } - result := applyCodexOAuthTransform(reqBody, true) + result := applyCodexOAuthTransform(reqBody, true) // isCodexCLI=true instructions, ok := reqBody["instructions"].(string) require.True(t, ok) - require.Equal(t, "user custom instructions", instructions) - // instructions 未变,但其他字段(如 store、stream)可能被修改 - require.True(t, result.Modified) -} - -func TestApplyCodexOAuthTransform_CodexCLI_AddsInstructionsWhenEmpty(t *testing.T) { - // Codex CLI 场景:无 instructions 时补充内置指令 - setupCodexCache(t) - - reqBody := map[string]any{ - "model": "gpt-5.1", - "input": []any{}, - } - - result := applyCodexOAuthTransform(reqBody, true) - - instructions, ok := reqBody["instructions"].(string) - require.True(t, ok) - require.NotEmpty(t, instructions) - require.True(t, result.Modified) -} - -func TestApplyCodexOAuthTransform_NonCodexCLI_UsesOpenCodeInstructions(t *testing.T) { - // 非 Codex CLI 场景:使用 opencode 指令(缓存中有 header) - setupCodexCache(t) - - reqBody := map[string]any{ - "model": "gpt-5.1", - "input": []any{}, - } - - result := applyCodexOAuthTransform(reqBody, false) - - instructions, ok := reqBody["instructions"].(string) - require.True(t, ok) - require.Equal(t, "header", instructions) // setupCodexCache 设置的缓存内容 - require.True(t, result.Modified) -} - -func setupCodexCache(t *testing.T) { - t.Helper() - - // 使用临时 HOME 避免触发网络拉取 header。 - // Windows 使用 USERPROFILE,Unix 使用 HOME。 - tempDir := t.TempDir() - t.Setenv("HOME", tempDir) - t.Setenv("USERPROFILE", tempDir) - - cacheDir := filepath.Join(tempDir, ".opencode", "cache") - require.NoError(t, os.MkdirAll(cacheDir, 0o755)) - require.NoError(t, os.WriteFile(filepath.Join(cacheDir, "opencode-codex-header.txt"), []byte("header"), 0o644)) - - meta := map[string]any{ - "etag": "", - "lastFetch": time.Now().UTC().Format(time.RFC3339), - "lastChecked": time.Now().UnixMilli(), - } - data, err := json.Marshal(meta) - require.NoError(t, err) - require.NoError(t, os.WriteFile(filepath.Join(cacheDir, "opencode-codex-header-meta.json"), data, 0o644)) + require.Equal(t, "existing instructions", instructions) + // Modified 仍可能为 true(因为其他字段被修改),但 instructions 应保持不变 + _ = result } func TestApplyCodexOAuthTransform_CodexCLI_SuppliesDefaultWhenEmpty(t *testing.T) { // Codex CLI 场景:无 instructions 时补充默认值 - setupCodexCache(t) reqBody := map[string]any{ "model": "gpt-5.1", @@ -284,8 +215,7 @@ func TestApplyCodexOAuthTransform_CodexCLI_SuppliesDefaultWhenEmpty(t *testing.T } func TestApplyCodexOAuthTransform_NonCodexCLI_OverridesInstructions(t *testing.T) { - // 非 Codex CLI 场景:使用 opencode 指令覆盖 - setupCodexCache(t) + // 非 Codex CLI 场景:使用内置 Codex CLI 指令覆盖 reqBody := map[string]any{ "model": "gpt-5.1", diff --git a/backend/internal/service/openai_gateway_service.go b/backend/internal/service/openai_gateway_service.go index 6c4fe256..f26ce03f 100644 --- a/backend/internal/service/openai_gateway_service.go +++ b/backend/internal/service/openai_gateway_service.go @@ -10,9 +10,7 @@ import ( "errors" "fmt" "io" - "log" "net/http" - "regexp" "sort" "strconv" "strings" @@ -20,10 +18,14 @@ import ( "time" "github.com/Wei-Shaw/sub2api/internal/config" + "github.com/Wei-Shaw/sub2api/internal/pkg/logger" "github.com/Wei-Shaw/sub2api/internal/pkg/openai" "github.com/Wei-Shaw/sub2api/internal/util/responseheaders" "github.com/Wei-Shaw/sub2api/internal/util/urlvalidator" "github.com/gin-gonic/gin" + "github.com/tidwall/gjson" + "github.com/tidwall/sjson" + "go.uber.org/zap" ) const ( @@ -32,13 +34,15 @@ const ( // OpenAI Platform API for API Key accounts (fallback) openaiPlatformAPIURL = "https://api.openai.com/v1/responses" openaiStickySessionTTL = time.Hour // 粘性会话TTL + codexCLIUserAgent = "codex_cli_rs/0.98.0" + // codex_cli_only 拒绝时单个请求头日志长度上限(字符) + codexCLIOnlyHeaderValueMaxBytes = 256 + + // OpenAIParsedRequestBodyKey 缓存 handler 侧已解析的请求体,避免重复解析。 + OpenAIParsedRequestBodyKey = "openai_parsed_request_body" ) -// openaiSSEDataRe matches SSE data lines with optional whitespace after colon. -// Some upstream APIs return non-standard "data:" without space (should be "data: "). -var openaiSSEDataRe = regexp.MustCompile(`^data:\s*`) - -// OpenAI allowed headers whitelist (for non-OAuth accounts) +// OpenAI allowed headers whitelist (for non-passthrough). var openaiAllowedHeaders = map[string]bool{ "accept-language": true, "content-type": true, @@ -48,6 +52,35 @@ var openaiAllowedHeaders = map[string]bool{ "session_id": true, } +// OpenAI passthrough allowed headers whitelist. +// 透传模式下仅放行这些低风险请求头,避免将非标准/环境噪声头传给上游触发风控。 +var openaiPassthroughAllowedHeaders = map[string]bool{ + "accept": true, + "accept-language": true, + "content-type": true, + "conversation_id": true, + "openai-beta": true, + "user-agent": true, + "originator": true, + "session_id": true, +} + +// codex_cli_only 拒绝时记录的请求头白名单(仅用于诊断日志,不参与上游透传) +var codexCLIOnlyDebugHeaderWhitelist = []string{ + "User-Agent", + "Content-Type", + "Accept", + "Accept-Language", + "OpenAI-Beta", + "Originator", + "Session_ID", + "Conversation_ID", + "X-Request-ID", + "X-Client-Request-ID", + "X-Forwarded-For", + "X-Real-IP", +} + // OpenAICodexUsageSnapshot represents Codex API usage limits from response headers type OpenAICodexUsageSnapshot struct { PrimaryUsedPercent *float64 `json:"primary_used_percent,omitempty"` @@ -175,6 +208,7 @@ type OpenAIGatewayService struct { userSubRepo UserSubscriptionRepository cache GatewayCache cfg *config.Config + codexDetector CodexClientRestrictionDetector schedulerSnapshot *SchedulerSnapshotService concurrencyService *ConcurrencyService billingService *BillingService @@ -210,6 +244,7 @@ func NewOpenAIGatewayService( userSubRepo: userSubRepo, cache: cache, cfg: cfg, + codexDetector: NewOpenAICodexClientRestrictionDetector(cfg), schedulerSnapshot: schedulerSnapshot, concurrencyService: concurrencyService, billingService: billingService, @@ -222,13 +257,228 @@ func NewOpenAIGatewayService( } } +func (s *OpenAIGatewayService) getCodexClientRestrictionDetector() CodexClientRestrictionDetector { + if s != nil && s.codexDetector != nil { + return s.codexDetector + } + var cfg *config.Config + if s != nil { + cfg = s.cfg + } + return NewOpenAICodexClientRestrictionDetector(cfg) +} + +func (s *OpenAIGatewayService) detectCodexClientRestriction(c *gin.Context, account *Account) CodexClientRestrictionDetectionResult { + return s.getCodexClientRestrictionDetector().Detect(c, account) +} + +func getAPIKeyIDFromContext(c *gin.Context) int64 { + if c == nil { + return 0 + } + v, exists := c.Get("api_key") + if !exists { + return 0 + } + apiKey, ok := v.(*APIKey) + if !ok || apiKey == nil { + return 0 + } + return apiKey.ID +} + +func logCodexCLIOnlyDetection(ctx context.Context, c *gin.Context, account *Account, apiKeyID int64, result CodexClientRestrictionDetectionResult, body []byte) { + if !result.Enabled { + return + } + if ctx == nil { + ctx = context.Background() + } + accountID := int64(0) + if account != nil { + accountID = account.ID + } + fields := []zap.Field{ + zap.String("component", "service.openai_gateway"), + zap.Int64("account_id", accountID), + zap.Bool("codex_cli_only_enabled", result.Enabled), + zap.Bool("codex_official_client_match", result.Matched), + zap.String("reject_reason", result.Reason), + } + if apiKeyID > 0 { + fields = append(fields, zap.Int64("api_key_id", apiKeyID)) + } + if !result.Matched { + fields = appendCodexCLIOnlyRejectedRequestFields(fields, c, body) + } + log := logger.FromContext(ctx).With(fields...) + if result.Matched { + return + } + log.Warn("OpenAI codex_cli_only 拒绝非官方客户端请求") +} + +func appendCodexCLIOnlyRejectedRequestFields(fields []zap.Field, c *gin.Context, body []byte) []zap.Field { + if c == nil || c.Request == nil { + return fields + } + + req := c.Request + requestModel, requestStream, promptCacheKey := extractOpenAIRequestMetaFromBody(body) + fields = append(fields, + zap.String("request_method", strings.TrimSpace(req.Method)), + zap.String("request_path", strings.TrimSpace(req.URL.Path)), + zap.String("request_query", strings.TrimSpace(req.URL.RawQuery)), + zap.String("request_host", strings.TrimSpace(req.Host)), + zap.String("request_client_ip", strings.TrimSpace(c.ClientIP())), + zap.String("request_remote_addr", strings.TrimSpace(req.RemoteAddr)), + zap.String("request_user_agent", strings.TrimSpace(req.Header.Get("User-Agent"))), + zap.String("request_content_type", strings.TrimSpace(req.Header.Get("Content-Type"))), + zap.Int64("request_content_length", req.ContentLength), + zap.Bool("request_stream", requestStream), + ) + if requestModel != "" { + fields = append(fields, zap.String("request_model", requestModel)) + } + if promptCacheKey != "" { + fields = append(fields, zap.String("request_prompt_cache_key_sha256", hashSensitiveValueForLog(promptCacheKey))) + } + + if headers := snapshotCodexCLIOnlyHeaders(req.Header); len(headers) > 0 { + fields = append(fields, zap.Any("request_headers", headers)) + } + fields = append(fields, zap.Int("request_body_size", len(body))) + return fields +} + +func snapshotCodexCLIOnlyHeaders(header http.Header) map[string]string { + if len(header) == 0 { + return nil + } + result := make(map[string]string, len(codexCLIOnlyDebugHeaderWhitelist)) + for _, key := range codexCLIOnlyDebugHeaderWhitelist { + value := strings.TrimSpace(header.Get(key)) + if value == "" { + continue + } + result[strings.ToLower(key)] = truncateString(value, codexCLIOnlyHeaderValueMaxBytes) + } + return result +} + +func hashSensitiveValueForLog(raw string) string { + value := strings.TrimSpace(raw) + if value == "" { + return "" + } + sum := sha256.Sum256([]byte(value)) + return hex.EncodeToString(sum[:8]) +} + +func logOpenAIInstructionsRequiredDebug( + ctx context.Context, + c *gin.Context, + account *Account, + upstreamStatusCode int, + upstreamMsg string, + requestBody []byte, + upstreamBody []byte, +) { + msg := strings.TrimSpace(upstreamMsg) + if !isOpenAIInstructionsRequiredError(upstreamStatusCode, msg, upstreamBody) { + return + } + if ctx == nil { + ctx = context.Background() + } + + accountID := int64(0) + accountName := "" + if account != nil { + accountID = account.ID + accountName = strings.TrimSpace(account.Name) + } + + userAgent := "" + if c != nil { + userAgent = strings.TrimSpace(c.GetHeader("User-Agent")) + } + + fields := []zap.Field{ + zap.String("component", "service.openai_gateway"), + zap.Int64("account_id", accountID), + zap.String("account_name", accountName), + zap.Int("upstream_status_code", upstreamStatusCode), + zap.String("upstream_error_message", msg), + zap.String("request_user_agent", userAgent), + zap.Bool("codex_official_client_match", openai.IsCodexCLIRequest(userAgent)), + } + fields = appendCodexCLIOnlyRejectedRequestFields(fields, c, requestBody) + + logger.FromContext(ctx).With(fields...).Warn("OpenAI 上游返回 Instructions are required,已记录请求详情用于排查") +} + +func isOpenAIInstructionsRequiredError(upstreamStatusCode int, upstreamMsg string, upstreamBody []byte) bool { + if upstreamStatusCode != http.StatusBadRequest { + return false + } + + hasInstructionRequired := func(text string) bool { + lower := strings.ToLower(strings.TrimSpace(text)) + if lower == "" { + return false + } + if strings.Contains(lower, "instructions are required") { + return true + } + if strings.Contains(lower, "required parameter: 'instructions'") { + return true + } + if strings.Contains(lower, "required parameter: instructions") { + return true + } + if strings.Contains(lower, "missing required parameter") && strings.Contains(lower, "instructions") { + return true + } + return strings.Contains(lower, "instruction") && strings.Contains(lower, "required") + } + + if hasInstructionRequired(upstreamMsg) { + return true + } + if len(upstreamBody) == 0 { + return false + } + + errMsg := gjson.GetBytes(upstreamBody, "error.message").String() + errMsgLower := strings.ToLower(strings.TrimSpace(errMsg)) + errCode := strings.ToLower(strings.TrimSpace(gjson.GetBytes(upstreamBody, "error.code").String())) + errParam := strings.ToLower(strings.TrimSpace(gjson.GetBytes(upstreamBody, "error.param").String())) + errType := strings.ToLower(strings.TrimSpace(gjson.GetBytes(upstreamBody, "error.type").String())) + + if errParam == "instructions" { + return true + } + if hasInstructionRequired(errMsg) { + return true + } + if strings.Contains(errCode, "missing_required_parameter") && strings.Contains(errMsgLower, "instructions") { + return true + } + if strings.Contains(errType, "invalid_request") && strings.Contains(errMsgLower, "instructions") && strings.Contains(errMsgLower, "required") { + return true + } + + return false +} + // GenerateSessionHash generates a sticky-session hash for OpenAI requests. // // Priority: // 1. Header: session_id // 2. Header: conversation_id // 3. Body: prompt_cache_key (opencode) -func (s *OpenAIGatewayService) GenerateSessionHash(c *gin.Context, reqBody map[string]any) string { +func (s *OpenAIGatewayService) GenerateSessionHash(c *gin.Context, body []byte) string { if c == nil { return "" } @@ -237,10 +487,8 @@ func (s *OpenAIGatewayService) GenerateSessionHash(c *gin.Context, reqBody map[s if sessionID == "" { sessionID = strings.TrimSpace(c.GetHeader("conversation_id")) } - if sessionID == "" && reqBody != nil { - if v, ok := reqBody["prompt_cache_key"].(string); ok { - sessionID = strings.TrimSpace(v) - } + if sessionID == "" && len(body) > 0 { + sessionID = strings.TrimSpace(gjson.GetBytes(body, "prompt_cache_key").String()) } if sessionID == "" { return "" @@ -744,30 +992,64 @@ func (s *OpenAIGatewayService) handleFailoverSideEffects(ctx context.Context, re func (s *OpenAIGatewayService) Forward(ctx context.Context, c *gin.Context, account *Account, body []byte) (*OpenAIForwardResult, error) { startTime := time.Now() - // Parse request body once (avoid multiple parse/serialize cycles) - var reqBody map[string]any - if err := json.Unmarshal(body, &reqBody); err != nil { - return nil, fmt.Errorf("parse request: %w", err) + restrictionResult := s.detectCodexClientRestriction(c, account) + apiKeyID := getAPIKeyIDFromContext(c) + logCodexCLIOnlyDetection(ctx, c, account, apiKeyID, restrictionResult, body) + if restrictionResult.Enabled && !restrictionResult.Matched { + c.JSON(http.StatusForbidden, gin.H{ + "error": gin.H{ + "type": "forbidden_error", + "message": "This account only allows Codex official clients", + }, + }) + return nil, errors.New("codex_cli_only restriction: only codex official clients are allowed") } - // Extract model and stream from parsed body - reqModel, _ := reqBody["model"].(string) - reqStream, _ := reqBody["stream"].(bool) - promptCacheKey := "" - if v, ok := reqBody["prompt_cache_key"].(string); ok { - promptCacheKey = strings.TrimSpace(v) + originalBody := body + reqModel, reqStream, promptCacheKey := extractOpenAIRequestMetaFromBody(body) + originalModel := reqModel + + isCodexCLI := openai.IsCodexCLIRequest(c.GetHeader("User-Agent")) || (s.cfg != nil && s.cfg.Gateway.ForceCodexCLI) + passthroughEnabled := account.IsOpenAIPassthroughEnabled() + if passthroughEnabled { + // 透传分支只需要轻量提取字段,避免热路径全量 Unmarshal。 + reasoningEffort := extractOpenAIReasoningEffortFromBody(body, reqModel) + return s.forwardOpenAIPassthrough(ctx, c, account, originalBody, reqModel, reasoningEffort, reqStream, startTime) + } + + reqBody, err := getOpenAIRequestBodyMap(c, body) + if err != nil { + return nil, err + } + + if v, ok := reqBody["model"].(string); ok { + reqModel = v + originalModel = reqModel + } + if v, ok := reqBody["stream"].(bool); ok { + reqStream = v + } + if promptCacheKey == "" { + if v, ok := reqBody["prompt_cache_key"].(string); ok { + promptCacheKey = strings.TrimSpace(v) + } } // Track if body needs re-serialization bodyModified := false - originalModel := reqModel - isCodexCLI := openai.IsCodexCLIRequest(c.GetHeader("User-Agent")) + // 非透传模式下,保持历史行为:非 Codex CLI 请求在 instructions 为空时注入默认指令。 + if !isCodexCLI && isInstructionsEmpty(reqBody) { + if instructions := strings.TrimSpace(GetOpenCodeInstructions()); instructions != "" { + reqBody["instructions"] = instructions + bodyModified = true + } + } // 对所有请求执行模型映射(包含 Codex CLI)。 mappedModel := account.GetMappedModel(reqModel) if mappedModel != reqModel { - log.Printf("[OpenAI] Model mapping applied: %s -> %s (account: %s, isCodexCLI: %v)", reqModel, mappedModel, account.Name, isCodexCLI) + logger.LegacyPrintf("service.openai_gateway", "[OpenAI] Model mapping applied: %s -> %s (account: %s, isCodexCLI: %v)", reqModel, mappedModel, account.Name, isCodexCLI) reqBody["model"] = mappedModel bodyModified = true } @@ -776,7 +1058,7 @@ func (s *OpenAIGatewayService) Forward(ctx context.Context, c *gin.Context, acco if model, ok := reqBody["model"].(string); ok { normalizedModel := normalizeCodexModel(model) if normalizedModel != "" && normalizedModel != model { - log.Printf("[OpenAI] Codex model normalization: %s -> %s (account: %s, type: %s, isCodexCLI: %v)", + logger.LegacyPrintf("service.openai_gateway", "[OpenAI] Codex model normalization: %s -> %s (account: %s, type: %s, isCodexCLI: %v)", model, normalizedModel, account.Name, account.Type, isCodexCLI) reqBody["model"] = normalizedModel mappedModel = normalizedModel @@ -789,7 +1071,7 @@ func (s *OpenAIGatewayService) Forward(ctx context.Context, c *gin.Context, acco if effort, ok := reasoning["effort"].(string); ok && effort == "minimal" { reasoning["effort"] = "none" bodyModified = true - log.Printf("[OpenAI] Normalized reasoning.effort: minimal -> none (account: %s)", account.Name) + logger.LegacyPrintf("service.openai_gateway", "[OpenAI] Normalized reasoning.effort: minimal -> none (account: %s)", account.Name) } } @@ -880,12 +1162,12 @@ func (s *OpenAIGatewayService) Forward(ctx context.Context, c *gin.Context, acco } // Capture upstream request body for ops retry of this attempt. - if c != nil { - c.Set(OpsUpstreamRequestBodyKey, string(body)) - } + setOpsUpstreamRequestBody(c, body) // Send request + upstreamStart := time.Now() resp, err := s.httpUpstream.Do(upstreamReq, proxyURL, account.ID, account.Concurrency) + SetOpsLatencyMs(c, OpsUpstreamLatencyMsKey, time.Since(upstreamStart).Milliseconds()) if err != nil { // Ensure the client receives an error response (handlers assume Forward writes on non-failover errors). safeErr := sanitizeUpstreamErrorMessage(err.Error()) @@ -939,7 +1221,7 @@ func (s *OpenAIGatewayService) Forward(ctx context.Context, c *gin.Context, acco s.handleFailoverSideEffects(ctx, resp, account) return nil, &UpstreamFailoverError{StatusCode: resp.StatusCode, ResponseBody: respBody} } - return s.handleErrorResponse(ctx, resp, c, account) + return s.handleErrorResponse(ctx, resp, c, account, body) } // Handle normal response @@ -966,6 +1248,10 @@ func (s *OpenAIGatewayService) Forward(ctx context.Context, c *gin.Context, acco } } + if usage == nil { + usage = &OpenAIUsage{} + } + reasoningEffort := extractOpenAIReasoningEffort(reqBody, originalModel) return &OpenAIForwardResult{ @@ -979,6 +1265,579 @@ func (s *OpenAIGatewayService) Forward(ctx context.Context, c *gin.Context, acco }, nil } +func (s *OpenAIGatewayService) forwardOpenAIPassthrough( + ctx context.Context, + c *gin.Context, + account *Account, + body []byte, + reqModel string, + reasoningEffort *string, + reqStream bool, + startTime time.Time, +) (*OpenAIForwardResult, error) { + if account != nil && account.Type == AccountTypeOAuth { + if rejectReason := detectOpenAIPassthroughInstructionsRejectReason(reqModel, body); rejectReason != "" { + rejectMsg := "OpenAI codex passthrough requires a non-empty instructions field" + setOpsUpstreamError(c, http.StatusForbidden, rejectMsg, "") + appendOpsUpstreamError(c, OpsUpstreamErrorEvent{ + Platform: account.Platform, + AccountID: account.ID, + AccountName: account.Name, + UpstreamStatusCode: http.StatusForbidden, + Passthrough: true, + Kind: "request_error", + Message: rejectMsg, + Detail: rejectReason, + }) + logOpenAIPassthroughInstructionsRejected(ctx, c, account, reqModel, rejectReason, body) + c.JSON(http.StatusForbidden, gin.H{ + "error": gin.H{ + "type": "forbidden_error", + "message": rejectMsg, + }, + }) + return nil, fmt.Errorf("openai passthrough rejected before upstream: %s", rejectReason) + } + + normalizedBody, normalized, err := normalizeOpenAIPassthroughOAuthBody(body) + if err != nil { + return nil, err + } + if normalized { + body = normalizedBody + reqStream = true + } + } + + logger.LegacyPrintf("service.openai_gateway", + "[OpenAI 自动透传] 命中自动透传分支: account=%d name=%s type=%s model=%s stream=%v", + account.ID, + account.Name, + account.Type, + reqModel, + reqStream, + ) + if reqStream && c != nil && c.Request != nil { + if timeoutHeaders := collectOpenAIPassthroughTimeoutHeaders(c.Request.Header); len(timeoutHeaders) > 0 { + streamWarnLogger := logger.FromContext(ctx).With( + zap.String("component", "service.openai_gateway"), + zap.Int64("account_id", account.ID), + zap.Strings("timeout_headers", timeoutHeaders), + ) + if s.isOpenAIPassthroughTimeoutHeadersAllowed() { + streamWarnLogger.Warn("OpenAI passthrough 透传请求包含超时相关请求头,且当前配置为放行,可能导致上游提前断流") + } else { + streamWarnLogger.Warn("OpenAI passthrough 检测到超时相关请求头,将按配置过滤以降低断流风险") + } + } + } + + // Get access token + token, _, err := s.GetAccessToken(ctx, account) + if err != nil { + return nil, err + } + + upstreamReq, err := s.buildUpstreamRequestOpenAIPassthrough(ctx, c, account, body, token) + if err != nil { + return nil, err + } + + proxyURL := "" + if account.ProxyID != nil && account.Proxy != nil { + proxyURL = account.Proxy.URL() + } + + setOpsUpstreamRequestBody(c, body) + if c != nil { + c.Set("openai_passthrough", true) + } + + upstreamStart := time.Now() + resp, err := s.httpUpstream.Do(upstreamReq, proxyURL, account.ID, account.Concurrency) + SetOpsLatencyMs(c, OpsUpstreamLatencyMsKey, time.Since(upstreamStart).Milliseconds()) + if err != nil { + safeErr := sanitizeUpstreamErrorMessage(err.Error()) + setOpsUpstreamError(c, 0, safeErr, "") + appendOpsUpstreamError(c, OpsUpstreamErrorEvent{ + Platform: account.Platform, + AccountID: account.ID, + AccountName: account.Name, + UpstreamStatusCode: 0, + Passthrough: true, + Kind: "request_error", + Message: safeErr, + }) + c.JSON(http.StatusBadGateway, gin.H{ + "error": gin.H{ + "type": "upstream_error", + "message": "Upstream request failed", + }, + }) + return nil, fmt.Errorf("upstream request failed: %s", safeErr) + } + defer func() { _ = resp.Body.Close() }() + + if resp.StatusCode >= 400 { + // 透传模式不做 failover(避免改变原始上游语义),按上游原样返回错误响应。 + return nil, s.handleErrorResponsePassthrough(ctx, resp, c, account, body) + } + + var usage *OpenAIUsage + var firstTokenMs *int + if reqStream { + result, err := s.handleStreamingResponsePassthrough(ctx, resp, c, account, startTime) + if err != nil { + return nil, err + } + usage = result.usage + firstTokenMs = result.firstTokenMs + } else { + usage, err = s.handleNonStreamingResponsePassthrough(ctx, resp, c) + if err != nil { + return nil, err + } + } + + if snapshot := ParseCodexRateLimitHeaders(resp.Header); snapshot != nil { + s.updateCodexUsageSnapshot(ctx, account.ID, snapshot) + } + + if usage == nil { + usage = &OpenAIUsage{} + } + + return &OpenAIForwardResult{ + RequestID: resp.Header.Get("x-request-id"), + Usage: *usage, + Model: reqModel, + ReasoningEffort: reasoningEffort, + Stream: reqStream, + Duration: time.Since(startTime), + FirstTokenMs: firstTokenMs, + }, nil +} + +func logOpenAIPassthroughInstructionsRejected( + ctx context.Context, + c *gin.Context, + account *Account, + reqModel string, + rejectReason string, + body []byte, +) { + if ctx == nil { + ctx = context.Background() + } + accountID := int64(0) + accountName := "" + accountType := "" + if account != nil { + accountID = account.ID + accountName = strings.TrimSpace(account.Name) + accountType = strings.TrimSpace(string(account.Type)) + } + fields := []zap.Field{ + zap.String("component", "service.openai_gateway"), + zap.Int64("account_id", accountID), + zap.String("account_name", accountName), + zap.String("account_type", accountType), + zap.String("request_model", strings.TrimSpace(reqModel)), + zap.String("reject_reason", strings.TrimSpace(rejectReason)), + } + fields = appendCodexCLIOnlyRejectedRequestFields(fields, c, body) + logger.FromContext(ctx).With(fields...).Warn("OpenAI passthrough 本地拦截:Codex 请求缺少有效 instructions") +} + +func (s *OpenAIGatewayService) buildUpstreamRequestOpenAIPassthrough( + ctx context.Context, + c *gin.Context, + account *Account, + body []byte, + token string, +) (*http.Request, error) { + targetURL := openaiPlatformAPIURL + switch account.Type { + case AccountTypeOAuth: + targetURL = chatgptCodexURL + case AccountTypeAPIKey: + baseURL := account.GetOpenAIBaseURL() + if baseURL != "" { + validatedURL, err := s.validateUpstreamBaseURL(baseURL) + if err != nil { + return nil, err + } + targetURL = buildOpenAIResponsesURL(validatedURL) + } + } + + req, err := http.NewRequestWithContext(ctx, http.MethodPost, targetURL, bytes.NewReader(body)) + if err != nil { + return nil, err + } + + // 透传客户端请求头(安全白名单)。 + allowTimeoutHeaders := s.isOpenAIPassthroughTimeoutHeadersAllowed() + if c != nil && c.Request != nil { + for key, values := range c.Request.Header { + lower := strings.ToLower(strings.TrimSpace(key)) + if !isOpenAIPassthroughAllowedRequestHeader(lower, allowTimeoutHeaders) { + continue + } + for _, v := range values { + req.Header.Add(key, v) + } + } + } + + // 覆盖入站鉴权残留,并注入上游认证 + req.Header.Del("authorization") + req.Header.Del("x-api-key") + req.Header.Del("x-goog-api-key") + req.Header.Set("authorization", "Bearer "+token) + + // OAuth 透传到 ChatGPT internal API 时补齐必要头。 + if account.Type == AccountTypeOAuth { + promptCacheKey := strings.TrimSpace(gjson.GetBytes(body, "prompt_cache_key").String()) + req.Host = "chatgpt.com" + if chatgptAccountID := account.GetChatGPTAccountID(); chatgptAccountID != "" { + req.Header.Set("chatgpt-account-id", chatgptAccountID) + } + if req.Header.Get("accept") == "" { + req.Header.Set("accept", "text/event-stream") + } + if req.Header.Get("OpenAI-Beta") == "" { + req.Header.Set("OpenAI-Beta", "responses=experimental") + } + if req.Header.Get("originator") == "" { + req.Header.Set("originator", "codex_cli_rs") + } + if promptCacheKey != "" { + if req.Header.Get("conversation_id") == "" { + req.Header.Set("conversation_id", promptCacheKey) + } + if req.Header.Get("session_id") == "" { + req.Header.Set("session_id", promptCacheKey) + } + } + } + + // 透传模式也支持账户自定义 User-Agent 与 ForceCodexCLI 兜底。 + customUA := account.GetOpenAIUserAgent() + if customUA != "" { + req.Header.Set("user-agent", customUA) + } + if s.cfg != nil && s.cfg.Gateway.ForceCodexCLI { + req.Header.Set("user-agent", codexCLIUserAgent) + } + // OAuth 安全透传:对非 Codex UA 统一兜底,降低被上游风控拦截概率。 + if account.Type == AccountTypeOAuth && !openai.IsCodexCLIRequest(req.Header.Get("user-agent")) { + req.Header.Set("user-agent", codexCLIUserAgent) + } + + if req.Header.Get("content-type") == "" { + req.Header.Set("content-type", "application/json") + } + + return req, nil +} + +func (s *OpenAIGatewayService) handleErrorResponsePassthrough( + ctx context.Context, + resp *http.Response, + c *gin.Context, + account *Account, + requestBody []byte, +) error { + body, _ := io.ReadAll(io.LimitReader(resp.Body, 2<<20)) + + upstreamMsg := strings.TrimSpace(extractUpstreamErrorMessage(body)) + upstreamMsg = sanitizeUpstreamErrorMessage(upstreamMsg) + upstreamDetail := "" + if s.cfg != nil && s.cfg.Gateway.LogUpstreamErrorBody { + maxBytes := s.cfg.Gateway.LogUpstreamErrorBodyMaxBytes + if maxBytes <= 0 { + maxBytes = 2048 + } + upstreamDetail = truncateString(string(body), maxBytes) + } + setOpsUpstreamError(c, resp.StatusCode, upstreamMsg, upstreamDetail) + logOpenAIInstructionsRequiredDebug(ctx, c, account, resp.StatusCode, upstreamMsg, requestBody, body) + appendOpsUpstreamError(c, OpsUpstreamErrorEvent{ + Platform: account.Platform, + AccountID: account.ID, + AccountName: account.Name, + UpstreamStatusCode: resp.StatusCode, + UpstreamRequestID: resp.Header.Get("x-request-id"), + Passthrough: true, + Kind: "http_error", + Message: upstreamMsg, + Detail: upstreamDetail, + UpstreamResponseBody: upstreamDetail, + }) + + writeOpenAIPassthroughResponseHeaders(c.Writer.Header(), resp.Header, s.cfg) + contentType := resp.Header.Get("Content-Type") + if contentType == "" { + contentType = "application/json" + } + c.Data(resp.StatusCode, contentType, body) + + if upstreamMsg == "" { + return fmt.Errorf("upstream error: %d", resp.StatusCode) + } + return fmt.Errorf("upstream error: %d message=%s", resp.StatusCode, upstreamMsg) +} + +func isOpenAIPassthroughAllowedRequestHeader(lowerKey string, allowTimeoutHeaders bool) bool { + if lowerKey == "" { + return false + } + if isOpenAIPassthroughTimeoutHeader(lowerKey) { + return allowTimeoutHeaders + } + return openaiPassthroughAllowedHeaders[lowerKey] +} + +func isOpenAIPassthroughTimeoutHeader(lowerKey string) bool { + switch lowerKey { + case "x-stainless-timeout", "x-stainless-read-timeout", "x-stainless-connect-timeout", "x-request-timeout", "request-timeout", "grpc-timeout": + return true + default: + return false + } +} + +func (s *OpenAIGatewayService) isOpenAIPassthroughTimeoutHeadersAllowed() bool { + return s != nil && s.cfg != nil && s.cfg.Gateway.OpenAIPassthroughAllowTimeoutHeaders +} + +func collectOpenAIPassthroughTimeoutHeaders(h http.Header) []string { + if h == nil { + return nil + } + var matched []string + for key, values := range h { + lowerKey := strings.ToLower(strings.TrimSpace(key)) + if isOpenAIPassthroughTimeoutHeader(lowerKey) { + entry := lowerKey + if len(values) > 0 { + entry = fmt.Sprintf("%s=%s", lowerKey, strings.Join(values, "|")) + } + matched = append(matched, entry) + } + } + sort.Strings(matched) + return matched +} + +type openaiStreamingResultPassthrough struct { + usage *OpenAIUsage + firstTokenMs *int +} + +func (s *OpenAIGatewayService) handleStreamingResponsePassthrough( + ctx context.Context, + resp *http.Response, + c *gin.Context, + account *Account, + startTime time.Time, +) (*openaiStreamingResultPassthrough, error) { + writeOpenAIPassthroughResponseHeaders(c.Writer.Header(), resp.Header, s.cfg) + + // SSE headers + c.Header("Content-Type", "text/event-stream") + c.Header("Cache-Control", "no-cache") + c.Header("Connection", "keep-alive") + c.Header("X-Accel-Buffering", "no") + if v := resp.Header.Get("x-request-id"); v != "" { + c.Header("x-request-id", v) + } + + w := c.Writer + flusher, ok := w.(http.Flusher) + if !ok { + return nil, errors.New("streaming not supported") + } + + usage := &OpenAIUsage{} + var firstTokenMs *int + clientDisconnected := false + sawDone := false + upstreamRequestID := strings.TrimSpace(resp.Header.Get("x-request-id")) + + scanner := bufio.NewScanner(resp.Body) + maxLineSize := defaultMaxLineSize + if s.cfg != nil && s.cfg.Gateway.MaxLineSize > 0 { + maxLineSize = s.cfg.Gateway.MaxLineSize + } + scanBuf := getSSEScannerBuf64K() + scanner.Buffer(scanBuf[:0], maxLineSize) + defer putSSEScannerBuf64K(scanBuf) + + for scanner.Scan() { + line := scanner.Text() + if data, ok := extractOpenAISSEDataLine(line); ok { + trimmedData := strings.TrimSpace(data) + if trimmedData == "[DONE]" { + sawDone = true + } + if firstTokenMs == nil && trimmedData != "" && trimmedData != "[DONE]" { + ms := int(time.Since(startTime).Milliseconds()) + firstTokenMs = &ms + } + s.parseSSEUsage(data, usage) + } + + if !clientDisconnected { + if _, err := fmt.Fprintln(w, line); err != nil { + clientDisconnected = true + logger.LegacyPrintf("service.openai_gateway", "[OpenAI passthrough] Client disconnected during streaming, continue draining upstream for usage: account=%d", account.ID) + } else { + flusher.Flush() + } + } + } + if err := scanner.Err(); err != nil { + if clientDisconnected { + logger.LegacyPrintf("service.openai_gateway", "[OpenAI passthrough] Upstream read error after client disconnect: account=%d err=%v", account.ID, err) + return &openaiStreamingResultPassthrough{usage: usage, firstTokenMs: firstTokenMs}, nil + } + if errors.Is(err, context.Canceled) || errors.Is(err, context.DeadlineExceeded) { + logger.LegacyPrintf("service.openai_gateway", + "[OpenAI passthrough] 流读取被取消,可能发生断流: account=%d request_id=%s err=%v ctx_err=%v", + account.ID, + upstreamRequestID, + err, + ctx.Err(), + ) + return &openaiStreamingResultPassthrough{usage: usage, firstTokenMs: firstTokenMs}, nil + } + if errors.Is(err, bufio.ErrTooLong) { + logger.LegacyPrintf("service.openai_gateway", "[OpenAI passthrough] SSE line too long: account=%d max_size=%d error=%v", account.ID, maxLineSize, err) + return &openaiStreamingResultPassthrough{usage: usage, firstTokenMs: firstTokenMs}, err + } + logger.LegacyPrintf("service.openai_gateway", + "[OpenAI passthrough] 流读取异常中断: account=%d request_id=%s err=%v", + account.ID, + upstreamRequestID, + err, + ) + return &openaiStreamingResultPassthrough{usage: usage, firstTokenMs: firstTokenMs}, fmt.Errorf("stream read error: %w", err) + } + if !clientDisconnected && !sawDone && ctx.Err() == nil { + logger.FromContext(ctx).With( + zap.String("component", "service.openai_gateway"), + zap.Int64("account_id", account.ID), + zap.String("upstream_request_id", upstreamRequestID), + ).Info("OpenAI passthrough 上游流在未收到 [DONE] 时结束,疑似断流") + } + + return &openaiStreamingResultPassthrough{usage: usage, firstTokenMs: firstTokenMs}, nil +} + +func (s *OpenAIGatewayService) handleNonStreamingResponsePassthrough( + ctx context.Context, + resp *http.Response, + c *gin.Context, +) (*OpenAIUsage, error) { + maxBytes := resolveUpstreamResponseReadLimit(s.cfg) + body, err := readUpstreamResponseBodyLimited(resp.Body, maxBytes) + if err != nil { + if errors.Is(err, ErrUpstreamResponseBodyTooLarge) { + setOpsUpstreamError(c, http.StatusBadGateway, "upstream response too large", "") + c.JSON(http.StatusBadGateway, gin.H{ + "error": gin.H{ + "type": "upstream_error", + "message": "Upstream response too large", + }, + }) + } + return nil, err + } + + usage := &OpenAIUsage{} + usageParsed := false + if len(body) > 0 { + var response struct { + Usage struct { + InputTokens int `json:"input_tokens"` + OutputTokens int `json:"output_tokens"` + InputTokenDetails struct { + CachedTokens int `json:"cached_tokens"` + } `json:"input_tokens_details"` + } `json:"usage"` + } + if json.Unmarshal(body, &response) == nil { + usage.InputTokens = response.Usage.InputTokens + usage.OutputTokens = response.Usage.OutputTokens + usage.CacheReadInputTokens = response.Usage.InputTokenDetails.CachedTokens + usageParsed = true + } + } + if !usageParsed { + // 兜底:尝试从 SSE 文本中解析 usage + usage = s.parseSSEUsageFromBody(string(body)) + } + + writeOpenAIPassthroughResponseHeaders(c.Writer.Header(), resp.Header, s.cfg) + + contentType := resp.Header.Get("Content-Type") + if contentType == "" { + contentType = "application/json" + } + c.Data(resp.StatusCode, contentType, body) + return usage, nil +} + +func writeOpenAIPassthroughResponseHeaders(dst http.Header, src http.Header, cfg *config.Config) { + if dst == nil || src == nil { + return + } + if cfg != nil { + responseheaders.WriteFilteredHeaders(dst, src, cfg.Security.ResponseHeaders) + } else { + // 兜底:尽量保留最基础的 content-type + if v := strings.TrimSpace(src.Get("Content-Type")); v != "" { + dst.Set("Content-Type", v) + } + } + // 透传模式强制放行 x-codex-* 响应头(若上游返回)。 + // 注意:真实 http.Response.Header 的 key 一般会被 canonicalize;但为了兼容测试/自建响应, + // 这里用 EqualFold 做一次大小写不敏感的查找。 + getCaseInsensitiveValues := func(h http.Header, want string) []string { + if h == nil { + return nil + } + for k, vals := range h { + if strings.EqualFold(k, want) { + return vals + } + } + return nil + } + + for _, rawKey := range []string{ + "x-codex-primary-used-percent", + "x-codex-primary-reset-after-seconds", + "x-codex-primary-window-minutes", + "x-codex-secondary-used-percent", + "x-codex-secondary-reset-after-seconds", + "x-codex-secondary-window-minutes", + "x-codex-primary-over-secondary-limit-percent", + } { + vals := getCaseInsensitiveValues(src, rawKey) + if len(vals) == 0 { + continue + } + key := http.CanonicalHeaderKey(rawKey) + dst.Del(key) + for _, v := range vals { + dst.Add(key, v) + } + } +} + func (s *OpenAIGatewayService) buildUpstreamRequest(ctx context.Context, c *gin.Context, account *Account, body []byte, token string, isStream bool, promptCacheKey string, isCodexCLI bool) (*http.Request, error) { // Determine target URL based on account type var targetURL string @@ -996,7 +1855,7 @@ func (s *OpenAIGatewayService) buildUpstreamRequest(ctx context.Context, c *gin. if err != nil { return nil, err } - targetURL = validatedURL + "/responses" + targetURL = buildOpenAIResponsesURL(validatedURL) } default: targetURL = openaiPlatformAPIURL @@ -1050,6 +1909,12 @@ func (s *OpenAIGatewayService) buildUpstreamRequest(ctx context.Context, c *gin. req.Header.Set("user-agent", customUA) } + // 若开启 ForceCodexCLI,则强制将上游 User-Agent 伪装为 Codex CLI。 + // 用于网关未透传/改写 User-Agent 时,仍能命中 Codex 侧识别逻辑。 + if s.cfg != nil && s.cfg.Gateway.ForceCodexCLI { + req.Header.Set("user-agent", codexCLIUserAgent) + } + // Ensure required headers exist if req.Header.Get("content-type") == "" { req.Header.Set("content-type", "application/json") @@ -1058,7 +1923,13 @@ func (s *OpenAIGatewayService) buildUpstreamRequest(ctx context.Context, c *gin. return req, nil } -func (s *OpenAIGatewayService) handleErrorResponse(ctx context.Context, resp *http.Response, c *gin.Context, account *Account) (*OpenAIForwardResult, error) { +func (s *OpenAIGatewayService) handleErrorResponse( + ctx context.Context, + resp *http.Response, + c *gin.Context, + account *Account, + requestBody []byte, +) (*OpenAIForwardResult, error) { body, _ := io.ReadAll(io.LimitReader(resp.Body, 2<<20)) upstreamMsg := strings.TrimSpace(extractUpstreamErrorMessage(body)) @@ -1072,9 +1943,10 @@ func (s *OpenAIGatewayService) handleErrorResponse(ctx context.Context, resp *ht upstreamDetail = truncateString(string(body), maxBytes) } setOpsUpstreamError(c, resp.StatusCode, upstreamMsg, upstreamDetail) + logOpenAIInstructionsRequiredDebug(ctx, c, account, resp.StatusCode, upstreamMsg, requestBody, body) if s.cfg != nil && s.cfg.Gateway.LogUpstreamErrorBody { - log.Printf( + logger.LegacyPrintf("service.openai_gateway", "OpenAI upstream error %d (account=%d platform=%s type=%s): %s", resp.StatusCode, account.ID, @@ -1230,7 +2102,8 @@ func (s *OpenAIGatewayService) handleStreamingResponse(ctx context.Context, resp if s.cfg != nil && s.cfg.Gateway.MaxLineSize > 0 { maxLineSize = s.cfg.Gateway.MaxLineSize } - scanner.Buffer(make([]byte, 64*1024), maxLineSize) + scanBuf := getSSEScannerBuf64K() + scanner.Buffer(scanBuf[:0], maxLineSize) type scanEvent struct { line string @@ -1249,7 +2122,8 @@ func (s *OpenAIGatewayService) handleStreamingResponse(ctx context.Context, resp } var lastReadAt int64 atomic.StoreInt64(&lastReadAt, time.Now().UnixNano()) - go func() { + go func(scanBuf *sseScannerBuf64K) { + defer putSSEScannerBuf64K(scanBuf) defer close(events) for scanner.Scan() { atomic.StoreInt64(&lastReadAt, time.Now().UnixNano()) @@ -1260,7 +2134,7 @@ func (s *OpenAIGatewayService) handleStreamingResponse(ctx context.Context, resp if err := scanner.Err(); err != nil { _ = sendEvent(scanEvent{err: err}) } - }() + }(scanBuf) defer close(done) streamInterval := time.Duration(0) @@ -1332,16 +2206,16 @@ func (s *OpenAIGatewayService) handleStreamingResponse(ctx context.Context, resp // 客户端断开/取消请求时,上游读取往往会返回 context canceled。 // /v1/responses 的 SSE 事件必须符合 OpenAI 协议;这里不注入自定义 error event,避免下游 SDK 解析失败。 if errors.Is(ev.err, context.Canceled) || errors.Is(ev.err, context.DeadlineExceeded) { - log.Printf("Context canceled during streaming, returning collected usage") + logger.LegacyPrintf("service.openai_gateway", "Context canceled during streaming, returning collected usage") return &openaiStreamingResult{usage: usage, firstTokenMs: firstTokenMs}, nil } // 客户端已断开时,上游出错仅影响体验,不影响计费;返回已收集 usage if clientDisconnected { - log.Printf("Upstream read error after client disconnect: %v, returning collected usage", ev.err) + logger.LegacyPrintf("service.openai_gateway", "Upstream read error after client disconnect: %v, returning collected usage", ev.err) return &openaiStreamingResult{usage: usage, firstTokenMs: firstTokenMs}, nil } if errors.Is(ev.err, bufio.ErrTooLong) { - log.Printf("SSE line too long: account=%d max_size=%d error=%v", account.ID, maxLineSize, ev.err) + logger.LegacyPrintf("service.openai_gateway", "SSE line too long: account=%d max_size=%d error=%v", account.ID, maxLineSize, ev.err) sendErrorEvent("response_too_large") return &openaiStreamingResult{usage: usage, firstTokenMs: firstTokenMs}, ev.err } @@ -1353,8 +2227,7 @@ func (s *OpenAIGatewayService) handleStreamingResponse(ctx context.Context, resp lastDataAt = time.Now() // Extract data from SSE line (supports both "data: " and "data:" formats) - if openaiSSEDataRe.MatchString(line) { - data := openaiSSEDataRe.ReplaceAllString(line, "") + if data, ok := extractOpenAISSEDataLine(line); ok { // Replace model in response if needed if needModelReplace { @@ -1371,7 +2244,7 @@ func (s *OpenAIGatewayService) handleStreamingResponse(ctx context.Context, resp if !clientDisconnected { if _, err := fmt.Fprintf(w, "%s\n", line); err != nil { clientDisconnected = true - log.Printf("Client disconnected during streaming, continuing to drain upstream for billing") + logger.LegacyPrintf("service.openai_gateway", "Client disconnected during streaming, continuing to drain upstream for billing") } else { flusher.Flush() } @@ -1388,7 +2261,7 @@ func (s *OpenAIGatewayService) handleStreamingResponse(ctx context.Context, resp if !clientDisconnected { if _, err := fmt.Fprintf(w, "%s\n", line); err != nil { clientDisconnected = true - log.Printf("Client disconnected during streaming, continuing to drain upstream for billing") + logger.LegacyPrintf("service.openai_gateway", "Client disconnected during streaming, continuing to drain upstream for billing") } else { flusher.Flush() } @@ -1401,10 +2274,10 @@ func (s *OpenAIGatewayService) handleStreamingResponse(ctx context.Context, resp continue } if clientDisconnected { - log.Printf("Upstream timeout after client disconnect, returning collected usage") + logger.LegacyPrintf("service.openai_gateway", "Upstream timeout after client disconnect, returning collected usage") return &openaiStreamingResult{usage: usage, firstTokenMs: firstTokenMs}, nil } - log.Printf("Stream data interval timeout: account=%d model=%s interval=%s", account.ID, originalModel, streamInterval) + logger.LegacyPrintf("service.openai_gateway", "Stream data interval timeout: account=%d model=%s interval=%s", account.ID, originalModel, streamInterval) // 处理流超时,可能标记账户为临时不可调度或错误状态 if s.rateLimitService != nil { s.rateLimitService.HandleStreamTimeout(ctx, account, originalModel) @@ -1421,7 +2294,7 @@ func (s *OpenAIGatewayService) handleStreamingResponse(ctx context.Context, resp } if _, err := fmt.Fprint(w, ":\n\n"); err != nil { clientDisconnected = true - log.Printf("Client disconnected during streaming, continuing to drain upstream for billing") + logger.LegacyPrintf("service.openai_gateway", "Client disconnected during streaming, continuing to drain upstream for billing") continue } flusher.Flush() @@ -1430,40 +2303,47 @@ func (s *OpenAIGatewayService) handleStreamingResponse(ctx context.Context, resp } +// extractOpenAISSEDataLine 低开销提取 SSE `data:` 行内容。 +// 兼容 `data: xxx` 与 `data:xxx` 两种格式。 +func extractOpenAISSEDataLine(line string) (string, bool) { + if !strings.HasPrefix(line, "data:") { + return "", false + } + start := len("data:") + for start < len(line) { + if line[start] != ' ' && line[start] != ' ' { + break + } + start++ + } + return line[start:], true +} + func (s *OpenAIGatewayService) replaceModelInSSELine(line, fromModel, toModel string) string { - if !openaiSSEDataRe.MatchString(line) { + data, ok := extractOpenAISSEDataLine(line) + if !ok { return line } - data := openaiSSEDataRe.ReplaceAllString(line, "") if data == "" || data == "[DONE]" { return line } - var event map[string]any - if err := json.Unmarshal([]byte(data), &event); err != nil { - return line - } - - // Replace model in response - if m, ok := event["model"].(string); ok && m == fromModel { - event["model"] = toModel - newData, err := json.Marshal(event) + // 使用 gjson 精确检查 model 字段,避免全量 JSON 反序列化 + if m := gjson.Get(data, "model"); m.Exists() && m.Str == fromModel { + newData, err := sjson.Set(data, "model", toModel) if err != nil { return line } - return "data: " + string(newData) + return "data: " + newData } - // Check nested response - if response, ok := event["response"].(map[string]any); ok { - if m, ok := response["model"].(string); ok && m == fromModel { - response["model"] = toModel - newData, err := json.Marshal(event) - if err != nil { - return line - } - return "data: " + string(newData) + // 检查嵌套的 response.model 字段 + if m := gjson.Get(data, "response.model"); m.Exists() && m.Str == fromModel { + newData, err := sjson.Set(data, "response.model", toModel) + if err != nil { + return line } + return "data: " + newData } return line @@ -1484,30 +2364,35 @@ func (s *OpenAIGatewayService) correctToolCallsInResponseBody(body []byte) []byt } func (s *OpenAIGatewayService) parseSSEUsage(data string, usage *OpenAIUsage) { - // Parse response.completed event for usage (OpenAI Responses format) - var event struct { - Type string `json:"type"` - Response struct { - Usage struct { - InputTokens int `json:"input_tokens"` - OutputTokens int `json:"output_tokens"` - InputTokenDetails struct { - CachedTokens int `json:"cached_tokens"` - } `json:"input_tokens_details"` - } `json:"usage"` - } `json:"response"` + if usage == nil || data == "" || data == "[DONE]" { + return + } + // 选择性解析:仅在数据中包含 completed 事件标识时才进入字段提取。 + if !strings.Contains(data, `"response.completed"`) { + return + } + if gjson.Get(data, "type").String() != "response.completed" { + return } - if json.Unmarshal([]byte(data), &event) == nil && event.Type == "response.completed" { - usage.InputTokens = event.Response.Usage.InputTokens - usage.OutputTokens = event.Response.Usage.OutputTokens - usage.CacheReadInputTokens = event.Response.Usage.InputTokenDetails.CachedTokens - } + usage.InputTokens = int(gjson.Get(data, "response.usage.input_tokens").Int()) + usage.OutputTokens = int(gjson.Get(data, "response.usage.output_tokens").Int()) + usage.CacheReadInputTokens = int(gjson.Get(data, "response.usage.input_tokens_details.cached_tokens").Int()) } func (s *OpenAIGatewayService) handleNonStreamingResponse(ctx context.Context, resp *http.Response, c *gin.Context, account *Account, originalModel, mappedModel string) (*OpenAIUsage, error) { - body, err := io.ReadAll(resp.Body) + maxBytes := resolveUpstreamResponseReadLimit(s.cfg) + body, err := readUpstreamResponseBodyLimited(resp.Body, maxBytes) if err != nil { + if errors.Is(err, ErrUpstreamResponseBodyTooLarge) { + setOpsUpstreamError(c, http.StatusBadGateway, "upstream response too large", "") + c.JSON(http.StatusBadGateway, gin.H{ + "error": gin.H{ + "type": "upstream_error", + "message": "Upstream response too large", + }, + }) + } return nil, err } @@ -1613,10 +2498,10 @@ func (s *OpenAIGatewayService) handleOAuthSSEToJSON(resp *http.Response, c *gin. func extractCodexFinalResponse(body string) ([]byte, bool) { lines := strings.Split(body, "\n") for _, line := range lines { - if !openaiSSEDataRe.MatchString(line) { + data, ok := extractOpenAISSEDataLine(line) + if !ok { continue } - data := openaiSSEDataRe.ReplaceAllString(line, "") if data == "" || data == "[DONE]" { continue } @@ -1640,10 +2525,10 @@ func (s *OpenAIGatewayService) parseSSEUsageFromBody(body string) *OpenAIUsage { usage := &OpenAIUsage{} lines := strings.Split(body, "\n") for _, line := range lines { - if !openaiSSEDataRe.MatchString(line) { + data, ok := extractOpenAISSEDataLine(line) + if !ok { continue } - data := openaiSSEDataRe.ReplaceAllString(line, "") if data == "" || data == "[DONE]" { continue } @@ -1655,7 +2540,7 @@ func (s *OpenAIGatewayService) parseSSEUsageFromBody(body string) *OpenAIUsage { func (s *OpenAIGatewayService) replaceModelInSSEBody(body, fromModel, toModel string) string { lines := strings.Split(body, "\n") for i, line := range lines { - if !openaiSSEDataRe.MatchString(line) { + if _, ok := extractOpenAISSEDataLine(line); !ok { continue } lines[i] = s.replaceModelInSSELine(line, fromModel, toModel) @@ -1682,24 +2567,31 @@ func (s *OpenAIGatewayService) validateUpstreamBaseURL(raw string) (string, erro return normalized, nil } +// buildOpenAIResponsesURL 组装 OpenAI Responses 端点。 +// - base 以 /v1 结尾:追加 /responses +// - base 已是 /responses:原样返回 +// - 其他情况:追加 /v1/responses +func buildOpenAIResponsesURL(base string) string { + normalized := strings.TrimRight(strings.TrimSpace(base), "/") + if strings.HasSuffix(normalized, "/responses") { + return normalized + } + if strings.HasSuffix(normalized, "/v1") { + return normalized + "/responses" + } + return normalized + "/v1/responses" +} + func (s *OpenAIGatewayService) replaceModelInResponseBody(body []byte, fromModel, toModel string) []byte { - var resp map[string]any - if err := json.Unmarshal(body, &resp); err != nil { - return body + // 使用 gjson/sjson 精确替换 model 字段,避免全量 JSON 反序列化 + if m := gjson.GetBytes(body, "model"); m.Exists() && m.Str == fromModel { + newBody, err := sjson.SetBytes(body, "model", toModel) + if err != nil { + return body + } + return newBody } - - model, ok := resp["model"].(string) - if !ok || model != fromModel { - return body - } - - resp["model"] = toModel - newBody, err := json.Marshal(resp) - if err != nil { - return body - } - - return newBody + return body } // OpenAIRecordUsageInput input for recording usage @@ -1803,7 +2695,7 @@ func (s *OpenAIGatewayService) RecordUsage(ctx context.Context, input *OpenAIRec inserted, err := s.usageLogRepo.Create(ctx, usageLog) if s.cfg != nil && s.cfg.RunMode == config.RunModeSimple { - log.Printf("[SIMPLE MODE] Usage recorded (not billed): user=%d, tokens=%d", usageLog.UserID, usageLog.TotalTokens()) + logger.LegacyPrintf("service.openai_gateway", "[SIMPLE MODE] Usage recorded (not billed): user=%d, tokens=%d", usageLog.UserID, usageLog.TotalTokens()) s.deferredService.ScheduleLastUsedUpdate(account.ID) return nil } @@ -1826,7 +2718,7 @@ func (s *OpenAIGatewayService) RecordUsage(ctx context.Context, input *OpenAIRec // Update API key quota if applicable (only for balance mode with quota set) if shouldBill && cost.ActualCost > 0 && apiKey.Quota > 0 && input.APIKeyService != nil { if err := input.APIKeyService.UpdateQuotaUsed(ctx, apiKey.ID, cost.ActualCost); err != nil { - log.Printf("Update API key quota failed: %v", err) + logger.LegacyPrintf("service.openai_gateway", "Update API key quota failed: %v", err) } } @@ -1904,16 +2796,41 @@ func ParseCodexRateLimitHeaders(headers http.Header) *OpenAICodexUsageSnapshot { return snapshot } -// updateCodexUsageSnapshot saves the Codex usage snapshot to account's Extra field -func (s *OpenAIGatewayService) updateCodexUsageSnapshot(ctx context.Context, accountID int64, snapshot *OpenAICodexUsageSnapshot) { +func codexSnapshotBaseTime(snapshot *OpenAICodexUsageSnapshot, fallback time.Time) time.Time { if snapshot == nil { - return + return fallback + } + if snapshot.UpdatedAt == "" { + return fallback + } + base, err := time.Parse(time.RFC3339, snapshot.UpdatedAt) + if err != nil { + return fallback + } + return base +} + +func codexResetAtRFC3339(base time.Time, resetAfterSeconds *int) *string { + if resetAfterSeconds == nil { + return nil + } + sec := *resetAfterSeconds + if sec < 0 { + sec = 0 + } + resetAt := base.Add(time.Duration(sec) * time.Second).Format(time.RFC3339) + return &resetAt +} + +func buildCodexUsageExtraUpdates(snapshot *OpenAICodexUsageSnapshot, fallbackNow time.Time) map[string]any { + if snapshot == nil { + return nil } - // Convert snapshot to map for merging into Extra + baseTime := codexSnapshotBaseTime(snapshot, fallbackNow) updates := make(map[string]any) - // Save raw primary/secondary fields for debugging/tracing + // 保存原始 primary/secondary 字段,便于排查问题 if snapshot.PrimaryUsedPercent != nil { updates["codex_primary_used_percent"] = *snapshot.PrimaryUsedPercent } @@ -1935,9 +2852,9 @@ func (s *OpenAIGatewayService) updateCodexUsageSnapshot(ctx context.Context, acc if snapshot.PrimaryOverSecondaryPercent != nil { updates["codex_primary_over_secondary_percent"] = *snapshot.PrimaryOverSecondaryPercent } - updates["codex_usage_updated_at"] = snapshot.UpdatedAt + updates["codex_usage_updated_at"] = baseTime.Format(time.RFC3339) - // Normalize to canonical 5h/7d fields + // 归一化到 5h/7d 规范字段 if normalized := snapshot.Normalize(); normalized != nil { if normalized.Used5hPercent != nil { updates["codex_5h_used_percent"] = *normalized.Used5hPercent @@ -1957,6 +2874,29 @@ func (s *OpenAIGatewayService) updateCodexUsageSnapshot(ctx context.Context, acc if normalized.Window7dMinutes != nil { updates["codex_7d_window_minutes"] = *normalized.Window7dMinutes } + if reset5hAt := codexResetAtRFC3339(baseTime, normalized.Reset5hSeconds); reset5hAt != nil { + updates["codex_5h_reset_at"] = *reset5hAt + } + if reset7dAt := codexResetAtRFC3339(baseTime, normalized.Reset7dSeconds); reset7dAt != nil { + updates["codex_7d_reset_at"] = *reset7dAt + } + } + + return updates +} + +// updateCodexUsageSnapshot saves the Codex usage snapshot to account's Extra field +func (s *OpenAIGatewayService) updateCodexUsageSnapshot(ctx context.Context, accountID int64, snapshot *OpenAICodexUsageSnapshot) { + if snapshot == nil { + return + } + if s == nil || s.accountRepo == nil { + return + } + + updates := buildCodexUsageExtraUpdates(snapshot, time.Now()) + if len(updates) == 0 { + return } // Update account's Extra field asynchronously @@ -2013,6 +2953,103 @@ func deriveOpenAIReasoningEffortFromModel(model string) string { return normalizeOpenAIReasoningEffort(parts[len(parts)-1]) } +func extractOpenAIRequestMetaFromBody(body []byte) (model string, stream bool, promptCacheKey string) { + if len(body) == 0 { + return "", false, "" + } + + model = strings.TrimSpace(gjson.GetBytes(body, "model").String()) + stream = gjson.GetBytes(body, "stream").Bool() + promptCacheKey = strings.TrimSpace(gjson.GetBytes(body, "prompt_cache_key").String()) + return model, stream, promptCacheKey +} + +// normalizeOpenAIPassthroughOAuthBody 将透传 OAuth 请求体收敛为旧链路关键行为: +// 1) store=false 2) stream=true +func normalizeOpenAIPassthroughOAuthBody(body []byte) ([]byte, bool, error) { + if len(body) == 0 { + return body, false, nil + } + + normalized := body + changed := false + + if store := gjson.GetBytes(normalized, "store"); !store.Exists() || store.Type != gjson.False { + next, err := sjson.SetBytes(normalized, "store", false) + if err != nil { + return body, false, fmt.Errorf("normalize passthrough body store=false: %w", err) + } + normalized = next + changed = true + } + + if stream := gjson.GetBytes(normalized, "stream"); !stream.Exists() || stream.Type != gjson.True { + next, err := sjson.SetBytes(normalized, "stream", true) + if err != nil { + return body, false, fmt.Errorf("normalize passthrough body stream=true: %w", err) + } + normalized = next + changed = true + } + + return normalized, changed, nil +} + +func detectOpenAIPassthroughInstructionsRejectReason(reqModel string, body []byte) string { + model := strings.ToLower(strings.TrimSpace(reqModel)) + if !strings.Contains(model, "codex") { + return "" + } + + instructions := gjson.GetBytes(body, "instructions") + if !instructions.Exists() { + return "instructions_missing" + } + if instructions.Type != gjson.String { + return "instructions_not_string" + } + if strings.TrimSpace(instructions.String()) == "" { + return "instructions_empty" + } + return "" +} + +func extractOpenAIReasoningEffortFromBody(body []byte, requestedModel string) *string { + reasoningEffort := strings.TrimSpace(gjson.GetBytes(body, "reasoning.effort").String()) + if reasoningEffort == "" { + reasoningEffort = strings.TrimSpace(gjson.GetBytes(body, "reasoning_effort").String()) + } + if reasoningEffort != "" { + normalized := normalizeOpenAIReasoningEffort(reasoningEffort) + if normalized == "" { + return nil + } + return &normalized + } + + value := deriveOpenAIReasoningEffortFromModel(requestedModel) + if value == "" { + return nil + } + return &value +} + +func getOpenAIRequestBodyMap(c *gin.Context, body []byte) (map[string]any, error) { + if c != nil { + if cached, ok := c.Get(OpenAIParsedRequestBodyKey); ok { + if reqBody, ok := cached.(map[string]any); ok && reqBody != nil { + return reqBody, nil + } + } + } + + var reqBody map[string]any + if err := json.Unmarshal(body, &reqBody); err != nil { + return nil, fmt.Errorf("parse request: %w", err) + } + return reqBody, nil +} + func extractOpenAIReasoningEffort(reqBody map[string]any, requestedModel string) *string { if value, present := getOpenAIReasoningEffortFromReqBody(reqBody); present { if value == "" { diff --git a/backend/internal/service/openai_gateway_service_codex_cli_only_test.go b/backend/internal/service/openai_gateway_service_codex_cli_only_test.go new file mode 100644 index 00000000..d7c95ada --- /dev/null +++ b/backend/internal/service/openai_gateway_service_codex_cli_only_test.go @@ -0,0 +1,266 @@ +package service + +import ( + "bytes" + "context" + "io" + "net/http" + "net/http/httptest" + "strings" + "testing" + + "github.com/Wei-Shaw/sub2api/internal/config" + "github.com/gin-gonic/gin" + "github.com/stretchr/testify/require" +) + +type stubCodexRestrictionDetector struct { + result CodexClientRestrictionDetectionResult +} + +func (s *stubCodexRestrictionDetector) Detect(_ *gin.Context, _ *Account) CodexClientRestrictionDetectionResult { + return s.result +} + +func TestOpenAIGatewayService_GetCodexClientRestrictionDetector(t *testing.T) { + gin.SetMode(gin.TestMode) + + t.Run("使用注入的 detector", func(t *testing.T) { + expected := &stubCodexRestrictionDetector{ + result: CodexClientRestrictionDetectionResult{Enabled: true, Matched: true, Reason: "stub"}, + } + svc := &OpenAIGatewayService{codexDetector: expected} + + got := svc.getCodexClientRestrictionDetector() + require.Same(t, expected, got) + }) + + t.Run("service 为 nil 时返回默认 detector", func(t *testing.T) { + var svc *OpenAIGatewayService + got := svc.getCodexClientRestrictionDetector() + require.NotNil(t, got) + }) + + t.Run("service 未注入 detector 时返回默认 detector", func(t *testing.T) { + svc := &OpenAIGatewayService{cfg: &config.Config{Gateway: config.GatewayConfig{ForceCodexCLI: true}}} + got := svc.getCodexClientRestrictionDetector() + require.NotNil(t, got) + + rec := httptest.NewRecorder() + c, _ := gin.CreateTestContext(rec) + c.Request = httptest.NewRequest(http.MethodPost, "/v1/responses", nil) + c.Request.Header.Set("User-Agent", "curl/8.0") + account := &Account{Platform: PlatformOpenAI, Type: AccountTypeOAuth, Extra: map[string]any{"codex_cli_only": true}} + + result := got.Detect(c, account) + require.True(t, result.Enabled) + require.True(t, result.Matched) + require.Equal(t, CodexClientRestrictionReasonForceCodexCLI, result.Reason) + }) +} + +func TestGetAPIKeyIDFromContext(t *testing.T) { + gin.SetMode(gin.TestMode) + + t.Run("context 为 nil", func(t *testing.T) { + require.Equal(t, int64(0), getAPIKeyIDFromContext(nil)) + }) + + t.Run("上下文没有 api_key", func(t *testing.T) { + rec := httptest.NewRecorder() + c, _ := gin.CreateTestContext(rec) + require.Equal(t, int64(0), getAPIKeyIDFromContext(c)) + }) + + t.Run("api_key 类型错误", func(t *testing.T) { + rec := httptest.NewRecorder() + c, _ := gin.CreateTestContext(rec) + c.Set("api_key", "not-api-key") + require.Equal(t, int64(0), getAPIKeyIDFromContext(c)) + }) + + t.Run("api_key 指针为空", func(t *testing.T) { + rec := httptest.NewRecorder() + c, _ := gin.CreateTestContext(rec) + var k *APIKey + c.Set("api_key", k) + require.Equal(t, int64(0), getAPIKeyIDFromContext(c)) + }) + + t.Run("正常读取 api_key_id", func(t *testing.T) { + rec := httptest.NewRecorder() + c, _ := gin.CreateTestContext(rec) + c.Set("api_key", &APIKey{ID: 12345}) + require.Equal(t, int64(12345), getAPIKeyIDFromContext(c)) + }) +} + +func TestLogCodexCLIOnlyDetection_NilSafety(t *testing.T) { + // 不校验日志内容,仅保证在 nil 入参下不会 panic。 + require.NotPanics(t, func() { + logCodexCLIOnlyDetection(context.TODO(), nil, nil, 0, CodexClientRestrictionDetectionResult{Enabled: true, Matched: false, Reason: "test"}, nil) + logCodexCLIOnlyDetection(context.Background(), nil, nil, 0, CodexClientRestrictionDetectionResult{Enabled: false, Matched: false, Reason: "disabled"}, nil) + }) +} + +func TestLogCodexCLIOnlyDetection_OnlyLogsRejected(t *testing.T) { + logSink, restore := captureStructuredLog(t) + defer restore() + + account := &Account{ID: 1001} + logCodexCLIOnlyDetection(context.Background(), nil, account, 2002, CodexClientRestrictionDetectionResult{ + Enabled: true, + Matched: true, + Reason: CodexClientRestrictionReasonMatchedUA, + }, nil) + logCodexCLIOnlyDetection(context.Background(), nil, account, 2002, CodexClientRestrictionDetectionResult{ + Enabled: true, + Matched: false, + Reason: CodexClientRestrictionReasonNotMatchedUA, + }, nil) + + require.False(t, logSink.ContainsMessage("OpenAI codex_cli_only 允许官方客户端请求")) + require.True(t, logSink.ContainsMessage("OpenAI codex_cli_only 拒绝非官方客户端请求")) +} + +func TestLogCodexCLIOnlyDetection_RejectedIncludesRequestDetails(t *testing.T) { + gin.SetMode(gin.TestMode) + logSink, restore := captureStructuredLog(t) + defer restore() + + rec := httptest.NewRecorder() + c, _ := gin.CreateTestContext(rec) + c.Request = httptest.NewRequest(http.MethodPost, "/v1/responses?trace=1", bytes.NewReader(nil)) + c.Request.Header.Set("User-Agent", "codex_cli_rs/0.98.0 (Windows 10.0.19045; x86_64) unknown") + c.Request.Header.Set("Content-Type", "application/json") + c.Request.Header.Set("OpenAI-Beta", "assistants=v2") + + body := []byte(`{"model":"gpt-5.2","stream":false,"prompt_cache_key":"pc-123","access_token":"secret-token","input":[{"type":"text","text":"hello"}]}`) + account := &Account{ID: 1001} + logCodexCLIOnlyDetection(context.Background(), c, account, 2002, CodexClientRestrictionDetectionResult{ + Enabled: true, + Matched: false, + Reason: CodexClientRestrictionReasonNotMatchedUA, + }, body) + + require.True(t, logSink.ContainsFieldValue("request_user_agent", "codex_cli_rs/0.98.0 (Windows 10.0.19045; x86_64) unknown")) + require.True(t, logSink.ContainsFieldValue("request_model", "gpt-5.2")) + require.True(t, logSink.ContainsFieldValue("request_query", "trace=1")) + require.True(t, logSink.ContainsFieldValue("request_prompt_cache_key_sha256", hashSensitiveValueForLog("pc-123"))) + require.True(t, logSink.ContainsFieldValue("request_headers", "openai-beta")) + require.True(t, logSink.ContainsField("request_body_size")) + require.False(t, logSink.ContainsField("request_body_preview")) +} + +func TestLogOpenAIInstructionsRequiredDebug_LogsRequestDetails(t *testing.T) { + gin.SetMode(gin.TestMode) + logSink, restore := captureStructuredLog(t) + defer restore() + + rec := httptest.NewRecorder() + c, _ := gin.CreateTestContext(rec) + c.Request = httptest.NewRequest(http.MethodPost, "/v1/responses?trace=1", bytes.NewReader(nil)) + c.Request.Header.Set("User-Agent", "curl/8.0") + c.Request.Header.Set("Content-Type", "application/json") + c.Request.Header.Set("OpenAI-Beta", "assistants=v2") + + body := []byte(`{"model":"gpt-5.1-codex","stream":false,"prompt_cache_key":"pc-abc","access_token":"secret-token","input":[{"type":"text","text":"hello"}]}`) + account := &Account{ID: 1001, Name: "codex max套餐"} + + logOpenAIInstructionsRequiredDebug( + context.Background(), + c, + account, + http.StatusBadRequest, + "Instructions are required", + body, + []byte(`{"error":{"message":"Instructions are required","type":"invalid_request_error","param":"instructions","code":"missing_required_parameter"}}`), + ) + + require.True(t, logSink.ContainsMessageAtLevel("OpenAI 上游返回 Instructions are required,已记录请求详情用于排查", "warn")) + require.True(t, logSink.ContainsFieldValue("request_user_agent", "curl/8.0")) + require.True(t, logSink.ContainsFieldValue("request_model", "gpt-5.1-codex")) + require.True(t, logSink.ContainsFieldValue("request_query", "trace=1")) + require.True(t, logSink.ContainsFieldValue("account_name", "codex max套餐")) + require.True(t, logSink.ContainsFieldValue("request_headers", "openai-beta")) + require.True(t, logSink.ContainsField("request_body_size")) + require.False(t, logSink.ContainsField("request_body_preview")) +} + +func TestLogOpenAIInstructionsRequiredDebug_NonTargetErrorSkipped(t *testing.T) { + gin.SetMode(gin.TestMode) + logSink, restore := captureStructuredLog(t) + defer restore() + + rec := httptest.NewRecorder() + c, _ := gin.CreateTestContext(rec) + c.Request = httptest.NewRequest(http.MethodPost, "/v1/responses", bytes.NewReader(nil)) + c.Request.Header.Set("User-Agent", "curl/8.0") + body := []byte(`{"model":"gpt-5.1-codex","stream":false}`) + + logOpenAIInstructionsRequiredDebug( + context.Background(), + c, + &Account{ID: 1001}, + http.StatusForbidden, + "forbidden", + body, + []byte(`{"error":{"message":"forbidden"}}`), + ) + + require.False(t, logSink.ContainsMessage("OpenAI 上游返回 Instructions are required,已记录请求详情用于排查")) +} + +func TestOpenAIGatewayService_Forward_LogsInstructionsRequiredDetails(t *testing.T) { + gin.SetMode(gin.TestMode) + logSink, restore := captureStructuredLog(t) + defer restore() + + rec := httptest.NewRecorder() + c, _ := gin.CreateTestContext(rec) + c.Request = httptest.NewRequest(http.MethodPost, "/v1/responses?trace=1", bytes.NewReader(nil)) + c.Request.Header.Set("User-Agent", "codex_cli_rs/0.1.0") + c.Request.Header.Set("Content-Type", "application/json") + c.Request.Header.Set("OpenAI-Beta", "assistants=v2") + + upstream := &httpUpstreamRecorder{ + resp: &http.Response{ + StatusCode: http.StatusBadRequest, + Header: http.Header{ + "Content-Type": []string{"application/json"}, + "x-request-id": []string{"rid-upstream"}, + }, + Body: io.NopCloser(strings.NewReader(`{"error":{"message":"Missing required parameter: 'instructions'","type":"invalid_request_error","param":"instructions","code":"missing_required_parameter"}}`)), + }, + } + svc := &OpenAIGatewayService{ + cfg: &config.Config{ + Gateway: config.GatewayConfig{ForceCodexCLI: false}, + }, + httpUpstream: upstream, + } + account := &Account{ + ID: 1001, + Name: "codex max套餐", + Platform: PlatformOpenAI, + Type: AccountTypeAPIKey, + Concurrency: 1, + Credentials: map[string]any{"api_key": "sk-test"}, + Status: StatusActive, + Schedulable: true, + RateMultiplier: f64p(1), + } + body := []byte(`{"model":"gpt-5.1-codex","stream":false,"input":[{"type":"text","text":"hello"}],"prompt_cache_key":"pc-forward","access_token":"secret-token"}`) + + _, err := svc.Forward(context.Background(), c, account, body) + require.Error(t, err) + require.Equal(t, http.StatusBadGateway, rec.Code) + require.Contains(t, err.Error(), "upstream error: 400") + + require.True(t, logSink.ContainsMessageAtLevel("OpenAI 上游返回 Instructions are required,已记录请求详情用于排查", "warn")) + require.True(t, logSink.ContainsFieldValue("request_user_agent", "codex_cli_rs/0.1.0")) + require.True(t, logSink.ContainsFieldValue("request_model", "gpt-5.1-codex")) + require.True(t, logSink.ContainsFieldValue("request_headers", "openai-beta")) + require.True(t, logSink.ContainsField("request_body_size")) + require.False(t, logSink.ContainsField("request_body_preview")) +} diff --git a/backend/internal/service/openai_gateway_service_codex_snapshot_test.go b/backend/internal/service/openai_gateway_service_codex_snapshot_test.go new file mode 100644 index 00000000..654dd4ca --- /dev/null +++ b/backend/internal/service/openai_gateway_service_codex_snapshot_test.go @@ -0,0 +1,192 @@ +package service + +import ( + "testing" + "time" +) + +func TestCodexSnapshotBaseTime(t *testing.T) { + fallback := time.Date(2026, 2, 20, 9, 0, 0, 0, time.UTC) + + t.Run("nil snapshot uses fallback", func(t *testing.T) { + got := codexSnapshotBaseTime(nil, fallback) + if !got.Equal(fallback) { + t.Fatalf("got %v, want fallback %v", got, fallback) + } + }) + + t.Run("empty updatedAt uses fallback", func(t *testing.T) { + got := codexSnapshotBaseTime(&OpenAICodexUsageSnapshot{}, fallback) + if !got.Equal(fallback) { + t.Fatalf("got %v, want fallback %v", got, fallback) + } + }) + + t.Run("valid updatedAt wins", func(t *testing.T) { + got := codexSnapshotBaseTime(&OpenAICodexUsageSnapshot{UpdatedAt: "2026-02-16T10:00:00Z"}, fallback) + want := time.Date(2026, 2, 16, 10, 0, 0, 0, time.UTC) + if !got.Equal(want) { + t.Fatalf("got %v, want %v", got, want) + } + }) + + t.Run("invalid updatedAt uses fallback", func(t *testing.T) { + got := codexSnapshotBaseTime(&OpenAICodexUsageSnapshot{UpdatedAt: "invalid"}, fallback) + if !got.Equal(fallback) { + t.Fatalf("got %v, want fallback %v", got, fallback) + } + }) +} + +func TestCodexResetAtRFC3339(t *testing.T) { + base := time.Date(2026, 2, 16, 10, 0, 0, 0, time.UTC) + + t.Run("nil reset returns nil", func(t *testing.T) { + if got := codexResetAtRFC3339(base, nil); got != nil { + t.Fatalf("expected nil, got %v", *got) + } + }) + + t.Run("positive seconds", func(t *testing.T) { + sec := 90 + got := codexResetAtRFC3339(base, &sec) + if got == nil { + t.Fatal("expected non-nil") + } + if *got != "2026-02-16T10:01:30Z" { + t.Fatalf("got %s, want %s", *got, "2026-02-16T10:01:30Z") + } + }) + + t.Run("negative seconds clamp to base", func(t *testing.T) { + sec := -3 + got := codexResetAtRFC3339(base, &sec) + if got == nil { + t.Fatal("expected non-nil") + } + if *got != "2026-02-16T10:00:00Z" { + t.Fatalf("got %s, want %s", *got, "2026-02-16T10:00:00Z") + } + }) +} + +func TestBuildCodexUsageExtraUpdates_UsesSnapshotUpdatedAt(t *testing.T) { + primaryUsed := 88.0 + primaryReset := 86400 + primaryWindow := 10080 + secondaryUsed := 12.0 + secondaryReset := 3600 + secondaryWindow := 300 + + snapshot := &OpenAICodexUsageSnapshot{ + PrimaryUsedPercent: &primaryUsed, + PrimaryResetAfterSeconds: &primaryReset, + PrimaryWindowMinutes: &primaryWindow, + SecondaryUsedPercent: &secondaryUsed, + SecondaryResetAfterSeconds: &secondaryReset, + SecondaryWindowMinutes: &secondaryWindow, + UpdatedAt: "2026-02-16T10:00:00Z", + } + + updates := buildCodexUsageExtraUpdates(snapshot, time.Date(2026, 2, 20, 8, 0, 0, 0, time.UTC)) + if updates == nil { + t.Fatal("expected non-nil updates") + } + + if got := updates["codex_usage_updated_at"]; got != "2026-02-16T10:00:00Z" { + t.Fatalf("codex_usage_updated_at = %v, want %s", got, "2026-02-16T10:00:00Z") + } + if got := updates["codex_5h_reset_at"]; got != "2026-02-16T11:00:00Z" { + t.Fatalf("codex_5h_reset_at = %v, want %s", got, "2026-02-16T11:00:00Z") + } + if got := updates["codex_7d_reset_at"]; got != "2026-02-17T10:00:00Z" { + t.Fatalf("codex_7d_reset_at = %v, want %s", got, "2026-02-17T10:00:00Z") + } +} + +func TestBuildCodexUsageExtraUpdates_FallbackToNowWhenUpdatedAtInvalid(t *testing.T) { + primaryUsed := 15.0 + primaryReset := 30 + primaryWindow := 300 + + fallbackNow := time.Date(2026, 2, 20, 8, 30, 0, 0, time.UTC) + snapshot := &OpenAICodexUsageSnapshot{ + PrimaryUsedPercent: &primaryUsed, + PrimaryResetAfterSeconds: &primaryReset, + PrimaryWindowMinutes: &primaryWindow, + UpdatedAt: "invalid-time", + } + + updates := buildCodexUsageExtraUpdates(snapshot, fallbackNow) + if updates == nil { + t.Fatal("expected non-nil updates") + } + + if got := updates["codex_usage_updated_at"]; got != "2026-02-20T08:30:00Z" { + t.Fatalf("codex_usage_updated_at = %v, want %s", got, "2026-02-20T08:30:00Z") + } + if got := updates["codex_5h_reset_at"]; got != "2026-02-20T08:30:30Z" { + t.Fatalf("codex_5h_reset_at = %v, want %s", got, "2026-02-20T08:30:30Z") + } +} + +func TestBuildCodexUsageExtraUpdates_ClampNegativeResetSeconds(t *testing.T) { + primaryUsed := 90.0 + primaryReset := 7200 + primaryWindow := 10080 + secondaryUsed := 100.0 + secondaryReset := -15 + secondaryWindow := 300 + + snapshot := &OpenAICodexUsageSnapshot{ + PrimaryUsedPercent: &primaryUsed, + PrimaryResetAfterSeconds: &primaryReset, + PrimaryWindowMinutes: &primaryWindow, + SecondaryUsedPercent: &secondaryUsed, + SecondaryResetAfterSeconds: &secondaryReset, + SecondaryWindowMinutes: &secondaryWindow, + UpdatedAt: "2026-02-16T10:00:00Z", + } + + updates := buildCodexUsageExtraUpdates(snapshot, time.Time{}) + if updates == nil { + t.Fatal("expected non-nil updates") + } + + if got := updates["codex_5h_reset_after_seconds"]; got != -15 { + t.Fatalf("codex_5h_reset_after_seconds = %v, want %d", got, -15) + } + if got := updates["codex_5h_reset_at"]; got != "2026-02-16T10:00:00Z" { + t.Fatalf("codex_5h_reset_at = %v, want %s", got, "2026-02-16T10:00:00Z") + } +} + +func TestBuildCodexUsageExtraUpdates_NilSnapshot(t *testing.T) { + if got := buildCodexUsageExtraUpdates(nil, time.Now()); got != nil { + t.Fatalf("expected nil updates, got %v", got) + } +} + +func TestBuildCodexUsageExtraUpdates_WithoutNormalizedWindowFields(t *testing.T) { + primaryUsed := 42.0 + fallbackNow := time.Date(2026, 2, 20, 9, 15, 0, 0, time.UTC) + snapshot := &OpenAICodexUsageSnapshot{ + PrimaryUsedPercent: &primaryUsed, + UpdatedAt: "", + } + + updates := buildCodexUsageExtraUpdates(snapshot, fallbackNow) + if updates == nil { + t.Fatal("expected non-nil updates") + } + + if got := updates["codex_usage_updated_at"]; got != "2026-02-20T09:15:00Z" { + t.Fatalf("codex_usage_updated_at = %v, want %s", got, "2026-02-20T09:15:00Z") + } + if _, ok := updates["codex_5h_reset_at"]; ok { + t.Fatalf("did not expect codex_5h_reset_at in updates: %v", updates["codex_5h_reset_at"]) + } + if _, ok := updates["codex_7d_reset_at"]; ok { + t.Fatalf("did not expect codex_7d_reset_at in updates: %v", updates["codex_7d_reset_at"]) + } +} diff --git a/backend/internal/service/openai_gateway_service_hotpath_test.go b/backend/internal/service/openai_gateway_service_hotpath_test.go new file mode 100644 index 00000000..6b11831f --- /dev/null +++ b/backend/internal/service/openai_gateway_service_hotpath_test.go @@ -0,0 +1,125 @@ +package service + +import ( + "net/http/httptest" + "testing" + + "github.com/gin-gonic/gin" + "github.com/stretchr/testify/require" +) + +func TestExtractOpenAIRequestMetaFromBody(t *testing.T) { + tests := []struct { + name string + body []byte + wantModel string + wantStream bool + wantPromptKey string + }{ + { + name: "完整字段", + body: []byte(`{"model":"gpt-5","stream":true,"prompt_cache_key":" ses-1 "}`), + wantModel: "gpt-5", + wantStream: true, + wantPromptKey: "ses-1", + }, + { + name: "缺失可选字段", + body: []byte(`{"model":"gpt-4"}`), + wantModel: "gpt-4", + wantStream: false, + wantPromptKey: "", + }, + { + name: "空请求体", + body: nil, + wantModel: "", + wantStream: false, + wantPromptKey: "", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + model, stream, promptKey := extractOpenAIRequestMetaFromBody(tt.body) + require.Equal(t, tt.wantModel, model) + require.Equal(t, tt.wantStream, stream) + require.Equal(t, tt.wantPromptKey, promptKey) + }) + } +} + +func TestExtractOpenAIReasoningEffortFromBody(t *testing.T) { + tests := []struct { + name string + body []byte + model string + wantNil bool + wantValue string + }{ + { + name: "优先读取 reasoning.effort", + body: []byte(`{"reasoning":{"effort":"medium"}}`), + model: "gpt-5-high", + wantNil: false, + wantValue: "medium", + }, + { + name: "兼容 reasoning_effort", + body: []byte(`{"reasoning_effort":"x-high"}`), + model: "", + wantNil: false, + wantValue: "xhigh", + }, + { + name: "minimal 归一化为空", + body: []byte(`{"reasoning":{"effort":"minimal"}}`), + model: "gpt-5-high", + wantNil: true, + }, + { + name: "缺失字段时从模型后缀推导", + body: []byte(`{"input":"hi"}`), + model: "gpt-5-high", + wantNil: false, + wantValue: "high", + }, + { + name: "未知后缀不返回", + body: []byte(`{"input":"hi"}`), + model: "gpt-5-unknown", + wantNil: true, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got := extractOpenAIReasoningEffortFromBody(tt.body, tt.model) + if tt.wantNil { + require.Nil(t, got) + return + } + require.NotNil(t, got) + require.Equal(t, tt.wantValue, *got) + }) + } +} + +func TestGetOpenAIRequestBodyMap_UsesContextCache(t *testing.T) { + gin.SetMode(gin.TestMode) + rec := httptest.NewRecorder() + c, _ := gin.CreateTestContext(rec) + + cached := map[string]any{"model": "cached-model", "stream": true} + c.Set(OpenAIParsedRequestBodyKey, cached) + + got, err := getOpenAIRequestBodyMap(c, []byte(`{invalid-json`)) + require.NoError(t, err) + require.Equal(t, cached, got) +} + +func TestGetOpenAIRequestBodyMap_ParseErrorWithoutCache(t *testing.T) { + _, err := getOpenAIRequestBodyMap(nil, []byte(`{invalid-json`)) + require.Error(t, err) + require.Contains(t, err.Error(), "parse request") +} diff --git a/backend/internal/service/openai_gateway_service_test.go b/backend/internal/service/openai_gateway_service_test.go index ae69a986..226648e4 100644 --- a/backend/internal/service/openai_gateway_service_test.go +++ b/backend/internal/service/openai_gateway_service_test.go @@ -14,8 +14,13 @@ import ( "github.com/Wei-Shaw/sub2api/internal/config" "github.com/gin-gonic/gin" + "github.com/stretchr/testify/require" ) +// 编译期接口断言 +var _ AccountRepository = (*stubOpenAIAccountRepo)(nil) +var _ GatewayCache = (*stubGatewayCache)(nil) + type stubOpenAIAccountRepo struct { AccountRepository accounts []Account @@ -124,17 +129,19 @@ func TestOpenAIGatewayService_GenerateSessionHash_Priority(t *testing.T) { svc := &OpenAIGatewayService{} + bodyWithKey := []byte(`{"prompt_cache_key":"ses_aaa"}`) + // 1) session_id header wins c.Request.Header.Set("session_id", "sess-123") c.Request.Header.Set("conversation_id", "conv-456") - h1 := svc.GenerateSessionHash(c, map[string]any{"prompt_cache_key": "ses_aaa"}) + h1 := svc.GenerateSessionHash(c, bodyWithKey) if h1 == "" { t.Fatalf("expected non-empty hash") } // 2) conversation_id used when session_id absent c.Request.Header.Del("session_id") - h2 := svc.GenerateSessionHash(c, map[string]any{"prompt_cache_key": "ses_aaa"}) + h2 := svc.GenerateSessionHash(c, bodyWithKey) if h2 == "" { t.Fatalf("expected non-empty hash") } @@ -144,7 +151,7 @@ func TestOpenAIGatewayService_GenerateSessionHash_Priority(t *testing.T) { // 3) prompt_cache_key used when both headers absent c.Request.Header.Del("conversation_id") - h3 := svc.GenerateSessionHash(c, map[string]any{"prompt_cache_key": "ses_aaa"}) + h3 := svc.GenerateSessionHash(c, bodyWithKey) if h3 == "" { t.Fatalf("expected non-empty hash") } @@ -153,7 +160,7 @@ func TestOpenAIGatewayService_GenerateSessionHash_Priority(t *testing.T) { } // 4) empty when no signals - h4 := svc.GenerateSessionHash(c, map[string]any{}) + h4 := svc.GenerateSessionHash(c, []byte(`{}`)) if h4 != "" { t.Fatalf("expected empty hash when no signals") } @@ -1066,6 +1073,43 @@ func TestOpenAIStreamingHeadersOverride(t *testing.T) { } } +func TestOpenAIStreamingReuseScannerBufferAndStillWorks(t *testing.T) { + gin.SetMode(gin.TestMode) + cfg := &config.Config{ + Gateway: config.GatewayConfig{ + StreamDataIntervalTimeout: 0, + StreamKeepaliveInterval: 0, + MaxLineSize: defaultMaxLineSize, + }, + } + svc := &OpenAIGatewayService{cfg: cfg} + + rec := httptest.NewRecorder() + c, _ := gin.CreateTestContext(rec) + c.Request = httptest.NewRequest(http.MethodPost, "/", nil) + + pr, pw := io.Pipe() + resp := &http.Response{ + StatusCode: http.StatusOK, + Body: pr, + Header: http.Header{}, + } + + go func() { + defer func() { _ = pw.Close() }() + _, _ = pw.Write([]byte("data: {\"type\":\"response.completed\",\"response\":{\"usage\":{\"input_tokens\":1,\"output_tokens\":2,\"input_tokens_details\":{\"cached_tokens\":3}}}}\n\n")) + }() + + result, err := svc.handleStreamingResponse(c.Request.Context(), resp, c, &Account{ID: 1}, time.Now(), "model", "model") + _ = pr.Close() + require.NoError(t, err) + require.NotNil(t, result) + require.NotNil(t, result.usage) + require.Equal(t, 1, result.usage.InputTokens) + require.Equal(t, 2, result.usage.OutputTokens) + require.Equal(t, 3, result.usage.CacheReadInputTokens) +} + func TestOpenAIInvalidBaseURLWhenAllowlistDisabled(t *testing.T) { gin.SetMode(gin.TestMode) cfg := &config.Config{ @@ -1149,3 +1193,332 @@ func TestOpenAIValidateUpstreamBaseURLEnabledEnforcesAllowlist(t *testing.T) { t.Fatalf("expected non-allowlisted host to fail") } } + +// ==================== P1-08 修复:model 替换性能优化测试 ==================== + +func TestReplaceModelInSSELine(t *testing.T) { + svc := &OpenAIGatewayService{} + + tests := []struct { + name string + line string + from string + to string + expected string + }{ + { + name: "顶层 model 字段替换", + line: `data: {"id":"chatcmpl-123","model":"gpt-4o","choices":[]}`, + from: "gpt-4o", + to: "my-custom-model", + expected: `data: {"id":"chatcmpl-123","model":"my-custom-model","choices":[]}`, + }, + { + name: "嵌套 response.model 替换", + line: `data: {"type":"response","response":{"id":"resp-1","model":"gpt-4o","output":[]}}`, + from: "gpt-4o", + to: "my-model", + expected: `data: {"type":"response","response":{"id":"resp-1","model":"my-model","output":[]}}`, + }, + { + name: "model 不匹配时不替换", + line: `data: {"id":"chatcmpl-123","model":"gpt-3.5-turbo","choices":[]}`, + from: "gpt-4o", + to: "my-model", + expected: `data: {"id":"chatcmpl-123","model":"gpt-3.5-turbo","choices":[]}`, + }, + { + name: "无 model 字段时不替换", + line: `data: {"id":"chatcmpl-123","choices":[]}`, + from: "gpt-4o", + to: "my-model", + expected: `data: {"id":"chatcmpl-123","choices":[]}`, + }, + { + name: "空 data 行", + line: `data: `, + from: "gpt-4o", + to: "my-model", + expected: `data: `, + }, + { + name: "[DONE] 行", + line: `data: [DONE]`, + from: "gpt-4o", + to: "my-model", + expected: `data: [DONE]`, + }, + { + name: "非 data: 前缀行", + line: `event: message`, + from: "gpt-4o", + to: "my-model", + expected: `event: message`, + }, + { + name: "非法 JSON 不替换", + line: `data: {invalid json}`, + from: "gpt-4o", + to: "my-model", + expected: `data: {invalid json}`, + }, + { + name: "无空格 data: 格式", + line: `data:{"id":"x","model":"gpt-4o"}`, + from: "gpt-4o", + to: "my-model", + expected: `data: {"id":"x","model":"my-model"}`, + }, + { + name: "model 名含特殊字符", + line: `data: {"model":"org/model-v2.1-beta"}`, + from: "org/model-v2.1-beta", + to: "custom/alias", + expected: `data: {"model":"custom/alias"}`, + }, + { + name: "空行", + line: "", + from: "gpt-4o", + to: "my-model", + expected: "", + }, + { + name: "保持其他字段不变", + line: `data: {"id":"abc","object":"chat.completion.chunk","model":"gpt-4o","created":1234567890,"choices":[{"index":0,"delta":{"content":"hi"}}]}`, + from: "gpt-4o", + to: "alias", + expected: `data: {"id":"abc","object":"chat.completion.chunk","model":"alias","created":1234567890,"choices":[{"index":0,"delta":{"content":"hi"}}]}`, + }, + { + name: "顶层优先于嵌套:同时存在两个 model", + line: `data: {"model":"gpt-4o","response":{"model":"gpt-4o"}}`, + from: "gpt-4o", + to: "replaced", + expected: `data: {"model":"replaced","response":{"model":"gpt-4o"}}`, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got := svc.replaceModelInSSELine(tt.line, tt.from, tt.to) + require.Equal(t, tt.expected, got) + }) + } +} + +func TestReplaceModelInSSEBody(t *testing.T) { + svc := &OpenAIGatewayService{} + + tests := []struct { + name string + body string + from string + to string + expected string + }{ + { + name: "多行 SSE body 替换", + body: "data: {\"model\":\"gpt-4o\",\"choices\":[]}\n\ndata: {\"model\":\"gpt-4o\",\"choices\":[{\"delta\":{\"content\":\"hi\"}}]}\n\ndata: [DONE]\n", + from: "gpt-4o", + to: "alias", + expected: "data: {\"model\":\"alias\",\"choices\":[]}\n\ndata: {\"model\":\"alias\",\"choices\":[{\"delta\":{\"content\":\"hi\"}}]}\n\ndata: [DONE]\n", + }, + { + name: "无需替换的 body", + body: "data: {\"model\":\"gpt-3.5-turbo\"}\n\ndata: [DONE]\n", + from: "gpt-4o", + to: "alias", + expected: "data: {\"model\":\"gpt-3.5-turbo\"}\n\ndata: [DONE]\n", + }, + { + name: "混合 event 和 data 行", + body: "event: message\ndata: {\"model\":\"gpt-4o\"}\n\n", + from: "gpt-4o", + to: "alias", + expected: "event: message\ndata: {\"model\":\"alias\"}\n\n", + }, + { + name: "空 body", + body: "", + from: "gpt-4o", + to: "alias", + expected: "", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got := svc.replaceModelInSSEBody(tt.body, tt.from, tt.to) + require.Equal(t, tt.expected, got) + }) + } +} + +func TestReplaceModelInResponseBody(t *testing.T) { + svc := &OpenAIGatewayService{} + + tests := []struct { + name string + body string + from string + to string + expected string + }{ + { + name: "替换顶层 model", + body: `{"id":"chatcmpl-123","model":"gpt-4o","choices":[]}`, + from: "gpt-4o", + to: "alias", + expected: `{"id":"chatcmpl-123","model":"alias","choices":[]}`, + }, + { + name: "model 不匹配不替换", + body: `{"id":"chatcmpl-123","model":"gpt-3.5-turbo","choices":[]}`, + from: "gpt-4o", + to: "alias", + expected: `{"id":"chatcmpl-123","model":"gpt-3.5-turbo","choices":[]}`, + }, + { + name: "无 model 字段不替换", + body: `{"id":"chatcmpl-123","choices":[]}`, + from: "gpt-4o", + to: "alias", + expected: `{"id":"chatcmpl-123","choices":[]}`, + }, + { + name: "非法 JSON 返回原值", + body: `not json`, + from: "gpt-4o", + to: "alias", + expected: `not json`, + }, + { + name: "空 body 返回原值", + body: ``, + from: "gpt-4o", + to: "alias", + expected: ``, + }, + { + name: "保持嵌套结构不变", + body: `{"model":"gpt-4o","usage":{"prompt_tokens":10,"completion_tokens":20},"choices":[{"message":{"role":"assistant","content":"hello"}}]}`, + from: "gpt-4o", + to: "alias", + expected: `{"model":"alias","usage":{"prompt_tokens":10,"completion_tokens":20},"choices":[{"message":{"role":"assistant","content":"hello"}}]}`, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got := svc.replaceModelInResponseBody([]byte(tt.body), tt.from, tt.to) + require.Equal(t, tt.expected, string(got)) + }) + } +} + +func TestExtractOpenAISSEDataLine(t *testing.T) { + tests := []struct { + name string + line string + wantData string + wantOK bool + }{ + {name: "标准格式", line: `data: {"type":"x"}`, wantData: `{"type":"x"}`, wantOK: true}, + {name: "无空格格式", line: `data:{"type":"x"}`, wantData: `{"type":"x"}`, wantOK: true}, + {name: "纯空数据", line: `data: `, wantData: ``, wantOK: true}, + {name: "非 data 行", line: `event: message`, wantData: ``, wantOK: false}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got, ok := extractOpenAISSEDataLine(tt.line) + require.Equal(t, tt.wantOK, ok) + require.Equal(t, tt.wantData, got) + }) + } +} + +func TestParseSSEUsage_SelectiveParsing(t *testing.T) { + svc := &OpenAIGatewayService{} + usage := &OpenAIUsage{InputTokens: 9, OutputTokens: 8, CacheReadInputTokens: 7} + + // 非 completed 事件,不应覆盖 usage + svc.parseSSEUsage(`{"type":"response.in_progress","response":{"usage":{"input_tokens":1,"output_tokens":2}}}`, usage) + require.Equal(t, 9, usage.InputTokens) + require.Equal(t, 8, usage.OutputTokens) + require.Equal(t, 7, usage.CacheReadInputTokens) + + // completed 事件,应提取 usage + svc.parseSSEUsage(`{"type":"response.completed","response":{"usage":{"input_tokens":3,"output_tokens":5,"input_tokens_details":{"cached_tokens":2}}}}`, usage) + require.Equal(t, 3, usage.InputTokens) + require.Equal(t, 5, usage.OutputTokens) + require.Equal(t, 2, usage.CacheReadInputTokens) +} + +func TestExtractCodexFinalResponse_SampleReplay(t *testing.T) { + body := strings.Join([]string{ + `event: message`, + `data: {"type":"response.in_progress","response":{"id":"resp_1"}}`, + `data: {"type":"response.completed","response":{"id":"resp_1","model":"gpt-4o","usage":{"input_tokens":11,"output_tokens":22,"input_tokens_details":{"cached_tokens":3}}}}`, + `data: [DONE]`, + }, "\n") + + finalResp, ok := extractCodexFinalResponse(body) + require.True(t, ok) + require.Contains(t, string(finalResp), `"id":"resp_1"`) + require.Contains(t, string(finalResp), `"input_tokens":11`) +} + +func TestHandleOAuthSSEToJSON_CompletedEventReturnsJSON(t *testing.T) { + gin.SetMode(gin.TestMode) + rec := httptest.NewRecorder() + c, _ := gin.CreateTestContext(rec) + c.Request = httptest.NewRequest(http.MethodPost, "/", nil) + + svc := &OpenAIGatewayService{cfg: &config.Config{}} + resp := &http.Response{ + StatusCode: http.StatusOK, + Header: http.Header{"Content-Type": []string{"text/event-stream"}}, + } + body := []byte(strings.Join([]string{ + `data: {"type":"response.in_progress","response":{"id":"resp_2"}}`, + `data: {"type":"response.completed","response":{"id":"resp_2","model":"gpt-4o","usage":{"input_tokens":7,"output_tokens":9,"input_tokens_details":{"cached_tokens":1}}}}`, + `data: [DONE]`, + }, "\n")) + + usage, err := svc.handleOAuthSSEToJSON(resp, c, body, "gpt-4o", "gpt-4o") + require.NoError(t, err) + require.NotNil(t, usage) + require.Equal(t, 7, usage.InputTokens) + require.Equal(t, 9, usage.OutputTokens) + require.Equal(t, 1, usage.CacheReadInputTokens) + // Header 可能由上游 Content-Type 透传;关键是 body 已转换为最终 JSON 响应。 + require.NotContains(t, rec.Body.String(), "event:") + require.Contains(t, rec.Body.String(), `"id":"resp_2"`) + require.NotContains(t, rec.Body.String(), "data:") +} + +func TestHandleOAuthSSEToJSON_NoFinalResponseKeepsSSEBody(t *testing.T) { + gin.SetMode(gin.TestMode) + rec := httptest.NewRecorder() + c, _ := gin.CreateTestContext(rec) + c.Request = httptest.NewRequest(http.MethodPost, "/", nil) + + svc := &OpenAIGatewayService{cfg: &config.Config{}} + resp := &http.Response{ + StatusCode: http.StatusOK, + Header: http.Header{"Content-Type": []string{"text/event-stream"}}, + } + body := []byte(strings.Join([]string{ + `data: {"type":"response.in_progress","response":{"id":"resp_3"}}`, + `data: [DONE]`, + }, "\n")) + + usage, err := svc.handleOAuthSSEToJSON(resp, c, body, "gpt-4o", "gpt-4o") + require.NoError(t, err) + require.NotNil(t, usage) + require.Equal(t, 0, usage.InputTokens) + require.Contains(t, rec.Header().Get("Content-Type"), "text/event-stream") + require.Contains(t, rec.Body.String(), `data: {"type":"response.in_progress"`) +} diff --git a/backend/internal/service/openai_oauth_passthrough_test.go b/backend/internal/service/openai_oauth_passthrough_test.go new file mode 100644 index 00000000..7a996c26 --- /dev/null +++ b/backend/internal/service/openai_oauth_passthrough_test.go @@ -0,0 +1,928 @@ +package service + +import ( + "bytes" + "context" + "fmt" + "io" + "net/http" + "net/http/httptest" + "strings" + "sync" + "testing" + "time" + + "github.com/Wei-Shaw/sub2api/internal/config" + "github.com/Wei-Shaw/sub2api/internal/pkg/logger" + "github.com/gin-gonic/gin" + "github.com/stretchr/testify/require" + "github.com/tidwall/gjson" +) + +func f64p(v float64) *float64 { return &v } + +type httpUpstreamRecorder struct { + lastReq *http.Request + lastBody []byte + + resp *http.Response + err error +} + +func (u *httpUpstreamRecorder) Do(req *http.Request, proxyURL string, accountID int64, accountConcurrency int) (*http.Response, error) { + u.lastReq = req + if req != nil && req.Body != nil { + b, _ := io.ReadAll(req.Body) + u.lastBody = b + _ = req.Body.Close() + req.Body = io.NopCloser(bytes.NewReader(b)) + } + if u.err != nil { + return nil, u.err + } + return u.resp, nil +} + +func (u *httpUpstreamRecorder) DoWithTLS(req *http.Request, proxyURL string, accountID int64, accountConcurrency int, enableTLSFingerprint bool) (*http.Response, error) { + return u.Do(req, proxyURL, accountID, accountConcurrency) +} + +var structuredLogCaptureMu sync.Mutex + +type inMemoryLogSink struct { + mu sync.Mutex + events []*logger.LogEvent +} + +func (s *inMemoryLogSink) WriteLogEvent(event *logger.LogEvent) { + if event == nil { + return + } + cloned := *event + if event.Fields != nil { + cloned.Fields = make(map[string]any, len(event.Fields)) + for k, v := range event.Fields { + cloned.Fields[k] = v + } + } + s.mu.Lock() + s.events = append(s.events, &cloned) + s.mu.Unlock() +} + +func (s *inMemoryLogSink) ContainsMessage(substr string) bool { + s.mu.Lock() + defer s.mu.Unlock() + for _, ev := range s.events { + if ev != nil && strings.Contains(ev.Message, substr) { + return true + } + } + return false +} + +func (s *inMemoryLogSink) ContainsMessageAtLevel(substr, level string) bool { + s.mu.Lock() + defer s.mu.Unlock() + wantLevel := strings.ToLower(strings.TrimSpace(level)) + for _, ev := range s.events { + if ev == nil { + continue + } + if strings.Contains(ev.Message, substr) && strings.ToLower(strings.TrimSpace(ev.Level)) == wantLevel { + return true + } + } + return false +} + +func (s *inMemoryLogSink) ContainsFieldValue(field, substr string) bool { + s.mu.Lock() + defer s.mu.Unlock() + for _, ev := range s.events { + if ev == nil || ev.Fields == nil { + continue + } + if v, ok := ev.Fields[field]; ok && strings.Contains(fmt.Sprint(v), substr) { + return true + } + } + return false +} + +func (s *inMemoryLogSink) ContainsField(field string) bool { + s.mu.Lock() + defer s.mu.Unlock() + for _, ev := range s.events { + if ev == nil || ev.Fields == nil { + continue + } + if _, ok := ev.Fields[field]; ok { + return true + } + } + return false +} + +func captureStructuredLog(t *testing.T) (*inMemoryLogSink, func()) { + t.Helper() + structuredLogCaptureMu.Lock() + + err := logger.Init(logger.InitOptions{ + Level: "debug", + Format: "json", + ServiceName: "sub2api", + Environment: "test", + Output: logger.OutputOptions{ + ToStdout: true, + ToFile: false, + }, + Sampling: logger.SamplingOptions{Enabled: false}, + }) + require.NoError(t, err) + + sink := &inMemoryLogSink{} + logger.SetSink(sink) + return sink, func() { + logger.SetSink(nil) + structuredLogCaptureMu.Unlock() + } +} + +func TestOpenAIGatewayService_OAuthPassthrough_StreamKeepsToolNameAndBodyNormalized(t *testing.T) { + gin.SetMode(gin.TestMode) + + rec := httptest.NewRecorder() + c, _ := gin.CreateTestContext(rec) + c.Request = httptest.NewRequest(http.MethodPost, "/v1/responses", bytes.NewReader(nil)) + c.Request.Header.Set("User-Agent", "codex_cli_rs/0.1.0") + c.Request.Header.Set("Authorization", "Bearer inbound-should-not-forward") + c.Request.Header.Set("Cookie", "secret=1") + c.Request.Header.Set("X-Api-Key", "sk-inbound") + c.Request.Header.Set("X-Goog-Api-Key", "goog-inbound") + c.Request.Header.Set("Accept-Encoding", "gzip") + c.Request.Header.Set("Proxy-Authorization", "Basic abc") + c.Request.Header.Set("X-Test", "keep") + + originalBody := []byte(`{"model":"gpt-5.2","stream":true,"store":true,"instructions":"local-test-instructions","input":[{"type":"text","text":"hi"}]}`) + + upstreamSSE := strings.Join([]string{ + `data: {"type":"response.output_item.added","item":{"type":"tool_call","tool_calls":[{"function":{"name":"apply_patch"}}]}}`, + "", + "data: [DONE]", + "", + }, "\n") + resp := &http.Response{ + StatusCode: http.StatusOK, + Header: http.Header{"Content-Type": []string{"text/event-stream"}, "x-request-id": []string{"rid"}}, + Body: io.NopCloser(strings.NewReader(upstreamSSE)), + } + upstream := &httpUpstreamRecorder{resp: resp} + + svc := &OpenAIGatewayService{ + cfg: &config.Config{Gateway: config.GatewayConfig{ForceCodexCLI: false}}, + httpUpstream: upstream, + openAITokenProvider: &OpenAITokenProvider{ // minimal: will be bypassed by nil cache/service, but GetAccessToken uses provider only if non-nil + accountRepo: nil, + }, + } + + account := &Account{ + ID: 123, + Name: "acc", + Platform: PlatformOpenAI, + Type: AccountTypeOAuth, + Concurrency: 1, + Credentials: map[string]any{"access_token": "oauth-token", "chatgpt_account_id": "chatgpt-acc"}, + Extra: map[string]any{"openai_passthrough": true}, + Status: StatusActive, + Schedulable: true, + RateMultiplier: f64p(1), + } + + // Use the gateway method that reads token from credentials when provider is nil. + svc.openAITokenProvider = nil + + result, err := svc.Forward(context.Background(), c, account, originalBody) + require.NoError(t, err) + require.NotNil(t, result) + require.True(t, result.Stream) + + // 1) 透传 OAuth 请求体与旧链路关键行为保持一致:store=false + stream=true。 + require.Equal(t, false, gjson.GetBytes(upstream.lastBody, "store").Bool()) + require.Equal(t, true, gjson.GetBytes(upstream.lastBody, "stream").Bool()) + require.Equal(t, "local-test-instructions", strings.TrimSpace(gjson.GetBytes(upstream.lastBody, "instructions").String())) + // 其余关键字段保持原值。 + require.Equal(t, "gpt-5.2", gjson.GetBytes(upstream.lastBody, "model").String()) + require.Equal(t, "hi", gjson.GetBytes(upstream.lastBody, "input.0.text").String()) + + // 2) only auth is replaced; inbound auth/cookie are not forwarded + require.Equal(t, "Bearer oauth-token", upstream.lastReq.Header.Get("Authorization")) + require.Equal(t, "codex_cli_rs/0.1.0", upstream.lastReq.Header.Get("User-Agent")) + require.Empty(t, upstream.lastReq.Header.Get("Cookie")) + require.Empty(t, upstream.lastReq.Header.Get("X-Api-Key")) + require.Empty(t, upstream.lastReq.Header.Get("X-Goog-Api-Key")) + require.Empty(t, upstream.lastReq.Header.Get("Accept-Encoding")) + require.Empty(t, upstream.lastReq.Header.Get("Proxy-Authorization")) + require.Empty(t, upstream.lastReq.Header.Get("X-Test")) + + // 3) required OAuth headers are present + require.Equal(t, "chatgpt.com", upstream.lastReq.Host) + require.Equal(t, "chatgpt-acc", upstream.lastReq.Header.Get("chatgpt-account-id")) + + // 4) downstream SSE keeps tool name (no toolCorrector) + body := rec.Body.String() + require.Contains(t, body, "apply_patch") + require.NotContains(t, body, "\"name\":\"edit\"") +} + +func TestOpenAIGatewayService_OAuthPassthrough_CodexMissingInstructionsRejectedBeforeUpstream(t *testing.T) { + gin.SetMode(gin.TestMode) + logSink, restore := captureStructuredLog(t) + defer restore() + + rec := httptest.NewRecorder() + c, _ := gin.CreateTestContext(rec) + c.Request = httptest.NewRequest(http.MethodPost, "/v1/responses?trace=1", bytes.NewReader(nil)) + c.Request.Header.Set("User-Agent", "codex_cli_rs/0.98.0 (Windows 10.0.19045; x86_64) unknown") + c.Request.Header.Set("Content-Type", "application/json") + c.Request.Header.Set("OpenAI-Beta", "responses=experimental") + + // Codex 模型且缺少 instructions,应在本地直接 403 拒绝,不触达上游。 + originalBody := []byte(`{"model":"gpt-5.1-codex-max","stream":false,"store":true,"input":[{"type":"text","text":"hi"}]}`) + + upstream := &httpUpstreamRecorder{ + resp: &http.Response{ + StatusCode: http.StatusOK, + Header: http.Header{"Content-Type": []string{"application/json"}, "x-request-id": []string{"rid"}}, + Body: io.NopCloser(strings.NewReader(`{"output":[],"usage":{"input_tokens":1,"output_tokens":1}}`)), + }, + } + + svc := &OpenAIGatewayService{ + cfg: &config.Config{Gateway: config.GatewayConfig{ForceCodexCLI: false}}, + httpUpstream: upstream, + } + + account := &Account{ + ID: 123, + Name: "acc", + Platform: PlatformOpenAI, + Type: AccountTypeOAuth, + Concurrency: 1, + Credentials: map[string]any{"access_token": "oauth-token", "chatgpt_account_id": "chatgpt-acc"}, + Extra: map[string]any{"openai_passthrough": true}, + Status: StatusActive, + Schedulable: true, + RateMultiplier: f64p(1), + } + + result, err := svc.Forward(context.Background(), c, account, originalBody) + require.Error(t, err) + require.Nil(t, result) + require.Equal(t, http.StatusForbidden, rec.Code) + require.Contains(t, rec.Body.String(), "requires a non-empty instructions field") + require.Nil(t, upstream.lastReq) + + require.True(t, logSink.ContainsMessage("OpenAI passthrough 本地拦截:Codex 请求缺少有效 instructions")) + require.True(t, logSink.ContainsFieldValue("request_user_agent", "codex_cli_rs/0.98.0 (Windows 10.0.19045; x86_64) unknown")) + require.True(t, logSink.ContainsFieldValue("reject_reason", "instructions_missing")) +} + +func TestOpenAIGatewayService_OAuthPassthrough_DisabledUsesLegacyTransform(t *testing.T) { + gin.SetMode(gin.TestMode) + + rec := httptest.NewRecorder() + c, _ := gin.CreateTestContext(rec) + c.Request = httptest.NewRequest(http.MethodPost, "/v1/responses", bytes.NewReader(nil)) + c.Request.Header.Set("User-Agent", "codex_cli_rs/0.1.0") + + // store=true + stream=false should be forced to store=false + stream=true by applyCodexOAuthTransform (OAuth legacy path) + inputBody := []byte(`{"model":"gpt-5.2","stream":false,"store":true,"input":[{"type":"text","text":"hi"}]}`) + + resp := &http.Response{ + StatusCode: http.StatusOK, + Header: http.Header{"Content-Type": []string{"text/event-stream"}, "x-request-id": []string{"rid"}}, + Body: io.NopCloser(strings.NewReader("data: [DONE]\n\n")), + } + upstream := &httpUpstreamRecorder{resp: resp} + + svc := &OpenAIGatewayService{ + cfg: &config.Config{Gateway: config.GatewayConfig{ForceCodexCLI: false}}, + httpUpstream: upstream, + } + + account := &Account{ + ID: 123, + Name: "acc", + Platform: PlatformOpenAI, + Type: AccountTypeOAuth, + Concurrency: 1, + Credentials: map[string]any{"access_token": "oauth-token", "chatgpt_account_id": "chatgpt-acc"}, + Extra: map[string]any{"openai_passthrough": false}, + Status: StatusActive, + Schedulable: true, + RateMultiplier: f64p(1), + } + + _, err := svc.Forward(context.Background(), c, account, inputBody) + require.NoError(t, err) + + // legacy path rewrites request body (not byte-equal) + require.NotEqual(t, inputBody, upstream.lastBody) + require.Contains(t, string(upstream.lastBody), `"store":false`) + require.Contains(t, string(upstream.lastBody), `"stream":true`) +} + +func TestOpenAIGatewayService_OAuthLegacy_CompositeCodexUAUsesCodexOriginator(t *testing.T) { + gin.SetMode(gin.TestMode) + + rec := httptest.NewRecorder() + c, _ := gin.CreateTestContext(rec) + c.Request = httptest.NewRequest(http.MethodPost, "/v1/responses", bytes.NewReader(nil)) + // 复合 UA(前缀不是 codex_cli_rs),历史实现会误判为非 Codex 并走 opencode。 + c.Request.Header.Set("User-Agent", "Mozilla/5.0 codex_cli_rs/0.1.0") + + inputBody := []byte(`{"model":"gpt-5.2","stream":true,"store":false,"input":[{"type":"text","text":"hi"}]}`) + + resp := &http.Response{ + StatusCode: http.StatusOK, + Header: http.Header{"Content-Type": []string{"text/event-stream"}, "x-request-id": []string{"rid"}}, + Body: io.NopCloser(strings.NewReader("data: [DONE]\n\n")), + } + upstream := &httpUpstreamRecorder{resp: resp} + + svc := &OpenAIGatewayService{ + cfg: &config.Config{Gateway: config.GatewayConfig{ForceCodexCLI: false}}, + httpUpstream: upstream, + } + + account := &Account{ + ID: 123, + Name: "acc", + Platform: PlatformOpenAI, + Type: AccountTypeOAuth, + Concurrency: 1, + Credentials: map[string]any{"access_token": "oauth-token", "chatgpt_account_id": "chatgpt-acc"}, + Extra: map[string]any{"openai_passthrough": false}, + Status: StatusActive, + Schedulable: true, + RateMultiplier: f64p(1), + } + + _, err := svc.Forward(context.Background(), c, account, inputBody) + require.NoError(t, err) + require.NotNil(t, upstream.lastReq) + require.Equal(t, "codex_cli_rs", upstream.lastReq.Header.Get("originator")) + require.NotEqual(t, "opencode", upstream.lastReq.Header.Get("originator")) +} + +func TestOpenAIGatewayService_OAuthPassthrough_ResponseHeadersAllowXCodex(t *testing.T) { + gin.SetMode(gin.TestMode) + + rec := httptest.NewRecorder() + c, _ := gin.CreateTestContext(rec) + c.Request = httptest.NewRequest(http.MethodPost, "/v1/responses", bytes.NewReader(nil)) + c.Request.Header.Set("User-Agent", "codex_cli_rs/0.1.0") + + originalBody := []byte(`{"model":"gpt-5.2","stream":false,"input":[{"type":"text","text":"hi"}]}`) + + headers := make(http.Header) + headers.Set("Content-Type", "application/json") + headers.Set("x-request-id", "rid") + headers.Set("x-codex-primary-used-percent", "12") + headers.Set("x-codex-secondary-used-percent", "34") + headers.Set("x-codex-primary-window-minutes", "300") + headers.Set("x-codex-secondary-window-minutes", "10080") + headers.Set("x-codex-primary-reset-after-seconds", "1") + + resp := &http.Response{ + StatusCode: http.StatusOK, + Header: headers, + Body: io.NopCloser(strings.NewReader(`{"output":[],"usage":{"input_tokens":1,"output_tokens":1,"input_tokens_details":{"cached_tokens":0}}}`)), + } + upstream := &httpUpstreamRecorder{resp: resp} + + svc := &OpenAIGatewayService{ + cfg: &config.Config{Gateway: config.GatewayConfig{ForceCodexCLI: false}}, + httpUpstream: upstream, + } + + account := &Account{ + ID: 123, + Name: "acc", + Platform: PlatformOpenAI, + Type: AccountTypeOAuth, + Concurrency: 1, + Credentials: map[string]any{"access_token": "oauth-token", "chatgpt_account_id": "chatgpt-acc"}, + Extra: map[string]any{"openai_passthrough": true}, + Status: StatusActive, + Schedulable: true, + RateMultiplier: f64p(1), + } + + _, err := svc.Forward(context.Background(), c, account, originalBody) + require.NoError(t, err) + + require.Equal(t, "12", rec.Header().Get("x-codex-primary-used-percent")) + require.Equal(t, "34", rec.Header().Get("x-codex-secondary-used-percent")) +} + +func TestOpenAIGatewayService_OAuthPassthrough_UpstreamErrorIncludesPassthroughFlag(t *testing.T) { + gin.SetMode(gin.TestMode) + + rec := httptest.NewRecorder() + c, _ := gin.CreateTestContext(rec) + c.Request = httptest.NewRequest(http.MethodPost, "/v1/responses", bytes.NewReader(nil)) + c.Request.Header.Set("User-Agent", "codex_cli_rs/0.1.0") + + originalBody := []byte(`{"model":"gpt-5.2","stream":false,"input":[{"type":"text","text":"hi"}]}`) + + resp := &http.Response{ + StatusCode: http.StatusBadRequest, + Header: http.Header{"Content-Type": []string{"application/json"}, "x-request-id": []string{"rid"}}, + Body: io.NopCloser(strings.NewReader(`{"error":{"message":"bad"}}`)), + } + upstream := &httpUpstreamRecorder{resp: resp} + + svc := &OpenAIGatewayService{ + cfg: &config.Config{Gateway: config.GatewayConfig{ForceCodexCLI: false}}, + httpUpstream: upstream, + } + + account := &Account{ + ID: 123, + Name: "acc", + Platform: PlatformOpenAI, + Type: AccountTypeOAuth, + Concurrency: 1, + Credentials: map[string]any{"access_token": "oauth-token", "chatgpt_account_id": "chatgpt-acc"}, + Extra: map[string]any{"openai_passthrough": true}, + Status: StatusActive, + Schedulable: true, + RateMultiplier: f64p(1), + } + + _, err := svc.Forward(context.Background(), c, account, originalBody) + require.Error(t, err) + + // should append an upstream error event with passthrough=true + v, ok := c.Get(OpsUpstreamErrorsKey) + require.True(t, ok) + arr, ok := v.([]*OpsUpstreamErrorEvent) + require.True(t, ok) + require.NotEmpty(t, arr) + require.True(t, arr[len(arr)-1].Passthrough) +} + +func TestOpenAIGatewayService_OAuthPassthrough_NonCodexUAFallbackToCodexUA(t *testing.T) { + gin.SetMode(gin.TestMode) + + rec := httptest.NewRecorder() + c, _ := gin.CreateTestContext(rec) + c.Request = httptest.NewRequest(http.MethodPost, "/v1/responses", bytes.NewReader(nil)) + // Non-Codex UA + c.Request.Header.Set("User-Agent", "curl/8.0") + + inputBody := []byte(`{"model":"gpt-5.2","stream":false,"store":true,"input":[{"type":"text","text":"hi"}]}`) + + resp := &http.Response{ + StatusCode: http.StatusOK, + Header: http.Header{"Content-Type": []string{"text/event-stream"}, "x-request-id": []string{"rid"}}, + Body: io.NopCloser(strings.NewReader("data: [DONE]\n\n")), + } + upstream := &httpUpstreamRecorder{resp: resp} + + svc := &OpenAIGatewayService{ + cfg: &config.Config{Gateway: config.GatewayConfig{ForceCodexCLI: false}}, + httpUpstream: upstream, + } + + account := &Account{ + ID: 123, + Name: "acc", + Platform: PlatformOpenAI, + Type: AccountTypeOAuth, + Concurrency: 1, + Credentials: map[string]any{"access_token": "oauth-token", "chatgpt_account_id": "chatgpt-acc"}, + Extra: map[string]any{"openai_passthrough": true}, + Status: StatusActive, + Schedulable: true, + RateMultiplier: f64p(1), + } + + _, err := svc.Forward(context.Background(), c, account, inputBody) + require.NoError(t, err) + require.Equal(t, false, gjson.GetBytes(upstream.lastBody, "store").Bool()) + require.Equal(t, true, gjson.GetBytes(upstream.lastBody, "stream").Bool()) + require.Equal(t, "codex_cli_rs/0.98.0", upstream.lastReq.Header.Get("User-Agent")) +} + +func TestOpenAIGatewayService_CodexCLIOnly_RejectsNonCodexClient(t *testing.T) { + gin.SetMode(gin.TestMode) + + rec := httptest.NewRecorder() + c, _ := gin.CreateTestContext(rec) + c.Request = httptest.NewRequest(http.MethodPost, "/v1/responses", bytes.NewReader(nil)) + c.Request.Header.Set("User-Agent", "curl/8.0") + + inputBody := []byte(`{"model":"gpt-5.2","stream":false,"store":true,"input":[{"type":"text","text":"hi"}]}`) + + svc := &OpenAIGatewayService{ + cfg: &config.Config{Gateway: config.GatewayConfig{ForceCodexCLI: false}}, + } + + account := &Account{ + ID: 123, + Name: "acc", + Platform: PlatformOpenAI, + Type: AccountTypeOAuth, + Concurrency: 1, + Credentials: map[string]any{"access_token": "oauth-token", "chatgpt_account_id": "chatgpt-acc"}, + Extra: map[string]any{"openai_passthrough": true, "codex_cli_only": true}, + Status: StatusActive, + Schedulable: true, + RateMultiplier: f64p(1), + } + + _, err := svc.Forward(context.Background(), c, account, inputBody) + require.Error(t, err) + require.Equal(t, http.StatusForbidden, rec.Code) + require.Contains(t, rec.Body.String(), "Codex official clients") +} + +func TestOpenAIGatewayService_CodexCLIOnly_AllowOfficialClientFamilies(t *testing.T) { + gin.SetMode(gin.TestMode) + + tests := []struct { + name string + ua string + originator string + }{ + {name: "codex_cli_rs", ua: "codex_cli_rs/0.99.0", originator: ""}, + {name: "codex_vscode", ua: "codex_vscode/1.0.0", originator: ""}, + {name: "codex_app", ua: "codex_app/2.1.0", originator: ""}, + {name: "originator_codex_chatgpt_desktop", ua: "curl/8.0", originator: "codex_chatgpt_desktop"}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + rec := httptest.NewRecorder() + c, _ := gin.CreateTestContext(rec) + c.Request = httptest.NewRequest(http.MethodPost, "/v1/responses", bytes.NewReader(nil)) + c.Request.Header.Set("User-Agent", tt.ua) + if tt.originator != "" { + c.Request.Header.Set("originator", tt.originator) + } + + inputBody := []byte(`{"model":"gpt-5.2","stream":false,"store":true,"input":[{"type":"text","text":"hi"}]}`) + + resp := &http.Response{ + StatusCode: http.StatusOK, + Header: http.Header{"Content-Type": []string{"text/event-stream"}, "x-request-id": []string{"rid"}}, + Body: io.NopCloser(strings.NewReader("data: [DONE]\n\n")), + } + upstream := &httpUpstreamRecorder{resp: resp} + + svc := &OpenAIGatewayService{ + cfg: &config.Config{Gateway: config.GatewayConfig{ForceCodexCLI: false}}, + httpUpstream: upstream, + } + + account := &Account{ + ID: 123, + Name: "acc", + Platform: PlatformOpenAI, + Type: AccountTypeOAuth, + Concurrency: 1, + Credentials: map[string]any{"access_token": "oauth-token", "chatgpt_account_id": "chatgpt-acc"}, + Extra: map[string]any{"openai_passthrough": true, "codex_cli_only": true}, + Status: StatusActive, + Schedulable: true, + RateMultiplier: f64p(1), + } + + _, err := svc.Forward(context.Background(), c, account, inputBody) + require.NoError(t, err) + require.NotNil(t, upstream.lastReq) + }) + } +} + +func TestOpenAIGatewayService_OAuthPassthrough_StreamingSetsFirstTokenMs(t *testing.T) { + gin.SetMode(gin.TestMode) + + rec := httptest.NewRecorder() + c, _ := gin.CreateTestContext(rec) + c.Request = httptest.NewRequest(http.MethodPost, "/v1/responses", bytes.NewReader(nil)) + c.Request.Header.Set("User-Agent", "codex_cli_rs/0.1.0") + + originalBody := []byte(`{"model":"gpt-5.2","stream":true,"input":[{"type":"text","text":"hi"}]}`) + + upstreamSSE := strings.Join([]string{ + `data: {"type":"response.output_text.delta","delta":"h"}`, + "", + "data: [DONE]", + "", + }, "\n") + resp := &http.Response{ + StatusCode: http.StatusOK, + Header: http.Header{"Content-Type": []string{"text/event-stream"}, "x-request-id": []string{"rid"}}, + Body: io.NopCloser(strings.NewReader(upstreamSSE)), + } + upstream := &httpUpstreamRecorder{resp: resp} + + svc := &OpenAIGatewayService{ + cfg: &config.Config{Gateway: config.GatewayConfig{ForceCodexCLI: false}}, + httpUpstream: upstream, + } + + account := &Account{ + ID: 123, + Name: "acc", + Platform: PlatformOpenAI, + Type: AccountTypeOAuth, + Concurrency: 1, + Credentials: map[string]any{"access_token": "oauth-token", "chatgpt_account_id": "chatgpt-acc"}, + Extra: map[string]any{"openai_passthrough": true}, + Status: StatusActive, + Schedulable: true, + RateMultiplier: f64p(1), + } + + start := time.Now() + result, err := svc.Forward(context.Background(), c, account, originalBody) + require.NoError(t, err) + // sanity: duration after start + require.GreaterOrEqual(t, time.Since(start), time.Duration(0)) + require.NotNil(t, result.FirstTokenMs) + require.GreaterOrEqual(t, *result.FirstTokenMs, 0) +} + +func TestOpenAIGatewayService_OAuthPassthrough_StreamClientDisconnectStillCollectsUsage(t *testing.T) { + gin.SetMode(gin.TestMode) + + rec := httptest.NewRecorder() + c, _ := gin.CreateTestContext(rec) + c.Request = httptest.NewRequest(http.MethodPost, "/v1/responses", bytes.NewReader(nil)) + c.Request.Header.Set("User-Agent", "codex_cli_rs/0.1.0") + // 首次写入成功,后续写入失败,模拟客户端中途断开。 + c.Writer = &failingGinWriter{ResponseWriter: c.Writer, failAfter: 1} + + originalBody := []byte(`{"model":"gpt-5.2","stream":true,"input":[{"type":"text","text":"hi"}]}`) + + upstreamSSE := strings.Join([]string{ + `data: {"type":"response.output_text.delta","delta":"h"}`, + "", + `data: {"type":"response.completed","response":{"usage":{"input_tokens":11,"output_tokens":7,"input_tokens_details":{"cached_tokens":3}}}}`, + "", + "data: [DONE]", + "", + }, "\n") + resp := &http.Response{ + StatusCode: http.StatusOK, + Header: http.Header{"Content-Type": []string{"text/event-stream"}, "x-request-id": []string{"rid"}}, + Body: io.NopCloser(strings.NewReader(upstreamSSE)), + } + upstream := &httpUpstreamRecorder{resp: resp} + + svc := &OpenAIGatewayService{ + cfg: &config.Config{Gateway: config.GatewayConfig{ForceCodexCLI: false}}, + httpUpstream: upstream, + } + + account := &Account{ + ID: 123, + Name: "acc", + Platform: PlatformOpenAI, + Type: AccountTypeOAuth, + Concurrency: 1, + Credentials: map[string]any{"access_token": "oauth-token", "chatgpt_account_id": "chatgpt-acc"}, + Extra: map[string]any{"openai_passthrough": true}, + Status: StatusActive, + Schedulable: true, + RateMultiplier: f64p(1), + } + + result, err := svc.Forward(context.Background(), c, account, originalBody) + require.NoError(t, err) + require.NotNil(t, result) + require.True(t, result.Stream) + require.NotNil(t, result.FirstTokenMs) + require.Equal(t, 11, result.Usage.InputTokens) + require.Equal(t, 7, result.Usage.OutputTokens) + require.Equal(t, 3, result.Usage.CacheReadInputTokens) +} + +func TestOpenAIGatewayService_APIKeyPassthrough_PreservesBodyAndUsesResponsesEndpoint(t *testing.T) { + gin.SetMode(gin.TestMode) + + rec := httptest.NewRecorder() + c, _ := gin.CreateTestContext(rec) + c.Request = httptest.NewRequest(http.MethodPost, "/v1/responses", bytes.NewReader(nil)) + c.Request.Header.Set("User-Agent", "curl/8.0") + c.Request.Header.Set("X-Test", "keep") + + originalBody := []byte(`{"model":"gpt-5.2","stream":false,"max_output_tokens":128,"input":[{"type":"text","text":"hi"}]}`) + resp := &http.Response{ + StatusCode: http.StatusOK, + Header: http.Header{"Content-Type": []string{"application/json"}, "x-request-id": []string{"rid"}}, + Body: io.NopCloser(strings.NewReader(`{"output":[],"usage":{"input_tokens":1,"output_tokens":1,"input_tokens_details":{"cached_tokens":0}}}`)), + } + upstream := &httpUpstreamRecorder{resp: resp} + + svc := &OpenAIGatewayService{ + cfg: &config.Config{Gateway: config.GatewayConfig{ForceCodexCLI: false}}, + httpUpstream: upstream, + } + + account := &Account{ + ID: 456, + Name: "apikey-acc", + Platform: PlatformOpenAI, + Type: AccountTypeAPIKey, + Concurrency: 1, + Credentials: map[string]any{"api_key": "sk-api-key", "base_url": "https://api.openai.com"}, + Extra: map[string]any{"openai_passthrough": true}, + Status: StatusActive, + Schedulable: true, + RateMultiplier: f64p(1), + } + + _, err := svc.Forward(context.Background(), c, account, originalBody) + require.NoError(t, err) + require.NotNil(t, upstream.lastReq) + require.Equal(t, originalBody, upstream.lastBody) + require.Equal(t, "https://api.openai.com/v1/responses", upstream.lastReq.URL.String()) + require.Equal(t, "Bearer sk-api-key", upstream.lastReq.Header.Get("Authorization")) + require.Equal(t, "curl/8.0", upstream.lastReq.Header.Get("User-Agent")) + require.Empty(t, upstream.lastReq.Header.Get("X-Test")) +} + +func TestOpenAIGatewayService_OAuthPassthrough_WarnOnTimeoutHeadersForStream(t *testing.T) { + gin.SetMode(gin.TestMode) + logSink, restore := captureStructuredLog(t) + defer restore() + + rec := httptest.NewRecorder() + c, _ := gin.CreateTestContext(rec) + c.Request = httptest.NewRequest(http.MethodPost, "/v1/responses", bytes.NewReader(nil)) + c.Request.Header.Set("User-Agent", "codex_cli_rs/0.1.0") + c.Request.Header.Set("x-stainless-timeout", "10000") + + originalBody := []byte(`{"model":"gpt-5.2","stream":true,"input":[{"type":"text","text":"hi"}]}`) + resp := &http.Response{ + StatusCode: http.StatusOK, + Header: http.Header{"Content-Type": []string{"text/event-stream"}, "X-Request-Id": []string{"rid-timeout"}}, + Body: io.NopCloser(strings.NewReader("data: [DONE]\n\n")), + } + upstream := &httpUpstreamRecorder{resp: resp} + svc := &OpenAIGatewayService{ + cfg: &config.Config{Gateway: config.GatewayConfig{ForceCodexCLI: false}}, + httpUpstream: upstream, + } + account := &Account{ + ID: 321, + Name: "acc", + Platform: PlatformOpenAI, + Type: AccountTypeOAuth, + Concurrency: 1, + Credentials: map[string]any{"access_token": "oauth-token", "chatgpt_account_id": "chatgpt-acc"}, + Extra: map[string]any{"openai_passthrough": true}, + Status: StatusActive, + Schedulable: true, + RateMultiplier: f64p(1), + } + + _, err := svc.Forward(context.Background(), c, account, originalBody) + require.NoError(t, err) + require.True(t, logSink.ContainsMessage("检测到超时相关请求头,将按配置过滤以降低断流风险")) + require.True(t, logSink.ContainsFieldValue("timeout_headers", "x-stainless-timeout=10000")) +} + +func TestOpenAIGatewayService_OAuthPassthrough_InfoWhenStreamEndsWithoutDone(t *testing.T) { + gin.SetMode(gin.TestMode) + logSink, restore := captureStructuredLog(t) + defer restore() + + rec := httptest.NewRecorder() + c, _ := gin.CreateTestContext(rec) + c.Request = httptest.NewRequest(http.MethodPost, "/v1/responses", bytes.NewReader(nil)) + c.Request.Header.Set("User-Agent", "codex_cli_rs/0.1.0") + + originalBody := []byte(`{"model":"gpt-5.2","stream":true,"input":[{"type":"text","text":"hi"}]}`) + // 注意:刻意不发送 [DONE],模拟上游中途断流。 + resp := &http.Response{ + StatusCode: http.StatusOK, + Header: http.Header{"Content-Type": []string{"text/event-stream"}, "X-Request-Id": []string{"rid-truncate"}}, + Body: io.NopCloser(strings.NewReader("data: {\"type\":\"response.output_text.delta\",\"delta\":\"h\"}\n\n")), + } + upstream := &httpUpstreamRecorder{resp: resp} + svc := &OpenAIGatewayService{ + cfg: &config.Config{Gateway: config.GatewayConfig{ForceCodexCLI: false}}, + httpUpstream: upstream, + } + account := &Account{ + ID: 654, + Name: "acc", + Platform: PlatformOpenAI, + Type: AccountTypeOAuth, + Concurrency: 1, + Credentials: map[string]any{"access_token": "oauth-token", "chatgpt_account_id": "chatgpt-acc"}, + Extra: map[string]any{"openai_passthrough": true}, + Status: StatusActive, + Schedulable: true, + RateMultiplier: f64p(1), + } + + _, err := svc.Forward(context.Background(), c, account, originalBody) + require.NoError(t, err) + require.True(t, logSink.ContainsMessage("上游流在未收到 [DONE] 时结束,疑似断流")) + require.True(t, logSink.ContainsMessageAtLevel("上游流在未收到 [DONE] 时结束,疑似断流", "info")) + require.True(t, logSink.ContainsFieldValue("upstream_request_id", "rid-truncate")) +} + +func TestOpenAIGatewayService_OAuthPassthrough_DefaultFiltersTimeoutHeaders(t *testing.T) { + gin.SetMode(gin.TestMode) + + rec := httptest.NewRecorder() + c, _ := gin.CreateTestContext(rec) + c.Request = httptest.NewRequest(http.MethodPost, "/v1/responses", bytes.NewReader(nil)) + c.Request.Header.Set("User-Agent", "codex_cli_rs/0.1.0") + c.Request.Header.Set("x-stainless-timeout", "120000") + c.Request.Header.Set("X-Test", "keep") + + originalBody := []byte(`{"model":"gpt-5.2","stream":false,"input":[{"type":"text","text":"hi"}]}`) + resp := &http.Response{ + StatusCode: http.StatusOK, + Header: http.Header{"Content-Type": []string{"application/json"}, "X-Request-Id": []string{"rid-filter-default"}}, + Body: io.NopCloser(strings.NewReader(`{"output":[],"usage":{"input_tokens":1,"output_tokens":1,"input_tokens_details":{"cached_tokens":0}}}`)), + } + upstream := &httpUpstreamRecorder{resp: resp} + svc := &OpenAIGatewayService{ + cfg: &config.Config{Gateway: config.GatewayConfig{ForceCodexCLI: false}}, + httpUpstream: upstream, + } + account := &Account{ + ID: 111, + Name: "acc", + Platform: PlatformOpenAI, + Type: AccountTypeOAuth, + Concurrency: 1, + Credentials: map[string]any{"access_token": "oauth-token", "chatgpt_account_id": "chatgpt-acc"}, + Extra: map[string]any{"openai_passthrough": true}, + Status: StatusActive, + Schedulable: true, + RateMultiplier: f64p(1), + } + + _, err := svc.Forward(context.Background(), c, account, originalBody) + require.NoError(t, err) + require.NotNil(t, upstream.lastReq) + require.Empty(t, upstream.lastReq.Header.Get("x-stainless-timeout")) + require.Empty(t, upstream.lastReq.Header.Get("X-Test")) +} + +func TestOpenAIGatewayService_OAuthPassthrough_AllowTimeoutHeadersWhenConfigured(t *testing.T) { + gin.SetMode(gin.TestMode) + + rec := httptest.NewRecorder() + c, _ := gin.CreateTestContext(rec) + c.Request = httptest.NewRequest(http.MethodPost, "/v1/responses", bytes.NewReader(nil)) + c.Request.Header.Set("User-Agent", "codex_cli_rs/0.1.0") + c.Request.Header.Set("x-stainless-timeout", "120000") + c.Request.Header.Set("X-Test", "keep") + + originalBody := []byte(`{"model":"gpt-5.2","stream":false,"input":[{"type":"text","text":"hi"}]}`) + resp := &http.Response{ + StatusCode: http.StatusOK, + Header: http.Header{"Content-Type": []string{"application/json"}, "X-Request-Id": []string{"rid-filter-allow"}}, + Body: io.NopCloser(strings.NewReader(`{"output":[],"usage":{"input_tokens":1,"output_tokens":1,"input_tokens_details":{"cached_tokens":0}}}`)), + } + upstream := &httpUpstreamRecorder{resp: resp} + svc := &OpenAIGatewayService{ + cfg: &config.Config{Gateway: config.GatewayConfig{ + ForceCodexCLI: false, + OpenAIPassthroughAllowTimeoutHeaders: true, + }}, + httpUpstream: upstream, + } + account := &Account{ + ID: 222, + Name: "acc", + Platform: PlatformOpenAI, + Type: AccountTypeOAuth, + Concurrency: 1, + Credentials: map[string]any{"access_token": "oauth-token", "chatgpt_account_id": "chatgpt-acc"}, + Extra: map[string]any{"openai_passthrough": true}, + Status: StatusActive, + Schedulable: true, + RateMultiplier: f64p(1), + } + + _, err := svc.Forward(context.Background(), c, account, originalBody) + require.NoError(t, err) + require.NotNil(t, upstream.lastReq) + require.Equal(t, "120000", upstream.lastReq.Header.Get("x-stainless-timeout")) + require.Empty(t, upstream.lastReq.Header.Get("X-Test")) +} diff --git a/backend/internal/service/openai_oauth_service.go b/backend/internal/service/openai_oauth_service.go index ca7470b9..087ad4ec 100644 --- a/backend/internal/service/openai_oauth_service.go +++ b/backend/internal/service/openai_oauth_service.go @@ -2,13 +2,20 @@ package service import ( "context" + "crypto/subtle" + "encoding/json" + "io" "net/http" + "net/url" + "strings" "time" infraerrors "github.com/Wei-Shaw/sub2api/internal/pkg/errors" "github.com/Wei-Shaw/sub2api/internal/pkg/openai" ) +var openAISoraSessionAuthURL = "https://sora.chatgpt.com/api/auth/session" + // OpenAIOAuthService handles OpenAI OAuth authentication flows type OpenAIOAuthService struct { sessionStore *openai.SessionStore @@ -92,6 +99,7 @@ func (s *OpenAIOAuthService) GenerateAuthURL(ctx context.Context, proxyID *int64 type OpenAIExchangeCodeInput struct { SessionID string Code string + State string RedirectURI string ProxyID *int64 } @@ -116,6 +124,12 @@ func (s *OpenAIOAuthService) ExchangeCode(ctx context.Context, input *OpenAIExch if !ok { return nil, infraerrors.New(http.StatusBadRequest, "OPENAI_OAUTH_SESSION_NOT_FOUND", "session not found or expired") } + if input.State == "" { + return nil, infraerrors.New(http.StatusBadRequest, "OPENAI_OAUTH_STATE_REQUIRED", "oauth state is required") + } + if subtle.ConstantTimeCompare([]byte(input.State), []byte(session.State)) != 1 { + return nil, infraerrors.New(http.StatusBadRequest, "OPENAI_OAUTH_INVALID_STATE", "invalid oauth state") + } // Get proxy URL: prefer input.ProxyID, fallback to session.ProxyURL proxyURL := session.ProxyURL @@ -173,7 +187,12 @@ func (s *OpenAIOAuthService) ExchangeCode(ctx context.Context, input *OpenAIExch // RefreshToken refreshes an OpenAI OAuth token func (s *OpenAIOAuthService) RefreshToken(ctx context.Context, refreshToken string, proxyURL string) (*OpenAITokenInfo, error) { - tokenResp, err := s.oauthClient.RefreshToken(ctx, refreshToken, proxyURL) + return s.RefreshTokenWithClientID(ctx, refreshToken, proxyURL, "") +} + +// RefreshTokenWithClientID refreshes an OpenAI/Sora OAuth token with optional client_id. +func (s *OpenAIOAuthService) RefreshTokenWithClientID(ctx context.Context, refreshToken string, proxyURL string, clientID string) (*OpenAITokenInfo, error) { + tokenResp, err := s.oauthClient.RefreshTokenWithClientID(ctx, refreshToken, proxyURL, clientID) if err != nil { return nil, err } @@ -205,13 +224,83 @@ func (s *OpenAIOAuthService) RefreshToken(ctx context.Context, refreshToken stri return tokenInfo, nil } -// RefreshAccountToken refreshes token for an OpenAI account -func (s *OpenAIOAuthService) RefreshAccountToken(ctx context.Context, account *Account) (*OpenAITokenInfo, error) { - if !account.IsOpenAI() { - return nil, infraerrors.New(http.StatusBadRequest, "OPENAI_OAUTH_INVALID_ACCOUNT", "account is not an OpenAI account") +// ExchangeSoraSessionToken exchanges Sora session_token to access_token. +func (s *OpenAIOAuthService) ExchangeSoraSessionToken(ctx context.Context, sessionToken string, proxyID *int64) (*OpenAITokenInfo, error) { + if strings.TrimSpace(sessionToken) == "" { + return nil, infraerrors.New(http.StatusBadRequest, "SORA_SESSION_TOKEN_REQUIRED", "session_token is required") } - refreshToken := account.GetOpenAIRefreshToken() + proxyURL, err := s.resolveProxyURL(ctx, proxyID) + if err != nil { + return nil, err + } + + req, err := http.NewRequestWithContext(ctx, http.MethodGet, openAISoraSessionAuthURL, nil) + if err != nil { + return nil, infraerrors.Newf(http.StatusInternalServerError, "SORA_SESSION_REQUEST_BUILD_FAILED", "failed to build request: %v", err) + } + req.Header.Set("Cookie", "__Secure-next-auth.session-token="+strings.TrimSpace(sessionToken)) + req.Header.Set("Accept", "application/json") + req.Header.Set("Origin", "https://sora.chatgpt.com") + req.Header.Set("Referer", "https://sora.chatgpt.com/") + req.Header.Set("User-Agent", "Sora/1.2026.007 (Android 15; 24122RKC7C; build 2600700)") + + client := newOpenAIOAuthHTTPClient(proxyURL) + resp, err := client.Do(req) + if err != nil { + return nil, infraerrors.Newf(http.StatusBadGateway, "SORA_SESSION_REQUEST_FAILED", "request failed: %v", err) + } + defer func() { _ = resp.Body.Close() }() + + body, _ := io.ReadAll(io.LimitReader(resp.Body, 2<<20)) + if resp.StatusCode != http.StatusOK { + return nil, infraerrors.Newf(http.StatusBadGateway, "SORA_SESSION_EXCHANGE_FAILED", "status %d: %s", resp.StatusCode, strings.TrimSpace(string(body))) + } + + var sessionResp struct { + AccessToken string `json:"accessToken"` + Expires string `json:"expires"` + User struct { + Email string `json:"email"` + Name string `json:"name"` + } `json:"user"` + } + if err := json.Unmarshal(body, &sessionResp); err != nil { + return nil, infraerrors.Newf(http.StatusBadGateway, "SORA_SESSION_PARSE_FAILED", "failed to parse response: %v", err) + } + if strings.TrimSpace(sessionResp.AccessToken) == "" { + return nil, infraerrors.New(http.StatusBadGateway, "SORA_SESSION_ACCESS_TOKEN_MISSING", "session exchange response missing access token") + } + + expiresAt := time.Now().Add(time.Hour).Unix() + if strings.TrimSpace(sessionResp.Expires) != "" { + if parsed, parseErr := time.Parse(time.RFC3339, sessionResp.Expires); parseErr == nil { + expiresAt = parsed.Unix() + } + } + expiresIn := expiresAt - time.Now().Unix() + if expiresIn < 0 { + expiresIn = 0 + } + + return &OpenAITokenInfo{ + AccessToken: strings.TrimSpace(sessionResp.AccessToken), + ExpiresIn: expiresIn, + ExpiresAt: expiresAt, + Email: strings.TrimSpace(sessionResp.User.Email), + }, nil +} + +// RefreshAccountToken refreshes token for an OpenAI/Sora OAuth account +func (s *OpenAIOAuthService) RefreshAccountToken(ctx context.Context, account *Account) (*OpenAITokenInfo, error) { + if account.Platform != PlatformOpenAI && account.Platform != PlatformSora { + return nil, infraerrors.New(http.StatusBadRequest, "OPENAI_OAUTH_INVALID_ACCOUNT", "account is not an OpenAI/Sora account") + } + if account.Type != AccountTypeOAuth { + return nil, infraerrors.New(http.StatusBadRequest, "OPENAI_OAUTH_INVALID_ACCOUNT_TYPE", "account is not an OAuth account") + } + + refreshToken := account.GetCredential("refresh_token") if refreshToken == "" { return nil, infraerrors.New(http.StatusBadRequest, "OPENAI_OAUTH_NO_REFRESH_TOKEN", "no refresh token available") } @@ -224,7 +313,8 @@ func (s *OpenAIOAuthService) RefreshAccountToken(ctx context.Context, account *A } } - return s.RefreshToken(ctx, refreshToken, proxyURL) + clientID := account.GetCredential("client_id") + return s.RefreshTokenWithClientID(ctx, refreshToken, proxyURL, clientID) } // BuildAccountCredentials builds credentials map from token info @@ -260,3 +350,30 @@ func (s *OpenAIOAuthService) BuildAccountCredentials(tokenInfo *OpenAITokenInfo) func (s *OpenAIOAuthService) Stop() { s.sessionStore.Stop() } + +func (s *OpenAIOAuthService) resolveProxyURL(ctx context.Context, proxyID *int64) (string, error) { + if proxyID == nil { + return "", nil + } + proxy, err := s.proxyRepo.GetByID(ctx, *proxyID) + if err != nil { + return "", infraerrors.Newf(http.StatusBadRequest, "OPENAI_OAUTH_PROXY_NOT_FOUND", "proxy not found: %v", err) + } + if proxy == nil { + return "", nil + } + return proxy.URL(), nil +} + +func newOpenAIOAuthHTTPClient(proxyURL string) *http.Client { + transport := &http.Transport{} + if strings.TrimSpace(proxyURL) != "" { + if parsed, err := url.Parse(proxyURL); err == nil && parsed.Host != "" { + transport.Proxy = http.ProxyURL(parsed) + } + } + return &http.Client{ + Timeout: 120 * time.Second, + Transport: transport, + } +} diff --git a/backend/internal/service/openai_oauth_service_sora_session_test.go b/backend/internal/service/openai_oauth_service_sora_session_test.go new file mode 100644 index 00000000..fb76f6c1 --- /dev/null +++ b/backend/internal/service/openai_oauth_service_sora_session_test.go @@ -0,0 +1,69 @@ +package service + +import ( + "context" + "errors" + "net/http" + "net/http/httptest" + "testing" + + "github.com/Wei-Shaw/sub2api/internal/pkg/openai" + "github.com/stretchr/testify/require" +) + +type openaiOAuthClientNoopStub struct{} + +func (s *openaiOAuthClientNoopStub) ExchangeCode(ctx context.Context, code, codeVerifier, redirectURI, proxyURL string) (*openai.TokenResponse, error) { + return nil, errors.New("not implemented") +} + +func (s *openaiOAuthClientNoopStub) RefreshToken(ctx context.Context, refreshToken, proxyURL string) (*openai.TokenResponse, error) { + return nil, errors.New("not implemented") +} + +func (s *openaiOAuthClientNoopStub) RefreshTokenWithClientID(ctx context.Context, refreshToken, proxyURL string, clientID string) (*openai.TokenResponse, error) { + return nil, errors.New("not implemented") +} + +func TestOpenAIOAuthService_ExchangeSoraSessionToken_Success(t *testing.T) { + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + require.Equal(t, http.MethodGet, r.Method) + require.Contains(t, r.Header.Get("Cookie"), "__Secure-next-auth.session-token=st-token") + w.Header().Set("Content-Type", "application/json") + _, _ = w.Write([]byte(`{"accessToken":"at-token","expires":"2099-01-01T00:00:00Z","user":{"email":"demo@example.com"}}`)) + })) + defer server.Close() + + origin := openAISoraSessionAuthURL + openAISoraSessionAuthURL = server.URL + defer func() { openAISoraSessionAuthURL = origin }() + + svc := NewOpenAIOAuthService(nil, &openaiOAuthClientNoopStub{}) + defer svc.Stop() + + info, err := svc.ExchangeSoraSessionToken(context.Background(), "st-token", nil) + require.NoError(t, err) + require.NotNil(t, info) + require.Equal(t, "at-token", info.AccessToken) + require.Equal(t, "demo@example.com", info.Email) + require.Greater(t, info.ExpiresAt, int64(0)) +} + +func TestOpenAIOAuthService_ExchangeSoraSessionToken_MissingAccessToken(t *testing.T) { + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.Header().Set("Content-Type", "application/json") + _, _ = w.Write([]byte(`{"expires":"2099-01-01T00:00:00Z"}`)) + })) + defer server.Close() + + origin := openAISoraSessionAuthURL + openAISoraSessionAuthURL = server.URL + defer func() { openAISoraSessionAuthURL = origin }() + + svc := NewOpenAIOAuthService(nil, &openaiOAuthClientNoopStub{}) + defer svc.Stop() + + _, err := svc.ExchangeSoraSessionToken(context.Background(), "st-token", nil) + require.Error(t, err) + require.Contains(t, err.Error(), "missing access token") +} diff --git a/backend/internal/service/openai_oauth_service_state_test.go b/backend/internal/service/openai_oauth_service_state_test.go new file mode 100644 index 00000000..0a2a195f --- /dev/null +++ b/backend/internal/service/openai_oauth_service_state_test.go @@ -0,0 +1,102 @@ +package service + +import ( + "context" + "errors" + "sync/atomic" + "testing" + "time" + + "github.com/Wei-Shaw/sub2api/internal/pkg/openai" + "github.com/stretchr/testify/require" +) + +type openaiOAuthClientStateStub struct { + exchangeCalled int32 +} + +func (s *openaiOAuthClientStateStub) ExchangeCode(ctx context.Context, code, codeVerifier, redirectURI, proxyURL string) (*openai.TokenResponse, error) { + atomic.AddInt32(&s.exchangeCalled, 1) + return &openai.TokenResponse{ + AccessToken: "at", + RefreshToken: "rt", + ExpiresIn: 3600, + }, nil +} + +func (s *openaiOAuthClientStateStub) RefreshToken(ctx context.Context, refreshToken, proxyURL string) (*openai.TokenResponse, error) { + return nil, errors.New("not implemented") +} + +func (s *openaiOAuthClientStateStub) RefreshTokenWithClientID(ctx context.Context, refreshToken, proxyURL string, clientID string) (*openai.TokenResponse, error) { + return s.RefreshToken(ctx, refreshToken, proxyURL) +} + +func TestOpenAIOAuthService_ExchangeCode_StateRequired(t *testing.T) { + client := &openaiOAuthClientStateStub{} + svc := NewOpenAIOAuthService(nil, client) + defer svc.Stop() + + svc.sessionStore.Set("sid", &openai.OAuthSession{ + State: "expected-state", + CodeVerifier: "verifier", + RedirectURI: openai.DefaultRedirectURI, + CreatedAt: time.Now(), + }) + + _, err := svc.ExchangeCode(context.Background(), &OpenAIExchangeCodeInput{ + SessionID: "sid", + Code: "auth-code", + }) + require.Error(t, err) + require.Contains(t, err.Error(), "oauth state is required") + require.Equal(t, int32(0), atomic.LoadInt32(&client.exchangeCalled)) +} + +func TestOpenAIOAuthService_ExchangeCode_StateMismatch(t *testing.T) { + client := &openaiOAuthClientStateStub{} + svc := NewOpenAIOAuthService(nil, client) + defer svc.Stop() + + svc.sessionStore.Set("sid", &openai.OAuthSession{ + State: "expected-state", + CodeVerifier: "verifier", + RedirectURI: openai.DefaultRedirectURI, + CreatedAt: time.Now(), + }) + + _, err := svc.ExchangeCode(context.Background(), &OpenAIExchangeCodeInput{ + SessionID: "sid", + Code: "auth-code", + State: "wrong-state", + }) + require.Error(t, err) + require.Contains(t, err.Error(), "invalid oauth state") + require.Equal(t, int32(0), atomic.LoadInt32(&client.exchangeCalled)) +} + +func TestOpenAIOAuthService_ExchangeCode_StateMatch(t *testing.T) { + client := &openaiOAuthClientStateStub{} + svc := NewOpenAIOAuthService(nil, client) + defer svc.Stop() + + svc.sessionStore.Set("sid", &openai.OAuthSession{ + State: "expected-state", + CodeVerifier: "verifier", + RedirectURI: openai.DefaultRedirectURI, + CreatedAt: time.Now(), + }) + + info, err := svc.ExchangeCode(context.Background(), &OpenAIExchangeCodeInput{ + SessionID: "sid", + Code: "auth-code", + State: "expected-state", + }) + require.NoError(t, err) + require.NotNil(t, info) + require.Equal(t, "at", info.AccessToken) + require.Equal(t, int32(1), atomic.LoadInt32(&client.exchangeCalled)) + + _, ok := svc.sessionStore.Get("sid") + require.False(t, ok) +} diff --git a/backend/internal/service/openai_token_provider.go b/backend/internal/service/openai_token_provider.go index 87a7713b..a8a6b96c 100644 --- a/backend/internal/service/openai_token_provider.go +++ b/backend/internal/service/openai_token_provider.go @@ -4,16 +4,74 @@ import ( "context" "errors" "log/slog" + "math/rand/v2" "strings" + "sync/atomic" "time" ) const ( - openAITokenRefreshSkew = 3 * time.Minute - openAITokenCacheSkew = 5 * time.Minute - openAILockWaitTime = 200 * time.Millisecond + openAITokenRefreshSkew = 3 * time.Minute + openAITokenCacheSkew = 5 * time.Minute + openAILockInitialWait = 20 * time.Millisecond + openAILockMaxWait = 120 * time.Millisecond + openAILockMaxAttempts = 5 + openAILockJitterRatio = 0.2 + openAILockWarnThresholdMs = 250 ) +// OpenAITokenRuntimeMetrics 表示 OpenAI token 刷新与锁竞争保护指标快照。 +type OpenAITokenRuntimeMetrics struct { + RefreshRequests int64 + RefreshSuccess int64 + RefreshFailure int64 + LockAcquireFailure int64 + LockContention int64 + LockWaitSamples int64 + LockWaitTotalMs int64 + LockWaitHit int64 + LockWaitMiss int64 + LastObservedUnixMs int64 +} + +type openAITokenRuntimeMetricsStore struct { + refreshRequests atomic.Int64 + refreshSuccess atomic.Int64 + refreshFailure atomic.Int64 + lockAcquireFailure atomic.Int64 + lockContention atomic.Int64 + lockWaitSamples atomic.Int64 + lockWaitTotalMs atomic.Int64 + lockWaitHit atomic.Int64 + lockWaitMiss atomic.Int64 + lastObservedUnixMs atomic.Int64 +} + +func (m *openAITokenRuntimeMetricsStore) snapshot() OpenAITokenRuntimeMetrics { + if m == nil { + return OpenAITokenRuntimeMetrics{} + } + return OpenAITokenRuntimeMetrics{ + RefreshRequests: m.refreshRequests.Load(), + RefreshSuccess: m.refreshSuccess.Load(), + RefreshFailure: m.refreshFailure.Load(), + LockAcquireFailure: m.lockAcquireFailure.Load(), + LockContention: m.lockContention.Load(), + LockWaitSamples: m.lockWaitSamples.Load(), + LockWaitTotalMs: m.lockWaitTotalMs.Load(), + LockWaitHit: m.lockWaitHit.Load(), + LockWaitMiss: m.lockWaitMiss.Load(), + LastObservedUnixMs: m.lastObservedUnixMs.Load(), + } +} + +func (m *openAITokenRuntimeMetricsStore) touchNow() { + if m == nil { + return + } + m.lastObservedUnixMs.Store(time.Now().UnixMilli()) +} + // OpenAITokenCache Token 缓存接口(复用 GeminiTokenCache 接口定义) type OpenAITokenCache = GeminiTokenCache @@ -22,6 +80,7 @@ type OpenAITokenProvider struct { accountRepo AccountRepository tokenCache OpenAITokenCache openAIOAuthService *OpenAIOAuthService + metrics *openAITokenRuntimeMetricsStore } func NewOpenAITokenProvider( @@ -33,16 +92,32 @@ func NewOpenAITokenProvider( accountRepo: accountRepo, tokenCache: tokenCache, openAIOAuthService: openAIOAuthService, + metrics: &openAITokenRuntimeMetricsStore{}, + } +} + +func (p *OpenAITokenProvider) SnapshotRuntimeMetrics() OpenAITokenRuntimeMetrics { + if p == nil { + return OpenAITokenRuntimeMetrics{} + } + p.ensureMetrics() + return p.metrics.snapshot() +} + +func (p *OpenAITokenProvider) ensureMetrics() { + if p != nil && p.metrics == nil { + p.metrics = &openAITokenRuntimeMetricsStore{} } } // GetAccessToken 获取有效的 access_token func (p *OpenAITokenProvider) GetAccessToken(ctx context.Context, account *Account) (string, error) { + p.ensureMetrics() if account == nil { return "", errors.New("account is nil") } - if account.Platform != PlatformOpenAI || account.Type != AccountTypeOAuth { - return "", errors.New("not an openai oauth account") + if (account.Platform != PlatformOpenAI && account.Platform != PlatformSora) || account.Type != AccountTypeOAuth { + return "", errors.New("not an openai/sora oauth account") } cacheKey := OpenAITokenCacheKey(account) @@ -64,6 +139,8 @@ func (p *OpenAITokenProvider) GetAccessToken(ctx context.Context, account *Accou needsRefresh := expiresAt == nil || time.Until(*expiresAt) <= openAITokenRefreshSkew refreshFailed := false if needsRefresh && p.tokenCache != nil { + p.metrics.refreshRequests.Add(1) + p.metrics.touchNow() locked, lockErr := p.tokenCache.AcquireRefreshLock(ctx, cacheKey, 30*time.Second) if lockErr == nil && locked { defer func() { _ = p.tokenCache.ReleaseRefreshLock(ctx, cacheKey) }() @@ -80,16 +157,23 @@ func (p *OpenAITokenProvider) GetAccessToken(ctx context.Context, account *Accou } expiresAt = account.GetCredentialAsTime("expires_at") if expiresAt == nil || time.Until(*expiresAt) <= openAITokenRefreshSkew { - if p.openAIOAuthService == nil { + if account.Platform == PlatformSora { + slog.Debug("openai_token_refresh_skipped_for_sora", "account_id", account.ID) + // Sora 账号不走 OpenAI OAuth 刷新,交由 Sora 客户端的 ST/RT 恢复链路处理。 + refreshFailed = true + } else if p.openAIOAuthService == nil { slog.Warn("openai_oauth_service_not_configured", "account_id", account.ID) + p.metrics.refreshFailure.Add(1) refreshFailed = true // 无法刷新,标记失败 } else { tokenInfo, err := p.openAIOAuthService.RefreshAccountToken(ctx, account) if err != nil { // 刷新失败时记录警告,但不立即返回错误,尝试使用现有 token slog.Warn("openai_token_refresh_failed", "account_id", account.ID, "error", err) + p.metrics.refreshFailure.Add(1) refreshFailed = true // 刷新失败,标记以使用短 TTL } else { + p.metrics.refreshSuccess.Add(1) newCredentials := p.openAIOAuthService.BuildAccountCredentials(tokenInfo) for k, v := range account.Credentials { if _, exists := newCredentials[k]; !exists { @@ -106,6 +190,8 @@ func (p *OpenAITokenProvider) GetAccessToken(ctx context.Context, account *Accou } } else if lockErr != nil { // Redis 错误导致无法获取锁,降级为无锁刷新(仅在 token 接近过期时) + p.metrics.lockAcquireFailure.Add(1) + p.metrics.touchNow() slog.Warn("openai_token_lock_failed_degraded_refresh", "account_id", account.ID, "error", lockErr) // 检查 ctx 是否已取消 @@ -124,15 +210,22 @@ func (p *OpenAITokenProvider) GetAccessToken(ctx context.Context, account *Accou // 仅在 expires_at 已过期/接近过期时才执行无锁刷新 if expiresAt == nil || time.Until(*expiresAt) <= openAITokenRefreshSkew { - if p.openAIOAuthService == nil { + if account.Platform == PlatformSora { + slog.Debug("openai_token_refresh_skipped_for_sora_degraded", "account_id", account.ID) + // Sora 账号不走 OpenAI OAuth 刷新,交由 Sora 客户端的 ST/RT 恢复链路处理。 + refreshFailed = true + } else if p.openAIOAuthService == nil { slog.Warn("openai_oauth_service_not_configured", "account_id", account.ID) + p.metrics.refreshFailure.Add(1) refreshFailed = true } else { tokenInfo, err := p.openAIOAuthService.RefreshAccountToken(ctx, account) if err != nil { slog.Warn("openai_token_refresh_failed_degraded", "account_id", account.ID, "error", err) + p.metrics.refreshFailure.Add(1) refreshFailed = true } else { + p.metrics.refreshSuccess.Add(1) newCredentials := p.openAIOAuthService.BuildAccountCredentials(tokenInfo) for k, v := range account.Credentials { if _, exists := newCredentials[k]; !exists { @@ -148,16 +241,21 @@ func (p *OpenAITokenProvider) GetAccessToken(ctx context.Context, account *Accou } } } else { - // 锁获取失败(被其他 worker 持有),等待 200ms 后重试读取缓存 - time.Sleep(openAILockWaitTime) - if token, err := p.tokenCache.GetAccessToken(ctx, cacheKey); err == nil && strings.TrimSpace(token) != "" { + // 锁被其他 worker 持有:使用短轮询+jitter,降低固定等待导致的尾延迟台阶。 + p.metrics.lockContention.Add(1) + p.metrics.touchNow() + token, waitErr := p.waitForTokenAfterLockRace(ctx, cacheKey) + if waitErr != nil { + return "", waitErr + } + if strings.TrimSpace(token) != "" { slog.Debug("openai_token_cache_hit_after_wait", "account_id", account.ID) return token, nil } } } - accessToken := account.GetOpenAIAccessToken() + accessToken := account.GetCredential("access_token") if strings.TrimSpace(accessToken) == "" { return "", errors.New("access_token not found in credentials") } @@ -198,3 +296,64 @@ func (p *OpenAITokenProvider) GetAccessToken(ctx context.Context, account *Accou return accessToken, nil } + +func (p *OpenAITokenProvider) waitForTokenAfterLockRace(ctx context.Context, cacheKey string) (string, error) { + wait := openAILockInitialWait + totalWaitMs := int64(0) + for i := 0; i < openAILockMaxAttempts; i++ { + actualWait := jitterLockWait(wait) + timer := time.NewTimer(actualWait) + select { + case <-ctx.Done(): + if !timer.Stop() { + select { + case <-timer.C: + default: + } + } + return "", ctx.Err() + case <-timer.C: + } + + waitMs := actualWait.Milliseconds() + if waitMs < 0 { + waitMs = 0 + } + totalWaitMs += waitMs + p.metrics.lockWaitSamples.Add(1) + p.metrics.lockWaitTotalMs.Add(waitMs) + p.metrics.touchNow() + + token, err := p.tokenCache.GetAccessToken(ctx, cacheKey) + if err == nil && strings.TrimSpace(token) != "" { + p.metrics.lockWaitHit.Add(1) + if totalWaitMs >= openAILockWarnThresholdMs { + slog.Warn("openai_token_lock_wait_high", "wait_ms", totalWaitMs, "attempts", i+1) + } + return token, nil + } + + if wait < openAILockMaxWait { + wait *= 2 + if wait > openAILockMaxWait { + wait = openAILockMaxWait + } + } + } + + p.metrics.lockWaitMiss.Add(1) + if totalWaitMs >= openAILockWarnThresholdMs { + slog.Warn("openai_token_lock_wait_high", "wait_ms", totalWaitMs, "attempts", openAILockMaxAttempts) + } + return "", nil +} + +func jitterLockWait(base time.Duration) time.Duration { + if base <= 0 { + return 0 + } + minFactor := 1 - openAILockJitterRatio + maxFactor := 1 + openAILockJitterRatio + factor := minFactor + rand.Float64()*(maxFactor-minFactor) + return time.Duration(float64(base) * factor) +} diff --git a/backend/internal/service/openai_token_provider_test.go b/backend/internal/service/openai_token_provider_test.go index c2e3dbb0..1cd92367 100644 --- a/backend/internal/service/openai_token_provider_test.go +++ b/backend/internal/service/openai_token_provider_test.go @@ -375,7 +375,7 @@ func TestOpenAITokenProvider_WrongPlatform(t *testing.T) { token, err := provider.GetAccessToken(context.Background(), account) require.Error(t, err) - require.Contains(t, err.Error(), "not an openai oauth account") + require.Contains(t, err.Error(), "not an openai/sora oauth account") require.Empty(t, token) } @@ -389,7 +389,7 @@ func TestOpenAITokenProvider_WrongAccountType(t *testing.T) { token, err := provider.GetAccessToken(context.Background(), account) require.Error(t, err) - require.Contains(t, err.Error(), "not an openai oauth account") + require.Contains(t, err.Error(), "not an openai/sora oauth account") require.Empty(t, token) } @@ -808,3 +808,119 @@ func TestOpenAITokenProvider_Real_NilCredentials(t *testing.T) { require.Contains(t, err.Error(), "access_token not found") require.Empty(t, token) } + +func TestOpenAITokenProvider_Real_LockRace_PollingHitsCache(t *testing.T) { + cache := newOpenAITokenCacheStub() + cache.lockAcquired = false // 模拟锁被其他 worker 持有 + + expiresAt := time.Now().Add(1 * time.Minute).Format(time.RFC3339) + account := &Account{ + ID: 207, + Platform: PlatformOpenAI, + Type: AccountTypeOAuth, + Credentials: map[string]any{ + "access_token": "fallback-token", + "expires_at": expiresAt, + }, + } + + cacheKey := OpenAITokenCacheKey(account) + go func() { + time.Sleep(5 * time.Millisecond) + cache.mu.Lock() + cache.tokens[cacheKey] = "winner-token" + cache.mu.Unlock() + }() + + provider := NewOpenAITokenProvider(nil, cache, nil) + token, err := provider.GetAccessToken(context.Background(), account) + require.NoError(t, err) + require.Equal(t, "winner-token", token) +} + +func TestOpenAITokenProvider_Real_LockRace_ContextCanceled(t *testing.T) { + cache := newOpenAITokenCacheStub() + cache.lockAcquired = false // 模拟锁被其他 worker 持有 + + expiresAt := time.Now().Add(1 * time.Minute).Format(time.RFC3339) + account := &Account{ + ID: 208, + Platform: PlatformOpenAI, + Type: AccountTypeOAuth, + Credentials: map[string]any{ + "access_token": "fallback-token", + "expires_at": expiresAt, + }, + } + + ctx, cancel := context.WithCancel(context.Background()) + cancel() + + provider := NewOpenAITokenProvider(nil, cache, nil) + start := time.Now() + token, err := provider.GetAccessToken(ctx, account) + require.Error(t, err) + require.ErrorIs(t, err, context.Canceled) + require.Empty(t, token) + require.Less(t, time.Since(start), 50*time.Millisecond) +} + +func TestOpenAITokenProvider_RuntimeMetrics_LockWaitHitAndSnapshot(t *testing.T) { + cache := newOpenAITokenCacheStub() + cache.lockAcquired = false + + expiresAt := time.Now().Add(1 * time.Minute).Format(time.RFC3339) + account := &Account{ + ID: 209, + Platform: PlatformOpenAI, + Type: AccountTypeOAuth, + Credentials: map[string]any{ + "access_token": "fallback-token", + "expires_at": expiresAt, + }, + } + cacheKey := OpenAITokenCacheKey(account) + go func() { + time.Sleep(10 * time.Millisecond) + cache.mu.Lock() + cache.tokens[cacheKey] = "winner-token" + cache.mu.Unlock() + }() + + provider := NewOpenAITokenProvider(nil, cache, nil) + token, err := provider.GetAccessToken(context.Background(), account) + require.NoError(t, err) + require.Equal(t, "winner-token", token) + + metrics := provider.SnapshotRuntimeMetrics() + require.GreaterOrEqual(t, metrics.RefreshRequests, int64(1)) + require.GreaterOrEqual(t, metrics.LockContention, int64(1)) + require.GreaterOrEqual(t, metrics.LockWaitSamples, int64(1)) + require.GreaterOrEqual(t, metrics.LockWaitHit, int64(1)) + require.GreaterOrEqual(t, metrics.LockWaitTotalMs, int64(0)) + require.GreaterOrEqual(t, metrics.LastObservedUnixMs, int64(1)) +} + +func TestOpenAITokenProvider_RuntimeMetrics_LockAcquireFailure(t *testing.T) { + cache := newOpenAITokenCacheStub() + cache.lockErr = errors.New("redis lock error") + + expiresAt := time.Now().Add(1 * time.Minute).Format(time.RFC3339) + account := &Account{ + ID: 210, + Platform: PlatformOpenAI, + Type: AccountTypeOAuth, + Credentials: map[string]any{ + "access_token": "fallback-token", + "expires_at": expiresAt, + }, + } + + provider := NewOpenAITokenProvider(nil, cache, nil) + _, err := provider.GetAccessToken(context.Background(), account) + require.NoError(t, err) + + metrics := provider.SnapshotRuntimeMetrics() + require.GreaterOrEqual(t, metrics.LockAcquireFailure, int64(1)) + require.GreaterOrEqual(t, metrics.RefreshRequests, int64(1)) +} diff --git a/backend/internal/service/openai_tool_corrector.go b/backend/internal/service/openai_tool_corrector.go index f4719275..deec80fa 100644 --- a/backend/internal/service/openai_tool_corrector.go +++ b/backend/internal/service/openai_tool_corrector.go @@ -3,8 +3,9 @@ package service import ( "encoding/json" "fmt" - "log" "sync" + + "github.com/Wei-Shaw/sub2api/internal/pkg/logger" ) // codexToolNameMapping 定义 Codex 原生工具名称到 OpenCode 工具名称的映射 @@ -140,7 +141,7 @@ func (c *CodexToolCorrector) CorrectToolCallsInSSEData(data string) (string, boo // 序列化回 JSON correctedBytes, err := json.Marshal(payload) if err != nil { - log.Printf("[CodexToolCorrector] Failed to marshal corrected data: %v", err) + logger.LegacyPrintf("service.openai_tool_corrector", "[CodexToolCorrector] Failed to marshal corrected data: %v", err) return data, false } @@ -219,13 +220,13 @@ func (c *CodexToolCorrector) correctToolParameters(toolName string, functionCall argsMap["workdir"] = workDir delete(argsMap, "work_dir") corrected = true - log.Printf("[CodexToolCorrector] Renamed 'work_dir' to 'workdir' in bash tool") + logger.LegacyPrintf("service.openai_tool_corrector", "[CodexToolCorrector] Renamed 'work_dir' to 'workdir' in bash tool") } } else { if _, exists := argsMap["work_dir"]; exists { delete(argsMap, "work_dir") corrected = true - log.Printf("[CodexToolCorrector] Removed duplicate 'work_dir' parameter from bash tool") + logger.LegacyPrintf("service.openai_tool_corrector", "[CodexToolCorrector] Removed duplicate 'work_dir' parameter from bash tool") } } @@ -236,17 +237,17 @@ func (c *CodexToolCorrector) correctToolParameters(toolName string, functionCall argsMap["filePath"] = filePath delete(argsMap, "file_path") corrected = true - log.Printf("[CodexToolCorrector] Renamed 'file_path' to 'filePath' in edit tool") + logger.LegacyPrintf("service.openai_tool_corrector", "[CodexToolCorrector] Renamed 'file_path' to 'filePath' in edit tool") } else if filePath, exists := argsMap["path"]; exists { argsMap["filePath"] = filePath delete(argsMap, "path") corrected = true - log.Printf("[CodexToolCorrector] Renamed 'path' to 'filePath' in edit tool") + logger.LegacyPrintf("service.openai_tool_corrector", "[CodexToolCorrector] Renamed 'path' to 'filePath' in edit tool") } else if filePath, exists := argsMap["file"]; exists { argsMap["filePath"] = filePath delete(argsMap, "file") corrected = true - log.Printf("[CodexToolCorrector] Renamed 'file' to 'filePath' in edit tool") + logger.LegacyPrintf("service.openai_tool_corrector", "[CodexToolCorrector] Renamed 'file' to 'filePath' in edit tool") } } @@ -255,7 +256,7 @@ func (c *CodexToolCorrector) correctToolParameters(toolName string, functionCall argsMap["oldString"] = oldString delete(argsMap, "old_string") corrected = true - log.Printf("[CodexToolCorrector] Renamed 'old_string' to 'oldString' in edit tool") + logger.LegacyPrintf("service.openai_tool_corrector", "[CodexToolCorrector] Renamed 'old_string' to 'oldString' in edit tool") } } @@ -264,7 +265,7 @@ func (c *CodexToolCorrector) correctToolParameters(toolName string, functionCall argsMap["newString"] = newString delete(argsMap, "new_string") corrected = true - log.Printf("[CodexToolCorrector] Renamed 'new_string' to 'newString' in edit tool") + logger.LegacyPrintf("service.openai_tool_corrector", "[CodexToolCorrector] Renamed 'new_string' to 'newString' in edit tool") } } @@ -273,7 +274,7 @@ func (c *CodexToolCorrector) correctToolParameters(toolName string, functionCall argsMap["replaceAll"] = replaceAll delete(argsMap, "replace_all") corrected = true - log.Printf("[CodexToolCorrector] Renamed 'replace_all' to 'replaceAll' in edit tool") + logger.LegacyPrintf("service.openai_tool_corrector", "[CodexToolCorrector] Renamed 'replace_all' to 'replaceAll' in edit tool") } } } @@ -303,7 +304,7 @@ func (c *CodexToolCorrector) recordCorrection(from, to string) { key := fmt.Sprintf("%s->%s", from, to) c.stats.CorrectionsByTool[key]++ - log.Printf("[CodexToolCorrector] Corrected tool call: %s -> %s (total: %d)", + logger.LegacyPrintf("service.openai_tool_corrector", "[CodexToolCorrector] Corrected tool call: %s -> %s (total: %d)", from, to, c.stats.TotalCorrected) } diff --git a/backend/internal/service/ops_aggregation_service.go b/backend/internal/service/ops_aggregation_service.go index 972462ec..ec77fe12 100644 --- a/backend/internal/service/ops_aggregation_service.go +++ b/backend/internal/service/ops_aggregation_service.go @@ -5,12 +5,12 @@ import ( "database/sql" "errors" "fmt" - "log" "strings" "sync" "time" "github.com/Wei-Shaw/sub2api/internal/config" + "github.com/Wei-Shaw/sub2api/internal/pkg/logger" "github.com/google/uuid" "github.com/redis/go-redis/v9" ) @@ -190,7 +190,7 @@ func (s *OpsAggregationService) aggregateHourly() { latest, ok, err := s.opsRepo.GetLatestHourlyBucketStart(ctxMax) cancelMax() if err != nil { - log.Printf("[OpsAggregation][hourly] failed to read latest bucket: %v", err) + logger.LegacyPrintf("service.ops_aggregation", "[OpsAggregation][hourly] failed to read latest bucket: %v", err) } else if ok { candidate := latest.Add(-opsAggHourlyOverlap) if candidate.After(start) { @@ -209,7 +209,7 @@ func (s *OpsAggregationService) aggregateHourly() { chunkEnd := minTime(cursor.Add(opsAggHourlyChunk), end) if err := s.opsRepo.UpsertHourlyMetrics(ctx, cursor, chunkEnd); err != nil { aggErr = err - log.Printf("[OpsAggregation][hourly] upsert failed (%s..%s): %v", cursor.Format(time.RFC3339), chunkEnd.Format(time.RFC3339), err) + logger.LegacyPrintf("service.ops_aggregation", "[OpsAggregation][hourly] upsert failed (%s..%s): %v", cursor.Format(time.RFC3339), chunkEnd.Format(time.RFC3339), err) break } } @@ -288,7 +288,7 @@ func (s *OpsAggregationService) aggregateDaily() { latest, ok, err := s.opsRepo.GetLatestDailyBucketDate(ctxMax) cancelMax() if err != nil { - log.Printf("[OpsAggregation][daily] failed to read latest bucket: %v", err) + logger.LegacyPrintf("service.ops_aggregation", "[OpsAggregation][daily] failed to read latest bucket: %v", err) } else if ok { candidate := latest.Add(-opsAggDailyOverlap) if candidate.After(start) { @@ -307,7 +307,7 @@ func (s *OpsAggregationService) aggregateDaily() { chunkEnd := minTime(cursor.Add(opsAggDailyChunk), end) if err := s.opsRepo.UpsertDailyMetrics(ctx, cursor, chunkEnd); err != nil { aggErr = err - log.Printf("[OpsAggregation][daily] upsert failed (%s..%s): %v", cursor.Format("2006-01-02"), chunkEnd.Format("2006-01-02"), err) + logger.LegacyPrintf("service.ops_aggregation", "[OpsAggregation][daily] upsert failed (%s..%s): %v", cursor.Format("2006-01-02"), chunkEnd.Format("2006-01-02"), err) break } } @@ -427,7 +427,7 @@ func (s *OpsAggregationService) maybeLogSkip(prefix string) { if prefix == "" { prefix = "[OpsAggregation]" } - log.Printf("%s leader lock held by another instance; skipping", prefix) + logger.LegacyPrintf("service.ops_aggregation", "%s leader lock held by another instance; skipping", prefix) } func utcFloorToHour(t time.Time) time.Time { diff --git a/backend/internal/service/ops_alert_evaluator_service.go b/backend/internal/service/ops_alert_evaluator_service.go index 7c62e247..169a5e32 100644 --- a/backend/internal/service/ops_alert_evaluator_service.go +++ b/backend/internal/service/ops_alert_evaluator_service.go @@ -3,7 +3,6 @@ package service import ( "context" "fmt" - "log" "math" "strconv" "strings" @@ -11,6 +10,7 @@ import ( "time" "github.com/Wei-Shaw/sub2api/internal/config" + "github.com/Wei-Shaw/sub2api/internal/pkg/logger" "github.com/google/uuid" "github.com/redis/go-redis/v9" ) @@ -186,7 +186,7 @@ func (s *OpsAlertEvaluatorService) evaluateOnce(interval time.Duration) { rules, err := s.opsRepo.ListAlertRules(ctx) if err != nil { s.recordHeartbeatError(runAt, time.Since(startedAt), err) - log.Printf("[OpsAlertEvaluator] list rules failed: %v", err) + logger.LegacyPrintf("service.ops_alert_evaluator", "[OpsAlertEvaluator] list rules failed: %v", err) return } @@ -236,7 +236,7 @@ func (s *OpsAlertEvaluatorService) evaluateOnce(interval time.Duration) { activeEvent, err := s.opsRepo.GetActiveAlertEvent(ctx, rule.ID) if err != nil { - log.Printf("[OpsAlertEvaluator] get active event failed (rule=%d): %v", rule.ID, err) + logger.LegacyPrintf("service.ops_alert_evaluator", "[OpsAlertEvaluator] get active event failed (rule=%d): %v", rule.ID, err) continue } @@ -258,7 +258,7 @@ func (s *OpsAlertEvaluatorService) evaluateOnce(interval time.Duration) { latestEvent, err := s.opsRepo.GetLatestAlertEvent(ctx, rule.ID) if err != nil { - log.Printf("[OpsAlertEvaluator] get latest event failed (rule=%d): %v", rule.ID, err) + logger.LegacyPrintf("service.ops_alert_evaluator", "[OpsAlertEvaluator] get latest event failed (rule=%d): %v", rule.ID, err) continue } if latestEvent != nil && rule.CooldownMinutes > 0 { @@ -283,7 +283,7 @@ func (s *OpsAlertEvaluatorService) evaluateOnce(interval time.Duration) { created, err := s.opsRepo.CreateAlertEvent(ctx, firedEvent) if err != nil { - log.Printf("[OpsAlertEvaluator] create event failed (rule=%d): %v", rule.ID, err) + logger.LegacyPrintf("service.ops_alert_evaluator", "[OpsAlertEvaluator] create event failed (rule=%d): %v", rule.ID, err) continue } @@ -300,7 +300,7 @@ func (s *OpsAlertEvaluatorService) evaluateOnce(interval time.Duration) { if activeEvent != nil { resolvedAt := now if err := s.opsRepo.UpdateAlertEventStatus(ctx, activeEvent.ID, OpsAlertStatusResolved, &resolvedAt); err != nil { - log.Printf("[OpsAlertEvaluator] resolve event failed (event=%d): %v", activeEvent.ID, err) + logger.LegacyPrintf("service.ops_alert_evaluator", "[OpsAlertEvaluator] resolve event failed (event=%d): %v", activeEvent.ID, err) } else { eventsResolved++ } @@ -779,7 +779,7 @@ func (s *OpsAlertEvaluatorService) tryAcquireLeaderLock(ctx context.Context, loc } if s.redisClient == nil { s.warnNoRedisOnce.Do(func() { - log.Printf("[OpsAlertEvaluator] redis not configured; running without distributed lock") + logger.LegacyPrintf("service.ops_alert_evaluator", "[OpsAlertEvaluator] redis not configured; running without distributed lock") }) return nil, true } @@ -797,7 +797,7 @@ func (s *OpsAlertEvaluatorService) tryAcquireLeaderLock(ctx context.Context, loc // Prefer fail-closed to avoid duplicate evaluators stampeding the DB when Redis is flaky. // Single-node deployments can disable the distributed lock via runtime settings. s.warnNoRedisOnce.Do(func() { - log.Printf("[OpsAlertEvaluator] leader lock SetNX failed; skipping this cycle: %v", err) + logger.LegacyPrintf("service.ops_alert_evaluator", "[OpsAlertEvaluator] leader lock SetNX failed; skipping this cycle: %v", err) }) return nil, false } @@ -819,7 +819,7 @@ func (s *OpsAlertEvaluatorService) maybeLogSkip(key string) { return } s.skipLogAt = now - log.Printf("[OpsAlertEvaluator] leader lock held by another instance; skipping (key=%q)", key) + logger.LegacyPrintf("service.ops_alert_evaluator", "[OpsAlertEvaluator] leader lock held by another instance; skipping (key=%q)", key) } func (s *OpsAlertEvaluatorService) recordHeartbeatSuccess(runAt time.Time, duration time.Duration, result string) { diff --git a/backend/internal/service/ops_alert_evaluator_service_test.go b/backend/internal/service/ops_alert_evaluator_service_test.go index 068ab6bb..83d358a3 100644 --- a/backend/internal/service/ops_alert_evaluator_service_test.go +++ b/backend/internal/service/ops_alert_evaluator_service_test.go @@ -10,6 +10,8 @@ import ( "github.com/stretchr/testify/require" ) +var _ OpsRepository = (*stubOpsRepo)(nil) + type stubOpsRepo struct { OpsRepository overview *OpsDashboardOverview diff --git a/backend/internal/service/ops_cleanup_service.go b/backend/internal/service/ops_cleanup_service.go index 1ade7176..1cae6fe5 100644 --- a/backend/internal/service/ops_cleanup_service.go +++ b/backend/internal/service/ops_cleanup_service.go @@ -4,12 +4,12 @@ import ( "context" "database/sql" "fmt" - "log" "strings" "sync" "time" "github.com/Wei-Shaw/sub2api/internal/config" + "github.com/Wei-Shaw/sub2api/internal/pkg/logger" "github.com/google/uuid" "github.com/redis/go-redis/v9" "github.com/robfig/cron/v3" @@ -75,11 +75,11 @@ func (s *OpsCleanupService) Start() { return } if s.cfg != nil && !s.cfg.Ops.Cleanup.Enabled { - log.Printf("[OpsCleanup] not started (disabled)") + logger.LegacyPrintf("service.ops_cleanup", "[OpsCleanup] not started (disabled)") return } if s.opsRepo == nil || s.db == nil { - log.Printf("[OpsCleanup] not started (missing deps)") + logger.LegacyPrintf("service.ops_cleanup", "[OpsCleanup] not started (missing deps)") return } @@ -99,12 +99,12 @@ func (s *OpsCleanupService) Start() { c := cron.New(cron.WithParser(opsCleanupCronParser), cron.WithLocation(loc)) _, err := c.AddFunc(schedule, func() { s.runScheduled() }) if err != nil { - log.Printf("[OpsCleanup] not started (invalid schedule=%q): %v", schedule, err) + logger.LegacyPrintf("service.ops_cleanup", "[OpsCleanup] not started (invalid schedule=%q): %v", schedule, err) return } s.cron = c s.cron.Start() - log.Printf("[OpsCleanup] started (schedule=%q tz=%s)", schedule, loc.String()) + logger.LegacyPrintf("service.ops_cleanup", "[OpsCleanup] started (schedule=%q tz=%s)", schedule, loc.String()) }) } @@ -118,7 +118,7 @@ func (s *OpsCleanupService) Stop() { select { case <-ctx.Done(): case <-time.After(3 * time.Second): - log.Printf("[OpsCleanup] cron stop timed out") + logger.LegacyPrintf("service.ops_cleanup", "[OpsCleanup] cron stop timed out") } } }) @@ -146,17 +146,19 @@ func (s *OpsCleanupService) runScheduled() { counts, err := s.runCleanupOnce(ctx) if err != nil { s.recordHeartbeatError(runAt, time.Since(startedAt), err) - log.Printf("[OpsCleanup] cleanup failed: %v", err) + logger.LegacyPrintf("service.ops_cleanup", "[OpsCleanup] cleanup failed: %v", err) return } s.recordHeartbeatSuccess(runAt, time.Since(startedAt), counts) - log.Printf("[OpsCleanup] cleanup complete: %s", counts) + logger.LegacyPrintf("service.ops_cleanup", "[OpsCleanup] cleanup complete: %s", counts) } type opsCleanupDeletedCounts struct { errorLogs int64 retryAttempts int64 alertEvents int64 + systemLogs int64 + logAudits int64 systemMetrics int64 hourlyPreagg int64 dailyPreagg int64 @@ -164,10 +166,12 @@ type opsCleanupDeletedCounts struct { func (c opsCleanupDeletedCounts) String() string { return fmt.Sprintf( - "error_logs=%d retry_attempts=%d alert_events=%d system_metrics=%d hourly_preagg=%d daily_preagg=%d", + "error_logs=%d retry_attempts=%d alert_events=%d system_logs=%d log_audits=%d system_metrics=%d hourly_preagg=%d daily_preagg=%d", c.errorLogs, c.retryAttempts, c.alertEvents, + c.systemLogs, + c.logAudits, c.systemMetrics, c.hourlyPreagg, c.dailyPreagg, @@ -204,6 +208,18 @@ func (s *OpsCleanupService) runCleanupOnce(ctx context.Context) (opsCleanupDelet return out, err } out.alertEvents = n + + n, err = deleteOldRowsByID(ctx, s.db, "ops_system_logs", "created_at", cutoff, batchSize, false) + if err != nil { + return out, err + } + out.systemLogs = n + + n, err = deleteOldRowsByID(ctx, s.db, "ops_system_log_cleanup_audits", "created_at", cutoff, batchSize, false) + if err != nil { + return out, err + } + out.logAudits = n } // Minute-level metrics snapshots. @@ -315,11 +331,11 @@ func (s *OpsCleanupService) tryAcquireLeaderLock(ctx context.Context) (func(), b } // Redis error: fall back to DB advisory lock. s.warnNoRedisOnce.Do(func() { - log.Printf("[OpsCleanup] leader lock SetNX failed; falling back to DB advisory lock: %v", err) + logger.LegacyPrintf("service.ops_cleanup", "[OpsCleanup] leader lock SetNX failed; falling back to DB advisory lock: %v", err) }) } else { s.warnNoRedisOnce.Do(func() { - log.Printf("[OpsCleanup] redis not configured; using DB advisory lock") + logger.LegacyPrintf("service.ops_cleanup", "[OpsCleanup] redis not configured; using DB advisory lock") }) } diff --git a/backend/internal/service/ops_log_runtime.go b/backend/internal/service/ops_log_runtime.go new file mode 100644 index 00000000..ed8aefa9 --- /dev/null +++ b/backend/internal/service/ops_log_runtime.go @@ -0,0 +1,267 @@ +package service + +import ( + "context" + "encoding/json" + "errors" + "fmt" + "strings" + "time" + + "github.com/Wei-Shaw/sub2api/internal/config" + "github.com/Wei-Shaw/sub2api/internal/pkg/logger" + "go.uber.org/zap" +) + +func defaultOpsRuntimeLogConfig(cfg *config.Config) *OpsRuntimeLogConfig { + out := &OpsRuntimeLogConfig{ + Level: "info", + EnableSampling: false, + SamplingInitial: 100, + SamplingNext: 100, + Caller: true, + StacktraceLevel: "error", + RetentionDays: 30, + } + if cfg == nil { + return out + } + out.Level = strings.ToLower(strings.TrimSpace(cfg.Log.Level)) + out.EnableSampling = cfg.Log.Sampling.Enabled + out.SamplingInitial = cfg.Log.Sampling.Initial + out.SamplingNext = cfg.Log.Sampling.Thereafter + out.Caller = cfg.Log.Caller + out.StacktraceLevel = strings.ToLower(strings.TrimSpace(cfg.Log.StacktraceLevel)) + if cfg.Ops.Cleanup.ErrorLogRetentionDays > 0 { + out.RetentionDays = cfg.Ops.Cleanup.ErrorLogRetentionDays + } + return out +} + +func normalizeOpsRuntimeLogConfig(cfg *OpsRuntimeLogConfig, defaults *OpsRuntimeLogConfig) { + if cfg == nil || defaults == nil { + return + } + cfg.Level = strings.ToLower(strings.TrimSpace(cfg.Level)) + if cfg.Level == "" { + cfg.Level = defaults.Level + } + cfg.StacktraceLevel = strings.ToLower(strings.TrimSpace(cfg.StacktraceLevel)) + if cfg.StacktraceLevel == "" { + cfg.StacktraceLevel = defaults.StacktraceLevel + } + if cfg.SamplingInitial <= 0 { + cfg.SamplingInitial = defaults.SamplingInitial + } + if cfg.SamplingNext <= 0 { + cfg.SamplingNext = defaults.SamplingNext + } + if cfg.RetentionDays <= 0 { + cfg.RetentionDays = defaults.RetentionDays + } +} + +func validateOpsRuntimeLogConfig(cfg *OpsRuntimeLogConfig) error { + if cfg == nil { + return errors.New("invalid config") + } + switch strings.ToLower(strings.TrimSpace(cfg.Level)) { + case "debug", "info", "warn", "error": + default: + return errors.New("level must be one of: debug/info/warn/error") + } + switch strings.ToLower(strings.TrimSpace(cfg.StacktraceLevel)) { + case "none", "error", "fatal": + default: + return errors.New("stacktrace_level must be one of: none/error/fatal") + } + if cfg.SamplingInitial <= 0 { + return errors.New("sampling_initial must be positive") + } + if cfg.SamplingNext <= 0 { + return errors.New("sampling_thereafter must be positive") + } + if cfg.RetentionDays < 1 || cfg.RetentionDays > 3650 { + return errors.New("retention_days must be between 1 and 3650") + } + return nil +} + +func (s *OpsService) GetRuntimeLogConfig(ctx context.Context) (*OpsRuntimeLogConfig, error) { + if s == nil || s.settingRepo == nil { + var cfg *config.Config + if s != nil { + cfg = s.cfg + } + defaultCfg := defaultOpsRuntimeLogConfig(cfg) + return defaultCfg, nil + } + defaultCfg := defaultOpsRuntimeLogConfig(s.cfg) + if ctx == nil { + ctx = context.Background() + } + + raw, err := s.settingRepo.GetValue(ctx, SettingKeyOpsRuntimeLogConfig) + if err != nil { + if errors.Is(err, ErrSettingNotFound) { + b, _ := json.Marshal(defaultCfg) + _ = s.settingRepo.Set(ctx, SettingKeyOpsRuntimeLogConfig, string(b)) + return defaultCfg, nil + } + return nil, err + } + + cfg := &OpsRuntimeLogConfig{} + if err := json.Unmarshal([]byte(raw), cfg); err != nil { + return defaultCfg, nil + } + normalizeOpsRuntimeLogConfig(cfg, defaultCfg) + return cfg, nil +} + +func (s *OpsService) UpdateRuntimeLogConfig(ctx context.Context, req *OpsRuntimeLogConfig, operatorID int64) (*OpsRuntimeLogConfig, error) { + if s == nil || s.settingRepo == nil { + return nil, errors.New("setting repository not initialized") + } + if req == nil { + return nil, errors.New("invalid config") + } + if ctx == nil { + ctx = context.Background() + } + if operatorID <= 0 { + return nil, errors.New("invalid operator id") + } + + oldCfg, err := s.GetRuntimeLogConfig(ctx) + if err != nil { + return nil, err + } + next := *req + normalizeOpsRuntimeLogConfig(&next, defaultOpsRuntimeLogConfig(s.cfg)) + if err := validateOpsRuntimeLogConfig(&next); err != nil { + s.auditRuntimeLogConfigFailure(operatorID, oldCfg, &next, "validation_failed: "+err.Error()) + return nil, err + } + + if err := applyOpsRuntimeLogConfig(&next); err != nil { + s.auditRuntimeLogConfigFailure(operatorID, oldCfg, &next, "apply_failed: "+err.Error()) + return nil, err + } + + next.Source = "runtime_setting" + next.UpdatedAt = time.Now().UTC().Format(time.RFC3339Nano) + next.UpdatedByUserID = operatorID + + encoded, err := json.Marshal(&next) + if err != nil { + return nil, err + } + if err := s.settingRepo.Set(ctx, SettingKeyOpsRuntimeLogConfig, string(encoded)); err != nil { + // 存储失败时回滚到旧配置,避免内存状态与持久化状态不一致。 + _ = applyOpsRuntimeLogConfig(oldCfg) + s.auditRuntimeLogConfigFailure(operatorID, oldCfg, &next, "persist_failed: "+err.Error()) + return nil, err + } + + s.auditRuntimeLogConfigChange(operatorID, oldCfg, &next, "updated") + + return &next, nil +} + +func (s *OpsService) ResetRuntimeLogConfig(ctx context.Context, operatorID int64) (*OpsRuntimeLogConfig, error) { + if s == nil || s.settingRepo == nil { + return nil, errors.New("setting repository not initialized") + } + if ctx == nil { + ctx = context.Background() + } + if operatorID <= 0 { + return nil, errors.New("invalid operator id") + } + + oldCfg, err := s.GetRuntimeLogConfig(ctx) + if err != nil { + return nil, err + } + + resetCfg := defaultOpsRuntimeLogConfig(s.cfg) + normalizeOpsRuntimeLogConfig(resetCfg, defaultOpsRuntimeLogConfig(s.cfg)) + if err := validateOpsRuntimeLogConfig(resetCfg); err != nil { + s.auditRuntimeLogConfigFailure(operatorID, oldCfg, resetCfg, "reset_validation_failed: "+err.Error()) + return nil, err + } + if err := applyOpsRuntimeLogConfig(resetCfg); err != nil { + s.auditRuntimeLogConfigFailure(operatorID, oldCfg, resetCfg, "reset_apply_failed: "+err.Error()) + return nil, err + } + + // 清理 runtime 覆盖配置,回退到 env/yaml baseline。 + if err := s.settingRepo.Delete(ctx, SettingKeyOpsRuntimeLogConfig); err != nil && !errors.Is(err, ErrSettingNotFound) { + _ = applyOpsRuntimeLogConfig(oldCfg) + s.auditRuntimeLogConfigFailure(operatorID, oldCfg, resetCfg, "reset_persist_failed: "+err.Error()) + return nil, err + } + + now := time.Now().UTC().Format(time.RFC3339Nano) + resetCfg.Source = "baseline" + resetCfg.UpdatedAt = now + resetCfg.UpdatedByUserID = operatorID + + s.auditRuntimeLogConfigChange(operatorID, oldCfg, resetCfg, "reset") + return resetCfg, nil +} + +func applyOpsRuntimeLogConfig(cfg *OpsRuntimeLogConfig) error { + if cfg == nil { + return fmt.Errorf("nil runtime log config") + } + if err := logger.Reconfigure(func(opts *logger.InitOptions) error { + opts.Level = strings.ToLower(strings.TrimSpace(cfg.Level)) + opts.Caller = cfg.Caller + opts.StacktraceLevel = strings.ToLower(strings.TrimSpace(cfg.StacktraceLevel)) + opts.Sampling.Enabled = cfg.EnableSampling + opts.Sampling.Initial = cfg.SamplingInitial + opts.Sampling.Thereafter = cfg.SamplingNext + return nil + }); err != nil { + return err + } + return nil +} + +func (s *OpsService) applyRuntimeLogConfigOnStartup(ctx context.Context) { + if s == nil { + return + } + cfg, err := s.GetRuntimeLogConfig(ctx) + if err != nil { + return + } + _ = applyOpsRuntimeLogConfig(cfg) +} + +func (s *OpsService) auditRuntimeLogConfigChange(operatorID int64, oldCfg *OpsRuntimeLogConfig, newCfg *OpsRuntimeLogConfig, action string) { + oldRaw, _ := json.Marshal(oldCfg) + newRaw, _ := json.Marshal(newCfg) + logger.With( + zap.String("component", "audit.log_config_change"), + zap.String("action", strings.TrimSpace(action)), + zap.Int64("operator_id", operatorID), + zap.String("old", string(oldRaw)), + zap.String("new", string(newRaw)), + ).Info("runtime log config changed") +} + +func (s *OpsService) auditRuntimeLogConfigFailure(operatorID int64, oldCfg *OpsRuntimeLogConfig, newCfg *OpsRuntimeLogConfig, reason string) { + oldRaw, _ := json.Marshal(oldCfg) + newRaw, _ := json.Marshal(newCfg) + logger.With( + zap.String("component", "audit.log_config_change"), + zap.String("action", "failed"), + zap.Int64("operator_id", operatorID), + zap.String("reason", strings.TrimSpace(reason)), + zap.String("old", string(oldRaw)), + zap.String("new", string(newRaw)), + ).Warn("runtime log config change failed") +} diff --git a/backend/internal/service/ops_log_runtime_test.go b/backend/internal/service/ops_log_runtime_test.go new file mode 100644 index 00000000..658b4812 --- /dev/null +++ b/backend/internal/service/ops_log_runtime_test.go @@ -0,0 +1,570 @@ +package service + +import ( + "context" + "encoding/json" + "errors" + "testing" + + "github.com/Wei-Shaw/sub2api/internal/config" + "github.com/Wei-Shaw/sub2api/internal/pkg/logger" +) + +type runtimeSettingRepoStub struct { + values map[string]string + deleted map[string]bool + setCalls int + getValueFn func(key string) (string, error) + setFn func(key, value string) error + deleteFn func(key string) error +} + +func newRuntimeSettingRepoStub() *runtimeSettingRepoStub { + return &runtimeSettingRepoStub{ + values: map[string]string{}, + deleted: map[string]bool{}, + } +} + +func (s *runtimeSettingRepoStub) Get(ctx context.Context, key string) (*Setting, error) { + value, err := s.GetValue(ctx, key) + if err != nil { + return nil, err + } + return &Setting{Key: key, Value: value}, nil +} + +func (s *runtimeSettingRepoStub) GetValue(_ context.Context, key string) (string, error) { + if s.getValueFn != nil { + return s.getValueFn(key) + } + value, ok := s.values[key] + if !ok { + return "", ErrSettingNotFound + } + return value, nil +} + +func (s *runtimeSettingRepoStub) Set(_ context.Context, key, value string) error { + if s.setFn != nil { + if err := s.setFn(key, value); err != nil { + return err + } + } + s.values[key] = value + s.setCalls++ + return nil +} + +func (s *runtimeSettingRepoStub) GetMultiple(_ context.Context, keys []string) (map[string]string, error) { + out := make(map[string]string, len(keys)) + for _, key := range keys { + if value, ok := s.values[key]; ok { + out[key] = value + } + } + return out, nil +} + +func (s *runtimeSettingRepoStub) SetMultiple(_ context.Context, settings map[string]string) error { + for key, value := range settings { + s.values[key] = value + } + return nil +} + +func (s *runtimeSettingRepoStub) GetAll(_ context.Context) (map[string]string, error) { + out := make(map[string]string, len(s.values)) + for key, value := range s.values { + out[key] = value + } + return out, nil +} + +func (s *runtimeSettingRepoStub) Delete(_ context.Context, key string) error { + if s.deleteFn != nil { + if err := s.deleteFn(key); err != nil { + return err + } + } + if _, ok := s.values[key]; !ok { + return ErrSettingNotFound + } + delete(s.values, key) + s.deleted[key] = true + return nil +} + +func TestUpdateRuntimeLogConfig_InvalidConfigShouldNotApply(t *testing.T) { + repo := newRuntimeSettingRepoStub() + svc := &OpsService{ + settingRepo: repo, + cfg: &config.Config{ + Log: config.LogConfig{ + Level: "info", + Caller: true, + StacktraceLevel: "error", + Sampling: config.LogSamplingConfig{ + Enabled: false, + Initial: 100, + Thereafter: 100, + }, + }, + }, + } + + if err := logger.Init(logger.InitOptions{ + Level: "info", + Format: "json", + ServiceName: "sub2api", + Environment: "test", + Output: logger.OutputOptions{ + ToStdout: true, + ToFile: false, + }, + }); err != nil { + t.Fatalf("init logger: %v", err) + } + + _, err := svc.UpdateRuntimeLogConfig(context.Background(), &OpsRuntimeLogConfig{ + Level: "trace", + EnableSampling: true, + SamplingInitial: 100, + SamplingNext: 100, + Caller: true, + StacktraceLevel: "error", + RetentionDays: 30, + }, 1) + if err == nil { + t.Fatalf("expected validation error") + } + if logger.CurrentLevel() != "info" { + t.Fatalf("logger level changed unexpectedly: %s", logger.CurrentLevel()) + } + if repo.setCalls != 1 { + // GetRuntimeLogConfig() 会在 key 缺失时写入默认值,此处应只有这一次持久化。 + t.Fatalf("unexpected set calls: %d", repo.setCalls) + } +} + +func TestResetRuntimeLogConfig_ShouldFallbackToBaseline(t *testing.T) { + repo := newRuntimeSettingRepoStub() + existing := &OpsRuntimeLogConfig{ + Level: "debug", + EnableSampling: true, + SamplingInitial: 50, + SamplingNext: 50, + Caller: true, + StacktraceLevel: "error", + RetentionDays: 60, + Source: "runtime_setting", + } + raw, _ := json.Marshal(existing) + repo.values[SettingKeyOpsRuntimeLogConfig] = string(raw) + + svc := &OpsService{ + settingRepo: repo, + cfg: &config.Config{ + Log: config.LogConfig{ + Level: "warn", + Caller: false, + StacktraceLevel: "fatal", + Sampling: config.LogSamplingConfig{ + Enabled: false, + Initial: 100, + Thereafter: 100, + }, + }, + Ops: config.OpsConfig{ + Cleanup: config.OpsCleanupConfig{ + ErrorLogRetentionDays: 45, + }, + }, + }, + } + + if err := logger.Init(logger.InitOptions{ + Level: "debug", + Format: "json", + ServiceName: "sub2api", + Environment: "test", + Output: logger.OutputOptions{ + ToStdout: true, + ToFile: false, + }, + }); err != nil { + t.Fatalf("init logger: %v", err) + } + + resetCfg, err := svc.ResetRuntimeLogConfig(context.Background(), 9) + if err != nil { + t.Fatalf("ResetRuntimeLogConfig() error: %v", err) + } + if resetCfg.Source != "baseline" { + t.Fatalf("source = %q, want baseline", resetCfg.Source) + } + if resetCfg.Level != "warn" { + t.Fatalf("level = %q, want warn", resetCfg.Level) + } + if resetCfg.RetentionDays != 45 { + t.Fatalf("retention_days = %d, want 45", resetCfg.RetentionDays) + } + if logger.CurrentLevel() != "warn" { + t.Fatalf("logger level = %q, want warn", logger.CurrentLevel()) + } + if !repo.deleted[SettingKeyOpsRuntimeLogConfig] { + t.Fatalf("runtime setting key should be deleted") + } +} + +func TestResetRuntimeLogConfig_InvalidOperator(t *testing.T) { + svc := &OpsService{settingRepo: newRuntimeSettingRepoStub()} + _, err := svc.ResetRuntimeLogConfig(context.Background(), 0) + if err == nil { + t.Fatalf("expected invalid operator error") + } + if err.Error() != "invalid operator id" { + t.Fatalf("unexpected error: %v", err) + } +} + +func TestGetRuntimeLogConfig_InvalidJSONFallback(t *testing.T) { + repo := newRuntimeSettingRepoStub() + repo.values[SettingKeyOpsRuntimeLogConfig] = `{invalid-json}` + + svc := &OpsService{ + settingRepo: repo, + cfg: &config.Config{ + Log: config.LogConfig{ + Level: "warn", + Caller: true, + StacktraceLevel: "error", + Sampling: config.LogSamplingConfig{ + Enabled: false, + Initial: 100, + Thereafter: 100, + }, + }, + }, + } + + got, err := svc.GetRuntimeLogConfig(context.Background()) + if err != nil { + t.Fatalf("GetRuntimeLogConfig() error: %v", err) + } + if got.Level != "warn" { + t.Fatalf("level = %q, want warn", got.Level) + } +} + +func TestUpdateRuntimeLogConfig_PersistFailureRollback(t *testing.T) { + repo := newRuntimeSettingRepoStub() + oldCfg := &OpsRuntimeLogConfig{ + Level: "info", + EnableSampling: false, + SamplingInitial: 100, + SamplingNext: 100, + Caller: true, + StacktraceLevel: "error", + RetentionDays: 30, + } + raw, _ := json.Marshal(oldCfg) + repo.values[SettingKeyOpsRuntimeLogConfig] = string(raw) + repo.setFn = func(key, value string) error { + if key == SettingKeyOpsRuntimeLogConfig { + return errors.New("db down") + } + return nil + } + + svc := &OpsService{ + settingRepo: repo, + cfg: &config.Config{ + Log: config.LogConfig{ + Level: "info", + Caller: true, + StacktraceLevel: "error", + Sampling: config.LogSamplingConfig{ + Enabled: false, + Initial: 100, + Thereafter: 100, + }, + }, + }, + } + + if err := logger.Init(logger.InitOptions{ + Level: "info", + Format: "json", + ServiceName: "sub2api", + Environment: "test", + Output: logger.OutputOptions{ + ToStdout: true, + ToFile: false, + }, + }); err != nil { + t.Fatalf("init logger: %v", err) + } + + _, err := svc.UpdateRuntimeLogConfig(context.Background(), &OpsRuntimeLogConfig{ + Level: "debug", + EnableSampling: false, + SamplingInitial: 100, + SamplingNext: 100, + Caller: true, + StacktraceLevel: "error", + RetentionDays: 30, + }, 5) + if err == nil { + t.Fatalf("expected persist error") + } + // Persist failure should rollback runtime level back to old effective level. + if logger.CurrentLevel() != "info" { + t.Fatalf("logger level should rollback to info, got %s", logger.CurrentLevel()) + } +} + +func TestApplyRuntimeLogConfigOnStartup(t *testing.T) { + repo := newRuntimeSettingRepoStub() + cfgRaw := `{"level":"debug","enable_sampling":false,"sampling_initial":100,"sampling_thereafter":100,"caller":true,"stacktrace_level":"error","retention_days":30}` + repo.values[SettingKeyOpsRuntimeLogConfig] = cfgRaw + + svc := &OpsService{ + settingRepo: repo, + cfg: &config.Config{ + Log: config.LogConfig{ + Level: "info", + Caller: true, + StacktraceLevel: "error", + Sampling: config.LogSamplingConfig{ + Enabled: false, + Initial: 100, + Thereafter: 100, + }, + }, + }, + } + + if err := logger.Init(logger.InitOptions{ + Level: "info", + Format: "json", + ServiceName: "sub2api", + Environment: "test", + Output: logger.OutputOptions{ + ToStdout: true, + ToFile: false, + }, + }); err != nil { + t.Fatalf("init logger: %v", err) + } + + svc.applyRuntimeLogConfigOnStartup(context.Background()) + if logger.CurrentLevel() != "debug" { + t.Fatalf("expected startup apply debug, got %s", logger.CurrentLevel()) + } +} + +func TestDefaultNormalizeAndValidateRuntimeLogConfig(t *testing.T) { + defaults := defaultOpsRuntimeLogConfig(&config.Config{ + Log: config.LogConfig{ + Level: "DEBUG", + Caller: false, + StacktraceLevel: "FATAL", + Sampling: config.LogSamplingConfig{ + Enabled: true, + Initial: 50, + Thereafter: 20, + }, + }, + Ops: config.OpsConfig{ + Cleanup: config.OpsCleanupConfig{ + ErrorLogRetentionDays: 7, + }, + }, + }) + if defaults.Level != "debug" || defaults.StacktraceLevel != "fatal" || defaults.RetentionDays != 7 { + t.Fatalf("unexpected defaults: %+v", defaults) + } + + cfg := &OpsRuntimeLogConfig{ + Level: " ", + EnableSampling: true, + SamplingInitial: 0, + SamplingNext: -1, + Caller: true, + StacktraceLevel: "", + RetentionDays: 0, + } + normalizeOpsRuntimeLogConfig(cfg, defaults) + if cfg.Level != "debug" || cfg.StacktraceLevel != "fatal" { + t.Fatalf("normalize level/stacktrace failed: %+v", cfg) + } + if cfg.SamplingInitial != 50 || cfg.SamplingNext != 20 || cfg.RetentionDays != 7 { + t.Fatalf("normalize numeric defaults failed: %+v", cfg) + } + if err := validateOpsRuntimeLogConfig(cfg); err != nil { + t.Fatalf("validate normalized config should pass: %v", err) + } +} + +func TestValidateRuntimeLogConfigErrors(t *testing.T) { + cases := []struct { + name string + cfg *OpsRuntimeLogConfig + }{ + {name: "nil", cfg: nil}, + {name: "bad level", cfg: &OpsRuntimeLogConfig{Level: "trace", StacktraceLevel: "error", SamplingInitial: 1, SamplingNext: 1, RetentionDays: 1}}, + {name: "bad stack", cfg: &OpsRuntimeLogConfig{Level: "info", StacktraceLevel: "warn", SamplingInitial: 1, SamplingNext: 1, RetentionDays: 1}}, + {name: "bad initial", cfg: &OpsRuntimeLogConfig{Level: "info", StacktraceLevel: "error", SamplingInitial: 0, SamplingNext: 1, RetentionDays: 1}}, + {name: "bad next", cfg: &OpsRuntimeLogConfig{Level: "info", StacktraceLevel: "error", SamplingInitial: 1, SamplingNext: 0, RetentionDays: 1}}, + {name: "bad retention", cfg: &OpsRuntimeLogConfig{Level: "info", StacktraceLevel: "error", SamplingInitial: 1, SamplingNext: 1, RetentionDays: 0}}, + } + for _, tc := range cases { + t.Run(tc.name, func(t *testing.T) { + if err := validateOpsRuntimeLogConfig(tc.cfg); err == nil { + t.Fatalf("expected validation error") + } + }) + } +} + +func TestGetRuntimeLogConfigFallbackAndErrors(t *testing.T) { + var nilSvc *OpsService + cfg, err := nilSvc.GetRuntimeLogConfig(context.Background()) + if err != nil { + t.Fatalf("nil svc should fallback default: %v", err) + } + if cfg.Level != "info" { + t.Fatalf("unexpected nil svc default level: %s", cfg.Level) + } + + repo := newRuntimeSettingRepoStub() + repo.getValueFn = func(key string) (string, error) { + return "", errors.New("boom") + } + svc := &OpsService{ + settingRepo: repo, + cfg: &config.Config{ + Log: config.LogConfig{ + Level: "warn", + Caller: true, + StacktraceLevel: "error", + Sampling: config.LogSamplingConfig{ + Enabled: false, + Initial: 100, + Thereafter: 100, + }, + }, + }, + } + if _, err := svc.GetRuntimeLogConfig(context.Background()); err == nil { + t.Fatalf("expected get value error") + } +} + +func TestUpdateRuntimeLogConfig_PreconditionErrors(t *testing.T) { + svc := &OpsService{} + if _, err := svc.UpdateRuntimeLogConfig(context.Background(), &OpsRuntimeLogConfig{}, 1); err == nil { + t.Fatalf("expected setting repo not initialized") + } + + svc = &OpsService{settingRepo: newRuntimeSettingRepoStub()} + if _, err := svc.UpdateRuntimeLogConfig(context.Background(), nil, 1); err == nil { + t.Fatalf("expected invalid config") + } + if _, err := svc.UpdateRuntimeLogConfig(context.Background(), &OpsRuntimeLogConfig{ + Level: "info", + StacktraceLevel: "error", + SamplingInitial: 1, + SamplingNext: 1, + RetentionDays: 1, + }, 0); err == nil { + t.Fatalf("expected invalid operator") + } +} + +func TestUpdateRuntimeLogConfig_Success(t *testing.T) { + repo := newRuntimeSettingRepoStub() + svc := &OpsService{ + settingRepo: repo, + cfg: &config.Config{ + Log: config.LogConfig{ + Level: "info", + Caller: true, + StacktraceLevel: "error", + Sampling: config.LogSamplingConfig{ + Enabled: false, + Initial: 100, + Thereafter: 100, + }, + }, + }, + } + + if err := logger.Init(logger.InitOptions{ + Level: "info", + Format: "json", + ServiceName: "sub2api", + Environment: "test", + Output: logger.OutputOptions{ + ToStdout: true, + ToFile: false, + }, + }); err != nil { + t.Fatalf("init logger: %v", err) + } + + next, err := svc.UpdateRuntimeLogConfig(context.Background(), &OpsRuntimeLogConfig{ + Level: "debug", + EnableSampling: false, + SamplingInitial: 100, + SamplingNext: 100, + Caller: true, + StacktraceLevel: "error", + RetentionDays: 30, + }, 2) + if err != nil { + t.Fatalf("UpdateRuntimeLogConfig() error: %v", err) + } + if next.Source != "runtime_setting" || next.UpdatedByUserID != 2 || next.UpdatedAt == "" { + t.Fatalf("unexpected metadata: %+v", next) + } + if logger.CurrentLevel() != "debug" { + t.Fatalf("expected applied level debug, got %s", logger.CurrentLevel()) + } +} + +func TestResetRuntimeLogConfig_IgnoreNotFoundDelete(t *testing.T) { + repo := newRuntimeSettingRepoStub() + repo.deleteFn = func(key string) error { return ErrSettingNotFound } + svc := &OpsService{ + settingRepo: repo, + cfg: &config.Config{ + Log: config.LogConfig{ + Level: "info", + Caller: true, + StacktraceLevel: "error", + Sampling: config.LogSamplingConfig{ + Enabled: false, + Initial: 100, + Thereafter: 100, + }, + }, + }, + } + if _, err := svc.ResetRuntimeLogConfig(context.Background(), 1); err != nil { + t.Fatalf("reset should ignore ErrSettingNotFound: %v", err) + } +} + +func TestApplyRuntimeLogConfigHelpers(t *testing.T) { + if err := applyOpsRuntimeLogConfig(nil); err == nil { + t.Fatalf("expected nil config error") + } + + normalizeOpsRuntimeLogConfig(nil, &OpsRuntimeLogConfig{Level: "info"}) + normalizeOpsRuntimeLogConfig(&OpsRuntimeLogConfig{Level: "debug"}, nil) + + var nilSvc *OpsService + nilSvc.applyRuntimeLogConfigOnStartup(context.Background()) +} diff --git a/backend/internal/service/ops_models.go b/backend/internal/service/ops_models.go index 347cd52b..2ed06d90 100644 --- a/backend/internal/service/ops_models.go +++ b/backend/internal/service/ops_models.go @@ -2,6 +2,21 @@ package service import "time" +type OpsSystemLog struct { + ID int64 `json:"id"` + CreatedAt time.Time `json:"created_at"` + Level string `json:"level"` + Component string `json:"component"` + Message string `json:"message"` + RequestID string `json:"request_id"` + ClientRequestID string `json:"client_request_id"` + UserID *int64 `json:"user_id"` + AccountID *int64 `json:"account_id"` + Platform string `json:"platform"` + Model string `json:"model"` + Extra map[string]any `json:"extra,omitempty"` +} + type OpsErrorLog struct { ID int64 `json:"id"` CreatedAt time.Time `json:"created_at"` diff --git a/backend/internal/service/ops_openai_token_stats.go b/backend/internal/service/ops_openai_token_stats.go new file mode 100644 index 00000000..63f88ba0 --- /dev/null +++ b/backend/internal/service/ops_openai_token_stats.go @@ -0,0 +1,55 @@ +package service + +import ( + "context" + + infraerrors "github.com/Wei-Shaw/sub2api/internal/pkg/errors" +) + +func (s *OpsService) GetOpenAITokenStats(ctx context.Context, filter *OpsOpenAITokenStatsFilter) (*OpsOpenAITokenStatsResponse, error) { + if err := s.RequireMonitoringEnabled(ctx); err != nil { + return nil, err + } + if s.opsRepo == nil { + return nil, infraerrors.ServiceUnavailable("OPS_REPO_UNAVAILABLE", "Ops repository not available") + } + if filter == nil { + return nil, infraerrors.BadRequest("OPS_FILTER_REQUIRED", "filter is required") + } + if filter.StartTime.IsZero() || filter.EndTime.IsZero() { + return nil, infraerrors.BadRequest("OPS_TIME_RANGE_REQUIRED", "start_time/end_time are required") + } + if filter.StartTime.After(filter.EndTime) { + return nil, infraerrors.BadRequest("OPS_TIME_RANGE_INVALID", "start_time must be <= end_time") + } + + if filter.GroupID != nil && *filter.GroupID <= 0 { + return nil, infraerrors.BadRequest("OPS_GROUP_ID_INVALID", "group_id must be > 0") + } + + // top_n cannot be mixed with page/page_size params. + if filter.TopN > 0 && (filter.Page > 0 || filter.PageSize > 0) { + return nil, infraerrors.BadRequest("OPS_PAGINATION_CONFLICT", "top_n cannot be used with page/page_size") + } + + if filter.TopN > 0 { + if filter.TopN < 1 || filter.TopN > 100 { + return nil, infraerrors.BadRequest("OPS_TOPN_INVALID", "top_n must be between 1 and 100") + } + } else { + if filter.Page <= 0 { + filter.Page = 1 + } + if filter.PageSize <= 0 { + filter.PageSize = 20 + } + if filter.Page < 1 { + return nil, infraerrors.BadRequest("OPS_PAGE_INVALID", "page must be >= 1") + } + if filter.PageSize < 1 || filter.PageSize > 100 { + return nil, infraerrors.BadRequest("OPS_PAGE_SIZE_INVALID", "page_size must be between 1 and 100") + } + } + + return s.opsRepo.GetOpenAITokenStats(ctx, filter) +} diff --git a/backend/internal/service/ops_openai_token_stats_models.go b/backend/internal/service/ops_openai_token_stats_models.go new file mode 100644 index 00000000..ef40fa1f --- /dev/null +++ b/backend/internal/service/ops_openai_token_stats_models.go @@ -0,0 +1,54 @@ +package service + +import "time" + +type OpsOpenAITokenStatsFilter struct { + TimeRange string + StartTime time.Time + EndTime time.Time + + Platform string + GroupID *int64 + + // Pagination mode (default): page/page_size + Page int + PageSize int + + // TopN mode: top_n + TopN int +} + +func (f *OpsOpenAITokenStatsFilter) IsTopNMode() bool { + return f != nil && f.TopN > 0 +} + +type OpsOpenAITokenStatsItem struct { + Model string `json:"model"` + RequestCount int64 `json:"request_count"` + AvgTokensPerSec *float64 `json:"avg_tokens_per_sec"` + AvgFirstTokenMs *float64 `json:"avg_first_token_ms"` + TotalOutputTokens int64 `json:"total_output_tokens"` + AvgDurationMs int64 `json:"avg_duration_ms"` + RequestsWithFirstToken int64 `json:"requests_with_first_token"` +} + +type OpsOpenAITokenStatsResponse struct { + TimeRange string `json:"time_range"` + StartTime time.Time `json:"start_time"` + EndTime time.Time `json:"end_time"` + + Platform string `json:"platform,omitempty"` + GroupID *int64 `json:"group_id,omitempty"` + + Items []*OpsOpenAITokenStatsItem `json:"items"` + + // Total model rows before pagination/topN trimming. + Total int64 `json:"total"` + + // Pagination mode metadata. + Page int `json:"page,omitempty"` + PageSize int `json:"page_size,omitempty"` + + // TopN mode metadata. + TopN *int `json:"top_n,omitempty"` +} diff --git a/backend/internal/service/ops_openai_token_stats_test.go b/backend/internal/service/ops_openai_token_stats_test.go new file mode 100644 index 00000000..ee332f91 --- /dev/null +++ b/backend/internal/service/ops_openai_token_stats_test.go @@ -0,0 +1,162 @@ +package service + +import ( + "context" + "testing" + "time" + + infraerrors "github.com/Wei-Shaw/sub2api/internal/pkg/errors" + "github.com/stretchr/testify/require" +) + +type openAITokenStatsRepoStub struct { + OpsRepository + resp *OpsOpenAITokenStatsResponse + err error + captured *OpsOpenAITokenStatsFilter +} + +func (s *openAITokenStatsRepoStub) GetOpenAITokenStats(ctx context.Context, filter *OpsOpenAITokenStatsFilter) (*OpsOpenAITokenStatsResponse, error) { + s.captured = filter + if s.err != nil { + return nil, s.err + } + if s.resp != nil { + return s.resp, nil + } + return &OpsOpenAITokenStatsResponse{}, nil +} + +func TestOpsServiceGetOpenAITokenStats_Validation(t *testing.T) { + now := time.Now().UTC() + + tests := []struct { + name string + filter *OpsOpenAITokenStatsFilter + wantCode int + wantReason string + }{ + { + name: "filter 不能为空", + filter: nil, + wantCode: 400, + wantReason: "OPS_FILTER_REQUIRED", + }, + { + name: "start_time/end_time 必填", + filter: &OpsOpenAITokenStatsFilter{ + StartTime: time.Time{}, + EndTime: now, + }, + wantCode: 400, + wantReason: "OPS_TIME_RANGE_REQUIRED", + }, + { + name: "start_time 不能晚于 end_time", + filter: &OpsOpenAITokenStatsFilter{ + StartTime: now, + EndTime: now.Add(-1 * time.Minute), + }, + wantCode: 400, + wantReason: "OPS_TIME_RANGE_INVALID", + }, + { + name: "group_id 必须大于 0", + filter: &OpsOpenAITokenStatsFilter{ + StartTime: now.Add(-time.Hour), + EndTime: now, + GroupID: int64Ptr(0), + }, + wantCode: 400, + wantReason: "OPS_GROUP_ID_INVALID", + }, + { + name: "top_n 与分页参数互斥", + filter: &OpsOpenAITokenStatsFilter{ + StartTime: now.Add(-time.Hour), + EndTime: now, + TopN: 10, + Page: 1, + }, + wantCode: 400, + wantReason: "OPS_PAGINATION_CONFLICT", + }, + { + name: "top_n 参数越界", + filter: &OpsOpenAITokenStatsFilter{ + StartTime: now.Add(-time.Hour), + EndTime: now, + TopN: 101, + }, + wantCode: 400, + wantReason: "OPS_TOPN_INVALID", + }, + { + name: "page_size 参数越界", + filter: &OpsOpenAITokenStatsFilter{ + StartTime: now.Add(-time.Hour), + EndTime: now, + Page: 1, + PageSize: 101, + }, + wantCode: 400, + wantReason: "OPS_PAGE_SIZE_INVALID", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + svc := &OpsService{ + opsRepo: &openAITokenStatsRepoStub{}, + } + + _, err := svc.GetOpenAITokenStats(context.Background(), tt.filter) + require.Error(t, err) + require.Equal(t, tt.wantCode, infraerrors.Code(err)) + require.Equal(t, tt.wantReason, infraerrors.Reason(err)) + }) + } +} + +func TestOpsServiceGetOpenAITokenStats_DefaultPagination(t *testing.T) { + now := time.Now().UTC() + repo := &openAITokenStatsRepoStub{ + resp: &OpsOpenAITokenStatsResponse{ + Items: []*OpsOpenAITokenStatsItem{ + {Model: "gpt-4o-mini", RequestCount: 10}, + }, + Total: 1, + }, + } + svc := &OpsService{opsRepo: repo} + + filter := &OpsOpenAITokenStatsFilter{ + TimeRange: "30d", + StartTime: now.Add(-30 * 24 * time.Hour), + EndTime: now, + } + resp, err := svc.GetOpenAITokenStats(context.Background(), filter) + require.NoError(t, err) + require.NotNil(t, resp) + require.NotNil(t, repo.captured) + require.Equal(t, 1, repo.captured.Page) + require.Equal(t, 20, repo.captured.PageSize) + require.Equal(t, 0, repo.captured.TopN) +} + +func TestOpsServiceGetOpenAITokenStats_RepoUnavailable(t *testing.T) { + now := time.Now().UTC() + svc := &OpsService{} + + _, err := svc.GetOpenAITokenStats(context.Background(), &OpsOpenAITokenStatsFilter{ + TimeRange: "1h", + StartTime: now.Add(-time.Hour), + EndTime: now, + TopN: 10, + }) + require.Error(t, err) + require.Equal(t, 503, infraerrors.Code(err)) + require.Equal(t, "OPS_REPO_UNAVAILABLE", infraerrors.Reason(err)) +} + +func int64Ptr(v int64) *int64 { return &v } diff --git a/backend/internal/service/ops_port.go b/backend/internal/service/ops_port.go index 347b06b5..f3633eae 100644 --- a/backend/internal/service/ops_port.go +++ b/backend/internal/service/ops_port.go @@ -10,6 +10,10 @@ type OpsRepository interface { ListErrorLogs(ctx context.Context, filter *OpsErrorLogFilter) (*OpsErrorLogList, error) GetErrorLogByID(ctx context.Context, id int64) (*OpsErrorLogDetail, error) ListRequestDetails(ctx context.Context, filter *OpsRequestDetailFilter) ([]*OpsRequestDetail, int64, error) + BatchInsertSystemLogs(ctx context.Context, inputs []*OpsInsertSystemLogInput) (int64, error) + ListSystemLogs(ctx context.Context, filter *OpsSystemLogFilter) (*OpsSystemLogList, error) + DeleteSystemLogs(ctx context.Context, filter *OpsSystemLogCleanupFilter) (int64, error) + InsertSystemLogCleanupAudit(ctx context.Context, input *OpsSystemLogCleanupAudit) error InsertRetryAttempt(ctx context.Context, input *OpsInsertRetryAttemptInput) (int64, error) UpdateRetryAttempt(ctx context.Context, input *OpsUpdateRetryAttemptInput) error @@ -27,6 +31,7 @@ type OpsRepository interface { GetLatencyHistogram(ctx context.Context, filter *OpsDashboardFilter) (*OpsLatencyHistogramResponse, error) GetErrorTrend(ctx context.Context, filter *OpsDashboardFilter, bucketSeconds int) (*OpsErrorTrendResponse, error) GetErrorDistribution(ctx context.Context, filter *OpsDashboardFilter) (*OpsErrorDistributionResponse, error) + GetOpenAITokenStats(ctx context.Context, filter *OpsOpenAITokenStatsFilter) (*OpsOpenAITokenStatsResponse, error) InsertSystemMetrics(ctx context.Context, input *OpsInsertSystemMetricsInput) error GetLatestSystemMetrics(ctx context.Context, windowMinutes int) (*OpsSystemMetricsSnapshot, error) @@ -98,6 +103,10 @@ type OpsInsertErrorLogInput struct { // It is set by OpsService.RecordError before persisting. UpstreamErrorsJSON *string + AuthLatencyMs *int64 + RoutingLatencyMs *int64 + UpstreamLatencyMs *int64 + ResponseLatencyMs *int64 TimeToFirstTokenMs *int64 RequestBodyJSON *string // sanitized json string (not raw bytes) @@ -200,6 +209,69 @@ type OpsInsertSystemMetricsInput struct { ConcurrencyQueueDepth *int } +type OpsInsertSystemLogInput struct { + CreatedAt time.Time + Level string + Component string + Message string + RequestID string + ClientRequestID string + UserID *int64 + AccountID *int64 + Platform string + Model string + ExtraJSON string +} + +type OpsSystemLogFilter struct { + StartTime *time.Time + EndTime *time.Time + + Level string + Component string + + RequestID string + ClientRequestID string + UserID *int64 + AccountID *int64 + Platform string + Model string + Query string + + Page int + PageSize int +} + +type OpsSystemLogCleanupFilter struct { + StartTime *time.Time + EndTime *time.Time + + Level string + Component string + + RequestID string + ClientRequestID string + UserID *int64 + AccountID *int64 + Platform string + Model string + Query string +} + +type OpsSystemLogList struct { + Logs []*OpsSystemLog `json:"logs"` + Total int `json:"total"` + Page int `json:"page"` + PageSize int `json:"page_size"` +} + +type OpsSystemLogCleanupAudit struct { + CreatedAt time.Time + OperatorID int64 + Conditions string + DeletedRows int64 +} + type OpsSystemMetricsSnapshot struct { ID int64 `json:"id"` CreatedAt time.Time `json:"created_at"` diff --git a/backend/internal/service/ops_repo_mock_test.go b/backend/internal/service/ops_repo_mock_test.go new file mode 100644 index 00000000..e250dea3 --- /dev/null +++ b/backend/internal/service/ops_repo_mock_test.go @@ -0,0 +1,196 @@ +package service + +import ( + "context" + "time" +) + +// opsRepoMock is a test-only OpsRepository implementation with optional function hooks. +type opsRepoMock struct { + BatchInsertSystemLogsFn func(ctx context.Context, inputs []*OpsInsertSystemLogInput) (int64, error) + ListSystemLogsFn func(ctx context.Context, filter *OpsSystemLogFilter) (*OpsSystemLogList, error) + DeleteSystemLogsFn func(ctx context.Context, filter *OpsSystemLogCleanupFilter) (int64, error) + InsertSystemLogCleanupAuditFn func(ctx context.Context, input *OpsSystemLogCleanupAudit) error +} + +func (m *opsRepoMock) InsertErrorLog(ctx context.Context, input *OpsInsertErrorLogInput) (int64, error) { + return 0, nil +} + +func (m *opsRepoMock) ListErrorLogs(ctx context.Context, filter *OpsErrorLogFilter) (*OpsErrorLogList, error) { + return &OpsErrorLogList{Errors: []*OpsErrorLog{}, Page: 1, PageSize: 20}, nil +} + +func (m *opsRepoMock) GetErrorLogByID(ctx context.Context, id int64) (*OpsErrorLogDetail, error) { + return &OpsErrorLogDetail{}, nil +} + +func (m *opsRepoMock) ListRequestDetails(ctx context.Context, filter *OpsRequestDetailFilter) ([]*OpsRequestDetail, int64, error) { + return []*OpsRequestDetail{}, 0, nil +} + +func (m *opsRepoMock) BatchInsertSystemLogs(ctx context.Context, inputs []*OpsInsertSystemLogInput) (int64, error) { + if m.BatchInsertSystemLogsFn != nil { + return m.BatchInsertSystemLogsFn(ctx, inputs) + } + return int64(len(inputs)), nil +} + +func (m *opsRepoMock) ListSystemLogs(ctx context.Context, filter *OpsSystemLogFilter) (*OpsSystemLogList, error) { + if m.ListSystemLogsFn != nil { + return m.ListSystemLogsFn(ctx, filter) + } + return &OpsSystemLogList{Logs: []*OpsSystemLog{}, Total: 0, Page: 1, PageSize: 50}, nil +} + +func (m *opsRepoMock) DeleteSystemLogs(ctx context.Context, filter *OpsSystemLogCleanupFilter) (int64, error) { + if m.DeleteSystemLogsFn != nil { + return m.DeleteSystemLogsFn(ctx, filter) + } + return 0, nil +} + +func (m *opsRepoMock) InsertSystemLogCleanupAudit(ctx context.Context, input *OpsSystemLogCleanupAudit) error { + if m.InsertSystemLogCleanupAuditFn != nil { + return m.InsertSystemLogCleanupAuditFn(ctx, input) + } + return nil +} + +func (m *opsRepoMock) InsertRetryAttempt(ctx context.Context, input *OpsInsertRetryAttemptInput) (int64, error) { + return 0, nil +} + +func (m *opsRepoMock) UpdateRetryAttempt(ctx context.Context, input *OpsUpdateRetryAttemptInput) error { + return nil +} + +func (m *opsRepoMock) GetLatestRetryAttemptForError(ctx context.Context, sourceErrorID int64) (*OpsRetryAttempt, error) { + return nil, nil +} + +func (m *opsRepoMock) ListRetryAttemptsByErrorID(ctx context.Context, sourceErrorID int64, limit int) ([]*OpsRetryAttempt, error) { + return []*OpsRetryAttempt{}, nil +} + +func (m *opsRepoMock) UpdateErrorResolution(ctx context.Context, errorID int64, resolved bool, resolvedByUserID *int64, resolvedRetryID *int64, resolvedAt *time.Time) error { + return nil +} + +func (m *opsRepoMock) GetWindowStats(ctx context.Context, filter *OpsDashboardFilter) (*OpsWindowStats, error) { + return &OpsWindowStats{}, nil +} + +func (m *opsRepoMock) GetRealtimeTrafficSummary(ctx context.Context, filter *OpsDashboardFilter) (*OpsRealtimeTrafficSummary, error) { + return &OpsRealtimeTrafficSummary{}, nil +} + +func (m *opsRepoMock) GetDashboardOverview(ctx context.Context, filter *OpsDashboardFilter) (*OpsDashboardOverview, error) { + return &OpsDashboardOverview{}, nil +} + +func (m *opsRepoMock) GetThroughputTrend(ctx context.Context, filter *OpsDashboardFilter, bucketSeconds int) (*OpsThroughputTrendResponse, error) { + return &OpsThroughputTrendResponse{}, nil +} + +func (m *opsRepoMock) GetLatencyHistogram(ctx context.Context, filter *OpsDashboardFilter) (*OpsLatencyHistogramResponse, error) { + return &OpsLatencyHistogramResponse{}, nil +} + +func (m *opsRepoMock) GetErrorTrend(ctx context.Context, filter *OpsDashboardFilter, bucketSeconds int) (*OpsErrorTrendResponse, error) { + return &OpsErrorTrendResponse{}, nil +} + +func (m *opsRepoMock) GetErrorDistribution(ctx context.Context, filter *OpsDashboardFilter) (*OpsErrorDistributionResponse, error) { + return &OpsErrorDistributionResponse{}, nil +} + +func (m *opsRepoMock) GetOpenAITokenStats(ctx context.Context, filter *OpsOpenAITokenStatsFilter) (*OpsOpenAITokenStatsResponse, error) { + return &OpsOpenAITokenStatsResponse{}, nil +} + +func (m *opsRepoMock) InsertSystemMetrics(ctx context.Context, input *OpsInsertSystemMetricsInput) error { + return nil +} + +func (m *opsRepoMock) GetLatestSystemMetrics(ctx context.Context, windowMinutes int) (*OpsSystemMetricsSnapshot, error) { + return &OpsSystemMetricsSnapshot{}, nil +} + +func (m *opsRepoMock) UpsertJobHeartbeat(ctx context.Context, input *OpsUpsertJobHeartbeatInput) error { + return nil +} + +func (m *opsRepoMock) ListJobHeartbeats(ctx context.Context) ([]*OpsJobHeartbeat, error) { + return []*OpsJobHeartbeat{}, nil +} + +func (m *opsRepoMock) ListAlertRules(ctx context.Context) ([]*OpsAlertRule, error) { + return []*OpsAlertRule{}, nil +} + +func (m *opsRepoMock) CreateAlertRule(ctx context.Context, input *OpsAlertRule) (*OpsAlertRule, error) { + return input, nil +} + +func (m *opsRepoMock) UpdateAlertRule(ctx context.Context, input *OpsAlertRule) (*OpsAlertRule, error) { + return input, nil +} + +func (m *opsRepoMock) DeleteAlertRule(ctx context.Context, id int64) error { + return nil +} + +func (m *opsRepoMock) ListAlertEvents(ctx context.Context, filter *OpsAlertEventFilter) ([]*OpsAlertEvent, error) { + return []*OpsAlertEvent{}, nil +} + +func (m *opsRepoMock) GetAlertEventByID(ctx context.Context, eventID int64) (*OpsAlertEvent, error) { + return &OpsAlertEvent{}, nil +} + +func (m *opsRepoMock) GetActiveAlertEvent(ctx context.Context, ruleID int64) (*OpsAlertEvent, error) { + return nil, nil +} + +func (m *opsRepoMock) GetLatestAlertEvent(ctx context.Context, ruleID int64) (*OpsAlertEvent, error) { + return nil, nil +} + +func (m *opsRepoMock) CreateAlertEvent(ctx context.Context, event *OpsAlertEvent) (*OpsAlertEvent, error) { + return event, nil +} + +func (m *opsRepoMock) UpdateAlertEventStatus(ctx context.Context, eventID int64, status string, resolvedAt *time.Time) error { + return nil +} + +func (m *opsRepoMock) UpdateAlertEventEmailSent(ctx context.Context, eventID int64, emailSent bool) error { + return nil +} + +func (m *opsRepoMock) CreateAlertSilence(ctx context.Context, input *OpsAlertSilence) (*OpsAlertSilence, error) { + return input, nil +} + +func (m *opsRepoMock) IsAlertSilenced(ctx context.Context, ruleID int64, platform string, groupID *int64, region *string, now time.Time) (bool, error) { + return false, nil +} + +func (m *opsRepoMock) UpsertHourlyMetrics(ctx context.Context, startTime, endTime time.Time) error { + return nil +} + +func (m *opsRepoMock) UpsertDailyMetrics(ctx context.Context, startTime, endTime time.Time) error { + return nil +} + +func (m *opsRepoMock) GetLatestHourlyBucketStart(ctx context.Context) (time.Time, bool, error) { + return time.Time{}, false, nil +} + +func (m *opsRepoMock) GetLatestDailyBucketDate(ctx context.Context) (time.Time, bool, error) { + return time.Time{}, false, nil +} + +var _ OpsRepository = (*opsRepoMock)(nil) diff --git a/backend/internal/service/ops_service.go b/backend/internal/service/ops_service.go index 9c121b8b..767d1704 100644 --- a/backend/internal/service/ops_service.go +++ b/backend/internal/service/ops_service.go @@ -20,6 +20,22 @@ const ( opsMaxStoredErrorBodyBytes = 20 * 1024 ) +// PrepareOpsRequestBodyForQueue 在入队前对请求体执行脱敏与裁剪,返回可直接写入 OpsInsertErrorLogInput 的字段。 +// 该方法用于避免异步队列持有大块原始请求体,减少错误风暴下的内存放大风险。 +func PrepareOpsRequestBodyForQueue(raw []byte) (requestBodyJSON *string, truncated bool, requestBodyBytes *int) { + if len(raw) == 0 { + return nil, false, nil + } + sanitized, truncated, bytesLen := sanitizeAndTrimRequestBody(raw, opsMaxStoredRequestBodyBytes) + if sanitized != "" { + out := sanitized + requestBodyJSON = &out + } + n := bytesLen + requestBodyBytes = &n + return requestBodyJSON, truncated, requestBodyBytes +} + // OpsService provides ingestion and query APIs for the Ops monitoring module. type OpsService struct { opsRepo OpsRepository @@ -37,6 +53,7 @@ type OpsService struct { openAIGatewayService *OpenAIGatewayService geminiCompatService *GeminiMessagesCompatService antigravityGatewayService *AntigravityGatewayService + systemLogSink *OpsSystemLogSink } func NewOpsService( @@ -50,8 +67,9 @@ func NewOpsService( openAIGatewayService *OpenAIGatewayService, geminiCompatService *GeminiMessagesCompatService, antigravityGatewayService *AntigravityGatewayService, + systemLogSink *OpsSystemLogSink, ) *OpsService { - return &OpsService{ + svc := &OpsService{ opsRepo: opsRepo, settingRepo: settingRepo, cfg: cfg, @@ -64,7 +82,10 @@ func NewOpsService( openAIGatewayService: openAIGatewayService, geminiCompatService: geminiCompatService, antigravityGatewayService: antigravityGatewayService, + systemLogSink: systemLogSink, } + svc.applyRuntimeLogConfigOnStartup(context.Background()) + return svc } func (s *OpsService) RequireMonitoringEnabled(ctx context.Context) error { @@ -127,12 +148,7 @@ func (s *OpsService) RecordError(ctx context.Context, entry *OpsInsertErrorLogIn // Sanitize + trim request body (errors only). if len(rawRequestBody) > 0 { - sanitized, truncated, bytesLen := sanitizeAndTrimRequestBody(rawRequestBody, opsMaxStoredRequestBodyBytes) - if sanitized != "" { - entry.RequestBodyJSON = &sanitized - } - entry.RequestBodyTruncated = truncated - entry.RequestBodyBytes = &bytesLen + entry.RequestBodyJSON, entry.RequestBodyTruncated, entry.RequestBodyBytes = PrepareOpsRequestBodyForQueue(rawRequestBody) } // Sanitize + truncate error_body to avoid storing sensitive data. diff --git a/backend/internal/service/ops_service_prepare_queue_test.go b/backend/internal/service/ops_service_prepare_queue_test.go new file mode 100644 index 00000000..d6f32c2d --- /dev/null +++ b/backend/internal/service/ops_service_prepare_queue_test.go @@ -0,0 +1,60 @@ +package service + +import ( + "encoding/json" + "strings" + "testing" + + "github.com/stretchr/testify/require" +) + +func TestPrepareOpsRequestBodyForQueue_EmptyBody(t *testing.T) { + requestBodyJSON, truncated, requestBodyBytes := PrepareOpsRequestBodyForQueue(nil) + require.Nil(t, requestBodyJSON) + require.False(t, truncated) + require.Nil(t, requestBodyBytes) +} + +func TestPrepareOpsRequestBodyForQueue_InvalidJSON(t *testing.T) { + raw := []byte("{invalid-json") + requestBodyJSON, truncated, requestBodyBytes := PrepareOpsRequestBodyForQueue(raw) + require.Nil(t, requestBodyJSON) + require.False(t, truncated) + require.NotNil(t, requestBodyBytes) + require.Equal(t, len(raw), *requestBodyBytes) +} + +func TestPrepareOpsRequestBodyForQueue_RedactSensitiveFields(t *testing.T) { + raw := []byte(`{ + "model":"claude-3-5-sonnet-20241022", + "api_key":"sk-test-123", + "headers":{"authorization":"Bearer secret-token"}, + "messages":[{"role":"user","content":"hello"}] + }`) + + requestBodyJSON, truncated, requestBodyBytes := PrepareOpsRequestBodyForQueue(raw) + require.NotNil(t, requestBodyJSON) + require.NotNil(t, requestBodyBytes) + require.False(t, truncated) + require.Equal(t, len(raw), *requestBodyBytes) + + var body map[string]any + require.NoError(t, json.Unmarshal([]byte(*requestBodyJSON), &body)) + require.Equal(t, "[REDACTED]", body["api_key"]) + headers, ok := body["headers"].(map[string]any) + require.True(t, ok) + require.Equal(t, "[REDACTED]", headers["authorization"]) +} + +func TestPrepareOpsRequestBodyForQueue_LargeBodyTruncated(t *testing.T) { + largeMsg := strings.Repeat("x", opsMaxStoredRequestBodyBytes*2) + raw := []byte(`{"model":"claude-3-5-sonnet-20241022","messages":[{"role":"user","content":"` + largeMsg + `"}]}`) + + requestBodyJSON, truncated, requestBodyBytes := PrepareOpsRequestBodyForQueue(raw) + require.NotNil(t, requestBodyJSON) + require.NotNil(t, requestBodyBytes) + require.True(t, truncated) + require.Equal(t, len(raw), *requestBodyBytes) + require.LessOrEqual(t, len(*requestBodyJSON), opsMaxStoredRequestBodyBytes) + require.Contains(t, *requestBodyJSON, "request_body_truncated") +} diff --git a/backend/internal/service/ops_settings_models.go b/backend/internal/service/ops_settings_models.go index ecc62220..8b5359e3 100644 --- a/backend/internal/service/ops_settings_models.go +++ b/backend/internal/service/ops_settings_models.go @@ -68,6 +68,20 @@ type OpsMetricThresholds struct { UpstreamErrorRatePercentMax *float64 `json:"upstream_error_rate_percent_max,omitempty"` // 上游错误率高于此值变红 } +type OpsRuntimeLogConfig struct { + Level string `json:"level"` + EnableSampling bool `json:"enable_sampling"` + SamplingInitial int `json:"sampling_initial"` + SamplingNext int `json:"sampling_thereafter"` + Caller bool `json:"caller"` + StacktraceLevel string `json:"stacktrace_level"` + RetentionDays int `json:"retention_days"` + Source string `json:"source,omitempty"` + UpdatedAt string `json:"updated_at,omitempty"` + UpdatedByUserID int64 `json:"updated_by_user_id,omitempty"` + Extra map[string]any `json:"extra,omitempty"` +} + type OpsAlertRuntimeSettings struct { EvaluationIntervalSeconds int `json:"evaluation_interval_seconds"` diff --git a/backend/internal/service/ops_system_log_service.go b/backend/internal/service/ops_system_log_service.go new file mode 100644 index 00000000..f5a64803 --- /dev/null +++ b/backend/internal/service/ops_system_log_service.go @@ -0,0 +1,124 @@ +package service + +import ( + "context" + "database/sql" + "encoding/json" + "errors" + "log" + "strings" + "time" + + infraerrors "github.com/Wei-Shaw/sub2api/internal/pkg/errors" +) + +func (s *OpsService) ListSystemLogs(ctx context.Context, filter *OpsSystemLogFilter) (*OpsSystemLogList, error) { + if err := s.RequireMonitoringEnabled(ctx); err != nil { + return nil, err + } + if s.opsRepo == nil { + return &OpsSystemLogList{ + Logs: []*OpsSystemLog{}, + Total: 0, + Page: 1, + PageSize: 50, + }, nil + } + if filter == nil { + filter = &OpsSystemLogFilter{} + } + if filter.Page <= 0 { + filter.Page = 1 + } + if filter.PageSize <= 0 { + filter.PageSize = 50 + } + if filter.PageSize > 200 { + filter.PageSize = 200 + } + + result, err := s.opsRepo.ListSystemLogs(ctx, filter) + if err != nil { + return nil, infraerrors.InternalServer("OPS_SYSTEM_LOG_LIST_FAILED", "Failed to list system logs").WithCause(err) + } + return result, nil +} + +func (s *OpsService) CleanupSystemLogs(ctx context.Context, filter *OpsSystemLogCleanupFilter, operatorID int64) (int64, error) { + if err := s.RequireMonitoringEnabled(ctx); err != nil { + return 0, err + } + if s.opsRepo == nil { + return 0, infraerrors.ServiceUnavailable("OPS_REPO_UNAVAILABLE", "Ops repository not available") + } + if operatorID <= 0 { + return 0, infraerrors.BadRequest("OPS_SYSTEM_LOG_CLEANUP_INVALID_OPERATOR", "invalid operator") + } + if filter == nil { + filter = &OpsSystemLogCleanupFilter{} + } + if filter.EndTime != nil && filter.StartTime != nil && filter.StartTime.After(*filter.EndTime) { + return 0, infraerrors.BadRequest("OPS_SYSTEM_LOG_CLEANUP_INVALID_RANGE", "invalid time range") + } + + deletedRows, err := s.opsRepo.DeleteSystemLogs(ctx, filter) + if err != nil { + if errors.Is(err, sql.ErrNoRows) { + return 0, nil + } + if strings.Contains(strings.ToLower(err.Error()), "requires at least one filter") { + return 0, infraerrors.BadRequest("OPS_SYSTEM_LOG_CLEANUP_FILTER_REQUIRED", "cleanup requires at least one filter condition") + } + return 0, infraerrors.InternalServer("OPS_SYSTEM_LOG_CLEANUP_FAILED", "Failed to cleanup system logs").WithCause(err) + } + + if auditErr := s.opsRepo.InsertSystemLogCleanupAudit(ctx, &OpsSystemLogCleanupAudit{ + CreatedAt: time.Now().UTC(), + OperatorID: operatorID, + Conditions: marshalSystemLogCleanupConditions(filter), + DeletedRows: deletedRows, + }); auditErr != nil { + // 审计失败不影响主流程,避免运维清理被阻塞。 + log.Printf("[OpsSystemLog] cleanup audit failed: %v", auditErr) + } + return deletedRows, nil +} + +func marshalSystemLogCleanupConditions(filter *OpsSystemLogCleanupFilter) string { + if filter == nil { + return "{}" + } + payload := map[string]any{ + "level": strings.TrimSpace(filter.Level), + "component": strings.TrimSpace(filter.Component), + "request_id": strings.TrimSpace(filter.RequestID), + "client_request_id": strings.TrimSpace(filter.ClientRequestID), + "platform": strings.TrimSpace(filter.Platform), + "model": strings.TrimSpace(filter.Model), + "query": strings.TrimSpace(filter.Query), + } + if filter.UserID != nil { + payload["user_id"] = *filter.UserID + } + if filter.AccountID != nil { + payload["account_id"] = *filter.AccountID + } + if filter.StartTime != nil && !filter.StartTime.IsZero() { + payload["start_time"] = filter.StartTime.UTC().Format(time.RFC3339Nano) + } + if filter.EndTime != nil && !filter.EndTime.IsZero() { + payload["end_time"] = filter.EndTime.UTC().Format(time.RFC3339Nano) + } + raw, err := json.Marshal(payload) + if err != nil { + return "{}" + } + return string(raw) +} + +func (s *OpsService) GetSystemLogSinkHealth() OpsSystemLogSinkHealth { + if s == nil || s.systemLogSink == nil { + return OpsSystemLogSinkHealth{} + } + return s.systemLogSink.Health() +} diff --git a/backend/internal/service/ops_system_log_service_test.go b/backend/internal/service/ops_system_log_service_test.go new file mode 100644 index 00000000..cc9ddefe --- /dev/null +++ b/backend/internal/service/ops_system_log_service_test.go @@ -0,0 +1,243 @@ +package service + +import ( + "context" + "database/sql" + "errors" + "strings" + "testing" + "time" + + "github.com/Wei-Shaw/sub2api/internal/config" +) + +func TestOpsServiceListSystemLogs_DefaultClampAndSuccess(t *testing.T) { + var gotFilter *OpsSystemLogFilter + repo := &opsRepoMock{ + ListSystemLogsFn: func(ctx context.Context, filter *OpsSystemLogFilter) (*OpsSystemLogList, error) { + gotFilter = filter + return &OpsSystemLogList{ + Logs: []*OpsSystemLog{{ID: 1, Level: "warn", Message: "x"}}, + Total: 1, + Page: filter.Page, + PageSize: filter.PageSize, + }, nil + }, + } + svc := NewOpsService(repo, nil, nil, nil, nil, nil, nil, nil, nil, nil, nil) + + out, err := svc.ListSystemLogs(context.Background(), &OpsSystemLogFilter{ + Page: 0, + PageSize: 999, + }) + if err != nil { + t.Fatalf("ListSystemLogs() error: %v", err) + } + if gotFilter == nil { + t.Fatalf("expected repository to receive filter") + } + if gotFilter.Page != 1 || gotFilter.PageSize != 200 { + t.Fatalf("filter normalized unexpectedly: page=%d pageSize=%d", gotFilter.Page, gotFilter.PageSize) + } + if out.Total != 1 || len(out.Logs) != 1 { + t.Fatalf("unexpected result: %+v", out) + } +} + +func TestOpsServiceListSystemLogs_MonitoringDisabled(t *testing.T) { + svc := NewOpsService( + &opsRepoMock{}, + nil, + &config.Config{Ops: config.OpsConfig{Enabled: false}}, + nil, nil, nil, nil, nil, nil, nil, nil, + ) + _, err := svc.ListSystemLogs(context.Background(), &OpsSystemLogFilter{}) + if err == nil { + t.Fatalf("expected disabled error") + } +} + +func TestOpsServiceListSystemLogs_NilRepoReturnsEmpty(t *testing.T) { + svc := NewOpsService(nil, nil, nil, nil, nil, nil, nil, nil, nil, nil, nil) + out, err := svc.ListSystemLogs(context.Background(), nil) + if err != nil { + t.Fatalf("ListSystemLogs() error: %v", err) + } + if out == nil || out.Page != 1 || out.PageSize != 50 || out.Total != 0 || len(out.Logs) != 0 { + t.Fatalf("unexpected nil-repo result: %+v", out) + } +} + +func TestOpsServiceListSystemLogs_RepoErrorMapped(t *testing.T) { + repo := &opsRepoMock{ + ListSystemLogsFn: func(ctx context.Context, filter *OpsSystemLogFilter) (*OpsSystemLogList, error) { + return nil, errors.New("db down") + }, + } + svc := NewOpsService(repo, nil, nil, nil, nil, nil, nil, nil, nil, nil, nil) + _, err := svc.ListSystemLogs(context.Background(), &OpsSystemLogFilter{}) + if err == nil { + t.Fatalf("expected mapped internal error") + } + if !strings.Contains(err.Error(), "OPS_SYSTEM_LOG_LIST_FAILED") { + t.Fatalf("unexpected error: %v", err) + } +} + +func TestOpsServiceCleanupSystemLogs_SuccessAndAudit(t *testing.T) { + var audit *OpsSystemLogCleanupAudit + repo := &opsRepoMock{ + DeleteSystemLogsFn: func(ctx context.Context, filter *OpsSystemLogCleanupFilter) (int64, error) { + return 3, nil + }, + InsertSystemLogCleanupAuditFn: func(ctx context.Context, input *OpsSystemLogCleanupAudit) error { + audit = input + return nil + }, + } + svc := NewOpsService(repo, nil, nil, nil, nil, nil, nil, nil, nil, nil, nil) + userID := int64(7) + now := time.Now().UTC() + filter := &OpsSystemLogCleanupFilter{ + StartTime: &now, + Level: "warn", + RequestID: "req-1", + ClientRequestID: "creq-1", + UserID: &userID, + Query: "timeout", + } + + deleted, err := svc.CleanupSystemLogs(context.Background(), filter, 99) + if err != nil { + t.Fatalf("CleanupSystemLogs() error: %v", err) + } + if deleted != 3 { + t.Fatalf("deleted=%d, want 3", deleted) + } + if audit == nil { + t.Fatalf("expected cleanup audit") + } + if !strings.Contains(audit.Conditions, `"client_request_id":"creq-1"`) { + t.Fatalf("audit conditions should include client_request_id: %s", audit.Conditions) + } + if !strings.Contains(audit.Conditions, `"user_id":7`) { + t.Fatalf("audit conditions should include user_id: %s", audit.Conditions) + } +} + +func TestOpsServiceCleanupSystemLogs_RepoUnavailableAndInvalidOperator(t *testing.T) { + svc := NewOpsService(nil, nil, nil, nil, nil, nil, nil, nil, nil, nil, nil) + if _, err := svc.CleanupSystemLogs(context.Background(), &OpsSystemLogCleanupFilter{RequestID: "r"}, 1); err == nil { + t.Fatalf("expected repo unavailable error") + } + + svc = NewOpsService(&opsRepoMock{}, nil, nil, nil, nil, nil, nil, nil, nil, nil, nil) + if _, err := svc.CleanupSystemLogs(context.Background(), &OpsSystemLogCleanupFilter{RequestID: "r"}, 0); err == nil { + t.Fatalf("expected invalid operator error") + } +} + +func TestOpsServiceCleanupSystemLogs_FilterRequired(t *testing.T) { + repo := &opsRepoMock{ + DeleteSystemLogsFn: func(ctx context.Context, filter *OpsSystemLogCleanupFilter) (int64, error) { + return 0, errors.New("cleanup requires at least one filter condition") + }, + } + svc := NewOpsService(repo, nil, nil, nil, nil, nil, nil, nil, nil, nil, nil) + _, err := svc.CleanupSystemLogs(context.Background(), &OpsSystemLogCleanupFilter{}, 1) + if err == nil { + t.Fatalf("expected filter required error") + } + if !strings.Contains(strings.ToLower(err.Error()), "filter") { + t.Fatalf("unexpected error: %v", err) + } +} + +func TestOpsServiceCleanupSystemLogs_InvalidRange(t *testing.T) { + repo := &opsRepoMock{} + svc := NewOpsService(repo, nil, nil, nil, nil, nil, nil, nil, nil, nil, nil) + start := time.Now().UTC() + end := start.Add(-time.Hour) + _, err := svc.CleanupSystemLogs(context.Background(), &OpsSystemLogCleanupFilter{ + StartTime: &start, + EndTime: &end, + }, 1) + if err == nil { + t.Fatalf("expected invalid range error") + } +} + +func TestOpsServiceCleanupSystemLogs_NoRowsAndInternalError(t *testing.T) { + repo := &opsRepoMock{ + DeleteSystemLogsFn: func(ctx context.Context, filter *OpsSystemLogCleanupFilter) (int64, error) { + return 0, sql.ErrNoRows + }, + } + svc := NewOpsService(repo, nil, nil, nil, nil, nil, nil, nil, nil, nil, nil) + deleted, err := svc.CleanupSystemLogs(context.Background(), &OpsSystemLogCleanupFilter{ + RequestID: "req-1", + }, 1) + if err != nil || deleted != 0 { + t.Fatalf("expected no rows shortcut, deleted=%d err=%v", deleted, err) + } + + repo.DeleteSystemLogsFn = func(ctx context.Context, filter *OpsSystemLogCleanupFilter) (int64, error) { + return 0, errors.New("boom") + } + if _, err := svc.CleanupSystemLogs(context.Background(), &OpsSystemLogCleanupFilter{ + RequestID: "req-1", + }, 1); err == nil { + t.Fatalf("expected internal cleanup error") + } +} + +func TestOpsServiceCleanupSystemLogs_AuditFailureIgnored(t *testing.T) { + repo := &opsRepoMock{ + DeleteSystemLogsFn: func(ctx context.Context, filter *OpsSystemLogCleanupFilter) (int64, error) { + return 5, nil + }, + InsertSystemLogCleanupAuditFn: func(ctx context.Context, input *OpsSystemLogCleanupAudit) error { + return errors.New("audit down") + }, + } + svc := NewOpsService(repo, nil, nil, nil, nil, nil, nil, nil, nil, nil, nil) + deleted, err := svc.CleanupSystemLogs(context.Background(), &OpsSystemLogCleanupFilter{ + RequestID: "r1", + }, 1) + if err != nil || deleted != 5 { + t.Fatalf("audit failure should not break cleanup, deleted=%d err=%v", deleted, err) + } +} + +func TestMarshalSystemLogCleanupConditions_NilAndMarshalError(t *testing.T) { + if got := marshalSystemLogCleanupConditions(nil); got != "{}" { + t.Fatalf("nil filter should return {}, got %s", got) + } + + now := time.Now().UTC() + userID := int64(1) + filter := &OpsSystemLogCleanupFilter{ + StartTime: &now, + EndTime: &now, + UserID: &userID, + } + got := marshalSystemLogCleanupConditions(filter) + if !strings.Contains(got, `"start_time"`) || !strings.Contains(got, `"user_id":1`) { + t.Fatalf("unexpected marshal payload: %s", got) + } +} + +func TestOpsServiceGetSystemLogSinkHealth(t *testing.T) { + svc := NewOpsService(nil, nil, nil, nil, nil, nil, nil, nil, nil, nil, nil) + health := svc.GetSystemLogSinkHealth() + if health.QueueCapacity != 0 || health.QueueDepth != 0 { + t.Fatalf("unexpected health for nil sink: %+v", health) + } + + sink := NewOpsSystemLogSink(&opsRepoMock{}) + svc = NewOpsService(&opsRepoMock{}, nil, nil, nil, nil, nil, nil, nil, nil, nil, sink) + health = svc.GetSystemLogSinkHealth() + if health.QueueCapacity <= 0 { + t.Fatalf("expected non-zero queue capacity: %+v", health) + } +} diff --git a/backend/internal/service/ops_system_log_sink.go b/backend/internal/service/ops_system_log_sink.go new file mode 100644 index 00000000..c50a30d5 --- /dev/null +++ b/backend/internal/service/ops_system_log_sink.go @@ -0,0 +1,335 @@ +package service + +import ( + "context" + "encoding/json" + "fmt" + "os" + "strconv" + "strings" + "sync" + "sync/atomic" + "time" + + "github.com/Wei-Shaw/sub2api/internal/pkg/logger" + "github.com/Wei-Shaw/sub2api/internal/util/logredact" +) + +type OpsSystemLogSinkHealth struct { + QueueDepth int64 `json:"queue_depth"` + QueueCapacity int64 `json:"queue_capacity"` + DroppedCount uint64 `json:"dropped_count"` + WriteFailed uint64 `json:"write_failed_count"` + WrittenCount uint64 `json:"written_count"` + AvgWriteDelayMs uint64 `json:"avg_write_delay_ms"` + LastError string `json:"last_error"` +} + +type OpsSystemLogSink struct { + opsRepo OpsRepository + + queue chan *logger.LogEvent + + batchSize int + flushInterval time.Duration + + ctx context.Context + cancel context.CancelFunc + wg sync.WaitGroup + + droppedCount uint64 + writeFailed uint64 + writtenCount uint64 + totalDelayNs uint64 + + lastError atomic.Value +} + +func NewOpsSystemLogSink(opsRepo OpsRepository) *OpsSystemLogSink { + ctx, cancel := context.WithCancel(context.Background()) + s := &OpsSystemLogSink{ + opsRepo: opsRepo, + queue: make(chan *logger.LogEvent, 5000), + batchSize: 200, + flushInterval: time.Second, + ctx: ctx, + cancel: cancel, + } + s.lastError.Store("") + return s +} + +func (s *OpsSystemLogSink) Start() { + if s == nil || s.opsRepo == nil { + return + } + s.wg.Add(1) + go s.run() +} + +func (s *OpsSystemLogSink) Stop() { + if s == nil { + return + } + s.cancel() + s.wg.Wait() +} + +func (s *OpsSystemLogSink) WriteLogEvent(event *logger.LogEvent) { + if s == nil || event == nil || !s.shouldIndex(event) { + return + } + if s.ctx != nil { + select { + case <-s.ctx.Done(): + return + default: + } + } + + select { + case s.queue <- event: + default: + atomic.AddUint64(&s.droppedCount, 1) + } +} + +func (s *OpsSystemLogSink) shouldIndex(event *logger.LogEvent) bool { + level := strings.ToLower(strings.TrimSpace(event.Level)) + switch level { + case "warn", "warning", "error", "fatal", "panic", "dpanic": + return true + } + + component := strings.ToLower(strings.TrimSpace(event.Component)) + // zap 的 LoggerName 往往为空或不等于业务组件名;业务组件名通常以字段 component 透传。 + if event.Fields != nil { + if fc := strings.ToLower(strings.TrimSpace(asString(event.Fields["component"]))); fc != "" { + component = fc + } + } + if strings.Contains(component, "http.access") { + return true + } + if strings.Contains(component, "audit") { + return true + } + return false +} + +func (s *OpsSystemLogSink) run() { + defer s.wg.Done() + + ticker := time.NewTicker(s.flushInterval) + defer ticker.Stop() + + batch := make([]*logger.LogEvent, 0, s.batchSize) + flush := func(baseCtx context.Context) { + if len(batch) == 0 { + return + } + started := time.Now() + inserted, err := s.flushBatch(baseCtx, batch) + delay := time.Since(started) + if err != nil { + atomic.AddUint64(&s.writeFailed, uint64(len(batch))) + s.lastError.Store(err.Error()) + _, _ = fmt.Fprintf(os.Stderr, "time=%s level=WARN msg=\"ops system log sink flush failed\" err=%v batch=%d\n", + time.Now().Format(time.RFC3339Nano), err, len(batch), + ) + } else { + atomic.AddUint64(&s.writtenCount, uint64(inserted)) + atomic.AddUint64(&s.totalDelayNs, uint64(delay.Nanoseconds())) + s.lastError.Store("") + } + batch = batch[:0] + } + drainAndFlush := func() { + for { + select { + case item := <-s.queue: + if item == nil { + continue + } + batch = append(batch, item) + if len(batch) >= s.batchSize { + flush(context.Background()) + } + default: + flush(context.Background()) + return + } + } + } + + for { + select { + case <-s.ctx.Done(): + drainAndFlush() + return + case item := <-s.queue: + if item == nil { + continue + } + batch = append(batch, item) + if len(batch) >= s.batchSize { + flush(s.ctx) + } + case <-ticker.C: + flush(s.ctx) + } + } +} + +func (s *OpsSystemLogSink) flushBatch(baseCtx context.Context, batch []*logger.LogEvent) (int, error) { + inputs := make([]*OpsInsertSystemLogInput, 0, len(batch)) + for _, event := range batch { + if event == nil { + continue + } + createdAt := event.Time.UTC() + if createdAt.IsZero() { + createdAt = time.Now().UTC() + } + + fields := copyMap(event.Fields) + requestID := asString(fields["request_id"]) + clientRequestID := asString(fields["client_request_id"]) + platform := asString(fields["platform"]) + model := asString(fields["model"]) + component := strings.TrimSpace(event.Component) + if fieldComponent := asString(fields["component"]); fieldComponent != "" { + component = fieldComponent + } + if component == "" { + component = "app" + } + + userID := asInt64Ptr(fields["user_id"]) + accountID := asInt64Ptr(fields["account_id"]) + + // 统一脱敏后写入索引。 + message := logredact.RedactText(strings.TrimSpace(event.Message)) + redactedExtra := logredact.RedactMap(fields) + extraJSONBytes, _ := json.Marshal(redactedExtra) + extraJSON := string(extraJSONBytes) + if strings.TrimSpace(extraJSON) == "" { + extraJSON = "{}" + } + + inputs = append(inputs, &OpsInsertSystemLogInput{ + CreatedAt: createdAt, + Level: strings.ToLower(strings.TrimSpace(event.Level)), + Component: component, + Message: message, + RequestID: requestID, + ClientRequestID: clientRequestID, + UserID: userID, + AccountID: accountID, + Platform: platform, + Model: model, + ExtraJSON: extraJSON, + }) + } + + if len(inputs) == 0 { + return 0, nil + } + if baseCtx == nil || baseCtx.Err() != nil { + baseCtx = context.Background() + } + ctx, cancel := context.WithTimeout(baseCtx, 5*time.Second) + defer cancel() + inserted, err := s.opsRepo.BatchInsertSystemLogs(ctx, inputs) + if err != nil { + return 0, err + } + return int(inserted), nil +} + +func (s *OpsSystemLogSink) Health() OpsSystemLogSinkHealth { + if s == nil { + return OpsSystemLogSinkHealth{} + } + written := atomic.LoadUint64(&s.writtenCount) + totalDelay := atomic.LoadUint64(&s.totalDelayNs) + var avgDelay uint64 + if written > 0 { + avgDelay = (totalDelay / written) / uint64(time.Millisecond) + } + + lastErr, _ := s.lastError.Load().(string) + return OpsSystemLogSinkHealth{ + QueueDepth: int64(len(s.queue)), + QueueCapacity: int64(cap(s.queue)), + DroppedCount: atomic.LoadUint64(&s.droppedCount), + WriteFailed: atomic.LoadUint64(&s.writeFailed), + WrittenCount: written, + AvgWriteDelayMs: avgDelay, + LastError: strings.TrimSpace(lastErr), + } +} + +func copyMap(in map[string]any) map[string]any { + if len(in) == 0 { + return map[string]any{} + } + out := make(map[string]any, len(in)) + for k, v := range in { + out[k] = v + } + return out +} + +func asString(v any) string { + switch t := v.(type) { + case string: + return strings.TrimSpace(t) + case fmt.Stringer: + return strings.TrimSpace(t.String()) + default: + return "" + } +} + +func asInt64Ptr(v any) *int64 { + switch t := v.(type) { + case int: + n := int64(t) + if n <= 0 { + return nil + } + return &n + case int64: + n := t + if n <= 0 { + return nil + } + return &n + case float64: + n := int64(t) + if n <= 0 { + return nil + } + return &n + case json.Number: + if n, err := t.Int64(); err == nil { + if n <= 0 { + return nil + } + return &n + } + case string: + raw := strings.TrimSpace(t) + if raw == "" { + return nil + } + if n, err := strconv.ParseInt(raw, 10, 64); err == nil { + if n <= 0 { + return nil + } + return &n + } + } + return nil +} diff --git a/backend/internal/service/ops_system_log_sink_test.go b/backend/internal/service/ops_system_log_sink_test.go new file mode 100644 index 00000000..12a2ec0c --- /dev/null +++ b/backend/internal/service/ops_system_log_sink_test.go @@ -0,0 +1,313 @@ +package service + +import ( + "context" + "encoding/json" + "errors" + "fmt" + "strings" + "sync/atomic" + "testing" + "time" + + "github.com/Wei-Shaw/sub2api/internal/pkg/logger" +) + +func TestOpsSystemLogSink_ShouldIndex(t *testing.T) { + sink := &OpsSystemLogSink{} + + cases := []struct { + name string + event *logger.LogEvent + want bool + }{ + { + name: "warn level", + event: &logger.LogEvent{Level: "warn", Component: "app"}, + want: true, + }, + { + name: "error level", + event: &logger.LogEvent{Level: "error", Component: "app"}, + want: true, + }, + { + name: "access component", + event: &logger.LogEvent{Level: "info", Component: "http.access"}, + want: true, + }, + { + name: "access component from fields (real zap path)", + event: &logger.LogEvent{ + Level: "info", + Component: "", + Fields: map[string]any{"component": "http.access"}, + }, + want: true, + }, + { + name: "audit component", + event: &logger.LogEvent{Level: "info", Component: "audit.log_config_change"}, + want: true, + }, + { + name: "audit component from fields (real zap path)", + event: &logger.LogEvent{ + Level: "info", + Component: "", + Fields: map[string]any{"component": "audit.log_config_change"}, + }, + want: true, + }, + { + name: "plain info", + event: &logger.LogEvent{Level: "info", Component: "app"}, + want: false, + }, + } + + for _, tc := range cases { + if got := sink.shouldIndex(tc.event); got != tc.want { + t.Fatalf("%s: shouldIndex()=%v, want %v", tc.name, got, tc.want) + } + } +} + +func TestOpsSystemLogSink_WriteLogEvent_ShouldDropWhenQueueFull(t *testing.T) { + sink := &OpsSystemLogSink{ + queue: make(chan *logger.LogEvent, 1), + } + + sink.WriteLogEvent(&logger.LogEvent{Level: "warn", Component: "app"}) + sink.WriteLogEvent(&logger.LogEvent{Level: "warn", Component: "app"}) + + if got := len(sink.queue); got != 1 { + t.Fatalf("queue len = %d, want 1", got) + } + if dropped := atomic.LoadUint64(&sink.droppedCount); dropped != 1 { + t.Fatalf("droppedCount = %d, want 1", dropped) + } +} + +func TestOpsSystemLogSink_Health(t *testing.T) { + sink := &OpsSystemLogSink{ + queue: make(chan *logger.LogEvent, 10), + } + sink.lastError.Store("db timeout") + atomic.StoreUint64(&sink.droppedCount, 3) + atomic.StoreUint64(&sink.writeFailed, 2) + atomic.StoreUint64(&sink.writtenCount, 5) + atomic.StoreUint64(&sink.totalDelayNs, uint64(5000000)) // 5ms total -> avg 1ms + sink.queue <- &logger.LogEvent{Level: "warn", Component: "app"} + sink.queue <- &logger.LogEvent{Level: "warn", Component: "app"} + + health := sink.Health() + if health.QueueDepth != 2 { + t.Fatalf("queue depth = %d, want 2", health.QueueDepth) + } + if health.QueueCapacity != 10 { + t.Fatalf("queue capacity = %d, want 10", health.QueueCapacity) + } + if health.DroppedCount != 3 { + t.Fatalf("dropped = %d, want 3", health.DroppedCount) + } + if health.WriteFailed != 2 { + t.Fatalf("write failed = %d, want 2", health.WriteFailed) + } + if health.WrittenCount != 5 { + t.Fatalf("written = %d, want 5", health.WrittenCount) + } + if health.AvgWriteDelayMs != 1 { + t.Fatalf("avg delay ms = %d, want 1", health.AvgWriteDelayMs) + } + if health.LastError != "db timeout" { + t.Fatalf("last error = %q, want db timeout", health.LastError) + } +} + +func TestOpsSystemLogSink_StartStopAndFlushSuccess(t *testing.T) { + done := make(chan struct{}, 1) + var captured []*OpsInsertSystemLogInput + repo := &opsRepoMock{ + BatchInsertSystemLogsFn: func(_ context.Context, inputs []*OpsInsertSystemLogInput) (int64, error) { + captured = append(captured, inputs...) + select { + case done <- struct{}{}: + default: + } + return int64(len(inputs)), nil + }, + } + + sink := NewOpsSystemLogSink(repo) + sink.batchSize = 1 + sink.flushInterval = 10 * time.Millisecond + sink.Start() + defer sink.Stop() + + sink.WriteLogEvent(&logger.LogEvent{ + Time: time.Now().UTC(), + Level: "warn", + Component: "http.access", + Message: `authorization="Bearer sk-test-123"`, + Fields: map[string]any{ + "component": "http.access", + "request_id": "req-1", + "client_request_id": "creq-1", + "user_id": "12", + "account_id": json.Number("34"), + "platform": "openai", + "model": "gpt-5", + }, + }) + + select { + case <-done: + case <-time.After(2 * time.Second): + t.Fatalf("timeout waiting for sink flush") + } + + if len(captured) != 1 { + t.Fatalf("captured len = %d, want 1", len(captured)) + } + item := captured[0] + if item.RequestID != "req-1" || item.ClientRequestID != "creq-1" { + t.Fatalf("unexpected request ids: %+v", item) + } + if item.UserID == nil || *item.UserID != 12 { + t.Fatalf("unexpected user_id: %+v", item.UserID) + } + if item.AccountID == nil || *item.AccountID != 34 { + t.Fatalf("unexpected account_id: %+v", item.AccountID) + } + if strings.TrimSpace(item.Message) == "" { + t.Fatalf("message should not be empty") + } + health := sink.Health() + if health.WrittenCount == 0 { + t.Fatalf("written_count should be >0") + } +} + +func TestOpsSystemLogSink_FlushFailureUpdatesHealth(t *testing.T) { + repo := &opsRepoMock{ + BatchInsertSystemLogsFn: func(_ context.Context, inputs []*OpsInsertSystemLogInput) (int64, error) { + return 0, errors.New("db unavailable") + }, + } + sink := NewOpsSystemLogSink(repo) + sink.batchSize = 1 + sink.flushInterval = 10 * time.Millisecond + sink.Start() + defer sink.Stop() + + sink.WriteLogEvent(&logger.LogEvent{ + Time: time.Now().UTC(), + Level: "warn", + Component: "app", + Message: "boom", + Fields: map[string]any{}, + }) + + deadline := time.Now().Add(2 * time.Second) + for time.Now().Before(deadline) { + health := sink.Health() + if health.WriteFailed > 0 { + if !strings.Contains(health.LastError, "db unavailable") { + t.Fatalf("unexpected last error: %s", health.LastError) + } + return + } + time.Sleep(20 * time.Millisecond) + } + t.Fatalf("write_failed_count not updated") +} + +func TestOpsSystemLogSink_StopFlushUsesActiveContextAndDrainsQueue(t *testing.T) { + var inserted int64 + var canceledCtxCalls int64 + repo := &opsRepoMock{ + BatchInsertSystemLogsFn: func(ctx context.Context, inputs []*OpsInsertSystemLogInput) (int64, error) { + if err := ctx.Err(); err != nil { + atomic.AddInt64(&canceledCtxCalls, 1) + return 0, err + } + atomic.AddInt64(&inserted, int64(len(inputs))) + return int64(len(inputs)), nil + }, + } + + sink := NewOpsSystemLogSink(repo) + sink.batchSize = 200 + sink.flushInterval = time.Hour + sink.Start() + + sink.WriteLogEvent(&logger.LogEvent{ + Time: time.Now().UTC(), + Level: "warn", + Component: "app", + Message: "pending-on-shutdown", + Fields: map[string]any{"component": "http.access"}, + }) + + sink.Stop() + + if got := atomic.LoadInt64(&inserted); got != 1 { + t.Fatalf("inserted = %d, want 1", got) + } + if got := atomic.LoadInt64(&canceledCtxCalls); got != 0 { + t.Fatalf("canceled ctx calls = %d, want 0", got) + } + health := sink.Health() + if health.WrittenCount != 1 { + t.Fatalf("written_count = %d, want 1", health.WrittenCount) + } +} + +type stringerValue string + +func (s stringerValue) String() string { return string(s) } + +func TestOpsSystemLogSink_HelperFunctions(t *testing.T) { + src := map[string]any{"a": 1} + cloned := copyMap(src) + src["a"] = 2 + v, ok := cloned["a"].(int) + if !ok || v != 1 { + t.Fatalf("copyMap should create copy") + } + if got := asString(stringerValue(" hello ")); got != "hello" { + t.Fatalf("asString stringer = %q", got) + } + if got := asString(fmt.Errorf("x")); got != "" { + t.Fatalf("asString error should be empty, got %q", got) + } + if got := asString(123); got != "" { + t.Fatalf("asString non-string should be empty, got %q", got) + } + + cases := []struct { + in any + want int64 + ok bool + }{ + {in: 5, want: 5, ok: true}, + {in: int64(6), want: 6, ok: true}, + {in: float64(7), want: 7, ok: true}, + {in: json.Number("8"), want: 8, ok: true}, + {in: "9", want: 9, ok: true}, + {in: "0", ok: false}, + {in: -1, ok: false}, + {in: "abc", ok: false}, + } + for _, tc := range cases { + got := asInt64Ptr(tc.in) + if tc.ok { + if got == nil || *got != tc.want { + t.Fatalf("asInt64Ptr(%v) = %+v, want %d", tc.in, got, tc.want) + } + } else if got != nil { + t.Fatalf("asInt64Ptr(%v) should be nil, got %d", tc.in, *got) + } + } +} diff --git a/backend/internal/service/ops_upstream_context.go b/backend/internal/service/ops_upstream_context.go index 3514df79..23c154ce 100644 --- a/backend/internal/service/ops_upstream_context.go +++ b/backend/internal/service/ops_upstream_context.go @@ -21,11 +21,33 @@ const ( // This value is sanitized+trimmed before being persisted. OpsUpstreamRequestBodyKey = "ops_upstream_request_body" + // Optional stage latencies (milliseconds) for troubleshooting and alerting. + OpsAuthLatencyMsKey = "ops_auth_latency_ms" + OpsRoutingLatencyMsKey = "ops_routing_latency_ms" + OpsUpstreamLatencyMsKey = "ops_upstream_latency_ms" + OpsResponseLatencyMsKey = "ops_response_latency_ms" + OpsTimeToFirstTokenMsKey = "ops_time_to_first_token_ms" + // OpsSkipPassthroughKey 由 applyErrorPassthroughRule 在命中 skip_monitoring=true 的规则时设置。 // ops_error_logger 中间件检查此 key,为 true 时跳过错误记录。 OpsSkipPassthroughKey = "ops_skip_passthrough" ) +func setOpsUpstreamRequestBody(c *gin.Context, body []byte) { + if c == nil || len(body) == 0 { + return + } + // 热路径避免 string(body) 额外分配,按需在落库前再转换。 + c.Set(OpsUpstreamRequestBodyKey, body) +} + +func SetOpsLatencyMs(c *gin.Context, key string, value int64) { + if c == nil || strings.TrimSpace(key) == "" || value < 0 { + return + } + c.Set(key, value) +} + func setOpsUpstreamError(c *gin.Context, upstreamStatusCode int, upstreamMessage, upstreamDetail string) { if c == nil { return @@ -46,6 +68,10 @@ func setOpsUpstreamError(c *gin.Context, upstreamStatusCode int, upstreamMessage type OpsUpstreamErrorEvent struct { AtUnixMs int64 `json:"at_unix_ms,omitempty"` + // Passthrough 表示本次请求是否命中“原样透传(仅替换认证)”分支。 + // 该字段用于排障与灰度评估;存入 JSON,不涉及 DB schema 变更。 + Passthrough bool `json:"passthrough,omitempty"` + // Context Platform string `json:"platform,omitempty"` AccountID int64 `json:"account_id,omitempty"` @@ -91,8 +117,11 @@ func appendOpsUpstreamError(c *gin.Context, ev OpsUpstreamErrorEvent) { // stored it on the context, attach it so ops can retry this specific attempt. if ev.UpstreamRequestBody == "" { if v, ok := c.Get(OpsUpstreamRequestBodyKey); ok { - if s, ok := v.(string); ok { - ev.UpstreamRequestBody = strings.TrimSpace(s) + switch raw := v.(type) { + case string: + ev.UpstreamRequestBody = strings.TrimSpace(raw) + case []byte: + ev.UpstreamRequestBody = strings.TrimSpace(string(raw)) } } } diff --git a/backend/internal/service/ops_upstream_context_test.go b/backend/internal/service/ops_upstream_context_test.go new file mode 100644 index 00000000..50ceaa0e --- /dev/null +++ b/backend/internal/service/ops_upstream_context_test.go @@ -0,0 +1,47 @@ +package service + +import ( + "net/http/httptest" + "testing" + + "github.com/gin-gonic/gin" + "github.com/stretchr/testify/require" +) + +func TestAppendOpsUpstreamError_UsesRequestBodyBytesFromContext(t *testing.T) { + gin.SetMode(gin.TestMode) + rec := httptest.NewRecorder() + c, _ := gin.CreateTestContext(rec) + + setOpsUpstreamRequestBody(c, []byte(`{"model":"gpt-5"}`)) + appendOpsUpstreamError(c, OpsUpstreamErrorEvent{ + Kind: "http_error", + Message: "upstream failed", + }) + + v, ok := c.Get(OpsUpstreamErrorsKey) + require.True(t, ok) + events, ok := v.([]*OpsUpstreamErrorEvent) + require.True(t, ok) + require.Len(t, events, 1) + require.Equal(t, `{"model":"gpt-5"}`, events[0].UpstreamRequestBody) +} + +func TestAppendOpsUpstreamError_UsesRequestBodyStringFromContext(t *testing.T) { + gin.SetMode(gin.TestMode) + rec := httptest.NewRecorder() + c, _ := gin.CreateTestContext(rec) + + c.Set(OpsUpstreamRequestBodyKey, `{"model":"gpt-4"}`) + appendOpsUpstreamError(c, OpsUpstreamErrorEvent{ + Kind: "request_error", + Message: "dial timeout", + }) + + v, ok := c.Get(OpsUpstreamErrorsKey) + require.True(t, ok) + events, ok := v.([]*OpsUpstreamErrorEvent) + require.True(t, ok) + require.Len(t, events, 1) + require.Equal(t, `{"model":"gpt-4"}`, events[0].UpstreamRequestBody) +} diff --git a/backend/internal/service/parse_integral_number_unit.go b/backend/internal/service/parse_integral_number_unit.go new file mode 100644 index 00000000..c9c617b1 --- /dev/null +++ b/backend/internal/service/parse_integral_number_unit.go @@ -0,0 +1,51 @@ +//go:build unit + +package service + +import ( + "encoding/json" + "math" +) + +// parseIntegralNumber 将 JSON 解码后的数字安全转换为 int。 +// 仅接受“整数值”的输入,小数/NaN/Inf/越界值都会返回 false。 +// +// 说明: +// - 该函数当前仅用于 unit 测试中的 map-based 解析逻辑验证,因此放在 unit build tag 下, +// 避免在默认构建中触发 unused lint。 +func parseIntegralNumber(raw any) (int, bool) { + switch v := raw.(type) { + case float64: + if math.IsNaN(v) || math.IsInf(v, 0) || v != math.Trunc(v) { + return 0, false + } + if v > float64(math.MaxInt) || v < float64(math.MinInt) { + return 0, false + } + return int(v), true + case int: + return v, true + case int8: + return int(v), true + case int16: + return int(v), true + case int32: + return int(v), true + case int64: + if v > int64(math.MaxInt) || v < int64(math.MinInt) { + return 0, false + } + return int(v), true + case json.Number: + i64, err := v.Int64() + if err != nil { + return 0, false + } + if i64 > int64(math.MaxInt) || i64 < int64(math.MinInt) { + return 0, false + } + return int(i64), true + default: + return 0, false + } +} diff --git a/backend/internal/service/pricing_service.go b/backend/internal/service/pricing_service.go index a3a94189..41e8b5eb 100644 --- a/backend/internal/service/pricing_service.go +++ b/backend/internal/service/pricing_service.go @@ -6,7 +6,6 @@ import ( "encoding/hex" "encoding/json" "fmt" - "log" "os" "path/filepath" "regexp" @@ -15,8 +14,10 @@ import ( "time" "github.com/Wei-Shaw/sub2api/internal/config" + "github.com/Wei-Shaw/sub2api/internal/pkg/logger" "github.com/Wei-Shaw/sub2api/internal/pkg/openai" "github.com/Wei-Shaw/sub2api/internal/util/urlvalidator" + "go.uber.org/zap" ) var ( @@ -86,12 +87,12 @@ func NewPricingService(cfg *config.Config, remoteClient PricingRemoteClient) *Pr func (s *PricingService) Initialize() error { // 确保数据目录存在 if err := os.MkdirAll(s.cfg.Pricing.DataDir, 0755); err != nil { - log.Printf("[Pricing] Failed to create data directory: %v", err) + logger.LegacyPrintf("service.pricing", "[Pricing] Failed to create data directory: %v", err) } // 首次加载价格数据 if err := s.checkAndUpdatePricing(); err != nil { - log.Printf("[Pricing] Initial load failed, using fallback: %v", err) + logger.LegacyPrintf("service.pricing", "[Pricing] Initial load failed, using fallback: %v", err) if err := s.useFallbackPricing(); err != nil { return fmt.Errorf("failed to load pricing data: %w", err) } @@ -100,7 +101,7 @@ func (s *PricingService) Initialize() error { // 启动定时更新 s.startUpdateScheduler() - log.Printf("[Pricing] Service initialized with %d models", len(s.pricingData)) + logger.LegacyPrintf("service.pricing", "[Pricing] Service initialized with %d models", len(s.pricingData)) return nil } @@ -108,7 +109,7 @@ func (s *PricingService) Initialize() error { func (s *PricingService) Stop() { close(s.stopCh) s.wg.Wait() - log.Println("[Pricing] Service stopped") + logger.LegacyPrintf("service.pricing", "%s", "[Pricing] Service stopped") } // startUpdateScheduler 启动定时更新调度器 @@ -129,7 +130,7 @@ func (s *PricingService) startUpdateScheduler() { select { case <-ticker.C: if err := s.syncWithRemote(); err != nil { - log.Printf("[Pricing] Sync failed: %v", err) + logger.LegacyPrintf("service.pricing", "[Pricing] Sync failed: %v", err) } case <-s.stopCh: return @@ -137,7 +138,7 @@ func (s *PricingService) startUpdateScheduler() { } }() - log.Printf("[Pricing] Update scheduler started (check every %v)", hashInterval) + logger.LegacyPrintf("service.pricing", "[Pricing] Update scheduler started (check every %v)", hashInterval) } // checkAndUpdatePricing 检查并更新价格数据 @@ -146,7 +147,7 @@ func (s *PricingService) checkAndUpdatePricing() error { // 检查本地文件是否存在 if _, err := os.Stat(pricingFile); os.IsNotExist(err) { - log.Println("[Pricing] Local pricing file not found, downloading...") + logger.LegacyPrintf("service.pricing", "%s", "[Pricing] Local pricing file not found, downloading...") return s.downloadPricingData() } @@ -160,9 +161,9 @@ func (s *PricingService) checkAndUpdatePricing() error { maxAge := time.Duration(s.cfg.Pricing.UpdateIntervalHours) * time.Hour if fileAge > maxAge { - log.Printf("[Pricing] Local file is %v old, updating...", fileAge.Round(time.Hour)) + logger.LegacyPrintf("service.pricing", "[Pricing] Local file is %v old, updating...", fileAge.Round(time.Hour)) if err := s.downloadPricingData(); err != nil { - log.Printf("[Pricing] Download failed, using existing file: %v", err) + logger.LegacyPrintf("service.pricing", "[Pricing] Download failed, using existing file: %v", err) } } @@ -177,7 +178,7 @@ func (s *PricingService) syncWithRemote() error { // 计算本地文件哈希 localHash, err := s.computeFileHash(pricingFile) if err != nil { - log.Printf("[Pricing] Failed to compute local hash: %v", err) + logger.LegacyPrintf("service.pricing", "[Pricing] Failed to compute local hash: %v", err) return s.downloadPricingData() } @@ -185,15 +186,15 @@ func (s *PricingService) syncWithRemote() error { if s.cfg.Pricing.HashURL != "" { remoteHash, err := s.fetchRemoteHash() if err != nil { - log.Printf("[Pricing] Failed to fetch remote hash: %v", err) + logger.LegacyPrintf("service.pricing", "[Pricing] Failed to fetch remote hash: %v", err) return nil // 哈希获取失败不影响正常使用 } if remoteHash != localHash { - log.Println("[Pricing] Remote hash differs, downloading new version...") + logger.LegacyPrintf("service.pricing", "%s", "[Pricing] Remote hash differs, downloading new version...") return s.downloadPricingData() } - log.Println("[Pricing] Hash check passed, no update needed") + logger.LegacyPrintf("service.pricing", "%s", "[Pricing] Hash check passed, no update needed") return nil } @@ -207,7 +208,7 @@ func (s *PricingService) syncWithRemote() error { maxAge := time.Duration(s.cfg.Pricing.UpdateIntervalHours) * time.Hour if fileAge > maxAge { - log.Printf("[Pricing] File is %v old, downloading...", fileAge.Round(time.Hour)) + logger.LegacyPrintf("service.pricing", "[Pricing] File is %v old, downloading...", fileAge.Round(time.Hour)) return s.downloadPricingData() } @@ -220,7 +221,7 @@ func (s *PricingService) downloadPricingData() error { if err != nil { return err } - log.Printf("[Pricing] Downloading from %s", remoteURL) + logger.LegacyPrintf("service.pricing", "[Pricing] Downloading from %s", remoteURL) ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second) defer cancel() @@ -254,7 +255,7 @@ func (s *PricingService) downloadPricingData() error { // 保存到本地文件 pricingFile := s.getPricingFilePath() if err := os.WriteFile(pricingFile, body, 0644); err != nil { - log.Printf("[Pricing] Failed to save file: %v", err) + logger.LegacyPrintf("service.pricing", "[Pricing] Failed to save file: %v", err) } // 保存哈希 @@ -262,7 +263,7 @@ func (s *PricingService) downloadPricingData() error { hashStr := hex.EncodeToString(hash[:]) hashFile := s.getHashFilePath() if err := os.WriteFile(hashFile, []byte(hashStr+"\n"), 0644); err != nil { - log.Printf("[Pricing] Failed to save hash: %v", err) + logger.LegacyPrintf("service.pricing", "[Pricing] Failed to save hash: %v", err) } // 更新内存数据 @@ -272,7 +273,7 @@ func (s *PricingService) downloadPricingData() error { s.localHash = hashStr s.mu.Unlock() - log.Printf("[Pricing] Downloaded %d models successfully", len(data)) + logger.LegacyPrintf("service.pricing", "[Pricing] Downloaded %d models successfully", len(data)) return nil } @@ -334,7 +335,7 @@ func (s *PricingService) parsePricingData(body []byte) (map[string]*LiteLLMModel } if skipped > 0 { - log.Printf("[Pricing] Skipped %d invalid entries", skipped) + logger.LegacyPrintf("service.pricing", "[Pricing] Skipped %d invalid entries", skipped) } if len(result) == 0 { @@ -373,7 +374,7 @@ func (s *PricingService) loadPricingData(filePath string) error { } s.mu.Unlock() - log.Printf("[Pricing] Loaded %d models from %s", len(pricingData), filePath) + logger.LegacyPrintf("service.pricing", "[Pricing] Loaded %d models from %s", len(pricingData), filePath) return nil } @@ -385,7 +386,7 @@ func (s *PricingService) useFallbackPricing() error { return fmt.Errorf("fallback file not found: %s", fallbackFile) } - log.Printf("[Pricing] Using fallback file: %s", fallbackFile) + logger.LegacyPrintf("service.pricing", "[Pricing] Using fallback file: %s", fallbackFile) // 复制到数据目录 data, err := os.ReadFile(fallbackFile) @@ -395,7 +396,7 @@ func (s *PricingService) useFallbackPricing() error { pricingFile := s.getPricingFilePath() if err := os.WriteFile(pricingFile, data, 0644); err != nil { - log.Printf("[Pricing] Failed to copy fallback: %v", err) + logger.LegacyPrintf("service.pricing", "[Pricing] Failed to copy fallback: %v", err) } return s.loadPricingData(fallbackFile) @@ -644,7 +645,7 @@ func (s *PricingService) matchByModelFamily(model string) *LiteLLMModelPricing { for key, pricing := range s.pricingData { keyLower := strings.ToLower(key) if strings.Contains(keyLower, pattern) { - log.Printf("[Pricing] Fuzzy matched %s -> %s", model, key) + logger.LegacyPrintf("service.pricing", "[Pricing] Fuzzy matched %s -> %s", model, key) return pricing } } @@ -655,24 +656,36 @@ func (s *PricingService) matchByModelFamily(model string) *LiteLLMModelPricing { // matchOpenAIModel OpenAI 模型回退匹配策略 // 回退顺序: -// 1. gpt-5.2-codex -> gpt-5.2(去掉后缀如 -codex, -mini, -max 等) -// 2. gpt-5.2-20251222 -> gpt-5.2(去掉日期版本号) -// 3. gpt-5.3-codex -> gpt-5.2-codex -// 4. 最终回退到 DefaultTestModel (gpt-5.1-codex) +// 1. gpt-5.3-codex-spark* -> gpt-5.1-codex(按业务要求固定计费) +// 2. gpt-5.2-codex -> gpt-5.2(去掉后缀如 -codex, -mini, -max 等) +// 3. gpt-5.2-20251222 -> gpt-5.2(去掉日期版本号) +// 4. gpt-5.3-codex -> gpt-5.2-codex +// 5. 最终回退到 DefaultTestModel (gpt-5.1-codex) func (s *PricingService) matchOpenAIModel(model string) *LiteLLMModelPricing { + if strings.HasPrefix(model, "gpt-5.3-codex-spark") { + if pricing, ok := s.pricingData["gpt-5.1-codex"]; ok { + logger.LegacyPrintf("service.pricing", "[Pricing][SparkBilling] %s -> %s billing", model, "gpt-5.1-codex") + logger.With(zap.String("component", "service.pricing")). + Info(fmt.Sprintf("[Pricing] OpenAI fallback matched %s -> %s", model, "gpt-5.1-codex")) + return pricing + } + } + // 尝试的回退变体 variants := s.generateOpenAIModelVariants(model, openAIModelDatePattern) for _, variant := range variants { if pricing, ok := s.pricingData[variant]; ok { - log.Printf("[Pricing] OpenAI fallback matched %s -> %s", model, variant) + logger.With(zap.String("component", "service.pricing")). + Info(fmt.Sprintf("[Pricing] OpenAI fallback matched %s -> %s", model, variant)) return pricing } } if strings.HasPrefix(model, "gpt-5.3-codex") { if pricing, ok := s.pricingData["gpt-5.2-codex"]; ok { - log.Printf("[Pricing] OpenAI fallback matched %s -> %s", model, "gpt-5.2-codex") + logger.With(zap.String("component", "service.pricing")). + Info(fmt.Sprintf("[Pricing] OpenAI fallback matched %s -> %s", model, "gpt-5.2-codex")) return pricing } } @@ -680,7 +693,7 @@ func (s *PricingService) matchOpenAIModel(model string) *LiteLLMModelPricing { // 最终回退到 DefaultTestModel defaultModel := strings.ToLower(openai.DefaultTestModel) if pricing, ok := s.pricingData[defaultModel]; ok { - log.Printf("[Pricing] OpenAI fallback to default model %s -> %s", model, defaultModel) + logger.LegacyPrintf("service.pricing", "[Pricing] OpenAI fallback to default model %s -> %s", model, defaultModel) return pricing } diff --git a/backend/internal/service/pricing_service_test.go b/backend/internal/service/pricing_service_test.go new file mode 100644 index 00000000..127ff342 --- /dev/null +++ b/backend/internal/service/pricing_service_test.go @@ -0,0 +1,53 @@ +package service + +import ( + "testing" + + "github.com/stretchr/testify/require" +) + +func TestGetModelPricing_Gpt53CodexSparkUsesGpt51CodexPricing(t *testing.T) { + sparkPricing := &LiteLLMModelPricing{InputCostPerToken: 1} + gpt53Pricing := &LiteLLMModelPricing{InputCostPerToken: 9} + + svc := &PricingService{ + pricingData: map[string]*LiteLLMModelPricing{ + "gpt-5.1-codex": sparkPricing, + "gpt-5.3": gpt53Pricing, + }, + } + + got := svc.GetModelPricing("gpt-5.3-codex-spark") + require.Same(t, sparkPricing, got) +} + +func TestGetModelPricing_Gpt53CodexFallbackStillUsesGpt52Codex(t *testing.T) { + gpt52CodexPricing := &LiteLLMModelPricing{InputCostPerToken: 2} + + svc := &PricingService{ + pricingData: map[string]*LiteLLMModelPricing{ + "gpt-5.2-codex": gpt52CodexPricing, + }, + } + + got := svc.GetModelPricing("gpt-5.3-codex") + require.Same(t, gpt52CodexPricing, got) +} + +func TestGetModelPricing_OpenAIFallbackMatchedLoggedAsInfo(t *testing.T) { + logSink, restore := captureStructuredLog(t) + defer restore() + + gpt52CodexPricing := &LiteLLMModelPricing{InputCostPerToken: 2} + svc := &PricingService{ + pricingData: map[string]*LiteLLMModelPricing{ + "gpt-5.2-codex": gpt52CodexPricing, + }, + } + + got := svc.GetModelPricing("gpt-5.3-codex") + require.Same(t, gpt52CodexPricing, got) + + require.True(t, logSink.ContainsMessageAtLevel("[Pricing] OpenAI fallback matched gpt-5.3-codex -> gpt-5.2-codex", "info")) + require.False(t, logSink.ContainsMessageAtLevel("[Pricing] OpenAI fallback matched gpt-5.3-codex -> gpt-5.2-codex", "warn")) +} diff --git a/backend/internal/service/proxy.go b/backend/internal/service/proxy.go index 7eb7728f..fc449091 100644 --- a/backend/internal/service/proxy.go +++ b/backend/internal/service/proxy.go @@ -40,6 +40,11 @@ type ProxyWithAccountCount struct { CountryCode string Region string City string + QualityStatus string + QualityScore *int + QualityGrade string + QualitySummary string + QualityChecked *int64 } type ProxyAccountSummary struct { diff --git a/backend/internal/service/proxy_latency_cache.go b/backend/internal/service/proxy_latency_cache.go index 4a1cc77b..f54bff88 100644 --- a/backend/internal/service/proxy_latency_cache.go +++ b/backend/internal/service/proxy_latency_cache.go @@ -6,15 +6,21 @@ import ( ) type ProxyLatencyInfo struct { - Success bool `json:"success"` - LatencyMs *int64 `json:"latency_ms,omitempty"` - Message string `json:"message,omitempty"` - IPAddress string `json:"ip_address,omitempty"` - Country string `json:"country,omitempty"` - CountryCode string `json:"country_code,omitempty"` - Region string `json:"region,omitempty"` - City string `json:"city,omitempty"` - UpdatedAt time.Time `json:"updated_at"` + Success bool `json:"success"` + LatencyMs *int64 `json:"latency_ms,omitempty"` + Message string `json:"message,omitempty"` + IPAddress string `json:"ip_address,omitempty"` + Country string `json:"country,omitempty"` + CountryCode string `json:"country_code,omitempty"` + Region string `json:"region,omitempty"` + City string `json:"city,omitempty"` + QualityStatus string `json:"quality_status,omitempty"` + QualityScore *int `json:"quality_score,omitempty"` + QualityGrade string `json:"quality_grade,omitempty"` + QualitySummary string `json:"quality_summary,omitempty"` + QualityCheckedAt *int64 `json:"quality_checked_at,omitempty"` + QualityCFRay string `json:"quality_cf_ray,omitempty"` + UpdatedAt time.Time `json:"updated_at"` } type ProxyLatencyCache interface { diff --git a/backend/internal/service/ratelimit_service.go b/backend/internal/service/ratelimit_service.go index b1d767fc..fcc7c4a0 100644 --- a/backend/internal/service/ratelimit_service.go +++ b/backend/internal/service/ratelimit_service.go @@ -738,7 +738,19 @@ func (s *RateLimitService) ClearRateLimit(ctx context.Context, accountID int64) if err := s.accountRepo.ClearAntigravityQuotaScopes(ctx, accountID); err != nil { return err } - return s.accountRepo.ClearModelRateLimits(ctx, accountID) + if err := s.accountRepo.ClearModelRateLimits(ctx, accountID); err != nil { + return err + } + // 清除限流时一并清理临时不可调度状态,避免周限/窗口重置后仍被本地临时状态阻断。 + if err := s.accountRepo.ClearTempUnschedulable(ctx, accountID); err != nil { + return err + } + if s.tempUnschedCache != nil { + if err := s.tempUnschedCache.DeleteTempUnsched(ctx, accountID); err != nil { + slog.Warn("temp_unsched_cache_delete_failed", "account_id", accountID, "error", err) + } + } + return nil } func (s *RateLimitService) ClearTempUnschedulable(ctx context.Context, accountID int64) error { diff --git a/backend/internal/service/ratelimit_service_clear_test.go b/backend/internal/service/ratelimit_service_clear_test.go new file mode 100644 index 00000000..f48151ed --- /dev/null +++ b/backend/internal/service/ratelimit_service_clear_test.go @@ -0,0 +1,172 @@ +//go:build unit + +package service + +import ( + "context" + "errors" + "testing" + + "github.com/Wei-Shaw/sub2api/internal/config" + "github.com/stretchr/testify/require" +) + +type rateLimitClearRepoStub struct { + mockAccountRepoForGemini + clearRateLimitCalls int + clearAntigravityCalls int + clearModelRateLimitCalls int + clearTempUnschedCalls int + clearRateLimitErr error + clearAntigravityErr error + clearModelRateLimitErr error + clearTempUnschedulableErr error +} + +func (r *rateLimitClearRepoStub) ClearRateLimit(ctx context.Context, id int64) error { + r.clearRateLimitCalls++ + return r.clearRateLimitErr +} + +func (r *rateLimitClearRepoStub) ClearAntigravityQuotaScopes(ctx context.Context, id int64) error { + r.clearAntigravityCalls++ + return r.clearAntigravityErr +} + +func (r *rateLimitClearRepoStub) ClearModelRateLimits(ctx context.Context, id int64) error { + r.clearModelRateLimitCalls++ + return r.clearModelRateLimitErr +} + +func (r *rateLimitClearRepoStub) ClearTempUnschedulable(ctx context.Context, id int64) error { + r.clearTempUnschedCalls++ + return r.clearTempUnschedulableErr +} + +type tempUnschedCacheRecorder struct { + deletedIDs []int64 + deleteErr error +} + +func (c *tempUnschedCacheRecorder) SetTempUnsched(ctx context.Context, accountID int64, state *TempUnschedState) error { + return nil +} + +func (c *tempUnschedCacheRecorder) GetTempUnsched(ctx context.Context, accountID int64) (*TempUnschedState, error) { + return nil, nil +} + +func (c *tempUnschedCacheRecorder) DeleteTempUnsched(ctx context.Context, accountID int64) error { + c.deletedIDs = append(c.deletedIDs, accountID) + return c.deleteErr +} + +func TestRateLimitService_ClearRateLimit_AlsoClearsTempUnschedulable(t *testing.T) { + repo := &rateLimitClearRepoStub{} + cache := &tempUnschedCacheRecorder{} + svc := NewRateLimitService(repo, nil, &config.Config{}, nil, cache) + + err := svc.ClearRateLimit(context.Background(), 42) + require.NoError(t, err) + + require.Equal(t, 1, repo.clearRateLimitCalls) + require.Equal(t, 1, repo.clearAntigravityCalls) + require.Equal(t, 1, repo.clearModelRateLimitCalls) + require.Equal(t, 1, repo.clearTempUnschedCalls) + require.Equal(t, []int64{42}, cache.deletedIDs) +} + +func TestRateLimitService_ClearRateLimit_ClearTempUnschedulableFailed(t *testing.T) { + repo := &rateLimitClearRepoStub{ + clearTempUnschedulableErr: errors.New("clear temp unsched failed"), + } + cache := &tempUnschedCacheRecorder{} + svc := NewRateLimitService(repo, nil, &config.Config{}, nil, cache) + + err := svc.ClearRateLimit(context.Background(), 7) + require.Error(t, err) + + require.Equal(t, 1, repo.clearTempUnschedCalls) + require.Empty(t, cache.deletedIDs) +} + +func TestRateLimitService_ClearRateLimit_ClearRateLimitFailed(t *testing.T) { + repo := &rateLimitClearRepoStub{ + clearRateLimitErr: errors.New("clear rate limit failed"), + } + cache := &tempUnschedCacheRecorder{} + svc := NewRateLimitService(repo, nil, &config.Config{}, nil, cache) + + err := svc.ClearRateLimit(context.Background(), 11) + require.Error(t, err) + + require.Equal(t, 1, repo.clearRateLimitCalls) + require.Equal(t, 0, repo.clearAntigravityCalls) + require.Equal(t, 0, repo.clearModelRateLimitCalls) + require.Equal(t, 0, repo.clearTempUnschedCalls) + require.Empty(t, cache.deletedIDs) +} + +func TestRateLimitService_ClearRateLimit_ClearAntigravityFailed(t *testing.T) { + repo := &rateLimitClearRepoStub{ + clearAntigravityErr: errors.New("clear antigravity failed"), + } + cache := &tempUnschedCacheRecorder{} + svc := NewRateLimitService(repo, nil, &config.Config{}, nil, cache) + + err := svc.ClearRateLimit(context.Background(), 12) + require.Error(t, err) + + require.Equal(t, 1, repo.clearRateLimitCalls) + require.Equal(t, 1, repo.clearAntigravityCalls) + require.Equal(t, 0, repo.clearModelRateLimitCalls) + require.Equal(t, 0, repo.clearTempUnschedCalls) + require.Empty(t, cache.deletedIDs) +} + +func TestRateLimitService_ClearRateLimit_ClearModelRateLimitsFailed(t *testing.T) { + repo := &rateLimitClearRepoStub{ + clearModelRateLimitErr: errors.New("clear model rate limits failed"), + } + cache := &tempUnschedCacheRecorder{} + svc := NewRateLimitService(repo, nil, &config.Config{}, nil, cache) + + err := svc.ClearRateLimit(context.Background(), 13) + require.Error(t, err) + + require.Equal(t, 1, repo.clearRateLimitCalls) + require.Equal(t, 1, repo.clearAntigravityCalls) + require.Equal(t, 1, repo.clearModelRateLimitCalls) + require.Equal(t, 0, repo.clearTempUnschedCalls) + require.Empty(t, cache.deletedIDs) +} + +func TestRateLimitService_ClearRateLimit_CacheDeleteFailedShouldNotFail(t *testing.T) { + repo := &rateLimitClearRepoStub{} + cache := &tempUnschedCacheRecorder{ + deleteErr: errors.New("cache delete failed"), + } + svc := NewRateLimitService(repo, nil, &config.Config{}, nil, cache) + + err := svc.ClearRateLimit(context.Background(), 14) + require.NoError(t, err) + + require.Equal(t, 1, repo.clearRateLimitCalls) + require.Equal(t, 1, repo.clearAntigravityCalls) + require.Equal(t, 1, repo.clearModelRateLimitCalls) + require.Equal(t, 1, repo.clearTempUnschedCalls) + require.Equal(t, []int64{14}, cache.deletedIDs) +} + +func TestRateLimitService_ClearRateLimit_WithoutTempUnschedCache(t *testing.T) { + repo := &rateLimitClearRepoStub{} + svc := NewRateLimitService(repo, nil, &config.Config{}, nil, nil) + + err := svc.ClearRateLimit(context.Background(), 15) + require.NoError(t, err) + + require.Equal(t, 1, repo.clearRateLimitCalls) + require.Equal(t, 1, repo.clearAntigravityCalls) + require.Equal(t, 1, repo.clearModelRateLimitCalls) + require.Equal(t, 1, repo.clearTempUnschedCalls) +} diff --git a/backend/internal/service/scheduler_shuffle_test.go b/backend/internal/service/scheduler_shuffle_test.go index 78ac5f57..0d82b2f3 100644 --- a/backend/internal/service/scheduler_shuffle_test.go +++ b/backend/internal/service/scheduler_shuffle_test.go @@ -125,13 +125,13 @@ func TestShuffleWithinSortGroups_MixedGroups(t *testing.T) { // ============ shuffleWithinPriorityAndLastUsed 测试 ============ func TestShuffleWithinPriorityAndLastUsed_Empty(t *testing.T) { - shuffleWithinPriorityAndLastUsed(nil) - shuffleWithinPriorityAndLastUsed([]*Account{}) + shuffleWithinPriorityAndLastUsed(nil, false) + shuffleWithinPriorityAndLastUsed([]*Account{}, false) } func TestShuffleWithinPriorityAndLastUsed_SingleElement(t *testing.T) { accounts := []*Account{{ID: 1, Priority: 1}} - shuffleWithinPriorityAndLastUsed(accounts) + shuffleWithinPriorityAndLastUsed(accounts, false) require.Equal(t, int64(1), accounts[0].ID) } @@ -146,7 +146,7 @@ func TestShuffleWithinPriorityAndLastUsed_SameGroup_Shuffled(t *testing.T) { for i := 0; i < 100; i++ { cpy := make([]*Account, len(accounts)) copy(cpy, accounts) - shuffleWithinPriorityAndLastUsed(cpy) + shuffleWithinPriorityAndLastUsed(cpy, false) seen[cpy[0].ID] = true } require.GreaterOrEqual(t, len(seen), 2, "same group should be shuffled") @@ -162,7 +162,7 @@ func TestShuffleWithinPriorityAndLastUsed_DifferentPriority_OrderPreserved(t *te for i := 0; i < 20; i++ { cpy := make([]*Account, len(accounts)) copy(cpy, accounts) - shuffleWithinPriorityAndLastUsed(cpy) + shuffleWithinPriorityAndLastUsed(cpy, false) require.Equal(t, int64(1), cpy[0].ID) require.Equal(t, int64(2), cpy[1].ID) require.Equal(t, int64(3), cpy[2].ID) @@ -182,7 +182,7 @@ func TestShuffleWithinPriorityAndLastUsed_DifferentLastUsedAt_OrderPreserved(t * for i := 0; i < 20; i++ { cpy := make([]*Account, len(accounts)) copy(cpy, accounts) - shuffleWithinPriorityAndLastUsed(cpy) + shuffleWithinPriorityAndLastUsed(cpy, false) require.Equal(t, int64(1), cpy[0].ID) require.Equal(t, int64(2), cpy[1].ID) require.Equal(t, int64(3), cpy[2].ID) diff --git a/backend/internal/service/scheduler_snapshot_service.go b/backend/internal/service/scheduler_snapshot_service.go index 52d455b8..4d95743c 100644 --- a/backend/internal/service/scheduler_snapshot_service.go +++ b/backend/internal/service/scheduler_snapshot_service.go @@ -4,12 +4,13 @@ import ( "context" "encoding/json" "errors" - "log" + "log/slog" "strconv" "sync" "time" "github.com/Wei-Shaw/sub2api/internal/config" + "github.com/Wei-Shaw/sub2api/internal/pkg/logger" ) var ( @@ -103,7 +104,7 @@ func (s *SchedulerSnapshotService) ListSchedulableAccounts(ctx context.Context, if s.cache != nil { cached, hit, err := s.cache.GetSnapshot(ctx, bucket) if err != nil { - log.Printf("[Scheduler] cache read failed: bucket=%s err=%v", bucket.String(), err) + logger.LegacyPrintf("service.scheduler_snapshot", "[Scheduler] cache read failed: bucket=%s err=%v", bucket.String(), err) } else if hit { return derefAccounts(cached), useMixed, nil } @@ -123,7 +124,7 @@ func (s *SchedulerSnapshotService) ListSchedulableAccounts(ctx context.Context, if s.cache != nil { if err := s.cache.SetSnapshot(fallbackCtx, bucket, accounts); err != nil { - log.Printf("[Scheduler] cache write failed: bucket=%s err=%v", bucket.String(), err) + logger.LegacyPrintf("service.scheduler_snapshot", "[Scheduler] cache write failed: bucket=%s err=%v", bucket.String(), err) } } @@ -137,7 +138,7 @@ func (s *SchedulerSnapshotService) GetAccount(ctx context.Context, accountID int if s.cache != nil { account, err := s.cache.GetAccount(ctx, accountID) if err != nil { - log.Printf("[Scheduler] account cache read failed: id=%d err=%v", accountID, err) + logger.LegacyPrintf("service.scheduler_snapshot", "[Scheduler] account cache read failed: id=%d err=%v", accountID, err) } else if account != nil { return account, nil } @@ -167,17 +168,17 @@ func (s *SchedulerSnapshotService) runInitialRebuild() { defer cancel() buckets, err := s.cache.ListBuckets(ctx) if err != nil { - log.Printf("[Scheduler] list buckets failed: %v", err) + logger.LegacyPrintf("service.scheduler_snapshot", "[Scheduler] list buckets failed: %v", err) } if len(buckets) == 0 { buckets, err = s.defaultBuckets(ctx) if err != nil { - log.Printf("[Scheduler] default buckets failed: %v", err) + logger.LegacyPrintf("service.scheduler_snapshot", "[Scheduler] default buckets failed: %v", err) return } } if err := s.rebuildBuckets(ctx, buckets, "startup"); err != nil { - log.Printf("[Scheduler] rebuild startup failed: %v", err) + logger.LegacyPrintf("service.scheduler_snapshot", "[Scheduler] rebuild startup failed: %v", err) } } @@ -204,7 +205,7 @@ func (s *SchedulerSnapshotService) runFullRebuildWorker(interval time.Duration) select { case <-ticker.C: if err := s.triggerFullRebuild("interval"); err != nil { - log.Printf("[Scheduler] full rebuild failed: %v", err) + logger.LegacyPrintf("service.scheduler_snapshot", "[Scheduler] full rebuild failed: %v", err) } case <-s.stopCh: return @@ -221,13 +222,13 @@ func (s *SchedulerSnapshotService) pollOutbox() { watermark, err := s.cache.GetOutboxWatermark(ctx) if err != nil { - log.Printf("[Scheduler] outbox watermark read failed: %v", err) + logger.LegacyPrintf("service.scheduler_snapshot", "[Scheduler] outbox watermark read failed: %v", err) return } events, err := s.outboxRepo.ListAfter(ctx, watermark, 200) if err != nil { - log.Printf("[Scheduler] outbox poll failed: %v", err) + logger.LegacyPrintf("service.scheduler_snapshot", "[Scheduler] outbox poll failed: %v", err) return } if len(events) == 0 { @@ -240,14 +241,14 @@ func (s *SchedulerSnapshotService) pollOutbox() { err := s.handleOutboxEvent(eventCtx, event) cancel() if err != nil { - log.Printf("[Scheduler] outbox handle failed: id=%d type=%s err=%v", event.ID, event.EventType, err) + logger.LegacyPrintf("service.scheduler_snapshot", "[Scheduler] outbox handle failed: id=%d type=%s err=%v", event.ID, event.EventType, err) return } } lastID := events[len(events)-1].ID if err := s.cache.SetOutboxWatermark(ctx, lastID); err != nil { - log.Printf("[Scheduler] outbox watermark write failed: %v", err) + logger.LegacyPrintf("service.scheduler_snapshot", "[Scheduler] outbox watermark write failed: %v", err) } else { watermarkForCheck = lastID } @@ -444,14 +445,14 @@ func (s *SchedulerSnapshotService) rebuildBucket(ctx context.Context, bucket Sch accounts, err := s.loadAccountsFromDB(rebuildCtx, bucket, bucket.Mode == SchedulerModeMixed) if err != nil { - log.Printf("[Scheduler] rebuild failed: bucket=%s reason=%s err=%v", bucket.String(), reason, err) + logger.LegacyPrintf("service.scheduler_snapshot", "[Scheduler] rebuild failed: bucket=%s reason=%s err=%v", bucket.String(), reason, err) return err } if err := s.cache.SetSnapshot(rebuildCtx, bucket, accounts); err != nil { - log.Printf("[Scheduler] rebuild cache failed: bucket=%s reason=%s err=%v", bucket.String(), reason, err) + logger.LegacyPrintf("service.scheduler_snapshot", "[Scheduler] rebuild cache failed: bucket=%s reason=%s err=%v", bucket.String(), reason, err) return err } - log.Printf("[Scheduler] rebuild ok: bucket=%s reason=%s size=%d", bucket.String(), reason, len(accounts)) + slog.Debug("[Scheduler] rebuild ok", "bucket", bucket.String(), "reason", reason, "size", len(accounts)) return nil } @@ -464,13 +465,13 @@ func (s *SchedulerSnapshotService) triggerFullRebuild(reason string) error { buckets, err := s.cache.ListBuckets(ctx) if err != nil { - log.Printf("[Scheduler] list buckets failed: %v", err) + logger.LegacyPrintf("service.scheduler_snapshot", "[Scheduler] list buckets failed: %v", err) return err } if len(buckets) == 0 { buckets, err = s.defaultBuckets(ctx) if err != nil { - log.Printf("[Scheduler] default buckets failed: %v", err) + logger.LegacyPrintf("service.scheduler_snapshot", "[Scheduler] default buckets failed: %v", err) return err } } @@ -484,7 +485,7 @@ func (s *SchedulerSnapshotService) checkOutboxLag(ctx context.Context, oldest Sc lag := time.Since(oldest.CreatedAt) if lagSeconds := int(lag.Seconds()); lagSeconds >= s.cfg.Gateway.Scheduling.OutboxLagWarnSeconds && s.cfg.Gateway.Scheduling.OutboxLagWarnSeconds > 0 { - log.Printf("[Scheduler] outbox lag warning: %ds", lagSeconds) + logger.LegacyPrintf("service.scheduler_snapshot", "[Scheduler] outbox lag warning: %ds", lagSeconds) } if s.cfg.Gateway.Scheduling.OutboxLagRebuildSeconds > 0 && int(lag.Seconds()) >= s.cfg.Gateway.Scheduling.OutboxLagRebuildSeconds { @@ -494,12 +495,12 @@ func (s *SchedulerSnapshotService) checkOutboxLag(ctx context.Context, oldest Sc s.lagMu.Unlock() if failures >= s.cfg.Gateway.Scheduling.OutboxLagRebuildFailures { - log.Printf("[Scheduler] outbox lag rebuild triggered: lag=%s failures=%d", lag, failures) + logger.LegacyPrintf("service.scheduler_snapshot", "[Scheduler] outbox lag rebuild triggered: lag=%s failures=%d", lag, failures) s.lagMu.Lock() s.lagFailures = 0 s.lagMu.Unlock() if err := s.triggerFullRebuild("outbox_lag"); err != nil { - log.Printf("[Scheduler] outbox lag rebuild failed: %v", err) + logger.LegacyPrintf("service.scheduler_snapshot", "[Scheduler] outbox lag rebuild failed: %v", err) } } } else { @@ -517,9 +518,9 @@ func (s *SchedulerSnapshotService) checkOutboxLag(ctx context.Context, oldest Sc return } if maxID-watermark >= int64(threshold) { - log.Printf("[Scheduler] outbox backlog rebuild triggered: backlog=%d", maxID-watermark) + logger.LegacyPrintf("service.scheduler_snapshot", "[Scheduler] outbox backlog rebuild triggered: backlog=%d", maxID-watermark) if err := s.triggerFullRebuild("outbox_backlog"); err != nil { - log.Printf("[Scheduler] outbox backlog rebuild failed: %v", err) + logger.LegacyPrintf("service.scheduler_snapshot", "[Scheduler] outbox backlog rebuild failed: %v", err) } } } diff --git a/backend/internal/service/sora_account_service.go b/backend/internal/service/sora_account_service.go new file mode 100644 index 00000000..eccc1acf --- /dev/null +++ b/backend/internal/service/sora_account_service.go @@ -0,0 +1,40 @@ +package service + +import "context" + +// SoraAccountRepository Sora 账号扩展表仓储接口 +// 用于管理 sora_accounts 表,与 accounts 主表形成双表结构。 +// +// 设计说明: +// - sora_accounts 表存储 Sora 账号的 OAuth 凭证副本 +// - Sora gateway 优先读取此表的字段以获得更好的查询性能 +// - 主表 accounts 通过 credentials JSON 字段也存储相同信息 +// - Token 刷新时需要同时更新两个表以保持数据一致性 +type SoraAccountRepository interface { + // Upsert 创建或更新 Sora 账号扩展信息 + // accountID: 关联的 accounts.id + // updates: 要更新的字段,支持 access_token、refresh_token、session_token + // + // 如果记录不存在则创建,存在则更新。 + // 用于: + // 1. 创建 Sora 账号时初始化扩展表 + // 2. Token 刷新时同步更新扩展表 + Upsert(ctx context.Context, accountID int64, updates map[string]any) error + + // GetByAccountID 根据账号 ID 获取 Sora 扩展信息 + // 返回 nil, nil 表示记录不存在(非错误) + GetByAccountID(ctx context.Context, accountID int64) (*SoraAccount, error) + + // Delete 删除 Sora 账号扩展信息 + // 通常由外键 ON DELETE CASCADE 自动处理,此方法用于手动清理 + Delete(ctx context.Context, accountID int64) error +} + +// SoraAccount Sora 账号扩展信息 +// 对应 sora_accounts 表,存储 Sora 账号的 OAuth 凭证副本 +type SoraAccount struct { + AccountID int64 // 关联的 accounts.id + AccessToken string // OAuth access_token + RefreshToken string // OAuth refresh_token + SessionToken string // Session token(可选,用于 ST→AT 兜底) +} diff --git a/backend/internal/service/sora_client.go b/backend/internal/service/sora_client.go new file mode 100644 index 00000000..7cecfa03 --- /dev/null +++ b/backend/internal/service/sora_client.go @@ -0,0 +1,2123 @@ +package service + +import ( + "bytes" + "context" + "encoding/base64" + "encoding/hex" + "encoding/json" + "errors" + "fmt" + "hash/fnv" + "io" + "log" + "math/rand" + "mime" + "mime/multipart" + "net/http" + "net/textproto" + "net/url" + "path" + "sort" + "strconv" + "strings" + "sync" + "time" + + "github.com/Wei-Shaw/sub2api/internal/config" + openaioauth "github.com/Wei-Shaw/sub2api/internal/pkg/openai" + "github.com/Wei-Shaw/sub2api/internal/util/logredact" + "github.com/Wei-Shaw/sub2api/internal/util/soraerror" + "github.com/google/uuid" + "github.com/tidwall/gjson" + "golang.org/x/crypto/sha3" +) + +const ( + soraChatGPTBaseURL = "https://chatgpt.com" + soraSentinelFlow = "sora_2_create_task" + soraDefaultUserAgent = "Sora/1.2026.007 (Android 15; 24122RKC7C; build 2600700)" +) + +var ( + soraSessionAuthURL = "https://sora.chatgpt.com/api/auth/session" + soraOAuthTokenURL = "https://auth.openai.com/oauth/token" +) + +const ( + soraPowMaxIteration = 500000 +) + +var soraPowCores = []int{8, 16, 24, 32} + +var soraPowScripts = []string{ + "https://cdn.oaistatic.com/_next/static/cXh69klOLzS0Gy2joLDRS/_ssgManifest.js?dpl=453ebaec0d44c2decab71692e1bfe39be35a24b3", +} + +var soraPowDPL = []string{ + "prod-f501fe933b3edf57aea882da888e1a544df99840", +} + +var soraPowNavigatorKeys = []string{ + "registerProtocolHandler−function registerProtocolHandler() { [native code] }", + "storage−[object StorageManager]", + "locks−[object LockManager]", + "appCodeName−Mozilla", + "permissions−[object Permissions]", + "webdriver−false", + "vendor−Google Inc.", + "mediaDevices−[object MediaDevices]", + "cookieEnabled−true", + "product−Gecko", + "productSub−20030107", + "hardwareConcurrency−32", + "onLine−true", +} + +var soraPowDocumentKeys = []string{ + "_reactListeningo743lnnpvdg", + "location", +} + +var soraPowWindowKeys = []string{ + "0", "window", "self", "document", "name", "location", + "navigator", "screen", "innerWidth", "innerHeight", + "localStorage", "sessionStorage", "crypto", "performance", + "fetch", "setTimeout", "setInterval", "console", +} + +var soraDesktopUserAgents = []string{ + "Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/131.0.0.0 Safari/537.36", + "Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/130.0.0.0 Safari/537.36", + "Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/129.0.0.0 Safari/537.36", + "Mozilla/5.0 (Macintosh; Intel Mac OS X 10_15_7) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/131.0.0.0 Safari/537.36", + "Mozilla/5.0 (Macintosh; Intel Mac OS X 10_15_7) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/130.0.0.0 Safari/537.36", + "Mozilla/5.0 (X11; Linux x86_64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/131.0.0.0 Safari/537.36", + "Mozilla/5.0 (Windows NT 11.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/131.0.0.0 Safari/537.36", +} + +var soraMobileUserAgents = []string{ + "Sora/1.2026.007 (Android 15; 24122RKC7C; build 2600700)", + "Sora/1.2026.007 (Android 14; SM-G998B; build 2600700)", + "Sora/1.2026.007 (Android 15; Pixel 8 Pro; build 2600700)", + "Sora/1.2026.007 (Android 14; Pixel 7; build 2600700)", + "Sora/1.2026.007 (Android 15; 2211133C; build 2600700)", + "Sora/1.2026.007 (Android 14; SM-S918B; build 2600700)", + "Sora/1.2026.007 (Android 15; OnePlus 12; build 2600700)", +} + +var soraRand = rand.New(rand.NewSource(time.Now().UnixNano())) +var soraRandMu sync.Mutex +var soraPerfStart = time.Now() +var soraPowTokenGenerator = soraGetPowToken + +// SoraClient 定义直连 Sora 的任务操作接口。 +type SoraClient interface { + Enabled() bool + UploadImage(ctx context.Context, account *Account, data []byte, filename string) (string, error) + CreateImageTask(ctx context.Context, account *Account, req SoraImageRequest) (string, error) + CreateVideoTask(ctx context.Context, account *Account, req SoraVideoRequest) (string, error) + CreateStoryboardTask(ctx context.Context, account *Account, req SoraStoryboardRequest) (string, error) + UploadCharacterVideo(ctx context.Context, account *Account, data []byte) (string, error) + GetCameoStatus(ctx context.Context, account *Account, cameoID string) (*SoraCameoStatus, error) + DownloadCharacterImage(ctx context.Context, account *Account, imageURL string) ([]byte, error) + UploadCharacterImage(ctx context.Context, account *Account, data []byte) (string, error) + FinalizeCharacter(ctx context.Context, account *Account, req SoraCharacterFinalizeRequest) (string, error) + SetCharacterPublic(ctx context.Context, account *Account, cameoID string) error + DeleteCharacter(ctx context.Context, account *Account, characterID string) error + PostVideoForWatermarkFree(ctx context.Context, account *Account, generationID string) (string, error) + DeletePost(ctx context.Context, account *Account, postID string) error + GetWatermarkFreeURLCustom(ctx context.Context, account *Account, parseURL, parseToken, postID string) (string, error) + EnhancePrompt(ctx context.Context, account *Account, prompt, expansionLevel string, durationS int) (string, error) + GetImageTask(ctx context.Context, account *Account, taskID string) (*SoraImageTaskStatus, error) + GetVideoTask(ctx context.Context, account *Account, taskID string) (*SoraVideoTaskStatus, error) +} + +// SoraImageRequest 图片生成请求参数 +type SoraImageRequest struct { + Prompt string + Width int + Height int + MediaID string +} + +// SoraVideoRequest 视频生成请求参数 +type SoraVideoRequest struct { + Prompt string + Orientation string + Frames int + Model string + Size string + MediaID string + RemixTargetID string + CameoIDs []string +} + +// SoraStoryboardRequest 分镜视频生成请求参数 +type SoraStoryboardRequest struct { + Prompt string + Orientation string + Frames int + Model string + Size string + MediaID string +} + +// SoraImageTaskStatus 图片任务状态 +type SoraImageTaskStatus struct { + ID string + Status string + ProgressPct float64 + URLs []string + ErrorMsg string +} + +// SoraVideoTaskStatus 视频任务状态 +type SoraVideoTaskStatus struct { + ID string + Status string + ProgressPct int + URLs []string + GenerationID string + ErrorMsg string +} + +// SoraCameoStatus 角色处理中间态 +type SoraCameoStatus struct { + Status string + StatusMessage string + DisplayNameHint string + UsernameHint string + ProfileAssetURL string + InstructionSetHint any + InstructionSet any +} + +// SoraCharacterFinalizeRequest 角色定稿请求参数 +type SoraCharacterFinalizeRequest struct { + CameoID string + Username string + DisplayName string + ProfileAssetPointer string + InstructionSet any +} + +// SoraUpstreamError 上游错误 +type SoraUpstreamError struct { + StatusCode int + Message string + Headers http.Header + Body []byte +} + +func (e *SoraUpstreamError) Error() string { + if e == nil { + return "sora upstream error" + } + if e.Message != "" { + return fmt.Sprintf("sora upstream error: %d %s", e.StatusCode, e.Message) + } + return fmt.Sprintf("sora upstream error: %d", e.StatusCode) +} + +// SoraDirectClient 直连 Sora 实现 +type SoraDirectClient struct { + cfg *config.Config + httpUpstream HTTPUpstream + tokenProvider *OpenAITokenProvider + accountRepo AccountRepository + soraAccountRepo SoraAccountRepository + baseURL string + challengeCooldownMu sync.RWMutex + challengeCooldowns map[string]soraChallengeCooldownEntry + sidecarSessionMu sync.RWMutex + sidecarSessions map[string]soraSidecarSessionEntry +} + +type soraRequestTraceContextKey struct{} + +type soraRequestTrace struct { + ID string + ProxyKey string + UAHash string +} + +// NewSoraDirectClient 创建 Sora 直连客户端 +func NewSoraDirectClient(cfg *config.Config, httpUpstream HTTPUpstream, tokenProvider *OpenAITokenProvider) *SoraDirectClient { + baseURL := "" + if cfg != nil { + rawBaseURL := strings.TrimRight(strings.TrimSpace(cfg.Sora.Client.BaseURL), "/") + baseURL = normalizeSoraBaseURL(rawBaseURL) + if rawBaseURL != "" && baseURL != rawBaseURL { + log.Printf("[SoraClient] normalized base_url from %s to %s", sanitizeSoraLogURL(rawBaseURL), sanitizeSoraLogURL(baseURL)) + } + } + return &SoraDirectClient{ + cfg: cfg, + httpUpstream: httpUpstream, + tokenProvider: tokenProvider, + baseURL: baseURL, + challengeCooldowns: make(map[string]soraChallengeCooldownEntry), + sidecarSessions: make(map[string]soraSidecarSessionEntry), + } +} + +func (c *SoraDirectClient) SetAccountRepositories(accountRepo AccountRepository, soraAccountRepo SoraAccountRepository) { + if c == nil { + return + } + c.accountRepo = accountRepo + c.soraAccountRepo = soraAccountRepo +} + +// Enabled 判断是否启用 Sora 直连 +func (c *SoraDirectClient) Enabled() bool { + if c == nil { + return false + } + if strings.TrimSpace(c.baseURL) != "" { + return true + } + if c.cfg == nil { + return false + } + return strings.TrimSpace(normalizeSoraBaseURL(c.cfg.Sora.Client.BaseURL)) != "" +} + +// PreflightCheck 在创建任务前执行账号能力预检。 +// 当前仅对视频模型执行 /nf/check 预检,用于提前识别额度耗尽或能力缺失。 +func (c *SoraDirectClient) PreflightCheck(ctx context.Context, account *Account, requestedModel string, modelCfg SoraModelConfig) error { + if modelCfg.Type != "video" { + return nil + } + token, err := c.getAccessToken(ctx, account) + if err != nil { + return err + } + userAgent := c.taskUserAgent() + proxyURL := c.resolveProxyURL(account) + headers := c.buildBaseHeaders(token, userAgent) + headers.Set("Accept", "application/json") + body, _, err := c.doRequestWithProxy(ctx, account, proxyURL, http.MethodGet, c.buildURL("/nf/check"), headers, nil, false) + if err != nil { + var upstreamErr *SoraUpstreamError + if errors.As(err, &upstreamErr) && upstreamErr.StatusCode == http.StatusNotFound { + return &SoraUpstreamError{ + StatusCode: http.StatusForbidden, + Message: "当前账号未开通 Sora2 能力或无可用配额", + Headers: upstreamErr.Headers, + Body: upstreamErr.Body, + } + } + return err + } + + rateLimitReached := gjson.GetBytes(body, "rate_limit_and_credit_balance.rate_limit_reached").Bool() + remaining := gjson.GetBytes(body, "rate_limit_and_credit_balance.estimated_num_videos_remaining") + if rateLimitReached || (remaining.Exists() && remaining.Int() <= 0) { + msg := "当前账号 Sora2 可用配额不足" + if requestedModel != "" { + msg = fmt.Sprintf("当前账号 %s 可用配额不足", requestedModel) + } + return &SoraUpstreamError{ + StatusCode: http.StatusTooManyRequests, + Message: msg, + Headers: http.Header{}, + } + } + return nil +} + +func (c *SoraDirectClient) UploadImage(ctx context.Context, account *Account, data []byte, filename string) (string, error) { + if len(data) == 0 { + return "", errors.New("empty image data") + } + token, err := c.getAccessToken(ctx, account) + if err != nil { + return "", err + } + userAgent := c.taskUserAgent() + proxyURL := c.resolveProxyURL(account) + if filename == "" { + filename = "image.png" + } + var body bytes.Buffer + writer := multipart.NewWriter(&body) + contentType := mime.TypeByExtension(path.Ext(filename)) + if contentType == "" { + contentType = "application/octet-stream" + } + partHeader := make(textproto.MIMEHeader) + partHeader.Set("Content-Disposition", fmt.Sprintf(`form-data; name="file"; filename="%s"`, filename)) + partHeader.Set("Content-Type", contentType) + part, err := writer.CreatePart(partHeader) + if err != nil { + return "", err + } + if _, err := part.Write(data); err != nil { + return "", err + } + if err := writer.WriteField("file_name", filename); err != nil { + return "", err + } + if err := writer.Close(); err != nil { + return "", err + } + + headers := c.buildBaseHeaders(token, userAgent) + headers.Set("Content-Type", writer.FormDataContentType()) + + respBody, _, err := c.doRequestWithProxy(ctx, account, proxyURL, http.MethodPost, c.buildURL("/uploads"), headers, &body, false) + if err != nil { + return "", err + } + id := strings.TrimSpace(gjson.GetBytes(respBody, "id").String()) + if id == "" { + return "", errors.New("upload response missing id") + } + return id, nil +} + +func (c *SoraDirectClient) CreateImageTask(ctx context.Context, account *Account, req SoraImageRequest) (string, error) { + token, err := c.getAccessToken(ctx, account) + if err != nil { + return "", err + } + userAgent := c.taskUserAgent() + proxyURL := c.resolveProxyURL(account) + ctx = c.withRequestTrace(ctx, account, proxyURL, userAgent) + operation := "simple_compose" + inpaintItems := []map[string]any{} + if strings.TrimSpace(req.MediaID) != "" { + operation = "remix" + inpaintItems = append(inpaintItems, map[string]any{ + "type": "image", + "frame_index": 0, + "upload_media_id": req.MediaID, + }) + } + payload := map[string]any{ + "type": "image_gen", + "operation": operation, + "prompt": req.Prompt, + "width": req.Width, + "height": req.Height, + "n_variants": 1, + "n_frames": 1, + "inpaint_items": inpaintItems, + } + headers := c.buildBaseHeaders(token, userAgent) + headers.Set("Content-Type", "application/json") + headers.Set("Origin", "https://sora.chatgpt.com") + headers.Set("Referer", "https://sora.chatgpt.com/") + + body, err := json.Marshal(payload) + if err != nil { + return "", err + } + sentinel, err := c.generateSentinelToken(ctx, account, token, userAgent, proxyURL) + if err != nil { + return "", err + } + headers.Set("openai-sentinel-token", sentinel) + + respBody, _, err := c.doRequestWithProxy(ctx, account, proxyURL, http.MethodPost, c.buildURL("/video_gen"), headers, bytes.NewReader(body), true) + if err != nil { + return "", err + } + taskID := strings.TrimSpace(gjson.GetBytes(respBody, "id").String()) + if taskID == "" { + return "", errors.New("image task response missing id") + } + return taskID, nil +} + +func (c *SoraDirectClient) CreateVideoTask(ctx context.Context, account *Account, req SoraVideoRequest) (string, error) { + token, err := c.getAccessToken(ctx, account) + if err != nil { + return "", err + } + userAgent := c.taskUserAgent() + proxyURL := c.resolveProxyURL(account) + ctx = c.withRequestTrace(ctx, account, proxyURL, userAgent) + orientation := req.Orientation + if orientation == "" { + orientation = "landscape" + } + nFrames := req.Frames + if nFrames <= 0 { + nFrames = 450 + } + model := req.Model + if model == "" { + model = "sy_8" + } + size := req.Size + if size == "" { + size = "small" + } + + inpaintItems := []map[string]any{} + if strings.TrimSpace(req.MediaID) != "" { + inpaintItems = append(inpaintItems, map[string]any{ + "kind": "upload", + "upload_id": req.MediaID, + }) + } + payload := map[string]any{ + "kind": "video", + "prompt": req.Prompt, + "orientation": orientation, + "size": size, + "n_frames": nFrames, + "model": model, + "inpaint_items": inpaintItems, + } + if strings.TrimSpace(req.RemixTargetID) != "" { + payload["remix_target_id"] = req.RemixTargetID + payload["cameo_ids"] = []string{} + payload["cameo_replacements"] = map[string]any{} + } else if len(req.CameoIDs) > 0 { + payload["cameo_ids"] = req.CameoIDs + payload["cameo_replacements"] = map[string]any{} + } + + headers := c.buildBaseHeaders(token, userAgent) + headers.Set("Content-Type", "application/json") + headers.Set("Origin", "https://sora.chatgpt.com") + headers.Set("Referer", "https://sora.chatgpt.com/") + body, err := json.Marshal(payload) + if err != nil { + return "", err + } + sentinel, err := c.generateSentinelToken(ctx, account, token, userAgent, proxyURL) + if err != nil { + return "", err + } + headers.Set("openai-sentinel-token", sentinel) + + respBody, _, err := c.doRequestWithProxy(ctx, account, proxyURL, http.MethodPost, c.buildURL("/nf/create"), headers, bytes.NewReader(body), true) + if err != nil { + return "", err + } + taskID := strings.TrimSpace(gjson.GetBytes(respBody, "id").String()) + if taskID == "" { + return "", errors.New("video task response missing id") + } + return taskID, nil +} + +func (c *SoraDirectClient) CreateStoryboardTask(ctx context.Context, account *Account, req SoraStoryboardRequest) (string, error) { + token, err := c.getAccessToken(ctx, account) + if err != nil { + return "", err + } + userAgent := c.taskUserAgent() + proxyURL := c.resolveProxyURL(account) + ctx = c.withRequestTrace(ctx, account, proxyURL, userAgent) + orientation := req.Orientation + if orientation == "" { + orientation = "landscape" + } + nFrames := req.Frames + if nFrames <= 0 { + nFrames = 450 + } + model := req.Model + if model == "" { + model = "sy_8" + } + size := req.Size + if size == "" { + size = "small" + } + + inpaintItems := []map[string]any{} + if strings.TrimSpace(req.MediaID) != "" { + inpaintItems = append(inpaintItems, map[string]any{ + "kind": "upload", + "upload_id": req.MediaID, + }) + } + payload := map[string]any{ + "kind": "video", + "prompt": req.Prompt, + "title": "Draft your video", + "orientation": orientation, + "size": size, + "n_frames": nFrames, + "storyboard_id": nil, + "inpaint_items": inpaintItems, + "remix_target_id": nil, + "model": model, + "metadata": nil, + "style_id": nil, + "cameo_ids": nil, + "cameo_replacements": nil, + "audio_caption": nil, + "audio_transcript": nil, + "video_caption": nil, + } + + headers := c.buildBaseHeaders(token, userAgent) + headers.Set("Content-Type", "application/json") + headers.Set("Origin", "https://sora.chatgpt.com") + headers.Set("Referer", "https://sora.chatgpt.com/") + body, err := json.Marshal(payload) + if err != nil { + return "", err + } + sentinel, err := c.generateSentinelToken(ctx, account, token, userAgent, proxyURL) + if err != nil { + return "", err + } + headers.Set("openai-sentinel-token", sentinel) + + respBody, _, err := c.doRequestWithProxy(ctx, account, proxyURL, http.MethodPost, c.buildURL("/nf/create/storyboard"), headers, bytes.NewReader(body), true) + if err != nil { + return "", err + } + taskID := strings.TrimSpace(gjson.GetBytes(respBody, "id").String()) + if taskID == "" { + return "", errors.New("storyboard task response missing id") + } + return taskID, nil +} + +func (c *SoraDirectClient) UploadCharacterVideo(ctx context.Context, account *Account, data []byte) (string, error) { + if len(data) == 0 { + return "", errors.New("empty video data") + } + token, err := c.getAccessToken(ctx, account) + if err != nil { + return "", err + } + userAgent := c.taskUserAgent() + proxyURL := c.resolveProxyURL(account) + + var body bytes.Buffer + writer := multipart.NewWriter(&body) + partHeader := make(textproto.MIMEHeader) + partHeader.Set("Content-Disposition", `form-data; name="file"; filename="video.mp4"`) + partHeader.Set("Content-Type", "video/mp4") + part, err := writer.CreatePart(partHeader) + if err != nil { + return "", err + } + if _, err := part.Write(data); err != nil { + return "", err + } + if err := writer.WriteField("timestamps", "0,3"); err != nil { + return "", err + } + if err := writer.Close(); err != nil { + return "", err + } + + headers := c.buildBaseHeaders(token, userAgent) + headers.Set("Content-Type", writer.FormDataContentType()) + respBody, _, err := c.doRequestWithProxy(ctx, account, proxyURL, http.MethodPost, c.buildURL("/characters/upload"), headers, &body, false) + if err != nil { + return "", err + } + cameoID := strings.TrimSpace(gjson.GetBytes(respBody, "id").String()) + if cameoID == "" { + return "", errors.New("character upload response missing id") + } + return cameoID, nil +} + +func (c *SoraDirectClient) GetCameoStatus(ctx context.Context, account *Account, cameoID string) (*SoraCameoStatus, error) { + token, err := c.getAccessToken(ctx, account) + if err != nil { + return nil, err + } + userAgent := c.taskUserAgent() + proxyURL := c.resolveProxyURL(account) + headers := c.buildBaseHeaders(token, userAgent) + respBody, _, err := c.doRequestWithProxy( + ctx, + account, + proxyURL, + http.MethodGet, + c.buildURL("/project_y/cameos/in_progress/"+strings.TrimSpace(cameoID)), + headers, + nil, + false, + ) + if err != nil { + return nil, err + } + return &SoraCameoStatus{ + Status: strings.TrimSpace(gjson.GetBytes(respBody, "status").String()), + StatusMessage: strings.TrimSpace(gjson.GetBytes(respBody, "status_message").String()), + DisplayNameHint: strings.TrimSpace(gjson.GetBytes(respBody, "display_name_hint").String()), + UsernameHint: strings.TrimSpace(gjson.GetBytes(respBody, "username_hint").String()), + ProfileAssetURL: strings.TrimSpace(gjson.GetBytes(respBody, "profile_asset_url").String()), + InstructionSetHint: gjson.GetBytes(respBody, "instruction_set_hint").Value(), + InstructionSet: gjson.GetBytes(respBody, "instruction_set").Value(), + }, nil +} + +func (c *SoraDirectClient) DownloadCharacterImage(ctx context.Context, account *Account, imageURL string) ([]byte, error) { + token, err := c.getAccessToken(ctx, account) + if err != nil { + return nil, err + } + userAgent := c.taskUserAgent() + proxyURL := c.resolveProxyURL(account) + headers := c.buildBaseHeaders(token, userAgent) + headers.Set("Accept", "image/*,*/*;q=0.8") + + respBody, _, err := c.doRequestWithProxy( + ctx, + account, + proxyURL, + http.MethodGet, + strings.TrimSpace(imageURL), + headers, + nil, + false, + ) + if err != nil { + return nil, err + } + return respBody, nil +} + +func (c *SoraDirectClient) UploadCharacterImage(ctx context.Context, account *Account, data []byte) (string, error) { + if len(data) == 0 { + return "", errors.New("empty character image") + } + token, err := c.getAccessToken(ctx, account) + if err != nil { + return "", err + } + userAgent := c.taskUserAgent() + proxyURL := c.resolveProxyURL(account) + + var body bytes.Buffer + writer := multipart.NewWriter(&body) + partHeader := make(textproto.MIMEHeader) + partHeader.Set("Content-Disposition", `form-data; name="file"; filename="profile.webp"`) + partHeader.Set("Content-Type", "image/webp") + part, err := writer.CreatePart(partHeader) + if err != nil { + return "", err + } + if _, err := part.Write(data); err != nil { + return "", err + } + if err := writer.WriteField("use_case", "profile"); err != nil { + return "", err + } + if err := writer.Close(); err != nil { + return "", err + } + + headers := c.buildBaseHeaders(token, userAgent) + headers.Set("Content-Type", writer.FormDataContentType()) + respBody, _, err := c.doRequestWithProxy(ctx, account, proxyURL, http.MethodPost, c.buildURL("/project_y/file/upload"), headers, &body, false) + if err != nil { + return "", err + } + assetPointer := strings.TrimSpace(gjson.GetBytes(respBody, "asset_pointer").String()) + if assetPointer == "" { + return "", errors.New("character image upload response missing asset_pointer") + } + return assetPointer, nil +} + +func (c *SoraDirectClient) FinalizeCharacter(ctx context.Context, account *Account, req SoraCharacterFinalizeRequest) (string, error) { + token, err := c.getAccessToken(ctx, account) + if err != nil { + return "", err + } + userAgent := c.taskUserAgent() + proxyURL := c.resolveProxyURL(account) + ctx = c.withRequestTrace(ctx, account, proxyURL, userAgent) + payload := map[string]any{ + "cameo_id": req.CameoID, + "username": req.Username, + "display_name": req.DisplayName, + "profile_asset_pointer": req.ProfileAssetPointer, + "instruction_set": nil, + "safety_instruction_set": nil, + } + body, err := json.Marshal(payload) + if err != nil { + return "", err + } + headers := c.buildBaseHeaders(token, userAgent) + headers.Set("Content-Type", "application/json") + headers.Set("Origin", "https://sora.chatgpt.com") + headers.Set("Referer", "https://sora.chatgpt.com/") + respBody, _, err := c.doRequestWithProxy(ctx, account, proxyURL, http.MethodPost, c.buildURL("/characters/finalize"), headers, bytes.NewReader(body), false) + if err != nil { + return "", err + } + characterID := strings.TrimSpace(gjson.GetBytes(respBody, "character.character_id").String()) + if characterID == "" { + return "", errors.New("character finalize response missing character_id") + } + return characterID, nil +} + +func (c *SoraDirectClient) SetCharacterPublic(ctx context.Context, account *Account, cameoID string) error { + token, err := c.getAccessToken(ctx, account) + if err != nil { + return err + } + userAgent := c.taskUserAgent() + proxyURL := c.resolveProxyURL(account) + payload := map[string]any{"visibility": "public"} + body, err := json.Marshal(payload) + if err != nil { + return err + } + headers := c.buildBaseHeaders(token, userAgent) + headers.Set("Content-Type", "application/json") + headers.Set("Origin", "https://sora.chatgpt.com") + headers.Set("Referer", "https://sora.chatgpt.com/") + _, _, err = c.doRequestWithProxy( + ctx, + account, + proxyURL, + http.MethodPost, + c.buildURL("/project_y/cameos/by_id/"+strings.TrimSpace(cameoID)+"/update_v2"), + headers, + bytes.NewReader(body), + false, + ) + return err +} + +func (c *SoraDirectClient) DeleteCharacter(ctx context.Context, account *Account, characterID string) error { + token, err := c.getAccessToken(ctx, account) + if err != nil { + return err + } + userAgent := c.taskUserAgent() + proxyURL := c.resolveProxyURL(account) + headers := c.buildBaseHeaders(token, userAgent) + _, _, err = c.doRequestWithProxy( + ctx, + account, + proxyURL, + http.MethodDelete, + c.buildURL("/project_y/characters/"+strings.TrimSpace(characterID)), + headers, + nil, + false, + ) + return err +} + +func (c *SoraDirectClient) PostVideoForWatermarkFree(ctx context.Context, account *Account, generationID string) (string, error) { + token, err := c.getAccessToken(ctx, account) + if err != nil { + return "", err + } + userAgent := c.taskUserAgent() + proxyURL := c.resolveProxyURL(account) + ctx = c.withRequestTrace(ctx, account, proxyURL, userAgent) + payload := map[string]any{ + "attachments_to_create": []map[string]any{ + { + "generation_id": generationID, + "kind": "sora", + }, + }, + "post_text": "", + } + body, err := json.Marshal(payload) + if err != nil { + return "", err + } + headers := c.buildBaseHeaders(token, userAgent) + headers.Set("Content-Type", "application/json") + headers.Set("Origin", "https://sora.chatgpt.com") + headers.Set("Referer", "https://sora.chatgpt.com/") + sentinel, err := c.generateSentinelToken(ctx, account, token, userAgent, proxyURL) + if err != nil { + return "", err + } + headers.Set("openai-sentinel-token", sentinel) + respBody, _, err := c.doRequestWithProxy(ctx, account, proxyURL, http.MethodPost, c.buildURL("/project_y/post"), headers, bytes.NewReader(body), true) + if err != nil { + return "", err + } + postID := strings.TrimSpace(gjson.GetBytes(respBody, "post.id").String()) + if postID == "" { + return "", errors.New("watermark-free publish response missing post.id") + } + return postID, nil +} + +func (c *SoraDirectClient) DeletePost(ctx context.Context, account *Account, postID string) error { + token, err := c.getAccessToken(ctx, account) + if err != nil { + return err + } + userAgent := c.taskUserAgent() + proxyURL := c.resolveProxyURL(account) + headers := c.buildBaseHeaders(token, userAgent) + _, _, err = c.doRequestWithProxy( + ctx, + account, + proxyURL, + http.MethodDelete, + c.buildURL("/project_y/post/"+strings.TrimSpace(postID)), + headers, + nil, + false, + ) + return err +} + +func (c *SoraDirectClient) GetWatermarkFreeURLCustom(ctx context.Context, account *Account, parseURL, parseToken, postID string) (string, error) { + parseURL = strings.TrimRight(strings.TrimSpace(parseURL), "/") + if parseURL == "" { + return "", errors.New("custom parse url is required") + } + if strings.TrimSpace(parseToken) == "" { + return "", errors.New("custom parse token is required") + } + shareURL := "https://sora.chatgpt.com/p/" + strings.TrimSpace(postID) + payload := map[string]any{ + "url": shareURL, + "token": strings.TrimSpace(parseToken), + } + body, err := json.Marshal(payload) + if err != nil { + return "", err + } + req, err := http.NewRequestWithContext(ctx, http.MethodPost, parseURL+"/get-sora-link", bytes.NewReader(body)) + if err != nil { + return "", err + } + req.Header.Set("Content-Type", "application/json") + + proxyURL := c.resolveProxyURL(account) + accountID := int64(0) + accountConcurrency := 0 + if account != nil { + accountID = account.ID + accountConcurrency = account.Concurrency + } + var resp *http.Response + if c.httpUpstream != nil { + resp, err = c.httpUpstream.Do(req, proxyURL, accountID, accountConcurrency) + } else { + resp, err = http.DefaultClient.Do(req) + } + if err != nil { + return "", err + } + defer func() { _ = resp.Body.Close() }() + raw, err := io.ReadAll(io.LimitReader(resp.Body, 4<<20)) + if err != nil { + return "", err + } + if resp.StatusCode != http.StatusOK { + return "", fmt.Errorf("custom parse failed: %d %s", resp.StatusCode, truncateForLog(raw, 256)) + } + downloadLink := strings.TrimSpace(gjson.GetBytes(raw, "download_link").String()) + if downloadLink == "" { + return "", errors.New("custom parse response missing download_link") + } + return downloadLink, nil +} + +func (c *SoraDirectClient) EnhancePrompt(ctx context.Context, account *Account, prompt, expansionLevel string, durationS int) (string, error) { + token, err := c.getAccessToken(ctx, account) + if err != nil { + return "", err + } + userAgent := c.taskUserAgent() + proxyURL := c.resolveProxyURL(account) + if strings.TrimSpace(expansionLevel) == "" { + expansionLevel = "medium" + } + if durationS <= 0 { + durationS = 10 + } + + payload := map[string]any{ + "prompt": prompt, + "expansion_level": expansionLevel, + "duration_s": durationS, + } + body, err := json.Marshal(payload) + if err != nil { + return "", err + } + + headers := c.buildBaseHeaders(token, userAgent) + headers.Set("Content-Type", "application/json") + headers.Set("Accept", "application/json") + headers.Set("Origin", "https://sora.chatgpt.com") + headers.Set("Referer", "https://sora.chatgpt.com/") + + respBody, _, err := c.doRequestWithProxy(ctx, account, proxyURL, http.MethodPost, c.buildURL("/editor/enhance_prompt"), headers, bytes.NewReader(body), false) + if err != nil { + return "", err + } + enhancedPrompt := strings.TrimSpace(gjson.GetBytes(respBody, "enhanced_prompt").String()) + if enhancedPrompt == "" { + return "", errors.New("enhance_prompt response missing enhanced_prompt") + } + return enhancedPrompt, nil +} + +func (c *SoraDirectClient) GetImageTask(ctx context.Context, account *Account, taskID string) (*SoraImageTaskStatus, error) { + status, found, err := c.fetchRecentImageTask(ctx, account, taskID, c.recentTaskLimit()) + if err != nil { + return nil, err + } + if found { + return status, nil + } + maxLimit := c.recentTaskLimitMax() + if maxLimit > 0 && maxLimit != c.recentTaskLimit() { + status, found, err = c.fetchRecentImageTask(ctx, account, taskID, maxLimit) + if err != nil { + return nil, err + } + if found { + return status, nil + } + } + return &SoraImageTaskStatus{ID: taskID, Status: "processing"}, nil +} + +func (c *SoraDirectClient) fetchRecentImageTask(ctx context.Context, account *Account, taskID string, limit int) (*SoraImageTaskStatus, bool, error) { + token, err := c.getAccessToken(ctx, account) + if err != nil { + return nil, false, err + } + userAgent := c.taskUserAgent() + proxyURL := c.resolveProxyURL(account) + headers := c.buildBaseHeaders(token, userAgent) + if limit <= 0 { + limit = 20 + } + endpoint := fmt.Sprintf("/v2/recent_tasks?limit=%d", limit) + respBody, _, err := c.doRequestWithProxy(ctx, account, proxyURL, http.MethodGet, c.buildURL(endpoint), headers, nil, false) + if err != nil { + return nil, false, err + } + var found *SoraImageTaskStatus + gjson.GetBytes(respBody, "task_responses").ForEach(func(_, item gjson.Result) bool { + if item.Get("id").String() != taskID { + return true // continue + } + status := strings.TrimSpace(item.Get("status").String()) + progress := item.Get("progress_pct").Float() + var urls []string + item.Get("generations").ForEach(func(_, gen gjson.Result) bool { + if u := strings.TrimSpace(gen.Get("url").String()); u != "" { + urls = append(urls, u) + } + return true + }) + found = &SoraImageTaskStatus{ + ID: taskID, + Status: status, + ProgressPct: progress, + URLs: urls, + } + return false // break + }) + if found != nil { + return found, true, nil + } + return &SoraImageTaskStatus{ID: taskID, Status: "processing"}, false, nil +} + +func (c *SoraDirectClient) recentTaskLimit() int { + if c == nil || c.cfg == nil { + return 20 + } + if c.cfg.Sora.Client.RecentTaskLimit > 0 { + return c.cfg.Sora.Client.RecentTaskLimit + } + return 20 +} + +func (c *SoraDirectClient) recentTaskLimitMax() int { + if c == nil || c.cfg == nil { + return 0 + } + if c.cfg.Sora.Client.RecentTaskLimitMax > 0 { + return c.cfg.Sora.Client.RecentTaskLimitMax + } + return 0 +} + +func (c *SoraDirectClient) GetVideoTask(ctx context.Context, account *Account, taskID string) (*SoraVideoTaskStatus, error) { + token, err := c.getAccessToken(ctx, account) + if err != nil { + return nil, err + } + userAgent := c.taskUserAgent() + proxyURL := c.resolveProxyURL(account) + headers := c.buildBaseHeaders(token, userAgent) + + respBody, _, err := c.doRequestWithProxy(ctx, account, proxyURL, http.MethodGet, c.buildURL("/nf/pending/v2"), headers, nil, false) + if err != nil { + return nil, err + } + // 搜索 pending 列表(JSON 数组) + pendingResult := gjson.ParseBytes(respBody) + if pendingResult.IsArray() { + var pendingFound *SoraVideoTaskStatus + pendingResult.ForEach(func(_, task gjson.Result) bool { + if task.Get("id").String() != taskID { + return true + } + progress := 0 + if v := task.Get("progress_pct"); v.Exists() { + progress = int(v.Float() * 100) + } + status := strings.TrimSpace(task.Get("status").String()) + pendingFound = &SoraVideoTaskStatus{ + ID: taskID, + Status: status, + ProgressPct: progress, + } + return false + }) + if pendingFound != nil { + return pendingFound, nil + } + } + + respBody, _, err = c.doRequestWithProxy(ctx, account, proxyURL, http.MethodGet, c.buildURL("/project_y/profile/drafts?limit=15"), headers, nil, false) + if err != nil { + return nil, err + } + var draftFound *SoraVideoTaskStatus + gjson.GetBytes(respBody, "items").ForEach(func(_, draft gjson.Result) bool { + if draft.Get("task_id").String() != taskID { + return true + } + generationID := strings.TrimSpace(draft.Get("id").String()) + kind := strings.TrimSpace(draft.Get("kind").String()) + reason := strings.TrimSpace(draft.Get("reason_str").String()) + if reason == "" { + reason = strings.TrimSpace(draft.Get("markdown_reason_str").String()) + } + urlStr := strings.TrimSpace(draft.Get("downloadable_url").String()) + if urlStr == "" { + urlStr = strings.TrimSpace(draft.Get("url").String()) + } + + if kind == "sora_content_violation" || reason != "" || urlStr == "" { + msg := reason + if msg == "" { + msg = "Content violates guardrails" + } + draftFound = &SoraVideoTaskStatus{ + ID: taskID, + Status: "failed", + GenerationID: generationID, + ErrorMsg: msg, + } + } else { + draftFound = &SoraVideoTaskStatus{ + ID: taskID, + Status: "completed", + GenerationID: generationID, + URLs: []string{urlStr}, + } + } + return false + }) + if draftFound != nil { + return draftFound, nil + } + + return &SoraVideoTaskStatus{ID: taskID, Status: "processing"}, nil +} + +func (c *SoraDirectClient) buildURL(endpoint string) string { + base := strings.TrimRight(strings.TrimSpace(c.baseURL), "/") + if base == "" && c != nil && c.cfg != nil { + base = normalizeSoraBaseURL(c.cfg.Sora.Client.BaseURL) + c.baseURL = base + } + if base == "" { + return endpoint + } + if strings.HasPrefix(endpoint, "/") { + return base + endpoint + } + return base + "/" + endpoint +} + +func (c *SoraDirectClient) defaultUserAgent() string { + if c == nil || c.cfg == nil { + return soraDefaultUserAgent + } + ua := strings.TrimSpace(c.cfg.Sora.Client.UserAgent) + if ua == "" { + return soraDefaultUserAgent + } + return ua +} + +func (c *SoraDirectClient) taskUserAgent() string { + if c != nil && c.cfg != nil { + if ua := strings.TrimSpace(c.cfg.Sora.Client.UserAgent); ua != "" { + return ua + } + } + if len(soraMobileUserAgents) > 0 { + return soraMobileUserAgents[soraRandInt(len(soraMobileUserAgents))] + } + if len(soraDesktopUserAgents) > 0 { + return soraDesktopUserAgents[soraRandInt(len(soraDesktopUserAgents))] + } + return soraDefaultUserAgent +} + +func (c *SoraDirectClient) resolveProxyURL(account *Account) string { + if account == nil || account.ProxyID == nil || account.Proxy == nil { + return "" + } + return strings.TrimSpace(account.Proxy.URL()) +} + +func (c *SoraDirectClient) getAccessToken(ctx context.Context, account *Account) (string, error) { + if account == nil { + return "", errors.New("account is nil") + } + + allowProvider := c.allowOpenAITokenProvider(account) + var providerErr error + if allowProvider && c.tokenProvider != nil { + token, err := c.tokenProvider.GetAccessToken(ctx, account) + if err == nil && strings.TrimSpace(token) != "" { + c.logTokenSource(account, "openai_token_provider") + return token, nil + } + providerErr = err + if err != nil && c.debugEnabled() { + c.debugLogf( + "token_provider_failed account_id=%d platform=%s err=%s", + account.ID, + account.Platform, + logredact.RedactText(err.Error()), + ) + } + } + token := strings.TrimSpace(account.GetCredential("access_token")) + if token != "" { + expiresAt := account.GetCredentialAsTime("expires_at") + if expiresAt != nil && time.Until(*expiresAt) <= 2*time.Minute { + refreshed, refreshErr := c.recoverAccessToken(ctx, account, "access_token_expiring") + if refreshErr == nil && strings.TrimSpace(refreshed) != "" { + c.logTokenSource(account, "refresh_token_recovered") + return refreshed, nil + } + if refreshErr != nil && c.debugEnabled() { + c.debugLogf("token_refresh_before_use_failed account_id=%d err=%s", account.ID, logredact.RedactText(refreshErr.Error())) + } + } + c.logTokenSource(account, "account_credentials") + return token, nil + } + + recovered, recoverErr := c.recoverAccessToken(ctx, account, "access_token_missing") + if recoverErr == nil && strings.TrimSpace(recovered) != "" { + c.logTokenSource(account, "session_or_refresh_recovered") + return recovered, nil + } + if recoverErr != nil && c.debugEnabled() { + c.debugLogf("token_recover_failed account_id=%d platform=%s err=%s", account.ID, account.Platform, logredact.RedactText(recoverErr.Error())) + } + if providerErr != nil { + return "", providerErr + } + if c.tokenProvider != nil && !allowProvider { + c.logTokenSource(account, "account_credentials(provider_disabled)") + } + return "", errors.New("access_token not found") +} + +func (c *SoraDirectClient) recoverAccessToken(ctx context.Context, account *Account, reason string) (string, error) { + if account == nil { + return "", errors.New("account is nil") + } + + if sessionToken := strings.TrimSpace(account.GetCredential("session_token")); sessionToken != "" { + accessToken, expiresAt, err := c.exchangeSessionToken(ctx, account, sessionToken) + if err == nil && strings.TrimSpace(accessToken) != "" { + c.applyRecoveredToken(ctx, account, accessToken, "", expiresAt, sessionToken) + c.logTokenRecover(account, "session_token", reason, true, nil) + return accessToken, nil + } + c.logTokenRecover(account, "session_token", reason, false, err) + } + + refreshToken := strings.TrimSpace(account.GetCredential("refresh_token")) + if refreshToken == "" { + return "", errors.New("session_token/refresh_token not found") + } + accessToken, newRefreshToken, expiresAt, err := c.exchangeRefreshToken(ctx, account, refreshToken) + if err != nil { + c.logTokenRecover(account, "refresh_token", reason, false, err) + return "", err + } + if strings.TrimSpace(accessToken) == "" { + return "", errors.New("refreshed access_token is empty") + } + c.applyRecoveredToken(ctx, account, accessToken, newRefreshToken, expiresAt, "") + c.logTokenRecover(account, "refresh_token", reason, true, nil) + return accessToken, nil +} + +func (c *SoraDirectClient) exchangeSessionToken(ctx context.Context, account *Account, sessionToken string) (string, string, error) { + headers := http.Header{} + headers.Set("Cookie", "__Secure-next-auth.session-token="+sessionToken) + headers.Set("Accept", "application/json") + headers.Set("Origin", "https://sora.chatgpt.com") + headers.Set("Referer", "https://sora.chatgpt.com/") + headers.Set("User-Agent", c.defaultUserAgent()) + body, _, err := c.doRequest(ctx, account, http.MethodGet, soraSessionAuthURL, headers, nil, false) + if err != nil { + return "", "", err + } + accessToken := strings.TrimSpace(gjson.GetBytes(body, "accessToken").String()) + if accessToken == "" { + return "", "", errors.New("session exchange missing accessToken") + } + expiresAt := strings.TrimSpace(gjson.GetBytes(body, "expires").String()) + return accessToken, expiresAt, nil +} + +func (c *SoraDirectClient) exchangeRefreshToken(ctx context.Context, account *Account, refreshToken string) (string, string, string, error) { + clientIDs := []string{ + strings.TrimSpace(account.GetCredential("client_id")), + openaioauth.SoraClientID, + openaioauth.ClientID, + } + tried := make(map[string]struct{}, len(clientIDs)) + var lastErr error + + for _, clientID := range clientIDs { + if clientID == "" { + continue + } + if _, ok := tried[clientID]; ok { + continue + } + tried[clientID] = struct{}{} + + formData := url.Values{} + formData.Set("client_id", clientID) + formData.Set("grant_type", "refresh_token") + formData.Set("refresh_token", refreshToken) + formData.Set("redirect_uri", "com.openai.chat://auth0.openai.com/ios/com.openai.chat/callback") + headers := http.Header{} + headers.Set("Accept", "application/json") + headers.Set("Content-Type", "application/x-www-form-urlencoded") + headers.Set("User-Agent", c.defaultUserAgent()) + + respBody, _, err := c.doRequest(ctx, account, http.MethodPost, soraOAuthTokenURL, headers, strings.NewReader(formData.Encode()), false) + if err != nil { + lastErr = err + if c.debugEnabled() { + c.debugLogf("refresh_token_exchange_failed account_id=%d client_id=%s err=%s", account.ID, clientID, logredact.RedactText(err.Error())) + } + continue + } + accessToken := strings.TrimSpace(gjson.GetBytes(respBody, "access_token").String()) + if accessToken == "" { + lastErr = errors.New("oauth refresh response missing access_token") + continue + } + newRefreshToken := strings.TrimSpace(gjson.GetBytes(respBody, "refresh_token").String()) + expiresIn := gjson.GetBytes(respBody, "expires_in").Int() + expiresAt := "" + if expiresIn > 0 { + expiresAt = time.Now().Add(time.Duration(expiresIn) * time.Second).Format(time.RFC3339) + } + return accessToken, newRefreshToken, expiresAt, nil + } + + if lastErr != nil { + return "", "", "", lastErr + } + return "", "", "", errors.New("no available client_id for refresh_token exchange") +} + +func (c *SoraDirectClient) applyRecoveredToken(ctx context.Context, account *Account, accessToken, refreshToken, expiresAt, sessionToken string) { + if account == nil { + return + } + if account.Credentials == nil { + account.Credentials = make(map[string]any) + } + if strings.TrimSpace(accessToken) != "" { + account.Credentials["access_token"] = accessToken + } + if strings.TrimSpace(refreshToken) != "" { + account.Credentials["refresh_token"] = refreshToken + } + if strings.TrimSpace(expiresAt) != "" { + account.Credentials["expires_at"] = expiresAt + } + if strings.TrimSpace(sessionToken) != "" { + account.Credentials["session_token"] = sessionToken + } + + if c.accountRepo != nil { + if err := c.accountRepo.Update(ctx, account); err != nil { + if c.debugEnabled() { + c.debugLogf("persist_recovered_token_failed account_id=%d err=%s", account.ID, logredact.RedactText(err.Error())) + } + } + } + c.updateSoraAccountExtension(ctx, account, accessToken, refreshToken, sessionToken) +} + +func (c *SoraDirectClient) updateSoraAccountExtension(ctx context.Context, account *Account, accessToken, refreshToken, sessionToken string) { + if c == nil || c.soraAccountRepo == nil || account == nil || account.ID <= 0 { + return + } + updates := make(map[string]any) + if strings.TrimSpace(accessToken) != "" && strings.TrimSpace(refreshToken) != "" { + updates["access_token"] = accessToken + updates["refresh_token"] = refreshToken + } + if strings.TrimSpace(sessionToken) != "" { + updates["session_token"] = sessionToken + } + if len(updates) == 0 { + return + } + if err := c.soraAccountRepo.Upsert(ctx, account.ID, updates); err != nil && c.debugEnabled() { + c.debugLogf("persist_sora_extension_failed account_id=%d err=%s", account.ID, logredact.RedactText(err.Error())) + } +} + +func (c *SoraDirectClient) logTokenRecover(account *Account, source, reason string, success bool, err error) { + if !c.debugEnabled() || account == nil { + return + } + if success { + c.debugLogf("token_recover_success account_id=%d platform=%s source=%s reason=%s", account.ID, account.Platform, source, reason) + return + } + if err == nil { + c.debugLogf("token_recover_failed account_id=%d platform=%s source=%s reason=%s", account.ID, account.Platform, source, reason) + return + } + c.debugLogf("token_recover_failed account_id=%d platform=%s source=%s reason=%s err=%s", account.ID, account.Platform, source, reason, logredact.RedactText(err.Error())) +} + +func (c *SoraDirectClient) allowOpenAITokenProvider(account *Account) bool { + if c == nil || c.tokenProvider == nil { + return false + } + if account != nil && account.Platform == PlatformSora { + return c.cfg != nil && c.cfg.Sora.Client.UseOpenAITokenProvider + } + return true +} + +func (c *SoraDirectClient) logTokenSource(account *Account, source string) { + if !c.debugEnabled() || account == nil { + return + } + c.debugLogf( + "token_selected account_id=%d platform=%s account_type=%s source=%s", + account.ID, + account.Platform, + account.Type, + source, + ) +} + +func (c *SoraDirectClient) buildBaseHeaders(token, userAgent string) http.Header { + headers := http.Header{} + if token != "" { + headers.Set("Authorization", "Bearer "+token) + } + if userAgent != "" { + headers.Set("User-Agent", userAgent) + } + if c != nil && c.cfg != nil { + for key, value := range c.cfg.Sora.Client.Headers { + if strings.EqualFold(key, "authorization") || strings.EqualFold(key, "openai-sentinel-token") { + continue + } + headers.Set(key, value) + } + } + return headers +} + +func (c *SoraDirectClient) doRequest(ctx context.Context, account *Account, method, urlStr string, headers http.Header, body io.Reader, allowRetry bool) ([]byte, http.Header, error) { + return c.doRequestWithProxy(ctx, account, c.resolveProxyURL(account), method, urlStr, headers, body, allowRetry) +} + +func (c *SoraDirectClient) doRequestWithProxy( + ctx context.Context, + account *Account, + proxyURL string, + method, + urlStr string, + headers http.Header, + body io.Reader, + allowRetry bool, +) ([]byte, http.Header, error) { + if strings.TrimSpace(urlStr) == "" { + return nil, nil, errors.New("empty upstream url") + } + proxyURL = strings.TrimSpace(proxyURL) + if proxyURL == "" { + proxyURL = c.resolveProxyURL(account) + } + if cooldownErr := c.checkCloudflareChallengeCooldown(account, proxyURL); cooldownErr != nil { + return nil, nil, cooldownErr + } + traceID, traceProxyKey, traceUAHash := c.requestTraceFields(ctx, proxyURL, headers.Get("User-Agent")) + timeout := 0 + if c != nil && c.cfg != nil { + timeout = c.cfg.Sora.Client.TimeoutSeconds + } + if timeout > 0 { + var cancel context.CancelFunc + ctx, cancel = context.WithTimeout(ctx, time.Duration(timeout)*time.Second) + defer cancel() + } + maxRetries := 0 + if allowRetry && c != nil && c.cfg != nil { + maxRetries = c.cfg.Sora.Client.MaxRetries + } + if maxRetries < 0 { + maxRetries = 0 + } + + var bodyBytes []byte + if body != nil { + b, err := io.ReadAll(body) + if err != nil { + return nil, nil, err + } + bodyBytes = b + } + + attempts := maxRetries + 1 + authRecovered := false + authRecoverExtraAttemptGranted := false + challengeRetried := false + sawCFChallenge := false + var lastErr error + for attempt := 1; attempt <= attempts; attempt++ { + if c.debugEnabled() { + c.debugLogf( + "request_start trace_id=%s method=%s url=%s attempt=%d/%d timeout_s=%d body_bytes=%d proxy_bound=%t proxy_key=%s ua_hash=%s headers=%s", + traceID, + method, + sanitizeSoraLogURL(urlStr), + attempt, + attempts, + timeout, + len(bodyBytes), + proxyURL != "", + traceProxyKey, + traceUAHash, + formatSoraHeaders(headers), + ) + } + + var reader io.Reader + if bodyBytes != nil { + reader = bytes.NewReader(bodyBytes) + } + req, err := http.NewRequestWithContext(ctx, method, urlStr, reader) + if err != nil { + return nil, nil, err + } + req.Header = headers.Clone() + start := time.Now() + + resp, err := c.doHTTP(req, proxyURL, account) + if err != nil { + lastErr = err + if c.debugEnabled() { + c.debugLogf( + "request_transport_error trace_id=%s method=%s url=%s attempt=%d/%d err=%s", + traceID, + method, + sanitizeSoraLogURL(urlStr), + attempt, + attempts, + logredact.RedactText(err.Error()), + ) + } + if attempt < attempts && allowRetry { + if c.debugEnabled() { + c.debugLogf("request_retry_scheduled trace_id=%s method=%s url=%s reason=transport_error next_attempt=%d/%d", traceID, method, sanitizeSoraLogURL(urlStr), attempt+1, attempts) + } + c.sleepRetry(attempt) + continue + } + return nil, nil, err + } + + respBody, readErr := io.ReadAll(io.LimitReader(resp.Body, 2<<20)) + _ = resp.Body.Close() + if readErr != nil { + return nil, resp.Header, readErr + } + + if c.cfg != nil && c.cfg.Sora.Client.Debug { + c.debugLogf( + "response_received trace_id=%s method=%s url=%s attempt=%d/%d status=%d cost=%s resp_bytes=%d resp_headers=%s", + traceID, + method, + sanitizeSoraLogURL(urlStr), + attempt, + attempts, + resp.StatusCode, + time.Since(start), + len(respBody), + formatSoraHeaders(resp.Header), + ) + } + + if resp.StatusCode != http.StatusOK && resp.StatusCode != http.StatusCreated { + isCFChallenge := soraerror.IsCloudflareChallengeResponse(resp.StatusCode, resp.Header, respBody) + if isCFChallenge { + sawCFChallenge = true + c.recordCloudflareChallengeCooldown(account, proxyURL, resp.StatusCode, resp.Header, respBody) + if allowRetry && attempt < attempts && !challengeRetried { + challengeRetried = true + if c.debugEnabled() { + c.debugLogf("request_retry_scheduled trace_id=%s method=%s url=%s reason=cloudflare_challenge status=%d next_attempt=%d/%d", traceID, method, sanitizeSoraLogURL(urlStr), resp.StatusCode, attempt+1, attempts) + } + c.sleepRetry(attempt) + continue + } + } + if !isCFChallenge && !authRecovered && shouldAttemptSoraTokenRecover(resp.StatusCode, urlStr) && account != nil { + if recovered, recoverErr := c.recoverAccessToken(ctx, account, fmt.Sprintf("upstream_status_%d", resp.StatusCode)); recoverErr == nil && strings.TrimSpace(recovered) != "" { + headers.Set("Authorization", "Bearer "+recovered) + authRecovered = true + if attempt == attempts && !authRecoverExtraAttemptGranted { + attempts++ + authRecoverExtraAttemptGranted = true + } + if c.debugEnabled() { + c.debugLogf("request_retry_with_recovered_token trace_id=%s method=%s url=%s status=%d", traceID, method, sanitizeSoraLogURL(urlStr), resp.StatusCode) + } + continue + } else if recoverErr != nil && c.debugEnabled() { + c.debugLogf("request_recover_token_failed trace_id=%s method=%s url=%s status=%d err=%s", traceID, method, sanitizeSoraLogURL(urlStr), resp.StatusCode, logredact.RedactText(recoverErr.Error())) + } + } + if c.debugEnabled() { + c.debugLogf( + "response_non_success trace_id=%s method=%s url=%s attempt=%d/%d status=%d body=%s", + traceID, + method, + sanitizeSoraLogURL(urlStr), + attempt, + attempts, + resp.StatusCode, + summarizeSoraResponseBody(respBody, 512), + ) + } + upstreamErr := c.buildUpstreamError(resp.StatusCode, resp.Header, respBody, urlStr) + lastErr = upstreamErr + if isCFChallenge { + return nil, resp.Header, upstreamErr + } + if allowRetry && attempt < attempts && (resp.StatusCode == http.StatusTooManyRequests || resp.StatusCode >= 500) { + if c.debugEnabled() { + c.debugLogf("request_retry_scheduled trace_id=%s method=%s url=%s reason=status_%d next_attempt=%d/%d", traceID, method, sanitizeSoraLogURL(urlStr), resp.StatusCode, attempt+1, attempts) + } + c.sleepRetry(attempt) + continue + } + return nil, resp.Header, upstreamErr + } + if sawCFChallenge { + c.clearCloudflareChallengeCooldown(account, proxyURL) + } + return respBody, resp.Header, nil + } + if lastErr != nil { + return nil, nil, lastErr + } + return nil, nil, errors.New("upstream retries exhausted") +} + +func shouldAttemptSoraTokenRecover(statusCode int, rawURL string) bool { + switch statusCode { + case http.StatusUnauthorized, http.StatusForbidden: + parsed, err := url.Parse(strings.TrimSpace(rawURL)) + if err != nil { + return false + } + host := strings.ToLower(parsed.Hostname()) + if host != "sora.chatgpt.com" && host != "chatgpt.com" { + return false + } + // 避免在 ST->AT 转换接口上递归触发 token 恢复导致死循环。 + path := strings.ToLower(strings.TrimSpace(parsed.Path)) + if path == "/api/auth/session" { + return false + } + return true + default: + return false + } +} + +func (c *SoraDirectClient) doHTTP(req *http.Request, proxyURL string, account *Account) (*http.Response, error) { + if c != nil && c.cfg != nil && c.cfg.Sora.Client.CurlCFFISidecar.Enabled { + resp, err := c.doHTTPViaCurlCFFISidecar(req, proxyURL, account) + if err != nil { + return nil, err + } + return resp, nil + } + + enableTLS := c == nil || c.cfg == nil || !c.cfg.Sora.Client.DisableTLSFingerprint + if c.httpUpstream != nil { + accountID := int64(0) + accountConcurrency := 0 + if account != nil { + accountID = account.ID + accountConcurrency = account.Concurrency + } + return c.httpUpstream.DoWithTLS(req, proxyURL, accountID, accountConcurrency, enableTLS) + } + return http.DefaultClient.Do(req) +} + +func (c *SoraDirectClient) sleepRetry(attempt int) { + backoff := time.Duration(attempt*attempt) * time.Second + if backoff > 10*time.Second { + backoff = 10 * time.Second + } + time.Sleep(backoff) +} + +func (c *SoraDirectClient) buildUpstreamError(status int, headers http.Header, body []byte, requestURL string) error { + msg := strings.TrimSpace(extractUpstreamErrorMessage(body)) + msg = sanitizeUpstreamErrorMessage(msg) + if status == http.StatusNotFound && strings.Contains(strings.ToLower(msg), "not found") { + if hint := soraBaseURLNotFoundHint(requestURL); hint != "" { + msg = strings.TrimSpace(msg + " " + hint) + } + } + if msg == "" { + msg = truncateForLog(body, 256) + } + return &SoraUpstreamError{ + StatusCode: status, + Message: msg, + Headers: headers, + Body: body, + } +} + +func normalizeSoraBaseURL(raw string) string { + trimmed := strings.TrimRight(strings.TrimSpace(raw), "/") + if trimmed == "" { + return "" + } + parsed, err := url.Parse(trimmed) + if err != nil || parsed.Scheme == "" || parsed.Host == "" { + return trimmed + } + host := strings.ToLower(parsed.Hostname()) + if host != "sora.chatgpt.com" && host != "chatgpt.com" { + return trimmed + } + pathVal := strings.TrimRight(strings.TrimSpace(parsed.Path), "/") + switch pathVal { + case "", "/": + parsed.Path = "/backend" + case "/backend-api": + parsed.Path = "/backend" + } + return strings.TrimRight(parsed.String(), "/") +} + +func soraBaseURLNotFoundHint(requestURL string) string { + parsed, err := url.Parse(strings.TrimSpace(requestURL)) + if err != nil || parsed.Host == "" { + return "" + } + host := strings.ToLower(parsed.Hostname()) + if host != "sora.chatgpt.com" && host != "chatgpt.com" { + return "" + } + pathVal := strings.TrimSpace(parsed.Path) + if strings.HasPrefix(pathVal, "/backend/") || pathVal == "/backend" { + return "" + } + return "(请检查 sora.client.base_url,建议配置为 https://sora.chatgpt.com/backend)" +} + +func (c *SoraDirectClient) generateSentinelToken(ctx context.Context, account *Account, accessToken, userAgent, proxyURL string) (string, error) { + reqID := uuid.NewString() + userAgent = strings.TrimSpace(userAgent) + if userAgent == "" { + userAgent = c.taskUserAgent() + } + powToken := soraPowTokenGenerator(userAgent) + payload := map[string]any{ + "p": powToken, + "flow": soraSentinelFlow, + "id": reqID, + } + body, err := json.Marshal(payload) + if err != nil { + return "", err + } + headers := http.Header{} + headers.Set("Accept", "application/json, text/plain, */*") + headers.Set("Content-Type", "application/json") + headers.Set("Origin", "https://sora.chatgpt.com") + headers.Set("Referer", "https://sora.chatgpt.com/") + headers.Set("User-Agent", userAgent) + if accessToken != "" { + headers.Set("Authorization", "Bearer "+accessToken) + } + + urlStr := soraChatGPTBaseURL + "/backend-api/sentinel/req" + respBody, _, err := c.doRequestWithProxy(ctx, account, proxyURL, http.MethodPost, urlStr, headers, bytes.NewReader(body), true) + if err != nil { + return "", err + } + var resp map[string]any + if err := json.Unmarshal(respBody, &resp); err != nil { + return "", err + } + + sentinel := soraBuildSentinelToken(soraSentinelFlow, reqID, powToken, resp, userAgent) + if sentinel == "" { + return "", errors.New("failed to build sentinel token") + } + return sentinel, nil +} + +func soraGetPowToken(userAgent string) string { + configList := soraBuildPowConfig(userAgent) + seed := strconv.FormatFloat(soraRandFloat(), 'f', -1, 64) + difficulty := "0fffff" + solution, _ := soraSolvePow(seed, difficulty, configList) + return "gAAAAAC" + solution +} + +func soraRandFloat() float64 { + soraRandMu.Lock() + defer soraRandMu.Unlock() + return soraRand.Float64() +} + +func soraRandInt(max int) int { + if max <= 1 { + return 0 + } + soraRandMu.Lock() + defer soraRandMu.Unlock() + return soraRand.Intn(max) +} + +func soraBuildPowConfig(userAgent string) []any { + userAgent = strings.TrimSpace(userAgent) + if userAgent == "" && len(soraDesktopUserAgents) > 0 { + userAgent = soraDesktopUserAgents[0] + } + screenVal := soraStableChoiceInt([]int{ + 1920 + 1080, + 2560 + 1440, + 1920 + 1200, + 2560 + 1600, + }, userAgent+"|screen") + perfMs := float64(time.Since(soraPerfStart).Milliseconds()) + wallMs := float64(time.Now().UnixNano()) / 1e6 + diff := wallMs - perfMs + return []any{ + screenVal, + soraPowParseTime(), + 4294705152, + 0, + userAgent, + soraStableChoice(soraPowScripts, userAgent+"|script"), + soraStableChoice(soraPowDPL, userAgent+"|dpl"), + "en-US", + "en-US,es-US,en,es", + 0, + soraStableChoice(soraPowNavigatorKeys, userAgent+"|navigator"), + soraStableChoice(soraPowDocumentKeys, userAgent+"|document"), + soraStableChoice(soraPowWindowKeys, userAgent+"|window"), + perfMs, + uuid.NewString(), + "", + soraStableChoiceInt(soraPowCores, userAgent+"|cores"), + diff, + } +} + +func soraStableChoice(items []string, seed string) string { + if len(items) == 0 { + return "" + } + idx := soraStableIndex(seed, len(items)) + return items[idx] +} + +func soraStableChoiceInt(items []int, seed string) int { + if len(items) == 0 { + return 0 + } + idx := soraStableIndex(seed, len(items)) + return items[idx] +} + +func soraStableIndex(seed string, size int) int { + if size <= 0 { + return 0 + } + h := fnv.New32a() + _, _ = h.Write([]byte(seed)) + return int(h.Sum32() % uint32(size)) +} + +func soraPowParseTime() string { + loc := time.FixedZone("EST", -5*3600) + return time.Now().In(loc).Format("Mon Jan 02 2006 15:04:05 GMT-0700 (Eastern Standard Time)") +} + +func soraSolvePow(seed, difficulty string, configList []any) (string, bool) { + diffLen := len(difficulty) / 2 + target, err := hexDecodeString(difficulty) + if err != nil { + return "", false + } + seedBytes := []byte(seed) + + part1 := mustMarshalJSON(configList[:3]) + part2 := mustMarshalJSON(configList[4:9]) + part3 := mustMarshalJSON(configList[10:]) + + staticPart1 := append(part1[:len(part1)-1], ',') + staticPart2 := append([]byte(","), append(part2[1:len(part2)-1], ',')...) + staticPart3 := append([]byte(","), part3[1:]...) + + for i := 0; i < soraPowMaxIteration; i++ { + dynamicI := []byte(strconv.Itoa(i)) + dynamicJ := []byte(strconv.Itoa(i >> 1)) + finalJSON := make([]byte, 0, len(staticPart1)+len(dynamicI)+len(staticPart2)+len(dynamicJ)+len(staticPart3)) + finalJSON = append(finalJSON, staticPart1...) + finalJSON = append(finalJSON, dynamicI...) + finalJSON = append(finalJSON, staticPart2...) + finalJSON = append(finalJSON, dynamicJ...) + finalJSON = append(finalJSON, staticPart3...) + + b64 := base64.StdEncoding.EncodeToString(finalJSON) + hash := sha3.Sum512(append(seedBytes, []byte(b64)...)) + if bytes.Compare(hash[:diffLen], target[:diffLen]) <= 0 { + return b64, true + } + } + + errorToken := "wQ8Lk5FbGpA2NcR9dShT6gYjU7VxZ4D" + base64.StdEncoding.EncodeToString([]byte(fmt.Sprintf("\"%s\"", seed))) + return errorToken, false +} + +func soraBuildSentinelToken(flow, reqID, powToken string, resp map[string]any, userAgent string) string { + finalPow := powToken + proof, _ := resp["proofofwork"].(map[string]any) + if required, _ := proof["required"].(bool); required { + seed, _ := proof["seed"].(string) + difficulty, _ := proof["difficulty"].(string) + if seed != "" && difficulty != "" { + configList := soraBuildPowConfig(userAgent) + solution, _ := soraSolvePow(seed, difficulty, configList) + finalPow = "gAAAAAB" + solution + } + } + if !strings.HasSuffix(finalPow, "~S") { + finalPow += "~S" + } + turnstile, _ := resp["turnstile"].(map[string]any) + tokenPayload := map[string]any{ + "p": finalPow, + "t": safeMapString(turnstile, "dx"), + "c": safeString(resp["token"]), + "id": reqID, + "flow": flow, + } + encoded, _ := json.Marshal(tokenPayload) + return string(encoded) +} + +func safeMapString(m map[string]any, key string) string { + if m == nil { + return "" + } + if v, ok := m[key]; ok { + return safeString(v) + } + return "" +} + +func safeString(v any) string { + switch val := v.(type) { + case string: + return val + default: + return fmt.Sprintf("%v", val) + } +} + +func mustMarshalJSON(v any) []byte { + b, _ := json.Marshal(v) + return b +} + +func hexDecodeString(s string) ([]byte, error) { + dst := make([]byte, len(s)/2) + _, err := hex.Decode(dst, []byte(s)) + return dst, err +} + +func (c *SoraDirectClient) withRequestTrace(ctx context.Context, account *Account, proxyURL, userAgent string) context.Context { + if ctx == nil { + ctx = context.Background() + } + if existing, ok := ctx.Value(soraRequestTraceContextKey{}).(*soraRequestTrace); ok && existing != nil && existing.ID != "" { + return ctx + } + accountID := int64(0) + if account != nil { + accountID = account.ID + } + seed := fmt.Sprintf("%d|%s|%s|%d", accountID, normalizeSoraProxyKey(proxyURL), strings.TrimSpace(userAgent), time.Now().UnixNano()) + trace := &soraRequestTrace{ + ID: "sora-" + soraHashForLog(seed), + ProxyKey: normalizeSoraProxyKey(proxyURL), + UAHash: soraHashForLog(strings.TrimSpace(userAgent)), + } + return context.WithValue(ctx, soraRequestTraceContextKey{}, trace) +} + +func (c *SoraDirectClient) requestTraceFields(ctx context.Context, proxyURL, userAgent string) (string, string, string) { + proxyKey := normalizeSoraProxyKey(proxyURL) + uaHash := soraHashForLog(strings.TrimSpace(userAgent)) + traceID := "" + if ctx != nil { + if trace, ok := ctx.Value(soraRequestTraceContextKey{}).(*soraRequestTrace); ok && trace != nil { + if strings.TrimSpace(trace.ID) != "" { + traceID = strings.TrimSpace(trace.ID) + } + if strings.TrimSpace(trace.ProxyKey) != "" { + proxyKey = strings.TrimSpace(trace.ProxyKey) + } + if strings.TrimSpace(trace.UAHash) != "" { + uaHash = strings.TrimSpace(trace.UAHash) + } + } + } + if traceID == "" { + traceID = "sora-" + soraHashForLog(fmt.Sprintf("%s|%d", proxyKey, time.Now().UnixNano())) + } + return traceID, proxyKey, uaHash +} + +func soraHashForLog(raw string) string { + h := fnv.New32a() + _, _ = h.Write([]byte(raw)) + return fmt.Sprintf("%08x", h.Sum32()) +} + +func sanitizeSoraLogURL(raw string) string { + parsed, err := url.Parse(raw) + if err != nil { + return raw + } + q := parsed.Query() + q.Del("sig") + q.Del("expires") + parsed.RawQuery = q.Encode() + return parsed.String() +} + +func (c *SoraDirectClient) debugEnabled() bool { + return c != nil && c.cfg != nil && c.cfg.Sora.Client.Debug +} + +func (c *SoraDirectClient) debugLogf(format string, args ...any) { + if !c.debugEnabled() { + return + } + log.Printf("[SoraClient] "+format, args...) +} + +func formatSoraHeaders(headers http.Header) string { + if len(headers) == 0 { + return "{}" + } + keys := make([]string, 0, len(headers)) + for key := range headers { + keys = append(keys, key) + } + sort.Strings(keys) + out := make(map[string]string, len(keys)) + for _, key := range keys { + values := headers.Values(key) + if len(values) == 0 { + continue + } + val := strings.Join(values, ",") + if isSensitiveHeader(key) { + out[key] = "***" + continue + } + out[key] = truncateForLog([]byte(logredact.RedactText(val)), 160) + } + encoded, err := json.Marshal(out) + if err != nil { + return "{}" + } + return string(encoded) +} + +func isSensitiveHeader(key string) bool { + k := strings.ToLower(strings.TrimSpace(key)) + switch k { + case "authorization", "openai-sentinel-token", "cookie", "set-cookie", "x-api-key": + return true + default: + return false + } +} + +func summarizeSoraResponseBody(body []byte, maxLen int) string { + if len(body) == 0 { + return "" + } + var text string + if json.Valid(body) { + text = logredact.RedactJSON(body) + } else { + text = logredact.RedactText(string(body)) + } + text = strings.TrimSpace(text) + if maxLen <= 0 || len(text) <= maxLen { + return text + } + return text[:maxLen] + "...(truncated)" +} diff --git a/backend/internal/service/sora_client_gjson_test.go b/backend/internal/service/sora_client_gjson_test.go new file mode 100644 index 00000000..d38cfa57 --- /dev/null +++ b/backend/internal/service/sora_client_gjson_test.go @@ -0,0 +1,515 @@ +//go:build unit + +package service + +import ( + "strings" + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + "github.com/tidwall/gjson" +) + +// ---------- 辅助解析函数(复制生产代码中的 gjson 解析逻辑,用于单元测试) ---------- + +// testParseUploadOrCreateTaskID 模拟 UploadImage / CreateImageTask / CreateVideoTask 中 +// 用 gjson.GetBytes(respBody, "id") 提取 id 的逻辑。 +func testParseUploadOrCreateTaskID(respBody []byte) (string, error) { + id := strings.TrimSpace(gjson.GetBytes(respBody, "id").String()) + if id == "" { + return "", assert.AnError // 占位错误,表示 "missing id" + } + return id, nil +} + +// testParseFetchRecentImageTask 模拟 fetchRecentImageTask 中的 gjson.ForEach 解析逻辑。 +func testParseFetchRecentImageTask(respBody []byte, taskID string) (*SoraImageTaskStatus, bool) { + var found *SoraImageTaskStatus + gjson.GetBytes(respBody, "task_responses").ForEach(func(_, item gjson.Result) bool { + if item.Get("id").String() != taskID { + return true // continue + } + status := strings.TrimSpace(item.Get("status").String()) + progress := item.Get("progress_pct").Float() + var urls []string + item.Get("generations").ForEach(func(_, gen gjson.Result) bool { + if u := strings.TrimSpace(gen.Get("url").String()); u != "" { + urls = append(urls, u) + } + return true + }) + found = &SoraImageTaskStatus{ + ID: taskID, + Status: status, + ProgressPct: progress, + URLs: urls, + } + return false // break + }) + if found != nil { + return found, true + } + return &SoraImageTaskStatus{ID: taskID, Status: "processing"}, false +} + +// testParseGetVideoTaskPending 模拟 GetVideoTask 中解析 pending 列表的逻辑。 +func testParseGetVideoTaskPending(respBody []byte, taskID string) (*SoraVideoTaskStatus, bool) { + pendingResult := gjson.ParseBytes(respBody) + if !pendingResult.IsArray() { + return nil, false + } + var pendingFound *SoraVideoTaskStatus + pendingResult.ForEach(func(_, task gjson.Result) bool { + if task.Get("id").String() != taskID { + return true + } + progress := 0 + if v := task.Get("progress_pct"); v.Exists() { + progress = int(v.Float() * 100) + } + status := strings.TrimSpace(task.Get("status").String()) + pendingFound = &SoraVideoTaskStatus{ + ID: taskID, + Status: status, + ProgressPct: progress, + } + return false + }) + if pendingFound != nil { + return pendingFound, true + } + return nil, false +} + +// testParseGetVideoTaskDrafts 模拟 GetVideoTask 中解析 drafts 列表的逻辑。 +func testParseGetVideoTaskDrafts(respBody []byte, taskID string) (*SoraVideoTaskStatus, bool) { + var draftFound *SoraVideoTaskStatus + gjson.GetBytes(respBody, "items").ForEach(func(_, draft gjson.Result) bool { + if draft.Get("task_id").String() != taskID { + return true + } + kind := strings.TrimSpace(draft.Get("kind").String()) + reason := strings.TrimSpace(draft.Get("reason_str").String()) + if reason == "" { + reason = strings.TrimSpace(draft.Get("markdown_reason_str").String()) + } + urlStr := strings.TrimSpace(draft.Get("downloadable_url").String()) + if urlStr == "" { + urlStr = strings.TrimSpace(draft.Get("url").String()) + } + + if kind == "sora_content_violation" || reason != "" || urlStr == "" { + msg := reason + if msg == "" { + msg = "Content violates guardrails" + } + draftFound = &SoraVideoTaskStatus{ + ID: taskID, + Status: "failed", + ErrorMsg: msg, + } + } else { + draftFound = &SoraVideoTaskStatus{ + ID: taskID, + Status: "completed", + URLs: []string{urlStr}, + } + } + return false + }) + if draftFound != nil { + return draftFound, true + } + return nil, false +} + +// ===================== Test 1: TestSoraParseUploadResponse ===================== + +func TestSoraParseUploadResponse(t *testing.T) { + tests := []struct { + name string + body string + wantID string + wantErr bool + }{ + { + name: "正常 id", + body: `{"id":"file-abc123","status":"uploaded"}`, + wantID: "file-abc123", + }, + { + name: "空 id", + body: `{"id":"","status":"uploaded"}`, + wantErr: true, + }, + { + name: "无 id 字段", + body: `{"status":"uploaded"}`, + wantErr: true, + }, + { + name: "id 全为空白", + body: `{"id":" ","status":"uploaded"}`, + wantErr: true, + }, + { + name: "id 前后有空白", + body: `{"id":" file-trimmed ","status":"uploaded"}`, + wantID: "file-trimmed", + }, + { + name: "空 JSON 对象", + body: `{}`, + wantErr: true, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + id, err := testParseUploadOrCreateTaskID([]byte(tt.body)) + if tt.wantErr { + require.Error(t, err, "应返回错误") + return + } + require.NoError(t, err) + require.Equal(t, tt.wantID, id) + }) + } +} + +// ===================== Test 2: TestSoraParseCreateTaskResponse ===================== + +func TestSoraParseCreateTaskResponse(t *testing.T) { + tests := []struct { + name string + body string + wantID string + wantErr bool + }{ + { + name: "正常任务 id", + body: `{"id":"task-123"}`, + wantID: "task-123", + }, + { + name: "缺失 id", + body: `{"status":"created"}`, + wantErr: true, + }, + { + name: "空 id", + body: `{"id":" "}`, + wantErr: true, + }, + { + name: "id 为数字(gjson 转字符串)", + body: `{"id":123}`, + wantID: "123", + }, + { + name: "id 含特殊字符", + body: `{"id":"task-abc-def-456-ghi"}`, + wantID: "task-abc-def-456-ghi", + }, + { + name: "额外字段不影响解析", + body: `{"id":"task-999","type":"image_gen","extra":"data"}`, + wantID: "task-999", + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + id, err := testParseUploadOrCreateTaskID([]byte(tt.body)) + if tt.wantErr { + require.Error(t, err, "应返回错误") + return + } + require.NoError(t, err) + require.Equal(t, tt.wantID, id) + }) + } +} + +// ===================== Test 3: TestSoraParseFetchRecentImageTask ===================== + +func TestSoraParseFetchRecentImageTask(t *testing.T) { + tests := []struct { + name string + body string + taskID string + wantFound bool + wantStatus string + wantProgress float64 + wantURLs []string + }{ + { + name: "匹配已完成任务", + body: `{"task_responses":[{"id":"task-1","status":"completed","progress_pct":1.0,"generations":[{"url":"https://example.com/img.png"}]}]}`, + taskID: "task-1", + wantFound: true, + wantStatus: "completed", + wantProgress: 1.0, + wantURLs: []string{"https://example.com/img.png"}, + }, + { + name: "匹配处理中任务", + body: `{"task_responses":[{"id":"task-2","status":"processing","progress_pct":0.5,"generations":[]}]}`, + taskID: "task-2", + wantFound: true, + wantStatus: "processing", + wantProgress: 0.5, + wantURLs: nil, + }, + { + name: "无匹配任务", + body: `{"task_responses":[{"id":"other","status":"completed"}]}`, + taskID: "task-1", + wantFound: false, + wantStatus: "processing", + }, + { + name: "空 task_responses", + body: `{"task_responses":[]}`, + taskID: "task-1", + wantFound: false, + wantStatus: "processing", + }, + { + name: "缺少 task_responses 字段", + body: `{"other":"data"}`, + taskID: "task-1", + wantFound: false, + wantStatus: "processing", + }, + { + name: "多个任务中精准匹配", + body: `{"task_responses":[{"id":"task-a","status":"completed","progress_pct":1.0,"generations":[{"url":"https://a.com/1.png"}]},{"id":"task-b","status":"processing","progress_pct":0.3,"generations":[]},{"id":"task-c","status":"failed","progress_pct":0}]}`, + taskID: "task-b", + wantFound: true, + wantStatus: "processing", + wantProgress: 0.3, + wantURLs: nil, + }, + { + name: "多个 generations", + body: `{"task_responses":[{"id":"task-m","status":"completed","progress_pct":1.0,"generations":[{"url":"https://a.com/1.png"},{"url":"https://a.com/2.png"},{"url":""}]}]}`, + taskID: "task-m", + wantFound: true, + wantStatus: "completed", + wantProgress: 1.0, + wantURLs: []string{"https://a.com/1.png", "https://a.com/2.png"}, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + status, found := testParseFetchRecentImageTask([]byte(tt.body), tt.taskID) + require.Equal(t, tt.wantFound, found, "found 不匹配") + require.NotNil(t, status) + require.Equal(t, tt.taskID, status.ID) + require.Equal(t, tt.wantStatus, status.Status) + if tt.wantFound { + require.InDelta(t, tt.wantProgress, status.ProgressPct, 0.001, "进度不匹配") + require.Equal(t, tt.wantURLs, status.URLs) + } + }) + } +} + +// ===================== Test 4: TestSoraParseGetVideoTaskPending ===================== + +func TestSoraParseGetVideoTaskPending(t *testing.T) { + tests := []struct { + name string + body string + taskID string + wantFound bool + wantStatus string + wantProgress int + }{ + { + name: "匹配 pending 任务", + body: `[{"id":"task-1","status":"processing","progress_pct":0.5}]`, + taskID: "task-1", + wantFound: true, + wantStatus: "processing", + wantProgress: 50, + }, + { + name: "进度为 0", + body: `[{"id":"task-2","status":"queued","progress_pct":0}]`, + taskID: "task-2", + wantFound: true, + wantStatus: "queued", + wantProgress: 0, + }, + { + name: "进度为 1(100%)", + body: `[{"id":"task-3","status":"completing","progress_pct":1.0}]`, + taskID: "task-3", + wantFound: true, + wantStatus: "completing", + wantProgress: 100, + }, + { + name: "空数组", + body: `[]`, + taskID: "task-1", + wantFound: false, + }, + { + name: "无匹配 id", + body: `[{"id":"task-other","status":"processing","progress_pct":0.3}]`, + taskID: "task-1", + wantFound: false, + }, + { + name: "多个任务精准匹配", + body: `[{"id":"task-a","status":"processing","progress_pct":0.2},{"id":"task-b","status":"queued","progress_pct":0},{"id":"task-c","status":"processing","progress_pct":0.8}]`, + taskID: "task-c", + wantFound: true, + wantStatus: "processing", + wantProgress: 80, + }, + { + name: "非数组 JSON", + body: `{"id":"task-1","status":"processing"}`, + taskID: "task-1", + wantFound: false, + }, + { + name: "无 progress_pct 字段", + body: `[{"id":"task-4","status":"pending"}]`, + taskID: "task-4", + wantFound: true, + wantStatus: "pending", + wantProgress: 0, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + status, found := testParseGetVideoTaskPending([]byte(tt.body), tt.taskID) + require.Equal(t, tt.wantFound, found, "found 不匹配") + if tt.wantFound { + require.NotNil(t, status) + require.Equal(t, tt.taskID, status.ID) + require.Equal(t, tt.wantStatus, status.Status) + require.Equal(t, tt.wantProgress, status.ProgressPct) + } + }) + } +} + +// ===================== Test 5: TestSoraParseGetVideoTaskDrafts ===================== + +func TestSoraParseGetVideoTaskDrafts(t *testing.T) { + tests := []struct { + name string + body string + taskID string + wantFound bool + wantStatus string + wantURLs []string + wantErr string + }{ + { + name: "正常完成的视频", + body: `{"items":[{"task_id":"task-1","kind":"video","downloadable_url":"https://example.com/video.mp4"}]}`, + taskID: "task-1", + wantFound: true, + wantStatus: "completed", + wantURLs: []string{"https://example.com/video.mp4"}, + }, + { + name: "使用 url 字段回退", + body: `{"items":[{"task_id":"task-2","kind":"video","url":"https://example.com/fallback.mp4"}]}`, + taskID: "task-2", + wantFound: true, + wantStatus: "completed", + wantURLs: []string{"https://example.com/fallback.mp4"}, + }, + { + name: "内容违规", + body: `{"items":[{"task_id":"task-3","kind":"sora_content_violation","reason_str":"Content policy violation"}]}`, + taskID: "task-3", + wantFound: true, + wantStatus: "failed", + wantErr: "Content policy violation", + }, + { + name: "内容违规 - markdown_reason_str 回退", + body: `{"items":[{"task_id":"task-4","kind":"sora_content_violation","markdown_reason_str":"Markdown reason"}]}`, + taskID: "task-4", + wantFound: true, + wantStatus: "failed", + wantErr: "Markdown reason", + }, + { + name: "内容违规 - 无 reason 使用默认消息", + body: `{"items":[{"task_id":"task-5","kind":"sora_content_violation"}]}`, + taskID: "task-5", + wantFound: true, + wantStatus: "failed", + wantErr: "Content violates guardrails", + }, + { + name: "有 reason_str 但非 violation kind(仍判定失败)", + body: `{"items":[{"task_id":"task-6","kind":"video","reason_str":"Some error occurred"}]}`, + taskID: "task-6", + wantFound: true, + wantStatus: "failed", + wantErr: "Some error occurred", + }, + { + name: "空 URL 判定为失败", + body: `{"items":[{"task_id":"task-7","kind":"video","downloadable_url":"","url":""}]}`, + taskID: "task-7", + wantFound: true, + wantStatus: "failed", + wantErr: "Content violates guardrails", + }, + { + name: "无匹配 task_id", + body: `{"items":[{"task_id":"task-other","kind":"video","downloadable_url":"https://example.com/video.mp4"}]}`, + taskID: "task-1", + wantFound: false, + }, + { + name: "空 items", + body: `{"items":[]}`, + taskID: "task-1", + wantFound: false, + }, + { + name: "缺少 items 字段", + body: `{"other":"data"}`, + taskID: "task-1", + wantFound: false, + }, + { + name: "多个 items 精准匹配", + body: `{"items":[{"task_id":"task-a","kind":"video","downloadable_url":"https://a.com/a.mp4"},{"task_id":"task-b","kind":"sora_content_violation","reason_str":"Bad content"},{"task_id":"task-c","kind":"video","downloadable_url":"https://c.com/c.mp4"}]}`, + taskID: "task-b", + wantFound: true, + wantStatus: "failed", + wantErr: "Bad content", + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + status, found := testParseGetVideoTaskDrafts([]byte(tt.body), tt.taskID) + require.Equal(t, tt.wantFound, found, "found 不匹配") + if !tt.wantFound { + return + } + require.NotNil(t, status) + require.Equal(t, tt.taskID, status.ID) + require.Equal(t, tt.wantStatus, status.Status) + if tt.wantErr != "" { + require.Equal(t, tt.wantErr, status.ErrorMsg) + } + if tt.wantURLs != nil { + require.Equal(t, tt.wantURLs, status.URLs) + } + }) + } +} diff --git a/backend/internal/service/sora_client_test.go b/backend/internal/service/sora_client_test.go new file mode 100644 index 00000000..cffe8a35 --- /dev/null +++ b/backend/internal/service/sora_client_test.go @@ -0,0 +1,1075 @@ +//go:build unit + +package service + +import ( + "context" + "encoding/base64" + "encoding/json" + "errors" + "io" + "net/http" + "net/http/httptest" + "strings" + "sync/atomic" + "testing" + "time" + + "github.com/Wei-Shaw/sub2api/internal/config" + "github.com/stretchr/testify/require" +) + +func TestSoraDirectClient_DoRequestSuccess(t *testing.T) { + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.Header().Set("Content-Type", "application/json") + w.WriteHeader(http.StatusOK) + _, _ = w.Write([]byte(`{"ok":true}`)) + })) + defer server.Close() + + cfg := &config.Config{ + Sora: config.SoraConfig{ + Client: config.SoraClientConfig{BaseURL: server.URL}, + }, + } + client := NewSoraDirectClient(cfg, nil, nil) + + body, _, err := client.doRequest(context.Background(), &Account{ID: 1}, http.MethodGet, server.URL, http.Header{}, nil, false) + require.NoError(t, err) + require.Contains(t, string(body), "ok") +} + +func TestSoraDirectClient_BuildBaseHeaders(t *testing.T) { + cfg := &config.Config{ + Sora: config.SoraConfig{ + Client: config.SoraClientConfig{ + Headers: map[string]string{ + "X-Test": "yes", + "Authorization": "should-ignore", + "openai-sentinel-token": "skip", + }, + }, + }, + } + client := NewSoraDirectClient(cfg, nil, nil) + + headers := client.buildBaseHeaders("token-123", "UA") + require.Equal(t, "Bearer token-123", headers.Get("Authorization")) + require.Equal(t, "UA", headers.Get("User-Agent")) + require.Equal(t, "yes", headers.Get("X-Test")) + require.Empty(t, headers.Get("openai-sentinel-token")) +} + +func TestSoraDirectClient_GetImageTaskFallbackLimit(t *testing.T) { + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + limit := r.URL.Query().Get("limit") + w.Header().Set("Content-Type", "application/json") + switch limit { + case "1": + _, _ = w.Write([]byte(`{"task_responses":[]}`)) + case "2": + _, _ = w.Write([]byte(`{"task_responses":[{"id":"task-1","status":"completed","progress_pct":1,"generations":[{"url":"https://example.com/a.png"}]}]}`)) + default: + _, _ = w.Write([]byte(`{"task_responses":[]}`)) + } + })) + defer server.Close() + + cfg := &config.Config{ + Sora: config.SoraConfig{ + Client: config.SoraClientConfig{ + BaseURL: server.URL, + RecentTaskLimit: 1, + RecentTaskLimitMax: 2, + }, + }, + } + client := NewSoraDirectClient(cfg, nil, nil) + account := &Account{Credentials: map[string]any{"access_token": "token"}} + + status, err := client.GetImageTask(context.Background(), account, "task-1") + require.NoError(t, err) + require.Equal(t, "completed", status.Status) + require.Equal(t, []string{"https://example.com/a.png"}, status.URLs) +} + +func TestNormalizeSoraBaseURL(t *testing.T) { + t.Parallel() + tests := []struct { + name string + raw string + want string + }{ + { + name: "empty", + raw: "", + want: "", + }, + { + name: "append_backend_for_sora_host", + raw: "https://sora.chatgpt.com", + want: "https://sora.chatgpt.com/backend", + }, + { + name: "convert_backend_api_to_backend", + raw: "https://sora.chatgpt.com/backend-api", + want: "https://sora.chatgpt.com/backend", + }, + { + name: "keep_backend", + raw: "https://sora.chatgpt.com/backend", + want: "https://sora.chatgpt.com/backend", + }, + { + name: "keep_custom_host", + raw: "https://example.com/custom-path", + want: "https://example.com/custom-path", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got := normalizeSoraBaseURL(tt.raw) + require.Equal(t, tt.want, got) + }) + } +} + +func TestSoraDirectClient_BuildURL_UsesNormalizedBaseURL(t *testing.T) { + t.Parallel() + cfg := &config.Config{ + Sora: config.SoraConfig{ + Client: config.SoraClientConfig{ + BaseURL: "https://sora.chatgpt.com", + }, + }, + } + client := NewSoraDirectClient(cfg, nil, nil) + require.Equal(t, "https://sora.chatgpt.com/backend/video_gen", client.buildURL("/video_gen")) +} + +func TestSoraDirectClient_BuildUpstreamError_NotFoundHint(t *testing.T) { + t.Parallel() + client := NewSoraDirectClient(&config.Config{}, nil, nil) + + err := client.buildUpstreamError(http.StatusNotFound, http.Header{}, []byte(`{"error":{"message":"Not found"}}`), "https://sora.chatgpt.com/video_gen") + var upstreamErr *SoraUpstreamError + require.ErrorAs(t, err, &upstreamErr) + require.Contains(t, upstreamErr.Message, "请检查 sora.client.base_url") + + errNoHint := client.buildUpstreamError(http.StatusNotFound, http.Header{}, []byte(`{"error":{"message":"Not found"}}`), "https://sora.chatgpt.com/backend/video_gen") + require.ErrorAs(t, errNoHint, &upstreamErr) + require.NotContains(t, upstreamErr.Message, "请检查 sora.client.base_url") +} + +func TestFormatSoraHeaders_RedactsSensitive(t *testing.T) { + t.Parallel() + headers := http.Header{} + headers.Set("Authorization", "Bearer secret-token") + headers.Set("openai-sentinel-token", "sentinel-secret") + headers.Set("X-Test", "ok") + + out := formatSoraHeaders(headers) + require.Contains(t, out, `"Authorization":"***"`) + require.Contains(t, out, `Sentinel-Token":"***"`) + require.Contains(t, out, `"X-Test":"ok"`) + require.NotContains(t, out, "secret-token") + require.NotContains(t, out, "sentinel-secret") +} + +func TestSummarizeSoraResponseBody_RedactsJSON(t *testing.T) { + t.Parallel() + body := []byte(`{"error":{"message":"bad"},"access_token":"abc123"}`) + out := summarizeSoraResponseBody(body, 512) + require.Contains(t, out, `"access_token":"***"`) + require.NotContains(t, out, "abc123") +} + +func TestSummarizeSoraResponseBody_Truncates(t *testing.T) { + t.Parallel() + body := []byte(strings.Repeat("x", 100)) + out := summarizeSoraResponseBody(body, 10) + require.Contains(t, out, "(truncated)") +} + +func TestSoraDirectClient_GetAccessToken_SoraDefaultUseCredentials(t *testing.T) { + t.Parallel() + cache := newOpenAITokenCacheStub() + provider := NewOpenAITokenProvider(nil, cache, nil) + cfg := &config.Config{ + Sora: config.SoraConfig{ + Client: config.SoraClientConfig{ + BaseURL: "https://sora.chatgpt.com/backend", + }, + }, + } + client := NewSoraDirectClient(cfg, nil, provider) + account := &Account{ + ID: 1, + Platform: PlatformSora, + Type: AccountTypeOAuth, + Credentials: map[string]any{ + "access_token": "sora-credential-token", + }, + } + + token, err := client.getAccessToken(context.Background(), account) + require.NoError(t, err) + require.Equal(t, "sora-credential-token", token) + require.Equal(t, int32(0), atomic.LoadInt32(&cache.getCalled)) +} + +func TestSoraDirectClient_GetAccessToken_SoraCanEnableProvider(t *testing.T) { + t.Parallel() + cache := newOpenAITokenCacheStub() + account := &Account{ + ID: 2, + Platform: PlatformSora, + Type: AccountTypeOAuth, + Credentials: map[string]any{ + "access_token": "sora-credential-token", + }, + } + cache.tokens[OpenAITokenCacheKey(account)] = "provider-token" + provider := NewOpenAITokenProvider(nil, cache, nil) + cfg := &config.Config{ + Sora: config.SoraConfig{ + Client: config.SoraClientConfig{ + BaseURL: "https://sora.chatgpt.com/backend", + UseOpenAITokenProvider: true, + }, + }, + } + client := NewSoraDirectClient(cfg, nil, provider) + + token, err := client.getAccessToken(context.Background(), account) + require.NoError(t, err) + require.Equal(t, "provider-token", token) + require.Greater(t, atomic.LoadInt32(&cache.getCalled), int32(0)) +} + +func TestSoraDirectClient_GetAccessToken_FromSessionToken(t *testing.T) { + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + require.Equal(t, http.MethodGet, r.Method) + require.Contains(t, r.Header.Get("Cookie"), "__Secure-next-auth.session-token=session-token") + w.Header().Set("Content-Type", "application/json") + _ = json.NewEncoder(w).Encode(map[string]any{ + "accessToken": "session-access-token", + "expires": "2099-01-01T00:00:00Z", + }) + })) + defer server.Close() + + origin := soraSessionAuthURL + soraSessionAuthURL = server.URL + defer func() { soraSessionAuthURL = origin }() + + client := NewSoraDirectClient(&config.Config{}, nil, nil) + account := &Account{ + ID: 10, + Platform: PlatformSora, + Type: AccountTypeOAuth, + Credentials: map[string]any{ + "session_token": "session-token", + }, + } + + token, err := client.getAccessToken(context.Background(), account) + require.NoError(t, err) + require.Equal(t, "session-access-token", token) + require.Equal(t, "session-access-token", account.GetCredential("access_token")) +} + +func TestSoraDirectClient_GetAccessToken_FromRefreshToken(t *testing.T) { + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + require.Equal(t, http.MethodPost, r.Method) + require.Equal(t, "/oauth/token", r.URL.Path) + require.Equal(t, "application/x-www-form-urlencoded", r.Header.Get("Content-Type")) + require.NoError(t, r.ParseForm()) + require.Equal(t, "refresh_token", r.FormValue("grant_type")) + require.Equal(t, "refresh-token-old", r.FormValue("refresh_token")) + require.NotEmpty(t, r.FormValue("client_id")) + require.Equal(t, "com.openai.chat://auth0.openai.com/ios/com.openai.chat/callback", r.FormValue("redirect_uri")) + w.Header().Set("Content-Type", "application/json") + _ = json.NewEncoder(w).Encode(map[string]any{ + "access_token": "refresh-access-token", + "refresh_token": "refresh-token-new", + "expires_in": 3600, + }) + })) + defer server.Close() + + origin := soraOAuthTokenURL + soraOAuthTokenURL = server.URL + "/oauth/token" + defer func() { soraOAuthTokenURL = origin }() + + client := NewSoraDirectClient(&config.Config{}, nil, nil) + account := &Account{ + ID: 11, + Platform: PlatformSora, + Type: AccountTypeOAuth, + Credentials: map[string]any{ + "refresh_token": "refresh-token-old", + }, + } + + token, err := client.getAccessToken(context.Background(), account) + require.NoError(t, err) + require.Equal(t, "refresh-access-token", token) + require.Equal(t, "refresh-token-new", account.GetCredential("refresh_token")) + require.NotNil(t, account.GetCredentialAsTime("expires_at")) +} + +func TestSoraDirectClient_PreflightCheck_VideoQuotaExceeded(t *testing.T) { + t.Parallel() + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + require.Equal(t, http.MethodGet, r.Method) + require.Equal(t, "/nf/check", r.URL.Path) + w.Header().Set("Content-Type", "application/json") + _ = json.NewEncoder(w).Encode(map[string]any{ + "rate_limit_and_credit_balance": map[string]any{ + "estimated_num_videos_remaining": 0, + "rate_limit_reached": true, + }, + }) + })) + defer server.Close() + + cfg := &config.Config{ + Sora: config.SoraConfig{ + Client: config.SoraClientConfig{ + BaseURL: server.URL, + }, + }, + } + client := NewSoraDirectClient(cfg, nil, nil) + account := &Account{ + ID: 12, + Platform: PlatformSora, + Type: AccountTypeOAuth, + Credentials: map[string]any{ + "access_token": "ok", + "expires_at": time.Now().Add(2 * time.Hour).Format(time.RFC3339), + }, + } + err := client.PreflightCheck(context.Background(), account, "sora2-landscape-10s", SoraModelConfig{Type: "video"}) + require.Error(t, err) + var upstreamErr *SoraUpstreamError + require.ErrorAs(t, err, &upstreamErr) + require.Equal(t, http.StatusTooManyRequests, upstreamErr.StatusCode) +} + +func TestShouldAttemptSoraTokenRecover(t *testing.T) { + t.Parallel() + + require.True(t, shouldAttemptSoraTokenRecover(http.StatusUnauthorized, "https://sora.chatgpt.com/backend/video_gen")) + require.True(t, shouldAttemptSoraTokenRecover(http.StatusForbidden, "https://chatgpt.com/backend/video_gen")) + require.False(t, shouldAttemptSoraTokenRecover(http.StatusUnauthorized, "https://sora.chatgpt.com/api/auth/session")) + require.False(t, shouldAttemptSoraTokenRecover(http.StatusUnauthorized, "https://auth.openai.com/oauth/token")) + require.False(t, shouldAttemptSoraTokenRecover(http.StatusTooManyRequests, "https://sora.chatgpt.com/backend/video_gen")) +} + +type soraClientRequestCall struct { + Path string + UserAgent string + ProxyURL string +} + +type soraClientRecordingUpstream struct { + calls []soraClientRequestCall +} + +func (u *soraClientRecordingUpstream) Do(_ *http.Request, _ string, _ int64, _ int) (*http.Response, error) { + return nil, errors.New("unexpected Do call") +} + +func (u *soraClientRecordingUpstream) DoWithTLS(req *http.Request, proxyURL string, _ int64, _ int, _ bool) (*http.Response, error) { + u.calls = append(u.calls, soraClientRequestCall{ + Path: req.URL.Path, + UserAgent: req.Header.Get("User-Agent"), + ProxyURL: proxyURL, + }) + switch req.URL.Path { + case "/backend-api/sentinel/req": + return newSoraClientMockResponse(http.StatusOK, `{"token":"sentinel-token","turnstile":{"dx":"ok"}}`), nil + case "/backend/nf/create": + return newSoraClientMockResponse(http.StatusOK, `{"id":"task-123"}`), nil + case "/backend/nf/create/storyboard": + return newSoraClientMockResponse(http.StatusOK, `{"id":"storyboard-123"}`), nil + case "/backend/uploads": + return newSoraClientMockResponse(http.StatusOK, `{"id":"upload-123"}`), nil + case "/backend/nf/check": + return newSoraClientMockResponse(http.StatusOK, `{"rate_limit_and_credit_balance":{"estimated_num_videos_remaining":1,"rate_limit_reached":false}}`), nil + case "/backend/characters/upload": + return newSoraClientMockResponse(http.StatusOK, `{"id":"cameo-123"}`), nil + case "/backend/project_y/cameos/in_progress/cameo-123": + return newSoraClientMockResponse(http.StatusOK, `{"status":"finalized","status_message":"Completed","username_hint":"foo.bar","display_name_hint":"Bar","profile_asset_url":"https://example.com/avatar.webp"}`), nil + case "/backend/project_y/file/upload": + return newSoraClientMockResponse(http.StatusOK, `{"asset_pointer":"asset-123"}`), nil + case "/backend/characters/finalize": + return newSoraClientMockResponse(http.StatusOK, `{"character":{"character_id":"character-123"}}`), nil + case "/backend/project_y/post": + return newSoraClientMockResponse(http.StatusOK, `{"post":{"id":"s_post"}}`), nil + default: + return newSoraClientMockResponse(http.StatusOK, `{"ok":true}`), nil + } +} + +func newSoraClientMockResponse(statusCode int, body string) *http.Response { + return &http.Response{ + StatusCode: statusCode, + Header: make(http.Header), + Body: io.NopCloser(strings.NewReader(body)), + } +} + +func TestSoraDirectClient_TaskUserAgent_DefaultMobileFallback(t *testing.T) { + client := NewSoraDirectClient(&config.Config{}, nil, nil) + ua := client.taskUserAgent() + require.NotEmpty(t, ua) + allowed := append([]string{}, soraMobileUserAgents...) + allowed = append(allowed, soraDesktopUserAgents...) + require.Contains(t, allowed, ua) +} + +func TestSoraDirectClient_CreateVideoTask_UsesSameUserAgentAndProxyForSentinelAndCreate(t *testing.T) { + originPowTokenGenerator := soraPowTokenGenerator + soraPowTokenGenerator = func(_ string) string { return "gAAAAACmock" } + defer func() { + soraPowTokenGenerator = originPowTokenGenerator + }() + + upstream := &soraClientRecordingUpstream{} + cfg := &config.Config{ + Sora: config.SoraConfig{ + Client: config.SoraClientConfig{ + BaseURL: "https://sora.chatgpt.com/backend", + }, + }, + } + client := NewSoraDirectClient(cfg, upstream, nil) + proxyID := int64(9) + account := &Account{ + ID: 21, + Platform: PlatformSora, + Type: AccountTypeOAuth, + Concurrency: 1, + ProxyID: &proxyID, + Proxy: &Proxy{ + Protocol: "http", + Host: "127.0.0.1", + Port: 8080, + }, + Credentials: map[string]any{ + "access_token": "access-token", + "expires_at": time.Now().Add(30 * time.Minute).Format(time.RFC3339), + }, + } + + taskID, err := client.CreateVideoTask(context.Background(), account, SoraVideoRequest{Prompt: "test"}) + require.NoError(t, err) + require.Equal(t, "task-123", taskID) + require.Len(t, upstream.calls, 2) + + sentinelCall := upstream.calls[0] + createCall := upstream.calls[1] + require.Equal(t, "/backend-api/sentinel/req", sentinelCall.Path) + require.Equal(t, "/backend/nf/create", createCall.Path) + require.Equal(t, "http://127.0.0.1:8080", sentinelCall.ProxyURL) + require.Equal(t, sentinelCall.ProxyURL, createCall.ProxyURL) + require.NotEmpty(t, sentinelCall.UserAgent) + require.Equal(t, sentinelCall.UserAgent, createCall.UserAgent) +} + +func TestSoraDirectClient_UploadImage_UsesTaskUserAgentAndProxy(t *testing.T) { + upstream := &soraClientRecordingUpstream{} + cfg := &config.Config{ + Sora: config.SoraConfig{ + Client: config.SoraClientConfig{ + BaseURL: "https://sora.chatgpt.com/backend", + }, + }, + } + client := NewSoraDirectClient(cfg, upstream, nil) + proxyID := int64(3) + account := &Account{ + ID: 31, + ProxyID: &proxyID, + Proxy: &Proxy{ + Protocol: "http", + Host: "127.0.0.1", + Port: 8080, + }, + Credentials: map[string]any{ + "access_token": "access-token", + "expires_at": time.Now().Add(30 * time.Minute).Format(time.RFC3339), + }, + } + + uploadID, err := client.UploadImage(context.Background(), account, []byte("mock-image"), "a.png") + require.NoError(t, err) + require.Equal(t, "upload-123", uploadID) + require.Len(t, upstream.calls, 1) + require.Equal(t, "/backend/uploads", upstream.calls[0].Path) + require.Equal(t, "http://127.0.0.1:8080", upstream.calls[0].ProxyURL) + require.NotEmpty(t, upstream.calls[0].UserAgent) +} + +func TestSoraDirectClient_PreflightCheck_UsesTaskUserAgentAndProxy(t *testing.T) { + upstream := &soraClientRecordingUpstream{} + cfg := &config.Config{ + Sora: config.SoraConfig{ + Client: config.SoraClientConfig{ + BaseURL: "https://sora.chatgpt.com/backend", + }, + }, + } + client := NewSoraDirectClient(cfg, upstream, nil) + proxyID := int64(7) + account := &Account{ + ID: 41, + ProxyID: &proxyID, + Proxy: &Proxy{ + Protocol: "http", + Host: "127.0.0.1", + Port: 8080, + }, + Credentials: map[string]any{ + "access_token": "access-token", + "expires_at": time.Now().Add(30 * time.Minute).Format(time.RFC3339), + }, + } + + err := client.PreflightCheck(context.Background(), account, "sora2", SoraModelConfig{Type: "video"}) + require.NoError(t, err) + require.Len(t, upstream.calls, 1) + require.Equal(t, "/backend/nf/check", upstream.calls[0].Path) + require.Equal(t, "http://127.0.0.1:8080", upstream.calls[0].ProxyURL) + require.NotEmpty(t, upstream.calls[0].UserAgent) +} + +func TestSoraDirectClient_CreateStoryboardTask(t *testing.T) { + originPowTokenGenerator := soraPowTokenGenerator + soraPowTokenGenerator = func(_ string) string { return "gAAAAACmock" } + defer func() { soraPowTokenGenerator = originPowTokenGenerator }() + + upstream := &soraClientRecordingUpstream{} + cfg := &config.Config{ + Sora: config.SoraConfig{ + Client: config.SoraClientConfig{ + BaseURL: "https://sora.chatgpt.com/backend", + }, + }, + } + client := NewSoraDirectClient(cfg, upstream, nil) + account := &Account{ + ID: 51, + Credentials: map[string]any{ + "access_token": "access-token", + "expires_at": time.Now().Add(30 * time.Minute).Format(time.RFC3339), + }, + } + + taskID, err := client.CreateStoryboardTask(context.Background(), account, SoraStoryboardRequest{ + Prompt: "Shot 1:\nduration: 5sec\nScene: cat", + }) + require.NoError(t, err) + require.Equal(t, "storyboard-123", taskID) + require.Len(t, upstream.calls, 2) + require.Equal(t, "/backend-api/sentinel/req", upstream.calls[0].Path) + require.Equal(t, "/backend/nf/create/storyboard", upstream.calls[1].Path) +} + +func TestSoraDirectClient_GetVideoTask_ReturnsGenerationID(t *testing.T) { + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.Header().Set("Content-Type", "application/json") + switch r.URL.Path { + case "/nf/pending/v2": + _, _ = w.Write([]byte(`[]`)) + case "/project_y/profile/drafts": + _, _ = w.Write([]byte(`{"items":[{"id":"gen_1","task_id":"task-1","kind":"video","downloadable_url":"https://example.com/v.mp4"}]}`)) + default: + http.NotFound(w, r) + } + })) + defer server.Close() + + cfg := &config.Config{ + Sora: config.SoraConfig{ + Client: config.SoraClientConfig{ + BaseURL: server.URL, + }, + }, + } + client := NewSoraDirectClient(cfg, nil, nil) + account := &Account{Credentials: map[string]any{"access_token": "token"}} + + status, err := client.GetVideoTask(context.Background(), account, "task-1") + require.NoError(t, err) + require.Equal(t, "completed", status.Status) + require.Equal(t, "gen_1", status.GenerationID) + require.Equal(t, []string{"https://example.com/v.mp4"}, status.URLs) +} + +func TestSoraDirectClient_PostVideoForWatermarkFree(t *testing.T) { + originPowTokenGenerator := soraPowTokenGenerator + soraPowTokenGenerator = func(_ string) string { return "gAAAAACmock" } + defer func() { soraPowTokenGenerator = originPowTokenGenerator }() + + upstream := &soraClientRecordingUpstream{} + cfg := &config.Config{ + Sora: config.SoraConfig{ + Client: config.SoraClientConfig{ + BaseURL: "https://sora.chatgpt.com/backend", + }, + }, + } + client := NewSoraDirectClient(cfg, upstream, nil) + account := &Account{ + ID: 52, + Credentials: map[string]any{ + "access_token": "access-token", + "expires_at": time.Now().Add(30 * time.Minute).Format(time.RFC3339), + }, + } + + postID, err := client.PostVideoForWatermarkFree(context.Background(), account, "gen_1") + require.NoError(t, err) + require.Equal(t, "s_post", postID) + require.Len(t, upstream.calls, 2) + require.Equal(t, "/backend-api/sentinel/req", upstream.calls[0].Path) + require.Equal(t, "/backend/project_y/post", upstream.calls[1].Path) +} + +type soraClientFallbackUpstream struct { + doWithTLSCalls int32 + respBody string + respStatusCode int + err error +} + +func (u *soraClientFallbackUpstream) Do(_ *http.Request, _ string, _ int64, _ int) (*http.Response, error) { + return nil, errors.New("unexpected Do call") +} + +func (u *soraClientFallbackUpstream) DoWithTLS(_ *http.Request, _ string, _ int64, _ int, _ bool) (*http.Response, error) { + atomic.AddInt32(&u.doWithTLSCalls, 1) + if u.err != nil { + return nil, u.err + } + statusCode := u.respStatusCode + if statusCode <= 0 { + statusCode = http.StatusOK + } + body := u.respBody + if body == "" { + body = `{"ok":true}` + } + return newSoraClientMockResponse(statusCode, body), nil +} + +func TestSoraDirectClient_DoHTTP_UsesCurlCFFISidecarWhenEnabled(t *testing.T) { + var captured soraCurlCFFISidecarRequest + sidecar := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + require.Equal(t, http.MethodPost, r.Method) + require.Equal(t, "/request", r.URL.Path) + raw, err := io.ReadAll(r.Body) + require.NoError(t, err) + require.NoError(t, json.Unmarshal(raw, &captured)) + _ = json.NewEncoder(w).Encode(map[string]any{ + "status_code": http.StatusOK, + "headers": map[string]any{ + "Content-Type": "application/json", + "X-Sidecar": []string{"yes"}, + }, + "body_base64": base64.StdEncoding.EncodeToString([]byte(`{"ok":true}`)), + }) + })) + defer sidecar.Close() + + upstream := &soraClientFallbackUpstream{} + cfg := &config.Config{ + Sora: config.SoraConfig{ + Client: config.SoraClientConfig{ + BaseURL: "https://sora.chatgpt.com/backend", + CurlCFFISidecar: config.SoraCurlCFFISidecarConfig{ + Enabled: true, + BaseURL: sidecar.URL, + Impersonate: "chrome131", + TimeoutSeconds: 15, + SessionReuseEnabled: true, + }, + }, + }, + } + client := NewSoraDirectClient(cfg, upstream, nil) + req, err := http.NewRequest(http.MethodPost, "https://sora.chatgpt.com/backend/me", strings.NewReader("hello-sidecar")) + require.NoError(t, err) + req.Header.Set("User-Agent", "test-ua") + + resp, err := client.doHTTP(req, "http://127.0.0.1:18080", &Account{ID: 1}) + require.NoError(t, err) + defer resp.Body.Close() + body, err := io.ReadAll(resp.Body) + require.NoError(t, err) + + require.JSONEq(t, `{"ok":true}`, string(body)) + require.Equal(t, int32(0), atomic.LoadInt32(&upstream.doWithTLSCalls)) + require.Equal(t, "http://127.0.0.1:18080", captured.ProxyURL) + require.NotEmpty(t, captured.SessionKey) + require.Equal(t, "chrome131", captured.Impersonate) + require.Equal(t, "https://sora.chatgpt.com/backend/me", captured.URL) + decodedReqBody, err := base64.StdEncoding.DecodeString(captured.BodyBase64) + require.NoError(t, err) + require.Equal(t, "hello-sidecar", string(decodedReqBody)) +} + +func TestSoraDirectClient_DoHTTP_CurlCFFISidecarFailureReturnsError(t *testing.T) { + sidecar := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusBadGateway) + _, _ = w.Write([]byte(`{"error":"boom"}`)) + })) + defer sidecar.Close() + + upstream := &soraClientFallbackUpstream{respBody: `{"fallback":true}`} + cfg := &config.Config{ + Sora: config.SoraConfig{ + Client: config.SoraClientConfig{ + BaseURL: "https://sora.chatgpt.com/backend", + CurlCFFISidecar: config.SoraCurlCFFISidecarConfig{ + Enabled: true, + BaseURL: sidecar.URL, + }, + }, + }, + } + client := NewSoraDirectClient(cfg, upstream, nil) + req, err := http.NewRequest(http.MethodGet, "https://sora.chatgpt.com/backend/me", nil) + require.NoError(t, err) + + _, err = client.doHTTP(req, "", &Account{ID: 2}) + require.Error(t, err) + require.Contains(t, err.Error(), "sora curl_cffi sidecar") + require.Equal(t, int32(0), atomic.LoadInt32(&upstream.doWithTLSCalls)) +} + +func TestSoraDirectClient_DoHTTP_CurlCFFISidecarDisabledUsesLegacyStack(t *testing.T) { + upstream := &soraClientFallbackUpstream{respBody: `{"legacy":true}`} + cfg := &config.Config{ + Sora: config.SoraConfig{ + Client: config.SoraClientConfig{ + BaseURL: "https://sora.chatgpt.com/backend", + CurlCFFISidecar: config.SoraCurlCFFISidecarConfig{ + Enabled: false, + BaseURL: "http://127.0.0.1:18080", + }, + }, + }, + } + client := NewSoraDirectClient(cfg, upstream, nil) + req, err := http.NewRequest(http.MethodGet, "https://sora.chatgpt.com/backend/me", nil) + require.NoError(t, err) + + resp, err := client.doHTTP(req, "", &Account{ID: 3}) + require.NoError(t, err) + defer resp.Body.Close() + body, err := io.ReadAll(resp.Body) + require.NoError(t, err) + require.JSONEq(t, `{"legacy":true}`, string(body)) + require.Equal(t, int32(1), atomic.LoadInt32(&upstream.doWithTLSCalls)) +} + +func TestConvertSidecarHeaderValue_NilAndSlice(t *testing.T) { + require.Nil(t, convertSidecarHeaderValue(nil)) + require.Equal(t, []string{"a", "b"}, convertSidecarHeaderValue([]any{"a", " ", "b"})) +} + +func TestSoraDirectClient_DoHTTP_SidecarSessionKeyStableForSameAccountProxy(t *testing.T) { + var captured []soraCurlCFFISidecarRequest + sidecar := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + raw, err := io.ReadAll(r.Body) + require.NoError(t, err) + var reqPayload soraCurlCFFISidecarRequest + require.NoError(t, json.Unmarshal(raw, &reqPayload)) + captured = append(captured, reqPayload) + _ = json.NewEncoder(w).Encode(map[string]any{ + "status_code": http.StatusOK, + "headers": map[string]any{ + "Content-Type": "application/json", + }, + "body": `{"ok":true}`, + }) + })) + defer sidecar.Close() + + cfg := &config.Config{ + Sora: config.SoraConfig{ + Client: config.SoraClientConfig{ + BaseURL: "https://sora.chatgpt.com/backend", + CurlCFFISidecar: config.SoraCurlCFFISidecarConfig{ + Enabled: true, + BaseURL: sidecar.URL, + SessionReuseEnabled: true, + SessionTTLSeconds: 3600, + }, + }, + }, + } + client := NewSoraDirectClient(cfg, nil, nil) + account := &Account{ID: 1001} + + req1, err := http.NewRequest(http.MethodGet, "https://sora.chatgpt.com/backend/me", nil) + require.NoError(t, err) + _, err = client.doHTTP(req1, "http://127.0.0.1:18080", account) + require.NoError(t, err) + + req2, err := http.NewRequest(http.MethodGet, "https://sora.chatgpt.com/backend/me", nil) + require.NoError(t, err) + _, err = client.doHTTP(req2, "http://127.0.0.1:18080", account) + require.NoError(t, err) + + require.Len(t, captured, 2) + require.NotEmpty(t, captured[0].SessionKey) + require.Equal(t, captured[0].SessionKey, captured[1].SessionKey) +} + +func TestSoraDirectClient_DoRequestWithProxy_CloudflareChallengeSetsCooldownAfterSingleRetry(t *testing.T) { + var sidecarCalls int32 + sidecar := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + atomic.AddInt32(&sidecarCalls, 1) + _ = json.NewEncoder(w).Encode(map[string]any{ + "status_code": http.StatusForbidden, + "headers": map[string]any{ + "cf-ray": "9d05d73dec4d8c8e-GRU", + "content-type": "text/html", + }, + "body": `Just a moment...`, + }) + })) + defer sidecar.Close() + + cfg := &config.Config{ + Sora: config.SoraConfig{ + Client: config.SoraClientConfig{ + BaseURL: "https://sora.chatgpt.com/backend", + MaxRetries: 3, + CloudflareChallengeCooldownSeconds: 60, + CurlCFFISidecar: config.SoraCurlCFFISidecarConfig{ + Enabled: true, + BaseURL: sidecar.URL, + Impersonate: "chrome131", + }, + }, + }, + } + client := NewSoraDirectClient(cfg, nil, nil) + headers := http.Header{} + + _, _, err := client.doRequestWithProxy( + context.Background(), + &Account{ID: 99}, + "http://127.0.0.1:18080", + http.MethodGet, + "https://sora.chatgpt.com/backend/me", + headers, + nil, + true, + ) + require.Error(t, err) + var upstreamErr *SoraUpstreamError + require.ErrorAs(t, err, &upstreamErr) + require.Equal(t, http.StatusForbidden, upstreamErr.StatusCode) + require.Equal(t, int32(2), atomic.LoadInt32(&sidecarCalls), "challenge should trigger exactly one same-proxy retry") + + _, _, err = client.doRequestWithProxy( + context.Background(), + &Account{ID: 99}, + "http://127.0.0.1:18080", + http.MethodGet, + "https://sora.chatgpt.com/backend/me", + headers, + nil, + true, + ) + require.Error(t, err) + require.ErrorAs(t, err, &upstreamErr) + require.Equal(t, http.StatusTooManyRequests, upstreamErr.StatusCode) + require.Contains(t, upstreamErr.Message, "cooling down") + require.Contains(t, upstreamErr.Message, "cf-ray") + require.Equal(t, int32(2), atomic.LoadInt32(&sidecarCalls), "cooldown should block outbound request") +} + +func TestSoraDirectClient_DoRequestWithProxy_CloudflareRetrySuccessClearsCooldown(t *testing.T) { + var sidecarCalls int32 + sidecar := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + call := atomic.AddInt32(&sidecarCalls, 1) + if call == 1 { + _ = json.NewEncoder(w).Encode(map[string]any{ + "status_code": http.StatusForbidden, + "headers": map[string]any{ + "cf-ray": "9d05d73dec4d8c8e-GRU", + "content-type": "text/html", + }, + "body": `Just a moment...`, + }) + return + } + _ = json.NewEncoder(w).Encode(map[string]any{ + "status_code": http.StatusOK, + "headers": map[string]any{ + "content-type": "application/json", + }, + "body": `{"ok":true}`, + }) + })) + defer sidecar.Close() + + cfg := &config.Config{ + Sora: config.SoraConfig{ + Client: config.SoraClientConfig{ + BaseURL: "https://sora.chatgpt.com/backend", + MaxRetries: 3, + CloudflareChallengeCooldownSeconds: 60, + CurlCFFISidecar: config.SoraCurlCFFISidecarConfig{ + Enabled: true, + BaseURL: sidecar.URL, + Impersonate: "chrome131", + }, + }, + }, + } + client := NewSoraDirectClient(cfg, nil, nil) + headers := http.Header{} + account := &Account{ID: 109} + proxyURL := "http://127.0.0.1:18080" + + body, _, err := client.doRequestWithProxy( + context.Background(), + account, + proxyURL, + http.MethodGet, + "https://sora.chatgpt.com/backend/me", + headers, + nil, + true, + ) + require.NoError(t, err) + require.Contains(t, string(body), `"ok":true`) + require.Equal(t, int32(2), atomic.LoadInt32(&sidecarCalls)) + + _, _, err = client.doRequestWithProxy( + context.Background(), + account, + proxyURL, + http.MethodGet, + "https://sora.chatgpt.com/backend/me", + headers, + nil, + true, + ) + require.NoError(t, err) + require.Equal(t, int32(3), atomic.LoadInt32(&sidecarCalls), "cooldown should be cleared after retry succeeds") +} + +func TestSoraComputeChallengeCooldownSeconds(t *testing.T) { + require.Equal(t, 0, soraComputeChallengeCooldownSeconds(0, 3)) + require.Equal(t, 10, soraComputeChallengeCooldownSeconds(10, 1)) + require.Equal(t, 20, soraComputeChallengeCooldownSeconds(10, 2)) + require.Equal(t, 40, soraComputeChallengeCooldownSeconds(10, 4)) + require.Equal(t, 40, soraComputeChallengeCooldownSeconds(10, 9), "streak should cap at x4") + require.Equal(t, 3600, soraComputeChallengeCooldownSeconds(1200, 9), "cooldown should cap at 3600s") +} + +func TestSoraDirectClient_RecordCloudflareChallengeCooldown_EscalatesStreak(t *testing.T) { + cfg := &config.Config{ + Sora: config.SoraConfig{ + Client: config.SoraClientConfig{ + CloudflareChallengeCooldownSeconds: 10, + }, + }, + } + client := NewSoraDirectClient(cfg, nil, nil) + account := &Account{ID: 201} + proxyURL := "http://127.0.0.1:18080" + + client.recordCloudflareChallengeCooldown(account, proxyURL, http.StatusForbidden, http.Header{"Cf-Ray": []string{"9d05d73dec4d8c8e-GRU"}}, nil) + client.recordCloudflareChallengeCooldown(account, proxyURL, http.StatusForbidden, http.Header{"Cf-Ray": []string{"9d05d73dec4d8c8f-GRU"}}, nil) + + key := soraAccountProxyKey(account, proxyURL) + entry, ok := client.challengeCooldowns[key] + require.True(t, ok) + require.Equal(t, 2, entry.ConsecutiveChallenges) + require.Equal(t, "9d05d73dec4d8c8f-GRU", entry.CFRay) + remain := int(entry.Until.Sub(entry.LastChallengeAt).Seconds()) + require.GreaterOrEqual(t, remain, 19) +} + +func TestSoraDirectClient_SidecarSessionKey_SkipsWhenAccountMissing(t *testing.T) { + cfg := &config.Config{ + Sora: config.SoraConfig{ + Client: config.SoraClientConfig{ + CurlCFFISidecar: config.SoraCurlCFFISidecarConfig{ + Enabled: true, + SessionReuseEnabled: true, + SessionTTLSeconds: 3600, + }, + }, + }, + } + client := NewSoraDirectClient(cfg, nil, nil) + require.Equal(t, "", client.sidecarSessionKey(nil, "http://127.0.0.1:18080")) + require.Empty(t, client.sidecarSessions) +} + +func TestSoraDirectClient_SidecarSessionKey_PrunesExpiredAndRecreates(t *testing.T) { + cfg := &config.Config{ + Sora: config.SoraConfig{ + Client: config.SoraClientConfig{ + CurlCFFISidecar: config.SoraCurlCFFISidecarConfig{ + Enabled: true, + SessionReuseEnabled: true, + SessionTTLSeconds: 3600, + }, + }, + }, + } + client := NewSoraDirectClient(cfg, nil, nil) + account := &Account{ID: 123} + key := soraAccountProxyKey(account, "http://127.0.0.1:18080") + client.sidecarSessions[key] = soraSidecarSessionEntry{ + SessionKey: "sora-expired", + ExpiresAt: time.Now().Add(-time.Minute), + LastUsedAt: time.Now().Add(-2 * time.Minute), + } + + sessionKey := client.sidecarSessionKey(account, "http://127.0.0.1:18080") + require.NotEmpty(t, sessionKey) + require.NotEqual(t, "sora-expired", sessionKey) + require.Len(t, client.sidecarSessions, 1) +} + +func TestSoraDirectClient_SidecarSessionKey_TTLZeroKeepsLongLivedSession(t *testing.T) { + cfg := &config.Config{ + Sora: config.SoraConfig{ + Client: config.SoraClientConfig{ + CurlCFFISidecar: config.SoraCurlCFFISidecarConfig{ + Enabled: true, + SessionReuseEnabled: true, + SessionTTLSeconds: 0, + }, + }, + }, + } + client := NewSoraDirectClient(cfg, nil, nil) + account := &Account{ID: 456} + + first := client.sidecarSessionKey(account, "http://127.0.0.1:18080") + second := client.sidecarSessionKey(account, "http://127.0.0.1:18080") + require.NotEmpty(t, first) + require.Equal(t, first, second) + + key := soraAccountProxyKey(account, "http://127.0.0.1:18080") + entry, ok := client.sidecarSessions[key] + require.True(t, ok) + require.True(t, entry.ExpiresAt.After(time.Now().Add(300*24*time.Hour))) +} diff --git a/backend/internal/service/sora_curl_cffi_sidecar.go b/backend/internal/service/sora_curl_cffi_sidecar.go new file mode 100644 index 00000000..40f5c017 --- /dev/null +++ b/backend/internal/service/sora_curl_cffi_sidecar.go @@ -0,0 +1,260 @@ +package service + +import ( + "bytes" + "encoding/base64" + "encoding/json" + "errors" + "fmt" + "io" + "net/http" + "net/url" + "strings" + "time" + + "github.com/Wei-Shaw/sub2api/internal/util/logredact" +) + +const soraCurlCFFISidecarDefaultTimeoutSeconds = 60 + +type soraCurlCFFISidecarRequest struct { + Method string `json:"method"` + URL string `json:"url"` + Headers map[string][]string `json:"headers,omitempty"` + BodyBase64 string `json:"body_base64,omitempty"` + ProxyURL string `json:"proxy_url,omitempty"` + SessionKey string `json:"session_key,omitempty"` + Impersonate string `json:"impersonate,omitempty"` + TimeoutSeconds int `json:"timeout_seconds,omitempty"` +} + +type soraCurlCFFISidecarResponse struct { + StatusCode int `json:"status_code"` + Status int `json:"status"` + Headers map[string]any `json:"headers"` + BodyBase64 string `json:"body_base64"` + Body string `json:"body"` + Error string `json:"error"` +} + +func (c *SoraDirectClient) doHTTPViaCurlCFFISidecar(req *http.Request, proxyURL string, account *Account) (*http.Response, error) { + if req == nil || req.URL == nil { + return nil, errors.New("request url is nil") + } + if c == nil || c.cfg == nil { + return nil, errors.New("sora curl_cffi sidecar config is nil") + } + if !c.cfg.Sora.Client.CurlCFFISidecar.Enabled { + return nil, errors.New("sora curl_cffi sidecar is disabled") + } + endpoint := c.curlCFFISidecarEndpoint() + if endpoint == "" { + return nil, errors.New("sora curl_cffi sidecar base_url is empty") + } + + bodyBytes, err := readAndRestoreRequestBody(req) + if err != nil { + return nil, fmt.Errorf("sora curl_cffi sidecar read request body failed: %w", err) + } + + headers := make(map[string][]string, len(req.Header)+1) + for key, vals := range req.Header { + copied := make([]string, len(vals)) + copy(copied, vals) + headers[key] = copied + } + if strings.TrimSpace(req.Host) != "" { + if _, ok := headers["Host"]; !ok { + headers["Host"] = []string{req.Host} + } + } + + payload := soraCurlCFFISidecarRequest{ + Method: req.Method, + URL: req.URL.String(), + Headers: headers, + ProxyURL: strings.TrimSpace(proxyURL), + SessionKey: c.sidecarSessionKey(account, proxyURL), + Impersonate: c.curlCFFIImpersonate(), + TimeoutSeconds: c.curlCFFISidecarTimeoutSeconds(), + } + if len(bodyBytes) > 0 { + payload.BodyBase64 = base64.StdEncoding.EncodeToString(bodyBytes) + } + + encoded, err := json.Marshal(payload) + if err != nil { + return nil, fmt.Errorf("sora curl_cffi sidecar marshal request failed: %w", err) + } + + sidecarReq, err := http.NewRequestWithContext(req.Context(), http.MethodPost, endpoint, bytes.NewReader(encoded)) + if err != nil { + return nil, fmt.Errorf("sora curl_cffi sidecar build request failed: %w", err) + } + sidecarReq.Header.Set("Content-Type", "application/json") + sidecarReq.Header.Set("Accept", "application/json") + + httpClient := &http.Client{Timeout: time.Duration(payload.TimeoutSeconds) * time.Second} + sidecarResp, err := httpClient.Do(sidecarReq) + if err != nil { + return nil, fmt.Errorf("sora curl_cffi sidecar request failed: %w", err) + } + defer func() { + _ = sidecarResp.Body.Close() + }() + + sidecarRespBody, err := io.ReadAll(io.LimitReader(sidecarResp.Body, 8<<20)) + if err != nil { + return nil, fmt.Errorf("sora curl_cffi sidecar read response failed: %w", err) + } + if sidecarResp.StatusCode != http.StatusOK { + redacted := truncateForLog([]byte(logredact.RedactText(string(sidecarRespBody))), 512) + return nil, fmt.Errorf("sora curl_cffi sidecar http status=%d body=%s", sidecarResp.StatusCode, redacted) + } + + var payloadResp soraCurlCFFISidecarResponse + if err := json.Unmarshal(sidecarRespBody, &payloadResp); err != nil { + return nil, fmt.Errorf("sora curl_cffi sidecar parse response failed: %w", err) + } + if msg := strings.TrimSpace(payloadResp.Error); msg != "" { + return nil, fmt.Errorf("sora curl_cffi sidecar upstream error: %s", msg) + } + statusCode := payloadResp.StatusCode + if statusCode <= 0 { + statusCode = payloadResp.Status + } + if statusCode <= 0 { + return nil, errors.New("sora curl_cffi sidecar response missing status code") + } + + responseBody := []byte(payloadResp.Body) + if strings.TrimSpace(payloadResp.BodyBase64) != "" { + decoded, err := base64.StdEncoding.DecodeString(payloadResp.BodyBase64) + if err != nil { + return nil, fmt.Errorf("sora curl_cffi sidecar decode body failed: %w", err) + } + responseBody = decoded + } + + respHeaders := make(http.Header) + for key, rawVal := range payloadResp.Headers { + for _, v := range convertSidecarHeaderValue(rawVal) { + respHeaders.Add(key, v) + } + } + + return &http.Response{ + StatusCode: statusCode, + Header: respHeaders, + Body: io.NopCloser(bytes.NewReader(responseBody)), + ContentLength: int64(len(responseBody)), + Request: req, + }, nil +} + +func readAndRestoreRequestBody(req *http.Request) ([]byte, error) { + if req == nil || req.Body == nil { + return nil, nil + } + bodyBytes, err := io.ReadAll(req.Body) + if err != nil { + return nil, err + } + _ = req.Body.Close() + req.Body = io.NopCloser(bytes.NewReader(bodyBytes)) + req.ContentLength = int64(len(bodyBytes)) + return bodyBytes, nil +} + +func (c *SoraDirectClient) curlCFFISidecarEndpoint() string { + if c == nil || c.cfg == nil { + return "" + } + raw := strings.TrimSpace(c.cfg.Sora.Client.CurlCFFISidecar.BaseURL) + if raw == "" { + return "" + } + parsed, err := url.Parse(raw) + if err != nil || strings.TrimSpace(parsed.Scheme) == "" || strings.TrimSpace(parsed.Host) == "" { + return raw + } + if path := strings.TrimSpace(parsed.Path); path == "" || path == "/" { + parsed.Path = "/request" + } + return parsed.String() +} + +func (c *SoraDirectClient) curlCFFISidecarTimeoutSeconds() int { + if c == nil || c.cfg == nil { + return soraCurlCFFISidecarDefaultTimeoutSeconds + } + timeoutSeconds := c.cfg.Sora.Client.CurlCFFISidecar.TimeoutSeconds + if timeoutSeconds <= 0 { + return soraCurlCFFISidecarDefaultTimeoutSeconds + } + return timeoutSeconds +} + +func (c *SoraDirectClient) curlCFFIImpersonate() string { + if c == nil || c.cfg == nil { + return "chrome131" + } + impersonate := strings.TrimSpace(c.cfg.Sora.Client.CurlCFFISidecar.Impersonate) + if impersonate == "" { + return "chrome131" + } + return impersonate +} + +func (c *SoraDirectClient) sidecarSessionReuseEnabled() bool { + if c == nil || c.cfg == nil { + return true + } + return c.cfg.Sora.Client.CurlCFFISidecar.SessionReuseEnabled +} + +func (c *SoraDirectClient) sidecarSessionTTLSeconds() int { + if c == nil || c.cfg == nil { + return 3600 + } + ttl := c.cfg.Sora.Client.CurlCFFISidecar.SessionTTLSeconds + if ttl < 0 { + return 3600 + } + return ttl +} + +func convertSidecarHeaderValue(raw any) []string { + switch val := raw.(type) { + case nil: + return nil + case string: + if strings.TrimSpace(val) == "" { + return nil + } + return []string{val} + case []any: + out := make([]string, 0, len(val)) + for _, item := range val { + s := strings.TrimSpace(fmt.Sprint(item)) + if s != "" { + out = append(out, s) + } + } + return out + case []string: + out := make([]string, 0, len(val)) + for _, item := range val { + if strings.TrimSpace(item) != "" { + out = append(out, item) + } + } + return out + default: + s := strings.TrimSpace(fmt.Sprint(val)) + if s == "" { + return nil + } + return []string{s} + } +} diff --git a/backend/internal/service/sora_gateway_service.go b/backend/internal/service/sora_gateway_service.go new file mode 100644 index 00000000..ac29ae0d --- /dev/null +++ b/backend/internal/service/sora_gateway_service.go @@ -0,0 +1,1464 @@ +package service + +import ( + "context" + "encoding/base64" + "encoding/json" + "errors" + "fmt" + "io" + "log" + "math" + "mime" + "net" + "net/http" + "net/url" + "regexp" + "strconv" + "strings" + "time" + + "github.com/Wei-Shaw/sub2api/internal/config" + "github.com/gin-gonic/gin" +) + +const soraImageInputMaxBytes = 20 << 20 +const soraImageInputMaxRedirects = 3 +const soraImageInputTimeout = 20 * time.Second +const soraVideoInputMaxBytes = 200 << 20 +const soraVideoInputMaxRedirects = 3 +const soraVideoInputTimeout = 60 * time.Second + +var soraImageSizeMap = map[string]string{ + "gpt-image": "360", + "gpt-image-landscape": "540", + "gpt-image-portrait": "540", +} + +var soraBlockedHostnames = map[string]struct{}{ + "localhost": {}, + "localhost.localdomain": {}, + "metadata.google.internal": {}, + "metadata.google.internal.": {}, +} + +var soraBlockedCIDRs = mustParseCIDRs([]string{ + "0.0.0.0/8", + "10.0.0.0/8", + "100.64.0.0/10", + "127.0.0.0/8", + "169.254.0.0/16", + "172.16.0.0/12", + "192.168.0.0/16", + "224.0.0.0/4", + "240.0.0.0/4", + "::/128", + "::1/128", + "fc00::/7", + "fe80::/10", +}) + +// SoraGatewayService handles forwarding requests to Sora upstream. +type SoraGatewayService struct { + soraClient SoraClient + mediaStorage *SoraMediaStorage + rateLimitService *RateLimitService + cfg *config.Config +} + +type soraWatermarkOptions struct { + Enabled bool + ParseMethod string + ParseURL string + ParseToken string + FallbackOnFailure bool + DeletePost bool +} + +type soraCharacterOptions struct { + SetPublic bool + DeleteAfterGenerate bool +} + +type soraCharacterFlowResult struct { + CameoID string + CharacterID string + Username string + DisplayName string +} + +var soraStoryboardPattern = regexp.MustCompile(`\[\d+(?:\.\d+)?s\]`) +var soraStoryboardShotPattern = regexp.MustCompile(`\[(\d+(?:\.\d+)?)s\]\s*([^\[]+)`) +var soraRemixTargetPattern = regexp.MustCompile(`s_[a-f0-9]{32}`) +var soraRemixTargetInURLPattern = regexp.MustCompile(`https://sora\.chatgpt\.com/p/s_[a-f0-9]{32}`) + +type soraPreflightChecker interface { + PreflightCheck(ctx context.Context, account *Account, requestedModel string, modelCfg SoraModelConfig) error +} + +func NewSoraGatewayService( + soraClient SoraClient, + mediaStorage *SoraMediaStorage, + rateLimitService *RateLimitService, + cfg *config.Config, +) *SoraGatewayService { + return &SoraGatewayService{ + soraClient: soraClient, + mediaStorage: mediaStorage, + rateLimitService: rateLimitService, + cfg: cfg, + } +} + +func (s *SoraGatewayService) Forward(ctx context.Context, c *gin.Context, account *Account, body []byte, clientStream bool) (*ForwardResult, error) { + startTime := time.Now() + + if s.soraClient == nil || !s.soraClient.Enabled() { + if c != nil { + c.JSON(http.StatusServiceUnavailable, gin.H{ + "error": gin.H{ + "type": "api_error", + "message": "Sora 上游未配置", + }, + }) + } + return nil, errors.New("sora upstream not configured") + } + + var reqBody map[string]any + if err := json.Unmarshal(body, &reqBody); err != nil { + s.writeSoraError(c, http.StatusBadRequest, "invalid_request_error", "Failed to parse request body", clientStream) + return nil, fmt.Errorf("parse request: %w", err) + } + reqModel, _ := reqBody["model"].(string) + reqStream, _ := reqBody["stream"].(bool) + if strings.TrimSpace(reqModel) == "" { + s.writeSoraError(c, http.StatusBadRequest, "invalid_request_error", "model is required", clientStream) + return nil, errors.New("model is required") + } + + mappedModel := account.GetMappedModel(reqModel) + if mappedModel != "" && mappedModel != reqModel { + reqModel = mappedModel + } + + modelCfg, ok := GetSoraModelConfig(reqModel) + if !ok { + s.writeSoraError(c, http.StatusBadRequest, "invalid_request_error", "Unsupported Sora model", clientStream) + return nil, fmt.Errorf("unsupported model: %s", reqModel) + } + prompt, imageInput, videoInput, remixTargetID := extractSoraInput(reqBody) + prompt = strings.TrimSpace(prompt) + imageInput = strings.TrimSpace(imageInput) + videoInput = strings.TrimSpace(videoInput) + remixTargetID = strings.TrimSpace(remixTargetID) + + if videoInput != "" && modelCfg.Type != "video" { + s.writeSoraError(c, http.StatusBadRequest, "invalid_request_error", "video input only supports video models", clientStream) + return nil, errors.New("video input only supports video models") + } + if videoInput != "" && imageInput != "" { + s.writeSoraError(c, http.StatusBadRequest, "invalid_request_error", "image input and video input cannot be used together", clientStream) + return nil, errors.New("image input and video input cannot be used together") + } + characterOnly := videoInput != "" && prompt == "" + if modelCfg.Type == "prompt_enhance" && prompt == "" { + s.writeSoraError(c, http.StatusBadRequest, "invalid_request_error", "prompt is required", clientStream) + return nil, errors.New("prompt is required") + } + if modelCfg.Type != "prompt_enhance" && prompt == "" && !characterOnly { + s.writeSoraError(c, http.StatusBadRequest, "invalid_request_error", "prompt is required", clientStream) + return nil, errors.New("prompt is required") + } + + reqCtx, cancel := s.withSoraTimeout(ctx, reqStream) + if cancel != nil { + defer cancel() + } + if checker, ok := s.soraClient.(soraPreflightChecker); ok && !characterOnly { + if err := checker.PreflightCheck(reqCtx, account, reqModel, modelCfg); err != nil { + return nil, s.handleSoraRequestError(ctx, account, err, reqModel, c, clientStream) + } + } + + if modelCfg.Type == "prompt_enhance" { + enhancedPrompt, err := s.soraClient.EnhancePrompt(reqCtx, account, prompt, modelCfg.ExpansionLevel, modelCfg.DurationS) + if err != nil { + return nil, s.handleSoraRequestError(ctx, account, err, reqModel, c, clientStream) + } + content := strings.TrimSpace(enhancedPrompt) + if content == "" { + content = prompt + } + var firstTokenMs *int + if clientStream { + ms, streamErr := s.writeSoraStream(c, reqModel, content, startTime) + if streamErr != nil { + return nil, streamErr + } + firstTokenMs = ms + } else if c != nil { + c.JSON(http.StatusOK, buildSoraNonStreamResponse(content, reqModel)) + } + return &ForwardResult{ + RequestID: "", + Model: reqModel, + Stream: clientStream, + Duration: time.Since(startTime), + FirstTokenMs: firstTokenMs, + Usage: ClaudeUsage{}, + MediaType: "prompt", + }, nil + } + + characterOpts := parseSoraCharacterOptions(reqBody) + watermarkOpts := parseSoraWatermarkOptions(reqBody) + var characterResult *soraCharacterFlowResult + if videoInput != "" { + videoData, videoErr := decodeSoraVideoInput(reqCtx, videoInput) + if videoErr != nil { + s.writeSoraError(c, http.StatusBadRequest, "invalid_request_error", videoErr.Error(), clientStream) + return nil, videoErr + } + characterResult, videoErr = s.createCharacterFromVideo(reqCtx, account, videoData, characterOpts) + if videoErr != nil { + return nil, s.handleSoraRequestError(ctx, account, videoErr, reqModel, c, clientStream) + } + if characterResult != nil && characterOpts.DeleteAfterGenerate && strings.TrimSpace(characterResult.CharacterID) != "" && !characterOnly { + characterID := strings.TrimSpace(characterResult.CharacterID) + defer func() { + cleanupCtx, cancelCleanup := context.WithTimeout(context.Background(), 15*time.Second) + defer cancelCleanup() + if err := s.soraClient.DeleteCharacter(cleanupCtx, account, characterID); err != nil { + log.Printf("[Sora] cleanup character failed, character_id=%s err=%v", characterID, err) + } + }() + } + if characterOnly { + content := "角色创建成功" + if characterResult != nil && strings.TrimSpace(characterResult.Username) != "" { + content = fmt.Sprintf("角色创建成功,角色名@%s", strings.TrimSpace(characterResult.Username)) + } + var firstTokenMs *int + if clientStream { + ms, streamErr := s.writeSoraStream(c, reqModel, content, startTime) + if streamErr != nil { + return nil, streamErr + } + firstTokenMs = ms + } else if c != nil { + resp := buildSoraNonStreamResponse(content, reqModel) + if characterResult != nil { + resp["character_id"] = characterResult.CharacterID + resp["cameo_id"] = characterResult.CameoID + resp["character_username"] = characterResult.Username + resp["character_display_name"] = characterResult.DisplayName + } + c.JSON(http.StatusOK, resp) + } + return &ForwardResult{ + RequestID: "", + Model: reqModel, + Stream: clientStream, + Duration: time.Since(startTime), + FirstTokenMs: firstTokenMs, + Usage: ClaudeUsage{}, + MediaType: "prompt", + }, nil + } + if characterResult != nil && strings.TrimSpace(characterResult.Username) != "" { + prompt = fmt.Sprintf("@%s %s", characterResult.Username, prompt) + } + } + + var imageData []byte + imageFilename := "" + if imageInput != "" { + decoded, filename, err := decodeSoraImageInput(reqCtx, imageInput) + if err != nil { + s.writeSoraError(c, http.StatusBadRequest, "invalid_request_error", err.Error(), clientStream) + return nil, err + } + imageData = decoded + imageFilename = filename + } + + mediaID := "" + if len(imageData) > 0 { + uploadID, err := s.soraClient.UploadImage(reqCtx, account, imageData, imageFilename) + if err != nil { + return nil, s.handleSoraRequestError(ctx, account, err, reqModel, c, clientStream) + } + mediaID = uploadID + } + + taskID := "" + var err error + switch modelCfg.Type { + case "image": + taskID, err = s.soraClient.CreateImageTask(reqCtx, account, SoraImageRequest{ + Prompt: prompt, + Width: modelCfg.Width, + Height: modelCfg.Height, + MediaID: mediaID, + }) + case "video": + if remixTargetID == "" && isSoraStoryboardPrompt(prompt) { + taskID, err = s.soraClient.CreateStoryboardTask(reqCtx, account, SoraStoryboardRequest{ + Prompt: formatSoraStoryboardPrompt(prompt), + Orientation: modelCfg.Orientation, + Frames: modelCfg.Frames, + Model: modelCfg.Model, + Size: modelCfg.Size, + MediaID: mediaID, + }) + } else { + taskID, err = s.soraClient.CreateVideoTask(reqCtx, account, SoraVideoRequest{ + Prompt: prompt, + Orientation: modelCfg.Orientation, + Frames: modelCfg.Frames, + Model: modelCfg.Model, + Size: modelCfg.Size, + MediaID: mediaID, + RemixTargetID: remixTargetID, + CameoIDs: extractSoraCameoIDs(reqBody), + }) + } + default: + err = fmt.Errorf("unsupported model type: %s", modelCfg.Type) + } + if err != nil { + return nil, s.handleSoraRequestError(ctx, account, err, reqModel, c, clientStream) + } + + if clientStream && c != nil { + s.prepareSoraStream(c, taskID) + } + + var mediaURLs []string + videoGenerationID := "" + mediaType := modelCfg.Type + imageCount := 0 + imageSize := "" + switch modelCfg.Type { + case "image": + urls, pollErr := s.pollImageTask(reqCtx, c, account, taskID, clientStream) + if pollErr != nil { + return nil, s.handleSoraRequestError(ctx, account, pollErr, reqModel, c, clientStream) + } + mediaURLs = urls + imageCount = len(urls) + imageSize = soraImageSizeFromModel(reqModel) + case "video": + videoStatus, pollErr := s.pollVideoTaskDetailed(reqCtx, c, account, taskID, clientStream) + if pollErr != nil { + return nil, s.handleSoraRequestError(ctx, account, pollErr, reqModel, c, clientStream) + } + if videoStatus != nil { + mediaURLs = videoStatus.URLs + videoGenerationID = strings.TrimSpace(videoStatus.GenerationID) + } + default: + mediaType = "prompt" + } + + watermarkPostID := "" + if modelCfg.Type == "video" && watermarkOpts.Enabled { + watermarkURL, postID, watermarkErr := s.resolveWatermarkFreeURL(reqCtx, account, videoGenerationID, watermarkOpts) + if watermarkErr != nil { + if !watermarkOpts.FallbackOnFailure { + return nil, s.handleSoraRequestError(ctx, account, watermarkErr, reqModel, c, clientStream) + } + log.Printf("[Sora] watermark-free fallback to original URL, task_id=%s err=%v", taskID, watermarkErr) + } else if strings.TrimSpace(watermarkURL) != "" { + mediaURLs = []string{strings.TrimSpace(watermarkURL)} + watermarkPostID = strings.TrimSpace(postID) + } + } + + finalURLs := s.normalizeSoraMediaURLs(mediaURLs) + if len(mediaURLs) > 0 && s.mediaStorage != nil && s.mediaStorage.Enabled() { + stored, storeErr := s.mediaStorage.StoreFromURLs(reqCtx, mediaType, mediaURLs) + if storeErr != nil { + // 存储失败时降级使用原始 URL,不中断用户请求 + log.Printf("[Sora] StoreFromURLs failed, falling back to original URLs: %v", storeErr) + } else { + finalURLs = s.normalizeSoraMediaURLs(stored) + } + } + if watermarkPostID != "" && watermarkOpts.DeletePost { + if deleteErr := s.soraClient.DeletePost(reqCtx, account, watermarkPostID); deleteErr != nil { + log.Printf("[Sora] delete post failed, post_id=%s err=%v", watermarkPostID, deleteErr) + } + } + + content := buildSoraContent(mediaType, finalURLs) + var firstTokenMs *int + if clientStream { + ms, streamErr := s.writeSoraStream(c, reqModel, content, startTime) + if streamErr != nil { + return nil, streamErr + } + firstTokenMs = ms + } else if c != nil { + response := buildSoraNonStreamResponse(content, reqModel) + if len(finalURLs) > 0 { + response["media_url"] = finalURLs[0] + if len(finalURLs) > 1 { + response["media_urls"] = finalURLs + } + } + c.JSON(http.StatusOK, response) + } + + return &ForwardResult{ + RequestID: taskID, + Model: reqModel, + Stream: clientStream, + Duration: time.Since(startTime), + FirstTokenMs: firstTokenMs, + Usage: ClaudeUsage{}, + MediaType: mediaType, + MediaURL: firstMediaURL(finalURLs), + ImageCount: imageCount, + ImageSize: imageSize, + }, nil +} + +func (s *SoraGatewayService) withSoraTimeout(ctx context.Context, stream bool) (context.Context, context.CancelFunc) { + if s == nil || s.cfg == nil { + return ctx, nil + } + timeoutSeconds := s.cfg.Gateway.SoraRequestTimeoutSeconds + if stream { + timeoutSeconds = s.cfg.Gateway.SoraStreamTimeoutSeconds + } + if timeoutSeconds <= 0 { + return ctx, nil + } + return context.WithTimeout(ctx, time.Duration(timeoutSeconds)*time.Second) +} + +func parseSoraWatermarkOptions(body map[string]any) soraWatermarkOptions { + opts := soraWatermarkOptions{ + Enabled: parseBoolWithDefault(body, "watermark_free", false), + ParseMethod: strings.ToLower(strings.TrimSpace(parseStringWithDefault(body, "watermark_parse_method", "third_party"))), + ParseURL: strings.TrimSpace(parseStringWithDefault(body, "watermark_parse_url", "")), + ParseToken: strings.TrimSpace(parseStringWithDefault(body, "watermark_parse_token", "")), + FallbackOnFailure: parseBoolWithDefault(body, "watermark_fallback_on_failure", true), + DeletePost: parseBoolWithDefault(body, "watermark_delete_post", false), + } + if opts.ParseMethod == "" { + opts.ParseMethod = "third_party" + } + return opts +} + +func parseSoraCharacterOptions(body map[string]any) soraCharacterOptions { + return soraCharacterOptions{ + SetPublic: parseBoolWithDefault(body, "character_set_public", true), + DeleteAfterGenerate: parseBoolWithDefault(body, "character_delete_after_generate", true), + } +} + +func parseBoolWithDefault(body map[string]any, key string, def bool) bool { + if body == nil { + return def + } + val, ok := body[key] + if !ok { + return def + } + switch typed := val.(type) { + case bool: + return typed + case int: + return typed != 0 + case int32: + return typed != 0 + case int64: + return typed != 0 + case float64: + return typed != 0 + case string: + typed = strings.ToLower(strings.TrimSpace(typed)) + if typed == "true" || typed == "1" || typed == "yes" { + return true + } + if typed == "false" || typed == "0" || typed == "no" { + return false + } + } + return def +} + +func parseStringWithDefault(body map[string]any, key, def string) string { + if body == nil { + return def + } + val, ok := body[key] + if !ok { + return def + } + if str, ok := val.(string); ok { + return str + } + return def +} + +func extractSoraCameoIDs(body map[string]any) []string { + if body == nil { + return nil + } + raw, ok := body["cameo_ids"] + if !ok { + return nil + } + switch typed := raw.(type) { + case []string: + out := make([]string, 0, len(typed)) + for _, item := range typed { + item = strings.TrimSpace(item) + if item != "" { + out = append(out, item) + } + } + return out + case []any: + out := make([]string, 0, len(typed)) + for _, item := range typed { + str, ok := item.(string) + if !ok { + continue + } + str = strings.TrimSpace(str) + if str != "" { + out = append(out, str) + } + } + return out + default: + return nil + } +} + +func (s *SoraGatewayService) createCharacterFromVideo(ctx context.Context, account *Account, videoData []byte, opts soraCharacterOptions) (*soraCharacterFlowResult, error) { + cameoID, err := s.soraClient.UploadCharacterVideo(ctx, account, videoData) + if err != nil { + return nil, err + } + + cameoStatus, err := s.pollCameoStatus(ctx, account, cameoID) + if err != nil { + return nil, err + } + username := processSoraCharacterUsername(cameoStatus.UsernameHint) + displayName := strings.TrimSpace(cameoStatus.DisplayNameHint) + if displayName == "" { + displayName = "Character" + } + profileAssetURL := strings.TrimSpace(cameoStatus.ProfileAssetURL) + if profileAssetURL == "" { + return nil, errors.New("profile asset url not found in cameo status") + } + + avatarData, err := s.soraClient.DownloadCharacterImage(ctx, account, profileAssetURL) + if err != nil { + return nil, err + } + assetPointer, err := s.soraClient.UploadCharacterImage(ctx, account, avatarData) + if err != nil { + return nil, err + } + instructionSet := cameoStatus.InstructionSetHint + if instructionSet == nil { + instructionSet = cameoStatus.InstructionSet + } + + characterID, err := s.soraClient.FinalizeCharacter(ctx, account, SoraCharacterFinalizeRequest{ + CameoID: strings.TrimSpace(cameoID), + Username: username, + DisplayName: displayName, + ProfileAssetPointer: assetPointer, + InstructionSet: instructionSet, + }) + if err != nil { + return nil, err + } + + if opts.SetPublic { + if err := s.soraClient.SetCharacterPublic(ctx, account, cameoID); err != nil { + return nil, err + } + } + + return &soraCharacterFlowResult{ + CameoID: strings.TrimSpace(cameoID), + CharacterID: strings.TrimSpace(characterID), + Username: strings.TrimSpace(username), + DisplayName: displayName, + }, nil +} + +func (s *SoraGatewayService) pollCameoStatus(ctx context.Context, account *Account, cameoID string) (*SoraCameoStatus, error) { + timeout := 10 * time.Minute + interval := 5 * time.Second + maxAttempts := int(math.Ceil(timeout.Seconds() / interval.Seconds())) + if maxAttempts < 1 { + maxAttempts = 1 + } + + var lastErr error + consecutiveErrors := 0 + for attempt := 0; attempt < maxAttempts; attempt++ { + status, err := s.soraClient.GetCameoStatus(ctx, account, cameoID) + if err != nil { + lastErr = err + consecutiveErrors++ + if consecutiveErrors >= 3 { + break + } + if attempt < maxAttempts-1 { + if sleepErr := sleepWithContext(ctx, interval); sleepErr != nil { + return nil, sleepErr + } + } + continue + } + consecutiveErrors = 0 + if status == nil { + if attempt < maxAttempts-1 { + if sleepErr := sleepWithContext(ctx, interval); sleepErr != nil { + return nil, sleepErr + } + } + continue + } + currentStatus := strings.ToLower(strings.TrimSpace(status.Status)) + statusMessage := strings.TrimSpace(status.StatusMessage) + if currentStatus == "failed" { + if statusMessage == "" { + statusMessage = "character creation failed" + } + return nil, errors.New(statusMessage) + } + if strings.EqualFold(statusMessage, "Completed") || currentStatus == "finalized" { + return status, nil + } + if attempt < maxAttempts-1 { + if sleepErr := sleepWithContext(ctx, interval); sleepErr != nil { + return nil, sleepErr + } + } + } + if lastErr != nil { + return nil, fmt.Errorf("poll cameo status failed: %w", lastErr) + } + return nil, errors.New("cameo processing timeout") +} + +func processSoraCharacterUsername(usernameHint string) string { + usernameHint = strings.TrimSpace(usernameHint) + if usernameHint == "" { + usernameHint = "character" + } + if strings.Contains(usernameHint, ".") { + parts := strings.Split(usernameHint, ".") + usernameHint = strings.TrimSpace(parts[len(parts)-1]) + } + if usernameHint == "" { + usernameHint = "character" + } + return fmt.Sprintf("%s%d", usernameHint, soraRandInt(900)+100) +} + +func (s *SoraGatewayService) resolveWatermarkFreeURL(ctx context.Context, account *Account, generationID string, opts soraWatermarkOptions) (string, string, error) { + generationID = strings.TrimSpace(generationID) + if generationID == "" { + return "", "", errors.New("generation id is required for watermark-free mode") + } + postID, err := s.soraClient.PostVideoForWatermarkFree(ctx, account, generationID) + if err != nil { + return "", "", err + } + postID = strings.TrimSpace(postID) + if postID == "" { + return "", "", errors.New("watermark-free publish returned empty post id") + } + + switch opts.ParseMethod { + case "custom": + urlVal, parseErr := s.soraClient.GetWatermarkFreeURLCustom(ctx, account, opts.ParseURL, opts.ParseToken, postID) + if parseErr != nil { + return "", postID, parseErr + } + return strings.TrimSpace(urlVal), postID, nil + case "", "third_party": + return fmt.Sprintf("https://oscdn2.dyysy.com/MP4/%s.mp4", postID), postID, nil + default: + return "", postID, fmt.Errorf("unsupported watermark parse method: %s", opts.ParseMethod) + } +} + +func (s *SoraGatewayService) shouldFailoverUpstreamError(statusCode int) bool { + switch statusCode { + case 401, 402, 403, 404, 429, 529: + return true + default: + return statusCode >= 500 + } +} + +func buildSoraNonStreamResponse(content, model string) map[string]any { + return map[string]any{ + "id": fmt.Sprintf("chatcmpl-%d", time.Now().UnixNano()), + "object": "chat.completion", + "created": time.Now().Unix(), + "model": model, + "choices": []any{ + map[string]any{ + "index": 0, + "message": map[string]any{ + "role": "assistant", + "content": content, + }, + "finish_reason": "stop", + }, + }, + } +} + +func soraImageSizeFromModel(model string) string { + modelLower := strings.ToLower(model) + if size, ok := soraImageSizeMap[modelLower]; ok { + return size + } + if strings.Contains(modelLower, "landscape") || strings.Contains(modelLower, "portrait") { + return "540" + } + return "360" +} + +func soraProErrorMessage(model, upstreamMsg string) string { + modelLower := strings.ToLower(model) + if strings.Contains(modelLower, "sora2pro-hd") { + return "当前账号无法使用 Sora Pro-HD 模型,请更换模型或账号" + } + if strings.Contains(modelLower, "sora2pro") { + return "当前账号无法使用 Sora Pro 模型,请更换模型或账号" + } + return "" +} + +func firstMediaURL(urls []string) string { + if len(urls) == 0 { + return "" + } + return urls[0] +} + +func (s *SoraGatewayService) buildSoraMediaURL(path string, rawQuery string) string { + if path == "" { + return path + } + prefix := "/sora/media" + values := url.Values{} + if rawQuery != "" { + if parsed, err := url.ParseQuery(rawQuery); err == nil { + values = parsed + } + } + + signKey := "" + ttlSeconds := 0 + if s != nil && s.cfg != nil { + signKey = strings.TrimSpace(s.cfg.Gateway.SoraMediaSigningKey) + ttlSeconds = s.cfg.Gateway.SoraMediaSignedURLTTLSeconds + } + values.Del("sig") + values.Del("expires") + signingQuery := values.Encode() + if signKey != "" && ttlSeconds > 0 { + expires := time.Now().Add(time.Duration(ttlSeconds) * time.Second).Unix() + signature := SignSoraMediaURL(path, signingQuery, expires, signKey) + if signature != "" { + values.Set("expires", strconv.FormatInt(expires, 10)) + values.Set("sig", signature) + prefix = "/sora/media-signed" + } + } + + encoded := values.Encode() + if encoded == "" { + return prefix + path + } + return prefix + path + "?" + encoded +} + +func (s *SoraGatewayService) prepareSoraStream(c *gin.Context, requestID string) { + if c == nil { + return + } + c.Header("Content-Type", "text/event-stream") + c.Header("Cache-Control", "no-cache") + c.Header("Connection", "keep-alive") + c.Header("X-Accel-Buffering", "no") + if strings.TrimSpace(requestID) != "" { + c.Header("x-request-id", requestID) + } +} + +func (s *SoraGatewayService) writeSoraStream(c *gin.Context, model, content string, startTime time.Time) (*int, error) { + if c == nil { + return nil, nil + } + writer := c.Writer + flusher, _ := writer.(http.Flusher) + + chunk := map[string]any{ + "id": fmt.Sprintf("chatcmpl-%d", time.Now().UnixNano()), + "object": "chat.completion.chunk", + "created": time.Now().Unix(), + "model": model, + "choices": []any{ + map[string]any{ + "index": 0, + "delta": map[string]any{ + "content": content, + }, + }, + }, + } + encoded, _ := json.Marshal(chunk) + if _, err := fmt.Fprintf(writer, "data: %s\n\n", encoded); err != nil { + return nil, err + } + if flusher != nil { + flusher.Flush() + } + ms := int(time.Since(startTime).Milliseconds()) + finalChunk := map[string]any{ + "id": chunk["id"], + "object": "chat.completion.chunk", + "created": time.Now().Unix(), + "model": model, + "choices": []any{ + map[string]any{ + "index": 0, + "delta": map[string]any{}, + "finish_reason": "stop", + }, + }, + } + finalEncoded, _ := json.Marshal(finalChunk) + if _, err := fmt.Fprintf(writer, "data: %s\n\n", finalEncoded); err != nil { + return &ms, err + } + if _, err := fmt.Fprint(writer, "data: [DONE]\n\n"); err != nil { + return &ms, err + } + if flusher != nil { + flusher.Flush() + } + return &ms, nil +} + +func (s *SoraGatewayService) writeSoraError(c *gin.Context, status int, errType, message string, stream bool) { + if c == nil { + return + } + if stream { + flusher, _ := c.Writer.(http.Flusher) + errorData := map[string]any{ + "error": map[string]string{ + "type": errType, + "message": message, + }, + } + jsonBytes, err := json.Marshal(errorData) + if err != nil { + _ = c.Error(err) + return + } + errorEvent := fmt.Sprintf("event: error\ndata: %s\n\n", string(jsonBytes)) + _, _ = fmt.Fprint(c.Writer, errorEvent) + _, _ = fmt.Fprint(c.Writer, "data: [DONE]\n\n") + if flusher != nil { + flusher.Flush() + } + return + } + c.JSON(status, gin.H{ + "error": gin.H{ + "type": errType, + "message": message, + }, + }) +} + +func (s *SoraGatewayService) handleSoraRequestError(ctx context.Context, account *Account, err error, model string, c *gin.Context, stream bool) error { + if err == nil { + return nil + } + var upstreamErr *SoraUpstreamError + if errors.As(err, &upstreamErr) { + if s.rateLimitService != nil && account != nil { + s.rateLimitService.HandleUpstreamError(ctx, account, upstreamErr.StatusCode, upstreamErr.Headers, upstreamErr.Body) + } + if s.shouldFailoverUpstreamError(upstreamErr.StatusCode) { + var responseHeaders http.Header + if upstreamErr.Headers != nil { + responseHeaders = upstreamErr.Headers.Clone() + } + return &UpstreamFailoverError{ + StatusCode: upstreamErr.StatusCode, + ResponseBody: upstreamErr.Body, + ResponseHeaders: responseHeaders, + } + } + msg := upstreamErr.Message + if override := soraProErrorMessage(model, msg); override != "" { + msg = override + } + s.writeSoraError(c, upstreamErr.StatusCode, "upstream_error", msg, stream) + return err + } + if errors.Is(err, context.DeadlineExceeded) { + s.writeSoraError(c, http.StatusGatewayTimeout, "timeout_error", "Sora generation timeout", stream) + return err + } + s.writeSoraError(c, http.StatusBadGateway, "api_error", err.Error(), stream) + return err +} + +func (s *SoraGatewayService) pollImageTask(ctx context.Context, c *gin.Context, account *Account, taskID string, stream bool) ([]string, error) { + interval := s.pollInterval() + maxAttempts := s.pollMaxAttempts() + lastPing := time.Now() + for attempt := 0; attempt < maxAttempts; attempt++ { + status, err := s.soraClient.GetImageTask(ctx, account, taskID) + if err != nil { + return nil, err + } + switch strings.ToLower(status.Status) { + case "succeeded", "completed": + return status.URLs, nil + case "failed": + if status.ErrorMsg != "" { + return nil, errors.New(status.ErrorMsg) + } + return nil, errors.New("sora image generation failed") + } + if stream { + s.maybeSendPing(c, &lastPing) + } + if err := sleepWithContext(ctx, interval); err != nil { + return nil, err + } + } + return nil, errors.New("sora image generation timeout") +} + +func (s *SoraGatewayService) pollVideoTaskDetailed(ctx context.Context, c *gin.Context, account *Account, taskID string, stream bool) (*SoraVideoTaskStatus, error) { + interval := s.pollInterval() + maxAttempts := s.pollMaxAttempts() + lastPing := time.Now() + for attempt := 0; attempt < maxAttempts; attempt++ { + status, err := s.soraClient.GetVideoTask(ctx, account, taskID) + if err != nil { + return nil, err + } + switch strings.ToLower(status.Status) { + case "completed", "succeeded": + return status, nil + case "failed": + if status.ErrorMsg != "" { + return nil, errors.New(status.ErrorMsg) + } + return nil, errors.New("sora video generation failed") + } + if stream { + s.maybeSendPing(c, &lastPing) + } + if err := sleepWithContext(ctx, interval); err != nil { + return nil, err + } + } + return nil, errors.New("sora video generation timeout") +} + +func (s *SoraGatewayService) pollInterval() time.Duration { + if s == nil || s.cfg == nil { + return 2 * time.Second + } + interval := s.cfg.Sora.Client.PollIntervalSeconds + if interval <= 0 { + interval = 2 + } + return time.Duration(interval) * time.Second +} + +func (s *SoraGatewayService) pollMaxAttempts() int { + if s == nil || s.cfg == nil { + return 600 + } + maxAttempts := s.cfg.Sora.Client.MaxPollAttempts + if maxAttempts <= 0 { + maxAttempts = 600 + } + return maxAttempts +} + +func (s *SoraGatewayService) maybeSendPing(c *gin.Context, lastPing *time.Time) { + if c == nil { + return + } + interval := 10 * time.Second + if s != nil && s.cfg != nil && s.cfg.Concurrency.PingInterval > 0 { + interval = time.Duration(s.cfg.Concurrency.PingInterval) * time.Second + } + if time.Since(*lastPing) < interval { + return + } + if _, err := fmt.Fprint(c.Writer, ":\n\n"); err == nil { + if flusher, ok := c.Writer.(http.Flusher); ok { + flusher.Flush() + } + *lastPing = time.Now() + } +} + +func (s *SoraGatewayService) normalizeSoraMediaURLs(urls []string) []string { + if len(urls) == 0 { + return urls + } + output := make([]string, 0, len(urls)) + for _, raw := range urls { + raw = strings.TrimSpace(raw) + if raw == "" { + continue + } + if strings.HasPrefix(raw, "http://") || strings.HasPrefix(raw, "https://") { + output = append(output, raw) + continue + } + pathVal := raw + if !strings.HasPrefix(pathVal, "/") { + pathVal = "/" + pathVal + } + output = append(output, s.buildSoraMediaURL(pathVal, "")) + } + return output +} + +func buildSoraContent(mediaType string, urls []string) string { + switch mediaType { + case "image": + parts := make([]string, 0, len(urls)) + for _, u := range urls { + parts = append(parts, fmt.Sprintf("![image](%s)", u)) + } + return strings.Join(parts, "\n") + case "video": + if len(urls) == 0 { + return "" + } + return fmt.Sprintf("```html\n\n```", urls[0]) + default: + return "" + } +} + +func extractSoraInput(body map[string]any) (prompt, imageInput, videoInput, remixTargetID string) { + if body == nil { + return "", "", "", "" + } + if v, ok := body["remix_target_id"].(string); ok { + remixTargetID = strings.TrimSpace(v) + } + if v, ok := body["image"].(string); ok { + imageInput = v + } + if v, ok := body["video"].(string); ok { + videoInput = v + } + if v, ok := body["prompt"].(string); ok && strings.TrimSpace(v) != "" { + prompt = v + } + if messages, ok := body["messages"].([]any); ok { + builder := strings.Builder{} + for _, raw := range messages { + msg, ok := raw.(map[string]any) + if !ok { + continue + } + role, _ := msg["role"].(string) + if role != "" && role != "user" { + continue + } + content := msg["content"] + text, img, vid := parseSoraMessageContent(content) + if text != "" { + if builder.Len() > 0 { + _, _ = builder.WriteString("\n") + } + _, _ = builder.WriteString(text) + } + if imageInput == "" && img != "" { + imageInput = img + } + if videoInput == "" && vid != "" { + videoInput = vid + } + } + if prompt == "" { + prompt = builder.String() + } + } + if remixTargetID == "" { + remixTargetID = extractRemixTargetIDFromPrompt(prompt) + } + prompt = cleanRemixLinkFromPrompt(prompt) + return prompt, imageInput, videoInput, remixTargetID +} + +func parseSoraMessageContent(content any) (text, imageInput, videoInput string) { + switch val := content.(type) { + case string: + return val, "", "" + case []any: + builder := strings.Builder{} + for _, item := range val { + itemMap, ok := item.(map[string]any) + if !ok { + continue + } + t, _ := itemMap["type"].(string) + switch t { + case "text": + if txt, ok := itemMap["text"].(string); ok && strings.TrimSpace(txt) != "" { + if builder.Len() > 0 { + _, _ = builder.WriteString("\n") + } + _, _ = builder.WriteString(txt) + } + case "image_url": + if imageInput == "" { + if urlVal, ok := itemMap["image_url"].(map[string]any); ok { + imageInput = fmt.Sprintf("%v", urlVal["url"]) + } else if urlStr, ok := itemMap["image_url"].(string); ok { + imageInput = urlStr + } + } + case "video_url": + if videoInput == "" { + if urlVal, ok := itemMap["video_url"].(map[string]any); ok { + videoInput = fmt.Sprintf("%v", urlVal["url"]) + } else if urlStr, ok := itemMap["video_url"].(string); ok { + videoInput = urlStr + } + } + } + } + return builder.String(), imageInput, videoInput + default: + return "", "", "" + } +} + +func isSoraStoryboardPrompt(prompt string) bool { + prompt = strings.TrimSpace(prompt) + if prompt == "" { + return false + } + return len(soraStoryboardPattern.FindAllString(prompt, -1)) >= 1 +} + +func formatSoraStoryboardPrompt(prompt string) string { + prompt = strings.TrimSpace(prompt) + if prompt == "" { + return "" + } + matches := soraStoryboardShotPattern.FindAllStringSubmatch(prompt, -1) + if len(matches) == 0 { + return prompt + } + firstBracketPos := strings.Index(prompt, "[") + instructions := "" + if firstBracketPos > 0 { + instructions = strings.TrimSpace(prompt[:firstBracketPos]) + } + shots := make([]string, 0, len(matches)) + for i, match := range matches { + if len(match) < 3 { + continue + } + duration := strings.TrimSpace(match[1]) + scene := strings.TrimSpace(match[2]) + if scene == "" { + continue + } + shots = append(shots, fmt.Sprintf("Shot %d:\nduration: %ssec\nScene: %s", i+1, duration, scene)) + } + if len(shots) == 0 { + return prompt + } + timeline := strings.Join(shots, "\n\n") + if instructions == "" { + return timeline + } + return fmt.Sprintf("current timeline:\n%s\n\ninstructions:\n%s", timeline, instructions) +} + +func extractRemixTargetIDFromPrompt(prompt string) string { + prompt = strings.TrimSpace(prompt) + if prompt == "" { + return "" + } + return strings.TrimSpace(soraRemixTargetPattern.FindString(prompt)) +} + +func cleanRemixLinkFromPrompt(prompt string) string { + prompt = strings.TrimSpace(prompt) + if prompt == "" { + return prompt + } + cleaned := soraRemixTargetInURLPattern.ReplaceAllString(prompt, "") + cleaned = soraRemixTargetPattern.ReplaceAllString(cleaned, "") + cleaned = strings.Join(strings.Fields(cleaned), " ") + return strings.TrimSpace(cleaned) +} + +func decodeSoraImageInput(ctx context.Context, input string) ([]byte, string, error) { + raw := strings.TrimSpace(input) + if raw == "" { + return nil, "", errors.New("empty image input") + } + if strings.HasPrefix(raw, "data:") { + parts := strings.SplitN(raw, ",", 2) + if len(parts) != 2 { + return nil, "", errors.New("invalid data url") + } + meta := parts[0] + payload := parts[1] + decoded, err := decodeBase64WithLimit(payload, soraImageInputMaxBytes) + if err != nil { + return nil, "", err + } + ext := "" + if strings.HasPrefix(meta, "data:") { + metaParts := strings.SplitN(meta[5:], ";", 2) + if len(metaParts) > 0 { + if exts, err := mime.ExtensionsByType(metaParts[0]); err == nil && len(exts) > 0 { + ext = exts[0] + } + } + } + filename := "image" + ext + return decoded, filename, nil + } + if strings.HasPrefix(raw, "http://") || strings.HasPrefix(raw, "https://") { + return downloadSoraImageInput(ctx, raw) + } + decoded, err := decodeBase64WithLimit(raw, soraImageInputMaxBytes) + if err != nil { + return nil, "", errors.New("invalid base64 image") + } + return decoded, "image.png", nil +} + +func decodeSoraVideoInput(ctx context.Context, input string) ([]byte, error) { + raw := strings.TrimSpace(input) + if raw == "" { + return nil, errors.New("empty video input") + } + if strings.HasPrefix(raw, "data:") { + parts := strings.SplitN(raw, ",", 2) + if len(parts) != 2 { + return nil, errors.New("invalid video data url") + } + decoded, err := decodeBase64WithLimit(parts[1], soraVideoInputMaxBytes) + if err != nil { + return nil, errors.New("invalid base64 video") + } + if len(decoded) == 0 { + return nil, errors.New("empty video data") + } + return decoded, nil + } + if strings.HasPrefix(raw, "http://") || strings.HasPrefix(raw, "https://") { + return downloadSoraVideoInput(ctx, raw) + } + decoded, err := decodeBase64WithLimit(raw, soraVideoInputMaxBytes) + if err != nil { + return nil, errors.New("invalid base64 video") + } + if len(decoded) == 0 { + return nil, errors.New("empty video data") + } + return decoded, nil +} + +func downloadSoraImageInput(ctx context.Context, rawURL string) ([]byte, string, error) { + parsed, err := validateSoraRemoteURL(rawURL) + if err != nil { + return nil, "", err + } + req, err := http.NewRequestWithContext(ctx, http.MethodGet, parsed.String(), nil) + if err != nil { + return nil, "", err + } + client := &http.Client{ + Timeout: soraImageInputTimeout, + CheckRedirect: func(req *http.Request, via []*http.Request) error { + if len(via) >= soraImageInputMaxRedirects { + return errors.New("too many redirects") + } + return validateSoraRemoteURLValue(req.URL) + }, + } + resp, err := client.Do(req) + if err != nil { + return nil, "", err + } + defer func() { _ = resp.Body.Close() }() + if resp.StatusCode != http.StatusOK { + return nil, "", fmt.Errorf("download image failed: %d", resp.StatusCode) + } + data, err := io.ReadAll(io.LimitReader(resp.Body, soraImageInputMaxBytes)) + if err != nil { + return nil, "", err + } + ext := fileExtFromURL(parsed.String()) + if ext == "" { + ext = fileExtFromContentType(resp.Header.Get("Content-Type")) + } + filename := "image" + ext + return data, filename, nil +} + +func downloadSoraVideoInput(ctx context.Context, rawURL string) ([]byte, error) { + parsed, err := validateSoraRemoteURL(rawURL) + if err != nil { + return nil, err + } + req, err := http.NewRequestWithContext(ctx, http.MethodGet, parsed.String(), nil) + if err != nil { + return nil, err + } + client := &http.Client{ + Timeout: soraVideoInputTimeout, + CheckRedirect: func(req *http.Request, via []*http.Request) error { + if len(via) >= soraVideoInputMaxRedirects { + return errors.New("too many redirects") + } + return validateSoraRemoteURLValue(req.URL) + }, + } + resp, err := client.Do(req) + if err != nil { + return nil, err + } + defer func() { _ = resp.Body.Close() }() + if resp.StatusCode != http.StatusOK { + return nil, fmt.Errorf("download video failed: %d", resp.StatusCode) + } + data, err := io.ReadAll(io.LimitReader(resp.Body, soraVideoInputMaxBytes)) + if err != nil { + return nil, err + } + if len(data) == 0 { + return nil, errors.New("empty video content") + } + return data, nil +} + +func decodeBase64WithLimit(encoded string, maxBytes int64) ([]byte, error) { + if maxBytes <= 0 { + return nil, errors.New("invalid max bytes limit") + } + decoder := base64.NewDecoder(base64.StdEncoding, strings.NewReader(encoded)) + limited := io.LimitReader(decoder, maxBytes+1) + data, err := io.ReadAll(limited) + if err != nil { + return nil, err + } + if int64(len(data)) > maxBytes { + return nil, fmt.Errorf("input exceeds %d bytes limit", maxBytes) + } + return data, nil +} + +func validateSoraRemoteURL(raw string) (*url.URL, error) { + if strings.TrimSpace(raw) == "" { + return nil, errors.New("empty remote url") + } + parsed, err := url.Parse(raw) + if err != nil { + return nil, fmt.Errorf("invalid remote url: %w", err) + } + if err := validateSoraRemoteURLValue(parsed); err != nil { + return nil, err + } + return parsed, nil +} + +func validateSoraRemoteURLValue(parsed *url.URL) error { + if parsed == nil { + return errors.New("invalid remote url") + } + scheme := strings.ToLower(strings.TrimSpace(parsed.Scheme)) + if scheme != "http" && scheme != "https" { + return errors.New("only http/https remote url is allowed") + } + if parsed.User != nil { + return errors.New("remote url cannot contain userinfo") + } + host := strings.ToLower(strings.TrimSpace(parsed.Hostname())) + if host == "" { + return errors.New("remote url missing host") + } + if _, blocked := soraBlockedHostnames[host]; blocked { + return errors.New("remote url is not allowed") + } + if ip := net.ParseIP(host); ip != nil { + if isSoraBlockedIP(ip) { + return errors.New("remote url is not allowed") + } + return nil + } + ips, err := net.LookupIP(host) + if err != nil { + return fmt.Errorf("resolve remote url failed: %w", err) + } + for _, ip := range ips { + if isSoraBlockedIP(ip) { + return errors.New("remote url is not allowed") + } + } + return nil +} + +func isSoraBlockedIP(ip net.IP) bool { + if ip == nil { + return true + } + for _, cidr := range soraBlockedCIDRs { + if cidr.Contains(ip) { + return true + } + } + return false +} + +func mustParseCIDRs(values []string) []*net.IPNet { + out := make([]*net.IPNet, 0, len(values)) + for _, val := range values { + _, cidr, err := net.ParseCIDR(val) + if err != nil { + continue + } + out = append(out, cidr) + } + return out +} diff --git a/backend/internal/service/sora_gateway_service_test.go b/backend/internal/service/sora_gateway_service_test.go new file mode 100644 index 00000000..5888fe92 --- /dev/null +++ b/backend/internal/service/sora_gateway_service_test.go @@ -0,0 +1,526 @@ +//go:build unit + +package service + +import ( + "context" + "encoding/json" + "errors" + "net/http" + "net/http/httptest" + "strings" + "testing" + "time" + + "github.com/Wei-Shaw/sub2api/internal/config" + "github.com/gin-gonic/gin" + "github.com/stretchr/testify/require" +) + +var _ SoraClient = (*stubSoraClientForPoll)(nil) + +type stubSoraClientForPoll struct { + imageStatus *SoraImageTaskStatus + videoStatus *SoraVideoTaskStatus + imageCalls int + videoCalls int + enhanced string + enhanceErr error + storyboard bool + videoReq SoraVideoRequest + parseErr error + postCalls int + deleteCalls int +} + +func (s *stubSoraClientForPoll) Enabled() bool { return true } +func (s *stubSoraClientForPoll) UploadImage(ctx context.Context, account *Account, data []byte, filename string) (string, error) { + return "", nil +} +func (s *stubSoraClientForPoll) CreateImageTask(ctx context.Context, account *Account, req SoraImageRequest) (string, error) { + return "task-image", nil +} +func (s *stubSoraClientForPoll) CreateVideoTask(ctx context.Context, account *Account, req SoraVideoRequest) (string, error) { + s.videoReq = req + return "task-video", nil +} +func (s *stubSoraClientForPoll) CreateStoryboardTask(ctx context.Context, account *Account, req SoraStoryboardRequest) (string, error) { + s.storyboard = true + return "task-video", nil +} +func (s *stubSoraClientForPoll) UploadCharacterVideo(ctx context.Context, account *Account, data []byte) (string, error) { + return "cameo-1", nil +} +func (s *stubSoraClientForPoll) GetCameoStatus(ctx context.Context, account *Account, cameoID string) (*SoraCameoStatus, error) { + return &SoraCameoStatus{ + Status: "finalized", + StatusMessage: "Completed", + DisplayNameHint: "Character", + UsernameHint: "user.character", + ProfileAssetURL: "https://example.com/avatar.webp", + }, nil +} +func (s *stubSoraClientForPoll) DownloadCharacterImage(ctx context.Context, account *Account, imageURL string) ([]byte, error) { + return []byte("avatar"), nil +} +func (s *stubSoraClientForPoll) UploadCharacterImage(ctx context.Context, account *Account, data []byte) (string, error) { + return "asset-pointer", nil +} +func (s *stubSoraClientForPoll) FinalizeCharacter(ctx context.Context, account *Account, req SoraCharacterFinalizeRequest) (string, error) { + return "character-1", nil +} +func (s *stubSoraClientForPoll) SetCharacterPublic(ctx context.Context, account *Account, cameoID string) error { + return nil +} +func (s *stubSoraClientForPoll) DeleteCharacter(ctx context.Context, account *Account, characterID string) error { + return nil +} +func (s *stubSoraClientForPoll) PostVideoForWatermarkFree(ctx context.Context, account *Account, generationID string) (string, error) { + s.postCalls++ + return "s_post", nil +} +func (s *stubSoraClientForPoll) DeletePost(ctx context.Context, account *Account, postID string) error { + s.deleteCalls++ + return nil +} +func (s *stubSoraClientForPoll) GetWatermarkFreeURLCustom(ctx context.Context, account *Account, parseURL, parseToken, postID string) (string, error) { + if s.parseErr != nil { + return "", s.parseErr + } + return "https://example.com/no-watermark.mp4", nil +} +func (s *stubSoraClientForPoll) EnhancePrompt(ctx context.Context, account *Account, prompt, expansionLevel string, durationS int) (string, error) { + if s.enhanced != "" { + return s.enhanced, s.enhanceErr + } + return "enhanced prompt", s.enhanceErr +} +func (s *stubSoraClientForPoll) GetImageTask(ctx context.Context, account *Account, taskID string) (*SoraImageTaskStatus, error) { + s.imageCalls++ + return s.imageStatus, nil +} +func (s *stubSoraClientForPoll) GetVideoTask(ctx context.Context, account *Account, taskID string) (*SoraVideoTaskStatus, error) { + s.videoCalls++ + return s.videoStatus, nil +} + +func TestSoraGatewayService_PollImageTaskCompleted(t *testing.T) { + client := &stubSoraClientForPoll{ + imageStatus: &SoraImageTaskStatus{ + Status: "completed", + URLs: []string{"https://example.com/a.png"}, + }, + } + cfg := &config.Config{ + Sora: config.SoraConfig{ + Client: config.SoraClientConfig{ + PollIntervalSeconds: 1, + MaxPollAttempts: 1, + }, + }, + } + service := NewSoraGatewayService(client, nil, nil, cfg) + + urls, err := service.pollImageTask(context.Background(), nil, &Account{ID: 1}, "task", false) + require.NoError(t, err) + require.Equal(t, []string{"https://example.com/a.png"}, urls) + require.Equal(t, 1, client.imageCalls) +} + +func TestSoraGatewayService_ForwardPromptEnhance(t *testing.T) { + client := &stubSoraClientForPoll{ + enhanced: "cinematic prompt", + } + cfg := &config.Config{ + Sora: config.SoraConfig{ + Client: config.SoraClientConfig{ + PollIntervalSeconds: 1, + MaxPollAttempts: 1, + }, + }, + } + svc := NewSoraGatewayService(client, nil, nil, cfg) + account := &Account{ + ID: 1, + Platform: PlatformSora, + Status: StatusActive, + } + body := []byte(`{"model":"prompt-enhance-short-10s","messages":[{"role":"user","content":"cat running"}],"stream":false}`) + + result, err := svc.Forward(context.Background(), nil, account, body, false) + require.NoError(t, err) + require.NotNil(t, result) + require.Equal(t, "prompt", result.MediaType) + require.Equal(t, "prompt-enhance-short-10s", result.Model) +} + +func TestSoraGatewayService_ForwardStoryboardPrompt(t *testing.T) { + client := &stubSoraClientForPoll{ + videoStatus: &SoraVideoTaskStatus{ + Status: "completed", + URLs: []string{"https://example.com/v.mp4"}, + }, + } + cfg := &config.Config{ + Sora: config.SoraConfig{ + Client: config.SoraClientConfig{ + PollIntervalSeconds: 1, + MaxPollAttempts: 1, + }, + }, + } + svc := NewSoraGatewayService(client, nil, nil, cfg) + account := &Account{ID: 1, Platform: PlatformSora, Status: StatusActive} + body := []byte(`{"model":"sora2-landscape-10s","messages":[{"role":"user","content":"[5.0s]猫猫跳伞 [5.0s]猫猫落地"}],"stream":false}`) + + result, err := svc.Forward(context.Background(), nil, account, body, false) + require.NoError(t, err) + require.NotNil(t, result) + require.True(t, client.storyboard) +} + +func TestSoraGatewayService_ForwardCharacterOnly(t *testing.T) { + client := &stubSoraClientForPoll{} + cfg := &config.Config{ + Sora: config.SoraConfig{ + Client: config.SoraClientConfig{ + PollIntervalSeconds: 1, + MaxPollAttempts: 1, + }, + }, + } + svc := NewSoraGatewayService(client, nil, nil, cfg) + account := &Account{ID: 1, Platform: PlatformSora, Status: StatusActive} + body := []byte(`{"model":"sora2-landscape-10s","video":"aGVsbG8=","stream":false}`) + + result, err := svc.Forward(context.Background(), nil, account, body, false) + require.NoError(t, err) + require.NotNil(t, result) + require.Equal(t, "prompt", result.MediaType) + require.Equal(t, 0, client.videoCalls) +} + +func TestSoraGatewayService_ForwardWatermarkFallback(t *testing.T) { + client := &stubSoraClientForPoll{ + videoStatus: &SoraVideoTaskStatus{ + Status: "completed", + URLs: []string{"https://example.com/original.mp4"}, + GenerationID: "gen_1", + }, + parseErr: errors.New("parse failed"), + } + cfg := &config.Config{ + Sora: config.SoraConfig{ + Client: config.SoraClientConfig{ + PollIntervalSeconds: 1, + MaxPollAttempts: 1, + }, + }, + } + svc := NewSoraGatewayService(client, nil, nil, cfg) + account := &Account{ID: 1, Platform: PlatformSora, Status: StatusActive} + body := []byte(`{"model":"sora2-landscape-10s","messages":[{"role":"user","content":"cat running"}],"stream":false,"watermark_free":true,"watermark_parse_method":"custom","watermark_parse_url":"https://parser.example.com","watermark_parse_token":"token","watermark_fallback_on_failure":true}`) + + result, err := svc.Forward(context.Background(), nil, account, body, false) + require.NoError(t, err) + require.NotNil(t, result) + require.Equal(t, "https://example.com/original.mp4", result.MediaURL) + require.Equal(t, 1, client.postCalls) + require.Equal(t, 0, client.deleteCalls) +} + +func TestSoraGatewayService_ForwardWatermarkCustomSuccessAndDelete(t *testing.T) { + client := &stubSoraClientForPoll{ + videoStatus: &SoraVideoTaskStatus{ + Status: "completed", + URLs: []string{"https://example.com/original.mp4"}, + GenerationID: "gen_1", + }, + } + cfg := &config.Config{ + Sora: config.SoraConfig{ + Client: config.SoraClientConfig{ + PollIntervalSeconds: 1, + MaxPollAttempts: 1, + }, + }, + } + svc := NewSoraGatewayService(client, nil, nil, cfg) + account := &Account{ID: 1, Platform: PlatformSora, Status: StatusActive} + body := []byte(`{"model":"sora2-landscape-10s","messages":[{"role":"user","content":"cat running"}],"stream":false,"watermark_free":true,"watermark_parse_method":"custom","watermark_parse_url":"https://parser.example.com","watermark_parse_token":"token","watermark_delete_post":true}`) + + result, err := svc.Forward(context.Background(), nil, account, body, false) + require.NoError(t, err) + require.NotNil(t, result) + require.Equal(t, "https://example.com/no-watermark.mp4", result.MediaURL) + require.Equal(t, 1, client.postCalls) + require.Equal(t, 1, client.deleteCalls) +} + +func TestSoraGatewayService_PollVideoTaskFailed(t *testing.T) { + client := &stubSoraClientForPoll{ + videoStatus: &SoraVideoTaskStatus{ + Status: "failed", + ErrorMsg: "reject", + }, + } + cfg := &config.Config{ + Sora: config.SoraConfig{ + Client: config.SoraClientConfig{ + PollIntervalSeconds: 1, + MaxPollAttempts: 1, + }, + }, + } + service := NewSoraGatewayService(client, nil, nil, cfg) + + status, err := service.pollVideoTaskDetailed(context.Background(), nil, &Account{ID: 1}, "task", false) + require.Error(t, err) + require.Nil(t, status) + require.Contains(t, err.Error(), "reject") + require.Equal(t, 1, client.videoCalls) +} + +func TestSoraGatewayService_BuildSoraMediaURLSigned(t *testing.T) { + cfg := &config.Config{ + Gateway: config.GatewayConfig{ + SoraMediaSigningKey: "test-key", + SoraMediaSignedURLTTLSeconds: 600, + }, + } + service := NewSoraGatewayService(nil, nil, nil, cfg) + + url := service.buildSoraMediaURL("/image/2025/01/01/a.png", "") + require.Contains(t, url, "/sora/media-signed") + require.Contains(t, url, "expires=") + require.Contains(t, url, "sig=") +} + +func TestNormalizeSoraMediaURLs_Empty(t *testing.T) { + svc := NewSoraGatewayService(nil, nil, nil, &config.Config{}) + result := svc.normalizeSoraMediaURLs(nil) + require.Empty(t, result) + + result = svc.normalizeSoraMediaURLs([]string{}) + require.Empty(t, result) +} + +func TestNormalizeSoraMediaURLs_HTTPUrls(t *testing.T) { + svc := NewSoraGatewayService(nil, nil, nil, &config.Config{}) + urls := []string{"https://example.com/a.png", "http://example.com/b.mp4"} + result := svc.normalizeSoraMediaURLs(urls) + require.Equal(t, urls, result) +} + +func TestNormalizeSoraMediaURLs_LocalPaths(t *testing.T) { + cfg := &config.Config{} + svc := NewSoraGatewayService(nil, nil, nil, cfg) + urls := []string{"/image/2025/01/a.png", "video/2025/01/b.mp4"} + result := svc.normalizeSoraMediaURLs(urls) + require.Len(t, result, 2) + require.Contains(t, result[0], "/sora/media") + require.Contains(t, result[1], "/sora/media") +} + +func TestNormalizeSoraMediaURLs_SkipsBlank(t *testing.T) { + svc := NewSoraGatewayService(nil, nil, nil, &config.Config{}) + urls := []string{"https://example.com/a.png", "", " ", "https://example.com/b.png"} + result := svc.normalizeSoraMediaURLs(urls) + require.Len(t, result, 2) +} + +func TestBuildSoraContent_Image(t *testing.T) { + content := buildSoraContent("image", []string{"https://a.com/1.png", "https://a.com/2.png"}) + require.Contains(t, content, "![image](https://a.com/1.png)") + require.Contains(t, content, "![image](https://a.com/2.png)") +} + +func TestBuildSoraContent_Video(t *testing.T) { + content := buildSoraContent("video", []string{"https://a.com/v.mp4"}) + require.Contains(t, content, "